pytorch自定義二值化網(wǎng)絡(luò)層方式
任務(wù)要求:
自定義一個層主要是定義該層的實現(xiàn)函數(shù),只需要重載Function的forward和backward函數(shù)即可,如下:
import torch from torch.autograd import Function from torch.autograd import Variable
定義二值化函數(shù)
class BinarizedF(Function): def forward(self, input): self.save_for_backward(input) a = torch.ones_like(input) b = -torch.ones_like(input) output = torch.where(input>=0,a,b) return output def backward(self, output_grad): input, = self.saved_tensors input_abs = torch.abs(input) ones = torch.ones_like(input) zeros = torch.zeros_like(input) input_grad = torch.where(input_abs<=1,ones, zeros) return input_grad
定義一個module
class BinarizedModule(nn.Module): def __init__(self): super(BinarizedModule, self).__init__() self.BF = BinarizedF() def forward(self,input): print(input.shape) output =self.BF(input) return output
進(jìn)行測試
a = Variable(torch.randn(4,480,640), requires_grad=True) output = BinarizedModule()(a) output.backward(torch.ones(a.size())) print(a) print(a.grad)
其中, 二值化函數(shù)部分也可以按照方式寫,但是速度慢了0.05s
class BinarizedF(Function): def forward(self, input): self.save_for_backward(input) output = torch.ones_like(input) output[input<0] = -1 return output def backward(self, output_grad): input, = self.saved_tensors input_grad = output_grad.clone() input_abs = torch.abs(input) input_grad[input_abs>1] = 0 return input_grad
以上這篇pytorch自定義二值化網(wǎng)絡(luò)層方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
PyCharm連接遠(yuǎn)程服務(wù)器的超級詳細(xì)教程
Pycharm可以與服務(wù)器建立連接,把相應(yīng)的項目同步到服務(wù)器上,下面這篇文章主要給大家介紹了關(guān)于PyCharm連接遠(yuǎn)程服務(wù)器的超級詳細(xì)教程,文中通過圖文介紹的非常詳細(xì),需要的朋友可以參考下2022-12-12Qt實現(xiàn)炫酷啟動圖動態(tài)進(jìn)度條效果
最近接到一個新需求,讓做一個動效進(jìn)度條。剛接手這個項目真的不知所措,后來慢慢理清思路,問題迎刃而解,下面小編通過本文給大家?guī)砹薗t實現(xiàn)炫酷啟動圖動態(tài)進(jìn)度條效果,感興趣的朋友一起看看吧2021-11-11pandas按行按列遍歷Dataframe的三種方式小結(jié)
本文主要介紹了pandas按行按列遍歷Dataframe,主要介紹了三種方法,具有一定的參考價值,感興趣的可以了解一下2023-11-11python入門學(xué)習(xí)關(guān)于for else的特殊特性講解
本文將介紹 Python 中的" for-else"特性,并通過簡單的示例說明如何正確使用它,有需要的朋友可以借鑒參考下,希望能夠有所幫助2021-11-11python numpy中array與pandas的DataFrame轉(zhuǎn)換方式
這篇文章主要介紹了python numpy中array與pandas的DataFrame轉(zhuǎn)換方式,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-07-07paramiko模塊安裝和使用(遠(yuǎn)程登錄服務(wù)器)
paramiko是用python語言寫的一個模塊,遵循SSH2協(xié)議,支持以加密和認(rèn)證的方式,進(jìn)行遠(yuǎn)程服務(wù)器的連接,下面學(xué)習(xí)一下它的使用方法2014-01-01詳解OpenCV中直方圖,掩膜和直方圖均衡化的實現(xiàn)
這篇文章主要為大家詳細(xì)介紹了OpenCV中直方圖、掩膜、直方圖均衡化詳細(xì)介紹及代碼的實現(xiàn),文中的示例代碼講解詳細(xì),需要的可以參考一下2022-11-11