if.cc 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include <stddef.h>
  13. #include <cstring>
  14. #include "tensorflow/lite/c/builtin_op_data.h"
  15. #include "tensorflow/lite/c/common.h"
  16. #include "tensorflow/lite/kernels/internal/compatibility.h"
  17. #include "tensorflow/lite/kernels/kernel_util.h"
  18. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  19. #include "tensorflow/lite/micro/memory_helpers.h"
  20. #include "tensorflow/lite/micro/micro_graph.h"
  21. #include "tensorflow/lite/schema/schema_generated.h"
  22. namespace tflite {
  23. namespace {
  24. struct OpData {
  25. int then_subgraph_index;
  26. int else_subgraph_index;
  27. };
  28. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  29. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  30. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  31. }
  32. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  33. OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
  34. const auto* params =
  35. reinterpret_cast<const TfLiteIfParams*>(node->builtin_data);
  36. op_data->then_subgraph_index = params->then_subgraph_index;
  37. op_data->else_subgraph_index = params->else_subgraph_index;
  38. TF_LITE_ENSURE(context, node->inputs->size > 0);
  39. // The first input is the condition.
  40. const TfLiteTensor* cond;
  41. TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
  42. TF_LITE_ENSURE_EQ(context, cond->type, kTfLiteBool);
  43. TF_LITE_ENSURE_EQ(context, NumElements(cond), 1);
  44. // The first input of the node is the condition. The rest of inputs are
  45. // passed to the branch subgraphs. Therefore, the number of subgraph inputs
  46. // will be the number of node inputs - 1.
  47. size_t num_inputs = node->inputs->size - 1;
  48. size_t num_outputs = node->outputs->size;
  49. // Casting to TfliteIntArray is required since we are re-using
  50. // GetExecutionPlan from TfLiteContext. On TFLM this method returns a
  51. // MicroGraph.
  52. // TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
  53. MicroGraph* graph_info;
  54. context->GetExecutionPlan(context,
  55. reinterpret_cast<TfLiteIntArray**>(&graph_info));
  56. TF_LITE_ENSURE(context,
  57. op_data->then_subgraph_index < graph_info->NumSubgraphs());
  58. TF_LITE_ENSURE(context,
  59. op_data->else_subgraph_index < graph_info->NumSubgraphs());
  60. TF_LITE_ENSURE_EQ(
  61. context, num_inputs,
  62. graph_info->NumSubgraphInputs(op_data->then_subgraph_index));
  63. TF_LITE_ENSURE_EQ(
  64. context, num_outputs,
  65. graph_info->NumSubgraphOutputs(op_data->then_subgraph_index));
  66. return kTfLiteOk;
  67. }
  68. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  69. const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
  70. const TfLiteTensor* cond;
  71. TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
  72. bool cond_value = cond->data.b[0];
  73. // Casting to TfliteIntArray is required since we are re-using
  74. // GetExecutionPlan from TfLiteContext. On TFLM this method returns a
  75. // MicroGraph.
  76. // TODO(b/188226309): Design a cleaner way to get a graph from kernel context.
  77. MicroGraph* graph_info;
  78. context->GetExecutionPlan(context,
  79. reinterpret_cast<TfLiteIntArray**>(&graph_info));
  80. // Currently we copy the input / output between the subgraphs. This isn't
  81. // optimized yet.
  82. int active_branch_subgraph_index =
  83. cond_value ? op_data->then_subgraph_index : op_data->else_subgraph_index;
  84. for (size_t i = 0;
  85. i < graph_info->NumSubgraphInputs(active_branch_subgraph_index); ++i) {
  86. const TfLiteEvalTensor* input =
  87. tflite::micro::GetEvalInput(context, node, i + 1);
  88. TfLiteEvalTensor* subgraph_input =
  89. graph_info->GetSubgraphInput(active_branch_subgraph_index, i);
  90. // These checks must occur in Eval since TfLiteEvalTensors are not available
  91. // during Prepare.
  92. size_t input_bytes;
  93. size_t subgraph_input_bytes;
  94. TF_LITE_ENSURE_OK(context, TfLiteEvalTensorByteLength(input, &input_bytes));
  95. TF_LITE_ENSURE_OK(context, TfLiteEvalTensorByteLength(
  96. subgraph_input, &subgraph_input_bytes));
  97. TF_LITE_ENSURE_TYPES_EQ(context, input->type, subgraph_input->type);
  98. TF_LITE_ENSURE_EQ(context, input_bytes, subgraph_input_bytes);
  99. memcpy(subgraph_input->data.raw, input->data.raw, input_bytes);
  100. }
  101. TF_LITE_ENSURE_OK(context,
  102. graph_info->InvokeSubgraph(active_branch_subgraph_index));
  103. for (size_t i = 0;
  104. i < graph_info->NumSubgraphOutputs(active_branch_subgraph_index); ++i) {
  105. const TfLiteEvalTensor* output =
  106. tflite::micro::GetEvalOutput(context, node, i);
  107. TfLiteEvalTensor* subgraph_output =
  108. graph_info->GetSubgraphOutput(active_branch_subgraph_index, i);
  109. // These checks must occur in Eval since TfLiteEvalTensors are not available
  110. // during Prepare.
  111. size_t output_bytes;
  112. size_t subgraph_output_bytes;
  113. TF_LITE_ENSURE_OK(context,
  114. TfLiteEvalTensorByteLength(output, &output_bytes));
  115. TF_LITE_ENSURE_OK(context, TfLiteEvalTensorByteLength(
  116. subgraph_output, &subgraph_output_bytes));
  117. TF_LITE_ENSURE_TYPES_EQ(context, output->type, subgraph_output->type);
  118. TF_LITE_ENSURE_EQ(context, output_bytes, subgraph_output_bytes);
  119. memcpy(output->data.raw, subgraph_output->data.raw, output_bytes);
  120. }
  121. return kTfLiteOk;
  122. }
  123. } // namespace.
  124. TfLiteRegistration Register_IF() {
  125. return {/*init=*/Init,
  126. /*free=*/nullptr,
  127. /*prepare=*/Prepare,
  128. /*invoke=*/Eval,
  129. /*profiling_string=*/nullptr,
  130. /*builtin_code=*/0,
  131. /*custom_name=*/nullptr,
  132. /*version=*/0};
  133. }
  134. } // namespace tflite