Skip to content

456 add support for traces generated by jax 08#501

Open
gphuang wants to merge 18 commits intomainfrom
456-add-support-for-traces-generated-by-jax-08
Open

456 add support for traces generated by jax 08#501
gphuang wants to merge 18 commits intomainfrom
456-add-support-for-traces-generated-by-jax-08

Conversation

@gphuang
Copy link
Contributor

@gphuang gphuang commented Feb 24, 2026

Summary

Closes #456. Migrates TraceLens from tensorboard-plugin-profile to xprof 2.20.7 for JAX trace processing so traces from JAX 0.8+ can be loaded and analyzed.

Why xprof 2.20.7 instead of 2.21.3

  • xprof 2.20.8+ / 2.21.x removes hlo_op and correlation_id from GPU kernel events, which breaks TraceLens HLO linking and perf metrics.
  • xprof 2.20.7 keeps these fields and loads JAX 0.8 traces without the INT_MAX crash seen with tensorboard-plugin-profile 2.19.

Backward compatibility

  • Older JAX 0.6 traces: protobuf version conflict! (See comments below.) (e.g. jax_conv_minimal_legacy, JAX ~0.6): All 10 test_jax_conv_analysis tests pass, including perf model (FLOPS, bytes, shapes).
  • JAX 0.8 traces: Trace loading, event categorization, and GPU timelines work; HLO metadata for perf models is limited until Add hlo_op and framework name scope call_stack nodes to Jax.tree #425 lands.

Changes

  • Replace tensorboard-plugin-profile with xprof 2.20.7 (with fallback for backward compatibility)
  • Switch GPU event detection from PID-based (pid <= 100) to process_name-based (/device:GPU)
  • Add _resolve_operand_references for xprof's bare operand references
  • Fix custom_call_target and backend_config extraction when metadata= is absent
  • Fix get_gpu_event_lists PID classification (use pid in self.gpu_pids instead of pid < 100)
  • Rename traces: jax_conv_minimaljax_conv_minimal_legacy, jax08_convjax_conv_minimal_08 in tests/traces/mi300/
  • Add JAX 0.8 conv trace tests/traces/mi300/jax_conv_minimal_08 for E2E smoke tests
  • Pin protobuf>=6.31.1,<7.0.0 for grpcio-status compatibility (avoids RET_CHECK on JAX 0.8 traces)
  • Pin Black to 26.1.0 in lint workflow (avoids CI failing when Black versions changes)
  • Code quality: _is_gpu_event() helper, guard name/args against None, test cleanup fixture

gphuang and others added 7 commits February 20, 2026 12:32
…sing (#456)

Replace hardcoded `tensorboard_plugin_profile` imports with `xprof` (falling
back to the old package for backward compatibility). This enables loading
traces generated by JAX 0.8.0+, which produce HLO instruction IDs > INT_MAX
that crash tensorboard-plugin-profile <= 2.19.0.

Also fix GPU PID detection: xprof remaps device PIDs (e.g. 1 -> 1001),
breaking the previous `pid <= 100` heuristic. Now uses process metadata
(`/device:GPU` in process name) which is robust to any PID scheme.

Fix pre-existing bug where JaxTreePerfAnalyzer.from_file did not extract
and pass metadata_events to build_tree, causing a TypeError on main.

Co-authored-by: Cursor <cursoragent@cursor.com>
Three issues discovered when using xprof as the trace processing backend:

1. get_dict() gated custom_call_target and backend_config extraction on
   metadata= being present in the HLO text. xprof's graph_viewer omits
   metadata= (even with show_metadata=True), so custom_call_target was
   silently dropped. Fix: extract these fields independently.

2. xprof's graph_viewer emits operands as bare references (%bitcast.39.0)
   without inline type annotations (bf16[...] %bitcast.39.0). Add a
   post-processing pass (_resolve_operand_references) that looks up each
   reference in the hlo_ops dict and substitutes its output type.

3. xprof's xspace_to_tool_names() no longer extracts .hlo_proto.pb files
   as a side effect. process_protobuf_file now returns {} gracefully
   instead of crashing with IndexError when files are missing.

Also fix test_compare_perf_report.py to only compare columns present in
both reference and generated DataFrames (avoids KeyError from new columns
added by recent PRs).

Relax test_tree_event_cats assertion to tolerate minor categorization
differences between xprof and tensorboard-plugin-profile.

Co-authored-by: Cursor <cursoragent@cursor.com>
- Add xprof==2.20.7 to install_requires for JAX trace processing
- Raise clear RuntimeError when xprof returns None (e.g. permission issues)

Co-authored-by: Cursor <cursoragent@cursor.com>
- jax_conv_minimal -> jax_conv_minimal_legacy (JAX ~0.6, full perf model)
- jax08_conv -> jax_conv_minimal_08 (JAX 0.8, comparable minimal conv)
- Update test_jax_conv_analysis.py path
- Document naming in tests/traces/README.md
- Add *.SSTABLE to .gitignore (xprof cache)
- Remove legacy SSTABLE from repo

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
- Update add_gpu_ops_to_tree and _categorize_gpu_kernel_ops docstrings to
  reflect process_name-based GPU detection (not pid <= 100)
- Clarify setup.py xprof comment: preferred library, required for JAX 0.8+
- Include converter library name in DataLoader RuntimeError when conversion fails
- Log warning when falling back to tensorboard-plugin-profile
- Log warning when HLO operand reference cannot be resolved
- Add explicit missing-columns assert in test_compare_perf_report

Co-authored-by: Cursor <cursoragent@cursor.com>
CI was using unpinned black (latest from pip), which can produce different
formatting than local. Pin to match local version so lint passes.

Co-authored-by: Cursor <cursoragent@cursor.com>
@gphuang gphuang self-assigned this Feb 24, 2026
@gphuang gphuang marked this pull request as ready for review February 24, 2026 12:39
Copilot AI review requested due to automatic review settings February 24, 2026 12:39
Reference reports may include these columns (from main) while generated
reports on this branch may not. Add to cols_ignore to avoid assertion
failure on UnaryElementwise and other sheets.

Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Migrates TraceLens’ JAX trace ingestion from tensorboard-plugin-profile to xprof 2.20.7 to support JAX 0.8+ xplane traces while preserving legacy perf/HLO linking behavior.

Changes:

  • Switch trace conversion to prefer xprof (with fallback import path) and add HLO parsing fixes for xprof output (operand reference resolution; metadata/custom_call parsing).
  • Update JAX GPU-event identification to use process.process_name (e.g. /device:GPU) instead of PID heuristics, and pass metadata events into JAX tree building.
  • Update/relax certain JAX conv tests and strengthen perf report comparison by asserting missing reference columns; add/rename MI300 JAX trace artifacts + documentation updates.

Reviewed changes

Copilot reviewed 9 out of 22 changed files in this pull request and generated no comments.

Show a summary per file
File Description
.github/workflows/lint.yml Pins Black version used in CI lint job.
.gitignore Ignores *.SSTABLE artifacts.
setup.py Adds xprof==2.20.7 dependency for JAX trace conversion.
TraceLens/util.py Prefers xprof for xplane conversion; improves HLO metadata parsing and resolves bare operand references.
TraceLens/TreePerf/gpu_event_analyser.py Derives GPU PIDs using process_name metadata.
TraceLens/TreePerf/tree_perf.py Splits metadata vs non-metadata events and threads metadata into JaxTraceToTree.build_tree().
TraceLens/Trace2Tree/trace_to_tree.py Uses process.process_name to detect GPU events for tree augmentation/categorization.
tests/test_compare_perf_report.py Adds assertion for missing reference columns before comparing values.
tests/test_jax_conv_analysis.py Updates legacy JAX conv trace path and loosens event-category assertions.
tests/traces/README.md Documents MI300 JAX trace naming for legacy vs 0.8 traces.
tests/traces/mi300/jax_conv_minimal_legacy_perf_report.xlsx Adds/updates reference perf report artifact for MI300 legacy JAX conv.
tests/traces/mi300/jax_conv_minimal_legacy/jit_forward_3d_conv(15).hlo_proto.pb Adds HLO sidecar for legacy trace perf/HLO linking.
tests/traces/mi300/jax_conv_minimal_legacy/jit_convert_element_type(1).hlo_proto.pb Adds HLO sidecar for legacy trace perf/HLO linking.
tests/traces/mi300/jax_conv_minimal_legacy/jit__unstack(7).hlo_proto.pb Adds HLO sidecar for legacy trace perf/HLO linking.
tests/traces/mi300/jax_conv_minimal_legacy/jit__threefry_split(5).hlo_proto.pb Adds HLO sidecar for legacy trace perf/HLO linking.
tests/traces/mi300/jax_conv_minimal_legacy/jit__threefry_seed(3).hlo_proto.pb Adds HLO sidecar for legacy trace perf/HLO linking.
tests/traces/mi300/jax_conv_minimal_legacy/jit__normal(11).hlo_proto.pb Adds HLO sidecar for legacy trace perf/HLO linking.
tests/traces/mi300/jax_conv_minimal_legacy/jit__multi_slice(13).hlo_proto.pb Adds HLO sidecar for legacy trace perf/HLO linking.
Comments suppressed due to low confidence (3)

TraceLens/Trace2Tree/trace_to_tree.py:435

  • name = event.get("name") may be None for some trace events; later logic in this method does substring checks against name, which will raise a TypeError if name is not a string. Consider normalizing to an empty string (or skipping categorization) when name is missing/non-string to keep tree building robust across trace variants.
            if "/device:GPU" in proc_name:

                if event.get("cat") == "kernel":
                    name = event.get("name")

TraceLens/util.py:56

  • This change introduces xprof-based conversion (and a new JAX 0.8 trace is added to test data), but the test suite still only exercises JAX parsing against the legacy trace. Add a unit/integration test that loads the jax_conv_minimal_08 .xplane.pb via JaxTreePerfAnalyzer.from_file to ensure the xprof conversion path works and tree building/categorization doesn’t regress.
        if filename_path.endswith("pb"):
            try:
                from xprof.convert import raw_to_tool_data as convert

                converter_lib = "xprof"
            except ImportError:
                from tensorboard_plugin_profile.convert import (
                    raw_to_tool_data as convert,
                )

                converter_lib = "tensorboard-plugin-profile"
                logger.warning(
                    "xprof not available, falling back to tensorboard-plugin-profile "
                    "for trace conversion. Install xprof for JAX 0.8+ support."
                )

            data, _ = convert.xspace_to_tool_data([filename_path], "trace_viewer@^", {})
            if data is None:
                raise RuntimeError(
                    f"Trace conversion using '{converter_lib}' returned None for "
                    f"{filename_path}. Ensure the file exists and the output directory "
                    "is writable (cache files may need to be written)."
                )
            data = data.decode("utf-8")  # we get bytes back from the call above

TraceLens/TreePerf/gpu_event_analyser.py:337

  • JaxGPUEventAnalyser now derives gpu_pids from process_name metadata, but get_gpu_event_lists() still classifies GPU vs CPU using pid < 100. If xprof/JAX 0.8 traces don’t preserve the legacy PID scheme, this will silently drop GPU events from metrics. Consider using the same process_name-based check (or pid in self.gpu_pids) inside get_gpu_event_lists() for consistency.
    def __init__(self, events):
        super().__init__(events)  # Call the parent's __init__
        self.gpu_pids = list(
            set(
                event["pid"]
                for event in events
                if "/device:GPU"
                in (
                    event.get("process", {}).get("process_name", "")
                    if isinstance(event.get("process"), dict)
                    else str(event.get("process", ""))
                )
            )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

gphuang and others added 3 commits February 25, 2026 08:19
Resolve conflicts:
- tree_perf.py: keep defensive metadata_events None check from 456
- test_compare_perf_report.py: keep 456's missing_cols assert and cols_ignore

Co-authored-by: Cursor <cursoragent@cursor.com>
protobuf 4.25.8 is incompatible with grpcio-status (requires >=6.31.1).
Older protobuf causes RET_CHECK failure on JAX 0.8 traces with HLO ids > INT_MAX.

- Add protobuf>=6.31.1,<7.0.0 to setup.py install_requires
- Upgrade protobuf in regression-tests JAX step before running tests
- Remove tensorboard-plugin-profile from JAX step (use xprof from main install)

Co-authored-by: Cursor <cursoragent@cursor.com>
- gpu_event_analyser: use pid in gpu_pids instead of pid < 100 (xprof remaps PIDs)
- gpu_event_analyser: add _is_gpu_event helper, fix get_breakdown_df_multigpu
- trace_to_tree: add _is_gpu_event helper, guard name/args against None
- test_jax_conv_analysis: clarify test_tree_event_cats assertion
- test_jax_perf_report: add fixture to cleanup tmpdirs after tests

Co-authored-by: Cursor <cursoragent@cursor.com>
@gphuang gphuang requested a review from gabeweisz February 25, 2026 09:25
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These hlo_proto.pb files should not be checked in - they are generated when xprof.convert / tensorboard_plugin_profile.convert processes the xplane.pb file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“xprof’s xspace_to_tool_names() no longer extracts .hlo_proto.pb files as a side effect (unlike tensorboard_plugin_profile). For the legacy trace to work with xprof, these files need to be checked in. They’ve been in the repo since PR #380.”

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that is then case, how do we process new traces from Jax 0.6 starting from scratch if we complete this PR? We can't break Jax 0.6.0 functionality yet

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to install both xprof and tensorboard-plugin-profile 2.19.0 and import the correct library based on the trace? Or does the protobuf version break this?

Copy link
Contributor Author

@gphuang gphuang Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The legacy hlo_proto.pb files had caused the tests to pass with xprof. I will remove them.

The protobuf version actually breaks in this situation:

  • xprof’s conversion path uses grpcio / grpcio-status, which pins protobuf ≥ 6.31.1
  • tbp 2.19 cannot be used safely with protobuf 6.31.1, which is outside the supported range, causing descriptor error.

I did not yet find a safe solution to have both xprof2.20.7 and tbp2.19 in the same environment. What are you suggestions, which is the priority jax0.6 or 0.8? @devalshahamd @gabeweisz

Copy link
Contributor Author

@gphuang gphuang Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current State

Tool JAX 0.6 JAX 0.8 Protobuf
xprof 2.20.7 Trace loads; no full perf model (FLOPS/bytes) — doesn't extract .hlo_proto.pb Full support Requires ≥6.31.1 (grpcio-status)
tbp 2.19 Full support (extracts .hlo_proto.pb) Crashes (INT_MAX in HLO ids) Requires <5.0 (3.x–4.x)

gphuang and others added 7 commits February 26, 2026 09:07
Resolve conflicts:
- regression-tests.yml: adopt main's reorganized workflow (#506), add protobuf>=6.31.1 for JAX tests, use xprof from setup.py
- test_compare_perf_report.py: accept deletion (replaced by test_perf_report_regression + test_compare_perf_reports)

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for traces generated by Jax 8.0

3 participants