esp_nn_depthwise_conv_opt.c 15 KB


  1. // Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <esp_nn_defs.h>
  15. #include <common_functions.h>
  16. int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
  17. const data_dims_t *filter_dims,
  18. const data_dims_t *output_dims,
  19. const dw_conv_params_t *conv_params)
  20. {
  21. return 0;
  22. }
  23. void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf)
  24. {
  25. }
  26. /* common channel multiplier == 1 case */
  27. __attribute__ ((noinline))
  28. static void esp_nn_depthwise_conv_s8_ch_mult_1(const data_dims_t *input_dims,
  29. const int8_t *input_data,
  30. const data_dims_t *filter_dims,
  31. const int8_t *filter_data,
  32. const int32_t *bias,
  33. const data_dims_t *output_dims,
  34. int8_t *out_data,
  35. const dw_conv_params_t *conv_params,
  36. const quant_data_t *quant_data)
  37. {
  38. const uint16_t input_wd = input_dims->width;
  39. const uint16_t input_ht = input_dims->height;
  40. const uint16_t channels = input_dims->channels;
  41. const int32_t input_offset = conv_params->in_offset;
  42. const int32_t out_offset = conv_params->out_offset;
  43. const uint16_t pad_wd = conv_params->padding.width;
  44. const uint16_t pad_ht = conv_params->padding.height;
  45. const uint16_t stride_wd = conv_params->stride.width;
  46. const uint16_t stride_ht = conv_params->stride.height;
  47. const uint16_t filter_wd = filter_dims->width;
  48. const uint16_t filter_ht = filter_dims->height;
  49. const uint16_t out_wd = output_dims->width;
  50. const uint16_t out_ht = output_dims->height;
  51. const int32_t activation_min = conv_params->activation.min;
  52. const int32_t activation_max = conv_params->activation.max;
  53. int out_idx = 0;
  54. for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
  55. const int16_t base_y = (out_y * stride_ht) - pad_ht;
  56. for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
  57. const int16_t base_x = (out_x * stride_wd) - pad_wd;
  58. const int32_t *out_shift = quant_data->shift;
  59. const int32_t *out_mult = quant_data->mult;
  60. /* Select filter so as the point doesn't lie outside block */
  61. int filter_y_start = max(0, -base_y);
  62. int filter_x_start = max(0, -base_x);
  63. int filter_y_end = min(filter_ht, input_ht - base_y);
  64. int filter_x_end = min(filter_wd, input_wd - base_x);
  65. int ch_idx = 0;
  66. for (; ch_idx < channels - 3; ch_idx += 4) {//channel_loop
  67. int32_t result0 = 0;
  68. int32_t result1 = 0;
  69. int32_t result2 = 0;
  70. int32_t result3 = 0;
  71. for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
  72. const int32_t idx_y = base_y + filter_y_idx;
  73. for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
  74. const int32_t idx_x = base_x + filter_x_idx;
  75. int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
  76. int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
  77. int32_t input_val0 = input_data[input_index + 0] + input_offset;
  78. int32_t input_val1 = input_data[input_index + 1] + input_offset;
  79. int32_t input_val2 = input_data[input_index + 2] + input_offset;
  80. int32_t input_val3 = input_data[input_index + 3] + input_offset;
  81. int32_t filter_val0 = filter_data[filter_index + 0];
  82. int32_t filter_val1 = filter_data[filter_index + 1];
  83. int32_t filter_val2 = filter_data[filter_index + 2];
  84. int32_t filter_val3 = filter_data[filter_index + 3];
  85. result0 += input_val0 * filter_val0;
  86. result1 += input_val1 * filter_val1;
  87. result2 += input_val2 * filter_val2;
  88. result3 += input_val3 * filter_val3;
  89. }
  90. }
  91. if (bias) {
  92. result0 += bias[ch_idx + 0];
  93. result1 += bias[ch_idx + 1];
  94. result2 += bias[ch_idx + 2];
  95. result3 += bias[ch_idx + 3];
  96. }
  97. result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
  98. result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
  99. result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
  100. result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
  101. result0 += out_offset;
  102. result1 += out_offset;
  103. result2 += out_offset;
  104. result3 += out_offset;
  105. result0 = max(result0, activation_min);
  106. result1 = max(result1, activation_min);
  107. result2 = max(result2, activation_min);
  108. result3 = max(result3, activation_min);
  109. result0 = min(result0, activation_max);
  110. result1 = min(result1, activation_max);
  111. result2 = min(result2, activation_max);
  112. result3 = min(result3, activation_max);
  113. out_data[out_idx++] = result0;
  114. out_data[out_idx++] = result1;
  115. out_data[out_idx++] = result2;
  116. out_data[out_idx++] = result3;
  117. }
  118. for (; ch_idx < channels; ch_idx++) {//channel_loop
  119. int32_t result = 0;
  120. for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
  121. const int32_t idx_y = base_y + filter_y_idx;
  122. for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
  123. const int32_t idx_x = base_x + filter_x_idx;
  124. int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
  125. int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
  126. int32_t input_val = input_data[input_index] + input_offset;
  127. int32_t filter_val = filter_data[filter_index];
  128. result += input_val * filter_val;
  129. }
  130. }
  131. if (bias) {
  132. result += bias[ch_idx];
  133. }
  134. result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
  135. result += out_offset;
  136. result = max(result, activation_min);
  137. result = min(result, activation_max);
  138. out_data[out_idx++] = result;
  139. }
  140. }
  141. }
  142. }
  143. void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
  144. const int8_t *input_data,
  145. const data_dims_t *filter_dims,
  146. const int8_t *filter_data,
  147. const int32_t *bias,
  148. const data_dims_t *output_dims,
  149. int8_t *out_data,
  150. const dw_conv_params_t *conv_params,
  151. const quant_data_t *quant_data)
  152. {
  153. const uint16_t ch_mult = conv_params->ch_mult;
  154. if (ch_mult == 1) {
  155. esp_nn_depthwise_conv_s8_ch_mult_1(input_dims, input_data, filter_dims, filter_data,
  156. bias, output_dims, out_data, conv_params, quant_data);
  157. return;
  158. }
  159. const uint16_t input_wd = input_dims->width;
  160. const uint16_t input_ht = input_dims->height;
  161. const uint16_t channels = input_dims->channels;
  162. const int32_t input_offset = conv_params->in_offset;
  163. const int32_t out_offset = conv_params->out_offset;
  164. const uint16_t pad_wd = conv_params->padding.width;
  165. const uint16_t pad_ht = conv_params->padding.height;
  166. const uint16_t stride_wd = conv_params->stride.width;
  167. const uint16_t stride_ht = conv_params->stride.height;
  168. const uint16_t filter_wd = filter_dims->width;
  169. const uint16_t filter_ht = filter_dims->height;
  170. const uint16_t out_wd = output_dims->width;
  171. const uint16_t out_ht = output_dims->height;
  172. const int32_t activation_min = conv_params->activation.min;
  173. const int32_t activation_max = conv_params->activation.max;
  174. int out_idx = 0;
  175. for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
  176. const int16_t base_y = (out_y * stride_ht) - pad_ht;
  177. for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
  178. const int16_t base_x = (out_x * stride_wd) - pad_wd;
  179. const int32_t *out_shift = quant_data->shift;
  180. const int32_t *out_mult = quant_data->mult;
  181. /* Select filter so as the point doesn't lie outside block */
  182. int filter_y_start = max(0, -base_y);
  183. int filter_x_start = max(0, -base_x);
  184. int filter_y_end = min(filter_ht, input_ht - base_y);
  185. int filter_x_end = min(filter_wd, input_wd - base_x);
  186. for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
  187. int ch_mult_idx = 0;
  188. for (; ch_mult_idx < ch_mult - 3; ch_mult_idx += 4) {
  189. int32_t result0 = 0;
  190. int32_t result1 = 0;
  191. int32_t result2 = 0;
  192. int32_t result3 = 0;
  193. const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx;
  194. for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
  195. const int32_t idx_y = base_y + filter_y_idx;
  196. for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
  197. const int32_t idx_x = base_x + filter_x_idx;
  198. int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
  199. int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
  200. int32_t input_val = input_data[input_index] + input_offset;
  201. int32_t filter_val0 = filter_data[filter_index + 0];
  202. int32_t filter_val1 = filter_data[filter_index + 1];
  203. int32_t filter_val2 = filter_data[filter_index + 2];
  204. int32_t filter_val3 = filter_data[filter_index + 3];
  205. result0 += input_val * filter_val0;
  206. result1 += input_val * filter_val1;
  207. result2 += input_val * filter_val2;
  208. result3 += input_val * filter_val3;
  209. }
  210. }
  211. if (bias) {
  212. result0 += bias[out_ch_idx + 0];
  213. result1 += bias[out_ch_idx + 1];
  214. result2 += bias[out_ch_idx + 2];
  215. result3 += bias[out_ch_idx + 3];
  216. }
  217. result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
  218. result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
  219. result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
  220. result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
  221. result0 += out_offset;
  222. result1 += out_offset;
  223. result2 += out_offset;
  224. result3 += out_offset;
  225. result0 = max(result0, activation_min);
  226. result1 = max(result1, activation_min);
  227. result2 = max(result2, activation_min);
  228. result3 = max(result3, activation_min);
  229. result0 = min(result0, activation_max);
  230. result1 = min(result1, activation_max);
  231. result2 = min(result2, activation_max);
  232. result3 = min(result3, activation_max);
  233. out_data[out_idx++] = result0;
  234. out_data[out_idx++] = result1;
  235. out_data[out_idx++] = result2;
  236. out_data[out_idx++] = result3;
  237. }
  238. for (; ch_mult_idx < ch_mult; ch_mult_idx++) {
  239. int32_t result = 0;
  240. const int out_ch_idx = ch_idx * ch_mult + ch_mult_idx;
  241. for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
  242. const int32_t idx_y = base_y + filter_y_idx;
  243. for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
  244. const int32_t idx_x = base_x + filter_x_idx;
  245. int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
  246. int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
  247. int32_t input_val = input_data[input_index] + input_offset;
  248. int32_t filter_val = filter_data[filter_index];
  249. result += input_val * filter_val;
  250. }
  251. }
  252. if (bias) {
  253. result += bias[out_ch_idx];
  254. }
  255. result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
  256. result += out_offset;
  257. result = max(result, activation_min);
  258. result = min(result, activation_max);
  259. out_data[out_idx++] = result;
  260. }
  261. }
  262. }
  263. }
  264. }