PyTorch實現(xiàn)卷積神經(jīng)網(wǎng)絡的搭建詳解
PyTorch中實現(xiàn)卷積的重要基礎函數(shù)
1、nn.Conv2d:
nn.Conv2d在pytorch中用于實現(xiàn)卷積。
nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, )
1、in_channels為輸入通道數(shù)。
2、out_channels為輸出通道數(shù)。
3、kernel_size為卷積核大小。
4、stride為步數(shù)。
5、padding為padding情況。
6、dilation表示空洞卷積情況。
2、nn.MaxPool2d(kernel_size=2)
nn.MaxPool2d在pytorch中用于實現(xiàn)最大池化。
具體使用方式如下:
MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
1、kernel_size為池化核的大小
2、stride為步長
3、padding為填充情況
3、nn.ReLU()
nn.ReLU()用來實現(xiàn)Relu函數(shù),實現(xiàn)非線性。
4、x.view()
x.view用于reshape特征層的形狀。
全部代碼
這是一個簡單的CNN模型,用于預測mnist手寫體。
import os import numpy as np import torch import torch.nn as nn import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt # 循環(huán)世代 EPOCH = 20 BATCH_SIZE = 50 # 下載mnist數(shù)據(jù)集 train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,) # (60000, 28, 28) print(train_data.train_data.size()) # (60000) print(train_data.train_labels.size()) train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # 測試集 test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) # (2000, 1, 28, 28) # 標準化 test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255. test_y = test_data.test_labels[:2000] # 建立pytorch神經(jīng)網(wǎng)絡 class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() #----------------------------# # 第一部分卷積 #----------------------------# self.conv1 = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2, dilation=1 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) #----------------------------# # 第二部分卷積 #----------------------------# self.conv2 = nn.Sequential( nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) #----------------------------# # 全連接+池化+全連接 #----------------------------# self.ful1 = nn.Linear(64 * 7 * 7, 512) self.drop = nn.Dropout(0.5) self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax()) #----------------------------# # 前向傳播 #----------------------------# def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) x = self.ful1(x) x = self.drop(x) output = self.ful2(x) return output cnn = CNN() # 指定優(yōu)化器 optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3) # 指定loss函數(shù) loss_func = nn.CrossEntropyLoss() for epoch in range(EPOCH): for step, (b_x, b_y) in enumerate(train_loader): #----------------------------# # 計算loss并修正權值 #----------------------------# output = cnn(b_x) loss = loss_func(output, b_y) optimizer.zero_grad() loss.backward() optimizer.step() #----------------------------# # 打印 #----------------------------# if step % 50 == 0: test_output = cnn(test_x) pred_y = torch.max(test_output, 1)[1].data.numpy() accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0)) print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)
以上就是PyTorch實現(xiàn)卷積神經(jīng)網(wǎng)絡的搭建詳解的詳細內(nèi)容,更多關于PyTorch搭建卷積神經(jīng)網(wǎng)絡的資料請關注腳本之家其它相關文章!
- PyTorch中的神經(jīng)網(wǎng)絡 Mnist 分類任務
- 使用Pytorch構建第一個神經(jīng)網(wǎng)絡模型?附案例實戰(zhàn)
- pytorch簡單實現(xiàn)神經(jīng)網(wǎng)絡功能
- pytorch深度神經(jīng)網(wǎng)絡入門準備自己的圖片數(shù)據(jù)
- Pytorch卷積神經(jīng)網(wǎng)絡遷移學習的目標及好處
- Pytorch深度學習經(jīng)典卷積神經(jīng)網(wǎng)絡resnet模塊訓練
- Pytorch卷積神經(jīng)網(wǎng)絡resent網(wǎng)絡實踐
- Pytorch神經(jīng)網(wǎng)絡參數(shù)管理方法詳細講解
相關文章
Python開啟Http Server的實現(xiàn)步驟
本文主要介紹了Python開啟Http Server的實現(xiàn)步驟,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2023-07-07在Python中使用matplotlib模塊繪制數(shù)據(jù)圖的示例
這篇文章主要介紹了在Python中使用matplotlib模塊繪制數(shù)據(jù)圖的示例,matplotlib模塊經(jīng)常被用來實現(xiàn)數(shù)據(jù)的可視化,需要的朋友可以參考下2015-05-05PyCharm插件開發(fā)實踐之PyGetterAndSetter詳解
這篇文章主要介紹了PyCharm插件開發(fā)實踐-PyGetterAndSetter,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-10-10pyenv虛擬環(huán)境管理python多版本和軟件庫的方法
這篇文章主要介紹了pyenv虛擬環(huán)境管理python多版本和軟件庫,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2019-12-12