esp_nn_conv_ansi.c 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. // Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <stdint.h>
  15. #include <common_functions.h>
  16. int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
  17. const uint16_t input_ht,
  18. const uint16_t in_ch,
  19. const uint16_t out_ch,
  20. const uint16_t filter_wd,
  21. const uint16_t filter_ht)
  22. {
  23. return 0;
  24. }
  25. void esp_nn_set_conv_scratch_buf_ansi(const void *buf)
  26. {
  27. }
  28. /**
  29. * Assumption 1: i/p channels == o/p channels
  30. * Assumption 2: Pointers are valid
  31. * Assumption 3: dialation width = 1
  32. */
  33. void esp_nn_conv_u8_ansi(const uint8_t *input_data,
  34. const uint16_t input_wd,
  35. const uint16_t input_ht,
  36. const uint16_t in_channels,
  37. const int32_t input_offset,
  38. const uint16_t pad_wd,
  39. const uint16_t pad_ht,
  40. const uint16_t stride_wd,
  41. const uint16_t stride_ht,
  42. const uint8_t *filter_data,
  43. const uint16_t filter_wd,
  44. const uint16_t filter_ht,
  45. const int32_t filter_offset,
  46. const int32_t *bias,
  47. uint8_t *out_data,
  48. const uint16_t out_wd,
  49. const uint16_t out_ht,
  50. const uint16_t out_channels,
  51. const int32_t out_offset,
  52. const int32_t out_shift,
  53. const int32_t out_mult,
  54. const int32_t activation_min,
  55. const int32_t activation_max)
  56. {
  57. for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
  58. const int16_t base_y = (out_y * stride_ht) - pad_ht;
  59. for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
  60. const int16_t base_x = (out_x * stride_wd) - pad_wd;
  61. for (int out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {//channel_loop
  62. int32_t result = 0;
  63. /* Select filter so as the point doesn't lie outside block */
  64. int filter_y_start = max(0, -base_y);
  65. int filter_x_start = max(0, -base_x);
  66. int filter_y_end = min(filter_ht, input_ht - base_y);
  67. int filter_x_end = min(filter_wd, input_wd - base_x);
  68. for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
  69. const int32_t idx_y = base_y + filter_y_idx;
  70. for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
  71. const int32_t idx_x = base_x + filter_x_idx;
  72. for (int in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
  73. int32_t input_index = (idx_y * input_wd + idx_x) * in_channels + in_ch_idx;
  74. int32_t filter_index = ((out_ch_idx * filter_ht + filter_y_idx)
  75. * filter_wd + filter_x_idx) * in_channels
  76. + in_ch_idx;
  77. int32_t input_val = input_data[input_index] + input_offset;
  78. int32_t filter_val = filter_data[filter_index] + filter_offset;
  79. result += input_val * filter_val;
  80. }
  81. }
  82. }
  83. if (bias) {
  84. result += bias[out_ch_idx];
  85. }
  86. result = esp_nn_multiply_by_quantized_mult(result, out_mult, out_shift);
  87. result += out_offset;
  88. result = max(result, activation_min);
  89. result = min(result, activation_max);
  90. int out_index = (out_y * out_wd + out_x) * out_channels + out_ch_idx;
  91. out_data[out_index] = (uint8_t) result;
  92. }
  93. }
  94. }
  95. }
  96. /**
  97. * Assumption 1: i/p channels == o/p channels
  98. * Assumption 2: Pointers are valid
  99. * Assumption 3: dialation width = 1
  100. */
  101. void esp_nn_conv_s8_ansi(const int8_t *input_data,
  102. const uint16_t input_wd,
  103. const uint16_t input_ht,
  104. const uint16_t in_channels,
  105. const int32_t input_offset,
  106. const uint16_t pad_wd,
  107. const uint16_t pad_ht,
  108. const uint16_t stride_wd,
  109. const uint16_t stride_ht,
  110. const int8_t *filter_data,
  111. const uint16_t filter_wd,
  112. const uint16_t filter_ht,
  113. const int32_t *bias,
  114. int8_t *out_data,
  115. const uint16_t out_wd,
  116. const uint16_t out_ht,
  117. const uint16_t out_channels,
  118. const int32_t out_offset,
  119. const int32_t *out_shift,
  120. const int32_t *out_mult,
  121. const int32_t activation_min,
  122. const int32_t activation_max)
  123. {
  124. int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
  125. for (out_y = 0; out_y < out_ht; out_y++) {
  126. for (out_x = 0; out_x < out_wd; out_x++) {
  127. for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
  128. int32_t conv_out = 0;
  129. const int32_t base_y = stride_ht * out_y - pad_ht;
  130. const int32_t base_x = stride_wd * out_x - pad_wd;
  131. const int32_t filter_y_start = max(0, -base_y);
  132. const int32_t filter_x_start = max(0, -base_x);
  133. const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
  134. const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
  135. for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
  136. for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
  137. const int32_t in_row = base_y + filter_y_idx;
  138. const int32_t in_col = base_x + filter_x_idx;
  139. int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
  140. int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
  141. (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
  142. for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
  143. conv_out +=
  144. (input_data[input_base_offset + in_ch_idx] + input_offset) *
  145. filter_data[filter_base_offset + in_ch_idx];
  146. }
  147. }
  148. }
  149. if (bias) {
  150. conv_out += bias[out_ch_idx];
  151. }
  152. conv_out = esp_nn_multiply_by_quantized_mult(conv_out, out_mult[out_ch_idx], out_shift[out_ch_idx]);
  153. conv_out += out_offset;
  154. conv_out = max(conv_out, activation_min);
  155. conv_out = min(conv_out, activation_max);
  156. *out_data++ = (int8_t) conv_out;
  157. }
  158. }
  159. }
  160. }