1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_cpu.cc |
22 | */ |
23 | #ifdef TVM_LLVM_VERSION |
24 | |
25 | #include "codegen_cpu.h" |
26 | |
27 | #include <llvm/ADT/SmallVector.h> |
28 | #include <llvm/ADT/StringRef.h> |
29 | #include <llvm/IR/Argument.h> |
30 | #include <llvm/IR/Attributes.h> |
31 | #include <llvm/IR/BasicBlock.h> |
32 | #include <llvm/IR/CallingConv.h> |
33 | #include <llvm/IR/Comdat.h> |
34 | #include <llvm/IR/Constants.h> |
35 | #include <llvm/IR/DIBuilder.h> |
36 | #include <llvm/IR/DebugInfoMetadata.h> |
37 | #include <llvm/IR/DebugLoc.h> |
38 | #include <llvm/IR/DerivedTypes.h> |
39 | #include <llvm/IR/Function.h> |
40 | #include <llvm/IR/GlobalVariable.h> |
41 | #include <llvm/IR/Instructions.h> |
42 | #include <llvm/IR/LLVMContext.h> |
43 | #include <llvm/IR/MDBuilder.h> |
44 | #include <llvm/IR/Metadata.h> |
45 | #include <llvm/IR/Module.h> |
46 | #if TVM_LLVM_VERSION >= 100 |
47 | #include <llvm/Support/Alignment.h> |
48 | #endif |
49 | #include <llvm/Support/raw_ostream.h> |
50 | #include <llvm/Target/TargetMachine.h> |
51 | #include <llvm/Transforms/Utils/ModuleUtils.h> |
52 | #include <tvm/runtime/c_runtime_api.h> |
53 | #include <tvm/runtime/module.h> |
54 | #include <tvm/tir/analysis.h> |
55 | |
56 | #include <algorithm> |
57 | #include <memory> |
58 | #include <unordered_map> |
59 | #include <unordered_set> |
60 | |
61 | #include "../func_registry_generator.h" |
62 | #include "../metadata_utils.h" |
63 | #include "llvm_instance.h" |
64 | |
65 | namespace tvm { |
66 | namespace codegen { |
67 | |
68 | // Make these non-inline because of std::unique_ptr. See comment in |
69 | // codegen_llvm.cc for more information. |
70 | CodeGenCPU::CodeGenCPU() = default; |
71 | CodeGenCPU::~CodeGenCPU() = default; |
72 | |
73 | void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, |
74 | bool dynamic_lookup, bool target_c_runtime) { |
75 | CodeGenLLVM::Init(module_name, llvm_target, system_lib, dynamic_lookup, target_c_runtime); |
76 | dbg_info_ = CreateDebugInfo(module_.get()); |
77 | static_assert(sizeof(TVMValue) == sizeof(double), "invariant" ); |
78 | func_handle_map_.clear(); |
79 | export_system_symbols_.clear(); |
80 | |
81 | // Runtime types. |
82 | |
83 | t_tvm_shape_index_ = |
84 | llvm::Type::getIntNTy(*llvm_target_->GetContext(), DataType::ShapeIndex().bits()); |
85 | // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: |
86 | // typedef struct { DLDeviceType device_type; int device_id; } DLDevice; |
87 | t_tvm_device_ = llvm::StructType::create({t_int_, t_int_}); |
88 | // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: |
89 | // typedef struct { uint8_t code; uint8_t bits; uint16_t lanes; } DLDataType; |
90 | t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}); |
91 | // Defined in include/tvm/runtime/c_runtime_api.h: |
92 | // typedef void* TVMFunctionHandle; |
93 | t_tvm_func_handle_ = t_void_p_; |
94 | // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: |
95 | // typedef struct { ... } DLTensor; |
96 | t_tvm_array_ = llvm::StructType::create({t_void_p_, t_tvm_device_, t_int_, t_tvm_type_, |
97 | t_tvm_shape_index_->getPointerTo(), |
98 | t_tvm_shape_index_->getPointerTo(), t_int64_}); |
99 | // Defined in include/tvm/runtime/c_runtime_api.h: |
100 | // typedef union { ... } TVMValue; |
101 | t_tvm_value_ = llvm::StructType::create({t_float64_}); |
102 | // Defined in include/tvm/runtime/c_backend_api.h: |
103 | // typedef struct { void* sync_handle; int32_t num_task; } TVMParallelGroupEnv; |
104 | t_tvm_parallel_group_env_ = llvm::StructType::create({t_int32_->getPointerTo(), t_int32_}); |
105 | // Defined in include/tvm/runtime/c_backend_api.h: |
106 | // typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args, |
107 | // TVMValue* out_ret_value, int* out_ret_tcode, |
108 | // void* resource_handle); |
109 | ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get( |
110 | t_int_, |
111 | {t_void_p_, t_int_->getPointerTo(), t_int_, t_void_p_, t_int_->getPointerTo(), t_void_p_}, |
112 | false); |
113 | t_tvm_crt_func_registry_ = llvm::StructType::create( |
114 | {t_char_->getPointerTo(), ftype_tvm_backend_packed_c_func_->getPointerTo()}); |
115 | t_tvm_crt_module_ = llvm::StructType::create({t_tvm_crt_func_registry_->getPointerTo()}); |
116 | // Defined in include/tvm/runtime/c_backend_api.h: |
117 | // typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata); |
118 | ftype_tvm_parallel_lambda_ = llvm::FunctionType::get( |
119 | t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo(), t_void_p_}, false); |
120 | md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr" , md_tbaa_root_); |
121 | |
122 | // Runtime functions. |
123 | |
124 | // Defined in include/tvm/runtime/c_runtime_api.h: |
125 | // int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args, |
126 | // TVMValue* ret_val, int* ret_type_code); |
127 | ftype_tvm_func_call_ = llvm::FunctionType::get( |
128 | t_int_, |
129 | {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, |
130 | t_tvm_value_->getPointerTo(), t_int_->getPointerTo()}, |
131 | false); |
132 | // Defined in include/tvm/runtime/c_backend_api.h: |
133 | // int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out); |
134 | ftype_tvm_get_func_from_env_ = llvm::FunctionType::get( |
135 | t_int_, {t_void_p_, t_char_->getPointerTo(), t_tvm_func_handle_->getPointerTo()}, false); |
136 | // Defined in include/tvm/runtime/c_runtime_api.h: |
137 | // void TVMAPISetLastError(const char* msg); |
138 | ftype_tvm_api_set_last_error_ = |
139 | llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false); |
140 | // Defined in include/tvm/runtime/c_backend_api.h: |
141 | // int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task); |
142 | ftype_tvm_parallel_launch_ = llvm::FunctionType::get( |
143 | t_int_, {ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}, false); |
144 | // Defined in include/tvm/runtime/c_backend_api.h: |
145 | // int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); |
146 | ftype_tvm_parallel_barrier_ = |
147 | llvm::FunctionType::get(t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo()}, false); |
148 | ftype_tvm_static_init_callback_ = llvm::FunctionType::get(t_int_, {t_void_p_}, false); |
149 | ftype_tvm_static_init_ = |
150 | llvm::FunctionType::get(t_int_, |
151 | {t_void_p_->getPointerTo(), |
152 | ftype_tvm_static_init_callback_->getPointerTo(), t_void_p_, t_int_}, |
153 | false); |
154 | // initialize TVM runtime API |
155 | if (system_lib && !target_c_runtime) { |
156 | // We will need this in environment for backward registration. |
157 | // Defined in include/tvm/runtime/c_backend_api.h: |
158 | // int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); |
159 | f_tvm_register_system_symbol_ = llvm::Function::Create( |
160 | llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false), |
161 | llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol" , module_.get()); |
162 | } else { |
163 | f_tvm_register_system_symbol_ = nullptr; |
164 | } |
165 | if (dynamic_lookup || system_lib) { |
166 | f_tvm_func_call_ = llvm::Function::Create(ftype_tvm_func_call_, llvm::Function::ExternalLinkage, |
167 | "TVMFuncCall" , module_.get()); |
168 | f_tvm_get_func_from_env_ = |
169 | llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage, |
170 | "TVMBackendGetFuncFromEnv" , module_.get()); |
171 | f_tvm_api_set_last_error_ = |
172 | llvm::Function::Create(ftype_tvm_api_set_last_error_, llvm::Function::ExternalLinkage, |
173 | "TVMAPISetLastError" , module_.get()); |
174 | f_tvm_parallel_launch_ = |
175 | llvm::Function::Create(ftype_tvm_parallel_launch_, llvm::Function::ExternalLinkage, |
176 | "TVMBackendParallelLaunch" , module_.get()); |
177 | f_tvm_parallel_barrier_ = |
178 | llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage, |
179 | "TVMBackendParallelBarrier" , module_.get()); |
180 | } |
181 | target_c_runtime_ = target_c_runtime; |
182 | is_system_lib_ = system_lib; |
183 | InitGlobalContext(dynamic_lookup); |
184 | } |
185 | |
186 | llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) { |
187 | #if TVM_LLVM_VERSION >= 50 |
188 | llvm::SmallVector<llvm::Metadata*, 4> paramTys; |
189 | |
190 | paramTys.push_back(GetDebugType(f->ret_type)); |
191 | for (const auto& param : f->params) { |
192 | paramTys.push_back(GetDebugType(GetType(param))); |
193 | } |
194 | |
195 | auto* DIFunctionTy = dbg_info_->di_builder_->createSubroutineType( |
196 | dbg_info_->di_builder_->getOrCreateTypeArray(paramTys)); |
197 | |
198 | bool local_to_unit = llvm::GlobalVariable::isLocalLinkage(llvm::GlobalValue::InternalLinkage); |
199 | |
200 | // TODO(driazati): determine the IRModule name instead of hardcoding 'main.tir' |
201 | #if TVM_LLVM_VERSION >= 80 |
202 | auto SPFlags = llvm::DISubprogram::toSPFlags(local_to_unit, /*IsDefinition=*/true, |
203 | /*IsOptimized=*/true); |
204 | auto* DIFunction = dbg_info_->di_builder_->createFunction( |
205 | /*Scope=*/dbg_info_->file_, /*Name=*/"main.tir" , /*LinkageName=*/"" , |
206 | /*File=*/dbg_info_->file_, /*LineNo=*/0, /*Ty=*/DIFunctionTy, |
207 | /*ScopeLine=*/0, /*Flags=*/llvm::DINode::FlagZero, /*SPFlags=*/SPFlags); |
208 | #else |
209 | auto* DIFunction = dbg_info_->di_builder_->createFunction( |
210 | /*Scope=*/dbg_info_->file_, /*Name=*/"main.tir" , /*LinkageName=*/"" , |
211 | /*File=*/dbg_info_->file_, /*LineNo=*/0, /*Ty=*/DIFunctionTy, |
212 | /*isLocalToUnit=*/local_to_unit, /*isDefinition=*/true, /*ScopeLine=*/0, |
213 | /*Flags=*/llvm::DINode::FlagPrototyped, /*isOptimized=*/true); |
214 | #endif |
215 | return DIFunction; |
216 | #else |
217 | return nullptr; |
218 | #endif |
219 | } |
220 | |
221 | void CodeGenCPU::AddFunction(const PrimFunc& f) { |
222 | #if TVM_LLVM_VERSION >= 50 |
223 | di_subprogram_ = CreateDebugFunction(f); |
224 | #endif |
225 | EmitDebugLocation(f->span); |
226 | CodeGenLLVM::AddFunction(f); |
227 | if (f_tvm_register_system_symbol_ != nullptr) { |
228 | auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); |
229 | ICHECK(global_symbol.defined()) |
230 | << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute" ; |
231 | export_system_symbols_.emplace_back( |
232 | std::make_pair(global_symbol.value().operator std::string(), function_)); |
233 | } |
234 | AddDebugInformation(f, function_); |
235 | } |
236 | |
237 | // Following Glow |DebugInfo::generateFunctionDebugInfo|, https://git.io/fjadv |
238 | void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { |
239 | #if TVM_LLVM_VERSION >= 50 |
240 | ICHECK(di_subprogram_); |
241 | f_llvm->setSubprogram(di_subprogram_); |
242 | ICHECK_EQ(f_llvm->getSubprogram(), di_subprogram_); |
243 | |
244 | IRBuilder builder(&f_llvm->getEntryBlock()); |
245 | if (!f_llvm->getEntryBlock().empty()) { |
246 | builder.SetInsertPoint(&f_llvm->getEntryBlock().front()); |
247 | } |
248 | llvm::DebugLoc DL; |
249 | builder.SetCurrentDebugLocation(DL); |
250 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
251 | for (size_t i = 0; i < f_llvm->arg_size(); ++i) { |
252 | auto* paramAlloca = builder.CreateAlloca(f_llvm->getFunctionType()->getParamType(i)); |
253 | std::string paramName = "arg" + std::to_string(i + 1); |
254 | auto param = dbg_info_->di_builder_->createParameterVariable( |
255 | di_subprogram_, paramName, i + 1, dbg_info_->file_, 0, |
256 | GetDebugType(GetType(f_tir->params[i]), f_llvm->getFunctionType()->getParamType(i)), |
257 | /*alwaysPreserve=*/true); |
258 | auto* store = builder.CreateStore(f_llvm->arg_begin() + i, paramAlloca); |
259 | auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, di_subprogram_); |
260 | dbg_info_->di_builder_->insertDeclare(paramAlloca, param, |
261 | dbg_info_->di_builder_->createExpression(), |
262 | llvm::DebugLoc(di_loc), store); |
263 | } |
264 | dbg_info_->di_builder_->finalizeSubprogram(f_llvm->getSubprogram()); |
265 | auto* scope = f_llvm->getSubprogram(); |
266 | if (!scope) { |
267 | return; |
268 | } |
269 | |
270 | for (auto& BB : *f_llvm) { |
271 | for (auto& I : BB) { |
272 | if (I.getDebugLoc()) { |
273 | continue; |
274 | } |
275 | auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, scope); |
276 | I.setDebugLoc(llvm::DebugLoc(di_loc)); |
277 | } |
278 | } |
279 | #endif |
280 | } |
281 | |
282 | llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir) { |
283 | return GetDebugType(ty_tir, GetLLVMType(ty_tir)); |
284 | } |
285 | llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) { |
286 | if (ty_llvm == t_void_) { |
287 | return nullptr; |
288 | } else if (ty_llvm == llvm::Type::getFloatTy(*llvm_target_->GetContext())) { |
289 | return dbg_info_->di_builder_->createBasicType("float" , 32, llvm::dwarf::DW_ATE_float); |
290 | } else if (ty_llvm == t_int8_) { |
291 | return dbg_info_->di_builder_->createBasicType("int8" , 8, llvm::dwarf::DW_ATE_signed); |
292 | } else if (ty_llvm == t_int32_) { |
293 | return dbg_info_->di_builder_->createBasicType("int32" , 32, llvm::dwarf::DW_ATE_signed); |
294 | } else if (ty_llvm->isPointerTy()) { |
295 | auto* ptr_type = ty_tir.as<PointerTypeNode>(); |
296 | ICHECK(ptr_type != nullptr || GetRuntimeDataType(ty_tir).is_handle()) |
297 | << "Got LLVM pointer type from non-pointer IR type: " << ty_tir; |
298 | auto* pointee_type = ptr_type != nullptr ? GetDebugType(ptr_type->element_type, |
299 | GetLLVMType(ptr_type->element_type)) |
300 | : nullptr; |
301 | return dbg_info_->di_builder_->createPointerType(pointee_type, |
302 | ty_llvm->getPrimitiveSizeInBits()); |
303 | } else { |
304 | std::string type_str; |
305 | llvm::raw_string_ostream rso(type_str); |
306 | ty_llvm->print(rso); |
307 | LOG(FATAL) << "Unknown LLVM type:" << rso.str(); |
308 | } |
309 | return nullptr; |
310 | } |
311 | |
312 | void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { |
313 | llvm::Function* f = module_->getFunction(entry_func_name); |
314 | ICHECK(f) << "Function " << entry_func_name << "does not in module" ; |
315 | llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1); |
316 | llvm::GlobalVariable* global = |
317 | new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, nullptr, |
318 | runtime::symbol::tvm_module_main); |
319 | #if TVM_LLVM_VERSION >= 100 |
320 | global->setAlignment(llvm::Align(1)); |
321 | #else |
322 | global->setAlignment(1); |
323 | #endif |
324 | // comdat is needed for windows select any linking to work |
325 | // set comdat to Any(weak linking) |
326 | if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { |
327 | llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main); |
328 | comdat->setSelectionKind(llvm::Comdat::Any); |
329 | global->setComdat(comdat); |
330 | } |
331 | |
332 | global->setInitializer( |
333 | llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), entry_func_name)); |
334 | global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass); |
335 | } |
336 | |
337 | std::unique_ptr<llvm::Module> CodeGenCPU::Finish() { |
338 | // link modules |
339 | if (dbg_info_ != nullptr) { |
340 | dbg_info_->di_builder_->finalize(); |
341 | } |
342 | return CodeGenLLVM::Finish(); |
343 | } |
344 | |
345 | CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, |
346 | llvm::Value* index, int kind) { |
347 | if (kind < builtin::kArrKindBound_) { |
348 | if (buf->getType() == t_void_p_) { |
349 | buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); |
350 | } else { |
351 | ICHECK_EQ(buf->getType(), t_tvm_array_->getPointerTo()); |
352 | } |
353 | } |
354 | switch (kind) { |
355 | case builtin::kArrAddr: { |
356 | return TypedPointer(t_tvm_array_, builder_->CreateInBoundsGEP(t_tvm_array_, buf, index)); |
357 | } |
358 | case builtin::kArrData: { |
359 | llvm::Type* member_type = t_tvm_array_->getStructElementType(0); |
360 | llvm::Value* member_addr = |
361 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(0)}); |
362 | return TypedPointer(member_type, member_addr); |
363 | } |
364 | case builtin::kArrShape: { |
365 | llvm::Type* member_type = t_tvm_array_->getStructElementType(4); |
366 | llvm::Value* member_addr = |
367 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(4)}); |
368 | return TypedPointer(member_type, member_addr); |
369 | } |
370 | case builtin::kArrStrides: { |
371 | llvm::Type* member_type = t_tvm_array_->getStructElementType(5); |
372 | llvm::Value* member_addr = |
373 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(5)}); |
374 | return TypedPointer(member_type, member_addr); |
375 | } |
376 | case builtin::kArrNDim: { |
377 | llvm::Type* member_type = t_tvm_array_->getStructElementType(2); |
378 | llvm::Value* member_addr = |
379 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(2)}); |
380 | return TypedPointer(member_type, member_addr); |
381 | } |
382 | case builtin::kArrTypeCode: { |
383 | llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(0); |
384 | llvm::Value* member_addr = |
385 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(0)}); |
386 | return TypedPointer(member_type, member_addr); |
387 | } |
388 | case builtin::kArrTypeBits: { |
389 | llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(1); |
390 | llvm::Value* member_addr = |
391 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(1)}); |
392 | return TypedPointer(member_type, member_addr); |
393 | } |
394 | case builtin::kArrTypeLanes: { |
395 | llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(2); |
396 | llvm::Value* member_addr = |
397 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(2)}); |
398 | return TypedPointer(member_type, member_addr); |
399 | } |
400 | case builtin::kArrByteOffset: { |
401 | llvm::Type* member_type = t_tvm_array_->getStructElementType(6); |
402 | llvm::Value* member_addr = |
403 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(6)}); |
404 | return TypedPointer(member_type, member_addr); |
405 | } |
406 | case builtin::kArrDeviceId: { |
407 | llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(1); |
408 | llvm::Value* member_addr = |
409 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(1)}); |
410 | return TypedPointer(member_type, member_addr); |
411 | } |
412 | case builtin::kArrDeviceType: { |
413 | llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(0); |
414 | llvm::Value* member_addr = |
415 | builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(0)}); |
416 | return TypedPointer(member_type, member_addr); |
417 | } |
418 | case builtin::kTVMValueContent: { |
419 | ICHECK_EQ(t.lanes(), 1); |
420 | ICHECK(t.is_handle() || t.bits() == 64); |
421 | if (t.is_int()) { |
422 | buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); |
423 | return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); |
424 | } else if (t.is_float()) { |
425 | buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); |
426 | return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); |
427 | } else { |
428 | ICHECK(t.is_handle()); |
429 | buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); |
430 | buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); |
431 | return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); |
432 | } |
433 | } |
434 | default: |
435 | LOG(FATAL) << "unknown field code" ; |
436 | } |
437 | } |
438 | |
439 | llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, |
440 | const Array<PrimExpr>& args, bool skip_first_arg) { |
441 | std::vector<llvm::Value*> arg_values; |
442 | for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) { |
443 | arg_values.push_back(MakeValue(args[i])); |
444 | } |
445 | std::vector<llvm::Type*> arg_types; |
446 | for (llvm::Value* v : arg_values) { |
447 | arg_types.push_back(v->getType()); |
448 | } |
449 | llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false); |
450 | // Check if it is available in global function table as injected function. |
451 | auto it = gv_func_map_.find(global_symbol); |
452 | if (it != gv_func_map_.end()) { |
453 | if (it->second == nullptr) { |
454 | gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol); |
455 | it = gv_func_map_.find(global_symbol); |
456 | } |
457 | #if TVM_LLVM_VERSION >= 90 |
458 | auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second)); |
459 | #else |
460 | auto ext_callee = GetContextPtr(it->second); |
461 | #endif |
462 | return builder_->CreateCall(ext_callee, arg_values); |
463 | } else { |
464 | llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol)); |
465 | if (f == nullptr) { |
466 | f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, |
467 | MakeStringRef(global_symbol), module_.get()); |
468 | } |
469 | #if TVM_LLVM_VERSION >= 90 |
470 | auto ext_callee = llvm::FunctionCallee(f); |
471 | #else |
472 | auto ext_callee = f; |
473 | #endif |
474 | return builder_->CreateCall(ext_callee, arg_values); |
475 | } |
476 | } |
477 | |
478 | llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) { |
479 | llvm::GlobalVariable* gv = new llvm::GlobalVariable( |
480 | *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, nullptr, name); |
481 | #if TVM_LLVM_VERSION >= 100 |
482 | gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(p_type))); |
483 | #else |
484 | gv->setAlignment(data_layout_->getTypeAllocSize(p_type)); |
485 | #endif |
486 | gv->setInitializer(llvm::Constant::getNullValue(p_type)); |
487 | gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); |
488 | // comdat is needed for windows select any linking to work |
489 | // set comdat to Any(weak linking) |
490 | if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { |
491 | llvm::Comdat* comdat = module_->getOrInsertComdat(name); |
492 | comdat->setSelectionKind(llvm::Comdat::Any); |
493 | gv->setComdat(comdat); |
494 | } |
495 | return gv; |
496 | } |
497 | |
498 | llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { |
499 | ICHECK(gv != nullptr); |
500 | #if TVM_LLVM_VERSION >= 110 |
501 | llvm::LoadInst* faddr = |
502 | builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment())); |
503 | #elif TVM_LLVM_VERSION >= 80 |
504 | llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment()); |
505 | #else |
506 | llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); |
507 | #endif |
508 | faddr->setMetadata("tbaa" , |
509 | md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); |
510 | return faddr; |
511 | } |
512 | |
513 | void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { |
514 | // Module context |
515 | gv_mod_ctx_ = InitContextPtr(t_void_p_, tvm::runtime::symbol::tvm_module_ctx); |
516 | // Register back the locations. |
517 | if (f_tvm_register_system_symbol_ != nullptr && !target_c_runtime_) { |
518 | export_system_symbols_.emplace_back( |
519 | std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_)); |
520 | } else { |
521 | if (!dynamic_lookup) { |
522 | gv_tvm_func_call_ = InitContextPtr(ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall" ); |
523 | gv_tvm_get_func_from_env_ = InitContextPtr(ftype_tvm_get_func_from_env_->getPointerTo(), |
524 | "__TVMBackendGetFuncFromEnv" ); |
525 | gv_tvm_api_set_last_error_ = |
526 | InitContextPtr(ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError" ); |
527 | gv_tvm_parallel_launch_ = |
528 | InitContextPtr(ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch" ); |
529 | gv_tvm_parallel_barrier_ = InitContextPtr(ftype_tvm_parallel_barrier_->getPointerTo(), |
530 | "__TVMBackendParallelBarrier" ); |
531 | // Mark as context functions |
532 | gv_func_map_["TVMBackendAllocWorkspace" ] = nullptr; |
533 | gv_func_map_["TVMBackendFreeWorkspace" ] = nullptr; |
534 | } |
535 | } |
536 | } |
537 | |
538 | llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) { |
539 | // create emit codes that checks and load the function. |
540 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
541 | auto* fail_block = llvm::BasicBlock::Create(*ctx, "call_fail" , function_); |
542 | auto* end_block = llvm::BasicBlock::Create(*ctx, "call_end" , function_); |
543 | auto* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0)); |
544 | builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); |
545 | builder_->SetInsertPoint(fail_block); |
546 | // return the code. |
547 | builder_->CreateRet(retcode); |
548 | // otherwise set it to be new end. |
549 | builder_->SetInsertPoint(end_block); |
550 | return end_block; |
551 | } |
552 | |
553 | void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { |
554 | EmitDebugLocation(op); |
555 | /*! \brief maintain states that should be guarded when step into compute scope */ |
556 | struct ComputeScopeStates { |
557 | explicit ComputeScopeStates(CodeGenCPU* parent) : parent_(parent) {} |
558 | |
559 | void EnterWithScope() { |
560 | std::swap(function_, parent_->function_); |
561 | std::swap(analyzer_, parent_->analyzer_); |
562 | std::swap(var_map_, parent_->var_map_); |
563 | } |
564 | |
565 | void ExitWithScope() { |
566 | std::swap(function_, parent_->function_); |
567 | std::swap(analyzer_, parent_->analyzer_); |
568 | std::swap(var_map_, parent_->var_map_); |
569 | } |
570 | |
571 | llvm::Function* function_{nullptr}; |
572 | std::unordered_map<const VarNode*, llvm::Value*> var_map_; |
573 | std::unique_ptr<arith::Analyzer> analyzer_{std::make_unique<arith::Analyzer>()}; |
574 | CodeGenCPU* parent_; |
575 | }; |
576 | |
577 | // There are two reasons why we create another function for compute_scope |
578 | // - Make sure the generated compute function is clearly separately(though it can get inlined) |
579 | // - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs. |
580 | // This is easier than set the alias scope manually. |
581 | Array<Var> vargs = tir::UndefinedVars(op->body, {}); |
582 | std::vector<llvm::Value*> arg_values; |
583 | std::vector<llvm::Type*> arg_types; |
584 | for (Var v : vargs) { |
585 | llvm::Value* value = MakeValue(v); |
586 | value->setName(v->name_hint.c_str()); |
587 | arg_values.push_back(value); |
588 | arg_types.push_back(value->getType()); |
589 | } |
590 | llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_types, false); |
591 | // $xxx_compute_ functions are not global. They should be marked as static (via InternalLinkage) |
592 | // to call them correctly on MIPS platform (CALL16 issue) |
593 | // Linkage ld Error: CALL16 reloc at 0x290 not against global symbol |
594 | const StringImmNode* value = op->value.as<StringImmNode>(); |
595 | ICHECK(value != nullptr); |
596 | llvm::Function* fcompute = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, |
597 | MakeStringRef(value->value), module_.get()); |
598 | SetTargetAttributes(fcompute); |
599 | |
600 | llvm::BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values)); |
601 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
602 | // enter compute scope and setup compute function. |
603 | With<ComputeScopeStates> scope_states_guard(this); |
604 | size_t idx = 0; |
605 | for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); ++it, ++idx) { |
606 | llvm::Argument* v = &(*it); |
607 | const Var& var = vargs[idx]; |
608 | var_map_[var.get()] = v; |
609 | if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) { |
610 | // set non alias. |
611 | #if TVM_LLVM_VERSION >= 50 |
612 | fcompute->addParamAttr(idx, llvm::Attribute::NoAlias); |
613 | // always not inline compute function to make the code structure clean |
614 | #else |
615 | fcompute->setDoesNotAlias(idx + 1); |
616 | #endif |
617 | fcompute->addFnAttr(llvm::Attribute::NoInline); |
618 | } |
619 | // Add alignment attribute if needed. |
620 | #if TVM_LLVM_VERSION >= 50 |
621 | auto f = alloc_storage_info_.find(var.get()); |
622 | if (f != alloc_storage_info_.end()) { |
623 | unsigned align = f->second.alignment; |
624 | if (align > 1) { |
625 | auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align); |
626 | fcompute->addParamAttr(idx, attr); |
627 | } |
628 | } |
629 | #endif |
630 | } |
631 | |
632 | function_ = fcompute; |
633 | auto* compute_entry = llvm::BasicBlock::Create(*ctx, "entry" , function_); |
634 | builder_->SetInsertPoint(compute_entry); |
635 | this->VisitStmt(op->body); |
636 | builder_->CreateRet(ConstInt32(0)); |
637 | builder_->SetInsertPoint(compute_call_end); |
638 | } |
639 | |
640 | CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array<Var>& vfields, |
641 | uint64_t* num_bytes, |
642 | std::string struct_name) { |
643 | if (vfields.size() == 0) { |
644 | *num_bytes = 0U; |
645 | return TypedPointer(t_void_p_, llvm::Constant::getNullValue(t_void_p_)); |
646 | } |
647 | std::vector<llvm::Type*> fields; |
648 | for (Var v : vfields) { |
649 | auto it = var_map_.find(v.get()); |
650 | ICHECK(it != var_map_.end()); |
651 | fields.push_back(it->second->getType()); |
652 | } |
653 | llvm::StructType* ctype = struct_name.size() ? llvm::StructType::create(fields, struct_name) |
654 | : llvm::StructType::create(fields); |
655 | llvm::AllocaInst* cvalue = |
656 | WithFunctionEntry([&]() { return builder_->CreateAlloca(ctype, ConstInt32(1)); }); |
657 | llvm::Value* zero = ConstInt32(0); |
658 | for (size_t i = 0; i < vfields.size(); ++i) { |
659 | builder_->CreateStore(var_map_.at(vfields[i].get()), |
660 | builder_->CreateInBoundsGEP(ctype, cvalue, {zero, ConstInt32(i)})); |
661 | } |
662 | *num_bytes = data_layout_->getTypeAllocSize(ctype); |
663 | return TypedPointer(ctype, cvalue); |
664 | } |
665 | |
666 | void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const Array<Var>& vfields, |
667 | std::unordered_map<const VarNode*, llvm::Value*>* vmap) { |
668 | for (size_t i = 0; i < vfields.size(); ++i) { |
669 | llvm::Type* field_type = cdata.type->getStructElementType(i); |
670 | llvm::Value* field_addr = |
671 | builder_->CreateInBoundsGEP(cdata.type, cdata.addr, {ConstInt32(0), ConstInt32(i)}); |
672 | llvm::Value* load = |
673 | builder_->CreateLoad(field_type, field_addr, std::string(vfields[i]->name_hint)); |
674 | (*vmap)[vfields[i].get()] = load; |
675 | } |
676 | } |
677 | |
678 | void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::string name) { |
679 | // closure data |
680 | llvm::Function* f = |
681 | llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage, |
682 | "__tvm_parallel_lambda" , module_.get()); |
683 | SetTargetAttributes(f); |
684 | |
685 | // allocate and setup the closure, call the closure. |
686 | Array<Var> vfields = tir::UndefinedVars(body, {}); |
687 | uint64_t nbytes; |
688 | TypedPointer cdata = PackClosureData(vfields, &nbytes, "closure_" + name); |
689 | #if TVM_LLVM_VERSION >= 90 |
690 | auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); |
691 | #else |
692 | auto launch_callee = RuntimeTVMParallelLaunch(); |
693 | #endif |
694 | llvm::BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall( |
695 | launch_callee, |
696 | {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)})); |
697 | // Setup the closure function. |
698 | auto* lambda_entry = |
699 | llvm::BasicBlock::Create(*llvm_target_->GetContext(), "parallel_closure_entry" , f); |
700 | builder_->SetInsertPoint(lambda_entry); |
701 | auto it = f->arg_begin(); |
702 | llvm::Value* task_id = &(*it++); |
703 | task_id->setName("task_id" ); |
704 | llvm::Value* penv = &(*it++); |
705 | cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); |
706 | // setup new variable map, swap it with current var context. |
707 | std::unordered_map<const VarNode*, llvm::Value*> new_vmap; |
708 | UnpackClosureData(cdata, vfields, &new_vmap); |
709 | // setup parallel env |
710 | ParallelEnv par_env; |
711 | par_env.task_id = Var("task_id" , DataType::Int(32)); |
712 | par_env.num_task = Var("num_task" , DataType::Int(32)); |
713 | new_vmap[par_env.task_id.get()] = task_id; |
714 | new_vmap[par_env.num_task.get()] = builder_->CreateLoad( |
715 | t_int32_, |
716 | builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)}), |
717 | "num_task" ); |
718 | par_env.penv = penv; |
719 | auto new_analyzer = std::make_unique<arith::Analyzer>(); |
720 | std::swap(function_, f); |
721 | std::swap(parallel_env_, par_env); |
722 | std::swap(analyzer_, new_analyzer); |
723 | std::swap(var_map_, new_vmap); |
724 | this->VisitStmt(body); |
725 | builder_->CreateRet(ConstInt32(0)); |
726 | // swap the var map back, now we are back on track. |
727 | std::swap(var_map_, new_vmap); |
728 | std::swap(analyzer_, new_analyzer); |
729 | std::swap(parallel_env_, par_env); |
730 | std::swap(function_, f); |
731 | ICHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch" ; |
732 | builder_->SetInsertPoint(par_launch_end); |
733 | } |
734 | |
735 | llvm::Value* CodeGenCPU::CreateStaticHandle() { |
736 | llvm::GlobalVariable* gv = |
737 | new llvm::GlobalVariable(*module_, t_void_p_, false, llvm::GlobalValue::PrivateLinkage, |
738 | nullptr, "__tvm_static_handle" ); |
739 | #if TVM_LLVM_VERSION >= 100 |
740 | gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(t_void_p_))); |
741 | #else |
742 | gv->setAlignment(data_layout_->getTypeAllocSize(t_void_p_)); |
743 | #endif |
744 | gv->setInitializer(llvm::Constant::getNullValue(t_void_p_)); |
745 | return gv; |
746 | } |
747 | |
748 | void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) { |
749 | // closure data |
750 | llvm::Function* f = |
751 | llvm::Function::Create(ftype_tvm_static_init_callback_, llvm::Function::PrivateLinkage, |
752 | "__tvm_static_init_lambda" , module_.get()); |
753 | SetTargetAttributes(f); |
754 | llvm::Value* gv = CreateStaticHandle(); |
755 | llvm::Function* finit = module_->getFunction(init_fname); |
756 | if (finit == nullptr) { |
757 | finit = llvm::Function::Create(ftype_tvm_static_init_, llvm::Function::ExternalLinkage, |
758 | init_fname, module_.get()); |
759 | } |
760 | // allocate and setup the closure, call the closure. |
761 | uint64_t nbytes; |
762 | Array<Var> vfields = tir::UndefinedVars(body, {}); |
763 | TypedPointer cdata = PackClosureData(vfields, &nbytes); |
764 | llvm::BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( |
765 | finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); |
766 | // Setup the closure function. |
767 | auto* lambda_entry = llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry" , f); |
768 | builder_->SetInsertPoint(lambda_entry); |
769 | auto it = f->arg_begin(); |
770 | cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); |
771 | // setup new variable map, swap it with current var context. |
772 | std::unordered_map<const VarNode*, llvm::Value*> new_vmap; |
773 | UnpackClosureData(cdata, vfields, &new_vmap); |
774 | ICHECK(parallel_env_.penv == nullptr); |
775 | auto new_analyzer = std::make_unique<arith::Analyzer>(); |
776 | std::swap(function_, f); |
777 | std::swap(analyzer_, new_analyzer); |
778 | std::swap(var_map_, new_vmap); |
779 | this->VisitStmt(body); |
780 | builder_->CreateRet(ConstInt32(0)); |
781 | // swap the var map back, now we are back on track. |
782 | std::swap(var_map_, new_vmap); |
783 | std::swap(analyzer_, new_analyzer); |
784 | std::swap(function_, f); |
785 | builder_->SetInsertPoint(init_end); |
786 | } |
787 | |
788 | llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { |
789 | // We will store the packed function handle in global space. |
790 | // Initialize it during the first call. |
791 | llvm::DataLayout layout(module_.get()); |
792 | uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_); |
793 | auto it = func_handle_map_.find(fname); |
794 | |
795 | llvm::GlobalVariable* hptr; |
796 | if (it == func_handle_map_.end()) { |
797 | // create global location for the handle |
798 | // create the function handle |
799 | hptr = |
800 | new llvm::GlobalVariable(*module_, t_tvm_func_handle_, false, |
801 | llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname); |
802 | #if TVM_LLVM_VERSION >= 100 |
803 | hptr->setAlignment(llvm::Align(align)); |
804 | #else |
805 | hptr->setAlignment(align); |
806 | #endif |
807 | hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_)); |
808 | func_handle_map_[fname] = hptr; |
809 | } else { |
810 | hptr = it->second; |
811 | } |
812 | // create emit codes that checks and load the function. |
813 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
814 | llvm::BasicBlock* pre_block = builder_->GetInsertBlock(); |
815 | auto* init_block = llvm::BasicBlock::Create(*ctx, "handle_init" , function_); |
816 | auto* end_block = llvm::BasicBlock::Create(*ctx, "handle_init_end" , function_); |
817 | #if TVM_LLVM_VERSION >= 110 |
818 | llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align)); |
819 | #elif TVM_LLVM_VERSION >= 80 |
820 | llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, align); |
821 | #else |
822 | llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); |
823 | #endif |
824 | llvm::Value* handle_not_null = |
825 | builder_->CreateICmpNE(handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); |
826 | builder_->CreateCondBr(handle_not_null, end_block, init_block, md_very_likely_branch_); |
827 | // Initialize the handle if needed. |
828 | builder_->SetInsertPoint(init_block); |
829 | llvm::Value* out = |
830 | WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); |
831 | #if TVM_LLVM_VERSION >= 110 |
832 | llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, |
833 | llvm::Align(gv_mod_ctx_->getAlignment())); |
834 | #elif TVM_LLVM_VERSION >= 80 |
835 | llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, |
836 | gv_mod_ctx_->getAlignment()); |
837 | #else |
838 | llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); |
839 | #endif |
840 | ctx_load->setMetadata( |
841 | "tbaa" , md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); |
842 | #if TVM_LLVM_VERSION >= 90 |
843 | auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); |
844 | #else |
845 | auto env_callee = RuntimeTVMGetFuncFromEnv(); |
846 | #endif |
847 | llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx_load, GetConstString(fname), out}); |
848 | init_block = CheckCallSuccess(retcode); |
849 | #if TVM_LLVM_VERSION >= 110 |
850 | llvm::Value* loaded_handle = |
851 | builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align)); |
852 | #elif TVM_LLVM_VERSION >= 80 |
853 | llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align); |
854 | #else |
855 | llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); |
856 | #endif |
857 | // Store the handle |
858 | builder_->CreateStore(loaded_handle, hptr); |
859 | builder_->CreateBr(end_block); |
860 | // end block |
861 | builder_->SetInsertPoint(end_block); |
862 | llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2); |
863 | phi->addIncoming(handle, pre_block); |
864 | phi->addIncoming(loaded_handle, init_block); |
865 | return phi; |
866 | } |
867 | |
868 | CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& args, |
869 | const DataType& r_type, |
870 | const int64_t begin, const int64_t end, |
871 | bool use_string_lookup) { |
872 | PackedCall pc; |
873 | std::string func_name = args[0].as<StringImmNode>()->value; |
874 | // call the function |
875 | int64_t nargs = end - begin; |
876 | ICHECK_GE(nargs, 0); |
877 | llvm::Value* stack_value = MakeValue(args[1]); |
878 | llvm::Value* stack_tcode = MakeValue(args[2]); |
879 | llvm::Value* arg_value = builder_->CreateInBoundsGEP( |
880 | t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), |
881 | ConstInt32(begin)); |
882 | TypedPointer arg_tcode = |
883 | CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(begin)}, DataType::Int(32)); |
884 | llvm::Value* ret_value = builder_->CreateInBoundsGEP( |
885 | t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), |
886 | ConstInt32(end)); |
887 | TypedPointer ret_tcode = |
888 | CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32)); |
889 | |
890 | llvm::FunctionType* callee_ftype = nullptr; |
891 | llvm::Value* callee_value = nullptr; |
892 | std::vector<llvm::Value*> call_args; |
893 | |
894 | if (use_string_lookup) { |
895 | callee_ftype = ftype_tvm_func_call_; |
896 | callee_value = RuntimeTVMFuncCall(); |
897 | call_args.push_back(GetPackedFuncHandle(func_name)); |
898 | call_args.insert(call_args.end(), |
899 | {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); |
900 | } else { |
901 | callee_ftype = ftype_tvm_backend_packed_c_func_; |
902 | callee_value = module_->getFunction(func_name); |
903 | if (callee_value == nullptr) { |
904 | callee_value = |
905 | llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, |
906 | func_name, module_.get()); |
907 | } |
908 | |
909 | nargs -= 1; |
910 | call_args.insert(call_args.end(), { |
911 | builder_->CreateBitCast(arg_value, t_void_p_), |
912 | arg_tcode.addr, |
913 | ConstInt32(nargs), |
914 | builder_->CreateBitCast(ret_value, t_void_p_), |
915 | ret_tcode.addr, |
916 | }); |
917 | call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_)); |
918 | } |
919 | #if TVM_LLVM_VERSION >= 90 |
920 | auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value); |
921 | #else |
922 | (void)callee_ftype; // use callee_ftype to avoid unused variable warning when using older LLVM. |
923 | auto call_callee = callee_value; |
924 | #endif |
925 | llvm::Value* call = builder_->CreateCall(call_callee, call_args); |
926 | |
927 | llvm::BasicBlock* end_block = CheckCallSuccess(call); |
928 | |
929 | // Load the return value and cast it to the designated type (r_type). |
930 | DataType r_api_type = tir::APIType(r_type); |
931 | llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); |
932 | llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo()); |
933 | #if TVM_LLVM_VERSION >= 110 |
934 | llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); |
935 | #elif TVM_LLVM_VERSION >= 80 |
936 | llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8); |
937 | #else |
938 | llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); |
939 | #endif |
940 | pc.ret_value = CreateCast(r_api_type, r_type, rvalue); |
941 | |
942 | // Load the return type code. |
943 | #if TVM_LLVM_VERSION >= 110 |
944 | pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); |
945 | #elif TVM_LLVM_VERSION >= 80 |
946 | pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); |
947 | #else |
948 | pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); |
949 | #endif |
950 | |
951 | pc.end_block = end_block; |
952 | return pc; |
953 | } |
954 | |
955 | llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lookup) { |
956 | auto expected_num_args = use_string_lookup ? 5U : 6U; |
957 | ICHECK_EQ(op->args.size(), expected_num_args); |
958 | PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value, |
959 | op->args[4].as<IntImmNode>()->value, use_string_lookup); |
960 | return pc.ret_value; |
961 | } |
962 | |
963 | llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { |
964 | ICHECK_EQ(op->args.size(), 6U); |
965 | PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value, |
966 | op->args[4].as<IntImmNode>()->value, true); |
967 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
968 | // Get traced value. |
969 | llvm::Value* traced_value = MakeValue(op->args[5]); |
970 | // The update_block handles case when we need to update the return value. |
971 | llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx, "update_block" , function_); |
972 | // The continue_block handles case when we need to return original |
973 | // traced value. |
974 | llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx, "continue_block" , function_); |
975 | |
976 | // Check the ret_type_code and create cmp instruction. |
977 | llvm::Value* cmp = |
978 | builder_->CreateICmpNE(pc.ret_tcode, llvm::ConstantInt::get(t_int_, kTVMNullptr)); |
979 | builder_->CreateCondBr(cmp, update_block, continue_block); |
980 | builder_->SetInsertPoint(update_block); |
981 | builder_->CreateBr(continue_block); |
982 | builder_->SetInsertPoint(continue_block); |
983 | // The return value depends on from what bb we come from. |
984 | llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); |
985 | phi_rvalue->addIncoming(pc.ret_value, update_block); |
986 | phi_rvalue->addIncoming(traced_value, pc.end_block); |
987 | return phi_rvalue; |
988 | } |
989 | |
990 | llvm::Value* CodeGenCPU::RuntimeTVMFuncCall() { |
991 | if (f_tvm_func_call_ != nullptr) return f_tvm_func_call_; |
992 | return GetContextPtr(gv_tvm_func_call_); |
993 | } |
994 | |
995 | llvm::Value* CodeGenCPU::RuntimeTVMGetFuncFromEnv() { |
996 | if (f_tvm_get_func_from_env_ != nullptr) return f_tvm_get_func_from_env_; |
997 | return GetContextPtr(gv_tvm_get_func_from_env_); |
998 | } |
999 | llvm::Value* CodeGenCPU::RuntimeTVMAPISetLastError() { |
1000 | if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_; |
1001 | return GetContextPtr(gv_tvm_api_set_last_error_); |
1002 | } |
1003 | llvm::Value* CodeGenCPU::RuntimeTVMParallelLaunch() { |
1004 | if (f_tvm_parallel_launch_ != nullptr) return f_tvm_parallel_launch_; |
1005 | return GetContextPtr(gv_tvm_parallel_launch_); |
1006 | } |
1007 | |
1008 | llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { |
1009 | if (f_tvm_parallel_barrier_ != nullptr) return f_tvm_parallel_barrier_; |
1010 | return GetContextPtr(gv_tvm_parallel_barrier_); |
1011 | } |
1012 | |
1013 | /*! \brief Defines LLVM Types for each Metadata member type. */ |
1014 | struct MetadataLlvmTypes { |
1015 | llvm::Type* t_float64; |
1016 | llvm::Type* t_uint8; |
1017 | llvm::Type* t_int64; |
1018 | llvm::Type* t_bool; |
1019 | llvm::Type* t_cstring; |
1020 | llvm::Type* t_void_p; |
1021 | llvm::StructType* t_data_type; |
1022 | |
1023 | /*! \brief Maps a MetadataBase subclass' type_key to its corresponding LLVM StructType. */ |
1024 | ::std::unordered_map<std::string, llvm::StructType*> structs_by_type_key; |
1025 | }; |
1026 | |
1027 | class MetadataTypeDefiner : public AttrVisitor { |
1028 | public: |
1029 | MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* llvm_types) |
1030 | : ctx_{ctx}, llvm_types_{llvm_types} {} |
1031 | |
1032 | void Visit(const char* key, double* value) final { |
1033 | elements_.emplace_back(llvm_types_->t_float64); |
1034 | } |
1035 | void Visit(const char* key, int64_t* value) final { |
1036 | elements_.emplace_back(llvm_types_->t_int64); |
1037 | } |
1038 | void Visit(const char* key, uint64_t* value) final { |
1039 | elements_.emplace_back(llvm_types_->t_int64); |
1040 | } |
1041 | void Visit(const char* key, int* value) final { elements_.emplace_back(llvm_types_->t_int64); } |
1042 | void Visit(const char* key, bool* value) final { elements_.emplace_back(llvm_types_->t_bool); } |
1043 | void Visit(const char* key, std::string* value) final { |
1044 | elements_.emplace_back(llvm_types_->t_cstring); |
1045 | } |
1046 | void Visit(const char* key, void** value) final { elements_.emplace_back(llvm_types_->t_void_p); } |
1047 | void Visit(const char* key, DataType* value) final { |
1048 | elements_.emplace_back(llvm_types_->t_data_type); |
1049 | } |
1050 | void Visit(const char* key, runtime::NDArray* value) final { |
1051 | elements_.emplace_back(llvm_types_->t_int64); |
1052 | elements_.emplace_back(llvm_types_->t_void_p); |
1053 | } |
1054 | |
1055 | private: |
1056 | void VisitMetadataBase(runtime::metadata::MetadataBase metadata) { |
1057 | elements_.emplace_back(llvm::PointerType::getUnqual( |
1058 | llvm::StructType::create(*ctx_, metadata->get_c_struct_name()))); |
1059 | if (visited_.find(metadata->get_c_struct_name()) != visited_.end()) { |
1060 | return; |
1061 | } |
1062 | |
1063 | if (to_visit_.find(metadata->get_c_struct_name()) != to_visit_.end()) { |
1064 | return; |
1065 | } |
1066 | to_visit_[metadata->get_c_struct_name()] = metadata; |
1067 | } |
1068 | |
1069 | public: |
1070 | using MetadataKind = runtime::metadata::MetadataKind; |
1071 | |
1072 | void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { |
1073 | switch (arr->kind) { |
1074 | case MetadataKind::kUint64: // LLVM encodes signed and unsigned with same types. |
1075 | case MetadataKind::kInt64: |
1076 | elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64)); |
1077 | break; |
1078 | case MetadataKind::kBool: |
1079 | elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_bool)); |
1080 | break; |
1081 | case MetadataKind::kString: |
1082 | elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring)); |
1083 | break; |
1084 | case MetadataKind::kHandle: |
1085 | CHECK(false) << "Do not support handle" ; |
1086 | break; |
1087 | case MetadataKind::kMetadata: |
1088 | if (llvm_types_->structs_by_type_key.count(arr->type_key)) { |
1089 | elements_.emplace_back( |
1090 | llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key])); |
1091 | } |
1092 | break; |
1093 | default: |
1094 | CHECK(false) << "Unsupported metadata kind " << arr->kind; |
1095 | break; |
1096 | } |
1097 | } |
1098 | |
1099 | void Visit(const char* key, ObjectRef* value) final { |
1100 | const runtime::metadata::MetadataArrayNode* arr = |
1101 | value->as<runtime::metadata::MetadataArrayNode>(); |
1102 | if (arr != nullptr) { |
1103 | VisitArray(arr); |
1104 | } else { |
1105 | elements_.emplace_back( |
1106 | llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[(*value)->GetTypeKey()])); |
1107 | } |
1108 | } |
1109 | |
1110 | void DefineType(runtime::metadata::MetadataBase metadata) { |
1111 | ICHECK(elements_.empty()); |
1112 | ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); |
1113 | llvm_types_->structs_by_type_key[metadata->GetTypeKey()] = |
1114 | llvm::StructType::create(*ctx_, elements_, metadata->get_c_struct_name()); |
1115 | elements_.clear(); |
1116 | } |
1117 | |
1118 | llvm::LLVMContext* ctx_; |
1119 | struct MetadataLlvmTypes* llvm_types_; |
1120 | ::std::unordered_set<::std::string> visited_; |
1121 | ::std::unordered_map<::std::string, runtime::metadata::MetadataBase> to_visit_; |
1122 | ::std::vector<llvm::Type*> elements_; |
1123 | }; |
1124 | |
1125 | class MetadataSerializerLLVM : public AttrVisitor { |
1126 | using MetadataKind = runtime::metadata::MetadataKind; |
1127 | |
1128 | public: |
1129 | MetadataSerializerLLVM(CodeGenLLVM* codegen, struct MetadataLlvmTypes* llvm_types) |
1130 | : codegen_{codegen}, llvm_types_{llvm_types} {} |
1131 | |
1132 | void Visit(const char* key, double* value) final { |
1133 | elements_.back().emplace_back(llvm::ConstantFP::get(llvm_types_->t_float64, *value)); |
1134 | } |
1135 | void Visit(const char* key, int64_t* value) final { |
1136 | elements_.back().emplace_back(llvm::ConstantInt::get( |
1137 | llvm_types_->t_int64, static_cast<uint64_t>(*value), true /* isSigned */)); |
1138 | } |
1139 | void Visit(const char* key, uint64_t* value) final { |
1140 | elements_.back().emplace_back( |
1141 | llvm::ConstantInt::get(llvm_types_->t_int64, *value, false /* isSigned */)); |
1142 | } |
1143 | void Visit(const char* key, int* value) final { |
1144 | elements_.back().emplace_back( |
1145 | llvm::ConstantInt::get(llvm_types_->t_int64, *value, true /* isSigned */)); |
1146 | } |
1147 | void Visit(const char* key, bool* value) final { |
1148 | elements_.back().emplace_back(llvm::ConstantInt::get( |
1149 | llvm_types_->t_uint8, static_cast<uint64_t>(*value), false /* isSigned */)); |
1150 | } |
1151 | void Visit(const char* key, std::string* value) final { |
1152 | elements_.back().emplace_back(codegen_->GetConstString(*value)); |
1153 | } |
1154 | void Visit(const char* key, void** value) final { |
1155 | CHECK(false) << "Do not support serializing void*" ; |
1156 | } |
1157 | void Visit(const char* key, DataType* value) final { |
1158 | elements_.back().emplace_back(llvm::ConstantStruct::get( |
1159 | llvm_types_->t_data_type, |
1160 | {llvm::ConstantInt::get(llvm_types_->t_uint8, value->code(), false /* isSigned */), |
1161 | llvm::ConstantInt::get(llvm_types_->t_uint8, value->bits(), false /* isSigned */), |
1162 | llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */)})); |
1163 | } |
1164 | |
1165 | // Serializing NDArray as tuple of len, data |
1166 | void Visit(const char* key, runtime::NDArray* value) final { |
1167 | std::string bytes; |
1168 | dmlc::MemoryStringStream stream(&bytes); |
1169 | value->Save(&stream); |
1170 | elements_.back().emplace_back( |
1171 | llvm::ConstantInt::get(llvm_types_->t_int64, bytes.length(), true /* isSigned */)); |
1172 | elements_.back().emplace_back(codegen_->GetConstString(bytes)); |
1173 | } |
1174 | |
1175 | void VisitMetadata(runtime::metadata::MetadataBase metadata) { |
1176 | elements_.emplace_back(std::vector<llvm::Constant*>()); |
1177 | ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); |
1178 | auto struct_elements = elements_.back(); |
1179 | elements_.pop_back(); |
1180 | auto struct_ty = llvm_types_->structs_by_type_key[metadata->GetTypeKey()]; |
1181 | ICHECK(struct_ty != nullptr) << "Did not find LLVM StructType* for type_key=" |
1182 | << metadata->GetTypeKey(); |
1183 | CHECK_EQ(struct_elements.size(), struct_ty->getNumElements()); |
1184 | auto out = llvm::ConstantStruct::get(struct_ty, struct_elements); |
1185 | if (elements_.size() > 0) { |
1186 | elements_.back().push_back(out); |
1187 | } else { |
1188 | last_production_ = out; |
1189 | } |
1190 | } |
1191 | |
1192 | void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { |
1193 | llvm::Type* element_type; |
1194 | switch (arr->kind) { |
1195 | case MetadataKind::kInt64: |
1196 | element_type = llvm_types_->t_int64; |
1197 | break; |
1198 | case MetadataKind::kUint64: |
1199 | element_type = llvm_types_->t_int64; |
1200 | break; |
1201 | case MetadataKind::kBool: |
1202 | element_type = llvm_types_->t_uint8; |
1203 | break; |
1204 | case MetadataKind::kString: |
1205 | element_type = llvm_types_->t_cstring; |
1206 | break; |
1207 | case MetadataKind::kMetadata: { |
1208 | element_type = llvm_types_->structs_by_type_key[arr->type_key]; |
1209 | ICHECK(element_type != nullptr) |
1210 | << "Did not find LLVM StructType* for type_key=" << arr->type_key; |
1211 | break; |
1212 | } |
1213 | default: |
1214 | LOG(FATAL) << "unknown metadata kind " << arr->kind; |
1215 | break; |
1216 | } |
1217 | |
1218 | elements_.emplace_back(std::vector<llvm::Constant*>()); |
1219 | for (auto o : arr->array) { |
1220 | if (o->IsInstance<FloatImmNode>()) { |
1221 | double value = Downcast<FloatImm>(o)->value; |
1222 | Visit(nullptr, &value); |
1223 | } |
1224 | if (o->IsInstance<IntImmNode>()) { |
1225 | auto value = Downcast<IntImm>(o)->value; |
1226 | Visit(nullptr, &value); |
1227 | } else if (o->IsInstance<StringObj>()) { |
1228 | ::std::string value = Downcast<String>(o); |
1229 | Visit(nullptr, &value); |
1230 | } else { |
1231 | // nested array not possible. |
1232 | VisitMetadata(Downcast<runtime::metadata::MetadataBase>(o)); |
1233 | } |
1234 | } |
1235 | auto array = elements_.back(); |
1236 | elements_.pop_back(); |
1237 | CHECK(element_type != nullptr); |
1238 | auto arr_ty = llvm::ArrayType::get(element_type, array.size()); |
1239 | auto llvm_arr = llvm::ConstantArray::get(arr_ty, array); |
1240 | |
1241 | if (elements_.size() > 0) { |
1242 | elements_.back().emplace_back( |
1243 | codegen_->GetGlobalConstant(llvm_arr, "" , llvm::GlobalValue::PrivateLinkage)); |
1244 | } else { |
1245 | last_production_ = llvm_arr; |
1246 | } |
1247 | } |
1248 | |
1249 | void Visit(const char* key, ObjectRef* value) final { |
1250 | const runtime::metadata::MetadataArrayNode* arr = |
1251 | value->as<runtime::metadata::MetadataArrayNode>(); |
1252 | if (arr != nullptr) { |
1253 | VisitArray(arr); |
1254 | return; |
1255 | } |
1256 | |
1257 | runtime::metadata::MetadataBase metadata = Downcast<runtime::metadata::MetadataBase>(*value); |
1258 | VisitMetadata(metadata); |
1259 | } |
1260 | |
1261 | llvm::Constant* Serialize(runtime::metadata::MetadataBase metadata) { |
1262 | Visit(nullptr, &metadata); |
1263 | ICHECK(last_production_); |
1264 | return codegen_->GetGlobalConstant(last_production_); |
1265 | } |
1266 | |
1267 | CodeGenLLVM* codegen_; |
1268 | MetadataLlvmTypes* llvm_types_; |
1269 | llvm::LLVMContext* ctx_; |
1270 | llvm::Module* module_; |
1271 | std::vector<std::vector<llvm::Constant*>> elements_; |
1272 | llvm::Constant* last_production_; |
1273 | }; |
1274 | |
1275 | void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { |
1276 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
1277 | MetadataLlvmTypes llvm_types{ |
1278 | t_float64_ /* t_float64 */, |
1279 | llvm::Type::getInt8Ty(*ctx) /* t_uint8 */, |
1280 | t_int64_ /* t_int64 */, |
1281 | llvm::Type::getInt8Ty(*ctx) /* t_bool */, |
1282 | t_char_->getPointerTo() /* t_cstring */, |
1283 | t_void_p_ /* t_void_p */, |
1284 | llvm::StructType::create(*ctx, {t_int8_, t_int8_, t_int8_}, "DLDataType" ) /* t_data_type */, |
1285 | }; |
1286 | |
1287 | // create sample ConstantInfoMetadata instance for MetadataTypeDefiner |
1288 | std::string bytes; |
1289 | runtime::NDArray ci = runtime::NDArray::Empty({0}, DataType::UInt(8), Device{kDLCPU}); |
1290 | dmlc::MemoryStringStream stream(&bytes); |
1291 | ci.Save(&stream); |
1292 | TVMConstantInfo di = |
1293 | TVMConstantInfo{"default-none" , 0, static_cast<int64_t>(bytes.size()), bytes.c_str()}; |
1294 | |
1295 | std::vector<runtime::metadata::MetadataBase> queue; |
1296 | queue.push_back(runtime::metadata::ConstantInfoMetadata(&di)); |
1297 | |
1298 | metadata::DiscoverComplexTypesVisitor discover_complex{&queue}; |
1299 | discover_complex.Discover(metadata); |
1300 | |
1301 | MetadataTypeDefiner definer{ctx, &llvm_types}; |
1302 | for (auto md : queue) { |
1303 | if (md.defined()) { |
1304 | definer.DefineType(md); |
1305 | } |
1306 | } |
1307 | |
1308 | MetadataSerializerLLVM serializer{this, &llvm_types}; |
1309 | auto metadata_constant_gv = serializer.Serialize(metadata); |
1310 | |
1311 | function_ = |
1312 | llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, |
1313 | runtime::symbol::tvm_get_c_metadata, module_.get()); |
1314 | SetTargetAttributes(function_); |
1315 | function_->setCallingConv(llvm::CallingConv::C); |
1316 | function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); |
1317 | |
1318 | llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx, "entry" , function_); |
1319 | builder_->SetInsertPoint(entry_point_entry); |
1320 | |
1321 | auto ret_values_p = builder_->CreateBitCast(GetArg(function_, 3), t_void_p_->getPointerTo()); |
1322 | builder_->CreateStore(builder_->CreateBitCast(metadata_constant_gv, t_void_p_), ret_values_p); |
1323 | |
1324 | auto ret_tcode = builder_->CreateBitCast(GetArg(function_, 4), t_int_->getPointerTo()); |
1325 | builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode); |
1326 | |
1327 | builder_->CreateRet(ConstInt32(0)); |
1328 | } |
1329 | |
1330 | void CodeGenCPU::DefineFunctionRegistry(Array<String> func_names) { |
1331 | ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime" ; |
1332 | Array<String> symbols; |
1333 | std::vector<llvm::Constant*> funcs; |
1334 | for (auto sym : func_names) { |
1335 | symbols.push_back(sym); |
1336 | auto* sym_func = |
1337 | llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::GlobalValue::ExternalLinkage, |
1338 | sym.operator std::string(), module_.get()); |
1339 | |
1340 | funcs.emplace_back(sym_func); |
1341 | } |
1342 | llvm::ArrayType* t_tvm_crt_func_ptrs = |
1343 | llvm::ArrayType::get(ftype_tvm_backend_packed_c_func_->getPointerTo(), funcs.size()); |
1344 | llvm::DataLayout layout(module_.get()); |
1345 | |
1346 | llvm::GlobalVariable* func_registry_ptrs = new llvm::GlobalVariable( |
1347 | *module_, t_tvm_crt_func_ptrs, true, llvm::GlobalValue::InternalLinkage, |
1348 | llvm::ConstantArray::get(t_tvm_crt_func_ptrs, funcs), "_tvm_func_registry_ptrs" ); |
1349 | |
1350 | uint64_t align = layout.getTypeAllocSize(ftype_tvm_backend_packed_c_func_->getPointerTo()); |
1351 | #if TVM_LLVM_VERSION >= 100 |
1352 | func_registry_ptrs->setAlignment(llvm::Align(align)); |
1353 | #else |
1354 | func_registry_ptrs->setAlignment(align); |
1355 | #endif |
1356 | llvm::GlobalVariable* func_registry = new llvm::GlobalVariable( |
1357 | *module_, t_tvm_crt_func_registry_, true, llvm::GlobalVariable::InternalLinkage, |
1358 | llvm::ConstantStruct::get( |
1359 | t_tvm_crt_func_registry_, |
1360 | {GetConstString(::tvm::target::GenerateFuncRegistryNames(symbols)), |
1361 | llvm::ConstantExpr::getBitCast(func_registry_ptrs, |
1362 | ftype_tvm_backend_packed_c_func_->getPointerTo())}), |
1363 | "_tvm_crt_func_registry" ); |
1364 | llvm::GlobalVariable* module = new llvm::GlobalVariable( |
1365 | *module_, t_tvm_crt_module_, true, llvm::GlobalValue::InternalLinkage, |
1366 | llvm::ConstantStruct::get(t_tvm_crt_module_, {func_registry}), "_tvm_crt_module" ); |
1367 | |
1368 | // Now build TVMSystemLibEntryPoint. |
1369 | llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false); |
1370 | function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, |
1371 | "TVMSystemLibEntryPoint" , module_.get()); |
1372 | SetTargetAttributes(function_); |
1373 | llvm::BasicBlock* entry_point_entry = |
1374 | llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry" , function_); |
1375 | builder_->SetInsertPoint(entry_point_entry); |
1376 | builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_)); |
1377 | } |
1378 | |
1379 | void CodeGenCPU::AddStartupFunction() { |
1380 | if (!target_c_runtime_) { |
1381 | llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false); |
1382 | function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, |
1383 | "__tvm_module_startup" , module_.get()); |
1384 | SetTargetAttributes(function_); |
1385 | llvm::BasicBlock* startup_entry = |
1386 | llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry" , function_); |
1387 | builder_->SetInsertPoint(startup_entry); |
1388 | for (const auto& kv : export_system_symbols_) { |
1389 | llvm::Value* name = GetConstString(kv.first); |
1390 | builder_->CreateCall(f_tvm_register_system_symbol_, |
1391 | {name, builder_->CreateBitCast(kv.second, t_void_p_)}); |
1392 | } |
1393 | llvm::appendToGlobalCtors(*module_, function_, 65535); |
1394 | builder_->CreateRet(nullptr); |
1395 | } |
1396 | } |
1397 | |
1398 | llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { |
1399 | if (op->op.same_as(builtin::tvm_call_packed_lowered())) { |
1400 | return CreateCallPacked(op, true /* use_string_lookup */); |
1401 | } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { |
1402 | return CreateCallTracePacked(op); |
1403 | } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { |
1404 | return CreateCallPacked(op, false /* use_string_lookup */); |
1405 | } else if (op->op.same_as(builtin::tvm_static_handle())) { |
1406 | return CreateStaticHandle(); |
1407 | } else if (op->op.same_as(builtin::tvm_throw_last_error())) { |
1408 | builder_->CreateRet(ConstInt32(-1)); |
1409 | auto next_block = std::next(builder_->GetInsertBlock()->getIterator()); |
1410 | llvm::BasicBlock* new_bb = |
1411 | llvm::BasicBlock::Create(*llvm_target_->GetContext(), "cont" , function_, &*next_block); |
1412 | builder_->SetInsertPoint(new_bb); |
1413 | return ConstInt32(-1); |
1414 | } else if (op->op.same_as(builtin::tvm_struct_get())) { |
1415 | ICHECK_EQ(op->args.size(), 3U); |
1416 | int kind = op->args[2].as<IntImmNode>()->value; |
1417 | TypedPointer ref = |
1418 | CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); |
1419 | if (kind == builtin::kArrAddr) { |
1420 | return builder_->CreatePointerCast(ref.addr, t_void_p_); |
1421 | } else { |
1422 | return builder_->CreateLoad(ref.type, ref.addr); |
1423 | } |
1424 | } else if (op->op.same_as(builtin::tvm_struct_set())) { |
1425 | ICHECK_EQ(op->args.size(), 4U); |
1426 | int kind = op->args[2].as<IntImmNode>()->value; |
1427 | llvm::Value* value = MakeValue(op->args[3]); |
1428 | TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), |
1429 | MakeValue(op->args[1]), kind); |
1430 | ICHECK(kind != builtin::kArrAddr); |
1431 | if (value->getType()->isPointerTy()) { |
1432 | value = builder_->CreatePointerCast(value, ref.type); |
1433 | } |
1434 | builder_->CreateStore(value, ref.addr); |
1435 | return ConstInt32(0); |
1436 | } else if (op->op.same_as(builtin::tvm_stack_alloca())) { |
1437 | ICHECK_EQ(op->args.size(), 2U); |
1438 | const std::string& type = op->args[0].as<StringImmNode>()->value; |
1439 | return WithFunctionEntry([&]() -> llvm::AllocaInst* { |
1440 | const int64_t* pval = as_const_int(op->args[1]); |
1441 | ICHECK(pval) << "require stack alloca to contain constant value" ; |
1442 | llvm::Value* num = ConstInt32(pval[0]); |
1443 | if (type == "shape" ) { |
1444 | return builder_->CreateAlloca(t_tvm_shape_index_, num); |
1445 | } else if (type == "arg_value" ) { |
1446 | return builder_->CreateAlloca(t_tvm_value_, num); |
1447 | } else if (type == "arg_tcode" ) { |
1448 | return builder_->CreateAlloca(t_int_, num); |
1449 | } else if (type == "array" ) { |
1450 | return builder_->CreateAlloca(t_tvm_array_, num); |
1451 | } else { |
1452 | LOG(FATAL) << "Unknown stack alloca type " << type; |
1453 | } |
1454 | }); |
1455 | } else { |
1456 | return CodeGenLLVM::CreateIntrinsic(op); |
1457 | } |
1458 | } |
1459 | |
1460 | void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { |
1461 | EmitDebugLocation(op); |
1462 | llvm::Value* cond = MakeValue(op->condition); |
1463 | std::ostringstream os; |
1464 | os << "Assert fail: " << op->condition; |
1465 | if (op->message.as<StringImmNode>()) { |
1466 | os << ", " << op->message.as<StringImmNode>()->value; |
1467 | } |
1468 | llvm::Value* msg = GetConstString(os.str()); |
1469 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
1470 | auto* fail_block = llvm::BasicBlock::Create(*ctx, "assert_fail" , function_); |
1471 | auto* end_block = llvm::BasicBlock::Create(*ctx, "assert_end" , function_); |
1472 | builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); |
1473 | // fail condition. |
1474 | builder_->SetInsertPoint(fail_block); |
1475 | |
1476 | #if TVM_LLVM_VERSION >= 90 |
1477 | auto err_callee = |
1478 | llvm::FunctionCallee(ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); |
1479 | #else |
1480 | auto err_callee = RuntimeTVMAPISetLastError(); |
1481 | #endif |
1482 | builder_->CreateCall(err_callee, {msg}); |
1483 | builder_->CreateRet(ConstInt32(-1)); |
1484 | // otherwise set it to be new end. |
1485 | builder_->SetInsertPoint(end_block); |
1486 | CodeGenLLVM::VisitStmt_(op); |
1487 | } |
1488 | |
1489 | void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { |
1490 | EmitDebugLocation(op); |
1491 | if (op->attr_key == tir::attr::coproc_uop_scope) { |
1492 | const StringImmNode* value = op->value.as<StringImmNode>(); |
1493 | ICHECK(value != nullptr); |
1494 | this->CreateStaticInit(value->value, op->body); |
1495 | } else if (op->attr_key == tir::attr::compute_scope) { |
1496 | this->CreateComputeScope(op); |
1497 | } else if (tir::attr::IsPragmaKey(op->attr_key)) { |
1498 | if (op->attr_key == "pragma_parallel_stride_pattern" ) { |
1499 | ICHECK(parallel_env_.penv != nullptr) |
1500 | << "Pragma parallel_stride_pattern only valid in parallel launch" ; |
1501 | parallel_env_.stride_pattern = true; |
1502 | this->VisitStmt(op->body); |
1503 | } else if (op->attr_key == "pragma_parallel_launch_point" ) { |
1504 | CreateParallelLaunch(op->body, 0, "pragma_parallel" ); |
1505 | } else if (op->attr_key == "pragma_parallel_barrier_when_finish" ) { |
1506 | ICHECK(parallel_env_.penv != nullptr) << "Cannot run barrier without parallel environment" ; |
1507 | ICHECK(!parallel_env_.in_parallel_loop) |
1508 | << "Cannot not place within parallel loop as the workload may differ, " |
1509 | << " place it between parallel and parallel_launch_point" ; |
1510 | this->VisitStmt(op->body); |
1511 | #if TVM_LLVM_VERSION >= 90 |
1512 | auto bar_callee = |
1513 | llvm::FunctionCallee(ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier()); |
1514 | #else |
1515 | auto bar_callee = RuntimeTVMParallelBarrier(); |
1516 | #endif |
1517 | builder_->CreateCall(bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv}); |
1518 | } else if (op->attr_key == tir::attr::pragma_import_llvm) { |
1519 | const StringImmNode* value = op->value.as<StringImmNode>(); |
1520 | ICHECK(value != nullptr); |
1521 | this->HandleImport(value->value); |
1522 | this->VisitStmt(op->body); |
1523 | } else { |
1524 | LOG(WARNING) << "Unknown pragma " << op->attr_key; |
1525 | this->VisitStmt(op->body); |
1526 | } |
1527 | } else { |
1528 | CodeGenLLVM::VisitStmt_(op); |
1529 | } |
1530 | } |
1531 | |
1532 | void CodeGenCPU::VisitStmt_(const ForNode* op) { |
1533 | EmitDebugLocation(op); |
1534 | ICHECK(is_zero(op->min)); |
1535 | if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) { |
1536 | CodeGenLLVM::VisitStmt_(op); |
1537 | } else if (op->kind == ForKind::kParallel) { |
1538 | if (parallel_env_.penv == nullptr) { |
1539 | CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, |
1540 | op->thread_binding, op->annotations), |
1541 | 0, std::string("loop_parallel_" ) + op->loop_var->name_hint.c_str()); |
1542 | } else { |
1543 | // already in parallel env. |
1544 | ICHECK(parallel_env_.task_id.defined()); |
1545 | ICHECK(parallel_env_.num_task.defined()); |
1546 | ICHECK(parallel_env_.penv != nullptr); |
1547 | DataType t = op->extent.dtype(); |
1548 | PrimExpr num_task = cast(t, parallel_env_.num_task); |
1549 | PrimExpr task_id = cast(t, parallel_env_.task_id); |
1550 | ICHECK(!parallel_env_.in_parallel_loop) |
1551 | << "Nested parallel loop is not supported by threadpool, try fuse them instead" ; |
1552 | parallel_env_.in_parallel_loop = true; |
1553 | if (parallel_env_.stride_pattern) { |
1554 | CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), |
1555 | op->loop_var, op->body); |
1556 | } else { |
1557 | PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; |
1558 | PrimExpr begin = min(task_id * step, op->extent); |
1559 | PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent); |
1560 | CreateSerialFor(MakeValue(begin), MakeValue(end), |
1561 | llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); |
1562 | } |
1563 | parallel_env_.in_parallel_loop = false; |
1564 | ++parallel_env_.parallel_loop_count; |
1565 | } |
1566 | } else { |
1567 | LOG(FATAL) << "cannot handle for type " << op->kind; |
1568 | } |
1569 | } |
1570 | |
1571 | TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu" ) |
1572 | .set_body([](const TVMArgs& targs, TVMRetValue* rv) { |
1573 | *rv = static_cast<void*>(new CodeGenCPU()); |
1574 | }); |
1575 | |
1576 | } // namespace codegen |
1577 | } // namespace tvm |
1578 | |
1579 | #endif // TVM_LLVM_VERSION |
1580 | |