77# This file provides the util functions to apply activation checkpointing to the model.
88# Technically, this is not a part of distributed, but distributed module is the best place to put it.
99
10+ import os
1011from collections import defaultdict
1112
1213import torch
@@ -293,6 +294,7 @@ def apply_ac(
293294 model_compile_enabled : bool = False ,
294295 use_flex_attn : bool = False ,
295296 op_sac_save_list : set [torch ._ops .OpOverload ] | None = None ,
297+ base_folder : str = "" ,
296298) -> None :
297299 """Apply activation checkpointing to the model.
298300
@@ -311,15 +313,27 @@ def apply_ac(
311313 None
312314 """
313315
314- for layer_id , transformer_block in model .layers .named_children ():
315- transformer_block = _apply_ac_to_transformer_block (
316- transformer_block ,
317- ac_config ,
318- base_fqn = f"layers.{ layer_id } " ,
319- model_compile_enabled = model_compile_enabled ,
320- use_flex_attn = use_flex_attn ,
321- op_sac_save_list = op_sac_save_list ,
322- )
323- model .layers .register_module (layer_id , transformer_block )
316+ if ac_config .mode == "memory_budget" :
317+ assert model_compile_enabled , "Memory budget mode requires model to be compiled"
318+ if ac_config .visualize_memory_budget_pareto :
319+ pareto_dir = os .path .join (base_folder , "memory_budget_pareto" )
320+ if not os .path .exists (pareto_dir ):
321+ os .makedirs (pareto_dir , exist_ok = True )
322+ torch ._functorch .config .memory_budget_pareto_dir = pareto_dir
323+ torch ._functorch .config .visualize_memory_budget_pareto = True
324+
325+ torch ._functorch .config .activation_memory_budget = ac_config .memory_budget
326+ logger .info (f"Selected { ac_config .memory_budget } budget option" )
327+ else :
328+ for layer_id , transformer_block in model .layers .named_children ():
329+ transformer_block = _apply_ac_to_transformer_block (
330+ transformer_block ,
331+ ac_config ,
332+ base_fqn = f"layers.{ layer_id } " ,
333+ model_compile_enabled = model_compile_enabled ,
334+ use_flex_attn = use_flex_attn ,
335+ op_sac_save_list = op_sac_save_list ,
336+ )
337+ model .layers .register_module (layer_id , transformer_block )
324338
325339 logger .info (f"Applied { ac_config .mode } activation checkpointing to the model" )
0 commit comments