PyTorch零基礎(chǔ)入門之邏輯斯蒂回歸
學(xué)習(xí)總結(jié)
(1)和上一講的模型訓(xùn)練是類似的,只是在線性模型的基礎(chǔ)上加個(gè)sigmoid,然后loss函數(shù)改為交叉熵BCE函數(shù)(當(dāng)然也可以用其他函數(shù)),另外一開始的數(shù)據(jù)y_data也從數(shù)值改為類別0和1(本例為二分類,注意x_data
和y_data
這里也是矩陣的形式)。
一、sigmoid函數(shù)
logistic function是一種sigmoid函數(shù)(還有其他sigmoid函數(shù)),但由于使用過于廣泛,pytorch默認(rèn)logistic function叫為sigmoid函數(shù)。還有如下的各種sigmoid函數(shù):
二、和Linear的區(qū)別
邏輯斯蒂和線性模型的unit區(qū)別如下圖:
sigmoid
函數(shù)是不需要參數(shù)的,所以不用對其初始化(直接調(diào)用nn.functional.sigmoid
即可)。
另外loss函數(shù)從MSE改用交叉熵BCE:盡可能和真實(shí)分類貼近。
如下圖右方表格所示,當(dāng) y ^ \hat{y} y^越接近y時(shí)則BCE Loss值越小。
三、邏輯斯蒂回歸(分類)PyTorch實(shí)現(xiàn)
# -*- coding: utf-8 -*- """ Created on Mon Oct 18 08:35:00 2021 @author: 86493 """ import torch import torch.nn as nn import matplotlib.pyplot as plt import torch.nn.functional as F import numpy as np # 準(zhǔn)備數(shù)據(jù) x_data = torch.Tensor([[1.0], [2.0], [3.0]]) y_data = torch.Tensor([[0], [0], [1]]) losslst = [] class LogisticRegressionModel(nn.Module): def __init__(self): super(LogisticRegressionModel, self).__init__() self.linear = torch.nn.Linear(1, 1) def forward(self, x): # 和線性模型的網(wǎng)絡(luò)的唯一區(qū)別在這句,多了F.sigmoid y_predict = F.sigmoid(self.linear(x)) return y_predict model = LogisticRegressionModel() # 使用交叉熵作損失函數(shù) criterion = torch.nn.BCELoss(size_average = False) optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # 訓(xùn)練 for epoch in range(1000): y_predict = model(x_data) loss = criterion(y_predict, y_data) # 打印loss對象會自動(dòng)調(diào)用__str__ print(epoch, loss.item()) losslst.append(loss.item()) # 梯度清零后反向傳播 optimizer.zero_grad() loss.backward() optimizer.step() # 畫圖 plt.plot(range(1000), losslst) plt.ylabel('Loss') plt.xlabel('epoch') plt.show() # test # 每周學(xué)習(xí)的時(shí)間,200個(gè)點(diǎn) x = np.linspace(0, 10, 200) x_t = torch.Tensor(x).view((200, 1)) y_t = model(x_t) y = y_t.data.numpy() plt.plot(x, y) # 畫 probability of pass = 0.5的紅色橫線 plt.plot([0, 10], [0.5, 0.5], c = 'r') plt.xlabel('Hours') plt.ylabel('Probability of Pass') plt.grid() plt.show()
可以看出處于通過和不通過的分界線是Hours=2.5。
Reference
到此這篇關(guān)于PyTorch零基礎(chǔ)入門之邏輯斯蒂回歸的文章就介紹到這了,更多相關(guān)PyTorch 邏輯斯蒂回歸內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python之如何將標(biāo)簽轉(zhuǎn)化為one-hot(獨(dú)熱編碼)
這篇文章主要介紹了python之如何將標(biāo)簽轉(zhuǎn)化為one-hot(獨(dú)熱編碼)問題,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-06-06Python中的復(fù)制操作及copy模塊中的淺拷貝與深拷貝方法
淺拷貝和深拷貝是Python基礎(chǔ)學(xué)習(xí)中必須辨析的知識點(diǎn),這里我們將為大家解析Python中的復(fù)制操作及copy模塊中的淺拷貝與深拷貝方法:2016-07-07Python使用Selenium、PhantomJS爬取動(dòng)態(tài)渲染頁面
本文主要介紹了Python使用Selenium、PhantomJS爬取動(dòng)態(tài)渲染頁面,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-05-05python程序中斷然后接著中斷代碼繼續(xù)運(yùn)行問題
這篇文章主要介紹了python程序中斷然后接著中斷代碼繼續(xù)運(yùn)行問題,具有很好的參考價(jià)值,希望對大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-02-02python實(shí)現(xiàn)跨excel的工作表sheet之間的復(fù)制方法
今天小編就為大家分享一篇python實(shí)現(xiàn)跨excel的工作表sheet之間的復(fù)制方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-05-05Python數(shù)據(jù)分析之?Pandas?Dataframe應(yīng)用自定義
這篇文章主要介紹了Python數(shù)據(jù)分析之?Pandas?Dataframe應(yīng)用自定義,文章基于python的相關(guān)資料展開?Pandas?Dataframe應(yīng)用自定義的詳細(xì)內(nèi)容,需要的小伙伴可以參考一下2022-05-05pandas如何將DataFrame?轉(zhuǎn)為txt文本去除引號
這篇文章主要介紹了pandas如何將DataFrame?轉(zhuǎn)為txt文本去除引號,文中補(bǔ)充介紹了DataFrame導(dǎo)CSV?txt?||?每行有雙引號的原因及解決辦法,感興趣的朋友跟隨小編一起看看吧2024-01-01PyChar學(xué)習(xí)教程之自定義文件與代碼模板詳解
pycharm默認(rèn)的【新建】文件,格式很不友好,那么就需要改一下文件模板。下面這篇文章主要給大家介紹了關(guān)于PyChar學(xué)習(xí)教程之自定義文件與代碼模板的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),需要的朋友們下面跟著小編來一起看看吧。2017-07-07