1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <utility>
16#include <vector>
17
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/framework/register_types.h"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/framework/variant.h"
22#include "tensorflow/core/framework/variant_encode_decode.h"
23#include "tensorflow/core/kernels/ragged_tensor_variant.h"
24#include "tensorflow/core/lib/core/errors.h"
25#include "tensorflow/core/lib/core/status.h"
26
27namespace tensorflow {
28namespace {
29
30/* Extracts the components of the variant-encoded tensor `encoded_variant`
31 * into a flat vector of `RaggedTensorVariant` objects. */
32Status RaggedComponentsFromVariant(
33 const Tensor& encoded_variant, int input_ragged_rank,
34 int output_ragged_rank, DataType value_dtype, DataType split_dtype,
35 std::vector<RaggedTensorVariant>* decoded_ragged) {
36 const auto& flat_variants = encoded_variant.flat<Variant>();
37 decoded_ragged->reserve(flat_variants.size());
38
39 for (int i = 0; i < flat_variants.size(); i++) {
40 const auto& flat_variant = flat_variants(i);
41 const RaggedTensorVariant* decoded =
42 flat_variant.get<RaggedTensorVariant>();
43 if (decoded == nullptr) {
44 return errors::InvalidArgument(
45 "Input Variant element at index ", i,
46 " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
47 }
48 decoded_ragged->push_back(*decoded);
49 decoded = &decoded_ragged->back();
50 // Check ragged rank & types
51 if (decoded->ragged_rank() != input_ragged_rank) {
52 return errors::InvalidArgument(
53 "Encoded input RaggedTensorVariant has ragged_rank=",
54 decoded->ragged_rank(), ". Expected ragged_rank=", input_ragged_rank,
55 ".");
56 }
57 if (decoded->values().dtype() != value_dtype) {
58 return errors::InvalidArgument(
59 "Expected values Tensor dtype: ", DataTypeString(value_dtype),
60 ", found: ", DataTypeString(decoded->values().dtype()));
61 }
62 if (decoded->values().dims() < 1 && output_ragged_rank != 0) {
63 return errors::InvalidArgument(
64 "Ragged values must have rank >= 1; encoded scalar element at index ",
65 i, " has values Tensor: ", decoded->values().DebugString());
66 }
67 for (const auto& splits : decoded->nested_splits()) {
68 if (splits.dtype() != split_dtype) {
69 return errors::InvalidArgument(
70 "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
71 ", found: ", DataTypeString(splits.dtype()));
72 }
73 if (splits.dims() != 1) {
74 return errors::InvalidArgument(
75 "Ragged splits must have rank 1; encoded scalar element at index ",
76 i, " has splits Tensor ", splits.DebugString());
77 }
78 }
79 }
80 return OkStatus();
81}
82
83/* Takes a set of RaggedTensorVariants for non-ragged tensors, stacks
84 * their flat_values, and sets output_ragged's flat_values to that stacked
85 * value. I.e.:
86 *
87 * output_ragged.values = stack([c.values for c in ragged_components])
88 *
89 * Requires that elements of `ragged_components` have no splits.
90 *
91 * This should only be used when input_ragged_rank=0 and output_ragged_rank=0.
92 */
93template <typename VALUE_TYPE>
94Status StackNonRaggedTensors(
95 const std::vector<RaggedTensorVariant>& ragged_components,
96 RaggedTensorVariant* output_ragged) {
97 if (ragged_components.empty()) {
98 output_ragged->set_values(Tensor(DataTypeToEnum<VALUE_TYPE>::value, {0}));
99 return OkStatus();
100 }
101
102 TensorShape component_values_shape = ragged_components[0].values().shape();
103 TensorShape result_shape = component_values_shape;
104 result_shape.InsertDim(0, ragged_components.size());
105
106 output_ragged->set_values(
107 Tensor(DataTypeToEnum<VALUE_TYPE>::value, result_shape));
108 auto output_values_flat = output_ragged->mutable_values()->flat<VALUE_TYPE>();
109 int values_index = 0;
110 for (int i = 0; i < ragged_components.size(); i++) {
111 auto& component_values = ragged_components[i].values();
112 if (component_values.shape() != component_values_shape) {
113 return errors::InvalidArgument(
114 "All flat_values must have compatible shapes. Shape at index 0: ",
115 component_values_shape, ". Shape at index ", i, ": ",
116 component_values.shape());
117 }
118 auto component_values_flat = component_values.flat<VALUE_TYPE>();
119 for (int j = 0; j < component_values_flat.size(); j++) {
120 output_values_flat(values_index++) = component_values_flat(j);
121 }
122 }
123 return OkStatus();
124}
125
126template <typename VALUE_TYPE, typename SPLIT_TYPE>
127Status NestedStackRaggedTensors(
128 const std::vector<RaggedTensorVariant>& ragged_components,
129 const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
130 const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
131 output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
132 const int dims = nested_dim_sizes.size();
133
134 if (output_ragged_rank == 0) {
135 if (input_ragged_rank > 0) {
136 return errors::InvalidArgument(
137 "Expected input_ragged_rank=0 if output_ragged_rank==0. "
138 "Got input_ragged_rank=",
139 input_ragged_rank);
140 }
141 return StackNonRaggedTensors<VALUE_TYPE>(ragged_components, output_ragged);
142 }
143
144 // Populate first `dims - 1` splits.
145 for (int i = 0; i < dims - 1; i++) {
146 int dims_splits_size = nested_dim_sizes[i] + 1;
147 output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
148 TensorShape({dims_splits_size})));
149 auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
150 int split_diff = nested_dim_sizes[i + 1];
151 for (int j = 0; j < dims_splits_size; j++) {
152 splits_vec(j) = j * split_diff;
153 }
154 }
155
156 // Populate `dims`-th split.
157 int splits_size = ragged_components.size() + 1;
158 output_ragged->append_splits(
159 Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
160 auto dims_splits_vec =
161 output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
162 dims_splits_vec(0) = 0;
163 for (int i = 0; i < ragged_components.size(); i++) {
164 int split_val = ragged_components[i].values().shape().dim_size(0);
165 if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
166 split_val = ragged_components[i].splits(0).NumElements() - 1;
167 }
168 dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
169 }
170
171 // Populate last `input_ragged_rank` splits.
172 for (int i = 0; i < input_ragged_rank; i++) {
173 int split_index = dims + i;
174 int split_size = 1;
175 for (int j = 0; j < ragged_components.size(); j++) {
176 if (!ragged_components[j].nested_splits().empty()) {
177 split_size += ragged_components[j].splits(i).NumElements() - 1;
178 }
179 }
180 output_ragged->append_splits(
181 Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
182 auto splits_vec =
183 output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
184 splits_vec(0) = 0;
185 SPLIT_TYPE last_split_value = 0;
186 int index = 1;
187 for (int j = 0; j < ragged_components.size(); j++) {
188 if (ragged_components[j].nested_splits().empty()) {
189 // Corner case: empty row. e.g [ [[x], [x]], [] ]
190 continue;
191 }
192 auto component_splits_vec =
193 ragged_components[j].splits(i).vec<SPLIT_TYPE>();
194 for (int k = 1; k < component_splits_vec.size(); k++, index++) {
195 splits_vec(index) = component_splits_vec(k) + last_split_value;
196 }
197 last_split_value = splits_vec(index - 1);
198 }
199 }
200
201 // If the variant tensor input is empty, then we have no way to determine
202 // the correct shape for the dense_values. (It must have rank>=1, and its
203 // outer dimension must be 0, but we don't know its shape beyond that.)
204 // For now, we just use a shape of `[0]` in this case.
205 // TODO(edloper): Update this op with an attribute containing information
206 // about dense_values shape. If it's `None`, then we'll probably still have
207 // to use shape=[0] here, but if we have more info, then we can use it.
208 // E.g., in map_fn, we may have shape info from the RaggedTensorSpec.
209 TensorShape component_values_shape;
210 if (ragged_components.empty()) {
211 component_values_shape = TensorShape({0});
212 } else {
213 component_values_shape = ragged_components[0].values().shape();
214 }
215
216 // Populate values.
217 int values_size = component_values_shape.dim_size(0);
218 for (int i = 1; i < ragged_components.size(); i++) {
219 if (ragged_components[i].values().dims() != component_values_shape.dims()) {
220 return errors::InvalidArgument(
221 "Rank of values must match for all "
222 "components; values shape at index 0: ",
223 component_values_shape.DebugString(), ", values shape at index ", i,
224 ": ", ragged_components[i].values().shape().DebugString());
225 }
226 values_size += ragged_components[i].values().shape().dim_size(0);
227 }
228 component_values_shape.set_dim(0, values_size);
229 output_ragged->set_values(
230 Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
231 auto output_values_flat =
232 output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
233 int values_index = 0;
234
235 TensorShape expected_value_shape = component_values_shape;
236 expected_value_shape.RemoveDim(0);
237
238 for (int i = 0; i < ragged_components.size(); i++) {
239 // Check that the flat_values tensor shape is compatible.
240 TensorShape value_shape = ragged_components[i].values().shape();
241 value_shape.RemoveDim(0);
242 if (value_shape != expected_value_shape) {
243 return errors::InvalidArgument(
244 "All flat_values must have compatible shapes. Shape at index 0: ",
245 expected_value_shape, ". Shape at index ", i, ": ", value_shape,
246 ". If you are using tf.map_fn, then you may need to specify an "
247 "explicit fn_output_signature with appropriate ragged_rank, and/or "
248 "convert output tensors to RaggedTensors.");
249 }
250
251 auto component_values_flat =
252 ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
253 int num_inner_elements = ragged_components[i].values().NumElements();
254 if (ragged_components[i].values().dim_size(0) > 0) {
255 num_inner_elements /= ragged_components[i].values().dim_size(0);
256 }
257 for (int j = 0; j < ragged_components[i].values().dim_size(0);
258 j++, values_index++) {
259 for (int k = 0; k < num_inner_elements; k++) {
260 output_values_flat(values_index, k) = component_values_flat(j, k);
261 }
262 }
263 }
264 return OkStatus();
265}
266} // namespace
267
268template <typename VALUE_TYPE, typename SPLIT_TYPE>
269class RaggedTensorFromVariantOp : public OpKernel {
270 public:
271 explicit RaggedTensorFromVariantOp(OpKernelConstruction* context)
272 : OpKernel(context) {
273 OP_REQUIRES_OK(context, context->GetAttr("input_ragged_rank",
274 &input_ragged_rank_attr_));
275 OP_REQUIRES_OK(
276 context, context->GetAttr("output_ragged_rank", &output_ragged_rank_));
277 }
278
279 void Compute(OpKernelContext* context) override {
280 // Read input Tensor.
281 const Tensor& encoded_variant = context->input(0);
282 auto input_ragged_rank_ = input_ragged_rank_attr_;
283
284 if (input_ragged_rank_ == -1) { // Infer input_ragged_rank_.
285 input_ragged_rank_ = output_ragged_rank_ - encoded_variant.dims();
286 if (output_ragged_rank_ == 0 && input_ragged_rank_ < 0) {
287 input_ragged_rank_ = 0;
288 }
289 OP_REQUIRES(context, input_ragged_rank_ >= 0,
290 errors::InvalidArgument(
291 "Inferred input_ragged_rank (output_ragged_rank - "
292 "encoded_variant.dims()) must be >= 0, found "
293 "output_ragged_rank: ",
294 output_ragged_rank_,
295 ", encoded_variant.dims(): ", encoded_variant.dims(),
296 ", inferred input_ragged_rank: ", input_ragged_rank_));
297 }
298 OP_REQUIRES(
299 context,
300 (output_ragged_rank_ == 0 && input_ragged_rank_ == 0) ||
301 (output_ragged_rank_ ==
302 encoded_variant.dims() + input_ragged_rank_),
303 errors::InvalidArgument(
304 "output_ragged_rank must be equal to input_ragged_rank + "
305 "encoded_ragged.dims(); output_ragged_rank: ",
306 output_ragged_rank_, ", input_ragged_rank: ", input_ragged_rank_,
307 ", encoded_variant.dims(): ", encoded_variant.dims(), "."));
308
309 // Decode all variants.
310 const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
311 const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
312 std::vector<RaggedTensorVariant> decoded_components;
313 OP_REQUIRES_OK(context,
314 RaggedComponentsFromVariant(
315 encoded_variant, input_ragged_rank_, output_ragged_rank_,
316 value_dtype, split_dtype, &decoded_components));
317
318 // Corner case: input is a scalar.
319 if (encoded_variant.dims() == 0) {
320 ReturnRaggedTensor(context, decoded_components[0]);
321 return;
322 }
323
324 // Nested-Stack Ragged components into a batched RaggedTensor.
325 std::vector<int> encoded_dim_sizes(encoded_variant.dims(), 0);
326 for (int i = 0; i < encoded_variant.dims(); i++) {
327 encoded_dim_sizes[i] = encoded_variant.dim_size(i);
328 }
329 RaggedTensorVariant output_ragged;
330 OP_REQUIRES_OK(
331 context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
332 decoded_components, encoded_dim_sizes, input_ragged_rank_,
333 output_ragged_rank_, &output_ragged));
334
335 // Set output.
336 ReturnRaggedTensor(context, output_ragged);
337 }
338
339 private:
340 int input_ragged_rank_attr_;
341 int output_ragged_rank_;
342
343 void ReturnRaggedTensor(OpKernelContext* context,
344 const RaggedTensorVariant& ragged_tensor) {
345 int ragged_rank = ragged_tensor.ragged_rank();
346 OpOutputList splits_out;
347 OP_REQUIRES_OK(context,
348 context->output_list("output_nested_splits", &splits_out));
349 for (int i = 0; i < ragged_rank; i++) {
350 splits_out.set(i, ragged_tensor.splits(i));
351 }
352 context->set_output(ragged_rank, ragged_tensor.values());
353 }
354};
355
356#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
357 REGISTER_KERNEL_BUILDER(Name("RaggedTensorFromVariant") \
358 .Device(DEVICE_CPU) \
359 .TypeConstraint<value_type>("Tvalues") \
360 .TypeConstraint<split_type>("Tsplits"), \
361 RaggedTensorFromVariantOp<value_type, split_type>);
362#define REGISTER_KERNELS(value_type) \
363 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
364 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64_t)
365TF_CALL_POD_TYPES(REGISTER_KERNELS);
366TF_CALL_tstring(REGISTER_KERNELS);
367TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
368TF_CALL_quint16(REGISTER_KERNELS);
369TF_CALL_qint16(REGISTER_KERNELS);
370#undef REGISTER_KERNELS
371#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
372} // namespace tensorflow
373