Tensorflow 如何從checkpoint文件中加載變量名和變量值
假設你已經(jīng)經(jīng)過上千次的迭代,并且得到了以下模型:
則從這些checkpoint文件中加載變量名和變量值代碼如下:
model_dir = './ckpt-182802' import tensorflow as tf from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader(model_dir) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key)) # Remove this is you want to print only variable names
Mnist
下面將給出一個基于卷積神經(jīng)網(wǎng)絡的手寫數(shù)字識別樣例:
# -*- coding: utf-8 -*- import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow.python.framework import graph_util log_dir = './tensorboard' mnist = input_data.read_data_sets(train_dir="./mnist_data",one_hot=True) if tf.gfile.Exists(log_dir): tf.gfile.DeleteRecursively(log_dir) tf.gfile.MakeDirs(log_dir) #定義輸入數(shù)據(jù)mnist圖片大小28*28*1=784,None表示batch_size x = tf.placeholder(dtype=tf.float32,shape=[None,28*28],name="input") #定義標簽數(shù)據(jù),mnist共10類 y_ = tf.placeholder(dtype=tf.float32,shape=[None,10],name="y_") #將數(shù)據(jù)調(diào)整為二維數(shù)據(jù),w*H*c---> 28*28*1,-1表示N張 image = tf.reshape(x,shape=[-1,28,28,1]) #第一層,卷積核={5*5*1*32},池化核={2*2*1,1*2*2*1} w1 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,1,32],stddev=0.1,dtype=tf.float32,name="w1")) b1= tf.Variable(initial_value=tf.zeros(shape=[32])) conv1 = tf.nn.conv2d(input=image,filter=w1,strides=[1,1,1,1],padding="SAME",name="conv1") relu1 = tf.nn.relu(tf.nn.bias_add(conv1,b1),name="relu1") pool1 = tf.nn.max_pool(value=relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME") #shape={None,14,14,32} #第二層,卷積核={5*5*32*64},池化核={2*2*1,1*2*2*1} w2 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,32,64],stddev=0.1,dtype=tf.float32,name="w2")) b2 = tf.Variable(initial_value=tf.zeros(shape=[64])) conv2 = tf.nn.conv2d(input=pool1,filter=w2,strides=[1,1,1,1],padding="SAME") relu2 = tf.nn.relu(tf.nn.bias_add(conv2,b2),name="relu2") pool2 = tf.nn.max_pool(value=relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="pool2") #shape={None,7,7,64} #FC1 w3 = tf.Variable(initial_value=tf.random_normal(shape=[7*7*64,1024],stddev=0.1,dtype=tf.float32,name="w3")) b3 = tf.Variable(initial_value=tf.zeros(shape=[1024])) #關鍵,進行reshape input3 = tf.reshape(pool2,shape=[-1,7*7*64],name="input3") fc1 = tf.nn.relu(tf.nn.bias_add(value=tf.matmul(input3,w3),bias=b3),name="fc1") #shape={None,1024} #FC2 w4 = tf.Variable(initial_value=tf.random_normal(shape=[1024,10],stddev=0.1,dtype=tf.float32,name="w4")) b4 = tf.Variable(initial_value=tf.zeros(shape=[10])) fc2 = tf.nn.bias_add(value=tf.matmul(fc1,w4),bias=b4,name="logit") #shape={None,10} #定義交叉熵損失 # 使用softmax將NN計算輸出值表示為概率 y = tf.nn.softmax(fc2,name="out") # 定義交叉熵損失函數(shù) cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=fc2,labels=y_) loss = tf.reduce_mean(cross_entropy) tf.summary.scalar('Cross_Entropy',loss) #定義solver train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss) for var in tf.trainable_variables(): print var #train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss) #定義正確值,判斷二者下標index是否相等 correct_predict = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) #定義如何計算準確率 accuracy = tf.reduce_mean(tf.cast(correct_predict,dtype=tf.float32),name="accuracy") tf.summary.scalar('Training_ACC',accuracy) #定義初始化op merged = tf.summary.merge_all() init = tf.global_variables_initializer() saver = tf.train.Saver() #訓練NN with tf.Session() as session: session.run(fetches=init) writer = tf.summary.FileWriter(log_dir,session.graph) #定義記錄日志的位置 for i in range(0,500): xs, ys = mnist.train.next_batch(100) session.run(fetches=train,feed_dict={x:xs,y_:ys}) if i%10 == 0: train_accuracy,summary = session.run(fetches=[accuracy,merged],feed_dict={x:xs,y_:ys}) writer.add_summary(summary,i) print(i,"accuracy=",train_accuracy) ''' #訓練完成后,將網(wǎng)絡中的權值轉(zhuǎn)化為常量,形成常量graph,注意:需要x與label constant_graph = graph_util.convert_variables_to_constants(sess=session, input_graph_def=session.graph_def, output_node_names=['out','y_','input']) #將帶權值的graph序列化,寫成pb文件存儲起來 with tf.gfile.FastGFile("lenet.pb", mode='wb') as f: f.write(constant_graph.SerializeToString()) ''' saver.save(session,'./ckpt')
補充:查看tensorflow產(chǎn)生的checkpoint文件內(nèi)容的方法
tensorflow在保存權重模型時多使用tf.train.Saver().save 函數(shù)進行權重保存,保存的ckpt文件無法直接打開,但tensorflow提供了相關函數(shù) tf.train.NewCheckpointReader 可以對ckpt文件進行權重查看。
import os from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join('modelckpt', "fc_nn_model") # Read data from checkpoint file reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() # Print tensor name and values for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key))
其中‘modelckpt'是存放.ckpt文件的文件夾,"fc_nn_model"是文件名,如下圖所示。
var_to_shape_map是一個字典,其中的鍵值是變量名,對應的值是該變量的形狀,如
{‘LSTM_input/bias_LSTM/Adam_1': [128]}
想要查看某變量值時,需要調(diào)用get_tensor函數(shù),即輸入以下代碼:
reader.get_tensor('LSTM_input/bias_LSTM/Adam_1')
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python開發(fā)SQLite3數(shù)據(jù)庫相關操作詳解【連接,查詢,插入,更新,刪除,關閉等】
這篇文章主要介紹了Python開發(fā)SQLite3數(shù)據(jù)庫相關操作,結合實例形式較為詳細的分析了Python操作SQLite3數(shù)據(jù)庫的連接,查詢,插入,更新,刪除,關閉等相關操作技巧,需要的朋友可以參考下2017-07-07基于Django框架利用Ajax實現(xiàn)點贊功能實例代碼
點贊這個功能是我們現(xiàn)在經(jīng)常會遇到的一個功能,下面這篇文章主要給大家介紹了關于基于Django框架利用Ajax實現(xiàn)點贊功能的相關資料,文中通過示例代碼介紹的非常詳細,需要的朋友們下面隨著小編來一起學習學習吧2018-08-08Python實現(xiàn)一個簡單三層神經(jīng)網(wǎng)絡的搭建及測試 代碼解析
一個完整的神經(jīng)網(wǎng)絡一般由三層構成:輸入層,隱藏層(可以有多層)和輸出層。本文所構建的神經(jīng)網(wǎng)絡隱藏層只有一層。一個神經(jīng)網(wǎng)絡主要由三部分構成(代碼結構上):初始化,訓練,和預測。,需要的朋友可以參考下面文章內(nèi)容的具體內(nèi)容2021-09-09