CTfLiteClass.h 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. /*
  2. #define TFLITE_MINIMAL_CHECK(x) \
  3. if (!(x)) { \
  4. fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
  5. exit(1); \
  6. }
  7. */
  8. #include "tensorflow/lite/micro/all_ops_resolver.h"
  9. #include "tensorflow/lite/micro/micro_error_reporter.h"
  10. #include "tensorflow/lite/micro/micro_interpreter.h"
  11. #include "tensorflow/lite/schema/schema_generated.h"
  12. #include "tensorflow/lite/micro/kernels/micro_ops.h"
  13. #include "esp_err.h"
  14. #include "esp_log.h"
  15. #include "CImageBasis.h"
  16. #ifdef SUPRESS_TFLITE_ERRORS
  17. #include "tensorflow/lite/core/api/error_reporter.h"
  18. #include "tensorflow/lite/micro/compatibility.h"
  19. #include "tensorflow/lite/micro/debug_log.h"
  20. ///// OwnErrorReporter to prevent printing of Errors (especially unavoidable in CalculateActivationRangeQuantized@kerne_util.cc)
  21. namespace tflite {
  22. class OwnMicroErrorReporter : public ErrorReporter {
  23. public:
  24. int Report(const char* format, va_list args) override;
  25. };
  26. } // namespace tflite
  27. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  28. #endif
  29. class CTfLiteClass
  30. {
  31. protected:
  32. tflite::ErrorReporter *error_reporter;
  33. const tflite::Model* model;
  34. tflite::MicroInterpreter* interpreter;
  35. TfLiteTensor* output = nullptr;
  36. static tflite::AllOpsResolver resolver;
  37. int kTensorArenaSize;
  38. uint8_t *tensor_arena;
  39. unsigned char *modelload = NULL;
  40. float* input;
  41. int input_i;
  42. int im_height, im_width, im_channel;
  43. long GetFileSize(std::string filename);
  44. unsigned char* ReadFileToCharArray(std::string _fn);
  45. public:
  46. CTfLiteClass();
  47. ~CTfLiteClass();
  48. bool LoadModel(std::string _fn);
  49. void MakeAllocate();
  50. void GetInputTensorSize();
  51. bool LoadInputImageBasis(CImageBasis *rs);
  52. void Invoke();
  53. int GetAnzOutPut(bool silent = true);
  54. int GetOutClassification(int _von = -1, int _bis = -1);
  55. int GetClassFromImageBasis(CImageBasis *rs);
  56. std::string GetStatusFlow();
  57. float GetOutputValue(int nr);
  58. void GetInputDimension(bool silent);
  59. int ReadInputDimenstion(int _dim);
  60. };