From 6e781ab333829d4cb20a1d223aefcf16c9406039 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 12 Feb 2026 18:03:44 -0800 Subject: [PATCH] [devtools] Auto-record after-transform graph in ETRecord When `to_edge_transform_and_lower` is called with `generate_etrecord=True` and custom `transform_passes` are applied, the after-transform graph is now automatically recorded in the ETRecord's `graph_map` under the key `"edge_after_transform"`. This enables backends like Qualcomm to use the post-custom-transform graph as the golden reference for numeric gap calculation, while have zero impact for regular exportation. This is part of the operator-level numeric discrepancy detector project for ExecuTorch Qualcomm backend (https://github.com/pytorch/executorch/issues/16381). Design doc: https://docs.google.com/document/d/1GaCHiy9InytOsUrl2BKEgOiP1iKTfpCVdWg6QDh0N2E/edit?tab=t.0#heading=h.fcrpnrtb6cud Differential Revision: [D93176563](https://our.internmc.facebook.com/intern/diff/D93176563/) [ghstack-poisoned] --- devtools/etrecord/tests/etrecord_test.py | 119 +++++++++++++++++++++++ exir/program/_program.py | 5 + 2 files changed, 124 insertions(+) diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 11463a976b4..a57515bffee 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -1819,3 +1819,122 @@ def test_multi_method_etrecord_generation(self): # Verify other ETRecord components are preserved self.assertIsNotNone(parsed_etrecord._debug_handle_map) self.assertIsNotNone(parsed_etrecord._delegate_map) + + def test_edge_after_transform_graph_capture(self): + """Test that to_edge_transform_and_lower with transform_passes captures the after-transform graph. + + When generate_etrecord=True and transform_passes are applied, the ETRecord should + contain the after-transform graph under the key 'edge_after_transform' in graph_map. + This enables backends like Qualcomm to use the post-custom-transform graph as the + golden reference for numeric gap calculation. + """ + from torch.fx.passes.infra.pass_base import PassBase, PassResult + + # Create a simple custom pass that modifies the graph + class SimpleCustomPass(PassBase): + """A simple pass that adds a marker attribute to each node.""" + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Mark each node to indicate this pass ran + for node in graph_module.graph.nodes: + node.meta["custom_pass_applied"] = True + return PassResult(graph_module=graph_module, modified=True) + + f = models.BasicSinMax() + aten_dialect = export(f, f.get_random_inputs(), strict=True) + + # Create edge program with custom transform pass and generate_etrecord=True + transform_passes = [SimpleCustomPass()] + + edge_manager = to_edge_transform_and_lower( + aten_dialect, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + transform_passes=transform_passes, + generate_etrecord=True, + ) + + # Verify that ETRecord was generated + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + + # Verify graph_map exists and contains the 'edge_after_transform' key + self.assertIsNotNone(etrecord.graph_map) + self.assertIn( + "edge_after_transform/forward", + etrecord.graph_map, + "graph_map should contain 'edge_after_transform/forward' when transform_passes are applied", + ) + + # Verify the captured graph has the custom pass marker + after_transform_graph = etrecord.graph_map["edge_after_transform/forward"] + self.assertIsNotNone(after_transform_graph) + + # Check that at least one node has the custom_pass_applied marker + has_marker = False + for node in after_transform_graph.graph.nodes: + if node.meta.get("custom_pass_applied", False): + has_marker = True + break + + self.assertTrue( + has_marker, + "The edge_after_transform graph should have the custom pass marker applied", + ) + + # Verify edge_dialect_program is still the pre-transform graph (original behavior preserved) + self.assertIsNotNone(etrecord.edge_dialect_program) + + # Save and parse the ETRecord to verify persistence + et_output = edge_manager.to_executorch() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_custom_pass.bin" + + # Get ETRecord and save + complete_etrecord = et_output.get_etrecord() + complete_etrecord.save(etrecord_path) + + # Parse ETRecord back + parsed_etrecord = parse_etrecord(etrecord_path) + + # Verify the after-transform graph is preserved after save/parse + self.assertIsNotNone(parsed_etrecord.graph_map) + self.assertIn( + "edge_after_transform/forward", + parsed_etrecord.graph_map, + "Parsed ETRecord should still contain 'edge_after_transform/forward'", + ) + + # Verify the parsed graph still has the marker + parsed_after_transform_graph = parsed_etrecord.graph_map[ + "edge_after_transform/forward" + ] + self.assertIsNotNone(parsed_after_transform_graph) + + def test_no_edge_after_transform_without_transform_passes(self): + """Test that 'edge_after_transform' is NOT added when no transform_passes are provided. + + This ensures backward compatibility - when generate_etrecord=True but no transform_passes + are applied, the ETRecord should NOT have an 'edge_after_transform' entry. + """ + f = models.BasicSinMax() + aten_dialect = export(f, f.get_random_inputs(), strict=True) + + # Create edge program WITHOUT transform_passes + edge_manager = to_edge_transform_and_lower( + aten_dialect, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + generate_etrecord=True, + ) + + # Verify that ETRecord was generated + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + + # Verify that 'edge_after_transform' is NOT in graph_map + if etrecord.graph_map is not None: + self.assertNotIn( + "edge_after_transform/forward", + etrecord.graph_map, + "graph_map should NOT contain 'edge_after_transform/forward' when no transform_passes are applied", + ) diff --git a/exir/program/_program.py b/exir/program/_program.py index 8e825f6f85b..66dfef0b287 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1385,6 +1385,11 @@ def to_edge_transform_and_lower( # noqa: C901 if transform_passes is not None: edge_manager = edge_manager.transform(transform_passes) + if generate_etrecord: + edge_manager._etrecord.add_extra_export_modules( + {"edge_after_transform": copy.deepcopy(edge_manager)} + ) + max_num_partitioners = 0 for partitioner_list in partitioner.values(): max_num_partitioners = max(max_num_partitioners, len(partitioner_list))