Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_is_fbgemm_gpu_genai_available,
is_sm_at_least_89,
is_sm_at_least_90,
is_sm_at_least_100,
torch_version_at_least,
)

Expand All @@ -49,6 +50,28 @@ def forward(self, x):
return x


class ToyConvModel(torch.nn.Module):
def __init__(
self, dim, in_channels, out_channels, kernel_size, bias, padding, dtype, device
):
super().__init__()
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
self.conv = convs[dim](
in_channels,
out_channels,
kernel_size,
bias=bias,
padding=padding,
dtype=dtype,
device=device,
)
if dim == 3:
self.conv = self.conv.to(memory_format=torch.channels_last_3d)

def forward(self, x):
return self.conv(x)


# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down Expand Up @@ -148,6 +171,94 @@ def test_fp8_linear_variants(
f"Quantization error is too high got a SQNR of {error}"
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_100(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("granularity", [PerTensor()])
@common_utils.parametrize("inference_mode", [True, False])
@common_utils.parametrize(
"kernel_preference",
[KernelPreference.AUTO],
)
# only test for 3D conv for now
# Inputs are (N, C_in, C_out, D, H, W)
@common_utils.parametrize(
"sizes",
[
(4, 16, 64, 32, 32, 32),
],
)
def test_fp8_conv_variants(
self,
dtype: torch.dtype,
compile: bool,
granularity,
inference_mode: bool,
kernel_preference: KernelPreference,
sizes: Tuple,
):
if (
isinstance(granularity, PerTensor)
and kernel_preference == KernelPreference.FBGEMM
):
return unittest.skip(
"per tensor with fbgemm kernel preferece does not work yet"
)

if kernel_preference == KernelPreference.FBGEMM and (
(not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_90())
):
return unittest.skip(
"Requires fbgemm_gpu_genai to run fbgemm kernel preference test"
)

dim = 3
N, C_in, C_out, D, H, W = sizes
kernel_size = 3

# Note: this is channel last memory format
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)

# Create a linear layer with bfloat16 dtype
model = ToyConvModel(
dim,
C_in,
C_out,
kernel_size,
bias=False,
padding=0,
dtype=dtype,
device="cuda",
).eval()

quantized_model = copy.deepcopy(model)

config = Float8DynamicActivationFloat8WeightConfig(
granularity=granularity,
kernel_preference=kernel_preference,
)

_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)

quantize_(quantized_model, config, filter_fn=_is_conv3d)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

inference_mode_ctx = torch.inference_mode() if inference_mode else nullcontext()
with inference_mode_ctx:
output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)

error = compute_error(output_original, output_quantized)
assert compute_error(output_original, output_quantized) > 20, (
f"Quantization error is too high got a SQNR of {error}"
)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@unittest.skipIf(
not is_sm_at_least_90(),
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,7 +1813,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
_check_hardware_support(granularity)
activation_granularity, weight_granularity = granularity

if not _fp8_mm_compat(weight):
if weight.dim() != 5 and not _fp8_mm_compat(weight):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return weight
Expand Down
104 changes: 103 additions & 1 deletion torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _(func, types, args, kwargs):
)

act_quant_kwargs = weight_tensor.act_quant_kwargs
# quantizing activation, if `act_quant_kwargs` is specified
# quantize activation, if `act_quant_kwargs` is specified
if act_quant_kwargs is not None:
input_tensor = _choose_quant_func_and_quantize_tensor(
input_tensor, act_quant_kwargs
Expand Down Expand Up @@ -418,6 +418,108 @@ def _(func, types, args, kwargs):
return res


def _scaled_conv3d(
input_tensor,
weight_tensor,
bias,
stride,
padding,
dilation,
):
assert isinstance(weight_tensor, Float8Tensor), (
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
)

assert input_tensor.dim() == 5 and weight_tensor.dim() == 5, (
"Only support 3D conv currently"
)
assert _is_fbgemm_gpu_genai_available(), (
"quantized fp8 conv3d requires fbgemm_gpu_genai to be available"
)
act_quant_kwargs = weight_tensor.act_quant_kwargs
# quantize activation, if `act_quant_kwargs` is specified
if act_quant_kwargs is not None:
input_tensor = _choose_quant_func_and_quantize_tensor(
input_tensor, act_quant_kwargs
)

# move C_in to last dim
# after permute: (N, D, H, W, C_in)
act_qdata = input_tensor.qdata.permute([0, 2, 3, 4, 1])

# move C_in to last dim
# after permute: (C_out, K1, K2, K3, C_in)
weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1])

assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), (
"Please make sure both activation and weights are in the `channels_last_3d` memory_format"
)

act_scale = input_tensor.scale
weight_scale = weight_tensor.scale
output = torch.ops.fbgemm.f8f8bf16_conv(
act_qdata,
weight_qdata,
act_scale * weight_scale,
padding,
stride,
dilation,
)
# output shape after permute: N, C_out, D_out, H_out, W_out
output = output.permute([0, 4, 1, 2, 3])
return output


@implements(aten.convolution.default)
def _(func, types, args, kwargs):
(
input_tensor,
weight_tensor,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
) = args
assert not transposed, "transposed conv is not supported currently"
assert tuple(output_padding) == (0, 0, 0), (
f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}"
)
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
return _scaled_conv3d(
input_tensor,
weight_tensor,
bias,
stride,
padding,
dilation,
)


@implements(aten.conv3d.default)
def _(func, types, args, kwargs):
(
input_tensor,
weight_tensor,
bias,
stride,
padding,
dilation,
groups,
) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1])
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
return _scaled_conv3d(
input_tensor,
weight_tensor,
bias,
stride,
padding,
dilation,
)


@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
"""Supports slicing for 1d, 2d, and 3d tensors
Expand Down
1 change: 1 addition & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"is_MI300",
"is_sm_at_least_89",
"is_sm_at_least_90",
"is_sm_at_least_100",
"is_package_at_least",
"DummyModule",
# Deprecated
Expand Down
Loading