python tensorflow基于cnn實(shí)現(xiàn)手寫數(shù)字識(shí)別
一份基于cnn的手寫數(shù)字自識(shí)別的代碼,供大家參考,具體內(nèi)容如下
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 加載數(shù)據(jù)集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 以交互式方式啟動(dòng)session
# 如果不使用交互式session,則在啟動(dòng)session前必須
# 構(gòu)建整個(gè)計(jì)算圖,才能啟動(dòng)該計(jì)算圖
sess = tf.InteractiveSession()
"""構(gòu)建計(jì)算圖"""
# 通過占位符來為輸入圖像和目標(biāo)輸出類別創(chuàng)建節(jié)點(diǎn)
# shape參數(shù)是可選的,有了它tensorflow可以自動(dòng)捕獲維度不一致導(dǎo)致的錯(cuò)誤
x = tf.placeholder("float", shape=[None, 784]) # 原始輸入
y_ = tf.placeholder("float", shape=[None, 10]) # 目標(biāo)值
# 為了不在建立模型的時(shí)候反復(fù)做初始化操作,
# 我們定義兩個(gè)函數(shù)用于初始化
def weight_variable(shape):
# 截尾正態(tài)分布,stddev是正態(tài)分布的標(biāo)準(zhǔn)偏差
initial = tf.truncated_normal(shape=shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
# 卷積核池化,步長為1,0邊距
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
"""第一層卷積"""
# 由一個(gè)卷積和一個(gè)最大池化組成。濾波器5x5中算出32個(gè)特征,是因?yàn)槭褂?2個(gè)濾波器進(jìn)行卷積
# 卷積的權(quán)重張量形狀是[5, 5, 1, 32],1是輸入通道的個(gè)數(shù),32是輸出通道個(gè)數(shù)
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一個(gè)輸出通道都有一個(gè)偏置量
b_conv1 = bias_variable([32])
# 位了使用卷積,必須將輸入轉(zhuǎn)換成4維向量,2、3維表示圖片的寬、高
# 最后一維表示圖片的顏色通道(因?yàn)槭腔叶葓D像所以通道數(shù)維1,RGB圖像通道數(shù)為3)
x_image = tf.reshape(x, [-1, 28, 28, 1])
# 第一層的卷積結(jié)果,使用Relu作為激活函數(shù)
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一層卷積后的池化結(jié)果
h_pool1 = max_pool_2x2(h_conv1)
"""第二層卷積"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
"""全連接層"""
# 圖片尺寸減小到7*7,加入一個(gè)有1024個(gè)神經(jīng)元的全連接層
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 將最后的池化層輸出張量reshape成一維向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全連接層的輸出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
"""使用Dropout減少過擬合"""
# 使用placeholder占位符來表示神經(jīng)元的輸出在dropout中保持不變的概率
# 在訓(xùn)練的過程中啟用dropout,在測試過程中關(guān)閉dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
"""輸出層"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型預(yù)測輸出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
# 交叉熵?fù)p失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))
# 模型訓(xùn)練,使用AdamOptimizer來做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
# 正確預(yù)測,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 將布爾值轉(zhuǎn)化成浮點(diǎn)數(shù),取平均值作為精確度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# 在session中先初始化變量才能在session中調(diào)用
sess.run(tf.global_variables_initializer())
# 迭代優(yōu)化模型
for i in range(2000):
# 每次取50個(gè)樣本進(jìn)行訓(xùn)練
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={
x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中間不使用dropout
print("step %d, training accuracy %g" % (i, train_accuracy))
train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
做了2000次迭代,在測試集上的識(shí)別精度能夠到0.9772……
以上就是本文的全部內(nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
Python中的字符串類型基本知識(shí)學(xué)習(xí)教程
這篇文章主要介紹了Python中的字符串類型基本知識(shí)學(xué)習(xí)教程,包括轉(zhuǎn)義符和字符串拼接以及原始字符串等基礎(chǔ)知識(shí)講解,需要的朋友可以參考下2016-02-02
python函數(shù)中return后的語句一定不會(huì)執(zhí)行嗎?
這篇文章主要給大家詳細(xì)分析講解了關(guān)于python函數(shù)中return語句后的語句是否一定不會(huì)執(zhí)行的相關(guān)資料,文中介紹的非常詳細(xì),對(duì)大家具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面跟著小編一起來學(xué)習(xí)學(xué)習(xí)吧。2017-07-07
python 繪制擬合曲線并加指定點(diǎn)標(biāo)識(shí)的實(shí)現(xiàn)
這篇文章主要介紹了python 繪制擬合曲線并加指定點(diǎn)標(biāo)識(shí)的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07
python構(gòu)建深度神經(jīng)網(wǎng)絡(luò)(DNN)
這篇文章主要為大家詳細(xì)介紹了python構(gòu)建深度神經(jīng)網(wǎng)絡(luò)DNN,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-03-03
Python 2種方法求某個(gè)范圍內(nèi)的所有素?cái)?shù)(質(zhì)數(shù))
這篇文章主要介紹了Python 2種方法求某個(gè)范圍內(nèi)的所有素?cái)?shù)(質(zhì)數(shù)),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-01-01

