diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 596ec9cb7..aec8a6e4e 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -24,6 +24,7 @@ __Bug Fixes__: __API Changes__: +#1498 Add support for torch.export ExportedProgram models (.pt2 files)
#1503 Add ReadOnlySpan overloads to many methods.
#1478 Fix `torch.jit.ScriptModule.zero_grad`.
#1495 Make `torchvision.io.read_image` and `torchvision.io.read_image_async` allow subsequent opening of the file for reading.
diff --git a/src/Native/LibTorchSharp/CMakeLists.txt b/src/Native/LibTorchSharp/CMakeLists.txt index 60b61f049..8e5e1e38a 100644 --- a/src/Native/LibTorchSharp/CMakeLists.txt +++ b/src/Native/LibTorchSharp/CMakeLists.txt @@ -11,6 +11,7 @@ set(SOURCES crc32c.h THSAutograd.h THSData.h + THSExport.h THSJIT.h THSNN.h THSStorage.h @@ -23,6 +24,7 @@ set(SOURCES THSActivation.cpp THSAutograd.cpp THSData.cpp + THSExport.cpp THSFFT.cpp THSJIT.cpp THSLinearAlgebra.cpp diff --git a/src/Native/LibTorchSharp/THSExport.cpp b/src/Native/LibTorchSharp/THSExport.cpp new file mode 100644 index 000000000..06c4b8a30 --- /dev/null +++ b/src/Native/LibTorchSharp/THSExport.cpp @@ -0,0 +1,63 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +#include "THSExport.h" + +// torch.export support via AOTInductor +// This uses torch::inductor::AOTIModelPackageLoader which is INFERENCE-ONLY +// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python + +ExportedProgramModule THSExport_load(const char* filename) +{ + CATCH( + // Load .pt2 file using AOTIModelPackageLoader + // This requires models to be compiled with aoti_compile_and_package() + auto* loader = new torch::inductor::AOTIModelPackageLoader(filename); + return loader; + ); + + return nullptr; +} + +void THSExport_Module_dispose(const ExportedProgramModule module) +{ + delete module; +} + +void THSExport_Module_run( + const ExportedProgramModule module, + const Tensor* input_tensors, + const int input_length, + Tensor** result_tensors, + int64_t* result_length) +{ + *result_tensors = nullptr; + *result_length = 0; + + CATCH( + // Convert input tensor pointers to std::vector + std::vector inputs; + inputs.reserve(input_length); + for (int i = 0; i < input_length; i++) { + inputs.push_back(*input_tensors[i]); + } + + // Run inference + std::vector outputs = module->run(inputs); + + // Allocate output array and copy results + auto count = outputs.size(); + auto* tensors = new Tensor[count]; + + for (size_t i = 0; i < count; i++) { + tensors[i] = new torch::Tensor(outputs[i]); + } + + // Only expose to caller after full success + *result_tensors = tensors; + *result_length = static_cast(count); + ); +} + +void THSExport_Module_run_free_results(Tensor* result_tensors) +{ + delete[] result_tensors; +} diff --git a/src/Native/LibTorchSharp/THSExport.h b/src/Native/LibTorchSharp/THSExport.h new file mode 100644 index 000000000..ed1758674 --- /dev/null +++ b/src/Native/LibTorchSharp/THSExport.h @@ -0,0 +1,39 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +#pragma once + +#include "../Stdafx.h" + +#include "torch/torch.h" +#include "torch/csrc/inductor/aoti_package/model_package_loader.h" + +#include "Utils.h" + +// torch.export ExportedProgram module via AOTInductor +// Note: Uses torch::inductor::AOTIModelPackageLoader for inference-only execution +typedef torch::inductor::AOTIModelPackageLoader* ExportedProgramModule; + +// torch.export support via AOTInductor - Load and execute PyTorch ExportedProgram models (.pt2 files) +// ExportedProgram is PyTorch 2.x's recommended way to export models for production deployment +// +// IMPORTANT: This implementation uses torch::inductor::AOTIModelPackageLoader which is +// INFERENCE-ONLY. Training, parameter updates, and device movement are not supported. +// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python. + +// Load an AOTInductor-compiled model package from a .pt2 file +EXPORT_API(ExportedProgramModule) THSExport_load(const char* filename); + +// Dispose of an ExportedProgram module +EXPORT_API(void) THSExport_Module_dispose(const ExportedProgramModule module); + +// Execute the ExportedProgram's forward method (inference only) +// Input: Array of tensors +// Output: Array of result tensors (caller must free with THSExport_Module_run_free_results) +EXPORT_API(void) THSExport_Module_run( + const ExportedProgramModule module, + const Tensor* input_tensors, + const int input_length, + Tensor** result_tensors, + int64_t* result_length); + +// Free the result tensor array allocated by THSExport_Module_run +EXPORT_API(void) THSExport_Module_run_free_results(Tensor* result_tensors); diff --git a/src/TorchSharp/Export/ExportedProgram.cs b/src/TorchSharp/Export/ExportedProgram.cs new file mode 100644 index 000000000..b016c16b8 --- /dev/null +++ b/src/TorchSharp/Export/ExportedProgram.cs @@ -0,0 +1,238 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. + +using System; +using System.Runtime.InteropServices; +using static TorchSharp.PInvoke.NativeMethods; + +namespace TorchSharp +{ + public static partial class torch + { + public static partial class export + { + /// + /// Load a PyTorch ExportedProgram from a .pt2 file compiled with AOTInductor. + /// + /// Path to the .pt2 file + /// ExportedProgram model for inference + /// + /// IMPORTANT: The .pt2 file must be compiled with torch._inductor.aoti_compile_and_package() in Python. + /// Models saved with torch.export.save() alone will NOT work - they require AOTInductor compilation. + /// + /// This implementation is INFERENCE-ONLY. Training, parameter updates, and device movement + /// are not supported. The model is compiled for a specific device (CPU/CUDA) at compile time. + /// + /// Example Python code to create compatible .pt2 files: + /// + /// import torch + /// import torch._inductor + /// + /// # Export the model + /// exported = torch.export.export(model, example_inputs) + /// + /// # Compile with AOTInductor (required for C++ loading) + /// torch._inductor.aoti_compile_and_package( + /// exported, + /// package_path="model.pt2" + /// ) + /// + /// + public static ExportedProgram load(string filename) + { + return new ExportedProgram(filename); + } + + /// + /// Load a PyTorch ExportedProgram with typed output. + /// + public static ExportedProgram load(string filename) + { + return new ExportedProgram(filename); + } + } + } + + /// + /// Represents a PyTorch ExportedProgram loaded from an AOTInductor-compiled .pt2 file. + /// This is an INFERENCE-ONLY implementation - training and parameter updates are not supported. + /// + /// + /// Unlike TorchScript models, ExportedProgram models are ahead-of-time (AOT) compiled for + /// a specific device and are optimized for inference performance. They provide 30-40% better + /// latency compared to TorchScript in many cases. + /// + /// Key limitations: + /// - Inference only (no training, no gradients) + /// - No parameter access or updates + /// - No device movement (compiled for specific device) + /// - No dynamic model structure changes + /// + /// Use torch.jit for models that require training or dynamic behavior. + /// + public class ExportedProgram : IDisposable + { + private IntPtr handle; + private bool _disposed = false; + + internal ExportedProgram(string filename) + { + handle = THSExport_load(filename); + if (handle == IntPtr.Zero) + torch.CheckForErrors(); + } + + /// + /// Run inference on the model with the given input tensors. + /// + /// Input tensors for the model + /// Array of output tensors + /// + /// The number and shapes of inputs must match what the model was exported with. + /// All inputs must be on the same device that the model was compiled for. + /// + public torch.Tensor[] run(params torch.Tensor[] inputs) + { + if (_disposed) + throw new ObjectDisposedException(nameof(ExportedProgram)); + + // Convert managed tensors to IntPtr array + IntPtr[] input_handles = new IntPtr[inputs.Length]; + for (int i = 0; i < inputs.Length; i++) + { + input_handles[i] = inputs[i].Handle; + } + + // Call native run method + THSExport_Module_run(handle, input_handles, inputs.Length, out IntPtr result_ptr, out long result_length); + torch.CheckForErrors(); + + // Marshal result array + if (result_length < 0 || result_length > int.MaxValue) + throw new InvalidOperationException( + $"Native export run returned an out-of-range result length: {result_length}."); + + int count = (int)result_length; + torch.Tensor[] results = new torch.Tensor[count]; + IntPtr[] result_handles = new IntPtr[count]; + + try + { + Marshal.Copy(result_ptr, result_handles, 0, count); + + for (int i = 0; i < count; i++) + { + results[i] = new torch.Tensor(result_handles[i]); + } + } + finally + { + // Free the native array (tensors are now owned by managed Tensor objects) + THSExport_Module_run_free_results(result_ptr); + } + + return results; + } + + /// + /// Synonym for run() - executes forward pass. + /// + public torch.Tensor[] forward(params torch.Tensor[] inputs) => run(inputs); + + /// + /// Synonym for run() - executes the model. + /// + public torch.Tensor[] call(params torch.Tensor[] inputs) => run(inputs); + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (handle != IntPtr.Zero) + { + THSExport_Module_dispose(handle); + handle = IntPtr.Zero; + } + _disposed = true; + } + } + + ~ExportedProgram() + { + Dispose(false); + } + } + + /// + /// Generic version of ExportedProgram with typed output. + /// + /// The return type (Tensor, Tensor[], or tuple of Tensors) + public class ExportedProgram : ExportedProgram + { + internal ExportedProgram(string filename) : base(filename) + { + } + + /// + /// Run inference with typed return value. + /// + public new TResult run(params torch.Tensor[] inputs) + { + var results = base.run(inputs); + + // Handle different return types + if (typeof(TResult) == typeof(torch.Tensor)) + { + if (results.Length != 1) + throw new InvalidOperationException($"Expected 1 output tensor, got {results.Length}"); + return (TResult)(object)results[0]; + } + + if (typeof(TResult) == typeof(torch.Tensor[])) + { + return (TResult)(object)results; + } + + // Handle tuple types + if (typeof(TResult).IsGenericType) + { + var resultType = typeof(TResult); + var genericType = resultType.GetGenericTypeDefinition(); + + if (genericType == typeof(ValueTuple<,>)) + { + var typeArgs = resultType.GetGenericArguments(); + if (typeArgs[0] != typeof(torch.Tensor) || typeArgs[1] != typeof(torch.Tensor)) + throw new NotSupportedException( + $"Tuple return type {resultType} is not supported. Only ValueTuple is supported."); + + if (results.Length != 2) + throw new InvalidOperationException($"Expected 2 output tensors, got {results.Length}"); + return (TResult)Activator.CreateInstance(resultType, results[0], results[1]); + } + + if (genericType == typeof(ValueTuple<,,>)) + { + var typeArgs = resultType.GetGenericArguments(); + if (typeArgs[0] != typeof(torch.Tensor) || typeArgs[1] != typeof(torch.Tensor) || typeArgs[2] != typeof(torch.Tensor)) + throw new NotSupportedException( + $"Tuple return type {resultType} is not supported. Only ValueTuple is supported."); + + if (results.Length != 3) + throw new InvalidOperationException($"Expected 3 output tensors, got {results.Length}"); + return (TResult)Activator.CreateInstance(resultType, results[0], results[1], results[2]); + } + } + + throw new NotSupportedException($"Return type {typeof(TResult)} is not supported"); + } + + public new TResult forward(params torch.Tensor[] inputs) => run(inputs); + public new TResult call(params torch.Tensor[] inputs) => run(inputs); + } +} diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSExport.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSExport.cs new file mode 100644 index 000000000..061a69f74 --- /dev/null +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSExport.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +#nullable enable +using System; +using System.Runtime.InteropServices; + +namespace TorchSharp.PInvoke +{ +#pragma warning disable CA2101 + internal static partial class NativeMethods + { + // torch.export support via AOTInductor (INFERENCE-ONLY) + // Models must be compiled with torch._inductor.aoti_compile_and_package() in Python + + // Load ExportedProgram from .pt2 file + [DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)] + internal static extern IntPtr THSExport_load(string filename); + + // Dispose ExportedProgram module + [DllImport("LibTorchSharp")] + internal static extern void THSExport_Module_dispose(IntPtr handle); + + // Execute forward pass (inference only) + [DllImport("LibTorchSharp")] + internal static extern void THSExport_Module_run( + IntPtr module, + IntPtr[] input_tensors, + int input_length, + out IntPtr result_tensors, + out long result_length); + + // Free result tensor array allocated by THSExport_Module_run + [DllImport("LibTorchSharp")] + internal static extern void THSExport_Module_run_free_results(IntPtr result_tensors); + } +#pragma warning restore CA2101 +} diff --git a/test/TorchSharpTest/TestExport.cs b/test/TorchSharpTest/TestExport.cs new file mode 100644 index 000000000..891a0f995 --- /dev/null +++ b/test/TorchSharpTest/TestExport.cs @@ -0,0 +1,67 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.IO; +using System.Runtime.InteropServices; +using Xunit; + +#nullable enable + +namespace TorchSharp +{ + [Collection("Sequential")] + public class TestExport + { + [Fact] + public void TestExport_LoadNonExistentFile() + { + Assert.Throws(() => + torch.export.load("nonexistent.pt2")); + } + + [Fact] + public void TestExport_LoadInvalidFile() + { + var tmpFile = Path.GetTempFileName(); + try + { + File.WriteAllBytes(tmpFile, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }); + Assert.ThrowsAny(() => + torch.export.load(tmpFile)); + } + finally + { + File.Delete(tmpFile); + } + } + + [Fact] + public void TestExport_LoadEmptyPath() + { + Assert.ThrowsAny(() => + torch.export.load("")); + } + + [Fact] + public void TestExport_GenericLoadNonExistentFile() + { + Assert.Throws(() => + torch.export.load("nonexistent.pt2")); + } + + [Fact] + public void TestExport_GenericLoadInvalidFile() + { + var tmpFile = Path.GetTempFileName(); + try + { + File.WriteAllBytes(tmpFile, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }); + Assert.ThrowsAny(() => + torch.export.load(tmpFile)); + } + finally + { + File.Delete(tmpFile); + } + } + } +}