Skip to content

Fix bicubic antialias export: use cubic_coeff_a=-0.5 instead of -0.75#2849

Open
Copilot wants to merge 2 commits intomainfrom
copilot/fix-bicubic-antialias-coefficient
Open

Fix bicubic antialias export: use cubic_coeff_a=-0.5 instead of -0.75#2849
Copilot wants to merge 2 commits intomainfrom
copilot/fix-bicubic-antialias-coefficient

Conversation

Copy link
Contributor

Copilot AI commented Mar 11, 2026

When exporting F.interpolate(mode='bicubic', antialias=True), the ONNX Resize node was emitted with cubic_coeff_a=-0.75 (OpenCV-compatible), but PyTorch uses -0.5 (Keys/PIL-compatible) for the antialias path. This caused ~32x higher numerical error vs. PyTorch when running the exported model in ONNX Runtime.

Changes

  • _aten_upsample_output_size / _aten_upsample_scales: Added cubic_coeff_a: float = -0.75 parameter (default preserves existing behavior for non-antialias cases) and thread it through to op.Resize.
  • aten__upsample_bicubic2d_aa: Pass cubic_coeff_a=-0.5 to match PyTorch's runtime behavior when antialias=True.
# antialias=True  → cubic_coeff_a=-0.5  (Keys/PIL-compatible)  ✓
# antialias=False → cubic_coeff_a=-0.75 (OpenCV-compatible)    ✓
Original prompt

This section details on the original issue you should resolve

<issue_title>ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic antialias=True (should be -0.5)</issue_title>
<issue_description>### 🐛 Describe the bug

ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic antialias=True (should be -0.5)

Bug

When exporting F.interpolate(mode='bicubic', antialias=True) to ONNX via the dynamo exporter, the Resize node is written with cubic_coeff_a=-0.75. However, PyTorch internally uses cubic_coeff_a=-0.5 (Keys interpolation) when antialias=True, as documented in the source:

// aten/src/ATen/native/cpu/UpSampleKernel.cpp, line ~1347
// We are using -0.5 for bicubic, antialiasing=true (compatibility with PIL)
// and using -0.75 for bicubic, antialiasing=false (compatibility with Opencv)
constexpr scalar_t a = use_keys_cubic ? -0.5 : -0.75;

The exported ONNX model therefore produces different results than PyTorch when run in ONNX Runtime (or any runtime that correctly respects the cubic_coeff_a attribute).

The -0.75 value was originally hardcoded in PR pytorch/pytorch#24805 for the non-antialias case and was carried forward without accounting for the antialias path. The distinction between -0.5 (Keys, PIL-compatible) and -0.75 (OpenCV-compatible) based on the antialias flag was introduced in the ATen kernels via pytorch/vision#3810 and pytorch#68819.

The legacy TorchScript exporter does not support antialias=True at all (UnsupportedOperatorError), so this only affects the dynamo exporter.

To reproduce

import io
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn
import torch.nn.functional as F


class BicubicAA(nn.Module):
    def forward(self, x):
        return F.interpolate(x, size=[224, 224], mode="bicubic",
                             align_corners=False, antialias=True)


# Export
model = BicubicAA()
model.eval()
x = torch.rand(1, 3, 800, 600)
buf = io.BytesIO()
torch.onnx.export(model, (x,), buf, opset_version=18, dynamo=True)
buf.seek(0)
onnx_model = onnx.load(buf)

# Inspect: cubic_coeff_a is -0.75 (wrong for antialias=True)
for node in onnx_model.graph.node:
    if node.op_type == "Resize":
        for attr in node.attribute:
            if attr.name == "cubic_coeff_a":
                print(f"Exported cubic_coeff_a = {attr.f}")  # -0.75
            if attr.name == "antialias":
                print(f"Exported antialias = {attr.i}")       # 1

# Numerical impact
with torch.no_grad():
    pt_out = model(x).numpy()

buf.seek(0)
sess = ort.InferenceSession(buf.read())
ort_wrong = sess.run(None, {"x": x.numpy()})[0]

# Patch to correct value and re-run
for node in onnx_model.graph.node:
    if node.op_type == "Resize":
        for attr in node.attribute:
            if attr.name == "cubic_coeff_a":
                attr.f = -0.5
buf2 = io.BytesIO()
onnx.save(onnx_model, buf2)
buf2.seek(0)
sess2 = ort.InferenceSession(buf2.read())
ort_fixed = sess2.run(None, {"x": x.numpy()})[0]

print(f"PyTorch vs ONNX (exported a=-0.75): mean={np.abs(ort_wrong - pt_out).mean():.2e}")
print(f"PyTorch vs ONNX (patched  a=-0.50): mean={np.abs(ort_fixed - pt_out).mean():.2e}")

Output:

Exported cubic_coeff_a = -0.75
Exported antialias = 1
PyTorch vs ONNX (exported a=-0.75): mean=5.31e-03
PyTorch vs ONNX (patched  a=-0.50): mean=1.67e-04

Patching cubic_coeff_a to -0.5 reduces mean error by 32x, confirming that PyTorch uses -0.5 at runtime but the exporter writes -0.75.

Expected behavior

When antialias=True, the ONNX Resize node should be exported with cubic_coeff_a=-0.5 to match PyTorch's runtime behavior. When antialias=False, cubic_coeff_a=-0.75 is correct.

Versions

Collecting environment information...
PyTorch version: 2.10.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 4.2.3
Libc version: glibc-2.31

Python version: 3.12.12 (main, Feb 3 2026, 22:51:04) [Clang 21.1.4 ] (64-bit runtime)
Python platform: Linux-5.4.0-208-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 565.57.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical...


🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. Learn more about Advanced Security.

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix cubic_coeff_a value for bicubic antialias in ONNX export Fix bicubic antialias export: use cubic_coeff_a=-0.5 instead of -0.75 Mar 11, 2026
@justinchuby justinchuby marked this pull request as ready for review March 11, 2026 16:40
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Mar 11, 2026
@codecov
Copy link

codecov bot commented Mar 11, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 71.86%. Comparing base (4c4f7a0) to head (0bebad4).
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2849   +/-   ##
=======================================
  Coverage   71.86%   71.86%           
=======================================
  Files         239      239           
  Lines       29139    29139           
  Branches     2875     2875           
=======================================
  Hits        20942    20942           
  Misses       7219     7219           
  Partials      978      978           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby enabled auto-merge (squash) March 11, 2026 16:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic antialias=True (should be -0.5)

2 participants