pytorch cnn 識別手寫的字實現(xiàn)自建圖片數(shù)據(jù)
更新時間:2018年05月20日 17:03:26 作者:瓦力冫
這篇文章主要介紹了pytorch cnn 識別手寫的字實現(xiàn)自建圖片數(shù)據(jù),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
本文主要介紹了pytorch cnn 識別手寫的字實現(xiàn)自建圖片數(shù)據(jù),分享給大家,具體如下:
# library # standard library import os # third-party library import torch import torch.nn as nn from torch.autograd import Variable from torch.utils.data import Dataset, DataLoader import torchvision import matplotlib.pyplot as plt from PIL import Image import numpy as np # torch.manual_seed(1) # reproducible # Hyper Parameters EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch BATCH_SIZE = 50 LR = 0.001 # learning rate root = "./mnist/raw/" def default_loader(path): # return Image.open(path).convert('RGB') return Image.open(path) class MyDataset(Dataset): def __init__(self, txt, transform=None, target_transform=None, loader=default_loader): fh = open(txt, 'r') imgs = [] for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0], int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader fh.close() def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) img = Image.fromarray(np.array(img), mode='L') if self.transform is not None: img = self.transform(img) return img,label def __len__(self): return len(self.imgs) train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor()) train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True) test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE) class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( # input shape (1, 28, 28) nn.Conv2d( in_channels=1, # input height out_channels=16, # n_filters kernel_size=5, # filter size stride=1, # filter movement/step padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1 ), # output shape (16, 28, 28) nn.ReLU(), # activation nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14) ) self.conv2 = nn.Sequential( # input shape (16, 14, 14) nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14) nn.ReLU(), # activation nn.MaxPool2d(2), # output shape (32, 7, 7) ) self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7) output = self.out(x) return output, x # return x for visualization cnn = CNN() print(cnn) # net architecture optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted # training and testing for epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader b_x = Variable(x) # batch x b_y = Variable(y) # batch y output = cnn(b_x)[0] # cnn output loss = loss_func(output, b_y) # cross entropy loss optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients if step % 50 == 0: cnn.eval() eval_loss = 0. eval_acc = 0. for i, (tx, ty) in enumerate(test_loader): t_x = Variable(tx) t_y = Variable(ty) output = cnn(t_x)[0] loss = loss_func(output, t_y) eval_loss += loss.data[0] pred = torch.max(output, 1)[1] num_correct = (pred == t_y).sum() eval_acc += float(num_correct.data[0]) acc_rate = eval_acc / float(len(test_data)) print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))
圖片和label 見上一篇文章《pytorch 把MNIST數(shù)據(jù)集轉(zhuǎn)換成圖片和txt》
結(jié)果如下:
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
您可能感興趣的文章:
- Pytorch 使用CNN圖像分類的實現(xiàn)
- pytorch實現(xiàn)textCNN的具體操作
- Pytorch mask-rcnn 實現(xiàn)細節(jié)分享
- 在Pytorch中使用Mask R-CNN進行實例分割操作
- pytorch實現(xiàn)CNN卷積神經(jīng)網(wǎng)絡(luò)
- pytorch實現(xiàn)用CNN和LSTM對文本進行分類方式
- 用Pytorch訓(xùn)練CNN(數(shù)據(jù)集MNIST,使用GPU的方法)
- pytorch + visdom CNN處理自建圖片數(shù)據(jù)集的方法
- PyTorch CNN實戰(zhàn)之MNIST手寫數(shù)字識別示例
- PyTorch上實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)CNN的方法
- 基于PyTorch實現(xiàn)一個簡單的CNN圖像分類器
相關(guān)文章
Python深度學(xué)習(xí)albumentations數(shù)據(jù)增強庫
下面開始albumenations的正式介紹,在這里我強烈建議英語基礎(chǔ)還好的讀者去官方網(wǎng)站跟著教程一步步學(xué)習(xí),而這里的內(nèi)容主要是我自己的一個總結(jié)以及方便英語能力較弱的讀者學(xué)習(xí)2021-09-09Python測試WebService接口的實現(xiàn)示例
webService接口是走soap協(xié)議通過http傳輸,請求報文和返回報文都是xml格式的,本文主要介紹了Python測試WebService接口,具有一定的參考價值,感興趣的可以了解一下2024-03-03用Python selenium實現(xiàn)淘寶搶單機器人
今天給大家?guī)淼氖顷P(guān)于Python實戰(zhàn)的相關(guān)知識,文章圍繞著用Python selenium實現(xiàn)淘寶搶單機器人展開,文中有非常詳細的介紹及代碼示例,需要的朋友可以參考下2021-06-06python cx_Oracle的基礎(chǔ)使用方法(連接和增刪改查)
這篇文章主要給大家介紹了關(guān)于python cx_Oracle的基礎(chǔ)使用方法,其中包括連接、增刪改查等基本操作,并給大家分享了python 連接Oracle 亂碼問題的解決方法,需要的朋友可以參考借鑒,下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧。2017-11-11