python神經(jīng)網(wǎng)絡Xception模型復現(xiàn)詳解
Xception是繼Inception后提出的對Inception v3的另一種改進,學一學總是好的
什么是Xception模型
Xception是谷歌公司繼Inception后,提出的InceptionV3的一種改進模型,其改進的主要內容為采用depthwise separable convolution來替換原來Inception v3中的多尺寸卷積核特征響應操作。
在講Xception模型之前,首先要講一下什么是depthwise separable convolution(深度可分離卷積塊)。
深度可分離卷積塊由兩個部分組成,分別是深度可分離卷積和1x1普通卷積,深度可分離卷積的卷積核大小一般是3x3的,便于理解的話我們可以把它當作是特征提取,1x1的普通卷積可以完成通道數(shù)的調整。
下圖為深度可分離卷積塊的結構示意圖:
深度可分離卷積塊的目的是使用更少的參數(shù)來代替普通的3x3卷積。
我們可以進行一下普通卷積和深度可分離卷積塊的對比:
假設有一個3×3大小的卷積層,其輸入通道為16、輸出通道為32。具體為,32個3×3大小的卷積核會遍歷16個通道中的每個數(shù)據(jù),最后可得到所需的32個輸出通道,所需參數(shù)為16×32×3×3=4608個。
應用深度可分離卷積,用16個3×3大小的卷積核分別遍歷16通道的數(shù)據(jù),得到了16個特征圖譜。在融合操作之前,接著用32個1×1大小的卷積核遍歷這16個特征圖譜,所需參數(shù)為16×3×3+16×32×1×1=656個。
可以看出來depthwise separable convolution可以減少模型的參數(shù)。
通俗地理解深度可分離卷積結構塊,就是3x3的卷積核厚度只有一層,然后在輸入張量上一層一層地滑動,每一次卷積完生成一個輸出通道,當卷積完成后,再利用1x1的卷積調整厚度。
(視頻中有些許錯誤,感謝zl960929的提醒,Xception使用的深度可分離卷積塊SeparableConv2D也就是先深度可分離卷積再進行1x1卷積。)
對于Xception模型而言,其一共可以分為3個flow,分別是Entry flow、Middle flow、Exit flow;
分為14個block,其中Entry flow中有4個、Middle flow中有8個、Exit flow中有2個。
具體結構如下:
其內部主要結構就是殘差卷積網(wǎng)絡搭配SeparableConv2D層實現(xiàn)一個個block,在Xception模型中,常見的兩個block的結構如下。這個主要在Entry flow和Exit flow中:
這個主要在Middle flow中:
Xception網(wǎng)絡部分實現(xiàn)代碼
#-------------------------------------------------------------# # Xception的網(wǎng)絡部分 #-------------------------------------------------------------# from keras.preprocessing import image from keras.models import Model from keras import layers from keras.layers import Dense,Input,BatchNormalization,Activation,Conv2D,SeparableConv2D,MaxPooling2D from keras.layers import GlobalAveragePooling2D,GlobalMaxPooling2D from keras import backend as K from keras.applications.imagenet_utils import decode_predictions def Xception(input_shape = [299,299,3],classes=1000): img_input = Input(shape=input_shape) #--------------------------# # Entry flow #--------------------------# #--------------------# # block1 #--------------------# # 299,299,3 -> 149,149,64 x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(img_input) x = BatchNormalization(name='block1_conv1_bn')(x) x = Activation('relu', name='block1_conv1_act')(x) x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x) x = BatchNormalization(name='block1_conv2_bn')(x) x = Activation('relu', name='block1_conv2_act')(x) #--------------------# # block2 #--------------------# # 149,149,64 -> 75,75,128 residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x) x = BatchNormalization(name='block2_sepconv1_bn')(x) x = Activation('relu', name='block2_sepconv2_act')(x) x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x) x = BatchNormalization(name='block2_sepconv2_bn')(x) x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x) x = layers.add([x, residual]) #--------------------# # block3 #--------------------# # 75,75,128 -> 38,38,256 residual = Conv2D(256, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = Activation('relu', name='block3_sepconv1_act')(x) x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x) x = BatchNormalization(name='block3_sepconv1_bn')(x) x = Activation('relu', name='block3_sepconv2_act')(x) x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x) x = BatchNormalization(name='block3_sepconv2_bn')(x) x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x) x = layers.add([x, residual]) #--------------------# # block4 #--------------------# # 38,38,256 -> 19,19,728 residual = Conv2D(728, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = Activation('relu', name='block4_sepconv1_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x) x = BatchNormalization(name='block4_sepconv1_bn')(x) x = Activation('relu', name='block4_sepconv2_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x) x = BatchNormalization(name='block4_sepconv2_bn')(x) x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x) x = layers.add([x, residual]) #--------------------------# # Middle flow #--------------------------# #--------------------# # block5--block12 #--------------------# # 19,19,728 -> 19,19,728 for i in range(8): residual = x prefix = 'block' + str(i + 5) x = Activation('relu', name=prefix + '_sepconv1_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x) x = BatchNormalization(name=prefix + '_sepconv1_bn')(x) x = Activation('relu', name=prefix + '_sepconv2_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x) x = BatchNormalization(name=prefix + '_sepconv2_bn')(x) x = Activation('relu', name=prefix + '_sepconv3_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x) x = BatchNormalization(name=prefix + '_sepconv3_bn')(x) x = layers.add([x, residual]) #--------------------------# # Exit flow #--------------------------# #--------------------# # block13 #--------------------# # 19,19,728 -> 10,10,1024 residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = Activation('relu', name='block13_sepconv1_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x) x = BatchNormalization(name='block13_sepconv1_bn')(x) x = Activation('relu', name='block13_sepconv2_act')(x) x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x) x = BatchNormalization(name='block13_sepconv2_bn')(x) x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x) x = layers.add([x, residual]) #--------------------# # block14 #--------------------# # 10,10,1024 -> 10,10,2048 x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x) x = BatchNormalization(name='block14_sepconv1_bn')(x) x = Activation('relu', name='block14_sepconv1_act')(x) x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x) x = BatchNormalization(name='block14_sepconv2_bn')(x) x = Activation('relu', name='block14_sepconv2_act')(x) x = GlobalAveragePooling2D(name='avg_pool')(x) x = Dense(classes, activation='softmax', name='predictions')(x) inputs = img_input model = Model(inputs, x, name='xception') model.load_weights("xception_weights_tf_dim_ordering_tf_kernels.h5") return model
圖片預測
建立網(wǎng)絡后,可以用以下的代碼進行預測。
def preprocess_input(x): x /= 255. x -= 0.5 x *= 2. return x if __name__ == '__main__': model = Xception() img_path = 'elephant.jpg' img = image.load_img(img_path, target_size=(299, 299)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) print('Input image shape:', x.shape) preds = model.predict(x) print(np.argmax(preds)) print('Predicted:', decode_predictions(preds))
預測所需的已經(jīng)訓練好的Xception模型可以在https://github.com/fchollet/deep-learning-models/releases下載。非常方便。
預測結果為:
Predicted: [[('n02504458', 'African_elephant', 0.47570863), ('n01871265', 'tusker', 0.3173351), ('n02504013', 'Indian_elephant', 0.030323735), ('n02963159', 'cardigan', 0.0007877756), ('n02410509', 'bison', 0.00075616257)]]
以上就是python神經(jīng)網(wǎng)絡Xception模型復現(xiàn)詳解的詳細內容,更多關于Xception模型的復現(xiàn)詳解的資料請關注腳本之家其它相關文章!
相關文章
pandas數(shù)據(jù)的合并與拼接的實現(xiàn)
Pandas包的merge、join、concat方法可以完成數(shù)據(jù)的合并和拼接,本文主要介紹了這三種實現(xiàn)方式,具有一定的參考價值,感興趣的小伙伴們可以參考一下2021-12-12利用Anaconda完美解決Python 2與python 3的共存問題
Anaconda 是 Python 的一個發(fā)行版,如果把 Python 比作 Linux,那么 Anancoda 就是 CentOS 或者 Ubuntu,下面這篇文章主要給大家介紹了利用Anaconda完美解決Python 2與python 3共存問題的相關資料,文中介紹的非常詳細,需要的朋友可以參考借鑒。2017-05-05python目標檢測yolo2詳解及預測代碼復現(xiàn)
這篇文章主要為大家介紹了python目標檢測yolo2詳解及其預測代碼復現(xiàn),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2022-05-05requests.gPython?用requests.get獲取網(wǎng)頁內容為空?’?’問題
這篇文章主要介紹了requests.gPython?用requests.get獲取網(wǎng)頁內容為空?’?’,溫行首先舉例說明,具有一定得參考價值,需要的小伙伴可以參考一下2022-01-01