1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/common_shape_fns.h"
16#include "tensorflow/core/framework/op.h"
17#include "tensorflow/core/framework/shape_inference.h"
18#include "tensorflow/core/util/ragged_to_dense_util.h"
19
20namespace tensorflow {
21
22using errors::InvalidArgument;
23using shape_inference::DimensionHandle;
24using shape_inference::InferenceContext;
25using shape_inference::ShapeHandle;
26
27namespace {
28tensorflow::Status ValidateRowPartitionTypesAndShapes(
29 const std::vector<RowPartitionType>& row_partition_types,
30 InferenceContext* c) {
31 // Note: the allowed types may be extended in the future.
32 for (RowPartitionType row_partition_type : row_partition_types) {
33 switch (row_partition_type) {
34 case RowPartitionType::FIRST_DIM_SIZE:
35 case RowPartitionType::VALUE_ROWIDS:
36 case RowPartitionType::ROW_SPLITS:
37 break;
38 default:
39 return InvalidArgument("Unsupported partition type: ",
40 RowPartitionTypeToString(row_partition_type));
41 }
42 }
43
44 if (row_partition_types.empty()) {
45 return InvalidArgument("Partition info types should not be empty");
46 }
47 for (int i = 1; i < row_partition_types.size(); ++i) {
48 if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
49 return InvalidArgument("FIRST_DIM_SIZE must be first");
50 }
51 }
52 if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE &&
53 (row_partition_types.size() < 2 ||
54 row_partition_types[1] != RowPartitionType::VALUE_ROWIDS)) {
55 return InvalidArgument("FIRST_DIM_SIZE must be followed by VALUE_ROWIDS");
56 }
57 if (row_partition_types[0] == RowPartitionType::VALUE_ROWIDS) {
58 return InvalidArgument("VALUE_ROWIDS cannot be first");
59 }
60
61 int num_row_partition_tensors;
62 TF_RETURN_IF_ERROR(
63 c->GetAttr("num_row_partition_tensors", &num_row_partition_tensors));
64 if (num_row_partition_tensors != row_partition_types.size()) {
65 return InvalidArgument(
66 "Number of row partition tensors (", num_row_partition_tensors,
67 ") does not equal the number of row partition types(",
68 row_partition_types.size(), ").");
69 }
70
71 for (int i = 0; i < num_row_partition_tensors; ++i) {
72 TensorShapeProto partition_shape;
73 c->ShapeHandleToProto(c->input(3 + i), &partition_shape);
74 if (partition_shape.unknown_rank()) {
75 continue;
76 }
77 if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
78 if (partition_shape.dim_size() != 0) {
79 return InvalidArgument("FIRST_DIM_SIZE must be a scalar.");
80 }
81 } else {
82 if (partition_shape.dim_size() != 1) {
83 return InvalidArgument("Row partition must be a vector.");
84 }
85 }
86 }
87 return OkStatus();
88}
89
90} // namespace
91
92Status RaggedTensorToSparseShapeFn(InferenceContext* c);
93Status RaggedTensorToVariantShapeFn(InferenceContext* c);
94Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
95Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
96Status RaggedTensorToTensorShapeFn(InferenceContext* c);
97
98//==============================================================================
99// Registered Ops
100//==============================================================================
101
102REGISTER_OP("RaggedTensorToSparse")
103 .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
104 .Input("rt_dense_values: T")
105 .Output("sparse_indices: int64")
106 .Output("sparse_values: T")
107 .Output("sparse_dense_shape: int64")
108 .Attr("RAGGED_RANK: int >= 1")
109 .Attr("T: type")
110 .Attr("Tsplits: {int32, int64} = DT_INT64")
111 .SetShapeFn(RaggedTensorToSparseShapeFn);
112
113REGISTER_OP("RaggedTensorToVariant")
114 .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
115 .Input("rt_dense_values: Tvalues")
116 .Output("encoded_ragged: variant")
117 .Attr("RAGGED_RANK: int >= 0")
118 .Attr("Tvalues: type")
119 .Attr("Tsplits: {int32, int64} = DT_INT64")
120 .Attr("batched_input: bool")
121 .SetTypeConstructor(full_type::Unary(TFT_RAGGED, "Tvalues"))
122 .SetShapeFn(RaggedTensorToVariantShapeFn);
123
124REGISTER_OP("RaggedTensorFromVariant")
125 .Input("encoded_ragged: variant")
126 .Output("output_nested_splits: output_ragged_rank * Tsplits")
127 .Output("output_dense_values: Tvalues")
128 .Attr("input_ragged_rank: int >= -1")
129 .Attr("output_ragged_rank: int >= 0")
130 .Attr("Tvalues: type")
131 .Attr("Tsplits: {int32, int64} = DT_INT64")
132 .SetShapeFn(RaggedTensorFromVariantShapeFn);
133
134REGISTER_OP("RaggedTensorToVariantGradient")
135 .Input("encoded_ragged_grad: variant")
136 .Input("row_splits: Tsplits")
137 .Input("dense_values_shape: int32")
138 .Output("dense_values_grad: Tvalues")
139 .Attr("Tvalues: type")
140 .Attr("Tsplits: {int32, int64} = DT_INT64")
141 .SetShapeFn(RaggedTensorToVariantGradientShapeFn);
142
143REGISTER_OP("RaggedTensorToTensor")
144 .Attr("T: type")
145 .Attr("Tindex: {int64, int32}")
146 .Attr("Tshape: {int64, int32}")
147 .Attr("num_row_partition_tensors: int")
148 .Attr("row_partition_types: list(string)")
149 .Input("shape: Tshape")
150 .Input("values: T")
151 .Input("default_value: T")
152 .Input("row_partition_tensors: num_row_partition_tensors * Tindex")
153 .Output("result: T")
154 .SetShapeFn(RaggedTensorToTensorShapeFn);
155
156//==============================================================================
157// Shape Functions
158//==============================================================================
159
160Status RaggedTensorToSparseShapeFn(InferenceContext* c) {
161 int64_t num_splits;
162 TF_RETURN_IF_ERROR(c->GetAttr<int64_t>("RAGGED_RANK", &num_splits));
163 // TODO(b/112274756): Allow ragged_rank to be 0.
164 if (num_splits < 1) {
165 return errors::InvalidArgument("Requires RAGGED_RANK>0");
166 }
167 ShapeHandle rt_dense_values = c->input(num_splits);
168 TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
169
170 // Check that all rt_nested_splits have rank 1.
171 for (int64_t i = 0; i < num_splits; ++i) {
172 ShapeHandle splits = c->input(i);
173 TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
174 }
175
176 DimensionHandle dense_dims =
177 c->RankKnown(rt_dense_values)
178 ? c->MakeDim(c->Rank(rt_dense_values) + num_splits)
179 : c->UnknownDim();
180 DimensionHandle num_values = c->NumElements(rt_dense_values);
181
182 c->set_output(0, c->Matrix(num_values, dense_dims)); // indices
183 c->set_output(1, c->Vector(num_values)); // values
184 c->set_output(2, c->Vector(dense_dims)); // dense_shape
185
186 return OkStatus();
187}
188
189Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
190 int64_t num_splits;
191 TF_RETURN_IF_ERROR(c->GetAttr<int64_t>("RAGGED_RANK", &num_splits));
192 bool batched;
193 TF_RETURN_IF_ERROR(c->GetAttr<bool>("batched_input", &batched));
194 shape_inference::ShapeHandle rt_dense_values = c->input(num_splits);
195 TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
196 for (int64_t i = 0; i < num_splits; ++i) {
197 shape_inference::ShapeHandle splits = c->input(i);
198 TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
199 }
200 if (batched) {
201 auto num_first_splits = c->Dim(c->input(0), 0);
202 shape_inference::DimensionHandle num_rows;
203 TF_RETURN_IF_ERROR(c->Subtract(num_first_splits, 1, &num_rows));
204 c->set_output(0, c->Vector(num_rows));
205 } else {
206 c->set_output(0, c->Scalar());
207 }
208 return OkStatus();
209}
210
211Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
212 ShapeHandle shape;
213 TF_RETURN_IF_ERROR(
214 c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
215 c->set_output(0, shape);
216 return OkStatus();
217}
218
219Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
220 int64_t input_ragged_rank;
221 TF_RETURN_IF_ERROR(
222 c->GetAttr<int64_t>("input_ragged_rank", &input_ragged_rank));
223 int64_t output_ragged_rank;
224 TF_RETURN_IF_ERROR(
225 c->GetAttr<int64_t>("output_ragged_rank", &output_ragged_rank));
226 shape_inference::ShapeHandle encoded_ragged = c->input(0);
227 if (c->RankKnown(encoded_ragged) && input_ragged_rank >= 0) {
228 shape_inference::ShapeHandle unused;
229 TF_RETURN_IF_ERROR(c->WithRank(
230 encoded_ragged, output_ragged_rank - input_ragged_rank, &unused));
231 }
232 for (int64_t i = 0; i < output_ragged_rank; i++) {
233 c->set_output(i, c->UnknownShapeOfRank(1));
234 }
235 c->set_output(output_ragged_rank, c->UnknownShape());
236 return OkStatus();
237}
238
239tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c) {
240 TensorShapeProto shape;
241 {
242 ShapeHandle shape_handle;
243 TF_RETURN_IF_ERROR(
244 c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &shape_handle));
245 c->ShapeHandleToProto(shape_handle, &shape);
246 }
247
248 std::vector<RowPartitionType> row_partition_types;
249 TF_RETURN_IF_ERROR(GetRowPartitionTypes(c, &row_partition_types));
250 int ragged_rank = GetRaggedRank(row_partition_types);
251 TF_RETURN_IF_ERROR(
252 ValidateRowPartitionTypesAndShapes(row_partition_types, c));
253
254 TensorShapeProto value_shape;
255 c->ShapeHandleToProto(c->input(1), &value_shape);
256
257 TensorShapeProto default_value_shape;
258 c->ShapeHandleToProto(c->input(2), &default_value_shape);
259
260 TF_RETURN_IF_ERROR(
261 ValidateDefaultValueShape(default_value_shape, value_shape));
262
263 // TODO(martinz): Theoretically, we could check the first dimension of
264 // value_shape against the first dimension of the last row_partition_tensor
265 // assuming it is a VALUE_ROWIDS type.
266 // TODO(martinz): Although we normally don't know the first dimension of the
267 // output, we could infer it from the first dimension of the first
268 // row_partition_tensor if it is ROW_SPLITS type.
269 // TODO(martinz): If the shape is provided, but the value_shape has missing
270 // dimensions, we can check the default_value_shape against the shape.
271 TensorShapeProto output_shape;
272 TF_RETURN_IF_ERROR(CombineRaggedTensorToTensorShapes(
273 ragged_rank, shape, value_shape, &output_shape));
274
275 ShapeHandle output_shape_handle;
276 TF_RETURN_IF_ERROR(
277 c->MakeShapeFromShapeProto(output_shape, &output_shape_handle));
278 c->set_output(0, output_shape_handle);
279 return OkStatus();
280}
281
282} // namespace tensorflow
283