@@ -693,53 +693,56 @@ def _get_kernel_with_merged_lora(self):
693693 `kernel_scale`: The quantization scale for the merged kernel.
694694 This is `None` if the layer is not quantized.
695695 """
696- if self .dtype_policy .quantization_mode is not None :
697- kernel_value = self ._kernel
698- kernel_scale = self .kernel_scale
699- if self .lora_enabled :
700- # Dequantize kernel to float
701- if self .quantization_mode == "int4" :
702- unpacked_kernel = quantizers .unpack_int4 (
703- kernel_value , self ._orig_input_dim
704- )
705- float_kernel = ops .divide (
706- ops .cast (unpacked_kernel , self .compute_dtype ),
707- kernel_scale ,
708- )
709- quant_range = (- 8 , 7 )
710- elif self .quantization_mode == "int8" :
711- float_kernel = ops .divide (
712- ops .cast (kernel_value , self .compute_dtype ), kernel_scale
713- )
714- quant_range = (- 127 , 127 )
715- else :
716- raise ValueError (
717- "Unsupported quantization mode: "
718- f"{ self .quantization_mode } "
719- )
720-
721- # Merge LoRA weights in float domain
722- lora_delta = (self .lora_alpha / self .lora_rank ) * ops .matmul (
723- self .lora_kernel_a , self .lora_kernel_b
724- )
725- merged_float_kernel = ops .add (float_kernel , lora_delta )
726-
727- # Requantize
728- requantized_kernel , kernel_scale = quantizers .abs_max_quantize (
729- merged_float_kernel ,
730- axis = 0 ,
731- value_range = quant_range ,
732- dtype = "int8" ,
733- to_numpy = True ,
734- )
735- kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
736-
737- # Pack if int4
738- if self .quantization_mode == "int4" :
739- kernel_value , _ , _ = quantizers .pack_int4 (
740- requantized_kernel
741- )
742- else :
743- kernel_value = requantized_kernel
696+ if self .dtype_policy .quantization_mode is None :
697+ return self .kernel , None
698+
699+ kernel_value = self ._kernel
700+ kernel_scale = self .kernel_scale
701+
702+ if not self .lora_enabled :
744703 return kernel_value , kernel_scale
745- return self .kernel , None
704+
705+ # Dequantize, Merge, and Re-quantize
706+
707+ # Dequantize kernel to float
708+ if self .quantization_mode == "int4" :
709+ unpacked_kernel = quantizers .unpack_int4 (
710+ kernel_value , self ._orig_input_dim
711+ )
712+ float_kernel = ops .divide (
713+ ops .cast (unpacked_kernel , self .compute_dtype ),
714+ kernel_scale ,
715+ )
716+ quant_range = (- 8 , 7 )
717+ elif self .quantization_mode == "int8" :
718+ float_kernel = ops .divide (
719+ ops .cast (kernel_value , self .compute_dtype ), kernel_scale
720+ )
721+ quant_range = (- 127 , 127 )
722+ else :
723+ raise ValueError (
724+ f"Unsupported quantization mode: { self .quantization_mode } "
725+ )
726+
727+ # Merge LoRA weights in float domain
728+ lora_delta = (self .lora_alpha / self .lora_rank ) * ops .matmul (
729+ self .lora_kernel_a , self .lora_kernel_b
730+ )
731+ merged_float_kernel = ops .add (float_kernel , lora_delta )
732+
733+ # Requantize
734+ requantized_kernel , kernel_scale = quantizers .abs_max_quantize (
735+ merged_float_kernel ,
736+ axis = 0 ,
737+ value_range = quant_range ,
738+ dtype = "int8" ,
739+ to_numpy = True ,
740+ )
741+ kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
742+
743+ # Pack if int4
744+ if self .quantization_mode == "int4" :
745+ kernel_value , _ , _ = quantizers .pack_int4 (requantized_kernel )
746+ else :
747+ kernel_value = requantized_kernel
748+ return kernel_value , kernel_scale
0 commit comments