Tensorflow實(shí)現(xiàn)AlexNet卷積神經(jīng)網(wǎng)絡(luò)及運(yùn)算時間評測
本文實(shí)例為大家分享了Tensorflow實(shí)現(xiàn)AlexNet卷積神經(jīng)網(wǎng)絡(luò)的具體實(shí)現(xiàn)代碼,供大家參考,具體內(nèi)容如下
之前已經(jīng)介紹過了AlexNet的網(wǎng)絡(luò)構(gòu)建了,這次主要不是為了訓(xùn)練數(shù)據(jù),而是為了對每個batch的前饋(Forward)和反饋(backward)的平均耗時進(jìn)行計(jì)算。在設(shè)計(jì)網(wǎng)絡(luò)的過程中,分類的結(jié)果很重要,但是運(yùn)算速率也相當(dāng)重要。尤其是在跟蹤(Tracking)的任務(wù)中,如果使用的網(wǎng)絡(luò)太深,那么也會導(dǎo)致實(shí)時性不好。
from datetime import datetime import math import time import tensorflow as tf batch_size = 32 num_batches = 100 def print_activations(t): print(t.op.name, '', t.get_shape().as_list()) def inference(images): parameters = [] with tf.name_scope('conv1') as scope: kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights') conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME') biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases') bias = tf.nn.bias_add(conv, biases) conv1 = tf.nn.relu(bias, name = scope) print_activations(conv1) parameters += [kernel, biases] lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1') pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1') print_activations(pool1) with tf.name_scope('conv2') as scope: kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights') conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME') biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases') bias = tf.nn.bias_add(conv, biases) conv2 = tf.nn.relu(bias, name = scope) parameters += [kernel, biases] print_activations(conv2) lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2') pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2') print_activations(pool2) with tf.name_scope('conv3') as scope: kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights') conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME') biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases') bias = tf.nn.bias_add(conv, biases) conv3 = tf.nn.relu(bias, name = scope) parameters += [kernel, biases] print_activations(conv3) with tf.name_scope('conv4') as scope: kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights') conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME') biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases') bias = tf.nn.bias_add(conv, biases) conv4 = tf.nn.relu(bias, name = scope) parameters += [kernel, biases] print_activations(conv4) with tf.name_scope('conv5') as scope: kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights') conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME') biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases') bias = tf.nn.bias_add(conv, biases) conv5 = tf.nn.relu(bias, name = scope) parameters += [kernel, biases] print_activations(conv5) pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5') print_activations(pool5) return pool5, parameters def time_tensorflow_run(session, target, info_string): num_steps_burn_in = 10 total_duration = 0.0 total_duration_squared = 0.0 for i in range(num_batches + num_steps_burn_in): start_time = time.time() _ = session.run(target) duration = time.time() - start_time if i >= num_steps_burn_in: if not i % 10: print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration)) total_duration += duration total_duration_squared += duration * duration mn = total_duration / num_batches vr = total_duration_squared / num_batches - mn * mn sd = math.sqrt(vr) print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd)) def run_benchmark(): with tf.Graph().as_default(): image_size = 224 images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1)) pool5, parameters = inference(images) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) time_tensorflow_run(sess, pool5, "Forward") objective = tf.nn.l2_loss(pool5) grad = tf.gradients(objective, parameters) time_tensorflow_run(sess, grad, "Forward-backward") run_benchmark()
這里的代碼都是之前講過的,只是加了一個計(jì)算時間和現(xiàn)實(shí)網(wǎng)絡(luò)的卷積核的函數(shù),應(yīng)該很容易就看懂了,就不多贅述了。我在GTX TITAN X上前饋大概需要0.024s, 反饋大概需要0.079s。哈哈,自己動手試一試哦。
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
Python數(shù)據(jù)持久化shelve模塊用法分析
這篇文章主要介紹了Python數(shù)據(jù)持久化shelve模塊用法,結(jié)合實(shí)例形式較為詳細(xì)的總結(jié)分析了shelve模塊的功能、原理及簡單使用方法,需要的朋友可以參考下2018-06-06使用Python向C語言的鏈接庫傳遞數(shù)組、結(jié)構(gòu)體、指針類型的數(shù)據(jù)
今天小編就為大家分享一篇關(guān)于使用Python向C語言的鏈接庫傳遞數(shù)組、結(jié)構(gòu)體、指針類型的數(shù)據(jù),小編覺得內(nèi)容挺不錯的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來看看吧2019-01-01Python3自動生成MySQL數(shù)據(jù)字典的markdown文本的實(shí)現(xiàn)
這篇文章主要介紹了Python3自動生成MySQL數(shù)據(jù)字典的markdown文本的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-05-05Python原始字符串(raw strings)用法實(shí)例
這篇文章主要介紹了Python原始字符串(raw strings)用法實(shí)例,在使用Python進(jìn)行字符串處理的過程中非常具有實(shí)用價(jià)值,需要的朋友可以參考下2014-10-10Python強(qiáng)化練習(xí)之Tensorflow2 opp算法實(shí)現(xiàn)月球登陸器
在面向?qū)ο蟪霈F(xiàn)之前,我們采用的開發(fā)方法都是面向過程的編程(OPP)。面向過程的編程中最常用的一個分析方法是“功能分解”。我們會把用戶需求先分解成模塊,然后把模塊分解成大的功能,再把大的功能分解成小的功能,整個需求就是按照這樣的方式,最終分解成一個一個的函數(shù)2021-10-10Django的性能優(yōu)化實(shí)現(xiàn)解析
這篇文章主要介紹了Django的性能優(yōu)化實(shí)現(xiàn)解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-07-07python selenium循環(huán)登陸網(wǎng)站的實(shí)現(xiàn)
這篇文章主要介紹了python selenium循環(huán)登陸網(wǎng)站的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-11-11Python PyQt5干貨滿滿小項(xiàng)目輕松實(shí)現(xiàn)高效摳圖去背景
PyQt5以一套Python模塊的形式來實(shí)現(xiàn)功能。它包含了超過620個類,600個方法和函數(shù)。本篇文章手把手帶你用PyQt5輕松實(shí)現(xiàn)圖片扣除背景,大家可以在過程中查缺補(bǔ)漏,提升水平2021-11-11Python-OpenCV實(shí)戰(zhàn):利用 KNN 算法識別手寫數(shù)字
K-最近鄰(KNN)是監(jiān)督學(xué)習(xí)中最簡單的算法之一,KNN可用于分類和回歸問題。本文將為大家介紹的是通過KNN算法實(shí)現(xiàn)識別手寫數(shù)字。文中的示例代碼介紹詳細(xì),需要的朋友可以參考一下2021-12-12Python matplotlib修改默認(rèn)字體的操作
這篇文章主要介紹了Python matplotlib修改默認(rèn)字體的操作,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-03-03