micro_interpreter.h 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_
  13. #define TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_
  14. #include <cstddef>
  15. #include <cstdint>
  16. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  17. #include "tensorflow/lite/c/common.h"
  18. #include "tensorflow/lite/core/api/error_reporter.h"
  19. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  20. #include "tensorflow/lite/micro/micro_allocator.h"
  21. #include "tensorflow/lite/micro/micro_graph.h"
  22. #include "tensorflow/lite/micro/micro_op_resolver.h"
  23. #include "tensorflow/lite/micro/micro_profiler.h"
  24. #include "tensorflow/lite/portable_type_to_tflitetype.h"
  25. #include "tensorflow/lite/schema/schema_generated.h"
  26. // Copied from tensorflow/lite/version.h to avoid a dependency chain into
  27. // tensorflow/core.
  28. #define TFLITE_SCHEMA_VERSION (3)
  29. namespace tflite {
  30. class MicroInterpreter {
  31. public:
  32. // The lifetime of the model, op resolver, tensor arena, error reporter and
  33. // profiler must be at least as long as that of the interpreter object, since
  34. // the interpreter may need to access them at any time. This means that you
  35. // should usually create them with the same scope as each other, for example
  36. // having them all allocated on the stack as local variables through a
  37. // top-level function. The interpreter doesn't do any deallocation of any of
  38. // the pointed-to objects, ownership remains with the caller.
  39. MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
  40. uint8_t* tensor_arena, size_t tensor_arena_size,
  41. ErrorReporter* error_reporter,
  42. MicroProfiler* profiler = nullptr);
  43. // Create an interpreter instance using an existing MicroAllocator instance.
  44. // This constructor should be used when creating an allocator that needs to
  45. // have allocation handled in more than one interpreter or for recording
  46. // allocations inside the interpreter. The lifetime of the allocator must be
  47. // as long as that of the interpreter object.
  48. MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
  49. MicroAllocator* allocator, ErrorReporter* error_reporter,
  50. MicroProfiler* profiler = nullptr);
  51. ~MicroInterpreter();
  52. // Runs through the model and allocates all necessary input, output and
  53. // intermediate tensors.
  54. TfLiteStatus AllocateTensors();
  55. // In order to support partial graph runs for strided models, this can return
  56. // values other than kTfLiteOk and kTfLiteError.
  57. // TODO(b/149795762): Add this to the TfLiteStatus enum.
  58. TfLiteStatus Invoke();
  59. TfLiteTensor* input(size_t index);
  60. size_t inputs_size() const {
  61. return model_->subgraphs()->Get(0)->inputs()->size();
  62. }
  63. const flatbuffers::Vector<int32_t>& inputs() const {
  64. return *model_->subgraphs()->Get(0)->inputs();
  65. }
  66. TfLiteTensor* input_tensor(size_t index) { return input(index); }
  67. template <class T>
  68. T* typed_input_tensor(int tensor_index) {
  69. if (TfLiteTensor* tensor_ptr = input_tensor(tensor_index)) {
  70. if (tensor_ptr->type == typeToTfLiteType<T>()) {
  71. return GetTensorData<T>(tensor_ptr);
  72. }
  73. }
  74. return nullptr;
  75. }
  76. TfLiteTensor* output(size_t index);
  77. size_t outputs_size() const {
  78. return model_->subgraphs()->Get(0)->outputs()->size();
  79. }
  80. const flatbuffers::Vector<int32_t>& outputs() const {
  81. return *model_->subgraphs()->Get(0)->outputs();
  82. }
  83. TfLiteTensor* output_tensor(size_t index) { return output(index); }
  84. template <class T>
  85. T* typed_output_tensor(int tensor_index) {
  86. if (TfLiteTensor* tensor_ptr = output_tensor(tensor_index)) {
  87. if (tensor_ptr->type == typeToTfLiteType<T>()) {
  88. return GetTensorData<T>(tensor_ptr);
  89. }
  90. }
  91. return nullptr;
  92. }
  93. // Reset all variable tensors to the default value.
  94. TfLiteStatus ResetVariableTensors();
  95. TfLiteStatus initialization_status() const { return initialization_status_; }
  96. // Populates node and registration pointers representing the inference graph
  97. // of the model from values inside the flatbuffer (loaded from the TfLiteModel
  98. // instance). Persistent data (e.g. operator data) is allocated from the
  99. // arena.
  100. TfLiteStatus PrepareNodeAndRegistrationDataFromFlatbuffer();
  101. // For debugging only.
  102. // Returns the actual used arena in bytes. This method gives the optimal arena
  103. // size. It's only available after `AllocateTensors` has been called.
  104. // Note that normally `tensor_arena` requires 16 bytes alignment to fully
  105. // utilize the space. If it's not the case, the optimial arena size would be
  106. // arena_used_bytes() + 16.
  107. size_t arena_used_bytes() const { return allocator_.used_bytes(); }
  108. protected:
  109. const MicroAllocator& allocator() const { return allocator_; }
  110. const TfLiteContext& context() const { return context_; }
  111. private:
  112. // TODO(b/158263161): Consider switching to Create() function to enable better
  113. // error reporting during initialization.
  114. void Init(MicroProfiler* profiler);
  115. // Gets the current subgraph index used from within context methods.
  116. int get_subgraph_index() { return graph_.GetCurrentSubgraphIndex(); }
  117. // Static functions that are bound to the TfLiteContext instance:
  118. static void* AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes);
  119. static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx,
  120. size_t bytes,
  121. int* buffer_idx);
  122. static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx);
  123. static void ReportOpError(struct TfLiteContext* context, const char* format,
  124. ...);
  125. static TfLiteTensor* GetTensor(const struct TfLiteContext* context,
  126. int tensor_idx);
  127. static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context,
  128. int tensor_idx);
  129. static TfLiteStatus GetGraph(struct TfLiteContext* context,
  130. TfLiteIntArray** args);
  131. const Model* model_;
  132. const MicroOpResolver& op_resolver_;
  133. ErrorReporter* error_reporter_;
  134. TfLiteContext context_ = {};
  135. MicroAllocator& allocator_;
  136. MicroGraph graph_;
  137. bool tensors_allocated_;
  138. TfLiteStatus initialization_status_;
  139. ScratchBufferHandle* scratch_buffer_handles_ = nullptr;
  140. // TODO(b/162311891): Clean these pointers up when this class supports buffers
  141. // from TfLiteEvalTensor.
  142. TfLiteTensor** input_tensors_;
  143. TfLiteTensor** output_tensors_;
  144. };
  145. } // namespace tflite
  146. #endif // TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_