亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

Tensorflow中使用tfrecord方式讀取數(shù)據(jù)的方法

 更新時(shí)間:2018年06月19日 10:06:08   作者:無空ty  
這篇文章主要介紹了Tensorflow中使用tfrecord方式讀取數(shù)據(jù)的方法,適用于數(shù)據(jù)較多時(shí),小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧

前言

本博客默認(rèn)讀者對(duì)神經(jīng)網(wǎng)絡(luò)與Tensorflow有一定了解,對(duì)其中的一些術(shù)語不再做具體解釋。并且本博客主要以圖片數(shù)據(jù)為例進(jìn)行介紹,如有錯(cuò)誤,敬請(qǐng)斧正。

使用Tensorflow訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),我們可以用多種方式來讀取自己的數(shù)據(jù)。如果數(shù)據(jù)集比較小,而且內(nèi)存足夠大,可以選擇直接將所有數(shù)據(jù)讀進(jìn)內(nèi)存,然后每次取一個(gè)batch的數(shù)據(jù)出來。如果數(shù)據(jù)較多,可以每次直接從硬盤中進(jìn)行讀取,不過這種方式的讀取效率就比較低了。此篇博客就主要講一下Tensorflow官方推薦的一種較為高效的數(shù)據(jù)讀取方式——tfrecord。

從宏觀來講,tfrecord其實(shí)是一種數(shù)據(jù)存儲(chǔ)形式。使用tfrecord時(shí),實(shí)際上是先讀取原生數(shù)據(jù),然后轉(zhuǎn)換成tfrecord格式,再存儲(chǔ)在硬盤上。而使用時(shí),再把數(shù)據(jù)從相應(yīng)的tfrecord文件中解碼讀取出來。那么使用tfrecord和直接從硬盤讀取原生數(shù)據(jù)相比到底有什么優(yōu)勢(shì)呢?其實(shí),Tensorflow有和tfrecord配套的一些函數(shù),可以加快數(shù)據(jù)的處理。實(shí)際讀取tfrecord數(shù)據(jù)時(shí),先以相應(yīng)的tfrecord文件為參數(shù),創(chuàng)建一個(gè)輸入隊(duì)列,這個(gè)隊(duì)列有一定的容量(視具體硬件限制,用戶可以設(shè)置不同的值),在一部分?jǐn)?shù)據(jù)出隊(duì)列時(shí),tfrecord中的其他數(shù)據(jù)就可以通過預(yù)取進(jìn)入隊(duì)列,并且這個(gè)過程和網(wǎng)絡(luò)的計(jì)算是獨(dú)立進(jìn)行的。也就是說,網(wǎng)絡(luò)每一個(gè)iteration的訓(xùn)練不必等待數(shù)據(jù)隊(duì)列準(zhǔn)備好再開始,隊(duì)列中的數(shù)據(jù)始終是充足的,而往隊(duì)列中填充數(shù)據(jù)時(shí),也可以使用多線程加速。

下面,本文將從以下4個(gè)方面對(duì)tfrecord進(jìn)行介紹:

  1. tfrecord格式簡介
  2. 利用自己的數(shù)據(jù)生成tfrecord文件
  3. 從tfrecord文件讀取數(shù)據(jù)
  4. 實(shí)例測試

1. tfrecord格式簡介

這部分主要參考了另一篇博文,Tensorflow 訓(xùn)練自己的數(shù)據(jù)集(二)(TFRecord)

tfecord文件中的數(shù)據(jù)是通過tf.train.Example Protocol Buffer的格式存儲(chǔ)的,下面是tf.train.Example的定義

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
  oneof kind{
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

從上述代碼可以看出,tf.train.Example 的數(shù)據(jù)結(jié)構(gòu)很簡單。tf.train.Example中包含了一個(gè)從屬性名稱到取值的字典,其中屬性名稱為一個(gè)字符串,屬性的取值可以為字符串(BytesList ),浮點(diǎn)數(shù)列表(FloatList )或整數(shù)列表(Int64List )。例如我們可以將圖片轉(zhuǎn)換為字符串進(jìn)行存儲(chǔ),圖像對(duì)應(yīng)的類別標(biāo)號(hào)作為整數(shù)存儲(chǔ),而用于回歸任務(wù)的ground-truth可以作為浮點(diǎn)數(shù)存儲(chǔ)。通過后面的代碼我們會(huì)對(duì)tfrecord的這種字典形式有更直觀的認(rèn)識(shí)。

2. 利用自己的數(shù)據(jù)生成tfrecord文件

先上一段代碼,然后我再針對(duì)代碼進(jìn)行相關(guān)介紹。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
import scipy.io as sio


def _bytes_feature(value):
  return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))

def _int64_feature(value):
  return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))


root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecords_filename = root_path + 'tfrecords/train.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename)


height = 300
width = 300
meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
meanvalue = meanfile['mean']

txtfile = root_path + 'txt/train.txt'
fr = open(txtfile)

for i in fr.readlines():
  item = i.split()
  img = np.float64(misc.imread(root_path + '/images/train_images/' + item[0]))
  img = img - meanvalue
  maskmat = sio.loadmat(root_path + '/mats/train_mats/' + item[1])
  mask = np.float64(maskmat['seg_mask'])
  label = int(item[2])
  img_raw = img.tostring()
  mask_raw = mask.tostring()
  example = tf.train.Example(features=tf.train.Features(feature={
    'height': _int64_feature(height),
    'width': _int64_feature(width),
    'name': _bytes_feature(item[0]),
    'image_raw': _bytes_feature(img_raw),
    'mask_raw': _bytes_feature(mask_raw),
    'label': _int64_feature(label)}))

  writer.write(example.SerializeToString())

writer.close()
fr.close()

代碼中前兩個(gè)函數(shù)(_bytes_feature和_int64_feature)是將我們的原生數(shù)據(jù)進(jìn)行轉(zhuǎn)換用的,尤其是圖片要轉(zhuǎn)換成字符串再進(jìn)行存儲(chǔ)。這兩個(gè)函數(shù)的定義來自官方的示例。

接下來,我定義了數(shù)據(jù)的(路徑-label文件)txtfile,它大概長這個(gè)樣子:

這里稍微啰嗦下,介紹一下我的實(shí)驗(yàn)內(nèi)容。我做的是一個(gè)multi-task的實(shí)驗(yàn),一支task做分割,一支task做分類。所以txtfile中每一行是一個(gè)樣本,每個(gè)樣本又包含3項(xiàng),第一項(xiàng)為圖片名稱,第二項(xiàng)為相應(yīng)的ground-truth segmentation mask的名稱,第三項(xiàng)是圖片的標(biāo)簽。(txtfile中內(nèi)容形式無所謂,只要能讀到想讀的數(shù)據(jù)就可以)

接著回到主題繼續(xù)講代碼,之后我又定義了即將生成的tfrecord的文件路徑和名稱,即tfrecord_filename,還有一個(gè)writer,這個(gè)writer是進(jìn)行寫操作用的。

接下來是圖片的高度、寬度以及我事先在整個(gè)數(shù)據(jù)集上計(jì)算好的圖像均值文件。高度、寬度其實(shí)完全沒必要引入,這里只是為了說明tfrecord的生成而寫的。而均值文件是為了對(duì)圖像進(jìn)行事先的去均值化操作而引入的,在大多數(shù)機(jī)器學(xué)習(xí)任務(wù)中,圖像去均值化對(duì)提高算法的性能還是很有幫助的。

最后就是根據(jù)txtfile中的每一行進(jìn)行相關(guān)數(shù)據(jù)的讀取、轉(zhuǎn)換以及tfrecord的生成了。首先是根據(jù)圖片路徑讀取圖片內(nèi)容,然后圖像減去之前讀入的均值,接著根據(jù)segmentation mask的路徑讀取mask(如果只是圖像分類任務(wù),那么就不會(huì)有這些額外的mask),txtfile中的label讀出來是string格式,這里要轉(zhuǎn)換成int。然后圖像和mask數(shù)據(jù)也要用相應(yīng)的tosring函數(shù)轉(zhuǎn)換成string。

真正的核心是下面這一小段代碼:

example = tf.train.Example(features=tf.train.Features(feature={
    'height': _int64_feature(height),
    'width': _int64_feature(width),
    'name': _bytes_feature(item[0]),
    'image_raw': _bytes_feature(img_raw),
    'mask_raw': _bytes_feature(mask_raw),
    'label': _int64_feature(label)}))

writer.write(example.SerializeToString())

這里很好地體現(xiàn)了tfrecord的字典特性,tfrecord中每一個(gè)樣本都是一個(gè)小字典,這個(gè)字典可以包含任意多個(gè)鍵值對(duì)。比如我這里就存儲(chǔ)了圖片的高度、寬度、圖片名稱、圖片內(nèi)容、mask內(nèi)容以及圖片的label。對(duì)于我的任務(wù)來說,其實(shí)height、width、name都不是必需的,這里僅僅是為了展示。鍵值對(duì)的鍵全都是字符串,鍵起什么名字都可以,只要能方便以后使用就可以。

定義好一個(gè)example后就可以用之前的writer來把它真正寫入tfrecord文件了,這其實(shí)就跟把一行內(nèi)容寫入一個(gè)txt文件一樣。代碼的最后就是writer和txt文件對(duì)象的關(guān)閉了。

最后在指定文件夾下,就得到了指定名字的tfrecord文件,如下所示:

需要注意的是,生成的tfrecord文件比原生數(shù)據(jù)的大小還要大,這是正?,F(xiàn)象。這種現(xiàn)象可能是因?yàn)閳D片一般都存儲(chǔ)為jpg等壓縮格式,而tfrecord文件存儲(chǔ)的是解壓后的數(shù)據(jù)。

3. 從tfrecord文件讀取數(shù)據(jù)

還是代碼先行。

from scipy import misc
import tensorflow as tf
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt

root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'

def read_and_decode(filename_queue, random_crop=False, random_clip=False, shuffle_batch=True):
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
   serialized_example,
   features={
     'height': tf.FixedLenFeature([], tf.int64),
     'width': tf.FixedLenFeature([], tf.int64),
     'name': tf.FixedLenFeature([], tf.string),              
     'image_raw': tf.FixedLenFeature([], tf.string),
     'mask_raw': tf.FixedLenFeature([], tf.string),                
     'label': tf.FixedLenFeature([], tf.int64)
   })

  image = tf.decode_raw(features['image_raw'], tf.float64)
  image = tf.reshape(image, [300,300,3])

  mask = tf.decode_raw(features['mask_raw'], tf.float64)
  mask = tf.reshape(mask, [300,300])

  name = features['name']

  label = features['label']
  width = features['width']
  height = features['height']

#  if random_crop:
#    image = tf.random_crop(image, [227, 227, 3])
#  else:
#    image = tf.image.resize_image_with_crop_or_pad(image, 227, 227)

#  if random_clip:
#    image = tf.image.random_flip_left_right(image)


  if shuffle_batch:
    images, masks, names, labels, widths, heights = tf.train.shuffle_batch([image, mask, name, label, width, height],
                        batch_size=4,
                        capacity=8000,
                        num_threads=4,
                        min_after_dequeue=2000)
  else:
    images, masks, names, labels, widths, heights = tf.train.batch([image, mask, name, label, width, height],
                    batch_size=4,
                    capacity=8000,
                    num_threads=4)
  return images, masks, names, labels, widths, heights

讀取tfrecord文件中的數(shù)據(jù)主要是應(yīng)用read_and_decode()這個(gè)函數(shù),可以看到其中有個(gè)參數(shù)是filename_queue,其實(shí)我們并不是直接從tfrecord文件進(jìn)行讀取,而是要先利用tfrecord文件創(chuàng)建一個(gè)輸入隊(duì)列,如本文開頭所述那樣。關(guān)于這點(diǎn),到后面真正的測試代碼我再介紹。

在read_and_decode()中,一上來我們先定義一個(gè)reader對(duì)象,然后使用reader得到serialized_example,這是一個(gè)序列化的對(duì)象,接著使用tf.parse_single_example()函數(shù)對(duì)此對(duì)象進(jìn)行初步解析。從代碼中可以看到,解析時(shí),我們要用到之前定義的那些鍵。對(duì)于圖像、mask這種轉(zhuǎn)換成字符串的數(shù)據(jù),要進(jìn)一步使用tf.decode_raw()函數(shù)進(jìn)行解析,這里要特別注意函數(shù)里的第二個(gè)參數(shù),也就是解析后的類型。之前圖片在轉(zhuǎn)成字符串之前是什么類型的數(shù)據(jù),那么這里的參數(shù)就要填成對(duì)應(yīng)的類型,否則會(huì)報(bào)錯(cuò)。對(duì)于name、label、width、height這樣的數(shù)據(jù)就不用再解析了,我們得到的features對(duì)象就是個(gè)字典,利用鍵就可以拿到對(duì)應(yīng)的值,如代碼所示。

我注釋掉的部分是用來做數(shù)據(jù)增強(qiáng)的,比如隨機(jī)的裁剪與翻轉(zhuǎn),除了這兩種,其他形式的數(shù)據(jù)增強(qiáng)也可以寫在這里,讀者可以根據(jù)自己的需要,決定是否使用各種數(shù)據(jù)增強(qiáng)方式。

函數(shù)最后就是使用解析出來的數(shù)據(jù)生成batch了。Tensorflow提供了兩種方式,一種是shuffle_batch,這種主要是用在訓(xùn)練中,隨機(jī)選取樣本組成batch。另外一種就是按照數(shù)據(jù)在tfrecord中的先后順序生成batch。對(duì)于生成batch的函數(shù),建議讀者去官網(wǎng)查看API文檔進(jìn)行細(xì)致了解。這里稍微做一下介紹,batch的大小,即batch_size就需要在生成batch的函數(shù)里指定。另外,capacity參數(shù)指定數(shù)據(jù)隊(duì)列一次性能放多少個(gè)樣本,此參數(shù)設(shè)置什么值需要視硬件環(huán)境而定。num_threads參數(shù)指定可以開啟幾個(gè)線程來向數(shù)據(jù)隊(duì)列中填充數(shù)據(jù),如果硬件性能不夠強(qiáng),最好設(shè)小一點(diǎn),否則容易崩。

4. 實(shí)例測試

實(shí)際使用時(shí)先指定好我們需要使用的tfrecord文件:

root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'

然后用該tfrecord文件創(chuàng)建一個(gè)輸入隊(duì)列:

filename_queue = tf.train.string_input_producer([tfrecord_filename],
                          num_epochs=3)

這里有個(gè)參數(shù)是num_epochs,指定好之后,Tensorflow自然知道如何讀取數(shù)據(jù),保證在遍歷數(shù)據(jù)集的一個(gè)epoch中樣本不會(huì)重復(fù),也知道數(shù)據(jù)讀取何時(shí)應(yīng)該停止。

下面我將完整的測試代碼貼出:

def test_run(tfrecord_filename):
  filename_queue = tf.train.string_input_producer([tfrecord_filename],
                          num_epochs=3)
  images, masks, names, labels, widths, heights = read_and_decode(filename_queue)

  init_op = tf.group(tf.global_variables_initializer(),
            tf.local_variables_initializer())

  meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
  meanvalue = meanfile['mean']


  with tf.Session() as sess:
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(1):
      imgs, msks, nms, labs, wids, heis = sess.run([images, masks, names, labels, widths, heights])
      print 'batch' + str(i) + ': '
      #print type(imgs[0])

      for j in range(4):
        print nms[j] + ': ' + str(labs[j]) + ' ' + str(wids[j]) + ' ' + str(heis[j])
        img = np.uint8(imgs[j] + meanvalue)
        msk = np.uint8(msks[j])
        plt.subplot(4,2,j*2+1)
        plt.imshow(img)
        plt.subplot(4,2,j*2+2)
        plt.imshow(msk, vmin=0, vmax=5)
      plt.show()

    coord.request_stop()
    coord.join(threads)

函數(shù)中接下來就是利用之前定義的read_and_decode()來得到一個(gè)batch的數(shù)據(jù),此后我又讀入了均值文件,這是因?yàn)橹白隽巳ゾ堤幚恚绻o@示圖片需要再把均值加回來。

再之后就是建立一個(gè)Tensorflow session,然后初始化對(duì)象。這些是Tensorflow基本操作,不再贅述。下面的這兩句代碼非常重要,是讀取數(shù)據(jù)必不可少的。

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

然后是運(yùn)行sess.run()拿到實(shí)際數(shù)據(jù),之前只是相當(dāng)于定義好了,并沒有得到真實(shí)數(shù)值。為了簡單起見,我在之后的循環(huán)里只測試了一個(gè)batch的數(shù)據(jù),關(guān)于tfrecord的標(biāo)準(zhǔn)使用我也建議讀者去官網(wǎng)的數(shù)據(jù)讀取部分看看示例。循環(huán)里對(duì)數(shù)據(jù)的各種信息進(jìn)行了展示,結(jié)果如下:

從圖片的名字可以看出,數(shù)據(jù)的確是進(jìn)行了shuffle的,標(biāo)簽、寬度、高度、圖片本身以及對(duì)應(yīng)的mask圖像也全部展示出來了。

測試函數(shù)的最后,要使用以下兩句代碼進(jìn)行停止,就如同文件需要close()一樣:

以上就是本文的全部內(nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。

相關(guān)文章

  • python Pexpect模塊的使用

    python Pexpect模塊的使用

    這篇文章主要介紹了python Pexpect模塊的使用,幫助大家更好的理解和使用python,感興趣的朋友可以了解下
    2020-12-12
  • 使用Python中的pytesseract模塊實(shí)現(xiàn)抓取圖片中文字

    使用Python中的pytesseract模塊實(shí)現(xiàn)抓取圖片中文字

    最近同事用網(wǎng)上提供掃描軟件進(jìn)行掃描識(shí)別文字,每天上線只能夠做兩次掃描,請(qǐng)求我研發(fā)一個(gè)小工具幫助解決識(shí)別圖片的中文字,最終我選擇使用pytesseract模塊可以解決這個(gè)需求問題,本文給大家分享實(shí)現(xiàn)代碼操作感興趣的朋友跟隨小編一起看看吧
    2022-11-11
  • 人工智能學(xué)習(xí)pyTorch自建數(shù)據(jù)集及可視化結(jié)果實(shí)現(xiàn)過程

    人工智能學(xué)習(xí)pyTorch自建數(shù)據(jù)集及可視化結(jié)果實(shí)現(xiàn)過程

    這篇文章主要為大家介紹了人工智能學(xué)習(xí)pyTorch自建數(shù)據(jù)集及可視化結(jié)果的實(shí)現(xiàn)過程,有需要的朋友可以借鑒參考下,希望能夠有所幫助
    2021-11-11
  • Python性能提升之延遲初始化

    Python性能提升之延遲初始化

    本文給大家分享的是在Python中使用延遲計(jì)算來提升性能的方法,非常的實(shí)用,有需要的小伙伴可以參考下
    2016-12-12
  • python鏈接sqlite數(shù)據(jù)庫的詳細(xì)代碼實(shí)例

    python鏈接sqlite數(shù)據(jù)庫的詳細(xì)代碼實(shí)例

    SQLite數(shù)據(jù)庫是一款非常小巧的嵌入式開源數(shù)據(jù)庫軟件,也就是說沒有獨(dú)立的維護(hù)進(jìn)程,所有的維護(hù)都來自于程序本身,它是遵守ACID的關(guān)聯(lián)式數(shù)據(jù)庫管理系統(tǒng),它的設(shè)計(jì)目標(biāo)是嵌入式的,而且目前已經(jīng)在很多嵌入式產(chǎn)品中使用了它,它占用資源非常的低
    2021-09-09
  • Python函數(shù)返回值實(shí)例分析

    Python函數(shù)返回值實(shí)例分析

    這篇文章主要介紹了Python函數(shù)返回值,實(shí)例分析了Python中返回一個(gè)返回值與多個(gè)返回值的方法,需要的朋友可以參考下
    2015-06-06
  • Python基于scipy實(shí)現(xiàn)信號(hào)濾波功能

    Python基于scipy實(shí)現(xiàn)信號(hào)濾波功能

    本文將以實(shí)戰(zhàn)的形式基于scipy模塊使用Python實(shí)現(xiàn)簡單濾波處理。這篇文章主要介紹了Python基于scipy實(shí)現(xiàn)信號(hào)濾波功能,需要的朋友可以參考下
    2019-05-05
  • SQLAlchemy的主要組件詳細(xì)講解

    SQLAlchemy的主要組件詳細(xì)講解

    SQLAlchemy是一個(gè)基于Python實(shí)現(xiàn)的ORM框架,能滿足大多數(shù)數(shù)據(jù)庫操作需求,同時(shí)支持多種數(shù)據(jù)庫引擎(SQLite,MySQL,Postgresql,Oracle等),這篇文章主要介紹了SQLAlchemy的主要組件有哪些,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)具有一定的參考借鑒價(jià)值,需要的朋友可以參考
    2023-08-08
  • Python 如何安裝Selenium(推薦)

    Python 如何安裝Selenium(推薦)

    Selenium 是一個(gè) Web的自動(dòng)化測試工具 ,最初是為網(wǎng)站 自動(dòng)化測試而開發(fā)的 , Selenium 可以直接調(diào)用瀏覽器 ,它支持所有主流的瀏覽器,本文給大家介紹Python 如何安裝Selenium,感興趣的朋友一起看看吧
    2021-05-05
  • 超詳細(xì)注釋之OpenCV按位AND OR XOR和NOT

    超詳細(xì)注釋之OpenCV按位AND OR XOR和NOT

    這篇文章主要介紹了OpenCV按位AND OR XOR和NOT運(yùn)算,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2021-09-09

最新評(píng)論