Skip to content

Conversation

utkarshsharma1
Copy link
Collaborator

This PR introduces critical fixes to the psum & psum_scatter microbenchmark to ensure its correctness and accuracy.

  1. Correct Input Sharding (in_specs) for both collective operation
    Problem: The input tensor was fully replicated (in_specs=P(None, None))
    Fix: The input is now correctly sharded with in_specs=P(None, "ici")

@chishuen chishuen requested review from chishuen and removed request for chishuen August 28, 2025 05:26
@chishuen
Copy link
Collaborator

I think psum and psum_scatter do not require sharding. In fact, it might be better to keep the matrix unsharded. Otherwise, we might also need to tweak the implementation of AllGather such that the message size (x-axis) can align.

This might be helpful: https://jax-ml.github.io/scaling-book/training/#fully-sharded-data-parallelism-fsdp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants