1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "NodeBuilder.h" |
18 | |
19 | NodeBuilder &NodeBuilder::addMember(MemberType type, const std::string &name, |
20 | bool addSetter) { |
21 | MemberTypeInfo *typeInfo = nullptr; |
22 | |
23 | if (type == MemberType::TypeRef) { |
24 | typeInfo = &kTypeRefTypeInfo; |
25 | } else if (type == MemberType::Float) { |
26 | typeInfo = &kFloatTypeInfo; |
27 | } else if (type == MemberType::Unsigned) { |
28 | typeInfo = &kUnsignedTypeInfo; |
29 | } else if (type == MemberType::Boolean) { |
30 | typeInfo = &kBooleanTypeInfo; |
31 | } else if (type == MemberType::Int64) { |
32 | typeInfo = &kInt64TypeInfo; |
33 | } else if (type == MemberType::String) { |
34 | typeInfo = &kStringTypeInfo; |
35 | } else if (type == MemberType::VectorFloat) { |
36 | typeInfo = &kVectorFloatTypeInfo; |
37 | } else if (type == MemberType::VectorUnsigned) { |
38 | typeInfo = &kVectorUnsignedTypeInfo; |
39 | } else if (type == MemberType::VectorInt64) { |
40 | typeInfo = &kVectorInt64TypeInfo; |
41 | } else if (type == MemberType::VectorSigned) { |
42 | typeInfo = &kVectorSignedTypeInfo; |
43 | } else if (type == MemberType::VectorSizeT) { |
44 | typeInfo = &kVectorSizeTTypeInfo; |
45 | } else if (type == MemberType::VectorDimT) { |
46 | typeInfo = &kVectorDimTTypeInfo; |
47 | } else if (type == MemberType::VectorNodeValue) { |
48 | typeInfo = &kVectorNodeValueTypeInfo; |
49 | } else if (type == MemberType::Enum) { |
50 | typeInfo = &kEnumTypeInfo; |
51 | } else if (type == MemberType::UserDefinedType) { |
52 | llvm_unreachable("addMember should be called with a MemberTypeInfo " |
53 | "parameter in this case" ); |
54 | } else { |
55 | llvm_unreachable("Type not recognized" ); |
56 | } |
57 | |
58 | return addMember(*typeInfo, name, addSetter); |
59 | } |
60 | |
61 | NodeBuilder &NodeBuilder::addFusedActivation() { |
62 | return addMember(MEMBER_TYPE_INFO(glow::FusedActivation), "FusedActivation" , |
63 | /* addSetter */ true) |
64 | .addMember(MemberType::VectorFloat, "FusedActivationArgs" , |
65 | /* addSetter */ true) |
66 | .addExtraMethod("bool hasFusedActivation() const;" , |
67 | "bool " + name_ + |
68 | "Node::hasFusedActivation() const { return " |
69 | "getFusedActivation() != FusedActivation::NONE; }" ); |
70 | } |
71 | |
72 | void NodeBuilder::emitMemberForwardDecls(std::ostream &os) const { |
73 | for (const auto &mem : members_) { |
74 | const std::string &forwardDecl = (mem.first).forwardDecl; |
75 | if (!forwardDecl.empty()) { |
76 | os << forwardDecl << "\n" ; |
77 | } |
78 | } |
79 | |
80 | os << "\n" ; |
81 | } |
82 | |
83 | void NodeBuilder::emitEnumModePrinters(std::ostream &os) const { |
84 | os << "\nconst char *" << name_ << "Node::getModeStr(" << name_ |
85 | << "Node::Mode m) {\n" ; |
86 | os << " static const char *names[] = {" ; |
87 | for (const auto &e : enum_) { |
88 | os << "\"" << e << "\", " ; |
89 | } |
90 | os << "nullptr};\n" ; |
91 | os << " return names[static_cast<int>(m)];\n" ; |
92 | os << "}\n" ; |
93 | } |
94 | |
95 | void NodeBuilder::emitCtor(std::ostream &os) const { |
96 | os << " " << name_ << "Node(llvm::StringRef name" ; |
97 | |
98 | // Generate the external type parameters: |
99 | for (const auto ¶mName : ctorTypeParams_) { |
100 | os << ", TypeRef " << paramName << " " ; |
101 | } |
102 | |
103 | // The enum 'Mode' parameter: |
104 | if (!enum_.empty()) { |
105 | os << ", Mode mode" ; |
106 | } |
107 | |
108 | // The operands of the graph node: |
109 | for (const auto &op : nodeInputs_) { |
110 | os << ", NodeValue " << op; |
111 | } |
112 | |
113 | // Extra class members: |
114 | for (const auto &op : members_) { |
115 | os << ", " << getCtorArgTypename(&op.first) << " " << op.second; |
116 | } |
117 | |
118 | // Initialize the base clases: |
119 | os << ")\n : Node(Kinded::Kind::" << name_ << "NodeKind, name)" ; |
120 | |
121 | // Print the initialization list: |
122 | if (!enum_.empty()) { |
123 | os << ", mode_(mode)" ; |
124 | } |
125 | |
126 | // Initialize the operands: |
127 | for (const auto &op : nodeInputs_) { |
128 | os << ", " << op << "_(" |
129 | << "this, " << op << ")" ; |
130 | } |
131 | |
132 | // Initialize the members: |
133 | for (const auto &op : members_) { |
134 | if ((op.first).type != MemberType::VectorNodeValue) { |
135 | os << ", " << op.second << "_(" << op.second << ")" ; |
136 | continue; |
137 | } |
138 | continue; |
139 | os << ", " << op.second << "_(" << op.second << ".begin(), " << op.second |
140 | << ".end()" |
141 | << ")" ; |
142 | } |
143 | |
144 | // The constructor body: |
145 | os << " {\n" ; |
146 | for (auto &RT : nodeOutputs_) { |
147 | os << " addResult(" << RT.first << ");\n" ; |
148 | } |
149 | |
150 | for (const auto &op : members_) { |
151 | if ((op.first).type != MemberType::VectorNodeValue) { |
152 | continue; |
153 | } |
154 | os << " " << op.second << "_.resize(" << op.second << ".size());\n" ; |
155 | os << " for (size_t idx = 0, e = " << op.second |
156 | << ".size(); idx < e; ++idx) {\n" |
157 | << " " << op.second << "_[idx] = " << op.second << "[idx];\n" |
158 | << " " << op.second << "_[idx].setParent(this);\n" |
159 | << " }\n" ; |
160 | } |
161 | |
162 | os << " }\n" ; |
163 | } |
164 | |
165 | void NodeBuilder::emitClassMembers(std::ostream &os) const { |
166 | // Emit the type of the enum (which is public). |
167 | if (!enum_.empty()) { |
168 | os << " public:\n enum class Mode {\n" ; |
169 | for (const auto &E : enum_) { |
170 | os << " " << E << ",\n" ; |
171 | } |
172 | os << " };\n\n private:\n" ; |
173 | } |
174 | |
175 | // Emit class members: |
176 | if (!enum_.empty()) { |
177 | os << " Mode mode_;\n" ; |
178 | } |
179 | for (const auto &op : nodeInputs_) { |
180 | os << " NodeHandle " << op << "_;\n" ; |
181 | } |
182 | for (const auto &op : members_) { |
183 | os << " " << getStorageTypename(&op.first) << " " << op.second << "_;\n" ; |
184 | } |
185 | } |
186 | |
187 | void NodeBuilder::emitMemberGetterSetter(std::ostream &os, |
188 | const MemberTypeInfo *typeInfo, |
189 | const std::string &name) const { |
190 | // Synthesize the general getter. |
191 | auto typeStr = getReturnTypename(typeInfo); |
192 | os << " " << typeStr << " get" << name << "() const { return " << name |
193 | << "_; }\n" ; |
194 | |
195 | if (typeInfo->addSetter) { |
196 | os << " void set" << name << "(" << typeStr << " a) {" << name |
197 | << "_ = a; }\n" ; |
198 | } |
199 | } |
200 | |
201 | void NodeBuilder::emitSettersGetters(std::ostream &os) const { |
202 | // Print the getters/setters. |
203 | for (const auto &inName : nodeInputs_) { |
204 | os << " const NodeValue get" << inName << "() const { return " << inName |
205 | << "_; }\n" ; |
206 | } |
207 | |
208 | unsigned idx = 0; |
209 | for (const auto &op : nodeOutputs_) { |
210 | os << " NodeValue get" << op.second << "() { return getNthResult(" << idx |
211 | << "); }\n" ; |
212 | os << " const NodeValue get" << op.second |
213 | << "() const { return getNthResult(" << idx << "); }\n" ; |
214 | idx++; |
215 | } |
216 | |
217 | for (const auto &op : members_) { |
218 | emitMemberGetterSetter(os, &op.first, op.second); |
219 | } |
220 | |
221 | // Synthesize the 'classof' method that enables the non-rtti polymorphism. |
222 | os << "\n static bool classof(const Kinded *k) {\n" |
223 | << " return k->getKind() == Kinded::Kind::" << name_ << "NodeKind;\n" |
224 | << " }\n\n" ; |
225 | |
226 | os << "\n bool isOverwrittenNthInput(unsigned idx) const {\n" ; |
227 | for (const auto &overwrittenInput : nodeOverwrittenInputs_) { |
228 | os << " if (idx == " << overwrittenInput << ") return true;\n" ; |
229 | } |
230 | os << " return false;\n" ; |
231 | os << " }\n\n" ; |
232 | |
233 | if (!enum_.empty()) { |
234 | os << " Mode getMode() const { return mode_; }\n" ; |
235 | } |
236 | } |
237 | |
238 | void NodeBuilder::emitEdges(std::ostream &os) const { |
239 | os << "\nunsigned " << name_ << "Node::getNumInputs() const {\n" |
240 | << " return " << nodeInputs_.size(); |
241 | for (const auto &op : members_) { |
242 | if ((op.first).type != MemberType::VectorNodeValue) { |
243 | continue; |
244 | } |
245 | os << " + " << op.second << "_.size()" ; |
246 | } |
247 | os << ";\n}\n" ; |
248 | |
249 | os << "\nstd::string " << name_ |
250 | << "Node::getInputName(unsigned idx) const {\n" ; |
251 | for (size_t i = 0; i < nodeInputs_.size(); i++) { |
252 | os << " if (idx == " << i << ") { return \"" << nodeInputs_[i] |
253 | << "\"; }\n" ; |
254 | } |
255 | os << " idx -= " << nodeInputs_.size() << ";\n" ; |
256 | for (const auto &op : members_) { |
257 | if ((op.first).type != MemberType::VectorNodeValue) { |
258 | continue; |
259 | } |
260 | os << " if (idx < " << op.second << "_.size()) { return \"" << op.second |
261 | << "\" + std::to_string(idx); }\n" |
262 | << " idx -= " << op.second << "_.size();\n" ; |
263 | } |
264 | os << " llvm_unreachable(\"Invalid index\");\n}\n" ; |
265 | |
266 | os << "\nNodeValue " << name_ << "Node::getNthInput(unsigned idx) {\n" ; |
267 | for (size_t i = 0; i < nodeInputs_.size(); i++) { |
268 | os << " if (idx == " << i << ") { return " << nodeInputs_[i] << "_; }\n" ; |
269 | } |
270 | os << " idx -= " << nodeInputs_.size() << ";\n" ; |
271 | for (const auto &op : members_) { |
272 | if ((op.first).type != MemberType::VectorNodeValue) { |
273 | continue; |
274 | } |
275 | os << " if (idx < " << op.second << "_.size()) { return " << op.second |
276 | << "_[idx]; }\n idx -= " << op.second << "_.size();\n" ; |
277 | } |
278 | os << " llvm_unreachable(\"Invalid index\");\n}\n" ; |
279 | |
280 | os << "\nvoid " << name_ |
281 | << "Node::setNthInput(unsigned idx, NodeValue val) {\n" ; |
282 | for (size_t i = 0; i < nodeInputs_.size(); i++) { |
283 | os << " if (idx == " << i << ") { " << nodeInputs_[i] |
284 | << "_ = val; return; }\n" ; |
285 | } |
286 | os << " idx -= " << nodeInputs_.size() << ";\n" ; |
287 | for (const auto &op : members_) { |
288 | if ((op.first).type != MemberType::VectorNodeValue) { |
289 | continue; |
290 | } |
291 | os << " if (idx < " << op.second << "_.size()) { " << op.second |
292 | << "_[idx] = val; return; }\n idx -= " << op.second << "_.size();\n" ; |
293 | } |
294 | os << " llvm_unreachable(\"Invalid index\");\n}\n" ; |
295 | |
296 | os << "\nllvm::StringRef " << name_ |
297 | << "Node::getOutputName(unsigned idx) const {\n" ; |
298 | for (size_t i = 0; i < nodeOutputs_.size(); i++) { |
299 | os << " if (idx == " << i << ") { return \"" << nodeOutputs_[i].second |
300 | << "\"; }\n" ; |
301 | } |
302 | os << " llvm_unreachable(\"Invalid index\");\n}\n" ; |
303 | } |
304 | |
305 | void NodeBuilder::emitPrettyPrinter(std::ostream &os) const { |
306 | os << "\nstd::string " << name_ << "Node::getDebugDesc() const {\n" |
307 | << " DescriptionBuilder db(getKindName());\n" |
308 | << " db.addParam(\"Name\", separateString(getName(), 100, \"\\n\"));\n" ; |
309 | |
310 | os << " if (hasPredicate()) db.addParam(\"Predicate\", \"Yes\");\n" ; |
311 | |
312 | os << " db\n" ; |
313 | if (!enum_.empty()) { |
314 | os << " .addParam(\"Mode\", getModeStr())\n" ; |
315 | } |
316 | |
317 | // Generate description for inputs. |
318 | for (const auto &op : nodeInputs_) { |
319 | os << " .addParam(\"" << op << "\", *(get" << op << "().getType()))\n" ; |
320 | } |
321 | |
322 | for (const auto &mem : members_) { |
323 | // Don't try to print the node operands directly. |
324 | MemberType ty = (mem.first).type; |
325 | if (ty == MemberType::VectorNodeValue) { |
326 | continue; |
327 | } |
328 | |
329 | if (ty == MemberType::Enum) { |
330 | os << " .addParam(\"" << mem.second << "\", static_cast<int>(get" |
331 | << mem.second << "()))\n" ; |
332 | } else { |
333 | os << " .addParam(\"" << mem.second << "\", get" << mem.second |
334 | << "())\n" ; |
335 | } |
336 | } |
337 | os << " .addParam(\"Users\", getNumUsers());\n" ; |
338 | |
339 | for (const auto &mem : members_) { |
340 | if ((mem.first).type != MemberType::VectorNodeValue) { |
341 | continue; |
342 | } |
343 | |
344 | // Make sure that inputs are properly indexed. |
345 | os << " {\n" ; |
346 | os << " unsigned mIndex = 0;\n" ; |
347 | os << " for (const auto &II : get" << mem.second << "()) {\n" |
348 | << " db.addParam(\"" << mem.second |
349 | << "\"+std::to_string(mIndex++), *II.getType());\n" |
350 | << " }\n" |
351 | << " }\n" ; |
352 | } |
353 | |
354 | // Generate description for outputs. |
355 | for (const auto &op : nodeOutputs_) { |
356 | os << " db.addParam(\"" << op.second << "\", *(get" << op.second |
357 | << "().getType()));\n" ; |
358 | } |
359 | |
360 | os << " return db;\n}\n" ; |
361 | } |
362 | |
363 | void NodeBuilder::emitCloner(std::ostream &os) const { |
364 | os << "\nNode* " << name_ << "Node::clone() const {\n" ; |
365 | |
366 | os << " return new " << name_ << "Node(getName()" ; |
367 | |
368 | // Pass the external type arguments: |
369 | for (const auto ¶mName : ctorTypeParams_) { |
370 | os << ", get" << paramName << "().getType()" ; |
371 | } |
372 | |
373 | // The enum 'Mode' parameter: |
374 | if (!enum_.empty()) { |
375 | os << ", getMode()" ; |
376 | } |
377 | |
378 | // The operands of the graph node: |
379 | for (const auto &op : nodeInputs_) { |
380 | os << ", get" << op << "()" ; |
381 | } |
382 | |
383 | // Extra class members: |
384 | for (const auto &op : members_) { |
385 | os << ", get" << op.second << "()" ; |
386 | } |
387 | |
388 | os << ");\n}\n" ; |
389 | } |
390 | |
391 | /// \returns true if a can be a part of a valid C/C++ identifier. |
392 | static bool isIdentifierChar(char c) { return (c == '_' || isalnum(c)); } |
393 | |
394 | void NodeBuilder::emitEquator(std::ostream &os) const { |
395 | os << "\nbool " << name_ << "Node::isEqual(const " << name_ |
396 | << "Node &other) const {\n return true" ; |
397 | |
398 | if (!enum_.empty()) { |
399 | os << " &&\n getMode() == other.getMode()" ; |
400 | } |
401 | |
402 | for (const auto &op : nodeInputs_) { |
403 | os << " &&\n " << op << "_ == other." << op << "_" ; |
404 | } |
405 | |
406 | os << " &&\n predicate_ == other.predicate_" ; |
407 | |
408 | for (const auto &mem : members_) { |
409 | // Use custom user-defined comparator functions if available. |
410 | std::string cmpFn = mem.first.cmpFn; |
411 | if (cmpFn.empty() || !isIdentifierChar(cmpFn.at(0))) { |
412 | if (cmpFn.empty()) { |
413 | // Default comparator is ==. |
414 | cmpFn = "==" ; |
415 | } |
416 | os << " &&\n " << mem.second << "_ " << cmpFn << " other." |
417 | << mem.second << "_" ; |
418 | } else { |
419 | os << " &&\n " << cmpFn << "(" << mem.second << "_, other." |
420 | << mem.second << "_)" ; |
421 | } |
422 | } |
423 | |
424 | for (int i = 0, e = nodeOutputs_.size(); i < e; i++) { |
425 | os << " &&\n getType(" << i << ") == other.getType(" << i << ")" ; |
426 | } |
427 | os << ";\n}\n" ; |
428 | } |
429 | |
430 | static bool isVectorType(MemberType ty) { |
431 | return ty == MemberType::VectorFloat || ty == MemberType::VectorNodeValue || |
432 | ty == MemberType::VectorSizeT || ty == MemberType::VectorDimT || |
433 | ty == MemberType::VectorUnsigned || ty == MemberType::VectorInt64 || |
434 | ty == MemberType::VectorSigned; |
435 | } |
436 | |
437 | static bool isFloatVectorType(MemberType ty) { |
438 | return ty == MemberType::VectorFloat; |
439 | } |
440 | |
441 | void NodeBuilder::emitHasher(std::ostream &os) const { |
442 | os << "\nllvm::hash_code " << name_ << "Node::getHash() const {\n" |
443 | << " return llvm::hash_combine(" ; |
444 | |
445 | if (enum_.empty() && nodeInputs_.empty() && members_.empty()) { |
446 | os << "0);\n }\n" ; |
447 | return; |
448 | } |
449 | |
450 | auto delim = "" ; |
451 | if (!enum_.empty()) { |
452 | os << delim << "\n getMode()" ; |
453 | delim = "," ; |
454 | } |
455 | for (const auto &mem : members_) { |
456 | auto ty = (mem.first).type; |
457 | if (ty == MemberType::Float) { |
458 | os << delim << "\n toBinary(" << mem.second << "_)" ; |
459 | } else if (isFloatVectorType(ty)) { |
460 | os << delim |
461 | << "\n [](const std::vector<float>& floatVec) -> llvm::hash_code " |
462 | "{\n std::vector<size_t> sizeVec = toBinary(floatVec);\n " |
463 | " return llvm::hash_combine_range(sizeVec.begin(), " |
464 | "sizeVec.end());\n }(" |
465 | << mem.second << "_)" ; |
466 | } else if (isVectorType(ty)) { |
467 | os << delim << "\n llvm::hash_combine_range(" << mem.second |
468 | << "_.begin(), " << mem.second << "_.end())" ; |
469 | } else if (ty == MemberType::Enum) { |
470 | os << delim << "\n static_cast<int>(" << mem.second << "_)" ; |
471 | } else { |
472 | os << delim << "\n " << mem.second << "_" ; |
473 | } |
474 | delim = "," ; |
475 | } |
476 | |
477 | for (const auto &op : nodeInputs_) { |
478 | os << delim << "\n " << op << "_" ; |
479 | delim = "," ; |
480 | } |
481 | |
482 | os << ");\n}\n" ; |
483 | } |
484 | void NodeBuilder::emitVisitor(std::ostream &os) const { |
485 | os << "\nvoid " << name_ |
486 | << "Node::visit(Node *parent, NodeWalker *visitor) {\n" |
487 | << " if (!visitor->shouldVisit(parent, this)) { return; }\n" |
488 | << " visitor->pre(parent, this);\n" |
489 | << "if (hasPredicate())\n" |
490 | << " getPredicate().getNode()->visit(this, visitor);\n" ; |
491 | |
492 | for (const auto &op : nodeInputs_) { |
493 | os << " get" << op << "().getNode()->visit(this, visitor);\n" ; |
494 | } |
495 | |
496 | for (const auto &op : members_) { |
497 | if ((op.first).type == MemberType::VectorNodeValue) { |
498 | os << " for (auto &I : " << op.second |
499 | << "_) { I.getNode()->visit(this, visitor); }\n" ; |
500 | } |
501 | } |
502 | |
503 | os << " visitor->post(parent, this);\n}\n" ; |
504 | } |
505 | |
506 | void NodeBuilder::emitDocstring(std::ostream &os) const { |
507 | std::istringstream stream(docstring_); |
508 | std::string line; |
509 | while (std::getline(stream, line)) { |
510 | os << "/// " << line << "\n" ; |
511 | } |
512 | } |
513 | |
514 | void NodeBuilder::emitIndicesEnum(std::ostream &os) const { |
515 | os << " enum InputIndices {\n" ; |
516 | for (size_t i = 0; i < nodeInputs_.size(); i++) { |
517 | os << " " ; |
518 | os << nodeInputs_[i]; |
519 | os << "Idx = " << i << ",\n" ; |
520 | } |
521 | os << " };\n\n" ; |
522 | |
523 | os << " enum ResultIndices {\n" ; |
524 | for (int i = 0, e = nodeOutputs_.size(); i < e; i++) { |
525 | os << " " ; |
526 | os << nodeOutputs_[i].second; |
527 | os << "Idx = " << i << ",\n" ; |
528 | } |
529 | os << " };\n\n" ; |
530 | } |
531 | |
532 | void NodeBuilder::emitNodeClass(std::ostream &os) const { |
533 | emitMemberForwardDecls(os); |
534 | |
535 | os << "\nnamespace glow {\n" ; |
536 | |
537 | emitDocstring(os); |
538 | |
539 | os << "class " << name_ << "Node final : public Node {\n" ; |
540 | |
541 | emitClassMembers(os); |
542 | |
543 | os << "\n public:\n" ; |
544 | |
545 | emitIndicesEnum(os); |
546 | emitCtor(os); |
547 | emitSettersGetters(os); |
548 | |
549 | os << " unsigned getNumInputs() const;\n" |
550 | << " std::string getInputName(unsigned idx) const;\n" |
551 | << " NodeValue getNthInput(unsigned idx);\n" |
552 | << " void setNthInput(unsigned idx, NodeValue val);\n" |
553 | << " llvm::StringRef getOutputName(unsigned idx) const;\n" |
554 | << " bool hasSideEffects() const { return " << hasSideEffects_ << "; }\n" |
555 | << " bool isCanonical() const { return " << !isBackendSpecific_ << "; }\n" |
556 | << " bool isDataParallel() const { return " << isDataParallel_ << "; }\n" |
557 | << " std::string getDebugDesc() const;\n" |
558 | << " bool isEqual(const " << name_ << "Node &other) const;\n" |
559 | << " llvm::hash_code getHash() const;\n" |
560 | << " void visit(Node *parent, NodeWalker *visitor);\n" |
561 | << " Node* clone() const;\n" |
562 | << " bool verify() const;\n" ; |
563 | |
564 | if (hasExtraResults_) { |
565 | os << " void addExtraResult(TypeRef T) { addResult(T); }\n" ; |
566 | } |
567 | |
568 | if (!enum_.empty()) { |
569 | os << " const char *getModeStr() const { return getModeStr(mode_); }\n" |
570 | << " static const char *getModeStr(Mode m);\n" ; |
571 | } |
572 | |
573 | for (const auto &m : extraMethods_) { |
574 | os << " " << m.first; |
575 | } |
576 | |
577 | os << "};\n} // namespace glow\n" ; |
578 | } |
579 | |
580 | void NodeBuilder::emitCppMethods(std::ostream &os) const { |
581 | emitEdges(os); |
582 | emitPrettyPrinter(os); |
583 | emitVisitor(os); |
584 | emitEquator(os); |
585 | emitCloner(os); |
586 | emitHasher(os); |
587 | if (!enum_.empty()) { |
588 | emitEnumModePrinters(os); |
589 | } |
590 | |
591 | // Emit the "extra" method bodies. |
592 | for (const auto &m : extraMethods_) { |
593 | os << m.second; |
594 | } |
595 | } |
596 | |
597 | bool NodeBuilder::hasCtorTypeParams(llvm::StringRef res) const { |
598 | for (const std::string &s : ctorTypeParams_) { |
599 | if (s == res) { |
600 | return true; |
601 | } |
602 | } |
603 | return false; |
604 | } |
605 | |
606 | void NodeBuilder::emitImportMethods(std::ostream &os) const { |
607 | os << "if (typeName == \"Glow_" << name_ << "\") {\n" ; |
608 | |
609 | // Load all the inputs. |
610 | for (size_t i = 0, e = nodeInputs_.size(); i < e; i++) { |
611 | auto &op = nodeInputs_[i]; |
612 | os << " NodeValue " << op << ";\n" ; |
613 | os << " ASSIGN_VALUE_OR_RETURN_ERR(" << op |
614 | << ", getNodeValueByName(op.input(" << i << ")));\n\n" ; |
615 | } |
616 | |
617 | // Load all the output types. |
618 | for (size_t i = 0, e = nodeOutputs_.size(); i < e; i++) { |
619 | auto &op = nodeOutputs_[i]; |
620 | if (hasCtorTypeParams(op.second)) { |
621 | os << " TypeRef " << op.second << "OutTy;\n" ; |
622 | os << " ASSIGN_VALUE_OR_RETURN_ERR(" << op.second |
623 | << "OutTy, loadTypeFromAttributes(" << std::to_string(i) |
624 | << ", dict));\n\n" ; |
625 | } |
626 | } |
627 | |
628 | // Load the members. |
629 | for (const auto &op : members_) { |
630 | auto ty = getCtorArgTypename(&op.first); |
631 | os << " " << ty << " " << op.second << ";\n" ; |
632 | os << " ASSIGN_VALUE_OR_RETURN_ERR(" << op.second; |
633 | os << ", loadAttribute<" << ty << ">(dict.at(\"" << op.second |
634 | << "\"), *this));\n\n" ; |
635 | } |
636 | |
637 | // We have all items needed to construct the node, so do so. |
638 | const auto nodeName = name_ + "Node" ; |
639 | os << " " << nodeName << " *loadedNode = G_->addNode(new " << nodeName |
640 | << "(opName" ; |
641 | for (const auto &op : nodeOutputs_) { |
642 | if (hasCtorTypeParams(op.second)) { |
643 | os << ", " << op.second << "OutTy" ; |
644 | } |
645 | } |
646 | for (size_t i = 0, e = nodeInputs_.size(); i < e; i++) { |
647 | auto &op = nodeInputs_[i]; |
648 | os << ", " << op; |
649 | } |
650 | for (const auto &op : members_) { |
651 | os << ", " << op.second; |
652 | } |
653 | os << "));\n\n" ; |
654 | |
655 | // Now load a predicate if one exists. |
656 | os << " if (dict.count(\"Predicate\")) {\n" ; |
657 | os << " NodeValue Predicate;\n" ; |
658 | os << " ASSIGN_VALUE_OR_RETURN_ERR(Predicate, " |
659 | "loadAttribute<NodeValue>(dict.at(\"Predicate\"), *this));\n" ; |
660 | os << " loadedNode->setPredicate(Predicate);\n" ; |
661 | os << " }\n\n" ; |
662 | |
663 | // Add the node to the Function and return it. |
664 | os << " RETURN_IF_ERR(addNodeAsOutput(op, loadedNode));\n" ; |
665 | os << " return loadedNode;\n" ; |
666 | os << "}\n\n" ; |
667 | } |
668 | |
669 | void NodeBuilder::emitExportMethods(std::ostream &os) const { |
670 | os << "case glow::Kinded::Kind::" << name_ << "NodeKind: {\n" ; |
671 | os << " auto *N__ = llvm::cast<" << name_ << "Node>(node);\n" ; |
672 | |
673 | // Add the node. Note that Glow custom ops are prefixed with "Glow_" |
674 | os << " opProto = graph.add_node();\n" ; |
675 | os << " opProto->set_op_type(\"Glow_" << name_ << "\");\n" ; |
676 | os << " opProto->set_name(glow::legalizeName(N__->getName()));\n" ; |
677 | |
678 | // Add all of the node's inputs. |
679 | for (const auto &op : nodeInputs_) { |
680 | os << " opProto->add_input(N__->get" << op |
681 | << "().generateNodeOutputName(/* stripResNoFor0thInput */ true));\n" ; |
682 | // Note: Add each input's type attributes so that other tools have easy |
683 | // visibility into types. This info may go ignored by the importer. |
684 | os << " addTypeAttributes(opProto, N__, " << name_ << "Node::" << op |
685 | << "Idx, /* isInput */ true);\n" ; |
686 | } |
687 | |
688 | // Add all of the node's outputs. |
689 | for (const auto &op : nodeOutputs_) { |
690 | os << " opProto->add_output(N__->get" << op.second |
691 | << "().generateNodeOutputName(/* stripResNoFor0thInput */ true));\n" ; |
692 | // Note: export the type attributes even if not needed by the importer, so |
693 | // that other tools have easy visibility into types. This info may go |
694 | // ignored by the importer. |
695 | os << " addTypeAttributes(opProto, N__, " << name_ << "Node::" << op.second |
696 | << "Idx, /* isInput */ false);\n" ; |
697 | } |
698 | |
699 | // Add any members the node has. |
700 | for (const auto &op : members_) { |
701 | os << " addValueAttribute(opProto, \"" << op.second << "\", N__->get" |
702 | << op.second << "());\n" ; |
703 | |
704 | // If the member is a VectorNodeValue then also add the types of the NVs. |
705 | if (op.first.type == MemberType::VectorNodeValue) { |
706 | os << " for (unsigned i = 0, e = N__->get" << op.second |
707 | << "().size(); i < e; i++) {\n" ; |
708 | os << " addTypeAttributes(opProto, N__->get" << op.second |
709 | << "()[i], i, /* isInput */ true, \"" << op.second << "_\");\n" ; |
710 | os << " }\n" ; |
711 | } |
712 | } |
713 | |
714 | // Check if the node has a predicate and add it if so. |
715 | os << " if (N__->hasPredicate()) {\n" ; |
716 | os << " addValueAttribute(opProto, \"Predicate\", " |
717 | "N__->getPredicate().generateNodeOutputName(/* stripResNoFor0thInput " |
718 | "*/ true));\n" ; |
719 | os << " }\n" ; |
720 | |
721 | os << " break;\n" ; |
722 | os << "}\n\n" ; |
723 | } |
724 | |
725 | NodeBuilder &NodeBuilder::addGradient() { |
726 | NodeBuilder GN(hStream, cStream, dStream, iStream, eStream, name_ + "Grad" , |
727 | isBackendSpecific_); |
728 | |
729 | // The new 'Grad' class will have all of the fields of the current class. |
730 | GN.members_ = members_; |
731 | GN.enum_ = enum_; |
732 | GN.isDataParallel_ = isDataParallel_; |
733 | |
734 | // Add the inputs that we'll use in the grad instruction. |
735 | for (const std::string &in : nodeInputs_) { |
736 | GN.addInput(in); |
737 | } |
738 | |
739 | for (const std::string &in : nodeInputs_) { |
740 | GN.addResult(in + ".getType()" , "GradOfInputNamed" + in); |
741 | } |
742 | |
743 | for (const auto &out : nodeOutputs_) { |
744 | GN.addInput("OriginalOutputFor" + out.second); |
745 | GN.addInput("GradOfOriginalOutputNamed" + out.second); |
746 | } |
747 | |
748 | // Construct a factory method that builds the new grad node and add |
749 | // it to the current non-grad instruction. |
750 | |
751 | std::string decl = name_ + "GradNode *getGrad(GraphGradMapper &builder);\n" ; |
752 | std::stringstream ss; |
753 | ss << "\n" + name_ + "GradNode *" + name_ |
754 | << "Node::getGrad(GraphGradMapper &builder) {\n" |
755 | << " auto *x = new " + name_ + "GradNode(getName().str() + \"_grad\"" ; |
756 | |
757 | if (enum_.size()) { |
758 | ss << ", (" << name_ << "GradNode::Mode)getMode()" ; |
759 | } |
760 | |
761 | // Add the inputs that we'll use in the grad instruction. |
762 | for (const std::string &in : nodeInputs_) { |
763 | ss << ", get" << in << "()" ; |
764 | } |
765 | |
766 | for (const auto &out : nodeOutputs_) { |
767 | ss << ", get" << out.second << "(), builder.getGradient(get" << out.second |
768 | << "())" ; |
769 | } |
770 | |
771 | // Extra class members: |
772 | for (const auto &op : members_) { |
773 | ss << ", get" << op.second << "()" ; |
774 | } |
775 | |
776 | ss << ");\n" ; |
777 | |
778 | // Register the result of the new node as the gradients of the original node |
779 | // inputs. |
780 | for (const std::string &in : nodeInputs_) { |
781 | ss << " builder.addGradient(get" << in << "(), x->getGradOfInputNamed" |
782 | << in << "());\n" ; |
783 | } |
784 | ss << " return x;\n}\n" ; |
785 | addExtraMethod(decl, ss.str()); |
786 | |
787 | return *this; |
788 | } |
789 | |
790 | NodeBuilder::~NodeBuilder() { |
791 | emitNodeClass(hStream); |
792 | emitCppMethods(cStream); |
793 | if (!skipAutogenSerialization_) { |
794 | emitImportMethods(iStream); |
795 | emitExportMethods(eStream); |
796 | } |
797 | } |
798 | |