Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 18 additions & 39 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@
from executorch.devtools.etrecord import ETRecord, parse_etrecord
from executorch.devtools.inspector._inspector_utils import (
calculate_time_scale_factor,
compare_intermediate_outputs,
create_debug_handle_to_op_node_mapping,
DebugHandle,
display_or_print_df,
EDGE_DIALECT_GRAPH_KEY,
EXCLUDED_COLUMNS_WHEN_PRINTING,
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT,
EXCLUDED_EVENTS_WHEN_PRINTING,
find_op_names,
find_populated_event,
FORWARD,
gen_etdump_object,
Expand Down Expand Up @@ -1421,8 +1419,10 @@ def calculate_numeric_gap(
Args:
distance: The metrics the inspector will use for gap calculation. Can be either:
- A string: one of "MSE", "L1", or "SNR" for built-in comparators.
- A custom NumericalComparatorBase instance: allows you to define custom comparison logic
by subclassing NumericalComparatorBase and implementing the compare() method.
- A custom NumericalComparatorBase instance: allows you to define custom comparison
logic by subclassing NumericalComparatorBase and implementing the element_compare()
method. Custom comparators can also override the preprocessing() method to apply
transformations (e.g., layout conversion, dequantization) before comparison.
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose
connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR
Expand All @@ -1448,48 +1448,27 @@ def calculate_numeric_gap(
mapping = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)

# Get or create comparator
if isinstance(distance, NumericalComparatorBase):
comparator = distance
# Inject inspector if not already set
if comparator.inspector is None:
comparator.inspector = self
else:
metric = distance.strip().upper()
if metric == "MSE":
comparator = MSEComparator()
comparator = MSEComparator(inspector=self)
elif metric == "L1":
comparator = L1Comparator()
comparator = L1Comparator(inspector=self)
elif metric == "SNR":
comparator = SNRComparator()
comparator = SNRComparator(inspector=self)
else:
raise ValueError(f"Unsupported distance metric {distance!r}")

rows = []
for (aot_debug_handle, aot_intermediate_output), (
runtime_debug_handle,
runtime_intermediate_output,
) in mapping.items():
if aot_intermediate_output is None or runtime_intermediate_output is None:
continue
# If aot outputs length is > 1 then comparison fails since we dont really have
# any instances where runtime intermediate output is a tuple or list
# This does not happen when edge dialect program is reference for comparison
# but happens in aten graph where ops like unbind remain undecomposed
if (
isinstance(aot_intermediate_output, Sequence)
and len(aot_intermediate_output) > 1
):
continue
rows.append(
{
"aot_ops": find_op_names(
aot_debug_handle, aot_debug_handle_to_op_names
),
"aot_intermediate_output": aot_intermediate_output,
"runtime_ops": find_op_names(
runtime_debug_handle, runtime_debug_handle_to_op_names
),
"runtime_intermediate_output": runtime_intermediate_output,
"gap": compare_intermediate_outputs(
aot_intermediate_output, runtime_intermediate_output, comparator
),
}
)
return pd.DataFrame(rows)
# Delegate to comparator's compare method (includes preprocessing)
return comparator.compare(
mapping,
aot_debug_handle_to_op_names,
runtime_debug_handle_to_op_names,
)
34 changes: 0 additions & 34 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,40 +1068,6 @@ def find_op_names(
return result


def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
"""
Compare two outputs, handling both sequence and non-sequence cases,
and return a list of comparison results.
Parameters:
a: The first intermediate output to compare.
b: The second intermediate output to compare.
comparator: A comparator object with a `compare` method.
Returns:
List[float]: A list of comparison results.
Raises:
ValueError: If one input is a sequence and the other is not, or if sequences have different lengths.
"""
is_a_sequence = isinstance(a, Sequence)
is_b_sequence = isinstance(b, Sequence)
if is_a_sequence and is_b_sequence:
# Ensure both sequences have the same length
if len(a) != len(b):
raise ValueError(
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}."
)

# Compare each element in the sequences and return the list of results
return [comparator.compare(x, y) for x, y in zip(a, b)]
elif not is_a_sequence and not is_b_sequence:
# Compare non-sequence items and return the result in a list
return [comparator.compare(a, b)]
else:
# Raise an error if one is a sequence and the other is not
raise ValueError(
f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences."
)


def get_ancestor_node_identifiers(node: Node) -> List[str]:
"""Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.

Expand Down
12 changes: 11 additions & 1 deletion devtools/inspector/numerical_comparator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.


# Re-export DebugHandle from _inspector_utils for convenience
from executorch.devtools.inspector._inspector_utils import DebugHandle
from executorch.devtools.inspector.numerical_comparator.l1_numerical_comparator import (
L1Comparator,
)
Expand All @@ -14,6 +16,7 @@
)

from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
IntermediateOutputMapping,
NumericalComparatorBase,
)

Expand All @@ -22,4 +25,11 @@
)


__all__ = ["L1Comparator", "MSEComparator", "SNRComparator", "NumericalComparatorBase"]
__all__ = [
"DebugHandle",
"IntermediateOutputMapping",
"L1Comparator",
"MSEComparator",
"NumericalComparatorBase",
"SNRComparator",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any
from typing import Any, Optional, TYPE_CHECKING

import torch
from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor
from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
NumericalComparatorBase,
)

if TYPE_CHECKING:
from executorch.devtools.inspector._inspector import Inspector


class L1Comparator(NumericalComparatorBase):
def compare(self, a: Any, b: Any) -> float:
"""L1 (sum of absolute differences) comparator for numerical discrepancy detection."""

def __init__(self, inspector: Optional["Inspector"] = None) -> None:
super().__init__(inspector)

def element_compare(self, a: Any, b: Any) -> float:
"""Sum up all these element-wise absolute differences between two tensors."""

t_a = convert_to_float_tensor(a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any
from typing import Any, Optional, TYPE_CHECKING

import torch
from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor
from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
NumericalComparatorBase,
)

if TYPE_CHECKING:
from executorch.devtools.inspector._inspector import Inspector


class MSEComparator(NumericalComparatorBase):
def compare(self, a: Any, b: Any) -> float:
"""Mean Squared Error comparator for numerical discrepancy detection."""

def __init__(self, inspector: Optional["Inspector"] = None) -> None:
super().__init__(inspector)

def element_compare(self, a: Any, b: Any) -> float:
"""Compare mean squared difference between two outputs."""

t_a = convert_to_float_tensor(a)
Expand Down
Loading
Loading