亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

使用Pytorch如何完成多分類(lèi)問(wèn)題

 更新時(shí)間:2023年02月02日 11:10:08   作者:LiBiGo  
這篇文章主要介紹了使用Pytorch如何完成多分類(lèi)問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

Pytorch如何完成多分類(lèi)

多分類(lèi)問(wèn)題在最后的輸出層采用的Softmax Layer,其具有兩個(gè)特點(diǎn):1.每個(gè)輸出的值都是在(0,1);2.所有值加起來(lái)和為1.

假設(shè)是最后線性層的輸出,則對(duì)應(yīng)的Softmax function為:

輸出經(jīng)過(guò)sigmoid運(yùn)算即可是西安輸出的分類(lèi)概率都大于0且總和為1。

上圖的交叉熵?fù)p失就包含了softmax計(jì)算和右邊的標(biāo)簽輸入計(jì)算(即框起來(lái)的部分)

所以在使用交叉熵?fù)p失的時(shí)候,神經(jīng)網(wǎng)絡(luò)的最后一層是不要做激活的,因?yàn)榘阉龀煞植嫉募せ钍前诮徊骒負(fù)p失里面的,最后一層不要做非線性變換,直接交給交叉熵?fù)p失。

如上圖,做交叉熵?fù)p失時(shí)要求y是一個(gè)長(zhǎng)整型的張量,構(gòu)造時(shí)直接用

criterion = torch.nn.CrossEntropyLoss()

3個(gè)類(lèi)別,分別是2,0,1

Y_pred1 ,Y_pred2還是線性輸出,沒(méi)經(jīng)過(guò)softmax,還不是概率分布,比如Y_pred1,0.9最大,表示對(duì)應(yīng)為第3個(gè)的概率最大,和2吻合,1.1最大,表示對(duì)應(yīng)為第1個(gè)的概率最大,和0吻合,2.1最大,表示對(duì)應(yīng)為第2個(gè)的概率最大,和1吻合,那么Y_pred1 的損失會(huì)比較小

對(duì)于Y_pred2,0.8最大,表示對(duì)應(yīng)為第1個(gè)的概率最大,和0不吻合,0.5最大,表示對(duì)應(yīng)為第3個(gè)的概率最大,和2不吻合,0.5最大,表示對(duì)應(yīng)為第3個(gè)的概率最大,和2不吻合,那么Y_pred2 的損失會(huì)比較大

Exercise 9-1: CrossEntropyLoss vs NLLLoss

What are the differences?

• Reading the document:

https://pytorch.org/docs/stable/nn.html#crossentropyloss

https://pytorch.org/docs/stable/nn.html#nllloss

• Try to know why:

• CrossEntropyLoss <==> LogSoftmax + NLLLoss

為什么要用transform

transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])

PyTorch讀圖像用的是python的imageLibrary,就是PIL,現(xiàn)在用的都是pillow,pillow讀進(jìn)來(lái)的圖像用神經(jīng)網(wǎng)絡(luò)處理的時(shí)候,神經(jīng)網(wǎng)絡(luò)有一個(gè)特點(diǎn)就是希望輸入的數(shù)值比較小,最好是在-1到+1之間,最好是輸入遵從正態(tài)分布,這樣的輸入對(duì)神經(jīng)網(wǎng)絡(luò)訓(xùn)練是最有幫助的

原始圖像是28*28的像素值在0到255之間,我們把它轉(zhuǎn)變成圖像張量,像素值是0到1

在視覺(jué)里面,灰度圖就是一個(gè)矩陣,但實(shí)際上并不是一個(gè)矩陣,我們把它叫做單通道圖像,彩色圖像是3通道,通道有寬度和高度,一般我們讀進(jìn)來(lái)的圖像張量是WHC(寬高通道)

在PyTorch里面我們需要轉(zhuǎn)化成CWH,把通道放在前面是為了在PyTorch里面進(jìn)行更高效的圖像處理,卷積運(yùn)算。所以拿到圖像之后,我們就把它先轉(zhuǎn)化成pytorch里面的一個(gè)Tensor,把0到255的值變成0到1的浮點(diǎn)數(shù),然后把維度由2828變成128*28的張量,由單通道變成多通道,

這個(gè)過(guò)程可以用transforms的ToTensor這個(gè)函數(shù)實(shí)現(xiàn)

歸一化

transforms.Normalize((0.1307, ), (0.3081, ))

這里的0.1307,0.3081是對(duì)Mnist數(shù)據(jù)集所有的像素求均值方差得到的

也就是說(shuō),將來(lái)拿到了圖像,先變成張量,然后Normalize,切換到0,1分布,然后供神經(jīng)網(wǎng)絡(luò)訓(xùn)練

如上圖,定義好transform變換之后,直接把它放到數(shù)據(jù)集里面,為什么要放在數(shù)據(jù)集里面呢,是為了在讀取第i個(gè)數(shù)據(jù)的時(shí)候,直接用transform處理

 

模型

輸入是一組圖像,激活層改用Relu

全連接神經(jīng)網(wǎng)絡(luò)要求輸入是一個(gè)矩陣

所以需要把輸入的張量變成一階的,這里的N表示有N個(gè)圖片

view函數(shù)可以改變張量的形狀,-1表示將來(lái)自動(dòng)去算它的值是多少,比如輸入是n128*28

將來(lái)會(huì)自動(dòng)把n算出來(lái),輸入了張量就知道形狀,就知道有多少個(gè)數(shù)值

最后輸出是(N,10)因?yàn)槭怯?-9這10個(gè)標(biāo)簽嘛,10表示該圖像屬于某一個(gè)標(biāo)簽的概率,現(xiàn)在還是線性值,我們?cè)儆胹oftmax把它變成概率

 #沿著第一個(gè)維度找最大值的下標(biāo),返回值有兩個(gè),因?yàn)槭?0列嘛,返回值一個(gè)是每一行的最大值,另一個(gè)是最大值的下標(biāo)(每一個(gè)樣本就是一行,每一行有10個(gè)量)(行是第0個(gè)維度,列是第1個(gè)維度)

MNIST數(shù)據(jù)集訓(xùn)練代碼

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
 
# prepare dataset
 
batch_size = 64
 
transform = transforms.Compose([
    transforms.ToTensor(), #先將圖像變換成一個(gè)張量tensor。
    transforms.Normalize((0.1307,), (0.3081,))
    #其中的0.1307是MNIST數(shù)據(jù)集的均值,0.3081是MNIST數(shù)據(jù)集的標(biāo)準(zhǔn)差。
])  # 歸一化,均值和方差
 
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True,
                               download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
 
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False,
                               download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
 
# design model using class
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)
 
    def forward(self, x):
        # 28 * 28 = 784
        # 784 = 28 * 28,即將N *1*28*28轉(zhuǎn)化成 N *1*784
        x = x.view(-1, 784)  # -1其實(shí)就是自動(dòng)獲取mini_batch
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        return self.l5(x)  # 最后一層不做激活,不進(jìn)行非線性變換
 
model = Net()
 
#CrossEntropyLoss <==> LogSoftmax + NLLLoss。
#也就是說(shuō)使用CrossEntropyLoss最后一層(線性層)是不需要做其他變化的;
#使用NLLLoss之前,需要對(duì)最后一層(線性層)先進(jìn)行SoftMax處理,再進(jìn)行l(wèi)og操作。
 
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
#momentum 是帶有優(yōu)化的一個(gè)訓(xùn)練過(guò)程參數(shù)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 
# training cycle forward, backward, update
 
def train(epoch):
    running_loss = 0.0
    #enumerate()函數(shù)用于將一個(gè)可遍歷的數(shù)據(jù)對(duì)象(如列表、元組或字符串)組合為一個(gè)索引序列,
    #同時(shí)列出數(shù)據(jù)和數(shù)據(jù)下標(biāo),一般用在 for 循環(huán)當(dāng)中。
    #enumerate(sequence, [start=0])
    for batch_idx, data in enumerate(train_loader, 0):
        # 獲得一個(gè)批次的數(shù)據(jù)和標(biāo)簽
        inputs, target = data
        optimizer.zero_grad()
 
        #forward + backward + update
        # 獲得模型預(yù)測(cè)結(jié)果(64, 10)
        outputs = model(inputs)
        # 交叉熵代價(jià)函數(shù)outputs(64,10),target(64)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0
 
def test():
    correct = 0
    total = 0
    with torch.no_grad():#不需要計(jì)算梯度。
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            #orch.max的返回值有兩個(gè),第一個(gè)是每一行的最大值是多少,第二個(gè)是每一行最大值的下標(biāo)(索引)是多少。
            _, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0個(gè)維度,行是第1個(gè)維度
            total += labels.size(0)
            correct += (predicted == labels).sum().item()  # 張量之間的比較運(yùn)算
    print('accuracy on test set: %d %% ' % (100 * correct / total))
 
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python之np.where()如何替換缺失值

    Python之np.where()如何替換缺失值

    這篇文章主要介紹了Python中的np.where()如何替換缺失值問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2024-02-02
  • 詳解python單元測(cè)試框架unittest

    詳解python單元測(cè)試框架unittest

    本篇文章給大家詳解了python單元測(cè)試框架unittest的相關(guān)知識(shí)點(diǎn),有興趣的朋友參考學(xué)習(xí)下。
    2018-07-07
  • 在linux系統(tǒng)下安裝python librtmp包的實(shí)現(xiàn)方法

    在linux系統(tǒng)下安裝python librtmp包的實(shí)現(xiàn)方法

    今天小編就為大家分享一篇在linux系統(tǒng)下安裝python librtmp包的實(shí)現(xiàn)方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-07-07
  • 利用python模擬sql語(yǔ)句對(duì)員工表格進(jìn)行增刪改查

    利用python模擬sql語(yǔ)句對(duì)員工表格進(jìn)行增刪改查

    這篇文章主要給大家介紹了關(guān)于利用python模擬sql語(yǔ)句實(shí)現(xiàn)對(duì)員工表格進(jìn)行增刪改查的相關(guān)資料,文中介紹了詳細(xì)的需求以及示例代碼,對(duì)大家的理解和學(xué)習(xí)具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起看看吧。
    2017-07-07
  • 淺談四種快速易用的Python數(shù)據(jù)可視化方法

    淺談四種快速易用的Python數(shù)據(jù)可視化方法

    這篇文章主要介紹了淺談四種快速易用的Python數(shù)據(jù)可視化方法,數(shù)據(jù)可視化,是指用圖形的方式來(lái)展現(xiàn)數(shù)據(jù),從而更加清晰有效地傳遞信息,主要方法包括圖表類(lèi)型的選擇和圖表設(shè)計(jì)的準(zhǔn)則,需要的朋友可以參考下
    2023-04-04
  • Python3 itchat實(shí)現(xiàn)微信定時(shí)發(fā)送群消息的實(shí)例代碼

    Python3 itchat實(shí)現(xiàn)微信定時(shí)發(fā)送群消息的實(shí)例代碼

    使用微信,定時(shí)往指定的微信群里發(fā)送指定信息。接下來(lái)通過(guò)本文給大家分享Python3 itchat實(shí)現(xiàn)微信定時(shí)發(fā)送群消息的實(shí)例代碼,需要的朋友可以參考下
    2019-07-07
  • pandas 根據(jù)列的值選取所有行的示例

    pandas 根據(jù)列的值選取所有行的示例

    今天小編就為大家分享一篇pandas 根據(jù)列的值選取所有行的示例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2018-11-11
  • python連接mongodb操作數(shù)據(jù)示例(mongodb數(shù)據(jù)庫(kù)配置類(lèi))

    python連接mongodb操作數(shù)據(jù)示例(mongodb數(shù)據(jù)庫(kù)配置類(lèi))

    這篇文章主要介紹了python連接mongodb操作數(shù)據(jù)示例,主要包括插入數(shù)據(jù)、更新數(shù)據(jù)、查詢數(shù)據(jù)、刪除數(shù)據(jù)等
    2013-12-12
  • Python中分支語(yǔ)句與循環(huán)語(yǔ)句實(shí)例詳解

    Python中分支語(yǔ)句與循環(huán)語(yǔ)句實(shí)例詳解

    這篇文章主要給大家介紹了關(guān)于Python中分支語(yǔ)句與循環(huán)語(yǔ)句的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家學(xué)習(xí)或者使用python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2018-09-09
  • Python串口通信的接收與發(fā)送的實(shí)現(xiàn)

    Python串口通信的接收與發(fā)送的實(shí)現(xiàn)

    串口通信是指通過(guò)串口進(jìn)行數(shù)據(jù)傳輸?shù)囊环N通信方式,本文就來(lái)介紹一下Python串口通信的接收與發(fā)送的實(shí)現(xiàn),具有一定的參考價(jià)值,感興趣的可以了解一下
    2023-11-11

最新評(píng)論