@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.experimental_run_functions_eagerly(True)
f(tf.constant(1))
base_model_out = base_model(inputs)
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()(base_model_out)
prediction_layer = tf.keras.layers.Dense(45, activation='softmax')(global_average_layer)
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(
	lambda x: tf.fill(
		[tf.cast(x, tf.int32)],
		x
	)
)
dataset = dataset.padded_batch(
	batch_size=4,
	padded_shapes=(None,),
	padding_values=0
)