對(duì)pytorch網(wǎng)絡(luò)層結(jié)構(gòu)的數(shù)組化詳解
最近再寫(xiě)openpose,它的網(wǎng)絡(luò)結(jié)構(gòu)是多階段的網(wǎng)絡(luò),所以寫(xiě)網(wǎng)絡(luò)的時(shí)候很想用列表的方式,但是直接使用列表不能將網(wǎng)絡(luò)中相應(yīng)的部分放入到cuda中去。
其實(shí)這個(gè)問(wèn)題很簡(jiǎn)單的,使用moduleList就好了。
1 我先是定義了一個(gè)函數(shù),用來(lái)根據(jù)超參數(shù),建立一個(gè)基礎(chǔ)網(wǎng)絡(luò)結(jié)構(gòu)
stage = [[3, 3, 3, 1, 1], [7, 7, 7, 7, 7, 1, 1]] branches_cfg = [[[128, 128, 128, 512, 38], [128, 128, 128, 512, 19]], [[128, 128, 128, 128, 128, 128, 38], [128, 128, 128, 128, 128, 128, 19]]] # used for add two branches as well as adapt to certain stage def add_extra(i, branches_cfg, stage): """ only add CNN of brancdes S & L in stage Ti at the end of net :param in_channels:the input channels & out :param stage: size of filter :param branches_cfg: channels of image :return:list of layers """ in_channels = i layers = [] for k in range(len(stage)): padding = stage[k] // 2 conv2d = nn.Conv2d(in_channels, branches_cfg[k], kernel_size=stage[k], padding=padding) layers += [conv2d, nn.ReLU(inplace=True)] in_channels = branches_cfg[k] return layers
2 然后用普通列表裝載他們
conf_bra_list = [] paf_bra_list = [] # param for branch network in_channels = 128 for i in range(all_stage): if i > 0: branches = branches_cfg[1] conv_sz = stage[1] else: branches = branches_cfg[0] conv_sz = stage[0] conf_bra_list.append(nn.Sequential(*add_extra(in_channels, branches[0], conv_sz))) paf_bra_list.append(nn.Sequential(*add_extra(in_channels, branches[1], conv_sz))) in_channels = 185
3 再然后,使用moduleList方法,把普通列表專(zhuān)成pytorch下的模塊
# to list self.conf_bra = nn.ModuleList(conf_bra_list) self.paf_bra = nn.ModuleList(paf_bra_list)
4 最后,調(diào)用就好了
out_0 = x # the base transform for k in range(len(self.vgg)): out_0 = self.vgg[k](out_0) # local name space name = locals() confs = [] pafs = [] outs = [] length = len(self.conf_bra) for i in range(length): name['conf_%s' % (i + 1)] = self.conf_bra[i](name['out_%s' % i]) name['paf_%s' % (i + 1)] = self.paf_bra[i](name['out_%s' % i]) name['out_%s' % (i + 1)] = torch.cat([name['conf_%s' % (i + 1)], name['paf_%s' % (i + 1)], out_0], 1) confs.append('conf_%s' % (i + 1)) pafs.append('paf_%s' % (i + 1)) outs.append('out_%s' % (i + 1))
5 順便裝了一下,使用了python局部變量命名空間,name = locals(),其實(shí)完全使用普通列表保存變量就好了,高興就好。
以上這篇對(duì)pytorch網(wǎng)絡(luò)層結(jié)構(gòu)的數(shù)組化詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python 實(shí)現(xiàn)人和電腦猜拳的示例代碼
這篇文章主要介紹了python 實(shí)現(xiàn)人和電腦猜拳的示例代碼,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-03-03python基于win32api實(shí)現(xiàn)鍵盤(pán)輸入
這篇文章主要介紹了python基于win32api實(shí)現(xiàn)鍵盤(pán)輸入,幫助大家更好的理解和使用python,感興趣的朋友可以了解下2020-12-12opencv3/C++ 平面對(duì)象識(shí)別&透視變換方式
今天小編就為大家分享一篇opencv3/C++ 平面對(duì)象識(shí)別&透視變換方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-12-12PyTorch實(shí)現(xiàn)重寫(xiě)/改寫(xiě)Dataset并載入Dataloader
這篇文章主要介紹了PyTorch實(shí)現(xiàn)重寫(xiě)/改寫(xiě)Dataset并載入Dataloader,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-07-07Python的10道簡(jiǎn)單測(cè)試題(含答案)
這篇文章主要介紹了Python的10道簡(jiǎn)單測(cè)試題(含答案),學(xué)習(xí)了一段時(shí)間python的小伙伴來(lái)做幾道測(cè)試題檢驗(yàn)一下自己的學(xué)習(xí)成果吧2023-04-04Python return語(yǔ)句如何實(shí)現(xiàn)結(jié)果返回調(diào)用
這篇文章主要介紹了Python return語(yǔ)句如何實(shí)現(xiàn)結(jié)果返回調(diào)用,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-10-10