lstm_eval.cc 140 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981
  1. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/kernels/lstm_eval.h"
  13. #include <cmath>
  14. #include <cstdint>
  15. #include <cstring>
  16. #include <memory>
  17. #include "tensorflow/lite/c/builtin_op_data.h"
  18. #include "tensorflow/lite/c/common.h"
  19. #include "tensorflow/lite/kernels/internal/compatibility.h"
  20. #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
  21. #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
  22. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  23. #include "tensorflow/lite/kernels/op_macros.h"
  24. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  25. #include "tensorflow/lite/micro/kernels/micro_tensor_utils.h"
  26. namespace tflite {
  27. namespace {
  28. void ComputeRowSums(
  29. int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
  30. int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
  31. int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
  32. int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
  33. int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
  34. int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
  35. int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
  36. int n_input, int n_aux_input, int n_output,
  37. const int8_t* input_to_input_weights_ptr,
  38. const int8_t* input_to_forget_weights_ptr,
  39. const int8_t* input_to_cell_weights_ptr,
  40. const int8_t* input_to_output_weights_ptr,
  41. const int8_t* aux_input_to_input_weights_ptr,
  42. const int8_t* aux_input_to_forget_weights_ptr,
  43. const int8_t* aux_input_to_cell_weights_ptr,
  44. const int8_t* aux_input_to_output_weights_ptr,
  45. const int8_t* recurrent_to_input_weights_ptr,
  46. const int8_t* recurrent_to_forget_weights_ptr,
  47. const int8_t* recurrent_to_cell_weights_ptr,
  48. const int8_t* recurrent_to_output_weights_ptr,
  49. const int8_t* projection_weights_ptr, bool use_cifg,
  50. const float* aux_input_ptr) {
  51. // Compute the row sums for dequantization
  52. if (!use_cifg) {
  53. micro_tensor_utils::ReductionSumVector(
  54. input_to_input_weights_ptr, input_to_input_row_sums, n_cell, n_input);
  55. }
  56. micro_tensor_utils::ReductionSumVector(
  57. input_to_forget_weights_ptr, input_to_forget_row_sums, n_cell, n_input);
  58. micro_tensor_utils::ReductionSumVector(
  59. input_to_cell_weights_ptr, input_to_cell_row_sums, n_cell, n_input);
  60. micro_tensor_utils::ReductionSumVector(
  61. input_to_output_weights_ptr, input_to_output_row_sums, n_cell, n_input);
  62. if (aux_input_ptr) {
  63. if (!use_cifg) {
  64. micro_tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
  65. aux_input_to_input_row_sums,
  66. n_cell, n_aux_input);
  67. }
  68. micro_tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
  69. aux_input_to_forget_row_sums, n_cell,
  70. n_aux_input);
  71. micro_tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
  72. aux_input_to_cell_row_sums, n_cell,
  73. n_aux_input);
  74. micro_tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
  75. aux_input_to_output_row_sums, n_cell,
  76. n_aux_input);
  77. }
  78. if (!use_cifg) {
  79. micro_tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
  80. recurrent_to_input_row_sums, n_cell,
  81. n_output);
  82. }
  83. micro_tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
  84. recurrent_to_forget_row_sums, n_cell,
  85. n_output);
  86. micro_tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
  87. recurrent_to_cell_row_sums, n_cell,
  88. n_output);
  89. micro_tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
  90. recurrent_to_output_row_sums, n_cell,
  91. n_output);
  92. if (projection_weights_ptr != nullptr) {
  93. micro_tensor_utils::ReductionSumVector(
  94. projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
  95. }
  96. }
  97. // Calculates a single LSTM gate.
  98. //
  99. // Implements the following formula: (* is matrix multiply)
  100. // gate = activate(W_input * input + W_aux * aux_input +
  101. // W_peephole * cell + W_recurrent * prev_output + bias)
  102. // with layer norm:
  103. // gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
  104. //
  105. // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
  106. //
  107. // Parameters:
  108. // Input vectors (to LSTM): | Size: | Optional?
  109. // input | n_input |
  110. // aux_input | n_aux_input | y (bidir LSTM)
  111. // Input vectors (persistent states):
  112. // output_state | n_output |
  113. // cell_state | n_cell |
  114. // 'Constant' inputs:
  115. // input_to_gate_weights | n_cell * n_input |
  116. // aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
  117. // recurrent_to_gate_weights | n_cell * n_output |
  118. // cell_to_gate_weights | n_cell | y (peephole)
  119. // gate_bias | n_cell |
  120. // layer_norm_coefficients | n_cell | y (layer norm)
  121. // Output vector:
  122. // gate | n_cell |
  123. // Scalar parameters:
  124. // n_batch - batch size / number of vectors
  125. // n_input, n_aux_input, n_output, n_cell - size of vectors.
  126. // activation - activation to use.
  127. // is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
  128. // use_layer_norm - if doing layer norm LSTM.
  129. inline void CalculateLstmGateFloat(
  130. const float* input, const float* input_to_gate_weights,
  131. const float* aux_input, const float* aux_input_to_gate_weights,
  132. const float* output_state, const float* recurrent_to_gate_weights,
  133. const float* cell_state, const float* cell_to_gate_weights,
  134. const float* layer_norm_coefficients, const float* gate_bias,
  135. const int n_batch, const int n_input, const int n_aux_input,
  136. const int n_output, const int n_cell,
  137. const TfLiteFusedActivation activation, float* gate,
  138. const bool is_input_all_zeros, const bool is_aux_input_all_zeros) {
  139. const bool use_peephole = (cell_to_gate_weights != nullptr);
  140. const bool use_layer_norm = (layer_norm_coefficients != nullptr);
  141. // Initialize scratch buffers with bias for regular lstm or initialize with
  142. // zero for layer norm lstm.
  143. if (use_layer_norm) {
  144. memset(gate, 0, n_cell * n_batch * sizeof(float));
  145. } else {
  146. micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch,
  147. gate);
  148. }
  149. // For each batch and cell: compute input_weight * input.
  150. // Skip if input is all zeros.
  151. if (!is_input_all_zeros) {
  152. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  153. input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
  154. }
  155. // For each batch and cell: compute aux_input_weight * aux_input.
  156. // Skip if auxiliary input is not available or all zeros.
  157. if (!is_aux_input_all_zeros) {
  158. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  159. aux_input_to_gate_weights, n_cell, n_aux_input, aux_input, n_batch,
  160. gate);
  161. }
  162. // For each batch and cell: compute recurrent_weight * output_state.
  163. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  164. recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
  165. // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
  166. if (use_peephole) {
  167. micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
  168. cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
  169. }
  170. // Do layer normalization (if layer norm LSTM)
  171. if (use_layer_norm) {
  172. micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
  173. micro_tensor_utils::VectorBatchVectorCwiseProduct(
  174. layer_norm_coefficients, n_cell, gate, n_batch, gate);
  175. micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
  176. }
  177. // Apply activation
  178. micro_tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell,
  179. activation, gate);
  180. }
  181. // Updates the LSTM cell state, used by both float and hybrid LSTM versions.
  182. //
  183. // Implements the following formula:
  184. // cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
  185. //
  186. // With CIFG LSTM, input gate is replaced by (1-forget_gate).
  187. //
  188. // Parameters:
  189. // - n_batch, n_cell: sizes of vectors
  190. // - cell_state: input/output vector, size n_batch*n_cell
  191. // - input_gate: input vector, size n_batch*n_cell.
  192. // - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
  193. // - cell_gate: input vector, size n_batch*n_cell.
  194. // - use_cifg: use 1-forget_gate instead of input_gate.
  195. // - clip: if > 0, clip the resulting cell state to [-clip, +clip].
  196. void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
  197. const float* input_gate, float* forget_gate,
  198. const float* cell_gate, bool use_cifg, float clip) {
  199. micro_tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
  200. n_batch * n_cell, cell_state);
  201. if (use_cifg) {
  202. // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
  203. // scratch, as input_gate array is not allocated in this case. (Be careful
  204. // not to write to the scratch before reading the forget gate data.)
  205. float* scratch = forget_gate;
  206. micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
  207. micro_tensor_utils::VectorVectorCwiseProductAccumulate(
  208. cell_gate, scratch, n_batch * n_cell, cell_state);
  209. } else {
  210. micro_tensor_utils::VectorVectorCwiseProductAccumulate(
  211. cell_gate, input_gate, n_batch * n_cell, cell_state);
  212. }
  213. if (clip > 0.0f) {
  214. micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
  215. }
  216. }
  217. // Calculates the output state tensor of an LSTM step.
  218. //
  219. // Implements the following formula:
  220. // output_no_projection = output_gate .* activate(cell_state)
  221. // (elementwise vector product)
  222. // If no projection is used:
  223. // output = output_state = output_no_projection
  224. // With projection:
  225. // output = output_state = clip(W*output_no_projection + bias)
  226. //
  227. // Output might not have a different 'stride' than n_batch, so we need to copy.
  228. //
  229. // Parameters:
  230. // - n_batch: batches: the number of distinct vectors in each array.
  231. // - n_cell, n_output: sizes of vectors.
  232. // - cell_state, output_gate: input vectors, size n_batch*n_cell.
  233. // - projection_weights, projection_weights_scale, projection_bias:
  234. // constant inputs, describing projection matrix and bias.
  235. // - proj_clip: if > 0, clip the output of the projection.
  236. // - output_state: output vector, size n_batch*n_output. Must be contigous.
  237. // - scratch: scratch area, size n_batch*n_cell.
  238. void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
  239. const float* cell_state, const float* output_gate,
  240. TfLiteFusedActivation activation,
  241. const float* projection_weights,
  242. const float* projection_bias,
  243. const float proj_clip, float* output_state,
  244. float* scratch) {
  245. micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
  246. activation, scratch);
  247. micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch,
  248. n_batch * n_cell, scratch);
  249. const bool use_projection = (projection_weights != nullptr);
  250. const bool use_projection_bias = (projection_bias != nullptr);
  251. if (use_projection) {
  252. if (use_projection_bias) {
  253. micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output,
  254. n_batch, output_state);
  255. } else {
  256. memset(output_state, 0, n_batch * n_output * sizeof(float));
  257. }
  258. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  259. projection_weights, n_output, n_cell, scratch, n_batch, output_state);
  260. if (proj_clip > 0.0f) {
  261. micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
  262. proj_clip);
  263. }
  264. } else {
  265. std::memcpy(output_state, scratch, n_batch * n_output * sizeof(float));
  266. }
  267. }
  268. // Calculates a single LSTM gate, hybrid version.
  269. // Implements the same functionality as CalculateLstmGateFloat.
  270. void CalculateLstmGateHybrid(
  271. // Input and weights
  272. const int8_t* input, const float* input_sf, const int32_t* input_zp,
  273. const int8_t* input_to_gate_weights,
  274. const uint8_t* input_to_gate_weights_ledger,
  275. const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
  276. // Aux input and weights
  277. const int8_t* aux_input, const float* aux_input_sf,
  278. const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights,
  279. const float aux_input_to_gate_weights_scale,
  280. int32_t* aux_input_to_gate_row_sums,
  281. // Output state and weights
  282. const int8_t* output_state, const float* output_state_sf,
  283. const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
  284. const uint8_t* recurrent_to_gate_weights_ledger,
  285. const float recurrent_to_gate_weights_scale,
  286. int32_t* recurrent_to_gate_row_sums,
  287. // Cell state and weights (peephole LSTM)
  288. const float* cell_state, const int8_t* cell_to_gate_weights,
  289. const float cell_to_gate_weights_scale,
  290. // Layer normalization coefficients (layer norm LSTM) + gate bias
  291. const float* layer_norm_coefficients, const float* gate_bias,
  292. // Array sizes
  293. const int n_batch, const int n_input, const int n_aux_input,
  294. const int n_output, const int n_cell,
  295. const TfLiteFusedActivation activation,
  296. // Output
  297. float* gate,
  298. // Parameters for performance optimizations
  299. const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
  300. const bool is_output_state_all_zeros, bool* compute_row_sums,
  301. // Scratch arrays
  302. float* scratch0, // size: n_batch
  303. float* scratch1, // size: n_cell, only used if peephole LSTM
  304. float* scales, // size: n_batch
  305. int32_t* accum_scratch // For MatrixBatchVectorMultiplyAccumulate
  306. ) {
  307. const bool use_peephole = (cell_to_gate_weights != nullptr);
  308. const bool use_layer_norm = (layer_norm_coefficients != nullptr);
  309. // Initialize scratch buffers with bias for regular lstm or initialize with
  310. // zero for layer norm lstm.
  311. if (use_layer_norm) {
  312. memset(gate, 0, n_cell * n_batch * sizeof(float));
  313. } else {
  314. micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch,
  315. gate);
  316. }
  317. // For each batch and cell: compute input_weight * input.
  318. // Skip if input is all zeros.
  319. if (!is_input_all_zeros) {
  320. if (input_to_gate_weights_ledger != nullptr) {
  321. for (int i = 0; i < n_batch; i++) {
  322. scales[i] = input_to_gate_weights_scale * input_sf[i];
  323. }
  324. micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
  325. input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
  326. input, scales, n_batch, gate);
  327. } else {
  328. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  329. input_to_gate_weights, n_cell, n_input, input,
  330. input_to_gate_weights_scale, input_sf, n_batch, gate,
  331. /*per_channel_scale=*/nullptr, input_zp, accum_scratch,
  332. input_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
  333. }
  334. }
  335. // For each batch and cell: compute aux_input_weight * aux_input.
  336. // Skip if auxiliary input is not available or all zeros.
  337. if (!is_aux_input_all_zeros) {
  338. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  339. aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
  340. aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate,
  341. /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch,
  342. aux_input_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
  343. }
  344. // For each batch and cell: compute recurrent_weight * output_state.
  345. // Skip if output state is all zeros.
  346. if (!is_output_state_all_zeros) {
  347. if (recurrent_to_gate_weights_ledger != nullptr) {
  348. for (int i = 0; i < n_batch; i++) {
  349. scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
  350. }
  351. micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
  352. recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
  353. n_output, output_state, scales, n_batch, gate);
  354. } else {
  355. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  356. recurrent_to_gate_weights, n_cell, n_output, output_state,
  357. recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
  358. /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
  359. recurrent_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
  360. }
  361. }
  362. // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
  363. if (use_peephole) {
  364. float* recovered_cell_weights = scratch1;
  365. micro_tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell,
  366. cell_to_gate_weights_scale,
  367. recovered_cell_weights);
  368. micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
  369. recovered_cell_weights, n_cell, cell_state, n_batch, gate);
  370. }
  371. // Do layer normalization (if layer norm LSTM)
  372. if (use_layer_norm) {
  373. micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
  374. micro_tensor_utils::VectorBatchVectorCwiseProduct(
  375. layer_norm_coefficients, n_cell, gate, n_batch, gate);
  376. micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
  377. }
  378. // Apply activation
  379. micro_tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch,
  380. activation, gate);
  381. }
  382. // Calculates the output state tensor of an LSTM step. See Float version too.
  383. //
  384. // Parameters:
  385. // - n_batch: batches: the number of distinct vectors in each array.
  386. // - n_cell, n_output: sizes of vectors.
  387. // - cell_state, output_gate: input vectors, size n_batch*n_cell.
  388. // - projection_weights, projection_weights_scale, projection_bias:
  389. // constant inputs, describing projection matrix and bias.
  390. // - proj_clip: if > 0, clip the output of the projection.
  391. // - output_state: output vector, size n_batch*n_output. Must be contigous.
  392. // - asymmetric_quantize_inputs: parameter to control quantization.
  393. // - projection_weights_row_sums, compute_row_sums: Data for optimized
  394. // MatrixBatchVectorMultiplyAccumulate.
  395. // - scratch0: scratch area of size n_batch*n_cell
  396. // - scratch1: scratch area of size n_batch*n_cell
  397. // - scratch2: scratch area of size n_batch
  398. // - scratch3: scratch area of size n_batch
  399. // - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate
  400. // - scales: scratch area of size n_batch
  401. void CalculateLstmOutputHybrid(
  402. int n_batch, int n_cell, int n_output, const float* cell_state,
  403. const float* output_gate, TfLiteFusedActivation activation,
  404. const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
  405. float projection_weights_scale, const float* projection_bias,
  406. const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
  407. int32_t* projection_weights_row_sums, bool* compute_row_sums,
  408. float* scratch0, int8_t* scratch1, float* scratch2, int32_t* scratch3,
  409. int32_t* scratch4, float* scales) {
  410. micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
  411. activation, scratch0);
  412. micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
  413. n_batch * n_cell, scratch0);
  414. const bool use_projection = (projection_weights != nullptr);
  415. const bool use_projection_bias = (projection_bias != nullptr);
  416. if (use_projection) {
  417. if (use_projection_bias) {
  418. micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output,
  419. n_batch, output_state);
  420. } else {
  421. memset(output_state, 0, n_batch * n_output * sizeof(float));
  422. }
  423. if (!micro_tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) {
  424. // Save quantization and matmul computation for all zero output.
  425. micro_tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell,
  426. scratch1, scratch2, scratch3,
  427. asymmetric_quantize_inputs);
  428. if (projection_weights_ledger != nullptr) {
  429. for (int i = 0; i < n_batch; i++) {
  430. scales[i] = projection_weights_scale * scratch2[i];
  431. }
  432. micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
  433. projection_weights, projection_weights_ledger, n_output, n_cell,
  434. scratch1, scales, n_batch, output_state);
  435. } else {
  436. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  437. projection_weights, n_output, n_cell, scratch1,
  438. projection_weights_scale, scratch2, n_batch, output_state,
  439. /*per_channel_scale=*/nullptr, scratch3, scratch4,
  440. projection_weights_row_sums, compute_row_sums, scratch2, nullptr);
  441. }
  442. }
  443. if (proj_clip > 0.0f) {
  444. micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
  445. proj_clip);
  446. }
  447. } else {
  448. std::memcpy(output_state, scratch0, n_batch * n_output * sizeof(float));
  449. }
  450. }
  451. // Calculates a single LSTM gate, int8x8_16 version.
  452. // Implements the same functionality as CalculateLstmGateFloat.
  453. void CalculateLstmGateInteger8x8_16(
  454. // Input and weights
  455. const int8_t* input, const int8_t* input_to_gate_weights,
  456. const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
  457. const int32_t input_to_gate_scale_b,
  458. // Output state and weights
  459. const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
  460. const int32_t* recurrent_to_gate_bias,
  461. const int32_t recurrent_to_gate_scale_a,
  462. const int32_t recurrent_to_gate_scale_b,
  463. // Cell state and weights
  464. const int16_t* cell_state, const int16_t* cell_to_gate_weights,
  465. const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
  466. // Layer normalization parameters (layer norm LSTM)
  467. const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
  468. const int32_t layer_norm_input_scale_a,
  469. const int32_t layer_norm_input_scale_b,
  470. const int32_t layer_norm_variance_guard,
  471. // Array sizes
  472. const int n_batch, const int n_input, const int n_output, const int n_cell,
  473. const TfLiteFusedActivation activation,
  474. // Output
  475. int16_t* gate,
  476. // Parameters for performance optimizations
  477. // Scratch arrays
  478. int32_t* scratch5) {
  479. const bool use_peephole = (cell_to_gate_weights != nullptr);
  480. const bool use_layer_norm = (layer_norm_coefficients != nullptr);
  481. // Initialize scratch buffers with zeros. Note that unlike float and hybrid
  482. // versions, bias is only used in layer normalization.
  483. memset(gate, 0, n_batch * n_cell * sizeof(int16_t));
  484. // For each batch and cell: compute input_weight * input.
  485. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  486. input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
  487. input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate,
  488. nullptr);
  489. // Note: no aux_input.
  490. // For each batch and cell: compute recurrent_weight * output_state.
  491. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  492. output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
  493. recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
  494. n_cell, 0, scratch5, gate, nullptr);
  495. // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
  496. if (use_peephole) {
  497. micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
  498. cell_to_gate_weights, n_output, cell_state, n_batch,
  499. cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
  500. }
  501. // Do layer normalization (if layer norm LSTM)
  502. if (use_layer_norm) {
  503. micro_tensor_utils::ApplyLayerNorm(
  504. gate, layer_norm_coefficients, layer_norm_bias,
  505. layer_norm_input_scale_a, layer_norm_input_scale_b,
  506. layer_norm_variance_guard, n_batch, n_cell, gate);
  507. }
  508. // Apply activation
  509. switch (activation) {
  510. case kTfLiteActSigmoid:
  511. reference_integer_ops::Logistic(
  512. 0 /*data->input_multiplier*/, 0 /*data->input_left_shift */,
  513. n_batch * n_cell /*NumElements(input->dims)*/,
  514. gate /* tflite::micro::GetTensorData<int16_t>(input) */,
  515. gate /*tflite::micro::GetTensorData<int16_t>(output) */);
  516. break;
  517. case kTfLiteActTanh: {
  518. int32_t dims_data = n_batch * n_cell;
  519. RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
  520. reference_integer_ops::Tanh(0, 0, tanh_inp_shape, gate, tanh_inp_shape,
  521. gate);
  522. } break;
  523. default:
  524. // Only Sigmoid or Tanh is used.
  525. TFLITE_ASSERT_FALSE;
  526. }
  527. }
  528. // Updates the LSTM cell state, used by both integer LSTM versions.
  529. // Also see UpdateLstmCellFloat.
  530. //
  531. // Parameters:
  532. // - n_batch, n_cell: sizes of vectors
  533. // - cell_state: input/output vector, size n_batch*n_cell
  534. // - cell_state_scale: scaling factor of cell state.
  535. // - input_gate: input vector, size n_batch*n_cell.
  536. // - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
  537. // - cell_gate: input vector, size n_batch*n_cell.
  538. // - use_cifg: use 1-forget_gate instead of input_gate.
  539. // - clip: if > 0, clip the resulting cell state to [-clip, +clip].
  540. void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
  541. int32_t cell_state_scale, const int16_t* input_gate,
  542. int16_t* forget_gate, const int16_t* cell_gate,
  543. bool use_cifg, int16_t clip) {
  544. // Use the forget_gate array as scratch, as input_gate array is not allocated
  545. // in CIFG case. (Be careful not to write to the scratch before reading the
  546. // forget gate data.)
  547. int16_t* scratch = forget_gate;
  548. micro_tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
  549. cell_state);
  550. if (use_cifg) {
  551. micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
  552. micro_tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell,
  553. 30 + cell_state_scale, scratch);
  554. } else {
  555. micro_tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell,
  556. 30 + cell_state_scale, scratch);
  557. }
  558. micro_tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell,
  559. cell_state);
  560. if (clip > 0) {
  561. micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
  562. }
  563. }
  564. // Calculates the output state tensor of an LSTM step. See Float and hybrid
  565. // versions as well.
  566. //
  567. // Parameters:
  568. // - n_batch: batches: the number of distinct vectors in each array.
  569. // - n_cell, n_output: sizes of vectors.
  570. // - cell_state, output_gate: input vectors, size n_batch*n_cell.
  571. // - cell_state_scale: scaling of cell_state.
  572. // - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
  573. // - hidden_zp: zero_point for cell_state.*output_gate
  574. // - projection_weights, proj_scale_[a|b], projection_bias:
  575. // constant inputs, describing projection matrix and bias.
  576. // - output_state_zp: zero point of output_state. (Input, calibrated value.)
  577. // - quantized_proj_clip: if > 0, clip the output of the projection.
  578. // - output_state: output vector, size n_batch*n_output. Must be contigous.
  579. // - scratch0: scratch area of size n_batch*n_cell
  580. // - scratch1: scratch area of size n_batch*n_cell
  581. // - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
  582. void CalculateLstmOutputInteger8x8_16(
  583. int n_batch, int n_cell, int n_output, int16_t* cell_state,
  584. int32_t cell_state_scale, const int16_t* output_gate,
  585. int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
  586. const int8_t* projection_weights, int32_t proj_scale_a,
  587. int32_t proj_scale_b, const int32_t* projection_bias,
  588. int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
  589. int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
  590. // Note: unlike float/hybrid, the activation is always Tanh.
  591. {
  592. int32_t tanh_input_left_shift = (15 + cell_state_scale) - 3;
  593. int32_t dims_data = n_batch * n_cell;
  594. if (tanh_input_left_shift < 0) /* handling negative shift value */
  595. {
  596. int32_t i;
  597. tanh_input_left_shift = -tanh_input_left_shift;
  598. for (i = 0; i < dims_data; i++) {
  599. cell_state[i] = cell_state[i] >> tanh_input_left_shift;
  600. }
  601. tanh_input_left_shift = 0;
  602. }
  603. RuntimeShape tanh_inp_shape = RuntimeShape(1, &dims_data);
  604. reference_integer_ops::Tanh(0, tanh_input_left_shift, tanh_inp_shape,
  605. cell_state, tanh_inp_shape, scratch0);
  606. }
  607. micro_tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a,
  608. hidden_scale_b, n_batch, n_cell, hidden_zp,
  609. scratch1);
  610. const bool use_projection = (projection_weights != nullptr);
  611. if (use_projection) {
  612. // Note: no bias like in float/hybrid
  613. memset(output_state, 0, n_batch * n_output * sizeof(int8_t));
  614. micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
  615. scratch1, projection_bias, projection_weights, proj_scale_a,
  616. proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
  617. output_state, nullptr);
  618. if (quantized_proj_clip > 0) {
  619. micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
  620. quantized_proj_clip);
  621. }
  622. } else {
  623. std::memcpy(output_state, scratch1, n_batch * n_output * sizeof(int8_t));
  624. }
  625. }
  626. // Calculates a single LSTM gate, int8x8_8 version.
  627. // Implements the same functionality as CalculateLstmGateFloat.
  628. void CalculateLstmGateInteger8x8_8(
  629. // Inputs and weights
  630. const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
  631. const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
  632. const int32_t input_times_weights_scale_a,
  633. const int32_t input_times_weights_scale_b,
  634. const int32_t input_times_weights_zp,
  635. // Output state and weights
  636. const int8_t* output_state, const int32_t output_state_zp,
  637. const int8_t* recurrent_to_gate_weight,
  638. const int32_t recurrent_to_gate_scale_a,
  639. const int32_t recurrent_to_gate_scale_b,
  640. const int32_t output_state_times_weights_scale_a,
  641. const int32_t output_state_times_weights_scale_b,
  642. const int32_t output_state_times_weights_zp,
  643. // Layer normalization parameters (layer norm LSTM)
  644. const int16_t* layer_norm_gate_weight,
  645. const int32_t layer_norm_gate_scale_a,
  646. const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
  647. // Array sizes
  648. const int n_batch, const int n_input, const int n_output, const int n_cell,
  649. const TfLiteFusedActivation activation,
  650. // Output
  651. int16_t* gate,
  652. // Scratch arrays, both sized n_batch*n_cell
  653. int8_t* scratch0, int8_t* scratch1) {
  654. // Multiply input * input_weights => scratch0
  655. micro_tensor_utils::MatrixBatchVectorMultiply(
  656. input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
  657. input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
  658. input_times_weights_zp);
  659. // Multiply output_state * recurrent_weights => scratch1
  660. micro_tensor_utils::MatrixBatchVectorMultiply(
  661. output_state, output_state_zp, recurrent_to_gate_weight,
  662. recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
  663. n_cell, scratch1, output_state_times_weights_zp);
  664. // Add scratch0 + scratch1 => gate
  665. micro_tensor_utils::TwoGateSaturatingAdd(
  666. scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
  667. input_times_weights_scale_a, input_times_weights_scale_b,
  668. output_state_times_weights_scale_a, output_state_times_weights_scale_b,
  669. n_batch, n_cell, gate);
  670. // Apply layer normalization.
  671. micro_tensor_utils::ApplyLayerNormFloat(
  672. gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
  673. layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
  674. // Apply activation.
  675. switch (activation) {
  676. case kTfLiteActSigmoid:
  677. micro_tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate);
  678. break;
  679. case kTfLiteActTanh:
  680. micro_tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
  681. break;
  682. default:
  683. // Only Sigmoid or Tanh is used.
  684. TFLITE_ASSERT_FALSE;
  685. }
  686. }
  687. // Calculates the output state tensor of an LSTM step. See Float and hybrid
  688. // versions as well.
  689. //
  690. // Parameters:
  691. // - n_batch: batches: the number of distinct vectors in each array.
  692. // - n_cell, n_output: sizes of vectors.
  693. // - cell_state, output_gate: input vectors, size n_batch*n_cell.
  694. // - projection_weights, proj_scale_[a|b], projection_bias:
  695. // constant inputs, describing projection matrix and bias.
  696. // - output_state_zp: zero point of the output state.
  697. // - quantized_proj_clip: if > 0, clip the output of the projection.
  698. // - output_state: output vector, size n_batch*n_output. Must be contigous.
  699. // - scratch: scratch area of size n_batch*n_cell
  700. void CalculateLstmOutputInteger8x8_8(
  701. int n_batch, int n_cell, int n_output, const int16_t* cell_state,
  702. const int16_t* output_gate, const int8_t* projection_weights,
  703. int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
  704. int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
  705. int16_t* scratch) {
  706. // Note: unlike float/hybrid, the activation is always Tanh.
  707. micro_tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch);
  708. micro_tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell,
  709. 15 + 15 - 15, scratch);
  710. // Note: no bias like in float/hybrid
  711. micro_tensor_utils::MatrixBatchVectorMultiply(
  712. scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
  713. n_batch, n_cell, n_output, output_state_zp, output_state);
  714. if (quantized_proj_clip > 0) {
  715. micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
  716. quantized_proj_clip);
  717. }
  718. }
  719. // Performs an LSTM batch inference step for input specified by input_ptr.
  720. // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
  721. // biases (*_bias_ptr), and buffers (*_scratch), along with additional
  722. // parameters:
  723. // - params: various LSTM params including activation, clipping, etc.,
  724. // - n_batch: size of batch,
  725. // - n_cell: number of cells (or units),
  726. // - n_input: the input size,
  727. // - n_aux_input: the auxiliary input size.
  728. // - n_output: the output size.
  729. // - output_batch_leading_dim: the leading dimension of the output buffer.
  730. //
  731. // Input of size 'n_batch * n_input':
  732. // input_ptr
  733. // Input of size 'n_batch * n_aux_input':
  734. // aux_input_ptr - optional (can be nullptr)
  735. //
  736. // LSTM weights:
  737. // Input weights of size 'n_cell * n_input':
  738. // input_to_input_weights - optional
  739. // input_to_forget_weights
  740. // input_to_cell_weights
  741. // input_to_output_weights
  742. // Auxiliary input weights of size 'n_cell * n_aux_input':
  743. // aux_input_to_input_weights - optional
  744. // aux_input_to_forget_weights - optional
  745. // aux_input_to_cell_weights - optional
  746. // aux_input_to_output_weights - optional
  747. // Recurrent weights of size 'n_cell * n_output':
  748. // recurrent_to_input_weights - optional
  749. // recurrent_to_forget_weights
  750. // recurrent_to_cell_weights
  751. // recurrent_to_input_weights
  752. // Peephole weights of size 'n_cell', representing diagonal matrices.
  753. // cell_to_input_weights - optional
  754. // cell_to_cell_weights - optional
  755. // cell_to_output_weights - optional
  756. // Projection weights of size 'n_output * n_cell'
  757. // projection_weights_ptr - optional
  758. // Gate biases of size 'n_cell':
  759. // input_gate_bias_ptr - optional
  760. // forget_gate_bias_ptr
  761. // cell_gate_bias_ptr
  762. // output_gate_bias_ptr
  763. //
  764. // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
  765. // input_layer_norm_coefficients_ptr - optional
  766. // forget_layer_norm_coefficients_ptr - optional
  767. // cell_layer_norm_coefficients_ptr - optional
  768. // output_layer_norm_coefficients_ptr - optional
  769. //
  770. // The pointers to the cell and output state and the output are updated.
  771. //
  772. // The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
  773. // in batch_major order, and each step processes batch_size many inputs from
  774. // input_ptr, and updates batch_size many cell and output states.
  775. //
  776. // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
  777. // output tensor, and in most cases will be equal to n_output. It is usually not
  778. // when we want to store the LSTM output into a slice of the output tensor, e.g.
  779. // for bidirectional LSTMs with merge_outputs. In this case, the batched
  780. // operations cannot be used since they assume that the batched outputs are
  781. // contiguous, and we manually loop over the batched outputs.
  782. inline void LstmStepFloat(
  783. const float* input_ptr, const float* input_to_input_weights_ptr,
  784. const float* input_to_forget_weights_ptr,
  785. const float* input_to_cell_weights_ptr,
  786. const float* input_to_output_weights_ptr, const float* aux_input_ptr,
  787. const float* aux_input_to_input_weights_ptr,
  788. const float* aux_input_to_forget_weights_ptr,
  789. const float* aux_input_to_cell_weights_ptr,
  790. const float* aux_input_to_output_weights_ptr,
  791. const float* recurrent_to_input_weights_ptr,
  792. const float* recurrent_to_forget_weights_ptr,
  793. const float* recurrent_to_cell_weights_ptr,
  794. const float* recurrent_to_output_weights_ptr,
  795. const float* cell_to_input_weights_ptr,
  796. const float* cell_to_forget_weights_ptr,
  797. const float* cell_to_output_weights_ptr,
  798. const float* input_layer_norm_coefficients_ptr,
  799. const float* forget_layer_norm_coefficients_ptr,
  800. const float* cell_layer_norm_coefficients_ptr,
  801. const float* output_layer_norm_coefficients_ptr,
  802. const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
  803. const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
  804. const float* projection_weights_ptr, const float* projection_bias_ptr,
  805. const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
  806. int n_aux_input, int n_output, int output_batch_leading_dim,
  807. float* output_state_ptr, float* cell_state_ptr, float* scratch0,
  808. float* scratch1, float* scratch2, float* scratch3, float* output_ptr) {
  809. // Since we have already checked that weights are all there or none, we can
  810. // check the existence of only one to the get the condition.
  811. const bool use_cifg = (input_to_input_weights_ptr == nullptr);
  812. // Make named scratch buffers.
  813. float* input_gate_scratch = scratch0;
  814. float* forget_gate_scratch = scratch1;
  815. float* cell_gate_scratch = scratch2;
  816. float* output_gate_scratch = scratch3;
  817. // Check if inputs are all zeros so we can skip some computations.
  818. const bool is_input_all_zeros =
  819. micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
  820. const bool is_aux_input_all_zeros =
  821. (aux_input_ptr == nullptr ||
  822. micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
  823. if (!use_cifg) {
  824. // Calculate the input gate. (If not CIFG.)
  825. CalculateLstmGateFloat(
  826. input_ptr, input_to_input_weights_ptr, aux_input_ptr,
  827. aux_input_to_input_weights_ptr, output_state_ptr,
  828. recurrent_to_input_weights_ptr, cell_state_ptr,
  829. cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
  830. input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
  831. /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
  832. is_input_all_zeros, is_aux_input_all_zeros);
  833. }
  834. // Calculate the forget gate.
  835. CalculateLstmGateFloat(
  836. input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
  837. aux_input_to_forget_weights_ptr, output_state_ptr,
  838. recurrent_to_forget_weights_ptr, cell_state_ptr,
  839. cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
  840. forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
  841. /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
  842. is_aux_input_all_zeros);
  843. // Calculate the cell update gate.
  844. CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
  845. aux_input_to_cell_weights_ptr, output_state_ptr,
  846. recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
  847. /*cell_to_gate_weights=*/nullptr,
  848. cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
  849. n_batch, n_input, n_aux_input, n_output, n_cell,
  850. params->activation, cell_gate_scratch,
  851. is_input_all_zeros, is_aux_input_all_zeros);
  852. // Update the cell state.
  853. UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
  854. forget_gate_scratch, cell_gate_scratch, use_cifg,
  855. params->cell_clip);
  856. // Calculate output gate.
  857. CalculateLstmGateFloat(
  858. input_ptr, input_to_output_weights_ptr, aux_input_ptr,
  859. aux_input_to_output_weights_ptr, output_state_ptr,
  860. recurrent_to_output_weights_ptr, cell_state_ptr,
  861. cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
  862. output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
  863. /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
  864. is_aux_input_all_zeros);
  865. // Update the output state.
  866. CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
  867. output_gate_scratch, params->activation,
  868. projection_weights_ptr, projection_bias_ptr,
  869. params->proj_clip, output_state_ptr, scratch2);
  870. // Copy output state to the output. Note that the output's rows may not be
  871. // contiguous (output_batch_leading_dim != n_output).
  872. for (int b = 0; b < n_batch; b++) {
  873. std::memcpy(output_ptr + b * output_batch_leading_dim,
  874. output_state_ptr + b * n_output, n_output * sizeof(float));
  875. }
  876. }
  877. // Same as above but with quantized weight matrices. In detail:
  878. // Input of size 'n_batch * n_input':
  879. // input_ptr
  880. // Input of size 'n_batch * n_aux_input':
  881. // aux_input_ptr - optional (can be nullptr)
  882. //
  883. // LSTM weights:
  884. // Quantized input weights of size 'n_cell * n_input':
  885. // input_to_input_weights - optional
  886. // input_to_forget_weights
  887. // input_to_cell_weights
  888. // input_to_input_weights
  889. // Quantized auxiliary input weights of size 'n_cell * n_aux_input':
  890. // aux_input_to_input_weights - optional
  891. // aux_input_to_forget_weights - optional
  892. // aux_input_to_cell_weights - optional
  893. // aux_input_to_output_weights - optional
  894. // Quantized recurrent weights of size 'n_cell * n_output':
  895. // recurrent_to_input_weights - optional
  896. // recurrent_to_forget_weights
  897. // recurrent_to_cell_weights
  898. // recurrent_to_input_weights
  899. // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
  900. // cell_to_input_weights - optional
  901. // cell_to_cell_weights - optional
  902. // cell_to_output_weights - optional
  903. // Quantized projection weights of size 'n_output * n_cell'
  904. // projection_weights_ptr - optional
  905. // Weight scales (scalars) for each of the weights above.
  906. // input_to_input_weights_scale - optional
  907. // input_to_forget_weights_scale
  908. // input_to_cell_weights_scale
  909. // input_to_output_weights_scale
  910. // aux_input_to_input_weights_scale - optional
  911. // aux_input_to_forget_weights_scale - optional
  912. // aux_input_to_cell_weights_scale - optional
  913. // aux_input_to_output_weights_scale - optional
  914. // recurrent_to_input_weights_scale - optional
  915. // recurrent_to_forget_weights_scale
  916. // recurrent_to_cell_weights_scale
  917. // recurrent_to_output_weights_scale
  918. // cell_to_input_weights_scale,
  919. // cell_to_forget_weights_scale,
  920. // cell_to_output_weights_scale,
  921. // projection_weights_scale - optional
  922. // Gate biases of size 'n_cell':
  923. // input_gate_bias_ptr - optional
  924. // forget_gate_bias_ptr
  925. // cell_gate_bias_ptr
  926. // output_gate_bias_ptr
  927. //
  928. // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
  929. // input_layer_norm_coefficients_ptr - optional
  930. // forget_layer_norm_coefficients_ptr - optional
  931. // cell_layer_norm_coefficients_ptr - optional
  932. // output_layer_norm_coefficients_ptr - optional
  933. //
  934. // Temporary pre-allocated storage for quantized values:
  935. // quantized_input_ptr (same size as input_ptr)
  936. // quantized_output_state_ptr (same size as output_state_ptr)
  937. // quantized_output_scratch (same size as cell_state_ptr)
  938. // Temporary pre-allocated storage for recovered values:
  939. // recovered_cell_weights (same size as cell_to_*_weights)
  940. //
  941. // Outputs:
  942. // output_state_ptr - size 'n_batch * n_output'
  943. // cell_state_ptr - size 'n_batch * n_cell'
  944. // output_ptr - size 'n_batch * output_batch_leading_dim'
  945. inline void LstmStepHybrid(
  946. const float* input_ptr, const int8_t* input_to_input_weights_ptr,
  947. const uint8_t* input_to_input_weights_ledger_ptr,
  948. float input_to_input_weights_scale,
  949. const int8_t* input_to_forget_weights_ptr,
  950. const uint8_t* input_to_forget_weights_ledger_ptr,
  951. float input_to_forget_weights_scale,
  952. const int8_t* input_to_cell_weights_ptr,
  953. const uint8_t* input_to_cell_weights_ledger_ptr,
  954. float input_to_cell_weights_scale,
  955. const int8_t* input_to_output_weights_ptr,
  956. const uint8_t* input_to_output_weights_ledger_ptr,
  957. float input_to_output_weights_scale, const float* aux_input_ptr,
  958. const int8_t* aux_input_to_input_weights_ptr,
  959. float aux_input_to_input_weights_scale,
  960. const int8_t* aux_input_to_forget_weights_ptr,
  961. float aux_input_to_forget_weights_scale,
  962. const int8_t* aux_input_to_cell_weights_ptr,
  963. float aux_input_to_cell_weights_scale,
  964. const int8_t* aux_input_to_output_weights_ptr,
  965. float aux_input_to_output_weights_scale,
  966. const int8_t* recurrent_to_input_weights_ptr,
  967. const uint8_t* recurrent_to_input_weights_ledger_ptr,
  968. float recurrent_to_input_weights_scale,
  969. const int8_t* recurrent_to_forget_weights_ptr,
  970. const uint8_t* recurrent_to_forget_weights_ledger_ptr,
  971. float recurrent_to_forget_weights_scale,
  972. const int8_t* recurrent_to_cell_weights_ptr,
  973. const uint8_t* recurrent_to_cell_weights_ledger_ptr,
  974. float recurrent_to_cell_weights_scale,
  975. const int8_t* recurrent_to_output_weights_ptr,
  976. const uint8_t* recurrent_to_output_weights_ledger_ptr,
  977. float recurrent_to_output_weights_scale,
  978. const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
  979. const int8_t* cell_to_forget_weights_ptr,
  980. float cell_to_forget_weights_scale,
  981. const int8_t* cell_to_output_weights_ptr,
  982. float cell_to_output_weights_scale,
  983. const float* input_layer_norm_coefficients_ptr,
  984. const float* forget_layer_norm_coefficients_ptr,
  985. const float* cell_layer_norm_coefficients_ptr,
  986. const float* output_layer_norm_coefficients_ptr,
  987. const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
  988. const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
  989. const int8_t* projection_weights_ptr,
  990. const uint8_t* projection_weights_ledger_ptr,
  991. float projection_weights_scale, const float* projection_bias_ptr,
  992. const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
  993. int n_aux_input, int n_output, int output_batch_leading_dim,
  994. float* scratch0, float* scratch1, float* scratch2, float* scratch3,
  995. float* scales, float* input_sf, float* aux_input_sf, float* output_state_sf,
  996. float* scaling_factors_scratch, float* recovered_cell_weights,
  997. int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
  998. int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
  999. float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
  1000. float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
  1001. int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
  1002. bool* compute_row_sums, bool asymmetric_quantize_inputs) {
  1003. // Since we have already checked that weights are all there or none, we
  1004. // can check the existence of only one to the get the condition.
  1005. const bool use_cifg = (input_to_input_weights_ptr == nullptr);
  1006. // Make named scratch buffers for the different gates.
  1007. float* input_gate_scratch = scratch0;
  1008. float* forget_gate_scratch = scratch1;
  1009. float* cell_gate_scratch = scratch2;
  1010. float* output_gate_scratch = scratch3;
  1011. int32_t* input_to_input_row_sums = nullptr;
  1012. int32_t* input_to_forget_row_sums = nullptr;
  1013. int32_t* input_to_cell_row_sums = nullptr;
  1014. int32_t* input_to_output_row_sums = nullptr;
  1015. int32_t* aux_input_to_input_row_sums = nullptr;
  1016. int32_t* aux_input_to_forget_row_sums = nullptr;
  1017. int32_t* aux_input_to_cell_row_sums = nullptr;
  1018. int32_t* aux_input_to_output_row_sums = nullptr;
  1019. int32_t* recurrent_to_input_row_sums = nullptr;
  1020. int32_t* recurrent_to_forget_row_sums = nullptr;
  1021. int32_t* recurrent_to_cell_row_sums = nullptr;
  1022. int32_t* recurrent_to_output_row_sums = nullptr;
  1023. int32_t* projection_weights_row_sums = nullptr;
  1024. if (asymmetric_quantize_inputs) {
  1025. int num_row_sums = use_cifg ? 6 : 8;
  1026. if (aux_input_ptr != nullptr) {
  1027. num_row_sums += use_cifg ? 3 : 4;
  1028. }
  1029. if (projection_weights_ptr != nullptr) {
  1030. num_row_sums += ceil(static_cast<float>(n_output) / n_cell);
  1031. }
  1032. TFLITE_DCHECK(row_sums_size == num_row_sums);
  1033. input_to_input_row_sums = row_sums;
  1034. input_to_forget_row_sums =
  1035. use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
  1036. input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
  1037. input_to_output_row_sums = input_to_cell_row_sums + n_cell;
  1038. if (aux_input_ptr != nullptr) {
  1039. aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
  1040. aux_input_to_forget_row_sums = use_cifg
  1041. ? aux_input_to_input_row_sums
  1042. : aux_input_to_input_row_sums + n_cell;
  1043. aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
  1044. aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
  1045. }
  1046. recurrent_to_input_row_sums = aux_input_ptr
  1047. ? aux_input_to_output_row_sums + n_cell
  1048. : input_to_output_row_sums + n_cell;
  1049. recurrent_to_forget_row_sums = use_cifg
  1050. ? recurrent_to_input_row_sums
  1051. : recurrent_to_input_row_sums + n_cell;
  1052. recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
  1053. recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
  1054. if (projection_weights_ptr != nullptr) {
  1055. projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
  1056. }
  1057. if (*compute_row_sums) {
  1058. ComputeRowSums(
  1059. input_to_input_row_sums, input_to_forget_row_sums,
  1060. input_to_cell_row_sums, input_to_output_row_sums,
  1061. aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
  1062. aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
  1063. recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
  1064. recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
  1065. projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
  1066. n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
  1067. input_to_cell_weights_ptr, input_to_output_weights_ptr,
  1068. aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
  1069. aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
  1070. recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
  1071. recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
  1072. projection_weights_ptr, use_cifg, aux_input_ptr);
  1073. *compute_row_sums = false;
  1074. }
  1075. }
  1076. // Check if inputs are all zeros so we can skip some computations.
  1077. const bool is_input_all_zeros =
  1078. micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
  1079. const bool is_aux_input_all_zeros =
  1080. (aux_input_ptr == nullptr ||
  1081. micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
  1082. const bool is_output_state_all_zeros =
  1083. micro_tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output);
  1084. // Quantize inputs.
  1085. if (!is_input_all_zeros) {
  1086. micro_tensor_utils::BatchQuantizeFloats(
  1087. input_ptr, n_batch, n_input, quantized_input_ptr, input_sf, input_zp,
  1088. asymmetric_quantize_inputs);
  1089. }
  1090. if (!is_aux_input_all_zeros) {
  1091. micro_tensor_utils::BatchQuantizeFloats(
  1092. aux_input_ptr, n_batch, n_aux_input, quantized_aux_input_ptr,
  1093. aux_input_sf, aux_input_zp, asymmetric_quantize_inputs);
  1094. }
  1095. if (!is_output_state_all_zeros) {
  1096. micro_tensor_utils::BatchQuantizeFloats(
  1097. output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
  1098. output_state_sf, output_state_zp, asymmetric_quantize_inputs);
  1099. }
  1100. if (!use_cifg) {
  1101. // Calculate the input gate. (If not CIFG.)
  1102. CalculateLstmGateHybrid(
  1103. quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
  1104. input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
  1105. input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
  1106. aux_input_zp, aux_input_to_input_weights_ptr,
  1107. aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
  1108. quantized_output_state_ptr, output_state_sf, output_state_zp,
  1109. recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
  1110. recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
  1111. cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
  1112. input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
  1113. n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
  1114. input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
  1115. is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
  1116. recovered_cell_weights, scales, accum_scratch_ptr);
  1117. }
  1118. // Calculate the forget gate.
  1119. CalculateLstmGateHybrid(
  1120. quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
  1121. input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
  1122. input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
  1123. aux_input_zp, aux_input_to_forget_weights_ptr,
  1124. aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
  1125. quantized_output_state_ptr, output_state_sf, output_state_zp,
  1126. recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
  1127. recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
  1128. cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
  1129. forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
  1130. n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
  1131. forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
  1132. is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
  1133. recovered_cell_weights, scales, accum_scratch_ptr);
  1134. // Calculate the cell update gate.
  1135. CalculateLstmGateHybrid(
  1136. quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
  1137. input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
  1138. input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
  1139. aux_input_zp, aux_input_to_cell_weights_ptr,
  1140. aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
  1141. quantized_output_state_ptr, output_state_sf, output_state_zp,
  1142. recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
  1143. recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
  1144. /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
  1145. /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
  1146. cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
  1147. params->activation, cell_gate_scratch, is_input_all_zeros,
  1148. is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums,
  1149. scaling_factors_scratch, recovered_cell_weights, scales,
  1150. accum_scratch_ptr);
  1151. // Update the cell state.
  1152. UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
  1153. forget_gate_scratch, cell_gate_scratch, use_cifg,
  1154. params->cell_clip);
  1155. // Calculate the output gate.
  1156. CalculateLstmGateHybrid(
  1157. quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
  1158. input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
  1159. input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
  1160. aux_input_zp, aux_input_to_output_weights_ptr,
  1161. aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
  1162. quantized_output_state_ptr, output_state_sf, output_state_zp,
  1163. recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
  1164. recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
  1165. cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
  1166. output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
  1167. n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
  1168. output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
  1169. is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
  1170. recovered_cell_weights, scales, accum_scratch_ptr);
  1171. // Update the output state.
  1172. CalculateLstmOutputHybrid(
  1173. n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
  1174. params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
  1175. projection_weights_scale, projection_bias_ptr, params->proj_clip,
  1176. output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
  1177. compute_row_sums, scratch2, quantized_output_scratch, input_sf, input_zp,
  1178. accum_scratch_ptr, scales);
  1179. // Copy output state to the output. Note that the output's rows may not be
  1180. // contiguous (output_batch_leading_dim != n_output).
  1181. for (int b = 0; b < n_batch; b++) {
  1182. std::memcpy(output_ptr + b * output_batch_leading_dim,
  1183. output_state_ptr + b * n_output, n_output * sizeof(float));
  1184. }
  1185. }
  1186. // Fully quantized lstm kernel for 16 bit gate matmul output.
  1187. //
  1188. // Input tensor of size n_batch * n_input:
  1189. // input_ptr
  1190. //
  1191. // LSTM weights:
  1192. // Quantized input weights of size 'n_cell * n_input':
  1193. // input_to_input_weight_ptr - optional
  1194. // input_to_forget_weight_ptr - optional
  1195. // input_to_cell_weight_ptr - optional
  1196. // input_to_output_weight_ptr - optional
  1197. //
  1198. // Quantized recurrent weights of size 'n_cell * n_output':
  1199. // recurrent_to_input_weight_ptr - optional
  1200. // recurrent_to_forget_weights_ptr
  1201. // recurrent_to_cell_weights_ptr
  1202. // recurrent_to_input_weights_ptr
  1203. //
  1204. // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
  1205. // cell_to_input_weights - optional
  1206. // cell_to_cell_weights - optional
  1207. // cell_to_output_weights - optional
  1208. //
  1209. // Quantized projection weights of size 'n_output * n_cell'
  1210. // projection_weight_ptr - optional
  1211. //
  1212. // Weight scales (scalars) for each of the weights above.
  1213. // effective_input_to_input_scale_a - optional
  1214. // effective_input_to_input_scale_b - optional
  1215. // effective_input_to_forget_scale_a
  1216. // effective_input_to_forget_scale_b
  1217. // effective_input_to_cell_scale_a
  1218. // effective_input_to_cell_scale_b
  1219. // effective_input_to_output_scale_a
  1220. // effective_input_to_output_scale_b
  1221. // effective_recurrent_to_input_scale_a - optional
  1222. // effective_recurrent_to_input_scale_b - optional
  1223. // effective_recurrent_to_forget_scale_a
  1224. // effective_recurrent_to_forget_scale_b
  1225. // effective_recurrent_to_cell_scale_a
  1226. // effective_recurrent_to_cell_scale_b
  1227. // effective_recurrent_to_output_scale_a
  1228. // effective_recurrent_to_output_scale_b
  1229. // effective_proj_scale_a - optional
  1230. // effective_proj_scale_b - optional
  1231. //
  1232. // Gate biases of size 'n_cell':
  1233. // input_gate_bias_ptr - optional
  1234. // forget_gate_bias_ptr
  1235. // cell_gate_bias_ptr
  1236. // output_gate_bias_ptr
  1237. //
  1238. // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
  1239. // layer_norm_input_weight_ptr - optional
  1240. // layer_norm_forget_weight_ptr - optional
  1241. // layer_norm_cell_weight_ptr - optional
  1242. // layer_norm_output_weight_ptr - optional
  1243. //
  1244. // Layer norm scales of size 'n_cell'.
  1245. // layer_norm_input_scale_a - optional
  1246. // layer_norm_input_scale_b - optional
  1247. // layer_norm_forget_scale_a - optional
  1248. // layer_norm_forget_scale_b - optional
  1249. // layer_norm_cell_scale_a - optional
  1250. // layer_norm_cell_scale_b - optional
  1251. // layer_norm_output_scale_a - optional
  1252. // layer_norm_output_scale_b - optional
  1253. //
  1254. // Scalar values:
  1255. // quantized_cell_clip: quantized clip value for cell.
  1256. // quantized_proj_clip: quantized clip value for projection.
  1257. // cell_state_scale: the power of two scale for cell state.
  1258. //
  1259. // Zero points:
  1260. // output_state_zp: zero point of output state
  1261. // hidden_zp: zero point for hidden state.
  1262. //
  1263. // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
  1264. // n_batch.
  1265. // scratch0
  1266. // scratch1
  1267. // scratch2
  1268. // scratch3
  1269. // scratch4
  1270. // scratch5: this scratch buffer is created purely for optimizing the
  1271. // MatrixBatchVectorMultiplyAccumulate.
  1272. //
  1273. // Outputs:
  1274. // output_state_ptr - size 'n_batch * n_output'
  1275. // cell_state_ptr - size 'n_batch * n_cell'
  1276. // output_ptr - size 'n_batch * n_output'
  1277. // TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
  1278. inline void LstmStepInteger8x8_16(
  1279. const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
  1280. int32_t effective_input_to_input_scale_a,
  1281. int32_t effective_input_to_input_scale_b,
  1282. const int8_t* input_to_forget_weight_ptr,
  1283. int32_t effective_input_to_forget_scale_a,
  1284. int32_t effective_input_to_forget_scale_b,
  1285. const int8_t* input_to_cell_weight_ptr,
  1286. int32_t effective_input_to_cell_scale_a,
  1287. int32_t effective_input_to_cell_scale_b,
  1288. const int8_t* input_to_output_weight_ptr,
  1289. int32_t effective_input_to_output_scale_a,
  1290. int32_t effective_input_to_output_scale_b,
  1291. const int8_t* recurrent_to_input_weight_ptr,
  1292. int32_t effective_recurrent_to_input_scale_a,
  1293. int32_t effective_recurrent_to_input_scale_b,
  1294. const int8_t* recurrent_to_forget_weight_ptr,
  1295. int32_t effective_recurrent_to_forget_scale_a,
  1296. int32_t effective_recurrent_to_forget_scale_b,
  1297. const int8_t* recurrent_to_cell_weight_ptr,
  1298. int32_t effective_recurrent_to_cell_scale_a,
  1299. int32_t effective_recurrent_to_cell_scale_b,
  1300. const int8_t* recurrent_to_output_weight_ptr,
  1301. int32_t effective_recurrent_to_output_scale_a,
  1302. int32_t effective_recurrent_to_output_scale_b,
  1303. const int16_t* cell_to_input_weight_ptr,
  1304. int32_t effective_cell_to_input_scale_a,
  1305. int32_t effective_cell_to_input_scale_b,
  1306. const int16_t* cell_to_forget_weight_ptr,
  1307. int32_t effective_cell_to_forget_scale_a,
  1308. int32_t effective_cell_to_forget_scale_b,
  1309. const int16_t* cell_to_output_weight_ptr,
  1310. int32_t effective_cell_to_output_scale_a,
  1311. int32_t effective_cell_to_output_scale_b,
  1312. const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
  1313. int32_t effective_proj_scale_b, int32_t hidden_zp,
  1314. int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
  1315. const int16_t* layer_norm_input_weight_ptr,
  1316. int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
  1317. const int16_t* layer_norm_forget_weight_ptr,
  1318. int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
  1319. const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
  1320. int32_t layer_norm_cell_scale_b,
  1321. const int16_t* layer_norm_output_weight_ptr,
  1322. int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
  1323. const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
  1324. const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
  1325. int16_t quantized_cell_clip, int8_t quantized_proj_clip,
  1326. int32_t cell_state_scale, int32_t input_variance_guard,
  1327. int32_t forget_variance_guard, int32_t cell_variance_guard,
  1328. int32_t output_variance_guard,
  1329. const int32_t* input_to_forget_effective_bias,
  1330. const int32_t* recurrent_to_forget_effective_bias,
  1331. const int32_t* input_to_cell_effective_bias,
  1332. const int32_t* recurrent_to_cell_effective_bias,
  1333. const int32_t* input_to_output_effective_bias,
  1334. const int32_t* recurrent_to_output_effective_bias,
  1335. const int32_t* input_to_input_effective_bias,
  1336. const int32_t* recurrent_to_input_effective_bias,
  1337. const int32_t* projection_effective_bias, int n_batch, int n_cell,
  1338. int n_input, int n_output, int8_t* output_state_ptr,
  1339. int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
  1340. int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
  1341. int8_t* scratch4, int32_t* scratch5) {
  1342. // Make named scratch buffers for the different gates.
  1343. int16_t* input_gate_scratch = scratch0;
  1344. int16_t* forget_gate_scratch = scratch1;
  1345. int16_t* cell_gate_scratch = scratch2;
  1346. int16_t* output_gate_scratch = scratch3;
  1347. // Since we have already checked that weights are all there or none, we
  1348. // can check the existence of only one to the get the condition.
  1349. const bool use_cifg = (input_to_input_weight_ptr == nullptr);
  1350. // Check for nullptrs.
  1351. TFLITE_DCHECK(input_to_forget_effective_bias);
  1352. TFLITE_DCHECK(recurrent_to_forget_effective_bias);
  1353. TFLITE_DCHECK(input_to_cell_effective_bias);
  1354. TFLITE_DCHECK(recurrent_to_cell_effective_bias);
  1355. TFLITE_DCHECK(input_to_output_effective_bias);
  1356. TFLITE_DCHECK(recurrent_to_output_effective_bias);
  1357. if (!use_cifg) {
  1358. TFLITE_DCHECK(input_to_input_effective_bias);
  1359. TFLITE_DCHECK(recurrent_to_input_effective_bias);
  1360. }
  1361. const bool use_projection = (projection_weight_ptr != nullptr);
  1362. if (use_projection) {
  1363. TFLITE_DCHECK(projection_effective_bias);
  1364. }
  1365. if (!use_cifg) {
  1366. // Calculate the input gate. (If not CIFG.)
  1367. CalculateLstmGateInteger8x8_16(
  1368. input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
  1369. effective_input_to_input_scale_a, effective_input_to_input_scale_b,
  1370. output_state_ptr, recurrent_to_input_weight_ptr,
  1371. recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
  1372. effective_recurrent_to_input_scale_b, cell_state_ptr,
  1373. cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
  1374. effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
  1375. input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
  1376. input_variance_guard, n_batch, n_input, n_output, n_cell,
  1377. kTfLiteActSigmoid, input_gate_scratch, scratch5);
  1378. }
  1379. // Calculate the forget gate.
  1380. CalculateLstmGateInteger8x8_16(
  1381. input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
  1382. effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
  1383. output_state_ptr, recurrent_to_forget_weight_ptr,
  1384. recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
  1385. effective_recurrent_to_forget_scale_b, cell_state_ptr,
  1386. cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
  1387. effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
  1388. forget_gate_bias_ptr, layer_norm_forget_scale_a,
  1389. layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
  1390. n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, scratch5);
  1391. // Calculate the cell update gate.
  1392. CalculateLstmGateInteger8x8_16(
  1393. input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
  1394. effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
  1395. output_state_ptr, recurrent_to_cell_weight_ptr,
  1396. recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
  1397. effective_recurrent_to_cell_scale_b, cell_state_ptr,
  1398. /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
  1399. /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
  1400. cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
  1401. cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
  1402. cell_gate_scratch, scratch5);
  1403. // Update the cell state.
  1404. UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
  1405. input_gate_scratch, forget_gate_scratch,
  1406. cell_gate_scratch, use_cifg, quantized_cell_clip);
  1407. // Calculate the output gate.
  1408. CalculateLstmGateInteger8x8_16(
  1409. input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
  1410. effective_input_to_output_scale_a, effective_input_to_output_scale_b,
  1411. output_state_ptr, recurrent_to_output_weight_ptr,
  1412. recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
  1413. effective_recurrent_to_output_scale_b, cell_state_ptr,
  1414. cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
  1415. effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
  1416. output_gate_bias_ptr, layer_norm_output_scale_a,
  1417. layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
  1418. n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, scratch5);
  1419. // Update the output state.
  1420. CalculateLstmOutputInteger8x8_16(
  1421. n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
  1422. output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
  1423. hidden_zp, projection_weight_ptr, effective_proj_scale_a,
  1424. effective_proj_scale_b, projection_effective_bias, output_state_zp,
  1425. quantized_proj_clip, output_state_ptr, scratch0, scratch4, scratch5);
  1426. // Copy output state to the output. Note that unlike float or hybrid, output
  1427. // is always contiguous.
  1428. std::memcpy(output_ptr, output_state_ptr,
  1429. n_batch * n_output * sizeof(int8_t));
  1430. }
  1431. // Fully quantized lstm kernel for 8 bit gate matmul output.
  1432. //
  1433. // Input tensor of size n_batch * n_input:
  1434. // input_ptr
  1435. //
  1436. // LSTM weights:
  1437. // Quantized input weights of size 'n_cell * n_input':
  1438. // input_to_input_weight_ptr - optional
  1439. // input_to_forget_weight_ptr - optional
  1440. // input_to_cell_weight_ptr - optional
  1441. // input_to_output_weight_ptr - optional
  1442. //
  1443. // Quantized recurrent weights of size 'n_cell * n_output':
  1444. // recurrent_to_input_weight_ptr - optional
  1445. // recurrent_to_forget_weights_ptr
  1446. // recurrent_to_cell_weights_ptr
  1447. // recurrent_to_input_weights_ptr
  1448. //
  1449. // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
  1450. // cell_to_input_weights - optional
  1451. // cell_to_cell_weights - optional
  1452. // cell_to_output_weights - optional
  1453. //
  1454. // Quantized projection weights of size 'n_output * n_cell'
  1455. // projection_weight_ptr - optional
  1456. //
  1457. // Weight scales (scalars) for each of the weights above.
  1458. // effective_input_to_input_scale_a - optional
  1459. // effective_input_to_input_scale_b - optional
  1460. // effective_input_to_forget_scale_a
  1461. // effective_input_to_forget_scale_b
  1462. // effective_input_to_cell_scale_a
  1463. // effective_input_to_cell_scale_b
  1464. // effective_input_to_output_scale_a
  1465. // effective_input_to_output_scale_b
  1466. // effective_recurrent_to_input_scale_a - optional
  1467. // effective_recurrent_to_input_scale_b - optional
  1468. // effective_recurrent_to_forget_scale_a
  1469. // effective_recurrent_to_forget_scale_b
  1470. // effective_recurrent_to_cell_scale_a
  1471. // effective_recurrent_to_cell_scale_b
  1472. // effective_recurrent_to_output_scale_a
  1473. // effective_recurrent_to_output_scale_b
  1474. // effective_proj_scale_a - optional
  1475. // effective_proj_scale_b - optional
  1476. //
  1477. // Gate biases of size 'n_cell':
  1478. // input_gate_bias_ptr - optional
  1479. // forget_gate_bias_ptr
  1480. // cell_gate_bias_ptr
  1481. // output_gate_bias_ptr
  1482. //
  1483. // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
  1484. // layer_norm_input_weight_ptr - optional
  1485. // layer_norm_forget_weight_ptr - optional
  1486. // layer_norm_cell_weight_ptr - optional
  1487. // layer_norm_output_weight_ptr - optional
  1488. //
  1489. // Layer norm scales of size 'n_cell'.
  1490. // layer_norm_input_scale_a - optional
  1491. // layer_norm_input_scale_b - optional
  1492. // layer_norm_forget_scale_a - optional
  1493. // layer_norm_forget_scale_b - optional
  1494. // layer_norm_cell_scale_a - optional
  1495. // layer_norm_cell_scale_b - optional
  1496. // layer_norm_output_scale_a - optional
  1497. // layer_norm_output_scale_b - optional
  1498. //
  1499. // Scalar values:
  1500. // quantized_cell_clip: quantized clip value for cell.
  1501. // quantized_proj_clip: quantized clip value for projection.
  1502. // cell_state_scale: the power of two scale for cell state.
  1503. //
  1504. // Zero points:
  1505. // input_zp: zero point for input tensor.
  1506. // output_state_zp: zero point of output state.
  1507. // hidden_zp: zero point for hidden state.
  1508. //
  1509. // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
  1510. // n_batch.
  1511. // scratch0
  1512. // scratch1
  1513. // scratch2
  1514. // scratch3
  1515. // scratch4
  1516. // scratch5
  1517. // scratch6
  1518. // scratch7
  1519. //
  1520. // Outputs:
  1521. // output_state_ptr - size 'n_batch * n_output'
  1522. // cell_state_ptr - size 'n_batch * n_cell'
  1523. // output_ptr - size 'n_batch * n_output'
  1524. //
  1525. // Can move zero point calculation into Prepare() for better perfomance.
  1526. // TODO(b/159947023): scratch5 is unused, remove.
  1527. inline void LstmStepInteger8x8_8(
  1528. const int8_t* input_ptr, int32_t input_zp,
  1529. const int8_t* input_to_input_weight_ptr,
  1530. int32_t effective_input_to_input_scale_a,
  1531. int32_t effective_input_to_input_scale_b,
  1532. const int8_t* input_to_forget_weight_ptr,
  1533. int32_t effective_input_to_forget_scale_a,
  1534. int32_t effective_input_to_forget_scale_b,
  1535. const int8_t* input_to_cell_weight_ptr,
  1536. int32_t effective_input_to_cell_scale_a,
  1537. int32_t effective_input_to_cell_scale_b,
  1538. const int8_t* input_to_output_weight_ptr,
  1539. int32_t effective_input_to_output_scale_a,
  1540. int32_t effective_input_to_output_scale_b,
  1541. const int8_t* recurrent_to_input_weight_ptr,
  1542. int32_t effective_recurrent_to_input_scale_a,
  1543. int32_t effective_recurrent_to_input_scale_b,
  1544. const int8_t* recurrent_to_forget_weight_ptr,
  1545. int32_t effective_recurrent_to_forget_scale_a,
  1546. int32_t effective_recurrent_to_forget_scale_b,
  1547. const int8_t* recurrent_to_cell_weight_ptr,
  1548. int32_t effective_recurrent_to_cell_scale_a,
  1549. int32_t effective_recurrent_to_cell_scale_b,
  1550. const int8_t* recurrent_to_output_weight_ptr,
  1551. int32_t effective_recurrent_to_output_scale_a,
  1552. int32_t effective_recurrent_to_output_scale_b,
  1553. const int8_t* cell_to_input_weight_ptr,
  1554. int32_t effective_cell_to_input_scale_a,
  1555. int32_t effective_cell_to_input_scale_b,
  1556. const int8_t* cell_to_forget_weight_ptr,
  1557. int32_t effective_cell_to_forget_scale_a,
  1558. int32_t effective_cell_to_forget_scale_b,
  1559. const int8_t* cell_to_output_weight_ptr,
  1560. int32_t effective_cell_to_output_scale_a,
  1561. int32_t effective_cell_to_output_scale_b,
  1562. const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
  1563. int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
  1564. int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
  1565. const int16_t* layer_norm_forget_weight_ptr,
  1566. int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
  1567. const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
  1568. int32_t layer_norm_cell_scale_b,
  1569. const int16_t* layer_norm_output_weight_ptr,
  1570. int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
  1571. const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
  1572. const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
  1573. const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
  1574. const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
  1575. const int32_t* intermediate_zp, int16_t quantized_cell_clip,
  1576. int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
  1577. int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
  1578. int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
  1579. int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
  1580. int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
  1581. int16_t* scratch7) {
  1582. // TODO(b/159066113): scratch5 is unused, remove.
  1583. // Make named scratch buffers for the different gates.
  1584. int16_t* forget_gate_scratch = scratch2;
  1585. int16_t* cell_gate_scratch = scratch3;
  1586. int16_t* output_gate_scratch = scratch4;
  1587. // no-CIFG is not supported here
  1588. // Calculate the forget gate.
  1589. CalculateLstmGateInteger8x8_8(
  1590. input_ptr, input_zp, input_to_forget_weight_ptr,
  1591. effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
  1592. intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
  1593. output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
  1594. effective_recurrent_to_forget_scale_a,
  1595. effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
  1596. intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
  1597. layer_norm_forget_scale_a, layer_norm_forget_scale_b,
  1598. forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
  1599. kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
  1600. // Calculate the cell update gate.
  1601. CalculateLstmGateInteger8x8_8(
  1602. input_ptr, input_zp, input_to_cell_weight_ptr,
  1603. effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
  1604. intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
  1605. output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
  1606. effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
  1607. intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
  1608. layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
  1609. layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
  1610. n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
  1611. // Update the cell state.
  1612. UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
  1613. /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
  1614. forget_gate_scratch, cell_gate_scratch,
  1615. /*use_cifg=*/true, quantized_cell_clip);
  1616. // Calculate the output gate.
  1617. CalculateLstmGateInteger8x8_8(
  1618. input_ptr, input_zp, input_to_output_weight_ptr,
  1619. effective_input_to_output_scale_a, effective_input_to_output_scale_b,
  1620. intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
  1621. output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
  1622. effective_recurrent_to_output_scale_a,
  1623. effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
  1624. intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
  1625. layer_norm_output_scale_a, layer_norm_output_scale_b,
  1626. output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
  1627. kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
  1628. // Update the output state.
  1629. CalculateLstmOutputInteger8x8_8(
  1630. n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
  1631. projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
  1632. projection_bias_ptr, output_state_zp, quantized_proj_clip,
  1633. output_state_ptr, scratch2);
  1634. // Copy output state to the output. Note that unlike float or hybrid, output
  1635. // is always contigous.
  1636. std::memcpy(output_ptr, output_state_ptr,
  1637. n_batch * n_output * sizeof(int8_t));
  1638. }
  1639. } // namespace
  1640. TfLiteStatus EvalFloatLstm(
  1641. const TfLiteEvalTensor* input,
  1642. const TfLiteEvalTensor* input_to_input_weights,
  1643. const TfLiteEvalTensor* input_to_forget_weights,
  1644. const TfLiteEvalTensor* input_to_cell_weights,
  1645. const TfLiteEvalTensor* input_to_output_weights,
  1646. const TfLiteEvalTensor* recurrent_to_input_weights,
  1647. const TfLiteEvalTensor* recurrent_to_forget_weights,
  1648. const TfLiteEvalTensor* recurrent_to_cell_weights,
  1649. const TfLiteEvalTensor* recurrent_to_output_weights,
  1650. const TfLiteEvalTensor* cell_to_input_weights,
  1651. const TfLiteEvalTensor* cell_to_forget_weights,
  1652. const TfLiteEvalTensor* cell_to_output_weights,
  1653. const TfLiteEvalTensor* input_layer_norm_coefficients,
  1654. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  1655. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  1656. const TfLiteEvalTensor* output_layer_norm_coefficients,
  1657. const TfLiteEvalTensor* aux_input,
  1658. const TfLiteEvalTensor* aux_input_to_input_weights,
  1659. const TfLiteEvalTensor* aux_input_to_forget_weights,
  1660. const TfLiteEvalTensor* aux_input_to_cell_weights,
  1661. const TfLiteEvalTensor* aux_input_to_output_weights,
  1662. const TfLiteEvalTensor* input_gate_bias,
  1663. const TfLiteEvalTensor* forget_gate_bias,
  1664. const TfLiteEvalTensor* cell_gate_bias,
  1665. const TfLiteEvalTensor* output_gate_bias,
  1666. const TfLiteEvalTensor* projection_weights,
  1667. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  1668. bool forward_sequence, bool time_major, int output_offset,
  1669. float* scratch_buffer, TfLiteEvalTensor* output_state,
  1670. TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output) {
  1671. TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
  1672. int max_time, n_batch;
  1673. if (input->dims->size == 3) {
  1674. max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
  1675. n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
  1676. } else {
  1677. max_time = 1;
  1678. n_batch = input->dims->data[0];
  1679. }
  1680. const int n_input = input->dims->data[input->dims->size - 1];
  1681. const int aux_input_size =
  1682. (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
  1683. // n_cell and n_output will be the same size when there is no projection.
  1684. const int n_cell = input_to_output_weights->dims->data[0];
  1685. const int n_output = recurrent_to_output_weights->dims->data[1];
  1686. // Since we have already checked that weights are all there or none, we can
  1687. // check the existence of only one to the get the condition.
  1688. const bool use_cifg = (input_to_input_weights == nullptr);
  1689. // Index the scratch buffers pointers to the global scratch buffer.
  1690. float* input_gate_scratch = nullptr;
  1691. float* cell_gate_scratch = nullptr;
  1692. float* forget_gate_scratch = nullptr;
  1693. float* output_gate_scratch = nullptr;
  1694. if (use_cifg) {
  1695. cell_gate_scratch = scratch_buffer;
  1696. forget_gate_scratch = scratch_buffer + n_cell * n_batch;
  1697. output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
  1698. } else {
  1699. input_gate_scratch = scratch_buffer;
  1700. cell_gate_scratch = scratch_buffer + n_cell * n_batch;
  1701. forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
  1702. output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch;
  1703. }
  1704. const int output_batch_leading_dim =
  1705. output->dims->data[output->dims->size - 1];
  1706. if (time_major) {
  1707. // Loop through the sequence.
  1708. const int input_step = n_batch * n_input;
  1709. const int output_step = n_batch * output_batch_leading_dim;
  1710. for (int t = 0; t < max_time; t++) {
  1711. // If this is the forward_sequence, step forward, otherwise step
  1712. // backwards.
  1713. const int t_rel = forward_sequence ? t : max_time - t - 1;
  1714. const float* input_ptr =
  1715. tflite::micro::GetTensorData<float>(input) + t_rel * input_step;
  1716. const float* aux_input_ptr = nullptr;
  1717. if (aux_input) {
  1718. aux_input_ptr =
  1719. tflite::micro::GetTensorData<float>(aux_input) + t_rel * input_step;
  1720. }
  1721. float* output_ptr = tflite::micro::GetTensorData<float>(output) +
  1722. t_rel * output_step + output_offset;
  1723. LstmStepFloat(
  1724. input_ptr,
  1725. input_to_input_weights == nullptr
  1726. ? nullptr
  1727. : tflite::micro::GetTensorData<float>(input_to_input_weights),
  1728. input_to_forget_weights == nullptr
  1729. ? nullptr
  1730. : tflite::micro::GetTensorData<float>(input_to_forget_weights),
  1731. input_to_cell_weights == nullptr
  1732. ? nullptr
  1733. : tflite::micro::GetTensorData<float>(input_to_cell_weights),
  1734. input_to_output_weights == nullptr
  1735. ? nullptr
  1736. : tflite::micro::GetTensorData<float>(input_to_output_weights),
  1737. aux_input_ptr,
  1738. aux_input_to_input_weights == nullptr
  1739. ? nullptr
  1740. : tflite::micro::GetTensorData<float>(aux_input_to_input_weights),
  1741. aux_input_to_forget_weights == nullptr
  1742. ? nullptr
  1743. : tflite::micro::GetTensorData<float>(
  1744. aux_input_to_forget_weights),
  1745. aux_input_to_cell_weights == nullptr
  1746. ? nullptr
  1747. : tflite::micro::GetTensorData<float>(aux_input_to_cell_weights),
  1748. aux_input_to_output_weights == nullptr
  1749. ? nullptr
  1750. : tflite::micro::GetTensorData<float>(
  1751. aux_input_to_output_weights),
  1752. recurrent_to_input_weights == nullptr
  1753. ? nullptr
  1754. : tflite::micro::GetTensorData<float>(recurrent_to_input_weights),
  1755. recurrent_to_forget_weights == nullptr
  1756. ? nullptr
  1757. : tflite::micro::GetTensorData<float>(
  1758. recurrent_to_forget_weights),
  1759. recurrent_to_cell_weights == nullptr
  1760. ? nullptr
  1761. : tflite::micro::GetTensorData<float>(recurrent_to_cell_weights),
  1762. recurrent_to_output_weights == nullptr
  1763. ? nullptr
  1764. : tflite::micro::GetTensorData<float>(
  1765. recurrent_to_output_weights),
  1766. cell_to_input_weights == nullptr
  1767. ? nullptr
  1768. : tflite::micro::GetTensorData<float>(cell_to_input_weights),
  1769. cell_to_forget_weights == nullptr
  1770. ? nullptr
  1771. : tflite::micro::GetTensorData<float>(cell_to_forget_weights),
  1772. cell_to_output_weights == nullptr
  1773. ? nullptr
  1774. : tflite::micro::GetTensorData<float>(cell_to_output_weights),
  1775. input_layer_norm_coefficients == nullptr
  1776. ? nullptr
  1777. : tflite::micro::GetTensorData<float>(
  1778. input_layer_norm_coefficients),
  1779. forget_layer_norm_coefficients == nullptr
  1780. ? nullptr
  1781. : tflite::micro::GetTensorData<float>(
  1782. forget_layer_norm_coefficients),
  1783. cell_layer_norm_coefficients == nullptr
  1784. ? nullptr
  1785. : tflite::micro::GetTensorData<float>(
  1786. cell_layer_norm_coefficients),
  1787. output_layer_norm_coefficients == nullptr
  1788. ? nullptr
  1789. : tflite::micro::GetTensorData<float>(
  1790. output_layer_norm_coefficients),
  1791. input_gate_bias == nullptr
  1792. ? nullptr
  1793. : tflite::micro::GetTensorData<float>(input_gate_bias),
  1794. forget_gate_bias == nullptr
  1795. ? nullptr
  1796. : tflite::micro::GetTensorData<float>(forget_gate_bias),
  1797. cell_gate_bias == nullptr
  1798. ? nullptr
  1799. : tflite::micro::GetTensorData<float>(cell_gate_bias),
  1800. output_gate_bias == nullptr
  1801. ? nullptr
  1802. : tflite::micro::GetTensorData<float>(output_gate_bias),
  1803. projection_weights == nullptr
  1804. ? nullptr
  1805. : tflite::micro::GetTensorData<float>(projection_weights),
  1806. projection_bias == nullptr
  1807. ? nullptr
  1808. : tflite::micro::GetTensorData<float>(projection_bias),
  1809. params, n_batch, n_cell, n_input, aux_input_size, n_output,
  1810. output_batch_leading_dim,
  1811. tflite::micro::GetTensorData<float>(output_state),
  1812. tflite::micro::GetTensorData<float>(cell_state), input_gate_scratch,
  1813. forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
  1814. output_ptr);
  1815. }
  1816. } else {
  1817. for (int b = 0; b < n_batch; b++) {
  1818. const int input_step = n_input;
  1819. const int output_step = output_batch_leading_dim;
  1820. for (int t = 0; t < max_time; t++) {
  1821. // If this is the forward_sequence, step forward, otherwise step
  1822. // backwards.
  1823. const int t_rel = forward_sequence ? t : max_time - t - 1;
  1824. const int time_offset = b * max_time + t_rel;
  1825. const float* input_ptr = tflite::micro::GetTensorData<float>(input) +
  1826. time_offset * input_step;
  1827. const float* aux_input_ptr = nullptr;
  1828. if (aux_input) {
  1829. aux_input_ptr = tflite::micro::GetTensorData<float>(aux_input) +
  1830. time_offset * input_step;
  1831. }
  1832. float* output_ptr = tflite::micro::GetTensorData<float>(output) +
  1833. time_offset * output_step + output_offset;
  1834. // Offset the {output,cell}_state pointers to the right batch.
  1835. float* output_state_ptr =
  1836. tflite::micro::GetTensorData<float>(output_state) +
  1837. b * output_batch_leading_dim;
  1838. float* cell_state_ptr =
  1839. tflite::micro::GetTensorData<float>(cell_state) + b * n_cell;
  1840. // Offset the scratch pointers to the right batch.
  1841. float* input_gate_scratch_ptr =
  1842. input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
  1843. float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
  1844. float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
  1845. float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
  1846. LstmStepFloat(
  1847. input_ptr,
  1848. input_to_input_weights == nullptr
  1849. ? nullptr
  1850. : tflite::micro::GetTensorData<float>(input_to_input_weights),
  1851. input_to_forget_weights == nullptr
  1852. ? nullptr
  1853. : tflite::micro::GetTensorData<float>(input_to_forget_weights),
  1854. input_to_cell_weights == nullptr
  1855. ? nullptr
  1856. : tflite::micro::GetTensorData<float>(input_to_cell_weights),
  1857. input_to_output_weights == nullptr
  1858. ? nullptr
  1859. : tflite::micro::GetTensorData<float>(input_to_output_weights),
  1860. aux_input_ptr,
  1861. aux_input_to_input_weights == nullptr
  1862. ? nullptr
  1863. : tflite::micro::GetTensorData<float>(
  1864. aux_input_to_input_weights),
  1865. aux_input_to_forget_weights == nullptr
  1866. ? nullptr
  1867. : tflite::micro::GetTensorData<float>(
  1868. aux_input_to_forget_weights),
  1869. aux_input_to_cell_weights == nullptr
  1870. ? nullptr
  1871. : tflite::micro::GetTensorData<float>(
  1872. aux_input_to_cell_weights),
  1873. aux_input_to_output_weights == nullptr
  1874. ? nullptr
  1875. : tflite::micro::GetTensorData<float>(
  1876. aux_input_to_output_weights),
  1877. recurrent_to_input_weights == nullptr
  1878. ? nullptr
  1879. : tflite::micro::GetTensorData<float>(
  1880. recurrent_to_input_weights),
  1881. recurrent_to_forget_weights == nullptr
  1882. ? nullptr
  1883. : tflite::micro::GetTensorData<float>(
  1884. recurrent_to_forget_weights),
  1885. recurrent_to_cell_weights == nullptr
  1886. ? nullptr
  1887. : tflite::micro::GetTensorData<float>(
  1888. recurrent_to_cell_weights),
  1889. recurrent_to_output_weights == nullptr
  1890. ? nullptr
  1891. : tflite::micro::GetTensorData<float>(
  1892. recurrent_to_output_weights),
  1893. cell_to_input_weights == nullptr
  1894. ? nullptr
  1895. : tflite::micro::GetTensorData<float>(cell_to_input_weights),
  1896. cell_to_forget_weights == nullptr
  1897. ? nullptr
  1898. : tflite::micro::GetTensorData<float>(cell_to_forget_weights),
  1899. cell_to_output_weights == nullptr
  1900. ? nullptr
  1901. : tflite::micro::GetTensorData<float>(cell_to_output_weights),
  1902. input_layer_norm_coefficients == nullptr
  1903. ? nullptr
  1904. : tflite::micro::GetTensorData<float>(
  1905. input_layer_norm_coefficients),
  1906. forget_layer_norm_coefficients == nullptr
  1907. ? nullptr
  1908. : tflite::micro::GetTensorData<float>(
  1909. forget_layer_norm_coefficients),
  1910. cell_layer_norm_coefficients == nullptr
  1911. ? nullptr
  1912. : tflite::micro::GetTensorData<float>(
  1913. cell_layer_norm_coefficients),
  1914. output_layer_norm_coefficients == nullptr
  1915. ? nullptr
  1916. : tflite::micro::GetTensorData<float>(
  1917. output_layer_norm_coefficients),
  1918. input_gate_bias == nullptr
  1919. ? nullptr
  1920. : tflite::micro::GetTensorData<float>(input_gate_bias),
  1921. forget_gate_bias == nullptr
  1922. ? nullptr
  1923. : tflite::micro::GetTensorData<float>(forget_gate_bias),
  1924. cell_gate_bias == nullptr
  1925. ? nullptr
  1926. : tflite::micro::GetTensorData<float>(cell_gate_bias),
  1927. output_gate_bias == nullptr
  1928. ? nullptr
  1929. : tflite::micro::GetTensorData<float>(output_gate_bias),
  1930. projection_weights == nullptr
  1931. ? nullptr
  1932. : tflite::micro::GetTensorData<float>(projection_weights),
  1933. projection_bias == nullptr
  1934. ? nullptr
  1935. : tflite::micro::GetTensorData<float>(projection_bias),
  1936. params,
  1937. /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
  1938. output_batch_leading_dim, output_state_ptr, cell_state_ptr,
  1939. input_gate_scratch_ptr, forget_gate_scratch_ptr,
  1940. cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
  1941. }
  1942. }
  1943. }
  1944. return kTfLiteOk;
  1945. }
  1946. TfLiteStatus EvalHybridLstm(
  1947. const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
  1948. const TfLiteEvalTensor* input_to_input_weights,
  1949. const TfLiteEvalTensor* input_to_input_weights_ledger,
  1950. const TfLiteEvalTensor* input_to_forget_weights,
  1951. const TfLiteEvalTensor* input_to_forget_weights_ledger,
  1952. const TfLiteEvalTensor* input_to_cell_weights,
  1953. const TfLiteEvalTensor* input_to_cell_weights_ledger,
  1954. const TfLiteEvalTensor* input_to_output_weights,
  1955. const TfLiteEvalTensor* input_to_output_weights_ledger,
  1956. const TfLiteEvalTensor* recurrent_to_input_weights,
  1957. const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
  1958. const TfLiteEvalTensor* recurrent_to_forget_weights,
  1959. const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
  1960. const TfLiteEvalTensor* recurrent_to_cell_weights,
  1961. const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
  1962. const TfLiteEvalTensor* recurrent_to_output_weights,
  1963. const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
  1964. const TfLiteEvalTensor* cell_to_input_weights,
  1965. const TfLiteEvalTensor* cell_to_forget_weights,
  1966. const TfLiteEvalTensor* cell_to_output_weights,
  1967. const TfLiteEvalTensor* input_layer_norm_coefficients,
  1968. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  1969. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  1970. const TfLiteEvalTensor* output_layer_norm_coefficients,
  1971. const TfLiteEvalTensor* aux_input,
  1972. const TfLiteEvalTensor* aux_input_to_input_weights,
  1973. const TfLiteEvalTensor* aux_input_to_forget_weights,
  1974. const TfLiteEvalTensor* aux_input_to_cell_weights,
  1975. const TfLiteEvalTensor* aux_input_to_output_weights,
  1976. const TfLiteEvalTensor* input_gate_bias,
  1977. const TfLiteEvalTensor* forget_gate_bias,
  1978. const TfLiteEvalTensor* cell_gate_bias,
  1979. const TfLiteEvalTensor* output_gate_bias,
  1980. const TfLiteEvalTensor* projection_weights,
  1981. const TfLiteEvalTensor* projection_weights_ledger,
  1982. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  1983. bool forward_sequence, bool time_major, int output_offset,
  1984. float* scratch_buffer, float* input_sf, float* aux_input_sf,
  1985. float* output_state_sf, float* prod_scaling_factors,
  1986. float* recovered_cell_weights, int8_t* input_quantized,
  1987. int8_t* aux_input_quantized, int8_t* output_state_quantized,
  1988. int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
  1989. TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
  1990. TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
  1991. int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
  1992. bool* compute_row_sums) {
  1993. TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
  1994. const int n_input = input->dims->data[input->dims->size - 1];
  1995. int max_time, n_batch;
  1996. if (input->dims->size == 2) {
  1997. max_time = 1;
  1998. n_batch = input->dims->data[0];
  1999. } else {
  2000. max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
  2001. n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
  2002. }
  2003. const int aux_input_size =
  2004. (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
  2005. // n_cell and n_output will be the same size when there is no projection.
  2006. const int n_cell = input_to_output_weights->dims->data[0];
  2007. const int n_output = recurrent_to_output_weights->dims->data[1];
  2008. // Since we have already checked that weights are all there or none, we can
  2009. // check the existence of only one to get the condition.
  2010. const bool use_cifg = (input_to_input_weights == nullptr);
  2011. float* input_gate_scratch = nullptr;
  2012. float* cell_gate_scratch = nullptr;
  2013. float* forget_gate_scratch = nullptr;
  2014. float* output_gate_scratch = nullptr;
  2015. if (use_cifg) {
  2016. cell_gate_scratch = scratch_buffer;
  2017. forget_gate_scratch = scratch_buffer + n_cell * n_batch;
  2018. output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
  2019. } else {
  2020. input_gate_scratch = scratch_buffer;
  2021. cell_gate_scratch = scratch_buffer + n_cell * n_batch;
  2022. forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
  2023. output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch;
  2024. }
  2025. const int output_batch_leading_dim =
  2026. output->dims->data[output->dims->size - 1];
  2027. int32_t* input_zp_ptr = nullptr;
  2028. int32_t* aux_input_zp_ptr = nullptr;
  2029. int32_t* output_state_zp_ptr = nullptr;
  2030. int32_t* row_sums_ptr = nullptr;
  2031. if (params->asymmetric_quantize_inputs) {
  2032. input_zp_ptr = input_zp;
  2033. aux_input_zp_ptr = aux_input_zp;
  2034. output_state_zp_ptr = output_state_zp;
  2035. row_sums_ptr = row_sums;
  2036. }
  2037. if (time_major) {
  2038. // Feed the sequence into the LSTM step-by-step.
  2039. const int input_step = n_batch * n_input;
  2040. const int output_step = n_batch * output_batch_leading_dim;
  2041. for (int t = 0; t < max_time; t++) {
  2042. // If this is the forward_sequence, step forward, otherwise step
  2043. // backwards.
  2044. const int t_rel = forward_sequence ? t : max_time - t - 1;
  2045. const float* input_ptr =
  2046. tflite::micro::GetTensorData<float>(input) + t_rel * input_step;
  2047. const float* aux_input_ptr = nullptr;
  2048. if (aux_input) {
  2049. aux_input_ptr =
  2050. tflite::micro::GetTensorData<float>(aux_input) + t_rel * input_step;
  2051. }
  2052. float* output_ptr = tflite::micro::GetTensorData<float>(output) +
  2053. t_rel * output_step + output_offset;
  2054. LstmStepHybrid(
  2055. input_ptr,
  2056. input_to_input_weights == nullptr
  2057. ? nullptr
  2058. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2059. input_to_input_weights_ledger == nullptr
  2060. ? nullptr
  2061. : tflite::micro::GetTensorData<uint8_t>(
  2062. input_to_input_weights_ledger),
  2063. hybrid_lstm_scales->input_to_input_weights_scale,
  2064. input_to_forget_weights == nullptr
  2065. ? nullptr
  2066. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2067. input_to_forget_weights_ledger == nullptr
  2068. ? nullptr
  2069. : tflite::micro::GetTensorData<uint8_t>(
  2070. input_to_forget_weights_ledger),
  2071. hybrid_lstm_scales->input_to_forget_weights_scale,
  2072. input_to_cell_weights == nullptr
  2073. ? nullptr
  2074. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2075. input_to_cell_weights_ledger == nullptr
  2076. ? nullptr
  2077. : tflite::micro::GetTensorData<uint8_t>(
  2078. input_to_cell_weights_ledger),
  2079. hybrid_lstm_scales->input_to_cell_weights_scale,
  2080. input_to_output_weights == nullptr
  2081. ? nullptr
  2082. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2083. input_to_output_weights_ledger == nullptr
  2084. ? nullptr
  2085. : tflite::micro::GetTensorData<uint8_t>(
  2086. input_to_output_weights_ledger),
  2087. hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr,
  2088. aux_input_to_input_weights == nullptr
  2089. ? nullptr
  2090. : tflite::micro::GetTensorData<int8_t>(
  2091. aux_input_to_input_weights),
  2092. hybrid_lstm_scales->aux_input_to_input_weights_scale,
  2093. aux_input_to_forget_weights == nullptr
  2094. ? nullptr
  2095. : tflite::micro::GetTensorData<int8_t>(
  2096. aux_input_to_forget_weights),
  2097. hybrid_lstm_scales->aux_input_to_forget_weights_scale,
  2098. aux_input_to_cell_weights == nullptr
  2099. ? nullptr
  2100. : tflite::micro::GetTensorData<int8_t>(aux_input_to_cell_weights),
  2101. hybrid_lstm_scales->aux_input_to_cell_weights_scale,
  2102. aux_input_to_output_weights == nullptr
  2103. ? nullptr
  2104. : tflite::micro::GetTensorData<int8_t>(
  2105. aux_input_to_output_weights),
  2106. hybrid_lstm_scales->aux_input_to_output_weights_scale,
  2107. recurrent_to_input_weights == nullptr
  2108. ? nullptr
  2109. : tflite::micro::GetTensorData<int8_t>(
  2110. recurrent_to_input_weights),
  2111. recurrent_to_input_weights_ledger == nullptr
  2112. ? nullptr
  2113. : tflite::micro::GetTensorData<uint8_t>(
  2114. recurrent_to_input_weights_ledger),
  2115. hybrid_lstm_scales->recurrent_to_input_weights_scale,
  2116. recurrent_to_forget_weights == nullptr
  2117. ? nullptr
  2118. : tflite::micro::GetTensorData<int8_t>(
  2119. recurrent_to_forget_weights),
  2120. recurrent_to_forget_weights_ledger == nullptr
  2121. ? nullptr
  2122. : tflite::micro::GetTensorData<uint8_t>(
  2123. recurrent_to_forget_weights_ledger),
  2124. hybrid_lstm_scales->recurrent_to_forget_weights_scale,
  2125. recurrent_to_cell_weights == nullptr
  2126. ? nullptr
  2127. : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
  2128. recurrent_to_cell_weights_ledger == nullptr
  2129. ? nullptr
  2130. : tflite::micro::GetTensorData<uint8_t>(
  2131. recurrent_to_cell_weights_ledger),
  2132. hybrid_lstm_scales->recurrent_to_cell_weights_scale,
  2133. recurrent_to_output_weights == nullptr
  2134. ? nullptr
  2135. : tflite::micro::GetTensorData<int8_t>(
  2136. recurrent_to_output_weights),
  2137. recurrent_to_output_weights_ledger == nullptr
  2138. ? nullptr
  2139. : tflite::micro::GetTensorData<uint8_t>(
  2140. recurrent_to_output_weights_ledger),
  2141. hybrid_lstm_scales->recurrent_to_output_weights_scale,
  2142. cell_to_input_weights == nullptr
  2143. ? nullptr
  2144. : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
  2145. hybrid_lstm_scales->cell_to_input_weights_scale,
  2146. cell_to_forget_weights == nullptr
  2147. ? nullptr
  2148. : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
  2149. hybrid_lstm_scales->cell_to_forget_weights_scale,
  2150. cell_to_output_weights == nullptr
  2151. ? nullptr
  2152. : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
  2153. hybrid_lstm_scales->cell_to_output_weights_scale,
  2154. input_layer_norm_coefficients == nullptr
  2155. ? nullptr
  2156. : tflite::micro::GetTensorData<float>(
  2157. input_layer_norm_coefficients),
  2158. forget_layer_norm_coefficients == nullptr
  2159. ? nullptr
  2160. : tflite::micro::GetTensorData<float>(
  2161. forget_layer_norm_coefficients),
  2162. cell_layer_norm_coefficients == nullptr
  2163. ? nullptr
  2164. : tflite::micro::GetTensorData<float>(
  2165. cell_layer_norm_coefficients),
  2166. output_layer_norm_coefficients == nullptr
  2167. ? nullptr
  2168. : tflite::micro::GetTensorData<float>(
  2169. output_layer_norm_coefficients),
  2170. input_gate_bias == nullptr
  2171. ? nullptr
  2172. : tflite::micro::GetTensorData<float>(input_gate_bias),
  2173. forget_gate_bias == nullptr
  2174. ? nullptr
  2175. : tflite::micro::GetTensorData<float>(forget_gate_bias),
  2176. cell_gate_bias == nullptr
  2177. ? nullptr
  2178. : tflite::micro::GetTensorData<float>(cell_gate_bias),
  2179. output_gate_bias == nullptr
  2180. ? nullptr
  2181. : tflite::micro::GetTensorData<float>(output_gate_bias),
  2182. projection_weights == nullptr
  2183. ? nullptr
  2184. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2185. projection_weights_ledger == nullptr
  2186. ? nullptr
  2187. : tflite::micro::GetTensorData<uint8_t>(
  2188. projection_weights_ledger),
  2189. hybrid_lstm_scales->projection_weights_scale,
  2190. projection_bias == nullptr
  2191. ? nullptr
  2192. : tflite::micro::GetTensorData<float>(projection_bias),
  2193. params, n_batch, n_cell, n_input, aux_input_size, n_output,
  2194. output_batch_leading_dim, input_gate_scratch, forget_gate_scratch,
  2195. cell_gate_scratch, output_gate_scratch, scales, input_sf,
  2196. aux_input_sf, output_state_sf, prod_scaling_factors,
  2197. recovered_cell_weights, input_quantized, aux_input_quantized,
  2198. output_state_quantized, cell_state_quantized,
  2199. tflite::micro::GetTensorData<float>(output_state),
  2200. tflite::micro::GetTensorData<float>(cell_state),
  2201. output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr,
  2202. output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
  2203. params->asymmetric_quantize_inputs);
  2204. }
  2205. } else {
  2206. for (int b = 0; b < n_batch; b++) {
  2207. const int input_step = n_input;
  2208. const int output_step = output_batch_leading_dim;
  2209. for (int t = 0; t < max_time; t++) {
  2210. // If this is the forward_sequence, step forward, otherwise step
  2211. // backwards.
  2212. const int t_rel = forward_sequence ? t : max_time - t - 1;
  2213. const int time_offset = b * max_time + t_rel;
  2214. const float* input_ptr = tflite::micro::GetTensorData<float>(input) +
  2215. time_offset * input_step;
  2216. const float* aux_input_ptr = nullptr;
  2217. if (aux_input) {
  2218. aux_input_ptr = tflite::micro::GetTensorData<float>(aux_input) +
  2219. time_offset * input_step;
  2220. }
  2221. float* output_ptr = tflite::micro::GetTensorData<float>(output) +
  2222. time_offset * output_step + output_offset;
  2223. // Offset the {output,cell}_state pointers to the right batch.
  2224. float* output_state_ptr =
  2225. tflite::micro::GetTensorData<float>(output_state) +
  2226. b * output_batch_leading_dim;
  2227. float* cell_state_ptr =
  2228. tflite::micro::GetTensorData<float>(cell_state) + b * n_cell;
  2229. // Offset the scratch pointers to the right batch.
  2230. float* input_gate_scratch_ptr =
  2231. input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
  2232. float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
  2233. float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
  2234. float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
  2235. LstmStepHybrid(
  2236. input_ptr,
  2237. input_to_input_weights == nullptr
  2238. ? nullptr
  2239. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2240. input_to_input_weights_ledger == nullptr
  2241. ? nullptr
  2242. : tflite::micro::GetTensorData<uint8_t>(
  2243. input_to_input_weights_ledger),
  2244. hybrid_lstm_scales->input_to_input_weights_scale,
  2245. input_to_forget_weights == nullptr
  2246. ? nullptr
  2247. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2248. input_to_forget_weights_ledger == nullptr
  2249. ? nullptr
  2250. : tflite::micro::GetTensorData<uint8_t>(
  2251. input_to_forget_weights_ledger),
  2252. hybrid_lstm_scales->input_to_forget_weights_scale,
  2253. input_to_cell_weights == nullptr
  2254. ? nullptr
  2255. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2256. input_to_cell_weights_ledger == nullptr
  2257. ? nullptr
  2258. : tflite::micro::GetTensorData<uint8_t>(
  2259. input_to_cell_weights_ledger),
  2260. hybrid_lstm_scales->input_to_cell_weights_scale,
  2261. input_to_output_weights == nullptr
  2262. ? nullptr
  2263. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2264. input_to_output_weights_ledger == nullptr
  2265. ? nullptr
  2266. : tflite::micro::GetTensorData<uint8_t>(
  2267. input_to_output_weights_ledger),
  2268. hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr,
  2269. aux_input_to_input_weights == nullptr
  2270. ? nullptr
  2271. : tflite::micro::GetTensorData<int8_t>(
  2272. aux_input_to_input_weights),
  2273. hybrid_lstm_scales->aux_input_to_input_weights_scale,
  2274. aux_input_to_forget_weights == nullptr
  2275. ? nullptr
  2276. : tflite::micro::GetTensorData<int8_t>(
  2277. aux_input_to_forget_weights),
  2278. hybrid_lstm_scales->aux_input_to_forget_weights_scale,
  2279. aux_input_to_cell_weights == nullptr
  2280. ? nullptr
  2281. : tflite::micro::GetTensorData<int8_t>(
  2282. aux_input_to_cell_weights),
  2283. hybrid_lstm_scales->aux_input_to_cell_weights_scale,
  2284. aux_input_to_output_weights == nullptr
  2285. ? nullptr
  2286. : tflite::micro::GetTensorData<int8_t>(
  2287. aux_input_to_output_weights),
  2288. hybrid_lstm_scales->aux_input_to_output_weights_scale,
  2289. recurrent_to_input_weights == nullptr
  2290. ? nullptr
  2291. : tflite::micro::GetTensorData<int8_t>(
  2292. recurrent_to_input_weights),
  2293. recurrent_to_input_weights_ledger == nullptr
  2294. ? nullptr
  2295. : tflite::micro::GetTensorData<uint8_t>(
  2296. recurrent_to_input_weights_ledger),
  2297. hybrid_lstm_scales->recurrent_to_input_weights_scale,
  2298. recurrent_to_forget_weights == nullptr
  2299. ? nullptr
  2300. : tflite::micro::GetTensorData<int8_t>(
  2301. recurrent_to_forget_weights),
  2302. recurrent_to_forget_weights_ledger == nullptr
  2303. ? nullptr
  2304. : tflite::micro::GetTensorData<uint8_t>(
  2305. recurrent_to_forget_weights_ledger),
  2306. hybrid_lstm_scales->recurrent_to_forget_weights_scale,
  2307. recurrent_to_cell_weights == nullptr
  2308. ? nullptr
  2309. : tflite::micro::GetTensorData<int8_t>(
  2310. recurrent_to_cell_weights),
  2311. recurrent_to_cell_weights_ledger == nullptr
  2312. ? nullptr
  2313. : tflite::micro::GetTensorData<uint8_t>(
  2314. recurrent_to_cell_weights_ledger),
  2315. hybrid_lstm_scales->recurrent_to_cell_weights_scale,
  2316. recurrent_to_output_weights == nullptr
  2317. ? nullptr
  2318. : tflite::micro::GetTensorData<int8_t>(
  2319. recurrent_to_output_weights),
  2320. recurrent_to_output_weights_ledger == nullptr
  2321. ? nullptr
  2322. : tflite::micro::GetTensorData<uint8_t>(
  2323. recurrent_to_output_weights_ledger),
  2324. hybrid_lstm_scales->recurrent_to_output_weights_scale,
  2325. cell_to_input_weights == nullptr
  2326. ? nullptr
  2327. : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
  2328. hybrid_lstm_scales->cell_to_input_weights_scale,
  2329. cell_to_forget_weights == nullptr
  2330. ? nullptr
  2331. : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
  2332. hybrid_lstm_scales->cell_to_forget_weights_scale,
  2333. cell_to_output_weights == nullptr
  2334. ? nullptr
  2335. : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
  2336. hybrid_lstm_scales->cell_to_output_weights_scale,
  2337. input_layer_norm_coefficients == nullptr
  2338. ? nullptr
  2339. : tflite::micro::GetTensorData<float>(
  2340. input_layer_norm_coefficients),
  2341. forget_layer_norm_coefficients == nullptr
  2342. ? nullptr
  2343. : tflite::micro::GetTensorData<float>(
  2344. forget_layer_norm_coefficients),
  2345. cell_layer_norm_coefficients == nullptr
  2346. ? nullptr
  2347. : tflite::micro::GetTensorData<float>(
  2348. cell_layer_norm_coefficients),
  2349. output_layer_norm_coefficients == nullptr
  2350. ? nullptr
  2351. : tflite::micro::GetTensorData<float>(
  2352. output_layer_norm_coefficients),
  2353. input_gate_bias == nullptr
  2354. ? nullptr
  2355. : tflite::micro::GetTensorData<float>(input_gate_bias),
  2356. forget_gate_bias == nullptr
  2357. ? nullptr
  2358. : tflite::micro::GetTensorData<float>(forget_gate_bias),
  2359. cell_gate_bias == nullptr
  2360. ? nullptr
  2361. : tflite::micro::GetTensorData<float>(cell_gate_bias),
  2362. output_gate_bias == nullptr
  2363. ? nullptr
  2364. : tflite::micro::GetTensorData<float>(output_gate_bias),
  2365. projection_weights == nullptr
  2366. ? nullptr
  2367. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2368. projection_weights_ledger == nullptr
  2369. ? nullptr
  2370. : tflite::micro::GetTensorData<uint8_t>(
  2371. projection_weights_ledger),
  2372. hybrid_lstm_scales->projection_weights_scale,
  2373. projection_bias == nullptr
  2374. ? nullptr
  2375. : tflite::micro::GetTensorData<float>(projection_bias),
  2376. params,
  2377. /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
  2378. output_batch_leading_dim, input_gate_scratch_ptr,
  2379. forget_gate_scratch_ptr, cell_gate_scratch_ptr,
  2380. output_gate_scratch_ptr, scales, input_sf, aux_input_sf,
  2381. output_state_sf, prod_scaling_factors, recovered_cell_weights,
  2382. input_quantized, aux_input_quantized, output_state_quantized,
  2383. cell_state_quantized, output_state_ptr, cell_state_ptr,
  2384. output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr,
  2385. output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
  2386. params->asymmetric_quantize_inputs);
  2387. }
  2388. }
  2389. }
  2390. return kTfLiteOk;
  2391. }
  2392. TfLiteStatus EvalInteger8x8_16Lstm(
  2393. const TfLiteEvalTensor* input,
  2394. const TfLiteEvalTensor* input_to_input_weights,
  2395. const TfLiteEvalTensor* input_to_forget_weights,
  2396. const TfLiteEvalTensor* input_to_cell_weights,
  2397. const TfLiteEvalTensor* input_to_output_weights,
  2398. const TfLiteEvalTensor* recurrent_to_input_weights,
  2399. const TfLiteEvalTensor* recurrent_to_forget_weights,
  2400. const TfLiteEvalTensor* recurrent_to_cell_weights,
  2401. const TfLiteEvalTensor* recurrent_to_output_weights,
  2402. const TfLiteEvalTensor* cell_to_input_weights,
  2403. const TfLiteEvalTensor* cell_to_forget_weights,
  2404. const TfLiteEvalTensor* cell_to_output_weights,
  2405. const TfLiteEvalTensor* input_layer_norm_coefficients,
  2406. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  2407. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  2408. const TfLiteEvalTensor* output_layer_norm_coefficients,
  2409. const TfLiteEvalTensor* input_gate_bias,
  2410. const TfLiteEvalTensor* forget_gate_bias,
  2411. const TfLiteEvalTensor* cell_gate_bias,
  2412. const TfLiteEvalTensor* output_gate_bias,
  2413. const TfLiteEvalTensor* projection_weights,
  2414. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  2415. bool forward_sequence, bool time_major,
  2416. const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
  2417. TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
  2418. TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
  2419. int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5) {
  2420. TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
  2421. const int n_input = input->dims->data[input->dims->size - 1];
  2422. int max_time, n_batch;
  2423. if (input->dims->size == 2) {
  2424. max_time = 1;
  2425. n_batch = input->dims->data[0];
  2426. } else {
  2427. max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
  2428. n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
  2429. }
  2430. // n_cell and n_output will be the same size when there is no projection.
  2431. const int n_cell = input_to_output_weights->dims->data[0];
  2432. const int n_output = recurrent_to_output_weights->dims->data[1];
  2433. // Get params for time/batch/sequence.
  2434. const int output_batch_leading_dim =
  2435. output->dims->data[output->dims->size - 1];
  2436. if (time_major) {
  2437. const int input_step = n_batch * n_input;
  2438. const int output_step = n_batch * output_batch_leading_dim;
  2439. for (int t = 0; t < max_time; t++) {
  2440. const int t_rel = t;
  2441. int8_t* output_ptr =
  2442. tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
  2443. const int8_t* input_ptr =
  2444. tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
  2445. LstmStepInteger8x8_16(
  2446. input_ptr,
  2447. input_to_input_weights == nullptr
  2448. ? nullptr
  2449. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2450. integer_lstm_param->effective_input_to_input_scale_a,
  2451. integer_lstm_param->effective_input_to_input_scale_b,
  2452. input_to_forget_weights == nullptr
  2453. ? nullptr
  2454. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2455. integer_lstm_param->effective_input_to_forget_scale_a,
  2456. integer_lstm_param->effective_input_to_forget_scale_b,
  2457. input_to_cell_weights == nullptr
  2458. ? nullptr
  2459. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2460. integer_lstm_param->effective_input_to_cell_scale_a,
  2461. integer_lstm_param->effective_input_to_cell_scale_b,
  2462. input_to_output_weights == nullptr
  2463. ? nullptr
  2464. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2465. integer_lstm_param->effective_input_to_output_scale_a,
  2466. integer_lstm_param->effective_input_to_output_scale_b,
  2467. recurrent_to_input_weights == nullptr
  2468. ? nullptr
  2469. : tflite::micro::GetTensorData<int8_t>(
  2470. recurrent_to_input_weights),
  2471. integer_lstm_param->effective_recurrent_to_input_scale_a,
  2472. integer_lstm_param->effective_recurrent_to_input_scale_b,
  2473. recurrent_to_forget_weights == nullptr
  2474. ? nullptr
  2475. : tflite::micro::GetTensorData<int8_t>(
  2476. recurrent_to_forget_weights),
  2477. integer_lstm_param->effective_recurrent_to_forget_scale_a,
  2478. integer_lstm_param->effective_recurrent_to_forget_scale_b,
  2479. recurrent_to_cell_weights == nullptr
  2480. ? nullptr
  2481. : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
  2482. integer_lstm_param->effective_recurrent_to_cell_scale_a,
  2483. integer_lstm_param->effective_recurrent_to_cell_scale_b,
  2484. recurrent_to_output_weights == nullptr
  2485. ? nullptr
  2486. : tflite::micro::GetTensorData<int8_t>(
  2487. recurrent_to_output_weights),
  2488. integer_lstm_param->effective_recurrent_to_output_scale_a,
  2489. integer_lstm_param->effective_recurrent_to_output_scale_b,
  2490. cell_to_input_weights == nullptr
  2491. ? nullptr
  2492. : tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
  2493. integer_lstm_param->effective_cell_to_input_scale_a,
  2494. integer_lstm_param->effective_cell_to_input_scale_b,
  2495. cell_to_forget_weights == nullptr
  2496. ? nullptr
  2497. : tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
  2498. integer_lstm_param->effective_cell_to_forget_scale_a,
  2499. integer_lstm_param->effective_cell_to_forget_scale_b,
  2500. cell_to_output_weights == nullptr
  2501. ? nullptr
  2502. : tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
  2503. integer_lstm_param->effective_cell_to_output_scale_a,
  2504. integer_lstm_param->effective_cell_to_output_scale_b,
  2505. projection_weights == nullptr
  2506. ? nullptr
  2507. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2508. integer_lstm_param->effective_proj_scale_a,
  2509. integer_lstm_param->effective_proj_scale_b,
  2510. integer_lstm_param->hidden_zp,
  2511. integer_lstm_param->effective_hidden_scale_a,
  2512. integer_lstm_param->effective_hidden_scale_b,
  2513. input_layer_norm_coefficients == nullptr
  2514. ? nullptr
  2515. : tflite::micro::GetTensorData<int16_t>(
  2516. input_layer_norm_coefficients),
  2517. integer_lstm_param->layer_norm_input_scale_a,
  2518. integer_lstm_param->layer_norm_input_scale_b,
  2519. forget_layer_norm_coefficients == nullptr
  2520. ? nullptr
  2521. : tflite::micro::GetTensorData<int16_t>(
  2522. forget_layer_norm_coefficients),
  2523. integer_lstm_param->layer_norm_forget_scale_a,
  2524. integer_lstm_param->layer_norm_forget_scale_b,
  2525. cell_layer_norm_coefficients == nullptr
  2526. ? nullptr
  2527. : tflite::micro::GetTensorData<int16_t>(
  2528. cell_layer_norm_coefficients),
  2529. integer_lstm_param->layer_norm_cell_scale_a,
  2530. integer_lstm_param->layer_norm_cell_scale_b,
  2531. output_layer_norm_coefficients == nullptr
  2532. ? nullptr
  2533. : tflite::micro::GetTensorData<int16_t>(
  2534. output_layer_norm_coefficients),
  2535. integer_lstm_param->layer_norm_output_scale_a,
  2536. integer_lstm_param->layer_norm_output_scale_b,
  2537. input_gate_bias == nullptr
  2538. ? nullptr
  2539. : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
  2540. forget_gate_bias == nullptr
  2541. ? nullptr
  2542. : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
  2543. cell_gate_bias == nullptr
  2544. ? nullptr
  2545. : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
  2546. output_gate_bias == nullptr
  2547. ? nullptr
  2548. : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
  2549. integer_lstm_param->quantized_cell_clip,
  2550. integer_lstm_param->quantized_proj_clip,
  2551. integer_lstm_param->cell_scale,
  2552. integer_lstm_param->input_variance_guard,
  2553. integer_lstm_param->forget_variance_guard,
  2554. integer_lstm_param->cell_variance_guard,
  2555. integer_lstm_param->output_variance_guard,
  2556. integer_lstm_param->input_to_forget_effective_bias,
  2557. integer_lstm_param->recurrent_to_forget_effective_bias,
  2558. integer_lstm_param->input_to_cell_effective_bias,
  2559. integer_lstm_param->recurrent_to_cell_effective_bias,
  2560. integer_lstm_param->input_to_output_effective_bias,
  2561. integer_lstm_param->recurrent_to_output_effective_bias,
  2562. integer_lstm_param->input_to_input_effective_bias,
  2563. integer_lstm_param->recurrent_to_input_effective_bias,
  2564. integer_lstm_param->projection_effective_bias, n_batch, n_cell,
  2565. n_input, n_output, tflite::micro::GetTensorData<int8_t>(output_state),
  2566. output_state_zp, tflite::micro::GetTensorData<int16_t>(cell_state),
  2567. output_ptr, scratch0, scratch1, scratch2, scratch3, scratch4,
  2568. scratch5);
  2569. }
  2570. } else {
  2571. for (int b = 0; b < n_batch; b++) {
  2572. const int input_step = n_input;
  2573. const int output_step = output_batch_leading_dim;
  2574. for (int t = 0; t < max_time; t++) {
  2575. // If this is the forward_sequence, step forward, otherwise step
  2576. // backwards.
  2577. const int t_rel = forward_sequence ? t : max_time - t - 1;
  2578. const int time_offset = b * max_time + t_rel;
  2579. const int8_t* input_ptr = tflite::micro::GetTensorData<int8_t>(input) +
  2580. time_offset * input_step;
  2581. int8_t* output_ptr = tflite::micro::GetTensorData<int8_t>(output) +
  2582. time_offset * output_step;
  2583. // Offset the {output,cell}_state pointers to the right batch.
  2584. int8_t* output_state_ptr =
  2585. tflite::micro::GetTensorData<int8_t>(output_state) +
  2586. b * output_batch_leading_dim;
  2587. int16_t* cell_state_ptr =
  2588. tflite::micro::GetTensorData<int16_t>(cell_state) + b * n_cell;
  2589. LstmStepInteger8x8_16(
  2590. input_ptr,
  2591. input_to_input_weights == nullptr
  2592. ? nullptr
  2593. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2594. integer_lstm_param->effective_input_to_input_scale_a,
  2595. integer_lstm_param->effective_input_to_input_scale_b,
  2596. input_to_forget_weights == nullptr
  2597. ? nullptr
  2598. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2599. integer_lstm_param->effective_input_to_forget_scale_a,
  2600. integer_lstm_param->effective_input_to_forget_scale_b,
  2601. input_to_cell_weights == nullptr
  2602. ? nullptr
  2603. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2604. integer_lstm_param->effective_input_to_cell_scale_a,
  2605. integer_lstm_param->effective_input_to_cell_scale_b,
  2606. input_to_output_weights == nullptr
  2607. ? nullptr
  2608. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2609. integer_lstm_param->effective_input_to_output_scale_a,
  2610. integer_lstm_param->effective_input_to_output_scale_b,
  2611. recurrent_to_input_weights == nullptr
  2612. ? nullptr
  2613. : tflite::micro::GetTensorData<int8_t>(
  2614. recurrent_to_input_weights),
  2615. integer_lstm_param->effective_recurrent_to_input_scale_a,
  2616. integer_lstm_param->effective_recurrent_to_input_scale_b,
  2617. recurrent_to_forget_weights == nullptr
  2618. ? nullptr
  2619. : tflite::micro::GetTensorData<int8_t>(
  2620. recurrent_to_forget_weights),
  2621. integer_lstm_param->effective_recurrent_to_forget_scale_a,
  2622. integer_lstm_param->effective_recurrent_to_forget_scale_b,
  2623. recurrent_to_cell_weights == nullptr
  2624. ? nullptr
  2625. : tflite::micro::GetTensorData<int8_t>(
  2626. recurrent_to_cell_weights),
  2627. integer_lstm_param->effective_recurrent_to_cell_scale_a,
  2628. integer_lstm_param->effective_recurrent_to_cell_scale_b,
  2629. recurrent_to_output_weights == nullptr
  2630. ? nullptr
  2631. : tflite::micro::GetTensorData<int8_t>(
  2632. recurrent_to_output_weights),
  2633. integer_lstm_param->effective_recurrent_to_output_scale_a,
  2634. integer_lstm_param->effective_recurrent_to_output_scale_b,
  2635. cell_to_input_weights == nullptr
  2636. ? nullptr
  2637. : tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
  2638. integer_lstm_param->effective_cell_to_input_scale_a,
  2639. integer_lstm_param->effective_cell_to_input_scale_b,
  2640. cell_to_forget_weights == nullptr
  2641. ? nullptr
  2642. : tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
  2643. integer_lstm_param->effective_cell_to_forget_scale_a,
  2644. integer_lstm_param->effective_cell_to_forget_scale_b,
  2645. cell_to_output_weights == nullptr
  2646. ? nullptr
  2647. : tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
  2648. integer_lstm_param->effective_cell_to_output_scale_a,
  2649. integer_lstm_param->effective_cell_to_output_scale_b,
  2650. projection_weights == nullptr
  2651. ? nullptr
  2652. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2653. integer_lstm_param->effective_proj_scale_a,
  2654. integer_lstm_param->effective_proj_scale_b,
  2655. integer_lstm_param->hidden_zp,
  2656. integer_lstm_param->effective_hidden_scale_a,
  2657. integer_lstm_param->effective_hidden_scale_b,
  2658. input_layer_norm_coefficients == nullptr
  2659. ? nullptr
  2660. : tflite::micro::GetTensorData<int16_t>(
  2661. input_layer_norm_coefficients),
  2662. integer_lstm_param->layer_norm_input_scale_a,
  2663. integer_lstm_param->layer_norm_input_scale_b,
  2664. forget_layer_norm_coefficients == nullptr
  2665. ? nullptr
  2666. : tflite::micro::GetTensorData<int16_t>(
  2667. forget_layer_norm_coefficients),
  2668. integer_lstm_param->layer_norm_forget_scale_a,
  2669. integer_lstm_param->layer_norm_forget_scale_b,
  2670. cell_layer_norm_coefficients == nullptr
  2671. ? nullptr
  2672. : tflite::micro::GetTensorData<int16_t>(
  2673. cell_layer_norm_coefficients),
  2674. integer_lstm_param->layer_norm_cell_scale_a,
  2675. integer_lstm_param->layer_norm_cell_scale_b,
  2676. output_layer_norm_coefficients == nullptr
  2677. ? nullptr
  2678. : tflite::micro::GetTensorData<int16_t>(
  2679. output_layer_norm_coefficients),
  2680. integer_lstm_param->layer_norm_output_scale_a,
  2681. integer_lstm_param->layer_norm_output_scale_b,
  2682. input_gate_bias == nullptr
  2683. ? nullptr
  2684. : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
  2685. forget_gate_bias == nullptr
  2686. ? nullptr
  2687. : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
  2688. cell_gate_bias == nullptr
  2689. ? nullptr
  2690. : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
  2691. output_gate_bias == nullptr
  2692. ? nullptr
  2693. : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
  2694. integer_lstm_param->quantized_cell_clip,
  2695. integer_lstm_param->quantized_proj_clip,
  2696. integer_lstm_param->cell_scale,
  2697. integer_lstm_param->input_variance_guard,
  2698. integer_lstm_param->forget_variance_guard,
  2699. integer_lstm_param->cell_variance_guard,
  2700. integer_lstm_param->output_variance_guard,
  2701. integer_lstm_param->input_to_forget_effective_bias,
  2702. integer_lstm_param->recurrent_to_forget_effective_bias,
  2703. integer_lstm_param->input_to_cell_effective_bias,
  2704. integer_lstm_param->recurrent_to_cell_effective_bias,
  2705. integer_lstm_param->input_to_output_effective_bias,
  2706. integer_lstm_param->recurrent_to_output_effective_bias,
  2707. integer_lstm_param->input_to_input_effective_bias,
  2708. integer_lstm_param->recurrent_to_input_effective_bias,
  2709. integer_lstm_param->projection_effective_bias, /*n_batch=*/1,
  2710. n_cell, n_input, n_output, output_state_ptr, output_state_zp,
  2711. cell_state_ptr, output_ptr, scratch0, scratch1, scratch2, scratch3,
  2712. scratch4, scratch5);
  2713. }
  2714. }
  2715. }
  2716. return kTfLiteOk;
  2717. }
  2718. TfLiteStatus EvalInteger8x8_8Lstm(
  2719. const TfLiteEvalTensor* input,
  2720. const TfLiteEvalTensor* input_to_input_weights,
  2721. const TfLiteEvalTensor* input_to_forget_weights,
  2722. const TfLiteEvalTensor* input_to_cell_weights,
  2723. const TfLiteEvalTensor* input_to_output_weights,
  2724. const TfLiteEvalTensor* recurrent_to_input_weights,
  2725. const TfLiteEvalTensor* recurrent_to_forget_weights,
  2726. const TfLiteEvalTensor* recurrent_to_cell_weights,
  2727. const TfLiteEvalTensor* recurrent_to_output_weights,
  2728. const TfLiteEvalTensor* cell_to_input_weights,
  2729. const TfLiteEvalTensor* cell_to_forget_weights,
  2730. const TfLiteEvalTensor* cell_to_output_weights,
  2731. const TfLiteEvalTensor* input_layer_norm_coefficients,
  2732. const TfLiteEvalTensor* forget_layer_norm_coefficients,
  2733. const TfLiteEvalTensor* cell_layer_norm_coefficients,
  2734. const TfLiteEvalTensor* output_layer_norm_coefficients,
  2735. const TfLiteEvalTensor* input_gate_bias,
  2736. const TfLiteEvalTensor* forget_gate_bias,
  2737. const TfLiteEvalTensor* cell_gate_bias,
  2738. const TfLiteEvalTensor* output_gate_bias,
  2739. const TfLiteEvalTensor* projection_weights,
  2740. const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
  2741. TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
  2742. TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
  2743. int32_t input_zp, int32_t output_state_zp, int8_t* scratch0,
  2744. int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, int16_t* scratch4,
  2745. int16_t* scratch5, int16_t* scratch6, int16_t* scratch7) {
  2746. TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
  2747. const int n_input = input->dims->data[input->dims->size - 1];
  2748. int max_time, n_batch;
  2749. if (input->dims->size == 2) {
  2750. max_time = 1;
  2751. n_batch = input->dims->data[0];
  2752. } else {
  2753. max_time = input->dims->data[0];
  2754. n_batch = input->dims->data[1];
  2755. }
  2756. // n_cell and n_output will be the same size when there is no projection.
  2757. const int n_cell = input_to_output_weights->dims->data[0];
  2758. const int n_output = recurrent_to_output_weights->dims->data[1];
  2759. // Get params for time/batch/sequence.
  2760. const int output_batch_leading_dim =
  2761. output->dims->data[output->dims->size - 1];
  2762. const int input_step = n_batch * n_input;
  2763. const int output_step = n_batch * output_batch_leading_dim;
  2764. for (int t = 0; t < max_time; t++) {
  2765. const int t_rel = t;
  2766. int8_t* output_ptr =
  2767. tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
  2768. // Input can be int8 asymmetric or int16 symmetric.
  2769. const int8_t* input_ptr =
  2770. tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
  2771. LstmStepInteger8x8_8(
  2772. input_ptr, input_zp,
  2773. input_to_input_weights == nullptr
  2774. ? nullptr
  2775. : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
  2776. integer_lstm_param->effective_input_to_input_scale_a,
  2777. integer_lstm_param->effective_input_to_input_scale_b,
  2778. input_to_forget_weights == nullptr
  2779. ? nullptr
  2780. : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
  2781. integer_lstm_param->effective_input_to_forget_scale_a,
  2782. integer_lstm_param->effective_input_to_forget_scale_b,
  2783. input_to_cell_weights == nullptr
  2784. ? nullptr
  2785. : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
  2786. integer_lstm_param->effective_input_to_cell_scale_a,
  2787. integer_lstm_param->effective_input_to_cell_scale_b,
  2788. input_to_output_weights == nullptr
  2789. ? nullptr
  2790. : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
  2791. integer_lstm_param->effective_input_to_output_scale_a,
  2792. integer_lstm_param->effective_input_to_output_scale_b,
  2793. recurrent_to_input_weights == nullptr
  2794. ? nullptr
  2795. : tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
  2796. integer_lstm_param->effective_recurrent_to_input_scale_a,
  2797. integer_lstm_param->effective_recurrent_to_input_scale_b,
  2798. recurrent_to_forget_weights == nullptr
  2799. ? nullptr
  2800. : tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
  2801. integer_lstm_param->effective_recurrent_to_forget_scale_a,
  2802. integer_lstm_param->effective_recurrent_to_forget_scale_b,
  2803. recurrent_to_cell_weights == nullptr
  2804. ? nullptr
  2805. : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
  2806. integer_lstm_param->effective_recurrent_to_cell_scale_a,
  2807. integer_lstm_param->effective_recurrent_to_cell_scale_b,
  2808. recurrent_to_output_weights == nullptr
  2809. ? nullptr
  2810. : tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
  2811. integer_lstm_param->effective_recurrent_to_output_scale_a,
  2812. integer_lstm_param->effective_recurrent_to_output_scale_b,
  2813. cell_to_input_weights == nullptr
  2814. ? nullptr
  2815. : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
  2816. integer_lstm_param->effective_cell_to_input_scale_a,
  2817. integer_lstm_param->effective_cell_to_input_scale_b,
  2818. cell_to_forget_weights == nullptr
  2819. ? nullptr
  2820. : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
  2821. integer_lstm_param->effective_cell_to_forget_scale_a,
  2822. integer_lstm_param->effective_cell_to_forget_scale_b,
  2823. cell_to_output_weights == nullptr
  2824. ? nullptr
  2825. : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
  2826. integer_lstm_param->effective_cell_to_output_scale_a,
  2827. integer_lstm_param->effective_cell_to_output_scale_b,
  2828. projection_weights == nullptr
  2829. ? nullptr
  2830. : tflite::micro::GetTensorData<int8_t>(projection_weights),
  2831. integer_lstm_param->effective_proj_scale_a,
  2832. integer_lstm_param->effective_proj_scale_b,
  2833. input_layer_norm_coefficients == nullptr
  2834. ? nullptr
  2835. : tflite::micro::GetTensorData<int16_t>(
  2836. input_layer_norm_coefficients),
  2837. integer_lstm_param->layer_norm_input_scale_a,
  2838. integer_lstm_param->layer_norm_input_scale_b,
  2839. forget_layer_norm_coefficients == nullptr
  2840. ? nullptr
  2841. : tflite::micro::GetTensorData<int16_t>(
  2842. forget_layer_norm_coefficients),
  2843. integer_lstm_param->layer_norm_forget_scale_a,
  2844. integer_lstm_param->layer_norm_forget_scale_b,
  2845. cell_layer_norm_coefficients == nullptr
  2846. ? nullptr
  2847. : tflite::micro::GetTensorData<int16_t>(
  2848. cell_layer_norm_coefficients),
  2849. integer_lstm_param->layer_norm_cell_scale_a,
  2850. integer_lstm_param->layer_norm_cell_scale_b,
  2851. output_layer_norm_coefficients == nullptr
  2852. ? nullptr
  2853. : tflite::micro::GetTensorData<int16_t>(
  2854. output_layer_norm_coefficients),
  2855. integer_lstm_param->layer_norm_output_scale_a,
  2856. integer_lstm_param->layer_norm_output_scale_b,
  2857. input_gate_bias == nullptr
  2858. ? nullptr
  2859. : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
  2860. forget_gate_bias == nullptr
  2861. ? nullptr
  2862. : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
  2863. cell_gate_bias == nullptr
  2864. ? nullptr
  2865. : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
  2866. output_gate_bias == nullptr
  2867. ? nullptr
  2868. : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
  2869. projection_bias == nullptr
  2870. ? nullptr
  2871. : tflite::micro::GetTensorData<int32_t>(projection_bias),
  2872. params, integer_lstm_param->intermediate_scale_a,
  2873. integer_lstm_param->intermediate_scale_b,
  2874. integer_lstm_param->intermediate_zp,
  2875. integer_lstm_param->quantized_cell_clip,
  2876. integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
  2877. n_output, output_batch_leading_dim,
  2878. tflite::micro::GetTensorData<int8_t>(output_state), output_state_zp,
  2879. tflite::micro::GetTensorData<int16_t>(cell_state), output_ptr, scratch0,
  2880. scratch1, scratch2, scratch3, scratch4, scratch5, scratch6, scratch7);
  2881. }
  2882. return kTfLiteOk;
  2883. }
  2884. } // namespace tflite