micro_graph.h 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_MICRO_GRAPH_H_
  13. #define TENSORFLOW_LITE_MICRO_MICRO_GRAPH_H_
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/micro/micro_allocator.h"
  16. #include "tensorflow/lite/schema/schema_generated.h"
  17. namespace tflite {
  18. // Abstracts the details of interacting with the tflite::Model.
  19. //
  20. // Provides methods to access, initialize, prepare, invoke and free any
  21. // subgraph in the tflite::Graph.
  22. class MicroGraph {
  23. public:
  24. // The lifetime of the context, model, and allocator must be at least as long
  25. // as that of the graph object, since the this class may need to access them
  26. // at any time.
  27. MicroGraph(TfLiteContext* context, const Model* model,
  28. MicroAllocator* allocator);
  29. virtual ~MicroGraph();
  30. // Sets up builtin data and calls TfLiteRegistration->Init for every operator
  31. // in every subgraph in the model.
  32. virtual TfLiteStatus InitSubgraphs();
  33. // Calls TfLiteRegistration->Prepare for every operator in every subgraph in
  34. // the model.
  35. virtual TfLiteStatus PrepareSubgraphs();
  36. // Calls TfLiteRegistration->Free for every operator in every subgraph in the
  37. // model.
  38. virtual TfLiteStatus FreeSubgraphs();
  39. // Calls TfLiteRegistration->Invoke for every operator in a single subgraph in
  40. // the model.
  41. virtual TfLiteStatus InvokeSubgraph(int subgraph_idx);
  42. // Zeros out all variable tensors in all subgraphs in the model.
  43. virtual TfLiteStatus ResetVariableTensors();
  44. // Number of tensor inputs to a specified subgraph in the model.
  45. virtual size_t NumSubgraphInputs(int subgraph_idx);
  46. // Get the specified input tensor of a specified subgraph in the model.
  47. virtual TfLiteEvalTensor* GetSubgraphInput(int subgraph_idx, int input_idx);
  48. // Number of tensor outputs from a specified subgraph in the model.
  49. virtual size_t NumSubgraphOutputs(int subgraph_idx);
  50. // Get the specified output tensor of a specified subgraph in the model.
  51. virtual TfLiteEvalTensor* GetSubgraphOutput(int subgraph_idx, int output_idx);
  52. // Number of subgraphs in the model.
  53. virtual int NumSubgraphs();
  54. // Hook to pass in subgraph allocations tracked within the interpreter,
  55. // allowing MicroGraph to init / prepare / invoke subgraphs in the model.
  56. void SetSubgraphAllocations(SubgraphAllocations* subgraph_allocations);
  57. // Get the current subgraph index. Within an on operator, this is guaranteed
  58. // to be the subgraph of that operator.
  59. int GetCurrentSubgraphIndex() { return current_subgraph_index_; }
  60. // Gets the list of alloctions for each subgraph. This is the source of truth
  61. // for all per-subgraph allocation data.
  62. SubgraphAllocations* GetAllocations() { return subgraph_allocations_; }
  63. private:
  64. TfLiteContext* context_;
  65. const Model* model_;
  66. MicroAllocator* allocator_;
  67. SubgraphAllocations* subgraph_allocations_ = nullptr;
  68. int current_subgraph_index_;
  69. const flatbuffers::Vector<flatbuffers::Offset<SubGraph>>* subgraphs_;
  70. TF_LITE_REMOVE_VIRTUAL_DELETE
  71. };
  72. } // namespace tflite
  73. #endif // TENSORFLOW_LITE_MICRO_MICRO_GRAPH_H_