CTfLiteClass.h 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #pragma once
  2. #ifndef CTFLITECLASS_H
  3. #define CTFLITECLASS_H
  4. #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
  5. #include "tensorflow/lite/micro/micro_interpreter.h"
  6. #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
  7. #include "tensorflow/lite/micro/kernels/micro_ops.h"
  8. #include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"
  9. #include "tensorflow/lite/micro/micro_interpreter.h"
  10. #include "tensorflow/lite/schema/schema_generated.h"
  11. #include "tensorflow/lite/micro/kernels/micro_ops.h"
  12. #include "esp_err.h"
  13. #include "esp_log.h"
  14. #include "CImageBasis.h"
  15. #ifdef SUPRESS_TFLITE_ERRORS
  16. #include "tensorflow/lite/core/api/error_reporter.h"
  17. #include "tensorflow/lite/micro/compatibility.h"
  18. #include "tensorflow/lite/micro/debug_log.h"
  19. ///// OwnErrorReporter to prevent printing of Errors (especially unavoidable in CalculateActivationRangeQuantized@kerne_util.cc)
  20. namespace tflite {
  21. class OwnMicroErrorReporter : public ErrorReporter {
  22. public:
  23. int Report(const char* format, va_list args) override;
  24. };
  25. } // namespace tflite
  26. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  27. #endif
  28. class CTfLiteClass
  29. {
  30. protected:
  31. tflite::ErrorReporter *error_reporter;
  32. const tflite::Model* model;
  33. tflite::MicroInterpreter* interpreter;
  34. TfLiteTensor* output = nullptr;
  35. int kTensorArenaSize;
  36. uint8_t *tensor_arena;
  37. unsigned char *modelfile = NULL;
  38. float* input;
  39. int input_i;
  40. int im_height, im_width, im_channel;
  41. long GetFileSize(std::string filename);
  42. bool ReadFileToModel(std::string _fn);
  43. public:
  44. CTfLiteClass();
  45. ~CTfLiteClass();
  46. bool LoadModel(std::string _fn);
  47. bool MakeAllocate();
  48. void GetInputTensorSize();
  49. bool LoadInputImageBasis(CImageBasis *rs);
  50. void Invoke();
  51. int GetAnzOutPut(bool silent = true);
  52. int GetOutClassification(int _von = -1, int _bis = -1);
  53. int GetClassFromImageBasis(CImageBasis *rs);
  54. std::string GetStatusFlow();
  55. float GetOutputValue(int nr);
  56. void GetInputDimension(bool silent);
  57. int ReadInputDimenstion(int _dim);
  58. };
  59. void MakeStaticResolver();
  60. #endif //CTFLITECLASS_H