1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19
20namespace tensorflow {
21
22using shape_inference::InferenceContext;
23
24REGISTER_OP("SymbolicGradient")
25 .Input("input: Tin")
26 .Output("output: Tout")
27 .Attr("Tin: list(type)")
28 .Attr("Tout: list(type)")
29 .Attr("f: func")
30 .SetShapeFn([](InferenceContext* c) {
31 if (c->num_inputs() < c->num_outputs()) {
32 return errors::InvalidArgument("len(inputs) < len(outputs)");
33 }
34 std::vector<DataType> types;
35 TF_RETURN_IF_ERROR(c->GetAttr("Tin", &types));
36 // Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of
37 // (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its
38 // outputs (dx, dy, dz) are the same as (x, y, z).
39 for (int i = 0; i < c->num_outputs(); ++i) {
40 if (types[i] == DT_RESOURCE) {
41 const std::vector<shape_inference::ShapeAndType>* handle_type =
42 c->input_handle_shapes_and_types(i);
43 if (handle_type != nullptr) {
44 c->set_output(i, handle_type->at(0).shape);
45 } else {
46 c->set_output(i, c->UnknownShape());
47 }
48 } else {
49 c->set_output(i, c->input(i));
50 }
51 }
52 return OkStatus();
53 });
54
55REGISTER_OP("RemoteCall")
56 .Input("target: string")
57 .Input("args: Tin")
58 .Output("output: Tout")
59 .Attr("Tin: list(type)")
60 .Attr("Tout: list(type)")
61 .Attr("f: func")
62 .SetIsStateful()
63 .SetShapeFn(shape_inference::UnknownShape);
64
65// TODO(drpng): remove this.
66REGISTER_OP("_If")
67 .Input("cond: Tcond")
68 .Input("input: Tin")
69 .Output("output: Tout")
70 .Attr("Tcond: type")
71 .Attr("Tin: list(type)")
72 .Attr("Tout: list(type)")
73 .Attr("then_branch: func")
74 .Attr("else_branch: func")
75 .SetIsStateful()
76 .SetShapeFn(shape_inference::UnknownShape)
77 .Doc(R"doc(
78output = cond ? then_branch(input) : else_branch(input)
79
80cond: A Tensor. If the tensor is a scalar of non-boolean type, the
81 scalar is converted to a boolean according to the
82 following rule: if the scalar is a numerical value, non-zero means
83 True and zero means False; if the scalar is a string, non-empty
84 means True and empty means False. If the tensor is not a scalar,
85 being empty means False and being non-empty means True.
86input: A list of input tensors.
87then_branch: A function that takes 'inputs' and returns a list of
88 tensors, whose types are the same as what else_branch returns.
89else_branch: A function that takes 'inputs' and returns a list of
90 tensors. whose types are the same as what then_branch returns.
91)doc");
92
93Status IfShapeInferenceFn(shape_inference::InferenceContext* c) {
94 std::vector<PartialTensorShape> output_shapes;
95 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
96 // If `output_shapes` attr is set use that as the shapes of the outputs
97 // else return unknown shapes.
98 if (output_shapes.empty()) return shape_inference::UnknownShape(c);
99 if (output_shapes.size() != c->num_outputs()) {
100 return errors::InvalidArgument(
101 "`output_shapes` must be the same length as num outputs (",
102 output_shapes.size(), " vs. ", c->num_outputs());
103 }
104 for (size_t i = 0; i < output_shapes.size(); ++i) {
105 shape_inference::ShapeHandle output_shape_handle;
106 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
107 output_shapes[i], &output_shape_handle));
108 c->set_output(static_cast<int>(i), output_shape_handle);
109 }
110 return OkStatus();
111}
112
113REGISTER_OP("StatelessIf")
114 .Input("cond: Tcond")
115 .Input("input: Tin")
116 .Output("output: Tout")
117 .Attr("Tcond: type")
118 .Attr("Tin: list(type) >= 0")
119 .Attr("Tout: list(type) >= 0")
120 .Attr("then_branch: func")
121 .Attr("else_branch: func")
122 .Attr("output_shapes: list(shape) = []")
123 .SetShapeFn(IfShapeInferenceFn);
124
125REGISTER_OP("If")
126 .Input("cond: Tcond")
127 .Input("input: Tin")
128 .Output("output: Tout")
129 .Attr("Tcond: type")
130 .Attr("Tin: list(type) >= 0")
131 .Attr("Tout: list(type) >= 0")
132 .Attr("then_branch: func")
133 .Attr("else_branch: func")
134 .Attr("output_shapes: list(shape) = []")
135 .SetIsStateful()
136 .SetShapeFn(IfShapeInferenceFn);
137
138Status CaseShapeInferenceFn(shape_inference::InferenceContext* c) {
139 std::vector<PartialTensorShape> output_shapes;
140 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
141 // If `output_shapes` attr is set use that as the shapes of the outputs
142 // else return unknown shapes.
143 if (output_shapes.empty()) return shape_inference::UnknownShape(c);
144 if (output_shapes.size() != c->num_outputs()) {
145 return errors::InvalidArgument(
146 "`output_shapes` must be the same length as num outputs (",
147 output_shapes.size(), " vs. ", c->num_outputs());
148 }
149 for (size_t i = 0; i < output_shapes.size(); ++i) {
150 shape_inference::ShapeHandle output_shape_handle;
151 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
152 output_shapes[i], &output_shape_handle));
153 c->set_output(static_cast<int>(i), output_shape_handle);
154 }
155 return OkStatus();
156}
157
158REGISTER_OP("StatelessCase")
159 .Input("branch_index: int32")
160 .Input("input: Tin")
161 .Output("output: Tout")
162 .Attr("Tin: list(type) >= 0")
163 .Attr("Tout: list(type) >= 0")
164 .Attr("branches: list(func) >= 1")
165 .Attr("output_shapes: list(shape) = []")
166 .SetShapeFn(CaseShapeInferenceFn);
167
168REGISTER_OP("Case")
169 .Input("branch_index: int32")
170 .Input("input: Tin")
171 .Output("output: Tout")
172 .Attr("Tin: list(type) >= 0")
173 .Attr("Tout: list(type) >= 0")
174 .Attr("branches: list(func) >= 1")
175 .Attr("output_shapes: list(shape) = []")
176 .SetIsStateful()
177 .SetShapeFn(CaseShapeInferenceFn);
178
179// TODO(drpng): remove this.
180REGISTER_OP("_While")
181 .Input("input: T")
182 .Output("output: T")
183 .Attr("T: list(type) >= 0")
184 .Attr("cond: func")
185 .Attr("body: func")
186 .SetIsStateful()
187 .SetShapeFn([](shape_inference::InferenceContext* c) {
188 for (int i = 0; i < c->num_outputs(); ++i) {
189 c->set_output(i, c->input(i));
190 }
191 return OkStatus();
192 })
193 .Doc(R"doc(
194output = input; While (Cond(output)) { output = Body(output) }
195
196input: A list of input tensors whose types are T.
197output: A list of output tensors whose types are T.
198cond: A function takes 'input' and returns a tensor. If the tensor is
199 a scalar of non-boolean, the scalar is converted to a boolean
200 according to the following rule: if the scalar is a numerical
201 value, non-zero means True and zero means False; if the scalar is
202 a string, non-empty means True and empty means False. If the
203 tensor is not a scalar, non-emptiness means True and False
204 otherwise.
205body: A function that takes a list of tensors and returns another
206 list of tensors. Both lists have the same types as specified
207 by T.
208)doc");
209
210Status WhileShapeInferenceFn(shape_inference::InferenceContext* c) {
211 std::vector<PartialTensorShape> output_shapes;
212 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
213 // If `output_shapes` attr is set use that as the shapes of the outputs
214 // else use the input shapes.
215 if (!output_shapes.empty()) {
216 if (output_shapes.size() != c->num_outputs()) {
217 return errors::InvalidArgument(
218 "`output_shapes` must be the same length as num outputs (",
219 output_shapes.size(), " vs. ", c->num_outputs());
220 }
221 for (size_t i = 0; i < output_shapes.size(); ++i) {
222 shape_inference::ShapeHandle output_shape_handle;
223 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
224 output_shapes[i], &output_shape_handle));
225 c->set_output(static_cast<int>(i), output_shape_handle);
226 }
227 } else {
228 for (int i = 0; i < c->num_outputs(); ++i) {
229 c->set_output(i, c->input(i));
230 }
231 }
232 return OkStatus();
233}
234
235REGISTER_OP("While")
236 .Input("input: T")
237 .Output("output: T")
238 .Attr("T: list(type) >= 0")
239 .Attr("cond: func")
240 .Attr("body: func")
241 .Attr("output_shapes: list(shape) = []")
242 .Attr("parallel_iterations: int = 10")
243 .SetIsStateful()
244 .SetShapeFn(WhileShapeInferenceFn);
245
246REGISTER_OP("StatelessWhile")
247 .Input("input: T")
248 .Output("output: T")
249 .Attr("T: list(type) >= 0")
250 .Attr("cond: func")
251 .Attr("body: func")
252 .Attr("output_shapes: list(shape) = []")
253 .Attr("parallel_iterations: int = 10")
254 .SetShapeFn(WhileShapeInferenceFn);
255
256REGISTER_OP("ToBool")
257 .Input("input: T")
258 .Output("output: bool")
259 .Attr("T: type")
260 .SetShapeFn(shape_inference::ScalarShape);
261
262REGISTER_OP("For")
263 .Input("start: int32")
264 .Input("limit: int32")
265 .Input("delta: int32")
266 .Input("input: T")
267 .Output("output: T")
268 .Attr("T: list(type) >= 0")
269 .Attr("body: func")
270 .SetShapeFn(shape_inference::UnknownShape);
271
272// While no useful shape function is registered for function call ops directly,
273// ShapeRefiner is run by default to perform shape inference.
274REGISTER_OP("PartitionedCall")
275 .Input("args: Tin")
276 .Output("output: Tout")
277 .Attr("Tin: list(type) >= 0")
278 .Attr("Tout: list(type) >= 0")
279 .Attr("f: func")
280 .Attr("config: string = ''")
281 .Attr("config_proto: string = ''")
282 .Attr("executor_type: string = ''")
283 .SetShapeFn(shape_inference::UnknownShape);
284
285REGISTER_OP("StatefulPartitionedCall")
286 .Input("args: Tin")
287 .Output("output: Tout")
288 .Attr("Tin: list(type) >= 0")
289 .Attr("Tout: list(type) >= 0")
290 .Attr("f: func")
291 .Attr("config: string = ''") // Deprecated in favor of config_proto
292 .Attr("config_proto: string = ''")
293 .Attr("executor_type: string = ''")
294 .SetIsStateful()
295 .SetIsDistributedCommunication()
296 .SetShapeFn(shape_inference::UnknownShape);
297
298// This op is used as a placeholder in If branch functions. It doesn't provide a
299// valid output when run, so must either be removed (e.g. replaced with a
300// function input) or guaranteed not to be used (e.g. if mirroring an
301// intermediate output needed for the gradient computation of the other branch).
302REGISTER_OP("FakeParam")
303 .Output("output: dtype")
304 .Attr("dtype: type")
305 .Attr("shape: shape")
306 .SetShapeFn([](InferenceContext* c) {
307 PartialTensorShape shape;
308 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
309 shape_inference::ShapeHandle out;
310 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
311 c->set_output(0, out);
312 return OkStatus();
313 });
314
315// Returns the device index.
316REGISTER_OP("DeviceIndex")
317 .Output("index: int32")
318 .Attr("device_names: list(string)")
319 .SetShapeFn(shape_inference::ScalarShape)
320 .SetDoNotOptimize();
321
322} // end namespace tensorflow
323