-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Fix Context Parallel validation checks #12446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
faf61a4
428399b
1d76322
a66787b
881e262
0845ca0
8018a6a
f925783
5bfc7dd
fb15ff5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,29 +79,47 @@ def __post_init__(self): | |
if self.ulysses_degree is None: | ||
self.ulysses_degree = 1 | ||
|
||
if self.ring_degree == 1 and self.ulysses_degree == 1: | ||
raise ValueError( | ||
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference" | ||
) | ||
if self.ring_degree < 1 or self.ulysses_degree < 1: | ||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") | ||
if self.ring_degree > 1 and self.ulysses_degree > 1: | ||
raise ValueError( | ||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." | ||
) | ||
if self.rotate_method != "allgather": | ||
raise NotImplementedError( | ||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." | ||
) | ||
|
||
@property | ||
def mesh_shape(self) -> Tuple[int, int]: | ||
"""Shape of the device mesh (ring_degree, ulysses_degree).""" | ||
return (self.ring_degree, self.ulysses_degree) | ||
|
||
@property | ||
def mesh_dim_names(self) -> Tuple[str, str]: | ||
"""Dimension names for the device mesh.""" | ||
return ("ring", "ulysses") | ||
|
||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): | ||
self._rank = rank | ||
self._world_size = world_size | ||
self._device = device | ||
self._mesh = mesh | ||
if self.ring_degree is None: | ||
self.ring_degree = 1 | ||
if self.ulysses_degree is None: | ||
self.ulysses_degree = 1 | ||
if self.rotate_method != "allgather": | ||
raise NotImplementedError( | ||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." | ||
|
||
if self.ulysses_degree * self.ring_degree > world_size: | ||
raise ValueError( | ||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." | ||
) | ||
if self._flattened_mesh is None: | ||
self._flattened_mesh = self._mesh._flatten() | ||
if self._ring_mesh is None: | ||
self._ring_mesh = self._mesh["ring"] | ||
if self._ulysses_mesh is None: | ||
self._ulysses_mesh = self._mesh["ulysses"] | ||
if self._ring_local_rank is None: | ||
self._ring_local_rank = self._ring_mesh.get_local_rank() | ||
if self._ulysses_local_rank is None: | ||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() | ||
|
||
self._flattened_mesh = self._mesh._flatten() | ||
self._ring_mesh = self._mesh["ring"] | ||
self._ulysses_mesh = self._mesh["ulysses"] | ||
self._ring_local_rank = self._ring_mesh.get_local_rank() | ||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() | ||
Comment on lines
+118
to
+122
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't they be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are internal attributes that are derived from mesh which is set through the The guards are redundant, they would always be |
||
|
||
|
||
@dataclass | ||
|
@@ -119,22 +137,22 @@ class ParallelConfig: | |
_rank: int = None | ||
_world_size: int = None | ||
_device: torch.device = None | ||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None | ||
_mesh: torch.distributed.device_mesh.DeviceMesh = None | ||
|
||
def setup( | ||
self, | ||
rank: int, | ||
world_size: int, | ||
device: torch.device, | ||
*, | ||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | ||
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | ||
): | ||
self._rank = rank | ||
self._world_size = world_size | ||
self._device = device | ||
self._cp_mesh = cp_mesh | ||
self._mesh = mesh | ||
if self.context_parallel_config is not None: | ||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh) | ||
self.context_parallel_config.setup(rank, world_size, device, mesh) | ||
|
||
|
||
@dataclass(frozen=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we hit line as both cannot be set, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both can be set techinically, but currently both can't be > 1. Also this is for cases where you have 3 GPUs available and you set something like ulysses_degree=1 and ring_degree==4 (more GPUs being requested is greater than world_size)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels slightly confusing to me but since we're erroring out early for unsupported
ulysses_degree
andring_degree
value combos, I think it's okay.