Skip to content

Commit 7e0749d

Browse files
committed
Add tests
1 parent b30977e commit 7e0749d

File tree

2 files changed

+106
-24
lines changed

2 files changed

+106
-24
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from torch.testing._internal import common_utils
1616
from torch.testing._internal.common_utils import run_tests
1717

18+
from torchao.float8.inference import Float8MMConfig
1819
from torchao.quantization import (
1920
Float8DynamicActivationFloat8WeightConfig,
2021
Float8WeightOnlyConfig,
22+
Granularity,
2123
PerRow,
2224
PerTensor,
2325
quantize_,
@@ -82,7 +84,7 @@ def test_fp8_linear_variants(
8284
dtype: torch.dtype,
8385
mode: str,
8486
compile: bool,
85-
granularity,
87+
granularity: Granularity,
8688
kernel_preference: KernelPreference,
8789
sizes: Tuple,
8890
):
@@ -148,6 +150,61 @@ def test_fp8_linear_variants(
148150
f"Quantization error is too high got a SQNR of {error}"
149151
)
150152

153+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
154+
@unittest.skipIf(
155+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
156+
)
157+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
158+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
159+
@common_utils.parametrize(
160+
"kernel_preference",
161+
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
162+
)
163+
# Inputs are (M,..), K, N
164+
@common_utils.parametrize(
165+
"sizes",
166+
[
167+
((128,), 256, 128),
168+
((32, 128), 64, 256),
169+
],
170+
)
171+
def test_fp8_matmul(
172+
self,
173+
dtype: torch.dtype,
174+
granularity: Granularity,
175+
kernel_preference: KernelPreference,
176+
sizes: Tuple,
177+
):
178+
if (
179+
isinstance(granularity, PerTensor)
180+
and kernel_preference == KernelPreference.FBGEMM
181+
):
182+
return unittest.skip(
183+
"per tensor with fbgemm kernel preferece does not work yet"
184+
)
185+
M, N, K = sizes
186+
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
187+
weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda")
188+
mm_config = Float8MMConfig()
189+
input_tensor_fp8 = Float8Tensor.from_hp(
190+
input_tensor,
191+
granularity=granularity,
192+
mm_config=mm_config,
193+
kernel_preference=kernel_preference,
194+
)
195+
weight_tensor_fp8 = Float8Tensor.from_hp(
196+
weight_tensor,
197+
granularity=granularity,
198+
mm_config=mm_config,
199+
kernel_preference=kernel_preference,
200+
)
201+
output_tensor = torch.matmul(input_tensor, weight_tensor)
202+
output_tensor_fp8 = torch.matmul(input_tensor_fp8, weight_tensor_fp8)
203+
error = compute_error(output_tensor, output_tensor_fp8)
204+
assert compute_error(output_tensor, output_tensor_fp8) > 20, (
205+
f"Quantization error is too high got a SQNR of {error}"
206+
)
207+
151208
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
152209
@unittest.skipIf(
153210
not is_sm_at_least_90(),
@@ -653,6 +710,38 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):
653710

654711
self.assertEqual(sliced_dequantized, sliced_original)
655712

713+
def test_to_dtype_layout(self):
714+
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
715+
x_fp8 = Float8Tensor.from_hp(x)
716+
y_fp8 = torch.ops.aten.to.dtype_layout(
717+
x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu"
718+
)
719+
self.assertEqual(y_fp8.dtype, x_fp8.dtype)
720+
self.assertEqual(y_fp8.layout, x_fp8.layout)
721+
self.assertEqual(y_fp8.device, torch.device("cpu"))
722+
723+
def test_has_compatible_shallow_copy_type(self):
724+
x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
725+
x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
726+
x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16)
727+
x1_fp8 = Float8Tensor.from_hp(x1)
728+
x2_fp8 = Float8Tensor.from_hp(x2)
729+
x3_fp8 = Float8Tensor.from_hp(x3)
730+
self.assertFalse(torch._has_compatible_shallow_copy_type(x1, x2_fp8))
731+
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x2))
732+
self.assertTrue(torch._has_compatible_shallow_copy_type(x1_fp8, x2_fp8))
733+
# Wrong shape
734+
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8))
735+
736+
def test_transpose(self):
737+
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
738+
x_fp8 = Float8Tensor.from_hp(x)
739+
x_fp8_t = x_fp8.t()
740+
torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0)
741+
torch.testing.assert_close(x_fp8_t.scale, x_fp8.scale.t(), atol=0, rtol=0)
742+
self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0)
743+
self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0)
744+
656745

657746
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
658747

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -286,18 +286,19 @@ def _float8_linear_impl(
286286
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
287287
)
288288

289-
# TODO: make this better
290289
# During the backward pass, we transpose the weight tensor,
291290
# so if the weight tensor was originally rowwise quantized,
292291
# now it becomes colwise. In this case, simply dequantize
293292
# the tensor and do a bf16 matmul
294-
is_backward = (
295-
weight_tensor.block_size[0] == weight_tensor.shape[0] and
296-
weight_tensor.block_size[1] == 1
293+
is_colwise = (
294+
weight_tensor.block_size[0] == weight_tensor.shape[0]
295+
and weight_tensor.block_size[1] == 1
297296
)
298-
if is_backward:
297+
if is_colwise:
299298
return torch.nn.functional.linear(
300-
input_tensor, weight_tensor.dequantize(), bias,
299+
input_tensor,
300+
weight_tensor.dequantize(),
301+
bias,
301302
)
302303

303304
act_quant_kwargs = weight_tensor.act_quant_kwargs
@@ -598,22 +599,9 @@ def _(func, types, args, kwargs):
598599
assert original_shape[-1] == size[-1], (
599600
f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}"
600601
)
601-
# TODO(andrew): This is technically not needed for unsloth fp8 RL
602-
# but fixes a bug nonetheless, can do this separately
603-
# Example input shapes:
604-
# self.shape = [6, 363, 4096]
605-
# self.scale.shape = [6, 363, 1]
606-
# self.block_size = [1, 1, 4096]
607-
# size = [-1, 4096]
608-
#
609-
# Example output shapes:
610-
# self.shape = [2178, 4096]
611-
# self.scale.shape = [2178, 1]
612-
# self.block_size = [1, 4096]
613-
new_dim0 = original_shape[0] * original_shape[1]
614-
assert size[0] == new_dim0 or size[0] == -1
615-
qdata = self.qdata.reshape(new_dim0, -1)
616-
scale = self.scale.reshape(new_dim0, -1)
602+
# TODO: this seems wrong, we should merge the first two dimensions instead
603+
qdata = self.qdata.reshape(*size)
604+
scale = self.scale.reshape(*size)
617605
block_size = self.block_size.copy()
618606
block_size = [block_size[0] * block_size[1], block_size[2]]
619607
elif len(original_shape) == 2 and len(size) == 3:
@@ -754,10 +742,15 @@ def _(func, types, args, kwargs):
754742
# This is called during _apply() to see if we can shallow
755743
# copy the content of one tensor into another. For now,
756744
# we only allow shallow copy if both tensors are `Float8Tensor`
745+
# and have the same shape.
757746
@implements_torch_function(torch._has_compatible_shallow_copy_type)
758747
def _(func, types, args, kwargs):
759748
assert len(args) == 2
760-
return isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
749+
return (
750+
isinstance(args[0], Float8Tensor)
751+
and isinstance(args[1], Float8Tensor)
752+
and args[0].shape == args[1].shape
753+
)
761754

762755

763756
@implements(aten.t.default)

0 commit comments

Comments
 (0)