1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/framework/python_op_gen_internal.h" |
17 | |
18 | #include <float.h> |
19 | #include <stdio.h> |
20 | |
21 | #include <iomanip> |
22 | #include <sstream> |
23 | #include <unordered_map> |
24 | |
25 | #include "absl/strings/escaping.h" |
26 | #include "absl/strings/str_replace.h" |
27 | #include "tensorflow/core/framework/api_def.pb.h" |
28 | #include "tensorflow/core/framework/attr_value.pb.h" |
29 | #include "tensorflow/core/framework/op.h" |
30 | #include "tensorflow/core/framework/op_def.pb.h" |
31 | #include "tensorflow/core/framework/op_def_util.h" |
32 | #include "tensorflow/core/framework/op_gen_lib.h" |
33 | #include "tensorflow/core/framework/tensor.pb.h" |
34 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
35 | #include "tensorflow/core/framework/types.h" |
36 | #include "tensorflow/core/framework/types.pb.h" |
37 | #include "tensorflow/core/lib/gtl/map_util.h" |
38 | #include "tensorflow/core/lib/strings/str_util.h" |
39 | #include "tensorflow/core/lib/strings/strcat.h" |
40 | #include "tensorflow/core/lib/strings/stringprintf.h" |
41 | #include "tensorflow/core/platform/logging.h" |
42 | #include "tensorflow/core/platform/macros.h" |
43 | #include "tensorflow/core/platform/types.h" |
44 | |
45 | namespace tensorflow { |
46 | namespace python_op_gen_internal { |
47 | |
48 | const int kRightMargin = 78; |
49 | // Names specified in tf_export decorators are exported to |
50 | // TensorFlow 2.0 by default. |
51 | const int kLatestAPIExportVersion = 2; |
52 | |
53 | bool IsPythonReserved(const string& s) { |
54 | static const std::set<string>* const kPythonReserved = new std::set<string>( |
55 | {// Keywords in Python, from: |
56 | // import keyword |
57 | // print keyword.kwlist |
58 | "and" , "as" , "assert" , "break" , "class" , "continue" , "def" , "del" , |
59 | "elif" , "else" , "except" , "exec" , "finally" , "for" , "from" , "global" , |
60 | "if" , "import" , "in" , "is" , "lambda" , "not" , "or" , "pass" , "print" , |
61 | "raise" , "return" , "try" , "while" , "with" , "yield" , |
62 | // Built-in functions and types in Python, from: |
63 | // [x for x in dir(__builtins__) if not x[0].islower()] |
64 | "ArithmeticError" , "AssertionError" , "AttributeError" , "BaseException" , |
65 | "BufferError" , "BytesWarning" , "DeprecationWarning" , "EOFError" , |
66 | "Ellipsis" , "EnvironmentError" , "Exception" , "False" , |
67 | "FloatingPointError" , "FutureWarning" , "GeneratorExit" , "IOError" , |
68 | "ImportError" , "ImportWarning" , "IndentationError" , "IndexError" , |
69 | "KeyError" , "KeyboardInterrupt" , "LookupError" , "MemoryError" , |
70 | "NameError" , "None" , "NotImplemented" , "NotImplementedError" , "OSError" , |
71 | "OverflowError" , "PendingDeprecationWarning" , "ReferenceError" , |
72 | "RuntimeError" , "RuntimeWarning" , "StandardError" , "StopIteration" , |
73 | "SyntaxError" , "SyntaxWarning" , "SystemError" , "SystemExit" , "TabError" , |
74 | "True" , "TypeError" , "UnboundLocalError" , "UnicodeDecodeError" , |
75 | "UnicodeEncodeError" , "UnicodeError" , "UnicodeTranslateError" , |
76 | "UnicodeWarning" , "UserWarning" , "ValueError" , "Warning" , |
77 | "ZeroDivisionError" , "__debug__" , "__doc__" , "__import__" , "__name__" , |
78 | "__package__" }); |
79 | |
80 | return kPythonReserved->count(s) > 0; |
81 | } |
82 | |
83 | bool IsOpWithUnderscorePrefix(const string& s) { |
84 | static const std::set<string>* const kUnderscoreOps = new std::set<string>( |
85 | {// Lowercase built-in functions and types in Python, from: |
86 | // [x for x in dir(__builtins__) if x[0].islower()] except "round". |
87 | // These need to be excluded so they don't conflict with actual built-in |
88 | // functions since we use '*' imports. |
89 | "abs" , "all" , "any" , "apply" , "bin" , "bool" , "buffer" , "bytearray" , |
90 | "bytes" , "callable" , "chr" , "classmethod" , "cmp" , "coerce" , "compile" , |
91 | "complex" , "copyright" , "credits" , "delattr" , "dict" , "dir" , "divmod" , |
92 | "enumerate" , "eval" , "execfile" , "exit" , "file" , "filter" , "float" , |
93 | "format" , "frozenset" , "getattr" , "globals" , "hasattr" , "hash" , "help" , |
94 | "hex" , "id" , "input" , "int" , "intern" , "isinstance" , "issubclass" , |
95 | "iter" , "len" , "license" , "list" , "locals" , "long" , "map" , "max" , |
96 | "memoryview" , "min" , "next" , "object" , "oct" , "open" , "ord" , "pow" , |
97 | "print" , "property" , "quit" , "range" , "raw_input" , "reduce" , "reload" , |
98 | "repr" , "reversed" , "set" , "setattr" , "slice" , "sorted" , "staticmethod" , |
99 | "str" , "sum" , "super" , "tuple" , "type" , "unichr" , "unicode" , "vars" , |
100 | "xrange" , "zip" , |
101 | // These have the same name as ops defined in Python and might be used |
102 | // incorrectly depending on order of '*' imports. |
103 | // TODO(annarev): reduce usage of '*' imports and remove these from the |
104 | // list. |
105 | "fused_batch_norm" , "histogram_fixed_width" , "stack" , |
106 | "batch_norm_with_global_normalization" , "clip_by_value" }); |
107 | return kUnderscoreOps->count(s) > 0; |
108 | } |
109 | |
110 | string AvoidPythonReserved(const string& s) { |
111 | // Convert namespace separators ('>' characters) to joiners |
112 | string result = absl::StrReplaceAll(s, {{">" , "_" }}); |
113 | |
114 | if (IsPythonReserved(result)) return strings::StrCat(result, "_" ); |
115 | return result; |
116 | } |
117 | |
118 | // Indent the first line by "initial" spaces and all following lines |
119 | // by "rest" spaces. |
120 | string Indent(int initial, int rest, StringPiece in) { |
121 | // TODO(josh11b): Also word-wrapping? |
122 | string copy(in.data(), in.size()); |
123 | absl::StripTrailingAsciiWhitespace(©); |
124 | std::vector<string> v = str_util::Split(copy, '\n'); |
125 | |
126 | string result; |
127 | bool first = true; |
128 | for (const string& line : v) { |
129 | if (first) { |
130 | result = strings::StrCat(Spaces(initial), line, "\n" ); |
131 | first = false; |
132 | } else { |
133 | if (line.empty()) { |
134 | strings::StrAppend(&result, "\n" ); |
135 | } else { |
136 | strings::StrAppend(&result, Spaces(rest), line, "\n" ); |
137 | } |
138 | } |
139 | } |
140 | return result; |
141 | } |
142 | |
143 | // Adds append to *dest, with a space if the first line will be <= width, |
144 | // or a newline otherwise. |
145 | void AppendWithinWidth(string* dest, StringPiece append, int width) { |
146 | auto first_line = append.find('\n'); |
147 | if (first_line == string::npos) first_line = append.size(); |
148 | if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) { |
149 | strings::StrAppend(dest, "\n" , append); |
150 | } else { |
151 | strings::StrAppend(dest, " " , append); |
152 | } |
153 | } |
154 | |
155 | // Like DataTypeString() but uses the Python names for the |
156 | // float types. |
157 | string PythonDataTypeString(DataType dtype) { |
158 | switch (dtype) { |
159 | case DT_FLOAT: |
160 | return "float32" ; |
161 | case DT_DOUBLE: |
162 | return "float64" ; |
163 | default: |
164 | return DataTypeString(dtype); |
165 | } |
166 | } |
167 | |
168 | string TypeString(DataType dtype, bool ref) { |
169 | if (ref) { |
170 | return strings::StrCat("mutable `" , PythonDataTypeString(dtype), "`" ); |
171 | } else { |
172 | return strings::StrCat("`" , PythonDataTypeString(dtype), "`" ); |
173 | } |
174 | } |
175 | |
176 | string TypeListString(const AttrValue& value) { |
177 | string ret; |
178 | for (int t : value.list().type()) { |
179 | if (!ret.empty()) strings::StrAppend(&ret, ", " ); |
180 | DataType dtype = static_cast<DataType>(t); |
181 | if (IsRefType(dtype)) { |
182 | strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)), |
183 | " mutable" ); |
184 | } else { |
185 | strings::StrAppend(&ret, "`" , PythonDataTypeString(dtype), "`" ); |
186 | } |
187 | } |
188 | return ret; |
189 | } |
190 | |
191 | string SingleTensorName(DataType dtype, bool is_ref) { |
192 | const string type_str = TypeString(dtype, is_ref); |
193 | return strings::StrCat("A `Tensor` of type " , type_str, "." ); |
194 | } |
195 | |
196 | const char kUnknownTensorType[] = {"A `Tensor`." }; |
197 | |
198 | string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg, |
199 | const std::unordered_map<string, string>& inferred_attrs, |
200 | bool is_output) { |
201 | if (!arg.number_attr().empty()) { |
202 | // N Tensors with the same type |
203 | const string* original_arg = |
204 | gtl::FindOrNull(inferred_attrs, arg.number_attr()); |
205 | string prefix; |
206 | if (original_arg == nullptr) { |
207 | prefix = strings::StrCat("A list of `" , arg.number_attr(), "`" ); |
208 | } else if (*original_arg == arg.name()) { |
209 | const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); |
210 | if (attr->has_minimum() && attr->minimum() > 0) { |
211 | prefix = strings::StrCat("A list of at least " , attr->minimum()); |
212 | } else { |
213 | prefix = "A list of" ; |
214 | } |
215 | } else { |
216 | prefix = strings::StrCat("A list with the same length as `" , |
217 | AvoidPythonReserved(*original_arg), "` of" ); |
218 | } |
219 | |
220 | if (arg.type() != DT_INVALID) { |
221 | return strings::StrCat(prefix, " `Tensor` objects with type " , |
222 | TypeString(arg.type(), arg.is_ref()), "." ); |
223 | } else { |
224 | original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr()); |
225 | if (arg.is_ref()) { |
226 | strings::StrAppend(&prefix, " mutable" ); |
227 | } |
228 | if (original_arg == nullptr) { |
229 | return strings::StrCat(prefix, " `Tensor` objects with type `" , |
230 | arg.type_attr(), "`." ); |
231 | } else if (*original_arg == arg.name()) { |
232 | const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); |
233 | if (attr->has_allowed_values()) { |
234 | return strings::StrCat(prefix, |
235 | " `Tensor` objects with the same type in: " , |
236 | TypeListString(attr->allowed_values()), "." ); |
237 | } else { |
238 | return strings::StrCat(prefix, |
239 | " `Tensor` objects with the same type." ); |
240 | } |
241 | } else { |
242 | return strings::StrCat(prefix, |
243 | " `Tensor` objects with the same type as `" , |
244 | AvoidPythonReserved(*original_arg), "`." ); |
245 | } |
246 | } |
247 | } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) { |
248 | const bool is_list = !arg.type_list_attr().empty(); |
249 | const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr(); |
250 | const OpDef::AttrDef* attr = FindAttr(attr_name, op_def); |
251 | const string mutable_str = arg.is_ref() ? "mutable " : "" ; |
252 | const string prefix = |
253 | is_list ? strings::StrCat("A list of " , mutable_str, "`Tensor` objects" ) |
254 | : strings::StrCat("A " , mutable_str, "`Tensor`" ); |
255 | const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name); |
256 | if (original_arg == nullptr) { |
257 | return strings::StrCat(prefix, " of type `" , attr_name, "`." ); |
258 | } else if (*original_arg == arg.name()) { |
259 | if (attr->has_allowed_values()) { |
260 | if (is_list) { |
261 | return strings::StrCat(prefix, " with types from: " , |
262 | TypeListString(attr->allowed_values()), "." ); |
263 | } else { |
264 | return strings::StrCat( |
265 | prefix, is_output ? ". Has one of the following types: " |
266 | : ". Must be one of the following types: " , |
267 | TypeListString(attr->allowed_values()), "." ); |
268 | } |
269 | } else { |
270 | return strings::StrCat(prefix, "." ); |
271 | } |
272 | } else { |
273 | return strings::StrCat(prefix, |
274 | is_output ? ". Has the same type as `" |
275 | : ". Must have the same type as `" , |
276 | AvoidPythonReserved(*original_arg), "`." ); |
277 | } |
278 | } else { |
279 | return SingleTensorName(arg.type(), arg.is_ref()); |
280 | } |
281 | } |
282 | |
283 | string GetReturns(const OpDef& op_def, |
284 | const std::vector<string>& output_type_string) { |
285 | string result; |
286 | DCHECK_EQ(op_def.output_arg_size(), output_type_string.size()); |
287 | const int num_outs = op_def.output_arg_size(); |
288 | strings::StrAppend(&result, "\n Returns:\n" ); |
289 | if (num_outs == 0) { |
290 | strings::StrAppend(&result, " The created Operation.\n" ); |
291 | } else { |
292 | if (num_outs == 1) { |
293 | StringPiece description = op_def.output_arg(0).description(); |
294 | if (ConsumeEquals(&description)) { // Skip the generated type info. |
295 | strings::StrAppend(&result, Indent(4, 4, description)); |
296 | } else { |
297 | // Special case of one output, don't use the name of the output unless |
298 | // there is no description. |
299 | string desc = output_type_string.empty() ? kUnknownTensorType |
300 | : output_type_string[0]; |
301 | if (desc == kUnknownTensorType) { |
302 | // Special case where we don't understand how the output tensor type |
303 | // depends on the input tensor types, just use the output arg |
304 | // description if we can. |
305 | if (!description.empty()) { |
306 | desc = op_def.output_arg(0).description(); |
307 | } else if (!op_def.output_arg(0).name().empty()) { |
308 | desc = strings::StrCat(" The " , op_def.output_arg(0).name(), |
309 | " `Tensor`." ); |
310 | } |
311 | } else if (!description.empty()) { |
312 | AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); |
313 | } |
314 | strings::StrAppend(&result, Indent(4, 4, desc)); |
315 | } |
316 | } else { |
317 | std::vector<string> out_names(num_outs); |
318 | for (int i = 0; i < num_outs; ++i) { |
319 | if (!op_def.output_arg(i).name().empty()) { |
320 | out_names[i] = op_def.output_arg(i).name(); |
321 | } else { |
322 | out_names[i] = strings::StrCat("output" , i); |
323 | } |
324 | } |
325 | strings::StrAppend(&result, " A tuple of `Tensor` objects (" , |
326 | absl::StrJoin(out_names, ", " ), ").\n\n" ); |
327 | for (int i = 0; i < num_outs; ++i) { |
328 | string desc = strings::StrCat(out_names[i], ": " ); |
329 | StringPiece description = op_def.output_arg(i).description(); |
330 | if (ConsumeEquals(&description)) { // Skip the generated type info. |
331 | strings::StrAppend(&desc, description); |
332 | } else { |
333 | const string type = static_cast<size_t>(i) < output_type_string.size() |
334 | ? output_type_string[i] |
335 | : kUnknownTensorType; |
336 | if (!description.empty()) { |
337 | if (type == kUnknownTensorType) { |
338 | // Special case where we don't understand how the output tensor |
339 | // type depends on the input tensor types, so we just use the |
340 | // output arg description. |
341 | strings::StrAppend(&desc, description); |
342 | } else { |
343 | strings::StrAppend(&desc, type, " " , description); |
344 | } |
345 | } else { |
346 | strings::StrAppend(&desc, type); |
347 | } |
348 | } |
349 | strings::StrAppend(&result, Indent(4, 6, desc)); |
350 | } |
351 | } |
352 | } |
353 | return result; |
354 | } |
355 | |
356 | string StringToPython(const string& str) { |
357 | return strings::StrCat("\"" , absl::CEscape(str), "\"" ); |
358 | } |
359 | |
360 | string DataTypeToPython(DataType dtype, const string& dtype_module) { |
361 | return strings::StrCat(dtype_module, PythonDataTypeString(dtype)); |
362 | } |
363 | |
364 | string ShapeToPython(const TensorShapeProto& shape) { |
365 | if (shape.unknown_rank()) { |
366 | return "None" ; |
367 | } |
368 | string python = "[" ; |
369 | for (const auto& dim : shape.dim()) { |
370 | if (python.size() > 1) strings::StrAppend(&python, ", " ); |
371 | if (!dim.name().empty()) { |
372 | strings::StrAppend(&python, "(" , StringToPython(dim.name()), ", " , |
373 | dim.size(), ")" ); |
374 | } else { |
375 | strings::StrAppend(&python, dim.size()); |
376 | } |
377 | } |
378 | strings::StrAppend(&python, "]" ); |
379 | return python; |
380 | } |
381 | |
382 | string TensorToPython(const TensorProto& proto) { |
383 | return proto.ShortDebugString(); |
384 | } |
385 | |
386 | string AttrListToPython(const AttrValue& value, |
387 | const string& dtype_module = "tf." ) { |
388 | string ret; |
389 | if (value.list().s_size() > 0) { |
390 | for (int i = 0; i < value.list().s_size(); ++i) { |
391 | if (i > 0) strings::StrAppend(&ret, ", " ); |
392 | strings::StrAppend(&ret, StringToPython(value.list().s(i))); |
393 | } |
394 | } else if (value.list().i_size() > 0) { |
395 | for (int i = 0; i < value.list().i_size(); ++i) { |
396 | if (i > 0) strings::StrAppend(&ret, ", " ); |
397 | strings::StrAppend(&ret, value.list().i(i)); |
398 | } |
399 | } else if (value.list().f_size() > 0) { |
400 | for (int i = 0; i < value.list().f_size(); ++i) { |
401 | if (i > 0) strings::StrAppend(&ret, ", " ); |
402 | strings::StrAppend(&ret, value.list().f(i)); |
403 | } |
404 | } else if (value.list().b_size() > 0) { |
405 | for (int i = 0; i < value.list().b_size(); ++i) { |
406 | if (i > 0) strings::StrAppend(&ret, ", " ); |
407 | strings::StrAppend(&ret, value.list().b(i) ? "True" : "False" ); |
408 | } |
409 | } else if (value.list().type_size() > 0) { |
410 | for (int i = 0; i < value.list().type_size(); ++i) { |
411 | if (i > 0) strings::StrAppend(&ret, ", " ); |
412 | strings::StrAppend(&ret, |
413 | DataTypeToPython(value.list().type(i), dtype_module)); |
414 | } |
415 | } else if (value.list().shape_size() > 0) { |
416 | for (int i = 0; i < value.list().shape_size(); ++i) { |
417 | if (i > 0) strings::StrAppend(&ret, ", " ); |
418 | strings::StrAppend(&ret, ShapeToPython(value.list().shape(i))); |
419 | } |
420 | } else if (value.list().tensor_size() > 0) { |
421 | for (int i = 0; i < value.list().tensor_size(); ++i) { |
422 | if (i > 0) strings::StrAppend(&ret, ", " ); |
423 | strings::StrAppend(&ret, TensorToPython(value.list().tensor(i))); |
424 | } |
425 | } else if (value.list().func_size() > 0) { |
426 | for (int i = 0; i < value.list().func_size(); ++i) { |
427 | if (i > 0) strings::StrAppend(&ret, ", " ); |
428 | strings::StrAppend(&ret, StringToPython(value.list().func(i).name())); |
429 | } |
430 | } |
431 | return ret; |
432 | } |
433 | |
434 | // NOTE: The return value may contain spaces (for example, it could be |
435 | // a string "foo bar" with an embedded space) and is not safe to pass |
436 | // to WordWrap(). |
437 | string AttrValueToPython(const string& type, const AttrValue& value, |
438 | const string& dtype_module) { |
439 | if (type == "string" ) { |
440 | return StringToPython(value.s()); |
441 | } else if (type == "int" ) { |
442 | return strings::StrCat(value.i()); |
443 | } else if (type == "float" ) { |
444 | if (std::isnan(value.f()) || std::isinf(value.f())) { |
445 | return strings::StrCat("float('" , value.f(), "')" ); |
446 | } else { |
447 | // Use locale-independent conversion. |
448 | static_assert(FLT_DIG < 10, "FLT_DIG is too big" ); |
449 | std::ostringstream s; |
450 | s.imbue(std::locale::classic()); |
451 | s << std::setprecision(FLT_DIG) << value.f(); |
452 | // If there is no I/O error for `std::ostringstream s` return s.str(), |
453 | // otherwise fallback to strings::StrCat(value.f()). |
454 | if (s.good()) { |
455 | return s.str(); |
456 | } |
457 | return strings::StrCat(value.f()); |
458 | } |
459 | } else if (type == "bool" ) { |
460 | return value.b() ? "True" : "False" ; |
461 | } else if (type == "type" ) { |
462 | return DataTypeToPython(value.type(), dtype_module); |
463 | } else if (type == "shape" ) { |
464 | return ShapeToPython(value.shape()); |
465 | } else if (type == "tensor" ) { |
466 | return TensorToPython(value.tensor()); |
467 | } else if (type == "func" ) { |
468 | return StringToPython(value.func().name()); |
469 | } else if (absl::StartsWith(type, "list(" )) { |
470 | return strings::StrCat("[" , AttrListToPython(value, dtype_module), "]" ); |
471 | } else { |
472 | return "?" ; |
473 | } |
474 | } |
475 | |
476 | void GenerateLowerCaseOpName(const string& str, string* result) { |
477 | const char joiner = '_'; |
478 | const char namespace_separator = '>'; |
479 | const int last_index = str.size() - 1; |
480 | for (int i = 0; i <= last_index; ++i) { |
481 | const char c = str[i]; |
482 | // Convert namespace separators ('>' characters) to joiners |
483 | if (c == namespace_separator) { |
484 | result->push_back(joiner); |
485 | continue; |
486 | } |
487 | |
488 | // Emit a joiner only if a previous-lower-to-now-upper or a |
489 | // now-upper-to-next-lower transition happens. |
490 | // (But don't emit an extra joiner if we just saw a namespace separator |
491 | if (isupper(c) && (i > 0)) { |
492 | if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) { |
493 | if (!(str[i - 1] == namespace_separator)) { |
494 | result->push_back(joiner); |
495 | } |
496 | } |
497 | } |
498 | result->push_back(tolower(c)); |
499 | } |
500 | } |
501 | |
502 | static void AddDelimiter(string* append_to, const string& delim) { |
503 | if (!append_to->empty()) strings::StrAppend(append_to, delim); |
504 | } |
505 | |
506 | const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) { |
507 | for (int i = 0; i < api_def.attr_size(); ++i) { |
508 | if (api_def.attr(i).name() == name) { |
509 | return &api_def.attr(i); |
510 | } |
511 | } |
512 | return nullptr; |
513 | } |
514 | |
515 | GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def, |
516 | const string& function_name, bool add_type_annotations) |
517 | : op_def_(op_def), |
518 | api_def_(api_def), |
519 | function_name_(function_name), |
520 | add_type_annotations_(add_type_annotations), |
521 | num_outs_(op_def.output_arg_size()) {} |
522 | |
523 | GenPythonOp::~GenPythonOp() {} |
524 | |
525 | string GenPythonOp::Code() { |
526 | // This has all the input args followed by those attrs that don't have |
527 | // defaults. |
528 | std::vector<ParamNames> params_no_default; |
529 | // The parameters with defaults (these have to be listed after those without). |
530 | // No input args are included, just attrs. |
531 | std::vector<ParamNames> params_with_default; |
532 | |
533 | for (int i = 0; i < api_def_.arg_order_size(); ++i) { |
534 | const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); |
535 | const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); |
536 | params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to()); |
537 | if (!arg.type_attr().empty()) { |
538 | gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name()); |
539 | } else if (!arg.type_list_attr().empty()) { |
540 | gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(), |
541 | arg.name()); |
542 | } |
543 | if (!arg.number_attr().empty()) { |
544 | gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name()); |
545 | } |
546 | } |
547 | for (int i = 0; i < api_def_.attr_size(); ++i) { |
548 | const auto& attr(api_def_.attr(i)); |
549 | // Do not add inferred attrs to the Python function signature. |
550 | if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { |
551 | if (attr.has_default_value()) { |
552 | params_with_default.emplace_back(attr.name(), attr.rename_to()); |
553 | } else { |
554 | params_no_default.emplace_back(attr.name(), attr.rename_to()); |
555 | } |
556 | } |
557 | } |
558 | |
559 | // Save the list of attr parameters (attrs that won't be inferred), |
560 | // those with defaults go at the end. |
561 | // Get the attrs in the order we want by taking the attrs without defaults |
562 | // from the end of args_no_default, and adding args_no_default. |
563 | attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() + |
564 | params_with_default.size()); |
565 | for (int i = op_def_.input_arg_size(), end = params_no_default.size(); |
566 | i < end; ++i) { |
567 | attrs_.push_back(params_no_default[i].GetName()); |
568 | } |
569 | for (int i = 0, end = params_with_default.size(); i < end; ++i) { |
570 | attrs_.push_back(params_with_default[i].GetName()); |
571 | } |
572 | |
573 | param_names_.reserve(params_no_default.size() + params_with_default.size()); |
574 | param_names_.insert(param_names_.begin(), params_no_default.begin(), |
575 | params_no_default.end()); |
576 | for (const auto& param : params_with_default) { |
577 | param_names_.push_back(param); |
578 | } |
579 | |
580 | string parameters; |
581 | for (const auto& param : params_no_default) { |
582 | AddDelimiter(¶meters, ", " ); |
583 | strings::StrAppend(¶meters, param.GetRenameTo()); |
584 | } |
585 | for (const auto& param_and_default : params_with_default) { |
586 | AddDelimiter(¶meters, ", " ); |
587 | strings::StrAppend(¶meters, param_and_default.GetRenameTo(), "=None" ); |
588 | } |
589 | AddDelimiter(¶meters, ", " ); |
590 | strings::StrAppend(¶meters, "name=None" ); |
591 | |
592 | AddExport(); |
593 | AddDefLine(parameters); |
594 | AddDocStringDescription(); |
595 | AddDocStringArgs(); |
596 | AddDocStringInputs(); |
597 | AddDocStringAttrs(); |
598 | AddDocStringNameArg(); |
599 | AddOutputGlobals(); |
600 | AddDocStringOutputs(); |
601 | strings::StrAppend(&result_, " \"\"\"\n" ); |
602 | AddBody(" " ); |
603 | strings::StrAppend(&result_, "\n\n" ); |
604 | |
605 | return prelude_ + result_; |
606 | } |
607 | |
608 | void GenPythonOp::AddExport() { |
609 | if (api_def_.visibility() != ApiDef::VISIBLE) { |
610 | return; |
611 | } |
612 | // Whether op should be available in latest export version. |
613 | bool op_available_in_latest = |
614 | !api_def_.deprecation_version() || |
615 | api_def_.deprecation_version() > kLatestAPIExportVersion; |
616 | |
617 | string names; |
618 | string names_v1; |
619 | string deprecated_endpoints; |
620 | |
621 | for (const auto& endpoint : api_def_.endpoint()) { |
622 | string endpoint_name; |
623 | python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(), |
624 | &endpoint_name); |
625 | if (endpoint.deprecated() || endpoint.deprecation_version() > 0) { |
626 | AddDelimiter(&deprecated_endpoints, ", " ); |
627 | strings::StrAppend(&deprecated_endpoints, "'" , endpoint_name, "'" ); |
628 | } |
629 | // Add all endpoints to TensorFlow 1.* API. |
630 | AddDelimiter(&names_v1, ", " ); |
631 | strings::StrAppend(&names_v1, "'" , endpoint_name, "'" ); |
632 | // Add non-deprecated endpoints to TensorFlow 2.* API. |
633 | if (op_available_in_latest && |
634 | (!endpoint.deprecation_version() || |
635 | endpoint.deprecation_version() > kLatestAPIExportVersion)) { |
636 | AddDelimiter(&names, ", " ); |
637 | strings::StrAppend(&names, "'" , endpoint_name, "'" ); |
638 | } |
639 | } |
640 | |
641 | // tf_export decorator has the following format: |
642 | // @tf_export(v2_name, v2_name, v1=[v1_name, v1_name]) |
643 | if (names != names_v1) { |
644 | AddDelimiter(&names, ", " ); |
645 | strings::StrAppend(&names, "v1=[" , names_v1, "]" ); |
646 | } |
647 | strings::StrAppend(&result_, "@tf_export(" , names, ")\n" ); |
648 | |
649 | // If all endpoints are deprecated, add @deprecated decorator. |
650 | if (!api_def_.deprecation_message().empty()) { |
651 | const string instructions = api_def_.deprecation_message(); |
652 | strings::StrAppend(&result_, "@deprecated(None, '" , instructions, "')\n" ); |
653 | } |
654 | // Add @deprecated_endpoints decorator. |
655 | if (!deprecated_endpoints.empty()) { |
656 | strings::StrAppend(&result_, "@deprecated_endpoints(" , deprecated_endpoints, |
657 | ")\n" ); |
658 | } |
659 | } |
660 | |
661 | void GenPythonOp::AddDefLine(const string& function_name, |
662 | const string& parameters) { |
663 | strings::StrAppend(&result_, "def " , function_name, "(" , parameters, "):\n" ); |
664 | } |
665 | |
666 | void GenPythonOp::AddDefLine(const string& parameters) { |
667 | AddDefLine(function_name_, parameters); |
668 | } |
669 | |
670 | void GenPythonOp::AddDocStringDescription() { |
671 | string ; |
672 | if (api_def_.summary().empty()) { |
673 | comment = "TODO: add doc.\n" ; |
674 | } else { |
675 | comment = strings::StrCat(api_def_.summary(), "\n" ); |
676 | if (!api_def_.description().empty()) { |
677 | strings::StrAppend(&comment, "\n" , Indent(2, 2, api_def_.description())); |
678 | } |
679 | } |
680 | strings::StrAppend(&result_, " r\"\"\"" , comment, "\n" ); |
681 | } |
682 | |
683 | void GenPythonOp::AddDocStringArgs() { |
684 | strings::StrAppend(&result_, " Args:\n" ); |
685 | } |
686 | |
687 | void GenPythonOp::AddDocStringInputs() { |
688 | for (int i = 0; i < api_def_.arg_order_size(); ++i) { |
689 | const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); |
690 | const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); |
691 | StringPiece description = api_def_arg.description(); |
692 | string desc; |
693 | if (ConsumeEquals(&description)) { // Skip the generated type info. |
694 | desc = strings::StrCat(param_names_[i].GetRenameTo(), ": " ); |
695 | } else { |
696 | desc = strings::StrCat(param_names_[i].GetRenameTo(), ": " , |
697 | ArgTypeName(op_def_, arg, inferred_attrs_, false)); |
698 | } |
699 | if (!description.empty()) { |
700 | AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); |
701 | } |
702 | strings::StrAppend(&result_, Indent(4, 6, desc)); |
703 | } |
704 | } |
705 | |
706 | void GenPythonOp::AddDocStringAttrs() { |
707 | for (const string& name : attrs_) { |
708 | const auto& attr = *FindAttr(name, op_def_); |
709 | const auto& api_def_attr = *FindAttr(name, api_def_); |
710 | string desc = |
711 | strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": " ); |
712 | |
713 | static const char* const kAttrTypeName[][2] = { |
714 | {"string" , "`string`" }, |
715 | {"list(string)" , "list of `strings`" }, |
716 | {"int" , "`int`" }, |
717 | {"list(int)" , "list of `ints`" }, |
718 | {"float" , "`float`" }, |
719 | {"list(float)" , "list of `floats`" }, |
720 | {"bool" , "`bool`" }, |
721 | {"list(bool)" , "list of `bools`" }, |
722 | {"type" , "`tf.DType`" }, |
723 | {"list(type)" , "list of `tf.DTypes`" }, |
724 | {"shape" , "`tf.TensorShape` or list of `ints`" }, |
725 | {"list(shape)" , |
726 | "list of shapes (each a `tf.TensorShape` or list of `ints`)" }, |
727 | {"tensor" , "`tf.TensorProto`" }, |
728 | {"list(tensor)" , "list of `tf.TensorProto` objects" }, |
729 | {"func" , "function decorated with @Defun" }, |
730 | {"list(func)" , "list of functions decorated with @Defun" }, |
731 | }; |
732 | for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) { |
733 | if (attr.type() == kAttrTypeName[i][0]) { |
734 | string s; |
735 | if (api_def_attr.has_default_value()) { |
736 | s = strings::StrCat("optional " , kAttrTypeName[i][1]); |
737 | } else { |
738 | s = kAttrTypeName[i][1]; |
739 | } |
740 | if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) { |
741 | strings::StrAppend(&desc, "An " , s); |
742 | } else { |
743 | strings::StrAppend(&desc, "A " , s); |
744 | } |
745 | break; |
746 | } |
747 | } |
748 | |
749 | if (attr.has_allowed_values()) { |
750 | strings::StrAppend(&desc, " from: `" , |
751 | AttrListToPython(attr.allowed_values()), "`" ); |
752 | } |
753 | |
754 | if (attr.has_minimum()) { |
755 | if (attr.type() == "int" ) { |
756 | strings::StrAppend(&desc, " that is `>= " , attr.minimum(), "`" ); |
757 | } else if (attr.minimum() > 0) { |
758 | strings::StrAppend(&desc, " that has length `>= " , attr.minimum(), "`" ); |
759 | } |
760 | } |
761 | |
762 | strings::StrAppend(&desc, "." ); |
763 | |
764 | if (api_def_attr.has_default_value()) { |
765 | strings::StrAppend( |
766 | &desc, " Defaults to `" , |
767 | AttrValueToPython(attr.type(), api_def_attr.default_value()), "`." ); |
768 | } |
769 | if (!api_def_attr.description().empty()) { |
770 | AppendWithinWidth(&desc, api_def_attr.description(), |
771 | kRightMargin - 4 /* indent */); |
772 | } |
773 | strings::StrAppend(&result_, Indent(4, 6, desc)); |
774 | } |
775 | } |
776 | |
777 | void GenPythonOp::AddDocStringNameArg() { |
778 | strings::StrAppend(&result_, |
779 | " name: A name for the operation (optional).\n" ); |
780 | } |
781 | |
782 | void GenPythonOp::AddOutputGlobals() { |
783 | // Generate a namedtuple class to hold the outputs, if there are multiple. |
784 | // Example: |
785 | // |
786 | // _OpOutputs = collections.namedtuple( |
787 | // "_OpOutputs", |
788 | // "out1 out2 out3") |
789 | if (num_outs_ > 1) { |
790 | std::vector<string> out_names; |
791 | out_names.reserve(num_outs_); |
792 | for (int i = 0; i < num_outs_; ++i) { |
793 | const string out_name = !api_def_.out_arg(i).rename_to().empty() |
794 | ? api_def_.out_arg(i).rename_to() |
795 | : strings::StrCat("output" , i); |
796 | out_names.push_back(strings::StrCat("\"" , out_name, "\"" )); |
797 | } |
798 | |
799 | strings::StrAppend(&prelude_, "_" , AvoidPythonReserved(op_def_.name()), |
800 | "Output = collections.namedtuple(\n" ); |
801 | strings::StrAppend(&prelude_, " \"" , AvoidPythonReserved(op_def_.name()), |
802 | "\",\n" ); |
803 | strings::StrAppend(&prelude_, " [" , absl::StrJoin(out_names, ", " ), |
804 | "])" ); |
805 | strings::StrAppend(&prelude_, "\n\n" ); |
806 | } |
807 | strings::StrAppend(&prelude_, "\n" ); |
808 | } |
809 | |
810 | void GenPythonOp::AddDocStringOutputs() { |
811 | std::vector<string> output_type_string; |
812 | output_type_string.reserve(num_outs_); |
813 | for (int i = 0; i < num_outs_; ++i) { |
814 | output_type_string.push_back( |
815 | ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true)); |
816 | } |
817 | strings::StrAppend(&result_, GetReturns(op_def_, output_type_string)); |
818 | } |
819 | |
820 | void GenPythonOp::AddBody(const string& prefix) { |
821 | const string apply_prefix = strings::StrCat( |
822 | prefix, "_result = _op_def_lib.apply_op(\"" , op_def_.name(), "\", " ); |
823 | AddBodyNoReturn(apply_prefix); |
824 | if (num_outs_ > 1) { |
825 | strings::StrAppend(&result_, prefix, "_result = _" , |
826 | AvoidPythonReserved(op_def_.name()), |
827 | "Output._make(_result)\n" ); |
828 | } |
829 | strings::StrAppend(&result_, prefix, "return _result\n" ); |
830 | } |
831 | |
832 | void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) { |
833 | string args; |
834 | for (size_t i = 0; i < param_names_.size(); ++i) { |
835 | strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()), |
836 | "=" , param_names_[i].GetRenameTo(), ", " ); |
837 | } |
838 | strings::StrAppend(&args, "name=name)" ); |
839 | |
840 | strings::StrAppend(&result_, |
841 | // Wrap the arguments, and indent to the (. |
842 | WordWrap(apply_prefix, args, kRightMargin), "\n" ); |
843 | } |
844 | |
845 | } // namespace python_op_gen_internal |
846 | } // namespace tensorflow |
847 | |