亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

pytorch使用voc分割數(shù)據(jù)集訓(xùn)練FCN流程講解

 更新時(shí)間:2022年12月08日 11:21:05   作者:專(zhuān)業(yè)女神殺手  
這篇文章主要介紹了pytorch使用voc分割數(shù)據(jù)集訓(xùn)練FCN流程,圖像分割發(fā)展過(guò)程也經(jīng)歷了傳統(tǒng)算法到深度學(xué)習(xí)算法的轉(zhuǎn)變,傳統(tǒng)的分割算法包括閾值分割、分水嶺、邊緣檢測(cè)等等

語(yǔ)義分割是對(duì)圖像中的每一個(gè)像素進(jìn)行分類(lèi),從而完成圖像分割的過(guò)程。分割主要用于醫(yī)學(xué)圖像領(lǐng)域和無(wú)人駕駛領(lǐng)域。

和其他算法一樣,圖像分割發(fā)展過(guò)程也經(jīng)歷了傳統(tǒng)算法到深度學(xué)習(xí)算法的轉(zhuǎn)變,傳統(tǒng)的分割算法包括閾值分割、分水嶺、邊緣檢測(cè)等等,面臨的問(wèn)題也跟其他傳統(tǒng)圖像處理算法一樣,就是魯棒性不夠,但在一些場(chǎng)景單一不變的場(chǎng)合,傳統(tǒng)圖像處理依舊用的較多。

FCN是2014年的一篇論文,深度學(xué)習(xí)語(yǔ)義分割的開(kāi)山之作,從思想上奠定了語(yǔ)義分割的基礎(chǔ)。

Fully Convolutional Networks for Semantic Segmentation

Submitted on 14 Nov 2014

https://arxiv.org/abs/1411.4038

一、FCN理論介紹

上圖是原論文中的截圖,從整體架構(gòu)上描繪了FCN的網(wǎng)絡(luò)架構(gòu)。其實(shí)就是圖像經(jīng)過(guò)一系列卷積運(yùn)算,然后再上采樣成原圖大小,輸出每一個(gè)像素的類(lèi)別概率。

上圖更加細(xì)致的描述了FCN的網(wǎng)絡(luò)。backbone采用VGG16,把VGG的fully-connect層用卷積來(lái)表示,即conv6-7(一個(gè)大小和feature_map同樣size的卷積核,就相當(dāng)于全連接)??偟膩?lái)說(shuō),網(wǎng)絡(luò)有下列幾個(gè)關(guān)鍵點(diǎn):

1. Fully Convolution: 用于解決像素的預(yù)測(cè)問(wèn)題。通過(guò)將基礎(chǔ)網(wǎng)絡(luò)(如VGG16)最后全連接層替換為卷積層,可實(shí)現(xiàn)任意大小的圖像輸入,并且輸出圖像大小與輸入相對(duì)應(yīng);

2.Transpose Convolution: 上采樣過(guò)程,用于恢復(fù)圖片尺寸,方便后續(xù)進(jìn)行逐個(gè)像素的預(yù)測(cè);

3. Skip Architecture : 用于融合高底層特征信息。因?yàn)榫矸e是個(gè)下采樣操作,而轉(zhuǎn)置卷積雖然恢復(fù)了圖像尺寸,但畢竟不是卷積的逆操作,所以信息肯定有丟失,而skip architecture可以融合千層的細(xì)粒度信息和深層的粗粒度信息,提高分割的精細(xì)程度。

FCN-32s: 沒(méi)有跳連接,按照每層轉(zhuǎn)置卷積放大2倍的速度放大,經(jīng)過(guò)五層后放大32倍復(fù)原原圖大小。

FCN-16s: 一個(gè)skip-connect,(1/32)放大為(1/16)后,再與vgg的(1/16)相加,然后繼續(xù)放大,直到原圖大小。

FCN-8s: 兩個(gè)skip-connect,一個(gè)是(1/32)放大為(1/16)后,再與vgg的(1/16)相加;另外一個(gè)是(1/16)放大為(1/8)之后,再與vgg的(1/8)相加,然后繼續(xù)放大,直到原圖大小。

二、訓(xùn)練過(guò)程

pytorch訓(xùn)練深度學(xué)習(xí)模型主要實(shí)現(xiàn)三個(gè)文件即可,分別為data.py, model.py, train.py。其中data.py里實(shí)現(xiàn)數(shù)據(jù)批量處理功能,model.py定義網(wǎng)絡(luò)模型,train.py實(shí)現(xiàn)訓(xùn)練步驟。

2.1 voc數(shù)據(jù)集介紹

下載地址:Pascal VOC Dataset Mirror

圖片的名稱(chēng)在/ImageSets/Segmentation/train.txt ans val.txt里

圖片都在./data/VOC2012/JPEGImages文件夾下面,需要在train.txt讀取的每一行后面加上.jpg

標(biāo)簽都在./data/VOC2012/SegmentationClass文件夾下面,需要在讀取的每一行后面加上.png

voc_seg_data.py

import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader,Dataset
import numpy as np
import os
from PIL import Image
from datetime import datetime
class VOC_SEG(Dataset):
    def __init__(self, root, width, height, train=True, transforms=None):
        # 圖像統(tǒng)一剪切尺寸(width, height)
        self.width = width
        self.height = height
        # VOC數(shù)據(jù)集中對(duì)應(yīng)的標(biāo)簽
        self.classes = ['background','aeroplane','bicycle','bird','boat',
           'bottle','bus','car','cat','chair','cow','diningtable',
           'dog','horse','motorbike','person','potted plant',
           'sheep','sofa','train','tv/monitor']
        # 各種標(biāo)簽所對(duì)應(yīng)的顏色
        self.colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]
        # 輔助變量
        self.fnum = 0
        if transforms is None:
            normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            self.transforms = T.Compose([
                T.ToTensor(),
                normalize
            ])
        # 像素值(RGB)與類(lèi)別label(0,1,3...)一一對(duì)應(yīng)
        self.cm2lbl = np.zeros(256**3)
        for i, cm in enumerate(self.colormap):
            self.cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i
        if train:
            txt_fname = root+"/ImageSets/Segmentation/train.txt"
        else:
            txt_fname = root+"/ImageSets/Segmentation/val.txt"
        with open(txt_fname, 'r') as f:
            images = f.read().split()
        imgs = [os.path.join(root, "JPEGImages", item+".jpg") for item in images]
        labels = [os.path.join(root, "SegmentationClass", item+".png") for item in images]
        self.imgs = self._filter(imgs)
        self.labels = self._filter(labels)
        if train:
            print("訓(xùn)練集:加載了 " + str(len(self.imgs)) + " 張圖片和標(biāo)簽" + ",過(guò)濾了" + str(self.fnum) + "張圖片")
        else:
            print("測(cè)試集:加載了 " + str(len(self.imgs)) + " 張圖片和標(biāo)簽" + ",過(guò)濾了" + str(self.fnum) + "張圖片")
    def _crop(self, data, label):
        """
        切割函數(shù),默認(rèn)都是從圖片的左上角開(kāi)始切割。切割后的圖片寬是width,高是height
        data和label都是Image對(duì)象
        """
        box = (0,0,self.width,self.height)
        data = data.crop(box)
        label = label.crop(box)
        return data, label
    def _image2label(self, im):
        data = np.array(im, dtype="int32")
        idx = (data[:,:,0]*256+data[:,:,1])*256+data[:,:,2]
        return np.array(self.cm2lbl[idx], dtype="int64")
    def _image_transforms(self, data, label):
        data, label = self._crop(data,label)
        data = self.transforms(data)
        label = self._image2label(label)
        label = torch.from_numpy(label)
        return data, label
    def _filter(self, imgs): 
        img = []
        for im in imgs:
            if (Image.open(im).size[1] >= self.height and 
               Image.open(im).size[0] >= self.width):
                img.append(im)
            else:
                self.fnum  = self.fnum+1
        return img
    def __getitem__(self, index: int):
        img_path = self.imgs[index]
        label_path = self.labels[index]
        img = Image.open(img_path)
        label = Image.open(label_path).convert("RGB")
        img, label = self._image_transforms(img, label)
        return img, label
    def __len__(self) :
        return len(self.imgs)
if __name__=="__main__":
    root = "./VOCdevkit/VOC2012"
    height = 224
    width = 224
    voc_train = VOC_SEG(root, width, height, train=True)
    voc_test = VOC_SEG(root, width, height, train=False)
    # train_data = DataLoader(voc_train, batch_size=8, shuffle=True)
    # valid_data = DataLoader(voc_test, batch_size=8)
    for data, label in voc_train:
        print(data.shape)
        print(label.shape)
        break
  • 我這里為了省事把一些輔助函數(shù),如_crop(), _filter(),還是有變量colormap等都寫(xiě)到類(lèi)里面了。實(shí)際上脫離出來(lái)另外寫(xiě)一個(gè)數(shù)據(jù)預(yù)處理的文件比較好,這樣在訓(xùn)練結(jié)束后,推理測(cè)試時(shí)可以直接調(diào)用相應(yīng)的處理函數(shù)。
  • 數(shù)據(jù)處理的結(jié)果是得到data, label。data是tensor格式的圖像,label也是tensor,且已經(jīng)把像素(RGB)替換為了int類(lèi)別號(hào)。這樣在訓(xùn)練時(shí)候,交叉熵函數(shù)直接會(huì)實(shí)現(xiàn)one-hot處理,就跟訓(xùn)練分類(lèi)網(wǎng)絡(luò)一樣。

2.2 網(wǎng)絡(luò)定義

fcn8s_net.py

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchsummary import summary
from torchvision import models
class FCN8s(nn.Module):
    def __init__(self, num_classes=21):
        super(FCN8s,self).__init__()
        net = models.vgg16(pretrained=True)   # 從預(yù)訓(xùn)練模型加載VGG16網(wǎng)絡(luò)參數(shù)
        self.premodel = net.features          # 只使用Vgg16的五層卷積層(特征提取層)(3,224,224)----->(512,7,7)
        # self.conv6 = nn.Conv2d(512,512,kernel_size=1,stride=1,padding=0,dilation=1) 
        # self.conv7 = nn.Conv2d(512,512,kernel_size=1,stride=1,padding=0,dilation=1)
        # (512,7,7)
        self.relu = nn.ReLU(inplace=True)
        self.deconv1 = nn.ConvTranspose2d(512,512,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)  # x2
        self.bn1 = nn.BatchNorm2d(512)
        # (512, 14, 14)
        self.deconv2 = nn.ConvTranspose2d(512,256,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)  # x2
        self.bn2 = nn.BatchNorm2d(256)
        # (256, 28, 28)
        self.deconv3 = nn.ConvTranspose2d(256,128,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)  # x2
        self.bn3 = nn.BatchNorm2d(128)
        # (128, 56, 56)
        self.deconv4 = nn.ConvTranspose2d(128,64,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)   # x2
        self.bn4 = nn.BatchNorm2d(64)
        # (64, 112, 112)
        self.deconv5 = nn.ConvTranspose2d(64,32,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)    # x2
        self.bn5 = nn.BatchNorm2d(32)
        # (32, 224, 224)
        self.classifier = nn.Conv2d(32, num_classes, kernel_size=1)
        # (num_classes, 224, 224)
    def forward(self, input):
        x = input
        for i in range(len(self.premodel)):
            x = self.premodel[i](x)
            if i == 16:
                x3 = x  # maxpooling3的feature map (1/8)
            if i == 23:
                x4 = x  # maxpooling4的feature map (1/16)
            if i == 30:
                x5 = x  # maxpooling5的feature map (1/32)
        # 五層轉(zhuǎn)置卷積,每層size放大2倍,與VGG16剛好相反。兩個(gè)skip-connect
        score = self.relu(self.deconv1(x5))   # out_size = 2*in_size (1/16)
        score = self.bn1(score + x4)
        score = self.relu(self.deconv2(score)) # out_size = 2*in_size (1/8)  
        score = self.bn2(score + x3)
        score = self.bn3(self.relu(self.deconv3(score)))  # out_size = 2*in_size (1/4)
        score = self.bn4(self.relu(self.deconv4(score)))  # out_size = 2*in_size (1/2)
        score = self.bn5(self.relu(self.deconv5(score)))  # out_size = 2*in_size (1)
        score = self.classifier(score)                    # size不變,使輸出的channel等于類(lèi)別數(shù)
        return score
if __name__ == "__main__":
    model = FCN8s()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    print(model)

FCN的網(wǎng)絡(luò)代碼實(shí)現(xiàn)上,在網(wǎng)上查的都有所差異,不過(guò)總體都是卷積+轉(zhuǎn)置卷積+跳鏈接的結(jié)構(gòu)。實(shí)際上只要實(shí)現(xiàn)特征提取(提取抽象特征)——轉(zhuǎn)置卷積(恢復(fù)原圖大?。?mdash;—給每一個(gè)像素分類(lèi)的過(guò)程就夠了。

本次實(shí)驗(yàn)采用vgg16的五層卷積層作為特征提取網(wǎng)絡(luò),然后接五個(gè)轉(zhuǎn)置卷積(2x)恢復(fù)到原圖大小,然后再接一個(gè)卷積層把feature map的通道調(diào)整為類(lèi)別個(gè)數(shù)(21)。最后再softmax分類(lèi)就行了。

2.3 訓(xùn)練

train.py

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from voc_seg_data import VOC_SEG
from fcn_net import FCN8s
import os
import numpy as np
# 計(jì)算混淆矩陣
def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist
# 根據(jù)混淆矩陣計(jì)算Acc和mIou
def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaluation result.
      - overall accuracy
      - mean accuracy
      - mean IU
    """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / hist.sum()
    with np.errstate(divide='ignore', invalid='ignore'):
        acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    with np.errstate(divide='ignore', invalid='ignore'):
        iu = np.diag(hist) / (
            hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
        )
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    return acc, acc_cls, mean_iu
def main():
    # 1. load dataset
    root = "./VOCdevkit/VOC2012"
    batch_size = 32
    height = 224
    width = 224
    voc_train = VOC_SEG(root, width, height, train=True)
    voc_test = VOC_SEG(root, width, height, train=False)
    train_dataloader = DataLoader(voc_train,batch_size=batch_size,shuffle=True)
    val_dataloader = DataLoader(voc_test,batch_size=batch_size,shuffle=True)
    # 2. load model
    num_class = 21
    model = FCN8s(num_classes=num_class)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    # 3. prepare super parameters
    criterion = nn.CrossEntropyLoss() 
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.7)
    epoch = 50
    # 4. train
    val_acc_list = []
    out_dir = "./checkpoints/"
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    for epoch in range(0, epoch):
        print('\nEpoch: %d' % (epoch + 1))
        model.train()
        sum_loss = 0.0
        for batch_idx, (images, labels) in enumerate(train_dataloader):
            length = len(train_dataloader)
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images) # torch.size([batch_size, num_class, width, height])
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            sum_loss += loss.item()
            predicted = torch.argmax(outputs.data, 1)
            label_pred = predicted.data.cpu().numpy()
            label_true = labels.data.cpu().numpy()
            acc, acc_cls, mean_iu = label_accuracy_score(label_true,label_pred,num_class)
            print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% | Acc_cls: %.03f%% |Mean_iu: %.3f' 
                % (epoch + 1, (batch_idx + 1 + epoch * length), sum_loss / (batch_idx + 1), 
                100. *acc, 100.*acc_cls, mean_iu))
        #get the ac with testdataset in each epoch
        print('Waiting Val...')
        mean_iu_epoch = 0.0
        mean_acc = 0.0
        mean_acc_cls = 0.0
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(val_dataloader):
                model.eval()
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predicted = torch.argmax(outputs.data, 1)
                label_pred = predicted.data.cpu().numpy()
                label_true = labels.data.cpu().numpy()
                acc, acc_cls, mean_iu = label_accuracy_score(label_true,label_pred,num_class)
                # total += labels.size(0)
                # iou = torch.sum((predicted == labels.data), (1,2)) / float(width*height)
                # iou = torch.sum(iou)
                # correct += iou
                mean_iu_epoch += mean_iu
                mean_acc += acc
                mean_acc_cls += acc_cls
            print('Acc_epoch: %.3f%% | Acc_cls_epoch: %.03f%% |Mean_iu_epoch: %.3f' 
                % ((100. *mean_acc / len(val_dataloader)), (100.*mean_acc_cls/len(val_dataloader)), mean_iu_epoch/len(val_dataloader)) )
            val_acc_list.append(mean_iu_epoch/len(val_dataloader))
        torch.save(model.state_dict(), out_dir+"last.pt")
        if mean_iu_epoch/len(val_dataloader) == max(val_acc_list):
            torch.save(model.state_dict(), out_dir+"best.pt")
            print("save epoch {} model".format(epoch))
if __name__ == "__main__":
    main()

整體訓(xùn)練流程沒(méi)問(wèn)題,讀者可以根據(jù)需要更改其模型評(píng)價(jià)標(biāo)準(zhǔn)和相關(guān)代碼。在本次訓(xùn)練中,主要使用Acc作為評(píng)價(jià)指標(biāo),其實(shí)就是分類(lèi)正確的像素個(gè)數(shù)除以全部像素個(gè)數(shù)。最終訓(xùn)練結(jié)果如下:

0.8

訓(xùn)練集的Acc來(lái)到了0.8, 驗(yàn)證集的Acc來(lái)到了0.77。由于有一些函數(shù)是復(fù)制過(guò)來(lái)的,如_hist等,所以其他指標(biāo)暫時(shí)不參考。

到此這篇關(guān)于pytorch使用voc分割數(shù)據(jù)集訓(xùn)練FCN流程講解的文章就介紹到這了,更多相關(guān)pytorch訓(xùn)練FCN內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評(píng)論