1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_cpu.cc
22 */
23#ifdef TVM_LLVM_VERSION
24
25#include "codegen_cpu.h"
26
27#include <llvm/ADT/SmallVector.h>
28#include <llvm/ADT/StringRef.h>
29#include <llvm/IR/Argument.h>
30#include <llvm/IR/Attributes.h>
31#include <llvm/IR/BasicBlock.h>
32#include <llvm/IR/CallingConv.h>
33#include <llvm/IR/Comdat.h>
34#include <llvm/IR/Constants.h>
35#include <llvm/IR/DIBuilder.h>
36#include <llvm/IR/DebugInfoMetadata.h>
37#include <llvm/IR/DebugLoc.h>
38#include <llvm/IR/DerivedTypes.h>
39#include <llvm/IR/Function.h>
40#include <llvm/IR/GlobalVariable.h>
41#include <llvm/IR/Instructions.h>
42#include <llvm/IR/LLVMContext.h>
43#include <llvm/IR/MDBuilder.h>
44#include <llvm/IR/Metadata.h>
45#include <llvm/IR/Module.h>
46#if TVM_LLVM_VERSION >= 100
47#include <llvm/Support/Alignment.h>
48#endif
49#include <llvm/Support/raw_ostream.h>
50#include <llvm/Target/TargetMachine.h>
51#include <llvm/Transforms/Utils/ModuleUtils.h>
52#include <tvm/runtime/c_runtime_api.h>
53#include <tvm/runtime/module.h>
54#include <tvm/tir/analysis.h>
55
56#include <algorithm>
57#include <memory>
58#include <unordered_map>
59#include <unordered_set>
60
61#include "../func_registry_generator.h"
62#include "../metadata_utils.h"
63#include "llvm_instance.h"
64
65namespace tvm {
66namespace codegen {
67
68// Make these non-inline because of std::unique_ptr. See comment in
69// codegen_llvm.cc for more information.
70CodeGenCPU::CodeGenCPU() = default;
71CodeGenCPU::~CodeGenCPU() = default;
72
73void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib,
74 bool dynamic_lookup, bool target_c_runtime) {
75 CodeGenLLVM::Init(module_name, llvm_target, system_lib, dynamic_lookup, target_c_runtime);
76 dbg_info_ = CreateDebugInfo(module_.get());
77 static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
78 func_handle_map_.clear();
79 export_system_symbols_.clear();
80
81 // Runtime types.
82
83 t_tvm_shape_index_ =
84 llvm::Type::getIntNTy(*llvm_target_->GetContext(), DataType::ShapeIndex().bits());
85 // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h:
86 // typedef struct { DLDeviceType device_type; int device_id; } DLDevice;
87 t_tvm_device_ = llvm::StructType::create({t_int_, t_int_});
88 // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h:
89 // typedef struct { uint8_t code; uint8_t bits; uint16_t lanes; } DLDataType;
90 t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
91 // Defined in include/tvm/runtime/c_runtime_api.h:
92 // typedef void* TVMFunctionHandle;
93 t_tvm_func_handle_ = t_void_p_;
94 // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h:
95 // typedef struct { ... } DLTensor;
96 t_tvm_array_ = llvm::StructType::create({t_void_p_, t_tvm_device_, t_int_, t_tvm_type_,
97 t_tvm_shape_index_->getPointerTo(),
98 t_tvm_shape_index_->getPointerTo(), t_int64_});
99 // Defined in include/tvm/runtime/c_runtime_api.h:
100 // typedef union { ... } TVMValue;
101 t_tvm_value_ = llvm::StructType::create({t_float64_});
102 // Defined in include/tvm/runtime/c_backend_api.h:
103 // typedef struct { void* sync_handle; int32_t num_task; } TVMParallelGroupEnv;
104 t_tvm_parallel_group_env_ = llvm::StructType::create({t_int32_->getPointerTo(), t_int32_});
105 // Defined in include/tvm/runtime/c_backend_api.h:
106 // typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args,
107 // TVMValue* out_ret_value, int* out_ret_tcode,
108 // void* resource_handle);
109 ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get(
110 t_int_,
111 {t_void_p_, t_int_->getPointerTo(), t_int_, t_void_p_, t_int_->getPointerTo(), t_void_p_},
112 false);
113 t_tvm_crt_func_registry_ = llvm::StructType::create(
114 {t_char_->getPointerTo(), ftype_tvm_backend_packed_c_func_->getPointerTo()});
115 t_tvm_crt_module_ = llvm::StructType::create({t_tvm_crt_func_registry_->getPointerTo()});
116 // Defined in include/tvm/runtime/c_backend_api.h:
117 // typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata);
118 ftype_tvm_parallel_lambda_ = llvm::FunctionType::get(
119 t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo(), t_void_p_}, false);
120 md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_);
121
122 // Runtime functions.
123
124 // Defined in include/tvm/runtime/c_runtime_api.h:
125 // int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args,
126 // TVMValue* ret_val, int* ret_type_code);
127 ftype_tvm_func_call_ = llvm::FunctionType::get(
128 t_int_,
129 {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_,
130 t_tvm_value_->getPointerTo(), t_int_->getPointerTo()},
131 false);
132 // Defined in include/tvm/runtime/c_backend_api.h:
133 // int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out);
134 ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(
135 t_int_, {t_void_p_, t_char_->getPointerTo(), t_tvm_func_handle_->getPointerTo()}, false);
136 // Defined in include/tvm/runtime/c_runtime_api.h:
137 // void TVMAPISetLastError(const char* msg);
138 ftype_tvm_api_set_last_error_ =
139 llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false);
140 // Defined in include/tvm/runtime/c_backend_api.h:
141 // int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task);
142 ftype_tvm_parallel_launch_ = llvm::FunctionType::get(
143 t_int_, {ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}, false);
144 // Defined in include/tvm/runtime/c_backend_api.h:
145 // int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
146 ftype_tvm_parallel_barrier_ =
147 llvm::FunctionType::get(t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo()}, false);
148 ftype_tvm_static_init_callback_ = llvm::FunctionType::get(t_int_, {t_void_p_}, false);
149 ftype_tvm_static_init_ =
150 llvm::FunctionType::get(t_int_,
151 {t_void_p_->getPointerTo(),
152 ftype_tvm_static_init_callback_->getPointerTo(), t_void_p_, t_int_},
153 false);
154 // initialize TVM runtime API
155 if (system_lib && !target_c_runtime) {
156 // We will need this in environment for backward registration.
157 // Defined in include/tvm/runtime/c_backend_api.h:
158 // int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
159 f_tvm_register_system_symbol_ = llvm::Function::Create(
160 llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
161 llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
162 } else {
163 f_tvm_register_system_symbol_ = nullptr;
164 }
165 if (dynamic_lookup || system_lib) {
166 f_tvm_func_call_ = llvm::Function::Create(ftype_tvm_func_call_, llvm::Function::ExternalLinkage,
167 "TVMFuncCall", module_.get());
168 f_tvm_get_func_from_env_ =
169 llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage,
170 "TVMBackendGetFuncFromEnv", module_.get());
171 f_tvm_api_set_last_error_ =
172 llvm::Function::Create(ftype_tvm_api_set_last_error_, llvm::Function::ExternalLinkage,
173 "TVMAPISetLastError", module_.get());
174 f_tvm_parallel_launch_ =
175 llvm::Function::Create(ftype_tvm_parallel_launch_, llvm::Function::ExternalLinkage,
176 "TVMBackendParallelLaunch", module_.get());
177 f_tvm_parallel_barrier_ =
178 llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage,
179 "TVMBackendParallelBarrier", module_.get());
180 }
181 target_c_runtime_ = target_c_runtime;
182 is_system_lib_ = system_lib;
183 InitGlobalContext(dynamic_lookup);
184}
185
186llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) {
187#if TVM_LLVM_VERSION >= 50
188 llvm::SmallVector<llvm::Metadata*, 4> paramTys;
189
190 paramTys.push_back(GetDebugType(f->ret_type));
191 for (const auto& param : f->params) {
192 paramTys.push_back(GetDebugType(GetType(param)));
193 }
194
195 auto* DIFunctionTy = dbg_info_->di_builder_->createSubroutineType(
196 dbg_info_->di_builder_->getOrCreateTypeArray(paramTys));
197
198 bool local_to_unit = llvm::GlobalVariable::isLocalLinkage(llvm::GlobalValue::InternalLinkage);
199
200 // TODO(driazati): determine the IRModule name instead of hardcoding 'main.tir'
201#if TVM_LLVM_VERSION >= 80
202 auto SPFlags = llvm::DISubprogram::toSPFlags(local_to_unit, /*IsDefinition=*/true,
203 /*IsOptimized=*/true);
204 auto* DIFunction = dbg_info_->di_builder_->createFunction(
205 /*Scope=*/dbg_info_->file_, /*Name=*/"main.tir", /*LinkageName=*/"",
206 /*File=*/dbg_info_->file_, /*LineNo=*/0, /*Ty=*/DIFunctionTy,
207 /*ScopeLine=*/0, /*Flags=*/llvm::DINode::FlagZero, /*SPFlags=*/SPFlags);
208#else
209 auto* DIFunction = dbg_info_->di_builder_->createFunction(
210 /*Scope=*/dbg_info_->file_, /*Name=*/"main.tir", /*LinkageName=*/"",
211 /*File=*/dbg_info_->file_, /*LineNo=*/0, /*Ty=*/DIFunctionTy,
212 /*isLocalToUnit=*/local_to_unit, /*isDefinition=*/true, /*ScopeLine=*/0,
213 /*Flags=*/llvm::DINode::FlagPrototyped, /*isOptimized=*/true);
214#endif
215 return DIFunction;
216#else
217 return nullptr;
218#endif
219}
220
221void CodeGenCPU::AddFunction(const PrimFunc& f) {
222#if TVM_LLVM_VERSION >= 50
223 di_subprogram_ = CreateDebugFunction(f);
224#endif
225 EmitDebugLocation(f->span);
226 CodeGenLLVM::AddFunction(f);
227 if (f_tvm_register_system_symbol_ != nullptr) {
228 auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
229 ICHECK(global_symbol.defined())
230 << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
231 export_system_symbols_.emplace_back(
232 std::make_pair(global_symbol.value().operator std::string(), function_));
233 }
234 AddDebugInformation(f, function_);
235}
236
237// Following Glow |DebugInfo::generateFunctionDebugInfo|, https://git.io/fjadv
238void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) {
239#if TVM_LLVM_VERSION >= 50
240 ICHECK(di_subprogram_);
241 f_llvm->setSubprogram(di_subprogram_);
242 ICHECK_EQ(f_llvm->getSubprogram(), di_subprogram_);
243
244 IRBuilder builder(&f_llvm->getEntryBlock());
245 if (!f_llvm->getEntryBlock().empty()) {
246 builder.SetInsertPoint(&f_llvm->getEntryBlock().front());
247 }
248 llvm::DebugLoc DL;
249 builder.SetCurrentDebugLocation(DL);
250 llvm::LLVMContext* ctx = llvm_target_->GetContext();
251 for (size_t i = 0; i < f_llvm->arg_size(); ++i) {
252 auto* paramAlloca = builder.CreateAlloca(f_llvm->getFunctionType()->getParamType(i));
253 std::string paramName = "arg" + std::to_string(i + 1);
254 auto param = dbg_info_->di_builder_->createParameterVariable(
255 di_subprogram_, paramName, i + 1, dbg_info_->file_, 0,
256 GetDebugType(GetType(f_tir->params[i]), f_llvm->getFunctionType()->getParamType(i)),
257 /*alwaysPreserve=*/true);
258 auto* store = builder.CreateStore(f_llvm->arg_begin() + i, paramAlloca);
259 auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, di_subprogram_);
260 dbg_info_->di_builder_->insertDeclare(paramAlloca, param,
261 dbg_info_->di_builder_->createExpression(),
262 llvm::DebugLoc(di_loc), store);
263 }
264 dbg_info_->di_builder_->finalizeSubprogram(f_llvm->getSubprogram());
265 auto* scope = f_llvm->getSubprogram();
266 if (!scope) {
267 return;
268 }
269
270 for (auto& BB : *f_llvm) {
271 for (auto& I : BB) {
272 if (I.getDebugLoc()) {
273 continue;
274 }
275 auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, scope);
276 I.setDebugLoc(llvm::DebugLoc(di_loc));
277 }
278 }
279#endif
280}
281
282llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir) {
283 return GetDebugType(ty_tir, GetLLVMType(ty_tir));
284}
285llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) {
286 if (ty_llvm == t_void_) {
287 return nullptr;
288 } else if (ty_llvm == llvm::Type::getFloatTy(*llvm_target_->GetContext())) {
289 return dbg_info_->di_builder_->createBasicType("float", 32, llvm::dwarf::DW_ATE_float);
290 } else if (ty_llvm == t_int8_) {
291 return dbg_info_->di_builder_->createBasicType("int8", 8, llvm::dwarf::DW_ATE_signed);
292 } else if (ty_llvm == t_int32_) {
293 return dbg_info_->di_builder_->createBasicType("int32", 32, llvm::dwarf::DW_ATE_signed);
294 } else if (ty_llvm->isPointerTy()) {
295 auto* ptr_type = ty_tir.as<PointerTypeNode>();
296 ICHECK(ptr_type != nullptr || GetRuntimeDataType(ty_tir).is_handle())
297 << "Got LLVM pointer type from non-pointer IR type: " << ty_tir;
298 auto* pointee_type = ptr_type != nullptr ? GetDebugType(ptr_type->element_type,
299 GetLLVMType(ptr_type->element_type))
300 : nullptr;
301 return dbg_info_->di_builder_->createPointerType(pointee_type,
302 ty_llvm->getPrimitiveSizeInBits());
303 } else {
304 std::string type_str;
305 llvm::raw_string_ostream rso(type_str);
306 ty_llvm->print(rso);
307 LOG(FATAL) << "Unknown LLVM type:" << rso.str();
308 }
309 return nullptr;
310}
311
312void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
313 llvm::Function* f = module_->getFunction(entry_func_name);
314 ICHECK(f) << "Function " << entry_func_name << "does not in module";
315 llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1);
316 llvm::GlobalVariable* global =
317 new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, nullptr,
318 runtime::symbol::tvm_module_main);
319#if TVM_LLVM_VERSION >= 100
320 global->setAlignment(llvm::Align(1));
321#else
322 global->setAlignment(1);
323#endif
324 // comdat is needed for windows select any linking to work
325 // set comdat to Any(weak linking)
326 if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
327 llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main);
328 comdat->setSelectionKind(llvm::Comdat::Any);
329 global->setComdat(comdat);
330 }
331
332 global->setInitializer(
333 llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), entry_func_name));
334 global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
335}
336
337std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
338 // link modules
339 if (dbg_info_ != nullptr) {
340 dbg_info_->di_builder_->finalize();
341 }
342 return CodeGenLLVM::Finish();
343}
344
345CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf,
346 llvm::Value* index, int kind) {
347 if (kind < builtin::kArrKindBound_) {
348 if (buf->getType() == t_void_p_) {
349 buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
350 } else {
351 ICHECK_EQ(buf->getType(), t_tvm_array_->getPointerTo());
352 }
353 }
354 switch (kind) {
355 case builtin::kArrAddr: {
356 return TypedPointer(t_tvm_array_, builder_->CreateInBoundsGEP(t_tvm_array_, buf, index));
357 }
358 case builtin::kArrData: {
359 llvm::Type* member_type = t_tvm_array_->getStructElementType(0);
360 llvm::Value* member_addr =
361 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(0)});
362 return TypedPointer(member_type, member_addr);
363 }
364 case builtin::kArrShape: {
365 llvm::Type* member_type = t_tvm_array_->getStructElementType(4);
366 llvm::Value* member_addr =
367 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(4)});
368 return TypedPointer(member_type, member_addr);
369 }
370 case builtin::kArrStrides: {
371 llvm::Type* member_type = t_tvm_array_->getStructElementType(5);
372 llvm::Value* member_addr =
373 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(5)});
374 return TypedPointer(member_type, member_addr);
375 }
376 case builtin::kArrNDim: {
377 llvm::Type* member_type = t_tvm_array_->getStructElementType(2);
378 llvm::Value* member_addr =
379 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(2)});
380 return TypedPointer(member_type, member_addr);
381 }
382 case builtin::kArrTypeCode: {
383 llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(0);
384 llvm::Value* member_addr =
385 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(0)});
386 return TypedPointer(member_type, member_addr);
387 }
388 case builtin::kArrTypeBits: {
389 llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(1);
390 llvm::Value* member_addr =
391 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(1)});
392 return TypedPointer(member_type, member_addr);
393 }
394 case builtin::kArrTypeLanes: {
395 llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(2);
396 llvm::Value* member_addr =
397 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(2)});
398 return TypedPointer(member_type, member_addr);
399 }
400 case builtin::kArrByteOffset: {
401 llvm::Type* member_type = t_tvm_array_->getStructElementType(6);
402 llvm::Value* member_addr =
403 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(6)});
404 return TypedPointer(member_type, member_addr);
405 }
406 case builtin::kArrDeviceId: {
407 llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(1);
408 llvm::Value* member_addr =
409 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(1)});
410 return TypedPointer(member_type, member_addr);
411 }
412 case builtin::kArrDeviceType: {
413 llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(0);
414 llvm::Value* member_addr =
415 builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(0)});
416 return TypedPointer(member_type, member_addr);
417 }
418 case builtin::kTVMValueContent: {
419 ICHECK_EQ(t.lanes(), 1);
420 ICHECK(t.is_handle() || t.bits() == 64);
421 if (t.is_int()) {
422 buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo());
423 return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index));
424 } else if (t.is_float()) {
425 buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo());
426 return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index));
427 } else {
428 ICHECK(t.is_handle());
429 buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo());
430 buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index);
431 return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()));
432 }
433 }
434 default:
435 LOG(FATAL) << "unknown field code";
436 }
437}
438
439llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol,
440 const Array<PrimExpr>& args, bool skip_first_arg) {
441 std::vector<llvm::Value*> arg_values;
442 for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
443 arg_values.push_back(MakeValue(args[i]));
444 }
445 std::vector<llvm::Type*> arg_types;
446 for (llvm::Value* v : arg_values) {
447 arg_types.push_back(v->getType());
448 }
449 llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false);
450 // Check if it is available in global function table as injected function.
451 auto it = gv_func_map_.find(global_symbol);
452 if (it != gv_func_map_.end()) {
453 if (it->second == nullptr) {
454 gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol);
455 it = gv_func_map_.find(global_symbol);
456 }
457#if TVM_LLVM_VERSION >= 90
458 auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second));
459#else
460 auto ext_callee = GetContextPtr(it->second);
461#endif
462 return builder_->CreateCall(ext_callee, arg_values);
463 } else {
464 llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol));
465 if (f == nullptr) {
466 f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
467 MakeStringRef(global_symbol), module_.get());
468 }
469#if TVM_LLVM_VERSION >= 90
470 auto ext_callee = llvm::FunctionCallee(f);
471#else
472 auto ext_callee = f;
473#endif
474 return builder_->CreateCall(ext_callee, arg_values);
475 }
476}
477
478llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) {
479 llvm::GlobalVariable* gv = new llvm::GlobalVariable(
480 *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, nullptr, name);
481#if TVM_LLVM_VERSION >= 100
482 gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(p_type)));
483#else
484 gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
485#endif
486 gv->setInitializer(llvm::Constant::getNullValue(p_type));
487 gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
488 // comdat is needed for windows select any linking to work
489 // set comdat to Any(weak linking)
490 if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
491 llvm::Comdat* comdat = module_->getOrInsertComdat(name);
492 comdat->setSelectionKind(llvm::Comdat::Any);
493 gv->setComdat(comdat);
494 }
495 return gv;
496}
497
498llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) {
499 ICHECK(gv != nullptr);
500#if TVM_LLVM_VERSION >= 110
501 llvm::LoadInst* faddr =
502 builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment()));
503#elif TVM_LLVM_VERSION >= 80
504 llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment());
505#else
506 llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
507#endif
508 faddr->setMetadata("tbaa",
509 md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
510 return faddr;
511}
512
513void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) {
514 // Module context
515 gv_mod_ctx_ = InitContextPtr(t_void_p_, tvm::runtime::symbol::tvm_module_ctx);
516 // Register back the locations.
517 if (f_tvm_register_system_symbol_ != nullptr && !target_c_runtime_) {
518 export_system_symbols_.emplace_back(
519 std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
520 } else {
521 if (!dynamic_lookup) {
522 gv_tvm_func_call_ = InitContextPtr(ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall");
523 gv_tvm_get_func_from_env_ = InitContextPtr(ftype_tvm_get_func_from_env_->getPointerTo(),
524 "__TVMBackendGetFuncFromEnv");
525 gv_tvm_api_set_last_error_ =
526 InitContextPtr(ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError");
527 gv_tvm_parallel_launch_ =
528 InitContextPtr(ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch");
529 gv_tvm_parallel_barrier_ = InitContextPtr(ftype_tvm_parallel_barrier_->getPointerTo(),
530 "__TVMBackendParallelBarrier");
531 // Mark as context functions
532 gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
533 gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
534 }
535 }
536}
537
538llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
539 // create emit codes that checks and load the function.
540 llvm::LLVMContext* ctx = llvm_target_->GetContext();
541 auto* fail_block = llvm::BasicBlock::Create(*ctx, "call_fail", function_);
542 auto* end_block = llvm::BasicBlock::Create(*ctx, "call_end", function_);
543 auto* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0));
544 builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
545 builder_->SetInsertPoint(fail_block);
546 // return the code.
547 builder_->CreateRet(retcode);
548 // otherwise set it to be new end.
549 builder_->SetInsertPoint(end_block);
550 return end_block;
551}
552
553void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
554 EmitDebugLocation(op);
555 /*! \brief maintain states that should be guarded when step into compute scope */
556 struct ComputeScopeStates {
557 explicit ComputeScopeStates(CodeGenCPU* parent) : parent_(parent) {}
558
559 void EnterWithScope() {
560 std::swap(function_, parent_->function_);
561 std::swap(analyzer_, parent_->analyzer_);
562 std::swap(var_map_, parent_->var_map_);
563 }
564
565 void ExitWithScope() {
566 std::swap(function_, parent_->function_);
567 std::swap(analyzer_, parent_->analyzer_);
568 std::swap(var_map_, parent_->var_map_);
569 }
570
571 llvm::Function* function_{nullptr};
572 std::unordered_map<const VarNode*, llvm::Value*> var_map_;
573 std::unique_ptr<arith::Analyzer> analyzer_{std::make_unique<arith::Analyzer>()};
574 CodeGenCPU* parent_;
575 };
576
577 // There are two reasons why we create another function for compute_scope
578 // - Make sure the generated compute function is clearly separately(though it can get inlined)
579 // - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
580 // This is easier than set the alias scope manually.
581 Array<Var> vargs = tir::UndefinedVars(op->body, {});
582 std::vector<llvm::Value*> arg_values;
583 std::vector<llvm::Type*> arg_types;
584 for (Var v : vargs) {
585 llvm::Value* value = MakeValue(v);
586 value->setName(v->name_hint.c_str());
587 arg_values.push_back(value);
588 arg_types.push_back(value->getType());
589 }
590 llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_types, false);
591 // $xxx_compute_ functions are not global. They should be marked as static (via InternalLinkage)
592 // to call them correctly on MIPS platform (CALL16 issue)
593 // Linkage ld Error: CALL16 reloc at 0x290 not against global symbol
594 const StringImmNode* value = op->value.as<StringImmNode>();
595 ICHECK(value != nullptr);
596 llvm::Function* fcompute = llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
597 MakeStringRef(value->value), module_.get());
598 SetTargetAttributes(fcompute);
599
600 llvm::BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values));
601 llvm::LLVMContext* ctx = llvm_target_->GetContext();
602 // enter compute scope and setup compute function.
603 With<ComputeScopeStates> scope_states_guard(this);
604 size_t idx = 0;
605 for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); ++it, ++idx) {
606 llvm::Argument* v = &(*it);
607 const Var& var = vargs[idx];
608 var_map_[var.get()] = v;
609 if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
610 // set non alias.
611#if TVM_LLVM_VERSION >= 50
612 fcompute->addParamAttr(idx, llvm::Attribute::NoAlias);
613 // always not inline compute function to make the code structure clean
614#else
615 fcompute->setDoesNotAlias(idx + 1);
616#endif
617 fcompute->addFnAttr(llvm::Attribute::NoInline);
618 }
619 // Add alignment attribute if needed.
620#if TVM_LLVM_VERSION >= 50
621 auto f = alloc_storage_info_.find(var.get());
622 if (f != alloc_storage_info_.end()) {
623 unsigned align = f->second.alignment;
624 if (align > 1) {
625 auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align);
626 fcompute->addParamAttr(idx, attr);
627 }
628 }
629#endif
630 }
631
632 function_ = fcompute;
633 auto* compute_entry = llvm::BasicBlock::Create(*ctx, "entry", function_);
634 builder_->SetInsertPoint(compute_entry);
635 this->VisitStmt(op->body);
636 builder_->CreateRet(ConstInt32(0));
637 builder_->SetInsertPoint(compute_call_end);
638}
639
640CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array<Var>& vfields,
641 uint64_t* num_bytes,
642 std::string struct_name) {
643 if (vfields.size() == 0) {
644 *num_bytes = 0U;
645 return TypedPointer(t_void_p_, llvm::Constant::getNullValue(t_void_p_));
646 }
647 std::vector<llvm::Type*> fields;
648 for (Var v : vfields) {
649 auto it = var_map_.find(v.get());
650 ICHECK(it != var_map_.end());
651 fields.push_back(it->second->getType());
652 }
653 llvm::StructType* ctype = struct_name.size() ? llvm::StructType::create(fields, struct_name)
654 : llvm::StructType::create(fields);
655 llvm::AllocaInst* cvalue =
656 WithFunctionEntry([&]() { return builder_->CreateAlloca(ctype, ConstInt32(1)); });
657 llvm::Value* zero = ConstInt32(0);
658 for (size_t i = 0; i < vfields.size(); ++i) {
659 builder_->CreateStore(var_map_.at(vfields[i].get()),
660 builder_->CreateInBoundsGEP(ctype, cvalue, {zero, ConstInt32(i)}));
661 }
662 *num_bytes = data_layout_->getTypeAllocSize(ctype);
663 return TypedPointer(ctype, cvalue);
664}
665
666void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const Array<Var>& vfields,
667 std::unordered_map<const VarNode*, llvm::Value*>* vmap) {
668 for (size_t i = 0; i < vfields.size(); ++i) {
669 llvm::Type* field_type = cdata.type->getStructElementType(i);
670 llvm::Value* field_addr =
671 builder_->CreateInBoundsGEP(cdata.type, cdata.addr, {ConstInt32(0), ConstInt32(i)});
672 llvm::Value* load =
673 builder_->CreateLoad(field_type, field_addr, std::string(vfields[i]->name_hint));
674 (*vmap)[vfields[i].get()] = load;
675 }
676}
677
678void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::string name) {
679 // closure data
680 llvm::Function* f =
681 llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage,
682 "__tvm_parallel_lambda", module_.get());
683 SetTargetAttributes(f);
684
685 // allocate and setup the closure, call the closure.
686 Array<Var> vfields = tir::UndefinedVars(body, {});
687 uint64_t nbytes;
688 TypedPointer cdata = PackClosureData(vfields, &nbytes, "closure_" + name);
689#if TVM_LLVM_VERSION >= 90
690 auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch());
691#else
692 auto launch_callee = RuntimeTVMParallelLaunch();
693#endif
694 llvm::BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall(
695 launch_callee,
696 {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)}));
697 // Setup the closure function.
698 auto* lambda_entry =
699 llvm::BasicBlock::Create(*llvm_target_->GetContext(), "parallel_closure_entry", f);
700 builder_->SetInsertPoint(lambda_entry);
701 auto it = f->arg_begin();
702 llvm::Value* task_id = &(*it++);
703 task_id->setName("task_id");
704 llvm::Value* penv = &(*it++);
705 cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
706 // setup new variable map, swap it with current var context.
707 std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
708 UnpackClosureData(cdata, vfields, &new_vmap);
709 // setup parallel env
710 ParallelEnv par_env;
711 par_env.task_id = Var("task_id", DataType::Int(32));
712 par_env.num_task = Var("num_task", DataType::Int(32));
713 new_vmap[par_env.task_id.get()] = task_id;
714 new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
715 t_int32_,
716 builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)}),
717 "num_task");
718 par_env.penv = penv;
719 auto new_analyzer = std::make_unique<arith::Analyzer>();
720 std::swap(function_, f);
721 std::swap(parallel_env_, par_env);
722 std::swap(analyzer_, new_analyzer);
723 std::swap(var_map_, new_vmap);
724 this->VisitStmt(body);
725 builder_->CreateRet(ConstInt32(0));
726 // swap the var map back, now we are back on track.
727 std::swap(var_map_, new_vmap);
728 std::swap(analyzer_, new_analyzer);
729 std::swap(parallel_env_, par_env);
730 std::swap(function_, f);
731 ICHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch";
732 builder_->SetInsertPoint(par_launch_end);
733}
734
735llvm::Value* CodeGenCPU::CreateStaticHandle() {
736 llvm::GlobalVariable* gv =
737 new llvm::GlobalVariable(*module_, t_void_p_, false, llvm::GlobalValue::PrivateLinkage,
738 nullptr, "__tvm_static_handle");
739#if TVM_LLVM_VERSION >= 100
740 gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(t_void_p_)));
741#else
742 gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
743#endif
744 gv->setInitializer(llvm::Constant::getNullValue(t_void_p_));
745 return gv;
746}
747
748void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) {
749 // closure data
750 llvm::Function* f =
751 llvm::Function::Create(ftype_tvm_static_init_callback_, llvm::Function::PrivateLinkage,
752 "__tvm_static_init_lambda", module_.get());
753 SetTargetAttributes(f);
754 llvm::Value* gv = CreateStaticHandle();
755 llvm::Function* finit = module_->getFunction(init_fname);
756 if (finit == nullptr) {
757 finit = llvm::Function::Create(ftype_tvm_static_init_, llvm::Function::ExternalLinkage,
758 init_fname, module_.get());
759 }
760 // allocate and setup the closure, call the closure.
761 uint64_t nbytes;
762 Array<Var> vfields = tir::UndefinedVars(body, {});
763 TypedPointer cdata = PackClosureData(vfields, &nbytes);
764 llvm::BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall(
765 finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)}));
766 // Setup the closure function.
767 auto* lambda_entry = llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", f);
768 builder_->SetInsertPoint(lambda_entry);
769 auto it = f->arg_begin();
770 cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
771 // setup new variable map, swap it with current var context.
772 std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
773 UnpackClosureData(cdata, vfields, &new_vmap);
774 ICHECK(parallel_env_.penv == nullptr);
775 auto new_analyzer = std::make_unique<arith::Analyzer>();
776 std::swap(function_, f);
777 std::swap(analyzer_, new_analyzer);
778 std::swap(var_map_, new_vmap);
779 this->VisitStmt(body);
780 builder_->CreateRet(ConstInt32(0));
781 // swap the var map back, now we are back on track.
782 std::swap(var_map_, new_vmap);
783 std::swap(analyzer_, new_analyzer);
784 std::swap(function_, f);
785 builder_->SetInsertPoint(init_end);
786}
787
788llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
789 // We will store the packed function handle in global space.
790 // Initialize it during the first call.
791 llvm::DataLayout layout(module_.get());
792 uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
793 auto it = func_handle_map_.find(fname);
794
795 llvm::GlobalVariable* hptr;
796 if (it == func_handle_map_.end()) {
797 // create global location for the handle
798 // create the function handle
799 hptr =
800 new llvm::GlobalVariable(*module_, t_tvm_func_handle_, false,
801 llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname);
802#if TVM_LLVM_VERSION >= 100
803 hptr->setAlignment(llvm::Align(align));
804#else
805 hptr->setAlignment(align);
806#endif
807 hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
808 func_handle_map_[fname] = hptr;
809 } else {
810 hptr = it->second;
811 }
812 // create emit codes that checks and load the function.
813 llvm::LLVMContext* ctx = llvm_target_->GetContext();
814 llvm::BasicBlock* pre_block = builder_->GetInsertBlock();
815 auto* init_block = llvm::BasicBlock::Create(*ctx, "handle_init", function_);
816 auto* end_block = llvm::BasicBlock::Create(*ctx, "handle_init_end", function_);
817#if TVM_LLVM_VERSION >= 110
818 llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align));
819#elif TVM_LLVM_VERSION >= 80
820 llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, align);
821#else
822 llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
823#endif
824 llvm::Value* handle_not_null =
825 builder_->CreateICmpNE(handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
826 builder_->CreateCondBr(handle_not_null, end_block, init_block, md_very_likely_branch_);
827 // Initialize the handle if needed.
828 builder_->SetInsertPoint(init_block);
829 llvm::Value* out =
830 WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); });
831#if TVM_LLVM_VERSION >= 110
832 llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_,
833 llvm::Align(gv_mod_ctx_->getAlignment()));
834#elif TVM_LLVM_VERSION >= 80
835 llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_,
836 gv_mod_ctx_->getAlignment());
837#else
838 llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment());
839#endif
840 ctx_load->setMetadata(
841 "tbaa", md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
842#if TVM_LLVM_VERSION >= 90
843 auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv());
844#else
845 auto env_callee = RuntimeTVMGetFuncFromEnv();
846#endif
847 llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx_load, GetConstString(fname), out});
848 init_block = CheckCallSuccess(retcode);
849#if TVM_LLVM_VERSION >= 110
850 llvm::Value* loaded_handle =
851 builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align));
852#elif TVM_LLVM_VERSION >= 80
853 llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align);
854#else
855 llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
856#endif
857 // Store the handle
858 builder_->CreateStore(loaded_handle, hptr);
859 builder_->CreateBr(end_block);
860 // end block
861 builder_->SetInsertPoint(end_block);
862 llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
863 phi->addIncoming(handle, pre_block);
864 phi->addIncoming(loaded_handle, init_block);
865 return phi;
866}
867
868CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& args,
869 const DataType& r_type,
870 const int64_t begin, const int64_t end,
871 bool use_string_lookup) {
872 PackedCall pc;
873 std::string func_name = args[0].as<StringImmNode>()->value;
874 // call the function
875 int64_t nargs = end - begin;
876 ICHECK_GE(nargs, 0);
877 llvm::Value* stack_value = MakeValue(args[1]);
878 llvm::Value* stack_tcode = MakeValue(args[2]);
879 llvm::Value* arg_value = builder_->CreateInBoundsGEP(
880 t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
881 ConstInt32(begin));
882 TypedPointer arg_tcode =
883 CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(begin)}, DataType::Int(32));
884 llvm::Value* ret_value = builder_->CreateInBoundsGEP(
885 t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
886 ConstInt32(end));
887 TypedPointer ret_tcode =
888 CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32));
889
890 llvm::FunctionType* callee_ftype = nullptr;
891 llvm::Value* callee_value = nullptr;
892 std::vector<llvm::Value*> call_args;
893
894 if (use_string_lookup) {
895 callee_ftype = ftype_tvm_func_call_;
896 callee_value = RuntimeTVMFuncCall();
897 call_args.push_back(GetPackedFuncHandle(func_name));
898 call_args.insert(call_args.end(),
899 {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr});
900 } else {
901 callee_ftype = ftype_tvm_backend_packed_c_func_;
902 callee_value = module_->getFunction(func_name);
903 if (callee_value == nullptr) {
904 callee_value =
905 llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage,
906 func_name, module_.get());
907 }
908
909 nargs -= 1;
910 call_args.insert(call_args.end(), {
911 builder_->CreateBitCast(arg_value, t_void_p_),
912 arg_tcode.addr,
913 ConstInt32(nargs),
914 builder_->CreateBitCast(ret_value, t_void_p_),
915 ret_tcode.addr,
916 });
917 call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_));
918 }
919#if TVM_LLVM_VERSION >= 90
920 auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value);
921#else
922 (void)callee_ftype; // use callee_ftype to avoid unused variable warning when using older LLVM.
923 auto call_callee = callee_value;
924#endif
925 llvm::Value* call = builder_->CreateCall(call_callee, call_args);
926
927 llvm::BasicBlock* end_block = CheckCallSuccess(call);
928
929 // Load the return value and cast it to the designated type (r_type).
930 DataType r_api_type = tir::APIType(r_type);
931 llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type);
932 llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo());
933#if TVM_LLVM_VERSION >= 110
934 llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8));
935#elif TVM_LLVM_VERSION >= 80
936 llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8);
937#else
938 llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
939#endif
940 pc.ret_value = CreateCast(r_api_type, r_type, rvalue);
941
942 // Load the return type code.
943#if TVM_LLVM_VERSION >= 110
944 pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8));
945#elif TVM_LLVM_VERSION >= 80
946 pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8);
947#else
948 pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8);
949#endif
950
951 pc.end_block = end_block;
952 return pc;
953}
954
955llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lookup) {
956 auto expected_num_args = use_string_lookup ? 5U : 6U;
957 ICHECK_EQ(op->args.size(), expected_num_args);
958 PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value,
959 op->args[4].as<IntImmNode>()->value, use_string_lookup);
960 return pc.ret_value;
961}
962
963llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) {
964 ICHECK_EQ(op->args.size(), 6U);
965 PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value,
966 op->args[4].as<IntImmNode>()->value, true);
967 llvm::LLVMContext* ctx = llvm_target_->GetContext();
968 // Get traced value.
969 llvm::Value* traced_value = MakeValue(op->args[5]);
970 // The update_block handles case when we need to update the return value.
971 llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx, "update_block", function_);
972 // The continue_block handles case when we need to return original
973 // traced value.
974 llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx, "continue_block", function_);
975
976 // Check the ret_type_code and create cmp instruction.
977 llvm::Value* cmp =
978 builder_->CreateICmpNE(pc.ret_tcode, llvm::ConstantInt::get(t_int_, kTVMNullptr));
979 builder_->CreateCondBr(cmp, update_block, continue_block);
980 builder_->SetInsertPoint(update_block);
981 builder_->CreateBr(continue_block);
982 builder_->SetInsertPoint(continue_block);
983 // The return value depends on from what bb we come from.
984 llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2);
985 phi_rvalue->addIncoming(pc.ret_value, update_block);
986 phi_rvalue->addIncoming(traced_value, pc.end_block);
987 return phi_rvalue;
988}
989
990llvm::Value* CodeGenCPU::RuntimeTVMFuncCall() {
991 if (f_tvm_func_call_ != nullptr) return f_tvm_func_call_;
992 return GetContextPtr(gv_tvm_func_call_);
993}
994
995llvm::Value* CodeGenCPU::RuntimeTVMGetFuncFromEnv() {
996 if (f_tvm_get_func_from_env_ != nullptr) return f_tvm_get_func_from_env_;
997 return GetContextPtr(gv_tvm_get_func_from_env_);
998}
999llvm::Value* CodeGenCPU::RuntimeTVMAPISetLastError() {
1000 if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_;
1001 return GetContextPtr(gv_tvm_api_set_last_error_);
1002}
1003llvm::Value* CodeGenCPU::RuntimeTVMParallelLaunch() {
1004 if (f_tvm_parallel_launch_ != nullptr) return f_tvm_parallel_launch_;
1005 return GetContextPtr(gv_tvm_parallel_launch_);
1006}
1007
1008llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
1009 if (f_tvm_parallel_barrier_ != nullptr) return f_tvm_parallel_barrier_;
1010 return GetContextPtr(gv_tvm_parallel_barrier_);
1011}
1012
1013/*! \brief Defines LLVM Types for each Metadata member type. */
1014struct MetadataLlvmTypes {
1015 llvm::Type* t_float64;
1016 llvm::Type* t_uint8;
1017 llvm::Type* t_int64;
1018 llvm::Type* t_bool;
1019 llvm::Type* t_cstring;
1020 llvm::Type* t_void_p;
1021 llvm::StructType* t_data_type;
1022
1023 /*! \brief Maps a MetadataBase subclass' type_key to its corresponding LLVM StructType. */
1024 ::std::unordered_map<std::string, llvm::StructType*> structs_by_type_key;
1025};
1026
1027class MetadataTypeDefiner : public AttrVisitor {
1028 public:
1029 MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* llvm_types)
1030 : ctx_{ctx}, llvm_types_{llvm_types} {}
1031
1032 void Visit(const char* key, double* value) final {
1033 elements_.emplace_back(llvm_types_->t_float64);
1034 }
1035 void Visit(const char* key, int64_t* value) final {
1036 elements_.emplace_back(llvm_types_->t_int64);
1037 }
1038 void Visit(const char* key, uint64_t* value) final {
1039 elements_.emplace_back(llvm_types_->t_int64);
1040 }
1041 void Visit(const char* key, int* value) final { elements_.emplace_back(llvm_types_->t_int64); }
1042 void Visit(const char* key, bool* value) final { elements_.emplace_back(llvm_types_->t_bool); }
1043 void Visit(const char* key, std::string* value) final {
1044 elements_.emplace_back(llvm_types_->t_cstring);
1045 }
1046 void Visit(const char* key, void** value) final { elements_.emplace_back(llvm_types_->t_void_p); }
1047 void Visit(const char* key, DataType* value) final {
1048 elements_.emplace_back(llvm_types_->t_data_type);
1049 }
1050 void Visit(const char* key, runtime::NDArray* value) final {
1051 elements_.emplace_back(llvm_types_->t_int64);
1052 elements_.emplace_back(llvm_types_->t_void_p);
1053 }
1054
1055 private:
1056 void VisitMetadataBase(runtime::metadata::MetadataBase metadata) {
1057 elements_.emplace_back(llvm::PointerType::getUnqual(
1058 llvm::StructType::create(*ctx_, metadata->get_c_struct_name())));
1059 if (visited_.find(metadata->get_c_struct_name()) != visited_.end()) {
1060 return;
1061 }
1062
1063 if (to_visit_.find(metadata->get_c_struct_name()) != to_visit_.end()) {
1064 return;
1065 }
1066 to_visit_[metadata->get_c_struct_name()] = metadata;
1067 }
1068
1069 public:
1070 using MetadataKind = runtime::metadata::MetadataKind;
1071
1072 void VisitArray(const runtime::metadata::MetadataArrayNode* arr) {
1073 switch (arr->kind) {
1074 case MetadataKind::kUint64: // LLVM encodes signed and unsigned with same types.
1075 case MetadataKind::kInt64:
1076 elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64));
1077 break;
1078 case MetadataKind::kBool:
1079 elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_bool));
1080 break;
1081 case MetadataKind::kString:
1082 elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring));
1083 break;
1084 case MetadataKind::kHandle:
1085 CHECK(false) << "Do not support handle";
1086 break;
1087 case MetadataKind::kMetadata:
1088 if (llvm_types_->structs_by_type_key.count(arr->type_key)) {
1089 elements_.emplace_back(
1090 llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key]));
1091 }
1092 break;
1093 default:
1094 CHECK(false) << "Unsupported metadata kind " << arr->kind;
1095 break;
1096 }
1097 }
1098
1099 void Visit(const char* key, ObjectRef* value) final {
1100 const runtime::metadata::MetadataArrayNode* arr =
1101 value->as<runtime::metadata::MetadataArrayNode>();
1102 if (arr != nullptr) {
1103 VisitArray(arr);
1104 } else {
1105 elements_.emplace_back(
1106 llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[(*value)->GetTypeKey()]));
1107 }
1108 }
1109
1110 void DefineType(runtime::metadata::MetadataBase metadata) {
1111 ICHECK(elements_.empty());
1112 ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
1113 llvm_types_->structs_by_type_key[metadata->GetTypeKey()] =
1114 llvm::StructType::create(*ctx_, elements_, metadata->get_c_struct_name());
1115 elements_.clear();
1116 }
1117
1118 llvm::LLVMContext* ctx_;
1119 struct MetadataLlvmTypes* llvm_types_;
1120 ::std::unordered_set<::std::string> visited_;
1121 ::std::unordered_map<::std::string, runtime::metadata::MetadataBase> to_visit_;
1122 ::std::vector<llvm::Type*> elements_;
1123};
1124
1125class MetadataSerializerLLVM : public AttrVisitor {
1126 using MetadataKind = runtime::metadata::MetadataKind;
1127
1128 public:
1129 MetadataSerializerLLVM(CodeGenLLVM* codegen, struct MetadataLlvmTypes* llvm_types)
1130 : codegen_{codegen}, llvm_types_{llvm_types} {}
1131
1132 void Visit(const char* key, double* value) final {
1133 elements_.back().emplace_back(llvm::ConstantFP::get(llvm_types_->t_float64, *value));
1134 }
1135 void Visit(const char* key, int64_t* value) final {
1136 elements_.back().emplace_back(llvm::ConstantInt::get(
1137 llvm_types_->t_int64, static_cast<uint64_t>(*value), true /* isSigned */));
1138 }
1139 void Visit(const char* key, uint64_t* value) final {
1140 elements_.back().emplace_back(
1141 llvm::ConstantInt::get(llvm_types_->t_int64, *value, false /* isSigned */));
1142 }
1143 void Visit(const char* key, int* value) final {
1144 elements_.back().emplace_back(
1145 llvm::ConstantInt::get(llvm_types_->t_int64, *value, true /* isSigned */));
1146 }
1147 void Visit(const char* key, bool* value) final {
1148 elements_.back().emplace_back(llvm::ConstantInt::get(
1149 llvm_types_->t_uint8, static_cast<uint64_t>(*value), false /* isSigned */));
1150 }
1151 void Visit(const char* key, std::string* value) final {
1152 elements_.back().emplace_back(codegen_->GetConstString(*value));
1153 }
1154 void Visit(const char* key, void** value) final {
1155 CHECK(false) << "Do not support serializing void*";
1156 }
1157 void Visit(const char* key, DataType* value) final {
1158 elements_.back().emplace_back(llvm::ConstantStruct::get(
1159 llvm_types_->t_data_type,
1160 {llvm::ConstantInt::get(llvm_types_->t_uint8, value->code(), false /* isSigned */),
1161 llvm::ConstantInt::get(llvm_types_->t_uint8, value->bits(), false /* isSigned */),
1162 llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */)}));
1163 }
1164
1165 // Serializing NDArray as tuple of len, data
1166 void Visit(const char* key, runtime::NDArray* value) final {
1167 std::string bytes;
1168 dmlc::MemoryStringStream stream(&bytes);
1169 value->Save(&stream);
1170 elements_.back().emplace_back(
1171 llvm::ConstantInt::get(llvm_types_->t_int64, bytes.length(), true /* isSigned */));
1172 elements_.back().emplace_back(codegen_->GetConstString(bytes));
1173 }
1174
1175 void VisitMetadata(runtime::metadata::MetadataBase metadata) {
1176 elements_.emplace_back(std::vector<llvm::Constant*>());
1177 ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
1178 auto struct_elements = elements_.back();
1179 elements_.pop_back();
1180 auto struct_ty = llvm_types_->structs_by_type_key[metadata->GetTypeKey()];
1181 ICHECK(struct_ty != nullptr) << "Did not find LLVM StructType* for type_key="
1182 << metadata->GetTypeKey();
1183 CHECK_EQ(struct_elements.size(), struct_ty->getNumElements());
1184 auto out = llvm::ConstantStruct::get(struct_ty, struct_elements);
1185 if (elements_.size() > 0) {
1186 elements_.back().push_back(out);
1187 } else {
1188 last_production_ = out;
1189 }
1190 }
1191
1192 void VisitArray(const runtime::metadata::MetadataArrayNode* arr) {
1193 llvm::Type* element_type;
1194 switch (arr->kind) {
1195 case MetadataKind::kInt64:
1196 element_type = llvm_types_->t_int64;
1197 break;
1198 case MetadataKind::kUint64:
1199 element_type = llvm_types_->t_int64;
1200 break;
1201 case MetadataKind::kBool:
1202 element_type = llvm_types_->t_uint8;
1203 break;
1204 case MetadataKind::kString:
1205 element_type = llvm_types_->t_cstring;
1206 break;
1207 case MetadataKind::kMetadata: {
1208 element_type = llvm_types_->structs_by_type_key[arr->type_key];
1209 ICHECK(element_type != nullptr)
1210 << "Did not find LLVM StructType* for type_key=" << arr->type_key;
1211 break;
1212 }
1213 default:
1214 LOG(FATAL) << "unknown metadata kind " << arr->kind;
1215 break;
1216 }
1217
1218 elements_.emplace_back(std::vector<llvm::Constant*>());
1219 for (auto o : arr->array) {
1220 if (o->IsInstance<FloatImmNode>()) {
1221 double value = Downcast<FloatImm>(o)->value;
1222 Visit(nullptr, &value);
1223 }
1224 if (o->IsInstance<IntImmNode>()) {
1225 auto value = Downcast<IntImm>(o)->value;
1226 Visit(nullptr, &value);
1227 } else if (o->IsInstance<StringObj>()) {
1228 ::std::string value = Downcast<String>(o);
1229 Visit(nullptr, &value);
1230 } else {
1231 // nested array not possible.
1232 VisitMetadata(Downcast<runtime::metadata::MetadataBase>(o));
1233 }
1234 }
1235 auto array = elements_.back();
1236 elements_.pop_back();
1237 CHECK(element_type != nullptr);
1238 auto arr_ty = llvm::ArrayType::get(element_type, array.size());
1239 auto llvm_arr = llvm::ConstantArray::get(arr_ty, array);
1240
1241 if (elements_.size() > 0) {
1242 elements_.back().emplace_back(
1243 codegen_->GetGlobalConstant(llvm_arr, "", llvm::GlobalValue::PrivateLinkage));
1244 } else {
1245 last_production_ = llvm_arr;
1246 }
1247 }
1248
1249 void Visit(const char* key, ObjectRef* value) final {
1250 const runtime::metadata::MetadataArrayNode* arr =
1251 value->as<runtime::metadata::MetadataArrayNode>();
1252 if (arr != nullptr) {
1253 VisitArray(arr);
1254 return;
1255 }
1256
1257 runtime::metadata::MetadataBase metadata = Downcast<runtime::metadata::MetadataBase>(*value);
1258 VisitMetadata(metadata);
1259 }
1260
1261 llvm::Constant* Serialize(runtime::metadata::MetadataBase metadata) {
1262 Visit(nullptr, &metadata);
1263 ICHECK(last_production_);
1264 return codegen_->GetGlobalConstant(last_production_);
1265 }
1266
1267 CodeGenLLVM* codegen_;
1268 MetadataLlvmTypes* llvm_types_;
1269 llvm::LLVMContext* ctx_;
1270 llvm::Module* module_;
1271 std::vector<std::vector<llvm::Constant*>> elements_;
1272 llvm::Constant* last_production_;
1273};
1274
1275void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) {
1276 llvm::LLVMContext* ctx = llvm_target_->GetContext();
1277 MetadataLlvmTypes llvm_types{
1278 t_float64_ /* t_float64 */,
1279 llvm::Type::getInt8Ty(*ctx) /* t_uint8 */,
1280 t_int64_ /* t_int64 */,
1281 llvm::Type::getInt8Ty(*ctx) /* t_bool */,
1282 t_char_->getPointerTo() /* t_cstring */,
1283 t_void_p_ /* t_void_p */,
1284 llvm::StructType::create(*ctx, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */,
1285 };
1286
1287 // create sample ConstantInfoMetadata instance for MetadataTypeDefiner
1288 std::string bytes;
1289 runtime::NDArray ci = runtime::NDArray::Empty({0}, DataType::UInt(8), Device{kDLCPU});
1290 dmlc::MemoryStringStream stream(&bytes);
1291 ci.Save(&stream);
1292 TVMConstantInfo di =
1293 TVMConstantInfo{"default-none", 0, static_cast<int64_t>(bytes.size()), bytes.c_str()};
1294
1295 std::vector<runtime::metadata::MetadataBase> queue;
1296 queue.push_back(runtime::metadata::ConstantInfoMetadata(&di));
1297
1298 metadata::DiscoverComplexTypesVisitor discover_complex{&queue};
1299 discover_complex.Discover(metadata);
1300
1301 MetadataTypeDefiner definer{ctx, &llvm_types};
1302 for (auto md : queue) {
1303 if (md.defined()) {
1304 definer.DefineType(md);
1305 }
1306 }
1307
1308 MetadataSerializerLLVM serializer{this, &llvm_types};
1309 auto metadata_constant_gv = serializer.Serialize(metadata);
1310
1311 function_ =
1312 llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage,
1313 runtime::symbol::tvm_get_c_metadata, module_.get());
1314 SetTargetAttributes(function_);
1315 function_->setCallingConv(llvm::CallingConv::C);
1316 function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
1317
1318 llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx, "entry", function_);
1319 builder_->SetInsertPoint(entry_point_entry);
1320
1321 auto ret_values_p = builder_->CreateBitCast(GetArg(function_, 3), t_void_p_->getPointerTo());
1322 builder_->CreateStore(builder_->CreateBitCast(metadata_constant_gv, t_void_p_), ret_values_p);
1323
1324 auto ret_tcode = builder_->CreateBitCast(GetArg(function_, 4), t_int_->getPointerTo());
1325 builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode);
1326
1327 builder_->CreateRet(ConstInt32(0));
1328}
1329
1330void CodeGenCPU::DefineFunctionRegistry(Array<String> func_names) {
1331 ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime";
1332 Array<String> symbols;
1333 std::vector<llvm::Constant*> funcs;
1334 for (auto sym : func_names) {
1335 symbols.push_back(sym);
1336 auto* sym_func =
1337 llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::GlobalValue::ExternalLinkage,
1338 sym.operator std::string(), module_.get());
1339
1340 funcs.emplace_back(sym_func);
1341 }
1342 llvm::ArrayType* t_tvm_crt_func_ptrs =
1343 llvm::ArrayType::get(ftype_tvm_backend_packed_c_func_->getPointerTo(), funcs.size());
1344 llvm::DataLayout layout(module_.get());
1345
1346 llvm::GlobalVariable* func_registry_ptrs = new llvm::GlobalVariable(
1347 *module_, t_tvm_crt_func_ptrs, true, llvm::GlobalValue::InternalLinkage,
1348 llvm::ConstantArray::get(t_tvm_crt_func_ptrs, funcs), "_tvm_func_registry_ptrs");
1349
1350 uint64_t align = layout.getTypeAllocSize(ftype_tvm_backend_packed_c_func_->getPointerTo());
1351#if TVM_LLVM_VERSION >= 100
1352 func_registry_ptrs->setAlignment(llvm::Align(align));
1353#else
1354 func_registry_ptrs->setAlignment(align);
1355#endif
1356 llvm::GlobalVariable* func_registry = new llvm::GlobalVariable(
1357 *module_, t_tvm_crt_func_registry_, true, llvm::GlobalVariable::InternalLinkage,
1358 llvm::ConstantStruct::get(
1359 t_tvm_crt_func_registry_,
1360 {GetConstString(::tvm::target::GenerateFuncRegistryNames(symbols)),
1361 llvm::ConstantExpr::getBitCast(func_registry_ptrs,
1362 ftype_tvm_backend_packed_c_func_->getPointerTo())}),
1363 "_tvm_crt_func_registry");
1364 llvm::GlobalVariable* module = new llvm::GlobalVariable(
1365 *module_, t_tvm_crt_module_, true, llvm::GlobalValue::InternalLinkage,
1366 llvm::ConstantStruct::get(t_tvm_crt_module_, {func_registry}), "_tvm_crt_module");
1367
1368 // Now build TVMSystemLibEntryPoint.
1369 llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false);
1370 function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
1371 "TVMSystemLibEntryPoint", module_.get());
1372 SetTargetAttributes(function_);
1373 llvm::BasicBlock* entry_point_entry =
1374 llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", function_);
1375 builder_->SetInsertPoint(entry_point_entry);
1376 builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_));
1377}
1378
1379void CodeGenCPU::AddStartupFunction() {
1380 if (!target_c_runtime_) {
1381 llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
1382 function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
1383 "__tvm_module_startup", module_.get());
1384 SetTargetAttributes(function_);
1385 llvm::BasicBlock* startup_entry =
1386 llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", function_);
1387 builder_->SetInsertPoint(startup_entry);
1388 for (const auto& kv : export_system_symbols_) {
1389 llvm::Value* name = GetConstString(kv.first);
1390 builder_->CreateCall(f_tvm_register_system_symbol_,
1391 {name, builder_->CreateBitCast(kv.second, t_void_p_)});
1392 }
1393 llvm::appendToGlobalCtors(*module_, function_, 65535);
1394 builder_->CreateRet(nullptr);
1395 }
1396}
1397
1398llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
1399 if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
1400 return CreateCallPacked(op, true /* use_string_lookup */);
1401 } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) {
1402 return CreateCallTracePacked(op);
1403 } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
1404 return CreateCallPacked(op, false /* use_string_lookup */);
1405 } else if (op->op.same_as(builtin::tvm_static_handle())) {
1406 return CreateStaticHandle();
1407 } else if (op->op.same_as(builtin::tvm_throw_last_error())) {
1408 builder_->CreateRet(ConstInt32(-1));
1409 auto next_block = std::next(builder_->GetInsertBlock()->getIterator());
1410 llvm::BasicBlock* new_bb =
1411 llvm::BasicBlock::Create(*llvm_target_->GetContext(), "cont", function_, &*next_block);
1412 builder_->SetInsertPoint(new_bb);
1413 return ConstInt32(-1);
1414 } else if (op->op.same_as(builtin::tvm_struct_get())) {
1415 ICHECK_EQ(op->args.size(), 3U);
1416 int kind = op->args[2].as<IntImmNode>()->value;
1417 TypedPointer ref =
1418 CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind);
1419 if (kind == builtin::kArrAddr) {
1420 return builder_->CreatePointerCast(ref.addr, t_void_p_);
1421 } else {
1422 return builder_->CreateLoad(ref.type, ref.addr);
1423 }
1424 } else if (op->op.same_as(builtin::tvm_struct_set())) {
1425 ICHECK_EQ(op->args.size(), 4U);
1426 int kind = op->args[2].as<IntImmNode>()->value;
1427 llvm::Value* value = MakeValue(op->args[3]);
1428 TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]),
1429 MakeValue(op->args[1]), kind);
1430 ICHECK(kind != builtin::kArrAddr);
1431 if (value->getType()->isPointerTy()) {
1432 value = builder_->CreatePointerCast(value, ref.type);
1433 }
1434 builder_->CreateStore(value, ref.addr);
1435 return ConstInt32(0);
1436 } else if (op->op.same_as(builtin::tvm_stack_alloca())) {
1437 ICHECK_EQ(op->args.size(), 2U);
1438 const std::string& type = op->args[0].as<StringImmNode>()->value;
1439 return WithFunctionEntry([&]() -> llvm::AllocaInst* {
1440 const int64_t* pval = as_const_int(op->args[1]);
1441 ICHECK(pval) << "require stack alloca to contain constant value";
1442 llvm::Value* num = ConstInt32(pval[0]);
1443 if (type == "shape") {
1444 return builder_->CreateAlloca(t_tvm_shape_index_, num);
1445 } else if (type == "arg_value") {
1446 return builder_->CreateAlloca(t_tvm_value_, num);
1447 } else if (type == "arg_tcode") {
1448 return builder_->CreateAlloca(t_int_, num);
1449 } else if (type == "array") {
1450 return builder_->CreateAlloca(t_tvm_array_, num);
1451 } else {
1452 LOG(FATAL) << "Unknown stack alloca type " << type;
1453 }
1454 });
1455 } else {
1456 return CodeGenLLVM::CreateIntrinsic(op);
1457 }
1458}
1459
1460void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
1461 EmitDebugLocation(op);
1462 llvm::Value* cond = MakeValue(op->condition);
1463 std::ostringstream os;
1464 os << "Assert fail: " << op->condition;
1465 if (op->message.as<StringImmNode>()) {
1466 os << ", " << op->message.as<StringImmNode>()->value;
1467 }
1468 llvm::Value* msg = GetConstString(os.str());
1469 llvm::LLVMContext* ctx = llvm_target_->GetContext();
1470 auto* fail_block = llvm::BasicBlock::Create(*ctx, "assert_fail", function_);
1471 auto* end_block = llvm::BasicBlock::Create(*ctx, "assert_end", function_);
1472 builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
1473 // fail condition.
1474 builder_->SetInsertPoint(fail_block);
1475
1476#if TVM_LLVM_VERSION >= 90
1477 auto err_callee =
1478 llvm::FunctionCallee(ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError());
1479#else
1480 auto err_callee = RuntimeTVMAPISetLastError();
1481#endif
1482 builder_->CreateCall(err_callee, {msg});
1483 builder_->CreateRet(ConstInt32(-1));
1484 // otherwise set it to be new end.
1485 builder_->SetInsertPoint(end_block);
1486 CodeGenLLVM::VisitStmt_(op);
1487}
1488
1489void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
1490 EmitDebugLocation(op);
1491 if (op->attr_key == tir::attr::coproc_uop_scope) {
1492 const StringImmNode* value = op->value.as<StringImmNode>();
1493 ICHECK(value != nullptr);
1494 this->CreateStaticInit(value->value, op->body);
1495 } else if (op->attr_key == tir::attr::compute_scope) {
1496 this->CreateComputeScope(op);
1497 } else if (tir::attr::IsPragmaKey(op->attr_key)) {
1498 if (op->attr_key == "pragma_parallel_stride_pattern") {
1499 ICHECK(parallel_env_.penv != nullptr)
1500 << "Pragma parallel_stride_pattern only valid in parallel launch";
1501 parallel_env_.stride_pattern = true;
1502 this->VisitStmt(op->body);
1503 } else if (op->attr_key == "pragma_parallel_launch_point") {
1504 CreateParallelLaunch(op->body, 0, "pragma_parallel");
1505 } else if (op->attr_key == "pragma_parallel_barrier_when_finish") {
1506 ICHECK(parallel_env_.penv != nullptr) << "Cannot run barrier without parallel environment";
1507 ICHECK(!parallel_env_.in_parallel_loop)
1508 << "Cannot not place within parallel loop as the workload may differ, "
1509 << " place it between parallel and parallel_launch_point";
1510 this->VisitStmt(op->body);
1511#if TVM_LLVM_VERSION >= 90
1512 auto bar_callee =
1513 llvm::FunctionCallee(ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier());
1514#else
1515 auto bar_callee = RuntimeTVMParallelBarrier();
1516#endif
1517 builder_->CreateCall(bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv});
1518 } else if (op->attr_key == tir::attr::pragma_import_llvm) {
1519 const StringImmNode* value = op->value.as<StringImmNode>();
1520 ICHECK(value != nullptr);
1521 this->HandleImport(value->value);
1522 this->VisitStmt(op->body);
1523 } else {
1524 LOG(WARNING) << "Unknown pragma " << op->attr_key;
1525 this->VisitStmt(op->body);
1526 }
1527 } else {
1528 CodeGenLLVM::VisitStmt_(op);
1529 }
1530}
1531
1532void CodeGenCPU::VisitStmt_(const ForNode* op) {
1533 EmitDebugLocation(op);
1534 ICHECK(is_zero(op->min));
1535 if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) {
1536 CodeGenLLVM::VisitStmt_(op);
1537 } else if (op->kind == ForKind::kParallel) {
1538 if (parallel_env_.penv == nullptr) {
1539 CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body,
1540 op->thread_binding, op->annotations),
1541 0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str());
1542 } else {
1543 // already in parallel env.
1544 ICHECK(parallel_env_.task_id.defined());
1545 ICHECK(parallel_env_.num_task.defined());
1546 ICHECK(parallel_env_.penv != nullptr);
1547 DataType t = op->extent.dtype();
1548 PrimExpr num_task = cast(t, parallel_env_.num_task);
1549 PrimExpr task_id = cast(t, parallel_env_.task_id);
1550 ICHECK(!parallel_env_.in_parallel_loop)
1551 << "Nested parallel loop is not supported by threadpool, try fuse them instead";
1552 parallel_env_.in_parallel_loop = true;
1553 if (parallel_env_.stride_pattern) {
1554 CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task),
1555 op->loop_var, op->body);
1556 } else {
1557 PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
1558 PrimExpr begin = min(task_id * step, op->extent);
1559 PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent);
1560 CreateSerialFor(MakeValue(begin), MakeValue(end),
1561 llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body);
1562 }
1563 parallel_env_.in_parallel_loop = false;
1564 ++parallel_env_.parallel_loop_count;
1565 }
1566 } else {
1567 LOG(FATAL) << "cannot handle for type " << op->kind;
1568 }
1569}
1570
1571TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu")
1572 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
1573 *rv = static_cast<void*>(new CodeGenCPU());
1574 });
1575
1576} // namespace codegen
1577} // namespace tvm
1578
1579#endif // TVM_LLVM_VERSION
1580