-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[compile] Enable sequence parallelism matching w/o custom ops enabled #27126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: angelayi <[email protected]>
c1efc65
to
ed10d76
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking this on! Could you just add me as a co-author on one of the commits?
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" | ||
def get_first_out_wrapper(fn): | ||
@functools.wraps(fn) | ||
def wrapper(*args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work? I thought that during tracing the pattern matching tracer will think that args
is a single parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! updated the test to assert the number of all_reduce/all_gather ops in the graph!
Signed-off-by: angelayi <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
ed10d76
to
5d66118
Compare
Purpose
Based on #24604, modified sequence-parallelism pass to do custom op matching w/o needing to enable the custom op
Test Plan
pytest -sv tests/compile/test_sequence_parallelism.py
Performance numbers
I did some benchmarking with the command on H100 w/o flashinfer
while varying
"pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true}
vs."pass_config": {"enable_async_tp": false, "enable_sequence_parallelism": false}
"custom_ops":["+quant_fp8", "+rms_norm"]
vs."custom_ops":[]