diff --git a/lib/activations/activations/activations_computation.py b/lib/activations/activations/activations_computation.py index 093633c..a431aee 100644 --- a/lib/activations/activations/activations_computation.py +++ b/lib/activations/activations/activations_computation.py @@ -1,27 +1,70 @@ -from typing import Literal +from enum import Enum +from typing import Callable import torch +from nnsight.envoy import Envoy # type: ignore +from nnsight.intervention import InterventionProxy # type: ignore from util.subject import Subject -def get_activations_computing_func(subject: Subject, activation_type: Literal["MLP"], layer: int): +class ActivationType(str, Enum): + RESID = "resid" + MLP_IN = "mlp_in" + MLP_OUT = "mlp_out" + ATTN_OUT = "attn_out" + NEURONS = "neurons" + + +def _get_activations_funcs( + subject: Subject, activation_type: ActivationType, layer: int +) -> tuple[Callable[[], Envoy], Callable[[Envoy], InterventionProxy]]: + if activation_type == ActivationType.RESID: + return ( + lambda: subject.layers[layer], + lambda component: component.output[0], + ) + if activation_type == ActivationType.MLP_IN: + return ( + lambda: subject.mlps[layer], + lambda component: component.input, + ) + if activation_type == ActivationType.MLP_OUT: + return ( + lambda: subject.mlps[layer], + lambda component: component.output, + ) + if activation_type == ActivationType.ATTN_OUT: + return ( + lambda: subject.attns[layer], + lambda component: component.output[0], + ) + if activation_type == ActivationType.NEURONS: + return ( + lambda: subject.w_outs[layer], + lambda component: component.input, + ) + raise ValueError(f"Unknown activation type: {activation_type}") + + +def get_activations_computing_func( + subject: Subject, activation_type: ActivationType, layer: int +) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: """ Returns a function that computes activations for a given input: input_ids: torch.Tensor attn_mask: torch.Tensor """ - if activation_type == "MLP": - mlp_acts_for_layer = subject.w_outs[layer] - - def get_mlp_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: - with torch.no_grad(): - with subject.model.trace( - {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore - ): - acts = mlp_acts_for_layer.input.save() - return acts - - return get_mlp_activations - else: - raise ValueError(f"Unknown activation type: {activation_type}") + get_component, get_activations = _get_activations_funcs(subject, activation_type, layer) + + def activations_computing_func( + input_ids: torch.Tensor, attn_mask: torch.Tensor + ) -> torch.Tensor: + with torch.no_grad(): + with subject.model.trace( + {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore + ): + acts: torch.Tensor = get_activations(get_component()).save() # type: ignore + return acts + + return activations_computing_func diff --git a/lib/activations/activations/exemplars_wrapper.py b/lib/activations/activations/exemplars_wrapper.py index 9aa323b..294f3d8 100644 --- a/lib/activations/activations/exemplars_wrapper.py +++ b/lib/activations/activations/exemplars_wrapper.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np from activations.activations import ActivationRecord +from activations.activations_computation import ActivationType from activations.dataset import ( ChatDataset, HFDatasetWrapper, @@ -321,7 +322,7 @@ class ExemplarConfig(BaseModel): batch_size: int = 512 rand_seqs: int = 10 seed: int = 64 - activation_type: Literal["MLP"] = "MLP" + activation_type: ActivationType = ActivationType.NEURONS class ExemplarsWrapper: @@ -347,6 +348,8 @@ def __init__( if subject.is_chat_model: folder_name_components.append("chat") folder_name_components.append(f"{config.seq_len}seqlen") + if config.activation_type != "neurons": + folder_name_components.append(config.activation_type) assert subject.tokenizer.padding_side == "left" folder_name = "_".join(folder_name_components) @@ -430,10 +433,7 @@ def load_layer_checkpoint(self, layer: int, split: ExemplarSplit) -> ( ExemplarSplit.RANDOM_TEST, ) - if self.config.activation_type == "MLP": - num_features = self.subject.I - else: - raise ValueError(f"Invalid activation type: {self.config.activation_type}") + num_features = self.num_features num_top_feats_to_save = self.config.num_top_acts_to_save k, seq_len = self.config.k, self.config.seq_len @@ -496,10 +496,7 @@ def save_layer_checkpoint( layer_dir = self.get_layer_dir(layer, split) os.makedirs(layer_dir, exist_ok=True) - if self.config.activation_type == "MLP": - num_features = self.subject.I - else: - raise ValueError(f"Invalid activation type: {self.config.activation_type}") + num_features = self.num_features num_top_feats_to_save = self.config.num_top_acts_to_save k, seq_len = self.config.k, self.config.seq_len @@ -883,6 +880,19 @@ def visualize_neuron_exemplars( ) display(HTML(html_content)) # type: ignore + @property + def num_features(self) -> int: + if self.config.activation_type == ActivationType.NEURONS: + return self.subject.I + if self.config.activation_type in ( + ActivationType.RESID, + ActivationType.MLP_IN, + ActivationType.MLP_OUT, + ActivationType.ATTN_OUT, + ): + return self.subject.D + raise ValueError(f"Invalid activation type: {self.config.activation_type}") + ################### # Example Configs # diff --git a/project/expgen/.gitignore b/project/expgen/.gitignore new file mode 100644 index 0000000..8fce603 --- /dev/null +++ b/project/expgen/.gitignore @@ -0,0 +1 @@ +data/ diff --git a/project/expgen/scripts/compute_exemplars.py b/project/expgen/scripts/compute_exemplars.py index 7f24b64..ae017f3 100644 --- a/project/expgen/scripts/compute_exemplars.py +++ b/project/expgen/scripts/compute_exemplars.py @@ -4,8 +4,10 @@ """ import argparse +from typing import Any -from activations.dataset import fineweb_dset_config, lmsys_dset_config +from activations.activations_computation import ActivationType +from activations.dataset import HFDatasetWrapperConfig, fineweb_dset_config, lmsys_dset_config from activations.exemplars import ExemplarSplit from activations.exemplars_computation import ( compute_exemplars_for_layer, @@ -15,12 +17,25 @@ from util.subject import Subject, get_subject_config parser = argparse.ArgumentParser() +parser.add_argument( + "--activation_type", + type=str, + choices=[ + ActivationType.RESID, + ActivationType.MLP_IN, + ActivationType.MLP_OUT, + ActivationType.ATTN_OUT, + ActivationType.NEURONS, + ], + default="neurons", + help="Type of activations from which we pick indices to compute exemplars for.", +) parser.add_argument( "--layer_indices", type=int, nargs="+", default=None, - help="Layers from which we pick neurons to compute exemplars for.", + help="Layers from which we pick indices to compute exemplars for.", ) parser.add_argument( "--subject_hf_model_id", @@ -87,7 +102,7 @@ subject_config = get_subject_config(args.subject_hf_model_id) subject = Subject(subject_config, nnsight_lm_kwargs={"dispatch": True}) -hf_dataset_configs = [] +hf_dataset_configs: list[HFDatasetWrapperConfig] = [] for hf_dataset in args.hf_datasets: if hf_dataset == "fineweb": hf_dataset_configs.append(fineweb_dset_config) @@ -106,13 +121,14 @@ num_top_acts_to_save=args.num_top_acts_to_save, batch_size=args.batch_size, seed=args.seed, + activation_type=args.activation_type, ) exemplars_wrapper = ExemplarsWrapper(args.data_dir, exemplar_config, subject) layer_indices = args.layer_indices if args.layer_indices else range(subject.L) for layer in layer_indices: print(f"============ Layer {layer} ============") - kwargs = { + kwargs: dict[str, Any] = { "exemplars_wrapper": exemplars_wrapper, "layer": layer, "split": ExemplarSplit(args.split),