diff --git a/onnxscript/function_libs/torch_lib/ops/vision.py b/onnxscript/function_libs/torch_lib/ops/vision.py index 5c1b1fda6b..c88484392f 100644 --- a/onnxscript/function_libs/torch_lib/ops/vision.py +++ b/onnxscript/function_libs/torch_lib/ops/vision.py @@ -56,13 +56,13 @@ def _process_sampling_ratio_for_roi_align(sampling_ratio: int): def torchvision_roi_align( input, boxes, - output_size: Sequence[int], - spatial_scale: float = 1.0, + spatial_scale: float, + pooled_height: int, + pooled_width: int, sampling_ratio: int = -1, aligned: bool = False, ): - """roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor""" - pooled_height, pooled_width = output_size + """roi_align(input, boxes, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned)""" batch_indices = _process_batch_indices_for_roi_align(boxes) rois_coords = _process_rois_for_roi_align(boxes) coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" @@ -79,7 +79,6 @@ def torchvision_roi_align( sampling_ratio=sampling_ratio, ) - @torch_op("torchvision::roi_pool", trace_only=True) def torchvision_roi_pool(input, boxes, output_size: Sequence[int], spatial_scale: float = 1.0): """roi_pool(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0) -> torch.Tensor""" diff --git a/tests/function_libs/torch_lib/ops/vision_test.py b/tests/function_libs/torch_lib/ops/vision_test.py new file mode 100644 index 0000000000..b5cdf1ccfb --- /dev/null +++ b/tests/function_libs/torch_lib/ops/vision_test.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest +import torch +import torchvision +from torch.onnx import export +import os + +class VisionOperatorTest(unittest.TestCase): + def setUp(self): + self.model_path = "roi_align_test.onnx" + + def tearDown(self): + if os.path.exists(self.model_path): + os.remove(self.model_path) + + def test_roi_align_export_with_seven_arguments(self): + """ + Tests that torchvision::roi_align exports correctly with 7 positional arguments. + This covers the signature change where output_size is decomposed into + pooled_height and pooled_width. + """ + class RoiAlignModel(torch.nn.Module): + def forward(self, x, boxes): + return torchvision.ops.roi_align( + x, + boxes, + output_size=(7, 7), + spatial_scale=0.5, + sampling_ratio=2, + aligned=True + ) + + # Create dummy inputs: (N, C, H, W) and (K, 5) + x = torch.randn(1, 3, 32, 32, dtype=torch.float32) + boxes = torch.tensor([[0, 0, 0, 10, 10]], dtype=torch.float32) + model = RoiAlignModel().eval() + + try: + export(model, (x, boxes), self.model_path) + export_success = True + except Exception as e: + export_success = False + self.fail(f"torch.onnx.export failed for roi_align: {e}") + + self.assertTrue(export_success) + +if __name__ == "__main__": + unittest.main()