pytorch通過自己的數據集訓練Unet網絡架構
在圖像分割這個問題上,主要有兩個流派:Encoder-Decoder和Dialated Conv。本文介紹的是編解碼網絡中最為經典的U-Net。隨著骨干網路的進化,很多相應衍生出來的網絡大多都是對于Unet進行了改進但是本質上的思路還是沒有太多的變化。比如結合DenseNet 和Unet的FCDenseNet, Unet++
一、Unet網絡介紹
論文:https://arxiv.org/abs/1505.04597v1(2015)
UNet的設計就是應用與醫(yī)學圖像的分割。由于醫(yī)學影像處理中,數據量較少,本文提出的方法有效提升了使用少量數據集訓練檢測的效果,提出了處理大尺寸圖像的有效方法。
UNet的網絡架構繼承自FCN,并在此基礎上做了些改變。提出了Encoder-Decoder概念,實際上就是FCN那個先卷積再上采樣的思想。

上圖是Unet的網絡結構,從圖中可以看出,
結構左邊為Encoder,即下采樣提取特征的過程。Encoder基本模塊為雙卷積形式,即輸入經過兩個
conu 3x3,使用的valid卷積,在代碼實現(xiàn)時我們可以增加padding使用same卷積,來適應Skip Architecture。下采樣采用的池化層直接縮小2倍。
結構右邊是Decoder,即上采樣恢復圖像尺寸并預測的過程。Decoder一樣采用雙卷積的形式,其中上采樣使用轉置卷積實現(xiàn),每次轉置卷積放大2倍。
結構中間copy and crop是一個cat操作,即feature map的通道疊加。
二、VOC訓練Unet
2.1 Unet代碼實現(xiàn)
根據上面對于Unet網絡結構的介紹,可見其結構非常對稱簡單,代碼Unet.py實現(xiàn)如下:
from turtle import forward
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
# Encoder
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
# Decoder
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.output = nn.Conv2d(64, out_ch, 1)
def forward(self, x):
conv1 = self.conv1(x)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3)
pool4 = self.pool4(conv4)
conv5 = self.conv5(pool4)
up6 = self.up6(conv5)
meger6 = torch.cat([up6, conv4], dim=1)
conv6 = self.conv6(meger6)
up7 = self.up7(conv6)
meger7 = torch.cat([up7, conv3], dim=1)
conv7 = self.conv7(meger7)
up8 = self.up8(conv7)
meger8 = torch.cat([up8, conv2], dim=1)
conv8 = self.conv8(meger8)
up9 = self.up9(conv8)
meger9 = torch.cat([up9, conv1], dim=1)
conv9 = self.conv9(meger9)
out = self.output(conv9)
return out
if __name__=="__main__":
model = Unet(3, 21)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(model)2.2 數據集處理


數據來源于kaggle,下載地址我忘了。包含2個類別,1個車,還有1個背景類,共有5k+的數據,按照比例分為訓練集和驗證集即可。具體見carnava.py
from PIL import Image
from requests import check_compatibility
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
import numpy as np
import os
import matplotlib.pyplot as plt
class Car(Dataset):
def __init__(self, root, train=True):
self.root = root
self.crop_size = (256, 256)
self.img_path = os.path.join(root, "train_hq")
self.label_path = os.path.join(root, "train_masks")
img_path_list = [os.path.join(self.img_path, im) for im in os.listdir(self.img_path)]
train_path_list, val_path_list = self._split_data_set(img_path_list)
if train:
self.imgs_list = train_path_list
else:
self.imgs_list = val_path_list
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.transforms = T.Compose([
T.Resize(256),
T.CenterCrop(256),
T.ToTensor(),
normalize
])
self.transforms_val = T.Compose([
T.Resize(256),
T.CenterCrop(256)
])
self.color_map = [[0, 0, 0], [255, 255, 255]]
def __getitem__(self, index: int):
im_path = self.imgs_list[index]
image = Image.open(im_path).convert("RGB")
data = self.transforms(image)
(filepath, filename) = os.path.split(im_path)
filename = filename.split('.')[0]
label = Image.open(self.label_path +"/"+filename+"_mask.gif").convert("RGB")
label = self.transforms_val(label)
cm2lb=np.zeros(256**3)
for i,cm in enumerate(self.color_map):
cm2lb[(cm[0]*256+cm[1])*256+cm[2]]=i
image=np.array(label,dtype=np.int64)
idx=(image[:,:,0]*256+image[:,:,1])*256+image[:,:,2]
label=np.array(cm2lb[idx],dtype=np.int64)
label=torch.from_numpy(label).long()
return data, label
def label2img(self, label):
cmap = self.color_map
cmap = np.array(cmap).astype(np.uint8)
pred = cmap[label]
return pred
def __len__(self):
return len(self.imgs_list)
def _split_data_set(self, img_path_list):
val_path_list = img_path_list[::8]
train_path_list = []
for item in img_path_list:
if item not in val_path_list:
train_path_list.append(item)
return train_path_list, val_path_list
if __name__=="__main__":
root = "../dataset/carvana"
car_train = Car(root,train=True)
train_dataloader = DataLoader(car_train, batch_size=8, shuffle=True)
print(len(car_train))
print(len(train_dataloader))
# for data, label in car_train:
# print(data.shape)
# print(label.shape)
# break
(data, label) = car_train[190]
label_np = label.data.numpy()
label_im = car_train.label2img(label_np)
plt.figure()
plt.imshow(label_im)
plt.show()2.3 訓練過程
分割其實就是給每個像素分類而已,所以損失函數依舊是交叉熵函數,正確率為分類正確的像素點個數/全部的像素點個數
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from voc import VOC
from carnava import Car
from unet import Unet
import os
import numpy as np
from torch import optim
import torch.nn as nn
import util
# 計算混淆矩陣
def _fast_hist(label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
return hist
def label_accuracy_score(label_trues, label_preds, n_class):
"""Returns accuracy score evaluation result.
- overall accuracy
- mean accuracy
- mean IU
"""
hist = np.zeros((n_class, n_class))
for lt, lp in zip(label_trues, label_preds):
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
acc = np.diag(hist).sum() / hist.sum()
with np.errstate(divide='ignore', invalid='ignore'):
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
with np.errstate(divide='ignore', invalid='ignore'):
iu = np.diag(hist) / (
hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
)
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / hist.sum()
return acc, acc_cls, mean_iu
out_path = "./out"
if not os.path.exists(out_path):
os.makedirs(out_path)
log_path = os.path.join(out_path, "result.txt")
if os.path.exists(log_path):
os.remove(log_path)
model_path = os.path.join(out_path, "best_model.pth")
root = "../dataset/carvana"
epochs = 5
numclasses = 2
train_data = Car(root, train=True)
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
val_data = Car(root, train=False)
val_dataloader = DataLoader(val_data, batch_size=16, shuffle=True)
net = Unet(3, numclasses)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
def train_model():
best_score = 0.0
for e in range(epochs):
net.train()
train_loss = 0.0
label_true = torch.LongTensor()
label_pred = torch.LongTensor()
for batch_id, (data, label) in enumerate(train_dataloader):
data, label = data.to(device), label.to(device)
output = net(data)
loss = criterion(output, label)
pred = output.argmax(dim=1).squeeze().data.cpu()
real = label.data.cpu()
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss+=loss.cpu().item()
label_true = torch.cat((label_true,real),dim=0)
label_pred = torch.cat((label_pred,pred),dim=0)
train_loss /= len(train_dataloader)
acc, acc_cls, mean_iu = label_accuracy_score(label_true.numpy(),label_pred.numpy(),numclasses)
print("\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}".format(
e+1, train_loss, acc, acc_cls, mean_iu))
with open(log_path, 'a') as f:
f.write('\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
e+1,train_loss,acc, acc_cls, mean_iu))
net.eval()
val_loss = 0.0
val_label_true = torch.LongTensor()
val_label_pred = torch.LongTensor()
with torch.no_grad():
for batch_id, (data, label) in enumerate(val_dataloader):
data, label = data.to(device), label.to(device)
output = net(data)
loss = criterion(output, label)
pred = output.argmax(dim=1).squeeze().data.cpu()
real = label.data.cpu()
val_loss += loss.cpu().item()
val_label_true = torch.cat((val_label_true, real), dim=0)
val_label_pred = torch.cat((val_label_pred, pred), dim=0)
val_loss/=len(val_dataloader)
val_acc, val_acc_cls, val_mean_iu = label_accuracy_score(val_label_true.numpy(),
val_label_pred.numpy(),numclasses)
print('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(e+1, val_loss, val_acc, val_acc_cls, val_mean_iu))
with open(log_path, 'a') as f:
f.write('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
e+1,val_loss,val_acc, val_acc_cls, val_mean_iu))
score = (val_acc_cls+val_mean_iu)/2
if score > best_score:
best_score = score
torch.save(net.state_dict(), model_path)
def evaluate():
import util
import random
import matplotlib.pyplot as plt
net.load_state_dict(torch.load(model_path))
index = random.randint(0, len(val_data)-1)
val_image, val_label = val_data[index]
out = net(val_image.unsqueeze(0).to(device))
pred = out.argmax(dim=1).squeeze().data.cpu().numpy()
label = val_label.data.numpy()
img_pred = val_data.label2img(pred)
img_label = val_data.label2img(label)
temp = val_image.numpy()
temp = (temp-np.min(temp)) / (np.max(temp)-np.min(temp))*255
fig, ax = plt.subplots(1,3)
ax[0].imshow(temp.transpose(1,2,0).astype("uint8"))
ax[1].imshow(img_label)
ax[2].imshow(img_pred)
plt.show()
if __name__=="__main__":
# train_model()
evaluate()最終訓練結果是:

由于數據比較簡單,訓練到epoch為5時,mIOU就已經達到0.97了。
最后測試一下效果:

從左到右分別是:原圖、真實label、預測label
備注:
其實最開始使用voc數據集訓練的,但效果極差,也沒發(fā)現(xiàn)哪里有問題。換個數據集效果就好了,可能有兩個原因:
1. voc數據我在處理數據時出錯了,沒檢查出來
2. 這個數據集比較簡單,容易學習,所以效果差不多。
到此這篇關于pytorch通過自己的數據集訓練Unet網絡架構的文章就介紹到這了,更多相關pytorch Unet內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python實現(xiàn)批量讀取圖片并存入mongodb數據庫的方法示例
這篇文章主要介紹了Python實現(xiàn)批量讀取圖片并存入mongodb數據庫的方法,涉及Python文件讀取及數據庫寫入相關操作技巧,需要的朋友可以參考下2018-04-04
Django在視圖中使用表單并和數據庫進行數據交互的實現(xiàn)
本文主要介紹了Django在視圖中使用表單并和數據庫進行數據交互,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2022-07-07
pytorch繪制并顯示loss曲線和acc曲線,LeNet5識別圖像準確率
今天小編就為大家分享一篇pytorch繪制并顯示loss曲線和acc曲線,LeNet5識別圖像準確率,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01
CentOS 7 安裝python3.7.1的方法及注意事項
這篇文章主要介紹了CentOS 7 安裝python3.7.1的方法,文中給大家提到了注意事項,需要的朋友可以參考下2018-11-11

