From 54457886893c61214ebd69524e8ead0f571e30a1 Mon Sep 17 00:00:00 2001 From: Yuri Khrustalev Date: Wed, 25 Mar 2026 16:55:50 -0700 Subject: [PATCH 1/2] Add more quants --- src/liquidonnx/lfm2/builder.py | 235 +++++++++++++++++++++++++++++++-- src/liquidonnx/lfm2/export.py | 160 +++++++++++++++++----- 2 files changed, 353 insertions(+), 42 deletions(-) diff --git a/src/liquidonnx/lfm2/builder.py b/src/liquidonnx/lfm2/builder.py index 0a46f8f..856010a 100644 --- a/src/liquidonnx/lfm2/builder.py +++ b/src/liquidonnx/lfm2/builder.py @@ -29,6 +29,66 @@ logger = logging.getLogger(__name__) +# === INT4 Block Quantization === + +INT4_BITS = 4 +INT4_MAX = (1 << INT4_BITS) - 1 # 15, max value for unsigned 4-bit +DEFAULT_BLOCK_SIZE = 32 +SCALE_EPS = 1e-10 + + +def quantize_int4_block( + weight: np.ndarray, block_size: int = DEFAULT_BLOCK_SIZE +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Quantize weight tensor to INT4 with block-wise scales and zero points. + + Args: + weight: FP32 weight tensor of shape [..., K] where K is quantized dimension + block_size: Number of elements per quantization block + + Returns: + quant: UINT8 tensor with packed INT4 values (2 per byte) + scales: FP32 scales, one per block + zero_points: UINT8 packed zero points (2 per byte) + """ + *batch_dims, K = weight.shape + n_blocks = (K + block_size - 1) // block_size + + pad_K = n_blocks * block_size + if pad_K != K: + pad_shape = list(weight.shape) + pad_shape[-1] = pad_K - K + weight = np.concatenate([weight, np.zeros(pad_shape, dtype=weight.dtype)], axis=-1) + + weight_blocked = weight.reshape(*batch_dims, n_blocks, block_size) + + w_min = weight_blocked.min(axis=-1, keepdims=True) + w_max = weight_blocked.max(axis=-1, keepdims=True) + + scale = (w_max - w_min) / float(INT4_MAX) + scale = np.where(scale < SCALE_EPS, 1.0, scale) + zero_point = np.round(-w_min / scale).clip(0, INT4_MAX).astype(np.uint8) + + # q = round(w/s + zp) to match community + quant = np.round(weight_blocked / scale + zero_point).clip(0, INT4_MAX).astype(np.uint8) + + # Pack two INT4 values into one UINT8 (low nibble first) + quant_packed = quant[..., 0::2] | (quant[..., 1::2] << 4) + + scales = scale.squeeze(-1).astype(np.float32) + + # Pack zero points + zero_point = zero_point.squeeze(-1) + if n_blocks % 2 == 1: + zp_shape = list(zero_point.shape) + zp_shape[-1] = 1 + zero_point = np.concatenate([zero_point, np.zeros(zp_shape, dtype=np.uint8)], axis=-1) + zp_packed = zero_point[..., 0::2] | (zero_point[..., 1::2] << 4) + + quant_final = quant_packed.reshape(*batch_dims, -1) + + return quant_final, scales, zp_packed + @dataclass class LFM2Config: @@ -86,7 +146,12 @@ class LFM2Builder(ONNXBuilderBase): """ def __init__( - self, config: LFM2Config, use_integrated_rope: bool = False, vl_naming: bool = False + self, + config: LFM2Config, + use_integrated_rope: bool = False, + vl_naming: bool = False, + use_q4: bool = False, + q4_block_size: int = DEFAULT_BLOCK_SIZE, ): """ Args: @@ -96,17 +161,114 @@ def __init__( vl_naming: Use VL-style node naming (Shape, Gather_1) instead of LFM2-style (Shape_for_slice, Gather_for_slice). Community VL and LFM2 models use different conventions. + use_q4: Use INT4 quantized embedding (GatherBlockQuantized) and lm_head + (MatMulNBits). Other MatMul layers are left as FP32 for post-export + quantization. + q4_block_size: Block size for INT4 quantization (default: 32). """ super().__init__() self.config = config self.head_dim = config.hidden_size // config.num_attention_heads self.use_integrated_rope = use_integrated_rope self.vl_naming = vl_naming + self.use_q4 = use_q4 + self.q4_block_size = q4_block_size # Categorize layers self.conv_indices = [i for i, t in enumerate(config.layer_types) if t == "conv"] self.attn_indices = [i for i, t in enumerate(config.layer_types) if t == "full_attention"] + # === Q4 Quantization Methods === + + def _quantize_for_matmul_nbits( + self, weight: np.ndarray, name: str + ) -> tuple[str, str, str, int, int]: + """Quantize weight for MatMulNBits operator. + + Args: + weight: FP32 weight tensor of shape [K, N] (already transposed for MatMul) + name: Base name for initializers + + Returns: + Tuple of (quant_name, scales_name, zp_name, K, N) + """ + K, N = weight.shape + block_size = self.q4_block_size + + weight_t = weight.T # [N, K] + quant, scales, zp = quantize_int4_block(weight_t, block_size) + + n_blocks = (K + block_size - 1) // block_size + quant_3d = quant.reshape(N, n_blocks, block_size // 2) + + quant_name = f"{name}_quant" + scales_name = f"{name}_scales" + zp_name = f"{name}_zp" + + self.add_initializer(quant_name, quant_3d, dtype=np.uint8) + self.add_initializer(scales_name, scales) + self.add_initializer(zp_name, zp, dtype=np.uint8) + + return quant_name, scales_name, zp_name, K, N + + def make_matmul_nbits( + self, input_name: str, weight: np.ndarray, name: str, output_name: str + ) -> str: + """Create MatMulNBits node for INT4 quantized linear layer. + + Args: + input_name: Input tensor name + weight: Weight matrix [K, N] (already transposed for MatMul) + name: Base name for the operation + output_name: Output tensor name + """ + quant_name, scales_name, zp_name, K, N = self._quantize_for_matmul_nbits(weight, name) + + return self.make_node( + "MatMulNBits", + [input_name, quant_name, scales_name, zp_name], + [output_name], + domain="com.microsoft", + K=K, + N=N, + bits=4, + block_size=self.q4_block_size, + ) + + def make_gather_block_quantized( + self, weight: np.ndarray, indices_name: str, name: str, output_name: str + ) -> str: + """Create GatherBlockQuantized node for INT4 quantized embedding lookup. + + Args: + weight: Embedding weight [vocab_size, hidden_size] + indices_name: Input token IDs tensor name + name: Base name for initializers + output_name: Output tensor name + """ + block_size = self.q4_block_size + + quant, scales, zp = quantize_int4_block(weight, block_size) + + quant_name = f"{name}_quant" + scales_name = f"{name}_scales" + zp_name = f"{name}_zp" + + self.add_initializer(quant_name, quant, dtype=np.uint8) + self.add_initializer(scales_name, scales) + self.add_initializer(zp_name, zp, dtype=np.uint8) + + return self.make_node( + "GatherBlockQuantized", + [quant_name, indices_name, scales_name, zp_name], + [output_name], + domain="com.microsoft", + bits=4, + block_size=block_size, + gather_axis=0, + quantize_axis=1, + ) + def make_simple_layernorm( self, input_name: str, weight_name: str, path: str, name: str = None ) -> str: @@ -318,7 +480,17 @@ def build_outputs(self): ) def build_embedding(self) -> str: - self.add_initializer("model.embed_tokens.weight", self.weights["model.embed_tokens.weight"]) + embed_weight = self.weights["model.embed_tokens.weight"] + + if self.use_q4: + return self.make_gather_block_quantized( + embed_weight, + "input_ids", + "model_embed_tokens_weight", + "/model/embed_tokens/GatherBlockQuantized/output_0", + ) + + self.add_initializer("model.embed_tokens.weight", embed_weight) return self.make_node( "Gather", ["model.embed_tokens.weight", "input_ids"], @@ -830,8 +1002,48 @@ def build_lm_head(self, hidden_state: str) -> str: name=f"/model/layers.{num_layers}/final_norm_layernorm/SkipLayerNorm", ) - # LM head with tied embeddings (community approach) - # Transpose embed_tokens at runtime instead of storing a copy (saves 256MB) + if self.use_q4: + # Q4: Use MatMulNBits for lm_head with shared embedding weights + embed_quant_name = "model_embed_tokens_weight_quant" + embed_quant = None + for init in self.initializers: + if init.name == embed_quant_name: + embed_quant = onnx.numpy_helper.to_array(init) + break + + if embed_quant is None: + raise ValueError("Embedding quant not found - build_embedding must be called first") + + vocab_size = embed_quant.shape[0] + K = self.config.hidden_size + n_blocks = (K + self.q4_block_size - 1) // self.q4_block_size + + # Reshape to 3D for MatMulNBits: [N, n_blocks, block_size/2] + embed_quant_matmul = embed_quant.reshape( + vocab_size, n_blocks, self.q4_block_size // 2 + ) + self.add_initializer( + "model_embed_tokens_weight_quant_matmul", embed_quant_matmul, dtype=np.uint8 + ) + + # Reuse scales and zero points from embedding + return self.make_node( + "MatMulNBits", + [ + normed, + "model_embed_tokens_weight_quant_matmul", + "model_embed_tokens_weight_scales", + "model_embed_tokens_weight_zp", + ], + ["logits"], + domain="com.microsoft", + K=K, + N=vocab_size, + bits=4, + block_size=self.q4_block_size, + ) + + # FP32: Transpose embed_tokens at runtime instead of storing a copy # embed_tokens.weight [vocab, hidden] → [hidden, vocab] lm_head_weight = self.make_node( "Transpose", @@ -876,10 +1088,12 @@ def build_value_info(self): self.add_value_info(f"{mask_prefix}/{gather_name}/Cast/output_0", TensorProto.INT32, []) # === Embedding output === + if self.use_q4: + embed_output = "/model/embed_tokens/GatherBlockQuantized/output_0" + else: + embed_output = "/model/embed_tokens/Gather/output_0" self.add_value_info( - "/model/embed_tokens/Gather/output_0", - TensorProto.FLOAT, - ["batch_size", "sequence_length", H], + embed_output, TensorProto.FLOAT, ["batch_size", "sequence_length", H] ) # === Per-layer outputs === @@ -1083,9 +1297,10 @@ def build_value_info(self): TensorProto.FLOAT, ["batch_size", "sequence_length", H], ) - self.add_value_info( - "/lm_head/Transpose/output_0", TensorProto.FLOAT, [H, self.config.vocab_size] - ) + if not self.use_q4: + self.add_value_info( + "/lm_head/Transpose/output_0", TensorProto.FLOAT, [H, self.config.vocab_size] + ) def load_weights(self, model_path: str): """Load weights from HuggingFace model.""" diff --git a/src/liquidonnx/lfm2/export.py b/src/liquidonnx/lfm2/export.py index 815b8f4..70451f6 100644 --- a/src/liquidonnx/lfm2/export.py +++ b/src/liquidonnx/lfm2/export.py @@ -8,13 +8,15 @@ ├── config.json ├── tokenizer.json └── onnx/ - ├── model.onnx # FP32 + ├── model.onnx # FP32 ├── model.onnx_data - ├── model_fp16.onnx # --precision fp16 + ├── model_fp16.onnx # --precision fp16 ├── model_fp16.onnx_data - ├── model_q4.onnx # --precision q4 + ├── model_q4.onnx # --precision q4 (GatherBlockQuantized) ├── model_q4.onnx_data - ├── model_q8.onnx # --precision q8 + ├── model_q4f32.onnx # --precision q4f32 (FP32 embedding) + ├── model_q4f32.onnx_data + ├── model_q8.onnx # --precision q8 └── model_q8.onnx_data Usage: @@ -31,14 +33,11 @@ # Export with specific precisions uv run lfm2-export LiquidAI/LFM2-350M --precision fp16 q4 - # Export with all precisions (fp16, q4, q8) + # Export with all precisions uv run lfm2-export LiquidAI/LFM2-350M --precision # Convert existing export (skip FP32 export) uv run lfm2-export LiquidAI/LFM2-350M --precision --skip-export - - # Quantize with lm_head included - uv run lfm2-export LiquidAI/LFM2-350M --precision q4 --no-exclude-lm-head """ import argparse @@ -195,36 +194,100 @@ def export_model(model_path: str, output_dir: pathlib.Path | str): return output_path +def export_model_q4( + model_path: str, + output_dir: pathlib.Path | str, + block_size: int = 32, + symmetric: bool = True, +): + """Export LFM2 model with Q4 quantization (GatherBlockQuantized + MatMulNBits). + + Builds the model with INT4 embedding (GatherBlockQuantized) and lm_head + (MatMulNBits), then post-export quantizes remaining FP32 MatMul layers. + """ + output_dir = pathlib.Path(output_dir) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + lfm2_config = LFM2Config.from_hf_config(config) + + builder = LFM2Builder(lfm2_config, use_q4=True, q4_block_size=block_size) + model = builder.build(model_path) + + onnx_dir = output_dir / "onnx" + onnx_dir.mkdir(parents=True, exist_ok=True) + + # Save intermediate model (embedding+lm_head quantized, other layers FP32) + intermediate_path = onnx_dir / "model_q4_intermediate.onnx" + intermediate_data = onnx_dir / "model_q4_intermediate.onnx_data" + if intermediate_data.exists(): + intermediate_data.unlink() + + onnx.save_model( + model, + str(intermediate_path), + save_as_external_data=True, + all_tensors_to_one_file=True, + location="model_q4_intermediate.onnx_data", + ) + + # Post-export quantization for remaining FP32 MatMul nodes + output_path = onnx_dir / "model_q4.onnx" + if output_path.exists(): + logger.info("Skipping q4 (already exists)") + intermediate_path.unlink(missing_ok=True) + intermediate_data.unlink(missing_ok=True) + return + + quantize_model( + intermediate_path, + output_path, + bits=4, + block_size=block_size, + exclude_lm_head=True, + symmetric=symmetric, + ) + + # Clean up intermediate + intermediate_path.unlink(missing_ok=True) + intermediate_data.unlink(missing_ok=True) + + _, q4_mb = get_model_size(output_path) + logger.info(f" Q4 model: {q4_mb:.1f} MB") + + def do_quantize( onnx_dir: pathlib.Path, bits: int, exclude_lm_head: bool, block_size: int, symmetric: bool = False, + output_name: str | None = None, + source: str = "model.onnx", ): """Quantize model to INT4 or INT8. Args: - onnx_dir: Directory containing model.onnx + onnx_dir: Directory containing source model bits: Quantization bits (4 or 8) exclude_lm_head: Keep lm_head in FP32 block_size: Block size for quantization symmetric: Use symmetric quantization (no zero points, better JSEP compatibility) + output_name: Override output filename (default: model_q{bits}.onnx) + source: Source model filename to quantize from (default: model.onnx) """ - input_model = onnx_dir / "model.onnx" + input_model = onnx_dir / source if not input_model.exists(): - raise FileNotFoundError(f"model.onnx not found in {onnx_dir}") + raise FileNotFoundError(f"{source} not found in {onnx_dir}") - output_model = onnx_dir / f"model_q{bits}.onnx" + output_model = onnx_dir / (output_name or f"model_q{bits}.onnx") if output_model.exists(): - logger.info(f"Skipping q{bits} (already exists)") + logger.info(f"Skipping {output_model.name} (already exists)") return _, orig_mb = get_model_size(input_model) quant_type = "symmetric" if symmetric else "asymmetric" - logger.info(f"Quantizing to Q{bits} ({quant_type})...") + logger.info(f"Quantizing {source} to Q{bits} ({quant_type})...") quantize_model( input_model, output_model, @@ -280,18 +343,13 @@ def main(): parser.add_argument( "--skip-export", action="store_true", - help="Skip FP32 export, only run precision conversion", + help="Skip FP32 base model export. Q4 always rebuilds from HF weights.", ) parser.add_argument( "--precision", nargs="*", metavar="PRECISION", - help="Output precisions: fp16, q4, q8, or all (default if no args)", - ) - parser.add_argument( - "--no-exclude-lm-head", - action="store_true", - help="Quantize lm_head layer (by default kept in FP32)", + help="Output precisions: fp16, q4, q4f32, q8, or all (default if no args)", ) parser.add_argument( "--block-size", @@ -326,23 +384,30 @@ def main(): output_name = args.output_name or f"{model_name}-ONNX" output_dir = args.output_dir / "exports" / output_name - quant_bits = [] do_fp16_conversion = False + do_q4 = False + do_q4f32 = False + do_q8 = False if args.precision is not None: if len(args.precision) == 0: - quant_bits = [4, 8] do_fp16_conversion = True + do_q4 = True + do_q4f32 = True + do_q8 = True else: for p in args.precision: p = p.lower() if p == "fp16": do_fp16_conversion = True - elif p in ("q4", "q8"): - quant_bits.append(int(p[1])) + elif p == "q4": + do_q4 = True + elif p == "q4f32": + do_q4f32 = True + elif p == "q8": + do_q8 = True else: - parser.error(f"Invalid precision: {p}. Use fp16, q4, or q8.") + parser.error(f"Invalid precision: {p}. Use fp16, q4, q4f32, or q8.") - exclude_lm_head = not args.no_exclude_lm_head onnx_dir = output_dir / "onnx" if not args.skip_export: @@ -360,14 +425,45 @@ def main(): do_fp16(onnx_dir) logger.info(f" {model_name}: OK") - for bits in quant_bits: + if do_q4: logger.info("=" * 60) - logger.info(f"Quantizing to Q{bits}") + logger.info("Exporting Q4 (GatherBlockQuantized)") logger.info("=" * 60) + export_model_q4( + args.model, + output_dir, + block_size=args.block_size, + symmetric=not args.q4_asymmetric, + ) + logger.info(f" {model_name}: OK") - # Q4: symmetric by default (required for WebGPU), Q8: asymmetric - symmetric = (bits == 4) and not args.q4_asymmetric - do_quantize(onnx_dir, bits, exclude_lm_head, args.block_size, symmetric=symmetric) + if do_q4f32: + logger.info("=" * 60) + logger.info("Quantizing to Q4F32") + logger.info("=" * 60) + symmetric = not args.q4_asymmetric + do_quantize( + onnx_dir, + bits=4, + exclude_lm_head=True, + block_size=args.block_size, + symmetric=symmetric, + output_name="model_q4f32.onnx", + ) + logger.info(f" {model_name}: OK") + + if do_q8: + logger.info("=" * 60) + logger.info("Quantizing to Q8") + logger.info("=" * 60) + do_quantize( + onnx_dir, + bits=8, + exclude_lm_head=True, + block_size=args.block_size, + symmetric=False, + output_name="model_q8.onnx", + ) logger.info(f" {model_name}: OK") if not args.no_split_data: From 2f658843904e362ebad3bb882a469768b95e8ab5 Mon Sep 17 00:00:00 2001 From: Yuri Khrustalev Date: Wed, 8 Apr 2026 12:35:33 -0400 Subject: [PATCH 2/2] Fix ruff formatting in builder.py Co-Authored-By: Claude Opus 4.6 (1M context) --- src/liquidonnx/lfm2/builder.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/liquidonnx/lfm2/builder.py b/src/liquidonnx/lfm2/builder.py index 856010a..95a5924 100644 --- a/src/liquidonnx/lfm2/builder.py +++ b/src/liquidonnx/lfm2/builder.py @@ -1019,9 +1019,7 @@ def build_lm_head(self, hidden_state: str) -> str: n_blocks = (K + self.q4_block_size - 1) // self.q4_block_size # Reshape to 3D for MatMulNBits: [N, n_blocks, block_size/2] - embed_quant_matmul = embed_quant.reshape( - vocab_size, n_blocks, self.q4_block_size // 2 - ) + embed_quant_matmul = embed_quant.reshape(vocab_size, n_blocks, self.q4_block_size // 2) self.add_initializer( "model_embed_tokens_weight_quant_matmul", embed_quant_matmul, dtype=np.uint8 ) @@ -1092,9 +1090,7 @@ def build_value_info(self): embed_output = "/model/embed_tokens/GatherBlockQuantized/output_0" else: embed_output = "/model/embed_tokens/Gather/output_0" - self.add_value_info( - embed_output, TensorProto.FLOAT, ["batch_size", "sequence_length", H] - ) + self.add_value_info(embed_output, TensorProto.FLOAT, ["batch_size", "sequence_length", H]) # === Per-layer outputs === for layer_idx in range(num_layers):