使用Pytorch實(shí)現(xiàn)Swish激活函數(shù)的示例詳解
前言
激活函數(shù)是人工神經(jīng)網(wǎng)絡(luò)的基本組成部分。他們將非線性引入模型,使其能夠?qū)W習(xí)數(shù)據(jù)中的復(fù)雜關(guān)系。Swish 激活函數(shù)就是此類激活函數(shù)之一,因其獨(dú)特的屬性和相對于廣泛使用的整流線性單元 (ReLU) 激活的潛在優(yōu)勢而受到關(guān)注。在本文中,我們將深入研究 Swish 激活函數(shù),提供數(shù)學(xué)公式,探索其相對于 ReLU 的優(yōu)勢,并使用 PyTorch 演示其實(shí)現(xiàn)。
Swish 激活功能
Swish 激活函數(shù)由 Google 研究人員于 2017 年推出,其數(shù)學(xué)定義如下:
Swish(x) = x * sigmoid(x)
Where:
- x:激活函數(shù)的輸入值。
- sigmoid(x):sigmoid 函數(shù),將任何實(shí)數(shù)值映射到范圍 [0, 1]。隨著 x 的增加,它從 0 平滑過渡到 1。
Swish 激活將線性分量(輸入 x)與非線性分量(sigmoid函數(shù))相結(jié)合,產(chǎn)生平滑且可微的激活函數(shù)。
在哪里使用 Swish 激活?
Swish 可用于各種神經(jīng)網(wǎng)絡(luò)架構(gòu),包括前饋神經(jīng)網(wǎng)絡(luò)、卷積神經(jīng)網(wǎng)絡(luò) (CNN) 和循環(huán)神經(jīng)網(wǎng)絡(luò) (RNN)。它的優(yōu)勢在深度網(wǎng)絡(luò)中變得尤為明顯,它可以幫助緩解梯度消失問題。
Swish 激活函數(shù)相對于 ReLU 的優(yōu)點(diǎn)
現(xiàn)在,我們來探討一下 Swish 激活函數(shù)與流行的 ReLU 激活函數(shù)相比的優(yōu)勢。
平滑度和可微分性
由于 sigmoid 分量的存在,Swish 是一個平滑且可微的函數(shù)。此屬性使其非常適合基于梯度的優(yōu)化技術(shù),例如隨機(jī)梯度下降 (SGD) 和反向傳播。相比之下,ReLU 在零處不可微(ReLU 的導(dǎo)數(shù)在 x=0 時未定義),這可能會帶來優(yōu)化挑戰(zhàn)。
改進(jìn)深度網(wǎng)絡(luò)的學(xué)習(xí)
在深度神經(jīng)網(wǎng)絡(luò)中,與 ReLU 相比,Swish 可以實(shí)現(xiàn)更好的學(xué)習(xí)和收斂。Swish 的平滑性有助于梯度在網(wǎng)絡(luò)中更平滑地流動,減少訓(xùn)練期間梯度消失的可能性。這在非常深的網(wǎng)絡(luò)中尤其有用。
類似的計算成本
Swish 激活的計算效率很高,類似于 ReLU。這兩個函數(shù)都涉及基本的算術(shù)運(yùn)算,不會顯著增加訓(xùn)練或推理過程中的計算負(fù)擔(dān)。
使用 PyTorch 實(shí)現(xiàn)
現(xiàn)在,我們來看看如何使用 PyTorch 實(shí)現(xiàn) Swish 激活函數(shù)。我們將創(chuàng)建一個自定義 Swish 模塊并將其集成到一個簡單的神經(jīng)網(wǎng)絡(luò)中。
讓我們從導(dǎo)入必要的庫開始。
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader
一旦我們完成了庫的導(dǎo)入,我們就可以定義自定義激活——Swish。
以下代碼定義了一個繼承 PyTorch 基類的類。類內(nèi)部有一個forward方法。該方法定義模塊如何處理輸入數(shù)據(jù)。它將輸入張量作為參數(shù),并在應(yīng)用 Swish 激活后返回輸出張量。
# Swish功能 class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x)
定義 Swish 類后,我們繼續(xù)定義神經(jīng)網(wǎng)絡(luò)模型。
在下面的代碼片段中,我們使用 PyTorch 定義了一個專為圖像分類任務(wù)設(shè)計的神經(jīng)網(wǎng)絡(luò)模型。
輸入層有28×28像素。
隱藏層
- 第一個隱藏層由 256 個神經(jīng)元組成。它采用扁平輸入并應(yīng)用線性變換來產(chǎn)生輸出。
- 第二個隱藏層由 128 個神經(jīng)元組成,從前一層獲取 256 維輸出并產(chǎn)生 128 維輸出。
- Swish 激活函數(shù)應(yīng)用于兩個隱藏層,以向網(wǎng)絡(luò)引入非線性。
- 輸出層由 10 個神經(jīng)元組成,用于執(zhí)行 10 個類別的分類。
# 定義神經(jīng)網(wǎng)絡(luò)模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28 * 28, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 10) self.swish = Swish() def forward(self, x): x = x.view(-1, 28 * 28) x = self.fc1(x) x = self.swish(x) x = self.fc2(x) x = self.swish(x) x = self.fc3(x) return x
為了設(shè)置用于訓(xùn)練的神經(jīng)網(wǎng)絡(luò),我們創(chuàng)建模型的實(shí)例,定義損失函數(shù)、優(yōu)化器和數(shù)據(jù)轉(zhuǎn)換。
# 創(chuàng)建模型的實(shí)例 model = Net() # 定義損失函數(shù)和優(yōu)化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # 定義數(shù)據(jù)轉(zhuǎn)換 transform = transforms.Compose([ transforms.ToTensor(), ])
完成此步驟后,我們可以繼續(xù)在數(shù)據(jù)集上訓(xùn)練和評估模型。讓我們使用以下代碼加載 MNIST 數(shù)據(jù)并創(chuàng)建用于訓(xùn)練的數(shù)據(jù)加載器。
# 加載MNIST數(shù)據(jù)集 train_dataset = datasets.MNIST('', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('', train=False, download=True, transform=transform) # 創(chuàng)建數(shù)據(jù)加載器 train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
有了這些數(shù)據(jù)加載器,我們就可以繼續(xù)訓(xùn)練循環(huán)來迭代批量的訓(xùn)練和測試數(shù)據(jù)。
在下面的代碼中,我們執(zhí)行了神經(jīng)網(wǎng)絡(luò)的訓(xùn)練循環(huán)。該循環(huán)將重復(fù) 5 個時期,在此期間更新模型的權(quán)重,以最大限度地減少損失并提高其在訓(xùn)練數(shù)據(jù)上的性能。
# 訓(xùn)練循環(huán) num_epochs = 5 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) for epoch in range(num_epochs): model.train() total_loss = 0.0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")
輸出:
Epoch 1/5, Loss: 1.6938323568503062
Epoch 2/5, Loss: 0.4569567457397779
Epoch 3/5, Loss: 0.3522500048557917
Epoch 4/5, Loss: 0.31695075702369213
Epoch 5/5, Loss: 0.2961081813474496
最后一步是模型評估步驟。
# 評估循環(huán) model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) outputs = model(data) _, predicted = torch.max(outputs.data, 1) total += target.size(0) correct += (predicted == target).sum().item() print(f"Accuracy on test set: {100 * correct / total}%")
輸出:
Accuracy on test set: 92.02%
結(jié)論
Swish 激活函數(shù)為 ReLU 等傳統(tǒng)激活函數(shù)提供了一種有前景的替代方案。它的平滑性、可微性和改善深度網(wǎng)絡(luò)學(xué)習(xí)的潛力使其成為現(xiàn)代神經(jīng)網(wǎng)絡(luò)架構(gòu)的寶貴工具。通過在 PyTorch 中實(shí)施 Swish,您可以利用其優(yōu)勢并探索其在各種機(jī)器學(xué)習(xí)任務(wù)中的有效性。
以上就是使用Pytorch實(shí)現(xiàn)Swish激活函數(shù)的示例詳解的詳細(xì)內(nèi)容,更多關(guān)于Pytorch Swish激活函數(shù)的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
靈活運(yùn)用Python 枚舉類來實(shí)現(xiàn)設(shè)計狀態(tài)碼信息
在python中枚舉是一種類(Enum,IntEnum),存放在enum模塊中。枚舉類型可以給一組標(biāo)簽賦予一組特定的值,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-09-09Matplotlib實(shí)戰(zhàn)之柱狀圖繪制詳解
柱狀圖,是一種使用矩形條,對不同類別進(jìn)行數(shù)值比較的統(tǒng)計圖表,這篇文章主要為大家詳細(xì)介紹了如何使用Matplotlib繪制柱狀圖,需要的可以參考下2023-08-08Python教程pandas數(shù)據(jù)分析去重復(fù)值
Pandas指定行進(jìn)行去重更新值,加載數(shù)據(jù)sample抽樣函數(shù),指定需要更新的值append直接添加append函數(shù)用法,根據(jù)某一列key值進(jìn)行去重key唯一2021-09-09跟老齊學(xué)Python之使用Python操作數(shù)據(jù)庫(1)
本文詳細(xì)講述了使用python操作數(shù)據(jù)庫所需要了解的知識以及準(zhǔn)備工作,十分的詳盡,這里推薦給想學(xué)習(xí)python的小伙伴。2014-11-11使用FastCGI部署Python的Django應(yīng)用的教程
這篇文章主要介紹了使用FastCGI部署Python的Django應(yīng)用的教程,FastCGI也是被最廣泛的應(yīng)用于Python框架和服務(wù)器連接的模塊,需要的朋友可以參考下2015-07-07使用Python處理KNN分類算法的實(shí)現(xiàn)代碼
KNN分類算法(K-Nearest-Neighbors?Classification),又叫K近鄰算法,是一個概念極其簡單,而分類效果又很優(yōu)秀的分類算法,這篇文章主要介紹了使用Python處理KNN分類算法,需要的朋友可以參考下2022-09-09python反射機(jī)制內(nèi)置函數(shù)及場景構(gòu)造詳解
這篇文章主要為大家介紹了python反射機(jī)制內(nèi)置函數(shù)及場景構(gòu)造示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-11-11