pytorch自定義二值化網絡層方式
更新時間:2020年01月07日 13:44:24 作者:ChLee98
今天小編就為大家分享一篇pytorch自定義二值化網絡層方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
任務要求:
自定義一個層主要是定義該層的實現(xiàn)函數(shù),只需要重載Function的forward和backward函數(shù)即可,如下:
import torch from torch.autograd import Function from torch.autograd import Variable
定義二值化函數(shù)
class BinarizedF(Function): def forward(self, input): self.save_for_backward(input) a = torch.ones_like(input) b = -torch.ones_like(input) output = torch.where(input>=0,a,b) return output def backward(self, output_grad): input, = self.saved_tensors input_abs = torch.abs(input) ones = torch.ones_like(input) zeros = torch.zeros_like(input) input_grad = torch.where(input_abs<=1,ones, zeros) return input_grad
定義一個module
class BinarizedModule(nn.Module): def __init__(self): super(BinarizedModule, self).__init__() self.BF = BinarizedF() def forward(self,input): print(input.shape) output =self.BF(input) return output
進行測試
a = Variable(torch.randn(4,480,640), requires_grad=True) output = BinarizedModule()(a) output.backward(torch.ones(a.size())) print(a) print(a.grad)
其中, 二值化函數(shù)部分也可以按照方式寫,但是速度慢了0.05s
class BinarizedF(Function): def forward(self, input): self.save_for_backward(input) output = torch.ones_like(input) output[input<0] = -1 return output def backward(self, output_grad): input, = self.saved_tensors input_grad = output_grad.clone() input_abs = torch.abs(input) input_grad[input_abs>1] = 0 return input_grad
以上這篇pytorch自定義二值化網絡層方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
python numpy中array與pandas的DataFrame轉換方式
這篇文章主要介紹了python numpy中array與pandas的DataFrame轉換方式,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-07-07詳解OpenCV中直方圖,掩膜和直方圖均衡化的實現(xiàn)
這篇文章主要為大家詳細介紹了OpenCV中直方圖、掩膜、直方圖均衡化詳細介紹及代碼的實現(xiàn),文中的示例代碼講解詳細,需要的可以參考一下2022-11-11