Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions keras/src/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,22 @@ def __init__(
self.embeddings_freq = embeddings_freq
self.embeddings_metadata = embeddings_metadata
if profile_batch:
if backend.backend() not in ("jax", "tensorflow"):
bkend = backend.backend()
if bkend not in ("jax", "tensorflow"):
# TODO: profiling not available in torch, numpy
# TODO: profiling not available in torch, numpy
raise ValueError(
"Profiling is not yet available with the "
f"{backend.backend()} backend. Please open a PR "
f"{bkend} backend. Please open a PR "
"if you'd like to add this feature. Received: "
f"profile_batch={profile_batch} (must be 0)"
)
elif backend.backend() == "jax":
elif bkend == "jax":
# Inline sys.version_info[1] < 12 check instead of recomputing bkend
if sys.version_info[1] < 12:
warnings.warn(
"Profiling with the "
f"{backend.backend()} backend requires python >= 3.12."
f"{bkend} backend requires python >= 3.12."
)
profile_batch = 0

Expand Down Expand Up @@ -334,8 +337,10 @@ def _pop_writer(self):
# See _push_writer for the content of the previous_context, which is
# pair of context.
previous_context = self._prev_summary_state.pop()
previous_context[1].__exit__(*sys.exc_info())
previous_context[0].__exit__(*sys.exc_info())
# Cache sys.exc_info result to reduce function calls
exc_info = sys.exc_info()
previous_context[1].__exit__(*exc_info)
previous_context[0].__exit__(*exc_info)

def _close_writers(self):
for writer in self._writers.values():
Expand Down