Skip to content
11 changes: 9 additions & 2 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,14 @@ def setUp(self) -> None:
transformer=transformer_8bit,
torch_dtype=torch.float16,
)
self.pipeline_8bit.enable_model_cpu_offload()
# On devices with <= 24 GB VRAM, enable_model_cpu_offload can OOM because it
# moves an entire sub-model to the accelerator at once. Fall back to
# sequential (per-layer) CPU offload in that case.
_, total_mem = torch.accelerator.get_memory_info(0)
if total_mem <= 25 * (1024**3):
self.pipeline_8bit.enable_sequential_cpu_offload()
else:
self.pipeline_8bit.enable_model_cpu_offload()

def tearDown(self):
del self.pipeline_8bit
Expand Down Expand Up @@ -709,7 +716,7 @@ def test_lora_loading(self):
expected_slice = np.array([0.3916, 0.3916, 0.3887, 0.4243, 0.4155, 0.4233, 0.4570, 0.4531, 0.4248])

max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
self.assertTrue(max_diff < 2e-3)


@require_transformers_version_greater("4.44.0")
Expand Down
Loading