micro_graph.cc 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/micro_graph.h"
  13. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/kernels/internal/compatibility.h"
  16. #include "tensorflow/lite/micro/flatbuffer_utils.h"
  17. #include "tensorflow/lite/micro/memory_helpers.h"
  18. #include "tensorflow/lite/micro/micro_error_reporter.h"
  19. #include "tensorflow/lite/micro/micro_profiler.h"
  20. #include "tensorflow/lite/schema/schema_generated.h"
  21. namespace tflite {
  22. namespace {
  23. const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
  24. if (registration->builtin_code == BuiltinOperator_CUSTOM) {
  25. return registration->custom_name;
  26. } else {
  27. return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
  28. }
  29. }
  30. } // namespace
  31. MicroGraph::MicroGraph(TfLiteContext* context, const Model* model,
  32. MicroAllocator* allocator,
  33. MicroResourceVariables* resource_variables)
  34. : context_(context),
  35. model_(model),
  36. allocator_(allocator),
  37. current_subgraph_index_(0),
  38. resource_variables_(resource_variables) {
  39. if (model != nullptr) {
  40. subgraphs_ = model->subgraphs();
  41. }
  42. }
  43. MicroGraph::~MicroGraph() {}
  44. TfLiteStatus MicroGraph::InitSubgraphs() {
  45. int previous_subgraph_idx = current_subgraph_index_;
  46. for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
  47. subgraph_idx++) {
  48. current_subgraph_index_ = subgraph_idx;
  49. uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
  50. for (size_t i = 0; i < operators_size; ++i) {
  51. TfLiteNode* node =
  52. &(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
  53. const TfLiteRegistration* registration =
  54. subgraph_allocations_[subgraph_idx]
  55. .node_and_registrations[i]
  56. .registration;
  57. size_t init_data_size;
  58. const char* init_data;
  59. if (registration->builtin_code == BuiltinOperator_CUSTOM) {
  60. init_data = reinterpret_cast<const char*>(node->custom_initial_data);
  61. init_data_size = node->custom_initial_data_size;
  62. } else {
  63. init_data = reinterpret_cast<const char*>(node->builtin_data);
  64. init_data_size = 0;
  65. }
  66. if (registration->init) {
  67. node->user_data =
  68. registration->init(context_, init_data, init_data_size);
  69. }
  70. }
  71. }
  72. current_subgraph_index_ = previous_subgraph_idx;
  73. return kTfLiteOk;
  74. }
  75. TfLiteStatus MicroGraph::PrepareSubgraphs() {
  76. int previous_subgraph_idx = current_subgraph_index_;
  77. for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
  78. subgraph_idx++) {
  79. current_subgraph_index_ = subgraph_idx;
  80. uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
  81. for (size_t i = 0; i < operators_size; ++i) {
  82. TfLiteNode* node =
  83. &(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
  84. const TfLiteRegistration* registration =
  85. subgraph_allocations_[subgraph_idx]
  86. .node_and_registrations[i]
  87. .registration;
  88. if (registration->prepare != nullptr) {
  89. TfLiteStatus prepare_status = registration->prepare(context_, node);
  90. if (prepare_status != kTfLiteOk) {
  91. MicroPrintf("Node %s (number %df) failed to prepare with status %d",
  92. OpNameFromRegistration(registration), i, prepare_status);
  93. return kTfLiteError;
  94. }
  95. }
  96. allocator_->FinishPrepareNodeAllocations(/*node_id=*/i);
  97. }
  98. }
  99. current_subgraph_index_ = previous_subgraph_idx;
  100. return kTfLiteOk;
  101. }
  102. TfLiteStatus MicroGraph::FreeSubgraphs() {
  103. int previous_subgraph_idx = current_subgraph_index_;
  104. for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
  105. subgraph_idx++) {
  106. current_subgraph_index_ = subgraph_idx;
  107. uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
  108. for (size_t i = 0; i < operators_size; ++i) {
  109. TfLiteNode* node =
  110. &(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
  111. const TfLiteRegistration* registration =
  112. subgraph_allocations_[subgraph_idx]
  113. .node_and_registrations[i]
  114. .registration;
  115. // registration is allocated outside the interpreter, so double check to
  116. // make sure it's not nullptr;
  117. if (registration != nullptr && registration->free != nullptr) {
  118. registration->free(context_, node->user_data);
  119. }
  120. }
  121. }
  122. current_subgraph_index_ = previous_subgraph_idx;
  123. return kTfLiteOk;
  124. }
  125. TfLiteStatus MicroGraph::InvokeSubgraph(int subgraph_idx) {
  126. int previous_subgraph_idx = current_subgraph_index_;
  127. current_subgraph_index_ = subgraph_idx;
  128. if (static_cast<size_t>(subgraph_idx) >= subgraphs_->size()) {
  129. MicroPrintf("Accessing subgraph %d but only %d subgraphs found",
  130. subgraph_idx, subgraphs_->size());
  131. return kTfLiteError;
  132. }
  133. uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
  134. for (size_t i = 0; i < operators_size; ++i) {
  135. TfLiteNode* node =
  136. &(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
  137. const TfLiteRegistration* registration = subgraph_allocations_[subgraph_idx]
  138. .node_and_registrations[i]
  139. .registration;
  140. // This ifdef is needed (even though ScopedMicroProfiler itself is a no-op with
  141. // -DTF_LITE_STRIP_ERROR_STRINGS) because the function OpNameFromRegistration is
  142. // only defined for builds with the error strings.
  143. #if !defined(TF_LITE_STRIP_ERROR_STRINGS)
  144. ScopedMicroProfiler scoped_profiler(
  145. OpNameFromRegistration(registration),
  146. reinterpret_cast<MicroProfiler*>(context_->profiler));
  147. #endif
  148. TFLITE_DCHECK(registration->invoke);
  149. TfLiteStatus invoke_status = registration->invoke(context_, node);
  150. // All TfLiteTensor structs used in the kernel are allocated from temp
  151. // memory in the allocator. This creates a chain of allocations in the
  152. // temp section. The call below resets the chain of allocations to
  153. // prepare for the next call.
  154. allocator_->ResetTempAllocations();
  155. if (invoke_status == kTfLiteError) {
  156. MicroPrintf("Node %s (number %d) failed to invoke with status %d",
  157. OpNameFromRegistration(registration), i, invoke_status);
  158. return kTfLiteError;
  159. } else if (invoke_status != kTfLiteOk) {
  160. return invoke_status;
  161. }
  162. }
  163. current_subgraph_index_ = previous_subgraph_idx;
  164. return kTfLiteOk;
  165. }
  166. TfLiteStatus MicroGraph::ResetVariableTensors() {
  167. for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
  168. subgraph_idx++) {
  169. const SubGraph* subgraph = (*subgraphs_)[subgraph_idx];
  170. for (size_t i = 0; i < subgraph->tensors()->size(); ++i) {
  171. auto* tensor = subgraph->tensors()->Get(i);
  172. if (tensor->is_variable()) {
  173. size_t buffer_size;
  174. TF_LITE_ENSURE_STATUS(TfLiteEvalTensorByteLength(
  175. &subgraph_allocations_[subgraph_idx].tensors[i], &buffer_size));
  176. int value = 0;
  177. if (tensor->type() == tflite::TensorType_INT8) {
  178. value = tensor->quantization()->zero_point()->Get(0);
  179. }
  180. memset(subgraph_allocations_[subgraph_idx].tensors[i].data.raw, value,
  181. buffer_size);
  182. }
  183. }
  184. }
  185. if (resource_variables_ != nullptr) {
  186. resource_variables_->ResetAll();
  187. }
  188. return kTfLiteOk;
  189. }
  190. int MicroGraph::NumSubgraphs() { return model_->subgraphs()->size(); }
  191. void MicroGraph::SetSubgraphAllocations(
  192. SubgraphAllocations* subgraph_allocations) {
  193. subgraph_allocations_ = subgraph_allocations;
  194. }
  195. size_t MicroGraph::NumSubgraphInputs(int subgraph_idx) {
  196. return model_->subgraphs()->Get(subgraph_idx)->inputs()->size();
  197. }
  198. TfLiteEvalTensor* MicroGraph::GetSubgraphInput(int subgraph_idx,
  199. int input_idx) {
  200. int tensor_idx =
  201. model_->subgraphs()->Get(subgraph_idx)->inputs()->Get(input_idx);
  202. return &subgraph_allocations_[subgraph_idx].tensors[tensor_idx];
  203. }
  204. size_t MicroGraph::NumSubgraphOutputs(int subgraph_idx) {
  205. return model_->subgraphs()->Get(subgraph_idx)->outputs()->size();
  206. }
  207. TfLiteEvalTensor* MicroGraph::GetSubgraphOutput(int subgraph_idx,
  208. int output_idx) {
  209. int tensor_idx =
  210. model_->subgraphs()->Get(subgraph_idx)->outputs()->Get(output_idx);
  211. return &subgraph_allocations_[subgraph_idx].tensors[tensor_idx];
  212. }
  213. } // namespace tflite