欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch中的自定義反向傳播,求導(dǎo)實(shí)例

 更新時(shí)間:2020年01月06日 14:54:58   作者:xuxiaoyuxuxiaoyu  
今天小編就為大家分享一篇pytorch中的自定義反向傳播,求導(dǎo)實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧

pytorch中自定義backward()函數(shù)。在圖像處理過(guò)程中,我們有時(shí)候會(huì)使用自己定義的算法處理圖像,這些算法多是基于numpy或者scipy等包。

那么如何將自定義算法的梯度加入到pytorch的計(jì)算圖中,能使用Loss.backward()操作自動(dòng)求導(dǎo)并優(yōu)化呢。下面的代碼展示了這個(gè)功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想實(shí)現(xiàn)自動(dòng)求導(dǎo),必須同時(shí)實(shí)現(xiàn)forward(),backward()兩個(gè)函數(shù)。

1、從代碼中可以看出來(lái),forward()函數(shù)是針對(duì)numpy數(shù)據(jù)操作,返回值再重新指定為torch.Tensor類型。因此就有這個(gè)問(wèn)題出現(xiàn)了:forward輸入input被轉(zhuǎn)換為numpy類型,輸出轉(zhuǎn)換為tensor類型,那么輸出output的grad_fn參數(shù)是如何指定的呢。調(diào)試發(fā)現(xiàn),當(dāng)main()中hr的requires_grad被指定為True,即hr被指定為需要求導(dǎo)的葉子節(jié)點(diǎn)。只要Bicubic類繼承自torch.autograd.Function,那么output也就是代碼中的lr的grad_fn就會(huì)被指定為<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic這個(gè)類。

2、backward()為求導(dǎo)的函數(shù),gard_output是鏈?zhǔn)角髮?dǎo)法則的上一級(jí)的梯度,grad_input即為我們想要得到的梯度。只需要在輸入指定grad_output,在調(diào)用loss.backward()過(guò)程中的某一步會(huì)執(zhí)行到Bicubic的backwward()函數(shù)

以上這篇pytorch中的自定義反向傳播,求導(dǎo)實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Flask框架運(yùn)用WTForms實(shí)現(xiàn)用戶注冊(cè)的示例詳解

    Flask框架運(yùn)用WTForms實(shí)現(xiàn)用戶注冊(cè)的示例詳解

    WTForms 是用于web開(kāi)發(fā)的靈活的表單驗(yàn)證和呈現(xiàn)庫(kù),它可以與您選擇的任何web框架和模板引擎一起工作,并支持?jǐn)?shù)據(jù)驗(yàn)證、CSRF保護(hù)、國(guó)際化等。本文將運(yùn)用WTForms實(shí)現(xiàn)用戶注冊(cè)功能,需要的可以參考一下
    2022-12-12
  • Pytorch轉(zhuǎn)tflite方式

    Pytorch轉(zhuǎn)tflite方式

    這篇文章主要介紹了Pytorch轉(zhuǎn)tflite方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2020-05-05
  • 使用python的pexpect模塊,實(shí)現(xiàn)遠(yuǎn)程免密登錄的示例

    使用python的pexpect模塊,實(shí)現(xiàn)遠(yuǎn)程免密登錄的示例

    今天小編就為大家分享一篇使用python的pexpect模塊,實(shí)現(xiàn)遠(yuǎn)程免密登錄的示例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-02-02
  • Python命令啟動(dòng)Web服務(wù)器實(shí)例詳解

    Python命令啟動(dòng)Web服務(wù)器實(shí)例詳解

    這篇文章主要介紹了Python命令啟動(dòng)Web服務(wù)器實(shí)例詳解的相關(guān)資料,需要的朋友可以參考下
    2017-02-02
  • python sklearn中tsne算法降維結(jié)果不一致問(wèn)題的解決方法

    python sklearn中tsne算法降維結(jié)果不一致問(wèn)題的解決方法

    最近在做一個(gè)文本聚類的分析,在對(duì)文本數(shù)據(jù)embedding后,想著看下數(shù)據(jù)的分布,于是用sklearn的TSNE算法來(lái)降維embedding后的數(shù)據(jù)結(jié)果,當(dāng)在多次執(zhí)行后,竟發(fā)現(xiàn)TSNE的結(jié)果竟然變了,而且每次都不一樣,所以本文就給大家講講如何解決sklearn中tsne算法降維結(jié)果不一致的問(wèn)題
    2023-10-10
  • Python Pycharm虛擬下百度飛漿PaddleX安裝報(bào)錯(cuò)問(wèn)題及處理方法(親測(cè)100%有效)

    Python Pycharm虛擬下百度飛漿PaddleX安裝報(bào)錯(cuò)問(wèn)題及處理方法(親測(cè)100%有效)

    最近很多朋友給小編留言在安裝PaddleX的時(shí)候總是出現(xiàn)各種奇葩問(wèn)題,不知道該怎么處理,今天小編通過(guò)本文給大家介紹下Python Pycharm虛擬下百度飛漿PaddleX安裝報(bào)錯(cuò)問(wèn)題及處理方法,真的有效,遇到同樣問(wèn)題的朋友快來(lái)參考下吧
    2021-05-05
  • Python構(gòu)造函數(shù)與析構(gòu)函數(shù)超詳細(xì)分析

    Python構(gòu)造函數(shù)與析構(gòu)函數(shù)超詳細(xì)分析

    在python之中定義一個(gè)類的時(shí)候會(huì)在類中創(chuàng)建一個(gè)名為_(kāi)_init__的函數(shù),這個(gè)函數(shù)就叫做構(gòu)造函數(shù)。它的作用就是在實(shí)例化類的時(shí)候去自動(dòng)的定義一些屬性和方法的值,而析構(gòu)函數(shù)恰恰是一個(gè)和它相反的函數(shù),這篇文章主要介紹了Python構(gòu)造函數(shù)與析構(gòu)函數(shù)
    2022-11-11
  • Python中循環(huán)后使用list.append()數(shù)據(jù)被覆蓋問(wèn)題的解決

    Python中循環(huán)后使用list.append()數(shù)據(jù)被覆蓋問(wèn)題的解決

    這篇文章主要給大家介紹了關(guān)于Python中循環(huán)后使用list.append()數(shù)據(jù)被覆蓋問(wèn)題的解決方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2018-07-07
  • 值得收藏,Python 開(kāi)發(fā)中的高級(jí)技巧

    值得收藏,Python 開(kāi)發(fā)中的高級(jí)技巧

    這篇文章主要介紹了Python 開(kāi)發(fā)中的高級(jí)技巧,非常不錯(cuò),具有收藏價(jià)值,感興趣的朋友一起看看吧
    2018-11-11
  • Pytorch中的 torch.distributions庫(kù)詳解

    Pytorch中的 torch.distributions庫(kù)詳解

    這篇文章主要介紹了Pytorch中的 torch.distributions庫(kù),本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2023-02-02

最新評(píng)論