Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 101 additions & 14 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,19 @@ def meta(self) -> _metadata.MetadataStore:
self._metadata = _metadata.MetadataStore()
return self._metadata

def tofile(self, file) -> None:
"""Write the tensor to a binary file.

This method writes the raw bytes of the tensor to a file-like object.
The file-like object must have a ``write`` method that accepts bytes.

.. versionadded:: 0.1.11

Args:
file: A file-like object with a ``write`` method that accepts bytes.
"""
file.write(self.tobytes())

def display(self, *, page: bool = False) -> None:
rich = _display.require_rich()

Expand Down Expand Up @@ -337,6 +350,38 @@ def _maybe_view_np_array_with_ml_dtypes(
return array


def _supports_fileno(file: Any) -> bool:
"""Check if the file-like object supports fileno()."""
if not hasattr(file, "fileno"):
return False
try:
file.fileno()
except Exception: # pylint: disable=broad-except
return False
return True


def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray:
"""Create a numpy array for the byte representation of the tensor.

This function is used for serializing the tensor to bytes. It handles the
special cases for 4-bit data types and endianness.
"""
array = tensor.numpy()
if tensor.dtype in {
_enums.DataType.INT4,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
}:
# Pack the array into int4
array = _type_casting.pack_4bitx2(array)
else:
assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
if not _IS_LITTLE_ENDIAN:
array = array.astype(array.dtype.newbyteorder("<"))
return array


class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
"""An immutable concrete tensor.

Expand Down Expand Up @@ -509,20 +554,24 @@ def tobytes(self) -> bytes:
value is not a numpy array.
"""
# TODO(justinchuby): Support DLPack
array = self.numpy()
if self.dtype in {
_enums.DataType.INT4,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
}:
# Pack the array into int4
array = _type_casting.pack_4bitx2(array)
else:
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
if not _IS_LITTLE_ENDIAN:
array = array.astype(array.dtype.newbyteorder("<"))
array = _create_np_array_for_byte_representation(self)
return array.tobytes()

def tofile(self, file) -> None:
"""Write the tensor to a binary file.

.. versionadded:: 0.1.11

Args:
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
"""
if isinstance(self._raw, np.ndarray) and _supports_fileno(file):
# This is a duplication of tobytes() for handling special cases
array = _create_np_array_for_byte_representation(self)
array.tofile(file)
else:
file.write(self.tobytes())


class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
"""An immutable concrete tensor with its data store on disk.
Expand All @@ -535,7 +584,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
the tensor is recommended if IO overhead and memory usage is a concern.

To obtain an array, call :meth:`numpy`. To obtain the bytes,
call :meth:`tobytes`.
call :meth:`tobytes`. To write the data to a file, call :meth:`tofile`.

The :attr:`location` must be a relative path conforming to the ONNX
specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed
Expand Down Expand Up @@ -590,7 +639,7 @@ def __init__(
length: The length of the data in bytes.
dtype: The data type of the tensor.
shape: The shape of the tensor.
name: The name of the tensor..
name: The name of the tensor.
doc_string: The documentation string.
metadata_props: The metadata properties.
base_dir: The base directory for the external data. It is used to resolve relative paths.
Expand Down Expand Up @@ -746,6 +795,18 @@ def tobytes(self) -> bytes:
length = self._length or self.nbytes
return self.raw[offset : offset + length]

def tofile(self, file) -> None:
self._check_validity()
with open(self.path, "rb") as src:
if self._offset is not None:
src.seek(self._offset)
bytes_to_copy = self._length or self.nbytes
chunk_size = 1024 * 1024 # 1MB
while bytes_to_copy > 0:
chunk = src.read(min(chunk_size, bytes_to_copy))
file.write(chunk)
bytes_to_copy -= len(chunk)

def valid(self) -> bool:
"""Check if the tensor is valid.

Expand Down Expand Up @@ -979,6 +1040,15 @@ def tobytes(self) -> bytes:
"""Return the bytes of the tensor."""
return self._evaluate().tobytes()

def tofile(self, file) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether tofile() makes sense to LazyTensor. hmm

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought it's not even real until it's evaluated. Intuitively, not very suitable with tofile(), which we want to write it to disk. But I guess in general expectation, we want all tensors have this method. It's understandable.

Copy link
Member Author

@justinchuby justinchuby Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually useful: even when the tensor is lazily evaluated, we still want to avoid tobytes() making a copy of the tensor data before writing to file. The screenshots on the PR description are showing lazy tensors.

tensor = self._evaluate()
if hasattr(tensor, "tofile"):
# Some existing implementation of TensorProtocol
# may not have tofile() as it was introduced in v0.1.11
tensor.tofile(file)
else:
super().tofile(file)


class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
"""A tensor that stores 4bit datatypes in packed format.
Expand Down Expand Up @@ -1113,6 +1183,23 @@ def tobytes(self) -> bytes:
array = array.astype(array.dtype.newbyteorder("<"))
return array.tobytes()

def tofile(self, file) -> None:
"""Write the tensor to a binary file.

.. versionadded:: 0.1.11

Args:
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
"""
if _supports_fileno(file):
# This is a duplication of tobytes() for handling edge cases
array = self.numpy_packed()
if not _IS_LITTLE_ENDIAN:
array = array.astype(array.dtype.newbyteorder("<"))
array.tofile(file)
else:
file.write(self.tobytes())


class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
"""Immutable symbolic dimension that can be shared across multiple shapes.
Expand Down
Loading