1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stddef.h> |
16 | #include <stdint.h> |
17 | |
18 | #include "tensorflow/lite/c/common.h" |
19 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
20 | #include "tensorflow/lite/kernels/internal/tensor.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
22 | #include "tensorflow/lite/kernels/kernel_util.h" |
23 | |
24 | namespace tflite { |
25 | namespace ops { |
26 | namespace builtin { |
27 | namespace select { |
28 | |
29 | constexpr int kInputTensorCondition = 0; |
30 | constexpr int kInputTensorX = 1; |
31 | constexpr int kInputTensorY = 2; |
32 | constexpr int kOutputTensor = 0; |
33 | |
34 | enum KernelType { |
35 | kVersionOne, |
36 | kVersionTwo, |
37 | }; |
38 | |
39 | struct OpData { |
40 | bool requires_broadcast; |
41 | // True if input condition is scalar or input condition has rank one and |
42 | // matches the first dimension of other inputs. |
43 | bool has_low_rank_input_condition; |
44 | }; |
45 | |
46 | void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) { |
47 | auto* data = new OpData; |
48 | data->requires_broadcast = false; |
49 | data->has_low_rank_input_condition = false; |
50 | return data; |
51 | } |
52 | |
53 | void SelectFree(TfLiteContext* context, void* buffer) { |
54 | delete reinterpret_cast<OpData*>(buffer); |
55 | } |
56 | |
57 | template <KernelType kernel_type> |
58 | TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) { |
59 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
60 | |
61 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); |
62 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
63 | |
64 | const TfLiteTensor* input_condition; |
65 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition, |
66 | &input_condition)); |
67 | const TfLiteTensor* input_x; |
68 | TF_LITE_ENSURE_OK(context, |
69 | GetInputSafe(context, node, kInputTensorX, &input_x)); |
70 | const TfLiteTensor* input_y; |
71 | TF_LITE_ENSURE_OK(context, |
72 | GetInputSafe(context, node, kInputTensorY, &input_y)); |
73 | TfLiteTensor* output; |
74 | TF_LITE_ENSURE_OK(context, |
75 | GetOutputSafe(context, node, kOutputTensor, &output)); |
76 | |
77 | // Input must be bool. |
78 | TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool); |
79 | TF_LITE_ENSURE_TYPES_EQ(context, input_x->type, input_y->type); |
80 | output->type = input_x->type; |
81 | |
82 | // Respect the original output shape when there are mixed shapes to represent |
83 | // a scalar data. |
84 | if (GetTensorShape(input_condition).FlatSize() == 1 && |
85 | GetTensorShape(input_x).FlatSize() == 1 && |
86 | GetTensorShape(input_y).FlatSize() == 1 && |
87 | GetTensorShape(output).FlatSize() == 1) { |
88 | return kTfLiteOk; |
89 | } |
90 | |
91 | bool same_shape = HaveSameShapes(input_condition, input_x) && |
92 | HaveSameShapes(input_x, input_y); |
93 | TfLiteIntArray* output_size; |
94 | if (!same_shape) { |
95 | switch (kernel_type) { |
96 | case kVersionOne: { |
97 | bool is_input_condition_scalar = NumDimensions(input_condition) == 0; |
98 | bool has_rank_one_input_condition = |
99 | NumDimensions(input_condition) == 1 && |
100 | SizeOfDimension(input_condition, 0) == SizeOfDimension(input_x, 0); |
101 | data->has_low_rank_input_condition = |
102 | is_input_condition_scalar || has_rank_one_input_condition; |
103 | TF_LITE_ENSURE(context, data->has_low_rank_input_condition); |
104 | |
105 | output_size = TfLiteIntArrayCopy(input_x->dims); |
106 | |
107 | // Input tensors must have the same type and size |
108 | TF_LITE_ENSURE(context, HaveSameShapes(input_x, input_y)); |
109 | break; |
110 | } |
111 | case kVersionTwo: { |
112 | TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( |
113 | context, input_condition, input_x, |
114 | input_y, &output_size)); |
115 | data->requires_broadcast = true; |
116 | break; |
117 | } |
118 | default: |
119 | return kTfLiteError; |
120 | } |
121 | } else { |
122 | output_size = TfLiteIntArrayCopy(input_x->dims); |
123 | } |
124 | |
125 | return context->ResizeTensor(context, output, output_size); |
126 | } |
127 | |
128 | TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { |
129 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
130 | const TfLiteTensor* input_condition; |
131 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition, |
132 | &input_condition)); |
133 | const TfLiteTensor* input_x; |
134 | TF_LITE_ENSURE_OK(context, |
135 | GetInputSafe(context, node, kInputTensorX, &input_x)); |
136 | const TfLiteTensor* input_y; |
137 | TF_LITE_ENSURE_OK(context, |
138 | GetInputSafe(context, node, kInputTensorY, &input_y)); |
139 | TfLiteTensor* output; |
140 | TF_LITE_ENSURE_OK(context, |
141 | GetOutputSafe(context, node, kOutputTensor, &output)); |
142 | |
143 | #define TF_LITE_SELECT(type, op) \ |
144 | reference_ops::op(GetTensorShape(input_condition), \ |
145 | GetTensorData<bool>(input_condition), \ |
146 | GetTensorShape(input_x), GetTensorData<type>(input_x), \ |
147 | GetTensorShape(input_y), GetTensorData<type>(input_y), \ |
148 | GetTensorShape(output), GetTensorData<type>(output)); |
149 | |
150 | #define TF_LITE_SWITCH(type, op) \ |
151 | switch (type) { \ |
152 | break; \ |
153 | case kTfLiteBool: \ |
154 | TF_LITE_SELECT(bool, op); \ |
155 | break; \ |
156 | case kTfLiteFloat32: \ |
157 | TF_LITE_SELECT(float, op); \ |
158 | break; \ |
159 | case kTfLiteUInt8: \ |
160 | TF_LITE_SELECT(uint8_t, op); \ |
161 | break; \ |
162 | case kTfLiteInt8: \ |
163 | TF_LITE_SELECT(int8_t, op); \ |
164 | break; \ |
165 | case kTfLiteInt16: \ |
166 | TF_LITE_SELECT(int16_t, op); \ |
167 | break; \ |
168 | case kTfLiteInt32: \ |
169 | TF_LITE_SELECT(int32_t, op); \ |
170 | break; \ |
171 | case kTfLiteInt64: \ |
172 | TF_LITE_SELECT(int64_t, op); \ |
173 | break; \ |
174 | default: \ |
175 | TF_LITE_KERNEL_LOG(context, \ |
176 | "Does not support type other than bool|float|int, " \ |
177 | "got %d", \ |
178 | type); \ |
179 | return kTfLiteError; \ |
180 | } |
181 | |
182 | if (data->has_low_rank_input_condition) { |
183 | TF_LITE_SWITCH(input_x->type, RankOneSelect); |
184 | } else if (data->requires_broadcast) { |
185 | TF_LITE_SWITCH(input_x->type, BroadcastSelect5DSlow); |
186 | } else { |
187 | TF_LITE_SWITCH(input_x->type, Select); |
188 | } |
189 | |
190 | #undef TF_LITE_SELECT |
191 | #undef TF_LITE_SWITCH |
192 | return kTfLiteOk; |
193 | } |
194 | |
195 | } // namespace select |
196 | |
197 | // Select op selects values of 'x' if the corresponding value of 'condition' is |
198 | // true or the value of 'y' if false. There are valid condition input sizes: |
199 | // |
200 | // 1. Either the same shape (in which case the select is elementwise), or |
201 | // 2. condition must be Rank 1 and match over the first dimension, or |
202 | // 3. condition is scalar |
203 | TfLiteRegistration* Register_SELECT() { |
204 | static TfLiteRegistration r = {select::SelectInit, select::SelectFree, |
205 | select::SelectPrepare<select::kVersionOne>, |
206 | select::SelectEval}; |
207 | return &r; |
208 | } |
209 | |
210 | // SelectV2 op selects values of 'x' if the corresponding value of 'condition' |
211 | // is true or the value of 'y' if false. There are valid condition input sizes: |
212 | // |
213 | // 1. Either the same shape (in which case the select is elementwise), or |
214 | // 2. Broadcastable shapes between 'condition', 'x' and 'y'. |
215 | TfLiteRegistration* Register_SELECT_V2() { |
216 | static TfLiteRegistration r = {select::SelectInit, select::SelectFree, |
217 | select::SelectPrepare<select::kVersionTwo>, |
218 | select::SelectEval}; |
219 | return &r; |
220 | } |
221 | |
222 | } // namespace builtin |
223 | } // namespace ops |
224 | } // namespace tflite |
225 | |