-
Couldn't load subscription status.
- Fork 353
[WIP] Move float8 cutlass sparse layout to Float8SemiSparseTensor #3182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.testing._internal.common_utils import ( | ||
| TestCase, | ||
| instantiate_parametrized_tests, | ||
| parametrize, | ||
| run_tests, | ||
| ) | ||
|
|
||
| from torchao.float8.inference import Float8MMConfig | ||
| from torchao.quantization.quantize_.workflows.float8.float8_semi_sparse_tensor import ( | ||
| Float8SemiSparseTensor, | ||
| ) | ||
| from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor | ||
| from torchao.sparsity.sparse_api import apply_fake_sparsity | ||
| from torchao.testing.utils import skip_if_rocm | ||
| from torchao.utils import is_sm_at_least_90 | ||
|
|
||
|
|
||
| @unittest.skipIf(not is_sm_at_least_90(), "Need H100+ to run") | ||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| class TestFloat8SemiSparseTensor(TestCase): | ||
| def setUp(self): | ||
| self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] | ||
|
|
||
| @skip_if_rocm("ROCm enablement in progress") | ||
| @parametrize( | ||
| "sizes", | ||
| [ | ||
| ((128,), 256, 128), | ||
| ((32, 128), 512, 128), | ||
| ((2, 32, 128), 256, 128), | ||
| ], | ||
| ) | ||
| def test_sparse_vs_dense_fp8(self, sizes): | ||
| dtype = torch.bfloat16 | ||
| device = "cuda" | ||
|
|
||
| M, N, K = sizes | ||
| input = torch.randn(*M, K, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
|
|
||
| apply_fake_sparsity(linear) | ||
|
|
||
| mm_config = Float8MMConfig(use_fast_accum=True) | ||
| input_fp8 = Float8Tensor.from_hp( | ||
| input, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config | ||
| ) | ||
|
|
||
| weight_fp8 = Float8Tensor.from_hp( | ||
| linear.weight.data, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config | ||
| ) | ||
| dense_output = torch.nn.functional.linear(input_fp8, weight_fp8, linear.bias) | ||
|
|
||
| weight_sparse_fp8 = Float8SemiSparseTensor.from_hp(linear.weight.data, [1, K]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: .detach() instead of .data? |
||
| sparse_output = torch.nn.functional.linear( | ||
| input_fp8, weight_sparse_fp8, linear.bias | ||
| ) | ||
|
|
||
| torch.testing.assert_close(dense_output, sparse_output, atol=3e-1, rtol=3e-1) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(TestFloat8SemiSparseTensor) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
| from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 | ||
| from torchao.quantization.quant_primitives import ( | ||
| _choose_scale_float8, | ||
| _quantize_affine_float8, | ||
| ) | ||
| from torchao.utils import TorchAOBaseTensor | ||
|
|
||
| __all__ = ["Float8SemiSparseTensor"] | ||
| aten = torch.ops.aten | ||
|
|
||
|
|
||
| class Float8SemiSparseTensor(TorchAOBaseTensor): | ||
| tensor_data_names = ["sparse", "meta", "scale"] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we should use [ |
||
|
|
||
| def __new__( | ||
| cls, | ||
| sparse: torch.Tensor, | ||
| meta: torch.Tensor, | ||
| scale: torch.Tensor, | ||
| ): | ||
| kwargs = {} | ||
| kwargs["device"] = sparse.device | ||
| kwargs["dtype"] = scale.dtype | ||
| kwargs["requires_grad"] = False | ||
| shape = (sparse.shape[0], 2 * sparse.shape[-1]) | ||
| return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
|
||
| def __init__( | ||
| self, | ||
| sparse: torch.Tensor, | ||
| meta: torch.Tensor, | ||
| scale: torch.Tensor, | ||
| ): | ||
| super().__init__() | ||
| self.sparse = sparse | ||
| self.meta = meta | ||
| self.scale = scale | ||
|
|
||
| def _quantization_type(self): | ||
| return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" | ||
|
|
||
| @classmethod | ||
| def from_hp( | ||
| cls, | ||
| w: torch.Tensor, | ||
| block_size: List[int], | ||
| ): | ||
| from torchao.sparsity.utils import mask_creator | ||
|
|
||
| dense = w * mask_creator(w).bool() | ||
|
|
||
| scale = _choose_scale_float8( | ||
| dense, | ||
| block_size=block_size, | ||
| float8_dtype=torch.float8_e4m3fn, | ||
| ) | ||
|
|
||
| w_fp8 = _quantize_affine_float8( | ||
| dense, | ||
| scale=scale, | ||
| float8_dtype=torch.float8_e4m3fn, | ||
| ) | ||
|
|
||
| sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8) | ||
|
|
||
| return cls( | ||
| sparse, | ||
| meta, | ||
| scale, | ||
| ) | ||
|
|
||
|
|
||
| implements = Float8SemiSparseTensor.implements | ||
| implements_torch_function = Float8SemiSparseTensor.implements_torch_function | ||
|
|
||
|
|
||
| @implements(aten.t.default) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you have to implement transpose? Transpose on a sparse matrix is kind of tricky, we should probably throw an |
||
| def _(func, types, args, kwargs): | ||
| from torch.utils._python_dispatch import return_and_correct_aliasing | ||
|
|
||
| self = args[0] | ||
| new = Float8SemiSparseTensor( | ||
| sparse=self.sparse, | ||
| meta=self.meta, | ||
| scale=self.scale, | ||
| ) | ||
| return return_and_correct_aliasing(func, args, kwargs, new) | ||
|
|
||
|
|
||
| def _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias): | ||
| from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 | ||
| from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( | ||
| Float8Tensor, | ||
| ) | ||
|
|
||
| if isinstance(input_tensor, Float8Tensor): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: what does this conditional do? |
||
| input = input_tensor.qdata | ||
| input_scale = input_tensor.scale | ||
| out_dtype = input_tensor.dtype | ||
| else: | ||
| input = input_tensor.qdata | ||
| input_scale = input_tensor.scale | ||
| out_dtype = input_tensor.dtype | ||
|
|
||
| weight = weight_tensor.sparse | ||
| weight_meta = weight_tensor.meta | ||
| weight_scale = weight_tensor.scale | ||
|
|
||
| # Reshape input_scale if needed: kernel expects scale to match input shape minus last dim | ||
| # For input [B, K], scale should be [B] not [B, 1] | ||
| if input_scale.dim() > input.dim() - 1: | ||
| input_scale = input_scale.squeeze(-1) | ||
|
|
||
| return rowwise_scaled_linear_sparse_cutlass_f8f8( | ||
| input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype | ||
| ) | ||
|
|
||
|
|
||
| @implements([aten.mm.default, aten.addmm.default]) | ||
| def _(func, types, args, kwargs): | ||
| if func == aten.addmm.default: | ||
| bias, input_tensor, weight_tensor = args | ||
| else: # aten.mm.default | ||
| input_tensor, weight_tensor = args | ||
| bias = None | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you may need to do some transpose trickery here to support mm and addmm, my understanding is that |
||
| return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias) | ||
|
|
||
|
|
||
| @implements(aten.linear.default) | ||
| @implements_torch_function(torch.nn.functional.linear) | ||
| def _(func, types, args, kwargs): | ||
| input_tensor, weight_tensor, bias = ( | ||
| args[0], | ||
| args[1], | ||
| args[2] if len(args) > 2 else None, | ||
| ) | ||
| return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias) | ||
|
|
||
|
|
||
| Float8SemiSparseTensor.__module__ = "torchao.quantization" | ||
|
|
||
| # Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True` | ||
| torch.serialization.add_safe_globals([Float8SemiSparseTensor]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcaip Updated testing to follow style of [test_sparse_apu.py](https://fburl.com/18n157bf. For now I'm omitting any config related changes until QRT, however this diff does include all ops (linear, addmm, mm) so that integration can be done as follow up.
Could I get feedback on the construction of the implementations and test before adding similar for addmm, mm and adding polish?