Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .mag_cache import MagCacheConfig, apply_mag_cache
from .mag_cache import MagCacheConfig, apply_mag_cache, update_mag_cache_num_steps
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/hooks/mag_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def __post_init__(self):
if not torch.is_tensor(self.mag_ratios):
self.mag_ratios = torch.tensor(self.mag_ratios)

self._original_mag_ratios = self.mag_ratios.clone()

if len(self.mag_ratios) != self.num_inference_steps:
logger.debug(
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
Expand Down Expand Up @@ -407,6 +409,7 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
# Initialize registry on the root module so the Pipeline can set context.
HookRegistry.check_if_exists_or_initialize(module)

module._mag_cache_config = config
state_manager = StateManager(MagCacheState, (), {})
remaining_blocks = []

Expand Down Expand Up @@ -466,3 +469,12 @@ def _apply_mag_cache_block_hook(

hook = MagCacheBlockHook(state_manager, is_tail, config)
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)

def update_mag_cache_num_steps(module: torch.nn.Module, num_steps: int) -> None:
config: MagCacheConfig = getattr(module, "_mag_cache_config", None)
if config is None:
return
original_ratios = getattr(config, "_original_mag_ratios", config.mag_ratios)
config.num_inference_steps = num_steps
if original_ratios is not None:
config.mag_ratios = nearest_interp(original_ratios, num_steps)
10 changes: 10 additions & 0 deletions src/diffusers/pipelines/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,16 @@ def __call__(
else:
boundary_timestep = None

if boundary_timestep is not None:
from ...hooks.mag_cache import update_mag_cache_num_steps

n_steps_t1 = sum(1 for t in timesteps if t >= boundary_timestep)
n_steps_t2 = len(timesteps) - n_steps_t1
if self.transformer is not None:
update_mag_cache_num_steps(self.transformer, n_steps_t1)
if self.transformer_2 is not None:
update_mag_cache_num_steps(self.transformer_2, n_steps_t2)

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,16 @@ def __call__(
else:
boundary_timestep = None

if boundary_timestep is not None:
from ...hooks.mag_cache import update_mag_cache_num_steps

n_steps_t1 = sum(1 for t in timesteps if t >= boundary_timestep)
n_steps_t2 = len(timesteps) - n_steps_t1
if self.transformer is not None:
update_mag_cache_num_steps(self.transformer, n_steps_t1)
if self.transformer_2 is not None:
update_mag_cache_num_steps(self.transformer_2, n_steps_t2)

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipelines/wan/pipeline_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,16 @@ def __call__(
else:
boundary_timestep = None

if boundary_timestep is not None:
from ...hooks.mag_cache import update_mag_cache_num_steps

n_steps_t1 = sum(1 for t in timesteps if t >= boundary_timestep)
n_steps_t2 = len(timesteps) - n_steps_t1
if self.transformer is not None:
update_mag_cache_num_steps(self.transformer, n_steps_t1)
if self.transformer_2 is not None:
update_mag_cache_num_steps(self.transformer_2, n_steps_t2)

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
Expand Down
Loading