1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19#define EIGEN_USE_GPU
20#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21
22#include "tensorflow/core/framework/bounds_check.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/kernels/cwise_ops_common.h"
25#include "tensorflow/core/platform/prefetch.h"
26
27namespace tensorflow {
28
29typedef Eigen::ThreadPoolDevice CPUDevice;
30typedef Eigen::GpuDevice GPUDevice;
31
32
33namespace functor {
34template <typename Device, typename T>
35struct SelectScalarHandler;
36} // namespace functor
37
38template <typename Device, typename T>
39class SelectOp : public OpKernel {
40 public:
41 explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
42
43 void Compute(OpKernelContext* ctx) override {
44 const Tensor* cond = &ctx->input(0);
45 const Tensor* then = &ctx->input(1);
46 const Tensor* else_ = &ctx->input(2);
47
48 if (TensorShapeUtils::IsScalar(cond->shape())) {
49 ComputeScalar(ctx, cond, then, else_);
50 return;
51 }
52
53 bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) &&
54 !TensorShapeUtils::IsVector(then->shape()));
55
56 if (broadcasting) {
57 ComputeBroadcasting(ctx, cond, then, else_);
58 } else {
59 ComputeElementwise(ctx, cond, then, else_);
60 }
61 }
62
63 protected:
64 void ComputeBroadcasting(OpKernelContext* ctx, const Tensor* cond,
65 const Tensor* then, const Tensor* else_) {
66 // Preliminary validation of sizes.
67 OP_REQUIRES(
68 ctx, TensorShapeUtils::IsVector(cond->shape()),
69 errors::InvalidArgument("'cond' must be a vector, but saw shape: ",
70 cond->shape().DebugString()));
71 OP_REQUIRES(
72 ctx,
73 FastBoundsCheck(cond->NumElements(),
74 std::numeric_limits<Eigen::DenseIndex>::max()),
75 errors::InvalidArgument("cond vector larger than ",
76 std::numeric_limits<Eigen::DenseIndex>::max()));
77 OP_REQUIRES(
78 ctx,
79 FastBoundsCheck(then->flat_outer_dims<T>().dimension(1),
80 std::numeric_limits<Eigen::DenseIndex>::max()),
81 errors::InvalidArgument("flat outer dims dim 1 size >= ",
82 std::numeric_limits<Eigen::DenseIndex>::max()));
83
84 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then->shape()),
85 errors::InvalidArgument(
86 "'then' must be at least a vector, but saw shape: ",
87 then->shape().DebugString()));
88 OP_REQUIRES(
89 ctx, then->shape().dim_size(0) == cond->NumElements(),
90 errors::InvalidArgument(
91 "Number of batches of 'then' must match size of 'cond', but saw: ",
92 then->shape().dim_size(0), " vs. ", cond->NumElements()));
93 OP_REQUIRES(
94 ctx, then->shape().IsSameSize(else_->shape()),
95 errors::InvalidArgument(
96 "'then' and 'else' must have the same size. but received: ",
97 then->shape().DebugString(), " vs. ",
98 else_->shape().DebugString()));
99
100 Tensor* output = nullptr;
101 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
102 {"t", "e"}, "output", then->shape(), &output));
103 if (output->NumElements() > 0) {
104 functor::BatchSelectFunctor<Device, T> func;
105 func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
106 cond->vec<bool>(), then->flat_outer_dims<T>(),
107 else_->flat_outer_dims<T>());
108 }
109 }
110
111 void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond,
112 const Tensor* then, const Tensor* else_) {
113 if (!ctx->ValidateInputsAreSameShape(this)) return;
114 Tensor* output = nullptr;
115 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
116 {"t", "e"}, "output", then->shape(), &output));
117 if (output->NumElements() > 0) {
118 functor::SelectFunctor<Device, T> func;
119 func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
120 then->flat<T>(), else_->flat<T>());
121 }
122 }
123
124 void ComputeScalar(OpKernelContext* ctx, const Tensor* cond,
125 const Tensor* then, const Tensor* else_) {
126 OP_REQUIRES(
127 ctx, then->shape().IsSameSize(else_->shape()),
128 errors::InvalidArgument(
129 "'then' and 'else' must have the same size. but received: ",
130 then->shape().DebugString(), " vs. ",
131 else_->shape().DebugString()));
132
133 functor::SelectScalarHandler<Device, T> handler;
134 handler(ctx, cond, then, else_);
135 }
136
137 private:
138 TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
139};
140template <typename Device, typename T>
141class SelectV2Op : public OpKernel {
142 public:
143 explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
144
145 void Compute(OpKernelContext* ctx) override {
146 const Tensor* cond = &ctx->input(0);
147 const Tensor* then = &ctx->input(1);
148 const Tensor* else_ = &ctx->input(2);
149
150 // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
151 // This matches the behavior of numpy.
152 BCastList<3> bcast({cond->shape().dim_sizes(), then->shape().dim_sizes(),
153 else_->shape().dim_sizes()},
154 false);
155 OP_REQUIRES(ctx, bcast.IsValid(),
156 errors::InvalidArgument(
157 "condition ", cond->shape().DebugString(), ", then ",
158 then->shape().DebugString(), ", and else ",
159 else_->shape().DebugString(), " must be broadcastable"));
160
161 // Broadcast `cond`, `then` and `else` to combined shape,
162 // in order to obtain the reshape.
163 BCast cond_bcast(bcast.output_shape(), cond->shape().dim_sizes(), false);
164 BCast then_bcast(bcast.output_shape(), then->shape().dim_sizes(), false);
165 BCast else_bcast(bcast.output_shape(), else_->shape().dim_sizes(), false);
166 OP_REQUIRES(
167 ctx,
168 cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(),
169 errors::InvalidArgument("condition ", cond->shape().DebugString(),
170 ", then ", then->shape().DebugString(),
171 ", and else ", else_->shape().DebugString(),
172 " must be broadcastable"));
173
174 // Combined shape should be the final shape.
175 OP_REQUIRES(
176 ctx,
177 cond_bcast.output_shape() == bcast.output_shape() &&
178 then_bcast.output_shape() == bcast.output_shape() &&
179 else_bcast.output_shape() == bcast.output_shape(),
180 errors::InvalidArgument("condition ", cond->shape().DebugString(),
181 ", then ", then->shape().DebugString(),
182 ", and else ", else_->shape().DebugString(),
183 " must be broadcastable to the same shape"));
184
185 Tensor* output = nullptr;
186 const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
187 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
188 {"t", "e"}, "output", output_shape, &output));
189
190 if (output->NumElements() == 0) {
191 return;
192 }
193
194#define HANDLE_DIM(NDIMS) \
195 { \
196 functor::BCastSelectFunctor<Device, T, NDIMS> func; \
197 func(ctx->eigen_device<Device>(), \
198 output->shaped<T, NDIMS>(bcast.result_shape()), \
199 cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \
200 then->template shaped<T, NDIMS>(then_bcast.y_reshape()), \
201 else_->template shaped<T, NDIMS>(else_bcast.y_reshape()), \
202 BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()), \
203 BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()), \
204 BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast())); \
205 }
206
207 const int ndims = static_cast<int>(bcast.result_shape().size());
208 switch (ndims) {
209 case 1:
210 HANDLE_DIM(1);
211 break;
212 case 2:
213 HANDLE_DIM(2);
214 break;
215 case 3:
216 HANDLE_DIM(3);
217 break;
218 case 4:
219 HANDLE_DIM(4);
220 break;
221 case 5:
222 HANDLE_DIM(5);
223 break;
224 case 6:
225 HANDLE_DIM(6);
226 break;
227 case 7:
228 HANDLE_DIM(7);
229 break;
230 case 8:
231 HANDLE_DIM(8);
232 break;
233 default:
234 ctx->SetStatus(errors::Unimplemented(
235 "Broadcast between ", ctx->input(0).shape().DebugString(), " and ",
236 ctx->input(1).shape().DebugString(), " is not supported yet."));
237 break;
238 }
239 }
240
241 private:
242 TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op);
243};
244
245#define REGISTER_SELECT(type) \
246 REGISTER_KERNEL_BUILDER( \
247 Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
248 SelectOp<CPUDevice, type>); \
249 REGISTER_KERNEL_BUILDER( \
250 Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
251 SelectV2Op<CPUDevice, type>);
252
253TF_CALL_ALL_TYPES(REGISTER_SELECT);
254
255#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
256
257#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
258
259// Registration of the GPU implementations.
260#define REGISTER_SELECT_GPU(type) \
261 REGISTER_KERNEL_BUILDER( \
262 Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
263 SelectOp<GPUDevice, type>); \
264 REGISTER_KERNEL_BUILDER( \
265 Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
266 SelectV2Op<GPUDevice, type>);
267
268REGISTER_SELECT_GPU(bool);
269REGISTER_SELECT_GPU(Eigen::half);
270REGISTER_SELECT_GPU(float);
271REGISTER_SELECT_GPU(double);
272REGISTER_SELECT_GPU(int32);
273REGISTER_SELECT_GPU(int64);
274REGISTER_SELECT_GPU(complex64);
275REGISTER_SELECT_GPU(complex128);
276
277#undef REGISTER_SELECT_GPU
278
279#else
280
281#define REGISTER_SELECT_GPU(type) \
282 REGISTER_KERNEL_BUILDER( \
283 Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
284 SelectOp<GPUDevice, type>);
285
286REGISTER_SELECT_GPU(bool);
287REGISTER_SELECT_GPU(Eigen::half);
288REGISTER_SELECT_GPU(float);
289REGISTER_SELECT_GPU(double);
290REGISTER_SELECT_GPU(int32);
291REGISTER_SELECT_GPU(int64_t);
292REGISTER_SELECT_GPU(complex64);
293REGISTER_SELECT_GPU(complex128);
294
295#undef REGISTER_SELECT_GPU
296#endif
297
298#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
299
300
301namespace functor {
302
303// CPU Specializations of Select functors.
304template <typename Device, typename T>
305struct SelectFunctorBase {
306 void operator()(const Device& d, typename TTypes<T>::Flat out,
307 typename TTypes<bool>::ConstFlat cond_flat,
308 typename TTypes<T>::ConstFlat then_flat,
309 typename TTypes<T>::ConstFlat else_flat) {
310 Assign(d, out, cond_flat.select(then_flat, else_flat));
311 }
312};
313
314template <typename T>
315struct SelectFunctor<CPUDevice, T> : SelectFunctorBase<CPUDevice, T> {};
316
317template <typename Device, typename T>
318struct SelectScalarHandler {
319 void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
320 const Tensor* else_) {
321 Tensor* output = nullptr;
322 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
323 {"t", "e"}, "output", then->shape(), &output));
324
325 if (output->NumElements() > 0) {
326 functor::SelectScalarFunctor<Device, T> func;
327 TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
328 func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
329 then->flat<T>(), else_->flat<T>());
330 }
331 }
332};
333
334// Specialization for CPU device. Forward input to output depending on the
335// `cond` value.
336// TODO(sjhwang): Consider specializing for GPUDevice as well by using
337// GPUDevice::memcpyDeviceToHost() to fetch bool value.
338template <typename T>
339struct SelectScalarHandler<CPUDevice, T> {
340 void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
341 const Tensor* else_) {
342 if (cond->scalar<bool>()()) {
343 OP_REQUIRES_OK(ctx, ctx->set_output("output", *then));
344 } else {
345 OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_));
346 }
347 }
348};
349
350
351template <typename Device, typename T>
352struct BatchSelectFunctorBase {
353 void operator()(const Device& d,
354 typename TTypes<T>::Matrix output_flat_outer_dims,
355 TTypes<bool>::ConstVec cond_vec,
356 typename TTypes<T>::ConstMatrix then_flat_outer_dims,
357 typename TTypes<T>::ConstMatrix else_flat_outer_dims) {
358 const Eigen::DenseIndex batch = cond_vec.size();
359 const Eigen::DenseIndex all_but_batch = then_flat_outer_dims.dimension(1);
360
361 Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> broadcast_dims;
362 broadcast_dims.set(1, all_but_batch);
363 Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1> > reshape_dims;
364 reshape_dims.set(0, batch);
365
366 Assign(d, output_flat_outer_dims,
367 cond_vec.reshape(reshape_dims)
368 .broadcast(broadcast_dims)
369 .select(then_flat_outer_dims, else_flat_outer_dims));
370 }
371};
372
373// A fast implementation on CPU, using loop to get rid of broadcasting.
374template <typename T>
375struct BatchSelectFunctor<CPUDevice, T> {
376 void operator()(const CPUDevice& d,
377 typename TTypes<T>::Matrix output_flat_outer_dims,
378 TTypes<bool>::ConstVec cond_vec,
379 typename TTypes<T>::ConstMatrix then_flat_outer_dims,
380 typename TTypes<T>::ConstMatrix else_flat_outer_dims) {
381 const size_t batch = cond_vec.size();
382 const size_t batch_size = then_flat_outer_dims.size() / batch;
383 T* output = output_flat_outer_dims.data();
384 const bool* c = cond_vec.data();
385 const T* t = then_flat_outer_dims.data();
386 const T* e = else_flat_outer_dims.data();
387
388 auto work = [batch_size, output, c, t, e](int64_t start, int64_t end) {
389 for (size_t i = start; i < end; ++i) {
390 size_t offset = i * batch_size;
391 port::prefetch<port::PREFETCH_HINT_NTA>(
392 reinterpret_cast<const void*>(&t[offset + batch_size]));
393 port::prefetch<port::PREFETCH_HINT_NTA>(
394 reinterpret_cast<const void*>(&e[offset + batch_size]));
395 port::prefetch<port::PREFETCH_HINT_NTA>(
396 reinterpret_cast<const void*>(&c[i + 1]));
397 if (c[i]) {
398 for (size_t j = 0; j < batch_size; ++j) {
399 output[offset + j] = t[offset + j];
400 }
401 } else {
402 for (size_t j = 0; j < batch_size; ++j) {
403 output[offset + j] = e[offset + j];
404 }
405 }
406 }
407 };
408 auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2, // ld bytes
409 sizeof(T) * batch_size, // st bytes
410 batch_size); // compute cycles
411 d.parallelFor(batch, cost, work);
412 }
413};
414
415template <typename Device, typename T, int NDIMS>
416struct BCastSelectFunctorBase {
417 void operator()(const Device& d,
418 typename TTypes<T, NDIMS>::Tensor output_tensor,
419 typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
420 typename TTypes<T, NDIMS>::ConstTensor then_tensor,
421 typename TTypes<T, NDIMS>::ConstTensor else_tensor,
422 typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
423 typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
424 typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
425 output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
426 .select(then_tensor.broadcast(then_bcast),
427 else_tensor.broadcast(else_bcast));
428 }
429};
430
431template <typename T, int NDIMS>
432struct BCastSelectFunctor<CPUDevice, T, NDIMS>
433 : BCastSelectFunctorBase<CPUDevice, T, NDIMS> {};
434
435
436} // namespace functor
437
438} // namespace tensorflow
439