circular_buffer.cc 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/kernels/circular_buffer.h"
  13. #include "tensorflow/lite/c/builtin_op_data.h"
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/kernels/internal/compatibility.h"
  16. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  17. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  18. #include "tensorflow/lite/kernels/kernel_util.h"
  19. #include "tensorflow/lite/kernels/op_macros.h"
  20. #include "tensorflow/lite/micro/flatbuffer_utils.h"
  21. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  22. /*
  23. * The circular buffer custom operator is used to implement strided streaming
  24. * convolutions on TFLite Micro. Each time this operator is invoked, it checks
  25. * whether or not to run, based on a predetermined stride in time. If the op
  26. * runs, it inserts the input into the end of the output buffer and shifts the
  27. * output values towards the start of the buffer. It discards the oldest value
  28. * in the output buffer.
  29. *
  30. * Input: [<input N+1]
  31. * Before shifting:
  32. * Output: [<input 1>, <input 2>, <input ...>, <input N>]
  33. *
  34. * After shifting:
  35. * Output: [<input 2>, <input 3>, <input ...>, <input N+1>]
  36. *
  37. * We make some assumptions in this custom operator:
  38. * - Input shape must be [1, 1, 1, depth]
  39. * - Output shape must be [1, num_slots, 1, depth]
  40. * - Input and output types must match.
  41. * - Input and output quantization params must be identical.
  42. */
  43. namespace tflite {
  44. void* CircularBufferInit(TfLiteContext* context, const char* buffer,
  45. size_t length) {
  46. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  47. OpDataCircularBuffer* op_data = static_cast<OpDataCircularBuffer*>(
  48. context->AllocatePersistentBuffer(context, sizeof(OpDataCircularBuffer)));
  49. if (buffer != nullptr && length > 0) {
  50. const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
  51. tflite::FlexbufferWrapper wrapper(buffer_t, length);
  52. op_data->cycles_max = wrapper.ElementAsInt32(kCircularBufferCyclesMaxIndex);
  53. } else {
  54. op_data->cycles_max = 0;
  55. }
  56. return op_data;
  57. }
  58. // Shifts buffer over by the output depth, and write new input to end of buffer.
  59. // num_slots is the number of samples stored in the output buffer.
  60. // depth is the size of each sample.
  61. void EvalInt8(const int8_t* input, int num_slots, int depth, int8_t* output) {
  62. memmove(output, &output[depth], (num_slots - 1) * depth);
  63. memcpy(&output[(num_slots - 1) * depth], input, depth);
  64. }
  65. TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
  66. const TfLiteEvalTensor* input =
  67. tflite::micro::GetEvalInput(context, node, kCircularBufferInputTensor);
  68. TfLiteEvalTensor* output =
  69. tflite::micro::GetEvalOutput(context, node, kCircularBufferOutputTensor);
  70. TFLITE_DCHECK(node->user_data != nullptr);
  71. OpDataCircularBuffer* data =
  72. reinterpret_cast<OpDataCircularBuffer*>(node->user_data);
  73. int num_slots = output->dims->data[1];
  74. int depth = output->dims->data[2] * output->dims->data[3];
  75. if (input->type == kTfLiteInt8) {
  76. EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
  77. tflite::micro::GetTensorData<int8_t>(output));
  78. } else {
  79. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  80. TfLiteTypeGetName(input->type), input->type);
  81. return kTfLiteError;
  82. }
  83. if (--data->cycles_until_run != 0) {
  84. // Signal the interpreter to end current run if the delay before op invoke
  85. // has not been reached.
  86. // TODO(b/149795762): Add kTfLiteAbort to TfLiteStatus enum.
  87. return static_cast<TfLiteStatus>(kTfLiteAbort);
  88. }
  89. data->cycles_until_run = data->cycles_max;
  90. return kTfLiteOk;
  91. }
  92. TfLiteRegistration* Register_CIRCULAR_BUFFER() {
  93. static TfLiteRegistration r = {/*init=*/CircularBufferInit,
  94. /*free=*/nullptr,
  95. /*prepare=*/CircularBufferPrepare,
  96. /*invoke=*/CircularBufferEval,
  97. /*profiling_string=*/nullptr,
  98. /*builtin_code=*/0,
  99. /*custom_name=*/nullptr,
  100. /*version=*/0};
  101. return &r;
  102. }
  103. } // namespace tflite