fake_micro_context.cc 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/fake_micro_context.h"
  13. #include "tensorflow/lite/kernels/internal/compatibility.h"
  14. #include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
  15. #include "tensorflow/lite/micro/micro_allocator.h"
  16. #include "tensorflow/lite/micro/micro_arena_constants.h"
  17. #include "tensorflow/lite/micro/micro_error_reporter.h"
  18. namespace tflite {
  19. namespace {
  20. // Dummy static variables to allow creation of dummy MicroAllocator.
  21. // All tests are guarateed to run serially.
  22. static constexpr int KDummyTensorArenaSize = 256;
  23. static uint8_t dummy_tensor_arena[KDummyTensorArenaSize];
  24. } // namespace
  25. FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
  26. SimpleMemoryAllocator* allocator,
  27. MicroGraph* micro_graph)
  28. : MicroContext(
  29. MicroAllocator::Create(dummy_tensor_arena, KDummyTensorArenaSize,
  30. GetMicroErrorReporter()),
  31. nullptr, micro_graph),
  32. tensors_(tensors),
  33. allocator_(allocator) {}
  34. TfLiteTensor* FakeMicroContext::AllocateTempTfLiteTensor(int tensor_index) {
  35. allocated_tensor_count_++;
  36. return &tensors_[tensor_index];
  37. }
  38. void FakeMicroContext::DeallocateTempTfLiteTensor(TfLiteTensor* tensor) {
  39. allocated_tensor_count_--;
  40. }
  41. bool FakeMicroContext::IsAllTempTfLiteTensorDeallocated() {
  42. return !allocated_tensor_count_;
  43. }
  44. TfLiteEvalTensor* FakeMicroContext::GetEvalTensor(int tensor_index) {
  45. TfLiteEvalTensor* eval_tensor =
  46. reinterpret_cast<TfLiteEvalTensor*>(allocator_->AllocateTemp(
  47. sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
  48. TFLITE_DCHECK(eval_tensor != nullptr);
  49. // In unit tests, the TfLiteTensor pointer contains the source of truth for
  50. // buffers and values:
  51. eval_tensor->data = tensors_[tensor_index].data;
  52. eval_tensor->dims = tensors_[tensor_index].dims;
  53. eval_tensor->type = tensors_[tensor_index].type;
  54. return eval_tensor;
  55. }
  56. void* FakeMicroContext::AllocatePersistentBuffer(size_t bytes) {
  57. // FakeMicroContext use SimpleMemoryAllocator, which does not automatically
  58. // apply the buffer alignment like MicroAllocator.
  59. // The buffer alignment is potentially wasteful but allows the
  60. // fake_micro_context to work correctly with optimized kernels.
  61. return allocator_->AllocatePersistentBuffer(bytes,
  62. MicroArenaBufferAlignment());
  63. }
  64. TfLiteStatus FakeMicroContext::RequestScratchBufferInArena(size_t bytes,
  65. int* buffer_index) {
  66. TFLITE_DCHECK(buffer_index != nullptr);
  67. if (scratch_buffer_count_ == kNumScratchBuffers_) {
  68. MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
  69. kNumScratchBuffers_);
  70. return kTfLiteError;
  71. }
  72. // For tests, we allocate scratch buffers from the tail and keep them around
  73. // for the lifetime of model. This means that the arena size in the tests will
  74. // be more than what we would have if the scratch buffers could share memory.
  75. scratch_buffers_[scratch_buffer_count_] =
  76. allocator_->AllocatePersistentBuffer(bytes, MicroArenaBufferAlignment());
  77. TFLITE_DCHECK(scratch_buffers_[scratch_buffer_count_] != nullptr);
  78. *buffer_index = scratch_buffer_count_++;
  79. return kTfLiteOk;
  80. }
  81. void* FakeMicroContext::GetScratchBuffer(int buffer_index) {
  82. TFLITE_DCHECK(scratch_buffer_count_ <= kNumScratchBuffers_);
  83. if (buffer_index >= scratch_buffer_count_) {
  84. return nullptr;
  85. }
  86. return scratch_buffers_[buffer_index];
  87. }
  88. } // namespace tflite