1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#include "tensorflow/core/framework/bounds_check.h"
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/register_types.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/framework/tensor_types.h"
23#include "tensorflow/core/framework/variant.h"
24#include "tensorflow/core/framework/variant_encode_decode.h"
25#include "tensorflow/core/kernels/gather_functor.h"
26#include "tensorflow/core/kernels/gather_functor_batched.h"
27#include "tensorflow/core/platform/mem.h"
28#include "tensorflow/core/platform/types.h"
29#include "tensorflow/core/util/util.h"
30
31namespace tensorflow {
32
33typedef Eigen::ThreadPoolDevice CPUDevice;
34typedef Eigen::GpuDevice GPUDevice;
35typedef Eigen::DenseIndex IndexType;
36
37template <typename Device, typename T, typename Index>
38class GatherOp : public OpKernel {
39 public:
40 // QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
41 // etc. here for the type of the second input argument. Should
42 // we have the framework do some sort of integer promotion
43 // automatically, or should that be something that users have to
44 // do explicitly with a conversion operator in the graph?
45 explicit GatherOp(OpKernelConstruction* c) : OpKernel(c) {
46 // Set batch_dims_ to 0 if the attribute does not exist.
47 if (c->HasAttr("batch_dims")) {
48 OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
49 } else {
50 batch_dims_ = 0;
51 }
52 }
53
54 void Compute(OpKernelContext* c) override {
55 const Tensor& params = c->input(0);
56 const Tensor& indices = c->input(1);
57 OP_REQUIRES(
58 c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
59 errors::InvalidArgument("params must be at least 1 dimensional"));
60
61 // GatherV2 added an axis argument. For backwards compatibility with Gather,
62 // fall back to axis 0 if the op does not have an axis input.
63 int64_t axis = 0;
64 bool axis_is_set = false; // Indicates whether the axis argument was set.
65 if (c->num_inputs() == 3) {
66 axis_is_set = true;
67 const Tensor& axis_tensor = c->input(2);
68 OP_REQUIRES(c, TensorShapeUtils::IsScalar(axis_tensor.shape()),
69 errors::InvalidArgument("axis must be scalar"));
70
71 if (axis_tensor.dtype() == DT_INT32) {
72 axis = axis_tensor.scalar<int32>()();
73 } else if (axis_tensor.dtype() == DT_INT64) {
74 axis = axis_tensor.scalar<int64_t>()();
75 } else {
76 OP_REQUIRES(c, false,
77 errors::InvalidArgument("axis must be int32 or int64."));
78 }
79 }
80
81 int64_t min_params_dim = axis < 0 ? -axis : axis + 1;
82 OP_REQUIRES(
83 c, params.dims() >= min_params_dim,
84 errors::InvalidArgument("Shape must be at least rank ", min_params_dim,
85 " but is rank ", params.dims()));
86
87 if (axis < 0) {
88 axis = params.dims() + axis;
89 }
90
91 // Modify only a local copy of batch_dims_.
92 int32_t batch_dims = batch_dims_;
93 if (batch_dims != 0) {
94 OP_REQUIRES(c,
95 batch_dims >= -indices.dims() && batch_dims <= indices.dims(),
96 errors::InvalidArgument("Expected batch_dims in the range [",
97 -indices.dims(), ", ", indices.dims(),
98 "], but got ", batch_dims));
99
100 if (batch_dims < 0) {
101 batch_dims = indices.dims() + batch_dims;
102 }
103
104 if (!axis_is_set) axis = batch_dims;
105
106 OP_REQUIRES(c, batch_dims < params.dims(),
107 errors::InvalidArgument("batch_dims (", batch_dims,
108 ") must be less than rank(params) (",
109 params.dims(), ")."));
110
111 OP_REQUIRES(c, axis >= batch_dims,
112 errors::InvalidArgument("batch_dims (", batch_dims,
113 ") must be less than or equal to ",
114 "axis (", axis, ")."));
115 for (int i = 0; i < batch_dims; ++i) {
116 OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i),
117 errors::InvalidArgument(
118 "params.shape[", i, "]: ", params.dim_size(i),
119 " should be equal to indices.shape[", i,
120 "]: ", indices.dim_size(i)));
121 }
122 }
123
124 // Check that we have enough index space
125 int64_t gather_dim_size = params.dim_size(axis);
126 const int64_t N = indices.NumElements();
127 OP_REQUIRES(
128 c, gather_dim_size <= std::numeric_limits<Index>::max(),
129 errors::InvalidArgument("params.shape[", axis, "] too large for ",
130 DataTypeString(DataTypeToEnum<Index>::v()),
131 " indexing: ", gather_dim_size, " > ",
132 std::numeric_limits<Index>::max()));
133
134 // The result shape is params.shape[:axis] + indices.shape[batch_dims:] +
135 // params.shape[axis + 1:].
136 TensorShape result_shape;
137 int64_t batch_size = 1;
138 int64_t outer_size = 1;
139 int64_t inner_size = 1;
140
141 for (int i = 0; i < batch_dims; ++i) {
142 result_shape.AddDim(params.dim_size(i));
143 batch_size *= params.dim_size(i);
144 }
145 for (int i = batch_dims; i < axis; ++i) {
146 result_shape.AddDim(params.dim_size(i));
147 outer_size *= params.dim_size(i);
148 }
149 for (int i = batch_dims; i < indices.dims(); ++i) {
150 result_shape.AddDim(indices.dim_size(i));
151 }
152 for (int i = axis + 1; i < params.dims(); ++i) {
153 result_shape.AddDim(params.dim_size(i));
154 inner_size *= params.dim_size(i);
155 }
156
157 Tensor* out = nullptr;
158 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
159 if (N == 0) return;
160 if (inner_size == 0) return;
161
162 int64_t bad_i = -1;
163 auto indices_flat = indices.flat<Index>();
164 if (batch_dims > 0) {
165 auto params_flat = params.shaped<T, 4>(
166 {batch_size, outer_size, gather_dim_size, inner_size});
167 auto out_flat = out->shaped<T, 4>(
168 {batch_size, outer_size, N / batch_size, inner_size});
169
170 functor::GatherFunctorBatched<Device, T, Index> functor;
171 bad_i = functor(c, params_flat, indices_flat, out_flat);
172 } else {
173 auto params_flat =
174 params.shaped<T, 3>({outer_size, gather_dim_size, inner_size});
175 auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
176
177 functor::GatherFunctor<Device, T, Index> functor;
178 bad_i = functor(c, params_flat, indices_flat, out_flat);
179 }
180 OP_REQUIRES(
181 c, bad_i < 0,
182 errors::InvalidArgument(
183 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
184 indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
185 }
186
187 private:
188 // The number of batch dimensions, as passed in the batch_dims attribute.
189 // It must be less than or equal to rank(indices).
190 int32 batch_dims_ = 0;
191};
192
193#define REGISTER_GATHER_FULL(dev, type, index_type) \
194 REGISTER_KERNEL_BUILDER(Name("Gather") \
195 .Device(DEVICE_##dev) \
196 .TypeConstraint<type>("Tparams") \
197 .TypeConstraint<index_type>("Tindices"), \
198 GatherOp<dev##Device, type, index_type>); \
199 REGISTER_KERNEL_BUILDER(Name("GatherV2") \
200 .Device(DEVICE_##dev) \
201 .TypeConstraint<type>("Tparams") \
202 .TypeConstraint<index_type>("Tindices") \
203 .HostMemory("axis"), \
204 GatherOp<dev##Device, type, index_type>)
205
206#define REGISTER_GATHER_ALL_INDICES(dev, type) \
207 REGISTER_GATHER_FULL(dev, type, int32); \
208 REGISTER_GATHER_FULL(dev, type, int64_t)
209
210#define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
211
212// Registration of the CPU implementations.
213TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
214TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
215TF_CALL_quint16(REGISTER_GATHER_CPU);
216TF_CALL_qint16(REGISTER_GATHER_CPU);
217
218#undef REGISTER_GATHER_CPU
219
220#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
221
222// Registration of the GPU implementations.
223#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
224
225TF_CALL_int32(REGISTER_GATHER_GPU);
226TF_CALL_int64(REGISTER_GATHER_GPU);
227TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
228
229#undef REGISTER_GATHER_GPU
230
231#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
232
233#undef REGISTER_GATHER_ALL_INDICES
234#undef REGISTER_GATHER_FULL
235
236} // namespace tensorflow
237