1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/ragged_to_dense_util.h" |
17 | |
18 | #include "tensorflow/core/framework/op.h" |
19 | #include "tensorflow/core/framework/shape_inference.h" |
20 | #include "tensorflow/core/framework/tensor_shape.h" |
21 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | using errors::InvalidArgument; |
26 | |
27 | tensorflow::Status GetRowPartitionTypesHelper( |
28 | const std::vector<string>& row_partition_type_strings, |
29 | std::vector<RowPartitionType>* row_partition_types) { |
30 | *row_partition_types = GetRowPartitionTypesHelper(row_partition_type_strings); |
31 | if (row_partition_types->size() != row_partition_type_strings.size()) { |
32 | // Something was not converted, return error status. |
33 | return InvalidArgument( |
34 | "Unknown string for partition info type: " , |
35 | row_partition_type_strings.at(row_partition_types->size())); |
36 | } |
37 | return OkStatus(); |
38 | } |
39 | |
40 | tensorflow::Status CombineRaggedTensorToTensorShapes( |
41 | int ragged_rank, const TensorShapeProto& shape, |
42 | const TensorShapeProto& value_shape, TensorShapeProto* output_shape) { |
43 | // Test for consistency of value_shape and shape specified. |
44 | // If shape is unspecified and value_shape is specified, then copy |
45 | // over the size from the value_shape dimension. |
46 | |
47 | if (value_shape.unknown_rank() && shape.unknown_rank()) { |
48 | output_shape->Clear(); |
49 | output_shape->set_unknown_rank(true); |
50 | return OkStatus(); |
51 | } |
52 | |
53 | if (shape.unknown_rank()) { |
54 | // Here, value_shape must be of known size. |
55 | while (output_shape->dim_size() < ragged_rank + value_shape.dim_size()) { |
56 | output_shape->add_dim()->set_size(-1); |
57 | } |
58 | } else { |
59 | *output_shape = shape; |
60 | } |
61 | if (value_shape.unknown_rank()) { |
62 | return OkStatus(); |
63 | } |
64 | // At this point, value_shape and output_shape have known ranks. |
65 | if (ragged_rank + value_shape.dim_size() != output_shape->dim_size()) { |
66 | return InvalidArgument( |
67 | "rt_input.shape and shape=" , TensorShape::DebugString(shape), |
68 | " are incompatible: rt_input.rank = " , |
69 | ragged_rank + value_shape.dim_size(), |
70 | " but shape.rank = " , output_shape->dim_size()); |
71 | } |
72 | |
73 | for (int i = 1; i < value_shape.dim_size(); ++i) { |
74 | const TensorShapeProto::Dim& value_dim = value_shape.dim(i); |
75 | TensorShapeProto::Dim* output_shape_dim = output_shape->mutable_dim( |
76 | output_shape->dim_size() - value_shape.dim_size() + i); |
77 | |
78 | if (value_dim.size() >= 0) { |
79 | if (output_shape_dim->size() >= 0) { |
80 | if (output_shape_dim->size() != value_dim.size()) { |
81 | return InvalidArgument( |
82 | "rt_input.shape and shape=" , TensorShape::DebugString(shape), |
83 | " are incompatible: rt_input.shape[" , i + ragged_rank, |
84 | "] = " , value_dim.size(), " but shape[" , i + ragged_rank, |
85 | "] = " , output_shape_dim->size()); |
86 | } |
87 | } else { |
88 | output_shape_dim->set_size(value_dim.size()); |
89 | } |
90 | } |
91 | } |
92 | return OkStatus(); |
93 | } |
94 | |
95 | tensorflow::Status ValidateDefaultValueShape( |
96 | const TensorShapeProto& default_value_shape, |
97 | const TensorShapeProto& value_shape) { |
98 | if (default_value_shape.unknown_rank() || value_shape.unknown_rank()) { |
99 | return OkStatus(); |
100 | } |
101 | |
102 | int default_ndims = default_value_shape.dim_size(); |
103 | int values_ndims = value_shape.dim_size(); |
104 | if (default_ndims >= values_ndims) { |
105 | return InvalidArgument( |
106 | "default_value.shape=" , TensorShape::DebugString(default_value_shape), |
107 | " and rt_input.flat_values.shape=" , |
108 | TensorShape::DebugString(value_shape), |
109 | " are incompatible: default_value.rank = " , default_ndims, |
110 | " must be less than rt_input.flat_values.rank = " , values_ndims); |
111 | } |
112 | for (int i = 0; i < std::min(default_ndims, values_ndims - 1); ++i) { |
113 | int default_dim = default_value_shape.dim(i).size(); |
114 | int value_dim = value_shape.dim(i + 1).size(); |
115 | if (default_dim >= 0 && value_dim >= 0 && default_dim != 1 && |
116 | default_dim != value_dim) { |
117 | return InvalidArgument( |
118 | "default_value.shape=" , TensorShape::DebugString(default_value_shape), |
119 | " and rt_input.flat_values.shape=" , |
120 | TensorShape::DebugString(value_shape), |
121 | " are incompatible: default_value.shape[" , |
122 | i - default_value_shape.dim_size(), "] = " , default_dim, |
123 | " but rt_input.flat_values.shape[" , |
124 | i - default_value_shape.dim_size(), "] = " , value_dim); |
125 | } |
126 | } |
127 | return OkStatus(); |
128 | } |
129 | |
130 | } // namespace tensorflow |
131 | |