common_functions.h 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. // Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <stdint.h>
  16. #include <stdbool.h>
  17. #include <string.h>
  18. /**
  19. * c99 standard still doesn't strictly inline functions
  20. * We need to use attribute as well to do this.
  21. */
  22. #define __NN_FORCE_INLINE__ __attribute((always_inline)) static inline
  23. /* min/max macros */
  24. #ifndef max
  25. #define max(a, b) ({ \
  26. __typeof__ (a) _a = (a); \
  27. __typeof__ (b) _b = (b); \
  28. _a > _b ? _a : _b; \
  29. })
  30. #define min(a, b) ({ \
  31. __typeof__ (a) _a = (a); \
  32. __typeof__ (b) _b = (b); \
  33. _a < _b ? _a : _b; \
  34. })
  35. #endif
  36. __NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in)
  37. {
  38. #if CONFIG_IDF_TARGET_ARCH_XTENSA
  39. __asm__ volatile("nsau %0, %0" : "+r" (in));
  40. return in;
  41. #elif defined(__GNUC__)
  42. return __builtin_clz(in);
  43. #else
  44. int32_t count = 32;
  45. uint32_t x = in, y = in >> 16;
  46. if (y != 0) {
  47. count -= 16;
  48. x = y;
  49. }
  50. y = x >> 8;
  51. if (y != 0) {
  52. count -= 8;
  53. x = y;
  54. }
  55. y = x >> 4;
  56. if (y != 0) {
  57. count -= 4;
  58. x = y;
  59. }
  60. y = x >> 2;
  61. if (y != 0) {
  62. count -= 2;
  63. x = y;
  64. }
  65. y = x >> 1;
  66. if (y != 0) {
  67. return count - 2;
  68. }
  69. return count - x;
  70. #endif
  71. }
  72. /**
  73. * Signed saturate a 32 bit value to 8 bits keeping output in 32 bit variable.
  74. */
  75. __NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in)
  76. {
  77. #if CONFIG_IDF_TARGET_ARCH_XTENSA
  78. __asm__ volatile("clamps %0, %0, 7" : "+a"(in));
  79. return in;
  80. #else
  81. return max(INT8_MIN, min(in, INT8_MAX));
  82. #endif
  83. }
  84. __NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
  85. {
  86. int32_t sign = (int32_t) (val64 >> 63);
  87. int32_t to_add = sign & ((1ul << 31) - 1);
  88. return (int32_t) ((int64_t) (val64 + to_add) >> 31);
  89. }
  90. __NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1)
  91. {
  92. int32_t result;
  93. int64_t in0_64 = (int64_t) in0;
  94. bool overflow = (in0 == in1) && (in0 == (int32_t) INT32_MIN);
  95. /* Nudge value */
  96. int64_t nudge_val = 1 << 30;
  97. if ((in0 < 0) ^ (in1 < 0)) {
  98. nudge_val = 1 - nudge_val;
  99. }
  100. /* Multiply and add nudge */
  101. int64_t mult = in0_64 * in1 + nudge_val;
  102. /* Round and pickup 32 bits */
  103. result = esp_nn_pick_sat_high32_of64(mult);
  104. return overflow ? INT32_MAX : result;
  105. }
  106. /**
  107. * fast version
  108. * this will fail for values closer to INT32_MAX and INT32_MIN by `1 << (exponent - 1)`.
  109. * We can afford to do this because we are at the very last stage of filter.
  110. * Also it is pretty rare condition as our output is going to be 8 bit.
  111. */
  112. __NN_FORCE_INLINE__ int32_t esp_nn_div_by_power_of_two_fast(int32_t val, int32_t exponent)
  113. {
  114. int32_t to_add = (1 << (exponent - 1)) - (val < 0);
  115. return (int32_t) ((val + to_add) >> exponent);
  116. }
  117. __NN_FORCE_INLINE__ int32_t esp_nn_div_by_power_of_two(int32_t val, int32_t exponent)
  118. {
  119. int32_t result;
  120. const int32_t mask = (1 << exponent) - 1;
  121. const int32_t remainder = val & mask;
  122. result = val >> exponent;
  123. int32_t threshold = (mask >> 1) + (result < 0);
  124. if (remainder > threshold) {
  125. result += 1;
  126. }
  127. return result;
  128. }
  129. __NN_FORCE_INLINE__ int32_t esp_nn_multiply_by_quantized_mult(int32_t x, int32_t mult, int32_t shift)
  130. {
  131. int32_t left_shift = shift > 0 ? shift : 0;
  132. int32_t right_shift = shift > 0 ? 0 : -shift;
  133. int32_t result = esp_nn_sat_round_doubling_high_mul(x * (1 << left_shift), mult);
  134. return esp_nn_div_by_power_of_two(result, right_shift);
  135. }
  136. __NN_FORCE_INLINE__ int32_t esp_nn_multiply_by_quantized_mult_fast(int32_t x, int32_t mult, int32_t shift)
  137. {
  138. int32_t left_shift = max(shift, 0);
  139. int32_t right_shift = left_shift - shift;
  140. int64_t nudge_val = 1 << 30;
  141. int64_t in0_64 = (int64_t) (x << left_shift);
  142. /* Multiply and add nudge */
  143. int64_t mult_64 = in0_64 * mult + nudge_val;
  144. int32_t result = (int32_t) (mult_64 >> 31);
  145. if (right_shift) {
  146. result = esp_nn_div_by_power_of_two_fast(result, right_shift);
  147. }
  148. return result;
  149. }
  150. static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
  151. const uint16_t input_wd,
  152. const uint16_t input_ht,
  153. const uint16_t channels,
  154. const int32_t pad_val,
  155. const uint16_t pad_wd,
  156. const uint16_t pad_ht)
  157. {
  158. /* memset with pad_val */
  159. memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels);
  160. dst += (pad_wd + input_wd + pad_wd) * channels;
  161. for (int i = 0; i < input_ht; i++) {
  162. dst += pad_wd * channels;
  163. for (int j = 0; j < input_wd * channels; j++) {
  164. *dst++ = *src++;
  165. }
  166. dst += pad_wd * channels;
  167. }
  168. }
  169. static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
  170. const uint16_t input_wd,
  171. const uint16_t input_ht,
  172. const uint16_t channels,
  173. const int32_t pad_val,
  174. const uint16_t pad_wd,
  175. const uint16_t pad_ht)
  176. {
  177. for (int i = 0; i < input_ht; i++) {
  178. for (int j = 0; j < input_wd * channels; j++) {
  179. *dst++ = *src++;
  180. }
  181. if (pad_wd) {
  182. memset(dst, pad_val, pad_wd * channels);
  183. dst += pad_wd * channels;
  184. }
  185. }
  186. /* pad end `pad_ht` lines at end */
  187. if (pad_ht) {
  188. memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
  189. }
  190. }
  191. /**
  192. * @brief convert 8 bit input data to 16 bit
  193. *
  194. * @param src int8_t source data
  195. * @param dst int16_t dst data
  196. * @param size length of data
  197. * @param offset offset to be added to src data. Range: [-128, 127]
  198. */
  199. __NN_FORCE_INLINE__ void esp_nn_s8_to_s16_with_offset(const int8_t *src, int16_t *dst,
  200. const int size, const int32_t offset)
  201. {
  202. int i = 0;
  203. for (; i < size; i += 2) {
  204. dst[i + 0] = src[i + 0] + offset;
  205. dst[i + 1] = src[i + 1] + offset;
  206. }
  207. if(i < size) {
  208. dst[i] = src[i] + offset;
  209. }
  210. }
  211. /**
  212. * @brief convert 8 bit input data to 16 bit
  213. *
  214. * @param src int8_t source data
  215. * @param dst int16_t dst data
  216. * @param size length of data
  217. */
  218. __NN_FORCE_INLINE__ void esp_nn_s8_to_s16(const int8_t *src, int16_t *dst, const int size)
  219. {
  220. int i = 0;
  221. for (; i < size; i += 2) {
  222. dst[i + 0] = src[i + 0];
  223. dst[i + 1] = src[i + 1];
  224. }
  225. if(i < size) {
  226. dst[i] = src[i];
  227. }
  228. }