Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.8
+++++

* :pr:`372`: fix patch on rotary embedding
* :pr:`371`: fix make_fake_with_dynamic_dimensions

0.8.7
Expand Down
28 changes: 27 additions & 1 deletion onnx_diagnostic/helpers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,9 +704,35 @@ def string_type(
if obj.__class__.__name__ == "VirtualTensor":
if verbose:
print(f"[string_type] TT4:{type(obj)}")

def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F821
if isinstance(value, str):
return value
if hasattr(value, "node") and isinstance(value.node, str):
return f"{value.node}"

from torch.fx.experimental.sym_node import SymNode

if hasattr(value, "node") and isinstance(value.node, SymNode):
# '_expr' is safer than expr
return str(value.node._expr).replace(" ", "")

try:
val_int = int(value)
return val_int
except (
TypeError,
ValueError,
AttributeError,
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode,
):
pass

raise AssertionError(f"Unable to convert {value!r} into string")

return (
f"{obj.__class__.__name__}(name={obj.name!r}, "
f"dtype={obj.dtype}, shape={obj.shape})"
f"dtype={obj.dtype}, shape={tuple(_torch_sym_int_to_str(_) for _ in obj.shape)})"
)

if obj.__class__.__name__ == "KeyValuesWrapper":
Expand Down
Loading