Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions src/detoxai/methods/clarcs/rrclarc.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,55 @@ def apply_model_correction(
assert len(cav_layers) == 1, "RR-CLARC only supports one CAV layer"
self.cav_layer = cav_layers[0]

# Register rr_clarc_hook
# Register rr_clarc_hook and build list of module names in order
cav_layer_found = False
module_names_in_order = []
for name, module in self.model.named_modules():
if name: # Skip empty string (root module)
module_names_in_order.append(name)
if name == self.cav_layer:
hook_fn = self.rr_clarc_hook()

handle = module.register_forward_hook(hook_fn)
self.hooks.append(handle)
_logger.debug(f"Added RR-CLARC hook to layer: {name}")
cav_layer_found = True
# Continue to collect all module names for proper layer ordering

if not cav_layer_found:
raise ValueError(f"CAV layer '{self.cav_layer}' not found in model")

# Get the index of cav_layer and the set of modules to freeze (all before cav_layer)
cav_layer_idx = module_names_in_order.index(self.cav_layer)
modules_to_freeze = set(module_names_in_order[:cav_layer_idx])

# Store original requires_grad state and freeze layers before cav_layer
self.original_requires_grad = {}
for name, param in self.model.named_parameters():
self.original_requires_grad[name] = param.requires_grad

# Check if this parameter belongs to a module that should be frozen
# For parameter 'layer1.0.conv.weight', we check if 'layer1', 'layer1.0',
# or 'layer1.0.conv' is in modules_to_freeze
if "." not in name:
# Top-level parameters (unlikely in practice) - skip freezing
continue

param_module = name.rsplit(".", 1)[
0
] # Extract module path from parameter name
should_freeze = False

# Check all possible parent module paths (break early on first match)
parts = param_module.split(".")
for i in range(1, len(parts) + 1):
module_path = ".".join(parts[:i])
if module_path in modules_to_freeze:
should_freeze = True
break # Found a match, no need to check further

if should_freeze:
param.requires_grad = False
_logger.debug(f"Frozen parameter: {name}")

# Override training_step in lightning model by modified_training_step
clone_original_training_step = deepcopy(self.lightning_model.training_step)
Expand All @@ -92,7 +133,9 @@ def apply_model_correction(

def configure_optimizers(self):
""" """
optimizer = torch.optim.Adam(self.parameters(), lr=ft_lr)
# Only optimize parameters that require gradients (from cav_layer onwards)
params_to_optimize = [p for p in self.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params_to_optimize, lr=ft_lr)
return optimizer

self.lightning_model.configure_optimizers = types.MethodType(
Expand All @@ -117,6 +160,11 @@ def configure_optimizers(self):
# Go back to eval mode
self.lightning_model.eval()

# Restore original requires_grad state
for name, param in self.model.named_parameters():
if name in self.original_requires_grad:
param.requires_grad = self.original_requires_grad[name]

# Remove hooks
self.remove_hooks()

Expand Down