pytorch網(wǎng)絡模型構建場景的問題介紹
記錄使用pytorch構建網(wǎng)絡模型過程遇到的點
1. 網(wǎng)絡模型構建中的問題
1.1 輸入變量是Tensor張量
各個模塊和網(wǎng)絡模型的輸入, 一定要是tensor
張量;
可以用一個列表存放多個張量。
如果是張量維度不夠,需要升維度,
可以先使用 torch.unsqueeze(dim = expected)
然后再使用torch.cat(dim )
進行拼接;
需要傳遞梯度的數(shù)據(jù),禁止使用numpy
, 也禁止先使用numpy,然后再轉換成張量的這種情況出現(xiàn);
這是因為pytorch的機制是只有是 Tensor
張量的類型,才會有梯度等屬性值,如果是numpy這些類別,這些變量并會丟失其梯度值。
1.2 __init__()方法使用
class ex: def __init__(self): pass
__init__
方法必須接受至少一個參數(shù)即self,
Python中,self是指向該對象本身的一個引用,
通過在類的內(nèi)部使用self變量,
類中的方法可以訪問自己的成員變量,簡單來說,self.varname的意義為”訪問該對象的varname屬性“
當然,__init__()
中可以封裝任意的程序邏輯,這是允許的,init()方法還接受任意多個其他參數(shù),允許在初始化時提供一些數(shù)據(jù),例如,對于剛剛的worker類,可以這樣寫:
class worker: def __init__(self,name,pay): self.name=name self.pay=pay
這樣,在創(chuàng)建worker類的對象時,必須提供name和pay兩個參數(shù):
b=worker('Jim',5000)
Python會自動調(diào)用worker.init()方法,并傳遞參數(shù)。
細節(jié)參考這里init方法
1.3 內(nèi)置函數(shù)setattr()
此時,可以使用python自帶的內(nèi)置函數(shù) setattr()
, 和對應的getattr()
setattr(object, name, value)
object – 對象。
name – 字符串,對象屬性。
value – 屬性值。
對已存在的屬性進行賦值:
>>>class A(object):
... bar = 1
...
>>> a = A()
>>> getattr(a, 'bar') # 獲取屬性 bar 值
1
>>> setattr(a, 'bar', 5) # 設置屬性 bar 值
>>> a.bar
5
如果屬性不存在會創(chuàng)建一個新的對象屬性,并對屬性賦值:>>>class A():
... name = "runoob"
...
>>> a = A()
>>> setattr(a, "age", 28)
>>> print(a.age)
28
>>>
setattr() 語法
setattr(object, name, value)
object – 對象。
name – 字符串,對象屬性。
value – 屬性值。
1.4 網(wǎng)絡模型的構建
注意到, 在python的 __init__()
函數(shù)中, self
本身就是該類的對象的一個引用,即self是指向該對象本身的一個引用,
利用上述這一點,當在神經(jīng)網(wǎng)絡中,
需要給多個屬性進行實例化時,
且這多個屬性使用的是同一個類進行實例化.
則使用 setattr(self, string, object1)
添加屬性;
class Temporal_GroupTrans(nn.Module): def __init__(self, num_classes=10,num_groups=35, drop_prob=0.5, pretrained= True): super(Temporal_GroupTrans, self).__init__() conv_block = Basic_slide_conv() for i in range( num_groups): setattr(self, "group" + str(i), conv_block) # 自定義transformer模型的初始化, CustomTransformerModel() 在該類中傳入初始化模型的參數(shù), # nip:512 輸入序列中,每個列向量的編碼維度, 16: 注意力頭的個數(shù) # 600: 中間mlp 隱藏層的維數(shù), 6: 堆疊transforEncode 編碼模塊的個數(shù); self.trans_model = CustomTransformerModel(512,16,600, 6,droupout=0.5,nclass=4)
則使用 getattr(self, string, object1)
獲取屬性;
trans_input_sequence = [] for i in range(0, num_groups, ): # 每組語譜圖的大小是一個 (bt, ch,96,12)的矩陣,組與組之間沒有重疊; cur_group = x[:, :, :, 12 * i:12 * (i + 1)] # VARIABLE_fun = "self.group" # 每一組,與之對應的卷積模塊; # cur_fun = eval(VARIABLE_fun + str(i )) cur_fun = getattr(self, 'group'+str(i)) cur_group_out = cur_fun(cur_group).unsqueeze(dim=1) # [bt,1, 512] trans_input_sequence.append(cur_group_out)
到此這篇關于pytorch網(wǎng)絡模型構建場景的問題介紹的文章就介紹到這了,更多相關pytorch網(wǎng)絡模型構建內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
django第一個項目127.0.0.1:8000不能訪問的解決方案詳析
django項目服務啟動后無法通過127.0.0.1訪問,下面這篇文章主要給大家介紹了關于django第一個項目127.0.0.1:8000不能訪問的解決方案,需要的朋友可以參考下2022-10-10解決Python運行文件出現(xiàn)out of memory框的問題
今天小編就為大家分享一篇解決Python運行文件出現(xiàn)out of memory框的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12