From 4bbb19714ffef37bc5d74a3333ea8c32e2acc390 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 13 Oct 2025 20:54:05 +0000 Subject: [PATCH 1/4] Initial plan From bcd9208ff644d24d2be4d416845574858a484c83 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 13 Oct 2025 21:05:40 +0000 Subject: [PATCH 2/4] Implement layer freezing in RR-ClArC fine-tuning Co-authored-by: istepka <49250572+istepka@users.noreply.github.com> --- src/detoxai/methods/clarcs/rrclarc.py | 38 ++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/detoxai/methods/clarcs/rrclarc.py b/src/detoxai/methods/clarcs/rrclarc.py index 272bbea..e36d738 100644 --- a/src/detoxai/methods/clarcs/rrclarc.py +++ b/src/detoxai/methods/clarcs/rrclarc.py @@ -76,6 +76,7 @@ def apply_model_correction( self.cav_layer = cav_layers[0] # Register rr_clarc_hook + cav_layer_found = False for name, module in self.model.named_modules(): if name == self.cav_layer: hook_fn = self.rr_clarc_hook() @@ -83,6 +84,34 @@ def apply_model_correction( handle = module.register_forward_hook(hook_fn) self.hooks.append(handle) _logger.debug(f"Added RR-CLARC hook to layer: {name}") + cav_layer_found = True + break + + if not cav_layer_found: + raise ValueError(f"CAV layer '{self.cav_layer}' not found in model") + + # Build a list of all module names in order + module_names_in_order = [name for name, _ in self.model.named_modules() if name] + + # Find the index of cav_layer + try: + cav_layer_idx = module_names_in_order.index(self.cav_layer) + except ValueError: + raise ValueError(f"CAV layer '{self.cav_layer}' not found in model modules") + + # Get the set of modules that should be frozen (all before cav_layer) + modules_to_freeze = set(module_names_in_order[:cav_layer_idx]) + + # Store original requires_grad state and freeze layers before cav_layer + self.original_requires_grad = {} + for name, param in self.model.named_parameters(): + self.original_requires_grad[name] = param.requires_grad + # Extract the module name from parameter name (e.g., "layer1.0.weight" -> "layer1.0") + param_module_name = name.rsplit(".", 1)[0] if "." in name else name + # Check if this parameter belongs to a module that should be frozen + if param_module_name in modules_to_freeze: + param.requires_grad = False + _logger.debug(f"Frozen parameter: {name}") # Override training_step in lightning model by modified_training_step clone_original_training_step = deepcopy(self.lightning_model.training_step) @@ -92,7 +121,9 @@ def apply_model_correction( def configure_optimizers(self): """ """ - optimizer = torch.optim.Adam(self.parameters(), lr=ft_lr) + # Only optimize parameters that require gradients (from cav_layer onwards) + params_to_optimize = [p for p in self.parameters() if p.requires_grad] + optimizer = torch.optim.Adam(params_to_optimize, lr=ft_lr) return optimizer self.lightning_model.configure_optimizers = types.MethodType( @@ -117,6 +148,11 @@ def configure_optimizers(self): # Go back to eval mode self.lightning_model.eval() + # Restore original requires_grad state + for name, param in self.model.named_parameters(): + if name in self.original_requires_grad: + param.requires_grad = self.original_requires_grad[name] + # Remove hooks self.remove_hooks() From 41aefe7c9680a26eb3d70b36afba5ac47487bd7a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 13 Oct 2025 21:09:53 +0000 Subject: [PATCH 3/4] Improve layer freezing logic to handle nested modules correctly Co-authored-by: istepka <49250572+istepka@users.noreply.github.com> --- src/detoxai/methods/clarcs/rrclarc.py | 36 +++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/detoxai/methods/clarcs/rrclarc.py b/src/detoxai/methods/clarcs/rrclarc.py index e36d738..c0dd08c 100644 --- a/src/detoxai/methods/clarcs/rrclarc.py +++ b/src/detoxai/methods/clarcs/rrclarc.py @@ -75,41 +75,45 @@ def apply_model_correction( assert len(cav_layers) == 1, "RR-CLARC only supports one CAV layer" self.cav_layer = cav_layers[0] - # Register rr_clarc_hook + # Register rr_clarc_hook and build list of module names in order cav_layer_found = False + module_names_in_order = [] for name, module in self.model.named_modules(): + if name: # Skip empty string (root module) + module_names_in_order.append(name) if name == self.cav_layer: hook_fn = self.rr_clarc_hook() - handle = module.register_forward_hook(hook_fn) self.hooks.append(handle) _logger.debug(f"Added RR-CLARC hook to layer: {name}") cav_layer_found = True - break if not cav_layer_found: raise ValueError(f"CAV layer '{self.cav_layer}' not found in model") - # Build a list of all module names in order - module_names_in_order = [name for name, _ in self.model.named_modules() if name] - - # Find the index of cav_layer - try: - cav_layer_idx = module_names_in_order.index(self.cav_layer) - except ValueError: - raise ValueError(f"CAV layer '{self.cav_layer}' not found in model modules") - - # Get the set of modules that should be frozen (all before cav_layer) + # Get the index of cav_layer and the set of modules to freeze (all before cav_layer) + cav_layer_idx = module_names_in_order.index(self.cav_layer) modules_to_freeze = set(module_names_in_order[:cav_layer_idx]) # Store original requires_grad state and freeze layers before cav_layer self.original_requires_grad = {} for name, param in self.model.named_parameters(): self.original_requires_grad[name] = param.requires_grad - # Extract the module name from parameter name (e.g., "layer1.0.weight" -> "layer1.0") - param_module_name = name.rsplit(".", 1)[0] if "." in name else name + # Check if this parameter belongs to a module that should be frozen - if param_module_name in modules_to_freeze: + # For parameter 'layer1.0.conv.weight', we check if 'layer1', 'layer1.0', + # or 'layer1.0.conv' is in modules_to_freeze + param_parts = name.rsplit(".", 1)[0] if "." in name else name + should_freeze = False + # Check all possible parent module paths + parts = param_parts.split(".") + for i in range(1, len(parts) + 1): + module_path = ".".join(parts[:i]) + if module_path in modules_to_freeze: + should_freeze = True + break + + if should_freeze: param.requires_grad = False _logger.debug(f"Frozen parameter: {name}") From b08f53acf6f63d99900b0bee6f64b441234bdc1f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 13 Oct 2025 21:12:31 +0000 Subject: [PATCH 4/4] Address code review feedback: improve comments and edge case handling Co-authored-by: istepka <49250572+istepka@users.noreply.github.com> --- src/detoxai/methods/clarcs/rrclarc.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/detoxai/methods/clarcs/rrclarc.py b/src/detoxai/methods/clarcs/rrclarc.py index c0dd08c..0a92d22 100644 --- a/src/detoxai/methods/clarcs/rrclarc.py +++ b/src/detoxai/methods/clarcs/rrclarc.py @@ -87,6 +87,7 @@ def apply_model_correction( self.hooks.append(handle) _logger.debug(f"Added RR-CLARC hook to layer: {name}") cav_layer_found = True + # Continue to collect all module names for proper layer ordering if not cav_layer_found: raise ValueError(f"CAV layer '{self.cav_layer}' not found in model") @@ -103,15 +104,22 @@ def apply_model_correction( # Check if this parameter belongs to a module that should be frozen # For parameter 'layer1.0.conv.weight', we check if 'layer1', 'layer1.0', # or 'layer1.0.conv' is in modules_to_freeze - param_parts = name.rsplit(".", 1)[0] if "." in name else name + if "." not in name: + # Top-level parameters (unlikely in practice) - skip freezing + continue + + param_module = name.rsplit(".", 1)[ + 0 + ] # Extract module path from parameter name should_freeze = False - # Check all possible parent module paths - parts = param_parts.split(".") + + # Check all possible parent module paths (break early on first match) + parts = param_module.split(".") for i in range(1, len(parts) + 1): module_path = ".".join(parts[:i]) if module_path in modules_to_freeze: should_freeze = True - break + break # Found a match, no need to check further if should_freeze: param.requires_grad = False