1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19
20namespace tensorflow {
21
22using shape_inference::InferenceContext;
23using shape_inference::ShapeHandle;
24
25// --------------------------------------------------------------------------
26namespace {
27
28Status SwitchShape(InferenceContext* c) {
29 ShapeHandle unused;
30 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
31 ShapeHandle out = c->input(0);
32 c->set_output(0, out);
33 c->set_output(1, out);
34
35 // Handle resource shape / dtype.
36 auto* handle_data = c->input_handle_shapes_and_types(0);
37 if (handle_data != nullptr) {
38 c->set_output_handle_shapes_and_types(0, *handle_data);
39 c->set_output_handle_shapes_and_types(1, *handle_data);
40 }
41 return OkStatus();
42}
43
44Status SwitchNShape(InferenceContext* c) {
45 ShapeHandle unused;
46 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
47 ShapeHandle out = c->input(0);
48 int num_outs;
49 TF_RETURN_IF_ERROR(c->GetAttr("num_outs", &num_outs));
50 for (int i = 0; i < num_outs; i++) {
51 c->set_output(i, out);
52 }
53
54 // Handle resource shape / dtype.
55 auto* handle_data = c->input_handle_shapes_and_types(0);
56 if (handle_data != nullptr) {
57 for (int i = 0; i < num_outs; i++) {
58 c->set_output_handle_shapes_and_types(i, *handle_data);
59 }
60 }
61 return OkStatus();
62}
63
64} // namespace
65
66REGISTER_OP("Switch")
67 .Input("data: T")
68 .Input("pred: bool")
69 .Output("output_false: T")
70 .Output("output_true: T")
71 .Attr("T: type")
72 .SetForwardTypeFn(full_type::ReplicateInput(0, 2))
73 .SetShapeFn(SwitchShape);
74
75REGISTER_OP("RefSwitch")
76 .Input("data: Ref(T)")
77 .Input("pred: bool")
78 .Output("output_false: Ref(T)")
79 .Output("output_true: Ref(T)")
80 .Attr("T: type")
81 .SetAllowsUninitializedInput()
82 .SetShapeFn(SwitchShape);
83
84REGISTER_OP("_SwitchN")
85 .Input("data: T")
86 .Input("output_index: int32")
87 .Output("outputs: num_outs * T")
88 .Attr("num_outs: int >= 1")
89 .Attr("T: type")
90 .SetShapeFn(SwitchNShape);
91
92// --------------------------------------------------------------------------
93REGISTER_OP("RefSelect")
94 .Input("index: int32")
95 .Input("inputs: Ref(N * T)")
96 .Output("output: Ref(T)")
97 .Attr("T: type")
98 .Attr("N: int >= 1")
99 .SetShapeFn([](InferenceContext* c) {
100 ShapeHandle unused;
101 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
102 ShapeHandle first_input = c->input(1);
103 if (!c->FullyDefined(first_input)) {
104 c->set_output(0, c->UnknownShape());
105 return OkStatus();
106 }
107 // If any inputs aren't fully defined or don't match, we return unknown.
108 for (int i = 2; i < c->num_inputs(); ++i) {
109 ShapeHandle input = c->input(i);
110 if (!c->FullyDefined(input) ||
111 !c->Merge(first_input, input, &unused).ok()) {
112 c->set_output(0, c->UnknownShape());
113 return OkStatus();
114 }
115 }
116 c->set_output(0, first_input);
117 return OkStatus();
118 });
119
120// --------------------------------------------------------------------------
121namespace {
122Status MergeShape(InferenceContext* c) {
123 ShapeHandle out = c->input(0);
124 if (!c->RankKnown(out)) {
125 out = c->UnknownShape();
126 } else {
127 int32_t rank = c->Rank(out);
128 for (int i = 1; i < c->num_inputs(); ++i) {
129 ShapeHandle input = c->input(i);
130 if (!c->RankKnown(input) || c->Rank(input) != rank) {
131 out = c->UnknownShape();
132 break;
133 }
134
135 for (int d = 0; d < rank; ++d) {
136 if (c->Value(c->Dim(input, d)) != c->Value(c->Dim(out, d))) {
137 TF_RETURN_IF_ERROR(c->ReplaceDim(out, d, c->UnknownDim(), &out));
138 }
139 }
140 }
141 }
142 c->set_output(0, out);
143 c->set_output(1, c->Scalar());
144 return OkStatus();
145}
146} // namespace
147
148REGISTER_OP("Merge")
149 .Input("inputs: N * T")
150 .Output("output: T")
151 .Output("value_index: int32")
152 .Attr("T: type")
153 .Attr("N: int >= 1")
154 .SetForwardTypeFn(full_type::Merge())
155 .SetShapeFn(MergeShape);
156
157REGISTER_OP("RefMerge")
158 .Input("inputs: Ref(N * T)")
159 .Output("output: Ref(T)")
160 .Output("value_index: int32")
161 .Attr("T: type")
162 .Attr("N: int >= 1")
163 .SetShapeFn(MergeShape);
164
165// --------------------------------------------------------------------------
166REGISTER_OP("Enter")
167 .Input("data: T")
168 .Output("output: T")
169 .Attr("T: type")
170 .Attr("frame_name: string")
171 .Attr("is_constant: bool = false")
172 .Attr("parallel_iterations: int = 10")
173 .SetForwardTypeFn(full_type::ReplicateInput())
174 .SetShapeFn([](InferenceContext* c) {
175 c->set_output(0, c->UnknownShape());
176
177 // Handle resource shape / dtype, if present.
178 auto* handle_data = c->input_handle_shapes_and_types(0);
179 if (handle_data != nullptr) {
180 c->set_output_handle_shapes_and_types(0, *handle_data);
181 }
182 // Propagate shape if output is a constant.
183 bool is_constant;
184 TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant));
185 if (is_constant) {
186 c->set_output(0, c->input(0));
187 }
188
189 return OkStatus();
190 });
191
192// --------------------------------------------------------------------------
193REGISTER_OP("RefEnter")
194 .Input("data: Ref(T)")
195 .Output("output: Ref(T)")
196 .Attr("T: type")
197 .Attr("frame_name: string")
198 .Attr("is_constant: bool = false")
199 .Attr("parallel_iterations: int = 10")
200 .SetShapeFn(shape_inference::UnchangedShape);
201
202// --------------------------------------------------------------------------
203REGISTER_OP("Exit")
204 .Input("data: T")
205 .Output("output: T")
206 .Attr("T: type")
207 .SetForwardTypeFn(full_type::ReplicateInput())
208 .SetShapeFn(shape_inference::UnchangedShape);
209
210REGISTER_OP("RefExit")
211 .Input("data: Ref(T)")
212 .Output("output: Ref(T)")
213 .Attr("T: type")
214 .SetShapeFn(shape_inference::UnchangedShape);
215
216// --------------------------------------------------------------------------
217REGISTER_OP("NextIteration")
218 .Input("data: T")
219 .Output("output: T")
220 .Attr("T: type")
221 .SetForwardTypeFn(full_type::ReplicateInput())
222 .SetShapeFn(shape_inference::UnchangedShape);
223
224REGISTER_OP("RefNextIteration")
225 .Input("data: Ref(T)")
226 .Output("output: Ref(T)")
227 .Attr("T: type")
228 .SetShapeFn(shape_inference::UnchangedShape);
229
230// --------------------------------------------------------------------------
231REGISTER_OP("LoopCond")
232 .Input("input: bool")
233 .Output("output: bool")
234 .SetShapeFn([](InferenceContext* c) {
235 return shape_inference::UnchangedShapeWithRank(c, 0);
236 });
237
238// --------------------------------------------------------------------------
239REGISTER_OP("ControlTrigger").SetShapeFn(shape_inference::NoOutputs);
240
241// --------------------------------------------------------------------------
242REGISTER_OP("Abort")
243 .Attr("error_msg: string = ''")
244 .Attr("exit_without_error: bool = false")
245 .SetShapeFn(shape_inference::NoOutputs);
246
247} // namespace tensorflow
248