美文网首页
两个tf模型合并

两个tf模型合并

作者: 求索_700e | 来源:发表于2019-06-12 21:33 被阅读0次

下面要实现的功能是:g1和g2并联,placeholder输入x是3.0, g1实现系y=3*x,g2实现y+3, 最后输出12

文件model_b.py如下:

import tensorflow as tf

from tensorflow.python.framework import graph_util

from tensorflow.python.tools import saved_model_utils

MODEL_SAVE_PATH = "./models/" # 保存模型的路径

with tf.Graph().as_default() as g2:

                input1 = tf.placeholder(tf.float32,name='g2_input')

                data = tf.Variable(3.)

                mul = tf.add(input1,data)

                tf.identity(mul,name='g2_output')

                init = tf.global_variables_initializer()

                saver = tf.train.Saver()

                with tf.Session(graph=g2) as sess:

                            sess.run(init)

                            g1def = graph_util.convert_variables_to_constants(sess,sess.graph_def,["g2_output"],

                                                                variable_names_whitelist=None,variable_names_blacklist=None)

                            #tf.train.write_graph(g1def, MODEL_SAVE_PATH, 'model_g2.pb', as_text=False)

                             saver.save(sess, "./models/g2_model.ckpt")

文件model_combined.py如下:

import tensorflow as tf

from tensorflow.python.framework import graph_util

from tensorflow.python.tools import saved_model_utils

MODEL_SAVE_PATH = "./models/" # 保存模型的路径

#g1和g2并联,输入x是3.0, g1实现系y=3*x,g2实现y+3, 最后输出12

with tf.Graph().as_default() as g1:

            input1 = tf.placeholder(tf.float32,name='g1_input')

            data = tf.Variable(3.)

            mul = tf.multiply(input1,data)

            tf.identity(mul,name='g1_output')

            init = tf.global_variables_initializer()

            with tf.Session(graph=g1) as sess:

                        sess.run(init)

                        g1def = graph_util.convert_variables_to_constants(sess, sess.graph_def,["g1_output"],

                                                variable_names_whitelist=None,

                                                variable_names_blacklist=None)

with tf.Graph().as_default() as g2:

                  with tf.Session(graph=g2) as sess:

                                     saver=tf.train.import_meta_graph('./models/g2_model.ckpt.meta')

                                      saver.restore(sess, './models/g2_model.ckpt')

                                      g2def = graph_util.convert_variables_to_constants(sess,sess.graph_def,["g2_output"])

##------------------------------------------------------------

with tf.Graph().as_default() as g_combined:

            with tf.Session(graph=g_combined) as sess:

                        x = tf.placeholder(tf.float32, name="my_input")

                        y = tf.import_graph_def(g1def, input_map={"g1_input:0": x}, return_elements=["g1_output:0"])

                        z, = tf.import_graph_def(g2def, input_map={"g2_input:0": y}, return_elements=["g2_output:0"])

                        tf.identity(z, "my_output")

                        print(sess.run(z,feed_dict={'my_input:0':3.}))

                        #保存1

                        #g_combineddef = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["my_output"])

                        #tf.train.write_graph(g_combineddef, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)

                        #保存2

                        #  tf.saved_model.simple_save(sess,

                        #   "./modelbase",

                        #   inputs={"my_input": x},

                        #   outputs={"my_output": z})

相关文章

网友评论

      本文标题:两个tf模型合并

      本文链接:https://www.haomeiwen.com/subject/gdtrfctx.html