micro_interpreter.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_
  13. #define TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_
  14. #include <cstddef>
  15. #include <cstdint>
  16. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  17. #include "tensorflow/lite/c/c_api_types.h"
  18. #include "tensorflow/lite/c/common.h"
  19. #include "tensorflow/lite/core/api/error_reporter.h"
  20. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  21. #include "tensorflow/lite/micro/micro_allocator.h"
  22. #include "tensorflow/lite/micro/micro_context.h"
  23. #include "tensorflow/lite/micro/micro_graph.h"
  24. #include "tensorflow/lite/micro/micro_op_resolver.h"
  25. #include "tensorflow/lite/micro/micro_profiler.h"
  26. #include "tensorflow/lite/portable_type_to_tflitetype.h"
  27. #include "tensorflow/lite/schema/schema_generated.h"
  28. /// Copied from tensorflow/lite/version.h to avoid a dependency chain into
  29. // tensorflow/core.
  30. #define TFLITE_SCHEMA_VERSION (3)
  31. namespace tflite {
  32. class MicroInterpreter {
  33. public:
  34. // The lifetime of the model, op resolver, tensor arena, error reporter,
  35. // resource variables, and profiler must be at least as long as that of the
  36. // interpreter object, since the interpreter may need to access them at any
  37. // time. This means that you should usually create them with the same scope as
  38. // each other, for example having them all allocated on the stack as local
  39. // variables through a top-level function. The interpreter doesn't do any
  40. // deallocation of any of the pointed-to objects, ownership remains with the
  41. // caller.
  42. MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
  43. uint8_t* tensor_arena, size_t tensor_arena_size,
  44. ErrorReporter* error_reporter,
  45. MicroResourceVariables* resource_variables = nullptr,
  46. MicroProfiler* profiler = nullptr);
  47. // Create an interpreter instance using an existing MicroAllocator instance.
  48. // This constructor should be used when creating an allocator that needs to
  49. // have allocation handled in more than one interpreter or for recording
  50. // allocations inside the interpreter. The lifetime of the allocator must be
  51. // as long as that of the interpreter object.
  52. MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
  53. MicroAllocator* allocator, ErrorReporter* error_reporter,
  54. MicroResourceVariables* resource_variables = nullptr,
  55. MicroProfiler* profiler = nullptr);
  56. ~MicroInterpreter();
  57. // Runs through the model and allocates all necessary input, output and
  58. // intermediate tensors.
  59. TfLiteStatus AllocateTensors();
  60. // In order to support partial graph runs for strided models, this can return
  61. // values other than kTfLiteOk and kTfLiteError.
  62. // TODO(b/149795762): Add this to the TfLiteStatus enum.
  63. TfLiteStatus Invoke();
  64. // This is the recommended API for an application to pass an external payload
  65. // pointer as an external context to kernels. The life time of the payload
  66. // pointer should be at least as long as this interpreter. TFLM supports only
  67. // one external context.
  68. TfLiteStatus SetMicroExternalContext(void* external_context_payload);
  69. TfLiteTensor* input(size_t index);
  70. size_t inputs_size() const {
  71. return model_->subgraphs()->Get(0)->inputs()->size();
  72. }
  73. const flatbuffers::Vector<int32_t>& inputs() const {
  74. return *model_->subgraphs()->Get(0)->inputs();
  75. }
  76. TfLiteTensor* input_tensor(size_t index) { return input(index); }
  77. template <class T>
  78. T* typed_input_tensor(int tensor_index) {
  79. if (TfLiteTensor* tensor_ptr = input_tensor(tensor_index)) {
  80. if (tensor_ptr->type == typeToTfLiteType<T>()) {
  81. return GetTensorData<T>(tensor_ptr);
  82. }
  83. }
  84. return nullptr;
  85. }
  86. TfLiteTensor* output(size_t index);
  87. size_t outputs_size() const {
  88. return model_->subgraphs()->Get(0)->outputs()->size();
  89. }
  90. const flatbuffers::Vector<int32_t>& outputs() const {
  91. return *model_->subgraphs()->Get(0)->outputs();
  92. }
  93. TfLiteTensor* output_tensor(size_t index) { return output(index); }
  94. template <class T>
  95. T* typed_output_tensor(int tensor_index) {
  96. if (TfLiteTensor* tensor_ptr = output_tensor(tensor_index)) {
  97. if (tensor_ptr->type == typeToTfLiteType<T>()) {
  98. return GetTensorData<T>(tensor_ptr);
  99. }
  100. }
  101. return nullptr;
  102. }
  103. // Reset the state to be what you would expect when the interpreter is first
  104. // created. i.e. after Init and Prepare is called for the very first time.
  105. TfLiteStatus Reset();
  106. // TODO(b/244457206): remove this in favor of Reset()
  107. // Reset all variable tensors to the default value.
  108. TfLiteStatus ResetVariableTensors();
  109. TfLiteStatus initialization_status() const { return initialization_status_; }
  110. // Populates node and registration pointers representing the inference graph
  111. // of the model from values inside the flatbuffer (loaded from the TfLiteModel
  112. // instance). Persistent data (e.g. operator data) is allocated from the
  113. // arena.
  114. TfLiteStatus PrepareNodeAndRegistrationDataFromFlatbuffer();
  115. // For debugging only.
  116. // Returns the actual used arena in bytes. This method gives the optimal arena
  117. // size. It's only available after `AllocateTensors` has been called.
  118. // Note that normally `tensor_arena` requires 16 bytes alignment to fully
  119. // utilize the space. If it's not the case, the optimial arena size would be
  120. // arena_used_bytes() + 16.
  121. size_t arena_used_bytes() const { return allocator_.used_bytes(); }
  122. protected:
  123. const MicroAllocator& allocator() const { return allocator_; }
  124. const TfLiteContext& context() const { return context_; }
  125. private:
  126. // TODO(b/158263161): Consider switching to Create() function to enable better
  127. // error reporting during initialization.
  128. void Init(MicroProfiler* profiler);
  129. // Gets the current subgraph index used from within context methods.
  130. int get_subgraph_index() { return graph_.GetCurrentSubgraphIndex(); }
  131. const Model* model_;
  132. const MicroOpResolver& op_resolver_;
  133. ErrorReporter* error_reporter_;
  134. TfLiteContext context_ = {};
  135. MicroAllocator& allocator_;
  136. MicroGraph graph_;
  137. bool tensors_allocated_;
  138. TfLiteStatus initialization_status_;
  139. ScratchBufferHandle* scratch_buffer_handles_ = nullptr;
  140. // TODO(b/162311891): Clean these pointers up when this class supports buffers
  141. // from TfLiteEvalTensor.
  142. TfLiteTensor** input_tensors_;
  143. TfLiteTensor** output_tensors_;
  144. MicroContext micro_context_;
  145. };
  146. } // namespace tflite
  147. #endif // TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_