tensorflow識別自己手寫數(shù)字
tensorflow作為google開源的項(xiàng)目,現(xiàn)在趕超了caffe,好像成為最受歡迎的深度學(xué)習(xí)框架。確實(shí)在編寫的時(shí)候更能感受到代碼的真實(shí)存在,這點(diǎn)和caffe不同,caffe通過編寫配置文件進(jìn)行網(wǎng)絡(luò)的生成。環(huán)境tensorflow是0.10的版本,注意其他版本有的語句會(huì)有錯(cuò)誤,這是tensorflow版本之間的兼容問題。
還需要安裝PIL:pip install Pillow
圖片的格式:
– 圖像標(biāo)準(zhǔn)化,可安裝在20×20像素的框內(nèi),同時(shí)保留其長寬比。
– 圖片都集中在一個(gè)28×28的圖像中。
– 像素以列為主進(jìn)行排序。像素值0到255,0表示背景(白色),255表示前景(黑色)。
創(chuàng)建一個(gè).png的文件,背景是白色的,手寫的字體是黑色的,
下面是數(shù)據(jù)測試的代碼,一個(gè)兩層的卷積神經(jīng)網(wǎng),然后用save進(jìn)行模型的保存。
# coding: UTF-8 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import input_data ''''' 得到數(shù)據(jù) ''' mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) training = mnist.train.images trainlable = mnist.train.labels testing = mnist.test.images testlabel = mnist.test.labels print ("MNIST loaded") # 獲取交互式的方式 sess = tf.InteractiveSession() # 初始化變量 x = tf.placeholder("float", shape=[None, 784]) y_ = tf.placeholder("float", shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) ''''' 生成權(quán)重函數(shù),其中shape是數(shù)據(jù)的形狀 ''' def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) ''''' 生成偏執(zhí)項(xiàng) 其中shape是數(shù)據(jù)形狀 ''' def bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') W_conv1 = weight_variable([5, 5, 1, 32]) b_conv1 = bias_variable([32]) x_image = tf.reshape(x, [-1, 28, 28, 1]) h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) h_pool1 = max_pool_2x2(h_conv1) W_conv2 = weight_variable([5, 5, 32, 64]) b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_pool_2x2(h_conv2) W_fc1 = weight_variable([7 * 7 * 64, 1024]) b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) keep_prob = tf.placeholder("float") h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10]) y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) # 保存網(wǎng)絡(luò)訓(xùn)練的參數(shù) saver = tf.train.Saver() sess.run(tf.initialize_all_variables()) for i in range(8000): batch = mnist.train.next_batch(50) if i%100 == 0: train_accuracy = accuracy.eval(feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0}) print "step %d, training accuracy %g"%(i, train_accuracy) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) save_path = saver.save(sess, "model_mnist.ckpt") print("Model saved in life:", save_path) print "test accuracy %g"%accuracy.eval(feed_dict={ x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
其中input_data.py如下代碼,是進(jìn)行mnist數(shù)據(jù)集的下載的:代碼是由mnist數(shù)據(jù)集提供的官方下載的版本。
# Copyright 2015 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions for downloading and reading MNIST data.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import gzip import os import tensorflow.python.platform import numpy from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' def maybe_download(filename, work_directory): """Download the data from Yann's website, unless it's already here.""" if not os.path.exists(work_directory): os.mkdir(work_directory) filepath = os.path.join(work_directory, filename) if not os.path.exists(filepath): filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) statinfo = os.stat(filepath) print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') return filepath def _read32(bytestream): dt = numpy.dtype(numpy.uint32).newbyteorder('>') return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] def extract_images(filename): """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" print('Extracting', filename) with gzip.open(filename) as bytestream: magic = _read32(bytestream) if magic != 2051: raise ValueError( 'Invalid magic number %d in MNIST image file: %s' % (magic, filename)) num_images = _read32(bytestream) rows = _read32(bytestream) cols = _read32(bytestream) buf = bytestream.read(rows * cols * num_images) data = numpy.frombuffer(buf, dtype=numpy.uint8) data = data.reshape(num_images, rows, cols, 1) return data def dense_to_one_hot(labels_dense, num_classes=10): """Convert class labels from scalars to one-hot vectors.""" num_labels = labels_dense.shape[0] index_offset = numpy.arange(num_labels) * num_classes labels_one_hot = numpy.zeros((num_labels, num_classes)) labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 return labels_one_hot def extract_labels(filename, one_hot=False): """Extract the labels into a 1D uint8 numpy array [index].""" print('Extracting', filename) with gzip.open(filename) as bytestream: magic = _read32(bytestream) if magic != 2049: raise ValueError( 'Invalid magic number %d in MNIST label file: %s' % (magic, filename)) num_items = _read32(bytestream) buf = bytestream.read(num_items) labels = numpy.frombuffer(buf, dtype=numpy.uint8) if one_hot: return dense_to_one_hot(labels) return labels class DataSet(object): def __init__(self, images, labels, fake_data=False, one_hot=False, dtype=tf.float32): """Construct a DataSet. one_hot arg is used only if fake_data is true. `dtype` can be either `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into `[0, 1]`. """ dtype = tf.as_dtype(dtype).base_dtype if dtype not in (tf.uint8, tf.float32): raise TypeError('Invalid image dtype %r, expected uint8 or float32' % dtype) if fake_data: self._num_examples = 10000 self.one_hot = one_hot else: assert images.shape[0] == labels.shape[0], ( 'images.shape: %s labels.shape: %s' % (images.shape, labels.shape)) self._num_examples = images.shape[0] # Convert shape from [num examples, rows, columns, depth] # to [num examples, rows*columns] (assuming depth == 1) assert images.shape[3] == 1 images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]) if dtype == tf.float32: # Convert from [0, 255] -> [0.0, 1.0]. images = images.astype(numpy.float32) images = numpy.multiply(images, 1.0 / 255.0) self._images = images self._labels = labels self._epochs_completed = 0 self._index_in_epoch = 0 @property def images(self): return self._images @property def labels(self): return self._labels @property def num_examples(self): return self._num_examples @property def epochs_completed(self): return self._epochs_completed def next_batch(self, batch_size, fake_data=False): """Return the next `batch_size` examples from this data set.""" if fake_data: fake_image = [1] * 784 if self.one_hot: fake_label = [1] + [0] * 9 else: fake_label = 0 return [fake_image for _ in xrange(batch_size)], [ fake_label for _ in xrange(batch_size)] start = self._index_in_epoch self._index_in_epoch += batch_size if self._index_in_epoch > self._num_examples: # Finished epoch self._epochs_completed += 1 # Shuffle the data perm = numpy.arange(self._num_examples) numpy.random.shuffle(perm) self._images = self._images[perm] self._labels = self._labels[perm] # Start next epoch start = 0 self._index_in_epoch = batch_size assert batch_size <= self._num_examples end = self._index_in_epoch return self._images[start:end], self._labels[start:end] def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): class DataSets(object): pass data_sets = DataSets() if fake_data: def fake(): return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) data_sets.train = fake() data_sets.validation = fake() data_sets.test = fake() return data_sets TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' TEST_LABELS = 't10k-labels-idx1-ubyte.gz' VALIDATION_SIZE = 5000 local_file = maybe_download(TRAIN_IMAGES, train_dir) train_images = extract_images(local_file) local_file = maybe_download(TRAIN_LABELS, train_dir) train_labels = extract_labels(local_file, one_hot=one_hot) local_file = maybe_download(TEST_IMAGES, train_dir) test_images = extract_images(local_file) local_file = maybe_download(TEST_LABELS, train_dir) test_labels = extract_labels(local_file, one_hot=one_hot) validation_images = train_images[:VALIDATION_SIZE] validation_labels = train_labels[:VALIDATION_SIZE] train_images = train_images[VALIDATION_SIZE:] train_labels = train_labels[VALIDATION_SIZE:] data_sets.train = DataSet(train_images, train_labels, dtype=dtype) data_sets.validation = DataSet(validation_images, validation_labels, dtype=dtype) data_sets.test = DataSet(test_images, test_labels, dtype=dtype) return data_sets
然后進(jìn)行代碼的測試:
# import modules import sys import tensorflow as tf from PIL import Image, ImageFilter def predictint(imvalue): """ This function returns the predicted integer. The imput is the pixel values from the imageprepare() function. """ # Define the model (same as when creating the model file) x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') W_conv1 = weight_variable([5, 5, 1, 32]) b_conv1 = bias_variable([32]) x_image = tf.reshape(x, [-1, 28, 28, 1]) h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) h_pool1 = max_pool_2x2(h_conv1) W_conv2 = weight_variable([5, 5, 32, 64]) b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_pool_2x2(h_conv2) W_fc1 = weight_variable([7 * 7 * 64, 1024]) b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) keep_prob = tf.placeholder(tf.float32) h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10]) y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) init_op = tf.initialize_all_variables() saver = tf.train.Saver() """ Load the model_mnist.ckpt file file is stored in the same directory as this python script is started Use the model to predict the integer. Integer is returend as list. Based on the documentatoin at https://www.tensorflow.org/versions/master/how_tos/variables/index.html """ with tf.Session() as sess: sess.run(init_op) saver.restore(sess, "model_mnist.ckpt") # print ("Model restored.") prediction = tf.argmax(y_conv, 1) return prediction.eval(feed_dict={x: [imvalue], keep_prob: 1.0}, session=sess) def imageprepare(argv): """ This function returns the pixel values. The imput is a png file location. """ im = Image.open(argv).convert('L') width = float(im.size[0]) height = float(im.size[1]) newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels if width > height: # check which dimension is bigger # Width is bigger. Width becomes 20 pixels. nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width if (nheight == 0): # rare case but minimum is 1 pixel nheigth = 1 # resize and sharpen img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) wtop = int(round(((28 - nheight) / 2), 0)) # caculate horizontal pozition newImage.paste(img, (4, wtop)) # paste resized image on white canvas else: # Height is bigger. Heigth becomes 20 pixels. nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height if (nwidth == 0): # rare case but minimum is 1 pixel nwidth = 1 # resize and sharpen img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition newImage.paste(img, (wleft, 4)) # paste resized image on white canvas # newImage.save("sample.png") tv = list(newImage.getdata()) # get pixel values # normalize pixels to 0 and 1. 0 is pure white, 1 is pure black. tva = [(255 - x) * 1.0 / 255.0 for x in tv] return tva # print(tva) def main(argv): """ Main function. """ imvalue = imageprepare(argv) predint = predictint(imvalue) print (predint[0]) # first value in list if __name__ == "__main__": main('2.png')
其中我用于測試的代碼如下:
可以將圖片另存到路徑下面,然后進(jìn)行測試。
(1)載入我的手寫數(shù)字的圖像。
(2)將圖像轉(zhuǎn)換為黑白(模式“L”)
(3)確定原始圖像的尺寸是最大的
(4)調(diào)整圖像的大小,使得最大尺寸(醚的高度及寬度)為20像素,并且以相同的比例最小化尺寸刻度。
(5)銳化圖像。這會(huì)極大地強(qiáng)化結(jié)果。
(6)把圖像粘貼在28×28像素的白色畫布上。在最大的尺寸上從頂部或側(cè)面居中圖像4個(gè)像素。最大尺寸始終是20個(gè)像素和4 + 20 + 4 = 28,最小尺寸被定位在28和縮放的圖像的新的大小之間差的一半。
(7)獲取新的圖像(畫布+居中的圖像)的像素值。
(8)歸一化像素值到0和1之間的一個(gè)值(這也在TensorFlow MNIST教程中完成)。其中0是白色的,1是純黑色。從步驟7得到的像素值是與之相反的,其中255是白色的,0黑色,所以數(shù)值必須反轉(zhuǎn)。下述公式包括反轉(zhuǎn)和規(guī)格化(255-X)* 1.0 / 255.0
相關(guān)文章
pytorch訓(xùn)練時(shí)的顯存占用遞增的問題解決
本文主要介紹了pytorch訓(xùn)練時(shí)的顯存占用遞增的問題解決,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-01-01使用virtualenv創(chuàng)建Python環(huán)境及PyQT5環(huán)境配置的方法
這篇文章主要介紹了使用virtualenv創(chuàng)建Python環(huán)境及PyQT5環(huán)境配置的方法,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-09-09ruff check文件目錄檢測--exclude參數(shù)設(shè)置路徑詳解
這篇文章主要為大家介紹了ruff check文件目錄檢測exclude參數(shù)如何設(shè)置多少路徑詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-10-10Django框架教程之正則表達(dá)式URL誤區(qū)詳解
正則表達(dá)式對大家來說應(yīng)該都不陌生,下面這篇文章主要給大家介紹了關(guān)于Django框架教程之正則表達(dá)式URL誤區(qū)的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考借鑒,下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧。2018-01-01TensorFlow卷積神經(jīng)網(wǎng)絡(luò)AlexNet實(shí)現(xiàn)示例詳解
這篇文章主要為大家介紹了TensorFlow卷積神經(jīng)網(wǎng)絡(luò)AlexNet實(shí)現(xiàn)示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步2021-11-11Python爬取智聯(lián)招聘數(shù)據(jù)分析師崗位相關(guān)信息的方法
這篇文章主要介紹了Python爬取智聯(lián)招聘數(shù)據(jù)分析師崗位相關(guān)信息的方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08