From 59e48014c6f31cbb940db18c9402b59707cd4624 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Thu, 18 Jun 2026 12:10:13 +0000 Subject: [PATCH 1/5] feat: add image edit plus --- ...convert_joyimage_edit_plus_to_diffusers.py | 290 ++++++++ src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_joyimage.py | 2 + .../transformer_joyimage_edit_plus.py | 365 +++++++++ src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/joyimage/__init__.py | 7 +- .../joyimage/pipeline_joyimage_edit_plus.py | 697 ++++++++++++++++++ .../pipelines/joyimage/pipeline_output.py | 8 + 10 files changed, 1377 insertions(+), 5 deletions(-) create mode 100644 scripts/convert_joyimage_edit_plus_to_diffusers.py create mode 100644 src/diffusers/models/transformers/transformer_joyimage_edit_plus.py create mode 100644 src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py diff --git a/scripts/convert_joyimage_edit_plus_to_diffusers.py b/scripts/convert_joyimage_edit_plus_to_diffusers.py new file mode 100644 index 000000000000..f01adb03c747 --- /dev/null +++ b/scripts/convert_joyimage_edit_plus_to_diffusers.py @@ -0,0 +1,290 @@ +import argparse +from typing import Any, Dict, Tuple + +import torch +from accelerate import init_empty_weights +from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration + +from diffusers import ( + AutoencoderKLWan, + JoyImageEditPlusPipeline, +) +from diffusers.models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) + + +# VAE conversion reused from convert_joyimage_edit_to_diffusers.py (identical VAE) +def convert_vae(vae_ckpt_path): + old_state_dict = torch.load(vae_ckpt_path, weights_only=True) + new_state_dict = {} + + middle_key_mapping = { + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + attention_mapping = { + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + head_mapping = { + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + for key, value in old_state_dict.items(): + if key in middle_key_mapping: + new_state_dict[middle_key_mapping[key]] = value + elif key in attention_mapping: + new_state_dict[attention_mapping[key]] = value + elif key in head_mapping: + new_state_dict[head_mapping[key]] = value + elif key in quant_mapping: + new_state_dict[quant_mapping[key]] = value + elif key == "encoder.conv1.weight": + new_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + new_state_dict["encoder.conv_in.bias"] = value + elif key == "decoder.conv1.weight": + new_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + new_state_dict["decoder.conv_in.bias"] = value + elif key.startswith("encoder.downsamples."): + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + new_state_dict[new_key] = value + elif key.startswith("decoder.upsamples."): + parts = key.split(".") + block_idx = int(parts[2]) + + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + new_state_dict[key] = value + continue + + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + new_state_dict[new_key] = value + + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + new_state_dict[new_key] = value + + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + + with init_empty_weights(): + vae = AutoencoderKLWan() + vae.load_state_dict(new_state_dict, strict=True, assign=True) + return vae + + +def get_transformer_config() -> Dict[str, Any]: + return { + "hidden_size": 4096, + "in_channels": 16, + "num_attention_heads": 32, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_dim": 4096, + "rope_type": "rope", + "theta": 10000, + } + + +def convert_transformer(ckpt_path: str): + checkpoint = torch.load(ckpt_path, weights_only=True) + if "model" in checkpoint: + original_state_dict = checkpoint["model"] + else: + original_state_dict = checkpoint + + attn_suffixes = ( + "img_attn_qkv.", + "img_attn_q_norm.", + "img_attn_k_norm.", + "img_attn_proj.", + "txt_attn_qkv.", + "txt_attn_q_norm.", + "txt_attn_k_norm.", + "txt_attn_proj.", + ) + remapped = {} + for key, value in original_state_dict.items(): + new_key = key + if key.startswith("double_blocks."): + for suffix in attn_suffixes: + if "." + suffix in key and ".attn." + suffix not in key: + new_key = key.replace("." + suffix, ".attn." + suffix) + break + remapped[new_key] = value + + config = get_transformer_config() + with init_empty_weights(): + transformer = JoyImageEditPlusTransformer3DModel(**config) + transformer.load_state_dict(remapped, strict=True, assign=True) + return transformer + + +def get_args(): + parser = argparse.ArgumentParser(description="Convert JoyImage Edit Plus checkpoints to diffusers format") + parser.add_argument("--transformer_ckpt_path", type=str, default=None) + parser.add_argument("--vae_ckpt_path", type=str, default=None) + parser.add_argument("--text_encoder_path", type=str, default=None) + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--dtype", default="bf16", help="Torch dtype (fp32, fp16, bf16)") + parser.add_argument("--flow_shift", type=float, default=1.5) + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +if __name__ == "__main__": + args = get_args() + transformer = None + vae = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + vae = vae.to(dtype=dtype) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.save_pipeline: + processor = AutoProcessor.from_pretrained(args.text_encoder_path) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( + args.text_encoder_path, torch_dtype=torch.bfloat16 + ).to("cuda") + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.flow_shift) + transformer = transformer.to("cuda") + vae = vae.to("cuda") + pipe = JoyImageEditPlusPipeline( + processor=processor, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ).to("cuda") + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + processor.save_pretrained(f"{args.output_path}/processor") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6353347503e1..69986ce8f67b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -275,6 +275,7 @@ "I2VGenXLUNet", "Ideogram4Transformer2DModel", "JoyImageEditTransformer3DModel", + "JoyImageEditPlusTransformer3DModel", "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", @@ -623,6 +624,8 @@ "ImageTextPipelineOutput", "JoyImageEditPipeline", "JoyImageEditPipelineOutput", + "JoyImageEditPlusPipeline", + "JoyImageEditPlusPipelineOutput", "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", "Kandinsky5I2IPipeline", @@ -1135,6 +1138,7 @@ I2VGenXLUNet, Ideogram4Transformer2DModel, JoyImageEditTransformer3DModel, + JoyImageEditPlusTransformer3DModel, Kandinsky3UNet, Kandinsky5Transformer3DModel, LatteTransformer3DModel, @@ -1458,6 +1462,8 @@ ImageTextPipelineOutput, JoyImageEditPipeline, JoyImageEditPipelineOutput, + JoyImageEditPlusPipeline, + JoyImageEditPlusPipelineOutput, Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, Kandinsky5I2IPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7a1d0801f2c5..8969441121fb 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -121,6 +121,7 @@ _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_ideogram4"] = ["Ideogram4Transformer2DModel"] _import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel"] + _import_structure["transformers.transformer_joyimage_edit_plus"] = ["JoyImageEditPlusTransformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] @@ -254,6 +255,7 @@ HunyuanVideoTransformer3DModel, Ideogram4Transformer2DModel, JoyImageEditTransformer3DModel, + JoyImageEditPlusTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, LongCatAudioDiTTransformer, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 1edceee3ca74..442a0022f4ff 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -42,6 +42,7 @@ from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_ideogram4 import Ideogram4Transformer2DModel from .transformer_joyimage import JoyImageEditTransformer3DModel + from .transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer from .transformer_longcat_image import LongCatImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py index b17ddb05f799..d30b0501e02f 100644 --- a/src/diffusers/models/transformers/transformer_joyimage.py +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -283,6 +283,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # modulation ( @@ -312,6 +313,7 @@ def forward( hidden_states=img_modulated, encoder_hidden_states=txt_modulated, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, ) hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py new file mode 100644 index 000000000000..abc8c2b4340a --- /dev/null +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -0,0 +1,365 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import AttentionMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm +from .transformer_joyimage import ( + JoyImageAttention, + JoyImageModulate, + JoyImageTimeTextImageEmbedding, + JoyImageTransformerBlock, +) + + +logger = logging.get_logger(__name__) + + +def _apply_rotary_emb_batched( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + """RoPE that handles both batched [B, S, D] and unbatched [S, D] freqs.""" + cos, sin = freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device) + + if cos.ndim == 2: + # unbatched: [S, D] -> [1, S, 1, D] + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + elif cos.ndim == 3: + # batched: [B, S, D] -> [B, S, 1, D] + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) + + def _rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq) + xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk) + return xq_out, xk_out + + +class JoyImageEditPlusAttnProcessor: + """Attention processor that supports batched RoPE embeddings for edit-plus multi-image input.""" + + _attention_backend = None + _parallel_config = None + + def __init__(self): + pass + + def __call__( + self, + attn: "JoyImageAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is None: + raise ValueError("JoyImageEditPlusAttnProcessor requires encoder_hidden_states") + + heads = attn.heads + + img_qkv = attn.img_attn_qkv(hidden_states) + img_query, img_key, img_value = img_qkv.chunk(3, dim=-1) + + txt_qkv = attn.txt_attn_qkv(encoder_hidden_states) + txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1) + + img_query = img_query.unflatten(-1, (heads, -1)) + img_key = img_key.unflatten(-1, (heads, -1)) + img_value = img_value.unflatten(-1, (heads, -1)) + + txt_query = txt_query.unflatten(-1, (heads, -1)) + txt_key = txt_key.unflatten(-1, (heads, -1)) + txt_value = txt_value.unflatten(-1, (heads, -1)) + + img_query = attn.img_attn_q_norm(img_query) + img_key = attn.img_attn_k_norm(img_key) + txt_query = attn.txt_attn_q_norm(txt_query) + txt_key = attn.txt_attn_k_norm(txt_key) + + if image_rotary_emb is not None: + vis_freqs, txt_freqs = image_rotary_emb + if vis_freqs is not None: + img_query, img_key = _apply_rotary_emb_batched(img_query, img_key, vis_freqs) + if txt_freqs is not None: + txt_query, txt_key = _apply_rotary_emb_batched(txt_query, txt_key, txt_freqs) + + joint_query = torch.cat([img_query, txt_query], dim=1) + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + img_attn_output = joint_hidden_states[:, : hidden_states.shape[1], :] + txt_attn_output = joint_hidden_states[:, hidden_states.shape[1] :, :] + + img_attn_output = attn.img_attn_proj(img_attn_output) + txt_attn_output = attn.txt_attn_proj(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class JoyImageEditPlusTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin): + """JoyImage Edit Plus Transformer for multi-image editing. + + Uses a patchify+padding approach where each reference image and the target noise are independently + patchified and concatenated into a flat patch sequence. Supports variable-resolution reference images. + + Input format: [B, max_patches, C, pt, ph, pw] (6D padded patches) + """ + + _skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"] + _no_split_modules = ["JoyImageTransformerBlock"] + _supports_gradient_checkpointing = True + _keep_in_fp32_modules = [ + "time_embedder", + "norm1", + "norm2", + "norm_out", + ] + _repeated_blocks = ["JoyImageTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: list = [1, 2, 2], + in_channels: int = 16, + out_channels: int | None = None, + hidden_size: int = 3072, + num_attention_heads: int = 24, + text_dim: int = 4096, + mlp_width_ratio: float = 4.0, + num_layers: int = 20, + rope_dim_list: list[int] = [16, 56, 56], + rope_type: str = "rope", + theta: int = 256, + ): + super().__init__() + + self.out_channels = out_channels or in_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.rope_dim_list = rope_dim_list + self.rope_type = rope_type + self.theta = theta + + attention_head_dim = hidden_size // num_attention_heads + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" + ) + + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + self.condition_embedder = JoyImageTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_dim, + ) + + self.double_blocks = nn.ModuleList( + [ + JoyImageTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = FP32LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size)) + + self.gradient_checkpointing = False + + # Set batched-RoPE-aware attention processor on all blocks + for block in self.double_blocks: + block.attn.set_processor(JoyImageEditPlusAttnProcessor()) + + def _get_rotary_pos_embed_for_range( + self, + start: Tuple[int, int, int], + stop: Tuple[int, int, int], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate 3D RoPE for a spatial range [start, stop).""" + head_dim = self.hidden_size // self.num_attention_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // 3] * 3 + + grids = [] + for i in range(3): + grids.append(torch.arange(start[i], stop[i], dtype=torch.float32)) + + mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0) + + cos_parts, sin_parts = [], [] + for i, dim in enumerate(rope_dim_list): + pos = mesh[i].reshape(-1) + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + angles = torch.outer(pos, freqs) + cos_parts.append(angles.cos().repeat_interleave(2, dim=1)) + sin_parts.append(angles.sin().repeat_interleave(2, dim=1)) + + return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor | None = None, + shape_list: List[List[Tuple[int, int, int]]] | None = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Tuple]: + """ + Args: + hidden_states: [B, max_patches, C, pt, ph, pw] - patchified latent input. + timestep: [B] - diffusion timestep. + encoder_hidden_states: [B, L, D] - text encoder outputs. + encoder_hidden_states_mask: [B, L] - attention mask for text tokens. + shape_list: Per-sample list of (t, h, w) tuples for each component (target + references). + return_dict: Whether to return a dict or tuple. + """ + batch_size, max_num_patches, channels, pt, ph, pw = hidden_states.shape + device = hidden_states.device + + # Unwrap list inputs (SglangXvideo passes these as lists from CFG branches) + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + if isinstance(encoder_hidden_states_mask, list): + encoder_hidden_states_mask = encoder_hidden_states_mask[0] + + # Resolve shape_list from forward context if not explicitly provided + if shape_list is None: + try: + from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context + + forward_batch = get_forward_context().forward_batch + if forward_batch is not None and forward_batch.vae_image_sizes is not None: + shape_list = [list(forward_batch.vae_image_sizes)] * batch_size + except (ImportError, AttributeError): + pass + if shape_list is None: + raise ValueError( + "shape_list must be provided either as an argument or via forward_batch.vae_image_sizes" + ) + + # 1. Condition embeddings + _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + # 2. Patchify via Conv3d: flatten (B, N) -> apply conv -> reshape back + x = hidden_states.reshape(batch_size * max_num_patches, channels, pt, ph, pw) + x = self.img_in(x) # (B*N, D, 1, 1, 1) + img = x.reshape(batch_size, max_num_patches, -1) + + # 3. Build per-component RoPE with temporal offsets + sample_cos_list, sample_sin_list = [], [] + + for i in range(batch_size): + s_cos_parts, s_sin_parts = [], [] + current_t_offset = 0 + + for thw in shape_list[i]: + t, h, w = thw + start = (current_t_offset, 0, 0) + stop = (current_t_offset + t, h, w) + cos_emb, sin_emb = self._get_rotary_pos_embed_for_range(start, stop) + s_cos_parts.append(cos_emb) + s_sin_parts.append(sin_emb) + current_t_offset += t + + s_cos = torch.cat(s_cos_parts, dim=0).to(device) + s_sin = torch.cat(s_sin_parts, dim=0).to(device) + + actual_len = s_cos.shape[0] + pad_len = max_num_patches - actual_len + if pad_len > 0: + s_cos = F.pad(s_cos, (0, 0, 0, pad_len), value=1.0) + s_sin = F.pad(s_sin, (0, 0, 0, pad_len), value=0.0) + + sample_cos_list.append(s_cos) + sample_sin_list.append(s_sin) + + vis_freqs = (torch.stack(sample_cos_list), torch.stack(sample_sin_list)) + + # 4. Build attention mask: [B, 1, 1, img_seq + txt_seq] + # img patches: only actual (non-padding) patches are valid; txt uses encoder_hidden_states_mask + attention_mask = None + if encoder_hidden_states_mask is not None: + img_mask = torch.zeros(batch_size, max_num_patches, device=device, dtype=encoder_hidden_states_mask.dtype) + for i in range(batch_size): + actual_len = sum(t * h * w for t, h, w in shape_list[i]) + img_mask[i, :actual_len] = 1.0 + full_mask = torch.cat([img_mask, encoder_hidden_states_mask], dim=1) + attention_mask = full_mask.unsqueeze(1).unsqueeze(1).bool() + + # 5. Run double blocks + for block in self.double_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img, txt = self._gradient_checkpointing_func(block, img, txt, vec, (vis_freqs, None), attention_mask) + else: + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=(vis_freqs, None), + attention_mask=attention_mask, + ) + + # 6. Output projection + reshape to 6D patches + img = self.proj_out(self.norm_out(img)) + img = img.reshape( + batch_size, max_num_patches, pt, ph, pw, self.out_channels + ).permute(0, 1, 5, 2, 3, 4) # -> [B, N, C, pt, ph, pw] + + if not return_dict: + return (img,) + return Transformer2DModelOutput(sample=img) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 850a991941ff..b467a20806bb 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -344,7 +344,7 @@ "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] - _import_structure["joyimage"] = ["JoyImageEditPipeline", "JoyImageEditPipelineOutput"] + _import_structure["joyimage"] = ["JoyImageEditPipeline", "JoyImageEditPipelineOutput", "JoyImageEditPlusPipeline", "JoyImageEditPlusPipelineOutput"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -757,7 +757,7 @@ from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline from .ideogram4 import Ideogram4Pipeline, Ideogram4PromptEnhancerHead - from .joyimage import JoyImageEditPipeline, JoyImageEditPipelineOutput + from .joyimage import JoyImageEditPipeline, JoyImageEditPipelineOutput, JoyImageEditPlusPipeline, JoyImageEditPlusPipelineOutput from .kandinsky import ( KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, diff --git a/src/diffusers/pipelines/joyimage/__init__.py b/src/diffusers/pipelines/joyimage/__init__.py index 85b9246b22a6..a5faea9d9763 100644 --- a/src/diffusers/pipelines/joyimage/__init__.py +++ b/src/diffusers/pipelines/joyimage/__init__.py @@ -22,8 +22,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_joyimage_edit"] = ["JoyImageEditPipeline"] - - _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput"] + _import_structure["pipeline_joyimage_edit_plus"] = ["JoyImageEditPlusPipeline"] + _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput", "JoyImageEditPlusPipelineOutput"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -34,7 +34,8 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_joyimage_edit import JoyImageEditPipeline - from .pipeline_output import JoyImageEditPipelineOutput + from .pipeline_joyimage_edit_plus import JoyImageEditPlusPipeline + from .pipeline_output import JoyImageEditPipelineOutput, JoyImageEditPlusPipelineOutput else: import sys diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py new file mode 100644 index 000000000000..c938e8e8ab32 --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -0,0 +1,697 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from PIL import Image +from transformers import ( + Qwen2Tokenizer, + Qwen3VLForConditionalGeneration, + Qwen3VLProcessor, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLWan +from ...models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import JoyImageEditImageProcessor, find_best_bucket +from .pipeline_output import JoyImageEditPlusPipelineOutput + + +EXAMPLE_DOC_STRING = """ +Examples: + ```python + >>> import torch + >>> from diffusers import JoyImageEditPlusPipeline + >>> from diffusers.utils import load_image + + >>> model_id = "jdopensource/JoyAI-Image-Edit-Plus-Diffusers" + >>> pipe = JoyImageEditPlusPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> images = [ + ... load_image("dog.png"), + ... load_image("person.png"), + ... ] + >>> output = pipe( + ... images=images, + ... prompt="Let the person lovingly play with the dog.", + ... height=1024, + ... width=1024, + ... num_inference_steps=30, + ... guidance_scale=4.0, + ... generator=torch.manual_seed(42), + ... ) + >>> output.images[0].save("output.png") + ``` +""" + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") + + if timesteps is not None: + if "timesteps" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom timesteps.") + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if "sigmas" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom sigmas.") + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +class JoyImageEditPlusPipeline(DiffusionPipeline): + """Diffusion pipeline for multi-image editing using JoyImage Edit Plus. + + Supports multiple reference images with different resolutions. Each reference image is independently + VAE-encoded and patchified, then concatenated with the target noise patches for joint denoising. + + Model offloading order: text_encoder -> transformer -> vae. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLWan, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: JoyImageEditPlusTransformer3DModel, + processor: Qwen3VLProcessor, + text_token_max_length: int = 2048, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + processor=processor, + ) + + self.text_token_max_length = text_token_max_length + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.vae_image_processor = JoyImageEditImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + ) + + self.prompt_template_encode = { + "multiple_images": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "{}<|im_start|>assistant\n" + ), + } + self.prompt_template_encode_start_idx = { + "multiple_images": 34, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_last_decoder_hidden_states(self, forward_fn, **kwargs): + """ + Run ``forward_fn(**kwargs)`` while capturing the **pre-norm** output of the last decoder layer via a forward + hook. + + This model was trained on transformers 4.57, where ``Qwen3VLForConditionalGeneration``'s + ``@check_model_inputs`` decorator monkey-patched each decoder layer to collect ``hidden_states``. Because + ``Qwen3VLCausalLMOutputWithPast`` has no ``last_hidden_state`` field, ``tie_last_hidden_states`` had no effect + and ``hidden_states[-1]`` was the **pre-norm** output of the last decoder layer. + + Starting from https://github.com/huggingface/transformers/pull/42609 the CausalLM forward explicitly returns + ``hidden_states=outputs.hidden_states`` from the inner model. Combined with the subsequent + ``@check_model_inputs`` → ``@capture_outputs`` migration (transformers 5.x), ``hidden_states`` is now captured + at the ``Qwen3VLTextModel`` level where ``tie_last_hidden_states=True`` replaces ``hidden_states[-1]`` with the + **post-norm** ``last_hidden_state``. The CausalLM simply passes this through, so ``hidden_states[-1]`` becomes + post-norm – a ~10x scale difference (std ~2 vs ~21) that breaks inference. + + This helper bypasses both mechanisms by hooking the last decoder layer directly, returning the raw pre-norm + output regardless of the transformers version. + """ + captured = {} + + def _hook(_module, _input, output): + captured["hidden_states"] = output[0] if isinstance(output, tuple) else output + + handle = self.text_encoder.model.language_model.layers[-1].register_forward_hook(_hook) + try: + forward_fn(**kwargs) + finally: + handle.remove() + return captured["hidden_states"] + + def encode_prompt_multiple_images( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + images: Optional[List[Image.Image]] = None, + max_sequence_length: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode prompts with inline tokens via the Qwen3-VL processor.""" + device = device or self._execution_device + template = self.prompt_template_encode["multiple_images"] + drop_idx = self.prompt_template_encode_start_idx["multiple_images"] + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [p.replace("\n", "<|vision_start|><|image_pad|><|vision_end|>") for p in prompt] + prompt = [template.format(p) for p in prompt] + + inputs = self.processor( + text=prompt, + images=images, + padding=True, + return_tensors="pt", + ).to(device) + + last_hidden_states = self._get_last_decoder_hidden_states(self.text_encoder, **inputs) + + prompt_embeds = last_hidden_states[:, drop_idx:] + prompt_embeds_mask = inputs["attention_mask"][:, drop_idx:] + + if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, -max_sequence_length:, :] + prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:] + + return prompt_embeds, prompt_embeds_mask + + def _pad_sequence(self, x: torch.Tensor, target_length: int) -> torch.Tensor: + current_length = x.shape[1] + if current_length >= target_length: + return x[:, -target_length:] + padding_length = target_length - current_length + if x.ndim >= 3: + padding = torch.zeros( + (x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device + ) + else: + padding = torch.zeros((x.shape[0], padding_length), dtype=x.dtype, device=x.device) + return torch.cat([x, padding], dim=1) + + def normalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latent = (latent - latents_mean) / latents_std + else: + latent = latent * self.vae.config.scaling_factor + return latent + + def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latent = latent * latents_std + latents_mean + else: + latent = latent / self.vae.config.scaling_factor + return latent + + def _resize_center_crop(self, img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + w, h = img.size + bh, bw = target_size + scale = max(bh / h, bw / w) + resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale) + img = img.resize((resize_w, resize_h), Image.LANCZOS) + left = (resize_w - bw) // 2 + top = (resize_h - bh) // 2 + img = img.crop((left, top, left + bw, top + bh)) + return img + + def _get_bucket_size(self, img: Image.Image) -> Tuple[int, int]: + return find_best_bucket(img.size[1], img.size[0], self.vae_image_processor.config.basesize) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + reference_images: Optional[List[List[Image.Image]]] = None, + enable_denormalization: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[Tuple[int, int, int]]]]: + """Prepare 6D padded latent tensor with target noise + reference image latents. + + Returns: + padded_latents: [B, max_patches, C, pt, ph, pw] + target_mask: [B, max_patches] (True for target patches) + shape_list: per-sample list of (t, h, w) tuples for each component + """ + pt, ph, pw = self.transformer.config.patch_size + + all_patches = [] + all_target_masks = [] + all_shape_lists = [] + max_patches = 0 + + for i in range(batch_size): + sample_gen = generator[i] if isinstance(generator, list) else generator + + # Target noise + t_target = 1 + h_target = int(height) // self.vae_scale_factor_spatial + w_target = int(width) // self.vae_scale_factor_spatial + noise_shape = (num_channels_latents, t_target, h_target, w_target) + noise_block = randn_tensor(noise_shape, generator=sample_gen, device=device, dtype=dtype) + + sample_items = [noise_block] + + # Reference images + if reference_images is not None and reference_images[i]: + for ref_img_pil in reference_images[i]: + ref_h, ref_w = self._get_bucket_size(ref_img_pil) + ref_img_pil = self._resize_center_crop(ref_img_pil, (ref_h, ref_w)) + + ref_tensor = torch.from_numpy(np.array(ref_img_pil.convert("RGB"))).to(device=device, dtype=dtype) + ref_tensor = (ref_tensor / 127.5 - 1.0).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.float32): + ref_latent = self.vae.encode(ref_tensor.float()).latent_dist.mode() + ref_latent = ref_latent.to(dtype) + ref_latent = self.normalize_latents(ref_latent) + ref_latent = ref_latent.squeeze(0) # [C, 1, H', W'] + sample_items.append(ref_latent) + + # Patchify each item and build shape_list + sample_patches = [] + sample_masks = [] + sample_shapes = [] + + for j, item in enumerate(sample_items): + c, t, h, w = item.shape + l_t, l_h, l_w = t // pt, h // ph, w // pw + sample_shapes.append((l_t, l_h, l_w)) + + patches = rearrange(item, "c (t pt) (h ph) (w pw) -> (t h w) c pt ph pw", pt=pt, ph=ph, pw=pw) + sample_patches.append(patches) + sample_masks.append(torch.full((patches.shape[0],), j == 0, device=device, dtype=torch.bool)) + + combined_patches = torch.cat(sample_patches, dim=0) + combined_masks = torch.cat(sample_masks, dim=0) + + all_patches.append(combined_patches) + all_target_masks.append(combined_masks) + all_shape_lists.append(sample_shapes) + max_patches = max(max_patches, combined_patches.shape[0]) + + # Pad to uniform size + padded_latents = torch.zeros( + (batch_size, max_patches, num_channels_latents, pt, ph, pw), device=device, dtype=dtype + ) + target_mask = torch.zeros((batch_size, max_patches), device=device, dtype=torch.bool) + + for i in range(batch_size): + n = all_patches[i].shape[0] + padded_latents[i, :n] = all_patches[i] + target_mask[i, :n] = all_target_masks[i] + + return padded_latents, target_mask, all_shape_lists + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def guidance_scale(self) -> float: + return self._guidance_scale + + @property + def do_classifier_free_guidance(self) -> bool: + return self._guidance_scale > 1 + + @property + def num_timesteps(self) -> int: + return self._num_timesteps + + @property + def interrupt(self) -> bool: + return self._interrupt + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + images: List[Image.Image] | List[List[Image.Image]] | None = None, + prompt: str | List[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 4096, + enable_denormalization: bool = True, + ): + r""" + Generate an edited image from multiple reference images and a text prompt. + + Args: + images (`List[Image.Image]` or `List[List[Image.Image]]`): + Reference images for editing. Each image can have a different resolution. + If a flat list is provided, it's treated as one sample with multiple references. + prompt (`str` or `List[str]`): + Text prompt describing the desired edit. + height (`int`, *optional*): + Output height in pixels. If None, determined from the last reference image's bucket. + width (`int`, *optional*): + Output width in pixels. If None, determined from the last reference image's bucket. + num_inference_steps (`int`, defaults to 30): + Number of denoising steps. + guidance_scale (`float`, defaults to 4.0): + Classifier-free guidance scale. + negative_prompt (`str` or `List[str]`, *optional*): + Negative prompt for CFG. + generator (`torch.Generator`, *optional*): + RNG generator for reproducibility. + enable_denormalization (`bool`, defaults to True): + Whether to denormalize latents before VAE decoding. + + Examples: + + Returns: + [`JoyImageEditPlusPipelineOutput`] or `tuple`. + """ + # Normalize images input to List[List[Image]] + if images is not None: + if isinstance(images[0], Image.Image): + images = [images] # single sample + + self._guidance_scale = guidance_scale + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Determine output resolution from last reference image if not specified + if height is None or width is None: + if images is not None and len(images[0]) > 0: + last_img = images[0][-1] + height, width = self._get_bucket_size(last_img) + else: + height = height or 1024 + width = width or 1024 + + device = self._execution_device + + # Pre-process images: bucket-resize each reference image (matching original pipeline) + if images is not None: + processed_images = [] + for sample_imgs in images: + processed_sample = [] + for img in sample_imgs: + ref_h, ref_w = self._get_bucket_size(img) + resize_img = self._resize_center_crop(img, (ref_h, ref_w)) + processed_sample.append(resize_img) + processed_images.append(processed_sample) + images = processed_images + + # Construct prompts with tokens + prompt = [prompt] if isinstance(prompt, str) else prompt + if images is not None: + formatted_prompts = [] + for i in range(batch_size): + num_refs = len(images[i]) if i < len(images) else 0 + image_tags = "".join(["\n" for _ in range(num_refs)]) + p = prompt[i] if i < len(prompt) else prompt[0] + formatted_prompts.append(f"<|im_start|>user\n{image_tags}{p}<|im_end|>\n") + else: + formatted_prompts = [f"<|im_start|>user\n{p}<|im_end|>\n" for p in prompt] + + # Flatten all images for the processor + flattened_images = None + if images is not None: + flattened_images = [img for sublist in images for img in sublist] + + # Encode prompt + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self.encode_prompt_multiple_images( + prompt=formatted_prompts, + images=flattened_images, + device=device, + max_sequence_length=max_sequence_length, + ) + + torch.save(prompt_embeds, "prompt_embeds.pt") + # Encode negative prompt for CFG + if self.do_classifier_free_guidance: + print(f"negative_prompt: {negative_prompt}") + if negative_prompt is None and negative_prompt_embeds is None: + neg_prompts = [] + for i in range(batch_size): + num_refs = len(images[i]) if images is not None and i < len(images) else 0 + image_tags = "".join(["\n" for _ in range(num_refs)]) + neg_prompts.append(f"<|im_start|>user\n{image_tags} <|im_end|>\n") + negative_prompt = neg_prompts + elif negative_prompt is not None and negative_prompt_embeds is None: + neg_list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + neg_prompts = [] + for i in range(batch_size): + num_refs = len(images[i]) if images is not None and i < len(images) else 0 + image_tags = "".join(["\n" for _ in range(num_refs)]) + n = neg_list[i] if i < len(neg_list) else neg_list[0] + neg_prompts.append(f"<|im_start|>user\n{image_tags}{n}<|im_end|>\n") + negative_prompt = neg_prompts + + if negative_prompt_embeds is None: + neg_prompt_list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt_multiple_images( + prompt=neg_prompt_list, + images=flattened_images, + device=device, + max_sequence_length=max_sequence_length, + ) + + # Pad and concatenate [negative, positive] + max_seq_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1]) + prompt_embeds = torch.cat([ + self._pad_sequence(negative_prompt_embeds, max_seq_len), + self._pad_sequence(prompt_embeds, max_seq_len), + ]) + if prompt_embeds_mask is not None and negative_prompt_embeds_mask is not None: + prompt_embeds_mask = torch.cat([ + self._pad_sequence(negative_prompt_embeds_mask, max_seq_len), + self._pad_sequence(prompt_embeds_mask, max_seq_len), + ]) + torch.save(prompt_embeds, 'prompt_embeds_2.pt') + + # Prepare timesteps — compute sigmas with single shift to match original scheduler + if timesteps is None and sigmas is None: + shift = getattr(self.scheduler.config, "shift", 1.0) + raw_sigmas = torch.linspace(1, 0, num_inference_steps + 1) + shifted_sigmas = shift * raw_sigmas / (1 + (shift - 1) * raw_sigmas) + sigmas = shifted_sigmas[:-1].tolist() + original_shift = self.scheduler.shift + self.scheduler.set_shift(1.0) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self.scheduler.set_shift(original_shift) + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # Prepare latents (patchified) + num_channels_latents = self.transformer.config.in_channels + padded_latents, target_mask, shape_list = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + reference_images=images, + enable_denormalization=enable_denormalization, + ) + torch.save(padded_latents, "padded_latents.pt") + torch.save(target_mask, "target_mask.pt") + # exit(0) + + # Zero out padding text tokens to prevent them from corrupting attention + # (original uses explicit attention masking; here we neutralize padding values) + if prompt_embeds_mask is not None: + prompt_embeds = prompt_embeds * prompt_embeds_mask.unsqueeze(-1) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + clean_reference_backup = padded_latents.clone() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Restore reference patches + padded_latents[~target_mask] = clean_reference_backup[~target_mask] + + model_input = padded_latents + + # CFG expansion + if self.do_classifier_free_guidance: + model_input_cfg = torch.cat([model_input] * 2) + t_expand = t.repeat(model_input_cfg.shape[0]) + cfg_shape_list = shape_list * 2 + else: + model_input_cfg = model_input + t_expand = t.repeat(batch_size) + cfg_shape_list = shape_list + + # Transformer forward + noise_pred = self.transformer( + hidden_states=model_input_cfg, + timestep=t_expand, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + shape_list=cfg_shape_list, + return_dict=False, + )[0] + + # CFG combination with norm rescaling + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + comb_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + cond_norm = torch.norm(noise_pred_text, dim=2, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=2, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # Scheduler step + padded_latents = self.scheduler.step(noise_pred, t, padded_latents, return_dict=False)[0].to( + dtype=prompt_embeds.dtype + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + padded_latents = callback_outputs.pop("latents", padded_latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if progress_bar is not None: + progress_bar.update() + + # Post-processing: decode target latents + if output_type != "latent": + padded_latents[~target_mask] = clean_reference_backup[~target_mask] + pt, ph, pw = self.transformer.config.patch_size + + image_list = [] + for b_idx in range(batch_size): + l_t, l_h, l_w = shape_list[b_idx][0] + target_len = l_t * l_h * l_w + + target_patches = padded_latents[b_idx, :target_len] + video_latent = rearrange( + target_patches, + "(t h w) c pt ph pw -> 1 c (t pt) (h ph) (w pw)", + t=l_t, h=l_h, w=l_w, + ) + + video_latent = self.denormalize_latents(video_latent) + + with torch.autocast(device_type="cuda", dtype=torch.float32): + sample_image = self.vae.decode(video_latent.float(), return_dict=False)[0] + sample_image = (sample_image / 2 + 0.5).clamp(0, 1).squeeze(0).cpu().float() + image_list.append(sample_image) + + # Convert to output format + output_images = [] + for img_tensor in image_list: + # img_tensor: [C, T, H, W] -> [C, H, W] (T=1) + img_tensor = img_tensor[:, 0] + img_np = (img_tensor.permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) + if output_type == "pil": + output_images.append(Image.fromarray(img_np)) + elif output_type == "np": + output_images.append(img_np) + else: + output_images.append(img_tensor) + + image = output_images + else: + image = padded_latents + + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return JoyImageEditPlusPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 175dce3540d7..40d9d3aa100f 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -14,3 +14,11 @@ class JoyImageEditPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] + +@dataclass +class JoyImageEditPlusPipelineOutput(BaseOutput): + """ + Output class for JoyImage Edit Plus multi-image editing pipelines. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file From fff2f4b5d0489778e3bdf9e86638be5ad9aff146 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Mon, 22 Jun 2026 05:31:11 +0000 Subject: [PATCH 2/5] refactor: remove debug code --- .../pipelines/joyimage/pipeline_joyimage_edit_plus.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index c938e8e8ab32..980939f427d6 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -506,10 +506,8 @@ def __call__( max_sequence_length=max_sequence_length, ) - torch.save(prompt_embeds, "prompt_embeds.pt") # Encode negative prompt for CFG if self.do_classifier_free_guidance: - print(f"negative_prompt: {negative_prompt}") if negative_prompt is None and negative_prompt_embeds is None: neg_prompts = [] for i in range(batch_size): @@ -547,7 +545,6 @@ def __call__( self._pad_sequence(negative_prompt_embeds_mask, max_seq_len), self._pad_sequence(prompt_embeds_mask, max_seq_len), ]) - torch.save(prompt_embeds, 'prompt_embeds_2.pt') # Prepare timesteps — compute sigmas with single shift to match original scheduler if timesteps is None and sigmas is None: @@ -579,9 +576,6 @@ def __call__( reference_images=images, enable_denormalization=enable_denormalization, ) - torch.save(padded_latents, "padded_latents.pt") - torch.save(target_mask, "target_mask.pt") - # exit(0) # Zero out padding text tokens to prevent them from corrupting attention # (original uses explicit attention masking; here we neutralize padding values) From d295e3e5a17257f074e71ff0a0bdf1d6198aef7a Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Tue, 23 Jun 2026 02:15:16 +0000 Subject: [PATCH 3/5] fix: address review issues for JoyImage Edit Plus - Remove einops dependency: replace rearrange with reshape/permute - Remove sglang-specific code from transformer forward - Remove unused import inspect from transformer - Fix hardcoded device_type="cuda" to use device.type - Simplify scheduler sigma math: delegate to retrieve_timesteps - Remove unused enable_denormalization parameter - Fix callback latents variable binding - Fix output_type="pt" to return stacked tensor - Set return_dict default to True in transformer forward - Add dummy objects for JoyImageEditPlus classes - Add transformer and pipeline test files --- .../transformer_joyimage_edit_plus.py | 19 +- .../joyimage/pipeline_joyimage_edit_plus.py | 50 ++-- src/diffusers/utils/dummy_pt_objects.py | 15 ++ .../dummy_torch_and_transformers_objects.py | 30 +++ ...t_models_transformer_joyimage_edit_plus.py | 114 +++++++++ .../joyimage/test_joyimage_edit_plus.py | 225 ++++++++++++++++++ 6 files changed, 403 insertions(+), 50 deletions(-) create mode 100644 tests/models/transformers/test_models_transformer_joyimage_edit_plus.py create mode 100644 tests/pipelines/joyimage/test_joyimage_edit_plus.py diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py index abc8c2b4340a..572c983ec453 100644 --- a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import math from typing import List, Tuple, Union @@ -255,7 +254,7 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None = None, shape_list: List[List[Tuple[int, int, int]]] | None = None, - return_dict: bool = False, + return_dict: bool = True, ) -> Union[torch.Tensor, Tuple]: """ Args: @@ -269,22 +268,6 @@ def forward( batch_size, max_num_patches, channels, pt, ph, pw = hidden_states.shape device = hidden_states.device - # Unwrap list inputs (SglangXvideo passes these as lists from CFG branches) - if not isinstance(encoder_hidden_states, torch.Tensor): - encoder_hidden_states = encoder_hidden_states[0] - if isinstance(encoder_hidden_states_mask, list): - encoder_hidden_states_mask = encoder_hidden_states_mask[0] - - # Resolve shape_list from forward context if not explicitly provided - if shape_list is None: - try: - from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context - - forward_batch = get_forward_context().forward_batch - if forward_batch is not None and forward_batch.vae_image_sizes is not None: - shape_list = [list(forward_batch.vae_image_sizes)] * batch_size - except (ImportError, AttributeError): - pass if shape_list is None: raise ValueError( "shape_list must be provided either as an argument or via forward_batch.vae_image_sizes" diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index 980939f427d6..144650d46b05 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -18,7 +18,6 @@ import numpy as np import torch -from einops import rearrange from PIL import Image from transformers import ( Qwen2Tokenizer, @@ -282,7 +281,6 @@ def prepare_latents( device: torch.device, generator: Optional[Union[torch.Generator, List[torch.Generator]]], reference_images: Optional[List[List[Image.Image]]] = None, - enable_denormalization: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[Tuple[int, int, int]]]]: """Prepare 6D padded latent tensor with target noise + reference image latents. @@ -319,7 +317,7 @@ def prepare_latents( ref_tensor = torch.from_numpy(np.array(ref_img_pil.convert("RGB"))).to(device=device, dtype=dtype) ref_tensor = (ref_tensor / 127.5 - 1.0).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) - with torch.autocast(device_type="cuda", dtype=torch.float32): + with torch.autocast(device_type=device.type, dtype=torch.float32): ref_latent = self.vae.encode(ref_tensor.float()).latent_dist.mode() ref_latent = ref_latent.to(dtype) ref_latent = self.normalize_latents(ref_latent) @@ -336,7 +334,8 @@ def prepare_latents( l_t, l_h, l_w = t // pt, h // ph, w // pw sample_shapes.append((l_t, l_h, l_w)) - patches = rearrange(item, "c (t pt) (h ph) (w pw) -> (t h w) c pt ph pw", pt=pt, ph=ph, pw=pw) + patches = item.reshape(c, l_t, pt, l_h, ph, l_w, pw) + patches = patches.permute(1, 3, 5, 0, 2, 4, 6).reshape(-1, c, pt, ph, pw) sample_patches.append(patches) sample_masks.append(torch.full((patches.shape[0],), j == 0, device=device, dtype=torch.bool)) @@ -411,7 +410,6 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 4096, - enable_denormalization: bool = True, ): r""" Generate an edited image from multiple reference images and a text prompt. @@ -434,8 +432,6 @@ def __call__( Negative prompt for CFG. generator (`torch.Generator`, *optional*): RNG generator for reproducibility. - enable_denormalization (`bool`, defaults to True): - Whether to denormalize latents before VAE decoding. Examples: @@ -546,22 +542,10 @@ def __call__( self._pad_sequence(prompt_embeds_mask, max_seq_len), ]) - # Prepare timesteps — compute sigmas with single shift to match original scheduler - if timesteps is None and sigmas is None: - shift = getattr(self.scheduler.config, "shift", 1.0) - raw_sigmas = torch.linspace(1, 0, num_inference_steps + 1) - shifted_sigmas = shift * raw_sigmas / (1 + (shift - 1) * raw_sigmas) - sigmas = shifted_sigmas[:-1].tolist() - original_shift = self.scheduler.shift - self.scheduler.set_shift(1.0) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - self.scheduler.set_shift(original_shift) - else: - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) + # Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # Prepare latents (patchified) num_channels_latents = self.transformer.config.in_channels @@ -574,7 +558,6 @@ def __call__( device=device, generator=generator, reference_images=images, - enable_denormalization=enable_denormalization, ) # Zero out padding text tokens to prevent them from corrupting attention @@ -631,6 +614,7 @@ def __call__( ) if callback_on_step_end is not None: + latents = padded_latents callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] @@ -653,15 +637,13 @@ def __call__( target_len = l_t * l_h * l_w target_patches = padded_latents[b_idx, :target_len] - video_latent = rearrange( - target_patches, - "(t h w) c pt ph pw -> 1 c (t pt) (h ph) (w pw)", - t=l_t, h=l_h, w=l_w, - ) + c_lat = target_patches.shape[1] + video_latent = target_patches.reshape(l_t, l_h, l_w, c_lat, pt, ph, pw) + video_latent = video_latent.permute(3, 0, 4, 1, 5, 2, 6).reshape(1, c_lat, l_t * pt, l_h * ph, l_w * pw) video_latent = self.denormalize_latents(video_latent) - with torch.autocast(device_type="cuda", dtype=torch.float32): + with torch.autocast(device_type=device.type, dtype=torch.float32): sample_image = self.vae.decode(video_latent.float(), return_dict=False)[0] sample_image = (sample_image / 2 + 0.5).clamp(0, 1).squeeze(0).cpu().float() image_list.append(sample_image) @@ -671,15 +653,19 @@ def __call__( for img_tensor in image_list: # img_tensor: [C, T, H, W] -> [C, H, W] (T=1) img_tensor = img_tensor[:, 0] - img_np = (img_tensor.permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) if output_type == "pil": + img_np = (img_tensor.permute(1, 2, 0).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8) output_images.append(Image.fromarray(img_np)) elif output_type == "np": + img_np = (img_tensor.permute(1, 2, 0).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8) output_images.append(img_np) else: output_images.append(img_tensor) - image = output_images + if output_type == "pt": + image = torch.stack(output_images) + else: + image = output_images else: image = padded_latents diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8439a2b93371..5fa793a9ab9f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1500,6 +1500,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class JoyImageEditPlusTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0747e76cf715..44f9ab79eff3 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2222,6 +2222,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class JoyImageEditPlusPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class JoyImageEditPlusPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Kandinsky3Img2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_joyimage_edit_plus.py b/tests/models/transformers/test_models_transformer_joyimage_edit_plus.py new file mode 100644 index 000000000000..451dbfbbf0ca --- /dev/null +++ b/tests/models/transformers/test_models_transformer_joyimage_edit_plus.py @@ -0,0 +1,114 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers import JoyImageEditPlusTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class JoyImageEditPlusTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return JoyImageEditPlusTransformer3DModel + + @property + def output_shape(self) -> tuple[int, ...]: + return (2, 16, 1, 2, 2) + + @property + def input_shape(self) -> tuple[int, ...]: + return (2, 16, 1, 2, 2) + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def uses_custom_attn_processor(self) -> bool: + return True + + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + return { + "patch_size": [1, 2, 2], + "in_channels": 16, + "hidden_size": 32, + "num_attention_heads": 2, + "text_dim": 16, + "num_layers": 2, + "rope_dim_list": [4, 6, 6], + "theta": 256, + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + max_patches = 2 + hidden_states = randn_tensor( + (batch_size, max_patches, 16, 1, 2, 2), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor((batch_size, 12, 16), generator=self.generator, device=torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + shape_list = [[(1, 1, 1), (1, 1, 1)]] + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "shape_list": shape_list, + } + + +class TestJoyImageEditPlusTransformer(JoyImageEditPlusTransformerTesterConfig, ModelTesterMixin): + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + pytest.skip("Tolerance requirements too high for meaningful test") + + +class TestJoyImageEditPlusTransformerMemory(JoyImageEditPlusTransformerTesterConfig, MemoryTesterMixin): + pass + + +class TestJoyImageEditPlusTransformerTraining(JoyImageEditPlusTransformerTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = {"JoyImageEditPlusTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestJoyImageEditPlusTransformerAttention(JoyImageEditPlusTransformerTesterConfig, AttentionTesterMixin): + pass + + +class TestJoyImageEditPlusTransformerCompile(JoyImageEditPlusTransformerTesterConfig, TorchCompileTesterMixin): + pass diff --git a/tests/pipelines/joyimage/test_joyimage_edit_plus.py b/tests/pipelines/joyimage/test_joyimage_edit_plus.py new file mode 100644 index 000000000000..e41265d30128 --- /dev/null +++ b/tests/pipelines/joyimage/test_joyimage_edit_plus.py @@ -0,0 +1,225 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from PIL import Image +from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + JoyImageEditPlusPipeline, + JoyImageEditPlusTransformer3DModel, +) +from diffusers.hooks import apply_group_offloading + +from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class JoyImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = JoyImageEditPlusPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = frozenset(["prompt", "images"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def setUp(self): + super().setUp() + self._bucket_patcher = patch( + "diffusers.pipelines.joyimage.image_processor.find_best_bucket", + return_value=(32, 32), + ) + self._bucket_patcher.start() + + def tearDown(self): + self._bucket_patcher.stop() + super().tearDown() + + def get_dummy_components(self): + tiny_ckpt_id = "huangfeice/tiny-random-Qwen3VLForConditionalGeneration" + + torch.manual_seed(0) + transformer = JoyImageEditPlusTransformer3DModel( + patch_size=[1, 2, 2], + in_channels=16, + hidden_size=32, + num_attention_heads=2, + text_dim=16, + num_layers=1, + rope_dim_list=[4, 6, 6], + theta=256, + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + processor = Qwen3VLProcessor.from_pretrained(tiny_ckpt_id) + processor.image_processor.min_pixels = 4 * 28 * 28 + processor.image_processor.max_pixels = 4 * 28 * 28 + + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(tiny_ckpt_id) + text_encoder.resize_token_embeddings(len(processor.tokenizer)) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": processor.tokenizer, + "processor": processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "combine the two images", + "images": [Image.new("RGB", (32, 32)), Image.new("RGB", (32, 32))], + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + @unittest.skip("num_images_per_prompt not applicable: each prompt is bound to reference images") + def test_num_images_per_prompt(self): + pass + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=False) + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): + super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol) + + @require_torch_accelerator + def test_group_offloading_inference(self): + if not self.test_group_offloading: + return + + def create_pipe(): + torch.manual_seed(0) + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + return pipe(**inputs)[0] + + pipe = create_pipe().to(torch_device) + output_without_group_offloading = run_forward(pipe) + + pipe = create_pipe() + for component_name in ["transformer", "text_encoder"]: + component = getattr(pipe, component_name, None) + if component is None: + continue + if hasattr(component, "enable_group_offload"): + component.enable_group_offload( + torch.device(torch_device), offload_type="block_level", num_blocks_per_group=1 + ) + else: + apply_group_offloading( + component, + onload_device=torch.device(torch_device), + offload_type="block_level", + num_blocks_per_group=1, + ) + pipe.vae.to(torch_device) + output_with_block_level = run_forward(pipe) + + pipe = create_pipe() + pipe.transformer.enable_group_offload(torch.device(torch_device), offload_type="leaf_level") + pipe.text_encoder.to(torch_device) + pipe.vae.to(torch_device) + output_with_leaf_level = run_forward(pipe) + + if torch.is_tensor(output_without_group_offloading): + output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy() + output_with_block_level = output_with_block_level.detach().cpu().numpy() + output_with_leaf_level = output_with_leaf_level.detach().cpu().numpy() + + self.assertTrue(np.allclose(output_without_group_offloading, output_with_block_level, atol=1e-4)) + self.assertTrue(np.allclose(output_without_group_offloading, output_with_leaf_level, atol=1e-4)) + + @unittest.skip("Qwen3VLForConditionalGeneration does not support leaf-level group offloading") + def test_pipeline_level_group_offloading_inference(self): + pass + + @unittest.skip("Qwen3VLForConditionalGeneration does not support sequential CPU offloading") + def test_sequential_cpu_offload_forward_pass(self): + pass + + @unittest.skip("Qwen3VLForConditionalGeneration does not support sequential CPU offloading") + def test_sequential_offload_forward_pass_twice(self): + pass From 84d120557adbc8cc643857071561a61d0aa56719 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Tue, 23 Jun 2026 02:30:41 +0000 Subject: [PATCH 4/5] fix: add missing newline at end of pipeline_output.py --- src/diffusers/pipelines/joyimage/pipeline_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 40d9d3aa100f..23cb24431462 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -21,4 +21,4 @@ class JoyImageEditPlusPipelineOutput(BaseOutput): Output class for JoyImage Edit Plus multi-image editing pipelines. """ - images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file + images: Union[List[PIL.Image.Image], np.ndarray] From 6f2763a8b123c32ec45d7bd413669fa7ac9d1b8a Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Tue, 23 Jun 2026 02:39:50 +0000 Subject: [PATCH 5/5] fix: add missing newline at end of pipeline_output.py --- src/diffusers/pipelines/joyimage/pipeline_output.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 23cb24431462..30be7c248e33 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -22,3 +22,4 @@ class JoyImageEditPlusPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] +