Skip to content

Fix aten_gru: add linear_before_reset=1 to match PyTorch GRU semantics#2853

Open
Copilot wants to merge 2 commits intomainfrom
copilot/add-linear-before-reset-to-gru
Open

Fix aten_gru: add linear_before_reset=1 to match PyTorch GRU semantics#2853
Copilot wants to merge 2 commits intomainfrom
copilot/add-linear-before-reset-to-gru

Conversation

Copy link
Contributor

Copilot AI commented Mar 14, 2026

PyTorch nn.GRU applies the linear transformation before multiplying by the reset gate (linear_before_reset=1), but the aten_gru translation was emitting ONNX GRU ops with the default linear_before_reset=0, producing numerically wrong results (error ~0.1 vs expected ~1e-7).

Changes

  • onnxscript/function_libs/torch_lib/ops/core.py: Add linear_before_reset=1 to both op.GRU calls in aten_gru — the biased and unbiased variants.
# Before (incorrect — uses default linear_before_reset=0)
Y, Y_h = op.GRU(current_input, W, R, B, initial_h=layer_h,
                direction=direction, hidden_size=hidden_size_attr)

# After (correct — matches PyTorch GRU: ht = g(Xt*(Wh^T) + rt (.) (Ht-1*(Rh^T) + Rbh) + Wbh))
Y, Y_h = op.GRU(current_input, W, R, B, initial_h=layer_h,
                direction=direction, hidden_size=hidden_size_attr,
                linear_before_reset=1)
Original prompt

This section details on the original issue you should resolve

<issue_title>GRU translation missing linear_before_reset=1 (produces incorrect results for PyTorch GRU)</issue_title>
<issue_description>## Summary

The aten_gru translation in onnxscript/function_libs/torch_lib/ops/core.py (added in #2674) does not set linear_before_reset=1 on the ONNX GRU op. This causes numerically incorrect results because PyTorch's nn.GRU uses the linear_before_reset=1 variant.

Details

PyTorch GRU computes the new gate as:

n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))

This matches the ONNX GRU spec with linear_before_reset=1. But the default linear_before_reset=0 applies the reset gate before the linear transformation — a different equation.

The two op.GRU calls (~lines 4352 and 4362) need linear_before_reset=1 added.

Reproduction

import torch, numpy as np

m = torch.nn.GRU(1, 32, batch_first=True)
m.eval()
inp = torch.randn(1, 10, 1)

with torch.no_grad():
    pt_out, _ = m(inp)

torch.onnx.export(m, (inp,), f="gru.onnx")

import onnxruntime as ort
sess = ort.InferenceSession("gru.onnx")
onnx_out = sess.run(None, {sess.get_inputs()[0].name: inp.numpy()})[0]

print("Max abs diff:", np.abs(pt_out.numpy() - onnx_out).max())
# Expected: ~1e-7 (float32 precision)
# Actual:   ~0.1 (incorrect GRU equation)

Environment

  • torch 2.10.0
  • onnxscript 0.6.2
  • onnxruntime 1.22.0

References

Comments on the Issue (you are @copilot in this section)


📱 Kick off Copilot coding agent tasks wherever you are with GitHub Mobile, available on iOS and Android.

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix missing linear_before_reset in GRU translation for ONNX Fix aten_gru: add linear_before_reset=1 to match PyTorch GRU semantics Mar 14, 2026
Copilot AI requested a review from justinchuby March 14, 2026 21:36
@justinchuby justinchuby marked this pull request as ready for review March 14, 2026 21:40
@justinchuby justinchuby enabled auto-merge (squash) March 14, 2026 21:40
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Mar 14, 2026
@codecov
Copy link

codecov bot commented Mar 14, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 71.86%. Comparing base (4c4f7a0) to head (170a407).
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2853   +/-   ##
=======================================
  Coverage   71.86%   71.86%           
=======================================
  Files         239      239           
  Lines       29139    29139           
  Branches     2875     2875           
=======================================
  Hits        20942    20942           
  Misses       7219     7219           
  Partials      978      978           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

GRU translation missing linear_before_reset=1 (produces incorrect results for PyTorch GRU)

2 participants