lstm_eval.cc 139 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/kernels/lstm_eval.h"
  13. #include <cmath>
  14. #include <cstdint>
  15. #include <cstring>
  16. #include <memory>
  17. #include "tensorflow/lite/c/builtin_op_data.h"
  18. #include "tensorflow/lite/c/common.h"
  19. #include "tensorflow/lite/kernels/internal/compatibility.h"
  20. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  21. #include "tensorflow/lite/kernels/op_macros.h"
  22. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  23. #include "tensorflow/lite/micro/kernels/micro_tensor_utils.h"
  24. namespace tflite {
  25. namespace {
  26. void ComputeRowSums(
  27. int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
  28. int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
  29. int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
  30. int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
  31. int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
  32. int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
  33. int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
  34. int n_input, int n_aux_input, int n_output,
  35. const int8_t* input_to_input_weights_ptr,
  36. const int8_t* input_to_forget_weights_ptr,
  37. const int8_t* input_to_cell_weights_ptr,
  38. const int8_t* input_to_output_weights_ptr,
  39. const int8_t* aux_input_to_input_weights_ptr,
  40. const int8_t* aux_input_to_forget_weights_ptr,
  41. const int8_t* aux_input_to_cell_weights_ptr,
  42. const int8_t* aux_input_to_output_weights_ptr,
  43. const int8_t* recurrent_to_input_weights_ptr,
  44. const int8_t* recurrent_to_forget_weights_ptr,
  45. const int8_t* recurrent_to_cell_weights_ptr,
  46. const int8_t* recurrent_to_output_weights_ptr,
  47. const int8_t* projection_weights_ptr, bool use_cifg,
  48. const float* aux_input_ptr) {
  49. // Compute the row sums for dequantization
  50. if (!use_cifg) {
  51. micro_tensor_utils::ReductionSumVector(
  52. input_to_input_weights_ptr, input_to_input_row_sums, n_cell, n_input);
  53. }
  54. micro_tensor_utils::ReductionSumVector(
  55. input_to_forget_weights_ptr, input_to_forget_row_sums, n_cell, n_input);
  56. micro_tensor_utils::ReductionSumVector(
  57. input_to_cell_weights_ptr, input_to_cell_row_sums, n_cell, n_input);
  58. micro_tensor_utils::ReductionSumVector(
  59. input_to_output_weights_ptr, input_to_output_row_sums, n_cell, n_input);
  60. if (aux_input_ptr) {
  61. if (!use_cifg) {
  62. micro_tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
  63. aux_input_to_input_row_sums,
  64. n_cell, n_aux_input);
  65. }
  66. micro_tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
  67. aux_input_to_forget_row_sums, n_cell,
  68. n_aux_input);
  69. micro_tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
  70. aux_input_to_cell_row_sums, n_cell,
  71. n_aux_input);
  72. micro_tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
  73. aux_input_to_output_row_sums, n_cell,
  74. n_aux_input);
  75. }
  76. if (!use_cifg) {
  77. micro_tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
  78. recurrent_to_input_row_sums, n_cell,
  79. n_output);
  80. }
  81. micro_tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
  82. recurrent_to_forget_row_sums, n_cell,
  83. n_output);
  84. micro_tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
  85. recurrent_to_cell_row_sums, n_cell,
  86. n_output);
  87. micro_tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
  88. recurrent_to_output_row_sums, n_cell,
  89. n_output);
  90. if (projection_weights_ptr != nullptr) {
  91. micro_tensor_utils::ReductionSumVector(
  92. projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
  93. }
  94. }
  95. // Calculates a single LSTM gate.
  96. //
  97. // Implements the following formula: (* is matrix multiply)
  98. // gate = activate(W_input * input + W_aux * aux_input +
  99. // W_peephole * cell + W_recurrent * prev_output + bias)
  100. // with layer norm:
  101. // gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
  102. //
  103. // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
  104. //
  105. // Parameters:
  106. // Input vectors (to LSTM): | Size: | Optional?
  107. // input | n_input |
  108. // aux_input | n_aux_input | y (bidir LSTM)
  109. // Input vectors (persistent states):
  110. // output_state | n_output |
  111. // cell_state | n_cell |
  112. // 'Constant' inputs:
  113. // input_to_gate_weights | n_cell * n_input |
  114. // aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
  115. // recurrent_to_gate_weights | n_cell * n_output |
  116. // cell_to_gate_weights | n_cell | y (peephole)
  117. // gate_bias | n_cell |
  118. // layer_norm_coefficients | n_cell | y (layer norm)
  119. // Output vector:
  120. // gate | n_cell |
  121. // Scalar parameters:
  122. // n_batch - batch size / number of vectors
  123. // n_input, n_aux_input, n_output, n_cell - size of vectors.
  124. // activation - activation to use.
  125. // is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
  126. // use_layer_norm - if doing layer norm LSTM.
  127. inline void CalculateLstmGateFloat(
  128. const float* input, const float* input_to_gate_weights,
  129. const float* aux_input, const float* aux_input_to_gate_weights,
  130. const float* output_state, const float* recurrent_to_gate_weights,
  131. const float* cell_state, const float* cell_to_gate_weights,
  132. const float* layer_norm_coefficients, const float* gate_bias,
  133. const int n_batch, const int n_input, const int n_aux_input,
  134. const int n_output, const int n_cell,
  135. const TfLiteFusedActivation activation, float* gate,
  136. const bool is_input_all_zeros, const bool is_aux_input_all_zeros) {
  137. const bool use_peephole = (cell_to_gate_weights != nullptr);
  138. const bool use_layer_norm = (layer_norm_coefficients != nullptr);
  139. // Initialize scratch buffers with bias for regular lstm or initialize with
  140. // zero for layer norm lstm.
  141. if (use_layer_norm) {
  142. memset(gate, 0, n_cell * n_batch * sizeof(float));
  143. } else {
  144. micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch,
  145. gate);
  146. }
  147. // For each batch and cell: compute input_weight * input.
  148. // Skip if input is all zeros.
  149. if (!is_input_all_zeros) {
  150. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  151. input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
  152. }
  153. // For each batch and cell: compute aux_input_weight * aux_input.
  154. // Skip if auxiliary input is not available or all zeros.
  155. if (!is_aux_input_all_zeros) {
  156. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  157. aux_input_to_gate_weights, n_cell, n_aux_input, aux_input, n_batch,
  158. gate);
  159. }
  160. // For each batch and cell: compute recurrent_weight * output_state.
  161. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  162. recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
  163. // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
  164. if (use_peephole) {
  165. micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
  166. cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
  167. }
  168. // Do layer normalization (if layer norm LSTM)
  169. if (use_layer_norm) {
  170. micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
  171. micro_tensor_utils::VectorBatchVectorCwiseProduct(
  172. layer_norm_coefficients, n_cell, gate, n_batch, gate);
  173. micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
  174. }
  175. // Apply activation
  176. micro_tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell,
  177. activation, gate);
  178. }
  179. // Updates the LSTM cell state, used by both float and hybrid LSTM versions.
  180. //
  181. // Implements the following formula:
  182. // cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
  183. //
  184. // With CIFG LSTM, input gate is replaced by (1-forget_gate).
  185. //
  186. // Parameters:
  187. // - n_batch, n_cell: sizes of vectors
  188. // - cell_state: input/output vector, size n_batch*n_cell
  189. // - input_gate: input vector, size n_batch*n_cell.
  190. // - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
  191. // - cell_gate: input vector, size n_batch*n_cell.
  192. // - use_cifg: use 1-forget_gate instead of input_gate.
  193. // - clip: if > 0, clip the resulting cell state to [-clip, +clip].
  194. void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
  195. const float* input_gate, float* forget_gate,
  196. const float* cell_gate, bool use_cifg, float clip) {
  197. micro_tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
  198. n_batch * n_cell, cell_state);
  199. if (use_cifg) {
  200. // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
  201. // scratch, as input_gate array is not allocated in this case. (Be careful
  202. // not to write to the scratch before reading the forget gate data.)
  203. float* scratch = forget_gate;
  204. micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
  205. micro_tensor_utils::VectorVectorCwiseProductAccumulate(
  206. cell_gate, scratch, n_batch * n_cell, cell_state);
  207. } else {
  208. micro_tensor_utils::VectorVectorCwiseProductAccumulate(
  209. cell_gate, input_gate, n_batch * n_cell, cell_state);
  210. }
  211. if (clip > 0.0f) {
  212. micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
  213. }
  214. }
  215. // Calculates the output state tensor of an LSTM step.
  216. //
  217. // Implements the following formula:
  218. // output_no_projection = output_gate .* activate(cell_state)
  219. // (elementwise vector product)
  220. // If no projection is used:
  221. // output = output_state = output_no_projection
  222. // With projection:
  223. // output = output_state = clip(W*output_no_projection + bias)
  224. //
  225. // Output might not have a different 'stride' than n_batch, so we need to copy.
  226. //
  227. // Parameters:
  228. // - n_batch: batches: the number of distinct vectors in each array.
  229. // - n_cell, n_output: sizes of vectors.
  230. // - cell_state, output_gate: input vectors, size n_batch*n_cell.
  231. // - projection_weights, projection_weights_scale, projection_bias:
  232. // constant inputs, describing projection matrix and bias.
  233. // - proj_clip: if > 0, clip the output of the projection.
  234. // - output_state: output vector, size n_batch*n_output. Must be contigous.
  235. // - scratch: scratch area, size n_batch*n_cell.
  236. void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
  237. const float* cell_state, const float* output_gate,
  238. TfLiteFusedActivation activation,
  239. const float* projection_weights,
  240. const float* projection_bias,
  241. const float proj_clip, float* output_state,
  242. float* scratch) {
  243. micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
  244. activation, scratch);
  245. micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch,
  246. n_batch * n_cell, scratch);
  247. const bool use_projection = (projection_weights != nullptr);
  248. const bool use_projection_bias = (projection_bias != nullptr);
  249. if (use_projection) {
  250. if (use_projection_bias) {
  251. micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output,
  252. n_batch, output_state);
  253. } else {
  254. memset(output_state, 0, n_batch * n_output * sizeof(float));
  255. }
  256. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  257. projection_weights, n_output, n_cell, scratch, n_batch, output_state);
  258. if (proj_clip > 0.0f) {
  259. micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
  260. proj_clip);
  261. }
  262. } else {
  263. std::memcpy(output_state, scratch, n_batch * n_output * sizeof(float));
  264. }
  265. }
  266. // Calculates a single LSTM gate, hybrid version.
  267. // Implements the same functionality as CalculateLstmGateFloat.
  268. void CalculateLstmGateHybrid(
  269. // Input and weights
  270. const int8_t* input, const float* input_sf, const int32_t* input_zp,
  271. const int8_t* input_to_gate_weights,
  272. const uint8_t* input_to_gate_weights_ledger,
  273. const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
  274. // Aux input and weights
  275. const int8_t* aux_input, const float* aux_input_sf,
  276. const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights,
  277. const float aux_input_to_gate_weights_scale,
  278. int32_t* aux_input_to_gate_row_sums,
  279. // Output state and weights
  280. const int8_t* output_state, const float* output_state_sf,
  281. const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
  282. const uint8_t* recurrent_to_gate_weights_ledger,
  283. const float recurrent_to_gate_weights_scale,
  284. int32_t* recurrent_to_gate_row_sums,
  285. // Cell state and weights (peephole LSTM)
  286. const float* cell_state, const int8_t* cell_to_gate_weights,
  287. const float cell_to_gate_weights_scale,
  288. // Layer normalization coefficients (layer norm LSTM) + gate bias
  289. const float* layer_norm_coefficients, const float* gate_bias,
  290. // Array sizes
  291. const int n_batch, const int n_input, const int n_aux_input,
  292. const int n_output, const int n_cell,
  293. const TfLiteFusedActivation activation,
  294. // Output
  295. float* gate,
  296. // Parameters for performance optimizations
  297. const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
  298. const bool is_output_state_all_zeros, bool* compute_row_sums,
  299. // Scratch arrays
  300. float* scratch0, // size: n_batch
  301. float* scratch1, // size: n_cell, only used if peephole LSTM
  302. float* scales, // size: n_batch
  303. int32_t* accum_scratch // For MatrixBatchVectorMultiplyAccumulate
  304. ) {
  305. const bool use_peephole = (cell_to_gate_weights != nullptr);
  306. const bool use_layer_norm = (layer_norm_coefficients != nullptr);
  307. // Initialize scratch buffers with bias for regular lstm or initialize with
  308. // zero for layer norm lstm.
  309. if (use_layer_norm) {
  310. memset(gate, 0, n_cell * n_batch * sizeof(float));
  311. } else {
  312. micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch,
  313. gate);
  314. }
  315. // For each batch and cell: compute input_weight * input.
  316. // Skip if input is all zeros.
  317. if (!is_input_all_zeros) {
  318. if (input_to_gate_weights_ledger != nullptr) {
  319. for (int i = 0; i < n_batch; i++) {
  320. scales[i] = input_to_gate_weights_scale * input_sf[i];
  321. }
  322. micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
  323. input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
  324. input, scales, n_batch, gate);
  325. } else {
  326. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  327. input_to_gate_weights, n_cell, n_input, input,
  328. input_to_gate_weights_scale, input_sf, n_batch, gate,
  329. /*per_channel_scale=*/nullptr, input_zp, accum_scratch,
  330. input_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
  331. }
  332. }
  333. // For each batch and cell: compute aux_input_weight * aux_input.
  334. // Skip if auxiliary input is not available or all zeros.
  335. if (!is_aux_input_all_zeros) {
  336. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  337. aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
  338. aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate,
  339. /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch,
  340. aux_input_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
  341. }
  342. // For each batch and cell: compute recurrent_weight * output_state.
  343. // Skip if output state is all zeros.
  344. if (!is_output_state_all_zeros) {
  345. if (recurrent_to_gate_weights_ledger != nullptr) {
  346. for (int i = 0; i < n_batch; i++) {
  347. scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
  348. }
  349. micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
  350. recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
  351. n_output, output_state, scales, n_batch, gate);
  352. } else {
  353. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  354. recurrent_to_gate_weights, n_cell, n_output, output_state,
  355. recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
  356. /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
  357. recurrent_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
  358. }
  359. }
  360. // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
  361. if (use_peephole) {
  362. float* recovered_cell_weights = scratch1;
  363. micro_tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell,
  364. cell_to_gate_weights_scale,
  365. recovered_cell_weights);
  366. micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
  367. recovered_cell_weights, n_cell, cell_state, n_batch, gate);
  368. }
  369. // Do layer normalization (if layer norm LSTM)
  370. if (use_layer_norm) {
  371. micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
  372. micro_tensor_utils::VectorBatchVectorCwiseProduct(
  373. layer_norm_coefficients, n_cell, gate, n_batch, gate);
  374. micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
  375. }
  376. // Apply activation
  377. micro_tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch,
  378. activation, gate);
  379. }
  380. // Calculates the output state tensor of an LSTM step. See Float version too.
  381. //
  382. // Parameters:
  383. // - n_batch: batches: the number of distinct vectors in each array.
  384. // - n_cell, n_output: sizes of vectors.
  385. // - cell_state, output_gate: input vectors, size n_batch*n_cell.
  386. // - projection_weights, projection_weights_scale, projection_bias:
  387. // constant inputs, describing projection matrix and bias.
  388. // - proj_clip: if > 0, clip the output of the projection.
  389. // - output_state: output vector, size n_batch*n_output. Must be contigous.
  390. // - asymmetric_quantize_inputs: parameter to control quantization.
  391. // - projection_weights_row_sums, compute_row_sums: Data for optimized
  392. // MatrixBatchVectorMultiplyAccumulate.
  393. // - scratch0: scratch area of size n_batch*n_cell
  394. // - scratch1: scratch area of size n_batch*n_cell
  395. // - scratch2: scratch area of size n_batch
  396. // - scratch3: scratch area of size n_batch
  397. // - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate
  398. // - scales: scratch area of size n_batch
  399. void CalculateLstmOutputHybrid(
  400. int n_batch, int n_cell, int n_output, const float* cell_state,
  401. const float* output_gate, TfLiteFusedActivation activation,
  402. const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
  403. float projection_weights_scale, const float* projection_bias,
  404. const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
  405. int32_t* projection_weights_row_sums, bool* compute_row_sums,
  406. float* scratch0, int8_t* scratch1, float* scratch2, int32_t* scratch3,
  407. int32_t* scratch4, float* scales) {
  408. micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
  409. activation, scratch0);
  410. micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
  411. n_batch * n_cell, scratch0);
  412. const bool use_projection = (projection_weights != nullptr);
  413. const bool use_projection_bias = (projection_bias != nullptr);
  414. if (use_projection) {
  415. if (use_projection_bias) {
  416. micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output,
  417. n_batch, output_state);
  418. } else {
  419. memset(output_state, 0, n_batch * n_output * sizeof(float));
  420. }
  421. if (!micro_tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) {
  422. // Save quantization and matmul computation for all zero output.
  423. micro_tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell,
  424. scratch1, scratch2, scratch3,
  425. asymmetric_quantize_inputs);
  426. if (projection_weights_ledger != nullptr) {
  427. for (int i = 0; i < n_batch; i++) {
  428. scales[i] = projection_weights_scale * scratch2[i];
  429. }
  430. micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
  431. projection_weights, projection_weights_ledger, n_output, n_cell,
  432. scratch1, scales, n_batch, output_state);
  433. } else {
  434. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  435. projection_weights, n_output, n_cell, scratch1,
  436. projection_weights_scale, scratch2, n_batch, output_state,
  437. /*per_channel_scale=*/nullptr, scratch3, scratch4,
  438. projection_weights_row_sums, compute_row_sums, scratch2, nullptr);
  439. }
  440. }
  441. if (proj_clip > 0.0f) {
  442. micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
  443. proj_clip);
  444. }
  445. } else {
  446. std::memcpy(output_state, scratch0, n_batch * n_output * sizeof(float));
  447. }
  448. }
  449. // Calculates a single LSTM gate, int8x8_16 version.
  450. // Implements the same functionality as CalculateLstmGateFloat.
  451. void CalculateLstmGateInteger8x8_16(
  452. // Input and weights
  453. const int8_t* input, const int8_t* input_to_gate_weights,
  454. const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
  455. const int32_t input_to_gate_scale_b,
  456. // Output state and weights
  457. const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
  458. const int32_t* recurrent_to_gate_bias,
  459. const int32_t recurrent_to_gate_scale_a,
  460. const int32_t recurrent_to_gate_scale_b,
  461. // Cell state and weights
  462. const int16_t* cell_state, const int16_t* cell_to_gate_weights,
  463. const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
  464. // Layer normalization parameters (layer norm LSTM)
  465. const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
  466. const int32_t layer_norm_input_scale_a,
  467. const int32_t layer_norm_input_scale_b,
  468. const int32_t layer_norm_variance_guard,
  469. // Array sizes
  470. const int n_batch, const int n_input, const int n_output, const int n_cell,
  471. const TfLiteFusedActivation activation,
  472. // Output
  473. int16_t* gate,
  474. // Parameters for performance optimizations
  475. // Scratch arrays
  476. int32_t* scratch5) {
  477. const bool use_peephole = (cell_to_gate_weights != nullptr);
  478. const bool use_layer_norm = (layer_norm_coefficients != nullptr);
  479. // Initialize scratch buffers with zeros. Note that unlike float and hybrid
  480. // versions, bias is only used in layer normalization.
  481. memset(gate, 0, n_batch * n_cell * sizeof(int16_t));
  482. // For each batch and cell: compute input_weight * input.
  483. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  484. input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
  485. input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate,
  486. nullptr);
  487. // Note: no aux_input.
  488. // For each batch and cell: compute recurrent_weight * output_state.
  489. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  490. output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
  491. recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
  492. n_cell, 0, scratch5, gate, nullptr);
  493. // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
  494. if (use_peephole) {
  495. micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
  496. cell_to_gate_weights, n_output, cell_state, n_batch,
  497. cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
  498. }
  499. // Do layer normalization (if layer norm LSTM)
  500. if (use_layer_norm) {
  501. micro_tensor_utils::ApplyLayerNorm(
  502. gate, layer_norm_coefficients, layer_norm_bias,
  503. layer_norm_input_scale_a, layer_norm_input_scale_b,
  504. layer_norm_variance_guard, n_batch, n_cell, gate);
  505. }
  506. // Apply activation
  507. switch (activation) {
  508. case kTfLiteActSigmoid:
  509. micro_tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
  510. break;
  511. case kTfLiteActTanh:
  512. micro_tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
  513. break;
  514. default:
  515. // Only Sigmoid or Tanh is used.
  516. TFLITE_ASSERT_FALSE;
  517. }
  518. }
  519. // Updates the LSTM cell state, used by both integer LSTM versions.
  520. // Also see UpdateLstmCellFloat.
  521. //
  522. // Parameters:
  523. // - n_batch, n_cell: sizes of vectors
  524. // - cell_state: input/output vector, size n_batch*n_cell
  525. // - cell_state_scale: scaling factor of cell state.
  526. // - input_gate: input vector, size n_batch*n_cell.
  527. // - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
  528. // - cell_gate: input vector, size n_batch*n_cell.
  529. // - use_cifg: use 1-forget_gate instead of input_gate.
  530. // - clip: if > 0, clip the resulting cell state to [-clip, +clip].
  531. void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
  532. int32_t cell_state_scale, const int16_t* input_gate,
  533. int16_t* forget_gate, const int16_t* cell_gate,
  534. bool use_cifg, int16_t clip) {
  535. // Use the forget_gate array as scratch, as input_gate array is not allocated
  536. // in CIFG case. (Be careful not to write to the scratch before reading the
  537. // forget gate data.)
  538. int16_t* scratch = forget_gate;
  539. micro_tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
  540. cell_state);
  541. if (use_cifg) {
  542. micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
  543. micro_tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell,
  544. 30 + cell_state_scale, scratch);
  545. } else {
  546. micro_tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell,
  547. 30 + cell_state_scale, scratch);
  548. }
  549. micro_tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell,
  550. cell_state);
  551. if (clip > 0) {
  552. micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
  553. }
  554. }
  555. // Calculates the output state tensor of an LSTM step. See Float and hybrid
  556. // versions as well.
  557. //
  558. // Parameters:
  559. // - n_batch: batches: the number of distinct vectors in each array.
  560. // - n_cell, n_output: sizes of vectors.
  561. // - cell_state, output_gate: input vectors, size n_batch*n_cell.
  562. // - cell_state_scale: scaling of cell_state.
  563. // - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
  564. // - hidden_zp: zero_point for cell_state.*output_gate
  565. // - projection_weights, proj_scale_[a|b], projection_bias:
  566. // constant inputs, describing projection matrix and bias.
  567. // - output_state_zp: zero point of output_state. (Input, calibrated value.)
  568. // - quantized_proj_clip: if > 0, clip the output of the projection.
  569. // - output_state: output vector, size n_batch*n_output. Must be contigous.
  570. // - scratch0: scratch area of size n_batch*n_cell
  571. // - scratch1: scratch area of size n_batch*n_cell
  572. // - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
  573. void CalculateLstmOutputInteger8x8_16(
  574. int n_batch, int n_cell, int n_output, const int16_t* cell_state,
  575. int32_t cell_state_scale, const int16_t* output_gate,
  576. int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
  577. const int8_t* projection_weights, int32_t proj_scale_a,
  578. int32_t proj_scale_b, const int32_t* projection_bias,
  579. int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
  580. int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
  581. // Note: unlike float/hybrid, the activation is always Tanh.
  582. micro_tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch,
  583. n_cell, scratch0);
  584. micro_tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a,
  585. hidden_scale_b, n_batch, n_cell, hidden_zp,
  586. scratch1);
  587. const bool use_projection = (projection_weights != nullptr);
  588. if (use_projection) {
  589. // Note: no bias like in float/hybrid
  590. memset(output_state, 0, n_batch * n_output * sizeof(int8_t));
  591. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  592. scratch1, projection_bias, projection_weights, proj_scale_a,
  593. proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
  594. output_state, nullptr);
  595. if (quantized_proj_clip > 0) {
  596. micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
  597. quantized_proj_clip);
  598. }
  599. } else {
  600. std::memcpy(output_state, scratch1, n_batch * n_output * sizeof(int8_t));
  601. }
  602. }
  603. // Calculates a single LSTM gate, int8x8_8 version.
  604. // Implements the same functionality as CalculateLstmGateFloat.
  605. void CalculateLstmGateInteger8x8_8(
  606. // Inputs and weights
  607. const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
  608. const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
  609. const int32_t input_times_weights_scale_a,
  610. const int32_t input_times_weights_scale_b,
  611. const int32_t input_times_weights_zp,
  612. // Output state and weights
  613. const int8_t* output_state, const int32_t output_state_zp,
  614. const int8_t* recurrent_to_gate_weight,
  615. const int32_t recurrent_to_gate_scale_a,
  616. const int32_t recurrent_to_gate_scale_b,
  617. const int32_t output_state_times_weights_scale_a,
  618. const int32_t output_state_times_weights_scale_b,
  619. const int32_t output_state_times_weights_zp,
  620. // Layer normalization parameters (layer norm LSTM)
  621. const int16_t* layer_norm_gate_weight,
  622. const int32_t layer_norm_gate_scale_a,
  623. const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
  624. // Array sizes
  625. const int n_batch, const int n_input, const int n_output, const int n_cell,
  626. const TfLiteFusedActivation activation,
  627. // Output
  628. int16_t* gate,
  629. // Scratch arrays, both sized n_batch*n_cell
  630. int8_t* scratch0, int8_t* scratch1) {
  631. // Multiply input * input_weights => scratch0
  632. micro_tensor_utils::MatrixBatchVectorMultiply(
  633. input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
  634. input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
  635. input_times_weights_zp);
  636. // Multiply output_state * recurrent_weights => scratch1
  637. micro_tensor_utils::MatrixBatchVectorMultiply(
  638. output_state, output_state_zp, recurrent_to_gate_weight,
  639. recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
  640. n_cell, scratch1, output_state_times_weights_zp);
  641. // Add scratch0 + scratch1 => gate
  642. micro_tensor_utils::TwoGateSaturatingAdd(
  643. scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
  644. input_times_weights_scale_a, input_times_weights_scale_b,
  645. output_state_times_weights_scale_a, output_state_times_weights_scale_b,
  646. n_batch, n_cell, gate);
  647. // Apply layer normalization.
  648. micro_tensor_utils::ApplyLayerNormFloat(
  649. gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
  650. layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
  651. // Apply activation.
  652. switch (activation) {
  653. case kTfLiteActSigmoid:
  654. micro_tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate);
  655. break;
  656. case kTfLiteActTanh:
  657. micro_tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
  658. break;
  659. default:
  660. // Only Sigmoid or Tanh is used.
  661. TFLITE_ASSERT_FALSE;
  662. }
  663. }
  664. // Calculates the output state tensor of an LSTM step. See Float and hybrid
  665. // versions as well.
  666. //
  667. // Parameters:
  668. // - n_batch: batches: the number of distinct vectors in each array.
  669. // - n_cell, n_output: sizes of vectors.
  670. // - cell_state, output_gate: input vectors, size n_batch*n_cell.
  671. // - projection_weights, proj_scale_[a|b], projection_bias:
  672. // constant inputs, describing projection matrix and bias.
  673. // - output_state_zp: zero point of the output state.
  674. // - quantized_proj_clip: if > 0, clip the output of the projection.
  675. // - output_state: output vector, size n_batch*n_output. Must be contigous.
  676. // - scratch: scratch area of size n_batch*n_cell
  677. void CalculateLstmOutputInteger8x8_8(
  678. int n_batch, int n_cell, int n_output, const int16_t* cell_state,
  679. const int16_t* output_gate, const int8_t* projection_weights,
  680. int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
  681. int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
  682. int16_t* scratch) {
  683. // Note: unlike float/hybrid, the activation is always Tanh.
  684. micro_tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch);
  685. micro_tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell,
  686. 15 + 15 - 15, scratch);
  687. // Note: no bias like in float/hybrid
  688. micro_tensor_utils::MatrixBatchVectorMultiply(
  689. scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
  690. n_batch, n_cell, n_output, output_state_zp, output_state);
  691. if (quantized_proj_clip > 0) {
  692. micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
  693. quantized_proj_clip);
  694. }
  695. }
  696. // Performs an LSTM batch inference step for input specified by input_ptr.
  697. // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
  698. // biases (*_bias_ptr), and buffers (*_scratch), along with additional
  699. // parameters:
  700. // - params: various LSTM params including activation, clipping, etc.,
  701. // - n_batch: size of batch,
  702. // - n_cell: number of cells (or units),
  703. // - n_input: the input size,
  704. // - n_aux_input: the auxiliary input size.
  705. // - n_output: the output size.
  706. // - output_batch_leading_dim: the leading dimension of the output buffer.
  707. //
  708. // Input of size 'n_batch * n_input':
  709. // input_ptr
  710. // Input of size 'n_batch * n_aux_input':
  711. // aux_input_ptr - optional (can be nullptr)
  712. //
  713. // LSTM weights:
  714. // Input weights of size 'n_cell * n_input':
  715. // input_to_input_weights - optional
  716. // input_to_forget_weights
  717. // input_to_cell_weights
  718. // input_to_output_weights
  719. // Auxiliary input weights of size 'n_cell * n_aux_input':
  720. // aux_input_to_input_weights - optional
  721. // aux_input_to_forget_weights - optional
  722. // aux_input_to_cell_weights - optional
  723. // aux_input_to_output_weights - optional
  724. // Recurrent weights of size 'n_cell * n_output':
  725. // recurrent_to_input_weights - optional
  726. // recurrent_to_forget_weights
  727. // recurrent_to_cell_weights
  728. // recurrent_to_input_weights
  729. // Peephole weights of size 'n_cell', representing diagonal matrices.
  730. // cell_to_input_weights - optional
  731. // cell_to_cell_weights - optional
  732. // cell_to_output_weights - optional
  733. // Projection weights of size 'n_output * n_cell'
  734. // projection_weights_ptr - optional
  735. // Gate biases of size 'n_cell':
  736. // input_gate_bias_ptr - optional
  737. // forget_gate_bias_ptr
  738. // cell_gate_bias_ptr
  739. // output_gate_bias_ptr
  740. //
  741. // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
  742. // input_layer_norm_coefficients_ptr - optional
  743. // forget_layer_norm_coefficients_ptr - optional
  744. // cell_layer_norm_coefficients_ptr - optional
  745. // output_layer_norm_coefficients_ptr - optional
  746. //
  747. // The pointers to the cell and output state and the output are updated.
  748. //
  749. // The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
  750. // in batch_major order, and each step processes batch_size many inputs from
  751. // input_ptr, and updates batch_size many cell and output states.
  752. //
  753. // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
  754. // output tensor, and in most cases will be equal to n_output. It is usually not
  755. // when we want to store the LSTM output into a slice of the output tensor, e.g.
  756. // for bidirectional LSTMs with merge_outputs. In this case, the batched
  757. // operations cannot be used since they assume that the batched outputs are
  758. // contiguous, and we manually loop over the batched outputs.
  759. inline void LstmStepFloat(
  760. const float* input_ptr, const float* input_to_input_weights_ptr,
  761. const float* input_to_forget_weights_ptr,
  762. const float* input_to_cell_weights_ptr,
  763. const float* input_to_output_weights_ptr, const float* aux_input_ptr,
  764. const float* aux_input_to_input_weights_ptr,
  765. const float* aux_input_to_forget_weights_ptr,
  766. const float* aux_input_to_cell_weights_ptr,
  767. const float* aux_input_to_output_weights_ptr,
  768. const float* recurrent_to_input_weights_ptr,
  769. const float* recurrent_to_forget_weights_ptr,
  770. const float* recurrent_to_cell_weights_ptr,
  771. const float* recurrent_to_output_weights_ptr,
  772. const float* cell_to_input_weights_ptr,
  773. const float* cell_to_forget_weights_ptr,
  774. const float* cell_to_output_weights_ptr,
  775. const float* input_layer_norm_coefficients_ptr,
  776. const float* forget_layer_norm_coefficients_ptr,
  777. const float* cell_layer_norm_coefficients_ptr,
  778. const float* output_layer_norm_coefficients_ptr,
  779. const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
  780. const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
  781. const float* projection_weights_ptr, const float* projection_bias_ptr,
  782. const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
  783. int n_aux_input, int n_output, int output_batch_leading_dim,
  784. float* output_state_ptr, float* cell_state_ptr, float* scratch0,
  785. float* scratch1, float* scratch2, float* scratch3, float* output_ptr) {
  786. // Since we have already checked that weights are all there or none, we can
  787. // check the existence of only one to the get the condition.
  788. const bool use_cifg = (input_to_input_weights_ptr == nullptr);
  789. // Make named scratch buffers.
  790. float* input_gate_scratch = scratch0;
  791. float* forget_gate_scratch = scratch1;
  792. float* cell_gate_scratch = scratch2;
  793. float* output_gate_scratch = scratch3;
  794. // Check if inputs are all zeros so we can skip some computations.
  795. const bool is_input_all_zeros =
  796. micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
  797. const bool is_aux_input_all_zeros =
  798. (aux_input_ptr == nullptr ||
  799. micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
  800. if (!use_cifg) {
  801. // Calculate the input gate. (If not CIFG.)
  802. CalculateLstmGateFloat(
  803. input_ptr, input_to_input_weights_ptr, aux_input_ptr,
  804. aux_input_to_input_weights_ptr, output_state_ptr,
  805. recurrent_to_input_weights_ptr, cell_state_ptr,
  806. cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
  807. input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
  808. /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
  809. is_input_all_zeros, is_aux_input_all_zeros);
  810. }
  811. // Calculate the forget gate.
  812. CalculateLstmGateFloat(
  813. input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
  814. aux_input_to_forget_weights_ptr, output_state_ptr,
  815. recurrent_to_forget_weights_ptr, cell_state_ptr,
  816. cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
  817. forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
  818. /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
  819. is_aux_input_all_zeros);
  820. // Calculate the cell update gate.
  821. CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
  822. aux_input_to_cell_weights_ptr, output_state_ptr,
  823. recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
  824. /*cell_to_gate_weights=*/nullptr,
  825. cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
  826. n_batch, n_input, n_aux_input, n_output, n_cell,
  827. params->activation, cell_gate_scratch,
  828. is_input_all_zeros, is_aux_input_all_zeros);
  829. // Update the cell state.
  830. UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
  831. forget_gate_scratch, cell_gate_scratch, use_cifg,
  832. params->cell_clip);
  833. // Calculate output gate.
  834. CalculateLstmGateFloat(
  835. input_ptr, input_to_output_weights_ptr, aux_input_ptr,
  836. aux_input_to_output_weights_ptr, output_state_ptr,
  837. recurrent_to_output_weights_ptr, cell_state_ptr,
  838. cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
  839. output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
  840. /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
  841. is_aux_input_all_zeros);
  842. // Update the output state.
  843. CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
  844. output_gate_scratch, params->activation,
  845. projection_weights_ptr, projection_bias_ptr,
  846. params->proj_clip, output_state_ptr, scratch2);
  847. // Copy output state to the output. Note that the output's rows may not be
  848. // contiguous (output_batch_leading_dim != n_output).
  849. for (int b = 0; b < n_batch; b++) {
  850. std::memcpy(output_ptr + b * output_batch_leading_dim,
  851. output_state_ptr + b * n_output, n_output * sizeof(float));
  852. }
  853. }
  854. // Same as above but with quantized weight matrices. In detail:
  855. // Input of size 'n_batch * n_input':
  856. // input_ptr
  857. // Input of size 'n_batch * n_aux_input':
  858. // aux_input_ptr - optional (can be nullptr)
  859. //
  860. // LSTM weights:
  861. // Quantized input weights of size 'n_cell * n_input':
  862. // input_to_input_weights - optional
  863. // input_to_forget_weights
  864. // input_to_cell_weights
  865. // input_to_input_weights
  866. // Quantized auxiliary input weights of size 'n_cell * n_aux_input':
  867. // aux_input_to_input_weights - optional
  868. // aux_input_to_forget_weights - optional
  869. // aux_input_to_cell_weights - optional
  870. // aux_input_to_output_weights - optional
  871. // Quantized recurrent weights of size 'n_cell * n_output':
  872. // recurrent_to_input_weights - optional
  873. // recurrent_to_forget_weights
  874. // recurrent_to_cell_weights
  875. // recurrent_to_input_weights
  876. // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
  877. // cell_to_input_weights - optional
  878. // cell_to_cell_weights - optional
  879. // cell_to_output_weights - optional
  880. // Quantized projection weights of size 'n_output * n_cell'
  881. // projection_weights_ptr - optional
  882. // Weight scales (scalars) for each of the weights above.
  883. // input_to_input_weights_scale - optional
  884. // input_to_forget_weights_scale
  885. // input_to_cell_weights_scale
  886. // input_to_output_weights_scale
  887. // aux_input_to_input_weights_scale - optional
  888. // aux_input_to_forget_weights_scale - optional
  889. // aux_input_to_cell_weights_scale - optional
  890. // aux_input_to_output_weights_scale - optional
  891. // recurrent_to_input_weights_scale - optional
  892. // recurrent_to_forget_weights_scale
  893. // recurrent_to_cell_weights_scale
  894. // recurrent_to_output_weights_scale
  895. // cell_to_input_weights_scale,
  896. // cell_to_forget_weights_scale,
  897. // cell_to_output_weights_scale,
  898. // projection_weights_scale - optional
  899. // Gate biases of size 'n_cell':
  900. // input_gate_bias_ptr - optional
  901. // forget_gate_bias_ptr
  902. // cell_gate_bias_ptr
  903. // output_gate_bias_ptr
  904. //
  905. // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
  906. // input_layer_norm_coefficients_ptr - optional
  907. // forget_layer_norm_coefficients_ptr - optional
  908. // cell_layer_norm_coefficients_ptr - optional
  909. // output_layer_norm_coefficients_ptr - optional
  910. //
  911. // Temporary pre-allocated storage for quantized values:
  912. // quantized_input_ptr (same size as input_ptr)
  913. // quantized_output_state_ptr (same size as output_state_ptr)
  914. // quantized_output_scratch (same size as cell_state_ptr)
  915. // Temporary pre-allocated storage for recovered values:
  916. // recovered_cell_weights (same size as cell_to_*_weights)
  917. //
  918. // Outputs:
  919. // output_state_ptr - size 'n_batch * n_output'
  920. // cell_state_ptr - size 'n_batch * n_cell'
  921. // output_ptr - size 'n_batch * output_batch_leading_dim'
  922. inline void LstmStepHybrid(
  923. const float* input_ptr, const int8_t* input_to_input_weights_ptr,
  924. const uint8_t* input_to_input_weights_ledger_ptr,
  925. float input_to_input_weights_scale,
  926. const int8_t* input_to_forget_weights_ptr,
  927. const uint8_t* input_to_forget_weights_ledger_ptr,
  928. float input_to_forget_weights_scale,
  929. const int8_t* input_to_cell_weights_ptr,
  930. const uint8_t* input_to_cell_weights_ledger_ptr,
  931. float input_to_cell_weights_scale,
  932. const int8_t* input_to_output_weights_ptr,
  933. const uint8_t* input_to_output_weights_ledger_ptr,
  934. float input_to_output_weights_scale, const float* aux_input_ptr,
  935. const int8_t* aux_input_to_input_weights_ptr,
  936. float aux_input_to_input_weights_scale,
  937. const int8_t* aux_input_to_forget_weights_ptr,
  938. float aux_input_to_forget_weights_scale,
  939. const int8_t* aux_input_to_cell_weights_ptr,
  940. float aux_input_to_cell_weights_scale,
  941. const int8_t* aux_input_to_output_weights_ptr,
  942. float aux_input_to_output_weights_scale,
  943. const int8_t* recurrent_to_input_weights_ptr,
  944. const uint8_t* recurrent_to_input_weights_ledger_ptr,
  945. float recurrent_to_input_weights_scale,
  946. const int8_t* recurrent_to_forget_weights_ptr,
  947. const uint8_t* recurrent_to_forget_weights_ledger_ptr,
  948. float recurrent_to_forget_weights_scale,
  949. const int8_t* recurrent_to_cell_weights_ptr,
  950. const uint8_t* recurrent_to_cell_weights_ledger_ptr,
  951. float recurrent_to_cell_weights_scale,
  952. const int8_t* recurrent_to_output_weights_ptr,
  953. const uint8_t* recurrent_to_output_weights_ledger_ptr,
  954. float recurrent_to_output_weights_scale,
  955. const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
  956. const int8_t* cell_to_forget_weights_ptr,
  957. float cell_to_forget_weights_scale,
  958. const int8_t* cell_to_output_weights_ptr,
  959. float cell_to_output_weights_scale,
  960. const float* input_layer_norm_coefficients_ptr,
  961. const float* forget_layer_norm_coefficients_ptr,
  962. const float* cell_layer_norm_coefficients_ptr,
  963. const float* output_layer_norm_coefficients_ptr,
  964. const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
  965. const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
  966. const int8_t* projection_weights_ptr,
  967. const uint8_t* projection_weights_ledger_ptr,
  968. float projection_weights_scale, const float* projection_bias_ptr,
  969. const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
  970. int n_aux_input, int n_output, int output_batch_leading_dim,
  971. float* scratch0, float* scratch1, float* scratch2, float* scratch3,
  972. float* scales, float* input_sf, float* aux_input_sf, float* output_state_sf,
  973. float* scaling_factors_scratch, float* recovered_cell_weights,
  974. int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
  975. int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
  976. float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
  977. float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
  978. int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
  979. bool* compute_row_sums, bool asymmetric_quantize_inputs) {
  980. // Since we have already checked that weights are all there or none, we
  981. // can check the existence of only one to the get the condition.
  982. const bool use_cifg = (input_to_input_weights_ptr == nullptr);
  983. // Make named scratch buffers for the different gates.
  984. float* input_gate_scratch = scratch0;
  985. float* forget_gate_scratch = scratch1;
  986. float* cell_gate_scratch = scratch2;
  987. float* output_gate_scratch = scratch3;
  988. int32_t* input_to_input_row_sums = nullptr;
  989. int32_t* input_to_forget_row_sums = nullptr;
  990. int32_t* input_to_cell_row_sums = nullptr;
  991. int32_t* input_to_output_row_sums = nullptr;
  992. int32_t* aux_input_to_input_row_sums = nullptr;
  993. int32_t* aux_input_to_forget_row_sums = nullptr;
  994. int32_t* aux_input_to_cell_row_sums = nullptr;
  995. int32_t* aux_input_to_output_row_sums = nullptr;
  996. int32_t* recurrent_to_input_row_sums = nullptr;
  997. int32_t* recurrent_to_forget_row_sums = nullptr;
  998. int32_t* recurrent_to_cell_row_sums = nullptr;
  999. int32_t* recurrent_to_output_row_sums = nullptr;
  1000. int32_t* projection_weights_row_sums = nullptr;
  1001. if (asymmetric_quantize_inputs) {
  1002. int num_row_sums = use_cifg ? 6 : 8;
  1003. if (aux_input_ptr != nullptr) {
  1004. num_row_sums += use_cifg ? 3 : 4;
  1005. }
  1006. if (projection_weights_ptr != nullptr) {
  1007. num_row_sums += ceil(static_cast<float>(n_output) / n_cell);
  1008. }
  1009. TFLITE_DCHECK(row_sums_size == num_row_sums);
  1010. input_to_input_row_sums = row_sums;
  1011. input_to_forget_row_sums =
  1012. use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
  1013. input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
  1014. input_to_output_row_sums = input_to_cell_row_sums + n_cell;
  1015. if (aux_input_ptr != nullptr) {
  1016. aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
  1017. aux_input_to_forget_row_sums = use_cifg
  1018. ? aux_input_to_input_row_sums
  1019. : aux_input_to_input_row_sums + n_cell;
  1020. aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
  1021. aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
  1022. }
  1023. recurrent_to_input_row_sums = aux_input_ptr
  1024. ? aux_input_to_output_row_sums + n_cell
  1025. : input_to_output_row_sums + n_cell;
  1026. recurrent_to_forget_row_sums = use_cifg
  1027. ? recurrent_to_input_row_sums
  1028. : recurrent_to_input_row_sums + n_cell;
  1029. recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
  1030. recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
  1031. if (projection_weights_ptr != nullptr) {
  1032. projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
  1033. }
  1034. if (*compute_row_sums) {
  1035. ComputeRowSums(
  1036. input_to_input_row_sums, input_to_forget_row_sums,
  1037. input_to_cell_row_sums, input_to_output_row_sums,
  1038. aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
  1039. aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
  1040. recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
  1041. recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
  1042. projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
  1043. n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
  1044. input_to_cell_weights_ptr, input_to_output_weights_ptr,
  1045. aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
  1046. aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
  1047. recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
  1048. recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
  1049. projection_weights_ptr, use_cifg, aux_input_ptr);
  1050. *compute_row_sums = false;
  1051. }
  1052. }
  1053. // Check if inputs are all zeros so we can skip some computations.
  1054. const bool is_input_all_zeros =
  1055. micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
  1056. const bool is_aux_input_all_zeros =
  1057. (aux_input_ptr == nullptr ||
  1058. micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
  1059. const bool is_output_state_all_zeros =
  1060. micro_tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output);
  1061. // Quantize inputs.
  1062. if (!is_input_all_zeros) {
  1063. micro_tensor_utils::BatchQuantizeFloats(
  1064. input_ptr, n_batch, n_input, quantized_input_ptr, input_sf, input_zp,
  1065. asymmetric_quantize_inputs);
  1066. }
  1067. if (!is_aux_input_all_zeros) {
  1068. micro_tensor_utils::BatchQuantizeFloats(
  1069. aux_input_ptr, n_batch, n_aux_input, quantized_aux_input_ptr,
  1070. aux_input_sf, aux_input_zp, asymmetric_quantize_inputs);
  1071. }
  1072. if (!is_output_state_all_zeros) {
  1073. micro_tensor_utils::BatchQuantizeFloats(
  1074. output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
  1075. output_state_sf, output_state_zp, asymmetric_quantize_inputs);
  1076. }
  1077. if (!use_cifg) {
  1078. // Calculate the input gate. (If not CIFG.)
  1079. CalculateLstmGateHybrid(
  1080. quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
  1081. input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
  1082. input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
  1083. aux_input_zp, aux_input_to_input_weights_ptr,
  1084. aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
  1085. quantized_output_state_ptr, output_state_sf, output_state_zp,
  1086. recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
  1087. recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
  1088. cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
  1089. input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
  1090. n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
  1091. input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
  1092. is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
  1093. recovered_cell_weights, scales, accum_scratch_ptr);
  1094. }
  1095. // Calculate the forget gate.
  1096. CalculateLstmGateHybrid(
  1097. quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
  1098. input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
  1099. input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
  1100. aux_input_zp, aux_input_to_forget_weights_ptr,
  1101. aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
  1102. quantized_output_state_ptr, output_state_sf, output_state_zp,
  1103. recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
  1104. recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
  1105. cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
  1106. forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
  1107. n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
  1108. forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
  1109. is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
  1110. recovered_cell_weights, scales, accum_scratch_ptr);
  1111. // Calculate the cell update gate.
  1112. CalculateLstmGateHybrid(
  1113. quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
  1114. input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
  1115. input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
  1116. aux_input_zp, aux_input_to_cell_weights_ptr,
  1117. aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
  1118. quantized_output_state_ptr, output_state_sf, output_state_zp,
  1119. recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
  1120. recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
  1121. /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
  1122. /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
  1123. cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
  1124. params->activation, cell_gate_scratch, is_input_all_zeros,
  1125. is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums,
  1126. scaling_factors_scratch, recovered_cell_weights, scales,
  1127. accum_scratch_ptr);
  1128. // Update the cell state.
  1129. UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
  1130. forget_gate_scratch, cell_gate_scratch, use_cifg,
  1131. params->cell_clip);
  1132. // Calculate the output gate.
  1133. CalculateLstmGateHybrid(
  1134. quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
  1135. input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
  1136. input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
  1137. aux_input_zp, aux_input_to_output_weights_ptr,
  1138. aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
  1139. quantized_output_state_ptr, output_state_sf, output_state_zp,
  1140. recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
  1141. recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
  1142. cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
  1143. output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
  1144. n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
  1145. output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
  1146. is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
  1147. recovered_cell_weights, scales, accum_scratch_ptr);
  1148. // Update the output state.
  1149. CalculateLstmOutputHybrid(
  1150. n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
  1151. params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
  1152. projection_weights_scale, projection_bias_ptr, params->proj_clip,
  1153. output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
  1154. compute_row_sums, scratch2, quantized_output_scratch, input_sf, input_zp,
  1155. accum_scratch_ptr, scales);
  1156. // Copy output state to the output. Note that the output's rows may not be
  1157. // contiguous (output_batch_leading_dim != n_output).
  1158. for (int b = 0; b < n_batch; b++) {
  1159. std::memcpy(output_ptr + b * output_batch_leading_dim,
  1160. output_state_ptr + b * n_output, n_output * sizeof(float));
  1161. }
  1162. }
  1163. // Fully quantized lstm kernel for 16 bit gate matmul output.
  1164. //
  1165. // Input tensor of size n_batch * n_input:
  1166. // input_ptr
  1167. //
  1168. // LSTM weights:
  1169. // Quantized input weights of size 'n_cell * n_input':
  1170. // input_to_input_weight_ptr - optional
  1171. // input_to_forget_weight_ptr - optional
  1172. // input_to_cell_weight_ptr - optional
  1173. // input_to_output_weight_ptr - optional
  1174. //
  1175. // Quantized recurrent weights of size 'n_cell * n_output':
  1176. // recurrent_to_input_weight_ptr - optional
  1177. // recurrent_to_forget_weights_ptr
  1178. // recurrent_to_cell_weights_ptr
  1179. // recurrent_to_input_weights_ptr
  1180. //
  1181. // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
  1182. // cell_to_input_weights - optional
  1183. // cell_to_cell_weights - optional
  1184. // cell_to_output_weights - optional
  1185. //
  1186. // Quantized projection weights of size 'n_output * n_cell'
  1187. // projection_weight_ptr - optional
  1188. //
  1189. // Weight scales (scalars) for each of the weights above.
  1190. // effective_input_to_input_scale_a - optional
  1191. // effective_input_to_input_scale_b - optional
  1192. // effective_input_to_forget_scale_a
  1193. // effective_input_to_forget_scale_b
  1194. // effective_input_to_cell_scale_a
  1195. // effective_input_to_cell_scale_b
  1196. // effective_input_to_output_scale_a
  1197. // effective_input_to_output_scale_b
  1198. // effective_recurrent_to_input_scale_a - optional
  1199. // effective_recurrent_to_input_scale_b - optional
  1200. // effective_recurrent_to_forget_scale_a
  1201. // effective_recurrent_to_forget_scale_b
  1202. // effective_recurrent_to_cell_scale_a
  1203. // effective_recurrent_to_cell_scale_b
  1204. // effective_recurrent_to_output_scale_a
  1205. // effective_recurrent_to_output_scale_b
  1206. // effective_proj_scale_a - optional
  1207. // effective_proj_scale_b - optional
  1208. //
  1209. // Gate biases of size 'n_cell':
  1210. // input_gate_bias_ptr - optional
  1211. // forget_gate_bias_ptr
  1212. // cell_gate_bias_ptr
  1213. // output_gate_bias_ptr
  1214. //
  1215. // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
  1216. // layer_norm_input_weight_ptr - optional
  1217. // layer_norm_forget_weight_ptr - optional
  1218. // layer_norm_cell_weight_ptr - optional
  1219. // layer_norm_output_weight_ptr - optional
  1220. //
  1221. // Layer norm scales of size 'n_cell'.
  1222. // layer_norm_input_scale_a - optional
  1223. // layer_norm_input_scale_b - optional
  1224. // layer_norm_forget_scale_a - optional
  1225. // layer_norm_forget_scale_b - optional
  1226. // layer_norm_cell_scale_a - optional
  1227. // layer_norm_cell_scale_b - optional
  1228. // layer_norm_output_scale_a - optional
  1229. // layer_norm_output_scale_b - optional
  1230. //
  1231. // Scalar values:
  1232. // quantized_cell_clip: quantized clip value for cell.
  1233. // quantized_proj_clip: quantized clip value for projection.
  1234. // cell_state_scale: the power of two scale for cell state.
  1235. //
  1236. // Zero points:
  1237. // output_state_zp: zero point of output state
  1238. // hidden_zp: zero point for hidden state.
  1239. //
  1240. // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
  1241. // n_batch.
  1242. // scratch0
  1243. // scratch1
  1244. // scratch2
  1245. // scratch3
  1246. // scratch4
  1247. // scratch5: this scratch buffer is created purely for optimizing the
  1248. // MatrixBatchVectorMultiplyAccumulate.
  1249. //
  1250. // Outputs:
  1251. // output_state_ptr - size 'n_batch * n_output'
  1252. // cell_state_ptr - size 'n_batch * n_cell'
  1253. // output_ptr - size 'n_batch * n_output'
  1254. // TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
  1255. inline void LstmStepInteger8x8_16(
  1256. const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
  1257. int32_t effective_input_to_input_scale_a,
  1258. int32_t effective_input_to_input_scale_b,
  1259. const int8_t* input_to_forget_weight_ptr,
  1260. int32_t effective_input_to_forget_scale_a,
  1261. int32_t effective_input_to_forget_scale_b,
  1262. const int8_t* input_to_cell_weight_ptr,
  1263. int32_t effective_input_to_cell_scale_a,
  1264. int32_t effective_input_to_cell_scale_b,
  1265. const int8_t* input_to_output_weight_ptr,
  1266. int32_t effective_input_to_output_scale_a,
  1267. int32_t effective_input_to_output_scale_b,
  1268. const int8_t* recurrent_to_input_weight_ptr,
  1269. int32_t effective_recurrent_to_input_scale_a,
  1270. int32_t effective_recurrent_to_input_scale_b,
  1271. const int8_t* recurrent_to_forget_weight_ptr,
  1272. int32_t effective_recurrent_to_forget_scale_a,
  1273. int32_t effective_recurrent_to_forget_scale_b,
  1274. const int8_t* recurrent_to_cell_weight_ptr,
  1275. int32_t effective_recurrent_to_cell_scale_a,
  1276. int32_t effective_recurrent_to_cell_scale_b,
  1277. const int8_t* recurrent_to_output_weight_ptr,
  1278. int32_t effective_recurrent_to_output_scale_a,
  1279. int32_t effective_recurrent_to_output_scale_b,
  1280. const int16_t* cell_to_input_weight_ptr,
  1281. int32_t effective_cell_to_input_scale_a,
  1282. int32_t effective_cell_to_input_scale_b,
  1283. const int16_t* cell_to_forget_weight_ptr,
  1284. int32_t effective_cell_to_forget_scale_a,
  1285. int32_t effective_cell_to_forget_scale_b,
  1286. const int16_t* cell_to_output_weight_ptr,
  1287. int32_t effective_cell_to_output_scale_a,
  1288. int32_t effective_cell_to_output_scale_b,
  1289. const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
  1290. int32_t effective_proj_scale_b, int32_t hidden_zp,
  1291. int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
  1292. const int16_t* layer_norm_input_weight_ptr,
  1293. int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
  1294. const int16_t* layer_norm_forget_weight_ptr,
  1295. int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
  1296. const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
  1297. int32_t layer_norm_cell_scale_b,
  1298. const int16_t* layer_norm_output_weight_ptr,
  1299. int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
  1300. const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
  1301. const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
  1302. int16_t quantized_cell_clip, int8_t quantized_proj_clip,
  1303. int32_t cell_state_scale, int32_t input_variance_guard,
  1304. int32_t forget_variance_guard, int32_t cell_variance_guard,
  1305. int32_t output_variance_guard,
  1306. const int32_t* input_to_forget_effective_bias,
  1307. const int32_t* recurrent_to_forget_effective_bias,
  1308. const int32_t* input_to_cell_effective_bias,
  1309. const int32_t* recurrent_to_cell_effective_bias,
  1310. const int32_t* input_to_output_effective_bias,
  1311. const int32_t* recurrent_to_output_effective_bias,
  1312. const int32_t* input_to_input_effective_bias,
  1313. const int32_t* recurrent_to_input_effective_bias,
  1314. const int32_t* projection_effective_bias, int n_batch, int n_cell,
  1315. int n_input, int n_output, int8_t* output_state_ptr,
  1316. int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
  1317. int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
  1318. int8_t* scratch4, int32_t* scratch5) {
  1319. // Make named scratch buffers for the different gates.
  1320. int16_t* input_gate_scratch = scratch0;
  1321. int16_t* forget_gate_scratch = scratch1;
  1322. int16_t* cell_gate_scratch = scratch2;
  1323. int16_t* output_gate_scratch = scratch3;
  1324. // Since we have already checked that weights are all there or none, we
  1325. // can check the existence of only one to the get the condition.
  1326. const bool use_cifg = (input_to_input_weight_ptr == nullptr);
  1327. // Check for nullptrs.
  1328. TFLITE_DCHECK(input_to_forget_effective_bias);
  1329. TFLITE_DCHECK(recurrent_to_forget_effective_bias);
  1330. TFLITE_DCHECK(input_to_cell_effective_bias);
  1331. TFLITE_DCHECK(recurrent_to_cell_effective_bias);
  1332. TFLITE_DCHECK(input_to_output_effective_bias);
  1333. TFLITE_DCHECK(recurrent_to_output_effective_bias);
  1334. if (!use_cifg) {
  1335. TFLITE_DCHECK(input_to_input_effective_bias);
  1336. TFLITE_DCHECK(recurrent_to_input_effective_bias);
  1337. }
  1338. const bool use_projection = (projection_weight_ptr != nullptr);
  1339. if (use_projection) {
  1340. TFLITE_DCHECK(projection_effective_bias);
  1341. }
  1342. if (!use_cifg) {
  1343. // Calculate the input gate. (If not CIFG.)
  1344. CalculateLstmGateInteger8x8_16(
  1345. input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
  1346. effective_input_to_input_scale_a, effective_input_to_input_scale_b,
  1347. output_state_ptr, recurrent_to_input_weight_ptr,
  1348. recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
  1349. effective_recurrent_to_input_scale_b, cell_state_ptr,
  1350. cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
  1351. effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
  1352. input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
  1353. input_variance_guard, n_batch, n_input, n_output, n_cell,
  1354. kTfLiteActSigmoid, input_gate_scratch, scratch5);
  1355. }
  1356. // Calculate the forget gate.
  1357. CalculateLstmGateInteger8x8_16(
  1358. input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
  1359. effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
  1360. output_state_ptr, recurrent_to_forget_weight_ptr,
  1361. recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
  1362. effective_recurrent_to_forget_scale_b, cell_state_ptr,
  1363. cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
  1364. effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
  1365. forget_gate_bias_ptr, layer_norm_forget_scale_a,
  1366. layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
  1367. n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, scratch5);
  1368. // Calculate the cell update gate.
  1369. CalculateLstmGateInteger8x8_16(
  1370. input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
  1371. effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
  1372. output_state_ptr, recurrent_to_cell_weight_ptr,
  1373. recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
  1374. effective_recurrent_to_cell_scale_b, cell_state_ptr,
  1375. /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
  1376. /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
  1377. cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
  1378. cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
  1379. cell_gate_scratch, scratch5);
  1380. // Update the cell state.
  1381. UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
  1382. input_gate_scratch, forget_gate_scratch,
  1383. cell_gate_scratch, use_cifg, quantized_cell_clip);
  1384. // Calculate the output gate.
  1385. CalculateLstmGateInteger8x8_16(
  1386. input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
  1387. effective_input_to_output_scale_a, effective_input_to_output_scale_b,
  1388. output_state_ptr, recurrent_to_output_weight_ptr,
  1389. recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
  1390. effective_recurrent_to_output_scale_b, cell_state_ptr,
  1391. cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
  1392. effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
  1393. output_gate_bias_ptr, layer_norm_output_scale_a,
  1394. layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
  1395. n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, scratch5);
  1396. // Update the output state.
  1397. CalculateLstmOutputInteger8x8_16(
  1398. n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
  1399. output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
  1400. hidden_zp, projection_weight_ptr, effective_proj_scale_a,
  1401. effective_proj_scale_b, projection_effective_bias, output_state_zp,
  1402. quantized_proj_clip, output_state_ptr, scratch0, scratch4, scratch5);
  1403. // Copy output state to the output. Note that unlike float or hybrid, output
  1404. // is always contiguous.
  1405. std::memcpy(output_ptr, output_state_ptr,
  1406. n_batch * n_output * sizeof(int8_t));
  1407. }
  1408. // Fully quantized lstm kernel for 8 bit gate matmul output.
  1409. //
  1410. // Input tensor of size n_batch * n_input:
  1411. // input_ptr
  1412. //
  1413. // LSTM weights:
  1414. // Quantized input weights of size 'n_cell * n_input':
  1415. // input_to_input_weight_ptr - optional
  1416. // input_to_forget_weight_ptr - optional
  1417. // input_to_cell_weight_ptr - optional
  1418. // input_to_output_weight_ptr - optional
  1419. //
  1420. // Quantized recurrent weights of size 'n_cell * n_output':
  1421. // recurrent_to_input_weight_ptr - optional
  1422. // recurrent_to_forget_weights_ptr
  1423. // recurrent_to_cell_weights_ptr
  1424. // recurrent_to_input_weights_ptr
  1425. //
  1426. // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
  1427. // cell_to_input_weights - optional
  1428. // cell_to_cell_weights - optional
  1429. // cell_to_output_weights - optional
  1430. //
  1431. // Quantized projection weights of size 'n_output * n_cell'
  1432. // projection_weight_ptr - optional
  1433. //
  1434. // Weight scales (scalars) for each of the weights above.
  1435. // effective_input_to_input_scale_a - optional
  1436. // effective_input_to_input_scale_b - optional
  1437. // effective_input_to_forget_scale_a
  1438. // effective_input_to_forget_scale_b
  1439. // effective_input_to_cell_scale_a
  1440. // effective_input_to_cell_scale_b
  1441. // effective_input_to_output_scale_a
  1442. // effective_input_to_output_scale_b
  1443. // effective_recurrent_to_input_scale_a - optional
  1444. // effective_recurrent_to_input_scale_b - optional
  1445. // effective_recurrent_to_forget_scale_a
  1446. // effective_recurrent_to_forget_scale_b
  1447. // effective_recurrent_to_cell_scale_a
  1448. // effective_recurrent_to_cell_scale_b
  1449. // effective_recurrent_to_output_scale_a
  1450. // effective_recurrent_to_output_scale_b
  1451. // effective_proj_scale_a - optional
  1452. // effective_proj_scale_b - optional
  1453. //
  1454. // Gate biases of size 'n_cell':
  1455. // input_gate_bias_ptr - optional
  1456. // forget_gate_bias_ptr
  1457. // cell_gate_bias_ptr
  1458. // output_gate_bias_ptr
  1459. //
  1460. // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
  1461. // layer_norm_input_weight_ptr - optional
  1462. // layer_norm_forget_weight_ptr - optional
  1463. // layer_norm_cell_weight_ptr - optional
  1464. // layer_norm_output_weight_ptr - optional
  1465. //
  1466. // Layer norm scales of size 'n_cell'.
  1467. // layer_norm_input_scale_a - optional
  1468. // layer_norm_input_scale_b - optional
  1469. // layer_norm_forget_scale_a - optional
  1470. // layer_norm_forget_scale_b - optional
  1471. // layer_norm_cell_scale_a - optional
  1472. // layer_norm_cell_scale_b - optional
  1473. // layer_norm_output_scale_a - optional
  1474. // layer_norm_output_scale_b - optional
  1475. //
  1476. // Scalar values:
  1477. // quantized_cell_clip: quantized clip value for cell.
  1478. // quantized_proj_clip: quantized clip value for projection.
  1479. // cell_state_scale: the power of two scale for cell state.
  1480. //
  1481. // Zero points:
  1482. // input_zp: zero point for input tensor.
  1483. // output_state_zp: zero point of output state.
  1484. // hidden_zp: zero point for hidden state.
  1485. //
  1486. // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
  1487. // n_batch.
  1488. // scratch0
  1489. // scratch1
  1490. // scratch2
  1491. // scratch3
  1492. // scratch4
  1493. // scratch5
  1494. // scratch6
  1495. // scratch7
  1496. //
  1497. // Outputs:
  1498. // output_state_ptr - size 'n_batch * n_output'
  1499. // cell_state_ptr - size 'n_batch * n_cell'
  1500. // output_ptr - size 'n_batch * n_output'
  1501. //
  1502. // Can move zero point calculation into Prepare() for better perfomance.
  1503. // TODO(b/159947023): scratch5 is unused, remove.
  1504. inline void LstmStepInteger8x8_8(
  1505. const int8_t* input_ptr, int32_t input_zp,
  1506. const int8_t* input_to_input_weight_ptr,
  1507. int32_t effective_input_to_input_scale_a,
  1508. int32_t effective_input_to_input_scale_b,
  1509. const int8_t* input_to_forget_weight_ptr,
  1510. int32_t effective_input_to_forget_scale_a,
  1511. int32_t effective_input_to_forget_scale_b,
  1512. const int8_t* input_to_cell_weight_ptr,
  1513. int32_t effective_input_to_cell_scale_a,
  1514. int32_t effective_input_to_cell_scale_b,
  1515. const int8_t* input_to_output_weight_ptr,
  1516. int32_t effective_input_to_output_scale_a,
  1517. int32_t effective_input_to_output_scale_b,
  1518. const int8_t* recurrent_to_input_weight_ptr,
  1519. int32_t effective_recurrent_to_input_scale_a,
  1520. int32_t effective_recurrent_to_input_scale_b,
  1521. const int8_t* recurrent_to_forget_weight_ptr,
  1522. int32_t effective_recurrent_to_forget_scale_a,
  1523. int32_t effective_recurrent_to_forget_scale_b,
  1524. const int8_t* recurrent_to_cell_weight_ptr,
  1525. int32_t effective_recurrent_to_cell_scale_a,
  1526. int32_t effective_recurrent_to_cell_scale_b,
  1527. const int8_t* recurrent_to_output_weight_ptr,
  1528. int32_t effective_recurrent_to_output_scale_a,
  1529. int32_t effective_recurrent_to_output_scale_b,
  1530. const int8_t* cell_to_input_weight_ptr,
  1531. int32_t effective_cell_to_input_scale_a,
  1532. int32_t effective_cell_to_input_scale_b,
  1533. const int8_t* cell_to_forget_weight_ptr,
  1534. int32_t effective_cell_to_forget_scale_a,
  1535. int32_t effective_cell_to_forget_scale_b,
  1536. const int8_t* cell_to_output_weight_ptr,
  1537. int32_t effective_cell_to_output_scale_a,
  1538. int32_t effective_cell_to_output_scale_b,
  1539. const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
  1540. int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
  1541. int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
  1542. const int16_t* layer_norm_forget_weight_ptr,
  1543. int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
  1544. const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
  1545. int32_t layer_norm_cell_scale_b,
  1546. const int16_t* layer_norm_output_weight_ptr,
  1547. int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
  1548. const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
  1549. const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
  1550. const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
  1551. const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
  1552. const int32_t* intermediate_zp, int16_t quantized_cell_clip,
  1553. int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
  1554. int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
  1555. int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
  1556. int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
  1557. int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
  1558. int16_t* scratch7) {
  1559. // TODO(b/159066113): scratch5 is unused, remove.
  1560. // Make named scratch buffers for the different gates.
  1561. int16_t* forget_gate_scratch = scratch2;
  1562. int16_t* cell_gate_scratch = scratch3;
  1563. int16_t* output_gate_scratch = scratch4;
  1564. // no-CIFG is not supported here
  1565. // Calculate the forget gate.
  1566. CalculateLstmGateInteger8x8_8(
  1567. input_ptr, input_zp, input_to_forget_weight_ptr,
  1568. effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
  1569. intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
  1570. output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
  1571. effective_recurrent_to_forget_scale_a,
  1572. effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
  1573. intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
  1574. layer_norm_forget_scale_a, layer_norm_forget_scale_b,
  1575. forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
  1576. kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
  1577. // Calculate the cell update gate.
  1578. CalculateLstmGateInteger8x8_8(
  1579. input_ptr, input_zp, input_to_cell_weight_ptr,
  1580. effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
  1581. intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
  1582. output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
  1583. effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
  1584. intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
  1585. layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
  1586. layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
  1587. n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
  1588. // Update the cell state.
  1589. UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
  1590. /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
  1591. forget_gate_scratch, cell_gate_scratch,
  1592. /*use_cifg=*/true, quantized_cell_clip);
  1593. // Calculate the output gate.
  1594. CalculateLstmGateInteger8x8_8(
  1595. input_ptr, input_zp, input_to_output_weight_ptr,
  1596. effective_input_to_output_scale_a, effective_input_to_output_scale_b,
  1597. intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
  1598. output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
  1599. effective_recurrent_to_output_scale_a,
  1600. effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
  1601. intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
  1602. layer_norm_output_scale_a, layer_norm_output_scale_b,
  1603. output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
  1604. kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
  1605. // Update the output state.
  1606. CalculateLstmOutputInteger8x8_8(
  1607. n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
  1608. projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
  1609. projection_bias_ptr, output_state_zp, quantized_proj_clip,
  1610. output_state_ptr, scratch2);
  1611. // Copy output state to the output. Note that unlike float or hybrid, output
  1612. // is always contigous.
  1613. std::memcpy(output_ptr, output_state_ptr,
  1614. n_batch * n_output * sizeof(int8_t));
  1615. }
  1616. } // namespace
  1617. TfLiteStatus EvalFloatLstm(
  1618. const TfLiteEvalTensor* input,
  1619. const TfLiteEvalTensor* input_to_input_weights,
  1620. const TfLiteEvalTensor* input_to_forget_weights,
  1621. const TfLiteEvalTensor* input_to_cell_weights,
  1622. const TfLiteEvalTensor* input_to_output_weights,
  1623. const TfLiteEvalTensor* recurrent_to_input_weights,
  1624. const TfLiteEvalTensor* recurrent_to_forget_weights,
  1625. const TfLiteEvalTensor* recurrent_to_cell_weights,
  1626. const TfLiteEvalTensor* recurrent_to_output_weights,
  1627. const TfLiteEvalTensor* cell_to_input_weights,
  1628. const TfLiteEvalTensor* cell_to_forget_weights,
  1629. const TfLiteEvalTensor* cell_to_output_weights,
  1630. const TfLiteEvalTensor* input_layer_norm_coefficients,
  1631. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  1632. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  1633. const TfLiteEvalTensor* output_layer_norm_coefficients,
  1634. const TfLiteEvalTensor* aux_input,
  1635. const TfLiteEvalTensor* aux_input_to_input_weights,
  1636. const TfLiteEvalTensor* aux_input_to_forget_weights,
  1637. const TfLiteEvalTensor* aux_input_to_cell_weights,
  1638. const TfLiteEvalTensor* aux_input_to_output_weights,
  1639. const TfLiteEvalTensor* input_gate_bias,
  1640. const TfLiteEvalTensor* forget_gate_bias,
  1641. const TfLiteEvalTensor* cell_gate_bias,
  1642. const TfLiteEvalTensor* output_gate_bias,
  1643. const TfLiteEvalTensor* projection_weights,
  1644. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  1645. bool forward_sequence, bool time_major, int output_offset,
  1646. float* scratch_buffer, TfLiteEvalTensor* output_state,
  1647. TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output) {
  1648. TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
  1649. int max_time, n_batch;
  1650. if (input->dims->size == 3) {
  1651. max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
  1652. n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
  1653. } else {
  1654. max_time = 1;
  1655. n_batch = input->dims->data[0];
  1656. }
  1657. const int n_input = input->dims->data[input->dims->size - 1];
  1658. const int aux_input_size =
  1659. (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
  1660. // n_cell and n_output will be the same size when there is no projection.
  1661. const int n_cell = input_to_output_weights->dims->data[0];
  1662. const int n_output = recurrent_to_output_weights->dims->data[1];
  1663. // Since we have already checked that weights are all there or none, we can
  1664. // check the existence of only one to the get the condition.
  1665. const bool use_cifg = (input_to_input_weights == nullptr);
  1666. // Index the scratch buffers pointers to the global scratch buffer.
  1667. float* input_gate_scratch = nullptr;
  1668. float* cell_gate_scratch = nullptr;
  1669. float* forget_gate_scratch = nullptr;
  1670. float* output_gate_scratch = nullptr;
  1671. if (use_cifg) {
  1672. cell_gate_scratch = scratch_buffer;
  1673. forget_gate_scratch = scratch_buffer + n_cell * n_batch;
  1674. output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
  1675. } else {
  1676. input_gate_scratch = scratch_buffer;
  1677. cell_gate_scratch = scratch_buffer + n_cell * n_batch;
  1678. forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
  1679. output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch;
  1680. }
  1681. const int output_batch_leading_dim =
  1682. output->dims->data[output->dims->size - 1];
  1683. if (time_major) {
  1684. // Loop through the sequence.
  1685. const int input_step = n_batch * n_input;
  1686. const int output_step = n_batch * output_batch_leading_dim;
  1687. for (int t = 0; t < max_time; t++) {
  1688. // If this is the forward_sequence, step forward, otherwise step
  1689. // backwards.
  1690. const int t_rel = forward_sequence ? t : max_time - t - 1;
  1691. const float* input_ptr =
  1692. tflite::micro::GetTensorData<float>(input) + t_rel * input_step;
  1693. const float* aux_input_ptr = nullptr;
  1694. if (aux_input) {
  1695. aux_input_ptr =
  1696. tflite::micro::GetTensorData<float>(aux_input) + t_rel * input_step;
  1697. }
  1698. float* output_ptr = tflite::micro::GetTensorData<float>(output) +
  1699. t_rel * output_step + output_offset;
  1700. LstmStepFloat(
  1701. input_ptr,
  1702. input_to_input_weights == nullptr
  1703. ? nullptr
  1704. : tflite::micro::GetTensorData<float>(input_to_input_weights),
  1705. input_to_forget_weights == nullptr
  1706. ? nullptr
  1707. : tflite::micro::GetTensorData<float>(input_to_forget_weights),
  1708. input_to_cell_weights == nullptr
  1709. ? nullptr
  1710. : tflite::micro::GetTensorData<float>(input_to_cell_weights),
  1711. input_to_output_weights == nullptr
  1712. ? nullptr
  1713. : tflite::micro::GetTensorData<float>(input_to_output_weights),
  1714. aux_input_ptr,
  1715. aux_input_to_input_weights == nullptr
  1716. ? nullptr
  1717. : tflite::micro::GetTensorData<float>(aux_input_to_input_weights),
  1718. aux_input_to_forget_weights == nullptr
  1719. ? nullptr
  1720. : tflite::micro::GetTensorData<float>(
  1721. aux_input_to_forget_weights),
  1722. aux_input_to_cell_weights == nullptr
  1723. ? nullptr
  1724. : tflite::micro::GetTensorData<float>(aux_input_to_cell_weights),
  1725. aux_input_to_output_weights == nullptr
  1726. ? nullptr
  1727. : tflite::micro::GetTensorData<float>(
  1728. aux_input_to_output_weights),
  1729. recurrent_to_input_weights == nullptr
  1730. ? nullptr
  1731. : tflite::micro::GetTensorData<float>(recurrent_to_input_weights),
  1732. recurrent_to_forget_weights == nullptr
  1733. ? nullptr
  1734. : tflite::micro::GetTensorData<float>(
  1735. recurrent_to_forget_weights),
  1736. recurrent_to_cell_weights == nullptr
  1737. ? nullptr
  1738. : tflite::micro::GetTensorData<float>(recurrent_to_cell_weights),
  1739. recurrent_to_output_weights == nullptr
  1740. ? nullptr
  1741. : tflite::micro::GetTensorData<float>(
  1742. recurrent_to_output_weights),
  1743. cell_to_input_weights == nullptr
  1744. ? nullptr
  1745. : tflite::micro::GetTensorData<float>(cell_to_input_weights),
  1746. cell_to_forget_weights == nullptr
  1747. ? nullptr
  1748. : tflite::micro::GetTensorData<float>(cell_to_forget_weights),
  1749. cell_to_output_weights == nullptr
  1750. ? nullptr
  1751. : tflite::micro::GetTensorData<float>(cell_to_output_weights),
  1752. input_layer_norm_coefficients == nullptr
  1753. ? nullptr
  1754. : tflite::micro::GetTensorData<float>(
  1755. input_layer_norm_coefficients),
  1756. forget_layer_norm_coefficients == nullptr
  1757. ? nullptr
  1758. : tflite::micro::GetTensorData<float>(
  1759. forget_layer_norm_coefficients),
  1760. cell_layer_norm_coefficients == nullptr
  1761. ? nullptr
  1762. : tflite::micro::GetTensorData<float>(
  1763. cell_layer_norm_coefficients),
  1764. output_layer_norm_coefficients == nullptr
  1765. ? nullptr
  1766. : tflite::micro::GetTensorData<float>(
  1767. output_layer_norm_coefficients),
  1768. input_gate_bias == nullptr
  1769. ? nullptr
  1770. : tflite::micro::GetTensorData<float>(input_gate_bias),
  1771. forget_gate_bias == nullptr
  1772. ? nullptr
  1773. : tflite::micro::GetTensorData<float>(forget_gate_bias),
  1774. cell_gate_bias == nullptr
  1775. ? nullptr
  1776. : tflite::micro::GetTensorData<float>(cell_gate_bias),
  1777. output_gate_bias == nullptr
  1778. ? nullptr
  1779. : tflite::micro::GetTensorData<float>(output_gate_bias),
  1780. projection_weights == nullptr
  1781. ? nullptr
  1782. : tflite::micro::GetTensorData<float>(projection_weights),
  1783. projection_bias == nullptr
  1784. ? nullptr
  1785. : tflite::micro::GetTensorData<float>(projection_bias),
  1786. params, n_batch, n_cell, n_input, aux_input_size, n_output,
  1787. output_batch_leading_dim,
  1788. tflite::micro::GetTensorData<float>(output_state),
  1789. tflite::micro::GetTensorData<float>(cell_state), input_gate_scratch,
  1790. forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
  1791. output_ptr);
  1792. }
  1793. } else {
  1794. for (int b = 0; b < n_batch; b++) {
  1795. const int input_step = n_input;
  1796. const int output_step = output_batch_leading_dim;
  1797. for (int t = 0; t < max_time; t++) {
  1798. // If this is the forward_sequence, step forward, otherwise step
  1799. // backwards.
  1800. const int t_rel = forward_sequence ? t : max_time - t - 1;
  1801. const int time_offset = b * max_time + t_rel;
  1802. const float* input_ptr = tflite::micro::GetTensorData<float>(input) +
  1803. time_offset * input_step;
  1804. const float* aux_input_ptr = nullptr;
  1805. if (aux_input) {
  1806. aux_input_ptr = tflite::micro::GetTensorData<float>(aux_input) +
  1807. time_offset * input_step;
  1808. }
  1809. float* output_ptr = tflite::micro::GetTensorData<float>(output) +
  1810. time_offset * output_step + output_offset;
  1811. // Offset the {output,cell}_state pointers to the right batch.
  1812. float* output_state_ptr =
  1813. tflite::micro::GetTensorData<float>(output_state) +
  1814. b * output_batch_leading_dim;
  1815. float* cell_state_ptr =
  1816. tflite::micro::GetTensorData<float>(cell_state) + b * n_cell;
  1817. // Offset the scratch pointers to the right batch.
  1818. float* input_gate_scratch_ptr =
  1819. input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
  1820. float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
  1821. float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
  1822. float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
  1823. LstmStepFloat(
  1824. input_ptr,
  1825. input_to_input_weights == nullptr
  1826. ? nullptr
  1827. : tflite::micro::GetTensorData<float>(input_to_input_weights),
  1828. input_to_forget_weights == nullptr
  1829. ? nullptr
  1830. : tflite::micro::GetTensorData<float>(input_to_forget_weights),
  1831. input_to_cell_weights == nullptr
  1832. ? nullptr
  1833. : tflite::micro::GetTensorData<float>(input_to_cell_weights),
  1834. input_to_output_weights == nullptr
  1835. ? nullptr
  1836. : tflite::micro::GetTensorData<float>(input_to_output_weights),
  1837. aux_input_ptr,
  1838. aux_input_to_input_weights == nullptr
  1839. ? nullptr
  1840. : tflite::micro::GetTensorData<float>(
  1841. aux_input_to_input_weights),
  1842. aux_input_to_forget_weights == nullptr
  1843. ? nullptr
  1844. : tflite::micro::GetTensorData<float>(
  1845. aux_input_to_forget_weights),
  1846. aux_input_to_cell_weights == nullptr
  1847. ? nullptr
  1848. : tflite::micro::GetTensorData<float>(
  1849. aux_input_to_cell_weights),
  1850. aux_input_to_output_weights == nullptr
  1851. ? nullptr
  1852. : tflite::micro::GetTensorData<float>(
  1853. aux_input_to_output_weights),
  1854. recurrent_to_input_weights == nullptr
  1855. ? nullptr
  1856. : tflite::micro::GetTensorData<float>(
  1857. recurrent_to_input_weights),
  1858. recurrent_to_forget_weights == nullptr
  1859. ? nullptr
  1860. : tflite::micro::GetTensorData<float>(
  1861. recurrent_to_forget_weights),
  1862. recurrent_to_cell_weights == nullptr
  1863. ? nullptr
  1864. : tflite::micro::GetTensorData<float>(
  1865. recurrent_to_cell_weights),
  1866. recurrent_to_output_weights == nullptr
  1867. ? nullptr
  1868. : tflite::micro::GetTensorData<float>(
  1869. recurrent_to_output_weights),
  1870. cell_to_input_weights == nullptr
  1871. ? nullptr
  1872. : tflite::micro::GetTensorData<float>(cell_to_input_weights),
  1873. cell_to_forget_weights == nullptr
  1874. ? nullptr
  1875. : tflite::micro::GetTensorData<float>(cell_to_forget_weights),
  1876. cell_to_output_weights == nullptr
  1877. ? nullptr
  1878. : tflite::micro::GetTensorData<float>(cell_to_output_weights),
  1879. input_layer_norm_coefficients == nullptr
  1880. ? nullptr
  1881. : tflite::micro::GetTensorData<float>(
  1882. input_layer_norm_coefficients),
  1883. forget_layer_norm_coefficients == nullptr
  1884. ? nullptr
  1885. : tflite::micro::GetTensorData<float>(
  1886. forget_layer_norm_coefficients),
  1887. cell_layer_norm_coefficients == nullptr
  1888. ? nullptr
  1889. : tflite::micro::GetTensorData<float>(
  1890. cell_layer_norm_coefficients),
  1891. output_layer_norm_coefficients == nullptr
  1892. ? nullptr
  1893. : tflite::micro::GetTensorData<float>(
  1894. output_layer_norm_coefficients),
  1895. input_gate_bias == nullptr
  1896. ? nullptr
  1897. : tflite::micro::GetTensorData<float>(input_gate_bias),
  1898. forget_gate_bias == nullptr
  1899. ? nullptr
  1900. : tflite::micro::GetTensorData<float>(forget_gate_bias),
  1901. cell_gate_bias == nullptr
  1902. ? nullptr
  1903. : tflite::micro::GetTensorData<float>(cell_gate_bias),
  1904. output_gate_bias == nullptr
  1905. ? nullptr
  1906. : tflite::micro::GetTensorData<float>(output_gate_bias),
  1907. projection_weights == nullptr
  1908. ? nullptr
  1909. : tflite::micro::GetTensorData<float>(projection_weights),
  1910. projection_bias == nullptr
  1911. ? nullptr
  1912. : tflite::micro::GetTensorData<float>(projection_bias),
  1913. params,
  1914. /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
  1915. output_batch_leading_dim, output_state_ptr, cell_state_ptr,
  1916. input_gate_scratch_ptr, forget_gate_scratch_ptr,
  1917. cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
  1918. }
  1919. }
  1920. }
  1921. return kTfLiteOk;
  1922. }
  1923. TfLiteStatus EvalHybridLstm(
  1924. const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
  1925. const TfLiteEvalTensor* input_to_input_weights,
  1926. const TfLiteEvalTensor* input_to_input_weights_ledger,
  1927. const TfLiteEvalTensor* input_to_forget_weights,
  1928. const TfLiteEvalTensor* input_to_forget_weights_ledger,
  1929. const TfLiteEvalTensor* input_to_cell_weights,
  1930. const TfLiteEvalTensor* input_to_cell_weights_ledger,
  1931. const TfLiteEvalTensor* input_to_output_weights,
  1932. const TfLiteEvalTensor* input_to_output_weights_ledger,
  1933. const TfLiteEvalTensor* recurrent_to_input_weights,
  1934. const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
  1935. const TfLiteEvalTensor* recurrent_to_forget_weights,
  1936. const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
  1937. const TfLiteEvalTensor* recurrent_to_cell_weights,
  1938. const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
  1939. const TfLiteEvalTensor* recurrent_to_output_weights,
  1940. const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
  1941. const TfLiteEvalTensor* cell_to_input_weights,
  1942. const TfLiteEvalTensor* cell_to_forget_weights,
  1943. const TfLiteEvalTensor* cell_to_output_weights,
  1944. const TfLiteEvalTensor* input_layer_norm_coefficients,
  1945. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  1946. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  1947. const TfLiteEvalTensor* output_layer_norm_coefficients,
  1948. const TfLiteEvalTensor* aux_input,
  1949. const TfLiteEvalTensor* aux_input_to_input_weights,
  1950. const TfLiteEvalTensor* aux_input_to_forget_weights,
  1951. const TfLiteEvalTensor* aux_input_to_cell_weights,
  1952. const TfLiteEvalTensor* aux_input_to_output_weights,
  1953. const TfLiteEvalTensor* input_gate_bias,
  1954. const TfLiteEvalTensor* forget_gate_bias,
  1955. const TfLiteEvalTensor* cell_gate_bias,
  1956. const TfLiteEvalTensor* output_gate_bias,
  1957. const TfLiteEvalTensor* projection_weights,
  1958. const TfLiteEvalTensor* projection_weights_ledger,
  1959. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  1960. bool forward_sequence, bool time_major, int output_offset,
  1961. float* scratch_buffer, float* input_sf, float* aux_input_sf,
  1962. float* output_state_sf, float* prod_scaling_factors,
  1963. float* recovered_cell_weights, int8_t* input_quantized,
  1964. int8_t* aux_input_quantized, int8_t* output_state_quantized,
  1965. int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
  1966. TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
  1967. TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
  1968. int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
  1969. bool* compute_row_sums) {
  1970. TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
  1971. const int n_input = input->dims->data[input->dims->size - 1];
  1972. int max_time, n_batch;
  1973. if (input->dims->size == 2) {
  1974. max_time = 1;
  1975. n_batch = input->dims->data[0];
  1976. } else {
  1977. max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
  1978. n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
  1979. }
  1980. const int aux_input_size =
  1981. (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
  1982. // n_cell and n_output will be the same size when there is no projection.
  1983. const int n_cell = input_to_output_weights->dims->data[0];
  1984. const int n_output = recurrent_to_output_weights->dims->data[1];
  1985. // Since we have already checked that weights are all there or none, we can
  1986. // check the existence of only one to get the condition.
  1987. const bool use_cifg = (input_to_input_weights == nullptr);
  1988. float* input_gate_scratch = nullptr;
  1989. float* cell_gate_scratch = nullptr;
  1990. float* forget_gate_scratch = nullptr;
  1991. float* output_gate_scratch = nullptr;
  1992. if (use_cifg) {
  1993. cell_gate_scratch = scratch_buffer;
  1994. forget_gate_scratch = scratch_buffer + n_cell * n_batch;
  1995. output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
  1996. } else {
  1997. input_gate_scratch = scratch_buffer;
  1998. cell_gate_scratch = scratch_buffer + n_cell * n_batch;
  1999. forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
  2000. output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch;
  2001. }
  2002. const int output_batch_leading_dim =
  2003. output->dims->data[output->dims->size - 1];
  2004. int32_t* input_zp_ptr = nullptr;
  2005. int32_t* aux_input_zp_ptr = nullptr;
  2006. int32_t* output_state_zp_ptr = nullptr;
  2007. int32_t* row_sums_ptr = nullptr;
  2008. if (params->asymmetric_quantize_inputs) {
  2009. input_zp_ptr = input_zp;
  2010. aux_input_zp_ptr = aux_input_zp;
  2011. output_state_zp_ptr = output_state_zp;
  2012. row_sums_ptr = row_sums;
  2013. }
  2014. if (time_major) {
  2015. // Feed the sequence into the LSTM step-by-step.
  2016. const int input_step = n_batch * n_input;
  2017. const int output_step = n_batch * output_batch_leading_dim;
  2018. for (int t = 0; t < max_time; t++) {
  2019. // If this is the forward_sequence, step forward, otherwise step
  2020. // backwards.
  2021. const int t_rel = forward_sequence ? t : max_time - t - 1;
  2022. const float* input_ptr =
  2023. tflite::micro::GetTensorData<float>(input) + t_rel * input_step;
  2024. const float* aux_input_ptr = nullptr;
  2025. if (aux_input) {
  2026. aux_input_ptr =
  2027. tflite::micro::GetTensorData<float>(aux_input) + t_rel * input_step;
  2028. }
  2029. float* output_ptr = tflite::micro::GetTensorData<float>(output) +
  2030. t_rel * output_step + output_offset;
  2031. LstmStepHybrid(
  2032. input_ptr,
  2033. input_to_input_weights == nullptr
  2034. ? nullptr
  2035. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2036. input_to_input_weights_ledger == nullptr
  2037. ? nullptr
  2038. : tflite::micro::GetTensorData<uint8_t>(
  2039. input_to_input_weights_ledger),
  2040. hybrid_lstm_scales->input_to_input_weights_scale,
  2041. input_to_forget_weights == nullptr
  2042. ? nullptr
  2043. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2044. input_to_forget_weights_ledger == nullptr
  2045. ? nullptr
  2046. : tflite::micro::GetTensorData<uint8_t>(
  2047. input_to_forget_weights_ledger),
  2048. hybrid_lstm_scales->input_to_forget_weights_scale,
  2049. input_to_cell_weights == nullptr
  2050. ? nullptr
  2051. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2052. input_to_cell_weights_ledger == nullptr
  2053. ? nullptr
  2054. : tflite::micro::GetTensorData<uint8_t>(
  2055. input_to_cell_weights_ledger),
  2056. hybrid_lstm_scales->input_to_cell_weights_scale,
  2057. input_to_output_weights == nullptr
  2058. ? nullptr
  2059. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2060. input_to_output_weights_ledger == nullptr
  2061. ? nullptr
  2062. : tflite::micro::GetTensorData<uint8_t>(
  2063. input_to_output_weights_ledger),
  2064. hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr,
  2065. aux_input_to_input_weights == nullptr
  2066. ? nullptr
  2067. : tflite::micro::GetTensorData<int8_t>(
  2068. aux_input_to_input_weights),
  2069. hybrid_lstm_scales->aux_input_to_input_weights_scale,
  2070. aux_input_to_forget_weights == nullptr
  2071. ? nullptr
  2072. : tflite::micro::GetTensorData<int8_t>(
  2073. aux_input_to_forget_weights),
  2074. hybrid_lstm_scales->aux_input_to_forget_weights_scale,
  2075. aux_input_to_cell_weights == nullptr
  2076. ? nullptr
  2077. : tflite::micro::GetTensorData<int8_t>(aux_input_to_cell_weights),
  2078. hybrid_lstm_scales->aux_input_to_cell_weights_scale,
  2079. aux_input_to_output_weights == nullptr
  2080. ? nullptr
  2081. : tflite::micro::GetTensorData<int8_t>(
  2082. aux_input_to_output_weights),
  2083. hybrid_lstm_scales->aux_input_to_output_weights_scale,
  2084. recurrent_to_input_weights == nullptr
  2085. ? nullptr
  2086. : tflite::micro::GetTensorData<int8_t>(
  2087. recurrent_to_input_weights),
  2088. recurrent_to_input_weights_ledger == nullptr
  2089. ? nullptr
  2090. : tflite::micro::GetTensorData<uint8_t>(
  2091. recurrent_to_input_weights_ledger),
  2092. hybrid_lstm_scales->recurrent_to_input_weights_scale,
  2093. recurrent_to_forget_weights == nullptr
  2094. ? nullptr
  2095. : tflite::micro::GetTensorData<int8_t>(
  2096. recurrent_to_forget_weights),
  2097. recurrent_to_forget_weights_ledger == nullptr
  2098. ? nullptr
  2099. : tflite::micro::GetTensorData<uint8_t>(
  2100. recurrent_to_forget_weights_ledger),
  2101. hybrid_lstm_scales->recurrent_to_forget_weights_scale,
  2102. recurrent_to_cell_weights == nullptr
  2103. ? nullptr
  2104. : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
  2105. recurrent_to_cell_weights_ledger == nullptr
  2106. ? nullptr
  2107. : tflite::micro::GetTensorData<uint8_t>(
  2108. recurrent_to_cell_weights_ledger),
  2109. hybrid_lstm_scales->recurrent_to_cell_weights_scale,
  2110. recurrent_to_output_weights == nullptr
  2111. ? nullptr
  2112. : tflite::micro::GetTensorData<int8_t>(
  2113. recurrent_to_output_weights),
  2114. recurrent_to_output_weights_ledger == nullptr
  2115. ? nullptr
  2116. : tflite::micro::GetTensorData<uint8_t>(
  2117. recurrent_to_output_weights_ledger),
  2118. hybrid_lstm_scales->recurrent_to_output_weights_scale,
  2119. cell_to_input_weights == nullptr
  2120. ? nullptr
  2121. : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
  2122. hybrid_lstm_scales->cell_to_input_weights_scale,
  2123. cell_to_forget_weights == nullptr
  2124. ? nullptr
  2125. : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
  2126. hybrid_lstm_scales->cell_to_forget_weights_scale,
  2127. cell_to_output_weights == nullptr
  2128. ? nullptr
  2129. : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
  2130. hybrid_lstm_scales->cell_to_output_weights_scale,
  2131. input_layer_norm_coefficients == nullptr
  2132. ? nullptr
  2133. : tflite::micro::GetTensorData<float>(
  2134. input_layer_norm_coefficients),
  2135. forget_layer_norm_coefficients == nullptr
  2136. ? nullptr
  2137. : tflite::micro::GetTensorData<float>(
  2138. forget_layer_norm_coefficients),
  2139. cell_layer_norm_coefficients == nullptr
  2140. ? nullptr
  2141. : tflite::micro::GetTensorData<float>(
  2142. cell_layer_norm_coefficients),
  2143. output_layer_norm_coefficients == nullptr
  2144. ? nullptr
  2145. : tflite::micro::GetTensorData<float>(
  2146. output_layer_norm_coefficients),
  2147. input_gate_bias == nullptr
  2148. ? nullptr
  2149. : tflite::micro::GetTensorData<float>(input_gate_bias),
  2150. forget_gate_bias == nullptr
  2151. ? nullptr
  2152. : tflite::micro::GetTensorData<float>(forget_gate_bias),
  2153. cell_gate_bias == nullptr
  2154. ? nullptr
  2155. : tflite::micro::GetTensorData<float>(cell_gate_bias),
  2156. output_gate_bias == nullptr
  2157. ? nullptr
  2158. : tflite::micro::GetTensorData<float>(output_gate_bias),
  2159. projection_weights == nullptr
  2160. ? nullptr
  2161. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2162. projection_weights_ledger == nullptr
  2163. ? nullptr
  2164. : tflite::micro::GetTensorData<uint8_t>(
  2165. projection_weights_ledger),
  2166. hybrid_lstm_scales->projection_weights_scale,
  2167. projection_bias == nullptr
  2168. ? nullptr
  2169. : tflite::micro::GetTensorData<float>(projection_bias),
  2170. params, n_batch, n_cell, n_input, aux_input_size, n_output,
  2171. output_batch_leading_dim, input_gate_scratch, forget_gate_scratch,
  2172. cell_gate_scratch, output_gate_scratch, scales, input_sf,
  2173. aux_input_sf, output_state_sf, prod_scaling_factors,
  2174. recovered_cell_weights, input_quantized, aux_input_quantized,
  2175. output_state_quantized, cell_state_quantized,
  2176. tflite::micro::GetTensorData<float>(output_state),
  2177. tflite::micro::GetTensorData<float>(cell_state),
  2178. output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr,
  2179. output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
  2180. params->asymmetric_quantize_inputs);
  2181. }
  2182. } else {
  2183. for (int b = 0; b < n_batch; b++) {
  2184. const int input_step = n_input;
  2185. const int output_step = output_batch_leading_dim;
  2186. for (int t = 0; t < max_time; t++) {
  2187. // If this is the forward_sequence, step forward, otherwise step
  2188. // backwards.
  2189. const int t_rel = forward_sequence ? t : max_time - t - 1;
  2190. const int time_offset = b * max_time + t_rel;
  2191. const float* input_ptr = tflite::micro::GetTensorData<float>(input) +
  2192. time_offset * input_step;
  2193. const float* aux_input_ptr = nullptr;
  2194. if (aux_input) {
  2195. aux_input_ptr = tflite::micro::GetTensorData<float>(aux_input) +
  2196. time_offset * input_step;
  2197. }
  2198. float* output_ptr = tflite::micro::GetTensorData<float>(output) +
  2199. time_offset * output_step + output_offset;
  2200. // Offset the {output,cell}_state pointers to the right batch.
  2201. float* output_state_ptr =
  2202. tflite::micro::GetTensorData<float>(output_state) +
  2203. b * output_batch_leading_dim;
  2204. float* cell_state_ptr =
  2205. tflite::micro::GetTensorData<float>(cell_state) + b * n_cell;
  2206. // Offset the scratch pointers to the right batch.
  2207. float* input_gate_scratch_ptr =
  2208. input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
  2209. float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
  2210. float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
  2211. float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
  2212. LstmStepHybrid(
  2213. input_ptr,
  2214. input_to_input_weights == nullptr
  2215. ? nullptr
  2216. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2217. input_to_input_weights_ledger == nullptr
  2218. ? nullptr
  2219. : tflite::micro::GetTensorData<uint8_t>(
  2220. input_to_input_weights_ledger),
  2221. hybrid_lstm_scales->input_to_input_weights_scale,
  2222. input_to_forget_weights == nullptr
  2223. ? nullptr
  2224. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2225. input_to_forget_weights_ledger == nullptr
  2226. ? nullptr
  2227. : tflite::micro::GetTensorData<uint8_t>(
  2228. input_to_forget_weights_ledger),
  2229. hybrid_lstm_scales->input_to_forget_weights_scale,
  2230. input_to_cell_weights == nullptr
  2231. ? nullptr
  2232. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2233. input_to_cell_weights_ledger == nullptr
  2234. ? nullptr
  2235. : tflite::micro::GetTensorData<uint8_t>(
  2236. input_to_cell_weights_ledger),
  2237. hybrid_lstm_scales->input_to_cell_weights_scale,
  2238. input_to_output_weights == nullptr
  2239. ? nullptr
  2240. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2241. input_to_output_weights_ledger == nullptr
  2242. ? nullptr
  2243. : tflite::micro::GetTensorData<uint8_t>(
  2244. input_to_output_weights_ledger),
  2245. hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr,
  2246. aux_input_to_input_weights == nullptr
  2247. ? nullptr
  2248. : tflite::micro::GetTensorData<int8_t>(
  2249. aux_input_to_input_weights),
  2250. hybrid_lstm_scales->aux_input_to_input_weights_scale,
  2251. aux_input_to_forget_weights == nullptr
  2252. ? nullptr
  2253. : tflite::micro::GetTensorData<int8_t>(
  2254. aux_input_to_forget_weights),
  2255. hybrid_lstm_scales->aux_input_to_forget_weights_scale,
  2256. aux_input_to_cell_weights == nullptr
  2257. ? nullptr
  2258. : tflite::micro::GetTensorData<int8_t>(
  2259. aux_input_to_cell_weights),
  2260. hybrid_lstm_scales->aux_input_to_cell_weights_scale,
  2261. aux_input_to_output_weights == nullptr
  2262. ? nullptr
  2263. : tflite::micro::GetTensorData<int8_t>(
  2264. aux_input_to_output_weights),
  2265. hybrid_lstm_scales->aux_input_to_output_weights_scale,
  2266. recurrent_to_input_weights == nullptr
  2267. ? nullptr
  2268. : tflite::micro::GetTensorData<int8_t>(
  2269. recurrent_to_input_weights),
  2270. recurrent_to_input_weights_ledger == nullptr
  2271. ? nullptr
  2272. : tflite::micro::GetTensorData<uint8_t>(
  2273. recurrent_to_input_weights_ledger),
  2274. hybrid_lstm_scales->recurrent_to_input_weights_scale,
  2275. recurrent_to_forget_weights == nullptr
  2276. ? nullptr
  2277. : tflite::micro::GetTensorData<int8_t>(
  2278. recurrent_to_forget_weights),
  2279. recurrent_to_forget_weights_ledger == nullptr
  2280. ? nullptr
  2281. : tflite::micro::GetTensorData<uint8_t>(
  2282. recurrent_to_forget_weights_ledger),
  2283. hybrid_lstm_scales->recurrent_to_forget_weights_scale,
  2284. recurrent_to_cell_weights == nullptr
  2285. ? nullptr
  2286. : tflite::micro::GetTensorData<int8_t>(
  2287. recurrent_to_cell_weights),
  2288. recurrent_to_cell_weights_ledger == nullptr
  2289. ? nullptr
  2290. : tflite::micro::GetTensorData<uint8_t>(
  2291. recurrent_to_cell_weights_ledger),
  2292. hybrid_lstm_scales->recurrent_to_cell_weights_scale,
  2293. recurrent_to_output_weights == nullptr
  2294. ? nullptr
  2295. : tflite::micro::GetTensorData<int8_t>(
  2296. recurrent_to_output_weights),
  2297. recurrent_to_output_weights_ledger == nullptr
  2298. ? nullptr
  2299. : tflite::micro::GetTensorData<uint8_t>(
  2300. recurrent_to_output_weights_ledger),
  2301. hybrid_lstm_scales->recurrent_to_output_weights_scale,
  2302. cell_to_input_weights == nullptr
  2303. ? nullptr
  2304. : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
  2305. hybrid_lstm_scales->cell_to_input_weights_scale,
  2306. cell_to_forget_weights == nullptr
  2307. ? nullptr
  2308. : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
  2309. hybrid_lstm_scales->cell_to_forget_weights_scale,
  2310. cell_to_output_weights == nullptr
  2311. ? nullptr
  2312. : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
  2313. hybrid_lstm_scales->cell_to_output_weights_scale,
  2314. input_layer_norm_coefficients == nullptr
  2315. ? nullptr
  2316. : tflite::micro::GetTensorData<float>(
  2317. input_layer_norm_coefficients),
  2318. forget_layer_norm_coefficients == nullptr
  2319. ? nullptr
  2320. : tflite::micro::GetTensorData<float>(
  2321. forget_layer_norm_coefficients),
  2322. cell_layer_norm_coefficients == nullptr
  2323. ? nullptr
  2324. : tflite::micro::GetTensorData<float>(
  2325. cell_layer_norm_coefficients),
  2326. output_layer_norm_coefficients == nullptr
  2327. ? nullptr
  2328. : tflite::micro::GetTensorData<float>(
  2329. output_layer_norm_coefficients),
  2330. input_gate_bias == nullptr
  2331. ? nullptr
  2332. : tflite::micro::GetTensorData<float>(input_gate_bias),
  2333. forget_gate_bias == nullptr
  2334. ? nullptr
  2335. : tflite::micro::GetTensorData<float>(forget_gate_bias),
  2336. cell_gate_bias == nullptr
  2337. ? nullptr
  2338. : tflite::micro::GetTensorData<float>(cell_gate_bias),
  2339. output_gate_bias == nullptr
  2340. ? nullptr
  2341. : tflite::micro::GetTensorData<float>(output_gate_bias),
  2342. projection_weights == nullptr
  2343. ? nullptr
  2344. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2345. projection_weights_ledger == nullptr
  2346. ? nullptr
  2347. : tflite::micro::GetTensorData<uint8_t>(
  2348. projection_weights_ledger),
  2349. hybrid_lstm_scales->projection_weights_scale,
  2350. projection_bias == nullptr
  2351. ? nullptr
  2352. : tflite::micro::GetTensorData<float>(projection_bias),
  2353. params,
  2354. /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
  2355. output_batch_leading_dim, input_gate_scratch_ptr,
  2356. forget_gate_scratch_ptr, cell_gate_scratch_ptr,
  2357. output_gate_scratch_ptr, scales, input_sf, aux_input_sf,
  2358. output_state_sf, prod_scaling_factors, recovered_cell_weights,
  2359. input_quantized, aux_input_quantized, output_state_quantized,
  2360. cell_state_quantized, output_state_ptr, cell_state_ptr,
  2361. output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr,
  2362. output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
  2363. params->asymmetric_quantize_inputs);
  2364. }
  2365. }
  2366. }
  2367. return kTfLiteOk;
  2368. }
  2369. TfLiteStatus EvalInteger8x8_16Lstm(
  2370. const TfLiteEvalTensor* input,
  2371. const TfLiteEvalTensor* input_to_input_weights,
  2372. const TfLiteEvalTensor* input_to_forget_weights,
  2373. const TfLiteEvalTensor* input_to_cell_weights,
  2374. const TfLiteEvalTensor* input_to_output_weights,
  2375. const TfLiteEvalTensor* recurrent_to_input_weights,
  2376. const TfLiteEvalTensor* recurrent_to_forget_weights,
  2377. const TfLiteEvalTensor* recurrent_to_cell_weights,
  2378. const TfLiteEvalTensor* recurrent_to_output_weights,
  2379. const TfLiteEvalTensor* cell_to_input_weights,
  2380. const TfLiteEvalTensor* cell_to_forget_weights,
  2381. const TfLiteEvalTensor* cell_to_output_weights,
  2382. const TfLiteEvalTensor* input_layer_norm_coefficients,
  2383. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  2384. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  2385. const TfLiteEvalTensor* output_layer_norm_coefficients,
  2386. const TfLiteEvalTensor* input_gate_bias,
  2387. const TfLiteEvalTensor* forget_gate_bias,
  2388. const TfLiteEvalTensor* cell_gate_bias,
  2389. const TfLiteEvalTensor* output_gate_bias,
  2390. const TfLiteEvalTensor* projection_weights,
  2391. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  2392. bool forward_sequence, bool time_major,
  2393. const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
  2394. TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
  2395. TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
  2396. int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5) {
  2397. TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
  2398. const int n_input = input->dims->data[input->dims->size - 1];
  2399. int max_time, n_batch;
  2400. if (input->dims->size == 2) {
  2401. max_time = 1;
  2402. n_batch = input->dims->data[0];
  2403. } else {
  2404. max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
  2405. n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
  2406. }
  2407. // n_cell and n_output will be the same size when there is no projection.
  2408. const int n_cell = input_to_output_weights->dims->data[0];
  2409. const int n_output = recurrent_to_output_weights->dims->data[1];
  2410. // Get params for time/batch/sequence.
  2411. const int output_batch_leading_dim =
  2412. output->dims->data[output->dims->size - 1];
  2413. if (time_major) {
  2414. const int input_step = n_batch * n_input;
  2415. const int output_step = n_batch * output_batch_leading_dim;
  2416. for (int t = 0; t < max_time; t++) {
  2417. const int t_rel = t;
  2418. int8_t* output_ptr =
  2419. tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
  2420. const int8_t* input_ptr =
  2421. tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
  2422. LstmStepInteger8x8_16(
  2423. input_ptr,
  2424. input_to_input_weights == nullptr
  2425. ? nullptr
  2426. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2427. integer_lstm_param->effective_input_to_input_scale_a,
  2428. integer_lstm_param->effective_input_to_input_scale_b,
  2429. input_to_forget_weights == nullptr
  2430. ? nullptr
  2431. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2432. integer_lstm_param->effective_input_to_forget_scale_a,
  2433. integer_lstm_param->effective_input_to_forget_scale_b,
  2434. input_to_cell_weights == nullptr
  2435. ? nullptr
  2436. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2437. integer_lstm_param->effective_input_to_cell_scale_a,
  2438. integer_lstm_param->effective_input_to_cell_scale_b,
  2439. input_to_output_weights == nullptr
  2440. ? nullptr
  2441. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2442. integer_lstm_param->effective_input_to_output_scale_a,
  2443. integer_lstm_param->effective_input_to_output_scale_b,
  2444. recurrent_to_input_weights == nullptr
  2445. ? nullptr
  2446. : tflite::micro::GetTensorData<int8_t>(
  2447. recurrent_to_input_weights),
  2448. integer_lstm_param->effective_recurrent_to_input_scale_a,
  2449. integer_lstm_param->effective_recurrent_to_input_scale_b,
  2450. recurrent_to_forget_weights == nullptr
  2451. ? nullptr
  2452. : tflite::micro::GetTensorData<int8_t>(
  2453. recurrent_to_forget_weights),
  2454. integer_lstm_param->effective_recurrent_to_forget_scale_a,
  2455. integer_lstm_param->effective_recurrent_to_forget_scale_b,
  2456. recurrent_to_cell_weights == nullptr
  2457. ? nullptr
  2458. : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
  2459. integer_lstm_param->effective_recurrent_to_cell_scale_a,
  2460. integer_lstm_param->effective_recurrent_to_cell_scale_b,
  2461. recurrent_to_output_weights == nullptr
  2462. ? nullptr
  2463. : tflite::micro::GetTensorData<int8_t>(
  2464. recurrent_to_output_weights),
  2465. integer_lstm_param->effective_recurrent_to_output_scale_a,
  2466. integer_lstm_param->effective_recurrent_to_output_scale_b,
  2467. cell_to_input_weights == nullptr
  2468. ? nullptr
  2469. : tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
  2470. integer_lstm_param->effective_cell_to_input_scale_a,
  2471. integer_lstm_param->effective_cell_to_input_scale_b,
  2472. cell_to_forget_weights == nullptr
  2473. ? nullptr
  2474. : tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
  2475. integer_lstm_param->effective_cell_to_forget_scale_a,
  2476. integer_lstm_param->effective_cell_to_forget_scale_b,
  2477. cell_to_output_weights == nullptr
  2478. ? nullptr
  2479. : tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
  2480. integer_lstm_param->effective_cell_to_output_scale_a,
  2481. integer_lstm_param->effective_cell_to_output_scale_b,
  2482. projection_weights == nullptr
  2483. ? nullptr
  2484. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2485. integer_lstm_param->effective_proj_scale_a,
  2486. integer_lstm_param->effective_proj_scale_b,
  2487. integer_lstm_param->hidden_zp,
  2488. integer_lstm_param->effective_hidden_scale_a,
  2489. integer_lstm_param->effective_hidden_scale_b,
  2490. input_layer_norm_coefficients == nullptr
  2491. ? nullptr
  2492. : tflite::micro::GetTensorData<int16_t>(
  2493. input_layer_norm_coefficients),
  2494. integer_lstm_param->layer_norm_input_scale_a,
  2495. integer_lstm_param->layer_norm_input_scale_b,
  2496. forget_layer_norm_coefficients == nullptr
  2497. ? nullptr
  2498. : tflite::micro::GetTensorData<int16_t>(
  2499. forget_layer_norm_coefficients),
  2500. integer_lstm_param->layer_norm_forget_scale_a,
  2501. integer_lstm_param->layer_norm_forget_scale_b,
  2502. cell_layer_norm_coefficients == nullptr
  2503. ? nullptr
  2504. : tflite::micro::GetTensorData<int16_t>(
  2505. cell_layer_norm_coefficients),
  2506. integer_lstm_param->layer_norm_cell_scale_a,
  2507. integer_lstm_param->layer_norm_cell_scale_b,
  2508. output_layer_norm_coefficients == nullptr
  2509. ? nullptr
  2510. : tflite::micro::GetTensorData<int16_t>(
  2511. output_layer_norm_coefficients),
  2512. integer_lstm_param->layer_norm_output_scale_a,
  2513. integer_lstm_param->layer_norm_output_scale_b,
  2514. input_gate_bias == nullptr
  2515. ? nullptr
  2516. : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
  2517. forget_gate_bias == nullptr
  2518. ? nullptr
  2519. : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
  2520. cell_gate_bias == nullptr
  2521. ? nullptr
  2522. : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
  2523. output_gate_bias == nullptr
  2524. ? nullptr
  2525. : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
  2526. integer_lstm_param->quantized_cell_clip,
  2527. integer_lstm_param->quantized_proj_clip,
  2528. integer_lstm_param->cell_scale,
  2529. integer_lstm_param->input_variance_guard,
  2530. integer_lstm_param->forget_variance_guard,
  2531. integer_lstm_param->cell_variance_guard,
  2532. integer_lstm_param->output_variance_guard,
  2533. integer_lstm_param->input_to_forget_effective_bias,
  2534. integer_lstm_param->recurrent_to_forget_effective_bias,
  2535. integer_lstm_param->input_to_cell_effective_bias,
  2536. integer_lstm_param->recurrent_to_cell_effective_bias,
  2537. integer_lstm_param->input_to_output_effective_bias,
  2538. integer_lstm_param->recurrent_to_output_effective_bias,
  2539. integer_lstm_param->input_to_input_effective_bias,
  2540. integer_lstm_param->recurrent_to_input_effective_bias,
  2541. integer_lstm_param->projection_effective_bias, n_batch, n_cell,
  2542. n_input, n_output, tflite::micro::GetTensorData<int8_t>(output_state),
  2543. output_state_zp, tflite::micro::GetTensorData<int16_t>(cell_state),
  2544. output_ptr, scratch0, scratch1, scratch2, scratch3, scratch4,
  2545. scratch5);
  2546. }
  2547. } else {
  2548. for (int b = 0; b < n_batch; b++) {
  2549. const int input_step = n_input;
  2550. const int output_step = output_batch_leading_dim;
  2551. for (int t = 0; t < max_time; t++) {
  2552. // If this is the forward_sequence, step forward, otherwise step
  2553. // backwards.
  2554. const int t_rel = forward_sequence ? t : max_time - t - 1;
  2555. const int time_offset = b * max_time + t_rel;
  2556. const int8_t* input_ptr = tflite::micro::GetTensorData<int8_t>(input) +
  2557. time_offset * input_step;
  2558. int8_t* output_ptr = tflite::micro::GetTensorData<int8_t>(output) +
  2559. time_offset * output_step;
  2560. // Offset the {output,cell}_state pointers to the right batch.
  2561. int8_t* output_state_ptr =
  2562. tflite::micro::GetTensorData<int8_t>(output_state) +
  2563. b * output_batch_leading_dim;
  2564. int16_t* cell_state_ptr =
  2565. tflite::micro::GetTensorData<int16_t>(cell_state) + b * n_cell;
  2566. LstmStepInteger8x8_16(
  2567. input_ptr,
  2568. input_to_input_weights == nullptr
  2569. ? nullptr
  2570. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2571. integer_lstm_param->effective_input_to_input_scale_a,
  2572. integer_lstm_param->effective_input_to_input_scale_b,
  2573. input_to_forget_weights == nullptr
  2574. ? nullptr
  2575. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2576. integer_lstm_param->effective_input_to_forget_scale_a,
  2577. integer_lstm_param->effective_input_to_forget_scale_b,
  2578. input_to_cell_weights == nullptr
  2579. ? nullptr
  2580. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2581. integer_lstm_param->effective_input_to_cell_scale_a,
  2582. integer_lstm_param->effective_input_to_cell_scale_b,
  2583. input_to_output_weights == nullptr
  2584. ? nullptr
  2585. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2586. integer_lstm_param->effective_input_to_output_scale_a,
  2587. integer_lstm_param->effective_input_to_output_scale_b,
  2588. recurrent_to_input_weights == nullptr
  2589. ? nullptr
  2590. : tflite::micro::GetTensorData<int8_t>(
  2591. recurrent_to_input_weights),
  2592. integer_lstm_param->effective_recurrent_to_input_scale_a,
  2593. integer_lstm_param->effective_recurrent_to_input_scale_b,
  2594. recurrent_to_forget_weights == nullptr
  2595. ? nullptr
  2596. : tflite::micro::GetTensorData<int8_t>(
  2597. recurrent_to_forget_weights),
  2598. integer_lstm_param->effective_recurrent_to_forget_scale_a,
  2599. integer_lstm_param->effective_recurrent_to_forget_scale_b,
  2600. recurrent_to_cell_weights == nullptr
  2601. ? nullptr
  2602. : tflite::micro::GetTensorData<int8_t>(
  2603. recurrent_to_cell_weights),
  2604. integer_lstm_param->effective_recurrent_to_cell_scale_a,
  2605. integer_lstm_param->effective_recurrent_to_cell_scale_b,
  2606. recurrent_to_output_weights == nullptr
  2607. ? nullptr
  2608. : tflite::micro::GetTensorData<int8_t>(
  2609. recurrent_to_output_weights),
  2610. integer_lstm_param->effective_recurrent_to_output_scale_a,
  2611. integer_lstm_param->effective_recurrent_to_output_scale_b,
  2612. cell_to_input_weights == nullptr
  2613. ? nullptr
  2614. : tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
  2615. integer_lstm_param->effective_cell_to_input_scale_a,
  2616. integer_lstm_param->effective_cell_to_input_scale_b,
  2617. cell_to_forget_weights == nullptr
  2618. ? nullptr
  2619. : tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
  2620. integer_lstm_param->effective_cell_to_forget_scale_a,
  2621. integer_lstm_param->effective_cell_to_forget_scale_b,
  2622. cell_to_output_weights == nullptr
  2623. ? nullptr
  2624. : tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
  2625. integer_lstm_param->effective_cell_to_output_scale_a,
  2626. integer_lstm_param->effective_cell_to_output_scale_b,
  2627. projection_weights == nullptr
  2628. ? nullptr
  2629. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2630. integer_lstm_param->effective_proj_scale_a,
  2631. integer_lstm_param->effective_proj_scale_b,
  2632. integer_lstm_param->hidden_zp,
  2633. integer_lstm_param->effective_hidden_scale_a,
  2634. integer_lstm_param->effective_hidden_scale_b,
  2635. input_layer_norm_coefficients == nullptr
  2636. ? nullptr
  2637. : tflite::micro::GetTensorData<int16_t>(
  2638. input_layer_norm_coefficients),
  2639. integer_lstm_param->layer_norm_input_scale_a,
  2640. integer_lstm_param->layer_norm_input_scale_b,
  2641. forget_layer_norm_coefficients == nullptr
  2642. ? nullptr
  2643. : tflite::micro::GetTensorData<int16_t>(
  2644. forget_layer_norm_coefficients),
  2645. integer_lstm_param->layer_norm_forget_scale_a,
  2646. integer_lstm_param->layer_norm_forget_scale_b,
  2647. cell_layer_norm_coefficients == nullptr
  2648. ? nullptr
  2649. : tflite::micro::GetTensorData<int16_t>(
  2650. cell_layer_norm_coefficients),
  2651. integer_lstm_param->layer_norm_cell_scale_a,
  2652. integer_lstm_param->layer_norm_cell_scale_b,
  2653. output_layer_norm_coefficients == nullptr
  2654. ? nullptr
  2655. : tflite::micro::GetTensorData<int16_t>(
  2656. output_layer_norm_coefficients),
  2657. integer_lstm_param->layer_norm_output_scale_a,
  2658. integer_lstm_param->layer_norm_output_scale_b,
  2659. input_gate_bias == nullptr
  2660. ? nullptr
  2661. : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
  2662. forget_gate_bias == nullptr
  2663. ? nullptr
  2664. : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
  2665. cell_gate_bias == nullptr
  2666. ? nullptr
  2667. : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
  2668. output_gate_bias == nullptr
  2669. ? nullptr
  2670. : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
  2671. integer_lstm_param->quantized_cell_clip,
  2672. integer_lstm_param->quantized_proj_clip,
  2673. integer_lstm_param->cell_scale,
  2674. integer_lstm_param->input_variance_guard,
  2675. integer_lstm_param->forget_variance_guard,
  2676. integer_lstm_param->cell_variance_guard,
  2677. integer_lstm_param->output_variance_guard,
  2678. integer_lstm_param->input_to_forget_effective_bias,
  2679. integer_lstm_param->recurrent_to_forget_effective_bias,
  2680. integer_lstm_param->input_to_cell_effective_bias,
  2681. integer_lstm_param->recurrent_to_cell_effective_bias,
  2682. integer_lstm_param->input_to_output_effective_bias,
  2683. integer_lstm_param->recurrent_to_output_effective_bias,
  2684. integer_lstm_param->input_to_input_effective_bias,
  2685. integer_lstm_param->recurrent_to_input_effective_bias,
  2686. integer_lstm_param->projection_effective_bias, /*n_batch=*/1,
  2687. n_cell, n_input, n_output, output_state_ptr, output_state_zp,
  2688. cell_state_ptr, output_ptr, scratch0, scratch1, scratch2, scratch3,
  2689. scratch4, scratch5);
  2690. }
  2691. }
  2692. }
  2693. return kTfLiteOk;
  2694. }
  2695. TfLiteStatus EvalInteger8x8_8Lstm(
  2696. const TfLiteEvalTensor* input,
  2697. const TfLiteEvalTensor* input_to_input_weights,
  2698. const TfLiteEvalTensor* input_to_forget_weights,
  2699. const TfLiteEvalTensor* input_to_cell_weights,
  2700. const TfLiteEvalTensor* input_to_output_weights,
  2701. const TfLiteEvalTensor* recurrent_to_input_weights,
  2702. const TfLiteEvalTensor* recurrent_to_forget_weights,
  2703. const TfLiteEvalTensor* recurrent_to_cell_weights,
  2704. const TfLiteEvalTensor* recurrent_to_output_weights,
  2705. const TfLiteEvalTensor* cell_to_input_weights,
  2706. const TfLiteEvalTensor* cell_to_forget_weights,
  2707. const TfLiteEvalTensor* cell_to_output_weights,
  2708. const TfLiteEvalTensor* input_layer_norm_coefficients,
  2709. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  2710. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  2711. const TfLiteEvalTensor* output_layer_norm_coefficients,
  2712. const TfLiteEvalTensor* input_gate_bias,
  2713. const TfLiteEvalTensor* forget_gate_bias,
  2714. const TfLiteEvalTensor* cell_gate_bias,
  2715. const TfLiteEvalTensor* output_gate_bias,
  2716. const TfLiteEvalTensor* projection_weights,
  2717. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  2718. TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
  2719. TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
  2720. int32_t input_zp, int32_t output_state_zp, int8_t* scratch0,
  2721. int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, int16_t* scratch4,
  2722. int16_t* scratch5, int16_t* scratch6, int16_t* scratch7) {
  2723. TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
  2724. const int n_input = input->dims->data[input->dims->size - 1];
  2725. int max_time, n_batch;
  2726. if (input->dims->size == 2) {
  2727. max_time = 1;
  2728. n_batch = input->dims->data[0];
  2729. } else {
  2730. max_time = input->dims->data[0];
  2731. n_batch = input->dims->data[1];
  2732. }
  2733. // n_cell and n_output will be the same size when there is no projection.
  2734. const int n_cell = input_to_output_weights->dims->data[0];
  2735. const int n_output = recurrent_to_output_weights->dims->data[1];
  2736. // Get params for time/batch/sequence.
  2737. const int output_batch_leading_dim =
  2738. output->dims->data[output->dims->size - 1];
  2739. const int input_step = n_batch * n_input;
  2740. const int output_step = n_batch * output_batch_leading_dim;
  2741. for (int t = 0; t < max_time; t++) {
  2742. const int t_rel = t;
  2743. int8_t* output_ptr =
  2744. tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
  2745. // Input can be int8 asymmetric or int16 symmetric.
  2746. const int8_t* input_ptr =
  2747. tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
  2748. LstmStepInteger8x8_8(
  2749. input_ptr, input_zp,
  2750. input_to_input_weights == nullptr
  2751. ? nullptr
  2752. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2753. integer_lstm_param->effective_input_to_input_scale_a,
  2754. integer_lstm_param->effective_input_to_input_scale_b,
  2755. input_to_forget_weights == nullptr
  2756. ? nullptr
  2757. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2758. integer_lstm_param->effective_input_to_forget_scale_a,
  2759. integer_lstm_param->effective_input_to_forget_scale_b,
  2760. input_to_cell_weights == nullptr
  2761. ? nullptr
  2762. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2763. integer_lstm_param->effective_input_to_cell_scale_a,
  2764. integer_lstm_param->effective_input_to_cell_scale_b,
  2765. input_to_output_weights == nullptr
  2766. ? nullptr
  2767. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2768. integer_lstm_param->effective_input_to_output_scale_a,
  2769. integer_lstm_param->effective_input_to_output_scale_b,
  2770. recurrent_to_input_weights == nullptr
  2771. ? nullptr
  2772. : tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
  2773. integer_lstm_param->effective_recurrent_to_input_scale_a,
  2774. integer_lstm_param->effective_recurrent_to_input_scale_b,
  2775. recurrent_to_forget_weights == nullptr
  2776. ? nullptr
  2777. : tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
  2778. integer_lstm_param->effective_recurrent_to_forget_scale_a,
  2779. integer_lstm_param->effective_recurrent_to_forget_scale_b,
  2780. recurrent_to_cell_weights == nullptr
  2781. ? nullptr
  2782. : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
  2783. integer_lstm_param->effective_recurrent_to_cell_scale_a,
  2784. integer_lstm_param->effective_recurrent_to_cell_scale_b,
  2785. recurrent_to_output_weights == nullptr
  2786. ? nullptr
  2787. : tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
  2788. integer_lstm_param->effective_recurrent_to_output_scale_a,
  2789. integer_lstm_param->effective_recurrent_to_output_scale_b,
  2790. cell_to_input_weights == nullptr
  2791. ? nullptr
  2792. : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
  2793. integer_lstm_param->effective_cell_to_input_scale_a,
  2794. integer_lstm_param->effective_cell_to_input_scale_b,
  2795. cell_to_forget_weights == nullptr
  2796. ? nullptr
  2797. : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
  2798. integer_lstm_param->effective_cell_to_forget_scale_a,
  2799. integer_lstm_param->effective_cell_to_forget_scale_b,
  2800. cell_to_output_weights == nullptr
  2801. ? nullptr
  2802. : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
  2803. integer_lstm_param->effective_cell_to_output_scale_a,
  2804. integer_lstm_param->effective_cell_to_output_scale_b,
  2805. projection_weights == nullptr
  2806. ? nullptr
  2807. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2808. integer_lstm_param->effective_proj_scale_a,
  2809. integer_lstm_param->effective_proj_scale_b,
  2810. input_layer_norm_coefficients == nullptr
  2811. ? nullptr
  2812. : tflite::micro::GetTensorData<int16_t>(
  2813. input_layer_norm_coefficients),
  2814. integer_lstm_param->layer_norm_input_scale_a,
  2815. integer_lstm_param->layer_norm_input_scale_b,
  2816. forget_layer_norm_coefficients == nullptr
  2817. ? nullptr
  2818. : tflite::micro::GetTensorData<int16_t>(
  2819. forget_layer_norm_coefficients),
  2820. integer_lstm_param->layer_norm_forget_scale_a,
  2821. integer_lstm_param->layer_norm_forget_scale_b,
  2822. cell_layer_norm_coefficients == nullptr
  2823. ? nullptr
  2824. : tflite::micro::GetTensorData<int16_t>(
  2825. cell_layer_norm_coefficients),
  2826. integer_lstm_param->layer_norm_cell_scale_a,
  2827. integer_lstm_param->layer_norm_cell_scale_b,
  2828. output_layer_norm_coefficients == nullptr
  2829. ? nullptr
  2830. : tflite::micro::GetTensorData<int16_t>(
  2831. output_layer_norm_coefficients),
  2832. integer_lstm_param->layer_norm_output_scale_a,
  2833. integer_lstm_param->layer_norm_output_scale_b,
  2834. input_gate_bias == nullptr
  2835. ? nullptr
  2836. : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
  2837. forget_gate_bias == nullptr
  2838. ? nullptr
  2839. : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
  2840. cell_gate_bias == nullptr
  2841. ? nullptr
  2842. : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
  2843. output_gate_bias == nullptr
  2844. ? nullptr
  2845. : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
  2846. projection_bias == nullptr
  2847. ? nullptr
  2848. : tflite::micro::GetTensorData<int32_t>(projection_bias),
  2849. params, integer_lstm_param->intermediate_scale_a,
  2850. integer_lstm_param->intermediate_scale_b,
  2851. integer_lstm_param->intermediate_zp,
  2852. integer_lstm_param->quantized_cell_clip,
  2853. integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
  2854. n_output, output_batch_leading_dim,
  2855. tflite::micro::GetTensorData<int8_t>(output_state), output_state_zp,
  2856. tflite::micro::GetTensorData<int16_t>(cell_state), output_ptr, scratch0,
  2857. scratch1, scratch2, scratch3, scratch4, scratch5, scratch6, scratch7);
  2858. }
  2859. return kTfLiteOk;
  2860. }
  2861. } // namespace tflite