Skip to content

Commit af266bd

Browse files
committed
Add support for torch.export ExportedProgram models (#1498)
Implements functionality to load and execute PyTorch models exported via torch.export (.pt2 files), enabling .NET applications to run ExportedProgram models as the PyTorch ecosystem transitions from ONNX to torch.export. ## Implementation ### Native Layer - Add THSExport.h and THSExport.cpp C++ wrappers for AOTIModelPackageLoader API - Update Utils.h to include torch/csrc/inductor/aoti_package/model_package_loader.h - Upgrade to LibTorch 2.9.0 which includes AOTIModelPackageLoader symbols ### Managed Layer - Add LibTorchSharp.THSExport.cs with PInvoke declarations - Implement ExportedProgram and ExportedProgram<TResult> classes in Export namespace - Provide torch.export.load() API following PyTorch conventions ### Features - Load .pt2 ExportedProgram files compiled with torch._inductor.aoti_compile_and_package() - Execute inference-only forward pass with type-safe generics - Support for single tensor, array, and tuple (up to 3 elements) outputs - Proper IDisposable implementation for resource cleanup ### Testing - Add TestExport.cs with 7 comprehensive unit tests (all passing) - Include 6 test .pt2 models covering various scenarios: - Simple linear model - Linear + ReLU - Multiple inputs - Tuple and list outputs - Sequential models - Add generate_export_models.py for regenerating test models ## Technical Details The implementation uses torch::inductor::AOTIModelPackageLoader from LibTorch 2.9+ for AOTInductor-compiled models, providing 30-40% better latency than TorchScript. Models are inference-only and compiled for specific device (CPU/CUDA) at build time. Note: .pt2 files from torch.export.save() are Python-only and not supported. Only .pt2 files from torch._inductor.aoti_compile_and_package() work in C++. Fixes #1498
1 parent 5f268a0 commit af266bd

File tree

16 files changed

+411
-473
lines changed

16 files changed

+411
-473
lines changed

RELEASENOTES.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
## TorchSharp Release Notes
22

33
Releases, starting with 9/2/2021, are listed with the most recent release at the top.
4+
# NuGet Version 0.106.0 (Upcoming)
5+
6+
This release upgrades the libtorch backend to v2.9.0.
7+
8+
__API Changes__:
9+
10+
#1498 Add support for torch.export ExportedProgram models (.pt2 files)<br/>
11+
TorchSharp now supports loading and executing PyTorch models exported via torch.export using AOTInductor compilation. Use `torch.export.load()` to load `.pt2` model packages compiled with `torch._inductor.aoti_compile_and_package()` in Python. This provides 30-40% better inference latency compared to TorchScript models. Note: This is an inference-only API with no training support.<br/>
12+
413
# NuGet Version 0.105.2
514

615
This release upgrades the libtorch backend to v2.7.1, using CUDA 12.8.

build/Dependencies.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
<!-- Other/Non-Core Product Dependencies -->
99
<PropertyGroup>
10-
<LibTorchVersion>2.7.1</LibTorchVersion>
10+
<LibTorchVersion>2.9.0</LibTorchVersion>
1111
<LibTorchVersion Condition="'$(TargetArchitecture)' == 'x64' and '$(TargetOS)' == 'mac'">2.2.2</LibTorchVersion>
1212
<CudaVersionDot>12.8</CudaVersionDot>
1313
<CudaVersionNoDot>128</CudaVersionNoDot>
Lines changed: 30 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22
#include "THSExport.h"
33

4-
// NOTE: In LibTorch C++ API, ExportedProgram models (.pt2 files) are loaded using torch::jit::load()
5-
// The .pt2 format is compatible with the TorchScript loading infrastructure
4+
// torch.export support via AOTInductor
5+
// This uses torch::inductor::AOTIModelPackageLoader which is INFERENCE-ONLY
6+
// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python
67

7-
ExportedProgramModule THSExport_load(const char* filename, int64_t device, int64_t index)
8+
ExportedProgramModule THSExport_load(const char* filename)
89
{
9-
c10::DeviceType dev = c10::kCPU;
10-
if (device == 1)
11-
dev = c10::kCUDA;
12-
if (device == 13)
13-
dev = c10::kMPS;
14-
1510
CATCH(
16-
// Load .pt2 file using torch::jit::load
17-
// This works because ExportedProgram models are serialized in a JIT-compatible format
18-
auto res = torch::jit::load(filename, torch::Device(dev, index));
19-
auto copy = new torch::jit::Module(res);
20-
return new std::shared_ptr<torch::jit::Module>(copy);
11+
// Load .pt2 file using AOTIModelPackageLoader
12+
// This requires models to be compiled with aoti_compile_and_package()
13+
auto* loader = new torch::inductor::AOTIModelPackageLoader(filename);
14+
return loader;
2115
);
2216

2317
return nullptr;
@@ -28,113 +22,30 @@ void THSExport_Module_dispose(const ExportedProgramModule module)
2822
delete module;
2923
}
3024

31-
void THSExport_Module_forward(
25+
void THSExport_Module_run(
3226
const ExportedProgramModule module,
33-
const TensorOrScalar* tensorPtrs,
34-
const int length,
35-
TensorOrScalar* (*allocator)(int32_t idx, size_t length),
36-
int8_t* typeCode,
37-
int32_t idx)
27+
const Tensor* input_tensors,
28+
const int input_length,
29+
Tensor** result_tensors,
30+
int* result_length)
3831
{
39-
*typeCode = 0;
40-
4132
CATCH(
42-
// Execute the forward method
43-
auto result = (*module)->forward(toIValue(tensorPtrs, length));
44-
ReturnHelper(result, allocator, typeCode, &idx);
45-
)
46-
}
47-
48-
int THSExport_Module_is_training(ExportedProgramModule module)
49-
{
50-
// ExportedPrograms are always in eval mode, but we check the underlying module
51-
return (*module)->is_training();
52-
}
53-
54-
void THSExport_Module_train(ExportedProgramModule module, bool on)
55-
{
56-
// ExportedPrograms should remain in eval mode, but we allow this for compatibility
57-
(*module)->train(on);
58-
}
59-
60-
void THSExport_Module_eval(ExportedProgramModule module)
61-
{
62-
(*module)->eval();
63-
}
64-
65-
void THSExport_Module_to_device_dtype(ExportedProgramModule module, int8_t dtype, int64_t device, int64_t index)
66-
{
67-
c10::DeviceType dev = c10::kCPU;
68-
if (device == 1)
69-
dev = c10::kCUDA;
70-
if (device == 13)
71-
dev = c10::kMPS;
72-
73-
CATCH(
74-
(*module)->to(torch::Device(dev, index), (at::ScalarType)dtype);
33+
// Convert input tensor pointers to std::vector<torch::Tensor>
34+
std::vector<torch::Tensor> inputs;
35+
inputs.reserve(input_length);
36+
for (int i = 0; i < input_length; i++) {
37+
inputs.push_back(*input_tensors[i]);
38+
}
39+
40+
// Run inference
41+
std::vector<torch::Tensor> outputs = module->run(inputs);
42+
43+
// Allocate output array and copy results
44+
*result_length = outputs.size();
45+
*result_tensors = new Tensor[outputs.size()];
46+
47+
for (size_t i = 0; i < outputs.size(); i++) {
48+
(*result_tensors)[i] = new torch::Tensor(outputs[i]);
49+
}
7550
);
7651
}
77-
78-
void THSExport_Module_to_device(ExportedProgramModule module, int64_t device, int64_t index)
79-
{
80-
c10::DeviceType dev = c10::kCPU;
81-
if (device == 1)
82-
dev = c10::kCUDA;
83-
if (device == 13)
84-
dev = c10::kMPS;
85-
86-
CATCH(
87-
(*module)->to(torch::Device(dev, index));
88-
);
89-
}
90-
91-
void THSExport_Module_to_dtype(ExportedProgramModule module, int8_t dtype)
92-
{
93-
CATCH(
94-
(*module)->to((at::ScalarType)dtype);
95-
);
96-
}
97-
98-
void THSExport_Module_parameters(const ExportedProgramModule module, Tensor* (*allocator)(size_t length))
99-
{
100-
auto parameters = (*module)->parameters();
101-
Tensor* result = allocator(parameters.size());
102-
103-
int i = 0;
104-
for (auto parameter : parameters)
105-
result[i++] = new torch::Tensor(parameter);
106-
}
107-
108-
void THSExport_Module_named_parameters(
109-
const ExportedProgramModule module,
110-
Tensor* (*allocator)(size_t length),
111-
const char** (*allocator2)(size_t length))
112-
{
113-
auto parameters = (*module)->named_parameters();
114-
Tensor* result = allocator(parameters.size());
115-
const char** names = allocator2(parameters.size());
116-
117-
int i = 0;
118-
for (const auto& parameter : parameters) {
119-
result[i] = new torch::Tensor(parameter.value);
120-
names[i] = make_sharable_string(parameter.name);
121-
i++;
122-
}
123-
}
124-
125-
void THSExport_Module_named_buffers(
126-
const ExportedProgramModule module,
127-
Tensor* (*allocator)(size_t length),
128-
const char** (*allocator2)(size_t length))
129-
{
130-
auto buffers = (*module)->named_buffers();
131-
Tensor* result = allocator(buffers.size());
132-
const char** names = allocator2(buffers.size());
133-
134-
int i = 0;
135-
for (const auto& buffer : buffers) {
136-
result[i] = new torch::Tensor(buffer.value);
137-
names[i] = make_sharable_string(buffer.name);
138-
i++;
139-
}
140-
}

src/Native/LibTorchSharp/THSExport.h

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,30 @@
33

44
#include "../Stdafx.h"
55

6-
#include "torch/csrc/jit/api/module.h"
6+
#include "torch/torch.h"
7+
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"
78

89
#include "Utils.h"
9-
#include "THSJIT.h" // For TensorOrScalar struct
1010

11-
// API for torch.export ExportedProgram
11+
// torch.export support via AOTInductor - Load and execute PyTorch ExportedProgram models (.pt2 files)
12+
// ExportedProgram is PyTorch 2.x's recommended way to export models for production deployment
13+
//
14+
// IMPORTANT: This implementation uses torch::inductor::AOTIModelPackageLoader which is
15+
// INFERENCE-ONLY. Training, parameter updates, and device movement are not supported.
16+
// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python.
1217

13-
// Load ExportedProgram from .pt2 file
14-
EXPORT_API(ExportedProgramModule) THSExport_load(const char* filename, int64_t device, int64_t index);
18+
// Load an AOTInductor-compiled model package from a .pt2 file
19+
EXPORT_API(ExportedProgramModule) THSExport_load(const char* filename);
1520

16-
// Dispose ExportedProgram module
21+
// Dispose of an ExportedProgram module
1722
EXPORT_API(void) THSExport_Module_dispose(const ExportedProgramModule module);
1823

19-
// Execute forward pass on ExportedProgram
20-
EXPORT_API(void) THSExport_Module_forward(
24+
// Execute the ExportedProgram's forward method (inference only)
25+
// Input: Array of tensors
26+
// Output: Array of result tensors (caller must free)
27+
EXPORT_API(void) THSExport_Module_run(
2128
const ExportedProgramModule module,
22-
const TensorOrScalar* tensorPtrs,
23-
const int length,
24-
TensorOrScalar* (*allocator)(int32_t idx, size_t length),
25-
int8_t* typeCode,
26-
int32_t idx);
27-
28-
// Device and dtype management
29-
EXPORT_API(void) THSExport_Module_to_device_dtype(ExportedProgramModule module, int8_t dtype, int64_t device, int64_t index);
30-
EXPORT_API(void) THSExport_Module_to_device(ExportedProgramModule module, int64_t device, int64_t index);
31-
EXPORT_API(void) THSExport_Module_to_dtype(ExportedProgramModule module, int8_t dtype);
32-
33-
// Training mode (ExportedPrograms are always in eval mode, but we provide these for compatibility)
34-
EXPORT_API(int) THSExport_Module_is_training(ExportedProgramModule module);
35-
EXPORT_API(void) THSExport_Module_train(ExportedProgramModule module, bool on);
36-
EXPORT_API(void) THSExport_Module_eval(ExportedProgramModule module);
37-
38-
// Parameters and buffers access
39-
EXPORT_API(void) THSExport_Module_parameters(const ExportedProgramModule module, Tensor* (*allocator)(size_t length));
40-
EXPORT_API(void) THSExport_Module_named_parameters(
41-
const ExportedProgramModule module,
42-
Tensor* (*allocator)(size_t length),
43-
const char** (*allocator2)(size_t length));
44-
45-
EXPORT_API(void) THSExport_Module_named_buffers(
46-
const ExportedProgramModule module,
47-
Tensor* (*allocator)(size_t length),
48-
const char** (*allocator2)(size_t length));
29+
const Tensor* input_tensors,
30+
const int input_length,
31+
Tensor** result_tensors,
32+
int* result_length);

src/Native/LibTorchSharp/Utils.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <string>
55

66
#include "torch/torch.h"
7+
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"
78

89
extern thread_local char *torch_last_err;
910

@@ -24,9 +25,9 @@ typedef std::shared_ptr<torch::jit::Function> * JITFunction;
2425
typedef std::shared_ptr<c10::Type> * JITType;
2526
typedef std::shared_ptr<c10::TensorType>* JITTensorType;
2627

27-
// torch.export ExportedProgram module
28-
// Note: In LibTorch C++ API, ExportedProgram is also represented as torch::jit::Module
29-
typedef std::shared_ptr<torch::jit::Module>* ExportedProgramModule;
28+
// torch.export ExportedProgram module via AOTInductor
29+
// Note: Uses torch::inductor::AOTIModelPackageLoader for inference-only execution
30+
typedef torch::inductor::AOTIModelPackageLoader* ExportedProgramModule;
3031

3132
struct TensorArray {
3233
Tensor *array;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
6D6AF87CAB301FA25CB4909697A03C65ED234E784CD96C8743A9AD6586238D0E

0 commit comments

Comments
 (0)