1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_TOOLS_NODEGEN_NODEBUILDER_H |
17 | #define GLOW_TOOLS_NODEGEN_NODEBUILDER_H |
18 | |
19 | #include "MemberType.h" |
20 | #include "glow/Support/Support.h" |
21 | |
22 | #include "llvm/ADT/ArrayRef.h" |
23 | |
24 | #include <cassert> |
25 | #include <fstream> |
26 | #include <iostream> |
27 | #include <sstream> |
28 | #include <string> |
29 | #include <unordered_map> |
30 | #include <vector> |
31 | |
32 | class Builder; |
33 | class NodeBuilder; |
34 | |
35 | class NodeBuilder { |
36 | /// The node name. |
37 | std::string name_; |
38 | /// The node operands. |
39 | std::vector<std::string> nodeInputs_; |
40 | /// A list of node inputs that are overwritten, i.e. are @out parameters |
41 | /// essentially. |
42 | std::vector<unsigned> nodeOverwrittenInputs_; |
43 | /// Initializes the result types of the nodes. The first argument is the c++ |
44 | /// expression that computes the type. For example "X->getType()". The second |
45 | /// argument is the name of the return type. Format: (type, name) |
46 | std::vector<std::pair<std::string, std::string>> nodeOutputs_; |
47 | /// A list of node members. Format: (type, name). |
48 | std::vector<std::pair<MemberTypeInfo, std::string>> members_; |
49 | /// The node enum cases. |
50 | std::vector<std::string> enum_; |
51 | /// A list of extra parameter that are declared in the node constructor. The |
52 | /// arguments are used when creating the result types of the node. |
53 | std::vector<std::string> ctorTypeParams_; |
54 | /// Stores the decl and body of a new public method that will be added to the |
55 | /// class. |
56 | std::vector<std::pair<std::string, std::string>> ; |
57 | /// Header file stream. |
58 | std::ofstream &hStream; |
59 | /// CPP file stream. |
60 | std::ofstream &cStream; |
61 | /// Def file stream. |
62 | std::ofstream &dStream; |
63 | /// Import file stream. |
64 | std::ofstream &iStream; |
65 | /// Export file stream. |
66 | std::ofstream &eStream; |
67 | /// Documentation string printed with the class definition. |
68 | std::string docstring_; |
69 | /// Whether node has side effects. By default there are no side effects. |
70 | bool hasSideEffects_{false}; |
71 | /// Specifies if this Node is backend specific. |
72 | bool isBackendSpecific_{false}; |
73 | /// Specifies if this Node is data parallel. |
74 | bool isDataParallel_{false}; |
75 | /// Specifies if this Node can have extra results. |
76 | bool {false}; |
77 | /// Specifies if this Node should skip serialization autogen. |
78 | bool skipAutogenSerialization_{false}; |
79 | |
80 | public: |
81 | NodeBuilder(std::ofstream &H, std::ofstream &C, std::ofstream &D, |
82 | std::ofstream &I, std::ofstream &E, const std::string &name, |
83 | bool isBackendSpecific) |
84 | : name_(name), hStream(H), cStream(C), dStream(D), iStream(I), eStream(E), |
85 | isBackendSpecific_(isBackendSpecific) { |
86 | dStream << "DEF_NODE(" << name << "Node, " << name << ")\n" ; |
87 | } |
88 | |
89 | /// Add an operand to the node. The name should start with a capital letter. |
90 | /// For example: "Input". |
91 | NodeBuilder &addInput(const std::string &op) { |
92 | nodeInputs_.push_back(op); |
93 | return *this; |
94 | } |
95 | /// Add a member to the node. Format: type, name. |
96 | /// The name should start with a capital letter. |
97 | /// For example: "Filter". |
98 | /// If \p addSetter then a setter will be generated for this member. |
99 | NodeBuilder &addMember(MemberType type, const std::string &name, |
100 | bool addSetter = false); |
101 | /// Add a member to the node. Format type, name. |
102 | /// The name should start with a capital letter. |
103 | /// For example: "Filter". |
104 | /// If MemberTypeInfo refers to an external user-defined type, this type T |
105 | /// should satisfy the following requirements: |
106 | /// * There should be a hash function with a signature like `llvm::hash_code |
107 | /// hash_value(const T)` which takes T by value, by reference or as a |
108 | /// pointer, depending on the intended use. |
109 | /// * There should be a stream output operator with a signature like |
110 | /// `llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const T);`, which |
111 | /// takes T by value, by reference or as a pointer, depending on the |
112 | /// intended use. |
113 | /// * There should be a comparison operator `bool operator==(const T LHS, |
114 | /// const T RHS)` (or a custom comparator function mentioned in |
115 | /// MemberTypeInfo::cmpFn), which takes T by reference or by value depending |
116 | /// on the intended use. |
117 | /// If \p addSetter then a setter will be generated for this member. |
118 | NodeBuilder &addMember(MemberTypeInfo typeInfo, const std::string &name, |
119 | bool addSetter = false) { |
120 | typeInfo.addSetter = addSetter; |
121 | members_.push_back({typeInfo, name}); |
122 | return *this; |
123 | } |
124 | |
125 | /// Adds the body of a new public method to the class. \p decl is the |
126 | /// decleration that goes in the header file. \p body is the implementation |
127 | /// that goes in the cpp file. |
128 | NodeBuilder &(const std::string &decl, |
129 | const std::string &body) { |
130 | extraMethods_.push_back(std::make_pair(decl, body)); |
131 | return *this; |
132 | } |
133 | |
134 | /// Add an field to the enum. The enum name should start with a capital |
135 | /// letter. For example: "External". |
136 | NodeBuilder &addEnumCase(const std::string &op) { |
137 | enum_.push_back(op); |
138 | return *this; |
139 | } |
140 | /// Set the expression that initializes a new return type for the node. |
141 | /// Example: 'LHS->getType()', "Result". |
142 | NodeBuilder &addResult(const std::string &ty, |
143 | const std::string &name = "Result" ) { |
144 | nodeOutputs_.push_back({ty, name}); |
145 | return *this; |
146 | } |
147 | /// Add a TypeRef parameter to the constructor and use this argument to add |
148 | /// a result type to the node. |
149 | NodeBuilder &addResultFromCtorArg(const std::string &name = "Result" ) { |
150 | ctorTypeParams_.push_back(name); |
151 | nodeOutputs_.push_back({name, name}); |
152 | return *this; |
153 | } |
154 | |
155 | /// Set the documentation string. Each line will be prepended with "/// ". |
156 | NodeBuilder &setDocstring(const std::string &docstring) { |
157 | docstring_ = docstring; |
158 | return *this; |
159 | } |
160 | |
161 | /// Set whether node has side effects. |
162 | NodeBuilder &setHasSideEffects(bool hasSideEffects) { |
163 | hasSideEffects_ = hasSideEffects; |
164 | return *this; |
165 | } |
166 | |
167 | NodeBuilder &addOverwrittenInput(const std::string &name) { |
168 | // Find the index of the overwritten input. |
169 | for (unsigned idx = 0, e = nodeInputs_.size(); idx < e; ++idx) { |
170 | if (nodeInputs_[idx] == name) { |
171 | nodeOverwrittenInputs_.push_back(idx); |
172 | return *this; |
173 | } |
174 | } |
175 | llvm_unreachable("Cannot register an overwritten input that is not a known " |
176 | "input of a node" ); |
177 | } |
178 | |
179 | NodeBuilder &dataParallel() { |
180 | isDataParallel_ = true; |
181 | return *this; |
182 | } |
183 | |
184 | NodeBuilder &() { |
185 | hasExtraResults_ = true; |
186 | return *this; |
187 | } |
188 | |
189 | NodeBuilder &skipAutogenSerialization() { |
190 | skipAutogenSerialization_ = true; |
191 | return *this; |
192 | } |
193 | |
194 | /// Constructs a new gradient node that is based on the current node that we |
195 | /// are building. The gradient node will produce one gradient output for each |
196 | /// input. The rule is that each output becomes an input (named "Output", to |
197 | /// preserve the original name) and each input becomes a gradient output with |
198 | /// the same name. |
199 | NodeBuilder &addGradient(); |
200 | |
201 | /// Helper to add a FusedActivation Member to this node, along with getters |
202 | /// and setters. |
203 | NodeBuilder &addFusedActivation(); |
204 | |
205 | ~NodeBuilder(); |
206 | |
207 | private: |
208 | /// Emit required forward declarations for node members. |
209 | void emitMemberForwardDecls(std::ostream &os) const; |
210 | |
211 | /// Emits the methods that converts an enum case into a textual label. |
212 | void emitEnumModePrinters(std::ostream &os) const; |
213 | |
214 | /// Emit the Node class constructor. |
215 | void emitCtor(std::ostream &os) const; |
216 | |
217 | /// Emits the class members (the fields of the class). |
218 | void emitClassMembers(std::ostream &os) const; |
219 | |
220 | /// Emit the getter for a accessible class member, and optionally a setter. |
221 | void emitMemberGetterSetter(std::ostream &os, const MemberTypeInfo *type, |
222 | const std::string &name) const; |
223 | |
224 | /// Emit setters/getters for each accessible class member. |
225 | void emitSettersGetters(std::ostream &os) const; |
226 | |
227 | /// Emit getters for input/output names and input nodes. |
228 | void emitEdges(std::ostream &os) const; |
229 | |
230 | /// Emit the methods that print a textual summary of the node. |
231 | void emitPrettyPrinter(std::ostream &os) const; |
232 | |
233 | /// Emit the isEqual method that performs node comparisons. |
234 | void emitEquator(std::ostream &os) const; |
235 | |
236 | /// Emit the clone() method copies the node. |
237 | void emitCloner(std::ostream &os) const; |
238 | |
239 | /// Emit the getHash method that computes a hash of a node. |
240 | void emitHasher(std::ostream &os) const; |
241 | |
242 | /// Emit the 'visit' method that implements node visitors. |
243 | void emitVisitor(std::ostream &os) const; |
244 | |
245 | /// Emit the class-level documentation string, if any. |
246 | void emitDocstring(std::ostream &os) const; |
247 | |
248 | /// Emit enums for each of the node's inputs and results indices. |
249 | void emitIndicesEnum(std::ostream &os) const; |
250 | |
251 | /// Emit the class definition for the node. |
252 | void emitNodeClass(std::ostream &os) const; |
253 | |
254 | /// Emit the methods that go into the CPP file and implement the methods that |
255 | /// were declared in the header file. |
256 | void emitCppMethods(std::ostream &os) const; |
257 | |
258 | /// Emit cases for importing to \p os. |
259 | void emitImportMethods(std::ostream &os) const; |
260 | |
261 | /// Emit cases for exporting to \p os. |
262 | void emitExportMethods(std::ostream &os) const; |
263 | |
264 | // \returns whether \p res is contained in \ref ctorTypeParams_. |
265 | bool hasCtorTypeParams(llvm::StringRef res) const; |
266 | }; |
267 | |
268 | class Builder { |
269 | std::ofstream &hStream; |
270 | std::ofstream &cStream; |
271 | std::ofstream &dStream; |
272 | std::ofstream &iStream; |
273 | std::ofstream &eStream; |
274 | |
275 | public: |
276 | /// Create a new top-level builder that holds the three output streams that |
277 | /// point to the header file, cpp file and enum definition file. |
278 | Builder(std::ofstream &H, std::ofstream &C, std::ofstream &D, |
279 | std::ofstream &I, std::ofstream &E) |
280 | : hStream(H), cStream(C), dStream(D), iStream(I), eStream(E) { |
281 | cStream << "#include \"glow/Graph/Nodes.h\"\n" |
282 | "#include \"glow/Base/Type.h\"\n" |
283 | "#include \"glow/Support/Support.h\"\n" |
284 | "using namespace glow;\n" ; |
285 | dStream << "#ifndef DEF_NODE\n#error The macro DEF_NODE was not declared.\n" |
286 | "#endif\n" ; |
287 | } |
288 | |
289 | ~Builder() { dStream << "#undef DEF_NODE" ; } |
290 | |
291 | /// Declare a new node and generate code for it. |
292 | NodeBuilder newNode(const std::string &name) { |
293 | const bool isBackendSpecific = false; |
294 | return NodeBuilder(hStream, cStream, dStream, iStream, eStream, name, |
295 | isBackendSpecific); |
296 | } |
297 | |
298 | /// Declare a new backend specific node and generate code for it. |
299 | NodeBuilder newBackendSpecificNode(const std::string &name) { |
300 | const bool isBackendSpecific = true; |
301 | return NodeBuilder(hStream, cStream, dStream, iStream, eStream, name, |
302 | isBackendSpecific); |
303 | } |
304 | |
305 | /// Declare the node in the def file but don't generate code for it. |
306 | void declareNode(const std::string &name) { |
307 | dStream << "DEF_NODE(" << name << ", " << name << ")\n" ; |
308 | } |
309 | |
310 | /// Include backend-specific verification at the end of the auto-generated |
311 | /// Nodes cpp file. |
312 | void includeBackendSpecificVerification(const std::string &filename) { |
313 | cStream << "\n#include \"" << filename << "\"\n" ; |
314 | } |
315 | |
316 | /// Include header into the auto-generated Nodes include file. |
317 | void (const std::string &filename) { |
318 | hStream << "\n#include \"" << filename << "\"\n" ; |
319 | } |
320 | }; |
321 | |
322 | #endif // GLOW_TOOLS_NODEGEN_NODEBUILDER_H |
323 | |