Tensorflow加載模型實現(xiàn)圖像分類識別流程詳解
前言
深度學(xué)習(xí)框架在市面上有很多。比如Theano、Caffe、CNTK、MXnet 、Tensorflow等。今天講解的就是主角Tensorflow。Tensorflow的前身是Google大腦項目的一個分布式機器學(xué)習(xí)訓(xùn)練框架,它是一個十分基礎(chǔ)且集成度很高的系統(tǒng),它的目標就是為研究超大型規(guī)模的視覺項目,后面延申到各個領(lǐng)域。Tensorflow 在2015年正式開源,開源的一個月內(nèi)就收獲到1w多的starts,這足以說明Tensorflow的優(yōu)越性以及Google的影響力。在Api方面Tensorflow為了滿足絕大部分的開發(fā)者需求,這也是Google的一貫作風(fēng),集成了Java、Go、Python、C++等編程語言。
正文
圖像識別是一件很有趣的事,話不多說,咱們先了解下特征提取VGG in Tensorflow。官網(wǎng)地址:VGG in TensorFlow · Davi Frossard。
VGG 是牛津大學(xué)的 K. Simonyan 和 A. Zisserman 在論文“Very Deep Convolutional Networks for Large-Scale Image Recognition”中提出的卷積神經(jīng)網(wǎng)絡(luò)模型。該模型在 ImageNet 中實現(xiàn)了 92.7% 的 top-5 測試準確率,這是一個包含 1000 個類別的超過 1400 萬張圖像的數(shù)據(jù)集。 在這篇簡短的文章中,我們提供了 VGG16 的實現(xiàn)以及從原始 Caffe 模型轉(zhuǎn)換為 TensorFlow 的權(quán)重。這句話是VGGNet官方的介紹,直接從它提供的數(shù)字可以看出來,它的識別率是十分高的,是不是很激動,動起手來吧。
開發(fā)步驟分4步,如下所示:
a) 依賴加載
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import os import scipy.io import scipy.misc from imagenet_classes import class_names
b)定義卷積、池化等函數(shù)
def _conv_layer(input,weight,bias): conv = tf.nn.conv2d(input,weight,strides=[1,1,1,1],padding="SAME") return tf.nn.bias_add(conv,bias) def _pool_layer(input): return tf.nn.max_pool(input,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME") def preprocess(image,mean_pixel): '''簡單預(yù)處理,全部圖片減去平均值''' return image-mean_pixel def unprocess(image,mean_pixel): return image+mean_pixel
c)圖像的讀取以及保存
def imread(path): return scipy.misc.imread(path) def imsave(image,path): img = np.clip(image,0,255).astype(np.int8) scipy.misc.imsave(path,image)
d) 定義網(wǎng)絡(luò)結(jié)構(gòu),這里使用的是VGG19
def net(data_path,input_image,sess=None): """ 讀取VGG模型參數(shù),搭建VGG網(wǎng)絡(luò) :param data_path: VGG模型文件位置 :param input_image: 輸入測試圖像 :return: """ layers = ( 'conv1_1', 'conv1_2', 'pool1', 'conv2_1', 'conv2_2', 'pool2', 'conv3_1', 'conv3_2', 'conv3_3','conv3_4', 'pool3', 'conv4_1', 'conv4_2', 'conv4_3','conv4_4', 'pool4', 'conv5_1', 'conv5_2', 'conv5_3','conv5_4', 'pool5', 'fc1' , 'fc2' , 'fc3' , 'softmax' ) data = scipy.io.loadmat(data_path) mean = data["normalization"][0][0][0][0][0] input_image = np.array([preprocess(input_image, mean)]).astype(np.float32)#去除平均值 net = {} current = input_image net["src_image"] = tf.constant(current) # 存儲數(shù)據(jù) count = 0 #計數(shù)存儲 for i in range(43): if str(data['layers'][0][i][0][0][0][0])[:4] == ("relu"): continue if str(data['layers'][0][i][0][0][0][0])[:4] == ("pool"): current = _pool_layer(current) elif str(data['layers'][0][i][0][0][0][0]) == ("softmax"): current = tf.nn.softmax(current) elif i == (37): shape = int(np.prod(current.get_shape()[1:])) current = tf.reshape(current, [-1, shape]) kernels, bias = data['layers'][0][i][0][0][0][0] kernels = np.reshape(kernels,[-1,4096]) bias = bias.reshape(-1) current = tf.nn.relu(tf.add(tf.matmul(current,kernels),bias)) elif i == (39): kernels, bias = data['layers'][0][i][0][0][0][0] kernels = np.reshape(kernels,[4096,4096]) bias = bias.reshape(-1) current = tf.nn.relu(tf.add(tf.matmul(current,kernels),bias)) elif i == 41: kernels, bias = data['layers'][0][i][0][0][0][0] kernels = np.reshape(kernels, [4096, 1000]) bias = bias.reshape(-1) current = tf.add(tf.matmul(current, kernels), bias) else: kernels,bias = data['layers'][0][i][0][0][0][0] #注意VGG存儲方式為[,] #kernels = np.transpose(kernels,[1,0,2,3]) bias = bias.reshape(-1)#降低維度 current = tf.nn.relu(_conv_layer(current,kernels,bias)) net[layers[count]] = current #存儲數(shù)據(jù) count += 1 return net, mean
e)加載模型進行識別
if __name__ == '__main__': VGG_PATH = "./one/imagenet-vgg-verydeep-19.mat" IMG_PATH = './one/3.jpg' input_image =imread(IMG_PATH) shape = (1, input_image.shape[0], input_image.shape[1], input_image.shape[2]) with tf.Session() as sess: image = tf.placeholder('float', shape=shape) nets, mean_pixel, all_layers= net(VGG_PATH, image) input_image_pre=np.array([preprocess(input_image,mean_pixel)]) layers = all_layers for i , layer in enumerate(layers): print("[%d/%d] %s" % (i+1,len(layers),layers)) features = nets[layer].eval(feed_dict={image:input_image_pre}) print("Type of 'feature' is ",type(features)) print("Shape of 'features' is %s" % (features.shape,)) if 1: plt.figure(i+1,figsize=(10,5)) plt.matshow(features[0,:,:,0],cmap=plt.cm.gray,fignum=i+1) plt.title(""+layer) plt.colorbar() plt.show()
VGG19網(wǎng)絡(luò)介紹
VGG19 的宏觀架構(gòu)如圖所示。我們在 TensorFlow 中的文件 vgg19.py 中對其進行編碼。請注意,我們包含一個預(yù)處理層,它采用像素值在 0-255 范圍內(nèi)的 RGB 圖像并減去平均圖像值(在整個 ImageNet 訓(xùn)練集上計算)。
總結(jié)
Tensorflow是一款十分不錯的深度學(xué)習(xí)框架,它在工業(yè)上得到的十分的認可并進行了實踐。因此,如果你還在猶豫生產(chǎn)落地使用框架,不要猶豫啦。VGGNet家族是一個十分優(yōu)秀的網(wǎng)絡(luò)結(jié)構(gòu),它在處理特征提取過程中,也是得到了很多公司和研究學(xué)者的認可,比較著名的有VGG16、VGG19等。
到此這篇關(guān)于Tensorflow加載模型實現(xiàn)圖像分類識別流程詳解的文章就介紹到這了,更多相關(guān)Tensorflow圖像分類識別內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python3實現(xiàn)爬取簡書首頁文章標題和文章鏈接的方法【測試可用】
這篇文章主要介紹了Python3實現(xiàn)爬取簡書首頁文章標題和文章鏈接的方法,結(jié)合實例形式分析了Python3基于urllib及bs4庫針對簡書網(wǎng)進行文章抓取相關(guān)操作技巧,需要的朋友可以參考下2018-12-12Python 字符串處理特殊空格\xc2\xa0\t\n Non-breaking space
今天遇到一個問題,使用python的find函數(shù)尋找字符串中的第一個空格時沒有找到正確的位置,下面是解決方法,需要的朋友可以參考下2020-02-02Python中動態(tài)檢測編碼chardet的使用教程
最近利用python抓取一些網(wǎng)上的數(shù)據(jù),遇到了編碼的問題。非常頭痛,幸運的是找到了解決的方法,下面這篇文章主要跟大家介紹了關(guān)于Python中動態(tài)檢測編碼chardet的使用方法,需要的朋友可以參考借鑒,下面來一起看看吧。2017-07-07Python 進程之間共享數(shù)據(jù)(全局變量)的方法
今天小編就為大家分享一篇Python 進程之間共享數(shù)據(jù)(全局變量)的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07Python Asyncio模塊實現(xiàn)的生產(chǎn)消費者模型的方法
這篇文章主要介紹了Python Asyncio模塊實現(xiàn)的生產(chǎn)消費者模型的方法,本文給大家介紹的非常詳細,對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-03-03