pytorch實現(xiàn)CNN卷積神經(jīng)網(wǎng)絡(luò)
本文為大家講解了pytorch實現(xiàn)CNN卷積神經(jīng)網(wǎng)絡(luò),供大家參考,具體內(nèi)容如下
我對卷積神經(jīng)網(wǎng)絡(luò)的一些認(rèn)識
卷積神經(jīng)網(wǎng)絡(luò)是時下最為流行的一種深度學(xué)習(xí)網(wǎng)絡(luò),由于其具有局部感受野等特性,讓其與人眼識別圖像具有相似性,因此被廣泛應(yīng)用于圖像識別中,本人是研究機械故障診斷方面的,一般利用旋轉(zhuǎn)機械的振動信號作為數(shù)據(jù)。
對一維信號,通常采取的方法有兩種,第一,直接對其做一維卷積,第二,反映到時頻圖像上,這就變成了圖像識別,此前一直都在利用keras搭建網(wǎng)絡(luò),最近學(xué)了pytroch搭建cnn的方法,進(jìn)行一下代碼的嘗試。所用數(shù)據(jù)為經(jīng)典的minist手寫字體數(shù)據(jù)集
import torch import torch.nn as nn import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt `EPOCH = 1 BATCH_SIZE = 50 LR = 0.001 DOWNLOAD_MNIST = True 從網(wǎng)上下載數(shù)據(jù)集: ```python train_data = torchvision.datasets.MNIST( root="./mnist/", train = True, transform=torchvision.transforms.ToTensor(), download = DOWNLOAD_MNIST, ) print(train_data.train_data.size()) print(train_data.train_labels.size()) ```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn') plt.title("%i" % train_data.train_labels[0]) plt.show() train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) test_data = torchvision.datasets.MNIST(root="./mnist/", train=False) test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255. test_y = test_data.test_labels[:2000] class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2, ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.conv2 = nn.Sequential( nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2), ) self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7) output = self.out(x) return output optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) loss_func = nn.CrossEntropyLoss() from matplotlib import cm try: from sklearn.manifold import TSNE; HAS_SK = True except: HAS_SK = False; print('Please install sklearn for layer visualization') def plot_with_labels(lowDWeights, labels): plt.cla() X, Y = lowDWeights[:, 0], lowDWeights[:, 1] for x, y, s in zip(X, Y, labels): c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9) plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01) plt.ion() for epoch in range(EPOCH): for step, (b_x, b_y) in enumerate(train_loader): output = cnn(b_x) loss = loss_func(output, b_y) optimizer.zero_grad() loss.backward() optimizer.step() if step % 50 == 0: test_output = cnn(test_x) pred_y = torch.max(test_output, 1)[1].data.numpy() accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0)) print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), "| test accuracy: %.2f" % accuracy) plt.ioff()
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
tensorboard 可以顯示graph,卻不能顯示scalar的解決方式
今天小編就為大家分享一篇tensorboard 可以顯示graph,卻不能顯示scalar的解決方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-02-02Pytorch實現(xiàn)圖像識別之?dāng)?shù)字識別(附詳細(xì)注釋)
這篇文章主要介紹了Pytorch實現(xiàn)圖像識別之?dāng)?shù)字識別(附詳細(xì)注釋),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-05-05