亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

pytorch:實現(xiàn)簡單的GAN示例(MNIST數據集)

 更新時間:2020年01月10日 09:17:37   作者:xckkcxxck  
今天小編就為大家分享一篇pytorch:實現(xiàn)簡單的GAN示例(MNIST數據集),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

我就廢話不多說了,直接上代碼吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 設置畫圖的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定義畫圖工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定義一個取樣的函數
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可視化圖片效果
show_images(imgs)
 
#判別網絡
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成網絡
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判別器的 loss 就是將真實數據的得分判斷為 1,假的數據的得分判斷為 0,而生成器的 loss 就是將假的數據判斷為 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵損失函數
 
def discriminator_loss(logits_real, logits_fake): # 判別器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 來進行訓練,學習率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判別網絡
      real_data = Variable(x).view(bs, -1) # 真實數據
      logits_real = D_net(real_data) # 判別網絡得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的數據
      logits_fake = D_net(fake_images) # 判別網絡得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判別器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 優(yōu)化判別網絡
      
      # 生成網絡
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的數據
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成網絡的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 優(yōu)化生成網絡
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)      

以上這篇pytorch:實現(xiàn)簡單的GAN示例(MNIST數據集)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關文章

  • Python PyQt5標準對話框用法示例

    Python PyQt5標準對話框用法示例

    這篇文章主要介紹了Python PyQt5標準對話框用法,結合實例形式分析了PyQt5常用的標準對話框及相關使用技巧,需要的朋友可以參考下
    2017-08-08
  • python之線程通過信號pyqtSignal刷新ui的方法

    python之線程通過信號pyqtSignal刷新ui的方法

    今天小編就為大家分享一篇python之線程通過信號pyqtSignal刷新ui的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-01-01
  • python創(chuàng)建模板文件及使用教程示例

    python創(chuàng)建模板文件及使用教程示例

    這篇文章主要介紹了python創(chuàng)建模板文件及使用教程示例
    2021-10-10
  • 淺談對pytroch中torch.autograd.backward的思考

    淺談對pytroch中torch.autograd.backward的思考

    這篇文章主要介紹了對pytroch中torch.autograd.backward的思考,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2019-12-12
  • python中import和from-import的區(qū)別解析

    python中import和from-import的區(qū)別解析

    這篇文章主要介紹了python中import和from-import的區(qū)別解析,本文通過實例代碼給大家講解的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2022-12-12
  • Python多路復用selector模塊的基本使用

    Python多路復用selector模塊的基本使用

    Python提供了selector模塊來實現(xiàn)IO多路復用,這篇文章給大家介紹了Python多路復用selector模塊的基本使用,感興趣的朋友一起看看吧
    2021-11-11
  • Python使用虛擬環(huán)境(安裝下載更新卸載)命令

    Python使用虛擬環(huán)境(安裝下載更新卸載)命令

    這篇文章主要為大家介紹了Python使用虛擬環(huán)境(安裝下載更新卸載)命令,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2023-11-11
  • Python安裝spark的詳細過程

    Python安裝spark的詳細過程

    這篇文章主要介紹了Python安裝spark的詳細過程,本文通過圖文實例代碼相結合給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2021-10-10
  • Python實現(xiàn)拉格朗日插值法的示例詳解

    Python實現(xiàn)拉格朗日插值法的示例詳解

    插值法是一種數學方法,用于在已知數據點(離散數據)之間插入數據,以生成連續(xù)的函數曲線,而格朗日插值法是一種多項式插值法。本文就來用Python實現(xiàn)拉格朗日插值法,希望對大家有所幫助
    2023-02-02
  • python應用之如何使用Python發(fā)送通知到微信

    python應用之如何使用Python發(fā)送通知到微信

    現(xiàn)在通過發(fā)微信信息來做消息通知和告警已經很普遍了,下面這篇文章主要給大家介紹了關于python應用之如何使用Python發(fā)送通知到微信的相關資料,文中通過實例代碼介紹的非常詳細,需要的朋友可以參考下
    2022-03-03

最新評論