1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_llvm.cc
22 */
23#ifdef TVM_LLVM_VERSION
24// Part of the code are adapted from Halide's CodeGen_LLVM
25#include "codegen_llvm.h"
26
27#include <llvm/ADT/ArrayRef.h>
28#include <llvm/ADT/SmallVector.h>
29#include <llvm/ADT/StringRef.h>
30#include <llvm/ADT/Triple.h>
31#include <llvm/Analysis/TargetTransformInfo.h>
32#if TVM_LLVM_VERSION >= 50
33#include <llvm/BinaryFormat/Dwarf.h>
34#else
35#include <llvm/Support/Dwarf.h>
36#endif
37#if TVM_LLVM_VERSION >= 60
38#include <llvm/CodeGen/TargetSubtargetInfo.h>
39#else
40#include <llvm/Target/TargetSubtargetInfo.h>
41#endif
42#include <llvm/IR/Argument.h>
43#include <llvm/IR/Attributes.h>
44#include <llvm/IR/BasicBlock.h>
45#include <llvm/IR/CallingConv.h>
46#include <llvm/IR/Constants.h>
47#include <llvm/IR/DIBuilder.h>
48#include <llvm/IR/DataLayout.h>
49#include <llvm/IR/DebugInfoMetadata.h>
50#include <llvm/IR/DerivedTypes.h>
51#if TVM_LLVM_VERSION >= 150
52#include <llvm/IR/FMF.h>
53#else
54#include <llvm/IR/Operator.h>
55#endif
56#include <llvm/IR/Function.h>
57#include <llvm/IR/GlobalVariable.h>
58#include <llvm/IR/Instructions.h>
59#include <llvm/IR/Intrinsics.h>
60#include <llvm/IR/LLVMContext.h>
61#include <llvm/IR/MDBuilder.h>
62#include <llvm/IR/Metadata.h>
63#include <llvm/IR/Module.h>
64#include <llvm/IR/Type.h>
65#include <llvm/IRReader/IRReader.h>
66#include <llvm/Linker/Linker.h>
67#include <llvm/Pass.h>
68#if TVM_LLVM_VERSION >= 160
69#include <llvm/IR/Verifier.h> // For VerifierPass
70#include <llvm/Passes/PassBuilder.h>
71#include <llvm/Passes/StandardInstrumentations.h>
72#else
73#include <llvm/IR/LegacyPassManager.h>
74#include <llvm/Transforms/IPO/PassManagerBuilder.h>
75#endif
76#if TVM_LLVM_VERSION >= 100
77#include <llvm/Support/Alignment.h>
78#include <llvm/Support/TypeSize.h>
79#endif
80#include <llvm/Support/CodeGen.h>
81#include <llvm/Support/MemoryBuffer.h>
82#include <llvm/Support/SourceMgr.h>
83#include <llvm/Target/TargetMachine.h>
84#include <llvm/Transforms/IPO.h>
85#include <llvm/Transforms/Utils/ModuleUtils.h>
86#include <tvm/runtime/c_runtime_api.h>
87#include <tvm/runtime/crt/error_codes.h>
88#include <tvm/runtime/device_api.h>
89#include <tvm/tir/op.h>
90
91#include <algorithm>
92#include <functional>
93#include <memory>
94#include <sstream>
95#include <string>
96#include <utility>
97#include <vector>
98
99#include "../../arith/pattern_match.h"
100#include "../build_common.h"
101#include "../func_registry_generator.h"
102#include "codegen_params.h"
103#include "llvm_instance.h"
104
105namespace tvm {
106namespace codegen {
107
108// CodeGenLLVM has members of type std::unique_ptr<T>. These members will be
109// instantiated in the constructor, which will requre that the type T is
110// complete at that point. Put the constructor (and destructor) here, since
111// all types should be complete here.
112CodeGenLLVM::CodeGenLLVM() = default;
113CodeGenLLVM::~CodeGenLLVM() = default;
114CodeGenLLVM::DebugInfo::~DebugInfo() = default;
115
116std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(LLVMTarget* llvm_target) {
117 std::string target = llvm_target->GetOrCreateTargetMachine()->getTarget().getName();
118 std::string factory_template = "tvm.codegen.llvm.target_";
119 void* handle = nullptr;
120 if (const PackedFunc* f = runtime::Registry::Get(factory_template + target)) {
121 handle = (*f)();
122 } else if (const PackedFunc* f = runtime::Registry::Get(factory_template + "cpu")) {
123 handle = (*f)();
124 } else {
125 LOG(FATAL) << "no factory function for codegen for target " << target;
126 }
127 if (handle) {
128 return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
129 } else {
130 LOG(FATAL) << "unable to create codegen for target " << target;
131 }
132}
133
134void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib,
135 bool dynamic_lookup, bool target_c_runtime) {
136 llvm_target_ = llvm_target;
137 llvm::LLVMContext* ctx = llvm_target_->GetContext();
138 builder_.reset(new IRBuilder(*ctx));
139 module_.reset(new llvm::Module(module_name, *ctx));
140 md_builder_.reset(new llvm::MDBuilder(*ctx));
141 // types
142 t_void_ = llvm::Type::getVoidTy(*ctx);
143 t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(GetGlobalAddressSpace());
144 t_int_ = llvm::Type::getInt32Ty(*ctx);
145 t_char_ = llvm::Type::getInt8Ty(*ctx);
146 t_int8_ = llvm::Type::getInt8Ty(*ctx);
147 t_int16_ = llvm::Type::getInt16Ty(*ctx);
148 t_int32_ = llvm::Type::getInt32Ty(*ctx);
149 t_int64_ = llvm::Type::getInt64Ty(*ctx);
150 t_float64_ = llvm::Type::getDoubleTy(*ctx);
151 // meta data
152 md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
153 md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
154 md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
155 InitTarget();
156}
157
158void CodeGenLLVM::SetFastMathFlags(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); }
159
160void CodeGenLLVM::InitTarget() {
161 llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
162 module_->setTargetTriple(tm->getTargetTriple().str());
163 module_->setDataLayout(tm->createDataLayout());
164 data_layout_.reset(new llvm::DataLayout(module_.get()));
165 if (native_vector_bits_ == 0) {
166 const auto& arch = tm->getTargetTriple().getArch();
167 if (arch == llvm::Triple::x86_64) {
168 // for avx512
169 native_vector_bits_ = 512;
170 } else if (arch == llvm::Triple::x86) {
171 native_vector_bits_ = 256;
172 } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
173 native_vector_bits_ = 128;
174 } else {
175 native_vector_bits_ = 128;
176 std::string arch_name = std::string(tm->getTargetTriple().getArchName());
177 LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
178 }
179 }
180
181#if TVM_LLVM_VERSION >= 60
182 bool use_float16_abi = false;
183#if TVM_LLVM_VERSION >= 150
184 // For conversions between _Float16 and float, LLVM uses runtime functions
185 // __extendhfsf2 and __truncsfhf2. On X86 up until version 14, LLVM used
186 // "uint16_t" for representing _Float16. Starting with LLVM 15, half-precision
187 // values can be passed in XMM registers (i.e. as floating-point). This happens
188 // when the compilation target has SSE2 enabled (either directly, or by enabling
189 // a feature that implies SSE2).
190 // Because the names of the conversion functions remain unchanged, it is impossible
191 // for TVM to provide them in the runtime, and have them work in both cases.
192 // To alleviate this issue, emit these functions directly into the target module
193 // after detecting whether or not to use floating-point ABI. To allow the linker
194 // to remove potential duplicates (or if they are unused), they are weak and
195 // reside in a separate section (ELF).
196 llvm::Triple::ArchType arch_type = tm->getTargetTriple().getArch();
197 if (arch_type == llvm::Triple::x86 || arch_type == llvm::Triple::x86_64) {
198 // Detect if SSE2 is enabled. This determines whether float16 ABI is used.
199 std::stringstream os;
200 const char fname[] = "test_sse2";
201 os << "target triple = \"" << llvm_target_->GetTargetTriple() << "\"\n"
202 << "define void @" << fname << "() #0 { ret void } attributes #0 = { \"target-cpu\"=\""
203 << llvm_target_->GetCPU() << "\" ";
204 if (auto&& fs = llvm_target_->GetTargetFeatureString(); !fs.empty()) {
205 os << "\"target-features\"=\"" << fs << "\" ";
206 }
207 os << "}\n";
208 auto mod = llvm_target_->GetInstance().ParseIR(os.str());
209 auto* test_sse2 = mod->getFunction(fname);
210 ICHECK(test_sse2 != nullptr) << "Module creation error";
211 use_float16_abi = tm->getSubtargetImpl(*test_sse2)->checkFeatures("+sse2");
212 }
213#endif // TVM_LLVM_VERSION >= 150
214
215 // Call this function only with LLVM >= 6.0. The code it emits uses "dso_local"
216 // which was introduced in LLVM 6.
217 EmitFloat16ConversionBuiltins(use_float16_abi);
218#endif // TVM_LLVM_VERSION >= 60
219}
220
221void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }
222
223void CodeGenLLVM::InitFuncState() {
224 var_map_.clear();
225 alias_var_set_.clear();
226 alloc_storage_info_.clear();
227 volatile_buf_.clear();
228 analyzer_.reset(new arith::Analyzer());
229}
230
231void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
232 this->InitFuncState();
233
234 ICHECK_EQ(f->buffer_map.size(), 0U)
235 << "Cannot codegen function with buffer_map, please lower them first";
236
237 std::vector<llvm::Type*> param_types;
238 is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
239 for (Var param : f->params) {
240 param_types.push_back(GetLLVMType(param));
241 if (!is_restricted_ && param.dtype().is_handle()) {
242 alias_var_set_.insert(param.get());
243 }
244 }
245 // TODO(tvm-team):
246 // Update the function type to respect the ret_type field of f.
247 // Once we allow more flexibility in the PrimFunc.
248 llvm::FunctionType* ftype =
249 llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);
250
251 auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
252 ICHECK(global_symbol.defined())
253 << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
254 function_ = module_->getFunction(MakeStringRef(global_symbol.value()));
255 if (function_ == nullptr) {
256 function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
257 MakeStringRef(global_symbol.value()), module_.get());
258 }
259 function_->setCallingConv(llvm::CallingConv::C);
260 function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
261 SetTargetAttributes(function_);
262
263 // set var map and align information
264 auto arg_it = function_->arg_begin();
265 for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
266 llvm::Argument* v = &(*arg_it);
267 const Var& var = f->params[i];
268 var_map_[var.get()] = v;
269 v->setName(std::string(var->name_hint));
270 if (is_restricted_) {
271 if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
272 // set non alias.
273#if TVM_LLVM_VERSION >= 50
274 function_->addParamAttr(i, llvm::Attribute::NoAlias);
275#else
276 function_->setDoesNotAlias(i + 1);
277#endif
278 }
279 }
280 }
281 llvm::LLVMContext* ctx = llvm_target_->GetContext();
282 llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx, "entry", function_);
283 builder_->SetInsertPoint(entry);
284 this->VisitStmt(f->body);
285
286 // Add alignment attribute if needed.
287#if TVM_LLVM_VERSION >= 50
288 for (size_t i = 0; i < f->params.size(); ++i) {
289 const Var& var = f->params[i];
290 auto f = alloc_storage_info_.find(var.get());
291 if (f != alloc_storage_info_.end()) {
292 unsigned align = f->second.alignment;
293 if (align > 1) {
294 auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align);
295 function_->addParamAttr(i, attr);
296 }
297 }
298 }
299#endif
300
301 EmitDebugLocation(f->span);
302 if (ret_void) {
303 builder_->CreateRetVoid();
304 } else {
305 builder_->CreateRet(ConstInt32(0));
306 }
307}
308
309std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
310 this->AddStartupFunction();
311 for (size_t i = 0; i < link_modules_.size(); ++i) {
312 ICHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
313 << "Failed to link modules";
314 }
315 link_modules_.clear();
316 // optimize
317 this->Optimize();
318 return std::move(module_);
319}
320
321void CodeGenLLVM::HandleImport(const std::string& code) {
322 llvm::StringRef code_str(code);
323 std::unique_ptr<llvm::Module> mlib;
324 if (code_str.endswith(".ll") || code_str.endswith(".bc")) {
325 mlib = llvm_target_->GetInstance().LoadIR(code);
326 } else {
327 mlib = llvm_target_->GetInstance().ParseIR(code);
328 }
329
330 mlib->setTargetTriple(llvm_target_->GetTargetTriple());
331 mlib->setDataLayout(llvm_target_->GetOrCreateTargetMachine()->createDataLayout());
332 // mark all the functions as force inline
333 for (llvm::Function& f : mlib->functions()) {
334 f.removeFnAttr(llvm::Attribute::NoInline);
335 f.addFnAttr(llvm::Attribute::AlwaysInline);
336 f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
337 }
338 // add to linker libraries.
339 this->AddLinkModule(std::move(mlib));
340}
341
342void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
343 link_modules_.emplace_back(std::move(mod));
344}
345
346void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
347 LOG(FATAL) << "not implemented";
348}
349
350llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented"; }
351
352llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) { LOG(FATAL) << "not implemented"; }
353
354#if TVM_LLVM_VERSION >= 160
355
356// Use new pass manager
357
358void CodeGenLLVM::Optimize() {
359 llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
360
361 bool debug_logging = false;
362 bool verify_each = false;
363
364 llvm::PipelineTuningOptions pto = llvm::PipelineTuningOptions();
365 llvm::PassInstrumentationCallbacks pic;
366 llvm::PassBuilder builder(tm, pto, std::nullopt, &pic);
367
368 llvm::LoopAnalysisManager lam;
369 llvm::FunctionAnalysisManager fam;
370 llvm::CGSCCAnalysisManager cgam;
371 llvm::ModuleAnalysisManager mam;
372 builder.registerLoopAnalyses(lam);
373 builder.registerFunctionAnalyses(fam);
374 builder.registerCGSCCAnalyses(cgam);
375 builder.registerModuleAnalyses(mam);
376 builder.crossRegisterProxies(lam, fam, cgam, mam);
377
378 // Construct the default pass pipeline depending on the opt level.
379 std::string pipeline;
380 switch (llvm_target_->GetOptLevel()) {
381 case llvm::CodeGenOpt::Level::None:
382 pipeline = "default<O0>";
383 break;
384 case llvm::CodeGenOpt::Level::Less:
385 pipeline = "default<O1>";
386 break;
387 case llvm::CodeGenOpt::Level::Default:
388 pipeline = "default<O2>";
389 break;
390 default:
391 // CodeGenOpt::Level::Aggressive
392 pipeline = "default<O3>";
393 break;
394 }
395
396 llvm::StandardInstrumentations si(*llvm_target_->GetContext(), debug_logging, verify_each);
397 si.registerCallbacks(pic, &fam);
398 llvm::ModulePassManager mpass;
399 if (verify_each) {
400 mpass.addPass(llvm::VerifierPass());
401 }
402 if (auto err = builder.parsePassPipeline(mpass, pipeline)) {
403 LOG(FATAL) << "error parsing pass pipeline '" << pipeline
404 << "':" << llvm::toString(std::move(err)) << '\n';
405 }
406
407 mpass.run(*module_, mam);
408}
409
410#else // TVM_LLVM_VERSION
411
412class FPassManager : public llvm::legacy::FunctionPassManager {
413 public:
414 explicit FPassManager(llvm::Module* m) : llvm::legacy::FunctionPassManager(m) {}
415 // override add to allow messaging
416 void add(llvm::Pass* p) final { llvm::legacy::FunctionPassManager::add(p); }
417};
418
419class MPassManager : public llvm::legacy::PassManager {
420 public:
421 // override add to allow messaging
422 void add(llvm::Pass* p) final { llvm::legacy::PassManager::add(p); }
423};
424
425void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {}
426
427void CodeGenLLVM::Optimize() {
428 // pass manager
429 FPassManager fpass(module_.get());
430 MPassManager mpass;
431 llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
432 mpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis()));
433 fpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis()));
434
435 // place optimization pass
436 llvm::PassManagerBuilder builder;
437
438 // Use the same opt-level as specified in TargetMachine for running passes
439 llvm::CodeGenOpt::Level opt_level = llvm_target_->GetOptLevel();
440
441 switch (opt_level) {
442 case llvm::CodeGenOpt::Level::None:
443 builder.OptLevel = 0;
444 break;
445 case llvm::CodeGenOpt::Level::Less:
446 builder.OptLevel = 1;
447 break;
448
449 case llvm::CodeGenOpt::Level::Default:
450 builder.OptLevel = 2;
451 break;
452
453 default:
454 // CodeGenOpt::Level::Aggressive
455 builder.OptLevel = 3;
456 }
457
458#if TVM_LLVM_VERSION >= 50
459 builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
460#else
461 builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0);
462#endif
463 builder.LoopVectorize = true;
464 builder.SLPVectorize = true;
465 this->InitPassManagerBuilder(&builder);
466
467#if TVM_LLVM_VERSION >= 50
468 tm->adjustPassManager(builder);
469#endif
470
471 builder.populateFunctionPassManager(fpass);
472 builder.populateModulePassManager(mpass);
473
474 fpass.doInitialization();
475 for (auto it = module_->begin(); it != module_->end(); ++it) {
476 fpass.run(*it);
477 }
478 fpass.doFinalization();
479 mpass.run(*module_);
480}
481#endif // TVM_LLVM_VERSION
482
483int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
484 return native_vector_bits_;
485}
486
487unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; }
488
489llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
490 if (dtype.is_handle()) {
491 ICHECK_EQ(dtype.lanes(), 1);
492 return t_void_p_;
493 }
494 if (dtype.is_void()) {
495 return t_void_;
496 }
497 llvm::Type* etype = nullptr;
498 llvm::LLVMContext* ctx = llvm_target_->GetContext();
499 if (dtype.is_int() || dtype.is_uint()) {
500 etype = llvm::Type::getIntNTy(*ctx, dtype.bits());
501 } else if (dtype.is_float()) {
502 switch (dtype.bits()) {
503 case 16:
504 etype = llvm::Type::getHalfTy(*ctx);
505 break;
506 case 32:
507 etype = llvm::Type::getFloatTy(*ctx);
508 break;
509 case 64:
510 etype = llvm::Type::getDoubleTy(*ctx);
511 break;
512 default:
513 LOG(FATAL) << "do not support " << dtype;
514 }
515 }
516 if (dtype.lanes() != 1) {
517#if TVM_LLVM_VERSION >= 110
518 return llvm::FixedVectorType::get(etype, dtype.lanes());
519#else
520 return llvm::VectorType::get(etype, dtype.lanes());
521#endif
522 } else {
523 return etype;
524 }
525} // namespace codegen
526
527llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
528 if (auto* ptr = type.as<PrimTypeNode>()) {
529 return DTypeToLLVMType(ptr->dtype);
530 } else if (auto* ptr = type.as<PointerTypeNode>()) {
531 // LLVM IR doesn't allow void*, so we need to recognize this
532 // pattern explicitly.
533 if (auto* primtype = ptr->element_type.as<PrimTypeNode>()) {
534 if (primtype->dtype.is_void()) {
535 return t_void_p_;
536 }
537 }
538 // TODO(tvm-team) consider put storage scope into the pointer type.
539 return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace());
540 } else if (IsVoidType(type)) {
541 return t_void_;
542 } else {
543 LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM Type";
544 }
545}
546
547llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const {
548 return GetLLVMType(GetType(expr));
549}
550
551// Add tbaa alias information for load
552//
553// use a binary tree typed system to declare information
554// and allow alias to be distinguished across nodes.
555//
556// This trick comes from Halide's CodeGen_LLVM
557//
558void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer_var, PrimExpr index,
559 DataType access_dtype) {
560 if (alias_var_set_.count(buffer_var) != 0) {
561 // Mark all possibly aliased pointer as same type.
562 llvm::MDNode* meta = md_tbaa_alias_set_;
563 inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
564 return;
565 }
566
567 int64_t base = 0, width = 0;
568 arith::PVar<IntImm> pbase, pstride;
569 arith::PVar<int> planes;
570 // create meta-data for alias analysis
571 // Use a group of binary tree ranges of memory banks.
572 int64_t xwith = 0;
573 if (arith::ramp(pbase, pstride, planes).Match(index)) {
574 base = pbase.Eval()->value;
575 xwith = planes.Eval() * pstride.Eval()->value;
576 } else if (auto* ptr = index.as<tir::IntImmNode>()) {
577 base = ptr->value;
578 xwith = 1;
579 }
580 // adjust address index unit to byte
581 const int64_t unit_bit_width = 8;
582 const int64_t access_elem_bits = access_dtype.bits() * access_dtype.lanes();
583 base = base * access_elem_bits / unit_bit_width;
584 xwith = (xwith * access_elem_bits + unit_bit_width - 1) / unit_bit_width;
585 if (xwith > 0) {
586 width = 1;
587 while (width < xwith) {
588 width *= 2;
589 }
590 while (base % width) {
591 base -= base % width;
592 width *= 2;
593 }
594 }
595
596 llvm::MDNode* meta = md_tbaa_root_;
597 std::ostringstream buffer_addr;
598 buffer_addr << buffer_var;
599 meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
600
601 // create a tree-shape access structure.
602 if (width != 0) {
603 for (int64_t w = 1024; w >= width; w /= 2) {
604 int64_t b = (base / w) * w;
605 std::stringstream os;
606 os << buffer_var << ".w" << w << ".b" << b;
607 meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
608 }
609 }
610 inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
611}
612
613void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index,
614 int* p_alignment, int* p_native_bits) {
615 int max_align_bits = t.bits();
616 auto it = alloc_storage_info_.find(buf_var);
617 if (it != alloc_storage_info_.end()) {
618 const StorageInfo& info = it->second;
619 *p_native_bits =
620 NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef<Var>(buf_var))));
621 max_align_bits = info.alignment * 8;
622 } else {
623 *p_native_bits = native_vector_bits_;
624 }
625
626 arith::ModularSet me = analyzer_->modular_set(index);
627 int64_t base = me->base;
628 int64_t coeff = me->coeff;
629
630 int align_bits = t.bits();
631 while (align_bits < max_align_bits && base % 2 == 0 && coeff % 2 == 0) {
632 base = base / 2;
633 coeff = coeff / 2;
634 align_bits *= 2;
635 }
636 if (align_bits < 8) {
637 align_bits = 8;
638 }
639 *p_alignment = align_bits / 8;
640}
641
642llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t size,
643 unsigned int shared_address_space,
644 int alignment,
645 llvm::GlobalValue::LinkageTypes linkage) {
646 llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size);
647 llvm::GlobalVariable* global =
648 new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem", nullptr,
649 llvm::GlobalValue::NotThreadLocal, shared_address_space);
650#if TVM_LLVM_VERSION >= 100
651 global->setAlignment(llvm::Align(alignment));
652#else
653 global->setAlignment(alignment);
654#endif
655 return global;
656}
657
658std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
659#if TVM_LLVM_VERSION >= 100
660 auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>();
661 debug_info->di_builder_ = std::make_unique<llvm::DIBuilder>(*module);
662#else
663 auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
664 debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
665#endif
666 // TODO(tulloch): pass this information through relay::Span classes to the IRModule instance?
667 debug_info->file_ = debug_info->di_builder_->createFile("main.tir", ".");
668 const int runtime_version = 0;
669 const bool is_optimized = false;
670 const char* compiler_flags = "";
671 debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
672 /*Lang=*/llvm::dwarf::DW_LANG_C, /*File=*/debug_info->file_, /*Producer=*/"TVM", is_optimized,
673 compiler_flags, runtime_version);
674 return debug_info;
675}
676
677llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
678#if TVM_LLVM_VERSION >= 110
679 llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes);
680#else
681 llvm::Type* type = llvm::VectorType::get(value->getType(), lanes);
682#endif
683 llvm::Constant* undef = llvm::UndefValue::get(type);
684 llvm::Constant* zero = ConstInt32(0);
685 value = builder_->CreateInsertElement(undef, value, zero);
686#if TVM_LLVM_VERSION >= 120
687 llvm::Constant* mask = llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero);
688#elif TVM_LLVM_VERSION >= 110
689 llvm::Constant* mask =
690 llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero);
691#else
692 llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
693#endif
694 return builder_->CreateShuffleVector(value, undef, mask);
695}
696
697llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
698 int num_elems = GetVectorNumElements(vec);
699 if (extent == num_elems && begin == 0) return vec;
700 ICHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
701 std::vector<llvm::Constant*> indices;
702 indices.reserve(extent);
703 for (int i = 0; i < extent; ++i) {
704 if (begin + i >= 0 && begin + i < num_elems) {
705 indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
706 } else {
707 indices.push_back(llvm::UndefValue::get(t_int32_));
708 }
709 }
710 return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
711}
712
713llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
714 int num_elems = GetVectorNumElements(vec);
715#if TVM_LLVM_VERSION >= 110
716 std::vector<int> indices;
717#else
718 std::vector<unsigned> indices;
719#endif
720 for (int i = 0; i < num_elems; ++i) {
721 indices.push_back(num_elems - i - 1);
722 }
723 return builder_->CreateShuffleVector(vec, vec, indices);
724}
725
726llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
727 llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes)));
728 int num_elems = GetVectorNumElements(vec);
729 if (num_elems == target_lanes) return vec;
730 ICHECK_LT(num_elems, target_lanes);
731 for (int i = 0; i < num_elems; ++i) {
732 mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
733 }
734 return builder_->CreateShuffleVector(vec, vec, mask);
735}
736
737llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
738 // To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane
739 // LLVM vector types.
740 for (size_t i = 0, e = vecs.size(); i != e; ++i) {
741 llvm::Value* v = vecs[i];
742 if (!v->getType()->isVectorTy()) {
743#if TVM_LLVM_VERSION >= 110
744 llvm::Type* vec_ty = llvm::FixedVectorType::get(v->getType(), 1);
745#else
746 llvm::Type* vec_ty = llvm::VectorType::get(v->getType(), 1);
747#endif
748 vecs[i] = builder_->CreateInsertElement(llvm::UndefValue::get(vec_ty), v, ConstInt32(0));
749 }
750 }
751
752 // concat vector, tree shape reduction
753 int total_lanes = 0;
754
755 for (llvm::Value* v : vecs) {
756 total_lanes += GetVectorNumElements(v);
757 }
758 while (vecs.size() > 1) {
759 std::vector<llvm::Value*> new_vecs;
760 for (size_t i = 0; i < vecs.size() - 1; i += 2) {
761 llvm::Value* lhs = vecs[i];
762 llvm::Value* rhs = vecs[i + 1];
763 const size_t lhs_lanes = GetVectorNumElements(lhs);
764 const size_t rhs_lanes = GetVectorNumElements(rhs);
765 if (lhs_lanes < rhs_lanes) {
766 lhs = CreateVecPad(lhs, rhs_lanes);
767 } else if (rhs_lanes < lhs_lanes) {
768 rhs = CreateVecPad(rhs, lhs_lanes);
769 }
770 const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes);
771#if TVM_LLVM_VERSION >= 110
772 std::vector<int> mask;
773#else
774 std::vector<unsigned> mask;
775#endif
776 for (size_t i = 0; i < lhs_lanes; ++i) {
777 mask.push_back(i);
778 }
779 for (size_t i = 0; i < rhs_lanes; ++i) {
780 mask.push_back(shared_lanes + i);
781 }
782 new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
783 }
784 if (vecs.size() % 2 != 0) {
785 new_vecs.push_back(vecs.back());
786 }
787 vecs.swap(new_vecs);
788 }
789 return CreateVecSlice(vecs[0], 0, total_lanes);
790}
791
792void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
793 const Var& loop_var, const Stmt& body) {
794 EmitDebugLocation(body->span);
795 llvm::BasicBlock* pre_block = builder_->GetInsertBlock();
796 std::string loop_var_name = loop_var->name_hint;
797 llvm::LLVMContext* ctx = llvm_target_->GetContext();
798 auto* for_begin = llvm::BasicBlock::Create(*ctx, "for_begin_" + loop_var_name, function_);
799 auto* for_body = llvm::BasicBlock::Create(*ctx, "for_body_" + loop_var_name, function_);
800 auto* for_end = llvm::BasicBlock::Create(*ctx, "for_end_" + loop_var_name, function_);
801 builder_->CreateBr(for_begin);
802 builder_->SetInsertPoint(for_begin);
803 llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
804 loop_value->setName(loop_var->name_hint.c_str());
805 loop_value->addIncoming(begin, pre_block);
806 ICHECK(!var_map_.count(loop_var.get()));
807 var_map_[loop_var.get()] = loop_value;
808 auto lt = CreateLT(loop_var.dtype(), loop_value, end);
809 builder_->CreateCondBr(lt, for_body, for_end, md_very_likely_branch_);
810 builder_->SetInsertPoint(for_body);
811 this->VisitStmt(body);
812 var_map_.erase(loop_var.get());
813 llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride);
814 loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
815 builder_->CreateBr(for_begin);
816 builder_->SetInsertPoint(for_end);
817}
818
819// cast operatpr
820llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
821 llvm::Type* target = DTypeToLLVMType(to);
822 if (value->getType() == target) return value;
823 if (to.is_handle()) {
824 return builder_->CreateBitCast(value, target);
825 } else if (to.is_uint() && to.bits() == 1) {
826 if (from.is_float()) {
827 llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
828 return builder_->CreateFCmpONE(value, zero);
829 } else {
830 llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0);
831 return builder_->CreateICmpNE(value, zero);
832 }
833 } else if (!from.is_float() && !to.is_float()) {
834 return builder_->CreateIntCast(value, target, from.is_int());
835 } else if (from.is_float() && to.is_int()) {
836 return builder_->CreateFPToSI(value, target);
837 } else if (from.is_float() && to.is_uint()) {
838 if (to.bits() < 8) {
839 value = builder_->CreateFPToUI(value, DTypeToLLVMType(to.with_bits(8)));
840 return builder_->CreateIntCast(value, target, false);
841 } else {
842 return builder_->CreateFPToUI(value, target);
843 }
844 } else if (from.is_int() && to.is_float()) {
845 return builder_->CreateSIToFP(value, target);
846 } else if (from.is_uint() && to.is_float()) {
847 return builder_->CreateUIToFP(value, target);
848 } else {
849 ICHECK(from.is_float() && to.is_float());
850 return builder_->CreateFPCast(value, target);
851 }
852}
853
854llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const std::string& name,
855 llvm::GlobalValue::LinkageTypes linkage_type) {
856 llvm::Type* ty = const_data->getType();
857 llvm::GlobalVariable* global =
858 new llvm::GlobalVariable(*module_, ty, true, linkage_type, const_data, name);
859#if TVM_LLVM_VERSION >= 100
860 global->setAlignment(llvm::Align(1));
861#else
862 global->setAlignment(1);
863#endif
864 llvm::Constant* zero = ConstInt32(0);
865 llvm::Constant* indices[] = {zero, zero};
866 llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(ty, global, indices);
867 return ptr;
868}
869
870llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) {
871 auto it = str_map_.find(str);
872 if (it != str_map_.end()) return it->second;
873 auto llvm_str = llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), str);
874 auto ptr = GetGlobalConstant(llvm_str, ".str", llvm::GlobalValue::PrivateLinkage);
875 str_map_[str] = ptr;
876 return ptr;
877}
878
879CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr,
880 DataType buffer_element_dtype,
881 llvm::ArrayRef<llvm::Value*> indices,
882 DataType value_dtype) {
883 ICHECK_EQ(indices.size(), 1) << "CodeGenLLVM requires all buffers to be flat 1-d buffers.";
884 llvm::Value* index = indices[0];
885
886 llvm::PointerType* buffer_ptr_type = llvm::dyn_cast<llvm::PointerType>(buffer_ptr->getType());
887 ICHECK(buffer_ptr_type != nullptr);
888 auto address_space = buffer_ptr_type->getAddressSpace();
889
890 llvm::Type* element_type = DTypeToLLVMType(buffer_element_dtype);
891 llvm::PointerType* element_ptr_type =
892 DTypeToLLVMType(buffer_element_dtype)->getPointerTo(address_space);
893 llvm::Type* value_type = DTypeToLLVMType(value_dtype);
894 llvm::PointerType* value_ptr_type = value_type->getPointerTo(address_space);
895
896 ICHECK(index->getType()->isIntegerTy()) << "Expected buffer index to be an integer";
897
898 if (buffer_ptr_type != element_ptr_type) {
899 buffer_ptr = builder_->CreatePointerCast(buffer_ptr, element_ptr_type);
900 }
901 ICHECK(!HasAlignmentPadding(buffer_element_dtype))
902 << "DType " << buffer_element_dtype
903 << " has padding for alignment. TVM data arrays are expected to be densely packed, with no "
904 "padding for alignment.";
905 llvm::Value* value_ptr = builder_->CreateInBoundsGEP(element_type, buffer_ptr, index);
906
907 if (element_ptr_type != value_ptr_type) {
908 value_ptr = builder_->CreatePointerCast(value_ptr, value_ptr_type);
909 }
910
911 return TypedPointer(value_type, value_ptr);
912}
913
914llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const {
915 auto it = var_map_.find(v);
916 ICHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
917 return it->second;
918}
919
920void CodeGenLLVM::CreatePrintf(const std::string& format,
921 llvm::ArrayRef<llvm::Value*> format_args) {
922 EmitDebugLocation();
923 llvm::Function* func_printf = module_->getFunction("printf");
924 if (func_printf == nullptr) {
925 llvm::FunctionType* ftype = llvm::FunctionType::get(t_int32_, true);
926 func_printf =
927 llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "printf", module_.get());
928 }
929
930 llvm::Function* func_fflush = module_->getFunction("fflush");
931 if (!func_fflush) {
932 llvm::FunctionType* ftype = llvm::FunctionType::get(t_int32_, {t_void_p_}, false);
933 func_fflush =
934 llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "fflush", module_.get());
935 }
936
937 llvm::Value* str = builder_->CreateGlobalStringPtr(format);
938 str->setName("printf_format_str");
939
940 std::vector<llvm::Value*> printf_args = {str};
941 printf_args.insert(printf_args.end(), format_args.begin(), format_args.end());
942 builder_->CreateCall(func_printf, printf_args);
943
944 // Call fflush() immediately, as this utility is intended for debug
945 // purposes. A segfault occurring within the generated LLVM code
946 // would otherwise leave the stdout buffer unflushed.
947 llvm::Value* null_stream = llvm::ConstantPointerNull::get(t_void_p_);
948 null_stream->setName("null_stream");
949 builder_->CreateCall(func_fflush, {null_stream});
950}
951
952llvm::Value* CodeGenLLVM::CreateLookupReturnAddress(unsigned int level) {
953 EmitDebugLocation();
954 llvm::Value* level_val = llvm::ConstantInt::get(t_int32_, level);
955 llvm::Function* builtin =
956 llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::returnaddress);
957 llvm::Value* call = builder_->CreateCall(builtin, level_val);
958 call->setName("return_addr");
959
960 return call;
961}
962
963llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol,
964 const Array<PrimExpr>& args, bool skip_first_arg) {
965 std::vector<llvm::Value*> arg_value;
966 std::vector<llvm::Type*> arg_type;
967 for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
968 arg_value.push_back(MakeValue(args[i]));
969 arg_type.push_back(arg_value.back()->getType());
970 }
971 llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_type, false);
972 llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol));
973 if (f == nullptr) {
974 f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, MakeStringRef(global_symbol),
975 module_.get());
976 }
977 llvm::CallInst* call = builder_->CreateCall(f, arg_value);
978 return call;
979}
980
981llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
982 llvm::ArrayRef<llvm::Type*> arg_types) {
983 llvm::Module* module = module_.get();
984
985 if (!llvm::Intrinsic::isOverloaded(id)) {
986 return llvm::Intrinsic::getDeclaration(module, id, {});
987 }
988
989 llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos;
990 llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos);
991 llvm::SmallVector<llvm::Type*, 4> overload_types;
992
993#if TVM_LLVM_VERSION >= 90
994 auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) {
995 overload_types.clear();
996 llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
997 auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
998 if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
999 bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref);
1000 if (error) {
1001 return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg;
1002 }
1003 }
1004 return match;
1005 };
1006
1007 // First, try matching the signature assuming non-vararg case.
1008 auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false);
1009 switch (try_match(fn_ty, false)) {
1010 case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet:
1011 // The return type doesn't match, there is nothing else to do.
1012 return nullptr;
1013 case llvm::Intrinsic::MatchIntrinsicTypes_Match:
1014 return llvm::Intrinsic::getDeclaration(module, id, overload_types);
1015 case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg:
1016 break;
1017 }
1018
1019 // Keep adding one type at a time (starting from empty list), and
1020 // try matching the vararg signature.
1021 llvm::SmallVector<llvm::Type*, 4> var_types;
1022 for (int i = 0, e = arg_types.size(); i <= e; ++i) {
1023 if (i > 0) var_types.push_back(arg_types[i - 1]);
1024 auto* ft = llvm::FunctionType::get(ret_type, var_types, true);
1025 if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
1026 return llvm::Intrinsic::getDeclaration(module, id, overload_types);
1027 }
1028 }
1029 // Failed to identify the type.
1030 return nullptr;
1031
1032#else // TVM_LLVM_VERSION
1033 llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
1034 // matchIntrinsicType returns true on error.
1035 if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) {
1036 return nullptr;
1037 }
1038 for (llvm::Type* t : arg_types) {
1039 if (llvm::Intrinsic::matchIntrinsicType(t, ref, overload_types)) {
1040 return nullptr;
1041 }
1042 }
1043 return llvm::Intrinsic::getDeclaration(module, id, overload_types);
1044#endif // TVM_LLVM_VERSION
1045}
1046
1047void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) {
1048 const std::string& cpu = llvm_target_->GetCPU();
1049 if (!cpu.empty()) {
1050 func->addFnAttr("target-cpu", cpu);
1051 }
1052 const std::string& features = llvm_target_->GetTargetFeatureString();
1053 if (!features.empty()) {
1054 func->addFnAttr("target-features", features);
1055 }
1056}
1057
1058void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
1059 // The LLVM IR for these function was obtained by compiling
1060 //
1061 // For integer ABI:
1062 // __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(a);
1063 // __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(a);
1064 // For floating-point ABI:
1065 // __truncXfYf2__<float, uint32_t, 23, _Float16, uint16_t, 10>(a);
1066 // __extendXfYf2__<_Float16, uint16_t, 10, float, uint32_t, 23>(a);
1067
1068 static const char trunc_body[] = // __truncsfhf2
1069 " %v0 = bitcast float %a0 to i32\n"
1070 " %v1 = and i32 %v0, 2147483647\n"
1071 " %v2 = add nsw i32 %v1, -947912704\n"
1072 " %v3 = add nsw i32 %v1, -1199570944\n"
1073 " %v4 = icmp ult i32 %v2, %v3\n"
1074 " br i1 %v4, label %b1, label %b5\n"
1075 "b1:\n"
1076 " %v5 = lshr i32 %v0, 13\n"
1077 " %v6 = and i32 %v5, 65535\n"
1078 " %v7 = add nuw nsw i32 %v6, -114688\n"
1079 " %v8 = and i32 %v0, 8191\n"
1080 " %v9 = icmp ugt i32 %v8, 4096\n"
1081 " br i1 %v9, label %b2, label %b3\n"
1082 "b2:\n"
1083 " %v10 = add nuw nsw i32 %v6, -114687\n"
1084 " br label %b13\n"
1085 "b3:\n"
1086 " %v11 = icmp eq i32 %v8, 4096\n"
1087 " br i1 %v11, label %b4, label %b13\n"
1088 "b4:\n"
1089 " %v12 = and i32 %v7, 65535\n"
1090 " %v13 = and i32 %v5, 1\n"
1091 " %v14 = add nuw nsw i32 %v12, %v13\n"
1092 " br label %b13\n"
1093 "b5:\n"
1094 " %v15 = icmp ugt i32 %v1, 2139095040\n"
1095 " br i1 %v15, label %b6, label %b7\n"
1096 "b6:\n"
1097 " %v16 = lshr i32 %v0, 13\n"
1098 " %v17 = and i32 %v16, 511\n"
1099 " %v18 = or i32 %v17, 32256\n"
1100 " br label %b13\n"
1101 "b7:\n"
1102 " %v19 = icmp ugt i32 %v1, 1199570943\n"
1103 " br i1 %v19, label %b13, label %b8\n"
1104 "b8:\n"
1105 " %v20 = icmp ult i32 %v1, 754974720\n"
1106 " br i1 %v20, label %b13, label %b9\n"
1107 "b9:\n"
1108 " %v21 = lshr i32 %v1, 23\n"
1109 " %v22 = sub nsw i32 113, %v21\n"
1110 " %v23 = and i32 %v0, 8388607\n"
1111 " %v24 = or i32 %v23, 8388608\n"
1112 " %v25 = add nsw i32 %v21, -81\n"
1113 " %v26 = shl i32 %v24, %v25\n"
1114 " %v27 = icmp ne i32 %v26, 0\n"
1115 " %v28 = lshr i32 %v24, %v22\n"
1116 " %v29 = zext i1 %v27 to i32\n"
1117 " %v30 = lshr i32 %v28, 13\n"
1118 " %v31 = and i32 %v28, 8191\n"
1119 " %v32 = or i32 %v31, %v29\n"
1120 " %v33 = icmp ugt i32 %v32, 4096\n"
1121 " br i1 %v33, label %b10, label %b11\n"
1122 "b10:\n"
1123 " %v34 = add nuw nsw i32 %v30, 1\n"
1124 " br label %b13\n"
1125 "b11:\n"
1126 " %v35 = icmp eq i32 %v32, 4096\n"
1127 " br i1 %v35, label %b12, label %b13\n"
1128 "b12:\n"
1129 " %v36 = and i32 %v30, 1\n"
1130 " %v37 = add nuw nsw i32 %v36, %v30\n"
1131 " br label %b13\n"
1132 "b13:\n"
1133 " %v38 = phi i32 [ %v18, %b6 ], [ %v10, %b2 ], [ %v14, %b4 ], [ %v7, %b3 ],\n"
1134 " [ 31744, %b7 ], [ 0, %b8 ], [ %v34, %b10 ], [ %v37, %b12 ],\n"
1135 " [ %v30, %b11 ]\n"
1136 " %v39 = lshr i32 %v0, 16\n"
1137 " %v40 = and i32 %v39, 32768\n"
1138 " %v41 = or i32 %v38, %v40\n"
1139 " %vlast = trunc i32 %v41 to i16\n";
1140
1141 static const char extend_body[] = // __extendhfsf2
1142 " %v1 = and i16 %vinp, 32767\n"
1143 " %v2 = zext i16 %v1 to i32\n"
1144 " %v3 = add nsw i16 %v1, -1024\n"
1145 " %v4 = icmp ult i16 %v3, 30720\n"
1146 " br i1 %v4, label %b1, label %b2\n"
1147 "b1:\n"
1148 " %v5 = shl nuw nsw i32 %v2, 13\n"
1149 " %v6 = add nuw nsw i32 %v5, 939524096\n"
1150 " br label %b6\n"
1151 "b2:\n"
1152 " %v7 = icmp ugt i16 %v1, 31743\n"
1153 " br i1 %v7, label %b3, label %b4\n"
1154 "b3:\n"
1155 " %v8 = shl nuw nsw i32 %v2, 13\n"
1156 " %v9 = or i32 %v8, 2139095040\n"
1157 " br label %b6\n"
1158 "b4:\n"
1159 " %v10 = icmp eq i16 %v1, 0\n"
1160 " br i1 %v10, label %b6, label %b5\n"
1161 "b5:\n"
1162 " %v11 = icmp ult i16 %v1, 256\n"
1163 " %v12 = lshr i32 %v2, 8\n"
1164 " %v13 = select i1 %v11, i32 %v2, i32 %v12\n"
1165 " %v14 = select i1 %v11, i32 32, i32 24\n"
1166 " %v15 = icmp ult i32 %v13, 16\n"
1167 " %v16 = lshr i32 %v13, 4\n"
1168 " %v17 = add nsw i32 %v14, -4\n"
1169 " %v18 = select i1 %v15, i32 %v13, i32 %v16\n"
1170 " %v19 = select i1 %v15, i32 %v14, i32 %v17\n"
1171 " %v20 = icmp ult i32 %v18, 4\n"
1172 " %v21 = lshr i32 %v18, 2\n"
1173 " %v22 = add nsw i32 %v19, -2\n"
1174 " %v23 = select i1 %v20, i32 %v18, i32 %v21\n"
1175 " %v24 = select i1 %v20, i32 %v19, i32 %v22\n"
1176 " %v25 = icmp ult i32 %v23, 2\n"
1177 " %v26 = sub nsw i32 0, %v23\n"
1178 " %v27 = select i1 %v25, i32 %v26, i32 -2\n"
1179 " %v28 = add nsw i32 %v27, %v24\n"
1180 " %v29 = add nsw i32 %v28, -8\n"
1181 " %v30 = shl i32 %v2, %v29\n"
1182 " %v31 = xor i32 %v30, 8388608\n"
1183 " %v32 = shl i32 %v28, 23\n"
1184 " %v33 = sub i32 1124073472, %v32\n"
1185 " %v34 = or i32 %v31, %v33\n"
1186 " br label %b6\n"
1187 "b6:\n"
1188 " %v35 = phi i32 [ %v6, %b1 ], [ %v9, %b3 ], [ %v34, %b5 ], [ 0, %b4 ]\n"
1189 " %v36 = and i16 %vinp, -32768\n"
1190 " %v37 = zext i16 %v36 to i32\n"
1191 " %v38 = shl nuw i32 %v37, 16\n"
1192 " %v39 = or i32 %v35, %v38\n"
1193 " %v40 = bitcast i32 %v39 to float\n"
1194 " ret float %v40\n"
1195 "}\n";
1196
1197 std::string short_type = use_float16_abi ? "half" : "i16";
1198
1199 std::string short_cast_in, short_cast_out;
1200 if (use_float16_abi) {
1201 short_cast_in = " %vinp = bitcast half %a0 to i16\n";
1202 short_cast_out = " %vres = bitcast i16 %vlast to half\n";
1203 } else {
1204 // No-ops that preserve the i16 values.
1205 short_cast_in = " %vinp = add i16 %a0, 0\n";
1206 short_cast_out = " %vres = add i16 %vlast, 0\n";
1207 }
1208
1209 llvm::Triple triple(llvm_target_->GetTargetTriple());
1210
1211 static const char elf_section_name[] = ".text.tvm.fp16.conv";
1212 std::string section = triple.getObjectFormat() == llvm::Triple::ELF
1213 ? std::string("section \"") + elf_section_name + "\" "
1214 : "";
1215
1216 std::string trunc_header = "define weak dso_local " + short_type +
1217 " @__truncsfhf2(float %a0) local_unnamed_addr #0 " + section +
1218 "{\nb0:\n";
1219 std::string trunc_return = " ret " + short_type + " %vres\n}\n";
1220
1221 std::string extend_header = "define weak dso_local float @__extendhfsf2(" + short_type +
1222 " %a0) local_unnamed_addr #0 " + section + "{\nb0:\n";
1223
1224 // truncate = trunc_header + trunc_body + short_cast_out + trunc_return
1225 // extend = extend_header + short_cast_in + extend_body
1226
1227 std::string attributes = "attributes #0 = { nounwind readnone \"target-cpu\"=\"" +
1228 llvm_target_->GetCPU() + "\" \"target-features\"=\"" +
1229 llvm_target_->GetTargetFeatureString() + "\" }\n";
1230
1231 auto data_layout = llvm_target_->GetOrCreateTargetMachine()->createDataLayout();
1232 std::string module_ir = "target triple = \"" + llvm_target_->GetTargetTriple() + "\"\n" +
1233 "target datalayout = \"" + data_layout.getStringRepresentation() +
1234 "\"\n" + trunc_header + trunc_body + short_cast_out + trunc_return +
1235 extend_header + short_cast_in + extend_body + attributes;
1236
1237 auto builtins_module = llvm_target_->GetInstance().ParseIR(module_ir);
1238 link_modules_.push_back(std::move(builtins_module));
1239}
1240
1241llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
1242 if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
1243 ICHECK_GE(op->args.size(), 2U);
1244 llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
1245 int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
1246 std::vector<llvm::Value*> arg_value;
1247 std::vector<llvm::Type*> arg_type;
1248 for (size_t i = 2; i < op->args.size(); ++i) {
1249 arg_value.push_back(MakeValue(op->args[i]));
1250 if (i - 2 < static_cast<size_t>(num_signature)) {
1251 arg_type.push_back(arg_value.back()->getType());
1252 }
1253 }
1254 // LLVM's prefetch intrinsic returns "void", while TVM's prefetch
1255 // returns int32. This causes problems because prefetch is one of
1256 // those intrinsics that is generated automatically via the
1257 // tvm.intrin.rule mechanism. Any other intrinsic with a type
1258 // mismatch will have to be treated specially here.
1259 // TODO(kparzysz-quic): fix this once TVM prefetch uses the same
1260 // type as LLVM.
1261 llvm::Type* return_type =
1262 (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef<PrimExpr>(op)) : t_void_;
1263 llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
1264 ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
1265#if TVM_LLVM_VERSION >= 130
1266 << llvm::Intrinsic::getBaseName(id).str();
1267#else
1268 << llvm::Intrinsic::getName(id, {});
1269#endif
1270 return builder_->CreateCall(f, arg_value);
1271 } else if (op->op.same_as(builtin::bitwise_and())) {
1272 return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
1273 } else if (op->op.same_as(builtin::bitwise_or())) {
1274 return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
1275 } else if (op->op.same_as(builtin::bitwise_not())) {
1276 return builder_->CreateNot(MakeValue(op->args[0]));
1277 } else if (op->op.same_as(builtin::bitwise_xor())) {
1278 return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
1279 } else if (op->op.same_as(builtin::shift_left())) {
1280 return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
1281 } else if (op->op.same_as(builtin::shift_right())) {
1282 if (op->args[0].dtype().is_int()) {
1283 return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
1284 } else {
1285 return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
1286 }
1287 } else if (op->op.same_as(builtin::tvm_storage_sync())) {
1288 return CreateStorageSync(op);
1289 } else if (op->op.same_as(builtin::address_of())) {
1290 const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
1291 ICHECK(op->args.size() == 1 && load);
1292
1293 Array<PrimExpr> indices = load->indices;
1294 if (const RampNode* r = indices[indices.size() - 1].as<RampNode>()) {
1295 indices.Set(indices.size() - 1, r->base);
1296 }
1297
1298 std::vector<llvm::Value*> indices_val;
1299 for (const auto& index : indices) {
1300 indices_val.push_back(MakeValue(index));
1301 }
1302
1303 TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype,
1304 indices_val, load->dtype);
1305 unsigned addrspace =
1306 llvm::dyn_cast<llvm::PointerType>(buffer_ptr.addr->getType())->getAddressSpace();
1307 return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace));
1308 } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) {
1309 return llvm::Constant::getNullValue(t_void_p_);
1310 } else if (op->op.same_as(builtin::isnullptr())) {
1311 return builder_->CreateIsNull(MakeValue(op->args[0]));
1312 } else if (op->op.same_as(builtin::large_uint_imm())) {
1313 ICHECK_EQ(op->args.size(), 2U);
1314 uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
1315 uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
1316 uint64_t val = (high << 32U) | low;
1317 return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val);
1318 } else if (op->op.same_as(builtin::if_then_else())) {
1319 ICHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition";
1320 llvm::LLVMContext* ctx = llvm_target_->GetContext();
1321 auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_);
1322 auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_);
1323 auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_);
1324 builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
1325 builder_->SetInsertPoint(then_block);
1326 llvm::Value* then_value = MakeValue(op->args[1]);
1327 llvm::BasicBlock* then_value_block = builder_->GetInsertBlock();
1328 builder_->CreateBr(end_block);
1329 builder_->SetInsertPoint(else_block);
1330 llvm::Value* else_value = MakeValue(op->args[2]);
1331 llvm::BasicBlock* else_value_block = builder_->GetInsertBlock();
1332 builder_->CreateBr(end_block);
1333 builder_->SetInsertPoint(end_block);
1334 llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
1335 value->addIncoming(then_value, then_value_block);
1336 value->addIncoming(else_value, else_value_block);
1337 return value;
1338 } else if (op->op.same_as(builtin::ret())) {
1339 auto const* val = op->args[0].as<IntImmNode>();
1340 ICHECK(val) << "the tir.ret should be transformed to return zero "
1341 << "before the llvm code generation.";
1342 ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to "
1343 << "return zero before the llvm code generation.";
1344 builder_->CreateRet(ConstInt32(0));
1345 // LLVM allows exactly one terminator in a single basic block
1346 // append a new dummy basic block to avoid error.
1347 llvm::BasicBlock* ret_dummy =
1348 llvm::BasicBlock::Create(*llvm_target_->GetContext(), "ret_dummy", function_);
1349 builder_->SetInsertPoint(ret_dummy);
1350 return ret_dummy;
1351 } else if (op->op.same_as(builtin::reinterpret())) {
1352 llvm::Type* target = DTypeToLLVMType(op->dtype);
1353 return builder_->CreateBitCast(MakeValue(op->args[0]), target);
1354 } else if (op->op.same_as(builtin::isnan())) {
1355 // TODO(hgt312): set fast math flag
1356 llvm::Value* a = MakeValue(op->args[0]);
1357 return builder_->CreateFCmpUNO(a, a);
1358 } else if (op->op.same_as(builtin::vectorlow())) {
1359 llvm::Value* v = MakeValue(op->args[0]);
1360 int l = GetVectorNumElements(v);
1361 return CreateVecSlice(v, 0, l / 2);
1362 } else if (op->op.same_as(builtin::vectorhigh())) {
1363 llvm::Value* v = MakeValue(op->args[0]);
1364 int l = GetVectorNumElements(v);
1365 return CreateVecSlice(v, l / 2, l / 2);
1366 } else if (op->op.same_as(builtin::vectorcombine())) {
1367 llvm::Value* v0 = MakeValue(op->args[0]);
1368 llvm::Value* v1 = MakeValue(op->args[1]);
1369 int num_elems = GetVectorNumElements(v0) * 2;
1370#if TVM_LLVM_VERSION >= 110
1371 std::vector<int> indices;
1372#else
1373 std::vector<unsigned> indices;
1374#endif
1375 for (int i = 0; i < num_elems; ++i) {
1376 indices.push_back(i);
1377 }
1378 return builder_->CreateShuffleVector(v0, v1, indices);
1379 } else if (op->op.same_as(builtin::atomic_add())) {
1380 // TODO(masahi): Support atomic for CPU backend
1381 LOG(FATAL) << "CPU backend does not support atomic add yet.";
1382 } else if (op->op.same_as(builtin::start_profile_intrinsic()) ||
1383 op->op.same_as(builtin::end_profile_intrinsic())) {
1384 LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op;
1385 return nullptr;
1386 } else {
1387 LOG(FATAL) << "unknown intrinsic " << op->op;
1388 }
1389}
1390
1391void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f) {
1392 if (const RampNode* ramp = e.as<RampNode>()) {
1393 for (int i = 0; i < ramp->dtype.lanes(); ++i) {
1394 PrimExpr offset = ramp->base + (ramp->stride * i);
1395 f(i, MakeValue(offset));
1396 }
1397 } else {
1398 llvm::Value* value = MakeValue(e);
1399 for (int i = 0; i < e.dtype().lanes(); ++i) {
1400 f(i, builder_->CreateExtractElement(value, i));
1401 }
1402 }
1403}
1404
1405// Visitors
1406llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); }
1407
1408llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
1409 return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
1410}
1411llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
1412 return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value);
1413}
1414
1415llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
1416 return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value);
1417}
1418
1419llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); }
1420
1421#define DEFINE_CODEGEN_BINARY_OP(Op) \
1422 llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
1423 if (t.is_int()) { \
1424 if (t.bits() >= 32) { \
1425 return builder_->CreateNSW##Op(a, b); \
1426 } else { \
1427 return builder_->Create##Op(a, b); \
1428 } \
1429 } else if (t.is_uint()) { \
1430 if (t.bits() >= 32) { \
1431 return builder_->CreateNUW##Op(a, b); \
1432 } else { \
1433 return builder_->Create##Op(a, b); \
1434 } \
1435 } else { \
1436 ICHECK(t.is_float()); \
1437 return builder_->CreateF##Op(a, b); \
1438 } \
1439 } \
1440 llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
1441 return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
1442 }
1443
1444DEFINE_CODEGEN_BINARY_OP(Add);
1445DEFINE_CODEGEN_BINARY_OP(Sub);
1446DEFINE_CODEGEN_BINARY_OP(Mul);
1447
1448#define DEFINE_CODEGEN_CMP_OP(Op) \
1449 llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
1450 if (t.is_int()) { \
1451 return builder_->CreateICmpS##Op(a, b); \
1452 } else if (t.is_uint()) { \
1453 return builder_->CreateICmpU##Op(a, b); \
1454 } else { \
1455 ICHECK(t.is_float()); \
1456 return builder_->CreateFCmpO##Op(a, b); \
1457 } \
1458 } \
1459 llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
1460 return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
1461 }
1462
1463DEFINE_CODEGEN_CMP_OP(LT);
1464DEFINE_CODEGEN_CMP_OP(LE);
1465DEFINE_CODEGEN_CMP_OP(GT);
1466DEFINE_CODEGEN_CMP_OP(GE);
1467
1468llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
1469 llvm::Value* a = MakeValue(op->a);
1470 llvm::Value* b = MakeValue(op->b);
1471 if (op->dtype.is_int()) {
1472 return builder_->CreateSDiv(a, b);
1473 } else if (op->dtype.is_uint()) {
1474 return builder_->CreateUDiv(a, b);
1475 } else {
1476 ICHECK(op->dtype.is_float());
1477 return builder_->CreateFDiv(a, b);
1478 }
1479}
1480
1481llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
1482 llvm::Value* a = MakeValue(op->a);
1483 llvm::Value* b = MakeValue(op->b);
1484 if (op->dtype.is_int()) {
1485 return builder_->CreateSRem(a, b);
1486 } else if (op->dtype.is_uint()) {
1487 return builder_->CreateURem(a, b);
1488 } else {
1489 ICHECK(op->dtype.is_float());
1490 return builder_->CreateFRem(a, b);
1491 }
1492}
1493
1494llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
1495 llvm::Value* a = MakeValue(op->a);
1496 llvm::Value* b = MakeValue(op->b);
1497 return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
1498}
1499
1500llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
1501 llvm::Value* a = MakeValue(op->a);
1502 llvm::Value* b = MakeValue(op->b);
1503 return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
1504}
1505
1506llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
1507 llvm::Value* a = MakeValue(op->a);
1508 llvm::Value* b = MakeValue(op->b);
1509 if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
1510 return builder_->CreateICmpEQ(a, b);
1511 } else {
1512 return builder_->CreateFCmpOEQ(a, b);
1513 }
1514}
1515
1516llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
1517 llvm::Value* a = MakeValue(op->a);
1518 llvm::Value* b = MakeValue(op->b);
1519 if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
1520 return builder_->CreateICmpNE(a, b);
1521 } else {
1522 return builder_->CreateFCmpONE(a, b);
1523 }
1524}
1525
1526llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) {
1527 return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
1528}
1529
1530llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) {
1531 return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
1532}
1533
1534llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) {
1535 return builder_->CreateNot(MakeValue(op->a));
1536}
1537
1538llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
1539 return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value),
1540 MakeValue(op->false_value));
1541}
1542
1543llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
1544 auto it = let_binding_.find(op->var);
1545 if (it != let_binding_.end()) {
1546 ICHECK(deep_equal_(it->second->value, op->value))
1547 << "Let cannot bind the same var to two different values";
1548 } else {
1549 let_binding_[op->var] = op;
1550 }
1551 auto var_value = MakeValue(op->value);
1552 var_map_[op->var.get()] = var_value;
1553 var_value->setName(op->var->name_hint.c_str());
1554 analyzer_->Bind(op->var, op->value);
1555 return MakeValue(op->body);
1556}
1557
1558llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
1559 LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead.";
1560}
1561
1562bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) {
1563 const llvm::DataLayout& data_layout = module_->getDataLayout();
1564 int bytes = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype));
1565 int bytes_scalar = data_layout.getTypeAllocSize(DTypeToLLVMType(dtype.element_of()));
1566 return bytes != bytes_scalar * dtype.lanes();
1567}
1568
1569void CodeGenLLVM::BufferAccessHelper(
1570 Buffer buffer, Array<PrimExpr> indices, DataType value_dtype,
1571 std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i, int alignment,
1572 bool is_volatile)>
1573 make_instruction) {
1574 DataType buffer_element_dtype = buffer->dtype;
1575
1576 ICHECK_GE(indices.size(), 1)
1577 << "Buffer " << buffer->name << " is accessed with no indices. "
1578 << "0-d scalar buffers are expected to be flattened to 1-d buffers prior to codegen.";
1579
1580 // Only the last index is allowed to be multi-lane. All earlier
1581 // indices must be scalar. This only matters for subclasses of
1582 // CodeGenLLVM, because the default implementation of GetBufferPtr
1583 // requires 1-d indices.
1584 std::vector<llvm::Value*> earlier_index_values;
1585 for (size_t i = 0; i < indices.size() - 1; i++) {
1586 ICHECK_EQ(indices[i].dtype().lanes(), 1)
1587 << "Buffer " << buffer->name << " is accessed with a multi-lane index at position " << i
1588 << ". Multi-lane indices are only supported as the last index.";
1589 earlier_index_values.push_back(MakeValue(indices[i]));
1590 }
1591
1592 PrimExpr last_index = indices[indices.size() - 1];
1593 ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes());
1594
1595 // Record index and elemtype in original form used for alias info
1596 PrimExpr last_index_origin = last_index;
1597 DataType buffer_element_dtype_origin = buffer_element_dtype;
1598
1599 bool is_volatile = volatile_buf_.count(buffer->data.get());
1600
1601 // If the buffer index is a contiguous ramp node, we only need to
1602 // access the first element, then cast to the value type.
1603 if (const RampNode* ramp_index = last_index.as<RampNode>()) {
1604 if (is_one(ramp_index->stride)) {
1605 last_index = ramp_index->base;
1606 }
1607 }
1608
1609 // All TVM arrays are densely packed. If the vectorized LLVM type
1610 // contains padding for alignment, we need to index based on the
1611 // size of the scalar type to avoid introducing that padding.
1612 if (last_index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) {
1613 last_index = buffer_element_dtype.lanes() * last_index;
1614 buffer_element_dtype = buffer_element_dtype.element_of();
1615 }
1616
1617 int alignment;
1618 if (last_index.dtype().lanes() == 1) {
1619 // If we are accessing with a single index, then the vectorized
1620 // element being accessed may require more alignment than the
1621 // underlying data type.
1622 int native_bits;
1623 GetAlignment(value_dtype, buffer->data.get(), last_index, &alignment, &native_bits);
1624 } else {
1625 // Otherwise, alignment is based on the return value's scalar
1626 // type.
1627 ICHECK_GE(value_dtype.bits(), 8);
1628 alignment = value_dtype.bits() / 8;
1629 }
1630
1631 llvm::Value* cached_vector_index = nullptr;
1632 for (int i = 0; i < last_index.dtype().lanes(); ++i) {
1633 llvm::Value* last_index_value;
1634 int subelement_i = i;
1635 if (const RampNode* ramp = last_index.as<RampNode>()) {
1636 PrimExpr offset = ramp->base + (ramp->stride * i);
1637 last_index_value = MakeValue(offset);
1638 } else if (last_index.dtype().lanes() > 1) {
1639 if (i == 0) {
1640 cached_vector_index = MakeValue(last_index);
1641 }
1642 last_index_value = builder_->CreateExtractElement(cached_vector_index, i);
1643 } else {
1644 last_index_value = MakeValue(last_index);
1645 subelement_i = -1;
1646 }
1647
1648 std::vector<llvm::Value*> all_index_values = earlier_index_values;
1649 all_index_values.push_back(last_index_value);
1650
1651 TypedPointer buffer_ptr =
1652 CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values,
1653 value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes()));
1654 auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile);
1655 AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin);
1656 }
1657}
1658
1659llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
1660 DataType value_dtype = op->dtype;
1661
1662 std::vector<llvm::Value*> loads;
1663
1664 auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment,
1665 bool is_volatile) {
1666#if TVM_LLVM_VERSION >= 110
1667 auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr,
1668 llvm::Align(alignment), is_volatile);
1669#elif TVM_LLVM_VERSION >= 80
1670 auto load =
1671 builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile);
1672#else
1673 auto load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile);
1674#endif
1675
1676 loads.push_back(load);
1677 return load;
1678 };
1679
1680 // Pass all indices into BufferAccessHelper. In CodeGenLLVM,
1681 // non-flat indices will result in an error in CreateBufferPtr, but
1682 // a subclass may override CreateBufferPtr.
1683 BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load);
1684
1685 if (loads.size() == 1) {
1686 return loads[0];
1687 } else {
1688 llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(value_dtype));
1689 for (size_t i = 0; i < loads.size(); i++) {
1690 ret = builder_->CreateInsertElement(ret, loads[i], ConstInt32(i));
1691 }
1692 return ret;
1693 }
1694}
1695
1696llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
1697 if (auto* ptr_op = op->op.as<OpNode>()) {
1698 auto call_op = GetRef<Op>(ptr_op);
1699 if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
1700 // call extern intrinsic
1701 ICHECK_GE(op->args.size(), 1U);
1702 auto global_symbol = Downcast<StringImm>(op->args[0]);
1703 return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), global_symbol->value, op->args,
1704 true);
1705 } else if (op_attr_global_symbol_.count(call_op)) {
1706 // call extern if the op itself have a global symbol.
1707 return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
1708 op->args, false);
1709 } else {
1710 VLOG(2) << "CreateIntrinsic: " << GetRef<Call>(op);
1711 auto x = CreateIntrinsic(op);
1712 VLOG(2) << "CreateIntrinsic done";
1713 return x;
1714 }
1715 } else {
1716 ICHECK(op->op.as<GlobalVarNode>());
1717 LOG(FATAL) << "Do not yet support cross function call";
1718 }
1719}
1720
1721llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
1722 llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
1723 for (int i = 0; i < op->lanes; ++i) {
1724 vec = builder_->CreateInsertElement(
1725 vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i));
1726 }
1727 return vec;
1728}
1729
1730llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
1731 std::vector<llvm::Value*> vecs(op->vectors.size());
1732 int total_lanes = 0;
1733 for (int i = 0, e = op->vectors.size(); i < e; ++i) {
1734 vecs[i] = VisitExpr(op->vectors[i]);
1735 total_lanes += op->vectors[i].dtype().lanes();
1736 }
1737 llvm::Value* v0 = CreateVecConcat(vecs);
1738 std::vector<uint32_t> idx(op->indices.size());
1739 for (int i = 0, e = op->indices.size(); i < e; ++i) {
1740 const int64_t* val = as_const_int(op->indices[i]);
1741 ICHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, "
1742 << "but get " << op->indices[i] << "\n";
1743 idx[i] = *val;
1744 }
1745 llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
1746 auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
1747 // If the output is a single-element vector, convert it back to a scalar.
1748 if (idx.size() == 1) {
1749 res = builder_->CreateExtractElement(res, ConstInt32(0));
1750 }
1751 return res;
1752}
1753
1754llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
1755 return CreateBroadcast(MakeValue(op->value), op->lanes);
1756}
1757
1758void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
1759 LOG(FATAL) << "Unexpected deprecated StoreNode. Use BufferStoreNode instead.";
1760}
1761
1762void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
1763 EmitDebugLocation(op);
1764 DataType value_dtype = op->value.dtype();
1765 Var buffer_var = op->buffer->data;
1766
1767 llvm::Value* value = MakeValue(op->value);
1768
1769 auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment,
1770 bool is_volatile) {
1771 llvm::Value* to_store = value;
1772 if (subelement_i != -1) {
1773 to_store = builder_->CreateExtractElement(value, subelement_i);
1774 }
1775#if TVM_LLVM_VERSION >= 110
1776 return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment),
1777 is_volatile);
1778#else
1779 return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile);
1780#endif
1781 };
1782
1783 // Pass all indices into BufferAccessHelper. In CodeGenLLVM,
1784 // non-flat indices will result in an error in CreateBufferPtr, but
1785 // a subclass may override CreateBufferPtr.
1786 BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store);
1787}
1788
1789void CodeGenLLVM::VisitStmt_(const ForNode* op) {
1790 EmitDebugLocation(op);
1791 ICHECK(is_zero(op->min));
1792 analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
1793 if (op->kind == ForKind::kUnrolled) {
1794 LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
1795 << " consider set unroll_explicit=True";
1796 } else {
1797 ICHECK(op->kind == ForKind::kSerial);
1798 }
1799 CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
1800 llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body);
1801}
1802
1803void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
1804 EmitDebugLocation(op);
1805 llvm::LLVMContext* ctx = llvm_target_->GetContext();
1806 auto* while_cond = llvm::BasicBlock::Create(*ctx, "while_cond", function_);
1807 auto* while_body = llvm::BasicBlock::Create(*ctx, "while_body", function_);
1808 auto* while_merge = llvm::BasicBlock::Create(*ctx, "while_merge", function_);
1809 builder_->CreateBr(while_cond);
1810 builder_->SetInsertPoint(while_cond);
1811 builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge);
1812 builder_->SetInsertPoint(while_body);
1813 this->VisitStmt(op->body);
1814 builder_->CreateBr(while_cond);
1815 builder_->SetInsertPoint(while_merge);
1816}
1817
1818void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
1819 EmitDebugLocation(op);
1820 llvm::Value* cond = MakeValue(op->condition);
1821 llvm::LLVMContext* ctx = llvm_target_->GetContext();
1822 auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_);
1823 auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_);
1824 if (op->else_case) {
1825 auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_);
1826 builder_->CreateCondBr(cond, then_block, else_block);
1827 builder_->SetInsertPoint(then_block);
1828 this->VisitStmt(op->then_case);
1829 builder_->CreateBr(end_block);
1830 builder_->SetInsertPoint(else_block);
1831 this->VisitStmt(op->else_case.value());
1832 builder_->CreateBr(end_block);
1833 } else {
1834 builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
1835 builder_->SetInsertPoint(then_block);
1836 this->VisitStmt(op->then_case);
1837 builder_->CreateBr(end_block);
1838 }
1839 builder_->SetInsertPoint(end_block);
1840}
1841
1842void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) {
1843 EmitDebugLocation(op);
1844 auto data = op->data.value();
1845 auto array = NDArrayToLLVMArray(llvm_target_->GetContext(), data);
1846 std::string symbol_name = op->buffer_var->name_hint;
1847 llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable(
1848 *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name);
1849
1850 var_map_[op->buffer_var.operator->()] = param_symbol;
1851 this->VisitStmt(op->body);
1852}
1853
1854void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
1855 EmitDebugLocation(op);
1856 ICHECK_EQ(op->extents.size(), 1)
1857 << "LLVM codegen only supports flat 1-d buffer allocation, but allocation of "
1858 << op->buffer_var->name_hint << " is " << op->extents << "-d";
1859
1860 ICHECK(!is_zero(op->condition));
1861 llvm::Value* buf = nullptr;
1862
1863 int32_t constant_size = op->ConstantAllocationSize();
1864 ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation";
1865 StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
1866 if (constant_size % 4 == 0 && info.alignment == 0) {
1867 info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
1868 }
1869 // maximum necessary alignment in the NV devices
1870 if (info.alignment > 16) {
1871 info.alignment = 16;
1872 }
1873 llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
1874 return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
1875 });
1876#if TVM_LLVM_VERSION >= 110
1877 auto alignment = static_cast<unsigned>(alloca->getAlign().value());
1878#else
1879 unsigned alignment = alloca->getAlignment();
1880#endif
1881 if (alignment < static_cast<unsigned>(info.alignment)) {
1882#if TVM_LLVM_VERSION >= 100
1883 alloca->setAlignment(llvm::Align(info.alignment));
1884#else
1885 alloca->setAlignment(info.alignment);
1886#endif
1887 }
1888#if TVM_LLVM_VERSION >= 110
1889 info.alignment = static_cast<unsigned>(alloca->getAlign().value());
1890#else
1891 info.alignment = alloca->getAlignment();
1892#endif
1893
1894 buf = alloca;
1895
1896 buf = builder_->CreatePointerCast(
1897 buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
1898 buf->setName(op->buffer_var->name_hint.c_str());
1899
1900 ICHECK(!var_map_.count(op->buffer_var.get()));
1901 var_map_[op->buffer_var.get()] = buf;
1902 this->VisitStmt(op->body);
1903}
1904
1905void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
1906 EmitDebugLocation(op);
1907 if (op->attr_key == tir::attr::thread_extent) {
1908 IterVar iv = Downcast<IterVar>(op->node);
1909 if (iv->thread_tag.length() != 0) {
1910 if (!var_map_.count(iv->var.get())) {
1911 var_map_[iv->var.get()] = GetThreadIndex(iv);
1912 analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
1913 }
1914 }
1915 } else if (op->attr_key == tir::attr::storage_alignment) {
1916 const VarNode* v = op->node.as<VarNode>();
1917 ICHECK(v);
1918 alloc_storage_info_[v].alignment = static_cast<int>(op->value.as<IntImmNode>()->value);
1919 if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) {
1920 builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
1921 alloc_storage_info_[v].alignment);
1922 }
1923 } else if (op->attr_key == tir::attr::volatile_scope) {
1924 const VarNode* v = op->node.as<VarNode>();
1925 ICHECK(v);
1926 volatile_buf_.insert(v);
1927 }
1928 this->VisitStmt(op->body);
1929}
1930
1931void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
1932 EmitDebugLocation(op);
1933 // auto a_cu =
1934 With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
1935 this->VisitStmt(op->body);
1936}
1937
1938void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
1939 EmitDebugLocation(op);
1940 const VarNode* v = op->var.get();
1941 ICHECK(!var_map_.count(v));
1942 if (v->dtype.is_handle()) {
1943 if (!is_restricted_) {
1944 alias_var_set_.insert(v);
1945 }
1946 }
1947 llvm::Value* value = MakeValue(op->value);
1948 value->setName(v->name_hint.c_str());
1949 var_map_[v] = value;
1950 analyzer_->Bind(op->var, op->value);
1951 if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) {
1952 builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
1953 alloc_storage_info_[v].alignment);
1954 }
1955 this->VisitStmt(op->body);
1956}
1957
1958void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) {
1959 EmitDebugLocation(op);
1960 for (Stmt stmt : op->seq) {
1961 this->VisitStmt(stmt);
1962 }
1963}
1964
1965void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
1966 EmitDebugLocation(op);
1967 MakeValue(op->value);
1968}
1969
1970void CodeGenLLVM::EmitDebugLocation(const Span& span) {
1971#if TVM_LLVM_VERSION >= 50
1972 if (di_subprogram_ == nullptr) {
1973 // debug info is not always generated outside of CPU codegen
1974 return;
1975 }
1976 if (!span.defined()) {
1977 VLOG(0) << "Cannot emit debug location for undefined span";
1978 return;
1979 }
1980 llvm::LLVMContext* ctx = llvm_target_->GetContext();
1981 auto loc = llvm::DebugLoc(llvm::DILocation::get(*ctx, span->line, span->column, di_subprogram_));
1982 builder_->SetCurrentDebugLocation(loc);
1983#endif
1984}
1985
1986void CodeGenLLVM::EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullptr); }
1987void CodeGenLLVM::EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); }
1988
1989} // namespace codegen
1990} // namespace tvm
1991
1992#endif // TVM_LLVM_VERSION
1993