conv copy.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/conv.h"
  13. #include "tensorflow/lite/c/builtin_op_data.h"
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/kernels/internal/common.h"
  16. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  17. #include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
  18. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  19. #include "tensorflow/lite/kernels/kernel_util.h"
  20. #include "tensorflow/lite/kernels/padding.h"
  21. namespace tflite {
  22. namespace ops {
  23. namespace micro {
  24. namespace conv {
  25. constexpr int kInputTensor = 0;
  26. constexpr int kFilterTensor = 1;
  27. constexpr int kBiasTensor = 2;
  28. constexpr int kOutputTensor = 0;
  29. // Angepasst jomjol 05.06.20
  30. //constexpr int kMaxChannels = 1024;
  31. constexpr int kMaxChannels = 4096;
  32. // Conv is quantized along dimension 0:
  33. // https://www.tensorflow.org/lite/performance/quantization_spec
  34. constexpr int kConvQuantizedDimension = 0;
  35. // This file has 2 implementation of Conv.
  36. struct OpData {
  37. TfLitePaddingValues padding;
  38. // The scaling factor from input to output (aka the 'real multiplier') can
  39. // be represented as a fixed point multiplier plus a left shift.
  40. int32_t output_multiplier;
  41. int output_shift;
  42. // Per channel output multiplier and shift.
  43. // TODO(b/141139247): Allocate these dynamically when possible.
  44. int32_t per_channel_output_multiplier[kMaxChannels];
  45. int32_t per_channel_output_shift[kMaxChannels];
  46. // The range of the fused activation layer. For example for kNone and
  47. // uint8_t these would be 0 and 255.
  48. int32_t output_activation_min;
  49. int32_t output_activation_max;
  50. };
  51. inline PaddingType RuntimePaddingType(TfLitePadding padding) {
  52. switch (padding) {
  53. case TfLitePadding::kTfLitePaddingSame:
  54. return PaddingType::kSame;
  55. case TfLitePadding::kTfLitePaddingValid:
  56. return PaddingType::kValid;
  57. case TfLitePadding::kTfLitePaddingUnknown:
  58. default:
  59. return PaddingType::kNone;
  60. }
  61. }
  62. TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
  63. TfLiteConvParams* params, int width, int height,
  64. int filter_width, int filter_height, int out_width,
  65. int out_height, const TfLiteType data_type,
  66. OpData* data) {
  67. bool has_bias = node->inputs->size == 3;
  68. // Check number of inputs/outputs
  69. TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
  70. TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
  71. // Matching GetWindowedOutputSize in TensorFlow.
  72. auto padding = params->padding;
  73. data->padding = ComputePaddingHeightWidth(
  74. params->stride_height, params->stride_width,
  75. params->dilation_height_factor, params->dilation_width_factor, height,
  76. width, filter_height, filter_width, padding, &out_height, &out_width);
  77. // Note that quantized inference requires that all tensors have their
  78. // parameters set. This is usually done during quantized training.
  79. if (data_type != kTfLiteFloat32) {
  80. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  81. const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
  82. const TfLiteTensor* bias =
  83. GetOptionalInputTensor(context, node, kBiasTensor);
  84. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  85. int output_channels = filter->dims->data[kConvQuantizedDimension];
  86. TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
  87. context, input, filter, bias, output, params->activation,
  88. &data->output_multiplier, &data->output_shift,
  89. &data->output_activation_min, &data->output_activation_max,
  90. data->per_channel_output_multiplier,
  91. reinterpret_cast<int*>(data->per_channel_output_shift),
  92. output_channels));
  93. }
  94. return kTfLiteOk;
  95. }
  96. void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
  97. TfLiteConvParams* params, OpData* data,
  98. const TfLiteTensor* input, const TfLiteTensor* filter,
  99. const TfLiteTensor* bias, TfLiteTensor* im2col,
  100. TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
  101. const int32_t input_offset = -input->params.zero_point;
  102. const int32_t filter_offset = -filter->params.zero_point;
  103. const int32_t output_offset = output->params.zero_point;
  104. ConvParams op_params;
  105. op_params.padding_type = RuntimePaddingType(params->padding);
  106. op_params.padding_values.width = data->padding.width;
  107. op_params.padding_values.height = data->padding.height;
  108. op_params.stride_width = params->stride_width;
  109. op_params.stride_height = params->stride_height;
  110. op_params.dilation_width_factor = params->dilation_width_factor;
  111. op_params.dilation_height_factor = params->dilation_height_factor;
  112. op_params.input_offset = input_offset;
  113. op_params.weights_offset = filter_offset;
  114. op_params.output_offset = output_offset;
  115. op_params.output_multiplier = data->output_multiplier;
  116. op_params.output_shift = -data->output_shift;
  117. op_params.quantized_activation_min = data->output_activation_min;
  118. op_params.quantized_activation_max = data->output_activation_max;
  119. reference_ops::Conv(op_params, GetTensorShape(input),
  120. GetTensorData<uint8_t>(input), GetTensorShape(filter),
  121. GetTensorData<uint8_t>(filter), GetTensorShape(bias),
  122. GetTensorData<int32_t>(bias), GetTensorShape(output),
  123. GetTensorData<uint8_t>(output), GetTensorShape(im2col),
  124. GetTensorData<uint8_t>(im2col), nullptr);
  125. }
  126. void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
  127. TfLiteConvParams* params, OpData* data,
  128. const TfLiteTensor* input,
  129. const TfLiteTensor* filter,
  130. const TfLiteTensor* bias, TfLiteTensor* output,
  131. TfLiteTensor* im2col) {
  132. ConvParams op_params;
  133. op_params.input_offset = -input->params.zero_point;
  134. op_params.output_offset = output->params.zero_point;
  135. op_params.stride_height = params->stride_height;
  136. op_params.stride_width = params->stride_width;
  137. op_params.dilation_height_factor = params->dilation_height_factor;
  138. op_params.dilation_width_factor = params->dilation_width_factor;
  139. op_params.padding_values.height = data->padding.height;
  140. op_params.padding_values.width = data->padding.width;
  141. op_params.quantized_activation_min = data->output_activation_min;
  142. op_params.quantized_activation_max = data->output_activation_max;
  143. reference_integer_ops::ConvPerChannel(
  144. op_params, data->per_channel_output_multiplier,
  145. data->per_channel_output_shift, GetTensorShape(input),
  146. GetTensorData<int8>(input), GetTensorShape(filter),
  147. GetTensorData<int8>(filter), GetTensorShape(bias),
  148. GetTensorData<int32>(bias), GetTensorShape(output),
  149. GetTensorData<int8>(output));
  150. }
  151. void EvalFloat(TfLiteContext* context, TfLiteNode* node,
  152. TfLiteConvParams* params, OpData* data,
  153. const TfLiteTensor* input, const TfLiteTensor* filter,
  154. const TfLiteTensor* bias, TfLiteTensor* im2col,
  155. TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
  156. float output_activation_min, output_activation_max;
  157. CalculateActivationRange(params->activation, &output_activation_min,
  158. &output_activation_max);
  159. ConvParams op_params;
  160. op_params.padding_type = RuntimePaddingType(params->padding);
  161. op_params.padding_values.width = data->padding.width;
  162. op_params.padding_values.height = data->padding.height;
  163. op_params.stride_width = params->stride_width;
  164. op_params.stride_height = params->stride_height;
  165. op_params.dilation_width_factor = params->dilation_width_factor;
  166. op_params.dilation_height_factor = params->dilation_height_factor;
  167. op_params.float_activation_min = output_activation_min;
  168. op_params.float_activation_max = output_activation_max;
  169. reference_ops::Conv(op_params, GetTensorShape(input),
  170. GetTensorData<float>(input), GetTensorShape(filter),
  171. GetTensorData<float>(filter), GetTensorShape(bias),
  172. GetTensorData<float>(bias), GetTensorShape(output),
  173. GetTensorData<float>(output), GetTensorShape(im2col),
  174. GetTensorData<float>(im2col));
  175. }
  176. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  177. auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
  178. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  179. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  180. const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
  181. const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
  182. int input_width = input->dims->data[2];
  183. int input_height = input->dims->data[1];
  184. int filter_width = filter->dims->data[2];
  185. int filter_height = filter->dims->data[1];
  186. int output_width = output->dims->data[2];
  187. int output_height = output->dims->data[1];
  188. OpData data;
  189. // All per-channel quantized tensors need valid zero point and scale arrays.
  190. if (input->type == kTfLiteInt8) {
  191. TF_LITE_ENSURE_EQ(context, filter->quantization.type,
  192. kTfLiteAffineQuantization);
  193. const auto* affine_quantization =
  194. reinterpret_cast<TfLiteAffineQuantization*>(
  195. filter->quantization.params);
  196. TF_LITE_ENSURE(context, affine_quantization);
  197. TF_LITE_ENSURE(context, affine_quantization->scale);
  198. TF_LITE_ENSURE(context, affine_quantization->zero_point);
  199. TF_LITE_ENSURE(context,
  200. affine_quantization->scale->size == 1 ||
  201. affine_quantization->scale->size ==
  202. filter->dims->data[kConvQuantizedDimension]);
  203. TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
  204. affine_quantization->zero_point->size);
  205. }
  206. TF_LITE_ENSURE_STATUS(CalculateOpData(
  207. context, node, params, input_width, input_height, filter_width,
  208. filter_height, output_width, output_height, input->type, &data));
  209. switch (input->type) { // Already know in/out types are same.
  210. case kTfLiteFloat32:
  211. EvalFloat(context, node, params, &data, input, filter, bias, nullptr,
  212. nullptr, output);
  213. break;
  214. case kTfLiteInt8:
  215. EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias,
  216. output, nullptr);
  217. break;
  218. case kTfLiteUInt8:
  219. EvalQuantized(context, node, params, &data, input, filter, bias, nullptr,
  220. nullptr, output);
  221. break;
  222. default:
  223. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  224. TfLiteTypeGetName(input->type), input->type);
  225. return kTfLiteError;
  226. }
  227. return kTfLiteOk;
  228. }
  229. } // namespace conv
  230. TfLiteRegistration* Register_CONV_2D() {
  231. static TfLiteRegistration r = {/*init=*/nullptr,
  232. /*free=*/nullptr,
  233. /*prepare=*/nullptr,
  234. /*invoke=*/conv::Eval,
  235. /*profiling_string=*/nullptr,
  236. /*builtin_code=*/0,
  237. /*custom_name=*/nullptr,
  238. /*version=*/0};
  239. return &r;
  240. }
  241. } // namespace micro
  242. } // namespace ops
  243. } // namespace tflite