Keras使用ImageNet上預(yù)訓(xùn)練的模型方式
我就廢話不多說了,大家還是直接看代碼吧!
import keras import numpy as np from keras.applications import vgg16, inception_v3, resnet50, mobilenet #Load the VGG model vgg_model = vgg16.VGG16(weights='imagenet') #Load the Inception_V3 model inception_model = inception_v3.InceptionV3(weights='imagenet') #Load the ResNet50 model resnet_model = resnet50.ResNet50(weights='imagenet') #Load the MobileNet model mobilenet_model = mobilenet.MobileNet(weights='imagenet')
在以上代碼中,我們首先import各種模型對(duì)應(yīng)的module,然后load模型,并用ImageNet的參數(shù)初始化模型的參數(shù)。
如果不想使用ImageNet上預(yù)訓(xùn)練到的權(quán)重初始話模型,可以將各語句的中'imagenet'替換為'None'。
補(bǔ)充知識(shí):keras上使用alexnet模型來高準(zhǔn)確度對(duì)mnist數(shù)據(jù)進(jìn)行分類
綱要
本文有兩個(gè)特點(diǎn):一是直接對(duì)本地mnist數(shù)據(jù)進(jìn)行讀?。僭O(shè)事先已經(jīng)下載或從別處拷來)二是基于keras框架(網(wǎng)上多是基于tf)使用alexnet對(duì)mnist數(shù)據(jù)進(jìn)行分類,并獲得較高準(zhǔn)確度(約為98%)
本地?cái)?shù)據(jù)讀取和分析
很多代碼都是一開始簡單調(diào)用一行代碼來從網(wǎng)站上下載mnist數(shù)據(jù),雖然只有10來MB,但是現(xiàn)在下載速度非常慢,而且經(jīng)常中途出錯(cuò),要費(fèi)很大的勁才能拿到數(shù)據(jù)。
(X_train, y_train), (X_test, y_test) = mnist.load_data()
其實(shí)可以單獨(dú)來獲得這些數(shù)據(jù)(一共4個(gè)gz包,如下所示),然后調(diào)用別的接口來分析它們。
mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #導(dǎo)入已經(jīng)下載好的數(shù)據(jù)集,"./MNIST_data"為存放mnist數(shù)據(jù)的目錄
x_train = mnist.train.images y_train = mnist.train.labels x_test = mnist.test.images y_test = mnist.test.labels
這里面要注意的是,兩種接口拿到的數(shù)據(jù)形式是不一樣的。 從網(wǎng)上直接下載下來的數(shù)據(jù) 其image data值的范圍是0~255,且label值為0,1,2,3...9。 而第二種接口獲取的數(shù)據(jù) image值已經(jīng)除以255(歸一化)變成0~1范圍,且label值已經(jīng)是one-hot形式(one_hot=True時(shí)),比如label值2的one-hot code為(0 0 1 0 0 0 0 0 0 0)
所以,以第一種方式獲取的數(shù)據(jù)需要做一些預(yù)處理(歸一和one-hot)才能輸入網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練 而第二種接口拿到的數(shù)據(jù)則可以直接進(jìn)行訓(xùn)練。
Alexnet模型的微調(diào)
按照公開的模型框架,Alexnet只有第1、2個(gè)卷積層才跟著BatchNormalization,后面三個(gè)CNN都沒有(如有說錯(cuò),請(qǐng)指正)。如果按照這個(gè)來搭建網(wǎng)絡(luò)模型,很容易導(dǎo)致梯度消失,現(xiàn)象就是 accuracy值一直處在很低的值。 如下所示。
在每個(gè)卷積層后面都加上BN后,準(zhǔn)確度才迭代提高。如下所示
完整代碼
import keras from keras.datasets import mnist from keras.models import Sequential from keras.layers.core import Dense, Dropout, Activation, Flatten from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D from keras.layers.normalization import BatchNormalization from keras.callbacks import ModelCheckpoint import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #tensorflow已經(jīng)包含了mnist案例的數(shù)據(jù) batch_size = 64 num_classes = 10 epochs = 10 img_shape = (28,28,1) # input dimensions img_rows, img_cols = 28,28 # dataset input #(x_train, y_train), (x_test, y_test) = mnist.load_data() mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #導(dǎo)入已經(jīng)下載好的數(shù)據(jù)集,"./MNIST_data"為存放mnist數(shù)據(jù)的目錄 print(mnist.train.images.shape, mnist.train.labels.shape) print(mnist.test.images.shape, mnist.test.labels.shape) print(mnist.validation.images.shape, mnist.validation.labels.shape) x_train = mnist.train.images y_train = mnist.train.labels x_test = mnist.test.images y_test = mnist.test.labels # data initialization x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) # Define the input layer inputs = keras.Input(shape = [img_rows, img_cols, 1]) #Define the converlutional layer 1 conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs) # Define the pooling layer 1 pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1) # Define the standardization layer 1 stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1) # Define the converlutional layer 2 conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1) # Defien the pooling layer 2 pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2) # Define the standardization layer 2 stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2) # Define the converlutional layer 3 conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2) stand3 = keras.layers.BatchNormalization(axis=1)(conv3) # Define the converlutional layer 4 conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3) stand4 = keras.layers.BatchNormalization(axis=1)(conv4) # Define the converlutional layer 5 conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4) pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5) stand5 = keras.layers.BatchNormalization(axis=1)(pooling5) # Define the fully connected layer flatten = keras.layers.Flatten()(stand5) fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten) drop1 = keras.layers.Dropout(0.5)(fc1) fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1) drop2 = keras.layers.Dropout(0.5)(fc2) fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2) # 基于Model方法構(gòu)建模型 model = keras.Model(inputs= inputs, outputs = fc3) # 編譯模型 model.compile(optimizer= tf.train.AdamOptimizer(0.001), loss= keras.losses.categorical_crossentropy, metrics= ['accuracy']) # 訓(xùn)練配置,僅供參考 model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))
以上這篇Keras使用ImageNet上預(yù)訓(xùn)練的模型方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python中eval帶來的潛在風(fēng)險(xiǎn)代碼分析
這篇文章主要介紹了Python中eval帶來的潛在風(fēng)險(xiǎn)代碼分析,具有一定借鑒價(jià)值,需要的朋友可以參考下。2017-12-12python matplotlib imshow熱圖坐標(biāo)替換/映射實(shí)例
這篇文章主要介紹了python matplotlib imshow熱圖坐標(biāo)替換/映射實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-03-03selenium3+python3環(huán)境搭建教程圖解
這篇文章主要介紹了selenium3+python3環(huán)境搭建教程圖解,需要的朋友可以參考下2018-12-12Python PyQt5實(shí)現(xiàn)拖放效果的原理詳解
這篇文章主要為大家詳細(xì)介紹了Python PyQt5中拖放效果的實(shí)現(xiàn)原理與實(shí)現(xiàn)代碼,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2022-11-11構(gòu)建高效的python requests長連接池詳解
這篇文章主要介紹了構(gòu)建高效的python requests長連接池詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-05-05