esp_nn_conv_esp32s3.c 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. // Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <stdint.h>
  15. #include <stdio.h>
  16. #include <common_functions.h>
  17. static int16_t *scratch_buffer = NULL;
  18. extern void esp_nn_conv_s16_mult8_1x1_esp32s3(const int8_t *input_data,
  19. const uint16_t input_wd,
  20. const uint16_t input_ht,
  21. const uint16_t in_channels,
  22. const int32_t input_offset,
  23. const int16_t *filter_data,
  24. const int32_t *bias,
  25. int8_t *out_data,
  26. const uint16_t out_wd,
  27. const uint16_t out_ht,
  28. const uint16_t out_channels,
  29. const int32_t out_offset,
  30. const int32_t *out_shift,
  31. const int32_t *out_mult,
  32. const int32_t activation_min,
  33. const int32_t activation_max,
  34. void *buffer /* scratch buffer */);
  35. extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
  36. const uint16_t input_wd,
  37. const uint16_t input_ht,
  38. const uint16_t in_channels,
  39. const int16_t *filter_data,
  40. const int32_t *bias,
  41. int8_t *out_data,
  42. const uint16_t out_wd,
  43. const uint16_t out_ht,
  44. const uint16_t out_channels,
  45. const int32_t out_offset,
  46. const int32_t *out_shift,
  47. const int32_t *out_mult,
  48. const int32_t activation_min,
  49. const int32_t activation_max,
  50. void *buffer /* scratch buffer */);
  51. extern void esp_nn_conv_s16_mult8_esp32s3(const int16_t *input_data,
  52. const uint16_t input_wd,
  53. const uint16_t input_ht,
  54. const uint16_t in_channels,
  55. const uint16_t pad_wd,
  56. const uint16_t pad_ht,
  57. const uint16_t stride_wd,
  58. const uint16_t stride_ht,
  59. const int16_t *filter_data,
  60. const uint16_t filter_wd,
  61. const uint16_t filter_ht,
  62. const int32_t *bias,
  63. int8_t *out_data,
  64. const uint16_t out_wd,
  65. const uint16_t out_ht,
  66. const uint16_t out_channels,
  67. const int32_t out_offset,
  68. const int32_t *out_shift,
  69. const int32_t *out_mult,
  70. const int32_t activation_min,
  71. const int32_t activation_max);
  72. extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int16_t *dst,
  73. const int size, const int32_t offset);
  74. extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
  75. static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
  76. const uint16_t input_wd,
  77. const uint16_t input_ht,
  78. const uint16_t in_channels,
  79. const int32_t input_offset,
  80. const uint16_t pad_wd,
  81. const uint16_t pad_ht,
  82. const uint16_t stride_wd,
  83. const uint16_t stride_ht,
  84. const int8_t *filter_data,
  85. const uint16_t filter_wd,
  86. const uint16_t filter_ht,
  87. const int32_t *bias,
  88. int8_t *out_data,
  89. const uint16_t out_wd,
  90. const uint16_t out_ht,
  91. const uint16_t out_channels,
  92. const int32_t out_offset,
  93. const int32_t *out_shift,
  94. const int32_t *out_mult,
  95. const int32_t activation_min,
  96. const int32_t activation_max)
  97. {
  98. int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
  99. for (out_y = 0; out_y < out_ht; out_y++) {
  100. for (out_x = 0; out_x < out_wd; out_x++) {
  101. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  102. int32_t conv_out = 0;
  103. const int32_t base_y = stride_ht * out_y - pad_ht;
  104. const int32_t base_x = stride_wd * out_x - pad_wd;
  105. const int32_t filter_y_start = max(0, -base_y);
  106. const int32_t filter_x_start = max(0, -base_x);
  107. const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
  108. const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
  109. for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
  110. for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
  111. const int32_t in_row = base_y + filter_y_idx;
  112. const int32_t in_col = base_x + filter_x_idx;
  113. int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
  114. int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
  115. (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
  116. for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
  117. conv_out +=
  118. (input_data[input_base_offset + in_ch_idx] + input_offset) *
  119. filter_data[filter_base_offset + in_ch_idx];
  120. }
  121. }
  122. }
  123. if (bias) {
  124. conv_out += bias[out_ch_idx];
  125. }
  126. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
  127. conv_out += out_offset;
  128. conv_out = max(conv_out, activation_min);
  129. conv_out = min(conv_out, activation_max);
  130. *out_data++ = (int8_t) conv_out;
  131. }
  132. }
  133. }
  134. }
  135. static void esp_nn_conv_s8_pad_valid(const int8_t *input_data,
  136. const uint16_t input_wd,
  137. const uint16_t input_ht,
  138. const uint16_t in_channels,
  139. const int32_t input_offset,
  140. const uint16_t stride_wd,
  141. const uint16_t stride_ht,
  142. const int8_t *filter_data,
  143. const uint16_t filter_wd,
  144. const uint16_t filter_ht,
  145. const int32_t *bias,
  146. int8_t *out_data,
  147. const uint16_t out_wd,
  148. const uint16_t out_ht,
  149. const uint16_t out_channels,
  150. const int32_t out_offset,
  151. const int32_t *out_shift,
  152. const int32_t *out_mult,
  153. const int32_t activation_min,
  154. const int32_t activation_max)
  155. {
  156. int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
  157. for (out_y = 0; out_y < out_ht; out_y++) {
  158. for (out_x = 0; out_x < out_wd; out_x++) {
  159. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  160. int32_t conv_out = 0;
  161. const int32_t base_y = stride_ht * out_y;
  162. const int32_t base_x = stride_wd * out_x;
  163. for (filter_y_idx = 0; filter_y_idx < filter_ht; filter_y_idx++) {
  164. for (filter_x_idx = 0; filter_x_idx < filter_wd; filter_x_idx++) {
  165. const int32_t in_row = base_y + filter_y_idx;
  166. const int32_t in_col = base_x + filter_x_idx;
  167. int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
  168. int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
  169. (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
  170. const int8_t *input_data_ptr = input_data + input_base_offset;
  171. const int8_t *filter_data_ptr = filter_data + filter_base_offset;
  172. for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
  173. conv_out += (*input_data_ptr++ + input_offset) * *filter_data_ptr++;
  174. }
  175. }
  176. }
  177. if (bias) {
  178. conv_out += bias[out_ch_idx];
  179. }
  180. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
  181. conv_out += out_offset;
  182. conv_out = max(conv_out, activation_min);
  183. conv_out = min(conv_out, activation_max);
  184. *out_data++ = (int8_t) conv_out;
  185. }
  186. }
  187. }
  188. }
  189. static void esp_nn_conv_s8_pad_valid_3x3(const int8_t *input_data,
  190. const uint16_t input_wd,
  191. const uint16_t input_ht,
  192. const uint16_t in_channels,
  193. const int32_t input_offset,
  194. const uint16_t stride_wd,
  195. const uint16_t stride_ht,
  196. const int8_t *filter_data,
  197. const int32_t *bias,
  198. int8_t *out_data,
  199. const uint16_t out_wd,
  200. const uint16_t out_ht,
  201. const uint16_t out_channels,
  202. const int32_t out_offset,
  203. const int32_t *out_shift,
  204. const int32_t *out_mult,
  205. const int32_t activation_min,
  206. const int32_t activation_max)
  207. {
  208. int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
  209. for (out_y = 0; out_y < out_ht; out_y++) {
  210. for (out_x = 0; out_x < out_wd; out_x++) {
  211. const int32_t base_y = stride_ht * out_y;
  212. const int32_t base_x = stride_wd * out_x;
  213. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  214. int32_t conv_out = 0;
  215. for (filter_y_idx = 0; filter_y_idx < 3; filter_y_idx++) {
  216. for (filter_x_idx = 0; filter_x_idx < 3; filter_x_idx++) {
  217. const int32_t in_row = base_y + filter_y_idx;
  218. const int32_t in_col = base_x + filter_x_idx;
  219. int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
  220. int32_t filter_base_offset = out_ch_idx * in_channels * 3 * 3 +
  221. (filter_y_idx * 3 + filter_x_idx) * in_channels;
  222. const int8_t *input_data_ptr = input_data + input_base_offset;
  223. const int8_t *filter_data_ptr = filter_data + filter_base_offset;
  224. for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
  225. conv_out += (*input_data_ptr++ + input_offset) * *filter_data_ptr++;
  226. }
  227. }
  228. }
  229. if (bias) {
  230. conv_out += bias[out_ch_idx];
  231. }
  232. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
  233. conv_out += out_offset;
  234. conv_out = max(conv_out, activation_min);
  235. conv_out = min(conv_out, activation_max);
  236. *out_data++ = (int8_t) conv_out;
  237. }
  238. }
  239. }
  240. }
  241. static void esp_nn_conv_s8_pad_valid_ch3_3x3(const int8_t *input_data,
  242. const uint16_t input_wd,
  243. const uint16_t input_ht,
  244. const int32_t input_offset,
  245. const uint16_t stride_wd,
  246. const uint16_t stride_ht,
  247. const int8_t *filter_data,
  248. const int32_t *bias,
  249. int8_t *out_data,
  250. const uint16_t out_wd,
  251. const uint16_t out_ht,
  252. const uint16_t out_channels,
  253. const int32_t out_offset,
  254. const int32_t *out_shift,
  255. const int32_t *out_mult,
  256. const int32_t activation_min,
  257. const int32_t activation_max)
  258. {
  259. int32_t out_ch_idx, out_y, out_x, filter_y_idx;
  260. /* use scratch_buffer to pre-compute offset factor */
  261. int16_t *filter_sum = (int16_t *) scratch_buffer;
  262. const int8_t *filter_ptr = filter_data;
  263. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  264. int16_t sum_val = 0;
  265. for (int i = 0; i < 9; i++) {
  266. sum_val += *filter_ptr++;
  267. sum_val += *filter_ptr++;
  268. sum_val += *filter_ptr++;
  269. }
  270. *filter_sum++ = sum_val;
  271. }
  272. for (out_y = 0; out_y < out_ht; out_y++) {
  273. for (out_x = 0; out_x < out_wd; out_x++) {
  274. const int8_t *filter_data_ptr = filter_data;
  275. const int32_t base_y = stride_ht * out_y;
  276. const int32_t base_x = stride_wd * out_x;
  277. const int8_t *input_base_ptr = input_data + (base_y * input_wd + base_x) * 3;
  278. int16_t *filter_sum = (int16_t *) scratch_buffer;
  279. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  280. int32_t conv_out = 0;
  281. for (filter_y_idx = 0; filter_y_idx < 3; filter_y_idx++) {
  282. const int8_t *input_data_ptr = input_base_ptr + (filter_y_idx * input_wd) * 3;
  283. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  284. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  285. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  286. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  287. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  288. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  289. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  290. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  291. conv_out += (*input_data_ptr++) * (*filter_data_ptr++);
  292. }
  293. conv_out += *filter_sum++ * input_offset;
  294. if (bias) {
  295. conv_out += bias[out_ch_idx];
  296. }
  297. conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
  298. conv_out += out_offset;
  299. conv_out = max(conv_out, activation_min);
  300. conv_out = min(conv_out, activation_max);
  301. *out_data++ = (int8_t) conv_out;
  302. }
  303. }
  304. }
  305. }
  306. int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
  307. const uint16_t input_ht,
  308. const uint16_t in_ch,
  309. const uint16_t out_ch,
  310. const uint16_t filter_wd,
  311. const uint16_t filter_ht)
  312. {
  313. int filter_size = filter_wd * filter_ht * in_ch * out_ch;
  314. int input_size = input_wd * input_ht * in_ch;
  315. int transpose_buf_size = 8 * in_ch; /* to store intermediate data */
  316. int align_buf_size = 32; /* extra buffer for alignment */
  317. return 2 * (filter_size + input_size + transpose_buf_size) + align_buf_size;
  318. }
  319. void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
  320. {
  321. scratch_buffer = (int16_t *) buf;
  322. }
  323. void esp_nn_conv_s8_esp32s3(const int8_t *input,
  324. const uint16_t input_wd,
  325. const uint16_t input_ht,
  326. const uint16_t channels,
  327. const int32_t input_offset,
  328. const uint16_t pad_wd,
  329. const uint16_t pad_ht,
  330. const uint16_t stride_wd,
  331. const uint16_t stride_ht,
  332. const int8_t *filter_data,
  333. const uint16_t filter_wd,
  334. const uint16_t filter_ht,
  335. const int32_t *bias,
  336. int8_t *out_data,
  337. const uint16_t out_wd,
  338. const uint16_t out_ht,
  339. const uint16_t out_channels,
  340. const int32_t out_offset,
  341. const int32_t *out_shift,
  342. const int32_t *out_mult,
  343. const int32_t activation_min,
  344. const int32_t activation_max)
  345. {
  346. int filter_size = filter_wd * filter_ht * channels * out_channels;
  347. int input_size = input_wd * input_ht * channels;
  348. int align_len = 16 - (filter_size & 15);
  349. int16_t *filter_data16 = scratch_buffer;
  350. int16_t *input_data16 = scratch_buffer + filter_size + align_len;
  351. if (scratch_buffer == NULL) {
  352. printf("esp_nn_conv error! scratch_buffer not set!\n");
  353. return;
  354. }
  355. if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
  356. pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
  357. int scratch_offset = (int) (filter_data16 + filter_size);
  358. void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
  359. esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
  360. esp_nn_conv_s16_mult8_1x1_esp32s3(
  361. input, input_wd, input_ht, channels, input_offset, filter_data16,
  362. bias, out_data, out_wd, out_ht, out_channels, out_offset,
  363. out_shift, out_mult, activation_min, activation_max, scratch_buf);
  364. } else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
  365. (input_wd * input_ht) % 16 == 0 && /* TODO: remove this check */
  366. pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
  367. int scratch_offset = (int) (input_data16 + input_size);
  368. void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
  369. esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
  370. esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input, input_data16, input_size, input_offset);
  371. esp_nn_conv_s16_mult4_1x1_esp32s3(
  372. input_data16, input_wd, input_ht, channels, filter_data16,
  373. bias, out_data, out_wd, out_ht, out_channels, out_offset,
  374. out_shift, out_mult, activation_min, activation_max, scratch_buf);
  375. } else if (channels % 8 == 0) {
  376. esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
  377. esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input, input_data16, input_size, input_offset);
  378. esp_nn_conv_s16_mult8_esp32s3(
  379. input_data16, input_wd, input_ht, channels, pad_wd, pad_ht,
  380. stride_wd, stride_ht, filter_data16, filter_wd, filter_ht, bias,
  381. out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
  382. out_mult, activation_min, activation_max);
  383. } else if (pad_wd == 0 && pad_ht == 0) {
  384. if (filter_wd == 3 && filter_ht == 3 && channels == 3) {
  385. esp_nn_conv_s8_pad_valid_ch3_3x3(input, input_wd, input_ht, input_offset,
  386. stride_wd, stride_ht, filter_data, bias,
  387. out_data, out_wd, out_ht, out_channels, out_offset,
  388. out_shift, out_mult, activation_min, activation_max);
  389. } else {
  390. esp_nn_conv_s8_pad_valid(input, input_wd, input_ht, channels, input_offset,
  391. stride_wd, stride_ht, filter_data, filter_wd, filter_ht, bias,
  392. out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
  393. out_mult, activation_min, activation_max);
  394. }
  395. } else {
  396. /* Basic unrolled version */
  397. esp_nn_conv_s8_unrolled(input, input_wd, input_ht, channels, input_offset,
  398. pad_wd, pad_ht, stride_wd, stride_ht,
  399. filter_data, filter_wd, filter_ht, bias,
  400. out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
  401. out_mult, activation_min, activation_max);
  402. }
  403. }