1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <limits>
16#include <memory>
17#include <string>
18#include <vector>
19
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/register_types.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/framework/tensor_shape.h"
24#include "tensorflow/core/util/util.h"
25
26namespace tensorflow {
27
28namespace {
29
30// For each slice in `(start, limit)` in `value_slices`, append
31// `params_dense_values_in[start:limit] to `values_out`. `value_size` indicates
32// the number of scalars contained in each value params_dense_values_in[i].
33template <typename VALUE_TYPE, typename SPLITS_TYPE>
34void WriteValueSlices(
35 const Tensor& params_dense_values_in,
36 const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
37 SPLITS_TYPE value_size, Tensor* values_out) {
38 const auto& params_dense_values =
39 params_dense_values_in.flat_outer_dims<VALUE_TYPE, 2>();
40 auto values = values_out->flat_outer_dims<VALUE_TYPE, 2>();
41 int out_pos = 0;
42 for (const auto& slice : value_slices) {
43 for (int i = slice.first; i < slice.second; ++i) {
44 for (int j = 0; j < value_size; ++j) {
45 values(out_pos, j) = params_dense_values(i, j);
46 }
47 ++out_pos;
48 }
49 }
50}
51
52} // namespace
53
54template <typename INDEX_TYPE, typename SPLITS_TYPE>
55class RaggedGatherOpBase : public OpKernel {
56 public:
57 using OpKernel::OpKernel;
58
59 void Compute(OpKernelContext* context) override {
60 // Get the input Tensors.
61
62 OpInputList params_nested_splits_in;
63 OP_REQUIRES_OK(context, context->input_list("params_nested_splits",
64 &params_nested_splits_in));
65 OP_REQUIRES(
66 context, params_nested_splits_in.size() > 0,
67 errors::InvalidArgument("params_nested_splits must be non empty"));
68
69 const Tensor& params_dense_values_in =
70 context->input(params_nested_splits_in.size());
71 const Tensor& indices_in =
72 context->input(params_nested_splits_in.size() + 1);
73
74 OP_REQUIRES(context, params_nested_splits_in[0].dims() > 0,
75 errors::InvalidArgument("Split tensors must not be scalars"));
76 SPLITS_TYPE num_params = params_nested_splits_in[0].dim_size(0) - 1;
77 OP_REQUIRES_OK(context, ValidateIndices(indices_in, num_params));
78
79 OP_REQUIRES(context, params_dense_values_in.dims() > 0,
80 errors::InvalidArgument("params.rank must be nonzero"));
81 SPLITS_TYPE num_params_dense_values = params_dense_values_in.dim_size(0);
82
83 // Calculate the `splits`, and store the value slices that we need to
84 // copy in `value_slices`.
85 std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>> value_slices;
86 SPLITS_TYPE num_values = 0;
87 std::vector<std::vector<SPLITS_TYPE>> out_splits;
88 OP_REQUIRES_OK(context, MakeSplits(indices_in, params_nested_splits_in,
89 num_params_dense_values, &out_splits,
90 &value_slices, &num_values));
91
92 // Write the output tensors.
93 OP_REQUIRES_OK(context, WriteSplits(out_splits, context));
94 OP_REQUIRES_OK(context,
95 WriteValues(params_dense_values_in, value_slices,
96 out_splits.size(), num_values, context));
97 }
98
99 private:
100 using ConstFlatType = typename TTypes<SPLITS_TYPE>::ConstFlat;
101
102 // Check if any indices are out-of-bounds.
103 ::tensorflow::Status ValidateIndices(const Tensor& indices_in,
104 SPLITS_TYPE num_params) {
105 const auto& indices = indices_in.flat<INDEX_TYPE>();
106 for (SPLITS_TYPE i = 0; i < indices.size(); ++i) {
107 SPLITS_TYPE index = indices(i);
108 if (index < 0 || index >= num_params) {
109 return errors::InvalidArgument(
110 "indices", SliceDebugString(indices_in.shape(), i), " = ", index,
111 " is not in [0, ", num_params, ")");
112 }
113 }
114 return OkStatus();
115 }
116
117 // Construct the `splits` output tensors, encoded using a nested vector.
118 // Also find the slices of values that need to be copied, and store them
119 // in `value_slices`. The total number of values that will be copied (which
120 // we need for allocating the output values tensor) is stored in `num_values`.
121 ::tensorflow::Status MakeSplits(
122 const Tensor& indices_in, const OpInputList& params_nested_splits_in,
123 SPLITS_TYPE num_params_dense_values,
124 std::vector<std::vector<SPLITS_TYPE>>* out_splits,
125 std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>* value_slices,
126 SPLITS_TYPE* num_values) {
127 *num_values = 0;
128 value_slices->clear();
129
130 int num_splits = indices_in.dims() - 1 + params_nested_splits_in.size();
131 out_splits->assign(num_splits, {0});
132
133 // Get Eigen tensors.
134 const auto& indices = indices_in.flat<INDEX_TYPE>();
135 std::vector<ConstFlatType> params_nested_splits;
136 params_nested_splits.reserve(params_nested_splits_in.size());
137 for (const auto& splits_in : params_nested_splits_in) {
138 params_nested_splits.push_back(splits_in.flat<SPLITS_TYPE>());
139 }
140
141 TF_RETURN_IF_ERROR(
142 ValidateSplits(params_nested_splits, num_params_dense_values));
143
144 // Add `splits` that come from all but the last dimension of the dense
145 // Tensor `indices`. In particular, for each dimension D, we add a
146 // splits tensor whose values are:
147 // range(reduce_prod(splits.shape[:D]) + 1) * splits.shape[D+1]
148 // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:
149 // [0, 3, 6] # length=2+1, stride=3
150 // [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4
151 int nrows = 1;
152 for (int dim = 0; dim < indices_in.dims() - 1; ++dim) {
153 nrows *= indices_in.dim_size(dim);
154 int row_length = indices_in.dim_size(dim + 1);
155 for (int i = 1; i < nrows + 1; ++i) {
156 out_splits->at(dim).push_back(i * row_length);
157 }
158 }
159
160 // Add `splits` that come from `params_nested_splits`. Starting with the
161 // outermost ragged dimension (i.e., the first `splits` tensor), we work
162 // our way in, finding the range of values that should be copied. As we
163 // go, we update the output `splits` for each dimension with the appropriate
164 // values. In particular, the *lengths* of the slices from `param_splits`
165 // should be copied to generate corresponding slice lengths in the output
166 // splits. E.g., if we are copying a ragged row with length 4, then we
167 // should add a new split point to out_splits that is 4 greater than the
168 // previous split point in out_splits.
169 for (int i = 0; i < indices.size(); ++i) {
170 int start = indices(i);
171 int limit = indices(i) + 1;
172
173 // Copy splits.
174 for (int dim = 0; dim < params_nested_splits.size(); ++dim) {
175 const auto& splits = params_nested_splits[dim];
176 int out_dim = dim + indices_in.dims() - 1;
177 if (out_dim >= 0) {
178 SPLITS_TYPE delta = out_splits->at(out_dim).back() - splits(start);
179 for (int j = start; j < limit; ++j) {
180 out_splits->at(out_dim).push_back(splits(j + 1) + delta);
181 }
182 }
183 start = splits(start);
184 limit = splits(limit);
185 }
186 if (limit != start) {
187 value_slices->emplace_back(start, limit);
188 *num_values += limit - start;
189 }
190 }
191 return OkStatus();
192 }
193
194 ::tensorflow::Status ValidateSplits(
195 const std::vector<ConstFlatType>& params_nested_splits,
196 SPLITS_TYPE num_params_dense_values) {
197 // Validate
198 for (int dim = 0; dim < params_nested_splits.size(); ++dim) {
199 const auto& splits = params_nested_splits[dim];
200 SPLITS_TYPE last_split = (dim == params_nested_splits.size() - 1)
201 ? num_params_dense_values
202 : params_nested_splits[dim + 1].size();
203 if (splits.size() == 0) {
204 return errors::InvalidArgument("Ragged splits may not be empty");
205 }
206 if (splits(0) < 0) {
207 return errors::InvalidArgument("Ragged splits must be non-negative");
208 }
209 if (splits(splits.size() - 1) > last_split) {
210 return errors::InvalidArgument(
211 "Ragged splits must not point past values");
212 }
213 for (int i = 1; i < splits.size(); ++i) {
214 if (splits(i - 1) > splits(i)) {
215 return errors::InvalidArgument("Ragged splits must be sorted");
216 }
217 }
218 }
219 return OkStatus();
220 }
221
222 ::tensorflow::Status WriteSplits(
223 const std::vector<std::vector<SPLITS_TYPE>>& out_splits,
224 OpKernelContext* context) {
225 OpOutputList splits_out;
226 TF_RETURN_IF_ERROR(
227 context->output_list("output_nested_splits", &splits_out));
228 for (int i = 0; i < out_splits.size(); ++i) {
229 Tensor* splits;
230 SPLITS_TYPE num_splits = out_splits[i].size();
231 TF_RETURN_IF_ERROR(
232 splits_out.allocate(i, TensorShape({num_splits}), &splits));
233 auto splits_flat = splits->flat<SPLITS_TYPE>();
234 std::copy_n(out_splits[i].data(), out_splits[i].size(),
235 splits_flat.data());
236 }
237 return OkStatus();
238 }
239
240 ::tensorflow::Status WriteValues(
241 const Tensor& params_dense_values_in,
242 const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
243 int values_index, SPLITS_TYPE num_values,
244 OpKernelContext* context) const {
245 Tensor* values_out = nullptr;
246 TensorShape values_shape = params_dense_values_in.shape();
247 values_shape.set_dim(0, num_values);
248 TF_RETURN_IF_ERROR(
249 context->allocate_output(values_index, values_shape, &values_out));
250 const SPLITS_TYPE num_elements = params_dense_values_in.NumElements();
251 const SPLITS_TYPE value_size =
252 num_elements == 0 ? 0
253 : (num_elements / params_dense_values_in.dim_size(0));
254 CallWriteValueSlices(params_dense_values_in, value_slices, value_size,
255 values_out);
256 return OkStatus();
257 }
258
259 protected:
260 // Call WriteValueSlices() using the appropriate VALUE_TYPE template
261 // parameter. This pattern is used to reduce binary size. In particular,
262 // this allows us to have two instantiations of this class (one for each
263 // index type), rather than 14 (one for each index type and value type),
264 // which cuts the binary size of this op from ~300k to <90k.
265 virtual void CallWriteValueSlices(
266 const Tensor& params_dense_values_in,
267 const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
268 SPLITS_TYPE value_size, Tensor* values_out) const = 0;
269};
270
271template <typename INDEX_TYPE, typename VALUE_TYPE, typename SPLITS_TYPE>
272class RaggedGatherOp : public RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE> {
273 public:
274 using RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE>::RaggedGatherOpBase;
275
276 private:
277 void CallWriteValueSlices(
278 const Tensor& params_dense_values_in,
279 const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
280 SPLITS_TYPE value_size, Tensor* values_out) const override {
281 WriteValueSlices<VALUE_TYPE>(params_dense_values_in, value_slices,
282 value_size, values_out);
283 }
284};
285
286#define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type, \
287 splits_type) \
288 REGISTER_KERNEL_BUILDER( \
289 Name("RaggedGather") \
290 .Device(DEVICE_CPU) \
291 .TypeConstraint<index_type>("Tindices") \
292 .TypeConstraint<value_type>("Tvalues") \
293 .TypeConstraint<splits_type>("Tsplits"), \
294 RaggedGatherOp<index_type, value_type, splits_type>);
295#define REGISTER_CPU_KERNEL(value_type) \
296 REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int32) \
297 REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64_t, value_type, int32) \
298 REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int64_t) \
299 REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64_t, value_type, int64_t)
300TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL);
301TF_CALL_tstring(REGISTER_CPU_KERNEL);
302TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
303TF_CALL_quint16(REGISTER_CPU_KERNEL);
304TF_CALL_qint16(REGISTER_CPU_KERNEL);
305#undef REGISTER_CPU_KERNEL
306#undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE
307
308} // namespace tensorflow
309