Skip to content

Commit ffc9da0

Browse files
authored
fix(mag_cache): correct per-transformer step count for dual-transformer pipelines
1 parent 2d0110f commit ffc9da0

5 files changed

Lines changed: 43 additions & 1 deletion

File tree

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .hooks import HookRegistry, ModelHook
2424
from .layer_skip import LayerSkipConfig, apply_layer_skip
2525
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
26-
from .mag_cache import MagCacheConfig, apply_mag_cache
26+
from .mag_cache import MagCacheConfig, apply_mag_cache, update_mag_cache_num_steps
2727
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
2828
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
2929
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache

src/diffusers/hooks/mag_cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def __post_init__(self):
130130
if not torch.is_tensor(self.mag_ratios):
131131
self.mag_ratios = torch.tensor(self.mag_ratios)
132132

133+
self._original_mag_ratios = self.mag_ratios.clone()
134+
133135
if len(self.mag_ratios) != self.num_inference_steps:
134136
logger.debug(
135137
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
@@ -407,6 +409,7 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
407409
# Initialize registry on the root module so the Pipeline can set context.
408410
HookRegistry.check_if_exists_or_initialize(module)
409411

412+
module._mag_cache_config = config
410413
state_manager = StateManager(MagCacheState, (), {})
411414
remaining_blocks = []
412415

@@ -466,3 +469,12 @@ def _apply_mag_cache_block_hook(
466469

467470
hook = MagCacheBlockHook(state_manager, is_tail, config)
468471
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)
472+
473+
def update_mag_cache_num_steps(module: torch.nn.Module, num_steps: int) -> None:
474+
config: MagCacheConfig = getattr(module, "_mag_cache_config", None)
475+
if config is None:
476+
return
477+
original_ratios = getattr(config, "_original_mag_ratios", config.mag_ratios)
478+
config.num_inference_steps = num_steps
479+
if original_ratios is not None:
480+
config.mag_ratios = nearest_interp(original_ratios, num_steps)

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,16 @@ def __call__(
586586
else:
587587
boundary_timestep = None
588588

589+
if boundary_timestep is not None:
590+
from ...hooks.mag_cache import update_mag_cache_num_steps
591+
592+
n_steps_t1 = sum(1 for t in timesteps if t >= boundary_timestep)
593+
n_steps_t2 = len(timesteps) - n_steps_t1
594+
if self.transformer is not None:
595+
update_mag_cache_num_steps(self.transformer, n_steps_t1)
596+
if self.transformer_2 is not None:
597+
update_mag_cache_num_steps(self.transformer_2, n_steps_t2)
598+
589599
with self.progress_bar(total=num_inference_steps) as progress_bar:
590600
for i, t in enumerate(timesteps):
591601
if self.interrupt:

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,16 @@ def __call__(
741741
else:
742742
boundary_timestep = None
743743

744+
if boundary_timestep is not None:
745+
from ...hooks.mag_cache import update_mag_cache_num_steps
746+
747+
n_steps_t1 = sum(1 for t in timesteps if t >= boundary_timestep)
748+
n_steps_t2 = len(timesteps) - n_steps_t1
749+
if self.transformer is not None:
750+
update_mag_cache_num_steps(self.transformer, n_steps_t1)
751+
if self.transformer_2 is not None:
752+
update_mag_cache_num_steps(self.transformer_2, n_steps_t2)
753+
744754
with self.progress_bar(total=num_inference_steps) as progress_bar:
745755
for i, t in enumerate(timesteps):
746756
if self.interrupt:

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,16 @@ def __call__(
955955
else:
956956
boundary_timestep = None
957957

958+
if boundary_timestep is not None:
959+
from ...hooks.mag_cache import update_mag_cache_num_steps
960+
961+
n_steps_t1 = sum(1 for t in timesteps if t >= boundary_timestep)
962+
n_steps_t2 = len(timesteps) - n_steps_t1
963+
if self.transformer is not None:
964+
update_mag_cache_num_steps(self.transformer, n_steps_t1)
965+
if self.transformer_2 is not None:
966+
update_mag_cache_num_steps(self.transformer_2, n_steps_t2)
967+
958968
with self.progress_bar(total=num_inference_steps) as progress_bar:
959969
for i, t in enumerate(timesteps):
960970
if self.interrupt:

0 commit comments

Comments
 (0)