1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "NodeBuilder.h"
18
19NodeBuilder &NodeBuilder::addMember(MemberType type, const std::string &name,
20 bool addSetter) {
21 MemberTypeInfo *typeInfo = nullptr;
22
23 if (type == MemberType::TypeRef) {
24 typeInfo = &kTypeRefTypeInfo;
25 } else if (type == MemberType::Float) {
26 typeInfo = &kFloatTypeInfo;
27 } else if (type == MemberType::Unsigned) {
28 typeInfo = &kUnsignedTypeInfo;
29 } else if (type == MemberType::Boolean) {
30 typeInfo = &kBooleanTypeInfo;
31 } else if (type == MemberType::Int64) {
32 typeInfo = &kInt64TypeInfo;
33 } else if (type == MemberType::String) {
34 typeInfo = &kStringTypeInfo;
35 } else if (type == MemberType::VectorFloat) {
36 typeInfo = &kVectorFloatTypeInfo;
37 } else if (type == MemberType::VectorUnsigned) {
38 typeInfo = &kVectorUnsignedTypeInfo;
39 } else if (type == MemberType::VectorInt64) {
40 typeInfo = &kVectorInt64TypeInfo;
41 } else if (type == MemberType::VectorSigned) {
42 typeInfo = &kVectorSignedTypeInfo;
43 } else if (type == MemberType::VectorSizeT) {
44 typeInfo = &kVectorSizeTTypeInfo;
45 } else if (type == MemberType::VectorDimT) {
46 typeInfo = &kVectorDimTTypeInfo;
47 } else if (type == MemberType::VectorNodeValue) {
48 typeInfo = &kVectorNodeValueTypeInfo;
49 } else if (type == MemberType::Enum) {
50 typeInfo = &kEnumTypeInfo;
51 } else if (type == MemberType::UserDefinedType) {
52 llvm_unreachable("addMember should be called with a MemberTypeInfo "
53 "parameter in this case");
54 } else {
55 llvm_unreachable("Type not recognized");
56 }
57
58 return addMember(*typeInfo, name, addSetter);
59}
60
61NodeBuilder &NodeBuilder::addFusedActivation() {
62 return addMember(MEMBER_TYPE_INFO(glow::FusedActivation), "FusedActivation",
63 /* addSetter */ true)
64 .addMember(MemberType::VectorFloat, "FusedActivationArgs",
65 /* addSetter */ true)
66 .addExtraMethod("bool hasFusedActivation() const;",
67 "bool " + name_ +
68 "Node::hasFusedActivation() const { return "
69 "getFusedActivation() != FusedActivation::NONE; }");
70}
71
72void NodeBuilder::emitMemberForwardDecls(std::ostream &os) const {
73 for (const auto &mem : members_) {
74 const std::string &forwardDecl = (mem.first).forwardDecl;
75 if (!forwardDecl.empty()) {
76 os << forwardDecl << "\n";
77 }
78 }
79
80 os << "\n";
81}
82
83void NodeBuilder::emitEnumModePrinters(std::ostream &os) const {
84 os << "\nconst char *" << name_ << "Node::getModeStr(" << name_
85 << "Node::Mode m) {\n";
86 os << " static const char *names[] = {";
87 for (const auto &e : enum_) {
88 os << "\"" << e << "\", ";
89 }
90 os << "nullptr};\n";
91 os << " return names[static_cast<int>(m)];\n";
92 os << "}\n";
93}
94
95void NodeBuilder::emitCtor(std::ostream &os) const {
96 os << " " << name_ << "Node(llvm::StringRef name";
97
98 // Generate the external type parameters:
99 for (const auto &paramName : ctorTypeParams_) {
100 os << ", TypeRef " << paramName << " ";
101 }
102
103 // The enum 'Mode' parameter:
104 if (!enum_.empty()) {
105 os << ", Mode mode";
106 }
107
108 // The operands of the graph node:
109 for (const auto &op : nodeInputs_) {
110 os << ", NodeValue " << op;
111 }
112
113 // Extra class members:
114 for (const auto &op : members_) {
115 os << ", " << getCtorArgTypename(&op.first) << " " << op.second;
116 }
117
118 // Initialize the base clases:
119 os << ")\n : Node(Kinded::Kind::" << name_ << "NodeKind, name)";
120
121 // Print the initialization list:
122 if (!enum_.empty()) {
123 os << ", mode_(mode)";
124 }
125
126 // Initialize the operands:
127 for (const auto &op : nodeInputs_) {
128 os << ", " << op << "_("
129 << "this, " << op << ")";
130 }
131
132 // Initialize the members:
133 for (const auto &op : members_) {
134 if ((op.first).type != MemberType::VectorNodeValue) {
135 os << ", " << op.second << "_(" << op.second << ")";
136 continue;
137 }
138 continue;
139 os << ", " << op.second << "_(" << op.second << ".begin(), " << op.second
140 << ".end()"
141 << ")";
142 }
143
144 // The constructor body:
145 os << " {\n";
146 for (auto &RT : nodeOutputs_) {
147 os << " addResult(" << RT.first << ");\n";
148 }
149
150 for (const auto &op : members_) {
151 if ((op.first).type != MemberType::VectorNodeValue) {
152 continue;
153 }
154 os << " " << op.second << "_.resize(" << op.second << ".size());\n";
155 os << " for (size_t idx = 0, e = " << op.second
156 << ".size(); idx < e; ++idx) {\n"
157 << " " << op.second << "_[idx] = " << op.second << "[idx];\n"
158 << " " << op.second << "_[idx].setParent(this);\n"
159 << " }\n";
160 }
161
162 os << " }\n";
163}
164
165void NodeBuilder::emitClassMembers(std::ostream &os) const {
166 // Emit the type of the enum (which is public).
167 if (!enum_.empty()) {
168 os << " public:\n enum class Mode {\n";
169 for (const auto &E : enum_) {
170 os << " " << E << ",\n";
171 }
172 os << " };\n\n private:\n";
173 }
174
175 // Emit class members:
176 if (!enum_.empty()) {
177 os << " Mode mode_;\n";
178 }
179 for (const auto &op : nodeInputs_) {
180 os << " NodeHandle " << op << "_;\n";
181 }
182 for (const auto &op : members_) {
183 os << " " << getStorageTypename(&op.first) << " " << op.second << "_;\n";
184 }
185}
186
187void NodeBuilder::emitMemberGetterSetter(std::ostream &os,
188 const MemberTypeInfo *typeInfo,
189 const std::string &name) const {
190 // Synthesize the general getter.
191 auto typeStr = getReturnTypename(typeInfo);
192 os << " " << typeStr << " get" << name << "() const { return " << name
193 << "_; }\n";
194
195 if (typeInfo->addSetter) {
196 os << " void set" << name << "(" << typeStr << " a) {" << name
197 << "_ = a; }\n";
198 }
199}
200
201void NodeBuilder::emitSettersGetters(std::ostream &os) const {
202 // Print the getters/setters.
203 for (const auto &inName : nodeInputs_) {
204 os << " const NodeValue get" << inName << "() const { return " << inName
205 << "_; }\n";
206 }
207
208 unsigned idx = 0;
209 for (const auto &op : nodeOutputs_) {
210 os << " NodeValue get" << op.second << "() { return getNthResult(" << idx
211 << "); }\n";
212 os << " const NodeValue get" << op.second
213 << "() const { return getNthResult(" << idx << "); }\n";
214 idx++;
215 }
216
217 for (const auto &op : members_) {
218 emitMemberGetterSetter(os, &op.first, op.second);
219 }
220
221 // Synthesize the 'classof' method that enables the non-rtti polymorphism.
222 os << "\n static bool classof(const Kinded *k) {\n"
223 << " return k->getKind() == Kinded::Kind::" << name_ << "NodeKind;\n"
224 << " }\n\n";
225
226 os << "\n bool isOverwrittenNthInput(unsigned idx) const {\n";
227 for (const auto &overwrittenInput : nodeOverwrittenInputs_) {
228 os << " if (idx == " << overwrittenInput << ") return true;\n";
229 }
230 os << " return false;\n";
231 os << " }\n\n";
232
233 if (!enum_.empty()) {
234 os << " Mode getMode() const { return mode_; }\n";
235 }
236}
237
238void NodeBuilder::emitEdges(std::ostream &os) const {
239 os << "\nunsigned " << name_ << "Node::getNumInputs() const {\n"
240 << " return " << nodeInputs_.size();
241 for (const auto &op : members_) {
242 if ((op.first).type != MemberType::VectorNodeValue) {
243 continue;
244 }
245 os << " + " << op.second << "_.size()";
246 }
247 os << ";\n}\n";
248
249 os << "\nstd::string " << name_
250 << "Node::getInputName(unsigned idx) const {\n";
251 for (size_t i = 0; i < nodeInputs_.size(); i++) {
252 os << " if (idx == " << i << ") { return \"" << nodeInputs_[i]
253 << "\"; }\n";
254 }
255 os << " idx -= " << nodeInputs_.size() << ";\n";
256 for (const auto &op : members_) {
257 if ((op.first).type != MemberType::VectorNodeValue) {
258 continue;
259 }
260 os << " if (idx < " << op.second << "_.size()) { return \"" << op.second
261 << "\" + std::to_string(idx); }\n"
262 << " idx -= " << op.second << "_.size();\n";
263 }
264 os << " llvm_unreachable(\"Invalid index\");\n}\n";
265
266 os << "\nNodeValue " << name_ << "Node::getNthInput(unsigned idx) {\n";
267 for (size_t i = 0; i < nodeInputs_.size(); i++) {
268 os << " if (idx == " << i << ") { return " << nodeInputs_[i] << "_; }\n";
269 }
270 os << " idx -= " << nodeInputs_.size() << ";\n";
271 for (const auto &op : members_) {
272 if ((op.first).type != MemberType::VectorNodeValue) {
273 continue;
274 }
275 os << " if (idx < " << op.second << "_.size()) { return " << op.second
276 << "_[idx]; }\n idx -= " << op.second << "_.size();\n";
277 }
278 os << " llvm_unreachable(\"Invalid index\");\n}\n";
279
280 os << "\nvoid " << name_
281 << "Node::setNthInput(unsigned idx, NodeValue val) {\n";
282 for (size_t i = 0; i < nodeInputs_.size(); i++) {
283 os << " if (idx == " << i << ") { " << nodeInputs_[i]
284 << "_ = val; return; }\n";
285 }
286 os << " idx -= " << nodeInputs_.size() << ";\n";
287 for (const auto &op : members_) {
288 if ((op.first).type != MemberType::VectorNodeValue) {
289 continue;
290 }
291 os << " if (idx < " << op.second << "_.size()) { " << op.second
292 << "_[idx] = val; return; }\n idx -= " << op.second << "_.size();\n";
293 }
294 os << " llvm_unreachable(\"Invalid index\");\n}\n";
295
296 os << "\nllvm::StringRef " << name_
297 << "Node::getOutputName(unsigned idx) const {\n";
298 for (size_t i = 0; i < nodeOutputs_.size(); i++) {
299 os << " if (idx == " << i << ") { return \"" << nodeOutputs_[i].second
300 << "\"; }\n";
301 }
302 os << " llvm_unreachable(\"Invalid index\");\n}\n";
303}
304
305void NodeBuilder::emitPrettyPrinter(std::ostream &os) const {
306 os << "\nstd::string " << name_ << "Node::getDebugDesc() const {\n"
307 << " DescriptionBuilder db(getKindName());\n"
308 << " db.addParam(\"Name\", separateString(getName(), 100, \"\\n\"));\n";
309
310 os << " if (hasPredicate()) db.addParam(\"Predicate\", \"Yes\");\n";
311
312 os << " db\n";
313 if (!enum_.empty()) {
314 os << " .addParam(\"Mode\", getModeStr())\n";
315 }
316
317 // Generate description for inputs.
318 for (const auto &op : nodeInputs_) {
319 os << " .addParam(\"" << op << "\", *(get" << op << "().getType()))\n";
320 }
321
322 for (const auto &mem : members_) {
323 // Don't try to print the node operands directly.
324 MemberType ty = (mem.first).type;
325 if (ty == MemberType::VectorNodeValue) {
326 continue;
327 }
328
329 if (ty == MemberType::Enum) {
330 os << " .addParam(\"" << mem.second << "\", static_cast<int>(get"
331 << mem.second << "()))\n";
332 } else {
333 os << " .addParam(\"" << mem.second << "\", get" << mem.second
334 << "())\n";
335 }
336 }
337 os << " .addParam(\"Users\", getNumUsers());\n";
338
339 for (const auto &mem : members_) {
340 if ((mem.first).type != MemberType::VectorNodeValue) {
341 continue;
342 }
343
344 // Make sure that inputs are properly indexed.
345 os << " {\n";
346 os << " unsigned mIndex = 0;\n";
347 os << " for (const auto &II : get" << mem.second << "()) {\n"
348 << " db.addParam(\"" << mem.second
349 << "\"+std::to_string(mIndex++), *II.getType());\n"
350 << " }\n"
351 << " }\n";
352 }
353
354 // Generate description for outputs.
355 for (const auto &op : nodeOutputs_) {
356 os << " db.addParam(\"" << op.second << "\", *(get" << op.second
357 << "().getType()));\n";
358 }
359
360 os << " return db;\n}\n";
361}
362
363void NodeBuilder::emitCloner(std::ostream &os) const {
364 os << "\nNode* " << name_ << "Node::clone() const {\n";
365
366 os << " return new " << name_ << "Node(getName()";
367
368 // Pass the external type arguments:
369 for (const auto &paramName : ctorTypeParams_) {
370 os << ", get" << paramName << "().getType()";
371 }
372
373 // The enum 'Mode' parameter:
374 if (!enum_.empty()) {
375 os << ", getMode()";
376 }
377
378 // The operands of the graph node:
379 for (const auto &op : nodeInputs_) {
380 os << ", get" << op << "()";
381 }
382
383 // Extra class members:
384 for (const auto &op : members_) {
385 os << ", get" << op.second << "()";
386 }
387
388 os << ");\n}\n";
389}
390
391/// \returns true if a can be a part of a valid C/C++ identifier.
392static bool isIdentifierChar(char c) { return (c == '_' || isalnum(c)); }
393
394void NodeBuilder::emitEquator(std::ostream &os) const {
395 os << "\nbool " << name_ << "Node::isEqual(const " << name_
396 << "Node &other) const {\n return true";
397
398 if (!enum_.empty()) {
399 os << " &&\n getMode() == other.getMode()";
400 }
401
402 for (const auto &op : nodeInputs_) {
403 os << " &&\n " << op << "_ == other." << op << "_";
404 }
405
406 os << " &&\n predicate_ == other.predicate_";
407
408 for (const auto &mem : members_) {
409 // Use custom user-defined comparator functions if available.
410 std::string cmpFn = mem.first.cmpFn;
411 if (cmpFn.empty() || !isIdentifierChar(cmpFn.at(0))) {
412 if (cmpFn.empty()) {
413 // Default comparator is ==.
414 cmpFn = "==";
415 }
416 os << " &&\n " << mem.second << "_ " << cmpFn << " other."
417 << mem.second << "_";
418 } else {
419 os << " &&\n " << cmpFn << "(" << mem.second << "_, other."
420 << mem.second << "_)";
421 }
422 }
423
424 for (int i = 0, e = nodeOutputs_.size(); i < e; i++) {
425 os << " &&\n getType(" << i << ") == other.getType(" << i << ")";
426 }
427 os << ";\n}\n";
428}
429
430static bool isVectorType(MemberType ty) {
431 return ty == MemberType::VectorFloat || ty == MemberType::VectorNodeValue ||
432 ty == MemberType::VectorSizeT || ty == MemberType::VectorDimT ||
433 ty == MemberType::VectorUnsigned || ty == MemberType::VectorInt64 ||
434 ty == MemberType::VectorSigned;
435}
436
437static bool isFloatVectorType(MemberType ty) {
438 return ty == MemberType::VectorFloat;
439}
440
441void NodeBuilder::emitHasher(std::ostream &os) const {
442 os << "\nllvm::hash_code " << name_ << "Node::getHash() const {\n"
443 << " return llvm::hash_combine(";
444
445 if (enum_.empty() && nodeInputs_.empty() && members_.empty()) {
446 os << "0);\n }\n";
447 return;
448 }
449
450 auto delim = "";
451 if (!enum_.empty()) {
452 os << delim << "\n getMode()";
453 delim = ",";
454 }
455 for (const auto &mem : members_) {
456 auto ty = (mem.first).type;
457 if (ty == MemberType::Float) {
458 os << delim << "\n toBinary(" << mem.second << "_)";
459 } else if (isFloatVectorType(ty)) {
460 os << delim
461 << "\n [](const std::vector<float>& floatVec) -> llvm::hash_code "
462 "{\n std::vector<size_t> sizeVec = toBinary(floatVec);\n "
463 " return llvm::hash_combine_range(sizeVec.begin(), "
464 "sizeVec.end());\n }("
465 << mem.second << "_)";
466 } else if (isVectorType(ty)) {
467 os << delim << "\n llvm::hash_combine_range(" << mem.second
468 << "_.begin(), " << mem.second << "_.end())";
469 } else if (ty == MemberType::Enum) {
470 os << delim << "\n static_cast<int>(" << mem.second << "_)";
471 } else {
472 os << delim << "\n " << mem.second << "_";
473 }
474 delim = ",";
475 }
476
477 for (const auto &op : nodeInputs_) {
478 os << delim << "\n " << op << "_";
479 delim = ",";
480 }
481
482 os << ");\n}\n";
483}
484void NodeBuilder::emitVisitor(std::ostream &os) const {
485 os << "\nvoid " << name_
486 << "Node::visit(Node *parent, NodeWalker *visitor) {\n"
487 << " if (!visitor->shouldVisit(parent, this)) { return; }\n"
488 << " visitor->pre(parent, this);\n"
489 << "if (hasPredicate())\n"
490 << " getPredicate().getNode()->visit(this, visitor);\n";
491
492 for (const auto &op : nodeInputs_) {
493 os << " get" << op << "().getNode()->visit(this, visitor);\n";
494 }
495
496 for (const auto &op : members_) {
497 if ((op.first).type == MemberType::VectorNodeValue) {
498 os << " for (auto &I : " << op.second
499 << "_) { I.getNode()->visit(this, visitor); }\n";
500 }
501 }
502
503 os << " visitor->post(parent, this);\n}\n";
504}
505
506void NodeBuilder::emitDocstring(std::ostream &os) const {
507 std::istringstream stream(docstring_);
508 std::string line;
509 while (std::getline(stream, line)) {
510 os << "/// " << line << "\n";
511 }
512}
513
514void NodeBuilder::emitIndicesEnum(std::ostream &os) const {
515 os << " enum InputIndices {\n";
516 for (size_t i = 0; i < nodeInputs_.size(); i++) {
517 os << " ";
518 os << nodeInputs_[i];
519 os << "Idx = " << i << ",\n";
520 }
521 os << " };\n\n";
522
523 os << " enum ResultIndices {\n";
524 for (int i = 0, e = nodeOutputs_.size(); i < e; i++) {
525 os << " ";
526 os << nodeOutputs_[i].second;
527 os << "Idx = " << i << ",\n";
528 }
529 os << " };\n\n";
530}
531
532void NodeBuilder::emitNodeClass(std::ostream &os) const {
533 emitMemberForwardDecls(os);
534
535 os << "\nnamespace glow {\n";
536
537 emitDocstring(os);
538
539 os << "class " << name_ << "Node final : public Node {\n";
540
541 emitClassMembers(os);
542
543 os << "\n public:\n";
544
545 emitIndicesEnum(os);
546 emitCtor(os);
547 emitSettersGetters(os);
548
549 os << " unsigned getNumInputs() const;\n"
550 << " std::string getInputName(unsigned idx) const;\n"
551 << " NodeValue getNthInput(unsigned idx);\n"
552 << " void setNthInput(unsigned idx, NodeValue val);\n"
553 << " llvm::StringRef getOutputName(unsigned idx) const;\n"
554 << " bool hasSideEffects() const { return " << hasSideEffects_ << "; }\n"
555 << " bool isCanonical() const { return " << !isBackendSpecific_ << "; }\n"
556 << " bool isDataParallel() const { return " << isDataParallel_ << "; }\n"
557 << " std::string getDebugDesc() const;\n"
558 << " bool isEqual(const " << name_ << "Node &other) const;\n"
559 << " llvm::hash_code getHash() const;\n"
560 << " void visit(Node *parent, NodeWalker *visitor);\n"
561 << " Node* clone() const;\n"
562 << " bool verify() const;\n";
563
564 if (hasExtraResults_) {
565 os << " void addExtraResult(TypeRef T) { addResult(T); }\n";
566 }
567
568 if (!enum_.empty()) {
569 os << " const char *getModeStr() const { return getModeStr(mode_); }\n"
570 << " static const char *getModeStr(Mode m);\n";
571 }
572
573 for (const auto &m : extraMethods_) {
574 os << " " << m.first;
575 }
576
577 os << "};\n} // namespace glow\n";
578}
579
580void NodeBuilder::emitCppMethods(std::ostream &os) const {
581 emitEdges(os);
582 emitPrettyPrinter(os);
583 emitVisitor(os);
584 emitEquator(os);
585 emitCloner(os);
586 emitHasher(os);
587 if (!enum_.empty()) {
588 emitEnumModePrinters(os);
589 }
590
591 // Emit the "extra" method bodies.
592 for (const auto &m : extraMethods_) {
593 os << m.second;
594 }
595}
596
597bool NodeBuilder::hasCtorTypeParams(llvm::StringRef res) const {
598 for (const std::string &s : ctorTypeParams_) {
599 if (s == res) {
600 return true;
601 }
602 }
603 return false;
604}
605
606void NodeBuilder::emitImportMethods(std::ostream &os) const {
607 os << "if (typeName == \"Glow_" << name_ << "\") {\n";
608
609 // Load all the inputs.
610 for (size_t i = 0, e = nodeInputs_.size(); i < e; i++) {
611 auto &op = nodeInputs_[i];
612 os << " NodeValue " << op << ";\n";
613 os << " ASSIGN_VALUE_OR_RETURN_ERR(" << op
614 << ", getNodeValueByName(op.input(" << i << ")));\n\n";
615 }
616
617 // Load all the output types.
618 for (size_t i = 0, e = nodeOutputs_.size(); i < e; i++) {
619 auto &op = nodeOutputs_[i];
620 if (hasCtorTypeParams(op.second)) {
621 os << " TypeRef " << op.second << "OutTy;\n";
622 os << " ASSIGN_VALUE_OR_RETURN_ERR(" << op.second
623 << "OutTy, loadTypeFromAttributes(" << std::to_string(i)
624 << ", dict));\n\n";
625 }
626 }
627
628 // Load the members.
629 for (const auto &op : members_) {
630 auto ty = getCtorArgTypename(&op.first);
631 os << " " << ty << " " << op.second << ";\n";
632 os << " ASSIGN_VALUE_OR_RETURN_ERR(" << op.second;
633 os << ", loadAttribute<" << ty << ">(dict.at(\"" << op.second
634 << "\"), *this));\n\n";
635 }
636
637 // We have all items needed to construct the node, so do so.
638 const auto nodeName = name_ + "Node";
639 os << " " << nodeName << " *loadedNode = G_->addNode(new " << nodeName
640 << "(opName";
641 for (const auto &op : nodeOutputs_) {
642 if (hasCtorTypeParams(op.second)) {
643 os << ", " << op.second << "OutTy";
644 }
645 }
646 for (size_t i = 0, e = nodeInputs_.size(); i < e; i++) {
647 auto &op = nodeInputs_[i];
648 os << ", " << op;
649 }
650 for (const auto &op : members_) {
651 os << ", " << op.second;
652 }
653 os << "));\n\n";
654
655 // Now load a predicate if one exists.
656 os << " if (dict.count(\"Predicate\")) {\n";
657 os << " NodeValue Predicate;\n";
658 os << " ASSIGN_VALUE_OR_RETURN_ERR(Predicate, "
659 "loadAttribute<NodeValue>(dict.at(\"Predicate\"), *this));\n";
660 os << " loadedNode->setPredicate(Predicate);\n";
661 os << " }\n\n";
662
663 // Add the node to the Function and return it.
664 os << " RETURN_IF_ERR(addNodeAsOutput(op, loadedNode));\n";
665 os << " return loadedNode;\n";
666 os << "}\n\n";
667}
668
669void NodeBuilder::emitExportMethods(std::ostream &os) const {
670 os << "case glow::Kinded::Kind::" << name_ << "NodeKind: {\n";
671 os << " auto *N__ = llvm::cast<" << name_ << "Node>(node);\n";
672
673 // Add the node. Note that Glow custom ops are prefixed with "Glow_"
674 os << " opProto = graph.add_node();\n";
675 os << " opProto->set_op_type(\"Glow_" << name_ << "\");\n";
676 os << " opProto->set_name(glow::legalizeName(N__->getName()));\n";
677
678 // Add all of the node's inputs.
679 for (const auto &op : nodeInputs_) {
680 os << " opProto->add_input(N__->get" << op
681 << "().generateNodeOutputName(/* stripResNoFor0thInput */ true));\n";
682 // Note: Add each input's type attributes so that other tools have easy
683 // visibility into types. This info may go ignored by the importer.
684 os << " addTypeAttributes(opProto, N__, " << name_ << "Node::" << op
685 << "Idx, /* isInput */ true);\n";
686 }
687
688 // Add all of the node's outputs.
689 for (const auto &op : nodeOutputs_) {
690 os << " opProto->add_output(N__->get" << op.second
691 << "().generateNodeOutputName(/* stripResNoFor0thInput */ true));\n";
692 // Note: export the type attributes even if not needed by the importer, so
693 // that other tools have easy visibility into types. This info may go
694 // ignored by the importer.
695 os << " addTypeAttributes(opProto, N__, " << name_ << "Node::" << op.second
696 << "Idx, /* isInput */ false);\n";
697 }
698
699 // Add any members the node has.
700 for (const auto &op : members_) {
701 os << " addValueAttribute(opProto, \"" << op.second << "\", N__->get"
702 << op.second << "());\n";
703
704 // If the member is a VectorNodeValue then also add the types of the NVs.
705 if (op.first.type == MemberType::VectorNodeValue) {
706 os << " for (unsigned i = 0, e = N__->get" << op.second
707 << "().size(); i < e; i++) {\n";
708 os << " addTypeAttributes(opProto, N__->get" << op.second
709 << "()[i], i, /* isInput */ true, \"" << op.second << "_\");\n";
710 os << " }\n";
711 }
712 }
713
714 // Check if the node has a predicate and add it if so.
715 os << " if (N__->hasPredicate()) {\n";
716 os << " addValueAttribute(opProto, \"Predicate\", "
717 "N__->getPredicate().generateNodeOutputName(/* stripResNoFor0thInput "
718 "*/ true));\n";
719 os << " }\n";
720
721 os << " break;\n";
722 os << "}\n\n";
723}
724
725NodeBuilder &NodeBuilder::addGradient() {
726 NodeBuilder GN(hStream, cStream, dStream, iStream, eStream, name_ + "Grad",
727 isBackendSpecific_);
728
729 // The new 'Grad' class will have all of the fields of the current class.
730 GN.members_ = members_;
731 GN.enum_ = enum_;
732 GN.isDataParallel_ = isDataParallel_;
733
734 // Add the inputs that we'll use in the grad instruction.
735 for (const std::string &in : nodeInputs_) {
736 GN.addInput(in);
737 }
738
739 for (const std::string &in : nodeInputs_) {
740 GN.addResult(in + ".getType()", "GradOfInputNamed" + in);
741 }
742
743 for (const auto &out : nodeOutputs_) {
744 GN.addInput("OriginalOutputFor" + out.second);
745 GN.addInput("GradOfOriginalOutputNamed" + out.second);
746 }
747
748 // Construct a factory method that builds the new grad node and add
749 // it to the current non-grad instruction.
750
751 std::string decl = name_ + "GradNode *getGrad(GraphGradMapper &builder);\n";
752 std::stringstream ss;
753 ss << "\n" + name_ + "GradNode *" + name_
754 << "Node::getGrad(GraphGradMapper &builder) {\n"
755 << " auto *x = new " + name_ + "GradNode(getName().str() + \"_grad\"";
756
757 if (enum_.size()) {
758 ss << ", (" << name_ << "GradNode::Mode)getMode()";
759 }
760
761 // Add the inputs that we'll use in the grad instruction.
762 for (const std::string &in : nodeInputs_) {
763 ss << ", get" << in << "()";
764 }
765
766 for (const auto &out : nodeOutputs_) {
767 ss << ", get" << out.second << "(), builder.getGradient(get" << out.second
768 << "())";
769 }
770
771 // Extra class members:
772 for (const auto &op : members_) {
773 ss << ", get" << op.second << "()";
774 }
775
776 ss << ");\n";
777
778 // Register the result of the new node as the gradients of the original node
779 // inputs.
780 for (const std::string &in : nodeInputs_) {
781 ss << " builder.addGradient(get" << in << "(), x->getGradOfInputNamed"
782 << in << "());\n";
783 }
784 ss << " return x;\n}\n";
785 addExtraMethod(decl, ss.str());
786
787 return *this;
788}
789
790NodeBuilder::~NodeBuilder() {
791 emitNodeClass(hStream);
792 emitCppMethods(cStream);
793 if (!skipAutogenSerialization_) {
794 emitImportMethods(iStream);
795 emitExportMethods(eStream);
796 }
797}
798