@@ -130,6 +130,8 @@ def __post_init__(self):
130130 if not torch .is_tensor (self .mag_ratios ):
131131 self .mag_ratios = torch .tensor (self .mag_ratios )
132132
133+ self ._original_mag_ratios = self .mag_ratios .clone ()
134+
133135 if len (self .mag_ratios ) != self .num_inference_steps :
134136 logger .debug (
135137 f"Interpolating mag_ratios from length { len (self .mag_ratios )} to { self .num_inference_steps } "
@@ -407,6 +409,7 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
407409 # Initialize registry on the root module so the Pipeline can set context.
408410 HookRegistry .check_if_exists_or_initialize (module )
409411
412+ module ._mag_cache_config = config
410413 state_manager = StateManager (MagCacheState , (), {})
411414 remaining_blocks = []
412415
@@ -466,3 +469,12 @@ def _apply_mag_cache_block_hook(
466469
467470 hook = MagCacheBlockHook (state_manager , is_tail , config )
468471 registry .register_hook (hook , _MAG_CACHE_BLOCK_HOOK )
472+
473+ def update_mag_cache_num_steps (module : torch .nn .Module , num_steps : int ) -> None :
474+ config : MagCacheConfig = getattr (module , "_mag_cache_config" , None )
475+ if config is None :
476+ return
477+ original_ratios = getattr (config , "_original_mag_ratios" , config .mag_ratios )
478+ config .num_inference_steps = num_steps
479+ if original_ratios is not None :
480+ config .mag_ratios = nearest_interp (original_ratios , num_steps )
0 commit comments