jomjol 3 лет назад
Родитель
Сommit
17fd0f96df
100 измененных файлов с 5715 добавлено и 976 удалено
  1. 4 2
      README.md
  2. 3 1
      code/components/esp-nn/CMakeLists.txt
  3. 4 4
      code/components/esp-nn/Kconfig.projbuild
  4. 4 3
      code/components/esp-nn/README.md
  5. 5 5
      code/components/esp-nn/include/esp_nn.h
  6. 1 0
      code/components/esp-nn/include/esp_nn_ansi_c.h
  7. 82 56
      code/components/esp-nn/include/esp_nn_ansi_headers.h
  8. 83 0
      code/components/esp-nn/include/esp_nn_defs.h
  9. 24 54
      code/components/esp-nn/include/esp_nn_esp32s3.h
  10. 9 10
      code/components/esp-nn/include/esp_nn_generic_opt.h
  11. 50 13
      code/components/esp-nn/src/common/common_functions.h
  12. 30 26
      code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c
  13. 106 79
      code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c
  14. 179 0
      code/components/esp-nn/src/convolution/esp_nn_conv_opt.c
  15. 30 27
      code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c
  16. 291 0
      code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c
  17. 97 37
      code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c
  18. 8 0
      code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3
  19. 16 4
      code/components/esp-nn/tests/src/basic_math_test.c
  20. 70 36
      code/components/esp-nn/tests/src/convolution_test.c
  21. BIN
      code/components/esp-nn_20220716.zip
  22. 1 1
      code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp
  23. 4 1
      code/components/tflite-lib/CMakeLists.txt
  24. 2 0
      code/components/tflite-lib/tensorflow/lite/builtin_ops.h
  25. 3 0
      code/components/tflite-lib/tensorflow/lite/c/builtin_op_data.h
  26. 7 1
      code/components/tflite-lib/tensorflow/lite/c/c_api_types.h
  27. 40 10
      code/components/tflite-lib/tensorflow/lite/c/common.cc
  28. 52 3
      code/components/tflite-lib/tensorflow/lite/c/common.h
  29. 11 0
      code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc
  30. 52 1
      code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h
  31. 3 3
      code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h
  32. 6 3
      code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h
  33. 5 5
      code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h
  34. 1 1
      code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h
  35. 3 3
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h
  36. 165 0
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc
  37. 104 0
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h
  38. 52 0
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc
  39. 59 0
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h
  40. 1 1
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.cc
  41. 4 4
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h
  42. 1 1
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.cc
  43. 4 4
      code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h
  44. 1 1
      code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc
  45. 7 20
      code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc
  46. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc
  47. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc
  48. 2 16
      code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc
  49. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc
  50. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc
  51. 3 9
      code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc
  52. 3 9
      code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc
  53. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc
  54. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc
  55. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc
  56. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc
  57. 12 48
      code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc
  58. 5 11
      code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc
  59. 40 22
      code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc
  60. 10 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h
  61. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc
  62. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc
  63. 3 10
      code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc
  64. 27 1
      code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h
  65. 9 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc
  66. 3 2
      code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc
  67. 1 10
      code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc
  68. 289 79
      code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc
  69. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc
  70. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc
  71. 46 21
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc
  72. 49 22
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc
  73. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc
  74. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc
  75. 2 16
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc
  76. 208 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc
  77. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc
  78. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc
  79. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc
  80. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc
  81. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc
  82. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc
  83. 19 12
      code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc
  84. 19 1
      code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.h
  85. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/gather.cc
  86. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/gather_nd.cc
  87. 2 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/hard_swish.cc
  88. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/if.cc
  89. 3 2
      code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.cc
  90. 3 2
      code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.h
  91. 15 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.cc
  92. 22 3
      code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.h
  93. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/l2_pool_2d.cc
  94. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/l2norm.cc
  95. 2 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/leaky_relu.cc
  96. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/log_softmax.cc
  97. 2 20
      code/components/tflite-lib/tensorflow/lite/micro/kernels/logical.cc
  98. 1 8
      code/components/tflite-lib/tensorflow/lite/micro/kernels/logistic.cc
  99. 2955 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.cc
  100. 250 0
      code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.h

+ 4 - 2
README.md

@@ -54,8 +54,10 @@ In other cases you can contact the developer via email: <img src="https://raw.gi
 
 
 ##### Rolling (2022-07-16)
 ##### Rolling (2022-07-16)
 
 
-- Updated esp32cam
+- TFMicro/Lite: Update (espressif Version 20220716)
+- Updated esp32cam (v20220716)
 - Integrated new analog classificational CNN (from @haverland)
 - Integrated new analog classificational CNN (from @haverland)
+- Bugfix: Postprocessing
 
 
 ##### Rolling (2022-07-01)
 ##### Rolling (2022-07-01)
 
 
@@ -79,7 +81,7 @@ Rolling (2022-04-26)
 - Extended MQTT with absolute Change (in addition to rate)
 - Extended MQTT with absolute Change (in addition to rate)
 - Internal optimization, removal of modelfile from `config.ini` (is now read out of the cnn file directly)
 - Internal optimization, removal of modelfile from `config.ini` (is now read out of the cnn file directly)
 
 
-- TFMicro/Lite: Update (espressif Verision 20220417)
+- TFMicro/Lite: Update (espressif Version 20220417)
 - ESP-IDF: Update to 4.3.0
 - ESP-IDF: Update to 4.3.0
 
 
 Rolling (2022-04-17)
 Rolling (2022-04-17)

+ 3 - 1
code/components/esp-nn/CMakeLists.txt

@@ -5,7 +5,9 @@ set(c_srcs
     "src/basic_math/esp_nn_add_ansi.c"
     "src/basic_math/esp_nn_add_ansi.c"
     "src/basic_math/esp_nn_mul_ansi.c"
     "src/basic_math/esp_nn_mul_ansi.c"
     "src/convolution/esp_nn_conv_ansi.c"
     "src/convolution/esp_nn_conv_ansi.c"
+    "src/convolution/esp_nn_conv_opt.c"
     "src/convolution/esp_nn_depthwise_conv_ansi.c"
     "src/convolution/esp_nn_depthwise_conv_ansi.c"
+    "src/convolution/esp_nn_depthwise_conv_opt.c"
     "src/fully_connected/esp_nn_fully_connected_ansi.c"
     "src/fully_connected/esp_nn_fully_connected_ansi.c"
     "src/softmax/esp_nn_softmax_ansi.c"
     "src/softmax/esp_nn_softmax_ansi.c"
     "src/softmax/esp_nn_softmax_opt.c"
     "src/softmax/esp_nn_softmax_opt.c"
@@ -23,7 +25,7 @@ if(CONFIG_IDF_TARGET_ESP32S3)
         "src/convolution/esp_nn_conv_esp32s3.c"
         "src/convolution/esp_nn_conv_esp32s3.c"
         "src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c"
         "src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c"
         "src/convolution/esp_nn_conv_s16_mult8_esp32s3.S"
         "src/convolution/esp_nn_conv_s16_mult8_esp32s3.S"
-        "src/convolution/esp_nn_conv_s16_mult8_1x1_esp32s3.S"
+        "src/convolution/esp_nn_conv_s8_mult8_1x1_esp32s3.S"
         "src/convolution/esp_nn_conv_s16_mult4_1x1_esp32s3.S"
         "src/convolution/esp_nn_conv_s16_mult4_1x1_esp32s3.S"
         "src/convolution/esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3.S"
         "src/convolution/esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3.S"
         "src/convolution/esp_nn_depthwise_conv_s16_mult1_esp32s3.S"
         "src/convolution/esp_nn_depthwise_conv_s16_mult1_esp32s3.S"

+ 4 - 4
code/components/esp-nn/Kconfig.projbuild

@@ -6,8 +6,8 @@ choice NN_OPTIMIZATIONS
    help
    help
       Use ANSI-C versions for verification and debug purpose.
       Use ANSI-C versions for verification and debug purpose.
       Optimisations are automatically picked up for a chipset.
       Optimisations are automatically picked up for a chipset.
-      For ESP32-S3, assembly Optimisations are selected.
-      For ESP32, just the ANSI C versions are selected for now.
+      For ESP32-S3, assembly optimisations are selected.
+      For other platforms(viz., ESP32, ESP32-C3), generic optimisations are used.
 
 
 config NN_ANSI_C
 config NN_ANSI_C
    bool "ANSI C"
    bool "ANSI C"
@@ -17,8 +17,8 @@ config NN_OPTIMIZED
    bool "Optimized versions"
    bool "Optimized versions"
    help
    help
       Optimisations are automatically picked up for a chipset.
       Optimisations are automatically picked up for a chipset.
-      For ESP32-S3, assembly Optimisations are selected.
-      For ESP32, just the ANSI C versions are selected for now.
+      For ESP32-S3, assembly optimisations are selected.
+      For other platforms(viz., ESP32, ESP32-C3), generic optimisations are used.
 endchoice
 endchoice
 
 
 config NN_OPTIMIZATIONS
 config NN_OPTIMIZATIONS

+ 4 - 3
code/components/esp-nn/README.md

@@ -7,7 +7,8 @@ The library contains optimised NN (Neural Network) functions for various Espress
 
 
 * Supported ESP chipsets include:
 * Supported ESP chipsets include:
    * ESP32-S3 (Assembly versions optimised to benefit from vector instructions of ESP32-S3)
    * ESP32-S3 (Assembly versions optimised to benefit from vector instructions of ESP32-S3)
-   * ESP32 (ANSI C versions)
+   * ESP32 (Generic optimisations)
+   * ESP32-C3 (Generic optimisations)
 
 
 ## Performance
 ## Performance
 
 
@@ -39,8 +40,8 @@ The library contains optimised NN (Neural Network) functions for various Espress
      * Optimized versions
      * Optimized versions
      * ANSI C
      * ANSI C
 
 
-  * Default selection is for `Optimized versions`. For ESP32-S3, assembly versions are automatically selected, whereas for ESP32,  ANSI-C versions are selected by default.
-  * For debugging purposes, you may want to select `ANSI C`
+  * Default selection is for `Optimized versions`. For ESP32-S3, assembly versions are automatically selected, whereas for other chipsets (viz., ESP32, ESP32-C3), generic optimisations are selected.
+  * For debugging purposes, you may want to select `ANSI C` reference versions.
 
 
 
 
 ## Contributing
 ## Contributing

+ 5 - 5
code/components/esp-nn/include/esp_nn.h

@@ -15,6 +15,7 @@
 #pragma once
 #pragma once
 
 
 #if defined(CONFIG_NN_OPTIMIZED)
 #if defined(CONFIG_NN_OPTIMIZED)
+// select apt optimisations
 #ifdef CONFIG_IDF_TARGET_ESP32S3
 #ifdef CONFIG_IDF_TARGET_ESP32S3
 #define ARCH_ESP32_S3 1
 #define ARCH_ESP32_S3 1
 #endif
 #endif
@@ -31,12 +32,11 @@ extern "C" {
 #include "esp_nn_ansi_headers.h"
 #include "esp_nn_ansi_headers.h"
 
 
 #if defined(CONFIG_NN_OPTIMIZED)
 #if defined(CONFIG_NN_OPTIMIZED)
-#ifdef ARCH_ESP32_S3
+#if defined(ARCH_ESP32_S3)
 #include "esp_nn_esp32s3.h"
 #include "esp_nn_esp32s3.h"
-#endif
-#ifdef ARCH_ESP32
-#include "esp_nn_esp32.h"
-#endif
+#else // for other platforms use generic optimisations
+#include "esp_nn_generic_opt.h"
+#endif // #if defined(ARCH_ESP32_S3)
 #else
 #else
 #include "esp_nn_ansi_c.h"
 #include "esp_nn_ansi_c.h"
 #endif
 #endif

+ 1 - 0
code/components/esp-nn/include/esp_nn_ansi_c.h

@@ -19,6 +19,7 @@
 
 
 #pragma once
 #pragma once
 
 
+#include "esp_nn_defs.h"
 #include "esp_nn_ansi_headers.h"
 #include "esp_nn_ansi_headers.h"
 
 
 #define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
 #define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi

+ 82 - 56
code/components/esp-nn/include/esp_nn_ansi_headers.h

@@ -18,8 +18,7 @@
  * @file        Header definitions to include for esp_nn reference functions
  * @file        Header definitions to include for esp_nn reference functions
  */
  */
 
 
-#include <stdint.h>
-
+#include "esp_nn_defs.h"
 /************************** Basic math functions ****************************/
 /************************** Basic math functions ****************************/
 
 
 /**
 /**
@@ -81,28 +80,15 @@ void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
  *              optimization notes: Though input_offset is int32 type,
  *              optimization notes: Though input_offset is int32 type,
  *              offset values are contained in 8 bits [-128, 127]
  *              offset values are contained in 8 bits [-128, 127]
  */
  */
-void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
-                                   const uint16_t input_wd,
-                                   const uint16_t input_ht,
-                                   const uint16_t channels,
-                                   const int32_t input_offset,
-                                   const uint16_t pad_wd,
-                                   const uint16_t pad_ht,
-                                   const uint16_t stride_wd,
-                                   const uint16_t stride_ht,
-                                   const uint16_t ch_mult,
+void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
+                                   const int8_t *input_data,
+                                   const data_dims_t *filter_dims,
                                    const int8_t *filter_data,
                                    const int8_t *filter_data,
-                                   const uint16_t filter_wd,
-                                   const uint16_t filter_ht,
                                    const int32_t *bias,
                                    const int32_t *bias,
+                                   const data_dims_t *output_dims,
                                    int8_t *out_data,
                                    int8_t *out_data,
-                                   const uint16_t out_wd,
-                                   const uint16_t out_ht,
-                                   const int32_t out_offset,
-                                   const int32_t *out_shift,
-                                   const int32_t *out_mult,
-                                   const int32_t activation_min,
-                                   const int32_t activation_max);
+                                   const dw_conv_params_t *conv_params,
+                                   const quant_data_t *quant_data);
 
 
 /**
 /**
  * @brief       2d-convolution channelwise
  * @brief       2d-convolution channelwise
@@ -112,43 +98,26 @@ void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
  *              inputs type: int8_t, output: int8_t
  *              inputs type: int8_t, output: int8_t
  *              input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  *              input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  */
  */
-void esp_nn_conv_s8_ansi(const int8_t *input_data,
-                         const uint16_t input_wd,
-                         const uint16_t input_ht,
-                         const uint16_t in_channels,
-                         const int32_t input_offset,
-                         const uint16_t pad_wd,
-                         const uint16_t pad_ht,
-                         const uint16_t stride_wd,
-                         const uint16_t stride_ht,
+void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
+                         const int8_t *input_data,
+                         const data_dims_t *filter_dims,
                          const int8_t *filter_data,
                          const int8_t *filter_data,
-                         const uint16_t filter_wd,
-                         const uint16_t filter_ht,
                          const int32_t *bias,
                          const int32_t *bias,
+                         const data_dims_t *output_dims,
                          int8_t *out_data,
                          int8_t *out_data,
-                         const uint16_t out_wd,
-                         const uint16_t out_ht,
-                         const uint16_t out_channels,
-                         const int32_t out_offset,
-                         const int32_t *out_shift,
-                         const int32_t *out_mult,
-                         const int32_t activation_min,
-                         const int32_t activation_max);
-
-int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
-                                      const uint16_t input_ht,
-                                      const uint16_t in_ch,
-                                      const uint16_t out_ch,
-                                      const uint16_t filter_wd,
-                                      const uint16_t filter_ht);
+                         const conv_params_t *conv_params,
+                         const quant_data_t *quant_data);
+
+int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
+                                      const data_dims_t *filter_dims,
+                                      const data_dims_t *output_dims,
+                                      const conv_params_t *conv_params);
 void esp_nn_set_conv_scratch_buf_ansi(const void *buf);
 void esp_nn_set_conv_scratch_buf_ansi(const void *buf);
 
 
-int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
-                                                const uint16_t input_ht,
-                                                const uint16_t channels,
-                                                const uint16_t ch_mult,
-                                                const uint16_t filter_wd,
-                                                const uint16_t filter_ht);
+int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
+                                                const data_dims_t *filter_dims,
+                                                const data_dims_t *output_dims,
+                                                const dw_conv_params_t *conv_params);
 void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf);
 void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf);
 
 
 /************************** Activation functions *****************************/
 /************************** Activation functions *****************************/
@@ -252,9 +221,6 @@ int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t h
  */
  */
 void esp_nn_set_softmax_scratch_buf_ansi(void *buffer);
 void esp_nn_set_softmax_scratch_buf_ansi(void *buffer);
 
 
-/* ANSI C function to be hooked up when optimised version needed */
-void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
-
 /**
 /**
  * @brief       reference softmax function
  * @brief       reference softmax function
  *
  *
@@ -268,6 +234,66 @@ void esp_nn_softmax_s8_ansi(const int8_t *input_data,
                             const int32_t diff_min,
                             const int32_t diff_min,
                             int8_t *output_data);
                             int8_t *output_data);
 
 
+
+//////////////////////////// Generic optimisations /////////////////////////////
+
+/************************** Convolution functions *****************************/
+
+/**
+ * @brief       2d-convolution channelwise optimized version
+ *
+ * @note        operation: result += (input + offset) * filter
+ *
+ *              inputs type: int8_t, output: int8_t
+ *              input offsets: although int32_t, they are contained in 8 bits [-128, 127]
+ */
+void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
+                        const int8_t *input_data,
+                        const data_dims_t *filter_dims,
+                        const int8_t *filter_data,
+                        const int32_t *bias,
+                        const data_dims_t *output_dims,
+                        int8_t *out_data,
+                        const conv_params_t *conv_params,
+                        const quant_data_t *quant_data);
+
+/**
+ * @brief       depthwise convolution per channel optimized version
+ *
+ * @note        inputs type: int8_t, output: int8_t
+ *              Version used in tflite is per channel.
+ *              This version follows the same footsprints.
+ *              Meaning, it has per out_channel shift and multiplier for
+ *              requantization
+ *
+ *              optimization notes: Though input_offset is int32 type,
+ *              offset values are contained in 8 bits [-128, 127]
+ */
+void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
+                                  const int8_t *input_data,
+                                  const data_dims_t *filter_dims,
+                                  const int8_t *filter_data,
+                                  const int32_t *bias,
+                                  const data_dims_t *output_dims,
+                                  int8_t *out_data,
+                                  const dw_conv_params_t *conv_params,
+                                  const quant_data_t *quant_data);
+
+int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
+                                     const data_dims_t *filter_dims,
+                                     const data_dims_t *output_dims,
+                                     const conv_params_t *conv_params);
+void esp_nn_set_conv_scratch_buf_opt(const void *buf);
+
+int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
+                                               const data_dims_t *filter_dims,
+                                               const data_dims_t *output_dims,
+                                               const dw_conv_params_t *conv_params);
+void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf);
+
+/* ANSI C function to be hooked up when optimised version needed */
+void esp_nn_set_softmax_scratch_buf_opt(void *buffer);
+
 /**
 /**
  * @brief       optimised version of softmax function
  * @brief       optimised version of softmax function
  *
  *

+ 83 - 0
code/components/esp-nn/include/esp_nn_defs.h

@@ -0,0 +1,83 @@
+// Copyright 2022 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <stdint.h>
+
+/**
+ * @brief structure to club data dims
+ * this structure can be used for input, output and filter
+ */
+typedef struct data_dims {
+    int32_t width;
+    int32_t height;
+    int32_t channels;
+
+    int32_t extra; // can be used as batch or any other param
+} data_dims_t;
+
+/**
+ * @brief 2d data structure (width, height)
+ *
+ */
+typedef struct data_2d {
+    int32_t width;
+    int32_t height;
+} data_2d_t;
+
+/**
+ * @brief min/max activation
+ */
+typedef struct act_params {
+    int32_t min;
+    int32_t max;
+} act_params_t;
+
+/**
+ * @brief per channel quant data
+ *
+ * @note number of shift and mult elements are equal to output channels
+ */
+typedef struct quant_data {
+    int32_t *shift;
+    int32_t *mult;
+} quant_data_t;
+
+/**
+ * @brief params specific to convolution 2d
+ *
+ */
+typedef struct conv_params {
+    int32_t in_offset;
+    int32_t out_offset;
+    data_2d_t stride;
+    data_2d_t padding;
+    data_2d_t dilation;
+    act_params_t activation;
+} conv_params_t;
+
+/**
+ * @brief params specific to depthwise convolution 2d
+ *
+ */
+typedef struct dw_conv_params {
+    int32_t in_offset;
+    int32_t out_offset;
+    int32_t ch_mult; // channel multiplier. (in_ch * ch_mult = out_ch)
+    data_2d_t stride;
+    data_2d_t padding;
+    data_2d_t dilation;
+    act_params_t activation;
+} dw_conv_params_t;

+ 24 - 54
code/components/esp-nn/include/esp_nn_esp32s3.h

@@ -19,7 +19,7 @@
 
 
 #pragma once
 #pragma once
 
 
-#include <stdint.h>
+#include "esp_nn_defs.h"
 #include "esp_nn_ansi_headers.h"
 #include "esp_nn_ansi_headers.h"
 
 
 /************************** Basic math functions *****************************/
 /************************** Basic math functions *****************************/
@@ -85,28 +85,15 @@ void esp_nn_mul_elementwise_s8_esp32s3(const int8_t *input1_data,
  *              optimization notes: Though input_offset is int32 type,
  *              optimization notes: Though input_offset is int32 type,
  *              offset values are contained in 8 bits [-128, 127]
  *              offset values are contained in 8 bits [-128, 127]
  */
  */
-void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
-                                      const uint16_t input_wd,
-                                      const uint16_t input_ht,
-                                      const uint16_t channels,
-                                      const int32_t input_offset,
-                                      const uint16_t pad_wd,
-                                      const uint16_t pad_ht,
-                                      const uint16_t stride_wd,
-                                      const uint16_t stride_ht,
-                                      const uint16_t ch_mult,
+void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
+                                      const int8_t *input_data,
+                                      const data_dims_t *filter_dims,
                                       const int8_t *filter_data,
                                       const int8_t *filter_data,
-                                      const uint16_t filter_wd,
-                                      const uint16_t filter_ht,
                                       const int32_t *bias,
                                       const int32_t *bias,
-                                      int8_t *out_data,
-                                      const uint16_t out_wd,
-                                      const uint16_t out_ht,
-                                      const int32_t out_offset,
-                                      const int32_t *out_shift,
-                                      const int32_t *out_mult,
-                                      const int32_t activation_min,
-                                      const int32_t activation_max);
+                                      const data_dims_t *output_dims,
+                                      int8_t *output_data,
+                                      const dw_conv_params_t *conv_params,
+                                      const quant_data_t *quant_data);
 
 
 /**
 /**
  * @brief       2d - convolution channelwise
  * @brief       2d - convolution channelwise
@@ -116,43 +103,26 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
  *              inputs type: int8_t, output: int8_t
  *              inputs type: int8_t, output: int8_t
  *              input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  *              input offsets: although int32_t, they are contained in 8 bits [-128, 127]
  */
  */
-void esp_nn_conv_s8_esp32s3(const int8_t *input_data,
-                            const uint16_t input_wd,
-                            const uint16_t input_ht,
-                            const uint16_t in_channels,
-                            const int32_t input_offset,
-                            const uint16_t pad_wd,
-                            const uint16_t pad_ht,
-                            const uint16_t stride_wd,
-                            const uint16_t stride_ht,
+void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
+                            const int8_t *input_data,
+                            const data_dims_t *filter_dims,
                             const int8_t *filter_data,
                             const int8_t *filter_data,
-                            const uint16_t filter_wd,
-                            const uint16_t filter_ht,
                             const int32_t *bias,
                             const int32_t *bias,
-                            int8_t *out_data,
-                            const uint16_t out_wd,
-                            const uint16_t out_ht,
-                            const uint16_t out_channels,
-                            const int32_t out_offset,
-                            const int32_t *out_shift,
-                            const int32_t *out_mult,
-                            const int32_t activation_min,
-                            const int32_t activation_max);
-
-int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
-                                         const uint16_t input_ht,
-                                         const uint16_t in_ch,
-                                         const uint16_t out_ch,
-                                         const uint16_t filter_wd,
-                                         const uint16_t filter_ht);
+                            const data_dims_t *output_dims,
+                            int8_t *output_data,
+                            const conv_params_t *conv_params,
+                            const quant_data_t *quant_data);
+
+int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+                                         const data_dims_t *filter_dims,
+                                         const data_dims_t *output_dims,
+                                         const conv_params_t *conv_params);
 void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf);
 void esp_nn_set_conv_scratch_buf_esp32s3(const void *buf);
 
 
-int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
-                                                   const uint16_t input_ht,
-                                                   const uint16_t channels,
-                                                   const uint16_t ch_mult,
-                                                   const uint16_t filter_wd,
-                                                   const uint16_t filter_ht);
+int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+                                                   const data_dims_t *filter_dims,
+                                                   const data_dims_t *output_dims,
+                                                   const dw_conv_params_t *conv_params);
 void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf);
 void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(const void *buf);
 
 
 /************************** Pooling functions *****************************/
 /************************** Pooling functions *****************************/

+ 9 - 10
code/components/esp-nn/include/esp_nn_esp32.h → code/components/esp-nn/include/esp_nn_generic_opt.h

@@ -13,28 +13,27 @@
 // limitations under the License.
 // limitations under the License.
 
 
 /**
 /**
- * @file        Header definitions to include for esp_nn optimized functions for
- *              the ESP32 platform.
- *              We are hooking up just the C versions for now.
- *              The file hence is exactly same as `esp_nn_ansi_c.h`
+ * @file        Header definitions to include for esp_nn generic optimisations
+ *              For functions which not having optimisations, _ansi versions are picked.
  */
  */
 
 
 #pragma once
 #pragma once
 
 
+#include "esp_nn_defs.h"
 #include "esp_nn_ansi_headers.h"
 #include "esp_nn_ansi_headers.h"
 
 
 #define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
 #define esp_nn_add_elementwise_s8 esp_nn_add_elementwise_s8_ansi
 #define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
 #define esp_nn_mul_elementwise_s8 esp_nn_mul_elementwise_s8_ansi
 
 
-#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_ansi
+#define esp_nn_depthwise_conv_s8 esp_nn_depthwise_conv_s8_opt
 
 
-#define esp_nn_conv_s8 esp_nn_conv_s8_ansi
+#define esp_nn_conv_s8 esp_nn_conv_s8_opt
 
 
-#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_ansi
-#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_ansi
+#define esp_nn_get_conv_scratch_size esp_nn_get_conv_scratch_size_opt
+#define esp_nn_set_conv_scratch_buf esp_nn_set_conv_scratch_buf_opt
 
 
-#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_ansi
-#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_ansi
+#define esp_nn_get_depthwise_conv_scratch_size esp_nn_get_depthwise_conv_scratch_size_opt
+#define esp_nn_set_depthwise_conv_scratch_buf esp_nn_set_depthwise_conv_scratch_buf_opt
 
 
 #define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
 #define esp_nn_relu6_s8 esp_nn_relu6_s8_ansi
 
 

+ 50 - 13
code/components/esp-nn/src/common/common_functions.h

@@ -41,15 +41,39 @@
 
 
 __NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in)
 __NN_FORCE_INLINE__ int32_t esp_nn_clz32(uint32_t in)
 {
 {
+#if CONFIG_IDF_TARGET_ARCH_XTENSA
     __asm__ volatile("nsau %0, %0" : "+r" (in));
     __asm__ volatile("nsau %0, %0" : "+r" (in));
     return in;
     return in;
-}
-
-__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
-{
-    int32_t sign = (int32_t) (val64 >> 63);
-    int32_t to_add = sign & ((1ul << 31) - 1);
-    return (int32_t) ((int64_t) (val64 + to_add) >> 31);
+#elif defined(__GNUC__)
+    return __builtin_clz(in);
+#else
+    int32_t count = 32;
+    uint32_t x = in, y = in >> 16;
+    if (y != 0) {
+        count -= 16;
+        x = y;
+    }
+    y = x >> 8;
+    if (y != 0) {
+        count -= 8;
+        x = y;
+    }
+    y = x >> 4;
+    if (y != 0) {
+        count -= 4;
+        x = y;
+    }
+    y = x >> 2;
+    if (y != 0) {
+        count -= 2;
+        x = y;
+    }
+    y = x >> 1;
+    if (y != 0) {
+        return count - 2;
+    }
+    return count - x;
+#endif
 }
 }
 
 
 /**
 /**
@@ -57,8 +81,19 @@ __NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
  */
  */
 __NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in)
 __NN_FORCE_INLINE__ int32_t esp_nn_saturate8(int32_t in)
 {
 {
+#if CONFIG_IDF_TARGET_ARCH_XTENSA
     __asm__ volatile("clamps %0, %0, 7" : "+a"(in));
     __asm__ volatile("clamps %0, %0, 7" : "+a"(in));
     return in;
     return in;
+#else
+    return max(INT8_MIN, min(in, INT8_MAX));
+#endif
+}
+
+__NN_FORCE_INLINE__ int32_t esp_nn_pick_sat_high32_of64(int64_t val64)
+{
+    int32_t sign = (int32_t) (val64 >> 63);
+    int32_t to_add = sign & ((1ul << 31) - 1);
+    return (int32_t) ((int64_t) (val64 + to_add) >> 31);
 }
 }
 
 
 __NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1)
 __NN_FORCE_INLINE__ int32_t esp_nn_sat_round_doubling_high_mul(int32_t in0, int32_t in1)
@@ -144,7 +179,7 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
                                              const uint16_t pad_ht)
                                              const uint16_t pad_ht)
 {
 {
     /* memset with pad_val */
     /* memset with pad_val */
-    memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels * 2);
+    memset(dst, pad_val, ((input_wd + 2 * pad_wd) * (input_ht + 2 * pad_ht)) * channels);
     dst += (pad_wd + input_wd + pad_wd) * channels;
     dst += (pad_wd + input_wd + pad_wd) * channels;
 
 
     for (int i = 0; i < input_ht; i++) {
     for (int i = 0; i < input_ht; i++) {
@@ -156,7 +191,6 @@ static void esp_nn_aligned_s8_pad_with_value(const int8_t *src, int8_t *dst,
     }
     }
 }
 }
 
 
-#if 0
 static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
 static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
                                                  const uint16_t input_wd,
                                                  const uint16_t input_wd,
                                                  const uint16_t input_ht,
                                                  const uint16_t input_ht,
@@ -169,13 +203,16 @@ static void esp_nn_aligned_s8_pad_end_with_value(const int8_t *src, int8_t *dst,
         for (int j = 0; j < input_wd * channels; j++) {
         for (int j = 0; j < input_wd * channels; j++) {
             *dst++ = *src++;
             *dst++ = *src++;
         }
         }
-        memset(dst, pad_val, pad_wd * channels);
-        dst += pad_wd * channels;
+        if (pad_wd) {
+            memset(dst, pad_val, pad_wd * channels);
+            dst += pad_wd * channels;
+        }
     }
     }
     /* pad end `pad_ht` lines at end */
     /* pad end `pad_ht` lines at end */
-    memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
+    if (pad_ht) {
+        memset(dst, pad_val, (input_wd + pad_wd) * pad_ht * channels);
+    }
 }
 }
-#endif
 
 
 /**
 /**
  * @brief       convert 8 bit input data to 16 bit
  * @brief       convert 8 bit input data to 16 bit

+ 30 - 26
code/components/esp-nn/src/convolution/esp_nn_conv_ansi.c

@@ -12,16 +12,14 @@
 // See the License for the specific language governing permissions and
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // limitations under the License.
 
 
-#include <stdint.h>
+#include <esp_nn_defs.h>
 
 
 #include <common_functions.h>
 #include <common_functions.h>
 
 
-int esp_nn_get_conv_scratch_size_ansi(const uint16_t input_wd,
-                                      const uint16_t input_ht,
-                                      const uint16_t in_ch,
-                                      const uint16_t out_ch,
-                                      const uint16_t filter_wd,
-                                      const uint16_t filter_ht)
+int esp_nn_get_conv_scratch_size_ansi(const data_dims_t *input_dims,
+                                      const data_dims_t *filter_dims,
+                                      const data_dims_t *output_dims,
+                                      const conv_params_t *conv_params)
 {
 {
     return 0;
     return 0;
 }
 }
@@ -108,29 +106,35 @@ void esp_nn_conv_u8_ansi(const uint8_t *input_data,
  * Assumption 2: Pointers are valid
  * Assumption 2: Pointers are valid
  * Assumption 3: dialation width = 1
  * Assumption 3: dialation width = 1
  */
  */
-void esp_nn_conv_s8_ansi(const int8_t *input_data,
-                         const uint16_t input_wd,
-                         const uint16_t input_ht,
-                         const uint16_t in_channels,
-                         const int32_t input_offset,
-                         const uint16_t pad_wd,
-                         const uint16_t pad_ht,
-                         const uint16_t stride_wd,
-                         const uint16_t stride_ht,
+void esp_nn_conv_s8_ansi(const data_dims_t *input_dims,
+                         const int8_t *input_data,
+                         const data_dims_t *filter_dims,
                          const int8_t *filter_data,
                          const int8_t *filter_data,
-                         const uint16_t filter_wd,
-                         const uint16_t filter_ht,
                          const int32_t *bias,
                          const int32_t *bias,
+                         const data_dims_t *output_dims,
                          int8_t *out_data,
                          int8_t *out_data,
-                         const uint16_t out_wd,
-                         const uint16_t out_ht,
-                         const uint16_t out_channels,
-                         const int32_t out_offset,
-                         const int32_t *out_shift,
-                         const int32_t *out_mult,
-                         const int32_t activation_min,
-                         const int32_t activation_max)
+                         const conv_params_t *conv_params,
+                         const quant_data_t *quant_data)
 {
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t in_channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_channels = output_dims->channels;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
     int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
     int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
 
 
     for (out_y = 0; out_y < out_ht; out_y++) {
     for (out_y = 0; out_y < out_ht; out_y++) {

+ 106 - 79
code/components/esp-nn/src/convolution/esp_nn_conv_esp32s3.c

@@ -12,30 +12,30 @@
 // See the License for the specific language governing permissions and
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // limitations under the License.
 
 
-#include <stdint.h>
 #include <stdio.h>
 #include <stdio.h>
+#include <esp_nn_defs.h>
 
 
 #include <common_functions.h>
 #include <common_functions.h>
 
 
 static int16_t *scratch_buffer = NULL;
 static int16_t *scratch_buffer = NULL;
 
 
-extern void esp_nn_conv_s16_mult8_1x1_esp32s3(const int8_t *input_data,
-                                              const uint16_t input_wd,
-                                              const uint16_t input_ht,
-                                              const uint16_t in_channels,
-                                              const int32_t input_offset,
-                                              const int16_t *filter_data,
-                                              const int32_t *bias,
-                                              int8_t *out_data,
-                                              const uint16_t out_wd,
-                                              const uint16_t out_ht,
-                                              const uint16_t out_channels,
-                                              const int32_t out_offset,
-                                              const int32_t *out_shift,
-                                              const int32_t *out_mult,
-                                              const int32_t activation_min,
-                                              const int32_t activation_max,
-                                              void *buffer /* scratch buffer */);
+extern void esp_nn_conv_s8_mult8_1x1_esp32s3(const int8_t *input_data,
+                                             const uint16_t input_wd,
+                                             const uint16_t input_ht,
+                                             const uint16_t in_channels,
+                                             const int32_t input_offset,
+                                             const int8_t *filter_aligned,
+                                             const int32_t *bias,
+                                             int8_t *out_data,
+                                             const uint16_t out_wd,
+                                             const uint16_t out_ht,
+                                             const uint16_t out_channels,
+                                             const int32_t out_offset,
+                                             const int32_t *out_shift,
+                                             const int32_t *out_mult,
+                                             const int32_t activation_min,
+                                             const int32_t activation_max,
+                                             void *buffer /* scratch buffer */);
 
 
 extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
 extern void esp_nn_conv_s16_mult4_1x1_esp32s3(const int16_t *input_data,
                                               const uint16_t input_wd,
                                               const uint16_t input_wd,
@@ -81,34 +81,40 @@ extern void esp_nn_aligned_s8_to_s16_with_offset_esp32s3(const int8_t *src, int1
 
 
 extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
 extern void esp_nn_s8_to_s16_esp32s3(const int8_t *src, int16_t *dst, const int size);
 
 
-static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
-                                    const uint16_t input_wd,
-                                    const uint16_t input_ht,
-                                    const uint16_t in_channels,
-                                    const int32_t input_offset,
-                                    const uint16_t pad_wd,
-                                    const uint16_t pad_ht,
-                                    const uint16_t stride_wd,
-                                    const uint16_t stride_ht,
+static void esp_nn_conv_s8_unrolled(const data_dims_t *input_dims,
+                                    const int8_t *input_data,
+                                    const data_dims_t *filter_dims,
                                     const int8_t *filter_data,
                                     const int8_t *filter_data,
-                                    const uint16_t filter_wd,
-                                    const uint16_t filter_ht,
                                     const int32_t *bias,
                                     const int32_t *bias,
+                                    const data_dims_t *output_dims,
                                     int8_t *out_data,
                                     int8_t *out_data,
-                                    const uint16_t out_wd,
-                                    const uint16_t out_ht,
-                                    const uint16_t out_channels,
-                                    const int32_t out_offset,
-                                    const int32_t *out_shift,
-                                    const int32_t *out_mult,
-                                    const int32_t activation_min,
-                                    const int32_t activation_max)
+                                    const conv_params_t *conv_params,
+                                    const quant_data_t *quant_data)
 {
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t in_ch = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_ch = output_dims->channels;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
     int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
     int32_t out_ch_idx, out_y, out_x, in_ch_idx, filter_y_idx, filter_x_idx;
 
 
     for (out_y = 0; out_y < out_ht; out_y++) {
     for (out_y = 0; out_y < out_ht; out_y++) {
         for (out_x = 0; out_x < out_wd; out_x++) {
         for (out_x = 0; out_x < out_wd; out_x++) {
-            for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+            for (out_ch_idx = 0; out_ch_idx < out_ch; out_ch_idx++) {
                 int32_t conv_out = 0;
                 int32_t conv_out = 0;
 
 
                 const int32_t base_y = stride_ht * out_y - pad_ht;
                 const int32_t base_y = stride_ht * out_y - pad_ht;
@@ -124,10 +130,10 @@ static void esp_nn_conv_s8_unrolled(const int8_t *input_data,
                     for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
                     for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
                         const int32_t in_row = base_y + filter_y_idx;
                         const int32_t in_row = base_y + filter_y_idx;
                         const int32_t in_col = base_x + filter_x_idx;
                         const int32_t in_col = base_x + filter_x_idx;
-                        int32_t input_base_offset = (in_row * input_wd + in_col) * in_channels;
-                        int32_t filter_base_offset = out_ch_idx * in_channels * filter_ht * filter_wd +
-                                                       (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
-                        for (in_ch_idx = 0; in_ch_idx < in_channels; in_ch_idx++) {
+                        int32_t input_base_offset = (in_row * input_wd + in_col) * in_ch;
+                        int32_t filter_base_offset = out_ch_idx * in_ch * filter_ht * filter_wd +
+                                                       (filter_y_idx * filter_wd + filter_x_idx) * in_ch;
+                        for (in_ch_idx = 0; in_ch_idx < in_ch; in_ch_idx++) {
                             conv_out +=
                             conv_out +=
                                 (input_data[input_base_offset + in_ch_idx] + input_offset) *
                                 (input_data[input_base_offset + in_ch_idx] + input_offset) *
                                 filter_data[filter_base_offset + in_ch_idx];
                                 filter_data[filter_base_offset + in_ch_idx];
@@ -332,18 +338,35 @@ static void esp_nn_conv_s8_pad_valid_ch3_3x3(const int8_t *input_data,
     }
     }
 }
 }
 
 
-int esp_nn_get_conv_scratch_size_esp32s3(const uint16_t input_wd,
-                                         const uint16_t input_ht,
-                                         const uint16_t in_ch,
-                                         const uint16_t out_ch,
-                                         const uint16_t filter_wd,
-                                         const uint16_t filter_ht)
+int esp_nn_get_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+                                         const data_dims_t *filter_dims,
+                                         const data_dims_t *output_dims,
+                                         const conv_params_t *conv_params)
 {
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t in_ch = input_dims->channels;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_ch = output_dims->channels;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+
     int filter_size = filter_wd * filter_ht * in_ch * out_ch;
     int filter_size = filter_wd * filter_ht * in_ch * out_ch;
     int input_size = input_wd * input_ht * in_ch;
     int input_size = input_wd * input_ht * in_ch;
-    int transpose_buf_size = 8 * in_ch; /* to store intermediate data */
+
+    int transpose_buf_size = 2 * (8 * in_ch); /* to store intermediate data */
+    if (input_wd * input_ht < 8) {
+        transpose_buf_size = 0; // not using this for leftover
+    }
     int align_buf_size = 32; /* extra buffer for alignment */
     int align_buf_size = 32; /* extra buffer for alignment */
-    return 2 * (filter_size + input_size +  transpose_buf_size) + align_buf_size;
+    if (in_ch % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
+            pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
+        return filter_size + transpose_buf_size + align_buf_size;
+    }
+    return 2 * (filter_size + input_size) +  transpose_buf_size + align_buf_size;
 }
 }
 
 
 void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
 void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
@@ -351,29 +374,35 @@ void esp_nn_set_conv_scratch_buf_esp32s3(void *buf)
     scratch_buffer = (int16_t *) buf;
     scratch_buffer = (int16_t *) buf;
 }
 }
 
 
-void esp_nn_conv_s8_esp32s3(const int8_t *input,
-                            const uint16_t input_wd,
-                            const uint16_t input_ht,
-                            const uint16_t channels,
-                            const int32_t input_offset,
-                            const uint16_t pad_wd,
-                            const uint16_t pad_ht,
-                            const uint16_t stride_wd,
-                            const uint16_t stride_ht,
+void esp_nn_conv_s8_esp32s3(const data_dims_t *input_dims,
+                            const int8_t *input,
+                            const data_dims_t *filter_dims,
                             const int8_t *filter_data,
                             const int8_t *filter_data,
-                            const uint16_t filter_wd,
-                            const uint16_t filter_ht,
                             const int32_t *bias,
                             const int32_t *bias,
+                            const data_dims_t *output_dims,
                             int8_t *out_data,
                             int8_t *out_data,
-                            const uint16_t out_wd,
-                            const uint16_t out_ht,
-                            const uint16_t out_channels,
-                            const int32_t out_offset,
-                            const int32_t *out_shift,
-                            const int32_t *out_mult,
-                            const int32_t activation_min,
-                            const int32_t activation_max)
+                            const conv_params_t *conv_params,
+                            const quant_data_t *quant_data)
 {
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_channels = output_dims->channels;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
     int filter_size = filter_wd * filter_ht * channels * out_channels;
     int filter_size = filter_wd * filter_ht * channels * out_channels;
     int input_size = input_wd * input_ht * channels;
     int input_size = input_wd * input_ht * channels;
     int align_len = 16 - (filter_size & 15);
     int align_len = 16 - (filter_size & 15);
@@ -387,15 +416,16 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input,
 
 
     if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
     if (channels % 8 == 0 && filter_wd == 1 && filter_ht == 1 &&
             pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
             pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
-        int scratch_offset = (int) (filter_data16 + filter_size);
+        int8_t *filter_aligned = (int8_t *) scratch_buffer;
+        int scratch_offset = (int) (filter_aligned + filter_size);
         void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
         void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
-        esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
-        esp_nn_conv_s16_mult8_1x1_esp32s3(
-            input, input_wd, input_ht, channels, input_offset, filter_data16,
+        memcpy(filter_aligned, filter_data, filter_size); // copy to aligned address
+        esp_nn_conv_s8_mult8_1x1_esp32s3(
+            input, input_wd, input_ht, channels, input_offset, filter_aligned,
             bias, out_data, out_wd, out_ht, out_channels, out_offset,
             bias, out_data, out_wd, out_ht, out_channels, out_offset,
             out_shift, out_mult, activation_min, activation_max, scratch_buf);
             out_shift, out_mult, activation_min, activation_max, scratch_buf);
     } else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
     } else if (channels % 4 == 0 && filter_wd == 1 && filter_ht == 1 &&
-            (input_wd * input_ht) % 16 == 0 && /* TODO: remove this check */
+            (input_wd * input_ht) % 4 == 0 && /* TODO: remove this check */
             pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
             pad_wd == 0 && pad_ht == 0 && stride_wd == 1 && stride_ht == 1) {
         int scratch_offset = (int) (input_data16 + input_size);
         int scratch_offset = (int) (input_data16 + input_size);
         void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
         void *scratch_buf = (void *) (scratch_offset + 16 - (scratch_offset & 15));
@@ -427,10 +457,7 @@ void esp_nn_conv_s8_esp32s3(const int8_t *input,
         }
         }
     } else {
     } else {
         /* Basic unrolled version */
         /* Basic unrolled version */
-        esp_nn_conv_s8_unrolled(input, input_wd, input_ht, channels, input_offset,
-                                pad_wd, pad_ht, stride_wd, stride_ht,
-                                filter_data, filter_wd, filter_ht, bias,
-                                out_data, out_wd, out_ht, out_channels, out_offset, out_shift,
-                                out_mult, activation_min, activation_max);
+        esp_nn_conv_s8_unrolled(input_dims, input, filter_dims, filter_data,
+                                bias, output_dims, out_data, conv_params, quant_data);
     }
     }
 }
 }

+ 179 - 0
code/components/esp-nn/src/convolution/esp_nn_conv_opt.c

@@ -0,0 +1,179 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <esp_nn_defs.h>
+
+#include <common_functions.h>
+
+int esp_nn_get_conv_scratch_size_opt(const data_dims_t *input_dims,
+                                     const data_dims_t *filter_dims,
+                                     const data_dims_t *output_dims,
+                                     const conv_params_t *conv_params)
+{
+    return 0;
+}
+
+void esp_nn_set_conv_scratch_buf_opt(const void *buf)
+{
+
+}
+
+__attribute__ ((noinline))
+static void esp_nn_conv_s8_1x1(const data_dims_t *input_dims,
+                               const int8_t *input_data,
+                               const int8_t *filter_data,
+                               const int32_t *bias,
+                               const data_dims_t *output_dims,
+                               int8_t *out_data,
+                               const conv_params_t *conv_params,
+                               const quant_data_t *quant_data)
+{
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t in_channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_channels = output_dims->channels;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
+    for (int32_t in_row = 0; in_row < out_ht * stride_ht; in_row += stride_ht) {
+        for (int32_t in_col = 0; in_col < out_wd * stride_wd; in_col += stride_wd) {
+            const int32_t *out_mult = quant_data->mult;
+            const int32_t *out_shift = quant_data->shift;
+            const int8_t *filter_ptr = filter_data;
+            const int8_t *input_base_ptr = input_data + (in_row * input_wd + in_col) * in_channels;
+            int32_t out_ch_idx = 0;
+            for (; out_ch_idx < out_channels; out_ch_idx++) {
+                int32_t conv_out = 0;
+
+                const int8_t *input_ptr = input_base_ptr;
+
+                int32_t in_ch_idx = 0;
+                for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                }
+                for (; in_ch_idx < in_channels; in_ch_idx ++) {
+                    conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                }
+                if (bias) {
+                    conv_out += bias[out_ch_idx];
+                }
+                conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
+                conv_out += out_offset;
+                conv_out = max(conv_out, activation_min);
+                conv_out = min(conv_out, activation_max);
+                *out_data++ = (int8_t) conv_out;
+            }
+        }
+    }
+}
+
+/**
+ * Assumption 1: i/p channels == o/p channels
+ * Assumption 2: Pointers are valid
+ * Assumption 3: dialation width = 1
+ */
+void esp_nn_conv_s8_opt(const data_dims_t *input_dims,
+                        const int8_t *input_data,
+                        const data_dims_t *filter_dims,
+                        const int8_t *filter_data,
+                        const int32_t *bias,
+                        const data_dims_t *output_dims,
+                        int8_t *out_data,
+                        const conv_params_t *conv_params,
+                        const quant_data_t *quant_data)
+{
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+
+    if (filter_wd == 1 && filter_ht == 1) {
+        esp_nn_conv_s8_1x1(input_dims, input_data, filter_data, bias,
+                           output_dims, out_data, conv_params, quant_data);
+        return;
+    }
+
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t in_channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t out_channels = output_dims->channels;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
+    int32_t out_ch_idx, out_y, out_x, filter_y_idx, filter_x_idx;
+
+    for (out_y = 0; out_y < out_ht; out_y++) {
+        for (out_x = 0; out_x < out_wd; out_x++) {
+            const int32_t *out_shift = quant_data->shift;
+            const int32_t *out_mult = quant_data->mult;
+            for (out_ch_idx = 0; out_ch_idx < out_channels; out_ch_idx++) {
+                int32_t conv_out = 0;
+
+                const int32_t base_y = stride_ht * out_y - pad_ht;
+                const int32_t base_x = stride_wd * out_x - pad_wd;
+
+                const int32_t filter_y_start = max(0, -base_y);
+                const int32_t filter_x_start = max(0, -base_x);
+
+                const int32_t filter_y_end = min(filter_ht, input_ht - base_y);
+                const int32_t filter_x_end = min(filter_wd, input_wd - base_x);
+
+                for (filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                    for (filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                        const int32_t in_row = base_y + filter_y_idx;
+                        const int32_t in_col = base_x + filter_x_idx;
+
+                        const int8_t *input_ptr = input_data +
+                                        (in_row * input_wd + in_col) * in_channels;
+                        const int8_t *filter_ptr = filter_data +
+                                        out_ch_idx * in_channels * filter_ht * filter_wd +
+                                        (filter_y_idx * filter_wd + filter_x_idx) * in_channels;
+                        int32_t in_ch_idx = 0;
+                        for (; in_ch_idx < in_channels - 3; in_ch_idx += 4) {
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                        }
+                        for (; in_ch_idx < in_channels; in_ch_idx ++) {
+                            conv_out += (*input_ptr++ + input_offset) * *filter_ptr++;
+                        }
+                    }
+                }
+                if (bias) {
+                    conv_out += bias[out_ch_idx];
+                }
+                conv_out = esp_nn_multiply_by_quantized_mult_fast(conv_out, *out_mult++, *out_shift++);
+                conv_out += out_offset;
+                conv_out = max(conv_out, activation_min);
+                conv_out = min(conv_out, activation_max);
+                *out_data++ = (int8_t) conv_out;
+            }
+        }
+    }
+}

+ 30 - 27
code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_ansi.c

@@ -12,16 +12,13 @@
 // See the License for the specific language governing permissions and
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // limitations under the License.
 
 
-#include <stdint.h>
-
+#include <esp_nn_defs.h>
 #include <common_functions.h>
 #include <common_functions.h>
 
 
-int esp_nn_get_depthwise_conv_scratch_size_ansi(const uint16_t input_wd,
-                                                const uint16_t input_ht,
-                                                const uint16_t channels,
-                                                const uint16_t ch_mult,
-                                                const uint16_t filter_wd,
-                                                const uint16_t filter_ht)
+int esp_nn_get_depthwise_conv_scratch_size_ansi(const data_dims_t *input_dims,
+                                                const data_dims_t *filter_dims,
+                                                const data_dims_t *output_dims,
+                                                const dw_conv_params_t *conv_params)
 {
 {
     return 0;
     return 0;
 }
 }
@@ -31,29 +28,35 @@ void esp_nn_set_depthwise_conv_scratch_buf_ansi(const void *buf)
 
 
 }
 }
 
 
-void esp_nn_depthwise_conv_s8_ansi(const int8_t *input_data,
-                                   const uint16_t input_wd,
-                                   const uint16_t input_ht,
-                                   const uint16_t channels,
-                                   const int32_t input_offset,
-                                   const uint16_t pad_wd,
-                                   const uint16_t pad_ht,
-                                   const uint16_t stride_wd,
-                                   const uint16_t stride_ht,
-                                   const uint16_t ch_mult,
+void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
+                                   const int8_t *input_data,
+                                   const data_dims_t *filter_dims,
                                    const int8_t *filter_data,
                                    const int8_t *filter_data,
-                                   const uint16_t filter_wd,
-                                   const uint16_t filter_ht,
                                    const int32_t *bias,
                                    const int32_t *bias,
+                                   const data_dims_t *output_dims,
                                    int8_t *out_data,
                                    int8_t *out_data,
-                                   const uint16_t out_wd,
-                                   const uint16_t out_ht,
-                                   const int32_t out_offset,
-                                   const int32_t *out_shift,
-                                   const int32_t *out_mult,
-                                   const int32_t activation_min,
-                                   const int32_t activation_max)
+                                   const dw_conv_params_t *conv_params,
+                                   const quant_data_t *quant_data)
 {
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+    const uint16_t ch_mult = conv_params->ch_mult;
+
     int out_idx = 0;
     int out_idx = 0;
     for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
     for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
         const int16_t base_y = (out_y * stride_ht) - pad_ht;
         const int16_t base_y = (out_y * stride_ht) - pad_ht;

+ 291 - 0
code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_opt.c

@@ -0,0 +1,291 @@
+// Copyright 2020-2021 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <esp_nn_defs.h>
+#include <common_functions.h>
+
+int esp_nn_get_depthwise_conv_scratch_size_opt(const data_dims_t *input_dims,
+                                               const data_dims_t *filter_dims,
+                                               const data_dims_t *output_dims,
+                                               const dw_conv_params_t *conv_params)
+{
+    return 0;
+}
+
+void esp_nn_set_depthwise_conv_scratch_buf_opt(const void *buf)
+{
+
+}
+
+/* common channel multiplier == 1 case */
+__attribute__ ((noinline))
+static void esp_nn_depthwise_conv_s8_ch_mult_1(const data_dims_t *input_dims,
+                                               const int8_t *input_data,
+                                               const data_dims_t *filter_dims,
+                                               const int8_t *filter_data,
+                                               const int32_t *bias,
+                                               const data_dims_t *output_dims,
+                                               int8_t *out_data,
+                                               const dw_conv_params_t *conv_params,
+                                               const quant_data_t *quant_data)
+{
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
+    int out_idx = 0;
+    for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+        const int16_t base_y = (out_y * stride_ht) - pad_ht;
+        for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+            const int16_t base_x = (out_x * stride_wd) - pad_wd;
+
+            const int32_t *out_shift = quant_data->shift;
+            const int32_t *out_mult = quant_data->mult;
+
+            /* Select filter so as the point doesn't lie outside block */
+            int filter_y_start = max(0, -base_y);
+            int filter_x_start = max(0, -base_x);
+            int filter_y_end = min(filter_ht, input_ht - base_y);
+            int filter_x_end = min(filter_wd, input_wd - base_x);
+
+            int ch_idx = 0;
+            for (; ch_idx < channels - 3; ch_idx += 4) {//channel_loop
+                int32_t result0 = 0;
+                int32_t result1 = 0;
+                int32_t result2 = 0;
+                int32_t result3 = 0;
+
+                for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                    const int32_t idx_y = base_y + filter_y_idx;
+                    for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                        const int32_t idx_x = base_x + filter_x_idx;
+                        int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+                        int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
+                        int32_t input_val0 = input_data[input_index + 0] + input_offset;
+                        int32_t input_val1 = input_data[input_index + 1] + input_offset;
+                        int32_t input_val2 = input_data[input_index + 2] + input_offset;
+                        int32_t input_val3 = input_data[input_index + 3] + input_offset;
+                        int32_t filter_val0 = filter_data[filter_index + 0];
+                        int32_t filter_val1 = filter_data[filter_index + 1];
+                        int32_t filter_val2 = filter_data[filter_index + 2];
+                        int32_t filter_val3 = filter_data[filter_index + 3];
+                        result0 += input_val0 * filter_val0;
+                        result1 += input_val1 * filter_val1;
+                        result2 += input_val2 * filter_val2;
+                        result3 += input_val3 * filter_val3;
+                    }
+                }
+                if (bias) {
+                    result0 += bias[ch_idx + 0];
+                    result1 += bias[ch_idx + 1];
+                    result2 += bias[ch_idx + 2];
+                    result3 += bias[ch_idx + 3];
+                }
+                result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
+                result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
+                result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
+                result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
+
+                result0 += out_offset;
+                result1 += out_offset;
+                result2 += out_offset;
+                result3 += out_offset;
+
+                result0 = max(result0, activation_min);
+                result1 = max(result1, activation_min);
+                result2 = max(result2, activation_min);
+                result3 = max(result3, activation_min);
+
+                result0 = min(result0, activation_max);
+                result1 = min(result1, activation_max);
+                result2 = min(result2, activation_max);
+                result3 = min(result3, activation_max);
+
+                out_data[out_idx++] = result0;
+                out_data[out_idx++] = result1;
+                out_data[out_idx++] = result2;
+                out_data[out_idx++] = result3;
+            }
+            for (; ch_idx < channels; ch_idx++) {//channel_loop
+                int32_t result = 0;
+
+                for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                    const int32_t idx_y = base_y + filter_y_idx;
+                    for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                        const int32_t idx_x = base_x + filter_x_idx;
+                        int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+                        int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels) + ch_idx;
+                        int32_t input_val = input_data[input_index] + input_offset;
+                        int32_t filter_val = filter_data[filter_index];
+                        result += input_val * filter_val;
+                    }
+                }
+                if (bias) {
+                    result += bias[ch_idx];
+                }
+                result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
+                result += out_offset;
+                result = max(result, activation_min);
+                result = min(result, activation_max);
+
+                out_data[out_idx++] = result;
+            }
+        }
+    }
+}
+
+void esp_nn_depthwise_conv_s8_opt(const data_dims_t *input_dims,
+                                  const int8_t *input_data,
+                                  const data_dims_t *filter_dims,
+                                  const int8_t *filter_data,
+                                  const int32_t *bias,
+                                  const data_dims_t *output_dims,
+                                  int8_t *out_data,
+                                  const dw_conv_params_t *conv_params,
+                                  const quant_data_t *quant_data)
+{
+    const uint16_t ch_mult = conv_params->ch_mult;
+    if (ch_mult == 1) {
+        esp_nn_depthwise_conv_s8_ch_mult_1(input_dims, input_data, filter_dims, filter_data,
+                                           bias, output_dims, out_data, conv_params, quant_data);
+        return;
+    }
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+
+    int out_idx = 0;
+    for (int out_y = 0; out_y < out_ht; out_y++) { //height loop
+        const int16_t base_y = (out_y * stride_ht) - pad_ht;
+        for (int out_x = 0; out_x < out_wd; out_x++) { //width_loop
+            const int16_t base_x = (out_x * stride_wd) - pad_wd;
+
+            const int32_t *out_shift = quant_data->shift;
+            const int32_t *out_mult = quant_data->mult;
+
+            /* Select filter so as the point doesn't lie outside block */
+            int filter_y_start = max(0, -base_y);
+            int filter_x_start = max(0, -base_x);
+            int filter_y_end = min(filter_ht, input_ht - base_y);
+            int filter_x_end = min(filter_wd, input_wd - base_x);
+
+            for (int ch_idx = 0; ch_idx < channels; ch_idx++) {//channel_loop
+                int ch_mult_idx = 0;
+                for (; ch_mult_idx < ch_mult - 3; ch_mult_idx += 4) {
+                    int32_t result0 = 0;
+                    int32_t result1 = 0;
+                    int32_t result2 = 0;
+                    int32_t result3 = 0;
+                    const int out_ch_idx =  ch_idx * ch_mult + ch_mult_idx;
+
+                    for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                        const int32_t idx_y = base_y + filter_y_idx;
+                        for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                            const int32_t idx_x = base_x + filter_x_idx;
+                            int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+                            int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
+                            int32_t input_val = input_data[input_index] + input_offset;
+                            int32_t filter_val0 = filter_data[filter_index + 0];
+                            int32_t filter_val1 = filter_data[filter_index + 1];
+                            int32_t filter_val2 = filter_data[filter_index + 2];
+                            int32_t filter_val3 = filter_data[filter_index + 3];
+                            result0 += input_val * filter_val0;
+                            result1 += input_val * filter_val1;
+                            result2 += input_val * filter_val2;
+                            result3 += input_val * filter_val3;
+                        }
+                    }
+                    if (bias) {
+                        result0 += bias[out_ch_idx + 0];
+                        result1 += bias[out_ch_idx + 1];
+                        result2 += bias[out_ch_idx + 2];
+                        result3 += bias[out_ch_idx + 3];
+                    }
+                    result0 = esp_nn_multiply_by_quantized_mult_fast(result0, *out_mult++, *out_shift++);
+                    result1 = esp_nn_multiply_by_quantized_mult_fast(result1, *out_mult++, *out_shift++);
+                    result2 = esp_nn_multiply_by_quantized_mult_fast(result2, *out_mult++, *out_shift++);
+                    result3 = esp_nn_multiply_by_quantized_mult_fast(result3, *out_mult++, *out_shift++);
+
+                    result0 += out_offset;
+                    result1 += out_offset;
+                    result2 += out_offset;
+                    result3 += out_offset;
+
+                    result0 = max(result0, activation_min);
+                    result1 = max(result1, activation_min);
+                    result2 = max(result2, activation_min);
+                    result3 = max(result3, activation_min);
+                    result0 = min(result0, activation_max);
+                    result1 = min(result1, activation_max);
+                    result2 = min(result2, activation_max);
+                    result3 = min(result3, activation_max);
+
+                    out_data[out_idx++] = result0;
+                    out_data[out_idx++] = result1;
+                    out_data[out_idx++] = result2;
+                    out_data[out_idx++] = result3;
+                }
+                for (; ch_mult_idx < ch_mult; ch_mult_idx++) {
+                    int32_t result = 0;
+                    const int out_ch_idx =  ch_idx * ch_mult + ch_mult_idx;
+
+                    for (int filter_y_idx = filter_y_start; filter_y_idx < filter_y_end; filter_y_idx++) {
+                        const int32_t idx_y = base_y + filter_y_idx;
+                        for (int filter_x_idx = filter_x_start; filter_x_idx < filter_x_end; filter_x_idx++) {
+                            const int32_t idx_x = base_x + filter_x_idx;
+                            int32_t input_index = (idx_y * input_wd + idx_x) * channels + ch_idx;
+                            int32_t filter_index = (filter_y_idx * filter_wd + filter_x_idx) * (channels * ch_mult) + out_ch_idx;
+                            int32_t input_val = input_data[input_index] + input_offset;
+                            int32_t filter_val = filter_data[filter_index];
+                            result += input_val * filter_val;
+                        }
+                    }
+                    if (bias) {
+                        result += bias[out_ch_idx];
+                    }
+                    result = esp_nn_multiply_by_quantized_mult_fast(result, *out_mult++, *out_shift++);
+                    result += out_offset;
+                    result = max(result, activation_min);
+                    result = min(result, activation_max);
+
+                    out_data[out_idx++] = result;
+                }
+            }
+        }
+    }
+}

+ 97 - 37
code/components/esp-nn/src/convolution/esp_nn_depthwise_conv_s8_esp32s3.c

@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // limitations under the License.
 
 
-#include <stdint.h>
 #include <stdio.h>
 #include <stdio.h>
+#include <esp_nn_defs.h>
 
 
 #include <common_functions.h>
 #include <common_functions.h>
 
 
@@ -353,17 +353,59 @@ void esp_nn_depthwise_conv_s8_ch_mult1(const int8_t *input_data,
     }
     }
 }
 }
 
 
-int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const uint16_t input_wd,
-                                                   const uint16_t input_ht,
-                                                   const uint16_t channels,
-                                                   const uint16_t ch_mult,
-                                                   const uint16_t filter_wd,
-                                                   const uint16_t filter_ht)
+int esp_nn_get_depthwise_conv_scratch_size_esp32s3(const data_dims_t *input_dims,
+                                                   const data_dims_t *filter_dims,
+                                                   const data_dims_t *output_dims,
+                                                   const dw_conv_params_t *conv_params)
 {
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t ch_mult = conv_params->ch_mult;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+
     int filter_size = filter_wd * filter_ht * channels * ch_mult;
     int filter_size = filter_wd * filter_ht * channels * ch_mult;
-    int padding_used = ((filter_wd == 3) && (filter_ht == 3)) * 2;
-    int input_size = (input_wd + padding_used) * (input_ht + padding_used) * channels;
-    return  2 * (filter_size + input_size) + 16; //16 for alignment
+    int pad_width = 0, pad_height = 0;
+
+    if ((ch_mult == 1) && (channels % 8 == 0) && (filter_wd == 3) && (filter_ht == 3)) {
+        if (channels % 16 == 0) {
+            if (pad_wd || pad_ht) {
+                pad_width = pad_wd * 2;
+                pad_height = pad_ht * 2;
+            } else {
+                // check if we need to pad additionally
+                pad_width = (out_wd * stride_wd + filter_wd - 1) - input_wd;
+                pad_height = (out_ht * stride_ht + filter_ht - 1) - input_ht;
+                // printf("in(%d %d %d), out(%d %d), filter (%d %d) stride (%d %d), pad (%d %d)",
+                //         input_wd, input_ht, channels, out_wd, out_ht, filter_wd, filter_ht,
+                //         stride_wd, stride_ht, pad_wd, pad_ht);
+            }
+            if (pad_width || pad_height) {
+                int input_size = (input_wd + pad_width) * (input_ht + pad_height) * channels;
+                // printf("ask1 %d\n", filter_size + input_size + 16);
+                return filter_size + input_size + 16;  // 16 for alignment
+            } else {
+                // printf("ask2 %d\n", filter_size + 16);
+                return filter_size + 16;  // 16 for alignment
+            }
+        } else {
+            int input_size = input_wd * input_ht * channels;
+            // printf("ask3 %d\n", 2 * (filter_size + input_size) + 16);
+            return  2 * (filter_size + input_size) + 16; // 16 for alignment
+        }
+    } else if (ch_mult % 4 == 0) {
+        int input_size = input_wd * input_ht * channels;
+        // printf("ask4 %d\n", 2 * (filter_size + input_size) + 16);
+        return  2 * (filter_size + input_size) + 16; // 16 for alignment
+    }
+    return 32; // just few bytes
 }
 }
 
 
 void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
 void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
@@ -376,29 +418,38 @@ void esp_nn_set_depthwise_conv_scratch_buf_esp32s3(void *buf)
  * Assumption 2: Pointers are valid
  * Assumption 2: Pointers are valid
  * Assumption 3: dialation width = 1
  * Assumption 3: dialation width = 1
  */
  */
-void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
-                                      const uint16_t input_wd,
-                                      const uint16_t input_ht,
-                                      const uint16_t channels,
-                                      const int32_t input_offset,
-                                      const uint16_t pad_wd,
-                                      const uint16_t pad_ht,
-                                      const uint16_t stride_wd,
-                                      const uint16_t stride_ht,
-                                      const uint16_t ch_mult,
+
+
+
+void esp_nn_depthwise_conv_s8_esp32s3(const data_dims_t *input_dims,
+                                      const int8_t *input_data,
+                                      const data_dims_t *filter_dims,
                                       const int8_t *filter_data,
                                       const int8_t *filter_data,
-                                      const uint16_t filter_wd,
-                                      const uint16_t filter_ht,
                                       const int32_t *bias,
                                       const int32_t *bias,
+                                      const data_dims_t *output_dims,
                                       int8_t *out_data,
                                       int8_t *out_data,
-                                      const uint16_t out_wd,
-                                      const uint16_t out_ht,
-                                      const int32_t out_offset,
-                                      const int32_t *out_shift,
-                                      const int32_t *out_mult,
-                                      const int32_t activation_min,
-                                      const int32_t activation_max)
+                                      const dw_conv_params_t *conv_params,
+                                      const quant_data_t *quant_data)
 {
 {
+    const uint16_t input_wd = input_dims->width;
+    const uint16_t input_ht = input_dims->height;
+    const uint16_t channels = input_dims->channels;
+    const int32_t input_offset = conv_params->in_offset;
+    const int32_t out_offset = conv_params->out_offset;
+    const uint16_t pad_wd = conv_params->padding.width;
+    const uint16_t pad_ht = conv_params->padding.height;
+    const uint16_t stride_wd = conv_params->stride.width;
+    const uint16_t stride_ht = conv_params->stride.height;
+    const uint16_t filter_wd = filter_dims->width;
+    const uint16_t filter_ht = filter_dims->height;
+    const uint16_t out_wd = output_dims->width;
+    const uint16_t out_ht = output_dims->height;
+    const int32_t *out_shift = quant_data->shift;
+    const int32_t *out_mult = quant_data->mult;
+    const int32_t activation_min = conv_params->activation.min;
+    const int32_t activation_max = conv_params->activation.max;
+    const uint16_t ch_mult = conv_params->ch_mult;
+
     int filter_size = filter_wd * filter_ht * channels * ch_mult;
     int filter_size = filter_wd * filter_ht * channels * ch_mult;
     int align_len = 16 - (filter_size & 15);
     int align_len = 16 - (filter_size & 15);
     int input_size = input_wd * input_ht * channels;
     int input_size = input_wd * input_ht * channels;
@@ -423,18 +474,27 @@ void esp_nn_depthwise_conv_s8_esp32s3(const int8_t *input_data,
                                                                   stride_wd, stride_ht, filter_aligned, bias,
                                                                   stride_wd, stride_ht, filter_aligned, bias,
                                                                   out_data, out_wd, out_ht, out_offset, out_shift,
                                                                   out_data, out_wd, out_ht, out_offset, out_shift,
                                                                   out_mult, activation_min, activation_max);
                                                                   out_mult, activation_min, activation_max);
-            } else if ((pad_wd == 0) && (pad_ht == 0) &&
-                    // because this does not handle padding offset cases yet, run just for stride (1, 1).
-                    // end padding of input with `-input_offset` should solve this
-                    (stride_wd == 1) && (stride_ht == 1)) {
+            } else if ((channels % 16 == 0) && (pad_wd == 0) && (pad_ht == 0)) {
                 /* process in 8 bits */
                 /* process in 8 bits */
                 int8_t *filter_aligned = (int8_t *) scratch_buffer;
                 int8_t *filter_aligned = (int8_t *) scratch_buffer;
+                int8_t *input_padded = (int8_t *) scratch_buffer + filter_size + align_len;
+
+                // check if we need to pad additionally
+                int pad_right = (out_wd * stride_wd + filter_wd - 1) - input_wd;
+                int pad_bottom = (out_ht * stride_ht + filter_ht - 1) - input_ht;
+                if (pad_right || pad_bottom) { // pad right and bottom
+                    esp_nn_aligned_s8_pad_end_with_value(input_data, input_padded, input_wd, input_ht,
+                                                         channels, -input_offset, pad_right, pad_bottom);
+                } else {
+                    input_padded = (int8_t *) input_data;
+                }
                 memcpy(filter_aligned, filter_data, filter_size);
                 memcpy(filter_aligned, filter_data, filter_size);
-                esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_data, input_wd, input_ht, channels, input_offset,
-                                                                  stride_wd, stride_ht, filter_aligned,
-                                                                  bias, out_data, out_wd, out_ht, out_offset, out_shift,
+                esp_nn_depthwise_conv_s8_mult1_3x3_padded_esp32s3(input_padded, input_wd + pad_right,
+                                                                  input_ht + pad_bottom, channels, input_offset,
+                                                                  stride_wd, stride_ht, filter_aligned, bias,
+                                                                  out_data, out_wd, out_ht, out_offset, out_shift,
                                                                   out_mult, activation_min, activation_max);
                                                                   out_mult, activation_min, activation_max);
-            } else { /* (channels % 8) == 0 && pad_wd == 1 && pad_ht == 1 */
+            } else { /* (channels % 8) == 0 */
                 esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
                 esp_nn_s8_to_s16_esp32s3(filter_data, filter_data16, filter_size);
                 esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
                 esp_nn_aligned_s8_to_s16_with_offset_esp32s3(input_data, input_data16, input_size, input_offset);
                 esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(input_data16, input_wd, input_ht, channels,
                 esp_nn_depthwise_conv_s16_mult1_3x3_esp32s3(input_data16, input_wd, input_ht, channels,

+ 8 - 0
code/components/esp-nn/test_app/sdkconfig.defaults.esp32s3

@@ -0,0 +1,8 @@
+# Default configurations for ESP32-S3
+
+CONFIG_ESP32S3_DEFAULT_CPU_FREQ_240=y
+CONFIG_ESP32S3_SPIRAM_SUPPORT=y
+
+CONFIG_ESP32S3_DATA_CACHE_64KB=y
+CONFIG_ESP32S3_DATA_CACHE_8WAYS=y
+CONFIG_ESP32S3_DATA_CACHE_LINE_64B=y

+ 16 - 4
code/components/esp-nn/tests/src/basic_math_test.c

@@ -23,7 +23,9 @@
 #include "test_utils.h"
 #include "test_utils.h"
 
 
 #if CONFIG_IDF_CMAKE
 #if CONFIG_IDF_CMAKE
+#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC))
 #define IDF_HEAP_CAPS 1
 #define IDF_HEAP_CAPS 1
+#endif
 
 
 #if IDF_HEAP_CAPS
 #if IDF_HEAP_CAPS
 #include "esp_heap_caps.h"
 #include "esp_heap_caps.h"
@@ -138,6 +140,11 @@ void esp_nn_add_elementwise_s8_test()
         out_c_orig = out_data_c;
         out_c_orig = out_data_c;
         out_opt_orig = out_data_opt;
         out_opt_orig = out_data_opt;
 #endif
 #endif
+        if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL ||
+                out_opt_orig == NULL) {
+            printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__);
+            goto elementwise_add_test_cleanup;
+        }
 
 
         for (int i = 0; i < size; ++i) {
         for (int i = 0; i < size; ++i) {
             input1[i] = rand() % 256 - 128;
             input1[i] = rand() % 256 - 128;
@@ -194,10 +201,10 @@ elementwise_add_test_cleanup:
         if (input2_orig) {
         if (input2_orig) {
             free(input2_orig);
             free(input2_orig);
         }
         }
-        if (out_data_c) {
+        if (out_c_orig) {
             free(out_c_orig);
             free(out_c_orig);
         }
         }
-        if (out_data_opt) {
+        if (out_opt_orig) {
             free(out_opt_orig);
             free(out_opt_orig);
         }
         }
     }
     }
@@ -282,6 +289,11 @@ void esp_nn_mul_elementwise_s8_test()
         out_c_orig = out_data_c;
         out_c_orig = out_data_c;
         out_opt_orig = out_data_opt;
         out_opt_orig = out_data_opt;
 #endif
 #endif
+        if (input1_orig == NULL || input2_orig == NULL || out_c_orig == NULL ||
+                out_opt_orig == NULL) {
+            printf(ANSI_COLOR_RED"%s error allocating buffers\n"ANSI_COLOR_RESET, __FUNCTION__);
+            goto elementwise_mult_test_cleanup;
+        }
 
 
         for (int i = 0; i < size; ++i) {
         for (int i = 0; i < size; ++i) {
             input1[i] = rand() % 256 - 128;
             input1[i] = rand() % 256 - 128;
@@ -333,10 +345,10 @@ elementwise_mult_test_cleanup:
         if (input2_orig) {
         if (input2_orig) {
             free(input2_orig);
             free(input2_orig);
         }
         }
-        if (out_data_c) {
+        if (out_c_orig) {
             free(out_c_orig);
             free(out_c_orig);
         }
         }
-        if (out_data_opt) {
+        if (out_opt_orig) {
             free(out_opt_orig);
             free(out_opt_orig);
         }
         }
     }
     }

+ 70 - 36
code/components/esp-nn/tests/src/convolution_test.c

@@ -22,8 +22,9 @@
 #include "test_utils.h"
 #include "test_utils.h"
 
 
 #if CONFIG_IDF_CMAKE
 #if CONFIG_IDF_CMAKE
+#if (CONFIG_SPIRAM_SUPPORT && (CONFIG_SPIRAM_USE_CAPS_ALLOC || CONFIG_SPIRAM_USE_MALLOC))
 #define IDF_HEAP_CAPS 1
 #define IDF_HEAP_CAPS 1
-
+#endif
 #if IDF_HEAP_CAPS
 #if IDF_HEAP_CAPS
 #include "esp_heap_caps.h"
 #include "esp_heap_caps.h"
 #endif
 #endif
@@ -44,8 +45,8 @@ void esp_nn_depthwise_conv_s8_test()
     uint16_t filter_ht, filter_wd, ch_mult;
     uint16_t filter_ht, filter_wd, ch_mult;
     uint16_t pad_wd, pad_ht, stride_wd, stride_ht;
     uint16_t pad_wd, pad_ht, stride_wd, stride_ht;
 
 
-    // run for 10 iterations
-    for (int itr = 0; itr < 10; itr++) {
+    // run for 15 iterations
+    for (int itr = 0; itr < 15; itr++) {
         /* prepare data */
         /* prepare data */
         switch (itr) {
         switch (itr) {
         case 0: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0)
         case 0: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0)
@@ -144,22 +145,52 @@ void esp_nn_depthwise_conv_s8_test()
             stride_wd = 2;
             stride_wd = 2;
             stride_ht = 2;
             stride_ht = 2;
             break;
             break;
+        case 8: // same as case 7, with large parameters
+            input_wd = 58;
+            input_ht = 58;
+            filter_ht = 3;
+            filter_wd = 3;
+            ch_mult = 1;
+            channels = 128;
+            pad_wd = 0;
+            pad_ht = 0;
+            stride_wd = 2;
+            stride_ht = 2;
+            break;
+        case 9: // (ch_mult 1, (channels % 16) = 0), filter (3,3), pad (0,0)  stride (2,2)
+            input_wd = 6;
+            input_ht = 6;
+            filter_ht = 3;
+            filter_wd = 3;
+            ch_mult = 1;
+            channels = 16;
+            pad_wd = 0;
+            pad_ht = 0;
+            stride_wd = 2;
+            stride_ht = 2;
+            break;
         default:
         default:
-            input_wd = 4;
-            input_ht = 4;
+            input_wd = 6;
+            input_ht = 6;
             filter_ht = 3;
             filter_ht = 3;
             filter_wd = 3;
             filter_wd = 3;
-            ch_mult = 4;
-            channels = 4;
-            pad_wd = 1;
-            pad_ht = 1;
-            stride_wd = 1;
-            stride_ht = 1;
+            ch_mult = 1;
+            channels = 16;
+            stride_wd = rand() % 2 + 1;
+            stride_ht = stride_wd;
+            pad_wd = stride_wd == 1 ? 0 : rand() % 2;
+            pad_ht = pad_wd;
+            printf("stride(%d), pad (%d)\t", stride_wd, pad_wd);
             break;
             break;
         }
         }
 
 
         uint16_t out_wd = (input_wd - filter_wd + 1) / stride_wd;
         uint16_t out_wd = (input_wd - filter_wd + 1) / stride_wd;
         uint16_t out_ht = (input_ht - filter_ht + 1) / stride_ht;
         uint16_t out_ht = (input_ht - filter_ht + 1) / stride_ht;
+        if (itr == 9) {
+            // expect the function to handle this gracefully
+            out_wd += 1;
+            out_ht += 1;
+        }
         int in_size = input_wd * input_ht * channels;
         int in_size = input_wd * input_ht * channels;
         int out_size = out_wd * out_ht * channels * ch_mult;
         int out_size = out_wd * out_ht * channels * ch_mult;
         int filter_size = filter_wd * filter_ht * channels * ch_mult + 4;
         int filter_size = filter_wd * filter_ht * channels * ch_mult + 4;
@@ -210,9 +241,16 @@ void esp_nn_depthwise_conv_s8_test()
             out_mult[i] = 0x7eb0e200 + rand() % 50;
             out_mult[i] = 0x7eb0e200 + rand() % 50;
         }
         }
 
 
-        int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(input_wd, input_ht,
-                                                                    channels, ch_mult,
-                                                                    filter_wd, filter_ht);
+        data_dims_t input_dims = {.width = input_wd, .height = input_ht, .channels = channels, 1};
+        data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = channels * ch_mult, 1};
+        data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0};
+        dw_conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset, .ch_mult = ch_mult,
+                                        .stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht},
+                                        .dilation = {0, 0}, .activation = {activation_min, activation_max}};
+        quant_data_t quant_data = {.shift = out_shift, .mult = out_mult};
+
+        int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(&input_dims, &filter_dims,
+                                                                      &output_dims, &conv_params);
         if (scratch_buf_size > 0) {
         if (scratch_buf_size > 0) {
 #if IDF_HEAP_CAPS
 #if IDF_HEAP_CAPS
             scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
             scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
@@ -234,11 +272,8 @@ void esp_nn_depthwise_conv_s8_test()
         }
         }
 
 
         /* C function */
         /* C function */
-        esp_nn_depthwise_conv_s8_ansi(input, input_wd, input_ht, channels, input_offset,
-                                    pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
-                                    filter_data + 4, filter_wd, filter_ht,
-                                    bias + 1, out_data_c, out_wd, out_ht, out_offset, out_shift,
-                                    out_mult, activation_min, activation_max);
+        esp_nn_depthwise_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 4,
+                                      bias + 1, &output_dims, out_data_c, &conv_params, &quant_data);
 
 
         if (itr == 0) {
         if (itr == 0) {
             profile_c_end();
             profile_c_end();
@@ -246,11 +281,8 @@ void esp_nn_depthwise_conv_s8_test()
         }
         }
 
 
         /* Optimized function */
         /* Optimized function */
-        esp_nn_depthwise_conv_s8(input, input_wd, input_ht, channels, input_offset,
-                                pad_wd, pad_ht, stride_wd, stride_ht, ch_mult,
-                                filter_data + 4, filter_wd, filter_ht,
-                                bias + 1, out_data_opt, out_wd, out_ht, out_offset, out_shift,
-                                out_mult, activation_min, activation_max);
+        esp_nn_depthwise_conv_s8(&input_dims, input, &filter_dims, filter_data + 4,
+                                 bias + 1, &output_dims, out_data_opt, &conv_params, &quant_data);
 
 
         if (itr == 0) {
         if (itr == 0) {
             /* disable profiler */
             /* disable profiler */
@@ -479,8 +511,16 @@ void esp_nn_conv_s8_test()
             out_mult[i] = 0x7f67f4f8 + rand() % 50;
             out_mult[i] = 0x7f67f4f8 + rand() % 50;
         }
         }
 
 
-        int scratch_buf_size = esp_nn_get_conv_scratch_size(in_wd, in_ht, in_channels,
-                                                            out_channels, filter_wd, filter_ht);
+        data_dims_t input_dims = {.width = in_wd, .height = in_ht, .channels = in_channels, 1};
+        data_dims_t output_dims = {.width = out_wd, .height = out_ht, .channels = out_channels, 1};
+        data_dims_t filter_dims = {.width = filter_wd, .height = filter_ht, 0, 0};
+        conv_params_t conv_params = {.in_offset = input_offset, .out_offset = out_offset,
+                                    .stride = {stride_wd, stride_ht}, .padding = {pad_wd, pad_ht},
+                                    .dilation = {0, 0}, .activation = {activation_min, activation_max}};
+        quant_data_t quant_data = {.shift = out_shift, .mult = out_mult};
+
+        int scratch_buf_size = esp_nn_get_conv_scratch_size(&input_dims, &filter_dims,
+                                                            &output_dims, &conv_params);
         if (scratch_buf_size > 0) {
         if (scratch_buf_size > 0) {
 #if IDF_HEAP_CAPS
 #if IDF_HEAP_CAPS
             void *scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
             void *scratch_buf = heap_caps_malloc(scratch_buf_size + 32, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
@@ -502,11 +542,8 @@ void esp_nn_conv_s8_test()
         }
         }
 
 
         /* C function */
         /* C function */
-        esp_nn_conv_s8_ansi(input, in_wd, in_ht, in_channels, input_offset,
-                            pad_wd, pad_ht, stride_wd, stride_ht,
-                            filter_data + 2, filter_wd, filter_ht, bias,
-                            out_data_c, out_wd, out_ht, out_channels, out_offset, out_shift,
-                            out_mult, activation_min, activation_max);
+        esp_nn_conv_s8_ansi(&input_dims, input, &filter_dims, filter_data + 2,
+                            bias, &output_dims, out_data_c, &conv_params, &quant_data);
 
 
         if (itr == 0) {
         if (itr == 0) {
             profile_c_end();
             profile_c_end();
@@ -514,11 +551,8 @@ void esp_nn_conv_s8_test()
         }
         }
 
 
         /* Optimized function */
         /* Optimized function */
-        esp_nn_conv_s8(input, in_wd, in_ht, in_channels, input_offset,
-                    pad_wd, pad_ht, stride_wd, stride_ht,
-                    filter_data + 2, filter_wd, filter_ht, bias,
-                    out_data_opt, out_wd, out_ht, out_channels, out_offset, out_shift,
-                    out_mult, activation_min, activation_max);
+        esp_nn_conv_s8(&input_dims, input, &filter_dims, filter_data + 2,
+                       bias, &output_dims, out_data_opt, &conv_params, &quant_data);
 
 
         if (itr == 0) {
         if (itr == 0) {
             /* disable profiler */
             /* disable profiler */

BIN
code/components/esp-nn_20220716.zip


+ 1 - 1
code/components/jomjol_flowcontroll/ClassFlowCNNGeneral.cpp

@@ -756,7 +756,7 @@ bool ClassFlowCNNGeneral::doNeuralNetwork(string time)
                             _fit = _val + _valminus;
                             _fit = _val + _valminus;
 
 
                         }
                         }
-                        if (result > 10)
+                        if (result >= 10)
                             result = result - 10;
                             result = result - 10;
                         if (result < 0)
                         if (result < 0)
                             result = result + 10;
                             result = result + 10;

+ 4 - 1
code/components/tflite-lib/CMakeLists.txt

@@ -25,7 +25,8 @@ list(REMOVE_ITEM srcs_kernels
           "${tfmicro_kernels_dir}/depthwise_conv.cc"
           "${tfmicro_kernels_dir}/depthwise_conv.cc"
           "${tfmicro_kernels_dir}/fully_connected.cc"
           "${tfmicro_kernels_dir}/fully_connected.cc"
           "${tfmicro_kernels_dir}/mul.cc"
           "${tfmicro_kernels_dir}/mul.cc"
-          "${tfmicro_kernels_dir}/pooling.cc")
+          "${tfmicro_kernels_dir}/pooling.cc"
+          "${tfmicro_kernels_dir}/softmax.cc")
 
 
 FILE(GLOB esp_nn_kernels
 FILE(GLOB esp_nn_kernels
           "${tfmicro_kernels_dir}/esp_nn/*.cc")
           "${tfmicro_kernels_dir}/esp_nn/*.cc")
@@ -38,6 +39,8 @@ set(lib_srcs
           "${tflite_dir}/kernels/kernel_util.cc"
           "${tflite_dir}/kernels/kernel_util.cc"
           "${tflite_dir}/micro/memory_planner/greedy_memory_planner.cc"
           "${tflite_dir}/micro/memory_planner/greedy_memory_planner.cc"
           "${tflite_dir}/micro/memory_planner/linear_memory_planner.cc"
           "${tflite_dir}/micro/memory_planner/linear_memory_planner.cc"
+          "${tflite_dir}/micro/arena_allocator/recording_simple_memory_allocator.cc"
+          "${tflite_dir}/micro/arena_allocator/simple_memory_allocator.cc"
           "${tflite_dir}/c/common.cc"
           "${tflite_dir}/c/common.cc"
           "${tflite_dir}/core/api/error_reporter.cc"
           "${tflite_dir}/core/api/error_reporter.cc"
           "${tflite_dir}/core/api/flatbuffer_conversions.cc"
           "${tflite_dir}/core/api/flatbuffer_conversions.cc"

+ 2 - 0
code/components/tflite-lib/tensorflow/lite/builtin_ops.h

@@ -179,6 +179,8 @@ typedef enum {
   kTfLiteBuiltinMultinomial = 149,
   kTfLiteBuiltinMultinomial = 149,
   kTfLiteBuiltinGelu = 150,
   kTfLiteBuiltinGelu = 150,
   kTfLiteBuiltinDynamicUpdateSlice = 151,
   kTfLiteBuiltinDynamicUpdateSlice = 151,
+  kTfLiteBuiltinRelu0To1 = 152,
+  kTfLiteBuiltinUnsortedSegmentProd = 153,
 } TfLiteBuiltinOperator;
 } TfLiteBuiltinOperator;
 
 
 #ifdef __cplusplus
 #ifdef __cplusplus

+ 3 - 0
code/components/tflite-lib/tensorflow/lite/c/builtin_op_data.h

@@ -518,6 +518,9 @@ typedef struct {
   bool approximate;
   bool approximate;
 } TfLiteGeluParams;
 } TfLiteGeluParams;
 
 
+typedef struct {
+  int num_segments;
+} TfLiteUnsortedSegmentProdParams;
 #ifdef __cplusplus
 #ifdef __cplusplus
 }  // extern "C"
 }  // extern "C"
 #endif  // __cplusplus
 #endif  // __cplusplus

+ 7 - 1
code/components/tflite-lib/tensorflow/lite/c/c_api_types.h

@@ -113,7 +113,13 @@ typedef struct TfLiteQuantizationParams {
 } TfLiteQuantizationParams;
 } TfLiteQuantizationParams;
 
 
 // --------------------------------------------------------------------------
 // --------------------------------------------------------------------------
-// Opaque types used by c_api_opaque.h.
+// Opaque types used by c_api.h, c_api_opaque.h and common.h.
+
+// TfLiteOpaqueContext is an opaque version of TfLiteContext;
+typedef struct TfLiteOpaqueContext TfLiteOpaqueContext;
+
+// TfLiteOpaqueNode is an opaque version of TfLiteNode;
+typedef struct TfLiteOpaqueNode TfLiteOpaqueNode;
 
 
 // TfLiteOpaqueTensor is an opaque version of TfLiteTensor;
 // TfLiteOpaqueTensor is an opaque version of TfLiteTensor;
 typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor;
 typedef struct TfLiteOpaqueTensor TfLiteOpaqueTensor;

+ 40 - 10
code/components/tflite-lib/tensorflow/lite/c/common.cc

@@ -14,13 +14,33 @@ limitations under the License.
 ==============================================================================*/
 ==============================================================================*/
 
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/c/common.h"
+
 #include "tensorflow/lite/c/c_api_types.h"
 #include "tensorflow/lite/c/c_api_types.h"
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+#include <string>
+
+#include "tensorflow/lite/core/macros.h"
+#include "tensorflow/lite/tensorflow_profiler_logger.h"
+#endif
 
 
 #ifndef TF_LITE_STATIC_MEMORY
 #ifndef TF_LITE_STATIC_MEMORY
 #include <stdlib.h>
 #include <stdlib.h>
 #include <string.h>
 #include <string.h>
 #endif  // TF_LITE_STATIC_MEMORY
 #endif  // TF_LITE_STATIC_MEMORY
 
 
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+namespace tflite {
+// Use weak symbols here (even though they are guarded by macros) to avoid
+// build breakage when building a benchmark requires TFLite runs. The main
+// benchmark library should have tensor_profiler_logger dependency.
+TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorAlloc(TfLiteTensor* tensor,
+                                               size_t num_bytes);
+
+TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorDealloc(TfLiteTensor* tensor);
+}  // namespace tflite
+
+#endif  // TF_LITE_TENSORFLOW_PROFILER
+
 extern "C" {
 extern "C" {
 
 
 size_t TfLiteIntArrayGetSizeInBytes(int size) {
 size_t TfLiteIntArrayGetSizeInBytes(int size) {
@@ -99,7 +119,12 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); }
 void TfLiteTensorDataFree(TfLiteTensor* t) {
 void TfLiteTensorDataFree(TfLiteTensor* t) {
   if (t->allocation_type == kTfLiteDynamic ||
   if (t->allocation_type == kTfLiteDynamic ||
       t->allocation_type == kTfLitePersistentRo) {
       t->allocation_type == kTfLitePersistentRo) {
-    free(t->data.raw);
+    if (t->data.raw) {
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+      tflite::OnTfLiteTensorDealloc(t);
+#endif
+      free(t->data.raw);
+    }
   }
   }
   t->data.raw = nullptr;
   t->data.raw = nullptr;
 }
 }
@@ -161,7 +186,7 @@ void TfLiteTensorFree(TfLiteTensor* t) {
   t->dims = nullptr;
   t->dims = nullptr;
 
 
   if (t->dims_signature) {
   if (t->dims_signature) {
-    TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
+    TfLiteIntArrayFree((TfLiteIntArray*)t->dims_signature);
   }
   }
   t->dims_signature = nullptr;
   t->dims_signature = nullptr;
 
 
@@ -191,16 +216,12 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
 }
 }
 
 
 TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
 TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
-  if (!src || !dst)
-    return kTfLiteOk;
-  if (src->bytes != dst->bytes)
-    return kTfLiteError;
-  if (src == dst)
-    return kTfLiteOk;
+  if (!src || !dst) return kTfLiteOk;
+  if (src->bytes != dst->bytes) return kTfLiteError;
+  if (src == dst) return kTfLiteOk;
 
 
   dst->type = src->type;
   dst->type = src->type;
-  if (dst->dims)
-    TfLiteIntArrayFree(dst->dims);
+  if (dst->dims) TfLiteIntArrayFree(dst->dims);
   dst->dims = TfLiteIntArrayCopy(src->dims);
   dst->dims = TfLiteIntArrayCopy(src->dims);
   memcpy(dst->data.raw, src->data.raw, src->bytes);
   memcpy(dst->data.raw, src->data.raw, src->bytes);
   dst->buffer_handle = src->buffer_handle;
   dst->buffer_handle = src->buffer_handle;
@@ -218,8 +239,17 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
   // TODO(b/145340303): Tensor data should be aligned.
   // TODO(b/145340303): Tensor data should be aligned.
   if (!tensor->data.raw) {
   if (!tensor->data.raw) {
     tensor->data.raw = (char*)malloc(num_bytes);
     tensor->data.raw = (char*)malloc(num_bytes);
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+    tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
+#endif
   } else if (num_bytes > tensor->bytes) {
   } else if (num_bytes > tensor->bytes) {
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+    tflite::OnTfLiteTensorDealloc(tensor);
+#endif
     tensor->data.raw = (char*)realloc(tensor->data.raw, num_bytes);
     tensor->data.raw = (char*)realloc(tensor->data.raw, num_bytes);
+#ifdef TF_LITE_TENSORFLOW_PROFILER
+    tflite::OnTfLiteTensorAlloc(tensor, num_bytes);
+#endif
   }
   }
   tensor->bytes = num_bytes;
   tensor->bytes = num_bytes;
 }
 }

+ 52 - 3
code/components/tflite-lib/tensorflow/lite/c/common.h

@@ -173,9 +173,9 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
     }                                                 \
     }                                                 \
   } while (false)
   } while (false)
 #else  // TF_LITE_STRIP_ERROR_STRINGS
 #else  // TF_LITE_STRIP_ERROR_STRINGS
-#define UNUSED(...) (void)sizeof(#__VA_ARGS__)
-#define TF_LITE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
-#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) UNUSED(__VA_ARGS__)
+#define ARGS_UNUSED(...) (void)sizeof(#__VA_ARGS__)
+#define TF_LITE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
+#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
 #endif  // TF_LITE_STRIP_ERROR_STRINGS
 #endif  // TF_LITE_STRIP_ERROR_STRINGS
 
 
 // Check whether value is true, and if not return kTfLiteError from
 // Check whether value is true, and if not return kTfLiteError from
@@ -842,6 +842,32 @@ typedef struct TfLiteContext {
                                    size_t* bytes);
                                    size_t* bytes);
 } TfLiteContext;
 } TfLiteContext;
 
 
+// `TfLiteRegistrationExternal` is an external version of `TfLiteRegistration`
+// for C API which doesn't use internal types (such as `TfLiteContext`) but only
+// uses stable API types (such as `TfLiteOpaqueContext`). The purpose of each
+// field is the exactly the same as with `TfLiteRegistration`.
+typedef struct TfLiteRegistrationExternal {
+  // Custom op name.
+  const char* custom_name;
+
+  // The version of the op. The verion should be higher than 0.
+  const int version;
+
+  // Initializes the op from serialized data.
+  void* (*init)(TfLiteOpaqueContext* context, const char* buffer,
+                size_t length);
+
+  // The pointer `buffer` is the data previously returned by an init invocation.
+  void (*free)(TfLiteOpaqueContext* context, void* buffer);
+
+  // Called when the inputs that this node depends on have been resized.
+  TfLiteStatus (*prepare)(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);
+
+  // Called when the node is executed. (should read node->inputs and output to
+  // node->outputs).
+  TfLiteStatus (*invoke)(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);
+} TfLiteRegistrationExternal;
+
 typedef struct TfLiteRegistration {
 typedef struct TfLiteRegistration {
   // Initializes the op from serialized data.
   // Initializes the op from serialized data.
   // Called only *once* for the lifetime of the op, so any one-time allocations
   // Called only *once* for the lifetime of the op, so any one-time allocations
@@ -903,8 +929,31 @@ typedef struct TfLiteRegistration {
   // Note: It is the responsibility of the registration binder to set this
   // Note: It is the responsibility of the registration binder to set this
   // properly.
   // properly.
   int version;
   int version;
+
+  // The external version of `TfLiteRegistration`. Since we can't use internal
+  // types (such as `TfLiteContext`) for C API to maintain ABI stability.
+  // C API user will provide `TfLiteRegistrationExternal` to implement custom
+  // ops. We keep it inside of `TfLiteRegistration` and use it to route
+  // callbacks properly.
+  TfLiteRegistrationExternal* registration_external;
 } TfLiteRegistration;
 } TfLiteRegistration;
 
 
+// Old version of `TfLiteRegistration` to maintain binary backward
+// compatibility.
+// WARNING: This structure is deprecated / not an official part of the API.
+// It should be only used for binary backward compatibility.
+typedef struct TfLiteRegistration_V1 {
+  void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+  void (*free)(TfLiteContext* context, void* buffer);
+  TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+  TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+  const char* (*profiling_string)(const TfLiteContext* context,
+                                  const TfLiteNode* node);
+  int32_t builtin_code;
+  const char* custom_name;
+  int version;
+} TfLiteRegistration_V1;
+
 // The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
 // The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
 // values should be 1, 2, 4, 8, ...etc.
 // values should be 1, 2, 4, 8, ...etc.
 typedef enum TfLiteDelegateFlags {
 typedef enum TfLiteDelegateFlags {

+ 11 - 0
code/components/tflite-lib/tensorflow/lite/core/api/flatbuffer_conversions.cc

@@ -836,6 +836,16 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
       *builtin_data = params.release();
       *builtin_data = params.release();
       return kTfLiteOk;
       return kTfLiteOk;
     }
     }
+    case BuiltinOperator_UNSORTED_SEGMENT_PROD: {
+      auto params = safe_allocator.Allocate<TfLiteUnsortedSegmentProdParams>();
+      TF_LITE_ENSURE(error_reporter, params != nullptr);
+      if (const auto* unsorted_segment_prod_params =
+              op->builtin_options_as_UnsortedSegmentProdOptions()) {
+        params->num_segments = unsorted_segment_prod_params->num_segments();
+      }
+      *builtin_data = params.release();
+      return kTfLiteOk;
+    }
     // Below are the ops with no builtin_data structure.
     // Below are the ops with no builtin_data structure.
     // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
     // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
     // ok for now, since there is no call implementation either.
     // ok for now, since there is no call implementation either.
@@ -848,6 +858,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
     case BuiltinOperator_MATRIX_DIAG:
     case BuiltinOperator_MATRIX_DIAG:
     case BuiltinOperator_MATRIX_SET_DIAG:
     case BuiltinOperator_MATRIX_SET_DIAG:
     case BuiltinOperator_RELU_N1_TO_1:
     case BuiltinOperator_RELU_N1_TO_1:
+    case BuiltinOperator_RELU_0_TO_1:
     case BuiltinOperator_SELECT:
     case BuiltinOperator_SELECT:
     case BuiltinOperator_SELECT_V2:
     case BuiltinOperator_SELECT_V2:
     case BuiltinOperator_SLICE:
     case BuiltinOperator_SLICE:

+ 52 - 1
code/components/tflite-lib/tensorflow/lite/core/api/op_resolver.h

@@ -23,6 +23,16 @@ limitations under the License.
 #include "tensorflow/lite/core/api/error_reporter.h"
 #include "tensorflow/lite/core/api/error_reporter.h"
 #include "tensorflow/lite/schema/schema_generated.h"
 #include "tensorflow/lite/schema/schema_generated.h"
 
 
+// Opaque type similar to TfLiteDelegate / TfLiteOpaqueDelegate.
+// This is used for cases (e.g. when using "TF Lite with Google Play Services")
+// where the TF Lite runtime might be built using a newer (or older)
+// version of the TF Lite sources than the app, and hence might have a
+// different definition of the TfLiteDelegate type. TF Lite APIs use
+// TfLiteOpaqueDelegate rather than TfLiteDelegate when they want to
+// refer to a delegate defined with that potentially different version
+// of the TfLiteDelegate type.
+struct TfLiteOpaqueDelegateStruct;
+
 namespace tflite {
 namespace tflite {
 
 
 /// Abstract interface that returns TfLiteRegistrations given op codes or custom
 /// Abstract interface that returns TfLiteRegistrations given op codes or custom
@@ -37,8 +47,10 @@ class OpResolver {
   virtual const TfLiteRegistration* FindOp(const char* op,
   virtual const TfLiteRegistration* FindOp(const char* op,
                                            int version) const = 0;
                                            int version) const = 0;
 
 
+  // Represents a sequence of delegates.
   using TfLiteDelegatePtrVector =
   using TfLiteDelegatePtrVector =
       std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
       std::vector<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>>;
+
   // Returns optional delegates for resolving and handling ops in the flatbuffer
   // Returns optional delegates for resolving and handling ops in the flatbuffer
   // model. This may be used in addition to the standard TfLiteRegistration
   // model. This may be used in addition to the standard TfLiteRegistration
   // lookup for graph resolution.
   // lookup for graph resolution.
@@ -47,16 +59,55 @@ class OpResolver {
     return {};
     return {};
   }
   }
 
 
-  // Represent a function that creates a TfLite delegate instance.
+  // Represents a function that creates a TfLite delegate instance.
   using TfLiteDelegateCreator =
   using TfLiteDelegateCreator =
       std::function<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
       std::function<std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
           int /*num_threads*/)>;
           int /*num_threads*/)>;
+
+  // Represents a sequence of delegate creator functions.
   using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>;
   using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>;
+
   // Returns a vector of delegate creators to create optional delegates for
   // Returns a vector of delegate creators to create optional delegates for
   // resolving and handling ops in the flatbuffer model. This may be used in
   // resolving and handling ops in the flatbuffer model. This may be used in
   // addition to the standard TfLiteRegistration lookup for graph resolution.
   // addition to the standard TfLiteRegistration lookup for graph resolution.
+  //
+  // Note that this method is not used (will not be called) if you are using
+  // TF Lite in Google Play Services; the GetOpaqueDelegateCreators method
+  // (see below) is used for that case.
   virtual TfLiteDelegateCreators GetDelegateCreators() const { return {}; }
   virtual TfLiteDelegateCreators GetDelegateCreators() const { return {}; }
 
 
+  // TODO(b/202712825): it would be nice if we could avoid the need for separate
+  // "opaque" types & methods for use only with TF Lite in Google Play Services.
+
+  // Represents an opaque delegate instance.
+  // WARNING: Experimental interface, subject to change.
+  using TfLiteOpaqueDelegatePtr =
+      std::unique_ptr<TfLiteOpaqueDelegateStruct,
+                      void (*)(TfLiteOpaqueDelegateStruct*)>;
+
+  // Represents a function that creates an opaque delegate instance.
+  // WARNING: Experimental interface, subject to change.
+  using TfLiteOpaqueDelegateCreator =
+      std::function<TfLiteOpaqueDelegatePtr(int /*num_threads*/)>;
+
+  // Represents a sequence of opaque delegate creator functions.
+  // WARNING: Experimental interface, subject to change.
+  using TfLiteOpaqueDelegateCreators = std::vector<TfLiteOpaqueDelegateCreator>;
+
+  // Returns a vector of opaque delegate creators to create optional opaque
+  // delegates for resolving and handling ops in the flatbuffer model. This may
+  // be used in addition to the standard TfLiteRegistration lookup for graph
+  // resolution.
+  //
+  // Note that this method will be called only if you are using TF Lite in
+  // Google Play Services; if you are using regular TF Lite, GetDelegateCreators
+  // (see above) is used instead.
+  //
+  // WARNING: Experimental interface, subject to change.
+  virtual TfLiteOpaqueDelegateCreators GetOpaqueDelegateCreators() const {
+    return {};
+  }
+
   virtual ~OpResolver() {}
   virtual ~OpResolver() {}
 
 
  private:
  private:

+ 3 - 3
code/components/tflite-lib/tensorflow/lite/kernels/internal/reference/hard_swish.h

@@ -23,9 +23,9 @@ namespace tflite {
 namespace reference_ops {
 namespace reference_ops {
 
 
 inline int16_t SaturatingLeftShift(int16_t value, int amount) {
 inline int16_t SaturatingLeftShift(int16_t value, int amount) {
-  int32_t result = static_cast<int32_t>(value) * (1 << amount);
-  result = std::min<int32_t>(result, std::numeric_limits<int16_t>::max());
-  result = std::max<int32_t>(result, std::numeric_limits<int16_t>::min());
+  int64_t result = static_cast<int64_t>(value) * (1 << amount);
+  result = std::min<int64_t>(result, std::numeric_limits<int16_t>::max());
+  result = std::max<int64_t>(result, std::numeric_limits<int16_t>::min());
   return result;
   return result;
 }
 }
 
 

+ 6 - 3
code/components/tflite-lib/tensorflow/lite/kernels/internal/runtime_shape.h

@@ -27,6 +27,11 @@ class RuntimeShape {
  public:
  public:
   RuntimeShape& operator=(RuntimeShape const&) = delete;
   RuntimeShape& operator=(RuntimeShape const&) = delete;
 
 
+  // RuntimeShape in TFLM supports up to 5 dimensions.
+  // The name kMaxSmallSize comes from the same file of the upstream
+  // tensorflow lite repo and need to be kept the same for max reuse.
+  static constexpr int kMaxSmallSize = 5;
+
   RuntimeShape() : size_(0) {}
   RuntimeShape() : size_(0) {}
 
 
   explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {}
   explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {}
@@ -104,11 +109,9 @@ class RuntimeShape {
                 sizeof(int32_t) * shape.DimensionsCount());
                 sizeof(int32_t) * shape.DimensionsCount());
   }
   }
 
 
-  // A maximum of 4 dimensions are supported on TFLM.
-  static constexpr int kMaxSize = 5;
   int32_t size_;
   int32_t size_;
   union {
   union {
-    int32_t dims_[kMaxSize];
+    int32_t dims_[kMaxSmallSize];
   };
   };
 };
 };
 
 

+ 5 - 5
code/components/tflite-lib/tensorflow/lite/kernels/internal/types.h

@@ -974,11 +974,11 @@ struct StridedSliceParams {
   int8_t strides_count;
   int8_t strides_count;
   int32_t strides[5];
   int32_t strides[5];
 
 
-  int16_t begin_mask;
-  int16_t ellipsis_mask;
-  int16_t end_mask;
-  int16_t new_axis_mask;
-  int16_t shrink_axis_mask;
+  uint16_t begin_mask;
+  uint16_t ellipsis_mask;
+  uint16_t end_mask;
+  uint16_t new_axis_mask;
+  uint16_t shrink_axis_mask;
 };
 };
 
 
 struct TanhParams {
 struct TanhParams {

+ 1 - 1
code/components/tflite-lib/tensorflow/lite/kernels/kernel_util.h

@@ -308,7 +308,7 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
                                         const TfLiteTensor* input3,
                                         const TfLiteTensor* input3,
                                         TfLiteIntArray** output_shape);
                                         TfLiteIntArray** output_shape);
 
 
-// Return the size of given type in bytes. Return 0 in in case of string.
+// Return the size of given type in bytes. Return 0 in case of string.
 int TfLiteTypeGetSize(TfLiteType type);
 int TfLiteTypeGetSize(TfLiteType type);
 
 
 // Whether the current platform is mobile (Android or iOS).
 // Whether the current platform is mobile (Android or iOS).

+ 3 - 3
code/components/tflite-lib/tensorflow/lite/micro/ibuffer_allocator.h → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h

@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 See the License for the specific language governing permissions and
 limitations under the License.
 limitations under the License.
 ==============================================================================*/
 ==============================================================================*/
-#ifndef TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
-#define TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_
 
 
 #include <cstddef>
 #include <cstddef>
 #include <cstdint>
 #include <cstdint>
@@ -97,4 +97,4 @@ class INonPersistentBufferAllocator {
 
 
 }  // namespace tflite
 }  // namespace tflite
 
 
-#endif  // TENSORFLOW_LITE_MICRO_IBUFFER_ALLOCATOR_H_
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_IBUFFER_ALLOCATOR_H_

+ 165 - 0
code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc

@@ -0,0 +1,165 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h"
+
+#include "tensorflow/lite/micro/memory_helpers.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
+
+namespace tflite {
+
+NonPersistentArenaBufferAllocator::NonPersistentArenaBufferAllocator(
+    uint8_t* buffer, size_t buffer_size)
+    : buffer_head_(buffer),
+      buffer_tail_(buffer + buffer_size),
+      head_temp_(buffer),
+      next_temp_(buffer) {}
+
+NonPersistentArenaBufferAllocator::~NonPersistentArenaBufferAllocator() {}
+
+// Allocates a temporary buffer. This buffer is not resizable.
+uint8_t* NonPersistentArenaBufferAllocator::AllocateTemp(size_t size,
+                                                         size_t alignment) {
+  uint8_t* const aligned_result = AlignPointerUp(next_temp_, alignment);
+  const size_t available_memory = buffer_tail_ - aligned_result;
+  if (available_memory < size) {
+    MicroPrintf(
+        "Failed to allocate temp memory. Requested: %u, "
+        "available %u, missing: %u",
+        size, available_memory, size - available_memory);
+    return nullptr;
+  }
+  next_temp_ = aligned_result + size;
+  temp_buffer_ptr_check_sum_ ^= reinterpret_cast<intptr_t>(aligned_result);
+  temp_buffer_count_++;
+  return aligned_result;
+}
+
+// Signals that a temporary buffer is no longer needed.
+void NonPersistentArenaBufferAllocator::DeallocateTemp(uint8_t* temp_buf) {
+  temp_buffer_ptr_check_sum_ ^= reinterpret_cast<intptr_t>(temp_buf);
+  temp_buffer_count_--;
+}
+
+// Returns true if all temporary buffers are already deallocated.
+bool NonPersistentArenaBufferAllocator::IsAllTempDeallocated() {
+  if (temp_buffer_count_ != 0 || temp_buffer_ptr_check_sum_ != 0) {
+    MicroPrintf(
+        "Number of allocated temp buffers: %d. Checksum passing status: %d",
+        temp_buffer_count_, !temp_buffer_ptr_check_sum_);
+    return false;
+  }
+  return true;
+}
+
+// Signals that all temporary allocations can be reclaimed. TFLM calls this
+// API when it knows that all temporary buffers that it requested has been
+// deallocated. The goal of API is to facilitate implementations of
+// INonPersistentBufferAllocator can reuse buffer with some reasonable
+// complexity.
+TfLiteStatus NonPersistentArenaBufferAllocator::ResetTempAllocations() {
+  if (!IsAllTempDeallocated()) {
+    MicroPrintf(
+        "All temp buffers must be freed before calling ResetTempAllocations()");
+    return kTfLiteError;
+  }
+  next_temp_ = head_temp_;
+  return kTfLiteOk;
+}
+
+// Returns a buffer that is resizable viable ResizeBuffer().
+uint8_t* NonPersistentArenaBufferAllocator::AllocateResizableBuffer(
+    size_t size, size_t alignment) {
+  // Only supports one resizable buffer, which starts at the buffer head.
+  uint8_t* expected_resizable_buf = AlignPointerUp(buffer_head_, alignment);
+
+  if (head_temp_ != expected_resizable_buf) {
+    MicroPrintf(
+        "Cannot allocate a new resizable buffer when one is already allocated");
+    return nullptr;
+  }
+
+  if (ResizeBuffer(expected_resizable_buf, size, alignment) == kTfLiteOk) {
+    return expected_resizable_buf;
+  }
+  return nullptr;
+}
+
+// Resizes a buffer that is previously returned by the AllocateResizableBuffer.
+// Note that ResizeBuffer(old_resizable_buf, 0, 1) effectively deallocates
+// a previous allocated resizable buffer.
+TfLiteStatus NonPersistentArenaBufferAllocator::ResizeBuffer(
+    uint8_t* resizable_buf, size_t size, size_t alignment) {
+  // Only supports one resizable buffer, which starts at the buffer head.
+  uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
+  if (resizable_buf != expect_resizable_buf) {
+    MicroPrintf("Internal error: buffer is not resizable");
+    return kTfLiteError;
+  }
+  if (head_temp_ != next_temp_) {
+    MicroPrintf("ResetTempAllocations() is not called before ResizeBuffer().");
+    return kTfLiteError;
+  }
+
+  const size_t available_memory = buffer_tail_ - expect_resizable_buf;
+  if (available_memory < size) {
+    MicroPrintf(
+        "Failed to resize buffer. Requested: %u, available %u, missing: %u",
+        size, available_memory, size - available_memory);
+    return kTfLiteError;
+  }
+  head_temp_ = expect_resizable_buf + size;
+  next_temp_ = head_temp_;
+
+  return kTfLiteOk;
+}
+
+// Frees up the memory occupied by the resizable buffer.
+TfLiteStatus NonPersistentArenaBufferAllocator::DeallocateResizableBuffer(
+    uint8_t* resizable_buf) {
+  return ResizeBuffer(resizable_buf, 0, 1);
+}
+
+// Returns a pointer pointing to the start of the overlay memory, which is
+// used for activation tensors and scratch buffers by kernels at Invoke stage.
+uint8_t* NonPersistentArenaBufferAllocator::GetOverlayMemoryAddress() const {
+  return buffer_head_;
+}
+
+// Reserves the size of the overlay memory. This overlay is reserved for the
+// kernels at Invoke stage. This is referred to as the overlay because before
+// Invoket state, the same memory can be used for temp buffers. The layout of
+// the memory is planned by the memory planner separately at Invoke stage.
+TfLiteStatus
+NonPersistentArenaBufferAllocator::ReserveNonPersistentOverlayMemory(
+    size_t size, size_t alignment) {
+  uint8_t* expect_resizable_buf = AlignPointerUp(buffer_head_, alignment);
+  return ResizeBuffer(expect_resizable_buf, size, alignment);
+}
+
+// Returns the size of non-persistent buffer in use.
+size_t NonPersistentArenaBufferAllocator::GetNonPersistentUsedBytes() const {
+  return (next_temp_ - buffer_head_);
+}
+
+// Returns the number of bytes available with a given alignment. This number
+// takes in account any temporary allocations.
+size_t NonPersistentArenaBufferAllocator::GetAvailableMemory(
+    size_t alignment) const {
+  uint8_t* const aligned_temp = AlignPointerUp(next_temp_, alignment);
+  uint8_t* const aligned_tail = AlignPointerDown(buffer_tail_, alignment);
+  return aligned_tail - aligned_temp;
+}
+
+}  // namespace tflite

+ 104 - 0
code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h

@@ -0,0 +1,104 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+
+#include <cstddef>
+#include <cstdint>
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
+#include "tensorflow/lite/micro/compatibility.h"
+
+namespace tflite {
+
+// Implement INonPersistentBufferAllocator on an arena that is dedicated for
+// non-persistent buffers.
+class NonPersistentArenaBufferAllocator : public INonPersistentBufferAllocator {
+ public:
+  NonPersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
+  virtual ~NonPersistentArenaBufferAllocator();
+
+  // Allocates a temporary buffer. This buffer is not resizable.
+  uint8_t* AllocateTemp(size_t size, size_t alignment) override;
+
+  // Signals that a temporary buffer is no longer needed.
+  void DeallocateTemp(uint8_t* buf) override;
+
+  // Returns true if all temporary buffers are already deallocated.
+  bool IsAllTempDeallocated() override;
+
+  // Signals that all temporary allocations can be reclaimed. TFLM calls this
+  // API when it knows that all temporary buffers that it requested has been
+  // deallocated.
+  TfLiteStatus ResetTempAllocations() override;
+
+  // Returns a buffer that is resizable viable ResizeBuffer().
+  uint8_t* AllocateResizableBuffer(size_t size, size_t alignment) override;
+
+  // Resizes a buffer that is previously returned by the
+  // AllocateResizableBuffer.
+  TfLiteStatus ResizeBuffer(uint8_t* resizable_buf, size_t size,
+                            size_t alignment) override;
+
+  // Frees up the memory occupied by the resizable buffer.
+  TfLiteStatus DeallocateResizableBuffer(uint8_t* resizable_buf) override;
+
+  // Returns a pointer pointing to the start of the overlay memory, which is
+  // used for activation tensors and scratch buffers by kernels at Invoke stage.
+  uint8_t* GetOverlayMemoryAddress() const override;
+
+  // Reserves the size of the overlay memory. This overlay is reserved for the
+  // kernels at Invoke stage. This is referred to as the overlay because before
+  // Invoket state, the same memory can be used for temp buffers. The layout of
+  // the memory is planned by the memory planner separately at Invoke stage.
+  TfLiteStatus ReserveNonPersistentOverlayMemory(size_t size,
+                                                 size_t alignment) override;
+
+  // Returns the size of non-persistent buffer in use.
+  size_t GetNonPersistentUsedBytes() const override;
+
+  // Returns the number of bytes available with a given alignment. This number
+  // takes in account any temporary allocations.
+  size_t GetAvailableMemory(size_t alignment) const override;
+
+  TF_LITE_REMOVE_VIRTUAL_DELETE
+
+ private:
+  // The memory arena that this allocator manages.
+  uint8_t* const buffer_head_;
+  uint8_t* const buffer_tail_;
+
+  // The whole region is split into two parts:
+  // buffer_head_ to head_temp_ - 1 belongs to the only resizable buffer.
+  // head_temp_ to buffer_tail_ can be used for (non-resizable) temp buffers.
+  uint8_t* head_temp_;
+
+  // next_temp_ points to the next available temp buffer allocation address and
+  // its range is between head_temp_ and buffer_tail_
+  uint8_t* next_temp_;
+
+  // XOR Check sum for outstanding temp buffers.
+  // If all temp buffers are deallocated OR no temp buffers are allocated,
+  // temp_buffer_ptr_check_sum_ == nullptr.
+  intptr_t temp_buffer_ptr_check_sum_ = 0;
+  // Count of outstanding temp buffers.
+  int temp_buffer_count_ = 0;
+};
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_NON_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_

+ 52 - 0
code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc

@@ -0,0 +1,52 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h"
+
+#include "tensorflow/lite/micro/memory_helpers.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
+
+namespace tflite {
+
+PersistentArenaBufferAllocator::PersistentArenaBufferAllocator(
+    uint8_t* buffer, size_t buffer_size)
+    : buffer_head_(buffer),
+      buffer_tail_(buffer + buffer_size),
+      tail_temp_(buffer_tail_) {}
+
+PersistentArenaBufferAllocator::~PersistentArenaBufferAllocator() {}
+
+uint8_t* PersistentArenaBufferAllocator::AllocatePersistentBuffer(
+    size_t size, size_t alignment) {
+  uint8_t* const aligned_result =
+      AlignPointerDown(tail_temp_ - size, alignment);
+  if (aligned_result < buffer_head_) {
+#ifndef TF_LITE_STRIP_ERROR_STRINGS
+    const size_t missing_memory = buffer_head_ - aligned_result;
+    MicroPrintf(
+        "Failed to allocate tail memory. Requested: %u, "
+        "available %u, missing: %u",
+        size, size - missing_memory, missing_memory);
+#endif
+    return nullptr;
+  }
+  tail_temp_ = aligned_result;
+  return aligned_result;
+}
+
+size_t PersistentArenaBufferAllocator::GetPersistentUsedBytes() const {
+  return buffer_tail_ - tail_temp_;
+}
+
+}  // namespace tflite

+ 59 - 0
code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h

@@ -0,0 +1,59 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_
+
+#include <cstddef>
+#include <cstdint>
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
+#include "tensorflow/lite/micro/compatibility.h"
+
+namespace tflite {
+
+// PersistentArenaBufferAllocator is an implementatation of
+// IPersistentBufferAllocator interface on an arena that is dedicated for
+// persistent buffers.
+class PersistentArenaBufferAllocator : public IPersistentBufferAllocator {
+ public:
+  PersistentArenaBufferAllocator(uint8_t* buffer, size_t buffer_size);
+  virtual ~PersistentArenaBufferAllocator();
+
+  // Allocates persistent memory. The persistent buffer is never freed.
+  // Returns nullptr if errors occured.
+  uint8_t* AllocatePersistentBuffer(size_t size, size_t alignment) override;
+
+  // Returns the size of all persistent allocations in bytes.
+  size_t GetPersistentUsedBytes() const override;
+
+  TF_LITE_REMOVE_VIRTUAL_DELETE
+ private:
+  // The memory arena that this allocator manages.
+  uint8_t* const buffer_head_;
+  uint8_t* const buffer_tail_;
+
+  // The whole region is split into two parts:
+  // tail_temp_ to buffer_tail_ contains allocated buffers;
+  // buffer_head_ to tail_temp_ - 1 belongs to still available spaces.
+  // So in essence, the allocated region grows from the bottom and emulates
+  // SimpleMemoryAllocator's persistent part.
+  uint8_t* tail_temp_;
+};
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_PERSISTENT_ARENA_BUFFER_ALLOCATOR_H_

+ 1 - 1
code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.cc → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.cc

@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 limitations under the License.
 ==============================================================================*/
 ==============================================================================*/
 
 
-#include "tensorflow/lite/micro/recording_simple_memory_allocator.h"
+#include "tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h"
 
 
 #include <new>
 #include <new>
 
 

+ 4 - 4
code/components/tflite-lib/tensorflow/lite/micro/recording_simple_memory_allocator.h → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/recording_simple_memory_allocator.h

@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
 limitations under the License.
 limitations under the License.
 ==============================================================================*/
 ==============================================================================*/
 
 
-#ifndef TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
-#define TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
 
 
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/compatibility.h"
 #include "tensorflow/lite/micro/compatibility.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
 
 
 namespace tflite {
 namespace tflite {
 
 
@@ -62,4 +62,4 @@ class RecordingSimpleMemoryAllocator : public SimpleMemoryAllocator {
 
 
 }  // namespace tflite
 }  // namespace tflite
 
 
-#endif  // TENSORFLOW_LITE_MICRO_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_RECORDING_SIMPLE_MEMORY_ALLOCATOR_H_

+ 1 - 1
code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.cc → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.cc

@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 limitations under the License.
 ==============================================================================*/
 ==============================================================================*/
 
 
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 
 
 #include <cstddef>
 #include <cstddef>
 #include <cstdint>
 #include <cstdint>

+ 4 - 4
code/components/tflite-lib/tensorflow/lite/micro/simple_memory_allocator.h → code/components/tflite-lib/tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h

@@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
 limitations under the License.
 limitations under the License.
 ==============================================================================*/
 ==============================================================================*/
 
 
-#ifndef TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
-#define TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
+#ifndef TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_
+#define TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_
 
 
 #include <cstddef>
 #include <cstddef>
 #include <cstdint>
 #include <cstdint>
 
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/core/api/error_reporter.h"
 #include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h"
 #include "tensorflow/lite/micro/compatibility.h"
 #include "tensorflow/lite/micro/compatibility.h"
-#include "tensorflow/lite/micro/ibuffer_allocator.h"
 
 
 namespace tflite {
 namespace tflite {
 
 
@@ -147,4 +147,4 @@ class SimpleMemoryAllocator : public INonPersistentBufferAllocator,
 
 
 }  // namespace tflite
 }  // namespace tflite
 
 
-#endif  // TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_
+#endif  // TENSORFLOW_LITE_MICRO_ARENA_ALLOCATOR_SIMPLE_MEMORY_ALLOCATOR_H_

+ 1 - 1
code/components/tflite-lib/tensorflow/lite/micro/fake_micro_context.cc

@@ -16,10 +16,10 @@ limitations under the License.
 #include "tensorflow/lite/micro/fake_micro_context.h"
 #include "tensorflow/lite/micro/fake_micro_context.h"
 
 
 #include "tensorflow/lite/kernels/internal/compatibility.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/micro_allocator.h"
 #include "tensorflow/lite/micro/micro_allocator.h"
 #include "tensorflow/lite/micro/micro_arena_constants.h"
 #include "tensorflow/lite/micro/micro_arena_constants.h"
 #include "tensorflow/lite/micro/micro_error_reporter.h"
 #include "tensorflow/lite/micro/micro_error_reporter.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
 
 
 namespace tflite {
 namespace tflite {
 namespace {
 namespace {

+ 7 - 20
code/components/tflite-lib/tensorflow/lite/micro/kernels/activations.cc

@@ -24,6 +24,7 @@ limitations under the License.
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/op_macros.h"
 #include "tensorflow/lite/kernels/op_macros.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
 #include "tensorflow/lite/micro/micro_utils.h"
 #include "tensorflow/lite/micro/micro_utils.h"
 
 
 namespace tflite {
 namespace tflite {
@@ -60,8 +61,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
       return kTfLiteOk;
       return kTfLiteOk;
     }
     }
     default: {
     default: {
-      TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
-                         TfLiteTypeGetName(input->type));
+      MicroPrintf("Only float32 is supported currently, got %s",
+                  TfLiteTypeGetName(input->type));
       return kTfLiteError;
       return kTfLiteError;
     }
     }
   }
   }
@@ -99,8 +100,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
       return kTfLiteOk;
       return kTfLiteOk;
     }
     }
     default: {
     default: {
-      TF_LITE_KERNEL_LOG(context, "Only float32 is supported currently, got %s",
-                         TfLiteTypeGetName(input->type));
+      MicroPrintf("Only float32 is supported currently, got %s",
+                  TfLiteTypeGetName(input->type));
       return kTfLiteError;
       return kTfLiteError;
     }
     }
   }
   }
@@ -109,25 +110,11 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_RELU() {
 TfLiteRegistration Register_RELU() {
-  return {/*init=*/ReluInit,
-          /*free=*/nullptr,
-          /*prepare=*/ReluPrepare,
-          /*invoke=*/ReluEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(ReluInit, ReluPrepare, ReluEval);
 }
 }
 
 
 TfLiteRegistration Register_RELU6() {
 TfLiteRegistration Register_RELU6() {
-  return {/*init=*/Relu6Init,
-          /*free=*/nullptr,
-          /*prepare=*/Relu6Prepare,
-          /*invoke=*/Relu6Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Relu6Init, Relu6Prepare, Relu6Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/add.cc

@@ -159,14 +159,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
 }
 }
 
 
 TfLiteRegistration Register_ADD() {
 TfLiteRegistration Register_ADD() {
-  return {/*init=*/AddInit,
-          /*free=*/nullptr,
-          /*prepare=*/AddPrepare,
-          /*invoke=*/AddEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/add_n.cc

@@ -208,14 +208,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_ADD_N() {
 TfLiteRegistration Register_ADD_N() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 2 - 16
code/components/tflite-lib/tensorflow/lite/micro/kernels/arg_min_max.cc

@@ -104,25 +104,11 @@ TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace arg_min_max
 }  // namespace arg_min_max
 
 
 TfLiteRegistration Register_ARG_MAX() {
 TfLiteRegistration Register_ARG_MAX() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/arg_min_max::ArgMaxEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMaxEval);
 }
 }
 
 
 TfLiteRegistration Register_ARG_MIN() {
 TfLiteRegistration Register_ARG_MIN() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/arg_min_max::ArgMinEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, arg_min_max::ArgMinEval);
 }
 }
 
 
 }  // namespace micro
 }  // namespace micro

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/assign_variable.cc

@@ -95,14 +95,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace.
 }  // namespace.
 
 
 TfLiteRegistration Register_ASSIGN_VARIABLE() {
 TfLiteRegistration Register_ASSIGN_VARIABLE() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/batch_to_space_nd.cc

@@ -105,14 +105,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace.
 }  // namespace.
 
 
 TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
 TfLiteRegistration Register_BATCH_TO_SPACE_ND() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 3 - 9
code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_args.cc

@@ -84,14 +84,8 @@ TfLiteStatus BroadcastArgsEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_BROADCAST_ARGS() {
 TfLiteRegistration Register_BROADCAST_ARGS() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/BroadcastArgsPrepare,
-          /*invoke=*/BroadcastArgsEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, BroadcastArgsPrepare,
+                                   BroadcastArgsEval);
 }
 }
 
 
-}  // namespace tflite
+}  // namespace tflite

+ 3 - 9
code/components/tflite-lib/tensorflow/lite/micro/kernels/broadcast_to.cc

@@ -116,14 +116,8 @@ TfLiteStatus BroadcastToEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_BROADCAST_TO() {
 TfLiteRegistration Register_BROADCAST_TO() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/BroadcastToPrepare,
-          /*invoke=*/BroadcastToEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, BroadcastToPrepare,
+                                   BroadcastToEval);
 }
 }
 
 
-}  // namespace tflite
+}  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/call_once.cc

@@ -82,14 +82,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace.
 }  // namespace.
 
 
 TfLiteRegistration Register_CALL_ONCE() {
 TfLiteRegistration Register_CALL_ONCE() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/cast.cc

@@ -108,14 +108,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_CAST() {
 TfLiteRegistration Register_CAST() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/ceil.cc

@@ -67,14 +67,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace ceil
 }  // namespace ceil
 
 
 TfLiteRegistration Register_CEIL() {
 TfLiteRegistration Register_CEIL() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/ceil::Prepare,
-          /*invoke=*/ceil::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, ceil::Prepare, ceil::Eval);
 }
 }
 
 
 }  // namespace micro
 }  // namespace micro

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/circular_buffer.cc

@@ -108,14 +108,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
 }
 }
 
 
 TfLiteRegistration* Register_CIRCULAR_BUFFER() {
 TfLiteRegistration* Register_CIRCULAR_BUFFER() {
-  static TfLiteRegistration r = {/*init=*/CircularBufferInit,
-                                 /*free=*/nullptr,
-                                 /*prepare=*/CircularBufferPrepare,
-                                 /*invoke=*/CircularBufferEval,
-                                 /*profiling_string=*/nullptr,
-                                 /*builtin_code=*/0,
-                                 /*custom_name=*/nullptr,
-                                 /*version=*/0};
+  static TfLiteRegistration r = tflite::micro::RegisterOp(CircularBufferInit, CircularBufferPrepare, CircularBufferEval);
   return &r;
   return &r;
 }
 }
 
 

+ 12 - 48
code/components/tflite-lib/tensorflow/lite/micro/kernels/comparisons.cc

@@ -583,69 +583,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace comparisons
 }  // namespace comparisons
 
 
 TfLiteRegistration Register_EQUAL() {
 TfLiteRegistration Register_EQUAL() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::EqualEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::EqualEval);
 }
 }
 
 
 TfLiteRegistration Register_NOT_EQUAL() {
 TfLiteRegistration Register_NOT_EQUAL() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::NotEqualEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::NotEqualEval);
 }
 }
 
 
 TfLiteRegistration Register_GREATER() {
 TfLiteRegistration Register_GREATER() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::GreaterEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::GreaterEval);
 }
 }
 
 
 TfLiteRegistration Register_GREATER_EQUAL() {
 TfLiteRegistration Register_GREATER_EQUAL() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::GreaterEqualEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::GreaterEqualEval);
 }
 }
 
 
 TfLiteRegistration Register_LESS() {
 TfLiteRegistration Register_LESS() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::LessEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::LessEval);
 }
 }
 
 
 TfLiteRegistration Register_LESS_EQUAL() {
 TfLiteRegistration Register_LESS_EQUAL() {
-  return {/*init=*/comparisons::Init,
-          /*free=*/nullptr,
-          /*prepare=*/comparisons::Prepare,
-          /*invoke=*/comparisons::LessEqualEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(comparisons::Init, comparisons::Prepare,
+                                   comparisons::LessEqualEval);
 }
 }
 
 
 }  // namespace micro
 }  // namespace micro

+ 5 - 11
code/components/tflite-lib/tensorflow/lite/micro/kernels/concatenation.cc

@@ -148,12 +148,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     TF_LITE_ENSURE(context, input != nullptr);
     TF_LITE_ENSURE(context, input != nullptr);
     int num_dimensions = NumDimensions(input);
     int num_dimensions = NumDimensions(input);
 
 
-    if (num_dimensions > 4) {
+    if (num_dimensions > RuntimeShape::kMaxSmallSize) {
       TF_LITE_KERNEL_LOG(
       TF_LITE_KERNEL_LOG(
           context,
           context,
-          "Op Concatenation does not currently support num dimensions >4 "
+          "Op Concatenation does not currently support num dimensions > %d "
           "Tensor has %d dimensions.",
           "Tensor has %d dimensions.",
-          num_dimensions);
+          RuntimeShape::kMaxSmallSize, num_dimensions);
       return kTfLiteError;
       return kTfLiteError;
     }
     }
     micro_context->DeallocateTempTfLiteTensor(input);
     micro_context->DeallocateTempTfLiteTensor(input);
@@ -252,14 +252,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace concatenation
 }  // namespace concatenation
 
 
 TfLiteRegistration Register_CONCATENATION() {
 TfLiteRegistration Register_CONCATENATION() {
-  return {/*init=*/concatenation::Init,
-          /*free=*/nullptr,
-          /*prepare=*/concatenation::Prepare,
-          /*invoke=*/concatenation::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(concatenation::Init, concatenation::Prepare,
+                                   concatenation::Eval);
 }
 }
 
 
 }  // namespace micro
 }  // namespace micro

+ 40 - 22
code/components/tflite-lib/tensorflow/lite/micro/kernels/conv.cc

@@ -25,6 +25,7 @@ limitations under the License.
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/padding.h"
 #include "tensorflow/lite/kernels/padding.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/micro_error_reporter.h"
 
 
 namespace tflite {
 namespace tflite {
 namespace {
 namespace {
@@ -67,23 +68,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorData<float>(filter),
           tflite::micro::GetTensorData<float>(filter),
           tflite::micro::GetTensorShape(bias),
           tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<float>(bias),
+          tflite::micro::GetOptionalTensorData<float>(bias),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<float>(output),
           tflite::micro::GetTensorData<float>(output),
           tflite::micro::GetTensorShape(nullptr), nullptr);
           tflite::micro::GetTensorShape(nullptr), nullptr);
       break;
       break;
     }
     }
     case kTfLiteInt16: {
     case kTfLiteInt16: {
-      reference_integer_ops::ConvPerChannel(
-          ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
-          data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
-          tflite::micro::GetTensorData<int16_t>(input),
-          tflite::micro::GetTensorShape(filter),
-          tflite::micro::GetTensorData<int8_t>(filter),
-          tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<std::int64_t>(bias),
-          tflite::micro::GetTensorShape(output),
-          tflite::micro::GetTensorData<int16_t>(output));
+      switch (bias->type) {
+        case kTfLiteInt32: {
+          reference_integer_ops::ConvPerChannel(
+              ConvParamsQuantized(params, data),
+              data.per_channel_output_multiplier, data.per_channel_output_shift,
+              tflite::micro::GetTensorShape(input),
+              tflite::micro::GetTensorData<int16_t>(input),
+              tflite::micro::GetTensorShape(filter),
+              tflite::micro::GetTensorData<int8_t>(filter),
+              tflite::micro::GetTensorShape(bias),
+              tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
+              tflite::micro::GetTensorShape(output),
+              tflite::micro::GetTensorData<int16_t>(output));
+          break;
+        }
+        case kTfLiteInt64: {
+          reference_integer_ops::ConvPerChannel(
+              ConvParamsQuantized(params, data),
+              data.per_channel_output_multiplier, data.per_channel_output_shift,
+              tflite::micro::GetTensorShape(input),
+              tflite::micro::GetTensorData<int16_t>(input),
+              tflite::micro::GetTensorShape(filter),
+              tflite::micro::GetTensorData<int8_t>(filter),
+              tflite::micro::GetTensorShape(bias),
+              tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
+              tflite::micro::GetTensorShape(output),
+              tflite::micro::GetTensorData<int16_t>(output));
+          break;
+        }
+        default:
+          MicroPrintf("Bias type %s (%d) not supported.",
+                      TfLiteTypeGetName(bias->type), bias->type);
+          return kTfLiteError;
+      }
       break;
       break;
     }
     }
     case kTfLiteInt8: {
     case kTfLiteInt8: {
@@ -94,14 +119,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorData<int8_t>(filter),
           tflite::micro::GetTensorData<int8_t>(filter),
           tflite::micro::GetTensorShape(bias),
           tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<int32_t>(bias),
+          tflite::micro::GetOptionalTensorData<int32_t>(bias),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<int8_t>(output));
           tflite::micro::GetTensorData<int8_t>(output));
       break;
       break;
     }
     }
     default:
     default:
-      TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
-                         TfLiteTypeGetName(input->type), input->type);
+      MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type),
+                  input->type);
       return kTfLiteError;
       return kTfLiteError;
   }
   }
   return kTfLiteOk;
   return kTfLiteOk;
@@ -110,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_CONV_2D() {
 TfLiteRegistration Register_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/ConvPrepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, ConvPrepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 10 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/conv_test.h

@@ -97,6 +97,16 @@ TfLiteStatus TestConvQuantizedPerChannel(
     float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
     float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
     TfLiteRegistration registration, int16_t* output_data);
     TfLiteRegistration registration, int16_t* output_data);
 
 
+TfLiteStatus TestConvQuantizedPerChannel(
+    int* input_dims_data, const float* input_data, int16_t* input_quantized,
+    float input_scale, int input_zero_point, int* filter_dims_data,
+    const float* filter_data, int8_t* filter_data_quantized,
+    int* bias_dims_data, const float* bias_data, int32_t* bias_data_quantized,
+    float* bias_scales, int* bias_zero_points, int* output_dims_data,
+    const float* expected_output_data, int16_t* expected_output_data_quantized,
+    float output_scale, int output_zero_point, TfLiteConvParams* conv_params,
+    TfLiteRegistration registration, int16_t* output_data);
+
 }  // namespace testing
 }  // namespace testing
 }  // namespace tflite
 }  // namespace tflite
 
 

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/cumsum.cc

@@ -169,14 +169,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_CUMSUM() {
 TfLiteRegistration Register_CUMSUM() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/depth_to_space.cc

@@ -136,14 +136,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_DEPTH_TO_SPACE() {
 TfLiteRegistration Register_DEPTH_TO_SPACE() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 3 - 10
code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.cc

@@ -62,7 +62,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorData<float>(filter),
           tflite::micro::GetTensorData<float>(filter),
           tflite::micro::GetTensorShape(bias),
           tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<float>(bias),
+          tflite::micro::GetOptionalTensorData<float>(bias),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<float>(output));
           tflite::micro::GetTensorData<float>(output));
       break;
       break;
@@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorShape(filter),
           tflite::micro::GetTensorData<int8_t>(filter),
           tflite::micro::GetTensorData<int8_t>(filter),
           tflite::micro::GetTensorShape(bias),
           tflite::micro::GetTensorShape(bias),
-          tflite::micro::GetTensorData<int32_t>(bias),
+          tflite::micro::GetOptionalTensorData<int32_t>(bias),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<int8_t>(output));
           tflite::micro::GetTensorData<int8_t>(output));
       break;
       break;
@@ -92,14 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
 TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/DepthwiseConvPrepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, DepthwiseConvPrepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 27 - 1
code/components/tflite-lib/tensorflow/lite/micro/kernels/depthwise_conv.h

@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 
 
 Licensed under the Apache License, Version 2.0 (the "License");
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 you may not use this file except in compliance with the License.
@@ -49,6 +49,32 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
 
 
 TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
 TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node);
 
 
+// This is the most generic TfLiteRegistration. The actual supported types may
+// still be target dependent. The only requirement is that every implementation
+// (reference or optimized) must define this function.
+TfLiteRegistration Register_DEPTHWISE_CONV_2D();
+
+#if defined(CMSIS_NN)
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int8 activations and int8 weights and uses the latency optimized
+// implementations.
+TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8();
+
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int16 activations and int8 weights and uses the latency optimized
+// implementations.
+TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16();
+
+#else
+inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT8() {
+  return Register_DEPTHWISE_CONV_2D();
+}
+
+inline TfLiteRegistration Register_DEPTHWISE_CONV_2D_INT16() {
+  return Register_DEPTHWISE_CONV_2D();
+}
+#endif
+
 }  // namespace tflite
 }  // namespace tflite
 
 
 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_
 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_DEPTHWISE_CONV_H_

+ 9 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize.cc

@@ -57,6 +57,13 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
                                   tflite::micro::GetTensorShape(output),
                                   tflite::micro::GetTensorShape(output),
                                   tflite::micro::GetTensorData<float>(output));
                                   tflite::micro::GetTensorData<float>(output));
         break;
         break;
+      case kTfLiteUInt8:
+        reference_ops::Dequantize(data->quantization_params,
+                                  tflite::micro::GetTensorShape(input),
+                                  tflite::micro::GetTensorData<uint8_t>(input),
+                                  tflite::micro::GetTensorShape(output),
+                                  tflite::micro::GetTensorData<float>(output));
+        break;
       default:
       default:
         MicroPrintf("Input %s, output %s not supported.",
         MicroPrintf("Input %s, output %s not supported.",
                     TfLiteTypeGetName(input->type),
                     TfLiteTypeGetName(input->type),
@@ -74,14 +81,8 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
 }
 }
 
 
 TfLiteRegistration Register_DEQUANTIZE() {
 TfLiteRegistration Register_DEQUANTIZE() {
-  return {/*init=*/DequantizeInit,
-          /*free=*/nullptr,
-          /*prepare=*/DequantizePrepare,
-          /*invoke=*/DequantizeEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare,
+                                   DequantizeEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 3 - 2
code/components/tflite-lib/tensorflow/lite/micro/kernels/dequantize_common.cc

@@ -41,8 +41,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
   TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
   TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
   TF_LITE_ENSURE(context, output != nullptr);
   TF_LITE_ENSURE(context, output != nullptr);
 
 
-  TF_LITE_ENSURE(context,
-                 input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
+  TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
+                              input->type == kTfLiteInt16 ||
+                              input->type == kTfLiteUInt8);
   TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
   TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
 
 
   if (output->type == kTfLiteInt32) {
   if (output->type == kTfLiteInt32) {

+ 1 - 10
code/components/tflite-lib/tensorflow/lite/micro/kernels/detection_postprocess.cc

@@ -149,8 +149,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   return op_data;
   return op_data;
 }
 }
 
 
-void Free(TfLiteContext* context, void* buffer) {}
-
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   auto* op_data = static_cast<OpData*>(node->user_data);
   auto* op_data = static_cast<OpData*>(node->user_data);
 
 
@@ -802,14 +800,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
 TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
-  static TfLiteRegistration r = {/*init=*/Init,
-                                 /*free=*/Free,
-                                 /*prepare=*/Prepare,
-                                 /*invoke=*/Eval,
-                                 /*profiling_string=*/nullptr,
-                                 /*builtin_code=*/0,
-                                 /*custom_name=*/nullptr,
-                                 /*version=*/0};
+  static TfLiteRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
   return &r;
   return &r;
 }
 }
 
 

+ 289 - 79
code/components/tflite-lib/tensorflow/lite/micro/kernels/elementwise.cc

@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 
 
 Licensed under the Apache License, Version 2.0 (the "License");
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@ limitations under the License.
 #include <cmath>
 #include <cmath>
 
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/kernels/kernel_util.h"
@@ -27,6 +29,22 @@ namespace micro {
 namespace elementwise {
 namespace elementwise {
 namespace {
 namespace {
 
 
+constexpr int kAbsNameId = 0;
+constexpr int kRsrqtNameId = 1;
+
+const int kElementwiseInputTensor = 0;
+const int kElementwiseOutputTensor = 0;
+
+struct OpDataAbsRsqrt {
+  int32_t multiplier;
+  int shift;
+  int input_offset;
+  int output_offset;
+  bool needs_rescale;
+  TfLiteQuantizationType input_quantization_type;
+  TfLiteType input_type;
+};
+
 bool IsNumericSupportedType(const TfLiteType type) {
 bool IsNumericSupportedType(const TfLiteType type) {
   return type == kTfLiteFloat32;
   return type == kTfLiteFloat32;
 }
 }
@@ -35,11 +53,57 @@ bool IsLogicalSupportedType(const TfLiteType type) {
   return type == kTfLiteBool;
   return type == kTfLiteBool;
 }
 }
 
 
+bool IsAbsSupportedType(const TfLiteType type) {
+  return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
+}
+
+bool IsRsqrtSupportedType(const TfLiteType type) {
+  return type == kTfLiteFloat32 || type == kTfLiteInt8;
+}
+
+inline void SetAbsOutputMultiplier(const float input_scale,
+                                   const float output_scale,
+                                   int32_t* multiplier, int* shift) {
+  QuantizeMultiplier(static_cast<double>(input_scale / output_scale),
+                     multiplier, shift);
+}
+
+inline void SetRsqrtOutputMultiplier(const float input_scale,
+                                     const float output_scale,
+                                     int32_t* multiplier, int* shift) {
+  const double scale =
+      1. / static_cast<double>((std::sqrt(input_scale) * output_scale));
+  QuantizeMultiplier(scale, multiplier, shift);
+}
+
 typedef bool (*IsSupportedType)(TfLiteType);
 typedef bool (*IsSupportedType)(TfLiteType);
 template <IsSupportedType>
 template <IsSupportedType>
 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
   MicroContext* micro_context = GetMicroContext(context);
   MicroContext* micro_context = GetMicroContext(context);
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  TfLiteTensor* input =
+      micro_context->AllocateTempInputTensor(node, kElementwiseInputTensor);
+  TF_LITE_ENSURE(context, input != nullptr);
+  TfLiteTensor* output =
+      micro_context->AllocateTempOutputTensor(node, kElementwiseOutputTensor);
+  TF_LITE_ENSURE(context, output != nullptr);
+  TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
+  if (!IsSupportedType(input->type)) {
+    TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
+                       TfLiteTypeGetName(input->type), input->type);
+    return kTfLiteError;
+  }
+
+  micro_context->DeallocateTempTfLiteTensor(input);
+  micro_context->DeallocateTempTfLiteTensor(output);
+  return kTfLiteOk;
+}
 
 
+typedef bool (*IsSupportedType)(TfLiteType);
+template <IsSupportedType, const int op_nameid>
+TfLiteStatus PrepareAbsRsqrt(TfLiteContext* context, TfLiteNode* node) {
+  MicroContext* micro_context = GetMicroContext(context);
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
   TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
   TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
@@ -53,14 +117,87 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
     return kTfLiteError;
     return kTfLiteError;
   }
   }
 
 
+  auto* op_data = static_cast<OpDataAbsRsqrt*>(node->user_data);
+  op_data->input_type = input->type;
+
+  // For int16 type input, we support both quantized and non-quantized
+  // evaluation.
+  if (op_nameid == kAbsNameId) {
+    op_data->input_quantization_type = input->quantization.type;
+  }
+
+  if (input->type == kTfLiteInt8 ||
+      (input->type == kTfLiteInt16 &&
+       input->quantization.type != kTfLiteNoQuantization)) {
+    TF_LITE_ENSURE_EQ(context, input->quantization.type,
+                      kTfLiteAffineQuantization);
+    TF_LITE_ENSURE_EQ(context, output->quantization.type,
+                      kTfLiteAffineQuantization);
+    const auto* input_params =
+        reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
+    const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
+        output->quantization.params);
+    TF_LITE_ENSURE(context, input_params != nullptr);
+    TF_LITE_ENSURE(context, input_params->scale != nullptr);
+    TF_LITE_ENSURE(context, input_params->scale->size > 0);
+    TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
+    TF_LITE_ENSURE(context, output_params != nullptr);
+    TF_LITE_ENSURE(context, output_params->scale != nullptr);
+    TF_LITE_ENSURE(context, output_params->scale->size > 0);
+    TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
+    op_data->input_offset = input_params->zero_point->data[0];
+    op_data->output_offset = output_params->zero_point->data[0];
+    if (input->type == kTfLiteInt16) {
+      TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
+      TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
+    }
+    const float input_scale = input_params->scale->data[0];
+    const float output_scale = output_params->scale->data[0];
+    op_data->needs_rescale = input_scale != output_scale;
+    if (op_nameid == kAbsNameId && op_data->needs_rescale) {
+      SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
+                             &op_data->shift);
+    } else if (op_nameid == kRsrqtNameId) {
+      SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
+                               &op_data->shift);
+    }
+  }
   micro_context->DeallocateTempTfLiteTensor(input);
   micro_context->DeallocateTempTfLiteTensor(input);
   micro_context->DeallocateTempTfLiteTensor(output);
   micro_context->DeallocateTempTfLiteTensor(output);
   return kTfLiteOk;
   return kTfLiteOk;
 }
 }
 
 
+template <typename T>
+inline TfLiteStatus EvalImplQuantized(
+    TfLiteContext* context, TfLiteNode* node,
+    T func(TfLiteContext*, TfLiteNode*, T),
+    TfLiteStatus validate_input_func(TfLiteContext*, TfLiteNode*, T),
+    TfLiteType expected_type) {
+  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
+  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
+  TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
+  const size_t num_elements = ElementCount(*input->dims);
+  const T* in_data = tflite::micro::GetTensorData<T>(input);
+  T* out_data = tflite::micro::GetTensorData<T>(output);
+  for (size_t i = 0; i < num_elements; ++i) {
+    if (validate_input_func) {
+      TF_LITE_ENSURE_OK(context,
+                        validate_input_func(context, node, in_data[i]));
+    }
+    out_data[i] = func(context, node, in_data[i]);
+  }
+  return kTfLiteOk;
+}
+
+template <typename T>
+inline T AbsHelper(T i) {
+  return std::abs(i);
+}
+
 template <typename T>
 template <typename T>
 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
-                             T func(T), TfLiteType expected_type) {
+                             T func(T), TfLiteStatus validate_input_func(T),
+                             TfLiteType expected_type) {
   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
@@ -68,6 +205,9 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
   const T* in_data = tflite::micro::GetTensorData<T>(input);
   const T* in_data = tflite::micro::GetTensorData<T>(input);
   T* out_data = tflite::micro::GetTensorData<T>(output);
   T* out_data = tflite::micro::GetTensorData<T>(output);
   for (size_t i = 0; i < num_elements; ++i) {
   for (size_t i = 0; i < num_elements; ++i) {
+    if (validate_input_func) {
+      TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
+    }
     out_data[i] = func(in_data[i]);
     out_data[i] = func(in_data[i]);
   }
   }
   return kTfLiteOk;
   return kTfLiteOk;
@@ -75,16 +215,114 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
 
 
 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
                                 float float_func(float)) {
                                 float float_func(float)) {
-  return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
+  return EvalImpl<float>(context, node, float_func,
+                         /*validate_input_func=*/nullptr, kTfLiteFloat32);
 }
 }
 
 
 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
+
                                 bool bool_func(bool)) {
                                 bool bool_func(bool)) {
-  return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
+  return EvalImpl<bool>(context, node, bool_func,
+                        /*validate_input_func=*/nullptr, kTfLiteBool);
+}
+
+void* ElementWiseAbsRsqrtInit(TfLiteContext* context, const char* buffer,
+                              size_t length) {
+  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+  return context->AllocatePersistentBuffer(context, sizeof(OpDataAbsRsqrt));
+}
+
+template <typename T>
+inline T AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
+  const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
+  const int kMin = std::numeric_limits<T>::min();
+  const int kMax = std::numeric_limits<T>::max();
+
+  const int32_t value = std::abs(i - op_data->input_offset);
+  if (!op_data->needs_rescale) {
+    return static_cast<T>(
+        std::min(std::max(static_cast<long int>(value + op_data->output_offset),
+                          static_cast<long int>(kMin)),
+                 static_cast<long int>(kMax)));
+  }
+
+  const int32_t output = tflite::MultiplyByQuantizedMultiplier(
+                             value, op_data->multiplier, op_data->shift) +
+                         op_data->output_offset;
+  return static_cast<T>(std::min(
+      std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
+      static_cast<long int>(kMax)));
+}
+
+template <typename T>
+inline T RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, T i) {
+  const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
+  const int kMin = std::numeric_limits<T>::min();
+  const int kMax = std::numeric_limits<T>::max();
+
+  const int32_t value = (i - op_data->input_offset);
+  const int32_t kShift = 20;  // Shift to keep value integer.
+  if (value == 0) {
+    // Assume that any value close to 0 represents the max output value.
+    return static_cast<T>(kMax);
+  }
+  int32_t inv_sqrt_multiplier;
+  int inv_sqrt_shift;
+  GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
+                                   &inv_sqrt_shift);
+  const int32_t data = tflite::MultiplyByQuantizedMultiplier(
+      static_cast<int32_t>(1), inv_sqrt_multiplier, inv_sqrt_shift + kShift);
+  const int32_t output =
+      tflite::MultiplyByQuantizedMultiplier(data, op_data->multiplier,
+                                            op_data->shift - kShift) +
+      op_data->output_offset;
+  return static_cast<T>(std::min(
+      std::max(static_cast<long int>(output), static_cast<long int>(kMin)),
+      static_cast<long int>(kMax)));
+}
+
+template <typename T>
+TfLiteStatus validate_input_func(TfLiteContext* context, TfLiteNode* node,
+                                 T i) {
+  const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
+
+  TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
+                     "Rsqrt is only defined for positive values");
+  return static_cast<TfLiteStatus>(kTfLiteOk);
 }
 }
 
 
 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
-  return EvalNumeric(context, node, std::abs);
+  OpDataAbsRsqrt* op_data = reinterpret_cast<OpDataAbsRsqrt*>(node->user_data);
+  TfLiteType type = op_data->input_type;
+  TfLiteQuantizationType input_quantization_type =
+      op_data->input_quantization_type;
+  TfLiteStatus eval_result;
+
+  switch (type) {
+    case kTfLiteFloat32:
+      eval_result = EvalNumeric(context, node, std::abs);
+      break;
+    case kTfLiteInt8:
+      eval_result =
+          EvalImplQuantized<int8_t>(context, node, AbsEvalQuantized,
+                                    /*validate_input_func=*/nullptr, type);
+      break;
+    case kTfLiteInt16:
+      eval_result =
+          input_quantization_type == kTfLiteNoQuantization
+              ? EvalImpl<int16_t>(context, node, AbsHelper,
+                                  /*validate_input_func=*/nullptr, type)
+              : EvalImplQuantized<int16_t>(context, node, AbsEvalQuantized,
+                                           /*validate_input_func=*/nullptr,
+                                           type);
+      break;
+    default:
+      TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
+                         TfLiteTypeGetName(type));
+      return kTfLiteError;
+      break;
+  }
+  return eval_result;
 }
 }
 
 
 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
@@ -104,7 +342,23 @@ TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
 }
 }
 
 
 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
-  return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
+  const auto* op_data = static_cast<const OpDataAbsRsqrt*>(node->user_data);
+  TfLiteType type = op_data->input_type;
+  switch (type) {
+    case kTfLiteFloat32:
+      return EvalImpl<float>(
+          context, node, [](float f) { return 1.f / std::sqrt(f); },
+          /*validate_input_func=*/nullptr, type);
+    case kTfLiteInt8:
+      return EvalImplQuantized<int8_t>(context, node,
+                                       elementwise::RsqrtEvalQuantized,
+                                       elementwise::validate_input_func, type);
+
+    default:
+      TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
+                         TfLiteTypeGetName(type));
+      return kTfLiteError;
+  }
 }
 }
 
 
 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
@@ -119,101 +373,57 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace elementwise
 }  // namespace elementwise
 
 
 TfLiteRegistration Register_ABS() {
 TfLiteRegistration Register_ABS() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::AbsEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      elementwise::ElementWiseAbsRsqrtInit,
+      elementwise::PrepareAbsRsqrt<elementwise::IsAbsSupportedType,
+                                   elementwise::kAbsNameId>,
+      elementwise::AbsEval);
 }
 }
 
 
 TfLiteRegistration Register_SIN() {
 TfLiteRegistration Register_SIN() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::SinEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SinEval);
 }
 }
 
 
 TfLiteRegistration Register_COS() {
 TfLiteRegistration Register_COS() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::CosEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::CosEval);
 }
 }
 
 
 TfLiteRegistration Register_LOG() {
 TfLiteRegistration Register_LOG() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::LogEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::LogEval);
 }
 }
 
 
 TfLiteRegistration Register_SQRT() {
 TfLiteRegistration Register_SQRT() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::SqrtEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SqrtEval);
 }
 }
 
 
 TfLiteRegistration Register_RSQRT() {
 TfLiteRegistration Register_RSQRT() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::RsqrtEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      elementwise::ElementWiseAbsRsqrtInit,
+      elementwise::PrepareAbsRsqrt<elementwise::IsRsqrtSupportedType,
+                                   elementwise::kRsrqtNameId>,
+      elementwise::RsqrtEval);
 }
 }
 
 
 TfLiteRegistration Register_SQUARE() {
 TfLiteRegistration Register_SQUARE() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
-          /*invoke=*/elementwise::SquareEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SquareEval);
 }
 }
 
 
 TfLiteRegistration Register_LOGICAL_NOT() {
 TfLiteRegistration Register_LOGICAL_NOT() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/
-          elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
-          /*invoke=*/elementwise::LogicalNotEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(
+      nullptr, elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
+      elementwise::LogicalNotEval);
 }
 }
 
 
 }  // namespace micro
 }  // namespace micro
 }  // namespace ops
 }  // namespace ops
-}  // namespace tflite
+}  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/elu.cc

@@ -146,14 +146,7 @@ TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_ELU() {
 TfLiteRegistration Register_ELU() {
-  return {/*init=*/EluInit,
-          /*free=*/nullptr,
-          /*prepare=*/EluPrepare,
-          /*invoke=*/EluEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(EluInit, EluPrepare, EluEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/add.cc

@@ -196,14 +196,7 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
 }
 }
 
 
 TfLiteRegistration Register_ADD() {
 TfLiteRegistration Register_ADD() {
-  return {/*init=*/AddInit,
-          /*free=*/nullptr,
-          /*prepare=*/AddPrepare,
-          /*invoke=*/AddEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(AddInit, AddPrepare, AddEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 46 - 21
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/conv.cc

@@ -112,9 +112,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 
 
 #if ESP_NN
 #if ESP_NN
   if (input->type == kTfLiteInt8) {
   if (input->type == kTfLiteInt8) {
+    data_dims_t input_dims =  {
+                                .width = input_width, .height = input_height,
+                                .channels = input->dims->data[3], 1
+                              };
+    data_dims_t output_dims = {
+                                .width = output_width, .height = output_height,
+                                .channels = output->dims->data[3], 1
+                              };
+    data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+    conv_params_t conv_params = {
+                                  .in_offset = 0, .out_offset = 0,
+                                  .stride = {params.stride_width, params.stride_height},
+                                  .padding = {data->op_data.padding.width, data->op_data.padding.height},
+                                  .dilation = {0, 0}, .activation = {-128, 127}
+                                };
+
     int scratch_buf_size = esp_nn_get_conv_scratch_size(
     int scratch_buf_size = esp_nn_get_conv_scratch_size(
-        input_width, input_height, input->dims->data[3],
-        output->dims->data[3], filter_width, filter_height);
+        &input_dims, &filter_dims, &output_dims, &conv_params);
     if (scratch_buf_size > 0) {
     if (scratch_buf_size > 0) {
       TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
       TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
         context, scratch_buf_size, &data->buffer_idx));
         context, scratch_buf_size, &data->buffer_idx));
@@ -191,18 +206,33 @@ inline void EvalQuantizedPerChannel(
     const int input_size = input_width * input_height * input_depth;
     const int input_size = input_width * input_height * input_depth;
     const int output_size = output_width * output_height * output_depth;
     const int output_size = output_width * output_height * output_depth;
 
 
+    data_dims_t input_dims =  {
+                                .width = input_width, .height = input_height,
+                                .channels = input_depth, 1
+                              };
+    data_dims_t output_dims = {
+                                .width = output_width, .height = output_height,
+                                .channels = output_depth, 1
+                              };
+    data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+    conv_params_t conv_params = {
+                                  .in_offset = input_offset, .out_offset = output_offset,
+                                  .stride = {stride_width, stride_height},
+                                  .padding = {pad_width, pad_height},
+                                  .dilation = {0, 0},
+                                  .activation = {activation_min, activation_max}
+                                };
+    quant_data_t quant_data = {
+                                .shift = data.op_data.per_channel_output_shift,
+                                .mult = data.op_data.per_channel_output_multiplier
+                              };
+
     for (int i_batch = 0; i_batch < batch_size; i_batch++) {
     for (int i_batch = 0; i_batch < batch_size; i_batch++) {
-      esp_nn_conv_s8(input_data + i_batch * input_size,
-                     input_width, input_height, input_depth, input_offset,
-                     pad_width, pad_height, stride_width, stride_height,
-                     tflite::micro::GetTensorData<int8_t>(filter),
-                     filter_width, filter_height,
+      esp_nn_conv_s8(&input_dims, input_data + i_batch * input_size,
+                     &filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
                      tflite::micro::GetTensorData<int32_t>(bias),
                      tflite::micro::GetTensorData<int32_t>(bias),
-                     output_data + i_batch * output_size,
-                     output_width, output_height, output_depth, output_offset,
-                     data.op_data.per_channel_output_shift,
-                     data.op_data.per_channel_output_multiplier,
-                     activation_min, activation_max);
+                     &output_dims, output_data + i_batch * output_size,
+                     &conv_params, &quant_data);
     }
     }
   } else {
   } else {
     reference_integer_ops::ConvPerChannel(
     reference_integer_ops::ConvPerChannel(
@@ -299,21 +329,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
                          TfLiteTypeGetName(input->type), input->type);
                          TfLiteTypeGetName(input->type), input->type);
       return kTfLiteError;
       return kTfLiteError;
   }
   }
-  conv_total_time += esp_timer_get_time() - start_time;
+  long long time_this_instance = esp_timer_get_time() - start_time;
+  conv_total_time += time_this_instance;
+  //printf("time this instance: %llu\n", time_this_instance / 1000);
   return kTfLiteOk;
   return kTfLiteOk;
 }
 }
 
 
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_CONV_2D() {
 TfLiteRegistration Register_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 49 - 22
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/depthwise_conv.cc

@@ -112,21 +112,36 @@ inline void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
     if (data.buffer_idx > -1) {
     if (data.buffer_idx > -1) {
       scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
       scratch_buf = context->GetScratchBuffer(context, data.buffer_idx);
     }
     }
+
     esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
     esp_nn_set_depthwise_conv_scratch_buf(scratch_buf);
 
 
+    data_dims_t input_dims =  {
+                                .width = input_width, .height = input_height,
+                                .channels = input_depth, 1
+                              };
+    data_dims_t output_dims = {
+                                .width = output_width, .height = output_height,
+                                .channels = output_depth, 1
+                              };
+    data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+    dw_conv_params_t conv_params =  {
+                                      .in_offset = input_offset, .out_offset = output_offset,
+                                      .ch_mult = depth_multiplier,
+                                      .stride = {stride_width, stride_height},
+                                      .padding = {pad_width, pad_height}, .dilation = {0, 0},
+                                      .activation = {activation_min, activation_max}
+                                    };
+    quant_data_t quant_data = {
+                                .shift = data.op_data.per_channel_output_shift,
+                                .mult = data.op_data.per_channel_output_multiplier
+                              };
+
     for (int i_batch = 0; i_batch < batch_size; i_batch++) {
     for (int i_batch = 0; i_batch < batch_size; i_batch++) {
-      esp_nn_depthwise_conv_s8(input_data + i_batch * input_size, input_width,
-                               input_height, input_depth, input_offset,
-                               pad_width, pad_height,
-                               stride_width, stride_height, depth_multiplier,
-                               tflite::micro::GetTensorData<int8_t>(filter),
-                               filter_width, filter_height,
+      esp_nn_depthwise_conv_s8(&input_dims, input_data + i_batch * input_size,
+                               &filter_dims, tflite::micro::GetTensorData<int8_t>(filter),
                                tflite::micro::GetTensorData<int32_t>(bias),
                                tflite::micro::GetTensorData<int32_t>(bias),
-                               output_data + i_batch * output_size,
-                               output_width, output_height, output_offset,
-                               data.op_data.per_channel_output_shift,
-                               data.op_data.per_channel_output_multiplier,
-                               activation_min, activation_max);
+                               &output_dims, output_data + i_batch * output_size,
+                               &conv_params, &quant_data);
     }
     }
   } else {
   } else {
     reference_integer_ops::DepthwiseConvPerChannel(
     reference_integer_ops::DepthwiseConvPerChannel(
@@ -209,9 +224,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 
 
 #if ESP_NN
 #if ESP_NN
   if (input->type == kTfLiteInt8) {
   if (input->type == kTfLiteInt8) {
+    data_dims_t input_dims =  {
+                                .width = input_width, .height = input_height,
+                                .channels = input->dims->data[3], 1
+                              };
+    data_dims_t output_dims = {
+                                .width = output_width, .height = output_height,
+                                .channels = output->dims->data[3], 1
+                              };
+    data_dims_t filter_dims = {.width = filter_width, .height = filter_height, 0, 0};
+    dw_conv_params_t conv_params =  {
+                                      .in_offset = 0, .out_offset = 0,
+                                      .ch_mult = params.depth_multiplier,
+                                      .stride = {params.stride_width, params.stride_height},
+                                      .padding = {data->op_data.padding.width, data->op_data.padding.height},
+                                      .dilation = {0, 0}, .activation = {-128, 127}
+                                    };
+
     int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
     int scratch_buf_size = esp_nn_get_depthwise_conv_scratch_size(
-        input_width, input_height, input->dims->data[3],
-        params.depth_multiplier, filter_width, filter_height);
+        &input_dims, &filter_dims, &output_dims, &conv_params);
     if (scratch_buf_size > 0) {
     if (scratch_buf_size > 0) {
       TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
       TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
         context, scratch_buf_size, &data->buffer_idx));
         context, scratch_buf_size, &data->buffer_idx));
@@ -299,21 +330,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
                          TfLiteTypeGetName(input->type), input->type);
                          TfLiteTypeGetName(input->type), input->type);
       return kTfLiteError;
       return kTfLiteError;
   }
   }
-  dc_total_time += esp_timer_get_time() - start_time;
+  long long time_this_instance = esp_timer_get_time() - start_time;
+  dc_total_time += time_this_instance;
+  // printf("time this instance: %llu\n", time_this_instance / 1000);
+
   return kTfLiteOk;
   return kTfLiteOk;
 }
 }
 
 
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
 TfLiteRegistration Register_DEPTHWISE_CONV_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/fully_connected.cc

@@ -185,14 +185,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_FULLY_CONNECTED() {
 TfLiteRegistration Register_FULLY_CONNECTED() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/mul.cc

@@ -118,14 +118,7 @@ TfLiteStatus MulEval(TfLiteContext* context, TfLiteNode* node) {
 }
 }
 
 
 TfLiteRegistration Register_MUL() {
 TfLiteRegistration Register_MUL() {
-  return {/*init=*/MulInit,
-          /*free=*/nullptr,
-          /*prepare=*/MulPrepare,
-          /*invoke=*/MulEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(MulInit, MulPrepare, MulEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 2 - 16
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/pooling.cc

@@ -221,25 +221,11 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_AVERAGE_POOL_2D() {
 TfLiteRegistration Register_AVERAGE_POOL_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/PoolingPrepare,
-          /*invoke=*/AverageEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, PoolingPrepare, AverageEval);
 }
 }
 
 
 TfLiteRegistration Register_MAX_POOL_2D() {
 TfLiteRegistration Register_MAX_POOL_2D() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/PoolingPrepare,
-          /*invoke=*/MaxEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, PoolingPrepare, MaxEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 208 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/esp_nn/softmax.cc

@@ -0,0 +1,208 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/micro/kernels/softmax.h"
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/softmax.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+
+#include "freertos/FreeRTOS.h"
+#include <esp_timer.h>
+
+#if ESP_NN
+#include <esp_nn.h>
+#endif
+
+long long softmax_total_time = 0;
+
+namespace tflite {
+namespace {
+// Softmax parameter data that persists in user_data
+const int kInt16LUTArraySize = 513;
+
+struct NodeData {
+  SoftmaxParams op_data;
+#if ESP_NN
+  int buffer_idx;
+#endif
+};
+
+static void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+  return context->AllocatePersistentBuffer(context, sizeof(NodeData));
+}
+
+void SoftmaxQuantized(TfLiteContext* context, const TfLiteEvalTensor* input,
+                      TfLiteEvalTensor* output, const NodeData* data) {
+  if (input->type == kTfLiteInt8) {
+    if (output->type == kTfLiteInt16) {
+      tflite::reference_ops::Softmax(
+          data->op_data, tflite::micro::GetTensorShape(input),
+          tflite::micro::GetTensorData<int8_t>(input),
+          tflite::micro::GetTensorShape(output),
+          tflite::micro::GetTensorData<int16_t>(output));
+    } else {
+#if ESP_NN
+      const int32_t input_beta_multiplier = data->op_data.input_multiplier;
+      const int32_t input_beta_left_shift = data->op_data.input_left_shift;
+      const int diff_min = data->op_data.diff_min;
+      const RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
+      const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
+      const int trailing_dim = input_shape.DimensionsCount() - 1;
+      const int outer_size =
+          MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+      const int depth =
+          MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+      const int8_t *in_ptr = tflite::micro::GetTensorData<int8_t>(input);
+      int8_t *out_ptr = tflite::micro::GetTensorData<int8_t>(output);
+      void *scratch_buf = NULL;
+      if (data->buffer_idx > -1) {
+        scratch_buf = context->GetScratchBuffer(context, data->buffer_idx);
+      }
+      esp_nn_set_softmax_scratch_buf(scratch_buf);
+      esp_nn_softmax_s8(in_ptr, outer_size, depth, input_beta_multiplier,
+                        input_beta_left_shift, diff_min, out_ptr);
+#else
+      tflite::reference_ops::Softmax(
+          data->op_data, tflite::micro::GetTensorShape(input),
+          tflite::micro::GetTensorData<int8_t>(input),
+          tflite::micro::GetTensorShape(output),
+          tflite::micro::GetTensorData<int8_t>(output));
+#endif
+    }
+  } else {
+    tflite::reference_ops::SoftmaxInt16(
+        data->op_data, tflite::micro::GetTensorShape(input),
+        tflite::micro::GetTensorData<int16_t>(input),
+        tflite::micro::GetTensorShape(output),
+        tflite::micro::GetTensorData<int16_t>(output));
+  }
+}
+
+static TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
+  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
+
+  TFLITE_DCHECK(node->user_data != nullptr);
+  NodeData data = *static_cast<NodeData*>(node->user_data);
+
+  long long start_time = esp_timer_get_time();
+  switch (input->type) {
+    case kTfLiteFloat32: {
+      tflite::reference_ops::Softmax(
+          data.op_data, tflite::micro::GetTensorShape(input),
+          tflite::micro::GetTensorData<float>(input),
+          tflite::micro::GetTensorShape(output),
+          tflite::micro::GetTensorData<float>(output));
+    }
+    break;
+    case kTfLiteInt8:
+    case kTfLiteInt16: {
+      SoftmaxQuantized(context, input, output, &data);
+    }
+    break;
+    default:
+      TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
+                         TfLiteTypeGetName(input->type), input->type);
+      return kTfLiteError;
+  }
+  softmax_total_time += esp_timer_get_time() - start_time;
+  return kTfLiteOk;
+}
+
+static TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  MicroContext* micro_context = GetMicroContext(context);
+
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
+  TF_LITE_ENSURE(context, input != nullptr);
+  TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
+  TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
+  TF_LITE_ENSURE(context, output != nullptr);
+
+  TF_LITE_ENSURE(context, node->user_data != nullptr);
+  NodeData* data = static_cast<NodeData*>(node->user_data);
+  // Only allocate LUTs for KTfLiteInt16 data type
+  if (input->type == kTfLiteInt16) {
+    void* raw_exp_lut = context->AllocatePersistentBuffer(
+        context, sizeof(int16_t) * kInt16LUTArraySize);
+    TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
+    data->op_data.exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
+    void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
+        context, sizeof(int16_t) * kInt16LUTArraySize);
+    TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
+    data->op_data.one_over_one_plus_x_lut =
+        reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
+  }
+
+  if (output->type == kTfLiteInt16) {
+    TF_LITE_ENSURE(context,
+                   input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
+  } else {
+    TF_LITE_ENSURE_EQ(context, input->type, output->type);
+  }
+
+  // Populate LUT if required
+  if (input->type == kTfLiteInt16) {
+    TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+    // exp LUT only used on negative values
+    // we consider exp(-10.0) is insignificant to accumulation
+    gen_lut<float, int16_t, int16_t>(
+        [](float value) { return std::exp(value); }, -10.0f, 0.0f, -1.0f, 1.0f,
+        data->op_data.exp_lut);
+    gen_lut<float, int16_t, int16_t>(
+        [](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f, -1.0f,
+        1.0f, data->op_data.one_over_one_plus_x_lut);
+    data->op_data.zero_point = output->params.zero_point;
+    data->op_data.scale = output->params.scale;
+  }
+
+  auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+  auto ret_val =
+      CalculateSoftmaxParams(context, input, output, params, &data->op_data);
+
+#if ESP_NN
+  if (output->type == kTfLiteInt8 && input->type == kTfLiteInt8) {
+    const int32_t input_width = input->dims->data[1];
+    const int32_t input_height = input->dims->data[2];
+    int scratch_buf_size = esp_nn_get_softmax_scratch_size(input_width,
+                                                           input_height);
+    if (scratch_buf_size > 0) {
+      TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
+        context, scratch_buf_size, &data->buffer_idx));
+    }
+  }
+#endif
+
+  micro_context->DeallocateTempTfLiteTensor(input);
+  micro_context->DeallocateTempTfLiteTensor(output);
+  return ret_val;
+}
+
+}  // namespace
+
+TfLiteRegistration Register_SOFTMAX() {
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
+}
+
+}  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/exp.cc

@@ -72,14 +72,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_EXP() {
 TfLiteRegistration Register_EXP() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/expand_dims.cc

@@ -146,14 +146,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_EXPAND_DIMS() {
 TfLiteRegistration Register_EXPAND_DIMS() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/fill.cc

@@ -135,14 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_FILL() {
 TfLiteRegistration Register_FILL() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/floor.cc

@@ -42,14 +42,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace floor
 }  // namespace floor
 
 
 TfLiteRegistration Register_FLOOR() {
 TfLiteRegistration Register_FLOOR() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/floor::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, floor::Eval);
 }
 }
 
 
 }  // namespace micro
 }  // namespace micro

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_div.cc

@@ -123,14 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_FLOOR_DIV() {
 TfLiteRegistration Register_FLOOR_DIV() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/floor_mod.cc

@@ -121,14 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_FLOOR_MOD() {
 TfLiteRegistration Register_FLOOR_MOD() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 19 - 12
code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.cc

@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 
 
 Licensed under the Apache License, Version 2.0 (the "License");
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 you may not use this file except in compliance with the License.
@@ -55,10 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
   TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
       node, kFullyConnectedOutputTensor);
       node, kFullyConnectedOutputTensor);
   TF_LITE_ENSURE(context, output != nullptr);
   TF_LITE_ENSURE(context, output != nullptr);
-
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
-  TF_LITE_ENSURE_MSG(context, input->type == filter->type,
-                     "Hybrid models are not supported on TFLite Micro.");
 
 
   TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
   TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
                                  context, params->activation, input->type,
                                  context, params->activation, input->type,
@@ -126,6 +123,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       break;
       break;
     }
     }
 
 
+    case kTfLiteInt16: {
+      const int64_t* bias_data =
+          nullptr != bias ? tflite::micro::GetTensorData<int64_t>(bias)
+                          : nullptr;
+
+      tflite::reference_integer_ops::FullyConnected(
+          FullyConnectedParamsQuantized(data),
+          tflite::micro::GetTensorShape(input),
+          tflite::micro::GetTensorData<int16_t>(input),
+          tflite::micro::GetTensorShape(filter),
+          tflite::micro::GetTensorData<int8_t>(filter),
+          tflite::micro::GetTensorShape(bias), bias_data,
+          tflite::micro::GetTensorShape(output),
+          tflite::micro::GetTensorData<int16_t>(output));
+      break;
+    }
+
     default: {
     default: {
       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
                          TfLiteTypeGetName(input->type), input->type);
                          TfLiteTypeGetName(input->type), input->type);
@@ -138,14 +152,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_FULLY_CONNECTED() {
 TfLiteRegistration Register_FULLY_CONNECTED() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 19 - 1
code/components/tflite-lib/tensorflow/lite/micro/kernels/fully_connected.h

@@ -1,4 +1,4 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 
 
 Licensed under the Apache License, Version 2.0 (the "License");
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 you may not use this file except in compliance with the License.
@@ -81,6 +81,24 @@ inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
 }
 }
 
 
 #endif
 #endif
+
+#if defined(CMSIS_NN)
+// Returns a TfLiteRegistration struct for kernel variant that only supports
+// int16.
+TfLiteRegistration Register_FULLY_CONNECTED_INT16();
+
+#else
+// Note that while this block gets used for both reference and optimized kernels
+// that do not have any specialized implementations, the only goal here is to
+// define fallback implementation that allow reference kernels to still be used
+// from applications that call a more specific kernel variant.
+
+inline TfLiteRegistration Register_FULLY_CONNECTED_INT16() {
+  return Register_FULLY_CONNECTED();
+}
+
+#endif
+
 }  // namespace tflite
 }  // namespace tflite
 
 
 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/gather.cc

@@ -218,14 +218,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_GATHER() {
 TfLiteRegistration Register_GATHER() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/gather_nd.cc

@@ -195,14 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_GATHER_ND() {
 TfLiteRegistration Register_GATHER_ND() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 2 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/hard_swish.cc

@@ -68,14 +68,8 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_HARD_SWISH() {
 TfLiteRegistration Register_HARD_SWISH() {
-  return {/*init=*/HardSwishInit,
-          /*free=*/nullptr,
-          /*prepare=*/tflite::HardSwishPrepare,
-          /*invoke=*/HardSwishEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(HardSwishInit, tflite::HardSwishPrepare,
+                                   HardSwishEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/if.cc

@@ -115,14 +115,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace.
 }  // namespace.
 
 
 TfLiteRegistration Register_IF() {
 TfLiteRegistration Register_IF() {
-  return {/*init=*/Init,
-          /*free=*/nullptr,
-          /*prepare=*/Prepare,
-          /*invoke=*/Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(Init, Prepare, Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 3 - 2
code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.cc

@@ -15,9 +15,9 @@ limitations under the License.
 
 
 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
 
 
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/micro_arena_constants.h"
 #include "tensorflow/lite/micro/micro_arena_constants.h"
 #include "tensorflow/lite/micro/micro_error_reporter.h"
 #include "tensorflow/lite/micro/micro_error_reporter.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/test_helpers.h"
 #include "tensorflow/lite/micro/test_helpers.h"
 
 
 namespace tflite {
 namespace tflite {
@@ -30,7 +30,7 @@ uint8_t KernelRunner::kKernelRunnerBuffer_[];
 KernelRunner::KernelRunner(const TfLiteRegistration& registration,
 KernelRunner::KernelRunner(const TfLiteRegistration& registration,
                            TfLiteTensor* tensors, int tensors_size,
                            TfLiteTensor* tensors, int tensors_size,
                            TfLiteIntArray* inputs, TfLiteIntArray* outputs,
                            TfLiteIntArray* inputs, TfLiteIntArray* outputs,
-                           void* builtin_data)
+                           void* builtin_data, TfLiteIntArray* intermediates)
     : registration_(registration),
     : registration_(registration),
       allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
       allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
                                                kKernelRunnerBuffer_,
                                                kKernelRunnerBuffer_,
@@ -54,6 +54,7 @@ KernelRunner::KernelRunner(const TfLiteRegistration& registration,
   node_.inputs = inputs;
   node_.inputs = inputs;
   node_.outputs = outputs;
   node_.outputs = outputs;
   node_.builtin_data = builtin_data;
   node_.builtin_data = builtin_data;
+  node_.intermediates = intermediates;
 }
 }
 
 
 bool KernelRunner::ValidateTempBufferDeallocated() {
 bool KernelRunner::ValidateTempBufferDeallocated() {

+ 3 - 2
code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_runner.h

@@ -18,9 +18,9 @@ limitations under the License.
 
 
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/micro/arena_allocator/simple_memory_allocator.h"
 #include "tensorflow/lite/micro/fake_micro_context.h"
 #include "tensorflow/lite/micro/fake_micro_context.h"
 #include "tensorflow/lite/micro/mock_micro_graph.h"
 #include "tensorflow/lite/micro/mock_micro_graph.h"
-#include "tensorflow/lite/micro/simple_memory_allocator.h"
 
 
 namespace tflite {
 namespace tflite {
 namespace micro {
 namespace micro {
@@ -35,7 +35,8 @@ class KernelRunner {
  public:
  public:
   KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
   KernelRunner(const TfLiteRegistration& registration, TfLiteTensor* tensors,
                int tensors_size, TfLiteIntArray* inputs,
                int tensors_size, TfLiteIntArray* inputs,
-               TfLiteIntArray* outputs, void* builtin_data);
+               TfLiteIntArray* outputs, void* builtin_data,
+               TfLiteIntArray* intermediates = nullptr);
 
 
   // Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
   // Calls init and prepare on the kernel (i.e. TfLiteRegistration) struct. Any
   // exceptions will be DebugLog'd and returned as a status code.
   // exceptions will be DebugLog'd and returned as a status code.

+ 15 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.cc

@@ -36,6 +36,21 @@ int ValidateTensorIndexing(const TfLiteContext* context, int index,
 
 
 }  // namespace
 }  // namespace
 
 
+TfLiteRegistration RegisterOp(
+    void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
+    TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
+    TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node)) {
+  return {/*init=*/init,
+          /*free=*/nullptr,
+          /*prepare=*/prepare,
+          /*invoke=*/invoke,
+          /*profiling_string=*/nullptr,
+          /*builtin_code=*/0,
+          /*custom_name=*/nullptr,
+          /*version=*/0,
+          /*registration_external=*/nullptr};
+}
+
 // Returns a mutable tensor for a given input index. is_variable must be checked
 // Returns a mutable tensor for a given input index. is_variable must be checked
 // during prepare when the full TfLiteTensor is available.
 // during prepare when the full TfLiteTensor is available.
 TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
 TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,

+ 22 - 3
code/components/tflite-lib/tensorflow/lite/micro/kernels/kernel_util.h

@@ -27,6 +27,11 @@ limitations under the License.
 namespace tflite {
 namespace tflite {
 namespace micro {
 namespace micro {
 
 
+TfLiteRegistration RegisterOp(
+    void* (*init)(TfLiteContext* context, const char* buffer, size_t length),
+    TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node),
+    TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node));
+
 // Returns a mutable tensor for a given input index. is_variable must be checked
 // Returns a mutable tensor for a given input index. is_variable must be checked
 // during prepare when the full TfLiteTensor is available.
 // during prepare when the full TfLiteTensor is available.
 TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
 TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
@@ -40,19 +45,33 @@ const TfLiteEvalTensor* GetEvalInput(const TfLiteContext* context,
 TfLiteEvalTensor* GetEvalOutput(const TfLiteContext* context,
 TfLiteEvalTensor* GetEvalOutput(const TfLiteContext* context,
                                 const TfLiteNode* node, int index);
                                 const TfLiteNode* node, int index);
 
 
-// Returns data for a TfLiteEvalTensor struct.
+// Returns data for a TfLiteEvalTensor struct that are expected to exist.
 template <typename T>
 template <typename T>
 T* GetTensorData(TfLiteEvalTensor* tensor) {
 T* GetTensorData(TfLiteEvalTensor* tensor) {
-  return tensor != nullptr ? reinterpret_cast<T*>(tensor->data.raw) : nullptr;
+  TFLITE_DCHECK(tensor != nullptr);
+  return reinterpret_cast<T*>(tensor->data.raw);
 }
 }
 
 
-// Returns const data for a TfLiteEvalTensor struct.
+// Returns const data for a TfLiteEvalTensor struct that are expected to exist.
 template <typename T>
 template <typename T>
 const T* GetTensorData(const TfLiteEvalTensor* tensor) {
 const T* GetTensorData(const TfLiteEvalTensor* tensor) {
   TFLITE_DCHECK(tensor != nullptr);
   TFLITE_DCHECK(tensor != nullptr);
   return reinterpret_cast<const T*>(tensor->data.raw);
   return reinterpret_cast<const T*>(tensor->data.raw);
 }
 }
 
 
+// Returns data for a TfLiteEvalTensor struct that could be null.
+template <typename T>
+T* GetOptionalTensorData(TfLiteEvalTensor* tensor) {
+  return tensor == nullptr ? nullptr : reinterpret_cast<T*>(tensor->data.raw);
+}
+
+// Returns const data for a TfLiteEvalTensor struct that could be null.
+template <typename T>
+const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) {
+  return tensor == nullptr ? nullptr
+                           : reinterpret_cast<const T*>(tensor->data.raw);
+}
+
 // Returns the shape of a TfLiteEvalTensor struct.
 // Returns the shape of a TfLiteEvalTensor struct.
 const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
 const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
 
 

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/l2_pool_2d.cc

@@ -136,14 +136,7 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_L2_POOL_2D() {
 TfLiteRegistration Register_L2_POOL_2D() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/L2Prepare,
-          /*invoke=*/L2Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, L2Prepare, L2Eval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/l2norm.cc

@@ -137,14 +137,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace l2norm
 }  // namespace l2norm
 
 
 TfLiteRegistration Register_L2NORM_REF() {
 TfLiteRegistration Register_L2NORM_REF() {
-  return {/*init=*/l2norm::Init,
-          /*free=*/nullptr,
-          /*prepare=*/l2norm::Prepare,
-          /*invoke=*/l2norm::Eval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(l2norm::Init, l2norm::Prepare, l2norm::Eval);
 }
 }
 
 
 TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); }
 TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); }

+ 2 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/leaky_relu.cc

@@ -88,14 +88,8 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
 }
 }
 
 
 TfLiteRegistration Register_LEAKY_RELU() {
 TfLiteRegistration Register_LEAKY_RELU() {
-  return {/*init=*/LeakyReluInit,
-          /*free=*/nullptr,
-          /*prepare=*/LeakyReluPrepare,
-          /*invoke=*/LeakyReluEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(LeakyReluInit, LeakyReluPrepare,
+                                   LeakyReluEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/log_softmax.cc

@@ -142,14 +142,7 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_LOG_SOFTMAX() {
 TfLiteRegistration Register_LOG_SOFTMAX() {
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/LogSoftmaxPrepare,
-          /*invoke=*/LogSoftmaxEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, LogSoftmaxPrepare, LogSoftmaxEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 2 - 20
code/components/tflite-lib/tensorflow/lite/micro/kernels/logical.cc

@@ -34,29 +34,11 @@ TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_LOGICAL_OR() {
 TfLiteRegistration Register_LOGICAL_OR() {
-  // Init, Free, Prepare, Eval are satisfying the Interface required by
-  // TfLiteRegistration.
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/LogicalOrEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, LogicalOrEval);
 }
 }
 
 
 TfLiteRegistration Register_LOGICAL_AND() {
 TfLiteRegistration Register_LOGICAL_AND() {
-  // Init, Free, Prepare, Eval are satisfying the Interface required by
-  // TfLiteRegistration.
-  return {/*init=*/nullptr,
-          /*free=*/nullptr,
-          /*prepare=*/nullptr,
-          /*invoke=*/LogicalAndEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(nullptr, nullptr, LogicalAndEval);
 }
 }
 
 
 }  // namespace tflite
 }  // namespace tflite

+ 1 - 8
code/components/tflite-lib/tensorflow/lite/micro/kernels/logistic.cc

@@ -106,13 +106,6 @@ TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
 }  // namespace
 }  // namespace
 
 
 TfLiteRegistration Register_LOGISTIC() {
 TfLiteRegistration Register_LOGISTIC() {
-  return {/*init=*/LogisticInit,
-          /*free=*/nullptr,
-          /*prepare=*/LogisticPrepare,
-          /*invoke=*/LogisticEval,
-          /*profiling_string=*/nullptr,
-          /*builtin_code=*/0,
-          /*custom_name=*/nullptr,
-          /*version=*/0};
+  return tflite::micro::RegisterOp(LogisticInit, LogisticPrepare, LogisticEval);
 }
 }
 }  // namespace tflite
 }  // namespace tflite

+ 2955 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.cc

@@ -0,0 +1,2955 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/micro/kernels/lstm_eval.h"
+
+#include <cmath>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
+#include "tensorflow/lite/micro/kernels/micro_tensor_utils.h"
+namespace tflite {
+namespace {
+
+void ComputeRowSums(
+    int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
+    int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
+    int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
+    int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
+    int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
+    int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
+    int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
+    int n_input, int n_aux_input, int n_output,
+    const int8_t* input_to_input_weights_ptr,
+    const int8_t* input_to_forget_weights_ptr,
+    const int8_t* input_to_cell_weights_ptr,
+    const int8_t* input_to_output_weights_ptr,
+    const int8_t* aux_input_to_input_weights_ptr,
+    const int8_t* aux_input_to_forget_weights_ptr,
+    const int8_t* aux_input_to_cell_weights_ptr,
+    const int8_t* aux_input_to_output_weights_ptr,
+    const int8_t* recurrent_to_input_weights_ptr,
+    const int8_t* recurrent_to_forget_weights_ptr,
+    const int8_t* recurrent_to_cell_weights_ptr,
+    const int8_t* recurrent_to_output_weights_ptr,
+    const int8_t* projection_weights_ptr, bool use_cifg,
+    const float* aux_input_ptr) {
+  // Compute the row sums for dequantization
+  if (!use_cifg) {
+    micro_tensor_utils::ReductionSumVector(
+        input_to_input_weights_ptr, input_to_input_row_sums, n_cell, n_input);
+  }
+  micro_tensor_utils::ReductionSumVector(
+      input_to_forget_weights_ptr, input_to_forget_row_sums, n_cell, n_input);
+  micro_tensor_utils::ReductionSumVector(
+      input_to_cell_weights_ptr, input_to_cell_row_sums, n_cell, n_input);
+  micro_tensor_utils::ReductionSumVector(
+      input_to_output_weights_ptr, input_to_output_row_sums, n_cell, n_input);
+
+  if (aux_input_ptr) {
+    if (!use_cifg) {
+      micro_tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
+                                             aux_input_to_input_row_sums,
+                                             n_cell, n_aux_input);
+    }
+    micro_tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
+                                           aux_input_to_forget_row_sums, n_cell,
+                                           n_aux_input);
+    micro_tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
+                                           aux_input_to_cell_row_sums, n_cell,
+                                           n_aux_input);
+    micro_tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
+                                           aux_input_to_output_row_sums, n_cell,
+                                           n_aux_input);
+  }
+  if (!use_cifg) {
+    micro_tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
+                                           recurrent_to_input_row_sums, n_cell,
+                                           n_output);
+  }
+  micro_tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
+                                         recurrent_to_forget_row_sums, n_cell,
+                                         n_output);
+  micro_tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
+                                         recurrent_to_cell_row_sums, n_cell,
+                                         n_output);
+  micro_tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
+                                         recurrent_to_output_row_sums, n_cell,
+                                         n_output);
+
+  if (projection_weights_ptr != nullptr) {
+    micro_tensor_utils::ReductionSumVector(
+        projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
+  }
+}
+
+// Calculates a single LSTM gate.
+//
+// Implements the following formula: (* is matrix multiply)
+//   gate = activate(W_input    * input + W_aux       * aux_input   +
+//                   W_peephole * cell  + W_recurrent * prev_output + bias)
+// with layer norm:
+//   gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
+//
+// Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
+//
+// Parameters:
+// Input vectors (to LSTM):    | Size:                | Optional?
+//   input                     | n_input              |
+//   aux_input                 | n_aux_input          | y (bidir LSTM)
+// Input vectors (persistent states):
+//   output_state              | n_output             |
+//   cell_state                | n_cell               |
+// 'Constant' inputs:
+//   input_to_gate_weights     | n_cell * n_input     |
+//   aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
+//   recurrent_to_gate_weights | n_cell * n_output    |
+//   cell_to_gate_weights      | n_cell               | y (peephole)
+//   gate_bias                 | n_cell               |
+//   layer_norm_coefficients   | n_cell               | y (layer norm)
+// Output vector:
+//   gate                      | n_cell               |
+// Scalar parameters:
+//   n_batch                                    - batch size / number of vectors
+//   n_input, n_aux_input, n_output, n_cell     - size of vectors.
+//   activation                                 - activation to use.
+//   is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
+//   use_layer_norm                             - if doing layer norm LSTM.
+inline void CalculateLstmGateFloat(
+    const float* input, const float* input_to_gate_weights,
+    const float* aux_input, const float* aux_input_to_gate_weights,
+    const float* output_state, const float* recurrent_to_gate_weights,
+    const float* cell_state, const float* cell_to_gate_weights,
+    const float* layer_norm_coefficients, const float* gate_bias,
+    const int n_batch, const int n_input, const int n_aux_input,
+    const int n_output, const int n_cell,
+    const TfLiteFusedActivation activation, float* gate,
+    const bool is_input_all_zeros, const bool is_aux_input_all_zeros) {
+  const bool use_peephole = (cell_to_gate_weights != nullptr);
+  const bool use_layer_norm = (layer_norm_coefficients != nullptr);
+
+  // Initialize scratch buffers with bias for regular lstm or initialize with
+  // zero for layer norm lstm.
+  if (use_layer_norm) {
+    memset(gate, 0, n_cell * n_batch * sizeof(float));
+  } else {
+    micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch,
+                                                gate);
+  }
+  // For each batch and cell: compute input_weight * input.
+  // Skip if input is all zeros.
+  if (!is_input_all_zeros) {
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
+  }
+  // For each batch and cell: compute aux_input_weight * aux_input.
+  // Skip if auxiliary input is not available or all zeros.
+  if (!is_aux_input_all_zeros) {
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        aux_input_to_gate_weights, n_cell, n_aux_input, aux_input, n_batch,
+        gate);
+  }
+  // For each batch and cell: compute recurrent_weight * output_state.
+  micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
+  // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
+  if (use_peephole) {
+    micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
+  }
+  // Do layer normalization (if layer norm LSTM)
+  if (use_layer_norm) {
+    micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
+    micro_tensor_utils::VectorBatchVectorCwiseProduct(
+        layer_norm_coefficients, n_cell, gate, n_batch, gate);
+    micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
+  }
+  // Apply activation
+  micro_tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell,
+                                              activation, gate);
+}
+
+// Updates the LSTM cell state, used by both float and hybrid LSTM versions.
+//
+// Implements the following formula:
+//   cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
+//
+// With CIFG LSTM, input gate is replaced by (1-forget_gate).
+//
+// Parameters:
+//  - n_batch, n_cell: sizes of vectors
+//  - cell_state: input/output vector, size n_batch*n_cell
+//  - input_gate: input vector, size n_batch*n_cell.
+//  - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
+//  - cell_gate: input vector, size n_batch*n_cell.
+//  - use_cifg: use 1-forget_gate instead of input_gate.
+//  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
+void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
+                         const float* input_gate, float* forget_gate,
+                         const float* cell_gate, bool use_cifg, float clip) {
+  micro_tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
+                                               n_batch * n_cell, cell_state);
+
+  if (use_cifg) {
+    // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
+    // scratch, as input_gate array is not allocated in this case. (Be careful
+    // not to write to the scratch before reading the forget gate data.)
+    float* scratch = forget_gate;
+    micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
+    micro_tensor_utils::VectorVectorCwiseProductAccumulate(
+        cell_gate, scratch, n_batch * n_cell, cell_state);
+  } else {
+    micro_tensor_utils::VectorVectorCwiseProductAccumulate(
+        cell_gate, input_gate, n_batch * n_cell, cell_state);
+  }
+  if (clip > 0.0f) {
+    micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
+  }
+}
+
+// Calculates the output state tensor of an LSTM step.
+//
+// Implements the following formula:
+//   output_no_projection = output_gate .* activate(cell_state)
+//     (elementwise vector product)
+// If no projection is used:
+//   output = output_state = output_no_projection
+// With projection:
+//   output = output_state = clip(W*output_no_projection + bias)
+//
+// Output might not have a different 'stride' than n_batch, so we need to copy.
+//
+// Parameters:
+//  - n_batch: batches: the number of distinct vectors in each array.
+//  - n_cell, n_output: sizes of vectors.
+//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
+//  - projection_weights, projection_weights_scale, projection_bias:
+//      constant inputs, describing projection matrix and bias.
+//  - proj_clip: if > 0, clip the output of the projection.
+//  - output_state: output vector, size n_batch*n_output. Must be contigous.
+//  - scratch: scratch area, size n_batch*n_cell.
+void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
+                              const float* cell_state, const float* output_gate,
+                              TfLiteFusedActivation activation,
+                              const float* projection_weights,
+                              const float* projection_bias,
+                              const float proj_clip, float* output_state,
+                              float* scratch) {
+  micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
+                                              activation, scratch);
+  micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch,
+                                               n_batch * n_cell, scratch);
+
+  const bool use_projection = (projection_weights != nullptr);
+  const bool use_projection_bias = (projection_bias != nullptr);
+
+  if (use_projection) {
+    if (use_projection_bias) {
+      micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output,
+                                                  n_batch, output_state);
+    } else {
+      memset(output_state, 0, n_batch * n_output * sizeof(float));
+    }
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        projection_weights, n_output, n_cell, scratch, n_batch, output_state);
+    if (proj_clip > 0.0f) {
+      micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
+                                        proj_clip);
+    }
+  } else {
+    std::memcpy(output_state, scratch, n_batch * n_output * sizeof(float));
+  }
+}
+
+// Calculates a single LSTM gate, hybrid version.
+// Implements the same functionality as CalculateLstmGateFloat.
+void CalculateLstmGateHybrid(
+    // Input and weights
+    const int8_t* input, const float* input_sf, const int32_t* input_zp,
+    const int8_t* input_to_gate_weights,
+    const uint8_t* input_to_gate_weights_ledger,
+    const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
+    // Aux input and weights
+    const int8_t* aux_input, const float* aux_input_sf,
+    const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights,
+    const float aux_input_to_gate_weights_scale,
+    int32_t* aux_input_to_gate_row_sums,
+    // Output state and weights
+    const int8_t* output_state, const float* output_state_sf,
+    const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
+    const uint8_t* recurrent_to_gate_weights_ledger,
+    const float recurrent_to_gate_weights_scale,
+    int32_t* recurrent_to_gate_row_sums,
+    // Cell state and weights (peephole LSTM)
+    const float* cell_state, const int8_t* cell_to_gate_weights,
+    const float cell_to_gate_weights_scale,
+    // Layer normalization coefficients (layer norm LSTM) + gate bias
+    const float* layer_norm_coefficients, const float* gate_bias,
+    // Array sizes
+    const int n_batch, const int n_input, const int n_aux_input,
+    const int n_output, const int n_cell,
+    const TfLiteFusedActivation activation,
+    // Output
+    float* gate,
+    // Parameters for performance optimizations
+    const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
+    const bool is_output_state_all_zeros, bool* compute_row_sums,
+    // Scratch arrays
+    float* scratch0,        // size: n_batch
+    float* scratch1,        // size: n_cell, only used if peephole LSTM
+    float* scales,          // size: n_batch
+    int32_t* accum_scratch  // For MatrixBatchVectorMultiplyAccumulate
+) {
+  const bool use_peephole = (cell_to_gate_weights != nullptr);
+  const bool use_layer_norm = (layer_norm_coefficients != nullptr);
+
+  // Initialize scratch buffers with bias for regular lstm or initialize with
+  // zero for layer norm lstm.
+  if (use_layer_norm) {
+    memset(gate, 0, n_cell * n_batch * sizeof(float));
+  } else {
+    micro_tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch,
+                                                gate);
+  }
+  // For each batch and cell: compute input_weight * input.
+  // Skip if input is all zeros.
+  if (!is_input_all_zeros) {
+    if (input_to_gate_weights_ledger != nullptr) {
+      for (int i = 0; i < n_batch; i++) {
+        scales[i] = input_to_gate_weights_scale * input_sf[i];
+      }
+      micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
+          input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
+          input, scales, n_batch, gate);
+
+    } else {
+      micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+          input_to_gate_weights, n_cell, n_input, input,
+          input_to_gate_weights_scale, input_sf, n_batch, gate,
+          /*per_channel_scale=*/nullptr, input_zp, accum_scratch,
+          input_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
+    }
+  }
+  // For each batch and cell: compute aux_input_weight * aux_input.
+  // Skip if auxiliary input is not available or all zeros.
+  if (!is_aux_input_all_zeros) {
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
+        aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate,
+        /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch,
+        aux_input_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
+  }
+  // For each batch and cell: compute recurrent_weight * output_state.
+  // Skip if output state is all zeros.
+  if (!is_output_state_all_zeros) {
+    if (recurrent_to_gate_weights_ledger != nullptr) {
+      for (int i = 0; i < n_batch; i++) {
+        scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
+      }
+      micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
+          recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
+          n_output, output_state, scales, n_batch, gate);
+    } else {
+      micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+          recurrent_to_gate_weights, n_cell, n_output, output_state,
+          recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
+          /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
+          recurrent_to_gate_row_sums, compute_row_sums, scratch0, nullptr);
+    }
+  }
+  // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
+  if (use_peephole) {
+    float* recovered_cell_weights = scratch1;
+    micro_tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell,
+                                             cell_to_gate_weights_scale,
+                                             recovered_cell_weights);
+    micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        recovered_cell_weights, n_cell, cell_state, n_batch, gate);
+  }
+  // Do layer normalization (if layer norm LSTM)
+  if (use_layer_norm) {
+    micro_tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
+    micro_tensor_utils::VectorBatchVectorCwiseProduct(
+        layer_norm_coefficients, n_cell, gate, n_batch, gate);
+    micro_tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
+  }
+  // Apply activation
+  micro_tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch,
+                                              activation, gate);
+}
+
+// Calculates the output state tensor of an LSTM step. See Float version too.
+//
+// Parameters:
+//  - n_batch: batches: the number of distinct vectors in each array.
+//  - n_cell, n_output: sizes of vectors.
+//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
+//  - projection_weights, projection_weights_scale, projection_bias:
+//      constant inputs, describing projection matrix and bias.
+//  - proj_clip: if > 0, clip the output of the projection.
+//  - output_state: output vector, size n_batch*n_output. Must be contigous.
+//  - asymmetric_quantize_inputs: parameter to control quantization.
+//  - projection_weights_row_sums, compute_row_sums: Data for optimized
+//      MatrixBatchVectorMultiplyAccumulate.
+//  - scratch0: scratch area of size n_batch*n_cell
+//  - scratch1: scratch area of size n_batch*n_cell
+//  - scratch2: scratch area of size n_batch
+//  - scratch3: scratch area of size n_batch
+//  - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate
+//  - scales: scratch area of size n_batch
+void CalculateLstmOutputHybrid(
+    int n_batch, int n_cell, int n_output, const float* cell_state,
+    const float* output_gate, TfLiteFusedActivation activation,
+    const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
+    float projection_weights_scale, const float* projection_bias,
+    const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
+    int32_t* projection_weights_row_sums, bool* compute_row_sums,
+    float* scratch0, int8_t* scratch1, float* scratch2, int32_t* scratch3,
+    int32_t* scratch4, float* scales) {
+  micro_tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
+                                              activation, scratch0);
+  micro_tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
+                                               n_batch * n_cell, scratch0);
+
+  const bool use_projection = (projection_weights != nullptr);
+  const bool use_projection_bias = (projection_bias != nullptr);
+
+  if (use_projection) {
+    if (use_projection_bias) {
+      micro_tensor_utils::VectorBatchVectorAssign(projection_bias, n_output,
+                                                  n_batch, output_state);
+    } else {
+      memset(output_state, 0, n_batch * n_output * sizeof(float));
+    }
+    if (!micro_tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) {
+      // Save quantization and matmul computation for all zero output.
+      micro_tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell,
+                                              scratch1, scratch2, scratch3,
+                                              asymmetric_quantize_inputs);
+      if (projection_weights_ledger != nullptr) {
+        for (int i = 0; i < n_batch; i++) {
+          scales[i] = projection_weights_scale * scratch2[i];
+        }
+        micro_tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
+            projection_weights, projection_weights_ledger, n_output, n_cell,
+            scratch1, scales, n_batch, output_state);
+      } else {
+        micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+            projection_weights, n_output, n_cell, scratch1,
+            projection_weights_scale, scratch2, n_batch, output_state,
+            /*per_channel_scale=*/nullptr, scratch3, scratch4,
+            projection_weights_row_sums, compute_row_sums, scratch2, nullptr);
+      }
+    }
+    if (proj_clip > 0.0f) {
+      micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
+                                        proj_clip);
+    }
+  } else {
+    std::memcpy(output_state, scratch0, n_batch * n_output * sizeof(float));
+  }
+}
+
+// Calculates a single LSTM gate, int8x8_16 version.
+// Implements the same functionality as CalculateLstmGateFloat.
+void CalculateLstmGateInteger8x8_16(
+    // Input and weights
+    const int8_t* input, const int8_t* input_to_gate_weights,
+    const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
+    const int32_t input_to_gate_scale_b,
+    // Output state and weights
+    const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
+    const int32_t* recurrent_to_gate_bias,
+    const int32_t recurrent_to_gate_scale_a,
+    const int32_t recurrent_to_gate_scale_b,
+    // Cell state and weights
+    const int16_t* cell_state, const int16_t* cell_to_gate_weights,
+    const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
+    // Layer normalization parameters (layer norm LSTM)
+    const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
+    const int32_t layer_norm_input_scale_a,
+    const int32_t layer_norm_input_scale_b,
+    const int32_t layer_norm_variance_guard,
+    // Array sizes
+    const int n_batch, const int n_input, const int n_output, const int n_cell,
+    const TfLiteFusedActivation activation,
+    // Output
+    int16_t* gate,
+    // Parameters for performance optimizations
+    // Scratch arrays
+    int32_t* scratch5) {
+  const bool use_peephole = (cell_to_gate_weights != nullptr);
+  const bool use_layer_norm = (layer_norm_coefficients != nullptr);
+
+  // Initialize scratch buffers with zeros. Note that unlike float and hybrid
+  // versions, bias is only used in layer normalization.
+  memset(gate, 0, n_batch * n_cell * sizeof(int16_t));
+  // For each batch and cell: compute input_weight * input.
+  micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
+      input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate,
+      nullptr);
+  // Note: no aux_input.
+
+  // For each batch and cell: compute recurrent_weight * output_state.
+  micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
+      recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
+      n_cell, 0, scratch5, gate, nullptr);
+  // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
+  if (use_peephole) {
+    micro_tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        cell_to_gate_weights, n_output, cell_state, n_batch,
+        cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
+  }
+  // Do layer normalization (if layer norm LSTM)
+  if (use_layer_norm) {
+    micro_tensor_utils::ApplyLayerNorm(
+        gate, layer_norm_coefficients, layer_norm_bias,
+        layer_norm_input_scale_a, layer_norm_input_scale_b,
+        layer_norm_variance_guard, n_batch, n_cell, gate);
+  }
+  // Apply activation
+  switch (activation) {
+    case kTfLiteActSigmoid:
+      micro_tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
+      break;
+    case kTfLiteActTanh:
+      micro_tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
+      break;
+    default:
+      // Only Sigmoid or Tanh is used.
+      TFLITE_ASSERT_FALSE;
+  }
+}
+
+// Updates the LSTM cell state, used by both integer LSTM versions.
+// Also see UpdateLstmCellFloat.
+//
+// Parameters:
+//  - n_batch, n_cell: sizes of vectors
+//  - cell_state: input/output vector, size n_batch*n_cell
+//  - cell_state_scale: scaling factor of cell state.
+//  - input_gate: input vector, size n_batch*n_cell.
+//  - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
+//  - cell_gate: input vector, size n_batch*n_cell.
+//  - use_cifg: use 1-forget_gate instead of input_gate.
+//  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
+void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
+                           int32_t cell_state_scale, const int16_t* input_gate,
+                           int16_t* forget_gate, const int16_t* cell_gate,
+                           bool use_cifg, int16_t clip) {
+  // Use the forget_gate array as scratch, as input_gate array is not allocated
+  // in CIFG case. (Be careful not to write to the scratch before reading the
+  // forget gate data.)
+  int16_t* scratch = forget_gate;
+
+  micro_tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
+                               cell_state);
+  if (use_cifg) {
+    micro_tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
+    micro_tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell,
+                                 30 + cell_state_scale, scratch);
+  } else {
+    micro_tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell,
+                                 30 + cell_state_scale, scratch);
+  }
+  micro_tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell,
+                               cell_state);
+
+  if (clip > 0) {
+    micro_tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
+  }
+}
+
+// Calculates the output state tensor of an LSTM step. See Float and hybrid
+// versions as well.
+//
+// Parameters:
+//  - n_batch: batches: the number of distinct vectors in each array.
+//  - n_cell, n_output: sizes of vectors.
+//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
+//  - cell_state_scale: scaling of cell_state.
+//  - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
+//  - hidden_zp: zero_point for cell_state.*output_gate
+//  - projection_weights, proj_scale_[a|b], projection_bias:
+//      constant inputs, describing projection matrix and bias.
+//  - output_state_zp: zero point of output_state. (Input, calibrated value.)
+//  - quantized_proj_clip: if > 0, clip the output of the projection.
+//  - output_state: output vector, size n_batch*n_output. Must be contigous.
+//  - scratch0: scratch area of size n_batch*n_cell
+//  - scratch1: scratch area of size n_batch*n_cell
+//  - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
+void CalculateLstmOutputInteger8x8_16(
+    int n_batch, int n_cell, int n_output, const int16_t* cell_state,
+    int32_t cell_state_scale, const int16_t* output_gate,
+    int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
+    const int8_t* projection_weights, int32_t proj_scale_a,
+    int32_t proj_scale_b, const int32_t* projection_bias,
+    int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
+    int16_t* scratch0, int8_t* scratch1, int32_t* scratch2) {
+  // Note: unlike float/hybrid, the activation is always Tanh.
+  micro_tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch,
+                                n_cell, scratch0);
+  micro_tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a,
+                               hidden_scale_b, n_batch, n_cell, hidden_zp,
+                               scratch1);
+
+  const bool use_projection = (projection_weights != nullptr);
+
+  if (use_projection) {
+    // Note: no bias like in float/hybrid
+    memset(output_state, 0, n_batch * n_output * sizeof(int8_t));
+    micro_tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        scratch1, projection_bias, projection_weights, proj_scale_a,
+        proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
+        output_state, nullptr);
+    if (quantized_proj_clip > 0) {
+      micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
+                                        quantized_proj_clip);
+    }
+  } else {
+    std::memcpy(output_state, scratch1, n_batch * n_output * sizeof(int8_t));
+  }
+}
+
+// Calculates a single LSTM gate, int8x8_8 version.
+// Implements the same functionality as CalculateLstmGateFloat.
+void CalculateLstmGateInteger8x8_8(
+    // Inputs and weights
+    const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
+    const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
+    const int32_t input_times_weights_scale_a,
+    const int32_t input_times_weights_scale_b,
+    const int32_t input_times_weights_zp,
+    // Output state and weights
+    const int8_t* output_state, const int32_t output_state_zp,
+    const int8_t* recurrent_to_gate_weight,
+    const int32_t recurrent_to_gate_scale_a,
+    const int32_t recurrent_to_gate_scale_b,
+    const int32_t output_state_times_weights_scale_a,
+    const int32_t output_state_times_weights_scale_b,
+    const int32_t output_state_times_weights_zp,
+    // Layer normalization parameters (layer norm LSTM)
+    const int16_t* layer_norm_gate_weight,
+    const int32_t layer_norm_gate_scale_a,
+    const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
+    // Array sizes
+    const int n_batch, const int n_input, const int n_output, const int n_cell,
+    const TfLiteFusedActivation activation,
+    // Output
+    int16_t* gate,
+    // Scratch arrays, both sized n_batch*n_cell
+    int8_t* scratch0, int8_t* scratch1) {
+  // Multiply input * input_weights => scratch0
+  micro_tensor_utils::MatrixBatchVectorMultiply(
+      input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
+      input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
+      input_times_weights_zp);
+  // Multiply output_state * recurrent_weights => scratch1
+  micro_tensor_utils::MatrixBatchVectorMultiply(
+      output_state, output_state_zp, recurrent_to_gate_weight,
+      recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
+      n_cell, scratch1, output_state_times_weights_zp);
+  // Add scratch0 + scratch1 => gate
+  micro_tensor_utils::TwoGateSaturatingAdd(
+      scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
+      input_times_weights_scale_a, input_times_weights_scale_b,
+      output_state_times_weights_scale_a, output_state_times_weights_scale_b,
+      n_batch, n_cell, gate);
+  // Apply layer normalization.
+  micro_tensor_utils::ApplyLayerNormFloat(
+      gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
+      layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
+  // Apply activation.
+  switch (activation) {
+    case kTfLiteActSigmoid:
+      micro_tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate);
+      break;
+    case kTfLiteActTanh:
+      micro_tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
+      break;
+    default:
+      // Only Sigmoid or Tanh is used.
+      TFLITE_ASSERT_FALSE;
+  }
+}
+
+// Calculates the output state tensor of an LSTM step. See Float and hybrid
+// versions as well.
+//
+// Parameters:
+//  - n_batch: batches: the number of distinct vectors in each array.
+//  - n_cell, n_output: sizes of vectors.
+//  - cell_state, output_gate: input vectors, size n_batch*n_cell.
+//  - projection_weights, proj_scale_[a|b], projection_bias:
+//      constant inputs, describing projection matrix and bias.
+//  - output_state_zp: zero point of the output state.
+//  - quantized_proj_clip: if > 0, clip the output of the projection.
+//  - output_state: output vector, size n_batch*n_output. Must be contigous.
+//  - scratch: scratch area of size n_batch*n_cell
+void CalculateLstmOutputInteger8x8_8(
+    int n_batch, int n_cell, int n_output, const int16_t* cell_state,
+    const int16_t* output_gate, const int8_t* projection_weights,
+    int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
+    int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
+    int16_t* scratch) {
+  // Note: unlike float/hybrid, the activation is always Tanh.
+  micro_tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch);
+  micro_tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell,
+                               15 + 15 - 15, scratch);
+  // Note: no bias like in float/hybrid
+  micro_tensor_utils::MatrixBatchVectorMultiply(
+      scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
+      n_batch, n_cell, n_output, output_state_zp, output_state);
+  if (quantized_proj_clip > 0) {
+    micro_tensor_utils::CwiseClipping(output_state, n_batch * n_output,
+                                      quantized_proj_clip);
+  }
+}
+
+// Performs an LSTM batch inference step for input specified by input_ptr.
+// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
+// biases (*_bias_ptr), and buffers (*_scratch), along with additional
+// parameters:
+//  - params: various LSTM params including activation, clipping, etc.,
+//  - n_batch: size of batch,
+//  - n_cell: number of cells (or units),
+//  - n_input: the input size,
+//  - n_aux_input: the auxiliary input size.
+//  - n_output: the output size.
+//  - output_batch_leading_dim: the leading dimension of the output buffer.
+//
+// Input of size 'n_batch * n_input':
+//   input_ptr
+// Input of size 'n_batch * n_aux_input':
+//   aux_input_ptr                     - optional (can be nullptr)
+//
+// LSTM weights:
+// Input weights of size 'n_cell * n_input':
+//   input_to_input_weights            - optional
+//   input_to_forget_weights
+//   input_to_cell_weights
+//   input_to_output_weights
+// Auxiliary input weights of size 'n_cell * n_aux_input':
+//   aux_input_to_input_weights        - optional
+//   aux_input_to_forget_weights       - optional
+//   aux_input_to_cell_weights         - optional
+//   aux_input_to_output_weights       - optional
+// Recurrent weights of size 'n_cell * n_output':
+//   recurrent_to_input_weights        - optional
+//   recurrent_to_forget_weights
+//   recurrent_to_cell_weights
+//   recurrent_to_input_weights
+// Peephole weights of size 'n_cell', representing diagonal matrices.
+//   cell_to_input_weights             - optional
+//   cell_to_cell_weights              - optional
+//   cell_to_output_weights            - optional
+// Projection weights of size 'n_output * n_cell'
+//   projection_weights_ptr            - optional
+// Gate biases of size 'n_cell':
+//   input_gate_bias_ptr               - optional
+//   forget_gate_bias_ptr
+//   cell_gate_bias_ptr
+//   output_gate_bias_ptr
+//
+// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
+//   input_layer_norm_coefficients_ptr  - optional
+//   forget_layer_norm_coefficients_ptr - optional
+//   cell_layer_norm_coefficients_ptr   - optional
+//   output_layer_norm_coefficients_ptr - optional
+//
+// The pointers to the cell and output state and the output are updated.
+//
+// The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
+// in batch_major order, and each step processes batch_size many inputs from
+// input_ptr, and updates batch_size many cell and output states.
+//
+// The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
+// output tensor, and in most cases will be equal to n_output. It is usually not
+// when we want to store the LSTM output into a slice of the output tensor, e.g.
+// for bidirectional LSTMs with merge_outputs. In this case, the batched
+// operations cannot be used since they assume that the batched outputs are
+// contiguous, and we manually loop over the batched outputs.
+inline void LstmStepFloat(
+    const float* input_ptr, const float* input_to_input_weights_ptr,
+    const float* input_to_forget_weights_ptr,
+    const float* input_to_cell_weights_ptr,
+    const float* input_to_output_weights_ptr, const float* aux_input_ptr,
+    const float* aux_input_to_input_weights_ptr,
+    const float* aux_input_to_forget_weights_ptr,
+    const float* aux_input_to_cell_weights_ptr,
+    const float* aux_input_to_output_weights_ptr,
+    const float* recurrent_to_input_weights_ptr,
+    const float* recurrent_to_forget_weights_ptr,
+    const float* recurrent_to_cell_weights_ptr,
+    const float* recurrent_to_output_weights_ptr,
+    const float* cell_to_input_weights_ptr,
+    const float* cell_to_forget_weights_ptr,
+    const float* cell_to_output_weights_ptr,
+    const float* input_layer_norm_coefficients_ptr,
+    const float* forget_layer_norm_coefficients_ptr,
+    const float* cell_layer_norm_coefficients_ptr,
+    const float* output_layer_norm_coefficients_ptr,
+    const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
+    const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
+    const float* projection_weights_ptr, const float* projection_bias_ptr,
+    const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+    int n_aux_input, int n_output, int output_batch_leading_dim,
+    float* output_state_ptr, float* cell_state_ptr, float* scratch0,
+    float* scratch1, float* scratch2, float* scratch3, float* output_ptr) {
+  // Since we have already checked that weights are all there or none, we can
+  // check the existence of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+
+  // Make named scratch buffers.
+  float* input_gate_scratch = scratch0;
+  float* forget_gate_scratch = scratch1;
+  float* cell_gate_scratch = scratch2;
+  float* output_gate_scratch = scratch3;
+
+  // Check if inputs are all zeros so we can skip some computations.
+  const bool is_input_all_zeros =
+      micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
+  const bool is_aux_input_all_zeros =
+      (aux_input_ptr == nullptr ||
+       micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
+  if (!use_cifg) {
+    // Calculate the input gate. (If not CIFG.)
+    CalculateLstmGateFloat(
+        input_ptr, input_to_input_weights_ptr, aux_input_ptr,
+        aux_input_to_input_weights_ptr, output_state_ptr,
+        recurrent_to_input_weights_ptr, cell_state_ptr,
+        cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
+        input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
+        /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
+        is_input_all_zeros, is_aux_input_all_zeros);
+  }
+  // Calculate the forget gate.
+  CalculateLstmGateFloat(
+      input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
+      aux_input_to_forget_weights_ptr, output_state_ptr,
+      recurrent_to_forget_weights_ptr, cell_state_ptr,
+      cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
+      forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
+      /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
+      is_aux_input_all_zeros);
+  // Calculate the cell update gate.
+  CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
+                         aux_input_to_cell_weights_ptr, output_state_ptr,
+                         recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
+                         /*cell_to_gate_weights=*/nullptr,
+                         cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
+                         n_batch, n_input, n_aux_input, n_output, n_cell,
+                         params->activation, cell_gate_scratch,
+                         is_input_all_zeros, is_aux_input_all_zeros);
+  // Update the cell state.
+  UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
+                      forget_gate_scratch, cell_gate_scratch, use_cifg,
+                      params->cell_clip);
+  // Calculate output gate.
+  CalculateLstmGateFloat(
+      input_ptr, input_to_output_weights_ptr, aux_input_ptr,
+      aux_input_to_output_weights_ptr, output_state_ptr,
+      recurrent_to_output_weights_ptr, cell_state_ptr,
+      cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
+      output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
+      /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
+      is_aux_input_all_zeros);
+  // Update the output state.
+  CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
+                           output_gate_scratch, params->activation,
+                           projection_weights_ptr, projection_bias_ptr,
+                           params->proj_clip, output_state_ptr, scratch2);
+  // Copy output state to the output. Note that the output's rows may not be
+  // contiguous (output_batch_leading_dim != n_output).
+  for (int b = 0; b < n_batch; b++) {
+    std::memcpy(output_ptr + b * output_batch_leading_dim,
+                output_state_ptr + b * n_output, n_output * sizeof(float));
+  }
+}
+
+// Same as above but with quantized weight matrices. In detail:
+// Input of size 'n_batch * n_input':
+//   input_ptr
+// Input of size 'n_batch * n_aux_input':
+//   aux_input_ptr                     - optional (can be nullptr)
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+//   input_to_input_weights            - optional
+//   input_to_forget_weights
+//   input_to_cell_weights
+//   input_to_input_weights
+// Quantized auxiliary input weights of size 'n_cell * n_aux_input':
+//   aux_input_to_input_weights        - optional
+//   aux_input_to_forget_weights       - optional
+//   aux_input_to_cell_weights         - optional
+//   aux_input_to_output_weights       - optional
+// Quantized recurrent weights of size 'n_cell * n_output':
+//   recurrent_to_input_weights        - optional
+//   recurrent_to_forget_weights
+//   recurrent_to_cell_weights
+//   recurrent_to_input_weights
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+//   cell_to_input_weights             - optional
+//   cell_to_cell_weights              - optional
+//   cell_to_output_weights            - optional
+// Quantized projection weights of size 'n_output * n_cell'
+//   projection_weights_ptr            - optional
+// Weight scales (scalars) for each of the weights above.
+//   input_to_input_weights_scale      - optional
+//   input_to_forget_weights_scale
+//   input_to_cell_weights_scale
+//   input_to_output_weights_scale
+//   aux_input_to_input_weights_scale  - optional
+//   aux_input_to_forget_weights_scale - optional
+//   aux_input_to_cell_weights_scale   - optional
+//   aux_input_to_output_weights_scale - optional
+//   recurrent_to_input_weights_scale  - optional
+//   recurrent_to_forget_weights_scale
+//   recurrent_to_cell_weights_scale
+//   recurrent_to_output_weights_scale
+//   cell_to_input_weights_scale,
+//   cell_to_forget_weights_scale,
+//   cell_to_output_weights_scale,
+//   projection_weights_scale          - optional
+// Gate biases of size 'n_cell':
+//   input_gate_bias_ptr               - optional
+//   forget_gate_bias_ptr
+//   cell_gate_bias_ptr
+//   output_gate_bias_ptr
+//
+// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
+//   input_layer_norm_coefficients_ptr  - optional
+//   forget_layer_norm_coefficients_ptr - optional
+//   cell_layer_norm_coefficients_ptr   - optional
+//   output_layer_norm_coefficients_ptr - optional
+//
+// Temporary pre-allocated storage for quantized values:
+//   quantized_input_ptr (same size as input_ptr)
+//   quantized_output_state_ptr (same size as output_state_ptr)
+//   quantized_output_scratch (same size as cell_state_ptr)
+// Temporary pre-allocated storage for recovered values:
+//   recovered_cell_weights (same size as cell_to_*_weights)
+//
+// Outputs:
+//   output_state_ptr - size 'n_batch * n_output'
+//   cell_state_ptr   - size 'n_batch * n_cell'
+//   output_ptr       - size 'n_batch * output_batch_leading_dim'
+inline void LstmStepHybrid(
+    const float* input_ptr, const int8_t* input_to_input_weights_ptr,
+    const uint8_t* input_to_input_weights_ledger_ptr,
+    float input_to_input_weights_scale,
+    const int8_t* input_to_forget_weights_ptr,
+    const uint8_t* input_to_forget_weights_ledger_ptr,
+    float input_to_forget_weights_scale,
+    const int8_t* input_to_cell_weights_ptr,
+    const uint8_t* input_to_cell_weights_ledger_ptr,
+    float input_to_cell_weights_scale,
+    const int8_t* input_to_output_weights_ptr,
+    const uint8_t* input_to_output_weights_ledger_ptr,
+    float input_to_output_weights_scale, const float* aux_input_ptr,
+    const int8_t* aux_input_to_input_weights_ptr,
+    float aux_input_to_input_weights_scale,
+    const int8_t* aux_input_to_forget_weights_ptr,
+    float aux_input_to_forget_weights_scale,
+    const int8_t* aux_input_to_cell_weights_ptr,
+    float aux_input_to_cell_weights_scale,
+    const int8_t* aux_input_to_output_weights_ptr,
+    float aux_input_to_output_weights_scale,
+    const int8_t* recurrent_to_input_weights_ptr,
+    const uint8_t* recurrent_to_input_weights_ledger_ptr,
+    float recurrent_to_input_weights_scale,
+    const int8_t* recurrent_to_forget_weights_ptr,
+    const uint8_t* recurrent_to_forget_weights_ledger_ptr,
+    float recurrent_to_forget_weights_scale,
+    const int8_t* recurrent_to_cell_weights_ptr,
+    const uint8_t* recurrent_to_cell_weights_ledger_ptr,
+    float recurrent_to_cell_weights_scale,
+    const int8_t* recurrent_to_output_weights_ptr,
+    const uint8_t* recurrent_to_output_weights_ledger_ptr,
+    float recurrent_to_output_weights_scale,
+    const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+    const int8_t* cell_to_forget_weights_ptr,
+    float cell_to_forget_weights_scale,
+    const int8_t* cell_to_output_weights_ptr,
+    float cell_to_output_weights_scale,
+    const float* input_layer_norm_coefficients_ptr,
+    const float* forget_layer_norm_coefficients_ptr,
+    const float* cell_layer_norm_coefficients_ptr,
+    const float* output_layer_norm_coefficients_ptr,
+    const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
+    const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
+    const int8_t* projection_weights_ptr,
+    const uint8_t* projection_weights_ledger_ptr,
+    float projection_weights_scale, const float* projection_bias_ptr,
+    const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+    int n_aux_input, int n_output, int output_batch_leading_dim,
+    float* scratch0, float* scratch1, float* scratch2, float* scratch3,
+    float* scales, float* input_sf, float* aux_input_sf, float* output_state_sf,
+    float* scaling_factors_scratch, float* recovered_cell_weights,
+    int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
+    int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
+    float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
+    float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
+    int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
+    bool* compute_row_sums, bool asymmetric_quantize_inputs) {
+  // Since we have already checked that weights are all there or none, we
+  // can check the existence of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+  // Make named scratch buffers for the different gates.
+  float* input_gate_scratch = scratch0;
+  float* forget_gate_scratch = scratch1;
+  float* cell_gate_scratch = scratch2;
+  float* output_gate_scratch = scratch3;
+
+  int32_t* input_to_input_row_sums = nullptr;
+  int32_t* input_to_forget_row_sums = nullptr;
+  int32_t* input_to_cell_row_sums = nullptr;
+  int32_t* input_to_output_row_sums = nullptr;
+  int32_t* aux_input_to_input_row_sums = nullptr;
+  int32_t* aux_input_to_forget_row_sums = nullptr;
+  int32_t* aux_input_to_cell_row_sums = nullptr;
+  int32_t* aux_input_to_output_row_sums = nullptr;
+  int32_t* recurrent_to_input_row_sums = nullptr;
+  int32_t* recurrent_to_forget_row_sums = nullptr;
+  int32_t* recurrent_to_cell_row_sums = nullptr;
+  int32_t* recurrent_to_output_row_sums = nullptr;
+  int32_t* projection_weights_row_sums = nullptr;
+
+  if (asymmetric_quantize_inputs) {
+    int num_row_sums = use_cifg ? 6 : 8;
+    if (aux_input_ptr != nullptr) {
+      num_row_sums += use_cifg ? 3 : 4;
+    }
+    if (projection_weights_ptr != nullptr) {
+      num_row_sums += ceil(static_cast<float>(n_output) / n_cell);
+    }
+    TFLITE_DCHECK(row_sums_size == num_row_sums);
+    input_to_input_row_sums = row_sums;
+    input_to_forget_row_sums =
+        use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
+    input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
+    input_to_output_row_sums = input_to_cell_row_sums + n_cell;
+    if (aux_input_ptr != nullptr) {
+      aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
+      aux_input_to_forget_row_sums = use_cifg
+                                         ? aux_input_to_input_row_sums
+                                         : aux_input_to_input_row_sums + n_cell;
+      aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
+      aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
+    }
+    recurrent_to_input_row_sums = aux_input_ptr
+                                      ? aux_input_to_output_row_sums + n_cell
+                                      : input_to_output_row_sums + n_cell;
+    recurrent_to_forget_row_sums = use_cifg
+                                       ? recurrent_to_input_row_sums
+                                       : recurrent_to_input_row_sums + n_cell;
+    recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
+    recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
+    if (projection_weights_ptr != nullptr) {
+      projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
+    }
+    if (*compute_row_sums) {
+      ComputeRowSums(
+          input_to_input_row_sums, input_to_forget_row_sums,
+          input_to_cell_row_sums, input_to_output_row_sums,
+          aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
+          aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
+          recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
+          recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
+          projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
+          n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+          input_to_cell_weights_ptr, input_to_output_weights_ptr,
+          aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
+          aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
+          recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+          recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+          projection_weights_ptr, use_cifg, aux_input_ptr);
+      *compute_row_sums = false;
+    }
+  }
+
+  // Check if inputs are all zeros so we can skip some computations.
+  const bool is_input_all_zeros =
+      micro_tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
+  const bool is_aux_input_all_zeros =
+      (aux_input_ptr == nullptr ||
+       micro_tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
+  const bool is_output_state_all_zeros =
+      micro_tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output);
+  // Quantize inputs.
+  if (!is_input_all_zeros) {
+    micro_tensor_utils::BatchQuantizeFloats(
+        input_ptr, n_batch, n_input, quantized_input_ptr, input_sf, input_zp,
+        asymmetric_quantize_inputs);
+  }
+  if (!is_aux_input_all_zeros) {
+    micro_tensor_utils::BatchQuantizeFloats(
+        aux_input_ptr, n_batch, n_aux_input, quantized_aux_input_ptr,
+        aux_input_sf, aux_input_zp, asymmetric_quantize_inputs);
+  }
+  if (!is_output_state_all_zeros) {
+    micro_tensor_utils::BatchQuantizeFloats(
+        output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
+        output_state_sf, output_state_zp, asymmetric_quantize_inputs);
+  }
+  if (!use_cifg) {
+    // Calculate the input gate. (If not CIFG.)
+    CalculateLstmGateHybrid(
+        quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
+        input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
+        input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
+        aux_input_zp, aux_input_to_input_weights_ptr,
+        aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
+        quantized_output_state_ptr, output_state_sf, output_state_zp,
+        recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
+        recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
+        cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
+        input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
+        n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
+        input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
+        is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
+        recovered_cell_weights, scales, accum_scratch_ptr);
+  }
+  // Calculate the forget gate.
+  CalculateLstmGateHybrid(
+      quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
+      input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
+      input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
+      aux_input_zp, aux_input_to_forget_weights_ptr,
+      aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
+      quantized_output_state_ptr, output_state_sf, output_state_zp,
+      recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
+      recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
+      cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+      forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
+      n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
+      forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
+      is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
+      recovered_cell_weights, scales, accum_scratch_ptr);
+  // Calculate the cell update gate.
+  CalculateLstmGateHybrid(
+      quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
+      input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
+      input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
+      aux_input_zp, aux_input_to_cell_weights_ptr,
+      aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
+      quantized_output_state_ptr, output_state_sf, output_state_zp,
+      recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
+      recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
+      /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
+      /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
+      cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
+      params->activation, cell_gate_scratch, is_input_all_zeros,
+      is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums,
+      scaling_factors_scratch, recovered_cell_weights, scales,
+      accum_scratch_ptr);
+  // Update the cell state.
+  UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
+                      forget_gate_scratch, cell_gate_scratch, use_cifg,
+                      params->cell_clip);
+  // Calculate the output gate.
+  CalculateLstmGateHybrid(
+      quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
+      input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
+      input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
+      aux_input_zp, aux_input_to_output_weights_ptr,
+      aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
+      quantized_output_state_ptr, output_state_sf, output_state_zp,
+      recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
+      recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
+      cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
+      output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
+      n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
+      output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
+      is_output_state_all_zeros, compute_row_sums, scaling_factors_scratch,
+      recovered_cell_weights, scales, accum_scratch_ptr);
+  // Update the output state.
+  CalculateLstmOutputHybrid(
+      n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
+      params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
+      projection_weights_scale, projection_bias_ptr, params->proj_clip,
+      output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
+      compute_row_sums, scratch2, quantized_output_scratch, input_sf, input_zp,
+      accum_scratch_ptr, scales);
+  // Copy output state to the output. Note that the output's rows may not be
+  // contiguous (output_batch_leading_dim != n_output).
+  for (int b = 0; b < n_batch; b++) {
+    std::memcpy(output_ptr + b * output_batch_leading_dim,
+                output_state_ptr + b * n_output, n_output * sizeof(float));
+  }
+}
+
+// Fully quantized lstm kernel for 16 bit gate matmul output.
+//
+// Input tensor of size n_batch * n_input:
+//   input_ptr
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+//   input_to_input_weight_ptr            - optional
+//   input_to_forget_weight_ptr           - optional
+//   input_to_cell_weight_ptr             - optional
+//   input_to_output_weight_ptr           - optional
+//
+// Quantized recurrent weights of size 'n_cell * n_output':
+//   recurrent_to_input_weight_ptr        - optional
+//   recurrent_to_forget_weights_ptr
+//   recurrent_to_cell_weights_ptr
+//   recurrent_to_input_weights_ptr
+//
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+//   cell_to_input_weights               - optional
+//   cell_to_cell_weights                - optional
+//   cell_to_output_weights              - optional
+//
+// Quantized projection weights of size 'n_output * n_cell'
+//   projection_weight_ptr                     - optional
+//
+// Weight scales (scalars) for each of the weights above.
+//   effective_input_to_input_scale_a    - optional
+//   effective_input_to_input_scale_b    - optional
+//   effective_input_to_forget_scale_a
+//   effective_input_to_forget_scale_b
+//   effective_input_to_cell_scale_a
+//   effective_input_to_cell_scale_b
+//   effective_input_to_output_scale_a
+//   effective_input_to_output_scale_b
+//   effective_recurrent_to_input_scale_a    - optional
+//   effective_recurrent_to_input_scale_b    - optional
+//   effective_recurrent_to_forget_scale_a
+//   effective_recurrent_to_forget_scale_b
+//   effective_recurrent_to_cell_scale_a
+//   effective_recurrent_to_cell_scale_b
+//   effective_recurrent_to_output_scale_a
+//   effective_recurrent_to_output_scale_b
+//   effective_proj_scale_a                  - optional
+//   effective_proj_scale_b                  - optional
+//
+// Gate biases of size 'n_cell':
+//   input_gate_bias_ptr                 - optional
+//   forget_gate_bias_ptr
+//   cell_gate_bias_ptr
+//   output_gate_bias_ptr
+//
+// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
+//   layer_norm_input_weight_ptr    - optional
+//   layer_norm_forget_weight_ptr   - optional
+//   layer_norm_cell_weight_ptr     - optional
+//   layer_norm_output_weight_ptr   - optional
+//
+// Layer norm scales of size 'n_cell'.
+//   layer_norm_input_scale_a     - optional
+//   layer_norm_input_scale_b     - optional
+//   layer_norm_forget_scale_a    - optional
+//   layer_norm_forget_scale_b    - optional
+//   layer_norm_cell_scale_a      - optional
+//   layer_norm_cell_scale_b      - optional
+//   layer_norm_output_scale_a    - optional
+//   layer_norm_output_scale_b    - optional
+//
+// Scalar values:
+//   quantized_cell_clip: quantized clip value for cell.
+//   quantized_proj_clip: quantized clip value for projection.
+//   cell_state_scale: the power of two scale for cell state.
+//
+// Zero points:
+//   output_state_zp: zero point of output state
+//   hidden_zp: zero point for hidden state.
+//
+// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
+// n_batch.
+//   scratch0
+//   scratch1
+//   scratch2
+//   scratch3
+//   scratch4
+//   scratch5: this scratch buffer is created purely for optimizing the
+//              MatrixBatchVectorMultiplyAccumulate.
+//
+// Outputs:
+//   output_state_ptr - size 'n_batch * n_output'
+//   cell_state_ptr   - size 'n_batch * n_cell'
+//   output_ptr       - size 'n_batch * n_output'
+// TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
+inline void LstmStepInteger8x8_16(
+    const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
+    int32_t effective_input_to_input_scale_a,
+    int32_t effective_input_to_input_scale_b,
+    const int8_t* input_to_forget_weight_ptr,
+    int32_t effective_input_to_forget_scale_a,
+    int32_t effective_input_to_forget_scale_b,
+    const int8_t* input_to_cell_weight_ptr,
+    int32_t effective_input_to_cell_scale_a,
+    int32_t effective_input_to_cell_scale_b,
+    const int8_t* input_to_output_weight_ptr,
+    int32_t effective_input_to_output_scale_a,
+    int32_t effective_input_to_output_scale_b,
+    const int8_t* recurrent_to_input_weight_ptr,
+    int32_t effective_recurrent_to_input_scale_a,
+    int32_t effective_recurrent_to_input_scale_b,
+    const int8_t* recurrent_to_forget_weight_ptr,
+    int32_t effective_recurrent_to_forget_scale_a,
+    int32_t effective_recurrent_to_forget_scale_b,
+    const int8_t* recurrent_to_cell_weight_ptr,
+    int32_t effective_recurrent_to_cell_scale_a,
+    int32_t effective_recurrent_to_cell_scale_b,
+    const int8_t* recurrent_to_output_weight_ptr,
+    int32_t effective_recurrent_to_output_scale_a,
+    int32_t effective_recurrent_to_output_scale_b,
+    const int16_t* cell_to_input_weight_ptr,
+    int32_t effective_cell_to_input_scale_a,
+    int32_t effective_cell_to_input_scale_b,
+    const int16_t* cell_to_forget_weight_ptr,
+    int32_t effective_cell_to_forget_scale_a,
+    int32_t effective_cell_to_forget_scale_b,
+    const int16_t* cell_to_output_weight_ptr,
+    int32_t effective_cell_to_output_scale_a,
+    int32_t effective_cell_to_output_scale_b,
+    const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
+    int32_t effective_proj_scale_b, int32_t hidden_zp,
+    int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
+    const int16_t* layer_norm_input_weight_ptr,
+    int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
+    const int16_t* layer_norm_forget_weight_ptr,
+    int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
+    const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
+    int32_t layer_norm_cell_scale_b,
+    const int16_t* layer_norm_output_weight_ptr,
+    int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
+    const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
+    const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
+    int16_t quantized_cell_clip, int8_t quantized_proj_clip,
+    int32_t cell_state_scale, int32_t input_variance_guard,
+    int32_t forget_variance_guard, int32_t cell_variance_guard,
+    int32_t output_variance_guard,
+    const int32_t* input_to_forget_effective_bias,
+    const int32_t* recurrent_to_forget_effective_bias,
+    const int32_t* input_to_cell_effective_bias,
+    const int32_t* recurrent_to_cell_effective_bias,
+    const int32_t* input_to_output_effective_bias,
+    const int32_t* recurrent_to_output_effective_bias,
+    const int32_t* input_to_input_effective_bias,
+    const int32_t* recurrent_to_input_effective_bias,
+    const int32_t* projection_effective_bias, int n_batch, int n_cell,
+    int n_input, int n_output, int8_t* output_state_ptr,
+    int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
+    int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
+    int8_t* scratch4, int32_t* scratch5) {
+  // Make named scratch buffers for the different gates.
+  int16_t* input_gate_scratch = scratch0;
+  int16_t* forget_gate_scratch = scratch1;
+  int16_t* cell_gate_scratch = scratch2;
+  int16_t* output_gate_scratch = scratch3;
+
+  // Since we have already checked that weights are all there or none, we
+  // can check the existence of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weight_ptr == nullptr);
+
+  // Check for nullptrs.
+  TFLITE_DCHECK(input_to_forget_effective_bias);
+  TFLITE_DCHECK(recurrent_to_forget_effective_bias);
+  TFLITE_DCHECK(input_to_cell_effective_bias);
+  TFLITE_DCHECK(recurrent_to_cell_effective_bias);
+  TFLITE_DCHECK(input_to_output_effective_bias);
+  TFLITE_DCHECK(recurrent_to_output_effective_bias);
+  if (!use_cifg) {
+    TFLITE_DCHECK(input_to_input_effective_bias);
+    TFLITE_DCHECK(recurrent_to_input_effective_bias);
+  }
+  const bool use_projection = (projection_weight_ptr != nullptr);
+  if (use_projection) {
+    TFLITE_DCHECK(projection_effective_bias);
+  }
+  if (!use_cifg) {
+    // Calculate the input gate. (If not CIFG.)
+    CalculateLstmGateInteger8x8_16(
+        input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
+        effective_input_to_input_scale_a, effective_input_to_input_scale_b,
+        output_state_ptr, recurrent_to_input_weight_ptr,
+        recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
+        effective_recurrent_to_input_scale_b, cell_state_ptr,
+        cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
+        effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
+        input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
+        input_variance_guard, n_batch, n_input, n_output, n_cell,
+        kTfLiteActSigmoid, input_gate_scratch, scratch5);
+  }
+  // Calculate the forget gate.
+  CalculateLstmGateInteger8x8_16(
+      input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
+      effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
+      output_state_ptr, recurrent_to_forget_weight_ptr,
+      recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
+      effective_recurrent_to_forget_scale_b, cell_state_ptr,
+      cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
+      effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
+      forget_gate_bias_ptr, layer_norm_forget_scale_a,
+      layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
+      n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, scratch5);
+  // Calculate the cell update gate.
+  CalculateLstmGateInteger8x8_16(
+      input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
+      effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
+      output_state_ptr, recurrent_to_cell_weight_ptr,
+      recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
+      effective_recurrent_to_cell_scale_b, cell_state_ptr,
+      /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
+      /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
+      cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
+      cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
+      cell_gate_scratch, scratch5);
+  // Update the cell state.
+  UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
+                        input_gate_scratch, forget_gate_scratch,
+                        cell_gate_scratch, use_cifg, quantized_cell_clip);
+  // Calculate the output gate.
+  CalculateLstmGateInteger8x8_16(
+      input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
+      effective_input_to_output_scale_a, effective_input_to_output_scale_b,
+      output_state_ptr, recurrent_to_output_weight_ptr,
+      recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
+      effective_recurrent_to_output_scale_b, cell_state_ptr,
+      cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
+      effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
+      output_gate_bias_ptr, layer_norm_output_scale_a,
+      layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
+      n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, scratch5);
+  // Update the output state.
+  CalculateLstmOutputInteger8x8_16(
+      n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
+      output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
+      hidden_zp, projection_weight_ptr, effective_proj_scale_a,
+      effective_proj_scale_b, projection_effective_bias, output_state_zp,
+      quantized_proj_clip, output_state_ptr, scratch0, scratch4, scratch5);
+  // Copy output state to the output. Note that unlike float or hybrid, output
+  // is always contiguous.
+  std::memcpy(output_ptr, output_state_ptr,
+              n_batch * n_output * sizeof(int8_t));
+}
+
+// Fully quantized lstm kernel for 8 bit gate matmul output.
+//
+// Input tensor of size n_batch * n_input:
+//   input_ptr
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+//   input_to_input_weight_ptr            - optional
+//   input_to_forget_weight_ptr           - optional
+//   input_to_cell_weight_ptr             - optional
+//   input_to_output_weight_ptr           - optional
+//
+// Quantized recurrent weights of size 'n_cell * n_output':
+//   recurrent_to_input_weight_ptr        - optional
+//   recurrent_to_forget_weights_ptr
+//   recurrent_to_cell_weights_ptr
+//   recurrent_to_input_weights_ptr
+//
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+//   cell_to_input_weights               - optional
+//   cell_to_cell_weights                - optional
+//   cell_to_output_weights              - optional
+//
+// Quantized projection weights of size 'n_output * n_cell'
+//   projection_weight_ptr                     - optional
+//
+// Weight scales (scalars) for each of the weights above.
+//   effective_input_to_input_scale_a    - optional
+//   effective_input_to_input_scale_b    - optional
+//   effective_input_to_forget_scale_a
+//   effective_input_to_forget_scale_b
+//   effective_input_to_cell_scale_a
+//   effective_input_to_cell_scale_b
+//   effective_input_to_output_scale_a
+//   effective_input_to_output_scale_b
+//   effective_recurrent_to_input_scale_a    - optional
+//   effective_recurrent_to_input_scale_b    - optional
+//   effective_recurrent_to_forget_scale_a
+//   effective_recurrent_to_forget_scale_b
+//   effective_recurrent_to_cell_scale_a
+//   effective_recurrent_to_cell_scale_b
+//   effective_recurrent_to_output_scale_a
+//   effective_recurrent_to_output_scale_b
+//   effective_proj_scale_a                  - optional
+//   effective_proj_scale_b                  - optional
+//
+// Gate biases of size 'n_cell':
+//   input_gate_bias_ptr                 - optional
+//   forget_gate_bias_ptr
+//   cell_gate_bias_ptr
+//   output_gate_bias_ptr
+//
+// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
+//   layer_norm_input_weight_ptr    - optional
+//   layer_norm_forget_weight_ptr   - optional
+//   layer_norm_cell_weight_ptr     - optional
+//   layer_norm_output_weight_ptr   - optional
+//
+// Layer norm scales of size 'n_cell'.
+//   layer_norm_input_scale_a     - optional
+//   layer_norm_input_scale_b     - optional
+//   layer_norm_forget_scale_a    - optional
+//   layer_norm_forget_scale_b    - optional
+//   layer_norm_cell_scale_a      - optional
+//   layer_norm_cell_scale_b      - optional
+//   layer_norm_output_scale_a    - optional
+//   layer_norm_output_scale_b    - optional
+//
+// Scalar values:
+//   quantized_cell_clip: quantized clip value for cell.
+//   quantized_proj_clip: quantized clip value for projection.
+//   cell_state_scale: the power of two scale for cell state.
+//
+// Zero points:
+//   input_zp: zero point for input tensor.
+//   output_state_zp: zero point of output state.
+//   hidden_zp: zero point for hidden state.
+//
+// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
+// n_batch.
+//   scratch0
+//   scratch1
+//   scratch2
+//   scratch3
+//   scratch4
+//   scratch5
+//   scratch6
+//   scratch7
+//
+// Outputs:
+//   output_state_ptr - size 'n_batch * n_output'
+//   cell_state_ptr   - size 'n_batch * n_cell'
+//   output_ptr       - size 'n_batch * n_output'
+//
+// Can move zero point calculation into Prepare() for better perfomance.
+// TODO(b/159947023): scratch5 is unused, remove.
+inline void LstmStepInteger8x8_8(
+    const int8_t* input_ptr, int32_t input_zp,
+    const int8_t* input_to_input_weight_ptr,
+    int32_t effective_input_to_input_scale_a,
+    int32_t effective_input_to_input_scale_b,
+    const int8_t* input_to_forget_weight_ptr,
+    int32_t effective_input_to_forget_scale_a,
+    int32_t effective_input_to_forget_scale_b,
+    const int8_t* input_to_cell_weight_ptr,
+    int32_t effective_input_to_cell_scale_a,
+    int32_t effective_input_to_cell_scale_b,
+    const int8_t* input_to_output_weight_ptr,
+    int32_t effective_input_to_output_scale_a,
+    int32_t effective_input_to_output_scale_b,
+    const int8_t* recurrent_to_input_weight_ptr,
+    int32_t effective_recurrent_to_input_scale_a,
+    int32_t effective_recurrent_to_input_scale_b,
+    const int8_t* recurrent_to_forget_weight_ptr,
+    int32_t effective_recurrent_to_forget_scale_a,
+    int32_t effective_recurrent_to_forget_scale_b,
+    const int8_t* recurrent_to_cell_weight_ptr,
+    int32_t effective_recurrent_to_cell_scale_a,
+    int32_t effective_recurrent_to_cell_scale_b,
+    const int8_t* recurrent_to_output_weight_ptr,
+    int32_t effective_recurrent_to_output_scale_a,
+    int32_t effective_recurrent_to_output_scale_b,
+    const int8_t* cell_to_input_weight_ptr,
+    int32_t effective_cell_to_input_scale_a,
+    int32_t effective_cell_to_input_scale_b,
+    const int8_t* cell_to_forget_weight_ptr,
+    int32_t effective_cell_to_forget_scale_a,
+    int32_t effective_cell_to_forget_scale_b,
+    const int8_t* cell_to_output_weight_ptr,
+    int32_t effective_cell_to_output_scale_a,
+    int32_t effective_cell_to_output_scale_b,
+    const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
+    int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
+    int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
+    const int16_t* layer_norm_forget_weight_ptr,
+    int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
+    const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
+    int32_t layer_norm_cell_scale_b,
+    const int16_t* layer_norm_output_weight_ptr,
+    int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
+    const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
+    const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
+    const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
+    const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
+    const int32_t* intermediate_zp, int16_t quantized_cell_clip,
+    int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
+    int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
+    int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
+    int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
+    int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
+    int16_t* scratch7) {
+  // TODO(b/159066113): scratch5 is unused, remove.
+
+  // Make named scratch buffers for the different gates.
+  int16_t* forget_gate_scratch = scratch2;
+  int16_t* cell_gate_scratch = scratch3;
+  int16_t* output_gate_scratch = scratch4;
+  // no-CIFG is not supported here
+
+  // Calculate the forget gate.
+  CalculateLstmGateInteger8x8_8(
+      input_ptr, input_zp, input_to_forget_weight_ptr,
+      effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
+      intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
+      output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
+      effective_recurrent_to_forget_scale_a,
+      effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
+      intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
+      layer_norm_forget_scale_a, layer_norm_forget_scale_b,
+      forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
+      kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
+  // Calculate the cell update gate.
+  CalculateLstmGateInteger8x8_8(
+      input_ptr, input_zp, input_to_cell_weight_ptr,
+      effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
+      intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
+      output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
+      effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
+      intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
+      layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
+      layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
+      n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
+  // Update the cell state.
+  UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
+                        /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
+                        forget_gate_scratch, cell_gate_scratch,
+                        /*use_cifg=*/true, quantized_cell_clip);
+  // Calculate the output gate.
+  CalculateLstmGateInteger8x8_8(
+      input_ptr, input_zp, input_to_output_weight_ptr,
+      effective_input_to_output_scale_a, effective_input_to_output_scale_b,
+      intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
+      output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
+      effective_recurrent_to_output_scale_a,
+      effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
+      intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
+      layer_norm_output_scale_a, layer_norm_output_scale_b,
+      output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
+      kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
+  // Update the output state.
+  CalculateLstmOutputInteger8x8_8(
+      n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
+      projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
+      projection_bias_ptr, output_state_zp, quantized_proj_clip,
+      output_state_ptr, scratch2);
+  // Copy output state to the output. Note that unlike float or hybrid, output
+  // is always contigous.
+  std::memcpy(output_ptr, output_state_ptr,
+              n_batch * n_output * sizeof(int8_t));
+}
+
+}  // namespace
+
+TfLiteStatus EvalFloatLstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* aux_input,
+    const TfLiteEvalTensor* aux_input_to_input_weights,
+    const TfLiteEvalTensor* aux_input_to_forget_weights,
+    const TfLiteEvalTensor* aux_input_to_cell_weights,
+    const TfLiteEvalTensor* aux_input_to_output_weights,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major, int output_offset,
+    float* scratch_buffer, TfLiteEvalTensor* output_state,
+    TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output) {
+  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
+  int max_time, n_batch;
+  if (input->dims->size == 3) {
+    max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
+    n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
+  } else {
+    max_time = 1;
+    n_batch = input->dims->data[0];
+  }
+  const int n_input = input->dims->data[input->dims->size - 1];
+  const int aux_input_size =
+      (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Since we have already checked that weights are all there or none, we can
+  // check the existence of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weights == nullptr);
+
+  // Index the scratch buffers pointers to the global scratch buffer.
+  float* input_gate_scratch = nullptr;
+  float* cell_gate_scratch = nullptr;
+  float* forget_gate_scratch = nullptr;
+  float* output_gate_scratch = nullptr;
+  if (use_cifg) {
+    cell_gate_scratch = scratch_buffer;
+    forget_gate_scratch = scratch_buffer + n_cell * n_batch;
+    output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
+  } else {
+    input_gate_scratch = scratch_buffer;
+    cell_gate_scratch = scratch_buffer + n_cell * n_batch;
+    forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
+    output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch;
+  }
+
+  const int output_batch_leading_dim =
+      output->dims->data[output->dims->size - 1];
+  if (time_major) {
+    // Loop through the sequence.
+    const int input_step = n_batch * n_input;
+    const int output_step = n_batch * output_batch_leading_dim;
+    for (int t = 0; t < max_time; t++) {
+      // If this is the forward_sequence, step forward, otherwise step
+      // backwards.
+      const int t_rel = forward_sequence ? t : max_time - t - 1;
+      const float* input_ptr =
+          tflite::micro::GetTensorData<float>(input) + t_rel * input_step;
+      const float* aux_input_ptr = nullptr;
+      if (aux_input) {
+        aux_input_ptr =
+            tflite::micro::GetTensorData<float>(aux_input) + t_rel * input_step;
+      }
+      float* output_ptr = tflite::micro::GetTensorData<float>(output) +
+                          t_rel * output_step + output_offset;
+
+      LstmStepFloat(
+          input_ptr,
+          input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_to_input_weights),
+          input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_to_forget_weights),
+          input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_to_cell_weights),
+          input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_to_output_weights),
+          aux_input_ptr,
+          aux_input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(aux_input_to_input_weights),
+          aux_input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    aux_input_to_forget_weights),
+          aux_input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(aux_input_to_cell_weights),
+          aux_input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    aux_input_to_output_weights),
+          recurrent_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(recurrent_to_input_weights),
+          recurrent_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    recurrent_to_forget_weights),
+          recurrent_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(recurrent_to_cell_weights),
+          recurrent_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    recurrent_to_output_weights),
+          cell_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_to_input_weights),
+          cell_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_to_forget_weights),
+          cell_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_to_output_weights),
+          input_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    input_layer_norm_coefficients),
+          forget_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    forget_layer_norm_coefficients),
+          cell_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    cell_layer_norm_coefficients),
+          output_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    output_layer_norm_coefficients),
+          input_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_gate_bias),
+          forget_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(forget_gate_bias),
+          cell_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_gate_bias),
+          output_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(output_gate_bias),
+          projection_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(projection_weights),
+          projection_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(projection_bias),
+          params, n_batch, n_cell, n_input, aux_input_size, n_output,
+          output_batch_leading_dim,
+          tflite::micro::GetTensorData<float>(output_state),
+          tflite::micro::GetTensorData<float>(cell_state), input_gate_scratch,
+          forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
+          output_ptr);
+    }
+  } else {
+    for (int b = 0; b < n_batch; b++) {
+      const int input_step = n_input;
+      const int output_step = output_batch_leading_dim;
+      for (int t = 0; t < max_time; t++) {
+        // If this is the forward_sequence, step forward, otherwise step
+        // backwards.
+        const int t_rel = forward_sequence ? t : max_time - t - 1;
+        const int time_offset = b * max_time + t_rel;
+        const float* input_ptr = tflite::micro::GetTensorData<float>(input) +
+                                 time_offset * input_step;
+        const float* aux_input_ptr = nullptr;
+        if (aux_input) {
+          aux_input_ptr = tflite::micro::GetTensorData<float>(aux_input) +
+                          time_offset * input_step;
+        }
+        float* output_ptr = tflite::micro::GetTensorData<float>(output) +
+                            time_offset * output_step + output_offset;
+
+        // Offset the {output,cell}_state pointers to the right batch.
+        float* output_state_ptr =
+            tflite::micro::GetTensorData<float>(output_state) +
+            b * output_batch_leading_dim;
+        float* cell_state_ptr =
+            tflite::micro::GetTensorData<float>(cell_state) + b * n_cell;
+        // Offset the scratch pointers to the right batch.
+        float* input_gate_scratch_ptr =
+            input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
+        float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
+        float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
+        float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
+
+        LstmStepFloat(
+            input_ptr,
+            input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_to_input_weights),
+            input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_to_forget_weights),
+            input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_to_cell_weights),
+            input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_to_output_weights),
+            aux_input_ptr,
+            aux_input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      aux_input_to_input_weights),
+            aux_input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      aux_input_to_forget_weights),
+            aux_input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      aux_input_to_cell_weights),
+            aux_input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      aux_input_to_output_weights),
+            recurrent_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      recurrent_to_input_weights),
+            recurrent_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      recurrent_to_forget_weights),
+            recurrent_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      recurrent_to_cell_weights),
+            recurrent_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      recurrent_to_output_weights),
+            cell_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_to_input_weights),
+            cell_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_to_forget_weights),
+            cell_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_to_output_weights),
+            input_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      input_layer_norm_coefficients),
+            forget_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      forget_layer_norm_coefficients),
+            cell_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      cell_layer_norm_coefficients),
+            output_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      output_layer_norm_coefficients),
+            input_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_gate_bias),
+            forget_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(forget_gate_bias),
+            cell_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_gate_bias),
+            output_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(output_gate_bias),
+            projection_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(projection_weights),
+            projection_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(projection_bias),
+            params,
+            /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
+            output_batch_leading_dim, output_state_ptr, cell_state_ptr,
+            input_gate_scratch_ptr, forget_gate_scratch_ptr,
+            cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
+      }
+    }
+  }
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybridLstm(
+    const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_input_weights_ledger,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_forget_weights_ledger,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_cell_weights_ledger,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* input_to_output_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* aux_input,
+    const TfLiteEvalTensor* aux_input_to_input_weights,
+    const TfLiteEvalTensor* aux_input_to_forget_weights,
+    const TfLiteEvalTensor* aux_input_to_cell_weights,
+    const TfLiteEvalTensor* aux_input_to_output_weights,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_weights_ledger,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major, int output_offset,
+    float* scratch_buffer, float* input_sf, float* aux_input_sf,
+    float* output_state_sf, float* prod_scaling_factors,
+    float* recovered_cell_weights, int8_t* input_quantized,
+    int8_t* aux_input_quantized, int8_t* output_state_quantized,
+    int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
+    TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
+    TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
+    int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
+    bool* compute_row_sums) {
+  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
+  const int n_input = input->dims->data[input->dims->size - 1];
+  int max_time, n_batch;
+  if (input->dims->size == 2) {
+    max_time = 1;
+    n_batch = input->dims->data[0];
+  } else {
+    max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
+    n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
+  }
+  const int aux_input_size =
+      (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Since we have already checked that weights are all there or none, we can
+  // check the existence of only one to get the condition.
+  const bool use_cifg = (input_to_input_weights == nullptr);
+
+  float* input_gate_scratch = nullptr;
+  float* cell_gate_scratch = nullptr;
+  float* forget_gate_scratch = nullptr;
+  float* output_gate_scratch = nullptr;
+  if (use_cifg) {
+    cell_gate_scratch = scratch_buffer;
+    forget_gate_scratch = scratch_buffer + n_cell * n_batch;
+    output_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
+  } else {
+    input_gate_scratch = scratch_buffer;
+    cell_gate_scratch = scratch_buffer + n_cell * n_batch;
+    forget_gate_scratch = scratch_buffer + 2 * n_cell * n_batch;
+    output_gate_scratch = scratch_buffer + 3 * n_cell * n_batch;
+  }
+
+  const int output_batch_leading_dim =
+      output->dims->data[output->dims->size - 1];
+
+  int32_t* input_zp_ptr = nullptr;
+  int32_t* aux_input_zp_ptr = nullptr;
+  int32_t* output_state_zp_ptr = nullptr;
+  int32_t* row_sums_ptr = nullptr;
+  if (params->asymmetric_quantize_inputs) {
+    input_zp_ptr = input_zp;
+    aux_input_zp_ptr = aux_input_zp;
+    output_state_zp_ptr = output_state_zp;
+    row_sums_ptr = row_sums;
+  }
+
+  if (time_major) {
+    // Feed the sequence into the LSTM step-by-step.
+    const int input_step = n_batch * n_input;
+    const int output_step = n_batch * output_batch_leading_dim;
+    for (int t = 0; t < max_time; t++) {
+      // If this is the forward_sequence, step forward, otherwise step
+      // backwards.
+      const int t_rel = forward_sequence ? t : max_time - t - 1;
+      const float* input_ptr =
+          tflite::micro::GetTensorData<float>(input) + t_rel * input_step;
+      const float* aux_input_ptr = nullptr;
+      if (aux_input) {
+        aux_input_ptr =
+            tflite::micro::GetTensorData<float>(aux_input) + t_rel * input_step;
+      }
+      float* output_ptr = tflite::micro::GetTensorData<float>(output) +
+                          t_rel * output_step + output_offset;
+      LstmStepHybrid(
+          input_ptr,
+          input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+          input_to_input_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    input_to_input_weights_ledger),
+          hybrid_lstm_scales->input_to_input_weights_scale,
+          input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+          input_to_forget_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    input_to_forget_weights_ledger),
+          hybrid_lstm_scales->input_to_forget_weights_scale,
+          input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+          input_to_cell_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    input_to_cell_weights_ledger),
+          hybrid_lstm_scales->input_to_cell_weights_scale,
+          input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+          input_to_output_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    input_to_output_weights_ledger),
+          hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr,
+          aux_input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    aux_input_to_input_weights),
+          hybrid_lstm_scales->aux_input_to_input_weights_scale,
+          aux_input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    aux_input_to_forget_weights),
+          hybrid_lstm_scales->aux_input_to_forget_weights_scale,
+          aux_input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(aux_input_to_cell_weights),
+          hybrid_lstm_scales->aux_input_to_cell_weights_scale,
+          aux_input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    aux_input_to_output_weights),
+          hybrid_lstm_scales->aux_input_to_output_weights_scale,
+          recurrent_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_input_weights),
+          recurrent_to_input_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    recurrent_to_input_weights_ledger),
+          hybrid_lstm_scales->recurrent_to_input_weights_scale,
+          recurrent_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_forget_weights),
+          recurrent_to_forget_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    recurrent_to_forget_weights_ledger),
+          hybrid_lstm_scales->recurrent_to_forget_weights_scale,
+          recurrent_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
+          recurrent_to_cell_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    recurrent_to_cell_weights_ledger),
+          hybrid_lstm_scales->recurrent_to_cell_weights_scale,
+          recurrent_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_output_weights),
+          recurrent_to_output_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    recurrent_to_output_weights_ledger),
+          hybrid_lstm_scales->recurrent_to_output_weights_scale,
+          cell_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
+          hybrid_lstm_scales->cell_to_input_weights_scale,
+          cell_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
+          hybrid_lstm_scales->cell_to_forget_weights_scale,
+          cell_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
+          hybrid_lstm_scales->cell_to_output_weights_scale,
+          input_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    input_layer_norm_coefficients),
+          forget_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    forget_layer_norm_coefficients),
+          cell_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    cell_layer_norm_coefficients),
+          output_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(
+                    output_layer_norm_coefficients),
+          input_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(input_gate_bias),
+          forget_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(forget_gate_bias),
+          cell_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(cell_gate_bias),
+          output_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(output_gate_bias),
+          projection_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(projection_weights),
+          projection_weights_ledger == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<uint8_t>(
+                    projection_weights_ledger),
+          hybrid_lstm_scales->projection_weights_scale,
+          projection_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<float>(projection_bias),
+          params, n_batch, n_cell, n_input, aux_input_size, n_output,
+          output_batch_leading_dim, input_gate_scratch, forget_gate_scratch,
+          cell_gate_scratch, output_gate_scratch, scales, input_sf,
+          aux_input_sf, output_state_sf, prod_scaling_factors,
+          recovered_cell_weights, input_quantized, aux_input_quantized,
+          output_state_quantized, cell_state_quantized,
+          tflite::micro::GetTensorData<float>(output_state),
+          tflite::micro::GetTensorData<float>(cell_state),
+          output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr,
+          output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
+          params->asymmetric_quantize_inputs);
+    }
+  } else {
+    for (int b = 0; b < n_batch; b++) {
+      const int input_step = n_input;
+      const int output_step = output_batch_leading_dim;
+      for (int t = 0; t < max_time; t++) {
+        // If this is the forward_sequence, step forward, otherwise step
+        // backwards.
+        const int t_rel = forward_sequence ? t : max_time - t - 1;
+        const int time_offset = b * max_time + t_rel;
+        const float* input_ptr = tflite::micro::GetTensorData<float>(input) +
+                                 time_offset * input_step;
+        const float* aux_input_ptr = nullptr;
+        if (aux_input) {
+          aux_input_ptr = tflite::micro::GetTensorData<float>(aux_input) +
+                          time_offset * input_step;
+        }
+        float* output_ptr = tflite::micro::GetTensorData<float>(output) +
+                            time_offset * output_step + output_offset;
+
+        // Offset the {output,cell}_state pointers to the right batch.
+        float* output_state_ptr =
+            tflite::micro::GetTensorData<float>(output_state) +
+            b * output_batch_leading_dim;
+        float* cell_state_ptr =
+            tflite::micro::GetTensorData<float>(cell_state) + b * n_cell;
+        // Offset the scratch pointers to the right batch.
+        float* input_gate_scratch_ptr =
+            input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
+        float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
+        float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
+        float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
+
+        LstmStepHybrid(
+            input_ptr,
+            input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+            input_to_input_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      input_to_input_weights_ledger),
+            hybrid_lstm_scales->input_to_input_weights_scale,
+            input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+            input_to_forget_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      input_to_forget_weights_ledger),
+            hybrid_lstm_scales->input_to_forget_weights_scale,
+            input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+            input_to_cell_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      input_to_cell_weights_ledger),
+            hybrid_lstm_scales->input_to_cell_weights_scale,
+            input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+            input_to_output_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      input_to_output_weights_ledger),
+            hybrid_lstm_scales->input_to_output_weights_scale, aux_input_ptr,
+            aux_input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      aux_input_to_input_weights),
+            hybrid_lstm_scales->aux_input_to_input_weights_scale,
+            aux_input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      aux_input_to_forget_weights),
+            hybrid_lstm_scales->aux_input_to_forget_weights_scale,
+            aux_input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      aux_input_to_cell_weights),
+            hybrid_lstm_scales->aux_input_to_cell_weights_scale,
+            aux_input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      aux_input_to_output_weights),
+            hybrid_lstm_scales->aux_input_to_output_weights_scale,
+            recurrent_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_input_weights),
+            recurrent_to_input_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      recurrent_to_input_weights_ledger),
+            hybrid_lstm_scales->recurrent_to_input_weights_scale,
+            recurrent_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_forget_weights),
+            recurrent_to_forget_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      recurrent_to_forget_weights_ledger),
+            hybrid_lstm_scales->recurrent_to_forget_weights_scale,
+            recurrent_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_cell_weights),
+            recurrent_to_cell_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      recurrent_to_cell_weights_ledger),
+            hybrid_lstm_scales->recurrent_to_cell_weights_scale,
+            recurrent_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_output_weights),
+            recurrent_to_output_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      recurrent_to_output_weights_ledger),
+            hybrid_lstm_scales->recurrent_to_output_weights_scale,
+            cell_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
+            hybrid_lstm_scales->cell_to_input_weights_scale,
+            cell_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
+            hybrid_lstm_scales->cell_to_forget_weights_scale,
+            cell_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
+            hybrid_lstm_scales->cell_to_output_weights_scale,
+            input_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      input_layer_norm_coefficients),
+            forget_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      forget_layer_norm_coefficients),
+            cell_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      cell_layer_norm_coefficients),
+            output_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(
+                      output_layer_norm_coefficients),
+            input_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(input_gate_bias),
+            forget_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(forget_gate_bias),
+            cell_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(cell_gate_bias),
+            output_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(output_gate_bias),
+            projection_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(projection_weights),
+            projection_weights_ledger == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<uint8_t>(
+                      projection_weights_ledger),
+            hybrid_lstm_scales->projection_weights_scale,
+            projection_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<float>(projection_bias),
+            params,
+            /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
+            output_batch_leading_dim, input_gate_scratch_ptr,
+            forget_gate_scratch_ptr, cell_gate_scratch_ptr,
+            output_gate_scratch_ptr, scales, input_sf, aux_input_sf,
+            output_state_sf, prod_scaling_factors, recovered_cell_weights,
+            input_quantized, aux_input_quantized, output_state_quantized,
+            cell_state_quantized, output_state_ptr, cell_state_ptr,
+            output_scratch_buffer, output_ptr, input_zp_ptr, aux_input_zp_ptr,
+            output_state_zp_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
+            params->asymmetric_quantize_inputs);
+      }
+    }
+  }
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalInteger8x8_16Lstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major,
+    const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
+    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
+    TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
+    int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5) {
+  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
+  const int n_input = input->dims->data[input->dims->size - 1];
+  int max_time, n_batch;
+  if (input->dims->size == 2) {
+    max_time = 1;
+    n_batch = input->dims->data[0];
+  } else {
+    max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
+    n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
+  }
+
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Get params for time/batch/sequence.
+  const int output_batch_leading_dim =
+      output->dims->data[output->dims->size - 1];
+
+  if (time_major) {
+    const int input_step = n_batch * n_input;
+    const int output_step = n_batch * output_batch_leading_dim;
+    for (int t = 0; t < max_time; t++) {
+      const int t_rel = t;
+      int8_t* output_ptr =
+          tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
+      const int8_t* input_ptr =
+          tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
+      LstmStepInteger8x8_16(
+          input_ptr,
+          input_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+          integer_lstm_param->effective_input_to_input_scale_a,
+          integer_lstm_param->effective_input_to_input_scale_b,
+          input_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+          integer_lstm_param->effective_input_to_forget_scale_a,
+          integer_lstm_param->effective_input_to_forget_scale_b,
+          input_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+          integer_lstm_param->effective_input_to_cell_scale_a,
+          integer_lstm_param->effective_input_to_cell_scale_b,
+          input_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+          integer_lstm_param->effective_input_to_output_scale_a,
+          integer_lstm_param->effective_input_to_output_scale_b,
+          recurrent_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_input_weights),
+          integer_lstm_param->effective_recurrent_to_input_scale_a,
+          integer_lstm_param->effective_recurrent_to_input_scale_b,
+          recurrent_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_forget_weights),
+          integer_lstm_param->effective_recurrent_to_forget_scale_a,
+          integer_lstm_param->effective_recurrent_to_forget_scale_b,
+          recurrent_to_cell_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
+          integer_lstm_param->effective_recurrent_to_cell_scale_a,
+          integer_lstm_param->effective_recurrent_to_cell_scale_b,
+          recurrent_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(
+                    recurrent_to_output_weights),
+          integer_lstm_param->effective_recurrent_to_output_scale_a,
+          integer_lstm_param->effective_recurrent_to_output_scale_b,
+          cell_to_input_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
+          integer_lstm_param->effective_cell_to_input_scale_a,
+          integer_lstm_param->effective_cell_to_input_scale_b,
+          cell_to_forget_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
+          integer_lstm_param->effective_cell_to_forget_scale_a,
+          integer_lstm_param->effective_cell_to_forget_scale_b,
+          cell_to_output_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
+          integer_lstm_param->effective_cell_to_output_scale_a,
+          integer_lstm_param->effective_cell_to_output_scale_b,
+          projection_weights == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int8_t>(projection_weights),
+          integer_lstm_param->effective_proj_scale_a,
+          integer_lstm_param->effective_proj_scale_b,
+          integer_lstm_param->hidden_zp,
+          integer_lstm_param->effective_hidden_scale_a,
+          integer_lstm_param->effective_hidden_scale_b,
+          input_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(
+                    input_layer_norm_coefficients),
+          integer_lstm_param->layer_norm_input_scale_a,
+          integer_lstm_param->layer_norm_input_scale_b,
+          forget_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(
+                    forget_layer_norm_coefficients),
+          integer_lstm_param->layer_norm_forget_scale_a,
+          integer_lstm_param->layer_norm_forget_scale_b,
+          cell_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(
+                    cell_layer_norm_coefficients),
+          integer_lstm_param->layer_norm_cell_scale_a,
+          integer_lstm_param->layer_norm_cell_scale_b,
+          output_layer_norm_coefficients == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int16_t>(
+                    output_layer_norm_coefficients),
+          integer_lstm_param->layer_norm_output_scale_a,
+          integer_lstm_param->layer_norm_output_scale_b,
+          input_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
+          forget_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
+          cell_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
+          output_gate_bias == nullptr
+              ? nullptr
+              : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
+          integer_lstm_param->quantized_cell_clip,
+          integer_lstm_param->quantized_proj_clip,
+          integer_lstm_param->cell_scale,
+          integer_lstm_param->input_variance_guard,
+          integer_lstm_param->forget_variance_guard,
+          integer_lstm_param->cell_variance_guard,
+          integer_lstm_param->output_variance_guard,
+          integer_lstm_param->input_to_forget_effective_bias,
+          integer_lstm_param->recurrent_to_forget_effective_bias,
+          integer_lstm_param->input_to_cell_effective_bias,
+          integer_lstm_param->recurrent_to_cell_effective_bias,
+          integer_lstm_param->input_to_output_effective_bias,
+          integer_lstm_param->recurrent_to_output_effective_bias,
+          integer_lstm_param->input_to_input_effective_bias,
+          integer_lstm_param->recurrent_to_input_effective_bias,
+          integer_lstm_param->projection_effective_bias, n_batch, n_cell,
+          n_input, n_output, tflite::micro::GetTensorData<int8_t>(output_state),
+          output_state_zp, tflite::micro::GetTensorData<int16_t>(cell_state),
+          output_ptr, scratch0, scratch1, scratch2, scratch3, scratch4,
+          scratch5);
+    }
+  } else {
+    for (int b = 0; b < n_batch; b++) {
+      const int input_step = n_input;
+      const int output_step = output_batch_leading_dim;
+      for (int t = 0; t < max_time; t++) {
+        // If this is the forward_sequence, step forward, otherwise step
+        // backwards.
+        const int t_rel = forward_sequence ? t : max_time - t - 1;
+        const int time_offset = b * max_time + t_rel;
+        const int8_t* input_ptr = tflite::micro::GetTensorData<int8_t>(input) +
+                                  time_offset * input_step;
+        int8_t* output_ptr = tflite::micro::GetTensorData<int8_t>(output) +
+                             time_offset * output_step;
+
+        // Offset the {output,cell}_state pointers to the right batch.
+        int8_t* output_state_ptr =
+            tflite::micro::GetTensorData<int8_t>(output_state) +
+            b * output_batch_leading_dim;
+        int16_t* cell_state_ptr =
+            tflite::micro::GetTensorData<int16_t>(cell_state) + b * n_cell;
+
+        LstmStepInteger8x8_16(
+            input_ptr,
+            input_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+            integer_lstm_param->effective_input_to_input_scale_a,
+            integer_lstm_param->effective_input_to_input_scale_b,
+            input_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+            integer_lstm_param->effective_input_to_forget_scale_a,
+            integer_lstm_param->effective_input_to_forget_scale_b,
+            input_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+            integer_lstm_param->effective_input_to_cell_scale_a,
+            integer_lstm_param->effective_input_to_cell_scale_b,
+            input_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+            integer_lstm_param->effective_input_to_output_scale_a,
+            integer_lstm_param->effective_input_to_output_scale_b,
+            recurrent_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_input_weights),
+            integer_lstm_param->effective_recurrent_to_input_scale_a,
+            integer_lstm_param->effective_recurrent_to_input_scale_b,
+            recurrent_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_forget_weights),
+            integer_lstm_param->effective_recurrent_to_forget_scale_a,
+            integer_lstm_param->effective_recurrent_to_forget_scale_b,
+            recurrent_to_cell_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_cell_weights),
+            integer_lstm_param->effective_recurrent_to_cell_scale_a,
+            integer_lstm_param->effective_recurrent_to_cell_scale_b,
+            recurrent_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(
+                      recurrent_to_output_weights),
+            integer_lstm_param->effective_recurrent_to_output_scale_a,
+            integer_lstm_param->effective_recurrent_to_output_scale_b,
+            cell_to_input_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
+            integer_lstm_param->effective_cell_to_input_scale_a,
+            integer_lstm_param->effective_cell_to_input_scale_b,
+            cell_to_forget_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
+            integer_lstm_param->effective_cell_to_forget_scale_a,
+            integer_lstm_param->effective_cell_to_forget_scale_b,
+            cell_to_output_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
+            integer_lstm_param->effective_cell_to_output_scale_a,
+            integer_lstm_param->effective_cell_to_output_scale_b,
+            projection_weights == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int8_t>(projection_weights),
+            integer_lstm_param->effective_proj_scale_a,
+            integer_lstm_param->effective_proj_scale_b,
+            integer_lstm_param->hidden_zp,
+            integer_lstm_param->effective_hidden_scale_a,
+            integer_lstm_param->effective_hidden_scale_b,
+            input_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(
+                      input_layer_norm_coefficients),
+            integer_lstm_param->layer_norm_input_scale_a,
+            integer_lstm_param->layer_norm_input_scale_b,
+            forget_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(
+                      forget_layer_norm_coefficients),
+            integer_lstm_param->layer_norm_forget_scale_a,
+            integer_lstm_param->layer_norm_forget_scale_b,
+            cell_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(
+                      cell_layer_norm_coefficients),
+            integer_lstm_param->layer_norm_cell_scale_a,
+            integer_lstm_param->layer_norm_cell_scale_b,
+            output_layer_norm_coefficients == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int16_t>(
+                      output_layer_norm_coefficients),
+            integer_lstm_param->layer_norm_output_scale_a,
+            integer_lstm_param->layer_norm_output_scale_b,
+            input_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
+            forget_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
+            cell_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
+            output_gate_bias == nullptr
+                ? nullptr
+                : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
+            integer_lstm_param->quantized_cell_clip,
+            integer_lstm_param->quantized_proj_clip,
+            integer_lstm_param->cell_scale,
+            integer_lstm_param->input_variance_guard,
+            integer_lstm_param->forget_variance_guard,
+            integer_lstm_param->cell_variance_guard,
+            integer_lstm_param->output_variance_guard,
+            integer_lstm_param->input_to_forget_effective_bias,
+            integer_lstm_param->recurrent_to_forget_effective_bias,
+            integer_lstm_param->input_to_cell_effective_bias,
+            integer_lstm_param->recurrent_to_cell_effective_bias,
+            integer_lstm_param->input_to_output_effective_bias,
+            integer_lstm_param->recurrent_to_output_effective_bias,
+            integer_lstm_param->input_to_input_effective_bias,
+            integer_lstm_param->recurrent_to_input_effective_bias,
+            integer_lstm_param->projection_effective_bias, /*n_batch=*/1,
+            n_cell, n_input, n_output, output_state_ptr, output_state_zp,
+            cell_state_ptr, output_ptr, scratch0, scratch1, scratch2, scratch3,
+            scratch4, scratch5);
+      }
+    }
+  }
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalInteger8x8_8Lstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
+    TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
+    int32_t input_zp, int32_t output_state_zp, int8_t* scratch0,
+    int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, int16_t* scratch4,
+    int16_t* scratch5, int16_t* scratch6, int16_t* scratch7) {
+  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
+  const int n_input = input->dims->data[input->dims->size - 1];
+  int max_time, n_batch;
+  if (input->dims->size == 2) {
+    max_time = 1;
+    n_batch = input->dims->data[0];
+  } else {
+    max_time = input->dims->data[0];
+    n_batch = input->dims->data[1];
+  }
+
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Get params for time/batch/sequence.
+  const int output_batch_leading_dim =
+      output->dims->data[output->dims->size - 1];
+  const int input_step = n_batch * n_input;
+  const int output_step = n_batch * output_batch_leading_dim;
+
+  for (int t = 0; t < max_time; t++) {
+    const int t_rel = t;
+    int8_t* output_ptr =
+        tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
+    // Input can be int8 asymmetric or int16 symmetric.
+    const int8_t* input_ptr =
+        tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
+    LstmStepInteger8x8_8(
+        input_ptr, input_zp,
+
+        input_to_input_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
+        integer_lstm_param->effective_input_to_input_scale_a,
+        integer_lstm_param->effective_input_to_input_scale_b,
+
+        input_to_forget_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
+        integer_lstm_param->effective_input_to_forget_scale_a,
+        integer_lstm_param->effective_input_to_forget_scale_b,
+
+        input_to_cell_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
+        integer_lstm_param->effective_input_to_cell_scale_a,
+        integer_lstm_param->effective_input_to_cell_scale_b,
+
+        input_to_output_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
+        integer_lstm_param->effective_input_to_output_scale_a,
+        integer_lstm_param->effective_input_to_output_scale_b,
+
+        recurrent_to_input_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
+        integer_lstm_param->effective_recurrent_to_input_scale_a,
+        integer_lstm_param->effective_recurrent_to_input_scale_b,
+
+        recurrent_to_forget_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
+        integer_lstm_param->effective_recurrent_to_forget_scale_a,
+        integer_lstm_param->effective_recurrent_to_forget_scale_b,
+
+        recurrent_to_cell_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
+        integer_lstm_param->effective_recurrent_to_cell_scale_a,
+        integer_lstm_param->effective_recurrent_to_cell_scale_b,
+
+        recurrent_to_output_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
+        integer_lstm_param->effective_recurrent_to_output_scale_a,
+        integer_lstm_param->effective_recurrent_to_output_scale_b,
+
+        cell_to_input_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
+        integer_lstm_param->effective_cell_to_input_scale_a,
+        integer_lstm_param->effective_cell_to_input_scale_b,
+
+        cell_to_forget_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
+        integer_lstm_param->effective_cell_to_forget_scale_a,
+        integer_lstm_param->effective_cell_to_forget_scale_b,
+
+        cell_to_output_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
+        integer_lstm_param->effective_cell_to_output_scale_a,
+        integer_lstm_param->effective_cell_to_output_scale_b,
+
+        projection_weights == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int8_t>(projection_weights),
+        integer_lstm_param->effective_proj_scale_a,
+        integer_lstm_param->effective_proj_scale_b,
+
+        input_layer_norm_coefficients == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int16_t>(
+                  input_layer_norm_coefficients),
+        integer_lstm_param->layer_norm_input_scale_a,
+        integer_lstm_param->layer_norm_input_scale_b,
+
+        forget_layer_norm_coefficients == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int16_t>(
+                  forget_layer_norm_coefficients),
+        integer_lstm_param->layer_norm_forget_scale_a,
+        integer_lstm_param->layer_norm_forget_scale_b,
+
+        cell_layer_norm_coefficients == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int16_t>(
+                  cell_layer_norm_coefficients),
+        integer_lstm_param->layer_norm_cell_scale_a,
+        integer_lstm_param->layer_norm_cell_scale_b,
+
+        output_layer_norm_coefficients == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int16_t>(
+                  output_layer_norm_coefficients),
+        integer_lstm_param->layer_norm_output_scale_a,
+        integer_lstm_param->layer_norm_output_scale_b,
+
+        input_gate_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(input_gate_bias),
+        forget_gate_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
+        cell_gate_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
+        output_gate_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(output_gate_bias),
+        projection_bias == nullptr
+            ? nullptr
+            : tflite::micro::GetTensorData<int32_t>(projection_bias),
+
+        params, integer_lstm_param->intermediate_scale_a,
+        integer_lstm_param->intermediate_scale_b,
+        integer_lstm_param->intermediate_zp,
+        integer_lstm_param->quantized_cell_clip,
+        integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
+        n_output, output_batch_leading_dim,
+        tflite::micro::GetTensorData<int8_t>(output_state), output_state_zp,
+        tflite::micro::GetTensorData<int16_t>(cell_state), output_ptr, scratch0,
+        scratch1, scratch2, scratch3, scratch4, scratch5, scratch6, scratch7);
+  }
+
+  return kTfLiteOk;
+}
+
+}  // namespace tflite

+ 250 - 0
code/components/tflite-lib/tensorflow/lite/micro/kernels/lstm_eval.h

@@ -0,0 +1,250 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
+#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_
+
+#include <cstdint>
+#include <memory>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+
+namespace tflite {
+
+// Pamameters for integer LSTM.
+// Consider split this into two Integer Parameters if more fields are added.
+struct IntegerLstmParameter {
+  int32_t effective_input_to_input_scale_a;
+  int32_t effective_input_to_input_scale_b;
+  int32_t effective_recurrent_to_input_scale_a;
+  int32_t effective_recurrent_to_input_scale_b;
+  int32_t effective_cell_to_input_scale_a;
+  int32_t effective_cell_to_input_scale_b;
+  int32_t effective_input_to_forget_scale_a;
+  int32_t effective_input_to_forget_scale_b;
+  int32_t effective_recurrent_to_forget_scale_a;
+  int32_t effective_recurrent_to_forget_scale_b;
+  int32_t effective_cell_to_forget_scale_a;
+  int32_t effective_cell_to_forget_scale_b;
+  int32_t effective_input_to_cell_scale_a;
+  int32_t effective_input_to_cell_scale_b;
+  int32_t effective_recurrent_to_cell_scale_a;
+  int32_t effective_recurrent_to_cell_scale_b;
+  int32_t effective_input_to_output_scale_a;
+  int32_t effective_input_to_output_scale_b;
+  int32_t effective_recurrent_to_output_scale_a;
+  int32_t effective_recurrent_to_output_scale_b;
+  int32_t effective_cell_to_output_scale_a;
+  int32_t effective_cell_to_output_scale_b;
+  int32_t effective_proj_scale_a;
+  int32_t effective_proj_scale_b;
+  int32_t effective_hidden_scale_a;
+  int32_t effective_hidden_scale_b;
+  int32_t layer_norm_input_scale_a;
+  int32_t layer_norm_input_scale_b;
+  int32_t layer_norm_forget_scale_a;
+  int32_t layer_norm_forget_scale_b;
+  int32_t layer_norm_cell_scale_a;
+  int32_t layer_norm_cell_scale_b;
+  int32_t layer_norm_output_scale_a;
+  int32_t layer_norm_output_scale_b;
+  // Quantized clip value for cell and projection. Zero value means no clipping.
+  int16_t quantized_cell_clip;
+  int8_t quantized_proj_clip;
+  int32_t hidden_zp;
+  int32_t cell_scale;
+
+  int32_t input_variance_guard;
+  int32_t forget_variance_guard;
+  int32_t cell_variance_guard;
+  int32_t output_variance_guard;
+
+  // Pre-calculate bias + zero_point * weight.
+  int32_t* input_to_forget_effective_bias;
+  int32_t* recurrent_to_forget_effective_bias;
+  int32_t* input_to_cell_effective_bias;
+  int32_t* recurrent_to_cell_effective_bias;
+  int32_t* input_to_output_effective_bias;
+  int32_t* recurrent_to_output_effective_bias;
+  int32_t* input_to_input_effective_bias;
+  int32_t* recurrent_to_input_effective_bias;
+  int32_t* projection_effective_bias;
+
+  // Scale and zero point for intermediate tensors.
+  // Used only in the 8x8_8 case.
+  int32_t intermediate_scale_a[8];
+  int32_t intermediate_scale_b[8];
+  int32_t intermediate_zp[12];
+};
+
+// Scales for hybrid op with integer inputs and float weights
+struct HybridLstmScales {
+  float input_to_input_weights_scale;
+  float input_to_forget_weights_scale;
+  float input_to_cell_weights_scale;
+  float input_to_output_weights_scale;
+  float aux_input_to_input_weights_scale;
+  float aux_input_to_forget_weights_scale;
+  float aux_input_to_cell_weights_scale;
+  float aux_input_to_output_weights_scale;
+  float recurrent_to_input_weights_scale;
+  float recurrent_to_forget_weights_scale;
+  float recurrent_to_cell_weights_scale;
+  float recurrent_to_output_weights_scale;
+  float cell_to_input_weights_scale;
+  float cell_to_forget_weights_scale;
+  float cell_to_output_weights_scale;
+  float projection_weights_scale;
+};
+
+TfLiteStatus EvalFloatLstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* aux_input,
+    const TfLiteEvalTensor* aux_input_to_input_weights,
+    const TfLiteEvalTensor* aux_input_to_forget_weights,
+    const TfLiteEvalTensor* aux_input_to_cell_weights,
+    const TfLiteEvalTensor* aux_input_to_output_weights,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major, int output_offset,
+    float* scratch_buffer, TfLiteEvalTensor* output_state,
+    TfLiteEvalTensor* cell_state, TfLiteEvalTensor* output);
+
+TfLiteStatus EvalHybridLstm(
+    const HybridLstmScales* hybrid_lstm_scales, const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_input_weights_ledger,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_forget_weights_ledger,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_cell_weights_ledger,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* input_to_output_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights_ledger,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights_ledger,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* aux_input,
+    const TfLiteEvalTensor* aux_input_to_input_weights,
+    const TfLiteEvalTensor* aux_input_to_forget_weights,
+    const TfLiteEvalTensor* aux_input_to_cell_weights,
+    const TfLiteEvalTensor* aux_input_to_output_weights,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_weights_ledger,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major, int output_offset,
+    float* scratch_buffer, float* input_sf, float* aux_input_sf,
+    float* output_state_sf, float* prod_scaling_factors,
+    float* recovered_cell_weights, int8_t* input_quantized,
+    int8_t* aux_input_quantized, int8_t* output_state_quantized,
+    int8_t* cell_state_quantized, float* scales, TfLiteEvalTensor* output_state,
+    TfLiteEvalTensor* cell_state, int32_t* output_scratch_buffer,
+    TfLiteEvalTensor* output, int32_t* input_zp, int32_t* aux_input_zp,
+    int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
+    bool* compute_row_sums);
+
+TfLiteStatus EvalInteger8x8_16Lstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    bool forward_sequence, bool time_major,
+    const IntegerLstmParameter* integer_lstm_param, int32_t output_state_zp,
+    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
+    TfLiteEvalTensor* output, int16_t* scratch0, int16_t* scratch1,
+    int16_t* scratch2, int16_t* scratch3, int8_t* scratch4, int32_t* scratch5);
+
+TfLiteStatus EvalInteger8x8_8Lstm(
+    const TfLiteEvalTensor* input,
+    const TfLiteEvalTensor* input_to_input_weights,
+    const TfLiteEvalTensor* input_to_forget_weights,
+    const TfLiteEvalTensor* input_to_cell_weights,
+    const TfLiteEvalTensor* input_to_output_weights,
+    const TfLiteEvalTensor* recurrent_to_input_weights,
+    const TfLiteEvalTensor* recurrent_to_forget_weights,
+    const TfLiteEvalTensor* recurrent_to_cell_weights,
+    const TfLiteEvalTensor* recurrent_to_output_weights,
+    const TfLiteEvalTensor* cell_to_input_weights,
+    const TfLiteEvalTensor* cell_to_forget_weights,
+    const TfLiteEvalTensor* cell_to_output_weights,
+    const TfLiteEvalTensor* input_layer_norm_coefficients,
+    const TfLiteEvalTensor* forget_layer_norm_coefficients,
+    const TfLiteEvalTensor* cell_layer_norm_coefficients,
+    const TfLiteEvalTensor* output_layer_norm_coefficients,
+    const TfLiteEvalTensor* input_gate_bias,
+    const TfLiteEvalTensor* forget_gate_bias,
+    const TfLiteEvalTensor* cell_gate_bias,
+    const TfLiteEvalTensor* output_gate_bias,
+    const TfLiteEvalTensor* projection_weights,
+    const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
+    TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
+    TfLiteEvalTensor* output, const IntegerLstmParameter* integer_lstm_param,
+    int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
+    int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, int16_t* scratch7);
+
+}  // namespace tflite
+#endif  // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_H_

Некоторые файлы не были показаны из-за большого количества измененных файлов