Visualize TF2 graph in TensorBoard
Do you sometimes wonder what’s happening underneath in TF2 eager model, or want to optimize performance by putting as much workload in the TF compute graph as possible? You might find this article useful in understanding how to visualize the graph in TensorBoard.
The rough idea is to export the traced graph through tf.summary
lib then open in TensorBoard. You can use `tf.summary.trace_on` to start tracing and `tf.summary.trace_export` to export the trace.
For example, I want to understand how the graph looks like when TF dataset is used. You can modify the code in Code
section to use your own code that you want to explore.
# Visualization
%load_ext tensorboard
import tensorboardfrom datetime import datetime# Set up logging.
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = 'logs/func/%s' % stamp # <- Name of this `run`
writer = tf.summary.create_file_writer(logdir)# Initialization
x = tf.random.uniform([1])# Start tracing and store it in `tf.summary`
tf.summary.trace_on(graph=True, profiler=False)# Code: Begin
############## Call `tf.function` when tracing.
@tf.function
def expand(x):
return tf.ones([2,x])@tf.function
def sum(x):
ds = tf.data.Dataset.from_tensor_slices([1,2,3])
ds = ds.map(expand)
for item in ds.unbatch():
x += tf.math.reduce_sum(item)
return xz = sum(x)# Code: End
###########with writer.as_default():
tf.summary.trace_export(
name="my_func_trace", # <- Name of tag
step=0,
profiler_outdir=logdir)
Then let’s visualize it in tensorboard
%tensorboard --logdir logs/func
If you don’t see the graph or the graph is not changed with new data. Please make sure that you have chosen the correct run
and tag
from the left panel.