1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_llvm.cc |
22 | */ |
23 | #ifdef TVM_LLVM_VERSION |
24 | // Part of the code are adapted from Halide's CodeGen_LLVM |
25 | #include "codegen_llvm.h" |
26 | |
27 | #include <llvm/ADT/ArrayRef.h> |
28 | #include <llvm/ADT/SmallVector.h> |
29 | #include <llvm/ADT/StringRef.h> |
30 | #include <llvm/ADT/Triple.h> |
31 | #include <llvm/Analysis/TargetTransformInfo.h> |
32 | #if TVM_LLVM_VERSION >= 50 |
33 | #include <llvm/BinaryFormat/Dwarf.h> |
34 | #else |
35 | #include <llvm/Support/Dwarf.h> |
36 | #endif |
37 | #if TVM_LLVM_VERSION >= 60 |
38 | #include <llvm/CodeGen/TargetSubtargetInfo.h> |
39 | #else |
40 | #include <llvm/Target/TargetSubtargetInfo.h> |
41 | #endif |
42 | #include <llvm/IR/Argument.h> |
43 | #include <llvm/IR/Attributes.h> |
44 | #include <llvm/IR/BasicBlock.h> |
45 | #include <llvm/IR/CallingConv.h> |
46 | #include <llvm/IR/Constants.h> |
47 | #include <llvm/IR/DIBuilder.h> |
48 | #include <llvm/IR/DataLayout.h> |
49 | #include <llvm/IR/DebugInfoMetadata.h> |
50 | #include <llvm/IR/DerivedTypes.h> |
51 | #if TVM_LLVM_VERSION >= 150 |
52 | #include <llvm/IR/FMF.h> |
53 | #else |
54 | #include <llvm/IR/Operator.h> |
55 | #endif |
56 | #include <llvm/IR/Function.h> |
57 | #include <llvm/IR/GlobalVariable.h> |
58 | #include <llvm/IR/Instructions.h> |
59 | #include <llvm/IR/Intrinsics.h> |
60 | #include <llvm/IR/LLVMContext.h> |
61 | #include <llvm/IR/MDBuilder.h> |
62 | #include <llvm/IR/Metadata.h> |
63 | #include <llvm/IR/Module.h> |
64 | #include <llvm/IR/Type.h> |
65 | #include <llvm/IRReader/IRReader.h> |
66 | #include <llvm/Linker/Linker.h> |
67 | #include <llvm/Pass.h> |
68 | #if TVM_LLVM_VERSION >= 160 |
69 | #include <llvm/IR/Verifier.h> // For VerifierPass |
70 | #include <llvm/Passes/PassBuilder.h> |
71 | #include <llvm/Passes/StandardInstrumentations.h> |
72 | #else |
73 | #include <llvm/IR/LegacyPassManager.h> |
74 | #include <llvm/Transforms/IPO/PassManagerBuilder.h> |
75 | #endif |
76 | #if TVM_LLVM_VERSION >= 100 |
77 | #include <llvm/Support/Alignment.h> |
78 | #include <llvm/Support/TypeSize.h> |
79 | #endif |
80 | #include <llvm/Support/CodeGen.h> |
81 | #include <llvm/Support/MemoryBuffer.h> |
82 | #include <llvm/Support/SourceMgr.h> |
83 | #include <llvm/Target/TargetMachine.h> |
84 | #include <llvm/Transforms/IPO.h> |
85 | #include <llvm/Transforms/Utils/ModuleUtils.h> |
86 | #include <tvm/runtime/c_runtime_api.h> |
87 | #include <tvm/runtime/crt/error_codes.h> |
88 | #include <tvm/runtime/device_api.h> |
89 | #include <tvm/tir/op.h> |
90 | |
91 | #include <algorithm> |
92 | #include <functional> |
93 | #include <memory> |
94 | #include <sstream> |
95 | #include <string> |
96 | #include <utility> |
97 | #include <vector> |
98 | |
99 | #include "../../arith/pattern_match.h" |
100 | #include "../build_common.h" |
101 | #include "../func_registry_generator.h" |
102 | #include "codegen_params.h" |
103 | #include "llvm_instance.h" |
104 | |
105 | namespace tvm { |
106 | namespace codegen { |
107 | |
108 | // CodeGenLLVM has members of type std::unique_ptr<T>. These members will be |
109 | // instantiated in the constructor, which will requre that the type T is |
110 | // complete at that point. Put the constructor (and destructor) here, since |
111 | // all types should be complete here. |
112 | CodeGenLLVM::CodeGenLLVM() = default; |
113 | CodeGenLLVM::~CodeGenLLVM() = default; |
114 | CodeGenLLVM::DebugInfo::~DebugInfo() = default; |
115 | |
116 | std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(LLVMTarget* llvm_target) { |
117 | std::string target = llvm_target->GetOrCreateTargetMachine()->getTarget().getName(); |
118 | std::string factory_template = "tvm.codegen.llvm.target_" ; |
119 | void* handle = nullptr; |
120 | if (const PackedFunc* f = runtime::Registry::Get(factory_template + target)) { |
121 | handle = (*f)(); |
122 | } else if (const PackedFunc* f = runtime::Registry::Get(factory_template + "cpu" )) { |
123 | handle = (*f)(); |
124 | } else { |
125 | LOG(FATAL) << "no factory function for codegen for target " << target; |
126 | } |
127 | if (handle) { |
128 | return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle)); |
129 | } else { |
130 | LOG(FATAL) << "unable to create codegen for target " << target; |
131 | } |
132 | } |
133 | |
134 | void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, |
135 | bool dynamic_lookup, bool target_c_runtime) { |
136 | llvm_target_ = llvm_target; |
137 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
138 | builder_.reset(new IRBuilder(*ctx)); |
139 | module_.reset(new llvm::Module(module_name, *ctx)); |
140 | md_builder_.reset(new llvm::MDBuilder(*ctx)); |
141 | // types |
142 | t_void_ = llvm::Type::getVoidTy(*ctx); |
143 | t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(GetGlobalAddressSpace()); |
144 | t_int_ = llvm::Type::getInt32Ty(*ctx); |
145 | t_char_ = llvm::Type::getInt8Ty(*ctx); |
146 | t_int8_ = llvm::Type::getInt8Ty(*ctx); |
147 | t_int16_ = llvm::Type::getInt16Ty(*ctx); |
148 | t_int32_ = llvm::Type::getInt32Ty(*ctx); |
149 | t_int64_ = llvm::Type::getInt64Ty(*ctx); |
150 | t_float64_ = llvm::Type::getDoubleTy(*ctx); |
151 | // meta data |
152 | md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1); |
153 | md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa" ); |
154 | md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias" , md_tbaa_root_); |
155 | InitTarget(); |
156 | } |
157 | |
158 | void CodeGenLLVM::SetFastMathFlags(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } |
159 | |
160 | void CodeGenLLVM::InitTarget() { |
161 | llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); |
162 | module_->setTargetTriple(tm->getTargetTriple().str()); |
163 | module_->setDataLayout(tm->createDataLayout()); |
164 | data_layout_.reset(new llvm::DataLayout(module_.get())); |
165 | if (native_vector_bits_ == 0) { |
166 | const auto& arch = tm->getTargetTriple().getArch(); |
167 | if (arch == llvm::Triple::x86_64) { |
168 | // for avx512 |
169 | native_vector_bits_ = 512; |
170 | } else if (arch == llvm::Triple::x86) { |
171 | native_vector_bits_ = 256; |
172 | } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) { |
173 | native_vector_bits_ = 128; |
174 | } else { |
175 | native_vector_bits_ = 128; |
176 | std::string arch_name = std::string(tm->getTargetTriple().getArchName()); |
177 | LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name; |
178 | } |
179 | } |
180 | |
181 | #if TVM_LLVM_VERSION >= 60 |
182 | bool use_float16_abi = false; |
183 | #if TVM_LLVM_VERSION >= 150 |
184 | // For conversions between _Float16 and float, LLVM uses runtime functions |
185 | // __extendhfsf2 and __truncsfhf2. On X86 up until version 14, LLVM used |
186 | // "uint16_t" for representing _Float16. Starting with LLVM 15, half-precision |
187 | // values can be passed in XMM registers (i.e. as floating-point). This happens |
188 | // when the compilation target has SSE2 enabled (either directly, or by enabling |
189 | // a feature that implies SSE2). |
190 | // Because the names of the conversion functions remain unchanged, it is impossible |
191 | // for TVM to provide them in the runtime, and have them work in both cases. |
192 | // To alleviate this issue, emit these functions directly into the target module |
193 | // after detecting whether or not to use floating-point ABI. To allow the linker |
194 | // to remove potential duplicates (or if they are unused), they are weak and |
195 | // reside in a separate section (ELF). |
196 | llvm::Triple::ArchType arch_type = tm->getTargetTriple().getArch(); |
197 | if (arch_type == llvm::Triple::x86 || arch_type == llvm::Triple::x86_64) { |
198 | // Detect if SSE2 is enabled. This determines whether float16 ABI is used. |
199 | std::stringstream os; |
200 | const char fname[] = "test_sse2" ; |
201 | os << "target triple = \"" << llvm_target_->GetTargetTriple() << "\"\n" |
202 | << "define void @" << fname << "() #0 { ret void } attributes #0 = { \"target-cpu\"=\"" |
203 | << llvm_target_->GetCPU() << "\" " ; |
204 | if (auto&& fs = llvm_target_->GetTargetFeatureString(); !fs.empty()) { |
205 | os << "\"target-features\"=\"" << fs << "\" " ; |
206 | } |
207 | os << "}\n" ; |
208 | auto mod = llvm_target_->GetInstance().ParseIR(os.str()); |
209 | auto* test_sse2 = mod->getFunction(fname); |
210 | ICHECK(test_sse2 != nullptr) << "Module creation error" ; |
211 | use_float16_abi = tm->getSubtargetImpl(*test_sse2)->checkFeatures("+sse2" ); |
212 | } |
213 | #endif // TVM_LLVM_VERSION >= 150 |
214 | |
215 | // Call this function only with LLVM >= 6.0. The code it emits uses "dso_local" |
216 | // which was introduced in LLVM 6. |
217 | EmitFloat16ConversionBuiltins(use_float16_abi); |
218 | #endif // TVM_LLVM_VERSION >= 60 |
219 | } |
220 | |
221 | void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); } |
222 | |
223 | void CodeGenLLVM::InitFuncState() { |
224 | var_map_.clear(); |
225 | alias_var_set_.clear(); |
226 | alloc_storage_info_.clear(); |
227 | volatile_buf_.clear(); |
228 | analyzer_.reset(new arith::Analyzer()); |
229 | } |
230 | |
231 | void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { |
232 | this->InitFuncState(); |
233 | |
234 | ICHECK_EQ(f->buffer_map.size(), 0U) |
235 | << "Cannot codegen function with buffer_map, please lower them first" ; |
236 | |
237 | std::vector<llvm::Type*> param_types; |
238 | is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias); |
239 | for (Var param : f->params) { |
240 | param_types.push_back(GetLLVMType(param)); |
241 | if (!is_restricted_ && param.dtype().is_handle()) { |
242 | alias_var_set_.insert(param.get()); |
243 | } |
244 | } |
245 | // TODO(tvm-team): |
246 | // Update the function type to respect the ret_type field of f. |
247 | // Once we allow more flexibility in the PrimFunc. |
248 | llvm::FunctionType* ftype = |
249 | llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false); |
250 | |
251 | auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); |
252 | ICHECK(global_symbol.defined()) |
253 | << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute" ; |
254 | function_ = module_->getFunction(MakeStringRef(global_symbol.value())); |
255 | if (function_ == nullptr) { |
256 | function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, |
257 | MakeStringRef(global_symbol.value()), module_.get()); |
258 | } |
259 | function_->setCallingConv(llvm::CallingConv::C); |
260 | function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); |
261 | SetTargetAttributes(function_); |
262 | |
263 | // set var map and align information |
264 | auto arg_it = function_->arg_begin(); |
265 | for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) { |
266 | llvm::Argument* v = &(*arg_it); |
267 | const Var& var = f->params[i]; |
268 | var_map_[var.get()] = v; |
269 | v->setName(std::string(var->name_hint)); |
270 | if (is_restricted_) { |
271 | if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) { |
272 | // set non alias. |
273 | #if TVM_LLVM_VERSION >= 50 |
274 | function_->addParamAttr(i, llvm::Attribute::NoAlias); |
275 | #else |
276 | function_->setDoesNotAlias(i + 1); |
277 | #endif |
278 | } |
279 | } |
280 | } |
281 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
282 | llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx, "entry" , function_); |
283 | builder_->SetInsertPoint(entry); |
284 | this->VisitStmt(f->body); |
285 | |
286 | // Add alignment attribute if needed. |
287 | #if TVM_LLVM_VERSION >= 50 |
288 | for (size_t i = 0; i < f->params.size(); ++i) { |
289 | const Var& var = f->params[i]; |
290 | auto f = alloc_storage_info_.find(var.get()); |
291 | if (f != alloc_storage_info_.end()) { |
292 | unsigned align = f->second.alignment; |
293 | if (align > 1) { |
294 | auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align); |
295 | function_->addParamAttr(i, attr); |
296 | } |
297 | } |
298 | } |
299 | #endif |
300 | |
301 | EmitDebugLocation(f->span); |
302 | if (ret_void) { |
303 | builder_->CreateRetVoid(); |
304 | } else { |
305 | builder_->CreateRet(ConstInt32(0)); |
306 | } |
307 | } |
308 | |
309 | std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() { |
310 | this->AddStartupFunction(); |
311 | for (size_t i = 0; i < link_modules_.size(); ++i) { |
312 | ICHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i]))) |
313 | << "Failed to link modules" ; |
314 | } |
315 | link_modules_.clear(); |
316 | // optimize |
317 | this->Optimize(); |
318 | return std::move(module_); |
319 | } |
320 | |
321 | void CodeGenLLVM::HandleImport(const std::string& code) { |
322 | llvm::StringRef code_str(code); |
323 | std::unique_ptr<llvm::Module> mlib; |
324 | if (code_str.endswith(".ll" ) || code_str.endswith(".bc" )) { |
325 | mlib = llvm_target_->GetInstance().LoadIR(code); |
326 | } else { |
327 | mlib = llvm_target_->GetInstance().ParseIR(code); |
328 | } |
329 | |
330 | mlib->setTargetTriple(llvm_target_->GetTargetTriple()); |
331 | mlib->setDataLayout(llvm_target_->GetOrCreateTargetMachine()->createDataLayout()); |
332 | // mark all the functions as force inline |
333 | for (llvm::Function& f : mlib->functions()) { |
334 | f.removeFnAttr(llvm::Attribute::NoInline); |
335 | f.addFnAttr(llvm::Attribute::AlwaysInline); |
336 | f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage); |
337 | } |
338 | // add to linker libraries. |
339 | this->AddLinkModule(std::move(mlib)); |
340 | } |
341 | |
342 | void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) { |
343 | link_modules_.emplace_back(std::move(mod)); |
344 | } |
345 | |
346 | void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) { |
347 | LOG(FATAL) << "not implemented" ; |
348 | } |
349 | |
350 | llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented" ; } |
351 | |
352 | llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) { LOG(FATAL) << "not implemented" ; } |
353 | |
354 | #if TVM_LLVM_VERSION >= 160 |
355 | |
356 | // Use new pass manager |
357 | |
358 | void CodeGenLLVM::Optimize() { |
359 | llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); |
360 | |
361 | bool debug_logging = false; |
362 | bool verify_each = false; |
363 | |
364 | llvm::PipelineTuningOptions pto = llvm::PipelineTuningOptions(); |
365 | llvm::PassInstrumentationCallbacks pic; |
366 | llvm::PassBuilder builder(tm, pto, std::nullopt, &pic); |
367 | |
368 | llvm::LoopAnalysisManager lam; |
369 | llvm::FunctionAnalysisManager fam; |
370 | llvm::CGSCCAnalysisManager cgam; |
371 | llvm::ModuleAnalysisManager mam; |
372 | builder.registerLoopAnalyses(lam); |
373 | builder.registerFunctionAnalyses(fam); |
374 | builder.registerCGSCCAnalyses(cgam); |
375 | builder.registerModuleAnalyses(mam); |
376 | builder.crossRegisterProxies(lam, fam, cgam, mam); |
377 | |
378 | // Construct the default pass pipeline depending on the opt level. |
379 | std::string pipeline; |
380 | switch (llvm_target_->GetOptLevel()) { |
381 | case llvm::CodeGenOpt::Level::None: |
382 | pipeline = "default<O0>" ; |
383 | break; |
384 | case llvm::CodeGenOpt::Level::Less: |
385 | pipeline = "default<O1>" ; |
386 | break; |
387 | case llvm::CodeGenOpt::Level::Default: |
388 | pipeline = "default<O2>" ; |
389 | break; |
390 | default: |
391 | // CodeGenOpt::Level::Aggressive |
392 | pipeline = "default<O3>" ; |
393 | break; |
394 | } |
395 | |
396 | llvm::StandardInstrumentations si(*llvm_target_->GetContext(), debug_logging, verify_each); |
397 | si.registerCallbacks(pic, &fam); |
398 | llvm::ModulePassManager mpass; |
399 | if (verify_each) { |
400 | mpass.addPass(llvm::VerifierPass()); |
401 | } |
402 | if (auto err = builder.parsePassPipeline(mpass, pipeline)) { |
403 | LOG(FATAL) << "error parsing pass pipeline '" << pipeline |
404 | << "':" << llvm::toString(std::move(err)) << '\n'; |
405 | } |
406 | |
407 | mpass.run(*module_, mam); |
408 | } |
409 | |
410 | #else // TVM_LLVM_VERSION |
411 | |
412 | class FPassManager : public llvm::legacy::FunctionPassManager { |
413 | public: |
414 | explicit FPassManager(llvm::Module* m) : llvm::legacy::FunctionPassManager(m) {} |
415 | // override add to allow messaging |
416 | void add(llvm::Pass* p) final { llvm::legacy::FunctionPassManager::add(p); } |
417 | }; |
418 | |
419 | class MPassManager : public llvm::legacy::PassManager { |
420 | public: |
421 | // override add to allow messaging |
422 | void add(llvm::Pass* p) final { llvm::legacy::PassManager::add(p); } |
423 | }; |
424 | |
425 | void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {} |
426 | |
427 | void CodeGenLLVM::Optimize() { |
428 | // pass manager |
429 | FPassManager fpass(module_.get()); |
430 | MPassManager mpass; |
431 | llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); |
432 | mpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); |
433 | fpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); |
434 | |
435 | // place optimization pass |
436 | llvm::PassManagerBuilder builder; |
437 | |
438 | // Use the same opt-level as specified in TargetMachine for running passes |
439 | llvm::CodeGenOpt::Level opt_level = llvm_target_->GetOptLevel(); |
440 | |
441 | switch (opt_level) { |
442 | case llvm::CodeGenOpt::Level::None: |
443 | builder.OptLevel = 0; |
444 | break; |
445 | case llvm::CodeGenOpt::Level::Less: |
446 | builder.OptLevel = 1; |
447 | break; |
448 | |
449 | case llvm::CodeGenOpt::Level::Default: |
450 | builder.OptLevel = 2; |
451 | break; |
452 | |
453 | default: |
454 | // CodeGenOpt::Level::Aggressive |
455 | builder.OptLevel = 3; |
456 | } |
457 | |
458 | #if TVM_LLVM_VERSION >= 50 |
459 | builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false); |
460 | #else |
461 | builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0); |
462 | #endif |
463 | builder.LoopVectorize = true; |
464 | builder.SLPVectorize = true; |
465 | this->InitPassManagerBuilder(&builder); |
466 | |
467 | #if TVM_LLVM_VERSION >= 50 |
468 | tm->adjustPassManager(builder); |
469 | #endif |
470 | |
471 | builder.populateFunctionPassManager(fpass); |
472 | builder.populateModulePassManager(mpass); |
473 | |
474 | fpass.doInitialization(); |
475 | for (auto it = module_->begin(); it != module_->end(); ++it) { |
476 | fpass.run(*it); |
477 | } |
478 | fpass.doFinalization(); |
479 | mpass.run(*module_); |
480 | } |
481 | #endif // TVM_LLVM_VERSION |
482 | |
483 | int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const { |
484 | return native_vector_bits_; |
485 | } |
486 | |
487 | unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; } |
488 | |
489 | llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { |
490 | if (dtype.is_handle()) { |
491 | ICHECK_EQ(dtype.lanes(), 1); |
492 | return t_void_p_; |
493 | } |
494 | if (dtype.is_void()) { |
495 | return t_void_; |
496 | } |
497 | llvm::Type* etype = nullptr; |
498 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
499 | if (dtype.is_int() || dtype.is_uint()) { |
500 | etype = llvm::Type::getIntNTy(*ctx, dtype.bits()); |
501 | } else if (dtype.is_float()) { |
502 | switch (dtype.bits()) { |
503 | case 16: |
504 | etype = llvm::Type::getHalfTy(*ctx); |
505 | break; |
506 | case 32: |
507 | etype = llvm::Type::getFloatTy(*ctx); |
508 | break; |
509 | case 64: |
510 | etype = llvm::Type::getDoubleTy(*ctx); |
511 | break; |
512 | default: |
513 | LOG(FATAL) << "do not support " << dtype; |
514 | } |
515 | } |
516 | if (dtype.lanes() != 1) { |
517 | #if TVM_LLVM_VERSION >= 110 |
518 | return llvm::FixedVectorType::get(etype, dtype.lanes()); |
519 | #else |
520 | return llvm::VectorType::get(etype, dtype.lanes()); |
521 | #endif |
522 | } else { |
523 | return etype; |
524 | } |
525 | } // namespace codegen |
526 | |
527 | llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { |
528 | if (auto* ptr = type.as<PrimTypeNode>()) { |
529 | return DTypeToLLVMType(ptr->dtype); |
530 | } else if (auto* ptr = type.as<PointerTypeNode>()) { |
531 | // LLVM IR doesn't allow void*, so we need to recognize this |
532 | // pattern explicitly. |
533 | if (auto* primtype = ptr->element_type.as<PrimTypeNode>()) { |
534 | if (primtype->dtype.is_void()) { |
535 | return t_void_p_; |
536 | } |
537 | } |
538 | // TODO(tvm-team) consider put storage scope into the pointer type. |
539 | return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace()); |
540 | } else if (IsVoidType(type)) { |
541 | return t_void_; |
542 | } else { |
543 | LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM Type" ; |
544 | } |
545 | } |
546 | |
547 | llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const { |
548 | return GetLLVMType(GetType(expr)); |
549 | } |
550 | |
551 | // Add tbaa alias information for load |
552 | // |
553 | // use a binary tree typed system to declare information |
554 | // and allow alias to be distinguished across nodes. |
555 | // |
556 | // This trick comes from Halide's CodeGen_LLVM |
557 | // |
558 | void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer_var, PrimExpr index, |
559 | DataType access_dtype) { |
560 | if (alias_var_set_.count(buffer_var) != 0) { |
561 | // Mark all possibly aliased pointer as same type. |
562 | llvm::MDNode* meta = md_tbaa_alias_set_; |
563 | inst->setMetadata("tbaa" , md_builder_->createTBAAStructTagNode(meta, meta, 0)); |
564 | return; |
565 | } |
566 | |
567 | int64_t base = 0, width = 0; |
568 | arith::PVar<IntImm> pbase, pstride; |
569 | arith::PVar<int> planes; |
570 | // create meta-data for alias analysis |
571 | // Use a group of binary tree ranges of memory banks. |
572 | int64_t xwith = 0; |
573 | if (arith::ramp(pbase, pstride, planes).Match(index)) { |
574 | base = pbase.Eval()->value; |
575 | xwith = planes.Eval() * pstride.Eval()->value; |
576 | } else if (auto* ptr = index.as<tir::IntImmNode>()) { |
577 | base = ptr->value; |
578 | xwith = 1; |
579 | } |
580 | // adjust address index unit to byte |
581 | const int64_t unit_bit_width = 8; |
582 | const int64_t access_elem_bits = access_dtype.bits() * access_dtype.lanes(); |
583 | base = base * access_elem_bits / unit_bit_width; |
584 | xwith = (xwith * access_elem_bits + unit_bit_width - 1) / unit_bit_width; |
585 | if (xwith > 0) { |
586 | width = 1; |
587 | while (width < xwith) { |
588 | width *= 2; |
589 | } |
590 | while (base % width) { |
591 | base -= base % width; |
592 | width *= 2; |
593 | } |
594 | } |
595 | |
596 | llvm::MDNode* meta = md_tbaa_root_; |
597 | std::ostringstream buffer_addr; |
598 | buffer_addr << buffer_var; |
599 | meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta); |
600 | |
601 | // create a tree-shape access structure. |
602 | if (width != 0) { |
603 | for (int64_t w = 1024; w >= width; w /= 2) { |
604 | int64_t b = (base / w) * w; |
605 | std::stringstream os; |
606 | os << buffer_var << ".w" << w << ".b" << b; |
607 | meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta); |
608 | } |
609 | } |
610 | inst->setMetadata("tbaa" , md_builder_->createTBAAStructTagNode(meta, meta, 0)); |
611 | } |
612 | |
613 | void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, |
614 | int* p_alignment, int* p_native_bits) { |
615 | int max_align_bits = t.bits(); |
616 | auto it = alloc_storage_info_.find(buf_var); |
617 | if (it != alloc_storage_info_.end()) { |
618 | const StorageInfo& info = it->second; |
619 | *p_native_bits = |
620 | NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef<Var>(buf_var)))); |
621 | max_align_bits = info.alignment * 8; |
622 | } else { |
623 | *p_native_bits = native_vector_bits_; |
624 | } |
625 | |
626 | arith::ModularSet me = analyzer_->modular_set(index); |
627 | int64_t base = me->base; |
628 | int64_t coeff = me->coeff; |
629 | |
630 | int align_bits = t.bits(); |
631 | while (align_bits < max_align_bits && base % 2 == 0 && coeff % 2 == 0) { |
632 | base = base / 2; |
633 | coeff = coeff / 2; |
634 | align_bits *= 2; |
635 | } |
636 | if (align_bits < 8) { |
637 | align_bits = 8; |
638 | } |
639 | *p_alignment = align_bits / 8; |
640 | } |
641 | |
642 | llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t size, |
643 | unsigned int shared_address_space, |
644 | int alignment, |
645 | llvm::GlobalValue::LinkageTypes linkage) { |
646 | llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size); |
647 | llvm::GlobalVariable* global = |
648 | new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem" , nullptr, |
649 | llvm::GlobalValue::NotThreadLocal, shared_address_space); |
650 | #if TVM_LLVM_VERSION >= 100 |
651 | global->setAlignment(llvm::Align(alignment)); |
652 | #else |
653 | global->setAlignment(alignment); |
654 | #endif |
655 | return global; |
656 | } |
657 | |
658 | std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { |
659 | #if TVM_LLVM_VERSION >= 100 |
660 | auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>(); |
661 | debug_info->di_builder_ = std::make_unique<llvm::DIBuilder>(*module); |
662 | #else |
663 | auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>(); |
664 | debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module); |
665 | #endif |
666 | // TODO(tulloch): pass this information through relay::Span classes to the IRModule instance? |
667 | debug_info->file_ = debug_info->di_builder_->createFile("main.tir" , "." ); |
668 | const int runtime_version = 0; |
669 | const bool is_optimized = false; |
670 | const char* compiler_flags = "" ; |
671 | debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit( |
672 | /*Lang=*/llvm::dwarf::DW_LANG_C, /*File=*/debug_info->file_, /*Producer=*/"TVM" , is_optimized, |
673 | compiler_flags, runtime_version); |
674 | return debug_info; |
675 | } |
676 | |
677 | llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { |
678 | #if TVM_LLVM_VERSION >= 110 |
679 | llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes); |
680 | #else |
681 | llvm::Type* type = llvm::VectorType::get(value->getType(), lanes); |
682 | #endif |
683 | llvm::Constant* undef = llvm::UndefValue::get(type); |
684 | llvm::Constant* zero = ConstInt32(0); |
685 | value = builder_->CreateInsertElement(undef, value, zero); |
686 | #if TVM_LLVM_VERSION >= 120 |
687 | llvm::Constant* mask = llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero); |
688 | #elif TVM_LLVM_VERSION >= 110 |
689 | llvm::Constant* mask = |
690 | llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero); |
691 | #else |
692 | llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); |
693 | #endif |
694 | return builder_->CreateShuffleVector(value, undef, mask); |
695 | } |
696 | |
697 | llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { |
698 | int num_elems = GetVectorNumElements(vec); |
699 | if (extent == num_elems && begin == 0) return vec; |
700 | ICHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n" ; |
701 | std::vector<llvm::Constant*> indices; |
702 | indices.reserve(extent); |
703 | for (int i = 0; i < extent; ++i) { |
704 | if (begin + i >= 0 && begin + i < num_elems) { |
705 | indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i)); |
706 | } else { |
707 | indices.push_back(llvm::UndefValue::get(t_int32_)); |
708 | } |
709 | } |
710 | return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices)); |
711 | } |
712 | |
713 | llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { |
714 | int num_elems = GetVectorNumElements(vec); |
715 | #if TVM_LLVM_VERSION >= 110 |
716 | std::vector<int> indices; |
717 | #else |
718 | std::vector<unsigned> indices; |
719 | #endif |
720 | for (int i = 0; i < num_elems; ++i) { |
721 | indices.push_back(num_elems - i - 1); |
722 | } |
723 | return builder_->CreateShuffleVector(vec, vec, indices); |
724 | } |
725 | |
726 | llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { |
727 | llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes))); |
728 | int num_elems = GetVectorNumElements(vec); |
729 | if (num_elems == target_lanes) return vec; |
730 | ICHECK_LT(num_elems, target_lanes); |
731 | for (int i = 0; i < num_elems; ++i) { |
732 | mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i)); |
733 | } |
734 | return builder_->CreateShuffleVector(vec, vec, mask); |
735 | } |
736 | |
737 | llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) { |
738 | // To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane |
739 | // LLVM vector types. |
740 | for (size_t i = 0, e = vecs.size(); i != e; ++i) { |
741 | llvm::Value* v = vecs[i]; |
742 | if (!v->getType()->isVectorTy()) { |
743 | #if TVM_LLVM_VERSION >= 110 |
744 | llvm::Type* vec_ty = llvm::FixedVectorType::get(v->getType(), 1); |
745 | #else |
746 | llvm::Type* vec_ty = llvm::VectorType::get(v->getType(), 1); |
747 | #endif |
748 | vecs[i] = builder_->CreateInsertElement(llvm::UndefValue::get(vec_ty), v, ConstInt32(0)); |
749 | } |
750 | } |
751 | |
752 | // concat vector, tree shape reduction |
753 | int total_lanes = 0; |
754 | |
755 | for (llvm::Value* v : vecs) { |
756 | total_lanes += GetVectorNumElements(v); |
757 | } |
758 | while (vecs.size() > 1) { |
759 | std::vector<llvm::Value*> new_vecs; |
760 | for (size_t i = 0; i < vecs.size() - 1; i += 2) { |
761 | llvm::Value* lhs = vecs[i]; |
762 | llvm::Value* rhs = vecs[i + 1]; |
763 | const size_t lhs_lanes = GetVectorNumElements(lhs); |
764 | const size_t rhs_lanes = GetVectorNumElements(rhs); |
765 | if (lhs_lanes < rhs_lanes) { |
766 | lhs = CreateVecPad(lhs, rhs_lanes); |
767 | } else if (rhs_lanes < lhs_lanes) { |
768 | rhs = CreateVecPad(rhs, lhs_lanes); |
769 | } |
770 | const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes); |
771 | #if TVM_LLVM_VERSION >= 110 |
772 | std::vector<int> mask; |
773 | #else |
774 | std::vector<unsigned> mask; |
775 | #endif |
776 | for (size_t i = 0; i < lhs_lanes; ++i) { |
777 | mask.push_back(i); |
778 | } |
779 | for (size_t i = 0; i < rhs_lanes; ++i) { |
780 | mask.push_back(shared_lanes + i); |
781 | } |
782 | new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask)); |
783 | } |
784 | if (vecs.size() % 2 != 0) { |
785 | new_vecs.push_back(vecs.back()); |
786 | } |
787 | vecs.swap(new_vecs); |
788 | } |
789 | return CreateVecSlice(vecs[0], 0, total_lanes); |
790 | } |
791 | |
792 | void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, |
793 | const Var& loop_var, const Stmt& body) { |
794 | EmitDebugLocation(body->span); |
795 | llvm::BasicBlock* pre_block = builder_->GetInsertBlock(); |
796 | std::string loop_var_name = loop_var->name_hint; |
797 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
798 | auto* for_begin = llvm::BasicBlock::Create(*ctx, "for_begin_" + loop_var_name, function_); |
799 | auto* for_body = llvm::BasicBlock::Create(*ctx, "for_body_" + loop_var_name, function_); |
800 | auto* for_end = llvm::BasicBlock::Create(*ctx, "for_end_" + loop_var_name, function_); |
801 | builder_->CreateBr(for_begin); |
802 | builder_->SetInsertPoint(for_begin); |
803 | llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2); |
804 | loop_value->setName(loop_var->name_hint.c_str()); |
805 | loop_value->addIncoming(begin, pre_block); |
806 | ICHECK(!var_map_.count(loop_var.get())); |
807 | var_map_[loop_var.get()] = loop_value; |
808 | auto lt = CreateLT(loop_var.dtype(), loop_value, end); |
809 | builder_->CreateCondBr(lt, for_body, for_end, md_very_likely_branch_); |
810 | builder_->SetInsertPoint(for_body); |
811 | this->VisitStmt(body); |
812 | var_map_.erase(loop_var.get()); |
813 | llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride); |
814 | loop_value->addIncoming(loop_next, builder_->GetInsertBlock()); |
815 | builder_->CreateBr(for_begin); |
816 | builder_->SetInsertPoint(for_end); |
817 | } |
818 | |
819 | // cast operatpr |
820 | llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { |
821 | llvm::Type* target = DTypeToLLVMType(to); |
822 | if (value->getType() == target) return value; |
823 | if (to.is_handle()) { |
824 | return builder_->CreateBitCast(value, target); |
825 | } else if (to.is_uint() && to.bits() == 1) { |
826 | if (from.is_float()) { |
827 | llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.); |
828 | return builder_->CreateFCmpONE(value, zero); |
829 | } else { |
830 | llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0); |
831 | return builder_->CreateICmpNE(value, zero); |
832 | } |
833 | } else if (!from.is_float() && !to.is_float()) { |
834 | return builder_->CreateIntCast(value, target, from.is_int()); |
835 | } else if (from.is_float() && to.is_int()) { |
836 | return builder_->CreateFPToSI(value, target); |
837 | } else if (from.is_float() && to.is_uint()) { |
838 | if (to.bits() < 8) { |
839 | value = builder_->CreateFPToUI(value, DTypeToLLVMType(to.with_bits(8))); |
840 | return builder_->CreateIntCast(value, target, false); |
841 | } else { |
842 | return builder_->CreateFPToUI(value, target); |
843 | } |
844 | } else if (from.is_int() && to.is_float()) { |
845 | return builder_->CreateSIToFP(value, target); |
846 | } else if (from.is_uint() && to.is_float()) { |
847 | return builder_->CreateUIToFP(value, target); |
848 | } else { |
849 | ICHECK(from.is_float() && to.is_float()); |
850 | return builder_->CreateFPCast(value, target); |
851 | } |
852 | } |
853 | |
854 | llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const std::string& name, |
855 | llvm::GlobalValue::LinkageTypes linkage_type) { |
856 | llvm::Type* ty = const_data->getType(); |
857 | llvm::GlobalVariable* global = |
858 | new llvm::GlobalVariable(*module_, ty, true, linkage_type, const_data, name); |
859 | #if TVM_LLVM_VERSION >= 100 |
860 | global->setAlignment(llvm::Align(1)); |
861 | #else |
862 | global->setAlignment(1); |
863 | #endif |
864 | llvm::Constant* zero = ConstInt32(0); |
865 | llvm::Constant* indices[] = {zero, zero}; |
866 | llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(ty, global, indices); |
867 | return ptr; |
868 | } |
869 | |
870 | llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { |
871 | auto it = str_map_.find(str); |
872 | if (it != str_map_.end()) return it->second; |
873 | auto llvm_str = llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), str); |
874 | auto ptr = GetGlobalConstant(llvm_str, ".str" , llvm::GlobalValue::PrivateLinkage); |
875 | str_map_[str] = ptr; |
876 | return ptr; |
877 | } |
878 | |
879 | CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr, |
880 | DataType buffer_element_dtype, |
881 | llvm::ArrayRef<llvm::Value*> indices, |
882 | DataType value_dtype) { |
883 | ICHECK_EQ(indices.size(), 1) << "CodeGenLLVM requires all buffers to be flat 1-d buffers." ; |
884 | llvm::Value* index = indices[0]; |
885 | |
886 | llvm::PointerType* buffer_ptr_type = llvm::dyn_cast<llvm::PointerType>(buffer_ptr->getType()); |
887 | ICHECK(buffer_ptr_type != nullptr); |
888 | auto address_space = buffer_ptr_type->getAddressSpace(); |
889 | |
890 | llvm::Type* element_type = DTypeToLLVMType(buffer_element_dtype); |
891 | llvm::PointerType* element_ptr_type = |
892 | DTypeToLLVMType(buffer_element_dtype)->getPointerTo(address_space); |
893 | llvm::Type* value_type = DTypeToLLVMType(value_dtype); |
894 | llvm::PointerType* value_ptr_type = value_type->getPointerTo(address_space); |
895 | |
896 | ICHECK(index->getType()->isIntegerTy()) << "Expected buffer index to be an integer" ; |
897 | |
898 | if (buffer_ptr_type != element_ptr_type) { |
899 | buffer_ptr = builder_->CreatePointerCast(buffer_ptr, element_ptr_type); |
900 | } |
901 | ICHECK(!HasAlignmentPadding(buffer_element_dtype)) |
902 | << "DType " << buffer_element_dtype |
903 | << " has padding for alignment. TVM data arrays are expected to be densely packed, with no " |
904 | "padding for alignment." ; |
905 | llvm::Value* value_ptr = builder_->CreateInBoundsGEP(element_type, buffer_ptr, index); |
906 | |
907 | if (element_ptr_type != value_ptr_type) { |
908 | value_ptr = builder_->CreatePointerCast(value_ptr, value_ptr_type); |
909 | } |
910 | |
911 | return TypedPointer(value_type, value_ptr); |
912 | } |
913 | |
914 | llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { |
915 | auto it = var_map_.find(v); |
916 | ICHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint; |
917 | return it->second; |
918 | } |
919 | |
920 | void CodeGenLLVM::CreatePrintf(const std::string& format, |
921 | llvm::ArrayRef<llvm::Value*> format_args) { |
922 | EmitDebugLocation(); |
923 | llvm::Function* func_printf = module_->getFunction("printf" ); |
924 | if (func_printf == nullptr) { |
925 | llvm::FunctionType* ftype = llvm::FunctionType::get(t_int32_, true); |
926 | func_printf = |
927 | llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "printf" , module_.get()); |
928 | } |
929 | |
930 | llvm::Function* func_fflush = module_->getFunction("fflush" ); |
931 | if (!func_fflush) { |
932 | llvm::FunctionType* ftype = llvm::FunctionType::get(t_int32_, {t_void_p_}, false); |
933 | func_fflush = |
934 | llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "fflush" , module_.get()); |
935 | } |
936 | |
937 | llvm::Value* str = builder_->CreateGlobalStringPtr(format); |
938 | str->setName("printf_format_str" ); |
939 | |
940 | std::vector<llvm::Value*> printf_args = {str}; |
941 | printf_args.insert(printf_args.end(), format_args.begin(), format_args.end()); |
942 | builder_->CreateCall(func_printf, printf_args); |
943 | |
944 | // Call fflush() immediately, as this utility is intended for debug |
945 | // purposes. A segfault occurring within the generated LLVM code |
946 | // would otherwise leave the stdout buffer unflushed. |
947 | llvm::Value* null_stream = llvm::ConstantPointerNull::get(t_void_p_); |
948 | null_stream->setName("null_stream" ); |
949 | builder_->CreateCall(func_fflush, {null_stream}); |
950 | } |
951 | |
952 | llvm::Value* CodeGenLLVM::CreateLookupReturnAddress(unsigned int level) { |
953 | EmitDebugLocation(); |
954 | llvm::Value* level_val = llvm::ConstantInt::get(t_int32_, level); |
955 | llvm::Function* builtin = |
956 | llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::returnaddress); |
957 | llvm::Value* call = builder_->CreateCall(builtin, level_val); |
958 | call->setName("return_addr" ); |
959 | |
960 | return call; |
961 | } |
962 | |
963 | llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol, |
964 | const Array<PrimExpr>& args, bool skip_first_arg) { |
965 | std::vector<llvm::Value*> arg_value; |
966 | std::vector<llvm::Type*> arg_type; |
967 | for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) { |
968 | arg_value.push_back(MakeValue(args[i])); |
969 | arg_type.push_back(arg_value.back()->getType()); |
970 | } |
971 | llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_type, false); |
972 | llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol)); |
973 | if (f == nullptr) { |
974 | f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, MakeStringRef(global_symbol), |
975 | module_.get()); |
976 | } |
977 | llvm::CallInst* call = builder_->CreateCall(f, arg_value); |
978 | return call; |
979 | } |
980 | |
981 | llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, |
982 | llvm::ArrayRef<llvm::Type*> arg_types) { |
983 | llvm::Module* module = module_.get(); |
984 | |
985 | if (!llvm::Intrinsic::isOverloaded(id)) { |
986 | return llvm::Intrinsic::getDeclaration(module, id, {}); |
987 | } |
988 | |
989 | llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos; |
990 | llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos); |
991 | llvm::SmallVector<llvm::Type*, 4> overload_types; |
992 | |
993 | #if TVM_LLVM_VERSION >= 90 |
994 | auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) { |
995 | overload_types.clear(); |
996 | llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos); |
997 | auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); |
998 | if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) { |
999 | bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref); |
1000 | if (error) { |
1001 | return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg; |
1002 | } |
1003 | } |
1004 | return match; |
1005 | }; |
1006 | |
1007 | // First, try matching the signature assuming non-vararg case. |
1008 | auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false); |
1009 | switch (try_match(fn_ty, false)) { |
1010 | case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet: |
1011 | // The return type doesn't match, there is nothing else to do. |
1012 | return nullptr; |
1013 | case llvm::Intrinsic::MatchIntrinsicTypes_Match: |
1014 | return llvm::Intrinsic::getDeclaration(module, id, overload_types); |
1015 | case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg: |
1016 | break; |
1017 | } |
1018 | |
1019 | // Keep adding one type at a time (starting from empty list), and |
1020 | // try matching the vararg signature. |
1021 | llvm::SmallVector<llvm::Type*, 4> var_types; |
1022 | for (int i = 0, e = arg_types.size(); i <= e; ++i) { |
1023 | if (i > 0) var_types.push_back(arg_types[i - 1]); |
1024 | auto* ft = llvm::FunctionType::get(ret_type, var_types, true); |
1025 | if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) { |
1026 | return llvm::Intrinsic::getDeclaration(module, id, overload_types); |
1027 | } |
1028 | } |
1029 | // Failed to identify the type. |
1030 | return nullptr; |
1031 | |
1032 | #else // TVM_LLVM_VERSION |
1033 | llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos); |
1034 | // matchIntrinsicType returns true on error. |
1035 | if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) { |
1036 | return nullptr; |
1037 | } |
1038 | for (llvm::Type* t : arg_types) { |
1039 | if (llvm::Intrinsic::matchIntrinsicType(t, ref, overload_types)) { |
1040 | return nullptr; |
1041 | } |
1042 | } |
1043 | return llvm::Intrinsic::getDeclaration(module, id, overload_types); |
1044 | #endif // TVM_LLVM_VERSION |
1045 | } |
1046 | |
1047 | void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) { |
1048 | const std::string& cpu = llvm_target_->GetCPU(); |
1049 | if (!cpu.empty()) { |
1050 | func->addFnAttr("target-cpu" , cpu); |
1051 | } |
1052 | const std::string& features = llvm_target_->GetTargetFeatureString(); |
1053 | if (!features.empty()) { |
1054 | func->addFnAttr("target-features" , features); |
1055 | } |
1056 | } |
1057 | |
1058 | void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) { |
1059 | // The LLVM IR for these function was obtained by compiling |
1060 | // |
1061 | // For integer ABI: |
1062 | // __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(a); |
1063 | // __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(a); |
1064 | // For floating-point ABI: |
1065 | // __truncXfYf2__<float, uint32_t, 23, _Float16, uint16_t, 10>(a); |
1066 | // __extendXfYf2__<_Float16, uint16_t, 10, float, uint32_t, 23>(a); |
1067 | |
1068 | static const char trunc_body[] = // __truncsfhf2 |
1069 | " %v0 = bitcast float %a0 to i32\n" |
1070 | " %v1 = and i32 %v0, 2147483647\n" |
1071 | " %v2 = add nsw i32 %v1, -947912704\n" |
1072 | " %v3 = add nsw i32 %v1, -1199570944\n" |
1073 | " %v4 = icmp ult i32 %v2, %v3\n" |
1074 | " br i1 %v4, label %b1, label %b5\n" |
1075 | "b1:\n" |
1076 | " %v5 = lshr i32 %v0, 13\n" |
1077 | " %v6 = and i32 %v5, 65535\n" |
1078 | " %v7 = add nuw nsw i32 %v6, -114688\n" |
1079 | " %v8 = and i32 %v0, 8191\n" |
1080 | " %v9 = icmp ugt i32 %v8, 4096\n" |
1081 | " br i1 %v9, label %b2, label %b3\n" |
1082 | "b2:\n" |
1083 | " %v10 = add nuw nsw i32 %v6, -114687\n" |
1084 | " br label %b13\n" |
1085 | "b3:\n" |
1086 | " %v11 = icmp eq i32 %v8, 4096\n" |
1087 | " br i1 %v11, label %b4, label %b13\n" |
1088 | "b4:\n" |
1089 | " %v12 = and i32 %v7, 65535\n" |
1090 | " %v13 = and i32 %v5, 1\n" |
1091 | " %v14 = add nuw nsw i32 %v12, %v13\n" |
1092 | " br label %b13\n" |
1093 | "b5:\n" |
1094 | " %v15 = icmp ugt i32 %v1, 2139095040\n" |
1095 | " br i1 %v15, label %b6, label %b7\n" |
1096 | "b6:\n" |
1097 | " %v16 = lshr i32 %v0, 13\n" |
1098 | " %v17 = and i32 %v16, 511\n" |
1099 | " %v18 = or i32 %v17, 32256\n" |
1100 | " br label %b13\n" |
1101 | "b7:\n" |
1102 | " %v19 = icmp ugt i32 %v1, 1199570943\n" |
1103 | " br i1 %v19, label %b13, label %b8\n" |
1104 | "b8:\n" |
1105 | " %v20 = icmp ult i32 %v1, 754974720\n" |
1106 | " br i1 %v20, label %b13, label %b9\n" |
1107 | "b9:\n" |
1108 | " %v21 = lshr i32 %v1, 23\n" |
1109 | " %v22 = sub nsw i32 113, %v21\n" |
1110 | " %v23 = and i32 %v0, 8388607\n" |
1111 | " %v24 = or i32 %v23, 8388608\n" |
1112 | " %v25 = add nsw i32 %v21, -81\n" |
1113 | " %v26 = shl i32 %v24, %v25\n" |
1114 | " %v27 = icmp ne i32 %v26, 0\n" |
1115 | " %v28 = lshr i32 %v24, %v22\n" |
1116 | " %v29 = zext i1 %v27 to i32\n" |
1117 | " %v30 = lshr i32 %v28, 13\n" |
1118 | " %v31 = and i32 %v28, 8191\n" |
1119 | " %v32 = or i32 %v31, %v29\n" |
1120 | " %v33 = icmp ugt i32 %v32, 4096\n" |
1121 | " br i1 %v33, label %b10, label %b11\n" |
1122 | "b10:\n" |
1123 | " %v34 = add nuw nsw i32 %v30, 1\n" |
1124 | " br label %b13\n" |
1125 | "b11:\n" |
1126 | " %v35 = icmp eq i32 %v32, 4096\n" |
1127 | " br i1 %v35, label %b12, label %b13\n" |
1128 | "b12:\n" |
1129 | " %v36 = and i32 %v30, 1\n" |
1130 | " %v37 = add nuw nsw i32 %v36, %v30\n" |
1131 | " br label %b13\n" |
1132 | "b13:\n" |
1133 | " %v38 = phi i32 [ %v18, %b6 ], [ %v10, %b2 ], [ %v14, %b4 ], [ %v7, %b3 ],\n" |
1134 | " [ 31744, %b7 ], [ 0, %b8 ], [ %v34, %b10 ], [ %v37, %b12 ],\n" |
1135 | " [ %v30, %b11 ]\n" |
1136 | " %v39 = lshr i32 %v0, 16\n" |
1137 | " %v40 = and i32 %v39, 32768\n" |
1138 | " %v41 = or i32 %v38, %v40\n" |
1139 | " %vlast = trunc i32 %v41 to i16\n" ; |
1140 | |
1141 | static const char extend_body[] = // __extendhfsf2 |
1142 | " %v1 = and i16 %vinp, 32767\n" |
1143 | " %v2 = zext i16 %v1 to i32\n" |
1144 | " %v3 = add nsw i16 %v1, -1024\n" |
1145 | " %v4 = icmp ult i16 %v3, 30720\n" |
1146 | " br i1 %v4, label %b1, label %b2\n" |
1147 | "b1:\n" |
1148 | " %v5 = shl nuw nsw i32 %v2, 13\n" |
1149 | " %v6 = add nuw nsw i32 %v5, 939524096\n" |
1150 | " br label %b6\n" |
1151 | "b2:\n" |
1152 | " %v7 = icmp ugt i16 %v1, 31743\n" |
1153 | " br i1 %v7, label %b3, label %b4\n" |
1154 | "b3:\n" |
1155 | " %v8 = shl nuw nsw i32 %v2, 13\n" |
1156 | " %v9 = or i32 %v8, 2139095040\n" |
1157 | " br label %b6\n" |
1158 | "b4:\n" |
1159 | " %v10 = icmp eq i16 %v1, 0\n" |
1160 | " br i1 %v10, label %b6, label %b5\n" |
1161 | "b5:\n" |
1162 | " %v11 = icmp ult i16 %v1, 256\n" |
1163 | " %v12 = lshr i32 %v2, 8\n" |
1164 | " %v13 = select i1 %v11, i32 %v2, i32 %v12\n" |
1165 | " %v14 = select i1 %v11, i32 32, i32 24\n" |
1166 | " %v15 = icmp ult i32 %v13, 16\n" |
1167 | " %v16 = lshr i32 %v13, 4\n" |
1168 | " %v17 = add nsw i32 %v14, -4\n" |
1169 | " %v18 = select i1 %v15, i32 %v13, i32 %v16\n" |
1170 | " %v19 = select i1 %v15, i32 %v14, i32 %v17\n" |
1171 | " %v20 = icmp ult i32 %v18, 4\n" |
1172 | " %v21 = lshr i32 %v18, 2\n" |
1173 | " %v22 = add nsw i32 %v19, -2\n" |
1174 | " %v23 = select i1 %v20, i32 %v18, i32 %v21\n" |
1175 | " %v24 = select i1 %v20, i32 %v19, i32 %v22\n" |
1176 | " %v25 = icmp ult i32 %v23, 2\n" |
1177 | " %v26 = sub nsw i32 0, %v23\n" |
1178 | " %v27 = select i1 %v25, i32 %v26, i32 -2\n" |
1179 | " %v28 = add nsw i32 %v27, %v24\n" |
1180 | " %v29 = add nsw i32 %v28, -8\n" |
1181 | " %v30 = shl i32 %v2, %v29\n" |
1182 | " %v31 = xor i32 %v30, 8388608\n" |
1183 | " %v32 = shl i32 %v28, 23\n" |
1184 | " %v33 = sub i32 1124073472, %v32\n" |
1185 | " %v34 = or i32 %v31, %v33\n" |
1186 | " br label %b6\n" |
1187 | "b6:\n" |
1188 | " %v35 = phi i32 [ %v6, %b1 ], [ %v9, %b3 ], [ %v34, %b5 ], [ 0, %b4 ]\n" |
1189 | " %v36 = and i16 %vinp, -32768\n" |
1190 | " %v37 = zext i16 %v36 to i32\n" |
1191 | " %v38 = shl nuw i32 %v37, 16\n" |
1192 | " %v39 = or i32 %v35, %v38\n" |
1193 | " %v40 = bitcast i32 %v39 to float\n" |
1194 | " ret float %v40\n" |
1195 | "}\n" ; |
1196 | |
1197 | std::string short_type = use_float16_abi ? "half" : "i16" ; |
1198 | |
1199 | std::string short_cast_in, short_cast_out; |
1200 | if (use_float16_abi) { |
1201 | short_cast_in = " %vinp = bitcast half %a0 to i16\n" ; |
1202 | short_cast_out = " %vres = bitcast i16 %vlast to half\n" ; |
1203 | } else { |
1204 | // No-ops that preserve the i16 values. |
1205 | short_cast_in = " %vinp = add i16 %a0, 0\n" ; |
1206 | short_cast_out = " %vres = add i16 %vlast, 0\n" ; |
1207 | } |
1208 | |
1209 | llvm::Triple triple(llvm_target_->GetTargetTriple()); |
1210 | |
1211 | static const char elf_section_name[] = ".text.tvm.fp16.conv" ; |
1212 | std::string section = triple.getObjectFormat() == llvm::Triple::ELF |
1213 | ? std::string("section \"" ) + elf_section_name + "\" " |
1214 | : "" ; |
1215 | |
1216 | std::string = "define weak dso_local " + short_type + |
1217 | " @__truncsfhf2(float %a0) local_unnamed_addr #0 " + section + |
1218 | "{\nb0:\n" ; |
1219 | std::string trunc_return = " ret " + short_type + " %vres\n}\n" ; |
1220 | |
1221 | std::string = "define weak dso_local float @__extendhfsf2(" + short_type + |
1222 | " %a0) local_unnamed_addr #0 " + section + "{\nb0:\n" ; |
1223 | |
1224 | // truncate = trunc_header + trunc_body + short_cast_out + trunc_return |
1225 | // extend = extend_header + short_cast_in + extend_body |
1226 | |
1227 | std::string attributes = "attributes #0 = { nounwind readnone \"target-cpu\"=\"" + |
1228 | llvm_target_->GetCPU() + "\" \"target-features\"=\"" + |
1229 | llvm_target_->GetTargetFeatureString() + "\" }\n" ; |
1230 | |
1231 | auto data_layout = llvm_target_->GetOrCreateTargetMachine()->createDataLayout(); |
1232 | std::string module_ir = "target triple = \"" + llvm_target_->GetTargetTriple() + "\"\n" + |
1233 | "target datalayout = \"" + data_layout.getStringRepresentation() + |
1234 | "\"\n" + trunc_header + trunc_body + short_cast_out + trunc_return + |
1235 | extend_header + short_cast_in + extend_body + attributes; |
1236 | |
1237 | auto builtins_module = llvm_target_->GetInstance().ParseIR(module_ir); |
1238 | link_modules_.push_back(std::move(builtins_module)); |
1239 | } |
1240 | |
1241 | llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { |
1242 | if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { |
1243 | ICHECK_GE(op->args.size(), 2U); |
1244 | llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value); |
1245 | int64_t num_signature = Downcast<IntImm>(op->args[1])->value; |
1246 | std::vector<llvm::Value*> arg_value; |
1247 | std::vector<llvm::Type*> arg_type; |
1248 | for (size_t i = 2; i < op->args.size(); ++i) { |
1249 | arg_value.push_back(MakeValue(op->args[i])); |
1250 | if (i - 2 < static_cast<size_t>(num_signature)) { |
1251 | arg_type.push_back(arg_value.back()->getType()); |
1252 | } |
1253 | } |
1254 | // LLVM's prefetch intrinsic returns "void", while TVM's prefetch |
1255 | // returns int32. This causes problems because prefetch is one of |
1256 | // those intrinsics that is generated automatically via the |
1257 | // tvm.intrin.rule mechanism. Any other intrinsic with a type |
1258 | // mismatch will have to be treated specially here. |
1259 | // TODO(kparzysz-quic): fix this once TVM prefetch uses the same |
1260 | // type as LLVM. |
1261 | llvm::Type* return_type = |
1262 | (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef<PrimExpr>(op)) : t_void_; |
1263 | llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); |
1264 | ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " |
1265 | #if TVM_LLVM_VERSION >= 130 |
1266 | << llvm::Intrinsic::getBaseName(id).str(); |
1267 | #else |
1268 | << llvm::Intrinsic::getName(id, {}); |
1269 | #endif |
1270 | return builder_->CreateCall(f, arg_value); |
1271 | } else if (op->op.same_as(builtin::bitwise_and())) { |
1272 | return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1])); |
1273 | } else if (op->op.same_as(builtin::bitwise_or())) { |
1274 | return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1])); |
1275 | } else if (op->op.same_as(builtin::bitwise_not())) { |
1276 | return builder_->CreateNot(MakeValue(op->args[0])); |
1277 | } else if (op->op.same_as(builtin::bitwise_xor())) { |
1278 | return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1])); |
1279 | } else if (op->op.same_as(builtin::shift_left())) { |
1280 | return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1])); |
1281 | } else if (op->op.same_as(builtin::shift_right())) { |
1282 | if (op->args[0].dtype().is_int()) { |
1283 | return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1])); |
1284 | } else { |
1285 | return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1])); |
1286 | } |
1287 | } else if (op->op.same_as(builtin::tvm_storage_sync())) { |
1288 | return CreateStorageSync(op); |
1289 | } else if (op->op.same_as(builtin::address_of())) { |
1290 | const BufferLoadNode* load = op->args[0].as<BufferLoadNode>(); |
1291 | ICHECK(op->args.size() == 1 && load); |
1292 | |
1293 | Array<PrimExpr> indices = load->indices; |
1294 | if (const RampNode* r = indices[indices.size() - 1].as<RampNode>()) { |
1295 | indices.Set(indices.size() - 1, r->base); |
1296 | } |
1297 | |
1298 | std::vector<llvm::Value*> indices_val; |
1299 | for (const auto& index : indices) { |
1300 | indices_val.push_back(MakeValue(index)); |
1301 | } |
1302 | |
1303 | TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype, |
1304 | indices_val, load->dtype); |
1305 | unsigned addrspace = |
1306 | llvm::dyn_cast<llvm::PointerType>(buffer_ptr.addr->getType())->getAddressSpace(); |
1307 | return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace)); |
1308 | } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) { |
1309 | return llvm::Constant::getNullValue(t_void_p_); |
1310 | } else if (op->op.same_as(builtin::isnullptr())) { |
1311 | return builder_->CreateIsNull(MakeValue(op->args[0])); |
1312 | } else if (op->op.same_as(builtin::large_uint_imm())) { |
1313 | ICHECK_EQ(op->args.size(), 2U); |
1314 | uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value); |
1315 | uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value); |
1316 | uint64_t val = (high << 32U) | low; |
1317 | return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); |
1318 | } else if (op->op.same_as(builtin::if_then_else())) { |
1319 | ICHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition" ; |
1320 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
1321 | auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then" , function_); |
1322 | auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else" , function_); |
1323 | auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end" , function_); |
1324 | builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block); |
1325 | builder_->SetInsertPoint(then_block); |
1326 | llvm::Value* then_value = MakeValue(op->args[1]); |
1327 | llvm::BasicBlock* then_value_block = builder_->GetInsertBlock(); |
1328 | builder_->CreateBr(end_block); |
1329 | builder_->SetInsertPoint(else_block); |
1330 | llvm::Value* else_value = MakeValue(op->args[2]); |
1331 | llvm::BasicBlock* else_value_block = builder_->GetInsertBlock(); |
1332 | builder_->CreateBr(end_block); |
1333 | builder_->SetInsertPoint(end_block); |
1334 | llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2); |
1335 | value->addIncoming(then_value, then_value_block); |
1336 | value->addIncoming(else_value, else_value_block); |
1337 | return value; |
1338 | } else if (op->op.same_as(builtin::ret())) { |
1339 | auto const* val = op->args[0].as<IntImmNode>(); |
1340 | ICHECK(val) << "the tir.ret should be transformed to return zero " |
1341 | << "before the llvm code generation." ; |
1342 | ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to " |
1343 | << "return zero before the llvm code generation." ; |
1344 | builder_->CreateRet(ConstInt32(0)); |
1345 | // LLVM allows exactly one terminator in a single basic block |
1346 | // append a new dummy basic block to avoid error. |
1347 | llvm::BasicBlock* ret_dummy = |
1348 | llvm::BasicBlock::Create(*llvm_target_->GetContext(), "ret_dummy" , function_); |
1349 | builder_->SetInsertPoint(ret_dummy); |
1350 | return ret_dummy; |
1351 | } else if (op->op.same_as(builtin::reinterpret())) { |
1352 | llvm::Type* target = DTypeToLLVMType(op->dtype); |
1353 | return builder_->CreateBitCast(MakeValue(op->args[0]), target); |
1354 | } else if (op->op.same_as(builtin::isnan())) { |
1355 | // TODO(hgt312): set fast math flag |
1356 | llvm::Value* a = MakeValue(op->args[0]); |
1357 | return builder_->CreateFCmpUNO(a, a); |
1358 | } else if (op->op.same_as(builtin::vectorlow())) { |
1359 | llvm::Value* v = MakeValue(op->args[0]); |
1360 | int l = GetVectorNumElements(v); |
1361 | return CreateVecSlice(v, 0, l / 2); |
1362 | } else if (op->op.same_as(builtin::vectorhigh())) { |
1363 | llvm::Value* v = MakeValue(op->args[0]); |
1364 | int l = GetVectorNumElements(v); |
1365 | return CreateVecSlice(v, l / 2, l / 2); |
1366 | } else if (op->op.same_as(builtin::vectorcombine())) { |
1367 | llvm::Value* v0 = MakeValue(op->args[0]); |
1368 | llvm::Value* v1 = MakeValue(op->args[1]); |
1369 | int num_elems = GetVectorNumElements(v0) * 2; |
1370 | #if TVM_LLVM_VERSION >= 110 |
1371 | std::vector<int> indices; |
1372 | #else |
1373 | std::vector<unsigned> indices; |
1374 | #endif |
1375 | for (int i = 0; i < num_elems; ++i) { |
1376 | indices.push_back(i); |
1377 | } |
1378 | return builder_->CreateShuffleVector(v0, v1, indices); |
1379 | } else if (op->op.same_as(builtin::atomic_add())) { |
1380 | // TODO(masahi): Support atomic for CPU backend |
1381 | LOG(FATAL) << "CPU backend does not support atomic add yet." ; |
1382 | } else if (op->op.same_as(builtin::start_profile_intrinsic()) || |
1383 | op->op.same_as(builtin::end_profile_intrinsic())) { |
1384 | LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op; |
1385 | return nullptr; |
1386 | } else { |
1387 | LOG(FATAL) << "unknown intrinsic " << op->op; |
1388 | } |
1389 | } |
1390 | |
1391 | void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f) { |
1392 | if (const RampNode* ramp = e.as<RampNode>()) { |
1393 | for (int i = 0; i < ramp->dtype.lanes(); ++i) { |
1394 | PrimExpr offset = ramp->base + (ramp->stride * i); |
1395 | f(i, MakeValue(offset)); |
1396 | } |
1397 | } else { |
1398 | llvm::Value* value = MakeValue(e); |
1399 | for (int i = 0; i < e.dtype().lanes(); ++i) { |
1400 | f(i, builder_->CreateExtractElement(value, i)); |
1401 | } |
1402 | } |
1403 | } |
1404 | |
1405 | // Visitors |
1406 | llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); } |
1407 | |
1408 | llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) { |
1409 | return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value)); |
1410 | } |
1411 | llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { |
1412 | return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value); |
1413 | } |
1414 | |
1415 | llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { |
1416 | return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); |
1417 | } |
1418 | |
1419 | llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); } |
1420 | |
1421 | #define DEFINE_CODEGEN_BINARY_OP(Op) \ |
1422 | llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ |
1423 | if (t.is_int()) { \ |
1424 | if (t.bits() >= 32) { \ |
1425 | return builder_->CreateNSW##Op(a, b); \ |
1426 | } else { \ |
1427 | return builder_->Create##Op(a, b); \ |
1428 | } \ |
1429 | } else if (t.is_uint()) { \ |
1430 | if (t.bits() >= 32) { \ |
1431 | return builder_->CreateNUW##Op(a, b); \ |
1432 | } else { \ |
1433 | return builder_->Create##Op(a, b); \ |
1434 | } \ |
1435 | } else { \ |
1436 | ICHECK(t.is_float()); \ |
1437 | return builder_->CreateF##Op(a, b); \ |
1438 | } \ |
1439 | } \ |
1440 | llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ |
1441 | return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ |
1442 | } |
1443 | |
1444 | DEFINE_CODEGEN_BINARY_OP(Add); |
1445 | DEFINE_CODEGEN_BINARY_OP(Sub); |
1446 | DEFINE_CODEGEN_BINARY_OP(Mul); |
1447 | |
1448 | #define DEFINE_CODEGEN_CMP_OP(Op) \ |
1449 | llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ |
1450 | if (t.is_int()) { \ |
1451 | return builder_->CreateICmpS##Op(a, b); \ |
1452 | } else if (t.is_uint()) { \ |
1453 | return builder_->CreateICmpU##Op(a, b); \ |
1454 | } else { \ |
1455 | ICHECK(t.is_float()); \ |
1456 | return builder_->CreateFCmpO##Op(a, b); \ |
1457 | } \ |
1458 | } \ |
1459 | llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ |
1460 | return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ |
1461 | } |
1462 | |
1463 | DEFINE_CODEGEN_CMP_OP(LT); |
1464 | DEFINE_CODEGEN_CMP_OP(LE); |
1465 | DEFINE_CODEGEN_CMP_OP(GT); |
1466 | DEFINE_CODEGEN_CMP_OP(GE); |
1467 | |
1468 | llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) { |
1469 | llvm::Value* a = MakeValue(op->a); |
1470 | llvm::Value* b = MakeValue(op->b); |
1471 | if (op->dtype.is_int()) { |
1472 | return builder_->CreateSDiv(a, b); |
1473 | } else if (op->dtype.is_uint()) { |
1474 | return builder_->CreateUDiv(a, b); |
1475 | } else { |
1476 | ICHECK(op->dtype.is_float()); |
1477 | return builder_->CreateFDiv(a, b); |
1478 | } |
1479 | } |
1480 | |
1481 | llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { |
1482 | llvm::Value* a = MakeValue(op->a); |
1483 | llvm::Value* b = MakeValue(op->b); |
1484 | if (op->dtype.is_int()) { |
1485 | return builder_->CreateSRem(a, b); |
1486 | } else if (op->dtype.is_uint()) { |
1487 | return builder_->CreateURem(a, b); |
1488 | } else { |
1489 | ICHECK(op->dtype.is_float()); |
1490 | return builder_->CreateFRem(a, b); |
1491 | } |
1492 | } |
1493 | |
1494 | llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) { |
1495 | llvm::Value* a = MakeValue(op->a); |
1496 | llvm::Value* b = MakeValue(op->b); |
1497 | return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); |
1498 | } |
1499 | |
1500 | llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) { |
1501 | llvm::Value* a = MakeValue(op->a); |
1502 | llvm::Value* b = MakeValue(op->b); |
1503 | return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); |
1504 | } |
1505 | |
1506 | llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) { |
1507 | llvm::Value* a = MakeValue(op->a); |
1508 | llvm::Value* b = MakeValue(op->b); |
1509 | if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { |
1510 | return builder_->CreateICmpEQ(a, b); |
1511 | } else { |
1512 | return builder_->CreateFCmpOEQ(a, b); |
1513 | } |
1514 | } |
1515 | |
1516 | llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) { |
1517 | llvm::Value* a = MakeValue(op->a); |
1518 | llvm::Value* b = MakeValue(op->b); |
1519 | if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { |
1520 | return builder_->CreateICmpNE(a, b); |
1521 | } else { |
1522 | return builder_->CreateFCmpONE(a, b); |
1523 | } |
1524 | } |
1525 | |
1526 | llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) { |
1527 | return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b)); |
1528 | } |
1529 | |
1530 | llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) { |
1531 | return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b)); |
1532 | } |
1533 | |
1534 | llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) { |
1535 | return builder_->CreateNot(MakeValue(op->a)); |
1536 | } |
1537 | |
1538 | llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) { |
1539 | return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value), |
1540 | MakeValue(op->false_value)); |
1541 | } |
1542 | |
1543 | llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { |
1544 | auto it = let_binding_.find(op->var); |
1545 | if (it != let_binding_.end()) { |
1546 | ICHECK(deep_equal_(it->second->value, op->value)) |
1547 | << "Let cannot bind the same var to two different values" ; |
1548 | } else { |
1549 | let_binding_[op->var] = op; |
1550 | } |
1551 | auto var_value = MakeValue(op->value); |
1552 | var_map_[op->var.get()] = var_value; |
1553 | var_value->setName(op->var->name_hint.c_str()); |
1554 | analyzer_->Bind(op->var, op->value); |
1555 | return MakeValue(op->body); |
1556 | } |
1557 | |
1558 | llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { |
1559 | LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead." ; |
1560 | } |
1561 | |
1562 | bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { |
1563 | const llvm::DataLayout& data_layout = module_->getDataLayout(); |
1564 | int bytes = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype)); |
1565 | int bytes_scalar = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype.element_of())); |
1566 | return bytes != bytes_scalar * dtype.lanes(); |
1567 | } |
1568 | |
1569 | void CodeGenLLVM::BufferAccessHelper( |
1570 | Buffer buffer, Array<PrimExpr> indices, DataType value_dtype, |
1571 | std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i, int alignment, |
1572 | bool is_volatile)> |
1573 | make_instruction) { |
1574 | DataType buffer_element_dtype = buffer->dtype; |
1575 | |
1576 | ICHECK_GE(indices.size(), 1) |
1577 | << "Buffer " << buffer->name << " is accessed with no indices. " |
1578 | << "0-d scalar buffers are expected to be flattened to 1-d buffers prior to codegen." ; |
1579 | |
1580 | // Only the last index is allowed to be multi-lane. All earlier |
1581 | // indices must be scalar. This only matters for subclasses of |
1582 | // CodeGenLLVM, because the default implementation of GetBufferPtr |
1583 | // requires 1-d indices. |
1584 | std::vector<llvm::Value*> earlier_index_values; |
1585 | for (size_t i = 0; i < indices.size() - 1; i++) { |
1586 | ICHECK_EQ(indices[i].dtype().lanes(), 1) |
1587 | << "Buffer " << buffer->name << " is accessed with a multi-lane index at position " << i |
1588 | << ". Multi-lane indices are only supported as the last index." ; |
1589 | earlier_index_values.push_back(MakeValue(indices[i])); |
1590 | } |
1591 | |
1592 | PrimExpr last_index = indices[indices.size() - 1]; |
1593 | ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes()); |
1594 | |
1595 | // Record index and elemtype in original form used for alias info |
1596 | PrimExpr last_index_origin = last_index; |
1597 | DataType buffer_element_dtype_origin = buffer_element_dtype; |
1598 | |
1599 | bool is_volatile = volatile_buf_.count(buffer->data.get()); |
1600 | |
1601 | // If the buffer index is a contiguous ramp node, we only need to |
1602 | // access the first element, then cast to the value type. |
1603 | if (const RampNode* ramp_index = last_index.as<RampNode>()) { |
1604 | if (is_one(ramp_index->stride)) { |
1605 | last_index = ramp_index->base; |
1606 | } |
1607 | } |
1608 | |
1609 | // All TVM arrays are densely packed. If the vectorized LLVM type |
1610 | // contains padding for alignment, we need to index based on the |
1611 | // size of the scalar type to avoid introducing that padding. |
1612 | if (last_index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) { |
1613 | last_index = buffer_element_dtype.lanes() * last_index; |
1614 | buffer_element_dtype = buffer_element_dtype.element_of(); |
1615 | } |
1616 | |
1617 | int alignment; |
1618 | if (last_index.dtype().lanes() == 1) { |
1619 | // If we are accessing with a single index, then the vectorized |
1620 | // element being accessed may require more alignment than the |
1621 | // underlying data type. |
1622 | int native_bits; |
1623 | GetAlignment(value_dtype, buffer->data.get(), last_index, &alignment, &native_bits); |
1624 | } else { |
1625 | // Otherwise, alignment is based on the return value's scalar |
1626 | // type. |
1627 | ICHECK_GE(value_dtype.bits(), 8); |
1628 | alignment = value_dtype.bits() / 8; |
1629 | } |
1630 | |
1631 | llvm::Value* cached_vector_index = nullptr; |
1632 | for (int i = 0; i < last_index.dtype().lanes(); ++i) { |
1633 | llvm::Value* last_index_value; |
1634 | int subelement_i = i; |
1635 | if (const RampNode* ramp = last_index.as<RampNode>()) { |
1636 | PrimExpr offset = ramp->base + (ramp->stride * i); |
1637 | last_index_value = MakeValue(offset); |
1638 | } else if (last_index.dtype().lanes() > 1) { |
1639 | if (i == 0) { |
1640 | cached_vector_index = MakeValue(last_index); |
1641 | } |
1642 | last_index_value = builder_->CreateExtractElement(cached_vector_index, i); |
1643 | } else { |
1644 | last_index_value = MakeValue(last_index); |
1645 | subelement_i = -1; |
1646 | } |
1647 | |
1648 | std::vector<llvm::Value*> all_index_values = earlier_index_values; |
1649 | all_index_values.push_back(last_index_value); |
1650 | |
1651 | TypedPointer buffer_ptr = |
1652 | CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, |
1653 | value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); |
1654 | auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); |
1655 | AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin); |
1656 | } |
1657 | } |
1658 | |
1659 | llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { |
1660 | DataType value_dtype = op->dtype; |
1661 | |
1662 | std::vector<llvm::Value*> loads; |
1663 | |
1664 | auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment, |
1665 | bool is_volatile) { |
1666 | #if TVM_LLVM_VERSION >= 110 |
1667 | auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, |
1668 | llvm::Align(alignment), is_volatile); |
1669 | #elif TVM_LLVM_VERSION >= 80 |
1670 | auto load = |
1671 | builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); |
1672 | #else |
1673 | auto load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); |
1674 | #endif |
1675 | |
1676 | loads.push_back(load); |
1677 | return load; |
1678 | }; |
1679 | |
1680 | // Pass all indices into BufferAccessHelper. In CodeGenLLVM, |
1681 | // non-flat indices will result in an error in CreateBufferPtr, but |
1682 | // a subclass may override CreateBufferPtr. |
1683 | BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load); |
1684 | |
1685 | if (loads.size() == 1) { |
1686 | return loads[0]; |
1687 | } else { |
1688 | llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(value_dtype)); |
1689 | for (size_t i = 0; i < loads.size(); i++) { |
1690 | ret = builder_->CreateInsertElement(ret, loads[i], ConstInt32(i)); |
1691 | } |
1692 | return ret; |
1693 | } |
1694 | } |
1695 | |
1696 | llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { |
1697 | if (auto* ptr_op = op->op.as<OpNode>()) { |
1698 | auto call_op = GetRef<Op>(ptr_op); |
1699 | if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { |
1700 | // call extern intrinsic |
1701 | ICHECK_GE(op->args.size(), 1U); |
1702 | auto global_symbol = Downcast<StringImm>(op->args[0]); |
1703 | return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), global_symbol->value, op->args, |
1704 | true); |
1705 | } else if (op_attr_global_symbol_.count(call_op)) { |
1706 | // call extern if the op itself have a global symbol. |
1707 | return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op], |
1708 | op->args, false); |
1709 | } else { |
1710 | VLOG(2) << "CreateIntrinsic: " << GetRef<Call>(op); |
1711 | auto x = CreateIntrinsic(op); |
1712 | VLOG(2) << "CreateIntrinsic done" ; |
1713 | return x; |
1714 | } |
1715 | } else { |
1716 | ICHECK(op->op.as<GlobalVarNode>()); |
1717 | LOG(FATAL) << "Do not yet support cross function call" ; |
1718 | } |
1719 | } |
1720 | |
1721 | llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { |
1722 | llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); |
1723 | for (int i = 0; i < op->lanes; ++i) { |
1724 | vec = builder_->CreateInsertElement( |
1725 | vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i)); |
1726 | } |
1727 | return vec; |
1728 | } |
1729 | |
1730 | llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { |
1731 | std::vector<llvm::Value*> vecs(op->vectors.size()); |
1732 | int total_lanes = 0; |
1733 | for (int i = 0, e = op->vectors.size(); i < e; ++i) { |
1734 | vecs[i] = VisitExpr(op->vectors[i]); |
1735 | total_lanes += op->vectors[i].dtype().lanes(); |
1736 | } |
1737 | llvm::Value* v0 = CreateVecConcat(vecs); |
1738 | std::vector<uint32_t> idx(op->indices.size()); |
1739 | for (int i = 0, e = op->indices.size(); i < e; ++i) { |
1740 | const int64_t* val = as_const_int(op->indices[i]); |
1741 | ICHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " |
1742 | << "but get " << op->indices[i] << "\n" ; |
1743 | idx[i] = *val; |
1744 | } |
1745 | llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx); |
1746 | auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask); |
1747 | // If the output is a single-element vector, convert it back to a scalar. |
1748 | if (idx.size() == 1) { |
1749 | res = builder_->CreateExtractElement(res, ConstInt32(0)); |
1750 | } |
1751 | return res; |
1752 | } |
1753 | |
1754 | llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { |
1755 | return CreateBroadcast(MakeValue(op->value), op->lanes); |
1756 | } |
1757 | |
1758 | void CodeGenLLVM::VisitStmt_(const StoreNode* op) { |
1759 | LOG(FATAL) << "Unexpected deprecated StoreNode. Use BufferStoreNode instead." ; |
1760 | } |
1761 | |
1762 | void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { |
1763 | EmitDebugLocation(op); |
1764 | DataType value_dtype = op->value.dtype(); |
1765 | Var buffer_var = op->buffer->data; |
1766 | |
1767 | llvm::Value* value = MakeValue(op->value); |
1768 | |
1769 | auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment, |
1770 | bool is_volatile) { |
1771 | llvm::Value* to_store = value; |
1772 | if (subelement_i != -1) { |
1773 | to_store = builder_->CreateExtractElement(value, subelement_i); |
1774 | } |
1775 | #if TVM_LLVM_VERSION >= 110 |
1776 | return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), |
1777 | is_volatile); |
1778 | #else |
1779 | return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); |
1780 | #endif |
1781 | }; |
1782 | |
1783 | // Pass all indices into BufferAccessHelper. In CodeGenLLVM, |
1784 | // non-flat indices will result in an error in CreateBufferPtr, but |
1785 | // a subclass may override CreateBufferPtr. |
1786 | BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store); |
1787 | } |
1788 | |
1789 | void CodeGenLLVM::VisitStmt_(const ForNode* op) { |
1790 | EmitDebugLocation(op); |
1791 | ICHECK(is_zero(op->min)); |
1792 | analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); |
1793 | if (op->kind == ForKind::kUnrolled) { |
1794 | LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, " |
1795 | << " consider set unroll_explicit=True" ; |
1796 | } else { |
1797 | ICHECK(op->kind == ForKind::kSerial); |
1798 | } |
1799 | CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), |
1800 | llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); |
1801 | } |
1802 | |
1803 | void CodeGenLLVM::VisitStmt_(const WhileNode* op) { |
1804 | EmitDebugLocation(op); |
1805 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
1806 | auto* while_cond = llvm::BasicBlock::Create(*ctx, "while_cond" , function_); |
1807 | auto* while_body = llvm::BasicBlock::Create(*ctx, "while_body" , function_); |
1808 | auto* while_merge = llvm::BasicBlock::Create(*ctx, "while_merge" , function_); |
1809 | builder_->CreateBr(while_cond); |
1810 | builder_->SetInsertPoint(while_cond); |
1811 | builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge); |
1812 | builder_->SetInsertPoint(while_body); |
1813 | this->VisitStmt(op->body); |
1814 | builder_->CreateBr(while_cond); |
1815 | builder_->SetInsertPoint(while_merge); |
1816 | } |
1817 | |
1818 | void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { |
1819 | EmitDebugLocation(op); |
1820 | llvm::Value* cond = MakeValue(op->condition); |
1821 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
1822 | auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then" , function_); |
1823 | auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end" , function_); |
1824 | if (op->else_case) { |
1825 | auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else" , function_); |
1826 | builder_->CreateCondBr(cond, then_block, else_block); |
1827 | builder_->SetInsertPoint(then_block); |
1828 | this->VisitStmt(op->then_case); |
1829 | builder_->CreateBr(end_block); |
1830 | builder_->SetInsertPoint(else_block); |
1831 | this->VisitStmt(op->else_case.value()); |
1832 | builder_->CreateBr(end_block); |
1833 | } else { |
1834 | builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_); |
1835 | builder_->SetInsertPoint(then_block); |
1836 | this->VisitStmt(op->then_case); |
1837 | builder_->CreateBr(end_block); |
1838 | } |
1839 | builder_->SetInsertPoint(end_block); |
1840 | } |
1841 | |
1842 | void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { |
1843 | EmitDebugLocation(op); |
1844 | auto data = op->data.value(); |
1845 | auto array = NDArrayToLLVMArray(llvm_target_->GetContext(), data); |
1846 | std::string symbol_name = op->buffer_var->name_hint; |
1847 | llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( |
1848 | *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); |
1849 | |
1850 | var_map_[op->buffer_var.operator->()] = param_symbol; |
1851 | this->VisitStmt(op->body); |
1852 | } |
1853 | |
1854 | void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { |
1855 | EmitDebugLocation(op); |
1856 | ICHECK_EQ(op->extents.size(), 1) |
1857 | << "LLVM codegen only supports flat 1-d buffer allocation, but allocation of " |
1858 | << op->buffer_var->name_hint << " is " << op->extents << "-d" ; |
1859 | |
1860 | ICHECK(!is_zero(op->condition)); |
1861 | llvm::Value* buf = nullptr; |
1862 | |
1863 | int32_t constant_size = op->ConstantAllocationSize(); |
1864 | ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation" ; |
1865 | StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; |
1866 | if (constant_size % 4 == 0 && info.alignment == 0) { |
1867 | info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); |
1868 | } |
1869 | // maximum necessary alignment in the NV devices |
1870 | if (info.alignment > 16) { |
1871 | info.alignment = 16; |
1872 | } |
1873 | llvm::AllocaInst* alloca = WithFunctionEntry([&]() { |
1874 | return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); |
1875 | }); |
1876 | #if TVM_LLVM_VERSION >= 110 |
1877 | auto alignment = static_cast<unsigned>(alloca->getAlign().value()); |
1878 | #else |
1879 | unsigned alignment = alloca->getAlignment(); |
1880 | #endif |
1881 | if (alignment < static_cast<unsigned>(info.alignment)) { |
1882 | #if TVM_LLVM_VERSION >= 100 |
1883 | alloca->setAlignment(llvm::Align(info.alignment)); |
1884 | #else |
1885 | alloca->setAlignment(info.alignment); |
1886 | #endif |
1887 | } |
1888 | #if TVM_LLVM_VERSION >= 110 |
1889 | info.alignment = static_cast<unsigned>(alloca->getAlign().value()); |
1890 | #else |
1891 | info.alignment = alloca->getAlignment(); |
1892 | #endif |
1893 | |
1894 | buf = alloca; |
1895 | |
1896 | buf = builder_->CreatePointerCast( |
1897 | buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); |
1898 | buf->setName(op->buffer_var->name_hint.c_str()); |
1899 | |
1900 | ICHECK(!var_map_.count(op->buffer_var.get())); |
1901 | var_map_[op->buffer_var.get()] = buf; |
1902 | this->VisitStmt(op->body); |
1903 | } |
1904 | |
1905 | void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { |
1906 | EmitDebugLocation(op); |
1907 | if (op->attr_key == tir::attr::thread_extent) { |
1908 | IterVar iv = Downcast<IterVar>(op->node); |
1909 | if (iv->thread_tag.length() != 0) { |
1910 | if (!var_map_.count(iv->var.get())) { |
1911 | var_map_[iv->var.get()] = GetThreadIndex(iv); |
1912 | analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); |
1913 | } |
1914 | } |
1915 | } else if (op->attr_key == tir::attr::storage_alignment) { |
1916 | const VarNode* v = op->node.as<VarNode>(); |
1917 | ICHECK(v); |
1918 | alloc_storage_info_[v].alignment = static_cast<int>(op->value.as<IntImmNode>()->value); |
1919 | if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) { |
1920 | builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), |
1921 | alloc_storage_info_[v].alignment); |
1922 | } |
1923 | } else if (op->attr_key == tir::attr::volatile_scope) { |
1924 | const VarNode* v = op->node.as<VarNode>(); |
1925 | ICHECK(v); |
1926 | volatile_buf_.insert(v); |
1927 | } |
1928 | this->VisitStmt(op->body); |
1929 | } |
1930 | |
1931 | void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) { |
1932 | EmitDebugLocation(op); |
1933 | // auto a_cu = |
1934 | With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition); |
1935 | this->VisitStmt(op->body); |
1936 | } |
1937 | |
1938 | void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { |
1939 | EmitDebugLocation(op); |
1940 | const VarNode* v = op->var.get(); |
1941 | ICHECK(!var_map_.count(v)); |
1942 | if (v->dtype.is_handle()) { |
1943 | if (!is_restricted_) { |
1944 | alias_var_set_.insert(v); |
1945 | } |
1946 | } |
1947 | llvm::Value* value = MakeValue(op->value); |
1948 | value->setName(v->name_hint.c_str()); |
1949 | var_map_[v] = value; |
1950 | analyzer_->Bind(op->var, op->value); |
1951 | if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) { |
1952 | builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), |
1953 | alloc_storage_info_[v].alignment); |
1954 | } |
1955 | this->VisitStmt(op->body); |
1956 | } |
1957 | |
1958 | void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { |
1959 | EmitDebugLocation(op); |
1960 | for (Stmt stmt : op->seq) { |
1961 | this->VisitStmt(stmt); |
1962 | } |
1963 | } |
1964 | |
1965 | void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { |
1966 | EmitDebugLocation(op); |
1967 | MakeValue(op->value); |
1968 | } |
1969 | |
1970 | void CodeGenLLVM::EmitDebugLocation(const Span& span) { |
1971 | #if TVM_LLVM_VERSION >= 50 |
1972 | if (di_subprogram_ == nullptr) { |
1973 | // debug info is not always generated outside of CPU codegen |
1974 | return; |
1975 | } |
1976 | if (!span.defined()) { |
1977 | VLOG(0) << "Cannot emit debug location for undefined span" ; |
1978 | return; |
1979 | } |
1980 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
1981 | auto loc = llvm::DebugLoc(llvm::DILocation::get(*ctx, span->line, span->column, di_subprogram_)); |
1982 | builder_->SetCurrentDebugLocation(loc); |
1983 | #endif |
1984 | } |
1985 | |
1986 | void CodeGenLLVM::EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullptr); } |
1987 | void CodeGenLLVM::EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); } |
1988 | |
1989 | } // namespace codegen |
1990 | } // namespace tvm |
1991 | |
1992 | #endif // TVM_LLVM_VERSION |
1993 | |