@@ -326,25 +326,25 @@ def save_own_variables(self, store):
326326 if not self .built :
327327 return
328328 mode = self .quantization_mode
329- if mode not in self .quantization_variable_spec :
329+ if mode not in self .variable_serialization_spec :
330330 raise self ._quantization_mode_error (mode )
331331
332332 # Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
333333 # for None/gptq)
334334 kernel_value , merged_kernel_scale = self ._get_kernel_with_merged_lora ()
335-
336- # Save the variables using the name as the key.
337- if mode != "gptq" :
338- store ["kernel" ] = kernel_value
339- if self .bias is not None :
340- store ["bias" ] = self .bias
341- for name in self .quantization_variable_spec [mode ]:
342- if name == "kernel_scale" and mode in ("int4" , "int8" ):
335+ idx = 0
336+ for name in self .variable_serialization_spec [mode ]:
337+ if name == "kernel" :
338+ store [str (idx )] = kernel_value
339+ elif name == "bias" and self .bias is None :
340+ continue
341+ elif name == "kernel_scale" and mode in ("int4" , "int8" ):
343342 # For int4/int8, the merged LoRA scale (if any) comes from
344343 # `_get_kernel_with_merged_lora()`
345- store [name ] = merged_kernel_scale
344+ store [str ( idx ) ] = merged_kernel_scale
346345 else :
347- store [name ] = getattr (self , name )
346+ store [str (idx )] = getattr (self , name )
347+ idx += 1
348348
349349 def load_own_variables (self , store ):
350350 if not self .lora_enabled :
@@ -353,39 +353,18 @@ def load_own_variables(self, store):
353353 if not self .built :
354354 return
355355 mode = self .quantization_mode
356- if mode not in self .quantization_variable_spec :
356+ if mode not in self .variable_serialization_spec :
357357 raise self ._quantization_mode_error (mode )
358358
359- # Determine whether to use the legacy loading method.
360- if "0" in store :
361- return self ._legacy_load_own_variables (store )
362-
363- # Load the variables using the name as the key.
364- if mode != "gptq" :
365- self ._kernel .assign (store ["kernel" ])
366- if self .bias is not None :
367- self .bias .assign (store ["bias" ])
368- for name in self .quantization_variable_spec [mode ]:
369- getattr (self , name ).assign (store [name ])
370- if self .lora_enabled :
371- self .lora_kernel_a .assign (ops .zeros (self .lora_kernel_a .shape ))
372- self .lora_kernel_b .assign (ops .zeros (self .lora_kernel_b .shape ))
373-
374- def _legacy_load_own_variables (self , store ):
375- # The keys of the `store` will be saved as determined because the
376- # default ordering will change after quantization
377- mode = self .quantization_mode
378- targets = []
379- if mode != "gptq" :
380- targets .append (self ._kernel )
381- if self .bias is not None :
382- targets .append (self .bias )
383- targets .extend (
384- getattr (self , name )
385- for name in self .quantization_variable_spec [mode ]
386- )
387- for i , variable in enumerate (targets ):
388- variable .assign (store [str (i )])
359+ idx = 0
360+ for name in self .variable_serialization_spec [mode ]:
361+ if name == "kernel" :
362+ self ._kernel .assign (store [str (idx )])
363+ elif name == "bias" and self .bias is None :
364+ continue
365+ else :
366+ getattr (self , name ).assign (store [str (idx )])
367+ idx += 1
389368 if self .lora_enabled :
390369 self .lora_kernel_a .assign (ops .zeros (self .lora_kernel_a .shape ))
391370 self .lora_kernel_b .assign (ops .zeros (self .lora_kernel_b .shape ))
@@ -418,53 +397,32 @@ def get_config(self):
418397 config ["gptq_unpacked_column_size" ] = self .gptq_unpacked_column_size
419398 return {** base_config , ** config }
420399
421- def _check_load_own_variables (self , store ):
422- all_vars = self ._trainable_variables + self ._non_trainable_variables
423- if len (store .keys ()) != len (all_vars ):
424- if len (all_vars ) == 0 and not self .built :
425- raise ValueError (
426- f"Layer '{ self .name } ' was never built "
427- "and thus it doesn't have any variables. "
428- f"However the weights file lists { len (store .keys ())} "
429- "variables for this layer.\n "
430- "In most cases, this error indicates that either:\n \n "
431- "1. The layer is owned by a parent layer that "
432- "implements a `build()` method, but calling the "
433- "parent's `build()` method did NOT create the state of "
434- f"the child layer '{ self .name } '. A `build()` method "
435- "must create ALL state for the layer, including "
436- "the state of any children layers.\n \n "
437- "2. You need to implement "
438- "the `def build_from_config(self, config)` method "
439- f"on layer '{ self .name } ', to specify how to rebuild "
440- "it during loading. "
441- "In this case, you might also want to implement the "
442- "method that generates the build config at saving time, "
443- "`def get_build_config(self)`. "
444- "The method `build_from_config()` is meant "
445- "to create the state "
446- "of the layer (i.e. its variables) upon deserialization." ,
447- )
448- raise ValueError (
449- f"Layer '{ self .name } ' expected { len (all_vars )} variables, "
450- "but received "
451- f"{ len (store .keys ())} variables during loading. "
452- f"Expected: { [v .name for v in all_vars ]} "
453- )
454-
455400 @property
456- def quantization_variable_spec (self ):
457- """Returns a dict mapping quantization modes to variable names.
401+ def variable_serialization_spec (self ):
402+ """Returns a dict mapping quantization modes to variable names in order .
458403
459404 This spec is used by `save_own_variables` and `load_own_variables` to
460- determine which variables should be saved/loaded for each quantization
461- mode.
405+ determine the correct ordering of variables during serialization for
406+ each quantization mode. `None` means no quantization .
462407 """
463408 return {
464- None : [],
465- "int8" : ["kernel_scale" ],
466- "int4" : ["kernel_scale" ],
409+ None : [
410+ "kernel" ,
411+ "bias" ,
412+ ],
413+ "int8" : [
414+ "kernel" ,
415+ "bias" ,
416+ "kernel_scale" ,
417+ ],
418+ "int4" : [
419+ "kernel" ,
420+ "bias" ,
421+ "kernel_scale" ,
422+ ],
467423 "float8" : [
424+ "kernel" ,
425+ "bias" ,
468426 "inputs_scale" ,
469427 "inputs_amax_history" ,
470428 "kernel_scale" ,
@@ -473,6 +431,7 @@ def quantization_variable_spec(self):
473431 "outputs_grad_amax_history" ,
474432 ],
475433 "gptq" : [
434+ "bias" ,
476435 "quantized_kernel" ,
477436 "kernel_scale" ,
478437 "kernel_zero" ,
0 commit comments