Commit ce8796d
authored
GatherBlockQuantized supports zero points and 8 bits for uint8 dtype (microsoft#25214)
Add support for unit8 GatherBlockQuantized for the following two areas:
* Allow zero points.
* Add bits attribute and support bits=8.
Major change is to update shape inference; and update unit tests to
cover these.
Note that only CPU implementation, and CUDA implementation will be added
later in another PR.
### Motivation and Context
Previously, zero points are not supported when dtype is uint8. Only 4
bit quantization without zero points were supported.
This change is to share weights of lm_head with 8 bit quantization
between GatherBlockQuantized and MatMulNBits.
For example, when K is multiple of `block_size`, typical input and
output shapes are like the following:
* data has shape (N, K) for 8 bits, or (N, K / 2) for 4 bits.
* scales has shape (N, k_blocks), where k_blocks = (K / block_size).
* zero_points has shape (N, k_blocks) for 8 bits, (N, (k_blocks + 1) /
2) for 4 bits.
* output will have shape (..., K), where ... is the shape of `indices`.1 parent 4b18210 commit ce8796d
File tree
6 files changed
+404
-225
lines changed- docs
- js/web/test/data/ops
- onnxruntime
- contrib_ops
- cpu/quantization
- js/quantization
- core/graph/contrib_ops
- test/contrib_ops
6 files changed
+404
-225
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2053 | 2053 | | |
2054 | 2054 | | |
2055 | 2055 | | |
| 2056 | + | |
| 2057 | + | |
2056 | 2058 | | |
2057 | 2059 | | |
2058 | 2060 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
24 | 29 | | |
25 | 30 | | |
26 | 31 | | |
| |||
Lines changed: 72 additions & 17 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | | - | |
| 21 | + | |
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | | - | |
| 26 | + | |
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| 32 | + | |
32 | 33 | | |
33 | 34 | | |
34 | 35 | | |
| |||
47 | 48 | | |
48 | 49 | | |
49 | 50 | | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
50 | 58 | | |
51 | 59 | | |
52 | 60 | | |
| |||
84 | 92 | | |
85 | 93 | | |
86 | 94 | | |
| 95 | + | |
87 | 96 | | |
88 | 97 | | |
89 | 98 | | |
| |||
94 | 103 | | |
95 | 104 | | |
96 | 105 | | |
97 | | - | |
98 | 106 | | |
99 | 107 | | |
| 108 | + | |
100 | 109 | | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
101 | 118 | | |
102 | 119 | | |
103 | | - | |
| 120 | + | |
104 | 121 | | |
105 | 122 | | |
106 | 123 | | |
| |||
113 | 130 | | |
114 | 131 | | |
115 | 132 | | |
116 | | - | |
| 133 | + | |
117 | 134 | | |
118 | | - | |
| 135 | + | |
119 | 136 | | |
120 | | - | |
121 | | - | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
122 | 148 | | |
123 | 149 | | |
124 | 150 | | |
| |||
137 | 163 | | |
138 | 164 | | |
139 | 165 | | |
140 | | - | |
141 | | - | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
142 | 174 | | |
143 | 175 | | |
144 | 176 | | |
| |||
186 | 218 | | |
187 | 219 | | |
188 | 220 | | |
189 | | - | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
190 | 231 | | |
191 | 232 | | |
192 | 233 | | |
193 | 234 | | |
194 | 235 | | |
195 | 236 | | |
196 | 237 | | |
| 238 | + | |
197 | 239 | | |
198 | | - | |
199 | | - | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
200 | 255 | | |
201 | | - | |
202 | | - | |
203 | | - | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
204 | 259 | | |
205 | 260 | | |
206 | 261 | | |
| |||
232 | 287 | | |
233 | 288 | | |
234 | 289 | | |
235 | | - | |
| 290 | + | |
236 | 291 | | |
237 | 292 | | |
238 | 293 | | |
| |||
Lines changed: 6 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
31 | 36 | | |
32 | 37 | | |
| 38 | + | |
33 | 39 | | |
34 | 40 | | |
35 | 41 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3599 | 3599 | | |
3600 | 3600 | | |
3601 | 3601 | | |
| 3602 | + | |
| 3603 | + | |
| 3604 | + | |
| 3605 | + | |
3602 | 3606 | | |
3603 | 3607 | | |
3604 | 3608 | | |
| |||
3614 | 3618 | | |
3615 | 3619 | | |
3616 | 3620 | | |
3617 | | - | |
| 3621 | + | |
| 3622 | + | |
3618 | 3623 | | |
3619 | 3624 | | |
3620 | 3625 | | |
3621 | 3626 | | |
3622 | 3627 | | |
3623 | 3628 | | |
3624 | | - | |
3625 | 3629 | | |
3626 | | - | |
3627 | | - | |
| 3630 | + | |
| 3631 | + | |
| 3632 | + | |
3628 | 3633 | | |
3629 | 3634 | | |
3630 | 3635 | | |
3631 | 3636 | | |
| 3637 | + | |
3632 | 3638 | | |
| 3639 | + | |
3633 | 3640 | | |
3634 | 3641 | | |
3635 | 3642 | | |
| |||
3643 | 3650 | | |
3644 | 3651 | | |
3645 | 3652 | | |
3646 | | - | |
3647 | | - | |
| 3653 | + | |
| 3654 | + | |
| 3655 | + | |
| 3656 | + | |
| 3657 | + | |
| 3658 | + | |
3648 | 3659 | | |
3649 | 3660 | | |
3650 | 3661 | | |
3651 | 3662 | | |
3652 | 3663 | | |
3653 | 3664 | | |
3654 | | - | |
| 3665 | + | |
3655 | 3666 | | |
3656 | 3667 | | |
3657 | 3668 | | |
| |||
3663 | 3674 | | |
3664 | 3675 | | |
3665 | 3676 | | |
3666 | | - | |
3667 | | - | |
3668 | | - | |
3669 | | - | |
3670 | 3677 | | |
3671 | 3678 | | |
3672 | 3679 | | |
| |||
3679 | 3686 | | |
3680 | 3687 | | |
3681 | 3688 | | |
| 3689 | + | |
| 3690 | + | |
| 3691 | + | |
| 3692 | + | |
| 3693 | + | |
| 3694 | + | |
3682 | 3695 | | |
3683 | 3696 | | |
3684 | 3697 | | |
3685 | 3698 | | |
3686 | 3699 | | |
3687 | 3700 | | |
3688 | 3701 | | |
3689 | | - | |
3690 | | - | |
| 3702 | + | |
| 3703 | + | |
| 3704 | + | |
| 3705 | + | |
| 3706 | + | |
| 3707 | + | |
| 3708 | + | |
3691 | 3709 | | |
3692 | | - | |
3693 | | - | |
3694 | | - | |
3695 | | - | |
3696 | | - | |
3697 | | - | |
3698 | | - | |
3699 | | - | |
3700 | | - | |
3701 | | - | |
| 3710 | + | |
| 3711 | + | |
| 3712 | + | |
| 3713 | + | |
| 3714 | + | |
| 3715 | + | |
| 3716 | + | |
| 3717 | + | |
| 3718 | + | |
| 3719 | + | |
| 3720 | + | |
| 3721 | + | |
| 3722 | + | |
3702 | 3723 | | |
3703 | 3724 | | |
3704 | 3725 | | |
| |||
0 commit comments