gather_nd.cc 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/c/common.h"
  13. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  14. #include "tensorflow/lite/kernels/kernel_util.h"
  15. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  16. #include "tensorflow/lite/micro/micro_utils.h"
  17. namespace tflite {
  18. namespace {
  19. constexpr int kParams = 0;
  20. constexpr int kIndices = 1;
  21. constexpr int kOutputTensor = 0;
  22. constexpr int MAX_INDICES_ND = 5;
  23. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  24. TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
  25. TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  26. const TfLiteTensor* params;
  27. TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
  28. const TfLiteTensor* indices;
  29. TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
  30. TfLiteTensor* output;
  31. TF_LITE_ENSURE_OK(context,
  32. GetOutputSafe(context, node, kOutputTensor, &output));
  33. switch (params->type) {
  34. case kTfLiteFloat32:
  35. case kTfLiteInt8:
  36. break;
  37. default:
  38. TF_LITE_KERNEL_LOG(context,
  39. "Params of type '%s' are not supported by gather_nd.",
  40. TfLiteTypeGetName(params->type));
  41. return kTfLiteError;
  42. break;
  43. }
  44. switch (indices->type) {
  45. case kTfLiteInt32:
  46. break;
  47. default:
  48. TF_LITE_KERNEL_LOG(context,
  49. "Indices of type '%s' are not supported by gather_nd.",
  50. TfLiteTypeGetName(indices->type));
  51. return kTfLiteError;
  52. }
  53. const int params_rank = NumDimensions(params);
  54. const int indices_rank = NumDimensions(indices);
  55. const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
  56. if (params_rank < 1) {
  57. TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
  58. return kTfLiteError;
  59. }
  60. if (indices_rank < 1) {
  61. TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
  62. return kTfLiteError;
  63. }
  64. if (indices_nd > params_rank) {
  65. TF_LITE_KERNEL_LOG(
  66. context, "Index innermost dimension length must be <= params rank.");
  67. return kTfLiteError;
  68. }
  69. if (indices_nd > MAX_INDICES_ND) {
  70. TF_LITE_KERNEL_LOG(context,
  71. "Index innermost dimension length must not exceed %d.",
  72. MAX_INDICES_ND);
  73. return kTfLiteError;
  74. }
  75. // Assign to output the input type.
  76. output->type = params->type;
  77. // TFLM gather_nd does not create the output tensor, but it needs to ensure
  78. // that the output shape is correct. The result shape is
  79. // indices.shape[:-1] + params.shape[indices.shape[-1]:]
  80. TfLiteIntArray* output_shape = output->dims;
  81. int output_index = 0;
  82. for (int i = 0; i < indices_rank - 1; ++i) {
  83. output_shape->data[output_index++] = indices->dims->data[i];
  84. }
  85. for (int i = indices_nd; i < params_rank; ++i) {
  86. output_shape->data[output_index++] = params->dims->data[i];
  87. }
  88. output_shape->size = output_index;
  89. return kTfLiteOk;
  90. }
  91. template <typename ParamsT, typename IndicesT>
  92. TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
  93. const TfLiteEvalTensor* indices,
  94. TfLiteEvalTensor* output) {
  95. const int indices_dims = indices->dims->size;
  96. const int indices_nd = indices->dims->data[indices_dims - 1];
  97. const int params_dims = params->dims->size;
  98. const IndicesT* index_data = tflite::micro::GetTensorData<IndicesT>(indices);
  99. const ParamsT* param_data = tflite::micro::GetTensorData<ParamsT>(params);
  100. ParamsT* output_data = tflite::micro::GetTensorData<ParamsT>(output);
  101. int n_slices = 1;
  102. for (int i = 0; i < indices_dims - 1; ++i) {
  103. n_slices *= indices->dims->data[i];
  104. }
  105. // If indices[-1] == params.rank, fetch single elements.
  106. // If indices[-1] < params.rank, fetch slices.
  107. int slice_size = 1;
  108. for (int i = indices_nd; i < params_dims; ++i) {
  109. slice_size *= params->dims->data[i];
  110. }
  111. int remain_flat_size = ElementCount(*params->dims);
  112. // Number of elements per dimension
  113. int dims_to_count[MAX_INDICES_ND];
  114. for (int i = 0; i < indices_nd; ++i) {
  115. dims_to_count[i] = remain_flat_size / params->dims->data[i];
  116. remain_flat_size = dims_to_count[i];
  117. }
  118. for (int i = 0; i < n_slices; ++i) {
  119. int from_pos = 0;
  120. for (int j = 0; j < indices_nd; ++j) {
  121. int offset = i * indices_nd + j;
  122. IndicesT index = index_data[offset];
  123. from_pos += index * dims_to_count[j];
  124. }
  125. std::memcpy(output_data + i * slice_size, param_data + from_pos,
  126. sizeof(ParamsT) * slice_size);
  127. }
  128. return kTfLiteOk;
  129. }
  130. template <typename IndicesT>
  131. TfLiteStatus EvalGatherNd(TfLiteContext* context,
  132. const TfLiteEvalTensor* params,
  133. const TfLiteEvalTensor* indices,
  134. TfLiteEvalTensor* output) {
  135. switch (params->type) {
  136. case kTfLiteFloat32:
  137. return GatherNd<float, IndicesT>(params, indices, output);
  138. break;
  139. case kTfLiteInt8:
  140. return GatherNd<int8_t, IndicesT>(params, indices, output);
  141. break;
  142. default:
  143. TF_LITE_KERNEL_LOG(context,
  144. "Params type '%s' are not supported by gather_nd.",
  145. TfLiteTypeGetName(params->type));
  146. return kTfLiteError;
  147. }
  148. }
  149. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  150. const TfLiteEvalTensor* params =
  151. tflite::micro::GetEvalInput(context, node, kParams);
  152. const TfLiteEvalTensor* indices =
  153. tflite::micro::GetEvalInput(context, node, kIndices);
  154. TfLiteEvalTensor* output =
  155. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  156. switch (indices->type) {
  157. case kTfLiteInt32:
  158. return EvalGatherNd<int32_t>(context, params, indices, output);
  159. break;
  160. default:
  161. TF_LITE_KERNEL_LOG(context,
  162. "Indices of type '%s' are not supported by gather_nd.",
  163. TfLiteTypeGetName(indices->type));
  164. return kTfLiteError;
  165. }
  166. }
  167. } // namespace
  168. TfLiteRegistration Register_GATHER_ND() {
  169. return {/*init=*/nullptr,
  170. /*free=*/nullptr,
  171. /*prepare=*/Prepare,
  172. /*invoke=*/Eval,
  173. /*profiling_string=*/nullptr,
  174. /*builtin_code=*/0,
  175. /*custom_name=*/nullptr,
  176. /*version=*/0};
  177. }
  178. } // namespace tflite