| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #include "tensorflow/lite/micro/fake_micro_context.h"
- #include "tensorflow/lite/kernels/internal/compatibility.h"
- #include "tensorflow/lite/micro/micro_allocator.h"
- #include "tensorflow/lite/micro/micro_arena_constants.h"
- #include "tensorflow/lite/micro/micro_error_reporter.h"
- #include "tensorflow/lite/micro/simple_memory_allocator.h"
- namespace tflite {
- namespace {
- // Dummy static variables to allow creation of dummy MicroAllocator.
- // All tests are guarateed to run serially.
- static constexpr int KDummyTensorArenaSize = 256;
- static uint8_t dummy_tensor_arena[KDummyTensorArenaSize];
- } // namespace
- FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
- SimpleMemoryAllocator* allocator,
- MicroGraph* micro_graph)
- : MicroContext(
- MicroAllocator::Create(dummy_tensor_arena, KDummyTensorArenaSize,
- GetMicroErrorReporter()),
- nullptr, micro_graph),
- tensors_(tensors),
- allocator_(allocator) {}
- TfLiteTensor* FakeMicroContext::AllocateTempTfLiteTensor(int tensor_index) {
- allocated_tensor_count_++;
- return &tensors_[tensor_index];
- }
- void FakeMicroContext::DeallocateTempTfLiteTensor(TfLiteTensor* tensor) {
- allocated_tensor_count_--;
- }
- bool FakeMicroContext::IsAllTempTfLiteTensorDeallocated() {
- return !allocated_tensor_count_;
- }
- TfLiteEvalTensor* FakeMicroContext::GetEvalTensor(int tensor_index) {
- TfLiteEvalTensor* eval_tensor =
- reinterpret_cast<TfLiteEvalTensor*>(allocator_->AllocateTemp(
- sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
- TFLITE_DCHECK(eval_tensor != nullptr);
- // In unit tests, the TfLiteTensor pointer contains the source of truth for
- // buffers and values:
- eval_tensor->data = tensors_[tensor_index].data;
- eval_tensor->dims = tensors_[tensor_index].dims;
- eval_tensor->type = tensors_[tensor_index].type;
- return eval_tensor;
- }
- void* FakeMicroContext::AllocatePersistentBuffer(size_t bytes) {
- // FakeMicroContext use SimpleMemoryAllocator, which does not automatically
- // apply the buffer alignment like MicroAllocator.
- // The buffer alignment is potentially wasteful but allows the
- // fake_micro_context to work correctly with optimized kernels.
- return allocator_->AllocatePersistentBuffer(bytes,
- MicroArenaBufferAlignment());
- }
- TfLiteStatus FakeMicroContext::RequestScratchBufferInArena(size_t bytes,
- int* buffer_index) {
- TFLITE_DCHECK(buffer_index != nullptr);
- if (scratch_buffer_count_ == kNumScratchBuffers_) {
- MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
- kNumScratchBuffers_);
- return kTfLiteError;
- }
- // For tests, we allocate scratch buffers from the tail and keep them around
- // for the lifetime of model. This means that the arena size in the tests will
- // be more than what we would have if the scratch buffers could share memory.
- scratch_buffers_[scratch_buffer_count_] =
- allocator_->AllocatePersistentBuffer(bytes, MicroArenaBufferAlignment());
- TFLITE_DCHECK(scratch_buffers_[scratch_buffer_count_] != nullptr);
- *buffer_index = scratch_buffer_count_++;
- return kTfLiteOk;
- }
- void* FakeMicroContext::GetScratchBuffer(int buffer_index) {
- TFLITE_DCHECK(scratch_buffer_count_ <= kNumScratchBuffers_);
- if (buffer_index >= scratch_buffer_count_) {
- return nullptr;
- }
- return scratch_buffers_[buffer_index];
- }
- } // namespace tflite
|