亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

pytorch自定義不可導(dǎo)激活函數(shù)的操作

 更新時(shí)間:2021年06月05日 14:46:53   作者:Luna_Lovegood_001  
這篇文章主要介紹了pytorch自定義不可導(dǎo)激活函數(shù)的操作,具有很好的參考價(jià)值,希望大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

pytorch自定義不可導(dǎo)激活函數(shù)

今天自定義不可導(dǎo)函數(shù)的時(shí)候遇到了一個(gè)大坑。

首先我需要自定義一個(gè)函數(shù):sign_f

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

然后我需要把它封裝為一個(gè)module 類型,就像 nn.Conv2d 模塊 封裝 f.conv2d 一樣,于是

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
	# 我需要的module
    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        
    def forward(self, inputs):
    	# 使用自定義函數(shù)
        outs = sign_f(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

結(jié)果報(bào)錯(cuò)

TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'

我試了半天,發(fā)現(xiàn)自定義函數(shù)后面要加 apply ,詳細(xì)見下面

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):

    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        self.r = sign_f.apply ### <-----注意此處
        
    def forward(self, inputs):
        outs = self.r(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

問題解決了!

PyTorch自定義帶學(xué)習(xí)參數(shù)的激活函數(shù)(如sigmoid)

有的時(shí)候我們需要給損失函數(shù)設(shè)一個(gè)超參數(shù)但是又不想設(shè)固定閾值想和網(wǎng)絡(luò)一起自動(dòng)學(xué)習(xí),例如給Sigmoid一個(gè)參數(shù)alpha進(jìn)行調(diào)節(jié)

在這里插入圖片描述

在這里插入圖片描述

函數(shù)如下:

import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))

驗(yàn)證和Sigmoid的一致性

class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
   
Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
        [0.4379, 0.1828, 0.4629],
        [0.4302, 0.1358, 0.4180]])

print(Sigmoid(input))
print(LearnSigmoid(input))

輸出結(jié)果

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=<MulBackward0>)

驗(yàn)證權(quán)重是不是會(huì)更新

import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()       
        self.LSigmoid = LearnableSigmoid()
    def forward(self, x):                
        x = self.LSigmoid(x)
        return x

net = Net()  
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)

for i in range(2):
    optimizer.zero_grad()     
    output = net(input_data)   
    loss = criterion(output, target)
    loss.backward()             
    optimizer.step()           
    print(list(net.parameters()))

輸出結(jié)果

tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]

會(huì)更新~

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python常用數(shù)據(jù)庫接口sqlite3和MySQLdb學(xué)習(xí)指南

    Python常用數(shù)據(jù)庫接口sqlite3和MySQLdb學(xué)習(xí)指南

    在本章節(jié)中,我們將學(xué)習(xí) Python 中常用的數(shù)據(jù)庫接口,包括 sqlite3用于SQLite數(shù)據(jù)庫和MySQLdb用于 MySQL 數(shù)據(jù)庫,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-06-06
  • 音頻處理 windows10下python三方庫librosa安裝教程

    音頻處理 windows10下python三方庫librosa安裝教程

    這篇文章主要介紹了音頻處理 windows10下python三方庫librosa安裝方法,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-06-06
  • 淺析python遞歸函數(shù)和河內(nèi)塔問題

    淺析python遞歸函數(shù)和河內(nèi)塔問題

    這篇文章主要介紹了python遞歸函數(shù)和河內(nèi)塔問題,非常不錯(cuò),具有參考借鑒價(jià)值,需要的朋友可以參考下
    2017-04-04
  • Python成功解決TypeError: ‘method’ object is not subscriptable

    Python成功解決TypeError: ‘method’ object is 

    在Python編程中,有時(shí)候我們可能會(huì)遇到一個(gè)讓人摸不著頭腦的錯(cuò)誤信息:TypeError: 'method' object is not subscriptable,本文給大家介紹了Python如何成功解決TypeError: ‘method’ object is not subscriptable,需要的朋友可以參考下
    2024-06-06
  • python基礎(chǔ)之貪婪模式與非貪婪模式

    python基礎(chǔ)之貪婪模式與非貪婪模式

    這篇文章主要介紹了python貪婪模式與非貪婪模式 ,實(shí)例分析了Python中返回一個(gè)返回值與多個(gè)返回值的方法,需要的朋友可以參考下
    2021-10-10
  • python版微信跳一跳游戲輔助

    python版微信跳一跳游戲輔助

    這篇文章主要為大家詳細(xì)介紹了python版微信跳一跳游戲輔助,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2018-01-01
  • Python列表常見操作詳解(獲取,增加,刪除,修改,排序等)

    Python列表常見操作詳解(獲取,增加,刪除,修改,排序等)

    這篇文章主要介紹了Python列表常見操作,結(jié)合實(shí)例形式總結(jié)分析了Python列表常見的獲取、增加、刪除、修改、排序、計(jì)算等相關(guān)操作技巧,需要的朋友可以參考下
    2019-02-02
  • Python數(shù)據(jù)分析numpy文本數(shù)據(jù)讀取索引切片實(shí)例詳解

    Python數(shù)據(jù)分析numpy文本數(shù)據(jù)讀取索引切片實(shí)例詳解

    這篇文章主要為大家介紹了Python數(shù)據(jù)分析numpy文本數(shù)據(jù)讀取索引切片實(shí)例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-08-08
  • python 循環(huán)遍歷字典元素的簡單方法

    python 循環(huán)遍歷字典元素的簡單方法

    下面小編就為大家?guī)硪黄猵ython循環(huán)遍歷字典元素的簡單方法。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2016-09-09
  • Python面向?qū)ο笾K詳解

    Python面向?qū)ο笾K詳解

    這篇文章主要為大家介紹了Python面向?qū)ο笾K,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來幫助
    2021-12-12

最新評論