keras實現(xiàn)theano和tensorflow訓(xùn)練的模型相互轉(zhuǎn)換
我就廢話不多說了,大家還是直接看代碼吧~
</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">
# coding:utf-8 """ If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D), be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe), and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano). """ from keras import backend as K from keras.utils.np_utils import convert_kernel from text_classifier import keras_text_classifier import sys def th2tf( model): import tensorflow as tf ops = [] for layer in model.layers: if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']: original_w = K.get_value(layer.W) converted_w = convert_kernel(original_w) ops.append(tf.assign(layer.W, converted_w).op) K.get_session().run(ops) return model def tf2th(model): for layer in model.layers: if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']: original_w = K.get_value(layer.W) converted_w = convert_kernel(original_w) K.set_value(layer.W, converted_w) return model def conv_layer_converted(tf_weights, th_weights, m = 0): """ :param tf_weights: :param th_weights: :param m: 0-tf2th, 1-th2tf :return: """ if m == 0: # tf2th tc = keras_text_classifier(weights_path=tf_weights) model = tc.loadmodel() model = tf2th(model) model.save_weights(th_weights) elif m == 1: # th2tf tc = keras_text_classifier(weights_path=th_weights) model = tc.loadmodel() model = th2tf(model) model.save_weights(tf_weights) else: print("0-tf2th, 1-th2tf") return if __name__ == '__main__': if len(sys.argv) < 4: print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow") sys.exit(0) tf_weights = sys.argv[1] th_weights = sys.argv[2] m = int(sys.argv[3]) conv_layer_converted(tf_weights, th_weights, m)
補充知識:keras學(xué)習(xí)之修改底層為TensorFlow還是theano
我們知道,keras的底層是TensorFlow或者theano
要知道我們是用的哪個為底層,只需要import keras即可顯示
修改方法:
打開
修改
以上這篇keras實現(xiàn)theano和tensorflow訓(xùn)練的模型相互轉(zhuǎn)換就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
caffe binaryproto 與 npy相互轉(zhuǎn)換的實例講解
今天小編就為大家分享一篇caffe binaryproto 與 npy相互轉(zhuǎn)換的實例講解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07基于Numpy.convolve使用Python實現(xiàn)滑動平均濾波的思路詳解
這篇文章主要介紹了Python極簡實現(xiàn)滑動平均濾波(基于Numpy.convolve)的相關(guān)知識,非常不錯,具有一定的參考借鑒價值,需要的朋友可以參考下2019-05-05