Skip to content

Commit e033343

Browse files
committed
Remove old QAT APIs
**Summary:** As a follow-up to #2641, which deprecated the old QAT APIs in 0.13.0, we remove them now in the next release 0.15.0. Fixes #2630. **Test Plan:** CI ghstack-source-id: de20afc Pull Request resolved: #3147
1 parent 8748104 commit e033343

File tree

11 files changed

+15
-273
lines changed

11 files changed

+15
-273
lines changed

docs/source/api_ref_qat.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ Legacy QAT APIs
4242
:toctree: generated/
4343
:nosignatures:
4444

45-
IntXQuantizationAwareTrainingConfig
46-
FromIntXQuantizationAwareTrainingConfig
4745
Int4WeightOnlyQATQuantizer
4846
linear.Int4WeightOnlyQATLinear
4947
Int8DynActInt4WeightQATQuantizer

test/prototype/test_embedding.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
)
1919
from torchao.quantization.granularity import PerAxis, PerGroup
2020
from torchao.quantization.qat import (
21-
FromIntXQuantizationAwareTrainingConfig,
2221
Int4WeightOnlyEmbeddingQATQuantizer,
2322
IntxFakeQuantizeConfig,
24-
IntXQuantizationAwareTrainingConfig,
23+
QATConfig,
2524
)
2625
from torchao.quantization.quant_api import (
2726
Int8DynamicActivationIntxWeightConfig,
@@ -257,7 +256,7 @@ def test_identical_to_IntxWeightOnlyConfig(
257256
],
258257
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
259258
)
260-
def test_identical_to_IntXQuantizationAwareTrainingConfig(
259+
def test_identical_to_QATConfig(
261260
self, weight_dtype, granularity, mapping_type, scale_dtype, model_dtype
262261
):
263262
# ASYMMETRIC in QAT is very different that PTQ configs
@@ -288,12 +287,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
288287
)
289288
quantize_(
290289
model,
291-
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
290+
QATConfig(weight_config=weight_config, step="prepare"),
292291
embedding_filter,
293292
)
294293
prepared_out = model(indices)
295294

296-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
295+
quantize_(model, QATConfig(step="convert"), embedding_filter)
297296
quantize_(
298297
model,
299298
IntxWeightOnlyConfig(
@@ -355,7 +354,7 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer(
355354
prepared_out = model(indices)
356355

357356
# Convert model method 1
358-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
357+
quantize_(model, QATConfig(step="convert"), embedding_filter)
359358
quantize_(
360359
model,
361360
IntxWeightOnlyConfig(

test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
2121
from torchao.quantization.granularity import PerAxis, PerGroup
2222
from torchao.quantization.qat import (
23-
FromIntXQuantizationAwareTrainingConfig,
2423
Int8DynActInt4WeightQATQuantizer,
2524
IntxFakeQuantizeConfig,
26-
IntXQuantizationAwareTrainingConfig,
25+
QATConfig,
2726
)
2827
from torchao.quantization.quant_api import (
2928
Int8DynamicActivationInt4WeightConfig,
@@ -499,7 +498,7 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
499498
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
500499
],
501500
)
502-
def test_identical_to_IntXQuantizationAwareTrainingConfig(
501+
def test_identical_to_QATConfig(
503502
self,
504503
weight_dtype,
505504
group_size,
@@ -545,7 +544,11 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
545544

546545
quantize_(
547546
model,
548-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
547+
QATConfig(
548+
activation_config=activation_config,
549+
weight_config=weight_config,
550+
step="prepare",
551+
),
549552
)
550553
try:
551554
prepared_out = model(activations)
@@ -555,7 +558,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
555558
return
556559
raise e
557560

558-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
561+
quantize_(model, QATConfig(step="convert"))
559562
quantize_(
560563
model,
561564
Int8DynamicActivationIntxWeightConfig(
@@ -606,7 +609,7 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer(
606609
prepared_out = model(activations)
607610

608611
# Convert model method 1
609-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
612+
quantize_(model, QATConfig(step="convert"))
610613
quantize_(
611614
model,
612615
Int8DynamicActivationIntxWeightConfig(

test/quantization/test_qat.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import copy
1111
import unittest
12-
import warnings
1312
from typing import List, Type
1413

1514
import torch
@@ -39,8 +38,6 @@
3938
)
4039
from torchao.quantization.qat.api import (
4140
ComposableQATQuantizer,
42-
FromIntXQuantizationAwareTrainingConfig,
43-
IntXQuantizationAwareTrainingConfig,
4441
QATConfig,
4542
QATStep,
4643
initialize_fake_quantizers,
@@ -1718,95 +1715,6 @@ def test_qat_fp8a4w_quantizer(self):
17181715
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
17191716
self.assertFalse(torch.equal(new_weight, prev_weight))
17201717

1721-
def test_legacy_quantize_api_e2e(self):
1722-
"""
1723-
Test that the following two APIs are numerically equivalent:
1724-
1725-
New API:
1726-
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1727-
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1728-
1729-
Old API:
1730-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1731-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
1732-
quantize_(model, Int8DynamicActivationInt4WeightConfig())
1733-
"""
1734-
group_size = 16
1735-
torch.manual_seed(self.SEED)
1736-
m = M()
1737-
baseline_model = copy.deepcopy(m)
1738-
1739-
# Baseline prepare
1740-
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1741-
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1742-
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
1743-
quantize_(baseline_model, old_qat_config)
1744-
1745-
# QATConfig prepare
1746-
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1747-
quantize_(m, QATConfig(base_config, step="prepare"))
1748-
1749-
# Compare prepared values
1750-
torch.manual_seed(self.SEED)
1751-
x = m.example_inputs()
1752-
x2 = copy.deepcopy(x)
1753-
out = m(*x)
1754-
baseline_out = baseline_model(*x2)
1755-
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1756-
1757-
# Baseline convert
1758-
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
1759-
quantize_(baseline_model, base_config)
1760-
1761-
# quantize_ convert
1762-
quantize_(m, QATConfig(base_config, step="convert"))
1763-
1764-
# Compare converted values
1765-
torch.manual_seed(self.SEED)
1766-
x = m.example_inputs()
1767-
x2 = copy.deepcopy(x)
1768-
out = m(*x)
1769-
baseline_out = baseline_model(*x2)
1770-
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1771-
1772-
def test_qat_api_deprecation(self):
1773-
"""
1774-
Test that the appropriate deprecation warning is logged exactly once per class.
1775-
"""
1776-
from torchao.quantization.qat import (
1777-
FakeQuantizeConfig,
1778-
FakeQuantizer,
1779-
from_intx_quantization_aware_training,
1780-
intx_quantization_aware_training,
1781-
)
1782-
1783-
# Reset deprecation warning state, otherwise we won't log warnings here
1784-
warnings.resetwarnings()
1785-
1786-
# Map from deprecated API to the args needed to instantiate it
1787-
deprecated_apis_to_args = {
1788-
IntXQuantizationAwareTrainingConfig: (),
1789-
FromIntXQuantizationAwareTrainingConfig: (),
1790-
intx_quantization_aware_training: (),
1791-
from_intx_quantization_aware_training: (),
1792-
FakeQuantizeConfig: (torch.int8, "per_channel"),
1793-
FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),),
1794-
}
1795-
1796-
with warnings.catch_warnings(record=True) as _warnings:
1797-
# Call each deprecated API twice
1798-
for cls, args in deprecated_apis_to_args.items():
1799-
cls(*args)
1800-
cls(*args)
1801-
1802-
# Each call should trigger the warning only once
1803-
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
1804-
for w in _warnings:
1805-
self.assertIn(
1806-
"is deprecated and will be removed in a future release",
1807-
str(w.message),
1808-
)
1809-
18101718
def test_qat_api_convert_no_quantization(self):
18111719
"""
18121720
Test that `QATConfig(step="convert")` swaps back to nn modules without quantization.

torchao/quantization/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
PlainLayout,
6565
TensorCoreTiledLayout,
6666
UIntXWeightOnlyConfig,
67-
intx_quantization_aware_training,
6867
quantize_,
6968
swap_conv2d_1x1_to_linear,
7069
)
@@ -119,7 +118,6 @@
119118
"ALL_AUTOQUANT_CLASS_LIST",
120119
# top level API - manual
121120
"quantize_",
122-
"intx_quantization_aware_training",
123121
"swap_conv2d_1x1_to_linear",
124122
"Int4DynamicActivationInt4WeightConfig",
125123
"Int8DynamicActivationInt4WeightConfig",

torchao/quantization/prototype/qat/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torchao.quantization.qat.api import (
22
ComposableQATQuantizer,
3-
FakeQuantizeConfig,
3+
IntxFakeQuantizeConfig as FakeQuantizeConfig,
44
)
55

66
__all__ = [

torchao/quantization/qat/__init__.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
11
from .api import (
22
ComposableQATQuantizer,
3-
FromIntXQuantizationAwareTrainingConfig,
4-
IntXQuantizationAwareTrainingConfig,
53
QATConfig,
64
QATStep,
7-
from_intx_quantization_aware_training,
85
initialize_fake_quantizers,
9-
intx_quantization_aware_training,
106
)
117
from .embedding import (
128
FakeQuantizedEmbedding,
139
Int4WeightOnlyEmbeddingQATQuantizer,
1410
)
1511
from .fake_quantize_config import (
16-
FakeQuantizeConfig,
1712
FakeQuantizeConfigBase,
1813
Float8FakeQuantizeConfig,
1914
IntxFakeQuantizeConfig,
2015
)
2116
from .fake_quantizer import (
22-
FakeQuantizer,
2317
FakeQuantizerBase,
2418
Float8FakeQuantizer,
2519
IntxFakeQuantizer,
@@ -50,11 +44,4 @@
5044
"Int4WeightOnlyEmbeddingQATQuantizer",
5145
"Int4WeightOnlyQATQuantizer",
5246
"Int8DynActInt4WeightQATQuantizer",
53-
# for BC
54-
"FakeQuantizer",
55-
"FakeQuantizeConfig",
56-
"from_intx_quantization_aware_training",
57-
"FromIntXQuantizationAwareTrainingConfig",
58-
"intx_quantization_aware_training",
59-
"IntXQuantizationAwareTrainingConfig",
6047
]

0 commit comments

Comments
 (0)