1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <cstdint> |
16 | #include <utility> |
17 | #include <vector> |
18 | |
19 | #include "tensorflow/core/framework/op_kernel.h" |
20 | #include "tensorflow/core/framework/register_types.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/framework/tensor_shape.h" |
23 | #include "tensorflow/core/framework/variant.h" |
24 | #include "tensorflow/core/framework/variant_encode_decode.h" |
25 | #include "tensorflow/core/framework/variant_op_registry.h" |
26 | #include "tensorflow/core/kernels/concat_lib.h" |
27 | #include "tensorflow/core/kernels/ragged_tensor_variant.h" |
28 | #include "tensorflow/core/lib/core/errors.h" |
29 | #include "tensorflow/core/lib/core/status.h" |
30 | #include "tensorflow/core/platform/errors.h" |
31 | #include "tensorflow/core/util/tensor_ops_util.h" |
32 | |
33 | namespace tensorflow { |
34 | namespace { |
35 | |
36 | template <typename VALUE_TYPE> |
37 | Status UnbatchDenseZerothDim( |
38 | const RaggedTensorVariant& batched_ragged, |
39 | std::vector<RaggedTensorVariant>* ragged_components) { |
40 | Tensor batched_values = batched_ragged.values(); |
41 | TensorShape values_shape = batched_values.shape(); |
42 | if (values_shape.dims() < 1) { |
43 | return errors::InvalidArgument("Can't unbatch rank-0 tensor." ); |
44 | } |
45 | auto num_components = values_shape.dim_size(0); |
46 | values_shape.RemoveDim(0); |
47 | auto num_values = values_shape.num_elements(); |
48 | |
49 | ragged_components->resize(num_components); |
50 | const auto& batched_flat = batched_values.flat<VALUE_TYPE>(); |
51 | |
52 | for (auto i = decltype(num_components){}; i < num_components; i++) { |
53 | (*ragged_components)[i].set_values( |
54 | Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape)); |
55 | auto ragged_component_values_flat = |
56 | (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>(); |
57 | for (auto j = decltype(num_values){}; j < num_values; j++) { |
58 | ragged_component_values_flat(j) = batched_flat(j + i * num_values); |
59 | } |
60 | } |
61 | |
62 | return OkStatus(); |
63 | } |
64 | |
65 | template <typename VALUE_TYPE, typename SPLIT_TYPE> |
66 | Status UnbatchRaggedZerothDim( |
67 | const RaggedTensorVariant& batched_ragged, |
68 | std::vector<RaggedTensorVariant>* ragged_components) { |
69 | // Set up the component Ragged Tensors. |
70 | int ragged_rank = batched_ragged.ragged_rank(); |
71 | if (ragged_rank == 0) { |
72 | return UnbatchDenseZerothDim<VALUE_TYPE>(batched_ragged, ragged_components); |
73 | } |
74 | |
75 | auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>(); |
76 | auto num_components = batched_splits_top_vec.size() - 1; |
77 | |
78 | if (num_components < 0) { |
79 | return errors::Internal("Invalid split argument." ); |
80 | } |
81 | |
82 | int num_splits = ragged_rank - 1; |
83 | ragged_components->resize(num_components); |
84 | for (RaggedTensorVariant& ragged_component : *ragged_components) { |
85 | ragged_component.mutable_nested_splits()->reserve(num_splits); |
86 | } |
87 | const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>(); |
88 | auto num_inner_elems = batched_ragged.values().NumElements(); |
89 | if (batched_ragged.values().dim_size(0) > 1) { |
90 | num_inner_elems /= batched_ragged.values().dim_size(0); |
91 | } |
92 | TensorShape values_shape = batched_ragged.values().shape(); |
93 | |
94 | // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]] |
95 | if (num_splits == 0) { |
96 | for (auto i = decltype(num_components){}; i < num_components; i++) { |
97 | auto start = batched_splits_top_vec(i); |
98 | auto limit = batched_splits_top_vec(i + 1); |
99 | auto num_values = limit - start; |
100 | values_shape.set_dim(0, num_values); |
101 | (*ragged_components)[i].set_values( |
102 | Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape)); |
103 | auto ragged_component_values_flat = |
104 | (*ragged_components)[i].mutable_values()->template flat<VALUE_TYPE>(); |
105 | for (auto j = decltype(num_values * num_inner_elems){}; |
106 | j < num_values * num_inner_elems; j++) { |
107 | ragged_component_values_flat(j) = |
108 | batched_flat(j + start * num_inner_elems); |
109 | } |
110 | } |
111 | return OkStatus(); |
112 | } |
113 | |
114 | // Unbatch nested splits. |
115 | std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec; |
116 | batched_splits_vec.reserve(ragged_rank); |
117 | for (int i = 0; i < ragged_rank; i++) { |
118 | batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>()); |
119 | } |
120 | std::vector<SPLIT_TYPE> index(num_splits, 1); |
121 | std::vector<SPLIT_TYPE> ragged_component_values_size(num_components, 0); |
122 | for (auto i = decltype(num_components){}; i < num_components; i++) { |
123 | std::vector<typename TTypes<SPLIT_TYPE>::Vec> ragged_component_splits_vec; |
124 | ragged_component_splits_vec.reserve(num_splits); |
125 | SPLIT_TYPE split_size = -1; |
126 | for (int j = 0; j < num_splits; j++) { |
127 | if (j == 0) { |
128 | split_size = |
129 | batched_splits_top_vec(i + 1) - batched_splits_top_vec(i) + 1; |
130 | } else { |
131 | // Update split size based on previous split. |
132 | SPLIT_TYPE last_index = ragged_component_splits_vec[j - 1].size() - 1; |
133 | split_size = ragged_component_splits_vec[j - 1](last_index) + 1; |
134 | } |
135 | (*ragged_components)[i].append_splits( |
136 | Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size}))); |
137 | ragged_component_splits_vec.push_back((*ragged_components)[i] |
138 | .mutable_splits(j) |
139 | ->template vec<SPLIT_TYPE>()); |
140 | SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1); |
141 | ragged_component_splits_vec[j](0) = 0; |
142 | for (SPLIT_TYPE k = 1; k < split_size; k++, index[j]++) { |
143 | ragged_component_splits_vec[j](k) = |
144 | batched_splits_vec[j + 1](index[j]) - last_split_value; |
145 | } |
146 | } |
147 | SPLIT_TYPE last_split_size = |
148 | ragged_component_splits_vec[num_splits - 1].size(); |
149 | ragged_component_values_size[i] = |
150 | ragged_component_splits_vec[num_splits - 1](last_split_size - 1); |
151 | } |
152 | |
153 | // Unbatch values. |
154 | int64_t value_index = 0; |
155 | for (auto i = decltype(num_components){}; i < num_components; i++) { |
156 | SPLIT_TYPE num_values = ragged_component_values_size[i]; |
157 | values_shape.set_dim(0, num_values); |
158 | (*ragged_components)[i].set_values( |
159 | Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape)); |
160 | auto ragged_component_values_flat = |
161 | (*ragged_components)[i].mutable_values()->template flat<VALUE_TYPE>(); |
162 | for (int64_t j = 0; j < num_values * num_inner_elems; j++, value_index++) { |
163 | ragged_component_values_flat(j) = batched_flat(value_index); |
164 | } |
165 | } |
166 | |
167 | return OkStatus(); |
168 | } |
169 | } // namespace |
170 | |
171 | template <typename VALUE_TYPE, typename SPLIT_TYPE> |
172 | class RaggedTensorToVariantOp : public OpKernel { |
173 | public: |
174 | explicit RaggedTensorToVariantOp(OpKernelConstruction* context) |
175 | : OpKernel(context) { |
176 | OP_REQUIRES_OK(context, context->GetAttr("batched_input" , &batched_input_)); |
177 | } |
178 | |
179 | void Compute(OpKernelContext* context) override { |
180 | // Read ragged_splits inputs. |
181 | OpInputList ragged_nested_splits_in; |
182 | OP_REQUIRES_OK(context, context->input_list("rt_nested_splits" , |
183 | &ragged_nested_splits_in)); |
184 | const int ragged_nested_splits_len = ragged_nested_splits_in.size(); |
185 | RaggedTensorVariant batched_ragged_input; |
186 | // Read ragged_values input. |
187 | batched_ragged_input.set_values(context->input(ragged_nested_splits_len)); |
188 | batched_ragged_input.mutable_nested_splits()->reserve( |
189 | ragged_nested_splits_len); |
190 | for (int i = 0; i < ragged_nested_splits_len; i++) { |
191 | OP_REQUIRES(context, ragged_nested_splits_in[i].dims() == 1, |
192 | errors::InvalidArgument("Requires nested_row_splits[" , i, "]" , |
193 | " to be rank 1 but is rank " , |
194 | ragged_nested_splits_in[i].dims())); |
195 | batched_ragged_input.append_splits(ragged_nested_splits_in[i]); |
196 | } |
197 | |
198 | if (!batched_input_) { |
199 | // Encode as a Scalar Variant Tensor. |
200 | Tensor* encoded_scalar; |
201 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), |
202 | &encoded_scalar)); |
203 | encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input); |
204 | return; |
205 | } |
206 | |
207 | // Unbatch the Ragged Tensor and encode the components. |
208 | std::vector<RaggedTensorVariant> unbatched_ragged_input; |
209 | OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>( |
210 | batched_ragged_input, &unbatched_ragged_input)); |
211 | |
212 | // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor. |
213 | Tensor* encoded_vector; |
214 | |
215 | // output_size will be used for calling TensorShape(int64_t ...). We |
216 | // cannot use `auto` type here, or there will be a narrowing error. |
217 | int64_t output_size = unbatched_ragged_input.size(); |
218 | OP_REQUIRES_OK(context, |
219 | context->allocate_output(0, TensorShape({output_size}), |
220 | &encoded_vector)); |
221 | auto encoded_vector_t = encoded_vector->vec<Variant>(); |
222 | for (auto i = decltype(output_size){}; i < output_size; i++) { |
223 | encoded_vector_t(i) = unbatched_ragged_input[i]; |
224 | } |
225 | } |
226 | |
227 | private: |
228 | bool batched_input_; |
229 | }; |
230 | |
231 | template <typename VALUE_TYPE, typename SPLIT_TYPE> |
232 | class RaggedTensorToVariantGradientOp : public OpKernel { |
233 | public: |
234 | using OpKernel::OpKernel; |
235 | |
236 | void Compute(OpKernelContext* context) override { |
237 | // Read inputs. |
238 | Tensor encoded_variant = context->input(0); |
239 | Tensor row_splits = context->input(1); |
240 | auto flat_row_splits = row_splits.flat<SPLIT_TYPE>(); |
241 | TensorShape dense_values_shape; |
242 | OP_REQUIRES_OK(context, |
243 | TensorShapeUtils::MakeShape(context->input(2).vec<int32>(), |
244 | &dense_values_shape)); |
245 | |
246 | const auto& flat_variants = encoded_variant.flat<Variant>(); |
247 | |
248 | // Get a Tensor containing the flat_values for each variant. |
249 | std::vector<Tensor> values; |
250 | for (int i = 0; i < flat_variants.size(); ++i) { |
251 | if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) { |
252 | values.push_back(encoded->values()); |
253 | } else { |
254 | // Missing value: this happens if only some of the variant values |
255 | // generated by ragged_tensor_to_variant impacted the value that we're |
256 | // calculating the gradient for. In this case, we will see a |
257 | // default-constructed variant; so treat it as a zero tensor with the |
258 | // appropriate shape. |
259 | const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v(); |
260 | auto piece_size = flat_row_splits(i + 1) - flat_row_splits(i); |
261 | TensorShape zeros_shape = dense_values_shape; |
262 | zeros_shape.set_dim(0, piece_size); |
263 | Tensor zero(value_dtype, zeros_shape); |
264 | zero.flat<VALUE_TYPE>().setZero(); |
265 | values.push_back(zero); |
266 | } |
267 | } |
268 | |
269 | if (values.size() == 1) { |
270 | // Just one flat_value tensor: return as-is. |
271 | context->set_output(0, values[0]); |
272 | } else { |
273 | Tensor* out = nullptr; |
274 | OP_REQUIRES_OK(context, |
275 | context->allocate_output(0, dense_values_shape, &out)); |
276 | // ConcatCPU assumes non-empty output. |
277 | if (dense_values_shape.num_elements() == 0) return; |
278 | // Multiple flat_values tensors: concatenate them together. |
279 | using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix; |
280 | using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix; |
281 | std::vector<std::unique_ptr<ConstPiece>> pieces; |
282 | pieces.reserve(values.size()); |
283 | for (const Tensor& t : values) { |
284 | // ConcatCPU assumes non-empty inputs. |
285 | if (t.NumElements() == 0) continue; |
286 | pieces.emplace_back( |
287 | new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()}))); |
288 | } |
289 | Piece out_flat = |
290 | out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()}); |
291 | ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat); |
292 | } |
293 | } |
294 | }; |
295 | |
296 | #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ |
297 | REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \ |
298 | .Device(DEVICE_CPU) \ |
299 | .TypeConstraint<value_type>("Tvalues") \ |
300 | .TypeConstraint<split_type>("Tsplits"), \ |
301 | RaggedTensorToVariantOp<value_type, split_type>); \ |
302 | REGISTER_KERNEL_BUILDER( \ |
303 | Name("RaggedTensorToVariantGradient") \ |
304 | .Device(DEVICE_CPU) \ |
305 | .TypeConstraint<value_type>("Tvalues") \ |
306 | .TypeConstraint<split_type>("Tsplits"), \ |
307 | RaggedTensorToVariantGradientOp<value_type, split_type>); |
308 | |
309 | #define REGISTER_KERNELS(value_type) \ |
310 | REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \ |
311 | REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64_t) |
312 | TF_CALL_POD_TYPES(REGISTER_KERNELS); |
313 | TF_CALL_tstring(REGISTER_KERNELS); |
314 | TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); |
315 | TF_CALL_quint16(REGISTER_KERNELS); |
316 | TF_CALL_qint16(REGISTER_KERNELS); |
317 | #undef REGISTER_KERNELS |
318 | #undef REGISTER_KERNELS_WITH_SPLIT_TYPE |
319 | } // namespace tensorflow |
320 | |