From 98f6c8c257c4f2e00b209c188fc993feb93e3fc7 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 18 Mar 2026 11:15:03 +0000 Subject: [PATCH 01/12] draft:add neuron as a legit backend --- src/diffusers/pipelines/pipeline_utils.py | 58 ++++++++++++++++++++++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 5 ++ src/diffusers/utils/torch_utils.py | 25 ++++++++-- 4 files changed, 85 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d675f1de04a7..44fe8367636d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -68,6 +68,7 @@ is_transformers_version, logging, numpy_to_pil, + requires_backends, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card @@ -2248,6 +2249,61 @@ def _is_pipeline_device_mapped(self): return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 + def enable_neuron_compile( + self, + model_names: Optional[List[str]] = None, + cache_dir: Optional[str] = None, + fullgraph: bool = True, + ) -> None: + """ + Compiles the pipeline's nn.Module components with ``torch.compile(backend="neuron")``, + enabling whole-graph NEFF compilation for AWS Trainium/Inferentia. + + The first forward call per component triggers neuronx-cc compilation (slow). + Use ``neuron_warmup()`` to trigger this explicitly before timed inference. + + Args: + model_names (`List[str]`, *optional*): + Component names to compile. Defaults to all nn.Module components. + cache_dir (`str`, *optional*): + Path to persist compiled NEFFs across runs via ``TORCH_NEURONX_NEFF_CACHE_DIR``. + Skips recompilation on subsequent runs. + fullgraph (`bool`, defaults to `True`): + Disallow graph breaks (required for full-graph fusion). + """ + requires_backends(self, "torch_neuronx") + import torch_neuronx # noqa: F401 — registers neuron backend + + if cache_dir is not None: + os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir + + if model_names is None: + model_names = [ + name for name, comp in self.components.items() if isinstance(comp, torch.nn.Module) + ] + + for name in model_names: + component = getattr(self, name, None) + if isinstance(component, torch.nn.Module) and not is_compiled_module(component): + logger.info(f"Compiling {name} with backend='neuron'") + setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph)) + + def neuron_warmup(self, *args, **kwargs) -> None: + """ + Runs a single dummy forward pass through the pipeline to trigger neuronx-cc + compilation for all components (static-shape NEFF compilation). + + This is equivalent to calling ``__call__`` with the same shapes but discards + the output. After warmup, subsequent calls reuse the compiled NEFFs and run fast. + + Pass the same arguments you would use for real inference (height, width, + num_inference_steps, batch_size, etc.) so that the compiled shapes match. + """ + logger.info("Running Neuron warmup forward pass to trigger NEFF compilation...") + with torch.no_grad(): + self(*args, **kwargs) + logger.info("Neuron warmup complete.") + class StableDiffusionMixin: r""" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 23d7ac7c6c2d..8a86cf4f4151 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -110,6 +110,7 @@ is_timm_available, is_torch_available, is_torch_mlu_available, + is_torch_neuronx_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 551fa358a28d..e23fccc1a374 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu") +_torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _kernels_available, _kernels_version = _is_package_available("kernels") @@ -249,6 +250,10 @@ def is_torch_mlu_available(): return _torch_mlu_available +def is_torch_neuronx_available(): + return _torch_neuronx_available + + def is_flax_available(): return _flax_available diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 7f4cb3e12766..88b53e2b5b16 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -21,19 +21,26 @@ import os from . import logging -from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version +from .import_utils import ( + is_torch_available, + is_torch_mlu_available, + is_torch_neuronx_available, + is_torch_npu_available, + is_torch_version, +) if is_torch_available(): import torch from torch.fft import fftn, fftshift, ifftn, ifftshift - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True} BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, + "neuron": None, "default": None, } BACKEND_DEVICE_COUNT = { @@ -41,6 +48,7 @@ "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "device_count", lambda: 0)(), "default": 0, } BACKEND_MANUAL_SEED = { @@ -48,6 +56,7 @@ "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, + "neuron": torch.manual_seed, "default": torch.manual_seed, } BACKEND_RESET_PEAK_MEMORY_STATS = { @@ -55,6 +64,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_RESET_MAX_MEMORY_ALLOCATED = { @@ -62,6 +72,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_MAX_MEMORY_ALLOCATED = { @@ -69,6 +80,7 @@ "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, + "neuron": 0, "default": 0, } BACKEND_SYNCHRONIZE = { @@ -76,6 +88,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -164,11 +177,15 @@ def randn_tensor( layout = layout or torch.strided device = device or torch.device("cpu") + # Neuron (XLA) does not support creating random tensors directly on device; always use CPU + if device.type == "neuron": + rand_device = torch.device("cpu") + if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" - if device != "mps": + if device.type not in ("mps", "neuron"): logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" @@ -289,6 +306,8 @@ def get_device(): return "mps" elif is_torch_mlu_available(): return "mlu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + return "neuron" else: return "cpu" From a76953cf34aecd0efda8e364798102b3c71a0db2 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Thu, 26 Mar 2026 11:57:09 +0000 Subject: [PATCH 02/12] feat: neuron-specific changes in the pipeline --- .../models/unets/unet_2d_condition.py | 5 ++-- src/diffusers/pipelines/pipeline_utils.py | 2 ++ .../pipeline_stable_diffusion_xl.py | 26 ++++++++++++++++--- src/diffusers/utils/import_utils.py | 5 ++++ src/diffusers/utils/torch_utils.py | 2 +- 5 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index deae25899475..b533bef35414 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -855,10 +855,11 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" is_npu = sample.device.type == "npu" + is_neuron = sample.device.type == "neuron" if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 44fe8367636d..5b329f46e2aa 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -2273,6 +2273,7 @@ def enable_neuron_compile( """ requires_backends(self, "torch_neuronx") import torch_neuronx # noqa: F401 — registers neuron backend + from torch_neuronx.neuron_dynamo_backend import set_model_name if cache_dir is not None: os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir @@ -2286,6 +2287,7 @@ def enable_neuron_compile( component = getattr(self, name, None) if isinstance(component, torch.nn.Module) and not is_compiled_module(component): logger.info(f"Compiling {name} with backend='neuron'") + set_model_name(name) setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph)) def neuron_warmup(self, *args, **kwargs) -> None: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2f6b105702e8..fdda2547f09e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1092,7 +1092,11 @@ def __call__( ) # 4. Prepare timesteps - if XLA_AVAILABLE: + # Keep timesteps on CPU for XLA (TPU) and Neuron: both use lazy/XLA execution where + # dynamic-shape ops like .nonzero() and .item() inside scheduler.index_for_timestep() + # are incompatible with static-graph compilation. + is_neuron_device = hasattr(device, "type") and device.type == "neuron" + if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" else: timestep_device = device @@ -1195,15 +1199,23 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # For Neuron: scale_model_input on CPU to avoid XLA ops outside the compiled UNet region. + # index_for_timestep() uses .nonzero()/.item() which are incompatible with static graphs. + if is_neuron_device: + latent_model_input = self.scheduler.scale_model_input(latent_model_input.to("cpu"), t).to(device) + else: + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds + # For Neuron: pre-cast timestep to float32 on device. Neuron XLA does not support + # int64 ops; the compiled UNet graph requires a float32 timestep input on-device. + t_unet = t.to(torch.float32).to(device) if is_neuron_device else t noise_pred = self.unet( latent_model_input, - t, + t_unet, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, @@ -1222,7 +1234,13 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # For Neuron: scheduler.step on CPU to keep scheduler arithmetic off the XLA device. + if is_neuron_device: + latents = self.scheduler.step( + noise_pred.to("cpu"), t, latents.to("cpu"), **extra_step_kwargs, return_dict=False + )[0].to(device) + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index e23fccc1a374..2ce989626b3d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -584,6 +584,10 @@ def is_av_available(): """ +TORCH_NEURONX_IMPORT_ERROR = """ +{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/ +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -614,6 +618,7 @@ def is_av_available(): ("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)), ("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)), ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), + ("torch_neuronx", (is_torch_neuronx_available, TORCH_NEURONX_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 7d16c8556689..55fee1d3249e 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -93,7 +93,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, - "neuron": None, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "synchronize", lambda: None)(), "default": None, } logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 2480388fb12a423527d491ab5211c058b07b3262 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Fri, 27 Mar 2026 17:55:26 +0000 Subject: [PATCH 03/12] tests: eager tests --- src/diffusers/pipelines/pipeline_utils.py | 58 --------------------- src/diffusers/utils/testing_utils.py | 3 ++ tests/pipelines/pixart_alpha/test_pixart.py | 10 +++- tests/testing_utils.py | 26 +++++++-- 4 files changed, 33 insertions(+), 64 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 5b329f46e2aa..bbee2189c22f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -2249,64 +2249,6 @@ def _is_pipeline_device_mapped(self): return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 - def enable_neuron_compile( - self, - model_names: Optional[List[str]] = None, - cache_dir: Optional[str] = None, - fullgraph: bool = True, - ) -> None: - """ - Compiles the pipeline's nn.Module components with ``torch.compile(backend="neuron")``, - enabling whole-graph NEFF compilation for AWS Trainium/Inferentia. - - The first forward call per component triggers neuronx-cc compilation (slow). - Use ``neuron_warmup()`` to trigger this explicitly before timed inference. - - Args: - model_names (`List[str]`, *optional*): - Component names to compile. Defaults to all nn.Module components. - cache_dir (`str`, *optional*): - Path to persist compiled NEFFs across runs via ``TORCH_NEURONX_NEFF_CACHE_DIR``. - Skips recompilation on subsequent runs. - fullgraph (`bool`, defaults to `True`): - Disallow graph breaks (required for full-graph fusion). - """ - requires_backends(self, "torch_neuronx") - import torch_neuronx # noqa: F401 — registers neuron backend - from torch_neuronx.neuron_dynamo_backend import set_model_name - - if cache_dir is not None: - os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir - - if model_names is None: - model_names = [ - name for name, comp in self.components.items() if isinstance(comp, torch.nn.Module) - ] - - for name in model_names: - component = getattr(self, name, None) - if isinstance(component, torch.nn.Module) and not is_compiled_module(component): - logger.info(f"Compiling {name} with backend='neuron'") - set_model_name(name) - setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph)) - - def neuron_warmup(self, *args, **kwargs) -> None: - """ - Runs a single dummy forward pass through the pipeline to trigger neuronx-cc - compilation for all components (static-shape NEFF compilation). - - This is equivalent to calling ``__call__`` with the same shapes but discards - the output. After warmup, subsequent calls reuse the compiled NEFFs and run fast. - - Pass the same arguments you would use for real inference (height, width, - num_inference_steps, batch_size, etc.) so that the compiled shapes match. - """ - logger.info("Running Neuron warmup forward pass to trigger NEFF compilation...") - with torch.no_grad(): - self(*args, **kwargs) - logger.info("Neuron warmup complete.") - - class StableDiffusionMixin: r""" Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 619a37034949..eefe52c477a6 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -46,6 +46,7 @@ is_peft_available, is_timm_available, is_torch_available, + is_torch_neuronx_available, is_torch_version, is_torchao_available, is_torchsde_available, @@ -113,6 +114,8 @@ torch_device = "cuda" elif torch.xpu.is_available(): torch_device = "xpu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + torch_device = torch.neuron.current_device() else: torch_device = "cpu" is_torch_higher_equal_than_1_12 = version.parse( diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index 037a9f44f31e..0aa6812c6b25 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -28,6 +28,8 @@ PixArtTransformer2DModel, ) +from diffusers.utils.import_utils import is_torch_neuronx_available + from ...testing_utils import ( backend_empty_cache, enable_full_determinism, @@ -291,7 +293,9 @@ def test_pixart_1024(self): expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589]) max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) - self.assertLessEqual(max_diff, 1e-4) + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + self.assertLessEqual(max_diff, atol) def test_pixart_512(self): generator = torch.Generator("cpu").manual_seed(0) @@ -307,7 +311,9 @@ def test_pixart_512(self): expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958]) max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) - self.assertLessEqual(max_diff, 1e-4) + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + self.assertLessEqual(max_diff, atol) def test_pixart_1024_without_resolution_binning(self): generator = torch.manual_seed(0) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 53c1b8aa26ce..778381cf31e0 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -45,6 +45,7 @@ is_peft_available, is_timm_available, is_torch_available, + is_torch_neuronx_available, is_torch_version, is_torchao_available, is_torchsde_available, @@ -109,6 +110,8 @@ torch_device = "cuda" elif torch.xpu.is_available(): torch_device = "xpu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + torch_device = torch.neuron.current_device() else: torch_device = "cpu" is_torch_higher_equal_than_1_12 = version.parse( @@ -1427,6 +1430,15 @@ def _is_torch_fp64_available(device): # Behaviour flags BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + # Neuron device key: torch.neuron.current_device() returns an int (e.g. 0). + # We capture it once at import time if torch_neuronx is available so we can add it + # to all dispatch tables using the same key that torch_device is set to. + _neuron_device = ( + torch.neuron.current_device() + if (is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available()) + else None + ) + # Function definitions BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, @@ -1478,13 +1490,19 @@ def _is_torch_fp64_available(device): "default": None, } + if _neuron_device is not None: + BACKEND_EMPTY_CACHE[_neuron_device] = None + BACKEND_DEVICE_COUNT[_neuron_device] = torch.neuron.device_count + BACKEND_MANUAL_SEED[_neuron_device] = torch.manual_seed + BACKEND_RESET_PEAK_MEMORY_STATS[_neuron_device] = None + BACKEND_RESET_MAX_MEMORY_ALLOCATED[_neuron_device] = None + BACKEND_MAX_MEMORY_ALLOCATED[_neuron_device] = 0 + BACKEND_SYNCHRONIZE[_neuron_device] = torch.neuron.synchronize + # This dispatches a defined function according to the accelerator from the function definitions. def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs): - if device not in dispatch_table: - return dispatch_table["default"](*args, **kwargs) - - fn = dispatch_table[device] + fn = dispatch_table[device] if device in dispatch_table else dispatch_table["default"] # Some device agnostic functions return values. Need to guard against 'None' instead at # user level From 1469c04ff43b406465baaa4b32207f15a655c5ff Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Thu, 9 Apr 2026 16:32:41 +0000 Subject: [PATCH 04/12] draft: start with tp for flux2 --- src/diffusers/hooks/tensor_parallel.py | 79 +++++ src/diffusers/models/_modeling_parallel.py | 63 +++- .../models/transformers/transformer_flux2.py | 182 ++++++++++ .../transformer_flux2_neuron_tp.py | 310 ++++++++++++++++++ .../pipelines/flux2/pipeline_flux2.py | 6 +- .../pipelines/flux2/pipeline_flux2_klein.py | 13 +- .../flux2/pipeline_flux2_klein_kv.py | 6 +- 7 files changed, 649 insertions(+), 10 deletions(-) create mode 100644 src/diffusers/hooks/tensor_parallel.py create mode 100644 src/diffusers/models/transformers/transformer_flux2_neuron_tp.py diff --git a/src/diffusers/hooks/tensor_parallel.py b/src/diffusers/hooks/tensor_parallel.py new file mode 100644 index 000000000000..2008927ca702 --- /dev/null +++ b/src/diffusers/hooks/tensor_parallel.py @@ -0,0 +1,79 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ..models._modeling_parallel import TensorParallelConfig +from ..utils import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +def apply_tensor_parallel( + model: torch.nn.Module, + config: TensorParallelConfig, + double_block_plan: dict, + single_block_plan: dict, +) -> None: + """Apply tensor parallelism to a ``Flux2Transformer2DModel``. + + This is the generic (non-Neuron) path. It calls + ``torch.distributed.tensor.parallel.parallelize_module`` directly on each + transformer block, using the plans defined on the model. + + For Neuron, use ``apply_tp_flux2_transformer_neuron`` from + ``diffusers.models.transformers.transformer_flux2_neuron_tp`` instead, which + pre-shards weights via ``DTensor.from_local`` to work around the Neuron NRT + consecutive-reduce-scatter bug. + + Args: + model (`torch.nn.Module`): + A ``Flux2Transformer2DModel`` instance. Must have ``transformer_blocks`` + and ``single_transformer_blocks`` attributes. + config (`TensorParallelConfig`): + TP configuration. ``config.setup()`` must have been called before this + function so that ``config._mesh`` is populated. + double_block_plan (`dict`): + ``parallelize_module`` plan for each double-stream block + (``model.transformer_blocks``). Keys are relative module paths + (e.g. ``"attn.to_q"``), values are ``ColwiseParallel()`` / + ``RowwiseParallel()`` instances. + single_block_plan (`dict`): + ``parallelize_module`` plan for each single-stream block + (``model.single_transformer_blocks``). + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + raise RuntimeError( + "apply_tensor_parallel requires an initialised torch.distributed process group." + ) + + try: + from torch.distributed.tensor.parallel import parallelize_module + except ImportError as e: + raise ImportError( + "apply_tensor_parallel requires PyTorch >= 2.3 with distributed tensor parallel support." + ) from e + + tp_mesh = config._mesh + if tp_mesh is None: + raise ValueError( + "`config._mesh` is None. Call `config.setup(rank, world_size, device)` before applying TP." + ) + + for block in model.transformer_blocks: + parallelize_module(block, tp_mesh, double_block_plan) + + for block in model.single_transformer_blocks: + parallelize_module(block, tp_mesh, single_block_plan) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 8573c01ca4c7..f14d37f594f3 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -35,7 +35,6 @@ # - Unified Attention # - More dispatcher attention backends # - CFG/Data Parallel -# - Tensor Parallel @dataclass @@ -142,6 +141,63 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() +@dataclass +class TensorParallelConfig: + """ + Configuration for tensor parallelism. + + Tensor parallelism shards weight matrices (column-wise and row-wise) across devices. + Each device computes a partial result; an AllReduce/AllGather at layer boundaries + reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module`` + with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles. + + On Neuron, use the ``_pre_shard_and_tp`` workaround from + ``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug + on large tensors (≥ 5120×5120). + + Args: + tp_degree (`int`, defaults to `1`): + Number of devices to shard across. Must be a divisor of the number of + attention heads (and FFN hidden dimensions) of the model being parallelised. + mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): + A custom device mesh to use. If provided, ``tp_degree`` is inferred from + ``mesh.size()`` and the argument is ignored. Useful when combining TP with + other parallelism strategies (e.g. CP) that share the same mesh. + """ + + tp_degree: int = 1 + mesh: torch.distributed.device_mesh.DeviceMesh | None = None + + _rank: int = None + _world_size: int = None + _device: torch.device = None + _mesh: torch.distributed.device_mesh.DeviceMesh = None + + def __post_init__(self): + if self.tp_degree < 1: + raise ValueError("`tp_degree` must be >= 1.") + + def setup( + self, + rank: int, + world_size: int, + device: torch.device, + mesh: torch.distributed.device_mesh.DeviceMesh | None = None, + ): + self._rank = rank + self._world_size = world_size + self._device = device + if mesh is not None: + self._mesh = mesh + elif self.mesh is not None: + self._mesh = self.mesh + else: + from torch.distributed.device_mesh import init_device_mesh + + device_type = str(device).split(":")[0] + self._mesh = init_device_mesh(device_type, (self.tp_degree,), mesh_dim_names=("tp",)) + + @dataclass class ParallelConfig: """ @@ -150,9 +206,12 @@ class ParallelConfig: Args: context_parallel_config (`ContextParallelConfig`, *optional*): Configuration for context parallelism. + tensor_parallel_config (`TensorParallelConfig`, *optional*): + Configuration for tensor parallelism. """ context_parallel_config: ContextParallelConfig | None = None + tensor_parallel_config: TensorParallelConfig | None = None _rank: int = None _world_size: int = None @@ -173,6 +232,8 @@ def setup( self._mesh = mesh if self.context_parallel_config is not None: self.context_parallel_config.setup(rank, world_size, device, mesh) + if self.tensor_parallel_config is not None: + self.tensor_parallel_config.setup(rank, world_size, device, mesh) @dataclass(frozen=True) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 5c90f3a46a98..38fcddb0cf54 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -621,6 +621,146 @@ def __call__( return hidden_states +class Flux2AttnProcessorTP(Flux2AttnProcessor): + """ + TP-aware version of ``Flux2AttnProcessor`` for double-stream transformer blocks. + + After column-wise weight sharding, each rank holds ``attn.heads // tp_size`` heads. + The only difference from the base class is that ``unflatten`` uses the local head + count rather than the full ``attn.heads``. + + Args: + tp_size (`int`): Number of tensor-parallel ranks (== ``tp_mesh.size()``). + """ + + def __init__(self, tp_size: int): + super().__init__() + self.tp_size = tp_size + + def __call__( + self, + attn: "Flux2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + local_heads = attn.heads // self.tp_size + head_dim = attn.head_dim + + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (local_heads, head_dim)) + key = key.unflatten(-1, (local_heads, head_dim)) + value = value.unflatten(-1, (local_heads, head_dim)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim)) + encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim)) + encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + return hidden_states + + +class Flux2ParallelSelfAttnProcessorTP(Flux2ParallelSelfAttnProcessor): + """ + TP-aware version of ``Flux2ParallelSelfAttnProcessor`` for single-stream blocks. + + After column-wise weight sharding the fused ``to_qkv_mlp_proj`` projection, + each rank holds a proportionally smaller slice of Q/K/V and MLP dimensions. + The split sizes are computed from the local (per-rank) head count and inner dim. + + Args: + tp_size (`int`): Number of tensor-parallel ranks (== ``tp_mesh.size()``). + """ + + def __init__(self, tp_size: int): + super().__init__() + self.tp_size = tp_size + + def __call__( + self, + attn: "Flux2ParallelSelfAttention", + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + local_heads = attn.heads // self.tp_size + head_dim = attn.head_dim + local_inner = attn.inner_dim // self.tp_size + local_mlp_gate = attn.mlp_hidden_dim * attn.mlp_mult_factor // self.tp_size + + hidden_states = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split(hidden_states, [3 * local_inner, local_mlp_gate], dim=-1) + + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (local_heads, head_dim)) + key = key.unflatten(-1, (local_heads, head_dim)) + value = value.unflatten(-1, (local_heads, head_dim)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + return attn.to_out(hidden_states) + + class Flux2KVParallelSelfAttnProcessor: """ Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens. @@ -1091,6 +1231,48 @@ class Flux2Transformer2DModel( "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), } + # Tensor-parallel sharding plans (one per block type). + # Used by ``apply_tensor_parallel`` (generic path) and by the Neuron-specific + # ``apply_tp_flux2_transformer_neuron`` (which also needs weight permutations). + # Populated lazily on first access to avoid importing torch.distributed.tensor + # at module import time when TP is not used. + _tp_double_block_plan: "dict | None" = None + _tp_single_block_plan: "dict | None" = None + + @classmethod + def _get_tp_double_block_plan(cls) -> dict: + """Return the TP sharding plan for double-stream (cross-attention + FFN) blocks.""" + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + + if cls._tp_double_block_plan is None: + cls._tp_double_block_plan = { + "attn.to_q": ColwiseParallel(), + "attn.to_k": ColwiseParallel(), + "attn.to_v": ColwiseParallel(), + "attn.to_out.0": RowwiseParallel(), + "attn.add_q_proj": ColwiseParallel(), + "attn.add_k_proj": ColwiseParallel(), + "attn.add_v_proj": ColwiseParallel(), + "attn.to_add_out": RowwiseParallel(), + "ff.linear_in": ColwiseParallel(), + "ff.linear_out": RowwiseParallel(), + "ff_context.linear_in": ColwiseParallel(), + "ff_context.linear_out": RowwiseParallel(), + } + return cls._tp_double_block_plan + + @classmethod + def _get_tp_single_block_plan(cls) -> dict: + """Return the TP sharding plan for single-stream (parallel self-attn + fused MLP) blocks.""" + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + + if cls._tp_single_block_plan is None: + cls._tp_single_block_plan = { + "attn.to_qkv_mlp_proj": ColwiseParallel(), + "attn.to_out": RowwiseParallel(), + } + return cls._tp_single_block_plan + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_flux2_neuron_tp.py b/src/diffusers/models/transformers/transformer_flux2_neuron_tp.py new file mode 100644 index 000000000000..0c7b57110b1a --- /dev/null +++ b/src/diffusers/models/transformers/transformer_flux2_neuron_tp.py @@ -0,0 +1,310 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Neuron-specific Tensor Parallelism utilities for Flux2 and Qwen3. + +This module provides the functions needed to apply tensor parallelism on AWS +Neuron hardware. The key difference from the generic ``apply_tensor_parallel`` +path is a workaround for a Neuron NRT bug: consecutive ``reduce_scatter`` +collectives for large weight tensors (≥ 5120×5120) can fail when all layers +are distributed in a single ``parallelize_module`` call. The fix is to +pre-shard each weight locally on CPU via ``DTensor.from_local`` *before* +calling ``parallelize_module``; the latter then sees already-placed DTensors, +skips the collective for weights, but still registers the required +input/output hooks for the forward pass. + +Entry points: + ``apply_tp_flux2_transformer_neuron(model, tp_mesh)`` + Apply TP to a ``Flux2Transformer2DModel``. Includes the weight + permutations required by Flux2's SwiGLU FFN and fused QKV+MLP + projections. + + ``apply_tp_qwen3_neuron(model, tp_mesh)`` + Apply TP to a ``Qwen3ForCausalLM`` text encoder. The sharding plan is + derived from ``model.config.base_model_tp_plan`` — the same plan used + by ``from_pretrained(tp_plan="auto")`` in transformers — so it stays in + sync automatically if the plan changes upstream. +""" + +import torch +import torch.distributed as dist +import torch.nn as nn + + +def _permute_swiglu_for_tp(weight: torch.Tensor, tp_size: int) -> torch.Tensor: + """Interleave gate/linear chunks of a SwiGLU FFN weight for column-wise TP. + + ``ff.linear_in`` in Flux2 double-stream blocks stores + ``[gate_0…gate_N, linear_0…linear_N]`` (two halves concatenated). + After ``ColwiseParallel``, rank *r* takes rows + ``[r*chunk : (r+1)*chunk]`` from the full weight — but that would give + rank *r* only gate rows, not the paired gate+linear rows it needs. + This function re-orders to ``[gate_0, linear_0, gate_1, linear_1, …]`` + so that slicing is consistent. + """ + with torch.no_grad(): + total = weight.shape[0] + inner = total // 2 + chunk = inner // tp_size + gate = weight[:inner] + linear = weight[inner:] + parts = [] + for i in range(tp_size): + parts.append(gate[i * chunk : (i + 1) * chunk]) + parts.append(linear[i * chunk : (i + 1) * chunk]) + return torch.cat(parts, dim=0) + + +def _permute_qkv_mlp_for_tp( + weight: torch.Tensor, + tp_size: int, + inner_dim: int, + mlp_hidden_dim: int, +) -> torch.Tensor: + """Interleave Q/K/V/gate/linear chunks of the fused ``to_qkv_mlp_proj`` weight. + + ``to_qkv_mlp_proj`` in single-stream blocks concatenates + ``[Q, K, V, mlp_gate, mlp_linear]`` along the output dimension. + Re-order so that rank *r* receives a contiguous slice containing its + proportional share of each component. + """ + with torch.no_grad(): + q = weight[:inner_dim] + k = weight[inner_dim : 2 * inner_dim] + v = weight[2 * inner_dim : 3 * inner_dim] + mlp_gate = weight[3 * inner_dim : 3 * inner_dim + mlp_hidden_dim] + mlp_lin = weight[3 * inner_dim + mlp_hidden_dim :] + + qkv_chunk = inner_dim // tp_size + mlp_chunk = mlp_hidden_dim // tp_size + + parts = [] + for i in range(tp_size): + parts += [ + q[i * qkv_chunk : (i + 1) * qkv_chunk], + k[i * qkv_chunk : (i + 1) * qkv_chunk], + v[i * qkv_chunk : (i + 1) * qkv_chunk], + mlp_gate[i * mlp_chunk : (i + 1) * mlp_chunk], + mlp_lin[i * mlp_chunk : (i + 1) * mlp_chunk], + ] + return torch.cat(parts, dim=0) + + +def _permute_out_for_tp( + weight: torch.Tensor, + tp_size: int, + attn_dim: int, + mlp_dim: int, +) -> torch.Tensor: + """Interleave attn/mlp output columns of the fused ``to_out`` weight. + + ``to_out`` in single-stream blocks accepts ``[attn_out, mlp_out]`` + concatenated along the input (column) dimension. Re-order columns so + that rank *r* receives a contiguous slice of paired attn+mlp columns. + """ + with torch.no_grad(): + attn_part = weight[:, :attn_dim] + mlp_part = weight[:, attn_dim:] + + attn_chunk = attn_dim // tp_size + mlp_chunk = mlp_dim // tp_size + + parts = [] + for i in range(tp_size): + parts.append(attn_part[:, i * attn_chunk : (i + 1) * attn_chunk]) + parts.append(mlp_part[:, i * mlp_chunk : (i + 1) * mlp_chunk]) + return torch.cat(parts, dim=1) + + +def _pre_shard_and_tp( + module: nn.Module, + tp_mesh: "torch.distributed.device_mesh.DeviceMesh", + plan: dict, + rank: int, + tp_size: int, +) -> None: + """Pre-shard Linear weights via ``DTensor.from_local``, then call ``parallelize_module``. + + Workaround for a Neuron NRT bug where consecutive ``reduce_scatter`` calls + for large weight tensors (≥ 5120×5120) fail when all layers are distributed + in a single ``parallelize_module`` call. By pre-sharding each weight on CPU + before the call, ``distribute_tensor`` inside ``parallelize_module`` sees an + already-placed DTensor and skips the collective, while the module hooks + (input/output specs) are still registered correctly. + + Args: + module: The block whose Linear sub-modules are being sharded. + tp_mesh: Device mesh for TP (1-D, size == tp_size). + plan: ``{relative_path: ColwiseParallel() | RowwiseParallel()}`` dict, + as expected by ``parallelize_module``. + rank: Current rank (``dist.get_rank()``). + tp_size: Total TP degree (``tp_mesh.size()``). + """ + from torch.distributed.tensor import DTensor, Shard + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module + + device = torch.neuron.current_device() + + for path, style in plan.items(): + # Resolve nested attribute path (e.g. "attn.to_q" or "attn.to_out.0") + submod = module + for part in path.split("."): + submod = getattr(submod, part) + + if not hasattr(submod, "weight"): + continue + + w = submod.weight.data # CPU at this point + if isinstance(style, ColwiseParallel): + rows = w.shape[0] // tp_size + shard = w[rank * rows : (rank + 1) * rows, :].contiguous().to(device) + submod.weight = nn.Parameter(DTensor.from_local(shard, tp_mesh, [Shard(0)])) + elif isinstance(style, RowwiseParallel): + cols = w.shape[1] // tp_size + shard = w[:, rank * cols : (rank + 1) * cols].contiguous().to(device) + submod.weight = nn.Parameter(DTensor.from_local(shard, tp_mesh, [Shard(1)])) + + # parallelize_module is now a no-op for weight distribution (already DTensors) + # but still registers the input/output hooks required for the forward pass. + parallelize_module(module, tp_mesh, plan) + + +def apply_tp_flux2_transformer_neuron( + model: "Flux2Transformer2DModel", + tp_mesh: "torch.distributed.device_mesh.DeviceMesh", +) -> "Flux2Transformer2DModel": + """Apply tensor parallelism to a ``Flux2Transformer2DModel`` on Neuron. + + Steps for each block type: + 1. Permute fused weights so that column-wise slicing gives each rank a + correct paired chunk (SwiGLU gate+linear, or Q/K/V/MLP). + 2. Pre-shard weights via ``DTensor.from_local`` (Neuron NRT workaround). + 3. Call ``parallelize_module`` to register input/output hooks. + 4. Replace the attention processor with the TP-aware variant. + + The model weights must still be on CPU when this function is called. + Move the model to the Neuron device *after* this call:: + + apply_tp_flux2_transformer_neuron(pipe.transformer, tp_mesh) + pipe.transformer = pipe.transformer.to(device) + + Args: + model: ``Flux2Transformer2DModel`` with weights on CPU. + tp_mesh: 1-D Neuron device mesh of size ``tp_size``. + + Returns: + The same ``model`` instance, modified in-place. + """ + from .transformer_flux2 import Flux2AttnProcessorTP, Flux2ParallelSelfAttnProcessorTP + + rank = dist.get_rank() + tp_size = tp_mesh.size() + + double_plan = model._get_tp_double_block_plan() + single_plan = model._get_tp_single_block_plan() + + # ── Double-stream blocks (cross-attn + FFN) ──────────────────────────── + for block in model.transformer_blocks: + # Permute SwiGLU weights before sharding + block.ff.linear_in.weight.data = _permute_swiglu_for_tp( + block.ff.linear_in.weight.data, tp_size + ) + block.ff_context.linear_in.weight.data = _permute_swiglu_for_tp( + block.ff_context.linear_in.weight.data, tp_size + ) + _pre_shard_and_tp(block, tp_mesh, double_plan, rank, tp_size) + block.attn.set_processor(Flux2AttnProcessorTP(tp_size)) + + # ── Single-stream blocks (parallel self-attn + fused MLP) ────────────── + for block in model.single_transformer_blocks: + attn = block.attn + inner_dim = attn.inner_dim + mlp_hidden = attn.mlp_hidden_dim + + attn.to_qkv_mlp_proj.weight.data = _permute_qkv_mlp_for_tp( + attn.to_qkv_mlp_proj.weight.data, tp_size, inner_dim, mlp_hidden + ) + attn.to_out.weight.data = _permute_out_for_tp( + attn.to_out.weight.data, tp_size, inner_dim, mlp_hidden + ) + _pre_shard_and_tp(block, tp_mesh, single_plan, rank, tp_size) + block.attn.set_processor(Flux2ParallelSelfAttnProcessorTP(tp_size)) + + return model + + +def apply_tp_qwen3_neuron( + model: "Qwen3ForCausalLM", + tp_mesh: "torch.distributed.device_mesh.DeviceMesh", +) -> "Qwen3ForCausalLM": + """Apply tensor parallelism to a ``Qwen3ForCausalLM`` text encoder on Neuron. + + The sharding plan is derived from ``model.config.base_model_tp_plan`` — + the same plan used by ``from_pretrained(tp_plan="auto")`` in transformers — + so it stays in sync automatically if the plan changes upstream. + + ``"replicated_with_grad_allreduce"`` entries (Q/K norm layers) are skipped: + those layers require gradient all-reduce in training but need no weight + sharding for inference. + + Qwen3's separate ``gate_proj`` / ``up_proj`` projections require no weight + permutations (unlike Flux2's fused SwiGLU). + + The model weights must still be on CPU when this function is called:: + + apply_tp_qwen3_neuron(pipe.text_encoder, tp_mesh) + pipe.text_encoder = pipe.text_encoder.to(device) + + **Primary path**: try ``Qwen3ForCausalLM.from_pretrained(model_id, tp_plan="auto")`` + first — transformers' native TP may work on Neuron directly since its hook + mechanism does not use DTensor reduce_scatter. Fall back to this function if + the NRT bug is triggered. + + Args: + model: ``Qwen3ForCausalLM`` with weights on CPU. + tp_mesh: 1-D Neuron device mesh of size ``tp_size``. + + Returns: + The same ``model`` instance, modified in-place. + """ + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + + rank = dist.get_rank() + tp_size = tp_mesh.size() + + style_map = { + "colwise": ColwiseParallel(), + "colwise_gather_output": ColwiseParallel(), # lm_head — same for inference + "rowwise": RowwiseParallel(), + # "replicated_with_grad_allreduce" → skipped (q_norm/k_norm, inference only) + } + + # config.base_model_tp_plan example: + # {"layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", ...} + per_layer_plan = { + path.split("*.")[1]: style_map[style] + for path, style in model.config.base_model_tp_plan.items() + if "*." in path and style in style_map + } + + if not per_layer_plan: + raise ValueError( + "Could not extract a per-layer TP plan from `model.config.base_model_tp_plan`. " + f"Got: {model.config.base_model_tp_plan}" + ) + + for layer in model.model.layers: + _pre_shard_and_tp(layer, tp_mesh, per_layer_plan, rank, tp_size) + + return model diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 4b60c6042d4f..3130cca8d032 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -369,7 +369,7 @@ def _prepare_text_ids( coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) - return torch.stack(out_ids) + return torch.stack(out_ids).float() @staticmethod def _prepare_latent_ids( @@ -401,7 +401,7 @@ def _prepare_latent_ids( # Expand to batch: (B, H*W, 4) latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) - return latent_ids + return latent_ids.float() @staticmethod def _prepare_image_ids( @@ -451,7 +451,7 @@ def _prepare_image_ids( image_latent_ids = torch.cat(image_latent_ids, dim=0) image_latent_ids = image_latent_ids.unsqueeze(0) - return image_latent_ids + return image_latent_ids.float() @staticmethod def _patchify_latents(latents): diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 936d2c3804ab..50f865aeef14 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -279,7 +279,7 @@ def _prepare_text_ids( coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) - return torch.stack(out_ids) + return torch.stack(out_ids).float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids @@ -312,7 +312,7 @@ def _prepare_latent_ids( # Expand to batch: (B, H*W, 4) latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) - return latent_ids + return latent_ids.float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids @@ -363,7 +363,7 @@ def _prepare_image_ids( image_latent_ids = torch.cat(image_latent_ids, dim=0) image_latent_ids = image_latent_ids.unsqueeze(0) - return image_latent_ids + return image_latent_ids.float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents @@ -877,6 +877,13 @@ def __call__( # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) + # When running with tensor parallelism all ranks run the same + # (deterministic) scheduler step, so this broadcast is a safety + # measure only — it keeps ranks in sync if numerical drift + # or non-determinism ever causes a divergence. + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.broadcast(latents, src=0) + if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py index 671953be63c1..998222166d86 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py @@ -284,7 +284,7 @@ def _prepare_text_ids( coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) - return torch.stack(out_ids) + return torch.stack(out_ids).float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids @@ -317,7 +317,7 @@ def _prepare_latent_ids( # Expand to batch: (B, H*W, 4) latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) - return latent_ids + return latent_ids.float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids @@ -368,7 +368,7 @@ def _prepare_image_ids( image_latent_ids = torch.cat(image_latent_ids, dim=0) image_latent_ids = image_latent_ids.unsqueeze(0) - return image_latent_ids + return image_latent_ids.float() @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents From 929ab7288fff62db2e5c32ec7e857eb3e9280f70 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Thu, 9 Apr 2026 16:37:56 +0000 Subject: [PATCH 05/12] fix: style --- .../train_instruct_pix2pix_sdxl.py | 8 ++- src/diffusers/loaders/peft.py | 2 +- src/diffusers/models/_modeling_parallel.py | 63 ++++++++++++++++++- src/diffusers/pipelines/pipeline_utils.py | 4 +- src/diffusers/utils/torch_utils.py | 9 ++- tests/pipelines/pixart_alpha/test_pixart.py | 1 - 6 files changed, 78 insertions(+), 9 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 5df0e22fe1cc..ce146c895686 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -85,9 +85,11 @@ def log_validation(pipeline, args, accelerator, generator, global_step, is_final os.makedirs(val_save_dir) original_image = ( - lambda image_url_or_path: load_image(image_url_or_path) - if urlparse(image_url_or_path).scheme - else Image.open(image_url_or_path).convert("RGB") + lambda image_url_or_path: ( + load_image(image_url_or_path) + if urlparse(image_url_or_path).scheme + else Image.open(image_url_or_path).convert("RGB") + ) )(args.val_image_url_or_path) if torch.backends.mps.is_available(): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index daa078bc25d5..68d9104e028d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) _SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( - lambda: (lambda model_cls, weights: weights), + lambda: lambda model_cls, weights: weights, { "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 8573c01ca4c7..e673980dbf44 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -35,7 +35,6 @@ # - Unified Attention # - More dispatcher attention backends # - CFG/Data Parallel -# - Tensor Parallel @dataclass @@ -142,6 +141,63 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() +@dataclass +class TensorParallelConfig: + """ + Configuration for tensor parallelism. + + Tensor parallelism shards weight matrices (column-wise and row-wise) across devices. + Each device computes a partial result; an AllReduce/AllGather at layer boundaries + reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module`` + with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles. + + On Neuron, use the ``_pre_shard_and_tp`` workaround from + ``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug + on large tensors (>= 5120x5120). + + Args: + tp_degree (`int`, defaults to `1`): + Number of devices to shard across. Must be a divisor of the number of + attention heads (and FFN hidden dimensions) of the model being parallelised. + mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): + A custom device mesh to use. If provided, ``tp_degree`` is inferred from + ``mesh.size()`` and the argument is ignored. Useful when combining TP with + other parallelism strategies (e.g. CP) that share the same mesh. + """ + + tp_degree: int = 1 + mesh: torch.distributed.device_mesh.DeviceMesh | None = None + + _rank: int = None + _world_size: int = None + _device: torch.device = None + _mesh: torch.distributed.device_mesh.DeviceMesh = None + + def __post_init__(self): + if self.tp_degree < 1: + raise ValueError("`tp_degree` must be >= 1.") + + def setup( + self, + rank: int, + world_size: int, + device: torch.device, + mesh: torch.distributed.device_mesh.DeviceMesh | None = None, + ): + self._rank = rank + self._world_size = world_size + self._device = device + if mesh is not None: + self._mesh = mesh + elif self.mesh is not None: + self._mesh = self.mesh + else: + from torch.distributed.device_mesh import init_device_mesh + + device_type = str(device).split(":")[0] + self._mesh = init_device_mesh(device_type, (self.tp_degree,), mesh_dim_names=("tp",)) + + @dataclass class ParallelConfig: """ @@ -150,9 +206,12 @@ class ParallelConfig: Args: context_parallel_config (`ContextParallelConfig`, *optional*): Configuration for context parallelism. + tensor_parallel_config (`TensorParallelConfig`, *optional*): + Configuration for tensor parallelism. """ context_parallel_config: ContextParallelConfig | None = None + tensor_parallel_config: TensorParallelConfig | None = None _rank: int = None _world_size: int = None @@ -173,6 +232,8 @@ def setup( self._mesh = mesh if self.context_parallel_config is not None: self.context_parallel_config.setup(rank, world_size, device, mesh) + if self.tensor_parallel_config is not None: + self.tensor_parallel_config.setup(rank, world_size, device, mesh) @dataclass(frozen=True) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index bbee2189c22f..d675f1de04a7 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -68,7 +68,6 @@ is_transformers_version, logging, numpy_to_pil, - requires_backends, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card @@ -2249,6 +2248,7 @@ def _is_pipeline_device_mapped(self): return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 + class StableDiffusionMixin: r""" Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 55fee1d3249e..e99719625df6 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -39,7 +39,14 @@ import torch from torch.fft import fftn, fftshift, ifftn, ifftshift - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True} + BACKEND_SUPPORTS_TRAINING = { + "cuda": True, + "xpu": True, + "cpu": True, + "mps": False, + "neuron": False, + "default": True, + } BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, "xpu": torch.xpu.empty_cache, diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index 0aa6812c6b25..86fe673a8c7d 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -27,7 +27,6 @@ PixArtAlphaPipeline, PixArtTransformer2DModel, ) - from diffusers.utils.import_utils import is_torch_neuronx_available from ...testing_utils import ( From 3bb9c7c3fc8483228d93c6bf9e16b2905712d17b Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Fri, 10 Apr 2026 15:35:35 +0000 Subject: [PATCH 06/12] fix:apr_02 beta --- src/diffusers/models/transformers/transformer_flux2.py | 3 ++- .../pipelines/pixart_alpha/pipeline_pixart_alpha.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 5c90f3a46a98..43d36d6476af 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -961,7 +961,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: pos = ids.float() is_mps = ids.device.type == "mps" is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + is_neuron = ids.device.type == "neuron" + freqs_dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] for i in range(len(self.axes_dim)): cos, sin = get_1d_rotary_pos_embed( diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 604e51d88583..bda4e40f3768 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -29,6 +29,7 @@ deprecate, is_bs4_available, is_ftfy_available, + is_torch_neuronx_available, is_torch_xla_available, logging, replace_example_docstring, @@ -862,7 +863,7 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - if XLA_AVAILABLE: + if XLA_AVAILABLE or is_torch_neuronx_available(): timestep_device = "cpu" else: timestep_device = device @@ -914,10 +915,11 @@ def __call__( # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" is_npu = latent_model_input.device.type == "npu" + is_neuron = latent_model_input.device.type == "neuron" if isinstance(current_timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) From dff1f32ccff3513fb6b155be3086ab998884c3f5 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Fri, 10 Apr 2026 16:17:13 +0000 Subject: [PATCH 07/12] feat:add wan --- .../transformers/transformer_wan_neuron_tp.py | 317 ++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_wan_neuron_tp.py diff --git a/src/diffusers/models/transformers/transformer_wan_neuron_tp.py b/src/diffusers/models/transformers/transformer_wan_neuron_tp.py new file mode 100644 index 000000000000..1377f10f0295 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_wan_neuron_tp.py @@ -0,0 +1,317 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Neuron-specific Tensor Parallelism utilities for WanTransformer3DModel. + +Entry point:: + + apply_tp_wan_transformer_neuron(model, tp_mesh) + +Apply TP to a ``WanTransformer3DModel`` for AWS Neuron. The model weights +must still be on CPU when this function is called. Move to the Neuron device +*after* this call:: + + apply_tp_wan_transformer_neuron(transformer, tp_mesh) + transformer = transformer.to(device) + +TP plan per ``WanTransformerBlock`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Self-attention (``attn1``): + ``to_q``, ``to_k``, ``to_v`` → ``ColwiseParallel`` + ``to_out.0`` → ``RowwiseParallel`` + +Cross-attention (``attn2``): + ``to_q``, ``to_k``, ``to_v`` → ``ColwiseParallel`` + ``to_out.0`` → ``RowwiseParallel`` + +Feed-forward (GELU-approximate): + ``ffn.net.0.proj`` → ``ColwiseParallel`` (dim → ffn_dim // tp_size) + ``ffn.net.2`` → ``RowwiseParallel`` (ffn_dim // tp_size → dim) + +Non-TP'd layers (small relative to 40 blocks; replicated): + ``patch_embedding``, ``condition_embedder``, ``norm_out``, ``proj_out``, + ``scale_shift_table``. + +norm_q / norm_k correctness +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``WanAttention`` applies ``RMSNorm(inner_dim=5120)`` to Q and K *before* +splitting into heads. After ``ColwiseParallel`` on ``to_q``/``to_k`` each +rank holds ``inner_dim // tp_size`` features. ``WanAttnProcessorTP`` +handles this via ``_apply_global_rms_norm``: the RMS is computed globally +across all TP ranks via ``dist.all_reduce(SUM)`` over local sum-of-squares, +so every rank applies the same scale — identical to the non-TP result. +The norm weight is sliced to each rank's portion. + +RoPE float64 → bfloat16 cast +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``WanRotaryPosEmbed`` stores ``freqs_cos`` / ``freqs_sin`` as float64 for +construction precision. ``WanAttnProcessorTP`` casts to ``hidden_states.dtype`` +(bfloat16) before computing RoPE to avoid mixed-dtype XLA ops which produce +NaN on Neuron. + +RoPE non-contiguous patch +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``WanRotaryPosEmbed.forward`` uses ``.expand()`` which produces non-contiguous +views; the XLA cat kernel rejects them. +``apply_tp_wan_transformer_neuron`` replaces ``transformer.rope.forward`` with +``_rope_forward_contiguous`` which calls ``.contiguous()`` on each expanded +tensor before the cat. +""" + +import torch +import torch.distributed as dist +import torch.nn as nn + +from .transformer_flux2_neuron_tp import _pre_shard_and_tp +from .transformer_wan import WanAttnProcessor + + +class WanAttnProcessorTP(WanAttnProcessor): + """TP-aware attention processor for ``WanAttention`` on Neuron. + + Differences from ``WanAttnProcessor``: + + 1. ``local_heads = attn.heads // tp_size`` — consistent with the sharded + ``inner_dim`` after ``ColwiseParallel``. + 2. ``_apply_global_rms_norm`` — computes the RMS globally across all TP + ranks via ``dist.all_reduce(SUM)`` over local sum-of-squares, then + applies the weight slice for this rank's portion. + 3. RoPE dtype cast — ``freqs_cos``/``freqs_sin`` are cast to + ``hidden_states.dtype`` (bfloat16) before use to avoid NaN on Neuron. + 4. I2V path (``add_k_proj`` / ``add_v_proj``) is not implemented — T2V only. + + Args: + tp_size: Total tensor-parallel degree (``tp_mesh.size()``). + rank: Current rank (``dist.get_rank()``). + """ + + def __init__(self, tp_size: int, rank: int): + super().__init__() + self.tp_size = tp_size + self.rank = rank + + def _apply_global_rms_norm( + self, + x: torch.Tensor, + norm_module: nn.Module, + local_dim: int, + ) -> torch.Tensor: + """RMSNorm with global RMS computed across all TP ranks via all-reduce. + + ``x``: ``[B, S, local_dim]`` — local shard after ColwiseParallel. + ``norm_module.weight``: ``[inner_dim = local_dim * tp_size]``. + + The global RMS = ``sqrt(sum_all_ranks(local_sum_sq) / inner_dim + eps)``. + All ranks contribute their local sum-of-squares via + ``dist.all_reduce(SUM)``, so each rank uses the same scale — identical + to the non-TP RMSNorm result. + """ + start = self.rank * local_dim + end = start + local_dim + w = norm_module.weight[start:end] + eps = getattr(norm_module, "eps", 1e-5) + inner_dim = local_dim * self.tp_size + # Local sum of squares [B, S, 1]; summed across TP ranks. + local_sq_sum = x.float().pow(2).sum(dim=-1, keepdim=True) + dist.all_reduce(local_sq_sum, op=dist.ReduceOp.SUM) + rms = local_sq_sum.div(inner_dim).add(eps).rsqrt() + return (x.float() * rms * w.float()).type_as(x) + + def __call__( + self, + attn: "WanAttention", # noqa: F821 + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + from torch.distributed.tensor import DTensor + + def _to_local(t: torch.Tensor) -> torch.Tensor: + return t.to_local() if isinstance(t, DTensor) else t + + local_heads = attn.heads // self.tp_size + orig_dtype = hidden_states.dtype + + # to_q/to_k/to_v are ColwiseParallel — output is a Shard(-1) DTensor; + # _to_local extracts the plain local shard [B, S, local_dim]. + if encoder_hidden_states is None: + # Self-attention + query = _to_local(attn.to_q(hidden_states)) + key = _to_local(attn.to_k(hidden_states)) + value = _to_local(attn.to_v(hidden_states)) + else: + # Cross-attention: encoder_hidden_states is replicated on all ranks. + # Run in float32 so that the (cond - uncond) guidance difference is + # at full precision; bfloat16 errors here get amplified by CFG scale. + query = _to_local(attn.to_q(hidden_states)).float() + key = _to_local(attn.to_k(encoder_hidden_states)).float() + value = _to_local(attn.to_v(encoder_hidden_states)).float() + + local_dim = query.shape[-1] # inner_dim // tp_size + + # Global RMSNorm with weight sliced to this rank's portion. + query = self._apply_global_rms_norm(query, attn.norm_q, local_dim) + key = self._apply_global_rms_norm(key, attn.norm_k, local_dim) + + # Reshape: [B, S, local_dim] → [B, S, local_heads, head_dim] + query = query.unflatten(-1, (local_heads, -1)) + key = key.unflatten(-1, (local_heads, -1)) + value = value.unflatten(-1, (local_heads, -1)) + + # RoPE only in self-attention (cross-attention has rotary_emb=None). + # freqs_cos/sin are float64; cast to hidden_states dtype to avoid NaN. + if rotary_emb is not None: + freqs_cos, freqs_sin = rotary_emb + + def _apply_rotary_emb(hs, fc, fs): + x1, x2 = hs.unflatten(-1, (-1, 2)).unbind(-1) + cos = fc[..., 0::2].to(hs.dtype) + sin = fs[..., 1::2].to(hs.dtype) + out = torch.empty_like(hs) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out + + query = _apply_rotary_emb(query, freqs_cos, freqs_sin) + key = _apply_rotary_emb(key, freqs_cos, freqs_sin) + + # BSHD-layout attention; parallel_config only for self-attention. + from ..attention_dispatch import dispatch_attention_fn + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), + ) + # [B, S, local_heads, head_dim] → [B, S, local_dim] + # Cast back to orig_dtype so RowwiseParallel to_out receives expected dtype. + hidden_states = hidden_states.flatten(2, 3).to(orig_dtype) + + # to_out[0] is RowwiseParallel: local matmul + all-reduce → [B, S, inner_dim]. + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) # Dropout (no-op at eval) + return hidden_states + + +def _make_rope_forward_contiguous(rope_mod: nn.Module): + """Return a patched ``WanRotaryPosEmbed.forward`` that calls ``.contiguous()`` + on expanded tensors before ``torch.cat``. + + ``WanRotaryPosEmbed.forward`` uses ``.expand()`` which produces non-contiguous + views. The XLA cat kernel on Neuron rejects non-contiguous inputs. This + patch is identical in semantics but inserts ``.contiguous()`` after each + expand to ensure packed memory layouts. + """ + + def _forward(hidden_states: torch.Tensor): + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = rope_mod.patch_size + ppf = num_frames // p_t + pph = height // p_h + ppw = width // p_w + split_sizes = [rope_mod.t_dim, rope_mod.h_dim, rope_mod.w_dim] + freqs_cos_split = rope_mod.freqs_cos.split(split_sizes, dim=1) + freqs_sin_split = rope_mod.freqs_sin.split(split_sizes, dim=1) + # .expand() produces non-contiguous views; .contiguous() copies them so + # that XLA cat sees packed memory layouts. + freqs_cos_f = freqs_cos_split[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1).contiguous() + freqs_cos_h = freqs_cos_split[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1).contiguous() + freqs_cos_w = freqs_cos_split[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1).contiguous() + freqs_sin_f = freqs_sin_split[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1).contiguous() + freqs_sin_h = freqs_sin_split[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1).contiguous() + freqs_sin_w = freqs_sin_split[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1).contiguous() + freqs_cos_out = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape( + 1, ppf * pph * ppw, 1, -1 + ) + freqs_sin_out = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape( + 1, ppf * pph * ppw, 1, -1 + ) + return freqs_cos_out, freqs_sin_out + + return _forward + + +def apply_tp_wan_transformer_neuron( + model: "WanTransformer3DModel", # noqa: F821 + tp_mesh: "torch.distributed.device_mesh.DeviceMesh", +) -> "WanTransformer3DModel": # noqa: F821 + """Apply tensor parallelism to a ``WanTransformer3DModel`` on Neuron. + + Steps: + + 1. Patch ``model.rope.forward`` with ``_make_rope_forward_contiguous`` to + fix XLA non-contiguous tensor errors in ``WanRotaryPosEmbed``. + 2. For each ``WanTransformerBlock``: + + a. Pre-shard Linear weights via ``DTensor.from_local`` (workaround for + the Neuron NRT consecutive-reduce-scatter bug on large tensors). + b. Call ``parallelize_module`` to register input/output hooks. + c. Replace the attention processor on ``attn1`` and ``attn2`` with + ``WanAttnProcessorTP``. + + The model weights must still be on CPU when this function is called. + Move to the Neuron device *after*:: + + apply_tp_wan_transformer_neuron(transformer, tp_mesh) + transformer = transformer.to(device) + + Args: + model: ``WanTransformer3DModel`` with weights on CPU. + tp_mesh: 1-D Neuron device mesh of size ``tp_size``. + + Returns: + The same ``model`` instance, modified in-place. + """ + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + + rank = dist.get_rank() + tp_size = tp_mesh.size() + + # Patch RoPE forward to add .contiguous() before XLA cat. + model.rope.forward = _make_rope_forward_contiguous(model.rope) + + # TP plan per WanTransformerBlock. + # No weight permutations needed: WAN uses separate Q/K/V linears (not fused) + # and GELU (not SwiGLU), so column-wise slicing is always correct. + block_plan = { + # Self-attention + "attn1.to_q": ColwiseParallel(), + "attn1.to_k": ColwiseParallel(), + "attn1.to_v": ColwiseParallel(), + "attn1.to_out.0": RowwiseParallel(), + # Cross-attention (encoder_hidden_states replicated on all ranks) + "attn2.to_q": ColwiseParallel(), + "attn2.to_k": ColwiseParallel(), + "attn2.to_v": ColwiseParallel(), + "attn2.to_out.0": RowwiseParallel(), + # Feed-forward: net[0] is GELU (has .proj), net[2] is output Linear + "ffn.net.0.proj": ColwiseParallel(), + "ffn.net.2": RowwiseParallel(), + } + + processor = WanAttnProcessorTP(tp_size=tp_size, rank=rank) + + for block in model.blocks: + _pre_shard_and_tp(block, tp_mesh, block_plan, rank, tp_size) + block.attn1.set_processor(processor) + block.attn2.set_processor(processor) + + return model From cbe8f28a7d327cfe128412e8c6c5b3aeacc649ec Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 14 Apr 2026 16:53:24 +0000 Subject: [PATCH 08/12] fix:pixart --- src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index bda4e40f3768..2537b541cb7b 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -862,6 +862,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + # Neuron compile backend does not support int64; downcast mask to int32. + if is_torch_neuronx_available() and prompt_attention_mask.dtype == torch.int64: + prompt_attention_mask = prompt_attention_mask.to(torch.int32) + # 4. Prepare timesteps if XLA_AVAILABLE or is_torch_neuronx_available(): timestep_device = "cpu" From 16b96069a060ed518b9f912845509ec8a84f6780 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 15 Apr 2026 16:00:50 +0000 Subject: [PATCH 09/12] fix: rewrite flux swiglu activation to avoid gather op in neuron IR --- src/diffusers/models/transformers/transformer_flux2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 43d36d6476af..69516ea4a37b 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -291,8 +291,8 @@ def __init__(self): self.gate_fn = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: - x1, x2 = x.chunk(2, dim=-1) - x = self.gate_fn(x1) * x2 + half = x.shape[-1] // 2 + x = self.gate_fn(x[..., :half]) * x[..., half:] return x From 7f13f6859d0618cc1a610bda170fe99de0f58fc6 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 15 Apr 2026 17:26:58 +0000 Subject: [PATCH 10/12] test: pixart compile mode on neuron --- tests/pipelines/pixart_alpha/test_pixart.py | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index 86fe673a8c7d..c97cebf5f56e 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -381,3 +381,47 @@ def test_pixart_512_without_resolution_binning(self): no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1] assert not np.allclose(image_slice, no_res_bin_image_slice, atol=1e-4, rtol=1e-4) + + @unittest.skipUnless(is_torch_neuronx_available(), "torch_neuronx not available") + def test_pixart_512_neuron_compile(self): + """ + Smoke-test PixArtAlphaPipeline under torch.compile(backend="neuron") at 512×512. + """ + import torch_neuronx # noqa: F401 — registers torch.neuron + from torch_neuronx.neuron_dynamo_backend import set_model_name + + device = torch.neuron.current_device() + generator = torch.Generator("cpu").manual_seed(0) + + pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.bfloat16) + pipe = pipe.to(device) + # Flush pending lazy-XLA parameter-copy ops before compiling. + torch.neuron.synchronize() + + pipe.transformer.eval() + pipe.vae.eval() + pipe.text_encoder.eval() + + set_model_name("pixart_text_encoder") + pipe.text_encoder = torch.compile(pipe.text_encoder, backend="neuron", fullgraph=True) + set_model_name("pixart_transformer") + pipe.transformer = torch.compile(pipe.transformer, backend="neuron", fullgraph=True) + # VAE must be compiled after pipeline __init__ (which reads vae.config.block_out_channels). + set_model_name("pixart_vae") + pipe.vae = torch.compile(pipe.vae, backend="neuron", fullgraph=True) + + image = pipe( + self.prompt, + generator=generator, + height=512, + width=512, + num_inference_steps=2, + output_type="np", + ).images + + self.assertEqual(image.shape, (1, 512, 512, 3)) + self.assertFalse(np.isnan(image).any(), "Output contains NaN values") + self.assertTrue( + (image >= 0.0).all() and (image <= 1.0).all(), + "Output pixel values outside [0, 1]", + ) From a354b88357de4c9e0edddc3b1ef8169545b6ca3e Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 11 May 2026 13:18:08 +0000 Subject: [PATCH 11/12] cleanup & fix style --- .../pipelines/flux2/pipeline_flux2_klein.py | 8 +++++++- .../pipeline_stable_diffusion_xl.py | 15 ++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 1f3b5c3c4fde..5c2fbc63a9e9 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -24,7 +24,7 @@ from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import randn_tensor, torch_device from ..pipeline_utils import DiffusionPipeline from .image_processor import Flux2ImageProcessor from .pipeline_output import Flux2PipelineOutput @@ -900,7 +900,13 @@ def __call__( # Pass pre-computed latent height/width to avoid DtoH sync from torch.max().item() latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + latent_device = latents.device + if torch_device == "neuron": + latents = latents.cpu() + latent_ids = latent_ids.cpu() latents = self._unpack_latents_with_ids(latents, latent_ids, latent_height // 2, latent_width // 2) + if torch_device == "neuron": + latents = latents.to(latent_device) latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index fdda2547f09e..4550ebcd980a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -45,7 +45,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import randn_tensor, torch_device from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput @@ -1092,11 +1092,8 @@ def __call__( ) # 4. Prepare timesteps - # Keep timesteps on CPU for XLA (TPU) and Neuron: both use lazy/XLA execution where - # dynamic-shape ops like .nonzero() and .item() inside scheduler.index_for_timestep() - # are incompatible with static-graph compilation. - is_neuron_device = hasattr(device, "type") and device.type == "neuron" - if XLA_AVAILABLE or is_neuron_device: + is_neuron = torch_device == "neuron" + if XLA_AVAILABLE or is_neuron: timestep_device = "cpu" else: timestep_device = device @@ -1201,7 +1198,7 @@ def __call__( # For Neuron: scale_model_input on CPU to avoid XLA ops outside the compiled UNet region. # index_for_timestep() uses .nonzero()/.item() which are incompatible with static graphs. - if is_neuron_device: + if is_neuron: latent_model_input = self.scheduler.scale_model_input(latent_model_input.to("cpu"), t).to(device) else: latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1212,7 +1209,7 @@ def __call__( added_cond_kwargs["image_embeds"] = image_embeds # For Neuron: pre-cast timestep to float32 on device. Neuron XLA does not support # int64 ops; the compiled UNet graph requires a float32 timestep input on-device. - t_unet = t.to(torch.float32).to(device) if is_neuron_device else t + t_unet = t.to(torch.float32).to(device) if is_neuron else t noise_pred = self.unet( latent_model_input, t_unet, @@ -1235,7 +1232,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype # For Neuron: scheduler.step on CPU to keep scheduler arithmetic off the XLA device. - if is_neuron_device: + if is_neuron: latents = self.scheduler.step( noise_pred.to("cpu"), t, latents.to("cpu"), **extra_step_kwargs, return_dict=False )[0].to(device) From c350f7bfd16a32428f416e75d7cc0190dcc602ba Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 22 Jun 2026 12:00:19 +0000 Subject: [PATCH 12/12] merge: another change --- src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 77f21c47a7fa..9aca9dd19e32 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -867,12 +867,8 @@ def __call__( prompt_attention_mask = prompt_attention_mask.to(torch.int32) # 4. Prepare timesteps -<<<<<<< HEAD - if XLA_AVAILABLE or is_torch_neuronx_available(): -======= is_neuron_device = device.type == "neuron" if XLA_AVAILABLE or is_neuron_device: ->>>>>>> main timestep_device = "cpu" else: timestep_device = device