美文网首页
Tensorflow 模型freeze并用于生产环境

Tensorflow 模型freeze并用于生产环境

作者: KK侠 | 来源:发表于2019-01-03 20:33 被阅读0次

模型freeze

网上有很多关于模型保存checkpoint的文章,这里不再详细介绍,可参考文章

image.png
meta:是压缩后的protobuf格式文件,用来保存图结构
data:是保存训练后权重
index:是不可变的k-v表,k为序列化的tensor名,v为其在data文件的地址
checkpoint:用来获取不同时间点的保存的结果

正常来说,有以下几个步骤:

  1. 加载图结构
  2. 加载保存的权重, 启动一个session 然后在该session中恢复权重
  3. 去掉所有和inference无关的metadata

其中1,2都比较常见,重点在于第三步,也就是图的'freezing'过程(frozen_graph),
关键想法在于,当一个模型被训练好后,训练中所使用的大量operation在inference时候都是不需要的,如一般定义的train_op,所以我们可以利用TF提供的 freeze_graph 函数来选取我们需要保留的operation,如推测的结果,TF会自动保留该operation需要用到所有信息

另外,对于所有的权重(variable), TF会保存成constant的形式,这样最终序列化的结构大小会比原来大大缩小

import os, argparse

import tensorflow as tf

# The original freeze_graph function
# from tensorflow.python.tools.freeze_graph import freeze_graph 

dir = os.path.dirname(os.path.realpath(__file__))

def freeze_graph(sess, graph, save_path, output_node_names, export=True, collections=[]):
    """
         sess是训练好结果的session,如果已经保存在本地,可以加载进来
         graph为图结构
         output_node_names 是所有需要保存的节点(operation),用comma区分
   
         这里clear devices 防止设备信息被写入,不然如果换到不同环境的上执行 
         可能会出错,其实这步也可以在load的时候指定
     """
    clear_devices = True

    with sess as sess:
       """ 使用内置的函数来将 variables 变成 constants"""
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, 
            graph.as_graph_def(), 
            output_node_names.split(",") 
        ) 
        """ 这里可以选择 export meta graph 或者将graph序列化写出"""
        if export:
            """ export 会同时将 graph, saver, colletions 写到 'MetaGraphDef' 
            的pb结构中
            """
            tf.train.export_meta_graph(filename=save_path, graph_def=output_graph_def,
                                       collection_list=collections,
                                       clear_devices=clear_devices)
        else 
             """ 直接将GraphDef序列化出去,不过这种做法相比于上面还是有点区别 
             的,上面的做法,可以同时将collection信息写出去, 这里如果也想要写        
             collection 则需要额外写, 因为这里的output_graph_def 只是一个GraphDef, 
             而上面的export结果是MetaGraphDef,区别参考
             https://www.tensorflow.org/api_guides/python/meta_graph
             """
             with tf.gfile.GFile(save_path, "wb") as f:
                f.write(output_graph_def.SerializeToString())
    return output_graph_def

模型加载

这步比较简单, 根据 保存的方式对应加载

  """ export_meta_graph """
 def load_graph(load_path):
    session = tf.Session(graph=tf.Graph())
    with session.graph.as_default():
       """此时整个图的variable已经变成了constant,所以一个新的session就可以run"""
       tf.train.import_meta_graph(load_path)

 """ graph序列化的方式 """
 def load_graph(load_path):
    with tf.gfile.GFile(load_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        """注意这里的 prefix,会在每个node前面加上前缀,
        这里是创建新图所以无所谓,如果加载到已经构建的图中,
        可以区分是哪个图的节点
        """
        tf.import_graph_def(graph_def, name="prefix")

相关文章

网友评论

      本文标题:Tensorflow 模型freeze并用于生产环境

      本文链接:https://www.haomeiwen.com/subject/esctrqtx.html