Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_config(group_size):

@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
class Int4PlainInt32Tensor(TestCase):
class Int4PlainInt32TensorXPU(TestCase):
@parametrize(
"sizes",
[
Expand Down Expand Up @@ -98,8 +98,74 @@ def test_activation_prescaling(self):
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)


instantiate_parametrized_tests(Int4PlainInt32Tensor)
@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+")
@unittest.skipIf(
torch.accelerator.current_accelerator().type != "npu"
or not torch.accelerator.is_available(),
"NPU not available",
)
class Int4PlainInt32TensorNPU(TestCase):
@parametrize("device", ["npu"])
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 128),
],
)
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("group_size", [32, 64])
def test_linear(self, device, sizes, dtype, group_size):
M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
orig_output = linear(input)
quantize_(linear, get_config(group_size))
quantized_output = linear(input)
self.assertTrue(compute_error(orig_output, quantized_output) > 10)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_module_path(self, device, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
quantize_(linear, get_config(group_size=64))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_activation_prescaling(self, device, dtype):
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(64))
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
)
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
_ACT_PRE_SCALE = 2
qw.act_pre_scale = _ACT_PRE_SCALE
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10)


instantiate_parametrized_tests(Int4PlainInt32TensorXPU)
instantiate_parametrized_tests(Int4PlainInt32TensorNPU)

if __name__ == "__main__":
run_tests()
286 changes: 231 additions & 55 deletions torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
choose_qparams_affine,
quantize_affine,
)
from torchao.utils import (
TorchAOBaseTensor,
)
from torchao.utils import TorchAOBaseTensor, torch_version_at_least

__all__ = [
"Int4PlainInt32Tensor",
Expand Down Expand Up @@ -91,58 +89,155 @@ def from_hp(
w: torch.Tensor,
block_size: List[int],
):
assert w.ndim == 2 and w.device.type == "xpu", (
f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim
assert w.dtype in [torch.float16, torch.bfloat16], (
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
)
original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
scale_dtype = None
zero_point_dtype = torch.int32
scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)
int_data = quantize_affine(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype"
)
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
return Int4PlainInt32Tensor(
packed_weight,
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous().to(torch.int8),
block_size,
original_shape,
act_pre_scale=None,
)
if w.device.type == "xpu":
return _from_hp_xpu(cls, w, block_size)
elif w.device.type == "npu":
return _from_hp_npu(cls, w, block_size)
else:
raise NotImplementedError(
f"Int4PlainInt32Tensor does not support device '{w.device.type}' yet."
)


def _from_hp_xpu(
cls,
w: torch.Tensor,
block_size: List[int],
):
assert w.ndim == 2 and w.device.type == "xpu", (
f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim
assert w.dtype in [torch.float16, torch.bfloat16], (
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
)
original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
scale_dtype = None
zero_point_dtype = torch.int32
scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)
int_data = quantize_affine(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype"
)
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
return Int4PlainInt32Tensor(
packed_weight,
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous().to(torch.int8),
block_size,
original_shape,
act_pre_scale=None,
)


def _from_hp_npu(
cls,
w: torch.Tensor,
block_size: List[int],
):
# Require PyTorch 2.7.1+ for NPU backend ops and backward compatibility.
assert torch_version_at_least("2.7.1"), (
"Need pytorch 2.7.1+ for NPU backend op support."
)

assert w.ndim == 2 and w.device.type == "npu", (
f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim
assert w.dtype in [torch.float16, torch.bfloat16], (
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
)

group_size = block_size[1]
k_dim = w.shape[-1]
assert group_size >= 32 and group_size % 32 == 0 and group_size < k_dim, (
f"Invalid group_size={group_size}: "
f"expected to be a multiple of 32, "
f"in range [32, {k_dim - 1}] for per-group quantization, "
f"but got group_size={group_size} (k_dim={k_dim})."
)

original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = -8
quant_max = 7
eps = 1e-6
scale_dtype = w.dtype
zero_point_dtype = w.dtype

scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)

int_data = quantize_affine(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)

assert int_data.dtype == torch.int32, (
"torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype"
)
assert int_data.shape[-1] % 8 == 0, (
f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}"
)

packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack(
int_data.contiguous(), 0
)

scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)

return Int4PlainInt32Tensor(
packed_weight.contiguous(),
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous(),
block_size,
original_shape,
act_pre_scale=None,
)


implements = Int4PlainInt32Tensor.implements
Expand All @@ -157,6 +252,22 @@ def _(func, types, args, kwargs):
args[1],
args[2] if len(args) > 2 else None,
)

if input_tensor.device.type == "xpu":
return _linear_xpu(input_tensor, weight_tensor, bias)
elif input_tensor.device.type == "npu":
return _linear_npu(input_tensor, weight_tensor, bias)
else:
raise NotImplementedError(
f"Int4PlainInt32Tensor does not support device '{input_tensor.device.type}' yet."
)


def _linear_xpu(
input_tensor,
weight_tensor,
bias,
):
assert input_tensor.device.type == "xpu", (
f"For XPU device only but got: {input_tensor.device}"
)
Expand Down Expand Up @@ -201,6 +312,71 @@ def _(func, types, args, kwargs):
return y.to(orig_dtype)


def _linear_npu(
input_tensor,
weight_tensor,
bias,
):
assert input_tensor.device.type == "npu", (
f"For NPU device only but got: {input_tensor.device.type}"
)
assert isinstance(weight_tensor, Int4PlainInt32Tensor), (
f"Expected weight_tensor to be Int4PlainInt32NPUTensor, got: {type(weight_tensor)}"
)
assert weight_tensor.block_size[0] == 1, (
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
)
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
)

if weight_tensor.act_pre_scale is not None:
input_tensor = input_tensor * weight_tensor.act_pre_scale

act_mat = input_tensor
packed_weight = weight_tensor.qdata
scale = weight_tensor.scale
zero_point = weight_tensor.zero_point

orig_act_size = act_mat.shape
orig_dtype = act_mat.dtype

# dtype alignment
if act_mat.dtype == torch.float16:
scale = scale.to(torch.float16)
zero_point = zero_point.to(torch.float16)
if bias is not None:
bias = bias.to(torch.float16)
elif act_mat.dtype == torch.bfloat16:
scale = scale.to(torch.bfloat16)
zero_point = zero_point.to(torch.bfloat16)
if bias is not None:
bias = bias.to(torch.float32)

# reshape to 2D
act_mat = act_mat.reshape(-1, act_mat.shape[-1])

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]

y = torch.ops.npu.npu_weight_quant_batchmatmul(
x=act_mat,
weight=packed_weight.transpose(-1, -2),
antiquant_scale=scale,
antiquant_offset=zero_point,
antiquant_group_size=groupsize,
bias=bias,
)

# remove out_feature padding
assert weight_tensor.ndim == 2
orig_out_features = weight_tensor.shape[-2]
y = y[:, :orig_out_features]
y = y.reshape(*orig_act_size[:-1], orig_out_features)

return y.to(orig_dtype)


Int4PlainInt32Tensor.__module__ = "torchao.quantization"

# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True`
Expand Down
Loading