欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

基于Tensorflow批量數(shù)據(jù)的輸入實(shí)現(xiàn)方式

 更新時(shí)間:2020年02月05日 14:36:09   作者:J_PrCz  
今天小編就為大家分享一篇基于Tensorflow批量數(shù)據(jù)的輸入實(shí)現(xiàn)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧

基于Tensorflow下的批量數(shù)據(jù)的輸入處理:

1.Tensor TFrecords格式

2.h5py的庫(kù)的數(shù)組方法

在tensorflow的框架下寫(xiě)CNN代碼,我在書(shū)寫(xiě)過(guò)程中,感覺(jué)不是框架內(nèi)容難寫(xiě), 更多的是我在對(duì)圖像的預(yù)處理和輸入這部分花了很多精神。

使用了兩種方法:

方法一:

Tensor 以Tfrecords的格式存儲(chǔ)數(shù)據(jù),如果對(duì)數(shù)據(jù)進(jìn)行標(biāo)簽,可以同時(shí)做到數(shù)據(jù)打標(biāo)簽。

①創(chuàng)建TFrecords文件

orig_image = '/home/images/train_image/'
gen_image = '/home/images/image_train.tfrecords'
def create_record():
  writer = tf.python_io.TFRecordWriter(gen_image)
  class_path = orig_image
  for img_name in os.listdir(class_path): #讀取每一幅圖像
    img_path = class_path + img_name 
    img = Image.open(img_path) #讀取圖像
    #img = img.resize((256, 256)) #設(shè)置圖片大小, 在這里可以對(duì)圖像進(jìn)行處理
    img_raw = img.tobytes() #將圖片轉(zhuǎn)化為原聲bytes 
    example = tf.train.Example(
         features=tf.train.Features(feature={
             'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[0])), #打標(biāo)簽
             'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))#存儲(chǔ)數(shù)據(jù)
             }))
    writer.write(example.SerializeToString())
  writer.close()

②讀取TFrecords文件

def read_and_decode(filename):
  #創(chuàng)建文件隊(duì)列,不限讀取的數(shù)據(jù)
  filename_queue = tf.train.string_input_producer([filename])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)

  features = tf.parse_single_example(
      serialized_example,
      features={
          'label': tf.FixedLenFeature([], tf.int64),
          'img_raw': tf.FixedLenFeature([], tf.string)})
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8) #tf.float32
  img = tf.image.convert_image_dtype(img, dtype=tf.float32)
  img = tf.reshape(img, [256, 256, 1])
  label = tf.cast(label, tf.int32)
  return img, label

③批量讀取數(shù)據(jù),使用tf.train.batch

min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
num_samples= len(os.listdir(orig_image))
create_record()
img, label = read_and_decode(gen_image)
total_batch = int(num_samples/batch_size)
image_batch, label_batch = tf.train.batch([img, label], batch_size=batch_size,
                      num_threads=32, capacity=capacity) 
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
  sess.run(init_op)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  for i in range(total_batch):
     cur_image_batch, cur_label_batch = sess.run([image_batch, label_batch])
  coord.request_stop()
  coord.join(threads)

方法二:

使用h5py就是使用數(shù)組的格式來(lái)存儲(chǔ)數(shù)據(jù)

這個(gè)方法比較好,在CNN的過(guò)程中,會(huì)使用到多個(gè)數(shù)據(jù)類(lèi)存儲(chǔ),比較好用, 比如一個(gè)數(shù)據(jù)進(jìn)行了兩種以上的變化,并且分類(lèi)存儲(chǔ),我認(rèn)為這個(gè)方法會(huì)比較好用。

import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import random
from scipy.interpolate import griddata
from skimage import img_as_float
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
class_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_gray_0_1/'
for img_name in os.listdir(class_path):
  img_path = class_path + img_name
  img = io.imread(img_path)
  m1 = img_as_float(img)
  m2, m3 = sample_inter1(m1) #一個(gè)數(shù)據(jù)處理的函數(shù)
  m1 = m1.reshape([256, 256, 1])
  m2 = m2.reshape([256, 256, 1])
  m3 = m3.reshape([256, 256, 1])
  orig_image.append(m1)
  sample_near.append(m2)
  sample_line.append(m3)

arrorig_image = np.asarray(orig_image) # [?, 256, 256, 1]
arrlsample_near = np.asarray(sample_near) # [?, 256, 256, 1] 
arrlsample_line = np.asarray(sample_line) # [?, 256, 256, 1] 

save_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_sample/train.h5'
def make_data(path):
  with h5py.File(save_path, 'w') as hf:
     hf.create_dataset('orig_image', data=arrorig_image)
     hf.create_dataset('sample_near', data=arrlsample_near)
     hf.create_dataset('sample_line', data=arrlsample_line)

def read_data(path):
  with h5py.File(path, 'r') as hf:
     orig_image = np.array(hf.get('orig_image')) #一定要對(duì)清楚上邊的標(biāo)簽名orig_image;
     sample_near = np.array(hf.get('sample_near'))
     sample_line = np.array(hf.get('sample_line'))
  return orig_image, sample_near, sample_line
make_data(save_path)
orig_image1, sample_near1, sample_line1 = read_data(save_path)
total_number = len(orig_image1)
batch_size = 20
batch_index = total_number/batch_size
for i in range(batch_index):
  batch_orig = orig_image1[i*batch_size:(i+1)*batch_size]
  batch_sample_near = sample_near1[i*batch_size:(i+1)*batch_size]
  batch_sample_line = sample_line1[i*batch_size:(i+1)*batch_size]

在使用h5py的時(shí)候,生成的文件巨大的時(shí)候,讀取數(shù)據(jù)顯示錯(cuò)誤:ioerror: unable to open file (bad object header version number)

基本就是這個(gè)生成的文件不能使用,適當(dāng)?shù)臏p少存儲(chǔ)的數(shù)據(jù),即可。

以上這篇基于Tensorflow批量數(shù)據(jù)的輸入實(shí)現(xiàn)方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評(píng)論