打开微信,使用扫一扫进入页面后,点击右上角菜单,
点击“发送给朋友”或“分享到朋友圈”完成分享
在学习或者调试TF网络过程中,我们经常需要知道tf python api的代码最后生成出了怎样的模型。下面科普一下模型保存和加载的方法。
【模型保存】在写完python代码构建好网络后,模型我们可以保存成protobuf的二进制(pb)或者文本格式(prototxt)。
tf.train.write_graph(session.graph_def, 'path', 'filename', as_text=True)
as_text=True,保存为prototxt文本格式,肉眼可读可编辑,对调试网络,学习网络结构很有帮助。
as_text=False,保存为pb二进制格式,肉眼读不懂不可编辑。
【模型加载】加载pb和pbtxt的代码也会略有不同。
【加载prototxt】
from google.protobuf import text_format fid = tf.gfile.GFile('net.prototxt', 'r') pbtxt = fid.read() sess = tf.Session() graph_def = tf.GraphDef() graph_def = text_format.Merge(pbtxt, graph_def) tf.import_graph_def(graph_def)
【加载pb】
f= open('net.pb', 'r') graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def)
【完】
热门帖子
精华帖子