1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <utility> |
16 | #include <vector> |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/register_types.h" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/framework/variant.h" |
22 | #include "tensorflow/core/framework/variant_encode_decode.h" |
23 | #include "tensorflow/core/kernels/ragged_tensor_variant.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace { |
29 | |
30 | /* Extracts the components of the variant-encoded tensor `encoded_variant` |
31 | * into a flat vector of `RaggedTensorVariant` objects. */ |
32 | Status RaggedComponentsFromVariant( |
33 | const Tensor& encoded_variant, int input_ragged_rank, |
34 | int output_ragged_rank, DataType value_dtype, DataType split_dtype, |
35 | std::vector<RaggedTensorVariant>* decoded_ragged) { |
36 | const auto& flat_variants = encoded_variant.flat<Variant>(); |
37 | decoded_ragged->reserve(flat_variants.size()); |
38 | |
39 | for (int i = 0; i < flat_variants.size(); i++) { |
40 | const auto& flat_variant = flat_variants(i); |
41 | const RaggedTensorVariant* decoded = |
42 | flat_variant.get<RaggedTensorVariant>(); |
43 | if (decoded == nullptr) { |
44 | return errors::InvalidArgument( |
45 | "Input Variant element at index " , i, |
46 | " doesn't hold a RaggedTensorVariant: " , flat_variant.DebugString()); |
47 | } |
48 | decoded_ragged->push_back(*decoded); |
49 | decoded = &decoded_ragged->back(); |
50 | // Check ragged rank & types |
51 | if (decoded->ragged_rank() != input_ragged_rank) { |
52 | return errors::InvalidArgument( |
53 | "Encoded input RaggedTensorVariant has ragged_rank=" , |
54 | decoded->ragged_rank(), ". Expected ragged_rank=" , input_ragged_rank, |
55 | "." ); |
56 | } |
57 | if (decoded->values().dtype() != value_dtype) { |
58 | return errors::InvalidArgument( |
59 | "Expected values Tensor dtype: " , DataTypeString(value_dtype), |
60 | ", found: " , DataTypeString(decoded->values().dtype())); |
61 | } |
62 | if (decoded->values().dims() < 1 && output_ragged_rank != 0) { |
63 | return errors::InvalidArgument( |
64 | "Ragged values must have rank >= 1; encoded scalar element at index " , |
65 | i, " has values Tensor: " , decoded->values().DebugString()); |
66 | } |
67 | for (const auto& splits : decoded->nested_splits()) { |
68 | if (splits.dtype() != split_dtype) { |
69 | return errors::InvalidArgument( |
70 | "Expected row_splits Tensor dtype: " , DataTypeString(split_dtype), |
71 | ", found: " , DataTypeString(splits.dtype())); |
72 | } |
73 | if (splits.dims() != 1) { |
74 | return errors::InvalidArgument( |
75 | "Ragged splits must have rank 1; encoded scalar element at index " , |
76 | i, " has splits Tensor " , splits.DebugString()); |
77 | } |
78 | } |
79 | } |
80 | return OkStatus(); |
81 | } |
82 | |
83 | /* Takes a set of RaggedTensorVariants for non-ragged tensors, stacks |
84 | * their flat_values, and sets output_ragged's flat_values to that stacked |
85 | * value. I.e.: |
86 | * |
87 | * output_ragged.values = stack([c.values for c in ragged_components]) |
88 | * |
89 | * Requires that elements of `ragged_components` have no splits. |
90 | * |
91 | * This should only be used when input_ragged_rank=0 and output_ragged_rank=0. |
92 | */ |
93 | template <typename VALUE_TYPE> |
94 | Status StackNonRaggedTensors( |
95 | const std::vector<RaggedTensorVariant>& ragged_components, |
96 | RaggedTensorVariant* output_ragged) { |
97 | if (ragged_components.empty()) { |
98 | output_ragged->set_values(Tensor(DataTypeToEnum<VALUE_TYPE>::value, {0})); |
99 | return OkStatus(); |
100 | } |
101 | |
102 | TensorShape component_values_shape = ragged_components[0].values().shape(); |
103 | TensorShape result_shape = component_values_shape; |
104 | result_shape.InsertDim(0, ragged_components.size()); |
105 | |
106 | output_ragged->set_values( |
107 | Tensor(DataTypeToEnum<VALUE_TYPE>::value, result_shape)); |
108 | auto output_values_flat = output_ragged->mutable_values()->flat<VALUE_TYPE>(); |
109 | int values_index = 0; |
110 | for (int i = 0; i < ragged_components.size(); i++) { |
111 | auto& component_values = ragged_components[i].values(); |
112 | if (component_values.shape() != component_values_shape) { |
113 | return errors::InvalidArgument( |
114 | "All flat_values must have compatible shapes. Shape at index 0: " , |
115 | component_values_shape, ". Shape at index " , i, ": " , |
116 | component_values.shape()); |
117 | } |
118 | auto component_values_flat = component_values.flat<VALUE_TYPE>(); |
119 | for (int j = 0; j < component_values_flat.size(); j++) { |
120 | output_values_flat(values_index++) = component_values_flat(j); |
121 | } |
122 | } |
123 | return OkStatus(); |
124 | } |
125 | |
126 | template <typename VALUE_TYPE, typename SPLIT_TYPE> |
127 | Status NestedStackRaggedTensors( |
128 | const std::vector<RaggedTensorVariant>& ragged_components, |
129 | const std::vector<int>& nested_dim_sizes, const int input_ragged_rank, |
130 | const int output_ragged_rank, RaggedTensorVariant* output_ragged) { |
131 | output_ragged->mutable_nested_splits()->reserve(output_ragged_rank); |
132 | const int dims = nested_dim_sizes.size(); |
133 | |
134 | if (output_ragged_rank == 0) { |
135 | if (input_ragged_rank > 0) { |
136 | return errors::InvalidArgument( |
137 | "Expected input_ragged_rank=0 if output_ragged_rank==0. " |
138 | "Got input_ragged_rank=" , |
139 | input_ragged_rank); |
140 | } |
141 | return StackNonRaggedTensors<VALUE_TYPE>(ragged_components, output_ragged); |
142 | } |
143 | |
144 | // Populate first `dims - 1` splits. |
145 | for (int i = 0; i < dims - 1; i++) { |
146 | int dims_splits_size = nested_dim_sizes[i] + 1; |
147 | output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value, |
148 | TensorShape({dims_splits_size}))); |
149 | auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>(); |
150 | int split_diff = nested_dim_sizes[i + 1]; |
151 | for (int j = 0; j < dims_splits_size; j++) { |
152 | splits_vec(j) = j * split_diff; |
153 | } |
154 | } |
155 | |
156 | // Populate `dims`-th split. |
157 | int splits_size = ragged_components.size() + 1; |
158 | output_ragged->append_splits( |
159 | Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size}))); |
160 | auto dims_splits_vec = |
161 | output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>(); |
162 | dims_splits_vec(0) = 0; |
163 | for (int i = 0; i < ragged_components.size(); i++) { |
164 | int split_val = ragged_components[i].values().shape().dim_size(0); |
165 | if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) { |
166 | split_val = ragged_components[i].splits(0).NumElements() - 1; |
167 | } |
168 | dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val; |
169 | } |
170 | |
171 | // Populate last `input_ragged_rank` splits. |
172 | for (int i = 0; i < input_ragged_rank; i++) { |
173 | int split_index = dims + i; |
174 | int split_size = 1; |
175 | for (int j = 0; j < ragged_components.size(); j++) { |
176 | if (!ragged_components[j].nested_splits().empty()) { |
177 | split_size += ragged_components[j].splits(i).NumElements() - 1; |
178 | } |
179 | } |
180 | output_ragged->append_splits( |
181 | Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size}))); |
182 | auto splits_vec = |
183 | output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>(); |
184 | splits_vec(0) = 0; |
185 | SPLIT_TYPE last_split_value = 0; |
186 | int index = 1; |
187 | for (int j = 0; j < ragged_components.size(); j++) { |
188 | if (ragged_components[j].nested_splits().empty()) { |
189 | // Corner case: empty row. e.g [ [[x], [x]], [] ] |
190 | continue; |
191 | } |
192 | auto component_splits_vec = |
193 | ragged_components[j].splits(i).vec<SPLIT_TYPE>(); |
194 | for (int k = 1; k < component_splits_vec.size(); k++, index++) { |
195 | splits_vec(index) = component_splits_vec(k) + last_split_value; |
196 | } |
197 | last_split_value = splits_vec(index - 1); |
198 | } |
199 | } |
200 | |
201 | // If the variant tensor input is empty, then we have no way to determine |
202 | // the correct shape for the dense_values. (It must have rank>=1, and its |
203 | // outer dimension must be 0, but we don't know its shape beyond that.) |
204 | // For now, we just use a shape of `[0]` in this case. |
205 | // TODO(edloper): Update this op with an attribute containing information |
206 | // about dense_values shape. If it's `None`, then we'll probably still have |
207 | // to use shape=[0] here, but if we have more info, then we can use it. |
208 | // E.g., in map_fn, we may have shape info from the RaggedTensorSpec. |
209 | TensorShape component_values_shape; |
210 | if (ragged_components.empty()) { |
211 | component_values_shape = TensorShape({0}); |
212 | } else { |
213 | component_values_shape = ragged_components[0].values().shape(); |
214 | } |
215 | |
216 | // Populate values. |
217 | int values_size = component_values_shape.dim_size(0); |
218 | for (int i = 1; i < ragged_components.size(); i++) { |
219 | if (ragged_components[i].values().dims() != component_values_shape.dims()) { |
220 | return errors::InvalidArgument( |
221 | "Rank of values must match for all " |
222 | "components; values shape at index 0: " , |
223 | component_values_shape.DebugString(), ", values shape at index " , i, |
224 | ": " , ragged_components[i].values().shape().DebugString()); |
225 | } |
226 | values_size += ragged_components[i].values().shape().dim_size(0); |
227 | } |
228 | component_values_shape.set_dim(0, values_size); |
229 | output_ragged->set_values( |
230 | Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape)); |
231 | auto output_values_flat = |
232 | output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>(); |
233 | int values_index = 0; |
234 | |
235 | TensorShape expected_value_shape = component_values_shape; |
236 | expected_value_shape.RemoveDim(0); |
237 | |
238 | for (int i = 0; i < ragged_components.size(); i++) { |
239 | // Check that the flat_values tensor shape is compatible. |
240 | TensorShape value_shape = ragged_components[i].values().shape(); |
241 | value_shape.RemoveDim(0); |
242 | if (value_shape != expected_value_shape) { |
243 | return errors::InvalidArgument( |
244 | "All flat_values must have compatible shapes. Shape at index 0: " , |
245 | expected_value_shape, ". Shape at index " , i, ": " , value_shape, |
246 | ". If you are using tf.map_fn, then you may need to specify an " |
247 | "explicit fn_output_signature with appropriate ragged_rank, and/or " |
248 | "convert output tensors to RaggedTensors." ); |
249 | } |
250 | |
251 | auto component_values_flat = |
252 | ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>(); |
253 | int num_inner_elements = ragged_components[i].values().NumElements(); |
254 | if (ragged_components[i].values().dim_size(0) > 0) { |
255 | num_inner_elements /= ragged_components[i].values().dim_size(0); |
256 | } |
257 | for (int j = 0; j < ragged_components[i].values().dim_size(0); |
258 | j++, values_index++) { |
259 | for (int k = 0; k < num_inner_elements; k++) { |
260 | output_values_flat(values_index, k) = component_values_flat(j, k); |
261 | } |
262 | } |
263 | } |
264 | return OkStatus(); |
265 | } |
266 | } // namespace |
267 | |
268 | template <typename VALUE_TYPE, typename SPLIT_TYPE> |
269 | class RaggedTensorFromVariantOp : public OpKernel { |
270 | public: |
271 | explicit RaggedTensorFromVariantOp(OpKernelConstruction* context) |
272 | : OpKernel(context) { |
273 | OP_REQUIRES_OK(context, context->GetAttr("input_ragged_rank" , |
274 | &input_ragged_rank_attr_)); |
275 | OP_REQUIRES_OK( |
276 | context, context->GetAttr("output_ragged_rank" , &output_ragged_rank_)); |
277 | } |
278 | |
279 | void Compute(OpKernelContext* context) override { |
280 | // Read input Tensor. |
281 | const Tensor& encoded_variant = context->input(0); |
282 | auto input_ragged_rank_ = input_ragged_rank_attr_; |
283 | |
284 | if (input_ragged_rank_ == -1) { // Infer input_ragged_rank_. |
285 | input_ragged_rank_ = output_ragged_rank_ - encoded_variant.dims(); |
286 | if (output_ragged_rank_ == 0 && input_ragged_rank_ < 0) { |
287 | input_ragged_rank_ = 0; |
288 | } |
289 | OP_REQUIRES(context, input_ragged_rank_ >= 0, |
290 | errors::InvalidArgument( |
291 | "Inferred input_ragged_rank (output_ragged_rank - " |
292 | "encoded_variant.dims()) must be >= 0, found " |
293 | "output_ragged_rank: " , |
294 | output_ragged_rank_, |
295 | ", encoded_variant.dims(): " , encoded_variant.dims(), |
296 | ", inferred input_ragged_rank: " , input_ragged_rank_)); |
297 | } |
298 | OP_REQUIRES( |
299 | context, |
300 | (output_ragged_rank_ == 0 && input_ragged_rank_ == 0) || |
301 | (output_ragged_rank_ == |
302 | encoded_variant.dims() + input_ragged_rank_), |
303 | errors::InvalidArgument( |
304 | "output_ragged_rank must be equal to input_ragged_rank + " |
305 | "encoded_ragged.dims(); output_ragged_rank: " , |
306 | output_ragged_rank_, ", input_ragged_rank: " , input_ragged_rank_, |
307 | ", encoded_variant.dims(): " , encoded_variant.dims(), "." )); |
308 | |
309 | // Decode all variants. |
310 | const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v(); |
311 | const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v(); |
312 | std::vector<RaggedTensorVariant> decoded_components; |
313 | OP_REQUIRES_OK(context, |
314 | RaggedComponentsFromVariant( |
315 | encoded_variant, input_ragged_rank_, output_ragged_rank_, |
316 | value_dtype, split_dtype, &decoded_components)); |
317 | |
318 | // Corner case: input is a scalar. |
319 | if (encoded_variant.dims() == 0) { |
320 | ReturnRaggedTensor(context, decoded_components[0]); |
321 | return; |
322 | } |
323 | |
324 | // Nested-Stack Ragged components into a batched RaggedTensor. |
325 | std::vector<int> encoded_dim_sizes(encoded_variant.dims(), 0); |
326 | for (int i = 0; i < encoded_variant.dims(); i++) { |
327 | encoded_dim_sizes[i] = encoded_variant.dim_size(i); |
328 | } |
329 | RaggedTensorVariant output_ragged; |
330 | OP_REQUIRES_OK( |
331 | context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>( |
332 | decoded_components, encoded_dim_sizes, input_ragged_rank_, |
333 | output_ragged_rank_, &output_ragged)); |
334 | |
335 | // Set output. |
336 | ReturnRaggedTensor(context, output_ragged); |
337 | } |
338 | |
339 | private: |
340 | int input_ragged_rank_attr_; |
341 | int output_ragged_rank_; |
342 | |
343 | void ReturnRaggedTensor(OpKernelContext* context, |
344 | const RaggedTensorVariant& ragged_tensor) { |
345 | int ragged_rank = ragged_tensor.ragged_rank(); |
346 | OpOutputList splits_out; |
347 | OP_REQUIRES_OK(context, |
348 | context->output_list("output_nested_splits" , &splits_out)); |
349 | for (int i = 0; i < ragged_rank; i++) { |
350 | splits_out.set(i, ragged_tensor.splits(i)); |
351 | } |
352 | context->set_output(ragged_rank, ragged_tensor.values()); |
353 | } |
354 | }; |
355 | |
356 | #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ |
357 | REGISTER_KERNEL_BUILDER(Name("RaggedTensorFromVariant") \ |
358 | .Device(DEVICE_CPU) \ |
359 | .TypeConstraint<value_type>("Tvalues") \ |
360 | .TypeConstraint<split_type>("Tsplits"), \ |
361 | RaggedTensorFromVariantOp<value_type, split_type>); |
362 | #define REGISTER_KERNELS(value_type) \ |
363 | REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \ |
364 | REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64_t) |
365 | TF_CALL_POD_TYPES(REGISTER_KERNELS); |
366 | TF_CALL_tstring(REGISTER_KERNELS); |
367 | TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); |
368 | TF_CALL_quint16(REGISTER_KERNELS); |
369 | TF_CALL_qint16(REGISTER_KERNELS); |
370 | #undef REGISTER_KERNELS |
371 | #undef REGISTER_KERNELS_WITH_SPLIT_TYPE |
372 | } // namespace tensorflow |
373 | |