|  | 
| 9 | 9 | 
 | 
| 10 | 10 | import copy | 
| 11 | 11 | import unittest | 
| 12 |  | -import warnings | 
| 13 | 12 | from typing import List, Type | 
| 14 | 13 | 
 | 
| 15 | 14 | import torch | 
|  | 
| 39 | 38 | ) | 
| 40 | 39 | from torchao.quantization.qat.api import ( | 
| 41 | 40 |     ComposableQATQuantizer, | 
| 42 |  | -    FromIntXQuantizationAwareTrainingConfig, | 
| 43 |  | -    IntXQuantizationAwareTrainingConfig, | 
| 44 | 41 |     QATConfig, | 
| 45 | 42 |     QATStep, | 
| 46 | 43 |     initialize_fake_quantizers, | 
| @@ -1718,95 +1715,6 @@ def test_qat_fp8a4w_quantizer(self): | 
| 1718 | 1715 |         self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) | 
| 1719 | 1716 |         self.assertFalse(torch.equal(new_weight, prev_weight)) | 
| 1720 | 1717 | 
 | 
| 1721 |  | -    def test_legacy_quantize_api_e2e(self): | 
| 1722 |  | -        """ | 
| 1723 |  | -        Test that the following two APIs are numerically equivalent: | 
| 1724 |  | -
 | 
| 1725 |  | -        New API: | 
| 1726 |  | -            quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) | 
| 1727 |  | -            quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) | 
| 1728 |  | -
 | 
| 1729 |  | -        Old API: | 
| 1730 |  | -            quantize_(model, IntXQuantizationAwareTrainingConfig(...)) | 
| 1731 |  | -            quantize_(model, FromIntXQuantizationAwareTrainingConfig()) | 
| 1732 |  | -            quantize_(model, Int8DynamicActivationInt4WeightConfig()) | 
| 1733 |  | -        """ | 
| 1734 |  | -        group_size = 16 | 
| 1735 |  | -        torch.manual_seed(self.SEED) | 
| 1736 |  | -        m = M() | 
| 1737 |  | -        baseline_model = copy.deepcopy(m) | 
| 1738 |  | - | 
| 1739 |  | -        # Baseline prepare | 
| 1740 |  | -        act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) | 
| 1741 |  | -        weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) | 
| 1742 |  | -        old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config) | 
| 1743 |  | -        quantize_(baseline_model, old_qat_config) | 
| 1744 |  | - | 
| 1745 |  | -        # QATConfig prepare | 
| 1746 |  | -        base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size) | 
| 1747 |  | -        quantize_(m, QATConfig(base_config, step="prepare")) | 
| 1748 |  | - | 
| 1749 |  | -        # Compare prepared values | 
| 1750 |  | -        torch.manual_seed(self.SEED) | 
| 1751 |  | -        x = m.example_inputs() | 
| 1752 |  | -        x2 = copy.deepcopy(x) | 
| 1753 |  | -        out = m(*x) | 
| 1754 |  | -        baseline_out = baseline_model(*x2) | 
| 1755 |  | -        torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) | 
| 1756 |  | - | 
| 1757 |  | -        # Baseline convert | 
| 1758 |  | -        quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig()) | 
| 1759 |  | -        quantize_(baseline_model, base_config) | 
| 1760 |  | - | 
| 1761 |  | -        # quantize_ convert | 
| 1762 |  | -        quantize_(m, QATConfig(base_config, step="convert")) | 
| 1763 |  | - | 
| 1764 |  | -        # Compare converted values | 
| 1765 |  | -        torch.manual_seed(self.SEED) | 
| 1766 |  | -        x = m.example_inputs() | 
| 1767 |  | -        x2 = copy.deepcopy(x) | 
| 1768 |  | -        out = m(*x) | 
| 1769 |  | -        baseline_out = baseline_model(*x2) | 
| 1770 |  | -        torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) | 
| 1771 |  | - | 
| 1772 |  | -    def test_qat_api_deprecation(self): | 
| 1773 |  | -        """ | 
| 1774 |  | -        Test that the appropriate deprecation warning is logged exactly once per class. | 
| 1775 |  | -        """ | 
| 1776 |  | -        from torchao.quantization.qat import ( | 
| 1777 |  | -            FakeQuantizeConfig, | 
| 1778 |  | -            FakeQuantizer, | 
| 1779 |  | -            from_intx_quantization_aware_training, | 
| 1780 |  | -            intx_quantization_aware_training, | 
| 1781 |  | -        ) | 
| 1782 |  | - | 
| 1783 |  | -        # Reset deprecation warning state, otherwise we won't log warnings here | 
| 1784 |  | -        warnings.resetwarnings() | 
| 1785 |  | - | 
| 1786 |  | -        # Map from deprecated API to the args needed to instantiate it | 
| 1787 |  | -        deprecated_apis_to_args = { | 
| 1788 |  | -            IntXQuantizationAwareTrainingConfig: (), | 
| 1789 |  | -            FromIntXQuantizationAwareTrainingConfig: (), | 
| 1790 |  | -            intx_quantization_aware_training: (), | 
| 1791 |  | -            from_intx_quantization_aware_training: (), | 
| 1792 |  | -            FakeQuantizeConfig: (torch.int8, "per_channel"), | 
| 1793 |  | -            FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),), | 
| 1794 |  | -        } | 
| 1795 |  | - | 
| 1796 |  | -        with warnings.catch_warnings(record=True) as _warnings: | 
| 1797 |  | -            # Call each deprecated API twice | 
| 1798 |  | -            for cls, args in deprecated_apis_to_args.items(): | 
| 1799 |  | -                cls(*args) | 
| 1800 |  | -                cls(*args) | 
| 1801 |  | - | 
| 1802 |  | -            # Each call should trigger the warning only once | 
| 1803 |  | -            self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) | 
| 1804 |  | -            for w in _warnings: | 
| 1805 |  | -                self.assertIn( | 
| 1806 |  | -                    "is deprecated and will be removed in a future release", | 
| 1807 |  | -                    str(w.message), | 
| 1808 |  | -                ) | 
| 1809 |  | - | 
| 1810 | 1718 |     def test_qat_api_convert_no_quantization(self): | 
| 1811 | 1719 |         """ | 
| 1812 | 1720 |         Test that `QATConfig(step="convert")` swaps back to nn modules without quantization. | 
|  | 
0 commit comments