micro_context.h 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_MICRO_CONTEXT_H_
  13. #define TENSORFLOW_LITE_MICRO_MICRO_CONTEXT_H_
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/micro/micro_allocator.h"
  16. #include "tensorflow/lite/micro/micro_graph.h"
  17. namespace tflite {
  18. // MicroContext is eventually going to become the API between TFLM and the
  19. // kernels, replacing all the functions in TfLiteContext. The end state is code
  20. // kernels to have code like:
  21. //
  22. // MicroContext* micro_context = GetMicroContext(context);
  23. // micro_context-><TFLM kernel API>
  24. class MicroContext {
  25. public:
  26. // Does not take any ownership, and all pointers must refer to valid objects
  27. // that outlive the one constructed.
  28. explicit MicroContext(MicroAllocator* allocator, const Model* model,
  29. MicroGraph* graph);
  30. virtual ~MicroContext();
  31. // Allocate persistent buffer which has the same life time as the interpreter.
  32. // Returns nullptr on failure.
  33. // The memory is allocated from the tail.
  34. // This method is only available in Init or Prepare stage.
  35. // Virtual so that it can be faked for kernel tests.
  36. virtual void* AllocatePersistentBuffer(size_t bytes);
  37. // Request a scratch buffer in the arena through static memory planning.
  38. // This method is only available in Prepare stage and the buffer is allocated
  39. // by the interpreter between Prepare and Eval stage. In Eval stage,
  40. // GetScratchBuffer API can be used to fetch the address.
  41. // Virtual so that it can be faked for kernel tests.
  42. virtual TfLiteStatus RequestScratchBufferInArena(size_t bytes,
  43. int* buffer_idx);
  44. // Get the scratch buffer pointer.
  45. // This method is only available in Eval stage.
  46. // Virtual so that it can be faked for kernel tests.
  47. virtual void* GetScratchBuffer(int buffer_idx);
  48. // Returns a temporary TfLiteTensor struct for a given index.
  49. // Virtual so that it can be faked for kernel tests.
  50. virtual TfLiteTensor* AllocateTempTfLiteTensor(int tensor_idx);
  51. // Returns a temporary TfLiteTensor struct for the specified input tensor of a
  52. // given mode. This is the recommended API over the deprecated
  53. // GetInput/GetInputSafe to get a temp input tensor. The returned tensor shall
  54. // be freed via calling DeallocateTempTfLiteTensor.
  55. virtual TfLiteTensor* AllocateTempInputTensor(const TfLiteNode* node,
  56. int index);
  57. // Returns a temporary TfLiteTensor struct for the specified output tensor of
  58. // a given mode. This is the recommended API over the deprecated
  59. // GetOutput/GetOutputSafe to get a temp output tensor. The returned tensor
  60. // shall be freed via calling DeallocateTempTfLiteTensor.
  61. virtual TfLiteTensor* AllocateTempOutputTensor(const TfLiteNode* node,
  62. int index);
  63. // Returns a temporary TfLiteTensor struct for the specified intermediate
  64. // tensor of a given mode. This is the recommended API over the deprecated
  65. // GetIntermediates/GetIntermediatesSafe to get a temp intermediate tensor.
  66. // The returned tensor shall be freed via calling DeallocateTempTfLiteTensor.
  67. virtual TfLiteTensor* AllocateTempIntermediateTensor(const TfLiteNode* node,
  68. int index);
  69. // Deallocates a temp TfLiteTensor.
  70. // Virtual so that it can be faked for kernel tests.
  71. virtual void DeallocateTempTfLiteTensor(TfLiteTensor* tensor);
  72. // Returns a TfLiteEvalTensor struct for a given index.
  73. // Virtual so that it can be faked for kernel tests.
  74. virtual TfLiteEvalTensor* GetEvalTensor(int tensor_idx);
  75. // Does not take ownership of the pointer and the pointer must refer to valid
  76. // an object that outlive this class instance.
  77. // This can only be called once to set one external context.
  78. TfLiteStatus set_external_context(void* external_context_payload);
  79. void* external_context() { return external_context_payload_; }
  80. MicroGraph& graph() { return graph_; }
  81. // Sets the pointer to a list of ScratchBufferHandle instances.
  82. // Not API between TFLM and kernels. Primarily used by the framework for
  83. // housekeeping in MicroContext.
  84. void SetScratchBufferHandles(ScratchBufferHandle* scratch_buffer_handles);
  85. private:
  86. // Return the tensor index as tensor_indices[index]. tensor_indices is of
  87. // max_size. Return -1 if index is not in the valid range of tensor_indices.
  88. int GetTensorIndex(int index, int max_size, const int* tensor_indices);
  89. MicroAllocator& allocator_;
  90. MicroGraph& graph_;
  91. const Model* model_;
  92. ScratchBufferHandle* scratch_buffer_handles_ = nullptr;
  93. void* external_context_payload_ = nullptr;
  94. TF_LITE_REMOVE_VIRTUAL_DELETE
  95. };
  96. inline MicroContext* GetMicroContext(const struct TfLiteContext* context) {
  97. return reinterpret_cast<MicroContext*>(context->impl_);
  98. }
  99. // Deprecated API. Prefer to using the MicroContext API directly from the
  100. // kernels.
  101. // TODO(b/213010668): migrate all existing kernels to use MicroContext, delete
  102. // these functions, and remove corresponding members from the TfLiteContext
  103. // struct for TFLM.
  104. inline void* MicroContextAllocatePersistentBuffer(TfLiteContext* ctx,
  105. size_t bytes) {
  106. return GetMicroContext(ctx)->AllocatePersistentBuffer(bytes);
  107. }
  108. inline TfLiteStatus MicroContextRequestScratchBufferInArena(TfLiteContext* ctx,
  109. size_t bytes,
  110. int* buffer_idx) {
  111. return GetMicroContext(ctx)->RequestScratchBufferInArena(bytes, buffer_idx);
  112. }
  113. inline void* MicroContextGetScratchBuffer(TfLiteContext* ctx, int buffer_idx) {
  114. return GetMicroContext(ctx)->GetScratchBuffer(buffer_idx);
  115. }
  116. inline TfLiteTensor* MicroContextGetTensor(const struct TfLiteContext* context,
  117. int tensor_idx) {
  118. return GetMicroContext(context)->AllocateTempTfLiteTensor(tensor_idx);
  119. }
  120. inline TfLiteEvalTensor* MicroContextGetEvalTensor(
  121. const struct TfLiteContext* context, int tensor_idx) {
  122. return GetMicroContext(context)->GetEvalTensor(tensor_idx);
  123. }
  124. inline TfLiteExternalContext* MicroContextGetExternalContext(
  125. TfLiteContext* context, TfLiteExternalContextType unused) {
  126. return reinterpret_cast<TfLiteExternalContext*>(
  127. GetMicroContext(context)->external_context());
  128. }
  129. // Requests that an error be reported with format string msg.
  130. void MicroContextReportOpError(struct TfLiteContext* context,
  131. const char* format, ...);
  132. } // namespace tflite
  133. #endif // TENSORFLOW_LITE_MICRO_MICRO_CONTEXT_H_