diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 6b6b4f583a6..368824f71a3 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -42,7 +42,6 @@ from executorch.devtools.etrecord import ETRecord, parse_etrecord from executorch.devtools.inspector._inspector_utils import ( calculate_time_scale_factor, - compare_intermediate_outputs, create_debug_handle_to_op_node_mapping, DebugHandle, display_or_print_df, @@ -50,7 +49,6 @@ EXCLUDED_COLUMNS_WHEN_PRINTING, EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT, EXCLUDED_EVENTS_WHEN_PRINTING, - find_op_names, find_populated_event, FORWARD, gen_etdump_object, @@ -1421,8 +1419,10 @@ def calculate_numeric_gap( Args: distance: The metrics the inspector will use for gap calculation. Can be either: - A string: one of "MSE", "L1", or "SNR" for built-in comparators. - - A custom NumericalComparatorBase instance: allows you to define custom comparison logic - by subclassing NumericalComparatorBase and implementing the compare() method. + - A custom NumericalComparatorBase instance: allows you to define custom comparison + logic by subclassing NumericalComparatorBase and implementing the element_compare() + method. Custom comparators can also override the preprocessing() method to apply + transformations (e.g., layout conversion, dequantization) before comparison. disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR @@ -1448,48 +1448,27 @@ def calculate_numeric_gap( mapping = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) + + # Get or create comparator if isinstance(distance, NumericalComparatorBase): comparator = distance + # Inject inspector if not already set + if comparator.inspector is None: + comparator.inspector = self else: metric = distance.strip().upper() if metric == "MSE": - comparator = MSEComparator() + comparator = MSEComparator(inspector=self) elif metric == "L1": - comparator = L1Comparator() + comparator = L1Comparator(inspector=self) elif metric == "SNR": - comparator = SNRComparator() + comparator = SNRComparator(inspector=self) else: raise ValueError(f"Unsupported distance metric {distance!r}") - rows = [] - for (aot_debug_handle, aot_intermediate_output), ( - runtime_debug_handle, - runtime_intermediate_output, - ) in mapping.items(): - if aot_intermediate_output is None or runtime_intermediate_output is None: - continue - # If aot outputs length is > 1 then comparison fails since we dont really have - # any instances where runtime intermediate output is a tuple or list - # This does not happen when edge dialect program is reference for comparison - # but happens in aten graph where ops like unbind remain undecomposed - if ( - isinstance(aot_intermediate_output, Sequence) - and len(aot_intermediate_output) > 1 - ): - continue - rows.append( - { - "aot_ops": find_op_names( - aot_debug_handle, aot_debug_handle_to_op_names - ), - "aot_intermediate_output": aot_intermediate_output, - "runtime_ops": find_op_names( - runtime_debug_handle, runtime_debug_handle_to_op_names - ), - "runtime_intermediate_output": runtime_intermediate_output, - "gap": compare_intermediate_outputs( - aot_intermediate_output, runtime_intermediate_output, comparator - ), - } - ) - return pd.DataFrame(rows) + # Delegate to comparator's compare method (includes preprocessing) + return comparator.compare( + mapping, + aot_debug_handle_to_op_names, + runtime_debug_handle_to_op_names, + ) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 878e0ddb7e0..556987e4bbf 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -1068,40 +1068,6 @@ def find_op_names( return result -def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: - """ - Compare two outputs, handling both sequence and non-sequence cases, - and return a list of comparison results. - Parameters: - a: The first intermediate output to compare. - b: The second intermediate output to compare. - comparator: A comparator object with a `compare` method. - Returns: - List[float]: A list of comparison results. - Raises: - ValueError: If one input is a sequence and the other is not, or if sequences have different lengths. - """ - is_a_sequence = isinstance(a, Sequence) - is_b_sequence = isinstance(b, Sequence) - if is_a_sequence and is_b_sequence: - # Ensure both sequences have the same length - if len(a) != len(b): - raise ValueError( - f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}." - ) - - # Compare each element in the sequences and return the list of results - return [comparator.compare(x, y) for x, y in zip(a, b)] - elif not is_a_sequence and not is_b_sequence: - # Compare non-sequence items and return the result in a list - return [comparator.compare(a, b)] - else: - # Raise an error if one is a sequence and the other is not - raise ValueError( - f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences." - ) - - def get_ancestor_node_identifiers(node: Node) -> List[str]: """Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in. diff --git a/devtools/inspector/numerical_comparator/__init__.py b/devtools/inspector/numerical_comparator/__init__.py index 0090c50025f..68ccfabe02a 100644 --- a/devtools/inspector/numerical_comparator/__init__.py +++ b/devtools/inspector/numerical_comparator/__init__.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. +# Re-export DebugHandle from _inspector_utils for convenience +from executorch.devtools.inspector._inspector_utils import DebugHandle from executorch.devtools.inspector.numerical_comparator.l1_numerical_comparator import ( L1Comparator, ) @@ -14,6 +16,7 @@ ) from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import ( + IntermediateOutputMapping, NumericalComparatorBase, ) @@ -22,4 +25,11 @@ ) -__all__ = ["L1Comparator", "MSEComparator", "SNRComparator", "NumericalComparatorBase"] +__all__ = [ + "DebugHandle", + "IntermediateOutputMapping", + "L1Comparator", + "MSEComparator", + "NumericalComparatorBase", + "SNRComparator", +] diff --git a/devtools/inspector/numerical_comparator/l1_numerical_comparator.py b/devtools/inspector/numerical_comparator/l1_numerical_comparator.py index 43f4f170c2f..ddc6233b769 100644 --- a/devtools/inspector/numerical_comparator/l1_numerical_comparator.py +++ b/devtools/inspector/numerical_comparator/l1_numerical_comparator.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any +from typing import Any, Optional, TYPE_CHECKING import torch from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor @@ -12,9 +12,17 @@ NumericalComparatorBase, ) +if TYPE_CHECKING: + from executorch.devtools.inspector._inspector import Inspector + class L1Comparator(NumericalComparatorBase): - def compare(self, a: Any, b: Any) -> float: + """L1 (sum of absolute differences) comparator for numerical discrepancy detection.""" + + def __init__(self, inspector: Optional["Inspector"] = None) -> None: + super().__init__(inspector) + + def element_compare(self, a: Any, b: Any) -> float: """Sum up all these element-wise absolute differences between two tensors.""" t_a = convert_to_float_tensor(a) diff --git a/devtools/inspector/numerical_comparator/mse_numerical_comparator.py b/devtools/inspector/numerical_comparator/mse_numerical_comparator.py index c4693ff2ad4..7a6b323e81a 100644 --- a/devtools/inspector/numerical_comparator/mse_numerical_comparator.py +++ b/devtools/inspector/numerical_comparator/mse_numerical_comparator.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any +from typing import Any, Optional, TYPE_CHECKING import torch from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor @@ -12,9 +12,17 @@ NumericalComparatorBase, ) +if TYPE_CHECKING: + from executorch.devtools.inspector._inspector import Inspector + class MSEComparator(NumericalComparatorBase): - def compare(self, a: Any, b: Any) -> float: + """Mean Squared Error comparator for numerical discrepancy detection.""" + + def __init__(self, inspector: Optional["Inspector"] = None) -> None: + super().__init__(inspector) + + def element_compare(self, a: Any, b: Any) -> float: """Compare mean squared difference between two outputs.""" t_a = convert_to_float_tensor(a) diff --git a/devtools/inspector/numerical_comparator/numerical_comparator_base.py b/devtools/inspector/numerical_comparator/numerical_comparator_base.py index db498980e1f..0d07be6f954 100644 --- a/devtools/inspector/numerical_comparator/numerical_comparator_base.py +++ b/devtools/inspector/numerical_comparator/numerical_comparator_base.py @@ -6,21 +6,258 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING + +import pandas as pd + +from executorch.devtools.inspector._inspector_utils import DebugHandle + +if TYPE_CHECKING: + from executorch.devtools.inspector._inspector import Inspector + +# Type alias for the mapping used in preprocessing +# Maps (aot_debug_handle, aot_output) -> (runtime_debug_handle, runtime_output) +IntermediateOutputMapping = Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]] class NumericalComparatorBase(ABC): + """Base class for numerical comparison with optional preprocessing. + + This class provides a framework for comparing intermediate outputs between + AOT (Ahead-of-Time) and runtime execution. Subclasses can override the + `preprocessing` method to transform tensors before comparison (e.g., layout + conversion, dequantization) and must implement `element_compare` for + element-wise comparison logic. + + The `compare` method is the main entry point called by Inspector, which + orchestrates the full comparison pipeline: preprocess -> element-wise compare + -> aggregate results into a DataFrame. + + Attributes: + _inspector: Optional reference to the Inspector instance, which provides + access to the reference graph and other metadata needed for preprocessing. + """ + + def __init__(self, inspector: Optional["Inspector"] = None) -> None: + """Initialize the comparator. + + Args: + inspector: Optional Inspector instance that provides access to the + reference graph and other metadata. Can be set later via the + `inspector` property. + """ + self._inspector: Optional["Inspector"] = inspector + + @property + def inspector(self) -> Optional["Inspector"]: + """Get the Inspector instance.""" + return self._inspector + + @inspector.setter + def inspector(self, value: Optional["Inspector"]) -> None: + """Set the Inspector instance.""" + self._inspector = value + + def preprocessing( + self, mapping: IntermediateOutputMapping + ) -> IntermediateOutputMapping: + """Transform the mapping before comparison. + + Override this method to apply custom preprocessing to the intermediate + outputs before comparison. This is useful for backends like Qualcomm that + require tensor transformations (e.g., dequantization, layout conversion) + before accurate numeric discrepancy measurement. + + The default implementation returns the mapping unchanged. + + Args: + mapping: Dictionary mapping AOT (debug_handle, intermediate_output) pairs + to runtime (debug_handle, intermediate_output) pairs. + + - Key: Tuple[DebugHandle, Any] + - DebugHandle: Tuple[int, ...] - debug handle(s) from AOT graph + - Any: torch.Tensor or sequence - AOT intermediate output + + - Value: Tuple[DebugHandle, Any] + - DebugHandle: Tuple[int, ...] - debug handle(s) from runtime + - Any: torch.Tensor or sequence - runtime intermediate output + + Returns: + The transformed mapping, ready for element-wise comparison. + + Note: + When implementing custom preprocessing, you can access the reference + graph via `self._inspector.get_reference_graph()` to retrieve node + metadata such as quantization parameters or layout information. + """ + return mapping + @abstractmethod - def compare(self, a: Any, b: Any) -> float: - """Compare two intermediate output and return a result. + def element_compare(self, a: Any, b: Any) -> float: + """Compare two tensors and return a scalar distance. - This method should be overridden by subclasses to provide custom comparison logic. + This method should be overridden by subclasses to provide custom + element-wise comparison logic (e.g., MSE, L1, SNR). Args: - a: The first intermediate output to compare. - b: The second intermediate output to compare. + a: The first intermediate output to compare (typically AOT output). + b: The second intermediate output to compare (typically runtime output). Returns: - A numerical result indicating the comparison outcome. + A numerical result indicating the comparison outcome (e.g., distance, + error metric). Lower values typically indicate better agreement. """ pass + + def compare( + self, + mapping: IntermediateOutputMapping, + aot_debug_handle_to_op_names: Dict[DebugHandle, List[str]], + runtime_debug_handle_to_op_names: Dict[DebugHandle, List[str]], + ) -> pd.DataFrame: + """Full comparison pipeline: preprocess -> element-wise compare -> aggregate. + + This is the main entry point called by Inspector.calculate_numeric_gap(). + It orchestrates the full comparison pipeline and returns a DataFrame + with the results. + + Args: + mapping: Dictionary mapping AOT (debug_handle, intermediate_output) pairs + to runtime (debug_handle, intermediate_output) pairs. + aot_debug_handle_to_op_names: Mapping from AOT debug handles to operator names. + runtime_debug_handle_to_op_names: Mapping from runtime debug handles to operator names. + + Returns: + pd.DataFrame: A DataFrame with columns: + - aot_ops: List of AOT operator names + - aot_intermediate_output: AOT intermediate output tensor + - runtime_ops: List of runtime operator names + - runtime_intermediate_output: Runtime intermediate output tensor + - gap: List of numerical gap values + """ + from executorch.devtools.inspector._inspector_utils import find_op_names + + def _validate_preprocessing_output( + processed_mapping: IntermediateOutputMapping, + ) -> None: + """Validate the output format of preprocessing(). + + Ensures the preprocessed mapping follows the expected format: + Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]] + + Args: + processed_mapping: The mapping returned by preprocessing(). + + Raises: + TypeError: If processed_mapping is not a dict. + ValueError: If any key or value in the mapping has an invalid format. + """ + if not isinstance(processed_mapping, dict): + raise TypeError( + f"preprocessing() must return a dict, got {type(processed_mapping).__name__}. " + "Expected format: Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]]" + ) + + for key, value in processed_mapping.items(): + # Validate key format: Tuple[DebugHandle, Any] + if not isinstance(key, tuple) or len(key) != 2: + raise ValueError( + f"Invalid key format in preprocessed mapping: {key}. " + "Expected Tuple[DebugHandle, Any] where DebugHandle is Tuple[int, ...]" + ) + aot_debug_handle, _ = key + if not isinstance(aot_debug_handle, tuple) or not all( + isinstance(x, int) for x in aot_debug_handle + ): + raise ValueError( + f"Invalid AOT debug handle in key: {aot_debug_handle}. " + "Expected Tuple[int, ...]" + ) + + # Validate value format: Tuple[DebugHandle, Any] + if not isinstance(value, tuple) or len(value) != 2: + raise ValueError( + f"Invalid value format in preprocessed mapping: {value}. " + "Expected Tuple[DebugHandle, Any] where DebugHandle is Tuple[int, ...]" + ) + runtime_debug_handle, _ = value + if not isinstance(runtime_debug_handle, tuple) or not all( + isinstance(x, int) for x in runtime_debug_handle + ): + raise ValueError( + f"Invalid runtime debug handle in value: {runtime_debug_handle}. " + "Expected Tuple[int, ...]" + ) + + def _compare_intermediate_outputs(a: Any, b: Any) -> List[float]: + """Compare two outputs, handling both sequence and non-sequence cases. + + Args: + a: The first intermediate output to compare. + b: The second intermediate output to compare. + + Returns: + List[float]: A list of comparison results. + + Raises: + ValueError: If one input is a sequence and the other is not, + or if sequences have different lengths. + """ + is_a_sequence = isinstance(a, Sequence) + is_b_sequence = isinstance(b, Sequence) + if is_a_sequence and is_b_sequence: + if len(a) != len(b): + raise ValueError( + f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length " + f"for comparison. len(a): {len(a)} len(b): {len(b)}." + ) + return [self.element_compare(x, y) for x, y in zip(a, b)] + elif not is_a_sequence and not is_b_sequence: + return [self.element_compare(a, b)] + else: + raise ValueError( + f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences " + f"or both must be non-sequences." + ) + + # Step 1: Apply preprocessing + processed_mapping = self.preprocessing(mapping) + + # Validate the preprocessed mapping format + _validate_preprocessing_output(processed_mapping) + + # Step 2: Element-wise comparison and aggregation + rows = [] + for (aot_debug_handle, aot_intermediate_output), ( + runtime_debug_handle, + runtime_intermediate_output, + ) in processed_mapping.items(): + if aot_intermediate_output is None or runtime_intermediate_output is None: + continue + # If aot outputs length is > 1 then comparison fails since we don't really have + # any instances where runtime intermediate output is a tuple or list. + # This does not happen when edge dialect program is reference for comparison + # but happens in aten graph where ops like unbind remain undecomposed. + if ( + isinstance(aot_intermediate_output, Sequence) + and len(aot_intermediate_output) > 1 + ): + continue + rows.append( + { + "aot_ops": find_op_names( + aot_debug_handle, aot_debug_handle_to_op_names + ), + "aot_intermediate_output": aot_intermediate_output, + "runtime_ops": find_op_names( + runtime_debug_handle, runtime_debug_handle_to_op_names + ), + "runtime_intermediate_output": runtime_intermediate_output, + "gap": _compare_intermediate_outputs( + aot_intermediate_output, runtime_intermediate_output + ), + } + ) + + # Step 3: Build and return DataFrame + return pd.DataFrame(rows) diff --git a/devtools/inspector/numerical_comparator/snr_numerical_comparator.py b/devtools/inspector/numerical_comparator/snr_numerical_comparator.py index efe881a2549..1e474a7eba3 100644 --- a/devtools/inspector/numerical_comparator/snr_numerical_comparator.py +++ b/devtools/inspector/numerical_comparator/snr_numerical_comparator.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Any +from typing import Any, Optional, TYPE_CHECKING import torch from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor @@ -13,9 +13,17 @@ NumericalComparatorBase, ) +if TYPE_CHECKING: + from executorch.devtools.inspector._inspector import Inspector + class SNRComparator(NumericalComparatorBase): - def compare(self, a: Any, b: Any) -> float: + """Signal-to-Noise Ratio comparator for numerical discrepancy detection.""" + + def __init__(self, inspector: Optional["Inspector"] = None) -> None: + super().__init__(inspector) + + def element_compare(self, a: Any, b: Any) -> float: """ Compare the Signal-to-Noise Ratio (SNR) between two inputs Formula: SNR = 10 * log10(original_power / error_power) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 422f5d5defe..6f22cd106c8 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -730,7 +730,7 @@ def test_calculate_numeric_gap_with_custom_comparator(self): # Create a custom comparator that returns the max absolute difference class MaxAbsDiffComparator(NumericalComparatorBase): - def compare(self, a, b): + def element_compare(self, a, b): if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): return torch.max(torch.abs(a - b)).item() return abs(a - b) @@ -795,6 +795,235 @@ def compare(self, a, b): # For (1,): max(|[4.0, 5.0, 6.0] - [3.0, 6.0, 5.0]|) = max([1.0, 1.0, 1.0]) = 1.0 self.assertEqual(df.iloc[1]["gap"][0], 1.0) + def test_calculate_numeric_gap_with_custom_comparator_and_preprocessing(self): + """Test calculate_numeric_gap with a custom comparator that includes preprocessing.""" + from executorch.devtools.inspector.numerical_comparator import ( + IntermediateOutputMapping, + NumericalComparatorBase, + ) + + # Create a custom comparator with preprocessing that scales runtime tensors by 2x + class ScalingComparator(NumericalComparatorBase): + def __init__(self, scale_factor: float = 2.0): + super().__init__() + self.scale_factor = scale_factor + self.preprocessing_called = False + + def preprocessing( + self, mapping: IntermediateOutputMapping + ) -> IntermediateOutputMapping: + """Scale runtime tensors by scale_factor before comparison.""" + self.preprocessing_called = True + transformed_mapping = {} + for (aot_handle, aot_output), ( + runtime_handle, + runtime_output, + ) in mapping.items(): + # Scale the runtime output + if isinstance(runtime_output, torch.Tensor): + scaled_runtime_output = runtime_output * self.scale_factor + else: + scaled_runtime_output = runtime_output + transformed_mapping[(aot_handle, aot_output)] = ( + runtime_handle, + scaled_runtime_output, + ) + return transformed_mapping + + def element_compare(self, a, b) -> float: + """Compute MSE between two tensors.""" + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return torch.mean(torch.square(a.float() - b.float())).item() + return (a - b) ** 2 + + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # AOT outputs: [1.0, 2.0, 3.0] and [4.0, 5.0, 6.0] + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + + # Runtime outputs: [1.0, 1.0, 1.0] and [2.0, 2.0, 2.0] + # After 2x scaling: [2.0, 2.0, 2.0] and [4.0, 4.0, 4.0] + runtime_intermediate_outputs = { + (0,): ([torch.tensor([1.0, 1.0, 1.0])], 1), + (1,): ([torch.tensor([2.0, 2.0, 2.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Create custom comparator with 2x scaling + custom_comparator = ScalingComparator(scale_factor=2.0) + + # Test with custom comparator + df = inspector_instance.calculate_numeric_gap(distance=custom_comparator) + + # Verify preprocessing was called + self.assertTrue(custom_comparator.preprocessing_called) + + # Verify DataFrame structure + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + cols = set(df.columns) + expected_cols = { + "aot_ops", + "aot_intermediate_output", + "runtime_ops", + "runtime_intermediate_output", + "gap", + } + self.assertEqual(cols, expected_cols) + + # Verify the comparison after preprocessing + # For (0,): AOT=[1.0, 2.0, 3.0], Runtime after scaling=[2.0, 2.0, 2.0] + # MSE = mean((1-2)^2 + (2-2)^2 + (3-2)^2) = mean(1 + 0 + 1) = 2/3 + expected_gap_0 = (1.0 + 0.0 + 1.0) / 3.0 + self.assertAlmostEqual(df.iloc[0]["gap"][0], expected_gap_0, places=5) + + # For (1,): AOT=[4.0, 5.0, 6.0], Runtime after scaling=[4.0, 4.0, 4.0] + # MSE = mean((4-4)^2 + (5-4)^2 + (6-4)^2) = mean(0 + 1 + 4) = 5/3 + expected_gap_1 = (0.0 + 1.0 + 4.0) / 3.0 + self.assertAlmostEqual(df.iloc[1]["gap"][0], expected_gap_1, places=5) + + def test_calculate_numeric_gap_with_invalid_preprocessing_output(self): + """Test that invalid preprocessing output raises appropriate errors.""" + from executorch.devtools.inspector.numerical_comparator import ( + NumericalComparatorBase, + ) + + # Test 1: preprocessing returns non-dict + class NonDictPreprocessingComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return "invalid" # Should return a dict + + def element_compare(self, a, b) -> float: + return 0.0 + + # Test 2: preprocessing returns dict with invalid key format + class InvalidKeyFormatComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return {"invalid_key": ((0,), torch.tensor([1.0]))} + + def element_compare(self, a, b) -> float: + return 0.0 + + # Test 3: preprocessing returns dict with invalid debug handle in key + class InvalidKeyDebugHandleComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return { + (("not_int",), torch.tensor([1.0])): ((0,), torch.tensor([1.0])) + } + + def element_compare(self, a, b) -> float: + return 0.0 + + # Test 4: preprocessing returns dict with invalid value format + class InvalidValueFormatComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return {((0,), torch.tensor([1.0])): "invalid_value"} + + def element_compare(self, a, b) -> float: + return 0.0 + + # Test 5: preprocessing returns dict with invalid debug handle in value + class InvalidValueDebugHandleComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return { + ((0,), torch.tensor([1.0])): (("not_int",), torch.tensor([1.0])) + } + + def element_compare(self, a, b) -> float: + return 0.0 + + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([1.0, 1.0, 1.0])], 1), + } + aot_debug_handle_to_op_name = {(0,): "op_0"} + runtime_debug_handle_to_op_name = {(0,): "op_0"} + + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Test 1: Non-dict return type + with self.assertRaises(TypeError) as context: + inspector_instance.calculate_numeric_gap( + distance=NonDictPreprocessingComparator() + ) + self.assertIn("must return a dict", str(context.exception)) + + # Test 2: Invalid key format + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance=InvalidKeyFormatComparator() + ) + self.assertIn("Invalid key format", str(context.exception)) + + # Test 3: Invalid debug handle in key + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance=InvalidKeyDebugHandleComparator() + ) + self.assertIn("Invalid AOT debug handle", str(context.exception)) + + # Test 4: Invalid value format + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance=InvalidValueFormatComparator() + ) + self.assertIn("Invalid value format", str(context.exception)) + + # Test 5: Invalid debug handle in value + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance=InvalidValueDebugHandleComparator() + ) + self.assertIn("Invalid runtime debug handle", str(context.exception)) + @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self): """ diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 8c4bb4b38b9..b1f32c0ec6e 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -30,7 +30,6 @@ calculate_mse, calculate_snr, calculate_time_scale_factor, - compare_intermediate_outputs, convert_to_float_tensor, create_debug_handle_to_op_node_mapping, EDGE_DIALECT_GRAPH_KEY, @@ -606,24 +605,6 @@ def test_find_op_names_mixed_single_and_multiple_ops(self): ["op1", "op2", "op3", "op4", "op5", "op6", "op7"], ) - def test_compare_intermediate_outputs_sequences(self): - a = [1.0, 2.0, 3.0] - b = [1.0, 2.5, 3.5] - result = compare_intermediate_outputs(a, b, L1Comparator()) - self.assertEqual(result, [0.0, 0.5, 0.5]) - - def test_compare_intermediate_outputs_diff_len_sequences(self): - a = [1.0, 2.0] - b = [1.0, 2.0, 3.0] - with self.assertRaises(ValueError): - compare_intermediate_outputs(a, b, L1Comparator()) - - def test_compare_intermediate_outputs_sequence_and_non_sequence(self): - a = [1.0, 2.0] - b = 1.0 - with self.assertRaises(ValueError): - compare_intermediate_outputs(a, b, L1Comparator()) - def test_equip_debug_handle_to_export_program_success(self): """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" # Create a test model diff --git a/devtools/inspector/tests/l1_comparator_test.py b/devtools/inspector/tests/l1_comparator_test.py index 1e9f0be9c10..b2c1a86910e 100644 --- a/devtools/inspector/tests/l1_comparator_test.py +++ b/devtools/inspector/tests/l1_comparator_test.py @@ -18,32 +18,32 @@ def test_identical_tensors(self): a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[1, 2], [3, 4]]) expected = 0.0 - result = self.l1_comparator.compare(a, b) + result = self.l1_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_scalar(self): a = 1 b = 2 expected = 1.0 - result = self.l1_comparator.compare(a, b) + result = self.l1_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_with_nans_replaced_with_zero(self): a = torch.tensor([3, 2, -1, float("nan")]) b = torch.tensor([float("nan"), 0, -3, 1]) expected = 8.0 - result = self.l1_comparator.compare(a, b) + result = self.l1_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_shape_mismatch_raises_exception(self): a = torch.tensor([0, 2, -1]) b = torch.tensor([1, 0, -3, 4]) with self.assertRaises(ValueError): - self.l1_comparator.compare(a, b) + self.l1_comparator.element_compare(a, b) def test_2D_tensors(self): a = torch.tensor([[4, 9], [6, 4]]) b = torch.tensor([[1, 2], [3, 5]]) expected = 14.0 - result = self.l1_comparator.compare(a, b) + result = self.l1_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) diff --git a/devtools/inspector/tests/mse_comparator_test.py b/devtools/inspector/tests/mse_comparator_test.py index b24302e12e8..f9e61af4e88 100644 --- a/devtools/inspector/tests/mse_comparator_test.py +++ b/devtools/inspector/tests/mse_comparator_test.py @@ -18,32 +18,32 @@ def test_identical_tensors(self): a = torch.tensor([[10, 4], [3, 4]]) b = torch.tensor([[10, 4], [3, 4]]) expected = 0.0 - result = self.mse_comparator.compare(a, b) + result = self.mse_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_scalar(self): a = 10 b = 2 expected = 64.0 - result = self.mse_comparator.compare(a, b) + result = self.mse_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_with_nans_replaced_with_zero(self): a = torch.tensor([3, 1, -3, float("nan")]) b = torch.tensor([float("nan"), 0, -3, 2]) expected = (9.0 + 1.0 + 0.0 + 4.0) / 4.0 - result = self.mse_comparator.compare(a, b) + result = self.mse_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_shape_mismatch_raises_exception(self): a = torch.tensor([0, 2, -1]) b = torch.tensor([1, 1, -3, 4]) with self.assertRaises(ValueError): - self.mse_comparator.compare(a, b) + self.mse_comparator.element_compare(a, b) def test_2D_tensors(self): a = torch.tensor([[4, 9], [6, 4]]) b = torch.tensor([[1, 2], [3, 10]]) expected = (9.0 + 49.0 + 9.0 + 36.0) / 4.0 - result = self.mse_comparator.compare(a, b) + result = self.mse_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) diff --git a/devtools/inspector/tests/snr_comparator_test.py b/devtools/inspector/tests/snr_comparator_test.py index b21e1f3d61a..93d0a2f5deb 100644 --- a/devtools/inspector/tests/snr_comparator_test.py +++ b/devtools/inspector/tests/snr_comparator_test.py @@ -19,27 +19,27 @@ def test_identical_tensors(self): # identical tensors --> error_power == 0 --> SNR is inf a = torch.tensor([[10, 4], [3, 4]]) b = torch.tensor([[10, 4], [3, 4]]) - result = self.snr_comparator.compare(a, b) + result = self.snr_comparator.element_compare(a, b) self.assertTrue(math.isinf(result) and result > 0) def test_scalar(self): # original_power == 1, error_power == 1 --> SNR = 10 * log10(1/1) = 0 a = 1 b = 2 - result = self.snr_comparator.compare(a, b) + result = self.snr_comparator.element_compare(a, b) self.assertAlmostEqual(result, 0.0) def test_with_nans_replaced_with_zero(self): a = torch.tensor([float("nan"), 1.0]) b = torch.tensor([0.0, 1.0]) - result = self.snr_comparator.compare(a, b) + result = self.snr_comparator.element_compare(a, b) self.assertTrue(math.isinf(result) and result > 0) def test_shape_mismatch_raises_exception(self): a = torch.tensor([1, 2, -1]) b = torch.tensor([1, 1, -3, 4]) with self.assertRaises(ValueError): - self.snr_comparator.compare(a, b) + self.snr_comparator.element_compare(a, b) def test_2D_tensors(self): # original_power = mean([16, 81, 36, 16]) = 37.25 @@ -48,5 +48,5 @@ def test_2D_tensors(self): a = torch.tensor([[4, 9], [6, 4]]) b = torch.tensor([[1, 2], [3, 5]]) expected = 10 * math.log10(37.25 / 17.0) - result = self.snr_comparator.compare(a, b) + result = self.snr_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected)