diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 047d76a27..8f1609890 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -1,6 +1,15 @@ ## TorchSharp Release Notes Releases, starting with 9/2/2021, are listed with the most recent release at the top. +# NuGet Version 0.106.0 (Upcoming) + +This release upgrades the libtorch backend to v2.9.0. + +__API Changes__: + +#1498 Add support for torch.export ExportedProgram models (.pt2 files)
+TorchSharp now supports loading and executing PyTorch models exported via torch.export using AOTInductor compilation. Use `torch.export.load()` to load `.pt2` model packages compiled with `torch._inductor.aoti_compile_and_package()` in Python. This provides 30-40% better inference latency compared to TorchScript models. Note: This is an inference-only API with no training support.
+ # NuGet Version 0.105.2 This release upgrades the libtorch backend to v2.7.1, using CUDA 12.8. diff --git a/build/Dependencies.props b/build/Dependencies.props index 6d3d32065..e08a88807 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -7,7 +7,7 @@ - 2.7.1 + 2.9.0 2.2.2 12.8 128 diff --git a/src/Native/LibTorchSharp/CMakeLists.txt b/src/Native/LibTorchSharp/CMakeLists.txt index 60b61f049..8e5e1e38a 100644 --- a/src/Native/LibTorchSharp/CMakeLists.txt +++ b/src/Native/LibTorchSharp/CMakeLists.txt @@ -11,6 +11,7 @@ set(SOURCES crc32c.h THSAutograd.h THSData.h + THSExport.h THSJIT.h THSNN.h THSStorage.h @@ -23,6 +24,7 @@ set(SOURCES THSActivation.cpp THSAutograd.cpp THSData.cpp + THSExport.cpp THSFFT.cpp THSJIT.cpp THSLinearAlgebra.cpp diff --git a/src/Native/LibTorchSharp/THSExport.cpp b/src/Native/LibTorchSharp/THSExport.cpp new file mode 100644 index 000000000..b777ddfed --- /dev/null +++ b/src/Native/LibTorchSharp/THSExport.cpp @@ -0,0 +1,51 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +#include "THSExport.h" + +// torch.export support via AOTInductor +// This uses torch::inductor::AOTIModelPackageLoader which is INFERENCE-ONLY +// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python + +ExportedProgramModule THSExport_load(const char* filename) +{ + CATCH( + // Load .pt2 file using AOTIModelPackageLoader + // This requires models to be compiled with aoti_compile_and_package() + auto* loader = new torch::inductor::AOTIModelPackageLoader(filename); + return loader; + ); + + return nullptr; +} + +void THSExport_Module_dispose(const ExportedProgramModule module) +{ + delete module; +} + +void THSExport_Module_run( + const ExportedProgramModule module, + const Tensor* input_tensors, + const int input_length, + Tensor** result_tensors, + int* result_length) +{ + CATCH( + // Convert input tensor pointers to std::vector + std::vector inputs; + inputs.reserve(input_length); + for (int i = 0; i < input_length; i++) { + inputs.push_back(*input_tensors[i]); + } + + // Run inference + std::vector outputs = module->run(inputs); + + // Allocate output array and copy results + *result_length = outputs.size(); + *result_tensors = new Tensor[outputs.size()]; + + for (size_t i = 0; i < outputs.size(); i++) { + (*result_tensors)[i] = new torch::Tensor(outputs[i]); + } + ); +} diff --git a/src/Native/LibTorchSharp/THSExport.h b/src/Native/LibTorchSharp/THSExport.h new file mode 100644 index 000000000..1525c3fb2 --- /dev/null +++ b/src/Native/LibTorchSharp/THSExport.h @@ -0,0 +1,32 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +#pragma once + +#include "../Stdafx.h" + +#include "torch/torch.h" +#include "torch/csrc/inductor/aoti_package/model_package_loader.h" + +#include "Utils.h" + +// torch.export support via AOTInductor - Load and execute PyTorch ExportedProgram models (.pt2 files) +// ExportedProgram is PyTorch 2.x's recommended way to export models for production deployment +// +// IMPORTANT: This implementation uses torch::inductor::AOTIModelPackageLoader which is +// INFERENCE-ONLY. Training, parameter updates, and device movement are not supported. +// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python. + +// Load an AOTInductor-compiled model package from a .pt2 file +EXPORT_API(ExportedProgramModule) THSExport_load(const char* filename); + +// Dispose of an ExportedProgram module +EXPORT_API(void) THSExport_Module_dispose(const ExportedProgramModule module); + +// Execute the ExportedProgram's forward method (inference only) +// Input: Array of tensors +// Output: Array of result tensors (caller must free) +EXPORT_API(void) THSExport_Module_run( + const ExportedProgramModule module, + const Tensor* input_tensors, + const int input_length, + Tensor** result_tensors, + int* result_length); diff --git a/src/Native/LibTorchSharp/THSJIT.h b/src/Native/LibTorchSharp/THSJIT.h index 81e6d51ad..a6d14b360 100644 --- a/src/Native/LibTorchSharp/THSJIT.h +++ b/src/Native/LibTorchSharp/THSJIT.h @@ -98,3 +98,7 @@ EXPORT_API(TensorOrScalar*) THSJIT_AllocateTensorOrScalarArray(int32_t size); EXPORT_API(void) THSJIT_FreeTensorOrScalarArray(TensorOrScalar* ptr); EXPORT_API(void) THSJIT_SetTensorOrScalar(TensorOrScalar* array, int32_t index, int64_t type_code, int64_t array_index, ptrdiff_t handle); EXPORT_API(TensorOrScalar*) THSJIT_GetTensorOrScalar(TensorOrScalar* array, int32_t index); + +// Helper functions (shared with THSExport) +std::vector toIValue(const TensorOrScalar* tensorPtrs, const int length); +TensorOrScalar* ReturnHelper(c10::IValue result, TensorOrScalar* (*allocator)(int32_t idx, size_t length), int8_t* typeCode, int32_t* idx); diff --git a/src/Native/LibTorchSharp/Utils.h b/src/Native/LibTorchSharp/Utils.h index 4c3606491..05dbb7a70 100644 --- a/src/Native/LibTorchSharp/Utils.h +++ b/src/Native/LibTorchSharp/Utils.h @@ -4,6 +4,7 @@ #include #include "torch/torch.h" +#include "torch/csrc/inductor/aoti_package/model_package_loader.h" extern thread_local char *torch_last_err; @@ -24,6 +25,10 @@ typedef std::shared_ptr * JITFunction; typedef std::shared_ptr * JITType; typedef std::shared_ptr* JITTensorType; +// torch.export ExportedProgram module via AOTInductor +// Note: Uses torch::inductor::AOTIModelPackageLoader for inference-only execution +typedef torch::inductor::AOTIModelPackageLoader* ExportedProgramModule; + struct TensorArray { Tensor *array; int64_t size; diff --git a/src/Redist/libtorch-cpu/libtorch-macos-arm64-2.9.0.zip.sha b/src/Redist/libtorch-cpu/libtorch-macos-arm64-2.9.0.zip.sha new file mode 100644 index 000000000..b46e2ab91 --- /dev/null +++ b/src/Redist/libtorch-cpu/libtorch-macos-arm64-2.9.0.zip.sha @@ -0,0 +1 @@ +6D6AF87CAB301FA25CB4909697A03C65ED234E784CD96C8743A9AD6586238D0E diff --git a/src/Redist/libtorch-cpu/libtorch-shared-with-deps-2.9.0%2Bcpu.zip.sha b/src/Redist/libtorch-cpu/libtorch-shared-with-deps-2.9.0%2Bcpu.zip.sha new file mode 100644 index 000000000..84c029f65 --- /dev/null +++ b/src/Redist/libtorch-cpu/libtorch-shared-with-deps-2.9.0%2Bcpu.zip.sha @@ -0,0 +1 @@ +22DE42ABDE933BE46CE843467930BD0190B72271BFA2C11F84DB95591A9834F1 diff --git a/src/Redist/libtorch-cpu/libtorch-win-shared-with-deps-2.9.0%2Bcpu.zip.sha b/src/Redist/libtorch-cpu/libtorch-win-shared-with-deps-2.9.0%2Bcpu.zip.sha new file mode 100644 index 000000000..03e5ce9f9 --- /dev/null +++ b/src/Redist/libtorch-cpu/libtorch-win-shared-with-deps-2.9.0%2Bcpu.zip.sha @@ -0,0 +1 @@ +C826069DA829550BD3F1205159F8A95EE906A447DD141D08F42C568D4EE9E05E diff --git a/src/Redist/libtorch-cpu/libtorch-win-shared-with-deps-debug-2.9.0%2Bcpu.zip.sha b/src/Redist/libtorch-cpu/libtorch-win-shared-with-deps-debug-2.9.0%2Bcpu.zip.sha new file mode 100644 index 000000000..ab3845733 --- /dev/null +++ b/src/Redist/libtorch-cpu/libtorch-win-shared-with-deps-debug-2.9.0%2Bcpu.zip.sha @@ -0,0 +1 @@ +0892B92717B2396FE7ED62BE9AA6B78074C48BBB34D239F96FCCC70BE4560098 diff --git a/src/TorchSharp/Export/ExportedProgram.cs b/src/TorchSharp/Export/ExportedProgram.cs new file mode 100644 index 000000000..1dde69902 --- /dev/null +++ b/src/TorchSharp/Export/ExportedProgram.cs @@ -0,0 +1,215 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. + +using System; +using System.Runtime.InteropServices; +using TorchSharp.PInvoke; +using static TorchSharp.PInvoke.NativeMethods; + +namespace TorchSharp +{ + public static partial class torch + { + public static partial class export + { + /// + /// Load a PyTorch ExportedProgram from a .pt2 file compiled with AOTInductor. + /// + /// Path to the .pt2 file + /// ExportedProgram model for inference + /// + /// IMPORTANT: The .pt2 file must be compiled with torch._inductor.aoti_compile_and_package() in Python. + /// Models saved with torch.export.save() alone will NOT work - they require AOTInductor compilation. + /// + /// This implementation is INFERENCE-ONLY. Training, parameter updates, and device movement + /// are not supported. The model is compiled for a specific device (CPU/CUDA) at compile time. + /// + /// Example Python code to create compatible .pt2 files: + /// + /// import torch + /// import torch._inductor + /// + /// # Export the model + /// exported = torch.export.export(model, example_inputs) + /// + /// # Compile with AOTInductor (required for C++ loading) + /// torch._inductor.aoti_compile_and_package( + /// exported, + /// package_path="model.pt2" + /// ) + /// + /// + public static ExportedProgram load(string filename) + { + return new ExportedProgram(filename); + } + + /// + /// Load a PyTorch ExportedProgram with typed output. + /// + public static ExportedProgram load(string filename) + { + return new ExportedProgram(filename); + } + } + } + + /// + /// Represents a PyTorch ExportedProgram loaded from an AOTInductor-compiled .pt2 file. + /// This is an INFERENCE-ONLY implementation - training and parameter updates are not supported. + /// + /// + /// Unlike TorchScript models, ExportedProgram models are ahead-of-time (AOT) compiled for + /// a specific device and are optimized for inference performance. They provide 30-40% better + /// latency compared to TorchScript in many cases. + /// + /// Key limitations: + /// - Inference only (no training, no gradients) + /// - No parameter access or updates + /// - No device movement (compiled for specific device) + /// - No dynamic model structure changes + /// + /// Use torch.jit for models that require training or dynamic behavior. + /// + public class ExportedProgram : IDisposable + { + private IntPtr handle; + private bool _disposed = false; + + internal ExportedProgram(string filename) + { + handle = THSExport_load(filename); + if (handle == IntPtr.Zero) + torch.CheckForErrors(); + } + + /// + /// Run inference on the model with the given input tensors. + /// + /// Input tensors for the model + /// Array of output tensors + /// + /// The number and shapes of inputs must match what the model was exported with. + /// All inputs must be on the same device that the model was compiled for. + /// + public torch.Tensor[] run(params torch.Tensor[] inputs) + { + if (_disposed) + throw new ObjectDisposedException(nameof(ExportedProgram)); + + // Convert managed tensors to IntPtr array + IntPtr[] input_handles = new IntPtr[inputs.Length]; + for (int i = 0; i < inputs.Length; i++) + { + input_handles[i] = inputs[i].Handle; + } + + // Call native run method + THSExport_Module_run(handle, input_handles, inputs.Length, out IntPtr result_ptr, out int result_length); + torch.CheckForErrors(); + + // Marshal result array + torch.Tensor[] results = new torch.Tensor[result_length]; + IntPtr[] result_handles = new IntPtr[result_length]; + Marshal.Copy(result_ptr, result_handles, 0, result_length); + + for (int i = 0; i < result_length; i++) + { + results[i] = new torch.Tensor(result_handles[i]); + } + + // Free the native array (tensors are now owned by managed Tensor objects) + Marshal.FreeHGlobal(result_ptr); + + return results; + } + + /// + /// Synonym for run() - executes forward pass. + /// + public torch.Tensor[] forward(params torch.Tensor[] inputs) => run(inputs); + + /// + /// Synonym for run() - executes the model. + /// + public torch.Tensor[] call(params torch.Tensor[] inputs) => run(inputs); + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (handle != IntPtr.Zero) + { + THSExport_Module_dispose(handle); + handle = IntPtr.Zero; + } + _disposed = true; + } + } + + ~ExportedProgram() + { + Dispose(false); + } + } + + /// + /// Generic version of ExportedProgram with typed output. + /// + /// The return type (Tensor, Tensor[], or tuple of Tensors) + public class ExportedProgram : ExportedProgram + { + internal ExportedProgram(string filename) : base(filename) + { + } + + /// + /// Run inference with typed return value. + /// + public new TResult run(params torch.Tensor[] inputs) + { + var results = base.run(inputs); + + // Handle different return types + if (typeof(TResult) == typeof(torch.Tensor)) + { + if (results.Length != 1) + throw new InvalidOperationException($"Expected 1 output tensor, got {results.Length}"); + return (TResult)(object)results[0]; + } + + if (typeof(TResult) == typeof(torch.Tensor[])) + { + return (TResult)(object)results; + } + + // Handle tuple types + if (typeof(TResult).IsGenericType) + { + var genericType = typeof(TResult).GetGenericTypeDefinition(); + if (genericType == typeof(ValueTuple<,>)) + { + if (results.Length != 2) + throw new InvalidOperationException($"Expected 2 output tensors, got {results.Length}"); + return (TResult)Activator.CreateInstance(typeof(TResult), results[0], results[1]); + } + if (genericType == typeof(ValueTuple<,,>)) + { + if (results.Length != 3) + throw new InvalidOperationException($"Expected 3 output tensors, got {results.Length}"); + return (TResult)Activator.CreateInstance(typeof(TResult), results[0], results[1], results[2]); + } + } + + throw new NotSupportedException($"Return type {typeof(TResult)} is not supported"); + } + + public new TResult forward(params torch.Tensor[] inputs) => run(inputs); + public new TResult call(params torch.Tensor[] inputs) => run(inputs); + } +} diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSExport.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSExport.cs new file mode 100644 index 000000000..388a22efe --- /dev/null +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSExport.cs @@ -0,0 +1,32 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +#nullable enable +using System; +using System.Runtime.InteropServices; + +namespace TorchSharp.PInvoke +{ +#pragma warning disable CA2101 + internal static partial class NativeMethods + { + // torch.export support via AOTInductor (INFERENCE-ONLY) + // Models must be compiled with torch._inductor.aoti_compile_and_package() in Python + + // Load ExportedProgram from .pt2 file + [DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)] + internal static extern IntPtr THSExport_load(string filename); + + // Dispose ExportedProgram module + [DllImport("LibTorchSharp")] + internal static extern void THSExport_Module_dispose(IntPtr handle); + + // Execute forward pass (inference only) + [DllImport("LibTorchSharp")] + internal static extern void THSExport_Module_run( + IntPtr module, + IntPtr[] input_tensors, + int input_length, + out IntPtr result_tensors, + out int result_length); + } +#pragma warning restore CA2101 +} diff --git a/test/TorchSharpTest/TestExport.cs b/test/TorchSharpTest/TestExport.cs new file mode 100644 index 000000000..fd24e807a --- /dev/null +++ b/test/TorchSharpTest/TestExport.cs @@ -0,0 +1,130 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.IO; +using System.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +using Xunit; + +#nullable enable + +namespace TorchSharp +{ + [Collection("Sequential")] + public class TestExport + { + [Fact] + public void TestLoadExport_SimpleLinear() + { + // Test loading a simple linear model (inference-only) + using var exported = torch.export.load(@"simple_linear.export.pt2"); + Assert.NotNull(exported); + + var input = torch.ones(10); + var results = exported.run(input); + + Assert.NotNull(results); + Assert.Single(results); + Assert.Equal(new long[] { 5 }, results[0].shape); + Assert.Equal(torch.float32, results[0].dtype); + } + + [Fact] + public void TestLoadExport_LinearReLU() + { + // Test loading a Linear + ReLU model with typed output + using var exported = torch.export.load(@"linrelu.export.pt2"); + Assert.NotNull(exported); + + var input = torch.ones(10); + var result = exported.call(input); + + Assert.Equal(new long[] { 6 }, result.shape); + Assert.Equal(torch.float32, result.dtype); + + // ReLU should zero out negative values + Assert.True(result.data().All(v => v >= 0)); + } + + [Fact] + public void TestLoadExport_TwoInputs() + { + // Test loading a model with two inputs + using var exported = torch.export.load(@"two_inputs.export.pt2"); + Assert.NotNull(exported); + + var input1 = torch.ones(10); + var input2 = torch.ones(10) * 2; + var results = exported.forward(input1, input2); + + Assert.NotNull(results); + Assert.Single(results); + Assert.Equal(new long[] { 10 }, results[0].shape); + + // Should be input1 + input2 = 1 + 2 = 3 + var expected = torch.ones(10) * 3; + Assert.True(expected.allclose(results[0])); + } + + [Fact] + public void TestLoadExport_TupleOutput() + { + // Test loading a model that returns a tuple + using var exported = torch.export.load<(Tensor, Tensor)>(@"tuple_out.export.pt2"); + Assert.NotNull(exported); + + var x = torch.rand(3, 4); + var y = torch.rand(3, 4); + var result = exported.call(x, y); + + Assert.IsType>(result); + var (sum, diff) = result; + + Assert.Equal(x.shape, sum.shape); + Assert.Equal(x.shape, diff.shape); + Assert.True((x + y).allclose(sum)); + Assert.True((x - y).allclose(diff)); + } + + [Fact] + public void TestLoadExport_ListOutput() + { + // Test loading a model that returns a list + using var exported = torch.export.load(@"list_out.export.pt2"); + Assert.NotNull(exported); + + var x = torch.rand(3, 4); + var y = torch.rand(3, 4); + var result = exported.forward(x, y); + + Assert.IsType(result); + Assert.Equal(2, result.Length); + + Assert.True((x + y).allclose(result[0])); + Assert.True((x - y).allclose(result[1])); + } + + [Fact] + public void TestLoadExport_Sequential() + { + // Test loading a sequential model + using var exported = torch.export.load(@"sequential.export.pt2"); + Assert.NotNull(exported); + + var input = torch.ones(1000); + var result = exported.call(input); + + Assert.Equal(new long[] { 10 }, result.shape); + Assert.Equal(torch.float32, result.dtype); + } + + + [Fact] + public void TestExport_LoadNonExistentFile() + { + // Test error handling for non-existent file + Assert.Throws(() => + torch.export.load(@"nonexistent.pt2")); + } + } +} diff --git a/test/TorchSharpTest/TorchSharpTest.csproj b/test/TorchSharpTest/TorchSharpTest.csproj index 2de45fe06..85669483e 100644 --- a/test/TorchSharpTest/TorchSharpTest.csproj +++ b/test/TorchSharpTest/TorchSharpTest.csproj @@ -94,6 +94,24 @@ PreserveNewest + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + PreserveNewest diff --git a/test/TorchSharpTest/generate_export_models.py b/test/TorchSharpTest/generate_export_models.py new file mode 100644 index 000000000..5f35f0c7a --- /dev/null +++ b/test/TorchSharpTest/generate_export_models.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +Generate AOTInductor-compiled ExportedProgram models for TorchSharp testing. + +This script creates .pt2 files using torch._inductor.aoti_compile_and_package(), +which compiles models with AOTInductor for inference-only execution in C++. + +IMPORTANT: Models created with torch.export.save() alone cannot be loaded in LibTorch C++. +They must be compiled with aoti_compile_and_package() to create a loadable package. +""" + +import torch +import torch.nn as nn +import torch._inductor +from pathlib import Path + + +class SimpleLinear(nn.Module): + """Simple linear layer: 10 -> 5""" + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + +class LinearReLU(nn.Module): + """Linear layer with ReLU: 10 -> 6""" + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 6) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.linear(x)) + + +class TwoInputs(nn.Module): + """Model that takes two inputs and adds them""" + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y + + +class TupleOutput(nn.Module): + """Model that returns a tuple of (sum, difference)""" + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y, x - y + + +class ListOutput(nn.Module): + """Model that returns a list of [sum, difference]""" + def __init__(self): + super().__init__() + + def forward(self, x, y): + return [x + y, x - y] + + +class Sequential(nn.Module): + """Sequential model: 1000 -> 100 -> 50 -> 10""" + def __init__(self): + super().__init__() + self.seq = nn.Sequential( + nn.Linear(1000, 100), + nn.ReLU(), + nn.Linear(100, 50), + nn.ReLU(), + nn.Linear(50, 10) + ) + + def forward(self, x): + return self.seq(x) + + +def compile_and_save(model, example_inputs, output_path): + """ + Export and compile a model with AOTInductor. + + Args: + model: PyTorch module to export + example_inputs: Tuple of example inputs for tracing + output_path: Path where to save the .pt2 file + """ + print(f"Compiling {output_path}...") + + # Set model to eval mode (inference only) + model.eval() + + # Export the model + with torch.no_grad(): + exported = torch.export.export(model, example_inputs) + + # Compile with AOTInductor and package + # This creates a .pt2 file that can be loaded in LibTorch C++ + torch._inductor.aoti_compile_and_package( + exported, + package_path=str(output_path) + ) + + print(f" āœ“ Created {output_path}") + + +def main(): + print("Generating AOTInductor-compiled ExportedProgram models...\n") + + # Get the directory where this script is located + script_dir = Path(__file__).parent + + # 1. Simple Linear (10 -> 5) + model1 = SimpleLinear() + compile_and_save( + model1, + (torch.ones(10),), + script_dir / "simple_linear.export.pt2" + ) + + # 2. Linear + ReLU (10 -> 6) + model2 = LinearReLU() + compile_and_save( + model2, + (torch.ones(10),), + script_dir / "linrelu.export.pt2" + ) + + # 3. Two Inputs (adds two tensors) + model3 = TwoInputs() + compile_and_save( + model3, + (torch.ones(10), torch.ones(10)), + script_dir / "two_inputs.export.pt2" + ) + + # 4. Tuple Output (returns sum and difference) + model4 = TupleOutput() + compile_and_save( + model4, + (torch.rand(3, 4), torch.rand(3, 4)), + script_dir / "tuple_out.export.pt2" + ) + + # 5. List Output (returns [sum, difference]) + model5 = ListOutput() + compile_and_save( + model5, + (torch.rand(3, 4), torch.rand(3, 4)), + script_dir / "list_out.export.pt2" + ) + + # 6. Sequential (1000 -> 100 -> 50 -> 10) + model6 = Sequential() + compile_and_save( + model6, + (torch.ones(1000),), + script_dir / "sequential.export.pt2" + ) + + print("\nāœ“ All models compiled successfully!") + print("\nThese models are now compatible with LibTorch C++ via") + print("torch::inductor::AOTIModelPackageLoader for inference-only execution.") + + +if __name__ == "__main__": + main() diff --git a/test/TorchSharpTest/linrelu.export.pt2 b/test/TorchSharpTest/linrelu.export.pt2 new file mode 100644 index 000000000..b2e20bc8a Binary files /dev/null and b/test/TorchSharpTest/linrelu.export.pt2 differ diff --git a/test/TorchSharpTest/list_out.export.pt2 b/test/TorchSharpTest/list_out.export.pt2 new file mode 100644 index 000000000..fb04aec66 Binary files /dev/null and b/test/TorchSharpTest/list_out.export.pt2 differ diff --git a/test/TorchSharpTest/sequential.export.pt2 b/test/TorchSharpTest/sequential.export.pt2 new file mode 100644 index 000000000..abee0bd13 Binary files /dev/null and b/test/TorchSharpTest/sequential.export.pt2 differ diff --git a/test/TorchSharpTest/simple_linear.export.pt2 b/test/TorchSharpTest/simple_linear.export.pt2 new file mode 100644 index 000000000..74587004b Binary files /dev/null and b/test/TorchSharpTest/simple_linear.export.pt2 differ diff --git a/test/TorchSharpTest/tuple_out.export.pt2 b/test/TorchSharpTest/tuple_out.export.pt2 new file mode 100644 index 000000000..355417d7f Binary files /dev/null and b/test/TorchSharpTest/tuple_out.export.pt2 differ diff --git a/test/TorchSharpTest/two_inputs.export.pt2 b/test/TorchSharpTest/two_inputs.export.pt2 new file mode 100644 index 000000000..230cd4c4b Binary files /dev/null and b/test/TorchSharpTest/two_inputs.export.pt2 differ