CTfLiteClass.cpp 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. #include "CTfLiteClass.h"
  2. #include "ClassLogFile.h"
  3. #include "Helper.h"
  4. #include "psram.h"
  5. #include "esp_log.h"
  6. #include "../../include/defines.h"
  7. #include <sys/stat.h>
  8. // #define DEBUG_DETAIL_ON
  9. static const char *TAG = "TFLITE";
  10. void CTfLiteClass::MakeStaticResolver()
  11. {
  12. resolver.AddFullyConnected();
  13. resolver.AddReshape();
  14. resolver.AddSoftmax();
  15. resolver.AddConv2D();
  16. resolver.AddMaxPool2D();
  17. resolver.AddQuantize();
  18. resolver.AddMul();
  19. resolver.AddAdd();
  20. resolver.AddLeakyRelu();
  21. resolver.AddDequantize();
  22. }
  23. float CTfLiteClass::GetOutputValue(int nr)
  24. {
  25. TfLiteTensor* output2 = this->interpreter->output(0);
  26. int numeroutput = output2->dims->data[1];
  27. if ((nr+1) > numeroutput)
  28. return -1000;
  29. return output2->data.f[nr];
  30. }
  31. int CTfLiteClass::GetClassFromImageBasis(CImageBasis *rs)
  32. {
  33. if (!LoadInputImageBasis(rs))
  34. return -1000;
  35. Invoke();
  36. return GetOutClassification();
  37. }
  38. int CTfLiteClass::GetOutClassification(int _von, int _bis)
  39. {
  40. TfLiteTensor* output2 = interpreter->output(0);
  41. float zw_max;
  42. float zw;
  43. int zw_class;
  44. if (output2 == NULL)
  45. return -1;
  46. int numeroutput = output2->dims->data[1];
  47. //ESP_LOGD(TAG, "number output neurons: %d", numeroutput);
  48. if (_bis == -1)
  49. _bis = numeroutput -1;
  50. if (_von == -1)
  51. _von = 0;
  52. if (_bis >= numeroutput)
  53. {
  54. ESP_LOGD(TAG, "NUMBER OF OUTPUT NEURONS does not match required classification!");
  55. return -1;
  56. }
  57. zw_max = output2->data.f[_von];
  58. zw_class = _von;
  59. for (int i = _von + 1; i <= _bis; ++i)
  60. {
  61. zw = output2->data.f[i];
  62. if (zw > zw_max)
  63. {
  64. zw_max = zw;
  65. zw_class = i;
  66. }
  67. }
  68. return (zw_class - _von);
  69. }
  70. void CTfLiteClass::GetInputDimension(bool silent = false)
  71. {
  72. TfLiteTensor* input2 = this->interpreter->input(0);
  73. int numdim = input2->dims->size;
  74. if (!silent) ESP_LOGD(TAG, "NumDimension: %d", numdim);
  75. int sizeofdim;
  76. for (int j = 0; j < numdim; ++j)
  77. {
  78. sizeofdim = input2->dims->data[j];
  79. if (!silent) ESP_LOGD(TAG, "SizeOfDimension %d: %d", j, sizeofdim);
  80. if (j == 1) im_height = sizeofdim;
  81. if (j == 2) im_width = sizeofdim;
  82. if (j == 3) im_channel = sizeofdim;
  83. }
  84. }
  85. int CTfLiteClass::ReadInputDimenstion(int _dim)
  86. {
  87. if (_dim == 0)
  88. return im_width;
  89. if (_dim == 1)
  90. return im_height;
  91. if (_dim == 2)
  92. return im_channel;
  93. return -1;
  94. }
  95. int CTfLiteClass::GetAnzOutPut(bool silent)
  96. {
  97. TfLiteTensor* output2 = this->interpreter->output(0);
  98. int numdim = output2->dims->size;
  99. if (!silent) ESP_LOGD(TAG, "NumDimension: %d", numdim);
  100. int sizeofdim;
  101. for (int j = 0; j < numdim; ++j)
  102. {
  103. sizeofdim = output2->dims->data[j];
  104. if (!silent) ESP_LOGD(TAG, "SizeOfDimension %d: %d", j, sizeofdim);
  105. }
  106. float fo;
  107. // Process the inference results.
  108. int numeroutput = output2->dims->data[1];
  109. for (int i = 0; i < numeroutput; ++i)
  110. {
  111. fo = output2->data.f[i];
  112. if (!silent) ESP_LOGD(TAG, "Result %d: %f", i, fo);
  113. }
  114. return numeroutput;
  115. }
  116. void CTfLiteClass::Invoke()
  117. {
  118. if (interpreter != nullptr)
  119. interpreter->Invoke();
  120. }
  121. bool CTfLiteClass::LoadInputImageBasis(CImageBasis *rs)
  122. {
  123. #ifdef DEBUG_DETAIL_ON
  124. LogFile.WriteHeapInfo("CTfLiteClass::LoadInputImageBasis - Start");
  125. #endif
  126. unsigned int w = rs->width;
  127. unsigned int h = rs->height;
  128. unsigned char red, green, blue;
  129. // ESP_LOGD(TAG, "Image: %s size: %d x %d\n", _fn.c_str(), w, h);
  130. input_i = 0;
  131. float* input_data_ptr = (interpreter->input(0))->data.f;
  132. for (int y = 0; y < h; ++y)
  133. for (int x = 0; x < w; ++x)
  134. {
  135. red = rs->GetPixelColor(x, y, 0);
  136. green = rs->GetPixelColor(x, y, 1);
  137. blue = rs->GetPixelColor(x, y, 2);
  138. *(input_data_ptr) = (float) red;
  139. input_data_ptr++;
  140. *(input_data_ptr) = (float) green;
  141. input_data_ptr++;
  142. *(input_data_ptr) = (float) blue;
  143. input_data_ptr++;
  144. }
  145. #ifdef DEBUG_DETAIL_ON
  146. LogFile.WriteHeapInfo("CTfLiteClass::LoadInputImageBasis - done");
  147. #endif
  148. return true;
  149. }
  150. bool CTfLiteClass::MakeAllocate()
  151. {
  152. MakeStaticResolver();
  153. #ifdef DEBUG_DETAIL_ON
  154. LogFile.WriteHeapInfo("CTLiteClass::Alloc start");
  155. #endif
  156. LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::MakeAllocate");
  157. this->interpreter = new tflite::MicroInterpreter(this->model, resolver, this->tensor_arena, this->kTensorArenaSize);
  158. LogFile.WriteToFile(ESP_LOG_INFO, TAG, "Trying to load the model. If it crashes here, it ist most likely due to a corrupted model!");
  159. if (this->interpreter)
  160. {
  161. TfLiteStatus allocate_status = this->interpreter->AllocateTensors();
  162. if (allocate_status != kTfLiteOk) {
  163. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "AllocateTensors() failed");
  164. this->GetInputDimension();
  165. return false;
  166. }
  167. }
  168. else
  169. {
  170. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "new tflite::MicroInterpreter failed");
  171. LogFile.WriteHeapInfo("CTfLiteClass::MakeAllocate-new tflite::MicroInterpreter failed");
  172. return false;
  173. }
  174. #ifdef DEBUG_DETAIL_ON
  175. LogFile.WriteHeapInfo("CTLiteClass::Alloc done");
  176. #endif
  177. return true;
  178. }
  179. void CTfLiteClass::GetInputTensorSize()
  180. {
  181. #ifdef DEBUG_DETAIL_ON
  182. float *zw = this->input;
  183. int test = sizeof(zw);
  184. ESP_LOGD(TAG, "Input Tensor Dimension: %d", test);
  185. #endif
  186. }
  187. long CTfLiteClass::GetFileSize(std::string filename)
  188. {
  189. struct stat stat_buf;
  190. long rc = -1;
  191. FILE *pFile = fopen(filename.c_str(), "rb"); // previously only "rb
  192. if (pFile != NULL)
  193. {
  194. rc = stat(filename.c_str(), &stat_buf);
  195. fclose(pFile);
  196. }
  197. return rc == 0 ? stat_buf.st_size : -1;
  198. }
  199. bool CTfLiteClass::ReadFileToModel(std::string _fn)
  200. {
  201. LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::ReadFileToModel: " + _fn);
  202. long size = GetFileSize(_fn);
  203. if (size == -1)
  204. {
  205. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "Model file doesn't exist: " + _fn + "!");
  206. return false;
  207. }
  208. else if(size > MAX_MODEL_SIZE) {
  209. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "Unable to load model '" + _fn + "'! It does not fit in the reserved shared memory in PSRAM!");
  210. return false;
  211. }
  212. LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "Loading Model " + _fn + " /size: " + std::to_string(size) + " bytes...");
  213. #ifdef DEBUG_DETAIL_ON
  214. LogFile.WriteHeapInfo("CTLiteClass::Alloc modelfile start");
  215. #endif
  216. modelfile = (unsigned char*)psram_get_shared_model_memory();
  217. if (modelfile != NULL)
  218. {
  219. FILE *pFile = fopen(_fn.c_str(), "rb"); // previously only "rb
  220. if (pFile != NULL)
  221. {
  222. fread(modelfile, 1, size, pFile);
  223. fclose(pFile);
  224. #ifdef DEBUG_DETAIL_ON
  225. LogFile.WriteHeapInfo("CTLiteClass::Alloc modelfile successful");
  226. #endif
  227. return true;
  228. }
  229. else
  230. {
  231. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "CTfLiteClass::ReadFileToModel: Model does not exist");
  232. return false;
  233. }
  234. }
  235. else
  236. {
  237. LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "CTfLiteClass::ReadFileToModel: Can't allocate enough memory: " + std::to_string(size));
  238. LogFile.WriteHeapInfo("CTfLiteClass::ReadFileToModel");
  239. return false;
  240. }
  241. }
  242. bool CTfLiteClass::LoadModel(std::string _fn)
  243. {
  244. LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::LoadModel");
  245. if (!ReadFileToModel(_fn.c_str())) {
  246. return false;
  247. }
  248. model = tflite::GetModel(modelfile);
  249. if(model == nullptr)
  250. return false;
  251. return true;
  252. }
  253. CTfLiteClass::CTfLiteClass()
  254. {
  255. this->model = nullptr;
  256. this->modelfile = NULL;
  257. this->interpreter = nullptr;
  258. this->input = nullptr;
  259. this->output = nullptr;
  260. this->kTensorArenaSize = TENSOR_ARENA_SIZE;
  261. this->tensor_arena = (uint8_t*)psram_get_shared_tensor_arena_memory();
  262. }
  263. CTfLiteClass::~CTfLiteClass()
  264. {
  265. delete this->interpreter;
  266. psram_free_shared_tensor_arena_and_model_memory();
  267. }