micro_graph.h 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_MICRO_GRAPH_H_
  13. #define TENSORFLOW_LITE_MICRO_MICRO_GRAPH_H_
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/micro/micro_allocator.h"
  16. #include "tensorflow/lite/micro/micro_resource_variable.h"
  17. #include "tensorflow/lite/schema/schema_generated.h"
  18. namespace tflite {
  19. // Abstracts the details of interacting with the tflite::Model.
  20. //
  21. // Provides methods to access, initialize, prepare, invoke and free any
  22. // subgraph in the tflite::Graph.
  23. class MicroGraph {
  24. public:
  25. // The lifetime of the context, model, allocator and resource_variables must
  26. // be at least as long as that of the graph object, since the this class may
  27. // need to access them at any time. If resource_variables is a nullptr,
  28. // GetResourceVariables will return a nullptr.
  29. MicroGraph(TfLiteContext* context, const Model* model,
  30. MicroAllocator* allocator,
  31. MicroResourceVariables* resource_variables);
  32. virtual ~MicroGraph();
  33. // Sets up builtin data and calls TfLiteRegistration->Init for every operator
  34. // in every subgraph in the model.
  35. virtual TfLiteStatus InitSubgraphs();
  36. // Calls TfLiteRegistration->Prepare for every operator in every subgraph in
  37. // the model.
  38. virtual TfLiteStatus PrepareSubgraphs();
  39. // Calls TfLiteRegistration->Free for every operator in every subgraph in the
  40. // model.
  41. virtual TfLiteStatus FreeSubgraphs();
  42. // Calls TfLiteRegistration->Invoke for every operator in a single subgraph in
  43. // the model.
  44. virtual TfLiteStatus InvokeSubgraph(int subgraph_idx);
  45. // Zeros out all variable tensors in all subgraphs in the model.
  46. virtual TfLiteStatus ResetVariableTensors();
  47. // Number of tensor inputs to a specified subgraph in the model.
  48. virtual size_t NumSubgraphInputs(int subgraph_idx);
  49. // Get the specified input tensor of a specified subgraph in the model.
  50. virtual TfLiteEvalTensor* GetSubgraphInput(int subgraph_idx, int input_idx);
  51. // Number of tensor outputs from a specified subgraph in the model.
  52. virtual size_t NumSubgraphOutputs(int subgraph_idx);
  53. // Get the specified output tensor of a specified subgraph in the model.
  54. virtual TfLiteEvalTensor* GetSubgraphOutput(int subgraph_idx, int output_idx);
  55. // Number of subgraphs in the model.
  56. virtual int NumSubgraphs();
  57. // Hook to pass in subgraph allocations tracked within the interpreter,
  58. // allowing MicroGraph to init / prepare / invoke subgraphs in the model.
  59. void SetSubgraphAllocations(SubgraphAllocations* subgraph_allocations);
  60. // Get the current subgraph index. Within an on operator, this is guaranteed
  61. // to be the subgraph of that operator.
  62. int GetCurrentSubgraphIndex() { return current_subgraph_index_; }
  63. // Gets the list of alloctions for each subgraph. This is the source of truth
  64. // for all per-subgraph allocation data.
  65. SubgraphAllocations* GetAllocations() { return subgraph_allocations_; }
  66. // Get the resource variables for this TFLM graph.
  67. MicroResourceVariables* GetResourceVariables() { return resource_variables_; }
  68. private:
  69. TfLiteContext* context_;
  70. const Model* model_;
  71. MicroAllocator* allocator_;
  72. SubgraphAllocations* subgraph_allocations_ = nullptr;
  73. int current_subgraph_index_;
  74. MicroResourceVariables* resource_variables_;
  75. const flatbuffers::Vector<flatbuffers::Offset<SubGraph>>* subgraphs_;
  76. TF_LITE_REMOVE_VIRTUAL_DELETE
  77. };
  78. } // namespace tflite
  79. #endif // TENSORFLOW_LITE_MICRO_MICRO_GRAPH_H_