用tensorflow搭建CNN的方法
CNN(Convolutional Neural Networks) 卷積神經(jīng)網(wǎng)絡(luò)簡(jiǎn)單講就是把一個(gè)圖片的數(shù)據(jù)傳遞給CNN,原涂層是由RGB組成,然后CNN把它的厚度加厚,長(zhǎng)寬變小,每做一層都這樣被拉長(zhǎng),最后形成一個(gè)分類器
在 CNN 中有幾個(gè)重要的概念:
- stride
- padding
- pooling
stride,就是每跨多少步抽取信息。每一塊抽取一部分信息,長(zhǎng)寬就縮減,但是厚度增加。抽取的各個(gè)小塊兒,再把它們合并起來,就變成一個(gè)壓縮后的立方體。
padding,抽取的方式有兩種,一種是抽取后的長(zhǎng)和寬縮減,另一種是抽取后的長(zhǎng)和寬和原來的一樣。
pooling,就是當(dāng)跨步比較大的時(shí)候,它會(huì)漏掉一些重要的信息,為了解決這樣的問題,就加上一層叫pooling,事先把這些必要的信息存儲(chǔ)起來,然后再變成壓縮后的層
利用tensorflow搭建CNN,也就是卷積神經(jīng)網(wǎng)絡(luò)是一件很簡(jiǎn)單的事情,筆者按照官方教程中使用MNIST手寫數(shù)字識(shí)別為例展開代碼,整個(gè)程序也基本與官方例程一致,不過在比較容易迷惑的地方加入了注釋,有一定的機(jī)器學(xué)習(xí)或者卷積神經(jīng)網(wǎng)絡(luò)制式的人都應(yīng)該可以迅速領(lǐng)會(huì)到代碼的含義。
#encoding=utf-8 import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True) def weight_variable(shape): initial = tf.truncated_normal(shape,stddev=0.1) #截?cái)嗾龖B(tài)分布,此函數(shù)原型為尺寸、均值、標(biāo)準(zhǔn)差 return tf.Variable(initial) def bias_variable(shape): initial = tf.constant(0.1,shape=shape) return tf.Variable(initial) def conv2d(x,W): return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME') # strides第0位和第3為一定為1,剩下的是卷積的橫向和縱向步長(zhǎng) def max_pool_2x2(x): return tf.nn.max_pool(x,ksize = [1,2,2,1],strides=[1,2,2,1],padding='SAME')# 參數(shù)同上,ksize是池化塊的大小 x = tf.placeholder("float", shape=[None, 784]) y_ = tf.placeholder("float", shape=[None, 10]) # 圖像轉(zhuǎn)化為一個(gè)四維張量,第一個(gè)參數(shù)代表樣本數(shù)量,-1表示不定,第二三參數(shù)代表圖像尺寸,最后一個(gè)參數(shù)代表圖像通道數(shù) x_image = tf.reshape(x,[-1,28,28,1]) # 第一層卷積加池化 w_conv1 = weight_variable([5,5,1,32]) # 第一二參數(shù)值得卷積核尺寸大小,即patch,第三個(gè)參數(shù)是圖像通道數(shù),第四個(gè)參數(shù)是卷積核的數(shù)目,代表會(huì)出現(xiàn)多少個(gè)卷積特征 b_conv1 = bias_variable([32]) h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1) h_pool1 = max_pool_2x2(h_conv1) # 第二層卷積加池化 w_conv2 = weight_variable([5,5,32,64]) # 多通道卷積,卷積出64個(gè)特征 b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2) h_pool2 = max_pool_2x2(h_conv2) # 原圖像尺寸28*28,第一輪圖像縮小為14*14,共有32張,第二輪后圖像縮小為7*7,共有64張 w_fc1 = weight_variable([7*7*64,1024]) b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64]) # 展開,第一個(gè)參數(shù)為樣本數(shù)量,-1未知 f_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1) # dropout操作,減少過擬合 keep_prob = tf.placeholder(tf.float32) h_fc1_drop = tf.nn.dropout(f_fc1,keep_prob) w_fc2 = weight_variable([1024,10]) b_fc2 = bias_variable([10]) y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2) cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) # 定義交叉熵為loss函數(shù) train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) # 調(diào)用優(yōu)化器優(yōu)化 correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) sess = tf.InteractiveSession() sess.run(tf.initialize_all_variables()) for i in range(2000): batch = mnist.train.next_batch(50) if i%100 == 0: train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0}) print "step %d, training accuracy %g"%(i, train_accuracy) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) print "test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images[0:500], y_: mnist.test.labels[0:500], keep_prob: 1.0})
在程序中主要注意這么幾點(diǎn):
1、維度問題,由于我們tensorflow基于的是張量這樣一個(gè)概念,張量其實(shí)就是維度擴(kuò)展的矩陣,因此維度特別重要,而且維度也是很容易使人迷惑的地方。
2、卷積問題,卷積核不只是二維的,多通道卷積時(shí)卷積核就是三維的
3、最后進(jìn)行檢驗(yàn)的時(shí)候,如果一次性加載出所有的驗(yàn)證集,出現(xiàn)了內(nèi)存爆掉的情況,由于是使用的是云端的服務(wù)器,可能內(nèi)存小一些,如果內(nèi)存夠用可以直接全部加載上看結(jié)果
4、這個(gè)程序原始版本迭代次數(shù)設(shè)置了20000次,這個(gè)次數(shù)大約要訓(xùn)練數(shù)個(gè)小時(shí)(在不使用GPU的情況下),這個(gè)次數(shù)可以按照要求更改。
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
Python移動(dòng)測(cè)試開發(fā)subprocess模塊項(xiàng)目實(shí)戰(zhàn)
這篇文章主要為大家介紹了Python移動(dòng)測(cè)試開發(fā)subprocess模塊項(xiàng)目實(shí)戰(zhàn)示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-07-07基于Django OneToOneField和ForeignKey的區(qū)別詳解
這篇文章主要介紹了基于Django OneToOneField和ForeignKey的區(qū)別詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-03-03五個(gè)Pandas?實(shí)戰(zhàn)案例帶你分析操作數(shù)據(jù)
pandas是基于NumPy的一種工具,該工具是為了解決數(shù)據(jù)分析任務(wù)而創(chuàng)建的。Pandas納入了大量庫和一些標(biāo)準(zhǔn)的數(shù)據(jù)模型,提供了高效操作大型數(shù)據(jù)集的工具。pandas提供大量快速便捷地處理數(shù)據(jù)的函數(shù)和方法。你很快就會(huì)發(fā)現(xiàn),它是使Python強(qiáng)大而高效的數(shù)據(jù)分析環(huán)境的重要因素之一2022-01-01Python基于QQ郵箱實(shí)現(xiàn)SSL發(fā)送
這篇文章主要介紹了Python基于QQ郵箱實(shí)現(xiàn)SSL發(fā)送,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-04-04tensorflow模型保存、加載之變量重命名實(shí)例
今天小編就為大家分享一篇tensorflow模型保存、加載之變量重命名實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-01-01python+opencv實(shí)現(xiàn)霍夫變換檢測(cè)直線
這篇文章主要為大家詳細(xì)介紹了python+opencv實(shí)現(xiàn)霍夫變換檢測(cè)直線,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-12-12