import tensorflow as tf

# 指定checkpoint文件和meta文件的路径
checkpoint_path = '/Users/mac/Desktop/model/checkpoint-25'
meta_path = '/Users/mac/Desktop/model/checkpoint-25.meta'

# 创建一个新的图
graph = tf.Graph()
with graph.as_default():
    # 使用tf.compat.v1来访问旧的API
    with tf.compat.v1.Session() as sess:
        # 导入meta文件中的图结构
        saver = tf.compat.v1.train.import_meta_graph(meta_path, clear_devices=True)
        
        # 恢复模型的权重
        saver.restore(sess, checkpoint_path)
        
        # 指定输出节点
        output_node_names = 'input,value_targets'  # 你需要指定你的模型的输出节点名称
        
        # 将变量转换为常量,并保存为pb文件
        output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
            sess,
            sess.graph_def,
            output_node_names.split(',')
        )
        
        # 保存为.pb文件
        with tf.io.gfile.GFile('/Users/mac/Desktop/model/frozen_model.pb', 'wb') as f:
            f.write(output_graph_def.SerializeToString())

pb预测

import tensorflow.compat.v1 as tf
# 定义一个用于加载模型的函数
def load_model(model_path):
    with graph.as_default():
        # 创建一个新的tf.Session
        sess = tf.compat.v1.Session(graph=graph)
        with tf.io.gfile.GFile(model_path, 'rb') as f:
            graph_def = tf.compat.v1.GraphDef()
            # 解析GraphDef
            graph_def.ParseFromString(f.read())
            # 导入GraphDef到图
            tf.import_graph_def(graph_def, name='')
        return sess  
model_session = load_model('/home/keras/open_spiel/model/aaa.pb')
import pyspiel
import numpy as np
game = pyspiel.load_game("tic_tac_toe")
state=game.new_initial_state()
obs = np.expand_dims(state.observation_tensor(), 0)
mask = np.expand_dims(state.legal_actions_mask(), 0)
def get_var(name):
    return model_session.graph.get_tensor_by_name(name + ":0")
input = get_var("input")
legals_mask = get_var("legals_mask")
training = get_var("training")
value_out = get_var("value_out")
policy_softmax = get_var("policy_softmax")
value, policy=model_session.run(
        [value_out, policy_softmax],
        feed_dict={input: np.array(obs, dtype=np.float32),
                   legals_mask: np.array(mask, dtype=bool),
                   training: False})

标签: none

添加新评论