|
@@ -0,0 +1,422 @@
|
|
|
|
|
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
+
|
|
|
|
|
+Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
+you may not use this file except in compliance with the License.
|
|
|
|
|
+You may obtain a copy of the License at
|
|
|
|
|
+
|
|
|
|
|
+ http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
+
|
|
|
|
|
+Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
+distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
+See the License for the specific language governing permissions and
|
|
|
|
|
+limitations under the License.
|
|
|
|
|
+==============================================================================*/
|
|
|
|
|
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
|
|
|
|
|
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
|
|
|
|
|
+
|
|
|
|
|
+#include <algorithm>
|
|
|
|
|
+#include <cmath>
|
|
|
|
|
+#include <cstdint>
|
|
|
|
|
+
|
|
|
|
|
+#include "tensorflow/lite/kernels/internal/common.h"
|
|
|
|
|
+#include "tensorflow/lite/kernels/internal/reference/concatenation.h"
|
|
|
|
|
+#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
|
|
|
|
+#include "tensorflow/lite/kernels/internal/types.h"
|
|
|
|
|
+
|
|
|
|
|
+namespace tflite {
|
|
|
|
|
+namespace reference_ops {
|
|
|
|
|
+
|
|
|
|
|
+inline void LstmCell(
|
|
|
|
|
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
|
|
|
|
|
+ const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
|
|
|
|
|
+ const float* prev_activ_data, const RuntimeShape& weights_shape,
|
|
|
|
|
+ const float* weights_data, const RuntimeShape& unextended_bias_shape,
|
|
|
|
|
+ const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
|
|
|
|
|
+ const float* prev_state_data,
|
|
|
|
|
+ const RuntimeShape& unextended_output_state_shape, float* output_state_data,
|
|
|
|
|
+ const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
|
|
|
|
|
+ const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
|
|
|
|
|
+ const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
|
|
|
|
|
+ const RuntimeShape input_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
|
|
|
|
+ const RuntimeShape prev_activ_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
|
|
|
|
|
+ const RuntimeShape bias_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
|
|
|
|
|
+ const RuntimeShape prev_state_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
|
|
|
|
|
+ const RuntimeShape output_state_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
|
|
|
|
|
+ const RuntimeShape output_activ_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
|
|
|
|
|
+ const RuntimeShape concat_temp_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
|
|
|
|
|
+ const RuntimeShape activ_temp_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
|
|
|
|
|
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
|
|
|
|
|
+
|
|
|
|
|
+ const int weights_dim_count = weights_shape.DimensionsCount();
|
|
|
|
|
+ const int batches =
|
|
|
|
|
+ MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
|
|
|
|
|
+ output_state_shape, 0, output_activ_shape, 0);
|
|
|
|
|
+ const int height =
|
|
|
|
|
+ MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
|
|
|
|
|
+ output_state_shape, 1, output_activ_shape, 1);
|
|
|
|
|
+ const int width =
|
|
|
|
|
+ MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
|
|
|
|
|
+ output_state_shape, 2, output_activ_shape, 2);
|
|
|
|
|
+ const int input_depth = input_shape.Dims(3);
|
|
|
|
|
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
|
|
|
|
|
+ const int total_input_depth = prev_activ_depth + input_depth;
|
|
|
|
|
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
|
|
|
|
|
+ total_input_depth);
|
|
|
|
|
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
|
|
|
|
|
+ const int intern_activ_depth =
|
|
|
|
|
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
|
|
|
|
|
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
|
|
|
|
|
+ intern_activ_depth * total_input_depth);
|
|
|
|
|
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
|
|
|
|
|
+ const int output_depth =
|
|
|
|
|
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
|
|
|
|
|
+ 3, output_activ_shape, 3);
|
|
|
|
|
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
|
|
|
|
|
+
|
|
|
|
|
+ // Concatenate prev_activ and input data together
|
|
|
|
|
+ float const* concat_input_arrays_data[2] = {input_data, prev_activ_data};
|
|
|
|
|
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
|
|
|
|
|
+ &prev_activ_shape};
|
|
|
|
|
+ tflite::ConcatenationParams concat_params;
|
|
|
|
|
+ concat_params.axis = 3;
|
|
|
|
|
+ concat_params.inputs_count = 2;
|
|
|
|
|
+ Concatenation(concat_params, concat_input_arrays_shapes,
|
|
|
|
|
+ concat_input_arrays_data, concat_temp_shape, concat_temp_data);
|
|
|
|
|
+
|
|
|
|
|
+ // Fully connected
|
|
|
|
|
+ tflite::FullyConnectedParams fc_params;
|
|
|
|
|
+ fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
|
|
|
|
+ fc_params.float_activation_max = std::numeric_limits<float>::max();
|
|
|
|
|
+ FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
|
|
|
|
|
+ weights_data, bias_shape, bias_data, activ_temp_shape,
|
|
|
|
|
+ activ_temp_data);
|
|
|
|
|
+
|
|
|
|
|
+ // Memory state update (the LSTM "guts")
|
|
|
|
|
+ for (int b = 0; b < batches; ++b) {
|
|
|
|
|
+ for (int w = 0; w < width; ++w) {
|
|
|
|
|
+ for (int h = 0; h < height; ++h) {
|
|
|
|
|
+ for (int c = 0; c < output_depth; ++c) {
|
|
|
|
|
+ const float input_gate =
|
|
|
|
|
+ 1.f /
|
|
|
|
|
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
|
|
|
|
|
+ 0 * output_depth + c)]));
|
|
|
|
|
+ const float new_input = std::tanh(activ_temp_data[Offset(
|
|
|
|
|
+ activ_temp_shape, b, h, w, 1 * output_depth + c)]);
|
|
|
|
|
+ const float forget_gate =
|
|
|
|
|
+ 1.f /
|
|
|
|
|
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
|
|
|
|
|
+ 2 * output_depth + c)]));
|
|
|
|
|
+ const float output_gate =
|
|
|
|
|
+ 1.f /
|
|
|
|
|
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
|
|
|
|
|
+ 3 * output_depth + c)]));
|
|
|
|
|
+ const float new_state =
|
|
|
|
|
+ input_gate * new_input +
|
|
|
|
|
+ forget_gate *
|
|
|
|
|
+ prev_state_data[Offset(prev_state_shape, b, h, w, c)];
|
|
|
|
|
+ output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
|
|
|
|
|
+ output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
|
|
|
|
|
+ output_gate * std::tanh(new_state);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Quantized LSTM cell implementation.
|
|
|
|
|
+// The quantization of the input, output arrays is as follows:
|
|
|
|
|
+// - The input activations are quantized as uint8 on the interval
|
|
|
|
|
+// [-1, 127/128].
|
|
|
|
|
+// The rationale for that is that is the natural interval for output
|
|
|
|
|
+// activations (see next point) and these need to be concatenated together.
|
|
|
|
|
+// We could accommodate different ranges by re-scaling, but we empirically
|
|
|
|
|
+// found that setting the input activations range to be [-1, 127/128] in the
|
|
|
|
|
+// first place, removing the need for re-scaling, greatly improves accuracy.
|
|
|
|
|
+// - The output activations are quantized as uint8 on the interval
|
|
|
|
|
+// [-1, 127/128].
|
|
|
|
|
+// The rationale for that is that the definition of a LSTM cell makes them
|
|
|
|
|
+// intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
|
|
|
|
|
+// makes for simpler, more accurate fixed-point arithmetic.
|
|
|
|
|
+// - The output-at-previous-timestep state array is obviously quantized as
|
|
|
|
|
+// the output activations.
|
|
|
|
|
+// - The internal LSTM memory (not the output-at-previous-timestep, the other
|
|
|
|
|
+// internal state array) is int16-quantized and may use any power-of-two,
|
|
|
|
|
+// symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
|
|
|
|
|
+// StateIntegerBits below, see the below discussion of that template
|
|
|
|
|
+// parameter ("The StateIntegerBits template parameter").
|
|
|
|
|
+// - The output of the internal fully-connected node is int16-quantized
|
|
|
|
|
+// on the interval [-8, 8 * 32767/32768], the rationale for which is
|
|
|
|
|
+// explained just below ("Why [-8, 8] for fully-connected output?").
|
|
|
|
|
+//
|
|
|
|
|
+//
|
|
|
|
|
+// === The StateIntegerBits template parameter ===
|
|
|
|
|
+//
|
|
|
|
|
+// The StateIntegerBits template parameter controls the fixed-point format used
|
|
|
|
|
+// to represent the internal memory of the LSTM cell (not the
|
|
|
|
|
+// output-at-previous-timestep, the other internal state array). It's currently
|
|
|
|
|
+// a template parameter so that the model can control that. The most typical
|
|
|
|
|
+// value for StateIntegerBits is 4. Other plausible values are anywhere between
|
|
|
|
|
+// 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
|
|
|
|
|
+// and drop that template parameter. The reason why it can't be a runtime
|
|
|
|
|
+// parameter is that this controls the fixed-point format used, i.e. we need to
|
|
|
|
|
+// generate actually different code based on it. In particular, we generate code
|
|
|
|
|
+// for a fixed-point tanh() implementation for that format, which internally
|
|
|
|
|
+// uses a fixed-point exp() implementation, which internally uses a
|
|
|
|
|
+// barrel-shifter with a number of steps that depends on StateIntegerBits.
|
|
|
|
|
+// Another consequence of that is that a higher value of StateIntegerBits
|
|
|
|
|
+// results in a more expensive implementation (more barrel shifter steps
|
|
|
|
|
+// needed).
|
|
|
|
|
+//
|
|
|
|
|
+//
|
|
|
|
|
+// === Why [-8, 8] for fully-connected output? ===
|
|
|
|
|
+//
|
|
|
|
|
+// This array is only fed to Logistic and Tanh functions, for which
|
|
|
|
|
+// the quantized implementation will want to use fixed-point arithmetic,
|
|
|
|
|
+// requiring a power-of-two representation interval. Thus, we should right
|
|
|
|
|
+// away quantize this array to a power-of-two interval; otherwise,
|
|
|
|
|
+// implementation will need to rescale that, losing any benefit that a tighter
|
|
|
|
|
+// representation interval might otherwise yield, while introducing some
|
|
|
|
|
+// numerical error and computational overhead.
|
|
|
|
|
+//
|
|
|
|
|
+// Now, Logistic and Tanh
|
|
|
|
|
+// are nearly constant (nearly equal to their horizontal asymptotes)
|
|
|
|
|
+// outside of a small bounded interval around 0:
|
|
|
|
|
+//
|
|
|
|
|
+// Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4
|
|
|
|
|
+// Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7
|
|
|
|
|
+// Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14
|
|
|
|
|
+//
|
|
|
|
|
+// From this, we see that clamping to [-4, 4] would be too inaccurate
|
|
|
|
|
+// (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
|
|
|
|
|
+// while clamping to [-16, 16] would make no difference even in float32.
|
|
|
|
|
+// However, for a fixed-point implementation in 16-bit integers, using 5
|
|
|
|
|
+// integer bits to represent the [-16, 16] range would leave only 11
|
|
|
|
|
+// fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
|
|
|
|
|
+// representable values. Notice that is higher than the
|
|
|
|
|
+// worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
|
|
|
|
|
+// Using [-8, 8] thus seems like the better compromise overall, enjoying
|
|
|
|
|
+// an increment of 2.4e-4 between representable values and a worst-case
|
|
|
|
|
+// clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
|
|
|
|
|
+// [-16, 16].
|
|
|
|
|
+//
|
|
|
|
|
+// Moreover, all other things being equal, it is nice to choose the narrower
|
|
|
|
|
+// representation range, as that makes the implementation of fixed-point
|
|
|
|
|
+// math functions a little cheaper (each integer bit requires an additional
|
|
|
|
|
+// barrel-shifter atep in the implementation of exp(-x)). That is further
|
|
|
|
|
+// reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
|
|
|
|
|
+// sense for 32-bit float or 32-bit fixed-point quantization, but we are
|
|
|
|
|
+// aiming for 16-bit fixed-point quantization of these internal nodes here.
|
|
|
|
|
+//
|
|
|
|
|
+template <int StateIntegerBits>
|
|
|
|
|
+inline void LstmCell(const LstmCellParams& params,
|
|
|
|
|
+ const RuntimeShape& unextended_input_shape,
|
|
|
|
|
+ const uint8_t* input_data_uint8,
|
|
|
|
|
+ const RuntimeShape& unextended_prev_activ_shape,
|
|
|
|
|
+ const uint8_t* prev_activ_data_uint8,
|
|
|
|
|
+ const RuntimeShape& weights_shape,
|
|
|
|
|
+ const uint8_t* weights_data_uint8,
|
|
|
|
|
+ const RuntimeShape& unextended_bias_shape,
|
|
|
|
|
+ const int32_t* bias_data_int32,
|
|
|
|
|
+ const RuntimeShape& unextended_prev_state_shape,
|
|
|
|
|
+ const int16_t* prev_state_data_int16,
|
|
|
|
|
+ const RuntimeShape& unextended_output_state_shape,
|
|
|
|
|
+ int16_t* output_state_data_int16,
|
|
|
|
|
+ const RuntimeShape& unextended_output_activ_shape,
|
|
|
|
|
+ uint8_t* output_activ_data_uint8,
|
|
|
|
|
+ const RuntimeShape& unextended_concat_temp_shape,
|
|
|
|
|
+ uint8_t* concat_temp_data_uint8,
|
|
|
|
|
+ const RuntimeShape& unextended_activ_temp_shape,
|
|
|
|
|
+ int16_t* activ_temp_data_int16, void* gemmlowp_context) {
|
|
|
|
|
+ (void)gemmlowp_context; // only used in optimized code.
|
|
|
|
|
+ int32_t weights_zero_point = params.weights_zero_point;
|
|
|
|
|
+ int32_t accum_multiplier = params.accum_multiplier;
|
|
|
|
|
+ int accum_shift = params.accum_shift;
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
|
|
|
|
|
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
|
|
|
|
|
+ const RuntimeShape input_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
|
|
|
|
+ const RuntimeShape prev_activ_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
|
|
|
|
|
+ const RuntimeShape bias_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
|
|
|
|
|
+ const RuntimeShape prev_state_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
|
|
|
|
|
+ const RuntimeShape output_state_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
|
|
|
|
|
+ const RuntimeShape output_activ_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
|
|
|
|
|
+ const RuntimeShape concat_temp_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
|
|
|
|
|
+ const RuntimeShape activ_temp_shape =
|
|
|
|
|
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
|
|
|
|
|
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
|
|
|
|
|
+
|
|
|
|
|
+ // Gather dimensions information, and perform consistency checks.
|
|
|
|
|
+ const int weights_dim_count = weights_shape.DimensionsCount();
|
|
|
|
|
+ const int outer_size = MatchingFlatSizeSkipDim(
|
|
|
|
|
+ input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
|
|
|
|
|
+ output_activ_shape);
|
|
|
|
|
+ const int input_depth = input_shape.Dims(3);
|
|
|
|
|
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
|
|
|
|
|
+ const int total_input_depth = prev_activ_depth + input_depth;
|
|
|
|
|
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
|
|
|
|
|
+ total_input_depth);
|
|
|
|
|
+ const int intern_activ_depth =
|
|
|
|
|
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
|
|
|
|
|
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
|
|
|
|
|
+ intern_activ_depth * total_input_depth);
|
|
|
|
|
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
|
|
|
|
|
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
|
|
|
|
|
+ const int output_depth =
|
|
|
|
|
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
|
|
|
|
|
+ 3, output_activ_shape, 3);
|
|
|
|
|
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
|
|
|
|
|
+ const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
|
|
|
|
|
+ const int fc_output_depth =
|
|
|
|
|
+ MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
|
|
|
|
|
+ const int fc_accum_depth = total_input_depth;
|
|
|
|
|
+ TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
|
|
|
|
|
+
|
|
|
|
|
+ // Depth-concatenate prev_activ and input data together.
|
|
|
|
|
+ uint8_t const* concat_input_arrays_data[2] = {input_data_uint8,
|
|
|
|
|
+ prev_activ_data_uint8};
|
|
|
|
|
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
|
|
|
|
|
+ &prev_activ_shape};
|
|
|
|
|
+ tflite::ConcatenationParams concat_params;
|
|
|
|
|
+ concat_params.axis = 3;
|
|
|
|
|
+ concat_params.inputs_count = 2;
|
|
|
|
|
+ Concatenation(concat_params, concat_input_arrays_shapes,
|
|
|
|
|
+ concat_input_arrays_data, concat_temp_shape,
|
|
|
|
|
+ concat_temp_data_uint8);
|
|
|
|
|
+
|
|
|
|
|
+ // Implementation of the fully connected node inside the LSTM cell.
|
|
|
|
|
+ // The operands are 8-bit integers, the accumulators are internally 32bit
|
|
|
|
|
+ // integers, and the output is 16-bit fixed-point with 3 integer bits so
|
|
|
|
|
+ // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
|
|
|
|
|
+ // is explained in the function comment above.
|
|
|
|
|
+ for (int b = 0; b < fc_batches; ++b) {
|
|
|
|
|
+ for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
|
|
|
|
|
+ // Internal accumulation.
|
|
|
|
|
+ // Initialize accumulator with the bias-value.
|
|
|
|
|
+ int32_t accum = bias_data_int32[out_c];
|
|
|
|
|
+ // Accumulation loop.
|
|
|
|
|
+ for (int d = 0; d < fc_accum_depth; ++d) {
|
|
|
|
|
+ int16_t input_val =
|
|
|
|
|
+ concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
|
|
|
|
|
+ int16_t weights_val =
|
|
|
|
|
+ weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
|
|
|
|
|
+ accum += input_val * weights_val;
|
|
|
|
|
+ }
|
|
|
|
|
+ // Down-scale the final int32 accumulator to the scale used by our
|
|
|
|
|
+ // (16-bit, using 3 integer bits) fixed-point format. The quantized
|
|
|
|
|
+ // multiplier and shift here have been pre-computed offline
|
|
|
|
|
+ // (e.g. by toco).
|
|
|
|
|
+ accum =
|
|
|
|
|
+ MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
|
|
|
|
|
+ // Saturate, cast to int16, and store to the temporary activations array.
|
|
|
|
|
+ accum = std::max(-32768, std::min(32767, accum));
|
|
|
|
|
+ activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Rest of the LSTM cell: tanh and logistic math functions, and some adds
|
|
|
|
|
+ // and muls, all done in 16-bit fixed-point.
|
|
|
|
|
+ for (int b = 0; b < outer_size; ++b) {
|
|
|
|
|
+ for (int c = 0; c < output_depth; ++c) {
|
|
|
|
|
+ // Define the fixed-point data types that we will use here. All use
|
|
|
|
|
+ // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
|
|
|
|
|
+ // They only differ by the number of integral vs. fractional bits,
|
|
|
|
|
+ // determining the range of values that they can represent.
|
|
|
|
|
+ //
|
|
|
|
|
+ // F0 uses 0 integer bits, range [-1, 1].
|
|
|
|
|
+ // This is the return type of math functions such as tanh, logistic,
|
|
|
|
|
+ // whose range is in [-1, 1].
|
|
|
|
|
+ using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
|
|
|
|
|
+ // F3 uses 3 integer bits, range [-8, 8].
|
|
|
|
|
+ // This is the range of the previous fully-connected node's output,
|
|
|
|
|
+ // which is our input here.
|
|
|
|
|
+ using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
|
|
|
|
|
+ // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
|
|
|
|
|
+ // 2^StateIntegerBits]. It's used to represent the internal state, whose
|
|
|
|
|
+ // number of integer bits is currently dictated by the model. See comment
|
|
|
|
|
+ // on the StateIntegerBits template parameter above.
|
|
|
|
|
+ using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
|
|
|
|
|
+ // Implementation of input gate, using fixed-point logistic function.
|
|
|
|
|
+ F3 input_gate_input = F3::FromRaw(
|
|
|
|
|
+ activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
|
|
|
|
|
+ F0 input_gate_output = gemmlowp::logistic(input_gate_input);
|
|
|
|
|
+ // Implementation of input modulation gate, using fixed-point tanh
|
|
|
|
|
+ // function.
|
|
|
|
|
+ F3 input_modulation_gate_input = F3::FromRaw(
|
|
|
|
|
+ activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
|
|
|
|
|
+ F0 input_modulation_gate_output =
|
|
|
|
|
+ gemmlowp::tanh(input_modulation_gate_input);
|
|
|
|
|
+ // Implementation of forget gate, using fixed-point logistic function.
|
|
|
|
|
+ F3 forget_gate_input = F3::FromRaw(
|
|
|
|
|
+ activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
|
|
|
|
|
+ F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
|
|
|
|
|
+ // Implementation of output gate, using fixed-point logistic function.
|
|
|
|
|
+ F3 output_gate_input = F3::FromRaw(
|
|
|
|
|
+ activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
|
|
|
|
|
+ F0 output_gate_output = gemmlowp::logistic(output_gate_input);
|
|
|
|
|
+ // Implementation of internal multiplication nodes, still in fixed-point.
|
|
|
|
|
+ F0 input_times_input_modulation =
|
|
|
|
|
+ input_gate_output * input_modulation_gate_output;
|
|
|
|
|
+ FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
|
|
|
|
|
+ FS prev_state_times_forget_state = forget_gate_output * prev_state;
|
|
|
|
|
+ // Implementation of internal addition node, saturating.
|
|
|
|
|
+ FS new_state = gemmlowp::SaturatingAdd(
|
|
|
|
|
+ gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
|
|
|
|
|
+ prev_state_times_forget_state);
|
|
|
|
|
+ // Implementation of last internal Tanh node, still in fixed-point.
|
|
|
|
|
+ // Since a Tanh fixed-point implementation is specialized for a given
|
|
|
|
|
+ // number or integer bits, and each specialization can have a substantial
|
|
|
|
|
+ // code size, and we already used above a Tanh on an input with 3 integer
|
|
|
|
|
+ // bits, and per the table in the above function comment there is no
|
|
|
|
|
+ // significant accuracy to be lost by clamping to [-8, +8] for a
|
|
|
|
|
+ // 3-integer-bits representation, let us just do that. This helps people
|
|
|
|
|
+ // porting this to targets where code footprint must be minimized.
|
|
|
|
|
+ F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
|
|
|
|
|
+ F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
|
|
|
|
|
+ // Store the new internal state back to memory, as 16-bit integers.
|
|
|
|
|
+ // Note: here we store the original value with StateIntegerBits, not
|
|
|
|
|
+ // the rescaled 3-integer-bits value fed to tanh.
|
|
|
|
|
+ output_state_data_int16[b * output_depth + c] = new_state.raw();
|
|
|
|
|
+ // Down-scale the output activations to 8-bit integers, saturating,
|
|
|
|
|
+ // and store back to memory.
|
|
|
|
|
+ int16_t rescaled_output_activ =
|
|
|
|
|
+ gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
|
|
|
|
|
+ int16_t clamped_output_activ = std::max<int16_t>(
|
|
|
|
|
+ -128, std::min<int16_t>(127, rescaled_output_activ));
|
|
|
|
|
+ output_activ_data_uint8[b * output_depth + c] =
|
|
|
|
|
+ 128 + clamped_output_activ;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+} // namespace reference_ops
|
|
|
|
|
+} // namespace tflite
|
|
|
|
|
+#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
|