1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17#define EIGEN_USE_THREADS
18
19#include "tensorflow/core/kernels/gather_nd_op.h"
20#include "tensorflow/core/framework/bounds_check.h"
21#include "tensorflow/core/framework/register_types.h"
22#include "tensorflow/core/lib/strings/str_util.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/platform/mem.h"
25#include "tensorflow/core/platform/types.h"
26
27namespace tensorflow {
28
29typedef Eigen::ThreadPoolDevice CPUDevice;
30typedef Eigen::GpuDevice GPUDevice;
31
32template <typename Device, typename T, typename Index>
33class GatherNdOp : public OpKernel {
34 public:
35 explicit GatherNdOp(OpKernelConstruction* c) : OpKernel(c) {
36 const DataType dt = DataTypeToEnum<T>::v();
37 const DataType index_t = DataTypeToEnum<Index>::v();
38 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
39 }
40
41 void Compute(OpKernelContext* c) override {
42 const Tensor& params = c->input(0);
43 const Tensor& indices = c->input(1);
44
45 Tensor out;
46 OP_REQUIRES_OK(
47 c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
48 c->set_output(0, out);
49 }
50};
51
52#define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
53 REGISTER_KERNEL_BUILDER(Name("GatherNd") \
54 .Device(DEVICE_##dev) \
55 .TypeConstraint<type>("Tparams") \
56 .TypeConstraint<index_type>("Tindices"), \
57 GatherNdOp<dev##Device, type, index_type>)
58
59#define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
60 REGISTER_GATHER_ND_FULL(dev, type, int32); \
61 REGISTER_GATHER_ND_FULL(dev, type, int64_t)
62
63#define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
64
65// TODO(ebrevdo): This is a pure data-movement kernel. It shouldn't be
66// instantiated for all different types. Instead, all the types should
67// be coalesced. So we should only have int8, int16, int32, int64 support.
68// And float is redirected to int32, double is redirected to int64,
69// and complex<float> is redirected to int32 with twice the number of
70// entries, similarly for complex<double>.
71//
72// Same for the GPU kernel.
73TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
74TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU);
75
76#undef REGISTER_GATHER_ND_CPU
77
78#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
79// Forward declarations of the functor specializations for GPU.
80namespace functor {
81#define DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \
82 template <> \
83 Index GatherNdSlice<GPUDevice, T, Index, NDIM>::operator()( \
84 const GPUDevice& d, const Index slice_size, \
85 typename TTypes<int32>::Scalar Tscratch, \
86 typename TTypes<T, NDIM + 1>::ConstTensor Tparams, \
87 typename TTypes<Index>::ConstMatrix Tindices, \
88 typename TTypes<T>::Matrix Tout); \
89 extern template struct GatherNdSlice<GPUDevice, T, Index, NDIM>;
90
91#define DECLARE_GPU_SPECS_INDEX(T, Index) \
92 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 0); \
93 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
94 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
95 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
96 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
97 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 5); \
98 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 6); \
99 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 7);
100
101#define DECLARE_GPU_SPECS(T) \
102 DECLARE_GPU_SPECS_INDEX(T, int32); \
103 DECLARE_GPU_SPECS_INDEX(T, int64_t)
104
105TF_CALL_int32(DECLARE_GPU_SPECS);
106TF_CALL_int64(DECLARE_GPU_SPECS);
107TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
108TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
109
110#undef DECLARE_GPU_SPECS
111#undef DECLARE_GPU_SPECS_INDEX
112} // namespace functor
113
114// Registration of the GPU implementations.
115#define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
116
117TF_CALL_int32(REGISTER_GATHER_ND_GPU);
118TF_CALL_int64(REGISTER_GATHER_ND_GPU);
119TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
120TF_CALL_COMPLEX_TYPES(REGISTER_GATHER_ND_GPU);
121
122#undef REGISTER_GATHER_ND_GPU
123
124#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
125
126#undef REGISTER_GATHER_ND_ALL_INDICES
127#undef REGISTER_GATHER_ND_FULL
128
129} // namespace tensorflow
130