1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/c_api_types.h"
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/internal/types.h"
23#include "tensorflow/lite/kernels/kernel_util.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace where {
29
30constexpr int kInputConditionTensor = 0;
31constexpr int kOutputTensor = 0;
32
33template <typename T>
34TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
35 const TfLiteTensor* cond_tensor,
36 TfLiteTensor* output_tensor) {
37 // Output tensor should have shape:
38 // (num_true, cond_rank), where num_true denotes the number of true values
39 // in condition.
40 const RuntimeShape& cond_shape = GetTensorShape(cond_tensor);
41 const int size = cond_shape.FlatSize();
42 const int cond_rank = cond_shape.DimensionsCount();
43 const T* cond_data = GetTensorData<T>(cond_tensor);
44
45 int true_count = 0;
46 for (int i = 0; i < size; ++i) {
47 if (cond_data[i] != T(0)) {
48 true_count++;
49 }
50 }
51 TfLiteIntArray* output_dims = TfLiteIntArrayCreate(2);
52 output_dims->data[0] = true_count;
53 output_dims->data[1] = cond_rank;
54 return context->ResizeTensor(context, output_tensor, output_dims);
55}
56
57template <typename T>
58TfLiteStatus PrepareOutput(TfLiteContext* context,
59 const TfLiteTensor* cond_tensor,
60 TfLiteTensor* output) {
61 // As output will be a 2D tensor of indices, use int64 to be consistent with
62 // tensorflow.
63 output->type = kTfLiteInt64;
64
65 // Exit early if cond is a non-const tensor. Set output tensor to dynamic so
66 // output size can be determined in Eval.
67 if (!IsConstantTensor(cond_tensor)) {
68 SetTensorToDynamic(output);
69 return kTfLiteOk;
70 }
71 return ResizeOutputTensor<T>(context, cond_tensor, output);
72}
73
74TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
75 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
76 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
77
78 const TfLiteTensor* cond_tensor;
79 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
80 &cond_tensor));
81 TfLiteTensor* output;
82 TF_LITE_ENSURE_OK(context,
83 GetOutputSafe(context, node, kOutputTensor, &output));
84
85 switch (cond_tensor->type) {
86 case kTfLiteBool:
87 return PrepareOutput<bool>(context, cond_tensor, output);
88 case kTfLiteFloat32:
89 return PrepareOutput<float>(context, cond_tensor, output);
90 case kTfLiteInt64:
91 return PrepareOutput<int64_t>(context, cond_tensor, output);
92 case kTfLiteInt32:
93 return PrepareOutput<int32_t>(context, cond_tensor, output);
94 case kTfLiteInt8:
95 return PrepareOutput<int8_t>(context, cond_tensor, output);
96 case kTfLiteUInt8:
97 return PrepareOutput<uint8_t>(context, cond_tensor, output);
98 case kTfLiteUInt32:
99 return PrepareOutput<uint32_t>(context, cond_tensor, output);
100 default:
101 TF_LITE_KERNEL_LOG(context,
102 "Condition tensor has unsupported type: '%s'.",
103 TfLiteTypeGetName(cond_tensor->type));
104 }
105 return kTfLiteOk;
106}
107
108TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
109 const TfLiteTensor* cond_tensor;
110 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
111 &cond_tensor));
112 TfLiteTensor* output;
113 TF_LITE_ENSURE_OK(context,
114 GetOutputSafe(context, node, kOutputTensor, &output));
115
116 if (IsDynamicTensor(output)) {
117 switch (cond_tensor->type) {
118 case kTfLiteBool:
119 TF_LITE_ENSURE_OK(
120 context, ResizeOutputTensor<bool>(context, cond_tensor, output));
121 break;
122 case kTfLiteFloat32:
123 TF_LITE_ENSURE_OK(
124 context, ResizeOutputTensor<float>(context, cond_tensor, output));
125 break;
126 case kTfLiteInt64:
127 TF_LITE_ENSURE_OK(
128 context, ResizeOutputTensor<int64_t>(context, cond_tensor, output));
129 break;
130 case kTfLiteInt32:
131 TF_LITE_ENSURE_OK(
132 context, ResizeOutputTensor<int32_t>(context, cond_tensor, output));
133 break;
134 case kTfLiteInt8:
135 TF_LITE_ENSURE_OK(
136 context, ResizeOutputTensor<int8_t>(context, cond_tensor, output));
137 break;
138 case kTfLiteUInt8:
139 TF_LITE_ENSURE_OK(
140 context, ResizeOutputTensor<uint8_t>(context, cond_tensor, output));
141 break;
142 case kTfLiteUInt32:
143 TF_LITE_ENSURE_OK(context, ResizeOutputTensor<uint32_t>(
144 context, cond_tensor, output));
145 break;
146 default:
147 TF_LITE_KERNEL_LOG(context,
148 "Condition tensor has unsupported type: '%s'.",
149 TfLiteTypeGetName(cond_tensor->type));
150 }
151 }
152
153 TfLiteIntArray* dims = cond_tensor->dims;
154 if (dims->size == 0) {
155 // Scalar tensors are not supported.
156 TF_LITE_KERNEL_LOG(context, "Where op requires condition w/ rank > 0");
157 return kTfLiteError;
158 }
159
160 switch (cond_tensor->type) {
161 case kTfLiteBool:
162 reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
163 GetTensorData<bool>(cond_tensor),
164 GetTensorData<int64_t>(output));
165 break;
166 case kTfLiteFloat32:
167 reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
168 GetTensorData<float>(cond_tensor),
169 GetTensorData<int64_t>(output));
170 break;
171 case kTfLiteInt64:
172 reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
173 GetTensorData<int64_t>(cond_tensor),
174 GetTensorData<int64_t>(output));
175 break;
176 case kTfLiteInt32:
177 reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
178 GetTensorData<int32_t>(cond_tensor),
179 GetTensorData<int64_t>(output));
180 break;
181 case kTfLiteInt8:
182 reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
183 GetTensorData<int8_t>(cond_tensor),
184 GetTensorData<int64_t>(output));
185 break;
186 case kTfLiteUInt8:
187 reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
188 GetTensorData<uint8_t>(cond_tensor),
189 GetTensorData<int64_t>(output));
190 break;
191 case kTfLiteUInt32:
192 reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
193 GetTensorData<uint32_t>(cond_tensor),
194 GetTensorData<int64_t>(output));
195 break;
196 default:
197 TF_LITE_KERNEL_LOG(context,
198 "Condition tensor has unsupported type: '%s'.",
199 TfLiteTypeGetName(cond_tensor->type));
200 }
201 return kTfLiteOk;
202}
203} // namespace where
204
205TfLiteRegistration* Register_WHERE() {
206 static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
207 where::Prepare, where::Eval};
208 return &r;
209}
210
211} // namespace builtin
212} // namespace ops
213} // namespace tflite
214