Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3529a0a
template1
tolgacangoz Oct 6, 2025
4f2ee5e
temp2
tolgacangoz Oct 6, 2025
778fb54
up
tolgacangoz Oct 6, 2025
d77b6ba
up
tolgacangoz Oct 6, 2025
2fc6ac2
fix-copies
tolgacangoz Oct 6, 2025
d667d03
Add support for Wan2.2-Animate-14B model in convert_wan_to_diffusers.py
tolgacangoz Oct 7, 2025
6182d44
style
tolgacangoz Oct 7, 2025
8c9fd89
Refactor WanAnimate model components
tolgacangoz Oct 7, 2025
d01e941
Enhance `WanAnimatePipeline` with new parameters for mode and tempora…
tolgacangoz Oct 7, 2025
7af953b
Update `WanAnimatePipeline` to require additional video inputs and im…
tolgacangoz Oct 7, 2025
a0372e3
Add Wan 2.2 Animate 14B model support and introduce Wan-Animate frame…
tolgacangoz Oct 7, 2025
05a01c6
Add unit test template for `WanAnimatePipeline` functionality
tolgacangoz Oct 7, 2025
22b83ce
Add unit tests for `WanAnimateTransformer3DModel` in GGUF format
tolgacangoz Oct 7, 2025
7fb6732
style
tolgacangoz Oct 7, 2025
3e6f893
Improve the template of `transformer_wan_animate.py`
tolgacangoz Oct 7, 2025
624a314
Update `WanAnimatePipeline`
tolgacangoz Oct 7, 2025
fc0edb5
style
tolgacangoz Oct 7, 2025
eb7eedd
Refactor test for `WanAnimatePipeline` to include new input structure
tolgacangoz Oct 7, 2025
8968b42
from `einops` to `torch`
tolgacangoz Oct 8, 2025
dce83a8
Merge branch 'main' into integrations/wan2.2-animate
tolgacangoz Oct 8, 2025
75b2382
Add padding functionality to `WanAnimatePipeline` for video frames
tolgacangoz Oct 8, 2025
802896e
style
tolgacangoz Oct 8, 2025
e06098f
Enhance `WanAnimatePipeline` with additional input parameters for imp…
tolgacangoz Oct 8, 2025
84768f6
up
tolgacangoz Oct 8, 2025
06e6138
Refactor `WanAnimatePipeline` for improved tensor handling and mask g…
tolgacangoz Oct 8, 2025
5777ce0
Refactor `WanAnimatePipeline` to streamline latent tensor processing …
tolgacangoz Oct 9, 2025
b8337c6
style
tolgacangoz Oct 9, 2025
f4eb9a0
Add new layers and functions to `transformer_wan_animate.py` for enha…
tolgacangoz Oct 9, 2025
4e6651b
Merge branch 'main' into integrations/wan2.2-animate
tolgacangoz Oct 13, 2025
d80ae19
Refactor `transformer_wan_animate.py` to improve modularity and type …
tolgacangoz Oct 10, 2025
348a945
Refactor `transformer_wan_animate.py` to enhance modularity and updat…
tolgacangoz Oct 15, 2025
7774421
Update the `ConvLayer` class to conditionally apply bias based on act…
tolgacangoz Oct 17, 2025
a5536e2
Simplify
tolgacangoz Oct 17, 2025
6a8662d
refactor transformer
tolgacangoz Oct 17, 2025
96a126a
Enhance `convert_wan_to_diffusers.py` for Animate model integration
tolgacangoz Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions docs/source/en/api/pipelines/wan.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers:
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)

> [!TIP]
> Click on the Wan models in the right sidebar for more examples of video generation.
Expand Down Expand Up @@ -249,6 +250,82 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p

The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.

</hfoption>
</hfoptions>

### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication

[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.

*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*

The example below demonstrates how to use the Wan-Animate pipeline to generate a video using a text description, a starting frame, a pose video, and a face video (optionally background video and mask video) in "animation" or "replacement" mode.

<hfoptions id="Animate usage">
<hfoption id="usage">

```python
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
from transformers import CLIPVisionModel


model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.to("cuda")

# Preprocessing: The input video should be preprocessed into several materials before be feed into the inference process.
# TODO: Diffusersify the preprocessing process: !python wan/modules/animate/preprocess/preprocess_data.py


image = load_image("preprocessed_results/astronaut.jpg")
pose_video = load_video("preprocessed_results/pose_video.mp4")
face_video = load_video("preprocessed_results/face_video.mp4")

def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width

def center_crop_resize(image, height, width):
# Calculate resize ratio to match first frame dimensions
resize_ratio = max(width / image.width, height / image.height)

# Resize the image
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)

return image, height, width

image, height, width = aspect_ratio_resize(image, pipe)

prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."

#guide_scale (`float` or tuple[`float`], *optional*, defaults 1.0):
# Classifier-free guidance scale. We only use it for expression control.
# In most cases, it's not necessary and faster generation can be achieved without it.
# When expression adjustments are needed, you may consider using this feature.
output = pipe(
image=image, pose_video=pose_video, face_video=face_video, prompt=prompt, height=height, width=width, guidance_scale=1.0
).frames[0]
export_to_video(output, "output.mp4", fps=16)
```

</hfoption>
</hfoptions>

## Notes

- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
Expand Down Expand Up @@ -359,6 +436,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- all
- __call__

## WanAnimatePipeline

[[autodoc]] WanAnimatePipeline
- all
- __call__

## WanPipelineOutput

[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
138 changes: 132 additions & 6 deletions scripts/convert_wan_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import argparse
import math
import pathlib
from typing import Any, Dict, Tuple

import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel, CLIPVisionModel

from diffusers import (
AutoencoderKLWan,
UniPCMultistepScheduler,
WanAnimatePipeline,
WanAnimateTransformer3DModel,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
Expand Down Expand Up @@ -105,8 +108,88 @@
"after_proj": "proj_out",
}

ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
# Motion encoder mappings
"motion_encoder.enc.net_app.convs": "condition_embedder.motion_embedder.convs",
"motion_encoder.enc.fc": "condition_embedder.motion_embedder.linears",
"motion_encoder.dec.direction.weight": "condition_embedder.motion_embedder.weight",
# Face encoder mappings
"face_encoder.conv1_local": "condition_embedder.face_embedder.conv1_local",
"face_encoder.conv2": "condition_embedder.face_embedder.conv2",
"face_encoder.conv3": "condition_embedder.face_embedder.conv3",
"face_encoder.out_proj": "condition_embedder.face_embedder.out_proj",
"face_encoder.norm1": "condition_embedder.face_embedder.norm1",
"face_encoder.norm2": "condition_embedder.face_embedder.norm2",
"face_encoder.norm3": "condition_embedder.face_embedder.norm3",
"face_encoder.padding_tokens": "condition_embedder.face_embedder.padding_tokens",
# Face adapter mappings
"face_adapter.fuser_blocks": "face_adapter",
}

def convert_equal_linear_weight(key: str, state_dict: Dict[str, Any]) -> None:
"""
Convert EqualLinear weights to standard Linear weights by applying the scale factor.
EqualLinear uses: F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
where scale = (1 / sqrt(in_dim)) * lr_mul
"""

def convert_equal_conv2d_weight(key: str, state_dict: Dict[str, Any]) -> None:
"""
Convert EqualConv2d weights to standard Conv2d weights by applying the scale factor.
EqualConv2d uses: F.conv2d(input, self.weight * self.scale, bias=self.bias, ...)
where scale = 1 / sqrt(in_channel * kernel_size^2)
"""

def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any]) -> None:
"""
Convert all motion encoder weights for Animate model.
This handles both EqualLinear (in fc/linears) and EqualConv2d (in conv layers).

In the original model:
- All Linear layers in fc use EqualLinear
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
- Blur kernels are stored as buffers in Sequential modules
"""

TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {"condition_embedder.motion_embedder": convert_animate_motion_encoder_weights}


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -364,6 +447,31 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-Animate-14B":
config = {
"model_id": "Wan-AI/Wan2.2-Animate-14B",
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"motion_encoder_dim": 512,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": (1, 2, 2),
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": 257 * 2,
},
}
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP


Expand All @@ -380,10 +488,12 @@ def convert_transformer(model_type: str, stage: str = None):
original_state_dict = load_sharded_safetensors(model_dir)

with init_empty_weights():
if "VACE" not in model_type:
transformer = WanTransformer3DModel.from_config(diffusers_config)
else:
if "Animate" in model_type:
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
elif "VACE" in model_type:
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
else:
transformer = WanTransformer3DModel.from_config(diffusers_config)

for key in list(original_state_dict.keys()):
new_key = key[:]
Expand Down Expand Up @@ -926,7 +1036,7 @@ def get_args():
if __name__ == "__main__":
args = get_args()

if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
transformer = convert_transformer(args.model_type, stage="high_noise_model")
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
else:
Expand All @@ -942,7 +1052,7 @@ def get_args():
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
if "FLF2V" in args.model_type:
flow_shift = 16.0
elif "TI2V" in args.model_type:
elif "TI2V" in args.model_type or "Animate" in args.model_type:
flow_shift = 5.0
else:
flow_shift = 3.0
Expand All @@ -954,6 +1064,8 @@ def get_args():
if args.dtype != "none":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
if transformer_2 is not None:
transformer_2.to(dtype)

if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
pipe = WanImageToVideoPipeline(
Expand Down Expand Up @@ -1016,6 +1128,20 @@ def get_args():
vae=vae,
scheduler=scheduler,
)
elif "Animate" in args.model_type:
image_encoder = CLIPVisionModel.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
pipe = WanAnimatePipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
image_processor=image_processor,
)
else:
pipe = WanPipeline(
transformer=transformer,
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"attention_backend",
Expand Down Expand Up @@ -618,6 +619,7 @@
"VisualClozeGenerationPipeline",
"VisualClozePipeline",
"VQDiffusionPipeline",
"WanAnimatePipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVACEPipeline",
Expand Down Expand Up @@ -949,6 +951,7 @@
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
Expand Down Expand Up @@ -1279,6 +1282,7 @@
VisualClozeGenerationPipeline,
VisualClozePipeline,
VQDiffusionPipeline,
WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVACEPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
Expand Down Expand Up @@ -198,6 +199,7 @@
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
Loading