Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,3 +1819,122 @@ def test_multi_method_etrecord_generation(self):
# Verify other ETRecord components are preserved
self.assertIsNotNone(parsed_etrecord._debug_handle_map)
self.assertIsNotNone(parsed_etrecord._delegate_map)

def test_edge_after_transform_graph_capture(self):
"""Test that to_edge_transform_and_lower with transform_passes captures the after-transform graph.

When generate_etrecord=True and transform_passes are applied, the ETRecord should
contain the after-transform graph under the key 'edge_after_transform' in graph_map.
This enables backends like Qualcomm to use the post-custom-transform graph as the
golden reference for numeric gap calculation.
"""
from torch.fx.passes.infra.pass_base import PassBase, PassResult

# Create a simple custom pass that modifies the graph
class SimpleCustomPass(PassBase):
"""A simple pass that adds a marker attribute to each node."""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Mark each node to indicate this pass ran
for node in graph_module.graph.nodes:
node.meta["custom_pass_applied"] = True
return PassResult(graph_module=graph_module, modified=True)

f = models.BasicSinMax()
aten_dialect = export(f, f.get_random_inputs(), strict=True)

# Create edge program with custom transform pass and generate_etrecord=True
transform_passes = [SimpleCustomPass()]

edge_manager = to_edge_transform_and_lower(
aten_dialect,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
transform_passes=transform_passes,
generate_etrecord=True,
)

# Verify that ETRecord was generated
self.assertIsNotNone(edge_manager._etrecord)
etrecord = edge_manager._etrecord

# Verify graph_map exists and contains the 'edge_after_transform' key
self.assertIsNotNone(etrecord.graph_map)
self.assertIn(
"edge_after_transform/forward",
etrecord.graph_map,
"graph_map should contain 'edge_after_transform/forward' when transform_passes are applied",
)

# Verify the captured graph has the custom pass marker
after_transform_graph = etrecord.graph_map["edge_after_transform/forward"]
self.assertIsNotNone(after_transform_graph)

# Check that at least one node has the custom_pass_applied marker
has_marker = False
for node in after_transform_graph.graph.nodes:
if node.meta.get("custom_pass_applied", False):
has_marker = True
break

self.assertTrue(
has_marker,
"The edge_after_transform graph should have the custom pass marker applied",
)

# Verify edge_dialect_program is still the pre-transform graph (original behavior preserved)
self.assertIsNotNone(etrecord.edge_dialect_program)

# Save and parse the ETRecord to verify persistence
et_output = edge_manager.to_executorch()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_custom_pass.bin"

# Get ETRecord and save
complete_etrecord = et_output.get_etrecord()
complete_etrecord.save(etrecord_path)

# Parse ETRecord back
parsed_etrecord = parse_etrecord(etrecord_path)

# Verify the after-transform graph is preserved after save/parse
self.assertIsNotNone(parsed_etrecord.graph_map)
self.assertIn(
"edge_after_transform/forward",
parsed_etrecord.graph_map,
"Parsed ETRecord should still contain 'edge_after_transform/forward'",
)

# Verify the parsed graph still has the marker
parsed_after_transform_graph = parsed_etrecord.graph_map[
"edge_after_transform/forward"
]
self.assertIsNotNone(parsed_after_transform_graph)

def test_no_edge_after_transform_without_transform_passes(self):
"""Test that 'edge_after_transform' is NOT added when no transform_passes are provided.

This ensures backward compatibility - when generate_etrecord=True but no transform_passes
are applied, the ETRecord should NOT have an 'edge_after_transform' entry.
"""
f = models.BasicSinMax()
aten_dialect = export(f, f.get_random_inputs(), strict=True)

# Create edge program WITHOUT transform_passes
edge_manager = to_edge_transform_and_lower(
aten_dialect,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
generate_etrecord=True,
)

# Verify that ETRecord was generated
self.assertIsNotNone(edge_manager._etrecord)
etrecord = edge_manager._etrecord

# Verify that 'edge_after_transform' is NOT in graph_map
if etrecord.graph_map is not None:
self.assertNotIn(
"edge_after_transform/forward",
etrecord.graph_map,
"graph_map should NOT contain 'edge_after_transform/forward' when no transform_passes are applied",
)
5 changes: 5 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,11 @@ def to_edge_transform_and_lower( # noqa: C901
if transform_passes is not None:
edge_manager = edge_manager.transform(transform_passes)

if generate_etrecord:
edge_manager._etrecord.add_extra_export_modules(
{"edge_after_transform": copy.deepcopy(edge_manager)}
)

max_num_partitioners = 0
for partitioner_list in partitioner.values():
max_num_partitioners = max(max_num_partitioners, len(partitioner_list))
Expand Down
Loading