micro_interpreter.cc 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/micro_interpreter.h"
  13. #include <cstdarg>
  14. #include <cstddef>
  15. #include <cstdint>
  16. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  17. #include "tensorflow/lite/c/c_api_types.h"
  18. #include "tensorflow/lite/c/common.h"
  19. #include "tensorflow/lite/core/api/error_reporter.h"
  20. #include "tensorflow/lite/core/api/tensor_utils.h"
  21. #include "tensorflow/lite/micro/flatbuffer_utils.h"
  22. #include "tensorflow/lite/micro/memory_helpers.h"
  23. #include "tensorflow/lite/micro/micro_allocator.h"
  24. #include "tensorflow/lite/micro/micro_error_reporter.h"
  25. #include "tensorflow/lite/micro/micro_op_resolver.h"
  26. #include "tensorflow/lite/micro/micro_profiler.h"
  27. #include "tensorflow/lite/schema/schema_generated.h"
  28. #include "tensorflow/lite/schema/schema_utils.h"
  29. namespace tflite {
  30. MicroInterpreter::MicroInterpreter(const Model* model,
  31. const MicroOpResolver& op_resolver,
  32. uint8_t* tensor_arena,
  33. size_t tensor_arena_size,
  34. ErrorReporter* error_reporter,
  35. MicroResourceVariables* resource_variables,
  36. MicroProfiler* profiler)
  37. : model_(model),
  38. op_resolver_(op_resolver),
  39. error_reporter_(error_reporter),
  40. allocator_(*MicroAllocator::Create(tensor_arena, tensor_arena_size,
  41. error_reporter)),
  42. graph_(&context_, model, &allocator_, resource_variables),
  43. tensors_allocated_(false),
  44. initialization_status_(kTfLiteError),
  45. input_tensors_(nullptr),
  46. output_tensors_(nullptr),
  47. micro_context_(&allocator_, model_, &graph_) {
  48. Init(profiler);
  49. }
  50. MicroInterpreter::MicroInterpreter(const Model* model,
  51. const MicroOpResolver& op_resolver,
  52. MicroAllocator* allocator,
  53. ErrorReporter* error_reporter,
  54. MicroResourceVariables* resource_variables,
  55. MicroProfiler* profiler)
  56. : model_(model),
  57. op_resolver_(op_resolver),
  58. error_reporter_(error_reporter),
  59. allocator_(*allocator),
  60. graph_(&context_, model, allocator, resource_variables),
  61. tensors_allocated_(false),
  62. initialization_status_(kTfLiteError),
  63. input_tensors_(nullptr),
  64. output_tensors_(nullptr),
  65. micro_context_(&allocator_, model_, &graph_) {
  66. Init(profiler);
  67. }
  68. MicroInterpreter::~MicroInterpreter() {
  69. if (graph_.GetAllocations() != nullptr) {
  70. graph_.FreeSubgraphs();
  71. }
  72. }
  73. void MicroInterpreter::Init(MicroProfiler* profiler) {
  74. context_.impl_ = static_cast<void*>(&micro_context_);
  75. context_.ReportError = MicroContextReportOpError;
  76. context_.GetTensor = MicroContextGetTensor;
  77. context_.GetEvalTensor = MicroContextGetEvalTensor;
  78. context_.profiler = profiler;
  79. initialization_status_ = kTfLiteOk;
  80. }
  81. TfLiteStatus MicroInterpreter::PrepareNodeAndRegistrationDataFromFlatbuffer() {
  82. for (int subgraph_idx = 0; subgraph_idx < graph_.NumSubgraphs();
  83. subgraph_idx++) {
  84. const SubGraph* subgraph = model_->subgraphs()->Get(subgraph_idx);
  85. TFLITE_DCHECK(subgraph != nullptr);
  86. auto* opcodes = model_->operator_codes();
  87. BuiltinDataAllocator* builtin_data_allocator =
  88. allocator_.GetBuiltinDataAllocator();
  89. uint32_t operators_size = NumSubgraphOperators(subgraph);
  90. for (size_t i = 0; i < operators_size; ++i) {
  91. const auto* op = subgraph->operators()->Get(i);
  92. const size_t index = op->opcode_index();
  93. if (index >= opcodes->size()) {
  94. MicroPrintf("Missing registration for opcode_index %d\n", index);
  95. return kTfLiteError;
  96. }
  97. const auto* opcode = opcodes->Get(index);
  98. TfLiteStatus status =
  99. GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
  100. &(graph_.GetAllocations()[subgraph_idx]
  101. .node_and_registrations[i]
  102. .registration));
  103. if (status != kTfLiteOk) {
  104. MicroPrintf("Failed to get registration from op code %s\n ",
  105. EnumNameBuiltinOperator(GetBuiltinCode(opcode)));
  106. return status;
  107. }
  108. const auto* registration = graph_.GetAllocations()[subgraph_idx]
  109. .node_and_registrations[i]
  110. .registration;
  111. if (registration == nullptr) {
  112. MicroPrintf("Skipping op for opcode_index %d\n", index);
  113. return kTfLiteError;
  114. }
  115. BuiltinOperator op_type =
  116. static_cast<BuiltinOperator>(registration->builtin_code);
  117. const char* custom_data = nullptr;
  118. size_t custom_data_size = 0;
  119. unsigned char* builtin_data = nullptr;
  120. if (op_type == BuiltinOperator_CUSTOM) {
  121. // Custom Ops may or may not have a non-null custom_options field.
  122. if (op->custom_options() != nullptr) {
  123. custom_data =
  124. reinterpret_cast<const char*>(op->custom_options()->data());
  125. custom_data_size = op->custom_options()->size();
  126. }
  127. } else {
  128. if (op->custom_options() != nullptr) {
  129. MicroPrintf(
  130. "Unsupported behavior: found builtin operator %s with custom "
  131. "options.\n",
  132. EnumNameBuiltinOperator(op_type));
  133. return kTfLiteError;
  134. }
  135. MicroOpResolver::BuiltinParseFunction parser =
  136. op_resolver_.GetOpDataParser(op_type);
  137. if (parser == nullptr) {
  138. MicroPrintf("Did not find a parser for %s",
  139. EnumNameBuiltinOperator(op_type));
  140. return kTfLiteError;
  141. }
  142. TF_LITE_ENSURE_STATUS(parser(op, error_reporter_,
  143. builtin_data_allocator,
  144. (void**)(&builtin_data)));
  145. }
  146. TfLiteIntArray* inputs_array =
  147. FlatBufferVectorToTfLiteTypeArray(op->inputs());
  148. TfLiteIntArray* outputs_array =
  149. FlatBufferVectorToTfLiteTypeArray(op->outputs());
  150. TfLiteNode* node = &(
  151. graph_.GetAllocations()[subgraph_idx].node_and_registrations[i].node);
  152. *node = {};
  153. node->inputs = inputs_array;
  154. node->outputs = outputs_array;
  155. node->builtin_data = reinterpret_cast<void*>(builtin_data);
  156. node->custom_initial_data = custom_data;
  157. node->custom_initial_data_size = custom_data_size;
  158. if (op->intermediates() && (op->intermediates()->size() > 0)) {
  159. node->intermediates =
  160. FlatBufferVectorToTfLiteTypeArray(op->intermediates());
  161. }
  162. }
  163. }
  164. return kTfLiteOk;
  165. }
  166. TfLiteStatus MicroInterpreter::AllocateTensors() {
  167. SubgraphAllocations* allocations = allocator_.StartModelAllocation(model_);
  168. if (allocations == nullptr) {
  169. TF_LITE_REPORT_ERROR(error_reporter_,
  170. "Failed starting model allocation.\n");
  171. initialization_status_ = kTfLiteError;
  172. return kTfLiteError;
  173. }
  174. graph_.SetSubgraphAllocations(allocations);
  175. TF_LITE_ENSURE_STATUS(PrepareNodeAndRegistrationDataFromFlatbuffer());
  176. // Only allow AllocatePersistentBuffer in Init stage.
  177. context_.AllocatePersistentBuffer = MicroContextAllocatePersistentBuffer;
  178. context_.RequestScratchBufferInArena = nullptr;
  179. context_.GetScratchBuffer = nullptr;
  180. context_.GetExternalContext = nullptr;
  181. TF_LITE_ENSURE_STATUS(graph_.InitSubgraphs());
  182. // Both AllocatePersistentBuffer and RequestScratchBufferInArena is
  183. // available in Prepare stage.
  184. context_.RequestScratchBufferInArena =
  185. MicroContextRequestScratchBufferInArena;
  186. // external_context become available in Prepare stage.
  187. context_.GetExternalContext = MicroContextGetExternalContext;
  188. TF_LITE_ENSURE_STATUS(graph_.PrepareSubgraphs());
  189. // Prepare is done, we're ready for Invoke. Memory allocation is no longer
  190. // allowed. Kernels can only fetch scratch buffers via GetScratchBuffer.
  191. context_.AllocatePersistentBuffer = nullptr;
  192. context_.RequestScratchBufferInArena = nullptr;
  193. context_.GetScratchBuffer = MicroContextGetScratchBuffer;
  194. TF_LITE_ENSURE_OK(&context_, allocator_.FinishModelAllocation(
  195. model_, graph_.GetAllocations(),
  196. &scratch_buffer_handles_));
  197. micro_context_.SetScratchBufferHandles(scratch_buffer_handles_);
  198. // TODO(b/162311891): Drop these allocations when the interpreter supports
  199. // handling buffers from TfLiteEvalTensor.
  200. input_tensors_ =
  201. reinterpret_cast<TfLiteTensor**>(allocator_.AllocatePersistentBuffer(
  202. sizeof(TfLiteTensor*) * inputs_size()));
  203. if (input_tensors_ == nullptr) {
  204. TF_LITE_REPORT_ERROR(
  205. error_reporter_,
  206. "Failed to allocate memory for context->input_tensors_, "
  207. "%d bytes required",
  208. sizeof(TfLiteTensor*) * inputs_size());
  209. return kTfLiteError;
  210. }
  211. for (size_t i = 0; i < inputs_size(); ++i) {
  212. input_tensors_[i] = allocator_.AllocatePersistentTfLiteTensor(
  213. model_, graph_.GetAllocations(), inputs().Get(i), 0);
  214. if (input_tensors_[i] == nullptr) {
  215. TF_LITE_REPORT_ERROR(error_reporter_,
  216. "Failed to initialize input tensor %d", i);
  217. return kTfLiteError;
  218. }
  219. }
  220. // TODO(b/162311891): Drop these allocations when the interpreter supports
  221. // handling buffers from TfLiteEvalTensor.
  222. output_tensors_ =
  223. reinterpret_cast<TfLiteTensor**>(allocator_.AllocatePersistentBuffer(
  224. sizeof(TfLiteTensor*) * outputs_size()));
  225. if (output_tensors_ == nullptr) {
  226. TF_LITE_REPORT_ERROR(
  227. error_reporter_,
  228. "Failed to allocate memory for context->output_tensors_, "
  229. "%d bytes required",
  230. sizeof(TfLiteTensor*) * outputs_size());
  231. return kTfLiteError;
  232. }
  233. for (size_t i = 0; i < outputs_size(); ++i) {
  234. output_tensors_[i] = allocator_.AllocatePersistentTfLiteTensor(
  235. model_, graph_.GetAllocations(), outputs().Get(i), 0);
  236. if (output_tensors_[i] == nullptr) {
  237. TF_LITE_REPORT_ERROR(error_reporter_,
  238. "Failed to initialize output tensor %d", i);
  239. return kTfLiteError;
  240. }
  241. }
  242. TF_LITE_ENSURE_STATUS(ResetVariableTensors());
  243. tensors_allocated_ = true;
  244. return kTfLiteOk;
  245. }
  246. TfLiteStatus MicroInterpreter::Invoke() {
  247. if (initialization_status_ != kTfLiteOk) {
  248. TF_LITE_REPORT_ERROR(error_reporter_,
  249. "Invoke() called after initialization failed\n");
  250. return kTfLiteError;
  251. }
  252. // Ensure tensors are allocated before the interpreter is invoked to avoid
  253. // difficult to debug segfaults.
  254. if (!tensors_allocated_) {
  255. TF_LITE_ENSURE_OK(&context_, AllocateTensors());
  256. }
  257. return graph_.InvokeSubgraph(0);
  258. }
  259. TfLiteTensor* MicroInterpreter::input(size_t index) {
  260. const size_t length = inputs_size();
  261. if (index >= length) {
  262. TF_LITE_REPORT_ERROR(error_reporter_,
  263. "Input index %d out of range (length is %d)", index,
  264. length);
  265. return nullptr;
  266. }
  267. return input_tensors_[index];
  268. }
  269. TfLiteTensor* MicroInterpreter::output(size_t index) {
  270. const size_t length = outputs_size();
  271. if (index >= length) {
  272. TF_LITE_REPORT_ERROR(error_reporter_,
  273. "Output index %d out of range (length is %d)", index,
  274. length);
  275. return nullptr;
  276. }
  277. return output_tensors_[index];
  278. }
  279. // Repurposing free subgraphs to reset state for some ops for now
  280. // will reset api is made. See b/220940833#comment25 for more context.
  281. TfLiteStatus MicroInterpreter::Reset() {
  282. TfLiteStatus status = graph_.FreeSubgraphs();
  283. if (status != kTfLiteOk) {
  284. return status;
  285. }
  286. return graph_.ResetVariableTensors();
  287. }
  288. // TODO: remove this API completely in favor of MicroInterpreter::Reset
  289. TfLiteStatus MicroInterpreter::ResetVariableTensors() {
  290. return graph_.ResetVariableTensors();
  291. }
  292. TfLiteStatus MicroInterpreter::SetMicroExternalContext(
  293. void* external_context_payload) {
  294. return micro_context_.set_external_context(external_context_payload);
  295. }
  296. } // namespace tflite