欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch 如何檢查模型梯度是否可導(dǎo)

 更新時(shí)間:2021年06月05日 11:44:43   作者:煙雨風(fēng)渡  
這篇文章主要介紹了PyTorch 檢查模型梯度是否可導(dǎo)的操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

一、PyTorch 檢查模型梯度是否可導(dǎo)

當(dāng)我們構(gòu)建復(fù)雜網(wǎng)絡(luò)模型或在模型中加入復(fù)雜操作時(shí),可能會(huì)需要驗(yàn)證該模型或操作是否可導(dǎo),即模型是否能夠優(yōu)化,在PyTorch框架下,我們可以使用torch.autograd.gradcheck函數(shù)來實(shí)現(xiàn)這一功能。

首先看一下官方文檔中關(guān)于該函數(shù)的介紹:

可以看到官方文檔中介紹了該函數(shù)基于何種方法,以及其參數(shù)列表,下面給出幾個(gè)例子介紹其使用方法,注意:

Tensor需要是雙精度浮點(diǎn)型且設(shè)置requires_grad = True

第一個(gè)例子:檢查某一操作是否可導(dǎo)

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

輸出為:

Are the gradients correct: True

第二個(gè)例子:檢查某一網(wǎng)絡(luò)模型是否可導(dǎo)

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定義神經(jīng)網(wǎng)絡(luò)模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

輸出為:

Are the gradients correct: True

二、Pytorch求導(dǎo)

1.標(biāo)量對(duì)矩陣求導(dǎo)

在這里插入圖片描述

驗(yàn)證:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩陣,注意,值必須要是float類型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都為默認(rèn)(默認(rèn)為False),所以求導(dǎo)時(shí),沒有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩陣對(duì)矩陣求導(dǎo)

在這里插入圖片描述 在這里插入圖片描述

驗(yàn)證:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩陣
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩陣
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括號(hào)里要加上這句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad為True的數(shù)組必須是float類型

進(jìn)行backgrad的必須是標(biāo)量,如果是向量,必須在后面括號(hào)里加上torch.ones_like(X)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python隊(duì)列Queue實(shí)現(xiàn)詳解

    Python隊(duì)列Queue實(shí)現(xiàn)詳解

    這篇文章主要介紹了Python隊(duì)列Queue實(shí)現(xiàn)詳解,隊(duì)列是一種列表,隊(duì)列用于存儲(chǔ)按順序排列的數(shù)據(jù),隊(duì)列是一種先進(jìn)先出的數(shù)據(jù)結(jié)構(gòu),不同的是隊(duì)列只能在隊(duì)尾插入元素,在隊(duì)首刪除元素,需要的朋友可以參考下
    2023-07-07
  • Python自動(dòng)發(fā)送郵件的方法實(shí)例總結(jié)

    Python自動(dòng)發(fā)送郵件的方法實(shí)例總結(jié)

    這篇文章主要介紹了Python自動(dòng)發(fā)送郵件的方法,結(jié)合實(shí)例形式總結(jié)分析了Python使用smtplib和email模塊發(fā)送郵件的相關(guān)使用技巧與操作注意事項(xiàng),需要的朋友可以參考下
    2018-12-12
  • Python實(shí)現(xiàn)簡單截取中文字符串的方法

    Python實(shí)現(xiàn)簡單截取中文字符串的方法

    這篇文章主要介紹了Python實(shí)現(xiàn)簡單截取中文字符串的方法,涉及Python字符串截取與編碼轉(zhuǎn)換的相關(guān)技巧,需要的朋友可以參考下
    2015-06-06
  • 解決Pycharm中import時(shí)無法識(shí)別自己寫的程序方法

    解決Pycharm中import時(shí)無法識(shí)別自己寫的程序方法

    今天小編就為大家分享一篇解決Pycharm中import時(shí)無法識(shí)別自己寫的程序方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2018-05-05
  • Python圖像識(shí)別+KNN求解數(shù)獨(dú)的實(shí)現(xiàn)

    Python圖像識(shí)別+KNN求解數(shù)獨(dú)的實(shí)現(xiàn)

    這篇文章主要介紹了Python圖像識(shí)別+KNN求解數(shù)獨(dú)的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-11-11
  • Python format()格式化輸出方法

    Python format()格式化輸出方法

    這篇文章主要介紹了Python format()格式化輸出方法, Python 2.6以后,Python 中的就提供了字符串類型(str)提供了 format() 方法對(duì)字符串進(jìn)行格式化,夏敏我們就來了解這個(gè)方法吧,需要的小伙伴也可以參考一下

    2021-12-12
  • Python爬取股票信息,并可視化數(shù)據(jù)的示例

    Python爬取股票信息,并可視化數(shù)據(jù)的示例

    這篇文章主要介紹了Python爬取股票信息,并可視化數(shù)據(jù)的示例,幫助大家更好的理解和使用python爬蟲,感興趣的朋友可以了解下
    2020-09-09
  • PyCharm中Matplotlib繪圖不能顯示UI效果的問題解決

    PyCharm中Matplotlib繪圖不能顯示UI效果的問題解決

    這篇文章主要介紹了PyCharm中Matplotlib繪圖不能顯示UI效果的問題解決,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-03-03
  • Python實(shí)現(xiàn)的字典值比較功能示例

    Python實(shí)現(xiàn)的字典值比較功能示例

    這篇文章主要介紹了Python實(shí)現(xiàn)的字典值比較功能,可實(shí)現(xiàn)針對(duì)字典格式數(shù)據(jù)的判斷、比較功能,涉及Python字典格式數(shù)據(jù)的遍歷、判斷等相關(guān)操作技巧,需要的朋友可以參考下
    2018-01-01
  • TensorFlow tf.nn.conv2d_transpose是怎樣實(shí)現(xiàn)反卷積的

    TensorFlow tf.nn.conv2d_transpose是怎樣實(shí)現(xiàn)反卷積的

    這篇文章主要介紹了TensorFlow tf.nn.conv2d_transpose是怎樣實(shí)現(xiàn)反卷積的,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-04-04

最新評(píng)論