esp_nn_conv_opt.c 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. // Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <esp_nn_defs.h>
  15. #include <common_functions.h>
  16. int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
  17. const data_dims_t *filter_dims,
  18. const data_dims_t *output_dims,
  19. const conv_params_t *conv_params)
  20. {
  21. return 0;
  22. }
  23. void esp_nn_set_conv_scratch_buf_opt(const void *buf)
  24. {
  25. }
  26. __attribute__ ((noinline))
  27. static void esp_nn_conv_s8_1x1(const data_dims_t *input_dims,
  28. const int8_t *input_data,
  29. const int8_t *filter_data,
  30. const int32_t *bias,
  31. const data_dims_t *output_dims,
  32. int8_t *out_data,
  33. const conv_params_t *conv_params,
  34. const quant_data_t *quant_data)
  35. {
  36. const uint16_t input_wd = input_dims->width;
  37. const uint16_t in_channels = input_dims->channels;
  38. const int32_t input_offset = conv_params->in_offset;
  39. const int32_t out_offset = conv_params->out_offset;
  40. const uint16_t stride_wd = conv_params->stride.width;
  41. const uint16_t stride_ht = conv_params->stride.height;
  42. const uint16_t out_wd = output_dims->width;
  43. const uint16_t out_ht = output_dims->height;
  44. const uint16_t out_channels = output_dims->channels;
  45. const int32_t activation_min = conv_params->activation.min;
  46. const int32_t activation_max = conv_params->activation.max;
  47. for (int32_t in_row = 0; in_row < out_ht * stride_ht; in_row += stride_ht) {
  48. for (int32_t in_col = 0; in_col < out_wd * stride_wd; in_col += stride_wd) {
  49. const int32_t *out_mult = quant_data->mult;
  50. const int32_t *out_shift = quant_data->shift;
  51. const int8_t *filter_ptr = filter_data;
  52. const int8_t *input_base_ptr = input_data + (in_row * input_wd + in_col) * in_channels;
  53. int32_t out_ch_idx = 0;
  54. for (; out_ch_idx < out_channels; out_ch_idx++) {
  55. int32_t conv_out = 0;
  56. const int8_t *input_ptr = input_base_ptr;
  57. int32_t in_ch_idx = 0;
  58. for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
  59. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  60. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  61. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  62. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  63. }
  64. for (; in_ch_idx < in_channels; in_ch_idx ++) {
  65. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  66. }
  67. if (bias) {
  68. conv_out += bias[out_ch_idx];
  69. }
  70. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
  71. conv_out += out_offset;
  72. conv_out = max(conv_out, activation_min);
  73. conv_out = min(conv_out, activation_max);
  74. *out_data++ = (int8_t) conv_out;
  75. }
  76. }
  77. }
  78. }
  79. /**
  80. * Assumption 1: i/p channels == o/p channels
  81. * Assumption 2: Pointers are valid
  82. * Assumption 3: dialation width = 1
  83. */
  84. void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
  85. const int8_t *input_data,
  86. const data_dims_t *filter_dims,
  87. const int8_t *filter_data,
  88. const int32_t *bias,
  89. const data_dims_t *output_dims,
  90. int8_t *out_data,
  91. const conv_params_t *conv_params,
  92. const quant_data_t *quant_data)
  93. {
  94. const uint16_t filter_wd = filter_dims->width;
  95. const uint16_t filter_ht = filter_dims->height;
  96. if (filter_wd == 1 && filter_ht == 1) {
  97. esp_nn_conv_s8_1x1(input_dims, input_data, filter_data, bias,
  98. output_dims, out_data, conv_params, quant_data);
  99. return;
  100. }
  101. const uint16_t input_wd = input_dims->width;
  102. const uint16_t input_ht = input_dims->height;
  103. const uint16_t in_channels = input_dims->channels;
  104. const int32_t input_offset = conv_params->in_offset;
  105. const int32_t out_offset = conv_params->out_offset;
  106. const uint16_t pad_wd = conv_params->padding.width;
  107. const uint16_t pad_ht = conv_params->padding.height;
  108. const uint16_t stride_wd = conv_params->stride.width;
  109. const uint16_t stride_ht = conv_params->stride.height;
  110. const uint16_t out_wd = output_dims->width;
  111. const uint16_t out_ht = output_dims->height;
  112. const uint16_t out_channels = output_dims->channels;
  113. const int32_t activation_min = conv_params->activation.min;
  114. const int32_t activation_max = conv_params->activation.max;
  115. int32_t out_ch_idx, out_y, out_x, filter_y_idx, filter_x_idx;
  116. for (out_y = 0; out_y < out_ht; out_y++) {
  117. for (out_x = 0; out_x < out_wd; out_x++) {
  118. const int32_t *out_shift = quant_data->shift;
  119. const int32_t *out_mult = quant_data->mult;
  120. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  121. int32_t conv_out = 0;
  122. const int32_t base_y = stride_ht * out_y - pad_ht;
  123. const int32_t base_x = stride_wd * out_x - pad_wd;
  124. const int32_t filter_y_start = max(0, -base_y);
  125. const int32_t filter_x_start = max(0, -base_x);
  126. const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
  127. const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
  128. for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
  129. for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
  130. const int32_t in_row = base_y + filter_y_idx;
  131. const int32_t in_col = base_x + filter_x_idx;
  132. const int8_t *input_ptr = input_data +
  133. (in_row * input_wd + in_col) * in_channels;
  134. const int8_t *filter_ptr = filter_data +
  135. out_ch_idx * in_channels * filter_ht * filter_wd +
  136. (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
  137. int32_t in_ch_idx = 0;
  138. for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
  139. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  140. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  141. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  142. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  143. }
  144. for (; in_ch_idx < in_channels; in_ch_idx ++) {
  145. conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
  146. }
  147. }
  148. }
  149. if (bias) {
  150. conv_out += bias[out_ch_idx];
  151. }
  152. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
  153. conv_out += out_offset;
  154. conv_out = max(conv_out, activation_min);
  155. conv_out = min(conv_out, activation_max);
  156. *out_data++ = (int8_t) conv_out;
  157. }
  158. }
  159. }
  160. }