diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 2a9aa81608e7..f6e582d1642b 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -23,7 +23,7 @@ from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook - from .mag_cache import MagCacheConfig, apply_mag_cache + from .mag_cache import MagCacheConfig, apply_mag_cache, update_mag_cache_num_steps from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py index d28cd2d793b6..a56cccc0dfcc 100644 --- a/src/diffusers/hooks/mag_cache.py +++ b/src/diffusers/hooks/mag_cache.py @@ -130,6 +130,8 @@ def __post_init__(self): if not torch.is_tensor(self.mag_ratios): self.mag_ratios = torch.tensor(self.mag_ratios) + self._original_mag_ratios = self.mag_ratios.clone() + if len(self.mag_ratios) != self.num_inference_steps: logger.debug( f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}" @@ -407,6 +409,7 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: # Initialize registry on the root module so the Pipeline can set context. HookRegistry.check_if_exists_or_initialize(module) + module._mag_cache_config = config state_manager = StateManager(MagCacheState, (), {}) remaining_blocks = [] @@ -466,3 +469,12 @@ def _apply_mag_cache_block_hook( hook = MagCacheBlockHook(state_manager, is_tail, config) registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) + +def update_mag_cache_num_steps(module: torch.nn.Module, num_steps: int) -> None: + config: MagCacheConfig = getattr(module, "_mag_cache_config", None) + if config is None: + return + original_ratios = getattr(config, "_original_mag_ratios", config.mag_ratios) + config.num_inference_steps = num_steps + if original_ratios is not None: + config.mag_ratios = nearest_interp(original_ratios, num_steps) \ No newline at end of file diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index be2d53f17932..00b7f2a852ba 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -586,6 +586,16 @@ def __call__( else: boundary_timestep = None + if boundary_timestep is not None: + from ...hooks.mag_cache import update_mag_cache_num_steps + + n_steps_t1 = sum(1 for t in timesteps if t >= boundary_timestep) + n_steps_t2 = len(timesteps) - n_steps_t1 + if self.transformer is not None: + update_mag_cache_num_steps(self.transformer, n_steps_t1) + if self.transformer_2 is not None: + update_mag_cache_num_steps(self.transformer_2, n_steps_t2) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 8061f67ab6b9..3c753e61d136 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -741,6 +741,16 @@ def __call__( else: boundary_timestep = None + if boundary_timestep is not None: + from ...hooks.mag_cache import update_mag_cache_num_steps + + n_steps_t1 = sum(1 for t in timesteps if t >= boundary_timestep) + n_steps_t2 = len(timesteps) - n_steps_t1 + if self.transformer is not None: + update_mag_cache_num_steps(self.transformer, n_steps_t1) + if self.transformer_2 is not None: + update_mag_cache_num_steps(self.transformer_2, n_steps_t2) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index b0896d382d67..910bbaadee86 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -955,6 +955,16 @@ def __call__( else: boundary_timestep = None + if boundary_timestep is not None: + from ...hooks.mag_cache import update_mag_cache_num_steps + + n_steps_t1 = sum(1 for t in timesteps if t >= boundary_timestep) + n_steps_t2 = len(timesteps) - n_steps_t1 + if self.transformer is not None: + update_mag_cache_num_steps(self.transformer, n_steps_t1) + if self.transformer_2 is not None: + update_mag_cache_num_steps(self.transformer_2, n_steps_t2) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: