lstm_eval.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
  13. #define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
  14. #include <cstdint>
  15. #include <memory>
  16. #include "tensorflow/lite/c/builtin_op_data.h"
  17. #include "tensorflow/lite/c/common.h"
  18. namespace tflite {
  19. // Pamameters for integer LSTM.
  20. // Consider split this into two Integer Parameters if more fields are added.
  21. struct IntegerLstmParameter {
  22. int32_t effective_input_to_input_scale_a;
  23. int32_t effective_input_to_input_scale_b;
  24. int32_t effective_recurrent_to_input_scale_a;
  25. int32_t effective_recurrent_to_input_scale_b;
  26. int32_t effective_cell_to_input_scale_a;
  27. int32_t effective_cell_to_input_scale_b;
  28. int32_t effective_input_to_forget_scale_a;
  29. int32_t effective_input_to_forget_scale_b;
  30. int32_t effective_recurrent_to_forget_scale_a;
  31. int32_t effective_recurrent_to_forget_scale_b;
  32. int32_t effective_cell_to_forget_scale_a;
  33. int32_t effective_cell_to_forget_scale_b;
  34. int32_t effective_input_to_cell_scale_a;
  35. int32_t effective_input_to_cell_scale_b;
  36. int32_t effective_recurrent_to_cell_scale_a;
  37. int32_t effective_recurrent_to_cell_scale_b;
  38. int32_t effective_input_to_output_scale_a;
  39. int32_t effective_input_to_output_scale_b;
  40. int32_t effective_recurrent_to_output_scale_a;
  41. int32_t effective_recurrent_to_output_scale_b;
  42. int32_t effective_cell_to_output_scale_a;
  43. int32_t effective_cell_to_output_scale_b;
  44. int32_t effective_proj_scale_a;
  45. int32_t effective_proj_scale_b;
  46. int32_t effective_hidden_scale_a;
  47. int32_t effective_hidden_scale_b;
  48. int32_t layer_norm_input_scale_a;
  49. int32_t layer_norm_input_scale_b;
  50. int32_t layer_norm_forget_scale_a;
  51. int32_t layer_norm_forget_scale_b;
  52. int32_t layer_norm_cell_scale_a;
  53. int32_t layer_norm_cell_scale_b;
  54. int32_t layer_norm_output_scale_a;
  55. int32_t layer_norm_output_scale_b;
  56. // Quantized clip value for cell and projection. Zero value means no clipping.
  57. int16_t quantized_cell_clip;
  58. int8_t quantized_proj_clip;
  59. int32_t hidden_zp;
  60. int32_t cell_scale;
  61. int32_t input_variance_guard;
  62. int32_t forget_variance_guard;
  63. int32_t cell_variance_guard;
  64. int32_t output_variance_guard;
  65. // Pre-calculate bias + zero_point * weight.
  66. int32_t* input_to_forget_effective_bias;
  67. int32_t* recurrent_to_forget_effective_bias;
  68. int32_t* input_to_cell_effective_bias;
  69. int32_t* recurrent_to_cell_effective_bias;
  70. int32_t* input_to_output_effective_bias;
  71. int32_t* recurrent_to_output_effective_bias;
  72. int32_t* input_to_input_effective_bias;
  73. int32_t* recurrent_to_input_effective_bias;
  74. int32_t* projection_effective_bias;
  75. // Scale and zero point for intermediate tensors.
  76. // Used only in the 8x8_8 case.
  77. int32_t intermediate_scale_a[8];
  78. int32_t intermediate_scale_b[8];
  79. int32_t intermediate_zp[12];
  80. };
  81. // Scales for hybrid op with integer inputs and float weights
  82. struct HybridLstmScales {
  83. float input_to_input_weights_scale;
  84. float input_to_forget_weights_scale;
  85. float input_to_cell_weights_scale;
  86. float input_to_output_weights_scale;
  87. float aux_input_to_input_weights_scale;
  88. float aux_input_to_forget_weights_scale;
  89. float aux_input_to_cell_weights_scale;
  90. float aux_input_to_output_weights_scale;
  91. float recurrent_to_input_weights_scale;
  92. float recurrent_to_forget_weights_scale;
  93. float recurrent_to_cell_weights_scale;
  94. float recurrent_to_output_weights_scale;
  95. float cell_to_input_weights_scale;
  96. float cell_to_forget_weights_scale;
  97. float cell_to_output_weights_scale;
  98. float projection_weights_scale;
  99. };
  100. TfLiteStatus EvalFloatLstm(
  101. const TfLiteEvalTensor* input,
  102. const TfLiteEvalTensor* input_to_input_weights,
  103. const TfLiteEvalTensor* input_to_forget_weights,
  104. const TfLiteEvalTensor* input_to_cell_weights,
  105. const TfLiteEvalTensor* input_to_output_weights,
  106. const TfLiteEvalTensor* recurrent_to_input_weights,
  107. const TfLiteEvalTensor* recurrent_to_forget_weights,
  108. const TfLiteEvalTensor* recurrent_to_cell_weights,
  109. const TfLiteEvalTensor* recurrent_to_output_weights,
  110. const TfLiteEvalTensor* cell_to_input_weights,
  111. const TfLiteEvalTensor* cell_to_forget_weights,
  112. const TfLiteEvalTensor* cell_to_output_weights,
  113. const TfLiteEvalTensor* input_layer_norm_coefficients,
  114. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  115. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  116. const TfLiteEvalTensor* output_layer_norm_coefficients,
  117. const TfLiteEvalTensor* aux_input,
  118. const TfLiteEvalTensor* aux_input_to_input_weights,
  119. const TfLiteEvalTensor* aux_input_to_forget_weights,
  120. const TfLiteEvalTensor* aux_input_to_cell_weights,
  121. const TfLiteEvalTensor* aux_input_to_output_weights,
  122. const TfLiteEvalTensor* input_gate_bias,
  123. const TfLiteEvalTensor* forget_gate_bias,
  124. const TfLiteEvalTensor* cell_gate_bias,
  125. const TfLiteEvalTensor* output_gate_bias,
  126. const TfLiteEvalTensor* projection_weights,
  127. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  128. bool forward_sequence, bool time_major, int output_offset,
  129. float* scratch_buffer, TfLiteEvalTensor* output_state,
  130. TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output);
  131. TfLiteStatus EvalHybridLstm(
  132. const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
  133. const TfLiteEvalTensor* input_to_input_weights,
  134. const TfLiteEvalTensor* input_to_input_weights_ledger,
  135. const TfLiteEvalTensor* input_to_forget_weights,
  136. const TfLiteEvalTensor* input_to_forget_weights_ledger,
  137. const TfLiteEvalTensor* input_to_cell_weights,
  138. const TfLiteEvalTensor* input_to_cell_weights_ledger,
  139. const TfLiteEvalTensor* input_to_output_weights,
  140. const TfLiteEvalTensor* input_to_output_weights_ledger,
  141. const TfLiteEvalTensor* recurrent_to_input_weights,
  142. const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
  143. const TfLiteEvalTensor* recurrent_to_forget_weights,
  144. const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
  145. const TfLiteEvalTensor* recurrent_to_cell_weights,
  146. const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
  147. const TfLiteEvalTensor* recurrent_to_output_weights,
  148. const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
  149. const TfLiteEvalTensor* cell_to_input_weights,
  150. const TfLiteEvalTensor* cell_to_forget_weights,
  151. const TfLiteEvalTensor* cell_to_output_weights,
  152. const TfLiteEvalTensor* input_layer_norm_coefficients,
  153. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  154. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  155. const TfLiteEvalTensor* output_layer_norm_coefficients,
  156. const TfLiteEvalTensor* aux_input,
  157. const TfLiteEvalTensor* aux_input_to_input_weights,
  158. const TfLiteEvalTensor* aux_input_to_forget_weights,
  159. const TfLiteEvalTensor* aux_input_to_cell_weights,
  160. const TfLiteEvalTensor* aux_input_to_output_weights,
  161. const TfLiteEvalTensor* input_gate_bias,
  162. const TfLiteEvalTensor* forget_gate_bias,
  163. const TfLiteEvalTensor* cell_gate_bias,
  164. const TfLiteEvalTensor* output_gate_bias,
  165. const TfLiteEvalTensor* projection_weights,
  166. const TfLiteEvalTensor* projection_weights_ledger,
  167. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  168. bool forward_sequence, bool time_major, int output_offset,
  169. float* scratch_buffer, float* input_sf, float* aux_input_sf,
  170. float* output_state_sf, float* prod_scaling_factors,
  171. float* recovered_cell_weights, int8_t* input_quantized,
  172. int8_t* aux_input_quantized, int8_t* output_state_quantized,
  173. int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
  174. TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
  175. TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
  176. int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
  177. bool* compute_row_sums);
  178. TfLiteStatus EvalInteger8x8_16Lstm(
  179. const TfLiteEvalTensor* input,
  180. const TfLiteEvalTensor* input_to_input_weights,
  181. const TfLiteEvalTensor* input_to_forget_weights,
  182. const TfLiteEvalTensor* input_to_cell_weights,
  183. const TfLiteEvalTensor* input_to_output_weights,
  184. const TfLiteEvalTensor* recurrent_to_input_weights,
  185. const TfLiteEvalTensor* recurrent_to_forget_weights,
  186. const TfLiteEvalTensor* recurrent_to_cell_weights,
  187. const TfLiteEvalTensor* recurrent_to_output_weights,
  188. const TfLiteEvalTensor* cell_to_input_weights,
  189. const TfLiteEvalTensor* cell_to_forget_weights,
  190. const TfLiteEvalTensor* cell_to_output_weights,
  191. const TfLiteEvalTensor* input_layer_norm_coefficients,
  192. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  193. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  194. const TfLiteEvalTensor* output_layer_norm_coefficients,
  195. const TfLiteEvalTensor* input_gate_bias,
  196. const TfLiteEvalTensor* forget_gate_bias,
  197. const TfLiteEvalTensor* cell_gate_bias,
  198. const TfLiteEvalTensor* output_gate_bias,
  199. const TfLiteEvalTensor* projection_weights,
  200. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  201. bool forward_sequence, bool time_major,
  202. const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
  203. TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
  204. TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
  205. int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5);
  206. TfLiteStatus EvalInteger8x8_8Lstm(
  207. const TfLiteEvalTensor* input,
  208. const TfLiteEvalTensor* input_to_input_weights,
  209. const TfLiteEvalTensor* input_to_forget_weights,
  210. const TfLiteEvalTensor* input_to_cell_weights,
  211. const TfLiteEvalTensor* input_to_output_weights,
  212. const TfLiteEvalTensor* recurrent_to_input_weights,
  213. const TfLiteEvalTensor* recurrent_to_forget_weights,
  214. const TfLiteEvalTensor* recurrent_to_cell_weights,
  215. const TfLiteEvalTensor* recurrent_to_output_weights,
  216. const TfLiteEvalTensor* cell_to_input_weights,
  217. const TfLiteEvalTensor* cell_to_forget_weights,
  218. const TfLiteEvalTensor* cell_to_output_weights,
  219. const TfLiteEvalTensor* input_layer_norm_coefficients,
  220. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  221. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  222. const TfLiteEvalTensor* output_layer_norm_coefficients,
  223. const TfLiteEvalTensor* input_gate_bias,
  224. const TfLiteEvalTensor* forget_gate_bias,
  225. const TfLiteEvalTensor* cell_gate_bias,
  226. const TfLiteEvalTensor* output_gate_bias,
  227. const TfLiteEvalTensor* projection_weights,
  228. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  229. TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
  230. TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
  231. int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
  232. int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7);
  233. } // namespace tflite
  234. #endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_