diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 786e0cf59f..c9f17d0190 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -15,9 +15,11 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_utils import run_tests +from torchao.float8.inference import Float8MMConfig from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, + Granularity, PerRow, PerTensor, quantize_, @@ -82,7 +84,7 @@ def test_fp8_linear_variants( dtype: torch.dtype, mode: str, compile: bool, - granularity, + granularity: Granularity, kernel_preference: KernelPreference, sizes: Tuple, ): @@ -148,6 +150,61 @@ def test_fp8_linear_variants( f"Quantization error is too high got a SQNR of {error}" ) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "kernel_preference", + [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], + ) + # Inputs are (M,..), K, N + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ], + ) + def test_fp8_matmul_variants( + self, + dtype: torch.dtype, + granularity: Granularity, + kernel_preference: KernelPreference, + sizes: Tuple, + ): + if ( + isinstance(granularity, PerTensor) + and kernel_preference == KernelPreference.FBGEMM + ): + return unittest.skip( + "per tensor with fbgemm kernel preferece does not work yet" + ) + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + weight_tensor = torch.randn(N, K, dtype=dtype, device="cuda") + mm_config = Float8MMConfig() + input_tensor_fp8 = Float8Tensor.from_hp( + input_tensor, + granularity=granularity, + mm_config=mm_config, + kernel_preference=kernel_preference, + ) + weight_tensor_fp8 = Float8Tensor.from_hp( + weight_tensor, + granularity=granularity, + mm_config=mm_config, + kernel_preference=kernel_preference, + ) + output_tensor = torch.matmul(input_tensor, weight_tensor.t()) + output_tensor_fp8 = torch.matmul(input_tensor_fp8, weight_tensor_fp8.t()) + error = compute_error(output_tensor, output_tensor_fp8) + assert compute_error(output_tensor, output_tensor_fp8) > 20, ( + f"Quantization error is too high got a SQNR of {error}" + ) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @unittest.skipIf( not is_sm_at_least_90(), @@ -653,6 +710,38 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape): self.assertEqual(sliced_dequantized, sliced_original) + def test_to_dtype_layout(self): + x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) + x_fp8 = Float8Tensor.from_hp(x) + y_fp8 = torch.ops.aten.to.dtype_layout( + x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu" + ) + self.assertEqual(y_fp8.dtype, x_fp8.dtype) + self.assertEqual(y_fp8.layout, x_fp8.layout) + self.assertEqual(y_fp8.device, torch.device("cpu")) + + def test_has_compatible_shallow_copy_type(self): + x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) + x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) + x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16) + x1_fp8 = Float8Tensor.from_hp(x1) + x2_fp8 = Float8Tensor.from_hp(x2) + x3_fp8 = Float8Tensor.from_hp(x3) + self.assertFalse(torch._has_compatible_shallow_copy_type(x1, x2_fp8)) + self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x2)) + self.assertTrue(torch._has_compatible_shallow_copy_type(x1_fp8, x2_fp8)) + # Wrong shape + self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8)) + + def test_transpose(self): + x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) + x_fp8 = Float8Tensor.from_hp(x) + x_fp8_t = x_fp8.t() + torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0) + torch.testing.assert_close(x_fp8_t.scale, x_fp8.scale.t(), atol=0, rtol=0) + self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0) + self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0) + common_utils.instantiate_parametrized_tests(TestFloat8Tensor) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 47395a15af..348f1c65a8 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -248,18 +248,59 @@ def from_hp( implements_torch_function = Float8Tensor.implements_torch_function -@implements([aten.linear.default]) -@implements_torch_function([torch.nn.functional.linear]) +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( args[0], args[1], args[2] if len(args) > 2 else None, ) + return _float8_linear_impl(input_tensor, weight_tensor, bias) + + +@implements(aten.mm.default) +@implements_torch_function(torch.matmul) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = args[0], args[1] + return _float8_linear_impl(input_tensor, weight_tensor.t()) + + +@implements(aten.addmm_.default) +def _(func, types, args, kwargs): + output_tensor, input_tensor, weight_tensor = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + out = _float8_linear_impl(input_tensor, weight_tensor.t()) + return output_tensor.copy_(out) + + +def _float8_linear_impl( + input_tensor: torch.Tensor, + weight_tensor: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: assert isinstance(weight_tensor, Float8Tensor), ( f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" ) + # During the backward pass, we transpose the weight tensor, + # so if the weight tensor was originally rowwise quantized, + # now it becomes colwise. In this case, simply dequantize + # the tensor and do a bf16 matmul + is_colwise = ( + weight_tensor.block_size[0] == weight_tensor.shape[0] + and weight_tensor.block_size[1] == 1 + ) + if is_colwise: + return torch.nn.functional.linear( + input_tensor, + weight_tensor.dequantize(), + bias, + ) + act_quant_kwargs = weight_tensor.act_quant_kwargs # quantizing activation, if `act_quant_kwargs` is specified if act_quant_kwargs is not None: @@ -299,6 +340,7 @@ def _(func, types, args, kwargs): assert _is_rowwise_scaled(input_tensor), ( "Input tensor must be rowwise block size" ) + wq = wq.contiguous() res = torch.ops.fbgemm.f8f8bf16_rowwise( xq, wq, @@ -557,6 +599,7 @@ def _(func, types, args, kwargs): assert original_shape[-1] == size[-1], ( f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" ) + # TODO: this seems wrong, we should merge the first two dimensions instead qdata = self.qdata.reshape(*size) scale = self.scale.reshape(*size) block_size = self.block_size.copy() @@ -665,6 +708,68 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new) +@implements(torch.ops.aten.to.dtype_layout) +def _(func, types, args, kwargs): + # only support kwargs for now + assert len(args) == 1 + self = args[0] + # only support dtype, layout, and device for now + for k in kwargs.keys(): + assert k in ["dtype", "layout", "device"] + # only support same dtype and layout + # different dtype and layout has undefined behavior + if "dtype" in kwargs: + assert kwargs["dtype"] == self.dtype + if "layout" in kwargs: + assert kwargs["layout"] == self.layout + # if device is the same, treat this like a no-op + device = kwargs.get("device") + if device == self.device: + return self + # otherwise, move all inner tensors to the new device + new_tensor = self.__class__( + func(self.qdata, device=device), + func(self.scale, device=device), + self.block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new_tensor) + + +# This is called during _apply() to see if we can shallow +# copy the content of one tensor into another. For now, +# we only allow shallow copy if both tensors are `Float8Tensor` +# and have the same shape. +@implements_torch_function(torch._has_compatible_shallow_copy_type) +def _(func, types, args, kwargs): + assert len(args) == 2 + return ( + isinstance(args[0], Float8Tensor) + and isinstance(args[1], Float8Tensor) + and args[0].shape == args[1].shape + ) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + assert len(args) == 1 + self = args[0] + assert len(self.block_size) == 2 + new_tensor = self.__class__( + self.qdata.t(), + self.scale.t(), + (self.block_size[1], self.block_size[0]), + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new_tensor) + + Float8Tensor.__module__ = "torchao.quantization" # Allow a model with Float8Tensor weights to be loaded with `weights_only=True`