1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_
17#define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_
18// Functor definition for GatherOp, must be compilable by nvcc.
19
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21#include "tensorflow/core/framework/bounds_check.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/platform/types.h"
26#include "tensorflow/core/util/util.h"
27
28namespace tsl {
29class Status;
30}
31namespace tensorflow {
32using tsl::Status;
33
34class OpKernelContext;
35class Tensor;
36
37namespace functor {
38
39template <typename Device, typename T, typename Index, int IXDIM>
40struct GatherNdSlice {
41 // Performs a slice gather op on (Tparams, Tindices), writing to Tout.
42 // Returns an index to Tindices if the value at that index is out of range.
43 // Returns -1 if all values of Tindices are in range.
44 Index operator()(const Device& d, const Index slice_size,
45 typename TTypes<int32>::Scalar Tscratch,
46 typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
47 typename TTypes<Index>::ConstMatrix Tindices,
48 typename TTypes<T>::Matrix Tout);
49};
50
51template <typename Device, typename T, typename Index>
52Status DoGatherNd(OpKernelContext* c, const Tensor& params,
53 const Tensor& indices, Tensor* out) {
54 if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) {
55 return errors::InvalidArgument("params must be at least a vector");
56 }
57 if (!TensorShapeUtils::IsVectorOrHigher(indices.shape())) {
58 return errors::InvalidArgument("indices must be at least a vector");
59 }
60 if (indices.dim_size(indices.dims() - 1) > params.dims()) {
61 return errors::InvalidArgument(
62 "index innermost dimension length must be <= params rank; saw: ",
63 indices.dim_size(indices.dims() - 1), " vs. ", params.dims());
64 }
65
66 const TensorShape& indices_shape(indices.shape());
67 const int64_t indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
68
69 // Check that we have enough index space
70 int64_t N_big = 1;
71 for (int i = 0; i < indices_shape.dims() - 1; ++i) {
72 N_big *= indices_shape.dim_size(i);
73 }
74 if (N_big > std::numeric_limits<int>::max()) {
75 return errors::InvalidArgument(
76 "indices has too many elements for int indexing: ", N_big, " > ",
77 std::numeric_limits<int>::max());
78 }
79 if (params.NumElements() > std::numeric_limits<Index>::max()) {
80 return errors::InvalidArgument("params.NumElements() too large for ",
81 DataTypeString(DataTypeToEnum<Index>::v()),
82 " indexing: ", params.NumElements(), " > ",
83 std::numeric_limits<Index>::max());
84 }
85
86 // The result shape is
87 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
88 Index N_result = 1;
89 for (int i = 0; i < indices_shape.dims() - 1; ++i) {
90 N_result *= indices_shape.dim_size(i);
91 }
92
93 const TensorShape& params_shape(params.shape());
94 Index total_nd = params_shape.dims();
95
96 TensorShape result_shape(indices_shape);
97 result_shape.RemoveLastDims(1);
98
99 int64_t slice_size_big = 1;
100 for (Index i = indices_nd; i < total_nd; ++i) {
101 slice_size_big *= params_shape.dim_size(i);
102 result_shape.AddDim(params_shape.dim_size(i));
103 }
104
105 if (slice_size_big > std::numeric_limits<Index>::max()) {
106 return errors::InvalidArgument(
107 "slice size is too large for indexing: ", slice_size_big, " > ",
108 std::numeric_limits<Index>::max());
109 }
110
111 const Index slice_size = static_cast<Index>(slice_size_big);
112
113 TF_RETURN_IF_ERROR(
114 c->allocate_temp(DataTypeToEnum<T>::value, result_shape, out));
115
116 if (N_result > 0) {
117 if (params_shape.num_elements() == 0) {
118 return errors::InvalidArgument(
119 "Requested more than 0 entries, but "
120 "params is empty. Params shape: ",
121 params_shape.DebugString());
122 }
123
124 auto indices_mat = indices.flat_inner_dims<Index>();
125
126 Index bad_i = -1;
127
128 // Request to copy slices / subtensors
129 // Make out a matrix with the slices the col size.
130 auto out_mat = out->shaped<T, 2>({N_result, slice_size});
131 Tensor scratch;
132 TF_RETURN_IF_ERROR(c->allocate_temp(DT_INT32, TensorShape(), &scratch));
133 auto scratch_scalar = scratch.scalar<int32>();
134
135 switch (indices_nd) {
136#define PARAMS_CASE(IXDIM) \
137 case IXDIM: { \
138 functor::GatherNdSlice<Device, T, Index, IXDIM> func; \
139 auto params_flat = params.flat_outer_dims<T, IXDIM + 1>(); \
140 bad_i = func(c->eigen_device<Device>(), slice_size, scratch_scalar, \
141 params_flat, indices_mat, out_mat); \
142 } break
143 PARAMS_CASE(0);
144 PARAMS_CASE(1);
145 PARAMS_CASE(2);
146 PARAMS_CASE(3);
147 PARAMS_CASE(4);
148 PARAMS_CASE(5);
149 PARAMS_CASE(6);
150 PARAMS_CASE(7);
151#undef PARAMS_CASE
152 default:
153 return errors::InvalidArgument(
154 "Only indices.shape[-1] values between 1 and 7 "
155 "are currently supported. Requested rank: ",
156 indices_nd);
157 }
158
159 // bad_i will only return >= 0 on CPUs right now.
160 if (bad_i >= 0) {
161 auto shape = indices.shape();
162 shape.RemoveLastDims(1);
163 return errors::InvalidArgument(
164 "indices", SliceDebugString(shape, bad_i), " = [",
165 str_util::Join(
166 gtl::ArraySlice<Index>(&indices_mat(bad_i, 0), indices_nd), ", "),
167 "] does not index into param shape ", params.shape().DebugString(),
168 ", node name: ", c->op_kernel().name());
169 }
170 }
171 return OkStatus();
172}
173
174} // namespace functor
175} // namespace tensorflow
176
177#endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_
178