Explorar o código

fix(tflite): Fix memory leaks in tflite integration (#2842)

Slider0007 %!s(int64=2) %!d(string=hai) anos
pai
achega
0e0fb459dc

+ 2 - 37
code/components/jomjol_tfliteclass/CTfLiteClass.cpp

@@ -12,17 +12,9 @@
 
 static const char *TAG = "TFLITE";
 
-/// Static Resolver muss mit allen Operatoren geladen Werden, die benöägit werden - ABER nur 1x --> gesonderte Funktion /////////////////////////////
-static bool MakeStaticResolverDone = false;
-static tflite::MicroMutableOpResolver<15> resolver;
 
-void MakeStaticResolver()
+void CTfLiteClass::MakeStaticResolver()
 {
-  if (MakeStaticResolverDone)
-    return;
-
-  MakeStaticResolverDone = true;
-
   resolver.AddFullyConnected();
   resolver.AddReshape();
   resolver.AddSoftmax();
@@ -34,7 +26,6 @@ void MakeStaticResolver()
   resolver.AddLeakyRelu();
   resolver.AddDequantize();
 }
-////////////////////////////////////////////////////////////////////////////////////////
 
 
 float CTfLiteClass::GetOutputValue(int nr)
@@ -207,9 +198,7 @@ bool CTfLiteClass::LoadInputImageBasis(CImageBasis *rs)
 
 bool CTfLiteClass::MakeAllocate()
 {
-
-  MakeStaticResolver();
-
+    MakeStaticResolver();
 
     #ifdef DEBUG_DETAIL_ON 
         LogFile.WriteHeapInfo("CTLiteClass::Alloc start");
@@ -217,13 +206,11 @@ bool CTfLiteClass::MakeAllocate()
 
     LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::MakeAllocate");
     this->interpreter = new tflite::MicroInterpreter(this->model, resolver, this->tensor_arena, this->kTensorArenaSize);
-//    this->interpreter = new tflite::MicroInterpreter(this->model, resolver, this->tensor_arena, this->kTensorArenaSize, this->error_reporter);
 
     if (this->interpreter) 
     {
         TfLiteStatus allocate_status = this->interpreter->AllocateTensors();
         if (allocate_status != kTfLiteOk) {
-            TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
             LogFile.WriteToFile(ESP_LOG_ERROR, TAG, "AllocateTensors() failed");
 
             this->GetInputDimension();   
@@ -313,13 +300,6 @@ bool CTfLiteClass::ReadFileToModel(std::string _fn)
 
 bool CTfLiteClass::LoadModel(std::string _fn)
 {
-#ifdef SUPRESS_TFLITE_ERRORS
-//    this->error_reporter = new tflite::ErrorReporter;
-    this->error_reporter = new tflite::OwnMicroErrorReporter;
-#else
-    this->error_reporter = new tflite::MicroErrorReporter;
-#endif
-
     LogFile.WriteToFile(ESP_LOG_DEBUG, TAG, "CTfLiteClass::LoadModel");
 
     if (!ReadFileToModel(_fn.c_str())) {
@@ -350,21 +330,6 @@ CTfLiteClass::CTfLiteClass()
 CTfLiteClass::~CTfLiteClass()
 {
   delete this->interpreter;
-//  delete this->error_reporter;
 
   psram_free_shared_tensor_arena_and_model_memory();
 }        
-
-#ifdef SUPRESS_TFLITE_ERRORS
-namespace tflite 
-{
-//tflite::ErrorReporter
-//  int OwnMicroErrorReporter::Report(const char* format, va_list args) 
-
-  int OwnMicroErrorReporter::Report(const char* format, va_list args) 
-  {
-    return 0;
-  }
-} 
-#endif
- 

+ 3 - 25
code/components/jomjol_tfliteclass/CTfLiteClass.h

@@ -5,39 +5,18 @@
 
 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
 #include "tensorflow/lite/micro/micro_interpreter.h"
-#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
-#include "tensorflow/lite/micro/kernels/micro_ops.h"
-
-#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"
-#include "tensorflow/lite/micro/micro_interpreter.h"
 #include "tensorflow/lite/schema/schema_generated.h"
-#include "tensorflow/lite/micro/kernels/micro_ops.h"
+
 #include "esp_err.h"
 #include "esp_log.h"
 
 #include "CImageBasis.h"
 
 
-
-#ifdef SUPRESS_TFLITE_ERRORS
-#include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/micro/compatibility.h"
-#include "tensorflow/lite/micro/debug_log.h"
-///// OwnErrorReporter to prevent printing of Errors (especially unavoidable in CalculateActivationRangeQuantized@kerne_util.cc)
-namespace tflite {
-    class OwnMicroErrorReporter : public ErrorReporter {
-        public:
-           int Report(const char* format, va_list args) override;
-    };
-}  // namespace tflite
-////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
-#endif
-
-
 class CTfLiteClass
 {
     protected:
-        tflite::ErrorReporter *error_reporter;
+        tflite::MicroMutableOpResolver<10> resolver;  
         const tflite::Model* model;
         tflite::MicroInterpreter* interpreter;
         TfLiteTensor* output = nullptr;     
@@ -54,6 +33,7 @@ class CTfLiteClass
 
         long GetFileSize(std::string filename);
         bool ReadFileToModel(std::string _fn);
+        void MakeStaticResolver();
 
     public:
         CTfLiteClass();
@@ -74,6 +54,4 @@ class CTfLiteClass
         int ReadInputDimenstion(int _dim);
 };
 
-void MakeStaticResolver();
-
 #endif //CTFLITECLASS_H

+ 0 - 9
code/include/defines.h

@@ -167,15 +167,6 @@
     #define LWT_DISCONNECTED "connection lost"
 
 
-    //CTfLiteClass
-    #define TFLITE_MINIMAL_CHECK(x)                              \
-        if (!(x)) {                                                \
-            fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
-            exit(1);                                                 \
-        }
-    // #define SUPRESS_TFLITE_ERRORS // use, to avoid error messages from TFLITE
-
-
     // connect_wlan.cpp
     //******************************
     /* WIFI roaming functionalities 802.11k+v (uses ca. 6kB - 8kB internal RAM; if SCAN CACHE activated: + 1kB / beacon)