diff --git a/keras/src/callbacks/tensorboard.py b/keras/src/callbacks/tensorboard.py index 506c8d6dafb4..356c9d3b21eb 100644 --- a/keras/src/callbacks/tensorboard.py +++ b/keras/src/callbacks/tensorboard.py @@ -176,19 +176,21 @@ def __init__( self.embeddings_freq = embeddings_freq self.embeddings_metadata = embeddings_metadata if profile_batch: - if backend.backend() not in ("jax", "tensorflow"): + backend_val = backend.backend() + if backend_val not in ("jax", "tensorflow"): + # TODO: profiling not available in torch, numpy # TODO: profiling not available in torch, numpy raise ValueError( "Profiling is not yet available with the " - f"{backend.backend()} backend. Please open a PR " + f"{backend_val} backend. Please open a PR " "if you'd like to add this feature. Received: " f"profile_batch={profile_batch} (must be 0)" ) - elif backend.backend() == "jax": + elif backend_val == "jax": if sys.version_info[1] < 12: warnings.warn( "Profiling with the " - f"{backend.backend()} backend requires python >= 3.12." + f"{backend_val} backend requires python >= 3.12." ) profile_batch = 0 @@ -431,11 +433,12 @@ def on_test_end(self, logs=None): self._pop_writer() def on_train_batch_begin(self, batch, logs=None): + should_trace = self._should_trace + if not should_trace: + return self._global_train_batch += 1 if self.write_steps_per_second: self._batch_start_time = time.time() - if not self._should_trace: - return if self._global_train_batch == self._start_batch: self._start_trace()