詳解TensorFlow訓(xùn)練網(wǎng)絡(luò)兩種方式
TensorFlow訓(xùn)練網(wǎng)絡(luò)有兩種方式,一種是基于tensor(array),另外一種是迭代器
兩種方式區(qū)別是:
- 第一種是要加載全部數(shù)據(jù)形成一個(gè)tensor,然后調(diào)用model.fit()然后指定參數(shù)batch_size進(jìn)行將所有數(shù)據(jù)進(jìn)行分批訓(xùn)練
- 第二種是自己先將數(shù)據(jù)分批形成一個(gè)迭代器,然后遍歷這個(gè)迭代器,分別訓(xùn)練每個(gè)批次的數(shù)據(jù)
方式一:通過迭代器
IMAGE_SIZE = 1000 # step1:加載數(shù)據(jù)集 (train_images, train_labels), (val_images, val_labels) = tf.keras.datasets.mnist.load_data() # step2:將圖像歸一化 train_images, val_images = train_images / 255.0, val_images / 255.0 # step3:設(shè)置訓(xùn)練集大小 train_images = train_images[:IMAGE_SIZE] val_images = val_images[:IMAGE_SIZE] train_labels = train_labels[:IMAGE_SIZE] val_labels = val_labels[:IMAGE_SIZE] # step4:將圖像的維度變?yōu)?IMAGE_SIZE,28,28,1) train_images = tf.expand_dims(train_images, axis=3) val_images = tf.expand_dims(val_images, axis=3) # step5:將圖像的尺寸變?yōu)?32,32) train_images = tf.image.resize(train_images, [32, 32]) val_images = tf.image.resize(val_images, [32, 32]) # step6:將數(shù)據(jù)變?yōu)榈? train_loader = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(32) val_loader = tf.data.Dataset.from_tensor_slices((val_images, val_labels)).batch(IMAGE_SIZE) # step5:導(dǎo)入模型 model = LeNet5() # 讓模型知道輸入數(shù)據(jù)的形式 model.build(input_shape=(1, 32, 32, 1)) # 結(jié)局Output Shape為 multiple model.call(Input(shape=(32, 32, 1))) # step6:編譯模型 model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # 權(quán)重保存路徑 checkpoint_path = "./weight/cp.ckpt" # 回調(diào)函數(shù),用戶保存權(quán)重 save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0) EPOCHS = 11 for epoch in range(1, EPOCHS): # 每個(gè)批次訓(xùn)練集誤差 train_epoch_loss_avg = tf.keras.metrics.Mean() # 每個(gè)批次訓(xùn)練集精度 train_epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() # 每個(gè)批次驗(yàn)證集誤差 val_epoch_loss_avg = tf.keras.metrics.Mean() # 每個(gè)批次驗(yàn)證集精度 val_epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() for x, y in train_loader: history = model.fit(x, y, validation_data=val_loader, callbacks=[save_callback], verbose=0) # 更新誤差,保留上次 train_epoch_loss_avg.update_state(history.history['loss'][0]) # 更新精度,保留上次 train_epoch_accuracy.update_state(y, model(x, training=True)) val_epoch_loss_avg.update_state(history.history['val_loss'][0]) val_epoch_accuracy.update_state(next(iter(val_loader))[1], model(next(iter(val_loader))[0], training=True)) # 使用.result()計(jì)算每個(gè)批次的誤差和精度結(jié)果 print("Epoch {:d}: trainLoss: {:.3f}, trainAccuracy: {:.3%} valLoss: {:.3f}, valAccuracy: {:.3%}".format(epoch, train_epoch_loss_avg.result(), train_epoch_accuracy.result(), val_epoch_loss_avg.result(), val_epoch_accuracy.result()))
方式二:適用model.fit()進(jìn)行分批訓(xùn)練
import model_sequential (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() # step2:將圖像歸一化 train_images, test_images = train_images / 255.0, test_images / 255.0 # step3:將圖像的維度變?yōu)?60000,28,28,1) train_images = tf.expand_dims(train_images, axis=3) test_images = tf.expand_dims(test_images, axis=3) # step4:將圖像尺寸改為(60000,32,32,1) train_images = tf.image.resize(train_images, [32, 32]) test_images = tf.image.resize(test_images, [32, 32]) # step5:導(dǎo)入模型 # history = LeNet5() history = model_sequential.LeNet() # 讓模型知道輸入數(shù)據(jù)的形式 history.build(input_shape=(1, 32, 32, 1)) # history(tf.zeros([1, 32, 32, 1])) # 結(jié)局Output Shape為 multiple history.call(Input(shape=(32, 32, 1))) history.summary() # step6:編譯模型 history.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # 權(quán)重保存路徑 checkpoint_path = "./weight/cp.ckpt" # 回調(diào)函數(shù),用戶保存權(quán)重 save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=1) # step7:訓(xùn)練模型 history = history.fit(train_images, train_labels, epochs=10, batch_size=32, validation_data=(test_images, test_labels), callbacks=[save_callback])
到此這篇關(guān)于詳解TensorFlow訓(xùn)練網(wǎng)絡(luò)兩種方式的文章就介紹到這了,更多相關(guān)TensorFlow訓(xùn)練網(wǎng)絡(luò)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!?
相關(guān)文章
Python3.9最新版下載與安裝圖文教程詳解(Windows系統(tǒng)為例)
這篇文章主要介紹了Python3.9最新版下載與安裝圖文教程詳解,本文通過圖文并茂的形式給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-11-11Pycharm+django2.2+python3.6+MySQL實(shí)現(xiàn)簡單的考試報(bào)名系統(tǒng)
這篇文章主要介紹了Pycharm+django2.2+python3.6+MySQL實(shí)現(xiàn)簡單的考試報(bào)名系統(tǒng),本文圖文并茂給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-09-09Scrapy模擬登錄趕集網(wǎng)的實(shí)現(xiàn)代碼
這篇文章主要介紹了Scrapy模擬登錄趕集網(wǎng)的實(shí)現(xiàn)代碼,本文通過代碼圖文相結(jié)合給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-07-07python中windows鏈接linux執(zhí)行命令并獲取執(zhí)行狀態(tài)的問題小結(jié)
這篇文章主要介紹了python中windows鏈接linux執(zhí)行命令并獲取執(zhí)行狀態(tài),由于工具是pyqt寫的所以牽扯到用python鏈接linux的問題,這里記錄一下一些碰到的問題,需要的朋友可以參考下2022-11-11Python 實(shí)現(xiàn)使用dict 創(chuàng)建二維數(shù)據(jù)、DataFrame
下面小編就為大家分享一篇Python 實(shí)現(xiàn)使用dict 創(chuàng)建二維數(shù)據(jù)、DataFrame,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04Python實(shí)現(xiàn)光速定位并提取兩個(gè)文件的不同之處
如果你經(jīng)常與Excel或Word打交道,那么從兩份表格/文檔中找到不一樣的元素是一件讓人很頭疼的工作。本文就將以兩份真實(shí)的Excel/Word文件為例,講解如何使用Python光速對比并提取文件中的不同之處2022-08-08OpenCV-Python實(shí)現(xiàn)輪廓的特征值
輪廓自身的一些屬性特征及輪廓所包圍對象的特征對于描述圖像具有重要意義。本篇博文將介紹幾個(gè)輪廓自身的屬性特征及輪廓包圍對象的特征,感興趣的可以了解一下2021-06-06