Skip to content

Commit ca03377

Browse files
SurbhiJainUSCThe tunix Authors
authored andcommitted
Disable setting specific profiler options on Pathways backend.
PiperOrigin-RevId: 814737710
1 parent 2eb4184 commit ca03377

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

tunix/sft/profiler.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class ProfilerOptions:
2828
skip_first_n_steps: int
2929
# Number of steps to profile.
3030
profiler_steps: int
31+
# Whether to set the profile options.
32+
set_profile_options: bool = True
3133
# https://github.com/jax-ml/jax/blob/0b1b909dd66a113ee0d7e54e55d0efef480e2a8a/docs/profiling.md?plain=1#L285
3234
host_tracer_level: int = 2 # set to 2 to capture HBM profiles.
3335
# https://github.com/jax-ml/jax/blob/0b1b909dd66a113ee0d7e54e55d0efef480e2a8a/docs/profiling.md?plain=1#L300
@@ -69,14 +71,19 @@ def maybe_activate(self, step: int):
6971
if self._do_not_profile or step != self._first_profile_step:
7072
return
7173
logging.info("Starting JAX profiler at step %d.", step)
72-
profile_options = jax.profiler.ProfileOptions()
73-
profile_options.host_tracer_level = self._profiler_options.host_tracer_level
74-
profile_options.python_tracer_level = (
75-
self._profiler_options.python_tracer_level
76-
)
77-
jax.profiler.start_trace(
78-
log_dir=self._output_path, profiler_options=profile_options
79-
)
74+
if self._profiler_options.set_profile_options:
75+
profile_options = jax.profiler.ProfileOptions()
76+
profile_options.host_tracer_level = (
77+
self._profiler_options.host_tracer_level
78+
)
79+
profile_options.python_tracer_level = (
80+
self._profiler_options.python_tracer_level
81+
)
82+
jax.profiler.start_trace(
83+
log_dir=self._output_path, profiler_options=profile_options
84+
)
85+
else:
86+
jax.profiler.start_trace(log_dir=self._output_path)
8087

8188
def maybe_deactivate(self, step: int):
8289
"""End the profiler."""

0 commit comments

Comments
 (0)