diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 67de7076fa..2880619d60 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -12,6 +12,7 @@ from __future__ import annotations import math +import string from typing import Any, Optional, Sequence, Tuple, Union import numpy as np @@ -56,6 +57,7 @@ _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi +_EINSUM_SYMBOLS = string.ascii_letters @torch_op("aten::_local_scalar_dense", trace_only=True) @@ -9791,6 +9793,81 @@ def aten_tril_indices(row: int, col: int, offset: int = 0) -> TensorType: raise NotImplementedError() +def _get_einsum_symbol(dim: int) -> str: + if dim >= len(_EINSUM_SYMBOLS): + raise ValueError("aten::_trilinear only supports up to 52 dimensions") + return _EINSUM_SYMBOLS[dim] + + +def _validate_trilinear_dims( + total_dim: int, dims: Sequence[int], dims_name: str +) -> None: + seen_dims = set() + for dim in dims: + if dim < 0 or dim >= total_dim: + raise ValueError( + f"aten::_trilinear {dims_name} values must be in [0, {total_dim})" + ) + if dim in seen_dims: + raise ValueError( + f"aten::_trilinear {dims_name} values must be unique" + ) + seen_dims.add(dim) + + +def _build_trilinear_subscript( + total_dim: int, expanded_dims: Sequence[int], dims_name: str +) -> str: + _validate_trilinear_dims(total_dim, expanded_dims, dims_name) + expanded_dims_set = set(expanded_dims) + return "".join( + _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in expanded_dims_set + ) + + +def _build_trilinear_equation( + total_dim: int, + expand1: Sequence[int], + expand2: Sequence[int], + expand3: Sequence[int], + sumdim: Sequence[int], +) -> str: + _validate_trilinear_dims(total_dim, sumdim, "sumdim") + sumdim_set = set(sumdim) + output_subscript = "".join( + _get_einsum_symbol(dim) for dim in range(total_dim) if dim not in sumdim_set + ) + return ( + f"{_build_trilinear_subscript(total_dim, expand1, 'expand1')}," + f"{_build_trilinear_subscript(total_dim, expand2, 'expand2')}," + f"{_build_trilinear_subscript(total_dim, expand3, 'expand3')}" + f"->{output_subscript}" + ) + + +@torch_op("aten::_trilinear", trace_only=True) +def aten__trilinear( + i1: TReal, + i2: TReal, + i3: TReal, + expand1: Sequence[int], + expand2: Sequence[int], + expand3: Sequence[int], + sumdim: Sequence[int], + unroll_dim: int = 1, +) -> TReal: + """_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor""" + + del unroll_dim + + input_rank = getattr(i1, "rank", None) + if input_rank is None: + input_rank = len(i1.shape) + total_dim = input_rank + len(expand1) + equation = _build_trilinear_equation(total_dim, expand1, expand2, expand3, sumdim) + return op.Einsum(i1, i2, i3, equation=equation) + + def aten_triplet_margin_loss( anchor: TensorType, positive: TensorType, diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index a28a6c9cd9..6c3f50412a 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -68,6 +68,35 @@ def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(input1, args=(input2, weight, None)) +def sample_inputs__trilinear(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for aten._trilinear using bilinear's internal call pattern.""" + del op_info + del kwargs + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + cases = [ + (2, 3, 4, 5), + (1, 2, 2, 1), + (3, 5, 2, 4), + ] + expand1 = (1, 3) + expand2 = (0,) + expand3 = (1, 2) + sumdim = (2, 3) + + for batch_size, in1_features, in2_features, out_features in cases: + input1 = make_arg((batch_size, in1_features)) + weight = make_arg((out_features, in1_features, in2_features)) + input2 = make_arg((batch_size, in2_features)) + yield opinfo_core.SampleInput( + input1, + args=(weight, input2, expand1, expand2, expand3, sumdim, 1), + ) + + def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -2516,6 +2545,13 @@ def __init__(self): sample_inputs_func=sample_inputs_bilinear, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._trilinear.default", + aten_name="_trilinear.default", + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs__trilinear, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a40535f4ba..16e365af5a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1293,6 +1293,11 @@ def _where_input_wrangler( dtypes=(torch.int32,), reason="fixme: ORT does not have an implementation of Trilu for int32.", ), + TorchLibOpInfo( + "ops.aten._trilinear.default", + core_ops.aten__trilinear, + tolerance={torch.float32: (2e-5, 2e-5)}, + ), TorchLibOpInfo("triu", core_ops.aten_triu).xfail( dtypes=(torch.int32,), reason="fixme: ORT does not have an implementation of Trilu for int32.",