1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21#define EIGEN_USE_GPU
22#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24#include "tensorflow/core/kernels/where_op.h"
25
26#include <memory>
27#include <numeric>
28#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29#include "tensorflow/core/framework/bounds_check.h"
30#include "tensorflow/core/framework/op_kernel.h"
31#include "tensorflow/core/framework/register_types.h"
32#include "tensorflow/core/framework/tensor.h"
33#include "tensorflow/core/framework/tensor_shape.h"
34#include "tensorflow/core/framework/tensor_types.h"
35#include "tensorflow/core/framework/types.h"
36#include "tensorflow/core/platform/logging.h"
37#include "tensorflow/core/platform/macros.h"
38#include "tensorflow/core/platform/types.h"
39
40#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
41#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
42#include "tensorflow/core/util/gpu_solvers.h"
43#if GOOGLE_CUDA
44#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h"
45using stream_executor::cuda::ScopedActivateExecutorContext;
46#elif TENSORFLOW_USE_ROCM
47#include "tensorflow/core/platform/rocm.h"
48using stream_executor::rocm::ScopedActivateExecutorContext;
49#endif // TENSORFLOW_USE_ROCM
50#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51
52namespace tensorflow {
53
54typedef Eigen::ThreadPoolDevice CPUDevice;
55typedef Eigen::GpuDevice GPUDevice;
56
57namespace functor {
58
59namespace {
60template <typename T>
61int64_t CountAccumulator(const T* begin, const T* end) {
62 return std::accumulate(begin, end, 0LL, [](int64_t accum, const T& val) {
63 return accum + (val != T(0));
64 });
65}
66
67template <>
68int64_t CountAccumulator<bool>(const bool* begin, const bool* end) {
69 return std::accumulate(begin, end, 0LL);
70}
71
72} // namespace
73
74template <typename T>
75struct NumTrue<CPUDevice, T, int64_t> {
76 static Status Compute(OpKernelContext* ctx, const CPUDevice& d,
77 typename TTypes<T>::ConstFlat input,
78 TTypes<int64_t>::UnalignedScalar num_true) {
79 num_true() = CountAccumulator<T>(input.data(), input.data() + input.size());
80 return OkStatus();
81 }
82};
83
84template <int DIMS, typename T, typename TIndex>
85struct Where<CPUDevice, DIMS, T, TIndex> {
86 EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
87 typename TTypes<int64_t>::Matrix output,
88 const typename Eigen::DSizes<TIndex, DIMS>& strides, TIndex true_n,
89 TIndex index) {
90 for (int i = 0; i < DIMS; ++i) {
91 output(true_n, i) = index / strides[i];
92 index -= output(true_n, i) * strides[i];
93 }
94 }
95
96 EIGEN_ALWAYS_INLINE static Status Compute(
97 OpKernelContext* ctx, const CPUDevice& d,
98 typename TTypes<T, DIMS>::ConstTensor input,
99 typename TTypes<int64_t>::Matrix output, TIndex* found_true) {
100 Eigen::DSizes<Eigen::DenseIndex, DIMS> dims = input.dimensions();
101 Eigen::DSizes<TIndex, DIMS> strides;
102
103 EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
104 static_cast<int>(Eigen::RowMajor)),
105 INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
106
107 strides[DIMS - 1] = 1;
108 for (int i = DIMS - 2; i >= 0; --i) {
109 strides[i] = strides[i + 1] * dims[i + 1];
110 }
111
112 Eigen::DenseIndex output_size = output.dimension(0);
113 for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
114 if (input.data()[n] != T(0)) {
115 if (FastBoundsCheck(*found_true, output_size)) {
116 WriteIndexRowMajor(output, strides, *found_true, n);
117 }
118 ++*found_true;
119 }
120 }
121 return OkStatus();
122 }
123};
124
125} // namespace functor
126
127template <typename T>
128class WhereCPUOp : public OpKernel {
129 public:
130 explicit WhereCPUOp(OpKernelConstruction* context) : OpKernel(context) {}
131
132 void Compute(OpKernelContext* context) override {
133 const Tensor& input = context->input(0);
134
135 OP_REQUIRES(
136 context, input.dtype() != DT_HALF,
137 errors::Unimplemented("No WhereOp available for float16/half type on "
138 "CPU; dying in CPU WhereOp to avoid silently "
139 "creating costly copies from device."));
140
141 const int input_dims = input.dims();
142
143 int64_t num_true;
144 TTypes<int64_t>::UnalignedScalar num_true_t(&num_true);
145
146 Status s = functor::NumTrue<CPUDevice, T, int64_t>::Compute(
147 context, context->eigen_device<CPUDevice>(), input.flat<T>(),
148 num_true_t);
149 OP_REQUIRES_OK(context, s);
150 TensorShape output_shape({num_true, input_dims});
151 Tensor* output = nullptr;
152 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
153
154 // TODO(ebrevdo): Replace single-threaded copy with a multithreaded block
155 // copy by getting block counts above instead of a global NumTrue, then
156 // having each block filled in separate threads below.
157 int64_t found_true = 0;
158
159#define HANDLE_DIM(NDIM) \
160 case NDIM: { \
161 Status s = functor::Where<CPUDevice, NDIM, T, int64_t>::Compute( \
162 context, context->eigen_device<CPUDevice>(), input.tensor<T, NDIM>(), \
163 output->matrix<int64_t>(), &found_true); \
164 OP_REQUIRES_OK(context, s); \
165 } break;
166
167 switch (input_dims) {
168 HANDLE_DIM(1);
169 HANDLE_DIM(2);
170 HANDLE_DIM(3);
171 HANDLE_DIM(4);
172 HANDLE_DIM(5);
173 HANDLE_DIM(6);
174 HANDLE_DIM(7);
175 HANDLE_DIM(8);
176
177 default:
178 OP_REQUIRES(context, false,
179 errors::InvalidArgument(
180 "WhereOp : Unhandled input dimensions: ", input_dims));
181 }
182#undef HANDLE_DIM
183
184 OP_REQUIRES(
185 context, found_true == num_true_t(),
186 errors::InvalidArgument(
187 "WhereOp: Race condition between counting the number of true "
188 "elements and writing them. When counting, saw ",
189 num_true_t(), " elements; but when writing their indices, saw ",
190 found_true, " elements."));
191 }
192
193 private:
194 TF_DISALLOW_COPY_AND_ASSIGN(WhereCPUOp);
195};
196
197#define REGISTER_WHERE_OP(T) \
198 REGISTER_KERNEL_BUILDER( \
199 Name("Where").Device(DEVICE_CPU).TypeConstraint<T>("T"), WhereCPUOp<T>);
200
201TF_CALL_NUMBER_TYPES(REGISTER_WHERE_OP);
202TF_CALL_bool(REGISTER_WHERE_OP);
203
204#undef REGISTER_WHERE_OP
205
206#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
207
208namespace functor {
209
210#define DECLARE_GPU_NUMTRUE(T, Tindex) \
211 template <> \
212 Status NumTrue<GPUDevice, T, Tindex>::Compute( \
213 OpKernelContext* ctx, const GPUDevice& d, TTypes<T>::ConstFlat input, \
214 TTypes<Tindex>::UnalignedScalar num_true); \
215 extern template struct NumTrue<GPUDevice, T, Tindex>
216
217#define DECLARE_GPU_NUMTRUE_TYPE(T) \
218 DECLARE_GPU_NUMTRUE(T, int32); \
219 DECLARE_GPU_NUMTRUE(T, int64_t);
220
221TF_CALL_NUMBER_TYPES(DECLARE_GPU_NUMTRUE_TYPE);
222TF_CALL_bool(DECLARE_GPU_NUMTRUE_TYPE);
223
224#undef DECLARE_GPU_NUMTRUE_TYPE
225#undef DECLARE_GPU_NUMTRUE
226
227#define DECLARE_GPU_WHERE_INDEX(Dims, T, Tindex) \
228 template <> \
229 Status Where<GPUDevice, Dims, T, Tindex>::Compute( \
230 OpKernelContext* ctx, const GPUDevice& d, \
231 typename TTypes<T, Dims>::ConstTensor input, \
232 typename TTypes<int64_t>::Matrix output, Tindex* found_true); \
233 extern template struct Where<GPUDevice, Dims, T, Tindex>;
234#define DECLARE_GPU_WHERE(Dims, T) \
235 DECLARE_GPU_WHERE_INDEX(Dims, T, int32); \
236 DECLARE_GPU_WHERE_INDEX(Dims, T, int64_t);
237
238#define DECLARE_GPU_WHERE_TYPES(T) \
239 DECLARE_GPU_WHERE(1, T); \
240 DECLARE_GPU_WHERE(2, T); \
241 DECLARE_GPU_WHERE(3, T); \
242 DECLARE_GPU_WHERE(4, T); \
243 DECLARE_GPU_WHERE(5, T); \
244 DECLARE_GPU_WHERE(6, T); \
245 DECLARE_GPU_WHERE(7, T); \
246 DECLARE_GPU_WHERE(8, T);
247
248TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_WHERE_TYPES);
249
250#undef DECLARE_GPU_WHERE_TYPES
251#undef DECLARE_GPU_WHERE
252#undef DECLARE_GPU_WHERE_INDEX
253
254} // namespace functor
255
256template <typename T>
257class WhereGPUOp : public AsyncOpKernel {
258 public:
259 explicit WhereGPUOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}
260
261 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
262 const Tensor& input = context->input(0);
263 const int input_dims = input.dims();
264
265 if (input.NumElements() < std::numeric_limits<int32>::max()) {
266 ComputeAsyncType<int32>(input, input_dims, context, done);
267 } else {
268 ComputeAsyncType<int64_t>(input, input_dims, context, done);
269 }
270 }
271
272 template <typename Tindex>
273 void ComputeAsyncType(const Tensor& input, const int input_dims,
274 OpKernelContext* context, DoneCallback done) {
275 // Step 0: alloc nnz
276 // Step 1: call nnz kernel
277 // Step 2: call create_output
278 // Step 3: call where kernel
279
280 // Allocate pinned memory for `num_true`. This memory is accessible on host
281 // and device.
282 ScratchSpace<Tindex> num_true(context, 1, /*on_host=*/true);
283 typename TTypes<Tindex>::UnalignedScalar num_true_t(
284 num_true.mutable_data());
285
286 // Push kernel to stream to get number of true elements.
287 const GPUDevice& d = context->eigen_device<GPUDevice>();
288 Status s = functor::NumTrue<GPUDevice, T, Tindex>::Compute(
289 context, d, input.flat<T>(), num_true_t);
290 OP_REQUIRES_OK_ASYNC(context, s, done);
291
292 auto create_and_check_output = [context, &d, &input, input_dims,
293 num_true = std::move(num_true), done]() {
294 // Ensure that within the callback, the proper GPU settings are
295 // configured.
296 auto stream = context->op_device_context()->stream();
297 ScopedActivateExecutorContext scoped_activation{stream->parent()};
298
299 // TODO(ebrevdo): Properly copy back found_true value to CPU for
300 // validation checking. Currently Where<GPUDevice>::Compute()
301 // does not perform this copy back to CPU.
302 Tindex found_true = -1;
303
304 // Step 1: Allocate the output and perform the selection/copy.
305 Tensor* output;
306 OP_REQUIRES_OK_ASYNC(
307 context,
308 context->allocate_output(
309 0, TensorShape({*num_true.data(), input_dims}), &output),
310 done);
311
312#define HANDLE_DIM(NDIM) \
313 case NDIM: { \
314 Status s = functor::Where<GPUDevice, NDIM, T, Tindex>::Compute( \
315 context, d, input.tensor<T, NDIM>(), output->matrix<int64_t>(), \
316 &found_true); \
317 OP_REQUIRES_OK_ASYNC(context, s, done); \
318 } break;
319
320 switch (input_dims) {
321 HANDLE_DIM(1);
322 HANDLE_DIM(2);
323 HANDLE_DIM(3);
324 HANDLE_DIM(4);
325 HANDLE_DIM(5);
326 HANDLE_DIM(6);
327 HANDLE_DIM(7);
328 HANDLE_DIM(8);
329
330 default:
331 OP_REQUIRES_ASYNC(
332 context, false,
333 errors::InvalidArgument("WhereOp: Unhandled input dimensions: ",
334 input_dims),
335 done);
336 }
337#undef HANDLE_DIM
338
339 // TODO(ebrevdo): Fix the copy back to host.
340
341 // OP_REQUIRES_ASYNC(
342 // context, found_true == num_true,
343 // errors::InvalidArgument(
344 // "WhereOp: Race condition between counting the number of true "
345 // "elements and writing them. When counting, saw ",
346 // num_true, " elements; but when writing their indices, saw ",
347 // found_true, " elements."),
348 // done);
349
350 done();
351 };
352
353 auto stream = context->op_device_context()->stream();
354 context->device()
355 ->tensorflow_accelerator_device_info()
356 ->event_mgr->ThenExecute(stream, create_and_check_output);
357 }
358
359 private:
360 TF_DISALLOW_COPY_AND_ASSIGN(WhereGPUOp);
361};
362
363#define REGISTER_GPU_WHERE_OP(T) \
364 REGISTER_KERNEL_BUILDER( \
365 Name("Where").Device(DEVICE_GPU).TypeConstraint<T>("T"), WhereGPUOp<T>);
366
367TF_CALL_WHERE_GPU_TYPES(REGISTER_GPU_WHERE_OP);
368#undef REGISTER_GPU_WHERE_OP
369
370#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
371
372REGISTER_KERNEL_BUILDER(Name("Where")
373 .Device(DEVICE_DEFAULT)
374 .TypeConstraint<int32>("T")
375 .HostMemory("input")
376 .HostMemory("index"),
377 WhereCPUOp<int32>);
378
379} // namespace tensorflow
380