Skip to content

Commit 43ebf47

Browse files
authored
Implement tofile on tensors to reduce data write time by 40% (#210)
This PR introduces the `tofile` method on tensors (similarly named as the one on numpy arrays), which allows for faster write and lower memory usage on external data by bypassing tobytes(). Compatibility with existing `TensorProtocol`s is maintained in the external data module by using `tofile` only when it is available in the class. The `TorchTensor` class in PyTorch exporter should be updated accordingly to leverage the new logic when saving. Note that io time to disk is reduced by 40% below. > [!NOTE] > TensorProtocol is not updated because we do isinstance() checks on external implementations (PyTorch). Adding the method in the protocol will cause isinstance check to fail on those implementations that have not added the tofile method. Reference: https://github.com/microsoft/onnxscript/pull/2241/files/b2381658492510a9bcc8c0a8574db7368e33bceb Before: ``` ________________________________________________________ Executed in 48.08 secs fish external usr time 60.54 secs 0.00 millis 60.54 secs sys time 23.06 secs 1.22 millis 23.06 secs ``` <img width="2325" height="1136" alt="image" src="https://github.com/user-attachments/assets/5283057d-c401-41f1-98f0-e11aa3707591" /> <img width="1225" height="1236" alt="image" src="https://github.com/user-attachments/assets/5d900bd1-4283-4332-9ec3-dd2bd30f8ae3" /> After: ``` ________________________________________________________ Executed in 45.69 secs fish external usr time 60.68 secs 244.00 micros 60.68 secs sys time 22.22 secs 518.00 micros 22.22 secs ``` <img width="2332" height="1247" alt="image" src="https://github.com/user-attachments/assets/87daff64-5b39-4d54-a9c8-98e7ec339033" /> <img width="1239" height="1236" alt="image" src="https://github.com/user-attachments/assets/d1b87f94-f463-4e45-afcc-577f9a5f7c91" /> Fix #207 --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 59086bc commit 43ebf47

File tree

5 files changed

+598
-26
lines changed

5 files changed

+598
-26
lines changed

src/onnx_ir/_core.py

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,19 @@ def meta(self) -> _metadata.MetadataStore:
185185
self._metadata = _metadata.MetadataStore()
186186
return self._metadata
187187

188+
def tofile(self, file) -> None:
189+
"""Write the tensor to a binary file.
190+
191+
This method writes the raw bytes of the tensor to a file-like object.
192+
The file-like object must have a ``write`` method that accepts bytes.
193+
194+
.. versionadded:: 0.1.11
195+
196+
Args:
197+
file: A file-like object with a ``write`` method that accepts bytes.
198+
"""
199+
file.write(self.tobytes())
200+
188201
def display(self, *, page: bool = False) -> None:
189202
rich = _display.require_rich()
190203

@@ -337,6 +350,38 @@ def _maybe_view_np_array_with_ml_dtypes(
337350
return array
338351

339352

353+
def _supports_fileno(file: Any) -> bool:
354+
"""Check if the file-like object supports fileno()."""
355+
if not hasattr(file, "fileno"):
356+
return False
357+
try:
358+
file.fileno()
359+
except Exception: # pylint: disable=broad-except
360+
return False
361+
return True
362+
363+
364+
def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray:
365+
"""Create a numpy array for the byte representation of the tensor.
366+
367+
This function is used for serializing the tensor to bytes. It handles the
368+
special cases for 4-bit data types and endianness.
369+
"""
370+
array = tensor.numpy()
371+
if tensor.dtype in {
372+
_enums.DataType.INT4,
373+
_enums.DataType.UINT4,
374+
_enums.DataType.FLOAT4E2M1,
375+
}:
376+
# Pack the array into int4
377+
array = _type_casting.pack_4bitx2(array)
378+
else:
379+
assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
380+
if not _IS_LITTLE_ENDIAN:
381+
array = array.astype(array.dtype.newbyteorder("<"))
382+
return array
383+
384+
340385
class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
341386
"""An immutable concrete tensor.
342387
@@ -509,20 +554,24 @@ def tobytes(self) -> bytes:
509554
value is not a numpy array.
510555
"""
511556
# TODO(justinchuby): Support DLPack
512-
array = self.numpy()
513-
if self.dtype in {
514-
_enums.DataType.INT4,
515-
_enums.DataType.UINT4,
516-
_enums.DataType.FLOAT4E2M1,
517-
}:
518-
# Pack the array into int4
519-
array = _type_casting.pack_4bitx2(array)
520-
else:
521-
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
522-
if not _IS_LITTLE_ENDIAN:
523-
array = array.astype(array.dtype.newbyteorder("<"))
557+
array = _create_np_array_for_byte_representation(self)
524558
return array.tobytes()
525559

560+
def tofile(self, file) -> None:
561+
"""Write the tensor to a binary file.
562+
563+
.. versionadded:: 0.1.11
564+
565+
Args:
566+
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
567+
"""
568+
if isinstance(self._raw, np.ndarray) and _supports_fileno(file):
569+
# This is a duplication of tobytes() for handling special cases
570+
array = _create_np_array_for_byte_representation(self)
571+
array.tofile(file)
572+
else:
573+
file.write(self.tobytes())
574+
526575

527576
class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
528577
"""An immutable concrete tensor with its data store on disk.
@@ -535,7 +584,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
535584
the tensor is recommended if IO overhead and memory usage is a concern.
536585
537586
To obtain an array, call :meth:`numpy`. To obtain the bytes,
538-
call :meth:`tobytes`.
587+
call :meth:`tobytes`. To write the data to a file, call :meth:`tofile`.
539588
540589
The :attr:`location` must be a relative path conforming to the ONNX
541590
specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed
@@ -590,7 +639,7 @@ def __init__(
590639
length: The length of the data in bytes.
591640
dtype: The data type of the tensor.
592641
shape: The shape of the tensor.
593-
name: The name of the tensor..
642+
name: The name of the tensor.
594643
doc_string: The documentation string.
595644
metadata_props: The metadata properties.
596645
base_dir: The base directory for the external data. It is used to resolve relative paths.
@@ -746,6 +795,18 @@ def tobytes(self) -> bytes:
746795
length = self._length or self.nbytes
747796
return self.raw[offset : offset + length]
748797

798+
def tofile(self, file) -> None:
799+
self._check_validity()
800+
with open(self.path, "rb") as src:
801+
if self._offset is not None:
802+
src.seek(self._offset)
803+
bytes_to_copy = self._length or self.nbytes
804+
chunk_size = 1024 * 1024 # 1MB
805+
while bytes_to_copy > 0:
806+
chunk = src.read(min(chunk_size, bytes_to_copy))
807+
file.write(chunk)
808+
bytes_to_copy -= len(chunk)
809+
749810
def valid(self) -> bool:
750811
"""Check if the tensor is valid.
751812
@@ -979,6 +1040,15 @@ def tobytes(self) -> bytes:
9791040
"""Return the bytes of the tensor."""
9801041
return self._evaluate().tobytes()
9811042

1043+
def tofile(self, file) -> None:
1044+
tensor = self._evaluate()
1045+
if hasattr(tensor, "tofile"):
1046+
# Some existing implementation of TensorProtocol
1047+
# may not have tofile() as it was introduced in v0.1.11
1048+
tensor.tofile(file)
1049+
else:
1050+
super().tofile(file)
1051+
9821052

9831053
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
9841054
"""A tensor that stores 4bit datatypes in packed format.
@@ -1113,6 +1183,23 @@ def tobytes(self) -> bytes:
11131183
array = array.astype(array.dtype.newbyteorder("<"))
11141184
return array.tobytes()
11151185

1186+
def tofile(self, file) -> None:
1187+
"""Write the tensor to a binary file.
1188+
1189+
.. versionadded:: 0.1.11
1190+
1191+
Args:
1192+
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
1193+
"""
1194+
if _supports_fileno(file):
1195+
# This is a duplication of tobytes() for handling edge cases
1196+
array = self.numpy_packed()
1197+
if not _IS_LITTLE_ENDIAN:
1198+
array = array.astype(array.dtype.newbyteorder("<"))
1199+
array.tofile(file)
1200+
else:
1201+
file.write(self.tobytes())
1202+
11161203

11171204
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
11181205
"""Immutable symbolic dimension that can be shared across multiple shapes.

0 commit comments

Comments
 (0)