亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

pytorch 實現(xiàn)變分自動編碼器的操作

 更新時間:2021年05月24日 09:45:47   作者:xckkcxxck  
這篇文章主要介紹了pytorch 實現(xiàn)變分自動編碼器的操作,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教

本來以為自動編碼器是很簡單的東西,但是也是看了好多資料仍然不太懂它的原理。先把代碼記錄下來,有時間好好研究。

這個例子是用MNIST數(shù)據(jù)集生成為例子

# -*- coding: utf-8 -*-
"""
Created on Fri Oct 12 11:42:19 2018
@author: www
""" 
import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 標準化
])
 
train_set = MNIST('E:\data', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 編碼
        z = self.reparametrize(mu, logvar) # 重新參數(shù)化成正態(tài)分布
        return self.decode(z), mu, logvar # 解碼,同時輸出均值方差 
 
net = VAE() # 實例化網(wǎng)絡
if torch.cuda.is_available():
    net = net.cuda()
    
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu)
 
#可以看到,對于輸入,網(wǎng)絡可以輸出隱含變量的均值和方差,這里的均值方差還沒有訓練
 
#下面開始訓練 
reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定義一個函數(shù)將最后的結(jié)果轉(zhuǎn)換回圖片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 將 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item()))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))                    
          

補充:PyTorch 深度學習快速入門——變分自動編碼器

變分編碼器是自動編碼器的升級版本,其結(jié)構(gòu)跟自動編碼器是類似的,也由編碼器和解碼器構(gòu)成。

回憶一下,自動編碼器有個問題,就是并不能任意生成圖片,因為我們沒有辦法自己去構(gòu)造隱藏向量,需要通過一張圖片輸入編碼我們才知道得到的隱含向量是什么,這時我們就可以通過變分自動編碼器來解決這個問題。

其實原理特別簡單,只需要在編碼過程給它增加一些限制,迫使其生成的隱含向量能夠粗略的遵循一個標準正態(tài)分布,這就是其與一般的自動編碼器最大的不同。

這樣我們生成一張新圖片就很簡單了,我們只需要給它一個標準正態(tài)分布的隨機隱含向量,這樣通過解碼器就能夠生成我們想要的圖片,而不需要給它一張原始圖片先編碼。

一般來講,我們通過 encoder 得到的隱含向量并不是一個標準的正態(tài)分布,為了衡量兩種分布的相似程度,我們使用 KL divergence,利用其來表示隱含向量與標準正態(tài)分布之間差異的 loss,另外一個 loss 仍然使用生成圖片與原圖片的均方誤差來表示。

KL divergence 的公式如下

重參數(shù) 為了避免計算 KL divergence 中的積分,我們使用重參數(shù)的技巧,不是每次產(chǎn)生一個隱含向量,而是生成兩個向量,一個表示均值,一個表示標準差,這里我們默認編碼之后的隱含向量服從一個正態(tài)分布的之后,就可以用一個標準正態(tài)分布先乘上標準差再加上均值來合成這個正態(tài)分布,最后 loss 就是希望這個生成的正態(tài)分布能夠符合一個標準正態(tài)分布,也就是希望均值為 0,方差為 1

所以最后我們可以將我們的 loss 定義為下面的函數(shù),由均方誤差和 KL divergence 求和得到一個總的 loss

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD

用 mnist 數(shù)據(jù)集來簡單說明一下變分自動編碼器

import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 標準化
])
 
train_set = MNIST('./mnist', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 編碼
        z = self.reparametrize(mu, logvar) # 重新參數(shù)化成正態(tài)分布
        return self.decode(z), mu, logvar # 解碼,同時輸出均值方差
 
net = VAE() # 實例化網(wǎng)絡
if torch.cuda.is_available():
    net = net.cuda()
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu) 
 
Variable containing:  Columns 0 to 9  -0.0307 -0.1439 -0.0435  0.3472  0.0368 -0.0339  0.0274 -0.5608  0.0280  0.2742  Columns 10 to 19  -0.6221 -0.0894 -0.0933  0.4241  0.1611  0.3267  0.5755 -0.0237  0.2714 -0.2806 [torch.cuda.FloatTensor of size 1x20 (GPU 0)]

可以看到,對于輸入,網(wǎng)絡可以輸出隱含變量的均值和方差,這里的均值方差還沒有訓練 下面開始訓練

reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定義一個函數(shù)將最后的結(jié)果轉(zhuǎn)換回圖片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 將 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data[0]))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))
  
epoch: 20, Loss: 61.5803 epoch: 40, Loss: 62.9573 epoch: 60, Loss: 63.4285 epoch: 80, Loss: 64.7138 epoch: 100, Loss: 63.3343

變分自動編碼器雖然比一般的自動編碼器效果要好,而且也限制了其輸出的編碼 (code) 的概率分布,但是它仍然是通過直接計算生成圖片和原始圖片的均方誤差來生成 loss,這個方式并不好,生成對抗網(wǎng)絡中,我們會講一講這種方式計算 loss 的局限性,然后會介紹一種新的訓練辦法,就是通過生成對抗的訓練方式來訓練網(wǎng)絡而不是直接比較兩張圖片的每個像素點的均方誤差

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python 中的9個實用技巧,助你提高開發(fā)效率

    python 中的9個實用技巧,助你提高開發(fā)效率

    這篇文章主要介紹了python 中的9個實用技巧,幫助大家提高python開發(fā)時的效率,感興趣的朋友可以了解下
    2020-08-08
  • 深入解讀Python如何進行文件讀寫

    深入解讀Python如何進行文件讀寫

    文件的作用 就是把一些存儲存放起來,可以讓程序下一次執(zhí)行的時候直接使用,而不必重新制作一份,省時省力,本文將帶你了解通過python如何進行文件的讀寫操作
    2021-10-10
  • numpy中幾種隨機數(shù)生成函數(shù)的用法

    numpy中幾種隨機數(shù)生成函數(shù)的用法

    numpy是Python中常用的科學計算庫,其中也包含了一些隨機數(shù)生成函數(shù),本文主要介紹了numpy中幾種隨機數(shù)生成函數(shù)的用法,具有一定的參考價值,感興趣的可以了解一下
    2023-11-11
  • python刪除文件、清空目錄的實現(xiàn)方法

    python刪除文件、清空目錄的實現(xiàn)方法

    這篇文章主要介紹了python刪除文件、清空目錄的實現(xiàn)方法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2020-09-09
  • Pandas執(zhí)行SQL操作的實現(xiàn)

    Pandas執(zhí)行SQL操作的實現(xiàn)

    使用SQL語句能夠完成對table的增刪改查操作,Pandas同樣也可以實現(xiàn)SQL語句的基本功能,本文就來介紹一下,具有一檔的參考價值,感興趣的可以了解一下
    2024-07-07
  • 詳解python深淺拷貝區(qū)別

    詳解python深淺拷貝區(qū)別

    在本篇文章里小編給大家整理了關(guān)于python深淺拷貝區(qū)別的相關(guān)知識點總結(jié),有興趣的朋友們可以參考下。
    2019-06-06
  • Python算法練習之二分查找算法的實現(xiàn)

    Python算法練習之二分查找算法的實現(xiàn)

    二分查找也稱折半查找(Binary Search),它是一種效率較高的查找方法。本文將介紹python如何實現(xiàn)二分查找算法,幫助大家更好的理解和使用python,感興趣的朋友可以了解下
    2022-06-06
  • 關(guān)于Python中的空值問題及解決

    關(guān)于Python中的空值問題及解決

    這篇文章主要介紹了關(guān)于Python中的空值問題及解決方案,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2023-11-11
  • Python生成器generator原理及用法解析

    Python生成器generator原理及用法解析

    這篇文章主要介紹了Python生成器generator原理及用法解析,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2020-07-07
  • pytorch實現(xiàn)下載加載mnist數(shù)據(jù)集

    pytorch實現(xiàn)下載加載mnist數(shù)據(jù)集

    這篇文章主要介紹了pytorch實現(xiàn)下載加載mnist數(shù)據(jù)集方式,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2024-06-06

最新評論