esp_nn_ansi_headers.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. // Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. /**
  16. * @file Header definitions to include for esp_nn reference functions
  17. */
  18. #include "esp_nn_defs.h"
  19. /************************** Basic math functions ****************************/
  20. /**
  21. * @brief elementwise addition
  22. *
  23. * @note inputs type: int8_t, output: int8_t
  24. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  25. *
  26. * shift values are expected to be <= 0
  27. */
  28. void esp_nn_add_elementwise_s8_ansi(const int8_t *input1_data,
  29. const int8_t *input2_data,
  30. const int32_t input1_offset,
  31. const int32_t input2_offset,
  32. const int32_t input1_mult,
  33. const int32_t input2_mult,
  34. const int32_t input1_shift,
  35. const int32_t input2_shift,
  36. const int32_t left_shift,
  37. int8_t *output,
  38. const int32_t out_offset,
  39. const int32_t out_mult,
  40. const int32_t out_shift,
  41. const int32_t activation_min,
  42. const int32_t activation_max,
  43. const int32_t size);
  44. /**
  45. * @brief elementwise multiplication
  46. *
  47. * @note inputs type: int8_t, output: int8_t
  48. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  49. *
  50. * output shift is expected to be <= 0
  51. */
  52. void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
  53. const int8_t *input2_data,
  54. const int32_t input1_offset,
  55. const int32_t input2_offset,
  56. int8_t *output,
  57. const int32_t out_offset,
  58. const int32_t out_mult,
  59. const int32_t out_shift,
  60. const int32_t activation_min,
  61. const int32_t activation_max,
  62. const int32_t size);
  63. /************************** Convolution functions *****************************/
  64. /**
  65. * @brief depthwise convolution per channel
  66. *
  67. * @note inputs type: int8_t, output: int8_t
  68. * Version used in tflite is per channel.
  69. * This version follows the same footsprints.
  70. * Meaning, it has per out_channel shift and multiplier for
  71. * requantization
  72. *
  73. * optimization notes: Though input_offset is int32 type,
  74. * offset values are contained in 8 bits [-128, 127]
  75. */
  76. void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
  77. const int8_t *input_data,
  78. const data_dims_t *filter_dims,
  79. const int8_t *filter_data,
  80. const int32_t *bias,
  81. const data_dims_t *output_dims,
  82. int8_t *out_data,
  83. const dw_conv_params_t *conv_params,
  84. const quant_data_t *quant_data);
  85. /**
  86. * @brief 2d-convolution channelwise
  87. *
  88. * @note operation: result += (input + offset) * filter
  89. *
  90. * inputs type: int8_t, output: int8_t
  91. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  92. */
  93. void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
  94. const int8_t *input_data,
  95. const data_dims_t *filter_dims,
  96. const int8_t *filter_data,
  97. const int32_t *bias,
  98. const data_dims_t *output_dims,
  99. int8_t *out_data,
  100. const conv_params_t *conv_params,
  101. const quant_data_t *quant_data);
  102. int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
  103. const data_dims_t *filter_dims,
  104. const data_dims_t *output_dims,
  105. const conv_params_t *conv_params);
  106. void esp_nn_set_conv_scratch_buf_ansi(const void *buf);
  107. int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
  108. const data_dims_t *filter_dims,
  109. const data_dims_t *output_dims,
  110. const dw_conv_params_t *conv_params);
  111. void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf);
  112. /************************** Activation functions *****************************/
  113. /**
  114. * @brief relu6
  115. *
  116. * @note inout: int8_t
  117. */
  118. void esp_nn_relu6_s8_ansi(int8_t *data, uint16_t size);
  119. /************************** Pooling functions *****************************/
  120. /**
  121. * @brief max_pool
  122. *
  123. * @note inputs type: int8_t, output: int8_t
  124. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  125. */
  126. void esp_nn_max_pool_s8_ansi(const int8_t *input,
  127. const uint16_t input_wd,
  128. const uint16_t input_ht,
  129. int8_t *output,
  130. const uint16_t output_wd,
  131. const uint16_t output_ht,
  132. const uint16_t stride_wd,
  133. const uint16_t stride_ht,
  134. const uint16_t filter_wd,
  135. const uint16_t filter_ht,
  136. const uint16_t pad_wd,
  137. const uint16_t pad_ht,
  138. const int32_t activation_min,
  139. const int32_t activation_max,
  140. const uint16_t channels);
  141. /**
  142. * @brief avg_pool
  143. *
  144. * @note inputs type: int8_t, output: int8_t
  145. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  146. */
  147. void esp_nn_avg_pool_s8_ansi(const int8_t *input,
  148. const uint16_t input_wd,
  149. const uint16_t input_ht,
  150. int8_t *output,
  151. const uint16_t output_wd,
  152. const uint16_t output_ht,
  153. const uint16_t stride_wd,
  154. const uint16_t stride_ht,
  155. const uint16_t filter_wd,
  156. const uint16_t filter_ht,
  157. const uint16_t pad_wd,
  158. const uint16_t pad_ht,
  159. const int32_t activation_min,
  160. const int32_t activation_max,
  161. const uint16_t channels);
  162. /************************** Fully connected functions ***********************/
  163. /**
  164. * @brief fully connected
  165. *
  166. * @note inputs type: int8_t, output: int8_t
  167. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  168. */
  169. void esp_nn_fully_connected_s8_ansi(const int8_t *input_data,
  170. const int32_t input_offset,
  171. const uint16_t row_len,
  172. const int8_t *filter_data,
  173. const int32_t filter_offset,
  174. const int32_t *bias,
  175. int8_t *out_data,
  176. const uint16_t out_channels,
  177. const int32_t out_offset,
  178. const int32_t out_shift,
  179. const int32_t out_mult,
  180. const int32_t activation_min,
  181. const int32_t activation_max);
  182. /**
  183. * @brief Get scratch buffer size needed by softmax function
  184. *
  185. * @param width
  186. * @param height
  187. * @return size in bytes
  188. *
  189. * @note buffer must be 4 byte aligned
  190. */
  191. int32_t esp_nn_get_softmax_scratch_size_ansi(const int32_t width, const int32_t height);
  192. /* ANSI C function to be hooked up when optimised version needed */
  193. int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t height);
  194. /**
  195. * @brief Set scratch buffer to be used by softmax function
  196. *
  197. * @param buffer this can be NULL if one needs to unset it
  198. * must be aligned to 4 bytes
  199. */
  200. void esp_nn_set_softmax_scratch_buf_ansi(void *buffer);
  201. /**
  202. * @brief reference softmax function
  203. *
  204. * @note inputs type: int8_t, output: int8_t
  205. */
  206. void esp_nn_softmax_s8_ansi(const int8_t *input_data,
  207. const int32_t height,
  208. const int32_t width,
  209. const int32_t mult,
  210. const int32_t shift,
  211. const int32_t diff_min,
  212. int8_t *output_data);
  213. //////////////////////////// Generic optimisations /////////////////////////////
  214. /************************** Convolution functions *****************************/
  215. /**
  216. * @brief 2d-convolution channelwise optimized version
  217. *
  218. * @note operation: result += (input + offset) * filter
  219. *
  220. * inputs type: int8_t, output: int8_t
  221. * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  222. */
  223. void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
  224. const int8_t *input_data,
  225. const data_dims_t *filter_dims,
  226. const int8_t *filter_data,
  227. const int32_t *bias,
  228. const data_dims_t *output_dims,
  229. int8_t *out_data,
  230. const conv_params_t *conv_params,
  231. const quant_data_t *quant_data);
  232. /**
  233. * @brief depthwise convolution per channel optimized version
  234. *
  235. * @note inputs type: int8_t, output: int8_t
  236. * Version used in tflite is per channel.
  237. * This version follows the same footsprints.
  238. * Meaning, it has per out_channel shift and multiplier for
  239. * requantization
  240. *
  241. * optimization notes: Though input_offset is int32 type,
  242. * offset values are contained in 8 bits [-128, 127]
  243. */
  244. void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
  245. const int8_t *input_data,
  246. const data_dims_t *filter_dims,
  247. const int8_t *filter_data,
  248. const int32_t *bias,
  249. const data_dims_t *output_dims,
  250. int8_t *out_data,
  251. const dw_conv_params_t *conv_params,
  252. const quant_data_t *quant_data);
  253. int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
  254. const data_dims_t *filter_dims,
  255. const data_dims_t *output_dims,
  256. const conv_params_t *conv_params);
  257. void esp_nn_set_conv_scratch_buf_opt(const void *buf);
  258. int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
  259. const data_dims_t *filter_dims,
  260. const data_dims_t *output_dims,
  261. const dw_conv_params_t *conv_params);
  262. void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf);
  263. /* ANSI C function to be hooked up when optimised version needed */
  264. void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
  265. /**
  266. * @brief optimised version of softmax function
  267. *
  268. * @note the function uses extra buffer (4 * width bytes)
  269. * hence, scratch buffers must be set before calling this.
  270. */
  271. void esp_nn_softmax_s8_opt(const int8_t *input_data,
  272. const int32_t height,
  273. const int32_t width,
  274. const int32_t mult,
  275. const int32_t shift,
  276. const int32_t diff_min,
  277. int8_t *output_data);