Skip to content

Commit

Permalink
Fast 1x1 Conv DSP case use unordered im2col (#158)
Browse files Browse the repository at this point in the history
This is done for armclang only and not for GCC based on performance
measurements. Also a fallback to old path is provided if no additional
buffer is provided. This may happen when relying on the
arm_convolve_wrapper_s8_get_buffer_size_dsp() for the buffer size. Note
it is recommended to use arm_convolve_wrapper_s8_get_buffer_size().

Related github issue: #44
  • Loading branch information
mansnils authored Nov 11, 2024
1 parent 628c103 commit 22080c6
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 29 deletions.
112 changes: 106 additions & 6 deletions Source/ConvolutionFunctions/arm_convolve_1x1_s8_fast.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_convolve_1x1_s8_fast.c
* Description: Fast s8 version of 1x1 convolution (non-square shape)
*
* $Date: 04 January 2024
* $Revision: V.3.5.0
* $Date: 05 November 2024
* $Revision: V.3.6.0
*
* Target : Arm(R) M-Profile Architecture
*
Expand All @@ -46,7 +46,6 @@
* Refer header file for details.
*
*/

arm_cmsis_nn_status arm_convolve_1x1_s8_fast(const cmsis_nn_context *ctx,
const cmsis_nn_conv_params *conv_params,
const cmsis_nn_per_channel_quant_params *quant_params,
Expand All @@ -65,13 +64,114 @@ arm_cmsis_nn_status arm_convolve_1x1_s8_fast(const cmsis_nn_context *ctx,
return ARM_CMSIS_NN_ARG_ERROR;
}

(void)ctx;
(void)filter_dims;
(void)bias_dims;

const int32_t lhs_rows = input_dims->w * input_dims->h * input_dims->n;
const int32_t rhs_rows = output_dims->c;
const int32_t rhs_cols = input_dims->c;
const int32_t rhs_rows = output_dims->c;
int32_t lhs_rows = input_dims->w * input_dims->h * input_dims->n;

#if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI) && defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
if (ctx->buf != NULL) /* Fall back to non buffered version if no additional memory buffer provided */
{
const int32_t batch = input_dims->n;
const int32_t output_h = output_dims->h;
const int32_t output_w = output_dims->w;
const int32_t input_inc = input_dims->w * rhs_cols;

for (int i_batch = 0; i_batch < batch; i_batch++)
{
const int32_t output_ch = output_dims->c;
const int8_t *ip = input_data;
int16_t *buffer_a = (int16_t *)ctx->buf;
int16_t *im2col_buf = (int16_t *)ctx->buf;
int8_t *out = output_data;
lhs_rows = 0;

for (int i_out_y = 0; i_out_y < output_h; i_out_y++, ip += input_inc)
{
for (int32_t k_x = 0, i_out_x = 0; i_out_x < output_w; i_out_x++, k_x += rhs_cols)
{
arm_s8_to_s16_unordered_with_offset(ip + k_x, im2col_buf, rhs_cols, conv_params->input_offset);
im2col_buf += rhs_cols;
lhs_rows++;
if (lhs_rows == 2)
{
out = arm_nn_mat_mult_kernel_s8_s16(filter_data,
buffer_a,
output_ch,
quant_params->shift,
quant_params->multiplier,
conv_params->output_offset,
conv_params->activation.min,
conv_params->activation.max,
rhs_cols,
rhs_cols,
bias_data,
out);
im2col_buf = buffer_a;
lhs_rows = 0;
}
}
if (out == NULL)
{
return ARM_CMSIS_NN_NO_IMPL_ERROR;
}
}

/* Handle left over columns */
if (lhs_rows != 0)
{
const int8_t *ker_a = filter_data;
for (int i = 0; i < output_ch; i++)
{
/* Load the accumulator with bias first */
int32_t sum = 0;
if (bias_data)
{
sum = bias_data[i];
}
const int16_t *ip_as_col = buffer_a;

/* 4 multiply and accumulates are done in one loop. */
uint16_t col_count = rhs_cols >> 2;
while (col_count)
{
int32_t ker_a1, ker_a2;
int32_t ip_b1, ip_b2;
ker_a = read_and_pad_reordered(ker_a, &ker_a1, &ker_a2);
ip_b1 = arm_nn_read_q15x2_ia(&ip_as_col);
sum = SMLAD(ker_a1, ip_b1, sum);
ip_b2 = arm_nn_read_q15x2_ia(&ip_as_col);
sum = SMLAD(ker_a2, ip_b2, sum);
col_count--;
}

/* Handle left over mac */
col_count = rhs_cols & 0x3;
while (col_count)
{
int8_t ker_a1 = *ker_a++;
int16_t ip_b1 = *ip_as_col++;
sum += ker_a1 * ip_b1;
col_count--;
}
sum = arm_nn_requantize(sum, quant_params->multiplier[i], quant_params->shift[i]);
sum += conv_params->output_offset;
sum = MAX(sum, conv_params->activation.min);
sum = MIN(sum, conv_params->activation.max);
*out++ = (int8_t)sum;
}
}
/* Advance to the next batch */
input_data += (input_dims->w * input_dims->h * rhs_cols);
output_data += (output_w * output_h * output_ch);
}
return ARM_CMSIS_NN_SUCCESS;
}
#else
(void)ctx;
#endif

arm_nn_mat_mult_nt_t_s8(input_data,
filter_data,
Expand Down
42 changes: 39 additions & 3 deletions Source/ConvolutionFunctions/arm_convolve_get_buffer_sizes_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_convolve_get_buffer_sizes_s8.c
* Description: Collection of get buffer size functions for the various s8 convolution layer functions.
*
* $Date: 28 March 2024
* $Revision: V.2.1.1
* $Date: 31 October 2024
* $Revision: V.2.2.1
*
* Target : Arm(R) M-Profile Architecture
*
Expand All @@ -40,6 +40,15 @@
* @addtogroup GetBufferSizeNNConv
* @{
*/
__STATIC_INLINE int32_t arm_convolve_1x1_s8_fast_get_buffer_size_dsp(const cmsis_nn_dims *input_dims)
{
#if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
return (2 * input_dims->c) * (int32_t)sizeof(int16_t);
#else
(void)input_dims;
return 0;
#endif
}

__STATIC_INLINE int32_t arm_convolve_s8_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
const cmsis_nn_dims *filter_dims)
Expand Down Expand Up @@ -112,7 +121,11 @@ int32_t arm_convolve_1_x_n_s8_get_buffer_size(const cmsis_nn_conv_params *conv_p

int32_t arm_convolve_1x1_s8_fast_get_buffer_size(const cmsis_nn_dims *input_dims)
{
#if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
return arm_convolve_1x1_s8_fast_get_buffer_size_dsp(input_dims);
#else
(void)input_dims;
#endif
return 0;
}

Expand All @@ -130,6 +143,8 @@ int32_t arm_convolve_wrapper_s8_get_buffer_size(const cmsis_nn_conv_params *conv
{
#if defined(ARM_MATH_MVEI)
return arm_convolve_wrapper_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
#elif defined(ARM_MATH_DSP)
return arm_convolve_wrapper_s8_get_buffer_size_dsp(conv_params, input_dims, filter_dims, output_dims);
#else
(void)output_dims;
if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
Expand Down Expand Up @@ -190,7 +205,28 @@ int32_t arm_convolve_wrapper_s8_get_buffer_size_dsp(const cmsis_nn_conv_params *
const cmsis_nn_dims *filter_dims,
const cmsis_nn_dims *output_dims)
{
return arm_convolve_wrapper_s8_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
(void)output_dims;
if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
(filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
{
if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
{
return arm_convolve_1x1_s8_fast_get_buffer_size_dsp(input_dims);
}
else
{
return 0;
}
}
else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
(conv_params->stride.w * input_dims->c % 4 == 0))
{
return arm_convolve_1_x_n_s8_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
}
else
{
return arm_convolve_s8_get_buffer_size(input_dims, filter_dims);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
* SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
*
* SPDX-License-Identifier: Apache-2.0
*
Expand Down Expand Up @@ -419,27 +419,30 @@ void buffer_size_arm_convolve_1x1_s8_fast(void)
cmsis_nn_dims filter_dims;
cmsis_nn_dims output_dims;

input_dims.n = KERNEL1X1_STRIDE_X_Y_2_INPUT_BATCHES;
input_dims.w = KERNEL1X1_STRIDE_X_Y_2_INPUT_W;
input_dims.h = KERNEL1X1_STRIDE_X_Y_2_INPUT_H;
input_dims.c = KERNEL1X1_STRIDE_X_Y_2_IN_CH;
filter_dims.w = KERNEL1X1_STRIDE_X_Y_2_FILTER_X;
filter_dims.h = KERNEL1X1_STRIDE_X_Y_2_FILTER_Y;
output_dims.w = KERNEL1X1_STRIDE_X_Y_2_OUTPUT_W;
output_dims.h = KERNEL1X1_STRIDE_X_Y_2_OUTPUT_H;
output_dims.c = KERNEL1X1_STRIDE_X_Y_2_OUT_CH;
input_dims.n = KERNEL1X1_INPUT_BATCHES;
input_dims.w = KERNEL1X1_INPUT_W;
input_dims.h = KERNEL1X1_INPUT_H;
input_dims.c = KERNEL1X1_IN_CH;
filter_dims.w = KERNEL1X1_FILTER_X;
filter_dims.h = KERNEL1X1_FILTER_Y;
output_dims.w = KERNEL1X1_OUTPUT_W;
output_dims.h = KERNEL1X1_OUTPUT_H;
output_dims.c = KERNEL1X1_OUT_CH;

conv_params.padding.w = KERNEL1X1_STRIDE_X_Y_2_PAD_X;
conv_params.padding.h = KERNEL1X1_STRIDE_X_Y_2_PAD_Y;
conv_params.stride.w = KERNEL1X1_STRIDE_X_Y_2_STRIDE_X;
conv_params.stride.h = KERNEL1X1_STRIDE_X_Y_2_STRIDE_Y;
conv_params.dilation.w = KERNEL1X1_STRIDE_X_Y_2_DILATION_X;
conv_params.dilation.h = KERNEL1X1_STRIDE_X_Y_2_DILATION_Y;
conv_params.padding.w = KERNEL1X1_PAD_X;
conv_params.padding.h = KERNEL1X1_PAD_Y;
conv_params.stride.w = KERNEL1X1_STRIDE_X;
conv_params.stride.h = KERNEL1X1_STRIDE_Y;
conv_params.dilation.w = KERNEL1X1_DILATION_X;
conv_params.dilation.h = KERNEL1X1_DILATION_Y;

conv_params.input_offset = KERNEL1X1_STRIDE_X_Y_2_INPUT_OFFSET;
conv_params.output_offset = KERNEL1X1_STRIDE_X_Y_2_OUTPUT_OFFSET;
conv_params.activation.min = KERNEL1X1_STRIDE_X_Y_2_OUT_ACTIVATION_MIN;
conv_params.activation.max = KERNEL1X1_STRIDE_X_Y_2_OUT_ACTIVATION_MAX;
conv_params.input_offset = KERNEL1X1_INPUT_OFFSET;
conv_params.output_offset = KERNEL1X1_OUTPUT_OFFSET;
conv_params.activation.min = KERNEL1X1_OUT_ACTIVATION_MIN;
conv_params.activation.max = KERNEL1X1_OUT_ACTIVATION_MAX;

TEST_ASSERT_EQUAL(conv_params.stride.w, 1);
TEST_ASSERT_EQUAL(conv_params.stride.h, 1);

const int32_t buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(&input_dims);
const int32_t wrapper_buf_size =
Expand Down

0 comments on commit 22080c6

Please sign in to comment.