@@ -28,6 +28,8 @@ class ProfilerOptions:
28
28
skip_first_n_steps : int
29
29
# Number of steps to profile.
30
30
profiler_steps : int
31
+ # Whether to set the profile options.
32
+ set_profile_options : bool = True
31
33
# https://github.com/jax-ml/jax/blob/0b1b909dd66a113ee0d7e54e55d0efef480e2a8a/docs/profiling.md?plain=1#L285
32
34
host_tracer_level : int = 2 # set to 2 to capture HBM profiles.
33
35
# https://github.com/jax-ml/jax/blob/0b1b909dd66a113ee0d7e54e55d0efef480e2a8a/docs/profiling.md?plain=1#L300
@@ -69,14 +71,19 @@ def maybe_activate(self, step: int):
69
71
if self ._do_not_profile or step != self ._first_profile_step :
70
72
return
71
73
logging .info ("Starting JAX profiler at step %d." , step )
72
- profile_options = jax .profiler .ProfileOptions ()
73
- profile_options .host_tracer_level = self ._profiler_options .host_tracer_level
74
- profile_options .python_tracer_level = (
75
- self ._profiler_options .python_tracer_level
76
- )
77
- jax .profiler .start_trace (
78
- log_dir = self ._output_path , profiler_options = profile_options
79
- )
74
+ if self ._profiler_options .set_profile_options :
75
+ profile_options = jax .profiler .ProfileOptions ()
76
+ profile_options .host_tracer_level = (
77
+ self ._profiler_options .host_tracer_level
78
+ )
79
+ profile_options .python_tracer_level = (
80
+ self ._profiler_options .python_tracer_level
81
+ )
82
+ jax .profiler .start_trace (
83
+ log_dir = self ._output_path , profiler_options = profile_options
84
+ )
85
+ else :
86
+ jax .profiler .start_trace (log_dir = self ._output_path )
80
87
81
88
def maybe_deactivate (self , step : int ):
82
89
"""End the profiler."""
0 commit comments