456 add support for traces generated by jax 08#501
Conversation
…sing (#456) Replace hardcoded `tensorboard_plugin_profile` imports with `xprof` (falling back to the old package for backward compatibility). This enables loading traces generated by JAX 0.8.0+, which produce HLO instruction IDs > INT_MAX that crash tensorboard-plugin-profile <= 2.19.0. Also fix GPU PID detection: xprof remaps device PIDs (e.g. 1 -> 1001), breaking the previous `pid <= 100` heuristic. Now uses process metadata (`/device:GPU` in process name) which is robust to any PID scheme. Fix pre-existing bug where JaxTreePerfAnalyzer.from_file did not extract and pass metadata_events to build_tree, causing a TypeError on main. Co-authored-by: Cursor <cursoragent@cursor.com>
Three issues discovered when using xprof as the trace processing backend:
1. get_dict() gated custom_call_target and backend_config extraction on
metadata= being present in the HLO text. xprof's graph_viewer omits
metadata= (even with show_metadata=True), so custom_call_target was
silently dropped. Fix: extract these fields independently.
2. xprof's graph_viewer emits operands as bare references (%bitcast.39.0)
without inline type annotations (bf16[...] %bitcast.39.0). Add a
post-processing pass (_resolve_operand_references) that looks up each
reference in the hlo_ops dict and substitutes its output type.
3. xprof's xspace_to_tool_names() no longer extracts .hlo_proto.pb files
as a side effect. process_protobuf_file now returns {} gracefully
instead of crashing with IndexError when files are missing.
Also fix test_compare_perf_report.py to only compare columns present in
both reference and generated DataFrames (avoids KeyError from new columns
added by recent PRs).
Relax test_tree_event_cats assertion to tolerate minor categorization
differences between xprof and tensorboard-plugin-profile.
Co-authored-by: Cursor <cursoragent@cursor.com>
- Add xprof==2.20.7 to install_requires for JAX trace processing - Raise clear RuntimeError when xprof returns None (e.g. permission issues) Co-authored-by: Cursor <cursoragent@cursor.com>
- jax_conv_minimal -> jax_conv_minimal_legacy (JAX ~0.6, full perf model) - jax08_conv -> jax_conv_minimal_08 (JAX 0.8, comparable minimal conv) - Update test_jax_conv_analysis.py path - Document naming in tests/traces/README.md - Add *.SSTABLE to .gitignore (xprof cache) - Remove legacy SSTABLE from repo Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
- Update add_gpu_ops_to_tree and _categorize_gpu_kernel_ops docstrings to reflect process_name-based GPU detection (not pid <= 100) - Clarify setup.py xprof comment: preferred library, required for JAX 0.8+ - Include converter library name in DataLoader RuntimeError when conversion fails - Log warning when falling back to tensorboard-plugin-profile - Log warning when HLO operand reference cannot be resolved - Add explicit missing-columns assert in test_compare_perf_report Co-authored-by: Cursor <cursoragent@cursor.com>
CI was using unpinned black (latest from pip), which can produce different formatting than local. Pin to match local version so lint passes. Co-authored-by: Cursor <cursoragent@cursor.com>
Reference reports may include these columns (from main) while generated reports on this branch may not. Add to cols_ignore to avoid assertion failure on UnaryElementwise and other sheets. Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Pull request overview
Migrates TraceLens’ JAX trace ingestion from tensorboard-plugin-profile to xprof 2.20.7 to support JAX 0.8+ xplane traces while preserving legacy perf/HLO linking behavior.
Changes:
- Switch trace conversion to prefer
xprof(with fallback import path) and add HLO parsing fixes for xprof output (operand reference resolution; metadata/custom_call parsing). - Update JAX GPU-event identification to use
process.process_name(e.g./device:GPU) instead of PID heuristics, and pass metadata events into JAX tree building. - Update/relax certain JAX conv tests and strengthen perf report comparison by asserting missing reference columns; add/rename MI300 JAX trace artifacts + documentation updates.
Reviewed changes
Copilot reviewed 9 out of 22 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
.github/workflows/lint.yml |
Pins Black version used in CI lint job. |
.gitignore |
Ignores *.SSTABLE artifacts. |
setup.py |
Adds xprof==2.20.7 dependency for JAX trace conversion. |
TraceLens/util.py |
Prefers xprof for xplane conversion; improves HLO metadata parsing and resolves bare operand references. |
TraceLens/TreePerf/gpu_event_analyser.py |
Derives GPU PIDs using process_name metadata. |
TraceLens/TreePerf/tree_perf.py |
Splits metadata vs non-metadata events and threads metadata into JaxTraceToTree.build_tree(). |
TraceLens/Trace2Tree/trace_to_tree.py |
Uses process.process_name to detect GPU events for tree augmentation/categorization. |
tests/test_compare_perf_report.py |
Adds assertion for missing reference columns before comparing values. |
tests/test_jax_conv_analysis.py |
Updates legacy JAX conv trace path and loosens event-category assertions. |
tests/traces/README.md |
Documents MI300 JAX trace naming for legacy vs 0.8 traces. |
tests/traces/mi300/jax_conv_minimal_legacy_perf_report.xlsx |
Adds/updates reference perf report artifact for MI300 legacy JAX conv. |
tests/traces/mi300/jax_conv_minimal_legacy/jit_forward_3d_conv(15).hlo_proto.pb |
Adds HLO sidecar for legacy trace perf/HLO linking. |
tests/traces/mi300/jax_conv_minimal_legacy/jit_convert_element_type(1).hlo_proto.pb |
Adds HLO sidecar for legacy trace perf/HLO linking. |
tests/traces/mi300/jax_conv_minimal_legacy/jit__unstack(7).hlo_proto.pb |
Adds HLO sidecar for legacy trace perf/HLO linking. |
tests/traces/mi300/jax_conv_minimal_legacy/jit__threefry_split(5).hlo_proto.pb |
Adds HLO sidecar for legacy trace perf/HLO linking. |
tests/traces/mi300/jax_conv_minimal_legacy/jit__threefry_seed(3).hlo_proto.pb |
Adds HLO sidecar for legacy trace perf/HLO linking. |
tests/traces/mi300/jax_conv_minimal_legacy/jit__normal(11).hlo_proto.pb |
Adds HLO sidecar for legacy trace perf/HLO linking. |
tests/traces/mi300/jax_conv_minimal_legacy/jit__multi_slice(13).hlo_proto.pb |
Adds HLO sidecar for legacy trace perf/HLO linking. |
Comments suppressed due to low confidence (3)
TraceLens/Trace2Tree/trace_to_tree.py:435
name = event.get("name")may beNonefor some trace events; later logic in this method does substring checks againstname, which will raise aTypeErrorifnameis not a string. Consider normalizing to an empty string (or skipping categorization) whennameis missing/non-string to keep tree building robust across trace variants.
if "/device:GPU" in proc_name:
if event.get("cat") == "kernel":
name = event.get("name")
TraceLens/util.py:56
- This change introduces xprof-based conversion (and a new JAX 0.8 trace is added to test data), but the test suite still only exercises JAX parsing against the legacy trace. Add a unit/integration test that loads the
jax_conv_minimal_08.xplane.pbviaJaxTreePerfAnalyzer.from_fileto ensure the xprof conversion path works and tree building/categorization doesn’t regress.
if filename_path.endswith("pb"):
try:
from xprof.convert import raw_to_tool_data as convert
converter_lib = "xprof"
except ImportError:
from tensorboard_plugin_profile.convert import (
raw_to_tool_data as convert,
)
converter_lib = "tensorboard-plugin-profile"
logger.warning(
"xprof not available, falling back to tensorboard-plugin-profile "
"for trace conversion. Install xprof for JAX 0.8+ support."
)
data, _ = convert.xspace_to_tool_data([filename_path], "trace_viewer@^", {})
if data is None:
raise RuntimeError(
f"Trace conversion using '{converter_lib}' returned None for "
f"{filename_path}. Ensure the file exists and the output directory "
"is writable (cache files may need to be written)."
)
data = data.decode("utf-8") # we get bytes back from the call above
TraceLens/TreePerf/gpu_event_analyser.py:337
JaxGPUEventAnalysernow derivesgpu_pidsfromprocess_namemetadata, butget_gpu_event_lists()still classifies GPU vs CPU usingpid < 100. If xprof/JAX 0.8 traces don’t preserve the legacy PID scheme, this will silently drop GPU events from metrics. Consider using the same process_name-based check (orpid in self.gpu_pids) insideget_gpu_event_lists()for consistency.
def __init__(self, events):
super().__init__(events) # Call the parent's __init__
self.gpu_pids = list(
set(
event["pid"]
for event in events
if "/device:GPU"
in (
event.get("process", {}).get("process_name", "")
if isinstance(event.get("process"), dict)
else str(event.get("process", ""))
)
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Resolve conflicts: - tree_perf.py: keep defensive metadata_events None check from 456 - test_compare_perf_report.py: keep 456's missing_cols assert and cols_ignore Co-authored-by: Cursor <cursoragent@cursor.com>
protobuf 4.25.8 is incompatible with grpcio-status (requires >=6.31.1). Older protobuf causes RET_CHECK failure on JAX 0.8 traces with HLO ids > INT_MAX. - Add protobuf>=6.31.1,<7.0.0 to setup.py install_requires - Upgrade protobuf in regression-tests JAX step before running tests - Remove tensorboard-plugin-profile from JAX step (use xprof from main install) Co-authored-by: Cursor <cursoragent@cursor.com>
- gpu_event_analyser: use pid in gpu_pids instead of pid < 100 (xprof remaps PIDs) - gpu_event_analyser: add _is_gpu_event helper, fix get_breakdown_df_multigpu - trace_to_tree: add _is_gpu_event helper, guard name/args against None - test_jax_conv_analysis: clarify test_tree_event_cats assertion - test_jax_perf_report: add fixture to cleanup tmpdirs after tests Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
These hlo_proto.pb files should not be checked in - they are generated when xprof.convert / tensorboard_plugin_profile.convert processes the xplane.pb file
There was a problem hiding this comment.
“xprof’s xspace_to_tool_names() no longer extracts .hlo_proto.pb files as a side effect (unlike tensorboard_plugin_profile). For the legacy trace to work with xprof, these files need to be checked in. They’ve been in the repo since PR #380.”
There was a problem hiding this comment.
If that is then case, how do we process new traces from Jax 0.6 starting from scratch if we complete this PR? We can't break Jax 0.6.0 functionality yet
There was a problem hiding this comment.
Do we need to install both xprof and tensorboard-plugin-profile 2.19.0 and import the correct library based on the trace? Or does the protobuf version break this?
There was a problem hiding this comment.
The legacy hlo_proto.pb files had caused the tests to pass with xprof. I will remove them.
The protobuf version actually breaks in this situation:
- xprof’s conversion path uses grpcio / grpcio-status, which pins protobuf ≥ 6.31.1
- tbp 2.19 cannot be used safely with protobuf 6.31.1, which is outside the supported range, causing descriptor error.
I did not yet find a safe solution to have both xprof2.20.7 and tbp2.19 in the same environment. What are you suggestions, which is the priority jax0.6 or 0.8? @devalshahamd @gabeweisz
There was a problem hiding this comment.
Current State
| Tool | JAX 0.6 | JAX 0.8 | Protobuf |
|---|---|---|---|
| xprof 2.20.7 | Trace loads; no full perf model (FLOPS/bytes) — doesn't extract .hlo_proto.pb |
Full support | Requires ≥6.31.1 (grpcio-status) |
| tbp 2.19 | Full support (extracts .hlo_proto.pb) |
Crashes (INT_MAX in HLO ids) | Requires <5.0 (3.x–4.x) |
Resolve conflicts: - regression-tests.yml: adopt main's reorganized workflow (#506), add protobuf>=6.31.1 for JAX tests, use xprof from setup.py - test_compare_perf_report.py: accept deletion (replaced by test_perf_report_regression + test_compare_perf_reports) Made-with: Cursor
Made-with: Cursor
Summary
Closes #456. Migrates TraceLens from
tensorboard-plugin-profileto xprof 2.20.7 for JAX trace processing so traces from JAX 0.8+ can be loaded and analyzed.Why xprof 2.20.7 instead of 2.21.3
hlo_opandcorrelation_idfrom GPU kernel events, which breaks TraceLens HLO linking and perf metrics.INT_MAXcrash seen withtensorboard-plugin-profile2.19.Backward compatibility
(e.g.jax_conv_minimal_legacy, JAX ~0.6): All 10test_jax_conv_analysistests pass, including perf model (FLOPS, bytes, shapes).Changes
tensorboard-plugin-profilewith xprof 2.20.7 (with fallback for backward compatibility)pid <= 100) to process_name-based (/device:GPU)_resolve_operand_referencesfor xprof's bare operand referencescustom_call_targetandbackend_configextraction whenmetadata=is absentget_gpu_event_listsPID classification (usepid in self.gpu_pidsinstead ofpid < 100)jax_conv_minimal→jax_conv_minimal_legacy,jax08_conv→jax_conv_minimal_08intests/traces/mi300/tests/traces/mi300/jax_conv_minimal_08for E2E smoke testsprotobuf>=6.31.1,<7.0.0for grpcio-status compatibility (avoids RET_CHECK on JAX 0.8 traces)_is_gpu_event()helper, guardname/argsagainst None, test cleanup fixture