Skip to content

Commit d527ac8

Browse files
committed
Use tofile directly
Signed-off-by: Justin Chu <[email protected]>
1 parent ef9b697 commit d527ac8

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

src/onnx_ir/_core.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,13 +1041,7 @@ def tobytes(self) -> bytes:
10411041
return self._evaluate().tobytes()
10421042

10431043
def tofile(self, file) -> None:
1044-
tensor = self._evaluate()
1045-
if hasattr(tensor, "tofile"):
1046-
# Some existing implementation (e.g. PyTorch <2.10) of TensorProtocol
1047-
# may not have tofile() as it was introduced in v0.1.11
1048-
tensor.tofile(file)
1049-
else:
1050-
super().tofile(file)
1044+
self._evaluate().tofile(file)
10511045

10521046

10531047
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors

src/onnx_ir/_protocols.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,18 @@ def tobytes(self) -> bytes:
147147
"""Return the tensor as a byte string conformed to the ONNX specification, in little endian."""
148148
...
149149

150+
def tofile(self, file) -> None:
151+
"""Write the tensor as a byte string conformed to the ONNX specification to the given file-like object.
152+
153+
The file-like object must support ``file.write(bytes)``.
154+
If the file-like object also supports ``file.fileno()``, it will be used
155+
to write the data directly to the underlying file descriptor. This is
156+
more efficient for large tensors.
157+
158+
.. versionadded:: 0.1.11
159+
"""
160+
...
161+
150162

151163
@typing.runtime_checkable
152164
class ValueProtocol(Protocol):

src/onnx_ir/external_data.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,7 @@ def _write_external_data(
210210
if current_offset > file_size:
211211
data_file.write(b"\0" * (current_offset - file_size))
212212

213-
if hasattr(tensor, "tofile"):
214-
# Some existing implementation (e.g. PyTorch <2.10) of TensorProtocol
215-
# may not have tofile() as it was introduced in v0.1.11
216-
tensor.tofile(data_file)
217-
else:
218-
raw_data = tensor.tobytes()
219-
if isinstance(tensor, _core.ExternalTensor):
220-
tensor.release()
221-
data_file.write(raw_data)
213+
tensor.tofile(data_file)
222214

223215

224216
def _create_external_tensor(

0 commit comments

Comments
 (0)