1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "InstrBuilder.h"
18#include "glow/Support/Compiler.h"
19
20unsigned InstrBuilder::getOperandIndexByName(llvm::StringRef name) const {
21 for (unsigned i = 0; i < operands_.size(); i++) {
22 if (name == operands_[i].first) {
23 return i;
24 }
25 }
26
27 llvm_unreachable("Can't find an operand with this name");
28}
29
30void InstrBuilder::emitCtor(std::ostream &os) const {
31 os << " " << name_ << "Inst(llvm::StringRef name";
32
33 // The operands of the instruction class:
34 for (const auto &op : operands_) {
35 os << ", Value *" << op.first;
36 }
37
38 // Extra class members:
39 for (const auto &op : members_) {
40 os << ", " << getStorageTypename(&op.first) << " " << op.second;
41 }
42
43 // Initialize the base clases:
44 os << ")\n : Instruction(name, Kinded::Kind::" << name_ << "InstKind, "
45 << ty_ << ", {\n";
46
47 // The operands of the instruction class:
48 for (const auto &op : operands_) {
49 os << " {" << op.first
50 << ", OperandKind::" << getOperandKindStr(op.second) << "},\n";
51 }
52 os << " })";
53
54 // Initialize the members:
55 for (const auto &op : members_) {
56 os << ", " << op.second << "_(" << op.second << ")";
57 }
58
59 // Empty constructor body.
60 os << " {}\n\n";
61}
62
63void InstrBuilder::emitIRBuilderMethods(std::ostream &osH,
64 std::ostream &osB) const {
65 osH << "\n"
66 << name_ << "Inst *create" << name_ << "Inst(llvm::StringRef name";
67 osB << "\n"
68 << name_ << "Inst *IRBuilder::create" << name_
69 << "Inst(llvm::StringRef name";
70
71 // The operands of the instruction class:
72 for (const auto &op : operands_) {
73 // Scratch operands are not exposed in the builder method interface
74 // but only in the instruction constructor.
75 if (op.second == OperandKind::Scratch) {
76 continue;
77 }
78 osH << ", Value *" << op.first;
79 osB << ", Value *" << op.first;
80 }
81
82 // Extra class members:
83 for (const auto &op : members_) {
84 osH << ", " << getStorageTypename(&op.first) << " " << op.second;
85 osB << ", " << getStorageTypename(&op.first) << " " << op.second;
86 }
87 osH << ");\n";
88 osB << ") {\n";
89
90 // Create allocations for the scratch operands.
91 for (const auto &op : operands_) {
92 if (op.second == OperandKind::Scratch) {
93 std::string allocSuffix = llvm::StringRef(op.first).lower();
94 osB << " std::string " << op.first << "Name = name.str() + \"."
95 << allocSuffix << "\";\n";
96 osB << " auto *" << op.first << "Type = F_->getParent()"
97 << "->uniqueType(ElemKind::Int8QTy, {1}, 0.0, 0);\n";
98 osB << " auto *" << op.first << " = createAllocActivationInst("
99 << op.first << "Name, " << op.first << "Type);\n";
100 }
101 }
102
103 // Initialize the base clases:
104 osB << " auto *A = new " << name_ << "Inst(uniqueName(name)";
105
106 // The operands of the instruction class:
107 for (const auto &op : operands_) {
108 osB << ", " << op.first;
109 }
110 // Extra class members:
111 for (const auto &op : members_) {
112 osB << ", " << op.second;
113 }
114 osB << ");\n";
115
116 // Modify allocation sizes based on the instruction requirements.
117 // We allocate at least 1 byte since the memory allocator does not
118 // handle properly allocation sizes of 0.
119 for (const auto &op : operands_) {
120 if (op.second == OperandKind::Scratch) {
121 // A special case is when the instruction already has a member called
122 // "<Operand>Size" for which we allow a different type than dim_t for
123 // flexibility and hence we create a local cast here to dim_t.
124 osB << " dim_t " << op.first << "SizeVar = static_cast<dim_t>(A->get"
125 << op.first << "Size());\n";
126 osB << " " << op.first << "SizeVar = " << op.first << "SizeVar > 0 ? "
127 << op.first << "SizeVar : 1;\n";
128 osB << " auto *" << op.first << "TypeResized = F_->getParent()"
129 << "->uniqueType(ElemKind::Int8QTy, {" << op.first
130 << "SizeVar}, 0.0, 0);\n";
131 osB << " " << op.first << "->setType(" << op.first << "TypeResized);\n";
132 osB << " " << op.first << "->setTy(" << op.first << "TypeResized);\n";
133 }
134 }
135
136 osB << " F_->pushInstr(A);\n return A;\n}\n";
137}
138
139void InstrBuilder::emitInplaceMethod(std::ostream &os) const {
140 os << "\n bool isInplaceOp(unsigned dstIdx, unsigned srcIdx) const {\n";
141 if (!inplaceOperands_.empty()) {
142 for (const auto &curInplaceOperands : inplaceOperands_) {
143 assert(curInplaceOperands.size() > 1 &&
144 "We don't have a pair of inplace args");
145 for (int i = 1, e = curInplaceOperands.size(); i < e; i++) {
146 auto F0 = getOperandIndexByName(curInplaceOperands[0]);
147 auto F1 = getOperandIndexByName(curInplaceOperands[i]);
148 os << " if (" << F0 << " == dstIdx && " << F1
149 << " == srcIdx) { return true; }\n";
150 }
151 }
152 }
153 os << " return false;\n }\n";
154}
155
156void InstrBuilder::emitCanonicalProperty(std::ostream &os) const {
157 os << "\n bool isCanonical() const {\n";
158 os << " return " << (isBackendSpecific_ ? "false" : "true") << ";\n }\n";
159}
160
161void InstrBuilder::emitDataParallelProperty(std::ostream &os) const {
162 os << "\n bool isDataParallel() const {\n";
163 os << " return " << (isDataParallel_ ? "true" : "false") << ";\n }\n";
164}
165
166void InstrBuilder::emitProperties(std::ostream &os) const {
167 emitInplaceMethod(os);
168 emitCanonicalProperty(os);
169 emitDataParallelProperty(os);
170}
171
172void InstrBuilder::emitClassMembers(std::ostream &os) const {
173 // Emit class members:
174 for (const auto &op : members_) {
175 os << " " << getStorageTypename(&op.first) << " " << op.second << "_;\n";
176 }
177}
178
179void InstrBuilder::emitOperandGetter(std::ostream &os, const std::string &name,
180 int index) const {
181 // Synthesize the general operand getter.
182 os << " Value *get" << name << "() const { return getOperand(" << index
183 << ").first; }\n";
184}
185
186void InstrBuilder::emitMemberGetter(std::ostream &os,
187 const MemberTypeInfo *type,
188 const std::string &name) const {
189 // Synthesize the general getter.
190 auto returnTypeStr = getReturnTypename(type);
191 os << " " << returnTypeStr << " get" << name << "() const { return " << name
192 << "_; }\n";
193}
194
195void InstrBuilder::emitSettersGetters(std::ostream &os) const {
196 // Print the getters/setters.
197 for (int i = 0, e = operands_.size(); i < e; i++) {
198 auto &op = operands_[i];
199 emitOperandGetter(os, op.first, i);
200 }
201
202 for (const auto &op : members_) {
203 emitMemberGetter(os, &op.first, op.second);
204 }
205
206 // Print size getter declarations for scratch operands. The functions will be
207 // manually implemented by the instruction creator.
208 for (const auto &op : operands_) {
209 if (op.second == OperandKind::Scratch) {
210 // A special case is when the instruction already has a member called
211 // "<Operand>Size" for which a getter was already emitted. We detect this
212 // particular case and not emit the getter again.
213 bool hasScratchSizeMember = false;
214 for (const auto &memb : members_) {
215 if (memb.second == (op.first + "Size")) {
216 hasScratchSizeMember = true;
217 break;
218 }
219 }
220 if (!hasScratchSizeMember) {
221 os << " dim_t get" << op.first << "Size() const;\n";
222 }
223 }
224 }
225
226 // Synthesize the 'classof' method that enables the non-rtti polymorphism.
227 os << "\n static bool classof(const Kinded *k) {\n"
228 << " return k->getKind() == Kinded::Kind::" << name_ << "InstKind;\n"
229 << " }\n";
230}
231
232void InstrBuilder::emitPrettyPrinter(std::ostream &os) const {
233 os << "\nvoid " << name_ << "Inst::dump(llvm::raw_ostream &os) const {\n";
234 os << " os << \"%\" << (std::string) getName() << \" = \" << getKindName()"
235 << " << \" \";\n dumpOperands(os);\n";
236
237 if (!members_.empty()) {
238 os << " os << \" {\"\n";
239 bool first = true;
240 for (const auto &mem : members_) {
241 os << " << \"" << (first ? " " : ", ") << mem.second << ": \" << "
242 << "get" << mem.second << "()\n";
243 first = false;
244 }
245 os << " << \"}\";\n";
246 }
247 os << "}\n";
248}
249
250void InstrBuilder::emitCloner(std::ostream &os) const {
251 os << "\nInstruction* " << name_ << "Inst::clone() const {\n";
252
253 os << " return new " << name_ << "Inst(getName()";
254
255 for (const auto &op : operands_) {
256 os << ", get" << op.first << "()";
257 }
258
259 for (const auto &mem : members_) {
260 os << ", get" << mem.second << "()";
261 }
262
263 os << ");\n}\n";
264}
265
266void InstrBuilder::emitGetOperandName(std::ostream &os) const {
267 os << "\nllvm::StringRef " << name_
268 << "Inst::getOperandName(unsigned idx) const {\n";
269 for (size_t i = 0; i < operands_.size(); i++) {
270 os << " if (idx == " << i << ") { return \"" << operands_[i].first
271 << "\"; }\n";
272 }
273 os << " llvm_unreachable(\"Invalid index\");\n}\n";
274}
275
276std::string getOpElementType(const std::string &name) {
277 const std::string elemKindPrefix = "ElemKind::";
278 if (name.substr(0, elemKindPrefix.size()) == elemKindPrefix) {
279 return name;
280 }
281 return "get" + name + "()->getElementType()";
282}
283
284void InstrBuilder::emitClass(std::ostream &os) const {
285 os << "\nnamespace glow {\nclass " << name_
286 << "Inst final : public Instruction {\n";
287
288 emitClassMembers(os);
289
290 os << "\n public:\n";
291
292 emitCtor(os);
293 emitSettersGetters(os);
294 emitProperties(os);
295
296 for (const auto &m : extraMethods_) {
297 os << "\n " << m.first << "\n";
298 }
299
300 os << "\n Instruction* clone() const;\n";
301 os << "\n void dump(llvm::raw_ostream &os) const;\n";
302 os << "\n llvm::StringRef getOperandName(unsigned idx) const;\n";
303
304 // If there is no auto-verification then we assume verification is manually
305 // provided.
306 if (autoVerificationPairs_.empty()) {
307 os << " void verify() const;\n";
308 } else {
309 os << " void verify() const {\n";
310 // Generate auto-verification checks for the current type of the node.
311 for (auto &pair : autoVerificationPairs_) {
312 switch (pair.first) {
313 // Generates a check that two operands of an instruction are of the same
314 // type.
315 case VerifyKind::SameType: {
316 for (size_t i = 1, e = pair.second.size(); i < e; i++) {
317 os << " assert(get" << pair.second[0] << "()->getType() == get"
318 << pair.second[i] << "()->getType() && \"Invalid Type\");\n";
319 }
320 break;
321 }
322 // Generates a check that two operands of an instruction are of the same
323 // shape.
324 case VerifyKind::SameShape: {
325 for (size_t i = 1, e = pair.second.size(); i < e; i++) {
326 os << " assert(get" << pair.second[0] << "()->dims().equals(get"
327 << pair.second[i] << "()->dims()) && \"Invalid Shape\");\n";
328 }
329 break;
330 }
331 // Generates a check that two operands of an instruction have elements of
332 // the same type.
333 case VerifyKind::SameElementType: {
334 auto firstOp = getOpElementType(pair.second[0]);
335 for (size_t i = 1, e = pair.second.size(); i < e; i++) {
336 os << " assert(" << firstOp
337 << " == " << getOpElementType(pair.second[i])
338 << " && \"Invalid Element Type\");\n";
339 }
340 break;
341 }
342 // Generates a check that the type of an operand satisfies a specific
343 // check performed by a predicate method on a type.
344 case VerifyKind::TypeCheck: {
345 for (size_t i = 1, e = pair.second.size(); i < e; i++) {
346 os << " assert(get" << pair.second[0] << "()->getType()->"
347 << pair.second[i] << " && \"Invalid Type\");\n";
348 }
349 break;
350 }
351 // No verification check needs to be generated.
352 case VerifyKind::NoVerify: {
353 assert(autoVerificationPairs_.size() == 1);
354 os << " // Nothing to verify.\n";
355 break;
356 }
357 default:
358 llvm_unreachable("Unknown verification kind.");
359 break;
360 }
361 }
362 os << " }\n";
363 }
364 os << "};\n} // namespace glow\n";
365}
366
367void InstrBuilder::emitCppMethods(std::ostream &os) const {
368 emitPrettyPrinter(os);
369 emitCloner(os);
370 emitGetOperandName(os);
371 // Emit the "extra" method bodies.
372 for (const auto &m : extraMethods_) {
373 os << "\n" << m.second << "\n";
374 }
375}
376
377InstrBuilder::~InstrBuilder() {
378 emitClass(headerStream);
379 emitCppMethods(cppStream);
380 emitIRBuilderMethods(builderHeaderStream, builderCppStream);
381 emitAutoIRGen(irGenStream);
382}
383
384void InstrBuilder::addGradientInstr(
385 llvm::ArrayRef<llvm::StringRef> originalFields,
386 llvm::ArrayRef<llvm::StringRef> gradFields) {
387 InstrBuilder GI(headerStream, cppStream, defStream, builderHeaderStream,
388 builderCppStream, irGenStream, name_ + "Grad",
389 isBackendSpecific_);
390
391 // The new 'Grad' class will have all of the fields of the current class.
392 GI.ty_ = ty_;
393 GI.members_ = members_;
394 GI.extraMethods_ = extraMethods_;
395
396 // Add the operands that we'll use in the grad instruction.
397 for (const auto &op : operands_) {
398 for (const auto &field : originalFields) {
399 if (field == op.first) {
400 // We may only read from the original weight operands.
401 GI.addOperand(op.first, OperandKind::In);
402 }
403 }
404 }
405
406 // Add the new 'grad' operands for the gradients.
407 for (const auto &op : operands_) {
408 for (const auto &field : gradFields) {
409 if (field == op.first) {
410 GI.addOperand(op.first + "Grad", negateOperandKind(op.second));
411 }
412 }
413 }
414
415 // Copy over the auto-verify information, updating the operand names for
416 // gradient.
417 for (auto &verifPair : autoVerificationPairs_) {
418 auto newPair = std::make_pair(verifPair.first, std::vector<std::string>());
419 for (auto &opName : verifPair.second) {
420 if (std::find(gradFields.begin(), gradFields.end(), opName) !=
421 gradFields.end()) {
422 newPair.second.push_back(opName + "Grad");
423 }
424 if (std::find(originalFields.begin(), originalFields.end(), opName) !=
425 originalFields.end()) {
426 newPair.second.push_back(opName);
427 }
428 }
429 GI.autoVerificationPairs_.push_back(newPair);
430 }
431}
432
433void InstrBuilder::emitAutoIRGen(std::ostream &os) const {
434 if (autoIRGenNodeName.empty()) {
435 return;
436 }
437
438 os << "case glow::Kinded::Kind::" << autoIRGenNodeName << "NodeKind: {\n";
439 os << " auto *CN__ = cast<" << autoIRGenNodeName << "Node>(N);\n";
440
441 // Note: The convention is for Nodes to have 'Input's and 'Output's, and for
442 // Instrs to have 'Src's and 'Dest's. Thus we map between the two below.
443
444 // A list of pairs (nodeResultName, destOpName).
445 llvm::SmallVector<std::pair<std::string, std::string>, 4>
446 nodeResultNameToValueName;
447 for (const auto &opPair : operands_) {
448 // Skip the scratch operands for this instruction since they are not
449 // registered as node operands.
450 if (opPair.second == OperandKind::Scratch) {
451 continue;
452 }
453 if (opPair.second == OperandKind::In) {
454 // All inputs of a node were mapped to the glow::Values already.
455 // So, just lookup for each input operand it's Value by using the
456 // corresponding input of a node as a key.
457 const std::string opNodeName =
458 (opPair.first == "Src") ? "Input" : opPair.first;
459 os << " auto *" << opPair.first << " = valueForNode(CN__->get"
460 << opNodeName << "());\n";
461 } else if (opPair.second == OperandKind::Out) {
462 // Remember for each output operand which result of a node produces it.
463 auto destOpName = opPair.first;
464 auto resNodeName = (destOpName == "Dest") ? "Result" : destOpName;
465 nodeResultNameToValueName.emplace_back(
466 std::make_pair(resNodeName, destOpName));
467 }
468 }
469
470 assert(!nodeResultNameToValueName.empty() &&
471 "Didn't find a result; Maybe using InOut which isn't yet supported");
472 os << " std::string allocName = std::string(N->getName()) + \".res\";\n";
473 // Allocate activations for all output operands.
474 for (auto &kv : nodeResultNameToValueName) {
475 auto &nodeResultName = kv.first;
476 auto &valueName = kv.second;
477 // Create activation for the output operand with name valueName using the
478 // type of the corresponding node result nodeResultName.
479 os << " auto *" << valueName
480 << "__ = builder_.createAllocActivationInst(allocName,"
481 << "CN__->get" << nodeResultName << "().getType());\n";
482 }
483 os << " auto *V = builder_.create" << name_ << "Inst(N->getName()";
484
485 // Pass down all the output operand Values as Instruction's constructor
486 // arguments.
487 for (auto &kv : nodeResultNameToValueName) {
488 auto &valueName = kv.second;
489 os << ", " << valueName << "__";
490 }
491 // Pass down all the input operand Values as Instruction's constructor
492 // arguments.
493 for (const auto &opPair : operands_) {
494 if (opPair.second == OperandKind::In) {
495 os << ", " << opPair.first;
496 }
497 }
498 // Pass down all the additional members as Instruction's constructor
499 // arguments.
500 for (const auto &memPair : members_) {
501 os << ", CN__->get" << memPair.second << "()";
502 }
503 os << ");\n";
504
505 os << " if (N->hasPredicate()) { "
506 "V->setPredicate(valueForNode(N->getPredicate())); }\n";
507 // Register which outputs of a node are mapped to which output operands of the
508 // generated instruction.
509 for (auto &kv : nodeResultNameToValueName) {
510 auto &nodeResultName = kv.first;
511 auto &valueName = kv.second;
512 os << " registerIR(CN__->get" << nodeResultName << "(), V->get"
513 << valueName << "());\n";
514 }
515 os << " nodeToInstr_[N] = V;\n";
516 os << " break;\n";
517 os << "}\n";
518}
519
520InstrBuilder &InstrBuilder::addMember(MemberType type,
521 const std::string &name) {
522 MemberTypeInfo *typeInfo = nullptr;
523
524 if (type == MemberType::TypeRef) {
525 typeInfo = &kTypeRefTypeInfo;
526 } else if (type == MemberType::Float) {
527 typeInfo = &kFloatTypeInfo;
528 } else if (type == MemberType::Unsigned) {
529 typeInfo = &kUnsignedTypeInfo;
530 } else if (type == MemberType::Boolean) {
531 typeInfo = &kBooleanTypeInfo;
532 } else if (type == MemberType::Int64) {
533 typeInfo = &kInt64TypeInfo;
534 } else if (type == MemberType::String) {
535 typeInfo = &kStringTypeInfo;
536 } else if (type == MemberType::VectorFloat) {
537 typeInfo = &kVectorFloatTypeInfo;
538 } else if (type == MemberType::VectorUnsigned) {
539 typeInfo = &kVectorUnsignedTypeInfo;
540 } else if (type == MemberType::VectorInt64) {
541 typeInfo = &kVectorInt64TypeInfo;
542 } else if (type == MemberType::VectorSigned) {
543 typeInfo = &kVectorSignedTypeInfo;
544 } else if (type == MemberType::VectorSizeT) {
545 typeInfo = &kVectorSizeTTypeInfo;
546 } else if (type == MemberType::VectorDimT) {
547 typeInfo = &kVectorDimTTypeInfo;
548 } else if (type == MemberType::VectorNodeValue) {
549 typeInfo = &kVectorNodeValueTypeInfo;
550 } else if (type == MemberType::Enum) {
551 typeInfo = &kEnumTypeInfo;
552 } else if (type == MemberType::UserDefinedType) {
553 llvm_unreachable("addMember should be called with a MemberTypeInfo "
554 "parameter in this case");
555 } else {
556 llvm_unreachable("Type not recognized");
557 }
558
559 return addMember(*typeInfo, name);
560}
561
562InstrBuilder &InstrBuilder::addFusedActivation() {
563 // When adding a fused activation we add the activation type and a vector of
564 // floating point parameters for parameterized activations (e.g. min and max
565 // for Clip or alpha factor for LeakyRelu).
566 return addMember(MEMBER_TYPE_INFO(glow::FusedActivation), "FusedActivation")
567 .addMember(MemberType::VectorFloat, "FusedActivationArgs");
568}
569