-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[Neuron] Add tensor parallel support for Neuron backend #13718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
98f6c8c
c58b8b8
3367409
0c51734
a76953c
2480388
1469c04
929ab72
52cac76
30cb353
28a5086
7fab0c4
68689e5
da79308
3bb9c7c
c4facab
dff1f32
1c930c4
1eb5ff9
cbe8f28
16b9606
7f13f68
a46cb19
a354b88
931bb85
9ab6dc3
48fb75b
c350f7b
644477a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,79 @@ | ||||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|
|
||||
| import torch | ||||
|
|
||||
| from ..models._modeling_parallel import TensorParallelConfig | ||||
| from ..utils import get_logger | ||||
|
|
||||
|
|
||||
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||||
|
|
||||
|
|
||||
| def apply_tensor_parallel( | ||||
| model: torch.nn.Module, | ||||
| config: TensorParallelConfig, | ||||
| double_block_plan: dict, | ||||
| single_block_plan: dict, | ||||
| ) -> None: | ||||
| """Apply tensor parallelism to a ``Flux2Transformer2DModel``. | ||||
|
|
||||
| This is the generic (non-Neuron) path. It calls | ||||
| ``torch.distributed.tensor.parallel.parallelize_module`` directly on each | ||||
| transformer block, using the plans defined on the model. | ||||
|
|
||||
| For Neuron, use ``apply_tp_flux2_transformer_neuron`` from | ||||
| ``diffusers.models.transformers.transformer_flux2_neuron_tp`` instead, which | ||||
| pre-shards weights via ``DTensor.from_local`` to work around the Neuron NRT | ||||
| consecutive-reduce-scatter bug. | ||||
|
|
||||
| Args: | ||||
| model (`torch.nn.Module`): | ||||
| A ``Flux2Transformer2DModel`` instance. Must have ``transformer_blocks`` | ||||
| and ``single_transformer_blocks`` attributes. | ||||
|
Comment on lines
+43
to
+44
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It cannot be specific to a particular model-type, right? |
||||
| config (`TensorParallelConfig`): | ||||
| TP configuration. ``config.setup()`` must have been called before this | ||||
| function so that ``config._mesh`` is populated. | ||||
| double_block_plan (`dict`): | ||||
| ``parallelize_module`` plan for each double-stream block | ||||
| (``model.transformer_blocks``). Keys are relative module paths | ||||
| (e.g. ``"attn.to_q"``), values are ``ColwiseParallel()`` / | ||||
| ``RowwiseParallel()`` instances. | ||||
| single_block_plan (`dict`): | ||||
| ``parallelize_module`` plan for each single-stream block | ||||
| (``model.single_transformer_blocks``). | ||||
| """ | ||||
| if not torch.distributed.is_available() or not torch.distributed.is_initialized(): | ||||
| raise RuntimeError( | ||||
| "apply_tensor_parallel requires an initialised torch.distributed process group." | ||||
| ) | ||||
|
|
||||
| try: | ||||
| from torch.distributed.tensor.parallel import parallelize_module | ||||
| except ImportError as e: | ||||
| raise ImportError( | ||||
| "apply_tensor_parallel requires PyTorch >= 2.3 with distributed tensor parallel support." | ||||
| ) from e | ||||
|
|
||||
| tp_mesh = config._mesh | ||||
| if tp_mesh is None: | ||||
| raise ValueError( | ||||
| "`config._mesh` is None. Call `config.setup(rank, world_size, device)` before applying TP." | ||||
| ) | ||||
|
|
||||
| for block in model.transformer_blocks: | ||||
| parallelize_module(block, tp_mesh, double_block_plan) | ||||
|
|
||||
| for block in model.single_transformer_blocks: | ||||
| parallelize_module(block, tp_mesh, single_block_plan) | ||||
|
Comment on lines
+69
to
+79
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make it similar to
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,6 @@ | |
| # - Unified Attention | ||
| # - More dispatcher attention backends | ||
| # - CFG/Data Parallel | ||
| # - Tensor Parallel | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -154,6 +153,63 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di | |
| self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() | ||
|
|
||
|
|
||
| @dataclass | ||
| class TensorParallelConfig: | ||
| """ | ||
| Configuration for tensor parallelism. | ||
|
|
||
| Tensor parallelism shards weight matrices (column-wise and row-wise) across devices. | ||
| Each device computes a partial result; an AllReduce/AllGather at layer boundaries | ||
| reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module`` | ||
| with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles. | ||
|
|
||
| On Neuron, use the ``_pre_shard_and_tp`` workaround from | ||
| ``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug | ||
| on large tensors (>= 5120x5120). | ||
|
|
||
| Args: | ||
| tp_degree (`int`, defaults to `1`): | ||
| Number of devices to shard across. Must be a divisor of the number of | ||
| attention heads (and FFN hidden dimensions) of the model being parallelised. | ||
| mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): | ||
| A custom device mesh to use. If provided, ``tp_degree`` is inferred from | ||
| ``mesh.size()`` and the argument is ignored. Useful when combining TP with | ||
| other parallelism strategies (e.g. CP) that share the same mesh. | ||
|
Comment on lines
+175
to
+177
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you provide an example for this? |
||
| """ | ||
|
|
||
| tp_degree: int = 1 | ||
| mesh: torch.distributed.device_mesh.DeviceMesh | None = None | ||
|
|
||
| _rank: int = None | ||
| _world_size: int = None | ||
| _device: torch.device = None | ||
| _mesh: torch.distributed.device_mesh.DeviceMesh = None | ||
|
|
||
| def __post_init__(self): | ||
| if self.tp_degree < 1: | ||
| raise ValueError("`tp_degree` must be >= 1.") | ||
|
|
||
| def setup( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is this supposed to be called from? |
||
| self, | ||
| rank: int, | ||
| world_size: int, | ||
| device: torch.device, | ||
| mesh: torch.distributed.device_mesh.DeviceMesh | None = None, | ||
| ): | ||
| self._rank = rank | ||
| self._world_size = world_size | ||
| self._device = device | ||
| if mesh is not None: | ||
| self._mesh = mesh | ||
| elif self.mesh is not None: | ||
| self._mesh = self.mesh | ||
| else: | ||
| from torch.distributed.device_mesh import init_device_mesh | ||
|
|
||
| device_type = str(device).split(":")[0] | ||
| self._mesh = init_device_mesh(device_type, (self.tp_degree,), mesh_dim_names=("tp",)) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ParallelConfig: | ||
| """ | ||
|
|
@@ -162,9 +218,12 @@ class ParallelConfig: | |
| Args: | ||
| context_parallel_config (`ContextParallelConfig`, *optional*): | ||
| Configuration for context parallelism. | ||
| tensor_parallel_config (`TensorParallelConfig`, *optional*): | ||
| Configuration for tensor parallelism. | ||
| """ | ||
|
|
||
| context_parallel_config: ContextParallelConfig | None = None | ||
| tensor_parallel_config: TensorParallelConfig | None = None | ||
|
|
||
| _rank: int = None | ||
| _world_size: int = None | ||
|
|
@@ -185,6 +244,8 @@ def setup( | |
| self._mesh = mesh | ||
| if self.context_parallel_config is not None: | ||
| self.context_parallel_config.setup(rank, world_size, device, mesh) | ||
| if self.tensor_parallel_config is not None: | ||
| self.tensor_parallel_config.setup(rank, world_size, device, mesh) | ||
|
Comment on lines
245
to
+248
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's raise if both |
||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like an unrelated change?