Python3 Tensorflow 1.7保存和还原模型(save or restore model)

TensorFlow中,tf.saved_model和tf.train.Saver都是用于保存和恢复模型的工具,但它们服务于不同的场景和需求。tf.train.Saver主要用于在训练过程中保存检查点(checkpoints),以便能够从中断的地方继续训练,或者进行模型的恢复和评估。它是TensorFlow 1.x中使用最广泛的保存和恢复机制。tf.saved_model是一个更全面的服务,它不仅保存变量,还保存了整个TensorFlow图的定义。这使得它非常适合于模型的导出和部署,特别是在与TensorFlow Serving、TensorFlow Lite或TensorFlow JS等技术栈集成时。

1、保存模型

Tensorflow 1.7中,保存训练模型,

import tensorflow as tf
from tensorflow.python.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Sessionas sess:
        #省略其它逻辑代码
        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )

path/to/your/location/ :保存模型的路径

2、恢复模型

Tensorflow 1.7中,还原训练模型,

import tensorflow as tf
from tensorflow.python.saved_model import tag_constants

restored_graph = tf.Graph()
with restored_graph.as_default():
    with tf.Sessionas sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
        'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })

path/to/your/location/ :要还原模型的路径

文档地址:https://www.tensorflow.org/programmers_guide/saved_model

推荐阅读
cjavapy编程之路首页