You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add a dequantize operator that converts int32 accumulator outputs to bfloat16 with per-group scale factor multiplication. This is needed to use the output of the INT8 GEMM operator (#93) in the rest of a model's forward pass.
Motivation:
The INT8 GEMM produces i32 accumulators, but the rest of the inference pipeline (RMSNorm, SiLU, RoPE, residual adds) operates in bf16. Without this conversion step, INT8 GEMM results can't feed back into the model, blocking end-to-end W8A8 quantized inference.
The existing dequant operator (iron/operators/dequant/) handles int4→bf16 for GPTQ/AWQ-style weight loading, which is a different use case. This operator would handle the GEMM accumulator conversion path.
Proposed behavior:
Input: int32 tensor + bfloat16 scale factors (one per group, precomputed as scale_activations * scale_weights)
Description:
Add a dequantize operator that converts int32 accumulator outputs to bfloat16 with per-group scale factor multiplication. This is needed to use the output of the INT8 GEMM operator (#93) in the rest of a model's forward pass.
Motivation:
The INT8 GEMM produces i32 accumulators, but the rest of the inference pipeline (RMSNorm, SiLU, RoPE, residual adds) operates in bf16. Without this conversion step, INT8 GEMM results can't feed back into the model, blocking end-to-end W8A8 quantized inference.
The existing dequant operator (iron/operators/dequant/) handles int4→bf16 for GPTQ/AWQ-style weight loading, which is a different use case. This operator would handle the GEMM accumulator conversion path.
Proposed behavior:
Implementation approach:
Related: