Pytorch?linear?多維輸入的參數問題
問題: 由于 在輸入lstm 層 每個batch 做了根據輸入序列最大長度做了padding,導致每個 batch 的 length 不同。 導致輸出 長度不同 。如:(batch, length, output_dim): (12,128,10),(12,111,10). 但是輸入 linear 層的時候沒有出現問題。
網站解釋:
官網 pytorch linear:
- Input:(*, H_{in})(∗,Hin?)where*∗means any number of dimensions including none andH_{in} = \text{in\_features}Hin?=in_features. 任意維度 number 理解有歧義 (a)number. k可以理解三維,四維。。。 (b) 可以理解 為某一維度的數 。
- Output:(*, H_{out})(∗,Hout?)where all but the last dimension are the same shape as the input andH_{out} = \text{out\_features}Hout?=out_features.
代碼解釋:
分別 用三維 和二維輸入數組,查看他們參數數目是否一樣。
import torch x = torch.randn(128, 20) # 輸入的維度是(128,20) m = torch.nn.Linear(20, 30) # 20,30是指維度 output = m(x) print('m.weight.shape:\n ', m.weight.shape) print('m.bias.shape:\n', m.bias.shape) print('output.shape:\n', output.shape) # ans = torch.mm(input,torch.t(m.weight))+m.bias 等價于下面的 ans = torch.mm(x, m.weight.t()) + m.bias print('ans.shape:\n', ans.shape) print(torch.equal(ans, output))
output:
m.weight.shape: torch.Size([30, 20]) m.bias.shape: torch.Size([30]) output.shape: torch.Size([128, 30]) ans.shape: torch.Size([128, 30]) True
x = torch.randn(128, 30,20) # 輸入的維度是(128,30,20) m = torch.nn.Linear(20, 30) # 20,30是指維度 output = m(x) print('m.weight.shape:\n ', m.weight.shape) print('m.bias.shape:\n', m.bias.shape) print('output.shape:\n', output.shape)
ouput: m.weight.shape: torch.Size([30, 20]) m.bias.shape: torch.Size([30]) output.shape: torch.Size([128, 30, 30])
結果:
(128,30,20),和 (128,20) 分別是如 nn.linear(30,20) 層。
weight.shape 均為: (30,20)
linear() 參數數目只和 input_dim ,output_dim 有關。
weight 在源碼的定義, 沒找到如何計算多維input的代碼。
到此這篇關于Pytorch linear 多維 輸入的參數的文章就介紹到這了,更多相關Pytorch多維 輸入內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python Unittest根據不同測試環(huán)境跳過用例的方法
這篇文章主要給大家介紹了關于Python Unittest如何根據不同測試環(huán)境跳過用例的相關資料,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面來一起看看吧2018-12-12Django框架創(chuàng)建mysql連接與使用示例
這篇文章主要介紹了Django框架創(chuàng)建mysql連接與使用,簡單介紹了Linux環(huán)境下mysql的安裝,并結合實例形式分析了Django框架基于第三方庫pymysql連接mysql數據庫相關操作技巧,需要的朋友可以參考下2019-07-07python matplotlib 畫dataframe的時間序列圖實例
今天小編就為大家分享一篇python matplotlib 畫dataframe的時間序列圖實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-11-11