1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <cstdint>
16#include <utility>
17#include <vector>
18
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/register_types.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/framework/tensor_shape.h"
23#include "tensorflow/core/framework/variant.h"
24#include "tensorflow/core/framework/variant_encode_decode.h"
25#include "tensorflow/core/framework/variant_op_registry.h"
26#include "tensorflow/core/kernels/concat_lib.h"
27#include "tensorflow/core/kernels/ragged_tensor_variant.h"
28#include "tensorflow/core/lib/core/errors.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/platform/errors.h"
31#include "tensorflow/core/util/tensor_ops_util.h"
32
33namespace tensorflow {
34namespace {
35
36template <typename VALUE_TYPE>
37Status UnbatchDenseZerothDim(
38 const RaggedTensorVariant& batched_ragged,
39 std::vector<RaggedTensorVariant>* ragged_components) {
40 Tensor batched_values = batched_ragged.values();
41 TensorShape values_shape = batched_values.shape();
42 if (values_shape.dims() < 1) {
43 return errors::InvalidArgument("Can't unbatch rank-0 tensor.");
44 }
45 auto num_components = values_shape.dim_size(0);
46 values_shape.RemoveDim(0);
47 auto num_values = values_shape.num_elements();
48
49 ragged_components->resize(num_components);
50 const auto& batched_flat = batched_values.flat<VALUE_TYPE>();
51
52 for (auto i = decltype(num_components){}; i < num_components; i++) {
53 (*ragged_components)[i].set_values(
54 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
55 auto ragged_component_values_flat =
56 (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
57 for (auto j = decltype(num_values){}; j < num_values; j++) {
58 ragged_component_values_flat(j) = batched_flat(j + i * num_values);
59 }
60 }
61
62 return OkStatus();
63}
64
65template <typename VALUE_TYPE, typename SPLIT_TYPE>
66Status UnbatchRaggedZerothDim(
67 const RaggedTensorVariant& batched_ragged,
68 std::vector<RaggedTensorVariant>* ragged_components) {
69 // Set up the component Ragged Tensors.
70 int ragged_rank = batched_ragged.ragged_rank();
71 if (ragged_rank == 0) {
72 return UnbatchDenseZerothDim<VALUE_TYPE>(batched_ragged, ragged_components);
73 }
74
75 auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
76 auto num_components = batched_splits_top_vec.size() - 1;
77
78 if (num_components < 0) {
79 return errors::Internal("Invalid split argument.");
80 }
81
82 int num_splits = ragged_rank - 1;
83 ragged_components->resize(num_components);
84 for (RaggedTensorVariant& ragged_component : *ragged_components) {
85 ragged_component.mutable_nested_splits()->reserve(num_splits);
86 }
87 const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
88 auto num_inner_elems = batched_ragged.values().NumElements();
89 if (batched_ragged.values().dim_size(0) > 1) {
90 num_inner_elems /= batched_ragged.values().dim_size(0);
91 }
92 TensorShape values_shape = batched_ragged.values().shape();
93
94 // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
95 if (num_splits == 0) {
96 for (auto i = decltype(num_components){}; i < num_components; i++) {
97 auto start = batched_splits_top_vec(i);
98 auto limit = batched_splits_top_vec(i + 1);
99 auto num_values = limit - start;
100 values_shape.set_dim(0, num_values);
101 (*ragged_components)[i].set_values(
102 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
103 auto ragged_component_values_flat =
104 (*ragged_components)[i].mutable_values()->template flat<VALUE_TYPE>();
105 for (auto j = decltype(num_values * num_inner_elems){};
106 j < num_values * num_inner_elems; j++) {
107 ragged_component_values_flat(j) =
108 batched_flat(j + start * num_inner_elems);
109 }
110 }
111 return OkStatus();
112 }
113
114 // Unbatch nested splits.
115 std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
116 batched_splits_vec.reserve(ragged_rank);
117 for (int i = 0; i < ragged_rank; i++) {
118 batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
119 }
120 std::vector<SPLIT_TYPE> index(num_splits, 1);
121 std::vector<SPLIT_TYPE> ragged_component_values_size(num_components, 0);
122 for (auto i = decltype(num_components){}; i < num_components; i++) {
123 std::vector<typename TTypes<SPLIT_TYPE>::Vec> ragged_component_splits_vec;
124 ragged_component_splits_vec.reserve(num_splits);
125 SPLIT_TYPE split_size = -1;
126 for (int j = 0; j < num_splits; j++) {
127 if (j == 0) {
128 split_size =
129 batched_splits_top_vec(i + 1) - batched_splits_top_vec(i) + 1;
130 } else {
131 // Update split size based on previous split.
132 SPLIT_TYPE last_index = ragged_component_splits_vec[j - 1].size() - 1;
133 split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
134 }
135 (*ragged_components)[i].append_splits(
136 Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
137 ragged_component_splits_vec.push_back((*ragged_components)[i]
138 .mutable_splits(j)
139 ->template vec<SPLIT_TYPE>());
140 SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
141 ragged_component_splits_vec[j](0) = 0;
142 for (SPLIT_TYPE k = 1; k < split_size; k++, index[j]++) {
143 ragged_component_splits_vec[j](k) =
144 batched_splits_vec[j + 1](index[j]) - last_split_value;
145 }
146 }
147 SPLIT_TYPE last_split_size =
148 ragged_component_splits_vec[num_splits - 1].size();
149 ragged_component_values_size[i] =
150 ragged_component_splits_vec[num_splits - 1](last_split_size - 1);
151 }
152
153 // Unbatch values.
154 int64_t value_index = 0;
155 for (auto i = decltype(num_components){}; i < num_components; i++) {
156 SPLIT_TYPE num_values = ragged_component_values_size[i];
157 values_shape.set_dim(0, num_values);
158 (*ragged_components)[i].set_values(
159 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
160 auto ragged_component_values_flat =
161 (*ragged_components)[i].mutable_values()->template flat<VALUE_TYPE>();
162 for (int64_t j = 0; j < num_values * num_inner_elems; j++, value_index++) {
163 ragged_component_values_flat(j) = batched_flat(value_index);
164 }
165 }
166
167 return OkStatus();
168}
169} // namespace
170
171template <typename VALUE_TYPE, typename SPLIT_TYPE>
172class RaggedTensorToVariantOp : public OpKernel {
173 public:
174 explicit RaggedTensorToVariantOp(OpKernelConstruction* context)
175 : OpKernel(context) {
176 OP_REQUIRES_OK(context, context->GetAttr("batched_input", &batched_input_));
177 }
178
179 void Compute(OpKernelContext* context) override {
180 // Read ragged_splits inputs.
181 OpInputList ragged_nested_splits_in;
182 OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
183 &ragged_nested_splits_in));
184 const int ragged_nested_splits_len = ragged_nested_splits_in.size();
185 RaggedTensorVariant batched_ragged_input;
186 // Read ragged_values input.
187 batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
188 batched_ragged_input.mutable_nested_splits()->reserve(
189 ragged_nested_splits_len);
190 for (int i = 0; i < ragged_nested_splits_len; i++) {
191 OP_REQUIRES(context, ragged_nested_splits_in[i].dims() == 1,
192 errors::InvalidArgument("Requires nested_row_splits[", i, "]",
193 " to be rank 1 but is rank ",
194 ragged_nested_splits_in[i].dims()));
195 batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
196 }
197
198 if (!batched_input_) {
199 // Encode as a Scalar Variant Tensor.
200 Tensor* encoded_scalar;
201 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
202 &encoded_scalar));
203 encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
204 return;
205 }
206
207 // Unbatch the Ragged Tensor and encode the components.
208 std::vector<RaggedTensorVariant> unbatched_ragged_input;
209 OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
210 batched_ragged_input, &unbatched_ragged_input));
211
212 // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
213 Tensor* encoded_vector;
214
215 // output_size will be used for calling TensorShape(int64_t ...). We
216 // cannot use `auto` type here, or there will be a narrowing error.
217 int64_t output_size = unbatched_ragged_input.size();
218 OP_REQUIRES_OK(context,
219 context->allocate_output(0, TensorShape({output_size}),
220 &encoded_vector));
221 auto encoded_vector_t = encoded_vector->vec<Variant>();
222 for (auto i = decltype(output_size){}; i < output_size; i++) {
223 encoded_vector_t(i) = unbatched_ragged_input[i];
224 }
225 }
226
227 private:
228 bool batched_input_;
229};
230
231template <typename VALUE_TYPE, typename SPLIT_TYPE>
232class RaggedTensorToVariantGradientOp : public OpKernel {
233 public:
234 using OpKernel::OpKernel;
235
236 void Compute(OpKernelContext* context) override {
237 // Read inputs.
238 Tensor encoded_variant = context->input(0);
239 Tensor row_splits = context->input(1);
240 auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
241 TensorShape dense_values_shape;
242 OP_REQUIRES_OK(context,
243 TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
244 &dense_values_shape));
245
246 const auto& flat_variants = encoded_variant.flat<Variant>();
247
248 // Get a Tensor containing the flat_values for each variant.
249 std::vector<Tensor> values;
250 for (int i = 0; i < flat_variants.size(); ++i) {
251 if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
252 values.push_back(encoded->values());
253 } else {
254 // Missing value: this happens if only some of the variant values
255 // generated by ragged_tensor_to_variant impacted the value that we're
256 // calculating the gradient for. In this case, we will see a
257 // default-constructed variant; so treat it as a zero tensor with the
258 // appropriate shape.
259 const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
260 auto piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
261 TensorShape zeros_shape = dense_values_shape;
262 zeros_shape.set_dim(0, piece_size);
263 Tensor zero(value_dtype, zeros_shape);
264 zero.flat<VALUE_TYPE>().setZero();
265 values.push_back(zero);
266 }
267 }
268
269 if (values.size() == 1) {
270 // Just one flat_value tensor: return as-is.
271 context->set_output(0, values[0]);
272 } else {
273 Tensor* out = nullptr;
274 OP_REQUIRES_OK(context,
275 context->allocate_output(0, dense_values_shape, &out));
276 // ConcatCPU assumes non-empty output.
277 if (dense_values_shape.num_elements() == 0) return;
278 // Multiple flat_values tensors: concatenate them together.
279 using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
280 using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
281 std::vector<std::unique_ptr<ConstPiece>> pieces;
282 pieces.reserve(values.size());
283 for (const Tensor& t : values) {
284 // ConcatCPU assumes non-empty inputs.
285 if (t.NumElements() == 0) continue;
286 pieces.emplace_back(
287 new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
288 }
289 Piece out_flat =
290 out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
291 ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
292 }
293 }
294};
295
296#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
297 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
298 .Device(DEVICE_CPU) \
299 .TypeConstraint<value_type>("Tvalues") \
300 .TypeConstraint<split_type>("Tsplits"), \
301 RaggedTensorToVariantOp<value_type, split_type>); \
302 REGISTER_KERNEL_BUILDER( \
303 Name("RaggedTensorToVariantGradient") \
304 .Device(DEVICE_CPU) \
305 .TypeConstraint<value_type>("Tvalues") \
306 .TypeConstraint<split_type>("Tsplits"), \
307 RaggedTensorToVariantGradientOp<value_type, split_type>);
308
309#define REGISTER_KERNELS(value_type) \
310 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
311 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64_t)
312TF_CALL_POD_TYPES(REGISTER_KERNELS);
313TF_CALL_tstring(REGISTER_KERNELS);
314TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
315TF_CALL_quint16(REGISTER_KERNELS);
316TF_CALL_qint16(REGISTER_KERNELS);
317#undef REGISTER_KERNELS
318#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
319} // namespace tensorflow
320