Przeglądaj źródła

Rolling 20220716_2

jomjol 3 lat temu
rodzic
commit
17fd0f96df
100 zmienionych plików z 5715 dodań i 976 usunięć
  1. 4 2
      README.md
  2. 3 1
      code/components/esp-nn/CMakeLists.txt
  3. 4 4
      code/components/esp-nn/Kconfig.projbuild
  4. 4 3
      code/components/esp-nn/README.md
  5. 5 5
      code/components/esp-nn/include/esp_nn.h
  6. 1 0
      code/components/esp-nn/include/esp_nn_ansi_c.h
  7. 82 56
      code/components/esp-nn/include/esp_nn_ansi_headers.h
  8. 83 0
      code/components/esp-nn/include/esp_nn_defs.h
  9. 24 54
      code/components/esp-nn/include/esp_nn_esp32s3.h
  10. 9 10
      code/components/esp-nn/include/esp_nn_generic_opt.h
  11. 50 13
      code/components/esp-nn/src/common/common_functions.h
  12. 30 26
      code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c
  13. 106 79
      code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c
  14. 179 0
      code/components/esp-nn/src/convolution/esp_nn_conv_opt.c
  15. 30 27
      code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c
  16. 291 0
      code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c
  17. 97 37
      code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c
  18. 8 0
      code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3
  19. 16 4
      code/components/esp-nn/tests/src/basic_math_test.c
  20. 70 36
      code/components/esp-nn/tests/src/convolution_test.c
  21. BIN
      code/components/esp-nn_20220716.zip
  22. 1 1
      code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp
  23. 4 1
      code/components/tflite-lib/CMakeLists.txt
  24. 2 0
      code/components/tflite-lib/tensorflow/lite/builtin_ops.h
  25. 3 0
      code/components/tflite-lib/tensorflow/lite/c/builtin_op_data.h
  26. 7 1
      code/components/tflite-lib/tensorflow/lite/c/c_api_types.h
  27. 40 10
      code/components/tflite-lib/tensorflow/lite/c/common.cc
  28. 52 3
      code/components/tflite-lib/tensorflow/lite/c/common.h
  29. 11 0
      code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc
  30. 52 1
      code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h
  31. 3 3
      code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h
  32. 6 3
      code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h
  33. 5 5
      code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h
  34. 1 1
      code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h
  35. 3 3
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h
  36. 165 0
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc
  37. 104 0
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h
  38. 52 0
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc
  39. 59 0
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h
  40. 1 1
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.cc
  41. 4 4
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h
  42. 1 1
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.cc
  43. 4 4
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h
  44. 1 1
      code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc
  45. 7 20
      code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc
  46. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc
  47. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc
  48. 2 16
      code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc
  49. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc
  50. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
  51. 3 9
      code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc
  52. 3 9
      code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc
  53. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc
  54. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc
  55. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc
  56. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc
  57. 12 48
      code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc
  58. 5 11
      code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc
  59. 40 22
      code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc
  60. 10 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h
  61. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc
  62. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc
  63. 3 10
      code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc
  64. 27 1
      code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h
  65. 9 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc
  66. 3 2
      code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc
  67. 1 10
      code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc
  68. 289 79
      code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc
  69. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc
  70. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc
  71. 46 21
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc
  72. 49 22
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc
  73. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc
  74. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc
  75. 2 16
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc
  76. 208 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc
  77. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc
  78. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc
  79. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc
  80. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc
  81. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc
  82. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc
  83. 19 12
      code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc
  84. 19 1
      code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.h
  85. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/gather.cc
  86. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/gather_nd.cc
  87. 2 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/hard_swish.cc
  88. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/if.cc
  89. 3 2
      code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.cc
  90. 3 2
      code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.h
  91. 15 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.cc
  92. 22 3
      code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.h
  93. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/l2_pool_2d.cc
  94. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/l2norm.cc
  95. 2 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/leaky_relu.cc
  96. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/log_softmax.cc
  97. 2 20
      code/components/tflite-lib/tensorflow/lite/micro/kernels/logical.cc
  98. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/logistic.cc
  99. 2955 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.cc
  100. 250 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.h

+ 4 - 2
README.md

@@ -54,8 +54,10 @@ In other cases you can contact the developer via email: <img src="https://raw.gi
 
 ##### Rolling (2022-07-16)
 
-- Updated esp32cam
+- TFMicro/Lite: Update (espressif Version 20220716)
+- Updated esp32cam (v20220716)
 - Integrated new analog classificational CNN (from @haverland)
+- Bugfix: Postprocessing
 
 ##### Rolling (2022-07-01)
 
@@ -79,7 +81,7 @@ Rolling (2022-04-26)
 - Extended MQTT with absolute Change (in addition to rate)
 - Internal optimization, removal of modelfile from `config.ini` (is now read out of the cnn file directly)
 
-- TFMicro/Lite: Update (espressif Verision 20220417)
+- TFMicro/Lite: Update (espressif Version 20220417)
 - ESP-IDF: Update to 4.3.0
 
 Rolling (2022-04-17)

+ 3 - 1
code/components/esp-nn/CMakeLists.txt

@@ -5,7 +5,9 @@ set(c_srcs
     "src/basic_math/esp_nn_add_ansi.c"
     "src/basic_math/esp_nn_mul_ansi.c"
     "src/convolution/esp_nn_conv_ansi.c"
+    "src/convolution/esp_nn_conv_opt.c"
     "src/convolution/esp_nn_depthwise_conv_ansi.c"
+    "src/convolution/esp_nn_depthwise_conv_opt.c"
     "src/fully_connected/esp_nn_fully_connected_ansi.c"
     "src/softmax/esp_nn_softmax_ansi.c"
     "src/softmax/esp_nn_softmax_opt.c"
@@ -23,7 +25,7 @@ if(CONFIG_IDF_TARGET_ESP32S3)
         "src/convolution/esp_nn_conv_esp32s3.c"
         "src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c"
         "src/convolution/esp_nn_conv_s16_mult8_esp32s3.S"
-        "src/convolution/esp_nn_conv_s16_mult8_1x1_esp32s3.S"
+        "src/convolution/esp_nn_conv_s8_mult8_1x1_esp32s3.S"
         "src/convolution/esp_nn_conv_s16_mult4_1x1_esp32s3.S"
         "src/convolution/esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3.S"
         "src/convolution/esp_nn_depthwise_conv_s16_mult1_esp32s3.S"

+ 4 - 4
code/components/esp-nn/Kconfig.projbuild

@@ -6,8 +6,8 @@ choice NN_OPTIMIZATIONS
    help
       Use ANSI-C versions for verification and debug purpose.
       Optimisations are automatically picked up for a chipset.
-      For ESP32-S3, assembly Optimisations are selected.
-      For ESP32, just the ANSI C versions are selected for now.
+      For ESP32-S3, assembly optimisations are selected.
+      For other platforms(viz., ESP32, ESP32-C3), generic optimisations are used.
 
 config NN_ANSI_C
    bool "ANSI C"
@@ -17,8 +17,8 @@ config NN_OPTIMIZED
    bool "Optimized versions"
    help
       Optimisations are automatically picked up for a chipset.
-      For ESP32-S3, assembly Optimisations are selected.
-      For ESP32, just the ANSI C versions are selected for now.
+      For ESP32-S3, assembly optimisations are selected.
+      For other platforms(viz., ESP32, ESP32-C3), generic optimisations are used.
 endchoice
 
 config NN_OPTIMIZATIONS

+ 4 - 3
code/components/esp-nn/README.md

@@ -7,7 +7,8 @@ The library contains optimised NN (Neural Network) functions for various Espress
 
 * Supported ESP chipsets include:
    * ESP32-S3 (Assembly versions optimised to benefit from vector instructions of ESP32-S3)
-   * ESP32 (ANSI C versions)
+   * ESP32 (Generic optimisations)
+   * ESP32-C3 (Generic optimisations)
 
 ## Performance
 
@@ -39,8 +40,8 @@ The library contains optimised NN (Neural Network) functions for various Espress
      * Optimized versions
      * ANSI C
 
-  * Default selection is for `Optimized versions`. For ESP32-S3, assembly versions are automatically selected, whereas for ESP32,  ANSI-C versions are selected by default.
-  * For debugging purposes, you may want to select `ANSI C`
+  * Default selection is for `Optimized versions`. For ESP32-S3, assembly versions are automatically selected, whereas for other chipsets (viz., ESP32, ESP32-C3), generic optimisations are selected.
+  * For debugging purposes, you may want to select `ANSI C` reference versions.
 
 
 ## Contributing

+ 5 - 5
code/components/esp-nn/include/esp_nn.h

@@ -15,6 +15,7 @@
 #pragma once
 
 #if defined(CONFIG_NN_OPTIMIZED)
+// select apt optimisations
 #ifdef CONFIG_IDF_TARGET_ESP32S3
 #define ARCH_ESP32_S3 1
 #endif
@@ -31,12 +32,11 @@ extern "C" {
 #include "esp_nn_ansi_headers.h"
 
 #if defined(CONFIG_NN_OPTIMIZED)
-#ifdef ARCH_ESP32_S3
+#if defined(ARCH_ESP32_S3)
 #include "esp_nn_esp32s3.h"
-#endif
-#ifdef ARCH_ESP32
-#include "esp_nn_esp32.h"
-#endif
+#else // for other platforms use generic optimisations
+#include "esp_nn_generic_opt.h"
+#endif // #if defined(ARCH_ESP32_S3)
 #else
 #include "esp_nn_ansi_c.h"
 #endif

+ 1 - 0
code/components/esp-nn/include/esp_nn_ansi_c.h

@@ -19,6 +19,7 @@
 
 #pragma once
 
+#include "esp_nn_defs.h"
 #include "esp_nn_ansi_headers.h"
 
 #define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi

+ 82 - 56
code/components/esp-nn/include/esp_nn_ansi_headers.h

@@ -18,8 +18,7 @@
  * @file        Header definitions to include for esp_nn reference functions
  */
 
-#include <stdint.h>
-
+#include "esp_nn_defs.h"
 /************************** Basic math functions ****************************/
 
 /**
@@ -81,28 +80,15 @@ void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
  *              optimization notes: Though input_offset is int32 type,
  *              offset values are contained in 8 bits [-128, 127]
  */
-void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
-                                   const uint16_t input_wd,
-                                   const uint16_t input_ht,
-                                   const uint16_t channels,
-                                   const int32_t input_offset,
-                                   const uint16_t pad_wd,
-                                   const uint16_t pad_ht,
-                                   const uint16_t stride_wd,
-                                   const uint16_t stride_ht,
-                                   const uint16_t ch_mult,
+void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
+                                   const int8_t *input_data,
+                                   const data_dims_t *filter_dims,
                                    const int8_t *filter_data,
-                                   const uint16_t filter_wd,
-                                   const uint16_t filter_ht,
                                    const int32_t *bias,
+                                   const data_dims_t *output_dims,
                                    int8_t *out_data,
-                                   const uint16_t out_wd,
-                                   const uint16_t out_ht,
-                                   const int32_t out_offset,
-                                   const int32_t *out_shift,
-                                   const int32_t *out_mult,
-                                   const int32_t activation_min,
-                                   const int32_t activation_max);
+                                   const dw_conv_params_t *conv_params,
+                                   const quant_data_t *quant_data);
 
 /**
  * @brief       2d-convolution channelwise
@@ -112,43 +98,26 @@ void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
  *              inputs type: int8_t, output: int8_t
  *              input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  */
-void esp_nn_conv_s8_ansi(const int8_t *input_data,
-                         const uint16_t input_wd,
-                         const uint16_t input_ht,
-                         const uint16_t in_channels,
-                         const int32_t input_offset,
-                         const uint16_t pad_wd,
-                         const uint16_t pad_ht,
-                         const uint16_t stride_wd,
-                         const uint16_t stride_ht,
+void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
+                         const int8_t *input_data,
+                         const data_dims_t *filter_dims,
                          const int8_t *filter_data,
-                         const uint16_t filter_wd,
-                         const uint16_t filter_ht,
                          const int32_t *bias,
+                         const data_dims_t *output_dims,
                          int8_t *out_data,
-                         const uint16_t out_wd,
-                         const uint16_t out_ht,
-                         const uint16_t out_channels,
-                         const int32_t out_offset,
-                         const int32_t *out_shift,
-                         const int32_t *out_mult,
-                         const int32_t activation_min,
-                         const int32_t activation_max);
-
-int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
-                                      const uint16_t input_ht,
-                                      const uint16_t in_ch,
-                                      const uint16_t out_ch,
-                                      const uint16_t filter_wd,
-                                      const uint16_t filter_ht);
+                         const conv_params_t *conv_params,
+                         const quant_data_t *quant_data);
+
+int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
+                                      const data_dims_t *filter_dims,
+                                      const data_dims_t *output_dims,
+                                      const conv_params_t *conv_params);
 void esp_nn_set_conv_scratch_buf_ansi(const void *buf);
 
-int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
-                                                const uint16_t input_ht,
-                                                const uint16_t channels,
-                                                const uint16_t ch_mult,
-                                                const uint16_t filter_wd,
-                                                const uint16_t filter_ht);
+int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
+                                                const data_dims_t *filter_dims,
+                                                const data_dims_t *output_dims,
+                                                const dw_conv_params_t *conv_params);
 void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf);
 
 /************************** Activation functions *****************************/
@@ -252,9 +221,6 @@ int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t h
  */
 void esp_nn_set_softmax_scratch_buf_ansi(void *buffer);
 
-/* ANSI C function to be hooked up when optimised version needed */
-void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
-
 /**
  * @brief       reference softmax function
  *
@@ -268,6 +234,66 @@ void esp_nn_softmax_s8_ansi(const int8_t *input_data,
                             const int32_t diff_min,
                             int8_t *output_data);
 
+
+//////////////////////////// Generic optimisations /////////////////////////////
+
+/************************** Convolution functions *****************************/
+
+/**
+ * @brief       2d-convolution channelwise optimized version
+ *
+ * @note        operation: result += (input + offset) * filter
+ *
+ *              inputs type: int8_t, output: int8_t
+ *              input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
+                        const int8_t *input_data,
+                        const data_dims_t *filter_dims,
+                        const int8_t *filter_data,
+                        const int32_t *bias,
+                        const data_dims_t *output_dims,
+                        int8_t *out_data,
+                        const conv_params_t *conv_params,
+                        const quant_data_t *quant_data);
+
+/**
+ * @brief       depthwise convolution per channel optimized version
+ *
+ * @note        inputs type: int8_t, output: int8_t
+ *              Version used in tflite is per channel.
+ *              This version follows the same footsprints.
+ *              Meaning, it has per out_channel shift and multiplier for
+ *              requantization
+ *
+ *              optimization notes: Though input_offset is int32 type,
+ *              offset values are contained in 8 bits [-128, 127]
+ */
+void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
+                                  const int8_t *input_data,
+                                  const data_dims_t *filter_dims,
+                                  const int8_t *filter_data,
+                                  const int32_t *bias,
+                                  const data_dims_t *output_dims,
+                                  int8_t *out_data,
+                                  const dw_conv_params_t *conv_params,
+                                  const quant_data_t *quant_data);
+
+int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
+                                     const data_dims_t *filter_dims,
+                                     const data_dims_t *output_dims,
+                                     const conv_params_t *conv_params);
+void esp_nn_set_conv_scratch_buf_opt(const void *buf);
+
+int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
+                                               const data_dims_t *filter_dims,
+                                               const data_dims_t *output_dims,
+                                               const dw_conv_params_t *conv_params);
+void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf);
+
+/* ANSI C function to be hooked up when optimised version needed */
+void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
+
 /**
  * @brief       optimised version of softmax function
  *

+ 83 - 0
code/components/esp-nn/include/esp_nn_defs.h

@@ -0,0 +1,83 @@
+// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <stdint.h>
+
+/**
+ * @brief structure to club data dims
+ * this structure can be used for input, output and filter
+ */
+typedef struct data_dims {
+    int32_t width;
+    int32_t height;
+    int32_t channels;
+
+    int32_t extra; // can be used as batch or any other param
+} data_dims_t;
+
+/**
+ * @brief 2d data structure (width, height)
+ *
+ */
+typedef struct data_2d {
+    int32_t width;
+    int32_t height;
+} data_2d_t;
+
+/**
+ * @brief min/max activation
+ */
+typedef struct act_params {
+    int32_t min;
+    int32_t max;
+} act_params_t;
+
+/**
+ * @brief per channel quant data
+ *
+ * @note number of shift and mult elements are equal to output channels
+ */
+typedef struct quant_data {
+    int32_t *shift;
+    int32_t *mult;
+} quant_data_t;
+
+/**
+ * @brief params specific to convolution 2d
+ *
+ */
+typedef struct conv_params {
+    int32_t in_offset;
+    int32_t out_offset;
+    data_2d_t stride;
+    data_2d_t padding;
+    data_2d_t dilation;
+    act_params_t activation;
+} conv_params_t;
+
+/**
+ * @brief params specific to depthwise convolution 2d
+ *
+ */
+typedef struct dw_conv_params {
+    int32_t in_offset;
+    int32_t out_offset;
+    int32_t ch_mult; // channel multiplier. (in_ch * ch_mult = out_ch)
+    data_2d_t stride;
+    data_2d_t padding;
+    data_2d_t dilation;
+    act_params_t activation;
+} dw_conv_params_t;

+ 24 - 54
code/components/esp-nn/include/esp_nn_esp32s3.h

@@ -19,7 +19,7 @@
 
 #pragma once
 
-#include <stdint.h>
+#include "esp_nn_defs.h"
 #include "esp_nn_ansi_headers.h"
 
 /************************** Basic math functions *****************************/
@@ -85,28 +85,15 @@ void esp_nn_mul_elementwise_s8_esp32s3(const int8_t *input1_data,
  *              optimization notes: Though input_offset is int32 type,
  *              offset values are contained in 8 bits [-128, 127]
  */
-void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
-                                      const uint16_t input_wd,
-                                      const uint16_t input_ht,
-                                      const uint16_t channels,
-                                      const int32_t input_offset,
-                                      const uint16_t pad_wd,
-                                      const uint16_t pad_ht,
-                                      const uint16_t stride_wd,
-                                      const uint16_t stride_ht,
-                                      const uint16_t ch_mult,
+void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
+                                      const int8_t *input_data,
+                                      const data_dims_t *filter_dims,
                                       const int8_t *filter_data,
-                                      const uint16_t filter_wd,
-                                      const uint16_t filter_ht,
                                       const int32_t *bias,
-                                      int8_t *out_data,
-                                      const uint16_t out_wd,
-                                      const uint16_t out_ht,
-                                      const int32_t out_offset,
-                                      const int32_t *out_shift,
-                                      const int32_t *out_mult,
-                                      const int32_t activation_min,
-                                      const int32_t activation_max);
+                                      const data_dims_t *output_dims,
+                                      int8_t *output_data,
+                                      const dw_conv_params_t *conv_params,
+                                      const quant_data_t *quant_data);
 
 /**
  * @brief       2d - convolution channelwise
@@ -116,43 +103,26 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
  *              inputs type: int8_t, output: int8_t
  *              input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  */
-void esp_nn_conv_s8_esp32s3(const int8_t *input_data,
-                            const uint16_t input_wd,
-                            const uint16_t input_ht,
-                            const uint16_t in_channels,
-                            const int32_t input_offset,
-                            const uint16_t pad_wd,
-                            const uint16_t pad_ht,
-                            const uint16_t stride_wd,
-                            const uint16_t stride_ht,
+void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
+                            const int8_t *input_data,
+                            const data_dims_t *filter_dims,
                             const int8_t *filter_data,
-                            const uint16_t filter_wd,
-                            const uint16_t filter_ht,
                             const int32_t *bias,
-                            int8_t *out_data,
-                            const uint16_t out_wd,
-                            const uint16_t out_ht,
-                            const uint16_t out_channels,
-                            const int32_t out_offset,
-                            const int32_t *out_shift,
-                            const int32_t *out_mult,
-                            const int32_t activation_min,
-                            const int32_t activation_max);
-
-int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
-                                         const uint16_t input_ht,
-                                         const uint16_t in_ch,
-                                         const uint16_t out_ch,
-                                         const uint16_t filter_wd,
-                                         const uint16_t filter_ht);
+                            const data_dims_t *output_dims,
+                            int8_t *output_data,
+                            const conv_params_t *conv_params,
+                            const quant_data_t *quant_data);
+
+int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+                                         const data_dims_t *filter_dims,
+                                         const data_dims_t *output_dims,
+                                         const conv_params_t *conv_params);
 void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf);
 
-int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
-                                                   const uint16_t input_ht,
-                                                   const uint16_t channels,
-                                                   const uint16_t ch_mult,
-                                                   const uint16_t filter_wd,
-                                                   const uint16_t filter_ht);
+int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+                                                   const data_dims_t *filter_dims,
+                                                   const data_dims_t *output_dims,
+                                                   const dw_conv_params_t *conv_params);
 void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf);
 
 /************************** Pooling functions *****************************/

+ 9 - 10
code/components/esp-nn/include/esp_nn_esp32.h → code/components/esp-nn/include/esp_nn_generic_opt.h

@@ -13,28 +13,27 @@
 // limitations under the License.
 
 /**
- * @file        Header definitions to include for esp_nn optimized functions for
- *              the ESP32 platform.
- *              We are hooking up just the C versions for now.
- *              The file hence is exactly same as `esp_nn_ansi_c.h`
+ * @file        Header definitions to include for esp_nn generic optimisations
+ *              For functions which not having optimisations, _ansi versions are picked.
  */
 
 #pragma once
 
+#include "esp_nn_defs.h"
 #include "esp_nn_ansi_headers.h"
 
 #define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
 #define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
 
-#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_ansi
+#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_opt
 
-#define esp_nn_conv_s8 esp_nn_conv_s8_ansi
+#define esp_nn_conv_s8 esp_nn_conv_s8_opt
 
-#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_ansi
-#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_ansi
+#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_opt
+#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_opt
 
-#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_ansi
-#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_ansi
+#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_opt
+#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_opt
 
 #define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
 

+ 50 - 13
code/components/esp-nn/src/common/common_functions.h

@@ -41,15 +41,39 @@
 
 __NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in)
 {
+#if CONFIG_IDF_TARGET_ARCH_XTENSA
     __asm__ volatile("nsau %0, %0" : "+r" (in));
     return in;
-}
-
-__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
-{
-    int32_t sign = (int32_t) (val64 >> 63);
-    int32_t to_add = sign & ((1ul << 31) - 1);
-    return (int32_t) ((int64_t) (val64 + to_add) >> 31);
+#elif defined(__GNUC__)
+    return __builtin_clz(in);
+#else
+    int32_t count = 32;
+    uint32_t x = in, y = in >> 16;
+    if (y != 0) {
+        count -= 16;
+        x = y;
+    }
+    y = x >> 8;
+    if (y != 0) {
+        count -= 8;
+        x = y;
+    }
+    y = x >> 4;
+    if (y != 0) {
+        count -= 4;
+        x = y;
+    }
+    y = x >> 2;
+    if (y != 0) {
+        count -= 2;
+        x = y;
+    }
+    y = x >> 1;
+    if (y != 0) {
+        return count - 2;
+    }
+    return count - x;
+#endif
 }
 
 /**
@@ -57,8 +81,19 @@ __NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
  */
 __NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in)
 {
+#if CONFIG_IDF_TARGET_ARCH_XTENSA
     __asm__ volatile("clamps %0, %0, 7" : "+a"(in));
     return in;
+#else
+    return max(INT8_MIN, min(in, INT8_MAX));
+#endif
+}
+
+__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
+{
+    int32_t sign = (int32_t) (val64 >> 63);
+    int32_t to_add = sign & ((1ul << 31) - 1);
+    return (int32_t) ((int64_t) (val64 + to_add) >> 31);
 }
 
 __NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1)
@@ -144,7 +179,7 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
                                              const uint16_t pad_ht)
 {
     /* memset with pad_val */
-    memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels * 2);
+    memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels);
     dst += (pad_wd + input_wd + pad_wd) * channels;
 
     for (int i = 0; i < input_ht; i++) {
@@ -156,7 +191,6 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
     }
 }
 
-#if 0
 static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
                                                  const uint16_t input_wd,
                                                  const uint16_t input_ht,
@@ -169,13 +203,16 @@ static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
         for (int j = 0; j < input_wd * channels; j++) {
             *dst++ = *src++;
         }
-        memset(dst, pad_val, pad_wd * channels);
-        dst += pad_wd * channels;
+        if (pad_wd) {
+            memset(dst, pad_val, pad_wd * channels);
+            dst += pad_wd * channels;
+        }
     }
     /* pad end `pad_ht` lines at end */
-    memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
+    if (pad_ht) {
+        memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
+    }
 }
-#endif
 
 /**
  * @brief       convert 8 bit input data to 16 bit

+ 30 - 26
code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c

@@ -12,16 +12,14 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include <stdint.h>
+#include <esp_nn_defs.h>
 
 #include <common_functions.h>
 
-int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
-                                      const uint16_t input_ht,
-                                      const uint16_t in_ch,
-                                      const uint16_t out_ch,
-                                      const uint16_t filter_wd,
-                                      const uint16_t filter_ht)
+int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
+                                      const data_dims_t *filter_dims,
+                                      const data_dims_t *output_dims,
+                                      const conv_params_t *conv_params)
 {
     return 0;
 }
@@ -108,29 +106,35 @@ void esp_nn_conv_u8_ansi(const uint8_t *input_data,
  * Assumption 2: Pointers are valid
  * Assumption 3: dialation width = 1
  */
-void esp_nn_conv_s8_ansi(const int8_t *input_data,
-                         const uint16_t input_wd,
-                         const uint16_t input_ht,
-                         const uint16_t in_channels,
-                         const int32_t input_offset,
-                         const uint16_t pad_wd,
-                         const uint16_t pad_ht,
-                         const uint16_t stride_wd,
-                         const uint16_t stride_ht,
+void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
+                         const int8_t *input_data,
+                         const data_dims_t *filter_dims,
                          const int8_t *filter_data,
-                         const uint16_t filter_wd,
-                         const uint16_t filter_ht,
                          const int32_t *bias,
+                         const data_dims_t *output_dims,
                          int8_t *out_data,
-                         const uint16_t out_wd,
-                         const uint16_t out_ht,
-                         const uint16_t out_channels,
-                         const int32_t out_offset,
-                         const int32_t *out_shift,
-                         const int32_t *out_mult,
-                         const int32_t activation_min,
-                         const int32_t activation_max)
+                         const conv_params_t *conv_params,
+                         const quant_data_t *quant_data)
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t in_channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_channels = output_dims->channels;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
     int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
 
     for (out_y = 0; out_y < out_ht; out_y++) {

+ 106 - 79
code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c

@@ -12,30 +12,30 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include <stdint.h>
 #include <stdio.h>
+#include <esp_nn_defs.h>
 
 #include <common_functions.h>
 
 static int16_t *scratch_buffer = NULL;
 
-extern void esp_nn_conv_s16_mult8_1x1_esp32s3(const int8_t *input_data,
-                                              const uint16_t input_wd,
-                                              const uint16_t input_ht,
-                                              const uint16_t in_channels,
-                                              const int32_t input_offset,
-                                              const int16_t *filter_data,
-                                              const int32_t *bias,
-                                              int8_t *out_data,
-                                              const uint16_t out_wd,
-                                              const uint16_t out_ht,
-                                              const uint16_t out_channels,
-                                              const int32_t out_offset,
-                                              const int32_t *out_shift,
-                                              const int32_t *out_mult,
-                                              const int32_t activation_min,
-                                              const int32_t activation_max,
-                                              void *buffer /* scratch buffer */);
+extern void esp_nn_conv_s8_mult8_1x1_esp32s3(const int8_t *input_data,
+                                             const uint16_t input_wd,
+                                             const uint16_t input_ht,
+                                             const uint16_t in_channels,
+                                             const int32_t input_offset,
+                                             const int8_t *filter_aligned,
+                                             const int32_t *bias,
+                                             int8_t *out_data,
+                                             const uint16_t out_wd,
+                                             const uint16_t out_ht,
+                                             const uint16_t out_channels,
+                                             const int32_t out_offset,
+                                             const int32_t *out_shift,
+                                             const int32_t *out_mult,
+                                             const int32_t activation_min,
+                                             const int32_t activation_max,
+                                             void *buffer /* scratch buffer */);
 
 extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
                                               const uint16_t input_wd,
@@ -81,34 +81,40 @@ extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int1
 
 extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
 
-static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
-                                    const uint16_t input_wd,
-                                    const uint16_t input_ht,
-                                    const uint16_t in_channels,
-                                    const int32_t input_offset,
-                                    const uint16_t pad_wd,
-                                    const uint16_t pad_ht,
-                                    const uint16_t stride_wd,
-                                    const uint16_t stride_ht,
+static void esp_nn_conv_s8_unrolled(const data_dims_t *input_dims,
+                                    const int8_t *input_data,
+                                    const data_dims_t *filter_dims,
                                     const int8_t *filter_data,
-                                    const uint16_t filter_wd,
-                                    const uint16_t filter_ht,
                                     const int32_t *bias,
+                                    const data_dims_t *output_dims,
                                     int8_t *out_data,
-                                    const uint16_t out_wd,
-                                    const uint16_t out_ht,
-                                    const uint16_t out_channels,
-                                    const int32_t out_offset,
-                                    const int32_t *out_shift,
-                                    const int32_t *out_mult,
-                                    const int32_t activation_min,
-                                    const int32_t activation_max)
+                                    const conv_params_t *conv_params,
+                                    const quant_data_t *quant_data)
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t in_ch = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_ch = output_dims->channels;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
     int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
 
     for (out_y = 0; out_y < out_ht; out_y++) {
         for (out_x = 0; out_x < out_wd; out_x++) {
-            for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+            for (out_ch_idx = 0; out_ch_idx < out_ch; out_ch_idx++) {
                 int32_t conv_out = 0;
 
                 const int32_t base_y = stride_ht * out_y - pad_ht;
@@ -124,10 +130,10 @@ static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
                     for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
                         const int32_t in_row = base_y + filter_y_idx;
                         const int32_t in_col = base_x + filter_x_idx;
-                        int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
-                        int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
-                                                       (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
-                        for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
+                        int32_t input_base_offset = (in_row * input_wd + in_col) * in_ch;
+                        int32_t filter_base_offset = out_ch_idx * in_ch * filter_ht * filter_wd +
+                                                       (filter_y_idx * filter_wd + filter_x_idx) * in_ch;
+                        for (in_ch_idx = 0; in_ch_idx < in_ch; in_ch_idx++) {
                             conv_out +=
                                 (input_data[input_base_offset + in_ch_idx] + input_offset) *
                                 filter_data[filter_base_offset + in_ch_idx];
@@ -332,18 +338,35 @@ static void esp_nn_conv_s8_pad_valid_ch3_3x3(const int8_t *input_data,
     }
 }
 
-int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
-                                         const uint16_t input_ht,
-                                         const uint16_t in_ch,
-                                         const uint16_t out_ch,
-                                         const uint16_t filter_wd,
-                                         const uint16_t filter_ht)
+int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+                                         const data_dims_t *filter_dims,
+                                         const data_dims_t *output_dims,
+                                         const conv_params_t *conv_params)
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t in_ch = input_dims->channels;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_ch = output_dims->channels;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+
     int filter_size = filter_wd * filter_ht * in_ch * out_ch;
     int input_size = input_wd * input_ht * in_ch;
-    int transpose_buf_size = 8 * in_ch; /* to store intermediate data */
+
+    int transpose_buf_size = 2 * (8 * in_ch); /* to store intermediate data */
+    if (input_wd * input_ht < 8) {
+        transpose_buf_size = 0; // not using this for leftover
+    }
     int align_buf_size = 32; /* extra buffer for alignment */
-    return 2 * (filter_size + input_size +  transpose_buf_size) + align_buf_size;
+    if (in_ch % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
+            pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
+        return filter_size + transpose_buf_size + align_buf_size;
+    }
+    return 2 * (filter_size + input_size) +  transpose_buf_size + align_buf_size;
 }
 
 void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
@@ -351,29 +374,35 @@ void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
     scratch_buffer = (int16_t *) buf;
 }
 
-void esp_nn_conv_s8_esp32s3(const int8_t *input,
-                            const uint16_t input_wd,
-                            const uint16_t input_ht,
-                            const uint16_t channels,
-                            const int32_t input_offset,
-                            const uint16_t pad_wd,
-                            const uint16_t pad_ht,
-                            const uint16_t stride_wd,
-                            const uint16_t stride_ht,
+void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
+                            const int8_t *input,
+                            const data_dims_t *filter_dims,
                             const int8_t *filter_data,
-                            const uint16_t filter_wd,
-                            const uint16_t filter_ht,
                             const int32_t *bias,
+                            const data_dims_t *output_dims,
                             int8_t *out_data,
-                            const uint16_t out_wd,
-                            const uint16_t out_ht,
-                            const uint16_t out_channels,
-                            const int32_t out_offset,
-                            const int32_t *out_shift,
-                            const int32_t *out_mult,
-                            const int32_t activation_min,
-                            const int32_t activation_max)
+                            const conv_params_t *conv_params,
+                            const quant_data_t *quant_data)
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_channels = output_dims->channels;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
     int filter_size = filter_wd * filter_ht * channels * out_channels;
     int input_size = input_wd * input_ht * channels;
     int align_len = 16 - (filter_size & 15);
@@ -387,15 +416,16 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input,
 
     if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
             pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
-        int scratch_offset = (int) (filter_data16 + filter_size);
+        int8_t *filter_aligned = (int8_t *) scratch_buffer;
+        int scratch_offset = (int) (filter_aligned + filter_size);
         void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
-        esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
-        esp_nn_conv_s16_mult8_1x1_esp32s3(
-            input, input_wd, input_ht, channels, input_offset, filter_data16,
+        memcpy(filter_aligned, filter_data, filter_size); // copy to aligned address
+        esp_nn_conv_s8_mult8_1x1_esp32s3(
+            input, input_wd, input_ht, channels, input_offset, filter_aligned,
             bias, out_data, out_wd, out_ht, out_channels, out_offset,
             out_shift, out_mult, activation_min, activation_max, scratch_buf);
     } else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
-            (input_wd * input_ht) % 16 == 0 && /* TODO: remove this check */
+            (input_wd * input_ht) % 4 == 0 && /* TODO: remove this check */
             pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
         int scratch_offset = (int) (input_data16 + input_size);
         void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
@@ -427,10 +457,7 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input,
         }
     } else {
         /* Basic unrolled version */
-        esp_nn_conv_s8_unrolled(input, input_wd, input_ht, channels, input_offset,
-                                pad_wd, pad_ht, stride_wd, stride_ht,
-                                filter_data, filter_wd, filter_ht, bias,
-                                out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
-                                out_mult, activation_min, activation_max);
+        esp_nn_conv_s8_unrolled(input_dims, input, filter_dims, filter_data,
+                                bias, output_dims, out_data, conv_params, quant_data);
     }
 }

+ 179 - 0
code/components/esp-nn/src/convolution/esp_nn_conv_opt.c

@@ -0,0 +1,179 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <esp_nn_defs.h>
+
+#include <common_functions.h>
+
+int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
+                                     const data_dims_t *filter_dims,
+                                     const data_dims_t *output_dims,
+                                     const conv_params_t *conv_params)
+{
+    return 0;
+}
+
+void esp_nn_set_conv_scratch_buf_opt(const void *buf)
+{
+
+}
+
+__attribute__ ((noinline))
+static void esp_nn_conv_s8_1x1(const data_dims_t *input_dims,
+                               const int8_t *input_data,
+                               const int8_t *filter_data,
+                               const int32_t *bias,
+                               const data_dims_t *output_dims,
+                               int8_t *out_data,
+                               const conv_params_t *conv_params,
+                               const quant_data_t *quant_data)
+{
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t in_channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_channels = output_dims->channels;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
+    for (int32_t in_row = 0; in_row < out_ht * stride_ht; in_row += stride_ht) {
+        for (int32_t in_col = 0; in_col < out_wd * stride_wd; in_col += stride_wd) {
+            const int32_t *out_mult = quant_data->mult;
+            const int32_t *out_shift = quant_data->shift;
+            const int8_t *filter_ptr = filter_data;
+            const int8_t *input_base_ptr = input_data + (in_row * input_wd + in_col) * in_channels;
+            int32_t out_ch_idx = 0;
+            for (; out_ch_idx < out_channels; out_ch_idx++) {
+                int32_t conv_out = 0;
+
+                const int8_t *input_ptr = input_base_ptr;
+
+                int32_t in_ch_idx = 0;
+                for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                }
+                for (; in_ch_idx < in_channels; in_ch_idx ++) {
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                }
+                if (bias) {
+                    conv_out += bias[out_ch_idx];
+                }
+                conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
+                conv_out += out_offset;
+                conv_out = max(conv_out, activation_min);
+                conv_out = min(conv_out, activation_max);
+                *out_data++ = (int8_t) conv_out;
+            }
+        }
+    }
+}
+
+/**
+ * Assumption 1: i/p channels == o/p channels
+ * Assumption 2: Pointers are valid
+ * Assumption 3: dialation width = 1
+ */
+void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
+                        const int8_t *input_data,
+                        const data_dims_t *filter_dims,
+                        const int8_t *filter_data,
+                        const int32_t *bias,
+                        const data_dims_t *output_dims,
+                        int8_t *out_data,
+                        const conv_params_t *conv_params,
+                        const quant_data_t *quant_data)
+{
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+
+    if (filter_wd == 1 && filter_ht == 1) {
+        esp_nn_conv_s8_1x1(input_dims, input_data, filter_data, bias,
+                           output_dims, out_data, conv_params, quant_data);
+        return;
+    }
+
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t in_channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_channels = output_dims->channels;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
+    int32_t out_ch_idx, out_y, out_x, filter_y_idx, filter_x_idx;
+
+    for (out_y = 0; out_y < out_ht; out_y++) {
+        for (out_x = 0; out_x < out_wd; out_x++) {
+            const int32_t *out_shift = quant_data->shift;
+            const int32_t *out_mult = quant_data->mult;
+            for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+                int32_t conv_out = 0;
+
+                const int32_t base_y = stride_ht * out_y - pad_ht;
+                const int32_t base_x = stride_wd * out_x - pad_wd;
+
+                const int32_t filter_y_start = max(0, -base_y);
+                const int32_t filter_x_start = max(0, -base_x);
+
+                const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
+                const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
+
+                for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                    for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                        const int32_t in_row = base_y + filter_y_idx;
+                        const int32_t in_col = base_x + filter_x_idx;
+
+                        const int8_t *input_ptr = input_data +
+                                        (in_row * input_wd + in_col) * in_channels;
+                        const int8_t *filter_ptr = filter_data +
+                                        out_ch_idx * in_channels * filter_ht * filter_wd +
+                                        (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
+                        int32_t in_ch_idx = 0;
+                        for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                        }
+                        for (; in_ch_idx < in_channels; in_ch_idx ++) {
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                        }
+                    }
+                }
+                if (bias) {
+                    conv_out += bias[out_ch_idx];
+                }
+                conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
+                conv_out += out_offset;
+                conv_out = max(conv_out, activation_min);
+                conv_out = min(conv_out, activation_max);
+                *out_data++ = (int8_t) conv_out;
+            }
+        }
+    }
+}

+ 30 - 27
code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c

@@ -12,16 +12,13 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include <stdint.h>
-
+#include <esp_nn_defs.h>
 #include <common_functions.h>
 
-int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
-                                                const uint16_t input_ht,
-                                                const uint16_t channels,
-                                                const uint16_t ch_mult,
-                                                const uint16_t filter_wd,
-                                                const uint16_t filter_ht)
+int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
+                                                const data_dims_t *filter_dims,
+                                                const data_dims_t *output_dims,
+                                                const dw_conv_params_t *conv_params)
 {
     return 0;
 }
@@ -31,29 +28,35 @@ void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf)
 
 }
 
-void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
-                                   const uint16_t input_wd,
-                                   const uint16_t input_ht,
-                                   const uint16_t channels,
-                                   const int32_t input_offset,
-                                   const uint16_t pad_wd,
-                                   const uint16_t pad_ht,
-                                   const uint16_t stride_wd,
-                                   const uint16_t stride_ht,
-                                   const uint16_t ch_mult,
+void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
+                                   const int8_t *input_data,
+                                   const data_dims_t *filter_dims,
                                    const int8_t *filter_data,
-                                   const uint16_t filter_wd,
-                                   const uint16_t filter_ht,
                                    const int32_t *bias,
+                                   const data_dims_t *output_dims,
                                    int8_t *out_data,
-                                   const uint16_t out_wd,
-                                   const uint16_t out_ht,
-                                   const int32_t out_offset,
-                                   const int32_t *out_shift,
-                                   const int32_t *out_mult,
-                                   const int32_t activation_min,
-                                   const int32_t activation_max)
+                                   const dw_conv_params_t *conv_params,
+                                   const quant_data_t *quant_data)
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+    const uint16_t ch_mult = conv_params->ch_mult;
+
     int out_idx = 0;
     for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
         const int16_t base_y = (out_y * stride_ht) - pad_ht;

+ 291 - 0
code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c

@@ -0,0 +1,291 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <esp_nn_defs.h>
+#include <common_functions.h>
+
+int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
+                                               const data_dims_t *filter_dims,
+                                               const data_dims_t *output_dims,
+                                               const dw_conv_params_t *conv_params)
+{
+    return 0;
+}
+
+void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf)
+{
+
+}
+
+/* common channel multiplier == 1 case */
+__attribute__ ((noinline))
+static void esp_nn_depthwise_conv_s8_ch_mult_1(const data_dims_t *input_dims,
+                                               const int8_t *input_data,
+                                               const data_dims_t *filter_dims,
+                                               const int8_t *filter_data,
+                                               const int32_t *bias,
+                                               const data_dims_t *output_dims,
+                                               int8_t *out_data,
+                                               const dw_conv_params_t *conv_params,
+                                               const quant_data_t *quant_data)
+{
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
+    int out_idx = 0;
+    for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+        const int16_t base_y = (out_y * stride_ht) - pad_ht;
+        for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+            const int16_t base_x = (out_x * stride_wd) - pad_wd;
+
+            const int32_t *out_shift = quant_data->shift;
+            const int32_t *out_mult = quant_data->mult;
+
+            /* Select filter so as the point doesn't lie outside block */
+            int filter_y_start = max(0, -base_y);
+            int filter_x_start = max(0, -base_x);
+            int filter_y_end = min(filter_ht, input_ht - base_y);
+            int filter_x_end = min(filter_wd, input_wd - base_x);
+
+            int ch_idx = 0;
+            for (; ch_idx < channels - 3; ch_idx += 4) {//channel_loop
+                int32_t result0 = 0;
+                int32_t result1 = 0;
+                int32_t result2 = 0;
+                int32_t result3 = 0;
+
+                for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                    const int32_t idx_y = base_y + filter_y_idx;
+                    for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                        const int32_t idx_x = base_x + filter_x_idx;
+                        int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+                        int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
+                        int32_t input_val0 = input_data[input_index + 0] + input_offset;
+                        int32_t input_val1 = input_data[input_index + 1] + input_offset;
+                        int32_t input_val2 = input_data[input_index + 2] + input_offset;
+                        int32_t input_val3 = input_data[input_index + 3] + input_offset;
+                        int32_t filter_val0 = filter_data[filter_index + 0];
+                        int32_t filter_val1 = filter_data[filter_index + 1];
+                        int32_t filter_val2 = filter_data[filter_index + 2];
+                        int32_t filter_val3 = filter_data[filter_index + 3];
+                        result0 += input_val0 * filter_val0;
+                        result1 += input_val1 * filter_val1;
+                        result2 += input_val2 * filter_val2;
+                        result3 += input_val3 * filter_val3;
+                    }
+                }
+                if (bias) {
+                    result0 += bias[ch_idx + 0];
+                    result1 += bias[ch_idx + 1];
+                    result2 += bias[ch_idx + 2];
+                    result3 += bias[ch_idx + 3];
+                }
+                result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
+                result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
+                result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
+                result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
+
+                result0 += out_offset;
+                result1 += out_offset;
+                result2 += out_offset;
+                result3 += out_offset;
+
+                result0 = max(result0, activation_min);
+                result1 = max(result1, activation_min);
+                result2 = max(result2, activation_min);
+                result3 = max(result3, activation_min);
+
+                result0 = min(result0, activation_max);
+                result1 = min(result1, activation_max);
+                result2 = min(result2, activation_max);
+                result3 = min(result3, activation_max);
+
+                out_data[out_idx++] = result0;
+                out_data[out_idx++] = result1;
+                out_data[out_idx++] = result2;
+                out_data[out_idx++] = result3;
+            }
+            for (; ch_idx < channels; ch_idx++) {//channel_loop
+                int32_t result = 0;
+
+                for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                    const int32_t idx_y = base_y + filter_y_idx;
+                    for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                        const int32_t idx_x = base_x + filter_x_idx;
+                        int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+                        int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
+                        int32_t input_val = input_data[input_index] + input_offset;
+                        int32_t filter_val = filter_data[filter_index];
+                        result += input_val * filter_val;
+                    }
+                }
+                if (bias) {
+                    result += bias[ch_idx];
+                }
+                result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
+                result += out_offset;
+                result = max(result, activation_min);
+                result = min(result, activation_max);
+
+                out_data[out_idx++] = result;
+            }
+        }
+    }
+}
+
+void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
+                                  const int8_t *input_data,
+                                  const data_dims_t *filter_dims,
+                                  const int8_t *filter_data,
+                                  const int32_t *bias,
+                                  const data_dims_t *output_dims,
+                                  int8_t *out_data,
+                                  const dw_conv_params_t *conv_params,
+                                  const quant_data_t *quant_data)
+{
+    const uint16_t ch_mult = conv_params->ch_mult;
+    if (ch_mult == 1) {
+        esp_nn_depthwise_conv_s8_ch_mult_1(input_dims, input_data, filter_dims, filter_data,
+                                           bias, output_dims, out_data, conv_params, quant_data);
+        return;
+    }
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
+    int out_idx = 0;
+    for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+        const int16_t base_y = (out_y * stride_ht) - pad_ht;
+        for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+            const int16_t base_x = (out_x * stride_wd) - pad_wd;
+
+            const int32_t *out_shift = quant_data->shift;
+            const int32_t *out_mult = quant_data->mult;
+
+            /* Select filter so as the point doesn't lie outside block */
+            int filter_y_start = max(0, -base_y);
+            int filter_x_start = max(0, -base_x);
+            int filter_y_end = min(filter_ht, input_ht - base_y);
+            int filter_x_end = min(filter_wd, input_wd - base_x);
+
+            for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
+                int ch_mult_idx = 0;
+                for (; ch_mult_idx < ch_mult - 3; ch_mult_idx += 4) {
+                    int32_t result0 = 0;
+                    int32_t result1 = 0;
+                    int32_t result2 = 0;
+                    int32_t result3 = 0;
+                    const int out_ch_idx =  ch_idx * ch_mult + ch_mult_idx;
+
+                    for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                        const int32_t idx_y = base_y + filter_y_idx;
+                        for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                            const int32_t idx_x = base_x + filter_x_idx;
+                            int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+                            int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
+                            int32_t input_val = input_data[input_index] + input_offset;
+                            int32_t filter_val0 = filter_data[filter_index + 0];
+                            int32_t filter_val1 = filter_data[filter_index + 1];
+                            int32_t filter_val2 = filter_data[filter_index + 2];
+                            int32_t filter_val3 = filter_data[filter_index + 3];
+                            result0 += input_val * filter_val0;
+                            result1 += input_val * filter_val1;
+                            result2 += input_val * filter_val2;
+                            result3 += input_val * filter_val3;
+                        }
+                    }
+                    if (bias) {
+                        result0 += bias[out_ch_idx + 0];
+                        result1 += bias[out_ch_idx + 1];
+                        result2 += bias[out_ch_idx + 2];
+                        result3 += bias[out_ch_idx + 3];
+                    }
+                    result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
+                    result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
+                    result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
+                    result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
+
+                    result0 += out_offset;
+                    result1 += out_offset;
+                    result2 += out_offset;
+                    result3 += out_offset;
+
+                    result0 = max(result0, activation_min);
+                    result1 = max(result1, activation_min);
+                    result2 = max(result2, activation_min);
+                    result3 = max(result3, activation_min);
+                    result0 = min(result0, activation_max);
+                    result1 = min(result1, activation_max);
+                    result2 = min(result2, activation_max);
+                    result3 = min(result3, activation_max);
+
+                    out_data[out_idx++] = result0;
+                    out_data[out_idx++] = result1;
+                    out_data[out_idx++] = result2;
+                    out_data[out_idx++] = result3;
+                }
+                for (; ch_mult_idx < ch_mult; ch_mult_idx++) {
+                    int32_t result = 0;
+                    const int out_ch_idx =  ch_idx * ch_mult + ch_mult_idx;
+
+                    for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                        const int32_t idx_y = base_y + filter_y_idx;
+                        for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                            const int32_t idx_x = base_x + filter_x_idx;
+                            int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+                            int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
+                            int32_t input_val = input_data[input_index] + input_offset;
+                            int32_t filter_val = filter_data[filter_index];
+                            result += input_val * filter_val;
+                        }
+                    }
+                    if (bias) {
+                        result += bias[out_ch_idx];
+                    }
+                    result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
+                    result += out_offset;
+                    result = max(result, activation_min);
+                    result = min(result, activation_max);
+
+                    out_data[out_idx++] = result;
+                }
+            }
+        }
+    }
+}

+ 97 - 37
code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c

@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include <stdint.h>
 #include <stdio.h>
+#include <esp_nn_defs.h>
 
 #include <common_functions.h>
 
@@ -353,17 +353,59 @@ void esp_nn_depthwise_conv_s8_ch_mult1(const int8_t *input_data,
     }
 }
 
-int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
-                                                   const uint16_t input_ht,
-                                                   const uint16_t channels,
-                                                   const uint16_t ch_mult,
-                                                   const uint16_t filter_wd,
-                                                   const uint16_t filter_ht)
+int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+                                                   const data_dims_t *filter_dims,
+                                                   const data_dims_t *output_dims,
+                                                   const dw_conv_params_t *conv_params)
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t ch_mult = conv_params->ch_mult;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+
     int filter_size = filter_wd * filter_ht * channels * ch_mult;
-    int padding_used = ((filter_wd == 3) && (filter_ht == 3)) * 2;
-    int input_size = (input_wd + padding_used) * (input_ht + padding_used) * channels;
-    return  2 * (filter_size + input_size) + 16; //16 for alignment
+    int pad_width = 0, pad_height = 0;
+
+    if ((ch_mult == 1) && (channels % 8 == 0) && (filter_wd == 3) && (filter_ht == 3)) {
+        if (channels % 16 == 0) {
+            if (pad_wd || pad_ht) {
+                pad_width = pad_wd * 2;
+                pad_height = pad_ht * 2;
+            } else {
+                // check if we need to pad additionally
+                pad_width = (out_wd * stride_wd + filter_wd - 1) - input_wd;
+                pad_height = (out_ht * stride_ht + filter_ht - 1) - input_ht;
+                // printf("in(%d %d %d), out(%d %d), filter (%d %d) stride (%d %d), pad (%d %d)",
+                //         input_wd, input_ht, channels, out_wd, out_ht, filter_wd, filter_ht,
+                //         stride_wd, stride_ht, pad_wd, pad_ht);
+            }
+            if (pad_width || pad_height) {
+                int input_size = (input_wd + pad_width) * (input_ht + pad_height) * channels;
+                // printf("ask1 %d\n", filter_size + input_size + 16);
+                return filter_size + input_size + 16;  // 16 for alignment
+            } else {
+                // printf("ask2 %d\n", filter_size + 16);
+                return filter_size + 16;  // 16 for alignment
+            }
+        } else {
+            int input_size = input_wd * input_ht * channels;
+            // printf("ask3 %d\n", 2 * (filter_size + input_size) + 16);
+            return  2 * (filter_size + input_size) + 16; // 16 for alignment
+        }
+    } else if (ch_mult % 4 == 0) {
+        int input_size = input_wd * input_ht * channels;
+        // printf("ask4 %d\n", 2 * (filter_size + input_size) + 16);
+        return  2 * (filter_size + input_size) + 16; // 16 for alignment
+    }
+    return 32; // just few bytes
 }
 
 void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
@@ -376,29 +418,38 @@ void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
  * Assumption 2: Pointers are valid
  * Assumption 3: dialation width = 1
  */
-void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
-                                      const uint16_t input_wd,
-                                      const uint16_t input_ht,
-                                      const uint16_t channels,
-                                      const int32_t input_offset,
-                                      const uint16_t pad_wd,
-                                      const uint16_t pad_ht,
-                                      const uint16_t stride_wd,
-                                      const uint16_t stride_ht,
-                                      const uint16_t ch_mult,
+
+
+
+void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
+                                      const int8_t *input_data,
+                                      const data_dims_t *filter_dims,
                                       const int8_t *filter_data,
-                                      const uint16_t filter_wd,
-                                      const uint16_t filter_ht,
                                       const int32_t *bias,
+                                      const data_dims_t *output_dims,
                                       int8_t *out_data,
-                                      const uint16_t out_wd,
-                                      const uint16_t out_ht,
-                                      const int32_t out_offset,
-                                      const int32_t *out_shift,
-                                      const int32_t *out_mult,
-                                      const int32_t activation_min,
-                                      const int32_t activation_max)
+                                      const dw_conv_params_t *conv_params,
+                                      const quant_data_t *quant_data)
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+    const uint16_t ch_mult = conv_params->ch_mult;
+
     int filter_size = filter_wd * filter_ht * channels * ch_mult;
     int align_len = 16 - (filter_size & 15);
     int input_size = input_wd * input_ht * channels;
@@ -423,18 +474,27 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
                                                                   stride_wd, stride_ht, filter_aligned, bias,
                                                                   out_data, out_wd, out_ht, out_offset, out_shift,
                                                                   out_mult, activation_min, activation_max);
-            } else if ((pad_wd == 0) && (pad_ht == 0) &&
-                    // because this does not handle padding offset cases yet, run just for stride (1, 1).
-                    // end padding of input with `-input_offset` should solve this
-                    (stride_wd == 1) && (stride_ht == 1)) {
+            } else if ((channels % 16 == 0) && (pad_wd == 0) && (pad_ht == 0)) {
                 /* process in 8 bits */
                 int8_t *filter_aligned = (int8_t *) scratch_buffer;
+                int8_t *input_padded = (int8_t *) scratch_buffer + filter_size + align_len;
+
+                // check if we need to pad additionally
+                int pad_right = (out_wd * stride_wd + filter_wd - 1) - input_wd;
+                int pad_bottom = (out_ht * stride_ht + filter_ht - 1) - input_ht;
+                if (pad_right || pad_bottom) { // pad right and bottom
+                    esp_nn_aligned_s8_pad_end_with_value(input_data, input_padded, input_wd, input_ht,
+                                                         channels, -input_offset, pad_right, pad_bottom);
+                } else {
+                    input_padded = (int8_t *) input_data;
+                }
                 memcpy(filter_aligned, filter_data, filter_size);
-                esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_data, input_wd, input_ht, channels, input_offset,
-                                                                  stride_wd, stride_ht, filter_aligned,
-                                                                  bias, out_data, out_wd, out_ht, out_offset, out_shift,
+                esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_padded, input_wd + pad_right,
+                                                                  input_ht + pad_bottom, channels, input_offset,
+                                                                  stride_wd, stride_ht, filter_aligned, bias,
+                                                                  out_data, out_wd, out_ht, out_offset, out_shift,
                                                                   out_mult, activation_min, activation_max);
-            } else { /* (channels % 8) == 0 && pad_wd == 1 && pad_ht == 1 */
+            } else { /* (channels % 8) == 0 */
                 esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
                 esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
                 esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(input_data16, input_wd, input_ht, channels,

+ 8 - 0
code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3

@@ -0,0 +1,8 @@
+# Default configurations for ESP32-S3
+
+CONFIG_ESP32S3_DEFAULT_CPU_FREQ_240=y
+CONFIG_ESP32S3_SPIRAM_SUPPORT=y
+
+CONFIG_ESP32S3_DATA_CACHE_64KB=y
+CONFIG_ESP32S3_DATA_CACHE_8WAYS=y
+CONFIG_ESP32S3_DATA_CACHE_LINE_64B=y

+ 16 - 4
code/components/esp-nn/tests/src/basic_math_test.c

@@ -23,7 +23,9 @@
 #include "test_utils.h"
 
 #if CONFIG_IDF_CMAKE
+#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC))
 #define IDF_HEAP_CAPS 1
+#endif
 
 #if IDF_HEAP_CAPS
 #include "esp_heap_caps.h"
@@ -138,6 +140,11 @@ void esp_nn_add_elementwise_s8_test()
         out_c_orig = out_data_c;
         out_opt_orig = out_data_opt;
 #endif
+        if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL ||
+                out_opt_orig == NULL) {
+            printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__);
+            goto elementwise_add_test_cleanup;
+        }
 
         for (int i = 0; i < size; ++i) {
             input1[i] = rand() % 256 - 128;
@@ -194,10 +201,10 @@ elementwise_add_test_cleanup:
         if (input2_orig) {
             free(input2_orig);
         }
-        if (out_data_c) {
+        if (out_c_orig) {
             free(out_c_orig);
         }
-        if (out_data_opt) {
+        if (out_opt_orig) {
             free(out_opt_orig);
         }
     }
@@ -282,6 +289,11 @@ void esp_nn_mul_elementwise_s8_test()
         out_c_orig = out_data_c;
         out_opt_orig = out_data_opt;
 #endif
+        if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL ||
+                out_opt_orig == NULL) {
+            printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__);
+            goto elementwise_mult_test_cleanup;
+        }
 
         for (int i = 0; i < size; ++i) {
             input1[i] = rand() % 256 - 128;
@@ -333,10 +345,10 @@ elementwise_mult_test_cleanup:
         if (input2_orig) {
             free(input2_orig);
         }
-        if (out_data_c) {
+        if (out_c_orig) {
             free(out_c_orig);
         }
-        if (out_data_opt) {
+        if (out_opt_orig) {
             free(out_opt_orig);
         }
     }

+ 70 - 36
code/components/esp-nn/tests/src/convolution_test.c

@@ -22,8 +22,9 @@
 #include "test_utils.h"
 
 #if CONFIG_IDF_CMAKE
+#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC))
 #define IDF_HEAP_CAPS 1
-
+#endif
 #if IDF_HEAP_CAPS
 #include "esp_heap_caps.h"
 #endif
@@ -44,8 +45,8 @@ void esp_nn_depthwise_conv_s8_test()
     uint16_t filter_ht, filter_wd, ch_mult;
     uint16_t pad_wd, pad_ht, stride_wd, stride_ht;
 
-    // run for 10 iterations
-    for (int itr = 0; itr < 10; itr++) {
+    // run for 15 iterations
+    for (int itr = 0; itr < 15; itr++) {
         /* prepare data */
         switch (itr) {
         case 0: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0)
@@ -144,22 +145,52 @@ void esp_nn_depthwise_conv_s8_test()
             stride_wd = 2;
             stride_ht = 2;
             break;
+        case 8: // same as case 7, with large parameters
+            input_wd = 58;
+            input_ht = 58;
+            filter_ht = 3;
+            filter_wd = 3;
+            ch_mult = 1;
+            channels = 128;
+            pad_wd = 0;
+            pad_ht = 0;
+            stride_wd = 2;
+            stride_ht = 2;
+            break;
+        case 9: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0)  stride (2,2)
+            input_wd = 6;
+            input_ht = 6;
+            filter_ht = 3;
+            filter_wd = 3;
+            ch_mult = 1;
+            channels = 16;
+            pad_wd = 0;
+            pad_ht = 0;
+            stride_wd = 2;
+            stride_ht = 2;
+            break;
         default:
-            input_wd = 4;
-            input_ht = 4;
+            input_wd = 6;
+            input_ht = 6;
             filter_ht = 3;
             filter_wd = 3;
-            ch_mult = 4;
-            channels = 4;
-            pad_wd = 1;
-            pad_ht = 1;
-            stride_wd = 1;
-            stride_ht = 1;
+            ch_mult = 1;
+            channels = 16;
+            stride_wd = rand() % 2 + 1;
+            stride_ht = stride_wd;
+            pad_wd = stride_wd == 1 ? 0 : rand() % 2;
+            pad_ht = pad_wd;
+            printf("stride(%d), pad (%d)\t", stride_wd, pad_wd);
             break;
         }
 
         uint16_t out_wd = (input_wd - filter_wd + 1) / stride_wd;
         uint16_t out_ht = (input_ht - filter_ht + 1) / stride_ht;
+        if (itr == 9) {
+            // expect the function to handle this gracefully
+            out_wd += 1;
+            out_ht += 1;
+        }
         int in_size = input_wd * input_ht * channels;
         int out_size = out_wd * out_ht * channels * ch_mult;
         int filter_size = filter_wd * filter_ht * channels * ch_mult + 4;
@@ -210,9 +241,16 @@ void esp_nn_depthwise_conv_s8_test()
             out_mult[i] = 0x7eb0e200 + rand() % 50;
         }
 
-        int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(input_wd, input_ht,
-                                                                    channels, ch_mult,
-                                                                    filter_wd, filter_ht);
+        data_dims_t input_dims = {.width = input_wd, .height = input_ht, .channels = channels, 1};
+        data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = channels * ch_mult, 1};
+        data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0};
+        dw_conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset, .ch_mult = ch_mult,
+                                        .stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht},
+                                        .dilation = {0, 0}, .activation = {activation_min, activation_max}};
+        quant_data_t quant_data = {.shift = out_shift, .mult = out_mult};
+
+        int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(&input_dims, &filter_dims,
+                                                                      &output_dims, &conv_params);
         if (scratch_buf_size > 0) {
 #if IDF_HEAP_CAPS
             scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
@@ -234,11 +272,8 @@ void esp_nn_depthwise_conv_s8_test()
         }
 
         /* C function */
-        esp_nn_depthwise_conv_s8_ansi(input, input_wd, input_ht, channels, input_offset,
-                                    pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
-                                    filter_data + 4, filter_wd, filter_ht,
-                                    bias + 1, out_data_c, out_wd, out_ht, out_offset, out_shift,
-                                    out_mult, activation_min, activation_max);
+        esp_nn_depthwise_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 4,
+                                      bias + 1, &output_dims, out_data_c, &conv_params, &quant_data);
 
         if (itr == 0) {
             profile_c_end();
@@ -246,11 +281,8 @@ void esp_nn_depthwise_conv_s8_test()
         }
 
         /* Optimized function */
-        esp_nn_depthwise_conv_s8(input, input_wd, input_ht, channels, input_offset,
-                                pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
-                                filter_data + 4, filter_wd, filter_ht,
-                                bias + 1, out_data_opt, out_wd, out_ht, out_offset, out_shift,
-                                out_mult, activation_min, activation_max);
+        esp_nn_depthwise_conv_s8(&input_dims, input, &filter_dims, filter_data + 4,
+                                 bias + 1, &output_dims, out_data_opt, &conv_params, &quant_data);
 
         if (itr == 0) {
             /* disable profiler */
@@ -479,8 +511,16 @@ void esp_nn_conv_s8_test()
             out_mult[i] = 0x7f67f4f8 + rand() % 50;
         }
 
-        int scratch_buf_size = esp_nn_get_conv_scratch_size(in_wd, in_ht, in_channels,
-                                                            out_channels, filter_wd, filter_ht);
+        data_dims_t input_dims = {.width = in_wd, .height = in_ht, .channels = in_channels, 1};
+        data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = out_channels, 1};
+        data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0};
+        conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset,
+                                    .stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht},
+                                    .dilation = {0, 0}, .activation = {activation_min, activation_max}};
+        quant_data_t quant_data = {.shift = out_shift, .mult = out_mult};
+
+        int scratch_buf_size = esp_nn_get_conv_scratch_size(&input_dims, &filter_dims,
+                                                            &output_dims, &conv_params);
         if (scratch_buf_size > 0) {
 #if IDF_HEAP_CAPS
             void *scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
@@ -502,11 +542,8 @@ void esp_nn_conv_s8_test()
         }
 
         /* C function */
-        esp_nn_conv_s8_ansi(input, in_wd, in_ht, in_channels, input_offset,
-                            pad_wd, pad_ht, stride_wd, stride_ht,
-                            filter_data + 2, filter_wd, filter_ht, bias,
-                            out_data_c, out_wd, out_ht, out_channels, out_offset, out_shift,
-                            out_mult, activation_min, activation_max);
+        esp_nn_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 2,
+                            bias, &output_dims, out_data_c, &conv_params, &quant_data);
 
         if (itr == 0) {
             profile_c_end();
@@ -514,11 +551,8 @@ void esp_nn_conv_s8_test()
         }
 
         /* Optimized function */
-        esp_nn_conv_s8(input, in_wd, in_ht, in_channels, input_offset,
-                    pad_wd, pad_ht, stride_wd, stride_ht,
-                    filter_data + 2, filter_wd, filter_ht, bias,
-                    out_data_opt, out_wd, out_ht, out_channels, out_offset, out_shift,
-                    out_mult, activation_min, activation_max);
+        esp_nn_conv_s8(&input_dims, input, &filter_dims, filter_data + 2,
+                       bias, &output_dims, out_data_opt, &conv_params, &quant_data);
 
         if (itr == 0) {
             /* disable profiler */

BIN
code/components/esp-nn_20220716.zip


+ 1 - 1
code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp

@@ -756,7 +756,7 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
                             _fit = _val + _valminus;
 
                         }
-                        if (result > 10)
+                        if (result >= 10)
                             result = result - 10;
                         if (result < 0)
                             result = result + 10;

+ 4 - 1
code/components/tflite-lib/CMakeLists.txt

@@ -25,7 +25,8 @@ list(REMOVE_ITEM srcs_kernels
           "${tfmicro_kernels_dir}/depthwise_conv.cc"
           "${tfmicro_kernels_dir}/fully_connected.cc"
           "${tfmicro_kernels_dir}/mul.cc"
-          "${tfmicro_kernels_dir}/pooling.cc")
+          "${tfmicro_kernels_dir}/pooling.cc"
+          "${tfmicro_kernels_dir}/softmax.cc")
 
 FILE(GLOB esp_nn_kernels
           "${tfmicro_kernels_dir}/esp_nn/*.cc")
@@ -38,6 +39,8 @@ set(lib_srcs
           "${tflite_dir}/kernels/kernel_util.cc"
           "${tflite_dir}/micro/memory_planner/greedy_memory_planner.cc"
           "${tflite_dir}/micro/memory_planner/linear_memory_planner.cc"
+          "${tflite_dir}/micro/arena_allocator/recording_simple_memory_allocator.cc"
+          "${tflite_dir}/micro/arena_allocator/simple_memory_allocator.cc"
           "${tflite_dir}/c/common.cc"
           "${tflite_dir}/core/api/error_reporter.cc"
           "${tflite_dir}/core/api/flatbuffer_conversions.cc"

+ 2 - 0
code/components/tflite-lib/tensorflow/lite/builtin_ops.h

@@ -179,6 +179,8 @@ typedef enum {
   kTfLiteBuiltinMultinomial = 149,
   kTfLiteBuiltinGelu = 150,
   kTfLiteBuiltinDynamicUpdateSlice = 151,
+  kTfLiteBuiltinRelu0To1 = 152,
+  kTfLiteBuiltinUnsortedSegmentProd = 153,
 } TfLiteBuiltinOperator;
 
 #ifdef __cplusplus

+ 3 - 0
code/components/tflite-lib/tensorflow/lite/c/builtin_op_data.h

@@ -518,6 +518,9 @@ typedef struct {
   bool approximate;
 } TfLiteGeluParams;
 
+typedef struct {
+  int num_segments;
+} TfLiteUnsortedSegmentProdParams;
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus

+ 7 - 1
code/components/tflite-lib/tensorflow/lite/c/c_api_types.h

@@ -113,7 +113,13 @@ typedef struct TfLiteQuantizationParams {
 } TfLiteQuantizationParams;
 
 // --------------------------------------------------------------------------
-// Opaque types used by c_api_opaque.h.
+// Opaque types used by c_api.h, c_api_opaque.h and common.h.
+
+// TfLiteOpaqueContext is an opaque version of TfLiteContext;
+typedef struct TfLiteOpaqueContext TfLiteOpaqueContext;
+
+// TfLiteOpaqueNode is an opaque version of TfLiteNode;
+typedef struct TfLiteOpaqueNode TfLiteOpaqueNode;
 
 // TfLiteOpaqueTensor is an opaque version of TfLiteTensor;
 typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor;

+ 40 - 10
code/components/tflite-lib/tensorflow/lite/c/common.cc

@@ -14,13 +14,33 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/lite/c/common.h"
+
 #include "tensorflow/lite/c/c_api_types.h"
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+#include <string>
+
+#include "tensorflow/lite/core/macros.h"
+#include "tensorflow/lite/tensorflow_profiler_logger.h"
+#endif
 
 #ifndef TF_LITE_STATIC_MEMORY
 #include <stdlib.h>
 #include <string.h>
 #endif  // TF_LITE_STATIC_MEMORY
 
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+namespace tflite {
+// Use weak symbols here (even though they are guarded by macros) to avoid
+// build breakage when building a benchmark requires TFLite runs. The main
+// benchmark library should have tensor_profiler_logger dependency.
+TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorAlloc(TfLiteTensor* tensor,
+                                               size_t num_bytes);
+
+TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorDealloc(TfLiteTensor* tensor);
+}  // namespace tflite
+
+#endif  // TF_LITE_TENSORFLOW_PROFILER
+
 extern "C" {
 
 size_t TfLiteIntArrayGetSizeInBytes(int size) {
@@ -99,7 +119,12 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); }
 void TfLiteTensorDataFree(TfLiteTensor* t) {
   if (t->allocation_type == kTfLiteDynamic ||
       t->allocation_type == kTfLitePersistentRo) {
-    free(t->data.raw);
+    if (t->data.raw) {
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+      tflite::OnTfLiteTensorDealloc(t);
+#endif
+      free(t->data.raw);
+    }
   }
   t->data.raw = nullptr;
 }
@@ -161,7 +186,7 @@ void TfLiteTensorFree(TfLiteTensor* t) {
   t->dims = nullptr;
 
   if (t->dims_signature) {
-    TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
+    TfLiteIntArrayFree((TfLiteIntArray*)t->dims_signature);
   }
   t->dims_signature = nullptr;
 
@@ -191,16 +216,12 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
 }
 
 TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
-  if (!src || !dst)
-    return kTfLiteOk;
-  if (src->bytes != dst->bytes)
-    return kTfLiteError;
-  if (src == dst)
-    return kTfLiteOk;
+  if (!src || !dst) return kTfLiteOk;
+  if (src->bytes != dst->bytes) return kTfLiteError;
+  if (src == dst) return kTfLiteOk;
 
   dst->type = src->type;
-  if (dst->dims)
-    TfLiteIntArrayFree(dst->dims);
+  if (dst->dims) TfLiteIntArrayFree(dst->dims);
   dst->dims = TfLiteIntArrayCopy(src->dims);
   memcpy(dst->data.raw, src->data.raw, src->bytes);
   dst->buffer_handle = src->buffer_handle;
@@ -218,8 +239,17 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
   // TODO(b/145340303): Tensor data should be aligned.
   if (!tensor->data.raw) {
     tensor->data.raw = (char*)malloc(num_bytes);
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+    tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
+#endif
   } else if (num_bytes > tensor->bytes) {
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+    tflite::OnTfLiteTensorDealloc(tensor);
+#endif
     tensor->data.raw = (char*)realloc(tensor->data.raw, num_bytes);
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+    tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
+#endif
   }
   tensor->bytes = num_bytes;
 }

+ 52 - 3
code/components/tflite-lib/tensorflow/lite/c/common.h

@@ -173,9 +173,9 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
     }                                                 \
   } while (false)
 #else  // TF_LITE_STRIP_ERROR_STRINGS
-#define UNUSED(...) (void)sizeof(#__VA_ARGS__)
-#define TF_LITE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
-#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
+#define ARGS_UNUSED(...) (void)sizeof(#__VA_ARGS__)
+#define TF_LITE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
+#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
 #endif  // TF_LITE_STRIP_ERROR_STRINGS
 
 // Check whether value is true, and if not return kTfLiteError from
@@ -842,6 +842,32 @@ typedef struct TfLiteContext {
                                    size_t* bytes);
 } TfLiteContext;
 
+// `TfLiteRegistrationExternal` is an external version of `TfLiteRegistration`
+// for C API which doesn't use internal types (such as `TfLiteContext`) but only
+// uses stable API types (such as `TfLiteOpaqueContext`). The purpose of each
+// field is the exactly the same as with `TfLiteRegistration`.
+typedef struct TfLiteRegistrationExternal {
+  // Custom op name.
+  const char* custom_name;
+
+  // The version of the op. The verion should be higher than 0.
+  const int version;
+
+  // Initializes the op from serialized data.
+  void* (*init)(TfLiteOpaqueContext* context, const char* buffer,
+                size_t length);
+
+  // The pointer `buffer` is the data previously returned by an init invocation.
+  void (*free)(TfLiteOpaqueContext* context, void* buffer);
+
+  // Called when the inputs that this node depends on have been resized.
+  TfLiteStatus (*prepare)(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);
+
+  // Called when the node is executed. (should read node->inputs and output to
+  // node->outputs).
+  TfLiteStatus (*invoke)(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);
+} TfLiteRegistrationExternal;
+
 typedef struct TfLiteRegistration {
   // Initializes the op from serialized data.
   // Called only *once* for the lifetime of the op, so any one-time allocations
@@ -903,8 +929,31 @@ typedef struct TfLiteRegistration {
   // Note: It is the responsibility of the registration binder to set this
   // properly.
   int version;
+
+  // The external version of `TfLiteRegistration`. Since we can't use internal
+  // types (such as `TfLiteContext`) for C API to maintain ABI stability.
+  // C API user will provide `TfLiteRegistrationExternal` to implement custom
+  // ops. We keep it inside of `TfLiteRegistration` and use it to route
+  // callbacks properly.
+  TfLiteRegistrationExternal* registration_external;
 } TfLiteRegistration;
 
+// Old version of `TfLiteRegistration` to maintain binary backward
+// compatibility.
+// WARNING: This structure is deprecated / not an official part of the API.
+// It should be only used for binary backward compatibility.
+typedef struct TfLiteRegistration_V1 {
+  void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+  void (*free)(TfLiteContext* context, void* buffer);
+  TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+  TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+  const char* (*profiling_string)(const TfLiteContext* context,
+                                  const TfLiteNode* node);
+  int32_t builtin_code;
+  const char* custom_name;
+  int version;
+} TfLiteRegistration_V1;
+
 // The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
 // values should be 1, 2, 4, 8, ...etc.
 typedef enum TfLiteDelegateFlags {

+ 11 - 0
code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc

@@ -836,6 +836,16 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
       *builtin_data = params.release();
       return kTfLiteOk;
     }
+    case BuiltinOperator_UNSORTED_SEGMENT_PROD: {
+      auto params = safe_allocator.Allocate<TfLiteUnsortedSegmentProdParams>();
+      TF_LITE_ENSURE(error_reporter, params != nullptr);
+      if (const auto* unsorted_segment_prod_params =
+              op->builtin_options_as_UnsortedSegmentProdOptions()) {
+        params->num_segments = unsorted_segment_prod_params->num_segments();
+      }
+      *builtin_data = params.release();
+      return kTfLiteOk;
+    }
     // Below are the ops with no builtin_data structure.
     // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
     // ok for now, since there is no call implementation either.
@@ -848,6 +858,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
     case BuiltinOperator_MATRIX_DIAG:
     case BuiltinOperator_MATRIX_SET_DIAG:
     case BuiltinOperator_RELU_N1_TO_1:
+    case BuiltinOperator_RELU_0_TO_1:
     case BuiltinOperator_SELECT:
     case BuiltinOperator_SELECT_V2:
     case BuiltinOperator_SLICE:

+ 52 - 1
code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h

@@ -23,6 +23,16 @@ limitations under the License.
 #include "tensorflow/lite/core/api/error_reporter.h"
 #include "tensorflow/lite/schema/schema_generated.h"
 
+// Opaque type similar to TfLiteDelegate / TfLiteOpaqueDelegate.
+// This is used for cases (e.g. when using "TF Lite with Google Play Services")
+// where the TF Lite runtime might be built using a newer (or older)
+// version of the TF Lite sources than the app, and hence might have a
+// different definition of the TfLiteDelegate type. TF Lite APIs use
+// TfLiteOpaqueDelegate rather than TfLiteDelegate when they want to
+// refer to a delegate defined with that potentially different version
+// of the TfLiteDelegate type.
+struct TfLiteOpaqueDelegateStruct;
+
 namespace tflite {
 
 /// Abstract interface that returns TfLiteRegistrations given op codes or custom
@@ -37,8 +47,10 @@ class OpResolver {
   virtual const TfLiteRegistration* FindOp(const char* op,
                                            int version) const = 0;
 
+  // Represents a sequence of delegates.
   using TfLiteDelegatePtrVector =
       std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
+
   // Returns optional delegates for resolving and handling ops in the flatbuffer
   // model. This may be used in addition to the standard TfLiteRegistration
   // lookup for graph resolution.
@@ -47,16 +59,55 @@ class OpResolver {
     return {};
   }
 
-  // Represent a function that creates a TfLite delegate instance.
+  // Represents a function that creates a TfLite delegate instance.
   using TfLiteDelegateCreator =
       std::function<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
           int /*num_threads*/)>;
+
+  // Represents a sequence of delegate creator functions.
   using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>;
+
   // Returns a vector of delegate creators to create optional delegates for
   // resolving and handling ops in the flatbuffer model. This may be used in
   // addition to the standard TfLiteRegistration lookup for graph resolution.
+  //
+  // Note that this method is not used (will not be called) if you are using
+  // TF Lite in Google Play Services; the GetOpaqueDelegateCreators method
+  // (see below) is used for that case.
   virtual TfLiteDelegateCreators GetDelegateCreators() const { return {}; }
 
+  // TODO(b/202712825): it would be nice if we could avoid the need for separate
+  // "opaque" types & methods for use only with TF Lite in Google Play Services.
+
+  // Represents an opaque delegate instance.
+  // WARNING: Experimental interface, subject to change.
+  using TfLiteOpaqueDelegatePtr =
+      std::unique_ptr<TfLiteOpaqueDelegateStruct,
+                      void (*)(TfLiteOpaqueDelegateStruct*)>;
+
+  // Represents a function that creates an opaque delegate instance.
+  // WARNING: Experimental interface, subject to change.
+  using TfLiteOpaqueDelegateCreator =
+      std::function<TfLiteOpaqueDelegatePtr(int /*num_threads*/)>;
+
+  // Represents a sequence of opaque delegate creator functions.
+  // WARNING: Experimental interface, subject to change.
+  using TfLiteOpaqueDelegateCreators = std::vector<TfLiteOpaqueDelegateCreator>;
+
+  // Returns a vector of opaque delegate creators to create optional opaque
+  // delegates for resolving and handling ops in the flatbuffer model. This may
+  // be used in addition to the standard TfLiteRegistration lookup for graph
+  // resolution.
+  //
+  // Note that this method will be called only if you are using TF Lite in
+  // Google Play Services; if you are using regular TF Lite, GetDelegateCreators
+  // (see above) is used instead.
+  //
+  // WARNING: Experimental interface, subject to change.
+  virtual TfLiteOpaqueDelegateCreators GetOpaqueDelegateCreators() const {
+    return {};
+  }
+
   virtual ~OpResolver() {}
 
  private:

+ 3 - 3
code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h

@@ -23,9 +23,9 @@ namespace tflite {
 namespace reference_ops {
 
 inline int16_t SaturatingLeftShift(int16_t value, int amount) {
-  int32_t result = static_cast<int32_t>(value) * (1 << amount);
-  result = std::min<int32_t>(result, std::numeric_limits<int16_t>::max());
-  result = std::max<int32_t>(result, std::numeric_limits<int16_t>::min());
+  int64_t result = static_cast<int64_t>(value) * (1 << amount);
+  result = std::min<int64_t>(result, std::numeric_limits<int16_t>::max());
+  result = std::max<int64_t>(result, std::numeric_limits<int16_t>::min());
   return result;
 }
 

+ 6 - 3
code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h

@@ -27,6 +27,11 @@ class RuntimeShape {
  public:
   RuntimeShape& operator=(RuntimeShape const&) = delete;
 
+  // RuntimeShape in TFLM supports up to 5 dimensions.
+  // The name kMaxSmallSize comes from the same file of the upstream
+  // tensorflow lite repo and need to be kept the same for max reuse.
+  static constexpr int kMaxSmallSize = 5;
+
   RuntimeShape() : size_(0) {}
 
   explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {}
@@ -104,11 +109,9 @@ class RuntimeShape {
                 sizeof(int32_t) * shape.DimensionsCount());
   }
 
-  // A maximum of 4 dimensions are supported on TFLM.
-  static constexpr int kMaxSize = 5;
   int32_t size_;
   union {
-    int32_t dims_[kMaxSize];
+    int32_t dims_[kMaxSmallSize];
   };
 };
 

+ 5 - 5
code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h

@@ -974,11 +974,11 @@ struct StridedSliceParams {
   int8_t strides_count;
   int32_t strides[5];
 
-  int16_t begin_mask;
-  int16_t ellipsis_mask;
-  int16_t end_mask;
-  int16_t new_axis_mask;
-  int16_t shrink_axis_mask;
+  uint16_t begin_mask;
+  uint16_t ellipsis_mask;
+  uint16_t end_mask;
+  uint16_t new_axis_mask;
+  uint16_t shrink_axis_mask;
 };
 
 struct TanhParams {

+ 1 - 1
code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h

@@ -308,7 +308,7 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
                                         const TfLiteTensor* input3,
                                         TfLiteIntArray** output_shape);
 
-// Return the size of given type in bytes. Return 0 in in case of string.
+// Return the size of given type in bytes. Return 0 in case of string.
 int TfLiteTypeGetSize(TfLiteType type);
 
 // Whether the current platform is mobile (Android or iOS).

+ 3 - 3
code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h

@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
-#define TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
 
 #include <cstddef>
 #include <cstdint>
@@ -97,4 +97,4 @@ class INonPersistentBufferAllocator {
 
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_

+ 165 - 0
code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc

@@ -0,0 +1,165 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h"
+
+#include "tensorflow/lite/micro/memory_helpers.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
+
+namespace tflite {
+
+NonPersistentArenaBufferAllocator::NonPersistentArenaBufferAllocator(
+    uint8_t* buffer, size_t buffer_size)
+    : buffer_head_(buffer),
+      buffer_tail_(buffer + buffer_size),
+      head_temp_(buffer),
+      next_temp_(buffer) {}
+
+NonPersistentArenaBufferAllocator::~NonPersistentArenaBufferAllocator() {}
+
+// Allocates a temporary buffer. This buffer is not resizable.
+uint8_t* NonPersistentArenaBufferAllocator::AllocateTemp(size_t size,
+                                                         size_t alignment) {
+  uint8_t* const aligned_result = AlignPointerUp(next_temp_, alignment);
+  const size_t available_memory = buffer_tail_ - aligned_result;
+  if (available_memory < size) {
+    MicroPrintf(
+        "Failed to allocate temp memory. Requested: %u, "
+        "available %u, missing: %u",
+        size, available_memory, size - available_memory);
+    return nullptr;
+  }
+  next_temp_ = aligned_result + size;
+  temp_buffer_ptr_check_sum_ ^= reinterpret_cast<intptr_t>(aligned_result);
+  temp_buffer_count_++;
+  return aligned_result;
+}
+
+// Signals that a temporary buffer is no longer needed.
+void NonPersistentArenaBufferAllocator::DeallocateTemp(uint8_t* temp_buf) {
+  temp_buffer_ptr_check_sum_ ^= reinterpret_cast<intptr_t>(temp_buf);
+  temp_buffer_count_--;
+}
+
+// Returns true if all temporary buffers are already deallocated.
+bool NonPersistentArenaBufferAllocator::IsAllTempDeallocated() {
+  if (temp_buffer_count_ != 0 || temp_buffer_ptr_check_sum_ != 0) {
+    MicroPrintf(
+        "Number of allocated temp buffers: %d. Checksum passing status: %d",
+        temp_buffer_count_, !temp_buffer_ptr_check_sum_);
+    return false;
+  }
+  return true;
+}
+
+// Signals that all temporary allocations can be reclaimed. TFLM calls this
+// API when it knows that all temporary buffers that it requested has been
+// deallocated. The goal of API is to facilitate implementations of
+// INonPersistentBufferAllocator can reuse buffer with some reasonable
+// complexity.
+TfLiteStatus NonPersistentArenaBufferAllocator::ResetTempAllocations() {
+  if (!IsAllTempDeallocated()) {
+    MicroPrintf(
+        "All temp buffers must be freed before calling ResetTempAllocations()");
+    return kTfLiteError;
+  }
+  next_temp_ = head_temp_;
+  return kTfLiteOk;
+}
+
+// Returns a buffer that is resizable viable ResizeBuffer().
+uint8_t* NonPersistentArenaBufferAllocator::AllocateResizableBuffer(
+    size_t size, size_t alignment) {
+  // Only supports one resizable buffer, which starts at the buffer head.
+  uint8_t* expected_resizable_buf = AlignPointerUp(buffer_head_, alignment);
+
+  if (head_temp_ != expected_resizable_buf) {
+    MicroPrintf(
+        "Cannot allocate a new resizable buffer when one is already allocated");
+    return nullptr;
+  }
+
+  if (ResizeBuffer(expected_resizable_buf, size, alignment) == kTfLiteOk) {
+    return expected_resizable_buf;
+  }
+  return nullptr;
+}
+
+// Resizes a buffer that is previously returned by the AllocateResizableBuffer.
+// Note that ResizeBuffer(old_resizable_buf, 0, 1) effectively deallocates
+// a previous allocated resizable buffer.
+TfLiteStatus NonPersistentArenaBufferAllocator::ResizeBuffer(
+    uint8_t* resizable_buf, size_t size, size_t alignment) {
+  // Only supports one resizable buffer, which starts at the buffer head.
+  uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
+  if (resizable_buf != expect_resizable_buf) {
+    MicroPrintf("Internal error: buffer is not resizable");
+    return kTfLiteError;
+  }
+  if (head_temp_ != next_temp_) {
+    MicroPrintf("ResetTempAllocations() is not called before ResizeBuffer().");
+    return kTfLiteError;
+  }
+
+  const size_t available_memory = buffer_tail_ - expect_resizable_buf;
+  if (available_memory < size) {
+    MicroPrintf(
+        "Failed to resize buffer. Requested: %u, available %u, missing: %u",
+        size, available_memory, size - available_memory);
+    return kTfLiteError;
+  }
+  head_temp_ = expect_resizable_buf + size;
+  next_temp_ = head_temp_;
+
+  return kTfLiteOk;
+}
+
+// Frees up the memory occupied by the resizable buffer.
+TfLiteStatus NonPersistentArenaBufferAllocator::DeallocateResizableBuffer(
+    uint8_t* resizable_buf) {
+  return ResizeBuffer(resizable_buf, 0, 1);
+}
+
+// Returns a pointer pointing to the start of the overlay memory, which is
+// used for activation tensors and scratch buffers by kernels at Invoke stage.
+uint8_t* NonPersistentArenaBufferAllocator::GetOverlayMemoryAddress() const {
+  return buffer_head_;
+}
+
+// Reserves the size of the overlay memory. This overlay is reserved for the
+// kernels at Invoke stage. This is referred to as the overlay because before
+// Invoket state, the same memory can be used for temp buffers. The layout of
+// the memory is planned by the memory planner separately at Invoke stage.
+TfLiteStatus
+NonPersistentArenaBufferAllocator::ReserveNonPersistentOverlayMemory(
+    size_t size, size_t alignment) {
+  uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
+  return ResizeBuffer(expect_resizable_buf, size, alignment);
+}
+
+// Returns the size of non-persistent buffer in use.
+size_t NonPersistentArenaBufferAllocator::GetNonPersistentUsedBytes() const {
+  return (next_temp_ - buffer_head_);
+}
+
+// Returns the number of bytes available with a given alignment. This number
+// takes in account any temporary allocations.
+size_t NonPersistentArenaBufferAllocator::GetAvailableMemory(
+    size_t alignment) const {
+  uint8_t* const aligned_temp = AlignPointerUp(next_temp_, alignment);
+  uint8_t* const aligned_tail = AlignPointerDown(buffer_tail_, alignment);
+  return aligned_tail - aligned_temp;
+}
+
+}  // namespace tflite

+ 104 - 0
code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h

@@ -0,0 +1,104 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+
+#include <cstddef>
+#include <cstdint>
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
+#include "tensorflow/lite/micro/compatibility.h"
+
+namespace tflite {
+
+// Implement INonPersistentBufferAllocator on an arena that is dedicated for
+// non-persistent buffers.
+class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator {
+ public:
+  NonPersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
+  virtual ~NonPersistentArenaBufferAllocator();
+
+  // Allocates a temporary buffer. This buffer is not resizable.
+  uint8_t* AllocateTemp(size_t size, size_t alignment) override;
+
+  // Signals that a temporary buffer is no longer needed.
+  void DeallocateTemp(uint8_t* buf) override;
+
+  // Returns true if all temporary buffers are already deallocated.
+  bool IsAllTempDeallocated() override;
+
+  // Signals that all temporary allocations can be reclaimed. TFLM calls this
+  // API when it knows that all temporary buffers that it requested has been
+  // deallocated.
+  TfLiteStatus ResetTempAllocations() override;
+
+  // Returns a buffer that is resizable viable ResizeBuffer().
+  uint8_t* AllocateResizableBuffer(size_t size, size_t alignment) override;
+
+  // Resizes a buffer that is previously returned by the
+  // AllocateResizableBuffer.
+  TfLiteStatus ResizeBuffer(uint8_t* resizable_buf, size_t size,
+                            size_t alignment) override;
+
+  // Frees up the memory occupied by the resizable buffer.
+  TfLiteStatus DeallocateResizableBuffer(uint8_t* resizable_buf) override;
+
+  // Returns a pointer pointing to the start of the overlay memory, which is
+  // used for activation tensors and scratch buffers by kernels at Invoke stage.
+  uint8_t* GetOverlayMemoryAddress() const override;
+
+  // Reserves the size of the overlay memory. This overlay is reserved for the
+  // kernels at Invoke stage. This is referred to as the overlay because before
+  // Invoket state, the same memory can be used for temp buffers. The layout of
+  // the memory is planned by the memory planner separately at Invoke stage.
+  TfLiteStatus ReserveNonPersistentOverlayMemory(size_t size,
+                                                 size_t alignment) override;
+
+  // Returns the size of non-persistent buffer in use.
+  size_t GetNonPersistentUsedBytes() const override;
+
+  // Returns the number of bytes available with a given alignment. This number
+  // takes in account any temporary allocations.
+  size_t GetAvailableMemory(size_t alignment) const override;
+
+  TF_LITE_REMOVE_VIRTUAL_DELETE
+
+ private:
+  // The memory arena that this allocator manages.
+  uint8_t* const buffer_head_;
+  uint8_t* const buffer_tail_;
+
+  // The whole region is split into two parts:
+  // buffer_head_ to head_temp_ - 1 belongs to the only resizable buffer.
+  // head_temp_ to buffer_tail_ can be used for (non-resizable) temp buffers.
+  uint8_t* head_temp_;
+
+  // next_temp_ points to the next available temp buffer allocation address and
+  // its range is between head_temp_ and buffer_tail_
+  uint8_t* next_temp_;
+
+  // XOR Check sum for outstanding temp buffers.
+  // If all temp buffers are deallocated OR no temp buffers are allocated,
+  // temp_buffer_ptr_check_sum_ == nullptr.
+  intptr_t temp_buffer_ptr_check_sum_ = 0;
+  // Count of outstanding temp buffers.
+  int temp_buffer_count_ = 0;
+};
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_

+ 52 - 0
code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc

@@ -0,0 +1,52 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h"
+
+#include "tensorflow/lite/micro/memory_helpers.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
+
+namespace tflite {
+
+PersistentArenaBufferAllocator::PersistentArenaBufferAllocator(
+    uint8_t* buffer, size_t buffer_size)
+    : buffer_head_(buffer),
+      buffer_tail_(buffer + buffer_size),
+      tail_temp_(buffer_tail_) {}
+
+PersistentArenaBufferAllocator::~PersistentArenaBufferAllocator() {}
+
+uint8_t* PersistentArenaBufferAllocator::AllocatePersistentBuffer(
+    size_t size, size_t alignment) {
+  uint8_t* const aligned_result =
+      AlignPointerDown(tail_temp_ - size, alignment);
+  if (aligned_result < buffer_head_) {
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+    const size_t missing_memory = buffer_head_ - aligned_result;
+    MicroPrintf(
+        "Failed to allocate tail memory. Requested: %u, "
+        "available %u, missing: %u",
+        size, size - missing_memory, missing_memory);
+#endif
+    return nullptr;
+  }
+  tail_temp_ = aligned_result;
+  return aligned_result;
+}
+
+size_t PersistentArenaBufferAllocator::GetPersistentUsedBytes() const {
+  return buffer_tail_ - tail_temp_;
+}
+
+}  // namespace tflite

+ 59 - 0
code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h

@@ -0,0 +1,59 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+
+#include <cstddef>
+#include <cstdint>
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
+#include "tensorflow/lite/micro/compatibility.h"
+
+namespace tflite {
+
+// PersistentArenaBufferAllocator is an implementatation of
+// IPersistentBufferAllocator interface on an arena that is dedicated for
+// persistent buffers.
+class PersistentArenaBufferAllocator : public IPersistentBufferAllocator {
+ public:
+  PersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
+  virtual ~PersistentArenaBufferAllocator();
+
+  // Allocates persistent memory. The persistent buffer is never freed.
+  // Returns nullptr if errors occured.
+  uint8_t* AllocatePersistentBuffer(size_t size, size_t alignment) override;
+
+  // Returns the size of all persistent allocations in bytes.
+  size_t GetPersistentUsedBytes() const override;
+
+  TF_LITE_REMOVE_VIRTUAL_DELETE
+ private:
+  // The memory arena that this allocator manages.
+  uint8_t* const buffer_head_;
+  uint8_t* const buffer_tail_;
+
+  // The whole region is split into two parts:
+  // tail_temp_ to buffer_tail_ contains allocated buffers;
+  // buffer_head_ to tail_temp_ - 1 belongs to still available spaces.
+  // So in essence, the allocated region grows from the bottom and emulates
+  // SimpleMemoryAllocator's persistent part.
+  uint8_t* tail_temp_;
+};
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_

+ 1 - 1
code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.cc → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.cc

@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/micro/recording_simple_memory_allocator.h"
+#include "tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h"
 
 #include <new>
 

+ 4 - 4
code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.h → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h

@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
-#define TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
 
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/compatibility.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
 
 namespace tflite {
 
@@ -62,4 +62,4 @@ class RecordingSimpleMemoryAllocator : public SimpleMemoryAllocator {
 
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_

+ 1 - 1
code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.cc → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.cc

@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 
 #include <cstddef>
 #include <cstdint>

+ 4 - 4
code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.h → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h

@@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
-#define TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_
 
 #include <cstddef>
 #include <cstdint>
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
 #include "tensorflow/lite/micro/compatibility.h"
-#include "tensorflow/lite/micro/ibuffer_allocator.h"
 
 namespace tflite {
 
@@ -147,4 +147,4 @@ class SimpleMemoryAllocator : public INonPersistentBufferAllocator,
 
 }  // namespace tflite
 
-#endif  // TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_

+ 1 - 1
code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc

@@ -16,10 +16,10 @@ limitations under the License.
 #include "tensorflow/lite/micro/fake_micro_context.h"
 
 #include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/micro_allocator.h"
 #include "tensorflow/lite/micro/micro_arena_constants.h"
 #include "tensorflow/lite/micro/micro_error_reporter.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
 
 namespace tflite {
 namespace {

+ 7 - 20
code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc

@@ -24,6 +24,7 @@ limitations under the License.
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/op_macros.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
 #include "tensorflow/lite/micro/micro_utils.h"
 
 namespace tflite {
@@ -60,8 +61,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
       return kTfLiteOk;
     }
     default: {
-      TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
-                         TfLiteTypeGetName(input->type));
+      MicroPrintf("Only float32 is supported currently, got %s",
+                  TfLiteTypeGetName(input->type));
       return kTfLiteError;
     }
   }
@@ -99,8 +100,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
       return kTfLiteOk;
     }
     default: {
-      TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
-                         TfLiteTypeGetName(input->type));
+      MicroPrintf("Only float32 is supported currently, got %s",
+                  TfLiteTypeGetName(input->type));
       return kTfLiteError;
     }
   }
@@ -109,25 +110,11 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_RELU() {
-  return {/*init=*/ReluInit,
-          /*free=*/nullptr,
-          /*prepare=*/ReluPrepare,
-          /*invoke=*/ReluEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(ReluInit, ReluPrepare, ReluEval);
 }
 
 TfLiteRegistration Register_RELU6() {
-  return {/*init=*/Relu6Init,
-          /*free=*/nullptr,
-          /*prepare=*/Relu6Prepare,
-          /*invoke=*/Relu6Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Relu6Init, Relu6Prepare, Relu6Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc

@@ -159,14 +159,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
 }
 
 TfLiteRegistration Register_ADD() {
-  return {/*init=*/AddInit,
-          /*free=*/nullptr,
-          /*prepare=*/AddPrepare,
-          /*invoke=*/AddEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc

@@ -208,14 +208,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_ADD_N() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 2 - 16
code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc

@@ -104,25 +104,11 @@ TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace arg_min_max
 
 TfLiteRegistration Register_ARG_MAX() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/arg_min_max::ArgMaxEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMaxEval);
 }
 
 TfLiteRegistration Register_ARG_MIN() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/arg_min_max::ArgMinEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMinEval);
 }
 
 }  // namespace micro

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc

@@ -95,14 +95,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace.
 
 TfLiteRegistration Register_ASSIGN_VARIABLE() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc

@@ -105,14 +105,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace.
 
 TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 3 - 9
code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc

@@ -84,14 +84,8 @@ TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_BROADCAST_ARGS() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/BroadcastArgsPrepare,
-          /*invoke=*/BroadcastArgsEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, BroadcastArgsPrepare,
+                                   BroadcastArgsEval);
 }
 
-}  // namespace tflite
+}  // namespace tflite

+ 3 - 9
code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc

@@ -116,14 +116,8 @@ TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_BROADCAST_TO() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/BroadcastToPrepare,
-          /*invoke=*/BroadcastToEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, BroadcastToPrepare,
+                                   BroadcastToEval);
 }
 
-}  // namespace tflite
+}  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc

@@ -82,14 +82,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace.
 
 TfLiteRegistration Register_CALL_ONCE() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc

@@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_CAST() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc

@@ -67,14 +67,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace ceil
 
 TfLiteRegistration Register_CEIL() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/ceil::Prepare,
-          /*invoke=*/ceil::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, ceil::Prepare, ceil::Eval);
 }
 
 }  // namespace micro

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc

@@ -108,14 +108,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
 }
 
 TfLiteRegistration* Register_CIRCULAR_BUFFER() {
-  static TfLiteRegistration r = {/*init=*/CircularBufferInit,
-                                 /*free=*/nullptr,
-                                 /*prepare=*/CircularBufferPrepare,
-                                 /*invoke=*/CircularBufferEval,
-                                 /*profiling_string=*/nullptr,
-                                 /*builtin_code=*/0,
-                                 /*custom_name=*/nullptr,
-                                 /*version=*/0};
+  static TfLiteRegistration r = tflite::micro::RegisterOp(CircularBufferInit, CircularBufferPrepare, CircularBufferEval);
   return &r;
 }
 

+ 12 - 48
code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc

@@ -583,69 +583,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace comparisons
 
 TfLiteRegistration Register_EQUAL() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::EqualEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::EqualEval);
 }
 
 TfLiteRegistration Register_NOT_EQUAL() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::NotEqualEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::NotEqualEval);
 }
 
 TfLiteRegistration Register_GREATER() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::GreaterEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::GreaterEval);
 }
 
 TfLiteRegistration Register_GREATER_EQUAL() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::GreaterEqualEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::GreaterEqualEval);
 }
 
 TfLiteRegistration Register_LESS() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::LessEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::LessEval);
 }
 
 TfLiteRegistration Register_LESS_EQUAL() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::LessEqualEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::LessEqualEval);
 }
 
 }  // namespace micro

+ 5 - 11
code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc

@@ -148,12 +148,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     TF_LITE_ENSURE(context, input != nullptr);
     int num_dimensions = NumDimensions(input);
 
-    if (num_dimensions > 4) {
+    if (num_dimensions > RuntimeShape::kMaxSmallSize) {
       TF_LITE_KERNEL_LOG(
           context,
-          "Op Concatenation does not currently support num dimensions >4 "
+          "Op Concatenation does not currently support num dimensions > %d "
           "Tensor has %d dimensions.",
-          num_dimensions);
+          RuntimeShape::kMaxSmallSize, num_dimensions);
       return kTfLiteError;
     }
     micro_context->DeallocateTempTfLiteTensor(input);
@@ -252,14 +252,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace concatenation
 
 TfLiteRegistration Register_CONCATENATION() {
-  return {/*init=*/concatenation::Init,
-          /*free=*/nullptr,
-          /*prepare=*/concatenation::Prepare,
-          /*invoke=*/concatenation::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(concatenation::Init, concatenation::Prepare,
+                                   concatenation::Eval);
 }
 
 }  // namespace micro

+ 40 - 22
code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc

@@ -25,6 +25,7 @@ limitations under the License.
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/padding.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
 
 namespace tflite {
 namespace {
@@ -67,23 +68,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorData<float>(filter),
           tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<float>(bias),
+          tflite::micro::GetOptionalTensorData<float>(bias),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<float>(output),
           tflite::micro::GetTensorShape(nullptr), nullptr);
       break;
     }
     case kTfLiteInt16: {
-      reference_integer_ops::ConvPerChannel(
-          ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
-          data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
-          tflite::micro::GetTensorData<int16_t>(input),
-          tflite::micro::GetTensorShape(filter),
-          tflite::micro::GetTensorData<int8_t>(filter),
-          tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<std::int64_t>(bias),
-          tflite::micro::GetTensorShape(output),
-          tflite::micro::GetTensorData<int16_t>(output));
+      switch (bias->type) {
+        case kTfLiteInt32: {
+          reference_integer_ops::ConvPerChannel(
+              ConvParamsQuantized(params, data),
+              data.per_channel_output_multiplier, data.per_channel_output_shift,
+              tflite::micro::GetTensorShape(input),
+              tflite::micro::GetTensorData<int16_t>(input),
+              tflite::micro::GetTensorShape(filter),
+              tflite::micro::GetTensorData<int8_t>(filter),
+              tflite::micro::GetTensorShape(bias),
+              tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
+              tflite::micro::GetTensorShape(output),
+              tflite::micro::GetTensorData<int16_t>(output));
+          break;
+        }
+        case kTfLiteInt64: {
+          reference_integer_ops::ConvPerChannel(
+              ConvParamsQuantized(params, data),
+              data.per_channel_output_multiplier, data.per_channel_output_shift,
+              tflite::micro::GetTensorShape(input),
+              tflite::micro::GetTensorData<int16_t>(input),
+              tflite::micro::GetTensorShape(filter),
+              tflite::micro::GetTensorData<int8_t>(filter),
+              tflite::micro::GetTensorShape(bias),
+              tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
+              tflite::micro::GetTensorShape(output),
+              tflite::micro::GetTensorData<int16_t>(output));
+          break;
+        }
+        default:
+          MicroPrintf("Bias type %s (%d) not supported.",
+                      TfLiteTypeGetName(bias->type), bias->type);
+          return kTfLiteError;
+      }
       break;
     }
     case kTfLiteInt8: {
@@ -94,14 +119,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorData<int8_t>(filter),
           tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<int32_t>(bias),
+          tflite::micro::GetOptionalTensorData<int32_t>(bias),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<int8_t>(output));
       break;
     }
     default:
-      TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                         TfLiteTypeGetName(input->type), input->type);
+      MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
+                  input->type);
       return kTfLiteError;
   }
   return kTfLiteOk;
@@ -110,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/ConvPrepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, ConvPrepare, Eval);
 }
 
 }  // namespace tflite

+ 10 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h

@@ -97,6 +97,16 @@ TfLiteStatus TestConvQuantizedPerChannel(
     float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
     TfLiteRegistration registration, int16_t* output_data);
 
+TfLiteStatus TestConvQuantizedPerChannel(
+    int* input_dims_data, const float* input_data, int16_t* input_quantized,
+    float input_scale, int input_zero_point, int* filter_dims_data,
+    const float* filter_data, int8_t* filter_data_quantized,
+    int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized,
+    float* bias_scales, int* bias_zero_points, int* output_dims_data,
+    const float* expected_output_data, int16_t* expected_output_data_quantized,
+    float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
+    TfLiteRegistration registration, int16_t* output_data);
+
 }  // namespace testing
 }  // namespace tflite
 

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc

@@ -169,14 +169,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_CUMSUM() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc

@@ -136,14 +136,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_DEPTH_TO_SPACE() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 3 - 10
code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc

@@ -62,7 +62,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorData<float>(filter),
           tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<float>(bias),
+          tflite::micro::GetOptionalTensorData<float>(bias),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<float>(output));
       break;
@@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorData<int8_t>(filter),
           tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<int32_t>(bias),
+          tflite::micro::GetOptionalTensorData<int32_t>(bias),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<int8_t>(output));
       break;
@@ -92,14 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/DepthwiseConvPrepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, DepthwiseConvPrepare, Eval);
 }
 
 }  // namespace tflite

+ 27 - 1
code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h

@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -49,6 +49,32 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
 
 TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
 
+// This is the most generic TfLiteRegistration. The actual supported types may
+// still be target dependent. The only requirement is that every implementation
+// (reference or optimized) must define this function.
+TfLiteRegistration Register_DEPTHWISE_CONV_2D();
+
+#if defined(CMSIS_NN)
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int8 activations and int8 weights and uses the latency optimized
+// implementations.
+TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8();
+
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int16 activations and int8 weights and uses the latency optimized
+// implementations.
+TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16();
+
+#else
+inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8() {
+  return Register_DEPTHWISE_CONV_2D();
+}
+
+inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16() {
+  return Register_DEPTHWISE_CONV_2D();
+}
+#endif
+
 }  // namespace tflite
 
 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_

+ 9 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc

@@ -57,6 +57,13 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
                                   tflite::micro::GetTensorShape(output),
                                   tflite::micro::GetTensorData<float>(output));
         break;
+      case kTfLiteUInt8:
+        reference_ops::Dequantize(data->quantization_params,
+                                  tflite::micro::GetTensorShape(input),
+                                  tflite::micro::GetTensorData<uint8_t>(input),
+                                  tflite::micro::GetTensorShape(output),
+                                  tflite::micro::GetTensorData<float>(output));
+        break;
       default:
         MicroPrintf("Input %s, output %s not supported.",
                     TfLiteTypeGetName(input->type),
@@ -74,14 +81,8 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
 }
 
 TfLiteRegistration Register_DEQUANTIZE() {
-  return {/*init=*/DequantizeInit,
-          /*free=*/nullptr,
-          /*prepare=*/DequantizePrepare,
-          /*invoke=*/DequantizeEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare,
+                                   DequantizeEval);
 }
 
 }  // namespace tflite

+ 3 - 2
code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc

@@ -41,8 +41,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
   TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
   TF_LITE_ENSURE(context, output != nullptr);
 
-  TF_LITE_ENSURE(context,
-                 input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
+  TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
+                              input->type == kTfLiteInt16 ||
+                              input->type == kTfLiteUInt8);
   TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
 
   if (output->type == kTfLiteInt32) {

+ 1 - 10
code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc

@@ -149,8 +149,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   return op_data;
 }
 
-void Free(TfLiteContext* context, void* buffer) {}
-
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   auto* op_data = static_cast<OpData*>(node->user_data);
 
@@ -802,14 +800,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
-  static TfLiteRegistration r = {/*init=*/Init,
-                                 /*free=*/Free,
-                                 /*prepare=*/Prepare,
-                                 /*invoke=*/Eval,
-                                 /*profiling_string=*/nullptr,
-                                 /*builtin_code=*/0,
-                                 /*custom_name=*/nullptr,
-                                 /*version=*/0};
+  static TfLiteRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
   return &r;
 }
 

+ 289 - 79
code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc

@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@ limitations under the License.
 #include <cmath>
 
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -27,6 +29,22 @@ namespace micro {
 namespace elementwise {
 namespace {
 
+constexpr int kAbsNameId = 0;
+constexpr int kRsrqtNameId = 1;
+
+const int kElementwiseInputTensor = 0;
+const int kElementwiseOutputTensor = 0;
+
+struct OpDataAbsRsqrt {
+  int32_t multiplier;
+  int shift;
+  int input_offset;
+  int output_offset;
+  bool needs_rescale;
+  TfLiteQuantizationType input_quantization_type;
+  TfLiteType input_type;
+};
+
 bool IsNumericSupportedType(const TfLiteType type) {
   return type == kTfLiteFloat32;
 }
@@ -35,11 +53,57 @@ bool IsLogicalSupportedType(const TfLiteType type) {
   return type == kTfLiteBool;
 }
 
+bool IsAbsSupportedType(const TfLiteType type) {
+  return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
+}
+
+bool IsRsqrtSupportedType(const TfLiteType type) {
+  return type == kTfLiteFloat32 || type == kTfLiteInt8;
+}
+
+inline void SetAbsOutputMultiplier(const float input_scale,
+                                   const float output_scale,
+                                   int32_t* multiplier, int* shift) {
+  QuantizeMultiplier(static_cast<double>(input_scale / output_scale),
+                     multiplier, shift);
+}
+
+inline void SetRsqrtOutputMultiplier(const float input_scale,
+                                     const float output_scale,
+                                     int32_t* multiplier, int* shift) {
+  const double scale =
+      1. / static_cast<double>((std::sqrt(input_scale) * output_scale));
+  QuantizeMultiplier(scale, multiplier, shift);
+}
+
 typedef bool (*IsSupportedType)(TfLiteType);
 template <IsSupportedType>
 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
   MicroContext* micro_context = GetMicroContext(context);
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  TfLiteTensor* input =
+      micro_context->AllocateTempInputTensor(node, kElementwiseInputTensor);
+  TF_LITE_ENSURE(context, input != nullptr);
+  TfLiteTensor* output =
+      micro_context->AllocateTempOutputTensor(node, kElementwiseOutputTensor);
+  TF_LITE_ENSURE(context, output != nullptr);
+  TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+  if (!IsSupportedType(input->type)) {
+    TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
+                       TfLiteTypeGetName(input->type), input->type);
+    return kTfLiteError;
+  }
+
+  micro_context->DeallocateTempTfLiteTensor(input);
+  micro_context->DeallocateTempTfLiteTensor(output);
+  return kTfLiteOk;
+}
 
+typedef bool (*IsSupportedType)(TfLiteType);
+template <IsSupportedType, const int op_nameid>
+TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
+  MicroContext* micro_context = GetMicroContext(context);
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
   TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
@@ -53,14 +117,87 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
     return kTfLiteError;
   }
 
+  auto* op_data = static_cast<OpDataAbsRsqrt*>(node->user_data);
+  op_data->input_type = input->type;
+
+  // For int16 type input, we support both quantized and non-quantized
+  // evaluation.
+  if (op_nameid == kAbsNameId) {
+    op_data->input_quantization_type = input->quantization.type;
+  }
+
+  if (input->type == kTfLiteInt8 ||
+      (input->type == kTfLiteInt16 &&
+       input->quantization.type != kTfLiteNoQuantization)) {
+    TF_LITE_ENSURE_EQ(context, input->quantization.type,
+                      kTfLiteAffineQuantization);
+    TF_LITE_ENSURE_EQ(context, output->quantization.type,
+                      kTfLiteAffineQuantization);
+    const auto* input_params =
+        reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
+    const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
+        output->quantization.params);
+    TF_LITE_ENSURE(context, input_params != nullptr);
+    TF_LITE_ENSURE(context, input_params->scale != nullptr);
+    TF_LITE_ENSURE(context, input_params->scale->size > 0);
+    TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
+    TF_LITE_ENSURE(context, output_params != nullptr);
+    TF_LITE_ENSURE(context, output_params->scale != nullptr);
+    TF_LITE_ENSURE(context, output_params->scale->size > 0);
+    TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
+    op_data->input_offset = input_params->zero_point->data[0];
+    op_data->output_offset = output_params->zero_point->data[0];
+    if (input->type == kTfLiteInt16) {
+      TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
+      TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
+    }
+    const float input_scale = input_params->scale->data[0];
+    const float output_scale = output_params->scale->data[0];
+    op_data->needs_rescale = input_scale != output_scale;
+    if (op_nameid == kAbsNameId && op_data->needs_rescale) {
+      SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
+                             &op_data->shift);
+    } else if (op_nameid == kRsrqtNameId) {
+      SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
+                               &op_data->shift);
+    }
+  }
   micro_context->DeallocateTempTfLiteTensor(input);
   micro_context->DeallocateTempTfLiteTensor(output);
   return kTfLiteOk;
 }
 
+template <typename T>
+inline TfLiteStatus EvalImplQuantized(
+    TfLiteContext* context, TfLiteNode* node,
+    T func(TfLiteContext*, TfLiteNode*, T),
+    TfLiteStatus validate_input_func(TfLiteContext*, TfLiteNode*, T),
+    TfLiteType expected_type) {
+  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
+  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
+  TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
+  const size_t num_elements = ElementCount(*input->dims);
+  const T* in_data = tflite::micro::GetTensorData<T>(input);
+  T* out_data = tflite::micro::GetTensorData<T>(output);
+  for (size_t i = 0; i < num_elements; ++i) {
+    if (validate_input_func) {
+      TF_LITE_ENSURE_OK(context,
+                        validate_input_func(context, node, in_data[i]));
+    }
+    out_data[i] = func(context, node, in_data[i]);
+  }
+  return kTfLiteOk;
+}
+
+template <typename T>
+inline T AbsHelper(T i) {
+  return std::abs(i);
+}
+
 template <typename T>
 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
-                             T func(T), TfLiteType expected_type) {
+                             T func(T), TfLiteStatus validate_input_func(T),
+                             TfLiteType expected_type) {
   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
@@ -68,6 +205,9 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
   const T* in_data = tflite::micro::GetTensorData<T>(input);
   T* out_data = tflite::micro::GetTensorData<T>(output);
   for (size_t i = 0; i < num_elements; ++i) {
+    if (validate_input_func) {
+      TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
+    }
     out_data[i] = func(in_data[i]);
   }
   return kTfLiteOk;
@@ -75,16 +215,114 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
 
 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
                                 float float_func(float)) {
-  return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
+  return EvalImpl<float>(context, node, float_func,
+                         /*validate_input_func=*/nullptr, kTfLiteFloat32);
 }
 
 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
+
                                 bool bool_func(bool)) {
-  return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
+  return EvalImpl<bool>(context, node, bool_func,
+                        /*validate_input_func=*/nullptr, kTfLiteBool);
+}
+
+void* ElementWiseAbsRsqrtInit(TfLiteContext* context, const char* buffer,
+                              size_t length) {
+  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+  return context->AllocatePersistentBuffer(context, sizeof(OpDataAbsRsqrt));
+}
+
+template <typename T>
+inline T AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
+  const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
+  const int kMin = std::numeric_limits<T>::min();
+  const int kMax = std::numeric_limits<T>::max();
+
+  const int32_t value = std::abs(i - op_data->input_offset);
+  if (!op_data->needs_rescale) {
+    return static_cast<T>(
+        std::min(std::max(static_cast<long int>(value + op_data->output_offset),
+                          static_cast<long int>(kMin)),
+                 static_cast<long int>(kMax)));
+  }
+
+  const int32_t output = tflite::MultiplyByQuantizedMultiplier(
+                             value, op_data->multiplier, op_data->shift) +
+                         op_data->output_offset;
+  return static_cast<T>(std::min(
+      std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
+      static_cast<long int>(kMax)));
+}
+
+template <typename T>
+inline T RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
+  const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
+  const int kMin = std::numeric_limits<T>::min();
+  const int kMax = std::numeric_limits<T>::max();
+
+  const int32_t value = (i - op_data->input_offset);
+  const int32_t kShift = 20;  // Shift to keep value integer.
+  if (value == 0) {
+    // Assume that any value close to 0 represents the max output value.
+    return static_cast<T>(kMax);
+  }
+  int32_t inv_sqrt_multiplier;
+  int inv_sqrt_shift;
+  GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
+                                   &inv_sqrt_shift);
+  const int32_t data = tflite::MultiplyByQuantizedMultiplier(
+      static_cast<int32_t>(1), inv_sqrt_multiplier, inv_sqrt_shift + kShift);
+  const int32_t output =
+      tflite::MultiplyByQuantizedMultiplier(data, op_data->multiplier,
+                                            op_data->shift - kShift) +
+      op_data->output_offset;
+  return static_cast<T>(std::min(
+      std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
+      static_cast<long int>(kMax)));
+}
+
+template <typename T>
+TfLiteStatus validate_input_func(TfLiteContext* context, TfLiteNode* node,
+                                 T i) {
+  const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
+
+  TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
+                     "Rsqrt is only defined for positive values");
+  return static_cast<TfLiteStatus>(kTfLiteOk);
 }
 
 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
-  return EvalNumeric(context, node, std::abs);
+  OpDataAbsRsqrt* op_data = reinterpret_cast<OpDataAbsRsqrt*>(node->user_data);
+  TfLiteType type = op_data->input_type;
+  TfLiteQuantizationType input_quantization_type =
+      op_data->input_quantization_type;
+  TfLiteStatus eval_result;
+
+  switch (type) {
+    case kTfLiteFloat32:
+      eval_result = EvalNumeric(context, node, std::abs);
+      break;
+    case kTfLiteInt8:
+      eval_result =
+          EvalImplQuantized<int8_t>(context, node, AbsEvalQuantized,
+                                    /*validate_input_func=*/nullptr, type);
+      break;
+    case kTfLiteInt16:
+      eval_result =
+          input_quantization_type == kTfLiteNoQuantization
+              ? EvalImpl<int16_t>(context, node, AbsHelper,
+                                  /*validate_input_func=*/nullptr, type)
+              : EvalImplQuantized<int16_t>(context, node, AbsEvalQuantized,
+                                           /*validate_input_func=*/nullptr,
+                                           type);
+      break;
+    default:
+      TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
+                         TfLiteTypeGetName(type));
+      return kTfLiteError;
+      break;
+  }
+  return eval_result;
 }
 
 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
@@ -104,7 +342,23 @@ TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
 }
 
 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
-  return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
+  const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
+  TfLiteType type = op_data->input_type;
+  switch (type) {
+    case kTfLiteFloat32:
+      return EvalImpl<float>(
+          context, node, [](float f) { return 1.f / std::sqrt(f); },
+          /*validate_input_func=*/nullptr, type);
+    case kTfLiteInt8:
+      return EvalImplQuantized<int8_t>(context, node,
+                                       elementwise::RsqrtEvalQuantized,
+                                       elementwise::validate_input_func, type);
+
+    default:
+      TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
+                         TfLiteTypeGetName(type));
+      return kTfLiteError;
+  }
 }
 
 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
@@ -119,101 +373,57 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace elementwise
 
 TfLiteRegistration Register_ABS() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::AbsEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      elementwise::ElementWiseAbsRsqrtInit,
+      elementwise::PrepareAbsRsqrt<elementwise::IsAbsSupportedType,
+                                   elementwise::kAbsNameId>,
+      elementwise::AbsEval);
 }
 
 TfLiteRegistration Register_SIN() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::SinEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SinEval);
 }
 
 TfLiteRegistration Register_COS() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::CosEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::CosEval);
 }
 
 TfLiteRegistration Register_LOG() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::LogEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::LogEval);
 }
 
 TfLiteRegistration Register_SQRT() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::SqrtEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SqrtEval);
 }
 
 TfLiteRegistration Register_RSQRT() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::RsqrtEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      elementwise::ElementWiseAbsRsqrtInit,
+      elementwise::PrepareAbsRsqrt<elementwise::IsRsqrtSupportedType,
+                                   elementwise::kRsrqtNameId>,
+      elementwise::RsqrtEval);
 }
 
 TfLiteRegistration Register_SQUARE() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::SquareEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SquareEval);
 }
 
 TfLiteRegistration Register_LOGICAL_NOT() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
-          /*invoke=*/elementwise::LogicalNotEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
+      elementwise::LogicalNotEval);
 }
 
 }  // namespace micro
 }  // namespace ops
-}  // namespace tflite
+}  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc

@@ -146,14 +146,7 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_ELU() {
-  return {/*init=*/EluInit,
-          /*free=*/nullptr,
-          /*prepare=*/EluPrepare,
-          /*invoke=*/EluEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(EluInit, EluPrepare, EluEval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc

@@ -196,14 +196,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
 }
 
 TfLiteRegistration Register_ADD() {
-  return {/*init=*/AddInit,
-          /*free=*/nullptr,
-          /*prepare=*/AddPrepare,
-          /*invoke=*/AddEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
 }
 
 }  // namespace tflite

+ 46 - 21
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc

@@ -112,9 +112,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 
 #if ESP_NN
   if (input->type == kTfLiteInt8) {
+    data_dims_t input_dims =  {
+                                .width = input_width, .height = input_height,
+                                .channels = input->dims->data[3], 1
+                              };
+    data_dims_t output_dims = {
+                                .width = output_width, .height = output_height,
+                                .channels = output->dims->data[3], 1
+                              };
+    data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+    conv_params_t conv_params = {
+                                  .in_offset = 0, .out_offset = 0,
+                                  .stride = {params.stride_width, params.stride_height},
+                                  .padding = {data->op_data.padding.width, data->op_data.padding.height},
+                                  .dilation = {0, 0}, .activation = {-128, 127}
+                                };
+
     int scratch_buf_size = esp_nn_get_conv_scratch_size(
-        input_width, input_height, input->dims->data[3],
-        output->dims->data[3], filter_width, filter_height);
+        &input_dims, &filter_dims, &output_dims, &conv_params);
     if (scratch_buf_size > 0) {
       TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
         context, scratch_buf_size, &data->buffer_idx));
@@ -191,18 +206,33 @@ inline void EvalQuantizedPerChannel(
     const int input_size = input_width * input_height * input_depth;
     const int output_size = output_width * output_height * output_depth;
 
+    data_dims_t input_dims =  {
+                                .width = input_width, .height = input_height,
+                                .channels = input_depth, 1
+                              };
+    data_dims_t output_dims = {
+                                .width = output_width, .height = output_height,
+                                .channels = output_depth, 1
+                              };
+    data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+    conv_params_t conv_params = {
+                                  .in_offset = input_offset, .out_offset = output_offset,
+                                  .stride = {stride_width, stride_height},
+                                  .padding = {pad_width, pad_height},
+                                  .dilation = {0, 0},
+                                  .activation = {activation_min, activation_max}
+                                };
+    quant_data_t quant_data = {
+                                .shift = data.op_data.per_channel_output_shift,
+                                .mult = data.op_data.per_channel_output_multiplier
+                              };
+
     for (int i_batch = 0; i_batch < batch_size; i_batch++) {
-      esp_nn_conv_s8(input_data + i_batch * input_size,
-                     input_width, input_height, input_depth, input_offset,
-                     pad_width, pad_height, stride_width, stride_height,
-                     tflite::micro::GetTensorData<int8_t>(filter),
-                     filter_width, filter_height,
+      esp_nn_conv_s8(&input_dims, input_data + i_batch * input_size,
+                     &filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
                      tflite::micro::GetTensorData<int32_t>(bias),
-                     output_data + i_batch * output_size,
-                     output_width, output_height, output_depth, output_offset,
-                     data.op_data.per_channel_output_shift,
-                     data.op_data.per_channel_output_multiplier,
-                     activation_min, activation_max);
+                     &output_dims, output_data + i_batch * output_size,
+                     &conv_params, &quant_data);
     }
   } else {
     reference_integer_ops::ConvPerChannel(
@@ -299,21 +329,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
                          TfLiteTypeGetName(input->type), input->type);
       return kTfLiteError;
   }
-  conv_total_time += esp_timer_get_time() - start_time;
+  long long time_this_instance = esp_timer_get_time() - start_time;
+  conv_total_time += time_this_instance;
+  //printf("time this instance: %llu\n", time_this_instance / 1000);
   return kTfLiteOk;
 }
 
 }  // namespace
 
 TfLiteRegistration Register_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 49 - 22
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc

@@ -112,21 +112,36 @@ inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
     if (data.buffer_idx > -1) {
       scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
     }
+
     esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
 
+    data_dims_t input_dims =  {
+                                .width = input_width, .height = input_height,
+                                .channels = input_depth, 1
+                              };
+    data_dims_t output_dims = {
+                                .width = output_width, .height = output_height,
+                                .channels = output_depth, 1
+                              };
+    data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+    dw_conv_params_t conv_params =  {
+                                      .in_offset = input_offset, .out_offset = output_offset,
+                                      .ch_mult = depth_multiplier,
+                                      .stride = {stride_width, stride_height},
+                                      .padding = {pad_width, pad_height}, .dilation = {0, 0},
+                                      .activation = {activation_min, activation_max}
+                                    };
+    quant_data_t quant_data = {
+                                .shift = data.op_data.per_channel_output_shift,
+                                .mult = data.op_data.per_channel_output_multiplier
+                              };
+
     for (int i_batch = 0; i_batch < batch_size; i_batch++) {
-      esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width,
-                               input_height, input_depth, input_offset,
-                               pad_width, pad_height,
-                               stride_width, stride_height, depth_multiplier,
-                               tflite::micro::GetTensorData<int8_t>(filter),
-                               filter_width, filter_height,
+      esp_nn_depthwise_conv_s8(&input_dims, input_data + i_batch * input_size,
+                               &filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
                                tflite::micro::GetTensorData<int32_t>(bias),
-                               output_data + i_batch * output_size,
-                               output_width, output_height, output_offset,
-                               data.op_data.per_channel_output_shift,
-                               data.op_data.per_channel_output_multiplier,
-                               activation_min, activation_max);
+                               &output_dims, output_data + i_batch * output_size,
+                               &conv_params, &quant_data);
     }
   } else {
     reference_integer_ops::DepthwiseConvPerChannel(
@@ -209,9 +224,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 
 #if ESP_NN
   if (input->type == kTfLiteInt8) {
+    data_dims_t input_dims =  {
+                                .width = input_width, .height = input_height,
+                                .channels = input->dims->data[3], 1
+                              };
+    data_dims_t output_dims = {
+                                .width = output_width, .height = output_height,
+                                .channels = output->dims->data[3], 1
+                              };
+    data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+    dw_conv_params_t conv_params =  {
+                                      .in_offset = 0, .out_offset = 0,
+                                      .ch_mult = params.depth_multiplier,
+                                      .stride = {params.stride_width, params.stride_height},
+                                      .padding = {data->op_data.padding.width, data->op_data.padding.height},
+                                      .dilation = {0, 0}, .activation = {-128, 127}
+                                    };
+
     int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
-        input_width, input_height, input->dims->data[3],
-        params.depth_multiplier, filter_width, filter_height);
+        &input_dims, &filter_dims, &output_dims, &conv_params);
     if (scratch_buf_size > 0) {
       TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
         context, scratch_buf_size, &data->buffer_idx));
@@ -299,21 +330,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
                          TfLiteTypeGetName(input->type), input->type);
       return kTfLiteError;
   }
-  dc_total_time += esp_timer_get_time() - start_time;
+  long long time_this_instance = esp_timer_get_time() - start_time;
+  dc_total_time += time_this_instance;
+  // printf("time this instance: %llu\n", time_this_instance / 1000);
+
   return kTfLiteOk;
 }
 
 }  // namespace
 
 TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc

@@ -185,14 +185,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_FULLY_CONNECTED() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc

@@ -118,14 +118,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
 }
 
 TfLiteRegistration Register_MUL() {
-  return {/*init=*/MulInit,
-          /*free=*/nullptr,
-          /*prepare=*/MulPrepare,
-          /*invoke=*/MulEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval);
 }
 
 }  // namespace tflite

+ 2 - 16
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc

@@ -221,25 +221,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
 }  // namespace
 
 TfLiteRegistration Register_AVERAGE_POOL_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/PoolingPrepare,
-          /*invoke=*/AverageEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval);
 }
 
 TfLiteRegistration Register_MAX_POOL_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/PoolingPrepare,
-          /*invoke=*/MaxEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval);
 }
 
 }  // namespace tflite

+ 208 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc

@@ -0,0 +1,208 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/kernels/softmax.h"
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/softmax.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+
+#include "freertos/FreeRTOS.h"
+#include <esp_timer.h>
+
+#if ESP_NN
+#include <esp_nn.h>
+#endif
+
+long long softmax_total_time = 0;
+
+namespace tflite {
+namespace {
+// Softmax parameter data that persists in user_data
+const int kInt16LUTArraySize = 513;
+
+struct NodeData {
+  SoftmaxParams op_data;
+#if ESP_NN
+  int buffer_idx;
+#endif
+};
+
+static void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+  return context->AllocatePersistentBuffer(context, sizeof(NodeData));
+}
+
+void SoftmaxQuantized(TfLiteContext* context, const TfLiteEvalTensor* input,
+                      TfLiteEvalTensor* output, const NodeData* data) {
+  if (input->type == kTfLiteInt8) {
+    if (output->type == kTfLiteInt16) {
+      tflite::reference_ops::Softmax(
+          data->op_data, tflite::micro::GetTensorShape(input),
+          tflite::micro::GetTensorData<int8_t>(input),
+          tflite::micro::GetTensorShape(output),
+          tflite::micro::GetTensorData<int16_t>(output));
+    } else {
+#if ESP_NN
+      const int32_t input_beta_multiplier = data->op_data.input_multiplier;
+      const int32_t input_beta_left_shift = data->op_data.input_left_shift;
+      const int diff_min = data->op_data.diff_min;
+      const RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
+      const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
+      const int trailing_dim = input_shape.DimensionsCount() - 1;
+      const int outer_size =
+          MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+      const int depth =
+          MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+      const int8_t *in_ptr = tflite::micro::GetTensorData<int8_t>(input);
+      int8_t *out_ptr = tflite::micro::GetTensorData<int8_t>(output);
+      void *scratch_buf = NULL;
+      if (data->buffer_idx > -1) {
+        scratch_buf = context->GetScratchBuffer(context, data->buffer_idx);
+      }
+      esp_nn_set_softmax_scratch_buf(scratch_buf);
+      esp_nn_softmax_s8(in_ptr, outer_size, depth, input_beta_multiplier,
+                        input_beta_left_shift, diff_min, out_ptr);
+#else
+      tflite::reference_ops::Softmax(
+          data->op_data, tflite::micro::GetTensorShape(input),
+          tflite::micro::GetTensorData<int8_t>(input),
+          tflite::micro::GetTensorShape(output),
+          tflite::micro::GetTensorData<int8_t>(output));
+#endif
+    }
+  } else {
+    tflite::reference_ops::SoftmaxInt16(
+        data->op_data, tflite::micro::GetTensorShape(input),
+        tflite::micro::GetTensorData<int16_t>(input),
+        tflite::micro::GetTensorShape(output),
+        tflite::micro::GetTensorData<int16_t>(output));
+  }
+}
+
+static TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
+  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
+
+  TFLITE_DCHECK(node->user_data != nullptr);
+  NodeData data = *static_cast<NodeData*>(node->user_data);
+
+  long long start_time = esp_timer_get_time();
+  switch (input->type) {
+    case kTfLiteFloat32: {
+      tflite::reference_ops::Softmax(
+          data.op_data, tflite::micro::GetTensorShape(input),
+          tflite::micro::GetTensorData<float>(input),
+          tflite::micro::GetTensorShape(output),
+          tflite::micro::GetTensorData<float>(output));
+    }
+    break;
+    case kTfLiteInt8:
+    case kTfLiteInt16: {
+      SoftmaxQuantized(context, input, output, &data);
+    }
+    break;
+    default:
+      TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
+                         TfLiteTypeGetName(input->type), input->type);
+      return kTfLiteError;
+  }
+  softmax_total_time += esp_timer_get_time() - start_time;
+  return kTfLiteOk;
+}
+
+static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  MicroContext* micro_context = GetMicroContext(context);
+
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
+  TF_LITE_ENSURE(context, input != nullptr);
+  TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
+  TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
+  TF_LITE_ENSURE(context, output != nullptr);
+
+  TF_LITE_ENSURE(context, node->user_data != nullptr);
+  NodeData* data = static_cast<NodeData*>(node->user_data);
+  // Only allocate LUTs for KTfLiteInt16 data type
+  if (input->type == kTfLiteInt16) {
+    void* raw_exp_lut = context->AllocatePersistentBuffer(
+        context, sizeof(int16_t) * kInt16LUTArraySize);
+    TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
+    data->op_data.exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
+    void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
+        context, sizeof(int16_t) * kInt16LUTArraySize);
+    TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
+    data->op_data.one_over_one_plus_x_lut =
+        reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
+  }
+
+  if (output->type == kTfLiteInt16) {
+    TF_LITE_ENSURE(context,
+                   input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
+  } else {
+    TF_LITE_ENSURE_EQ(context, input->type, output->type);
+  }
+
+  // Populate LUT if required
+  if (input->type == kTfLiteInt16) {
+    TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+    // exp LUT only used on negative values
+    // we consider exp(-10.0) is insignificant to accumulation
+    gen_lut<float, int16_t, int16_t>(
+        [](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
+        data->op_data.exp_lut);
+    gen_lut<float, int16_t, int16_t>(
+        [](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
+        1.0f, data->op_data.one_over_one_plus_x_lut);
+    data->op_data.zero_point = output->params.zero_point;
+    data->op_data.scale = output->params.scale;
+  }
+
+  auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+  auto ret_val =
+      CalculateSoftmaxParams(context, input, output, params, &data->op_data);
+
+#if ESP_NN
+  if (output->type == kTfLiteInt8 && input->type == kTfLiteInt8) {
+    const int32_t input_width = input->dims->data[1];
+    const int32_t input_height = input->dims->data[2];
+    int scratch_buf_size = esp_nn_get_softmax_scratch_size(input_width,
+                                                           input_height);
+    if (scratch_buf_size > 0) {
+      TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
+        context, scratch_buf_size, &data->buffer_idx));
+    }
+  }
+#endif
+
+  micro_context->DeallocateTempTfLiteTensor(input);
+  micro_context->DeallocateTempTfLiteTensor(output);
+  return ret_val;
+}
+
+}  // namespace
+
+TfLiteRegistration Register_SOFTMAX() {
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
+}
+
+}  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc

@@ -72,14 +72,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_EXP() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc

@@ -146,14 +146,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_EXPAND_DIMS() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc

@@ -135,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_FILL() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc

@@ -42,14 +42,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace floor
 
 TfLiteRegistration Register_FLOOR() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/floor::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, floor::Eval);
 }
 
 }  // namespace micro

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc

@@ -123,14 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_FLOOR_DIV() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc

@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_FLOOR_MOD() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 19 - 12
code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc

@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -55,10 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
       node, kFullyConnectedOutputTensor);
   TF_LITE_ENSURE(context, output != nullptr);
-
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
-  TF_LITE_ENSURE_MSG(context, input->type == filter->type,
-                     "Hybrid models are not supported on TFLite Micro.");
 
   TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
                                  context, params->activation, input->type,
@@ -126,6 +123,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       break;
     }
 
+    case kTfLiteInt16: {
+      const int64_t* bias_data =
+          nullptr != bias ? tflite::micro::GetTensorData<int64_t>(bias)
+                          : nullptr;
+
+      tflite::reference_integer_ops::FullyConnected(
+          FullyConnectedParamsQuantized(data),
+          tflite::micro::GetTensorShape(input),
+          tflite::micro::GetTensorData<int16_t>(input),
+          tflite::micro::GetTensorShape(filter),
+          tflite::micro::GetTensorData<int8_t>(filter),
+          tflite::micro::GetTensorShape(bias), bias_data,
+          tflite::micro::GetTensorShape(output),
+          tflite::micro::GetTensorData<int16_t>(output));
+      break;
+    }
+
     default: {
       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
                          TfLiteTypeGetName(input->type), input->type);
@@ -138,14 +152,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_FULLY_CONNECTED() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 19 - 1
code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.h

@@ -1,4 +1,4 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -81,6 +81,24 @@ inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
 }
 
 #endif
+
+#if defined(CMSIS_NN)
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int16.
+TfLiteRegistration Register_FULLY_CONNECTED_INT16();
+
+#else
+// Note that while this block gets used for both reference and optimized kernels
+// that do not have any specialized implementations, the only goal here is to
+// define fallback implementation that allow reference kernels to still be used
+// from applications that call a more specific kernel variant.
+
+inline TfLiteRegistration Register_FULLY_CONNECTED_INT16() {
+  return Register_FULLY_CONNECTED();
+}
+
+#endif
+
 }  // namespace tflite
 
 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/gather.cc

@@ -218,14 +218,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_GATHER() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/gather_nd.cc

@@ -195,14 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_GATHER_ND() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 2 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/hard_swish.cc

@@ -68,14 +68,8 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_HARD_SWISH() {
-  return {/*init=*/HardSwishInit,
-          /*free=*/nullptr,
-          /*prepare=*/tflite::HardSwishPrepare,
-          /*invoke=*/HardSwishEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(HardSwishInit, tflite::HardSwishPrepare,
+                                   HardSwishEval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/if.cc

@@ -115,14 +115,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace.
 
 TfLiteRegistration Register_IF() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 
 }  // namespace tflite

+ 3 - 2
code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.cc

@@ -15,9 +15,9 @@ limitations under the License.
 
 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
 
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/micro_arena_constants.h"
 #include "tensorflow/lite/micro/micro_error_reporter.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/test_helpers.h"
 
 namespace tflite {
@@ -30,7 +30,7 @@ uint8_t KernelRunner::kKernelRunnerBuffer_[];
 KernelRunner::KernelRunner(const TfLiteRegistration& registration,
                            TfLiteTensor* tensors, int tensors_size,
                            TfLiteIntArray* inputs, TfLiteIntArray* outputs,
-                           void* builtin_data)
+                           void* builtin_data, TfLiteIntArray* intermediates)
     : registration_(registration),
       allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
                                                kKernelRunnerBuffer_,
@@ -54,6 +54,7 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
   node_.inputs = inputs;
   node_.outputs = outputs;
   node_.builtin_data = builtin_data;
+  node_.intermediates = intermediates;
 }
 
 bool KernelRunner::ValidateTempBufferDeallocated() {

+ 3 - 2
code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.h

@@ -18,9 +18,9 @@ limitations under the License.
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/fake_micro_context.h"
 #include "tensorflow/lite/micro/mock_micro_graph.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
 
 namespace tflite {
 namespace micro {
@@ -35,7 +35,8 @@ class KernelRunner {
  public:
   KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
                int tensors_size, TfLiteIntArray* inputs,
-               TfLiteIntArray* outputs, void* builtin_data);
+               TfLiteIntArray* outputs, void* builtin_data,
+               TfLiteIntArray* intermediates = nullptr);
 
   // Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
   // exceptions will be DebugLog'd and returned as a status code.

+ 15 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.cc

@@ -36,6 +36,21 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index,
 
 }  // namespace
 
+TfLiteRegistration RegisterOp(
+    void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
+    TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
+    TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) {
+  return {/*init=*/init,
+          /*free=*/nullptr,
+          /*prepare=*/prepare,
+          /*invoke=*/invoke,
+          /*profiling_string=*/nullptr,
+          /*builtin_code=*/0,
+          /*custom_name=*/nullptr,
+          /*version=*/0,
+          /*registration_external=*/nullptr};
+}
+
 // Returns a mutable tensor for a given input index. is_variable must be checked
 // during prepare when the full TfLiteTensor is available.
 TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,

+ 22 - 3
code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.h

@@ -27,6 +27,11 @@ limitations under the License.
 namespace tflite {
 namespace micro {
 
+TfLiteRegistration RegisterOp(
+    void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
+    TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
+    TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node));
+
 // Returns a mutable tensor for a given input index. is_variable must be checked
 // during prepare when the full TfLiteTensor is available.
 TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
@@ -40,19 +45,33 @@ const TfLiteEvalTensor* GetEvalInput(const TfLiteContext* context,
 TfLiteEvalTensor* GetEvalOutput(const TfLiteContext* context,
                                 const TfLiteNode* node, int index);
 
-// Returns data for a TfLiteEvalTensor struct.
+// Returns data for a TfLiteEvalTensor struct that are expected to exist.
 template <typename T>
 T* GetTensorData(TfLiteEvalTensor* tensor) {
-  return tensor != nullptr ? reinterpret_cast<T*>(tensor->data.raw) : nullptr;
+  TFLITE_DCHECK(tensor != nullptr);
+  return reinterpret_cast<T*>(tensor->data.raw);
 }
 
-// Returns const data for a TfLiteEvalTensor struct.
+// Returns const data for a TfLiteEvalTensor struct that are expected to exist.
 template <typename T>
 const T* GetTensorData(const TfLiteEvalTensor* tensor) {
   TFLITE_DCHECK(tensor != nullptr);
   return reinterpret_cast<const T*>(tensor->data.raw);
 }
 
+// Returns data for a TfLiteEvalTensor struct that could be null.
+template <typename T>
+T* GetOptionalTensorData(TfLiteEvalTensor* tensor) {
+  return tensor == nullptr ? nullptr : reinterpret_cast<T*>(tensor->data.raw);
+}
+
+// Returns const data for a TfLiteEvalTensor struct that could be null.
+template <typename T>
+const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) {
+  return tensor == nullptr ? nullptr
+                           : reinterpret_cast<const T*>(tensor->data.raw);
+}
+
 // Returns the shape of a TfLiteEvalTensor struct.
 const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
 

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/l2_pool_2d.cc

@@ -136,14 +136,7 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_L2_POOL_2D() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/L2Prepare,
-          /*invoke=*/L2Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, L2Prepare, L2Eval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/l2norm.cc

@@ -137,14 +137,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace l2norm
 
 TfLiteRegistration Register_L2NORM_REF() {
-  return {/*init=*/l2norm::Init,
-          /*free=*/nullptr,
-          /*prepare=*/l2norm::Prepare,
-          /*invoke=*/l2norm::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(l2norm::Init, l2norm::Prepare, l2norm::Eval);
 }
 
 TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); }

+ 2 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/leaky_relu.cc

@@ -88,14 +88,8 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
 }
 
 TfLiteRegistration Register_LEAKY_RELU() {
-  return {/*init=*/LeakyReluInit,
-          /*free=*/nullptr,
-          /*prepare=*/LeakyReluPrepare,
-          /*invoke=*/LeakyReluEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(LeakyReluInit, LeakyReluPrepare,
+                                   LeakyReluEval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/log_softmax.cc

@@ -142,14 +142,7 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_LOG_SOFTMAX() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/LogSoftmaxPrepare,
-          /*invoke=*/LogSoftmaxEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, LogSoftmaxPrepare, LogSoftmaxEval);
 }
 
 }  // namespace tflite

+ 2 - 20
code/components/tflite-lib/tensorflow/lite/micro/kernels/logical.cc

@@ -34,29 +34,11 @@ TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_LOGICAL_OR() {
-  // Init, Free, Prepare, Eval are satisfying the Interface required by
-  // TfLiteRegistration.
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/LogicalOrEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, LogicalOrEval);
 }
 
 TfLiteRegistration Register_LOGICAL_AND() {
-  // Init, Free, Prepare, Eval are satisfying the Interface required by
-  // TfLiteRegistration.
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/LogicalAndEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, LogicalAndEval);
 }
 
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/logistic.cc

@@ -106,13 +106,6 @@ TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 
 TfLiteRegistration Register_LOGISTIC() {
-  return {/*init=*/LogisticInit,
-          /*free=*/nullptr,
-          /*prepare=*/LogisticPrepare,
-          /*invoke=*/LogisticEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(LogisticInit, LogisticPrepare, LogisticEval);
 }
 }  // namespace tflite

+ 2955 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.cc

@@ -0,0 +1,2955 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/micro/kernels/lstm_eval.h"
+
+#include <cmath>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/micro_tensor_utils.h"
+namespace tflite {
+namespace {
+
+void ComputeRowSums(
+    int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
+    int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
+    int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
+    int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
+    int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
+    int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
+    int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
+    int n_input, int n_aux_input, int n_output,
+    const int8_t* input_to_input_weights_ptr,
+    const int8_t* input_to_forget_weights_ptr,
+    const int8_t* input_to_cell_weights_ptr,
+    const int8_t* input_to_output_weights_ptr,
+    const int8_t* aux_input_to_input_weights_ptr,
+    const int8_t* aux_input_to_forget_weights_ptr,
+    const int8_t* aux_input_to_cell_weights_ptr,
+    const int8_t* aux_input_to_output_weights_ptr,
+    const int8_t* recurrent_to_input_weights_ptr,
+    const int8_t* recurrent_to_forget_weights_ptr,
+    const int8_t* recurrent_to_cell_weights_ptr,
+    const int8_t* recurrent_to_output_weights_ptr,
+    const int8_t* projection_weights_ptr, bool use_cifg,
+    const float* aux_input_ptr) {
+  // Compute the row sums for dequantization
+  if (!use_cifg) {
+    micro_tensor_utils::ReductionSumVector(
+        input_to_input_weights_ptr, input_to_input_row_sums, n_cell, n_input);
+  }
+  micro_tensor_utils::ReductionSumVector(
+      input_to_forget_weights_ptr, input_to_forget_row_sums, n_cell, n_input);
+  micro_tensor_utils::ReductionSumVector(
+      input_to_cell_weights_ptr, input_to_cell_row_sums, n_cell, n_input);
+  micro_tensor_utils::ReductionSumVector(
+      input_to_output_weights_ptr, input_to_output_row_sums, n_cell, n_input);
+
+  if (aux_input_ptr) {
+    if (!use_cifg) {
+      micro_tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
+                                             aux_input_to_input_row_sums,
+                                             n_cell, n_aux_input);
+    }
+    micro_tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
+                                           aux_input_to_forget_row_sums, n_cell,
+                                           n_aux_input);
+    micro_tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
+                                           aux_input_to_cell_row_sums, n_cell,
+                                           n_aux_input);
+    micro_tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
+                                           aux_input_to_output_row_sums, n_cell,
+                                           n_aux_input);
+  }
+  if (!use_cifg) {
+    micro_tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
+                                           recurrent_to_input_row_sums, n_cell,
+                                           n_output);
+  }
+  micro_tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
+                                         recurrent_to_forget_row_sums, n_cell,
+                                         n_output);
+  micro_tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
+                                         recurrent_to_cell_row_sums, n_cell,
+                                         n_output);
+  micro_tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
+                                         recurrent_to_output_row_sums, n_cell,
+                                         n_output);
+
+  if (projection_weights_ptr != nullptr) {
+    micro_tensor_utils::ReductionSumVector(
+        projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
+  }
+}
+
+// Calculates a single LSTM gate.
+//
+// Implements the following formula: (* is matrix multiply)
+//   gate = activate(W_input    * input + W_aux       * aux_input   +
+//                   W_peephole * cell  + W_recurrent * prev_output + bias)
+// with layer norm:
+//   gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
+//
+// Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
+//
+// Parameters:
+// Input vectors (to LSTM):    | Size:                | Optional?
+//   input                     | n_input              |
+//   aux_input                 | n_aux_input          | y (bidir LSTM)
+// Input vectors (persistent states):
+//   output_state              | n_output             |
+//   cell_state                | n_cell               |
+// 'Constant' inputs:
+//   input_to_gate_weights     | n_cell * n_input     |
+//   aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
+//   recurrent_to_gate_weights | n_cell * n_output    |
+//   cell_to_gate_weights      | n_cell               | y (peephole)
+//   gate_bias                 | n_cell               |
+//   layer_norm_coefficients   | n_cell               | y (layer norm)
+// Output vector:
+//   gate                      | n_cell               |
+// Scalar parameters:
+//   n_batch                                    - batch size / number of vectors
+//   n_input, n_aux_input, n_output, n_cell     - size of vectors.
+//   activation                                 - activation to use.
+//   is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
+//   use_layer_norm                             - if doing layer norm LSTM.
+inline void CalculateLstmGateFloat(
+    const float* input, const float* input_to_gate_weights,
+    const float* aux_input, const float* aux_input_to_gate_weights,
+    const float* output_state, const float* recurrent_to_gate_weights,
+    const float* cell_state, const float* cell_to_gate_weights,
+    const float* layer_norm_coefficients, const float* gate_bias,
+    const int n_batch, const int n_input, const int n_aux_input,
+    const int n_output, const int n_cell,
+    const TfLiteFusedActivation activation, float* gate,
+    const bool is_input_all_zeros, const bool is_aux_input_all_zeros) {
+  const bool use_peephole = (cell_to_gate_weights != nullptr);
+  const bool use_layer_norm = (layer_norm_coefficients != nullptr);
+
+  // Initialize scratch buffers with bias for regular lstm or initialize with
+  // zero for layer norm lstm.
+  if (use_layer_norm) {
+    memset(gate, 0, n_cell * n_batch * sizeof(float));
+  } else {
+    micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch,
+                                                gate);
+  }
+  // For each batch and cell: compute input_weight * input.
+  // Skip if input is all zeros.
+  if (!is_input_all_zeros) {
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
+  }
+  // For each batch and cell: compute aux_input_weight * aux_input.
+  // Skip if auxiliary input is not available or all zeros.
+  if (!is_aux_input_all_zeros) {
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        aux_input_to_gate_weights, n_cell, n_aux_input, aux_input, n_batch,
+        gate);
+  }
+  // For each batch and cell: compute recurrent_weight * output_state.
+  micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
+  // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
+  if (use_peephole) {
+    micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
+  }
+  // Do layer normalization (if layer norm LSTM)
+  if (use_layer_norm) {
+    micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
+    micro_tensor_utils::VectorBatchVectorCwiseProduct(
+        layer_norm_coefficients, n_cell, gate, n_batch, gate);
+    micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
+  }
+  // Apply activation
+  micro_tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell,
+                                              activation, gate);
+}
+
+// Updates the LSTM cell state, used by both float and hybrid LSTM versions.
+//
+// Implements the following formula:
+//   cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
+//
+// With CIFG LSTM, input gate is replaced by (1-forget_gate).
+//
+// Parameters:
+//  - n_batch, n_cell: sizes of vectors
+//  - cell_state: input/output vector, size n_batch*n_cell
+//  - input_gate: input vector, size n_batch*n_cell.
+//  - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
+//  - cell_gate: input vector, size n_batch*n_cell.
+//  - use_cifg: use 1-forget_gate instead of input_gate.
+//  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
+void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
+                         const float* input_gate, float* forget_gate,
+                         const float* cell_gate, bool use_cifg, float clip) {
+  micro_tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
+                                               n_batch * n_cell, cell_state);
+
+  if (use_cifg) {
+    // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
+    // scratch, as input_gate array is not allocated in this case. (Be careful
+    // not to write to the scratch before reading the forget gate data.)
+    float* scratch = forget_gate;
+    micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
+    micro_tensor_utils::VectorVectorCwiseProductAccumulate(
+        cell_gate, scratch, n_batch * n_cell, cell_state);
+  } else {
+    micro_tensor_utils::VectorVectorCwiseProductAccumulate(
+        cell_gate, input_gate, n_batch * n_cell, cell_state);
+  }
+  if (clip > 0.0f) {
+    micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
+  }
+}
+
+// Calculates the output state tensor of an LSTM step.
+//
+// Implements the following formula:
+//   output_no_projection = output_gate .* activate(cell_state)
+//     (elementwise vector product)
+// If no projection is used:
+//   output = output_state = output_no_projection
+// With projection:
+//   output = output_state = clip(W*output_no_projection + bias)
+//
+// Output might not have a different 'stride' than n_batch, so we need to copy.
+//
+// Parameters:
+//  - n_batch: batches: the number of distinct vectors in each array.
+//  - n_cell, n_output: sizes of vectors.
+//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
+//  - projection_weights, projection_weights_scale, projection_bias:
+//      constant inputs, describing projection matrix and bias.
+//  - proj_clip: if > 0, clip the output of the projection.
+//  - output_state: output vector, size n_batch*n_output. Must be contigous.
+//  - scratch: scratch area, size n_batch*n_cell.
+void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
+                              const float* cell_state, const float* output_gate,
+                              TfLiteFusedActivation activation,
+                              const float* projection_weights,
+                              const float* projection_bias,
+                              const float proj_clip, float* output_state,
+                              float* scratch) {
+  micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
+                                              activation, scratch);
+  micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch,
+                                               n_batch * n_cell, scratch);
+
+  const bool use_projection = (projection_weights != nullptr);
+  const bool use_projection_bias = (projection_bias != nullptr);
+
+  if (use_projection) {
+    if (use_projection_bias) {
+      micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output,
+                                                  n_batch, output_state);
+    } else {
+      memset(output_state, 0, n_batch * n_output * sizeof(float));
+    }
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        projection_weights, n_output, n_cell, scratch, n_batch, output_state);
+    if (proj_clip > 0.0f) {
+      micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
+                                        proj_clip);
+    }
+  } else {
+    std::memcpy(output_state, scratch, n_batch * n_output * sizeof(float));
+  }
+}
+
+// Calculates a single LSTM gate, hybrid version.
+// Implements the same functionality as CalculateLstmGateFloat.
+void CalculateLstmGateHybrid(
+    // Input and weights
+    const int8_t* input, const float* input_sf, const int32_t* input_zp,
+    const int8_t* input_to_gate_weights,
+    const uint8_t* input_to_gate_weights_ledger,
+    const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
+    // Aux input and weights
+    const int8_t* aux_input, const float* aux_input_sf,
+    const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights,
+    const float aux_input_to_gate_weights_scale,
+    int32_t* aux_input_to_gate_row_sums,
+    // Output state and weights
+    const int8_t* output_state, const float* output_state_sf,
+    const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
+    const uint8_t* recurrent_to_gate_weights_ledger,
+    const float recurrent_to_gate_weights_scale,
+    int32_t* recurrent_to_gate_row_sums,
+    // Cell state and weights (peephole LSTM)
+    const float* cell_state, const int8_t* cell_to_gate_weights,
+    const float cell_to_gate_weights_scale,
+    // Layer normalization coefficients (layer norm LSTM) + gate bias
+    const float* layer_norm_coefficients, const float* gate_bias,
+    // Array sizes
+    const int n_batch, const int n_input, const int n_aux_input,
+    const int n_output, const int n_cell,
+    const TfLiteFusedActivation activation,
+    // Output
+    float* gate,
+    // Parameters for performance optimizations
+    const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
+    const bool is_output_state_all_zeros, bool* compute_row_sums,
+    // Scratch arrays
+    float* scratch0,        // size: n_batch
+    float* scratch1,        // size: n_cell, only used if peephole LSTM
+    float* scales,          // size: n_batch
+    int32_t* accum_scratch  // For MatrixBatchVectorMultiplyAccumulate
+) {
+  const bool use_peephole = (cell_to_gate_weights != nullptr);
+  const bool use_layer_norm = (layer_norm_coefficients != nullptr);
+
+  // Initialize scratch buffers with bias for regular lstm or initialize with
+  // zero for layer norm lstm.
+  if (use_layer_norm) {
+    memset(gate, 0, n_cell * n_batch * sizeof(float));
+  } else {
+    micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch,
+                                                gate);
+  }
+  // For each batch and cell: compute input_weight * input.
+  // Skip if input is all zeros.
+  if (!is_input_all_zeros) {
+    if (input_to_gate_weights_ledger != nullptr) {
+      for (int i = 0; i < n_batch; i++) {
+        scales[i] = input_to_gate_weights_scale * input_sf[i];
+      }
+      micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
+          input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
+          input, scales, n_batch, gate);
+
+    } else {
+      micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+          input_to_gate_weights, n_cell, n_input, input,
+          input_to_gate_weights_scale, input_sf, n_batch, gate,
+          /*per_channel_scale=*/nullptr, input_zp, accum_scratch,
+          input_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
+    }
+  }
+  // For each batch and cell: compute aux_input_weight * aux_input.
+  // Skip if auxiliary input is not available or all zeros.
+  if (!is_aux_input_all_zeros) {
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
+        aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate,
+        /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch,
+        aux_input_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
+  }
+  // For each batch and cell: compute recurrent_weight * output_state.
+  // Skip if output state is all zeros.
+  if (!is_output_state_all_zeros) {
+    if (recurrent_to_gate_weights_ledger != nullptr) {
+      for (int i = 0; i < n_batch; i++) {
+        scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
+      }
+      micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
+          recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
+          n_output, output_state, scales, n_batch, gate);
+    } else {
+      micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+          recurrent_to_gate_weights, n_cell, n_output, output_state,
+          recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
+          /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
+          recurrent_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
+    }
+  }
+  // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
+  if (use_peephole) {
+    float* recovered_cell_weights = scratch1;
+    micro_tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell,
+                                             cell_to_gate_weights_scale,
+                                             recovered_cell_weights);
+    micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        recovered_cell_weights, n_cell, cell_state, n_batch, gate);
+  }
+  // Do layer normalization (if layer norm LSTM)
+  if (use_layer_norm) {
+    micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
+    micro_tensor_utils::VectorBatchVectorCwiseProduct(
+        layer_norm_coefficients, n_cell, gate, n_batch, gate);
+    micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
+  }
+  // Apply activation
+  micro_tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch,
+                                              activation, gate);
+}
+
+// Calculates the output state tensor of an LSTM step. See Float version too.
+//
+// Parameters:
+//  - n_batch: batches: the number of distinct vectors in each array.
+//  - n_cell, n_output: sizes of vectors.
+//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
+//  - projection_weights, projection_weights_scale, projection_bias:
+//      constant inputs, describing projection matrix and bias.
+//  - proj_clip: if > 0, clip the output of the projection.
+//  - output_state: output vector, size n_batch*n_output. Must be contigous.
+//  - asymmetric_quantize_inputs: parameter to control quantization.
+//  - projection_weights_row_sums, compute_row_sums: Data for optimized
+//      MatrixBatchVectorMultiplyAccumulate.
+//  - scratch0: scratch area of size n_batch*n_cell
+//  - scratch1: scratch area of size n_batch*n_cell
+//  - scratch2: scratch area of size n_batch
+//  - scratch3: scratch area of size n_batch
+//  - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate
+//  - scales: scratch area of size n_batch
+void CalculateLstmOutputHybrid(
+    int n_batch, int n_cell, int n_output, const float* cell_state,
+    const float* output_gate, TfLiteFusedActivation activation,
+    const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
+    float projection_weights_scale, const float* projection_bias,
+    const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
+    int32_t* projection_weights_row_sums, bool* compute_row_sums,
+    float* scratch0, int8_t* scratch1, float* scratch2, int32_t* scratch3,
+    int32_t* scratch4, float* scales) {
+  micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
+                                              activation, scratch0);
+  micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
+                                               n_batch * n_cell, scratch0);
+
+  const bool use_projection = (projection_weights != nullptr);
+  const bool use_projection_bias = (projection_bias != nullptr);
+
+  if (use_projection) {
+    if (use_projection_bias) {
+      micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output,
+                                                  n_batch, output_state);
+    } else {
+      memset(output_state, 0, n_batch * n_output * sizeof(float));
+    }
+    if (!micro_tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) {
+      // Save quantization and matmul computation for all zero output.
+      micro_tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell,
+                                              scratch1, scratch2, scratch3,
+                                              asymmetric_quantize_inputs);
+      if (projection_weights_ledger != nullptr) {
+        for (int i = 0; i < n_batch; i++) {
+          scales[i] = projection_weights_scale * scratch2[i];
+        }
+        micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
+            projection_weights, projection_weights_ledger, n_output, n_cell,
+            scratch1, scales, n_batch, output_state);
+      } else {
+        micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+            projection_weights, n_output, n_cell, scratch1,
+            projection_weights_scale, scratch2, n_batch, output_state,
+            /*per_channel_scale=*/nullptr, scratch3, scratch4,
+            projection_weights_row_sums, compute_row_sums, scratch2, nullptr);
+      }
+    }
+    if (proj_clip > 0.0f) {
+      micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
+                                        proj_clip);
+    }
+  } else {
+    std::memcpy(output_state, scratch0, n_batch * n_output * sizeof(float));
+  }
+}
+
+// Calculates a single LSTM gate, int8x8_16 version.
+// Implements the same functionality as CalculateLstmGateFloat.
+void CalculateLstmGateInteger8x8_16(
+    // Input and weights
+    const int8_t* input, const int8_t* input_to_gate_weights,
+    const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
+    const int32_t input_to_gate_scale_b,
+    // Output state and weights
+    const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
+    const int32_t* recurrent_to_gate_bias,
+    const int32_t recurrent_to_gate_scale_a,
+    const int32_t recurrent_to_gate_scale_b,
+    // Cell state and weights
+    const int16_t* cell_state, const int16_t* cell_to_gate_weights,
+    const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
+    // Layer normalization parameters (layer norm LSTM)
+    const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
+    const int32_t layer_norm_input_scale_a,
+    const int32_t layer_norm_input_scale_b,
+    const int32_t layer_norm_variance_guard,
+    // Array sizes
+    const int n_batch, const int n_input, const int n_output, const int n_cell,
+    const TfLiteFusedActivation activation,
+    // Output
+    int16_t* gate,
+    // Parameters for performance optimizations
+    // Scratch arrays
+    int32_t* scratch5) {
+  const bool use_peephole = (cell_to_gate_weights != nullptr);
+  const bool use_layer_norm = (layer_norm_coefficients != nullptr);
+
+  // Initialize scratch buffers with zeros. Note that unlike float and hybrid
+  // versions, bias is only used in layer normalization.
+  memset(gate, 0, n_batch * n_cell * sizeof(int16_t));
+  // For each batch and cell: compute input_weight * input.
+  micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
+      input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate,
+      nullptr);
+  // Note: no aux_input.
+
+  // For each batch and cell: compute recurrent_weight * output_state.
+  micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
+      recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
+      n_cell, 0, scratch5, gate, nullptr);
+  // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
+  if (use_peephole) {
+    micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        cell_to_gate_weights, n_output, cell_state, n_batch,
+        cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
+  }
+  // Do layer normalization (if layer norm LSTM)
+  if (use_layer_norm) {
+    micro_tensor_utils::ApplyLayerNorm(
+        gate, layer_norm_coefficients, layer_norm_bias,
+        layer_norm_input_scale_a, layer_norm_input_scale_b,
+        layer_norm_variance_guard, n_batch, n_cell, gate);
+  }
+  // Apply activation
+  switch (activation) {
+    case kTfLiteActSigmoid:
+      micro_tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
+      break;
+    case kTfLiteActTanh:
+      micro_tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
+      break;
+    default:
+      // Only Sigmoid or Tanh is used.
+      TFLITE_ASSERT_FALSE;
+  }
+}
+
+// Updates the LSTM cell state, used by both integer LSTM versions.
+// Also see UpdateLstmCellFloat.
+//
+// Parameters:
+//  - n_batch, n_cell: sizes of vectors
+//  - cell_state: input/output vector, size n_batch*n_cell
+//  - cell_state_scale: scaling factor of cell state.
+//  - input_gate: input vector, size n_batch*n_cell.
+//  - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
+//  - cell_gate: input vector, size n_batch*n_cell.
+//  - use_cifg: use 1-forget_gate instead of input_gate.
+//  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
+void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
+                           int32_t cell_state_scale, const int16_t* input_gate,
+                           int16_t* forget_gate, const int16_t* cell_gate,
+                           bool use_cifg, int16_t clip) {
+  // Use the forget_gate array as scratch, as input_gate array is not allocated
+  // in CIFG case. (Be careful not to write to the scratch before reading the
+  // forget gate data.)
+  int16_t* scratch = forget_gate;
+
+  micro_tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
+                               cell_state);
+  if (use_cifg) {
+    micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
+    micro_tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell,
+                                 30 + cell_state_scale, scratch);
+  } else {
+    micro_tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell,
+                                 30 + cell_state_scale, scratch);
+  }
+  micro_tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell,
+                               cell_state);
+
+  if (clip > 0) {
+    micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
+  }
+}
+
+// Calculates the output state tensor of an LSTM step. See Float and hybrid
+// versions as well.
+//
+// Parameters:
+//  - n_batch: batches: the number of distinct vectors in each array.
+//  - n_cell, n_output: sizes of vectors.
+//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
+//  - cell_state_scale: scaling of cell_state.
+//  - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
+//  - hidden_zp: zero_point for cell_state.*output_gate
+//  - projection_weights, proj_scale_[a|b], projection_bias:
+//      constant inputs, describing projection matrix and bias.
+//  - output_state_zp: zero point of output_state. (Input, calibrated value.)
+//  - quantized_proj_clip: if > 0, clip the output of the projection.
+//  - output_state: output vector, size n_batch*n_output. Must be contigous.
+//  - scratch0: scratch area of size n_batch*n_cell
+//  - scratch1: scratch area of size n_batch*n_cell
+//  - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
+void CalculateLstmOutputInteger8x8_16(
+    int n_batch, int n_cell, int n_output, const int16_t* cell_state,
+    int32_t cell_state_scale, const int16_t* output_gate,
+    int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
+    const int8_t* projection_weights, int32_t proj_scale_a,
+    int32_t proj_scale_b, const int32_t* projection_bias,
+    int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
+    int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
+  // Note: unlike float/hybrid, the activation is always Tanh.
+  micro_tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch,
+                                n_cell, scratch0);
+  micro_tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a,
+                               hidden_scale_b, n_batch, n_cell, hidden_zp,
+                               scratch1);
+
+  const bool use_projection = (projection_weights != nullptr);
+
+  if (use_projection) {
+    // Note: no bias like in float/hybrid
+    memset(output_state, 0, n_batch * n_output * sizeof(int8_t));
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        scratch1, projection_bias, projection_weights, proj_scale_a,
+        proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
+        output_state, nullptr);
+    if (quantized_proj_clip > 0) {
+      micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
+                                        quantized_proj_clip);
+    }
+  } else {
+    std::memcpy(output_state, scratch1, n_batch * n_output * sizeof(int8_t));
+  }
+}
+
+// Calculates a single LSTM gate, int8x8_8 version.
+// Implements the same functionality as CalculateLstmGateFloat.
+void CalculateLstmGateInteger8x8_8(
+    // Inputs and weights
+    const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
+    const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
+    const int32_t input_times_weights_scale_a,
+    const int32_t input_times_weights_scale_b,
+    const int32_t input_times_weights_zp,
+    // Output state and weights
+    const int8_t* output_state, const int32_t output_state_zp,
+    const int8_t* recurrent_to_gate_weight,
+    const int32_t recurrent_to_gate_scale_a,
+    const int32_t recurrent_to_gate_scale_b,
+    const int32_t output_state_times_weights_scale_a,
+    const int32_t output_state_times_weights_scale_b,
+    const int32_t output_state_times_weights_zp,
+    // Layer normalization parameters (layer norm LSTM)
+    const int16_t* layer_norm_gate_weight,
+    const int32_t layer_norm_gate_scale_a,
+    const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
+    // Array sizes
+    const int n_batch, const int n_input, const int n_output, const int n_cell,
+    const TfLiteFusedActivation activation,
+    // Output
+    int16_t* gate,
+    // Scratch arrays, both sized n_batch*n_cell
+    int8_t* scratch0, int8_t* scratch1) {
+  // Multiply input * input_weights => scratch0
+  micro_tensor_utils::MatrixBatchVectorMultiply(
+      input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
+      input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
+      input_times_weights_zp);
+  // Multiply output_state * recurrent_weights => scratch1
+  micro_tensor_utils::MatrixBatchVectorMultiply(
+      output_state, output_state_zp, recurrent_to_gate_weight,
+      recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
+      n_cell, scratch1, output_state_times_weights_zp);
+  // Add scratch0 + scratch1 => gate
+  micro_tensor_utils::TwoGateSaturatingAdd(
+      scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
+      input_times_weights_scale_a, input_times_weights_scale_b,
+      output_state_times_weights_scale_a, output_state_times_weights_scale_b,
+      n_batch, n_cell, gate);
+  // Apply layer normalization.
+  micro_tensor_utils::ApplyLayerNormFloat(
+      gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
+      layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
+  // Apply activation.
+  switch (activation) {
+    case kTfLiteActSigmoid:
+      micro_tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate);
+      break;
+    case kTfLiteActTanh:
+      micro_tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
+      break;
+    default:
+      // Only Sigmoid or Tanh is used.
+      TFLITE_ASSERT_FALSE;
+  }
+}
+
+// Calculates the output state tensor of an LSTM step. See Float and hybrid
+// versions as well.
+//
+// Parameters:
+//  - n_batch: batches: the number of distinct vectors in each array.
+//  - n_cell, n_output: sizes of vectors.
+//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
+//  - projection_weights, proj_scale_[a|b], projection_bias:
+//      constant inputs, describing projection matrix and bias.
+//  - output_state_zp: zero point of the output state.
+//  - quantized_proj_clip: if > 0, clip the output of the projection.
+//  - output_state: output vector, size n_batch*n_output. Must be contigous.
+//  - scratch: scratch area of size n_batch*n_cell
+void CalculateLstmOutputInteger8x8_8(
+    int n_batch, int n_cell, int n_output, const int16_t* cell_state,
+    const int16_t* output_gate, const int8_t* projection_weights,
+    int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
+    int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
+    int16_t* scratch) {
+  // Note: unlike float/hybrid, the activation is always Tanh.
+  micro_tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch);
+  micro_tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell,
+                               15 + 15 - 15, scratch);
+  // Note: no bias like in float/hybrid
+  micro_tensor_utils::MatrixBatchVectorMultiply(
+      scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
+      n_batch, n_cell, n_output, output_state_zp, output_state);
+  if (quantized_proj_clip > 0) {
+    micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
+                                      quantized_proj_clip);
+  }
+}
+
+// Performs an LSTM batch inference step for input specified by input_ptr.
+// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
+// biases (*_bias_ptr), and buffers (*_scratch), along with additional
+// parameters:
+//  - params: various LSTM params including activation, clipping, etc.,
+//  - n_batch: size of batch,
+//  - n_cell: number of cells (or units),
+//  - n_input: the input size,
+//  - n_aux_input: the auxiliary input size.
+//  - n_output: the output size.
+//  - output_batch_leading_dim: the leading dimension of the output buffer.
+//
+// Input of size 'n_batch * n_input':
+//   input_ptr
+// Input of size 'n_batch * n_aux_input':
+//   aux_input_ptr                     - optional (can be nullptr)
+//
+// LSTM weights:
+// Input weights of size 'n_cell * n_input':
+//   input_to_input_weights            - optional
+//   input_to_forget_weights
+//   input_to_cell_weights
+//   input_to_output_weights
+// Auxiliary input weights of size 'n_cell * n_aux_input':
+//   aux_input_to_input_weights        - optional
+//   aux_input_to_forget_weights       - optional
+//   aux_input_to_cell_weights         - optional
+//   aux_input_to_output_weights       - optional
+// Recurrent weights of size 'n_cell * n_output':
+//   recurrent_to_input_weights        - optional
+//   recurrent_to_forget_weights
+//   recurrent_to_cell_weights
+//   recurrent_to_input_weights
+// Peephole weights of size 'n_cell', representing diagonal matrices.
+//   cell_to_input_weights             - optional
+//   cell_to_cell_weights              - optional
+//   cell_to_output_weights            - optional
+// Projection weights of size 'n_output * n_cell'
+//   projection_weights_ptr            - optional
+// Gate biases of size 'n_cell':
+//   input_gate_bias_ptr               - optional
+//   forget_gate_bias_ptr
+//   cell_gate_bias_ptr
+//   output_gate_bias_ptr
+//
+// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
+//   input_layer_norm_coefficients_ptr  - optional
+//   forget_layer_norm_coefficients_ptr - optional
+//   cell_layer_norm_coefficients_ptr   - optional
+//   output_layer_norm_coefficients_ptr - optional
+//
+// The pointers to the cell and output state and the output are updated.
+//
+// The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
+// in batch_major order, and each step processes batch_size many inputs from
+// input_ptr, and updates batch_size many cell and output states.
+//
+// The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
+// output tensor, and in most cases will be equal to n_output. It is usually not
+// when we want to store the LSTM output into a slice of the output tensor, e.g.
+// for bidirectional LSTMs with merge_outputs. In this case, the batched
+// operations cannot be used since they assume that the batched outputs are
+// contiguous, and we manually loop over the batched outputs.
+inline void LstmStepFloat(
+    const float* input_ptr, const float* input_to_input_weights_ptr,
+    const float* input_to_forget_weights_ptr,
+    const float* input_to_cell_weights_ptr,
+    const float* input_to_output_weights_ptr, const float* aux_input_ptr,
+    const float* aux_input_to_input_weights_ptr,
+    const float* aux_input_to_forget_weights_ptr,
+    const float* aux_input_to_cell_weights_ptr,
+    const float* aux_input_to_output_weights_ptr,
+    const float* recurrent_to_input_weights_ptr,
+    const float* recurrent_to_forget_weights_ptr,
+    const float* recurrent_to_cell_weights_ptr,
+    const float* recurrent_to_output_weights_ptr,
+    const float* cell_to_input_weights_ptr,
+    const float* cell_to_forget_weights_ptr,
+    const float* cell_to_output_weights_ptr,
+    const float* input_layer_norm_coefficients_ptr,
+    const float* forget_layer_norm_coefficients_ptr,
+    const float* cell_layer_norm_coefficients_ptr,
+    const float* output_layer_norm_coefficients_ptr,
+    const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
+    const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
+    const float* projection_weights_ptr, const float* projection_bias_ptr,
+    const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+    int n_aux_input, int n_output, int output_batch_leading_dim,
+    float* output_state_ptr, float* cell_state_ptr, float* scratch0,
+    float* scratch1, float* scratch2, float* scratch3, float* output_ptr) {
+  // Since we have already checked that weights are all there or none, we can
+  // check the existence of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+
+  // Make named scratch buffers.
+  float* input_gate_scratch = scratch0;
+  float* forget_gate_scratch = scratch1;
+  float* cell_gate_scratch = scratch2;
+  float* output_gate_scratch = scratch3;
+
+  // Check if inputs are all zeros so we can skip some computations.
+  const bool is_input_all_zeros =
+      micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
+  const bool is_aux_input_all_zeros =
+      (aux_input_ptr == nullptr ||
+       micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
+  if (!use_cifg) {
+    // Calculate the input gate. (If not CIFG.)
+    CalculateLstmGateFloat(
+        input_ptr, input_to_input_weights_ptr, aux_input_ptr,
+        aux_input_to_input_weights_ptr, output_state_ptr,
+        recurrent_to_input_weights_ptr, cell_state_ptr,
+        cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
+        input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
+        /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
+        is_input_all_zeros, is_aux_input_all_zeros);
+  }
+  // Calculate the forget gate.
+  CalculateLstmGateFloat(
+      input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
+      aux_input_to_forget_weights_ptr, output_state_ptr,
+      recurrent_to_forget_weights_ptr, cell_state_ptr,
+      cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
+      forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
+      /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
+      is_aux_input_all_zeros);
+  // Calculate the cell update gate.
+  CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
+                         aux_input_to_cell_weights_ptr, output_state_ptr,
+                         recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
+                         /*cell_to_gate_weights=*/nullptr,
+                         cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
+                         n_batch, n_input, n_aux_input, n_output, n_cell,
+                         params->activation, cell_gate_scratch,
+                         is_input_all_zeros, is_aux_input_all_zeros);
+  // Update the cell state.
+  UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
+                      forget_gate_scratch, cell_gate_scratch, use_cifg,
+                      params->cell_clip);
+  // Calculate output gate.
+  CalculateLstmGateFloat(
+      input_ptr, input_to_output_weights_ptr, aux_input_ptr,
+      aux_input_to_output_weights_ptr, output_state_ptr,
+      recurrent_to_output_weights_ptr, cell_state_ptr,
+      cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
+      output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
+      /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
+      is_aux_input_all_zeros);
+  // Update the output state.
+  CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
+                           output_gate_scratch, params->activation,
+                           projection_weights_ptr, projection_bias_ptr,
+                           params->proj_clip, output_state_ptr, scratch2);
+  // Copy output state to the output. Note that the output's rows may not be
+  // contiguous (output_batch_leading_dim != n_output).
+  for (int b = 0; b < n_batch; b++) {
+    std::memcpy(output_ptr + b * output_batch_leading_dim,
+                output_state_ptr + b * n_output, n_output * sizeof(float));
+  }
+}
+
+// Same as above but with quantized weight matrices. In detail:
+// Input of size 'n_batch * n_input':
+//   input_ptr
+// Input of size 'n_batch * n_aux_input':
+//   aux_input_ptr                     - optional (can be nullptr)
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+//   input_to_input_weights            - optional
+//   input_to_forget_weights
+//   input_to_cell_weights
+//   input_to_input_weights
+// Quantized auxiliary input weights of size 'n_cell * n_aux_input':
+//   aux_input_to_input_weights        - optional
+//   aux_input_to_forget_weights       - optional
+//   aux_input_to_cell_weights         - optional
+//   aux_input_to_output_weights       - optional
+// Quantized recurrent weights of size 'n_cell * n_output':
+//   recurrent_to_input_weights        - optional
+//   recurrent_to_forget_weights
+//   recurrent_to_cell_weights
+//   recurrent_to_input_weights
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+//   cell_to_input_weights             - optional
+//   cell_to_cell_weights              - optional
+//   cell_to_output_weights            - optional
+// Quantized projection weights of size 'n_output * n_cell'
+//   projection_weights_ptr            - optional
+// Weight scales (scalars) for each of the weights above.
+//   input_to_input_weights_scale      - optional
+//   input_to_forget_weights_scale
+//   input_to_cell_weights_scale
+//   input_to_output_weights_scale
+//   aux_input_to_input_weights_scale  - optional
+//   aux_input_to_forget_weights_scale - optional
+//   aux_input_to_cell_weights_scale   - optional
+//   aux_input_to_output_weights_scale - optional
+//   recurrent_to_input_weights_scale  - optional
+//   recurrent_to_forget_weights_scale
+//   recurrent_to_cell_weights_scale
+//   recurrent_to_output_weights_scale
+//   cell_to_input_weights_scale,
+//   cell_to_forget_weights_scale,
+//   cell_to_output_weights_scale,
+//   projection_weights_scale          - optional
+// Gate biases of size 'n_cell':
+//   input_gate_bias_ptr               - optional
+//   forget_gate_bias_ptr
+//   cell_gate_bias_ptr
+//   output_gate_bias_ptr
+//
+// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
+//   input_layer_norm_coefficients_ptr  - optional
+//   forget_layer_norm_coefficients_ptr - optional
+//   cell_layer_norm_coefficients_ptr   - optional
+//   output_layer_norm_coefficients_ptr - optional
+//
+// Temporary pre-allocated storage for quantized values:
+//   quantized_input_ptr (same size as input_ptr)
+//   quantized_output_state_ptr (same size as output_state_ptr)
+//   quantized_output_scratch (same size as cell_state_ptr)
+// Temporary pre-allocated storage for recovered values:
+//   recovered_cell_weights (same size as cell_to_*_weights)
+//
+// Outputs:
+//   output_state_ptr - size 'n_batch * n_output'
+//   cell_state_ptr   - size 'n_batch * n_cell'
+//   output_ptr       - size 'n_batch * output_batch_leading_dim'
+inline void LstmStepHybrid(
+    const float* input_ptr, const int8_t* input_to_input_weights_ptr,
+    const uint8_t* input_to_input_weights_ledger_ptr,
+    float input_to_input_weights_scale,
+    const int8_t* input_to_forget_weights_ptr,
+    const uint8_t* input_to_forget_weights_ledger_ptr,
+    float input_to_forget_weights_scale,
+    const int8_t* input_to_cell_weights_ptr,
+    const uint8_t* input_to_cell_weights_ledger_ptr,
+    float input_to_cell_weights_scale,
+    const int8_t* input_to_output_weights_ptr,
+    const uint8_t* input_to_output_weights_ledger_ptr,
+    float input_to_output_weights_scale, const float* aux_input_ptr,
+    const int8_t* aux_input_to_input_weights_ptr,
+    float aux_input_to_input_weights_scale,
+    const int8_t* aux_input_to_forget_weights_ptr,
+    float aux_input_to_forget_weights_scale,
+    const int8_t* aux_input_to_cell_weights_ptr,
+    float aux_input_to_cell_weights_scale,
+    const int8_t* aux_input_to_output_weights_ptr,
+    float aux_input_to_output_weights_scale,
+    const int8_t* recurrent_to_input_weights_ptr,
+    const uint8_t* recurrent_to_input_weights_ledger_ptr,
+    float recurrent_to_input_weights_scale,
+    const int8_t* recurrent_to_forget_weights_ptr,
+    const uint8_t* recurrent_to_forget_weights_ledger_ptr,
+    float recurrent_to_forget_weights_scale,
+    const int8_t* recurrent_to_cell_weights_ptr,
+    const uint8_t* recurrent_to_cell_weights_ledger_ptr,
+    float recurrent_to_cell_weights_scale,
+    const int8_t* recurrent_to_output_weights_ptr,
+    const uint8_t* recurrent_to_output_weights_ledger_ptr,
+    float recurrent_to_output_weights_scale,
+    const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+    const int8_t* cell_to_forget_weights_ptr,
+    float cell_to_forget_weights_scale,
+    const int8_t* cell_to_output_weights_ptr,
+    float cell_to_output_weights_scale,
+    const float* input_layer_norm_coefficients_ptr,
+    const float* forget_layer_norm_coefficients_ptr,
+    const float* cell_layer_norm_coefficients_ptr,
+    const float* output_layer_norm_coefficients_ptr,
+    const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
+    const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
+    const int8_t* projection_weights_ptr,
+    const uint8_t* projection_weights_ledger_ptr,
+    float projection_weights_scale, const float* projection_bias_ptr,
+    const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+    int n_aux_input, int n_output, int output_batch_leading_dim,
+    float* scratch0, float* scratch1, float* scratch2, float* scratch3,
+    float* scales, float* input_sf, float* aux_input_sf, float* output_state_sf,
+    float* scaling_factors_scratch, float* recovered_cell_weights,
+    int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
+    int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
+    float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
+    float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
+    int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
+    bool* compute_row_sums, bool asymmetric_quantize_inputs) {
+  // Since we have already checked that weights are all there or none, we
+  // can check the existence of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+  // Make named scratch buffers for the different gates.
+  float* input_gate_scratch = scratch0;
+  float* forget_gate_scratch = scratch1;
+  float* cell_gate_scratch = scratch2;
+  float* output_gate_scratch = scratch3;
+
+  int32_t* input_to_input_row_sums = nullptr;
+  int32_t* input_to_forget_row_sums = nullptr;
+  int32_t* input_to_cell_row_sums = nullptr;
+  int32_t* input_to_output_row_sums = nullptr;
+  int32_t* aux_input_to_input_row_sums = nullptr;
+  int32_t* aux_input_to_forget_row_sums = nullptr;
+  int32_t* aux_input_to_cell_row_sums = nullptr;
+  int32_t* aux_input_to_output_row_sums = nullptr;
+  int32_t* recurrent_to_input_row_sums = nullptr;
+  int32_t* recurrent_to_forget_row_sums = nullptr;
+  int32_t* recurrent_to_cell_row_sums = nullptr;
+  int32_t* recurrent_to_output_row_sums = nullptr;
+  int32_t* projection_weights_row_sums = nullptr;
+
+  if (asymmetric_quantize_inputs) {
+    int num_row_sums = use_cifg ? 6 : 8;
+    if (aux_input_ptr != nullptr) {
+      num_row_sums += use_cifg ? 3 : 4;
+    }
+    if (projection_weights_ptr != nullptr) {
+      num_row_sums += ceil(static_cast<float>(n_output) / n_cell);
+    }
+    TFLITE_DCHECK(row_sums_size == num_row_sums);
+    input_to_input_row_sums = row_sums;
+    input_to_forget_row_sums =
+        use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
+    input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
+    input_to_output_row_sums = input_to_cell_row_sums + n_cell;
+    if (aux_input_ptr != nullptr) {
+      aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
+      aux_input_to_forget_row_sums = use_cifg
+                                         ? aux_input_to_input_row_sums
+                                         : aux_input_to_input_row_sums + n_cell;
+      aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
+      aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
+    }
+    recurrent_to_input_row_sums = aux_input_ptr
+                                      ? aux_input_to_output_row_sums + n_cell
+                                      : input_to_output_row_sums + n_cell;
+    recurrent_to_forget_row_sums = use_cifg
+                                       ? recurrent_to_input_row_sums
+                                       : recurrent_to_input_row_sums + n_cell;
+    recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
+    recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
+    if (projection_weights_ptr != nullptr) {
+      projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
+    }
+    if (*compute_row_sums) {
+      ComputeRowSums(
+          input_to_input_row_sums, input_to_forget_row_sums,
+          input_to_cell_row_sums, input_to_output_row_sums,
+          aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
+          aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
+          recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
+          recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
+          projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
+          n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+          input_to_cell_weights_ptr, input_to_output_weights_ptr,
+          aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
+          aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
+          recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+          recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+          projection_weights_ptr, use_cifg, aux_input_ptr);
+      *compute_row_sums = false;
+    }
+  }
+
+  // Check if inputs are all zeros so we can skip some computations.
+  const bool is_input_all_zeros =
+      micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
+  const bool is_aux_input_all_zeros =
+      (aux_input_ptr == nullptr ||
+       micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
+  const bool is_output_state_all_zeros =
+      micro_tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output);
+  // Quantize inputs.
+  if (!is_input_all_zeros) {
+    micro_tensor_utils::BatchQuantizeFloats(
+        input_ptr, n_batch, n_input, quantized_input_ptr, input_sf, input_zp,
+        asymmetric_quantize_inputs);
+  }
+  if (!is_aux_input_all_zeros) {
+    micro_tensor_utils::BatchQuantizeFloats(
+        aux_input_ptr, n_batch, n_aux_input, quantized_aux_input_ptr,
+        aux_input_sf, aux_input_zp, asymmetric_quantize_inputs);
+  }
+  if (!is_output_state_all_zeros) {
+    micro_tensor_utils::BatchQuantizeFloats(
+        output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
+        output_state_sf, output_state_zp, asymmetric_quantize_inputs);
+  }
+  if (!use_cifg) {
+    // Calculate the input gate. (If not CIFG.)
+    CalculateLstmGateHybrid(
+        quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
+        input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
+        input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
+        aux_input_zp, aux_input_to_input_weights_ptr,
+        aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
+        quantized_output_state_ptr, output_state_sf, output_state_zp,
+        recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
+        recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
+        cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
+        input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
+        n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
+        input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
+        is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
+        recovered_cell_weights, scales, accum_scratch_ptr);
+  }
+  // Calculate the forget gate.
+  CalculateLstmGateHybrid(
+      quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
+      input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
+      input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
+      aux_input_zp, aux_input_to_forget_weights_ptr,
+      aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
+      quantized_output_state_ptr, output_state_sf, output_state_zp,
+      recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
+      recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
+      cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+      forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
+      n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
+      forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
+      is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
+      recovered_cell_weights, scales, accum_scratch_ptr);
+  // Calculate the cell update gate.
+  CalculateLstmGateHybrid(
+      quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
+      input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
+      input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
+      aux_input_zp, aux_input_to_cell_weights_ptr,
+      aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
+      quantized_output_state_ptr, output_state_sf, output_state_zp,
+      recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
+      recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
+      /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
+      /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
+      cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
+      params->activation, cell_gate_scratch, is_input_all_zeros,
+      is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums,
+      scaling_factors_scratch, recovered_cell_weights, scales,
+      accum_scratch_ptr);
+  // Update the cell state.
+  UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
+                      forget_gate_scratch, cell_gate_scratch, use_cifg,
+                      params->cell_clip);
+  // Calculate the output gate.
+  CalculateLstmGateHybrid(
+      quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
+      input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
+      input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
+      aux_input_zp, aux_input_to_output_weights_ptr,
+      aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
+      quantized_output_state_ptr, output_state_sf, output_state_zp,
+      recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
+      recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
+      cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
+      output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
+      n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
+      output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
+      is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
+      recovered_cell_weights, scales, accum_scratch_ptr);
+  // Update the output state.
+  CalculateLstmOutputHybrid(
+      n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
+      params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
+      projection_weights_scale, projection_bias_ptr, params->proj_clip,
+      output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
+      compute_row_sums, scratch2, quantized_output_scratch, input_sf, input_zp,
+      accum_scratch_ptr, scales);
+  // Copy output state to the output. Note that the output's rows may not be
+  // contiguous (output_batch_leading_dim != n_output).
+  for (int b = 0; b < n_batch; b++) {
+    std::memcpy(output_ptr + b * output_batch_leading_dim,
+                output_state_ptr + b * n_output, n_output * sizeof(float));
+  }
+}
+
+// Fully quantized lstm kernel for 16 bit gate matmul output.
+//
+// Input tensor of size n_batch * n_input:
+//   input_ptr
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+//   input_to_input_weight_ptr            - optional
+//   input_to_forget_weight_ptr           - optional
+//   input_to_cell_weight_ptr             - optional
+//   input_to_output_weight_ptr           - optional
+//
+// Quantized recurrent weights of size 'n_cell * n_output':
+//   recurrent_to_input_weight_ptr        - optional
+//   recurrent_to_forget_weights_ptr
+//   recurrent_to_cell_weights_ptr
+//   recurrent_to_input_weights_ptr
+//
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+//   cell_to_input_weights               - optional
+//   cell_to_cell_weights                - optional
+//   cell_to_output_weights              - optional
+//
+// Quantized projection weights of size 'n_output * n_cell'
+//   projection_weight_ptr                     - optional
+//
+// Weight scales (scalars) for each of the weights above.
+//   effective_input_to_input_scale_a    - optional
+//   effective_input_to_input_scale_b    - optional
+//   effective_input_to_forget_scale_a
+//   effective_input_to_forget_scale_b
+//   effective_input_to_cell_scale_a
+//   effective_input_to_cell_scale_b
+//   effective_input_to_output_scale_a
+//   effective_input_to_output_scale_b
+//   effective_recurrent_to_input_scale_a    - optional
+//   effective_recurrent_to_input_scale_b    - optional
+//   effective_recurrent_to_forget_scale_a
+//   effective_recurrent_to_forget_scale_b
+//   effective_recurrent_to_cell_scale_a
+//   effective_recurrent_to_cell_scale_b
+//   effective_recurrent_to_output_scale_a
+//   effective_recurrent_to_output_scale_b
+//   effective_proj_scale_a                  - optional
+//   effective_proj_scale_b                  - optional
+//
+// Gate biases of size 'n_cell':
+//   input_gate_bias_ptr                 - optional
+//   forget_gate_bias_ptr
+//   cell_gate_bias_ptr
+//   output_gate_bias_ptr
+//
+// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
+//   layer_norm_input_weight_ptr    - optional
+//   layer_norm_forget_weight_ptr   - optional
+//   layer_norm_cell_weight_ptr     - optional
+//   layer_norm_output_weight_ptr   - optional
+//
+// Layer norm scales of size 'n_cell'.
+//   layer_norm_input_scale_a     - optional
+//   layer_norm_input_scale_b     - optional
+//   layer_norm_forget_scale_a    - optional
+//   layer_norm_forget_scale_b    - optional
+//   layer_norm_cell_scale_a      - optional
+//   layer_norm_cell_scale_b      - optional
+//   layer_norm_output_scale_a    - optional
+//   layer_norm_output_scale_b    - optional
+//
+// Scalar values:
+//   quantized_cell_clip: quantized clip value for cell.
+//   quantized_proj_clip: quantized clip value for projection.
+//   cell_state_scale: the power of two scale for cell state.
+//
+// Zero points:
+//   output_state_zp: zero point of output state
+//   hidden_zp: zero point for hidden state.
+//
+// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
+// n_batch.
+//   scratch0
+//   scratch1
+//   scratch2
+//   scratch3
+//   scratch4
+//   scratch5: this scratch buffer is created purely for optimizing the
+//              MatrixBatchVectorMultiplyAccumulate.
+//
+// Outputs:
+//   output_state_ptr - size 'n_batch * n_output'
+//   cell_state_ptr   - size 'n_batch * n_cell'
+//   output_ptr       - size 'n_batch * n_output'
+// TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
+inline void LstmStepInteger8x8_16(
+    const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
+    int32_t effective_input_to_input_scale_a,
+    int32_t effective_input_to_input_scale_b,
+    const int8_t* input_to_forget_weight_ptr,
+    int32_t effective_input_to_forget_scale_a,
+    int32_t effective_input_to_forget_scale_b,
+    const int8_t* input_to_cell_weight_ptr,
+    int32_t effective_input_to_cell_scale_a,
+    int32_t effective_input_to_cell_scale_b,
+    const int8_t* input_to_output_weight_ptr,
+    int32_t effective_input_to_output_scale_a,
+    int32_t effective_input_to_output_scale_b,
+    const int8_t* recurrent_to_input_weight_ptr,
+    int32_t effective_recurrent_to_input_scale_a,
+    int32_t effective_recurrent_to_input_scale_b,
+    const int8_t* recurrent_to_forget_weight_ptr,
+    int32_t effective_recurrent_to_forget_scale_a,
+    int32_t effective_recurrent_to_forget_scale_b,
+    const int8_t* recurrent_to_cell_weight_ptr,
+    int32_t effective_recurrent_to_cell_scale_a,
+    int32_t effective_recurrent_to_cell_scale_b,
+    const int8_t* recurrent_to_output_weight_ptr,
+    int32_t effective_recurrent_to_output_scale_a,
+    int32_t effective_recurrent_to_output_scale_b,
+    const int16_t* cell_to_input_weight_ptr,
+    int32_t effective_cell_to_input_scale_a,
+    int32_t effective_cell_to_input_scale_b,
+    const int16_t* cell_to_forget_weight_ptr,
+    int32_t effective_cell_to_forget_scale_a,
+    int32_t effective_cell_to_forget_scale_b,
+    const int16_t* cell_to_output_weight_ptr,
+    int32_t effective_cell_to_output_scale_a,
+    int32_t effective_cell_to_output_scale_b,
+    const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
+    int32_t effective_proj_scale_b, int32_t hidden_zp,
+    int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
+    const int16_t* layer_norm_input_weight_ptr,
+    int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
+    const int16_t* layer_norm_forget_weight_ptr,
+    int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
+    const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
+    int32_t layer_norm_cell_scale_b,
+    const int16_t* layer_norm_output_weight_ptr,
+    int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
+    const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
+    const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
+    int16_t quantized_cell_clip, int8_t quantized_proj_clip,
+    int32_t cell_state_scale, int32_t input_variance_guard,
+    int32_t forget_variance_guard, int32_t cell_variance_guard,
+    int32_t output_variance_guard,
+    const int32_t* input_to_forget_effective_bias,
+    const int32_t* recurrent_to_forget_effective_bias,
+    const int32_t* input_to_cell_effective_bias,
+    const int32_t* recurrent_to_cell_effective_bias,
+    const int32_t* input_to_output_effective_bias,
+    const int32_t* recurrent_to_output_effective_bias,
+    const int32_t* input_to_input_effective_bias,
+    const int32_t* recurrent_to_input_effective_bias,
+    const int32_t* projection_effective_bias, int n_batch, int n_cell,
+    int n_input, int n_output, int8_t* output_state_ptr,
+    int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
+    int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
+    int8_t* scratch4, int32_t* scratch5) {
+  // Make named scratch buffers for the different gates.
+  int16_t* input_gate_scratch = scratch0;
+  int16_t* forget_gate_scratch = scratch1;
+  int16_t* cell_gate_scratch = scratch2;
+  int16_t* output_gate_scratch = scratch3;
+
+  // Since we have already checked that weights are all there or none, we
+  // can check the existence of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weight_ptr == nullptr);
+
+  // Check for nullptrs.
+  TFLITE_DCHECK(input_to_forget_effective_bias);
+  TFLITE_DCHECK(recurrent_to_forget_effective_bias);
+  TFLITE_DCHECK(input_to_cell_effective_bias);
+  TFLITE_DCHECK(recurrent_to_cell_effective_bias);
+  TFLITE_DCHECK(input_to_output_effective_bias);
+  TFLITE_DCHECK(recurrent_to_output_effective_bias);
+  if (!use_cifg) {
+    TFLITE_DCHECK(input_to_input_effective_bias);
+    TFLITE_DCHECK(recurrent_to_input_effective_bias);
+  }
+  const bool use_projection = (projection_weight_ptr != nullptr);
+  if (use_projection) {
+    TFLITE_DCHECK(projection_effective_bias);
+  }
+  if (!use_cifg) {
+    // Calculate the input gate. (If not CIFG.)
+    CalculateLstmGateInteger8x8_16(
+        input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
+        effective_input_to_input_scale_a, effective_input_to_input_scale_b,
+        output_state_ptr, recurrent_to_input_weight_ptr,
+        recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
+        effective_recurrent_to_input_scale_b, cell_state_ptr,
+        cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
+        effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
+        input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
+        input_variance_guard, n_batch, n_input, n_output, n_cell,
+        kTfLiteActSigmoid, input_gate_scratch, scratch5);
+  }
+  // Calculate the forget gate.
+  CalculateLstmGateInteger8x8_16(
+      input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
+      effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
+      output_state_ptr, recurrent_to_forget_weight_ptr,
+      recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
+      effective_recurrent_to_forget_scale_b, cell_state_ptr,
+      cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
+      effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
+      forget_gate_bias_ptr, layer_norm_forget_scale_a,
+      layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
+      n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, scratch5);
+  // Calculate the cell update gate.
+  CalculateLstmGateInteger8x8_16(
+      input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
+      effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
+      output_state_ptr, recurrent_to_cell_weight_ptr,
+      recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
+      effective_recurrent_to_cell_scale_b, cell_state_ptr,
+      /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
+      /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
+      cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
+      cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
+      cell_gate_scratch, scratch5);
+  // Update the cell state.
+  UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
+                        input_gate_scratch, forget_gate_scratch,
+                        cell_gate_scratch, use_cifg, quantized_cell_clip);
+  // Calculate the output gate.
+  CalculateLstmGateInteger8x8_16(
+      input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
+      effective_input_to_output_scale_a, effective_input_to_output_scale_b,
+      output_state_ptr, recurrent_to_output_weight_ptr,
+      recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
+      effective_recurrent_to_output_scale_b, cell_state_ptr,
+      cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
+      effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
+      output_gate_bias_ptr, layer_norm_output_scale_a,
+      layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
+      n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, scratch5);
+  // Update the output state.
+  CalculateLstmOutputInteger8x8_16(
+      n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
+      output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
+      hidden_zp, projection_weight_ptr, effective_proj_scale_a,
+      effective_proj_scale_b, projection_effective_bias, output_state_zp,
+      quantized_proj_clip, output_state_ptr, scratch0, scratch4, scratch5);
+  // Copy output state to the output. Note that unlike float or hybrid, output
+  // is always contiguous.
+  std::memcpy(output_ptr, output_state_ptr,
+              n_batch * n_output * sizeof(int8_t));
+}
+
+// Fully quantized lstm kernel for 8 bit gate matmul output.
+//
+// Input tensor of size n_batch * n_input:
+//   input_ptr
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+//   input_to_input_weight_ptr            - optional
+//   input_to_forget_weight_ptr           - optional
+//   input_to_cell_weight_ptr             - optional
+//   input_to_output_weight_ptr           - optional
+//
+// Quantized recurrent weights of size 'n_cell * n_output':
+//   recurrent_to_input_weight_ptr        - optional
+//   recurrent_to_forget_weights_ptr
+//   recurrent_to_cell_weights_ptr
+//   recurrent_to_input_weights_ptr
+//
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+//   cell_to_input_weights               - optional
+//   cell_to_cell_weights                - optional
+//   cell_to_output_weights              - optional
+//
+// Quantized projection weights of size 'n_output * n_cell'
+//   projection_weight_ptr                     - optional
+//
+// Weight scales (scalars) for each of the weights above.
+//   effective_input_to_input_scale_a    - optional
+//   effective_input_to_input_scale_b    - optional
+//   effective_input_to_forget_scale_a
+//   effective_input_to_forget_scale_b
+//   effective_input_to_cell_scale_a
+//   effective_input_to_cell_scale_b
+//   effective_input_to_output_scale_a
+//   effective_input_to_output_scale_b
+//   effective_recurrent_to_input_scale_a    - optional
+//   effective_recurrent_to_input_scale_b    - optional
+//   effective_recurrent_to_forget_scale_a
+//   effective_recurrent_to_forget_scale_b
+//   effective_recurrent_to_cell_scale_a
+//   effective_recurrent_to_cell_scale_b
+//   effective_recurrent_to_output_scale_a
+//   effective_recurrent_to_output_scale_b
+//   effective_proj_scale_a                  - optional
+//   effective_proj_scale_b                  - optional
+//
+// Gate biases of size 'n_cell':
+//   input_gate_bias_ptr                 - optional
+//   forget_gate_bias_ptr
+//   cell_gate_bias_ptr
+//   output_gate_bias_ptr
+//
+// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
+//   layer_norm_input_weight_ptr    - optional
+//   layer_norm_forget_weight_ptr   - optional
+//   layer_norm_cell_weight_ptr     - optional
+//   layer_norm_output_weight_ptr   - optional
+//
+// Layer norm scales of size 'n_cell'.
+//   layer_norm_input_scale_a     - optional
+//   layer_norm_input_scale_b     - optional
+//   layer_norm_forget_scale_a    - optional
+//   layer_norm_forget_scale_b    - optional
+//   layer_norm_cell_scale_a      - optional
+//   layer_norm_cell_scale_b      - optional
+//   layer_norm_output_scale_a    - optional
+//   layer_norm_output_scale_b    - optional
+//
+// Scalar values:
+//   quantized_cell_clip: quantized clip value for cell.
+//   quantized_proj_clip: quantized clip value for projection.
+//   cell_state_scale: the power of two scale for cell state.
+//
+// Zero points:
+//   input_zp: zero point for input tensor.
+//   output_state_zp: zero point of output state.
+//   hidden_zp: zero point for hidden state.
+//
+// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
+// n_batch.
+//   scratch0
+//   scratch1
+//   scratch2
+//   scratch3
+//   scratch4
+//   scratch5
+//   scratch6
+//   scratch7
+//
+// Outputs:
+//   output_state_ptr - size 'n_batch * n_output'
+//   cell_state_ptr   - size 'n_batch * n_cell'
+//   output_ptr       - size 'n_batch * n_output'
+//
+// Can move zero point calculation into Prepare() for better perfomance.
+// TODO(b/159947023): scratch5 is unused, remove.
+inline void LstmStepInteger8x8_8(
+    const int8_t* input_ptr, int32_t input_zp,
+    const int8_t* input_to_input_weight_ptr,
+    int32_t effective_input_to_input_scale_a,
+    int32_t effective_input_to_input_scale_b,
+    const int8_t* input_to_forget_weight_ptr,
+    int32_t effective_input_to_forget_scale_a,
+    int32_t effective_input_to_forget_scale_b,
+    const int8_t* input_to_cell_weight_ptr,
+    int32_t effective_input_to_cell_scale_a,
+    int32_t effective_input_to_cell_scale_b,
+    const int8_t* input_to_output_weight_ptr,
+    int32_t effective_input_to_output_scale_a,
+    int32_t effective_input_to_output_scale_b,
+    const int8_t* recurrent_to_input_weight_ptr,
+    int32_t effective_recurrent_to_input_scale_a,
+    int32_t effective_recurrent_to_input_scale_b,
+    const int8_t* recurrent_to_forget_weight_ptr,
+    int32_t effective_recurrent_to_forget_scale_a,
+    int32_t effective_recurrent_to_forget_scale_b,
+    const int8_t* recurrent_to_cell_weight_ptr,
+    int32_t effective_recurrent_to_cell_scale_a,
+    int32_t effective_recurrent_to_cell_scale_b,
+    const int8_t* recurrent_to_output_weight_ptr,
+    int32_t effective_recurrent_to_output_scale_a,
+    int32_t effective_recurrent_to_output_scale_b,
+    const int8_t* cell_to_input_weight_ptr,
+    int32_t effective_cell_to_input_scale_a,
+    int32_t effective_cell_to_input_scale_b,
+    const int8_t* cell_to_forget_weight_ptr,
+    int32_t effective_cell_to_forget_scale_a,
+    int32_t effective_cell_to_forget_scale_b,
+    const int8_t* cell_to_output_weight_ptr,
+    int32_t effective_cell_to_output_scale_a,
+    int32_t effective_cell_to_output_scale_b,
+    const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
+    int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
+    int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
+    const int16_t* layer_norm_forget_weight_ptr,
+    int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
+    const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
+    int32_t layer_norm_cell_scale_b,
+    const int16_t* layer_norm_output_weight_ptr,
+    int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
+    const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
+    const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
+    const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
+    const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
+    const int32_t* intermediate_zp, int16_t quantized_cell_clip,
+    int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
+    int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
+    int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
+    int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
+    int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
+    int16_t* scratch7) {
+  // TODO(b/159066113): scratch5 is unused, remove.
+
+  // Make named scratch buffers for the different gates.
+  int16_t* forget_gate_scratch = scratch2;
+  int16_t* cell_gate_scratch = scratch3;
+  int16_t* output_gate_scratch = scratch4;
+  // no-CIFG is not supported here
+
+  // Calculate the forget gate.
+  CalculateLstmGateInteger8x8_8(
+      input_ptr, input_zp, input_to_forget_weight_ptr,
+      effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
+      intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
+      output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
+      effective_recurrent_to_forget_scale_a,
+      effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
+      intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
+      layer_norm_forget_scale_a, layer_norm_forget_scale_b,
+      forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
+      kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
+  // Calculate the cell update gate.
+  CalculateLstmGateInteger8x8_8(
+      input_ptr, input_zp, input_to_cell_weight_ptr,
+      effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
+      intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
+      output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
+      effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
+      intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
+      layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
+      layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
+      n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
+  // Update the cell state.
+  UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
+                        /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
+                        forget_gate_scratch, cell_gate_scratch,
+                        /*use_cifg=*/true, quantized_cell_clip);
+  // Calculate the output gate.
+  CalculateLstmGateInteger8x8_8(
+      input_ptr, input_zp, input_to_output_weight_ptr,
+      effective_input_to_output_scale_a, effective_input_to_output_scale_b,
+      intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
+      output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
+      effective_recurrent_to_output_scale_a,
+      effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
+      intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
+      layer_norm_output_scale_a, layer_norm_output_scale_b,
+      output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
+      kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
+  // Update the output state.
+  CalculateLstmOutputInteger8x8_8(
+      n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
+      projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
+      projection_bias_ptr, output_state_zp, quantized_proj_clip,
+      output_state_ptr, scratch2);
+  // Copy output state to the output. Note that unlike float or hybrid, output
+  // is always contigous.
+  std::memcpy(output_ptr, output_state_ptr,
+              n_batch * n_output * sizeof(int8_t));
+}
+
+}  // namespace
+
+TfLiteStatus EvalFloatLstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* aux_input,
+    const TfLiteEvalTensor* aux_input_to_input_weights,
+    const TfLiteEvalTensor* aux_input_to_forget_weights,
+    const TfLiteEvalTensor* aux_input_to_cell_weights,
+    const TfLiteEvalTensor* aux_input_to_output_weights,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major, int output_offset,
+    float* scratch_buffer, TfLiteEvalTensor* output_state,
+    TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output) {
+  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
+  int max_time, n_batch;
+  if (input->dims->size == 3) {
+    max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
+    n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
+  } else {
+    max_time = 1;
+    n_batch = input->dims->data[0];
+  }
+  const int n_input = input->dims->data[input->dims->size - 1];
+  const int aux_input_size =
+      (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Since we have already checked that weights are all there or none, we can
+  // check the existence of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weights == nullptr);
+
+  // Index the scratch buffers pointers to the global scratch buffer.
+  float* input_gate_scratch = nullptr;
+  float* cell_gate_scratch = nullptr;
+  float* forget_gate_scratch = nullptr;
+  float* output_gate_scratch = nullptr;
+  if (use_cifg) {
+    cell_gate_scratch = scratch_buffer;
+    forget_gate_scratch = scratch_buffer + n_cell * n_batch;
+    output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
+  } else {
+    input_gate_scratch = scratch_buffer;
+    cell_gate_scratch = scratch_buffer + n_cell * n_batch;
+    forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
+    output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch;
+  }
+
+  const int output_batch_leading_dim =
+      output->dims->data[output->dims->size - 1];
+  if (time_major) {
+    // Loop through the sequence.
+    const int input_step = n_batch * n_input;
+    const int output_step = n_batch * output_batch_leading_dim;
+    for (int t = 0; t < max_time; t++) {
+      // If this is the forward_sequence, step forward, otherwise step
+      // backwards.
+      const int t_rel = forward_sequence ? t : max_time - t - 1;
+      const float* input_ptr =
+          tflite::micro::GetTensorData<float>(input) + t_rel * input_step;
+      const float* aux_input_ptr = nullptr;
+      if (aux_input) {
+        aux_input_ptr =
+            tflite::micro::GetTensorData<float>(aux_input) + t_rel * input_step;
+      }
+      float* output_ptr = tflite::micro::GetTensorData<float>(output) +
+                          t_rel * output_step + output_offset;
+
+      LstmStepFloat(
+          input_ptr,
+          input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_to_input_weights),
+          input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_to_forget_weights),
+          input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_to_cell_weights),
+          input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_to_output_weights),
+          aux_input_ptr,
+          aux_input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(aux_input_to_input_weights),
+          aux_input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    aux_input_to_forget_weights),
+          aux_input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(aux_input_to_cell_weights),
+          aux_input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    aux_input_to_output_weights),
+          recurrent_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(recurrent_to_input_weights),
+          recurrent_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    recurrent_to_forget_weights),
+          recurrent_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(recurrent_to_cell_weights),
+          recurrent_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    recurrent_to_output_weights),
+          cell_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_to_input_weights),
+          cell_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_to_forget_weights),
+          cell_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_to_output_weights),
+          input_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    input_layer_norm_coefficients),
+          forget_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    forget_layer_norm_coefficients),
+          cell_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    cell_layer_norm_coefficients),
+          output_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    output_layer_norm_coefficients),
+          input_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_gate_bias),
+          forget_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(forget_gate_bias),
+          cell_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_gate_bias),
+          output_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(output_gate_bias),
+          projection_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(projection_weights),
+          projection_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(projection_bias),
+          params, n_batch, n_cell, n_input, aux_input_size, n_output,
+          output_batch_leading_dim,
+          tflite::micro::GetTensorData<float>(output_state),
+          tflite::micro::GetTensorData<float>(cell_state), input_gate_scratch,
+          forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
+          output_ptr);
+    }
+  } else {
+    for (int b = 0; b < n_batch; b++) {
+      const int input_step = n_input;
+      const int output_step = output_batch_leading_dim;
+      for (int t = 0; t < max_time; t++) {
+        // If this is the forward_sequence, step forward, otherwise step
+        // backwards.
+        const int t_rel = forward_sequence ? t : max_time - t - 1;
+        const int time_offset = b * max_time + t_rel;
+        const float* input_ptr = tflite::micro::GetTensorData<float>(input) +
+                                 time_offset * input_step;
+        const float* aux_input_ptr = nullptr;
+        if (aux_input) {
+          aux_input_ptr = tflite::micro::GetTensorData<float>(aux_input) +
+                          time_offset * input_step;
+        }
+        float* output_ptr = tflite::micro::GetTensorData<float>(output) +
+                            time_offset * output_step + output_offset;
+
+        // Offset the {output,cell}_state pointers to the right batch.
+        float* output_state_ptr =
+            tflite::micro::GetTensorData<float>(output_state) +
+            b * output_batch_leading_dim;
+        float* cell_state_ptr =
+            tflite::micro::GetTensorData<float>(cell_state) + b * n_cell;
+        // Offset the scratch pointers to the right batch.
+        float* input_gate_scratch_ptr =
+            input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
+        float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
+        float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
+        float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
+
+        LstmStepFloat(
+            input_ptr,
+            input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_to_input_weights),
+            input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_to_forget_weights),
+            input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_to_cell_weights),
+            input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_to_output_weights),
+            aux_input_ptr,
+            aux_input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      aux_input_to_input_weights),
+            aux_input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      aux_input_to_forget_weights),
+            aux_input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      aux_input_to_cell_weights),
+            aux_input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      aux_input_to_output_weights),
+            recurrent_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      recurrent_to_input_weights),
+            recurrent_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      recurrent_to_forget_weights),
+            recurrent_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      recurrent_to_cell_weights),
+            recurrent_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      recurrent_to_output_weights),
+            cell_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_to_input_weights),
+            cell_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_to_forget_weights),
+            cell_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_to_output_weights),
+            input_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      input_layer_norm_coefficients),
+            forget_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      forget_layer_norm_coefficients),
+            cell_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      cell_layer_norm_coefficients),
+            output_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      output_layer_norm_coefficients),
+            input_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_gate_bias),
+            forget_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(forget_gate_bias),
+            cell_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_gate_bias),
+            output_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(output_gate_bias),
+            projection_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(projection_weights),
+            projection_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(projection_bias),
+            params,
+            /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
+            output_batch_leading_dim, output_state_ptr, cell_state_ptr,
+            input_gate_scratch_ptr, forget_gate_scratch_ptr,
+            cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
+      }
+    }
+  }
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybridLstm(
+    const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_input_weights_ledger,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_forget_weights_ledger,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_cell_weights_ledger,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* input_to_output_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* aux_input,
+    const TfLiteEvalTensor* aux_input_to_input_weights,
+    const TfLiteEvalTensor* aux_input_to_forget_weights,
+    const TfLiteEvalTensor* aux_input_to_cell_weights,
+    const TfLiteEvalTensor* aux_input_to_output_weights,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_weights_ledger,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major, int output_offset,
+    float* scratch_buffer, float* input_sf, float* aux_input_sf,
+    float* output_state_sf, float* prod_scaling_factors,
+    float* recovered_cell_weights, int8_t* input_quantized,
+    int8_t* aux_input_quantized, int8_t* output_state_quantized,
+    int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
+    TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
+    TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
+    int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
+    bool* compute_row_sums) {
+  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
+  const int n_input = input->dims->data[input->dims->size - 1];
+  int max_time, n_batch;
+  if (input->dims->size == 2) {
+    max_time = 1;
+    n_batch = input->dims->data[0];
+  } else {
+    max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
+    n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
+  }
+  const int aux_input_size =
+      (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Since we have already checked that weights are all there or none, we can
+  // check the existence of only one to get the condition.
+  const bool use_cifg = (input_to_input_weights == nullptr);
+
+  float* input_gate_scratch = nullptr;
+  float* cell_gate_scratch = nullptr;
+  float* forget_gate_scratch = nullptr;
+  float* output_gate_scratch = nullptr;
+  if (use_cifg) {
+    cell_gate_scratch = scratch_buffer;
+    forget_gate_scratch = scratch_buffer + n_cell * n_batch;
+    output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
+  } else {
+    input_gate_scratch = scratch_buffer;
+    cell_gate_scratch = scratch_buffer + n_cell * n_batch;
+    forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
+    output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch;
+  }
+
+  const int output_batch_leading_dim =
+      output->dims->data[output->dims->size - 1];
+
+  int32_t* input_zp_ptr = nullptr;
+  int32_t* aux_input_zp_ptr = nullptr;
+  int32_t* output_state_zp_ptr = nullptr;
+  int32_t* row_sums_ptr = nullptr;
+  if (params->asymmetric_quantize_inputs) {
+    input_zp_ptr = input_zp;
+    aux_input_zp_ptr = aux_input_zp;
+    output_state_zp_ptr = output_state_zp;
+    row_sums_ptr = row_sums;
+  }
+
+  if (time_major) {
+    // Feed the sequence into the LSTM step-by-step.
+    const int input_step = n_batch * n_input;
+    const int output_step = n_batch * output_batch_leading_dim;
+    for (int t = 0; t < max_time; t++) {
+      // If this is the forward_sequence, step forward, otherwise step
+      // backwards.
+      const int t_rel = forward_sequence ? t : max_time - t - 1;
+      const float* input_ptr =
+          tflite::micro::GetTensorData<float>(input) + t_rel * input_step;
+      const float* aux_input_ptr = nullptr;
+      if (aux_input) {
+        aux_input_ptr =
+            tflite::micro::GetTensorData<float>(aux_input) + t_rel * input_step;
+      }
+      float* output_ptr = tflite::micro::GetTensorData<float>(output) +
+                          t_rel * output_step + output_offset;
+      LstmStepHybrid(
+          input_ptr,
+          input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+          input_to_input_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    input_to_input_weights_ledger),
+          hybrid_lstm_scales->input_to_input_weights_scale,
+          input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+          input_to_forget_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    input_to_forget_weights_ledger),
+          hybrid_lstm_scales->input_to_forget_weights_scale,
+          input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+          input_to_cell_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    input_to_cell_weights_ledger),
+          hybrid_lstm_scales->input_to_cell_weights_scale,
+          input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+          input_to_output_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    input_to_output_weights_ledger),
+          hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr,
+          aux_input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    aux_input_to_input_weights),
+          hybrid_lstm_scales->aux_input_to_input_weights_scale,
+          aux_input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    aux_input_to_forget_weights),
+          hybrid_lstm_scales->aux_input_to_forget_weights_scale,
+          aux_input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(aux_input_to_cell_weights),
+          hybrid_lstm_scales->aux_input_to_cell_weights_scale,
+          aux_input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    aux_input_to_output_weights),
+          hybrid_lstm_scales->aux_input_to_output_weights_scale,
+          recurrent_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_input_weights),
+          recurrent_to_input_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    recurrent_to_input_weights_ledger),
+          hybrid_lstm_scales->recurrent_to_input_weights_scale,
+          recurrent_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_forget_weights),
+          recurrent_to_forget_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    recurrent_to_forget_weights_ledger),
+          hybrid_lstm_scales->recurrent_to_forget_weights_scale,
+          recurrent_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
+          recurrent_to_cell_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    recurrent_to_cell_weights_ledger),
+          hybrid_lstm_scales->recurrent_to_cell_weights_scale,
+          recurrent_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_output_weights),
+          recurrent_to_output_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    recurrent_to_output_weights_ledger),
+          hybrid_lstm_scales->recurrent_to_output_weights_scale,
+          cell_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
+          hybrid_lstm_scales->cell_to_input_weights_scale,
+          cell_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
+          hybrid_lstm_scales->cell_to_forget_weights_scale,
+          cell_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
+          hybrid_lstm_scales->cell_to_output_weights_scale,
+          input_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    input_layer_norm_coefficients),
+          forget_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    forget_layer_norm_coefficients),
+          cell_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    cell_layer_norm_coefficients),
+          output_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    output_layer_norm_coefficients),
+          input_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_gate_bias),
+          forget_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(forget_gate_bias),
+          cell_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_gate_bias),
+          output_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(output_gate_bias),
+          projection_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(projection_weights),
+          projection_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    projection_weights_ledger),
+          hybrid_lstm_scales->projection_weights_scale,
+          projection_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(projection_bias),
+          params, n_batch, n_cell, n_input, aux_input_size, n_output,
+          output_batch_leading_dim, input_gate_scratch, forget_gate_scratch,
+          cell_gate_scratch, output_gate_scratch, scales, input_sf,
+          aux_input_sf, output_state_sf, prod_scaling_factors,
+          recovered_cell_weights, input_quantized, aux_input_quantized,
+          output_state_quantized, cell_state_quantized,
+          tflite::micro::GetTensorData<float>(output_state),
+          tflite::micro::GetTensorData<float>(cell_state),
+          output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr,
+          output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
+          params->asymmetric_quantize_inputs);
+    }
+  } else {
+    for (int b = 0; b < n_batch; b++) {
+      const int input_step = n_input;
+      const int output_step = output_batch_leading_dim;
+      for (int t = 0; t < max_time; t++) {
+        // If this is the forward_sequence, step forward, otherwise step
+        // backwards.
+        const int t_rel = forward_sequence ? t : max_time - t - 1;
+        const int time_offset = b * max_time + t_rel;
+        const float* input_ptr = tflite::micro::GetTensorData<float>(input) +
+                                 time_offset * input_step;
+        const float* aux_input_ptr = nullptr;
+        if (aux_input) {
+          aux_input_ptr = tflite::micro::GetTensorData<float>(aux_input) +
+                          time_offset * input_step;
+        }
+        float* output_ptr = tflite::micro::GetTensorData<float>(output) +
+                            time_offset * output_step + output_offset;
+
+        // Offset the {output,cell}_state pointers to the right batch.
+        float* output_state_ptr =
+            tflite::micro::GetTensorData<float>(output_state) +
+            b * output_batch_leading_dim;
+        float* cell_state_ptr =
+            tflite::micro::GetTensorData<float>(cell_state) + b * n_cell;
+        // Offset the scratch pointers to the right batch.
+        float* input_gate_scratch_ptr =
+            input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
+        float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
+        float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
+        float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
+
+        LstmStepHybrid(
+            input_ptr,
+            input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+            input_to_input_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      input_to_input_weights_ledger),
+            hybrid_lstm_scales->input_to_input_weights_scale,
+            input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+            input_to_forget_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      input_to_forget_weights_ledger),
+            hybrid_lstm_scales->input_to_forget_weights_scale,
+            input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+            input_to_cell_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      input_to_cell_weights_ledger),
+            hybrid_lstm_scales->input_to_cell_weights_scale,
+            input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+            input_to_output_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      input_to_output_weights_ledger),
+            hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr,
+            aux_input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      aux_input_to_input_weights),
+            hybrid_lstm_scales->aux_input_to_input_weights_scale,
+            aux_input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      aux_input_to_forget_weights),
+            hybrid_lstm_scales->aux_input_to_forget_weights_scale,
+            aux_input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      aux_input_to_cell_weights),
+            hybrid_lstm_scales->aux_input_to_cell_weights_scale,
+            aux_input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      aux_input_to_output_weights),
+            hybrid_lstm_scales->aux_input_to_output_weights_scale,
+            recurrent_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_input_weights),
+            recurrent_to_input_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      recurrent_to_input_weights_ledger),
+            hybrid_lstm_scales->recurrent_to_input_weights_scale,
+            recurrent_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_forget_weights),
+            recurrent_to_forget_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      recurrent_to_forget_weights_ledger),
+            hybrid_lstm_scales->recurrent_to_forget_weights_scale,
+            recurrent_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_cell_weights),
+            recurrent_to_cell_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      recurrent_to_cell_weights_ledger),
+            hybrid_lstm_scales->recurrent_to_cell_weights_scale,
+            recurrent_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_output_weights),
+            recurrent_to_output_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      recurrent_to_output_weights_ledger),
+            hybrid_lstm_scales->recurrent_to_output_weights_scale,
+            cell_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
+            hybrid_lstm_scales->cell_to_input_weights_scale,
+            cell_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
+            hybrid_lstm_scales->cell_to_forget_weights_scale,
+            cell_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
+            hybrid_lstm_scales->cell_to_output_weights_scale,
+            input_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      input_layer_norm_coefficients),
+            forget_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      forget_layer_norm_coefficients),
+            cell_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      cell_layer_norm_coefficients),
+            output_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      output_layer_norm_coefficients),
+            input_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_gate_bias),
+            forget_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(forget_gate_bias),
+            cell_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_gate_bias),
+            output_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(output_gate_bias),
+            projection_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(projection_weights),
+            projection_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      projection_weights_ledger),
+            hybrid_lstm_scales->projection_weights_scale,
+            projection_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(projection_bias),
+            params,
+            /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
+            output_batch_leading_dim, input_gate_scratch_ptr,
+            forget_gate_scratch_ptr, cell_gate_scratch_ptr,
+            output_gate_scratch_ptr, scales, input_sf, aux_input_sf,
+            output_state_sf, prod_scaling_factors, recovered_cell_weights,
+            input_quantized, aux_input_quantized, output_state_quantized,
+            cell_state_quantized, output_state_ptr, cell_state_ptr,
+            output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr,
+            output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
+            params->asymmetric_quantize_inputs);
+      }
+    }
+  }
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalInteger8x8_16Lstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major,
+    const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
+    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
+    TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
+    int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5) {
+  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
+  const int n_input = input->dims->data[input->dims->size - 1];
+  int max_time, n_batch;
+  if (input->dims->size == 2) {
+    max_time = 1;
+    n_batch = input->dims->data[0];
+  } else {
+    max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
+    n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
+  }
+
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Get params for time/batch/sequence.
+  const int output_batch_leading_dim =
+      output->dims->data[output->dims->size - 1];
+
+  if (time_major) {
+    const int input_step = n_batch * n_input;
+    const int output_step = n_batch * output_batch_leading_dim;
+    for (int t = 0; t < max_time; t++) {
+      const int t_rel = t;
+      int8_t* output_ptr =
+          tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
+      const int8_t* input_ptr =
+          tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
+      LstmStepInteger8x8_16(
+          input_ptr,
+          input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+          integer_lstm_param->effective_input_to_input_scale_a,
+          integer_lstm_param->effective_input_to_input_scale_b,
+          input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+          integer_lstm_param->effective_input_to_forget_scale_a,
+          integer_lstm_param->effective_input_to_forget_scale_b,
+          input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+          integer_lstm_param->effective_input_to_cell_scale_a,
+          integer_lstm_param->effective_input_to_cell_scale_b,
+          input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+          integer_lstm_param->effective_input_to_output_scale_a,
+          integer_lstm_param->effective_input_to_output_scale_b,
+          recurrent_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_input_weights),
+          integer_lstm_param->effective_recurrent_to_input_scale_a,
+          integer_lstm_param->effective_recurrent_to_input_scale_b,
+          recurrent_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_forget_weights),
+          integer_lstm_param->effective_recurrent_to_forget_scale_a,
+          integer_lstm_param->effective_recurrent_to_forget_scale_b,
+          recurrent_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
+          integer_lstm_param->effective_recurrent_to_cell_scale_a,
+          integer_lstm_param->effective_recurrent_to_cell_scale_b,
+          recurrent_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_output_weights),
+          integer_lstm_param->effective_recurrent_to_output_scale_a,
+          integer_lstm_param->effective_recurrent_to_output_scale_b,
+          cell_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
+          integer_lstm_param->effective_cell_to_input_scale_a,
+          integer_lstm_param->effective_cell_to_input_scale_b,
+          cell_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
+          integer_lstm_param->effective_cell_to_forget_scale_a,
+          integer_lstm_param->effective_cell_to_forget_scale_b,
+          cell_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
+          integer_lstm_param->effective_cell_to_output_scale_a,
+          integer_lstm_param->effective_cell_to_output_scale_b,
+          projection_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(projection_weights),
+          integer_lstm_param->effective_proj_scale_a,
+          integer_lstm_param->effective_proj_scale_b,
+          integer_lstm_param->hidden_zp,
+          integer_lstm_param->effective_hidden_scale_a,
+          integer_lstm_param->effective_hidden_scale_b,
+          input_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(
+                    input_layer_norm_coefficients),
+          integer_lstm_param->layer_norm_input_scale_a,
+          integer_lstm_param->layer_norm_input_scale_b,
+          forget_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(
+                    forget_layer_norm_coefficients),
+          integer_lstm_param->layer_norm_forget_scale_a,
+          integer_lstm_param->layer_norm_forget_scale_b,
+          cell_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(
+                    cell_layer_norm_coefficients),
+          integer_lstm_param->layer_norm_cell_scale_a,
+          integer_lstm_param->layer_norm_cell_scale_b,
+          output_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(
+                    output_layer_norm_coefficients),
+          integer_lstm_param->layer_norm_output_scale_a,
+          integer_lstm_param->layer_norm_output_scale_b,
+          input_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
+          forget_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
+          cell_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
+          output_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
+          integer_lstm_param->quantized_cell_clip,
+          integer_lstm_param->quantized_proj_clip,
+          integer_lstm_param->cell_scale,
+          integer_lstm_param->input_variance_guard,
+          integer_lstm_param->forget_variance_guard,
+          integer_lstm_param->cell_variance_guard,
+          integer_lstm_param->output_variance_guard,
+          integer_lstm_param->input_to_forget_effective_bias,
+          integer_lstm_param->recurrent_to_forget_effective_bias,
+          integer_lstm_param->input_to_cell_effective_bias,
+          integer_lstm_param->recurrent_to_cell_effective_bias,
+          integer_lstm_param->input_to_output_effective_bias,
+          integer_lstm_param->recurrent_to_output_effective_bias,
+          integer_lstm_param->input_to_input_effective_bias,
+          integer_lstm_param->recurrent_to_input_effective_bias,
+          integer_lstm_param->projection_effective_bias, n_batch, n_cell,
+          n_input, n_output, tflite::micro::GetTensorData<int8_t>(output_state),
+          output_state_zp, tflite::micro::GetTensorData<int16_t>(cell_state),
+          output_ptr, scratch0, scratch1, scratch2, scratch3, scratch4,
+          scratch5);
+    }
+  } else {
+    for (int b = 0; b < n_batch; b++) {
+      const int input_step = n_input;
+      const int output_step = output_batch_leading_dim;
+      for (int t = 0; t < max_time; t++) {
+        // If this is the forward_sequence, step forward, otherwise step
+        // backwards.
+        const int t_rel = forward_sequence ? t : max_time - t - 1;
+        const int time_offset = b * max_time + t_rel;
+        const int8_t* input_ptr = tflite::micro::GetTensorData<int8_t>(input) +
+                                  time_offset * input_step;
+        int8_t* output_ptr = tflite::micro::GetTensorData<int8_t>(output) +
+                             time_offset * output_step;
+
+        // Offset the {output,cell}_state pointers to the right batch.
+        int8_t* output_state_ptr =
+            tflite::micro::GetTensorData<int8_t>(output_state) +
+            b * output_batch_leading_dim;
+        int16_t* cell_state_ptr =
+            tflite::micro::GetTensorData<int16_t>(cell_state) + b * n_cell;
+
+        LstmStepInteger8x8_16(
+            input_ptr,
+            input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+            integer_lstm_param->effective_input_to_input_scale_a,
+            integer_lstm_param->effective_input_to_input_scale_b,
+            input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+            integer_lstm_param->effective_input_to_forget_scale_a,
+            integer_lstm_param->effective_input_to_forget_scale_b,
+            input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+            integer_lstm_param->effective_input_to_cell_scale_a,
+            integer_lstm_param->effective_input_to_cell_scale_b,
+            input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+            integer_lstm_param->effective_input_to_output_scale_a,
+            integer_lstm_param->effective_input_to_output_scale_b,
+            recurrent_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_input_weights),
+            integer_lstm_param->effective_recurrent_to_input_scale_a,
+            integer_lstm_param->effective_recurrent_to_input_scale_b,
+            recurrent_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_forget_weights),
+            integer_lstm_param->effective_recurrent_to_forget_scale_a,
+            integer_lstm_param->effective_recurrent_to_forget_scale_b,
+            recurrent_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_cell_weights),
+            integer_lstm_param->effective_recurrent_to_cell_scale_a,
+            integer_lstm_param->effective_recurrent_to_cell_scale_b,
+            recurrent_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_output_weights),
+            integer_lstm_param->effective_recurrent_to_output_scale_a,
+            integer_lstm_param->effective_recurrent_to_output_scale_b,
+            cell_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
+            integer_lstm_param->effective_cell_to_input_scale_a,
+            integer_lstm_param->effective_cell_to_input_scale_b,
+            cell_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
+            integer_lstm_param->effective_cell_to_forget_scale_a,
+            integer_lstm_param->effective_cell_to_forget_scale_b,
+            cell_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
+            integer_lstm_param->effective_cell_to_output_scale_a,
+            integer_lstm_param->effective_cell_to_output_scale_b,
+            projection_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(projection_weights),
+            integer_lstm_param->effective_proj_scale_a,
+            integer_lstm_param->effective_proj_scale_b,
+            integer_lstm_param->hidden_zp,
+            integer_lstm_param->effective_hidden_scale_a,
+            integer_lstm_param->effective_hidden_scale_b,
+            input_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(
+                      input_layer_norm_coefficients),
+            integer_lstm_param->layer_norm_input_scale_a,
+            integer_lstm_param->layer_norm_input_scale_b,
+            forget_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(
+                      forget_layer_norm_coefficients),
+            integer_lstm_param->layer_norm_forget_scale_a,
+            integer_lstm_param->layer_norm_forget_scale_b,
+            cell_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(
+                      cell_layer_norm_coefficients),
+            integer_lstm_param->layer_norm_cell_scale_a,
+            integer_lstm_param->layer_norm_cell_scale_b,
+            output_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(
+                      output_layer_norm_coefficients),
+            integer_lstm_param->layer_norm_output_scale_a,
+            integer_lstm_param->layer_norm_output_scale_b,
+            input_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
+            forget_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
+            cell_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
+            output_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
+            integer_lstm_param->quantized_cell_clip,
+            integer_lstm_param->quantized_proj_clip,
+            integer_lstm_param->cell_scale,
+            integer_lstm_param->input_variance_guard,
+            integer_lstm_param->forget_variance_guard,
+            integer_lstm_param->cell_variance_guard,
+            integer_lstm_param->output_variance_guard,
+            integer_lstm_param->input_to_forget_effective_bias,
+            integer_lstm_param->recurrent_to_forget_effective_bias,
+            integer_lstm_param->input_to_cell_effective_bias,
+            integer_lstm_param->recurrent_to_cell_effective_bias,
+            integer_lstm_param->input_to_output_effective_bias,
+            integer_lstm_param->recurrent_to_output_effective_bias,
+            integer_lstm_param->input_to_input_effective_bias,
+            integer_lstm_param->recurrent_to_input_effective_bias,
+            integer_lstm_param->projection_effective_bias, /*n_batch=*/1,
+            n_cell, n_input, n_output, output_state_ptr, output_state_zp,
+            cell_state_ptr, output_ptr, scratch0, scratch1, scratch2, scratch3,
+            scratch4, scratch5);
+      }
+    }
+  }
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalInteger8x8_8Lstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
+    TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
+    int32_t input_zp, int32_t output_state_zp, int8_t* scratch0,
+    int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, int16_t* scratch4,
+    int16_t* scratch5, int16_t* scratch6, int16_t* scratch7) {
+  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
+  const int n_input = input->dims->data[input->dims->size - 1];
+  int max_time, n_batch;
+  if (input->dims->size == 2) {
+    max_time = 1;
+    n_batch = input->dims->data[0];
+  } else {
+    max_time = input->dims->data[0];
+    n_batch = input->dims->data[1];
+  }
+
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Get params for time/batch/sequence.
+  const int output_batch_leading_dim =
+      output->dims->data[output->dims->size - 1];
+  const int input_step = n_batch * n_input;
+  const int output_step = n_batch * output_batch_leading_dim;
+
+  for (int t = 0; t < max_time; t++) {
+    const int t_rel = t;
+    int8_t* output_ptr =
+        tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
+    // Input can be int8 asymmetric or int16 symmetric.
+    const int8_t* input_ptr =
+        tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
+    LstmStepInteger8x8_8(
+        input_ptr, input_zp,
+
+        input_to_input_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+        integer_lstm_param->effective_input_to_input_scale_a,
+        integer_lstm_param->effective_input_to_input_scale_b,
+
+        input_to_forget_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+        integer_lstm_param->effective_input_to_forget_scale_a,
+        integer_lstm_param->effective_input_to_forget_scale_b,
+
+        input_to_cell_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+        integer_lstm_param->effective_input_to_cell_scale_a,
+        integer_lstm_param->effective_input_to_cell_scale_b,
+
+        input_to_output_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+        integer_lstm_param->effective_input_to_output_scale_a,
+        integer_lstm_param->effective_input_to_output_scale_b,
+
+        recurrent_to_input_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
+        integer_lstm_param->effective_recurrent_to_input_scale_a,
+        integer_lstm_param->effective_recurrent_to_input_scale_b,
+
+        recurrent_to_forget_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
+        integer_lstm_param->effective_recurrent_to_forget_scale_a,
+        integer_lstm_param->effective_recurrent_to_forget_scale_b,
+
+        recurrent_to_cell_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
+        integer_lstm_param->effective_recurrent_to_cell_scale_a,
+        integer_lstm_param->effective_recurrent_to_cell_scale_b,
+
+        recurrent_to_output_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
+        integer_lstm_param->effective_recurrent_to_output_scale_a,
+        integer_lstm_param->effective_recurrent_to_output_scale_b,
+
+        cell_to_input_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
+        integer_lstm_param->effective_cell_to_input_scale_a,
+        integer_lstm_param->effective_cell_to_input_scale_b,
+
+        cell_to_forget_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
+        integer_lstm_param->effective_cell_to_forget_scale_a,
+        integer_lstm_param->effective_cell_to_forget_scale_b,
+
+        cell_to_output_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
+        integer_lstm_param->effective_cell_to_output_scale_a,
+        integer_lstm_param->effective_cell_to_output_scale_b,
+
+        projection_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(projection_weights),
+        integer_lstm_param->effective_proj_scale_a,
+        integer_lstm_param->effective_proj_scale_b,
+
+        input_layer_norm_coefficients == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int16_t>(
+                  input_layer_norm_coefficients),
+        integer_lstm_param->layer_norm_input_scale_a,
+        integer_lstm_param->layer_norm_input_scale_b,
+
+        forget_layer_norm_coefficients == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int16_t>(
+                  forget_layer_norm_coefficients),
+        integer_lstm_param->layer_norm_forget_scale_a,
+        integer_lstm_param->layer_norm_forget_scale_b,
+
+        cell_layer_norm_coefficients == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int16_t>(
+                  cell_layer_norm_coefficients),
+        integer_lstm_param->layer_norm_cell_scale_a,
+        integer_lstm_param->layer_norm_cell_scale_b,
+
+        output_layer_norm_coefficients == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int16_t>(
+                  output_layer_norm_coefficients),
+        integer_lstm_param->layer_norm_output_scale_a,
+        integer_lstm_param->layer_norm_output_scale_b,
+
+        input_gate_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
+        forget_gate_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
+        cell_gate_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
+        output_gate_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
+        projection_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(projection_bias),
+
+        params, integer_lstm_param->intermediate_scale_a,
+        integer_lstm_param->intermediate_scale_b,
+        integer_lstm_param->intermediate_zp,
+        integer_lstm_param->quantized_cell_clip,
+        integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
+        n_output, output_batch_leading_dim,
+        tflite::micro::GetTensorData<int8_t>(output_state), output_state_zp,
+        tflite::micro::GetTensorData<int16_t>(cell_state), output_ptr, scratch0,
+        scratch1, scratch2, scratch3, scratch4, scratch5, scratch6, scratch7);
+  }
+
+  return kTfLiteOk;
+}
+
+}  // namespace tflite

+ 250 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.h

@@ -0,0 +1,250 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
+#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
+
+#include <cstdint>
+#include <memory>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+
+namespace tflite {
+
+// Pamameters for integer LSTM.
+// Consider split this into two Integer Parameters if more fields are added.
+struct IntegerLstmParameter {
+  int32_t effective_input_to_input_scale_a;
+  int32_t effective_input_to_input_scale_b;
+  int32_t effective_recurrent_to_input_scale_a;
+  int32_t effective_recurrent_to_input_scale_b;
+  int32_t effective_cell_to_input_scale_a;
+  int32_t effective_cell_to_input_scale_b;
+  int32_t effective_input_to_forget_scale_a;
+  int32_t effective_input_to_forget_scale_b;
+  int32_t effective_recurrent_to_forget_scale_a;
+  int32_t effective_recurrent_to_forget_scale_b;
+  int32_t effective_cell_to_forget_scale_a;
+  int32_t effective_cell_to_forget_scale_b;
+  int32_t effective_input_to_cell_scale_a;
+  int32_t effective_input_to_cell_scale_b;
+  int32_t effective_recurrent_to_cell_scale_a;
+  int32_t effective_recurrent_to_cell_scale_b;
+  int32_t effective_input_to_output_scale_a;
+  int32_t effective_input_to_output_scale_b;
+  int32_t effective_recurrent_to_output_scale_a;
+  int32_t effective_recurrent_to_output_scale_b;
+  int32_t effective_cell_to_output_scale_a;
+  int32_t effective_cell_to_output_scale_b;
+  int32_t effective_proj_scale_a;
+  int32_t effective_proj_scale_b;
+  int32_t effective_hidden_scale_a;
+  int32_t effective_hidden_scale_b;
+  int32_t layer_norm_input_scale_a;
+  int32_t layer_norm_input_scale_b;
+  int32_t layer_norm_forget_scale_a;
+  int32_t layer_norm_forget_scale_b;
+  int32_t layer_norm_cell_scale_a;
+  int32_t layer_norm_cell_scale_b;
+  int32_t layer_norm_output_scale_a;
+  int32_t layer_norm_output_scale_b;
+  // Quantized clip value for cell and projection. Zero value means no clipping.
+  int16_t quantized_cell_clip;
+  int8_t quantized_proj_clip;
+  int32_t hidden_zp;
+  int32_t cell_scale;
+
+  int32_t input_variance_guard;
+  int32_t forget_variance_guard;
+  int32_t cell_variance_guard;
+  int32_t output_variance_guard;
+
+  // Pre-calculate bias + zero_point * weight.
+  int32_t* input_to_forget_effective_bias;
+  int32_t* recurrent_to_forget_effective_bias;
+  int32_t* input_to_cell_effective_bias;
+  int32_t* recurrent_to_cell_effective_bias;
+  int32_t* input_to_output_effective_bias;
+  int32_t* recurrent_to_output_effective_bias;
+  int32_t* input_to_input_effective_bias;
+  int32_t* recurrent_to_input_effective_bias;
+  int32_t* projection_effective_bias;
+
+  // Scale and zero point for intermediate tensors.
+  // Used only in the 8x8_8 case.
+  int32_t intermediate_scale_a[8];
+  int32_t intermediate_scale_b[8];
+  int32_t intermediate_zp[12];
+};
+
+// Scales for hybrid op with integer inputs and float weights
+struct HybridLstmScales {
+  float input_to_input_weights_scale;
+  float input_to_forget_weights_scale;
+  float input_to_cell_weights_scale;
+  float input_to_output_weights_scale;
+  float aux_input_to_input_weights_scale;
+  float aux_input_to_forget_weights_scale;
+  float aux_input_to_cell_weights_scale;
+  float aux_input_to_output_weights_scale;
+  float recurrent_to_input_weights_scale;
+  float recurrent_to_forget_weights_scale;
+  float recurrent_to_cell_weights_scale;
+  float recurrent_to_output_weights_scale;
+  float cell_to_input_weights_scale;
+  float cell_to_forget_weights_scale;
+  float cell_to_output_weights_scale;
+  float projection_weights_scale;
+};
+
+TfLiteStatus EvalFloatLstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* aux_input,
+    const TfLiteEvalTensor* aux_input_to_input_weights,
+    const TfLiteEvalTensor* aux_input_to_forget_weights,
+    const TfLiteEvalTensor* aux_input_to_cell_weights,
+    const TfLiteEvalTensor* aux_input_to_output_weights,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major, int output_offset,
+    float* scratch_buffer, TfLiteEvalTensor* output_state,
+    TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output);
+
+TfLiteStatus EvalHybridLstm(
+    const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_input_weights_ledger,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_forget_weights_ledger,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_cell_weights_ledger,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* input_to_output_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* aux_input,
+    const TfLiteEvalTensor* aux_input_to_input_weights,
+    const TfLiteEvalTensor* aux_input_to_forget_weights,
+    const TfLiteEvalTensor* aux_input_to_cell_weights,
+    const TfLiteEvalTensor* aux_input_to_output_weights,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_weights_ledger,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major, int output_offset,
+    float* scratch_buffer, float* input_sf, float* aux_input_sf,
+    float* output_state_sf, float* prod_scaling_factors,
+    float* recovered_cell_weights, int8_t* input_quantized,
+    int8_t* aux_input_quantized, int8_t* output_state_quantized,
+    int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
+    TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
+    TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
+    int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
+    bool* compute_row_sums);
+
+TfLiteStatus EvalInteger8x8_16Lstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major,
+    const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
+    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
+    TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
+    int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5);
+
+TfLiteStatus EvalInteger8x8_8Lstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
+    TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
+    int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
+    int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7);
+
+}  // namespace tflite
+#endif  // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików