1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/common_shape_fns.h" |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/shape_inference.h" |
18 | #include "tensorflow/core/util/ragged_to_dense_util.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using errors::InvalidArgument; |
23 | using shape_inference::DimensionHandle; |
24 | using shape_inference::InferenceContext; |
25 | using shape_inference::ShapeHandle; |
26 | |
27 | namespace { |
28 | tensorflow::Status ValidateRowPartitionTypesAndShapes( |
29 | const std::vector<RowPartitionType>& row_partition_types, |
30 | InferenceContext* c) { |
31 | // Note: the allowed types may be extended in the future. |
32 | for (RowPartitionType row_partition_type : row_partition_types) { |
33 | switch (row_partition_type) { |
34 | case RowPartitionType::FIRST_DIM_SIZE: |
35 | case RowPartitionType::VALUE_ROWIDS: |
36 | case RowPartitionType::ROW_SPLITS: |
37 | break; |
38 | default: |
39 | return InvalidArgument("Unsupported partition type: " , |
40 | RowPartitionTypeToString(row_partition_type)); |
41 | } |
42 | } |
43 | |
44 | if (row_partition_types.empty()) { |
45 | return InvalidArgument("Partition info types should not be empty" ); |
46 | } |
47 | for (int i = 1; i < row_partition_types.size(); ++i) { |
48 | if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) { |
49 | return InvalidArgument("FIRST_DIM_SIZE must be first" ); |
50 | } |
51 | } |
52 | if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE && |
53 | (row_partition_types.size() < 2 || |
54 | row_partition_types[1] != RowPartitionType::VALUE_ROWIDS)) { |
55 | return InvalidArgument("FIRST_DIM_SIZE must be followed by VALUE_ROWIDS" ); |
56 | } |
57 | if (row_partition_types[0] == RowPartitionType::VALUE_ROWIDS) { |
58 | return InvalidArgument("VALUE_ROWIDS cannot be first" ); |
59 | } |
60 | |
61 | int num_row_partition_tensors; |
62 | TF_RETURN_IF_ERROR( |
63 | c->GetAttr("num_row_partition_tensors" , &num_row_partition_tensors)); |
64 | if (num_row_partition_tensors != row_partition_types.size()) { |
65 | return InvalidArgument( |
66 | "Number of row partition tensors (" , num_row_partition_tensors, |
67 | ") does not equal the number of row partition types(" , |
68 | row_partition_types.size(), ")." ); |
69 | } |
70 | |
71 | for (int i = 0; i < num_row_partition_tensors; ++i) { |
72 | TensorShapeProto partition_shape; |
73 | c->ShapeHandleToProto(c->input(3 + i), &partition_shape); |
74 | if (partition_shape.unknown_rank()) { |
75 | continue; |
76 | } |
77 | if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) { |
78 | if (partition_shape.dim_size() != 0) { |
79 | return InvalidArgument("FIRST_DIM_SIZE must be a scalar." ); |
80 | } |
81 | } else { |
82 | if (partition_shape.dim_size() != 1) { |
83 | return InvalidArgument("Row partition must be a vector." ); |
84 | } |
85 | } |
86 | } |
87 | return OkStatus(); |
88 | } |
89 | |
90 | } // namespace |
91 | |
92 | Status RaggedTensorToSparseShapeFn(InferenceContext* c); |
93 | Status RaggedTensorToVariantShapeFn(InferenceContext* c); |
94 | Status RaggedTensorFromVariantShapeFn(InferenceContext* c); |
95 | Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c); |
96 | Status RaggedTensorToTensorShapeFn(InferenceContext* c); |
97 | |
98 | //============================================================================== |
99 | // Registered Ops |
100 | //============================================================================== |
101 | |
102 | REGISTER_OP("RaggedTensorToSparse" ) |
103 | .Input("rt_nested_splits: RAGGED_RANK * Tsplits" ) |
104 | .Input("rt_dense_values: T" ) |
105 | .Output("sparse_indices: int64" ) |
106 | .Output("sparse_values: T" ) |
107 | .Output("sparse_dense_shape: int64" ) |
108 | .Attr("RAGGED_RANK: int >= 1" ) |
109 | .Attr("T: type" ) |
110 | .Attr("Tsplits: {int32, int64} = DT_INT64" ) |
111 | .SetShapeFn(RaggedTensorToSparseShapeFn); |
112 | |
113 | REGISTER_OP("RaggedTensorToVariant" ) |
114 | .Input("rt_nested_splits: RAGGED_RANK * Tsplits" ) |
115 | .Input("rt_dense_values: Tvalues" ) |
116 | .Output("encoded_ragged: variant" ) |
117 | .Attr("RAGGED_RANK: int >= 0" ) |
118 | .Attr("Tvalues: type" ) |
119 | .Attr("Tsplits: {int32, int64} = DT_INT64" ) |
120 | .Attr("batched_input: bool" ) |
121 | .SetTypeConstructor(full_type::Unary(TFT_RAGGED, "Tvalues" )) |
122 | .SetShapeFn(RaggedTensorToVariantShapeFn); |
123 | |
124 | REGISTER_OP("RaggedTensorFromVariant" ) |
125 | .Input("encoded_ragged: variant" ) |
126 | .Output("output_nested_splits: output_ragged_rank * Tsplits" ) |
127 | .Output("output_dense_values: Tvalues" ) |
128 | .Attr("input_ragged_rank: int >= -1" ) |
129 | .Attr("output_ragged_rank: int >= 0" ) |
130 | .Attr("Tvalues: type" ) |
131 | .Attr("Tsplits: {int32, int64} = DT_INT64" ) |
132 | .SetShapeFn(RaggedTensorFromVariantShapeFn); |
133 | |
134 | REGISTER_OP("RaggedTensorToVariantGradient" ) |
135 | .Input("encoded_ragged_grad: variant" ) |
136 | .Input("row_splits: Tsplits" ) |
137 | .Input("dense_values_shape: int32" ) |
138 | .Output("dense_values_grad: Tvalues" ) |
139 | .Attr("Tvalues: type" ) |
140 | .Attr("Tsplits: {int32, int64} = DT_INT64" ) |
141 | .SetShapeFn(RaggedTensorToVariantGradientShapeFn); |
142 | |
143 | REGISTER_OP("RaggedTensorToTensor" ) |
144 | .Attr("T: type" ) |
145 | .Attr("Tindex: {int64, int32}" ) |
146 | .Attr("Tshape: {int64, int32}" ) |
147 | .Attr("num_row_partition_tensors: int" ) |
148 | .Attr("row_partition_types: list(string)" ) |
149 | .Input("shape: Tshape" ) |
150 | .Input("values: T" ) |
151 | .Input("default_value: T" ) |
152 | .Input("row_partition_tensors: num_row_partition_tensors * Tindex" ) |
153 | .Output("result: T" ) |
154 | .SetShapeFn(RaggedTensorToTensorShapeFn); |
155 | |
156 | //============================================================================== |
157 | // Shape Functions |
158 | //============================================================================== |
159 | |
160 | Status RaggedTensorToSparseShapeFn(InferenceContext* c) { |
161 | int64_t num_splits; |
162 | TF_RETURN_IF_ERROR(c->GetAttr<int64_t>("RAGGED_RANK" , &num_splits)); |
163 | // TODO(b/112274756): Allow ragged_rank to be 0. |
164 | if (num_splits < 1) { |
165 | return errors::InvalidArgument("Requires RAGGED_RANK>0" ); |
166 | } |
167 | ShapeHandle rt_dense_values = c->input(num_splits); |
168 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values)); |
169 | |
170 | // Check that all rt_nested_splits have rank 1. |
171 | for (int64_t i = 0; i < num_splits; ++i) { |
172 | ShapeHandle splits = c->input(i); |
173 | TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits)); |
174 | } |
175 | |
176 | DimensionHandle dense_dims = |
177 | c->RankKnown(rt_dense_values) |
178 | ? c->MakeDim(c->Rank(rt_dense_values) + num_splits) |
179 | : c->UnknownDim(); |
180 | DimensionHandle num_values = c->NumElements(rt_dense_values); |
181 | |
182 | c->set_output(0, c->Matrix(num_values, dense_dims)); // indices |
183 | c->set_output(1, c->Vector(num_values)); // values |
184 | c->set_output(2, c->Vector(dense_dims)); // dense_shape |
185 | |
186 | return OkStatus(); |
187 | } |
188 | |
189 | Status RaggedTensorToVariantShapeFn(InferenceContext* c) { |
190 | int64_t num_splits; |
191 | TF_RETURN_IF_ERROR(c->GetAttr<int64_t>("RAGGED_RANK" , &num_splits)); |
192 | bool batched; |
193 | TF_RETURN_IF_ERROR(c->GetAttr<bool>("batched_input" , &batched)); |
194 | shape_inference::ShapeHandle rt_dense_values = c->input(num_splits); |
195 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values)); |
196 | for (int64_t i = 0; i < num_splits; ++i) { |
197 | shape_inference::ShapeHandle splits = c->input(i); |
198 | TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits)); |
199 | } |
200 | if (batched) { |
201 | auto num_first_splits = c->Dim(c->input(0), 0); |
202 | shape_inference::DimensionHandle num_rows; |
203 | TF_RETURN_IF_ERROR(c->Subtract(num_first_splits, 1, &num_rows)); |
204 | c->set_output(0, c->Vector(num_rows)); |
205 | } else { |
206 | c->set_output(0, c->Scalar()); |
207 | } |
208 | return OkStatus(); |
209 | } |
210 | |
211 | Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) { |
212 | ShapeHandle shape; |
213 | TF_RETURN_IF_ERROR( |
214 | c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape)); |
215 | c->set_output(0, shape); |
216 | return OkStatus(); |
217 | } |
218 | |
219 | Status RaggedTensorFromVariantShapeFn(InferenceContext* c) { |
220 | int64_t input_ragged_rank; |
221 | TF_RETURN_IF_ERROR( |
222 | c->GetAttr<int64_t>("input_ragged_rank" , &input_ragged_rank)); |
223 | int64_t output_ragged_rank; |
224 | TF_RETURN_IF_ERROR( |
225 | c->GetAttr<int64_t>("output_ragged_rank" , &output_ragged_rank)); |
226 | shape_inference::ShapeHandle encoded_ragged = c->input(0); |
227 | if (c->RankKnown(encoded_ragged) && input_ragged_rank >= 0) { |
228 | shape_inference::ShapeHandle unused; |
229 | TF_RETURN_IF_ERROR(c->WithRank( |
230 | encoded_ragged, output_ragged_rank - input_ragged_rank, &unused)); |
231 | } |
232 | for (int64_t i = 0; i < output_ragged_rank; i++) { |
233 | c->set_output(i, c->UnknownShapeOfRank(1)); |
234 | } |
235 | c->set_output(output_ragged_rank, c->UnknownShape()); |
236 | return OkStatus(); |
237 | } |
238 | |
239 | tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c) { |
240 | TensorShapeProto shape; |
241 | { |
242 | ShapeHandle shape_handle; |
243 | TF_RETURN_IF_ERROR( |
244 | c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &shape_handle)); |
245 | c->ShapeHandleToProto(shape_handle, &shape); |
246 | } |
247 | |
248 | std::vector<RowPartitionType> row_partition_types; |
249 | TF_RETURN_IF_ERROR(GetRowPartitionTypes(c, &row_partition_types)); |
250 | int ragged_rank = GetRaggedRank(row_partition_types); |
251 | TF_RETURN_IF_ERROR( |
252 | ValidateRowPartitionTypesAndShapes(row_partition_types, c)); |
253 | |
254 | TensorShapeProto value_shape; |
255 | c->ShapeHandleToProto(c->input(1), &value_shape); |
256 | |
257 | TensorShapeProto default_value_shape; |
258 | c->ShapeHandleToProto(c->input(2), &default_value_shape); |
259 | |
260 | TF_RETURN_IF_ERROR( |
261 | ValidateDefaultValueShape(default_value_shape, value_shape)); |
262 | |
263 | // TODO(martinz): Theoretically, we could check the first dimension of |
264 | // value_shape against the first dimension of the last row_partition_tensor |
265 | // assuming it is a VALUE_ROWIDS type. |
266 | // TODO(martinz): Although we normally don't know the first dimension of the |
267 | // output, we could infer it from the first dimension of the first |
268 | // row_partition_tensor if it is ROW_SPLITS type. |
269 | // TODO(martinz): If the shape is provided, but the value_shape has missing |
270 | // dimensions, we can check the default_value_shape against the shape. |
271 | TensorShapeProto output_shape; |
272 | TF_RETURN_IF_ERROR(CombineRaggedTensorToTensorShapes( |
273 | ragged_rank, shape, value_shape, &output_shape)); |
274 | |
275 | ShapeHandle output_shape_handle; |
276 | TF_RETURN_IF_ERROR( |
277 | c->MakeShapeFromShapeProto(output_shape, &output_shape_handle)); |
278 | c->set_output(0, output_shape_handle); |
279 | return OkStatus(); |
280 | } |
281 | |
282 | } // namespace tensorflow |
283 | |