1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <limits> |
16 | #include <memory> |
17 | #include <string> |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/register_types.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/tensor_shape.h" |
24 | #include "tensorflow/core/platform/errors.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | using errors::InvalidArgument; |
29 | |
30 | template <typename SPLITS_TYPE> |
31 | class RaggedTensorToSparseOp : public OpKernel { |
32 | public: |
33 | using OpKernel::OpKernel; |
34 | using ConstFlatSplits = typename TTypes<SPLITS_TYPE>::ConstFlat; |
35 | |
36 | void Compute(OpKernelContext* context) override { |
37 | // Read the `rt_nested_splits` input & convert to Eigen tensors. |
38 | OpInputList rt_nested_splits_in; |
39 | OP_REQUIRES_OK( |
40 | context, context->input_list("rt_nested_splits" , &rt_nested_splits_in)); |
41 | const int rt_nested_splits_len = rt_nested_splits_in.size(); |
42 | OP_REQUIRES(context, rt_nested_splits_len > 0, |
43 | errors::InvalidArgument("rt_nested_splits must be non empty" )); |
44 | std::vector<ConstFlatSplits> rt_nested_splits; |
45 | rt_nested_splits.reserve(rt_nested_splits_len); |
46 | for (int i = 0; i < rt_nested_splits_len; ++i) { |
47 | rt_nested_splits.push_back(rt_nested_splits_in[i].flat<SPLITS_TYPE>()); |
48 | } |
49 | |
50 | // Read the `rt_dense_values` input. |
51 | const Tensor& rt_dense_values_in = context->input(rt_nested_splits_len); |
52 | OP_REQUIRES_OK(context, |
53 | ValidateInputs(rt_nested_splits, rt_dense_values_in)); |
54 | |
55 | // Assemble each value in `sparse_indices` using three parts: |
56 | // - `index_prefix` is the index in dimensions up through the last ragged |
57 | // dimension. |
58 | // - `index_middle` is the index in the last ragged dimension. |
59 | // - `index_suffix` is the index in the dense value dimensions. |
60 | std::vector<int64_t> index_prefix(rt_nested_splits_len); |
61 | std::vector<std::vector<int64_t>> index_suffixes = |
62 | MakeIndexSuffixes(rt_dense_values_in.shape()); |
63 | |
64 | // Allocate the `sparse_indices` output tensor. |
65 | const int64_t nvals = |
66 | (rt_nested_splits.back()(rt_nested_splits.back().size() - 1) * |
67 | index_suffixes.size()); |
68 | const int64_t indices_len = |
69 | rt_nested_splits_len + rt_dense_values_in.dims(); |
70 | Tensor* sparse_indices_out = nullptr; |
71 | OP_REQUIRES_OK( |
72 | context, context->allocate_output(0, TensorShape({nvals, indices_len}), |
73 | &sparse_indices_out)); |
74 | auto sparse_indices = sparse_indices_out->tensor<int64_t, 2>(); |
75 | |
76 | // pos[i] is the current position in rt_nested_splits[i]. final_pos is a |
77 | // reference to make it easier to refer to pos[-1]. |
78 | std::vector<int64_t> pos(rt_nested_splits_len); |
79 | int64_t& final_pos = pos[rt_nested_splits_len - 1]; |
80 | |
81 | // Each iteration through the loop, we increment pos[-1], and add indices |
82 | // for all the values corresponding to |
83 | // rt_nested_splits[-1][pos[-1]:pos[-1]+1]. |
84 | int next_index = 0; |
85 | int max_final_pos = rt_nested_splits.back().size() - 1; |
86 | for (; final_pos < max_final_pos; ++final_pos) { |
87 | // Update `pos` to skip over completed elements (i.e., elements where |
88 | // we have already generated indices for all contained values). |
89 | for (int dim = rt_nested_splits_len - 2; dim >= 0; --dim) { |
90 | while (IsCompleted(pos, dim, rt_nested_splits)) { |
91 | pos[dim] += 1; |
92 | } |
93 | } |
94 | |
95 | // Update index_prefix. |
96 | for (int dim = 0; dim < index_prefix.size(); ++dim) { |
97 | int start = dim > 0 ? rt_nested_splits[dim - 1](pos[dim - 1]) : 0; |
98 | index_prefix[dim] = pos[dim] - start; |
99 | } |
100 | |
101 | // Get length of the final-ragged-dimension slice. |
102 | const auto& final_splits = rt_nested_splits[rt_nested_splits_len - 1]; |
103 | int64_t slice_len = final_splits(final_pos + 1) - final_splits(final_pos); |
104 | |
105 | // Add sparse_indices for this slice. |
106 | for (int64_t i = 0; i < slice_len; ++i) { |
107 | for (const auto& index_suffix : index_suffixes) { |
108 | int dim = 0; |
109 | for (int64_t index : index_prefix) { // index_prefix |
110 | sparse_indices(next_index, dim++) = index; |
111 | } |
112 | sparse_indices(next_index, dim++) = i; // index_middle |
113 | for (int64_t index : index_suffix) { // index_suffix |
114 | sparse_indices(next_index, dim++) = index; |
115 | } |
116 | DCHECK_EQ(dim, indices_len); |
117 | ++next_index; |
118 | } |
119 | } |
120 | } |
121 | DCHECK_EQ(next_index, nvals); |
122 | |
123 | // Output the `sparse_values` Tensor. |
124 | if (rt_dense_values_in.dims() == 1) { |
125 | context->set_output(1, rt_dense_values_in); |
126 | } else { |
127 | Tensor sparse_values_out(rt_dense_values_in.dtype()); |
128 | bool shapes_match = sparse_values_out.CopyFrom( |
129 | rt_dense_values_in, {rt_dense_values_in.NumElements()}); |
130 | DCHECK(shapes_match); |
131 | context->set_output(1, sparse_values_out); |
132 | } |
133 | |
134 | // Output the `sparse_dense_shape` Tensor. |
135 | int64_t ndims = rt_nested_splits_len + rt_dense_values_in.dims(); |
136 | Tensor* sparse_dense_shape_out = nullptr; |
137 | OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({ndims}), |
138 | &sparse_dense_shape_out)); |
139 | auto sparse_dense_shape = sparse_dense_shape_out->vec<int64_t>(); |
140 | sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1; |
141 | for (int dim = 0; dim < rt_nested_splits_len; ++dim) { |
142 | const auto& splits = rt_nested_splits[dim]; |
143 | SPLITS_TYPE max_width = 0; |
144 | for (int i = 1; i < splits.size(); ++i) { |
145 | max_width = std::max(max_width, splits(i) - splits(i - 1)); |
146 | } |
147 | sparse_dense_shape(dim + 1) = max_width; |
148 | } |
149 | for (int dim = 1; dim < rt_dense_values_in.dims(); ++dim) { |
150 | sparse_dense_shape(dim + rt_nested_splits_len) = |
151 | rt_dense_values_in.dim_size(dim); |
152 | } |
153 | } |
154 | |
155 | private: |
156 | // Validate `rt_nested_splits` to ensure we don't get any segfaults. |
157 | static ::tensorflow::Status ValidateInputs( |
158 | std::vector<ConstFlatSplits> rt_nested_splits, |
159 | const Tensor& rt_dense_values_in) { |
160 | for (int i = 0; i < rt_nested_splits.size(); ++i) { |
161 | if (rt_nested_splits[i].size() == 0) { |
162 | return InvalidArgument("ragged splits may not be empty." ); |
163 | } |
164 | if (rt_nested_splits[i](0) != 0) { |
165 | return InvalidArgument("First value of ragged splits must be 0." ); |
166 | } |
167 | for (int j = 1; j < rt_nested_splits[i].size(); ++j) { |
168 | if (rt_nested_splits[i](j) < rt_nested_splits[i](j - 1)) { |
169 | return InvalidArgument( |
170 | "Ragged splits should be non decreasing, but we got " , |
171 | rt_nested_splits[i](j - 1), " followed by " , |
172 | rt_nested_splits[i](j)); |
173 | } |
174 | } |
175 | if (i > 0) { |
176 | SPLITS_TYPE last_split = |
177 | rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1); |
178 | if (rt_nested_splits[i].size() != last_split + 1) { |
179 | return InvalidArgument( |
180 | "Final value of ragged splits must match the length " |
181 | "the corresponding ragged values." ); |
182 | } |
183 | } |
184 | } |
185 | if (rt_dense_values_in.dim_size(0) != |
186 | rt_nested_splits.back()(rt_nested_splits.back().size() - 1)) { |
187 | return InvalidArgument( |
188 | "Final value of ragged splits must match the length " |
189 | "the corresponding ragged values." ); |
190 | } |
191 | return OkStatus(); |
192 | } |
193 | |
194 | // Build a list of index suffixes that should be added for each ragged item, |
195 | // to encode the indices of dense values in that ragged item. This basically |
196 | // just gives a row-major enumeration of all indices in the given tensor |
197 | // shape, ignoring dim[0] (since that's the dimension that iterates over |
198 | // values, and we want index suffixes for a single value). Example: |
199 | // MakeIndexSuffixes(TensorShape({100, 3, 2}) |
200 | // --> {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}} |
201 | static std::vector<std::vector<int64_t>> MakeIndexSuffixes( |
202 | const TensorShape& values_shape) { |
203 | std::vector<std::vector<int64_t>> suffixes{{}}; |
204 | for (int dim = 1; dim < values_shape.dims(); ++dim) { |
205 | std::vector<std::vector<int64_t>> new_suffixes; |
206 | for (const auto& suffix : suffixes) { |
207 | for (int i = 0; i < values_shape.dim_size(dim); ++i) { |
208 | new_suffixes.push_back(suffix); |
209 | new_suffixes.back().push_back(i); |
210 | } |
211 | } |
212 | suffixes.swap(new_suffixes); |
213 | } |
214 | return suffixes; |
215 | } |
216 | |
217 | // Returns true if the ragged element at pos[dim] is "completed". A ragged |
218 | // element is completed if we have already generated indices for all of its |
219 | // values. |
220 | static bool IsCompleted( |
221 | const std::vector<int64_t>& pos, int dim, |
222 | const std::vector<ConstFlatSplits>& rt_nested_splits) { |
223 | int64_t current_child = pos[dim + 1]; |
224 | int64_t limit_child = rt_nested_splits[dim](pos[dim] + 1); |
225 | return current_child >= limit_child; |
226 | } |
227 | }; |
228 | |
229 | REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse" ) |
230 | .Device(DEVICE_CPU) |
231 | .TypeConstraint<int32>("Tsplits" ), |
232 | RaggedTensorToSparseOp<int32>); |
233 | |
234 | REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse" ) |
235 | .Device(DEVICE_CPU) |
236 | .TypeConstraint<int64_t>("Tsplits" ), |
237 | RaggedTensorToSparseOp<int64_t>); |
238 | |
239 | } // namespace tensorflow |
240 | |