esp_nn_esp32s3.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. // Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. /**
  15. * @file Header definitions to include for esp_nn optimized functions for
  16. * the ESP32-S3 platform
  17. */
  18. #pragma once
  19. #include <stdint.h>
  20. #include "esp_nn_ansi_headers.h"
  21. /************************** Basic math functions *****************************/
  22. /**
  23. * @brief elementwise addition
  24. *
  25. * @note inputs type: int8_t, output: int8_t
  26. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  27. *
  28. * shift values are expected to be <= 0
  29. */
  30. void esp_nn_add_elementwise_s8_esp32s3(const int8_t *input1_data,
  31. const int8_t *input2_data,
  32. const int32_t input1_offset,
  33. const int32_t input2_offset,
  34. const int32_t input1_mult,
  35. const int32_t input2_mult,
  36. const int32_t input1_shift,
  37. const int32_t input2_shift,
  38. const int32_t left_shift,
  39. int8_t *output,
  40. const int32_t out_offset,
  41. const int32_t out_mult,
  42. const int32_t out_shift,
  43. const int32_t activation_min,
  44. const int32_t activation_max,
  45. const int32_t size);
  46. /**
  47. * @brief elementwise multiplication
  48. *
  49. * @note inputs type: int8_t, output: int8_t
  50. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  51. *
  52. * output shift is expected to be <= 0
  53. */
  54. void esp_nn_mul_elementwise_s8_esp32s3(const int8_t *input1_data,
  55. const int8_t *input2_data,
  56. const int32_t input1_offset,
  57. const int32_t input2_offset,
  58. int8_t *output,
  59. const int32_t out_offset,
  60. const int32_t out_mult,
  61. const int32_t out_shift,
  62. const int32_t activation_min,
  63. const int32_t activation_max,
  64. const int32_t size);
  65. /************************** Convolution functions *****************************/
  66. /**
  67. * @brief depthwise convolution per channel
  68. *
  69. * @note inputs type: int8_t, output: int8_t
  70. * Version used in tflite is per channel.
  71. * This version follows the same footsprints.
  72. * Meaning, it has per out_channel shift and multiplier for
  73. * requantization
  74. *
  75. * optimization notes: Though input_offset is int32 type,
  76. * offset values are contained in 8 bits [-128, 127]
  77. */
  78. void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
  79. const uint16_t input_wd,
  80. const uint16_t input_ht,
  81. const uint16_t channels,
  82. const int32_t input_offset,
  83. const uint16_t pad_wd,
  84. const uint16_t pad_ht,
  85. const uint16_t stride_wd,
  86. const uint16_t stride_ht,
  87. const uint16_t ch_mult,
  88. const int8_t *filter_data,
  89. const uint16_t filter_wd,
  90. const uint16_t filter_ht,
  91. const int32_t *bias,
  92. int8_t *out_data,
  93. const uint16_t out_wd,
  94. const uint16_t out_ht,
  95. const int32_t out_offset,
  96. const int32_t *out_shift,
  97. const int32_t *out_mult,
  98. const int32_t activation_min,
  99. const int32_t activation_max);
  100. /**
  101. * @brief 2d - convolution channelwise
  102. *
  103. * @note operation: result += (input + offset) * filter
  104. *
  105. * inputs type: int8_t, output: int8_t
  106. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  107. */
  108. void esp_nn_conv_s8_esp32s3(const int8_t *input_data,
  109. const uint16_t input_wd,
  110. const uint16_t input_ht,
  111. const uint16_t in_channels,
  112. const int32_t input_offset,
  113. const uint16_t pad_wd,
  114. const uint16_t pad_ht,
  115. const uint16_t stride_wd,
  116. const uint16_t stride_ht,
  117. const int8_t *filter_data,
  118. const uint16_t filter_wd,
  119. const uint16_t filter_ht,
  120. const int32_t *bias,
  121. int8_t *out_data,
  122. const uint16_t out_wd,
  123. const uint16_t out_ht,
  124. const uint16_t out_channels,
  125. const int32_t out_offset,
  126. const int32_t *out_shift,
  127. const int32_t *out_mult,
  128. const int32_t activation_min,
  129. const int32_t activation_max);
  130. int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
  131. const uint16_t input_ht,
  132. const uint16_t in_ch,
  133. const uint16_t out_ch,
  134. const uint16_t filter_wd,
  135. const uint16_t filter_ht);
  136. void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf);
  137. int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
  138. const uint16_t input_ht,
  139. const uint16_t channels,
  140. const uint16_t ch_mult,
  141. const uint16_t filter_wd,
  142. const uint16_t filter_ht);
  143. void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf);
  144. /************************** Pooling functions *****************************/
  145. /**
  146. * @brief max_pool
  147. *
  148. * @note inputs type: int8_t, output: int8_t
  149. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  150. */
  151. void esp_nn_max_pool_s8_esp32s3(const int8_t *input,
  152. const uint16_t input_wd,
  153. const uint16_t input_ht,
  154. int8_t *output,
  155. const uint16_t output_wd,
  156. const uint16_t output_ht,
  157. const uint16_t stride_wd,
  158. const uint16_t stride_ht,
  159. const uint16_t filter_wd,
  160. const uint16_t filter_ht,
  161. const uint16_t pad_wd,
  162. const uint16_t pad_ht,
  163. const int32_t activation_min,
  164. const int32_t activation_max,
  165. const uint16_t channels);
  166. /**
  167. * @brief avg_pool
  168. *
  169. * @note inputs type: int8_t, output: int8_t
  170. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  171. */
  172. void esp_nn_avg_pool_s8_esp32s3(const int8_t *input,
  173. const uint16_t input_wd,
  174. const uint16_t input_ht,
  175. int8_t *output,
  176. const uint16_t output_wd,
  177. const uint16_t output_ht,
  178. const uint16_t stride_wd,
  179. const uint16_t stride_ht,
  180. const uint16_t filter_wd,
  181. const uint16_t filter_ht,
  182. const uint16_t pad_wd,
  183. const uint16_t pad_ht,
  184. const int32_t activation_min,
  185. const int32_t activation_max,
  186. const uint16_t channels);
  187. /************************** Fully connected functions *****************************/
  188. /**
  189. * @brief fully connected
  190. *
  191. * @note inputs type: int8_t, output: int8_t
  192. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  193. *
  194. * Current version works only on aligned input.
  195. * row_len and channels should both be multiple of 8.
  196. */
  197. void esp_nn_fully_connected_s8_esp32s3(const int8_t *input_data,
  198. const int32_t input_offset,
  199. const uint16_t row_len,
  200. const int8_t *filter_data,
  201. const int32_t filter_offset,
  202. const int32_t *bias,
  203. int8_t *out_data,
  204. const uint16_t out_channels,
  205. const int32_t out_offset,
  206. const int32_t out_shift,
  207. const int32_t out_mult,
  208. const int32_t activation_min,
  209. const int32_t activation_max);
  210. /**
  211. * @brief relu6
  212. *
  213. * @note inout: int8_t
  214. */
  215. void esp_nn_relu6_s8_esp32s3(int8_t *data, uint16_t size);
  216. /********************** function defines ***************************/
  217. #define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_esp32s3
  218. #define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_esp32s3
  219. #define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_esp32s3
  220. #define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_esp32s3
  221. #define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_esp32s3
  222. #define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_esp32s3
  223. #define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_esp32s3
  224. #define esp_nn_conv_s8 esp_nn_conv_s8_esp32s3
  225. #define esp_nn_relu6_s8 esp_nn_relu6_s8_esp32s3
  226. #define esp_nn_avg_pool_s8 esp_nn_avg_pool_s8_esp32s3
  227. #define esp_nn_max_pool_s8 esp_nn_max_pool_s8_esp32s3
  228. #define esp_nn_fully_connected_s8 esp_nn_fully_connected_s8_esp32s3
  229. #define esp_nn_get_softmax_scratch_size esp_nn_get_softmax_scratch_size_opt
  230. #define esp_nn_set_softmax_scratch_buf esp_nn_set_softmax_scratch_buf_opt
  231. #define esp_nn_softmax_s8 esp_nn_softmax_s8_opt