欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

keras訓(xùn)練淺層卷積網(wǎng)絡(luò)并保存和加載模型實(shí)例

 更新時(shí)間:2020年07月02日 11:21:55   作者:OliverkingLi  
這篇文章主要介紹了keras訓(xùn)練淺層卷積網(wǎng)絡(luò)并保存和加載模型實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧

這里我們使用keras定義簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)全連接層訓(xùn)練MNIST數(shù)據(jù)集和cifar10數(shù)據(jù)集:

keras_mnist.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import argparse
# 命令行參數(shù)運(yùn)行
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args =vars(ap.parse_args())
# 加載數(shù)據(jù)MNIST,然后歸一化到【0,1】,同時(shí)使用75%做訓(xùn)練,25%做測(cè)試
print("[INFO] loading MNIST (full) dataset")
dataset = datasets.fetch_mldata("MNIST Original", data_home="/home/king/test/python/train/pyimagesearch/nn/data/")
data = dataset.data.astype("float") / 255.0
(trainX, testX, trainY, testY) = train_test_split(data, dataset.target, test_size=0.25)
# 將label進(jìn)行one-hot編碼
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# keras定義網(wǎng)絡(luò)結(jié)構(gòu)784--256--128--10
model = Sequential()
model.add(Dense(256, input_shape=(784,), activation="relu"))
model.add(Dense(128, activation="relu"))
model.add(Dense(10, activation="softmax"))
# 開始訓(xùn)練
print("[INFO] training network...")
# 0.01的學(xué)習(xí)率
sgd = SGD(0.01)
# 交叉驗(yàn)證
model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=['accuracy'])
H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=128)
# 測(cè)試模型和評(píng)估
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=128)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=[str(x) for x in lb.classes_]))
# 保存可視化訓(xùn)練結(jié)果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 100), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 100), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 100), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 100), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

使用relu做激活函數(shù):

使用sigmoid做激活函數(shù):

接著我們自己定義一些modules去實(shí)現(xiàn)一個(gè)簡(jiǎn)單的卷基層去訓(xùn)練cifar10數(shù)據(jù)集:

imagetoarraypreprocessor.py

'''
該函數(shù)主要是實(shí)現(xiàn)keras的一個(gè)細(xì)節(jié)轉(zhuǎn)換,因?yàn)橛?xùn)練的圖像時(shí)RGB三顏色通道,讀取進(jìn)來(lái)的數(shù)據(jù)是有depth的,keras為了兼容一些后臺(tái),默認(rèn)是按照(height, width, depth)讀取,但有時(shí)候就要改變成(depth, height, width)
'''
from keras.preprocessing.image import img_to_array
class ImageToArrayPreprocessor:
	def __init__(self, dataFormat=None):
		self.dataFormat = dataFormat
 
	def preprocess(self, image):
		return img_to_array(image, data_format=self.dataFormat)
 

shallownet.py

'''
定義一個(gè)簡(jiǎn)單的卷基層:
input->conv->Relu->FC
'''
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation, Flatten, Dense
from keras import backend as K
 
class ShallowNet:
	@staticmethod
	def build(width, height, depth, classes):
		model = Sequential()
		inputShape = (height, width, depth)
 
		if K.image_data_format() == "channels_first":
			inputShape = (depth, height, width)
 
		model.add(Conv2D(32, (3, 3), padding="same", input_shape=inputShape))
		model.add(Activation("relu"))
 
		model.add(Flatten())
		model.add(Dense(classes))
		model.add(Activation("softmax"))
 
		return model

然后就是訓(xùn)練代碼:

keras_cifar10.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 標(biāo)簽0-9代表的類別string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.0001)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=1000, verbose=1)
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可視化訓(xùn)練結(jié)果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 1000), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 1000), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 1000), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 1000), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])
 

代碼中可以對(duì)訓(xùn)練的learning rate進(jìn)行微調(diào),大概可以接近60%的準(zhǔn)確率。

然后修改下代碼可以保存訓(xùn)練模型:

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 標(biāo)簽0-9代表的類別string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.005)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=50, verbose=1)
 
model.save(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可視化訓(xùn)練結(jié)果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 5), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 5), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 5), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 5), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])
 

命令行運(yùn)行:

我們使用另一個(gè)程序來(lái)加載上一次訓(xùn)練保存的模型,然后進(jìn)行測(cè)試:

test.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
from keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
# 標(biāo)簽0-9代表的類別string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
 
idxs = np.random.randint(0, len(testX), size=(10,))
testX = testX[idxs]
testY = testY[idxs]
 
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
 
print("[INFO] loading pre-trained network...")
model = load_model(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32).argmax(axis=1)
print("predictions\n", predictions)
for i in range(len(testY)):
	print("label:{}".format(labelNames[predictions[i]]))
 
trueLabel = []
for i in range(len(testY)):
	for j in range(len(testY[i])):
		if testY[i][j] != 0:
			trueLabel.append(j)
print(trueLabel)
 
print("ground truth testY:")
for i in range(len(trueLabel)):
	print("label:{}".format(labelNames[trueLabel[i]]))
 
print("TestY\n", testY)

以上這篇keras訓(xùn)練淺層卷積網(wǎng)絡(luò)并保存和加載模型實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • 解決已經(jīng)安裝requests,卻依然提示No module named requests問(wèn)題

    解決已經(jīng)安裝requests,卻依然提示No module named requests問(wèn)題

    今天小編就為大家分享一篇解決已經(jīng)安裝requests,卻依然提示No module named 'requests'問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2018-05-05
  • Python 點(diǎn)擊指定位置驗(yàn)證碼破解的實(shí)現(xiàn)代碼

    Python 點(diǎn)擊指定位置驗(yàn)證碼破解的實(shí)現(xiàn)代碼

    這篇文章主要介紹了Python 點(diǎn)擊指定位置驗(yàn)證碼破解的實(shí)現(xiàn)代碼,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2019-09-09
  • Django上傳xlsx文件直接轉(zhuǎn)化為DataFrame或直接保存的方法

    Django上傳xlsx文件直接轉(zhuǎn)化為DataFrame或直接保存的方法

    這篇文章主要介紹了Django上傳xlsx文件直接轉(zhuǎn)化為DataFrame或直接保存的方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2021-05-05
  • python中利用h5py模塊讀取h5文件中的主鍵方法

    python中利用h5py模塊讀取h5文件中的主鍵方法

    今天小編就為大家分享一篇python中利用h5py模塊讀取h5文件中的主鍵方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2018-06-06
  • Python MySQL數(shù)據(jù)庫(kù)連接池組件pymysqlpool詳解

    Python MySQL數(shù)據(jù)庫(kù)連接池組件pymysqlpool詳解

    這篇文章主要跟大家介紹了關(guān)于Python MySQL數(shù)據(jù)庫(kù)連接池組件pymysqlpool的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起看看吧。
    2017-07-07
  • 在python下讀取并展示raw格式的圖片實(shí)例

    在python下讀取并展示raw格式的圖片實(shí)例

    今天小編就為大家分享一篇在python下讀取并展示raw格式的圖片實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-01-01
  • python中使用numpy包的向量矩陣相乘np.dot和np.matmul實(shí)現(xiàn)

    python中使用numpy包的向量矩陣相乘np.dot和np.matmul實(shí)現(xiàn)

    本文主要介紹了python中使用numpy包的向量矩陣相乘np.dot和np.matmul實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2023-02-02
  • Python 依賴庫(kù)太多了該如何管理

    Python 依賴庫(kù)太多了該如何管理

    在 Python 的項(xiàng)目中,如何管理所用的全部依賴庫(kù)呢?最主流的做法是維護(hù)一份“requirements.txt”,記錄下依賴庫(kù)的名字及其版本號(hào),需要的朋友可以參考下
    2019-11-11
  • Python中常見(jiàn)的數(shù)據(jù)類型小結(jié)

    Python中常見(jiàn)的數(shù)據(jù)類型小結(jié)

    這篇文章主要對(duì)Python中常見(jiàn)的數(shù)據(jù)類型進(jìn)行了總結(jié)歸納,很有參考借鑒價(jià)值,需要的朋友可以參考下
    2015-08-08
  • 運(yùn)行Python編寫的程序方法實(shí)例

    運(yùn)行Python編寫的程序方法實(shí)例

    在本篇文章里小編給大家整理了關(guān)于運(yùn)行Python編寫的程序方法實(shí)例內(nèi)容,有興趣的朋友們可以學(xué)習(xí)下。
    2020-10-10

最新評(píng)論