keras建模的3種方式詳解
keras建模的3種方式
keras是google公司2016年發(fā)布的tensorflow為后端的深度學(xué)習(xí)網(wǎng)絡(luò)的高級接口。
三種建模方式:
- 序列模型
- 函數(shù)模型
- 子類模型
第一種序列模型:
import numpy as np from tensorflow.examples.tutorials.mnist import input_data from keras.models import Sequential from keras.models import load_model from keras.layers import Dense #加載數(shù)據(jù) def read_data(path): mnist=input_data.read_data_sets(path,one_hot=True) train_x,train_y=mnist.train.images,mnist.train.labels, valid_x,valid_y=mnist.validation.images,mnist.validation.labels, test_x,test_y=mnist.test.images,mnist.test.labels return train_x,train_y,valid_x,valid_y,test_x,test_y #序列模型 def DNN(train_x,train_y,valid_x,valid_y): #創(chuàng)建模型 model=Sequential() model.add(Dense(64,input_dim=784,activation='relu')) model.add(Dense(128,activation='relu')) model.add(Dense(10,activation='softmax')) #查看網(wǎng)絡(luò)模型 model.summary() #編譯模型 model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) #訓(xùn)練模型 model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y)) #保存模型 model.save('sequential.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data') DNN(train_x,train_y,valid_x,valid_y) model=load_model('sequential.h5') #下載模型 pre=model.predict(test_x) #測試驗證 #計算驗證集精度 a=np.argmax(pre,1) b=np.argmax(test_y,1) t=(a==b).astype(int) acc=np.sum(t)/len(a) print(acc)
第二種函數(shù)模型
import numpy as np from tensorflow.examples.tutorials.mnist import input_data from keras.models import Model from keras.models import load_model from keras.layers import Input,Dense #加載數(shù)據(jù) def read_data(path): mnist=input_data.read_data_sets(path,one_hot=True) train_x,train_y=mnist.train.images,mnist.train.labels, valid_x,valid_y=mnist.validation.images,mnist.validation.labels, test_x,test_y=mnist.test.images,mnist.test.labels return train_x,train_y,valid_x,valid_y,test_x,test_y #函數(shù)模型 def DNN(train_x,train_y,valid_x,valid_y): #創(chuàng)建模型 inputs=Input(shape=(784,)) x=Dense(64,activation='relu')(inputs) x=Dense(128,activation='relu')(x) output=Dense(10,activation='softmax')(x) model=Model(input=inputs,output=output) #查看網(wǎng)絡(luò)結(jié)構(gòu) model.summary() #編譯模型 model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) #訓(xùn)練模型 model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y)) #保存模型 model.save('fun_model.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data') DNN(train_x,train_y,valid_x,valid_y) model=load_model('fun_model.h5') #下載模型 pre=model.predict(test_x) #驗證數(shù)據(jù)集 #驗證數(shù)據(jù)集準確度 a=np.argmax(pre,1) b=np.argmax(test_y,1) t=(a==b).astype(int) acc=np.sum(t)/len(a) print(acc)
第三種子類模型
import numpy as np from tensorflow.examples.tutorials.mnist import input_data from keras.models import Model from keras.layers import Dense #加載數(shù)據(jù) def read_data(path): mnist=input_data.read_data_sets(path,one_hot=True) train_x,train_y=mnist.train.images,mnist.train.labels, valid_x,valid_y=mnist.validation.images,mnist.validation.labels, test_x,test_y=mnist.test.images,mnist.test.labels return train_x,train_y,valid_x,valid_y,test_x,test_y #子類模型 class DNN(Model): def __init__(self,train_x,train_y,valid_x,valid_y): super(DNN,self).__init__() #初始化網(wǎng)絡(luò)模型 self.dense1=Dense(64,input_dim=784,activation='relu') self.dense2=Dense(128,activation='relu') self.dense3=Dense(10,activation='softmax') def call(self,inputs): #回調(diào)順序 x=self.dense1(inputs) x=self.dense2(x) x=self.dense3(x) return x train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data') model=DNN(train_x,train_y,valid_x,valid_y) #編譯模型(學(xué)習(xí)率、損失函數(shù)、模型評估) model.compile(optimizer='adam(lr=0.001)',loss='categorical_crossentropy',metrics=['accuracy']) #訓(xùn)練模型 model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y)) #查看網(wǎng)絡(luò)結(jié)構(gòu) model.summary() pre=model.predict(test_x) #驗證數(shù)據(jù)集 #計算驗證數(shù)據(jù)集的準確度 a=np.argmax(pre,1) b=np.argmax(test_y,1) t=(a==b).astype(int) acc=np.sum(t)/len(a) print(acc)
常用的損失函數(shù):
mse #均方差(回歸)
mae #絕對誤差(回歸)
binary_crossentropy #二值交叉熵(二分類,邏輯回歸)
categorical_crossentropy #交叉熵(多分類)
到此這篇關(guān)于keras建模的3種方式詳解的文章就介紹到這了,更多相關(guān)keras建模方式內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
解決python多線程報錯:AttributeError: Can''t pickle local object問題
這篇文章主要介紹了解決python多線程報錯:AttributeError: Can't pickle local object問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-04-04Python3.4學(xué)習(xí)筆記之類型判斷,異常處理,終止程序操作小結(jié)
這篇文章主要介紹了Python3.4學(xué)習(xí)筆記之類型判斷,異常處理,終止程序操作,結(jié)合具體實例形式分析了Python3.4模塊導(dǎo)入、異常處理、退出程序等相關(guān)操作技巧與注意事項,需要的朋友可以參考下2019-03-03Python中__new__與__init__方法的區(qū)別詳解
這篇文章主要介紹了Python中__new__與__init__方法的區(qū)別,是Python學(xué)習(xí)中的基礎(chǔ)知識,需要的朋友可以參考下2015-05-05django-rest-framework解析請求參數(shù)過程詳解
這篇文章主要介紹了django-rest-framework解析請求參數(shù)過程詳解,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2019-07-07Python數(shù)據(jù)可視化plt.savefig如何將圖片存入固定路徑
這篇文章主要介紹了Python數(shù)據(jù)可視化plt.savefig如何將圖片存入固定路徑問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-09-09使用python讀取txt文件的內(nèi)容,并刪除重復(fù)的行數(shù)方法
下面小編就為大家分享一篇使用python讀取txt文件的內(nèi)容,并刪除重復(fù)的行數(shù)方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04