|
15 | 15 | from torch.testing._internal import common_utils |
16 | 16 | from torch.testing._internal.common_utils import run_tests |
17 | 17 |
|
| 18 | +from torchao.float8.inference import Float8MMConfig |
18 | 19 | from torchao.quantization import ( |
19 | 20 | Float8DynamicActivationFloat8WeightConfig, |
20 | 21 | Float8WeightOnlyConfig, |
| 22 | + Granularity, |
21 | 23 | PerRow, |
22 | 24 | PerTensor, |
23 | 25 | quantize_, |
@@ -82,7 +84,7 @@ def test_fp8_linear_variants( |
82 | 84 | dtype: torch.dtype, |
83 | 85 | mode: str, |
84 | 86 | compile: bool, |
85 | | - granularity, |
| 87 | + granularity: Granularity, |
86 | 88 | kernel_preference: KernelPreference, |
87 | 89 | sizes: Tuple, |
88 | 90 | ): |
@@ -148,6 +150,61 @@ def test_fp8_linear_variants( |
148 | 150 | f"Quantization error is too high got a SQNR of {error}" |
149 | 151 | ) |
150 | 152 |
|
| 153 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 154 | + @unittest.skipIf( |
| 155 | + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" |
| 156 | + ) |
| 157 | + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) |
| 158 | + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) |
| 159 | + @common_utils.parametrize( |
| 160 | + "kernel_preference", |
| 161 | + [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], |
| 162 | + ) |
| 163 | + # Inputs are (M,..), K, N |
| 164 | + @common_utils.parametrize( |
| 165 | + "sizes", |
| 166 | + [ |
| 167 | + ((128,), 256, 128), |
| 168 | + ((32, 128), 64, 256), |
| 169 | + ], |
| 170 | + ) |
| 171 | + def test_fp8_matmul( |
| 172 | + self, |
| 173 | + dtype: torch.dtype, |
| 174 | + granularity: Granularity, |
| 175 | + kernel_preference: KernelPreference, |
| 176 | + sizes: Tuple, |
| 177 | + ): |
| 178 | + if ( |
| 179 | + isinstance(granularity, PerTensor) |
| 180 | + and kernel_preference == KernelPreference.FBGEMM |
| 181 | + ): |
| 182 | + return unittest.skip( |
| 183 | + "per tensor with fbgemm kernel preferece does not work yet" |
| 184 | + ) |
| 185 | + M, N, K = sizes |
| 186 | + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") |
| 187 | + weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda") |
| 188 | + mm_config = Float8MMConfig() |
| 189 | + input_tensor_fp8 = Float8Tensor.from_hp( |
| 190 | + input_tensor, |
| 191 | + granularity=granularity, |
| 192 | + mm_config=mm_config, |
| 193 | + kernel_preference=kernel_preference, |
| 194 | + ) |
| 195 | + weight_tensor_fp8 = Float8Tensor.from_hp( |
| 196 | + weight_tensor, |
| 197 | + granularity=granularity, |
| 198 | + mm_config=mm_config, |
| 199 | + kernel_preference=kernel_preference, |
| 200 | + ) |
| 201 | + output_tensor = torch.matmul(input_tensor, weight_tensor) |
| 202 | + output_tensor_fp8 = torch.matmul(input_tensor_fp8, weight_tensor_fp8) |
| 203 | + error = compute_error(output_tensor, output_tensor_fp8) |
| 204 | + assert compute_error(output_tensor, output_tensor_fp8) > 20, ( |
| 205 | + f"Quantization error is too high got a SQNR of {error}" |
| 206 | + ) |
| 207 | + |
151 | 208 | @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) |
152 | 209 | @unittest.skipIf( |
153 | 210 | not is_sm_at_least_90(), |
@@ -653,6 +710,38 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape): |
653 | 710 |
|
654 | 711 | self.assertEqual(sliced_dequantized, sliced_original) |
655 | 712 |
|
| 713 | + def test_to_dtype_layout(self): |
| 714 | + x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) |
| 715 | + x_fp8 = Float8Tensor.from_hp(x) |
| 716 | + y_fp8 = torch.ops.aten.to.dtype_layout( |
| 717 | + x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu" |
| 718 | + ) |
| 719 | + self.assertEqual(y_fp8.dtype, x_fp8.dtype) |
| 720 | + self.assertEqual(y_fp8.layout, x_fp8.layout) |
| 721 | + self.assertEqual(y_fp8.device, torch.device("cpu")) |
| 722 | + |
| 723 | + def test_has_compatible_shallow_copy_type(self): |
| 724 | + x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) |
| 725 | + x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) |
| 726 | + x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16) |
| 727 | + x1_fp8 = Float8Tensor.from_hp(x1) |
| 728 | + x2_fp8 = Float8Tensor.from_hp(x2) |
| 729 | + x3_fp8 = Float8Tensor.from_hp(x3) |
| 730 | + self.assertFalse(torch._has_compatible_shallow_copy_type(x1, x2_fp8)) |
| 731 | + self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x2)) |
| 732 | + self.assertTrue(torch._has_compatible_shallow_copy_type(x1_fp8, x2_fp8)) |
| 733 | + # Wrong shape |
| 734 | + self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8)) |
| 735 | + |
| 736 | + def test_transpose(self): |
| 737 | + x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) |
| 738 | + x_fp8 = Float8Tensor.from_hp(x) |
| 739 | + x_fp8_t = x_fp8.t() |
| 740 | + torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0) |
| 741 | + torch.testing.assert_close(x_fp8_t.scale, x_fp8.scale.t(), atol=0, rtol=0) |
| 742 | + self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0) |
| 743 | + self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0) |
| 744 | + |
656 | 745 |
|
657 | 746 | common_utils.instantiate_parametrized_tests(TestFloat8Tensor) |
658 | 747 |
|
|
0 commit comments