1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
17#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
18
19#include <unordered_map>
20
21#include "tensorflow/core/framework/api_def.pb.h"
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/op_def.pb.h"
24#include "tensorflow/core/platform/types.h"
25
26namespace tensorflow {
27namespace python_op_gen_internal {
28
29// Returns true if s is a Python keyword or built-in.
30bool IsPythonReserved(const string& s);
31
32// Whether the op should be prefixed with underscore.
33bool IsOpWithUnderscorePrefix(const string& s);
34
35// Add a _ to the end of s if necessary to avoid a Python keyword or built-in.
36// Also convert namespace characters ('>') to '_' because python does not
37// support '>' in names
38string AvoidPythonReserved(const string& s);
39
40// Convert an AttrValue with type `type` to the Python representation for
41// that value.
42string AttrValueToPython(const string& type, const AttrValue& value,
43 const string& dtype_module = "tf.");
44
45void GenerateLowerCaseOpName(const string& str, string* result);
46
47string DataTypeToPython(DataType dtype, const string& dtype_module);
48
49// Names that corresponds to a single input parameter.
50class ParamNames {
51 public:
52 // Create param based on Arg.
53 ParamNames(const string& name, const string& rename_to) : name_(name) {
54 rename_to_ = AvoidPythonReserved(rename_to);
55 }
56
57 // Get original parameter name.
58 string GetName() const { return name_; }
59
60 // Get the name to rename the parameter to. Note that AvoidPythonReserved
61 // has already been applied.
62 string GetRenameTo() const { return rename_to_; }
63
64 private:
65 // Original parameter name.
66 string name_;
67 // API name for this parameter.
68 string rename_to_;
69};
70
71class GenPythonOp {
72 public:
73 GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
74 const string& function_name, bool add_type_annotations_);
75 virtual ~GenPythonOp();
76
77 virtual string Code();
78
79 protected:
80 // Print: def Function(parameters):
81 void AddDefLine(const string& function_name, const string& parameters);
82 void AddDefLine(const string& parameters);
83
84 // Format the Op's descriptions so that it can be a Python docstring.
85 void AddDocStringDescription();
86
87 void AddDocStringArgs();
88 void AddDocStringInputs();
89 void AddDocStringAttrs();
90 void AddDocStringNameArg();
91 void AddOutputGlobals();
92 void AddDocStringOutputs();
93 void AddBody(const string& prefix);
94 void AddBodyNoReturn(const string& apply_prefix);
95 void AddExport();
96
97 // From constructor arguments
98 const OpDef& op_def_;
99 const ApiDef& api_def_;
100 const string function_name_;
101 bool add_type_annotations_;
102 const int num_outs_;
103
104 // Return value from Code() is prelude_ + result_.
105 string prelude_; // Code before function definition
106 string result_; // Function definition
107
108 // Map from attr name to the first input arg it is inferred from
109 std::unordered_map<string, string> inferred_attrs_;
110
111 // The names of the non-inferred attrs, in parameter order
112 std::vector<string> attrs_;
113
114 // All parameters, including inputs & non-inferred attrs, required and those
115 // with defaults, except "name"
116 std::vector<ParamNames> param_names_;
117};
118
119} // namespace python_op_gen_internal
120} // namespace tensorflow
121
122#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
123