Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.float8.inference import Float8MMConfig
from torchao.quantization.quantize_.workflows.float8.float8_semi_sparse_tensor import (
Float8SemiSparseTensor,
)
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.testing.utils import skip_if_rocm
from torchao.utils import is_sm_at_least_90


@unittest.skipIf(not is_sm_at_least_90(), "Need H100+ to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
class TestFloat8SemiSparseTensor(TestCase):
def setUp(self):
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

@skip_if_rocm("ROCm enablement in progress")
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 128),
],
)
def test_sparse_vs_dense_fp8(self, sizes):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcaip Updated testing to follow style of [test_sparse_apu.py](https://fburl.com/18n157bf. For now I'm omitting any config related changes until QRT, however this diff does include all ops (linear, addmm, mm) so that integration can be done as follow up.

Could I get feedback on the construction of the implementations and test before adding similar for addmm, mm and adding polish?

dtype = torch.bfloat16
device = "cuda"

M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)

apply_fake_sparsity(linear)

mm_config = Float8MMConfig(use_fast_accum=True)
input_fp8 = Float8Tensor.from_hp(
input, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config
)

weight_fp8 = Float8Tensor.from_hp(
linear.weight.data, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config
)
dense_output = torch.nn.functional.linear(input_fp8, weight_fp8, linear.bias)

weight_sparse_fp8 = Float8SemiSparseTensor.from_hp(linear.weight.data, [1, K])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: .detach() instead of .data?

sparse_output = torch.nn.functional.linear(
input_fp8, weight_sparse_fp8, linear.bias
)

torch.testing.assert_close(dense_output, sparse_output, atol=3e-1, rtol=3e-1)


instantiate_parametrized_tests(TestFloat8SemiSparseTensor)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
quantize_affine,
)
from .quantize_.workflows import (
Float8SemiSparseTensor,
Float8Tensor,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Expand Down Expand Up @@ -148,6 +149,7 @@
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Int4OpaqueTensor",
"Float8SemiSparseTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .float8.float8_semi_sparse_tensor import (
Float8SemiSparseTensor,
)
from .float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
Expand Down Expand Up @@ -38,6 +41,7 @@
"Int4PlainInt32Tensor",
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Float8SemiSparseTensor",
"QuantizeTensorToFloat8Kwargs",
"Int4OpaqueTensor",
"Int4ChooseQParamsAlgorithm",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import torch

from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8
from torchao.quantization.quant_primitives import (
_choose_scale_float8,
_quantize_affine_float8,
)
from torchao.utils import TorchAOBaseTensor

__all__ = ["Float8SemiSparseTensor"]
aten = torch.ops.aten


class Float8SemiSparseTensor(TorchAOBaseTensor):
tensor_data_names = ["sparse", "meta", "scale"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should use [ compressed_values, metadata ] instead of sparse and meta here.


def __new__(
cls,
sparse: torch.Tensor,
meta: torch.Tensor,
scale: torch.Tensor,
):
kwargs = {}
kwargs["device"] = sparse.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
shape = (sparse.shape[0], 2 * sparse.shape[-1])
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
sparse: torch.Tensor,
meta: torch.Tensor,
scale: torch.Tensor,
):
super().__init__()
self.sparse = sparse
self.meta = meta
self.scale = scale

def _quantization_type(self):
return f"shape={self.shape}, device={self.device}, dtype={self.dtype}"

@classmethod
def from_hp(
cls,
w: torch.Tensor,
block_size: List[int],
):
from torchao.sparsity.utils import mask_creator

dense = w * mask_creator(w).bool()

scale = _choose_scale_float8(
dense,
block_size=block_size,
float8_dtype=torch.float8_e4m3fn,
)

w_fp8 = _quantize_affine_float8(
dense,
scale=scale,
float8_dtype=torch.float8_e4m3fn,
)

sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8)

return cls(
sparse,
meta,
scale,
)


implements = Float8SemiSparseTensor.implements
implements_torch_function = Float8SemiSparseTensor.implements_torch_function


@implements(aten.t.default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you have to implement transpose? Transpose on a sparse matrix is kind of tricky, we should probably throw an ValueError if it's called.

def _(func, types, args, kwargs):
from torch.utils._python_dispatch import return_and_correct_aliasing

self = args[0]
new = Float8SemiSparseTensor(
sparse=self.sparse,
meta=self.meta,
scale=self.scale,
)
return return_and_correct_aliasing(func, args, kwargs, new)


def _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias):
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
Float8Tensor,
)

if isinstance(input_tensor, Float8Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what does this conditional do?

input = input_tensor.qdata
input_scale = input_tensor.scale
out_dtype = input_tensor.dtype
else:
input = input_tensor.qdata
input_scale = input_tensor.scale
out_dtype = input_tensor.dtype

weight = weight_tensor.sparse
weight_meta = weight_tensor.meta
weight_scale = weight_tensor.scale

# Reshape input_scale if needed: kernel expects scale to match input shape minus last dim
# For input [B, K], scale should be [B] not [B, 1]
if input_scale.dim() > input.dim() - 1:
input_scale = input_scale.squeeze(-1)

return rowwise_scaled_linear_sparse_cutlass_f8f8(
input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype
)


@implements([aten.mm.default, aten.addmm.default])
def _(func, types, args, kwargs):
if func == aten.addmm.default:
bias, input_tensor, weight_tensor = args
else: # aten.mm.default
input_tensor, weight_tensor = args
bias = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you may need to do some transpose trickery here to support mm and addmm,

my understanding is that linear(x, w, bias) will return xW^t + bias so for mm / addmm you need to pass in a transposed weight

return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias)


@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias)


Float8SemiSparseTensor.__module__ = "torchao.quantization"

# Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Float8SemiSparseTensor])
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,10 @@ def _(func, types, args, kwargs):
args[1],
args[2] if len(args) > 2 else None,
)
assert isinstance(weight_tensor, Float8Tensor), (
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
)

# If weight is not Float8Tensor, return NotImplemented to allow weight's dispatch to handle it
if not isinstance(weight_tensor, Float8Tensor):
return NotImplemented

act_quant_kwargs = weight_tensor.act_quant_kwargs
# quantizing activation, if `act_quant_kwargs` is specified
Expand Down
Loading