| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #include "tensorflow/lite/c/builtin_op_data.h"
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/kernels/internal/quantization_util.h"
- #include "tensorflow/lite/kernels/internal/reference/add.h"
- #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
- #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
- #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
- #include "tensorflow/lite/kernels/kernel_util.h"
- #include "tensorflow/lite/kernels/op_macros.h"
- #include "tensorflow/lite/micro/kernels/add.h"
- #include "tensorflow/lite/micro/kernels/kernel_util.h"
- #include "tensorflow/lite/micro/memory_helpers.h"
- namespace tflite {
- const int kAddInputTensor1 = 0;
- const int kAddInputTensor2 = 1;
- const int kAddOutputTensor = 0;
- TfLiteStatus CalculateOpDataAdd(TfLiteContext* context, TfLiteAddParams* params,
- const TfLiteTensor* input1,
- const TfLiteTensor* input2,
- TfLiteTensor* output, OpDataAdd* data) {
- data->requires_broadcast = !HaveSameShapes(input1, input2);
- if (output->type == kTfLiteInt8 || output->type == kTfLiteInt16) {
- // 8bit -> 8bit general quantized path, with general rescalings
- data->input1_offset = -input1->params.zero_point;
- data->input2_offset = -input2->params.zero_point;
- data->output_offset = output->params.zero_point;
- data->left_shift = (output->type == kTfLiteInt16) ? 15 : 20;
- const double twice_max_input_scale =
- 2 * static_cast<double>(
- std::max(input1->params.scale, input2->params.scale));
- const double real_input1_multiplier =
- static_cast<double>(input1->params.scale) / twice_max_input_scale;
- const double real_input2_multiplier =
- static_cast<double>(input2->params.scale) / twice_max_input_scale;
- const double real_output_multiplier =
- twice_max_input_scale /
- ((1 << data->left_shift) * static_cast<double>(output->params.scale));
- QuantizeMultiplierSmallerThanOneExp(
- real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
- QuantizeMultiplierSmallerThanOneExp(
- real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
- QuantizeMultiplierSmallerThanOneExp(
- real_output_multiplier, &data->output_multiplier, &data->output_shift);
- TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
- context, params->activation, output, &data->output_activation_min,
- &data->output_activation_max));
- } else if (output->type == kTfLiteFloat32) {
- CalculateActivationRange(params->activation,
- &data->output_activation_min_f32,
- &data->output_activation_max_f32);
- }
- return kTfLiteOk;
- }
- TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->user_data != nullptr);
- TFLITE_DCHECK(node->builtin_data != nullptr);
- MicroContext* micro_context = GetMicroContext(context);
- TfLiteTensor* input1 =
- micro_context->AllocateTempInputTensor(node, kAddInputTensor1);
- TF_LITE_ENSURE(context, input1 != nullptr);
- TfLiteTensor* input2 =
- micro_context->AllocateTempInputTensor(node, kAddInputTensor2);
- TF_LITE_ENSURE(context, input2 != nullptr);
- TfLiteTensor* output =
- micro_context->AllocateTempOutputTensor(node, kAddOutputTensor);
- TF_LITE_ENSURE(context, output != nullptr);
- OpDataAdd* data = static_cast<OpDataAdd*>(node->user_data);
- auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
- TF_LITE_ENSURE_STATUS(
- CalculateOpDataAdd(context, params, input1, input2, output, data));
- micro_context->DeallocateTempTfLiteTensor(input1);
- micro_context->DeallocateTempTfLiteTensor(input2);
- micro_context->DeallocateTempTfLiteTensor(output);
- return kTfLiteOk;
- }
- } // namespace tflite
|