Skip to content

Conversation

angelayi
Copy link
Contributor

@angelayi angelayi commented Oct 17, 2025

Purpose

Based on #24604, modified sequence-parallelism pass to do custom op matching w/o needing to enable the custom op

Test Plan

pytest -sv tests/compile/test_sequence_parallelism.py

Performance numbers

I did some benchmarking with the command on H100 w/o flashinfer

VLLM_DISABLE_COMPILE_CACHE=1 VLLM_USE_STANDALONE_COMPILE=1 VLLM_LOGGING_LEVEL=DEBUG vllm bench latency --model=nvidia/Llama-3.3-70B-Instruct-FP8 --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 8 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level": 3, "use_inductor_graph_partition": false, "splitting_ops":[], "cudagraph_mode": "FULL", }' --no-enable-prefix-caching

while varying

  • "pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true} vs. "pass_config": {"enable_async_tp": false, "enable_sequence_parallelism": false}
  • "custom_ops":["+quant_fp8", "+rms_norm"] vs. "custom_ops":[]
image

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this on! Could you just add me as a co-author on one of the commits?

"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
def get_first_out_wrapper(fn):
@functools.wraps(fn)
def wrapper(*args):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work? I thought that during tracing the pattern matching tracer will think that args is a single parameter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! updated the test to assert the number of all_reduce/all_gather ops in the graph!

Signed-off-by: angelayi <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: To triage

Development

Successfully merging this pull request may close these issues.

2 participants