Skip to content

Conversation

@orangeH25
Copy link
Contributor

@orangeH25 orangeH25 commented Oct 14, 2025

Related to #3044

Summary

This PR adds NPU (Ascend) backend support for the INT4 weight-only quantization workflow.

It introduces a new tensor subclass, Int4PlainInt32TensorNPU, aligned with the existing Int4PlainInt32Tensor for the plain_int32 packing format.

Environment

  • torchao version: 0.13.0 (main branch, commit: f64daac)
  • torch version: 2.7.1
  • torch_npu version: 2.7.1rc1
  • Ascend Toolkit (CANN): 8.2.RC1
  • Device: Ascend 910B4
  • OS: EulerOS 2.10 (Kernel 4.19.90, aarch64)
  • Python: 3.11

Files changed

Modified

  • torchao/quantization/__init__.py
  • torchao/quantization/quant_api.py
  • torchao/quantization/quantize_/workflows/__init__.py

Added

  • torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py
  • test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py

Implementation Overview

  • Introduces Int4PlainInt32TensorNPU to enable NPU backend support for INT4 weight-only quantization.
  • Registeres new tensor subclass and integrated into quant_api.py for dispatch.
  • Updates __init__.py files to ensure proper import and exposure.
  • Adds corresponding test cases for NPU workflow.

Test Case

  • test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py
image

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3172

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 14, 2025
@orangeH25 orangeH25 marked this pull request as draft October 14, 2025 11:44
Comment on lines 29 to 33
try:
import torch_npu
except ImportError:
torch_npu = None

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch provide Autoload mechinasm, so we do not need to import it explicitly.

Comment on lines 43 to 44
@unittest.skipIf(torch_npu is None, "torch_npu is not available")
@unittest.skipIf(not torch_npu.npu.is_available(), "NPU not available")
Copy link

@fffrog fffrog Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@unittest.skipIf(torch_npu is None, "torch_npu is not available")
@unittest.skipIf(not torch_npu.npu.is_available(), "NPU not available")
@unittest.skipIf(torch.accelerator.current_accelerator(True).type == "npu" and torch.accelerator.is_available(), "NPU not available")

Comment on lines 45 to 48
@unittest.skipIf(
version.parse(torch_npu.__version__) < version.parse("2.7.1rc1"),
"Need torch_npu 2.7.1rc1+",
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove it because there are some strcit version mapping between PyTorch and Torch_NPU

)

assert int_data.dtype == torch.int32, (
f"torch_npu.npu_convert_weight_to_int4pack expects `int32` dtype"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"torch_npu.npu_convert_weight_to_int4pack expects `int32` dtype"
f"torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype"

)

assert int_data.shape[-1] % 8 == 0, (
f"torch_npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"torch_npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}"
f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}"

@orangeH25 orangeH25 marked this pull request as ready for review October 15, 2025 06:28
@orangeH25
Copy link
Contributor Author

Hi @jcaip @jerryzh168 , please help to review it, thanks!

and torch.accelerator.is_available(),
"NPU not available",
)
class Int4PlainInt32TensorNPU(TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, do we need NPUs to test this? I don't think we have any in CI.

Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @orangeH25 @fffrog!

The code looks good to me, but I'm curious on how to best test this? It looks like we skip tests in CI because we don't have NPU devices. I believe that NPU support was added to TorchTune as well, do you know how they test device specific functionality there?

Also, just a heads up most of the team is at PTC / Open source AI week in SF this week, so we might be a little slow in responding :)

@jerryzh168
Copy link
Contributor

please don't include device NPU in the name of Tensor, since packing_format is supposed to be agnostic to device

int4 weight-only quantization on Ascend NPU backend (groupwise quantization only)

Tensor Attributes:
qdata: (N, K/8), packed int4 weight, the data type is int32 here with 8*int4, the original dtype can be float16 or bfloat16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this exactly align with Int4PlainInt32Tensor? if so, please merge with that tensor subclass

@orangeH25
Copy link
Contributor Author

orangeH25 commented Oct 21, 2025

Hi @jcaip @jerryzh168 ,thanks for the review!


Just curious, do we need NPUs to test this? I don't think we have any in CI.

Yes, this case is actually pretty common in open-source projects.

A typical approach is to set up a nightly CI job in the downstream repo that automatically pulls the latest code from the upstream repo each day for build and testing. The results are then shown as a GitHub badge in the upstream repo’s README.md to clearly show the latest build status.

does this exactly align with Int4PlainInt32Tensor? if so, please merge with that tensor subclass

You mean that we should keep the entry logic in quant_api.py unchanged:

elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
    new_weight = Int4PlainInt32Tensor.from_hp(
        weight,
        block_size,
    )
    return new_weight

and then handle different backend implementations in the from_hp and linear methods of Int4PlainInt32Tensor using simple if/else branches.

class Int4PlainInt32Tensor(TorchAOBaseTensor):
    ...
    @classmethod
    def from_hp(
        cls,
        w: torch.Tensor,
        block_size: List[int],
    ):
        if w.device.type == "xpu":
            from_hp_xpu(cls, w, block_size)
        elif w.device.type == "npu":
            from_hp_npu(cls, w, block_size)
     
           
implements = Int4PlainInt32Tensor.implements
implements_torch_function = Int4PlainInt32Tensor.implements_torch_function

@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
    input_tensor, weight_tensor, bias = (
        args[0],
        args[1],
        args[2] if len(args) > 2 else None,
    )
    
	if input_tensor.device.type == "xpu":
    	return linear_xpu(input_tensor, weight_tensor, bias)
    elif input_tensor.device.type == "npu":
        return linear_npu(input_tensor, weight_tensor, bias)

Did I get that right? Happy to hear any thoughts or suggestions you might have!

@jerryzh168
Copy link
Contributor

You mean that we should keep the entry logic in quant_api.py unchanged:
and then handle different backend implementations in the from_hp and linear methods of Int4PlainInt32Tensor using simple if/else branches.

Yes that's correct

@orangeH25
Copy link
Contributor Author

You mean that we should keep the entry logic in quant_api.py unchanged:
and then handle different backend implementations in the from_hp and linear methods of Int4PlainInt32Tensor using simple if/else branches.

Yes that's correct

Got it, I will follow this approach, thanks!

@orangeH25
Copy link
Contributor Author

You mean that we should keep the entry logic in quant_api.py unchanged:
and then handle different backend implementations in the from_hp and linear methods of Int4PlainInt32Tensor using simple if/else branches.

Yes that's correct

Hi @jerryzh168 @jcaip , I’ve made those changes, please take a look, really appreciate it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants