How to export frozen graph in TF2
As TF2 is focusing on eager execution, the concept of session is gone and everything that runs turn into concrete functions. But the concept of traditional frozen graph can still find its use cases, especially for development and debugging.
Let’s say that you have a very simple model built on top of Keras.
keras = tf.kerasclass MyCustomLayer(keras.layers.Layer): def __init__(self): super(MyCustomLayer, self).__init__(self) self._weight = tf.Variable(initial_value=(2., 3.)) def call(self, input): output = tf.sigmoid(input) * self._weight return outputmodel = keras.models.Sequential(
[keras.layers.Input((1,2)), MyCustomLayer()])
In TF1.x you can get the session then the graph via
tf.keras.backend.get_session()
In TF2, you can’t get a global session, but you have access to a concrete function.
model = tf.keras.Sequential(...)func = tf.function(model).get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
With concrete function, you can already get the graph def via
func.graph.as_graph_def()
However it’s full of Variables and ReadVariableOp. We need to convert them to constant, because model inference doesn’t need to update the weights.
This can be done via a magic function called convert_variables_to_constants_v2_as_graph in TF2.2. It used to be called as convert_variables_to_constants_v2 in TF2.1 IIRC.
frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(func)
But that’s not enough. We need to more graph optimization to have a clean inference graph. Let’s call into grappler.
input_tensors = [tensor for tensor in frozen_func.inputsif tensor.dtype != tf.resource]output_tensors = frozen_func.outputs
graph_def = run_graph_optimizations( graph_def, input_tensors, output_tensors, config=get_grappler_config(["constfold", "function"]), graph=frozen_func.graph)
Nice job! You should have clean inference graph.
Here is the complete code: