欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Tensorflow訓(xùn)練MNIST手寫數(shù)字識(shí)別模型

 更新時(shí)間:2020年02月13日 11:10:46   作者:Sebastien23  
這篇文章主要為大家詳細(xì)介紹了Tensorflow訓(xùn)練MNIST手寫數(shù)字識(shí)別模型,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下

本文實(shí)例為大家分享了Tensorflow訓(xùn)練MNIST手寫數(shù)字識(shí)別模型的具體代碼,供大家參考,具體內(nèi)容如下

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
 
INPUT_NODE = 784  # 輸入層節(jié)點(diǎn)=圖片像素=28x28=784
OUTPUT_NODE = 10  # 輸出層節(jié)點(diǎn)數(shù)=圖片類別數(shù)目
 
LAYER1_NODE = 500  # 隱藏層節(jié)點(diǎn)數(shù),只有一個(gè)隱藏層
BATCH_SIZE = 100  # 一個(gè)訓(xùn)練包中的數(shù)據(jù)個(gè)數(shù),數(shù)字越小
          # 越接近隨機(jī)梯度下降,越大越接近梯度下降
 
LEARNING_RATE_BASE = 0.8   # 基礎(chǔ)學(xué)習(xí)率
LEARNING_RATE_DECAY = 0.99  # 學(xué)習(xí)率衰減率
 
REGULARIZATION_RATE = 0.0001  # 正則化項(xiàng)系數(shù)
TRAINING_STEPS = 30000     # 訓(xùn)練輪數(shù)
MOVING_AVG_DECAY = 0.99    # 滑動(dòng)平均衰減率
 
# 定義一個(gè)輔助函數(shù),給定神經(jīng)網(wǎng)絡(luò)的輸入和所有參數(shù),計(jì)算神經(jīng)網(wǎng)絡(luò)的前向傳播結(jié)果
def inference(input_tensor, avg_class, weights1, biases1,
       weights2, biases2):
 
 # 當(dāng)沒有提供滑動(dòng)平均類時(shí),直接使用參數(shù)當(dāng)前取值
 if avg_class == None:
  # 計(jì)算隱藏層前向傳播結(jié)果
  layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
  # 計(jì)算輸出層前向傳播結(jié)果
  return tf.matmul(layer1, weights2) + biases2
 else:
  # 首先計(jì)算變量的滑動(dòng)平均值,然后計(jì)算前向傳播結(jié)果
  layer1 = tf.nn.relu(
    tf.matmul(input_tensor, avg_class.average(weights1)) +
    avg_class.average(biases1))
  
  return tf.matmul(
    layer1, avg_class.average(weights2)) + avg_class.average(biases2)
 
# 訓(xùn)練模型的過程
def train(mnist):
 x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
 y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
 
 # 生成隱藏層參數(shù)
 weights1 = tf.Variable(
   tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
 biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
 
 # 生成輸出層參數(shù)
 weights2 = tf.Variable(
   tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
 biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
 
 # 計(jì)算前向傳播結(jié)果,不使用參數(shù)滑動(dòng)平均值 avg_class=None
 y = inference(x, None, weights1, biases1, weights2, biases2)
 
 # 定義訓(xùn)練輪數(shù)變量,指定為不可訓(xùn)練
 global_step = tf.Variable(0, trainable=False)
 
 # 給定滑動(dòng)平均衰減率和訓(xùn)練輪數(shù)的變量,初始化滑動(dòng)平均類
 variable_avgs = tf.train.ExponentialMovingAverage(
   MOVING_AVG_DECAY, global_step)
 
 # 在所有代表神經(jīng)網(wǎng)絡(luò)參數(shù)的可訓(xùn)練變量上使用滑動(dòng)平均
 variables_avgs_op = variable_avgs.apply(tf.trainable_variables())
 
 # 計(jì)算使用滑動(dòng)平均值后的前向傳播結(jié)果
 avg_y = inference(x, variable_avgs, weights1, biases1, weights2, biases2)
 
 # 計(jì)算交叉熵作為損失函數(shù)
 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
   logits=y, labels=tf.argmax(y_, 1))
 cross_entropy_mean = tf.reduce_mean(cross_entropy)
 
 # 計(jì)算L2正則化損失函數(shù)
 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
 regularization = regularizer(weights1) + regularizer(weights2)
 
 loss = cross_entropy_mean + regularization
 
 # 設(shè)置指數(shù)衰減的學(xué)習(xí)率
 learning_rate = tf.train.exponential_decay(
   LEARNING_RATE_BASE,
   global_step,              # 當(dāng)前迭代輪數(shù)
   mnist.train.num_examples / BATCH_SIZE, # 過完所有訓(xùn)練數(shù)據(jù)的迭代次數(shù)
   LEARNING_RATE_DECAY)
 
 
 # 優(yōu)化損失函數(shù)
 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
   loss, global_step=global_step)
 
 # 反向傳播同時(shí)更新神經(jīng)網(wǎng)絡(luò)參數(shù)及其滑動(dòng)平均值
 with tf.control_dependencies([train_step, variables_avgs_op]):
  train_op = tf.no_op(name='train')
 
 # 檢驗(yàn)使用了滑動(dòng)平均模型的神經(jīng)網(wǎng)絡(luò)前向傳播結(jié)果是否正確
 correct_prediction = tf.equal(tf.argmax(avg_y, 1), tf.argmax(y_, 1))
 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 
 # 初始化會(huì)話并開始訓(xùn)練
 with tf.Session() as sess:
  tf.global_variables_initializer().run()
  
  # 準(zhǔn)備驗(yàn)證數(shù)據(jù),用于判斷停止條件和訓(xùn)練效果
  validate_feed = {x: mnist.validation.images,
          y_: mnist.validation.labels}
  
  # 準(zhǔn)備測試數(shù)據(jù),用于模型優(yōu)劣的最后評價(jià)標(biāo)準(zhǔn)
  test_feed = {x: mnist.test.images, y_: mnist.test.labels}
  
  # 迭代訓(xùn)練神經(jīng)網(wǎng)絡(luò)
  for i in range(TRAINING_STEPS):
   if i%1000 == 0:
    validate_acc = sess.run(accuracy, feed_dict=validate_feed)
    print("After %d training step(s), validation accuracy using average " 
       "model is %g " % (i, validate_acc))
    
   xs, ys = mnist.train.next_batch(BATCH_SIZE)
   sess.run(train_op, feed_dict={x: xs, y_: ys})
  
  # 訓(xùn)練結(jié)束后在測試集上檢測模型的最終正確率
  test_acc = sess.run(accuracy, feed_dict=test_feed)
  print("After %d training steps, test accuracy using average model "
     "is %g " % (TRAINING_STEPS, test_acc))
  
  
# 主程序入口
def main(argv=None):
 mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
 train(mnist)
 
# Tensorflow主程序入口
if __name__ == '__main__':
 tf.app.run()

輸出結(jié)果如下:

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
After 0 training step(s), validation accuracy using average model is 0.0462 
After 1000 training step(s), validation accuracy using average model is 0.9784 
After 2000 training step(s), validation accuracy using average model is 0.9806 
After 3000 training step(s), validation accuracy using average model is 0.9798 
After 4000 training step(s), validation accuracy using average model is 0.9814 
After 5000 training step(s), validation accuracy using average model is 0.9826 
After 6000 training step(s), validation accuracy using average model is 0.9828 
After 7000 training step(s), validation accuracy using average model is 0.9832 
After 8000 training step(s), validation accuracy using average model is 0.9838 
After 9000 training step(s), validation accuracy using average model is 0.983 
After 10000 training step(s), validation accuracy using average model is 0.9836 
After 11000 training step(s), validation accuracy using average model is 0.9822 
After 12000 training step(s), validation accuracy using average model is 0.983 
After 13000 training step(s), validation accuracy using average model is 0.983 
After 14000 training step(s), validation accuracy using average model is 0.9844 
After 15000 training step(s), validation accuracy using average model is 0.9832 
After 16000 training step(s), validation accuracy using average model is 0.9844 
After 17000 training step(s), validation accuracy using average model is 0.9842 
After 18000 training step(s), validation accuracy using average model is 0.9842 
After 19000 training step(s), validation accuracy using average model is 0.9838 
After 20000 training step(s), validation accuracy using average model is 0.9834 
After 21000 training step(s), validation accuracy using average model is 0.9828 
After 22000 training step(s), validation accuracy using average model is 0.9834 
After 23000 training step(s), validation accuracy using average model is 0.9844 
After 24000 training step(s), validation accuracy using average model is 0.9838 
After 25000 training step(s), validation accuracy using average model is 0.9834 
After 26000 training step(s), validation accuracy using average model is 0.984 
After 27000 training step(s), validation accuracy using average model is 0.984 
After 28000 training step(s), validation accuracy using average model is 0.9836 
After 29000 training step(s), validation accuracy using average model is 0.9842 
After 30000 training steps, test accuracy using average model is 0.9839

以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python隨機(jī)驗(yàn)證碼生成和join?字符串的問題解析

    Python隨機(jī)驗(yàn)證碼生成和join?字符串的問題解析

    Python中有join()和os.path.join()兩個(gè)函數(shù),join是將字符串、元組、列表中的元素以指定的字符(分隔符)連接生成一個(gè)新的字符串而os.path.join():?將多個(gè)路徑組合后返回,本文給大家介紹的非常詳細(xì),需要的朋友一起看看吧
    2022-04-04
  • 用python寫asp詳細(xì)講解

    用python寫asp詳細(xì)講解

    本文介紹使用python寫asp程序代碼,大家參考使用吧
    2013-12-12
  • Python的numpy庫下的幾個(gè)小函數(shù)的用法(小結(jié))

    Python的numpy庫下的幾個(gè)小函數(shù)的用法(小結(jié))

    這篇文章主要介紹了Python的numpy庫下的幾個(gè)小函數(shù)的用法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-07-07
  • 對numpy下的軸交換transpose和swapaxes的示例解讀

    對numpy下的軸交換transpose和swapaxes的示例解讀

    今天小編就為大家分享一篇對numpy下的軸交換transpose和swapaxes的示例解讀,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-06-06
  • python scrapy爬蟲代碼及填坑

    python scrapy爬蟲代碼及填坑

    這篇文章主要介紹了python scrapy爬蟲代碼及填坑,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-08-08
  • Python入門必須知道的11個(gè)知識(shí)點(diǎn)

    Python入門必須知道的11個(gè)知識(shí)點(diǎn)

    這篇文章主要為大家詳細(xì)介紹了Python入門必須知道的11個(gè)知識(shí)點(diǎn),幫助更好地了解python,感興趣的小伙伴們可以參考一下
    2018-03-03
  • Python獲取Linux系統(tǒng)下的本機(jī)IP地址代碼分享

    Python獲取Linux系統(tǒng)下的本機(jī)IP地址代碼分享

    這篇文章主要介紹了Python獲取Linux系統(tǒng)下的本機(jī)IP地址代碼分享,本文直接給出實(shí)現(xiàn)代碼,可以獲取到eth0等網(wǎng)卡的IP地址,需要的朋友可以參考下
    2014-11-11
  • Python時(shí)間轉(zhuǎn)化方法超全總結(jié)

    Python時(shí)間轉(zhuǎn)化方法超全總結(jié)

    在生活和工作中,我們每個(gè)人每天都在和時(shí)間打交道。本文就為大家總結(jié)了Python實(shí)現(xiàn)時(shí)間轉(zhuǎn)化的多種方法,快來跟隨小編一起學(xué)習(xí)一下吧
    2022-03-03
  • Python基礎(chǔ)教程之while循環(huán)用法講解

    Python基礎(chǔ)教程之while循環(huán)用法講解

    Python中除了for循環(huán)之外還有一個(gè)while循環(huán),下面這篇文章主要給大家介紹了關(guān)于Python基礎(chǔ)教程之while循環(huán)用法講解的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考下
    2022-12-12
  • python 實(shí)現(xiàn)logging動(dòng)態(tài)變更輸出日志文件名

    python 實(shí)現(xiàn)logging動(dòng)態(tài)變更輸出日志文件名

    這篇文章主要介紹了python 實(shí)現(xiàn)logging動(dòng)態(tài)變更輸出日志文件名的案例,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2021-03-03

最新評論