欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

TensorFlow——Checkpoint為模型添加檢查點的實例

 更新時間:2020年01月21日 09:51:06   作者:Baby-Lily  
今天小編就為大家分享一篇TensorFlow——Checkpoint為模型添加檢查點的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

1.檢查點

保存模型并不限于在訓(xùn)練模型后,在訓(xùn)練模型之中也需要保存,因為TensorFlow訓(xùn)練模型時難免會出現(xiàn)中斷的情況,我們自然希望能夠?qū)⒂?xùn)練得到的參數(shù)保存下來,否則下次又要重新訓(xùn)練。

這種在訓(xùn)練中保存模型,習(xí)慣上稱之為保存檢查點。

2.添加保存點

通過添加檢查點,可以生成載入檢查點文件,并能夠指定生成檢查文件的個數(shù),例如使用saver的另一個參數(shù)——max_to_keep=1,表明最多只保存一個檢查點文件,在保存時使用如下的代碼傳入迭代次數(shù)。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代碼中,我們使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)將訓(xùn)練的參數(shù)傳入檢查點進行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一個文件,這樣在訓(xùn)練過程中得到的新的模型就會覆蓋以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
  saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的兩種方法也可以對checkpoint文件進行加載,tf.train.latest_checkpoint(savedir)為加載最后的檢查點文件。這種方式,我們可以通過保存指定訓(xùn)練次數(shù)的檢查點,比如保存5的倍數(shù)次保存一下檢查點。

3.簡便保存檢查點

我們還可以用更加簡單的方法進行檢查點的保存,tf.train.MonitoredTrainingSession()函數(shù),該函數(shù)可以直接實現(xiàn)保存載入檢查點模型的文件,與前面的方法不同的是,它是按照訓(xùn)練時間來保存檢查點的,可以通過指定save_checkpoint_secs參數(shù)的具體秒數(shù),設(shè)置多久保存一次檢查點。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代碼中,我們設(shè)置了沒訓(xùn)練了5秒中之后,就保存一次檢查點,它默認的保存時間間隔是10分鐘,這種按照時間的保存模式更適合使用大型數(shù)據(jù)集訓(xùn)練復(fù)雜模型的情況,注意在使用上述的方法時,要定義global_step變量,在訓(xùn)練完一個批次或者一個樣本之后,要將其進行加1的操作,否則將會報錯。

以上這篇TensorFlow——Checkpoint為模型添加檢查點的實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • NumPy 數(shù)組屬性的具體使用

    NumPy 數(shù)組屬性的具體使用

    本文主要介紹了NumPy 數(shù)組屬性的具體使用,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2022-08-08
  • Python操作MySQL模擬銀行轉(zhuǎn)賬

    Python操作MySQL模擬銀行轉(zhuǎn)賬

    這篇文章主要為大家詳細介紹了Python操作MySQL模擬銀行轉(zhuǎn)賬,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-03-03
  • 如何在Cloud Studio上執(zhí)行Python代碼?

    如何在Cloud Studio上執(zhí)行Python代碼?

    這篇文章主要介紹了如何在Cloud Studio上執(zhí)行Python代碼?,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2019-08-08
  • Python實現(xiàn)圖像的垂直投影示例

    Python實現(xiàn)圖像的垂直投影示例

    今天小編就為大家分享一篇Python實現(xiàn)圖像的垂直投影示例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-01-01
  • Python使用docx模塊編輯Word文檔

    Python使用docx模塊編輯Word文檔

    docx提供了一組功能豐富的函數(shù)和方法,用于創(chuàng)建、修改和讀取Word文檔,Python可以用它對word文檔進行大批量的編輯,下面小編就來通過一些示例為大家好好講講吧
    2023-07-07
  • Python深度學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)殘差塊

    Python深度學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)殘差塊

    這篇文章主要為大家介紹了Python深度學(xué)習(xí)中的神經(jīng)網(wǎng)絡(luò)殘差塊示例詳解有需要的 朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步
    2021-10-10
  • 在django中form的label和verbose name的區(qū)別說明

    在django中form的label和verbose name的區(qū)別說明

    這篇文章主要介紹了在django中form的label和verbose name的區(qū)別說明,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-05-05
  • python 提取文件的小程序

    python 提取文件的小程序

    在做網(wǎng)站項目時,開發(fā)經(jīng)常要給工程一個升級包,包含本次修改的內(nèi)容,這個升級包的內(nèi)容就是tomcat的發(fā)布目錄下的文件;
    2009-07-07
  • pytest配置項目不同環(huán)境URL的實現(xiàn)

    pytest配置項目不同環(huán)境URL的實現(xiàn)

    pytest-base-url是pytest的第三方插件,主要用來幫助我們進行切換測試環(huán)境地址,下面就來介紹一下配置不同環(huán)境URL的實現(xiàn),感興趣的可以了解一下
    2024-02-02
  • Python寫一個字符串數(shù)字后綴部分的遞增函數(shù)

    Python寫一個字符串數(shù)字后綴部分的遞增函數(shù)

    這篇文章主要介紹了Python寫一個字符串數(shù)字后綴部分的遞增函數(shù),寫函數(shù)之前需要Python處理重名字符串,添加或遞增數(shù)字字符串后綴,下面具體過程,需要的小伙伴可以參考一下
    2022-03-03

最新評論