esp_nn_conv_esp32s3.c 24 KB


  1. // Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <stdio.h>
  15. #include <esp_nn_defs.h>
  16. #include <common_functions.h>
  17. static int16_t *scratch_buffer = NULL;
  18. extern void esp_nn_conv_s8_mult8_1x1_esp32s3(const int8_t *input_data,
  19. const uint16_t input_wd,
  20. const uint16_t input_ht,
  21. const uint16_t in_channels,
  22. const int32_t input_offset,
  23. const int8_t *filter_aligned,
  24. const int32_t *bias,
  25. int8_t *out_data,
  26. const uint16_t out_wd,
  27. const uint16_t out_ht,
  28. const uint16_t out_channels,
  29. const int32_t out_offset,
  30. const int32_t *out_shift,
  31. const int32_t *out_mult,
  32. const int32_t activation_min,
  33. const int32_t activation_max,
  34. void *buffer /* scratch buffer */);
  35. extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
  36. const uint16_t input_wd,
  37. const uint16_t input_ht,
  38. const uint16_t in_channels,
  39. const int16_t *filter_data,
  40. const int32_t *bias,
  41. int8_t *out_data,
  42. const uint16_t out_wd,
  43. const uint16_t out_ht,
  44. const uint16_t out_channels,
  45. const int32_t out_offset,
  46. const int32_t *out_shift,
  47. const int32_t *out_mult,
  48. const int32_t activation_min,
  49. const int32_t activation_max,
  50. void *buffer /* scratch buffer */);
  51. extern void esp_nn_conv_s16_mult8_esp32s3(const int16_t *input_data,
  52. const uint16_t input_wd,
  53. const uint16_t input_ht,
  54. const uint16_t in_channels,
  55. const uint16_t pad_wd,
  56. const uint16_t pad_ht,
  57. const uint16_t stride_wd,
  58. const uint16_t stride_ht,
  59. const int16_t *filter_data,
  60. const uint16_t filter_wd,
  61. const uint16_t filter_ht,
  62. const int32_t *bias,
  63. int8_t *out_data,
  64. const uint16_t out_wd,
  65. const uint16_t out_ht,
  66. const uint16_t out_channels,
  67. const int32_t out_offset,
  68. const int32_t *out_shift,
  69. const int32_t *out_mult,
  70. const int32_t activation_min,
  71. const int32_t activation_max);
  72. extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int16_t *dst,
  73. const int size, const int32_t offset);
  74. extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
  75. static void esp_nn_conv_s8_unrolled(const data_dims_t *input_dims,
  76. const int8_t *input_data,
  77. const data_dims_t *filter_dims,
  78. const int8_t *filter_data,
  79. const int32_t *bias,
  80. const data_dims_t *output_dims,
  81. int8_t *out_data,
  82. const conv_params_t *conv_params,
  83. const quant_data_t *quant_data)
  84. {
  85. const uint16_t input_wd = input_dims->width;
  86. const uint16_t input_ht = input_dims->height;
  87. const uint16_t in_ch = input_dims->channels;
  88. const int32_t input_offset = conv_params->in_offset;
  89. const int32_t out_offset = conv_params->out_offset;
  90. const uint16_t pad_wd = conv_params->padding.width;
  91. const uint16_t pad_ht = conv_params->padding.height;
  92. const uint16_t stride_wd = conv_params->stride.width;
  93. const uint16_t stride_ht = conv_params->stride.height;
  94. const uint16_t filter_wd = filter_dims->width;
  95. const uint16_t filter_ht = filter_dims->height;
  96. const uint16_t out_wd = output_dims->width;
  97. const uint16_t out_ht = output_dims->height;
  98. const uint16_t out_ch = output_dims->channels;
  99. const int32_t *out_shift = quant_data->shift;
  100. const int32_t *out_mult = quant_data->mult;
  101. const int32_t activation_min = conv_params->activation.min;
  102. const int32_t activation_max = conv_params->activation.max;
  103. int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
  104. for (out_y = 0; out_y < out_ht; out_y++) {
  105. for (out_x = 0; out_x < out_wd; out_x++) {
  106. for (out_ch_idx = 0; out_ch_idx < out_ch; out_ch_idx++) {
  107. int32_t conv_out = 0;
  108. const int32_t base_y = stride_ht * out_y - pad_ht;
  109. const int32_t base_x = stride_wd * out_x - pad_wd;
  110. const int32_t filter_y_start = max(0, -base_y);
  111. const int32_t filter_x_start = max(0, -base_x);
  112. const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
  113. const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
  114. for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
  115. for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
  116. const int32_t in_row = base_y + filter_y_idx;
  117. const int32_t in_col = base_x + filter_x_idx;
  118. int32_t input_base_offset = (in_row * input_wd + in_col) * in_ch;
  119. int32_t filter_base_offset = out_ch_idx * in_ch * filter_ht * filter_wd +
  120. (filter_y_idx * filter_wd + filter_x_idx) * in_ch;
  121. for (in_ch_idx = 0; in_ch_idx < in_ch; in_ch_idx++) {
  122. conv_out +=
  123. (input_data[input_base_offset + in_ch_idx] + input_offset) *
  124. filter_data[filter_base_offset + in_ch_idx];
  125. }
  126. }
  127. }
  128. if (bias) {
  129. conv_out += bias[out_ch_idx];
  130. }
  131. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
  132. conv_out += out_offset;
  133. conv_out = max(conv_out, activation_min);
  134. conv_out = min(conv_out, activation_max);
  135. *out_data++ = (int8_t) conv_out;
  136. }
  137. }
  138. }
  139. }
  140. static void esp_nn_conv_s8_pad_valid(const int8_t *input_data,
  141. const uint16_t input_wd,
  142. const uint16_t input_ht,
  143. const uint16_t in_channels,
  144. const int32_t input_offset,
  145. const uint16_t stride_wd,
  146. const uint16_t stride_ht,
  147. const int8_t *filter_data,
  148. const uint16_t filter_wd,
  149. const uint16_t filter_ht,
  150. const int32_t *bias,
  151. int8_t *out_data,
  152. const uint16_t out_wd,
  153. const uint16_t out_ht,
  154. const uint16_t out_channels,
  155. const int32_t out_offset,
  156. const int32_t *out_shift,
  157. const int32_t *out_mult,
  158. const int32_t activation_min,
  159. const int32_t activation_max)
  160. {
  161. int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
  162. for (out_y = 0; out_y < out_ht; out_y++) {
  163. for (out_x = 0; out_x < out_wd; out_x++) {
  164. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  165. int32_t conv_out = 0;
  166. const int32_t base_y = stride_ht * out_y;
  167. const int32_t base_x = stride_wd * out_x;
  168. for (filter_y_idx = 0; filter_y_idx < filter_ht; filter_y_idx++) {
  169. for (filter_x_idx = 0; filter_x_idx < filter_wd; filter_x_idx++) {
  170. const int32_t in_row = base_y + filter_y_idx;
  171. const int32_t in_col = base_x + filter_x_idx;
  172. int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
  173. int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
  174. (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
  175. const int8_t *input_data_ptr = input_data + input_base_offset;
  176. const int8_t *filter_data_ptr = filter_data + filter_base_offset;
  177. for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
  178. conv_out += (*input_data_ptr++ + input_offset) * *filter_data_ptr++;
  179. }
  180. }
  181. }
  182. if (bias) {
  183. conv_out += bias[out_ch_idx];
  184. }
  185. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
  186. conv_out += out_offset;
  187. conv_out = max(conv_out, activation_min);
  188. conv_out = min(conv_out, activation_max);
  189. *out_data++ = (int8_t) conv_out;
  190. }
  191. }
  192. }
  193. }
  194. static void esp_nn_conv_s8_pad_valid_3x3(const int8_t *input_data,
  195. const uint16_t input_wd,
  196. const uint16_t input_ht,
  197. const uint16_t in_channels,
  198. const int32_t input_offset,
  199. const uint16_t stride_wd,
  200. const uint16_t stride_ht,
  201. const int8_t *filter_data,
  202. const int32_t *bias,
  203. int8_t *out_data,
  204. const uint16_t out_wd,
  205. const uint16_t out_ht,
  206. const uint16_t out_channels,
  207. const int32_t out_offset,
  208. const int32_t *out_shift,
  209. const int32_t *out_mult,
  210. const int32_t activation_min,
  211. const int32_t activation_max)
  212. {
  213. int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
  214. for (out_y = 0; out_y < out_ht; out_y++) {
  215. for (out_x = 0; out_x < out_wd; out_x++) {
  216. const int32_t base_y = stride_ht * out_y;
  217. const int32_t base_x = stride_wd * out_x;
  218. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  219. int32_t conv_out = 0;
  220. for (filter_y_idx = 0; filter_y_idx < 3; filter_y_idx++) {
  221. for (filter_x_idx = 0; filter_x_idx < 3; filter_x_idx++) {
  222. const int32_t in_row = base_y + filter_y_idx;
  223. const int32_t in_col = base_x + filter_x_idx;
  224. int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
  225. int32_t filter_base_offset = out_ch_idx * in_channels * 3 * 3 +
  226. (filter_y_idx * 3 + filter_x_idx) * in_channels;
  227. const int8_t *input_data_ptr = input_data + input_base_offset;
  228. const int8_t *filter_data_ptr = filter_data + filter_base_offset;
  229. for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
  230. conv_out += (*input_data_ptr++ + input_offset) * *filter_data_ptr++;
  231. }
  232. }
  233. }
  234. if (bias) {
  235. conv_out += bias[out_ch_idx];
  236. }
  237. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
  238. conv_out += out_offset;
  239. conv_out = max(conv_out, activation_min);
  240. conv_out = min(conv_out, activation_max);
  241. *out_data++ = (int8_t) conv_out;
  242. }
  243. }
  244. }
  245. }
  246. static void esp_nn_conv_s8_pad_valid_ch3_3x3(const int8_t *input_data,
  247. const uint16_t input_wd,
  248. const uint16_t input_ht,
  249. const int32_t input_offset,
  250. const uint16_t stride_wd,
  251. const uint16_t stride_ht,
  252. const int8_t *filter_data,
  253. const int32_t *bias,
  254. int8_t *out_data,
  255. const uint16_t out_wd,
  256. const uint16_t out_ht,
  257. const uint16_t out_channels,
  258. const int32_t out_offset,
  259. const int32_t *out_shift,
  260. const int32_t *out_mult,
  261. const int32_t activation_min,
  262. const int32_t activation_max)
  263. {
  264. int32_t out_ch_idx, out_y, out_x, filter_y_idx;
  265. /* use scratch_buffer to pre-compute offset factor */
  266. int16_t *filter_sum = (int16_t *) scratch_buffer;
  267. const int8_t *filter_ptr = filter_data;
  268. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  269. int16_t sum_val = 0;
  270. for (int i = 0; i < 9; i++) {
  271. sum_val += *filter_ptr++;
  272. sum_val += *filter_ptr++;
  273. sum_val += *filter_ptr++;
  274. }
  275. *filter_sum++ = sum_val;
  276. }
  277. for (out_y = 0; out_y < out_ht; out_y++) {
  278. for (out_x = 0; out_x < out_wd; out_x++) {
  279. const int8_t *filter_data_ptr = filter_data;
  280. const int32_t base_y = stride_ht * out_y;
  281. const int32_t base_x = stride_wd * out_x;
  282. const int8_t *input_base_ptr = input_data + (base_y * input_wd + base_x) * 3;
  283. int16_t *filter_sum = (int16_t *) scratch_buffer;
  284. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  285. int32_t conv_out = 0;
  286. for (filter_y_idx = 0; filter_y_idx < 3; filter_y_idx++) {
  287. const int8_t *input_data_ptr = input_base_ptr + (filter_y_idx * input_wd) * 3;
  288. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  289. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  290. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  291. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  292. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  293. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  294. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  295. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  296. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  297. }
  298. conv_out += *filter_sum++ * input_offset;
  299. if (bias) {
  300. conv_out += bias[out_ch_idx];
  301. }
  302. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
  303. conv_out += out_offset;
  304. conv_out = max(conv_out, activation_min);
  305. conv_out = min(conv_out, activation_max);
  306. *out_data++ = (int8_t) conv_out;
  307. }
  308. }
  309. }
  310. }
  311. int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
  312. const data_dims_t *filter_dims,
  313. const data_dims_t *output_dims,
  314. const conv_params_t *conv_params)
  315. {
  316. const uint16_t input_wd = input_dims->width;
  317. const uint16_t input_ht = input_dims->height;
  318. const uint16_t in_ch = input_dims->channels;
  319. const uint16_t filter_wd = filter_dims->width;
  320. const uint16_t filter_ht = filter_dims->height;
  321. const uint16_t out_ch = output_dims->channels;
  322. const uint16_t pad_wd = conv_params->padding.width;
  323. const uint16_t pad_ht = conv_params->padding.height;
  324. const uint16_t stride_wd = conv_params->stride.width;
  325. const uint16_t stride_ht = conv_params->stride.height;
  326. int filter_size = filter_wd * filter_ht * in_ch * out_ch;
  327. int input_size = input_wd * input_ht * in_ch;
  328. int transpose_buf_size = 2 * (8 * in_ch); /* to store intermediate data */
  329. if (input_wd * input_ht < 8) {
  330. transpose_buf_size = 0; // not using this for leftover
  331. }
  332. int align_buf_size = 32; /* extra buffer for alignment */
  333. if (in_ch % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
  334. pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
  335. return filter_size + transpose_buf_size + align_buf_size;
  336. }
  337. return 2 * (filter_size + input_size) + transpose_buf_size + align_buf_size;
  338. }
  339. void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
  340. {
  341. scratch_buffer = (int16_t *) buf;
  342. }
  343. void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
  344. const int8_t *input,
  345. const data_dims_t *filter_dims,
  346. const int8_t *filter_data,
  347. const int32_t *bias,
  348. const data_dims_t *output_dims,
  349. int8_t *out_data,
  350. const conv_params_t *conv_params,
  351. const quant_data_t *quant_data)
  352. {
  353. const uint16_t input_wd = input_dims->width;
  354. const uint16_t input_ht = input_dims->height;
  355. const uint16_t channels = input_dims->channels;
  356. const int32_t input_offset = conv_params->in_offset;
  357. const int32_t out_offset = conv_params->out_offset;
  358. const uint16_t pad_wd = conv_params->padding.width;
  359. const uint16_t pad_ht = conv_params->padding.height;
  360. const uint16_t stride_wd = conv_params->stride.width;
  361. const uint16_t stride_ht = conv_params->stride.height;
  362. const uint16_t filter_wd = filter_dims->width;
  363. const uint16_t filter_ht = filter_dims->height;
  364. const uint16_t out_wd = output_dims->width;
  365. const uint16_t out_ht = output_dims->height;
  366. const uint16_t out_channels = output_dims->channels;
  367. const int32_t *out_shift = quant_data->shift;
  368. const int32_t *out_mult = quant_data->mult;
  369. const int32_t activation_min = conv_params->activation.min;
  370. const int32_t activation_max = conv_params->activation.max;
  371. int filter_size = filter_wd * filter_ht * channels * out_channels;
  372. int input_size = input_wd * input_ht * channels;
  373. int align_len = 16 - (filter_size & 15);
  374. int16_t *filter_data16 = scratch_buffer;
  375. int16_t *input_data16 = scratch_buffer + filter_size + align_len;
  376. if (scratch_buffer == NULL) {
  377. printf("esp_nn_conv error! scratch_buffer not set!\n");
  378. return;
  379. }
  380. if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
  381. pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
  382. int8_t *filter_aligned = (int8_t *) scratch_buffer;
  383. int scratch_offset = (int) (filter_aligned + filter_size);
  384. void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
  385. memcpy(filter_aligned, filter_data, filter_size); // copy to aligned address
  386. esp_nn_conv_s8_mult8_1x1_esp32s3(
  387. input, input_wd, input_ht, channels, input_offset, filter_aligned,
  388. bias, out_data, out_wd, out_ht, out_channels, out_offset,
  389. out_shift, out_mult, activation_min, activation_max, scratch_buf);
  390. } else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
  391. (input_wd * input_ht) % 4 == 0 && /* TODO: remove this check */
  392. pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
  393. int scratch_offset = (int) (input_data16 + input_size);
  394. void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
  395. esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
  396. esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input, input_data16, input_size, input_offset);
  397. esp_nn_conv_s16_mult4_1x1_esp32s3(
  398. input_data16, input_wd, input_ht, channels, filter_data16,
  399. bias, out_data, out_wd, out_ht, out_channels, out_offset,
  400. out_shift, out_mult, activation_min, activation_max, scratch_buf);
  401. } else if (channels % 8 == 0) {
  402. esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
  403. esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input, input_data16, input_size, input_offset);
  404. esp_nn_conv_s16_mult8_esp32s3(
  405. input_data16, input_wd, input_ht, channels, pad_wd, pad_ht,
  406. stride_wd, stride_ht, filter_data16, filter_wd, filter_ht, bias,
  407. out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
  408. out_mult, activation_min, activation_max);
  409. } else if (pad_wd == 0 && pad_ht == 0) {
  410. if (filter_wd == 3 && filter_ht == 3 && channels == 3) {
  411. esp_nn_conv_s8_pad_valid_ch3_3x3(input, input_wd, input_ht, input_offset,
  412. stride_wd, stride_ht, filter_data, bias,
  413. out_data, out_wd, out_ht, out_channels, out_offset,
  414. out_shift, out_mult, activation_min, activation_max);
  415. } else {
  416. esp_nn_conv_s8_pad_valid(input, input_wd, input_ht, channels, input_offset,
  417. stride_wd, stride_ht, filter_data, filter_wd, filter_ht, bias,
  418. out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
  419. out_mult, activation_min, activation_max);
  420. }
  421. } else {
  422. /* Basic unrolled version */
  423. esp_nn_conv_s8_unrolled(input_dims, input, filter_dims, filter_data,
  424. bias, output_dims, out_data, conv_params, quant_data);
  425. }
  426. }