|
9 | 9 |
|
10 | 10 | import copy |
11 | 11 | import unittest |
12 | | -import warnings |
13 | 12 | from typing import List, Type |
14 | 13 |
|
15 | 14 | import torch |
|
39 | 38 | ) |
40 | 39 | from torchao.quantization.qat.api import ( |
41 | 40 | ComposableQATQuantizer, |
42 | | - FromIntXQuantizationAwareTrainingConfig, |
43 | | - IntXQuantizationAwareTrainingConfig, |
44 | 41 | QATConfig, |
45 | 42 | QATStep, |
46 | 43 | initialize_fake_quantizers, |
@@ -1718,95 +1715,6 @@ def test_qat_fp8a4w_quantizer(self): |
1718 | 1715 | self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) |
1719 | 1716 | self.assertFalse(torch.equal(new_weight, prev_weight)) |
1720 | 1717 |
|
1721 | | - def test_legacy_quantize_api_e2e(self): |
1722 | | - """ |
1723 | | - Test that the following two APIs are numerically equivalent: |
1724 | | -
|
1725 | | - New API: |
1726 | | - quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) |
1727 | | - quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) |
1728 | | -
|
1729 | | - Old API: |
1730 | | - quantize_(model, IntXQuantizationAwareTrainingConfig(...)) |
1731 | | - quantize_(model, FromIntXQuantizationAwareTrainingConfig()) |
1732 | | - quantize_(model, Int8DynamicActivationInt4WeightConfig()) |
1733 | | - """ |
1734 | | - group_size = 16 |
1735 | | - torch.manual_seed(self.SEED) |
1736 | | - m = M() |
1737 | | - baseline_model = copy.deepcopy(m) |
1738 | | - |
1739 | | - # Baseline prepare |
1740 | | - act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) |
1741 | | - weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) |
1742 | | - old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config) |
1743 | | - quantize_(baseline_model, old_qat_config) |
1744 | | - |
1745 | | - # QATConfig prepare |
1746 | | - base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size) |
1747 | | - quantize_(m, QATConfig(base_config, step="prepare")) |
1748 | | - |
1749 | | - # Compare prepared values |
1750 | | - torch.manual_seed(self.SEED) |
1751 | | - x = m.example_inputs() |
1752 | | - x2 = copy.deepcopy(x) |
1753 | | - out = m(*x) |
1754 | | - baseline_out = baseline_model(*x2) |
1755 | | - torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) |
1756 | | - |
1757 | | - # Baseline convert |
1758 | | - quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig()) |
1759 | | - quantize_(baseline_model, base_config) |
1760 | | - |
1761 | | - # quantize_ convert |
1762 | | - quantize_(m, QATConfig(base_config, step="convert")) |
1763 | | - |
1764 | | - # Compare converted values |
1765 | | - torch.manual_seed(self.SEED) |
1766 | | - x = m.example_inputs() |
1767 | | - x2 = copy.deepcopy(x) |
1768 | | - out = m(*x) |
1769 | | - baseline_out = baseline_model(*x2) |
1770 | | - torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) |
1771 | | - |
1772 | | - def test_qat_api_deprecation(self): |
1773 | | - """ |
1774 | | - Test that the appropriate deprecation warning is logged exactly once per class. |
1775 | | - """ |
1776 | | - from torchao.quantization.qat import ( |
1777 | | - FakeQuantizeConfig, |
1778 | | - FakeQuantizer, |
1779 | | - from_intx_quantization_aware_training, |
1780 | | - intx_quantization_aware_training, |
1781 | | - ) |
1782 | | - |
1783 | | - # Reset deprecation warning state, otherwise we won't log warnings here |
1784 | | - warnings.resetwarnings() |
1785 | | - |
1786 | | - # Map from deprecated API to the args needed to instantiate it |
1787 | | - deprecated_apis_to_args = { |
1788 | | - IntXQuantizationAwareTrainingConfig: (), |
1789 | | - FromIntXQuantizationAwareTrainingConfig: (), |
1790 | | - intx_quantization_aware_training: (), |
1791 | | - from_intx_quantization_aware_training: (), |
1792 | | - FakeQuantizeConfig: (torch.int8, "per_channel"), |
1793 | | - FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),), |
1794 | | - } |
1795 | | - |
1796 | | - with warnings.catch_warnings(record=True) as _warnings: |
1797 | | - # Call each deprecated API twice |
1798 | | - for cls, args in deprecated_apis_to_args.items(): |
1799 | | - cls(*args) |
1800 | | - cls(*args) |
1801 | | - |
1802 | | - # Each call should trigger the warning only once |
1803 | | - self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) |
1804 | | - for w in _warnings: |
1805 | | - self.assertIn( |
1806 | | - "is deprecated and will be removed in a future release", |
1807 | | - str(w.message), |
1808 | | - ) |
1809 | | - |
1810 | 1718 | def test_qat_api_convert_no_quantization(self): |
1811 | 1719 | """ |
1812 | 1720 | Test that `QATConfig(step="convert")` swaps back to nn modules without quantization. |
|
0 commit comments