pooling_common.cc 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/c/builtin_op_data.h"
  13. #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
  14. #include "tensorflow/lite/kernels/internal/reference/pooling.h"
  15. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  16. #include "tensorflow/lite/kernels/kernel_util.h"
  17. #include "tensorflow/lite/kernels/padding.h"
  18. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  19. #include "tensorflow/lite/micro/kernels/pooling.h"
  20. namespace tflite {
  21. const int kPoolingInputTensor = 0;
  22. const int kPoolingOutputTensor = 0;
  23. TfLiteStatus CalculateOpDataPooling(const TfLiteContext* context,
  24. const TfLitePoolParams* params,
  25. const TfLiteTensor* input,
  26. const TfLiteTensor* output,
  27. OpDataPooling* data) {
  28. // input: batch, height, width, channel
  29. int height = SizeOfDimension(input, 1);
  30. int width = SizeOfDimension(input, 2);
  31. int out_height, out_width;
  32. data->padding = ComputePaddingHeightWidth(
  33. params->stride_height, params->stride_width,
  34. /*dilation_rate_height=*/1,
  35. /*dilation_rate_width=*/1, height, width, params->filter_height,
  36. params->filter_width, params->padding, &out_height, &out_width);
  37. return kTfLiteOk;
  38. }
  39. TfLiteStatus PoolingPrepare(TfLiteContext* context, TfLiteNode* node) {
  40. TFLITE_DCHECK(node->builtin_data != nullptr);
  41. auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
  42. TFLITE_DCHECK(node->user_data != nullptr);
  43. OpDataPooling* data = static_cast<OpDataPooling*>(node->user_data);
  44. MicroContext* micro_context = GetMicroContext(context);
  45. TfLiteTensor* input =
  46. micro_context->AllocateTempInputTensor(node, kPoolingInputTensor);
  47. TF_LITE_ENSURE(context, input != nullptr);
  48. TfLiteTensor* output =
  49. micro_context->AllocateTempOutputTensor(node, kPoolingOutputTensor);
  50. TF_LITE_ENSURE(context, output != nullptr);
  51. TF_LITE_ENSURE_STATUS(
  52. CalculateOpDataPooling(context, params, input, output, data));
  53. if (input->type == kTfLiteFloat32) {
  54. CalculateActivationRange(params->activation, &data->activation_min_f32,
  55. &data->activation_max_f32);
  56. } else if (input->type == kTfLiteInt8) {
  57. CalculateActivationRangeQuantized(context, params->activation, output,
  58. &data->activation_min,
  59. &data->activation_max);
  60. }
  61. micro_context->DeallocateTempTfLiteTensor(input);
  62. micro_context->DeallocateTempTfLiteTensor(output);
  63. return kTfLiteOk;
  64. }
  65. void AveragePoolingEvalFloat(const TfLiteContext* context,
  66. const TfLiteNode* node,
  67. const TfLitePoolParams* params,
  68. const OpDataPooling* data,
  69. const TfLiteEvalTensor* input,
  70. TfLiteEvalTensor* output) {
  71. PoolParams op_params;
  72. op_params.stride_height = params->stride_height;
  73. op_params.stride_width = params->stride_width;
  74. op_params.filter_height = params->filter_height;
  75. op_params.filter_width = params->filter_width;
  76. op_params.padding_values.height = data->padding.height;
  77. op_params.padding_values.width = data->padding.width;
  78. op_params.float_activation_min = data->activation_min_f32;
  79. op_params.float_activation_max = data->activation_max_f32;
  80. reference_ops::AveragePool(op_params, tflite::micro::GetTensorShape(input),
  81. tflite::micro::GetTensorData<float>(input),
  82. tflite::micro::GetTensorShape(output),
  83. tflite::micro::GetTensorData<float>(output));
  84. }
  85. void AveragePoolingEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
  86. const TfLitePoolParams* params,
  87. const OpDataPooling* data,
  88. const TfLiteEvalTensor* input,
  89. TfLiteEvalTensor* output) {
  90. TFLITE_DCHECK(input->type == kTfLiteInt8);
  91. PoolParams op_params;
  92. op_params.stride_height = params->stride_height;
  93. op_params.stride_width = params->stride_width;
  94. op_params.filter_height = params->filter_height;
  95. op_params.filter_width = params->filter_width;
  96. op_params.padding_values.height = data->padding.height;
  97. op_params.padding_values.width = data->padding.width;
  98. op_params.quantized_activation_min = data->activation_min;
  99. op_params.quantized_activation_max = data->activation_max;
  100. reference_integer_ops::AveragePool(
  101. op_params, tflite::micro::GetTensorShape(input),
  102. tflite::micro::GetTensorData<int8_t>(input),
  103. tflite::micro::GetTensorShape(output),
  104. tflite::micro::GetTensorData<int8_t>(output));
  105. }
  106. void MaxPoolingEvalFloat(TfLiteContext* context, TfLiteNode* node,
  107. TfLitePoolParams* params, const OpDataPooling* data,
  108. const TfLiteEvalTensor* input,
  109. TfLiteEvalTensor* output) {
  110. tflite::PoolParams op_params;
  111. op_params.stride_height = params->stride_height;
  112. op_params.stride_width = params->stride_width;
  113. op_params.filter_height = params->filter_height;
  114. op_params.filter_width = params->filter_width;
  115. op_params.padding_values.height = data->padding.height;
  116. op_params.padding_values.width = data->padding.width;
  117. op_params.float_activation_min = data->activation_min_f32;
  118. op_params.float_activation_max = data->activation_max_f32;
  119. reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
  120. tflite::micro::GetTensorData<float>(input),
  121. tflite::micro::GetTensorShape(output),
  122. tflite::micro::GetTensorData<float>(output));
  123. }
  124. void MaxPoolingEvalQuantized(TfLiteContext* context, TfLiteNode* node,
  125. TfLitePoolParams* params,
  126. const OpDataPooling* data,
  127. const TfLiteEvalTensor* input,
  128. TfLiteEvalTensor* output) {
  129. tflite::PoolParams op_params;
  130. op_params.stride_height = params->stride_height;
  131. op_params.stride_width = params->stride_width;
  132. op_params.filter_height = params->filter_height;
  133. op_params.filter_width = params->filter_width;
  134. op_params.padding_values.height = data->padding.height;
  135. op_params.padding_values.width = data->padding.width;
  136. op_params.quantized_activation_min = data->activation_min;
  137. op_params.quantized_activation_max = data->activation_max;
  138. reference_integer_ops::MaxPool(op_params,
  139. tflite::micro::GetTensorShape(input),
  140. tflite::micro::GetTensorData<int8_t>(input),
  141. tflite::micro::GetTensorShape(output),
  142. tflite::micro::GetTensorData<int8_t>(output));
  143. }
  144. } // namespace tflite