From 191ae695bc247c7cde3afb3d86bd6b814122fb0a Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 13 Feb 2026 12:41:28 -0800 Subject: [PATCH] [devtools] Add reference_graph parameter to calculate_numeric_gap This change adds a `reference_graph` parameter to the `calculate_numeric_gap` API, allowing users to explicitly select which graph to use as the golden reference for numeric gap calculation, which enables backends like Qualcomm to use the post-custom-transform graph as the golden reference for numeric gap calculation. This is part of the operator-level numeric discrepancy detector project for ExecuTorch Qualcomm backend (https://github.com/pytorch/executorch/issues/16381). Design doc: https://docs.google.com/document/d/1GaCHiy9InytOsUrl2BKEgOiP1iKTfpCVdWg6QDh0N2E/edit?tab=t.0#heading=h.fcrpnrtb6cud Differential Revision: [D93266779](https://our.internmc.facebook.com/intern/diff/D93266779/) [ghstack-poisoned] --- devtools/inspector/_inspector.py | 123 +++++++++-- devtools/inspector/tests/inspector_test.py | 233 ++++++++++++++++++++- 2 files changed, 338 insertions(+), 18 deletions(-) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 368824f71a3..47009512469 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1166,31 +1166,90 @@ def _consume_etrecord(self) -> None: def _get_aot_intermediate_outputs_and_op_names( self, + reference_graph_name: str, disable_debug_handle_valdiation: bool = False, ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]: """ Capture intermediate outputs only if _representative_inputs are provided - when using bundled program to create the etrecord + when using bundled program to create the etrecord. + + Args: + reference_graph_name: Name of the graph to use as the reference for intermediate + output capture. Must be one of: + - "exported_program": Uses the ATen dialect exported program. Requires + successful debug handle backpropagation, otherwise raises an error. + - "edge_dialect_exported_program": Uses the Edge dialect program directly. + - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" + for post-custom-transform graph when transform_passes are applied in + to_edge_transform_and_lower with generate_etrecord=True). + disable_debug_handle_valdiation: If True, skip debug handle validation. + + Returns: + Tuple of (intermediate_outputs, debug_handle_to_op_names) dictionaries. + + Raises: + ValueError: If the specified reference_graph_name is not available or if + debug handle backpropagation fails for "exported_program". """ if self._etrecord._representative_inputs is None: return {}, {} export_program = None - # Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is one of the ancestors of the edge_dialect_program - if self._etrecord.exported_program and propagate_back_debug_handle( - self._etrecord.exported_program, - self._etrecord.export_graph_id, - self._etrecord.edge_dialect_program, - disable_debug_handle_valdiation, - ): - export_program = self._etrecord.exported_program - else: - log.warning( - "Either aten dialect exported program is not in ETRecord, or it is not one of the ancestors of current edge dialect program." - "Will fall back to use edge dialect program to extract intermediate output", - ) + if reference_graph_name == "exported_program": + # Use exported_program only if backpropagation succeeds + if self._etrecord.exported_program and propagate_back_debug_handle( + self._etrecord.exported_program, + self._etrecord.export_graph_id, + self._etrecord.edge_dialect_program, + disable_debug_handle_valdiation, + ): + export_program = self._etrecord.exported_program + log.info( + "Using 'exported_program' (ATen dialect) as reference graph for intermediate output capture" + ) + else: + raise ValueError( + "Cannot use 'exported_program' as reference graph: either the ATen dialect " + "exported program is not in ETRecord, or debug handle backpropagation failed. " + "Consider using 'edge_dialect_exported_program' instead." + ) + elif reference_graph_name == "edge_dialect_exported_program": + # Use edge_dialect_program directly export_program = self._etrecord.edge_dialect_program + log.info( + "Using 'edge_dialect_exported_program' (Edge dialect) as reference graph for intermediate output capture" + ) + else: + # Try to fetch from graph_map + # If no method name is provided (no "/" in the name), try adding "/forward" as default + lookup_name = reference_graph_name + if "/" not in reference_graph_name: + lookup_name = f"{reference_graph_name}/forward" + log.info( + f"No method name specified in '{reference_graph_name}', " + f"using '{lookup_name}' as default" + ) + + if ( + self._etrecord.graph_map is not None + and lookup_name in self._etrecord.graph_map + ): + export_program = self._etrecord.graph_map[lookup_name] + log.info( + f"Using '{lookup_name}' from graph_map as reference graph for intermediate output capture" + ) + else: + available_graphs = ( + list(self._etrecord.graph_map.keys()) + if self._etrecord.graph_map + else [] + ) + raise ValueError( + f"Reference graph '{lookup_name}' not found. " + f"Available options: 'exported_program', 'edge_dialect_exported_program', " + f"or one of the graphs in graph_map: {available_graphs}" + ) graph_module = export_program.module() aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping( graph_module @@ -1406,11 +1465,11 @@ def calculate_numeric_gap( self, distance: Union[str, NumericalComparatorBase], disable_debug_handle_valdiation: bool = False, + reference_graph: Optional[str] = None, ): """ Compares logged intermediate outputs from the exported graph (in ETRecord) with runtime outputs (in ETDump) using a user-specific numerical comparator. - If the exported graph is not supported, the function will fall back to use edge dialect graph. To use this function, you must first generate the ETRecord with representative inputs, and then create the Inspector instance with the ETRecord and ETDump. The Inspector can then @@ -1423,18 +1482,48 @@ def calculate_numeric_gap( logic by subclassing NumericalComparatorBase and implementing the element_compare() method. Custom comparators can also override the preprocessing() method to apply transformations (e.g., layout conversion, dequantization) before comparison. - disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., + disable_debug_handle_valdiation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This flag allows one to override such behavior and make best effort comparison. + reference_graph: Name of the graph to use as the golden reference for intermediate output capture. + Must be one of: + - "exported_program": Uses the ATen dialect exported program. Requires successful debug + handle backpropagation, otherwise raises an error. + - "edge_dialect_exported_program": Uses the Edge dialect program directly. + - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" for + post-custom-transform graph when transform_passes are applied in to_edge_transform_and_lower + with generate_etrecord=True). + + If None (default), automatically selects the best available graph: + - Uses "exported_program" if available and debug handle backpropagation succeeds. + - Falls back to "edge_dialect_exported_program" otherwise. Returns: pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps. """ + # Determine the reference graph to use + if reference_graph is None: + # Auto-select: try exported_program first, fall back to edge_dialect_exported_program + if self._etrecord.exported_program and propagate_back_debug_handle( + self._etrecord.exported_program, + self._etrecord.export_graph_id, + self._etrecord.edge_dialect_program, + disable_debug_handle_valdiation, + ): + reference_graph = "exported_program" + else: + log.warning( + "Either ATen dialect exported program is not in ETRecord, or debug handle " + "backpropagation failed. Falling back to 'edge_dialect_exported_program'." + ) + reference_graph = "edge_dialect_exported_program" + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( self._get_aot_intermediate_outputs_and_op_names( - disable_debug_handle_valdiation + reference_graph, + disable_debug_handle_valdiation, ) ) if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0: diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 6f22cd106c8..6413c91cd66 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -17,7 +17,7 @@ from typing import Callable, List, Union -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pandas as pd @@ -1024,6 +1024,237 @@ def element_compare(self, a, b) -> float: ) self.assertIn("Invalid runtime debug handle", str(context.exception)) + def test_calculate_numeric_gap_with_reference_graph_name(self): + """Test calculate_numeric_gap with the reference_graph_name parameter.""" + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ): + inspector_instance = Inspector( + etdump_path="", + etrecord="", + ) + + # Create mock intermediate outputs + aot_intermediate_outputs = { + (0,): ([torch.tensor([1.0, 2.0, 3.0])], 1), + (1,): ([torch.tensor([4.0, 5.0, 6.0])], 1), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1), + (1,): ([torch.tensor([5.0, 6.0, 7.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): ["op_0"], (1,): ["op_1"]} + runtime_debug_handle_to_op_name = {(0,): ["op_0"], (1,): ["op_1"]} + + # Create a mock graph module for the reference graph + class MockGraphModule: + def __init__(self): + self.graph = MockGraph() + + def module(self): + return self + + class MockGraph: + def __init__(self): + self.nodes = [] + + mock_graph_module = MockGraphModule() + + # Create a real ETRecord and use add_extra_export_modules to add the graph + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.exported_program = None + mock_etrecord.edge_dialect_program = mock_graph_module + + # Simulate what add_extra_export_modules does - it adds "/forward" suffix + # So "edge_after_transform" becomes "edge_after_transform/forward" + mock_etrecord.graph_map = {"edge_after_transform/forward": mock_graph_module} + + inspector_instance._etrecord = mock_etrecord + + # Mock the runtime intermediate outputs + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Mock IntermediateOutputCapturer to return our AOT outputs + with patch( + "executorch.devtools.inspector._inspector.IntermediateOutputCapturer" + ) as mock_capturer_class, patch( + "executorch.devtools.inspector._inspector.get_aot_debug_handle_to_op_name_mapping" + ) as mock_get_mapping: + mock_capturer = MagicMock() + mock_capturer.run_and_capture.return_value = aot_intermediate_outputs + mock_capturer_class.return_value = mock_capturer + mock_get_mapping.return_value = aot_debug_handle_to_op_name + + # Test with reference_graph_name parameter (without /forward suffix) + # The code should automatically add "/forward" when looking up + df = inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph_name="edge_after_transform", + ) + + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + + def test_calculate_numeric_gap_with_invalid_reference_graph_name(self): + """Test that calculate_numeric_gap raises ValueError for invalid reference_graph_name.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ): + inspector_instance = Inspector( + etdump_path="", + etrecord="", + ) + + # Create a real ETRecord with empty graph_map + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.graph_map = {} + + inspector_instance._etrecord = mock_etrecord + + # Test with non-existent reference_graph_name + # Since "non_existent_graph" has no "/", it will be looked up as "non_existent_graph/forward" + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph_name="non_existent_graph", + ) + self.assertIn("not found", str(context.exception)) + self.assertIn("non_existent_graph/forward", str(context.exception)) + + def test_calculate_numeric_gap_with_exported_program_name_backprop_failure(self): + """Test that calculate_numeric_gap raises ValueError when exported_program backpropagation fails.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ): + inspector_instance = Inspector( + etdump_path="", + etrecord="", + ) + + # Create mock graph modules + class MockGraphModule: + def __init__(self): + self.graph = MagicMock() + + def module(self): + return self + + mock_exported_program = MockGraphModule() + mock_edge_dialect_program = MockGraphModule() + + # Create a real ETRecord with exported_program + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.exported_program = mock_exported_program + mock_etrecord.edge_dialect_program = mock_edge_dialect_program + mock_etrecord.export_graph_id = "graph_id" + mock_etrecord.graph_map = {} + + inspector_instance._etrecord = mock_etrecord + + # Mock propagate_back_debug_handle to return False (backpropagation failure) + with patch( + "executorch.devtools.inspector._inspector.propagate_back_debug_handle" + ) as mock_propagate: + mock_propagate.return_value = False + + # Test with "exported_program" should raise error when backpropagation fails + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph_name="exported_program", + ) + self.assertIn("Cannot use 'exported_program'", str(context.exception)) + self.assertIn("backpropagation failed", str(context.exception)) + + def test_calculate_numeric_gap_with_edge_dialect_exported_program_name(self): + """Test calculate_numeric_gap with edge_dialect_exported_program reference_graph_name.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ): + inspector_instance = Inspector( + etdump_path="", + etrecord="", + ) + + # Create mock intermediate outputs + aot_intermediate_outputs = { + (0,): ([torch.tensor([1.0, 2.0, 3.0])], 1), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): ["op_0"]} + runtime_debug_handle_to_op_name = {(0,): ["op_0"]} + + # Create mock graph modules + class MockGraphModule: + def __init__(self): + self.graph = MagicMock() + self.graph.nodes = [] + + def module(self): + return self + + mock_edge_dialect_program = MockGraphModule() + + # Create a real ETRecord + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.exported_program = None + mock_etrecord.edge_dialect_program = mock_edge_dialect_program + mock_etrecord.graph_map = {} + + inspector_instance._etrecord = mock_etrecord + + # Mock the runtime intermediate outputs + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Mock IntermediateOutputCapturer to return our AOT outputs + with patch( + "executorch.devtools.inspector._inspector.IntermediateOutputCapturer" + ) as mock_capturer_class, patch( + "executorch.devtools.inspector._inspector.get_aot_debug_handle_to_op_name_mapping" + ) as mock_get_mapping: + mock_capturer = MagicMock() + mock_capturer.run_and_capture.return_value = aot_intermediate_outputs + mock_capturer_class.return_value = mock_capturer + mock_get_mapping.return_value = aot_debug_handle_to_op_name + + # Test with edge_dialect_exported_program parameter + df = inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph_name="edge_dialect_exported_program", + ) + + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 1) + @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self): """