diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 30a1c835..134c526f 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.8 +++++ +* :pr:`372`: fix patch on rotary embedding * :pr:`371`: fix make_fake_with_dynamic_dimensions 0.8.7 diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index f0a480ae..972cce4b 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -704,9 +704,35 @@ def string_type( if obj.__class__.__name__ == "VirtualTensor": if verbose: print(f"[string_type] TT4:{type(obj)}") + + def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F821 + if isinstance(value, str): + return value + if hasattr(value, "node") and isinstance(value.node, str): + return f"{value.node}" + + from torch.fx.experimental.sym_node import SymNode + + if hasattr(value, "node") and isinstance(value.node, SymNode): + # '_expr' is safer than expr + return str(value.node._expr).replace(" ", "") + + try: + val_int = int(value) + return val_int + except ( + TypeError, + ValueError, + AttributeError, + torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode, + ): + pass + + raise AssertionError(f"Unable to convert {value!r} into string") + return ( f"{obj.__class__.__name__}(name={obj.name!r}, " - f"dtype={obj.dtype}, shape={obj.shape})" + f"dtype={obj.dtype}, shape={tuple(_torch_sym_int_to_str(_) for _ in obj.shape)})" ) if obj.__class__.__name__ == "KeyValuesWrapper":