expand_dims.cc 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/c/common.h"
  13. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  14. #include "tensorflow/lite/kernels/kernel_util.h"
  15. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  16. #include "tensorflow/lite/micro/micro_utils.h"
  17. namespace tflite {
  18. namespace {
  19. constexpr int kInputTensor = 0;
  20. constexpr int kAxisTensor = 1;
  21. constexpr int kOutputTensor = 0;
  22. TfLiteStatus ExpandTensorDim(TfLiteContext* context,
  23. const TfLiteEvalTensor* input, int32_t axis,
  24. TfLiteEvalTensor* output) {
  25. const TfLiteIntArray* input_dims = input->dims;
  26. TfLiteIntArray* output_dims = output->dims;
  27. if (axis < 0) {
  28. axis = input_dims->size + 1 + axis;
  29. }
  30. TF_LITE_ENSURE(context, (axis <= input_dims->size));
  31. output_dims->size = input_dims->size + 1;
  32. for (int i = 0; i < output_dims->size; ++i) {
  33. if (i < axis) {
  34. output_dims->data[i] = input_dims->data[i];
  35. } else if (i == axis) {
  36. output_dims->data[i] = 1;
  37. } else {
  38. output_dims->data[i] = input_dims->data[i - 1];
  39. }
  40. }
  41. return kTfLiteOk;
  42. }
  43. TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
  44. const TfLiteEvalTensor* axis,
  45. int32_t* axis_value) {
  46. const int axis_dims = (tflite::micro::GetTensorShape(axis)).DimensionsCount();
  47. if (axis_dims > 1) {
  48. TF_LITE_KERNEL_LOG(context, "Axis has only one element for Expand_Dims.",
  49. axis_dims);
  50. return kTfLiteError;
  51. }
  52. if (kTfLiteInt32 == (axis->type)) {
  53. const int32_t* axis_ptr = tflite::micro::GetTensorData<int32_t>(axis);
  54. *axis_value = axis_ptr[0];
  55. return kTfLiteOk;
  56. } else {
  57. TF_LITE_KERNEL_LOG(context,
  58. "Axis type %s (%d) not supported by Expand_Dims.",
  59. TfLiteTypeGetName(axis->type), axis->type);
  60. return kTfLiteError;
  61. }
  62. }
  63. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  64. TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
  65. TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  66. const TfLiteTensor* input;
  67. TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
  68. const TfLiteTensor* axis;
  69. TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
  70. TfLiteTensor* output;
  71. TF_LITE_ENSURE_OK(context,
  72. GetOutputSafe(context, node, kOutputTensor, &output));
  73. output->type = input->type;
  74. if (IsDynamicTensor(axis)) {
  75. TF_LITE_KERNEL_LOG(context,
  76. "DynamicTensor is not yet supported by Expand_Dims.");
  77. return kTfLiteError;
  78. }
  79. return kTfLiteOk;
  80. }
  81. template <typename T>
  82. void memCopyN(T* out, const T* in, const int num_elements) {
  83. for (int i = 0; i < num_elements; ++i) {
  84. out[i] = in[i];
  85. }
  86. }
  87. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  88. const TfLiteEvalTensor* input =
  89. tflite::micro::GetEvalInput(context, node, kInputTensor);
  90. const TfLiteEvalTensor* axis =
  91. tflite::micro::GetEvalInput(context, node, kAxisTensor);
  92. TfLiteEvalTensor* output =
  93. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  94. const int flat_size = ElementCount(*input->dims);
  95. const int input_dims = input->dims->size;
  96. int32_t axis_value;
  97. TF_LITE_ENSURE_OK(context,
  98. GetAxisValueFromTensor(context, axis, &axis_value));
  99. if ((axis_value > static_cast<int32_t>(input_dims)) ||
  100. (axis_value < static_cast<int32_t>(-(input_dims + 1)))) {
  101. TF_LITE_KERNEL_LOG(context, "Invalid Expand_Dims axis value (%d).",
  102. axis_value);
  103. return kTfLiteError;
  104. }
  105. ExpandTensorDim(context, input, axis_value, output);
  106. switch (input->type) {
  107. case kTfLiteFloat32: {
  108. memCopyN(tflite::micro::GetTensorData<float>(output),
  109. tflite::micro::GetTensorData<float>(input), flat_size);
  110. } break;
  111. case kTfLiteInt8: {
  112. memCopyN(tflite::micro::GetTensorData<int8_t>(output),
  113. tflite::micro::GetTensorData<int8_t>(input), flat_size);
  114. } break;
  115. default:
  116. TF_LITE_KERNEL_LOG(
  117. context,
  118. "Expand_Dims only currently supports int8 and float32, got %d.",
  119. input->type);
  120. return kTfLiteError;
  121. }
  122. return kTfLiteOk;
  123. }
  124. } // namespace
  125. TfLiteRegistration Register_EXPAND_DIMS() {
  126. return {/*init=*/nullptr,
  127. /*free=*/nullptr,
  128. /*prepare=*/Prepare,
  129. /*invoke=*/Eval,
  130. /*profiling_string=*/nullptr,
  131. /*builtin_code=*/0,
  132. /*custom_name=*/nullptr,
  133. /*version=*/0};
  134. }
  135. } // namespace tflite