diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 55f759e8a..fc47dc1e5 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -9,7 +9,7 @@ TransformScheme, apply_transform_config, ) -from compressed_tensors.utils import TorchDtype +from compressed_tensors.utils import TorchDtype, get_head_dim from pydantic import Field, ValidationInfo, field_validator from transformers import PreTrainedModel @@ -126,16 +126,17 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.mappings = infer_mapping_from_model(state.model) self.norm_mappings = infer_norm_mapping_from_model(state.model) + head_dim = get_head_dim(state.model.config) config_groups = {} if SpinquantRotation.R1 in self.rotations: config_groups["R1"] = self._create_r1_scheme() if SpinquantRotation.R2 in self.rotations: - config_groups["R2"] = self._create_r2_scheme(state.model) + config_groups["R2"] = self._create_r2_scheme(head_dim) if SpinquantRotation.R3 in self.rotations: - config_groups["R3"] = self._create_r3_scheme() + config_groups["R3"] = self._create_r3_scheme(head_dim) if SpinquantRotation.R4 in self.rotations: config_groups["R4"] = self._create_r4_scheme() @@ -217,24 +218,7 @@ def _create_r1_scheme(self) -> TransformScheme: ], ) - def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: - config = model.config - - if hasattr(config, "head_dim"): - head_dim = config.head_dim - elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): - head_dim = config.hidden_size // config.num_attention_heads - else: - raise NotImplementedError() - - if self.transform_block_size: - if head_dim % self.transform_block_size != 0: - raise ValueError( - f"transform_block_size {self.transform_block_size} must be set " - f"such that model's head_dim {head_dim} is evenly divisible by it" - ) - head_dim = self.transform_block_size - + def _create_r2_scheme(self, head_dim: int) -> TransformScheme: return TransformScheme( type=self.transform_type, randomize=self.randomize, @@ -251,9 +235,23 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: ], ) - def _create_r3_scheme(self) -> TransformScheme: - raise NotImplementedError( - "SpinQuant R3 rotations will be added in a future release" + def _create_r3_scheme(self, head_dim: int) -> TransformScheme: + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + head_dim=head_dim, + apply=[ + TransformArgs( + targets=[self.mappings.attn], + location="q_attn", + ), + TransformArgs( + targets=[self.mappings.attn], + location="k_cache", + ), + ], ) def _create_r4_scheme(self) -> TransformScheme: diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py index 514d1f109..85c2f0c0f 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -14,6 +14,7 @@ class SpinQuantMapping(BaseModel): layers (https://arxiv.org/pdf/2405.16406 Fig. 1). :param embedding: name or regex of embedding layer + :param attn: name or regex of attention block in decoder layer :param attn_q: name or regex of q_proj layer in attention block :param attn_k: name or regex of k_proj layer in attention block :param attn_v: name or regex of v_proj layer in attention block @@ -29,6 +30,7 @@ class SpinQuantMapping(BaseModel): embedding: str + attn: str attn_q: str attn_k: str attn_v: str @@ -50,6 +52,7 @@ def cast_to_list(cls, value): _default_mappings = SpinQuantMapping( embedding="re:.*embed_tokens$", + attn="re:.*self_attn$", attn_q="re:.*q_proj$", attn_k="re:.*k_proj$", attn_v="re:.*v_proj$",