Skip to content

Commit 2ccc619

Browse files
committed
Add per tensor fp8 quantization support conv3d
Summary: att, we added support of quantization conv3d weights, with Float8DynamicActivationFloat8WeightConfig API: ``` config = Float8DynamicActivationFloat8WeightConfig( granularity=PerTensor(), ) _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) quantize_(quantized_model, config, filter_fn=_is_conv3d) ``` Test Plan: pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants Reviewers: Subscribers: Tasks: Tags:
1 parent 7e5d907 commit 2ccc619

File tree

4 files changed

+170
-2
lines changed

4 files changed

+170
-2
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_is_fbgemm_gpu_genai_available,
3131
is_sm_at_least_89,
3232
is_sm_at_least_90,
33+
is_sm_at_least_100,
3334
torch_version_at_least,
3435
)
3536

@@ -49,6 +50,28 @@ def forward(self, x):
4950
return x
5051

5152

53+
class ToyConvModel(torch.nn.Module):
54+
def __init__(
55+
self, dim, in_channels, out_channels, kernel_size, bias, padding, dtype, device
56+
):
57+
super().__init__()
58+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
59+
self.conv = convs[dim](
60+
in_channels,
61+
out_channels,
62+
kernel_size,
63+
bias=bias,
64+
padding=padding,
65+
dtype=dtype,
66+
device=device,
67+
)
68+
if dim == 3:
69+
self.conv = self.conv.to(memory_format=torch.channels_last_3d)
70+
71+
def forward(self, x):
72+
return self.conv(x)
73+
74+
5275
# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
5376
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
5477
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -148,6 +171,90 @@ def test_fp8_linear_variants(
148171
f"Quantization error is too high got a SQNR of {error}"
149172
)
150173

174+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
175+
@unittest.skipIf(
176+
not is_sm_at_least_100(), "Requires GPU with compute capability >= 8.9"
177+
)
178+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
179+
@common_utils.parametrize("compile", [True, False])
180+
@common_utils.parametrize("granularity", [PerTensor()])
181+
@common_utils.parametrize(
182+
"kernel_preference",
183+
[KernelPreference.AUTO],
184+
)
185+
# only test for 3D conv for now
186+
# Inputs are (N, C_in, C_out, D, H, W)
187+
@common_utils.parametrize(
188+
"sizes",
189+
[
190+
(4, 16, 64, 32, 32, 32),
191+
],
192+
)
193+
def test_fp8_conv_variants(
194+
self,
195+
dtype: torch.dtype,
196+
compile: bool,
197+
granularity,
198+
kernel_preference: KernelPreference,
199+
sizes: Tuple,
200+
):
201+
if (
202+
isinstance(granularity, PerTensor)
203+
and kernel_preference == KernelPreference.FBGEMM
204+
):
205+
return unittest.skip(
206+
"per tensor with fbgemm kernel preferece does not work yet"
207+
)
208+
209+
if kernel_preference == KernelPreference.FBGEMM and (
210+
(not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_90())
211+
):
212+
return unittest.skip(
213+
"Requires fbgemm_gpu_genai to run fbgemm kernel preference test"
214+
)
215+
216+
dim = 3
217+
N, C_in, C_out, D, H, W = sizes
218+
kernel_size = 3
219+
220+
# Note: this is channel last memory format
221+
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
222+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
223+
224+
# Create a linear layer with bfloat16 dtype
225+
model = ToyConvModel(
226+
dim,
227+
C_in,
228+
C_out,
229+
kernel_size,
230+
bias=False,
231+
padding=0,
232+
dtype=dtype,
233+
device="cuda",
234+
).eval()
235+
236+
quantized_model = copy.deepcopy(model)
237+
238+
config = Float8DynamicActivationFloat8WeightConfig(
239+
granularity=granularity,
240+
kernel_preference=kernel_preference,
241+
)
242+
243+
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
244+
245+
quantize_(quantized_model, config, filter_fn=_is_conv3d)
246+
247+
if compile:
248+
quantized_model = torch.compile(quantized_model, fullgraph=True)
249+
250+
output_original = model(input_tensor)
251+
output_quantized = quantized_model(input_tensor)
252+
253+
error = compute_error(output_original, output_quantized)
254+
assert compute_error(output_original, output_quantized) > 20, (
255+
f"Quantization error is too high got a SQNR of {error}"
256+
)
257+
151258
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
152259
@unittest.skipIf(
153260
not is_sm_at_least_90(),

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1813,7 +1813,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18131813
_check_hardware_support(granularity)
18141814
activation_granularity, weight_granularity = granularity
18151815

1816-
if not _fp8_mm_compat(weight):
1816+
if weight.dim() != 5 and not _fp8_mm_compat(weight):
18171817
# TODO(future PR): this should really throw an exception instead of silently
18181818
# not doing what the user asked
18191819
return weight

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _(func, types, args, kwargs):
261261
)
262262

263263
act_quant_kwargs = weight_tensor.act_quant_kwargs
264-
# quantizing activation, if `act_quant_kwargs` is specified
264+
# quantize activation, if `act_quant_kwargs` is specified
265265
if act_quant_kwargs is not None:
266266
input_tensor = _choose_quant_func_and_quantize_tensor(
267267
input_tensor, act_quant_kwargs
@@ -418,6 +418,66 @@ def _(func, types, args, kwargs):
418418
return res
419419

420420

421+
@implements(aten.convolution.default)
422+
def _(func, types, args, kwargs):
423+
(
424+
input_tensor,
425+
weight_tensor,
426+
bias,
427+
stride,
428+
padding,
429+
dilation,
430+
transposed,
431+
output_padding,
432+
groups,
433+
) = args
434+
assert isinstance(weight_tensor, Float8Tensor), (
435+
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
436+
)
437+
438+
assert input_tensor.dim() == 5 and weight_tensor.dim() == 5, (
439+
"Only support 3D conv currently"
440+
)
441+
assert not transposed, "transposed conv is not supported currently"
442+
assert tuple(output_padding) == (0, 0, 0), (
443+
f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}"
444+
)
445+
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
446+
assert _is_fbgemm_gpu_genai_available(), (
447+
"quantized fp8 conv3d requires fbgemm_gpu_genai to be available"
448+
)
449+
act_quant_kwargs = weight_tensor.act_quant_kwargs
450+
# quantize activation, if `act_quant_kwargs` is specified
451+
if act_quant_kwargs is not None:
452+
input_tensor = _choose_quant_func_and_quantize_tensor(
453+
input_tensor, act_quant_kwargs
454+
)
455+
456+
# move C_in to last dim
457+
# after permute: (N, D, H, W, C_in)
458+
act_qdata = input_tensor.qdata.permute([0, 2, 3, 4, 1])
459+
460+
# move C_in to last dim
461+
# after permute: (C_out, K1, K2, K3, C_in)
462+
weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1])
463+
464+
assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), "Please make sure both activation and weights are in the `channels_last_3d` memory_format"
465+
466+
act_scale = input_tensor.scale
467+
weight_scale = weight_tensor.scale
468+
output = torch.ops.fbgemm.f8f8bf16_conv(
469+
act_qdata,
470+
weight_qdata,
471+
act_scale * weight_scale,
472+
padding,
473+
stride,
474+
dilation,
475+
)
476+
# output shape after permute: N, C_out, D_out, H_out, W_out
477+
output = output.permute([0, 4, 1, 2, 3])
478+
return output
479+
480+
421481
@implements(aten.slice.Tensor)
422482
def _(func, types, args, kwargs):
423483
"""Supports slicing for 1d, 2d, and 3d tensors

torchao/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"is_MI300",
3333
"is_sm_at_least_89",
3434
"is_sm_at_least_90",
35+
"is_sm_at_least_100",
3536
"is_package_at_least",
3637
"DummyModule",
3738
# Deprecated

0 commit comments

Comments
 (0)