fully_connected.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
  13. #define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
  14. #include <cstdint>
  15. #include "tensorflow/lite/c/builtin_op_data.h"
  16. #include "tensorflow/lite/c/common.h"
  17. #include "tensorflow/lite/kernels/internal/types.h"
  18. namespace tflite {
  19. struct OpDataFullyConnected {
  20. // The scaling factor from input to output (aka the 'real multiplier') can
  21. // be represented as a fixed point multiplier plus a left shift.
  22. int32_t output_multiplier;
  23. int output_shift;
  24. // The range of the fused activation layer. For example for kNone and
  25. // uint8_t these would be 0 and 255.
  26. int32_t output_activation_min;
  27. int32_t output_activation_max;
  28. // The index of the temporary tensor where the quantized inputs are cached.
  29. int input_quantized_index;
  30. // Cached zero point values of tensors.
  31. int32_t input_zero_point;
  32. int32_t filter_zero_point;
  33. int32_t output_zero_point;
  34. };
  35. extern const int kFullyConnectedInputTensor;
  36. extern const int kFullyConnectedWeightsTensor;
  37. extern const int kFullyConnectedBiasTensor;
  38. extern const int kFullyConnectedOutputTensor;
  39. // Returns a FullyConnectedParams struct with all the parameters needed for a
  40. // float computation.
  41. FullyConnectedParams FullyConnectedParamsFloat(
  42. TfLiteFusedActivation activation);
  43. // Returns a FullyConnectedParams struct with all the parameters needed for a
  44. // quantized computation.
  45. FullyConnectedParams FullyConnectedParamsQuantized(
  46. const OpDataFullyConnected& op_data);
  47. TfLiteStatus CalculateOpDataFullyConnected(
  48. TfLiteContext* context, TfLiteFusedActivation activation,
  49. TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter,
  50. const TfLiteTensor* bias, TfLiteTensor* output, OpDataFullyConnected* data);
  51. // This is the most generic TfLiteRegistration. The actual supported types may
  52. // still be target dependent. The only requirement is that every implementation
  53. // (reference or optimized) must define this function.
  54. TfLiteRegistration Register_FULLY_CONNECTED();
  55. #if defined(CMSIS_NN) || defined(HEXAGON)
  56. // Returns a TfLiteRegistration struct for kernel variant that only supports
  57. // int8.
  58. TfLiteRegistration Register_FULLY_CONNECTED_INT8();
  59. #else
  60. // Note that while this block gets used for both reference and optimized kernels
  61. // that do not have any specialized implementations, the only goal here is to
  62. // define fallback implementation that allow reference kernels to still be used
  63. // from applications that call a more specific kernel variant.
  64. inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
  65. return Register_FULLY_CONNECTED();
  66. }
  67. #endif
  68. #if defined(CMSIS_NN)
  69. // Returns a TfLiteRegistration struct for kernel variant that only supports
  70. // int16.
  71. TfLiteRegistration Register_FULLY_CONNECTED_INT16();
  72. #else
  73. // Note that while this block gets used for both reference and optimized kernels
  74. // that do not have any specialized implementations, the only goal here is to
  75. // define fallback implementation that allow reference kernels to still be used
  76. // from applications that call a more specific kernel variant.
  77. inline TfLiteRegistration Register_FULLY_CONNECTED_INT16() {
  78. return Register_FULLY_CONNECTED();
  79. }
  80. #endif
  81. } // namespace tflite
  82. #endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_