Skip to content

Commit 532d1f0

Browse files
kmaherxschraderSimonpre-commit-ci[bot]
authored
Gemma default device fix (#161)
* Add max_memory parameter to run config Co-authored-by: Simon Schrader <[email protected]> * Use configurable max_memory for offline explainer Co-authored-by: Simon Schrader <[email protected]> * Fix breaking change in prompt input formatting from vLLM Co-authored-by: Simon Schrader <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug in gemmascope device type check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Simon Schrader <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0837a97 commit 532d1f0

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

delphi/sparse_coders/custom/gemmascope.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def from_pretrained(cls, model_name_or_path, position, device):
104104
pt_params = {k: torch.from_numpy(v) for k, v in params.items()}
105105
model = cls(params["W_enc"].shape[0], params["W_enc"].shape[1])
106106
model.load_state_dict(pt_params)
107-
if device == "cuda":
107+
if device == "cuda" or (
108+
isinstance(device, torch.device) and device.type == "cuda"
109+
):
108110
model.cuda()
109111
return model

0 commit comments

Comments
 (0)