@@ -44,7 +44,6 @@ def __init__(
4444 self .norm = norm
4545 self .is_activation = base_name != "weight"
4646
47-
4847 def calculate_mse_min_max (
4948 self ,
5049 observed : Tensor ,
@@ -88,7 +87,7 @@ def calculate_mse_min_max(
8887
8988 from compressed_tensors .quantization .utils import generate_gparam
9089
91- if (is_fp4 (self .quantization_args )) and global_scale is None :
90+ if (is_fp4 (self .quantization_args )) and global_scale is None :
9291 # If the quantization scheme is fp4 and global_scale is still None
9392 # i.e it has not yet been optimized, then we are should first get
9493 # the global scale and then optimize the local scales.
@@ -147,7 +146,7 @@ def calculate_updated_min_max(
147146 reduce_dims : Optional [Tuple [int ]] = None ,
148147 tensor_id : Optional [Any ] = None ,
149148 global_scale : Optional [torch .Tensor ] = None ,
150- is_local : Optional [bool ]= False ,
149+ is_local : Optional [bool ] = False ,
151150 ) -> Tuple [FloatTensor , IntTensor ]:
152151 """
153152 Updates the mse-clipped min and max values of the observed tensor using
@@ -258,7 +257,6 @@ def reset(self):
258257 self .min_val = {}
259258 self .max_val = {}
260259
261-
262260 def calculate_gparam (self , observed : Tensor ) -> torch .Tensor :
263261 """
264262 Generate a global scale using the observed min and max from MSE optimization.
@@ -276,4 +274,5 @@ def calculate_gparam(self, observed: Tensor) -> torch.Tensor:
276274 )
277275
278276 return generate_gparam (
279- updated_min_val = updated_min_val , updated_max_val = updated_max_val )
277+ updated_min_val = updated_min_val , updated_max_val = updated_max_val
278+ )
0 commit comments