模型freeze
网上有很多关于模型保存checkpoint的文章,这里不再详细介绍,可参考文章

meta:是压缩后的protobuf格式文件,用来保存图结构
data:是保存训练后权重
index:是不可变的k-v表,k为序列化的tensor名,v为其在data文件的地址
checkpoint:用来获取不同时间点的保存的结果
正常来说,有以下几个步骤:
- 加载图结构
- 加载保存的权重, 启动一个session 然后在该session中恢复权重
- 去掉所有和inference无关的metadata
其中1,2都比较常见,重点在于第三步,也就是图的'freezing'过程(frozen_graph),
关键想法在于,当一个模型被训练好后,训练中所使用的大量operation在inference时候都是不需要的,如一般定义的train_op,所以我们可以利用TF提供的 freeze_graph 函数来选取我们需要保留的operation,如推测的结果,TF会自动保留该operation需要用到所有信息
另外,对于所有的权重(variable), TF会保存成constant的形式,这样最终序列化的结构大小会比原来大大缩小
import os, argparse
import tensorflow as tf
# The original freeze_graph function
# from tensorflow.python.tools.freeze_graph import freeze_graph
dir = os.path.dirname(os.path.realpath(__file__))
def freeze_graph(sess, graph, save_path, output_node_names, export=True, collections=[]):
"""
sess是训练好结果的session,如果已经保存在本地,可以加载进来
graph为图结构
output_node_names 是所有需要保存的节点(operation),用comma区分
这里clear devices 防止设备信息被写入,不然如果换到不同环境的上执行
可能会出错,其实这步也可以在load的时候指定
"""
clear_devices = True
with sess as sess:
""" 使用内置的函数来将 variables 变成 constants"""
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
graph.as_graph_def(),
output_node_names.split(",")
)
""" 这里可以选择 export meta graph 或者将graph序列化写出"""
if export:
""" export 会同时将 graph, saver, colletions 写到 'MetaGraphDef'
的pb结构中
"""
tf.train.export_meta_graph(filename=save_path, graph_def=output_graph_def,
collection_list=collections,
clear_devices=clear_devices)
else
""" 直接将GraphDef序列化出去,不过这种做法相比于上面还是有点区别
的,上面的做法,可以同时将collection信息写出去, 这里如果也想要写
collection 则需要额外写, 因为这里的output_graph_def 只是一个GraphDef,
而上面的export结果是MetaGraphDef,区别参考
https://www.tensorflow.org/api_guides/python/meta_graph
"""
with tf.gfile.GFile(save_path, "wb") as f:
f.write(output_graph_def.SerializeToString())
return output_graph_def
模型加载
这步比较简单, 根据 保存的方式对应加载
""" export_meta_graph """
def load_graph(load_path):
session = tf.Session(graph=tf.Graph())
with session.graph.as_default():
"""此时整个图的variable已经变成了constant,所以一个新的session就可以run"""
tf.train.import_meta_graph(load_path)
""" graph序列化的方式 """
def load_graph(load_path):
with tf.gfile.GFile(load_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
"""注意这里的 prefix,会在每个node前面加上前缀,
这里是创建新图所以无所谓,如果加载到已经构建的图中,
可以区分是哪个图的节点
"""
tf.import_graph_def(graph_def, name="prefix")
网友评论