亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

pytorch cnn 識別手寫的字實現(xiàn)自建圖片數(shù)據(jù)

 更新時間:2018年05月20日 17:03:26   作者:瓦力冫  
這篇文章主要介紹了pytorch cnn 識別手寫的字實現(xiàn)自建圖片數(shù)據(jù),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧

本文主要介紹了pytorch cnn 識別手寫的字實現(xiàn)自建圖片數(shù)據(jù),分享給大家,具體如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

圖片和label 見上一篇文章《pytorch 把MNIST數(shù)據(jù)集轉(zhuǎn)換成圖片和txt

結(jié)果如下:

以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python深度學(xué)習(xí)albumentations數(shù)據(jù)增強庫

    Python深度學(xué)習(xí)albumentations數(shù)據(jù)增強庫

    下面開始albumenations的正式介紹,在這里我強烈建議英語基礎(chǔ)還好的讀者去官方網(wǎng)站跟著教程一步步學(xué)習(xí),而這里的內(nèi)容主要是我自己的一個總結(jié)以及方便英語能力較弱的讀者學(xué)習(xí)
    2021-09-09
  • Python測試WebService接口的實現(xiàn)示例

    Python測試WebService接口的實現(xiàn)示例

    webService接口是走soap協(xié)議通過http傳輸,請求報文和返回報文都是xml格式的,本文主要介紹了Python測試WebService接口,具有一定的參考價值,感興趣的可以了解一下
    2024-03-03
  • 用Python selenium實現(xiàn)淘寶搶單機器人

    用Python selenium實現(xiàn)淘寶搶單機器人

    今天給大家?guī)淼氖顷P(guān)于Python實戰(zhàn)的相關(guān)知識,文章圍繞著用Python selenium實現(xiàn)淘寶搶單機器人展開,文中有非常詳細的介紹及代碼示例,需要的朋友可以參考下
    2021-06-06
  • python cx_Oracle的基礎(chǔ)使用方法(連接和增刪改查)

    python cx_Oracle的基礎(chǔ)使用方法(連接和增刪改查)

    這篇文章主要給大家介紹了關(guān)于python cx_Oracle的基礎(chǔ)使用方法,其中包括連接、增刪改查等基本操作,并給大家分享了python 連接Oracle 亂碼問題的解決方法,需要的朋友可以參考借鑒,下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧。
    2017-11-11
  • python中的參數(shù)類型匹配提醒

    python中的參數(shù)類型匹配提醒

    這篇文章主要介紹了python中的參數(shù)類型匹配提醒,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2022-12-12
  • python3大文件解壓和基本操作

    python3大文件解壓和基本操作

    這篇文章主要為大家詳細介紹了python3大文件解壓和基本操作,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2017-12-12
  • python按照多個字符對字符串進行分割的方法

    python按照多個字符對字符串進行分割的方法

    這篇文章主要介紹了python按照多個字符對字符串進行分割的方法,涉及Python中正則表達式匹配的技巧,非常具有實用價值,需要的朋友可以參考下
    2015-03-03
  • 利用python進行數(shù)據(jù)加載

    利用python進行數(shù)據(jù)加載

    今天給大家?guī)淼氖顷P(guān)于Python的相關(guān)知識,文章圍繞著python數(shù)據(jù)加載展開,文中有非常詳細的介紹及代碼示例,需要的朋友可以參考下
    2021-06-06
  • Django上線部署之IIS的配置方法

    Django上線部署之IIS的配置方法

    這篇文章主要介紹了Django上線部署之IIS的配置方法,非常不錯,具有一定的參考借鑒價值,需要的朋友可以參考下
    2019-08-08
  • Python讀寫操作csv和excle文件代碼實例

    Python讀寫操作csv和excle文件代碼實例

    這篇文章主要介紹了python讀寫操作csv和excle文件代碼實例,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2020-03-03

最新評論