CTDA-CPA (Cross-Thread Dependency-Aware Critical Path Analysis) is an advanced performance analysis tool designed for large-scale model training. It accurately identifies performance hotspots by constructing complete critical paths across CPU multi-threading and GPU multi-stream execution.
- Cross-Thread Dependency Modeling: Captures inter-thread dependencies caused by Python GIL, ensuring complete critical path construction
- Accurate Hotspot Identification: Avoids "false hotspots" caused by computation overlap in traditional time-accumulation methods
- Non-Intrusive Analysis: Works with standard PyTorch Profiler data without requiring code modifications
- Framework Support: Compatible with DeepSpeed, Megatron-LM, FSDP, and native PyTorch training
- Visualization Tools: Includes computation dependency graph visualization for better understanding
CTDA-CPA has successfully identified and optimized:
- ResNet-18 Training: 47.06% performance improvement by optimizing data loading bottleneck
- Llama 2 Distributed Fine-tuning: 7.21% performance improvement by optimizing communication bottleneck
# Clone the repository
git clone https://github.com/yourusername/CTDA-CPA.git
cd CTDA-CPA
# Install dependencies
pip install -r requirements.txtAdd PyTorch Profiler to your training code:
import torch
import torch.profiler as profiler
def train_one_epoch(model, dataloader, optimizer, device):
model.train()
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx >= 10: # Profile first 10 batches
break
# Wrap training with profiler
with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_stack=True,
profile_memory=False,
) as prof:
train_one_epoch(model, train_loader, optimizer, device)
# Export trace file
prof.export_chrome_trace("trace.json")# Basic analysis
python ctda_cpa/hotspot_analysis.py --input trace.json --output results/
# With custom options
python ctda_cpa/hotspot_analysis.py \
--input trace.json \
--top_functions 10Key Arguments:
--input: Path to trace file (JSON or Parquet)--top_functions: Number of top hotspots to report (default: 10)
The analysis generates a hotspot report:
=== CTDA-CPA Hotspot Analysis Report ===
Critical Path Total Time: 1000 ms
Top 10 Hotspots:
ββββββ¬βββββββββββββββββββββββββ¬ββββββββββββββββββ¬βββββββββββββ¬βββββββββββββ¬ββββββββββββββββ
β # β name β category β duration β proportion β prop_in_total β
ββββββΌβββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββΌβββββββββββββΌββββββββββββββββ€
β 1 β resize of ImagingCore β python function β 100ms β 50% β 10% β
β 2 β ncclAllReduce β kernel β 60ms β 30% β 6% β
β 3 β aten::conv2d β cpu_op β 40ms β 20% β 4% β
ββββββ΄βββββββββββββββββββββββββ΄ββββββββββββββββββ΄βββββββββββββ΄βββββββββββββ΄ββββββββββββββββ
Optimization Recommendations:
β’ DataLoader is the primary bottleneck (50% of critical path)
β Consider: data preprocessing, num_workers tuning, prefetching
β’ Communication overhead is significant (30% of critical path)
β Consider: gradient bucketing, communication overlap
Example - Data Loading Bottleneck:
# Optimized DataLoader
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=8, # Increase workers
pin_memory=True, # Faster GPU transfer
persistent_workers=True # Reuse workers
)See optimizations/ for complete case studies.
Existing critical path analysis tools assume CPU threads are independent, but Python's GIL (Global Interpreter Lock) means only one thread executes at a time. In PyTorch training:
- Main Thread: Handles forward computation
- Autograd Thread: Executes backward propagation
- GIL Constraint: These threads execute serially with strict ordering
CTDA-CPA models these cross-thread dependencies by treating multi-threaded events as sequentially ordered based on timestamps, ensuring complete critical path construction across thread boundaries.
- Data Collection: Uses PyTorch Profiler to collect trace data (CPU ops, CUDA kernels, timing)
- Dependency Graph Construction: Builds a Computation Dependency Graph (CDG) with four types of dependencies:
- Sequential execution (same thread/stream)
- Invocation (CPUβGPU calls)
- Synchronization (CUDA sync operations)
- Cross-thread (GIL-induced ordering) β Our innovation
- Critical Path Identification: Uses timestamp-based backtracking to find the true critical path
- Hotspot Extraction: Identifies events on critical path that actually impact end-to-end time
| Feature | HTA | CTDA-CPA |
|---|---|---|
| Handles Computation Overlap | β | β |
| Cross-Thread Dependencies | β | β |
| Complete Critical Path | Partial | β |
| Multi-threaded Training | Limited | β |
Large trace files?
# Convert to Parquet for better performance
python ctda_cpa/utils/json2parquet.py trace.json trace.parquet
python ctda_cpa/hotspot_analysis.py --input trace.parquet --top_functions 10cd examples/deepspeed
bash run.shDemonstrates profiling and analyzing distributed fine-tuning of Llama 2 13B with DeepSpeed ZeRO optimization.
cd examples/megatron
bash run.shShows how to profile and analyze GPT-2 Large pretraining with Megatron-LM pipeline parallelism.
Optimized data loading pipeline achieving 47.06% speedup.
Optimized communication bottlenecks achieving 7.21% improvement.
This project is licensed under the MIT License - see the LICENSE file for details.
- Tong Yang - East China Normal University
- Ning Li - East China Normal University
- Bo Huang - East China Normal University
- Jianmei Guo - East China Normal University
Star β this repository if you find it helpful!