You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add support for torch.export ExportedProgram models (#1498)
Implements functionality to load and execute PyTorch models exported via
torch.export (.pt2 files), enabling .NET applications to run ExportedProgram
models as the PyTorch ecosystem transitions from ONNX to torch.export.
## Implementation
### Native Layer
- Add THSExport.h and THSExport.cpp C++ wrappers for AOTIModelPackageLoader API
- Update Utils.h to include torch/csrc/inductor/aoti_package/model_package_loader.h
- Upgrade to LibTorch 2.9.0 which includes AOTIModelPackageLoader symbols
### Managed Layer
- Add LibTorchSharp.THSExport.cs with PInvoke declarations
- Implement ExportedProgram and ExportedProgram<TResult> classes in Export namespace
- Provide torch.export.load() API following PyTorch conventions
### Features
- Load .pt2 ExportedProgram files compiled with torch._inductor.aoti_compile_and_package()
- Execute inference-only forward pass with type-safe generics
- Support for single tensor, array, and tuple (up to 3 elements) outputs
- Proper IDisposable implementation for resource cleanup
### Testing
- Add TestExport.cs with 7 comprehensive unit tests (all passing)
- Include 6 test .pt2 models covering various scenarios:
- Simple linear model
- Linear + ReLU
- Multiple inputs
- Tuple and list outputs
- Sequential models
- Add generate_export_models.py for regenerating test models
## Technical Details
The implementation uses torch::inductor::AOTIModelPackageLoader from LibTorch 2.9+
for AOTInductor-compiled models, providing 30-40% better latency than TorchScript.
Models are inference-only and compiled for specific device (CPU/CUDA) at build time.
Note: .pt2 files from torch.export.save() are Python-only and not supported.
Only .pt2 files from torch._inductor.aoti_compile_and_package() work in C++.
Fixes#1498
Copy file name to clipboardExpand all lines: RELEASENOTES.md
+9Lines changed: 9 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,6 +1,15 @@
1
1
## TorchSharp Release Notes
2
2
3
3
Releases, starting with 9/2/2021, are listed with the most recent release at the top.
4
+
# NuGet Version 0.106.0 (Upcoming)
5
+
6
+
This release upgrades the libtorch backend to v2.9.0.
7
+
8
+
__API Changes__:
9
+
10
+
#1498 Add support for torch.export ExportedProgram models (.pt2 files)<br/>
11
+
TorchSharp now supports loading and executing PyTorch models exported via torch.export using AOTInductor compilation. Use `torch.export.load()` to load `.pt2` model packages compiled with `torch._inductor.aoti_compile_and_package()` in Python. This provides 30-40% better inference latency compared to TorchScript models. Note: This is an inference-only API with no training support.<br/>
12
+
4
13
# NuGet Version 0.105.2
5
14
6
15
This release upgrades the libtorch backend to v2.7.1, using CUDA 12.8.
0 commit comments