From 6a38cb0d3900e26ddc8bfc49d34a634c050fbd8b Mon Sep 17 00:00:00 2001 From: wineandchord Date: Thu, 12 Mar 2026 22:16:33 +0800 Subject: [PATCH 1/6] Add aten._trilinear support to torch_lib core --- .../function_libs/torch_lib/ops/core.py | 55 +++++++++++++++++++ tests/function_libs/torch_lib/extra_opinfo.py | 36 ++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 15 +++++ 3 files changed, 106 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 67de7076fa..c37b66c8e4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -56,6 +56,7 @@ _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi +_EINSUM_SYMBOLS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" @torch_op("aten::_local_scalar_dense", trace_only=True) @@ -1192,6 +1193,60 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor: return op.CastLike(sampled, self) +def _get_einsum_symbol(dim: int) -> str: + if dim >= len(_EINSUM_SYMBOLS): + raise ValueError("aten::_trilinear only supports up to 52 dimensions") + return _EINSUM_SYMBOLS[dim] + + +def _build_trilinear_subscript(total_dim: int, expanded_dims: Sequence[int]) -> str: + expanded_dims_set = set(expanded_dims) + return "".join( + _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in expanded_dims_set + ) + + +def _build_trilinear_equation( + total_dim: int, + expand1: Sequence[int], + expand2: Sequence[int], + expand3: Sequence[int], + sumdim: Sequence[int], +) -> str: + sumdim_set = set(sumdim) + output_subscript = "".join( + _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in sumdim_set + ) + return ( + f"{_build_trilinear_subscript(total_dim, expand1)}," + f"{_build_trilinear_subscript(total_dim, expand2)}," + f"{_build_trilinear_subscript(total_dim, expand3)}->{output_subscript}" + ) + + +@torch_op("aten::_trilinear", trace_only=True) +def aten__trilinear( + i1: TReal, + i2: TReal, + i3: TReal, + expand1: Sequence[int], + expand2: Sequence[int], + expand3: Sequence[int], + sumdim: Sequence[int], + unroll_dim: int = 1, +) -> TReal: + """_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor""" + + del unroll_dim + + input_rank = getattr(i1, "rank", None) + if input_rank is None: + input_rank = len(i1.shape) + total_dim = input_rank + len(expand1) + equation = _build_trilinear_equation(total_dim, expand1, expand2, expand3, sumdim) + return op.Einsum(i1, i2, i3, equation=equation) + + @torch_op("aten::bilinear", trace_only=True) def aten_bilinear( input1: TensorType, diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index a28a6c9cd9..919e9e6942 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -68,6 +68,35 @@ def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(input1, args=(input2, weight, None)) +def sample_inputs__trilinear(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for aten._trilinear using bilinear's internal call pattern.""" + del op_info + del kwargs + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + cases = [ + (2, 3, 4, 5), + (1, 2, 2, 1), + (3, 5, 2, 4), + ] + expand1 = [1, 3] + expand2 = [0] + expand3 = [1, 2] + sumdim = [2, 3] + + for batch_size, in1_features, in2_features, out_features in cases: + input1 = make_arg((batch_size, in1_features)) + weight = make_arg((out_features, in1_features, in2_features)) + input2 = make_arg((batch_size, in2_features)) + yield opinfo_core.SampleInput( + input1, + args=(weight, input2, expand1, expand2, expand3, sumdim, 1), + ) + + def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -2516,6 +2545,13 @@ def __init__(self): sample_inputs_func=sample_inputs_bilinear, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._trilinear.default", + aten_name="_trilinear.default", + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs__trilinear, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a40535f4ba..9242b0b7b3 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -270,6 +270,15 @@ def _einsum_input_wrangler( return [args[1], args[0]], kwargs +def _trilinear_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + for index in range(3, 7): + if isinstance(args[index], np.ndarray): + args[index] = args[index].tolist() + return args, kwargs + + def _embedding_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -601,6 +610,12 @@ def _where_input_wrangler( TorchLibOpInfo( "bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (2e-5, 2e-5)} ), + TorchLibOpInfo( + "ops.aten._trilinear.default", + core_ops.aten__trilinear, + tolerance={torch.float32: (2e-5, 2e-5)}, + input_wrangler=_trilinear_input_wrangler, + ), TorchLibOpInfo( # This string is a unique ID. In extra_opinfo.py, we # also define test data for this ID with From 907a359a6e488dd98d396777918b2bba2f72d470 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Fri, 13 Mar 2026 00:13:34 +0800 Subject: [PATCH 2/6] Use string.ascii_letters for einsum symbols --- onnxscript/function_libs/torch_lib/ops/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c37b66c8e4..8aed631ed5 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -12,6 +12,7 @@ from __future__ import annotations import math +import string from typing import Any, Optional, Sequence, Tuple, Union import numpy as np @@ -56,7 +57,7 @@ _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi -_EINSUM_SYMBOLS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +_EINSUM_SYMBOLS = string.ascii_letters @torch_op("aten::_local_scalar_dense", trace_only=True) From 4fe91ab8d9ddbad9396a114839bba9ae4c1a10a3 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Fri, 13 Mar 2026 10:10:23 +0800 Subject: [PATCH 3/6] Adjust aten._trilinear sample inputs --- tests/function_libs/torch_lib/extra_opinfo.py | 8 ++++---- tests/function_libs/torch_lib/ops_test_data.py | 10 ---------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 919e9e6942..6c3f50412a 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -82,10 +82,10 @@ def sample_inputs__trilinear(op_info, device, dtype, requires_grad, **kwargs): (1, 2, 2, 1), (3, 5, 2, 4), ] - expand1 = [1, 3] - expand2 = [0] - expand3 = [1, 2] - sumdim = [2, 3] + expand1 = (1, 3) + expand2 = (0,) + expand3 = (1, 2) + sumdim = (2, 3) for batch_size, in1_features, in2_features, out_features in cases: input1 = make_arg((batch_size, in1_features)) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 9242b0b7b3..5215d3e6b1 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -270,15 +270,6 @@ def _einsum_input_wrangler( return [args[1], args[0]], kwargs -def _trilinear_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - for index in range(3, 7): - if isinstance(args[index], np.ndarray): - args[index] = args[index].tolist() - return args, kwargs - - def _embedding_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -614,7 +605,6 @@ def _where_input_wrangler( "ops.aten._trilinear.default", core_ops.aten__trilinear, tolerance={torch.float32: (2e-5, 2e-5)}, - input_wrangler=_trilinear_input_wrangler, ), TorchLibOpInfo( # This string is a unique ID. In extra_opinfo.py, we From f232330f09c98f557a7a805f21f7444c0ce7ab89 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Sat, 14 Mar 2026 00:39:55 +0800 Subject: [PATCH 4/6] Reorder aten._trilinear definitions --- .../function_libs/torch_lib/ops/core.py | 108 +++++++++--------- .../function_libs/torch_lib/ops_test_data.py | 10 +- 2 files changed, 59 insertions(+), 59 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8aed631ed5..6f979f8f42 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1194,60 +1194,6 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor: return op.CastLike(sampled, self) -def _get_einsum_symbol(dim: int) -> str: - if dim >= len(_EINSUM_SYMBOLS): - raise ValueError("aten::_trilinear only supports up to 52 dimensions") - return _EINSUM_SYMBOLS[dim] - - -def _build_trilinear_subscript(total_dim: int, expanded_dims: Sequence[int]) -> str: - expanded_dims_set = set(expanded_dims) - return "".join( - _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in expanded_dims_set - ) - - -def _build_trilinear_equation( - total_dim: int, - expand1: Sequence[int], - expand2: Sequence[int], - expand3: Sequence[int], - sumdim: Sequence[int], -) -> str: - sumdim_set = set(sumdim) - output_subscript = "".join( - _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in sumdim_set - ) - return ( - f"{_build_trilinear_subscript(total_dim, expand1)}," - f"{_build_trilinear_subscript(total_dim, expand2)}," - f"{_build_trilinear_subscript(total_dim, expand3)}->{output_subscript}" - ) - - -@torch_op("aten::_trilinear", trace_only=True) -def aten__trilinear( - i1: TReal, - i2: TReal, - i3: TReal, - expand1: Sequence[int], - expand2: Sequence[int], - expand3: Sequence[int], - sumdim: Sequence[int], - unroll_dim: int = 1, -) -> TReal: - """_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor""" - - del unroll_dim - - input_rank = getattr(i1, "rank", None) - if input_rank is None: - input_rank = len(i1.shape) - total_dim = input_rank + len(expand1) - equation = _build_trilinear_equation(total_dim, expand1, expand2, expand3, sumdim) - return op.Einsum(i1, i2, i3, equation=equation) - - @torch_op("aten::bilinear", trace_only=True) def aten_bilinear( input1: TensorType, @@ -9847,6 +9793,60 @@ def aten_tril_indices(row: int, col: int, offset: int = 0) -> TensorType: raise NotImplementedError() +def _get_einsum_symbol(dim: int) -> str: + if dim >= len(_EINSUM_SYMBOLS): + raise ValueError("aten::_trilinear only supports up to 52 dimensions") + return _EINSUM_SYMBOLS[dim] + + +def _build_trilinear_subscript(total_dim: int, expanded_dims: Sequence[int]) -> str: + expanded_dims_set = set(expanded_dims) + return "".join( + _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in expanded_dims_set + ) + + +def _build_trilinear_equation( + total_dim: int, + expand1: Sequence[int], + expand2: Sequence[int], + expand3: Sequence[int], + sumdim: Sequence[int], +) -> str: + sumdim_set = set(sumdim) + output_subscript = "".join( + _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in sumdim_set + ) + return ( + f"{_build_trilinear_subscript(total_dim, expand1)}," + f"{_build_trilinear_subscript(total_dim, expand2)}," + f"{_build_trilinear_subscript(total_dim, expand3)}->{output_subscript}" + ) + + +@torch_op("aten::_trilinear", trace_only=True) +def aten__trilinear( + i1: TReal, + i2: TReal, + i3: TReal, + expand1: Sequence[int], + expand2: Sequence[int], + expand3: Sequence[int], + sumdim: Sequence[int], + unroll_dim: int = 1, +) -> TReal: + """_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor""" + + del unroll_dim + + input_rank = getattr(i1, "rank", None) + if input_rank is None: + input_rank = len(i1.shape) + total_dim = input_rank + len(expand1) + equation = _build_trilinear_equation(total_dim, expand1, expand2, expand3, sumdim) + return op.Einsum(i1, i2, i3, equation=equation) + + def aten_triplet_margin_loss( anchor: TensorType, positive: TensorType, diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 5215d3e6b1..16e365af5a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -601,11 +601,6 @@ def _where_input_wrangler( TorchLibOpInfo( "bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (2e-5, 2e-5)} ), - TorchLibOpInfo( - "ops.aten._trilinear.default", - core_ops.aten__trilinear, - tolerance={torch.float32: (2e-5, 2e-5)}, - ), TorchLibOpInfo( # This string is a unique ID. In extra_opinfo.py, we # also define test data for this ID with @@ -1298,6 +1293,11 @@ def _where_input_wrangler( dtypes=(torch.int32,), reason="fixme: ORT does not have an implementation of Trilu for int32.", ), + TorchLibOpInfo( + "ops.aten._trilinear.default", + core_ops.aten__trilinear, + tolerance={torch.float32: (2e-5, 2e-5)}, + ), TorchLibOpInfo("triu", core_ops.aten_triu).xfail( dtypes=(torch.int32,), reason="fixme: ORT does not have an implementation of Trilu for int32.", From 2e064411df5e191bd5b78fcca4d2010e430e80f3 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Sat, 14 Mar 2026 07:51:10 +0800 Subject: [PATCH 5/6] Validate aten._trilinear dimension inputs --- .../function_libs/torch_lib/ops/core.py | 29 ++++++++++++-- tests/function_libs/torch_lib/ops_test.py | 40 +++++++++++++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6f979f8f42..2880619d60 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -9799,7 +9799,26 @@ def _get_einsum_symbol(dim: int) -> str: return _EINSUM_SYMBOLS[dim] -def _build_trilinear_subscript(total_dim: int, expanded_dims: Sequence[int]) -> str: +def _validate_trilinear_dims( + total_dim: int, dims: Sequence[int], dims_name: str +) -> None: + seen_dims = set() + for dim in dims: + if dim < 0 or dim >= total_dim: + raise ValueError( + f"aten::_trilinear {dims_name} values must be in [0, {total_dim})" + ) + if dim in seen_dims: + raise ValueError( + f"aten::_trilinear {dims_name} values must be unique" + ) + seen_dims.add(dim) + + +def _build_trilinear_subscript( + total_dim: int, expanded_dims: Sequence[int], dims_name: str +) -> str: + _validate_trilinear_dims(total_dim, expanded_dims, dims_name) expanded_dims_set = set(expanded_dims) return "".join( _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in expanded_dims_set @@ -9813,14 +9832,16 @@ def _build_trilinear_equation( expand3: Sequence[int], sumdim: Sequence[int], ) -> str: + _validate_trilinear_dims(total_dim, sumdim, "sumdim") sumdim_set = set(sumdim) output_subscript = "".join( _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in sumdim_set ) return ( - f"{_build_trilinear_subscript(total_dim, expand1)}," - f"{_build_trilinear_subscript(total_dim, expand2)}," - f"{_build_trilinear_subscript(total_dim, expand3)}->{output_subscript}" + f"{_build_trilinear_subscript(total_dim, expand1, 'expand1')}," + f"{_build_trilinear_subscript(total_dim, expand2, 'expand2')}," + f"{_build_trilinear_subscript(total_dim, expand3, 'expand3')}" + f"->{output_subscript}" ) diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index beb74b5462..c825525ec0 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -40,6 +40,7 @@ import onnxscript from onnxscript._internal import version_utils +from onnxscript.function_libs.torch_lib.ops import core as core_ops from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, @@ -110,6 +111,45 @@ def test_script_function_passes_checker( onnx.checker.check_function(function_proto) # type: ignore[attr-defined] +class TestTrilinearHelpers(unittest.TestCase): + def test_build_trilinear_equation_returns_expected_equation(self) -> None: + equation = core_ops._build_trilinear_equation( + 4, + (1, 3), + (0,), + (1, 2), + (2, 3), + ) + + self.assertEqual(equation, "ac,bcd,ad->ab") + + def test_build_trilinear_equation_rejects_out_of_range_dims(self) -> None: + with self.assertRaisesRegex( + ValueError, + "aten::_trilinear expand1 values must be in", + ): + core_ops._build_trilinear_equation( + 4, + (4,), + (0,), + (1, 2), + (2, 3), + ) + + def test_build_trilinear_equation_rejects_duplicate_dims(self) -> None: + with self.assertRaisesRegex( + ValueError, + "aten::_trilinear sumdim values must be unique", + ): + core_ops._build_trilinear_equation( + 4, + (1, 3), + (0,), + (1, 2), + (2, 2), + ) + + def run_test_output_match( test_suite: unittest.TestCase, device: str, From c5d4cdac47e5632332b2d58e58e17d1d53106ad5 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Sat, 14 Mar 2026 08:50:23 +0800 Subject: [PATCH 6/6] Remove unnecessary trilinear helper tests --- tests/function_libs/torch_lib/ops_test.py | 40 ----------------------- 1 file changed, 40 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index c825525ec0..beb74b5462 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -40,7 +40,6 @@ import onnxscript from onnxscript._internal import version_utils -from onnxscript.function_libs.torch_lib.ops import core as core_ops from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, @@ -111,45 +110,6 @@ def test_script_function_passes_checker( onnx.checker.check_function(function_proto) # type: ignore[attr-defined] -class TestTrilinearHelpers(unittest.TestCase): - def test_build_trilinear_equation_returns_expected_equation(self) -> None: - equation = core_ops._build_trilinear_equation( - 4, - (1, 3), - (0,), - (1, 2), - (2, 3), - ) - - self.assertEqual(equation, "ac,bcd,ad->ab") - - def test_build_trilinear_equation_rejects_out_of_range_dims(self) -> None: - with self.assertRaisesRegex( - ValueError, - "aten::_trilinear expand1 values must be in", - ): - core_ops._build_trilinear_equation( - 4, - (4,), - (0,), - (1, 2), - (2, 3), - ) - - def test_build_trilinear_equation_rejects_duplicate_dims(self) -> None: - with self.assertRaisesRegex( - ValueError, - "aten::_trilinear sumdim values must be unique", - ): - core_ops._build_trilinear_equation( - 4, - (1, 3), - (0,), - (1, 2), - (2, 2), - ) - - def run_test_output_match( test_suite: unittest.TestCase, device: str,