Add converter for aten::_grouped_mm.default#2805
Add converter for aten::_grouped_mm.default#2805Sid-V5 wants to merge 4 commits intomicrosoft:mainfrom
aten::_grouped_mm.default#2805Conversation
Implements the converter for aten::_grouped_mm.default to address issue microsoft#2795. Handles the batch/dense mode where groups are implicit in the batch dimension using MatMul, with optional bias addition and dtype casting.
|
@microsoft-github-policy-service agree |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2805 +/- ##
==========================================
- Coverage 71.86% 71.86% -0.01%
==========================================
Files 239 239
Lines 29139 29147 +8
Branches 2875 2877 +2
==========================================
+ Hits 20942 20946 +4
- Misses 7219 7221 +2
- Partials 978 980 +2 ☔ View full report in Codecov by Sentry. |
|
|
||
| # If offs is None, it uses the "dense" / "batch" mode where groups are implicit in the batch dimension. | ||
| # self: (G, M, K), mat2: (G, K, N) -> (G, M, N) | ||
| # TODO: Implement sparse mode when offs is not None. |
There was a problem hiding this comment.
Could you raise a not implemented error?
| opinfo_core.OpInfo( | ||
| "ops.aten._grouped_mm", | ||
| aten_name="_grouped_mm", | ||
| op=_mock_grouped_mm, |
There was a problem hiding this comment.
Why not use
| op=_mock_grouped_mm, | |
| op=torch.ops.aten._grouped_mm, |
?
| # Test without bias | ||
| yield opinfo_core.SampleInput(self_t, args=(mat2_t,)) | ||
|
|
||
| def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None): |
Check warning
Code scanning / lintrunner
PYLINT/W0613 Warning test
| # Test without bias | ||
| yield opinfo_core.SampleInput(self_t, args=(mat2_t,)) | ||
|
|
||
| def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None): |
Check warning
Code scanning / lintrunner
PYLINT/W0613 Warning test
There was a problem hiding this comment.
Pull request overview
Adds an ONNXScript converter for PyTorch’s aten::_grouped_mm (dense/batched mode) and wires it into the TorchLib op test infrastructure to address missing dispatch for aten._grouped_mm.default (#2795).
Changes:
- Implemented
aten_grouped_mmconverter usingMatMul, optionalAdd(bias), and optionalCast(out_dtype). - Registered the new op in TorchLib’s tested op list.
- Added OpInfo + sample inputs scaffolding for
_grouped_mmin the extra op database.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| onnxscript/function_libs/torch_lib/ops/core.py | Adds the aten::_grouped_mm converter implementation. |
| tests/function_libs/torch_lib/ops_test_data.py | Enables testing by adding a TorchLibOpInfo entry for ops.aten._grouped_mm. |
| tests/function_libs/torch_lib/extra_opinfo.py | Adds OpInfo registration and sample input generation for _grouped_mm. |
|
|
||
| def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None): | ||
| res = torch.matmul(self, mat2) | ||
| if bias is not None: | ||
| res = res + bias |
| def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None): | ||
| res = torch.matmul(self, mat2) | ||
| if bias is not None: | ||
| res = res + bias |
|
|
||
| # If offs is None, it uses the "dense" / "batch" mode where groups are implicit in the batch dimension. | ||
| # self: (G, M, K), mat2: (G, K, N) -> (G, M, N) | ||
| # TODO: Implement sparse mode when offs is not None. |
|
|
||
| # Test without bias | ||
| yield opinfo_core.SampleInput(self_t, args=(mat2_t,)) | ||
|
|
Implemented the converter for
aten::_grouped_mm.defaultto address #2795.Changes
aten_grouped_mmfunction inonnxscript/function_libs/torch_lib/ops/core.pyImplementation Details
The
aten::_grouped_mmoperator performs grouped matrix multiplication. This implementation handles the batch/dense mode (whenoffsisNone), where groups are implicit in the batch dimension:self: (G, M, K),mat2: (G, K, N)→ result:(G, M, N)op.MatMulfor the core computationbiasaddition viaop.Addout_dtypecasting viaop.CastThe offset-based mode (when
offsis provided) raisesNotImplementedError, as it requires segment-level matrix multiplications that are not directly expressible with standard ONNX operators.Testing
The function follows the same patterns as other converters in
core.py(e.g.,aten_bmm,aten_mm) and uses the@torch_opdecorator for automatic registration.Fixes #2795