svdf_common.cc 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include <math.h>
  13. #include "tensorflow/lite/c/builtin_op_data.h"
  14. #include "tensorflow/lite/c/common.h"
  15. #include "tensorflow/lite/kernels/internal/common.h"
  16. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  17. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  18. #include "tensorflow/lite/kernels/kernel_util.h"
  19. #include "tensorflow/lite/kernels/op_macros.h"
  20. #include "tensorflow/lite/micro/kernels/activation_utils.h"
  21. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  22. #include "tensorflow/lite/micro/kernels/svdf.h"
  23. #include "tensorflow/lite/micro/micro_utils.h"
  24. namespace tflite {
  25. /**
  26. * This version of SVDF is specific to TFLite Micro. It contains the following
  27. * differences between the TFLite version:
  28. *
  29. * 1.) Scratch tensor allocation - scratch tensors must be known ahead of time
  30. * for the Micro interpreter.
  31. * 2.) Output dimensions - the TFLite version determines output size and runtime
  32. * and resizes the output tensor. Micro runtime does not support tensor
  33. * resizing.
  34. */
  35. const int kSvdfInputTensor = 0;
  36. const int kSvdfWeightsFeatureTensor = 1;
  37. const int kSvdfWeightsTimeTensor = 2;
  38. const int kSvdfBiasTensor = 3;
  39. const int kSvdfInputActivationStateTensor =
  40. 4; // This is a variable tensor, and will be modified by this op.
  41. const int kSvdfOutputTensor = 0;
  42. template <typename T>
  43. void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
  44. const TfLiteEvalTensor* input_tensor,
  45. const TfLiteEvalTensor* weights_feature_tensor,
  46. const TfLiteEvalTensor* weights_time_tensor,
  47. const TfLiteEvalTensor* bias_tensor,
  48. const TfLiteSVDFParams* params,
  49. TfLiteEvalTensor* activation_state_tensor,
  50. TfLiteEvalTensor* output_tensor,
  51. const OpDataSvdf& data) {
  52. const int n_rank = params->rank;
  53. const int n_batch = input_tensor->dims->data[0];
  54. const int n_input = input_tensor->dims->data[1];
  55. const int n_filter = weights_feature_tensor->dims->data[0];
  56. const int n_unit = n_filter / n_rank;
  57. const int n_memory = weights_time_tensor->dims->data[1];
  58. TFLITE_DCHECK(context != nullptr);
  59. TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
  60. int32_t* scratch_tensor = static_cast<int32_t*>(
  61. context->GetScratchBuffer(context, data.scratch_tensor_index));
  62. int32_t* scratch_output_tensor = static_cast<int32_t*>(
  63. context->GetScratchBuffer(context, data.scratch_output_tensor_index));
  64. // Shift states.
  65. T* const state_ptr = tflite::micro::GetTensorData<T>(activation_state_tensor);
  66. // Left shift the activation_state.
  67. {
  68. T* new_state_start = state_ptr;
  69. const T* old_state_start = state_ptr + 1;
  70. const T* old_state_end = state_ptr + n_batch * n_filter * n_memory;
  71. while (old_state_start != old_state_end) {
  72. *new_state_start++ = *old_state_start++;
  73. }
  74. }
  75. // Note: no need to clear the latest activation, matmul is not accumulative.
  76. // Feature matmul.
  77. {
  78. T* state = tflite::micro::GetTensorData<T>(activation_state_tensor);
  79. const int8_t* input = tflite::micro::GetTensorData<int8_t>(input_tensor);
  80. const int8_t* weight_feature =
  81. tflite::micro::GetTensorData<int8_t>(weights_feature_tensor);
  82. const int32_t output_max = std::numeric_limits<T>::max();
  83. const int32_t output_min = std::numeric_limits<T>::min();
  84. T* result_in_batch = state + (n_memory - 1);
  85. for (int b = 0; b < n_batch; b++) {
  86. const int8_t* matrix_ptr = weight_feature;
  87. for (int r = 0; r < n_filter; r++) {
  88. int32_t dot_prod = 0;
  89. const int8_t* vector_in_batch = input + b * n_input;
  90. for (int c = 0; c < n_input; c++) {
  91. dot_prod +=
  92. *matrix_ptr++ * (*vector_in_batch++ - data.input_zero_point);
  93. }
  94. dot_prod = MultiplyByQuantizedMultiplier(
  95. dot_prod, data.effective_scale_1_a, data.effective_scale_1_b);
  96. dot_prod = std::min(std::max(output_min, dot_prod), output_max);
  97. // The int16 version of the op assumes a zero_point of 0. This
  98. // code accounts for the potentially non-zero zero_point for the int8
  99. // version of the op.
  100. *result_in_batch = data.activation_state_zero_point + dot_prod;
  101. result_in_batch += n_memory;
  102. }
  103. }
  104. }
  105. // Time.
  106. {
  107. for (int b = 0; b < n_batch; ++b) {
  108. int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
  109. // Perform batched vector dot product:
  110. const T* vector1_ptr =
  111. tflite::micro::GetTensorData<T>(weights_time_tensor);
  112. const T* vector2_ptr =
  113. tflite::micro::GetTensorData<T>(activation_state_tensor) +
  114. b * n_memory * n_filter;
  115. for (int i = 0; i < n_filter; i++) {
  116. *scratch_ptr_batch = 0;
  117. for (int j = 0; j < n_memory; j++) {
  118. *scratch_ptr_batch +=
  119. *vector1_ptr++ *
  120. (*vector2_ptr++ - data.activation_state_zero_point);
  121. }
  122. scratch_ptr_batch++;
  123. }
  124. }
  125. }
  126. // Reduce, add bias, rescale, activation.
  127. {
  128. // Add bias.
  129. if (bias_tensor) {
  130. // Vector batch assign:
  131. const int32_t* bias_data =
  132. tflite::micro::GetTensorData<int32_t>(bias_tensor);
  133. for (int i = 0; i < n_batch; ++i) {
  134. int32_t* output_ptr = scratch_output_tensor + i * n_unit;
  135. const int32_t* bias_ptr = bias_data;
  136. for (int j = 0; j < n_unit; ++j) {
  137. *output_ptr++ = *bias_ptr++;
  138. }
  139. }
  140. } else {
  141. int32_t* output_ptr = scratch_output_tensor;
  142. for (int i = 0; i < n_batch * n_unit; ++i) {
  143. *output_ptr++ = 0;
  144. }
  145. }
  146. // Reduce.
  147. for (int b = 0; b < n_batch; ++b) {
  148. int32_t* output_temp_ptr = scratch_output_tensor + b * n_unit;
  149. int32_t* scratch_ptr_batch = scratch_tensor + b * n_filter;
  150. // Reduction sum vector
  151. for (int i = 0; i < n_unit; ++i) {
  152. for (int j = 0; j < n_rank; ++j) {
  153. output_temp_ptr[i] += *scratch_ptr_batch++;
  154. }
  155. }
  156. }
  157. // Rescale.
  158. const int32_t output_max = std::numeric_limits<int8_t>::max();
  159. const int32_t output_min = std::numeric_limits<int8_t>::min();
  160. for (int i = 0; i < n_batch * n_unit; ++i) {
  161. int32_t x1 = scratch_output_tensor[i];
  162. int32_t x2 = MultiplyByQuantizedMultiplier(x1, data.effective_scale_2_a,
  163. data.effective_scale_2_b);
  164. int32_t x3 = x2 + data.output_zero_point;
  165. int32_t x4 = std::min(std::max(output_min, x3), output_max);
  166. tflite::micro::GetTensorData<int8_t>(output_tensor)[i] =
  167. static_cast<int8_t>(x4);
  168. }
  169. }
  170. }
  171. /**
  172. * Generate two versions of the integer code. One with int16_t type for the
  173. * time weights and the activation state, and another one with int8_t for the
  174. * same.
  175. */
  176. void EvalInt16SvdfReference(TfLiteContext* context, TfLiteNode* node,
  177. const TfLiteEvalTensor* input_tensor,
  178. const TfLiteEvalTensor* weights_feature_tensor,
  179. const TfLiteEvalTensor* weights_time_tensor,
  180. const TfLiteEvalTensor* bias_tensor,
  181. const TfLiteSVDFParams* params,
  182. TfLiteEvalTensor* activation_state_tensor,
  183. TfLiteEvalTensor* output_tensor,
  184. const OpDataSvdf& data) {
  185. EvalIntegerSvdfReference<int16_t>(
  186. context, node, input_tensor, weights_feature_tensor, weights_time_tensor,
  187. bias_tensor, params, activation_state_tensor, output_tensor, data);
  188. }
  189. void EvalInt8SvdfReference(TfLiteContext* context, TfLiteNode* node,
  190. const TfLiteEvalTensor* input_tensor,
  191. const TfLiteEvalTensor* weights_feature_tensor,
  192. const TfLiteEvalTensor* weights_time_tensor,
  193. const TfLiteEvalTensor* bias_tensor,
  194. const TfLiteSVDFParams* params,
  195. TfLiteEvalTensor* activation_state_tensor,
  196. TfLiteEvalTensor* output_tensor,
  197. const OpDataSvdf& data) {
  198. EvalIntegerSvdfReference<int8_t>(
  199. context, node, input_tensor, weights_feature_tensor, weights_time_tensor,
  200. bias_tensor, params, activation_state_tensor, output_tensor, data);
  201. }
  202. static inline void ApplyTimeWeightsBiasAndActivation(
  203. int batch_size, int memory_size, int num_filters, int num_units, int rank,
  204. const float* const weights_time_ptr, const float* const bias_ptr,
  205. TfLiteFusedActivation activation, float* const state_ptr,
  206. float* const scratch_ptr, float* const output_ptr) {
  207. // Compute matmul(activation_state, weights_time).
  208. for (int b = 0; b < batch_size; ++b) {
  209. // Perform batched vector dot product:
  210. float* scratch_ptr_batch = scratch_ptr + b * num_filters;
  211. const float* vector1_ptr = weights_time_ptr;
  212. const float* vector2_ptr = state_ptr + b * memory_size * num_filters;
  213. for (int i = 0; i < num_filters; ++i) {
  214. *scratch_ptr_batch = 0.f;
  215. for (int j = 0; j < memory_size; ++j) {
  216. *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
  217. }
  218. scratch_ptr_batch++;
  219. }
  220. }
  221. // Initialize output with bias if provided.
  222. if (bias_ptr) {
  223. // VectorBatchVectorAssign
  224. for (int i = 0; i < batch_size; ++i) {
  225. float* output_data = output_ptr + i * num_units;
  226. const float* bias_data = bias_ptr;
  227. for (int j = 0; j < num_units; ++j) {
  228. *output_data++ = *bias_data++;
  229. }
  230. }
  231. } else {
  232. float* output_data = output_ptr;
  233. for (int i = 0; i < batch_size * num_units; ++i) {
  234. *output_data++ = 0.0f;
  235. }
  236. }
  237. // Reduction sum.
  238. for (int b = 0; b < batch_size; ++b) {
  239. float* output_ptr_batch = output_ptr + b * num_units;
  240. float* scratch_ptr_batch = scratch_ptr + b * num_filters;
  241. // Reduction sum vector
  242. for (int i = 0; i < num_units; ++i) {
  243. for (int j = 0; j < rank; j++) {
  244. output_ptr_batch[i] += *scratch_ptr_batch++;
  245. }
  246. }
  247. }
  248. // Apply activation.
  249. for (int b = 0; b < batch_size; ++b) {
  250. float* output_ptr_batch = output_ptr + b * num_units;
  251. for (int i = 0; i < num_units; ++i) {
  252. *output_ptr_batch =
  253. tflite::ops::micro::ActivationValFloat(activation, *output_ptr_batch);
  254. ++output_ptr_batch;
  255. }
  256. }
  257. }
  258. void EvalFloatSvdfReference(
  259. TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input,
  260. const TfLiteEvalTensor* weights_feature,
  261. const TfLiteEvalTensor* weights_time, const TfLiteEvalTensor* bias,
  262. const TfLiteSVDFParams* params, int scratch_tensor_index,
  263. TfLiteEvalTensor* activation_state, TfLiteEvalTensor* output) {
  264. const int rank = params->rank;
  265. const int batch_size = input->dims->data[0];
  266. const int input_size = input->dims->data[1];
  267. const int num_filters = weights_feature->dims->data[0];
  268. const int num_units = num_filters / rank;
  269. const int memory_size = weights_time->dims->data[1];
  270. const float* weights_feature_ptr =
  271. tflite::micro::GetTensorData<float>(weights_feature);
  272. const float* weights_time_ptr =
  273. tflite::micro::GetTensorData<float>(weights_time);
  274. const float* bias_ptr = tflite::micro::GetTensorData<float>(bias);
  275. const float* input_ptr = tflite::micro::GetTensorData<float>(input);
  276. float* state_ptr = tflite::micro::GetTensorData<float>(activation_state);
  277. TFLITE_DCHECK(context != nullptr);
  278. TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
  279. float* scratch_ptr = static_cast<float*>(
  280. context->GetScratchBuffer(context, scratch_tensor_index));
  281. float* output_ptr = tflite::micro::GetTensorData<float>(output);
  282. // Left shift the activation_state.
  283. {
  284. float* new_state_start = state_ptr;
  285. const float* old_state_start = state_ptr + 1;
  286. const float* old_state_end =
  287. state_ptr + batch_size * num_filters * memory_size;
  288. while (old_state_start != old_state_end) {
  289. *new_state_start++ = *old_state_start++;
  290. }
  291. }
  292. // Note: no need to clear the latest activation, matmul is not accumulative.
  293. // Compute conv1d(inputs, weights_feature).
  294. // The activation_state's rightmost column is used to save current cycle
  295. // activation. This is achieved by starting at state_ptr[memory_size - 1] and
  296. // having the stride equal to memory_size.
  297. // Perform batched matrix vector multiply operation:
  298. {
  299. const float* matrix = weights_feature_ptr;
  300. const float* vector = input_ptr;
  301. float* result = &state_ptr[memory_size - 1];
  302. float* result_in_batch = result;
  303. for (int i = 0; i < batch_size; ++i) {
  304. const float* matrix_ptr = matrix;
  305. for (int j = 0; j < num_filters; ++j) {
  306. float dot_prod = 0.0f;
  307. const float* vector_in_batch = vector + i * input_size;
  308. for (int k = 0; k < input_size; ++k) {
  309. dot_prod += *matrix_ptr++ * *vector_in_batch++;
  310. }
  311. *result_in_batch = dot_prod;
  312. result_in_batch += memory_size;
  313. }
  314. }
  315. }
  316. ApplyTimeWeightsBiasAndActivation(
  317. batch_size, memory_size, num_filters, num_units, rank, weights_time_ptr,
  318. bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr);
  319. }
  320. TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
  321. TFLITE_DCHECK(node->builtin_data != nullptr);
  322. const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
  323. MicroContext* micro_context = GetMicroContext(context);
  324. // Validate Tensor Inputs (dtype depends on quantization):
  325. // [0] = Input, {2, batch_size, input_size}
  326. // [1] = Weights Feature, {2, num_filters, input_size}
  327. // [2] = Weights Time, {2, num_filters, memory_size}
  328. // [3] = Bias (optional), {1, num_units}
  329. // [4] = Activation State (variable),
  330. // {2, batch_size, memory_size * num_filters}
  331. TfLiteTensor* input =
  332. micro_context->AllocateTempInputTensor(node, kSvdfInputTensor);
  333. TF_LITE_ENSURE(context, input != nullptr);
  334. TfLiteTensor* weights_feature =
  335. micro_context->AllocateTempInputTensor(node, kSvdfWeightsFeatureTensor);
  336. TF_LITE_ENSURE(context, weights_feature != nullptr);
  337. TfLiteTensor* weights_time =
  338. micro_context->AllocateTempInputTensor(node, kSvdfWeightsTimeTensor);
  339. TF_LITE_ENSURE(context, weights_time != nullptr);
  340. TfLiteTensor* bias =
  341. micro_context->AllocateTempInputTensor(node, kSvdfBiasTensor);
  342. TfLiteTensor* activation_state = micro_context->AllocateTempInputTensor(
  343. node, kSvdfInputActivationStateTensor);
  344. TF_LITE_ENSURE(context, activation_state != nullptr);
  345. // Define input constants based on input tensor definition above:
  346. const int rank = params->rank;
  347. const int input_size = input->dims->data[1];
  348. const int batch_size = input->dims->data[0];
  349. const int num_filters = weights_feature->dims->data[0];
  350. TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
  351. const int num_units = num_filters / rank;
  352. const int memory_size = weights_time->dims->data[1];
  353. // Validate Input Tensor:
  354. TF_LITE_ENSURE(context,
  355. input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
  356. TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
  357. // Validate Tensor Output:
  358. // [0] = float/int8_t, {2, batch_size, num_units}
  359. TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
  360. TfLiteTensor* output =
  361. micro_context->AllocateTempOutputTensor(node, kSvdfOutputTensor);
  362. TF_LITE_ENSURE(context, output != nullptr);
  363. TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
  364. TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
  365. TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
  366. // Validate Weights Feature Input Tensor:
  367. TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
  368. TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
  369. // Validate Weights Time Input Tensor:
  370. TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
  371. TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
  372. TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
  373. // Validate Optional Bias Input Tensor:
  374. if (bias != nullptr) {
  375. TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
  376. }
  377. // Validate Activation State Input Tensor:
  378. TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
  379. TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
  380. TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
  381. memory_size * num_filters);
  382. // Since is_variable is not part of TFLiteEvalTensor, check is_variable here.
  383. TF_LITE_ENSURE_EQ(context, activation_state->is_variable, true);
  384. TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
  385. TFLITE_DCHECK(node->user_data != nullptr);
  386. OpDataSvdf* data = static_cast<OpDataSvdf*>(node->user_data);
  387. if (input->type == kTfLiteInt8) {
  388. TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
  389. TF_LITE_ENSURE(context, (weights_time->type == kTfLiteInt16) ||
  390. (weights_time->type == kTfLiteInt8));
  391. TF_LITE_ENSURE(context, (activation_state->type == kTfLiteInt16) ||
  392. (activation_state->type == kTfLiteInt8));
  393. if (bias != nullptr) {
  394. TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
  395. }
  396. TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
  397. const double effective_scale_1 = static_cast<double>(
  398. input->params.scale * weights_feature->params.scale /
  399. activation_state->params.scale);
  400. const double effective_scale_2 =
  401. static_cast<double>(activation_state->params.scale *
  402. weights_time->params.scale / output->params.scale);
  403. // TODO(b/162018098): Use TF_LITE_ENSURE_NEAR when it is ready.
  404. TF_LITE_ENSURE(
  405. context,
  406. std::abs(static_cast<double>(bias->params.scale) -
  407. static_cast<double>(activation_state->params.scale *
  408. weights_time->params.scale)) < 1e-5);
  409. QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
  410. &(data->effective_scale_1_b));
  411. QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
  412. &(data->effective_scale_2_b));
  413. data->input_zero_point = input->params.zero_point;
  414. data->output_zero_point = output->params.zero_point;
  415. data->activation_state_zero_point = activation_state->params.zero_point;
  416. TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
  417. const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
  418. context, batch_size * num_filters * sizeof(int32_t),
  419. &(data->scratch_tensor_index));
  420. TF_LITE_ENSURE_OK(context, scratch_status);
  421. const TfLiteStatus scratch_output_status =
  422. context->RequestScratchBufferInArena(
  423. context, batch_size * num_units * sizeof(int32_t),
  424. &(data->scratch_output_tensor_index));
  425. TF_LITE_ENSURE_OK(context, scratch_output_status);
  426. } else {
  427. TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
  428. TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
  429. TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
  430. if (bias != nullptr) {
  431. TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
  432. }
  433. TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
  434. TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
  435. const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
  436. context, batch_size * num_filters * sizeof(float),
  437. &(data->scratch_tensor_index));
  438. TF_LITE_ENSURE_OK(context, scratch_status);
  439. }
  440. micro_context->DeallocateTempTfLiteTensor(input);
  441. micro_context->DeallocateTempTfLiteTensor(weights_feature);
  442. micro_context->DeallocateTempTfLiteTensor(weights_time);
  443. micro_context->DeallocateTempTfLiteTensor(activation_state);
  444. micro_context->DeallocateTempTfLiteTensor(output);
  445. micro_context->DeallocateTempTfLiteTensor(bias);
  446. return kTfLiteOk;
  447. }
  448. } // namespace tflite