diff --git a/docs/source/en/api/parallel.md b/docs/source/en/api/parallel.md index f2a6bee3910e..5f300d5dd566 100644 --- a/docs/source/en/api/parallel.md +++ b/docs/source/en/api/parallel.md @@ -22,3 +22,9 @@ Parallelism strategies help speed up diffusion transformers by distributing comp [[autodoc]] ContextParallelConfig [[autodoc]] hooks.apply_context_parallel + +## TensorParallelConfig + +[[autodoc]] TensorParallelConfig + +[[autodoc]] hooks.apply_tensor_parallel diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 08b0262a9ef9..66ab00e56461 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -431,3 +431,73 @@ pipeline = DiffusionPipeline.from_pretrained( CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16, ).to(device) ``` + +## Tensor parallelism + +[Tensor parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) shards the weight matrices of a model across devices. Each device holds a column-wise (`"colwise"`) or row-wise (`"rowwise"`) slice of each layer, computes a partial result, and an `AllReduce`/`AllGather` at the layer boundary reconstructs the full output. Unlike context parallelism, it reduces the per-device *weight* memory, which is useful for models that do not fit on a single device. + +Pass a [`TensorParallelConfig`] to [`~ModelMixin.enable_parallelism`]. `tp_degree` is the number of devices to shard across and must divide the model's number of attention heads. The model must define a `_tp_plan` (a flat mapping of module-name globs to a `"colwise"`/`"rowwise"` style); [`Flux2Transformer2DModel`] ships one. + +```py +import torch +from torch import distributed as dist +from diffusers import DiffusionPipeline, TensorParallelConfig + +def setup_distributed(): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + return device + +def main(): + device = setup_distributed() + world_size = dist.get_world_size() + + pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16 + ).to(device) + + pipeline.transformer.enable_parallelism(config=TensorParallelConfig(tp_degree=world_size)) + + generator = torch.Generator().manual_seed(42) + image = pipeline("a cat holding a sign that says hello", generator=generator).images[0] + if dist.get_rank() == 0: + image.save("output.png") + if dist.is_initialized(): + dist.destroy_process_group() + +if __name__ == "__main__": + main() +``` + +```shell +torchrun --nproc-per-node 2 above_script.py +``` + +### Custom device mesh (combining with context parallelism) + +`TensorParallelConfig` (and `ContextParallelConfig`) accept a custom `mesh`. This lets you carve a multi-dimensional [device mesh](https://docs.pytorch.org/docs/stable/distributed.tensor.parallel.html) and pass the relevant sub-mesh to each strategy. For example, a `tp × ring × ulysses` layout: + +```py +from torch.distributed.device_mesh import init_device_mesh + +mesh = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("tp", "ring", "ulysses")) + +tp_config = TensorParallelConfig(mesh=mesh["tp"]) +``` + +When a custom `mesh` is supplied, `tp_degree` is inferred from `mesh.size()`. + +> [!WARNING] +> Combining context parallelism and tensor parallelism in a single `enable_parallelism()` call is not yet supported — passing both a `context_parallel_config` and a `tensor_parallel_config` raises an error. Enable one strategy at a time for now. + +### Neuron (AWS Trainium/Inferentia) + +On AWS Neuron, `enable_parallelism` automatically selects a pre-shard path that works around an NRT consecutive-reduce-scatter limitation on large weights. Because the weights are sharded on CPU before being placed on the device, **call `enable_parallelism` while the transformer is still on CPU, then move the pipeline to the Neuron device**: + +```py +pipeline.transformer.enable_parallelism(config=TensorParallelConfig(tp_degree=8)) +pipeline.transformer = pipeline.transformer.to("xla") +``` diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6353347503e1..34f499cdaf4b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -313,6 +313,7 @@ "StableCascadeUNet", "T2IAdapter", "T5FilmDecoder", + "TensorParallelConfig", "Transformer2DModel", "TransformerTemporalModel", "UNet1DModel", @@ -1172,6 +1173,7 @@ StableAudioDiTModel, T2IAdapter, T5FilmDecoder, + TensorParallelConfig, Transformer2DModel, TransformerTemporalModel, UNet1DModel, diff --git a/src/diffusers/hooks/tensor_parallel.py b/src/diffusers/hooks/tensor_parallel.py new file mode 100644 index 000000000000..19799098080d --- /dev/null +++ b/src/diffusers/hooks/tensor_parallel.py @@ -0,0 +1,154 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ..models._modeling_parallel import TensorParallelConfig +from ..utils import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +def _get_module(model: torch.nn.Module, path: str) -> torch.nn.Module: + """Resolve a dotted (wildcard-free) module path relative to ``model``.""" + submodule = model + if path: + for atom in path.split("."): + if not hasattr(submodule, atom): + raise ValueError(f"'{atom}' is not a submodule of '{submodule.__class__.__name__}'") + submodule = getattr(submodule, atom) + return submodule + + +def _resolve_tp_plan(model: torch.nn.Module, tp_plan: dict) -> list: + """Group a flat ``_tp_plan`` into per-module ``parallelize_module`` plans. + + ``tp_plan`` maps module-name globs (relative to ``model``) to a style string, e.g. + ``{"transformer_blocks.*.attn.to_q": "colwise"}``. Each glob is split at its single ``*``: + the prefix must resolve to a ``ModuleList`` and the suffix becomes the per-element relative + key. Entries are grouped by the repeated block instance so the caller issues one + ``parallelize_module`` call per block (required so ``RowwiseParallel`` attaches its input + redistribution at the block boundary). Keys without a ``*`` are grouped under the model itself. + + Returns: + A list of ``(submodule, {relative_path: style_str})`` tuples, in plan order. + """ + grouped: dict[int, tuple] = {} + order: list[int] = [] + + for pattern, style in tp_plan.items(): + if pattern.count("*") > 1: + raise ValueError(f"Wildcard '*' can only be used once in a `_tp_plan` key, got '{pattern}'.") + + if "*" in pattern: + prefix, _, suffix = pattern.partition("*") + container = _get_module(model, prefix.strip(".")) + if not isinstance(container, torch.nn.ModuleList): + raise ValueError( + f"`_tp_plan` wildcard '{pattern}' must expand over a `ModuleList`, but " + f"'{prefix.strip('.')}' resolved to '{container.__class__.__name__}'." + ) + relative = suffix.strip(".") + blocks = list(container) + else: + relative = pattern + blocks = [model] + + for block in blocks: + key = id(block) + if key not in grouped: + grouped[key] = (block, {}) + order.append(key) + grouped[key][1][relative] = style + + return [grouped[key] for key in order] + + +def _styles(relative_plan: dict) -> dict: + """Map a ``{relative_path: style_str}`` plan to ``parallelize_module`` style instances.""" + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + + mapping = {"colwise": ColwiseParallel, "rowwise": RowwiseParallel} + resolved = {} + for path, style in relative_plan.items(): + if style not in mapping: + raise ValueError( + f"Unsupported tensor-parallel style '{style}' for '{path}'. Expected one of {list(mapping)}." + ) + resolved[path] = mapping[style]() + return resolved + + +def apply_tensor_parallel( + model: torch.nn.Module, + config: TensorParallelConfig, + tp_plan: dict, + *, + backend: str = "default", +) -> None: + """Apply tensor parallelism to a model from its flat ``_tp_plan``. + + This is model-agnostic: it only relies on ``tp_plan`` (a flat mapping of module-name globs to + ``"colwise"``/``"rowwise"`` styles) and on the device mesh stored on ``config``. The attention + processors derive their per-rank head/inner sizes from ``config`` at runtime, so no processor + swap is needed. + + Args: + model (`torch.nn.Module`): + The model to shard (e.g. a ``Flux2Transformer2DModel``). + config (`TensorParallelConfig`): + TP configuration. ``config.setup()`` must have been called so that ``config._mesh`` is + populated. + tp_plan (`dict`): + The model's ``_tp_plan`` (see :class:`~diffusers.models.transformers.Flux2Transformer2DModel`). + backend (`str`, *optional*, defaults to `"default"`): + ``"default"`` uses ``torch.distributed.tensor.parallel.parallelize_module`` directly. + ``"neuron"`` routes to the Neuron pre-shard path, which works around the Neuron NRT + consecutive-reduce-scatter bug and applies the Flux2 fused-weight permutations. + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + raise RuntimeError("apply_tensor_parallel requires an initialised torch.distributed process group.") + + tp_mesh = config._mesh + if tp_mesh is None: + raise ValueError("`config._mesh` is None. Call `config.setup(rank, world_size, device)` before applying TP.") + + groups = _resolve_tp_plan(model, tp_plan) + logger.debug(f"Applying tensor parallel (backend={backend}) over {len(groups)} module group(s) on mesh {tp_mesh}.") + + if backend == "neuron": + from ..models.transformers.transformer_flux2_neuron_tp import _apply_tp_neuron + + _apply_tp_neuron(model, tp_mesh, groups) + return + + try: + from torch.distributed.tensor.parallel import parallelize_module + except ImportError as e: + raise ImportError( + "apply_tensor_parallel requires PyTorch >= 2.3 with distributed tensor parallel support." + ) from e + + # Some models fuse projections into single Linear layers (e.g. Flux2's SwiGLU FFN and fused + # QKV+MLP). Their weights must be re-ordered before contiguous sharding so each rank gets a + # correct paired slice. + permuters = getattr(model, "_tp_fused_block_permuters", None) or {} + tp_size = tp_mesh.size() + + for submodule, relative_plan in groups: + permuter = permuters.get(submodule.__class__.__name__) + if permuter is not None: + permuter(submodule, tp_size) + parallelize_module(submodule, tp_mesh, _styles(relative_plan)) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7a1d0801f2c5..74548acf184a 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -25,7 +25,7 @@ _import_structure = {} if is_torch_available(): - _import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"] + _import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig", "TensorParallelConfig"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] @@ -162,7 +162,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): - from ._modeling_parallel import ContextParallelConfig, ParallelConfig + from ._modeling_parallel import ContextParallelConfig, ParallelConfig, TensorParallelConfig from .adapter import MultiAdapter, T2IAdapter from .attention_dispatch import AttentionBackendName, attention_backend from .auto_model import AutoModel diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 56e1eced9eef..98047abecfd8 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -35,7 +35,6 @@ # - Unified Attention # - More dispatcher attention backends # - CFG/Data Parallel -# - Tensor Parallel @dataclass @@ -154,6 +153,67 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() +@dataclass +class TensorParallelConfig: + """ + Configuration for tensor parallelism. + + Tensor parallelism shards weight matrices (column-wise and row-wise) across devices. + Each device computes a partial result; an AllReduce/AllGather at layer boundaries + reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module`` + with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles. + + On Neuron, use the ``_pre_shard_and_tp`` workaround from + ``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug + on large tensors (>= 5120x5120). + + Args: + tp_degree (`int`, defaults to `1`): + Number of devices to shard across. Must be a divisor of the number of + attention heads (and FFN hidden dimensions) of the model being parallelised. + mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): + A custom device mesh to use. If provided, ``tp_degree`` is inferred from + ``mesh.size()`` and the argument is ignored. Useful when combining TP with + other parallelism strategies (e.g. CP) that share the same mesh. + """ + + tp_degree: int = 1 + mesh: torch.distributed.device_mesh.DeviceMesh | None = None + + _rank: int = None + _world_size: int = None + _device: torch.device = None + _mesh: torch.distributed.device_mesh.DeviceMesh = None + + def __post_init__(self): + if self.tp_degree < 1: + raise ValueError("`tp_degree` must be >= 1.") + + def setup( + self, + rank: int, + world_size: int, + device: torch.device, + mesh: torch.distributed.device_mesh.DeviceMesh | None = None, + ): + self._rank = rank + self._world_size = world_size + self._device = device + if mesh is not None: + self._mesh = mesh + elif self.mesh is not None: + self._mesh = self.mesh + else: + from torch.distributed.device_mesh import init_device_mesh + + device_type = str(device).split(":")[0] + self._mesh = init_device_mesh(device_type, (self.tp_degree,), mesh_dim_names=("tp",)) + + # Keep ``tp_degree`` consistent with the mesh actually used (a custom mesh wins). The + # attention processors read ``tp_degree`` at runtime to compute their per-rank sizes. + self.tp_degree = self._mesh.size() + + @dataclass class ParallelConfig: """ @@ -162,9 +222,12 @@ class ParallelConfig: Args: context_parallel_config (`ContextParallelConfig`, *optional*): Configuration for context parallelism. + tensor_parallel_config (`TensorParallelConfig`, *optional*): + Configuration for tensor parallelism. """ context_parallel_config: ContextParallelConfig | None = None + tensor_parallel_config: TensorParallelConfig | None = None _rank: int = None _world_size: int = None @@ -185,6 +248,20 @@ def setup( self._mesh = mesh if self.context_parallel_config is not None: self.context_parallel_config.setup(rank, world_size, device, mesh) + if self.tensor_parallel_config is not None: + self.tensor_parallel_config.setup(rank, world_size, device, mesh) + + @property + def _cp_world_size(self) -> int: + """Context-parallel world size, or 1 when context parallelism is not enabled. + + Lets attention backends branch on context parallelism without dereferencing a possibly + ``None`` ``context_parallel_config`` (e.g. when only tensor parallelism is active). + """ + cp = self.context_parallel_config + if cp is None or cp._world_size is None: + return 1 + return cp._world_size @dataclass(frozen=True) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index d9920a877112..76feff885bcd 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1117,7 +1117,7 @@ def _flash_attention_forward_op( scale = query.shape[-1] ** (-0.5) # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. - if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): + if grad_enabled or (_parallel_config is not None and _parallel_config._cp_world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 with torch.set_grad_enabled(grad_enabled): @@ -1225,7 +1225,7 @@ def _flash_attention_hub_forward_op( deterministic = False grad_enabled = any(x.requires_grad for x in (query, key, value)) - if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): + if grad_enabled or (_parallel_config is not None and _parallel_config._cp_world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 with torch.set_grad_enabled(grad_enabled): @@ -1337,7 +1337,7 @@ def _flash_varlen_attention_hub_forward_op( deterministic = False grad_enabled = any(x.requires_grad for x in (query, key, value)) - if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): + if grad_enabled or (_parallel_config is not None and _parallel_config._cp_world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 batch_size, seq_len_q, num_heads, _ = query.shape @@ -2664,7 +2664,7 @@ def _flash_attention( if attn_mask is not None: raise ValueError("`attn_mask` is not supported for flash-attn 2.") - if _parallel_config is None: + if _parallel_config is None or _parallel_config.context_parallel_config is None: out = flash_attn_func( q=query, k=key, @@ -2721,7 +2721,7 @@ def _flash_attention_hub( raise ValueError("`attn_mask` is not supported for flash-attn 2.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn - if _parallel_config is None: + if _parallel_config is None or _parallel_config.context_parallel_config is None: out = func( q=query, k=key, @@ -2773,14 +2773,18 @@ def _flash_varlen_attention_hub( return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: - if _parallel_config is not None and _parallel_config.context_parallel_config.ring_degree > 1: + if ( + _parallel_config is not None + and _parallel_config.context_parallel_config is not None + and _parallel_config.context_parallel_config.ring_degree > 1 + ): raise NotImplementedError("`ring_degree > 1` is not yet supported for the FLASH_VARLEN_HUB backend.") lse = None batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if _parallel_config is None: + if _parallel_config is None or _parallel_config.context_parallel_config is None: if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( @@ -2944,7 +2948,7 @@ def _flash_attention_3_hub( raise ValueError("`attn_mask` is not supported for flash-attn 3.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn - if _parallel_config is None: + if _parallel_config is None or _parallel_config.context_parallel_config is None: out = func( q=query, k=key, @@ -3324,7 +3328,7 @@ def _native_attention( # SDPA handles both boolean and additive masks correctly attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) - if _parallel_config is None: + if _parallel_config is None or _parallel_config.context_parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( query=query, @@ -3550,7 +3554,7 @@ def _native_npu_attention( ) -> torch.Tensor: if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") - if _parallel_config is None: + if _parallel_config is None or _parallel_config.context_parallel_config is None: attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) out = npu_fusion_attention( @@ -3634,7 +3638,7 @@ def _sage_attention( if attn_mask is not None: raise ValueError("`attn_mask` is not supported for sage attention") lse = None - if _parallel_config is None: + if _parallel_config is None or _parallel_config.context_parallel_config is None: out = sageattn( q=query, k=key, @@ -3686,7 +3690,7 @@ def _sage_attention_hub( raise ValueError("`attn_mask` is not supported for sage attention") lse = None func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn - if _parallel_config is None: + if _parallel_config is None or _parallel_config.context_parallel_config is None: out = func( q=query, k=key, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 41b0f689d9a4..47142ecbe082 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -58,13 +58,19 @@ is_bitsandbytes_version, is_flashpack_available, is_peft_available, + is_torch_neuronx_available, is_torch_version, logging, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card from ..utils.torch_utils import empty_device_cache -from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig +from ._modeling_parallel import ( + ContextParallelConfig, + ContextParallelModelPlan, + ParallelConfig, + TensorParallelConfig, +) from .model_loading_utils import ( _caching_allocator_warmup, _determine_device_map, @@ -250,6 +256,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _repeated_blocks = [] _parallel_config = None _cp_plan = None + _tp_plan = None + _tp_fused_block_permuters = None _skip_keys = None def __init__(self): @@ -1612,7 +1620,7 @@ def compile_repeated_blocks(self, *args, **kwargs): def enable_parallelism( self, *, - config: ParallelConfig | ContextParallelConfig, + config: ParallelConfig | ContextParallelConfig | TensorParallelConfig, cp_plan: dict[str, ContextParallelModelPlan] | None = None, ): logger.warning( @@ -1631,6 +1639,14 @@ def enable_parallelism( if isinstance(config, ContextParallelConfig): config = ParallelConfig(context_parallel_config=config) + elif isinstance(config, TensorParallelConfig): + config = ParallelConfig(tensor_parallel_config=config) + + if config.context_parallel_config is not None and config.tensor_parallel_config is not None: + raise ValueError( + "Combining context parallelism and tensor parallelism in a single " + "`enable_parallelism()` call is not yet supported. Please enable only one at a time." + ) rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() @@ -1676,7 +1692,16 @@ def enable_parallelism( mesh_shape=cp_config.mesh_shape, mesh_dim_names=cp_config.mesh_dim_names, ) + elif config.tensor_parallel_config is not None: + tp_config = config.tensor_parallel_config + mesh = tp_config.mesh or torch.distributed.device_mesh.init_device_mesh( + device_type=device_type, + mesh_shape=(tp_config.tp_degree,), + mesh_dim_names=("tp",), + ) + # `config.setup()` is the single place the CP/TP mesh is recorded onto the config (and, + # for TP, `tp_degree` is synced to the actual mesh size); see `ParallelConfig.setup`. config.setup(rank, world_size, device, mesh=mesh) self._parallel_config = config @@ -1696,6 +1721,28 @@ def enable_parallelism( cp_plan = cp_plan if cp_plan is not None else self._cp_plan apply_context_parallel(self, config.context_parallel_config, cp_plan) + if config.tensor_parallel_config is not None: + if self._tp_plan is None: + raise ValueError( + "`_tp_plan` must be set on the model class to use tensor parallelism. " + f"'{self.__class__.__name__}' does not define one." + ) + tp_degree = config.tensor_parallel_config.tp_degree + num_heads = getattr(self.config, "num_attention_heads", None) + if num_heads is not None and num_heads % tp_degree != 0: + raise ValueError(f"`tp_degree` ({tp_degree}) must divide the number of attention heads ({num_heads}).") + + from ..hooks.tensor_parallel import apply_tensor_parallel + + # The Neuron pre-shard path works around the NRT consecutive-reduce-scatter bug. Neuron + # does not surface as the torch accelerator (`torch._C._get_accelerator().type` is + # "cpu"), so detect it from the TP mesh's device type instead — on Neuron the mesh is a + # `DeviceMesh("neuron", ...)`. + tp_mesh = config.tensor_parallel_config._mesh + mesh_device_type = tp_mesh.device_type if tp_mesh is not None else device_type + backend = "neuron" if (is_torch_neuronx_available() and mesh_device_type == "neuron") else "default" + apply_tensor_parallel(self, config.tensor_parallel_config, self._tp_plan, backend=backend) + @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 17c8bd0ffd52..cceb7907e5b9 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -281,6 +281,18 @@ def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_s return _get_projections(attn, hidden_states, encoder_hidden_states) +def _get_tp_degree(parallel_config) -> int: + """Return the tensor-parallel degree from a processor's ``_parallel_config`` (1 if TP is off). + + When tensor parallelism is enabled, each rank holds ``attn.heads // tp_degree`` heads (and a + proportionally smaller slice of the fused inner / MLP dims), so the attention processors derive + their local sizes from this value at runtime. ``tp_degree == 1`` recovers the non-TP behavior. + """ + if parallel_config is not None and getattr(parallel_config, "tensor_parallel_config", None) is not None: + return parallel_config.tensor_parallel_config.tp_degree + return 1 + + class Flux2SwiGLU(nn.Module): """ Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection @@ -339,21 +351,25 @@ def __call__( attention_mask: torch.Tensor | None = None, image_rotary_emb: torch.Tensor | None = None, ) -> torch.Tensor: + # Under tensor parallelism each rank holds ``attn.heads // tp_degree`` heads after the + # column-wise weight sharding; ``tp_degree == 1`` recovers the non-TP behavior. + local_heads = attn.heads // _get_tp_degree(self._parallel_config) + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( attn, hidden_states, encoder_hidden_states ) - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) + query = query.unflatten(-1, (local_heads, -1)) + key = key.unflatten(-1, (local_heads, -1)) + value = value.unflatten(-1, (local_heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) if attn.added_kv_proj_dim is not None: - encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) - encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) - encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + encoder_query = encoder_query.unflatten(-1, (local_heads, -1)) + encoder_key = encoder_key.unflatten(-1, (local_heads, -1)) + encoder_value = encoder_value.unflatten(-1, (local_heads, -1)) encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) @@ -581,18 +597,24 @@ def __call__( attention_mask: torch.Tensor | None = None, image_rotary_emb: torch.Tensor | None = None, ) -> torch.Tensor: + # Under tensor parallelism the fused ``to_qkv_mlp_proj`` is column-sharded, so each rank + # holds a proportionally smaller slice of the Q/K/V (inner) and MLP dims, and ``attn.heads + # // tp_degree`` heads. ``tp_degree == 1`` recovers the non-TP behavior. + tp_degree = _get_tp_degree(self._parallel_config) + local_heads = attn.heads // tp_degree + local_inner = attn.inner_dim // tp_degree + local_mlp = attn.mlp_hidden_dim * attn.mlp_mult_factor // tp_degree + # Parallel in (QKV + MLP in) projection hidden_states = attn.to_qkv_mlp_proj(hidden_states) - qkv, mlp_hidden_states = torch.split( - hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 - ) + qkv, mlp_hidden_states = torch.split(hidden_states, [3 * local_inner, local_mlp], dim=-1) # Handle the attention logic query, key, value = qkv.chunk(3, dim=-1) - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) + query = query.unflatten(-1, (local_heads, -1)) + key = key.unflatten(-1, (local_heads, -1)) + value = value.unflatten(-1, (local_heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) @@ -1036,6 +1058,107 @@ def split(mod: torch.Tensor, mod_param_sets: int) -> tuple[tuple[torch.Tensor, t return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets)) +# ─── Tensor-parallel fused-weight permutations ────────────────────────────────────────────────── +# Flux2 fuses several projections into single Linear layers (SwiGLU gate+linear in the FFN, and +# Q/K/V + MLP in the single-stream ``to_qkv_mlp_proj`` / ``to_out``). Column/row sharding slices a +# weight contiguously and has no knowledge of that internal layout, so each rank would receive an +# unpaired slice (e.g. all-gate / no-linear). These helpers re-order the weight rows/columns so the +# contiguous slice each rank receives is the correct paired chunk. This is a property of Flux2's +# fused layers, not of any device backend, so it is applied on both the generic and Neuron paths. + + +def _permute_swiglu_for_tp(weight: torch.Tensor, tp_size: int) -> torch.Tensor: + """Interleave gate/linear chunks of a SwiGLU FFN weight (``ff.linear_in``) for column-wise TP. + + Re-orders ``[gate_0…gate_N, linear_0…linear_N]`` to ``[gate_0, linear_0, gate_1, linear_1, …]`` + so a contiguous row slice gives each rank paired gate+linear rows. + """ + with torch.no_grad(): + total = weight.shape[0] + inner = total // 2 + chunk = inner // tp_size + gate = weight[:inner] + linear = weight[inner:] + parts = [] + for i in range(tp_size): + parts.append(gate[i * chunk : (i + 1) * chunk]) + parts.append(linear[i * chunk : (i + 1) * chunk]) + return torch.cat(parts, dim=0) + + +def _permute_qkv_mlp_for_tp(weight: torch.Tensor, tp_size: int, inner_dim: int, mlp_hidden_dim: int) -> torch.Tensor: + """Interleave Q/K/V/gate/linear chunks of the fused ``to_qkv_mlp_proj`` weight for column-wise TP. + + Re-orders ``[Q, K, V, mlp_gate, mlp_linear]`` so rank *r* receives a contiguous slice with its + proportional share of each component. + """ + with torch.no_grad(): + q = weight[:inner_dim] + k = weight[inner_dim : 2 * inner_dim] + v = weight[2 * inner_dim : 3 * inner_dim] + mlp_gate = weight[3 * inner_dim : 3 * inner_dim + mlp_hidden_dim] + mlp_lin = weight[3 * inner_dim + mlp_hidden_dim :] + + qkv_chunk = inner_dim // tp_size + mlp_chunk = mlp_hidden_dim // tp_size + + parts = [] + for i in range(tp_size): + parts += [ + q[i * qkv_chunk : (i + 1) * qkv_chunk], + k[i * qkv_chunk : (i + 1) * qkv_chunk], + v[i * qkv_chunk : (i + 1) * qkv_chunk], + mlp_gate[i * mlp_chunk : (i + 1) * mlp_chunk], + mlp_lin[i * mlp_chunk : (i + 1) * mlp_chunk], + ] + return torch.cat(parts, dim=0) + + +def _permute_out_for_tp(weight: torch.Tensor, tp_size: int, attn_dim: int, mlp_dim: int) -> torch.Tensor: + """Interleave attn/mlp input columns of the fused ``to_out`` weight for row-wise TP. + + Re-orders the ``[attn_out, mlp_out]`` input columns so rank *r* receives a contiguous slice of + paired attn+mlp columns. + """ + with torch.no_grad(): + attn_part = weight[:, :attn_dim] + mlp_part = weight[:, attn_dim:] + + attn_chunk = attn_dim // tp_size + mlp_chunk = mlp_dim // tp_size + + parts = [] + for i in range(tp_size): + parts.append(attn_part[:, i * attn_chunk : (i + 1) * attn_chunk]) + parts.append(mlp_part[:, i * mlp_chunk : (i + 1) * mlp_chunk]) + return torch.cat(parts, dim=1) + + +def _permute_flux2_double_block(block: nn.Module, tp_size: int) -> None: + """Permute the SwiGLU FFN weights of a Flux2 double-stream block in place.""" + block.ff.linear_in.weight.data = _permute_swiglu_for_tp(block.ff.linear_in.weight.data, tp_size) + block.ff_context.linear_in.weight.data = _permute_swiglu_for_tp(block.ff_context.linear_in.weight.data, tp_size) + + +def _permute_flux2_single_block(block: nn.Module, tp_size: int) -> None: + """Permute the fused QKV+MLP / output weights of a Flux2 single-stream block in place.""" + attn = block.attn + attn.to_qkv_mlp_proj.weight.data = _permute_qkv_mlp_for_tp( + attn.to_qkv_mlp_proj.weight.data, tp_size, attn.inner_dim, attn.mlp_hidden_dim + ) + attn.to_out.weight.data = _permute_out_for_tp( + attn.to_out.weight.data, tp_size, attn.inner_dim, attn.mlp_hidden_dim + ) + + +# Maps block class name -> in-place fused-weight permuter. Consumed by ``apply_tensor_parallel`` +# (both the generic and Neuron backends) before the weights are sharded. +_FLUX2_TP_FUSED_PERMUTERS = { + "Flux2TransformerBlock": _permute_flux2_double_block, + "Flux2SingleTransformerBlock": _permute_flux2_single_block, +} + + class Flux2Transformer2DModel( ModelMixin, ConfigMixin, @@ -1090,6 +1213,35 @@ class Flux2Transformer2DModel( "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), } + # Tensor-parallel sharding plan: a flat mapping of fully-qualified module-name globs + # (relative to the model) to a parallel style string ("colwise" / "rowwise"). This is + # the same shape as ``_cp_plan`` and as ``transformers`` ``base_model_tp_plan``; it is + # consumed generically by ``apply_tensor_parallel`` (and its Neuron backend, which also + # applies weight permutations for Flux2's fused layers). Strings are mapped to + # ColwiseParallel/RowwiseParallel inside the hook to keep torch.distributed.tensor out + # of module import time. + _tp_plan = { + # double-stream (cross-attention + FFN) blocks + "transformer_blocks.*.attn.to_q": "colwise", + "transformer_blocks.*.attn.to_k": "colwise", + "transformer_blocks.*.attn.to_v": "colwise", + "transformer_blocks.*.attn.to_out.0": "rowwise", + "transformer_blocks.*.attn.add_q_proj": "colwise", + "transformer_blocks.*.attn.add_k_proj": "colwise", + "transformer_blocks.*.attn.add_v_proj": "colwise", + "transformer_blocks.*.attn.to_add_out": "rowwise", + "transformer_blocks.*.ff.linear_in": "colwise", + "transformer_blocks.*.ff.linear_out": "rowwise", + "transformer_blocks.*.ff_context.linear_in": "colwise", + "transformer_blocks.*.ff_context.linear_out": "rowwise", + # single-stream (parallel self-attn + fused MLP) blocks + "single_transformer_blocks.*.attn.to_qkv_mlp_proj": "colwise", + "single_transformer_blocks.*.attn.to_out": "rowwise", + } + # Per-block fused-weight permuters applied before sharding (Flux2 fuses gate+linear and + # Q/K/V+MLP into single Linears). Backend-agnostic: required on both generic and Neuron paths. + _tp_fused_block_permuters = _FLUX2_TP_FUSED_PERMUTERS + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_flux2_neuron_tp.py b/src/diffusers/models/transformers/transformer_flux2_neuron_tp.py new file mode 100644 index 000000000000..fba905d1c188 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_flux2_neuron_tp.py @@ -0,0 +1,229 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Neuron-specific Tensor Parallelism utilities for Flux2 and Qwen3. + +This module provides the functions needed to apply tensor parallelism on AWS +Neuron hardware. The key difference from the generic ``apply_tensor_parallel`` +path is a workaround for a Neuron NRT bug: consecutive ``reduce_scatter`` +collectives for large weight tensors (≥ 5120×5120) can fail when all layers +are distributed in a single ``parallelize_module`` call. The fix is to +pre-shard each weight locally on CPU via ``DTensor.from_local`` *before* +calling ``parallelize_module``; the latter then sees already-placed DTensors, +skips the collective for weights, but still registers the required +input/output hooks for the forward pass. + +Entry points: + ``apply_tp_flux2_transformer_neuron(model, tp_mesh)`` + Apply TP to a ``Flux2Transformer2DModel``. Includes the weight + permutations required by Flux2's SwiGLU FFN and fused QKV+MLP + projections. + + ``apply_tp_qwen3_neuron(model, tp_mesh)`` + Apply TP to a ``Qwen3ForCausalLM`` text encoder. The sharding plan is + derived from ``model.config.base_model_tp_plan`` — the same plan used + by ``from_pretrained(tp_plan="auto")`` in transformers — so it stays in + sync automatically if the plan changes upstream. +""" + +from typing import TYPE_CHECKING + +import torch +import torch.distributed as dist +import torch.nn as nn + + +if TYPE_CHECKING: + from transformers import Qwen3ForCausalLM + + from .transformer_flux2 import Flux2Transformer2DModel + + +def _pre_shard_and_tp( + module: nn.Module, + tp_mesh: "torch.distributed.device_mesh.DeviceMesh", + plan: dict, + rank: int, + tp_size: int, +) -> None: + """Pre-shard Linear weights via ``DTensor.from_local``, then call ``parallelize_module``. + + Workaround for a Neuron NRT bug where consecutive ``reduce_scatter`` calls + for large weight tensors (≥ 5120×5120) fail when all layers are distributed + in a single ``parallelize_module`` call. By pre-sharding each weight on CPU + before the call, ``distribute_tensor`` inside ``parallelize_module`` sees an + already-placed DTensor and skips the collective, while the module hooks + (input/output specs) are still registered correctly. + + Args: + module: The block whose Linear sub-modules are being sharded. + tp_mesh: Device mesh for TP (1-D, size == tp_size). + plan: ``{relative_path: ColwiseParallel() | RowwiseParallel()}`` dict, + as expected by ``parallelize_module``. + rank: Current rank (``dist.get_rank()``). + tp_size: Total TP degree (``tp_mesh.size()``). + """ + from torch.distributed.tensor import DTensor, Shard + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module + + device = torch.neuron.current_device() + + for path, style in plan.items(): + # Resolve nested attribute path (e.g. "attn.to_q" or "attn.to_out.0") + submod = module + for part in path.split("."): + submod = getattr(submod, part) + + if not hasattr(submod, "weight"): + continue + + w = submod.weight.data # CPU at this point + if isinstance(style, ColwiseParallel): + rows = w.shape[0] // tp_size + shard = w[rank * rows : (rank + 1) * rows, :].contiguous().to(device) + submod.weight = nn.Parameter(DTensor.from_local(shard, tp_mesh, [Shard(0)])) + elif isinstance(style, RowwiseParallel): + cols = w.shape[1] // tp_size + shard = w[:, rank * cols : (rank + 1) * cols].contiguous().to(device) + submod.weight = nn.Parameter(DTensor.from_local(shard, tp_mesh, [Shard(1)])) + + # parallelize_module is now a no-op for weight distribution (already DTensors) + # but still registers the input/output hooks required for the forward pass. + parallelize_module(module, tp_mesh, plan) + + +def _apply_tp_neuron( + model: nn.Module, + tp_mesh: "torch.distributed.device_mesh.DeviceMesh", + groups: list, +) -> None: + """Apply tensor parallelism on Neuron from resolved ``_tp_plan`` groups. + + ``groups`` is produced by ``diffusers.hooks.tensor_parallel._resolve_tp_plan`` — the same + source of truth used by the generic path, so the two backends shard identical layers. For + each ``(block, relative_plan)`` group this: + 1. permutes the model's fused weights (via ``model._tp_fused_block_permuters``, the same + backend-agnostic permuters the generic path uses) so column/row slicing gives each rank a + correct chunk, + 2. pre-shards the weights via ``DTensor.from_local`` (Neuron NRT consecutive-reduce-scatter + workaround), then calls ``parallelize_module`` to register the forward hooks. + + The attention processors derive their per-rank sizes from ``_parallel_config`` at runtime, so + no processor swap is performed here. Model weights must be on CPU when this is called. + """ + from ...hooks.tensor_parallel import _styles + + rank = dist.get_rank() + tp_size = tp_mesh.size() + permuters = getattr(model, "_tp_fused_block_permuters", None) or {} + + for block, relative_plan in groups: + permuter = permuters.get(block.__class__.__name__) + if permuter is not None: + permuter(block, tp_size) + _pre_shard_and_tp(block, tp_mesh, _styles(relative_plan), rank, tp_size) + + +def apply_tp_flux2_transformer_neuron( + model: "Flux2Transformer2DModel", + tp_mesh: "torch.distributed.device_mesh.DeviceMesh", +) -> "Flux2Transformer2DModel": + """Apply tensor parallelism to a ``Flux2Transformer2DModel`` on Neuron. + + Thin wrapper kept for direct/standalone use. The model weights must still be on CPU when this + is called; move the model to the Neuron device *after*:: + + apply_tp_flux2_transformer_neuron(pipe.transformer, tp_mesh) + pipe.transformer = pipe.transformer.to(device) + + Prefer the public API ``model.enable_parallelism(config=TensorParallelConfig(...))``, which + dispatches here automatically on Neuron. + + Args: + model: ``Flux2Transformer2DModel`` with weights on CPU. + tp_mesh: 1-D Neuron device mesh of size ``tp_size``. + + Returns: + The same ``model`` instance, modified in-place. + """ + from ...hooks.tensor_parallel import _resolve_tp_plan + + _apply_tp_neuron(model, tp_mesh, _resolve_tp_plan(model, model._tp_plan)) + return model + + +def apply_tp_qwen3_neuron( + model: "Qwen3ForCausalLM", + tp_mesh: "torch.distributed.device_mesh.DeviceMesh", +) -> "Qwen3ForCausalLM": + """Apply tensor parallelism to a ``Qwen3ForCausalLM`` text encoder on Neuron. + + The sharding plan is derived from ``model.config.base_model_tp_plan`` — + the same plan used by ``from_pretrained(tp_plan="auto")`` in transformers — + so it stays in sync automatically if the plan changes upstream. + + ``"replicated_with_grad_allreduce"`` entries (Q/K norm layers) are skipped: + those layers require gradient all-reduce in training but need no weight + sharding for inference. + + Qwen3's separate ``gate_proj`` / ``up_proj`` projections require no weight + permutations (unlike Flux2's fused SwiGLU). + + The model weights must still be on CPU when this function is called:: + + apply_tp_qwen3_neuron(pipe.text_encoder, tp_mesh) + pipe.text_encoder = pipe.text_encoder.to(device) + + **Primary path**: try ``Qwen3ForCausalLM.from_pretrained(model_id, tp_plan="auto")`` + first — transformers' native TP may work on Neuron directly since its hook + mechanism does not use DTensor reduce_scatter. Fall back to this function if + the NRT bug is triggered. + + Args: + model: ``Qwen3ForCausalLM`` with weights on CPU. + tp_mesh: 1-D Neuron device mesh of size ``tp_size``. + + Returns: + The same ``model`` instance, modified in-place. + """ + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + + rank = dist.get_rank() + tp_size = tp_mesh.size() + + style_map = { + "colwise": ColwiseParallel(), + "colwise_gather_output": ColwiseParallel(), # lm_head — same for inference + "rowwise": RowwiseParallel(), + # "replicated_with_grad_allreduce" → skipped (q_norm/k_norm, inference only) + } + + # config.base_model_tp_plan example: + # {"layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", ...} + per_layer_plan = { + path.split("*.")[1]: style_map[style] + for path, style in model.config.base_model_tp_plan.items() + if "*." in path and style in style_map + } + + if not per_layer_plan: + raise ValueError( + "Could not extract a per-layer TP plan from `model.config.base_model_tp_plan`. " + f"Got: {model.config.base_model_tp_plan}" + ) + + for layer in model.model.layers: + _pre_shard_and_tp(layer, tp_mesh, per_layer_plan, rank, tp_size) + + return model diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index b1645b4ae244..1a49a17dd8c7 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -369,7 +369,9 @@ def _prepare_text_ids( coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) - return torch.stack(out_ids) + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return torch.stack(out_ids).float() @staticmethod def _prepare_latent_ids( @@ -401,7 +403,9 @@ def _prepare_latent_ids( # Expand to batch: (B, H*W, 4) latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) - return latent_ids + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return latent_ids.float() @staticmethod def _prepare_image_ids( @@ -451,7 +455,9 @@ def _prepare_image_ids( image_latent_ids = torch.cat(image_latent_ids, dim=0) image_latent_ids = image_latent_ids.unsqueeze(0) - return image_latent_ids + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return image_latent_ids.float() @staticmethod def _patchify_latents(latents): diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index d768e6127f26..c21ca3fe8810 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -279,7 +279,9 @@ def _prepare_text_ids( coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) - return torch.stack(out_ids) + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return torch.stack(out_ids).float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids @@ -312,7 +314,9 @@ def _prepare_latent_ids( # Expand to batch: (B, H*W, 4) latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) - return latent_ids + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return latent_ids.float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids @@ -363,7 +367,9 @@ def _prepare_image_ids( image_latent_ids = torch.cat(image_latent_ids, dim=0) image_latent_ids = image_latent_ids.unsqueeze(0) - return image_latent_ids + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return image_latent_ids.float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents @@ -883,6 +889,13 @@ def __call__( # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) + # When running with tensor parallelism all ranks run the same + # (deterministic) scheduler step, so this broadcast is a safety + # measure only — it keeps ranks in sync if numerical drift + # or non-determinism ever causes a divergence. + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.broadcast(latents, src=0) + if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: @@ -904,7 +917,13 @@ def __call__( # Pass pre-computed latent height/width to avoid DtoH sync from torch.max().item() latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + latent_device = latents.device + if torch_device == "neuron": + latents = latents.cpu() + latent_ids = latent_ids.cpu() latents = self._unpack_latents_with_ids(latents, latent_ids, latent_height // 2, latent_width // 2) + if torch_device == "neuron": + latents = latents.to(latent_device) latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index fd9467003a71..c495f79f87b2 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -326,7 +326,9 @@ def _prepare_text_ids( coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) - return torch.stack(out_ids) + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return torch.stack(out_ids).float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids @@ -359,7 +361,9 @@ def _prepare_latent_ids( # Expand to batch: (B, H*W, 4) latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) - return latent_ids + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return latent_ids.float() @staticmethod def _prepare_image_ids( diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py index 78ed42f20afb..a5e8d6c75db1 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py @@ -284,7 +284,9 @@ def _prepare_text_ids( coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) - return torch.stack(out_ids) + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return torch.stack(out_ids).float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids @@ -317,7 +319,9 @@ def _prepare_latent_ids( # Expand to batch: (B, H*W, 4) latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) - return latent_ids + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return latent_ids.float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids @@ -368,7 +372,9 @@ def _prepare_image_ids( image_latent_ids = torch.cat(image_latent_ids, dim=0) image_latent_ids = image_latent_ids.unsqueeze(0) - return image_latent_ids + # Cast position ids to float32: these are RoPE coordinate indices and the Neuron + # compiler does not support int64 tensors. float32 is exact for this index range. + return image_latent_ids.float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8439a2b93371..0466cf5c3ac5 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2055,6 +2055,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TensorParallelConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"]