使用TensorFlow-Slim進(jìn)行圖像分類的實(shí)現(xiàn)
參考 https://github.com/tensorflow/models/tree/master/slim
使用TensorFlow-Slim進(jìn)行圖像分類
準(zhǔn)備
安裝TensorFlow
參考 https://www.tensorflow.org/install/
如在Ubuntu下安裝TensorFlow with GPU support, python 2.7版本
wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
下載TF-slim圖像模型庫
cd $WORKSPACE git clone https://github.com/tensorflow/models/
準(zhǔn)備數(shù)據(jù)
有不少公開數(shù)據(jù)集,這里以官網(wǎng)提供的Flowers為例。
官網(wǎng)提供了下載和轉(zhuǎn)換數(shù)據(jù)的代碼,為了理解代碼并能使用自己的數(shù)據(jù),這里參考官方提供的代碼進(jìn)行修改。
cd $WORKSPACE/data wget http://download.tensorflow.org/example_images/flower_photos.tgz tar zxf flower_photos.tgz
數(shù)據(jù)集文件夾結(jié)構(gòu)如下:
flower_photos ├── daisy │ ├── 100080576_f52e8ee070_n.jpg │ └── ... ├── dandelion ├── LICENSE.txt ├── roses ├── sunflowers └── tulips
由于實(shí)際情況中我們自己的數(shù)據(jù)集并不一定把圖片按類別放在不同的文件夾里,故我們生成list.txt來表示圖片路徑與標(biāo)簽的關(guān)系。
Python代碼:
import os
class_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
data_dir = 'flower_photos/'
output_path = 'list.txt'
fd = open(output_path, 'w')
for class_name in class_names_to_ids.keys():
images_list = os.listdir(data_dir + class_name)
for image_name in images_list:
fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))
fd.close()
為了方便后期查看label標(biāo)簽,也可以定義labels.txt:
daisy dandelion roses sunflowers tulips
隨機(jī)生成訓(xùn)練集與驗(yàn)證集:
Python代碼:
import random _NUM_VALIDATION = 350 _RANDOM_SEED = 0 list_path = 'list.txt' train_list_path = 'list_train.txt' val_list_path = 'list_val.txt' fd = open(list_path) lines = fd.readlines() fd.close() random.seed(_RANDOM_SEED) random.shuffle(lines) fd = open(train_list_path, 'w') for line in lines[_NUM_VALIDATION:]: fd.write(line) fd.close() fd = open(val_list_path, 'w') for line in lines[:_NUM_VALIDATION]: fd.write(line) fd.close()
生成TFRecord數(shù)據(jù):
Python代碼:
import sys
sys.path.insert(0, '../models/slim/')
from datasets import dataset_utils
import math
import os
import tensorflow as tf
def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
fd = open(list_path)
lines = [line.split() for line in fd]
fd.close()
num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
with tf.Graph().as_default():
decode_jpeg_data = tf.placeholder(dtype=tf.string)
decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS):
output_path = os.path.join(output_dir,
'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(
i + 1, len(lines), shard_id))
sys.stdout.flush()
image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()
image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
height, width = image.shape[0], image.shape[1]
example = dataset_utils.image_to_tfexample(
image_data, b'jpg', height, width, int(lines[i][1]))
tfrecord_writer.write(example.SerializeToString())
tfrecord_writer.close()
sys.stdout.write('\n')
sys.stdout.flush()
os.system('mkdir -p train')
convert_dataset('list_train.txt', 'flower_photos', 'train/')
os.system('mkdir -p val')
convert_dataset('list_val.txt', 'flower_photos', 'val/')
得到的文件夾結(jié)構(gòu)如下:
data ├── flower_photos ├── labels.txt ├── list_train.txt ├── list.txt ├── list_val.txt ├── train │ ├── data_00000-of-00005.tfrecord │ ├── ... │ └── data_00004-of-00005.tfrecord └── val ├── data_00000-of-00005.tfrecord ├── ... └── data_00004-of-00005.tfrecord
(可選)下載模型
官方提供了不少預(yù)訓(xùn)練模型,這里以Inception-ResNet-v2以例。
cd $WORKSPACE/checkpoints wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz tar zxf inception_resnet_v2_2016_08_30.tar.gz
訓(xùn)練
讀入數(shù)據(jù)
官方提供了讀入Flowers數(shù)據(jù)集的代碼models/slim/datasets/flowers.py,同樣這里也是參考并修改成能讀入上面定義的通用數(shù)據(jù)集。
把下面代碼寫入models/slim/datasets/dataset_classification.py。
import os
import tensorflow as tf
slim = tf.contrib.slim
def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'):
file_pattern = os.path.join(dataset_dir, file_pattern)
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
items_to_descriptions = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and ' + str(num_classes - 1),
}
labels_to_names = None
if labels_to_names_path is not None:
fd = open(labels_to_names_path)
labels_to_names = {i : line.strip() for i, line in enumerate(fd)}
fd.close()
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=num_samples,
items_to_descriptions=items_to_descriptions,
num_classes=num_classes,
labels_to_names=labels_to_names)
構(gòu)建模型
官方提供了許多模型在models/slim/nets/。
如需要自定義模型,則參考官方提供的模型并放在對(duì)應(yīng)的文件夾即可。
開始訓(xùn)練
官方提供了訓(xùn)練腳本,如果使用官方的數(shù)據(jù)讀入和處理,可使用以下方式開始訓(xùn)練。
cd $WORKSPACE/models/slim CUDA_VISIBLE_DEVICES="0" python train_image_classifier.py \ --train_dir=train_logs \ --dataset_name=flowers \ --dataset_split_name=train \ --dataset_dir=../../data/flowers \ --model_name=inception_resnet_v2 \ --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \ --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ --max_number_of_steps=1000 \ --batch_size=32 \ --learning_rate=0.01 \ --learning_rate_decay_type=fixed \ --save_interval_secs=60 \ --save_summaries_secs=60 \ --log_every_n_steps=10 \ --optimizer=rmsprop \ --weight_decay=0.00004
不fine-tune把--checkpoint_path, --checkpoint_exclude_scopes和--trainable_scopes刪掉。
fine-tune所有層把--checkpoint_exclude_scopes和--trainable_scopes刪掉。
如果只使用CPU則加上--clone_on_cpu=True。
其它參數(shù)可刪掉用默認(rèn)值或自行修改。
使用自己的數(shù)據(jù)則需要修改models/slim/train_image_classifier.py:
把
from datasets import dataset_factory
修改為
from datasets import dataset_classification
把
dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
修改為
dataset = dataset_classification.get_dataset( FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)
在
tf.app.flags.DEFINE_string( 'dataset_dir', None, 'The directory where the dataset files are stored.')
后加入
tf.app.flags.DEFINE_integer( 'num_samples', 3320, 'Number of samples.') tf.app.flags.DEFINE_integer( 'num_classes', 5, 'Number of classes.') tf.app.flags.DEFINE_string( 'labels_to_names_path', None, 'Label names file path.')
訓(xùn)練時(shí)執(zhí)行以下命令即可:
cd $WORKSPACE/models/slim python train_image_classifier.py \ --train_dir=train_logs \ --dataset_dir=../../data/train \ --num_samples=3320 \ --num_classes=5 \ --labels_to_names_path=../../data/labels.txt \ --model_name=inception_resnet_v2 \ --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \ --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits
可視化log
可一邊訓(xùn)練一邊可視化訓(xùn)練的log,可看到Loss趨勢(shì)。
tensorboard --logdir train_logs/
驗(yàn)證
官方提供了驗(yàn)證腳本。
python eval_image_classifier.py \ --checkpoint_path=train_logs \ --eval_dir=eval_logs \ --dataset_name=flowers \ --dataset_split_name=validation \ --dataset_dir=../../data/flowers \ --model_name=inception_resnet_v2
同樣,如果是使用自己的數(shù)據(jù)集,則需要修改models/slim/eval_image_classifier.py:
把
from datasets import dataset_factory
修改為
from datasets import dataset_classification
把
dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
修改為
dataset = dataset_classification.get_dataset( FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)
在
tf.app.flags.DEFINE_string( 'dataset_dir', None, 'The directory where the dataset files are stored.')
后加入
tf.app.flags.DEFINE_integer( 'num_samples', 350, 'Number of samples.') tf.app.flags.DEFINE_integer( 'num_classes', 5, 'Number of classes.') tf.app.flags.DEFINE_string( 'labels_to_names_path', None, 'Label names file path.')
驗(yàn)證時(shí)執(zhí)行以下命令即可:
python eval_image_classifier.py \ --checkpoint_path=train_logs \ --eval_dir=eval_logs \ --dataset_dir=../../data/val \ --num_samples=350 \ --num_classes=5 \ --model_name=inception_resnet_v2
可以一邊訓(xùn)練一邊驗(yàn)證,,注意使用其它的GPU或合理分配顯存。
同樣也可以可視化log,如果已經(jīng)在可視化訓(xùn)練的log則建議使用其它端口,如:
tensorboard --logdir eval_logs/ --port 6007
測試
參考models/slim/eval_image_classifier.py,可編寫讀取圖片用模型進(jìn)行推導(dǎo)的腳本models/slim/test_image_classifier.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import math
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
'master', '', 'The address of the TensorFlow master to use.')
tf.app.flags.DEFINE_string(
'checkpoint_path', '/tmp/tfmodel/',
'The directory where the model was written to or an absolute path to a '
'checkpoint file.')
tf.app.flags.DEFINE_string(
'test_path', '', 'Test image path.')
tf.app.flags.DEFINE_integer(
'num_classes', 5, 'Number of classes.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
tf.app.flags.DEFINE_integer(
'test_image_size', None, 'Eval image size')
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.test_list:
raise ValueError('You must supply the test list with --test_list')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
tf_global_step = slim.get_or_create_global_step()
####################
# Select the model #
####################
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
is_training=False)
#####################################
# Select the preprocessing function #
#####################################
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
test_image_size = FLAGS.test_image_size or network_fn.default_image_size
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.Graph().as_default()
with tf.Session() as sess:
image = open(FLAGS.test_path, 'rb').read()
image = tf.image.decode_jpeg(image, channels=3)
processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
processed_images = tf.expand_dims(processed_image, 0)
logits, _ = network_fn(processed_images)
predictions = tf.argmax(logits, 1)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
np_image, network_input, predictions = sess.run([image, processed_image, predictions])
print('{} {}'.format(FLAGS.test_path, predictions[0]))
if __name__ == '__main__':
tf.app.run()
測試時(shí)執(zhí)行以下命令即可:
python test_image_classifier.py \ --checkpoint_path=train_logs/ \ --test_path=../../data/flower_photos/tulips/6948239566_0ac0a124ee_n.jpg \ --num_classes=5 \ --model_name=inception_resnet_v2
以上就是本文的全部內(nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
對(duì)python打亂數(shù)據(jù)集中X,y標(biāo)簽對(duì)的方法詳解
今天就為大家分享一篇對(duì)python打亂數(shù)據(jù)集中X,y標(biāo)簽對(duì)的方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-12-12
python反轉(zhuǎn)一個(gè)三位整數(shù)的多種實(shí)現(xiàn)方案
這篇文章主要介紹了python反轉(zhuǎn)一個(gè)三位整數(shù)的多種實(shí)現(xiàn)方案,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2021-03-03
Python基于多線程操作數(shù)據(jù)庫相關(guān)問題分析
這篇文章主要介紹了Python基于多線程操作數(shù)據(jù)庫相關(guān)問題,結(jié)合實(shí)例形式分析了Python使用數(shù)據(jù)庫連接池并發(fā)操作數(shù)據(jù)庫避免超時(shí)、連接丟失相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2018-07-07
Python 25行代碼實(shí)現(xiàn)的RSA算法詳解
這篇文章主要介紹了Python 25行代碼實(shí)現(xiàn)的RSA算法,結(jié)合實(shí)例形式詳細(xì)分析了rsa加密算法的概念、原理、相關(guān)實(shí)現(xiàn)技巧與注意事項(xiàng),需要的朋友可以參考下2018-04-04
Python爬蟲之使用MongoDB存儲(chǔ)數(shù)據(jù)的實(shí)現(xiàn)
本文主要介紹了Python爬蟲之使用MongoDB存儲(chǔ)數(shù)據(jù)的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-06-06
python網(wǎng)絡(luò)爬蟲學(xué)習(xí)筆記(1)
這篇文章主要為大家詳細(xì)介紹了python網(wǎng)絡(luò)爬蟲學(xué)習(xí)筆記的第一篇,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-04-04
python實(shí)現(xiàn)將漢字保存成文本的方法
今天小編就為大家分享一篇python實(shí)現(xiàn)將漢字保存成文本的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-11-11
對(duì)Python subprocess.Popen子進(jìn)程管道阻塞詳解
今天小編就為大家分享一篇對(duì)Python subprocess.Popen子進(jìn)程管道阻塞詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-10-10

