1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <limits> |
16 | #include <memory> |
17 | #include <string> |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/register_types.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/tensor_shape.h" |
24 | #include "tensorflow/core/util/util.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | namespace { |
29 | |
30 | // For each slice in `(start, limit)` in `value_slices`, append |
31 | // `params_dense_values_in[start:limit] to `values_out`. `value_size` indicates |
32 | // the number of scalars contained in each value params_dense_values_in[i]. |
33 | template <typename VALUE_TYPE, typename SPLITS_TYPE> |
34 | void WriteValueSlices( |
35 | const Tensor& params_dense_values_in, |
36 | const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices, |
37 | SPLITS_TYPE value_size, Tensor* values_out) { |
38 | const auto& params_dense_values = |
39 | params_dense_values_in.flat_outer_dims<VALUE_TYPE, 2>(); |
40 | auto values = values_out->flat_outer_dims<VALUE_TYPE, 2>(); |
41 | int out_pos = 0; |
42 | for (const auto& slice : value_slices) { |
43 | for (int i = slice.first; i < slice.second; ++i) { |
44 | for (int j = 0; j < value_size; ++j) { |
45 | values(out_pos, j) = params_dense_values(i, j); |
46 | } |
47 | ++out_pos; |
48 | } |
49 | } |
50 | } |
51 | |
52 | } // namespace |
53 | |
54 | template <typename INDEX_TYPE, typename SPLITS_TYPE> |
55 | class RaggedGatherOpBase : public OpKernel { |
56 | public: |
57 | using OpKernel::OpKernel; |
58 | |
59 | void Compute(OpKernelContext* context) override { |
60 | // Get the input Tensors. |
61 | |
62 | OpInputList params_nested_splits_in; |
63 | OP_REQUIRES_OK(context, context->input_list("params_nested_splits" , |
64 | ¶ms_nested_splits_in)); |
65 | OP_REQUIRES( |
66 | context, params_nested_splits_in.size() > 0, |
67 | errors::InvalidArgument("params_nested_splits must be non empty" )); |
68 | |
69 | const Tensor& params_dense_values_in = |
70 | context->input(params_nested_splits_in.size()); |
71 | const Tensor& indices_in = |
72 | context->input(params_nested_splits_in.size() + 1); |
73 | |
74 | OP_REQUIRES(context, params_nested_splits_in[0].dims() > 0, |
75 | errors::InvalidArgument("Split tensors must not be scalars" )); |
76 | SPLITS_TYPE num_params = params_nested_splits_in[0].dim_size(0) - 1; |
77 | OP_REQUIRES_OK(context, ValidateIndices(indices_in, num_params)); |
78 | |
79 | OP_REQUIRES(context, params_dense_values_in.dims() > 0, |
80 | errors::InvalidArgument("params.rank must be nonzero" )); |
81 | SPLITS_TYPE num_params_dense_values = params_dense_values_in.dim_size(0); |
82 | |
83 | // Calculate the `splits`, and store the value slices that we need to |
84 | // copy in `value_slices`. |
85 | std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>> value_slices; |
86 | SPLITS_TYPE num_values = 0; |
87 | std::vector<std::vector<SPLITS_TYPE>> out_splits; |
88 | OP_REQUIRES_OK(context, MakeSplits(indices_in, params_nested_splits_in, |
89 | num_params_dense_values, &out_splits, |
90 | &value_slices, &num_values)); |
91 | |
92 | // Write the output tensors. |
93 | OP_REQUIRES_OK(context, WriteSplits(out_splits, context)); |
94 | OP_REQUIRES_OK(context, |
95 | WriteValues(params_dense_values_in, value_slices, |
96 | out_splits.size(), num_values, context)); |
97 | } |
98 | |
99 | private: |
100 | using ConstFlatType = typename TTypes<SPLITS_TYPE>::ConstFlat; |
101 | |
102 | // Check if any indices are out-of-bounds. |
103 | ::tensorflow::Status ValidateIndices(const Tensor& indices_in, |
104 | SPLITS_TYPE num_params) { |
105 | const auto& indices = indices_in.flat<INDEX_TYPE>(); |
106 | for (SPLITS_TYPE i = 0; i < indices.size(); ++i) { |
107 | SPLITS_TYPE index = indices(i); |
108 | if (index < 0 || index >= num_params) { |
109 | return errors::InvalidArgument( |
110 | "indices" , SliceDebugString(indices_in.shape(), i), " = " , index, |
111 | " is not in [0, " , num_params, ")" ); |
112 | } |
113 | } |
114 | return OkStatus(); |
115 | } |
116 | |
117 | // Construct the `splits` output tensors, encoded using a nested vector. |
118 | // Also find the slices of values that need to be copied, and store them |
119 | // in `value_slices`. The total number of values that will be copied (which |
120 | // we need for allocating the output values tensor) is stored in `num_values`. |
121 | ::tensorflow::Status MakeSplits( |
122 | const Tensor& indices_in, const OpInputList& params_nested_splits_in, |
123 | SPLITS_TYPE num_params_dense_values, |
124 | std::vector<std::vector<SPLITS_TYPE>>* out_splits, |
125 | std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>* value_slices, |
126 | SPLITS_TYPE* num_values) { |
127 | *num_values = 0; |
128 | value_slices->clear(); |
129 | |
130 | int num_splits = indices_in.dims() - 1 + params_nested_splits_in.size(); |
131 | out_splits->assign(num_splits, {0}); |
132 | |
133 | // Get Eigen tensors. |
134 | const auto& indices = indices_in.flat<INDEX_TYPE>(); |
135 | std::vector<ConstFlatType> params_nested_splits; |
136 | params_nested_splits.reserve(params_nested_splits_in.size()); |
137 | for (const auto& splits_in : params_nested_splits_in) { |
138 | params_nested_splits.push_back(splits_in.flat<SPLITS_TYPE>()); |
139 | } |
140 | |
141 | TF_RETURN_IF_ERROR( |
142 | ValidateSplits(params_nested_splits, num_params_dense_values)); |
143 | |
144 | // Add `splits` that come from all but the last dimension of the dense |
145 | // Tensor `indices`. In particular, for each dimension D, we add a |
146 | // splits tensor whose values are: |
147 | // range(reduce_prod(splits.shape[:D]) + 1) * splits.shape[D+1] |
148 | // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors: |
149 | // [0, 3, 6] # length=2+1, stride=3 |
150 | // [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4 |
151 | int nrows = 1; |
152 | for (int dim = 0; dim < indices_in.dims() - 1; ++dim) { |
153 | nrows *= indices_in.dim_size(dim); |
154 | int row_length = indices_in.dim_size(dim + 1); |
155 | for (int i = 1; i < nrows + 1; ++i) { |
156 | out_splits->at(dim).push_back(i * row_length); |
157 | } |
158 | } |
159 | |
160 | // Add `splits` that come from `params_nested_splits`. Starting with the |
161 | // outermost ragged dimension (i.e., the first `splits` tensor), we work |
162 | // our way in, finding the range of values that should be copied. As we |
163 | // go, we update the output `splits` for each dimension with the appropriate |
164 | // values. In particular, the *lengths* of the slices from `param_splits` |
165 | // should be copied to generate corresponding slice lengths in the output |
166 | // splits. E.g., if we are copying a ragged row with length 4, then we |
167 | // should add a new split point to out_splits that is 4 greater than the |
168 | // previous split point in out_splits. |
169 | for (int i = 0; i < indices.size(); ++i) { |
170 | int start = indices(i); |
171 | int limit = indices(i) + 1; |
172 | |
173 | // Copy splits. |
174 | for (int dim = 0; dim < params_nested_splits.size(); ++dim) { |
175 | const auto& splits = params_nested_splits[dim]; |
176 | int out_dim = dim + indices_in.dims() - 1; |
177 | if (out_dim >= 0) { |
178 | SPLITS_TYPE delta = out_splits->at(out_dim).back() - splits(start); |
179 | for (int j = start; j < limit; ++j) { |
180 | out_splits->at(out_dim).push_back(splits(j + 1) + delta); |
181 | } |
182 | } |
183 | start = splits(start); |
184 | limit = splits(limit); |
185 | } |
186 | if (limit != start) { |
187 | value_slices->emplace_back(start, limit); |
188 | *num_values += limit - start; |
189 | } |
190 | } |
191 | return OkStatus(); |
192 | } |
193 | |
194 | ::tensorflow::Status ValidateSplits( |
195 | const std::vector<ConstFlatType>& params_nested_splits, |
196 | SPLITS_TYPE num_params_dense_values) { |
197 | // Validate |
198 | for (int dim = 0; dim < params_nested_splits.size(); ++dim) { |
199 | const auto& splits = params_nested_splits[dim]; |
200 | SPLITS_TYPE last_split = (dim == params_nested_splits.size() - 1) |
201 | ? num_params_dense_values |
202 | : params_nested_splits[dim + 1].size(); |
203 | if (splits.size() == 0) { |
204 | return errors::InvalidArgument("Ragged splits may not be empty" ); |
205 | } |
206 | if (splits(0) < 0) { |
207 | return errors::InvalidArgument("Ragged splits must be non-negative" ); |
208 | } |
209 | if (splits(splits.size() - 1) > last_split) { |
210 | return errors::InvalidArgument( |
211 | "Ragged splits must not point past values" ); |
212 | } |
213 | for (int i = 1; i < splits.size(); ++i) { |
214 | if (splits(i - 1) > splits(i)) { |
215 | return errors::InvalidArgument("Ragged splits must be sorted" ); |
216 | } |
217 | } |
218 | } |
219 | return OkStatus(); |
220 | } |
221 | |
222 | ::tensorflow::Status WriteSplits( |
223 | const std::vector<std::vector<SPLITS_TYPE>>& out_splits, |
224 | OpKernelContext* context) { |
225 | OpOutputList splits_out; |
226 | TF_RETURN_IF_ERROR( |
227 | context->output_list("output_nested_splits" , &splits_out)); |
228 | for (int i = 0; i < out_splits.size(); ++i) { |
229 | Tensor* splits; |
230 | SPLITS_TYPE num_splits = out_splits[i].size(); |
231 | TF_RETURN_IF_ERROR( |
232 | splits_out.allocate(i, TensorShape({num_splits}), &splits)); |
233 | auto splits_flat = splits->flat<SPLITS_TYPE>(); |
234 | std::copy_n(out_splits[i].data(), out_splits[i].size(), |
235 | splits_flat.data()); |
236 | } |
237 | return OkStatus(); |
238 | } |
239 | |
240 | ::tensorflow::Status WriteValues( |
241 | const Tensor& params_dense_values_in, |
242 | const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices, |
243 | int values_index, SPLITS_TYPE num_values, |
244 | OpKernelContext* context) const { |
245 | Tensor* values_out = nullptr; |
246 | TensorShape values_shape = params_dense_values_in.shape(); |
247 | values_shape.set_dim(0, num_values); |
248 | TF_RETURN_IF_ERROR( |
249 | context->allocate_output(values_index, values_shape, &values_out)); |
250 | const SPLITS_TYPE num_elements = params_dense_values_in.NumElements(); |
251 | const SPLITS_TYPE value_size = |
252 | num_elements == 0 ? 0 |
253 | : (num_elements / params_dense_values_in.dim_size(0)); |
254 | CallWriteValueSlices(params_dense_values_in, value_slices, value_size, |
255 | values_out); |
256 | return OkStatus(); |
257 | } |
258 | |
259 | protected: |
260 | // Call WriteValueSlices() using the appropriate VALUE_TYPE template |
261 | // parameter. This pattern is used to reduce binary size. In particular, |
262 | // this allows us to have two instantiations of this class (one for each |
263 | // index type), rather than 14 (one for each index type and value type), |
264 | // which cuts the binary size of this op from ~300k to <90k. |
265 | virtual void CallWriteValueSlices( |
266 | const Tensor& params_dense_values_in, |
267 | const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices, |
268 | SPLITS_TYPE value_size, Tensor* values_out) const = 0; |
269 | }; |
270 | |
271 | template <typename INDEX_TYPE, typename VALUE_TYPE, typename SPLITS_TYPE> |
272 | class RaggedGatherOp : public RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE> { |
273 | public: |
274 | using RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE>::RaggedGatherOpBase; |
275 | |
276 | private: |
277 | void CallWriteValueSlices( |
278 | const Tensor& params_dense_values_in, |
279 | const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices, |
280 | SPLITS_TYPE value_size, Tensor* values_out) const override { |
281 | WriteValueSlices<VALUE_TYPE>(params_dense_values_in, value_slices, |
282 | value_size, values_out); |
283 | } |
284 | }; |
285 | |
286 | #define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type, \ |
287 | splits_type) \ |
288 | REGISTER_KERNEL_BUILDER( \ |
289 | Name("RaggedGather") \ |
290 | .Device(DEVICE_CPU) \ |
291 | .TypeConstraint<index_type>("Tindices") \ |
292 | .TypeConstraint<value_type>("Tvalues") \ |
293 | .TypeConstraint<splits_type>("Tsplits"), \ |
294 | RaggedGatherOp<index_type, value_type, splits_type>); |
295 | #define REGISTER_CPU_KERNEL(value_type) \ |
296 | REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int32) \ |
297 | REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64_t, value_type, int32) \ |
298 | REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int64_t) \ |
299 | REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64_t, value_type, int64_t) |
300 | TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL); |
301 | TF_CALL_tstring(REGISTER_CPU_KERNEL); |
302 | TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL); |
303 | TF_CALL_quint16(REGISTER_CPU_KERNEL); |
304 | TF_CALL_qint16(REGISTER_CPU_KERNEL); |
305 | #undef REGISTER_CPU_KERNEL |
306 | #undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE |
307 | |
308 | } // namespace tensorflow |
309 | |