keras讀取h5文件load_weights、load代碼操作
關(guān)于保存h5模型、權(quán)重網(wǎng)上的示例非常多,也非常簡(jiǎn)單。主要有以下兩個(gè)函數(shù):
1、keras.models.load_model() 讀取網(wǎng)絡(luò)、權(quán)重
2、keras.models.load_weights() 僅讀取權(quán)重
load_model代碼包含load_weights的代碼,區(qū)別在于load_weights時(shí)需要先有網(wǎng)絡(luò)、并且load_weights需要將權(quán)重?cái)?shù)據(jù)寫入到對(duì)應(yīng)網(wǎng)絡(luò)層的tensor中。
下面以resnet50加載h5權(quán)重為例,示例代碼如下
import keras from keras.preprocessing import image import numpy as np from network.resnet50 import ResNet50 #修改過(guò),不加載權(quán)重(默認(rèn)官方加載亦可) model = ResNet50() # 參數(shù)默認(rèn) by_name = Fasle, 否則只讀取匹配的權(quán)重 # 這里h5的層和權(quán)重文件中層名是對(duì)應(yīng)的(除input層) model.load_weights(r'\models\resnet50_weights_tf_dim_ordering_tf_kernels_v2.h5')
模型通過(guò) model.summary()輸出

一、模型加載權(quán)重 load_weights()
def load_weights(self, filepath, by_name=False, skip_mismatch=False, reshape=False):
if h5py is None:
raise ImportError('`load_weights` requires h5py.')
with h5py.File(filepath, mode='r') as f:
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
if by_name:
saving.load_weights_from_hdf5_group_by_name(
f, self.layers, skip_mismatch=skip_mismatch,reshape=reshape)
else:
saving.load_weights_from_hdf5_group(f, self.layers, reshape=reshape)
這里關(guān)心函數(shù)saving.load_weights_from_hdf5_group(f, self.layers, reshape=reshape)即可,參數(shù) f 傳遞了一個(gè)h5py文件對(duì)象。
讀取h5文件使用 h5py 包,簡(jiǎn)單使用HDFView看一下resnet50的權(quán)重文件。

resnet50_v2 這個(gè)權(quán)重文件,僅一個(gè)attr “l(fā)ayer_names”, 該attr包含177個(gè)string的Array,Array中每個(gè)元素就是層的名字(這里是嚴(yán)格對(duì)應(yīng)在keras進(jìn)行保存權(quán)重時(shí)網(wǎng)絡(luò)中每一層的name值,且層的順序也嚴(yán)格對(duì)應(yīng))。
對(duì)于每一個(gè)key(層名),都有一個(gè)屬性"weights_names",(value值可能為空)。
例如:
conv1的"weights_names"有"conv1_W:0"和"conv1_b:0",
flatten_1的"weights_names"為null。

這里就簡(jiǎn)單介紹,后面在代碼中說(shuō)明h5py如何讀取權(quán)重?cái)?shù)據(jù)。
二、從hdf5文件中加載權(quán)重 load_weights_from_hdf5_group()
1、找出keras模型層中具有weight的Tensor(tf.Variable)的層
def load_weights_from_hdf5_group(f, layers, reshape=False): # keras模型resnet50的model.layers的過(guò)濾 # 僅保留layer.weights不為空的層,過(guò)濾掉無(wú)學(xué)習(xí)參數(shù)的層 filtered_layers = [] for layer in layers: weights = layer.weights if weights: filtered_layers.append(layer)

filtered_layers為當(dāng)前模型resnet50過(guò)濾(input、paddind、activation、merge/add、flastten等)層后剩下107層的list
2、從hdf5文件中獲取包含權(quán)重?cái)?shù)據(jù)的層的名字
前面通過(guò)HDFView看過(guò)每一層有一個(gè)[“weight_names”]屬性,如果不為空,就說(shuō)明該層存在權(quán)重?cái)?shù)據(jù)。
先看一下控制臺(tái)對(duì)h5py對(duì)象f的基本操作(需要的去查看相關(guān)數(shù)據(jù)結(jié)構(gòu)定義):
>>> f <HDF5 file "resnet50_weights_tf_dim_ordering_tf_kernels_v2.h5" (mode r)> >>> f.filename 'E:\\DeepLearning\\keras_test\\models\\resnet50_weights_tf_dim_ordering_tf_kernels_v2.h5' >>> f.name '/' >>> f.attrs.keys() # f屬性列表 # <KeysViewHDF5 ['layer_names']> >>> f.keys() #無(wú)順序 <KeysViewHDF5 ['activation_1', 'activation_10', 'activation_11', 'activation_12', ...,'activation_8', 'activation_9', 'avg_pool', 'bn2a_branch1', 'bn2a_branch2a', ...,'res5c_branch2a', 'res5c_branch2b', 'res5c_branch2c', 'zeropadding2d_1']> >>> f.attrs['layer_names'] #*** 有順序, 和summary()對(duì)應(yīng) **** array([b'input_1', b'zeropadding2d_1', b'conv1', b'bn_conv1', b'activation_1', b'maxpooling2d_1', b'res2a_branch2a', ..., b'res2a_branch1', b'bn2a_branch2c', b'bn2a_branch1', b'merge_1', b'activation_47', b'res5c_branch2b', b'bn5c_branch2b', ..., b'activation_48', b'res5c_branch2c', b'bn5c_branch2c', b'merge_16', b'activation_49', b'avg_pool', b'flatten_1', b'fc1000'], dtype='|S15') >>> f['input_1'] <HDF5 group "/input_1" (0 members)> >>> f['input_1'].attrs.keys() # 在keras中,每一個(gè)層都有‘weight_names'屬性 # <KeysViewHDF5 ['weight_names']> >>> f['input_1'].attrs['weight_names'] # input層無(wú)權(quán)重 # array([], dtype=float64) >>> f['conv1'] <HDF5 group "/conv1" (2 members)> >>> f['conv1'].attrs.keys() <KeysViewHDF5 ['weight_names']> >>> f['conv1'].attrs['weight_names'] # conv層有權(quán)重w、b # array([b'conv1_W:0', b'conv1_b:0'], dtype='|S9')
從文件中讀取具有權(quán)重?cái)?shù)據(jù)的層的名字列表
# 獲取后hdf5文本文件中層的名字,順序?qū)?yīng)
layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
#上一句實(shí)現(xiàn) layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
filtered_layer_names = []
for name in layer_names:
g = f[name]
weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
#上一句實(shí)現(xiàn) weight_names = [n.decode('utf8') for n in f[name].attrs['weight_names']]
#保留有權(quán)重層的名字
if weight_names:
filtered_layer_names.append(name)
layer_names = filtered_layer_names
# 驗(yàn)證模型中有有權(quán)重tensor的層 與 從h5中讀取有權(quán)重層名字的 數(shù)量 保持一致。
if len(layer_names) != len(filtered_layers):
raise ValueError('You are trying to load a weight file '
'containing ' + str(len(layer_names)) +
' layers into a model with ' +
str(len(filtered_layers)) + ' layers.')
3、從hdf5文件中讀取的權(quán)重?cái)?shù)據(jù)、和keras模型層tf.Variable打包對(duì)應(yīng)
先看一下權(quán)重?cái)?shù)據(jù)、層的權(quán)重變量(Tensor tf.Variable)對(duì)象,以conv1為例
>>> f['conv1']['conv1_W:0'] # conv1_W:0 權(quán)重?cái)?shù)據(jù)數(shù)據(jù)集 <HDF5 dataset "conv1_W:0": shape (7, 7, 3, 64), type "<f4"> >>> f['conv1']['conv1_W:0'].value # conv1_W:0 權(quán)重?cái)?shù)據(jù)的值, 是一個(gè)標(biāo)準(zhǔn)的4d array array([[[[ 2.82526277e-02, -1.18737184e-02, 1.51488732e-03, ..., -1.07003953e-02, -5.27982824e-02, -1.36667420e-03], [ 5.86827798e-03, 5.04415408e-02, 3.46324709e-03, ..., 1.01423981e-02, 1.39493728e-02, 1.67549420e-02], [-2.44090753e-03, -4.86173332e-02, 2.69966386e-03, ..., -3.44439060e-04, 3.48098315e-02, 6.28910400e-03]], [[ 1.81872323e-02, -7.20698107e-03, 4.80302610e-03, ..., …. ]]]]) >>> conv1_w = np.asarray(f['conv1']['conv1_W:0']) # 直接轉(zhuǎn)換成numpy格式 >>> conv1_w.shape (7, 7, 3, 64) # 卷積層 >>> filtered_layers[0] <keras.layers.convolutional.Conv2D object at 0x000001F7487C0E10> >>> filtered_layers[0].name 'conv1' >>> filtered_layers[0].input <tf.Tensor 'conv1_pad/Pad:0' shape=(?, 230, 230, 3) dtype=float32> #卷積層權(quán)重?cái)?shù)據(jù) >>> filtered_layers[0].weights [<tf.Variable 'conv1/kernel:0' shape=(7, 7, 3, 64) dtype=float32_ref>, <tf.Variable 'conv1/bias:0' shape=(64,) dtype=float32_ref>]
將模型權(quán)重?cái)?shù)據(jù)變量Tensor(tf.Variable)、讀取的權(quán)重?cái)?shù)據(jù)打包對(duì)應(yīng),便于后續(xù)將數(shù)據(jù)寫入到權(quán)重變量中.
weight_value_tuples = []
# 枚舉過(guò)濾后的層
for k, name in enumerate(layer_names):
g = f[name]
weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
# 獲取文件中當(dāng)前層的權(quán)重?cái)?shù)據(jù)list, 數(shù)據(jù)類型轉(zhuǎn)換為numpy array
weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
# 獲取keras模型中層具有的權(quán)重?cái)?shù)據(jù)tf.Variable個(gè)數(shù)
layer = filtered_layers[k]
symbolic_weights = layer.weights
# 權(quán)重?cái)?shù)據(jù)預(yù)處理
weight_values = preprocess_weights_for_loading(layer, weight_values,
original_keras_version, original_backend,reshape=reshape)
# 驗(yàn)證權(quán)重?cái)?shù)據(jù)、tf.Variable數(shù)據(jù)是否相同
if len(weight_values) != len(symbolic_weights):
raise ValueError('Layer #' + str(k) + '(named "' + layer.name +
'" in the current model) was found to correspond to layer ' + name +
' in the save file. However the new layer ' + layer.name + ' expects ' +
str(len(symbolic_weights)) + 'weights, but the saved weights have ' +
str(len(weight_values)) + ' elements.')
# tf.Variable 和 權(quán)重?cái)?shù)據(jù) 打包
weight_value_tuples += zip(symbolic_weights, weight_values)
4、將讀取的權(quán)重?cái)?shù)據(jù)寫入到層的權(quán)重變量中
在3中已經(jīng)對(duì)應(yīng)好每一層的權(quán)重變量Tensor和權(quán)重?cái)?shù)據(jù),后面將使用tensorflow的sess.run方法進(jìn)新寫入,后面一行代碼。
K.batch_set_value(weight_value_tuples)
實(shí)際實(shí)現(xiàn)
def batch_set_value(tuples):
if tuples:
assign_ops = []
feed_dict = {}
for x, value in tuples:
# 獲取權(quán)重?cái)?shù)據(jù)類型
value = np.asarray(value, dtype=dtype(x))
tf_dtype = tf.as_dtype(x.dtype.name.split('_')[0])
if hasattr(x, '_assign_placeholder'):
assign_placeholder = x._assign_placeholder
assign_op = x._assign_op
else:
# 權(quán)重的tf.placeholder
assign_placeholder = tf.placeholder(tf_dtype, shape=value.shape)
# 對(duì)權(quán)重變量Tensor的賦值 assign的operation
assign_op = x.assign(assign_placeholder)
x._assign_placeholder = assign_placeholder # 用處?
x._assign_op = assign_op # 用處?
assign_ops.append(assign_op)
feed_dict[assign_placeholder] = value
# 利用tensorflow的tf.Session().run()對(duì)tensor進(jìn)行assign批次賦值
get_session().run(assign_ops, feed_dict=feed_dict)
至此,先有網(wǎng)絡(luò)模型,后從h5中加載權(quán)重文件結(jié)束。后面就可以直接利用模型進(jìn)行predict了。
三、模型加載 load_model()
這里基本和前面類似,多了一個(gè)加載網(wǎng)絡(luò)而已,后面的權(quán)重加載方式一樣。
首先將前面加載權(quán)重的模型使用 model.save()保存為res50_model.h5,使用HDFView查看

屬性成了3個(gè),backend, keras_version和model_config,用于說(shuō)明模型文件由某種后端生成,后端版本,以及json格式的網(wǎng)絡(luò)模型結(jié)構(gòu)。
有一個(gè)key鍵"model_weights", 相較于屬性有前面的h5模型,屬性多了2個(gè)為['backend', 'keras_version', 'layer_names'] 該key鍵下面的鍵值是一個(gè)list, 和前面的h5模型的權(quán)重?cái)?shù)據(jù)完全一致。
類似的,先利用python代碼查看下文件結(jié)構(gòu)
>>> ff <HDF5 file "res50_model.h5" (mode r)> >>> ff.attrs.keys() <KeysViewHDF5 ['backend', 'keras_version', 'model_config']> >>> ff.keys() <KeysViewHDF5 ['model_weights']> >>> ff['model_weights'].attrs.keys() ## ff['model_weights']有三個(gè)屬性 <KeysViewHDF5 ['backend', 'keras_version', 'layer_names']> >>> ff['model_weights'].keys() ## 無(wú)順序 <KeysViewHDF5 ['activation_1', 'activation_10', 'activation_11', 'activation_12', …, 'avg_pool', 'bn2a_branch1', 'bn2a_branch2a', 'bn2a_branch2b', …, 'bn5c_branch2c', 'bn_conv1', 'conv1', 'conv1_pad', 'fc1000', 'input_1', …, 'c_branch2a', 'res5c_branch2b', 'res5c_branch2c']> >>> ff['model_weights'].attrs['layer_names'] ## 有順序 array([b'input_1', b'conv1_pad', b'conv1', b'bn_conv1', b'activation_1', b'pool1_pad', b'max_pooling2d_1', b'res2a_branch2a', b'bn2a_branch2a', b'activation_2', b'res2a_branch2b', ... 省略 b'activation_48', b'res5c_branch2c', b'bn5c_branch2c', b'add_16', b'activation_49', b'avg_pool', b'fc1000'], dtype='|S15')
1、加載模型主函數(shù)load_model
def load_model(filepath, custom_objects=None, compile=True):
if h5py is None:
raise ImportError('`load_model` requires h5py.')
model = None
opened_new_file = not isinstance(filepath, h5py.Group)
# h5加載后轉(zhuǎn)換為一個(gè) h5dict 類,編譯通過(guò)鍵取值
f = h5dict(filepath, 'r')
try:
# 序列化并compile
model = _deserialize_model(f, custom_objects, compile)
finally:
if opened_new_file:
f.close()
return model
2、序列化并編譯_deserialize_model
函數(shù)def _deserialize_model(f, custom_objects=None, compile=True)的代碼顯示主要部分
第一步,加載網(wǎng)絡(luò)結(jié)構(gòu),實(shí)現(xiàn)完全同keras.models.model_from_json()
# 從h5中讀取網(wǎng)絡(luò)結(jié)構(gòu)的json描述字符串
model_config = f['model_config']
model_config = json.loads(model_config.decode('utf-8'))
# 根據(jù)json構(gòu)建網(wǎng)絡(luò)模型結(jié)構(gòu)
model = model_from_config(model_config, custom_objects=custom_objects)
第二步,加載網(wǎng)絡(luò)權(quán)重,完全同model.load_weights()
# 獲取有順序的網(wǎng)絡(luò)層名, 網(wǎng)絡(luò)層 model_weights_group = f['model_weights'] layer_names = model_weights_group['layer_names'] layers = model.layers # 過(guò)濾 有權(quán)重Tensor的層 for layer in layers: weights = layer.weights if weights: filtered_layers.append(layer) # 過(guò)濾有權(quán)重的數(shù)據(jù) filtered_layer_names = [] for name in layer_names: layer_weights = model_weights_group[name] weight_names = layer_weights['weight_names'] if weight_names: filtered_layer_names.append(name) # 打包數(shù)據(jù) weight_value_tuples weight_value_tuples = [] for k, name in enumerate(layer_names): layer_weights = model_weights_group[name] weight_names = layer_weights['weight_names'] weight_values = [layer_weights[weight_name] for weight_name in weight_names] layer = filtered_layers[k] symbolic_weights = layer.weights weight_values = preprocess_weights_for_loading(...) weight_value_tuples += zip(symbolic_weights, weight_values) # 批寫入 K.batch_set_value(weight_value_tuples)
第三步,compile并返回模型
正常情況,模型網(wǎng)路建立、加載權(quán)重后 compile之后就完成。若還有其他設(shè)置,則可以再進(jìn)行額外的處理。(模型訓(xùn)練后save會(huì)有額外是參數(shù)設(shè)置)。
例如,一個(gè)只有dense層的網(wǎng)路訓(xùn)練保存后查看,屬性多了"training_config",鍵多了"optimizer_weights",如下圖。

當(dāng)前res50_model.h5沒有額外的參數(shù)設(shè)置。
處理代碼如下
if compile:
training_config = f.get('training_config')
if training_config is None:
warnings.warn('No training configuration found in save file: '
'the model was *not* compiled. Compile it manually.')
return model
training_config = json.loads(training_config.decode('utf-8'))
optimizer_config = training_config['optimizer_config']
optimizer = optimizers.deserialize(optimizer_config, custom_objects=custom_objects)
# Recover loss functions and metrics.
loss = convert_custom_objects(training_config['loss'])
metrics = convert_custom_objects(training_config['metrics'])
sample_weight_mode = training_config['sample_weight_mode']
loss_weights = training_config['loss_weights']
# Compile model.
model.compile(optimizer=optimizer, loss=loss, metrics=metrics,
loss_weights=loss_weights, sample_weight_mode=sample_weight_mode)
# Set optimizer weights.
if 'optimizer_weights' in f:
# Build train function (to get weight updates).
model._make_train_function()
optimizer_weights_group = f['optimizer_weights']
optimizer_weight_names = [
n.decode('utf8') for n in ptimizer_weights_group['weight_names']]
optimizer_weight_values = [
optimizer_weights_group[n] for n in optimizer_weight_names]
try:
model.optimizer.set_weights(optimizer_weight_values)
except ValueError:
warnings.warn('Error in loading the saved optimizer state. As a result,'
'your model is starting with a freshly initialized optimizer.')
以上這篇keras讀取h5文件load_weights、load代碼操作就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python中__call__內(nèi)置函數(shù)用法實(shí)例
這篇文章主要介紹了python中__call__內(nèi)置函數(shù)用法,實(shí)例分析了python中__call__內(nèi)置函數(shù)的原理與使用技巧,需要的朋友可以參考下2015-06-06
Python中的復(fù)制操作及copy模塊中的淺拷貝與深拷貝方法
淺拷貝和深拷貝是Python基礎(chǔ)學(xué)習(xí)中必須辨析的知識(shí)點(diǎn),這里我們將為大家解析Python中的復(fù)制操作及copy模塊中的淺拷貝與深拷貝方法:2016-07-07
Django中QuerySet查詢優(yōu)化之prefetch_related詳解
prefetch_related()和select_related()的設(shè)計(jì)目的很相似,都是為了減少SQL查詢的數(shù)量,但是實(shí)現(xiàn)的方式不一樣,下面這篇文章主要給大家介紹了關(guān)于Django中QuerySet查詢優(yōu)化之prefetch_related的相關(guān)資料,需要的朋友可以參考下2022-11-11
pytest多進(jìn)程或多線程執(zhí)行測(cè)試實(shí)例
這篇文章介紹了pytest多進(jìn)程或多線程執(zhí)行測(cè)試的實(shí)例,文中通過(guò)示例代碼介紹的非常詳細(xì)。對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-07-07
Python分割單詞和轉(zhuǎn)換命名法的實(shí)現(xiàn)
本文主要介紹了Python分割單詞和轉(zhuǎn)換命名法的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03
Python3以GitHub為例來(lái)實(shí)現(xiàn)模擬登錄和爬取的實(shí)例講解
在本篇內(nèi)容里小編給大家分享的是關(guān)于Python3以GitHub為例來(lái)實(shí)現(xiàn)模擬登錄和爬取的實(shí)例講解,需要的朋友們可以參考下。2020-07-07
PyTorch 中的傅里葉卷積實(shí)現(xiàn)示例
這篇文章主要介紹了PyTorch 中的傅里葉卷積實(shí)現(xiàn)示例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-12-12

