lstm_shared.h 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_
  13. #define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_
  14. namespace tflite {
  15. // Input Tensors of size {n_batch, n_input}
  16. constexpr int kLstmInputTensor = 0;
  17. // Input weight tensors of size: {n_cell, n_input}
  18. constexpr int kLstmInputToInputWeightsTensor = 1; // Optional
  19. constexpr int kLstmInputToForgetWeightsTensor = 2;
  20. constexpr int kLstmInputToCellWeightsTensor = 3;
  21. constexpr int kLstmInputToOutputWeightsTensor = 4;
  22. // Recurrent weight tensors of size {n_cell, n_output}
  23. constexpr int kLstmRecurrentToInputWeightsTensor = 5; // Optional
  24. constexpr int kLstmRecurrentToForgetWeightsTensor = 6;
  25. constexpr int kLstmRecurrentToCellWeightsTensor = 7;
  26. constexpr int kLstmRecurrentToOutputWeightsTensor = 8;
  27. // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
  28. constexpr int kLstmCellToInputWeightsTensor = 9; // Optional
  29. constexpr int kLstmCellToForgetWeightsTensor = 10; // Optional
  30. constexpr int kLstmCellToOutputWeightsTensor = 11; // Optional
  31. // Gates bias tensors of size {n_cell}
  32. constexpr int kLstmInputGateBiasTensor = 12; // Optional
  33. constexpr int kLstmForgetGateBiasTensor = 13;
  34. constexpr int kLstmCellGateBiasTensor = 14;
  35. constexpr int kLstmOutputGateBiasTensor = 15;
  36. // Projection weight tensor of size {n_output, n_cell}
  37. constexpr int kLstmProjectionWeightsTensor = 16; // Optional
  38. // Projection bias tensor of size {n_output}
  39. constexpr int kLstmProjectionBiasTensor = 17; // Optional
  40. // These state tensors are defined as variable tensors, and will be modified by
  41. // this op.
  42. constexpr int kLstmOutputStateTensor = 18;
  43. constexpr int kLstmCellStateTensor = 19;
  44. // Layer norm coefficient tensors of size {n_cell}, representing a diagonal
  45. // matrix.
  46. constexpr int kLstmInputLayerNormCoefficientsTensor = 20; // Optional
  47. constexpr int kLstmForgetLayerNormCoefficientsTensor = 21; // Optional
  48. constexpr int kLstmCellLayerNormCoefficientsTensor = 22; // Optional
  49. constexpr int kLstmOutputLayerNormCoefficientsTensor = 23; // Optional
  50. // Output tensors.
  51. constexpr int kLstmOutputTensor = 0;
  52. } // namespace tflite
  53. #endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_SHARED_H_