读取keras保存的h5文件,显示各层的权重
# hdf5的数据结构主要是File - Group - Dataset三级,# 数据集dataset, 是同一类型数据的多维数组; 组group, 是一种容器结构# 参考我们的文件系统,不同的文件存放在不同的目录下:# 目录就是group,描述了数据集DataSet的分类信息,通过group有效的将多种dataset进行管理和划分# 文件就是dataset,表示具体的数据测试文件下载:blstm
·
# hdf5的数据结构主要是File - Group - Dataset三级,
# 数据集dataset, 是同一类型数据的多维数组; 组group, 是一种容器结构
# 参考我们的文件系统,不同的文件存放在不同的目录下:
# 目录就是group,描述了数据集DataSet的分类信息,通过group有效的将多种dataset进行管理和划分
# 文件就是dataset,表示具体的数据
测试文件下载 : blstm_model.h5 和 best_model.weights
链接: https://pan.baidu.com/s/189lGr5foy4AafwVFGGa3dw 提取码: gtye
import os
import h5py
import numpy as np
def print_model_h5_wegiths(weight_file_path):
# weights的tensor保存在Dataset的value中,而每一集都会有attrs保存各网络层的属性
f = h5py.File(weight_file_path) # 读取weights h5文件返回File类
try:
if len(f.attrs.items()):
print("{} contains: ".format(f.filename)) # weight_file_path
print("Root attributes:")
for key, value in f.attrs.items():
print(" {}: {}".format(key, value))
# 输出储存在File类中的attrs信息,一般是各层的名称 : layer_names\ backend \keras_version
for layer, g in f.items():
# 读取各层的名称以及包含层信息的Group类
print(" {} with Group : {}".format(layer, g)) # model_weights with Group : <HDF5 (22 members)>),
print(" Attributes:")
for key, value in g.attrs.items():
# 输出储存在Group类中的attrs信息,一般是各层的weights和bias及他们的名称
# eg ;weight_names: [b'attention_2/q_kernel:0' b'attention_2/k_kernel:0' b'attention_2/w_kernel:0']
print(" {}: {}".format(key, value))
#
print(" Dataset:") # np.array(f.get(key)).shape()
for name, d in g.items(): # 读取各层储存具体信息的Dataset类
print('name: ', name, d)
if str(f.filename).endswith('.weights'):
for k, v in d.items():
# 输出储存在Dataset中的层名称和权重,也可以打印dataset的attrs
# k , v embeddings:0 <HDF5 dataset "embeddings:0": shape (21, 128), type "<f4">
print(' {} with shape : {} or {} '.format(k, np.array(d.get(k)).shape, np.array(v).shape))
print(" {} have weights : {}".format(k, np.array(v))) # 各层的权重
if str(f.filename).endswith('.h5'):
for k, v in d.items(): # v 等价于 d.get(k)
print(k, v)
# Adam <HDF5 group "/optimizer_weights/training/Adam" (63 members)>
finally:
f.close()
print('当前工作路径:', os.getcwd())
model_weights = r'../ckpt/best_model.weights'
print_model_h5_wegiths(model_weights )
print('***'*10)
h5_weight = r'../ckpt/blstm_model.h5'
print_model_h5_wegiths(h5_weight)
更多推荐
已为社区贡献1条内容
所有评论(0)