Skip to content

Add converter for aten::_grouped_mm.default#2805

Open
Sid-V5 wants to merge 4 commits intomicrosoft:mainfrom
Sid-V5:add-grouped-mm-2795
Open

Add converter for aten::_grouped_mm.default#2805
Sid-V5 wants to merge 4 commits intomicrosoft:mainfrom
Sid-V5:add-grouped-mm-2795

Conversation

@Sid-V5
Copy link

@Sid-V5 Sid-V5 commented Feb 12, 2026

Implemented the converter for aten::_grouped_mm.default to address #2795.

Changes

  • Added aten_grouped_mm function in onnxscript/function_libs/torch_lib/ops/core.py

Implementation Details

The aten::_grouped_mm operator performs grouped matrix multiplication. This implementation handles the batch/dense mode (when offs is None), where groups are implicit in the batch dimension:

  • self: (G, M, K), mat2: (G, K, N) → result: (G, M, N)
  • Uses op.MatMul for the core computation
  • Supports optional bias addition via op.Add
  • Supports optional out_dtype casting via op.Cast

The offset-based mode (when offs is provided) raises NotImplementedError, as it requires segment-level matrix multiplications that are not directly expressible with standard ONNX operators.

Testing

The function follows the same patterns as other converters in core.py (e.g., aten_bmm, aten_mm) and uses the @torch_op decorator for automatic registration.

Fixes #2795

Implements the converter for aten::_grouped_mm.default to address issue microsoft#2795. Handles the batch/dense mode where groups are implicit in the batch dimension using MatMul, with optional bias addition and dtype casting.
@Sid-V5
Copy link
Author

Sid-V5 commented Feb 12, 2026

@microsoft-github-policy-service agree

@codecov
Copy link

codecov bot commented Feb 16, 2026

Codecov Report

❌ Patch coverage is 50.00000% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 71.86%. Comparing base (4c4f7a0) to head (76eaefc).
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 50.00% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2805      +/-   ##
==========================================
- Coverage   71.86%   71.86%   -0.01%     
==========================================
  Files         239      239              
  Lines       29139    29147       +8     
  Branches     2875     2877       +2     
==========================================
+ Hits        20942    20946       +4     
- Misses       7219     7221       +2     
- Partials      978      980       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


# If offs is None, it uses the "dense" / "batch" mode where groups are implicit in the batch dimension.
# self: (G, M, K), mat2: (G, K, N) -> (G, M, N)
# TODO: Implement sparse mode when offs is not None.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you raise a not implemented error?

opinfo_core.OpInfo(
"ops.aten._grouped_mm",
aten_name="_grouped_mm",
op=_mock_grouped_mm,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use

Suggested change
op=_mock_grouped_mm,
op=torch.ops.aten._grouped_mm,

?

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Mar 13, 2026
@justinchuby justinchuby requested a review from Copilot March 13, 2026 16:26
# Test without bias
yield opinfo_core.SampleInput(self_t, args=(mat2_t,))

def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None):

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning test

Unused argument 'offs' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
# Test without bias
yield opinfo_core.SampleInput(self_t, args=(mat2_t,))

def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None):

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning test

Unused argument 'out_dtype' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an ONNXScript converter for PyTorch’s aten::_grouped_mm (dense/batched mode) and wires it into the TorchLib op test infrastructure to address missing dispatch for aten._grouped_mm.default (#2795).

Changes:

  • Implemented aten_grouped_mm converter using MatMul, optional Add (bias), and optional Cast (out_dtype).
  • Registered the new op in TorchLib’s tested op list.
  • Added OpInfo + sample inputs scaffolding for _grouped_mm in the extra op database.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
onnxscript/function_libs/torch_lib/ops/core.py Adds the aten::_grouped_mm converter implementation.
tests/function_libs/torch_lib/ops_test_data.py Enables testing by adding a TorchLibOpInfo entry for ops.aten._grouped_mm.
tests/function_libs/torch_lib/extra_opinfo.py Adds OpInfo registration and sample input generation for _grouped_mm.

Comment on lines +46 to +50

def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None):
res = torch.matmul(self, mat2)
if bias is not None:
res = res + bias
Comment on lines +47 to +50
def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None):
res = torch.matmul(self, mat2)
if bias is not None:
res = res + bias

# If offs is None, it uses the "dense" / "batch" mode where groups are implicit in the batch dimension.
# self: (G, M, K), mat2: (G, K, N) -> (G, M, N)
# TODO: Implement sparse mode when offs is not None.

# Test without bias
yield opinfo_core.SampleInput(self_t, args=(mat2_t,))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

Missing converter for OpOverload(op='aten._grouped_mm', overload='default')

4 participants