1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | // -------------------------------------------------------------------------- |
26 | namespace { |
27 | |
28 | Status SwitchShape(InferenceContext* c) { |
29 | ShapeHandle unused; |
30 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
31 | ShapeHandle out = c->input(0); |
32 | c->set_output(0, out); |
33 | c->set_output(1, out); |
34 | |
35 | // Handle resource shape / dtype. |
36 | auto* handle_data = c->input_handle_shapes_and_types(0); |
37 | if (handle_data != nullptr) { |
38 | c->set_output_handle_shapes_and_types(0, *handle_data); |
39 | c->set_output_handle_shapes_and_types(1, *handle_data); |
40 | } |
41 | return OkStatus(); |
42 | } |
43 | |
44 | Status SwitchNShape(InferenceContext* c) { |
45 | ShapeHandle unused; |
46 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
47 | ShapeHandle out = c->input(0); |
48 | int num_outs; |
49 | TF_RETURN_IF_ERROR(c->GetAttr("num_outs" , &num_outs)); |
50 | for (int i = 0; i < num_outs; i++) { |
51 | c->set_output(i, out); |
52 | } |
53 | |
54 | // Handle resource shape / dtype. |
55 | auto* handle_data = c->input_handle_shapes_and_types(0); |
56 | if (handle_data != nullptr) { |
57 | for (int i = 0; i < num_outs; i++) { |
58 | c->set_output_handle_shapes_and_types(i, *handle_data); |
59 | } |
60 | } |
61 | return OkStatus(); |
62 | } |
63 | |
64 | } // namespace |
65 | |
66 | REGISTER_OP("Switch" ) |
67 | .Input("data: T" ) |
68 | .Input("pred: bool" ) |
69 | .Output("output_false: T" ) |
70 | .Output("output_true: T" ) |
71 | .Attr("T: type" ) |
72 | .SetForwardTypeFn(full_type::ReplicateInput(0, 2)) |
73 | .SetShapeFn(SwitchShape); |
74 | |
75 | REGISTER_OP("RefSwitch" ) |
76 | .Input("data: Ref(T)" ) |
77 | .Input("pred: bool" ) |
78 | .Output("output_false: Ref(T)" ) |
79 | .Output("output_true: Ref(T)" ) |
80 | .Attr("T: type" ) |
81 | .SetAllowsUninitializedInput() |
82 | .SetShapeFn(SwitchShape); |
83 | |
84 | REGISTER_OP("_SwitchN" ) |
85 | .Input("data: T" ) |
86 | .Input("output_index: int32" ) |
87 | .Output("outputs: num_outs * T" ) |
88 | .Attr("num_outs: int >= 1" ) |
89 | .Attr("T: type" ) |
90 | .SetShapeFn(SwitchNShape); |
91 | |
92 | // -------------------------------------------------------------------------- |
93 | REGISTER_OP("RefSelect" ) |
94 | .Input("index: int32" ) |
95 | .Input("inputs: Ref(N * T)" ) |
96 | .Output("output: Ref(T)" ) |
97 | .Attr("T: type" ) |
98 | .Attr("N: int >= 1" ) |
99 | .SetShapeFn([](InferenceContext* c) { |
100 | ShapeHandle unused; |
101 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
102 | ShapeHandle first_input = c->input(1); |
103 | if (!c->FullyDefined(first_input)) { |
104 | c->set_output(0, c->UnknownShape()); |
105 | return OkStatus(); |
106 | } |
107 | // If any inputs aren't fully defined or don't match, we return unknown. |
108 | for (int i = 2; i < c->num_inputs(); ++i) { |
109 | ShapeHandle input = c->input(i); |
110 | if (!c->FullyDefined(input) || |
111 | !c->Merge(first_input, input, &unused).ok()) { |
112 | c->set_output(0, c->UnknownShape()); |
113 | return OkStatus(); |
114 | } |
115 | } |
116 | c->set_output(0, first_input); |
117 | return OkStatus(); |
118 | }); |
119 | |
120 | // -------------------------------------------------------------------------- |
121 | namespace { |
122 | Status MergeShape(InferenceContext* c) { |
123 | ShapeHandle out = c->input(0); |
124 | if (!c->RankKnown(out)) { |
125 | out = c->UnknownShape(); |
126 | } else { |
127 | int32_t rank = c->Rank(out); |
128 | for (int i = 1; i < c->num_inputs(); ++i) { |
129 | ShapeHandle input = c->input(i); |
130 | if (!c->RankKnown(input) || c->Rank(input) != rank) { |
131 | out = c->UnknownShape(); |
132 | break; |
133 | } |
134 | |
135 | for (int d = 0; d < rank; ++d) { |
136 | if (c->Value(c->Dim(input, d)) != c->Value(c->Dim(out, d))) { |
137 | TF_RETURN_IF_ERROR(c->ReplaceDim(out, d, c->UnknownDim(), &out)); |
138 | } |
139 | } |
140 | } |
141 | } |
142 | c->set_output(0, out); |
143 | c->set_output(1, c->Scalar()); |
144 | return OkStatus(); |
145 | } |
146 | } // namespace |
147 | |
148 | REGISTER_OP("Merge" ) |
149 | .Input("inputs: N * T" ) |
150 | .Output("output: T" ) |
151 | .Output("value_index: int32" ) |
152 | .Attr("T: type" ) |
153 | .Attr("N: int >= 1" ) |
154 | .SetForwardTypeFn(full_type::Merge()) |
155 | .SetShapeFn(MergeShape); |
156 | |
157 | REGISTER_OP("RefMerge" ) |
158 | .Input("inputs: Ref(N * T)" ) |
159 | .Output("output: Ref(T)" ) |
160 | .Output("value_index: int32" ) |
161 | .Attr("T: type" ) |
162 | .Attr("N: int >= 1" ) |
163 | .SetShapeFn(MergeShape); |
164 | |
165 | // -------------------------------------------------------------------------- |
166 | REGISTER_OP("Enter" ) |
167 | .Input("data: T" ) |
168 | .Output("output: T" ) |
169 | .Attr("T: type" ) |
170 | .Attr("frame_name: string" ) |
171 | .Attr("is_constant: bool = false" ) |
172 | .Attr("parallel_iterations: int = 10" ) |
173 | .SetForwardTypeFn(full_type::ReplicateInput()) |
174 | .SetShapeFn([](InferenceContext* c) { |
175 | c->set_output(0, c->UnknownShape()); |
176 | |
177 | // Handle resource shape / dtype, if present. |
178 | auto* handle_data = c->input_handle_shapes_and_types(0); |
179 | if (handle_data != nullptr) { |
180 | c->set_output_handle_shapes_and_types(0, *handle_data); |
181 | } |
182 | // Propagate shape if output is a constant. |
183 | bool is_constant; |
184 | TF_RETURN_IF_ERROR(c->GetAttr("is_constant" , &is_constant)); |
185 | if (is_constant) { |
186 | c->set_output(0, c->input(0)); |
187 | } |
188 | |
189 | return OkStatus(); |
190 | }); |
191 | |
192 | // -------------------------------------------------------------------------- |
193 | REGISTER_OP("RefEnter" ) |
194 | .Input("data: Ref(T)" ) |
195 | .Output("output: Ref(T)" ) |
196 | .Attr("T: type" ) |
197 | .Attr("frame_name: string" ) |
198 | .Attr("is_constant: bool = false" ) |
199 | .Attr("parallel_iterations: int = 10" ) |
200 | .SetShapeFn(shape_inference::UnchangedShape); |
201 | |
202 | // -------------------------------------------------------------------------- |
203 | REGISTER_OP("Exit" ) |
204 | .Input("data: T" ) |
205 | .Output("output: T" ) |
206 | .Attr("T: type" ) |
207 | .SetForwardTypeFn(full_type::ReplicateInput()) |
208 | .SetShapeFn(shape_inference::UnchangedShape); |
209 | |
210 | REGISTER_OP("RefExit" ) |
211 | .Input("data: Ref(T)" ) |
212 | .Output("output: Ref(T)" ) |
213 | .Attr("T: type" ) |
214 | .SetShapeFn(shape_inference::UnchangedShape); |
215 | |
216 | // -------------------------------------------------------------------------- |
217 | REGISTER_OP("NextIteration" ) |
218 | .Input("data: T" ) |
219 | .Output("output: T" ) |
220 | .Attr("T: type" ) |
221 | .SetForwardTypeFn(full_type::ReplicateInput()) |
222 | .SetShapeFn(shape_inference::UnchangedShape); |
223 | |
224 | REGISTER_OP("RefNextIteration" ) |
225 | .Input("data: Ref(T)" ) |
226 | .Output("output: Ref(T)" ) |
227 | .Attr("T: type" ) |
228 | .SetShapeFn(shape_inference::UnchangedShape); |
229 | |
230 | // -------------------------------------------------------------------------- |
231 | REGISTER_OP("LoopCond" ) |
232 | .Input("input: bool" ) |
233 | .Output("output: bool" ) |
234 | .SetShapeFn([](InferenceContext* c) { |
235 | return shape_inference::UnchangedShapeWithRank(c, 0); |
236 | }); |
237 | |
238 | // -------------------------------------------------------------------------- |
239 | REGISTER_OP("ControlTrigger" ).SetShapeFn(shape_inference::NoOutputs); |
240 | |
241 | // -------------------------------------------------------------------------- |
242 | REGISTER_OP("Abort" ) |
243 | .Attr("error_msg: string = ''" ) |
244 | .Attr("exit_without_error: bool = false" ) |
245 | .SetShapeFn(shape_inference::NoOutputs); |
246 | |
247 | } // namespace tensorflow |
248 | |