pytorch實(shí)現(xiàn)圖像識別(實(shí)戰(zhàn))
1. 代碼講解
1.1 導(dǎo)庫
import os.path from os import listdir import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.nn import AdaptiveAvgPool2d from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data import Dataset import torchvision.transforms as transforms from sklearn.model_selection import train_test_split
1.2 標(biāo)準(zhǔn)化、transform、設(shè)置GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
normalize = transforms.Normalize(
? ?mean=[0.485, 0.456, 0.406],
? ?std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose([transforms.ToTensor(), normalize]) ?# 轉(zhuǎn)換1.3 預(yù)處理數(shù)據(jù)
class DogDataset(Dataset):
# 定義變量
? ? def __init__(self, img_paths, img_labels, size_of_images): ?
? ? ? ? self.img_paths = img_paths
? ? ? ? self.img_labels = img_labels
? ? ? ? self.size_of_images = size_of_images
# 多少長圖片
? ? def __len__(self):
? ? ? ? return len(self.img_paths)
# 打開每組圖片并處理每張圖片
? ? def __getitem__(self, index):
? ? ? ? PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images)
? ? ? ? TENSOR_IMAGE = transform(PIL_IMAGE)
? ? ? ? label = self.img_labels[index]
? ? ? ? return TENSOR_IMAGE, label
print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train')))
print(len(pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')))
print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\test')))
train_paths = []
test_paths = []
labels = []
# 訓(xùn)練集圖片路徑
train_paths_lir = r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train'
for path in listdir(train_paths_lir):
? ? train_paths.append(os.path.join(train_paths_lir, path)) ?
# 測試集圖片路徑
labels_data = pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')
labels_data = pd.DataFrame(labels_data) ?
# 把字符標(biāo)簽離散化,因?yàn)閿?shù)據(jù)有120種狗,不離散化后面把數(shù)據(jù)給模型時會報錯:字符標(biāo)簽過多。把字符標(biāo)簽從0-119編號
size_mapping = {}
value = 0
size_mapping = dict(labels_data['breed'].value_counts())
for kay in size_mapping:
? ? size_mapping[kay] = value
? ? value += 1
# print(size_mapping)
labels = labels_data['breed'].map(size_mapping)
labels = list(labels)
# print(labels)
print(len(labels))
# 劃分訓(xùn)練集和測試集
X_train, X_test, y_train, y_test = train_test_split(train_paths, labels, test_size=0.2)
train_set = DogDataset(X_train, y_train, (32, 32))
test_set = DogDataset(X_test, y_test, (32, 32))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)1.4 建立模型
class LeNet(nn.Module): ? ? def __init__(self): ? ? ? ? super(LeNet, self).__init__() ? ? ? ? self.features = nn.Sequential( ? ? ? ? ? ? nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5), ? ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2), ? ? ? ? ? ? nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2) ? ? ? ? ) ? ? ? ? self.classifier = nn.Sequential( ? ? ? ? ? ? nn.Linear(16 * 5 * 5, 120), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(120, 84), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(84, 120) ? ? ? ? ) ? ? def forward(self, x): ? ? ? ? batch_size = x.shape[0] ? ? ? ? x = self.features(x) ? ? ? ? x = x.view(batch_size, -1) ? ? ? ? x = self.classifier(x) ? ? ? ? return x model = LeNet().to(device) criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(model.parameters()) TRAIN_LOSS = [] ?# 損失 TRAIN_ACCURACY = [] ?# 準(zhǔn)確率
1.5 訓(xùn)練模型
def train(epoch):
? ? model.train()
? ? epoch_loss = 0.0 # 損失
? ? correct = 0 ?# 精確率
? ? for batch_index, (Data, Label) in enumerate(train_loader):
? ? # 扔到GPU中
? ? ? ? Data = Data.to(device)
? ? ? ? Label = Label.to(device)
? ? ? ? output_train = model(Data)
? ? # 計(jì)算損失
? ? ? ? loss_train = criterion(output_train, Label)
? ? ? ? epoch_loss = epoch_loss + loss_train.item()
? ? # 計(jì)算精確率
? ? ? ? pred = torch.max(output_train, 1)[1]
? ? ? ? train_correct = (pred == Label).sum()
? ? ? ? correct = correct + train_correct.item()
? ? # 梯度歸零、反向傳播、更新參數(shù)
? ? ? ? optimizer.zero_grad()
? ? ? ? loss_train.backward()
? ? ? ? optimizer.step()
? ? print('Epoch: ', epoch, 'Train_loss: ', epoch_loss / len(train_set), 'Train correct: ', correct / len(train_set))1.6 測試模型
和訓(xùn)練集差不多。
def test():
? ? model.eval()
? ? correct = 0.0
? ? test_loss = 0.0
? ? with torch.no_grad():
? ? ? ? for Data, Label in test_loader:
? ? ? ? ? ? Data = Data.to(device)
? ? ? ? ? ? Label = Label.to(device)
? ? ? ? ? ? test_output = model(Data)
? ? ? ? ? ? loss = criterion(test_output, Label)
? ? ? ? ? ? pred = torch.max(test_output, 1)[1]
? ? ? ? ? ? test_correct = (pred == Label).sum()
? ? ? ? ? ? correct = correct + test_correct.item()
? ? ? ? ? ? test_loss = test_loss + loss.item()
? ? print('Test_loss: ', test_loss / len(test_set), 'Test correct: ', correct / len(test_set))1.7結(jié)果
epoch = 10 for n_epoch in range(epoch): ? ? train(n_epoch) test()

到此這篇關(guān)于pytorch實(shí)現(xiàn)圖像識別(實(shí)戰(zhàn))的文章就介紹到這了,更多相關(guān)pytorch實(shí)現(xiàn)圖像識別內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python+OpenCV數(shù)字圖像處理之ROI區(qū)域的提取
ROI區(qū)域又叫感興趣區(qū)域。在機(jī)器視覺、圖像處理中,從被處理的圖像以方框、圓、橢圓、不規(guī)則多邊形等方式勾勒出需要處理的區(qū)域,稱為感興趣區(qū)域,ROI。本文主要為大家介紹如何通過Python+OpenCV提取ROI區(qū)域,需要的朋友可以了解一下2021-12-12
python+tkinter編寫電腦桌面放大鏡程序?qū)嵗a
這篇文章主要介紹了Python+tkinter編寫電腦桌面放大鏡程序?qū)嵗a,具有一定借鑒價值,需要的朋友可以參考下2018-01-01
詳細(xì)聊聊為什么Python中0.2+0.1不等于0.3
最近在學(xué)習(xí)過程中發(fā)現(xiàn)在計(jì)算機(jī)JS時發(fā)現(xiàn)了一個非常有意思事,0.1+0.2的結(jié)果不是0.3,而是0.30000000000000004,下面這篇文章主要給大家介紹了關(guān)于為什么Python中0.2+0.1不等于0.3的相關(guān)資料,需要的朋友可以參考下2022-12-12
python使用matplotlib:subplot繪制多個子圖的示例
這篇文章主要介紹了python使用matplotlib:subplot繪制多個子圖的示例,幫助大家更好的利用python繪制圖像,感興趣的朋友可以了解下2020-09-09
Python通過zookeeper實(shí)現(xiàn)分布式服務(wù)代碼解析
這篇文章主要介紹了Python通過zookeeper實(shí)現(xiàn)分布式服務(wù)代碼解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-07-07
tensorflow構(gòu)建BP神經(jīng)網(wǎng)絡(luò)的方法
這篇文章主要為大家詳細(xì)介紹了tensorflow構(gòu)建BP神經(jīng)網(wǎng)絡(luò)的方法,具有一定的參考價值,感興趣的小伙伴們可以參考一下2018-03-03
python 統(tǒng)計(jì)數(shù)組中元素出現(xiàn)次數(shù)并進(jìn)行排序的實(shí)例
今天小編就為大家分享一篇python 統(tǒng)計(jì)數(shù)組中元素出現(xiàn)次數(shù)并進(jìn)行排序的實(shí)例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07
django admin實(shí)現(xiàn)動態(tài)多選框表單的示例代碼
借助django-admin,可以快速得到CRUD界面,但若需要創(chuàng)建多選標(biāo)簽字段時,需要對表單進(jìn)行調(diào)整,本文通過示例代碼給大家介紹django admin多選框表單的實(shí)現(xiàn)方法,感興趣的朋友跟隨小編一起看看吧2021-05-05

