【模型读写】TF读写protobuf模型方法 qwer2019-07-15 14:06:47 回复 6 查看
【模型读写】TF读写protobuf模型方法




在学习或者调试TF网络过程中,我们经常需要知道tf python api的代码最后生成出了怎样的模型。下面科普一下模型保存和加载的方法。


【模型保存】在写完python代码构建好网络后,模型我们可以保存成protobuf的二进制(pb)或者文本格式(prototxt)。


tf.train.write_graph(session.graph_def, 'path', 'filename', as_text=True)

as_text=True,保存为prototxt文本格式,肉眼可读可编辑,对调试网络,学习网络结构很有帮助。

as_text=False,保存为pb二进制格式,肉眼读不懂不可编辑。


【模型加载】加载pb和pbtxt的代码也会略有不同。

【加载prototxt】

from google.protobuf import text_format fid = tf.gfile.GFile('net.prototxt', 'r') pbtxt = fid.read() sess = tf.Session() graph_def = tf.GraphDef() graph_def = text_format.Merge(pbtxt, graph_def) tf.import_graph_def(graph_def)


【加载pb】


f= open('net.pb', 'r') graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def)


【完】



官方微博 官方微信
版权所有 © 2019 寒武纪 Cambricon 备案/许可证号:京ICP备17003415
关闭