diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 78678cae..3b60a8b8 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -172,10 +172,10 @@ def _check_(): assert expected & set( dummies ), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}" - assert sequence_length == dummies["input_ids"].shape[-1], ( - f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for " - f"model class {model.__class__.__name__}" - ) + # assert sequence_length == dummies["input_ids"].shape[-1], ( + # f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for " + # f"model class {model.__class__.__name__}" + # ) assert batch_size == dummies["input_ids"].shape[0], ( f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for " f"model class {model.__class__.__name__}" diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py index cb55b3b2..1605eead 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py @@ -214,7 +214,7 @@ def longrope_frequency_update(self, position_ids, device, layer_type=None): cond, (lambda x, y: x.clone()), (lambda x, y: y.clone()), - [long_inv_freq, original_inv_freq], + [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq], ) setattr(self, f"{prefix}inv_freq", inv_freq) # if seq_len > original_max_position_embeddings: @@ -293,7 +293,7 @@ def dynamic_frequency_update(self, position_ids, device, layer_type=None): cond, (lambda x, y: x.clone()), (lambda x, y: y.clone()), - [long_inv_freq, original_inv_freq], + [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq], ) setattr(self, f"{prefix}inv_freq", inv_freq)