1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using shape_inference::InferenceContext; |
23 | |
24 | REGISTER_OP("SymbolicGradient" ) |
25 | .Input("input: Tin" ) |
26 | .Output("output: Tout" ) |
27 | .Attr("Tin: list(type)" ) |
28 | .Attr("Tout: list(type)" ) |
29 | .Attr("f: func" ) |
30 | .SetShapeFn([](InferenceContext* c) { |
31 | if (c->num_inputs() < c->num_outputs()) { |
32 | return errors::InvalidArgument("len(inputs) < len(outputs)" ); |
33 | } |
34 | std::vector<DataType> types; |
35 | TF_RETURN_IF_ERROR(c->GetAttr("Tin" , &types)); |
36 | // Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of |
37 | // (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its |
38 | // outputs (dx, dy, dz) are the same as (x, y, z). |
39 | for (int i = 0; i < c->num_outputs(); ++i) { |
40 | if (types[i] == DT_RESOURCE) { |
41 | const std::vector<shape_inference::ShapeAndType>* handle_type = |
42 | c->input_handle_shapes_and_types(i); |
43 | if (handle_type != nullptr) { |
44 | c->set_output(i, handle_type->at(0).shape); |
45 | } else { |
46 | c->set_output(i, c->UnknownShape()); |
47 | } |
48 | } else { |
49 | c->set_output(i, c->input(i)); |
50 | } |
51 | } |
52 | return OkStatus(); |
53 | }); |
54 | |
55 | REGISTER_OP("RemoteCall" ) |
56 | .Input("target: string" ) |
57 | .Input("args: Tin" ) |
58 | .Output("output: Tout" ) |
59 | .Attr("Tin: list(type)" ) |
60 | .Attr("Tout: list(type)" ) |
61 | .Attr("f: func" ) |
62 | .SetIsStateful() |
63 | .SetShapeFn(shape_inference::UnknownShape); |
64 | |
65 | // TODO(drpng): remove this. |
66 | REGISTER_OP("_If" ) |
67 | .Input("cond: Tcond" ) |
68 | .Input("input: Tin" ) |
69 | .Output("output: Tout" ) |
70 | .Attr("Tcond: type" ) |
71 | .Attr("Tin: list(type)" ) |
72 | .Attr("Tout: list(type)" ) |
73 | .Attr("then_branch: func" ) |
74 | .Attr("else_branch: func" ) |
75 | .SetIsStateful() |
76 | .SetShapeFn(shape_inference::UnknownShape) |
77 | .Doc(R"doc( |
78 | output = cond ? then_branch(input) : else_branch(input) |
79 | |
80 | cond: A Tensor. If the tensor is a scalar of non-boolean type, the |
81 | scalar is converted to a boolean according to the |
82 | following rule: if the scalar is a numerical value, non-zero means |
83 | True and zero means False; if the scalar is a string, non-empty |
84 | means True and empty means False. If the tensor is not a scalar, |
85 | being empty means False and being non-empty means True. |
86 | input: A list of input tensors. |
87 | then_branch: A function that takes 'inputs' and returns a list of |
88 | tensors, whose types are the same as what else_branch returns. |
89 | else_branch: A function that takes 'inputs' and returns a list of |
90 | tensors. whose types are the same as what then_branch returns. |
91 | )doc" ); |
92 | |
93 | Status IfShapeInferenceFn(shape_inference::InferenceContext* c) { |
94 | std::vector<PartialTensorShape> output_shapes; |
95 | TF_RETURN_IF_ERROR(c->GetAttr("output_shapes" , &output_shapes)); |
96 | // If `output_shapes` attr is set use that as the shapes of the outputs |
97 | // else return unknown shapes. |
98 | if (output_shapes.empty()) return shape_inference::UnknownShape(c); |
99 | if (output_shapes.size() != c->num_outputs()) { |
100 | return errors::InvalidArgument( |
101 | "`output_shapes` must be the same length as num outputs (" , |
102 | output_shapes.size(), " vs. " , c->num_outputs()); |
103 | } |
104 | for (size_t i = 0; i < output_shapes.size(); ++i) { |
105 | shape_inference::ShapeHandle output_shape_handle; |
106 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( |
107 | output_shapes[i], &output_shape_handle)); |
108 | c->set_output(static_cast<int>(i), output_shape_handle); |
109 | } |
110 | return OkStatus(); |
111 | } |
112 | |
113 | REGISTER_OP("StatelessIf" ) |
114 | .Input("cond: Tcond" ) |
115 | .Input("input: Tin" ) |
116 | .Output("output: Tout" ) |
117 | .Attr("Tcond: type" ) |
118 | .Attr("Tin: list(type) >= 0" ) |
119 | .Attr("Tout: list(type) >= 0" ) |
120 | .Attr("then_branch: func" ) |
121 | .Attr("else_branch: func" ) |
122 | .Attr("output_shapes: list(shape) = []" ) |
123 | .SetShapeFn(IfShapeInferenceFn); |
124 | |
125 | REGISTER_OP("If" ) |
126 | .Input("cond: Tcond" ) |
127 | .Input("input: Tin" ) |
128 | .Output("output: Tout" ) |
129 | .Attr("Tcond: type" ) |
130 | .Attr("Tin: list(type) >= 0" ) |
131 | .Attr("Tout: list(type) >= 0" ) |
132 | .Attr("then_branch: func" ) |
133 | .Attr("else_branch: func" ) |
134 | .Attr("output_shapes: list(shape) = []" ) |
135 | .SetIsStateful() |
136 | .SetShapeFn(IfShapeInferenceFn); |
137 | |
138 | Status CaseShapeInferenceFn(shape_inference::InferenceContext* c) { |
139 | std::vector<PartialTensorShape> output_shapes; |
140 | TF_RETURN_IF_ERROR(c->GetAttr("output_shapes" , &output_shapes)); |
141 | // If `output_shapes` attr is set use that as the shapes of the outputs |
142 | // else return unknown shapes. |
143 | if (output_shapes.empty()) return shape_inference::UnknownShape(c); |
144 | if (output_shapes.size() != c->num_outputs()) { |
145 | return errors::InvalidArgument( |
146 | "`output_shapes` must be the same length as num outputs (" , |
147 | output_shapes.size(), " vs. " , c->num_outputs()); |
148 | } |
149 | for (size_t i = 0; i < output_shapes.size(); ++i) { |
150 | shape_inference::ShapeHandle output_shape_handle; |
151 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( |
152 | output_shapes[i], &output_shape_handle)); |
153 | c->set_output(static_cast<int>(i), output_shape_handle); |
154 | } |
155 | return OkStatus(); |
156 | } |
157 | |
158 | REGISTER_OP("StatelessCase" ) |
159 | .Input("branch_index: int32" ) |
160 | .Input("input: Tin" ) |
161 | .Output("output: Tout" ) |
162 | .Attr("Tin: list(type) >= 0" ) |
163 | .Attr("Tout: list(type) >= 0" ) |
164 | .Attr("branches: list(func) >= 1" ) |
165 | .Attr("output_shapes: list(shape) = []" ) |
166 | .SetShapeFn(CaseShapeInferenceFn); |
167 | |
168 | REGISTER_OP("Case" ) |
169 | .Input("branch_index: int32" ) |
170 | .Input("input: Tin" ) |
171 | .Output("output: Tout" ) |
172 | .Attr("Tin: list(type) >= 0" ) |
173 | .Attr("Tout: list(type) >= 0" ) |
174 | .Attr("branches: list(func) >= 1" ) |
175 | .Attr("output_shapes: list(shape) = []" ) |
176 | .SetIsStateful() |
177 | .SetShapeFn(CaseShapeInferenceFn); |
178 | |
179 | // TODO(drpng): remove this. |
180 | REGISTER_OP("_While" ) |
181 | .Input("input: T" ) |
182 | .Output("output: T" ) |
183 | .Attr("T: list(type) >= 0" ) |
184 | .Attr("cond: func" ) |
185 | .Attr("body: func" ) |
186 | .SetIsStateful() |
187 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
188 | for (int i = 0; i < c->num_outputs(); ++i) { |
189 | c->set_output(i, c->input(i)); |
190 | } |
191 | return OkStatus(); |
192 | }) |
193 | .Doc(R"doc( |
194 | output = input; While (Cond(output)) { output = Body(output) } |
195 | |
196 | input: A list of input tensors whose types are T. |
197 | output: A list of output tensors whose types are T. |
198 | cond: A function takes 'input' and returns a tensor. If the tensor is |
199 | a scalar of non-boolean, the scalar is converted to a boolean |
200 | according to the following rule: if the scalar is a numerical |
201 | value, non-zero means True and zero means False; if the scalar is |
202 | a string, non-empty means True and empty means False. If the |
203 | tensor is not a scalar, non-emptiness means True and False |
204 | otherwise. |
205 | body: A function that takes a list of tensors and returns another |
206 | list of tensors. Both lists have the same types as specified |
207 | by T. |
208 | )doc" ); |
209 | |
210 | Status WhileShapeInferenceFn(shape_inference::InferenceContext* c) { |
211 | std::vector<PartialTensorShape> output_shapes; |
212 | TF_RETURN_IF_ERROR(c->GetAttr("output_shapes" , &output_shapes)); |
213 | // If `output_shapes` attr is set use that as the shapes of the outputs |
214 | // else use the input shapes. |
215 | if (!output_shapes.empty()) { |
216 | if (output_shapes.size() != c->num_outputs()) { |
217 | return errors::InvalidArgument( |
218 | "`output_shapes` must be the same length as num outputs (" , |
219 | output_shapes.size(), " vs. " , c->num_outputs()); |
220 | } |
221 | for (size_t i = 0; i < output_shapes.size(); ++i) { |
222 | shape_inference::ShapeHandle output_shape_handle; |
223 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( |
224 | output_shapes[i], &output_shape_handle)); |
225 | c->set_output(static_cast<int>(i), output_shape_handle); |
226 | } |
227 | } else { |
228 | for (int i = 0; i < c->num_outputs(); ++i) { |
229 | c->set_output(i, c->input(i)); |
230 | } |
231 | } |
232 | return OkStatus(); |
233 | } |
234 | |
235 | REGISTER_OP("While" ) |
236 | .Input("input: T" ) |
237 | .Output("output: T" ) |
238 | .Attr("T: list(type) >= 0" ) |
239 | .Attr("cond: func" ) |
240 | .Attr("body: func" ) |
241 | .Attr("output_shapes: list(shape) = []" ) |
242 | .Attr("parallel_iterations: int = 10" ) |
243 | .SetIsStateful() |
244 | .SetShapeFn(WhileShapeInferenceFn); |
245 | |
246 | REGISTER_OP("StatelessWhile" ) |
247 | .Input("input: T" ) |
248 | .Output("output: T" ) |
249 | .Attr("T: list(type) >= 0" ) |
250 | .Attr("cond: func" ) |
251 | .Attr("body: func" ) |
252 | .Attr("output_shapes: list(shape) = []" ) |
253 | .Attr("parallel_iterations: int = 10" ) |
254 | .SetShapeFn(WhileShapeInferenceFn); |
255 | |
256 | REGISTER_OP("ToBool" ) |
257 | .Input("input: T" ) |
258 | .Output("output: bool" ) |
259 | .Attr("T: type" ) |
260 | .SetShapeFn(shape_inference::ScalarShape); |
261 | |
262 | REGISTER_OP("For" ) |
263 | .Input("start: int32" ) |
264 | .Input("limit: int32" ) |
265 | .Input("delta: int32" ) |
266 | .Input("input: T" ) |
267 | .Output("output: T" ) |
268 | .Attr("T: list(type) >= 0" ) |
269 | .Attr("body: func" ) |
270 | .SetShapeFn(shape_inference::UnknownShape); |
271 | |
272 | // While no useful shape function is registered for function call ops directly, |
273 | // ShapeRefiner is run by default to perform shape inference. |
274 | REGISTER_OP("PartitionedCall" ) |
275 | .Input("args: Tin" ) |
276 | .Output("output: Tout" ) |
277 | .Attr("Tin: list(type) >= 0" ) |
278 | .Attr("Tout: list(type) >= 0" ) |
279 | .Attr("f: func" ) |
280 | .Attr("config: string = ''" ) |
281 | .Attr("config_proto: string = ''" ) |
282 | .Attr("executor_type: string = ''" ) |
283 | .SetShapeFn(shape_inference::UnknownShape); |
284 | |
285 | REGISTER_OP("StatefulPartitionedCall" ) |
286 | .Input("args: Tin" ) |
287 | .Output("output: Tout" ) |
288 | .Attr("Tin: list(type) >= 0" ) |
289 | .Attr("Tout: list(type) >= 0" ) |
290 | .Attr("f: func" ) |
291 | .Attr("config: string = ''" ) // Deprecated in favor of config_proto |
292 | .Attr("config_proto: string = ''" ) |
293 | .Attr("executor_type: string = ''" ) |
294 | .SetIsStateful() |
295 | .SetIsDistributedCommunication() |
296 | .SetShapeFn(shape_inference::UnknownShape); |
297 | |
298 | // This op is used as a placeholder in If branch functions. It doesn't provide a |
299 | // valid output when run, so must either be removed (e.g. replaced with a |
300 | // function input) or guaranteed not to be used (e.g. if mirroring an |
301 | // intermediate output needed for the gradient computation of the other branch). |
302 | REGISTER_OP("FakeParam" ) |
303 | .Output("output: dtype" ) |
304 | .Attr("dtype: type" ) |
305 | .Attr("shape: shape" ) |
306 | .SetShapeFn([](InferenceContext* c) { |
307 | PartialTensorShape shape; |
308 | TF_RETURN_IF_ERROR(c->GetAttr("shape" , &shape)); |
309 | shape_inference::ShapeHandle out; |
310 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); |
311 | c->set_output(0, out); |
312 | return OkStatus(); |
313 | }); |
314 | |
315 | // Returns the device index. |
316 | REGISTER_OP("DeviceIndex" ) |
317 | .Output("index: int32" ) |
318 | .Attr("device_names: list(string)" ) |
319 | .SetShapeFn(shape_inference::ScalarShape) |
320 | .SetDoNotOptimize(); |
321 | |
322 | } // end namespace tensorflow |
323 | |