Pytorch使用MNIST數(shù)據(jù)集實(shí)現(xiàn)CGAN和生成指定的數(shù)字方式
CGAN的全拼是Conditional Generative Adversarial Networks,條件生成對(duì)抗網(wǎng)絡(luò),在初始GAN的基礎(chǔ)上增加了圖片的相應(yīng)信息。
這里用傳統(tǒng)的卷積方式實(shí)現(xiàn)CGAN。
import torch from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms from torch import optim import torch.nn as nn import matplotlib.pyplot as plt import numpy as np from torch.autograd import Variable import pickle import copy import matplotlib.gridspec as gridspec import os def save_model(model, filename): #保存為CPU中可以打開(kāi)的模型 state = model.state_dict() x=state.copy() for key in x: x[key] = x[key].clone().cpu() torch.save(x, filename) def showimg(images,count): images=images.to('cpu') images=images.detach().numpy() images=images[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]] images=255*(0.5*images+0.5) images = images.astype(np.uint8) grid_length=int(np.ceil(np.sqrt(images.shape[0]))) plt.figure(figsize=(4,4)) width = images.shape[2] gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0) for i, img in enumerate(images): ax = plt.subplot(gs[i]) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(img.reshape(width,width),cmap = plt.cm.gray) plt.axis('off') plt.tight_layout() # plt.tight_layout() plt.savefig(r'./CGAN/images/%d.png'% count, bbox_inches='tight') def loadMNIST(batch_size): #MNIST圖片的大小是28*28 trans_img=transforms.Compose([transforms.ToTensor()]) trainset=MNIST('./data',train=True,transform=trans_img,download=True) testset=MNIST('./data',train=False,transform=trans_img,download=True) # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10) return trainset,testset,trainloader,testloader class discriminator(nn.Module): def __init__(self): super(discriminator,self).__init__() self.dis=nn.Sequential( nn.Conv2d(1,32,5,stride=1,padding=2), nn.LeakyReLU(0.2,True), nn.MaxPool2d((2,2)), nn.Conv2d(32,64,5,stride=1,padding=2), nn.LeakyReLU(0.2,True), nn.MaxPool2d((2,2)) ) self.fc=nn.Sequential( nn.Linear(7 * 7 * 64, 1024), nn.LeakyReLU(0.2, True), nn.Linear(1024, 10), nn.Sigmoid() ) def forward(self, x): x=self.dis(x) x=x.view(x.size(0),-1) x=self.fc(x) return x class generator(nn.Module): def __init__(self,input_size,num_feature): super(generator,self).__init__() self.fc=nn.Linear(input_size,num_feature) #1*56*56 self.br=nn.Sequential( nn.BatchNorm2d(1), nn.ReLU(True) ) self.gen=nn.Sequential( nn.Conv2d(1,50,3,stride=1,padding=1), nn.BatchNorm2d(50), nn.ReLU(True), nn.Conv2d(50,25,3,stride=1,padding=1), nn.BatchNorm2d(25), nn.ReLU(True), nn.Conv2d(25,1,2,stride=2), nn.Tanh() ) def forward(self, x): x=self.fc(x) x=x.view(x.size(0),1,56,56) x=self.br(x) x=self.gen(x) return x if __name__=="__main__": criterion=nn.BCELoss() num_img=100 z_dimension=110 D=discriminator() G=generator(z_dimension,3136) #1*56*56 trainset, testset, trainloader, testloader = loadMNIST(num_img) # data D=D.cuda() G=G.cuda() d_optimizer=optim.Adam(D.parameters(),lr=0.0003) g_optimizer=optim.Adam(G.parameters(),lr=0.0003) ''' 交替訓(xùn)練的方式訓(xùn)練網(wǎng)絡(luò) 先訓(xùn)練判別器網(wǎng)絡(luò)D再訓(xùn)練生成器網(wǎng)絡(luò)G 不同網(wǎng)絡(luò)的訓(xùn)練次數(shù)是超參數(shù) 也可以兩個(gè)網(wǎng)絡(luò)訓(xùn)練相同的次數(shù), 這樣就可以不用分別訓(xùn)練兩個(gè)網(wǎng)絡(luò) ''' count=0 #鑒別器D的訓(xùn)練,固定G的參數(shù) epoch = 119 gepoch = 1 for i in range(epoch): for (img, label) in trainloader: labels_onehot = np.zeros((num_img,10)) labels_onehot[np.arange(num_img),label.numpy()]=1 # img=img.view(num_img,-1) # img=np.concatenate((img.numpy(),labels_onehot)) # img=torch.from_numpy(img) img=Variable(img).cuda() real_label=Variable(torch.from_numpy(labels_onehot).float()).cuda()#真實(shí)label為1 fake_label=Variable(torch.zeros(num_img,10)).cuda()#假的label為0 #compute loss of real_img real_out=D(img) #真實(shí)圖片送入判別器D輸出0~1 d_loss_real=criterion(real_out,real_label)#得到loss real_scores=real_out#真實(shí)圖片放入判別器輸出越接近1越好 #compute loss of fake_img z=Variable(torch.randn(num_img,z_dimension)).cuda()#隨機(jī)生成向量 fake_img=G(z)#將向量放入生成網(wǎng)絡(luò)G生成一張圖片 fake_out=D(fake_img)#判別器判斷假的圖片 d_loss_fake=criterion(fake_out,fake_label)#假的圖片的loss fake_scores=fake_out#假的圖片放入判別器輸出越接近0越好 #D bp and optimize d_loss=d_loss_real+d_loss_fake d_optimizer.zero_grad() #判別器D的梯度歸零 d_loss.backward() #反向傳播 d_optimizer.step() #更新判別器D參數(shù) #生成器G的訓(xùn)練compute loss of fake_img for j in range(gepoch): z =torch.randn(num_img, 100) # 隨機(jī)生成向量 z=np.concatenate((z.numpy(),labels_onehot),axis=1) z=Variable(torch.from_numpy(z).float()).cuda() fake_img = G(z) # 將向量放入生成網(wǎng)絡(luò)G生成一張圖片 output = D(fake_img) # 經(jīng)過(guò)判別器得到結(jié)果 g_loss = criterion(output, real_label)#得到假的圖片與真實(shí)標(biāo)簽的loss #bp and optimize g_optimizer.zero_grad() #生成器G的梯度歸零 g_loss.backward() #反向傳播 g_optimizer.step()#更新生成器G參數(shù) temp=real_label if (i%10==0) and (i!=0): print(i) torch.save(G.state_dict(),r'./CGAN/Generator_cuda_%d.pkl'%i) torch.save(D.state_dict(), r'./CGAN/Discriminator_cuda_%d.pkl' % i) save_model(G, r'./CGAN/Generator_cpu_%d.pkl'%i) #保存為CPU中可以打開(kāi)的模型 save_model(D, r'./CGAN/Discriminator_cpu_%d.pkl'%i) #保存為CPU中可以打開(kāi)的模型 print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ' 'D real: {:.6f}, D fake: {:.6f}'.format( i, epoch, d_loss.data[0], g_loss.data[0], real_scores.data.mean(), fake_scores.data.mean())) temp=temp.to('cpu') _,x=torch.max(temp,1) x=x.numpy() print(x[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]]) showimg(fake_img,count) plt.show() count += 1
和基礎(chǔ)GAN Pytorch使用MNIST數(shù)據(jù)集實(shí)現(xiàn)基礎(chǔ)GAN 里面的卷積版網(wǎng)絡(luò)比較起來(lái),這里修改的主要是這幾個(gè)地方:
生成網(wǎng)絡(luò)的輸入值增加了真實(shí)圖片的類標(biāo)簽,生成網(wǎng)絡(luò)的初始向量z_dimension之前用的是100維,由于MNIST有10類,Onehot以后一張圖片的類標(biāo)簽是10維,所以將類標(biāo)簽放在后面z_dimension=100+10=110維;
訓(xùn)練生成器的時(shí)候,由于生成網(wǎng)絡(luò)的輸入向量z_dimension=110維,而且是100維隨機(jī)向量和10維真實(shí)圖片標(biāo)簽拼接,需要做相應(yīng)的拼接操作;
z =torch.randn(num_img, 100) # 隨機(jī)生成向量 z=np.concatenate((z.numpy(),labels_onehot),axis=1) z=Variable(torch.from_numpy(z).float()).cuda()
由于計(jì)算Loss和生成網(wǎng)絡(luò)的輸入向量都需要用到真實(shí)圖片的類標(biāo)簽,需要重新生成real_label,對(duì)label進(jìn)行onehot。其中real_label就是真實(shí)圖片的標(biāo)簽,當(dāng)num_img=100時(shí),real_label的維度是(100,10);
labels_onehot = np.zeros((num_img,10)) labels_onehot[np.arange(num_img),label.numpy()]=1 img=Variable(img).cuda() real_label=Variable(torch.from_numpy(labels_onehot).float()).cuda()#真實(shí)label為1 fake_label=Variable(torch.zeros(num_img,10)).cuda()#假的label為0
real_label的維度是(100,10),計(jì)算Loss的時(shí)候也要有對(duì)應(yīng)的維度,判別網(wǎng)絡(luò)的輸出也不再是標(biāo)量,而是要修改為10維;
nn.Linear(1024, 10)
在輸出圖片的同時(shí)輸出期望的類標(biāo)簽。
temp=temp.to('cpu') _,x=torch.max(temp,1)#返回值有兩個(gè),第一個(gè)是按列的最大值,第二個(gè)是相應(yīng)最大值的列標(biāo)號(hào) x=x.numpy() print(x[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]])
epoch等于0、25、50、75、100時(shí)訓(xùn)練的結(jié)果:
可以看到訓(xùn)練到后面圖像反而變模糊可能是訓(xùn)練過(guò)擬合
用模型生成指定的數(shù)字:
在訓(xùn)練的過(guò)程中保存了訓(xùn)練好的模型,根據(jù)輸出圖片的清晰度,用清晰度較高的模型,使用隨機(jī)向量和10維類標(biāo)簽來(lái)指定生成的數(shù)字。
import torch import torch.nn as nn import pickle import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec num_img=9 class discriminator(nn.Module): def __init__(self): super(discriminator, self).__init__() self.dis = nn.Sequential( nn.Conv2d(1, 32, 5, stride=1, padding=2), nn.LeakyReLU(0.2, True), nn.MaxPool2d((2, 2)), nn.Conv2d(32, 64, 5, stride=1, padding=2), nn.LeakyReLU(0.2, True), nn.MaxPool2d((2, 2)) ) self.fc = nn.Sequential( nn.Linear(7 * 7 * 64, 1024), nn.LeakyReLU(0.2, True), nn.Linear(1024, 10), nn.Sigmoid() ) def forward(self, x): x = self.dis(x) x = x.view(x.size(0), -1) x = self.fc(x) return x class generator(nn.Module): def __init__(self, input_size, num_feature): super(generator, self).__init__() self.fc = nn.Linear(input_size, num_feature) # 1*56*56 self.br = nn.Sequential( nn.BatchNorm2d(1), nn.ReLU(True) ) self.gen = nn.Sequential( nn.Conv2d(1, 50, 3, stride=1, padding=1), nn.BatchNorm2d(50), nn.ReLU(True), nn.Conv2d(50, 25, 3, stride=1, padding=1), nn.BatchNorm2d(25), nn.ReLU(True), nn.Conv2d(25, 1, 2, stride=2), nn.Tanh() ) def forward(self, x): x = self.fc(x) x = x.view(x.size(0), 1, 56, 56) x = self.br(x) x = self.gen(x) return x def show(images): images = images.detach().numpy() images = 255 * (0.5 * images + 0.5) images = images.astype(np.uint8) plt.figure(figsize=(4, 4)) width = images.shape[2] gs = gridspec.GridSpec(1, num_img, wspace=0, hspace=0) for i, img in enumerate(images): ax = plt.subplot(gs[i]) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(img.reshape(width, width), cmap=plt.cm.gray) plt.axis('off') plt.tight_layout() plt.tight_layout() # plt.savefig(r'drive/深度學(xué)習(xí)/DCGAN/images/%d.png' % count, bbox_inches='tight') return width def show_all(images_all): x=images_all[0] for i in range(1,len(images_all),1): x=np.concatenate((x,images_all[i]),0) print(x.shape) x = 255 * (0.5 * x + 0.5) x = x.astype(np.uint8) plt.figure(figsize=(9, 10)) width = x.shape[2] gs = gridspec.GridSpec(10, num_img, wspace=0, hspace=0) for i, img in enumerate(x): ax = plt.subplot(gs[i]) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(img.reshape(width, width), cmap=plt.cm.gray) plt.axis('off') plt.tight_layout() # 導(dǎo)入相應(yīng)的模型 z_dimension = 110 D = discriminator() G = generator(z_dimension, 3136) # 1*56*56 D.load_state_dict(torch.load(r'./CGAN/Discriminator.pkl')) G.load_state_dict(torch.load(r'./CGAN/Generator.pkl')) # 依次生成0到9 lis=[] for i in range(10): z = torch.randn((num_img, 100)) # 隨機(jī)生成向量 x=np.zeros((num_img,10)) x[:,i]=1 z = np.concatenate((z.numpy(), x),1) z = torch.from_numpy(z).float() fake_img = G(z) # 將向量放入生成網(wǎng)絡(luò)G生成一張圖片 lis.append(fake_img.detach().numpy()) output = D(fake_img) # 經(jīng)過(guò)判別器得到結(jié)果 show(fake_img) plt.savefig('./CGAN/generator/%d.png' % i, bbox_inches='tight') show_all(lis) plt.savefig('./CGAN/generator/all.png', bbox_inches='tight') plt.show()
生成的結(jié)果是:
以上這篇Pytorch使用MNIST數(shù)據(jù)集實(shí)現(xiàn)CGAN和生成指定的數(shù)字方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python基于百度AI實(shí)現(xiàn)抓取表情包
本文先抓取網(wǎng)絡(luò)上的表情圖像,然后利用百度 AI 識(shí)別表情包上的說(shuō)明文字,并利用表情文字重命名文件,感興趣的小伙伴們可以參考一下2021-06-06Python列表list常用內(nèi)建函數(shù)實(shí)例小結(jié)
這篇文章主要介紹了Python列表list常用內(nèi)建函數(shù),結(jié)合實(shí)例形式總結(jié)分析了Python列表list常見(jiàn)內(nèi)建函數(shù)的功能、使用方法及相關(guān)操作注意事項(xiàng),需要的朋友可以參考下2019-10-10不可錯(cuò)過(guò)的十本Python好書(shū)
不可錯(cuò)過(guò)的十本Python好書(shū),分別適合入門(mén)、進(jìn)階到精深三個(gè)不同階段的人來(lái)閱讀,感興趣的小伙伴們可以參考一下2017-07-07Python中你應(yīng)該知道的一些內(nèi)置函數(shù)
python提供了內(nèi)聯(lián)模塊buidin,該模塊定義了一些軟件開(kāi)發(fā)中常用的函數(shù),這些函數(shù)實(shí)現(xiàn)了數(shù)據(jù)類型的轉(zhuǎn)換,數(shù)據(jù)的計(jì)算,序列的處理等功能。下面這篇文章主要給大家介紹了Python中一些大家應(yīng)該知道的內(nèi)置函數(shù),文中總結(jié)的非常詳細(xì),需要的朋友們下面來(lái)一起看看吧。2017-03-03python-pymysql如何實(shí)現(xiàn)更新mysql表中任意字段數(shù)據(jù)
這篇文章主要介紹了python-pymysql如何實(shí)現(xiàn)更新mysql表中任意字段數(shù)據(jù)問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-05-05