Skip to content

[Neuron] Add tensor parallel support for Neuron backend#13718

Draft
JingyaHuang wants to merge 29 commits into
huggingface:mainfrom
JingyaHuang:support-neuron-tp
Draft

[Neuron] Add tensor parallel support for Neuron backend#13718
JingyaHuang wants to merge 29 commits into
huggingface:mainfrom
JingyaHuang:support-neuron-tp

Conversation

@JingyaHuang

Copy link
Copy Markdown
Contributor

What does this PR do?

This PR adds tensor parallel support for Neuron devices. Since TP isn't yet supported in diffusers, I followed the existing sequence parallel pattern and introduced a TensorParallelConfig.
This is still very much a work in progress. The goal at this stage is to surface the changes needed to enable TP on Neuron, not to land a stable implementation.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this!

I think it'd be simpler to keep the changes limited to a single model and a pipeline for iterating more quickly.

Additionally, a few thoughts:

I don't think we're exposing the TP config in modeling_utils.py. I think the enable_parallelism() method accept it:

def enable_parallelism(

(and include all the necessary validation)

And then instead of manually iterating on transformer_blocks and single_transformer_blocks, we could try to configure that through class-level attributes, e.g., _tp_blocks or something like that.

After these changes, we could perhaps work on presenting some numbers where TP is beneficial, etc. WDYT?

Comment on lines -88 to +92
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
lambda image_url_or_path: (
load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an unrelated change?

Comment on lines +43 to +44
A ``Flux2Transformer2DModel`` instance. Must have ``transformer_blocks``
and ``single_transformer_blocks`` attributes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It cannot be specific to a particular model-type, right?

Comment on lines +69 to +79
tp_mesh = config._mesh
if tp_mesh is None:
raise ValueError(
"`config._mesh` is None. Call `config.setup(rank, world_size, device)` before applying TP."
)

for block in model.transformer_blocks:
parallelize_module(block, tp_mesh, double_block_plan)

for block in model.single_transformer_blocks:
parallelize_module(block, tp_mesh, single_block_plan)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it similar to

def apply_context_parallel(
?

Comment thread src/diffusers/models/transformers/transformer_flux2.py
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)

return latent_ids
return latent_ids.float()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀

This doesn't seem like a related change?

Comment on lines +175 to +177
A custom device mesh to use. If provided, ``tp_degree`` is inferred from
``mesh.size()`` and the argument is ignored. Useful when combining TP with
other parallelism strategies (e.g. CP) that share the same mesh.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide an example for this?

Comment on lines 245 to +248
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, mesh)
if self.tensor_parallel_config is not None:
self.tensor_parallel_config.setup(rank, world_size, device, mesh)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's raise if both context_parallel_config and tensor_parallel_config are specified?

if self.tp_degree < 1:
raise ValueError("`tp_degree` must be >= 1.")

def setup(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this supposed to be called from?

@github-actions github-actions Bot removed the utils label Jun 22, 2026
@github-actions github-actions Bot removed the tests label Jun 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants