mock_micro_graph.h 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_MOCK_MICRO_GRAPH_H_
  13. #define TENSORFLOW_LITE_MICRO_MOCK_MICRO_GRAPH_H_
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/micro/micro_allocator.h"
  16. #include "tensorflow/lite/micro/micro_graph.h"
  17. #include "tensorflow/lite/schema/schema_generated.h"
  18. namespace tflite {
  19. // MockMicroGraph stubs out all MicroGraph methods used during invoke. A count
  20. // of the number of calls to invoke for each subgraph is maintained for
  21. // validation of control flow operators.
  22. class MockMicroGraph : public MicroGraph {
  23. public:
  24. explicit MockMicroGraph(SimpleMemoryAllocator* allocator);
  25. TfLiteStatus InvokeSubgraph(int subgraph_idx) override;
  26. TfLiteStatus ResetVariableTensors() override;
  27. size_t NumSubgraphInputs(int subgraph_idx) override;
  28. TfLiteEvalTensor* GetSubgraphInput(int subgraph_idx, int tensor_idx) override;
  29. size_t NumSubgraphOutputs(int subgraph_idx) override;
  30. TfLiteEvalTensor* GetSubgraphOutput(int subgraph_idx,
  31. int tensor_idx) override;
  32. int NumSubgraphs() override;
  33. int get_init_count() const { return init_count_; }
  34. int get_prepare_count() const { return prepare_count_; }
  35. int get_free_count() const { return free_count_; }
  36. int get_invoke_count(int subgraph_idx) const {
  37. return invoke_counts_[subgraph_idx];
  38. }
  39. private:
  40. static constexpr int kMaxSubgraphs = 10;
  41. SimpleMemoryAllocator* allocator_;
  42. TfLiteEvalTensor* mock_tensor_;
  43. int init_count_;
  44. int prepare_count_;
  45. int free_count_;
  46. int invoke_counts_[kMaxSubgraphs];
  47. TF_LITE_REMOVE_VIRTUAL_DELETE
  48. };
  49. } // namespace tflite
  50. #endif // TENSORFLOW_LITE_MICRO_MOCK_MICRO_GRAPH_H_