micro_interpreter.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/micro_interpreter.h"
  13. #include <cstdarg>
  14. #include <cstddef>
  15. #include <cstdint>
  16. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  17. #include "tensorflow/lite/c/common.h"
  18. #include "tensorflow/lite/core/api/error_reporter.h"
  19. #include "tensorflow/lite/core/api/tensor_utils.h"
  20. #include "tensorflow/lite/micro/memory_helpers.h"
  21. #include "tensorflow/lite/micro/micro_allocator.h"
  22. #include "tensorflow/lite/micro/micro_error_reporter.h"
  23. #include "tensorflow/lite/micro/micro_op_resolver.h"
  24. #include "tensorflow/lite/micro/micro_profiler.h"
  25. #include "tensorflow/lite/schema/schema_generated.h"
  26. namespace tflite {
  27. namespace {
  28. #ifndef TF_LITE_STRIP_ERROR_STRINGS
  29. const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
  30. if (registration->builtin_code == BuiltinOperator_CUSTOM) {
  31. return registration->custom_name;
  32. } else {
  33. return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
  34. }
  35. }
  36. #endif // !defined(TF_LITE_STRIP_ERROR_STRINGS)
  37. } // namespace
  38. namespace internal {
  39. ContextHelper::ContextHelper(ErrorReporter* error_reporter,
  40. MicroAllocator* allocator, const Model* model)
  41. : allocator_(allocator), error_reporter_(error_reporter), model_(model) {}
  42. void* ContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx,
  43. size_t bytes) {
  44. return reinterpret_cast<ContextHelper*>(ctx->impl_)
  45. ->allocator_->AllocatePersistentBuffer(bytes);
  46. }
  47. TfLiteStatus ContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx,
  48. size_t bytes,
  49. int* buffer_idx) {
  50. ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
  51. return helper->allocator_->RequestScratchBufferInArena(bytes, buffer_idx);
  52. }
  53. void* ContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) {
  54. ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
  55. ScratchBufferHandle* handle = helper->scratch_buffer_handles_ + buffer_idx;
  56. return handle->data;
  57. }
  58. void ContextHelper::ReportOpError(struct TfLiteContext* context,
  59. const char* format, ...) {
  60. #ifndef TF_LITE_STRIP_ERROR_STRINGS
  61. ContextHelper* helper = static_cast<ContextHelper*>(context->impl_);
  62. va_list args;
  63. va_start(args, format);
  64. TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args);
  65. va_end(args);
  66. #endif
  67. }
  68. TfLiteTensor* ContextHelper::GetTensor(const struct TfLiteContext* context,
  69. int tensor_idx) {
  70. ContextHelper* helper = static_cast<ContextHelper*>(context->impl_);
  71. return helper->allocator_->AllocateTempTfLiteTensor(
  72. helper->model_, helper->eval_tensors_, tensor_idx);
  73. }
  74. TfLiteEvalTensor* ContextHelper::GetEvalTensor(
  75. const struct TfLiteContext* context, int tensor_idx) {
  76. ContextHelper* helper = reinterpret_cast<ContextHelper*>(context->impl_);
  77. return &helper->eval_tensors_[tensor_idx];
  78. }
  79. void ContextHelper::SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors) {
  80. eval_tensors_ = eval_tensors;
  81. }
  82. void ContextHelper::SetScratchBufferHandles(
  83. ScratchBufferHandle* scratch_buffer_handles) {
  84. scratch_buffer_handles_ = scratch_buffer_handles;
  85. }
  86. } // namespace internal
  87. MicroInterpreter::MicroInterpreter(const Model* model,
  88. const MicroOpResolver& op_resolver,
  89. uint8_t* tensor_arena,
  90. size_t tensor_arena_size,
  91. ErrorReporter* error_reporter,
  92. MicroProfiler* profiler)
  93. : model_(model),
  94. op_resolver_(op_resolver),
  95. error_reporter_(error_reporter),
  96. allocator_(*MicroAllocator::Create(tensor_arena, tensor_arena_size,
  97. error_reporter)),
  98. tensors_allocated_(false),
  99. initialization_status_(kTfLiteError),
  100. eval_tensors_(nullptr),
  101. context_helper_(error_reporter_, &allocator_, model),
  102. input_tensors_(nullptr),
  103. output_tensors_(nullptr) {
  104. Init(profiler);
  105. }
  106. MicroInterpreter::MicroInterpreter(const Model* model,
  107. const MicroOpResolver& op_resolver,
  108. MicroAllocator* allocator,
  109. ErrorReporter* error_reporter,
  110. MicroProfiler* profiler)
  111. : model_(model),
  112. op_resolver_(op_resolver),
  113. error_reporter_(error_reporter),
  114. allocator_(*allocator),
  115. tensors_allocated_(false),
  116. initialization_status_(kTfLiteError),
  117. eval_tensors_(nullptr),
  118. context_helper_(error_reporter_, &allocator_, model),
  119. input_tensors_(nullptr),
  120. output_tensors_(nullptr) {
  121. Init(profiler);
  122. }
  123. MicroInterpreter::~MicroInterpreter() {
  124. if (node_and_registrations_ != nullptr) {
  125. for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
  126. TfLiteNode* node = &(node_and_registrations_[i].node);
  127. const TfLiteRegistration* registration =
  128. node_and_registrations_[i].registration;
  129. // registration is allocated outside the interpreter, so double check to
  130. // make sure it's not nullptr;
  131. if (registration != nullptr && registration->free != nullptr) {
  132. registration->free(&context_, node->user_data);
  133. }
  134. }
  135. }
  136. }
  137. void MicroInterpreter::Init(MicroProfiler* profiler) {
  138. const flatbuffers::Vector<flatbuffers::Offset<SubGraph>>* subgraphs =
  139. model_->subgraphs();
  140. if (subgraphs->size() != 1) {
  141. TF_LITE_REPORT_ERROR(error_reporter_,
  142. "Only 1 subgraph is currently supported.\n");
  143. initialization_status_ = kTfLiteError;
  144. return;
  145. }
  146. subgraph_ = (*subgraphs)[0];
  147. context_.impl_ = static_cast<void*>(&context_helper_);
  148. context_.ReportError = context_helper_.ReportOpError;
  149. context_.GetTensor = context_helper_.GetTensor;
  150. context_.GetEvalTensor = context_helper_.GetEvalTensor;
  151. context_.recommended_num_threads = 1;
  152. context_.profiler = profiler;
  153. initialization_status_ = kTfLiteOk;
  154. }
  155. TfLiteStatus MicroInterpreter::AllocateTensors() {
  156. if (allocator_.StartModelAllocation(model_, op_resolver_,
  157. &node_and_registrations_,
  158. &eval_tensors_) != kTfLiteOk) {
  159. TF_LITE_REPORT_ERROR(error_reporter_,
  160. "Failed starting model allocation.\n");
  161. initialization_status_ = kTfLiteError;
  162. return kTfLiteError;
  163. }
  164. // Update the pointer now that TfLiteEvalTensor allocation has completed on
  165. // the context helper.
  166. // TODO(b/16157777): This call would not be needed if ContextHelper rolled
  167. // into the interpreter.
  168. context_helper_.SetTfLiteEvalTensors(eval_tensors_);
  169. context_.tensors_size = subgraph_->tensors()->size();
  170. // Only allow AllocatePersistentBuffer in Init stage.
  171. context_.AllocatePersistentBuffer = context_helper_.AllocatePersistentBuffer;
  172. context_.RequestScratchBufferInArena = nullptr;
  173. context_.GetScratchBuffer = nullptr;
  174. for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
  175. auto* node = &(node_and_registrations_[i].node);
  176. auto* registration = node_and_registrations_[i].registration;
  177. size_t init_data_size;
  178. const char* init_data;
  179. if (registration->builtin_code == BuiltinOperator_CUSTOM) {
  180. init_data = reinterpret_cast<const char*>(node->custom_initial_data);
  181. init_data_size = node->custom_initial_data_size;
  182. } else {
  183. init_data = reinterpret_cast<const char*>(node->builtin_data);
  184. init_data_size = 0;
  185. }
  186. if (registration->init) {
  187. node->user_data =
  188. registration->init(&context_, init_data, init_data_size);
  189. }
  190. }
  191. // Both AllocatePersistentBuffer and RequestScratchBufferInArena is
  192. // available in Prepare stage.
  193. context_.RequestScratchBufferInArena =
  194. context_helper_.RequestScratchBufferInArena;
  195. for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
  196. auto* node = &(node_and_registrations_[i].node);
  197. auto* registration = node_and_registrations_[i].registration;
  198. if (registration->prepare) {
  199. TfLiteStatus prepare_status = registration->prepare(&context_, node);
  200. if (prepare_status != kTfLiteOk) {
  201. TF_LITE_REPORT_ERROR(
  202. error_reporter_,
  203. "Node %s (number %df) failed to prepare with status %d",
  204. OpNameFromRegistration(registration), i, prepare_status);
  205. return kTfLiteError;
  206. }
  207. }
  208. allocator_.FinishPrepareNodeAllocations(/*node_id=*/i);
  209. }
  210. // Prepare is done, we're ready for Invoke. Memory allocation is no longer
  211. // allowed. Kernels can only fetch scratch buffers via GetScratchBuffer.
  212. context_.AllocatePersistentBuffer = nullptr;
  213. context_.RequestScratchBufferInArena = nullptr;
  214. context_.GetScratchBuffer = context_helper_.GetScratchBuffer;
  215. TF_LITE_ENSURE_OK(&context_,
  216. allocator_.FinishModelAllocation(model_, eval_tensors_,
  217. &scratch_buffer_handles_));
  218. // TODO(b/16157777): Remove this when ContextHelper is rolled into this class.
  219. context_helper_.SetScratchBufferHandles(scratch_buffer_handles_);
  220. // TODO(b/162311891): Drop these allocations when the interpreter supports
  221. // handling buffers from TfLiteEvalTensor.
  222. input_tensors_ =
  223. reinterpret_cast<TfLiteTensor**>(allocator_.AllocatePersistentBuffer(
  224. sizeof(TfLiteTensor*) * inputs_size()));
  225. if (input_tensors_ == nullptr) {
  226. TF_LITE_REPORT_ERROR(
  227. error_reporter_,
  228. "Failed to allocate memory for context->input_tensors_, "
  229. "%d bytes required",
  230. sizeof(TfLiteTensor*) * inputs_size());
  231. return kTfLiteError;
  232. }
  233. for (size_t i = 0; i < inputs_size(); ++i) {
  234. input_tensors_[i] = allocator_.AllocatePersistentTfLiteTensor(
  235. model_, eval_tensors_, inputs().Get(i));
  236. if (input_tensors_[i] == nullptr) {
  237. TF_LITE_REPORT_ERROR(error_reporter_,
  238. "Failed to initialize input tensor %d", i);
  239. return kTfLiteError;
  240. }
  241. }
  242. // TODO(b/162311891): Drop these allocations when the interpreter supports
  243. // handling buffers from TfLiteEvalTensor.
  244. output_tensors_ =
  245. reinterpret_cast<TfLiteTensor**>(allocator_.AllocatePersistentBuffer(
  246. sizeof(TfLiteTensor*) * outputs_size()));
  247. if (output_tensors_ == nullptr) {
  248. TF_LITE_REPORT_ERROR(
  249. error_reporter_,
  250. "Failed to allocate memory for context->output_tensors_, "
  251. "%d bytes required",
  252. sizeof(TfLiteTensor*) * outputs_size());
  253. return kTfLiteError;
  254. }
  255. for (size_t i = 0; i < outputs_size(); ++i) {
  256. output_tensors_[i] = allocator_.AllocatePersistentTfLiteTensor(
  257. model_, eval_tensors_, outputs().Get(i));
  258. if (output_tensors_[i] == nullptr) {
  259. TF_LITE_REPORT_ERROR(error_reporter_,
  260. "Failed to initialize output tensor %d", i);
  261. return kTfLiteError;
  262. }
  263. }
  264. TF_LITE_ENSURE_STATUS(ResetVariableTensors());
  265. tensors_allocated_ = true;
  266. return kTfLiteOk;
  267. }
  268. TfLiteStatus MicroInterpreter::Invoke() {
  269. if (initialization_status_ != kTfLiteOk) {
  270. TF_LITE_REPORT_ERROR(error_reporter_,
  271. "Invoke() called after initialization failed\n");
  272. return kTfLiteError;
  273. }
  274. // Ensure tensors are allocated before the interpreter is invoked to avoid
  275. // difficult to debug segfaults.
  276. if (!tensors_allocated_) {
  277. TF_LITE_ENSURE_OK(&context_, AllocateTensors());
  278. }
  279. for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
  280. auto* node = &(node_and_registrations_[i].node);
  281. auto* registration = node_and_registrations_[i].registration;
  282. // This ifdef is needed (even though ScopedMicroProfiler itself is a no-op with
  283. // -DTF_LITE_STRIP_ERROR_STRINGS) because the function OpNameFromRegistration is
  284. // only defined for builds with the error strings.
  285. #if !defined(TF_LITE_STRIP_ERROR_STRINGS)
  286. ScopedMicroProfiler scoped_profiler(
  287. OpNameFromRegistration(registration),
  288. reinterpret_cast<MicroProfiler*>(context_.profiler));
  289. #endif
  290. TFLITE_DCHECK(registration->invoke);
  291. TfLiteStatus invoke_status = registration->invoke(&context_, node);
  292. // All TfLiteTensor structs used in the kernel are allocated from temp
  293. // memory in the allocator. This creates a chain of allocations in the
  294. // temp section. The call below resets the chain of allocations to
  295. // prepare for the next call.
  296. allocator_.ResetTempAllocations();
  297. if (invoke_status == kTfLiteError) {
  298. TF_LITE_REPORT_ERROR(
  299. error_reporter_,
  300. "Node %s (number %d) failed to invoke with status %d",
  301. OpNameFromRegistration(registration), i, invoke_status);
  302. return kTfLiteError;
  303. } else if (invoke_status != kTfLiteOk) {
  304. return invoke_status;
  305. }
  306. }
  307. return kTfLiteOk;
  308. }
  309. TfLiteTensor* MicroInterpreter::input(size_t index) {
  310. const size_t length = inputs_size();
  311. if (index >= length) {
  312. TF_LITE_REPORT_ERROR(error_reporter_,
  313. "Input index %d out of range (length is %d)", index,
  314. length);
  315. return nullptr;
  316. }
  317. return input_tensors_[index];
  318. }
  319. TfLiteTensor* MicroInterpreter::output(size_t index) {
  320. const size_t length = outputs_size();
  321. if (index >= length) {
  322. TF_LITE_REPORT_ERROR(error_reporter_,
  323. "Output index %d out of range (length is %d)", index,
  324. length);
  325. return nullptr;
  326. }
  327. return output_tensors_[index];
  328. }
  329. TfLiteTensor* MicroInterpreter::tensor(size_t index) {
  330. const size_t length = tensors_size();
  331. if (index >= length) {
  332. TF_LITE_REPORT_ERROR(error_reporter_,
  333. "Tensor index %d out of range (length is %d)", index,
  334. length);
  335. return nullptr;
  336. }
  337. return allocator_.AllocatePersistentTfLiteTensor(model_, eval_tensors_,
  338. index);
  339. }
  340. TfLiteStatus MicroInterpreter::ResetVariableTensors() {
  341. for (size_t i = 0; i < subgraph_->tensors()->size(); ++i) {
  342. auto* tensor = subgraph_->tensors()->Get(i);
  343. if (tensor->is_variable()) {
  344. size_t buffer_size;
  345. TF_LITE_ENSURE_STATUS(
  346. TfLiteEvalTensorByteLength(&eval_tensors_[i], &buffer_size));
  347. int value = 0;
  348. if (tensor->type() == tflite::TensorType_INT8) {
  349. value = tensor->quantization()->zero_point()->Get(0);
  350. }
  351. memset(eval_tensors_[i].data.raw, value, buffer_size);
  352. }
  353. }
  354. return kTfLiteOk;
  355. }
  356. } // namespace tflite