1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ |
17 | #define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ |
18 | |
19 | #include <unordered_map> |
20 | |
21 | #include "tensorflow/core/framework/api_def.pb.h" |
22 | #include "tensorflow/core/framework/attr_value.pb.h" |
23 | #include "tensorflow/core/framework/op_def.pb.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | namespace tensorflow { |
27 | namespace python_op_gen_internal { |
28 | |
29 | // Returns true if s is a Python keyword or built-in. |
30 | bool IsPythonReserved(const string& s); |
31 | |
32 | // Whether the op should be prefixed with underscore. |
33 | bool IsOpWithUnderscorePrefix(const string& s); |
34 | |
35 | // Add a _ to the end of s if necessary to avoid a Python keyword or built-in. |
36 | // Also convert namespace characters ('>') to '_' because python does not |
37 | // support '>' in names |
38 | string AvoidPythonReserved(const string& s); |
39 | |
40 | // Convert an AttrValue with type `type` to the Python representation for |
41 | // that value. |
42 | string AttrValueToPython(const string& type, const AttrValue& value, |
43 | const string& dtype_module = "tf." ); |
44 | |
45 | void GenerateLowerCaseOpName(const string& str, string* result); |
46 | |
47 | string DataTypeToPython(DataType dtype, const string& dtype_module); |
48 | |
49 | // Names that corresponds to a single input parameter. |
50 | class ParamNames { |
51 | public: |
52 | // Create param based on Arg. |
53 | ParamNames(const string& name, const string& rename_to) : name_(name) { |
54 | rename_to_ = AvoidPythonReserved(rename_to); |
55 | } |
56 | |
57 | // Get original parameter name. |
58 | string GetName() const { return name_; } |
59 | |
60 | // Get the name to rename the parameter to. Note that AvoidPythonReserved |
61 | // has already been applied. |
62 | string GetRenameTo() const { return rename_to_; } |
63 | |
64 | private: |
65 | // Original parameter name. |
66 | string name_; |
67 | // API name for this parameter. |
68 | string rename_to_; |
69 | }; |
70 | |
71 | class GenPythonOp { |
72 | public: |
73 | GenPythonOp(const OpDef& op_def, const ApiDef& api_def, |
74 | const string& function_name, bool add_type_annotations_); |
75 | virtual ~GenPythonOp(); |
76 | |
77 | virtual string Code(); |
78 | |
79 | protected: |
80 | // Print: def Function(parameters): |
81 | void AddDefLine(const string& function_name, const string& parameters); |
82 | void AddDefLine(const string& parameters); |
83 | |
84 | // Format the Op's descriptions so that it can be a Python docstring. |
85 | void AddDocStringDescription(); |
86 | |
87 | void AddDocStringArgs(); |
88 | void AddDocStringInputs(); |
89 | void AddDocStringAttrs(); |
90 | void AddDocStringNameArg(); |
91 | void AddOutputGlobals(); |
92 | void AddDocStringOutputs(); |
93 | void AddBody(const string& prefix); |
94 | void AddBodyNoReturn(const string& apply_prefix); |
95 | void AddExport(); |
96 | |
97 | // From constructor arguments |
98 | const OpDef& op_def_; |
99 | const ApiDef& api_def_; |
100 | const string function_name_; |
101 | bool add_type_annotations_; |
102 | const int num_outs_; |
103 | |
104 | // Return value from Code() is prelude_ + result_. |
105 | string prelude_; // Code before function definition |
106 | string result_; // Function definition |
107 | |
108 | // Map from attr name to the first input arg it is inferred from |
109 | std::unordered_map<string, string> inferred_attrs_; |
110 | |
111 | // The names of the non-inferred attrs, in parameter order |
112 | std::vector<string> attrs_; |
113 | |
114 | // All parameters, including inputs & non-inferred attrs, required and those |
115 | // with defaults, except "name" |
116 | std::vector<ParamNames> param_names_; |
117 | }; |
118 | |
119 | } // namespace python_op_gen_internal |
120 | } // namespace tensorflow |
121 | |
122 | #endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ |
123 | |