| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #include "tensorflow/lite/c/builtin_op_data.h"
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/kernels/internal/quantization_util.h"
- #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
- #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
- #include "tensorflow/lite/kernels/kernel_util.h"
- #include "tensorflow/lite/kernels/op_macros.h"
- #include "tensorflow/lite/micro/kernels/kernel_util.h"
- #include "tensorflow/lite/micro/memory_helpers.h"
- namespace tflite {
- namespace {
- struct SqueezeContext {
- SqueezeContext(TfLiteContext* context, TfLiteNode* node) {
- params = reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data);
- micro_context = GetMicroContext(context);
- input = micro_context->AllocateTempInputTensor(node, 0);
- output = micro_context->AllocateTempOutputTensor(node, 0);
- }
- ~SqueezeContext() {
- micro_context->DeallocateTempTfLiteTensor(input);
- micro_context->DeallocateTempTfLiteTensor(output);
- }
- MicroContext* micro_context;
- TfLiteSqueezeParams* params;
- TfLiteTensor* input;
- TfLiteTensor* output;
- };
- TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- SqueezeContext op_context(context, node);
- const int input_num_dims = NumDimensions(op_context.input);
- const int num_squeeze_dims = op_context.params->num_squeeze_dims;
- // Determines number of dimensions of output tensor after squeeze.
- const TfLiteIntArray* input_dims = op_context.input->dims;
- const TfLiteIntArray* output_dims = op_context.output->dims;
- const int* squeeze_dims = op_context.params->squeeze_dims;
- constexpr int max_squeeze_dims = 8;
- TF_LITE_ENSURE(context, input_num_dims <= max_squeeze_dims);
- bool should_squeeze[max_squeeze_dims] = {};
- if (num_squeeze_dims == 0) {
- for (int idx = 0; idx < input_num_dims; ++idx) {
- if (input_dims->data[idx] == 1) {
- should_squeeze[idx] = true;
- }
- }
- } else {
- for (int idx = 0; idx < num_squeeze_dims; ++idx) {
- int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + input_num_dims
- : squeeze_dims[idx];
- TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims &&
- input_dims->data[current] == 1);
- should_squeeze[current] = true;
- }
- }
- // Ensure output dimensions are big enough.
- for (int in_idx = 0, out_idx = 0; in_idx < input_num_dims; ++in_idx) {
- if (!should_squeeze[in_idx]) {
- TFLITE_CHECK_GE(output_dims->data[out_idx++], input_dims->data[in_idx]);
- }
- }
- return kTfLiteOk;
- }
- TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
- if (input->type == kTfLiteString) {
- TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
- TfLiteTypeGetName(input->type), input->type);
- return kTfLiteError;
- }
- TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
- size_t input_byte_size;
- size_t output_byte_size;
- TF_LITE_ENSURE_OK(context,
- TfLiteEvalTensorByteLength(input, &input_byte_size));
- TF_LITE_ENSURE_OK(context,
- TfLiteEvalTensorByteLength(output, &output_byte_size));
- TF_LITE_ENSURE_EQ(context, input_byte_size, output_byte_size);
- memcpy(output->data.raw, input->data.raw, input_byte_size);
- return kTfLiteOk;
- }
- } // namespace
- TfLiteRegistration Register_SQUEEZE() {
- return {/*init=*/nullptr,
- /*free=*/nullptr,
- /*prepare=*/Prepare,
- /*invoke=*/Eval,
- /*profiling_string=*/nullptr,
- /*builtin_code=*/0,
- /*custom_name=*/nullptr,
- /*version=*/0};
- }
- } // namespace tflite
|