Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions src/llmcompressor/modifiers/transform/spinquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
TransformScheme,
apply_transform_config,
)
from compressed_tensors.utils import TorchDtype
from compressed_tensors.utils import TorchDtype, get_head_dim
from pydantic import Field, ValidationInfo, field_validator
from transformers import PreTrainedModel

Expand Down Expand Up @@ -126,16 +126,17 @@ def on_initialize(self, state: State, **kwargs) -> bool:

self.mappings = infer_mapping_from_model(state.model)
self.norm_mappings = infer_norm_mapping_from_model(state.model)
head_dim = get_head_dim(state.model.config)

config_groups = {}
if SpinquantRotation.R1 in self.rotations:
config_groups["R1"] = self._create_r1_scheme()

if SpinquantRotation.R2 in self.rotations:
config_groups["R2"] = self._create_r2_scheme(state.model)
config_groups["R2"] = self._create_r2_scheme(head_dim)

if SpinquantRotation.R3 in self.rotations:
config_groups["R3"] = self._create_r3_scheme()
config_groups["R3"] = self._create_r3_scheme(head_dim)

if SpinquantRotation.R4 in self.rotations:
config_groups["R4"] = self._create_r4_scheme()
Expand Down Expand Up @@ -217,24 +218,7 @@ def _create_r1_scheme(self) -> TransformScheme:
],
)

def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
config = model.config

if hasattr(config, "head_dim"):
head_dim = config.head_dim
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
head_dim = config.hidden_size // config.num_attention_heads
else:
raise NotImplementedError()

if self.transform_block_size:
if head_dim % self.transform_block_size != 0:
raise ValueError(
f"transform_block_size {self.transform_block_size} must be set "
f"such that model's head_dim {head_dim} is evenly divisible by it"
)
head_dim = self.transform_block_size

def _create_r2_scheme(self, head_dim: int) -> TransformScheme:
return TransformScheme(
type=self.transform_type,
randomize=self.randomize,
Expand All @@ -251,9 +235,23 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
],
)

def _create_r3_scheme(self) -> TransformScheme:
raise NotImplementedError(
"SpinQuant R3 rotations will be added in a future release"
def _create_r3_scheme(self, head_dim: int) -> TransformScheme:
return TransformScheme(
type=self.transform_type,
randomize=self.randomize,
requires_grad=self.learnable,
precision=self.precision,
head_dim=head_dim,
apply=[
TransformArgs(
targets=[self.mappings.attn],
location="q_attn",
),
TransformArgs(
targets=[self.mappings.attn],
location="k_cache",
),
],
)

def _create_r4_scheme(self) -> TransformScheme:
Expand Down
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/transform/spinquant/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SpinQuantMapping(BaseModel):
layers (https://arxiv.org/pdf/2405.16406 Fig. 1).

:param embedding: name or regex of embedding layer
:param attn: name or regex of attention block in decoder layer
:param attn_q: name or regex of q_proj layer in attention block
:param attn_k: name or regex of k_proj layer in attention block
:param attn_v: name or regex of v_proj layer in attention block
Expand All @@ -29,6 +30,7 @@ class SpinQuantMapping(BaseModel):

embedding: str

attn: str
attn_q: str
attn_k: str
attn_v: str
Expand All @@ -50,6 +52,7 @@ def cast_to_list(cls, value):

_default_mappings = SpinQuantMapping(
embedding="re:.*embed_tokens$",
attn="re:.*self_attn$",
attn_q="re:.*q_proj$",
attn_k="re:.*k_proj$",
attn_v="re:.*v_proj$",
Expand Down