Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions keras/src/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,21 @@ def __init__(
self.embeddings_freq = embeddings_freq
self.embeddings_metadata = embeddings_metadata
if profile_batch:
if backend.backend() not in ("jax", "tensorflow"):
backend_val = backend.backend()
if backend_val not in ("jax", "tensorflow"):
# TODO: profiling not available in torch, numpy
# TODO: profiling not available in torch, numpy
raise ValueError(
"Profiling is not yet available with the "
f"{backend.backend()} backend. Please open a PR "
f"{backend_val} backend. Please open a PR "
"if you'd like to add this feature. Received: "
f"profile_batch={profile_batch} (must be 0)"
)
elif backend.backend() == "jax":
elif backend_val == "jax":
if sys.version_info[1] < 12:
warnings.warn(
"Profiling with the "
f"{backend.backend()} backend requires python >= 3.12."
f"{backend_val} backend requires python >= 3.12."
)
profile_batch = 0

Expand Down Expand Up @@ -431,11 +433,12 @@ def on_test_end(self, logs=None):
self._pop_writer()

def on_train_batch_begin(self, batch, logs=None):
should_trace = self._should_trace
if not should_trace:
return
self._global_train_batch += 1
if self.write_steps_per_second:
self._batch_start_time = time.time()
if not self._should_trace:
return

if self._global_train_batch == self._start_batch:
self._start_trace()
Expand Down