comparisons.cc 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
  13. #include "tensorflow/lite/c/common.h"
  14. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  15. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  16. #include "tensorflow/lite/kernels/kernel_util.h"
  17. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  18. namespace tflite {
  19. namespace ops {
  20. namespace micro {
  21. namespace comparisons {
  22. namespace {
  23. struct OpData {
  24. ComparisonParams params;
  25. };
  26. constexpr int kInputTensor1 = 0;
  27. constexpr int kInputTensor2 = 1;
  28. constexpr int kOutputTensor = 0;
  29. TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
  30. TFLITE_DCHECK(node->user_data != nullptr);
  31. const OpData* data = static_cast<const OpData*>(node->user_data);
  32. const TfLiteEvalTensor* input1 =
  33. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  34. const TfLiteEvalTensor* input2 =
  35. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  36. TfLiteEvalTensor* output =
  37. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  38. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  39. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  40. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  41. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  42. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  43. switch (input1->type) {
  44. case kTfLiteBool:
  45. requires_broadcast
  46. ? reference_ops::Broadcast4DSlowEqualNoScaling(
  47. data->params, input1_shape,
  48. tflite::micro::GetTensorData<bool>(input1), input2_shape,
  49. tflite::micro::GetTensorData<bool>(input2), output_shape,
  50. output_data)
  51. : reference_ops::EqualNoScaling(
  52. data->params, input1_shape,
  53. tflite::micro::GetTensorData<bool>(input1), input2_shape,
  54. tflite::micro::GetTensorData<bool>(input2), output_shape,
  55. output_data);
  56. break;
  57. case kTfLiteFloat32:
  58. requires_broadcast
  59. ? reference_ops::Broadcast4DSlowEqualNoScaling(
  60. data->params, input1_shape,
  61. tflite::micro::GetTensorData<float>(input1), input2_shape,
  62. tflite::micro::GetTensorData<float>(input2), output_shape,
  63. output_data)
  64. : reference_ops::EqualNoScaling(
  65. data->params, input1_shape,
  66. tflite::micro::GetTensorData<float>(input1), input2_shape,
  67. tflite::micro::GetTensorData<float>(input2), output_shape,
  68. output_data);
  69. break;
  70. case kTfLiteInt32:
  71. requires_broadcast
  72. ? reference_ops::Broadcast4DSlowEqualNoScaling(
  73. data->params, input1_shape,
  74. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  75. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  76. output_data)
  77. : reference_ops::EqualNoScaling(
  78. data->params, input1_shape,
  79. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  80. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  81. output_data);
  82. break;
  83. case kTfLiteInt64:
  84. requires_broadcast
  85. ? reference_ops::Broadcast4DSlowEqualNoScaling(
  86. data->params, input1_shape,
  87. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  88. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  89. output_data)
  90. : reference_ops::EqualNoScaling(
  91. data->params, input1_shape,
  92. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  93. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  94. output_data);
  95. break;
  96. case kTfLiteInt8:
  97. requires_broadcast
  98. ? reference_ops::Broadcast4DSlowEqualWithScaling(
  99. data->params, input1_shape,
  100. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  101. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  102. output_data)
  103. : reference_ops::EqualWithScaling(
  104. data->params, input1_shape,
  105. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  106. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  107. output_data);
  108. break;
  109. default:
  110. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  111. TfLiteTypeGetName(input1->type), input1->type);
  112. return kTfLiteError;
  113. }
  114. return kTfLiteOk;
  115. }
  116. // TODO(renjieliu): Refactor the logic to avoid duplications.
  117. TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
  118. TFLITE_DCHECK(node->user_data != nullptr);
  119. const OpData* data = static_cast<const OpData*>(node->user_data);
  120. const TfLiteEvalTensor* input1 =
  121. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  122. const TfLiteEvalTensor* input2 =
  123. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  124. TfLiteEvalTensor* output =
  125. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  126. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  127. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  128. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  129. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  130. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  131. switch (input1->type) {
  132. case kTfLiteBool:
  133. requires_broadcast
  134. ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
  135. data->params, input1_shape,
  136. tflite::micro::GetTensorData<bool>(input1), input2_shape,
  137. tflite::micro::GetTensorData<bool>(input2), output_shape,
  138. output_data)
  139. : reference_ops::NotEqualNoScaling(
  140. data->params, input1_shape,
  141. tflite::micro::GetTensorData<bool>(input1), input2_shape,
  142. tflite::micro::GetTensorData<bool>(input2), output_shape,
  143. output_data);
  144. break;
  145. case kTfLiteFloat32:
  146. requires_broadcast
  147. ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
  148. data->params, input1_shape,
  149. tflite::micro::GetTensorData<float>(input1), input2_shape,
  150. tflite::micro::GetTensorData<float>(input2), output_shape,
  151. output_data)
  152. : reference_ops::NotEqualNoScaling(
  153. data->params, input1_shape,
  154. tflite::micro::GetTensorData<float>(input1), input2_shape,
  155. tflite::micro::GetTensorData<float>(input2), output_shape,
  156. output_data);
  157. break;
  158. case kTfLiteInt32:
  159. requires_broadcast
  160. ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
  161. data->params, input1_shape,
  162. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  163. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  164. output_data)
  165. : reference_ops::NotEqualNoScaling(
  166. data->params, input1_shape,
  167. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  168. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  169. output_data);
  170. break;
  171. case kTfLiteInt64:
  172. requires_broadcast
  173. ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
  174. data->params, input1_shape,
  175. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  176. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  177. output_data)
  178. : reference_ops::NotEqualNoScaling(
  179. data->params, input1_shape,
  180. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  181. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  182. output_data);
  183. break;
  184. case kTfLiteInt8:
  185. requires_broadcast
  186. ? reference_ops::Broadcast4DSlowNotEqualWithScaling(
  187. data->params, input1_shape,
  188. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  189. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  190. output_data)
  191. : reference_ops::NotEqualWithScaling(
  192. data->params, input1_shape,
  193. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  194. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  195. output_data);
  196. break;
  197. default:
  198. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  199. TfLiteTypeGetName(input1->type), input1->type);
  200. return kTfLiteError;
  201. }
  202. return kTfLiteOk;
  203. }
  204. TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
  205. TFLITE_DCHECK(node->user_data != nullptr);
  206. const OpData* data = static_cast<const OpData*>(node->user_data);
  207. const TfLiteEvalTensor* input1 =
  208. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  209. const TfLiteEvalTensor* input2 =
  210. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  211. TfLiteEvalTensor* output =
  212. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  213. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  214. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  215. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  216. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  217. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  218. switch (input1->type) {
  219. case kTfLiteFloat32:
  220. requires_broadcast
  221. ? reference_ops::Broadcast4DSlowGreaterNoScaling(
  222. data->params, input1_shape,
  223. tflite::micro::GetTensorData<float>(input1), input2_shape,
  224. tflite::micro::GetTensorData<float>(input2), output_shape,
  225. output_data)
  226. : reference_ops::GreaterNoScaling(
  227. data->params, input1_shape,
  228. tflite::micro::GetTensorData<float>(input1), input2_shape,
  229. tflite::micro::GetTensorData<float>(input2), output_shape,
  230. output_data);
  231. break;
  232. case kTfLiteInt32:
  233. requires_broadcast
  234. ? reference_ops::Broadcast4DSlowGreaterNoScaling(
  235. data->params, input1_shape,
  236. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  237. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  238. output_data)
  239. : reference_ops::GreaterNoScaling(
  240. data->params, input1_shape,
  241. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  242. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  243. output_data);
  244. break;
  245. case kTfLiteInt64:
  246. requires_broadcast
  247. ? reference_ops::Broadcast4DSlowGreaterNoScaling(
  248. data->params, input1_shape,
  249. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  250. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  251. output_data)
  252. : reference_ops::GreaterNoScaling(
  253. data->params, input1_shape,
  254. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  255. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  256. output_data);
  257. break;
  258. case kTfLiteInt8:
  259. requires_broadcast
  260. ? reference_ops::Broadcast4DSlowGreaterWithScaling(
  261. data->params, input1_shape,
  262. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  263. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  264. output_data)
  265. : reference_ops::GreaterWithScaling(
  266. data->params, input1_shape,
  267. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  268. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  269. output_data);
  270. break;
  271. default:
  272. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  273. TfLiteTypeGetName(input1->type), input1->type);
  274. return kTfLiteError;
  275. }
  276. return kTfLiteOk;
  277. }
  278. TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
  279. TFLITE_DCHECK(node->user_data != nullptr);
  280. const OpData* data = static_cast<const OpData*>(node->user_data);
  281. const TfLiteEvalTensor* input1 =
  282. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  283. const TfLiteEvalTensor* input2 =
  284. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  285. TfLiteEvalTensor* output =
  286. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  287. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  288. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  289. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  290. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  291. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  292. switch (input1->type) {
  293. case kTfLiteFloat32:
  294. requires_broadcast
  295. ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
  296. data->params, input1_shape,
  297. tflite::micro::GetTensorData<float>(input1), input2_shape,
  298. tflite::micro::GetTensorData<float>(input2), output_shape,
  299. output_data)
  300. : reference_ops::GreaterEqualNoScaling(
  301. data->params, input1_shape,
  302. tflite::micro::GetTensorData<float>(input1), input2_shape,
  303. tflite::micro::GetTensorData<float>(input2), output_shape,
  304. output_data);
  305. break;
  306. case kTfLiteInt32:
  307. requires_broadcast
  308. ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
  309. data->params, input1_shape,
  310. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  311. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  312. output_data)
  313. : reference_ops::GreaterEqualNoScaling(
  314. data->params, input1_shape,
  315. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  316. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  317. output_data);
  318. break;
  319. case kTfLiteInt64:
  320. requires_broadcast
  321. ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
  322. data->params, input1_shape,
  323. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  324. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  325. output_data)
  326. : reference_ops::GreaterEqualNoScaling(
  327. data->params, input1_shape,
  328. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  329. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  330. output_data);
  331. break;
  332. case kTfLiteInt8:
  333. requires_broadcast
  334. ? reference_ops::Broadcast4DSlowGreaterEqualWithScaling(
  335. data->params, input1_shape,
  336. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  337. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  338. output_data)
  339. : reference_ops::GreaterEqualWithScaling(
  340. data->params, input1_shape,
  341. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  342. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  343. output_data);
  344. break;
  345. default:
  346. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  347. TfLiteTypeGetName(input1->type), input1->type);
  348. return kTfLiteError;
  349. }
  350. return kTfLiteOk;
  351. }
  352. TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
  353. TFLITE_DCHECK(node->user_data != nullptr);
  354. const OpData* data = static_cast<const OpData*>(node->user_data);
  355. const TfLiteEvalTensor* input1 =
  356. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  357. const TfLiteEvalTensor* input2 =
  358. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  359. TfLiteEvalTensor* output =
  360. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  361. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  362. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  363. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  364. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  365. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  366. switch (input1->type) {
  367. case kTfLiteFloat32:
  368. requires_broadcast
  369. ? reference_ops::Broadcast4DSlowLessNoScaling(
  370. data->params, input1_shape,
  371. tflite::micro::GetTensorData<float>(input1), input2_shape,
  372. tflite::micro::GetTensorData<float>(input2), output_shape,
  373. output_data)
  374. : reference_ops::LessNoScaling(
  375. data->params, input1_shape,
  376. tflite::micro::GetTensorData<float>(input1), input2_shape,
  377. tflite::micro::GetTensorData<float>(input2), output_shape,
  378. output_data);
  379. break;
  380. case kTfLiteInt32:
  381. requires_broadcast
  382. ? reference_ops::Broadcast4DSlowLessNoScaling(
  383. data->params, input1_shape,
  384. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  385. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  386. output_data)
  387. : reference_ops::LessNoScaling(
  388. data->params, input1_shape,
  389. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  390. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  391. output_data);
  392. break;
  393. case kTfLiteInt64:
  394. requires_broadcast
  395. ? reference_ops::Broadcast4DSlowLessNoScaling(
  396. data->params, input1_shape,
  397. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  398. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  399. output_data)
  400. : reference_ops::LessNoScaling(
  401. data->params, input1_shape,
  402. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  403. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  404. output_data);
  405. break;
  406. case kTfLiteInt8:
  407. requires_broadcast
  408. ? reference_ops::Broadcast4DSlowLessWithScaling(
  409. data->params, input1_shape,
  410. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  411. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  412. output_data)
  413. : reference_ops::LessWithScaling(
  414. data->params, input1_shape,
  415. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  416. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  417. output_data);
  418. break;
  419. default:
  420. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  421. TfLiteTypeGetName(input1->type), input1->type);
  422. return kTfLiteError;
  423. }
  424. return kTfLiteOk;
  425. }
  426. TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
  427. TFLITE_DCHECK(node->user_data != nullptr);
  428. const OpData* data = static_cast<const OpData*>(node->user_data);
  429. const TfLiteEvalTensor* input1 =
  430. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  431. const TfLiteEvalTensor* input2 =
  432. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  433. TfLiteEvalTensor* output =
  434. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  435. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  436. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  437. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  438. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  439. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  440. switch (input1->type) {
  441. case kTfLiteFloat32:
  442. requires_broadcast
  443. ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
  444. data->params, input1_shape,
  445. tflite::micro::GetTensorData<float>(input1), input2_shape,
  446. tflite::micro::GetTensorData<float>(input2), output_shape,
  447. output_data)
  448. : reference_ops::LessEqualNoScaling(
  449. data->params, input1_shape,
  450. tflite::micro::GetTensorData<float>(input1), input2_shape,
  451. tflite::micro::GetTensorData<float>(input2), output_shape,
  452. output_data);
  453. break;
  454. case kTfLiteInt32:
  455. requires_broadcast
  456. ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
  457. data->params, input1_shape,
  458. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  459. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  460. output_data)
  461. : reference_ops::LessEqualNoScaling(
  462. data->params, input1_shape,
  463. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  464. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  465. output_data);
  466. break;
  467. case kTfLiteInt64:
  468. requires_broadcast
  469. ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
  470. data->params, input1_shape,
  471. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  472. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  473. output_data)
  474. : reference_ops::LessEqualNoScaling(
  475. data->params, input1_shape,
  476. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  477. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  478. output_data);
  479. break;
  480. case kTfLiteInt8:
  481. requires_broadcast
  482. ? reference_ops::Broadcast4DSlowLessEqualWithScaling(
  483. data->params, input1_shape,
  484. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  485. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  486. output_data)
  487. : reference_ops::LessEqualWithScaling(
  488. data->params, input1_shape,
  489. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  490. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  491. output_data);
  492. break;
  493. default:
  494. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  495. TfLiteTypeGetName(input1->type), input1->type);
  496. return kTfLiteError;
  497. }
  498. return kTfLiteOk;
  499. }
  500. } // namespace
  501. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  502. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  503. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  504. }
  505. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  506. TFLITE_DCHECK(node->user_data != nullptr);
  507. OpData* data = static_cast<OpData*>(node->user_data);
  508. MicroContext* micro_context = GetMicroContext(context);
  509. TfLiteTensor* input1 =
  510. micro_context->AllocateTempInputTensor(node, kInputTensor1);
  511. TF_LITE_ENSURE(context, input1 != nullptr);
  512. TfLiteTensor* input2 =
  513. micro_context->AllocateTempInputTensor(node, kInputTensor2);
  514. TF_LITE_ENSURE(context, input2 != nullptr);
  515. if (input1->type == kTfLiteInt8) {
  516. auto input1_offset = -input1->params.zero_point;
  517. auto input2_offset = -input2->params.zero_point;
  518. const int kLeftShift = 8;
  519. int32_t input1_multiplier;
  520. int input1_shift;
  521. QuantizeMultiplierSmallerThanOneExp(
  522. static_cast<double>(input1->params.scale), &input1_multiplier,
  523. &input1_shift);
  524. int32_t input2_multiplier;
  525. int input2_shift;
  526. QuantizeMultiplierSmallerThanOneExp(
  527. static_cast<double>(input2->params.scale), &input2_multiplier,
  528. &input2_shift);
  529. data->params.left_shift = kLeftShift;
  530. data->params.input1_offset = input1_offset;
  531. data->params.input1_multiplier = input1_multiplier;
  532. data->params.input1_shift = input1_shift;
  533. data->params.input2_offset = input2_offset;
  534. data->params.input2_multiplier = input2_multiplier;
  535. data->params.input2_shift = input2_shift;
  536. }
  537. micro_context->DeallocateTempTfLiteTensor(input1);
  538. micro_context->DeallocateTempTfLiteTensor(input2);
  539. return kTfLiteOk;
  540. }
  541. } // namespace comparisons
  542. TfLiteRegistration Register_EQUAL() {
  543. return {/*init=*/comparisons::Init,
  544. /*free=*/nullptr,
  545. /*prepare=*/comparisons::Prepare,
  546. /*invoke=*/comparisons::EqualEval,
  547. /*profiling_string=*/nullptr,
  548. /*builtin_code=*/0,
  549. /*custom_name=*/nullptr,
  550. /*version=*/0};
  551. }
  552. TfLiteRegistration Register_NOT_EQUAL() {
  553. return {/*init=*/comparisons::Init,
  554. /*free=*/nullptr,
  555. /*prepare=*/comparisons::Prepare,
  556. /*invoke=*/comparisons::NotEqualEval,
  557. /*profiling_string=*/nullptr,
  558. /*builtin_code=*/0,
  559. /*custom_name=*/nullptr,
  560. /*version=*/0};
  561. }
  562. TfLiteRegistration Register_GREATER() {
  563. return {/*init=*/comparisons::Init,
  564. /*free=*/nullptr,
  565. /*prepare=*/comparisons::Prepare,
  566. /*invoke=*/comparisons::GreaterEval,
  567. /*profiling_string=*/nullptr,
  568. /*builtin_code=*/0,
  569. /*custom_name=*/nullptr,
  570. /*version=*/0};
  571. }
  572. TfLiteRegistration Register_GREATER_EQUAL() {
  573. return {/*init=*/comparisons::Init,
  574. /*free=*/nullptr,
  575. /*prepare=*/comparisons::Prepare,
  576. /*invoke=*/comparisons::GreaterEqualEval,
  577. /*profiling_string=*/nullptr,
  578. /*builtin_code=*/0,
  579. /*custom_name=*/nullptr,
  580. /*version=*/0};
  581. }
  582. TfLiteRegistration Register_LESS() {
  583. return {/*init=*/comparisons::Init,
  584. /*free=*/nullptr,
  585. /*prepare=*/comparisons::Prepare,
  586. /*invoke=*/comparisons::LessEval,
  587. /*profiling_string=*/nullptr,
  588. /*builtin_code=*/0,
  589. /*custom_name=*/nullptr,
  590. /*version=*/0};
  591. }
  592. TfLiteRegistration Register_LESS_EQUAL() {
  593. return {/*init=*/comparisons::Init,
  594. /*free=*/nullptr,
  595. /*prepare=*/comparisons::Prepare,
  596. /*invoke=*/comparisons::LessEqualEval,
  597. /*profiling_string=*/nullptr,
  598. /*builtin_code=*/0,
  599. /*custom_name=*/nullptr,
  600. /*version=*/0};
  601. }
  602. } // namespace micro
  603. } // namespace ops
  604. } // namespace tflite