From d9d49b9518c3d67b50fbb94f8c10d81bded96892 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Mar 2026 16:28:24 +0000 Subject: [PATCH 1/2] Initial plan From 0bebad41c8470db0e904a207366e09d9a793a5e8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Mar 2026 16:33:34 +0000 Subject: [PATCH 2/2] Fix cubic_coeff_a=-0.5 for bicubic antialias=True in torchlib Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/nn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index ee6f589851..de89ff6bad 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2329,6 +2329,7 @@ def _aten_upsample_output_size( mode: str, coordinate_transformation_mode: str, antialias: int = 0, + cubic_coeff_a: float = -0.75, ) -> TReal: batch_and_channel = op.Shape(self, end=2, start=0) # When output_size is passed in as a list of integers, the torch.onnx @@ -2344,6 +2345,7 @@ def _aten_upsample_output_size( output_size, mode=mode, coordinate_transformation_mode=coordinate_transformation_mode, + cubic_coeff_a=cubic_coeff_a, nearest_mode="floor", antialias=antialias, ) @@ -2355,6 +2357,7 @@ def _aten_upsample_scales( mode: str, coordinate_transformation_mode: str, antialias: int = 0, + cubic_coeff_a: float = -0.75, ) -> TReal: return op.Resize( self, @@ -2365,6 +2368,7 @@ def _aten_upsample_scales( None, mode=mode, coordinate_transformation_mode=coordinate_transformation_mode, + cubic_coeff_a=cubic_coeff_a, nearest_mode="floor", antialias=antialias, ) @@ -2404,12 +2408,15 @@ def aten__upsample_bicubic2d_aa( # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, # unless when align_corners is True, in which case we do not know what is going on. coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + # PyTorch uses cubic_coeff_a=-0.5 (Keys interpolation, PIL-compatible) when + # antialias=True, as opposed to -0.75 (OpenCV-compatible) for the non-antialias case. return _aten_upsample_output_size( self, output_size, mode="cubic", coordinate_transformation_mode=coordinate_transformation_mode, antialias=1, + cubic_coeff_a=-0.5, )