CTfLiteClass.cpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. #include "CTfLiteClass.h"
  2. #include "ClassLogFile.h"
  3. #include "Helper.h"
  4. #include "psram.h"
  5. #include "esp_log.h"
  6. #include "../../include/defines.h"
  7. #include <sys/stat.h>
  8. // #define DEBUG_DETAIL_ON
  9. static const char *TAG = "TFLITE";
  10. /// Static Resolver muss mit allen Operatoren geladen Werden, die benöägit werden - ABER nur 1x --> gesonderte Funktion /////////////////////////////
  11. static bool MakeStaticResolverDone = false;
  12. static tflite::MicroMutableOpResolver<15> resolver;
  13. void MakeStaticResolver()
  14. {
  15. if (MakeStaticResolverDone)
  16. return;
  17. MakeStaticResolverDone = true;
  18. resolver.AddFullyConnected();
  19. resolver.AddReshape();
  20. resolver.AddSoftmax();
  21. resolver.AddConv2D();
  22. resolver.AddMaxPool2D();
  23. resolver.AddQuantize();
  24. resolver.AddMul();
  25. resolver.AddAdd();
  26. resolver.AddLeakyRelu();
  27. resolver.AddDequantize();
  28. }
  29. ////////////////////////////////////////////////////////////////////////////////////////
  30. float CTfLiteClass::GetOutputValue(int nr)
  31. {
  32. TfLiteTensor* output2 = this->interpreter->output(0);
  33. int numeroutput = output2->dims->data[1];
  34. if ((nr+1) > numeroutput)
  35. return -1000;
  36. return output2->data.f[nr];
  37. }
  38. int CTfLiteClass::GetClassFromImageBasis(CImageBasis *rs)
  39. {
  40. if (!LoadInputImageBasis(rs))
  41. return -1000;
  42. Invoke();
  43. return GetOutClassification();
  44. }
  45. int CTfLiteClass::GetOutClassification(int _von, int _bis)
  46. {
  47. TfLiteTensor* output2 = interpreter->output(0);
  48. float zw_max;
  49. float zw;
  50. int zw_class;
  51. if (output2 == NULL)
  52. return -1;
  53. int numeroutput = output2->dims->data[1];
  54. //ESP_LOGD(TAG, "number output neurons: %d", numeroutput);
  55. if (_bis == -1)
  56. _bis = numeroutput -1;
  57. if (_von == -1)
  58. _von = 0;
  59. if (_bis >= numeroutput)
  60. {
  61. ESP_LOGD(TAG, "NUMBER OF OUTPUT NEURONS does not match required classification!");
  62. return -1;
  63. }
  64. zw_max = output2->data.f[_von];
  65. zw_class = _von;
  66. for (int i = _von + 1; i <= _bis; ++i)
  67. {
  68. zw = output2->data.f[i];
  69. if (zw > zw_max)
  70. {
  71. zw_max = zw;
  72. zw_class = i;
  73. }
  74. }
  75. return (zw_class - _von);
  76. }
  77. void CTfLiteClass::GetInputDimension(bool silent = false)
  78. {
  79. TfLiteTensor* input2 = this->interpreter->input(0);
  80. int numdim = input2->dims->size;
  81. if (!silent) ESP_LOGD(TAG, "NumDimension: %d", numdim);
  82. int sizeofdim;
  83. for (int j = 0; j < numdim; ++j)
  84. {
  85. sizeofdim = input2->dims->data[j];
  86. if (!silent) ESP_LOGD(TAG, "SizeOfDimension %d: %d", j, sizeofdim);
  87. if (j == 1) im_height = sizeofdim;
  88. if (j == 2) im_width = sizeofdim;
  89. if (j == 3) im_channel = sizeofdim;
  90. }
  91. }
  92. int CTfLiteClass::ReadInputDimenstion(int _dim)
  93. {
  94. if (_dim == 0)
  95. return im_width;
  96. if (_dim == 1)
  97. return im_height;
  98. if (_dim == 2)
  99. return im_channel;
  100. return -1;
  101. }
  102. int CTfLiteClass::GetAnzOutPut(bool silent)
  103. {
  104. TfLiteTensor* output2 = this->interpreter->output(0);
  105. int numdim = output2->dims->size;
  106. if (!silent) ESP_LOGD(TAG, "NumDimension: %d", numdim);
  107. int sizeofdim;
  108. for (int j = 0; j < numdim; ++j)
  109. {
  110. sizeofdim = output2->dims->data[j];
  111. if (!silent) ESP_LOGD(TAG, "SizeOfDimension %d: %d", j, sizeofdim);
  112. }
  113. float fo;
  114. // Process the inference results.
  115. int numeroutput = output2->dims->data[1];
  116. for (int i = 0; i < numeroutput; ++i)
  117. {
  118. fo = output2->data.f[i];
  119. if (!silent) ESP_LOGD(TAG, "Result %d: %f", i, fo);
  120. }
  121. return numeroutput;
  122. }
  123. void CTfLiteClass::Invoke()
  124. {
  125. if (interpreter != nullptr)
  126. interpreter->Invoke();
  127. }
  128. bool CTfLiteClass::LoadInputImageBasis(CImageBasis *rs)
  129. {
  130. #ifdef DEBUG_DETAIL_ON
  131. LogFile.WriteHeapInfo("CTfLiteClass::LoadInputImageBasis - Start");
  132. #endif
  133. unsigned int w = rs->width;
  134. unsigned int h = rs->height;
  135. unsigned char red, green, blue;
  136. // ESP_LOGD(TAG, "Image: %s size: %d x %d\n", _fn.c_str(), w, h);
  137. input_i = 0;
  138. float* input_data_ptr = (interpreter->input(0))->data.f;
  139. for (int y = 0; y < h; ++y)
  140. for (int x = 0; x < w; ++x)
  141. {
  142. red = rs->GetPixelColor(x, y, 0);
  143. green = rs->GetPixelColor(x, y, 1);
  144. blue = rs->GetPixelColor(x, y, 2);
  145. *(input_data_ptr) = (float) red;
  146. input_data_ptr++;
  147. *(input_data_ptr) = (float) green;
  148. input_data_ptr++;
  149. *(input_data_ptr) = (float) blue;
  150. input_data_ptr++;
  151. }
  152. #ifdef DEBUG_DETAIL_ON
  153. LogFile.WriteHeapInfo("CTfLiteClass::LoadInputImageBasis - done");
  154. #endif
  155. return true;
  156. }
  157. bool CTfLiteClass::MakeAllocate()
  158. {
  159. MakeStaticResolver();
  160. #ifdef DEBUG_DETAIL_ON
  161. LogFile.WriteHeapInfo("CTLiteClass::Alloc start");
  162. #endif
  163. LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::MakeAllocate");
  164. this->interpreter = new tflite::MicroInterpreter(this->model, resolver, this->tensor_arena, this->kTensorArenaSize);
  165. // this->interpreter = new tflite::MicroInterpreter(this->model, resolver, this->tensor_arena, this->kTensorArenaSize, this->error_reporter);
  166. if (this->interpreter)
  167. {
  168. TfLiteStatus allocate_status = this->interpreter->AllocateTensors();
  169. if (allocate_status != kTfLiteOk) {
  170. TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
  171. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "AllocateTensors() failed");
  172. this->GetInputDimension();
  173. return false;
  174. }
  175. }
  176. else
  177. {
  178. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "new tflite::MicroInterpreter failed");
  179. LogFile.WriteHeapInfo("CTfLiteClass::MakeAllocate-new tflite::MicroInterpreter failed");
  180. return false;
  181. }
  182. #ifdef DEBUG_DETAIL_ON
  183. LogFile.WriteHeapInfo("CTLiteClass::Alloc done");
  184. #endif
  185. return true;
  186. }
  187. void CTfLiteClass::GetInputTensorSize()
  188. {
  189. #ifdef DEBUG_DETAIL_ON
  190. float *zw = this->input;
  191. int test = sizeof(zw);
  192. ESP_LOGD(TAG, "Input Tensor Dimension: %d", test);
  193. #endif
  194. }
  195. long CTfLiteClass::GetFileSize(std::string filename)
  196. {
  197. struct stat stat_buf;
  198. long rc = stat(filename.c_str(), &stat_buf);
  199. return rc == 0 ? stat_buf.st_size : -1;
  200. }
  201. bool CTfLiteClass::ReadFileToModel(std::string _fn)
  202. {
  203. LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::ReadFileToModel: " + _fn);
  204. long size = GetFileSize(_fn);
  205. if (size == -1)
  206. {
  207. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "Model file doesn't exist: " + _fn + "!");
  208. return false;
  209. }
  210. else if(size > MAX_MODEL_SIZE) {
  211. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "Unable to load model '" + _fn + "'! It does not fit in the reserved shared memory in PSRAM!");
  212. return false;
  213. }
  214. LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "Loading Model " + _fn + " /size: " + std::to_string(size) + " bytes...");
  215. #ifdef DEBUG_DETAIL_ON
  216. LogFile.WriteHeapInfo("CTLiteClass::Alloc modelfile start");
  217. #endif
  218. modelfile = (unsigned char*)psram_get_shared_model_memory();
  219. if(modelfile != NULL)
  220. {
  221. FILE* f = fopen(_fn.c_str(), "rb"); // previously only "r
  222. fread(modelfile, 1, size, f);
  223. fclose(f);
  224. #ifdef DEBUG_DETAIL_ON
  225. LogFile.WriteHeapInfo("CTLiteClass::Alloc modelfile successful");
  226. #endif
  227. return true;
  228. }
  229. else
  230. {
  231. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "CTfLiteClass::ReadFileToModel: Can't allocate enough memory: " + std::to_string(size));
  232. LogFile.WriteHeapInfo("CTfLiteClass::ReadFileToModel");
  233. return false;
  234. }
  235. }
  236. bool CTfLiteClass::LoadModel(std::string _fn)
  237. {
  238. #ifdef SUPRESS_TFLITE_ERRORS
  239. // this->error_reporter = new tflite::ErrorReporter;
  240. this->error_reporter = new tflite::OwnMicroErrorReporter;
  241. #else
  242. this->error_reporter = new tflite::MicroErrorReporter;
  243. #endif
  244. LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::LoadModel");
  245. if (!ReadFileToModel(_fn.c_str())) {
  246. return false;
  247. }
  248. model = tflite::GetModel(modelfile);
  249. if(model == nullptr)
  250. return false;
  251. return true;
  252. }
  253. CTfLiteClass::CTfLiteClass()
  254. {
  255. this->model = nullptr;
  256. this->modelfile = NULL;
  257. this->interpreter = nullptr;
  258. this->input = nullptr;
  259. this->output = nullptr;
  260. this->kTensorArenaSize = TENSOR_ARENA_SIZE;
  261. this->tensor_arena = (uint8_t*)psram_get_shared_tensor_arena_memory();
  262. }
  263. CTfLiteClass::~CTfLiteClass()
  264. {
  265. delete this->interpreter;
  266. // delete this->error_reporter;
  267. psram_free_shared_tensor_arena_and_model_memory();
  268. }
  269. #ifdef SUPRESS_TFLITE_ERRORS
  270. namespace tflite
  271. {
  272. //tflite::ErrorReporter
  273. // int OwnMicroErrorReporter::Report(const char* format, va_list args)
  274. int OwnMicroErrorReporter::Report(const char* format, va_list args)
  275. {
  276. return 0;
  277. }
  278. }
  279. #endif