Bootstrap
本文主要介绍在Python3中,使用Tensorflow怎样保存和还原训练的模型(trained model)。

1、Tensorflow 1.7中,保存训练模型

import tensorflow as tf
from tensorflow.python.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Sessionas sess:
        #省略其它逻辑代码
        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )

path/to/your/location/ :保存模型的路径

2、Tensorflow 1.7中,还原训练模型

import tensorflow as tf
from tensorflow.python.saved_model import tag_constants

restored_graph = tf.Graph()
with restored_graph.as_default():
    with tf.Sessionas sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
        'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })

path/to/your/location/ :要还原模型的路径

文档地址:https://www.tensorflow.org/programmers_guide/saved_model