CTfLiteClass.h 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #pragma once
  2. #ifndef CTFLITECLASS_H
  3. #define CTFLITECLASS_H
  4. #include "tensorflow/lite/micro/all_ops_resolver.h"
  5. #include "tensorflow/lite/micro/micro_error_reporter.h"
  6. #include "tensorflow/lite/micro/micro_interpreter.h"
  7. #include "tensorflow/lite/schema/schema_generated.h"
  8. #include "tensorflow/lite/micro/kernels/micro_ops.h"
  9. #include "esp_err.h"
  10. #include "esp_log.h"
  11. #include "CImageBasis.h"
  12. #ifdef SUPRESS_TFLITE_ERRORS
  13. #include "tensorflow/lite/core/api/error_reporter.h"
  14. #include "tensorflow/lite/micro/compatibility.h"
  15. #include "tensorflow/lite/micro/debug_log.h"
  16. ///// OwnErrorReporter to prevent printing of Errors (especially unavoidable in CalculateActivationRangeQuantized@kerne_util.cc)
  17. namespace tflite {
  18. class OwnMicroErrorReporter : public ErrorReporter {
  19. public:
  20. int Report(const char* format, va_list args) override;
  21. };
  22. } // namespace tflite
  23. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  24. #endif
  25. class CTfLiteClass
  26. {
  27. protected:
  28. tflite::ErrorReporter *error_reporter;
  29. const tflite::Model* model;
  30. tflite::MicroInterpreter* interpreter;
  31. TfLiteTensor* output = nullptr;
  32. static tflite::AllOpsResolver resolver;
  33. int kTensorArenaSize;
  34. uint8_t *tensor_arena;
  35. unsigned char *modelload = NULL;
  36. float* input;
  37. int input_i;
  38. int im_height, im_width, im_channel;
  39. long GetFileSize(std::string filename);
  40. unsigned char* ReadFileToCharArray(std::string _fn);
  41. public:
  42. CTfLiteClass();
  43. ~CTfLiteClass();
  44. bool LoadModel(std::string _fn);
  45. void MakeAllocate();
  46. void GetInputTensorSize();
  47. bool LoadInputImageBasis(CImageBasis *rs);
  48. void Invoke();
  49. int GetAnzOutPut(bool silent = true);
  50. int GetOutClassification(int _von = -1, int _bis = -1);
  51. int GetClassFromImageBasis(CImageBasis *rs);
  52. std::string GetStatusFlow();
  53. float GetOutputValue(int nr);
  54. void GetInputDimension(bool silent);
  55. int ReadInputDimenstion(int _dim);
  56. };
  57. #endif //CTFLITECLASS_H