Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 266 additions & 44 deletions backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,45 @@ void xa_opt_quantized_conv2d_nchw(
bool conv1d = input.dim() == 3;
constexpr int kNnlibMaxDim = 4;

WORD32 input_height = conv1d ? 1 : input.size(2);
WORD32 input_width = conv1d ? input.size(2) : input.size(3);
WORD32 input_channels = input.size(1);
WORD32 kernel_height = conv1d ? 1 : weight.size(2);
WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3);
WORD32 kernel_channels = weight.size(1);
WORD32 out_channels = weight.size(0);
WORD32 out_height = conv1d ? 1 : out.size(2);
WORD32 out_width = conv1d ? out.size(2) : out.size(3);
WORD32 batches = input.size(0);

WORD32 x_stride = stride[1];
WORD32 y_stride = stride[0];
WORD32 x_padding = padding[1];
WORD32 y_padding = padding[0];
WORD32 dilation_width = dilation[1];
WORD32 dilation_height = dilation[0];

WORD32 input_zero_bias = -in_zero_point;
WORD32 kernel_zero_bias = -weight_zero_point;

WORD32 out_multiplier32[out_channels];
WORD32 out_shift32[out_channels];

float out_scale = 1. / output_scale;

for (int i = 0; i < out_channels; i++) {
out_multiplier32[i] = bias_scale * out_scale * 2147483648;
out_shift32[i] = 0;
}

WORD32 out_zero_bias = output_zero_point;
WORD32 inp_precision = 8;
WORD32 kernel_precision = 8;
pVOID p_scratch = nullptr;
WORD32* ptr_scratch;

WORD32 scratch_size = 0;

if (input.scalar_type() == ScalarType::Char) {
WORD8* __restrict__ p_out =
(WORD8* __restrict__)out.mutable_data_ptr<int8_t>();
Expand All @@ -185,48 +224,6 @@ void xa_opt_quantized_conv2d_nchw(
WORD32* __restrict__ p_bias =
(WORD32* __restrict__)bias.const_data_ptr<int32_t>();

WORD32 input_height = conv1d ? 1 : input.size(2);
WORD32 input_width = conv1d ? input.size(2) : input.size(3);
WORD32 input_channels = input.size(1);
WORD32 kernel_height = conv1d ? 1 : weight.size(2);
WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3);
WORD32 kernel_channels = weight.size(1);
WORD32 out_channels = weight.size(0);
WORD32 out_height = conv1d ? 1 : out.size(2);
WORD32 out_width = conv1d ? out.size(2) : out.size(3);
WORD32 batches = input.size(0);

WORD32 x_stride = stride[1];
WORD32 y_stride = stride[0];
WORD32 x_padding = padding[1];
WORD32 y_padding = padding[0];
WORD32 dilation_width = dilation[1];
WORD32 dilation_height = dilation[0];

// WORD32* kernel_bias_ptr =
// (WORD32*)weight_zero_point.const_data_ptr<int32_t>();

WORD32 input_zero_bias = -in_zero_point;
WORD32 kernel_zero_bias = -weight_zero_point;

WORD32 out_multiplier32[out_channels];
WORD32 out_shift32[out_channels];

float out_scale = 1. / output_scale;

for (int i = 0; i < out_channels; i++) {
out_multiplier32[i] = bias_scale * out_scale * 2147483648;
out_shift32[i] = 0;
}

WORD32 out_zero_bias = output_zero_point;
WORD32 inp_precision = 8;
WORD32 kernel_precision = 8;
pVOID p_scratch = nullptr;
WORD32* ptr_scratch;

WORD32 scratch_size = 0;

if (groups == 1) {
WORD32 out_data_format = 1;

Expand All @@ -245,13 +242,13 @@ void xa_opt_quantized_conv2d_nchw(
WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8);

WORD32 p_inp_shape[kNnlibMaxDim];
p_inp_shape[0] = input.size(0);
p_inp_shape[0] = batches;
p_inp_shape[1] = input_channels;
p_inp_shape[2] = input_height;
p_inp_shape[3] = input_width;

WORD32 p_out_shape[kNnlibMaxDim];
p_out_shape[0] = input.size(0);
p_out_shape[0] = batches;
p_out_shape[1] = input_height;
p_out_shape[2] = input_width;
p_out_shape[3] = input_channels;
Expand Down Expand Up @@ -439,6 +436,231 @@ void xa_opt_quantized_conv2d_nchw(
return;
}
}

if (input.scalar_type() == ScalarType::Byte) {
UWORD8* __restrict__ p_out =
(UWORD8* __restrict__)out.mutable_data_ptr<uint8_t>();
UWORD8* __restrict__ p_inp =
(UWORD8* __restrict__)input.const_data_ptr<uint8_t>();
UWORD8* __restrict__ p_kernel =
(UWORD8* __restrict__)weight.const_data_ptr<uint8_t>();
WORD32* __restrict__ p_bias =
(WORD32* __restrict__)bias.const_data_ptr<int32_t>();

WORD32 out_multiplier = out_multiplier32[0];
WORD32 out_shift = out_shift32[0];

if (groups == 1) {
WORD32 out_data_format = 1;

UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory(
ctx,
((batches * input_channels * input_height * input_width) + 8) *
sizeof(UWORD8));

UWORD8* ptr2 = (UWORD8*)kernels::allocate_temp_memory(
ctx,
((out_channels * kernel_channels * kernel_height * kernel_width) +
8) *
sizeof(UWORD8));

UWORD8* pin = (UWORD8*)ALIGN_PTR(ptr1, 8);
UWORD8* pkernel = (UWORD8*)ALIGN_PTR(ptr2, 8);

WORD32 p_inp_shape[kNnlibMaxDim];
p_inp_shape[0] = batches;
p_inp_shape[1] = input_channels;
p_inp_shape[2] = input_height;
p_inp_shape[3] = input_width;

WORD32 p_out_shape[kNnlibMaxDim];
p_out_shape[0] = batches;
p_out_shape[1] = input_height;
p_out_shape[2] = input_width;
p_out_shape[3] = input_channels;

WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1};

xa_nn_transpose_8_8(
(WORD8*)pin,
p_out_shape,
(WORD8*)p_inp,
p_inp_shape,
p_permute_vec,
kNnlibMaxDim,
kNnlibMaxDim);

WORD32 p_inp_shape1[kNnlibMaxDim];
p_inp_shape1[0] = out_channels;
p_inp_shape1[1] = kernel_channels;
p_inp_shape1[2] = kernel_height;
p_inp_shape1[3] = kernel_width;

WORD32 p_out_shape1[kNnlibMaxDim];
p_out_shape1[0] = out_channels;
p_out_shape1[1] = kernel_height;
p_out_shape1[2] = kernel_width;
p_out_shape1[3] = kernel_channels;

xa_nn_transpose_8_8(
(WORD8*)pkernel,
p_out_shape1,
(WORD8*)p_kernel,
p_inp_shape1,
p_permute_vec,
kNnlibMaxDim,
kNnlibMaxDim);

scratch_size = xa_nn_conv2d_getsize(
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
kernel_channels,
dilation_height,
dilation_width,
y_stride,
y_padding,
x_stride,
x_padding,
out_height,
out_width,
out_channels,
inp_precision,
kernel_precision,
out_data_format);

scratch_size = scratch_size < 0 ? 0 : scratch_size;

ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size);

p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8);

for (int _n = 0; _n < batches; _n++) {
UWORD8* in_batch =
pin + _n * input_channels * input_height * input_width;
UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width;

xa_nn_conv2d_std_asym8uxasym8u(
out_batch,
in_batch,
pkernel,
p_bias,
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
out_channels,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
input_zero_bias,
kernel_zero_bias,
out_multiplier,
out_shift,
out_zero_bias,
out_data_format,
p_scratch);
}
return;
}

if (groups == input_channels) {
WORD32 channels_multiplier = out_channels / input_channels;

scratch_size = xa_nn_conv2d_depthwise_getsize(
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
channels_multiplier,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
inp_precision,
1); // NCHW

scratch_size = scratch_size < 0 ? 0 : scratch_size;

ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size);

p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8);

UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory(
ctx,
((batches * out_channels * out_height * out_width) + 8) *
sizeof(UWORD8));

UWORD8* p_out_temp = (UWORD8*)ALIGN_PTR(ptr1, 8);

for (int _n = 0; _n < batches; _n++) {
UWORD8* in_batch =
p_inp + _n * input_channels * input_height * input_width;
UWORD8* out_batch =
p_out_temp + _n * out_channels * out_height * out_width;

xa_nn_conv2d_depthwise_asym8uxasym8u(
out_batch,
p_kernel,
in_batch,
p_bias,
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
channels_multiplier,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
input_zero_bias,
kernel_zero_bias,
out_multiplier,
out_shift,
out_zero_bias,
1, // NCHW
0, // NHWC
p_scratch);
}

WORD32 p_inp_shape[kNnlibMaxDim];
p_inp_shape[0] = batches;
p_inp_shape[1] = out_height;
p_inp_shape[2] = out_width;
p_inp_shape[3] = out_channels;

WORD32 p_out_shape[kNnlibMaxDim];
p_out_shape[0] = batches;
p_out_shape[1] = out_channels;
p_out_shape[2] = out_height;
p_out_shape[3] = out_width;

WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2};

xa_nn_transpose_8_8(
(WORD8*)p_out,
p_out_shape,
(WORD8*)p_out_temp,
p_inp_shape,
p_permute_vec,
kNnlibMaxDim,
kNnlibMaxDim);

return;
}
}
}

// The quantized convolution kernel. in_scale and weight_scale are implicit in
Expand Down
Loading