1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/LLVMIRCodeGen/LLVMIRGen.h" |
18 | #include "glow/Base/DimType.h" |
19 | |
20 | #include "glow/LLVMIRCodeGen/AllocationsInfo.h" |
21 | #include "glow/LLVMIRCodeGen/CommandLine.h" |
22 | #include "glow/LLVMIRCodeGen/LLVMBackend.h" |
23 | |
24 | #include "glow/Graph/Graph.h" |
25 | #include "glow/IR/IRUtils.h" |
26 | #include "glow/IR/Instrs.h" |
27 | #include "glow/Quantization/Base/Base.h" |
28 | |
29 | #include "llvm/ExecutionEngine/ExecutionEngine.h" |
30 | #include "llvm/IR/LegacyPassManager.h" |
31 | #include "llvm/IR/Verifier.h" |
32 | #include "llvm/IRReader/IRReader.h" |
33 | #include "llvm/Support/FileSystem.h" |
34 | #include "llvm/Support/Path.h" |
35 | #include "llvm/Support/SourceMgr.h" |
36 | #include "llvm/Support/TargetSelect.h" |
37 | #include "llvm/Target/TargetMachine.h" |
38 | |
39 | using namespace glow; |
40 | using llvm::cast; |
41 | using llvm::dyn_cast; |
42 | using llvm::isa; |
43 | |
44 | static llvm::cl::opt<bool> |
45 | dumpIR("dump-llvm-ir" , |
46 | llvm::cl::desc("Dump the LLVM-IR of the jitted code" ), |
47 | llvm::cl::init(false), llvm::cl::cat(getLLVMBackendCat())); |
48 | |
49 | static llvm::cl::opt<bool> |
50 | dumpJitAsm("dump-llvm-asm" , |
51 | llvm::cl::desc("Dump the textual assembly of the jitted code" ), |
52 | llvm::cl::init(false), llvm::cl::cat(getLLVMBackendCat())); |
53 | |
54 | llvm::cl::opt<bool> |
55 | emitDebugInfo("g" , llvm::cl::desc("Emit debug information for debuggers" ), |
56 | llvm::cl::init(false), llvm::cl::cat(getLLVMBackendCat())); |
57 | |
58 | /// Limitation of number of arguments for `emitDataParallelKernel`. |
59 | constexpr static size_t kArgLimit = 64; |
60 | |
61 | /// Query the TargetMachine to get the pointer size in bits |
62 | static unsigned getPointerNumBits(const llvm::TargetMachine &TM) { |
63 | return TM.getPointerSize(0) * 8; |
64 | } |
65 | |
66 | LLVMIRGen::LLVMIRGen(const IRFunction *F, AllocationsInfo &allocationsInfo, |
67 | std::string mainEntryName, llvm::StringRef libjitBC) |
68 | : F_(F), ctx_(std::make_unique<llvm::LLVMContext>()), |
69 | allocationsInfo_(allocationsInfo), libjitBC_(libjitBC) { |
70 | // Legalize main entry name. |
71 | setMainEntryName(mainEntryName); |
72 | } |
73 | |
74 | LLVMIRGen::LLVMIRGen(const IRFunction *F, AllocationsInfo &allocationsInfo, |
75 | std::string mainEntryName, llvm::StringRef libjitBC, |
76 | llvm::ArrayRef<llvm::MemoryBufferRef> objectRegistry) |
77 | : F_(F), ctx_(std::make_unique<llvm::LLVMContext>()), |
78 | allocationsInfo_(allocationsInfo), libjitBC_(libjitBC), |
79 | objectRegistry_(objectRegistry) { |
80 | // Legalize main entry name. |
81 | setMainEntryName(mainEntryName); |
82 | } |
83 | |
84 | /// Mutex to protect LLVM's TargetRegistry. |
85 | static std::mutex initTargetMutex; |
86 | |
87 | void LLVMIRGen::initTargetOptions(llvm::TargetOptions &targetOpts, |
88 | const LLVMBackendOptions &backendOpts) { |
89 | if (backendOpts.getFloatABI().hasValue()) { |
90 | targetOpts.FloatABIType = backendOpts.getFloatABI().getValue(); |
91 | } |
92 | if (!backendOpts.getABIName().empty()) { |
93 | targetOpts.MCOptions.ABIName = backendOpts.getABIName(); |
94 | } |
95 | } |
96 | |
97 | void LLVMIRGen::initTargetMachine(const LLVMBackendOptions &opts) { |
98 | // LLVM's TargetRegistry is not thread safe so we add a critical section. |
99 | std::lock_guard<std::mutex> g(initTargetMutex); |
100 | |
101 | llvm::InitializeAllTargets(); |
102 | llvm::InitializeAllTargetMCs(); |
103 | llvm::InitializeAllAsmPrinters(); |
104 | llvm::InitializeAllAsmParsers(); |
105 | |
106 | llvm::TargetOptions targetOpts; |
107 | // Initialize target options in a backend-specific way. |
108 | initTargetOptions(targetOpts, opts); |
109 | |
110 | if (opts.getTarget().empty()) { |
111 | TM_.reset(llvm::EngineBuilder() |
112 | .setCodeModel(opts.getCodeModel()) |
113 | .setRelocationModel(opts.getRelocModel()) |
114 | .setTargetOptions(targetOpts) |
115 | .selectTarget(llvm::Triple(), opts.getArch(), |
116 | LLVMBackend::getHostCPU(), |
117 | LLVMBackend::getHostFeatures())); |
118 | } else { |
119 | TM_.reset(llvm::EngineBuilder() |
120 | .setCodeModel(opts.getCodeModel()) |
121 | .setRelocationModel(opts.getRelocModel()) |
122 | .setTargetOptions(targetOpts) |
123 | .selectTarget(llvm::Triple(opts.getTarget()), opts.getArch(), |
124 | opts.getCPU(), opts.getTargetFeatures())); |
125 | } |
126 | assert(TM_ && "Could not initialize the target machine" ); |
127 | } |
128 | |
129 | llvm::StringRef LLVMIRGen::getBundleName() const { return bundleName_; } |
130 | |
131 | void LLVMIRGen::setBundleName(const std::string &name) { |
132 | bundleName_ = name.empty() ? "bundle" : legalizeName(name); |
133 | } |
134 | |
135 | llvm::StringRef LLVMIRGen::getSavedBundleName() const { |
136 | return savedBundleName_; |
137 | } |
138 | |
139 | void LLVMIRGen::setSavedBundleName(const std::string &name) { |
140 | assert(!name.empty() && "Name cannot be empty" ); |
141 | savedBundleName_ = name; |
142 | } |
143 | |
144 | std::string LLVMIRGen::getMainEntryName() const { return mainEntryName_; } |
145 | |
146 | void LLVMIRGen::setMainEntryName(std::string name) { |
147 | mainEntryName_ = name.empty() ? "main" : legalizeName(name); |
148 | } |
149 | |
150 | llvm::ArrayRef<llvm::MemoryBufferRef> LLVMIRGen::getObjectRegistry() const { |
151 | return objectRegistry_; |
152 | } |
153 | |
154 | void LLVMIRGen::setObjectRegistry( |
155 | llvm::ArrayRef<llvm::MemoryBufferRef> objectRegistry) { |
156 | objectRegistry_ = objectRegistry; |
157 | } |
158 | |
159 | std::vector<std::string> LLVMIRGen::getBundleObjects() const { |
160 | // Default list of object names. |
161 | auto bundleObjects = bundleObjects_; |
162 | // Add object names enforced from command line interface. |
163 | for (auto bundleObject : bundleObjectsOpt) { |
164 | bundleObjects.push_back(bundleObject); |
165 | } |
166 | return bundleObjects; |
167 | } |
168 | |
169 | void LLVMIRGen::addBundleObject(llvm::StringRef objectName) { |
170 | // Add bundle object if not already added. |
171 | auto it = |
172 | std::find(bundleObjects_.begin(), bundleObjects_.end(), objectName.str()); |
173 | if (it == bundleObjects_.end()) { |
174 | bundleObjects_.push_back(objectName.str()); |
175 | } |
176 | } |
177 | |
178 | /// Load base addresses of different memory areas so that they can be easily |
179 | /// reused during codegen. |
180 | void LLVMIRGen::loadBaseAddresses(llvm::IRBuilder<> &builder) { |
181 | auto *F = builder.GetInsertBlock()->getParent(); |
182 | |
183 | // Load the base addresses at the beginning of the entry function once they |
184 | // are set. They won't change after this point and all relative addressing |
185 | // computations will simply use them. |
186 | auto sizeTTy = builder.getIntNTy(getLibjitSizeTWidth()); |
187 | baseActivationsAddr_ = builder.CreatePtrToInt(F->args().begin() + 2, sizeTTy); |
188 | baseConstantWeightVarsAddr_ = |
189 | builder.CreatePtrToInt(F->args().begin(), sizeTTy); |
190 | baseMutableWeightVarsAddr_ = |
191 | builder.CreatePtrToInt(F->args().begin() + 1, sizeTTy); |
192 | offsetsArray_ = F->args().begin() + 3; |
193 | } |
194 | |
195 | // Search for the standard library bitcode file on disk and load it into an |
196 | // LLVM module. We search for the standard library around the current executable |
197 | // and also in the current directory. |
198 | static std::unique_ptr<llvm::Module> |
199 | loadStandardLibrary(llvm::LLVMContext *ctx, llvm::StringRef filename, |
200 | llvm::StringRef libjitBC) { |
201 | using llvm::sys::path::append; |
202 | using llvm::sys::path::parent_path; |
203 | |
204 | llvm::SMDiagnostic error; |
205 | |
206 | // Parse the compiled-in image of libjit and return the resulting Module. |
207 | // checking for and reporting errors from parseIR. |
208 | |
209 | auto mod = llvm::parseIR( |
210 | llvm::MemoryBufferRef( |
211 | llvm::StringRef(reinterpret_cast<const char *>(libjitBC.data()), |
212 | libjitBC.size()), |
213 | "libjit.bc" ), |
214 | error, *ctx); |
215 | |
216 | if (!mod) { |
217 | error.print("LLVMIRGen" , llvm::errs()); |
218 | } |
219 | return mod; |
220 | } |
221 | |
222 | /// Register a diagnostics handler that prevents the compiler from printing to |
223 | /// stdout. |
224 | static void registerEmptyDiagHandler(llvm::LLVMContext &ctx) { |
225 | ctx.setDiagnosticHandlerCallBack( |
226 | [](const llvm::DiagnosticInfo &DI, void *Context) { |
227 | // Do not emit any warnings or diagnostics when JITting. |
228 | }); |
229 | } |
230 | |
231 | void LLVMIRGen::initCodeGen() { |
232 | // Load the jit library as a new module. |
233 | llmodule_ = loadStandardLibrary(&getLLVMContext(), "libjit.bc" , libjitBC_); |
234 | CHECK(llmodule_.get()) << "Unable to load the JIT library." ; |
235 | |
236 | // By default, LLVM would emit some diagnostics, remarks, etc. It is fine for |
237 | // a static compiler, but not necessary for a JIT. Let's disable it by |
238 | // providing a dummy diagnostics handler, that does not emit anything. |
239 | // In particular, this allows us to get rid of the annoying "cannot vectorize" |
240 | // warnings. |
241 | registerEmptyDiagHandler(getLLVMContext()); |
242 | |
243 | // Assign the target information to the module. |
244 | llmodule_->setDataLayout(getTargetMachine().createDataLayout()); |
245 | |
246 | // Initialize the debug information emission. |
247 | initDebugInfo(); |
248 | } |
249 | |
250 | /// \returns the LLVM type corresponding to the type of elements stored in \p |
251 | /// val. |
252 | llvm::Type *LLVMIRGen::getElementType(llvm::IRBuilder<> &builder, |
253 | const Value *val) { |
254 | switch (val->getElementType()) { |
255 | case ElemKind::Int64ITy: |
256 | return builder.getInt64Ty(); |
257 | case ElemKind::FloatTy: |
258 | return builder.getFloatTy(); |
259 | case ElemKind::Float16Ty: |
260 | llvm_unreachable("Not implemented" ); |
261 | case ElemKind::BFloat16Ty: |
262 | llvm_unreachable("Not implemented" ); |
263 | case ElemKind::Float64Ty: |
264 | return builder.getDoubleTy(); |
265 | case ElemKind::Int8QTy: |
266 | return builder.getInt8Ty(); |
267 | case ElemKind::UInt8QTy: |
268 | llvm_unreachable("Not implemented" ); |
269 | case ElemKind::Int16QTy: |
270 | return builder.getInt16Ty(); |
271 | case ElemKind::Int32QTy: |
272 | return builder.getInt32Ty(); |
273 | case ElemKind::Int64QTy: |
274 | return builder.getInt64Ty(); |
275 | case ElemKind::UInt8ITy: |
276 | return builder.getInt8Ty(); |
277 | case ElemKind::Int32ITy: |
278 | return builder.getInt32Ty(); |
279 | case ElemKind::UInt8FusedQTy: |
280 | return builder.getInt8Ty(); |
281 | case ElemKind::UInt8FusedFP16QTy: |
282 | return builder.getInt8Ty(); |
283 | case ElemKind::UInt4FusedFP16QTy: |
284 | return builder.getInt8Ty(); |
285 | case ElemKind::UInt4FusedQTy: |
286 | return builder.getInt8Ty(); |
287 | case ElemKind::BoolTy: |
288 | static_assert(sizeof(bool) == sizeof(int8_t), |
289 | "Bool is expected to be the same size as int8." ); |
290 | return builder.getInt8Ty(); |
291 | } |
292 | return nullptr; |
293 | } |
294 | |
295 | void LLVMIRGen::performCodeGen() { |
296 | // Create the entry function into the LLVM module. |
297 | auto int8PtrTy = llvm::Type::getInt8PtrTy(getLLVMContext()); |
298 | auto dimTPtrTy = llvm::Type::getIntNPtrTy(getLLVMContext(), DIM_T_BITWIDTH); |
299 | // The entry point has the following API: |
300 | // int entry(uint8_t *baseConstantWeightVars, |
301 | // uint8_t *baseInoutWeightVars, |
302 | // uint8_t *baseActivations, |
303 | // dim_t *offsets); |
304 | llvm::Type *retTy = |
305 | llvm::Type::getIntNTy(getLLVMContext(), getLibjitIntWidth()); |
306 | llvm::FunctionType *jitFuncTy = llvm::FunctionType::get( |
307 | retTy, {int8PtrTy, int8PtrTy, int8PtrTy, dimTPtrTy}, false); |
308 | llvmF_ = llvm::Function::Create(jitFuncTy, llvm::Function::ExternalLinkage, |
309 | "main" , llmodule_.get()); |
310 | emittedLLVMFunctions_.emplace_back(llvmF_); |
311 | |
312 | // Setup the entry basic block and initialize the IR builder. |
313 | llvm::BasicBlock *entry_bb = |
314 | llvm::BasicBlock::Create(getLLVMContext(), "entry" , llvmF_); |
315 | builder_ = glow::make_unique<llvm::IRBuilder<>>(entry_bb); |
316 | // Terminate the function with a return instruction. |
317 | auto zero = builder_->getIntN(getLibjitIntWidth(), 0); |
318 | auto *ret = builder_->CreateRet(zero); |
319 | // Emit all the code before the retrun instruction. |
320 | builder_->SetInsertPoint(ret); |
321 | |
322 | instrNumbering_.reset(new InstructionNumbering(*F_)); |
323 | generateFunctionDebugInfo(); |
324 | loadBaseAddresses(*builder_); |
325 | generateLLVMIRForModule(*builder_); |
326 | } |
327 | |
328 | void LLVMIRGen::finishCodeGen() { |
329 | if (dumpIR) { |
330 | llvm::outs() << "LLVM module before optimizations:\n" ; |
331 | llmodule_->print(llvm::outs(), nullptr); |
332 | } |
333 | // Perform verification if no debug info is being emitted. |
334 | // Otherwise, the verification is performed later by |
335 | // generateDebugInfo, once the debug info emission is finalized. |
336 | if (!emitDebugInfo) { |
337 | // Perform verification, but ignore any debug info errors for now. |
338 | // Debug info errors will be checked later by generateDebugInfo. |
339 | bool brokenDebugInfo = false; |
340 | (void)brokenDebugInfo; |
341 | assert(!llvm::verifyModule(getModule(), &llvm::errs(), &brokenDebugInfo) && |
342 | "LLVM module verification error" ); |
343 | } |
344 | |
345 | // Optimize the module. |
346 | optimizeLLVMModule(&getModule(), getTargetMachine()); |
347 | |
348 | // Generate debug information. |
349 | generateModuleDebugInfo(); |
350 | |
351 | if (dumpIR) { |
352 | llvm::outs() << "LLVM module after optimizations:\n" ; |
353 | llmodule_->print(llvm::outs(), nullptr); |
354 | } |
355 | |
356 | if (dumpJitAsm) { |
357 | llvm::SmallVector<char, 0> asmBuffer; |
358 | llvm::raw_svector_ostream asmStream(asmBuffer); |
359 | llvm::legacy::PassManager PM; |
360 | #if FACEBOOK_INTERNAL && LLVM_VERSION_MAJOR < 8 |
361 | getTargetMachine().addPassesToEmitFile( |
362 | PM, asmStream, llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile); |
363 | #elif LLVM_VERSION_MAJOR < 10 |
364 | getTargetMachine().addPassesToEmitFile( |
365 | PM, asmStream, nullptr, |
366 | llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile); |
367 | #else |
368 | getTargetMachine().addPassesToEmitFile(PM, asmStream, nullptr, |
369 | llvm::CGFT_AssemblyFile); |
370 | #endif |
371 | |
372 | PM.run(*llmodule_); |
373 | llvm::outs() << asmStream.str(); |
374 | } |
375 | } |
376 | |
377 | llvm::Value *LLVMIRGen::emitValueAddress(llvm::IRBuilder<> &builder, |
378 | const glow::Value *val) { |
379 | assert(allocationsInfo_.allocatedAddress_.count(val) && |
380 | "Value address was not allocated" ); |
381 | llvm::Type *T = nullptr; |
382 | |
383 | switch (val->getElementType()) { |
384 | case ElemKind::FloatTy: |
385 | T = llvm::Type::getFloatPtrTy(getLLVMContext()); |
386 | break; |
387 | case ElemKind::Float16Ty: |
388 | T = llvm::Type::getInt16PtrTy(getLLVMContext()); |
389 | break; |
390 | case ElemKind::BFloat16Ty: |
391 | T = llvm::Type::getInt16PtrTy(getLLVMContext()); |
392 | break; |
393 | case ElemKind::Int8QTy: |
394 | T = llvm::Type::getInt8PtrTy(getLLVMContext()); |
395 | break; |
396 | case ElemKind::UInt8QTy: |
397 | T = llvm::Type::getInt8PtrTy(getLLVMContext()); |
398 | break; |
399 | case ElemKind::Int16QTy: |
400 | T = llvm::Type::getInt16PtrTy(getLLVMContext()); |
401 | break; |
402 | case ElemKind::Int32QTy: |
403 | T = llvm::Type::getInt32PtrTy(getLLVMContext()); |
404 | break; |
405 | case ElemKind::Int64QTy: |
406 | T = llvm::Type::getInt64PtrTy(getLLVMContext()); |
407 | break; |
408 | case ElemKind::Int64ITy: |
409 | T = llvm::Type::getInt64PtrTy(getLLVMContext()); |
410 | break; |
411 | case ElemKind::Int32ITy: |
412 | T = llvm::Type::getInt32PtrTy(getLLVMContext()); |
413 | break; |
414 | case ElemKind::UInt8ITy: |
415 | T = llvm::Type::getInt8PtrTy(getLLVMContext()); |
416 | break; |
417 | case ElemKind::UInt8FusedQTy: |
418 | T = llvm::Type::getInt8PtrTy(getLLVMContext()); |
419 | break; |
420 | case ElemKind::UInt8FusedFP16QTy: |
421 | T = llvm::Type::getInt8PtrTy(getLLVMContext()); |
422 | break; |
423 | case ElemKind::UInt4FusedFP16QTy: |
424 | T = llvm::Type::getInt8PtrTy(getLLVMContext()); |
425 | break; |
426 | case ElemKind::UInt4FusedQTy: |
427 | T = llvm::Type::getInt8PtrTy(getLLVMContext()); |
428 | break; |
429 | case ElemKind::BoolTy: |
430 | T = llvm::Type::getInt8PtrTy(getLLVMContext()); |
431 | break; |
432 | default: |
433 | LOG(FATAL) << "Unsupported element type: " |
434 | << Type::getElementName(val->getElementType()).str(); |
435 | } |
436 | |
437 | assert(allocationsInfo_.valueNumbers_.count(val)); |
438 | auto &kindAndValue = allocationsInfo_.valueNumbers_[val]; |
439 | |
440 | // Get the required base address. |
441 | llvm::Value *baseAddrValue = nullptr; |
442 | switch (kindAndValue.first) { |
443 | case AllocationsInfo::ValueKind::Activation: |
444 | baseAddrValue = baseActivationsAddr_; |
445 | break; |
446 | case AllocationsInfo::ValueKind::ConstantWeight: |
447 | baseAddrValue = baseConstantWeightVarsAddr_; |
448 | break; |
449 | case AllocationsInfo::ValueKind::MutableWeight: |
450 | baseAddrValue = baseMutableWeightVarsAddr_; |
451 | break; |
452 | } |
453 | |
454 | // Use relative addressing. |
455 | // Get offset. |
456 | auto sizeTTy = builder.getIntNTy(getLibjitSizeTWidth()); |
457 | auto dimTTy = builder.getIntNTy(DIM_T_BITWIDTH); |
458 | |
459 | auto valueIdx = llvm::ConstantInt::get(dimTTy, kindAndValue.second); |
460 | auto offsetAddr = builder.CreateGEP(dimTTy, offsetsArray_, valueIdx); |
461 | auto offsetValue = builder.CreateLoad(dimTTy, offsetAddr); |
462 | // Add offset to the base address. |
463 | llvm::Value *addr = builder.CreateAdd( |
464 | baseAddrValue, builder.CreateZExt(offsetValue, sizeTTy)); |
465 | return builder.CreateIntToPtr(addr, T); |
466 | } |
467 | |
468 | llvm::Value * |
469 | LLVMIRGen::emitConstOffsetsArray(llvm::IRBuilder<> &builder, |
470 | const AllocationsInfo &allocationsInfo) { |
471 | constexpr const char *offsetsArrayName = "offsetsArray" ; |
472 | auto dimTType = builder.getIntNTy(DIM_T_BITWIDTH); |
473 | std::vector<llvm::Constant *> elems(allocationsInfo.valueNumbers_.size()); |
474 | dim_t maxOffset = 0; |
475 | for (auto &I : allocationsInfo.valueNumbers_) { |
476 | auto *V = I.first; |
477 | auto offset = I.second.second; |
478 | elems[offset] = llvm::ConstantInt::get( |
479 | dimTType, allocationsInfo.allocatedAddress_.lookup(V)); |
480 | maxOffset = std::max(maxOffset, (dim_t)offset); |
481 | } |
482 | elems.resize(maxOffset + 1); |
483 | auto *arr = llvm::ConstantArray::get( |
484 | llvm::ArrayType::get(dimTType, elems.size()), elems); |
485 | // Ensure that the same casted global variable is used for the equivalent |
486 | // const arrays. This is important for the later function specialization pass. |
487 | // LLVM does not do it automatically for this code pattern involving global |
488 | // variables. It also reduces the number of variables. |
489 | auto &constArrayVar = constArrayPtrs_[arr]; |
490 | auto oldG = |
491 | getModule().getGlobalVariable(offsetsArrayName, /* allowInternal */ true); |
492 | if (constArrayVar && constArrayVar->getType() == dimTType->getPointerTo()) { |
493 | return constArrayVar; |
494 | } |
495 | if (oldG) { |
496 | oldG->setName("offsetsArrayOld" ); |
497 | } |
498 | auto *M = builder.GetInsertBlock()->getModule(); |
499 | auto *G = new llvm::GlobalVariable(*M, arr->getType(), true, |
500 | llvm::GlobalValue::InternalLinkage, arr, |
501 | offsetsArrayName); |
502 | constArrayVar = builder.CreateBitCast(G, dimTType->getPointerTo()); |
503 | if (oldG) { |
504 | // Replace the old offsetsArray by the new one and remove the old. |
505 | oldG->replaceAllUsesWith(G); |
506 | oldG->eraseFromParent(); |
507 | } |
508 | return constArrayVar; |
509 | } |
510 | |
511 | llvm::Value *LLVMIRGen::emitConstI32Array(llvm::IRBuilder<> &builder, |
512 | llvm::ArrayRef<int32_t> vals) { |
513 | std::vector<llvm::Constant *> elems; |
514 | for (auto I : vals) { |
515 | elems.push_back(builder.getInt32(I)); |
516 | } |
517 | return emitConstArray(builder, elems, builder.getInt32Ty()); |
518 | } |
519 | |
520 | llvm::Value *LLVMIRGen::emitConstFloatArray(llvm::IRBuilder<> &builder, |
521 | llvm::ArrayRef<float> vals) { |
522 | std::vector<llvm::Constant *> elems; |
523 | for (auto I : vals) { |
524 | elems.push_back(llvm::ConstantFP::get( |
525 | llvm::Type::getFloatTy(getLLVMContext()), (float)I)); |
526 | } |
527 | return emitConstArray(builder, elems, |
528 | llvm::Type::getFloatTy(getLLVMContext())); |
529 | } |
530 | |
531 | llvm::Value *LLVMIRGen::emitConstArray(llvm::IRBuilder<> &builder, |
532 | llvm::ArrayRef<llvm::Constant *> vals, |
533 | llvm::Type *elemTy) { |
534 | std::vector<llvm::Constant *> elems; |
535 | for (auto I : vals) { |
536 | elems.push_back(cast<llvm::Constant>(builder.CreateBitCast(I, elemTy))); |
537 | } |
538 | auto *arr = llvm::ConstantArray::get( |
539 | llvm::ArrayType::get(elemTy, elems.size()), elems); |
540 | // Ensure that the same casted global variable is used for the equivalent |
541 | // const arrays. This is important for the later function specialization pass. |
542 | // LLVM does not do it automatically for this code pattern involving global |
543 | // variables. It also reduces the number of variables. |
544 | auto &constArrayVar = constArrayPtrs_[arr]; |
545 | if (constArrayVar && constArrayVar->getType() == elemTy->getPointerTo()) |
546 | return constArrayVar; |
547 | |
548 | auto *M = builder.GetInsertBlock()->getModule(); |
549 | |
550 | auto *G = new llvm::GlobalVariable(*M, arr->getType(), true, |
551 | llvm::GlobalValue::InternalLinkage, arr); |
552 | constArrayVar = builder.CreateBitCast(G, elemTy->getPointerTo()); |
553 | return constArrayVar; |
554 | } |
555 | |
556 | void LLVMIRGen::emitArrayStore(llvm::IRBuilder<> &builder, |
557 | llvm::ArrayRef<llvm::Value *> vals, |
558 | llvm::Value *basePtr, unsigned baseIdx) { |
559 | for (size_t idx = 0, end = vals.size(); idx < end; ++idx) { |
560 | assert(vals[idx]->getType()->getPointerTo() == basePtr->getType() && |
561 | "Mismatch between pointer and value type!" ); |
562 | auto *storeIdx = builder.getInt32(idx + baseIdx); |
563 | auto *storeAddr = builder.CreateGEP(basePtr, storeIdx); |
564 | builder.CreateStore(vals[idx], storeAddr); |
565 | } |
566 | } |
567 | |
568 | llvm::Value *LLVMIRGen::emitValueDims(llvm::IRBuilder<> &builder, |
569 | const glow::Value *val) { |
570 | auto dims = val->dims(); |
571 | return emitConstDimTArray(builder, dims); |
572 | } |
573 | |
574 | template <class InstructionTy> |
575 | llvm::Value *LLVMIRGen::emitConstFloatActivationArgs(llvm::IRBuilder<> &builder, |
576 | const InstructionTy *I) { |
577 | return emitConstFloatArray(builder, I->getFusedActivationArgs()); |
578 | } |
579 | |
580 | template <class InstructionTy> |
581 | llvm::Value *LLVMIRGen::emitConstQuantActivationArgs(llvm::IRBuilder<> &builder, |
582 | const InstructionTy *I) { |
583 | auto actArgsF = I->getFusedActivationArgs(); |
584 | std::vector<int32_t> actArgsQ; |
585 | auto *destTy = I->getDest()->getType(); |
586 | switch (I->getFusedActivation()) { |
587 | case FusedActivation::NONE: |
588 | case FusedActivation::RELU: |
589 | assert(actArgsF.size() == 0 && "Invalid number of activation parameters!" ); |
590 | break; |
591 | case FusedActivation::CLIP: { |
592 | // For Clip we quantize min/max using the output quantization params. |
593 | assert(actArgsF.size() == 2 && |
594 | "Invalid number of parameters for fused Clip activation!" ); |
595 | float minF = actArgsF[0]; |
596 | float maxF = actArgsF[1]; |
597 | TensorQuantizationParams TQP{destTy->getScale(), destTy->getOffset()}; |
598 | int32_t minQ = quantization::quantize<int32_t>(minF, TQP); |
599 | int32_t maxQ = quantization::quantize<int32_t>(maxF, TQP); |
600 | actArgsQ.push_back(minQ); |
601 | actArgsQ.push_back(maxQ); |
602 | break; |
603 | } |
604 | case FusedActivation::SIGMOID: |
605 | LOG(FATAL) << "Fused Sigmoid for quantized type not supported!" ; |
606 | break; |
607 | case FusedActivation::TANH: |
608 | LOG(FATAL) << "Fused Tanh for quantized type not supported!" ; |
609 | break; |
610 | case FusedActivation::LEAKY_RELU: { |
611 | // For LeakyRelu we transform the alpha parameter into pre/post/scale. |
612 | assert(actArgsF.size() == 1 && |
613 | "Invalid number of parameters for fused LeakyRelu activation!" ); |
614 | float alpha = actArgsF[0]; |
615 | auto alphaScaleParam = quantization::quantizeScaleOffset32To8(alpha, 0); |
616 | actArgsQ.push_back(alphaScaleParam.pre); |
617 | actArgsQ.push_back(alphaScaleParam.post); |
618 | actArgsQ.push_back(alphaScaleParam.scale); |
619 | break; |
620 | } |
621 | default: |
622 | LOG(FATAL) << "Unsupported fused activation type!" ; |
623 | } |
624 | return emitConstI32Array(builder, actArgsQ); |
625 | } |
626 | |
627 | llvm::Value *LLVMIRGen::emitValueSize(llvm::IRBuilder<> &builder, |
628 | const glow::Value *val) { |
629 | return builder.getIntN(DIM_T_BITWIDTH, val->size()); |
630 | } |
631 | |
632 | llvm::Value *LLVMIRGen::emitConstF32(llvm::IRBuilder<> &builder, float val) { |
633 | return llvm::ConstantFP::get(llvm::Type::getFloatTy(getLLVMContext()), val); |
634 | } |
635 | |
636 | llvm::Value *LLVMIRGen::emitConstI32(llvm::IRBuilder<> &builder, int32_t val) { |
637 | return builder.getInt32(val); |
638 | } |
639 | |
640 | llvm::Value *LLVMIRGen::emitConstI16(llvm::IRBuilder<> &builder, int16_t val) { |
641 | return builder.getInt16(val); |
642 | } |
643 | |
644 | llvm::Value *LLVMIRGen::emitConstI8(llvm::IRBuilder<> &builder, int8_t val) { |
645 | return builder.getInt8(val); |
646 | } |
647 | |
648 | llvm::Value *LLVMIRGen::emitConstI1(llvm::IRBuilder<> &builder, bool val) { |
649 | return builder.getInt1(val); |
650 | } |
651 | |
652 | llvm::Value *LLVMIRGen::emitConstSizeT(llvm::IRBuilder<> &builder, size_t val) { |
653 | return builder.getIntN(getLibjitSizeTWidth(), val); |
654 | } |
655 | |
656 | llvm::Value *LLVMIRGen::emitConstDimT(llvm::IRBuilder<> &builder, dim_t val) { |
657 | return builder.getIntN(sizeof(dim_t) * 8, val); |
658 | } |
659 | |
660 | llvm::Value *LLVMIRGen::emitConst(llvm::IRBuilder<> &builder, float val, |
661 | glow::ElemKind kind) { |
662 | switch (kind) { |
663 | case ElemKind::FloatTy: |
664 | return llvm::ConstantFP::get(llvm::Type::getFloatTy(getLLVMContext()), val); |
665 | case ElemKind::Float16Ty: |
666 | llvm_unreachable("Not implemented" ); |
667 | case ElemKind::BFloat16Ty: |
668 | llvm_unreachable("Not implemented" ); |
669 | case ElemKind::Float64Ty: |
670 | return llvm::ConstantFP::get(llvm::Type::getDoubleTy(getLLVMContext()), |
671 | val); |
672 | case ElemKind::Int64ITy: |
673 | return builder.getInt64(static_cast<int64_t>(val)); |
674 | case ElemKind::Int8QTy: |
675 | return builder.getInt8(static_cast<int8_t>(val)); |
676 | case ElemKind::UInt8QTy: |
677 | llvm_unreachable("Not implemented" ); |
678 | case ElemKind::Int16QTy: |
679 | return builder.getInt16(static_cast<int16_t>(val)); |
680 | case ElemKind::Int32QTy: |
681 | return builder.getInt32(static_cast<int32_t>(val)); |
682 | case ElemKind::UInt8ITy: |
683 | return builder.getInt8(static_cast<uint8_t>(val)); |
684 | case ElemKind::Int64QTy: |
685 | return builder.getInt64(static_cast<int64_t>(val)); |
686 | case ElemKind::Int32ITy: |
687 | return builder.getInt32(static_cast<int32_t>(val)); |
688 | case ElemKind::UInt8FusedQTy: |
689 | return builder.getInt8(static_cast<int8_t>(val)); |
690 | case ElemKind::UInt8FusedFP16QTy: |
691 | return builder.getInt8(static_cast<int8_t>(val)); |
692 | case ElemKind::UInt4FusedFP16QTy: |
693 | return builder.getInt8(static_cast<int8_t>(val)); |
694 | case ElemKind::UInt4FusedQTy: |
695 | return builder.getInt8(static_cast<int8_t>(val)); |
696 | case ElemKind::BoolTy: |
697 | return builder.getInt8(static_cast<int8_t>(val)); |
698 | } |
699 | llvm_unreachable("Unknown element type" ); |
700 | } |
701 | |
702 | llvm::Value *LLVMIRGen::emitStringConst(llvm::IRBuilder<> &builder, |
703 | llvm::StringRef str) { |
704 | llvm::Constant *constStrArray = |
705 | llvm::ConstantDataArray::getString(getLLVMContext(), str, true); |
706 | llvm::GlobalVariable *gvarStr = new llvm::GlobalVariable( |
707 | *llmodule_, constStrArray->getType(), true, |
708 | llvm::GlobalValue::PrivateLinkage, constStrArray, ".str" ); |
709 | #if LLVM_VERSION_MAJOR >= 10 |
710 | gvarStr->setAlignment(llvm::MaybeAlign(1)); |
711 | #else |
712 | gvarStr->setAlignment(1); |
713 | #endif |
714 | // Add unnamed_addr attribute to enable constmerge pass. |
715 | gvarStr->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); |
716 | |
717 | return builder.CreateBitCast(gvarStr, builder.getInt8PtrTy()); |
718 | } |
719 | |
720 | void LLVMIRGen::markArgAsUnspecialized(llvm::Value *val) { |
721 | dontSpecializeArgsSet_.insert(val); |
722 | } |
723 | |
724 | static std::string createName(const std::string &name, ElemKind elemTy) { |
725 | switch (elemTy) { |
726 | case ElemKind::FloatTy: |
727 | return name + "_f" ; |
728 | case ElemKind::Float16Ty: |
729 | return name + "_fp16" ; |
730 | case ElemKind::BFloat16Ty: |
731 | return name + "_bfloat16" ; |
732 | case ElemKind::Int8QTy: |
733 | return name + "_i8" ; |
734 | case ElemKind::Int16QTy: |
735 | return name + "_i16" ; |
736 | case ElemKind::Int32QTy: |
737 | return name + "_i32" ; |
738 | case ElemKind::Int32ITy: |
739 | return name + "_i32" ; |
740 | case ElemKind::Int64ITy: |
741 | return name + "_u" ; |
742 | case ElemKind::BoolTy: |
743 | return name + "_b" ; |
744 | default: |
745 | LOG(FATAL) << "Unsupported element type: " |
746 | << Type::getElementName(elemTy).str(); |
747 | } |
748 | } |
749 | |
750 | llvm::Function * |
751 | LLVMIRGen::getFunction(const std::string &name, |
752 | llvm::ArrayRef<glow::ElemKind> elemTyArray) { |
753 | auto strName = "libjit_" + name; |
754 | |
755 | for (auto elTy : elemTyArray) { |
756 | strName = createName(strName, elTy); |
757 | } |
758 | auto *F = llmodule_->getFunction(strName); |
759 | CHECK(F) << "Unable to load the function: " << strName.c_str(); |
760 | return F; |
761 | } |
762 | |
763 | llvm::Function *LLVMIRGen::getFunction(const std::string &name) { |
764 | return getFunction(name, llvm::ArrayRef<ElemKind>{}); |
765 | } |
766 | |
767 | llvm::Function *LLVMIRGen::getFunction(const std::string &name, |
768 | ElemKind elemTy) { |
769 | return getFunction(name, llvm::ArrayRef<ElemKind>{elemTy}); |
770 | } |
771 | |
772 | llvm::Function *LLVMIRGen::getLLVMFunction() { return llvmF_; } |
773 | |
774 | llvm::CallInst *LLVMIRGen::createCall(llvm::IRBuilder<> &builder, |
775 | llvm::Function *callee, |
776 | llvm::ArrayRef<llvm::Value *> args, |
777 | bool checked) { |
778 | #ifndef NDEBUG |
779 | llvm::FunctionType *FTy = callee->getFunctionType(); |
780 | assert((args.size() == FTy->getNumParams() || |
781 | (FTy->isVarArg() && args.size() > FTy->getNumParams())) && |
782 | "Calling a function with bad signature: wrong number of arguments." ); |
783 | |
784 | for (unsigned i = 0; i != args.size(); ++i) { |
785 | assert((i >= FTy->getNumParams() || |
786 | FTy->getParamType(i) == args[i]->getType()) && |
787 | "Calling a function with a bad signature: argument type mismatch." ); |
788 | } |
789 | #endif |
790 | if (!checked || !callee->getReturnType()->isIntegerTy()) { |
791 | return builder.CreateCall(callee, args); |
792 | } |
793 | // Check if callee returned an error, i.e. non-zero result. |
794 | // Emit a return with this error code in this case. |
795 | auto *result = builder.CreateCall(callee, args); |
796 | auto *zero = builder.getIntN(result->getType()->getIntegerBitWidth(), 0); |
797 | auto *cond = builder.CreateICmpNE(result, zero); |
798 | auto insertionPoint = builder.GetInsertPoint(); |
799 | auto *currentBB = result->getParent(); |
800 | auto *falseBB = |
801 | currentBB->splitBasicBlock(builder.GetInsertPoint(), "cont_bb" ); |
802 | auto *trueBB = llvm::BasicBlock::Create(getLLVMContext(), "error_bb" , |
803 | result->getFunction()); |
804 | builder.SetInsertPoint(currentBB->getTerminator()); |
805 | builder.CreateCondBr(cond, trueBB, falseBB); |
806 | currentBB->getTerminator()->eraseFromParent(); |
807 | builder.SetInsertPoint(trueBB); |
808 | auto *castedResult = |
809 | builder.CreateBitCast(result, builder.getIntNTy(getLibjitIntWidth())); |
810 | builder.CreateRet(castedResult); |
811 | builder.SetInsertPoint(falseBB, insertionPoint); |
812 | builder.SetInsertPoint(falseBB->getTerminator()); |
813 | return result; |
814 | } |
815 | |
816 | llvm::CallInst * |
817 | LLVMIRGen::createCheckedCall(llvm::IRBuilder<> &builder, llvm::Function *callee, |
818 | llvm::ArrayRef<llvm::Value *> args) { |
819 | return createCall(builder, callee, args, /* checked */ true); |
820 | } |
821 | |
822 | llvm::CallInst * |
823 | LLVMIRGen::createUncheckedCall(llvm::IRBuilder<> &builder, |
824 | llvm::Function *callee, |
825 | llvm::ArrayRef<llvm::Value *> args) { |
826 | return createCall(builder, callee, args, /* checked */ false); |
827 | } |
828 | |
829 | std::pair<llvm::BasicBlock *, llvm::BasicBlock *> |
830 | LLVMIRGen::createLoop(llvm::IRBuilder<> &builder, llvm::LLVMContext &ctx, |
831 | llvm::Value *numElements) const { |
832 | auto dimTTy = builder.getIntNTy(DIM_T_BITWIDTH); |
833 | auto *initVal = llvm::ConstantInt::get(dimTTy, 0); |
834 | |
835 | // Make the new basic block for the loop header. Insert it after current |
836 | // block. |
837 | llvm::Function *func = builder.GetInsertBlock()->getParent(); |
838 | auto * = builder.GetInsertBlock(); |
839 | auto *loopBB = llvm::BasicBlock::Create(ctx, "loop" , func); |
840 | |
841 | // Insert a jump from the current block to the loopBB. |
842 | builder.CreateBr(loopBB); |
843 | |
844 | // Start insertion in LoopBB. |
845 | builder.SetInsertPoint(loopBB); |
846 | |
847 | // Create the PHI node with an entry for initial value. |
848 | llvm::PHINode *var = builder.CreatePHI(dimTTy, 2); |
849 | var->addIncoming(initVal, preheaderBB); |
850 | |
851 | // Emit the step value. |
852 | auto *stepVal = llvm::ConstantInt::get(dimTTy, 1); |
853 | auto *nextVal = builder.CreateAdd(var, stepVal, "nextvar" , /* HasNUW */ true, |
854 | /* HasNSW */ true); |
855 | // Compute the end condition. |
856 | auto *endCond = builder.CreateICmpULT(nextVal, numElements, "loopcond" ); |
857 | |
858 | // Create the "after loop" block and insert it. |
859 | auto *afterBB = llvm::BasicBlock::Create(ctx, "afterloop" , func); |
860 | |
861 | // Insert the conditional branch at the end of the loopBB. |
862 | auto *backEdge = builder.CreateCondBr(endCond, loopBB, afterBB); |
863 | // Add explicit loop llvm.loop.vectorize.enable metadata to the generated |
864 | // loop to help the LLVM vectorizer. Without this metadata, LLVM loop |
865 | // vectorizer bails on long data-parallel loops with a lot of operations. This |
866 | // metadata forces it to vectorize them anyways. |
867 | llvm::SmallVector<llvm::Metadata *, 4> args; |
868 | // Reserve operand 0 for loop id self reference. |
869 | // |
870 | // Initialize it with a special temporary metadata node, which is typically |
871 | // used to create cyclic metadata structures. tmpMD is a unique_ptr and thus |
872 | // will be freed automatically when it goes out of scope. |
873 | llvm::TempMDTuple tmpMD = llvm::MDNode::getTemporary(ctx, llvm::None); |
874 | args.push_back(tmpMD.get()); |
875 | llvm::Metadata *Vals[] = { |
876 | // Reserve operand 0 for loop id self reference. |
877 | llvm::MDString::get(ctx, "llvm.loop.vectorize.enable" ), |
878 | llvm::ConstantAsMetadata::get( |
879 | llvm::ConstantInt::get(llvm::Type::getInt1Ty(ctx), true))}; |
880 | args.push_back(llvm::MDNode::get(ctx, Vals)); |
881 | auto *loopMD = llvm::MDNode::get(ctx, args); |
882 | // Set the first operand to itself. |
883 | loopMD->replaceOperandWith(0, loopMD); |
884 | backEdge->setMetadata(llvm::LLVMContext::MD_loop, loopMD); |
885 | // Add a new entry to the PHI node for the backedge. |
886 | var->addIncoming(nextVal, loopBB); |
887 | builder.SetInsertPoint(afterBB); |
888 | return std::make_pair(loopBB, afterBB); |
889 | } |
890 | |
891 | /// Emit the address of the buffer \p v inside a data-parallel kernel \p kernel |
892 | /// using the mapping provided by \p bufferToArgNum. |
893 | llvm::Value * |
894 | LLVMIRGen::emitBufferAddress(llvm::IRBuilder<> &builder, Value *val, |
895 | llvm::Function *kernel, |
896 | llvm::DenseMap<Value *, int> &bufferToArgNum) { |
897 | assert(bufferToArgNum.count(val) && "Buffer should be in the map" ); |
898 | return kernel->args().begin() + bufferToArgNum[val]; |
899 | } |
900 | |
901 | /// Implementation of emitDataParallelKernel where we guarantee that the number |
902 | /// of arguments will be bound by 64. |
903 | void LLVMIRGen::emitDataParallelKernelImpl( |
904 | llvm::IRBuilder<> &builder, llvm::ArrayRef<const Instruction *> bundle, |
905 | llvm::ArrayRef<llvm::Type *> argTypes, |
906 | llvm::DenseMap<Value *, int> &bufferToArgNum, |
907 | llvm::ArrayRef<llvm::Value *> buffers) { |
908 | if (bundle.empty()) { |
909 | return; |
910 | } |
911 | // Create stacked kernel function type. |
912 | llvm::Type *voidTy = llvm::Type::getVoidTy(getLLVMContext()); |
913 | llvm::FunctionType *kernelFuncTy = |
914 | llvm::FunctionType::get(voidTy, argTypes, false); |
915 | auto *kernelFunc = |
916 | llvm::Function::Create(kernelFuncTy, llvm::Function::InternalLinkage, |
917 | "libjit_stacked_kernel" , llmodule_.get()); |
918 | // Mark all kernel function buffer parameters as no-alias, because above |
919 | // we ensured that they are uniqued. |
920 | for (unsigned paramIdx = 0; paramIdx < bufferToArgNum.size(); ++paramIdx) { |
921 | kernelFunc->addParamAttr(paramIdx, llvm::Attribute::AttrKind::NoAlias); |
922 | } |
923 | |
924 | // Create the entry BB. |
925 | llvm::BasicBlock *entryBB = |
926 | llvm::BasicBlock::Create(getLLVMContext(), "entry" , kernelFunc); |
927 | llvm::IRBuilder<> kernelBuilder(entryBB); |
928 | // Number of tensor elements. |
929 | auto *numElements = |
930 | emitValueSize(kernelBuilder, bundle[0]->getOperand(0).first); |
931 | // Create a loop inside the stacked kernel function being generated. |
932 | auto loopBBs = createLoop(kernelBuilder, getLLVMContext(), numElements); |
933 | |
934 | // Get the index parameter of the loop. |
935 | // This is the PHI node of the BB. |
936 | auto *kernelLoopIdx = dyn_cast<llvm::PHINode>(loopBBs.first->begin()); |
937 | assert(kernelLoopIdx && "Could not find the loop index" ); |
938 | // Insert the body of the loop right after the PHI node. |
939 | kernelBuilder.SetInsertPoint(loopBBs.first->getFirstNonPHIOrDbg()); |
940 | // Iterate over stacked instructions and create a kernel invocations per |
941 | // instruction. |
942 | for (auto &BI : bundle) { |
943 | // Name of the stacked operation to be invoked. |
944 | assert(canBePartOfDataParallelKernel(BI) && |
945 | "Data parallel operation is expected" ); |
946 | generateLLVMIRForDataParallelInstr(kernelBuilder, BI, kernelFunc, |
947 | bufferToArgNum, kernelLoopIdx); |
948 | } |
949 | kernelBuilder.SetInsertPoint(loopBBs.second); |
950 | // Add a return. |
951 | kernelBuilder.CreateRetVoid(); |
952 | |
953 | setCurrentDebugLocation(builder, *bundle.begin()); |
954 | // Emit a call of the kernel. |
955 | createUncheckedCall(builder, kernelFunc, buffers); |
956 | // Emit debug info for the generated data-parallel kernel. |
957 | generateFunctionDebugInfo(kernelFunc); |
958 | } |
959 | |
960 | /// Emit the function that implements a data-parallel kernel and calls it. |
961 | /// |
962 | /// The generated kernel functions get buffers as their parameters. The buffers |
963 | /// are uniqued, so that any buffer is passed as argument to the kernel function |
964 | /// only once. This allows us to mark all parameters of the generated kernel as |
965 | /// noalias. As a result, the LLVM optimizer makes use of the noalias attributes |
966 | /// and produces nicely vectorized code for the generated data-parallel kernels. |
967 | /// Note that we will emit a kernel whenever the number of arguments (aka unique |
968 | /// buffers) exceeds `kArgLimit`. |
969 | void LLVMIRGen::emitDataParallelKernel( |
970 | llvm::IRBuilder<> &builder, llvm::ArrayRef<const Instruction *> bundle) { |
971 | if (bundle.empty()) |
972 | return; |
973 | // Types of arguments for the kernel function being generated. |
974 | llvm::SmallVector<llvm::Type *, 32> argTypes; |
975 | // Map each buffer used by the kernel to the argument number of the kernel |
976 | // function. This ensures that same buffer is always mapped to the same |
977 | // argument. |
978 | llvm::DenseMap<Value *, int> bufferToArgNum; |
979 | // Buffers to be passed to the kernel function as arguments. |
980 | llvm::SmallVector<llvm::Value *, 32> buffers; |
981 | // Hold a group of instructions whose unique buffer size is no more than |
982 | // `kArgLimit` and ship it for processing |
983 | llvm::SmallVector<const Instruction *, 32> batchedBundle; |
984 | // Collect unique buffers up to `kArgLimit` used by the instructions of the |
985 | // kernel. |
986 | for (const auto I : bundle) { |
987 | // If adding the buffers of current instruction might make the total number |
988 | // of unique buffer exceed `kArgLimit`, we need to emit the kernel and start |
989 | // over. Note the "might" as this method is pessimistic, because number of |
990 | // buffers from current instruction might not be unique. Trade-off here is |
991 | // that the algorithm is cleaner and in practice, if we over-estimate the |
992 | // argument size by several, it does not matter too much. |
993 | if (argTypes.size() + I->getOperands().size() > kArgLimit) { |
994 | emitDataParallelKernelImpl(builder, batchedBundle, argTypes, |
995 | bufferToArgNum, buffers); |
996 | batchedBundle.clear(); |
997 | argTypes.clear(); |
998 | bufferToArgNum.clear(); |
999 | buffers.clear(); |
1000 | } |
1001 | |
1002 | // Add the instruction to the current bundle and process its operands |
1003 | batchedBundle.push_back(I); |
1004 | for (const auto &Op : I->getOperands()) { |
1005 | auto *buf = Op.first; |
1006 | if (!bufferToArgNum.count(buf)) { |
1007 | bufferToArgNum[buf] = argTypes.size(); |
1008 | buffers.push_back(emitValueAddress(builder, buf)); |
1009 | argTypes.push_back(getElementType(builder, buf)->getPointerTo()); |
1010 | } |
1011 | } |
1012 | } |
1013 | emitDataParallelKernelImpl(builder, batchedBundle, argTypes, bufferToArgNum, |
1014 | buffers); |
1015 | } |
1016 | |
1017 | /// Check if the provided operand overlaps with an operand of an instruction |
1018 | /// already in the bundle, but is not exactly the same memory region. |
1019 | /// Such memory regions cannot be considered data-parallel in the scope of the |
1020 | /// same kernel. |
1021 | /// |
1022 | /// \param allocationsInfo information about allocations |
1023 | /// \param bundle current bundle of stacked instructions |
1024 | /// \param op the operand to be checked for overlaps with the \p bundle. |
1025 | static bool isOverlappingWithAnyBundleBufferOperands( |
1026 | AllocationsInfo &allocationsInfo, |
1027 | llvm::SmallVectorImpl<const Instruction *> &bundle, |
1028 | const Instruction::Operand &op) { |
1029 | auto *buf = op.first; |
1030 | auto addr1 = allocationsInfo.allocatedAddress_[buf]; |
1031 | auto size1 = buf->getSizeInBytes(); |
1032 | for (auto bi : bundle) { |
1033 | for (auto bop : bi->getOperands()) { |
1034 | // Only input operands never interfere. |
1035 | if (bop.second == OperandKind::In && op.second == OperandKind::In) { |
1036 | continue; |
1037 | } |
1038 | auto buf2 = bop.first; |
1039 | auto addr2 = allocationsInfo.allocatedAddress_[buf2]; |
1040 | auto size2 = buf2->getSizeInBytes(); |
1041 | if (addr1 == addr2 && size1 == size2) { |
1042 | // The two buffers are the exact same memory region. The operations |
1043 | // cannot be within the same bundle because the buffer pointers are |
1044 | // "noalias" qualified, so the kernel operations can be reordered by |
1045 | // LLVM's optimizations. |
1046 | // TODO investigate if removing "noalias" can be used to create bigger |
1047 | // and faster bundles. |
1048 | return true; |
1049 | } |
1050 | if ((addr1 >= addr2 && addr1 < addr2 + size2) || |
1051 | (addr2 >= addr1 && addr2 < addr1 + size1)) { |
1052 | // Two intervals overlap, but are not the same. |
1053 | return true; |
1054 | } |
1055 | } |
1056 | } |
1057 | return false; |
1058 | } |
1059 | |
1060 | template <typename T> bool matchPair(T a, T b) { return a == b; } |
1061 | |
1062 | template <typename T> bool matchPair(T a) { return false; } |
1063 | |
1064 | /// Returns true if the input /p a matches with at least one of the inputs in |
1065 | /// the variadic list \p b .... |
1066 | template <typename T, typename... Args> bool matchPair(T a, T b, Args... args) { |
1067 | return a == b || matchPair(a, args...); |
1068 | } |
1069 | |
1070 | void LLVMIRGen::generateLLVMIRForModule(llvm::IRBuilder<> &builder) { |
1071 | // Go over the instructions and try to group them into bundles. |
1072 | auto &instrs = F_->getInstrs(); |
1073 | |
1074 | // Group instructions into bundles of shape compatible data parallel |
1075 | // instructions and emit them. |
1076 | llvm::SmallVector<const Instruction *, 32> bundle; |
1077 | for (auto &I : instrs) { |
1078 | if (!canBePartOfDataParallelKernel(&I)) { |
1079 | // Ignore memory management instructions as they are handled by the |
1080 | // MemoryManager and are NOPs for a JIT. |
1081 | if (isa<AllocActivationInst>(&I) || isa<DeallocActivationInst>(&I) || |
1082 | isa<TensorViewInst>(&I)) { |
1083 | generateLLVMIRForInstr(builder, &I); |
1084 | continue; |
1085 | } |
1086 | emitDataParallelKernel(builder, bundle); |
1087 | bundle.clear(); |
1088 | generateLLVMIRForInstr(builder, &I); |
1089 | continue; |
1090 | } |
1091 | |
1092 | // This is a data parallel instruction. |
1093 | |
1094 | // Check if the current instruction is shape compatible with the bundle. |
1095 | bool isBundleCompatible = true; |
1096 | if (!bundle.empty()) { |
1097 | auto val = I.getOperand(0).first; |
1098 | auto bundleVal = bundle.back()->getOperand(0).first; |
1099 | // Check if shapes have the same amount of elements. |
1100 | isBundleCompatible = val->size() == bundleVal->size(); |
1101 | } |
1102 | |
1103 | // Check all mutated operands of the current instruction. Their memory |
1104 | // regions should not have a non-exact overlap with any operands of the |
1105 | // bundled instructions. In case this condition does not hold, the current |
1106 | // instruction cannot be included into the data-parallel bundle, because |
1107 | // overlapping operand buffers are not data parallel. |
1108 | for (auto &op : I.getOperands()) { |
1109 | // If the mutated operand buffer overlaps with any buffer already used by |
1110 | // the bundle, the current instruction cannot become a part of the bundle. |
1111 | if (isOverlappingWithAnyBundleBufferOperands(allocationsInfo_, bundle, |
1112 | op)) { |
1113 | isBundleCompatible = false; |
1114 | break; |
1115 | } |
1116 | } |
1117 | |
1118 | // If the instruction cannot be added to the current bundle, emit the kernel |
1119 | // for the current bundle and start a new bundle. |
1120 | if (!isBundleCompatible) { |
1121 | emitDataParallelKernel(builder, bundle); |
1122 | bundle.clear(); |
1123 | } |
1124 | // Add a data parallel instruction to the bundle. |
1125 | bundle.push_back(&I); |
1126 | } |
1127 | |
1128 | emitDataParallelKernel(builder, bundle); |
1129 | } |
1130 | |
1131 | void LLVMIRGen::generateLLVMIRForDataParallelInstr( |
1132 | llvm::IRBuilder<> &builder, const glow::Instruction *I, |
1133 | llvm::Function *kernel, llvm::DenseMap<Value *, int> &bufferToArgNum, |
1134 | llvm::Value *loopCount) { |
1135 | setCurrentDebugLocation(builder, I); |
1136 | assert(canBePartOfDataParallelKernel(I) && |
1137 | "Instruction cannot be part of a data parallel kernel" ); |
1138 | switch (I->getKind()) { |
1139 | |
1140 | #define ARITHMETIC_UNARY_OP_WITH_IMM_CASE(INST_NAME_, FUN_NAME_, VALUE_) \ |
1141 | case Kinded::Kind::INST_NAME_##InstKind: { \ |
1142 | auto *AN = cast<INST_NAME_##Inst>(I); \ |
1143 | auto *dest = AN->getDest(); \ |
1144 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); \ |
1145 | auto *elementTy = getElementType(builder, dest); \ |
1146 | auto value = AN->get##VALUE_(); \ |
1147 | auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType()); \ |
1148 | auto *pointerNull = \ |
1149 | llvm::ConstantPointerNull::get(elementTy->getPointerTo()); \ |
1150 | if (dest->getType()->isQuantizedType()) { \ |
1151 | auto *destTy = dest->getType(); \ |
1152 | /* Quantize value based on the output type. */ \ |
1153 | /* Perform this early and let jit library to work */ \ |
1154 | /* with quantized number. */ \ |
1155 | TensorQuantizationParams TQP{destTy->getScale(), destTy->getOffset()}; \ |
1156 | if (destTy->getElementType() == ElemKind::Int8QTy) { \ |
1157 | auto quantizedValue = quantization::quantize<int8_t>(value, TQP); \ |
1158 | auto *val = emitConstI8(builder, quantizedValue); \ |
1159 | auto *stackedOpCall = createUncheckedCall( \ |
1160 | builder, F, {loopCount, val, pointerNull, pointerNull}); \ |
1161 | auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \ |
1162 | "buffer.element.addr"); \ |
1163 | builder.CreateStore(stackedOpCall, destAddr); \ |
1164 | } else if (destTy->getElementType() == ElemKind::Int16QTy) { \ |
1165 | auto quantizedValue = quantization::quantize<int16_t>(value, TQP); \ |
1166 | auto *val = emitConstI16(builder, quantizedValue); \ |
1167 | auto *stackedOpCall = createUncheckedCall( \ |
1168 | builder, F, {loopCount, val, pointerNull, pointerNull}); \ |
1169 | auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \ |
1170 | "buffer.element.addr"); \ |
1171 | builder.CreateStore(stackedOpCall, destAddr); \ |
1172 | } else { \ |
1173 | llvm_unreachable("Quantization precision not supported."); \ |
1174 | } \ |
1175 | } else { \ |
1176 | auto *val = emitConst(builder, value, dest->getElementType()); \ |
1177 | auto *stackedOpCall = createUncheckedCall( \ |
1178 | builder, F, {loopCount, val, pointerNull, pointerNull}); \ |
1179 | auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \ |
1180 | "buffer.element.addr"); \ |
1181 | builder.CreateStore(stackedOpCall, destAddr); \ |
1182 | } \ |
1183 | break; \ |
1184 | } |
1185 | ARITHMETIC_UNARY_OP_WITH_IMM_CASE(Splat, "splat" , Value); |
1186 | #undef ARITHMETIC_UNARY_OP_WITH_IMM_CASE |
1187 | |
1188 | case Kinded::Kind::TouchInstKind: |
1189 | // do nothing; |
1190 | break; |
1191 | |
1192 | case Kinded::Kind::ElementSelectInstKind: { |
1193 | auto *ES = cast<ElementSelectInst>(I); |
1194 | auto *dest = ES->getDest(); |
1195 | auto *cond = ES->getCond(); |
1196 | auto *lhs = ES->getLHS(); |
1197 | auto *rhs = ES->getRHS(); |
1198 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1199 | auto *condPtr = emitBufferAddress(builder, cond, kernel, bufferToArgNum); |
1200 | auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); |
1201 | auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); |
1202 | |
1203 | // Need _kernel suffix since these operations are implemented as |
1204 | // "data-parallel" kernels in libjit. |
1205 | auto *F = getFunction("elementselect_kernel" , lhs->getElementType()); |
1206 | |
1207 | if (lhs->getType()->isQuantizedType()) { |
1208 | auto *destTy = dest->getType(); |
1209 | auto *lhsTy = lhs->getType(); |
1210 | auto *rhsTy = rhs->getType(); |
1211 | |
1212 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
1213 | auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset()); |
1214 | auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset()); |
1215 | |
1216 | // The selected value will be either lhs = s_l * (i_l - o_l) or |
1217 | // rhs = s_r * (i_r - o_r); the stored result that must be computed is |
1218 | // therefore one of: |
1219 | // (i) i_d = (s_l / s_d) * (i_l - o_l) + o_d |
1220 | // (ii) i_d = (s_r / s_d) * (i_r - o_r) + o_d |
1221 | float destScale = destTy->getScale(); |
1222 | auto lhsScaleParams = quantization::quantizeScaleOffset32To8( |
1223 | lhsTy->getScale() / destScale, lhsTy->getOffset()); |
1224 | auto rhsScaleParams = quantization::quantizeScaleOffset32To8( |
1225 | rhsTy->getScale() / destScale, rhsTy->getOffset()); |
1226 | |
1227 | auto *lhsPre = emitConstI32(builder, lhsScaleParams.pre); |
1228 | auto *lhsPost = emitConstI32(builder, lhsScaleParams.post); |
1229 | auto *lhsScale = emitConstI32(builder, lhsScaleParams.scale); |
1230 | auto *rhsPre = emitConstI32(builder, rhsScaleParams.pre); |
1231 | auto *rhsPost = emitConstI32(builder, rhsScaleParams.post); |
1232 | auto *rhsScale = emitConstI32(builder, rhsScaleParams.scale); |
1233 | |
1234 | auto *stackedOpCall = createUncheckedCall( |
1235 | builder, F, |
1236 | {loopCount, condPtr, lhsPtr, rhsPtr, destOffset, lhsOffset, rhsOffset, |
1237 | lhsPre, lhsPost, lhsScale, rhsPre, rhsPost, rhsScale}); |
1238 | auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, |
1239 | loopCount, "buffer.element.addr" ); |
1240 | builder.CreateStore(stackedOpCall, destAddr); |
1241 | } else { |
1242 | auto *stackedOpCall = |
1243 | createUncheckedCall(builder, F, {loopCount, condPtr, lhsPtr, rhsPtr}); |
1244 | auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr, |
1245 | loopCount, "buffer.element.addr" ); |
1246 | builder.CreateStore(stackedOpCall, destAddr); |
1247 | } |
1248 | break; |
1249 | } |
1250 | case Kinded::Kind::IntLookupTableInstKind: { |
1251 | auto *lookupTable = cast<IntLookupTableInst>(I); |
1252 | auto *dest = lookupTable->getDest(); |
1253 | auto *src = lookupTable->getSrc(); |
1254 | auto *mapping = lookupTable->getMapping(); |
1255 | |
1256 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1257 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1258 | auto *mappingPtr = |
1259 | emitBufferAddress(builder, mapping, kernel, bufferToArgNum); |
1260 | |
1261 | auto *F = getFunction("intlookuptable_kernel" , dest->getElementType()); |
1262 | auto *stackedOpCall = |
1263 | builder.CreateCall(F, {loopCount, srcPtr, mappingPtr}); |
1264 | auto *destType = getElementType(builder, dest); |
1265 | auto *destAddr = |
1266 | builder.CreateGEP(destType, destPtr, loopCount, "buffer.element.addr" ); |
1267 | builder.CreateStore(stackedOpCall, destAddr); |
1268 | |
1269 | break; |
1270 | } |
1271 | #define ARITHMETIC_UNARY_OP_CASE(INST_NAME_, FUN_NAME_) \ |
1272 | case Kinded::Kind::INST_NAME_##InstKind: { \ |
1273 | auto *AN = cast<INST_NAME_##Inst>(I); \ |
1274 | auto *dest = AN->getDest(); \ |
1275 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); \ |
1276 | auto *srcPtr = \ |
1277 | emitBufferAddress(builder, AN->getSrc(), kernel, bufferToArgNum); \ |
1278 | auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType()); \ |
1279 | auto *elementTy = getElementType(builder, dest); \ |
1280 | auto *pointerNull = \ |
1281 | llvm::ConstantPointerNull::get(elementTy->getPointerTo()); \ |
1282 | auto *stackedOpCall = createUncheckedCall( \ |
1283 | builder, F, {loopCount, srcPtr, pointerNull, pointerNull}); \ |
1284 | auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr, \ |
1285 | loopCount, "buffer.element.addr"); \ |
1286 | builder.CreateStore(stackedOpCall, destAddr); \ |
1287 | break; \ |
1288 | } |
1289 | |
1290 | ARITHMETIC_UNARY_OP_CASE(Sigmoid, "sigmoid" ); |
1291 | ARITHMETIC_UNARY_OP_CASE(Tanh, "tanh" ); |
1292 | ARITHMETIC_UNARY_OP_CASE(ElementLog, "element_log" ); |
1293 | ARITHMETIC_UNARY_OP_CASE(ElementExp, "element_exp" ); |
1294 | ARITHMETIC_UNARY_OP_CASE(ElementAbs, "element_abs" ); |
1295 | ARITHMETIC_UNARY_OP_CASE(ElementNeg, "element_neg" ); |
1296 | ARITHMETIC_UNARY_OP_CASE(ElementFloor, "element_floor" ); |
1297 | ARITHMETIC_UNARY_OP_CASE(ElementCeil, "element_ceil" ); |
1298 | ARITHMETIC_UNARY_OP_CASE(ElementRound, "element_round" ); |
1299 | ARITHMETIC_UNARY_OP_CASE(ElementSqrt, "element_sqrt" ); |
1300 | ARITHMETIC_UNARY_OP_CASE(ElementErf, "element_erf" ); |
1301 | ARITHMETIC_UNARY_OP_CASE(ElementRsqrt, "element_rsqrt" ); |
1302 | ARITHMETIC_UNARY_OP_CASE(ElementReciprocal, "element_reciprocal" ); |
1303 | ARITHMETIC_UNARY_OP_CASE(ElementSin, "element_sin" ); |
1304 | ARITHMETIC_UNARY_OP_CASE(ElementCos, "element_cos" ); |
1305 | #undef ARITHMETIC_UNARY_OP_CASE |
1306 | |
1307 | case Kinded::Kind::ReluInstKind: { |
1308 | auto *RI = cast<ReluInst>(I); |
1309 | auto *src = RI->getSrc(); |
1310 | auto *dest = RI->getDest(); |
1311 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1312 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1313 | auto srcTy = src->getType(); |
1314 | auto destTy = dest->getType(); |
1315 | |
1316 | auto *F = getFunction("element_relu" , dest->getElementType()); |
1317 | llvm::CallInst *stackedOpCall = nullptr; |
1318 | if (dest->getElementType() == ElemKind::Int8QTy) { |
1319 | auto *srcOffset = |
1320 | emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset())); |
1321 | auto *destOffset = |
1322 | emitConstI8(builder, static_cast<int8_t>(destTy->getOffset())); |
1323 | auto destScaleParams = quantization::quantizeScaleOffset32To8( |
1324 | srcTy->getScale() / destTy->getScale(), 0); |
1325 | auto *destPre = emitConstI32(builder, destScaleParams.pre); |
1326 | auto *destPost = emitConstI32(builder, destScaleParams.post); |
1327 | auto *destScale = emitConstI32(builder, destScaleParams.scale); |
1328 | stackedOpCall = createCall(builder, F, |
1329 | {loopCount, srcPtr, srcOffset, destOffset, |
1330 | destPre, destPost, destScale}); |
1331 | } else if (dest->getElementType() == ElemKind::FloatTy) { |
1332 | stackedOpCall = createCall(builder, F, {loopCount, srcPtr}); |
1333 | } else { |
1334 | LOG(FATAL) << "Type is not supported" ; |
1335 | } |
1336 | auto *elementTy = getElementType(builder, dest); |
1337 | auto *destAddr = |
1338 | builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr" ); |
1339 | builder.CreateStore(stackedOpCall, destAddr); |
1340 | break; |
1341 | } |
1342 | |
1343 | case Kinded::Kind::ClipInstKind: { |
1344 | auto *CI = cast<ClipInst>(I); |
1345 | auto *src = CI->getSrc(); |
1346 | auto *dest = CI->getDest(); |
1347 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1348 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1349 | auto srcTy = src->getType(); |
1350 | auto destTy = dest->getType(); |
1351 | float clipMinF = CI->getMin(); |
1352 | float clipMaxF = CI->getMax(); |
1353 | |
1354 | auto *F = getFunction("element_clip" , dest->getElementType()); |
1355 | llvm::CallInst *stackedOpCall = nullptr; |
1356 | if (dest->getElementType() == ElemKind::Int8QTy) { |
1357 | TensorQuantizationParams srcTQP{src->getType()->getScale(), |
1358 | src->getType()->getOffset()}; |
1359 | int8_t clipMinQ = quantization::quantize<int8_t>(clipMinF, srcTQP); |
1360 | int8_t clipMaxQ = quantization::quantize<int8_t>(clipMaxF, srcTQP); |
1361 | auto *clipMin = emitConstI8(builder, clipMinQ); |
1362 | auto *clipMax = emitConstI8(builder, clipMaxQ); |
1363 | auto *srcOffset = |
1364 | emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset())); |
1365 | auto *destOffset = |
1366 | emitConstI8(builder, static_cast<int8_t>(destTy->getOffset())); |
1367 | auto destScaleParams = quantization::quantizeScaleOffset32To8( |
1368 | srcTy->getScale() / destTy->getScale(), 0); |
1369 | auto *destPre = emitConstI32(builder, destScaleParams.pre); |
1370 | auto *destPost = emitConstI32(builder, destScaleParams.post); |
1371 | auto *destScale = emitConstI32(builder, destScaleParams.scale); |
1372 | stackedOpCall = |
1373 | createCall(builder, F, |
1374 | {loopCount, srcPtr, clipMin, clipMax, srcOffset, |
1375 | destOffset, destPre, destPost, destScale}); |
1376 | } else if (dest->getElementType() == ElemKind::FloatTy) { |
1377 | auto *clipMin = emitConstF32(builder, clipMinF); |
1378 | auto *clipMax = emitConstF32(builder, clipMaxF); |
1379 | stackedOpCall = |
1380 | createCall(builder, F, {loopCount, srcPtr, clipMin, clipMax}); |
1381 | } else { |
1382 | LOG(FATAL) << "Type is not supported" ; |
1383 | } |
1384 | auto *elementTy = getElementType(builder, dest); |
1385 | auto *destAddr = |
1386 | builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr" ); |
1387 | builder.CreateStore(stackedOpCall, destAddr); |
1388 | break; |
1389 | } |
1390 | |
1391 | case Kinded::Kind::LeakyReluInstKind: { |
1392 | auto *LI = cast<LeakyReluInst>(I); |
1393 | auto *src = LI->getSrc(); |
1394 | auto *dest = LI->getDest(); |
1395 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1396 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1397 | auto srcTy = src->getType(); |
1398 | auto destTy = dest->getType(); |
1399 | |
1400 | auto *F = getFunction("element_leaky_relu" , dest->getElementType()); |
1401 | llvm::CallInst *stackedOpCall = nullptr; |
1402 | if (dest->getElementType() == ElemKind::Int8QTy) { |
1403 | auto *srcOffset = |
1404 | emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset())); |
1405 | auto *destOffset = |
1406 | emitConstI8(builder, static_cast<int8_t>(destTy->getOffset())); |
1407 | // Scale parameters for the positive input domain. |
1408 | auto posParams = quantization::quantizeScaleOffset32To8( |
1409 | srcTy->getScale() / destTy->getScale(), 0); |
1410 | auto *posPre = emitConstI32(builder, posParams.pre); |
1411 | auto *posPost = emitConstI32(builder, posParams.post); |
1412 | auto *posScale = emitConstI32(builder, posParams.scale); |
1413 | // Scale parameters for the negative input domain. |
1414 | auto negParams = quantization::quantizeScaleOffset32To8( |
1415 | srcTy->getScale() * LI->getAlpha() / destTy->getScale(), 0); |
1416 | auto *negPre = emitConstI32(builder, negParams.pre); |
1417 | auto *negPost = emitConstI32(builder, negParams.post); |
1418 | auto *negScale = emitConstI32(builder, negParams.scale); |
1419 | stackedOpCall = |
1420 | createCall(builder, F, |
1421 | {loopCount, srcPtr, srcOffset, destOffset, posPre, posPost, |
1422 | posScale, negPre, negPost, negScale}); |
1423 | } else if (dest->getElementType() == ElemKind::FloatTy) { |
1424 | auto *alpha = emitConstF32(builder, LI->getAlpha()); |
1425 | stackedOpCall = createCall(builder, F, {loopCount, srcPtr, alpha}); |
1426 | } else { |
1427 | LOG(FATAL) << "Type is not supported" ; |
1428 | } |
1429 | auto *elementTy = getElementType(builder, dest); |
1430 | auto *destAddr = |
1431 | builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr" ); |
1432 | builder.CreateStore(stackedOpCall, destAddr); |
1433 | break; |
1434 | } |
1435 | |
1436 | case Kinded::Kind::ElementIsNaNInstKind: { |
1437 | auto *AN = cast<ElementIsNaNInst>(I); |
1438 | auto *src = AN->getSrc(); |
1439 | auto *dest = AN->getDest(); |
1440 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1441 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1442 | auto *F = getFunction("element_is_nan_kernel" , src->getElementType()); |
1443 | auto *stackedOpCall = createUncheckedCall(builder, F, {loopCount, srcPtr}); |
1444 | auto *elementTy = getElementType(builder, dest); |
1445 | auto *destAddr = |
1446 | builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr" ); |
1447 | builder.CreateStore(stackedOpCall, destAddr); |
1448 | break; |
1449 | } |
1450 | |
1451 | case Kinded::Kind::QuantizeInstKind: { |
1452 | auto *QI = cast<QuantizeInst>(I); |
1453 | auto *src = QI->getSrc(); |
1454 | auto *dest = QI->getDest(); |
1455 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1456 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1457 | auto *destTy = dest->getType(); |
1458 | auto *destScale = emitConstF32(builder, destTy->getScale()); |
1459 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
1460 | auto *F = getFunction("element_quantize_kernel" , dest->getElementType()); |
1461 | |
1462 | auto *stackedOpCall = createUncheckedCall( |
1463 | builder, F, {loopCount, srcPtr, destScale, destOffset}); |
1464 | auto *destType = getElementType(builder, dest); |
1465 | auto *destAddr = |
1466 | builder.CreateGEP(destType, destPtr, loopCount, "buffer.element.addr" ); |
1467 | builder.CreateStore(stackedOpCall, destAddr); |
1468 | break; |
1469 | } |
1470 | |
1471 | case Kinded::Kind::DequantizeInstKind: { |
1472 | auto *DI = cast<DequantizeInst>(I); |
1473 | auto *src = DI->getSrc(); |
1474 | auto *dest = DI->getDest(); |
1475 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1476 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1477 | auto *srcTy = src->getType(); |
1478 | auto *srcScale = emitConstF32(builder, srcTy->getScale()); |
1479 | auto *srcOffset = emitConstI32(builder, srcTy->getOffset()); |
1480 | auto *F = getFunction("element_dequantize_kernel" , src->getElementType()); |
1481 | |
1482 | auto *stackedOpCall = createUncheckedCall( |
1483 | builder, F, {loopCount, srcPtr, srcScale, srcOffset}); |
1484 | auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr, loopCount, |
1485 | "buffer.element.addr" ); |
1486 | builder.CreateStore(stackedOpCall, destAddr); |
1487 | break; |
1488 | } |
1489 | |
1490 | case Kinded::Kind::RescaleQuantizedInstKind: { |
1491 | auto *RQI = cast<RescaleQuantizedInst>(I); |
1492 | auto *dest = RQI->getDest(); |
1493 | auto *src = RQI->getSrc(); |
1494 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1495 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1496 | |
1497 | auto *destType = dest->getType(); |
1498 | auto *srcType = src->getType(); |
1499 | |
1500 | auto rescaleParams = quantization::quantizeScaleOffset32To8( |
1501 | srcType->getScale() / destType->getScale(), srcType->getOffset()); |
1502 | |
1503 | auto *destOffset = emitConstI32(builder, destType->getOffset()); |
1504 | auto *srcOffset = emitConstI32(builder, srcType->getOffset()); |
1505 | auto *preShift = emitConstI32(builder, rescaleParams.pre); |
1506 | auto *postShift = emitConstI32(builder, rescaleParams.post); |
1507 | auto *scale = emitConstI32(builder, rescaleParams.scale); |
1508 | auto *F = getFunction("element_rescale_kernel" , dest->getElementType()); |
1509 | |
1510 | auto *stackedOpCall = createUncheckedCall( |
1511 | builder, F, |
1512 | {loopCount, srcPtr, destOffset, srcOffset, preShift, postShift, scale}); |
1513 | auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, loopCount, |
1514 | "buffer.element.addr" ); |
1515 | builder.CreateStore(stackedOpCall, destAddr); |
1516 | break; |
1517 | } |
1518 | |
1519 | case Kinded::Kind::CopyInstKind: { |
1520 | auto *CI = cast<CopyInst>(I); |
1521 | auto *dest = CI->getDest(); |
1522 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1523 | auto *srcPtr = |
1524 | emitBufferAddress(builder, CI->getSrc(), kernel, bufferToArgNum); |
1525 | auto *F = getFunction("copy_kernel" , dest->getElementType()); |
1526 | auto *elementTy = getElementType(builder, dest); |
1527 | auto *pointerNull = |
1528 | llvm::ConstantPointerNull::get(elementTy->getPointerTo()); |
1529 | auto *stackedOpCall = createUncheckedCall( |
1530 | builder, F, {loopCount, srcPtr, pointerNull, pointerNull}); |
1531 | auto *destAddr = builder.CreateGEP(getElementType(builder, dest), destPtr, |
1532 | loopCount, "buffer.element.addr" ); |
1533 | builder.CreateStore(stackedOpCall, destAddr); |
1534 | break; |
1535 | } |
1536 | |
1537 | #define ARITHMETIC_BINARY_OP_CASE(INST_NAME_, FUN_NAME_, ...) \ |
1538 | case Kinded::Kind::INST_NAME_##InstKind: { \ |
1539 | auto *AN = cast<INST_NAME_##Inst>(I); \ |
1540 | auto *dest = AN->getDest(); \ |
1541 | auto *lhs = AN->getLHS(); \ |
1542 | auto *rhs = AN->getRHS(); \ |
1543 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); \ |
1544 | auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); \ |
1545 | auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); \ |
1546 | \ |
1547 | auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType()); \ |
1548 | auto *elementTy = getElementType(builder, dest); \ |
1549 | auto *pointerNull = \ |
1550 | llvm::ConstantPointerNull::get(elementTy->getPointerTo()); \ |
1551 | bool typesMatched = matchPair(dest->getElementType(), __VA_ARGS__); \ |
1552 | if (lhs->getType()->isQuantizedType()) { \ |
1553 | auto *destTy = dest->getType(); \ |
1554 | auto *lhsTy = lhs->getType(); \ |
1555 | auto *rhsTy = rhs->getType(); \ |
1556 | \ |
1557 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); \ |
1558 | auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset()); \ |
1559 | auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset()); \ |
1560 | \ |
1561 | float destScale = destTy->getScale(); \ |
1562 | \ |
1563 | auto lhsScaleParams = quantization::quantizeScaleOffset32To8( \ |
1564 | lhsTy->getScale() / destScale, lhsTy->getOffset()); \ |
1565 | auto rhsScaleParams = quantization::quantizeScaleOffset32To8( \ |
1566 | rhsTy->getScale() / destScale, rhsTy->getOffset()); \ |
1567 | \ |
1568 | auto *lhsPre = emitConstI32(builder, lhsScaleParams.pre); \ |
1569 | auto *lhsPost = emitConstI32(builder, lhsScaleParams.post); \ |
1570 | auto *lhsScale = emitConstI32(builder, lhsScaleParams.scale); \ |
1571 | auto *rhsPre = emitConstI32(builder, rhsScaleParams.pre); \ |
1572 | auto *rhsPost = emitConstI32(builder, rhsScaleParams.post); \ |
1573 | auto *rhsScale = emitConstI32(builder, rhsScaleParams.scale); \ |
1574 | \ |
1575 | auto *stackedOpCall = createUncheckedCall( \ |
1576 | builder, F, \ |
1577 | {loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset, rhsOffset, \ |
1578 | lhsPre, lhsPost, lhsScale, rhsPre, rhsPost, rhsScale}); \ |
1579 | auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, \ |
1580 | loopCount, "buffer.element.addr"); \ |
1581 | builder.CreateStore(stackedOpCall, destAddr); \ |
1582 | } else if (typesMatched) { \ |
1583 | auto *stackedOpCall = createUncheckedCall( \ |
1584 | builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull}); \ |
1585 | auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \ |
1586 | "buffer.element.addr"); \ |
1587 | builder.CreateStore(stackedOpCall, destAddr); \ |
1588 | } else { \ |
1589 | llvm_unreachable("Unsupported Type in " #INST_NAME_); \ |
1590 | } \ |
1591 | break; \ |
1592 | } |
1593 | ARITHMETIC_BINARY_OP_CASE(ElementAdd, "element_add" , ElemKind::FloatTy, |
1594 | ElemKind::Int32ITy, ElemKind::Int64ITy); |
1595 | ARITHMETIC_BINARY_OP_CASE(ElementSub, "element_sub" , ElemKind::FloatTy); |
1596 | ARITHMETIC_BINARY_OP_CASE(ElementMax, "element_max" , ElemKind::FloatTy); |
1597 | ARITHMETIC_BINARY_OP_CASE(ElementMin, "element_min" , ElemKind::FloatTy); |
1598 | ARITHMETIC_BINARY_OP_CASE(ElementPow, "element_pow" , ElemKind::FloatTy); |
1599 | #undef ARITHMETIC_BINARY_OP_CASE |
1600 | |
1601 | case Kinded::Kind::ElementNotInstKind: { |
1602 | auto *NI = cast<ElementNotInst>(I); |
1603 | auto *dest = NI->getDest(); |
1604 | auto *src = NI->getSrc(); |
1605 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1606 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1607 | auto *F = getFunction("element_not_kernel" , src->getElementType()); |
1608 | auto *elementTy = getElementType(builder, dest); |
1609 | auto *stackedOpCall = createUncheckedCall(builder, F, {loopCount, srcPtr}); |
1610 | auto *destAddr = |
1611 | builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr" ); |
1612 | builder.CreateStore(stackedOpCall, destAddr); |
1613 | break; |
1614 | } |
1615 | |
1616 | case Kinded::Kind::ElementAndInstKind: { |
1617 | auto *AI = cast<ElementAndInst>(I); |
1618 | auto *dest = AI->getDest(); |
1619 | auto *lhs = AI->getLHS(); |
1620 | auto *rhs = AI->getRHS(); |
1621 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1622 | auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); |
1623 | auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); |
1624 | auto *F = getFunction("element_and_kernel" , lhs->getElementType()); |
1625 | auto *elementTy = getElementType(builder, dest); |
1626 | auto *stackedOpCall = |
1627 | createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr}); |
1628 | auto *destAddr = |
1629 | builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr" ); |
1630 | builder.CreateStore(stackedOpCall, destAddr); |
1631 | break; |
1632 | } |
1633 | |
1634 | case Kinded::Kind::ElementOrInstKind: { |
1635 | auto *OI = cast<ElementOrInst>(I); |
1636 | auto *dest = OI->getDest(); |
1637 | auto *lhs = OI->getLHS(); |
1638 | auto *rhs = OI->getRHS(); |
1639 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1640 | auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); |
1641 | auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); |
1642 | auto *F = getFunction("element_or_kernel" , lhs->getElementType()); |
1643 | auto *elementTy = getElementType(builder, dest); |
1644 | auto *stackedOpCall = |
1645 | createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr}); |
1646 | auto *destAddr = |
1647 | builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr" ); |
1648 | builder.CreateStore(stackedOpCall, destAddr); |
1649 | break; |
1650 | } |
1651 | |
1652 | case Kinded::Kind::ElementXorInstKind: { |
1653 | auto *XI = cast<ElementXorInst>(I); |
1654 | auto *dest = XI->getDest(); |
1655 | auto *lhs = XI->getLHS(); |
1656 | auto *rhs = XI->getRHS(); |
1657 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1658 | auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); |
1659 | auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); |
1660 | auto *F = getFunction("element_xor_kernel" , lhs->getElementType()); |
1661 | auto *elementTy = getElementType(builder, dest); |
1662 | auto *stackedOpCall = |
1663 | createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr}); |
1664 | auto *destAddr = |
1665 | builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr" ); |
1666 | builder.CreateStore(stackedOpCall, destAddr); |
1667 | break; |
1668 | } |
1669 | |
1670 | case Kinded::Kind::ElementCmpEQInstKind: |
1671 | case Kinded::Kind::ElementCmpNEQInstKind: |
1672 | case Kinded::Kind::ElementCmpLTInstKind: |
1673 | case Kinded::Kind::ElementCmpLTEInstKind: { |
1674 | Value *dest = nullptr; |
1675 | Value *lhs = nullptr; |
1676 | Value *rhs = nullptr; |
1677 | std::string kernelName; |
1678 | |
1679 | if (auto *CEQI = dyn_cast<ElementCmpEQInst>(I)) { |
1680 | dest = CEQI->getDest(); |
1681 | lhs = CEQI->getLHS(); |
1682 | rhs = CEQI->getRHS(); |
1683 | kernelName = "element_cmp_eq_kernel" ; |
1684 | } else if (auto *CNEQI = dyn_cast<ElementCmpNEQInst>(I)) { |
1685 | dest = CNEQI->getDest(); |
1686 | lhs = CNEQI->getLHS(); |
1687 | rhs = CNEQI->getRHS(); |
1688 | kernelName = "element_cmp_neq_kernel" ; |
1689 | } else if (auto *CLTEI = dyn_cast<ElementCmpLTEInst>(I)) { |
1690 | dest = CLTEI->getDest(); |
1691 | lhs = CLTEI->getLHS(); |
1692 | rhs = CLTEI->getRHS(); |
1693 | kernelName = "element_cmp_lte_kernel" ; |
1694 | } else if (auto *CLTI = dyn_cast<ElementCmpLTInst>(I)) { |
1695 | dest = CLTI->getDest(); |
1696 | lhs = CLTI->getLHS(); |
1697 | rhs = CLTI->getRHS(); |
1698 | kernelName = "element_cmp_lt_kernel" ; |
1699 | } else { |
1700 | llvm_unreachable( |
1701 | "Missmatch between Instruction Kind and instruction instance." ); |
1702 | } |
1703 | |
1704 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1705 | auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); |
1706 | auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); |
1707 | |
1708 | // Need _kernel suffix since these operations are implemented as |
1709 | // "data-parallel" kernels in libjit. |
1710 | auto *F = getFunction(kernelName.c_str(), lhs->getElementType()); |
1711 | |
1712 | if (lhs->getType()->isQuantizedType()) { |
1713 | auto *lhsTy = lhs->getType(); |
1714 | auto *rhsTy = rhs->getType(); |
1715 | |
1716 | auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset()); |
1717 | auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset()); |
1718 | |
1719 | // We can divide both sides of the comparison by the rhs scale since it is |
1720 | // strictly positive; this saves one rescale within the backend. The |
1721 | // inequalities are: |
1722 | // s_l * (i_l - o_l) <= s_r * (i_r - o_r) |
1723 | // <=> (s_l / s_r) * (i_l - o_l) <= i_r - o_r |
1724 | float scale = lhsTy->getScale() / rhsTy->getScale(); |
1725 | auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0); |
1726 | auto *cmpPre = emitConstI32(builder, scaleParams.pre); |
1727 | auto *cmpPost = emitConstI32(builder, scaleParams.post); |
1728 | auto *cmpScale = emitConstI32(builder, scaleParams.scale); |
1729 | |
1730 | auto *stackedOpCall = |
1731 | createUncheckedCall(builder, F, |
1732 | {loopCount, lhsPtr, rhsPtr, lhsOffset, rhsOffset, |
1733 | cmpPre, cmpPost, cmpScale}); |
1734 | auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, |
1735 | loopCount, "buffer.element.addr" ); |
1736 | builder.CreateStore(stackedOpCall, destAddr); |
1737 | } else { |
1738 | auto *stackedOpCall = |
1739 | createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr}); |
1740 | auto *elementTy = getElementType(builder, dest); |
1741 | auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, |
1742 | "buffer.element.addr" ); |
1743 | builder.CreateStore(stackedOpCall, destAddr); |
1744 | } |
1745 | break; |
1746 | } |
1747 | |
1748 | case Kinded::Kind::ElementMulInstKind: { |
1749 | auto *MI = cast<ElementMulInst>(I); |
1750 | auto *dest = MI->getDest(); |
1751 | auto *lhs = MI->getLHS(); |
1752 | auto *rhs = MI->getRHS(); |
1753 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1754 | auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); |
1755 | auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); |
1756 | |
1757 | // Need _kernel suffix since these operations are implemented as |
1758 | // "data-parallel" kernels in libjit. |
1759 | auto *F = getFunction("element_mul_kernel" , dest->getElementType()); |
1760 | auto *elementTy = getElementType(builder, dest); |
1761 | auto *pointerNull = |
1762 | llvm::ConstantPointerNull::get(elementTy->getPointerTo()); |
1763 | |
1764 | if (lhs->getType()->isQuantizedType()) { |
1765 | auto *destTy = dest->getType(); |
1766 | auto *lhsTy = lhs->getType(); |
1767 | auto *rhsTy = rhs->getType(); |
1768 | |
1769 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
1770 | auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset()); |
1771 | auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset()); |
1772 | |
1773 | // The multiplicative scale factor is s_l * s_r / s_d due to the equation |
1774 | // s_d * (i_d - o_d) = s_l * (i_l - o_l) * s_r * (i_r - o_r) |
1775 | // => i_d = (s_l * s_r / s_d) * (i_l - o_l) * (i_r - o_r) + o_d |
1776 | float scale = lhsTy->getScale() * rhsTy->getScale() / destTy->getScale(); |
1777 | auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0); |
1778 | auto *mulPre = emitConstI32(builder, scaleParams.pre); |
1779 | auto *mulPost = emitConstI32(builder, scaleParams.post); |
1780 | auto *mulScale = emitConstI32(builder, scaleParams.scale); |
1781 | |
1782 | auto *stackedOpCall = |
1783 | createUncheckedCall(builder, F, |
1784 | {loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset, |
1785 | rhsOffset, mulPre, mulPost, mulScale}); |
1786 | auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, |
1787 | loopCount, "buffer.element.addr" ); |
1788 | builder.CreateStore(stackedOpCall, destAddr); |
1789 | } else if (lhs->getType()->getElementType() == ElemKind::Int64ITy || |
1790 | lhs->getType()->getElementType() == ElemKind::Int32ITy || |
1791 | lhs->getType()->getElementType() == ElemKind::FloatTy) { |
1792 | auto *stackedOpCall = createUncheckedCall( |
1793 | builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull}); |
1794 | auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, |
1795 | "buffer.element.addr" ); |
1796 | builder.CreateStore(stackedOpCall, destAddr); |
1797 | } else { |
1798 | LOG_ASSERT(false) << "Unsupported element type for Mul." ; |
1799 | } |
1800 | break; |
1801 | } |
1802 | |
1803 | case Kinded::Kind::ElementDivInstKind: { |
1804 | auto *MI = cast<ElementDivInst>(I); |
1805 | auto *dest = MI->getDest(); |
1806 | auto *lhs = MI->getLHS(); |
1807 | auto *rhs = MI->getRHS(); |
1808 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1809 | auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); |
1810 | auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); |
1811 | |
1812 | // Need _kernel suffix since these operations are implemented as |
1813 | // "data-parallel" kernels in libjit. |
1814 | auto *F = getFunction("element_div_kernel" , dest->getElementType()); |
1815 | auto *elementTy = getElementType(builder, dest); |
1816 | auto *pointerNull = |
1817 | llvm::ConstantPointerNull::get(elementTy->getPointerTo()); |
1818 | |
1819 | if (lhs->getType()->isQuantizedType()) { |
1820 | auto *destTy = dest->getType(); |
1821 | auto *lhsTy = lhs->getType(); |
1822 | auto *rhsTy = rhs->getType(); |
1823 | |
1824 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
1825 | auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset()); |
1826 | auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset()); |
1827 | |
1828 | // The division scale factor is s_l / (s_r * s_d) due to the equation |
1829 | // s_d * (i_d - o_d) = (s_l * (i_l - o_l)) / (s_r * (i_r - o_r)) |
1830 | // => i_d = (s_l / (s_r * s_d)) * ((i_l - o_l) / (i_r - o_r)) + o_d |
1831 | float scale = |
1832 | lhsTy->getScale() / (rhsTy->getScale() * destTy->getScale()); |
1833 | auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0); |
1834 | auto *divPre = emitConstI32(builder, scaleParams.pre); |
1835 | auto *divPost = emitConstI32(builder, scaleParams.post); |
1836 | auto *divScale = emitConstI32(builder, scaleParams.scale); |
1837 | |
1838 | auto *stackedOpCall = |
1839 | createUncheckedCall(builder, F, |
1840 | {loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset, |
1841 | rhsOffset, divPre, divPost, divScale}); |
1842 | auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, |
1843 | loopCount, "buffer.element.addr" ); |
1844 | builder.CreateStore(stackedOpCall, destAddr); |
1845 | } else { |
1846 | auto *elementTy = getElementType(builder, dest); |
1847 | auto *stackedOpCall = createUncheckedCall( |
1848 | builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull}); |
1849 | auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, |
1850 | "buffer.element.addr" ); |
1851 | builder.CreateStore(stackedOpCall, destAddr); |
1852 | } |
1853 | break; |
1854 | } |
1855 | |
1856 | case Kinded::Kind::ModuloInstKind: { |
1857 | auto *MI = cast<ModuloInst>(I); |
1858 | auto *dest = MI->getDest(); |
1859 | auto *src = MI->getSrc(); |
1860 | auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); |
1861 | auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum); |
1862 | auto *divisor = emitConst(builder, MI->getDivisor(), ElemKind::Int64ITy); |
1863 | llvm::Function *F = nullptr; |
1864 | // Need _kernel suffix since these operations are implemented as |
1865 | // "data-parallel" kernels in libjit. |
1866 | if (MI->getSignFollowDivisor()) { |
1867 | F = getFunction("element_modulo_kernel_sign_follow" , |
1868 | dest->getElementType()); |
1869 | } else { |
1870 | F = getFunction("element_modulo_kernel_no_sign_follow" , |
1871 | dest->getElementType()); |
1872 | } |
1873 | auto *stackedOpCall = |
1874 | createUncheckedCall(builder, F, {loopCount, divisor, srcPtr}); |
1875 | llvm::Value *destAddr = nullptr; |
1876 | if (dest->getElementType() == ElemKind::Int64ITy) { |
1877 | destAddr = builder.CreateGEP(builder.getInt64Ty(), destPtr, loopCount, |
1878 | "buffer.element.addr" ); |
1879 | } else { |
1880 | destAddr = builder.CreateGEP(builder.getInt32Ty(), destPtr, loopCount, |
1881 | "buffer.element.addr" ); |
1882 | } |
1883 | builder.CreateStore(stackedOpCall, destAddr); |
1884 | break; |
1885 | } |
1886 | |
1887 | default: |
1888 | std::string sBuf; |
1889 | llvm::raw_string_ostream s(sBuf); |
1890 | I->dump(s); |
1891 | LOG(FATAL) << "Cannot select the instruction: " << s.str(); |
1892 | } |
1893 | } |
1894 | |
1895 | Tensor LLVMIRGen::getTensorForConstantValue(Value *value) { |
1896 | // Since we can't get the variable from a glow::Value directly, |
1897 | // we need to traverse the var list and find the one matching the given |
1898 | // Value. |
1899 | Tensor tensor; |
1900 | auto *F_ = getIRFunction(); |
1901 | for (auto &v : F_->findConstants()) { |
1902 | assert(isa<WeightVar>(F_->getWeightForNode(v))); |
1903 | auto *w = cast<glow::Value>(F_->getWeightForNode(v)); |
1904 | if (w == value) { |
1905 | tensor.assign(&v->getPayload()); |
1906 | break; |
1907 | } |
1908 | } |
1909 | CHECK(tensor.getUnsafePtr()) << "Can't find the constant value!" ; |
1910 | return tensor; |
1911 | } |
1912 | |
1913 | void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder, |
1914 | const glow::Instruction *I) { |
1915 | setCurrentDebugLocation(builder, I); |
1916 | assert((!canBePartOfDataParallelKernel(I)) && |
1917 | "data parallel instructions are not handled here" ); |
1918 | switch (I->getKind()) { |
1919 | case Kinded::Kind::MatMulInstKind: { |
1920 | auto *MM = cast<MatMulInst>(I); |
1921 | auto *dest = MM->getDest(); |
1922 | auto *lhs = MM->getLHS(); |
1923 | auto *rhs = MM->getRHS(); |
1924 | auto *destPtr = emitValueAddress(builder, dest); |
1925 | auto *lhsPtr = emitValueAddress(builder, lhs); |
1926 | auto *rhsPtr = emitValueAddress(builder, rhs); |
1927 | |
1928 | auto *destDims = emitValueDims(builder, dest); |
1929 | auto *lhsDims = emitValueDims(builder, lhs); |
1930 | auto *rhsDims = emitValueDims(builder, rhs); |
1931 | |
1932 | auto *F = getFunction("matmul" , dest->getElementType()); |
1933 | |
1934 | if (lhs->getType()->isQuantizedType()) { |
1935 | auto *destTy = dest->getType(); |
1936 | auto *lhsTy = lhs->getType(); |
1937 | auto *rhsTy = rhs->getType(); |
1938 | |
1939 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
1940 | auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset()); |
1941 | auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset()); |
1942 | |
1943 | auto outScaleParams = quantization::quantizeScaleOffset32To8( |
1944 | lhsTy->getScale() * rhsTy->getScale() / destTy->getScale(), 0); |
1945 | |
1946 | auto *outPre = emitConstI32(builder, outScaleParams.pre); |
1947 | auto *outPost = emitConstI32(builder, outScaleParams.post); |
1948 | auto *outScale = emitConstI32(builder, outScaleParams.scale); |
1949 | |
1950 | createCall(builder, F, |
1951 | {destPtr, lhsPtr, rhsPtr, destDims, lhsDims, rhsDims, |
1952 | destOffset, lhsOffset, rhsOffset, outPre, outPost, outScale}); |
1953 | } else { |
1954 | createCall(builder, F, |
1955 | {destPtr, lhsPtr, rhsPtr, destDims, lhsDims, rhsDims}); |
1956 | } |
1957 | break; |
1958 | } |
1959 | |
1960 | case Kinded::Kind::QuantizationProfileInstKind: { |
1961 | auto *QP = cast<QuantizationProfileInst>(I); |
1962 | auto *hist = QP->getHistogram(); |
1963 | auto *compInfo = QP->getComputationInfo(); |
1964 | auto *inputTensor = QP->getInputTensor(); |
1965 | |
1966 | auto *histPtr = emitValueAddress(builder, hist); |
1967 | auto *compInfoPtr = emitValueAddress(builder, compInfo); |
1968 | auto *inputTensorInfoPtr = emitValueAddress(builder, inputTensor); |
1969 | |
1970 | auto *histDims = emitValueDims(builder, hist); |
1971 | assert(inputTensor->getElementType() == ElemKind::FloatTy && |
1972 | "None float Tensor type for Quantization Profile Instruction." ); |
1973 | auto *tensorSize = emitConstDimT(builder, inputTensor->getType()->size()); |
1974 | |
1975 | auto *F = getFunction("quantization_profile" ); |
1976 | createCall( |
1977 | builder, F, |
1978 | {inputTensorInfoPtr, tensorSize, compInfoPtr, histPtr, histDims}); |
1979 | break; |
1980 | } |
1981 | |
1982 | case Kinded::Kind::FullyConnectedInstKind: { |
1983 | auto *FCI = cast<FullyConnectedInst>(I); |
1984 | auto *dest = FCI->getDest(); |
1985 | auto *src = FCI->getSrc(); |
1986 | auto *weights = FCI->getWeights(); |
1987 | auto *bias = FCI->getBias(); |
1988 | auto *destPtr = emitValueAddress(builder, dest); |
1989 | auto *srcPtr = emitValueAddress(builder, src); |
1990 | auto *weightsPtr = emitValueAddress(builder, weights); |
1991 | auto *biasPtr = emitValueAddress(builder, bias); |
1992 | auto *destDims = emitValueDims(builder, dest); |
1993 | auto *srcDims = emitValueDims(builder, src); |
1994 | auto *weightsDims = emitValueDims(builder, weights); |
1995 | auto *biasDims = emitValueDims(builder, bias); |
1996 | |
1997 | if (src->getType()->isQuantizedType()) { |
1998 | auto *destTy = dest->getType(); |
1999 | auto *srcTy = src->getType(); |
2000 | auto *weightsTy = weights->getType(); |
2001 | auto *biasTy = bias->getType(); |
2002 | |
2003 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
2004 | auto *srcOffset = emitConstI32(builder, srcTy->getOffset()); |
2005 | auto *weightsOffset = emitConstI32(builder, weightsTy->getOffset()); |
2006 | auto *biasOffset = emitConstI32(builder, biasTy->getOffset()); |
2007 | |
2008 | // Calculate the scale of the values that come out of the matrix |
2009 | // multiplication part of the calculation. |
2010 | float matMulScale = srcTy->getScale() * weightsTy->getScale(); |
2011 | |
2012 | // Calculate the scaling parameters for the bias and output. |
2013 | auto biasScaleParam = quantization::quantizeScaleOffset32To8( |
2014 | biasTy->getScale() / matMulScale, 0); |
2015 | auto outScaleParam = quantization::quantizeScaleOffset32To8( |
2016 | matMulScale / destTy->getScale(), 0); |
2017 | |
2018 | // Pass the pre-shift, post-shift and integer scale parameters for the |
2019 | // bias and output calculation. |
2020 | auto *biasPre = emitConstI32(builder, biasScaleParam.pre); |
2021 | auto *biasPost = emitConstI32(builder, biasScaleParam.post); |
2022 | auto *biasScale = emitConstI32(builder, biasScaleParam.scale); |
2023 | auto *outPre = emitConstI32(builder, outScaleParam.pre); |
2024 | auto *outPost = emitConstI32(builder, outScaleParam.post); |
2025 | auto *outScale = emitConstI32(builder, outScaleParam.scale); |
2026 | |
2027 | auto *F = |
2028 | getFunction("fc" , {dest->getElementType(), bias->getElementType()}); |
2029 | createCall(builder, F, |
2030 | {destPtr, srcPtr, weightsPtr, biasPtr, destDims, srcDims, |
2031 | weightsDims, biasDims, destOffset, srcOffset, weightsOffset, |
2032 | biasOffset, biasPre, biasPost, biasScale, outPre, outPost, |
2033 | outScale}); |
2034 | } else { |
2035 | auto *F = getFunction("fc" , dest->getElementType()); |
2036 | createCall(builder, F, |
2037 | {destPtr, srcPtr, weightsPtr, biasPtr, destDims, srcDims, |
2038 | weightsDims, biasDims}); |
2039 | } |
2040 | break; |
2041 | } |
2042 | |
2043 | case Kinded::Kind::RowwiseQuantizedFullyConnectedInstKind: { |
2044 | auto *RWQFC = cast<RowwiseQuantizedFullyConnectedInst>(I); |
2045 | |
2046 | auto scalesT = getTensorForConstantValue(RWQFC->getScales()); |
2047 | auto scalesH = scalesT.getHandle(); |
2048 | size_t rowNum = scalesH.dims()[0]; |
2049 | float inputScale = RWQFC->getSrc()->getType()->getScale(); |
2050 | |
2051 | float bScale = RWQFC->getBias()->getType()->getScale(); |
2052 | int32_t bOffset = RWQFC->getBias()->getType()->getOffset(); |
2053 | |
2054 | float outputScale = RWQFC->getDest()->getType()->getScale(); |
2055 | |
2056 | std::vector<llvm::Constant *> biasPreV(rowNum); |
2057 | std::vector<llvm::Constant *> biasPostV(rowNum); |
2058 | std::vector<llvm::Constant *> biasScaleV(rowNum); |
2059 | std::vector<llvm::Constant *> outputPreV(rowNum); |
2060 | std::vector<llvm::Constant *> outputPostV(rowNum); |
2061 | std::vector<llvm::Constant *> outputScaleV(rowNum); |
2062 | |
2063 | for (size_t i = 0; i < rowNum; i++) { |
2064 | // Calculate the scale of the values that come out of the matrix |
2065 | // multiplication part of the calculation. |
2066 | float matMulScale = inputScale * scalesH.raw(i); |
2067 | |
2068 | // Calculate the scaling parameters for the bias and output. |
2069 | auto biasScaleParam = |
2070 | quantization::quantizeScaleOffset32To8(bScale / matMulScale, bOffset); |
2071 | auto outScaleParam = |
2072 | quantization::quantizeScaleOffset32To8(matMulScale / outputScale, 0); |
2073 | |
2074 | // Pass the pre-shift, post-shift and integer scale parameters for the |
2075 | // bias and output calculation. |
2076 | biasPreV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2077 | biasScaleParam.pre, true); |
2078 | biasPostV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2079 | biasScaleParam.post, true); |
2080 | biasScaleV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2081 | biasScaleParam.scale, true); |
2082 | outputPreV[i] = |
2083 | llvm::ConstantInt::get(builder.getInt32Ty(), outScaleParam.pre, true); |
2084 | outputPostV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2085 | outScaleParam.post, true); |
2086 | outputScaleV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2087 | outScaleParam.scale, true); |
2088 | } |
2089 | |
2090 | auto *dest = RWQFC->getDest(); |
2091 | auto *src = RWQFC->getSrc(); |
2092 | auto *weights = RWQFC->getWeights(); |
2093 | auto *bias = RWQFC->getBias(); |
2094 | auto *weightsOffsets = RWQFC->getOffsets(); |
2095 | |
2096 | auto *destPtr = emitValueAddress(builder, dest); |
2097 | auto *srcPtr = emitValueAddress(builder, src); |
2098 | auto *weightsPtr = emitValueAddress(builder, weights); |
2099 | auto *biasPtr = emitValueAddress(builder, bias); |
2100 | auto *weightsOffsetsPtr = emitValueAddress(builder, weightsOffsets); |
2101 | auto *biasPrePtr = emitConstArray(builder, biasPreV, builder.getInt32Ty()); |
2102 | auto *biasPostPtr = |
2103 | emitConstArray(builder, biasPostV, builder.getInt32Ty()); |
2104 | auto *biasScalePtr = |
2105 | emitConstArray(builder, biasScaleV, builder.getInt32Ty()); |
2106 | auto *outputPrePtr = |
2107 | emitConstArray(builder, outputPreV, builder.getInt32Ty()); |
2108 | auto *outputPostPtr = |
2109 | emitConstArray(builder, outputPostV, builder.getInt32Ty()); |
2110 | auto *outputScalePtr = |
2111 | emitConstArray(builder, outputScaleV, builder.getInt32Ty()); |
2112 | |
2113 | auto *srcDims = emitValueDims(builder, src); |
2114 | auto *weightsDims = emitValueDims(builder, weights); |
2115 | auto *destDims = emitValueDims(builder, dest); |
2116 | auto *biasDims = emitValueDims(builder, bias); |
2117 | auto *row = emitConstDimT(builder, weightsOffsets->dims()[0]); |
2118 | |
2119 | auto *destOffset = emitConstI32(builder, dest->getType()->getOffset()); |
2120 | auto *srcOffset = emitConstI32(builder, src->getType()->getOffset()); |
2121 | auto *biasOffset = emitConstI32(builder, bOffset); |
2122 | |
2123 | llvm::Function *F = nullptr; |
2124 | if ((dest->getElementType() == ElemKind::Int8QTy) && |
2125 | (bias->getElementType() == ElemKind::Int8QTy)) { |
2126 | F = getFunction("rowwise_quantized_fc_i8_i8" ); |
2127 | } else if ((dest->getElementType() == ElemKind::Int8QTy) && |
2128 | (bias->getElementType() == ElemKind::Int32QTy)) { |
2129 | F = getFunction("rowwise_quantized_fc_i8_i32" ); |
2130 | } else { |
2131 | LOG(FATAL) << "Unsupported element/bias type for " |
2132 | "RowwiseQuantizedFullyConnectedInst" ; |
2133 | } |
2134 | |
2135 | createCall(builder, F, |
2136 | {destPtr, srcPtr, weightsPtr, biasPtr, weightsOffsetsPtr, |
2137 | biasPrePtr, biasPostPtr, biasScalePtr, outputPrePtr, |
2138 | outputPostPtr, outputScalePtr, destDims, srcDims, weightsDims, |
2139 | biasDims, row, destOffset, srcOffset, biasOffset}); |
2140 | break; |
2141 | } |
2142 | |
2143 | case Kinded::Kind::BatchedAddInstKind: { |
2144 | auto *BA = cast<BatchedAddInst>(I); |
2145 | auto *dest = BA->getDest(); |
2146 | auto *batch = BA->getBatch(); |
2147 | auto *slice = BA->getSlice(); |
2148 | auto *destPtr = emitValueAddress(builder, dest); |
2149 | auto *batchPtr = emitValueAddress(builder, batch); |
2150 | auto *slicePtr = emitValueAddress(builder, slice); |
2151 | |
2152 | auto bdim = flattenCdr(batch->dims()); |
2153 | auto *numSlice = emitConstDimT(builder, bdim.first); |
2154 | auto *sliceSize = emitConstDimT(builder, bdim.second); |
2155 | |
2156 | if (batch->getType()->isQuantizedType()) { |
2157 | auto *destTy = dest->getType(); |
2158 | auto *batchTy = batch->getType(); |
2159 | auto *sliceTy = slice->getType(); |
2160 | |
2161 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
2162 | auto *batchOffset = emitConstI32(builder, batchTy->getOffset()); |
2163 | auto *sliceOffset = emitConstI32(builder, sliceTy->getOffset()); |
2164 | |
2165 | float destScale = destTy->getScale(); |
2166 | |
2167 | // Here, we select parameters for scaling both summands to the |
2168 | // destination scale. |
2169 | auto batchScaleParams = quantization::quantizeScaleOffset32To8( |
2170 | batchTy->getScale() / destScale, batchTy->getOffset()); |
2171 | auto sliceScaleParams = quantization::quantizeScaleOffset32To8( |
2172 | sliceTy->getScale() / destScale, sliceTy->getOffset()); |
2173 | |
2174 | auto *batchPre = emitConstI32(builder, batchScaleParams.pre); |
2175 | auto *batchPost = emitConstI32(builder, batchScaleParams.post); |
2176 | auto *batchScale = emitConstI32(builder, batchScaleParams.scale); |
2177 | auto *slicePre = emitConstI32(builder, sliceScaleParams.pre); |
2178 | auto *slicePost = emitConstI32(builder, sliceScaleParams.post); |
2179 | auto *sliceScale = emitConstI32(builder, sliceScaleParams.scale); |
2180 | |
2181 | llvm::Function *F = nullptr; |
2182 | if (sliceTy->getElementType() == ElemKind::Int8QTy) { |
2183 | F = getFunction("batchedadd" , dest->getElementType()); |
2184 | } else if (sliceTy->getElementType() == ElemKind::Int32QTy) { |
2185 | F = getFunction("batchedadd_i32" , dest->getElementType()); |
2186 | } else { |
2187 | LOG(FATAL) << "Type is not supported: " |
2188 | << Type::getElementName(sliceTy->getElementType()).str(); |
2189 | } |
2190 | createCall(builder, F, |
2191 | {destPtr, batchPtr, slicePtr, numSlice, sliceSize, destOffset, |
2192 | batchOffset, sliceOffset, batchPre, batchPost, batchScale, |
2193 | slicePre, slicePost, sliceScale}); |
2194 | } else { |
2195 | auto *F = getFunction("batchedadd" , dest->getElementType()); |
2196 | createCall(builder, F, |
2197 | {destPtr, batchPtr, slicePtr, numSlice, sliceSize}); |
2198 | } |
2199 | break; |
2200 | } |
2201 | |
2202 | case Kinded::Kind::BatchedReduceAddInstKind: { |
2203 | auto *BR = cast<BatchedReduceAddInst>(I); |
2204 | auto *dest = BR->getDest(); |
2205 | auto *batch = BR->getBatch(); |
2206 | auto *destPtr = emitValueAddress(builder, dest); |
2207 | auto *batchPtr = emitValueAddress(builder, batch); |
2208 | auto *axis = emitConstDimT(builder, BR->getAxis()); |
2209 | |
2210 | ShapeVector eBatchDims = expandDimsToMax(batch->dims()); |
2211 | ShapeVector eDestDims = eBatchDims; |
2212 | eDestDims[BR->getAxis()] = 1; |
2213 | |
2214 | auto *batchDims = |
2215 | emitConstDimTArray(builder, llvm::makeArrayRef(eBatchDims)); |
2216 | auto *destDims = emitConstDimTArray(builder, llvm::makeArrayRef(eDestDims)); |
2217 | |
2218 | auto *F = getFunction("batchedreduceadd" , dest->getElementType()); |
2219 | |
2220 | if (batch->getType()->isQuantizedType()) { |
2221 | auto *destTy = dest->getType(); |
2222 | auto *batchTy = batch->getType(); |
2223 | |
2224 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
2225 | auto *batchOffset = emitConstI32(builder, batchTy->getOffset()); |
2226 | |
2227 | // BatchedReduceAdd is an accumulation operation, with equations |
2228 | // s_d * (i_d - o_d) = \sum s_b * (i_b - o_b) |
2229 | // => i_d - o_d = \sum (s_b / s_d) * (i_b - o_b) |
2230 | // => i_d = (s_b / s_d ) * [\sum (i_b - o_b)] + o_d |
2231 | auto batchScaleParams = quantization::quantizeScaleOffset32To8( |
2232 | batchTy->getScale() / destTy->getScale(), batchTy->getOffset()); |
2233 | |
2234 | auto *batchPre = emitConstI32(builder, batchScaleParams.pre); |
2235 | auto *batchPost = emitConstI32(builder, batchScaleParams.post); |
2236 | auto *batchScale = emitConstI32(builder, batchScaleParams.scale); |
2237 | |
2238 | createCall(builder, F, |
2239 | {destPtr, batchPtr, destDims, batchDims, destOffset, |
2240 | batchOffset, batchPre, batchPost, batchScale, axis}); |
2241 | } else { |
2242 | auto *destSize = emitConstDimT(builder, dest->size()); |
2243 | |
2244 | createCall(builder, F, |
2245 | {destPtr, batchPtr, destSize, destDims, batchDims, axis}); |
2246 | } |
2247 | break; |
2248 | } |
2249 | |
2250 | case Kinded::Kind::BatchedReduceProdInstKind: { |
2251 | auto *BR = cast<BatchedReduceProdInst>(I); |
2252 | auto *dest = BR->getDest(); |
2253 | auto *batch = BR->getBatch(); |
2254 | auto *destPtr = emitValueAddress(builder, dest); |
2255 | auto *batchPtr = emitValueAddress(builder, batch); |
2256 | auto *axis = emitConstDimT(builder, BR->getAxis()); |
2257 | |
2258 | ShapeVector eBatchDims = expandDimsToMax(batch->dims()); |
2259 | ShapeVector eDestDims = eBatchDims; |
2260 | eDestDims[BR->getAxis()] = 1; |
2261 | |
2262 | auto *batchDims = |
2263 | emitConstDimTArray(builder, llvm::makeArrayRef(eBatchDims)); |
2264 | auto *destDims = emitConstDimTArray(builder, llvm::makeArrayRef(eDestDims)); |
2265 | |
2266 | auto *F = getFunction("batchedreduceprod" , dest->getElementType()); |
2267 | |
2268 | assert(!batch->getType()->isQuantizedType() && |
2269 | "Quantized implementation for ReduceProd not supported yet." ); |
2270 | |
2271 | auto *destSize = emitConstDimT(builder, dest->size()); |
2272 | |
2273 | createCall(builder, F, |
2274 | {destPtr, batchPtr, destSize, destDims, batchDims, axis}); |
2275 | |
2276 | break; |
2277 | } |
2278 | |
2279 | #define BATCHED_REDUCE_MINMAX_CASE(INST_NAME_, FUN_NAME_) \ |
2280 | case Kinded::Kind::Batched##INST_NAME_##InstKind: { \ |
2281 | auto *BR = cast<Batched##INST_NAME_##Inst>(I); \ |
2282 | auto *dest = BR->getDest(); \ |
2283 | auto *batch = BR->getBatch(); \ |
2284 | auto axes = BR->getAxes(); \ |
2285 | auto *destPtr = emitValueAddress(builder, dest); \ |
2286 | auto *batchPtr = emitValueAddress(builder, batch); \ |
2287 | \ |
2288 | ShapeVector eBatchDims = expandDimsToMax(batch->dims()); \ |
2289 | ShapeVector eDestDims = eBatchDims; \ |
2290 | for (dim_t i = 0; i < axes.size(); i++) { \ |
2291 | eDestDims[axes[i]] = 1; \ |
2292 | } \ |
2293 | \ |
2294 | auto *batchDims = \ |
2295 | emitConstDimTArray(builder, llvm::makeArrayRef(eBatchDims)); \ |
2296 | auto *destDims = \ |
2297 | emitConstDimTArray(builder, llvm::makeArrayRef(eDestDims)); \ |
2298 | \ |
2299 | if (((batch->getElementType() != ElemKind::FloatTy) && \ |
2300 | (batch->getElementType() != ElemKind::Int32ITy) && \ |
2301 | (batch->getElementType() != ElemKind::Int64ITy)) || \ |
2302 | (batch->getElementType() != dest->getElementType())) { \ |
2303 | std::string errStr = "Cannot get function for "; \ |
2304 | std::string name = "INST_NAME_"; \ |
2305 | errStr += name; \ |
2306 | llvm_unreachable(errStr.c_str()); \ |
2307 | } \ |
2308 | \ |
2309 | llvm::Function *F = getFunction(FUN_NAME_, batch->getElementType()); \ |
2310 | if (!batch->getType()->isQuantizedType()) { \ |
2311 | auto *destSize = emitConstSizeT(builder, dest->size()); \ |
2312 | \ |
2313 | createCall(builder, F, \ |
2314 | {destPtr, batchPtr, destSize, destDims, batchDims}); \ |
2315 | } \ |
2316 | break; \ |
2317 | } |
2318 | BATCHED_REDUCE_MINMAX_CASE(ReduceMin, "reducemin" ) |
2319 | BATCHED_REDUCE_MINMAX_CASE(ReduceMax, "reducemax" ) |
2320 | #undef BATCHED_REDUCE_MINMAX_CASE |
2321 | |
2322 | case Kinded::Kind::ConvolutionInstKind: { |
2323 | auto *CI = cast<ConvolutionInst>(I); |
2324 | assert(CI->getLayout() == NHWC && |
2325 | "Glow CPU Backend supports only NHWC Convolutions" ); |
2326 | auto *dest = CI->getDest(); |
2327 | auto *src = CI->getSrc(); |
2328 | auto *filter = CI->getFilter(); |
2329 | auto *bias = CI->getBias(); |
2330 | auto *destPtr = emitValueAddress(builder, dest); |
2331 | auto *srcPtr = emitValueAddress(builder, src); |
2332 | auto *filterPtr = emitValueAddress(builder, filter); |
2333 | auto *biasPtr = emitValueAddress(builder, bias); |
2334 | |
2335 | auto *destDims = emitValueDims(builder, dest); |
2336 | auto *srcDims = emitValueDims(builder, src); |
2337 | auto *filterDims = emitValueDims(builder, filter); |
2338 | auto *biasDims = emitValueDims(builder, bias); |
2339 | |
2340 | auto *kernels = emitConstDimTArray(builder, CI->getKernels()); |
2341 | auto *strides = emitConstDimTArray(builder, CI->getStrides()); |
2342 | auto *pads = emitConstDimTArray(builder, CI->getPads()); |
2343 | auto *group = emitConstDimT(builder, CI->getGroup()); |
2344 | auto *dilation = emitConstDimTArray(builder, CI->getDilation()); |
2345 | |
2346 | auto destDepth = dest->dims()[3]; |
2347 | |
2348 | // Try to 'block' the convolution on the 'depth' dimension. We will process |
2349 | // this number output slices each iteration. |
2350 | unsigned unrollDFactor = 1; |
2351 | |
2352 | // In libjit_convolution_f function, 'unrollDFactor' output |
2353 | // layers will be processed together. Therefore, the number of |
2354 | // output layers in each group should be divisible by 'unrollDFactor' |
2355 | bool groupDividedBy8 = ((destDepth / CI->getGroup()) % 8) == 0; |
2356 | if (groupDividedBy8) { |
2357 | unrollDFactor = 8; |
2358 | } |
2359 | |
2360 | auto *unrollD = emitConstI32(builder, unrollDFactor); |
2361 | |
2362 | auto *actType = emitConstI32(builder, CI->getFusedActivation()); |
2363 | |
2364 | if (src->getType()->isQuantizedType()) { |
2365 | auto *destTy = dest->getType(); |
2366 | auto *srcTy = src->getType(); |
2367 | auto *filterTy = filter->getType(); |
2368 | auto *biasTy = bias->getType(); |
2369 | |
2370 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
2371 | auto *srcOffset = emitConstI32(builder, srcTy->getOffset()); |
2372 | auto *filterOffset = emitConstI32(builder, filterTy->getOffset()); |
2373 | auto *biasOffset = emitConstI32(builder, biasTy->getOffset()); |
2374 | |
2375 | // Calculate the scale of the values that come out of the matrix |
2376 | // multiplication part of the calculation. |
2377 | float matMulScale = srcTy->getScale() * filterTy->getScale(); |
2378 | |
2379 | // Calculate the scaling parameters for the bias and output. |
2380 | auto biasScaleParam = quantization::quantizeScaleOffset32To8( |
2381 | biasTy->getScale() / matMulScale, biasTy->getOffset()); |
2382 | auto outScaleParam = quantization::quantizeScaleOffset32To8( |
2383 | matMulScale / destTy->getScale(), 0); |
2384 | |
2385 | // Pass the pre-shift, post-shift and integer scale parameters for the |
2386 | // bias and output calculation. |
2387 | auto *biasPre = emitConstI32(builder, biasScaleParam.pre); |
2388 | auto *biasPost = emitConstI32(builder, biasScaleParam.post); |
2389 | auto *biasScale = emitConstI32(builder, biasScaleParam.scale); |
2390 | auto *outPre = emitConstI32(builder, outScaleParam.pre); |
2391 | auto *outPost = emitConstI32(builder, outScaleParam.post); |
2392 | auto *outScale = emitConstI32(builder, outScaleParam.scale); |
2393 | |
2394 | // Emit parameters for fused activation. |
2395 | auto *actArgsQuant = emitConstQuantActivationArgs(builder, CI); |
2396 | |
2397 | auto *F = getFunction("conv2d" , |
2398 | {dest->getElementType(), bias->getElementType()}); |
2399 | |
2400 | createCall(builder, F, |
2401 | {destPtr, srcPtr, filterPtr, biasPtr, destDims, |
2402 | srcDims, filterDims, biasDims, kernels, strides, |
2403 | pads, group, destOffset, srcOffset, filterOffset, |
2404 | biasOffset, biasPre, biasPost, biasScale, outPre, |
2405 | outPost, outScale, unrollD, dilation, actType, |
2406 | actArgsQuant}); |
2407 | } else { |
2408 | |
2409 | // Emit parameters for fused activation. |
2410 | auto *actArgsFloat = emitConstFloatActivationArgs(builder, CI); |
2411 | |
2412 | auto *F = getFunction("conv2d" , dest->getElementType()); |
2413 | |
2414 | createCall(builder, F, |
2415 | {destPtr, srcPtr, filterPtr, biasPtr, destDims, srcDims, |
2416 | filterDims, biasDims, kernels, strides, pads, group, unrollD, |
2417 | dilation, actType, actArgsFloat}); |
2418 | } |
2419 | break; |
2420 | } |
2421 | |
2422 | case Kinded::Kind::ConvolutionGradInstKind: { |
2423 | auto *CG = cast<ConvolutionGradInst>(I); |
2424 | auto *srcGrad = CG->getSrcGrad(); |
2425 | auto *destGrad = CG->getDestGrad(); |
2426 | auto *src = CG->getSrc(); |
2427 | auto *filterGrad = CG->getFilterGrad(); |
2428 | auto *srcGradPtr = emitValueAddress(builder, srcGrad); |
2429 | auto *destGradPtr = emitValueAddress(builder, destGrad); |
2430 | auto *srcPtr = emitValueAddress(builder, src); |
2431 | auto *filterGradPtr = emitValueAddress(builder, filterGrad); |
2432 | auto *biasGradPtr = emitValueAddress(builder, CG->getBiasGrad()); |
2433 | auto *filterPtr = emitValueAddress(builder, CG->getFilter()); |
2434 | |
2435 | auto *destGradDims = emitValueDims(builder, destGrad); |
2436 | auto *srcDims = emitValueDims(builder, src); |
2437 | auto *filterGradDims = emitValueDims(builder, filterGrad); |
2438 | |
2439 | auto *kernels = emitConstDimTArray(builder, CG->getKernels()); |
2440 | auto *strides = emitConstDimTArray(builder, CG->getStrides()); |
2441 | auto *pads = emitConstDimTArray(builder, CG->getPads()); |
2442 | auto *group = emitConstDimT(builder, CG->getGroup()); |
2443 | auto *dilation = emitConstDimTArray(builder, CG->getDilation()); |
2444 | |
2445 | auto *F = getFunction("convolution_grad" , srcGrad->getElementType()); |
2446 | createCall(builder, F, |
2447 | {srcGradPtr, destGradPtr, srcPtr, filterGradPtr, biasGradPtr, |
2448 | filterPtr, destGradDims, srcDims, filterGradDims, kernels, |
2449 | strides, pads, group, dilation}); |
2450 | break; |
2451 | } |
2452 | |
2453 | case Kinded::Kind::ConvTransposeInstKind: { |
2454 | auto *CI = cast<ConvTransposeInst>(I); |
2455 | auto *dest = CI->getDest(); |
2456 | auto *src = CI->getSrc(); |
2457 | auto *filter = CI->getFilter(); |
2458 | auto *bias = CI->getBias(); |
2459 | auto *destPtr = emitValueAddress(builder, dest); |
2460 | auto *srcPtr = emitValueAddress(builder, src); |
2461 | auto *filterPtr = emitValueAddress(builder, filter); |
2462 | auto *biasPtr = emitValueAddress(builder, bias); |
2463 | |
2464 | auto *destDims = emitValueDims(builder, dest); |
2465 | auto *srcDims = emitValueDims(builder, src); |
2466 | auto *filterDims = emitValueDims(builder, filter); |
2467 | auto *biasDims = emitValueDims(builder, bias); |
2468 | |
2469 | auto *kernels = emitConstDimTArray(builder, CI->getKernels()); |
2470 | auto *strides = emitConstDimTArray(builder, CI->getStrides()); |
2471 | auto *pads = emitConstDimTArray(builder, CI->getPads()); |
2472 | auto *group = emitConstDimT(builder, CI->getGroup()); |
2473 | auto *dilation = emitConstDimTArray(builder, CI->getDilation()); |
2474 | |
2475 | const char *kernelName = "conv_transpose" ; |
2476 | |
2477 | auto *F = getFunction(kernelName, dest->getElementType()); |
2478 | |
2479 | if (src->getType()->isQuantizedType()) { |
2480 | auto *destTy = dest->getType(); |
2481 | auto *srcTy = src->getType(); |
2482 | auto *filterTy = filter->getType(); |
2483 | |
2484 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
2485 | auto *srcOffset = emitConstI32(builder, srcTy->getOffset()); |
2486 | auto *filterOffset = emitConstI32(builder, filterTy->getOffset()); |
2487 | |
2488 | // Calculate the scale of the values that come out of the matrix |
2489 | // multiplication part of the calculation. |
2490 | float matMulScale = srcTy->getScale() * filterTy->getScale(); |
2491 | |
2492 | // Calculate the scaling parameters for the bias and output. |
2493 | auto outScaleParam = quantization::quantizeScaleOffset32To8( |
2494 | matMulScale / destTy->getScale(), 0); |
2495 | |
2496 | // Pass the pre-shift, post-shift and integer scale parameters for the |
2497 | // output calculation. |
2498 | auto *outPre = emitConstI32(builder, outScaleParam.pre); |
2499 | auto *outPost = emitConstI32(builder, outScaleParam.post); |
2500 | auto *outScale = emitConstI32(builder, outScaleParam.scale); |
2501 | |
2502 | createCall(builder, F, |
2503 | {destPtr, srcPtr, filterPtr, biasPtr, destDims, srcDims, |
2504 | filterDims, biasDims, kernels, strides, pads, group, |
2505 | destOffset, srcOffset, filterOffset, outPre, outPost, |
2506 | outScale, dilation}); |
2507 | } else { |
2508 | createCall(builder, F, |
2509 | {destPtr, srcPtr, filterPtr, biasPtr, destDims, srcDims, |
2510 | filterDims, biasDims, kernels, strides, pads, group, |
2511 | dilation}); |
2512 | } |
2513 | break; |
2514 | } |
2515 | |
2516 | case Kinded::Kind::ChannelwiseQuantizedConvolutionInstKind: { |
2517 | auto *CQCI = cast<ChannelwiseQuantizedConvolutionInst>(I); |
2518 | auto *dest = CQCI->getDest(); |
2519 | auto *src = CQCI->getSrc(); |
2520 | auto *filter = CQCI->getFilter(); |
2521 | auto *bias = CQCI->getBias(); |
2522 | auto *filterScales = CQCI->getFilterScales(); |
2523 | auto *filterOffsets = CQCI->getFilterOffsets(); |
2524 | auto *biasScales = CQCI->getBiasScales(); |
2525 | auto *biasOffsets = CQCI->getBiasOffsets(); |
2526 | |
2527 | auto *destTy = dest->getType(); |
2528 | auto *srcTy = src->getType(); |
2529 | |
2530 | auto filterScalesT = getTensorForConstantValue(filterScales); |
2531 | auto filterScalesH = filterScalesT.getHandle<float>(); |
2532 | |
2533 | auto biasScalesT = getTensorForConstantValue(biasScales); |
2534 | auto biasScalesH = biasScalesT.getHandle<float>(); |
2535 | |
2536 | // Compute quantization parameters for each channel. |
2537 | auto channelNum = dest->dims().back(); |
2538 | std::vector<llvm::Constant *> biasPreV(channelNum); |
2539 | std::vector<llvm::Constant *> biasPostV(channelNum); |
2540 | std::vector<llvm::Constant *> biasScaleV(channelNum); |
2541 | std::vector<llvm::Constant *> outputPreV(channelNum); |
2542 | std::vector<llvm::Constant *> outputPostV(channelNum); |
2543 | std::vector<llvm::Constant *> outputScaleV(channelNum); |
2544 | for (size_t i = 0; i < channelNum; i++) { |
2545 | |
2546 | // Compute the scaling parameters for bias and output. |
2547 | float matMulScale = srcTy->getScale() * filterScalesH.raw(i); |
2548 | auto biasScaleParam = quantization::quantizeScaleOffset32To8( |
2549 | biasScalesH.raw(i) / matMulScale, 0); |
2550 | auto outScaleParam = quantization::quantizeScaleOffset32To8( |
2551 | matMulScale / destTy->getScale(), 0); |
2552 | |
2553 | // Pass the pre-shift, post-shift and integer scale parameters for the |
2554 | // bias and output calculation. |
2555 | biasPreV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2556 | biasScaleParam.pre, true); |
2557 | biasPostV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2558 | biasScaleParam.post, true); |
2559 | biasScaleV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2560 | biasScaleParam.scale, true); |
2561 | outputPreV[i] = |
2562 | llvm::ConstantInt::get(builder.getInt32Ty(), outScaleParam.pre, true); |
2563 | outputPostV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2564 | outScaleParam.post, true); |
2565 | outputScaleV[i] = llvm::ConstantInt::get(builder.getInt32Ty(), |
2566 | outScaleParam.scale, true); |
2567 | } |
2568 | |
2569 | auto *destPtr = emitValueAddress(builder, dest); |
2570 | auto *srcPtr = emitValueAddress(builder, src); |
2571 | auto *filterPtr = emitValueAddress(builder, filter); |
2572 | auto *biasPtr = emitValueAddress(builder, bias); |
2573 | |
2574 | auto *destDims = emitValueDims(builder, dest); |
2575 | auto *srcDims = emitValueDims(builder, src); |
2576 | auto *filterDims = emitValueDims(builder, filter); |
2577 | auto *biasDims = emitValueDims(builder, bias); |
2578 | |
2579 | auto *kernels = emitConstDimTArray(builder, CQCI->getKernels()); |
2580 | auto *strides = emitConstDimTArray(builder, CQCI->getStrides()); |
2581 | auto *pads = emitConstDimTArray(builder, CQCI->getPads()); |
2582 | auto *group = emitConstDimT(builder, CQCI->getGroup()); |
2583 | auto *dilation = emitConstDimTArray(builder, CQCI->getDilation()); |
2584 | |
2585 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
2586 | auto *srcOffset = emitConstI32(builder, srcTy->getOffset()); |
2587 | auto *filterOffsetsPtr = emitValueAddress(builder, filterOffsets); |
2588 | auto *biasOffsetsPtr = emitValueAddress(builder, biasOffsets); |
2589 | |
2590 | auto *biasPrePtr = emitConstArray(builder, biasPreV, builder.getInt32Ty()); |
2591 | auto *biasPostPtr = |
2592 | emitConstArray(builder, biasPostV, builder.getInt32Ty()); |
2593 | auto *biasScalePtr = |
2594 | emitConstArray(builder, biasScaleV, builder.getInt32Ty()); |
2595 | auto *outputPrePtr = |
2596 | emitConstArray(builder, outputPreV, builder.getInt32Ty()); |
2597 | auto *outputPostPtr = |
2598 | emitConstArray(builder, outputPostV, builder.getInt32Ty()); |
2599 | auto *outputScalePtr = |
2600 | emitConstArray(builder, outputScaleV, builder.getInt32Ty()); |
2601 | |
2602 | bool isConv3D = (srcTy->dims().size() == 5); |
2603 | auto *F = getFunction(isConv3D ? "channelwise_quantized_conv3d" |
2604 | : "channelwise_quantized_conv2d" , |
2605 | {dest->getElementType(), bias->getElementType()}); |
2606 | |
2607 | auto *actType = emitConstI32(builder, CQCI->getFusedActivation()); |
2608 | auto *actArgsQuant = emitConstQuantActivationArgs(builder, CQCI); |
2609 | |
2610 | createCall(builder, F, |
2611 | {destPtr, srcPtr, filterPtr, biasPtr, |
2612 | destDims, srcDims, filterDims, biasDims, |
2613 | kernels, strides, pads, group, |
2614 | dilation, destOffset, srcOffset, filterOffsetsPtr, |
2615 | biasOffsetsPtr, biasPrePtr, biasPostPtr, biasScalePtr, |
2616 | outputPrePtr, outputPostPtr, outputScalePtr, actType, |
2617 | actArgsQuant}); |
2618 | break; |
2619 | } |
2620 | |
2621 | case Kinded::Kind::CrossEntropyLossInstKind: { |
2622 | auto *CI = cast<CrossEntropyLossInst>(I); |
2623 | auto *P = CI->getP(); |
2624 | auto *labels = CI->getLabels(); |
2625 | auto *CE = CI->getCE(); |
2626 | |
2627 | auto *CEPtr = emitValueAddress(builder, CE); |
2628 | auto *PPtr = emitValueAddress(builder, P); |
2629 | auto *labelsPtr = emitValueAddress(builder, labels); |
2630 | auto *dims = emitValueDims(builder, P); |
2631 | |
2632 | auto *F = getFunction("cross_entropy_loss" , |
2633 | {CE->getElementType(), labels->getElementType()}); |
2634 | createCall(builder, F, {CEPtr, PPtr, labelsPtr, dims}); |
2635 | break; |
2636 | } |
2637 | |
2638 | case Kinded::Kind::LengthsToRangesInstKind: { |
2639 | auto *LTR = cast<LengthsToRangesInst>(I); |
2640 | auto *dest = LTR->getDest(); |
2641 | auto *lengths = LTR->getLengths(); |
2642 | auto *destPtr = emitValueAddress(builder, dest); |
2643 | auto *lengthsPtr = emitValueAddress(builder, lengths); |
2644 | auto *size = emitConstDimT(builder, lengths->dims()[0]); |
2645 | auto *F = getFunction("lengths_to_ranges" , dest->getElementType()); |
2646 | createCall(builder, F, {destPtr, lengthsPtr, size}); |
2647 | break; |
2648 | } |
2649 | |
2650 | case Kinded::Kind::LengthsSumInstKind: { |
2651 | auto *LS = cast<LengthsSumInst>(I); |
2652 | auto *dest = LS->getDest(); |
2653 | auto *data = LS->getData(); |
2654 | auto *lengths = LS->getLengths(); |
2655 | |
2656 | auto *destPtr = emitValueAddress(builder, dest); |
2657 | auto *dataPtr = emitValueAddress(builder, data); |
2658 | auto *lengthsPtr = emitValueAddress(builder, lengths); |
2659 | |
2660 | auto *lengthsSize = emitConstDimT(builder, lengths->size()); |
2661 | auto *dataType = data->getType(); |
2662 | auto *destSize = emitConstDimT(builder, dest->size()); |
2663 | auto *sliceSize = |
2664 | emitConstDimT(builder, dataType->size() / dataType->dims()[0]); |
2665 | |
2666 | auto *F = getFunction("lengths_sum" , data->getElementType()); |
2667 | createCall( |
2668 | builder, F, |
2669 | {destPtr, dataPtr, lengthsPtr, destSize, lengthsSize, sliceSize}); |
2670 | break; |
2671 | } |
2672 | |
2673 | case Kinded::Kind::LocalResponseNormalizationInstKind: { |
2674 | auto *LRN = cast<LocalResponseNormalizationInst>(I); |
2675 | auto *dest = LRN->getDest(); |
2676 | auto *src = LRN->getSrc(); |
2677 | auto *destPtr = emitValueAddress(builder, dest); |
2678 | auto *srcPtr = emitValueAddress(builder, src); |
2679 | auto *scalePtr = emitValueAddress(builder, LRN->getScale()); |
2680 | |
2681 | auto *destDims = emitValueDims(builder, dest); |
2682 | auto *srcDims = emitValueDims(builder, src); |
2683 | auto *halfWindow = emitConstDimT(builder, LRN->getHalfWindowSize()); |
2684 | auto *alpha = emitConstF32(builder, LRN->getAlpha()); |
2685 | auto *beta = emitConstF32(builder, LRN->getBeta()); |
2686 | auto *k = emitConstF32(builder, LRN->getK()); |
2687 | |
2688 | auto *F = |
2689 | getFunction("local_response_normalization" , dest->getElementType()); |
2690 | createCall(builder, F, |
2691 | {destPtr, srcPtr, scalePtr, destDims, srcDims, halfWindow, alpha, |
2692 | beta, k}); |
2693 | break; |
2694 | } |
2695 | |
2696 | case Kinded::Kind::LocalResponseNormalizationGradInstKind: { |
2697 | auto *LRNG = llvm::cast<LocalResponseNormalizationGradInst>(I); |
2698 | auto *srcGrad = LRNG->getSrcGrad(); |
2699 | auto *dest = LRNG->getDest(); |
2700 | auto *srcGradPtr = emitValueAddress(builder, srcGrad); |
2701 | auto *destGradPtr = emitValueAddress(builder, LRNG->getDestGrad()); |
2702 | auto *srcPtr = emitValueAddress(builder, LRNG->getSrc()); |
2703 | auto *destPtr = emitValueAddress(builder, dest); |
2704 | auto *scalePtr = emitValueAddress(builder, LRNG->getScale()); |
2705 | |
2706 | auto *destDims = emitValueDims(builder, dest); |
2707 | |
2708 | auto *halfWindow = emitConstDimT(builder, LRNG->getHalfWindowSize()); |
2709 | auto *alpha = emitConstF32(builder, LRNG->getAlpha()); |
2710 | auto *beta = emitConstF32(builder, LRNG->getBeta()); |
2711 | |
2712 | auto *F = getFunction("local_response_normalization_grad" , |
2713 | srcGrad->getElementType()); |
2714 | createCall(builder, F, |
2715 | {srcGradPtr, destGradPtr, srcPtr, destPtr, scalePtr, destDims, |
2716 | halfWindow, alpha, beta}); |
2717 | break; |
2718 | } |
2719 | |
2720 | case Kinded::Kind::MaxPoolInstKind: { |
2721 | auto *PM = cast<MaxPoolInst>(I); |
2722 | assert(PM->getLayout() == NHWC && |
2723 | "Glow CPU Backend supports only NHWC Pools" ); |
2724 | auto *dest = PM->getDest(); |
2725 | auto *src = PM->getSrc(); |
2726 | auto *destPtr = emitValueAddress(builder, dest); |
2727 | auto *srcPtr = emitValueAddress(builder, src); |
2728 | |
2729 | auto *destDims = emitValueDims(builder, dest); |
2730 | auto *srcDims = emitValueDims(builder, src); |
2731 | |
2732 | auto *kernels = emitConstDimTArray(builder, PM->getKernels()); |
2733 | auto *strides = emitConstDimTArray(builder, PM->getStrides()); |
2734 | auto *pads = emitConstDimTArray(builder, PM->getPads()); |
2735 | |
2736 | auto *F = getFunction("max_pool" , dest->getElementType()); |
2737 | |
2738 | if (src->getType()->isQuantizedType()) { |
2739 | auto *destOffset = emitConstI32(builder, dest->getType()->getOffset()); |
2740 | createCall(builder, F, |
2741 | {srcPtr, destPtr, srcDims, destDims, kernels, strides, pads, |
2742 | destOffset}); |
2743 | } else { |
2744 | createCall(builder, F, |
2745 | {srcPtr, destPtr, srcDims, destDims, kernels, strides, pads}); |
2746 | } |
2747 | break; |
2748 | } |
2749 | |
2750 | case Kinded::Kind::MaxPoolWithArgmaxInstKind: { |
2751 | auto *PMXY = cast<MaxPoolWithArgmaxInst>(I); |
2752 | assert(PMXY->getLayout() == NHWC && |
2753 | "Glow CPU Backend supports only NHWC Pools" ); |
2754 | auto *dest = PMXY->getDest(); |
2755 | auto *src = PMXY->getSrc(); |
2756 | auto *destPtr = emitValueAddress(builder, dest); |
2757 | auto *srcPtr = emitValueAddress(builder, src); |
2758 | auto *argMax = PMXY->getArgmax(); |
2759 | auto *argmaxPtr = emitValueAddress(builder, argMax); |
2760 | |
2761 | auto *destDims = emitValueDims(builder, dest); |
2762 | auto *srcDims = emitValueDims(builder, src); |
2763 | |
2764 | auto *kernels = emitConstDimTArray(builder, PMXY->getKernels()); |
2765 | auto *strides = emitConstDimTArray(builder, PMXY->getStrides()); |
2766 | auto *pads = emitConstDimTArray(builder, PMXY->getPads()); |
2767 | |
2768 | auto *F = getFunction("max_pool_argmax" , |
2769 | {dest->getElementType(), argMax->getElementType()}); |
2770 | createCall(builder, F, |
2771 | {srcPtr, destPtr, argmaxPtr, srcDims, destDims, kernels, strides, |
2772 | pads}); |
2773 | break; |
2774 | } |
2775 | |
2776 | case Kinded::Kind::MaxPoolWithArgmaxGradInstKind: { |
2777 | auto *PMG = cast<MaxPoolWithArgmaxGradInst>(I); |
2778 | auto *srcGrad = PMG->getSrcGrad(); |
2779 | auto *srcGradPtr = emitValueAddress(builder, srcGrad); |
2780 | auto *destGradPtr = emitValueAddress(builder, PMG->getDestGrad()); |
2781 | auto *argMax = PMG->getArgmax(); |
2782 | auto *argmaxPtr = emitValueAddress(builder, argMax); |
2783 | |
2784 | auto *srcGradDims = emitValueDims(builder, srcGrad); |
2785 | auto *destDims = emitValueDims(builder, PMG->getDest()); |
2786 | |
2787 | auto *F = getFunction("max_pool_argmax_grad" , {srcGrad->getElementType(), |
2788 | argMax->getElementType()}); |
2789 | createCall(builder, F, |
2790 | {srcGradPtr, destGradPtr, argmaxPtr, srcGradDims, destDims}); |
2791 | break; |
2792 | } |
2793 | |
2794 | case Kinded::Kind::ArgMaxInstKind: { |
2795 | auto *AM = cast<ArgMaxInst>(I); |
2796 | auto *dest = AM->getDest(); |
2797 | auto *src = AM->getSrc(); |
2798 | auto *destPtr = emitValueAddress(builder, dest); |
2799 | auto *srcPtr = emitValueAddress(builder, src); |
2800 | auto *srcDims = emitValueDims(builder, src); |
2801 | auto *srcNumDims = emitConstSizeT(builder, src->dims().size()); |
2802 | auto *axis = emitConstSizeT(builder, AM->getAxis()); |
2803 | auto *F = |
2804 | getFunction("arg_max" , {src->getElementType(), dest->getElementType()}); |
2805 | createCall(builder, F, {srcPtr, destPtr, srcDims, srcNumDims, axis}); |
2806 | break; |
2807 | } |
2808 | |
2809 | case Kinded::Kind::ArgMinInstKind: { |
2810 | auto *AM = cast<ArgMinInst>(I); |
2811 | auto *dest = AM->getDest(); |
2812 | auto *src = AM->getSrc(); |
2813 | auto *destPtr = emitValueAddress(builder, dest); |
2814 | auto *srcPtr = emitValueAddress(builder, src); |
2815 | auto *srcDims = emitValueDims(builder, src); |
2816 | auto *srcNumDims = emitConstSizeT(builder, src->dims().size()); |
2817 | auto *axis = emitConstSizeT(builder, AM->getAxis()); |
2818 | auto *F = |
2819 | getFunction("arg_min" , {src->getElementType(), dest->getElementType()}); |
2820 | createCall(builder, F, {srcPtr, destPtr, srcDims, srcNumDims, axis}); |
2821 | break; |
2822 | } |
2823 | |
2824 | case Kinded::Kind::AvgPoolInstKind: { |
2825 | auto *PA = cast<AvgPoolInst>(I); |
2826 | assert(PA->getLayout() == NHWC && |
2827 | "Glow CPU Backend supports only NHWC Pools" ); |
2828 | auto *dest = PA->getDest(); |
2829 | auto *src = PA->getSrc(); |
2830 | auto *destPtr = emitValueAddress(builder, dest); |
2831 | auto *srcPtr = emitValueAddress(builder, src); |
2832 | |
2833 | auto *destDims = emitValueDims(builder, dest); |
2834 | auto *srcDims = emitValueDims(builder, src); |
2835 | |
2836 | auto *kernels = emitConstDimTArray(builder, PA->getKernels()); |
2837 | auto *strides = emitConstDimTArray(builder, PA->getStrides()); |
2838 | auto *pads = emitConstDimTArray(builder, PA->getPads()); |
2839 | auto *countIncludePads = emitConstI1(builder, PA->getCountIncludePads()); |
2840 | |
2841 | auto *F = getFunction("avg_pool" , dest->getElementType()); |
2842 | |
2843 | if (src->getType()->isQuantizedType()) { |
2844 | auto *destTy = dest->getType(); |
2845 | auto *srcTy = src->getType(); |
2846 | auto *destOffset = emitConstI32(builder, destTy->getOffset()); |
2847 | auto *srcOffset = emitConstI32(builder, srcTy->getOffset()); |
2848 | // When we count the padding pixels in the normalizing factor we include |
2849 | // the filter area in the scaling parameters since it is a constant. |
2850 | float scale = srcTy->getScale() / destTy->getScale(); |
2851 | if (PA->getCountIncludePads()) { |
2852 | scale = scale / (PA->getKernels()[0] * PA->getKernels()[1]); |
2853 | } |
2854 | auto outScaleParam = quantization::quantizeScaleOffset32To8(scale, 0); |
2855 | auto *outPre = emitConstI32(builder, outScaleParam.pre); |
2856 | auto *outPost = emitConstI32(builder, outScaleParam.post); |
2857 | auto *outScale = emitConstI32(builder, outScaleParam.scale); |
2858 | createCall(builder, F, |
2859 | {srcPtr, destPtr, srcDims, destDims, kernels, strides, pads, |
2860 | countIncludePads, destOffset, srcOffset, outPre, outPost, |
2861 | outScale}); |
2862 | } else { |
2863 | createCall(builder, F, |
2864 | {srcPtr, destPtr, srcDims, destDims, kernels, strides, pads, |
2865 | countIncludePads}); |
2866 | } |
2867 | break; |
2868 | } |
2869 | |
2870 | case Kinded::Kind::AdaptiveAvgPoolInstKind: { |
2871 | auto *PA = cast<AdaptiveAvgPoolInst>(I); |
2872 | |
2873 | auto *dest = PA->getDest(); |
2874 | auto *src = PA->getSrc(); |
2875 | auto *destPtr = emitValueAddress(builder, dest); |
2876 | auto *srcPtr = emitValueAddress(builder, src); |
2877 | |
2878 | auto *destDims = emitValueDims(builder, dest); |
2879 | auto *srcDims = emitValueDims(builder, src); |
2880 | |
2881 | auto *F = getFunction("adaptive_avg_pool" , dest->getElementType()); |
2882 | createCall(builder, F, {srcPtr, destPtr, srcDims, destDims}); |
2883 | break; |
2884 | } |
2885 | |
2886 | case Kinded::Kind::AvgPoolGradInstKind: { |
2887 | auto *PAG = cast<AvgPoolGradInst>(I); |
2888 | auto *srcGrad = PAG->getSrcGrad(); |
2889 | auto *srcGradPtr = emitValueAddress(builder, srcGrad); |
2890 | auto *destGradPtr = emitValueAddress(builder, PAG->getDestGrad()); |
2891 | |
2892 | auto *srcGradDims = emitValueDims(builder, srcGrad); |
2893 | auto *destDims = emitValueDims(builder, PAG->getDest()); |
2894 | |
2895 | auto *kernels = emitConstDimTArray(builder, PAG->getKernels()); |
2896 | auto *strides = emitConstDimTArray(builder, PAG->getStrides()); |
2897 | auto *pads = emitConstDimTArray(builder, PAG->getPads()); |
2898 | auto *countIncludePads = emitConstI1(builder, PAG->getCountIncludePads()); |
2899 | |
2900 | auto *F = getFunction("avg_pool_grad" , srcGrad->getElementType()); |
2901 | createCall(builder, F, |
2902 | {srcGradPtr, destGradPtr, srcGradDims, destDims, kernels, |
2903 | strides, pads, countIncludePads}); |
2904 | break; |
2905 | } |
2906 | |
2907 | case Kinded::Kind::SoftMaxInstKind: { |
2908 | auto *SM = cast<SoftMaxInst>(I); |
2909 | auto *dest = SM->getDest(); |
2910 | auto *src = SM->getSrc(); |
2911 | auto *destPtr = emitValueAddress(builder, dest); |
2912 | auto *srcPtr = emitValueAddress(builder, src); |
2913 | auto *destDims = emitValueDims(builder, dest); |
2914 | auto *srcDims = emitValueDims(builder, src); |
2915 | auto *F = getFunction("softmax" , dest->getElementType()); |
2916 | |
2917 | if (src->getType()->isQuantizedType()) { |
2918 | std::vector<int32_t> lut; |
2919 | |
2920 | // Compute lookup table containing all the exponentials based on the |
2921 | // formula e^(scale * value), where scale is the input scale of |
2922 | // the quantized input data and value is a value from [-255, 0]. |
2923 | for (int32_t i = 0; i < 256; i++) { |
2924 | auto exponent = |
2925 | FixedPointUInt32(exp(src->getType()->getScale() * (i - 255)), 1) |
2926 | .getFixedVal(); |
2927 | lut.push_back(exponent); |
2928 | } |
2929 | |
2930 | auto *lutPtr = emitConstI32Array(builder, lut); |
2931 | auto *outOffset = emitConstI32(builder, dest->getType()->getOffset()); |
2932 | float size = static_cast<float>(src->getType()->dims()[1]); |
2933 | auto *sumIntegerPart = emitConstI32(builder, ceil(log2(size))); |
2934 | |
2935 | if (ceil(log2(size)) == floor(log2(size))) { |
2936 | sumIntegerPart = emitConstI32(builder, ceil(log2(size)) + 1); |
2937 | } |
2938 | |
2939 | FixedPointUInt32 invScaleFixedPoint = |
2940 | FixedPointUInt32(1.f / dest->getType()->getScale()); |
2941 | auto *invScale = emitConstI32(builder, invScaleFixedPoint.getFixedVal()); |
2942 | auto *invScalePoint = |
2943 | emitConstI32(builder, invScaleFixedPoint.getIntBits()); |
2944 | createCall(builder, F, |
2945 | {srcPtr, destPtr, srcDims, lutPtr, outOffset, invScale, |
2946 | sumIntegerPart, invScalePoint}); |
2947 | } else { |
2948 | createCall(builder, F, {srcPtr, destPtr, srcDims, destDims}); |
2949 | } |
2950 | |
2951 | break; |
2952 | } |
2953 | |
2954 | case Kinded::Kind::SoftMaxGradInstKind: { |
2955 | auto *SMG = cast<SoftMaxGradInst>(I); |
2956 | auto *srcGrad = SMG->getSrcGrad(); |
2957 | auto *selected = SMG->getSelected(); |
2958 | auto *srcGradPtr = emitValueAddress(builder, srcGrad); |
2959 | auto *destPtr = emitValueAddress(builder, SMG->getOrigDest()); |
2960 | auto *selectedPtr = emitValueAddress(builder, selected); |
2961 | |
2962 | auto *srcGradDims = emitValueDims(builder, srcGrad); |
2963 | auto *selectedDims = emitValueDims(builder, selected); |
2964 | |
2965 | auto *F = getFunction("softmax_grad" , {srcGrad->getElementType(), |
2966 | selected->getElementType()}); |
2967 | createCall(builder, F, |
2968 | {srcGradPtr, destPtr, selectedPtr, srcGradDims, selectedDims}); |
2969 | break; |
2970 | } |
2971 | |
2972 | case Kinded::Kind::TopKInstKind: { |
2973 | auto *TI = cast<TopKInst>(I); |
2974 | auto *input = TI->getInput(); |
2975 | auto *valuesPtr = emitValueAddress(builder, TI->getValues()); |
2976 | auto *indicesPtr = emitValueAddress(builder, TI->getIndices()); |
2977 | auto *inputPtr = emitValueAddress(builder, input); |
2978 | auto *scratchPtr = emitValueAddress(builder, TI->getScratch()); |
2979 | |
2980 | auto *k = emitConstDimT(builder, TI->getK()); |
2981 | auto *n = emitConstDimT(builder, input->dims().back()); |
2982 | auto *size = emitConstDimT(builder, input->size()); |
2983 | |
2984 | auto indicesTy = TI->getIndices()->getElementType(); |
2985 | auto *F = getFunction("topk" , {input->getElementType(), indicesTy}); |
2986 | |
2987 | createCall(builder, F, |
2988 | {valuesPtr, indicesPtr, inputPtr, scratchPtr, k, n, size}); |
2989 | break; |
2990 | } |
2991 | |
2992 | case Kinded::Kind::SpaceToDepthInstKind: { |
2993 | auto *SI = cast<SpaceToDepthInst>(I); |
2994 | auto *dest = SI->getDest(); |
2995 | auto *src = SI->getSrc(); |
2996 | |
2997 | auto *dstPtr = emitValueAddress(builder, dest); |
2998 | auto *srcPtr = emitValueAddress(builder, src); |
2999 | |
3000 | auto *dstDims = emitValueDims(builder, dest); |
3001 | auto *srcDims = emitValueDims(builder, src); |
3002 | |
3003 | unsigned blockSize = SI->getBlockSize(); |
3004 | |
3005 | auto *F = getFunction("space_to_depth" , src->getElementType()); |
3006 | createCall( |
3007 | builder, F, |
3008 | {srcPtr, dstPtr, emitConstDimT(builder, blockSize), srcDims, dstDims}); |
3009 | break; |
3010 | } |
3011 | |
3012 | case Kinded::Kind::TransposeInstKind: { |
3013 | auto *TI = cast<TransposeInst>(I); |
3014 | auto *dest = TI->getDest(); |
3015 | auto *src = TI->getSrc(); |
3016 | auto *destPtr = emitValueAddress(builder, dest); |
3017 | auto *srcPtr = emitValueAddress(builder, src); |
3018 | |
3019 | auto *destDims = emitValueDims(builder, dest); |
3020 | auto *srcDims = emitValueDims(builder, src); |
3021 | |
3022 | // Convert the mask to size_t type. |
3023 | ShapeVector shuffSizeT; |
3024 | for (auto D : TI->getShuffle()) { |
3025 | shuffSizeT.push_back((size_t)D); |
3026 | } |
3027 | |
3028 | auto *shuffle = emitConstDimTArray(builder, llvm::makeArrayRef(shuffSizeT)); |
3029 | auto *len = emitConstDimT(builder, TI->getShuffle().size()); |
3030 | |
3031 | auto *F = getFunction("transpose" , dest->getElementType()); |
3032 | createCall(builder, F, {srcPtr, destPtr, srcDims, destDims, shuffle, len}); |
3033 | break; |
3034 | } |
3035 | |
3036 | case Kinded::Kind::FlipInstKind: { |
3037 | auto *FI = cast<FlipInst>(I); |
3038 | auto *dest = FI->getDest(); |
3039 | auto *src = FI->getSrc(); |
3040 | auto *destPtr = emitValueAddress(builder, dest); |
3041 | auto *srcPtr = emitValueAddress(builder, src); |
3042 | auto *dims = emitValueDims(builder, src); |
3043 | auto *axis = emitConstDimT(builder, FI->getAxis()); |
3044 | auto *dimsSize = emitConstDimT(builder, src->getType()->dims().size()); |
3045 | auto *F = getFunction("flip" , src->getElementType()); |
3046 | createCall(builder, F, {srcPtr, destPtr, dims, axis, dimsSize}); |
3047 | break; |
3048 | } |
3049 | |
3050 | // Alloc and Dealloc instructions are handled by the memory allocator. |
3051 | case Kinded::Kind::AllocActivationInstKind: |
3052 | case Kinded::Kind::DeallocActivationInstKind: |
3053 | case Kinded::Kind::TensorViewInstKind: |
3054 | break; |
3055 | |
3056 | case Kinded::Kind::InsertTensorInstKind: { |
3057 | auto *ITI = llvm::cast<InsertTensorInst>(I); |
3058 | auto *dest = ITI->getDest(); |
3059 | auto *src = ITI->getSrc(); |
3060 | auto offsets = ITI->getOffsets(); |
3061 | auto *destPtr = emitValueAddress(builder, dest); |
3062 | auto *srcPtr = emitValueAddress(builder, src); |
3063 | |
3064 | auto *destDims = emitValueDims(builder, dest); |
3065 | auto *srcDims = emitValueDims(builder, src); |
3066 | |
3067 | auto *destDimsSize = emitConstDimT(builder, dest->getType()->dims().size()); |
3068 | auto *srcDimsSize = emitConstDimT(builder, src->getType()->dims().size()); |
3069 | auto *offsetsPtr = emitConstDimTArray(builder, offsets); |
3070 | auto *offsetsArraySize = emitConstDimT(builder, offsets.size()); |
3071 | auto *count = emitConstDimT(builder, ITI->getCount()); |
3072 | auto *axis = emitConstDimT(builder, ITI->getAxis()); |
3073 | |
3074 | // Don't specialize the offsetPtr because we typically generate lots of |
3075 | // extracts from different offsets and specializing on this argument does |
3076 | // not speed things up. |
3077 | markArgAsUnspecialized(offsetsPtr); |
3078 | |
3079 | auto *F = getFunction("insert_tensor" , dest->getElementType()); |
3080 | createCall(builder, F, |
3081 | {destPtr, srcPtr, offsetsPtr, destDims, srcDims, destDimsSize, |
3082 | srcDimsSize, offsetsArraySize, count, axis}); |
3083 | break; |
3084 | } |
3085 | |
3086 | case Kinded::Kind::ExtractTensorInstKind: { |
3087 | auto *ITI = llvm::cast<ExtractTensorInst>(I); |
3088 | auto *dest = ITI->getDest(); |
3089 | auto *src = ITI->getSrc(); |
3090 | auto offsets = ITI->getOffsets(); |
3091 | auto *destPtr = emitValueAddress(builder, dest); |
3092 | auto *srcPtr = emitValueAddress(builder, src); |
3093 | |
3094 | auto *destDims = emitValueDims(builder, dest); |
3095 | auto *srcDims = emitValueDims(builder, src); |
3096 | |
3097 | auto *destDimsSize = emitConstDimT(builder, dest->getType()->dims().size()); |
3098 | auto *srcDimsSize = emitConstDimT(builder, src->getType()->dims().size()); |
3099 | auto *offsetsPtr = emitConstDimTArray(builder, offsets); |
3100 | auto *offsetsArraySize = emitConstDimT(builder, offsets.size()); |
3101 | |
3102 | // Don't specialize the offsetPtr because we typically generate lots of |
3103 | // extracts from different offsets and specializing on this argument does |
3104 | // not speed things up. |
3105 | markArgAsUnspecialized(offsetsPtr); |
3106 | |
3107 | auto *F = getFunction("extract_tensor" , dest->getElementType()); |
3108 | createCall(builder, F, |
3109 | {srcPtr, destPtr, offsetsPtr, srcDims, destDims, srcDimsSize, |
3110 | destDimsSize, offsetsArraySize}); |
3111 | break; |
3112 | } |
3113 | |
3114 | case Kinded::Kind::GatherInstKind: { |
3115 | auto *GI = llvm::cast<GatherInst>(I); |
3116 | auto *dest = GI->getDest(); |
3117 | auto *data = GI->getData(); |
3118 | auto *indices = GI->getIndices(); |
3119 | unsigned axis = GI->getBatchDims(); |
3120 | |
3121 | auto *destPtr = emitValueAddress(builder, dest); |
3122 | auto *dataPtr = emitValueAddress(builder, data); |
3123 | auto *indicesPtr = emitValueAddress(builder, indices); |
3124 | |
3125 | auto *indicesSize = emitConstDimT(builder, indices->size()); |
3126 | |
3127 | auto *dataType = data->getType(); |
3128 | |
3129 | // The size of the sample in the batch. |
3130 | size_t sampleSize = dataType->getSliceSize(axis); |
3131 | // The size of the slices that we gather. |
3132 | size_t sliceSize = dataType->getSliceSize(axis + 1); |
3133 | // The size of each sample in the batch. |
3134 | size_t numSamples = dataType->size() / sampleSize; |
3135 | |
3136 | auto *sliceSizeVal = emitConstDimT(builder, sliceSize); |
3137 | auto *numSamplesVal = emitConstDimT(builder, numSamples); |
3138 | auto *sampleSizeVal = emitConstDimT(builder, sampleSize); |
3139 | |
3140 | // Dispatching function depending on the input type of Indices. |
3141 | llvm::Function *F = nullptr; |
3142 | if (indices->getElementType() == ElemKind::Int64ITy) { |
3143 | F = getFunction("gather64" , dest->getElementType()); |
3144 | } else if (indices->getElementType() == ElemKind::Int32ITy) { |
3145 | F = getFunction("gather32" , dest->getElementType()); |
3146 | } |
3147 | if (!F) { |
3148 | llvm_unreachable("Cannot get function for Gather. " |
3149 | "Indices input of Gather has to be int32 or int64" ); |
3150 | } |
3151 | createCall(builder, F, |
3152 | {destPtr, dataPtr, indicesPtr, indicesSize, sliceSizeVal, |
3153 | numSamplesVal, sampleSizeVal}); |
3154 | break; |
3155 | } |
3156 | |
3157 | case Kinded::Kind::GatherNDInstKind: { |
3158 | auto *GI = llvm::cast<GatherNDInst>(I); |
3159 | auto *dest = GI->getDest(); |
3160 | auto *data = GI->getData(); |
3161 | auto *indices = GI->getIndices(); |
3162 | unsigned batchDims = GI->getBatchDims(); |
3163 | |
3164 | auto dataDims = data->dims(); |
3165 | auto indicesDims = indices->dims(); |
3166 | dim_t indicesDimLast = indicesDims.back(); |
3167 | |
3168 | // Compute batch count. |
3169 | dim_t batchCount = 1; |
3170 | for (size_t idx = 0; idx < batchDims; ++idx) { |
3171 | batchCount *= dataDims[idx]; |
3172 | } |
3173 | |
3174 | // Compute input slice count. |
3175 | dim_t inpSliceCount = 1; |
3176 | for (size_t idx = batchDims; idx < batchDims + indicesDimLast; ++idx) { |
3177 | inpSliceCount *= dataDims[idx]; |
3178 | } |
3179 | |
3180 | // Compute output slice count. |
3181 | dim_t outSliceCount = 1; |
3182 | for (size_t idx = batchDims; idx < indicesDims.size() - 1; ++idx) { |
3183 | outSliceCount *= indicesDims[idx]; |
3184 | } |
3185 | |
3186 | // Compute slice size (in bytes). |
3187 | dim_t sliceSize = data->getType()->getElementSize(); |
3188 | for (size_t idx = batchDims + indicesDimLast; idx < dataDims.size(); |
3189 | idx++) { |
3190 | sliceSize *= dataDims[idx]; |
3191 | } |
3192 | |
3193 | // Get indices dimension products. |
3194 | std::vector<dim_t> indicesDimProd(indicesDimLast); |
3195 | indicesDimProd[indicesDimLast - 1] = 1; |
3196 | for (ssize_t idx = static_cast<ssize_t>(indicesDimLast) - 2; idx >= 0; |
3197 | idx--) { |
3198 | indicesDimProd[idx] = |
3199 | indicesDimProd[idx + 1] * dataDims[batchDims + idx + 1]; |
3200 | } |
3201 | |
3202 | // Emit pointers. |
3203 | auto *destPtr = emitValueAddress(builder, dest); |
3204 | auto *dataPtr = emitValueAddress(builder, data); |
3205 | auto *indicesPtr = emitValueAddress(builder, indices); |
3206 | |
3207 | // Emit parameters. |
3208 | auto *batchCountArg = emitConstDimT(builder, batchCount); |
3209 | auto *inpSliceCountArg = emitConstDimT(builder, inpSliceCount); |
3210 | auto *outSliceCountArg = emitConstDimT(builder, outSliceCount); |
3211 | auto *sliceSizeArg = emitConstDimT(builder, sliceSize); |
3212 | auto *indicesDimLastArg = emitConstDimT(builder, indicesDimLast); |
3213 | auto *indicesDimProdArg = |
3214 | emitConstDimTArray(builder, llvm::makeArrayRef(indicesDimProd)); |
3215 | |
3216 | llvm::Function *F = getFunction( |
3217 | "gather_nd" , {data->getElementType(), indices->getElementType()}); |
3218 | createCall(builder, F, |
3219 | {destPtr, dataPtr, indicesPtr, batchCountArg, inpSliceCountArg, |
3220 | outSliceCountArg, sliceSizeArg, indicesDimLastArg, |
3221 | indicesDimProdArg}); |
3222 | break; |
3223 | } |
3224 | |
3225 | case Kinded::Kind::GatherRangesInstKind: { |
3226 | auto *GRI = llvm::cast<GatherRangesInst>(I); |
3227 | auto *output = GRI->getOutput(); |
3228 | auto *lengths = GRI->getLengths(); |
3229 | auto *data = GRI->getData(); |
3230 | auto *ranges = GRI->getRanges(); |
3231 | |
3232 | auto *outputPtr = emitValueAddress(builder, output); |
3233 | auto *lengthsPtr = emitValueAddress(builder, lengths); |
3234 | auto *dataPtr = emitValueAddress(builder, data); |
3235 | auto *rangesPtr = emitValueAddress(builder, ranges); |
3236 | |
3237 | auto rangesType = ranges->getType(); |
3238 | |
3239 | // The number of examples in ranges. |
3240 | size_t numExamples = rangesType->dims()[0]; |
3241 | // The number of range pairs in each example. |
3242 | size_t exampleSize = rangesType->dims()[1]; |
3243 | |
3244 | auto *numExamplesVal = emitConstDimT(builder, numExamples); |
3245 | auto *exampleSizeVal = emitConstDimT(builder, exampleSize); |
3246 | |
3247 | // Dispatching function depending on the input type of Ranges. |
3248 | llvm::Function *F = nullptr; |
3249 | if (ranges->getElementType() == ElemKind::Int64ITy) { |
3250 | F = getFunction("gatherranges64" , output->getElementType()); |
3251 | } else if (ranges->getElementType() == ElemKind::Int32ITy) { |
3252 | F = getFunction("gatherranges32" , output->getElementType()); |
3253 | } |
3254 | if (!F) { |
3255 | llvm_unreachable("Cannot get function for GatherRanges. " |
3256 | "Ranges input of GatherRanges has to be int32 or int64" ); |
3257 | } |
3258 | createCall(builder, F, |
3259 | {outputPtr, lengthsPtr, dataPtr, rangesPtr, numExamplesVal, |
3260 | exampleSizeVal}); |
3261 | break; |
3262 | } |
3263 | |
3264 | case Kinded::Kind::LengthsRangeFillInstKind: { |
3265 | auto *LRFI = llvm::cast<LengthsRangeFillInst>(I); |
3266 | auto *dest = LRFI->getDest(); |
3267 | auto *lengths = LRFI->getLengths(); |
3268 | |
3269 | auto *destPtr = emitValueAddress(builder, dest); |
3270 | auto *lengthsPtr = emitValueAddress(builder, lengths); |
3271 | |
3272 | auto *lengthsSize = emitConstDimT(builder, lengths->size()); |
3273 | |
3274 | // Dispatching function depending on the input type of Ranges. |
3275 | auto *F = getFunction("lengths_range_fill" , dest->getElementType()); |
3276 | createCall(builder, F, {lengthsPtr, destPtr, lengthsSize}); |
3277 | break; |
3278 | } |
3279 | |
3280 | case Kinded::Kind::ScatterDataInstKind: { |
3281 | auto *SDI = llvm::cast<ScatterDataInst>(I); |
3282 | auto *data = SDI->getData(); |
3283 | auto *indices = SDI->getIndices(); |
3284 | auto *slices = SDI->getSlices(); |
3285 | |
3286 | auto *dataPtr = emitValueAddress(builder, data); |
3287 | auto *indicesPtr = emitValueAddress(builder, indices); |
3288 | auto *slicesPtr = emitValueAddress(builder, slices); |
3289 | auto *dataDims = emitValueDims(builder, data); |
3290 | |
3291 | auto *indicesCnt = emitConstDimT(builder, indices->getType()->dims()[0]); |
3292 | auto *indicesSize = emitConstDimT(builder, indices->getType()->dims()[1]); |
3293 | auto *slicesType = slices->getType(); |
3294 | auto *sliceSize = |
3295 | emitConstDimT(builder, slicesType->size() / slicesType->dims()[0]); |
3296 | auto *isCumulative = emitConstI1(builder, SDI->getCumulative()); |
3297 | auto *F = getFunction("scatterdata" , |
3298 | {data->getElementType(), indices->getElementType()}); |
3299 | if (data->getType()->isQuantizedType()) { |
3300 | auto *dataScale = emitConstF32(builder, data->getType()->getScale()); |
3301 | auto *dataOffset = emitConstI32(builder, data->getType()->getOffset()); |
3302 | auto *sliceScale = emitConstF32(builder, slices->getType()->getScale()); |
3303 | auto *sliceOffset = emitConstI32(builder, slices->getType()->getOffset()); |
3304 | createCall(builder, F, |
3305 | {dataPtr, dataDims, indicesPtr, slicesPtr, indicesCnt, |
3306 | indicesSize, sliceSize, isCumulative, dataScale, dataOffset, |
3307 | sliceScale, sliceOffset}); |
3308 | } else { |
3309 | createCall(builder, F, |
3310 | {dataPtr, dataDims, indicesPtr, slicesPtr, indicesCnt, |
3311 | indicesSize, sliceSize, isCumulative}); |
3312 | } |
3313 | break; |
3314 | } |
3315 | |
3316 | case Kinded::Kind::SparseLengthsSumInstKind: { |
3317 | auto *SI = cast<SparseLengthsSumInst>(I); |
3318 | auto *dest = SI->getDest(); |
3319 | auto *data = SI->getData(); |
3320 | auto *indices = SI->getIndices(); |
3321 | auto *lengths = SI->getLengths(); |
3322 | auto *destPtr = emitValueAddress(builder, dest); |
3323 | auto *dataPtr = emitValueAddress(builder, data); |
3324 | auto *indicesPtr = emitValueAddress(builder, indices); |
3325 | auto *lengthsPtr = emitValueAddress(builder, lengths); |
3326 | auto *segments = emitConstDimT(builder, lengths->dims()[0]); |
3327 | auto *lineSize = emitConstDimT(builder, data->size() / data->dims()[0]); |
3328 | auto *F = getFunction("sparse_lengths_sum" , |
3329 | {dest->getElementType(), indices->getElementType()}); |
3330 | createCall(builder, F, |
3331 | {destPtr, dataPtr, indicesPtr, lengthsPtr, segments, lineSize}); |
3332 | break; |
3333 | } |
3334 | |
3335 | case Kinded::Kind::SparseLengthsWeightedSumInstKind: { |
3336 | auto *SI = cast<SparseLengthsWeightedSumInst>(I); |
3337 | auto *dest = SI->getDest(); |
3338 | auto *data = SI->getData(); |
3339 | auto *weights = SI->getWeights(); |
3340 | auto *indices = SI->getIndices(); |
3341 | auto *lengths = SI->getLengths(); |
3342 | auto *destPtr = emitValueAddress(builder, dest); |
3343 | auto *dataPtr = emitValueAddress(builder, data); |
3344 | auto *weightsPtr = emitValueAddress(builder, weights); |
3345 | auto *indicesPtr = emitValueAddress(builder, indices); |
3346 | auto *lengthsPtr = emitValueAddress(builder, lengths); |
3347 | auto *segments = emitConstDimT(builder, lengths->dims()[0]); |
3348 | auto *lineSize = emitConstDimT(builder, data->size() / data->dims()[0]); |
3349 | auto *F = getFunction("sparse_lengths_weighted_sum" , |
3350 | {dest->getElementType(), indices->getElementType()}); |
3351 | createCall(builder, F, |
3352 | {destPtr, dataPtr, weightsPtr, indicesPtr, lengthsPtr, segments, |
3353 | lineSize}); |
3354 | break; |
3355 | } |
3356 | |
3357 | case Kinded::Kind::EmbeddingInstKind: { |
3358 | auto *SI = cast<EmbeddingInst>(I); |
3359 | auto *dest = SI->getDest(); |
3360 | auto *weights = SI->getWeights(); |
3361 | auto *indices = SI->getIndices(); |
3362 | auto *padIdx = emitConstSizeT(builder, SI->getPadIdx()); |
3363 | auto *scale = emitConstI1(builder, SI->getScale()); |
3364 | auto *sparse = emitConstI1(builder, SI->getSparse()); |
3365 | auto *destPtr = emitValueAddress(builder, dest); |
3366 | auto *weightsPtr = emitValueAddress(builder, weights); |
3367 | auto *indicesPtr = emitValueAddress(builder, indices); |
3368 | auto *indDims = emitValueDims(builder, indices); |
3369 | auto *indSize = emitConstDimT(builder, indices->dims().size()); |
3370 | assert(weights->dims().size() == 2 && "weights must be 2-D" ); |
3371 | auto *numEmbedding = emitConstDimT(builder, weights->dims()[0]); |
3372 | auto *embeddingDim = emitConstDimT(builder, weights->dims()[1]); |
3373 | auto *F = getFunction("embedding" , dest->getElementType()); |
3374 | createCall(builder, F, |
3375 | {destPtr, weightsPtr, indicesPtr, indDims, indSize, numEmbedding, |
3376 | embeddingDim, padIdx, scale, sparse}); |
3377 | break; |
3378 | } |
3379 | |
3380 | case Kinded::Kind::EmbeddingBagInstKind: { |
3381 | auto *SI = cast<EmbeddingBagInst>(I); |
3382 | auto *dest = SI->getDest(); |
3383 | auto *data = SI->getData(); |
3384 | auto *weights = SI->getWeights(); |
3385 | auto *indices = SI->getIndices(); |
3386 | auto *offsets = SI->getOffsets(); |
3387 | auto *hasEndOffset = emitConstI1(builder, SI->getHasEndOffset()); |
3388 | auto *destPtr = emitValueAddress(builder, dest); |
3389 | auto *dataPtr = emitValueAddress(builder, data); |
3390 | auto *weightsPtr = emitValueAddress(builder, weights); |
3391 | auto *indicesPtr = emitValueAddress(builder, indices); |
3392 | auto *offsetsPtr = emitValueAddress(builder, offsets); |
3393 | auto *segments = emitConstDimT(builder, offsets->dims()[0]); |
3394 | auto *totalLength = emitConstDimT(builder, indices->dims()[0]); |
3395 | auto *lineSize = emitConstDimT(builder, data->size() / data->dims()[0]); |
3396 | auto *F = getFunction("embedding_bag" , dest->getElementType()); |
3397 | createCall(builder, F, |
3398 | {destPtr, dataPtr, weightsPtr, indicesPtr, offsetsPtr, segments, |
3399 | lineSize, totalLength, hasEndOffset}); |
3400 | break; |
3401 | } |
3402 | |
3403 | case Kinded::Kind::SparseLengthsWeightedSumGradInstKind: { |
3404 | auto *SI = cast<SparseLengthsWeightedSumGradInst>(I); |
3405 | auto *destGrad = SI->getDestGrad(); |
3406 | auto *dataGrad = SI->getDataGrad(); |
3407 | auto *weightsGrad = SI->getWeightsGrad(); |
3408 | auto *data = SI->getData(); |
3409 | auto *weights = SI->getWeights(); |
3410 | auto *indices = SI->getIndices(); |
3411 | auto *lengths = SI->getLengths(); |
3412 | auto *destGradPtr = emitValueAddress(builder, destGrad); |
3413 | auto *dataGradPtr = emitValueAddress(builder, dataGrad); |
3414 | auto *weightsGradPtr = emitValueAddress(builder, weightsGrad); |
3415 | auto *dataPtr = emitValueAddress(builder, data); |
3416 | auto *weightsPtr = emitValueAddress(builder, weights); |
3417 | auto *indicesPtr = emitValueAddress(builder, indices); |
3418 | auto *lengthsPtr = emitValueAddress(builder, lengths); |
3419 | auto *segments = emitConstDimT(builder, lengths->dims()[0]); |
3420 | auto *dataGradRawSize = |
3421 | emitConstDimT(builder, dataGrad->size() * sizeof(float)); |
3422 | auto *lineSize = |
3423 | emitConstDimT(builder, dataGrad->size() / dataGrad->dims()[0]); |
3424 | auto *F = |
3425 | getFunction("sparse_lengths_weighted_sum_grad" , |
3426 | {destGrad->getElementType(), indices->getElementType()}); |
3427 | createCall(builder, F, |
3428 | {destGradPtr, dataGradPtr, weightsGradPtr, dataPtr, weightsPtr, |
3429 | indicesPtr, lengthsPtr, segments, lineSize, dataGradRawSize}); |
3430 | break; |
3431 | } |
3432 | |
3433 | case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumInstKind: { |
3434 | auto *N = cast<RowwiseQuantizedSparseLengthsWeightedSumInst>(I); |
3435 | auto *dest = N->getDest(); |
3436 | auto *data = N->getData(); |
3437 | auto *scales = N->getScales(); |
3438 | auto *offsets = N->getOffsets(); |
3439 | auto *weights = N->getWeights(); |
3440 | auto *indices = N->getIndices(); |
3441 | auto *lengths = N->getLengths(); |
3442 | auto *destPtr = emitValueAddress(builder, dest); |
3443 | auto *dataPtr = emitValueAddress(builder, data); |
3444 | auto *scalesPtr = emitValueAddress(builder, scales); |
3445 | auto *offsetsPtr = emitValueAddress(builder, offsets); |
3446 | auto *weightsPtr = emitValueAddress(builder, weights); |
3447 | auto *indicesPtr = emitValueAddress(builder, indices); |
3448 | auto *lengthsPtr = emitValueAddress(builder, lengths); |
3449 | auto *segments = emitConstDimT(builder, lengths->dims()[0]); |
3450 | auto *lineSize = emitConstDimT(builder, data->size() / data->dims()[0]); |
3451 | auto *F = getFunction("rowwise_quantized_sparse_lengths_weighted_sum" , |
3452 | {dest->getElementType(), indices->getElementType()}); |
3453 | createCall(builder, F, |
3454 | {destPtr, dataPtr, scalesPtr, offsetsPtr, weightsPtr, indicesPtr, |
3455 | lengthsPtr, segments, lineSize}); |
3456 | break; |
3457 | } |
3458 | |
3459 | case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumInstKind: { |
3460 | auto *N = cast<FusedRowwiseQuantizedSparseLengthsWeightedSumInst>(I); |
3461 | auto *dest = N->getDest(); |
3462 | auto *data = N->getData(); |
3463 | auto *weights = N->getWeights(); |
3464 | auto *indices = N->getIndices(); |
3465 | auto *lengths = N->getLengths(); |
3466 | auto *destPtr = emitValueAddress(builder, dest); |
3467 | auto *dataPtr = emitValueAddress(builder, data); |
3468 | auto *weightsPtr = emitValueAddress(builder, weights); |
3469 | auto *indicesPtr = emitValueAddress(builder, indices); |
3470 | auto *lengthsPtr = emitValueAddress(builder, lengths); |
3471 | auto *segments = emitConstDimT(builder, lengths->dims()[0]); |
3472 | auto *inLineSize = emitConstDimT(builder, data->size() / data->dims()[0]); |
3473 | auto *outLineSize = emitConstDimT(builder, dest->size() / dest->dims()[0]); |
3474 | auto *F = getFunction("fused_rowwise_quantized_sparse_lengths_weighted_sum" , |
3475 | {dest->getElementType(), indices->getElementType()}); |
3476 | createCall(builder, F, |
3477 | {destPtr, dataPtr, weightsPtr, indicesPtr, lengthsPtr, segments, |
3478 | inLineSize, outLineSize}); |
3479 | break; |
3480 | } |
3481 | |
3482 | case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsInstKind: { |
3483 | auto *N = cast<EmbeddingBagByteRowwiseOffsetsInst>(I); |
3484 | auto *dest = N->getDest(); |
3485 | auto *data = N->getData(); |
3486 | auto *weights = N->getWeights(); |
3487 | auto *indices = N->getIndices(); |
3488 | auto *offsets = N->getOffsets(); |
3489 | auto *hasEndOffset = emitConstI1(builder, N->getHasEndOffset()); |
3490 | auto *destPtr = emitValueAddress(builder, dest); |
3491 | auto *dataPtr = emitValueAddress(builder, data); |
3492 | auto *weightsPtr = emitValueAddress(builder, weights); |
3493 | auto *indicesPtr = emitValueAddress(builder, indices); |
3494 | auto *offsetsPtr = emitValueAddress(builder, offsets); |
3495 | auto *segments = emitConstDimT(builder, offsets->dims()[0]); |
3496 | auto *numIndices = emitConstDimT(builder, indices->dims()[0]); |
3497 | auto *inLineSize = emitConstDimT(builder, data->size() / data->dims()[0]); |
3498 | auto *outLineSize = emitConstDimT(builder, dest->size() / dest->dims()[0]); |
3499 | auto *F = getFunction("embedding_bag_byte_rowwise_offsets" , |
3500 | dest->getElementType()); |
3501 | createCall(builder, F, |
3502 | {destPtr, dataPtr, weightsPtr, indicesPtr, offsetsPtr, segments, |
3503 | numIndices, inLineSize, outLineSize, hasEndOffset}); |
3504 | break; |
3505 | } |
3506 | |
3507 | case Kinded::Kind::DebugPrintInstKind: { |
3508 | auto *DPI = llvm::cast<DebugPrintInst>(I); |
3509 | auto *src = DPI->getSrc(); |
3510 | auto *srcPtr = emitValueAddress(builder, src); |
3511 | srcPtr = builder.CreateBitCast(srcPtr, builder.getInt8PtrTy()); |
3512 | auto *srcDims = emitValueDims(builder, src); |
3513 | auto *srcDimsSize = emitConstDimT(builder, src->getType()->dims().size()); |
3514 | auto *srcSize = emitConstSizeT(builder, src->getType()->size()); |
3515 | auto *srcSizeBytes = |
3516 | emitConstSizeT(builder, src->getType()->getSizeInBytes()); |
3517 | auto *srcElemKind = |
3518 | emitConstDimT(builder, static_cast<size_t>(src->getElementType())); |
3519 | auto *name = emitStringConst(builder, I->getName()); |
3520 | auto *filename = emitStringConst(builder, DPI->getFileName()); |
3521 | auto srcTypeStr = src->getType()->toString(); |
3522 | |
3523 | std::string format = DPI->getFormat(); |
3524 | if (format == "console" ) { |
3525 | // Dump tensor in console. |
3526 | auto *F = getFunction("dump_tensor_console" ); |
3527 | createCall(builder, F, {srcPtr, srcDims, srcDimsSize, srcElemKind, name}); |
3528 | |
3529 | } else if (format == "bin" ) { |
3530 | // Dump tensor in file in binary format. |
3531 | auto *F = getFunction("dump_tensor_bin" ); |
3532 | auto * = emitStringConst(builder, srcTypeStr); |
3533 | createCall(builder, F, {srcPtr, srcSizeBytes, filename, header}); |
3534 | |
3535 | } else if (format == "txt" ) { |
3536 | // Dump tensor in file in text format. |
3537 | auto *F = getFunction("dump_tensor_txt" , src->getElementType()); |
3538 | auto * = emitStringConst(builder, srcTypeStr); |
3539 | createCall(builder, F, {srcPtr, srcSize, filename, header}); |
3540 | |
3541 | } else if (format == "rawbin" ) { |
3542 | // Dump tensor in file in raw binary format. |
3543 | auto *F = getFunction("dump_tensor_bin" ); |
3544 | auto * = emitStringConst(builder, "" ); |
3545 | createCall(builder, F, {srcPtr, srcSizeBytes, filename, header}); |
3546 | |
3547 | } else if (format == "rawtxt" ) { |
3548 | // Dump tensor in file in raw text format. |
3549 | auto *F = getFunction("dump_tensor_txt" , src->getElementType()); |
3550 | auto * = emitStringConst(builder, "" ); |
3551 | createCall(builder, F, {srcPtr, srcSize, filename, header}); |
3552 | |
3553 | } else { |
3554 | LOG(FATAL) << "Invalid 'Format' attribute for DebugPrint instruction!" ; |
3555 | } |
3556 | break; |
3557 | } |
3558 | |
3559 | case Kinded::Kind::InstrumentInstKind: { |
3560 | auto *instrumentI = llvm::cast<InstrumentInst>(I); |
3561 | auto *opInfo = instrumentI->getOperandsInfo(); |
3562 | |
3563 | // Instruction being instrumented. |
3564 | Instruction *instrRef = instrumentI->getInstrRef(); |
3565 | |
3566 | // Emit instruction ID and instruction kind. |
3567 | llvm::Type *intTy = |
3568 | llvm::Type::getIntNTy(getLLVMContext(), getLibjitIntWidth()); |
3569 | auto *ID = llvm::ConstantInt::get(intTy, instrumentI->getID()); |
3570 | auto *kind = llvm::ConstantInt::get(intTy, (int)(instrRef->getKind())); |
3571 | |
3572 | // Emit number of input and output operands. |
3573 | auto inpNum = instrRef->getNumInputs(); |
3574 | auto outNum = instrRef->getNumOutputs(); |
3575 | auto opNum = inpNum + outNum; |
3576 | auto *opInp = llvm::ConstantInt::get(intTy, inpNum); |
3577 | auto *opOut = llvm::ConstantInt::get(intTy, outNum); |
3578 | |
3579 | // Emit opInfo address as uint8_t*. |
3580 | assert(opInfo->getType()->getSizeInBytes() >= 2 * sizeof(int64_t) && |
3581 | "Not enough memory allocated for instrumentation!" ); |
3582 | auto *opInfoPtr = emitValueAddress(builder, opInfo); |
3583 | opInfoPtr = builder.CreateBitCast(opInfoPtr, builder.getInt8PtrTy()); |
3584 | |
3585 | // Emit opAddr address as uint8_t** starting from offset 0. |
3586 | auto *opAddrPtr = |
3587 | builder.CreateGEP(opInfoPtr, llvm::ConstantInt::get(intTy, 0)); |
3588 | opAddrPtr = builder.CreateBitCast(opAddrPtr, |
3589 | builder.getInt8PtrTy()->getPointerTo()); |
3590 | |
3591 | // Emit opSize address as int* starting from offset opNum * sizeof(int64_t). |
3592 | auto *opSizePtr = builder.CreateGEP( |
3593 | opInfoPtr, llvm::ConstantInt::get(intTy, opNum * sizeof(int64_t))); |
3594 | opSizePtr = builder.CreateBitCast(opSizePtr, intTy->getPointerTo()); |
3595 | |
3596 | // Generate instrumentation. |
3597 | auto instrumentKind = instrumentI->getInstrumentKind(); |
3598 | if (instrumentKind == InstrumentKind::Before) { |
3599 | |
3600 | // Operands addresses and sizes. |
3601 | std::vector<llvm::Value *> opAddrArray; |
3602 | std::vector<llvm::Value *> opSizeArray; |
3603 | |
3604 | // Get addresses and sizes for the input operands. |
3605 | for (const auto &op : instrRef->getOperands()) { |
3606 | if (op.second == OperandKind::Out) { |
3607 | continue; |
3608 | } |
3609 | // Emit operand address as uint8_t* variable. |
3610 | auto *opAddr = emitValueAddress(builder, op.first); |
3611 | opAddr = builder.CreateBitCast(opAddr, builder.getInt8PtrTy()); |
3612 | opAddrArray.push_back(opAddr); |
3613 | // Emit operand size in bytes as int constant. |
3614 | auto *opSize = llvm::ConstantInt::get( |
3615 | intTy, op.first->getType()->getSizeInBytes()); |
3616 | opSizeArray.push_back(opSize); |
3617 | } |
3618 | assert(opAddrArray.size() == inpNum && "Inconsistent size!" ); |
3619 | |
3620 | // Get addresses and sizes for the output operands. |
3621 | for (const auto &op : instrRef->getOperands()) { |
3622 | if (op.second == OperandKind::In) { |
3623 | continue; |
3624 | } |
3625 | // Emit operand address as uint8_t* variable. |
3626 | auto *opAddr = emitValueAddress(builder, op.first); |
3627 | opAddr = builder.CreateBitCast(opAddr, builder.getInt8PtrTy()); |
3628 | opAddrArray.push_back(opAddr); |
3629 | // Emit operand size in bytes as int constant. |
3630 | auto *opSize = llvm::ConstantInt::get( |
3631 | intTy, op.first->getType()->getSizeInBytes()); |
3632 | opSizeArray.push_back(opSize); |
3633 | } |
3634 | assert(opAddrArray.size() == opNum && "Inconsistent size!" ); |
3635 | |
3636 | // Write the addresses of the operands in the opAddr. |
3637 | emitArrayStore(builder, opAddrArray, opAddrPtr); |
3638 | |
3639 | // Write the sizes of the operands in opSize. |
3640 | emitArrayStore(builder, opSizeArray, opSizePtr); |
3641 | |
3642 | // Create callback call. |
3643 | auto *F = getFunction("instrument_before" ); |
3644 | createCall(builder, F, {ID, kind, opInp, opOut, opAddrPtr, opSizePtr}); |
3645 | |
3646 | } else if (instrumentKind == InstrumentKind::After) { |
3647 | |
3648 | // Create callback call. |
3649 | auto *F = getFunction("instrument_after" ); |
3650 | createCall(builder, F, {ID, kind, opInp, opOut, opAddrPtr, opSizePtr}); |
3651 | |
3652 | } else { |
3653 | llvm_unreachable("Instrumentation kind not supported!" ); |
3654 | } |
3655 | // Print the IR instrumentation callback API. |
3656 | printInstrumentIR_ = true; |
3657 | break; |
3658 | } |
3659 | |
3660 | case Kinded::Kind::TraceEventInstKind: { |
3661 | auto *TEI = llvm::cast<TraceEventInst>(I); |
3662 | auto *data = TEI->getData(); |
3663 | auto *offset = emitConstDimT(builder, TEI->getIndex()); |
3664 | auto *dataPtr = emitValueAddress(builder, data); |
3665 | auto *F = getFunction("write_timestamp" ); |
3666 | createCall(builder, F, {dataPtr, offset}); |
3667 | break; |
3668 | } |
3669 | |
3670 | case Kinded::Kind::ResizeNearestInstKind: { |
3671 | auto *RNI = llvm::cast<ResizeNearestInst>(I); |
3672 | auto *result = RNI->getDest(); |
3673 | auto *input = RNI->getSrc(); |
3674 | auto *resultPtr = emitValueAddress(builder, result); |
3675 | auto *inputPtr = emitValueAddress(builder, input); |
3676 | |
3677 | auto *scalePtr = emitConstFloatArray(builder, RNI->getScale()); |
3678 | auto *destDims = emitValueDims(builder, result); |
3679 | auto *srcDims = emitValueDims(builder, input); |
3680 | auto *F = getFunction("resizenearest" , input->getElementType()); |
3681 | createCall(builder, F, {resultPtr, inputPtr, scalePtr, srcDims, destDims}); |
3682 | break; |
3683 | } |
3684 | |
3685 | case Kinded::Kind::ResizeBilinearInstKind: { |
3686 | auto *RBI = llvm::cast<ResizeBilinearInst>(I); |
3687 | auto *result = RBI->getDest(); |
3688 | auto *input = RBI->getSrc(); |
3689 | auto *resultPtr = emitValueAddress(builder, result); |
3690 | auto *inputPtr = emitValueAddress(builder, input); |
3691 | |
3692 | CHECK_EQ(RBI->getScale()[0], 1.0) << "Scaling batch not supported." ; |
3693 | CHECK_EQ(RBI->getScale()[3], 1.0) << "Scaling channel not supported." ; |
3694 | |
3695 | auto *scalePtr = emitConstFloatArray(builder, RBI->getScale()); |
3696 | auto *destDims = emitValueDims(builder, result); |
3697 | auto *srcDims = emitValueDims(builder, input); |
3698 | auto *F = getFunction("resizebilinear" , input->getElementType()); |
3699 | createCall(builder, F, {resultPtr, inputPtr, scalePtr, srcDims, destDims}); |
3700 | break; |
3701 | } |
3702 | |
3703 | case Kinded::Kind::NonMaxSuppressionInstKind: { |
3704 | auto *NMSI = llvm::cast<NonMaxSuppressionInst>(I); |
3705 | auto boxes = NMSI->getBoxes(); |
3706 | auto scores = NMSI->getScores(); |
3707 | auto indices = NMSI->getIndices(); |
3708 | auto numDetected = NMSI->getNumberOfSelectedIndices(); |
3709 | float iouThreshold = NMSI->getIouThreshold(); |
3710 | int64_t maxBoxesPerClass = NMSI->getMaxOutputBoxesPerClass(); |
3711 | float scoreThreshold = NMSI->getScoreThreshold(); |
3712 | int centerPointBox = NMSI->getCenterPointBox(); |
3713 | bool isV4 = NMSI->getIsTFVersion4(); |
3714 | |
3715 | auto *boxesPtr = emitValueAddress(builder, boxes); |
3716 | auto *scoresPtr = emitValueAddress(builder, scores); |
3717 | auto *indicesPtr = emitValueAddress(builder, indices); |
3718 | auto *numDetectedPtr = emitValueAddress(builder, numDetected); |
3719 | |
3720 | auto *maxBoxesPerClassVal = emitConstI32(builder, maxBoxesPerClass); |
3721 | auto *centerPointBoxVal = emitConstI32(builder, centerPointBox); |
3722 | auto *iouThresholdVal = emitConstF32(builder, iouThreshold); |
3723 | auto *scoreThresholdVal = emitConstF32(builder, scoreThreshold); |
3724 | |
3725 | auto *boxesDimVal = emitValueDims(builder, boxes); |
3726 | auto *scoreDimVal = emitValueDims(builder, scores); |
3727 | auto *indicesDimVal = emitValueDims(builder, indices); |
3728 | auto *boxesDimSizeVal = emitConstDimT(builder, boxes->dims().size()); |
3729 | auto *scoresDimSizeVal = emitConstDimT(builder, scores->dims().size()); |
3730 | auto *indicesDimSizeVal = emitConstDimT(builder, indices->dims().size()); |
3731 | auto *isV4Val = emitConstI1(builder, isV4); |
3732 | |
3733 | auto *F = getFunction("nms" , indices->getElementType()); |
3734 | createCall(builder, F, |
3735 | {indicesPtr, numDetectedPtr, boxesPtr, boxesDimVal, |
3736 | boxesDimSizeVal, scoresPtr, scoreDimVal, scoresDimSizeVal, |
3737 | indicesDimVal, indicesDimSizeVal, centerPointBoxVal, |
3738 | maxBoxesPerClassVal, iouThresholdVal, scoreThresholdVal, |
3739 | isV4Val}); |
3740 | break; |
3741 | } |
3742 | |
3743 | case Kinded::Kind::TFLiteDetectionPostProcessInstKind: { |
3744 | auto *DPPI = llvm::cast<TFLiteDetectionPostProcessInst>(I); |
3745 | auto boxes = DPPI->getBoxes(); |
3746 | auto scores = DPPI->getScores(); |
3747 | auto anchors = DPPI->getAnchors(); |
3748 | auto detectionBoxes = DPPI->getDetectionBoxes(); |
3749 | auto detectionClasses = DPPI->getDetectionClasses(); |
3750 | auto detectionScores = DPPI->getDetectionScores(); |
3751 | auto numDetections = DPPI->getNumDetections(); |
3752 | auto scratch = DPPI->getScratch(); |
3753 | |
3754 | // Emit pointers. |
3755 | auto *boxesPtr = emitValueAddress(builder, boxes); |
3756 | auto *scoresPtr = emitValueAddress(builder, scores); |
3757 | auto *anchorsPtr = emitValueAddress(builder, anchors); |
3758 | auto *detectionBoxesPtr = emitValueAddress(builder, detectionBoxes); |
3759 | auto *detectionClassesPtr = emitValueAddress(builder, detectionClasses); |
3760 | auto *detectionScoresPtr = emitValueAddress(builder, detectionScores); |
3761 | auto *numDetectionsPtr = emitValueAddress(builder, numDetections); |
3762 | auto *scratchPtr = emitValueAddress(builder, scratch); |
3763 | |
3764 | // Emit parameters. |
3765 | auto *numBoxes = emitConstI32(builder, boxes->dims()[1]); |
3766 | auto *numTotalClasses = emitConstI32(builder, scores->dims()[2]); |
3767 | auto *numClasses = emitConstI32(builder, DPPI->getNumClasses()); |
3768 | auto *maxDetections = emitConstI32(builder, DPPI->getMaxDetections()); |
3769 | auto *maxClassesPerDetection = |
3770 | emitConstI32(builder, DPPI->getMaxClassesPerDetection()); |
3771 | auto *maxDetectionsPerClass = |
3772 | emitConstI32(builder, DPPI->getMaxDetectionsPerClass()); |
3773 | auto *iouThreshold = emitConstF32(builder, DPPI->getIouThreshold()); |
3774 | auto *scoreThreshold = emitConstF32(builder, DPPI->getScoreThreshold()); |
3775 | auto *xScaleInv = emitConstF32(builder, 1.0f / DPPI->getXScale()); |
3776 | auto *yScaleInv = emitConstF32(builder, 1.0f / DPPI->getYScale()); |
3777 | auto *hScaleInv = emitConstF32(builder, 1.0f / DPPI->getHScale()); |
3778 | auto *wScaleInv = emitConstF32(builder, 1.0f / DPPI->getWScale()); |
3779 | auto *regularNMS = emitConstI1(builder, DPPI->getRegularNMS()); |
3780 | |
3781 | // Current implementation only supports batch size 1. |
3782 | assert(boxes->dims()[0] == 1 && |
3783 | "TFLiteDetectionPostProcess batch not supported!" ); |
3784 | |
3785 | // Call function. |
3786 | auto *F = getFunction("tflite_detection_post_process_f" ); |
3787 | createCall(builder, F, |
3788 | {boxesPtr, |
3789 | scoresPtr, |
3790 | anchorsPtr, |
3791 | detectionBoxesPtr, |
3792 | detectionClassesPtr, |
3793 | detectionScoresPtr, |
3794 | numDetectionsPtr, |
3795 | scratchPtr, |
3796 | numBoxes, |
3797 | numTotalClasses, |
3798 | numClasses, |
3799 | maxDetections, |
3800 | maxClassesPerDetection, |
3801 | maxDetectionsPerClass, |
3802 | iouThreshold, |
3803 | scoreThreshold, |
3804 | xScaleInv, |
3805 | yScaleInv, |
3806 | hScaleInv, |
3807 | wScaleInv, |
3808 | regularNMS}); |
3809 | break; |
3810 | } |
3811 | |
3812 | case Kinded::Kind::AudioSpectrogramInstKind: { |
3813 | auto *ASI = llvm::cast<AudioSpectrogramInst>(I); |
3814 | auto winOutScratch = ASI->getWinOutScratch(); |
3815 | auto fftOutScratch = ASI->getFftOutScratch(); |
3816 | auto spectrogram = ASI->getSpectrogram(); |
3817 | auto input = ASI->getInput(); |
3818 | auto window = ASI->getWindow(); |
3819 | auto twiddleFactors = ASI->getTwiddleFactors(); |
3820 | auto bitReverseIndices = ASI->getBitReverseIndices(); |
3821 | auto complexToRealWeights = ASI->getComplexToRealWeights(); |
3822 | int64_t windowSize = ASI->getWindowSize(); |
3823 | int64_t windowStride = ASI->getWindowStride(); |
3824 | bool magnitudeSquared = ASI->getMagnitudeSquared(); |
3825 | |
3826 | auto *winOutScratchPtr = emitValueAddress(builder, winOutScratch); |
3827 | auto *fftOutScratchPtr = emitValueAddress(builder, fftOutScratch); |
3828 | auto *spectrogramPtr = emitValueAddress(builder, spectrogram); |
3829 | auto *inputPtr = emitValueAddress(builder, input); |
3830 | auto *windowPtr = emitValueAddress(builder, window); |
3831 | auto *twiddleFactorsPtr = emitValueAddress(builder, twiddleFactors); |
3832 | auto *bitReverseIndicesPtr = emitValueAddress(builder, bitReverseIndices); |
3833 | auto *complexToRealWeightsPtr = |
3834 | emitValueAddress(builder, complexToRealWeights); |
3835 | auto *spectrogramDimVal = emitValueDims(builder, spectrogram); |
3836 | auto *inputLengthVal = emitConstDimT(builder, input->size()); |
3837 | auto *windowSizeVal = emitConstDimT(builder, windowSize); |
3838 | auto *windowStrideVal = emitConstDimT(builder, windowStride); |
3839 | auto *magnitudeSquaredVal = emitConstI1(builder, magnitudeSquared); |
3840 | |
3841 | auto *F = getFunction("audio_spectrogram" , spectrogram->getElementType()); |
3842 | createCall(builder, F, |
3843 | {winOutScratchPtr, fftOutScratchPtr, spectrogramPtr, inputPtr, |
3844 | windowPtr, twiddleFactorsPtr, bitReverseIndicesPtr, |
3845 | complexToRealWeightsPtr, spectrogramDimVal, inputLengthVal, |
3846 | windowSizeVal, windowStrideVal, magnitudeSquaredVal}); |
3847 | break; |
3848 | } |
3849 | |
3850 | case Kinded::Kind::MFCCInstKind: { |
3851 | auto *MFCCI = llvm::cast<MFCCInst>(I); |
3852 | auto scratch = MFCCI->getScratch(); |
3853 | auto coefficients = MFCCI->getCoefficients(); |
3854 | auto spectrogram = MFCCI->getSpectrogram(); |
3855 | auto melWeights = MFCCI->getMelWeights(); |
3856 | auto melRanges = MFCCI->getMelRanges(); |
3857 | auto dctMat = MFCCI->getDctMat(); |
3858 | int64_t filterBankCount = MFCCI->getFilterBankCount(); |
3859 | |
3860 | auto *scratchPtr = emitValueAddress(builder, scratch); |
3861 | auto *coefficientsPtr = emitValueAddress(builder, coefficients); |
3862 | auto *spectrogramPtr = emitValueAddress(builder, spectrogram); |
3863 | auto *melWeightsPtr = emitValueAddress(builder, melWeights); |
3864 | auto *melRangesPtr = emitValueAddress(builder, melRanges); |
3865 | auto *dctMatPtr = emitValueAddress(builder, dctMat); |
3866 | auto *coefficientsDimVal = emitValueDims(builder, coefficients); |
3867 | auto *spectrogramDimVal = emitValueDims(builder, spectrogram); |
3868 | auto *filterBankCountVal = emitConstDimT(builder, filterBankCount); |
3869 | |
3870 | auto *F = getFunction("mfcc" , coefficients->getElementType()); |
3871 | createCall(builder, F, |
3872 | {scratchPtr, coefficientsPtr, spectrogramPtr, melWeightsPtr, |
3873 | melRangesPtr, dctMatPtr, coefficientsDimVal, spectrogramDimVal, |
3874 | filterBankCountVal}); |
3875 | break; |
3876 | } |
3877 | |
3878 | case Kinded::Kind::ConvertToInstKind: { |
3879 | auto *CTI = llvm::cast<ConvertToInst>(I); |
3880 | auto *input = CTI->getInput(); |
3881 | auto *output = CTI->getResult(); |
3882 | |
3883 | auto *inputVal = emitValueAddress(builder, input); |
3884 | auto *outptVal = emitValueAddress(builder, output); |
3885 | auto *dimsVal = emitValueDims(builder, output); |
3886 | auto *dimSizeVal = emitConstDimT(builder, output->dims().size()); |
3887 | |
3888 | auto *F = getFunction("convertTo" , |
3889 | {output->getElementType(), input->getElementType()}); |
3890 | |
3891 | createCall(builder, F, {outptVal, inputVal, dimsVal, dimSizeVal}); |
3892 | break; |
3893 | } |
3894 | |
3895 | default: |
3896 | std::string sBuf; |
3897 | llvm::raw_string_ostream s(sBuf); |
3898 | I->dump(s); |
3899 | LOG(FATAL) << "Cannot select the instruction: " << s.str(); |
3900 | } |
3901 | } |
3902 | |
3903 | unsigned LLVMIRGen::getTargetSizeTWidth() const { |
3904 | return getPointerNumBits(*TM_); |
3905 | } |
3906 | |
3907 | unsigned LLVMIRGen::getLibjitSizeTWidth() const { |
3908 | auto *sizeTVar = getModule().getGlobalVariable("libjit_sizeTVar" , |
3909 | /* allowInternal */ true); |
3910 | assert(sizeTVar && "libjit_sizeTVar is not found" ); |
3911 | return sizeTVar->getType()->getPointerElementType()->getIntegerBitWidth(); |
3912 | } |
3913 | |
3914 | unsigned LLVMIRGen::getLibjitIntWidth() const { |
3915 | auto *intVar = getModule().getGlobalVariable("libjit_intVar" , |
3916 | /* allowInternal */ true); |
3917 | assert(intVar && "libjit_intVar is not found" ); |
3918 | return intVar->getType()->getPointerElementType()->getIntegerBitWidth(); |
3919 | } |
3920 | |
3921 | bool LLVMIRGen::isEligibleForSpecialization(const llvm::CallInst *call) { |
3922 | return true; |
3923 | } |
3924 | |
3925 | bool LLVMIRGen::canBePartOfDataParallelKernel( |
3926 | const glow::Instruction *I) const { |
3927 | return I->isDataParallel(); |
3928 | } |
3929 | |
3930 | /// Extra bundle header file content with the IR instrumentation callback API. |
3931 | static const char *instrumentIRApi = |
3932 | R"RAW( |
3933 | // ----------------------------------------------------------------------------- |
3934 | // Callback function used for Glow IR instruction instrumentation: |
3935 | // - This callback is called by the bundle BEFORE executing each instruction. |
3936 | // - This callback must be defined by the bundle user application. |
3937 | // ARGUMENTS: |
3938 | // id - Instruction instance ID. |
3939 | // kind - Instruction kind (type). |
3940 | // opInp - Number of input operands. |
3941 | // opOut - Number of output operands. |
3942 | // opAddr - Array with addresses for all operands. The addresses are listed |
3943 | // first for the input operands and then for the output operands. |
3944 | // The array contains opInp + opOut addresses. |
3945 | // opSize - Array with sizes (in bytes) for all operands. The sizes are listed |
3946 | // first for the input operands and then for the output operands. |
3947 | // The array contains opInp + opOut sizes. |
3948 | // NOTES: |
3949 | // - This callback should be used to dump only the input operands since the |
3950 | // output operands are not yet computed/written when this callback is used. |
3951 | // - This callback uses C linkage therefore if the callback is implemented in a |
3952 | // .cpp file you must enclose the implementation in extern "C" {}. |
3953 | // - Look in the metafile "instrument-ir.info" generated during compile-time |
3954 | // to see more information about the instrumented instructions. |
3955 | // ----------------------------------------------------------------------------- |
3956 | void glow_instrument_before(int id, int kind, int opInp, int opOut, uint8_t **opAddr, int *opSize); |
3957 | |
3958 | // ----------------------------------------------------------------------------- |
3959 | // Callback function used for Glow IR instruction instrumentation: |
3960 | // - This callback is called by the bundle AFTER executing each instruction. |
3961 | // - This callback must be defined by the bundle user application. |
3962 | // ARGUMENTS: |
3963 | // id - Instruction instance ID. |
3964 | // kind - Instruction kind (type). |
3965 | // opInp - Number of input operands. |
3966 | // opOut - Number of output operands. |
3967 | // opAddr - Array with addresses for all operands. The addresses are listed |
3968 | // first for the input operands and then for the output operands. |
3969 | // The array contains opInp + opOut addresses. |
3970 | // opSize - Array with sizes (in bytes) for all operands. The sizes are listed |
3971 | // first for the input operands and then for the output operands. |
3972 | // The array contains opInp + opOut sizes. |
3973 | // NOTES: |
3974 | // - This callback should be used to dump only the output operands since some |
3975 | // of the input operands might have been overwritten for instructions which |
3976 | // perform in-place computation. |
3977 | // - This callback uses C linkage therefore if the callback is implemented in a |
3978 | // .cpp file you must enclose the implementation in extern "C" {}. |
3979 | // - Look in the metafile "instrument-ir.info" generated during compile-time |
3980 | // to see more information about the instrumented instructions. |
3981 | // ----------------------------------------------------------------------------- |
3982 | void glow_instrument_after(int id, int kind, int opInp, int opOut, uint8_t **opAddr, int *opSize); |
3983 | )RAW" ; |
3984 | |
3985 | std::string LLVMIRGen::() const { |
3986 | std::string = "" ; |
3987 | // Print IR instrumentation callback API. |
3988 | if (printInstrumentIR_) { |
3989 | headerExtra += std::string(instrumentIRApi); |
3990 | } |
3991 | return headerExtra; |
3992 | } |
3993 | |