softmax_common.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. // Copyright 2022 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <stdint.h>
  15. #include <common_functions.h>
  16. #define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
  17. #define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
  18. #define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
  19. #define SAT_HIGH_MUL(x, y) esp_nn_sat_round_doubling_high_mul((x), (y))
  20. #define DIV_POW2(x,y) esp_nn_div_by_power_of_two((x), (y))
  21. __NN_FORCE_INLINE__ int32_t mul_power_of_2(int val, int exp)
  22. {
  23. const int32_t thresh = ((1 << (31 - exp)) - 1);
  24. int32_t result = val << exp;
  25. result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), INT32_MAX, result);
  26. result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), INT32_MIN, result);
  27. return result;
  28. }
  29. /**
  30. * @brief Calculate `1 / (1 + x)` for x in [0, 1]
  31. *
  32. * @param val input value to calculate `1/(1+x)` for
  33. * @return `int32_t` result
  34. * @note Newton-Raphson division
  35. *
  36. * https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
  37. * Refer to that page for the logic behind the 48/17 and 32/17 constants.
  38. * Pseudocode: https://en.wikipedia.org/wiki/Division_algorithm#Pseudocode
  39. */
  40. __NN_FORCE_INLINE__ int32_t esp_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
  41. {
  42. const int64_t sum = (int64_t) val + INT32_MAX;
  43. const int32_t half_denominator = (int32_t) ((sum + (sum >= 0 ? 1 : -1)) / 2L);
  44. int32_t constant_48_over_17 = 1515870810;
  45. int32_t constant_neg_32_over_17 = -1010580540;
  46. int32_t x = constant_48_over_17 + SAT_HIGH_MUL(half_denominator, constant_neg_32_over_17);
  47. const int32_t fixed_2_one = (1 << 29);
  48. x += mul_power_of_2(SAT_HIGH_MUL(x, fixed_2_one - SAT_HIGH_MUL(half_denominator, x)), 2);
  49. x += mul_power_of_2(SAT_HIGH_MUL(x, fixed_2_one - SAT_HIGH_MUL(half_denominator, x)), 2);
  50. x += mul_power_of_2(SAT_HIGH_MUL(x, fixed_2_one - SAT_HIGH_MUL(half_denominator, x)), 2);
  51. return mul_power_of_2(x, 1);
  52. }
  53. #define ONE_OVER_ONE_X(x) esp_nn_one_over_one_plus_x_for_x_in_0_1((x))
  54. /**
  55. * @brief Return exp(x) for x < 0.
  56. *
  57. */
  58. __NN_FORCE_INLINE__ int32_t esp_nn_exp_on_negative_values(int32_t val)
  59. {
  60. int32_t shift = 24;
  61. const int32_t one_quarter = (1 << shift);
  62. int32_t mask = one_quarter - 1;
  63. const int32_t val_mod_minus_quarter = (val & mask) - one_quarter;
  64. const int32_t remainder = val_mod_minus_quarter - val;
  65. // calculate exponent for x in [-1/4, 0) in `result`
  66. const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
  67. const int32_t x2 = SAT_HIGH_MUL(x, x);
  68. const int32_t x3 = SAT_HIGH_MUL(x2, x);
  69. const int32_t x4 = SAT_HIGH_MUL(x2, x2);
  70. const int32_t one_over_3 = 715827883;
  71. const int32_t one_over_8 = 1895147668;
  72. const int32_t x4_over_4 = DIV_POW2(x4, 2);
  73. const int32_t x4_over_4_plus_x3_over_6_plus_x2_over_2 = DIV_POW2(SAT_HIGH_MUL(x4_over_4 + x3, one_over_3) + x2, 1);
  74. int32_t result = one_over_8 + SAT_HIGH_MUL(one_over_8, x + x4_over_4_plus_x3_over_6_plus_x2_over_2);
  75. #define SELECT_IF_NON_ZERO(x) { \
  76. mask = MASK_IF_NON_ZERO(remainder & (1 << shift++)); \
  77. result = SELECT_USING_MASK(mask, SAT_HIGH_MUL(result, x), result); \
  78. }
  79. SELECT_IF_NON_ZERO(1672461947)
  80. SELECT_IF_NON_ZERO(1302514674)
  81. SELECT_IF_NON_ZERO(790015084)
  82. SELECT_IF_NON_ZERO(290630308)
  83. SELECT_IF_NON_ZERO(39332535)
  84. SELECT_IF_NON_ZERO(720401)
  85. SELECT_IF_NON_ZERO(242)
  86. #undef SELECT_IF_NON_ZERO
  87. mask = MASK_IF_ZERO(val);
  88. return SELECT_USING_MASK(mask, INT32_MAX, result);
  89. }