Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ class Profiling:
profile_freq: int = 10
"""How often to collect profile traces, in iterations"""

profiler_active: int = 1
"""
The steps profiler is active for.

This is used to configure torch.profile.schedule.
"""

profiler_warmup: int = 3
"""
The number of warmup steps before the active step in each profiling cycle.

This is used to configure torch.profile.schedule.
"""

enable_memory_snapshot: bool = False
"""Whether to dump memory snapshot"""

Expand Down
10 changes: 5 additions & 5 deletions torchtitan/tools/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from torchtitan.config import Profiling as ProfilingConfig
from torchtitan.tools.logging import logger

# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000

Expand All @@ -34,7 +31,11 @@ def maybe_enable_profiling(

if enable_profiling:
trace_dir = os.path.join(base_folder, profiling_config.save_traces_folder)
profile_freq = profiling_config.profile_freq
profile_freq, warmup, active = (
profiling_config.profile_freq,
profiling_config.profiler_warmup,
profiling_config.profiler_active,
)

rank = torch.distributed.get_rank()

Expand All @@ -58,7 +59,6 @@ def trace_handler(prof):
if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)

warmup, active = WARMUP, 1
wait = profile_freq - (active + warmup)
assert (
wait >= 0
Expand Down
Loading