1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_ |
18 | // Functor definition for GatherOp, must be compilable by nvcc. |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/framework/bounds_check.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/platform/types.h" |
26 | #include "tensorflow/core/util/util.h" |
27 | |
28 | namespace tsl { |
29 | class Status; |
30 | } |
31 | namespace tensorflow { |
32 | using tsl::Status; |
33 | |
34 | class OpKernelContext; |
35 | class Tensor; |
36 | |
37 | namespace functor { |
38 | |
39 | template <typename Device, typename T, typename Index, int IXDIM> |
40 | struct GatherNdSlice { |
41 | // Performs a slice gather op on (Tparams, Tindices), writing to Tout. |
42 | // Returns an index to Tindices if the value at that index is out of range. |
43 | // Returns -1 if all values of Tindices are in range. |
44 | Index operator()(const Device& d, const Index slice_size, |
45 | typename TTypes<int32>::Scalar Tscratch, |
46 | typename TTypes<T, IXDIM + 1>::ConstTensor Tparams, |
47 | typename TTypes<Index>::ConstMatrix Tindices, |
48 | typename TTypes<T>::Matrix Tout); |
49 | }; |
50 | |
51 | template <typename Device, typename T, typename Index> |
52 | Status DoGatherNd(OpKernelContext* c, const Tensor& params, |
53 | const Tensor& indices, Tensor* out) { |
54 | if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) { |
55 | return errors::InvalidArgument("params must be at least a vector" ); |
56 | } |
57 | if (!TensorShapeUtils::IsVectorOrHigher(indices.shape())) { |
58 | return errors::InvalidArgument("indices must be at least a vector" ); |
59 | } |
60 | if (indices.dim_size(indices.dims() - 1) > params.dims()) { |
61 | return errors::InvalidArgument( |
62 | "index innermost dimension length must be <= params rank; saw: " , |
63 | indices.dim_size(indices.dims() - 1), " vs. " , params.dims()); |
64 | } |
65 | |
66 | const TensorShape& indices_shape(indices.shape()); |
67 | const int64_t indices_nd = indices_shape.dim_size(indices_shape.dims() - 1); |
68 | |
69 | // Check that we have enough index space |
70 | int64_t N_big = 1; |
71 | for (int i = 0; i < indices_shape.dims() - 1; ++i) { |
72 | N_big *= indices_shape.dim_size(i); |
73 | } |
74 | if (N_big > std::numeric_limits<int>::max()) { |
75 | return errors::InvalidArgument( |
76 | "indices has too many elements for int indexing: " , N_big, " > " , |
77 | std::numeric_limits<int>::max()); |
78 | } |
79 | if (params.NumElements() > std::numeric_limits<Index>::max()) { |
80 | return errors::InvalidArgument("params.NumElements() too large for " , |
81 | DataTypeString(DataTypeToEnum<Index>::v()), |
82 | " indexing: " , params.NumElements(), " > " , |
83 | std::numeric_limits<Index>::max()); |
84 | } |
85 | |
86 | // The result shape is |
87 | // indices.shape[:-1] + params.shape[indices.shape[-1]:] |
88 | Index N_result = 1; |
89 | for (int i = 0; i < indices_shape.dims() - 1; ++i) { |
90 | N_result *= indices_shape.dim_size(i); |
91 | } |
92 | |
93 | const TensorShape& params_shape(params.shape()); |
94 | Index total_nd = params_shape.dims(); |
95 | |
96 | TensorShape result_shape(indices_shape); |
97 | result_shape.RemoveLastDims(1); |
98 | |
99 | int64_t slice_size_big = 1; |
100 | for (Index i = indices_nd; i < total_nd; ++i) { |
101 | slice_size_big *= params_shape.dim_size(i); |
102 | result_shape.AddDim(params_shape.dim_size(i)); |
103 | } |
104 | |
105 | if (slice_size_big > std::numeric_limits<Index>::max()) { |
106 | return errors::InvalidArgument( |
107 | "slice size is too large for indexing: " , slice_size_big, " > " , |
108 | std::numeric_limits<Index>::max()); |
109 | } |
110 | |
111 | const Index slice_size = static_cast<Index>(slice_size_big); |
112 | |
113 | TF_RETURN_IF_ERROR( |
114 | c->allocate_temp(DataTypeToEnum<T>::value, result_shape, out)); |
115 | |
116 | if (N_result > 0) { |
117 | if (params_shape.num_elements() == 0) { |
118 | return errors::InvalidArgument( |
119 | "Requested more than 0 entries, but " |
120 | "params is empty. Params shape: " , |
121 | params_shape.DebugString()); |
122 | } |
123 | |
124 | auto indices_mat = indices.flat_inner_dims<Index>(); |
125 | |
126 | Index bad_i = -1; |
127 | |
128 | // Request to copy slices / subtensors |
129 | // Make out a matrix with the slices the col size. |
130 | auto out_mat = out->shaped<T, 2>({N_result, slice_size}); |
131 | Tensor scratch; |
132 | TF_RETURN_IF_ERROR(c->allocate_temp(DT_INT32, TensorShape(), &scratch)); |
133 | auto scratch_scalar = scratch.scalar<int32>(); |
134 | |
135 | switch (indices_nd) { |
136 | #define PARAMS_CASE(IXDIM) \ |
137 | case IXDIM: { \ |
138 | functor::GatherNdSlice<Device, T, Index, IXDIM> func; \ |
139 | auto params_flat = params.flat_outer_dims<T, IXDIM + 1>(); \ |
140 | bad_i = func(c->eigen_device<Device>(), slice_size, scratch_scalar, \ |
141 | params_flat, indices_mat, out_mat); \ |
142 | } break |
143 | PARAMS_CASE(0); |
144 | PARAMS_CASE(1); |
145 | PARAMS_CASE(2); |
146 | PARAMS_CASE(3); |
147 | PARAMS_CASE(4); |
148 | PARAMS_CASE(5); |
149 | PARAMS_CASE(6); |
150 | PARAMS_CASE(7); |
151 | #undef PARAMS_CASE |
152 | default: |
153 | return errors::InvalidArgument( |
154 | "Only indices.shape[-1] values between 1 and 7 " |
155 | "are currently supported. Requested rank: " , |
156 | indices_nd); |
157 | } |
158 | |
159 | // bad_i will only return >= 0 on CPUs right now. |
160 | if (bad_i >= 0) { |
161 | auto shape = indices.shape(); |
162 | shape.RemoveLastDims(1); |
163 | return errors::InvalidArgument( |
164 | "indices" , SliceDebugString(shape, bad_i), " = [" , |
165 | str_util::Join( |
166 | gtl::ArraySlice<Index>(&indices_mat(bad_i, 0), indices_nd), ", " ), |
167 | "] does not index into param shape " , params.shape().DebugString(), |
168 | ", node name: " , c->op_kernel().name()); |
169 | } |
170 | } |
171 | return OkStatus(); |
172 | } |
173 | |
174 | } // namespace functor |
175 | } // namespace tensorflow |
176 | |
177 | #endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_ |
178 | |