circular_buffer.cc 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/c/builtin_op_data.h"
  13. #include "tensorflow/lite/c/common.h"
  14. #include "tensorflow/lite/kernels/internal/compatibility.h"
  15. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  16. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  17. #include "tensorflow/lite/kernels/kernel_util.h"
  18. #include "tensorflow/lite/kernels/op_macros.h"
  19. #include "tensorflow/lite/micro/flatbuffer_utils.h"
  20. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  21. /*
  22. * The circular buffer custom operator is used to implement strided streaming
  23. * convolutions on TFLite Micro. Each time this operator is invoked, it checks
  24. * whether or not to run, based on a predetermined stride in time. If the op
  25. * runs, it inserts the input into the end of the output buffer and shifts the
  26. * output values towards the start of the buffer. It discards the oldest value
  27. * in the output buffer.
  28. *
  29. * Input: [<input N+1]
  30. * Before shifting:
  31. * Output: [<input 1>, <input 2>, <input ...>, <input N>]
  32. *
  33. * After shifting:
  34. * Output: [<input 2>, <input 3>, <input ...>, <input N+1>]
  35. *
  36. * We make some assumptions in this custom operator:
  37. * - Input shape must be [1, 1, 1, depth]
  38. * - Output shape must be [1, num_slots, 1, depth]
  39. * - Input and output types must match.
  40. * - Input and output quantization params must be identical.
  41. */
  42. namespace tflite {
  43. namespace ops {
  44. namespace micro {
  45. namespace circular_buffer {
  46. namespace {
  47. // The CircularBuffer op has one input and one output tensor.
  48. constexpr int kInputTensor = 0;
  49. constexpr int kOutputTensor = 0;
  50. // Indices into the init flexbuffer's vector.
  51. // The parameter's name is in the comment that follows.
  52. // Elements in the vectors are ordered alphabetically by parameter name.
  53. constexpr int kCyclesMaxIndex = 0; // 'cycles_max'
  54. // TODO(b/149795762): Add this to TfLiteStatus enum.
  55. constexpr TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(-9);
  56. // These fields control the stride period of a strided streaming model. This op
  57. // returns kTfLiteAbort until cycles_until_run-- is zero. At this time,
  58. // cycles_until_run is reset to cycles_max.
  59. struct OpData {
  60. int cycles_until_run;
  61. int cycles_max;
  62. };
  63. } // namespace
  64. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  65. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  66. OpData* op_data = static_cast<OpData*>(
  67. context->AllocatePersistentBuffer(context, sizeof(OpData)));
  68. if (buffer != nullptr && length > 0) {
  69. const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
  70. tflite::FlexbufferWrapper wrapper(buffer_t, length);
  71. op_data->cycles_max = wrapper.ElementAsInt32(kCyclesMaxIndex);
  72. } else {
  73. op_data->cycles_max = 0;
  74. }
  75. return op_data;
  76. }
  77. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  78. const TfLiteTensor* input = GetInput(context, node, kInputTensor);
  79. TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  80. TFLITE_DCHECK(node->user_data != nullptr);
  81. OpData* op_data = static_cast<OpData*>(node->user_data);
  82. TF_LITE_ENSURE(context, input != nullptr);
  83. TF_LITE_ENSURE(context, output != nullptr);
  84. TF_LITE_ENSURE_EQ(context, input->dims->data[0], output->dims->data[0]);
  85. TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]);
  86. TF_LITE_ENSURE_EQ(context, input->dims->data[2], output->dims->data[2]);
  87. TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
  88. TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
  89. // The circular buffer custom operator currently only supports int8.
  90. TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
  91. if (op_data->cycles_max <= 0) {
  92. // The last circular buffer layer simply accumulates outputs, and does not
  93. // run periodically.
  94. // TODO(b/150001379): Move this special case logic to the tflite flatbuffer.
  95. static int cb_prepare_count = 0;
  96. cb_prepare_count++;
  97. // These checks specifically work for the only two streaming models
  98. // supported on TFLM. They use the shape of the output tensor along with the
  99. // layer number to determine if the circular buffer period should be 1 or 2.
  100. // These models are outlined int the following documents:
  101. // https://docs.google.com/document/d/1lc_G2ZFhjiKFo02UHjBaljye1xsL0EkfybkaVELEE3Q/edit?usp=sharing
  102. // https://docs.google.com/document/d/1pGc42PuWyrk-Jy1-9qeqtggvsmHr1ifz8Lmqfpr2rKA/edit?usp=sharing
  103. if (output->dims->data[1] == 5 || output->dims->data[1] == 13 ||
  104. output->dims->data[1] == 25 ||
  105. (cb_prepare_count == 5 && output->dims->data[2] == 2 &&
  106. output->dims->data[3] == 96)) {
  107. op_data->cycles_max = 1;
  108. cb_prepare_count = 0;
  109. } else {
  110. op_data->cycles_max = 2;
  111. }
  112. }
  113. op_data->cycles_until_run = op_data->cycles_max;
  114. node->user_data = op_data;
  115. return kTfLiteOk;
  116. }
  117. // Shifts buffer over by the output depth, and write new input to end of buffer.
  118. // num_slots is the number of samples stored in the output buffer.
  119. // depth is the size of each sample.
  120. void EvalInt8(const int8_t* input, int num_slots, int depth, int8_t* output) {
  121. memmove(output, &output[depth], (num_slots - 1) * depth);
  122. memcpy(&output[(num_slots - 1) * depth], input, depth);
  123. }
  124. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  125. const TfLiteEvalTensor* input =
  126. tflite::micro::GetEvalInput(context, node, kInputTensor);
  127. TfLiteEvalTensor* output =
  128. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  129. TFLITE_DCHECK(node->user_data != nullptr);
  130. OpData* data = reinterpret_cast<OpData*>(node->user_data);
  131. int num_slots = output->dims->data[1];
  132. int depth = output->dims->data[2] * output->dims->data[3];
  133. if (input->type == kTfLiteInt8) {
  134. EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
  135. tflite::micro::GetTensorData<int8_t>(output));
  136. } else {
  137. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  138. TfLiteTypeGetName(input->type), input->type);
  139. return kTfLiteError;
  140. }
  141. if (--data->cycles_until_run != 0) {
  142. // Signal the interpreter to end current run if the delay before op invoke
  143. // has not been reached.
  144. // TODO(b/149795762): Add kTfLiteAbort to TfLiteStatus enum.
  145. return static_cast<TfLiteStatus>(kTfLiteAbort);
  146. }
  147. data->cycles_until_run = data->cycles_max;
  148. return kTfLiteOk;
  149. }
  150. } // namespace circular_buffer
  151. TfLiteRegistration* Register_CIRCULAR_BUFFER() {
  152. static TfLiteRegistration r = {/*init=*/circular_buffer::Init,
  153. /*free=*/nullptr,
  154. /*prepare=*/circular_buffer::Prepare,
  155. /*invoke=*/circular_buffer::Eval,
  156. /*profiling_string=*/nullptr,
  157. /*builtin_code=*/0,
  158. /*custom_name=*/nullptr,
  159. /*version=*/0};
  160. return &r;
  161. }
  162. } // namespace micro
  163. } // namespace ops
  164. } // namespace tflite