欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch轉(zhuǎn)keras的有效方法,以FlowNet為例講解

 更新時間:2020年05月26日 08:43:52   作者:咆哮的阿杰  
這篇文章主要介紹了Pytorch轉(zhuǎn)keras的有效方法,以FlowNet為例講解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

Pytorch憑借動態(tài)圖機制,獲得了廣泛的使用,大有超越tensorflow的趨勢,不過在工程應(yīng)用上,TF仍然占據(jù)優(yōu)勢。有的時候我們會遇到這種情況,需要把模型應(yīng)用到工業(yè)中,運用到實際項目上,TF支持的PB文件和TF的C++接口就成為了有效的工具。今天就給大家講解一下Pytorch轉(zhuǎn)成Keras的方法,進而我們也可以獲得Pb文件,因為Keras是支持tensorflow的,我將會在下一篇博客講解獲得Pb文件,并使用Pb文件的方法。

Pytorch To Keras

首先,我們必須有清楚的認識,網(wǎng)上以及github上一些所謂的pytorch轉(zhuǎn)換Keras或者Keras轉(zhuǎn)換成Pytorch的工具代碼幾乎不能運行或者有使用的局限性(比如僅僅能轉(zhuǎn)換某一些模型),但是我們是可以用這些轉(zhuǎn)換代碼中看出一些端倪來,比如二者的參數(shù)的尺寸(shape)的形式、channel的排序(first or last)是否一樣,掌握到差異性,就能根據(jù)這些差異自己編寫轉(zhuǎn)換代碼,沒錯,自己編寫轉(zhuǎn)換代碼,是最穩(wěn)妥的辦法。整個過程也就分為兩個部分。筆者將會以Nvidia開源的FlowNet為例,將開源的Pytorch代碼轉(zhuǎn)化為Keras模型。

按照Pytorch中模型的結(jié)構(gòu),編寫對應(yīng)的Keras代碼,用keras的函數(shù)式API,構(gòu)建起來會非常方便。

把Pytorch的模型參數(shù),按照層的名稱依次賦值給Keras的模型

以上兩步雖然看上去簡單,但實際我也走了不少彎路。這里一個關(guān)鍵的地方,就是參數(shù)的shape在兩個框架中是否統(tǒng)一,那當然是不統(tǒng)一的。下面我以FlowNet為例。

Pytorch中的FlowNet代碼

我們僅僅展示層名稱和層參數(shù),就不把整個結(jié)構(gòu)貼出來了,否則會占很多的空間,形成水文。

先看用Keras搭建的flowNet模型,直接用model.summary()輸出模型信息

__________________________________________________________________________________________________
Layer (type)   Output Shape  Param # Connected to   
==================================================================================================
input_1 (InputLayer)  (None, 6, 512, 512) 0      
__________________________________________________________________________________________________
conv0 (Conv2D)   (None, 64, 512, 512) 3520 input_1[0][0]   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 64, 512, 512) 0  conv0[0][0]   
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 64, 514, 514) 0  leaky_re_lu_1[0][0]  
__________________________________________________________________________________________________
conv1 (Conv2D)   (None, 64, 256, 256) 36928 zero_padding2d_1[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 64, 256, 256) 0  conv1[0][0]   
__________________________________________________________________________________________________
conv1_1 (Conv2D)  (None, 128, 256, 256 73856 leaky_re_lu_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 128, 256, 256 0  conv1_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 128, 258, 258 0  leaky_re_lu_3[0][0]  
__________________________________________________________________________________________________
conv2 (Conv2D)   (None, 128, 128, 128 147584 zero_padding2d_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 128, 128, 128 0  conv2[0][0]   
__________________________________________________________________________________________________
conv2_1 (Conv2D)  (None, 128, 128, 128 147584 leaky_re_lu_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 128, 128, 128 0  conv2_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 128, 130, 130 0  leaky_re_lu_5[0][0]  
__________________________________________________________________________________________________
conv3 (Conv2D)   (None, 256, 64, 64) 295168 zero_padding2d_3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 256, 64, 64) 0  conv3[0][0]   
__________________________________________________________________________________________________
conv3_1 (Conv2D)  (None, 256, 64, 64) 590080 leaky_re_lu_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 256, 64, 64) 0  conv3_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 256, 66, 66) 0  leaky_re_lu_7[0][0]  
__________________________________________________________________________________________________
conv4 (Conv2D)   (None, 512, 32, 32) 1180160 zero_padding2d_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 512, 32, 32) 0  conv4[0][0]   
__________________________________________________________________________________________________
conv4_1 (Conv2D)  (None, 512, 32, 32) 2359808 leaky_re_lu_8[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None, 512, 32, 32) 0  conv4_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_5 (ZeroPadding2D (None, 512, 34, 34) 0  leaky_re_lu_9[0][0]  
__________________________________________________________________________________________________
conv5 (Conv2D)   (None, 512, 16, 16) 2359808 zero_padding2d_5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU) (None, 512, 16, 16) 0  conv5[0][0]   
__________________________________________________________________________________________________
conv5_1 (Conv2D)  (None, 512, 16, 16) 2359808 leaky_re_lu_10[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU) (None, 512, 16, 16) 0  conv5_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_6 (ZeroPadding2D (None, 512, 18, 18) 0  leaky_re_lu_11[0][0]  
__________________________________________________________________________________________________
conv6 (Conv2D)   (None, 1024, 8, 8) 4719616 zero_padding2d_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU) (None, 1024, 8, 8) 0  conv6[0][0]   
__________________________________________________________________________________________________
conv6_1 (Conv2D)  (None, 1024, 8, 8) 9438208 leaky_re_lu_12[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU) (None, 1024, 8, 8) 0  conv6_1[0][0]   
__________________________________________________________________________________________________
deconv5 (Conv2DTranspose) (None, 512, 16, 16) 8389120 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
predict_flow6 (Conv2D)  (None, 2, 8, 8) 18434 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU) (None, 512, 16, 16) 0  deconv5[0][0]   
__________________________________________________________________________________________________
upsampled_flow6_to_5 (Conv2DTra (None, 2, 16, 16) 66  predict_flow6[0][0]  
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 1026, 16, 16) 0  leaky_re_lu_11[0][0]  
         leaky_re_lu_14[0][0]  
         upsampled_flow6_to_5[0][0] 
__________________________________________________________________________________________________
inter_conv5 (Conv2D)  (None, 512, 16, 16) 4728320 concatenate_1[0][0]  
__________________________________________________________________________________________________
deconv4 (Conv2DTranspose) (None, 256, 32, 32) 4202752 concatenate_1[0][0]  
__________________________________________________________________________________________________
predict_flow5 (Conv2D)  (None, 2, 16, 16) 9218 inter_conv5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU) (None, 256, 32, 32) 0  deconv4[0][0]   
__________________________________________________________________________________________________
upsampled_flow5_to4 (Conv2DTran (None, 2, 32, 32) 66  predict_flow5[0][0]  
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 770, 32, 32) 0  leaky_re_lu_9[0][0]  
         leaky_re_lu_15[0][0]  
         upsampled_flow5_to4[0][0] 
__________________________________________________________________________________________________
inter_conv4 (Conv2D)  (None, 256, 32, 32) 1774336 concatenate_2[0][0]  
__________________________________________________________________________________________________
deconv3 (Conv2DTranspose) (None, 128, 64, 64) 1577088 concatenate_2[0][0]  
__________________________________________________________________________________________________
predict_flow4 (Conv2D)  (None, 2, 32, 32) 4610 inter_conv4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU) (None, 128, 64, 64) 0  deconv3[0][0]   
__________________________________________________________________________________________________
upsampled_flow4_to3 (Conv2DTran (None, 2, 64, 64) 66  predict_flow4[0][0]  
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 386, 64, 64) 0  leaky_re_lu_7[0][0]  
         leaky_re_lu_16[0][0]  
         upsampled_flow4_to3[0][0] 
__________________________________________________________________________________________________
inter_conv3 (Conv2D)  (None, 128, 64, 64) 444800 concatenate_3[0][0]  
__________________________________________________________________________________________________
deconv2 (Conv2DTranspose) (None, 64, 128, 128) 395328 concatenate_3[0][0]  
__________________________________________________________________________________________________
predict_flow3 (Conv2D)  (None, 2, 64, 64) 2306 inter_conv3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU) (None, 64, 128, 128) 0  deconv2[0][0]   
__________________________________________________________________________________________________
upsampled_flow3_to2 (Conv2DTran (None, 2, 128, 128) 66  predict_flow3[0][0]  
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 194, 128, 128 0  leaky_re_lu_5[0][0]  
         leaky_re_lu_17[0][0]  
         upsampled_flow3_to2[0][0] 
__________________________________________________________________________________________________
inter_conv2 (Conv2D)  (None, 64, 128, 128) 111808 concatenate_4[0][0]  
__________________________________________________________________________________________________
predict_flow2 (Conv2D)  (None, 2, 128, 128) 1154 inter_conv2[0][0]  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 2, 512, 512) 0  predict_flow2[0][0] 

再看看Pytorch搭建的flownet模型

 (conv0): Sequential(
 (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv1): Sequential(
 (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv1_1): Sequential(
 (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv2): Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv2_1): Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv3): Sequential(
 (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv3_1): Sequential(
 (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv4): Sequential(
 (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv4_1): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv5): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv5_1): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv6): Sequential(
 (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv6_1): Sequential(
 (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv5): Sequential(
 (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv4): Sequential(
 (0): ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv3): Sequential(
 (0): ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv2): Sequential(
 (0): ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (inter_conv5): Sequential(
 (0): Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv4): Sequential(
 (0): Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv3): Sequential(
 (0): Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv2): Sequential(
 (0): Conv2d(194, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (predict_flow6): Conv2d(1024, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow5): Conv2d(512, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow4): Conv2d(256, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow3): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow2): Conv2d(64, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (upsampled_flow6_to_5): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow5_to_4): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow4_to_3): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow3_to_2): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsample1): Upsample(scale_factor=4.0, mode=bilinear)
)
conv0 Sequential(
 (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv0.0 Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv0.1 LeakyReLU(negative_slope=0.1, inplace)
conv1 Sequential(
 (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv1.0 Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv1.1 LeakyReLU(negative_slope=0.1, inplace)
conv1_1 Sequential(
 (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv1_1.0 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv1_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv2 Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv2.0 Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv2.1 LeakyReLU(negative_slope=0.1, inplace)
conv2_1 Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv2_1.0 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv2_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv3 Sequential(
 (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv3.0 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv3.1 LeakyReLU(negative_slope=0.1, inplace)
conv3_1 Sequential(
 (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv3_1.0 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv3_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv4 Sequential(
 (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv4.0 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv4.1 LeakyReLU(negative_slope=0.1, inplace)
conv4_1 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv4_1.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv4_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv5 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv5.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv5.1 LeakyReLU(negative_slope=0.1, inplace)
conv5_1 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv5_1.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv5_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv6 Sequential(
 (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv6.0 Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv6.1 LeakyReLU(negative_slope=0.1, inplace)
conv6_1 Sequential(
 (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv6_1.0 Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv6_1.1 LeakyReLU(negative_slope=0.1, inplace)
deconv5 Sequential(
 (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv5.0 ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv5.1 LeakyReLU(negative_slope=0.1, inplace)
deconv4 Sequential(
 (0): ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv4.0 ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv4.1 LeakyReLU(negative_slope=0.1, inplace)
deconv3 Sequential(
 (0): ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv3.0 ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv3.1 LeakyReLU(negative_slope=0.1, inplace)
deconv2 Sequential(
 (0): ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv2.0 ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv2.1 LeakyReLU(negative_slope=0.1, inplace)
inter_conv5 Sequential(
 (0): Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv5.0 Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv4 Sequential(
 (0): Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv4.0 Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv3 Sequential(
 (0): Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv3.0 Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv2 Sequential(
 (0): Conv2d(194, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

因為Pytorch模型用name_modules()輸出不是按順序的,動態(tài)圖機制決定了只有在有數(shù)據(jù)流動之后才知道走過的路徑。所以上面的順序也是亂的。但我想表明的是,我用Keras搭建的模型確實是根據(jù)官方開源的Pytorch模型搭建的。

模型搭建完畢之后,就到了關(guān)鍵的步驟:給Keras模型賦值。

給Keras模型賦值

這個步驟其實注意三個點

Pytorch是channels_first的,Keras默認是channels_last,在代碼開頭加上這兩句:

K.set_image_data_format(‘channels_first')
K.set_learning_phase(0)

眾所周知,卷積層的權(quán)重是一個4維張量,那么,在Pytorch和keras中,卷積核的權(quán)重的形式是否一致的,那自然是不一致的,要不然我為啥還要寫這一點。那么就涉及到Pytorch權(quán)重的變形。

既然卷積層權(quán)重形式在兩個框架是不一致的,轉(zhuǎn)置卷積自然也是不一致的。

我們先看看卷積層在兩個框架中的形式

keras的卷積層權(quán)重形式

我們用以下代碼看keras卷積層權(quán)重形式

 for l in model.layers:
  print(l.name)
  for i, w in enumerate(l.get_weights()):
   print('%d'%i , w.shape)

第一個卷積層輸出如下 0之后是卷積權(quán)重的shape,1之后的是偏置項

conv0
0 (3, 3, 6, 64)
1 (64,)

所以Keras的卷積層權(quán)重形式是[ height, width, input_channels, out_channels]

Pytorch的卷積層權(quán)重形式

 net = FlowNet2SD()
 for n, m in net.named_parameters():
  print(n)
  print(m.data.size())

conv0.0.weight
torch.Size([64, 6, 3, 3])
conv0.0.bias
torch.Size([64])

用上面的代碼得到所有層的參數(shù)的shape,同樣找到第一個卷積層的參數(shù),查看shape。

通過對比我們可以發(fā)現(xiàn),Pytorch的卷積層shape是[ out_channels, input_channels, height, width]的形式。

那么我們在取出Pytorch權(quán)重之后,需要用np.transpose改變一下權(quán)重的排序,才能送到Keras模型對應(yīng)的層上。

Keras中轉(zhuǎn)置卷積權(quán)重形式

deconv4
0 (4, 4, 256, 1026)
1 (256,)

代碼仍然和上面一樣,找到轉(zhuǎn)置卷積的對應(yīng)的位置,查看一下

可以看出在Keras中,轉(zhuǎn)置卷積形式是 [ height, width, out_channels, input_channels]

Pytorch中轉(zhuǎn)置卷積權(quán)重形式

deconv4.0.weight
torch.Size([1026, 256, 4, 4])
deconv4.0.bias
torch.Size([256])

代碼仍然和上面一樣,找到轉(zhuǎn)置卷積的對應(yīng)的位置,查看一下

可以看出在Pytorch中,轉(zhuǎn)置卷積形式是 [ input_channels,out_channels,height, width]

小結(jié)

對于卷積層來說,Pytorch的權(quán)重需要使用

np.transpose(weight.data.numpy(), [2, 3, 1, 0])

才能賦值給keras模型對應(yīng)的層的權(quán)重。

對于轉(zhuǎn)置卷積來說,通過對比其實也是一樣的。不信你去試試嘛。O(∩_∩)O哈哈~

對于偏置項,兩種模塊都是一維的向量,不需要處理。

有的情況還可能需要通道顛倒一下,但是很少需要這樣做。

weights[::-1,::-1,:,:]

賦值

結(jié)束了預(yù)處理之后,我們就進入第二步,開始賦值了。

先看預(yù)處理的代碼:

for k,v in weights_from_torch.items():
 if 'bias' not in k:
  weights_from_torch[k] = v.data.numpy().transpose(2, 3, 1, 0)

賦值代碼我只截了一部分供大家參考:

k_model = k_model()
for layer in k_model.layers:
 current_layer_name = layer.name
 if current_layer_name=='conv0':
  weights = [weights_from_torch['conv0.0.weight'],weights_from_torch['conv0.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1':
  weights = [weights_from_torch['conv1.0.weight'],weights_from_torch['conv1.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1_1':
  weights = [weights_from_torch['conv1_1.0.weight'],weights_from_torch['conv1_1.0.bias']]
  layer.set_weights(weights)

首先就是定義Keras模型,用layers獲得所有層的迭代器。

遍歷迭代器,對一個層賦予相應(yīng)的值。

賦值需要用save_weights,其參數(shù)需要是一個列表,形式和get_weights的返回結(jié)果一致,即 [ conv_weights, bias_weights]

最后祝愿大家能實現(xiàn)自己模型的遷移。工程開源在了個人Github,有詳細的使用介紹,并且包含使用數(shù)據(jù),大家可以直接運行。

以上這篇Pytorch轉(zhuǎn)keras的有效方法,以FlowNet為例講解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python 迭代器介紹及作用詳情

    Python 迭代器介紹及作用詳情

    這篇文章主要介紹了Python 迭代器介紹及作用詳情,Python 中的迭代器是一個對象,用于迭代列表、元組、字典和集合等可迭代對象,文章圍繞主題展開詳細的內(nèi)容介紹,需要的朋友可以參考一下
    2022-07-07
  • python中and和or邏輯運算符的用法示例

    python中and和or邏輯運算符的用法示例

    python中的邏輯運算符有兩種返回值,python運算符除了能操作bool類型表達式,還能操作其他所有類型的表達式,這篇文章主要給大家介紹了關(guān)于python中and和or邏輯運算符用法的相關(guān)資料,需要的朋友可以參考下
    2022-01-01
  • numpy排序與集合運算用法示例

    numpy排序與集合運算用法示例

    這篇文章主要介紹了numpy排序與集合運算用法示例,具有一定借鑒價值,需要的朋友可以參考下。
    2017-12-12
  • python操作RabbitMq的三種工作模式

    python操作RabbitMq的三種工作模式

    這篇文章主要為大家介紹了python操作RabbitMq的三種工作模式,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步早日升職加薪
    2022-04-04
  • 總結(jié)Python使用過程中的bug

    總結(jié)Python使用過程中的bug

    今天給大家?guī)淼氖顷P(guān)于Python的相關(guān)知識,文章圍繞著Python使用過程中的bug展開,文中有非常詳細的介紹,需要的朋友可以參考下
    2021-06-06
  • python類中的self和變量用法及說明

    python類中的self和變量用法及說明

    這篇文章主要介紹了python類中的self和變量用法及說明,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2022-11-11
  • python中的實例方法、靜態(tài)方法、類方法、類變量和實例變量淺析

    python中的實例方法、靜態(tài)方法、類方法、類變量和實例變量淺析

    這篇文章主要介紹了python中的實例方法、靜態(tài)方法、類方法、類變量和實例變量淺析,需要的朋友可以參考下
    2014-04-04
  • 如何使用python檢測某網(wǎng)盤鏈接是否有效

    如何使用python檢測某網(wǎng)盤鏈接是否有效

    這篇文章主要為大家介紹了使用python檢測某網(wǎng)盤鏈接是否有效的方法示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2024-01-01
  • python如何讀取.mtx文件

    python如何讀取.mtx文件

    這篇文章主要介紹了python讀取.mtx文件的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2021-04-04
  • django中的HTML控件及參數(shù)傳遞方法

    django中的HTML控件及參數(shù)傳遞方法

    下面小編就為大家分享一篇django中的HTML控件及參數(shù)傳遞方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-03-03

最新評論