使用PyTorch實(shí)現(xiàn)MNIST手寫體識(shí)別代碼
實(shí)驗(yàn)環(huán)境
win10 + anaconda + jupyter notebook
Pytorch1.1.0
Python3.7
gpu環(huán)境(可選)
MNIST數(shù)據(jù)集介紹
MNIST 包括6萬(wàn)張28x28的訓(xùn)練樣本,1萬(wàn)張測(cè)試樣本,可以說是CV里的“Hello Word”。本文使用的CNN網(wǎng)絡(luò)將MNIST數(shù)據(jù)的識(shí)別率提高到了99%。下面我們就開始進(jìn)行實(shí)戰(zhàn)。
導(dǎo)入包
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms torch.__version__
定義超參數(shù)
BATCH_SIZE=512 EPOCHS=20 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
數(shù)據(jù)集
我們直接使用PyTorch中自帶的dataset,并使用DataLoader對(duì)訓(xùn)練數(shù)據(jù)和測(cè)試數(shù)據(jù)分別進(jìn)行讀取。如果下載過數(shù)據(jù)集這里download可選擇False
train_loader = torch.utils.data.DataLoader( datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=BATCH_SIZE, shuffle=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST('data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=BATCH_SIZE, shuffle=True)
定義網(wǎng)絡(luò)
該網(wǎng)絡(luò)包括兩個(gè)卷積層和兩個(gè)線性層,最后輸出10個(gè)維度,即代表0-9十個(gè)數(shù)字。
class ConvNet(nn.Module): def __init__(self): super().__init__() self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24) self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10) self.fc1 = nn.Linear(20*10*10,500) self.fc2 = nn.Linear(500,10) def forward(self,x): in_size = x.size(0) out = self.conv1(x) out = F.relu(out) out = F.max_pool2d(out, 2, 2) out = self.conv2(out) out = F.relu(out) out = out.view(in_size,-1) out = self.fc1(out) out = F.relu(out) out = self.fc2(out) out = F.log_softmax(out,dim=1) return out
實(shí)例化網(wǎng)絡(luò)
model = ConvNet().to(DEVICE) # 將網(wǎng)絡(luò)移動(dòng)到gpu上 optimizer = optim.Adam(model.parameters()) # 使用Adam優(yōu)化器
定義訓(xùn)練函數(shù)
def train(model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if(batch_idx+1)%30 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
定義測(cè)試函數(shù)
def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() # 將一批的損失相加 pred = output.max(1, keepdim=True)[1] # 找到概率最大的下標(biāo) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
開始訓(xùn)練
for epoch in range(1, EPOCHS + 1): train(model, DEVICE, train_loader, optimizer, epoch) test(model, DEVICE, test_loader)
實(shí)驗(yàn)結(jié)果
Train Epoch: 1 [14848/60000 (25%)] Loss: 0.375058 Train Epoch: 1 [30208/60000 (50%)] Loss: 0.255248 Train Epoch: 1 [45568/60000 (75%)] Loss: 0.128060 Test set: Average loss: 0.0992, Accuracy: 9690/10000 (97%) Train Epoch: 2 [14848/60000 (25%)] Loss: 0.093066 Train Epoch: 2 [30208/60000 (50%)] Loss: 0.087888 Train Epoch: 2 [45568/60000 (75%)] Loss: 0.068078 Test set: Average loss: 0.0599, Accuracy: 9816/10000 (98%) Train Epoch: 3 [14848/60000 (25%)] Loss: 0.043926 Train Epoch: 3 [30208/60000 (50%)] Loss: 0.037321 Train Epoch: 3 [45568/60000 (75%)] Loss: 0.068404 Test set: Average loss: 0.0416, Accuracy: 9859/10000 (99%) Train Epoch: 4 [14848/60000 (25%)] Loss: 0.031654 Train Epoch: 4 [30208/60000 (50%)] Loss: 0.041341 Train Epoch: 4 [45568/60000 (75%)] Loss: 0.036493 Test set: Average loss: 0.0361, Accuracy: 9873/10000 (99%) Train Epoch: 5 [14848/60000 (25%)] Loss: 0.027688 Train Epoch: 5 [30208/60000 (50%)] Loss: 0.019488 Train Epoch: 5 [45568/60000 (75%)] Loss: 0.018023 Test set: Average loss: 0.0344, Accuracy: 9875/10000 (99%) Train Epoch: 6 [14848/60000 (25%)] Loss: 0.024212 Train Epoch: 6 [30208/60000 (50%)] Loss: 0.018689 Train Epoch: 6 [45568/60000 (75%)] Loss: 0.040412 Test set: Average loss: 0.0350, Accuracy: 9879/10000 (99%) Train Epoch: 7 [14848/60000 (25%)] Loss: 0.030426 Train Epoch: 7 [30208/60000 (50%)] Loss: 0.026939 Train Epoch: 7 [45568/60000 (75%)] Loss: 0.010722 Test set: Average loss: 0.0287, Accuracy: 9892/10000 (99%) Train Epoch: 8 [14848/60000 (25%)] Loss: 0.021109 Train Epoch: 8 [30208/60000 (50%)] Loss: 0.034845 Train Epoch: 8 [45568/60000 (75%)] Loss: 0.011223 Test set: Average loss: 0.0299, Accuracy: 9904/10000 (99%) Train Epoch: 9 [14848/60000 (25%)] Loss: 0.011391 Train Epoch: 9 [30208/60000 (50%)] Loss: 0.008091 Train Epoch: 9 [45568/60000 (75%)] Loss: 0.039870 Test set: Average loss: 0.0341, Accuracy: 9890/10000 (99%) Train Epoch: 10 [14848/60000 (25%)] Loss: 0.026813 Train Epoch: 10 [30208/60000 (50%)] Loss: 0.011159 Train Epoch: 10 [45568/60000 (75%)] Loss: 0.024884 Test set: Average loss: 0.0286, Accuracy: 9901/10000 (99%) Train Epoch: 11 [14848/60000 (25%)] Loss: 0.006420 Train Epoch: 11 [30208/60000 (50%)] Loss: 0.003641 Train Epoch: 11 [45568/60000 (75%)] Loss: 0.003402 Test set: Average loss: 0.0377, Accuracy: 9894/10000 (99%) Train Epoch: 12 [14848/60000 (25%)] Loss: 0.006866 Train Epoch: 12 [30208/60000 (50%)] Loss: 0.012617 Train Epoch: 12 [45568/60000 (75%)] Loss: 0.008548 Test set: Average loss: 0.0311, Accuracy: 9908/10000 (99%) Train Epoch: 13 [14848/60000 (25%)] Loss: 0.010539 Train Epoch: 13 [30208/60000 (50%)] Loss: 0.002952 Train Epoch: 13 [45568/60000 (75%)] Loss: 0.002313 Test set: Average loss: 0.0293, Accuracy: 9905/10000 (99%) Train Epoch: 14 [14848/60000 (25%)] Loss: 0.002100 Train Epoch: 14 [30208/60000 (50%)] Loss: 0.000779 Train Epoch: 14 [45568/60000 (75%)] Loss: 0.005952 Test set: Average loss: 0.0335, Accuracy: 9897/10000 (99%) Train Epoch: 15 [14848/60000 (25%)] Loss: 0.006053 Train Epoch: 15 [30208/60000 (50%)] Loss: 0.002559 Train Epoch: 15 [45568/60000 (75%)] Loss: 0.002555 Test set: Average loss: 0.0357, Accuracy: 9894/10000 (99%) Train Epoch: 16 [14848/60000 (25%)] Loss: 0.000895 Train Epoch: 16 [30208/60000 (50%)] Loss: 0.004923 Train Epoch: 16 [45568/60000 (75%)] Loss: 0.002339 Test set: Average loss: 0.0400, Accuracy: 9893/10000 (99%) Train Epoch: 17 [14848/60000 (25%)] Loss: 0.004136 Train Epoch: 17 [30208/60000 (50%)] Loss: 0.000927 Train Epoch: 17 [45568/60000 (75%)] Loss: 0.002084 Test set: Average loss: 0.0353, Accuracy: 9895/10000 (99%) Train Epoch: 18 [14848/60000 (25%)] Loss: 0.004508 Train Epoch: 18 [30208/60000 (50%)] Loss: 0.001272 Train Epoch: 18 [45568/60000 (75%)] Loss: 0.000543 Test set: Average loss: 0.0380, Accuracy: 9894/10000 (99%) Train Epoch: 19 [14848/60000 (25%)] Loss: 0.001699 Train Epoch: 19 [30208/60000 (50%)] Loss: 0.000661 Train Epoch: 19 [45568/60000 (75%)] Loss: 0.000275 Test set: Average loss: 0.0339, Accuracy: 9905/10000 (99%) Train Epoch: 20 [14848/60000 (25%)] Loss: 0.000441 Train Epoch: 20 [30208/60000 (50%)] Loss: 0.000695 Train Epoch: 20 [45568/60000 (75%)] Loss: 0.000467 Test set: Average loss: 0.0396, Accuracy: 9894/10000 (99%)
總結(jié)
一個(gè)實(shí)際項(xiàng)目的工作流程:找到數(shù)據(jù)集,對(duì)數(shù)據(jù)做預(yù)處理,定義我們的模型,調(diào)整超參數(shù),測(cè)試訓(xùn)練,再通過訓(xùn)練結(jié)果對(duì)超參數(shù)進(jìn)行調(diào)整或者對(duì)模型進(jìn)行調(diào)整。
以上這篇使用PyTorch實(shí)現(xiàn)MNIST手寫體識(shí)別代碼就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python點(diǎn)云地面點(diǎn)濾波(Progressive Morphological Filter)算法介紹(PCL庫(kù))
這篇文章主要介紹了python點(diǎn)云地面點(diǎn)濾波(Progressive Morphological Filter)算法介紹(PCL庫(kù)),了解膨脹/腐蝕這兩個(gè)基礎(chǔ)操作,可以通過對(duì)其進(jìn)行簡(jiǎn)單組合來形成開/閉操作,需要的朋友可以參考下2021-08-08教你使用Pandas直接核算Excel中的快遞費(fèi)用
文中仔細(xì)說明了怎么根據(jù)賬單核算運(yùn)費(fèi).首先要確定運(yùn)費(fèi)規(guī)則,然后根據(jù)運(yùn)費(fèi)規(guī)則編寫代碼,生成核算列(快遞費(fèi) = 省份*重量),最后輸入賬單,進(jìn)行核算.將腳本件生成EXE文件,就可以使用啦,需要的朋友可以參考下2021-05-05python opencv 圖像拼接的實(shí)現(xiàn)方法
高級(jí)圖像拼接也叫作基于特征匹配的圖像拼接,拼接時(shí)消去兩幅圖像相同的部分,實(shí)現(xiàn)拼接合成全景圖。這篇文章主要介紹了python opencv 圖像拼接,需要的朋友可以參考下2019-06-06Python中%是什么意思?python中百分號(hào)如何使用?
最近在學(xué)習(xí)python過程中,發(fā)現(xiàn)了%的一些情況,這里就簡(jiǎn)單介紹一下,,需要的朋友可以參考下2018-03-03Python遞歸函數(shù) 二分查找算法實(shí)現(xiàn)解析
這篇文章主要介紹了Python遞歸函數(shù) 二分查找算法實(shí)現(xiàn)解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-08-08python 創(chuàng)建一維的0向量實(shí)例
今天小編就為大家分享一篇python 創(chuàng)建一維的0向量實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-12-12