@@ -135,8 +135,9 @@ def cast_inputs_to_dtype(inputs, current_dtype, target_dtype):
135135 return inputs .to (target_dtype ) if inputs .dtype == current_dtype else inputs
136136 if isinstance (inputs , dict ):
137137 return {k : cast_inputs_to_dtype (v , current_dtype , target_dtype ) for k , v in inputs .items ()}
138- if isinstance (inputs , list ):
139- return [cast_inputs_to_dtype (v , current_dtype , target_dtype ) for v in inputs ]
138+ if isinstance (inputs , (list , tuple )):
139+ # Preserve the container type so models that branch on it (e.g. `isinstance(..., tuple)`) still see a tuple.
140+ return type (inputs )(cast_inputs_to_dtype (v , current_dtype , target_dtype ) for v in inputs )
140141
141142 return inputs
142143
@@ -479,7 +480,11 @@ def test_keep_in_fp32_modules(self, tmp_path):
479480 )
480481 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ], ids = ["fp16" , "bf16" ])
481482 @torch .no_grad ()
482- def test_from_save_pretrained_dtype_inference (self , tmp_path , dtype , atol = 1e-4 , rtol = 0 ):
483+ def test_from_save_pretrained_dtype_inference (self , tmp_path , dtype ):
484+ # Low-precision inference is inherently lossy, and models that keep some modules in fp32 diverge further from
485+ # the fully-cast reference. Tolerances reflect the dtype's precision rather than a tight fp32-style threshold.
486+ atol = 3e-2 if dtype == torch .bfloat16 else 1e-2
487+ rtol = 0
483488 model = self .model_class (** self .get_init_dict ())
484489 model .to (torch_device )
485490 fp32_modules = model ._keep_in_fp32_modules or []
0 commit comments