reduce_common.cc 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/c/builtin_op_data.h"
  13. #include "tensorflow/lite/c/common.h"
  14. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  15. #include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
  16. #include "tensorflow/lite/kernels/internal/reference/reduce.h"
  17. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  18. #include "tensorflow/lite/kernels/internal/types.h"
  19. #include "tensorflow/lite/kernels/kernel_util.h"
  20. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  21. #include "tensorflow/lite/micro/kernels/reduce.h"
  22. #include "tensorflow/lite/micro/micro_error_reporter.h"
  23. #include "tensorflow/lite/micro/micro_utils.h"
  24. namespace tflite {
  25. const int kMaxNumberOfAxis = 5;
  26. const int kMaxNumberOfReducedAxis = 2;
  27. TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
  28. int32_t* multiplier, int* shift) {
  29. MicroContext* micro_context = GetMicroContext(context);
  30. // Inputs Tensor (dtype depends on quantization):
  31. // [0] = Input
  32. // [1] = Axis
  33. TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
  34. // Outputs Tensor (dtype depends on quantization):
  35. // [0] = Output
  36. // Validate number of inputs and outputs
  37. TF_LITE_ENSURE_EQ(context, node->inputs->size, 2);
  38. TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
  39. // Validate axis type
  40. TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
  41. TF_LITE_ENSURE(context, axis != nullptr);
  42. TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
  43. if (input->type == kTfLiteInt8) {
  44. TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
  45. const double real_multiplier = static_cast<double>(input->params.scale) /
  46. static_cast<double>(output->params.scale);
  47. QuantizeMultiplier(real_multiplier, multiplier, shift);
  48. micro_context->DeallocateTempTfLiteTensor(output);
  49. }
  50. micro_context->DeallocateTempTfLiteTensor(axis);
  51. micro_context->DeallocateTempTfLiteTensor(input);
  52. return kTfLiteOk;
  53. }
  54. TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
  55. OpDataReduce* op_data) {
  56. TF_LITE_ENSURE_OK(context, PrepareSimple(context, node, &op_data->multiplier,
  57. &op_data->shift));
  58. MicroContext* micro_context = GetMicroContext(context);
  59. TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
  60. TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
  61. TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
  62. op_data->input_scale = input->params.scale;
  63. op_data->output_scale = output->params.scale;
  64. op_data->num_output_elements = NumElements(output);
  65. context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
  66. &op_data->temp_buffer_idx);
  67. context->RequestScratchBufferInArena(
  68. context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
  69. &op_data->resolved_axis_idx);
  70. micro_context->DeallocateTempTfLiteTensor(input);
  71. micro_context->DeallocateTempTfLiteTensor(output);
  72. micro_context->DeallocateTempTfLiteTensor(axis);
  73. return kTfLiteOk;
  74. }
  75. TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
  76. OpDataReduce* op_data) {
  77. MicroContext* micro_context = GetMicroContext(context);
  78. TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
  79. TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
  80. if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
  81. const double real_multiplier = static_cast<double>(input->params.scale) /
  82. static_cast<double>(output->params.scale);
  83. QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
  84. }
  85. int output_size = NumElements(output);
  86. if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
  87. context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
  88. &op_data->temp_buffer_idx);
  89. op_data->input_zp = input->params.zero_point;
  90. op_data->input_scale = input->params.scale;
  91. op_data->output_zp = output->params.zero_point;
  92. op_data->output_scale = output->params.scale;
  93. }
  94. TF_LITE_ENSURE_OK(
  95. context,
  96. PrepareSimple(context, node, &(op_data->multiplier), &(op_data->shift)));
  97. // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
  98. micro_context->DeallocateTempTfLiteTensor(input);
  99. micro_context->DeallocateTempTfLiteTensor(output);
  100. return kTfLiteOk;
  101. }
  102. void ResolveAxis(const int* axis_data, int axis_count,
  103. tflite::MeanParams* op_params) {
  104. int i = 0;
  105. for (; i < axis_count; ++i) {
  106. op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
  107. }
  108. for (; i < 4; ++i) {
  109. op_params->axis[i] = 1;
  110. }
  111. op_params->axis_count = axis_count;
  112. }
  113. TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
  114. OpDataReduce* op_data) {
  115. const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
  116. const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
  117. TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
  118. TfLiteReducerParams* params =
  119. reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
  120. int num_axis = static_cast<int>(ElementCount(*axis->dims));
  121. int temp_index[kMaxNumberOfAxis];
  122. int resolved_axis[kMaxNumberOfReducedAxis];
  123. tflite::MeanParams op_params;
  124. ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
  125. // Special case mean implementation exists for 4D mean across axes 1 and 2.
  126. bool special_case_4d_axes_1_and_2 =
  127. input->dims->size == 4 && op_params.axis_count == 2 &&
  128. ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
  129. (op_params.axis[0] == 2 && op_params.axis[1] == 1));
  130. switch (input->type) {
  131. case kTfLiteFloat32: {
  132. // Defer to specialized implementation for 4D Mean across axes 1 & 2.
  133. if (params->keep_dims && special_case_4d_axes_1_and_2) {
  134. reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
  135. tflite::micro::GetTensorData<float>(input),
  136. tflite::micro::GetTensorShape(output),
  137. tflite::micro::GetTensorData<float>(output));
  138. } else {
  139. TF_LITE_ENSURE(
  140. context,
  141. reference_ops::Mean(
  142. tflite::micro::GetTensorData<float>(input), input->dims->data,
  143. input->dims->size, tflite::micro::GetTensorData<float>(output),
  144. output->dims->data, output->dims->size,
  145. tflite::micro::GetTensorData<int>(axis), num_axis,
  146. params->keep_dims, temp_index, resolved_axis,
  147. tflite::micro::GetTensorData<float>(output)));
  148. }
  149. } break;
  150. case kTfLiteInt8: {
  151. // Defer to specialized implementation for 4D Mean across axes 1 & 2.
  152. if (params->keep_dims && special_case_4d_axes_1_and_2) {
  153. reference_integer_ops::Mean(
  154. op_params, op_data->multiplier, op_data->shift,
  155. tflite::micro::GetTensorShape(input),
  156. tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
  157. tflite::micro::GetTensorShape(output),
  158. tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
  159. } else if (op_data->input_zp == op_data->output_zp &&
  160. op_data->input_scale == op_data->output_scale) {
  161. int32_t* temp_buffer = static_cast<int32_t*>(
  162. context->GetScratchBuffer(context, op_data->temp_buffer_idx));
  163. TF_LITE_ENSURE(
  164. context,
  165. reference_ops::Mean(
  166. tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
  167. input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
  168. output->dims->data, output->dims->size,
  169. tflite::micro::GetTensorData<int>(axis), num_axis,
  170. params->keep_dims, temp_index, resolved_axis, temp_buffer));
  171. } else {
  172. int32_t* temp_buffer = static_cast<int32_t*>(
  173. context->GetScratchBuffer(context, op_data->temp_buffer_idx));
  174. TF_LITE_ENSURE(
  175. context,
  176. reference_ops::QuantizedMeanOrSum(
  177. tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
  178. op_data->input_scale, input->dims->data, input->dims->size,
  179. tflite::micro::GetTensorData<int8_t>(output),
  180. op_data->output_zp, op_data->output_scale, output->dims->data,
  181. output->dims->size, tflite::micro::GetTensorData<int>(axis),
  182. num_axis, params->keep_dims, temp_index, resolved_axis,
  183. temp_buffer, false));
  184. }
  185. } break;
  186. case kTfLiteInt16: {
  187. // Defer to specialized implementation for 4D Mean across axes 1 & 2.
  188. if (params->keep_dims && special_case_4d_axes_1_and_2) {
  189. reference_integer_ops::Mean(
  190. op_params, op_data->multiplier, op_data->shift,
  191. tflite::micro::GetTensorShape(input),
  192. tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
  193. tflite::micro::GetTensorShape(output),
  194. tflite::micro::GetTensorData<int16_t>(output), op_data->output_zp);
  195. } else if (op_data->input_zp == op_data->output_zp &&
  196. op_data->input_scale == op_data->output_scale) {
  197. int32_t* temp_buffer = static_cast<int32_t*>(
  198. context->GetScratchBuffer(context, op_data->temp_buffer_idx));
  199. TF_LITE_ENSURE(
  200. context,
  201. reference_ops::Mean(tflite::micro::GetTensorData<int16_t>(input),
  202. input->dims->data, input->dims->size,
  203. tflite::micro::GetTensorData<int16_t>(output),
  204. output->dims->data, output->dims->size,
  205. tflite::micro::GetTensorData<int>(axis),
  206. num_axis, params->keep_dims, temp_index,
  207. resolved_axis, temp_buffer));
  208. } else {
  209. int32_t* temp_buffer = static_cast<int32_t*>(
  210. context->GetScratchBuffer(context, op_data->temp_buffer_idx));
  211. TF_LITE_ENSURE(
  212. context,
  213. reference_ops::QuantizedMeanOrSum(
  214. tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
  215. op_data->input_scale, input->dims->data, input->dims->size,
  216. tflite::micro::GetTensorData<int16_t>(output),
  217. op_data->output_zp, op_data->output_scale, output->dims->data,
  218. output->dims->size, tflite::micro::GetTensorData<int>(axis),
  219. num_axis, params->keep_dims, temp_index, resolved_axis,
  220. temp_buffer, false));
  221. }
  222. } break;
  223. default:
  224. TF_LITE_ENSURE_MSG(context, false,
  225. "Currently, only float32, int8 or int16 input type "
  226. "is supported.");
  227. }
  228. return kTfLiteOk;
  229. }
  230. TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
  231. OpDataReduce* op_data) {
  232. const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
  233. const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
  234. TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
  235. TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
  236. TfLiteReducerParams* params =
  237. static_cast<TfLiteReducerParams*>(node->builtin_data);
  238. // Interpret an axis tensor with null dimensions as a scalar
  239. int num_axis = static_cast<int>(ElementCount(*axis->dims));
  240. int* temp_buffer = static_cast<int*>(
  241. context->GetScratchBuffer(context, op_data->temp_buffer_idx));
  242. int* resolved_axis = static_cast<int*>(
  243. context->GetScratchBuffer(context, op_data->resolved_axis_idx));
  244. switch (input->type) {
  245. case kTfLiteFloat32:
  246. TF_LITE_ENSURE(
  247. context,
  248. reference_ops::ReduceGeneric<float>(
  249. tflite::micro::GetTensorData<float>(input), input->dims->data,
  250. input->dims->size, tflite::micro::GetTensorData<float>(output),
  251. output->dims->data, output->dims->size,
  252. tflite::micro::GetTensorData<int>(axis), num_axis,
  253. params->keep_dims, temp_buffer, resolved_axis,
  254. std::numeric_limits<float>::lowest(),
  255. [](const float current, const float in) -> float {
  256. return (in > current) ? in : current;
  257. }));
  258. break;
  259. case kTfLiteInt8:
  260. TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
  261. static_cast<double>(op_data->output_scale));
  262. TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
  263. TF_LITE_ENSURE(
  264. context,
  265. reference_ops::ReduceGeneric<int8_t>(
  266. tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
  267. input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
  268. output->dims->data, output->dims->size,
  269. tflite::micro::GetTensorData<int>(axis), num_axis,
  270. params->keep_dims, temp_buffer, resolved_axis,
  271. std::numeric_limits<int8_t>::lowest(),
  272. [](const int8_t current, const int8_t in) -> int8_t {
  273. return (in > current) ? in : current;
  274. }));
  275. break;
  276. default:
  277. MicroPrintf("Only float32 and int8 types are supported.");
  278. return kTfLiteError;
  279. }
  280. return kTfLiteOk;
  281. }
  282. TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
  283. OpDataReduce* op_data) {
  284. const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
  285. const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
  286. TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
  287. TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
  288. TfLiteReducerParams* params =
  289. static_cast<TfLiteReducerParams*>(node->builtin_data);
  290. // Interpret an axis tensor with null dimensions as a scalar.
  291. int num_axis = static_cast<int>(ElementCount(*axis->dims));
  292. int temp_index[kMaxNumberOfAxis];
  293. int resolved_axis[kMaxNumberOfReducedAxis];
  294. switch (input->type) {
  295. case kTfLiteFloat32: {
  296. TF_LITE_ENSURE(
  297. context,
  298. reference_ops::ReduceGeneric<float>(
  299. tflite::micro::GetTensorData<float>(input), input->dims->data,
  300. input->dims->size, tflite::micro::GetTensorData<float>(output),
  301. output->dims->data, output->dims->size,
  302. tflite::micro::GetTensorData<int>(axis), num_axis,
  303. params->keep_dims, temp_index, resolved_axis, /*init_value=*/0.f,
  304. [](const float current, const float in) -> float {
  305. return in + current;
  306. }));
  307. } break;
  308. case kTfLiteInt8: {
  309. int32_t* temp_buffer = static_cast<int32_t*>(
  310. context->GetScratchBuffer(context, op_data->temp_buffer_idx));
  311. TF_LITE_ENSURE(
  312. context,
  313. reference_ops::QuantizedMeanOrSum(
  314. tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
  315. op_data->input_scale, input->dims->data, input->dims->size,
  316. tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp,
  317. op_data->output_scale, output->dims->data, output->dims->size,
  318. tflite::micro::GetTensorData<int>(axis), num_axis,
  319. params->keep_dims, temp_index, resolved_axis, temp_buffer,
  320. /*compute_sum=*/true));
  321. } break;
  322. case kTfLiteInt16: {
  323. int32_t* temp_buffer = static_cast<int32_t*>(
  324. context->GetScratchBuffer(context, op_data->temp_buffer_idx));
  325. TF_LITE_ENSURE(
  326. context,
  327. reference_ops::QuantizedMeanOrSum(
  328. tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
  329. op_data->input_scale, input->dims->data, input->dims->size,
  330. tflite::micro::GetTensorData<int16_t>(output), op_data->output_zp,
  331. op_data->output_scale, output->dims->data, output->dims->size,
  332. tflite::micro::GetTensorData<int>(axis), num_axis,
  333. params->keep_dims, temp_index, resolved_axis, temp_buffer,
  334. /*compute_sum=*/true));
  335. } break;
  336. default:
  337. MicroPrintf("Only float32, int8, and int16 types are supported.");
  338. return kTfLiteError;
  339. }
  340. return kTfLiteOk;
  341. }
  342. } // namespace tflite