Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
98f6c8c
draft:add neuron as a legit backend
JingyaHuang Mar 18, 2026
c58b8b8
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Mar 18, 2026
3367409
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Mar 19, 2026
0c51734
Merge branch 'main' into add-neuron-backend
JingyaHuang Mar 25, 2026
a76953c
feat: neuron-specific changes in the pipeline
JingyaHuang Mar 26, 2026
2480388
tests: eager tests
JingyaHuang Mar 27, 2026
1469c04
draft: start with tp for flux2
JingyaHuang Apr 9, 2026
929ab72
fix: style
JingyaHuang Apr 9, 2026
52cac76
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Apr 9, 2026
30cb353
Merge branch 'huggingface:main' into support-neuron-tp
JingyaHuang Apr 9, 2026
28a5086
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/diffusers…
JingyaHuang Apr 9, 2026
7fab0c4
Merge branch 'huggingface:main' into support-neuron-tp
JingyaHuang Apr 10, 2026
68689e5
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Apr 10, 2026
da79308
Merge branch 'main' into add-neuron-backend
JingyaHuang Apr 10, 2026
3bb9c7c
fix:apr_02 beta
JingyaHuang Apr 10, 2026
c4facab
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/diffusers…
JingyaHuang Apr 10, 2026
dff1f32
feat:add wan
JingyaHuang Apr 10, 2026
1c930c4
Merge branch 'huggingface:main' into support-neuron-tp
JingyaHuang Apr 13, 2026
1eb5ff9
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Apr 13, 2026
cbe8f28
fix:pixart
JingyaHuang Apr 14, 2026
16b9606
fix: rewrite flux swiglu activation to avoid gather op in neuron IR
JingyaHuang Apr 15, 2026
7f13f68
test: pixart compile mode on neuron
JingyaHuang Apr 15, 2026
a46cb19
Merge branch 'main' into neuron-torch-comppile
JingyaHuang Apr 22, 2026
a354b88
cleanup & fix style
JingyaHuang May 11, 2026
931bb85
Merge branch 'neuron-torch-comppile' into support-neuron-tp
JingyaHuang May 11, 2026
9ab6dc3
Merge branch 'main' into support-neuron-tp
JingyaHuang May 11, 2026
48fb75b
Merge branch 'main' into support-neuron-tp
JingyaHuang Jun 22, 2026
c350f7b
merge: another change
JingyaHuang Jun 22, 2026
644477a
Merge branch 'main' into support-neuron-tp
JingyaHuang Jun 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ def log_validation(pipeline, args, accelerator, generator, global_step, is_final
os.makedirs(val_save_dir)

original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
lambda image_url_or_path: (
load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)
Comment on lines -88 to +92

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an unrelated change?

)(args.val_image_url_or_path)

if torch.backends.mps.is_available():
Expand Down
79 changes: 79 additions & 0 deletions src/diffusers/hooks/tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from ..models._modeling_parallel import TensorParallelConfig
from ..utils import get_logger


logger = get_logger(__name__) # pylint: disable=invalid-name


def apply_tensor_parallel(
model: torch.nn.Module,
config: TensorParallelConfig,
double_block_plan: dict,
single_block_plan: dict,
) -> None:
"""Apply tensor parallelism to a ``Flux2Transformer2DModel``.

This is the generic (non-Neuron) path. It calls
``torch.distributed.tensor.parallel.parallelize_module`` directly on each
transformer block, using the plans defined on the model.

For Neuron, use ``apply_tp_flux2_transformer_neuron`` from
``diffusers.models.transformers.transformer_flux2_neuron_tp`` instead, which
pre-shards weights via ``DTensor.from_local`` to work around the Neuron NRT
consecutive-reduce-scatter bug.

Args:
model (`torch.nn.Module`):
A ``Flux2Transformer2DModel`` instance. Must have ``transformer_blocks``
and ``single_transformer_blocks`` attributes.
Comment on lines +43 to +44

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It cannot be specific to a particular model-type, right?

config (`TensorParallelConfig`):
TP configuration. ``config.setup()`` must have been called before this
function so that ``config._mesh`` is populated.
double_block_plan (`dict`):
``parallelize_module`` plan for each double-stream block
(``model.transformer_blocks``). Keys are relative module paths
(e.g. ``"attn.to_q"``), values are ``ColwiseParallel()`` /
``RowwiseParallel()`` instances.
single_block_plan (`dict`):
``parallelize_module`` plan for each single-stream block
(``model.single_transformer_blocks``).
"""
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
raise RuntimeError(
"apply_tensor_parallel requires an initialised torch.distributed process group."
)

try:
from torch.distributed.tensor.parallel import parallelize_module
except ImportError as e:
raise ImportError(
"apply_tensor_parallel requires PyTorch >= 2.3 with distributed tensor parallel support."
) from e

tp_mesh = config._mesh
if tp_mesh is None:
raise ValueError(
"`config._mesh` is None. Call `config.setup(rank, world_size, device)` before applying TP."
)

for block in model.transformer_blocks:
parallelize_module(block, tp_mesh, double_block_plan)

for block in model.single_transformer_blocks:
parallelize_module(block, tp_mesh, single_block_plan)
Comment on lines +69 to +79

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it similar to

def apply_context_parallel(
?

2 changes: 1 addition & 1 deletion src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
logger = logging.get_logger(__name__)

_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
lambda: (lambda model_cls, weights: weights),
lambda: lambda model_cls, weights: weights,
{
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
Expand Down
63 changes: 62 additions & 1 deletion src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
# - Unified Attention
# - More dispatcher attention backends
# - CFG/Data Parallel
# - Tensor Parallel


@dataclass
Expand Down Expand Up @@ -154,6 +153,63 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()


@dataclass
class TensorParallelConfig:
"""
Configuration for tensor parallelism.

Tensor parallelism shards weight matrices (column-wise and row-wise) across devices.
Each device computes a partial result; an AllReduce/AllGather at layer boundaries
reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module``
with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles.

On Neuron, use the ``_pre_shard_and_tp`` workaround from
``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug
on large tensors (>= 5120x5120).

Args:
tp_degree (`int`, defaults to `1`):
Number of devices to shard across. Must be a divisor of the number of
attention heads (and FFN hidden dimensions) of the model being parallelised.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use. If provided, ``tp_degree`` is inferred from
``mesh.size()`` and the argument is ignored. Useful when combining TP with
other parallelism strategies (e.g. CP) that share the same mesh.
Comment on lines +175 to +177

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide an example for this?

"""

tp_degree: int = 1
mesh: torch.distributed.device_mesh.DeviceMesh | None = None

_rank: int = None
_world_size: int = None
_device: torch.device = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None

def __post_init__(self):
if self.tp_degree < 1:
raise ValueError("`tp_degree` must be >= 1.")

def setup(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this supposed to be called from?

self,
rank: int,
world_size: int,
device: torch.device,
mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
if mesh is not None:
self._mesh = mesh
elif self.mesh is not None:
self._mesh = self.mesh
else:
from torch.distributed.device_mesh import init_device_mesh

device_type = str(device).split(":")[0]
self._mesh = init_device_mesh(device_type, (self.tp_degree,), mesh_dim_names=("tp",))


@dataclass
class ParallelConfig:
"""
Expand All @@ -162,9 +218,12 @@ class ParallelConfig:
Args:
context_parallel_config (`ContextParallelConfig`, *optional*):
Configuration for context parallelism.
tensor_parallel_config (`TensorParallelConfig`, *optional*):
Configuration for tensor parallelism.
"""

context_parallel_config: ContextParallelConfig | None = None
tensor_parallel_config: TensorParallelConfig | None = None

_rank: int = None
_world_size: int = None
Expand All @@ -185,6 +244,8 @@ def setup(
self._mesh = mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, mesh)
if self.tensor_parallel_config is not None:
self.tensor_parallel_config.setup(rank, world_size, device, mesh)
Comment on lines 245 to +248

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's raise if both context_parallel_config and tensor_parallel_config are specified?



@dataclass(frozen=True)
Expand Down
182 changes: 182 additions & 0 deletions src/diffusers/models/transformers/transformer_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,146 @@ def __call__(
return hidden_states


class Flux2AttnProcessorTP(Flux2AttnProcessor):
"""
TP-aware version of ``Flux2AttnProcessor`` for double-stream transformer blocks.

After column-wise weight sharding, each rank holds ``attn.heads // tp_size`` heads.
The only difference from the base class is that ``unflatten`` uses the local head
count rather than the full ``attn.heads``.

Args:
tp_size (`int`): Number of tensor-parallel ranks (== ``tp_mesh.size()``).
"""

def __init__(self, tp_size: int):
super().__init__()
self.tp_size = tp_size

def __call__(
self,
attn: "Flux2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
local_heads = attn.heads // self.tp_size
head_dim = attn.head_dim

query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)

query = query.unflatten(-1, (local_heads, head_dim))
key = key.unflatten(-1, (local_heads, head_dim))
value = value.unflatten(-1, (local_heads, head_dim))

query = attn.norm_q(query)
key = attn.norm_k(key)

if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (local_heads, head_dim))
encoder_key = encoder_key.unflatten(-1, (local_heads, head_dim))
encoder_value = encoder_value.unflatten(-1, (local_heads, head_dim))

encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)

query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)

if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)

hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)

if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)

if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
return hidden_states


class Flux2ParallelSelfAttnProcessorTP(Flux2ParallelSelfAttnProcessor):
"""
TP-aware version of ``Flux2ParallelSelfAttnProcessor`` for single-stream blocks.

After column-wise weight sharding the fused ``to_qkv_mlp_proj`` projection,
each rank holds a proportionally smaller slice of Q/K/V and MLP dimensions.
The split sizes are computed from the local (per-rank) head count and inner dim.

Args:
tp_size (`int`): Number of tensor-parallel ranks (== ``tp_mesh.size()``).
"""

def __init__(self, tp_size: int):
super().__init__()
self.tp_size = tp_size

def __call__(
self,
attn: "Flux2ParallelSelfAttention",
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
local_heads = attn.heads // self.tp_size
head_dim = attn.head_dim
local_inner = attn.inner_dim // self.tp_size
local_mlp_gate = attn.mlp_hidden_dim * attn.mlp_mult_factor // self.tp_size

hidden_states = attn.to_qkv_mlp_proj(hidden_states)
qkv, mlp_hidden_states = torch.split(hidden_states, [3 * local_inner, local_mlp_gate], dim=-1)

query, key, value = qkv.chunk(3, dim=-1)

query = query.unflatten(-1, (local_heads, head_dim))
key = key.unflatten(-1, (local_heads, head_dim))
value = value.unflatten(-1, (local_heads, head_dim))

query = attn.norm_q(query)
key = attn.norm_k(key)

if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)

hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)

mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)

hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
return attn.to_out(hidden_states)


class Flux2KVParallelSelfAttnProcessor:
"""
Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens.
Expand Down Expand Up @@ -1090,6 +1230,48 @@ class Flux2Transformer2DModel(
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}

# Tensor-parallel sharding plans (one per block type).
# Used by ``apply_tensor_parallel`` (generic path) and by the Neuron-specific
# ``apply_tp_flux2_transformer_neuron`` (which also needs weight permutations).
# Populated lazily on first access to avoid importing torch.distributed.tensor
# at module import time when TP is not used.
_tp_double_block_plan: "dict | None" = None
_tp_single_block_plan: "dict | None" = None

@classmethod
def _get_tp_double_block_plan(cls) -> dict:
"""Return the TP sharding plan for double-stream (cross-attention + FFN) blocks."""
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel

if cls._tp_double_block_plan is None:
cls._tp_double_block_plan = {
"attn.to_q": ColwiseParallel(),
"attn.to_k": ColwiseParallel(),
"attn.to_v": ColwiseParallel(),
"attn.to_out.0": RowwiseParallel(),
"attn.add_q_proj": ColwiseParallel(),
"attn.add_k_proj": ColwiseParallel(),
"attn.add_v_proj": ColwiseParallel(),
"attn.to_add_out": RowwiseParallel(),
"ff.linear_in": ColwiseParallel(),
"ff.linear_out": RowwiseParallel(),
"ff_context.linear_in": ColwiseParallel(),
"ff_context.linear_out": RowwiseParallel(),
}
return cls._tp_double_block_plan

@classmethod
def _get_tp_single_block_plan(cls) -> dict:
"""Return the TP sharding plan for single-stream (parallel self-attn + fused MLP) blocks."""
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel

if cls._tp_single_block_plan is None:
cls._tp_single_block_plan = {
"attn.to_qkv_mlp_proj": ColwiseParallel(),
"attn.to_out": RowwiseParallel(),
}
return cls._tp_single_block_plan

@register_to_config
def __init__(
self,
Expand Down
Loading
Loading