1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using shape_inference::DimensionHandle; |
23 | using shape_inference::InferenceContext; |
24 | using shape_inference::ShapeHandle; |
25 | |
26 | REGISTER_OP("SetSize" ) |
27 | .Input("set_indices: int64" ) |
28 | .Input("set_values: T" ) |
29 | .Input("set_shape: int64" ) |
30 | .Attr("validate_indices: bool = true" ) |
31 | .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}" ) |
32 | .Output("size: int32" ) |
33 | .SetShapeFn(shape_inference::UnknownShape); |
34 | |
35 | REGISTER_OP("DenseToDenseSetOperation" ) |
36 | .Input("set1: T" ) |
37 | .Input("set2: T" ) |
38 | .Attr("set_operation: string" ) |
39 | .Attr("validate_indices: bool = true" ) |
40 | .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}" ) |
41 | .Output("result_indices: int64" ) |
42 | .Output("result_values: T" ) |
43 | .Output("result_shape: int64" ) |
44 | .SetShapeFn([](InferenceContext* c) { |
45 | if (c->num_inputs() != 2) { |
46 | return errors::InvalidArgument("len(inputs) != 2." ); |
47 | } |
48 | // The following should stay in sync with `ComputeDenseToDense` shape |
49 | // assertions in kernels/set_kernels.cc. |
50 | // Dimension n contains the set values to be compared, so ranks must be |
51 | // >= 2, and the first n-1 dimensions of inputs and output must be |
52 | // compatible. |
53 | DimensionHandle output_rank; |
54 | ShapeHandle input0_shape = c->input(0); |
55 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape)); |
56 | if (c->RankKnown(input0_shape)) { |
57 | const int32_t input0_rank = c->Rank(input0_shape); |
58 | ShapeHandle input1_shape = c->input(1); |
59 | TF_RETURN_IF_ERROR( |
60 | c->WithRank(input1_shape, input0_rank, &input1_shape)); |
61 | if (c->RankKnown(input1_shape)) { |
62 | // If both ranks are specified, the first n-1 dims must be compatible. |
63 | const int32_t rank = c->Rank(input1_shape); |
64 | ShapeHandle group0_shape; |
65 | TF_RETURN_IF_ERROR( |
66 | c->Subshape(input0_shape, 0, rank - 1, &group0_shape)); |
67 | ShapeHandle group1_shape; |
68 | TF_RETURN_IF_ERROR( |
69 | c->Subshape(input1_shape, 0, rank - 1, &group1_shape)); |
70 | ShapeHandle unused_shape; |
71 | TF_RETURN_IF_ERROR( |
72 | c->Merge(group0_shape, group1_shape, &unused_shape)); |
73 | } |
74 | output_rank = c->MakeDim(input0_rank); |
75 | } else { |
76 | ShapeHandle input1_shape = c->input(1); |
77 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(input1_shape, 2, &input1_shape)); |
78 | if (c->RankKnown(input1_shape)) { |
79 | output_rank = c->MakeDim(c->Rank(input1_shape)); |
80 | } else { |
81 | output_rank = c->UnknownDim(); |
82 | } |
83 | } |
84 | |
85 | c->set_output(0, c->Matrix(c->UnknownDim(), output_rank)); |
86 | c->set_output(1, c->Vector(c->UnknownDim())); |
87 | c->set_output(2, c->Vector(output_rank)); |
88 | return OkStatus(); |
89 | }); |
90 | |
91 | REGISTER_OP("DenseToSparseSetOperation" ) |
92 | .Input("set1: T" ) |
93 | .Input("set2_indices: int64" ) |
94 | .Input("set2_values: T" ) |
95 | .Input("set2_shape: int64" ) |
96 | .Attr("set_operation: string" ) |
97 | .Attr("validate_indices: bool = true" ) |
98 | .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}" ) |
99 | .Output("result_indices: int64" ) |
100 | .Output("result_values: T" ) |
101 | .Output("result_shape: int64" ) |
102 | .SetShapeFn([](InferenceContext* c) { |
103 | if (c->num_inputs() != 4) { |
104 | return errors::InvalidArgument("len(inputs) != 4." ); |
105 | } |
106 | // The following should stay in sync with `ComputeDenseToSparse` shape |
107 | // assertions in kernels/set_kernels.cc. |
108 | // Ranks must be compatible, and be >= 2. |
109 | ShapeHandle input1_shape_shape = c->input(3); |
110 | TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( |
111 | c, c->input(1), c->input(2), input1_shape_shape)); |
112 | |
113 | DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0); |
114 | |
115 | DimensionHandle output_rank_dim; |
116 | ShapeHandle input0_shape = c->input(0); |
117 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape)); |
118 | if (c->RankKnown(input0_shape)) { |
119 | const int32_t input0_rank = c->Rank(input0_shape); |
120 | TF_RETURN_IF_ERROR( |
121 | c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim)); |
122 | output_rank_dim = c->MakeDim(input0_rank); |
123 | } else if (c->ValueKnown(input1_rank_dim)) { |
124 | output_rank_dim = input1_rank_dim; |
125 | } else { |
126 | output_rank_dim = c->UnknownDim(); |
127 | } |
128 | |
129 | c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim)); |
130 | c->set_output(1, c->Vector(c->UnknownDim())); |
131 | c->set_output(2, c->Vector(output_rank_dim)); |
132 | return OkStatus(); |
133 | }); |
134 | |
135 | REGISTER_OP("SparseToSparseSetOperation" ) |
136 | .Input("set1_indices: int64" ) |
137 | .Input("set1_values: T" ) |
138 | .Input("set1_shape: int64" ) |
139 | .Input("set2_indices: int64" ) |
140 | .Input("set2_values: T" ) |
141 | .Input("set2_shape: int64" ) |
142 | .Attr("set_operation: string" ) |
143 | .Attr("validate_indices: bool = true" ) |
144 | .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}" ) |
145 | .Output("result_indices: int64" ) |
146 | .Output("result_values: T" ) |
147 | .Output("result_shape: int64" ) |
148 | .SetShapeFn([](InferenceContext* c) { |
149 | if (c->num_inputs() != 6) { |
150 | return errors::InvalidArgument("len(inputs) != 6." ); |
151 | } |
152 | // The following should stay in sync with `ComputeSparseToSparse` shape |
153 | // assertions in kernels/set_kernels.cc. |
154 | // Ranks must be compatible, and be >= 2. |
155 | ShapeHandle input0_shape_shape = c->input(2); |
156 | ShapeHandle input1_shape_shape = c->input(5); |
157 | TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( |
158 | c, c->input(0), c->input(1), input0_shape_shape)); |
159 | TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( |
160 | c, c->input(3), c->input(4), input1_shape_shape)); |
161 | |
162 | DimensionHandle input0_rank_dim = c->Dim(input0_shape_shape, 0); |
163 | DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0); |
164 | DimensionHandle output_rank_dim; |
165 | if (c->ValueKnown(input0_rank_dim)) { |
166 | const int64_t input0_rank = c->Value(input0_rank_dim); |
167 | if (input0_rank < 2) { |
168 | return errors::InvalidArgument("Input 0, expected rank >= 2, got " , |
169 | input0_rank, "." ); |
170 | } |
171 | TF_RETURN_IF_ERROR( |
172 | c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim)); |
173 | output_rank_dim = input0_rank_dim; |
174 | } else if (c->ValueKnown(input1_rank_dim)) { |
175 | const int64_t input1_rank = c->Value(input1_rank_dim); |
176 | if (input1_rank < 2) { |
177 | return errors::InvalidArgument("Input 1, expected rank >= 2, got " , |
178 | input1_rank, "." ); |
179 | } |
180 | output_rank_dim = input1_rank_dim; |
181 | } else { |
182 | output_rank_dim = c->UnknownDim(); |
183 | } |
184 | |
185 | c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim)); |
186 | c->set_output(1, c->Vector(c->UnknownDim())); |
187 | c->set_output(2, c->Vector(output_rank_dim)); |
188 | return OkStatus(); |
189 | }); |
190 | |
191 | } // namespace tensorflow |
192 | |