Skip to content

Commit 4a38e0d

Browse files
committed
fix bug where observers were called twice
Signed-off-by: Kyle Sayers <[email protected]>
1 parent b98e61a commit 4a38e0d

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from compressed_tensors.modeling import (
55
IMPL_ATTR,
66
KV_CACHE_ATTR,
7-
register_key_hook,
8-
register_query_hook,
9-
register_value_hook,
107
)
118
from compressed_tensors.quantization import (
129
DynamicType,
@@ -309,10 +306,10 @@ def _initialize_hooks(self, module: torch.nn.Module) -> Set[RemovableHandle]:
309306
)
310307
else:
311308
if hasattr(module, IMPL_ATTR):
312-
hooks.add(register_query_hook(module, calibrate_query_hook))
309+
hooks.add(self.register_hook(module, calibrate_query_hook, "query"))
313310
if hasattr(module, KV_CACHE_ATTR):
314-
hooks.add(register_key_hook(module, calibrate_key_hook))
315-
hooks.add(register_value_hook(module, calibrate_value_hook))
311+
hooks.add(self.register_hook(module, calibrate_key_hook, "key"))
312+
hooks.add(self.register_hook(module, calibrate_value_hook, "value"))
316313

317314
# output activations
318315
if output:

src/llmcompressor/modifiers/utils/hooks.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import contextlib
2-
from functools import wraps
2+
from functools import partial, wraps
33
from typing import Any, Callable, ClassVar, Optional, Set, Union
44

55
import torch
6+
from compressed_tensors.modeling import (
7+
register_key_hook,
8+
register_query_hook,
9+
register_value_hook,
10+
)
611
from loguru import logger
712
from pydantic import BaseModel
813
from torch.utils.hooks import RemovableHandle
@@ -92,7 +97,7 @@ def wrapped_hook(*args, **kwargs):
9297

9398
return hook(*args, **kwargs)
9499

95-
register_function = getattr(target, f"register_{hook_type}_hook")
100+
register_function = self._get_register_function(target, hook_type)
96101
handle = register_function(wrapped_hook, **kwargs)
97102
self._hooks.add(handle)
98103
logger.debug(f"{self} added {handle}")
@@ -113,3 +118,15 @@ def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
113118
hook.remove()
114119

115120
self._hooks -= handles
121+
122+
def _get_register_function(
123+
self, target: torch.nn.Module, hook_type: str
124+
) -> Callable:
125+
if hook_type == "query":
126+
return partial(register_query_hook, target)
127+
elif hook_type == "key":
128+
return partial(register_key_hook, target)
129+
elif hook_type == "value":
130+
return partial(register_value_hook, target)
131+
else:
132+
return getattr(target, f"register_{hook_type}_hook")

0 commit comments

Comments
 (0)