From 2722eabe8cbfd2583b1efd481bbc4e49ed2e16fc Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 13 Feb 2026 15:38:15 -0800 Subject: [PATCH 1/2] Migrate quantized_conv2d tests to graph builder (#17451) Summary: As titled. Much easier to debug that way. Retains the initial tests, adds a few more for 1D cases (will migrate to true 1D kernels soon). We also update the nchw ops to handle `uint8_t` (that got lost in translation when splitting the ops probably). Reviewed By: DrJessop Differential Revision: D93112638 --- .../op_quantized_conv2d_nchw_out.cpp | 310 +++++++++++++++--- 1 file changed, 266 insertions(+), 44 deletions(-) diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp index a17f1e6a3bc..6035b545ad8 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp @@ -175,6 +175,45 @@ void xa_opt_quantized_conv2d_nchw( bool conv1d = input.dim() == 3; constexpr int kNnlibMaxDim = 4; + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 kernel_channels = weight.size(1); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + WORD32 dilation_width = dilation[1]; + WORD32 dilation_height = dilation[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + WORD32 kernel_precision = 8; + pVOID p_scratch = nullptr; + WORD32* ptr_scratch; + + WORD32 scratch_size = 0; + if (input.scalar_type() == ScalarType::Char) { WORD8* __restrict__ p_out = (WORD8* __restrict__)out.mutable_data_ptr(); @@ -185,48 +224,6 @@ void xa_opt_quantized_conv2d_nchw( WORD32* __restrict__ p_bias = (WORD32* __restrict__)bias.const_data_ptr(); - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 kernel_channels = weight.size(1); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - WORD32 dilation_width = dilation[1]; - WORD32 dilation_height = dilation[0]; - - // WORD32* kernel_bias_ptr = - // (WORD32*)weight_zero_point.const_data_ptr(); - - WORD32 input_zero_bias = -in_zero_point; - WORD32 kernel_zero_bias = -weight_zero_point; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - WORD32 kernel_precision = 8; - pVOID p_scratch = nullptr; - WORD32* ptr_scratch; - - WORD32 scratch_size = 0; - if (groups == 1) { WORD32 out_data_format = 1; @@ -245,13 +242,13 @@ void xa_opt_quantized_conv2d_nchw( WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8); WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = input.size(0); + p_inp_shape[0] = batches; p_inp_shape[1] = input_channels; p_inp_shape[2] = input_height; p_inp_shape[3] = input_width; WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = input.size(0); + p_out_shape[0] = batches; p_out_shape[1] = input_height; p_out_shape[2] = input_width; p_out_shape[3] = input_channels; @@ -439,6 +436,231 @@ void xa_opt_quantized_conv2d_nchw( return; } } + + if (input.scalar_type() == ScalarType::Byte) { + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 out_multiplier = out_multiplier32[0]; + WORD32 out_shift = out_shift32[0]; + + if (groups == 1) { + WORD32 out_data_format = 1; + + UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * input_channels * input_height * input_width) + 8) * + sizeof(UWORD8)); + + UWORD8* ptr2 = (UWORD8*)kernels::allocate_temp_memory( + ctx, + ((out_channels * kernel_channels * kernel_height * kernel_width) + + 8) * + sizeof(UWORD8)); + + UWORD8* pin = (UWORD8*)ALIGN_PTR(ptr1, 8); + UWORD8* pkernel = (UWORD8*)ALIGN_PTR(ptr2, 8); + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = input_channels; + p_inp_shape[2] = input_height; + p_inp_shape[3] = input_width; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = input_height; + p_out_shape[2] = input_width; + p_out_shape[3] = input_channels; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1}; + + xa_nn_transpose_8_8( + (WORD8*)pin, + p_out_shape, + (WORD8*)p_inp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 p_inp_shape1[kNnlibMaxDim]; + p_inp_shape1[0] = out_channels; + p_inp_shape1[1] = kernel_channels; + p_inp_shape1[2] = kernel_height; + p_inp_shape1[3] = kernel_width; + + WORD32 p_out_shape1[kNnlibMaxDim]; + p_out_shape1[0] = out_channels; + p_out_shape1[1] = kernel_height; + p_out_shape1[2] = kernel_width; + p_out_shape1[3] = kernel_channels; + + xa_nn_transpose_8_8( + (WORD8*)pkernel, + p_out_shape1, + (WORD8*)p_kernel, + p_inp_shape1, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + scratch_size = xa_nn_conv2d_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + y_stride, + y_padding, + x_stride, + x_padding, + out_height, + out_width, + out_channels, + inp_precision, + kernel_precision, + out_data_format); + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = + pin + _n * input_channels * input_height * input_width; + UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_std_asym8uxasym8u( + out_batch, + in_batch, + pkernel, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + out_channels, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier, + out_shift, + out_zero_bias, + out_data_format, + p_scratch); + } + return; + } + + if (groups == input_channels) { + WORD32 channels_multiplier = out_channels / input_channels; + + scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 1); // NCHW + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * out_channels * out_height * out_width) + 8) * + sizeof(UWORD8)); + + UWORD8* p_out_temp = (UWORD8*)ALIGN_PTR(ptr1, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = + p_inp + _n * input_channels * input_height * input_width; + UWORD8* out_batch = + p_out_temp + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_asym8uxasym8u( + out_batch, + p_kernel, + in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier, + out_shift, + out_zero_bias, + 1, // NCHW + 0, // NHWC + p_scratch); + } + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = out_height; + p_inp_shape[2] = out_width; + p_inp_shape[3] = out_channels; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = out_channels; + p_out_shape[2] = out_height; + p_out_shape[3] = out_width; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; + + xa_nn_transpose_8_8( + (WORD8*)p_out, + p_out_shape, + (WORD8*)p_out_temp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + return; + } + } } // The quantized convolution kernel. in_scale and weight_scale are implicit in From a6d7b4ac504c1f619fa59d753258c073b49fa37c Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 13 Feb 2026 15:38:15 -0800 Subject: [PATCH 2/2] Fix ReplaceConvWithChannelLastConvPass for nhwc depthwise convolutions (#17460) Summary: As titled. The current pass doesn't handle the depthwise weights with the right shape, leading to numerical errors. Add a test to catch that as well. Reviewed By: hsharma35, DrJessop Differential Revision: D93188974 --- backends/cadence/aot/ops_registrations.py | 16 +- backends/cadence/aot/replace_ops.py | 48 +++++- .../aot/tests/test_replace_ops_passes.py | 139 +++++++++++++++++- .../op_quantized_conv2d_nhwc_out.cpp | 11 +- 4 files changed, 200 insertions(+), 14 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 9b4568e008d..3c6575bb43c 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -1160,13 +1160,25 @@ def quantized_conv2d_nhwc_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: - out_channels, *kernel_size, _ = weight.shape - in_size = input.shape # Assert that the input tensor has at least 3 dimensions, and at most 6 assert len(in_size) > 2 assert len(in_size) < 6 + # Determine weight layout based on input and weight dimensions: + # - 1D conv: input is 3D, weight is 3D [OC, K, IC] + # - 2D depthwise conv: input is 4D, weight is 3D [KH, KW, OC] + # - 2D regular conv: input is 4D, weight is 4D [OC, KH, KW, IC] + if len(in_size) == 3: + # 1D conv: weight is [OC, K, IC] + out_channels, *kernel_size, _ = weight.shape + elif len(weight.shape) == 3: + # 2D depthwise conv: weight is [KH, KW, OC] + *kernel_size, out_channels = weight.shape + else: + # 2D regular conv: weight is [OC, KH, KW, IC] + out_channels, *kernel_size, _ = weight.shape + # Compute the output tensor size output_size = ( get_conv1d_output_size( diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 9f7dab28fc2..095dbfe969d 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1071,6 +1071,32 @@ def _change_nhwc_to_nchw( permute_node.meta = node.meta return permute_node + def _change_depthwise_weight_to_hwc( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert depthwise weight from OIHW [OC, 1, KH, KW] to HWC [KH, KW, OC]. + + NNLib depthwise convolution expects weights in [KH, KW, OC] format when + inp_data_format=0 (NHWC), but the standard NCHW->NHWC permutation produces + [OC, KH, KW, 1]. This function applies the correct permutation for depthwise + convolution weights. + """ + # For depthwise: input shape is [OC, 1, KH, KW], target is [KH, KW, OC] + # Permute [0, 1, 2, 3] -> [2, 3, 0, 1] gives [KH, KW, OC, 1] + # Then squeeze the last dim (which is 1) to get [KH, KW, OC] + permute_indices = [2, 3, 0, 1] + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} + ) + permute_node.meta = node.meta + + # Squeeze the last dimension (which has size 1) + squeeze_node = graph.call_function( + exir_ops.edge.aten.squeeze_copy.dim, (permute_node, -1), {} + ) + squeeze_node.meta = node.meta + return squeeze_node + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: assert isinstance(node.target, EdgeOpOverload) quantized_op = ( @@ -1093,12 +1119,30 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: input_node = cast(torch.fx.Node, node.args[0]) weight_node = cast(torch.fx.Node, node.args[1]) + # Check if this is a depthwise convolution (groups == input_channels) + # and weight is 4D with shape [OC, 1, KH, KW] + groups = node.args[6] + input_shape = input_node.meta["val"].shape + weight_shape = weight_node.meta["val"].shape + input_channels = input_shape[1] # NCHW format, channels at index 1 + # Depthwise conv has 4D weight [OC, 1, KH, KW] where the IC dim is 1 + is_depthwise = ( + groups == input_channels + and len(weight_shape) == 4 + and weight_shape[1] == 1 + ) + # Insert transpose operations before the node with graph.inserting_before(node): # Convert input from NCHW to NHWC input_nhwc = self._change_nchw_to_nhwc(graph, input_node) - # Convert weight from NCHW to NHWC - weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node) + # Convert weight from NCHW to the appropriate format + if is_depthwise: + # For depthwise: [OC, 1, KH, KW] -> [KH, KW, OC] for NNLib + weight_nhwc = self._change_depthwise_weight_to_hwc(graph, weight_node) + else: + # For regular conv: [OC, IC, KH, KW] -> [OC, KH, KW, IC] + weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node) # Non-quantized ops need to set the last optional argument to True channel_last_arg = [] if quantized_op else [True] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 60a24b82556..3804759c570 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -2019,8 +2019,145 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: "ReplaceConvWithChannelLastConvPass", ) + def create_depthwise_convolution_graph_module( + self, + ) -> Tuple[Tuple[torch.Tensor, ...], torch.fx.GraphModule]: + """Helper to create a depthwise convolution node. + + For depthwise convolution, groups == input_channels. + Input shape: [N, C, H, W] = [1, 8, 224, 56] (NCHW) + Weight shape: [OC, 1, KH, KW] = [16, 1, 3, 3] where OC = C * channel_multiplier + """ + in_channels = 8 + out_channels = 16 + x = torch.randn(1, in_channels, 224, 56) + # Depthwise: weight shape is [out_channels, 1, kernel_h, kernel_w] + w = torch.randn(out_channels, 1, 3, 3) + b = torch.randn(out_channels) + stride = (1, 1) + padding = (1, 1) + dilation = (1, 1) + groups = in_channels # Depthwise: groups == input_channels + input_zero_point = 0 + w_zero_point = 0 + b_scale = 10 + out_scale = 1 + out_zero_point = 0 + out_multiplier = 5 + out_shift = 5 + args = ( + x, + w, + b, + stride, + padding, + dilation, + groups, + input_zero_point, + w_zero_point, + b_scale, + out_scale, + out_zero_point, + out_multiplier, + out_shift, + ) + placeholders = (x, w, b) + gm = single_op_builder( + placeholders=placeholders, + op=exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + args=args, + ) + return placeholders, gm + + def test_depthwise_convolution_weight_shape(self) -> None: + """Test that depthwise conv weight is transformed to [KH, KW, OC] format. + + For depthwise convolution with NHWC layout, NNLib expects weights in + [KH, KW, OC] format (3D), not [OC, KH, KW, 1] (4D standard NCHW->NHWC). + + The pass should: + 1. Detect depthwise convolution (groups == input_channels) + 2. Transform weight from [OC, 1, KH, KW] to [KH, KW, OC] (3D) + 3. Use permute_copy + squeeze_copy operations + """ + placeholders, gm = self.create_depthwise_convolution_graph_module() + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor), 1 + ) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + self.assertEqual(count_node(gm, exir_ops.edge.aten.squeeze_copy.dim), 0) + + # Apply replacement pass. + p = ReplaceConvWithChannelLastConvPass() + gm_after_replacement = p.call(gm).graph_module + + # Verify the quantized_conv2d_nhwc node exists + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + ), + 1, + ) + + # For depthwise conv: + # - Input: 1 permute (NCHW -> NHWC) + # - Weight: 1 permute ([OC, 1, KH, KW] -> [KH, KW, OC, 1]) + # - Output: 1 permute (NHWC -> NCHW) + # Total: 3 permutes + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), + 3, + ) + + # For depthwise conv, weight should also have squeeze_copy to go from + # [KH, KW, OC, 1] to [KH, KW, OC] (3D) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.squeeze_copy.dim), + 1, + ) + + # Find the weight node being passed to the quantized_conv2d_nhwc + for node in gm_after_replacement.graph.nodes: + if node.target != exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: + continue + + # The weight argument (index 1) should come from squeeze_copy for depthwise + weight_node = node.args[1] + self.assertEqual( + weight_node.target, + exir_ops.edge.aten.squeeze_copy.dim, + "Depthwise conv weight should be processed by squeeze_copy", + ) + + # The squeeze_copy input should come from permute_copy + permute_node = weight_node.args[0] + self.assertEqual( + permute_node.target, + exir_ops.edge.aten.permute_copy.default, + "squeeze_copy input should come from permute_copy", + ) -class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase): + # Verify the weight shape after transformation is 3D [KH, KW, OC] + weight_shape = weight_node.meta["val"].shape + self.assertEqual( + len(weight_shape), + 3, + f"Depthwise weight should be 3D [KH, KW, OC], got {len(weight_shape)}D", + ) + # Original weight: [16, 1, 3, 3] (OC=16, 1, KH=3, KW=3) + # Expected after transform: [3, 3, 16] (KH, KW, OC) + self.assertEqual(weight_shape[0], 3) # KH + self.assertEqual(weight_shape[1], 3) # KW + self.assertEqual(weight_shape[2], 16) # OC + + # Validate numerical accuracy + validate( + gm, + gm_after_replacement, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) def create_slice_graph( self, input_shape: Sequence[int], diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp index b2a7c341997..4a620f346b6 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp @@ -284,7 +284,7 @@ void xa_opt_quantized_conv2d_nhwc( if (groups == input_channels) { WORD32 channels_multiplier = out_channels / input_channels; - + scratch_size = xa_nn_conv2d_depthwise_getsize( input_height, input_width, @@ -307,18 +307,11 @@ void xa_opt_quantized_conv2d_nhwc( p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * out_channels * out_height * out_width) + 8) * - sizeof(WORD8)); - - WORD8* p_out_temp = (WORD8*)ALIGN_PTR(ptr1, 8); - for (int _n = 0; _n < batches; _n++) { WORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; WORD8* out_batch = - p_out_temp + _n * out_channels * out_height * out_width; + p_out + _n * out_channels * out_height * out_width; xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( out_batch,