1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/LLVMIRCodeGen/LLVMIRGen.h"
18#include "glow/Base/DimType.h"
19
20#include "glow/LLVMIRCodeGen/AllocationsInfo.h"
21#include "glow/LLVMIRCodeGen/CommandLine.h"
22#include "glow/LLVMIRCodeGen/LLVMBackend.h"
23
24#include "glow/Graph/Graph.h"
25#include "glow/IR/IRUtils.h"
26#include "glow/IR/Instrs.h"
27#include "glow/Quantization/Base/Base.h"
28
29#include "llvm/ExecutionEngine/ExecutionEngine.h"
30#include "llvm/IR/LegacyPassManager.h"
31#include "llvm/IR/Verifier.h"
32#include "llvm/IRReader/IRReader.h"
33#include "llvm/Support/FileSystem.h"
34#include "llvm/Support/Path.h"
35#include "llvm/Support/SourceMgr.h"
36#include "llvm/Support/TargetSelect.h"
37#include "llvm/Target/TargetMachine.h"
38
39using namespace glow;
40using llvm::cast;
41using llvm::dyn_cast;
42using llvm::isa;
43
44static llvm::cl::opt<bool>
45 dumpIR("dump-llvm-ir",
46 llvm::cl::desc("Dump the LLVM-IR of the jitted code"),
47 llvm::cl::init(false), llvm::cl::cat(getLLVMBackendCat()));
48
49static llvm::cl::opt<bool>
50 dumpJitAsm("dump-llvm-asm",
51 llvm::cl::desc("Dump the textual assembly of the jitted code"),
52 llvm::cl::init(false), llvm::cl::cat(getLLVMBackendCat()));
53
54llvm::cl::opt<bool>
55 emitDebugInfo("g", llvm::cl::desc("Emit debug information for debuggers"),
56 llvm::cl::init(false), llvm::cl::cat(getLLVMBackendCat()));
57
58/// Limitation of number of arguments for `emitDataParallelKernel`.
59constexpr static size_t kArgLimit = 64;
60
61/// Query the TargetMachine to get the pointer size in bits
62static unsigned getPointerNumBits(const llvm::TargetMachine &TM) {
63 return TM.getPointerSize(0) * 8;
64}
65
66LLVMIRGen::LLVMIRGen(const IRFunction *F, AllocationsInfo &allocationsInfo,
67 std::string mainEntryName, llvm::StringRef libjitBC)
68 : F_(F), ctx_(std::make_unique<llvm::LLVMContext>()),
69 allocationsInfo_(allocationsInfo), libjitBC_(libjitBC) {
70 // Legalize main entry name.
71 setMainEntryName(mainEntryName);
72}
73
74LLVMIRGen::LLVMIRGen(const IRFunction *F, AllocationsInfo &allocationsInfo,
75 std::string mainEntryName, llvm::StringRef libjitBC,
76 llvm::ArrayRef<llvm::MemoryBufferRef> objectRegistry)
77 : F_(F), ctx_(std::make_unique<llvm::LLVMContext>()),
78 allocationsInfo_(allocationsInfo), libjitBC_(libjitBC),
79 objectRegistry_(objectRegistry) {
80 // Legalize main entry name.
81 setMainEntryName(mainEntryName);
82}
83
84/// Mutex to protect LLVM's TargetRegistry.
85static std::mutex initTargetMutex;
86
87void LLVMIRGen::initTargetOptions(llvm::TargetOptions &targetOpts,
88 const LLVMBackendOptions &backendOpts) {
89 if (backendOpts.getFloatABI().hasValue()) {
90 targetOpts.FloatABIType = backendOpts.getFloatABI().getValue();
91 }
92 if (!backendOpts.getABIName().empty()) {
93 targetOpts.MCOptions.ABIName = backendOpts.getABIName();
94 }
95}
96
97void LLVMIRGen::initTargetMachine(const LLVMBackendOptions &opts) {
98 // LLVM's TargetRegistry is not thread safe so we add a critical section.
99 std::lock_guard<std::mutex> g(initTargetMutex);
100
101 llvm::InitializeAllTargets();
102 llvm::InitializeAllTargetMCs();
103 llvm::InitializeAllAsmPrinters();
104 llvm::InitializeAllAsmParsers();
105
106 llvm::TargetOptions targetOpts;
107 // Initialize target options in a backend-specific way.
108 initTargetOptions(targetOpts, opts);
109
110 if (opts.getTarget().empty()) {
111 TM_.reset(llvm::EngineBuilder()
112 .setCodeModel(opts.getCodeModel())
113 .setRelocationModel(opts.getRelocModel())
114 .setTargetOptions(targetOpts)
115 .selectTarget(llvm::Triple(), opts.getArch(),
116 LLVMBackend::getHostCPU(),
117 LLVMBackend::getHostFeatures()));
118 } else {
119 TM_.reset(llvm::EngineBuilder()
120 .setCodeModel(opts.getCodeModel())
121 .setRelocationModel(opts.getRelocModel())
122 .setTargetOptions(targetOpts)
123 .selectTarget(llvm::Triple(opts.getTarget()), opts.getArch(),
124 opts.getCPU(), opts.getTargetFeatures()));
125 }
126 assert(TM_ && "Could not initialize the target machine");
127}
128
129llvm::StringRef LLVMIRGen::getBundleName() const { return bundleName_; }
130
131void LLVMIRGen::setBundleName(const std::string &name) {
132 bundleName_ = name.empty() ? "bundle" : legalizeName(name);
133}
134
135llvm::StringRef LLVMIRGen::getSavedBundleName() const {
136 return savedBundleName_;
137}
138
139void LLVMIRGen::setSavedBundleName(const std::string &name) {
140 assert(!name.empty() && "Name cannot be empty");
141 savedBundleName_ = name;
142}
143
144std::string LLVMIRGen::getMainEntryName() const { return mainEntryName_; }
145
146void LLVMIRGen::setMainEntryName(std::string name) {
147 mainEntryName_ = name.empty() ? "main" : legalizeName(name);
148}
149
150llvm::ArrayRef<llvm::MemoryBufferRef> LLVMIRGen::getObjectRegistry() const {
151 return objectRegistry_;
152}
153
154void LLVMIRGen::setObjectRegistry(
155 llvm::ArrayRef<llvm::MemoryBufferRef> objectRegistry) {
156 objectRegistry_ = objectRegistry;
157}
158
159std::vector<std::string> LLVMIRGen::getBundleObjects() const {
160 // Default list of object names.
161 auto bundleObjects = bundleObjects_;
162 // Add object names enforced from command line interface.
163 for (auto bundleObject : bundleObjectsOpt) {
164 bundleObjects.push_back(bundleObject);
165 }
166 return bundleObjects;
167}
168
169void LLVMIRGen::addBundleObject(llvm::StringRef objectName) {
170 // Add bundle object if not already added.
171 auto it =
172 std::find(bundleObjects_.begin(), bundleObjects_.end(), objectName.str());
173 if (it == bundleObjects_.end()) {
174 bundleObjects_.push_back(objectName.str());
175 }
176}
177
178/// Load base addresses of different memory areas so that they can be easily
179/// reused during codegen.
180void LLVMIRGen::loadBaseAddresses(llvm::IRBuilder<> &builder) {
181 auto *F = builder.GetInsertBlock()->getParent();
182
183 // Load the base addresses at the beginning of the entry function once they
184 // are set. They won't change after this point and all relative addressing
185 // computations will simply use them.
186 auto sizeTTy = builder.getIntNTy(getLibjitSizeTWidth());
187 baseActivationsAddr_ = builder.CreatePtrToInt(F->args().begin() + 2, sizeTTy);
188 baseConstantWeightVarsAddr_ =
189 builder.CreatePtrToInt(F->args().begin(), sizeTTy);
190 baseMutableWeightVarsAddr_ =
191 builder.CreatePtrToInt(F->args().begin() + 1, sizeTTy);
192 offsetsArray_ = F->args().begin() + 3;
193}
194
195// Search for the standard library bitcode file on disk and load it into an
196// LLVM module. We search for the standard library around the current executable
197// and also in the current directory.
198static std::unique_ptr<llvm::Module>
199loadStandardLibrary(llvm::LLVMContext *ctx, llvm::StringRef filename,
200 llvm::StringRef libjitBC) {
201 using llvm::sys::path::append;
202 using llvm::sys::path::parent_path;
203
204 llvm::SMDiagnostic error;
205
206 // Parse the compiled-in image of libjit and return the resulting Module.
207 // checking for and reporting errors from parseIR.
208
209 auto mod = llvm::parseIR(
210 llvm::MemoryBufferRef(
211 llvm::StringRef(reinterpret_cast<const char *>(libjitBC.data()),
212 libjitBC.size()),
213 "libjit.bc"),
214 error, *ctx);
215
216 if (!mod) {
217 error.print("LLVMIRGen", llvm::errs());
218 }
219 return mod;
220}
221
222/// Register a diagnostics handler that prevents the compiler from printing to
223/// stdout.
224static void registerEmptyDiagHandler(llvm::LLVMContext &ctx) {
225 ctx.setDiagnosticHandlerCallBack(
226 [](const llvm::DiagnosticInfo &DI, void *Context) {
227 // Do not emit any warnings or diagnostics when JITting.
228 });
229}
230
231void LLVMIRGen::initCodeGen() {
232 // Load the jit library as a new module.
233 llmodule_ = loadStandardLibrary(&getLLVMContext(), "libjit.bc", libjitBC_);
234 CHECK(llmodule_.get()) << "Unable to load the JIT library.";
235
236 // By default, LLVM would emit some diagnostics, remarks, etc. It is fine for
237 // a static compiler, but not necessary for a JIT. Let's disable it by
238 // providing a dummy diagnostics handler, that does not emit anything.
239 // In particular, this allows us to get rid of the annoying "cannot vectorize"
240 // warnings.
241 registerEmptyDiagHandler(getLLVMContext());
242
243 // Assign the target information to the module.
244 llmodule_->setDataLayout(getTargetMachine().createDataLayout());
245
246 // Initialize the debug information emission.
247 initDebugInfo();
248}
249
250/// \returns the LLVM type corresponding to the type of elements stored in \p
251/// val.
252llvm::Type *LLVMIRGen::getElementType(llvm::IRBuilder<> &builder,
253 const Value *val) {
254 switch (val->getElementType()) {
255 case ElemKind::Int64ITy:
256 return builder.getInt64Ty();
257 case ElemKind::FloatTy:
258 return builder.getFloatTy();
259 case ElemKind::Float16Ty:
260 llvm_unreachable("Not implemented");
261 case ElemKind::BFloat16Ty:
262 llvm_unreachable("Not implemented");
263 case ElemKind::Float64Ty:
264 return builder.getDoubleTy();
265 case ElemKind::Int8QTy:
266 return builder.getInt8Ty();
267 case ElemKind::UInt8QTy:
268 llvm_unreachable("Not implemented");
269 case ElemKind::Int16QTy:
270 return builder.getInt16Ty();
271 case ElemKind::Int32QTy:
272 return builder.getInt32Ty();
273 case ElemKind::Int64QTy:
274 return builder.getInt64Ty();
275 case ElemKind::UInt8ITy:
276 return builder.getInt8Ty();
277 case ElemKind::Int32ITy:
278 return builder.getInt32Ty();
279 case ElemKind::UInt8FusedQTy:
280 return builder.getInt8Ty();
281 case ElemKind::UInt8FusedFP16QTy:
282 return builder.getInt8Ty();
283 case ElemKind::UInt4FusedFP16QTy:
284 return builder.getInt8Ty();
285 case ElemKind::UInt4FusedQTy:
286 return builder.getInt8Ty();
287 case ElemKind::BoolTy:
288 static_assert(sizeof(bool) == sizeof(int8_t),
289 "Bool is expected to be the same size as int8.");
290 return builder.getInt8Ty();
291 }
292 return nullptr;
293}
294
295void LLVMIRGen::performCodeGen() {
296 // Create the entry function into the LLVM module.
297 auto int8PtrTy = llvm::Type::getInt8PtrTy(getLLVMContext());
298 auto dimTPtrTy = llvm::Type::getIntNPtrTy(getLLVMContext(), DIM_T_BITWIDTH);
299 // The entry point has the following API:
300 // int entry(uint8_t *baseConstantWeightVars,
301 // uint8_t *baseInoutWeightVars,
302 // uint8_t *baseActivations,
303 // dim_t *offsets);
304 llvm::Type *retTy =
305 llvm::Type::getIntNTy(getLLVMContext(), getLibjitIntWidth());
306 llvm::FunctionType *jitFuncTy = llvm::FunctionType::get(
307 retTy, {int8PtrTy, int8PtrTy, int8PtrTy, dimTPtrTy}, false);
308 llvmF_ = llvm::Function::Create(jitFuncTy, llvm::Function::ExternalLinkage,
309 "main", llmodule_.get());
310 emittedLLVMFunctions_.emplace_back(llvmF_);
311
312 // Setup the entry basic block and initialize the IR builder.
313 llvm::BasicBlock *entry_bb =
314 llvm::BasicBlock::Create(getLLVMContext(), "entry", llvmF_);
315 builder_ = glow::make_unique<llvm::IRBuilder<>>(entry_bb);
316 // Terminate the function with a return instruction.
317 auto zero = builder_->getIntN(getLibjitIntWidth(), 0);
318 auto *ret = builder_->CreateRet(zero);
319 // Emit all the code before the retrun instruction.
320 builder_->SetInsertPoint(ret);
321
322 instrNumbering_.reset(new InstructionNumbering(*F_));
323 generateFunctionDebugInfo();
324 loadBaseAddresses(*builder_);
325 generateLLVMIRForModule(*builder_);
326}
327
328void LLVMIRGen::finishCodeGen() {
329 if (dumpIR) {
330 llvm::outs() << "LLVM module before optimizations:\n";
331 llmodule_->print(llvm::outs(), nullptr);
332 }
333 // Perform verification if no debug info is being emitted.
334 // Otherwise, the verification is performed later by
335 // generateDebugInfo, once the debug info emission is finalized.
336 if (!emitDebugInfo) {
337 // Perform verification, but ignore any debug info errors for now.
338 // Debug info errors will be checked later by generateDebugInfo.
339 bool brokenDebugInfo = false;
340 (void)brokenDebugInfo;
341 assert(!llvm::verifyModule(getModule(), &llvm::errs(), &brokenDebugInfo) &&
342 "LLVM module verification error");
343 }
344
345 // Optimize the module.
346 optimizeLLVMModule(&getModule(), getTargetMachine());
347
348 // Generate debug information.
349 generateModuleDebugInfo();
350
351 if (dumpIR) {
352 llvm::outs() << "LLVM module after optimizations:\n";
353 llmodule_->print(llvm::outs(), nullptr);
354 }
355
356 if (dumpJitAsm) {
357 llvm::SmallVector<char, 0> asmBuffer;
358 llvm::raw_svector_ostream asmStream(asmBuffer);
359 llvm::legacy::PassManager PM;
360#if FACEBOOK_INTERNAL && LLVM_VERSION_MAJOR < 8
361 getTargetMachine().addPassesToEmitFile(
362 PM, asmStream, llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile);
363#elif LLVM_VERSION_MAJOR < 10
364 getTargetMachine().addPassesToEmitFile(
365 PM, asmStream, nullptr,
366 llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile);
367#else
368 getTargetMachine().addPassesToEmitFile(PM, asmStream, nullptr,
369 llvm::CGFT_AssemblyFile);
370#endif
371
372 PM.run(*llmodule_);
373 llvm::outs() << asmStream.str();
374 }
375}
376
377llvm::Value *LLVMIRGen::emitValueAddress(llvm::IRBuilder<> &builder,
378 const glow::Value *val) {
379 assert(allocationsInfo_.allocatedAddress_.count(val) &&
380 "Value address was not allocated");
381 llvm::Type *T = nullptr;
382
383 switch (val->getElementType()) {
384 case ElemKind::FloatTy:
385 T = llvm::Type::getFloatPtrTy(getLLVMContext());
386 break;
387 case ElemKind::Float16Ty:
388 T = llvm::Type::getInt16PtrTy(getLLVMContext());
389 break;
390 case ElemKind::BFloat16Ty:
391 T = llvm::Type::getInt16PtrTy(getLLVMContext());
392 break;
393 case ElemKind::Int8QTy:
394 T = llvm::Type::getInt8PtrTy(getLLVMContext());
395 break;
396 case ElemKind::UInt8QTy:
397 T = llvm::Type::getInt8PtrTy(getLLVMContext());
398 break;
399 case ElemKind::Int16QTy:
400 T = llvm::Type::getInt16PtrTy(getLLVMContext());
401 break;
402 case ElemKind::Int32QTy:
403 T = llvm::Type::getInt32PtrTy(getLLVMContext());
404 break;
405 case ElemKind::Int64QTy:
406 T = llvm::Type::getInt64PtrTy(getLLVMContext());
407 break;
408 case ElemKind::Int64ITy:
409 T = llvm::Type::getInt64PtrTy(getLLVMContext());
410 break;
411 case ElemKind::Int32ITy:
412 T = llvm::Type::getInt32PtrTy(getLLVMContext());
413 break;
414 case ElemKind::UInt8ITy:
415 T = llvm::Type::getInt8PtrTy(getLLVMContext());
416 break;
417 case ElemKind::UInt8FusedQTy:
418 T = llvm::Type::getInt8PtrTy(getLLVMContext());
419 break;
420 case ElemKind::UInt8FusedFP16QTy:
421 T = llvm::Type::getInt8PtrTy(getLLVMContext());
422 break;
423 case ElemKind::UInt4FusedFP16QTy:
424 T = llvm::Type::getInt8PtrTy(getLLVMContext());
425 break;
426 case ElemKind::UInt4FusedQTy:
427 T = llvm::Type::getInt8PtrTy(getLLVMContext());
428 break;
429 case ElemKind::BoolTy:
430 T = llvm::Type::getInt8PtrTy(getLLVMContext());
431 break;
432 default:
433 LOG(FATAL) << "Unsupported element type: "
434 << Type::getElementName(val->getElementType()).str();
435 }
436
437 assert(allocationsInfo_.valueNumbers_.count(val));
438 auto &kindAndValue = allocationsInfo_.valueNumbers_[val];
439
440 // Get the required base address.
441 llvm::Value *baseAddrValue = nullptr;
442 switch (kindAndValue.first) {
443 case AllocationsInfo::ValueKind::Activation:
444 baseAddrValue = baseActivationsAddr_;
445 break;
446 case AllocationsInfo::ValueKind::ConstantWeight:
447 baseAddrValue = baseConstantWeightVarsAddr_;
448 break;
449 case AllocationsInfo::ValueKind::MutableWeight:
450 baseAddrValue = baseMutableWeightVarsAddr_;
451 break;
452 }
453
454 // Use relative addressing.
455 // Get offset.
456 auto sizeTTy = builder.getIntNTy(getLibjitSizeTWidth());
457 auto dimTTy = builder.getIntNTy(DIM_T_BITWIDTH);
458
459 auto valueIdx = llvm::ConstantInt::get(dimTTy, kindAndValue.second);
460 auto offsetAddr = builder.CreateGEP(dimTTy, offsetsArray_, valueIdx);
461 auto offsetValue = builder.CreateLoad(dimTTy, offsetAddr);
462 // Add offset to the base address.
463 llvm::Value *addr = builder.CreateAdd(
464 baseAddrValue, builder.CreateZExt(offsetValue, sizeTTy));
465 return builder.CreateIntToPtr(addr, T);
466}
467
468llvm::Value *
469LLVMIRGen::emitConstOffsetsArray(llvm::IRBuilder<> &builder,
470 const AllocationsInfo &allocationsInfo) {
471 constexpr const char *offsetsArrayName = "offsetsArray";
472 auto dimTType = builder.getIntNTy(DIM_T_BITWIDTH);
473 std::vector<llvm::Constant *> elems(allocationsInfo.valueNumbers_.size());
474 dim_t maxOffset = 0;
475 for (auto &I : allocationsInfo.valueNumbers_) {
476 auto *V = I.first;
477 auto offset = I.second.second;
478 elems[offset] = llvm::ConstantInt::get(
479 dimTType, allocationsInfo.allocatedAddress_.lookup(V));
480 maxOffset = std::max(maxOffset, (dim_t)offset);
481 }
482 elems.resize(maxOffset + 1);
483 auto *arr = llvm::ConstantArray::get(
484 llvm::ArrayType::get(dimTType, elems.size()), elems);
485 // Ensure that the same casted global variable is used for the equivalent
486 // const arrays. This is important for the later function specialization pass.
487 // LLVM does not do it automatically for this code pattern involving global
488 // variables. It also reduces the number of variables.
489 auto &constArrayVar = constArrayPtrs_[arr];
490 auto oldG =
491 getModule().getGlobalVariable(offsetsArrayName, /* allowInternal */ true);
492 if (constArrayVar && constArrayVar->getType() == dimTType->getPointerTo()) {
493 return constArrayVar;
494 }
495 if (oldG) {
496 oldG->setName("offsetsArrayOld");
497 }
498 auto *M = builder.GetInsertBlock()->getModule();
499 auto *G = new llvm::GlobalVariable(*M, arr->getType(), true,
500 llvm::GlobalValue::InternalLinkage, arr,
501 offsetsArrayName);
502 constArrayVar = builder.CreateBitCast(G, dimTType->getPointerTo());
503 if (oldG) {
504 // Replace the old offsetsArray by the new one and remove the old.
505 oldG->replaceAllUsesWith(G);
506 oldG->eraseFromParent();
507 }
508 return constArrayVar;
509}
510
511llvm::Value *LLVMIRGen::emitConstI32Array(llvm::IRBuilder<> &builder,
512 llvm::ArrayRef<int32_t> vals) {
513 std::vector<llvm::Constant *> elems;
514 for (auto I : vals) {
515 elems.push_back(builder.getInt32(I));
516 }
517 return emitConstArray(builder, elems, builder.getInt32Ty());
518}
519
520llvm::Value *LLVMIRGen::emitConstFloatArray(llvm::IRBuilder<> &builder,
521 llvm::ArrayRef<float> vals) {
522 std::vector<llvm::Constant *> elems;
523 for (auto I : vals) {
524 elems.push_back(llvm::ConstantFP::get(
525 llvm::Type::getFloatTy(getLLVMContext()), (float)I));
526 }
527 return emitConstArray(builder, elems,
528 llvm::Type::getFloatTy(getLLVMContext()));
529}
530
531llvm::Value *LLVMIRGen::emitConstArray(llvm::IRBuilder<> &builder,
532 llvm::ArrayRef<llvm::Constant *> vals,
533 llvm::Type *elemTy) {
534 std::vector<llvm::Constant *> elems;
535 for (auto I : vals) {
536 elems.push_back(cast<llvm::Constant>(builder.CreateBitCast(I, elemTy)));
537 }
538 auto *arr = llvm::ConstantArray::get(
539 llvm::ArrayType::get(elemTy, elems.size()), elems);
540 // Ensure that the same casted global variable is used for the equivalent
541 // const arrays. This is important for the later function specialization pass.
542 // LLVM does not do it automatically for this code pattern involving global
543 // variables. It also reduces the number of variables.
544 auto &constArrayVar = constArrayPtrs_[arr];
545 if (constArrayVar && constArrayVar->getType() == elemTy->getPointerTo())
546 return constArrayVar;
547
548 auto *M = builder.GetInsertBlock()->getModule();
549
550 auto *G = new llvm::GlobalVariable(*M, arr->getType(), true,
551 llvm::GlobalValue::InternalLinkage, arr);
552 constArrayVar = builder.CreateBitCast(G, elemTy->getPointerTo());
553 return constArrayVar;
554}
555
556void LLVMIRGen::emitArrayStore(llvm::IRBuilder<> &builder,
557 llvm::ArrayRef<llvm::Value *> vals,
558 llvm::Value *basePtr, unsigned baseIdx) {
559 for (size_t idx = 0, end = vals.size(); idx < end; ++idx) {
560 assert(vals[idx]->getType()->getPointerTo() == basePtr->getType() &&
561 "Mismatch between pointer and value type!");
562 auto *storeIdx = builder.getInt32(idx + baseIdx);
563 auto *storeAddr = builder.CreateGEP(basePtr, storeIdx);
564 builder.CreateStore(vals[idx], storeAddr);
565 }
566}
567
568llvm::Value *LLVMIRGen::emitValueDims(llvm::IRBuilder<> &builder,
569 const glow::Value *val) {
570 auto dims = val->dims();
571 return emitConstDimTArray(builder, dims);
572}
573
574template <class InstructionTy>
575llvm::Value *LLVMIRGen::emitConstFloatActivationArgs(llvm::IRBuilder<> &builder,
576 const InstructionTy *I) {
577 return emitConstFloatArray(builder, I->getFusedActivationArgs());
578}
579
580template <class InstructionTy>
581llvm::Value *LLVMIRGen::emitConstQuantActivationArgs(llvm::IRBuilder<> &builder,
582 const InstructionTy *I) {
583 auto actArgsF = I->getFusedActivationArgs();
584 std::vector<int32_t> actArgsQ;
585 auto *destTy = I->getDest()->getType();
586 switch (I->getFusedActivation()) {
587 case FusedActivation::NONE:
588 case FusedActivation::RELU:
589 assert(actArgsF.size() == 0 && "Invalid number of activation parameters!");
590 break;
591 case FusedActivation::CLIP: {
592 // For Clip we quantize min/max using the output quantization params.
593 assert(actArgsF.size() == 2 &&
594 "Invalid number of parameters for fused Clip activation!");
595 float minF = actArgsF[0];
596 float maxF = actArgsF[1];
597 TensorQuantizationParams TQP{destTy->getScale(), destTy->getOffset()};
598 int32_t minQ = quantization::quantize<int32_t>(minF, TQP);
599 int32_t maxQ = quantization::quantize<int32_t>(maxF, TQP);
600 actArgsQ.push_back(minQ);
601 actArgsQ.push_back(maxQ);
602 break;
603 }
604 case FusedActivation::SIGMOID:
605 LOG(FATAL) << "Fused Sigmoid for quantized type not supported!";
606 break;
607 case FusedActivation::TANH:
608 LOG(FATAL) << "Fused Tanh for quantized type not supported!";
609 break;
610 case FusedActivation::LEAKY_RELU: {
611 // For LeakyRelu we transform the alpha parameter into pre/post/scale.
612 assert(actArgsF.size() == 1 &&
613 "Invalid number of parameters for fused LeakyRelu activation!");
614 float alpha = actArgsF[0];
615 auto alphaScaleParam = quantization::quantizeScaleOffset32To8(alpha, 0);
616 actArgsQ.push_back(alphaScaleParam.pre);
617 actArgsQ.push_back(alphaScaleParam.post);
618 actArgsQ.push_back(alphaScaleParam.scale);
619 break;
620 }
621 default:
622 LOG(FATAL) << "Unsupported fused activation type!";
623 }
624 return emitConstI32Array(builder, actArgsQ);
625}
626
627llvm::Value *LLVMIRGen::emitValueSize(llvm::IRBuilder<> &builder,
628 const glow::Value *val) {
629 return builder.getIntN(DIM_T_BITWIDTH, val->size());
630}
631
632llvm::Value *LLVMIRGen::emitConstF32(llvm::IRBuilder<> &builder, float val) {
633 return llvm::ConstantFP::get(llvm::Type::getFloatTy(getLLVMContext()), val);
634}
635
636llvm::Value *LLVMIRGen::emitConstI32(llvm::IRBuilder<> &builder, int32_t val) {
637 return builder.getInt32(val);
638}
639
640llvm::Value *LLVMIRGen::emitConstI16(llvm::IRBuilder<> &builder, int16_t val) {
641 return builder.getInt16(val);
642}
643
644llvm::Value *LLVMIRGen::emitConstI8(llvm::IRBuilder<> &builder, int8_t val) {
645 return builder.getInt8(val);
646}
647
648llvm::Value *LLVMIRGen::emitConstI1(llvm::IRBuilder<> &builder, bool val) {
649 return builder.getInt1(val);
650}
651
652llvm::Value *LLVMIRGen::emitConstSizeT(llvm::IRBuilder<> &builder, size_t val) {
653 return builder.getIntN(getLibjitSizeTWidth(), val);
654}
655
656llvm::Value *LLVMIRGen::emitConstDimT(llvm::IRBuilder<> &builder, dim_t val) {
657 return builder.getIntN(sizeof(dim_t) * 8, val);
658}
659
660llvm::Value *LLVMIRGen::emitConst(llvm::IRBuilder<> &builder, float val,
661 glow::ElemKind kind) {
662 switch (kind) {
663 case ElemKind::FloatTy:
664 return llvm::ConstantFP::get(llvm::Type::getFloatTy(getLLVMContext()), val);
665 case ElemKind::Float16Ty:
666 llvm_unreachable("Not implemented");
667 case ElemKind::BFloat16Ty:
668 llvm_unreachable("Not implemented");
669 case ElemKind::Float64Ty:
670 return llvm::ConstantFP::get(llvm::Type::getDoubleTy(getLLVMContext()),
671 val);
672 case ElemKind::Int64ITy:
673 return builder.getInt64(static_cast<int64_t>(val));
674 case ElemKind::Int8QTy:
675 return builder.getInt8(static_cast<int8_t>(val));
676 case ElemKind::UInt8QTy:
677 llvm_unreachable("Not implemented");
678 case ElemKind::Int16QTy:
679 return builder.getInt16(static_cast<int16_t>(val));
680 case ElemKind::Int32QTy:
681 return builder.getInt32(static_cast<int32_t>(val));
682 case ElemKind::UInt8ITy:
683 return builder.getInt8(static_cast<uint8_t>(val));
684 case ElemKind::Int64QTy:
685 return builder.getInt64(static_cast<int64_t>(val));
686 case ElemKind::Int32ITy:
687 return builder.getInt32(static_cast<int32_t>(val));
688 case ElemKind::UInt8FusedQTy:
689 return builder.getInt8(static_cast<int8_t>(val));
690 case ElemKind::UInt8FusedFP16QTy:
691 return builder.getInt8(static_cast<int8_t>(val));
692 case ElemKind::UInt4FusedFP16QTy:
693 return builder.getInt8(static_cast<int8_t>(val));
694 case ElemKind::UInt4FusedQTy:
695 return builder.getInt8(static_cast<int8_t>(val));
696 case ElemKind::BoolTy:
697 return builder.getInt8(static_cast<int8_t>(val));
698 }
699 llvm_unreachable("Unknown element type");
700}
701
702llvm::Value *LLVMIRGen::emitStringConst(llvm::IRBuilder<> &builder,
703 llvm::StringRef str) {
704 llvm::Constant *constStrArray =
705 llvm::ConstantDataArray::getString(getLLVMContext(), str, true);
706 llvm::GlobalVariable *gvarStr = new llvm::GlobalVariable(
707 *llmodule_, constStrArray->getType(), true,
708 llvm::GlobalValue::PrivateLinkage, constStrArray, ".str");
709#if LLVM_VERSION_MAJOR >= 10
710 gvarStr->setAlignment(llvm::MaybeAlign(1));
711#else
712 gvarStr->setAlignment(1);
713#endif
714 // Add unnamed_addr attribute to enable constmerge pass.
715 gvarStr->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
716
717 return builder.CreateBitCast(gvarStr, builder.getInt8PtrTy());
718}
719
720void LLVMIRGen::markArgAsUnspecialized(llvm::Value *val) {
721 dontSpecializeArgsSet_.insert(val);
722}
723
724static std::string createName(const std::string &name, ElemKind elemTy) {
725 switch (elemTy) {
726 case ElemKind::FloatTy:
727 return name + "_f";
728 case ElemKind::Float16Ty:
729 return name + "_fp16";
730 case ElemKind::BFloat16Ty:
731 return name + "_bfloat16";
732 case ElemKind::Int8QTy:
733 return name + "_i8";
734 case ElemKind::Int16QTy:
735 return name + "_i16";
736 case ElemKind::Int32QTy:
737 return name + "_i32";
738 case ElemKind::Int32ITy:
739 return name + "_i32";
740 case ElemKind::Int64ITy:
741 return name + "_u";
742 case ElemKind::BoolTy:
743 return name + "_b";
744 default:
745 LOG(FATAL) << "Unsupported element type: "
746 << Type::getElementName(elemTy).str();
747 }
748}
749
750llvm::Function *
751LLVMIRGen::getFunction(const std::string &name,
752 llvm::ArrayRef<glow::ElemKind> elemTyArray) {
753 auto strName = "libjit_" + name;
754
755 for (auto elTy : elemTyArray) {
756 strName = createName(strName, elTy);
757 }
758 auto *F = llmodule_->getFunction(strName);
759 CHECK(F) << "Unable to load the function: " << strName.c_str();
760 return F;
761}
762
763llvm::Function *LLVMIRGen::getFunction(const std::string &name) {
764 return getFunction(name, llvm::ArrayRef<ElemKind>{});
765}
766
767llvm::Function *LLVMIRGen::getFunction(const std::string &name,
768 ElemKind elemTy) {
769 return getFunction(name, llvm::ArrayRef<ElemKind>{elemTy});
770}
771
772llvm::Function *LLVMIRGen::getLLVMFunction() { return llvmF_; }
773
774llvm::CallInst *LLVMIRGen::createCall(llvm::IRBuilder<> &builder,
775 llvm::Function *callee,
776 llvm::ArrayRef<llvm::Value *> args,
777 bool checked) {
778#ifndef NDEBUG
779 llvm::FunctionType *FTy = callee->getFunctionType();
780 assert((args.size() == FTy->getNumParams() ||
781 (FTy->isVarArg() && args.size() > FTy->getNumParams())) &&
782 "Calling a function with bad signature: wrong number of arguments.");
783
784 for (unsigned i = 0; i != args.size(); ++i) {
785 assert((i >= FTy->getNumParams() ||
786 FTy->getParamType(i) == args[i]->getType()) &&
787 "Calling a function with a bad signature: argument type mismatch.");
788 }
789#endif
790 if (!checked || !callee->getReturnType()->isIntegerTy()) {
791 return builder.CreateCall(callee, args);
792 }
793 // Check if callee returned an error, i.e. non-zero result.
794 // Emit a return with this error code in this case.
795 auto *result = builder.CreateCall(callee, args);
796 auto *zero = builder.getIntN(result->getType()->getIntegerBitWidth(), 0);
797 auto *cond = builder.CreateICmpNE(result, zero);
798 auto insertionPoint = builder.GetInsertPoint();
799 auto *currentBB = result->getParent();
800 auto *falseBB =
801 currentBB->splitBasicBlock(builder.GetInsertPoint(), "cont_bb");
802 auto *trueBB = llvm::BasicBlock::Create(getLLVMContext(), "error_bb",
803 result->getFunction());
804 builder.SetInsertPoint(currentBB->getTerminator());
805 builder.CreateCondBr(cond, trueBB, falseBB);
806 currentBB->getTerminator()->eraseFromParent();
807 builder.SetInsertPoint(trueBB);
808 auto *castedResult =
809 builder.CreateBitCast(result, builder.getIntNTy(getLibjitIntWidth()));
810 builder.CreateRet(castedResult);
811 builder.SetInsertPoint(falseBB, insertionPoint);
812 builder.SetInsertPoint(falseBB->getTerminator());
813 return result;
814}
815
816llvm::CallInst *
817LLVMIRGen::createCheckedCall(llvm::IRBuilder<> &builder, llvm::Function *callee,
818 llvm::ArrayRef<llvm::Value *> args) {
819 return createCall(builder, callee, args, /* checked */ true);
820}
821
822llvm::CallInst *
823LLVMIRGen::createUncheckedCall(llvm::IRBuilder<> &builder,
824 llvm::Function *callee,
825 llvm::ArrayRef<llvm::Value *> args) {
826 return createCall(builder, callee, args, /* checked */ false);
827}
828
829std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
830LLVMIRGen::createLoop(llvm::IRBuilder<> &builder, llvm::LLVMContext &ctx,
831 llvm::Value *numElements) const {
832 auto dimTTy = builder.getIntNTy(DIM_T_BITWIDTH);
833 auto *initVal = llvm::ConstantInt::get(dimTTy, 0);
834
835 // Make the new basic block for the loop header. Insert it after current
836 // block.
837 llvm::Function *func = builder.GetInsertBlock()->getParent();
838 auto *preheaderBB = builder.GetInsertBlock();
839 auto *loopBB = llvm::BasicBlock::Create(ctx, "loop", func);
840
841 // Insert a jump from the current block to the loopBB.
842 builder.CreateBr(loopBB);
843
844 // Start insertion in LoopBB.
845 builder.SetInsertPoint(loopBB);
846
847 // Create the PHI node with an entry for initial value.
848 llvm::PHINode *var = builder.CreatePHI(dimTTy, 2);
849 var->addIncoming(initVal, preheaderBB);
850
851 // Emit the step value.
852 auto *stepVal = llvm::ConstantInt::get(dimTTy, 1);
853 auto *nextVal = builder.CreateAdd(var, stepVal, "nextvar", /* HasNUW */ true,
854 /* HasNSW */ true);
855 // Compute the end condition.
856 auto *endCond = builder.CreateICmpULT(nextVal, numElements, "loopcond");
857
858 // Create the "after loop" block and insert it.
859 auto *afterBB = llvm::BasicBlock::Create(ctx, "afterloop", func);
860
861 // Insert the conditional branch at the end of the loopBB.
862 auto *backEdge = builder.CreateCondBr(endCond, loopBB, afterBB);
863 // Add explicit loop llvm.loop.vectorize.enable metadata to the generated
864 // loop to help the LLVM vectorizer. Without this metadata, LLVM loop
865 // vectorizer bails on long data-parallel loops with a lot of operations. This
866 // metadata forces it to vectorize them anyways.
867 llvm::SmallVector<llvm::Metadata *, 4> args;
868 // Reserve operand 0 for loop id self reference.
869 //
870 // Initialize it with a special temporary metadata node, which is typically
871 // used to create cyclic metadata structures. tmpMD is a unique_ptr and thus
872 // will be freed automatically when it goes out of scope.
873 llvm::TempMDTuple tmpMD = llvm::MDNode::getTemporary(ctx, llvm::None);
874 args.push_back(tmpMD.get());
875 llvm::Metadata *Vals[] = {
876 // Reserve operand 0 for loop id self reference.
877 llvm::MDString::get(ctx, "llvm.loop.vectorize.enable"),
878 llvm::ConstantAsMetadata::get(
879 llvm::ConstantInt::get(llvm::Type::getInt1Ty(ctx), true))};
880 args.push_back(llvm::MDNode::get(ctx, Vals));
881 auto *loopMD = llvm::MDNode::get(ctx, args);
882 // Set the first operand to itself.
883 loopMD->replaceOperandWith(0, loopMD);
884 backEdge->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
885 // Add a new entry to the PHI node for the backedge.
886 var->addIncoming(nextVal, loopBB);
887 builder.SetInsertPoint(afterBB);
888 return std::make_pair(loopBB, afterBB);
889}
890
891/// Emit the address of the buffer \p v inside a data-parallel kernel \p kernel
892/// using the mapping provided by \p bufferToArgNum.
893llvm::Value *
894LLVMIRGen::emitBufferAddress(llvm::IRBuilder<> &builder, Value *val,
895 llvm::Function *kernel,
896 llvm::DenseMap<Value *, int> &bufferToArgNum) {
897 assert(bufferToArgNum.count(val) && "Buffer should be in the map");
898 return kernel->args().begin() + bufferToArgNum[val];
899}
900
901/// Implementation of emitDataParallelKernel where we guarantee that the number
902/// of arguments will be bound by 64.
903void LLVMIRGen::emitDataParallelKernelImpl(
904 llvm::IRBuilder<> &builder, llvm::ArrayRef<const Instruction *> bundle,
905 llvm::ArrayRef<llvm::Type *> argTypes,
906 llvm::DenseMap<Value *, int> &bufferToArgNum,
907 llvm::ArrayRef<llvm::Value *> buffers) {
908 if (bundle.empty()) {
909 return;
910 }
911 // Create stacked kernel function type.
912 llvm::Type *voidTy = llvm::Type::getVoidTy(getLLVMContext());
913 llvm::FunctionType *kernelFuncTy =
914 llvm::FunctionType::get(voidTy, argTypes, false);
915 auto *kernelFunc =
916 llvm::Function::Create(kernelFuncTy, llvm::Function::InternalLinkage,
917 "libjit_stacked_kernel", llmodule_.get());
918 // Mark all kernel function buffer parameters as no-alias, because above
919 // we ensured that they are uniqued.
920 for (unsigned paramIdx = 0; paramIdx < bufferToArgNum.size(); ++paramIdx) {
921 kernelFunc->addParamAttr(paramIdx, llvm::Attribute::AttrKind::NoAlias);
922 }
923
924 // Create the entry BB.
925 llvm::BasicBlock *entryBB =
926 llvm::BasicBlock::Create(getLLVMContext(), "entry", kernelFunc);
927 llvm::IRBuilder<> kernelBuilder(entryBB);
928 // Number of tensor elements.
929 auto *numElements =
930 emitValueSize(kernelBuilder, bundle[0]->getOperand(0).first);
931 // Create a loop inside the stacked kernel function being generated.
932 auto loopBBs = createLoop(kernelBuilder, getLLVMContext(), numElements);
933
934 // Get the index parameter of the loop.
935 // This is the PHI node of the BB.
936 auto *kernelLoopIdx = dyn_cast<llvm::PHINode>(loopBBs.first->begin());
937 assert(kernelLoopIdx && "Could not find the loop index");
938 // Insert the body of the loop right after the PHI node.
939 kernelBuilder.SetInsertPoint(loopBBs.first->getFirstNonPHIOrDbg());
940 // Iterate over stacked instructions and create a kernel invocations per
941 // instruction.
942 for (auto &BI : bundle) {
943 // Name of the stacked operation to be invoked.
944 assert(canBePartOfDataParallelKernel(BI) &&
945 "Data parallel operation is expected");
946 generateLLVMIRForDataParallelInstr(kernelBuilder, BI, kernelFunc,
947 bufferToArgNum, kernelLoopIdx);
948 }
949 kernelBuilder.SetInsertPoint(loopBBs.second);
950 // Add a return.
951 kernelBuilder.CreateRetVoid();
952
953 setCurrentDebugLocation(builder, *bundle.begin());
954 // Emit a call of the kernel.
955 createUncheckedCall(builder, kernelFunc, buffers);
956 // Emit debug info for the generated data-parallel kernel.
957 generateFunctionDebugInfo(kernelFunc);
958}
959
960/// Emit the function that implements a data-parallel kernel and calls it.
961///
962/// The generated kernel functions get buffers as their parameters. The buffers
963/// are uniqued, so that any buffer is passed as argument to the kernel function
964/// only once. This allows us to mark all parameters of the generated kernel as
965/// noalias. As a result, the LLVM optimizer makes use of the noalias attributes
966/// and produces nicely vectorized code for the generated data-parallel kernels.
967/// Note that we will emit a kernel whenever the number of arguments (aka unique
968/// buffers) exceeds `kArgLimit`.
969void LLVMIRGen::emitDataParallelKernel(
970 llvm::IRBuilder<> &builder, llvm::ArrayRef<const Instruction *> bundle) {
971 if (bundle.empty())
972 return;
973 // Types of arguments for the kernel function being generated.
974 llvm::SmallVector<llvm::Type *, 32> argTypes;
975 // Map each buffer used by the kernel to the argument number of the kernel
976 // function. This ensures that same buffer is always mapped to the same
977 // argument.
978 llvm::DenseMap<Value *, int> bufferToArgNum;
979 // Buffers to be passed to the kernel function as arguments.
980 llvm::SmallVector<llvm::Value *, 32> buffers;
981 // Hold a group of instructions whose unique buffer size is no more than
982 // `kArgLimit` and ship it for processing
983 llvm::SmallVector<const Instruction *, 32> batchedBundle;
984 // Collect unique buffers up to `kArgLimit` used by the instructions of the
985 // kernel.
986 for (const auto I : bundle) {
987 // If adding the buffers of current instruction might make the total number
988 // of unique buffer exceed `kArgLimit`, we need to emit the kernel and start
989 // over. Note the "might" as this method is pessimistic, because number of
990 // buffers from current instruction might not be unique. Trade-off here is
991 // that the algorithm is cleaner and in practice, if we over-estimate the
992 // argument size by several, it does not matter too much.
993 if (argTypes.size() + I->getOperands().size() > kArgLimit) {
994 emitDataParallelKernelImpl(builder, batchedBundle, argTypes,
995 bufferToArgNum, buffers);
996 batchedBundle.clear();
997 argTypes.clear();
998 bufferToArgNum.clear();
999 buffers.clear();
1000 }
1001
1002 // Add the instruction to the current bundle and process its operands
1003 batchedBundle.push_back(I);
1004 for (const auto &Op : I->getOperands()) {
1005 auto *buf = Op.first;
1006 if (!bufferToArgNum.count(buf)) {
1007 bufferToArgNum[buf] = argTypes.size();
1008 buffers.push_back(emitValueAddress(builder, buf));
1009 argTypes.push_back(getElementType(builder, buf)->getPointerTo());
1010 }
1011 }
1012 }
1013 emitDataParallelKernelImpl(builder, batchedBundle, argTypes, bufferToArgNum,
1014 buffers);
1015}
1016
1017/// Check if the provided operand overlaps with an operand of an instruction
1018/// already in the bundle, but is not exactly the same memory region.
1019/// Such memory regions cannot be considered data-parallel in the scope of the
1020/// same kernel.
1021///
1022/// \param allocationsInfo information about allocations
1023/// \param bundle current bundle of stacked instructions
1024/// \param op the operand to be checked for overlaps with the \p bundle.
1025static bool isOverlappingWithAnyBundleBufferOperands(
1026 AllocationsInfo &allocationsInfo,
1027 llvm::SmallVectorImpl<const Instruction *> &bundle,
1028 const Instruction::Operand &op) {
1029 auto *buf = op.first;
1030 auto addr1 = allocationsInfo.allocatedAddress_[buf];
1031 auto size1 = buf->getSizeInBytes();
1032 for (auto bi : bundle) {
1033 for (auto bop : bi->getOperands()) {
1034 // Only input operands never interfere.
1035 if (bop.second == OperandKind::In && op.second == OperandKind::In) {
1036 continue;
1037 }
1038 auto buf2 = bop.first;
1039 auto addr2 = allocationsInfo.allocatedAddress_[buf2];
1040 auto size2 = buf2->getSizeInBytes();
1041 if (addr1 == addr2 && size1 == size2) {
1042 // The two buffers are the exact same memory region. The operations
1043 // cannot be within the same bundle because the buffer pointers are
1044 // "noalias" qualified, so the kernel operations can be reordered by
1045 // LLVM's optimizations.
1046 // TODO investigate if removing "noalias" can be used to create bigger
1047 // and faster bundles.
1048 return true;
1049 }
1050 if ((addr1 >= addr2 && addr1 < addr2 + size2) ||
1051 (addr2 >= addr1 && addr2 < addr1 + size1)) {
1052 // Two intervals overlap, but are not the same.
1053 return true;
1054 }
1055 }
1056 }
1057 return false;
1058}
1059
1060template <typename T> bool matchPair(T a, T b) { return a == b; }
1061
1062template <typename T> bool matchPair(T a) { return false; }
1063
1064/// Returns true if the input /p a matches with at least one of the inputs in
1065/// the variadic list \p b ....
1066template <typename T, typename... Args> bool matchPair(T a, T b, Args... args) {
1067 return a == b || matchPair(a, args...);
1068}
1069
1070void LLVMIRGen::generateLLVMIRForModule(llvm::IRBuilder<> &builder) {
1071 // Go over the instructions and try to group them into bundles.
1072 auto &instrs = F_->getInstrs();
1073
1074 // Group instructions into bundles of shape compatible data parallel
1075 // instructions and emit them.
1076 llvm::SmallVector<const Instruction *, 32> bundle;
1077 for (auto &I : instrs) {
1078 if (!canBePartOfDataParallelKernel(&I)) {
1079 // Ignore memory management instructions as they are handled by the
1080 // MemoryManager and are NOPs for a JIT.
1081 if (isa<AllocActivationInst>(&I) || isa<DeallocActivationInst>(&I) ||
1082 isa<TensorViewInst>(&I)) {
1083 generateLLVMIRForInstr(builder, &I);
1084 continue;
1085 }
1086 emitDataParallelKernel(builder, bundle);
1087 bundle.clear();
1088 generateLLVMIRForInstr(builder, &I);
1089 continue;
1090 }
1091
1092 // This is a data parallel instruction.
1093
1094 // Check if the current instruction is shape compatible with the bundle.
1095 bool isBundleCompatible = true;
1096 if (!bundle.empty()) {
1097 auto val = I.getOperand(0).first;
1098 auto bundleVal = bundle.back()->getOperand(0).first;
1099 // Check if shapes have the same amount of elements.
1100 isBundleCompatible = val->size() == bundleVal->size();
1101 }
1102
1103 // Check all mutated operands of the current instruction. Their memory
1104 // regions should not have a non-exact overlap with any operands of the
1105 // bundled instructions. In case this condition does not hold, the current
1106 // instruction cannot be included into the data-parallel bundle, because
1107 // overlapping operand buffers are not data parallel.
1108 for (auto &op : I.getOperands()) {
1109 // If the mutated operand buffer overlaps with any buffer already used by
1110 // the bundle, the current instruction cannot become a part of the bundle.
1111 if (isOverlappingWithAnyBundleBufferOperands(allocationsInfo_, bundle,
1112 op)) {
1113 isBundleCompatible = false;
1114 break;
1115 }
1116 }
1117
1118 // If the instruction cannot be added to the current bundle, emit the kernel
1119 // for the current bundle and start a new bundle.
1120 if (!isBundleCompatible) {
1121 emitDataParallelKernel(builder, bundle);
1122 bundle.clear();
1123 }
1124 // Add a data parallel instruction to the bundle.
1125 bundle.push_back(&I);
1126 }
1127
1128 emitDataParallelKernel(builder, bundle);
1129}
1130
1131void LLVMIRGen::generateLLVMIRForDataParallelInstr(
1132 llvm::IRBuilder<> &builder, const glow::Instruction *I,
1133 llvm::Function *kernel, llvm::DenseMap<Value *, int> &bufferToArgNum,
1134 llvm::Value *loopCount) {
1135 setCurrentDebugLocation(builder, I);
1136 assert(canBePartOfDataParallelKernel(I) &&
1137 "Instruction cannot be part of a data parallel kernel");
1138 switch (I->getKind()) {
1139
1140#define ARITHMETIC_UNARY_OP_WITH_IMM_CASE(INST_NAME_, FUN_NAME_, VALUE_) \
1141 case Kinded::Kind::INST_NAME_##InstKind: { \
1142 auto *AN = cast<INST_NAME_##Inst>(I); \
1143 auto *dest = AN->getDest(); \
1144 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); \
1145 auto *elementTy = getElementType(builder, dest); \
1146 auto value = AN->get##VALUE_(); \
1147 auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType()); \
1148 auto *pointerNull = \
1149 llvm::ConstantPointerNull::get(elementTy->getPointerTo()); \
1150 if (dest->getType()->isQuantizedType()) { \
1151 auto *destTy = dest->getType(); \
1152 /* Quantize value based on the output type. */ \
1153 /* Perform this early and let jit library to work */ \
1154 /* with quantized number. */ \
1155 TensorQuantizationParams TQP{destTy->getScale(), destTy->getOffset()}; \
1156 if (destTy->getElementType() == ElemKind::Int8QTy) { \
1157 auto quantizedValue = quantization::quantize<int8_t>(value, TQP); \
1158 auto *val = emitConstI8(builder, quantizedValue); \
1159 auto *stackedOpCall = createUncheckedCall( \
1160 builder, F, {loopCount, val, pointerNull, pointerNull}); \
1161 auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \
1162 "buffer.element.addr"); \
1163 builder.CreateStore(stackedOpCall, destAddr); \
1164 } else if (destTy->getElementType() == ElemKind::Int16QTy) { \
1165 auto quantizedValue = quantization::quantize<int16_t>(value, TQP); \
1166 auto *val = emitConstI16(builder, quantizedValue); \
1167 auto *stackedOpCall = createUncheckedCall( \
1168 builder, F, {loopCount, val, pointerNull, pointerNull}); \
1169 auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \
1170 "buffer.element.addr"); \
1171 builder.CreateStore(stackedOpCall, destAddr); \
1172 } else { \
1173 llvm_unreachable("Quantization precision not supported."); \
1174 } \
1175 } else { \
1176 auto *val = emitConst(builder, value, dest->getElementType()); \
1177 auto *stackedOpCall = createUncheckedCall( \
1178 builder, F, {loopCount, val, pointerNull, pointerNull}); \
1179 auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \
1180 "buffer.element.addr"); \
1181 builder.CreateStore(stackedOpCall, destAddr); \
1182 } \
1183 break; \
1184 }
1185 ARITHMETIC_UNARY_OP_WITH_IMM_CASE(Splat, "splat", Value);
1186#undef ARITHMETIC_UNARY_OP_WITH_IMM_CASE
1187
1188 case Kinded::Kind::TouchInstKind:
1189 // do nothing;
1190 break;
1191
1192 case Kinded::Kind::ElementSelectInstKind: {
1193 auto *ES = cast<ElementSelectInst>(I);
1194 auto *dest = ES->getDest();
1195 auto *cond = ES->getCond();
1196 auto *lhs = ES->getLHS();
1197 auto *rhs = ES->getRHS();
1198 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1199 auto *condPtr = emitBufferAddress(builder, cond, kernel, bufferToArgNum);
1200 auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
1201 auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
1202
1203 // Need _kernel suffix since these operations are implemented as
1204 // "data-parallel" kernels in libjit.
1205 auto *F = getFunction("elementselect_kernel", lhs->getElementType());
1206
1207 if (lhs->getType()->isQuantizedType()) {
1208 auto *destTy = dest->getType();
1209 auto *lhsTy = lhs->getType();
1210 auto *rhsTy = rhs->getType();
1211
1212 auto *destOffset = emitConstI32(builder, destTy->getOffset());
1213 auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
1214 auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());
1215
1216 // The selected value will be either lhs = s_l * (i_l - o_l) or
1217 // rhs = s_r * (i_r - o_r); the stored result that must be computed is
1218 // therefore one of:
1219 // (i) i_d = (s_l / s_d) * (i_l - o_l) + o_d
1220 // (ii) i_d = (s_r / s_d) * (i_r - o_r) + o_d
1221 float destScale = destTy->getScale();
1222 auto lhsScaleParams = quantization::quantizeScaleOffset32To8(
1223 lhsTy->getScale() / destScale, lhsTy->getOffset());
1224 auto rhsScaleParams = quantization::quantizeScaleOffset32To8(
1225 rhsTy->getScale() / destScale, rhsTy->getOffset());
1226
1227 auto *lhsPre = emitConstI32(builder, lhsScaleParams.pre);
1228 auto *lhsPost = emitConstI32(builder, lhsScaleParams.post);
1229 auto *lhsScale = emitConstI32(builder, lhsScaleParams.scale);
1230 auto *rhsPre = emitConstI32(builder, rhsScaleParams.pre);
1231 auto *rhsPost = emitConstI32(builder, rhsScaleParams.post);
1232 auto *rhsScale = emitConstI32(builder, rhsScaleParams.scale);
1233
1234 auto *stackedOpCall = createUncheckedCall(
1235 builder, F,
1236 {loopCount, condPtr, lhsPtr, rhsPtr, destOffset, lhsOffset, rhsOffset,
1237 lhsPre, lhsPost, lhsScale, rhsPre, rhsPost, rhsScale});
1238 auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
1239 loopCount, "buffer.element.addr");
1240 builder.CreateStore(stackedOpCall, destAddr);
1241 } else {
1242 auto *stackedOpCall =
1243 createUncheckedCall(builder, F, {loopCount, condPtr, lhsPtr, rhsPtr});
1244 auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr,
1245 loopCount, "buffer.element.addr");
1246 builder.CreateStore(stackedOpCall, destAddr);
1247 }
1248 break;
1249 }
1250 case Kinded::Kind::IntLookupTableInstKind: {
1251 auto *lookupTable = cast<IntLookupTableInst>(I);
1252 auto *dest = lookupTable->getDest();
1253 auto *src = lookupTable->getSrc();
1254 auto *mapping = lookupTable->getMapping();
1255
1256 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1257 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1258 auto *mappingPtr =
1259 emitBufferAddress(builder, mapping, kernel, bufferToArgNum);
1260
1261 auto *F = getFunction("intlookuptable_kernel", dest->getElementType());
1262 auto *stackedOpCall =
1263 builder.CreateCall(F, {loopCount, srcPtr, mappingPtr});
1264 auto *destType = getElementType(builder, dest);
1265 auto *destAddr =
1266 builder.CreateGEP(destType, destPtr, loopCount, "buffer.element.addr");
1267 builder.CreateStore(stackedOpCall, destAddr);
1268
1269 break;
1270 }
1271#define ARITHMETIC_UNARY_OP_CASE(INST_NAME_, FUN_NAME_) \
1272 case Kinded::Kind::INST_NAME_##InstKind: { \
1273 auto *AN = cast<INST_NAME_##Inst>(I); \
1274 auto *dest = AN->getDest(); \
1275 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); \
1276 auto *srcPtr = \
1277 emitBufferAddress(builder, AN->getSrc(), kernel, bufferToArgNum); \
1278 auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType()); \
1279 auto *elementTy = getElementType(builder, dest); \
1280 auto *pointerNull = \
1281 llvm::ConstantPointerNull::get(elementTy->getPointerTo()); \
1282 auto *stackedOpCall = createUncheckedCall( \
1283 builder, F, {loopCount, srcPtr, pointerNull, pointerNull}); \
1284 auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr, \
1285 loopCount, "buffer.element.addr"); \
1286 builder.CreateStore(stackedOpCall, destAddr); \
1287 break; \
1288 }
1289
1290 ARITHMETIC_UNARY_OP_CASE(Sigmoid, "sigmoid");
1291 ARITHMETIC_UNARY_OP_CASE(Tanh, "tanh");
1292 ARITHMETIC_UNARY_OP_CASE(ElementLog, "element_log");
1293 ARITHMETIC_UNARY_OP_CASE(ElementExp, "element_exp");
1294 ARITHMETIC_UNARY_OP_CASE(ElementAbs, "element_abs");
1295 ARITHMETIC_UNARY_OP_CASE(ElementNeg, "element_neg");
1296 ARITHMETIC_UNARY_OP_CASE(ElementFloor, "element_floor");
1297 ARITHMETIC_UNARY_OP_CASE(ElementCeil, "element_ceil");
1298 ARITHMETIC_UNARY_OP_CASE(ElementRound, "element_round");
1299 ARITHMETIC_UNARY_OP_CASE(ElementSqrt, "element_sqrt");
1300 ARITHMETIC_UNARY_OP_CASE(ElementErf, "element_erf");
1301 ARITHMETIC_UNARY_OP_CASE(ElementRsqrt, "element_rsqrt");
1302 ARITHMETIC_UNARY_OP_CASE(ElementReciprocal, "element_reciprocal");
1303 ARITHMETIC_UNARY_OP_CASE(ElementSin, "element_sin");
1304 ARITHMETIC_UNARY_OP_CASE(ElementCos, "element_cos");
1305#undef ARITHMETIC_UNARY_OP_CASE
1306
1307 case Kinded::Kind::ReluInstKind: {
1308 auto *RI = cast<ReluInst>(I);
1309 auto *src = RI->getSrc();
1310 auto *dest = RI->getDest();
1311 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1312 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1313 auto srcTy = src->getType();
1314 auto destTy = dest->getType();
1315
1316 auto *F = getFunction("element_relu", dest->getElementType());
1317 llvm::CallInst *stackedOpCall = nullptr;
1318 if (dest->getElementType() == ElemKind::Int8QTy) {
1319 auto *srcOffset =
1320 emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset()));
1321 auto *destOffset =
1322 emitConstI8(builder, static_cast<int8_t>(destTy->getOffset()));
1323 auto destScaleParams = quantization::quantizeScaleOffset32To8(
1324 srcTy->getScale() / destTy->getScale(), 0);
1325 auto *destPre = emitConstI32(builder, destScaleParams.pre);
1326 auto *destPost = emitConstI32(builder, destScaleParams.post);
1327 auto *destScale = emitConstI32(builder, destScaleParams.scale);
1328 stackedOpCall = createCall(builder, F,
1329 {loopCount, srcPtr, srcOffset, destOffset,
1330 destPre, destPost, destScale});
1331 } else if (dest->getElementType() == ElemKind::FloatTy) {
1332 stackedOpCall = createCall(builder, F, {loopCount, srcPtr});
1333 } else {
1334 LOG(FATAL) << "Type is not supported";
1335 }
1336 auto *elementTy = getElementType(builder, dest);
1337 auto *destAddr =
1338 builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
1339 builder.CreateStore(stackedOpCall, destAddr);
1340 break;
1341 }
1342
1343 case Kinded::Kind::ClipInstKind: {
1344 auto *CI = cast<ClipInst>(I);
1345 auto *src = CI->getSrc();
1346 auto *dest = CI->getDest();
1347 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1348 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1349 auto srcTy = src->getType();
1350 auto destTy = dest->getType();
1351 float clipMinF = CI->getMin();
1352 float clipMaxF = CI->getMax();
1353
1354 auto *F = getFunction("element_clip", dest->getElementType());
1355 llvm::CallInst *stackedOpCall = nullptr;
1356 if (dest->getElementType() == ElemKind::Int8QTy) {
1357 TensorQuantizationParams srcTQP{src->getType()->getScale(),
1358 src->getType()->getOffset()};
1359 int8_t clipMinQ = quantization::quantize<int8_t>(clipMinF, srcTQP);
1360 int8_t clipMaxQ = quantization::quantize<int8_t>(clipMaxF, srcTQP);
1361 auto *clipMin = emitConstI8(builder, clipMinQ);
1362 auto *clipMax = emitConstI8(builder, clipMaxQ);
1363 auto *srcOffset =
1364 emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset()));
1365 auto *destOffset =
1366 emitConstI8(builder, static_cast<int8_t>(destTy->getOffset()));
1367 auto destScaleParams = quantization::quantizeScaleOffset32To8(
1368 srcTy->getScale() / destTy->getScale(), 0);
1369 auto *destPre = emitConstI32(builder, destScaleParams.pre);
1370 auto *destPost = emitConstI32(builder, destScaleParams.post);
1371 auto *destScale = emitConstI32(builder, destScaleParams.scale);
1372 stackedOpCall =
1373 createCall(builder, F,
1374 {loopCount, srcPtr, clipMin, clipMax, srcOffset,
1375 destOffset, destPre, destPost, destScale});
1376 } else if (dest->getElementType() == ElemKind::FloatTy) {
1377 auto *clipMin = emitConstF32(builder, clipMinF);
1378 auto *clipMax = emitConstF32(builder, clipMaxF);
1379 stackedOpCall =
1380 createCall(builder, F, {loopCount, srcPtr, clipMin, clipMax});
1381 } else {
1382 LOG(FATAL) << "Type is not supported";
1383 }
1384 auto *elementTy = getElementType(builder, dest);
1385 auto *destAddr =
1386 builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
1387 builder.CreateStore(stackedOpCall, destAddr);
1388 break;
1389 }
1390
1391 case Kinded::Kind::LeakyReluInstKind: {
1392 auto *LI = cast<LeakyReluInst>(I);
1393 auto *src = LI->getSrc();
1394 auto *dest = LI->getDest();
1395 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1396 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1397 auto srcTy = src->getType();
1398 auto destTy = dest->getType();
1399
1400 auto *F = getFunction("element_leaky_relu", dest->getElementType());
1401 llvm::CallInst *stackedOpCall = nullptr;
1402 if (dest->getElementType() == ElemKind::Int8QTy) {
1403 auto *srcOffset =
1404 emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset()));
1405 auto *destOffset =
1406 emitConstI8(builder, static_cast<int8_t>(destTy->getOffset()));
1407 // Scale parameters for the positive input domain.
1408 auto posParams = quantization::quantizeScaleOffset32To8(
1409 srcTy->getScale() / destTy->getScale(), 0);
1410 auto *posPre = emitConstI32(builder, posParams.pre);
1411 auto *posPost = emitConstI32(builder, posParams.post);
1412 auto *posScale = emitConstI32(builder, posParams.scale);
1413 // Scale parameters for the negative input domain.
1414 auto negParams = quantization::quantizeScaleOffset32To8(
1415 srcTy->getScale() * LI->getAlpha() / destTy->getScale(), 0);
1416 auto *negPre = emitConstI32(builder, negParams.pre);
1417 auto *negPost = emitConstI32(builder, negParams.post);
1418 auto *negScale = emitConstI32(builder, negParams.scale);
1419 stackedOpCall =
1420 createCall(builder, F,
1421 {loopCount, srcPtr, srcOffset, destOffset, posPre, posPost,
1422 posScale, negPre, negPost, negScale});
1423 } else if (dest->getElementType() == ElemKind::FloatTy) {
1424 auto *alpha = emitConstF32(builder, LI->getAlpha());
1425 stackedOpCall = createCall(builder, F, {loopCount, srcPtr, alpha});
1426 } else {
1427 LOG(FATAL) << "Type is not supported";
1428 }
1429 auto *elementTy = getElementType(builder, dest);
1430 auto *destAddr =
1431 builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
1432 builder.CreateStore(stackedOpCall, destAddr);
1433 break;
1434 }
1435
1436 case Kinded::Kind::ElementIsNaNInstKind: {
1437 auto *AN = cast<ElementIsNaNInst>(I);
1438 auto *src = AN->getSrc();
1439 auto *dest = AN->getDest();
1440 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1441 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1442 auto *F = getFunction("element_is_nan_kernel", src->getElementType());
1443 auto *stackedOpCall = createUncheckedCall(builder, F, {loopCount, srcPtr});
1444 auto *elementTy = getElementType(builder, dest);
1445 auto *destAddr =
1446 builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
1447 builder.CreateStore(stackedOpCall, destAddr);
1448 break;
1449 }
1450
1451 case Kinded::Kind::QuantizeInstKind: {
1452 auto *QI = cast<QuantizeInst>(I);
1453 auto *src = QI->getSrc();
1454 auto *dest = QI->getDest();
1455 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1456 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1457 auto *destTy = dest->getType();
1458 auto *destScale = emitConstF32(builder, destTy->getScale());
1459 auto *destOffset = emitConstI32(builder, destTy->getOffset());
1460 auto *F = getFunction("element_quantize_kernel", dest->getElementType());
1461
1462 auto *stackedOpCall = createUncheckedCall(
1463 builder, F, {loopCount, srcPtr, destScale, destOffset});
1464 auto *destType = getElementType(builder, dest);
1465 auto *destAddr =
1466 builder.CreateGEP(destType, destPtr, loopCount, "buffer.element.addr");
1467 builder.CreateStore(stackedOpCall, destAddr);
1468 break;
1469 }
1470
1471 case Kinded::Kind::DequantizeInstKind: {
1472 auto *DI = cast<DequantizeInst>(I);
1473 auto *src = DI->getSrc();
1474 auto *dest = DI->getDest();
1475 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1476 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1477 auto *srcTy = src->getType();
1478 auto *srcScale = emitConstF32(builder, srcTy->getScale());
1479 auto *srcOffset = emitConstI32(builder, srcTy->getOffset());
1480 auto *F = getFunction("element_dequantize_kernel", src->getElementType());
1481
1482 auto *stackedOpCall = createUncheckedCall(
1483 builder, F, {loopCount, srcPtr, srcScale, srcOffset});
1484 auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr, loopCount,
1485 "buffer.element.addr");
1486 builder.CreateStore(stackedOpCall, destAddr);
1487 break;
1488 }
1489
1490 case Kinded::Kind::RescaleQuantizedInstKind: {
1491 auto *RQI = cast<RescaleQuantizedInst>(I);
1492 auto *dest = RQI->getDest();
1493 auto *src = RQI->getSrc();
1494 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1495 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1496
1497 auto *destType = dest->getType();
1498 auto *srcType = src->getType();
1499
1500 auto rescaleParams = quantization::quantizeScaleOffset32To8(
1501 srcType->getScale() / destType->getScale(), srcType->getOffset());
1502
1503 auto *destOffset = emitConstI32(builder, destType->getOffset());
1504 auto *srcOffset = emitConstI32(builder, srcType->getOffset());
1505 auto *preShift = emitConstI32(builder, rescaleParams.pre);
1506 auto *postShift = emitConstI32(builder, rescaleParams.post);
1507 auto *scale = emitConstI32(builder, rescaleParams.scale);
1508 auto *F = getFunction("element_rescale_kernel", dest->getElementType());
1509
1510 auto *stackedOpCall = createUncheckedCall(
1511 builder, F,
1512 {loopCount, srcPtr, destOffset, srcOffset, preShift, postShift, scale});
1513 auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, loopCount,
1514 "buffer.element.addr");
1515 builder.CreateStore(stackedOpCall, destAddr);
1516 break;
1517 }
1518
1519 case Kinded::Kind::CopyInstKind: {
1520 auto *CI = cast<CopyInst>(I);
1521 auto *dest = CI->getDest();
1522 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1523 auto *srcPtr =
1524 emitBufferAddress(builder, CI->getSrc(), kernel, bufferToArgNum);
1525 auto *F = getFunction("copy_kernel", dest->getElementType());
1526 auto *elementTy = getElementType(builder, dest);
1527 auto *pointerNull =
1528 llvm::ConstantPointerNull::get(elementTy->getPointerTo());
1529 auto *stackedOpCall = createUncheckedCall(
1530 builder, F, {loopCount, srcPtr, pointerNull, pointerNull});
1531 auto *destAddr = builder.CreateGEP(getElementType(builder, dest), destPtr,
1532 loopCount, "buffer.element.addr");
1533 builder.CreateStore(stackedOpCall, destAddr);
1534 break;
1535 }
1536
1537#define ARITHMETIC_BINARY_OP_CASE(INST_NAME_, FUN_NAME_, ...) \
1538 case Kinded::Kind::INST_NAME_##InstKind: { \
1539 auto *AN = cast<INST_NAME_##Inst>(I); \
1540 auto *dest = AN->getDest(); \
1541 auto *lhs = AN->getLHS(); \
1542 auto *rhs = AN->getRHS(); \
1543 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); \
1544 auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); \
1545 auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); \
1546 \
1547 auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType()); \
1548 auto *elementTy = getElementType(builder, dest); \
1549 auto *pointerNull = \
1550 llvm::ConstantPointerNull::get(elementTy->getPointerTo()); \
1551 bool typesMatched = matchPair(dest->getElementType(), __VA_ARGS__); \
1552 if (lhs->getType()->isQuantizedType()) { \
1553 auto *destTy = dest->getType(); \
1554 auto *lhsTy = lhs->getType(); \
1555 auto *rhsTy = rhs->getType(); \
1556 \
1557 auto *destOffset = emitConstI32(builder, destTy->getOffset()); \
1558 auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset()); \
1559 auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset()); \
1560 \
1561 float destScale = destTy->getScale(); \
1562 \
1563 auto lhsScaleParams = quantization::quantizeScaleOffset32To8( \
1564 lhsTy->getScale() / destScale, lhsTy->getOffset()); \
1565 auto rhsScaleParams = quantization::quantizeScaleOffset32To8( \
1566 rhsTy->getScale() / destScale, rhsTy->getOffset()); \
1567 \
1568 auto *lhsPre = emitConstI32(builder, lhsScaleParams.pre); \
1569 auto *lhsPost = emitConstI32(builder, lhsScaleParams.post); \
1570 auto *lhsScale = emitConstI32(builder, lhsScaleParams.scale); \
1571 auto *rhsPre = emitConstI32(builder, rhsScaleParams.pre); \
1572 auto *rhsPost = emitConstI32(builder, rhsScaleParams.post); \
1573 auto *rhsScale = emitConstI32(builder, rhsScaleParams.scale); \
1574 \
1575 auto *stackedOpCall = createUncheckedCall( \
1576 builder, F, \
1577 {loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset, rhsOffset, \
1578 lhsPre, lhsPost, lhsScale, rhsPre, rhsPost, rhsScale}); \
1579 auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, \
1580 loopCount, "buffer.element.addr"); \
1581 builder.CreateStore(stackedOpCall, destAddr); \
1582 } else if (typesMatched) { \
1583 auto *stackedOpCall = createUncheckedCall( \
1584 builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull}); \
1585 auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \
1586 "buffer.element.addr"); \
1587 builder.CreateStore(stackedOpCall, destAddr); \
1588 } else { \
1589 llvm_unreachable("Unsupported Type in " #INST_NAME_); \
1590 } \
1591 break; \
1592 }
1593 ARITHMETIC_BINARY_OP_CASE(ElementAdd, "element_add", ElemKind::FloatTy,
1594 ElemKind::Int32ITy, ElemKind::Int64ITy);
1595 ARITHMETIC_BINARY_OP_CASE(ElementSub, "element_sub", ElemKind::FloatTy);
1596 ARITHMETIC_BINARY_OP_CASE(ElementMax, "element_max", ElemKind::FloatTy);
1597 ARITHMETIC_BINARY_OP_CASE(ElementMin, "element_min", ElemKind::FloatTy);
1598 ARITHMETIC_BINARY_OP_CASE(ElementPow, "element_pow", ElemKind::FloatTy);
1599#undef ARITHMETIC_BINARY_OP_CASE
1600
1601 case Kinded::Kind::ElementNotInstKind: {
1602 auto *NI = cast<ElementNotInst>(I);
1603 auto *dest = NI->getDest();
1604 auto *src = NI->getSrc();
1605 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1606 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1607 auto *F = getFunction("element_not_kernel", src->getElementType());
1608 auto *elementTy = getElementType(builder, dest);
1609 auto *stackedOpCall = createUncheckedCall(builder, F, {loopCount, srcPtr});
1610 auto *destAddr =
1611 builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
1612 builder.CreateStore(stackedOpCall, destAddr);
1613 break;
1614 }
1615
1616 case Kinded::Kind::ElementAndInstKind: {
1617 auto *AI = cast<ElementAndInst>(I);
1618 auto *dest = AI->getDest();
1619 auto *lhs = AI->getLHS();
1620 auto *rhs = AI->getRHS();
1621 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1622 auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
1623 auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
1624 auto *F = getFunction("element_and_kernel", lhs->getElementType());
1625 auto *elementTy = getElementType(builder, dest);
1626 auto *stackedOpCall =
1627 createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
1628 auto *destAddr =
1629 builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
1630 builder.CreateStore(stackedOpCall, destAddr);
1631 break;
1632 }
1633
1634 case Kinded::Kind::ElementOrInstKind: {
1635 auto *OI = cast<ElementOrInst>(I);
1636 auto *dest = OI->getDest();
1637 auto *lhs = OI->getLHS();
1638 auto *rhs = OI->getRHS();
1639 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1640 auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
1641 auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
1642 auto *F = getFunction("element_or_kernel", lhs->getElementType());
1643 auto *elementTy = getElementType(builder, dest);
1644 auto *stackedOpCall =
1645 createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
1646 auto *destAddr =
1647 builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
1648 builder.CreateStore(stackedOpCall, destAddr);
1649 break;
1650 }
1651
1652 case Kinded::Kind::ElementXorInstKind: {
1653 auto *XI = cast<ElementXorInst>(I);
1654 auto *dest = XI->getDest();
1655 auto *lhs = XI->getLHS();
1656 auto *rhs = XI->getRHS();
1657 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1658 auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
1659 auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
1660 auto *F = getFunction("element_xor_kernel", lhs->getElementType());
1661 auto *elementTy = getElementType(builder, dest);
1662 auto *stackedOpCall =
1663 createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
1664 auto *destAddr =
1665 builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
1666 builder.CreateStore(stackedOpCall, destAddr);
1667 break;
1668 }
1669
1670 case Kinded::Kind::ElementCmpEQInstKind:
1671 case Kinded::Kind::ElementCmpNEQInstKind:
1672 case Kinded::Kind::ElementCmpLTInstKind:
1673 case Kinded::Kind::ElementCmpLTEInstKind: {
1674 Value *dest = nullptr;
1675 Value *lhs = nullptr;
1676 Value *rhs = nullptr;
1677 std::string kernelName;
1678
1679 if (auto *CEQI = dyn_cast<ElementCmpEQInst>(I)) {
1680 dest = CEQI->getDest();
1681 lhs = CEQI->getLHS();
1682 rhs = CEQI->getRHS();
1683 kernelName = "element_cmp_eq_kernel";
1684 } else if (auto *CNEQI = dyn_cast<ElementCmpNEQInst>(I)) {
1685 dest = CNEQI->getDest();
1686 lhs = CNEQI->getLHS();
1687 rhs = CNEQI->getRHS();
1688 kernelName = "element_cmp_neq_kernel";
1689 } else if (auto *CLTEI = dyn_cast<ElementCmpLTEInst>(I)) {
1690 dest = CLTEI->getDest();
1691 lhs = CLTEI->getLHS();
1692 rhs = CLTEI->getRHS();
1693 kernelName = "element_cmp_lte_kernel";
1694 } else if (auto *CLTI = dyn_cast<ElementCmpLTInst>(I)) {
1695 dest = CLTI->getDest();
1696 lhs = CLTI->getLHS();
1697 rhs = CLTI->getRHS();
1698 kernelName = "element_cmp_lt_kernel";
1699 } else {
1700 llvm_unreachable(
1701 "Missmatch between Instruction Kind and instruction instance.");
1702 }
1703
1704 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1705 auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
1706 auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
1707
1708 // Need _kernel suffix since these operations are implemented as
1709 // "data-parallel" kernels in libjit.
1710 auto *F = getFunction(kernelName.c_str(), lhs->getElementType());
1711
1712 if (lhs->getType()->isQuantizedType()) {
1713 auto *lhsTy = lhs->getType();
1714 auto *rhsTy = rhs->getType();
1715
1716 auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
1717 auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());
1718
1719 // We can divide both sides of the comparison by the rhs scale since it is
1720 // strictly positive; this saves one rescale within the backend. The
1721 // inequalities are:
1722 // s_l * (i_l - o_l) <= s_r * (i_r - o_r)
1723 // <=> (s_l / s_r) * (i_l - o_l) <= i_r - o_r
1724 float scale = lhsTy->getScale() / rhsTy->getScale();
1725 auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0);
1726 auto *cmpPre = emitConstI32(builder, scaleParams.pre);
1727 auto *cmpPost = emitConstI32(builder, scaleParams.post);
1728 auto *cmpScale = emitConstI32(builder, scaleParams.scale);
1729
1730 auto *stackedOpCall =
1731 createUncheckedCall(builder, F,
1732 {loopCount, lhsPtr, rhsPtr, lhsOffset, rhsOffset,
1733 cmpPre, cmpPost, cmpScale});
1734 auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
1735 loopCount, "buffer.element.addr");
1736 builder.CreateStore(stackedOpCall, destAddr);
1737 } else {
1738 auto *stackedOpCall =
1739 createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
1740 auto *elementTy = getElementType(builder, dest);
1741 auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,
1742 "buffer.element.addr");
1743 builder.CreateStore(stackedOpCall, destAddr);
1744 }
1745 break;
1746 }
1747
1748 case Kinded::Kind::ElementMulInstKind: {
1749 auto *MI = cast<ElementMulInst>(I);
1750 auto *dest = MI->getDest();
1751 auto *lhs = MI->getLHS();
1752 auto *rhs = MI->getRHS();
1753 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1754 auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
1755 auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
1756
1757 // Need _kernel suffix since these operations are implemented as
1758 // "data-parallel" kernels in libjit.
1759 auto *F = getFunction("element_mul_kernel", dest->getElementType());
1760 auto *elementTy = getElementType(builder, dest);
1761 auto *pointerNull =
1762 llvm::ConstantPointerNull::get(elementTy->getPointerTo());
1763
1764 if (lhs->getType()->isQuantizedType()) {
1765 auto *destTy = dest->getType();
1766 auto *lhsTy = lhs->getType();
1767 auto *rhsTy = rhs->getType();
1768
1769 auto *destOffset = emitConstI32(builder, destTy->getOffset());
1770 auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
1771 auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());
1772
1773 // The multiplicative scale factor is s_l * s_r / s_d due to the equation
1774 // s_d * (i_d - o_d) = s_l * (i_l - o_l) * s_r * (i_r - o_r)
1775 // => i_d = (s_l * s_r / s_d) * (i_l - o_l) * (i_r - o_r) + o_d
1776 float scale = lhsTy->getScale() * rhsTy->getScale() / destTy->getScale();
1777 auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0);
1778 auto *mulPre = emitConstI32(builder, scaleParams.pre);
1779 auto *mulPost = emitConstI32(builder, scaleParams.post);
1780 auto *mulScale = emitConstI32(builder, scaleParams.scale);
1781
1782 auto *stackedOpCall =
1783 createUncheckedCall(builder, F,
1784 {loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset,
1785 rhsOffset, mulPre, mulPost, mulScale});
1786 auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
1787 loopCount, "buffer.element.addr");
1788 builder.CreateStore(stackedOpCall, destAddr);
1789 } else if (lhs->getType()->getElementType() == ElemKind::Int64ITy ||
1790 lhs->getType()->getElementType() == ElemKind::Int32ITy ||
1791 lhs->getType()->getElementType() == ElemKind::FloatTy) {
1792 auto *stackedOpCall = createUncheckedCall(
1793 builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull});
1794 auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,
1795 "buffer.element.addr");
1796 builder.CreateStore(stackedOpCall, destAddr);
1797 } else {
1798 LOG_ASSERT(false) << "Unsupported element type for Mul.";
1799 }
1800 break;
1801 }
1802
1803 case Kinded::Kind::ElementDivInstKind: {
1804 auto *MI = cast<ElementDivInst>(I);
1805 auto *dest = MI->getDest();
1806 auto *lhs = MI->getLHS();
1807 auto *rhs = MI->getRHS();
1808 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1809 auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
1810 auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
1811
1812 // Need _kernel suffix since these operations are implemented as
1813 // "data-parallel" kernels in libjit.
1814 auto *F = getFunction("element_div_kernel", dest->getElementType());
1815 auto *elementTy = getElementType(builder, dest);
1816 auto *pointerNull =
1817 llvm::ConstantPointerNull::get(elementTy->getPointerTo());
1818
1819 if (lhs->getType()->isQuantizedType()) {
1820 auto *destTy = dest->getType();
1821 auto *lhsTy = lhs->getType();
1822 auto *rhsTy = rhs->getType();
1823
1824 auto *destOffset = emitConstI32(builder, destTy->getOffset());
1825 auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
1826 auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());
1827
1828 // The division scale factor is s_l / (s_r * s_d) due to the equation
1829 // s_d * (i_d - o_d) = (s_l * (i_l - o_l)) / (s_r * (i_r - o_r))
1830 // => i_d = (s_l / (s_r * s_d)) * ((i_l - o_l) / (i_r - o_r)) + o_d
1831 float scale =
1832 lhsTy->getScale() / (rhsTy->getScale() * destTy->getScale());
1833 auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0);
1834 auto *divPre = emitConstI32(builder, scaleParams.pre);
1835 auto *divPost = emitConstI32(builder, scaleParams.post);
1836 auto *divScale = emitConstI32(builder, scaleParams.scale);
1837
1838 auto *stackedOpCall =
1839 createUncheckedCall(builder, F,
1840 {loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset,
1841 rhsOffset, divPre, divPost, divScale});
1842 auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
1843 loopCount, "buffer.element.addr");
1844 builder.CreateStore(stackedOpCall, destAddr);
1845 } else {
1846 auto *elementTy = getElementType(builder, dest);
1847 auto *stackedOpCall = createUncheckedCall(
1848 builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull});
1849 auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,
1850 "buffer.element.addr");
1851 builder.CreateStore(stackedOpCall, destAddr);
1852 }
1853 break;
1854 }
1855
1856 case Kinded::Kind::ModuloInstKind: {
1857 auto *MI = cast<ModuloInst>(I);
1858 auto *dest = MI->getDest();
1859 auto *src = MI->getSrc();
1860 auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
1861 auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
1862 auto *divisor = emitConst(builder, MI->getDivisor(), ElemKind::Int64ITy);
1863 llvm::Function *F = nullptr;
1864 // Need _kernel suffix since these operations are implemented as
1865 // "data-parallel" kernels in libjit.
1866 if (MI->getSignFollowDivisor()) {
1867 F = getFunction("element_modulo_kernel_sign_follow",
1868 dest->getElementType());
1869 } else {
1870 F = getFunction("element_modulo_kernel_no_sign_follow",
1871 dest->getElementType());
1872 }
1873 auto *stackedOpCall =
1874 createUncheckedCall(builder, F, {loopCount, divisor, srcPtr});
1875 llvm::Value *destAddr = nullptr;
1876 if (dest->getElementType() == ElemKind::Int64ITy) {
1877 destAddr = builder.CreateGEP(builder.getInt64Ty(), destPtr, loopCount,
1878 "buffer.element.addr");
1879 } else {
1880 destAddr = builder.CreateGEP(builder.getInt32Ty(), destPtr, loopCount,
1881 "buffer.element.addr");
1882 }
1883 builder.CreateStore(stackedOpCall, destAddr);
1884 break;
1885 }
1886
1887 default:
1888 std::string sBuf;
1889 llvm::raw_string_ostream s(sBuf);
1890 I->dump(s);
1891 LOG(FATAL) << "Cannot select the instruction: " << s.str();
1892 }
1893}
1894
1895Tensor LLVMIRGen::getTensorForConstantValue(Value *value) {
1896 // Since we can't get the variable from a glow::Value directly,
1897 // we need to traverse the var list and find the one matching the given
1898 // Value.
1899 Tensor tensor;
1900 auto *F_ = getIRFunction();
1901 for (auto &v : F_->findConstants()) {
1902 assert(isa<WeightVar>(F_->getWeightForNode(v)));
1903 auto *w = cast<glow::Value>(F_->getWeightForNode(v));
1904 if (w == value) {
1905 tensor.assign(&v->getPayload());
1906 break;
1907 }
1908 }
1909 CHECK(tensor.getUnsafePtr()) << "Can't find the constant value!";
1910 return tensor;
1911}
1912
1913void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder,
1914 const glow::Instruction *I) {
1915 setCurrentDebugLocation(builder, I);
1916 assert((!canBePartOfDataParallelKernel(I)) &&
1917 "data parallel instructions are not handled here");
1918 switch (I->getKind()) {
1919 case Kinded::Kind::MatMulInstKind: {
1920 auto *MM = cast<MatMulInst>(I);
1921 auto *dest = MM->getDest();
1922 auto *lhs = MM->getLHS();
1923 auto *rhs = MM->getRHS();
1924 auto *destPtr = emitValueAddress(builder, dest);
1925 auto *lhsPtr = emitValueAddress(builder, lhs);
1926 auto *rhsPtr = emitValueAddress(builder, rhs);
1927
1928 auto *destDims = emitValueDims(builder, dest);
1929 auto *lhsDims = emitValueDims(builder, lhs);
1930 auto *rhsDims = emitValueDims(builder, rhs);
1931
1932 auto *F = getFunction("matmul", dest->getElementType());
1933
1934 if (lhs->getType()->isQuantizedType()) {
1935 auto *destTy = dest->getType();
1936 auto *lhsTy = lhs->getType();
1937 auto *rhsTy = rhs->getType();
1938
1939 auto *destOffset = emitConstI32(builder, destTy->getOffset());
1940 auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
1941 auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());
1942
1943 auto outScaleParams = quantization::quantizeScaleOffset32To8(
1944 lhsTy->getScale() * rhsTy->getScale() / destTy->getScale(), 0);
1945
1946 auto *outPre = emitConstI32(builder, outScaleParams.pre);
1947 auto *outPost = emitConstI32(builder, outScaleParams.post);
1948 auto *outScale = emitConstI32(builder, outScaleParams.scale);
1949
1950 createCall(builder, F,
1951 {destPtr, lhsPtr, rhsPtr, destDims, lhsDims, rhsDims,
1952 destOffset, lhsOffset, rhsOffset, outPre, outPost, outScale});
1953 } else {
1954 createCall(builder, F,
1955 {destPtr, lhsPtr, rhsPtr, destDims, lhsDims, rhsDims});
1956 }
1957 break;
1958 }
1959
1960 case Kinded::Kind::QuantizationProfileInstKind: {
1961 auto *QP = cast<QuantizationProfileInst>(I);
1962 auto *hist = QP->getHistogram();
1963 auto *compInfo = QP->getComputationInfo();
1964 auto *inputTensor = QP->getInputTensor();
1965
1966 auto *histPtr = emitValueAddress(builder, hist);
1967 auto *compInfoPtr = emitValueAddress(builder, compInfo);
1968 auto *inputTensorInfoPtr = emitValueAddress(builder, inputTensor);
1969
1970 auto *histDims = emitValueDims(builder, hist);
1971 assert(inputTensor->getElementType() == ElemKind::FloatTy &&
1972 "None float Tensor type for Quantization Profile Instruction.");
1973 auto *tensorSize = emitConstDimT(builder, inputTensor->getType()->size());
1974
1975 auto *F = getFunction("quantization_profile");
1976 createCall(
1977 builder, F,
1978 {inputTensorInfoPtr, tensorSize, compInfoPtr, histPtr, histDims});
1979 break;
1980 }
1981
1982 case Kinded::Kind::FullyConnectedInstKind: {
1983 auto *FCI = cast<FullyConnectedInst>(I);
1984 auto *dest = FCI->getDest();
1985 auto *src = FCI->getSrc();
1986 auto *weights = FCI->getWeights();
1987 auto *bias = FCI->getBias();
1988 auto *destPtr = emitValueAddress(builder, dest);
1989 auto *srcPtr = emitValueAddress(builder, src);
1990 auto *weightsPtr = emitValueAddress(builder, weights);
1991 auto *biasPtr = emitValueAddress(builder, bias);
1992 auto *destDims = emitValueDims(builder, dest);
1993 auto *srcDims = emitValueDims(builder, src);
1994 auto *weightsDims = emitValueDims(builder, weights);
1995 auto *biasDims = emitValueDims(builder, bias);
1996
1997 if (src->getType()->isQuantizedType()) {
1998 auto *destTy = dest->getType();
1999 auto *srcTy = src->getType();
2000 auto *weightsTy = weights->getType();
2001 auto *biasTy = bias->getType();
2002
2003 auto *destOffset = emitConstI32(builder, destTy->getOffset());
2004 auto *srcOffset = emitConstI32(builder, srcTy->getOffset());
2005 auto *weightsOffset = emitConstI32(builder, weightsTy->getOffset());
2006 auto *biasOffset = emitConstI32(builder, biasTy->getOffset());
2007
2008 // Calculate the scale of the values that come out of the matrix
2009 // multiplication part of the calculation.
2010 float matMulScale = srcTy->getScale() * weightsTy->getScale();
2011
2012 // Calculate the scaling parameters for the bias and output.
2013 auto biasScaleParam = quantization::quantizeScaleOffset32To8(
2014 biasTy->getScale() / matMulScale, 0);
2015 auto outScaleParam = quantization::quantizeScaleOffset32To8(
2016 matMulScale / destTy->getScale(), 0);
2017
2018 // Pass the pre-shift, post-shift and integer scale parameters for the
2019 // bias and output calculation.
2020 auto *biasPre = emitConstI32(builder, biasScaleParam.pre);
2021 auto *biasPost = emitConstI32(builder, biasScaleParam.post);
2022 auto *biasScale = emitConstI32(builder, biasScaleParam.scale);
2023 auto *outPre = emitConstI32(builder, outScaleParam.pre);
2024 auto *outPost = emitConstI32(builder, outScaleParam.post);
2025 auto *outScale = emitConstI32(builder, outScaleParam.scale);
2026
2027 auto *F =
2028 getFunction("fc", {dest->getElementType(), bias->getElementType()});
2029 createCall(builder, F,
2030 {destPtr, srcPtr, weightsPtr, biasPtr, destDims, srcDims,
2031 weightsDims, biasDims, destOffset, srcOffset, weightsOffset,
2032 biasOffset, biasPre, biasPost, biasScale, outPre, outPost,
2033 outScale});
2034 } else {
2035 auto *F = getFunction("fc", dest->getElementType());
2036 createCall(builder, F,
2037 {destPtr, srcPtr, weightsPtr, biasPtr, destDims, srcDims,
2038 weightsDims, biasDims});
2039 }
2040 break;
2041 }
2042
2043 case Kinded::Kind::RowwiseQuantizedFullyConnectedInstKind: {
2044 auto *RWQFC = cast<RowwiseQuantizedFullyConnectedInst>(I);
2045
2046 auto scalesT = getTensorForConstantValue(RWQFC->getScales());
2047 auto scalesH = scalesT.getHandle();
2048 size_t rowNum = scalesH.dims()[0];
2049 float inputScale = RWQFC->getSrc()->getType()->getScale();
2050
2051 float bScale = RWQFC->getBias()->getType()->getScale();
2052 int32_t bOffset = RWQFC->getBias()->getType()->getOffset();
2053
2054 float outputScale = RWQFC->getDest()->getType()->getScale();
2055
2056 std::vector<llvm::Constant *> biasPreV(rowNum);
2057 std::vector<llvm::Constant *> biasPostV(rowNum);
2058 std::vector<llvm::Constant *> biasScaleV(rowNum);
2059 std::vector<llvm::Constant *> outputPreV(rowNum);
2060 std::vector<llvm::Constant *> outputPostV(rowNum);
2061 std::vector<llvm::Constant *> outputScaleV(rowNum);
2062
2063 for (size_t i = 0; i < rowNum; i++) {
2064 // Calculate the scale of the values that come out of the matrix
2065 // multiplication part of the calculation.
2066 float matMulScale = inputScale * scalesH.raw(i);
2067
2068 // Calculate the scaling parameters for the bias and output.
2069 auto biasScaleParam =
2070 quantization::quantizeScaleOffset32To8(bScale / matMulScale, bOffset);
2071 auto outScaleParam =
2072 quantization::quantizeScaleOffset32To8(matMulScale / outputScale, 0);
2073
2074 // Pass the pre-shift, post-shift and integer scale parameters for the
2075 // bias and output calculation.
2076 biasPreV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2077 biasScaleParam.pre, true);
2078 biasPostV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2079 biasScaleParam.post, true);
2080 biasScaleV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2081 biasScaleParam.scale, true);
2082 outputPreV[i] =
2083 llvm::ConstantInt::get(builder.getInt32Ty(), outScaleParam.pre, true);
2084 outputPostV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2085 outScaleParam.post, true);
2086 outputScaleV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2087 outScaleParam.scale, true);
2088 }
2089
2090 auto *dest = RWQFC->getDest();
2091 auto *src = RWQFC->getSrc();
2092 auto *weights = RWQFC->getWeights();
2093 auto *bias = RWQFC->getBias();
2094 auto *weightsOffsets = RWQFC->getOffsets();
2095
2096 auto *destPtr = emitValueAddress(builder, dest);
2097 auto *srcPtr = emitValueAddress(builder, src);
2098 auto *weightsPtr = emitValueAddress(builder, weights);
2099 auto *biasPtr = emitValueAddress(builder, bias);
2100 auto *weightsOffsetsPtr = emitValueAddress(builder, weightsOffsets);
2101 auto *biasPrePtr = emitConstArray(builder, biasPreV, builder.getInt32Ty());
2102 auto *biasPostPtr =
2103 emitConstArray(builder, biasPostV, builder.getInt32Ty());
2104 auto *biasScalePtr =
2105 emitConstArray(builder, biasScaleV, builder.getInt32Ty());
2106 auto *outputPrePtr =
2107 emitConstArray(builder, outputPreV, builder.getInt32Ty());
2108 auto *outputPostPtr =
2109 emitConstArray(builder, outputPostV, builder.getInt32Ty());
2110 auto *outputScalePtr =
2111 emitConstArray(builder, outputScaleV, builder.getInt32Ty());
2112
2113 auto *srcDims = emitValueDims(builder, src);
2114 auto *weightsDims = emitValueDims(builder, weights);
2115 auto *destDims = emitValueDims(builder, dest);
2116 auto *biasDims = emitValueDims(builder, bias);
2117 auto *row = emitConstDimT(builder, weightsOffsets->dims()[0]);
2118
2119 auto *destOffset = emitConstI32(builder, dest->getType()->getOffset());
2120 auto *srcOffset = emitConstI32(builder, src->getType()->getOffset());
2121 auto *biasOffset = emitConstI32(builder, bOffset);
2122
2123 llvm::Function *F = nullptr;
2124 if ((dest->getElementType() == ElemKind::Int8QTy) &&
2125 (bias->getElementType() == ElemKind::Int8QTy)) {
2126 F = getFunction("rowwise_quantized_fc_i8_i8");
2127 } else if ((dest->getElementType() == ElemKind::Int8QTy) &&
2128 (bias->getElementType() == ElemKind::Int32QTy)) {
2129 F = getFunction("rowwise_quantized_fc_i8_i32");
2130 } else {
2131 LOG(FATAL) << "Unsupported element/bias type for "
2132 "RowwiseQuantizedFullyConnectedInst";
2133 }
2134
2135 createCall(builder, F,
2136 {destPtr, srcPtr, weightsPtr, biasPtr, weightsOffsetsPtr,
2137 biasPrePtr, biasPostPtr, biasScalePtr, outputPrePtr,
2138 outputPostPtr, outputScalePtr, destDims, srcDims, weightsDims,
2139 biasDims, row, destOffset, srcOffset, biasOffset});
2140 break;
2141 }
2142
2143 case Kinded::Kind::BatchedAddInstKind: {
2144 auto *BA = cast<BatchedAddInst>(I);
2145 auto *dest = BA->getDest();
2146 auto *batch = BA->getBatch();
2147 auto *slice = BA->getSlice();
2148 auto *destPtr = emitValueAddress(builder, dest);
2149 auto *batchPtr = emitValueAddress(builder, batch);
2150 auto *slicePtr = emitValueAddress(builder, slice);
2151
2152 auto bdim = flattenCdr(batch->dims());
2153 auto *numSlice = emitConstDimT(builder, bdim.first);
2154 auto *sliceSize = emitConstDimT(builder, bdim.second);
2155
2156 if (batch->getType()->isQuantizedType()) {
2157 auto *destTy = dest->getType();
2158 auto *batchTy = batch->getType();
2159 auto *sliceTy = slice->getType();
2160
2161 auto *destOffset = emitConstI32(builder, destTy->getOffset());
2162 auto *batchOffset = emitConstI32(builder, batchTy->getOffset());
2163 auto *sliceOffset = emitConstI32(builder, sliceTy->getOffset());
2164
2165 float destScale = destTy->getScale();
2166
2167 // Here, we select parameters for scaling both summands to the
2168 // destination scale.
2169 auto batchScaleParams = quantization::quantizeScaleOffset32To8(
2170 batchTy->getScale() / destScale, batchTy->getOffset());
2171 auto sliceScaleParams = quantization::quantizeScaleOffset32To8(
2172 sliceTy->getScale() / destScale, sliceTy->getOffset());
2173
2174 auto *batchPre = emitConstI32(builder, batchScaleParams.pre);
2175 auto *batchPost = emitConstI32(builder, batchScaleParams.post);
2176 auto *batchScale = emitConstI32(builder, batchScaleParams.scale);
2177 auto *slicePre = emitConstI32(builder, sliceScaleParams.pre);
2178 auto *slicePost = emitConstI32(builder, sliceScaleParams.post);
2179 auto *sliceScale = emitConstI32(builder, sliceScaleParams.scale);
2180
2181 llvm::Function *F = nullptr;
2182 if (sliceTy->getElementType() == ElemKind::Int8QTy) {
2183 F = getFunction("batchedadd", dest->getElementType());
2184 } else if (sliceTy->getElementType() == ElemKind::Int32QTy) {
2185 F = getFunction("batchedadd_i32", dest->getElementType());
2186 } else {
2187 LOG(FATAL) << "Type is not supported: "
2188 << Type::getElementName(sliceTy->getElementType()).str();
2189 }
2190 createCall(builder, F,
2191 {destPtr, batchPtr, slicePtr, numSlice, sliceSize, destOffset,
2192 batchOffset, sliceOffset, batchPre, batchPost, batchScale,
2193 slicePre, slicePost, sliceScale});
2194 } else {
2195 auto *F = getFunction("batchedadd", dest->getElementType());
2196 createCall(builder, F,
2197 {destPtr, batchPtr, slicePtr, numSlice, sliceSize});
2198 }
2199 break;
2200 }
2201
2202 case Kinded::Kind::BatchedReduceAddInstKind: {
2203 auto *BR = cast<BatchedReduceAddInst>(I);
2204 auto *dest = BR->getDest();
2205 auto *batch = BR->getBatch();
2206 auto *destPtr = emitValueAddress(builder, dest);
2207 auto *batchPtr = emitValueAddress(builder, batch);
2208 auto *axis = emitConstDimT(builder, BR->getAxis());
2209
2210 ShapeVector eBatchDims = expandDimsToMax(batch->dims());
2211 ShapeVector eDestDims = eBatchDims;
2212 eDestDims[BR->getAxis()] = 1;
2213
2214 auto *batchDims =
2215 emitConstDimTArray(builder, llvm::makeArrayRef(eBatchDims));
2216 auto *destDims = emitConstDimTArray(builder, llvm::makeArrayRef(eDestDims));
2217
2218 auto *F = getFunction("batchedreduceadd", dest->getElementType());
2219
2220 if (batch->getType()->isQuantizedType()) {
2221 auto *destTy = dest->getType();
2222 auto *batchTy = batch->getType();
2223
2224 auto *destOffset = emitConstI32(builder, destTy->getOffset());
2225 auto *batchOffset = emitConstI32(builder, batchTy->getOffset());
2226
2227 // BatchedReduceAdd is an accumulation operation, with equations
2228 // s_d * (i_d - o_d) = \sum s_b * (i_b - o_b)
2229 // => i_d - o_d = \sum (s_b / s_d) * (i_b - o_b)
2230 // => i_d = (s_b / s_d ) * [\sum (i_b - o_b)] + o_d
2231 auto batchScaleParams = quantization::quantizeScaleOffset32To8(
2232 batchTy->getScale() / destTy->getScale(), batchTy->getOffset());
2233
2234 auto *batchPre = emitConstI32(builder, batchScaleParams.pre);
2235 auto *batchPost = emitConstI32(builder, batchScaleParams.post);
2236 auto *batchScale = emitConstI32(builder, batchScaleParams.scale);
2237
2238 createCall(builder, F,
2239 {destPtr, batchPtr, destDims, batchDims, destOffset,
2240 batchOffset, batchPre, batchPost, batchScale, axis});
2241 } else {
2242 auto *destSize = emitConstDimT(builder, dest->size());
2243
2244 createCall(builder, F,
2245 {destPtr, batchPtr, destSize, destDims, batchDims, axis});
2246 }
2247 break;
2248 }
2249
2250 case Kinded::Kind::BatchedReduceProdInstKind: {
2251 auto *BR = cast<BatchedReduceProdInst>(I);
2252 auto *dest = BR->getDest();
2253 auto *batch = BR->getBatch();
2254 auto *destPtr = emitValueAddress(builder, dest);
2255 auto *batchPtr = emitValueAddress(builder, batch);
2256 auto *axis = emitConstDimT(builder, BR->getAxis());
2257
2258 ShapeVector eBatchDims = expandDimsToMax(batch->dims());
2259 ShapeVector eDestDims = eBatchDims;
2260 eDestDims[BR->getAxis()] = 1;
2261
2262 auto *batchDims =
2263 emitConstDimTArray(builder, llvm::makeArrayRef(eBatchDims));
2264 auto *destDims = emitConstDimTArray(builder, llvm::makeArrayRef(eDestDims));
2265
2266 auto *F = getFunction("batchedreduceprod", dest->getElementType());
2267
2268 assert(!batch->getType()->isQuantizedType() &&
2269 "Quantized implementation for ReduceProd not supported yet.");
2270
2271 auto *destSize = emitConstDimT(builder, dest->size());
2272
2273 createCall(builder, F,
2274 {destPtr, batchPtr, destSize, destDims, batchDims, axis});
2275
2276 break;
2277 }
2278
2279#define BATCHED_REDUCE_MINMAX_CASE(INST_NAME_, FUN_NAME_) \
2280 case Kinded::Kind::Batched##INST_NAME_##InstKind: { \
2281 auto *BR = cast<Batched##INST_NAME_##Inst>(I); \
2282 auto *dest = BR->getDest(); \
2283 auto *batch = BR->getBatch(); \
2284 auto axes = BR->getAxes(); \
2285 auto *destPtr = emitValueAddress(builder, dest); \
2286 auto *batchPtr = emitValueAddress(builder, batch); \
2287 \
2288 ShapeVector eBatchDims = expandDimsToMax(batch->dims()); \
2289 ShapeVector eDestDims = eBatchDims; \
2290 for (dim_t i = 0; i < axes.size(); i++) { \
2291 eDestDims[axes[i]] = 1; \
2292 } \
2293 \
2294 auto *batchDims = \
2295 emitConstDimTArray(builder, llvm::makeArrayRef(eBatchDims)); \
2296 auto *destDims = \
2297 emitConstDimTArray(builder, llvm::makeArrayRef(eDestDims)); \
2298 \
2299 if (((batch->getElementType() != ElemKind::FloatTy) && \
2300 (batch->getElementType() != ElemKind::Int32ITy) && \
2301 (batch->getElementType() != ElemKind::Int64ITy)) || \
2302 (batch->getElementType() != dest->getElementType())) { \
2303 std::string errStr = "Cannot get function for "; \
2304 std::string name = "INST_NAME_"; \
2305 errStr += name; \
2306 llvm_unreachable(errStr.c_str()); \
2307 } \
2308 \
2309 llvm::Function *F = getFunction(FUN_NAME_, batch->getElementType()); \
2310 if (!batch->getType()->isQuantizedType()) { \
2311 auto *destSize = emitConstSizeT(builder, dest->size()); \
2312 \
2313 createCall(builder, F, \
2314 {destPtr, batchPtr, destSize, destDims, batchDims}); \
2315 } \
2316 break; \
2317 }
2318 BATCHED_REDUCE_MINMAX_CASE(ReduceMin, "reducemin")
2319 BATCHED_REDUCE_MINMAX_CASE(ReduceMax, "reducemax")
2320#undef BATCHED_REDUCE_MINMAX_CASE
2321
2322 case Kinded::Kind::ConvolutionInstKind: {
2323 auto *CI = cast<ConvolutionInst>(I);
2324 assert(CI->getLayout() == NHWC &&
2325 "Glow CPU Backend supports only NHWC Convolutions");
2326 auto *dest = CI->getDest();
2327 auto *src = CI->getSrc();
2328 auto *filter = CI->getFilter();
2329 auto *bias = CI->getBias();
2330 auto *destPtr = emitValueAddress(builder, dest);
2331 auto *srcPtr = emitValueAddress(builder, src);
2332 auto *filterPtr = emitValueAddress(builder, filter);
2333 auto *biasPtr = emitValueAddress(builder, bias);
2334
2335 auto *destDims = emitValueDims(builder, dest);
2336 auto *srcDims = emitValueDims(builder, src);
2337 auto *filterDims = emitValueDims(builder, filter);
2338 auto *biasDims = emitValueDims(builder, bias);
2339
2340 auto *kernels = emitConstDimTArray(builder, CI->getKernels());
2341 auto *strides = emitConstDimTArray(builder, CI->getStrides());
2342 auto *pads = emitConstDimTArray(builder, CI->getPads());
2343 auto *group = emitConstDimT(builder, CI->getGroup());
2344 auto *dilation = emitConstDimTArray(builder, CI->getDilation());
2345
2346 auto destDepth = dest->dims()[3];
2347
2348 // Try to 'block' the convolution on the 'depth' dimension. We will process
2349 // this number output slices each iteration.
2350 unsigned unrollDFactor = 1;
2351
2352 // In libjit_convolution_f function, 'unrollDFactor' output
2353 // layers will be processed together. Therefore, the number of
2354 // output layers in each group should be divisible by 'unrollDFactor'
2355 bool groupDividedBy8 = ((destDepth / CI->getGroup()) % 8) == 0;
2356 if (groupDividedBy8) {
2357 unrollDFactor = 8;
2358 }
2359
2360 auto *unrollD = emitConstI32(builder, unrollDFactor);
2361
2362 auto *actType = emitConstI32(builder, CI->getFusedActivation());
2363
2364 if (src->getType()->isQuantizedType()) {
2365 auto *destTy = dest->getType();
2366 auto *srcTy = src->getType();
2367 auto *filterTy = filter->getType();
2368 auto *biasTy = bias->getType();
2369
2370 auto *destOffset = emitConstI32(builder, destTy->getOffset());
2371 auto *srcOffset = emitConstI32(builder, srcTy->getOffset());
2372 auto *filterOffset = emitConstI32(builder, filterTy->getOffset());
2373 auto *biasOffset = emitConstI32(builder, biasTy->getOffset());
2374
2375 // Calculate the scale of the values that come out of the matrix
2376 // multiplication part of the calculation.
2377 float matMulScale = srcTy->getScale() * filterTy->getScale();
2378
2379 // Calculate the scaling parameters for the bias and output.
2380 auto biasScaleParam = quantization::quantizeScaleOffset32To8(
2381 biasTy->getScale() / matMulScale, biasTy->getOffset());
2382 auto outScaleParam = quantization::quantizeScaleOffset32To8(
2383 matMulScale / destTy->getScale(), 0);
2384
2385 // Pass the pre-shift, post-shift and integer scale parameters for the
2386 // bias and output calculation.
2387 auto *biasPre = emitConstI32(builder, biasScaleParam.pre);
2388 auto *biasPost = emitConstI32(builder, biasScaleParam.post);
2389 auto *biasScale = emitConstI32(builder, biasScaleParam.scale);
2390 auto *outPre = emitConstI32(builder, outScaleParam.pre);
2391 auto *outPost = emitConstI32(builder, outScaleParam.post);
2392 auto *outScale = emitConstI32(builder, outScaleParam.scale);
2393
2394 // Emit parameters for fused activation.
2395 auto *actArgsQuant = emitConstQuantActivationArgs(builder, CI);
2396
2397 auto *F = getFunction("conv2d",
2398 {dest->getElementType(), bias->getElementType()});
2399
2400 createCall(builder, F,
2401 {destPtr, srcPtr, filterPtr, biasPtr, destDims,
2402 srcDims, filterDims, biasDims, kernels, strides,
2403 pads, group, destOffset, srcOffset, filterOffset,
2404 biasOffset, biasPre, biasPost, biasScale, outPre,
2405 outPost, outScale, unrollD, dilation, actType,
2406 actArgsQuant});
2407 } else {
2408
2409 // Emit parameters for fused activation.
2410 auto *actArgsFloat = emitConstFloatActivationArgs(builder, CI);
2411
2412 auto *F = getFunction("conv2d", dest->getElementType());
2413
2414 createCall(builder, F,
2415 {destPtr, srcPtr, filterPtr, biasPtr, destDims, srcDims,
2416 filterDims, biasDims, kernels, strides, pads, group, unrollD,
2417 dilation, actType, actArgsFloat});
2418 }
2419 break;
2420 }
2421
2422 case Kinded::Kind::ConvolutionGradInstKind: {
2423 auto *CG = cast<ConvolutionGradInst>(I);
2424 auto *srcGrad = CG->getSrcGrad();
2425 auto *destGrad = CG->getDestGrad();
2426 auto *src = CG->getSrc();
2427 auto *filterGrad = CG->getFilterGrad();
2428 auto *srcGradPtr = emitValueAddress(builder, srcGrad);
2429 auto *destGradPtr = emitValueAddress(builder, destGrad);
2430 auto *srcPtr = emitValueAddress(builder, src);
2431 auto *filterGradPtr = emitValueAddress(builder, filterGrad);
2432 auto *biasGradPtr = emitValueAddress(builder, CG->getBiasGrad());
2433 auto *filterPtr = emitValueAddress(builder, CG->getFilter());
2434
2435 auto *destGradDims = emitValueDims(builder, destGrad);
2436 auto *srcDims = emitValueDims(builder, src);
2437 auto *filterGradDims = emitValueDims(builder, filterGrad);
2438
2439 auto *kernels = emitConstDimTArray(builder, CG->getKernels());
2440 auto *strides = emitConstDimTArray(builder, CG->getStrides());
2441 auto *pads = emitConstDimTArray(builder, CG->getPads());
2442 auto *group = emitConstDimT(builder, CG->getGroup());
2443 auto *dilation = emitConstDimTArray(builder, CG->getDilation());
2444
2445 auto *F = getFunction("convolution_grad", srcGrad->getElementType());
2446 createCall(builder, F,
2447 {srcGradPtr, destGradPtr, srcPtr, filterGradPtr, biasGradPtr,
2448 filterPtr, destGradDims, srcDims, filterGradDims, kernels,
2449 strides, pads, group, dilation});
2450 break;
2451 }
2452
2453 case Kinded::Kind::ConvTransposeInstKind: {
2454 auto *CI = cast<ConvTransposeInst>(I);
2455 auto *dest = CI->getDest();
2456 auto *src = CI->getSrc();
2457 auto *filter = CI->getFilter();
2458 auto *bias = CI->getBias();
2459 auto *destPtr = emitValueAddress(builder, dest);
2460 auto *srcPtr = emitValueAddress(builder, src);
2461 auto *filterPtr = emitValueAddress(builder, filter);
2462 auto *biasPtr = emitValueAddress(builder, bias);
2463
2464 auto *destDims = emitValueDims(builder, dest);
2465 auto *srcDims = emitValueDims(builder, src);
2466 auto *filterDims = emitValueDims(builder, filter);
2467 auto *biasDims = emitValueDims(builder, bias);
2468
2469 auto *kernels = emitConstDimTArray(builder, CI->getKernels());
2470 auto *strides = emitConstDimTArray(builder, CI->getStrides());
2471 auto *pads = emitConstDimTArray(builder, CI->getPads());
2472 auto *group = emitConstDimT(builder, CI->getGroup());
2473 auto *dilation = emitConstDimTArray(builder, CI->getDilation());
2474
2475 const char *kernelName = "conv_transpose";
2476
2477 auto *F = getFunction(kernelName, dest->getElementType());
2478
2479 if (src->getType()->isQuantizedType()) {
2480 auto *destTy = dest->getType();
2481 auto *srcTy = src->getType();
2482 auto *filterTy = filter->getType();
2483
2484 auto *destOffset = emitConstI32(builder, destTy->getOffset());
2485 auto *srcOffset = emitConstI32(builder, srcTy->getOffset());
2486 auto *filterOffset = emitConstI32(builder, filterTy->getOffset());
2487
2488 // Calculate the scale of the values that come out of the matrix
2489 // multiplication part of the calculation.
2490 float matMulScale = srcTy->getScale() * filterTy->getScale();
2491
2492 // Calculate the scaling parameters for the bias and output.
2493 auto outScaleParam = quantization::quantizeScaleOffset32To8(
2494 matMulScale / destTy->getScale(), 0);
2495
2496 // Pass the pre-shift, post-shift and integer scale parameters for the
2497 // output calculation.
2498 auto *outPre = emitConstI32(builder, outScaleParam.pre);
2499 auto *outPost = emitConstI32(builder, outScaleParam.post);
2500 auto *outScale = emitConstI32(builder, outScaleParam.scale);
2501
2502 createCall(builder, F,
2503 {destPtr, srcPtr, filterPtr, biasPtr, destDims, srcDims,
2504 filterDims, biasDims, kernels, strides, pads, group,
2505 destOffset, srcOffset, filterOffset, outPre, outPost,
2506 outScale, dilation});
2507 } else {
2508 createCall(builder, F,
2509 {destPtr, srcPtr, filterPtr, biasPtr, destDims, srcDims,
2510 filterDims, biasDims, kernels, strides, pads, group,
2511 dilation});
2512 }
2513 break;
2514 }
2515
2516 case Kinded::Kind::ChannelwiseQuantizedConvolutionInstKind: {
2517 auto *CQCI = cast<ChannelwiseQuantizedConvolutionInst>(I);
2518 auto *dest = CQCI->getDest();
2519 auto *src = CQCI->getSrc();
2520 auto *filter = CQCI->getFilter();
2521 auto *bias = CQCI->getBias();
2522 auto *filterScales = CQCI->getFilterScales();
2523 auto *filterOffsets = CQCI->getFilterOffsets();
2524 auto *biasScales = CQCI->getBiasScales();
2525 auto *biasOffsets = CQCI->getBiasOffsets();
2526
2527 auto *destTy = dest->getType();
2528 auto *srcTy = src->getType();
2529
2530 auto filterScalesT = getTensorForConstantValue(filterScales);
2531 auto filterScalesH = filterScalesT.getHandle<float>();
2532
2533 auto biasScalesT = getTensorForConstantValue(biasScales);
2534 auto biasScalesH = biasScalesT.getHandle<float>();
2535
2536 // Compute quantization parameters for each channel.
2537 auto channelNum = dest->dims().back();
2538 std::vector<llvm::Constant *> biasPreV(channelNum);
2539 std::vector<llvm::Constant *> biasPostV(channelNum);
2540 std::vector<llvm::Constant *> biasScaleV(channelNum);
2541 std::vector<llvm::Constant *> outputPreV(channelNum);
2542 std::vector<llvm::Constant *> outputPostV(channelNum);
2543 std::vector<llvm::Constant *> outputScaleV(channelNum);
2544 for (size_t i = 0; i < channelNum; i++) {
2545
2546 // Compute the scaling parameters for bias and output.
2547 float matMulScale = srcTy->getScale() * filterScalesH.raw(i);
2548 auto biasScaleParam = quantization::quantizeScaleOffset32To8(
2549 biasScalesH.raw(i) / matMulScale, 0);
2550 auto outScaleParam = quantization::quantizeScaleOffset32To8(
2551 matMulScale / destTy->getScale(), 0);
2552
2553 // Pass the pre-shift, post-shift and integer scale parameters for the
2554 // bias and output calculation.
2555 biasPreV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2556 biasScaleParam.pre, true);
2557 biasPostV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2558 biasScaleParam.post, true);
2559 biasScaleV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2560 biasScaleParam.scale, true);
2561 outputPreV[i] =
2562 llvm::ConstantInt::get(builder.getInt32Ty(), outScaleParam.pre, true);
2563 outputPostV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2564 outScaleParam.post, true);
2565 outputScaleV[i] = llvm::ConstantInt::get(builder.getInt32Ty(),
2566 outScaleParam.scale, true);
2567 }
2568
2569 auto *destPtr = emitValueAddress(builder, dest);
2570 auto *srcPtr = emitValueAddress(builder, src);
2571 auto *filterPtr = emitValueAddress(builder, filter);
2572 auto *biasPtr = emitValueAddress(builder, bias);
2573
2574 auto *destDims = emitValueDims(builder, dest);
2575 auto *srcDims = emitValueDims(builder, src);
2576 auto *filterDims = emitValueDims(builder, filter);
2577 auto *biasDims = emitValueDims(builder, bias);
2578
2579 auto *kernels = emitConstDimTArray(builder, CQCI->getKernels());
2580 auto *strides = emitConstDimTArray(builder, CQCI->getStrides());
2581 auto *pads = emitConstDimTArray(builder, CQCI->getPads());
2582 auto *group = emitConstDimT(builder, CQCI->getGroup());
2583 auto *dilation = emitConstDimTArray(builder, CQCI->getDilation());
2584
2585 auto *destOffset = emitConstI32(builder, destTy->getOffset());
2586 auto *srcOffset = emitConstI32(builder, srcTy->getOffset());
2587 auto *filterOffsetsPtr = emitValueAddress(builder, filterOffsets);
2588 auto *biasOffsetsPtr = emitValueAddress(builder, biasOffsets);
2589
2590 auto *biasPrePtr = emitConstArray(builder, biasPreV, builder.getInt32Ty());
2591 auto *biasPostPtr =
2592 emitConstArray(builder, biasPostV, builder.getInt32Ty());
2593 auto *biasScalePtr =
2594 emitConstArray(builder, biasScaleV, builder.getInt32Ty());
2595 auto *outputPrePtr =
2596 emitConstArray(builder, outputPreV, builder.getInt32Ty());
2597 auto *outputPostPtr =
2598 emitConstArray(builder, outputPostV, builder.getInt32Ty());
2599 auto *outputScalePtr =
2600 emitConstArray(builder, outputScaleV, builder.getInt32Ty());
2601
2602 bool isConv3D = (srcTy->dims().size() == 5);
2603 auto *F = getFunction(isConv3D ? "channelwise_quantized_conv3d"
2604 : "channelwise_quantized_conv2d",
2605 {dest->getElementType(), bias->getElementType()});
2606
2607 auto *actType = emitConstI32(builder, CQCI->getFusedActivation());
2608 auto *actArgsQuant = emitConstQuantActivationArgs(builder, CQCI);
2609
2610 createCall(builder, F,
2611 {destPtr, srcPtr, filterPtr, biasPtr,
2612 destDims, srcDims, filterDims, biasDims,
2613 kernels, strides, pads, group,
2614 dilation, destOffset, srcOffset, filterOffsetsPtr,
2615 biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr,
2616 outputPrePtr, outputPostPtr, outputScalePtr, actType,
2617 actArgsQuant});
2618 break;
2619 }
2620
2621 case Kinded::Kind::CrossEntropyLossInstKind: {
2622 auto *CI = cast<CrossEntropyLossInst>(I);
2623 auto *P = CI->getP();
2624 auto *labels = CI->getLabels();
2625 auto *CE = CI->getCE();
2626
2627 auto *CEPtr = emitValueAddress(builder, CE);
2628 auto *PPtr = emitValueAddress(builder, P);
2629 auto *labelsPtr = emitValueAddress(builder, labels);
2630 auto *dims = emitValueDims(builder, P);
2631
2632 auto *F = getFunction("cross_entropy_loss",
2633 {CE->getElementType(), labels->getElementType()});
2634 createCall(builder, F, {CEPtr, PPtr, labelsPtr, dims});
2635 break;
2636 }
2637
2638 case Kinded::Kind::LengthsToRangesInstKind: {
2639 auto *LTR = cast<LengthsToRangesInst>(I);
2640 auto *dest = LTR->getDest();
2641 auto *lengths = LTR->getLengths();
2642 auto *destPtr = emitValueAddress(builder, dest);
2643 auto *lengthsPtr = emitValueAddress(builder, lengths);
2644 auto *size = emitConstDimT(builder, lengths->dims()[0]);
2645 auto *F = getFunction("lengths_to_ranges", dest->getElementType());
2646 createCall(builder, F, {destPtr, lengthsPtr, size});
2647 break;
2648 }
2649
2650 case Kinded::Kind::LengthsSumInstKind: {
2651 auto *LS = cast<LengthsSumInst>(I);
2652 auto *dest = LS->getDest();
2653 auto *data = LS->getData();
2654 auto *lengths = LS->getLengths();
2655
2656 auto *destPtr = emitValueAddress(builder, dest);
2657 auto *dataPtr = emitValueAddress(builder, data);
2658 auto *lengthsPtr = emitValueAddress(builder, lengths);
2659
2660 auto *lengthsSize = emitConstDimT(builder, lengths->size());
2661 auto *dataType = data->getType();
2662 auto *destSize = emitConstDimT(builder, dest->size());
2663 auto *sliceSize =
2664 emitConstDimT(builder, dataType->size() / dataType->dims()[0]);
2665
2666 auto *F = getFunction("lengths_sum", data->getElementType());
2667 createCall(
2668 builder, F,
2669 {destPtr, dataPtr, lengthsPtr, destSize, lengthsSize, sliceSize});
2670 break;
2671 }
2672
2673 case Kinded::Kind::LocalResponseNormalizationInstKind: {
2674 auto *LRN = cast<LocalResponseNormalizationInst>(I);
2675 auto *dest = LRN->getDest();
2676 auto *src = LRN->getSrc();
2677 auto *destPtr = emitValueAddress(builder, dest);
2678 auto *srcPtr = emitValueAddress(builder, src);
2679 auto *scalePtr = emitValueAddress(builder, LRN->getScale());
2680
2681 auto *destDims = emitValueDims(builder, dest);
2682 auto *srcDims = emitValueDims(builder, src);
2683 auto *halfWindow = emitConstDimT(builder, LRN->getHalfWindowSize());
2684 auto *alpha = emitConstF32(builder, LRN->getAlpha());
2685 auto *beta = emitConstF32(builder, LRN->getBeta());
2686 auto *k = emitConstF32(builder, LRN->getK());
2687
2688 auto *F =
2689 getFunction("local_response_normalization", dest->getElementType());
2690 createCall(builder, F,
2691 {destPtr, srcPtr, scalePtr, destDims, srcDims, halfWindow, alpha,
2692 beta, k});
2693 break;
2694 }
2695
2696 case Kinded::Kind::LocalResponseNormalizationGradInstKind: {
2697 auto *LRNG = llvm::cast<LocalResponseNormalizationGradInst>(I);
2698 auto *srcGrad = LRNG->getSrcGrad();
2699 auto *dest = LRNG->getDest();
2700 auto *srcGradPtr = emitValueAddress(builder, srcGrad);
2701 auto *destGradPtr = emitValueAddress(builder, LRNG->getDestGrad());
2702 auto *srcPtr = emitValueAddress(builder, LRNG->getSrc());
2703 auto *destPtr = emitValueAddress(builder, dest);
2704 auto *scalePtr = emitValueAddress(builder, LRNG->getScale());
2705
2706 auto *destDims = emitValueDims(builder, dest);
2707
2708 auto *halfWindow = emitConstDimT(builder, LRNG->getHalfWindowSize());
2709 auto *alpha = emitConstF32(builder, LRNG->getAlpha());
2710 auto *beta = emitConstF32(builder, LRNG->getBeta());
2711
2712 auto *F = getFunction("local_response_normalization_grad",
2713 srcGrad->getElementType());
2714 createCall(builder, F,
2715 {srcGradPtr, destGradPtr, srcPtr, destPtr, scalePtr, destDims,
2716 halfWindow, alpha, beta});
2717 break;
2718 }
2719
2720 case Kinded::Kind::MaxPoolInstKind: {
2721 auto *PM = cast<MaxPoolInst>(I);
2722 assert(PM->getLayout() == NHWC &&
2723 "Glow CPU Backend supports only NHWC Pools");
2724 auto *dest = PM->getDest();
2725 auto *src = PM->getSrc();
2726 auto *destPtr = emitValueAddress(builder, dest);
2727 auto *srcPtr = emitValueAddress(builder, src);
2728
2729 auto *destDims = emitValueDims(builder, dest);
2730 auto *srcDims = emitValueDims(builder, src);
2731
2732 auto *kernels = emitConstDimTArray(builder, PM->getKernels());
2733 auto *strides = emitConstDimTArray(builder, PM->getStrides());
2734 auto *pads = emitConstDimTArray(builder, PM->getPads());
2735
2736 auto *F = getFunction("max_pool", dest->getElementType());
2737
2738 if (src->getType()->isQuantizedType()) {
2739 auto *destOffset = emitConstI32(builder, dest->getType()->getOffset());
2740 createCall(builder, F,
2741 {srcPtr, destPtr, srcDims, destDims, kernels, strides, pads,
2742 destOffset});
2743 } else {
2744 createCall(builder, F,
2745 {srcPtr, destPtr, srcDims, destDims, kernels, strides, pads});
2746 }
2747 break;
2748 }
2749
2750 case Kinded::Kind::MaxPoolWithArgmaxInstKind: {
2751 auto *PMXY = cast<MaxPoolWithArgmaxInst>(I);
2752 assert(PMXY->getLayout() == NHWC &&
2753 "Glow CPU Backend supports only NHWC Pools");
2754 auto *dest = PMXY->getDest();
2755 auto *src = PMXY->getSrc();
2756 auto *destPtr = emitValueAddress(builder, dest);
2757 auto *srcPtr = emitValueAddress(builder, src);
2758 auto *argMax = PMXY->getArgmax();
2759 auto *argmaxPtr = emitValueAddress(builder, argMax);
2760
2761 auto *destDims = emitValueDims(builder, dest);
2762 auto *srcDims = emitValueDims(builder, src);
2763
2764 auto *kernels = emitConstDimTArray(builder, PMXY->getKernels());
2765 auto *strides = emitConstDimTArray(builder, PMXY->getStrides());
2766 auto *pads = emitConstDimTArray(builder, PMXY->getPads());
2767
2768 auto *F = getFunction("max_pool_argmax",
2769 {dest->getElementType(), argMax->getElementType()});
2770 createCall(builder, F,
2771 {srcPtr, destPtr, argmaxPtr, srcDims, destDims, kernels, strides,
2772 pads});
2773 break;
2774 }
2775
2776 case Kinded::Kind::MaxPoolWithArgmaxGradInstKind: {
2777 auto *PMG = cast<MaxPoolWithArgmaxGradInst>(I);
2778 auto *srcGrad = PMG->getSrcGrad();
2779 auto *srcGradPtr = emitValueAddress(builder, srcGrad);
2780 auto *destGradPtr = emitValueAddress(builder, PMG->getDestGrad());
2781 auto *argMax = PMG->getArgmax();
2782 auto *argmaxPtr = emitValueAddress(builder, argMax);
2783
2784 auto *srcGradDims = emitValueDims(builder, srcGrad);
2785 auto *destDims = emitValueDims(builder, PMG->getDest());
2786
2787 auto *F = getFunction("max_pool_argmax_grad", {srcGrad->getElementType(),
2788 argMax->getElementType()});
2789 createCall(builder, F,
2790 {srcGradPtr, destGradPtr, argmaxPtr, srcGradDims, destDims});
2791 break;
2792 }
2793
2794 case Kinded::Kind::ArgMaxInstKind: {
2795 auto *AM = cast<ArgMaxInst>(I);
2796 auto *dest = AM->getDest();
2797 auto *src = AM->getSrc();
2798 auto *destPtr = emitValueAddress(builder, dest);
2799 auto *srcPtr = emitValueAddress(builder, src);
2800 auto *srcDims = emitValueDims(builder, src);
2801 auto *srcNumDims = emitConstSizeT(builder, src->dims().size());
2802 auto *axis = emitConstSizeT(builder, AM->getAxis());
2803 auto *F =
2804 getFunction("arg_max", {src->getElementType(), dest->getElementType()});
2805 createCall(builder, F, {srcPtr, destPtr, srcDims, srcNumDims, axis});
2806 break;
2807 }
2808
2809 case Kinded::Kind::ArgMinInstKind: {
2810 auto *AM = cast<ArgMinInst>(I);
2811 auto *dest = AM->getDest();
2812 auto *src = AM->getSrc();
2813 auto *destPtr = emitValueAddress(builder, dest);
2814 auto *srcPtr = emitValueAddress(builder, src);
2815 auto *srcDims = emitValueDims(builder, src);
2816 auto *srcNumDims = emitConstSizeT(builder, src->dims().size());
2817 auto *axis = emitConstSizeT(builder, AM->getAxis());
2818 auto *F =
2819 getFunction("arg_min", {src->getElementType(), dest->getElementType()});
2820 createCall(builder, F, {srcPtr, destPtr, srcDims, srcNumDims, axis});
2821 break;
2822 }
2823
2824 case Kinded::Kind::AvgPoolInstKind: {
2825 auto *PA = cast<AvgPoolInst>(I);
2826 assert(PA->getLayout() == NHWC &&
2827 "Glow CPU Backend supports only NHWC Pools");
2828 auto *dest = PA->getDest();
2829 auto *src = PA->getSrc();
2830 auto *destPtr = emitValueAddress(builder, dest);
2831 auto *srcPtr = emitValueAddress(builder, src);
2832
2833 auto *destDims = emitValueDims(builder, dest);
2834 auto *srcDims = emitValueDims(builder, src);
2835
2836 auto *kernels = emitConstDimTArray(builder, PA->getKernels());
2837 auto *strides = emitConstDimTArray(builder, PA->getStrides());
2838 auto *pads = emitConstDimTArray(builder, PA->getPads());
2839 auto *countIncludePads = emitConstI1(builder, PA->getCountIncludePads());
2840
2841 auto *F = getFunction("avg_pool", dest->getElementType());
2842
2843 if (src->getType()->isQuantizedType()) {
2844 auto *destTy = dest->getType();
2845 auto *srcTy = src->getType();
2846 auto *destOffset = emitConstI32(builder, destTy->getOffset());
2847 auto *srcOffset = emitConstI32(builder, srcTy->getOffset());
2848 // When we count the padding pixels in the normalizing factor we include
2849 // the filter area in the scaling parameters since it is a constant.
2850 float scale = srcTy->getScale() / destTy->getScale();
2851 if (PA->getCountIncludePads()) {
2852 scale = scale / (PA->getKernels()[0] * PA->getKernels()[1]);
2853 }
2854 auto outScaleParam = quantization::quantizeScaleOffset32To8(scale, 0);
2855 auto *outPre = emitConstI32(builder, outScaleParam.pre);
2856 auto *outPost = emitConstI32(builder, outScaleParam.post);
2857 auto *outScale = emitConstI32(builder, outScaleParam.scale);
2858 createCall(builder, F,
2859 {srcPtr, destPtr, srcDims, destDims, kernels, strides, pads,
2860 countIncludePads, destOffset, srcOffset, outPre, outPost,
2861 outScale});
2862 } else {
2863 createCall(builder, F,
2864 {srcPtr, destPtr, srcDims, destDims, kernels, strides, pads,
2865 countIncludePads});
2866 }
2867 break;
2868 }
2869
2870 case Kinded::Kind::AdaptiveAvgPoolInstKind: {
2871 auto *PA = cast<AdaptiveAvgPoolInst>(I);
2872
2873 auto *dest = PA->getDest();
2874 auto *src = PA->getSrc();
2875 auto *destPtr = emitValueAddress(builder, dest);
2876 auto *srcPtr = emitValueAddress(builder, src);
2877
2878 auto *destDims = emitValueDims(builder, dest);
2879 auto *srcDims = emitValueDims(builder, src);
2880
2881 auto *F = getFunction("adaptive_avg_pool", dest->getElementType());
2882 createCall(builder, F, {srcPtr, destPtr, srcDims, destDims});
2883 break;
2884 }
2885
2886 case Kinded::Kind::AvgPoolGradInstKind: {
2887 auto *PAG = cast<AvgPoolGradInst>(I);
2888 auto *srcGrad = PAG->getSrcGrad();
2889 auto *srcGradPtr = emitValueAddress(builder, srcGrad);
2890 auto *destGradPtr = emitValueAddress(builder, PAG->getDestGrad());
2891
2892 auto *srcGradDims = emitValueDims(builder, srcGrad);
2893 auto *destDims = emitValueDims(builder, PAG->getDest());
2894
2895 auto *kernels = emitConstDimTArray(builder, PAG->getKernels());
2896 auto *strides = emitConstDimTArray(builder, PAG->getStrides());
2897 auto *pads = emitConstDimTArray(builder, PAG->getPads());
2898 auto *countIncludePads = emitConstI1(builder, PAG->getCountIncludePads());
2899
2900 auto *F = getFunction("avg_pool_grad", srcGrad->getElementType());
2901 createCall(builder, F,
2902 {srcGradPtr, destGradPtr, srcGradDims, destDims, kernels,
2903 strides, pads, countIncludePads});
2904 break;
2905 }
2906
2907 case Kinded::Kind::SoftMaxInstKind: {
2908 auto *SM = cast<SoftMaxInst>(I);
2909 auto *dest = SM->getDest();
2910 auto *src = SM->getSrc();
2911 auto *destPtr = emitValueAddress(builder, dest);
2912 auto *srcPtr = emitValueAddress(builder, src);
2913 auto *destDims = emitValueDims(builder, dest);
2914 auto *srcDims = emitValueDims(builder, src);
2915 auto *F = getFunction("softmax", dest->getElementType());
2916
2917 if (src->getType()->isQuantizedType()) {
2918 std::vector<int32_t> lut;
2919
2920 // Compute lookup table containing all the exponentials based on the
2921 // formula e^(scale * value), where scale is the input scale of
2922 // the quantized input data and value is a value from [-255, 0].
2923 for (int32_t i = 0; i < 256; i++) {
2924 auto exponent =
2925 FixedPointUInt32(exp(src->getType()->getScale() * (i - 255)), 1)
2926 .getFixedVal();
2927 lut.push_back(exponent);
2928 }
2929
2930 auto *lutPtr = emitConstI32Array(builder, lut);
2931 auto *outOffset = emitConstI32(builder, dest->getType()->getOffset());
2932 float size = static_cast<float>(src->getType()->dims()[1]);
2933 auto *sumIntegerPart = emitConstI32(builder, ceil(log2(size)));
2934
2935 if (ceil(log2(size)) == floor(log2(size))) {
2936 sumIntegerPart = emitConstI32(builder, ceil(log2(size)) + 1);
2937 }
2938
2939 FixedPointUInt32 invScaleFixedPoint =
2940 FixedPointUInt32(1.f / dest->getType()->getScale());
2941 auto *invScale = emitConstI32(builder, invScaleFixedPoint.getFixedVal());
2942 auto *invScalePoint =
2943 emitConstI32(builder, invScaleFixedPoint.getIntBits());
2944 createCall(builder, F,
2945 {srcPtr, destPtr, srcDims, lutPtr, outOffset, invScale,
2946 sumIntegerPart, invScalePoint});
2947 } else {
2948 createCall(builder, F, {srcPtr, destPtr, srcDims, destDims});
2949 }
2950
2951 break;
2952 }
2953
2954 case Kinded::Kind::SoftMaxGradInstKind: {
2955 auto *SMG = cast<SoftMaxGradInst>(I);
2956 auto *srcGrad = SMG->getSrcGrad();
2957 auto *selected = SMG->getSelected();
2958 auto *srcGradPtr = emitValueAddress(builder, srcGrad);
2959 auto *destPtr = emitValueAddress(builder, SMG->getOrigDest());
2960 auto *selectedPtr = emitValueAddress(builder, selected);
2961
2962 auto *srcGradDims = emitValueDims(builder, srcGrad);
2963 auto *selectedDims = emitValueDims(builder, selected);
2964
2965 auto *F = getFunction("softmax_grad", {srcGrad->getElementType(),
2966 selected->getElementType()});
2967 createCall(builder, F,
2968 {srcGradPtr, destPtr, selectedPtr, srcGradDims, selectedDims});
2969 break;
2970 }
2971
2972 case Kinded::Kind::TopKInstKind: {
2973 auto *TI = cast<TopKInst>(I);
2974 auto *input = TI->getInput();
2975 auto *valuesPtr = emitValueAddress(builder, TI->getValues());
2976 auto *indicesPtr = emitValueAddress(builder, TI->getIndices());
2977 auto *inputPtr = emitValueAddress(builder, input);
2978 auto *scratchPtr = emitValueAddress(builder, TI->getScratch());
2979
2980 auto *k = emitConstDimT(builder, TI->getK());
2981 auto *n = emitConstDimT(builder, input->dims().back());
2982 auto *size = emitConstDimT(builder, input->size());
2983
2984 auto indicesTy = TI->getIndices()->getElementType();
2985 auto *F = getFunction("topk", {input->getElementType(), indicesTy});
2986
2987 createCall(builder, F,
2988 {valuesPtr, indicesPtr, inputPtr, scratchPtr, k, n, size});
2989 break;
2990 }
2991
2992 case Kinded::Kind::SpaceToDepthInstKind: {
2993 auto *SI = cast<SpaceToDepthInst>(I);
2994 auto *dest = SI->getDest();
2995 auto *src = SI->getSrc();
2996
2997 auto *dstPtr = emitValueAddress(builder, dest);
2998 auto *srcPtr = emitValueAddress(builder, src);
2999
3000 auto *dstDims = emitValueDims(builder, dest);
3001 auto *srcDims = emitValueDims(builder, src);
3002
3003 unsigned blockSize = SI->getBlockSize();
3004
3005 auto *F = getFunction("space_to_depth", src->getElementType());
3006 createCall(
3007 builder, F,
3008 {srcPtr, dstPtr, emitConstDimT(builder, blockSize), srcDims, dstDims});
3009 break;
3010 }
3011
3012 case Kinded::Kind::TransposeInstKind: {
3013 auto *TI = cast<TransposeInst>(I);
3014 auto *dest = TI->getDest();
3015 auto *src = TI->getSrc();
3016 auto *destPtr = emitValueAddress(builder, dest);
3017 auto *srcPtr = emitValueAddress(builder, src);
3018
3019 auto *destDims = emitValueDims(builder, dest);
3020 auto *srcDims = emitValueDims(builder, src);
3021
3022 // Convert the mask to size_t type.
3023 ShapeVector shuffSizeT;
3024 for (auto D : TI->getShuffle()) {
3025 shuffSizeT.push_back((size_t)D);
3026 }
3027
3028 auto *shuffle = emitConstDimTArray(builder, llvm::makeArrayRef(shuffSizeT));
3029 auto *len = emitConstDimT(builder, TI->getShuffle().size());
3030
3031 auto *F = getFunction("transpose", dest->getElementType());
3032 createCall(builder, F, {srcPtr, destPtr, srcDims, destDims, shuffle, len});
3033 break;
3034 }
3035
3036 case Kinded::Kind::FlipInstKind: {
3037 auto *FI = cast<FlipInst>(I);
3038 auto *dest = FI->getDest();
3039 auto *src = FI->getSrc();
3040 auto *destPtr = emitValueAddress(builder, dest);
3041 auto *srcPtr = emitValueAddress(builder, src);
3042 auto *dims = emitValueDims(builder, src);
3043 auto *axis = emitConstDimT(builder, FI->getAxis());
3044 auto *dimsSize = emitConstDimT(builder, src->getType()->dims().size());
3045 auto *F = getFunction("flip", src->getElementType());
3046 createCall(builder, F, {srcPtr, destPtr, dims, axis, dimsSize});
3047 break;
3048 }
3049
3050 // Alloc and Dealloc instructions are handled by the memory allocator.
3051 case Kinded::Kind::AllocActivationInstKind:
3052 case Kinded::Kind::DeallocActivationInstKind:
3053 case Kinded::Kind::TensorViewInstKind:
3054 break;
3055
3056 case Kinded::Kind::InsertTensorInstKind: {
3057 auto *ITI = llvm::cast<InsertTensorInst>(I);
3058 auto *dest = ITI->getDest();
3059 auto *src = ITI->getSrc();
3060 auto offsets = ITI->getOffsets();
3061 auto *destPtr = emitValueAddress(builder, dest);
3062 auto *srcPtr = emitValueAddress(builder, src);
3063
3064 auto *destDims = emitValueDims(builder, dest);
3065 auto *srcDims = emitValueDims(builder, src);
3066
3067 auto *destDimsSize = emitConstDimT(builder, dest->getType()->dims().size());
3068 auto *srcDimsSize = emitConstDimT(builder, src->getType()->dims().size());
3069 auto *offsetsPtr = emitConstDimTArray(builder, offsets);
3070 auto *offsetsArraySize = emitConstDimT(builder, offsets.size());
3071 auto *count = emitConstDimT(builder, ITI->getCount());
3072 auto *axis = emitConstDimT(builder, ITI->getAxis());
3073
3074 // Don't specialize the offsetPtr because we typically generate lots of
3075 // extracts from different offsets and specializing on this argument does
3076 // not speed things up.
3077 markArgAsUnspecialized(offsetsPtr);
3078
3079 auto *F = getFunction("insert_tensor", dest->getElementType());
3080 createCall(builder, F,
3081 {destPtr, srcPtr, offsetsPtr, destDims, srcDims, destDimsSize,
3082 srcDimsSize, offsetsArraySize, count, axis});
3083 break;
3084 }
3085
3086 case Kinded::Kind::ExtractTensorInstKind: {
3087 auto *ITI = llvm::cast<ExtractTensorInst>(I);
3088 auto *dest = ITI->getDest();
3089 auto *src = ITI->getSrc();
3090 auto offsets = ITI->getOffsets();
3091 auto *destPtr = emitValueAddress(builder, dest);
3092 auto *srcPtr = emitValueAddress(builder, src);
3093
3094 auto *destDims = emitValueDims(builder, dest);
3095 auto *srcDims = emitValueDims(builder, src);
3096
3097 auto *destDimsSize = emitConstDimT(builder, dest->getType()->dims().size());
3098 auto *srcDimsSize = emitConstDimT(builder, src->getType()->dims().size());
3099 auto *offsetsPtr = emitConstDimTArray(builder, offsets);
3100 auto *offsetsArraySize = emitConstDimT(builder, offsets.size());
3101
3102 // Don't specialize the offsetPtr because we typically generate lots of
3103 // extracts from different offsets and specializing on this argument does
3104 // not speed things up.
3105 markArgAsUnspecialized(offsetsPtr);
3106
3107 auto *F = getFunction("extract_tensor", dest->getElementType());
3108 createCall(builder, F,
3109 {srcPtr, destPtr, offsetsPtr, srcDims, destDims, srcDimsSize,
3110 destDimsSize, offsetsArraySize});
3111 break;
3112 }
3113
3114 case Kinded::Kind::GatherInstKind: {
3115 auto *GI = llvm::cast<GatherInst>(I);
3116 auto *dest = GI->getDest();
3117 auto *data = GI->getData();
3118 auto *indices = GI->getIndices();
3119 unsigned axis = GI->getBatchDims();
3120
3121 auto *destPtr = emitValueAddress(builder, dest);
3122 auto *dataPtr = emitValueAddress(builder, data);
3123 auto *indicesPtr = emitValueAddress(builder, indices);
3124
3125 auto *indicesSize = emitConstDimT(builder, indices->size());
3126
3127 auto *dataType = data->getType();
3128
3129 // The size of the sample in the batch.
3130 size_t sampleSize = dataType->getSliceSize(axis);
3131 // The size of the slices that we gather.
3132 size_t sliceSize = dataType->getSliceSize(axis + 1);
3133 // The size of each sample in the batch.
3134 size_t numSamples = dataType->size() / sampleSize;
3135
3136 auto *sliceSizeVal = emitConstDimT(builder, sliceSize);
3137 auto *numSamplesVal = emitConstDimT(builder, numSamples);
3138 auto *sampleSizeVal = emitConstDimT(builder, sampleSize);
3139
3140 // Dispatching function depending on the input type of Indices.
3141 llvm::Function *F = nullptr;
3142 if (indices->getElementType() == ElemKind::Int64ITy) {
3143 F = getFunction("gather64", dest->getElementType());
3144 } else if (indices->getElementType() == ElemKind::Int32ITy) {
3145 F = getFunction("gather32", dest->getElementType());
3146 }
3147 if (!F) {
3148 llvm_unreachable("Cannot get function for Gather. "
3149 "Indices input of Gather has to be int32 or int64");
3150 }
3151 createCall(builder, F,
3152 {destPtr, dataPtr, indicesPtr, indicesSize, sliceSizeVal,
3153 numSamplesVal, sampleSizeVal});
3154 break;
3155 }
3156
3157 case Kinded::Kind::GatherNDInstKind: {
3158 auto *GI = llvm::cast<GatherNDInst>(I);
3159 auto *dest = GI->getDest();
3160 auto *data = GI->getData();
3161 auto *indices = GI->getIndices();
3162 unsigned batchDims = GI->getBatchDims();
3163
3164 auto dataDims = data->dims();
3165 auto indicesDims = indices->dims();
3166 dim_t indicesDimLast = indicesDims.back();
3167
3168 // Compute batch count.
3169 dim_t batchCount = 1;
3170 for (size_t idx = 0; idx < batchDims; ++idx) {
3171 batchCount *= dataDims[idx];
3172 }
3173
3174 // Compute input slice count.
3175 dim_t inpSliceCount = 1;
3176 for (size_t idx = batchDims; idx < batchDims + indicesDimLast; ++idx) {
3177 inpSliceCount *= dataDims[idx];
3178 }
3179
3180 // Compute output slice count.
3181 dim_t outSliceCount = 1;
3182 for (size_t idx = batchDims; idx < indicesDims.size() - 1; ++idx) {
3183 outSliceCount *= indicesDims[idx];
3184 }
3185
3186 // Compute slice size (in bytes).
3187 dim_t sliceSize = data->getType()->getElementSize();
3188 for (size_t idx = batchDims + indicesDimLast; idx < dataDims.size();
3189 idx++) {
3190 sliceSize *= dataDims[idx];
3191 }
3192
3193 // Get indices dimension products.
3194 std::vector<dim_t> indicesDimProd(indicesDimLast);
3195 indicesDimProd[indicesDimLast - 1] = 1;
3196 for (ssize_t idx = static_cast<ssize_t>(indicesDimLast) - 2; idx >= 0;
3197 idx--) {
3198 indicesDimProd[idx] =
3199 indicesDimProd[idx + 1] * dataDims[batchDims + idx + 1];
3200 }
3201
3202 // Emit pointers.
3203 auto *destPtr = emitValueAddress(builder, dest);
3204 auto *dataPtr = emitValueAddress(builder, data);
3205 auto *indicesPtr = emitValueAddress(builder, indices);
3206
3207 // Emit parameters.
3208 auto *batchCountArg = emitConstDimT(builder, batchCount);
3209 auto *inpSliceCountArg = emitConstDimT(builder, inpSliceCount);
3210 auto *outSliceCountArg = emitConstDimT(builder, outSliceCount);
3211 auto *sliceSizeArg = emitConstDimT(builder, sliceSize);
3212 auto *indicesDimLastArg = emitConstDimT(builder, indicesDimLast);
3213 auto *indicesDimProdArg =
3214 emitConstDimTArray(builder, llvm::makeArrayRef(indicesDimProd));
3215
3216 llvm::Function *F = getFunction(
3217 "gather_nd", {data->getElementType(), indices->getElementType()});
3218 createCall(builder, F,
3219 {destPtr, dataPtr, indicesPtr, batchCountArg, inpSliceCountArg,
3220 outSliceCountArg, sliceSizeArg, indicesDimLastArg,
3221 indicesDimProdArg});
3222 break;
3223 }
3224
3225 case Kinded::Kind::GatherRangesInstKind: {
3226 auto *GRI = llvm::cast<GatherRangesInst>(I);
3227 auto *output = GRI->getOutput();
3228 auto *lengths = GRI->getLengths();
3229 auto *data = GRI->getData();
3230 auto *ranges = GRI->getRanges();
3231
3232 auto *outputPtr = emitValueAddress(builder, output);
3233 auto *lengthsPtr = emitValueAddress(builder, lengths);
3234 auto *dataPtr = emitValueAddress(builder, data);
3235 auto *rangesPtr = emitValueAddress(builder, ranges);
3236
3237 auto rangesType = ranges->getType();
3238
3239 // The number of examples in ranges.
3240 size_t numExamples = rangesType->dims()[0];
3241 // The number of range pairs in each example.
3242 size_t exampleSize = rangesType->dims()[1];
3243
3244 auto *numExamplesVal = emitConstDimT(builder, numExamples);
3245 auto *exampleSizeVal = emitConstDimT(builder, exampleSize);
3246
3247 // Dispatching function depending on the input type of Ranges.
3248 llvm::Function *F = nullptr;
3249 if (ranges->getElementType() == ElemKind::Int64ITy) {
3250 F = getFunction("gatherranges64", output->getElementType());
3251 } else if (ranges->getElementType() == ElemKind::Int32ITy) {
3252 F = getFunction("gatherranges32", output->getElementType());
3253 }
3254 if (!F) {
3255 llvm_unreachable("Cannot get function for GatherRanges. "
3256 "Ranges input of GatherRanges has to be int32 or int64");
3257 }
3258 createCall(builder, F,
3259 {outputPtr, lengthsPtr, dataPtr, rangesPtr, numExamplesVal,
3260 exampleSizeVal});
3261 break;
3262 }
3263
3264 case Kinded::Kind::LengthsRangeFillInstKind: {
3265 auto *LRFI = llvm::cast<LengthsRangeFillInst>(I);
3266 auto *dest = LRFI->getDest();
3267 auto *lengths = LRFI->getLengths();
3268
3269 auto *destPtr = emitValueAddress(builder, dest);
3270 auto *lengthsPtr = emitValueAddress(builder, lengths);
3271
3272 auto *lengthsSize = emitConstDimT(builder, lengths->size());
3273
3274 // Dispatching function depending on the input type of Ranges.
3275 auto *F = getFunction("lengths_range_fill", dest->getElementType());
3276 createCall(builder, F, {lengthsPtr, destPtr, lengthsSize});
3277 break;
3278 }
3279
3280 case Kinded::Kind::ScatterDataInstKind: {
3281 auto *SDI = llvm::cast<ScatterDataInst>(I);
3282 auto *data = SDI->getData();
3283 auto *indices = SDI->getIndices();
3284 auto *slices = SDI->getSlices();
3285
3286 auto *dataPtr = emitValueAddress(builder, data);
3287 auto *indicesPtr = emitValueAddress(builder, indices);
3288 auto *slicesPtr = emitValueAddress(builder, slices);
3289 auto *dataDims = emitValueDims(builder, data);
3290
3291 auto *indicesCnt = emitConstDimT(builder, indices->getType()->dims()[0]);
3292 auto *indicesSize = emitConstDimT(builder, indices->getType()->dims()[1]);
3293 auto *slicesType = slices->getType();
3294 auto *sliceSize =
3295 emitConstDimT(builder, slicesType->size() / slicesType->dims()[0]);
3296 auto *isCumulative = emitConstI1(builder, SDI->getCumulative());
3297 auto *F = getFunction("scatterdata",
3298 {data->getElementType(), indices->getElementType()});
3299 if (data->getType()->isQuantizedType()) {
3300 auto *dataScale = emitConstF32(builder, data->getType()->getScale());
3301 auto *dataOffset = emitConstI32(builder, data->getType()->getOffset());
3302 auto *sliceScale = emitConstF32(builder, slices->getType()->getScale());
3303 auto *sliceOffset = emitConstI32(builder, slices->getType()->getOffset());
3304 createCall(builder, F,
3305 {dataPtr, dataDims, indicesPtr, slicesPtr, indicesCnt,
3306 indicesSize, sliceSize, isCumulative, dataScale, dataOffset,
3307 sliceScale, sliceOffset});
3308 } else {
3309 createCall(builder, F,
3310 {dataPtr, dataDims, indicesPtr, slicesPtr, indicesCnt,
3311 indicesSize, sliceSize, isCumulative});
3312 }
3313 break;
3314 }
3315
3316 case Kinded::Kind::SparseLengthsSumInstKind: {
3317 auto *SI = cast<SparseLengthsSumInst>(I);
3318 auto *dest = SI->getDest();
3319 auto *data = SI->getData();
3320 auto *indices = SI->getIndices();
3321 auto *lengths = SI->getLengths();
3322 auto *destPtr = emitValueAddress(builder, dest);
3323 auto *dataPtr = emitValueAddress(builder, data);
3324 auto *indicesPtr = emitValueAddress(builder, indices);
3325 auto *lengthsPtr = emitValueAddress(builder, lengths);
3326 auto *segments = emitConstDimT(builder, lengths->dims()[0]);
3327 auto *lineSize = emitConstDimT(builder, data->size() / data->dims()[0]);
3328 auto *F = getFunction("sparse_lengths_sum",
3329 {dest->getElementType(), indices->getElementType()});
3330 createCall(builder, F,
3331 {destPtr, dataPtr, indicesPtr, lengthsPtr, segments, lineSize});
3332 break;
3333 }
3334
3335 case Kinded::Kind::SparseLengthsWeightedSumInstKind: {
3336 auto *SI = cast<SparseLengthsWeightedSumInst>(I);
3337 auto *dest = SI->getDest();
3338 auto *data = SI->getData();
3339 auto *weights = SI->getWeights();
3340 auto *indices = SI->getIndices();
3341 auto *lengths = SI->getLengths();
3342 auto *destPtr = emitValueAddress(builder, dest);
3343 auto *dataPtr = emitValueAddress(builder, data);
3344 auto *weightsPtr = emitValueAddress(builder, weights);
3345 auto *indicesPtr = emitValueAddress(builder, indices);
3346 auto *lengthsPtr = emitValueAddress(builder, lengths);
3347 auto *segments = emitConstDimT(builder, lengths->dims()[0]);
3348 auto *lineSize = emitConstDimT(builder, data->size() / data->dims()[0]);
3349 auto *F = getFunction("sparse_lengths_weighted_sum",
3350 {dest->getElementType(), indices->getElementType()});
3351 createCall(builder, F,
3352 {destPtr, dataPtr, weightsPtr, indicesPtr, lengthsPtr, segments,
3353 lineSize});
3354 break;
3355 }
3356
3357 case Kinded::Kind::EmbeddingInstKind: {
3358 auto *SI = cast<EmbeddingInst>(I);
3359 auto *dest = SI->getDest();
3360 auto *weights = SI->getWeights();
3361 auto *indices = SI->getIndices();
3362 auto *padIdx = emitConstSizeT(builder, SI->getPadIdx());
3363 auto *scale = emitConstI1(builder, SI->getScale());
3364 auto *sparse = emitConstI1(builder, SI->getSparse());
3365 auto *destPtr = emitValueAddress(builder, dest);
3366 auto *weightsPtr = emitValueAddress(builder, weights);
3367 auto *indicesPtr = emitValueAddress(builder, indices);
3368 auto *indDims = emitValueDims(builder, indices);
3369 auto *indSize = emitConstDimT(builder, indices->dims().size());
3370 assert(weights->dims().size() == 2 && "weights must be 2-D");
3371 auto *numEmbedding = emitConstDimT(builder, weights->dims()[0]);
3372 auto *embeddingDim = emitConstDimT(builder, weights->dims()[1]);
3373 auto *F = getFunction("embedding", dest->getElementType());
3374 createCall(builder, F,
3375 {destPtr, weightsPtr, indicesPtr, indDims, indSize, numEmbedding,
3376 embeddingDim, padIdx, scale, sparse});
3377 break;
3378 }
3379
3380 case Kinded::Kind::EmbeddingBagInstKind: {
3381 auto *SI = cast<EmbeddingBagInst>(I);
3382 auto *dest = SI->getDest();
3383 auto *data = SI->getData();
3384 auto *weights = SI->getWeights();
3385 auto *indices = SI->getIndices();
3386 auto *offsets = SI->getOffsets();
3387 auto *hasEndOffset = emitConstI1(builder, SI->getHasEndOffset());
3388 auto *destPtr = emitValueAddress(builder, dest);
3389 auto *dataPtr = emitValueAddress(builder, data);
3390 auto *weightsPtr = emitValueAddress(builder, weights);
3391 auto *indicesPtr = emitValueAddress(builder, indices);
3392 auto *offsetsPtr = emitValueAddress(builder, offsets);
3393 auto *segments = emitConstDimT(builder, offsets->dims()[0]);
3394 auto *totalLength = emitConstDimT(builder, indices->dims()[0]);
3395 auto *lineSize = emitConstDimT(builder, data->size() / data->dims()[0]);
3396 auto *F = getFunction("embedding_bag", dest->getElementType());
3397 createCall(builder, F,
3398 {destPtr, dataPtr, weightsPtr, indicesPtr, offsetsPtr, segments,
3399 lineSize, totalLength, hasEndOffset});
3400 break;
3401 }
3402
3403 case Kinded::Kind::SparseLengthsWeightedSumGradInstKind: {
3404 auto *SI = cast<SparseLengthsWeightedSumGradInst>(I);
3405 auto *destGrad = SI->getDestGrad();
3406 auto *dataGrad = SI->getDataGrad();
3407 auto *weightsGrad = SI->getWeightsGrad();
3408 auto *data = SI->getData();
3409 auto *weights = SI->getWeights();
3410 auto *indices = SI->getIndices();
3411 auto *lengths = SI->getLengths();
3412 auto *destGradPtr = emitValueAddress(builder, destGrad);
3413 auto *dataGradPtr = emitValueAddress(builder, dataGrad);
3414 auto *weightsGradPtr = emitValueAddress(builder, weightsGrad);
3415 auto *dataPtr = emitValueAddress(builder, data);
3416 auto *weightsPtr = emitValueAddress(builder, weights);
3417 auto *indicesPtr = emitValueAddress(builder, indices);
3418 auto *lengthsPtr = emitValueAddress(builder, lengths);
3419 auto *segments = emitConstDimT(builder, lengths->dims()[0]);
3420 auto *dataGradRawSize =
3421 emitConstDimT(builder, dataGrad->size() * sizeof(float));
3422 auto *lineSize =
3423 emitConstDimT(builder, dataGrad->size() / dataGrad->dims()[0]);
3424 auto *F =
3425 getFunction("sparse_lengths_weighted_sum_grad",
3426 {destGrad->getElementType(), indices->getElementType()});
3427 createCall(builder, F,
3428 {destGradPtr, dataGradPtr, weightsGradPtr, dataPtr, weightsPtr,
3429 indicesPtr, lengthsPtr, segments, lineSize, dataGradRawSize});
3430 break;
3431 }
3432
3433 case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumInstKind: {
3434 auto *N = cast<RowwiseQuantizedSparseLengthsWeightedSumInst>(I);
3435 auto *dest = N->getDest();
3436 auto *data = N->getData();
3437 auto *scales = N->getScales();
3438 auto *offsets = N->getOffsets();
3439 auto *weights = N->getWeights();
3440 auto *indices = N->getIndices();
3441 auto *lengths = N->getLengths();
3442 auto *destPtr = emitValueAddress(builder, dest);
3443 auto *dataPtr = emitValueAddress(builder, data);
3444 auto *scalesPtr = emitValueAddress(builder, scales);
3445 auto *offsetsPtr = emitValueAddress(builder, offsets);
3446 auto *weightsPtr = emitValueAddress(builder, weights);
3447 auto *indicesPtr = emitValueAddress(builder, indices);
3448 auto *lengthsPtr = emitValueAddress(builder, lengths);
3449 auto *segments = emitConstDimT(builder, lengths->dims()[0]);
3450 auto *lineSize = emitConstDimT(builder, data->size() / data->dims()[0]);
3451 auto *F = getFunction("rowwise_quantized_sparse_lengths_weighted_sum",
3452 {dest->getElementType(), indices->getElementType()});
3453 createCall(builder, F,
3454 {destPtr, dataPtr, scalesPtr, offsetsPtr, weightsPtr, indicesPtr,
3455 lengthsPtr, segments, lineSize});
3456 break;
3457 }
3458
3459 case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumInstKind: {
3460 auto *N = cast<FusedRowwiseQuantizedSparseLengthsWeightedSumInst>(I);
3461 auto *dest = N->getDest();
3462 auto *data = N->getData();
3463 auto *weights = N->getWeights();
3464 auto *indices = N->getIndices();
3465 auto *lengths = N->getLengths();
3466 auto *destPtr = emitValueAddress(builder, dest);
3467 auto *dataPtr = emitValueAddress(builder, data);
3468 auto *weightsPtr = emitValueAddress(builder, weights);
3469 auto *indicesPtr = emitValueAddress(builder, indices);
3470 auto *lengthsPtr = emitValueAddress(builder, lengths);
3471 auto *segments = emitConstDimT(builder, lengths->dims()[0]);
3472 auto *inLineSize = emitConstDimT(builder, data->size() / data->dims()[0]);
3473 auto *outLineSize = emitConstDimT(builder, dest->size() / dest->dims()[0]);
3474 auto *F = getFunction("fused_rowwise_quantized_sparse_lengths_weighted_sum",
3475 {dest->getElementType(), indices->getElementType()});
3476 createCall(builder, F,
3477 {destPtr, dataPtr, weightsPtr, indicesPtr, lengthsPtr, segments,
3478 inLineSize, outLineSize});
3479 break;
3480 }
3481
3482 case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsInstKind: {
3483 auto *N = cast<EmbeddingBagByteRowwiseOffsetsInst>(I);
3484 auto *dest = N->getDest();
3485 auto *data = N->getData();
3486 auto *weights = N->getWeights();
3487 auto *indices = N->getIndices();
3488 auto *offsets = N->getOffsets();
3489 auto *hasEndOffset = emitConstI1(builder, N->getHasEndOffset());
3490 auto *destPtr = emitValueAddress(builder, dest);
3491 auto *dataPtr = emitValueAddress(builder, data);
3492 auto *weightsPtr = emitValueAddress(builder, weights);
3493 auto *indicesPtr = emitValueAddress(builder, indices);
3494 auto *offsetsPtr = emitValueAddress(builder, offsets);
3495 auto *segments = emitConstDimT(builder, offsets->dims()[0]);
3496 auto *numIndices = emitConstDimT(builder, indices->dims()[0]);
3497 auto *inLineSize = emitConstDimT(builder, data->size() / data->dims()[0]);
3498 auto *outLineSize = emitConstDimT(builder, dest->size() / dest->dims()[0]);
3499 auto *F = getFunction("embedding_bag_byte_rowwise_offsets",
3500 dest->getElementType());
3501 createCall(builder, F,
3502 {destPtr, dataPtr, weightsPtr, indicesPtr, offsetsPtr, segments,
3503 numIndices, inLineSize, outLineSize, hasEndOffset});
3504 break;
3505 }
3506
3507 case Kinded::Kind::DebugPrintInstKind: {
3508 auto *DPI = llvm::cast<DebugPrintInst>(I);
3509 auto *src = DPI->getSrc();
3510 auto *srcPtr = emitValueAddress(builder, src);
3511 srcPtr = builder.CreateBitCast(srcPtr, builder.getInt8PtrTy());
3512 auto *srcDims = emitValueDims(builder, src);
3513 auto *srcDimsSize = emitConstDimT(builder, src->getType()->dims().size());
3514 auto *srcSize = emitConstSizeT(builder, src->getType()->size());
3515 auto *srcSizeBytes =
3516 emitConstSizeT(builder, src->getType()->getSizeInBytes());
3517 auto *srcElemKind =
3518 emitConstDimT(builder, static_cast<size_t>(src->getElementType()));
3519 auto *name = emitStringConst(builder, I->getName());
3520 auto *filename = emitStringConst(builder, DPI->getFileName());
3521 auto srcTypeStr = src->getType()->toString();
3522
3523 std::string format = DPI->getFormat();
3524 if (format == "console") {
3525 // Dump tensor in console.
3526 auto *F = getFunction("dump_tensor_console");
3527 createCall(builder, F, {srcPtr, srcDims, srcDimsSize, srcElemKind, name});
3528
3529 } else if (format == "bin") {
3530 // Dump tensor in file in binary format.
3531 auto *F = getFunction("dump_tensor_bin");
3532 auto *header = emitStringConst(builder, srcTypeStr);
3533 createCall(builder, F, {srcPtr, srcSizeBytes, filename, header});
3534
3535 } else if (format == "txt") {
3536 // Dump tensor in file in text format.
3537 auto *F = getFunction("dump_tensor_txt", src->getElementType());
3538 auto *header = emitStringConst(builder, srcTypeStr);
3539 createCall(builder, F, {srcPtr, srcSize, filename, header});
3540
3541 } else if (format == "rawbin") {
3542 // Dump tensor in file in raw binary format.
3543 auto *F = getFunction("dump_tensor_bin");
3544 auto *header = emitStringConst(builder, "");
3545 createCall(builder, F, {srcPtr, srcSizeBytes, filename, header});
3546
3547 } else if (format == "rawtxt") {
3548 // Dump tensor in file in raw text format.
3549 auto *F = getFunction("dump_tensor_txt", src->getElementType());
3550 auto *header = emitStringConst(builder, "");
3551 createCall(builder, F, {srcPtr, srcSize, filename, header});
3552
3553 } else {
3554 LOG(FATAL) << "Invalid 'Format' attribute for DebugPrint instruction!";
3555 }
3556 break;
3557 }
3558
3559 case Kinded::Kind::InstrumentInstKind: {
3560 auto *instrumentI = llvm::cast<InstrumentInst>(I);
3561 auto *opInfo = instrumentI->getOperandsInfo();
3562
3563 // Instruction being instrumented.
3564 Instruction *instrRef = instrumentI->getInstrRef();
3565
3566 // Emit instruction ID and instruction kind.
3567 llvm::Type *intTy =
3568 llvm::Type::getIntNTy(getLLVMContext(), getLibjitIntWidth());
3569 auto *ID = llvm::ConstantInt::get(intTy, instrumentI->getID());
3570 auto *kind = llvm::ConstantInt::get(intTy, (int)(instrRef->getKind()));
3571
3572 // Emit number of input and output operands.
3573 auto inpNum = instrRef->getNumInputs();
3574 auto outNum = instrRef->getNumOutputs();
3575 auto opNum = inpNum + outNum;
3576 auto *opInp = llvm::ConstantInt::get(intTy, inpNum);
3577 auto *opOut = llvm::ConstantInt::get(intTy, outNum);
3578
3579 // Emit opInfo address as uint8_t*.
3580 assert(opInfo->getType()->getSizeInBytes() >= 2 * sizeof(int64_t) &&
3581 "Not enough memory allocated for instrumentation!");
3582 auto *opInfoPtr = emitValueAddress(builder, opInfo);
3583 opInfoPtr = builder.CreateBitCast(opInfoPtr, builder.getInt8PtrTy());
3584
3585 // Emit opAddr address as uint8_t** starting from offset 0.
3586 auto *opAddrPtr =
3587 builder.CreateGEP(opInfoPtr, llvm::ConstantInt::get(intTy, 0));
3588 opAddrPtr = builder.CreateBitCast(opAddrPtr,
3589 builder.getInt8PtrTy()->getPointerTo());
3590
3591 // Emit opSize address as int* starting from offset opNum * sizeof(int64_t).
3592 auto *opSizePtr = builder.CreateGEP(
3593 opInfoPtr, llvm::ConstantInt::get(intTy, opNum * sizeof(int64_t)));
3594 opSizePtr = builder.CreateBitCast(opSizePtr, intTy->getPointerTo());
3595
3596 // Generate instrumentation.
3597 auto instrumentKind = instrumentI->getInstrumentKind();
3598 if (instrumentKind == InstrumentKind::Before) {
3599
3600 // Operands addresses and sizes.
3601 std::vector<llvm::Value *> opAddrArray;
3602 std::vector<llvm::Value *> opSizeArray;
3603
3604 // Get addresses and sizes for the input operands.
3605 for (const auto &op : instrRef->getOperands()) {
3606 if (op.second == OperandKind::Out) {
3607 continue;
3608 }
3609 // Emit operand address as uint8_t* variable.
3610 auto *opAddr = emitValueAddress(builder, op.first);
3611 opAddr = builder.CreateBitCast(opAddr, builder.getInt8PtrTy());
3612 opAddrArray.push_back(opAddr);
3613 // Emit operand size in bytes as int constant.
3614 auto *opSize = llvm::ConstantInt::get(
3615 intTy, op.first->getType()->getSizeInBytes());
3616 opSizeArray.push_back(opSize);
3617 }
3618 assert(opAddrArray.size() == inpNum && "Inconsistent size!");
3619
3620 // Get addresses and sizes for the output operands.
3621 for (const auto &op : instrRef->getOperands()) {
3622 if (op.second == OperandKind::In) {
3623 continue;
3624 }
3625 // Emit operand address as uint8_t* variable.
3626 auto *opAddr = emitValueAddress(builder, op.first);
3627 opAddr = builder.CreateBitCast(opAddr, builder.getInt8PtrTy());
3628 opAddrArray.push_back(opAddr);
3629 // Emit operand size in bytes as int constant.
3630 auto *opSize = llvm::ConstantInt::get(
3631 intTy, op.first->getType()->getSizeInBytes());
3632 opSizeArray.push_back(opSize);
3633 }
3634 assert(opAddrArray.size() == opNum && "Inconsistent size!");
3635
3636 // Write the addresses of the operands in the opAddr.
3637 emitArrayStore(builder, opAddrArray, opAddrPtr);
3638
3639 // Write the sizes of the operands in opSize.
3640 emitArrayStore(builder, opSizeArray, opSizePtr);
3641
3642 // Create callback call.
3643 auto *F = getFunction("instrument_before");
3644 createCall(builder, F, {ID, kind, opInp, opOut, opAddrPtr, opSizePtr});
3645
3646 } else if (instrumentKind == InstrumentKind::After) {
3647
3648 // Create callback call.
3649 auto *F = getFunction("instrument_after");
3650 createCall(builder, F, {ID, kind, opInp, opOut, opAddrPtr, opSizePtr});
3651
3652 } else {
3653 llvm_unreachable("Instrumentation kind not supported!");
3654 }
3655 // Print the IR instrumentation callback API.
3656 printInstrumentIR_ = true;
3657 break;
3658 }
3659
3660 case Kinded::Kind::TraceEventInstKind: {
3661 auto *TEI = llvm::cast<TraceEventInst>(I);
3662 auto *data = TEI->getData();
3663 auto *offset = emitConstDimT(builder, TEI->getIndex());
3664 auto *dataPtr = emitValueAddress(builder, data);
3665 auto *F = getFunction("write_timestamp");
3666 createCall(builder, F, {dataPtr, offset});
3667 break;
3668 }
3669
3670 case Kinded::Kind::ResizeNearestInstKind: {
3671 auto *RNI = llvm::cast<ResizeNearestInst>(I);
3672 auto *result = RNI->getDest();
3673 auto *input = RNI->getSrc();
3674 auto *resultPtr = emitValueAddress(builder, result);
3675 auto *inputPtr = emitValueAddress(builder, input);
3676
3677 auto *scalePtr = emitConstFloatArray(builder, RNI->getScale());
3678 auto *destDims = emitValueDims(builder, result);
3679 auto *srcDims = emitValueDims(builder, input);
3680 auto *F = getFunction("resizenearest", input->getElementType());
3681 createCall(builder, F, {resultPtr, inputPtr, scalePtr, srcDims, destDims});
3682 break;
3683 }
3684
3685 case Kinded::Kind::ResizeBilinearInstKind: {
3686 auto *RBI = llvm::cast<ResizeBilinearInst>(I);
3687 auto *result = RBI->getDest();
3688 auto *input = RBI->getSrc();
3689 auto *resultPtr = emitValueAddress(builder, result);
3690 auto *inputPtr = emitValueAddress(builder, input);
3691
3692 CHECK_EQ(RBI->getScale()[0], 1.0) << "Scaling batch not supported.";
3693 CHECK_EQ(RBI->getScale()[3], 1.0) << "Scaling channel not supported.";
3694
3695 auto *scalePtr = emitConstFloatArray(builder, RBI->getScale());
3696 auto *destDims = emitValueDims(builder, result);
3697 auto *srcDims = emitValueDims(builder, input);
3698 auto *F = getFunction("resizebilinear", input->getElementType());
3699 createCall(builder, F, {resultPtr, inputPtr, scalePtr, srcDims, destDims});
3700 break;
3701 }
3702
3703 case Kinded::Kind::NonMaxSuppressionInstKind: {
3704 auto *NMSI = llvm::cast<NonMaxSuppressionInst>(I);
3705 auto boxes = NMSI->getBoxes();
3706 auto scores = NMSI->getScores();
3707 auto indices = NMSI->getIndices();
3708 auto numDetected = NMSI->getNumberOfSelectedIndices();
3709 float iouThreshold = NMSI->getIouThreshold();
3710 int64_t maxBoxesPerClass = NMSI->getMaxOutputBoxesPerClass();
3711 float scoreThreshold = NMSI->getScoreThreshold();
3712 int centerPointBox = NMSI->getCenterPointBox();
3713 bool isV4 = NMSI->getIsTFVersion4();
3714
3715 auto *boxesPtr = emitValueAddress(builder, boxes);
3716 auto *scoresPtr = emitValueAddress(builder, scores);
3717 auto *indicesPtr = emitValueAddress(builder, indices);
3718 auto *numDetectedPtr = emitValueAddress(builder, numDetected);
3719
3720 auto *maxBoxesPerClassVal = emitConstI32(builder, maxBoxesPerClass);
3721 auto *centerPointBoxVal = emitConstI32(builder, centerPointBox);
3722 auto *iouThresholdVal = emitConstF32(builder, iouThreshold);
3723 auto *scoreThresholdVal = emitConstF32(builder, scoreThreshold);
3724
3725 auto *boxesDimVal = emitValueDims(builder, boxes);
3726 auto *scoreDimVal = emitValueDims(builder, scores);
3727 auto *indicesDimVal = emitValueDims(builder, indices);
3728 auto *boxesDimSizeVal = emitConstDimT(builder, boxes->dims().size());
3729 auto *scoresDimSizeVal = emitConstDimT(builder, scores->dims().size());
3730 auto *indicesDimSizeVal = emitConstDimT(builder, indices->dims().size());
3731 auto *isV4Val = emitConstI1(builder, isV4);
3732
3733 auto *F = getFunction("nms", indices->getElementType());
3734 createCall(builder, F,
3735 {indicesPtr, numDetectedPtr, boxesPtr, boxesDimVal,
3736 boxesDimSizeVal, scoresPtr, scoreDimVal, scoresDimSizeVal,
3737 indicesDimVal, indicesDimSizeVal, centerPointBoxVal,
3738 maxBoxesPerClassVal, iouThresholdVal, scoreThresholdVal,
3739 isV4Val});
3740 break;
3741 }
3742
3743 case Kinded::Kind::TFLiteDetectionPostProcessInstKind: {
3744 auto *DPPI = llvm::cast<TFLiteDetectionPostProcessInst>(I);
3745 auto boxes = DPPI->getBoxes();
3746 auto scores = DPPI->getScores();
3747 auto anchors = DPPI->getAnchors();
3748 auto detectionBoxes = DPPI->getDetectionBoxes();
3749 auto detectionClasses = DPPI->getDetectionClasses();
3750 auto detectionScores = DPPI->getDetectionScores();
3751 auto numDetections = DPPI->getNumDetections();
3752 auto scratch = DPPI->getScratch();
3753
3754 // Emit pointers.
3755 auto *boxesPtr = emitValueAddress(builder, boxes);
3756 auto *scoresPtr = emitValueAddress(builder, scores);
3757 auto *anchorsPtr = emitValueAddress(builder, anchors);
3758 auto *detectionBoxesPtr = emitValueAddress(builder, detectionBoxes);
3759 auto *detectionClassesPtr = emitValueAddress(builder, detectionClasses);
3760 auto *detectionScoresPtr = emitValueAddress(builder, detectionScores);
3761 auto *numDetectionsPtr = emitValueAddress(builder, numDetections);
3762 auto *scratchPtr = emitValueAddress(builder, scratch);
3763
3764 // Emit parameters.
3765 auto *numBoxes = emitConstI32(builder, boxes->dims()[1]);
3766 auto *numTotalClasses = emitConstI32(builder, scores->dims()[2]);
3767 auto *numClasses = emitConstI32(builder, DPPI->getNumClasses());
3768 auto *maxDetections = emitConstI32(builder, DPPI->getMaxDetections());
3769 auto *maxClassesPerDetection =
3770 emitConstI32(builder, DPPI->getMaxClassesPerDetection());
3771 auto *maxDetectionsPerClass =
3772 emitConstI32(builder, DPPI->getMaxDetectionsPerClass());
3773 auto *iouThreshold = emitConstF32(builder, DPPI->getIouThreshold());
3774 auto *scoreThreshold = emitConstF32(builder, DPPI->getScoreThreshold());
3775 auto *xScaleInv = emitConstF32(builder, 1.0f / DPPI->getXScale());
3776 auto *yScaleInv = emitConstF32(builder, 1.0f / DPPI->getYScale());
3777 auto *hScaleInv = emitConstF32(builder, 1.0f / DPPI->getHScale());
3778 auto *wScaleInv = emitConstF32(builder, 1.0f / DPPI->getWScale());
3779 auto *regularNMS = emitConstI1(builder, DPPI->getRegularNMS());
3780
3781 // Current implementation only supports batch size 1.
3782 assert(boxes->dims()[0] == 1 &&
3783 "TFLiteDetectionPostProcess batch not supported!");
3784
3785 // Call function.
3786 auto *F = getFunction("tflite_detection_post_process_f");
3787 createCall(builder, F,
3788 {boxesPtr,
3789 scoresPtr,
3790 anchorsPtr,
3791 detectionBoxesPtr,
3792 detectionClassesPtr,
3793 detectionScoresPtr,
3794 numDetectionsPtr,
3795 scratchPtr,
3796 numBoxes,
3797 numTotalClasses,
3798 numClasses,
3799 maxDetections,
3800 maxClassesPerDetection,
3801 maxDetectionsPerClass,
3802 iouThreshold,
3803 scoreThreshold,
3804 xScaleInv,
3805 yScaleInv,
3806 hScaleInv,
3807 wScaleInv,
3808 regularNMS});
3809 break;
3810 }
3811
3812 case Kinded::Kind::AudioSpectrogramInstKind: {
3813 auto *ASI = llvm::cast<AudioSpectrogramInst>(I);
3814 auto winOutScratch = ASI->getWinOutScratch();
3815 auto fftOutScratch = ASI->getFftOutScratch();
3816 auto spectrogram = ASI->getSpectrogram();
3817 auto input = ASI->getInput();
3818 auto window = ASI->getWindow();
3819 auto twiddleFactors = ASI->getTwiddleFactors();
3820 auto bitReverseIndices = ASI->getBitReverseIndices();
3821 auto complexToRealWeights = ASI->getComplexToRealWeights();
3822 int64_t windowSize = ASI->getWindowSize();
3823 int64_t windowStride = ASI->getWindowStride();
3824 bool magnitudeSquared = ASI->getMagnitudeSquared();
3825
3826 auto *winOutScratchPtr = emitValueAddress(builder, winOutScratch);
3827 auto *fftOutScratchPtr = emitValueAddress(builder, fftOutScratch);
3828 auto *spectrogramPtr = emitValueAddress(builder, spectrogram);
3829 auto *inputPtr = emitValueAddress(builder, input);
3830 auto *windowPtr = emitValueAddress(builder, window);
3831 auto *twiddleFactorsPtr = emitValueAddress(builder, twiddleFactors);
3832 auto *bitReverseIndicesPtr = emitValueAddress(builder, bitReverseIndices);
3833 auto *complexToRealWeightsPtr =
3834 emitValueAddress(builder, complexToRealWeights);
3835 auto *spectrogramDimVal = emitValueDims(builder, spectrogram);
3836 auto *inputLengthVal = emitConstDimT(builder, input->size());
3837 auto *windowSizeVal = emitConstDimT(builder, windowSize);
3838 auto *windowStrideVal = emitConstDimT(builder, windowStride);
3839 auto *magnitudeSquaredVal = emitConstI1(builder, magnitudeSquared);
3840
3841 auto *F = getFunction("audio_spectrogram", spectrogram->getElementType());
3842 createCall(builder, F,
3843 {winOutScratchPtr, fftOutScratchPtr, spectrogramPtr, inputPtr,
3844 windowPtr, twiddleFactorsPtr, bitReverseIndicesPtr,
3845 complexToRealWeightsPtr, spectrogramDimVal, inputLengthVal,
3846 windowSizeVal, windowStrideVal, magnitudeSquaredVal});
3847 break;
3848 }
3849
3850 case Kinded::Kind::MFCCInstKind: {
3851 auto *MFCCI = llvm::cast<MFCCInst>(I);
3852 auto scratch = MFCCI->getScratch();
3853 auto coefficients = MFCCI->getCoefficients();
3854 auto spectrogram = MFCCI->getSpectrogram();
3855 auto melWeights = MFCCI->getMelWeights();
3856 auto melRanges = MFCCI->getMelRanges();
3857 auto dctMat = MFCCI->getDctMat();
3858 int64_t filterBankCount = MFCCI->getFilterBankCount();
3859
3860 auto *scratchPtr = emitValueAddress(builder, scratch);
3861 auto *coefficientsPtr = emitValueAddress(builder, coefficients);
3862 auto *spectrogramPtr = emitValueAddress(builder, spectrogram);
3863 auto *melWeightsPtr = emitValueAddress(builder, melWeights);
3864 auto *melRangesPtr = emitValueAddress(builder, melRanges);
3865 auto *dctMatPtr = emitValueAddress(builder, dctMat);
3866 auto *coefficientsDimVal = emitValueDims(builder, coefficients);
3867 auto *spectrogramDimVal = emitValueDims(builder, spectrogram);
3868 auto *filterBankCountVal = emitConstDimT(builder, filterBankCount);
3869
3870 auto *F = getFunction("mfcc", coefficients->getElementType());
3871 createCall(builder, F,
3872 {scratchPtr, coefficientsPtr, spectrogramPtr, melWeightsPtr,
3873 melRangesPtr, dctMatPtr, coefficientsDimVal, spectrogramDimVal,
3874 filterBankCountVal});
3875 break;
3876 }
3877
3878 case Kinded::Kind::ConvertToInstKind: {
3879 auto *CTI = llvm::cast<ConvertToInst>(I);
3880 auto *input = CTI->getInput();
3881 auto *output = CTI->getResult();
3882
3883 auto *inputVal = emitValueAddress(builder, input);
3884 auto *outptVal = emitValueAddress(builder, output);
3885 auto *dimsVal = emitValueDims(builder, output);
3886 auto *dimSizeVal = emitConstDimT(builder, output->dims().size());
3887
3888 auto *F = getFunction("convertTo",
3889 {output->getElementType(), input->getElementType()});
3890
3891 createCall(builder, F, {outptVal, inputVal, dimsVal, dimSizeVal});
3892 break;
3893 }
3894
3895 default:
3896 std::string sBuf;
3897 llvm::raw_string_ostream s(sBuf);
3898 I->dump(s);
3899 LOG(FATAL) << "Cannot select the instruction: " << s.str();
3900 }
3901}
3902
3903unsigned LLVMIRGen::getTargetSizeTWidth() const {
3904 return getPointerNumBits(*TM_);
3905}
3906
3907unsigned LLVMIRGen::getLibjitSizeTWidth() const {
3908 auto *sizeTVar = getModule().getGlobalVariable("libjit_sizeTVar",
3909 /* allowInternal */ true);
3910 assert(sizeTVar && "libjit_sizeTVar is not found");
3911 return sizeTVar->getType()->getPointerElementType()->getIntegerBitWidth();
3912}
3913
3914unsigned LLVMIRGen::getLibjitIntWidth() const {
3915 auto *intVar = getModule().getGlobalVariable("libjit_intVar",
3916 /* allowInternal */ true);
3917 assert(intVar && "libjit_intVar is not found");
3918 return intVar->getType()->getPointerElementType()->getIntegerBitWidth();
3919}
3920
3921bool LLVMIRGen::isEligibleForSpecialization(const llvm::CallInst *call) {
3922 return true;
3923}
3924
3925bool LLVMIRGen::canBePartOfDataParallelKernel(
3926 const glow::Instruction *I) const {
3927 return I->isDataParallel();
3928}
3929
3930/// Extra bundle header file content with the IR instrumentation callback API.
3931static const char *instrumentIRApi =
3932 R"RAW(
3933// -----------------------------------------------------------------------------
3934// Callback function used for Glow IR instruction instrumentation:
3935// - This callback is called by the bundle BEFORE executing each instruction.
3936// - This callback must be defined by the bundle user application.
3937// ARGUMENTS:
3938// id - Instruction instance ID.
3939// kind - Instruction kind (type).
3940// opInp - Number of input operands.
3941// opOut - Number of output operands.
3942// opAddr - Array with addresses for all operands. The addresses are listed
3943// first for the input operands and then for the output operands.
3944// The array contains opInp + opOut addresses.
3945// opSize - Array with sizes (in bytes) for all operands. The sizes are listed
3946// first for the input operands and then for the output operands.
3947// The array contains opInp + opOut sizes.
3948// NOTES:
3949// - This callback should be used to dump only the input operands since the
3950// output operands are not yet computed/written when this callback is used.
3951// - This callback uses C linkage therefore if the callback is implemented in a
3952// .cpp file you must enclose the implementation in extern "C" {}.
3953// - Look in the metafile "instrument-ir.info" generated during compile-time
3954// to see more information about the instrumented instructions.
3955// -----------------------------------------------------------------------------
3956void glow_instrument_before(int id, int kind, int opInp, int opOut, uint8_t **opAddr, int *opSize);
3957
3958// -----------------------------------------------------------------------------
3959// Callback function used for Glow IR instruction instrumentation:
3960// - This callback is called by the bundle AFTER executing each instruction.
3961// - This callback must be defined by the bundle user application.
3962// ARGUMENTS:
3963// id - Instruction instance ID.
3964// kind - Instruction kind (type).
3965// opInp - Number of input operands.
3966// opOut - Number of output operands.
3967// opAddr - Array with addresses for all operands. The addresses are listed
3968// first for the input operands and then for the output operands.
3969// The array contains opInp + opOut addresses.
3970// opSize - Array with sizes (in bytes) for all operands. The sizes are listed
3971// first for the input operands and then for the output operands.
3972// The array contains opInp + opOut sizes.
3973// NOTES:
3974// - This callback should be used to dump only the output operands since some
3975// of the input operands might have been overwritten for instructions which
3976// perform in-place computation.
3977// - This callback uses C linkage therefore if the callback is implemented in a
3978// .cpp file you must enclose the implementation in extern "C" {}.
3979// - Look in the metafile "instrument-ir.info" generated during compile-time
3980// to see more information about the instrumented instructions.
3981// -----------------------------------------------------------------------------
3982void glow_instrument_after(int id, int kind, int opInp, int opOut, uint8_t **opAddr, int *opSize);
3983)RAW";
3984
3985std::string LLVMIRGen::getBundleHeaderExtra() const {
3986 std::string headerExtra = "";
3987 // Print IR instrumentation callback API.
3988 if (printInstrumentIR_) {
3989 headerExtra += std::string(instrumentIRApi);
3990 }
3991 return headerExtra;
3992}
3993