pad.cc 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/pad.h"
  13. #include <string.h>
  14. #include "tensorflow/lite/c/builtin_op_data.h"
  15. #include "tensorflow/lite/c/common.h"
  16. #include "tensorflow/lite/kernels/internal/portable_tensor.h"
  17. #include "tensorflow/lite/kernels/internal/types.h"
  18. #include "tensorflow/lite/kernels/kernel_util.h"
  19. #include "tensorflow/lite/kernels/op_macros.h"
  20. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  21. namespace tflite {
  22. namespace ops {
  23. namespace micro {
  24. namespace pad {
  25. namespace {
  26. struct OpData {
  27. PadParams params;
  28. int32_t output_zero_point;
  29. };
  30. } // namespace
  31. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  32. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  33. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  34. }
  35. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  36. MicroContext* micro_context = GetMicroContext(context);
  37. TFLITE_DCHECK(node->user_data != nullptr);
  38. OpData* data = static_cast<OpData*>(node->user_data);
  39. TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
  40. TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  41. TfLiteTensor* input =
  42. micro_context->AllocateTempInputTensor(node, /*index=*/0);
  43. TF_LITE_ENSURE(context, input != nullptr);
  44. TfLiteTensor* paddings =
  45. micro_context->AllocateTempInputTensor(node, /*index=*/1);
  46. TF_LITE_ENSURE(context, paddings != nullptr);
  47. TfLiteTensor* constant_values =
  48. NumInputs(node) == 3
  49. ? micro_context->AllocateTempInputTensor(node, /*index=*/2)
  50. : nullptr;
  51. TfLiteTensor* output =
  52. micro_context->AllocateTempOutputTensor(node, /*index=*/0);
  53. TF_LITE_ENSURE(context, output != nullptr);
  54. TF_LITE_ENSURE_EQ(context, input->type, output->type);
  55. // Current implementations rely on the inputs being <= 4D.
  56. TF_LITE_ENSURE(context, NumDimensions(input) <=
  57. reference_ops::PadKernelMaxDimensionCount());
  58. if (constant_values != nullptr) {
  59. TF_LITE_ENSURE_EQ(context, input->type, constant_values->type);
  60. // Ensure that constant_values is a scalar.
  61. TF_LITE_ENSURE_EQ(context, NumElements(constant_values), 1);
  62. }
  63. // There must be a pair of paddings for each output dimension.
  64. TF_LITE_ENSURE_EQ(context, GetTensorShape(paddings).FlatSize(),
  65. output->dims->size * 2);
  66. // On Micro, outputs must be properly sized by the converter.
  67. // NOTE: This data is only available because the paddings buffer is stored in
  68. // the flatbuffer:
  69. TF_LITE_ENSURE(context, IsConstantTensor(paddings));
  70. const int32_t* paddings_data = GetTensorData<int32_t>(paddings);
  71. for (int i = 0; i < output->dims->size; i++) {
  72. int output_dim = output->dims->data[i];
  73. int expected_dim =
  74. input->dims->data[i] + paddings_data[i * 2] + paddings_data[i * 2 + 1];
  75. TF_LITE_ENSURE_EQ(context, output_dim, expected_dim);
  76. }
  77. // Calculate OpData:
  78. data->params.resizing_category = ResizingCategory::kGenericResize;
  79. const int paddings_total = GetTensorShape(paddings).FlatSize();
  80. if (paddings_total == 8 && (paddings_data[0] == 0 && paddings_data[1] == 0) &&
  81. (paddings_data[6] == 0 && paddings_data[7] == 0)) {
  82. data->params.resizing_category = ResizingCategory::kImageStyle;
  83. }
  84. const int num_input_dimensions = NumDimensions(input);
  85. data->params.left_padding_count = num_input_dimensions;
  86. data->params.right_padding_count = num_input_dimensions;
  87. for (int idx = num_input_dimensions - 1; idx >= 0; --idx) {
  88. data->params.left_padding[idx] = paddings_data[idx * 2];
  89. data->params.right_padding[idx] = paddings_data[idx * 2 + 1];
  90. }
  91. if (input->type == kTfLiteInt8) {
  92. if (constant_values == nullptr) {
  93. // Quantized Pad requires that 0 is represented in the quantized
  94. // range.
  95. TF_LITE_ENSURE(context, output->params.zero_point >=
  96. std::numeric_limits<int8_t>::min());
  97. TF_LITE_ENSURE(context, output->params.zero_point <=
  98. std::numeric_limits<int8_t>::max());
  99. } else {
  100. // Quantized Pad requires that 'constant_values' is represented in the
  101. // same quantized range as the input and output tensors.
  102. TF_LITE_ENSURE_EQ(context, output->params.zero_point,
  103. constant_values->params.zero_point);
  104. TF_LITE_ENSURE_EQ(context, static_cast<double>(output->params.scale),
  105. static_cast<double>(constant_values->params.scale));
  106. }
  107. data->output_zero_point = output->params.zero_point;
  108. }
  109. micro_context->DeallocateTempTfLiteTensor(input);
  110. micro_context->DeallocateTempTfLiteTensor(paddings);
  111. if (constant_values != nullptr) {
  112. micro_context->DeallocateTempTfLiteTensor(constant_values);
  113. }
  114. micro_context->DeallocateTempTfLiteTensor(output);
  115. return kTfLiteOk;
  116. }
  117. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  118. TFLITE_DCHECK(node->user_data != nullptr);
  119. const OpData* data = static_cast<const OpData*>(node->user_data);
  120. const TfLiteEvalTensor* input =
  121. tflite::micro::GetEvalInput(context, node, /*index=*/0);
  122. const TfLiteEvalTensor* constant_values =
  123. NumInputs(node) == 3
  124. ? tflite::micro::GetEvalInput(context, node, /*index=*/2)
  125. : nullptr;
  126. TfLiteEvalTensor* output =
  127. tflite::micro::GetEvalOutput(context, node, /*index=*/0);
  128. switch (input->type) {
  129. case kTfLiteFloat32: {
  130. float pad_value =
  131. constant_values == nullptr
  132. ? 0.f
  133. : *tflite::micro::GetTensorData<float>(constant_values);
  134. if (data->params.resizing_category == ResizingCategory::kImageStyle) {
  135. reference_ops::PadImageStyle(
  136. data->params, tflite::micro::GetTensorShape(input),
  137. tflite::micro::GetTensorData<float>(input), &pad_value,
  138. tflite::micro::GetTensorShape(output),
  139. tflite::micro::GetTensorData<float>(output));
  140. } else {
  141. reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
  142. tflite::micro::GetTensorData<float>(input),
  143. &pad_value, tflite::micro::GetTensorShape(output),
  144. tflite::micro::GetTensorData<float>(output));
  145. }
  146. } break;
  147. case kTfLiteInt8: {
  148. int8_t pad_value;
  149. if (constant_values == nullptr) {
  150. pad_value = static_cast<uint8_t>(data->output_zero_point);
  151. } else {
  152. pad_value = *tflite::micro::GetTensorData<int8_t>(constant_values);
  153. }
  154. if (data->params.resizing_category == ResizingCategory::kImageStyle) {
  155. reference_ops::PadImageStyle(
  156. data->params, tflite::micro::GetTensorShape(input),
  157. tflite::micro::GetTensorData<int8_t>(input), &pad_value,
  158. tflite::micro::GetTensorShape(output),
  159. tflite::micro::GetTensorData<int8_t>(output));
  160. } else {
  161. reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
  162. tflite::micro::GetTensorData<int8_t>(input),
  163. &pad_value, tflite::micro::GetTensorShape(output),
  164. tflite::micro::GetTensorData<int8_t>(output));
  165. }
  166. } break;
  167. case kTfLiteInt16: {
  168. int16_t pad_value =
  169. constant_values == nullptr
  170. ? 0
  171. : *tflite::micro::GetTensorData<int16_t>(constant_values);
  172. reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
  173. tflite::micro::GetTensorData<int16_t>(input),
  174. &pad_value, tflite::micro::GetTensorShape(output),
  175. tflite::micro::GetTensorData<int16_t>(output));
  176. } break;
  177. case kTfLiteInt32: {
  178. int32_t pad_value =
  179. constant_values == nullptr
  180. ? 0
  181. : *tflite::micro::GetTensorData<int32_t>(constant_values);
  182. reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
  183. tflite::micro::GetTensorData<int32_t>(input),
  184. &pad_value, tflite::micro::GetTensorShape(output),
  185. tflite::micro::GetTensorData<int32_t>(output));
  186. } break;
  187. default:
  188. MicroPrintf("Type %s not currently supported by Pad.",
  189. TfLiteTypeGetName(input->type));
  190. return kTfLiteError;
  191. }
  192. return kTfLiteOk;
  193. }
  194. } // namespace pad
  195. TfLiteRegistration Register_PAD() {
  196. return tflite::micro::RegisterOp(pad::Init, pad::Prepare, pad::Eval);
  197. }
  198. // Also register Pad as PadV2.
  199. TfLiteRegistration Register_PADV2() {
  200. return tflite::micro::RegisterOp(pad::Init, pad::Prepare, pad::Eval);
  201. }
  202. } // namespace micro
  203. } // namespace ops
  204. } // namespace tflite