1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stddef.h>
16#include <stdint.h>
17
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace select {
28
29constexpr int kInputTensorCondition = 0;
30constexpr int kInputTensorX = 1;
31constexpr int kInputTensorY = 2;
32constexpr int kOutputTensor = 0;
33
34enum KernelType {
35 kVersionOne,
36 kVersionTwo,
37};
38
39struct OpData {
40 bool requires_broadcast;
41 // True if input condition is scalar or input condition has rank one and
42 // matches the first dimension of other inputs.
43 bool has_low_rank_input_condition;
44};
45
46void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) {
47 auto* data = new OpData;
48 data->requires_broadcast = false;
49 data->has_low_rank_input_condition = false;
50 return data;
51}
52
53void SelectFree(TfLiteContext* context, void* buffer) {
54 delete reinterpret_cast<OpData*>(buffer);
55}
56
57template <KernelType kernel_type>
58TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
59 OpData* data = reinterpret_cast<OpData*>(node->user_data);
60
61 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
62 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
63
64 const TfLiteTensor* input_condition;
65 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition,
66 &input_condition));
67 const TfLiteTensor* input_x;
68 TF_LITE_ENSURE_OK(context,
69 GetInputSafe(context, node, kInputTensorX, &input_x));
70 const TfLiteTensor* input_y;
71 TF_LITE_ENSURE_OK(context,
72 GetInputSafe(context, node, kInputTensorY, &input_y));
73 TfLiteTensor* output;
74 TF_LITE_ENSURE_OK(context,
75 GetOutputSafe(context, node, kOutputTensor, &output));
76
77 // Input must be bool.
78 TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool);
79 TF_LITE_ENSURE_TYPES_EQ(context, input_x->type, input_y->type);
80 output->type = input_x->type;
81
82 // Respect the original output shape when there are mixed shapes to represent
83 // a scalar data.
84 if (GetTensorShape(input_condition).FlatSize() == 1 &&
85 GetTensorShape(input_x).FlatSize() == 1 &&
86 GetTensorShape(input_y).FlatSize() == 1 &&
87 GetTensorShape(output).FlatSize() == 1) {
88 return kTfLiteOk;
89 }
90
91 bool same_shape = HaveSameShapes(input_condition, input_x) &&
92 HaveSameShapes(input_x, input_y);
93 TfLiteIntArray* output_size;
94 if (!same_shape) {
95 switch (kernel_type) {
96 case kVersionOne: {
97 bool is_input_condition_scalar = NumDimensions(input_condition) == 0;
98 bool has_rank_one_input_condition =
99 NumDimensions(input_condition) == 1 &&
100 SizeOfDimension(input_condition, 0) == SizeOfDimension(input_x, 0);
101 data->has_low_rank_input_condition =
102 is_input_condition_scalar || has_rank_one_input_condition;
103 TF_LITE_ENSURE(context, data->has_low_rank_input_condition);
104
105 output_size = TfLiteIntArrayCopy(input_x->dims);
106
107 // Input tensors must have the same type and size
108 TF_LITE_ENSURE(context, HaveSameShapes(input_x, input_y));
109 break;
110 }
111 case kVersionTwo: {
112 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
113 context, input_condition, input_x,
114 input_y, &output_size));
115 data->requires_broadcast = true;
116 break;
117 }
118 default:
119 return kTfLiteError;
120 }
121 } else {
122 output_size = TfLiteIntArrayCopy(input_x->dims);
123 }
124
125 return context->ResizeTensor(context, output, output_size);
126}
127
128TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
129 OpData* data = reinterpret_cast<OpData*>(node->user_data);
130 const TfLiteTensor* input_condition;
131 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition,
132 &input_condition));
133 const TfLiteTensor* input_x;
134 TF_LITE_ENSURE_OK(context,
135 GetInputSafe(context, node, kInputTensorX, &input_x));
136 const TfLiteTensor* input_y;
137 TF_LITE_ENSURE_OK(context,
138 GetInputSafe(context, node, kInputTensorY, &input_y));
139 TfLiteTensor* output;
140 TF_LITE_ENSURE_OK(context,
141 GetOutputSafe(context, node, kOutputTensor, &output));
142
143#define TF_LITE_SELECT(type, op) \
144 reference_ops::op(GetTensorShape(input_condition), \
145 GetTensorData<bool>(input_condition), \
146 GetTensorShape(input_x), GetTensorData<type>(input_x), \
147 GetTensorShape(input_y), GetTensorData<type>(input_y), \
148 GetTensorShape(output), GetTensorData<type>(output));
149
150#define TF_LITE_SWITCH(type, op) \
151 switch (type) { \
152 break; \
153 case kTfLiteBool: \
154 TF_LITE_SELECT(bool, op); \
155 break; \
156 case kTfLiteFloat32: \
157 TF_LITE_SELECT(float, op); \
158 break; \
159 case kTfLiteUInt8: \
160 TF_LITE_SELECT(uint8_t, op); \
161 break; \
162 case kTfLiteInt8: \
163 TF_LITE_SELECT(int8_t, op); \
164 break; \
165 case kTfLiteInt16: \
166 TF_LITE_SELECT(int16_t, op); \
167 break; \
168 case kTfLiteInt32: \
169 TF_LITE_SELECT(int32_t, op); \
170 break; \
171 case kTfLiteInt64: \
172 TF_LITE_SELECT(int64_t, op); \
173 break; \
174 default: \
175 TF_LITE_KERNEL_LOG(context, \
176 "Does not support type other than bool|float|int, " \
177 "got %d", \
178 type); \
179 return kTfLiteError; \
180 }
181
182 if (data->has_low_rank_input_condition) {
183 TF_LITE_SWITCH(input_x->type, RankOneSelect);
184 } else if (data->requires_broadcast) {
185 TF_LITE_SWITCH(input_x->type, BroadcastSelect5DSlow);
186 } else {
187 TF_LITE_SWITCH(input_x->type, Select);
188 }
189
190#undef TF_LITE_SELECT
191#undef TF_LITE_SWITCH
192 return kTfLiteOk;
193}
194
195} // namespace select
196
197// Select op selects values of 'x' if the corresponding value of 'condition' is
198// true or the value of 'y' if false. There are valid condition input sizes:
199//
200// 1. Either the same shape (in which case the select is elementwise), or
201// 2. condition must be Rank 1 and match over the first dimension, or
202// 3. condition is scalar
203TfLiteRegistration* Register_SELECT() {
204 static TfLiteRegistration r = {select::SelectInit, select::SelectFree,
205 select::SelectPrepare<select::kVersionOne>,
206 select::SelectEval};
207 return &r;
208}
209
210// SelectV2 op selects values of 'x' if the corresponding value of 'condition'
211// is true or the value of 'y' if false. There are valid condition input sizes:
212//
213// 1. Either the same shape (in which case the select is elementwise), or
214// 2. Broadcastable shapes between 'condition', 'x' and 'y'.
215TfLiteRegistration* Register_SELECT_V2() {
216 static TfLiteRegistration r = {select::SelectInit, select::SelectFree,
217 select::SelectPrepare<select::kVersionTwo>,
218 select::SelectEval};
219 return &r;
220}
221
222} // namespace builtin
223} // namespace ops
224} // namespace tflite
225