基于Tensorflow批量數(shù)據(jù)的輸入實現(xiàn)方式
基于Tensorflow下的批量數(shù)據(jù)的輸入處理:
1.Tensor TFrecords格式
2.h5py的庫的數(shù)組方法
在tensorflow的框架下寫CNN代碼,我在書寫過程中,感覺不是框架內(nèi)容難寫, 更多的是我在對圖像的預(yù)處理和輸入這部分花了很多精神。
使用了兩種方法:
方法一:
Tensor 以Tfrecords的格式存儲數(shù)據(jù),如果對數(shù)據(jù)進行標(biāo)簽,可以同時做到數(shù)據(jù)打標(biāo)簽。
①創(chuàng)建TFrecords文件
orig_image = '/home/images/train_image/' gen_image = '/home/images/image_train.tfrecords' def create_record(): writer = tf.python_io.TFRecordWriter(gen_image) class_path = orig_image for img_name in os.listdir(class_path): #讀取每一幅圖像 img_path = class_path + img_name img = Image.open(img_path) #讀取圖像 #img = img.resize((256, 256)) #設(shè)置圖片大小, 在這里可以對圖像進行處理 img_raw = img.tobytes() #將圖片轉(zhuǎn)化為原聲bytes example = tf.train.Example( features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[0])), #打標(biāo)簽 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))#存儲數(shù)據(jù) })) writer.write(example.SerializeToString()) writer.close()
②讀取TFrecords文件
def read_and_decode(filename): #創(chuàng)建文件隊列,不限讀取的數(shù)據(jù) filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string)}) label = features['label'] img = features['img_raw'] img = tf.decode_raw(img, tf.uint8) #tf.float32 img = tf.image.convert_image_dtype(img, dtype=tf.float32) img = tf.reshape(img, [256, 256, 1]) label = tf.cast(label, tf.int32) return img, label
③批量讀取數(shù)據(jù),使用tf.train.batch
min_after_dequeue = 10000 capacity = min_after_dequeue + 3 * batch_size num_samples= len(os.listdir(orig_image)) create_record() img, label = read_and_decode(gen_image) total_batch = int(num_samples/batch_size) image_batch, label_batch = tf.train.batch([img, label], batch_size=batch_size, num_threads=32, capacity=capacity) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session() as sess: sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(total_batch): cur_image_batch, cur_label_batch = sess.run([image_batch, label_batch]) coord.request_stop() coord.join(threads)
方法二:
使用h5py就是使用數(shù)組的格式來存儲數(shù)據(jù)
這個方法比較好,在CNN的過程中,會使用到多個數(shù)據(jù)類存儲,比較好用, 比如一個數(shù)據(jù)進行了兩種以上的變化,并且分類存儲,我認(rèn)為這個方法會比較好用。
import os import h5py import matplotlib.pyplot as plt import numpy as np import random from scipy.interpolate import griddata from skimage import img_as_float import matplotlib.pyplot as plt os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' class_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_gray_0_1/' for img_name in os.listdir(class_path): img_path = class_path + img_name img = io.imread(img_path) m1 = img_as_float(img) m2, m3 = sample_inter1(m1) #一個數(shù)據(jù)處理的函數(shù) m1 = m1.reshape([256, 256, 1]) m2 = m2.reshape([256, 256, 1]) m3 = m3.reshape([256, 256, 1]) orig_image.append(m1) sample_near.append(m2) sample_line.append(m3) arrorig_image = np.asarray(orig_image) # [?, 256, 256, 1] arrlsample_near = np.asarray(sample_near) # [?, 256, 256, 1] arrlsample_line = np.asarray(sample_line) # [?, 256, 256, 1] save_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_sample/train.h5' def make_data(path): with h5py.File(save_path, 'w') as hf: hf.create_dataset('orig_image', data=arrorig_image) hf.create_dataset('sample_near', data=arrlsample_near) hf.create_dataset('sample_line', data=arrlsample_line) def read_data(path): with h5py.File(path, 'r') as hf: orig_image = np.array(hf.get('orig_image')) #一定要對清楚上邊的標(biāo)簽名orig_image; sample_near = np.array(hf.get('sample_near')) sample_line = np.array(hf.get('sample_line')) return orig_image, sample_near, sample_line make_data(save_path) orig_image1, sample_near1, sample_line1 = read_data(save_path) total_number = len(orig_image1) batch_size = 20 batch_index = total_number/batch_size for i in range(batch_index): batch_orig = orig_image1[i*batch_size:(i+1)*batch_size] batch_sample_near = sample_near1[i*batch_size:(i+1)*batch_size] batch_sample_line = sample_line1[i*batch_size:(i+1)*batch_size]
在使用h5py的時候,生成的文件巨大的時候,讀取數(shù)據(jù)顯示錯誤:ioerror: unable to open file (bad object header version number)
基本就是這個生成的文件不能使用,適當(dāng)?shù)臏p少存儲的數(shù)據(jù),即可。
以上這篇基于Tensorflow批量數(shù)據(jù)的輸入實現(xiàn)方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python OS系統(tǒng)解決路徑中空格原因?qū)е挛募虿婚_的問題
這篇文章主要介紹了Python OS系統(tǒng)解決路徑中空格原因?qū)е挛募虿婚_的問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2024-02-02Python 跨文件夾導(dǎo)入自定義包的實現(xiàn)
有時我們自己編寫一些模塊時,跨文件夾調(diào)用會出現(xiàn)ModuleNotFoundError: No module named 'XXX',本文就來介紹一下解決方法,感興趣的可以了解一下2023-11-11python3 破解 geetest(極驗)的滑塊驗證碼功能
這篇文章主要介紹了python3 破解 geetest(極驗)的滑塊驗證碼功能,本文通過實例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2018-02-02Python2.7基于淘寶接口獲取IP地址所在地理位置的方法【測試可用】
這篇文章主要介紹了Python2.7基于淘寶接口獲取IP地址所在地理位置的方法,涉及Python調(diào)用淘寶IP庫接口進行IP查詢的簡單操作技巧,需要的朋友可以參考下2017-06-06