Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import run_tests

from torchao.float8.inference import Float8MMConfig
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Granularity,
PerRow,
PerTensor,
quantize_,
Expand Down Expand Up @@ -82,7 +84,7 @@ def test_fp8_linear_variants(
dtype: torch.dtype,
mode: str,
compile: bool,
granularity,
granularity: Granularity,
kernel_preference: KernelPreference,
sizes: Tuple,
):
Expand Down Expand Up @@ -148,6 +150,61 @@ def test_fp8_linear_variants(
f"Quantization error is too high got a SQNR of {error}"
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
"kernel_preference",
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
)
# Inputs are (M,..), K, N
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 64, 256),
],
)
def test_fp8_matmul_variants(
self,
dtype: torch.dtype,
granularity: Granularity,
kernel_preference: KernelPreference,
sizes: Tuple,
):
if (
isinstance(granularity, PerTensor)
and kernel_preference == KernelPreference.FBGEMM
):
return unittest.skip(
"per tensor with fbgemm kernel preferece does not work yet"
)
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
weight_tensor = torch.randn(N, K, dtype=dtype, device="cuda")
mm_config = Float8MMConfig()
input_tensor_fp8 = Float8Tensor.from_hp(
input_tensor,
granularity=granularity,
mm_config=mm_config,
kernel_preference=kernel_preference,
)
weight_tensor_fp8 = Float8Tensor.from_hp(
weight_tensor,
granularity=granularity,
mm_config=mm_config,
kernel_preference=kernel_preference,
)
output_tensor = torch.matmul(input_tensor, weight_tensor.t())
output_tensor_fp8 = torch.matmul(input_tensor_fp8, weight_tensor_fp8.t())
error = compute_error(output_tensor, output_tensor_fp8)
assert compute_error(output_tensor, output_tensor_fp8) > 20, (
f"Quantization error is too high got a SQNR of {error}"
)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@unittest.skipIf(
not is_sm_at_least_90(),
Expand Down Expand Up @@ -653,6 +710,38 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):

self.assertEqual(sliced_dequantized, sliced_original)

def test_to_dtype_layout(self):
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x_fp8 = Float8Tensor.from_hp(x)
y_fp8 = torch.ops.aten.to.dtype_layout(
x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu"
)
self.assertEqual(y_fp8.dtype, x_fp8.dtype)
self.assertEqual(y_fp8.layout, x_fp8.layout)
self.assertEqual(y_fp8.device, torch.device("cpu"))

def test_has_compatible_shallow_copy_type(self):
x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16)
x1_fp8 = Float8Tensor.from_hp(x1)
x2_fp8 = Float8Tensor.from_hp(x2)
x3_fp8 = Float8Tensor.from_hp(x3)
self.assertFalse(torch._has_compatible_shallow_copy_type(x1, x2_fp8))
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x2))
self.assertTrue(torch._has_compatible_shallow_copy_type(x1_fp8, x2_fp8))
# Wrong shape
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8))

def test_transpose(self):
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x_fp8 = Float8Tensor.from_hp(x)
x_fp8_t = x_fp8.t()
torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0)
torch.testing.assert_close(x_fp8_t.scale, x_fp8.scale.t(), atol=0, rtol=0)
self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0)
self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0)


common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

Expand Down
109 changes: 107 additions & 2 deletions torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,18 +248,59 @@ def from_hp(
implements_torch_function = Float8Tensor.implements_torch_function


@implements([aten.linear.default])
@implements_torch_function([torch.nn.functional.linear])
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
return _float8_linear_impl(input_tensor, weight_tensor, bias)


@implements(aten.mm.default)
@implements_torch_function(torch.matmul)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = args[0], args[1]
return _float8_linear_impl(input_tensor, weight_tensor.t())


@implements(aten.addmm_.default)
def _(func, types, args, kwargs):
output_tensor, input_tensor, weight_tensor = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
out = _float8_linear_impl(input_tensor, weight_tensor.t())
return output_tensor.copy_(out)


def _float8_linear_impl(
input_tensor: torch.Tensor,
weight_tensor: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert isinstance(weight_tensor, Float8Tensor), (
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
)

# During the backward pass, we transpose the weight tensor,
# so if the weight tensor was originally rowwise quantized,
# now it becomes colwise. In this case, simply dequantize
# the tensor and do a bf16 matmul
is_colwise = (
weight_tensor.block_size[0] == weight_tensor.shape[0]
and weight_tensor.block_size[1] == 1
)
if is_colwise:
return torch.nn.functional.linear(
input_tensor,
weight_tensor.dequantize(),
bias,
)

act_quant_kwargs = weight_tensor.act_quant_kwargs
# quantizing activation, if `act_quant_kwargs` is specified
if act_quant_kwargs is not None:
Expand Down Expand Up @@ -299,6 +340,7 @@ def _(func, types, args, kwargs):
assert _is_rowwise_scaled(input_tensor), (
"Input tensor must be rowwise block size"
)
wq = wq.contiguous()
res = torch.ops.fbgemm.f8f8bf16_rowwise(
xq,
wq,
Expand Down Expand Up @@ -557,6 +599,7 @@ def _(func, types, args, kwargs):
assert original_shape[-1] == size[-1], (
f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}"
)
# TODO: this seems wrong, we should merge the first two dimensions instead
qdata = self.qdata.reshape(*size)
scale = self.scale.reshape(*size)
block_size = self.block_size.copy()
Expand Down Expand Up @@ -665,6 +708,68 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, new)


@implements(torch.ops.aten.to.dtype_layout)
def _(func, types, args, kwargs):
# only support kwargs for now
assert len(args) == 1
self = args[0]
# only support dtype, layout, and device for now
for k in kwargs.keys():
assert k in ["dtype", "layout", "device"]
# only support same dtype and layout
# different dtype and layout has undefined behavior
if "dtype" in kwargs:
assert kwargs["dtype"] == self.dtype
if "layout" in kwargs:
assert kwargs["layout"] == self.layout
# if device is the same, treat this like a no-op
device = kwargs.get("device")
if device == self.device:
return self
# otherwise, move all inner tensors to the new device
new_tensor = self.__class__(
func(self.qdata, device=device),
func(self.scale, device=device),
self.block_size,
self.mm_config,
self.act_quant_kwargs,
self.kernel_preference,
self.dtype,
)
return return_and_correct_aliasing(func, args, kwargs, new_tensor)


# This is called during _apply() to see if we can shallow
# copy the content of one tensor into another. For now,
# we only allow shallow copy if both tensors are `Float8Tensor`
# and have the same shape.
@implements_torch_function(torch._has_compatible_shallow_copy_type)
def _(func, types, args, kwargs):
assert len(args) == 2
return (
isinstance(args[0], Float8Tensor)
and isinstance(args[1], Float8Tensor)
and args[0].shape == args[1].shape
)


@implements(aten.t.default)
def _(func, types, args, kwargs):
assert len(args) == 1
self = args[0]
assert len(self.block_size) == 2
new_tensor = self.__class__(
self.qdata.t(),
self.scale.t(),
(self.block_size[1], self.block_size[0]),
self.mm_config,
self.act_quant_kwargs,
self.kernel_preference,
self.dtype,
)
return return_and_correct_aliasing(func, args, kwargs, new_tensor)


Float8Tensor.__module__ = "torchao.quantization"

# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`
Expand Down
Loading