欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Keras使用ImageNet上預(yù)訓(xùn)練的模型方式

 更新時(shí)間:2020年05月23日 15:09:06   作者:breeze5428  
這篇文章主要介紹了Keras使用ImageNet上預(yù)訓(xùn)練的模型方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧

我就廢話不多說了,大家還是直接看代碼吧!

import keras
import numpy as np
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
 
#Load the VGG model
vgg_model = vgg16.VGG16(weights='imagenet')
 
#Load the Inception_V3 model
inception_model = inception_v3.InceptionV3(weights='imagenet')
 
#Load the ResNet50 model
resnet_model = resnet50.ResNet50(weights='imagenet')
 
#Load the MobileNet model
mobilenet_model = mobilenet.MobileNet(weights='imagenet')

在以上代碼中,我們首先import各種模型對(duì)應(yīng)的module,然后load模型,并用ImageNet的參數(shù)初始化模型的參數(shù)。

如果不想使用ImageNet上預(yù)訓(xùn)練到的權(quán)重初始話模型,可以將各語句的中'imagenet'替換為'None'。

補(bǔ)充知識(shí):keras上使用alexnet模型來高準(zhǔn)確度對(duì)mnist數(shù)據(jù)進(jìn)行分類

綱要

本文有兩個(gè)特點(diǎn):一是直接對(duì)本地mnist數(shù)據(jù)進(jìn)行讀?。僭O(shè)事先已經(jīng)下載或從別處拷來)二是基于keras框架(網(wǎng)上多是基于tf)使用alexnet對(duì)mnist數(shù)據(jù)進(jìn)行分類,并獲得較高準(zhǔn)確度(約為98%)

本地?cái)?shù)據(jù)讀取和分析

很多代碼都是一開始簡單調(diào)用一行代碼來從網(wǎng)站上下載mnist數(shù)據(jù),雖然只有10來MB,但是現(xiàn)在下載速度非常慢,而且經(jīng)常中途出錯(cuò),要費(fèi)很大的勁才能拿到數(shù)據(jù)。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

其實(shí)可以單獨(dú)來獲得這些數(shù)據(jù)(一共4個(gè)gz包,如下所示),然后調(diào)用別的接口來分析它們。

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #導(dǎo)入已經(jīng)下載好的數(shù)據(jù)集,"./MNIST_data"為存放mnist數(shù)據(jù)的目錄

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

這里面要注意的是,兩種接口拿到的數(shù)據(jù)形式是不一樣的。 從網(wǎng)上直接下載下來的數(shù)據(jù) 其image data值的范圍是0~255,且label值為0,1,2,3...9。 而第二種接口獲取的數(shù)據(jù) image值已經(jīng)除以255(歸一化)變成0~1范圍,且label值已經(jīng)是one-hot形式(one_hot=True時(shí)),比如label值2的one-hot code為(0 0 1 0 0 0 0 0 0 0)

所以,以第一種方式獲取的數(shù)據(jù)需要做一些預(yù)處理(歸一和one-hot)才能輸入網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練 而第二種接口拿到的數(shù)據(jù)則可以直接進(jìn)行訓(xùn)練。

Alexnet模型的微調(diào)

按照公開的模型框架,Alexnet只有第1、2個(gè)卷積層才跟著BatchNormalization,后面三個(gè)CNN都沒有(如有說錯(cuò),請(qǐng)指正)。如果按照這個(gè)來搭建網(wǎng)絡(luò)模型,很容易導(dǎo)致梯度消失,現(xiàn)象就是 accuracy值一直處在很低的值。 如下所示。

在每個(gè)卷積層后面都加上BN后,準(zhǔn)確度才迭代提高。如下所示

完整代碼

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #tensorflow已經(jīng)包含了mnist案例的數(shù)據(jù)
 
batch_size = 64
num_classes = 10
epochs = 10
img_shape = (28,28,1)
 
# input dimensions
img_rows, img_cols = 28,28
 
# dataset input
#(x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #導(dǎo)入已經(jīng)下載好的數(shù)據(jù)集,"./MNIST_data"為存放mnist數(shù)據(jù)的目錄
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
 
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
 
# data initialization
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
 
# Define the input layer
inputs = keras.Input(shape = [img_rows, img_cols, 1])
 
 #Define the converlutional layer 1
conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)
# Define the pooling layer 1
pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)
# Define the standardization layer 1
stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)
 
# Define the converlutional layer 2
conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)
# Defien the pooling layer 2
pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)
# Define the standardization layer 2
stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)
 
# Define the converlutional layer 3
conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)
stand3 = keras.layers.BatchNormalization(axis=1)(conv3)
 
# Define the converlutional layer 4
conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)
stand4 = keras.layers.BatchNormalization(axis=1)(conv4)
 
# Define the converlutional layer 5
conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)
pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)
stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)
 
# Define the fully connected layer
flatten = keras.layers.Flatten()(stand5)
fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)
drop1 = keras.layers.Dropout(0.5)(fc1)
 
fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)
drop2 = keras.layers.Dropout(0.5)(fc2)
 
fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)
 
# 基于Model方法構(gòu)建模型
model = keras.Model(inputs= inputs, outputs = fc3)
# 編譯模型
model.compile(optimizer= tf.train.AdamOptimizer(0.001),
       loss= keras.losses.categorical_crossentropy,
       metrics= ['accuracy'])
# 訓(xùn)練配置,僅供參考
model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))

以上這篇Keras使用ImageNet上預(yù)訓(xùn)練的模型方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • pytorch自定義二值化網(wǎng)絡(luò)層方式

    pytorch自定義二值化網(wǎng)絡(luò)層方式

    今天小編就為大家分享一篇pytorch自定義二值化網(wǎng)絡(luò)層方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-01-01
  • Python中eval帶來的潛在風(fēng)險(xiǎn)代碼分析

    Python中eval帶來的潛在風(fēng)險(xiǎn)代碼分析

    這篇文章主要介紹了Python中eval帶來的潛在風(fēng)險(xiǎn)代碼分析,具有一定借鑒價(jià)值,需要的朋友可以參考下。
    2017-12-12
  • Flask?數(shù)據(jù)庫遷移詳情

    Flask?數(shù)據(jù)庫遷移詳情

    本文給大家分享的是?Flask?數(shù)據(jù)庫遷移詳情,db.create_all()不會(huì)重新創(chuàng)建表或是更新表,需要先使用db.drop_all()刪除數(shù)據(jù)庫中所有的表之后再調(diào)用db.create_all()才能重新創(chuàng)建表,但是這樣的話,原來表中的數(shù)據(jù)就都被刪除了,這肯定是不行的,這時(shí)就出現(xiàn)了數(shù)據(jù)庫遷移的概念
    2021-11-11
  • Python中的OpenCV圖像腐蝕處理和膨脹處理

    Python中的OpenCV圖像腐蝕處理和膨脹處理

    這篇文章主要介紹了Python中的OpenCV圖像腐蝕處理和膨脹處理,OpenCV是一個(gè)跨平臺(tái)的計(jì)算機(jī)視覺庫,可用于開發(fā)實(shí)時(shí)的圖像處理、計(jì)算機(jī)視覺以及模式識(shí)別程序,需要的朋友可以參考下
    2023-08-08
  • python matplotlib imshow熱圖坐標(biāo)替換/映射實(shí)例

    python matplotlib imshow熱圖坐標(biāo)替換/映射實(shí)例

    這篇文章主要介紹了python matplotlib imshow熱圖坐標(biāo)替換/映射實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-03-03
  • selenium3+python3環(huán)境搭建教程圖解

    selenium3+python3環(huán)境搭建教程圖解

    這篇文章主要介紹了selenium3+python3環(huán)境搭建教程圖解,需要的朋友可以參考下
    2018-12-12
  • 安裝完P(guān)ython包然后找不到模塊的解決步驟

    安裝完P(guān)ython包然后找不到模塊的解決步驟

    今天小編就為大家分享一篇安裝完P(guān)ython包然后找不到模塊的解決步驟,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-02-02
  • Python PyQt5實(shí)現(xiàn)拖放效果的原理詳解

    Python PyQt5實(shí)現(xiàn)拖放效果的原理詳解

    這篇文章主要為大家詳細(xì)介紹了Python PyQt5中拖放效果的實(shí)現(xiàn)原理與實(shí)現(xiàn)代碼,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下
    2022-11-11
  • 構(gòu)建高效的python requests長連接池詳解

    構(gòu)建高效的python requests長連接池詳解

    這篇文章主要介紹了構(gòu)建高效的python requests長連接池詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-05-05
  • python如何讀寫csv數(shù)據(jù)

    python如何讀寫csv數(shù)據(jù)

    這篇文章主要為大家詳細(xì)介紹了python如何讀寫csv數(shù)據(jù),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2018-03-03

最新評(píng)論