checkpoint转pb
import tensorflow as tf
# 指定checkpoint文件和meta文件的路径
checkpoint_path = '/Users/mac/Desktop/model/checkpoint-25'
meta_path = '/Users/mac/Desktop/model/checkpoint-25.meta'
# 创建一个新的图
graph = tf.Graph()
with graph.as_default():
# 使用tf.compat.v1来访问旧的API
with tf.compat.v1.Session() as sess:
# 导入meta文件中的图结构
saver = tf.compat.v1.train.import_meta_graph(meta_path, clear_devices=True)
# 恢复模型的权重
saver.restore(sess, checkpoint_path)
# 指定输出节点
output_node_names = 'input,value_targets' # 你需要指定你的模型的输出节点名称
# 将变量转换为常量,并保存为pb文件
output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names.split(',')
)
# 保存为.pb文件
with tf.io.gfile.GFile('/Users/mac/Desktop/model/frozen_model.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
pb预测
import tensorflow.compat.v1 as tf
# 定义一个用于加载模型的函数
def load_model(model_path):
with graph.as_default():
# 创建一个新的tf.Session
sess = tf.compat.v1.Session(graph=graph)
with tf.io.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
# 解析GraphDef
graph_def.ParseFromString(f.read())
# 导入GraphDef到图
tf.import_graph_def(graph_def, name='')
return sess
model_session = load_model('/home/keras/open_spiel/model/aaa.pb')
import pyspiel
import numpy as np
game = pyspiel.load_game("tic_tac_toe")
state=game.new_initial_state()
obs = np.expand_dims(state.observation_tensor(), 0)
mask = np.expand_dims(state.legal_actions_mask(), 0)
def get_var(name):
return model_session.graph.get_tensor_by_name(name + ":0")
input = get_var("input")
legals_mask = get_var("legals_mask")
training = get_var("training")
value_out = get_var("value_out")
policy_softmax = get_var("policy_softmax")
value, policy=model_session.run(
[value_out, policy_softmax],
feed_dict={input: np.array(obs, dtype=np.float32),
legals_mask: np.array(mask, dtype=bool),
training: False})