Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/diffusers/modular_pipelines/anima/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,19 @@ def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> Pi
return components, state


# Copied from diffusers.modular_pipelines.qwenimage.before_denoise.get_timesteps
def get_timesteps(scheduler, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)

t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
if hasattr(scheduler, "set_begin_index"):
scheduler.set_begin_index(t_start * scheduler.order)

return timesteps, num_inference_steps - t_start


class AnimaSetTimestepsStep(ModularPipelineBlocks):
model_name = "anima"

Expand Down Expand Up @@ -414,3 +427,76 @@ def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> Pi

self.set_block_state(state, block_state)
return components, state


class AnimaImg2ImgSetTimestepsStep(ModularPipelineBlocks):
"""Set the scheduler timesteps for Anima image-to-image inference.

This step computes the full timestep schedule and stores it in state. It does **not** set
``scheduler.set_begin_index`` — that is handled downstream by
``AnimaImg2ImgVaeEncoderStep``, which slices the schedule based on ``strength``.

Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)

Inputs:
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
Custom sigmas for the denoising process.

Outputs:
timesteps (`Tensor`):
Full timestep schedule for the denoising loop.
num_inference_steps (`int`):
Number of denoising steps (may be updated by ``retrieve_timesteps``).
"""

model_name = "anima"

@property
def expected_components(self) -> list[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]

@property
def description(self) -> str:
return "Set the scheduler timesteps for Anima image-to-image inference."

@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("num_inference_steps"),
InputParam.template("sigmas"),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"timesteps",
type_hint=torch.Tensor,
description="Full timestep schedule for the denoising loop.",
),
OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps."),
]

@torch.no_grad()
def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device

sigmas = (
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
if block_state.sigmas is None
else block_state.sigmas
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
device=device,
sigmas=sigmas,
)
# set_begin_index is omitted: get_timesteps() in AnimaImg2ImgVaeEncoderStep
# slices the schedule and sets the correct offset based on strength.

self.set_block_state(state, block_state)
return components, state
239 changes: 239 additions & 0 deletions src/diffusers/modular_pipelines/anima/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@

from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLQwenImage
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .before_denoise import get_timesteps
from .modular_pipeline import AnimaModularPipeline


Expand Down Expand Up @@ -251,3 +256,237 @@ def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> Pi

self.set_block_state(state, block_state)
return components, state


# Copied from diffusers.modular_pipelines.qwenimage.encoders.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


# Copied from diffusers.modular_pipelines.qwenimage.encoders.encode_vae_image
def encode_vae_image(
image: torch.Tensor,
vae: AutoencoderKLQwenImage,
generator: torch.Generator,
device: torch.device,
dtype: torch.dtype,
latent_channels: int = 16,
sample_mode: str = "argmax",
):
if not isinstance(image, torch.Tensor):
raise ValueError(f"Expected image to be a tensor, got {type(image)}.")

# preprocessed image should be a 4D tensor: batch_size, num_channels, height, width
if image.dim() == 4:
image = image.unsqueeze(2)
elif image.dim() != 5:
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")

image = image.to(device=device, dtype=dtype)

if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, latent_channels, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
latents_std = (
torch.tensor(vae.config.latents_std)
.view(1, latent_channels, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
image_latents = (image_latents - latents_mean) / latents_std

return image_latents


class AnimaImg2ImgVaeEncoderStep(ModularPipelineBlocks):
"""VAE Encoder step for Anima image-to-image generation.

Preprocesses the input image, encodes it with the VAE, generates noise, slices the
timestep schedule based on ``strength``, and adds noise to the image latents using
``scheduler.scale_noise()``.

Components:
vae (`AutoencoderKLQwenImage`)
scheduler (`FlowMatchEulerDiscreteScheduler`)
image_processor (`VaeImageProcessor`)

Inputs:
image (`PIL.Image.Image`):
Input image to use as starting point.
height (`int`, *optional*):
Height of the output image. Defaults to pipeline default.
width (`int`, *optional*):
Width of the output image. Defaults to pipeline default.
strength (`float`, *optional*, defaults to 0.9):
How much to transform the reference image. ``0`` means no change; ``1`` means
fully denoise from random noise.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images to generate per prompt.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
latents (`Tensor`, *optional*):
Pre-computed noise tensor. Generated randomly if ``None``.
timesteps (`Tensor`):
Full timestep schedule produced by ``AnimaImg2ImgSetTimestepsStep``.
num_inference_steps (`int`):
Total number of inference steps from ``AnimaImg2ImgSetTimestepsStep``.

Outputs:
latents (`Tensor`):
Noisy image latents to use as the starting point for denoising.
timesteps (`Tensor`):
Timestep schedule sliced by ``strength``.
num_inference_steps (`int`):
Number of denoising steps after strength-based slicing.
padding_mask (`Tensor`):
Cosmos padding mask for the image latents.
height (`int`):
Output image height (updated to pipeline default if not provided).
width (`int`):
Output image width (updated to pipeline default if not provided).
"""

model_name = "anima"

@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLQwenImage),
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]

@property
def description(self) -> str:
return (
"VAE Encoder step for Anima image-to-image generation. Encodes the input image, "
"slices the timestep schedule by strength, and adds noise via scheduler.scale_noise()."
)

@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("image"),
InputParam.template("height"),
InputParam.template("width"),
InputParam.template("strength"),
InputParam.template("num_images_per_prompt"),
InputParam.template("generator"),
InputParam.template("latents"),
InputParam.template("timesteps", required=True),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="Total number of inference steps from AnimaImg2ImgSetTimestepsStep.",
),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, provided by AnimaTextInputStep.",
),
InputParam("dtype", type_hint=torch.dtype, description="Dtype used by the Anima denoiser."),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="Noisy image latents for the denoising process."
),
OutputParam("timesteps", type_hint=torch.Tensor, description="Timestep schedule sliced by strength."),
OutputParam(
"num_inference_steps", type_hint=int, description="Number of denoising steps after strength slicing."
),
OutputParam("padding_mask", type_hint=torch.Tensor, description="Cosmos padding mask for image latents."),
OutputParam("height", type_hint=int, description="Image height used for generation."),
OutputParam("width", type_hint=int, description="Image width used for generation."),
]

@torch.no_grad()
def __call__(self, components: AnimaModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

device = components._execution_device
# dtype is provided by AnimaTextInputStep; fall back to vae dtype if not yet in state
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype

block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width

block_state.timesteps, block_state.num_inference_steps = get_timesteps(
components.scheduler, block_state.num_inference_steps, block_state.strength
)

# Total batch = prompt batch × images per prompt
total_batch = block_state.batch_size * block_state.num_images_per_prompt

# Preprocess PIL image(s) to tensor
processed_image = components.image_processor.preprocess(
image=block_state.image, height=block_state.height, width=block_state.width
)

# Encode to image latents; use VAE dtype for encoding
image_latents = encode_vae_image(
image=processed_image,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=components.vae.dtype,
latent_channels=components.num_channels_latents,
)

# Expand image_latents to total_batch (handles single image with multiple prompts)
if image_latents.shape[0] < total_batch:
repeats = total_batch // image_latents.shape[0]
image_latents = image_latents.repeat(repeats, 1, 1, 1, 1)

# Generate initial noise (or use pre-provided latents as noise)
if block_state.latents is None:
noise = randn_tensor(
image_latents.shape,
generator=block_state.generator,
device=device,
dtype=torch.float32,
)
else:
noise = block_state.latents.to(device=device, dtype=torch.float32)

# Add noise to image latents at the appropriate noise level for this strength
latent_timestep = block_state.timesteps[:1].repeat(total_batch)
block_state.latents = components.scheduler.scale_noise(
image_latents.to(dtype=torch.float32),
latent_timestep,
noise,
)

block_state.padding_mask = block_state.latents.new_zeros(
1, 1, block_state.height, block_state.width, dtype=dtype
)

self.set_block_state(state, block_state)
return components, state
Loading
Loading