1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
19 | #define EIGEN_USE_GPU |
20 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
21 | |
22 | #include "tensorflow/core/framework/bounds_check.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/kernels/cwise_ops_common.h" |
25 | #include "tensorflow/core/platform/prefetch.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | typedef Eigen::ThreadPoolDevice CPUDevice; |
30 | typedef Eigen::GpuDevice GPUDevice; |
31 | |
32 | |
33 | namespace functor { |
34 | template <typename Device, typename T> |
35 | struct SelectScalarHandler; |
36 | } // namespace functor |
37 | |
38 | template <typename Device, typename T> |
39 | class SelectOp : public OpKernel { |
40 | public: |
41 | explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {} |
42 | |
43 | void Compute(OpKernelContext* ctx) override { |
44 | const Tensor* cond = &ctx->input(0); |
45 | const Tensor* then = &ctx->input(1); |
46 | const Tensor* else_ = &ctx->input(2); |
47 | |
48 | if (TensorShapeUtils::IsScalar(cond->shape())) { |
49 | ComputeScalar(ctx, cond, then, else_); |
50 | return; |
51 | } |
52 | |
53 | bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) && |
54 | !TensorShapeUtils::IsVector(then->shape())); |
55 | |
56 | if (broadcasting) { |
57 | ComputeBroadcasting(ctx, cond, then, else_); |
58 | } else { |
59 | ComputeElementwise(ctx, cond, then, else_); |
60 | } |
61 | } |
62 | |
63 | protected: |
64 | void ComputeBroadcasting(OpKernelContext* ctx, const Tensor* cond, |
65 | const Tensor* then, const Tensor* else_) { |
66 | // Preliminary validation of sizes. |
67 | OP_REQUIRES( |
68 | ctx, TensorShapeUtils::IsVector(cond->shape()), |
69 | errors::InvalidArgument("'cond' must be a vector, but saw shape: " , |
70 | cond->shape().DebugString())); |
71 | OP_REQUIRES( |
72 | ctx, |
73 | FastBoundsCheck(cond->NumElements(), |
74 | std::numeric_limits<Eigen::DenseIndex>::max()), |
75 | errors::InvalidArgument("cond vector larger than " , |
76 | std::numeric_limits<Eigen::DenseIndex>::max())); |
77 | OP_REQUIRES( |
78 | ctx, |
79 | FastBoundsCheck(then->flat_outer_dims<T>().dimension(1), |
80 | std::numeric_limits<Eigen::DenseIndex>::max()), |
81 | errors::InvalidArgument("flat outer dims dim 1 size >= " , |
82 | std::numeric_limits<Eigen::DenseIndex>::max())); |
83 | |
84 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then->shape()), |
85 | errors::InvalidArgument( |
86 | "'then' must be at least a vector, but saw shape: " , |
87 | then->shape().DebugString())); |
88 | OP_REQUIRES( |
89 | ctx, then->shape().dim_size(0) == cond->NumElements(), |
90 | errors::InvalidArgument( |
91 | "Number of batches of 'then' must match size of 'cond', but saw: " , |
92 | then->shape().dim_size(0), " vs. " , cond->NumElements())); |
93 | OP_REQUIRES( |
94 | ctx, then->shape().IsSameSize(else_->shape()), |
95 | errors::InvalidArgument( |
96 | "'then' and 'else' must have the same size. but received: " , |
97 | then->shape().DebugString(), " vs. " , |
98 | else_->shape().DebugString())); |
99 | |
100 | Tensor* output = nullptr; |
101 | OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( |
102 | {"t" , "e" }, "output" , then->shape(), &output)); |
103 | if (output->NumElements() > 0) { |
104 | functor::BatchSelectFunctor<Device, T> func; |
105 | func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(), |
106 | cond->vec<bool>(), then->flat_outer_dims<T>(), |
107 | else_->flat_outer_dims<T>()); |
108 | } |
109 | } |
110 | |
111 | void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond, |
112 | const Tensor* then, const Tensor* else_) { |
113 | if (!ctx->ValidateInputsAreSameShape(this)) return; |
114 | Tensor* output = nullptr; |
115 | OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( |
116 | {"t" , "e" }, "output" , then->shape(), &output)); |
117 | if (output->NumElements() > 0) { |
118 | functor::SelectFunctor<Device, T> func; |
119 | func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(), |
120 | then->flat<T>(), else_->flat<T>()); |
121 | } |
122 | } |
123 | |
124 | void ComputeScalar(OpKernelContext* ctx, const Tensor* cond, |
125 | const Tensor* then, const Tensor* else_) { |
126 | OP_REQUIRES( |
127 | ctx, then->shape().IsSameSize(else_->shape()), |
128 | errors::InvalidArgument( |
129 | "'then' and 'else' must have the same size. but received: " , |
130 | then->shape().DebugString(), " vs. " , |
131 | else_->shape().DebugString())); |
132 | |
133 | functor::SelectScalarHandler<Device, T> handler; |
134 | handler(ctx, cond, then, else_); |
135 | } |
136 | |
137 | private: |
138 | TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); |
139 | }; |
140 | template <typename Device, typename T> |
141 | class SelectV2Op : public OpKernel { |
142 | public: |
143 | explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {} |
144 | |
145 | void Compute(OpKernelContext* ctx) override { |
146 | const Tensor* cond = &ctx->input(0); |
147 | const Tensor* then = &ctx->input(1); |
148 | const Tensor* else_ = &ctx->input(2); |
149 | |
150 | // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()), |
151 | // This matches the behavior of numpy. |
152 | BCastList<3> bcast({cond->shape().dim_sizes(), then->shape().dim_sizes(), |
153 | else_->shape().dim_sizes()}, |
154 | false); |
155 | OP_REQUIRES(ctx, bcast.IsValid(), |
156 | errors::InvalidArgument( |
157 | "condition " , cond->shape().DebugString(), ", then " , |
158 | then->shape().DebugString(), ", and else " , |
159 | else_->shape().DebugString(), " must be broadcastable" )); |
160 | |
161 | // Broadcast `cond`, `then` and `else` to combined shape, |
162 | // in order to obtain the reshape. |
163 | BCast cond_bcast(bcast.output_shape(), cond->shape().dim_sizes(), false); |
164 | BCast then_bcast(bcast.output_shape(), then->shape().dim_sizes(), false); |
165 | BCast else_bcast(bcast.output_shape(), else_->shape().dim_sizes(), false); |
166 | OP_REQUIRES( |
167 | ctx, |
168 | cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(), |
169 | errors::InvalidArgument("condition " , cond->shape().DebugString(), |
170 | ", then " , then->shape().DebugString(), |
171 | ", and else " , else_->shape().DebugString(), |
172 | " must be broadcastable" )); |
173 | |
174 | // Combined shape should be the final shape. |
175 | OP_REQUIRES( |
176 | ctx, |
177 | cond_bcast.output_shape() == bcast.output_shape() && |
178 | then_bcast.output_shape() == bcast.output_shape() && |
179 | else_bcast.output_shape() == bcast.output_shape(), |
180 | errors::InvalidArgument("condition " , cond->shape().DebugString(), |
181 | ", then " , then->shape().DebugString(), |
182 | ", and else " , else_->shape().DebugString(), |
183 | " must be broadcastable to the same shape" )); |
184 | |
185 | Tensor* output = nullptr; |
186 | const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); |
187 | OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( |
188 | {"t" , "e" }, "output" , output_shape, &output)); |
189 | |
190 | if (output->NumElements() == 0) { |
191 | return; |
192 | } |
193 | |
194 | #define HANDLE_DIM(NDIMS) \ |
195 | { \ |
196 | functor::BCastSelectFunctor<Device, T, NDIMS> func; \ |
197 | func(ctx->eigen_device<Device>(), \ |
198 | output->shaped<T, NDIMS>(bcast.result_shape()), \ |
199 | cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \ |
200 | then->template shaped<T, NDIMS>(then_bcast.y_reshape()), \ |
201 | else_->template shaped<T, NDIMS>(else_bcast.y_reshape()), \ |
202 | BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()), \ |
203 | BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()), \ |
204 | BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast())); \ |
205 | } |
206 | |
207 | const int ndims = static_cast<int>(bcast.result_shape().size()); |
208 | switch (ndims) { |
209 | case 1: |
210 | HANDLE_DIM(1); |
211 | break; |
212 | case 2: |
213 | HANDLE_DIM(2); |
214 | break; |
215 | case 3: |
216 | HANDLE_DIM(3); |
217 | break; |
218 | case 4: |
219 | HANDLE_DIM(4); |
220 | break; |
221 | case 5: |
222 | HANDLE_DIM(5); |
223 | break; |
224 | case 6: |
225 | HANDLE_DIM(6); |
226 | break; |
227 | case 7: |
228 | HANDLE_DIM(7); |
229 | break; |
230 | case 8: |
231 | HANDLE_DIM(8); |
232 | break; |
233 | default: |
234 | ctx->SetStatus(errors::Unimplemented( |
235 | "Broadcast between " , ctx->input(0).shape().DebugString(), " and " , |
236 | ctx->input(1).shape().DebugString(), " is not supported yet." )); |
237 | break; |
238 | } |
239 | } |
240 | |
241 | private: |
242 | TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op); |
243 | }; |
244 | |
245 | #define REGISTER_SELECT(type) \ |
246 | REGISTER_KERNEL_BUILDER( \ |
247 | Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
248 | SelectOp<CPUDevice, type>); \ |
249 | REGISTER_KERNEL_BUILDER( \ |
250 | Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
251 | SelectV2Op<CPUDevice, type>); |
252 | |
253 | TF_CALL_ALL_TYPES(REGISTER_SELECT); |
254 | |
255 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
256 | |
257 | #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) |
258 | |
259 | // Registration of the GPU implementations. |
260 | #define REGISTER_SELECT_GPU(type) \ |
261 | REGISTER_KERNEL_BUILDER( \ |
262 | Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
263 | SelectOp<GPUDevice, type>); \ |
264 | REGISTER_KERNEL_BUILDER( \ |
265 | Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
266 | SelectV2Op<GPUDevice, type>); |
267 | |
268 | REGISTER_SELECT_GPU(bool); |
269 | REGISTER_SELECT_GPU(Eigen::half); |
270 | REGISTER_SELECT_GPU(float); |
271 | REGISTER_SELECT_GPU(double); |
272 | REGISTER_SELECT_GPU(int32); |
273 | REGISTER_SELECT_GPU(int64); |
274 | REGISTER_SELECT_GPU(complex64); |
275 | REGISTER_SELECT_GPU(complex128); |
276 | |
277 | #undef REGISTER_SELECT_GPU |
278 | |
279 | #else |
280 | |
281 | #define REGISTER_SELECT_GPU(type) \ |
282 | REGISTER_KERNEL_BUILDER( \ |
283 | Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
284 | SelectOp<GPUDevice, type>); |
285 | |
286 | REGISTER_SELECT_GPU(bool); |
287 | REGISTER_SELECT_GPU(Eigen::half); |
288 | REGISTER_SELECT_GPU(float); |
289 | REGISTER_SELECT_GPU(double); |
290 | REGISTER_SELECT_GPU(int32); |
291 | REGISTER_SELECT_GPU(int64_t); |
292 | REGISTER_SELECT_GPU(complex64); |
293 | REGISTER_SELECT_GPU(complex128); |
294 | |
295 | #undef REGISTER_SELECT_GPU |
296 | #endif |
297 | |
298 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
299 | |
300 | |
301 | namespace functor { |
302 | |
303 | // CPU Specializations of Select functors. |
304 | template <typename Device, typename T> |
305 | struct SelectFunctorBase { |
306 | void operator()(const Device& d, typename TTypes<T>::Flat out, |
307 | typename TTypes<bool>::ConstFlat cond_flat, |
308 | typename TTypes<T>::ConstFlat then_flat, |
309 | typename TTypes<T>::ConstFlat else_flat) { |
310 | Assign(d, out, cond_flat.select(then_flat, else_flat)); |
311 | } |
312 | }; |
313 | |
314 | template <typename T> |
315 | struct SelectFunctor<CPUDevice, T> : SelectFunctorBase<CPUDevice, T> {}; |
316 | |
317 | template <typename Device, typename T> |
318 | struct SelectScalarHandler { |
319 | void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then, |
320 | const Tensor* else_) { |
321 | Tensor* output = nullptr; |
322 | OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( |
323 | {"t" , "e" }, "output" , then->shape(), &output)); |
324 | |
325 | if (output->NumElements() > 0) { |
326 | functor::SelectScalarFunctor<Device, T> func; |
327 | TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>(); |
328 | func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar, |
329 | then->flat<T>(), else_->flat<T>()); |
330 | } |
331 | } |
332 | }; |
333 | |
334 | // Specialization for CPU device. Forward input to output depending on the |
335 | // `cond` value. |
336 | // TODO(sjhwang): Consider specializing for GPUDevice as well by using |
337 | // GPUDevice::memcpyDeviceToHost() to fetch bool value. |
338 | template <typename T> |
339 | struct SelectScalarHandler<CPUDevice, T> { |
340 | void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then, |
341 | const Tensor* else_) { |
342 | if (cond->scalar<bool>()()) { |
343 | OP_REQUIRES_OK(ctx, ctx->set_output("output" , *then)); |
344 | } else { |
345 | OP_REQUIRES_OK(ctx, ctx->set_output("output" , *else_)); |
346 | } |
347 | } |
348 | }; |
349 | |
350 | |
351 | template <typename Device, typename T> |
352 | struct BatchSelectFunctorBase { |
353 | void operator()(const Device& d, |
354 | typename TTypes<T>::Matrix output_flat_outer_dims, |
355 | TTypes<bool>::ConstVec cond_vec, |
356 | typename TTypes<T>::ConstMatrix then_flat_outer_dims, |
357 | typename TTypes<T>::ConstMatrix else_flat_outer_dims) { |
358 | const Eigen::DenseIndex batch = cond_vec.size(); |
359 | const Eigen::DenseIndex all_but_batch = then_flat_outer_dims.dimension(1); |
360 | |
361 | Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> broadcast_dims; |
362 | broadcast_dims.set(1, all_but_batch); |
363 | Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1> > reshape_dims; |
364 | reshape_dims.set(0, batch); |
365 | |
366 | Assign(d, output_flat_outer_dims, |
367 | cond_vec.reshape(reshape_dims) |
368 | .broadcast(broadcast_dims) |
369 | .select(then_flat_outer_dims, else_flat_outer_dims)); |
370 | } |
371 | }; |
372 | |
373 | // A fast implementation on CPU, using loop to get rid of broadcasting. |
374 | template <typename T> |
375 | struct BatchSelectFunctor<CPUDevice, T> { |
376 | void operator()(const CPUDevice& d, |
377 | typename TTypes<T>::Matrix output_flat_outer_dims, |
378 | TTypes<bool>::ConstVec cond_vec, |
379 | typename TTypes<T>::ConstMatrix then_flat_outer_dims, |
380 | typename TTypes<T>::ConstMatrix else_flat_outer_dims) { |
381 | const size_t batch = cond_vec.size(); |
382 | const size_t batch_size = then_flat_outer_dims.size() / batch; |
383 | T* output = output_flat_outer_dims.data(); |
384 | const bool* c = cond_vec.data(); |
385 | const T* t = then_flat_outer_dims.data(); |
386 | const T* e = else_flat_outer_dims.data(); |
387 | |
388 | auto work = [batch_size, output, c, t, e](int64_t start, int64_t end) { |
389 | for (size_t i = start; i < end; ++i) { |
390 | size_t offset = i * batch_size; |
391 | port::prefetch<port::PREFETCH_HINT_NTA>( |
392 | reinterpret_cast<const void*>(&t[offset + batch_size])); |
393 | port::prefetch<port::PREFETCH_HINT_NTA>( |
394 | reinterpret_cast<const void*>(&e[offset + batch_size])); |
395 | port::prefetch<port::PREFETCH_HINT_NTA>( |
396 | reinterpret_cast<const void*>(&c[i + 1])); |
397 | if (c[i]) { |
398 | for (size_t j = 0; j < batch_size; ++j) { |
399 | output[offset + j] = t[offset + j]; |
400 | } |
401 | } else { |
402 | for (size_t j = 0; j < batch_size; ++j) { |
403 | output[offset + j] = e[offset + j]; |
404 | } |
405 | } |
406 | } |
407 | }; |
408 | auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2, // ld bytes |
409 | sizeof(T) * batch_size, // st bytes |
410 | batch_size); // compute cycles |
411 | d.parallelFor(batch, cost, work); |
412 | } |
413 | }; |
414 | |
415 | template <typename Device, typename T, int NDIMS> |
416 | struct BCastSelectFunctorBase { |
417 | void operator()(const Device& d, |
418 | typename TTypes<T, NDIMS>::Tensor output_tensor, |
419 | typename TTypes<bool, NDIMS>::ConstTensor cond_tensor, |
420 | typename TTypes<T, NDIMS>::ConstTensor then_tensor, |
421 | typename TTypes<T, NDIMS>::ConstTensor else_tensor, |
422 | typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast, |
423 | typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast, |
424 | typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) { |
425 | output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) |
426 | .select(then_tensor.broadcast(then_bcast), |
427 | else_tensor.broadcast(else_bcast)); |
428 | } |
429 | }; |
430 | |
431 | template <typename T, int NDIMS> |
432 | struct BCastSelectFunctor<CPUDevice, T, NDIMS> |
433 | : BCastSelectFunctorBase<CPUDevice, T, NDIMS> {}; |
434 | |
435 | |
436 | } // namespace functor |
437 | |
438 | } // namespace tensorflow |
439 | |