| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #ifndef TENSORFLOW_LITE_MICRO_KERNELS_KERNEL_UTIL_H_
- #define TENSORFLOW_LITE_MICRO_KERNELS_KERNEL_UTIL_H_
- #include <cstdint>
- #include "tensorflow/lite/c/builtin_op_data.h"
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/kernels/internal/compatibility.h"
- #include "tensorflow/lite/kernels/internal/types.h"
- namespace tflite {
- namespace micro {
- // Returns a mutable tensor for a given input index. is_variable must be checked
- // during prepare when the full TfLiteTensor is available.
- TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
- const TfLiteNode* node, int index);
- // Returns the TfLiteEvalTensor struct for a given input index in a node.
- const TfLiteEvalTensor* GetEvalInput(const TfLiteContext* context,
- const TfLiteNode* node, int index);
- // Returns the TfLiteEvalTensor struct for a given output index in a node.
- TfLiteEvalTensor* GetEvalOutput(const TfLiteContext* context,
- const TfLiteNode* node, int index);
- // Returns data for a TfLiteEvalTensor struct.
- template <typename T>
- T* GetTensorData(TfLiteEvalTensor* tensor) {
- return tensor != nullptr ? reinterpret_cast<T*>(tensor->data.raw) : nullptr;
- }
- // Returns const data for a TfLiteEvalTensor struct.
- template <typename T>
- const T* GetTensorData(const TfLiteEvalTensor* tensor) {
- TFLITE_DCHECK(tensor != nullptr);
- return reinterpret_cast<const T*>(tensor->data.raw);
- }
- // Returns the shape of a TfLiteEvalTensor struct.
- const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
- // Return true if the given tensors have the same shape.
- bool HaveSameShapes(const TfLiteEvalTensor* input1,
- const TfLiteEvalTensor* input2);
- PaddingType RuntimePaddingType(TfLitePadding padding);
- // Relocate tensor dims from FlatBuffer to the persistent storage arena.
- // The old dims data is copied to the new storage area.
- // The tensor and eval_tensor must be the same tensor.
- // Only use during Prepare phase.
- TfLiteStatus CreateWritableTensorDimsWithCopy(TfLiteContext* context,
- TfLiteTensor* tensor,
- TfLiteEvalTensor* eval_tensor);
- // Returns a blob of payload data. The payload is subjected to interpretation by
- // the OP. This is the recommended API for an OP to get an external context. OP
- // should use this instead of directly calling GetExternalContext function in
- // context. Example usage:
- //
- // An application can set an external context through interpreter as below
- // interpreter->SetMicroExternalContext(pointer_to_your_payload);
- //
- // Inside an OP that needs this payload, it get the payload pointer by:
- // Prepare(TfliteContext * context) {
- // ...
- // payload_ptr =
- // reinterpret_cast<your_data_type>(GetMicroExternalContext(context))
- // ...
- // }
- //
- void* GetMicroExternalContext(TfLiteContext* context);
- } // namespace micro
- } // namespace tflite
- #endif // TENSORFLOW_LITE_MICRO_KERNELS_KERNEL_UTIL_H_
|