Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,13 +1160,25 @@ def quantized_conv2d_nhwc_per_tensor_meta(
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Determine weight layout based on input and weight dimensions:
# - 1D conv: input is 3D, weight is 3D [OC, K, IC]
# - 2D depthwise conv: input is 4D, weight is 3D [KH, KW, OC]
# - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC]
if len(in_size) == 3:
# 1D conv: weight is [OC, K, IC]
out_channels, *kernel_size, _ = weight.shape
elif len(weight.shape) == 3:
# 2D depthwise conv: weight is [KH, KW, OC]
*kernel_size, out_channels = weight.shape
else:
# 2D regular conv: weight is [OC, KH, KW, IC]
out_channels, *kernel_size, _ = weight.shape

# Compute the output tensor size
output_size = (
get_conv1d_output_size(
Expand Down
48 changes: 46 additions & 2 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,32 @@ def _change_nhwc_to_nchw(
permute_node.meta = node.meta
return permute_node

def _change_depthwise_weight_to_hwc(
self, graph: torch.fx.Graph, node: torch.fx.Node
) -> torch.fx.Node:
"""Convert depthwise weight from OIHW [OC, 1, KH, KW] to HWC [KH, KW, OC].

NNLib depthwise convolution expects weights in [KH, KW, OC] format when
inp_data_format=0 (NHWC), but the standard NCHW->NHWC permutation produces
[OC, KH, KW, 1]. This function applies the correct permutation for depthwise
convolution weights.
"""
# For depthwise: input shape is [OC, 1, KH, KW], target is [KH, KW, OC]
# Permute [0, 1, 2, 3] -> [2, 3, 0, 1] gives [KH, KW, OC, 1]
# Then squeeze the last dim (which is 1) to get [KH, KW, OC]
permute_indices = [2, 3, 0, 1]
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
)
permute_node.meta = node.meta

# Squeeze the last dimension (which has size 1)
squeeze_node = graph.call_function(
exir_ops.edge.aten.squeeze_copy.dim, (permute_node, -1), {}
)
squeeze_node.meta = node.meta
return squeeze_node

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
assert isinstance(node.target, EdgeOpOverload)
quantized_op = (
Expand All @@ -1093,12 +1119,30 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
input_node = cast(torch.fx.Node, node.args[0])
weight_node = cast(torch.fx.Node, node.args[1])

# Check if this is a depthwise convolution (groups == input_channels)
# and weight is 4D with shape [OC, 1, KH, KW]
groups = node.args[6]
input_shape = input_node.meta["val"].shape
weight_shape = weight_node.meta["val"].shape
input_channels = input_shape[1] # NCHW format, channels at index 1
# Depthwise conv has 4D weight [OC, 1, KH, KW] where the IC dim is 1
is_depthwise = (
groups == input_channels
and len(weight_shape) == 4
and weight_shape[1] == 1
)

# Insert transpose operations before the node
with graph.inserting_before(node):
# Convert input from NCHW to NHWC
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)
# Convert weight from NCHW to NHWC
weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node)
# Convert weight from NCHW to the appropriate format
if is_depthwise:
# For depthwise: [OC, 1, KH, KW] -> [KH, KW, OC] for NNLib
weight_nhwc = self._change_depthwise_weight_to_hwc(graph, weight_node)
else:
# For regular conv: [OC, IC, KH, KW] -> [OC, KH, KW, IC]
weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node)

# Non-quantized ops need to set the last optional argument to True
channel_last_arg = [] if quantized_op else [True]
Expand Down
139 changes: 138 additions & 1 deletion backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,8 +2019,145 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None:
"ReplaceConvWithChannelLastConvPass",
)

def create_depthwise_convolution_graph_module(
self,
) -> Tuple[Tuple[torch.Tensor, ...], torch.fx.GraphModule]:
"""Helper to create a depthwise convolution node.

For depthwise convolution, groups == input_channels.
Input shape: [N, C, H, W] = [1, 8, 224, 56] (NCHW)
Weight shape: [OC, 1, KH, KW] = [16, 1, 3, 3] where OC = C * channel_multiplier
"""
in_channels = 8
out_channels = 16
x = torch.randn(1, in_channels, 224, 56)
# Depthwise: weight shape is [out_channels, 1, kernel_h, kernel_w]
w = torch.randn(out_channels, 1, 3, 3)
b = torch.randn(out_channels)
stride = (1, 1)
padding = (1, 1)
dilation = (1, 1)
groups = in_channels # Depthwise: groups == input_channels
input_zero_point = 0
w_zero_point = 0
b_scale = 10
out_scale = 1
out_zero_point = 0
out_multiplier = 5
out_shift = 5
args = (
x,
w,
b,
stride,
padding,
dilation,
groups,
input_zero_point,
w_zero_point,
b_scale,
out_scale,
out_zero_point,
out_multiplier,
out_shift,
)
placeholders = (x, w, b)
gm = single_op_builder(
placeholders=placeholders,
op=exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
args=args,
)
return placeholders, gm

def test_depthwise_convolution_weight_shape(self) -> None:
"""Test that depthwise conv weight is transformed to [KH, KW, OC] format.

For depthwise convolution with NHWC layout, NNLib expects weights in
[KH, KW, OC] format (3D), not [OC, KH, KW, 1] (4D standard NCHW->NHWC).

The pass should:
1. Detect depthwise convolution (groups == input_channels)
2. Transform weight from [OC, 1, KH, KW] to [KH, KW, OC] (3D)
3. Use permute_copy + squeeze_copy operations
"""
placeholders, gm = self.create_depthwise_convolution_graph_module()
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor), 1
)
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
self.assertEqual(count_node(gm, exir_ops.edge.aten.squeeze_copy.dim), 0)

# Apply replacement pass.
p = ReplaceConvWithChannelLastConvPass()
gm_after_replacement = p.call(gm).graph_module

# Verify the quantized_conv2d_nhwc node exists
self.assertEqual(
count_node(
gm_after_replacement,
exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor,
),
1,
)

# For depthwise conv:
# - Input: 1 permute (NCHW -> NHWC)
# - Weight: 1 permute ([OC, 1, KH, KW] -> [KH, KW, OC, 1])
# - Output: 1 permute (NHWC -> NCHW)
# Total: 3 permutes
self.assertEqual(
count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default),
3,
)

# For depthwise conv, weight should also have squeeze_copy to go from
# [KH, KW, OC, 1] to [KH, KW, OC] (3D)
self.assertEqual(
count_node(gm_after_replacement, exir_ops.edge.aten.squeeze_copy.dim),
1,
)

# Find the weight node being passed to the quantized_conv2d_nhwc
for node in gm_after_replacement.graph.nodes:
if node.target != exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor:
continue

# The weight argument (index 1) should come from squeeze_copy for depthwise
weight_node = node.args[1]
self.assertEqual(
weight_node.target,
exir_ops.edge.aten.squeeze_copy.dim,
"Depthwise conv weight should be processed by squeeze_copy",
)

# The squeeze_copy input should come from permute_copy
permute_node = weight_node.args[0]
self.assertEqual(
permute_node.target,
exir_ops.edge.aten.permute_copy.default,
"squeeze_copy input should come from permute_copy",
)

class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase):
# Verify the weight shape after transformation is 3D [KH, KW, OC]
weight_shape = weight_node.meta["val"].shape
self.assertEqual(
len(weight_shape),
3,
f"Depthwise weight should be 3D [KH, KW, OC], got {len(weight_shape)}D",
)
# Original weight: [16, 1, 3, 3] (OC=16, 1, KH=3, KW=3)
# Expected after transform: [3, 3, 16] (KH, KW, OC)
self.assertEqual(weight_shape[0], 3) # KH
self.assertEqual(weight_shape[1], 3) # KW
self.assertEqual(weight_shape[2], 16) # OC

# Validate numerical accuracy
validate(
gm,
gm_after_replacement,
placeholders,
"ReplaceConvWithChannelLastConvPass",
)
def create_slice_graph(
self,
input_shape: Sequence[int],
Expand Down
Loading
Loading