basic_math_test.c 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. // Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <stdint.h>
  15. #include <stdbool.h>
  16. #include <stdio.h>
  17. #include <stdlib.h>
  18. #include <malloc.h>
  19. #include <common_functions.h>
  20. #include <esp_nn.h>
  21. #include "test_utils.h"
  22. #if CONFIG_IDF_CMAKE
  23. #define IDF_HEAP_CAPS 1
  24. #if IDF_HEAP_CAPS
  25. #include "esp_heap_caps.h"
  26. #endif
  27. #endif
  28. void esp_nn_add_elementwise_s8_test()
  29. {
  30. /* prepare data */
  31. const int size = 1600 + 8 + 7; /* odd len to test leftover */
  32. int8_t *input1;
  33. int8_t *input2;
  34. int8_t *out_data_c;
  35. int8_t *out_data_opt;
  36. int8_t *input1_orig = NULL;
  37. int8_t *input2_orig = NULL;
  38. int8_t *out_c_orig = NULL;
  39. int8_t *out_opt_orig = NULL;
  40. int32_t input1_offset = 34;
  41. int32_t input2_offset = 35;
  42. int32_t output_offset = 36;
  43. int32_t input1_shift = -8; // right_shift amt always <= 0
  44. int32_t input2_shift = -8; // right_shift amt always <= 0
  45. int32_t output_shift = -9; // right_shift amt always <= 0
  46. int32_t left_shift = 15; // always +ve
  47. int32_t input1_mult = INT32_MAX;
  48. int32_t input2_mult = INT32_MAX;
  49. int32_t output_mult = INT32_MAX;
  50. int32_t activation_min = -128;
  51. int32_t activation_max = 127;
  52. for (int itr = 0; itr < 10; itr++) {
  53. switch (itr) {
  54. case 0: // all zeros
  55. input1_offset = 0;
  56. input2_offset = 0;
  57. output_offset = 0;
  58. input1_mult = 0;
  59. input2_mult = 0;
  60. output_mult = 0;
  61. input1_shift = 0;
  62. input2_shift = 0;
  63. output_shift = 0;
  64. left_shift = 0;
  65. break;
  66. case 1: // hit min
  67. input1_offset = -127;
  68. input2_offset = -127;
  69. output_offset = -128;
  70. input1_mult = MULT_MIN;
  71. input2_mult = MULT_MIN;
  72. output_mult = MULT_MIN;
  73. input1_shift = 0;
  74. input2_shift = 0;
  75. output_shift = 0;
  76. left_shift = 0;
  77. break;
  78. case 2: // hit max
  79. input1_offset = 128;
  80. input2_offset = 128;
  81. output_offset = -127;
  82. input1_mult = MULT_MAX;
  83. input2_mult = MULT_MAX;
  84. output_mult = MULT_MAX;
  85. input1_shift = SHIFT_MIN;
  86. input2_shift = SHIFT_MIN;
  87. output_shift = SHIFT_MIN;
  88. left_shift = 30 - 8; // since input is 8 bits
  89. break;
  90. case 3: // hit extreme max
  91. input1_offset = 128;
  92. input2_offset = 128;
  93. output_offset = -127;
  94. input1_mult = MULT_MAX;
  95. input2_mult = MULT_MAX;
  96. output_mult = MULT_MAX;
  97. input1_shift = 0;
  98. input2_shift = 0;
  99. output_shift = 0;
  100. left_shift = 30 - 8; // -8 since input is 8 bit
  101. break;
  102. default: // practical random input
  103. input1_offset = rand() % 256 - 127; // range [-127, 128]
  104. input2_offset = rand() % 256 - 127; // range [-127, 128]
  105. output_offset = rand() % 256 - 128; // range [-128, 127]
  106. input1_mult = MULT_MAX / 2 + rand() % INT16_MAX;
  107. input2_mult = MULT_MAX / 2 + rand() % INT16_MAX;
  108. output_mult = MULT_MAX / 2 + rand() % INT16_MAX;
  109. input1_shift = -8 + rand() % 4;
  110. input2_shift = -8 + rand() % 4;
  111. output_shift = -8 + rand() % 4;
  112. left_shift = rand() % 15;
  113. }
  114. #if IDF_HEAP_CAPS
  115. input1_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
  116. input2_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
  117. out_c_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
  118. out_opt_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
  119. input1 = 16 + input1_orig - ((uint32_t) input1_orig & 0xf);
  120. input2 = 16 + input2_orig - ((uint32_t) input2_orig & 0xf);
  121. out_data_c = 16 + out_c_orig - ((uint32_t) out_c_orig & 0xf);
  122. out_data_opt = 16 + out_opt_orig - ((uint32_t) out_opt_orig & 0xf);
  123. #else
  124. input1 = memalign(16, size);
  125. input2 = memalign(16, size);
  126. out_data_c = memalign(16, size);
  127. out_data_opt = memalign(16, size);
  128. input1_orig = input1;
  129. input2_orig = input2;
  130. out_c_orig = out_data_c;
  131. out_opt_orig = out_data_opt;
  132. #endif
  133. for (int i = 0; i < size; ++i) {
  134. input1[i] = rand() % 256 - 128;
  135. input2[i] = rand() % 256 - 128;
  136. }
  137. if (itr == 0) {
  138. /* enable profiler */
  139. profile_c_start();
  140. }
  141. /* C function */
  142. esp_nn_add_elementwise_s8_ansi(input1, input2, input1_offset, input2_offset,
  143. input1_mult, input2_mult, input1_shift, input2_shift,
  144. left_shift, out_data_c, output_offset, output_mult,
  145. output_shift, activation_min, activation_max, size);
  146. if (itr == 0) {
  147. profile_c_end();
  148. profile_opt_start();
  149. }
  150. /* Optimized function */
  151. esp_nn_add_elementwise_s8(input1, input2, input1_offset, input2_offset,
  152. input1_mult, input2_mult, input1_shift, input2_shift,
  153. left_shift, out_data_opt, output_offset, output_mult,
  154. output_shift, activation_min, activation_max, size);
  155. if (itr == 0) {
  156. /* disable profiler */
  157. profile_opt_end();
  158. }
  159. bool ret = CHECK_EQUAL(out_data_c, out_data_opt, size);
  160. if (ret == false) {
  161. printf(ANSI_COLOR_RED"%s[%d] failed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
  162. printf("Output: \n");
  163. PRINT_ARRAY_HEX(out_data_opt, size, 1);
  164. printf("Expected: \n");
  165. PRINT_ARRAY_HEX(out_data_c, size, 1);
  166. printf("Input1:\n");
  167. PRINT_ARRAY_HEX(input1, size, 1);
  168. printf("Input2:\n");
  169. PRINT_ARRAY_HEX(input2, size, 1);
  170. printf("in1_shift %d, in2_shift %d, left_shift %d, out_shift %d\n",
  171. input1_shift, input2_shift, left_shift, output_shift);
  172. printf("in1_mult %d, in2_mult %d, out_mult %d\n", input1_mult, input2_mult, output_mult);
  173. goto elementwise_add_test_cleanup;
  174. }
  175. printf(ANSI_COLOR_GREEN"%s[%d] passed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
  176. elementwise_add_test_cleanup:
  177. if (input1_orig) {
  178. free(input1_orig);
  179. }
  180. if (input2_orig) {
  181. free(input2_orig);
  182. }
  183. if (out_data_c) {
  184. free(out_c_orig);
  185. }
  186. if (out_data_opt) {
  187. free(out_opt_orig);
  188. }
  189. }
  190. }
  191. void esp_nn_mul_elementwise_s8_test()
  192. {
  193. /* prepare data */
  194. const int size = 1600 + 8 + 7; /* odd len to test leftover */
  195. int8_t *input1;
  196. int8_t *input2;
  197. int8_t *out_data_c;
  198. int8_t *out_data_opt;
  199. int32_t input1_offset = 34;
  200. int32_t input2_offset = 35;
  201. int32_t output_offset = 36;
  202. int32_t output_shift = -7;
  203. int32_t output_mult = MULT_MAX; // max out_mult
  204. int32_t activation_min = -128;
  205. int32_t activation_max = 127;
  206. int8_t *input1_orig = NULL;
  207. int8_t *input2_orig = NULL;
  208. int8_t *out_c_orig = NULL;
  209. int8_t *out_opt_orig = NULL;
  210. for (int itr = 0; itr < 10; itr++) {
  211. switch (itr) {
  212. case 0: // all zeros
  213. input1_offset = 0;
  214. input2_offset = 0;
  215. output_offset = 0;
  216. output_mult = 0;
  217. output_shift = 0;
  218. break;
  219. case 1: // hit min
  220. input1_offset = -127;
  221. input2_offset = -127;
  222. output_offset = -128;
  223. output_mult = MULT_MIN;
  224. output_shift = 0;
  225. break;
  226. case 2: // hit max
  227. input1_offset = 128;
  228. input2_offset = 128;
  229. output_offset = -127;
  230. output_mult = MULT_MAX;
  231. output_shift = SHIFT_MIN;
  232. break;
  233. case 3: // hit extreme max
  234. input1_offset = 128;
  235. input2_offset = 128;
  236. output_offset = -127;
  237. output_mult = MULT_MAX;
  238. output_shift = 0;
  239. break;
  240. default: // practical random input
  241. input1_offset = rand() % 256 - 127; // range [-127, 128]
  242. input2_offset = rand() % 256 - 127; // range [-127, 128]
  243. output_offset = rand() % 256 - 128; // range [-128, 127]
  244. output_mult = MULT_MAX / 2 + rand() % INT16_MAX;
  245. output_shift = -8 + rand() % 4;
  246. }
  247. #if IDF_HEAP_CAPS
  248. input1_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
  249. input2_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
  250. out_c_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
  251. out_opt_orig = (int8_t *) heap_caps_malloc(size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
  252. input1 = 16 + input1_orig - ((uint32_t) input1_orig & 0xf);
  253. input2 = 16 + input2_orig - ((uint32_t) input2_orig & 0xf);
  254. out_data_c = 16 + out_c_orig - ((uint32_t) out_c_orig & 0xf);
  255. out_data_opt = 16 + out_opt_orig - ((uint32_t) out_opt_orig & 0xf);
  256. #else
  257. input1 = memalign(16, size);
  258. input2 = memalign(16, size);
  259. out_data_c = memalign(16, size);
  260. out_data_opt = memalign(16, size);
  261. input1_orig = input1;
  262. input2_orig = input2;
  263. out_c_orig = out_data_c;
  264. out_opt_orig = out_data_opt;
  265. #endif
  266. for (int i = 0; i < size; ++i) {
  267. input1[i] = rand() % 256 - 128;
  268. input2[i] = rand() % 256 - 128;
  269. }
  270. if (itr == 0) {
  271. /* enable profiler */
  272. profile_c_start();
  273. }
  274. /* C function */
  275. esp_nn_mul_elementwise_s8_ansi(input1, input2, input1_offset, input2_offset,
  276. out_data_c, output_offset, output_mult, output_shift,
  277. activation_min, activation_max, size);
  278. if (itr == 0) {
  279. profile_c_end();
  280. profile_opt_start();
  281. }
  282. /* Optimized function */
  283. esp_nn_mul_elementwise_s8(input1, input2, input1_offset, input2_offset,
  284. out_data_opt, output_offset, output_mult, output_shift,
  285. activation_min, activation_max, size);
  286. if (itr == 0) {
  287. /* disable profiler */
  288. profile_opt_end();
  289. }
  290. bool ret = CHECK_EQUAL(out_data_c, out_data_opt, size);
  291. if (ret == false) {
  292. printf(ANSI_COLOR_RED"%s[%d] failed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
  293. printf("Output: \n");
  294. PRINT_ARRAY_HEX(out_data_opt, size, 1);
  295. printf("Expected: \n");
  296. PRINT_ARRAY_HEX(out_data_c, size, 1);
  297. printf("Input1:\n");
  298. PRINT_ARRAY_HEX(input1, size, 1);
  299. printf("Input2:\n");
  300. PRINT_ARRAY_HEX(input2, size, 1);
  301. goto elementwise_mult_test_cleanup;
  302. }
  303. printf(ANSI_COLOR_GREEN"%s[%d] passed\n"ANSI_COLOR_RESET, __FUNCTION__, itr);
  304. elementwise_mult_test_cleanup:
  305. if (input1_orig) {
  306. free(input1_orig);
  307. }
  308. if (input2_orig) {
  309. free(input2_orig);
  310. }
  311. if (out_data_c) {
  312. free(out_c_orig);
  313. }
  314. if (out_data_opt) {
  315. free(out_opt_orig);
  316. }
  317. }
  318. }