Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ def _process_sampling_ratio_for_roi_align(sampling_ratio: int):
def torchvision_roi_align(
input,
boxes,
output_size: Sequence[int],
spatial_scale: float = 1.0,
spatial_scale: float,
pooled_height: int,
pooled_width: int,
sampling_ratio: int = -1,
aligned: bool = False,
):
"""roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor"""
pooled_height, pooled_width = output_size
"""roi_align(input, boxes, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned)"""
batch_indices = _process_batch_indices_for_roi_align(boxes)
rois_coords = _process_rois_for_roi_align(boxes)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
Expand All @@ -79,7 +79,6 @@ def torchvision_roi_align(
sampling_ratio=sampling_ratio,
)


@torch_op("torchvision::roi_pool", trace_only=True)
def torchvision_roi_pool(input, boxes, output_size: Sequence[int], spatial_scale: float = 1.0):
"""roi_pool(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0) -> torch.Tensor"""
Expand Down
50 changes: 50 additions & 0 deletions tests/function_libs/torch_lib/ops/vision_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import unittest
import torch
import torchvision
from torch.onnx import export
import os

class VisionOperatorTest(unittest.TestCase):
def setUp(self):
self.model_path = "roi_align_test.onnx"

def tearDown(self):
if os.path.exists(self.model_path):
os.remove(self.model_path)

def test_roi_align_export_with_seven_arguments(self):
"""
Tests that torchvision::roi_align exports correctly with 7 positional arguments.
This covers the signature change where output_size is decomposed into
pooled_height and pooled_width.
"""
class RoiAlignModel(torch.nn.Module):
def forward(self, x, boxes):
return torchvision.ops.roi_align(
x,
boxes,
output_size=(7, 7),
spatial_scale=0.5,
sampling_ratio=2,
aligned=True
)

# Create dummy inputs: (N, C, H, W) and (K, 5)
x = torch.randn(1, 3, 32, 32, dtype=torch.float32)
boxes = torch.tensor([[0, 0, 0, 10, 10]], dtype=torch.float32)
model = RoiAlignModel().eval()

try:
export(model, (x, boxes), self.model_path)
export_success = True
except Exception as e:
export_success = False
self.fail(f"torch.onnx.export failed for roi_align: {e}")

self.assertTrue(export_success)

if __name__ == "__main__":
unittest.main()