1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_TOOLS_NODEGEN_NODEBUILDER_H
17#define GLOW_TOOLS_NODEGEN_NODEBUILDER_H
18
19#include "MemberType.h"
20#include "glow/Support/Support.h"
21
22#include "llvm/ADT/ArrayRef.h"
23
24#include <cassert>
25#include <fstream>
26#include <iostream>
27#include <sstream>
28#include <string>
29#include <unordered_map>
30#include <vector>
31
32class Builder;
33class NodeBuilder;
34
35class NodeBuilder {
36 /// The node name.
37 std::string name_;
38 /// The node operands.
39 std::vector<std::string> nodeInputs_;
40 /// A list of node inputs that are overwritten, i.e. are @out parameters
41 /// essentially.
42 std::vector<unsigned> nodeOverwrittenInputs_;
43 /// Initializes the result types of the nodes. The first argument is the c++
44 /// expression that computes the type. For example "X->getType()". The second
45 /// argument is the name of the return type. Format: (type, name)
46 std::vector<std::pair<std::string, std::string>> nodeOutputs_;
47 /// A list of node members. Format: (type, name).
48 std::vector<std::pair<MemberTypeInfo, std::string>> members_;
49 /// The node enum cases.
50 std::vector<std::string> enum_;
51 /// A list of extra parameter that are declared in the node constructor. The
52 /// arguments are used when creating the result types of the node.
53 std::vector<std::string> ctorTypeParams_;
54 /// Stores the decl and body of a new public method that will be added to the
55 /// class.
56 std::vector<std::pair<std::string, std::string>> extraMethods_;
57 /// Header file stream.
58 std::ofstream &hStream;
59 /// CPP file stream.
60 std::ofstream &cStream;
61 /// Def file stream.
62 std::ofstream &dStream;
63 /// Import file stream.
64 std::ofstream &iStream;
65 /// Export file stream.
66 std::ofstream &eStream;
67 /// Documentation string printed with the class definition.
68 std::string docstring_;
69 /// Whether node has side effects. By default there are no side effects.
70 bool hasSideEffects_{false};
71 /// Specifies if this Node is backend specific.
72 bool isBackendSpecific_{false};
73 /// Specifies if this Node is data parallel.
74 bool isDataParallel_{false};
75 /// Specifies if this Node can have extra results.
76 bool hasExtraResults_{false};
77 /// Specifies if this Node should skip serialization autogen.
78 bool skipAutogenSerialization_{false};
79
80public:
81 NodeBuilder(std::ofstream &H, std::ofstream &C, std::ofstream &D,
82 std::ofstream &I, std::ofstream &E, const std::string &name,
83 bool isBackendSpecific)
84 : name_(name), hStream(H), cStream(C), dStream(D), iStream(I), eStream(E),
85 isBackendSpecific_(isBackendSpecific) {
86 dStream << "DEF_NODE(" << name << "Node, " << name << ")\n";
87 }
88
89 /// Add an operand to the node. The name should start with a capital letter.
90 /// For example: "Input".
91 NodeBuilder &addInput(const std::string &op) {
92 nodeInputs_.push_back(op);
93 return *this;
94 }
95 /// Add a member to the node. Format: type, name.
96 /// The name should start with a capital letter.
97 /// For example: "Filter".
98 /// If \p addSetter then a setter will be generated for this member.
99 NodeBuilder &addMember(MemberType type, const std::string &name,
100 bool addSetter = false);
101 /// Add a member to the node. Format type, name.
102 /// The name should start with a capital letter.
103 /// For example: "Filter".
104 /// If MemberTypeInfo refers to an external user-defined type, this type T
105 /// should satisfy the following requirements:
106 /// * There should be a hash function with a signature like `llvm::hash_code
107 /// hash_value(const T)` which takes T by value, by reference or as a
108 /// pointer, depending on the intended use.
109 /// * There should be a stream output operator with a signature like
110 /// `llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const T);`, which
111 /// takes T by value, by reference or as a pointer, depending on the
112 /// intended use.
113 /// * There should be a comparison operator `bool operator==(const T LHS,
114 /// const T RHS)` (or a custom comparator function mentioned in
115 /// MemberTypeInfo::cmpFn), which takes T by reference or by value depending
116 /// on the intended use.
117 /// If \p addSetter then a setter will be generated for this member.
118 NodeBuilder &addMember(MemberTypeInfo typeInfo, const std::string &name,
119 bool addSetter = false) {
120 typeInfo.addSetter = addSetter;
121 members_.push_back({typeInfo, name});
122 return *this;
123 }
124
125 /// Adds the body of a new public method to the class. \p decl is the
126 /// decleration that goes in the header file. \p body is the implementation
127 /// that goes in the cpp file.
128 NodeBuilder &addExtraMethod(const std::string &decl,
129 const std::string &body) {
130 extraMethods_.push_back(std::make_pair(decl, body));
131 return *this;
132 }
133
134 /// Add an field to the enum. The enum name should start with a capital
135 /// letter. For example: "External".
136 NodeBuilder &addEnumCase(const std::string &op) {
137 enum_.push_back(op);
138 return *this;
139 }
140 /// Set the expression that initializes a new return type for the node.
141 /// Example: 'LHS->getType()', "Result".
142 NodeBuilder &addResult(const std::string &ty,
143 const std::string &name = "Result") {
144 nodeOutputs_.push_back({ty, name});
145 return *this;
146 }
147 /// Add a TypeRef parameter to the constructor and use this argument to add
148 /// a result type to the node.
149 NodeBuilder &addResultFromCtorArg(const std::string &name = "Result") {
150 ctorTypeParams_.push_back(name);
151 nodeOutputs_.push_back({name, name});
152 return *this;
153 }
154
155 /// Set the documentation string. Each line will be prepended with "/// ".
156 NodeBuilder &setDocstring(const std::string &docstring) {
157 docstring_ = docstring;
158 return *this;
159 }
160
161 /// Set whether node has side effects.
162 NodeBuilder &setHasSideEffects(bool hasSideEffects) {
163 hasSideEffects_ = hasSideEffects;
164 return *this;
165 }
166
167 NodeBuilder &addOverwrittenInput(const std::string &name) {
168 // Find the index of the overwritten input.
169 for (unsigned idx = 0, e = nodeInputs_.size(); idx < e; ++idx) {
170 if (nodeInputs_[idx] == name) {
171 nodeOverwrittenInputs_.push_back(idx);
172 return *this;
173 }
174 }
175 llvm_unreachable("Cannot register an overwritten input that is not a known "
176 "input of a node");
177 }
178
179 NodeBuilder &dataParallel() {
180 isDataParallel_ = true;
181 return *this;
182 }
183
184 NodeBuilder &hasExtraResults() {
185 hasExtraResults_ = true;
186 return *this;
187 }
188
189 NodeBuilder &skipAutogenSerialization() {
190 skipAutogenSerialization_ = true;
191 return *this;
192 }
193
194 /// Constructs a new gradient node that is based on the current node that we
195 /// are building. The gradient node will produce one gradient output for each
196 /// input. The rule is that each output becomes an input (named "Output", to
197 /// preserve the original name) and each input becomes a gradient output with
198 /// the same name.
199 NodeBuilder &addGradient();
200
201 /// Helper to add a FusedActivation Member to this node, along with getters
202 /// and setters.
203 NodeBuilder &addFusedActivation();
204
205 ~NodeBuilder();
206
207private:
208 /// Emit required forward declarations for node members.
209 void emitMemberForwardDecls(std::ostream &os) const;
210
211 /// Emits the methods that converts an enum case into a textual label.
212 void emitEnumModePrinters(std::ostream &os) const;
213
214 /// Emit the Node class constructor.
215 void emitCtor(std::ostream &os) const;
216
217 /// Emits the class members (the fields of the class).
218 void emitClassMembers(std::ostream &os) const;
219
220 /// Emit the getter for a accessible class member, and optionally a setter.
221 void emitMemberGetterSetter(std::ostream &os, const MemberTypeInfo *type,
222 const std::string &name) const;
223
224 /// Emit setters/getters for each accessible class member.
225 void emitSettersGetters(std::ostream &os) const;
226
227 /// Emit getters for input/output names and input nodes.
228 void emitEdges(std::ostream &os) const;
229
230 /// Emit the methods that print a textual summary of the node.
231 void emitPrettyPrinter(std::ostream &os) const;
232
233 /// Emit the isEqual method that performs node comparisons.
234 void emitEquator(std::ostream &os) const;
235
236 /// Emit the clone() method copies the node.
237 void emitCloner(std::ostream &os) const;
238
239 /// Emit the getHash method that computes a hash of a node.
240 void emitHasher(std::ostream &os) const;
241
242 /// Emit the 'visit' method that implements node visitors.
243 void emitVisitor(std::ostream &os) const;
244
245 /// Emit the class-level documentation string, if any.
246 void emitDocstring(std::ostream &os) const;
247
248 /// Emit enums for each of the node's inputs and results indices.
249 void emitIndicesEnum(std::ostream &os) const;
250
251 /// Emit the class definition for the node.
252 void emitNodeClass(std::ostream &os) const;
253
254 /// Emit the methods that go into the CPP file and implement the methods that
255 /// were declared in the header file.
256 void emitCppMethods(std::ostream &os) const;
257
258 /// Emit cases for importing to \p os.
259 void emitImportMethods(std::ostream &os) const;
260
261 /// Emit cases for exporting to \p os.
262 void emitExportMethods(std::ostream &os) const;
263
264 // \returns whether \p res is contained in \ref ctorTypeParams_.
265 bool hasCtorTypeParams(llvm::StringRef res) const;
266};
267
268class Builder {
269 std::ofstream &hStream;
270 std::ofstream &cStream;
271 std::ofstream &dStream;
272 std::ofstream &iStream;
273 std::ofstream &eStream;
274
275public:
276 /// Create a new top-level builder that holds the three output streams that
277 /// point to the header file, cpp file and enum definition file.
278 Builder(std::ofstream &H, std::ofstream &C, std::ofstream &D,
279 std::ofstream &I, std::ofstream &E)
280 : hStream(H), cStream(C), dStream(D), iStream(I), eStream(E) {
281 cStream << "#include \"glow/Graph/Nodes.h\"\n"
282 "#include \"glow/Base/Type.h\"\n"
283 "#include \"glow/Support/Support.h\"\n"
284 "using namespace glow;\n";
285 dStream << "#ifndef DEF_NODE\n#error The macro DEF_NODE was not declared.\n"
286 "#endif\n";
287 }
288
289 ~Builder() { dStream << "#undef DEF_NODE"; }
290
291 /// Declare a new node and generate code for it.
292 NodeBuilder newNode(const std::string &name) {
293 const bool isBackendSpecific = false;
294 return NodeBuilder(hStream, cStream, dStream, iStream, eStream, name,
295 isBackendSpecific);
296 }
297
298 /// Declare a new backend specific node and generate code for it.
299 NodeBuilder newBackendSpecificNode(const std::string &name) {
300 const bool isBackendSpecific = true;
301 return NodeBuilder(hStream, cStream, dStream, iStream, eStream, name,
302 isBackendSpecific);
303 }
304
305 /// Declare the node in the def file but don't generate code for it.
306 void declareNode(const std::string &name) {
307 dStream << "DEF_NODE(" << name << ", " << name << ")\n";
308 }
309
310 /// Include backend-specific verification at the end of the auto-generated
311 /// Nodes cpp file.
312 void includeBackendSpecificVerification(const std::string &filename) {
313 cStream << "\n#include \"" << filename << "\"\n";
314 }
315
316 /// Include header into the auto-generated Nodes include file.
317 void includeHeader(const std::string &filename) {
318 hStream << "\n#include \"" << filename << "\"\n";
319 }
320};
321
322#endif // GLOW_TOOLS_NODEGEN_NODEBUILDER_H
323