1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <limits>
16#include <memory>
17#include <string>
18#include <vector>
19
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/register_types.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/framework/tensor_shape.h"
24#include "tensorflow/core/platform/errors.h"
25
26namespace tensorflow {
27
28using errors::InvalidArgument;
29
30template <typename SPLITS_TYPE>
31class RaggedTensorToSparseOp : public OpKernel {
32 public:
33 using OpKernel::OpKernel;
34 using ConstFlatSplits = typename TTypes<SPLITS_TYPE>::ConstFlat;
35
36 void Compute(OpKernelContext* context) override {
37 // Read the `rt_nested_splits` input & convert to Eigen tensors.
38 OpInputList rt_nested_splits_in;
39 OP_REQUIRES_OK(
40 context, context->input_list("rt_nested_splits", &rt_nested_splits_in));
41 const int rt_nested_splits_len = rt_nested_splits_in.size();
42 OP_REQUIRES(context, rt_nested_splits_len > 0,
43 errors::InvalidArgument("rt_nested_splits must be non empty"));
44 std::vector<ConstFlatSplits> rt_nested_splits;
45 rt_nested_splits.reserve(rt_nested_splits_len);
46 for (int i = 0; i < rt_nested_splits_len; ++i) {
47 rt_nested_splits.push_back(rt_nested_splits_in[i].flat<SPLITS_TYPE>());
48 }
49
50 // Read the `rt_dense_values` input.
51 const Tensor& rt_dense_values_in = context->input(rt_nested_splits_len);
52 OP_REQUIRES_OK(context,
53 ValidateInputs(rt_nested_splits, rt_dense_values_in));
54
55 // Assemble each value in `sparse_indices` using three parts:
56 // - `index_prefix` is the index in dimensions up through the last ragged
57 // dimension.
58 // - `index_middle` is the index in the last ragged dimension.
59 // - `index_suffix` is the index in the dense value dimensions.
60 std::vector<int64_t> index_prefix(rt_nested_splits_len);
61 std::vector<std::vector<int64_t>> index_suffixes =
62 MakeIndexSuffixes(rt_dense_values_in.shape());
63
64 // Allocate the `sparse_indices` output tensor.
65 const int64_t nvals =
66 (rt_nested_splits.back()(rt_nested_splits.back().size() - 1) *
67 index_suffixes.size());
68 const int64_t indices_len =
69 rt_nested_splits_len + rt_dense_values_in.dims();
70 Tensor* sparse_indices_out = nullptr;
71 OP_REQUIRES_OK(
72 context, context->allocate_output(0, TensorShape({nvals, indices_len}),
73 &sparse_indices_out));
74 auto sparse_indices = sparse_indices_out->tensor<int64_t, 2>();
75
76 // pos[i] is the current position in rt_nested_splits[i]. final_pos is a
77 // reference to make it easier to refer to pos[-1].
78 std::vector<int64_t> pos(rt_nested_splits_len);
79 int64_t& final_pos = pos[rt_nested_splits_len - 1];
80
81 // Each iteration through the loop, we increment pos[-1], and add indices
82 // for all the values corresponding to
83 // rt_nested_splits[-1][pos[-1]:pos[-1]+1].
84 int next_index = 0;
85 int max_final_pos = rt_nested_splits.back().size() - 1;
86 for (; final_pos < max_final_pos; ++final_pos) {
87 // Update `pos` to skip over completed elements (i.e., elements where
88 // we have already generated indices for all contained values).
89 for (int dim = rt_nested_splits_len - 2; dim >= 0; --dim) {
90 while (IsCompleted(pos, dim, rt_nested_splits)) {
91 pos[dim] += 1;
92 }
93 }
94
95 // Update index_prefix.
96 for (int dim = 0; dim < index_prefix.size(); ++dim) {
97 int start = dim > 0 ? rt_nested_splits[dim - 1](pos[dim - 1]) : 0;
98 index_prefix[dim] = pos[dim] - start;
99 }
100
101 // Get length of the final-ragged-dimension slice.
102 const auto& final_splits = rt_nested_splits[rt_nested_splits_len - 1];
103 int64_t slice_len = final_splits(final_pos + 1) - final_splits(final_pos);
104
105 // Add sparse_indices for this slice.
106 for (int64_t i = 0; i < slice_len; ++i) {
107 for (const auto& index_suffix : index_suffixes) {
108 int dim = 0;
109 for (int64_t index : index_prefix) { // index_prefix
110 sparse_indices(next_index, dim++) = index;
111 }
112 sparse_indices(next_index, dim++) = i; // index_middle
113 for (int64_t index : index_suffix) { // index_suffix
114 sparse_indices(next_index, dim++) = index;
115 }
116 DCHECK_EQ(dim, indices_len);
117 ++next_index;
118 }
119 }
120 }
121 DCHECK_EQ(next_index, nvals);
122
123 // Output the `sparse_values` Tensor.
124 if (rt_dense_values_in.dims() == 1) {
125 context->set_output(1, rt_dense_values_in);
126 } else {
127 Tensor sparse_values_out(rt_dense_values_in.dtype());
128 bool shapes_match = sparse_values_out.CopyFrom(
129 rt_dense_values_in, {rt_dense_values_in.NumElements()});
130 DCHECK(shapes_match);
131 context->set_output(1, sparse_values_out);
132 }
133
134 // Output the `sparse_dense_shape` Tensor.
135 int64_t ndims = rt_nested_splits_len + rt_dense_values_in.dims();
136 Tensor* sparse_dense_shape_out = nullptr;
137 OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({ndims}),
138 &sparse_dense_shape_out));
139 auto sparse_dense_shape = sparse_dense_shape_out->vec<int64_t>();
140 sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1;
141 for (int dim = 0; dim < rt_nested_splits_len; ++dim) {
142 const auto& splits = rt_nested_splits[dim];
143 SPLITS_TYPE max_width = 0;
144 for (int i = 1; i < splits.size(); ++i) {
145 max_width = std::max(max_width, splits(i) - splits(i - 1));
146 }
147 sparse_dense_shape(dim + 1) = max_width;
148 }
149 for (int dim = 1; dim < rt_dense_values_in.dims(); ++dim) {
150 sparse_dense_shape(dim + rt_nested_splits_len) =
151 rt_dense_values_in.dim_size(dim);
152 }
153 }
154
155 private:
156 // Validate `rt_nested_splits` to ensure we don't get any segfaults.
157 static ::tensorflow::Status ValidateInputs(
158 std::vector<ConstFlatSplits> rt_nested_splits,
159 const Tensor& rt_dense_values_in) {
160 for (int i = 0; i < rt_nested_splits.size(); ++i) {
161 if (rt_nested_splits[i].size() == 0) {
162 return InvalidArgument("ragged splits may not be empty.");
163 }
164 if (rt_nested_splits[i](0) != 0) {
165 return InvalidArgument("First value of ragged splits must be 0.");
166 }
167 for (int j = 1; j < rt_nested_splits[i].size(); ++j) {
168 if (rt_nested_splits[i](j) < rt_nested_splits[i](j - 1)) {
169 return InvalidArgument(
170 "Ragged splits should be non decreasing, but we got ",
171 rt_nested_splits[i](j - 1), " followed by ",
172 rt_nested_splits[i](j));
173 }
174 }
175 if (i > 0) {
176 SPLITS_TYPE last_split =
177 rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);
178 if (rt_nested_splits[i].size() != last_split + 1) {
179 return InvalidArgument(
180 "Final value of ragged splits must match the length "
181 "the corresponding ragged values.");
182 }
183 }
184 }
185 if (rt_dense_values_in.dim_size(0) !=
186 rt_nested_splits.back()(rt_nested_splits.back().size() - 1)) {
187 return InvalidArgument(
188 "Final value of ragged splits must match the length "
189 "the corresponding ragged values.");
190 }
191 return OkStatus();
192 }
193
194 // Build a list of index suffixes that should be added for each ragged item,
195 // to encode the indices of dense values in that ragged item. This basically
196 // just gives a row-major enumeration of all indices in the given tensor
197 // shape, ignoring dim[0] (since that's the dimension that iterates over
198 // values, and we want index suffixes for a single value). Example:
199 // MakeIndexSuffixes(TensorShape({100, 3, 2})
200 // --> {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}}
201 static std::vector<std::vector<int64_t>> MakeIndexSuffixes(
202 const TensorShape& values_shape) {
203 std::vector<std::vector<int64_t>> suffixes{{}};
204 for (int dim = 1; dim < values_shape.dims(); ++dim) {
205 std::vector<std::vector<int64_t>> new_suffixes;
206 for (const auto& suffix : suffixes) {
207 for (int i = 0; i < values_shape.dim_size(dim); ++i) {
208 new_suffixes.push_back(suffix);
209 new_suffixes.back().push_back(i);
210 }
211 }
212 suffixes.swap(new_suffixes);
213 }
214 return suffixes;
215 }
216
217 // Returns true if the ragged element at pos[dim] is "completed". A ragged
218 // element is completed if we have already generated indices for all of its
219 // values.
220 static bool IsCompleted(
221 const std::vector<int64_t>& pos, int dim,
222 const std::vector<ConstFlatSplits>& rt_nested_splits) {
223 int64_t current_child = pos[dim + 1];
224 int64_t limit_child = rt_nested_splits[dim](pos[dim] + 1);
225 return current_child >= limit_child;
226 }
227};
228
229REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
230 .Device(DEVICE_CPU)
231 .TypeConstraint<int32>("Tsplits"),
232 RaggedTensorToSparseOp<int32>);
233
234REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
235 .Device(DEVICE_CPU)
236 .TypeConstraint<int64_t>("Tsplits"),
237 RaggedTensorToSparseOp<int64_t>);
238
239} // namespace tensorflow
240