欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Tensorflow2.1 完成權(quán)重或模型的保存和加載

 更新時(shí)間:2022年11月17日 16:30:52   作者:我是王大你是誰  
這篇文章主要為大家介紹了Tensorflow2.1 完成權(quán)重或模型的保存和加載,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

前言

本文主要使用 cpu 版本的 tensorflow-2.1 來完成深度學(xué)習(xí)權(quán)重參數(shù)/模型的保存和加載操作。

在我們進(jìn)行項(xiàng)目期間,很多時(shí)候都要在模型訓(xùn)練期間、訓(xùn)練結(jié)束之后對模型或者模型權(quán)重進(jìn)行保存,然后我們可以從之前停止的地方恢復(fù)原模型效果繼續(xù)進(jìn)行訓(xùn)練或者直接投入實(shí)際使用,另外為了節(jié)省存儲空間我們還可以自定義保存內(nèi)容和保存頻率。

實(shí)現(xiàn)方法

1. 讀取數(shù)據(jù)

(1)本文重點(diǎn)介紹模型或者模型權(quán)重的保存和讀取的相關(guān)操作,使用到的是 MNIST 數(shù)據(jù)集僅是為了演示效果,我們無需關(guān)心模型訓(xùn)練的質(zhì)量好壞。

(2)這里是常規(guī)的讀取數(shù)據(jù)操作,我們?yōu)榱四茌^快介紹本文重點(diǎn)內(nèi)容,只使用了 MNIST 前 1000 條數(shù)據(jù),然后對數(shù)據(jù)進(jìn)行歸一化操作,加快模型訓(xùn)練收斂速度,并且將每張圖片的數(shù)據(jù)從二維壓縮成一維。

import os
import tensorflow as tf
from tensorflow import keras
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

2. 搭建深度學(xué)習(xí)模型

(1)這里主要是搭建一個(gè)最簡單的深度學(xué)習(xí)模型。

(2)第一層將圖片的長度為 784 的一維向量轉(zhuǎn)換成 256 維向量的全連接操作,并且用到了 relu 激活函數(shù)。

(3)第二層緊接著使用了防止過擬合的 Dropout 操作,神經(jīng)元丟棄率為 50% 。

(4)第三層為輸出層,也就是輸出每張圖片屬于對應(yīng) 10 種類別的分布概率。

(5)優(yōu)化器我們選擇了最常見的 Adam 。

(6)損失函數(shù)選擇了 SparseCategoricalCrossentropy 。

(7)評估指標(biāo)選用了 SparseCategoricalAccuracy 。

def create_model():
    model = tf.keras.Sequential([keras.layers.Dense(256, activation='relu', input_shape=(784,)),
                                 keras.layers.Dropout(0.5),
                                 keras.layers.Dense(10) ])
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
    return model

3. 使用回調(diào)函數(shù)在每個(gè) epoch 后自動保存模型權(quán)重

(1)這里介紹一種在模型訓(xùn)練期間保存權(quán)重參數(shù)的方法,我們定義一個(gè)回調(diào)函數(shù) callback ,它可以在訓(xùn)練過程中將權(quán)重保存在自定義目錄中 weights_path ,在訓(xùn)練過程中一共執(zhí)行 5 次 epoch ,每次 epoch 結(jié)束之后就會保存一次模型的權(quán)重到指定的目錄。

(2)可以看到最后使用測試集進(jìn)行評估的 loss 為 0.4952 ,分類準(zhǔn)確率為 0.8500 。

weights_path = "training_weights/cp.ckpt"
weights_dir = os.path.dirname(weights_path)
callback = tf.keras.callbacks.ModelCheckpoint(filepath=weights_path, save_weights_only=True,  verbose=1)
model = create_model()
model.fit(train_images, 
          train_labels,  
          epochs=5,
          validation_data=(test_images, test_labels),
          callbacks=[callback]) 

輸出結(jié)果為:

 val_loss: 0.4952 - val_sparse_categorical_accuracy: 0.8500             

(3)我們?yōu)g覽目標(biāo)文件夾里,只有三個(gè)文件,每個(gè) epoch 后自動都會保存三個(gè)文件,在下一次 epoch 之后會自動更新這三個(gè)文件的內(nèi)容。

os.listdir(weights_dir)

結(jié)果為:

['checkpoint', 'cp.ckpt.data-00000-of-00001', 'cp.ckpt.index']

(4) 我們通過 create_model 定義了一個(gè)新的模型實(shí)例,然后讓其在沒有訓(xùn)練的情況下使用測試數(shù)據(jù)進(jìn)行評估,結(jié)果可想而知,準(zhǔn)確率差的離譜。

NewModel = create_model()
loss, acc = NewModel.evaluate(test_images, test_labels, verbose=2)

結(jié)果為:

loss: 2.3694 - sparse_categorical_accuracy: 0.1330

(5) tensorflow 中只要兩個(gè)模型有相同的模型結(jié)構(gòu),就可以在它們之間共享權(quán)重,所以我們使用 NewModel 讀取了之前訓(xùn)練好的模型權(quán)重,再使用測試集對其進(jìn)行評估發(fā)現(xiàn),損失值和準(zhǔn)確率和舊模型的結(jié)果完全一樣,說明權(quán)重被相同結(jié)構(gòu)的新模型成功加載并使用。

NewModel.load_weights(checkpoint_path)
loss, acc = NewModel.evaluate(test_images, test_labels, verbose=2)

輸出結(jié)果:

loss: 0.4952 - sparse_categorical_accuracy: 0.8500

4. 使用回調(diào)函數(shù)每經(jīng)過 5 個(gè) epoch 對模型權(quán)重保存一次

(1)如果我們想保留多個(gè)中間 epoch 的模型訓(xùn)練的權(quán)重,或者我們想每隔幾個(gè) epoch 保存一次模型訓(xùn)練的權(quán)重,這時(shí)候我們可以通過設(shè)置保存頻率 period 來完成,我這里讓新建的模型訓(xùn)練 30 個(gè) epoch ,在每經(jīng)過 10 epoch 后保存一次模型訓(xùn)練好的權(quán)重。

(2)使用測試集對此次模型進(jìn)行評估,損失值為 0.4047 ,準(zhǔn)確率為 0.8680 。

weights_path = "training_weights2/cp-{epoch:04d}.ckpt"
weights_dir = os.path.dirname(weights_path)
batch_size = 64
cp_callback = tf.keras.callbacks.ModelCheckpoint( filepath=weights_path, 
                                                  verbose=1, 
                                                  save_weights_only=True,
                                                  period=10)
model = create_model()
model.save_weights(weights_path.format(epoch=1))
model.fit(train_images, 
          train_labels,
          epochs=30, 
          batch_size=batch_size, 
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=1)

結(jié)果輸出為:

val_loss: 0.4047 - val_sparse_categorical_accuracy: 0.8680   

(3)這里我們能看到指定目錄中的文件組成,這里的 0001 是因?yàn)橛?xùn)練時(shí)指定了要保存的 epoch 的權(quán)重,其他都是每 10 個(gè) epoch 保存的權(quán)重參數(shù)文件。目錄中有一個(gè) checkpoint ,它是一個(gè)檢查點(diǎn)文本文件,文件保存了一個(gè)目錄下所有的模型文件列表,首行記錄的是最后(最近)一次保存的模型名稱。

(4)每個(gè) epoch 保存下來的文件都包含:

  • 一個(gè)索引文件,指示哪些權(quán)重存儲在哪個(gè)分片中
  • 一個(gè)或多個(gè)包含模型權(quán)重的分片

瀏覽文件夾內(nèi)容

os.listdir(weights_dir)

結(jié)果如下:

['checkpoint', 'cp-0001.ckpt.data-00000-of-00001', 'cp-0001.ckpt.index', 'cp-0010.ckpt.data-00000-of-00001', 'cp-0010.ckpt.index', 'cp-0020.ckpt.data-00000-of-00001', 'cp-0020.ckpt.index', 'cp-0030.ckpt.data-00000-of-00001', 'cp-0030.ckpt.index']

(5)我們將最后一次保存的權(quán)重讀取出來,然后創(chuàng)建一個(gè)新的模型去讀取剛剛保存的最新的之前訓(xùn)練好的模型權(quán)重,然后通過測試集對新模型進(jìn)行評估,發(fā)現(xiàn)損失值準(zhǔn)確率和之前完全一樣,說明權(quán)重被成功讀取并使用。

latest = tf.train.latest_checkpoint(weights_dir)
newModel = create_model()
newModel.load_weights(latest)
loss, acc = newModel.evaluate(test_images, test_labels, verbose=2)

結(jié)果如下:

loss: 0.4047 - sparse_categorical_accuracy: 0.8680

5. 手動保存模型權(quán)重到指定目錄

(1)有時(shí)候我們還想手動將模型訓(xùn)練好的權(quán)重保存到指定的目錄下,我們可以使用 save_weights 函數(shù),通過我們新建了一個(gè)同樣的新模型,然后使用 load_weights 函數(shù)去讀取權(quán)重并使用測試集對其進(jìn)行評估,發(fā)現(xiàn)損失值和準(zhǔn)確率仍然和之前的兩種結(jié)果完全一樣。

model.save_weights('./training_weights3/my_cp')
newModel = create_model()
newModel.load_weights('./training_weights3/my_cp')
loss, acc = newModel.evaluate(test_images, test_labels, verbose=2)

結(jié)果如下:

loss: 0.4047 - sparse_categorical_accuracy: 0.8680

6. 手動保存整個(gè)模型結(jié)構(gòu)和權(quán)重

(1)有時(shí)候我們還需要保存整個(gè)模型的結(jié)構(gòu)和權(quán)重,這時(shí)候我們直接使用 save 函數(shù)即可將這些內(nèi)容保存到指定目錄,使用該方法要保證目錄是存在的否則會報(bào)錯(cuò),所以這里我們要創(chuàng)建文件夾。我們能看到損失值為 0.4821,準(zhǔn)確率為 0.8460 。

model = create_model()
model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels), verbose=1)
!mkdir my_model
modelPath = './my_model'
model.save(modelPath)

輸出結(jié)果:

val_loss: 0.4821 - val_sparse_categorical_accuracy: 0.8460

(2)然后我們通過函數(shù) load_model 即可生成出一個(gè)新的完全一樣結(jié)構(gòu)和權(quán)重的模型,我們使用測試集對其進(jìn)行評估,發(fā)現(xiàn)準(zhǔn)確率和損失值和之前完全一樣,說明模型結(jié)構(gòu)和權(quán)重被完全讀取恢復(fù)。

new_model = tf.keras.models.load_model(modelPath)
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)

輸出結(jié)果:

 loss: 0.4821 - sparse_categorical_accuracy: 0.8460

以上就是Tensorflow2.1 完成權(quán)重或模型的保存和加載的詳細(xì)內(nèi)容,更多關(guān)于Tensorflow完成權(quán)重模型保存加載的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • 在PyCharm下使用 ipython 交互式編程的方法

    在PyCharm下使用 ipython 交互式編程的方法

    今天小編就為大家分享一篇在PyCharm下使用 ipython 交互式編程的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-01-01
  • Python設(shè)計(jì)模式之原型模式實(shí)例詳解

    Python設(shè)計(jì)模式之原型模式實(shí)例詳解

    這篇文章主要介紹了Python設(shè)計(jì)模式之原型模式,結(jié)合實(shí)例形式較為詳細(xì)的分析了Python原型模式的概念、原理、用法及相關(guān)操作注意事項(xiàng),需要的朋友可以參考下
    2019-01-01
  • Python入門必須知道的11個(gè)知識點(diǎn)

    Python入門必須知道的11個(gè)知識點(diǎn)

    這篇文章主要為大家詳細(xì)介紹了Python入門必須知道的11個(gè)知識點(diǎn),幫助更好地了解python,感興趣的小伙伴們可以參考一下
    2018-03-03
  • python加速器numba使用詳解

    python加速器numba使用詳解

    本文主要介紹了python加速器numba使用詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2023-02-02
  • Python軟件包安裝的三種常見方法

    Python軟件包安裝的三種常見方法

    python擁有非常豐富的擴(kuò)展包,下面這篇文章主要給大家介紹了關(guān)于Python軟件包安裝的三種常見方法,文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考下
    2022-07-07
  • pycharm使用docker容器開發(fā)的詳細(xì)教程

    pycharm使用docker容器開發(fā)的詳細(xì)教程

    這篇文章主要介紹了pycharm使用docker容器開發(fā)的詳細(xì)教程,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2023-01-01
  • python利用Excel讀取和存儲測試數(shù)據(jù)完成接口自動化教程

    python利用Excel讀取和存儲測試數(shù)據(jù)完成接口自動化教程

    這篇文章主要介紹了python利用Excel讀取和存儲測試數(shù)據(jù)完成接口自動化教程,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-04-04
  • Pandas透視表與交叉表的使用

    Pandas透視表與交叉表的使用

    Pandas中的交叉表和透視表的作用相似,本文就來介紹一下Pandas透視表與交叉表的使用,具有一定的參考價(jià)值,感興趣的可以了解一下
    2023-11-11
  • PyCharm 常用快捷鍵和設(shè)置方法

    PyCharm 常用快捷鍵和設(shè)置方法

    下面小編就為大家分享一篇PyCharm 常用快捷鍵和設(shè)置方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2017-12-12
  • Python+Matplotlib實(shí)現(xiàn)繪制三維折線圖

    Python+Matplotlib實(shí)現(xiàn)繪制三維折線圖

    立體圖視覺上層次分明色彩鮮艷,具有很強(qiáng)的視覺沖擊力,讓觀看的人駐景時(shí)間長,留下深刻的印象。今天我們就通過這篇文章來了解如何用python中的matplotlib庫繪制漂亮的三維折線圖吧
    2023-03-03

最新評論