1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_amdgpu.cc |
22 | * \brief AMDGPU code generator. |
23 | */ |
24 | #ifdef TVM_LLVM_VERSION |
25 | |
26 | #include <llvm/ADT/SmallString.h> |
27 | #include <llvm/IR/Attributes.h> |
28 | #include <llvm/IR/CallingConv.h> |
29 | #include <llvm/IR/Function.h> |
30 | #include <llvm/IR/GlobalValue.h> |
31 | #include <llvm/IR/Instructions.h> |
32 | #include <llvm/IR/Intrinsics.h> |
33 | #if TVM_LLVM_VERSION >= 100 |
34 | #include <llvm/IR/IntrinsicsAMDGPU.h> |
35 | #endif |
36 | #include <llvm/IR/LegacyPassManager.h> |
37 | #include <llvm/IRReader/IRReader.h> |
38 | #if TVM_LLVM_VERSION >= 100 |
39 | #include <llvm/Support/Alignment.h> |
40 | #endif |
41 | #include <llvm/Support/CodeGen.h> |
42 | #include <llvm/Support/SourceMgr.h> |
43 | #include <llvm/Support/raw_ostream.h> |
44 | #include <llvm/Target/TargetMachine.h> |
45 | #include <llvm/Transforms/IPO/PassManagerBuilder.h> |
46 | #include <llvm/Transforms/Utils/Cloning.h> |
47 | #include <tvm/runtime/c_runtime_api.h> |
48 | #include <tvm/runtime/device_api.h> |
49 | #include <tvm/runtime/registry.h> |
50 | |
51 | #include "../../runtime/rocm/rocm_module.h" |
52 | #include "../build_common.h" |
53 | #include "codegen_llvm.h" |
54 | #include "llvm_instance.h" |
55 | |
56 | namespace tvm { |
57 | namespace codegen { |
58 | |
59 | namespace { |
60 | |
61 | // calls the device api to get the max threads per block |
62 | static inline int DetectROCMmaxThreadsPerBlock() { |
63 | Device tvm_dev; |
64 | tvm_dev.device_type = kDLROCM; |
65 | tvm_dev.device_id = 0; |
66 | tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_dev, true); |
67 | if (api != nullptr) { |
68 | TVMRetValue val; |
69 | api->GetAttr(tvm_dev, tvm::runtime::kExist, &val); |
70 | if (val.operator int() == 1) { |
71 | tvm::runtime::DeviceAPI::Get(tvm_dev)->GetAttr(tvm_dev, tvm::runtime::kMaxThreadsPerBlock, |
72 | &val); |
73 | return val.operator int(); |
74 | } |
75 | } |
76 | LOG(WARNING) << "Cannot get maximum number of threads for AMD codegen" ; |
77 | return 256; // see the discussion at PR #4342 for the choice of default |
78 | } |
79 | |
80 | } // namespace |
81 | |
82 | // AMDGPU code generator. |
83 | class CodeGenAMDGPU : public CodeGenLLVM { |
84 | public: |
85 | CodeGenAMDGPU() = default; |
86 | virtual ~CodeGenAMDGPU() = default; |
87 | |
88 | void AddFunction(const PrimFunc& f) final { |
89 | // add function as void return value |
90 | CodeGenLLVM::AddFunctionInternal(f, true); |
91 | function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); |
92 | std::ostringstream attr; |
93 | attr << "1," << DetectROCMmaxThreadsPerBlock(); |
94 | function_->addFnAttr("amdgpu-flat-work-group-size" , attr.str()); |
95 | } |
96 | |
97 | void VisitStmt_(const AllocateNode* op) final { |
98 | ICHECK(!is_zero(op->condition)); |
99 | llvm::Value* buf = nullptr; |
100 | StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; |
101 | auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); |
102 | |
103 | if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn" ) { |
104 | LOG(WARNING) << "Dynamic shared memory support for rocm is experimental." ; |
105 | buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16), |
106 | llvm::GlobalValue::ExternalLinkage); |
107 | } else { |
108 | size_t constant_size = op->ConstantAllocationSize(); |
109 | ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU" ; |
110 | |
111 | if (constant_size % 4 == 0 && info.alignment == 0) { |
112 | info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); |
113 | } |
114 | // maximum necessary alignment in the AMD devices |
115 | if (info.alignment > 16) { |
116 | info.alignment = 16; |
117 | } |
118 | if (storage_scope.rank == runtime::StorageRank::kLocal) { |
119 | // const int local_address_space = 5; |
120 | // TODO(tqchen): for higher version of LLVM, local address space can be set. |
121 | llvm::AllocaInst* alloca = WithFunctionEntry([&]() { |
122 | return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); |
123 | }); |
124 | #if TVM_LLVM_VERSION >= 110 |
125 | auto alignment = static_cast<unsigned>(alloca->getAlign().value()); |
126 | #else |
127 | unsigned alignment = alloca->getAlignment(); |
128 | #endif |
129 | if (alignment < static_cast<unsigned>(info.alignment)) { |
130 | #if TVM_LLVM_VERSION >= 100 |
131 | alloca->setAlignment(llvm::Align(info.alignment)); |
132 | #else |
133 | alloca->setAlignment(info.alignment); |
134 | #endif |
135 | } |
136 | buf = alloca; |
137 | } else { |
138 | ICHECK(storage_scope.rank == runtime::StorageRank::kShared) |
139 | << "Can only allocate shared or local memory inside kernel" ; |
140 | // Shared memory: address space == 3 |
141 | buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, |
142 | llvm::GlobalValue::PrivateLinkage); |
143 | } |
144 | } |
145 | |
146 | buf = builder_->CreatePointerCast( |
147 | buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); |
148 | ICHECK(!var_map_.count(op->buffer_var.get())); |
149 | var_map_[op->buffer_var.get()] = buf; |
150 | this->VisitStmt(op->body); |
151 | } |
152 | |
153 | // Return the thread index via intrinsics. |
154 | llvm::Value* GetThreadIndex(const IterVar& iv) final { |
155 | runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); |
156 | llvm::Intrinsic::ID intrin_id = llvm::Intrinsic::amdgcn_workitem_id_x; |
157 | if (ts.rank == 1) { |
158 | switch (ts.dim_index) { |
159 | case 0: |
160 | intrin_id = llvm::Intrinsic::amdgcn_workitem_id_x; |
161 | break; |
162 | case 1: |
163 | intrin_id = llvm::Intrinsic::amdgcn_workitem_id_y; |
164 | break; |
165 | case 2: |
166 | intrin_id = llvm::Intrinsic::amdgcn_workitem_id_z; |
167 | break; |
168 | default: |
169 | LOG(FATAL) << "unknown workitem idx" ; |
170 | } |
171 | } else { |
172 | ICHECK_EQ(ts.rank, 0); |
173 | switch (ts.dim_index) { |
174 | case 0: |
175 | intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_x; |
176 | break; |
177 | case 1: |
178 | intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_y; |
179 | break; |
180 | case 2: |
181 | intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_z; |
182 | break; |
183 | default: |
184 | LOG(FATAL) << "unknown workgroup idx" ; |
185 | } |
186 | } |
187 | llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); |
188 | return builder_->CreateCall(f, {}); |
189 | } |
190 | |
191 | llvm::Value* CreateStorageSync(const CallNode* op) final { |
192 | const std::string& sync = op->args[0].as<StringImmNode>()->value; |
193 | if (sync == "warp" ) { |
194 | return nullptr; |
195 | } else if (sync == "shared" ) { |
196 | llvm::Function* f = |
197 | llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::amdgcn_s_barrier); |
198 | return builder_->CreateCall(f, {}); |
199 | } else { |
200 | LOG(FATAL) << "Do not support sync " << sync; |
201 | } |
202 | } |
203 | |
204 | #if TVM_LLVM_VERSION < 160 |
205 | // This function only works with the legacy pass manager. |
206 | void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final { |
207 | // Additional optimization hook to tweak the builder. |
208 | } |
209 | #endif |
210 | |
211 | unsigned GetGlobalAddressSpace() const final { return 1; } |
212 | |
213 | llvm::Value* CreateIntrinsic(const CallNode* op) final { |
214 | if (op->op.same_as(builtin::atomic_add())) { |
215 | ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now" ; |
216 | llvm::Value* v0 = MakeValue(op->args[0]); |
217 | llvm::Value* v1 = MakeValue(op->args[1]); |
218 | if (op->args[1]->dtype.is_float()) { |
219 | #if TVM_LLVM_VERSION >= 90 |
220 | #if TVM_LLVM_VERSION >= 130 |
221 | return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(), |
222 | llvm::AtomicOrdering::Monotonic); |
223 | #else |
224 | return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, |
225 | llvm::AtomicOrdering::Monotonic); |
226 | #endif |
227 | #else |
228 | LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer" ; |
229 | #endif |
230 | } |
231 | #if TVM_LLVM_VERSION >= 130 |
232 | return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1, llvm::MaybeAlign(), |
233 | llvm::AtomicOrdering::Monotonic); |
234 | #else |
235 | return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1, |
236 | llvm::AtomicOrdering::Monotonic); |
237 | #endif |
238 | } |
239 | return CodeGenLLVM::CreateIntrinsic(op); |
240 | } |
241 | |
242 | protected: |
243 | void InitTarget() final { |
244 | // Maximum vector lane = float4 |
245 | native_vector_bits_ = 4 * 32; |
246 | CodeGenLLVM::InitTarget(); |
247 | } |
248 | }; |
249 | |
250 | runtime::Module BuildAMDGPU(IRModule mod, Target target) { |
251 | LLVMInstance llvm_instance; |
252 | |
253 | With<LLVMTarget> llvm_target(llvm_instance, target); |
254 | #if TVM_LLVM_VERSION < 90 |
255 | LOG(FATAL) << "AMDGPU backend requires at least LLVM 9" ; |
256 | // Lower versions will crash when loading the bitcode, see |
257 | // issue #4087 for a discussion |
258 | #endif |
259 | auto cg = std::make_unique<CodeGenAMDGPU>(); |
260 | |
261 | cg->Init("TVMAMDGPUModule" , llvm_target.get(), false, false, false); |
262 | |
263 | cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { |
264 | ICHECK(kv.second->template IsInstance<PrimFuncNode>()) |
265 | << "Can only lower IR Module with PrimFuncs" ; |
266 | return Downcast<PrimFunc>(kv.second); |
267 | }); |
268 | |
269 | llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); |
270 | const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path" ); |
271 | Array<runtime::String> bitcode_files = (*find_rocm_bitcodes)(); |
272 | |
273 | for (auto& bitcode_path : bitcode_files) { |
274 | std::unique_ptr<llvm::Module> mlib = llvm_instance.LoadIR(bitcode_path); |
275 | mlib->setTargetTriple(llvm_target->GetTargetTriple()); |
276 | mlib->setDataLayout(tm->createDataLayout()); |
277 | |
278 | for (llvm::Function& f : mlib->functions()) { |
279 | f.addFnAttr(llvm::Attribute::AlwaysInline); |
280 | } |
281 | cg->AddLinkModule(std::move(mlib)); |
282 | } |
283 | |
284 | std::unique_ptr<llvm::Module> module = cg->Finish(); |
285 | llvm::SmallString<8> dataObj, data_ll, dataAsm; |
286 | llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm); |
287 | destObj.SetUnbuffered(); |
288 | dest_ll.SetUnbuffered(); |
289 | destAsm.SetUnbuffered(); |
290 | module->print(dest_ll, nullptr); |
291 | #if TVM_LLVM_VERSION <= 60 |
292 | std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(module.get()); |
293 | std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(module.get()); |
294 | #else |
295 | std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(*module.get()); |
296 | std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(*module.get()); |
297 | #endif |
298 | llvm::legacy::PassManager pass; |
299 | |
300 | #if TVM_LLVM_VERSION <= 60 |
301 | ICHECK(tm->addPassesToEmitFile(pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) |
302 | << "Cannot emit target CGFT_ObjectFile" ; |
303 | #elif TVM_LLVM_VERSION <= 90 |
304 | ICHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) |
305 | << "Cannot emit target CGFT_ObjectFile" ; |
306 | #else |
307 | ICHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) |
308 | << "Cannot emit target CGFT_ObjectFile" ; |
309 | #endif |
310 | pass.run(*mObj); |
311 | std::string obj(dataObj.begin(), dataObj.end()); |
312 | |
313 | llvm::legacy::PassManager passAsm; |
314 | #if TVM_LLVM_VERSION <= 60 |
315 | ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0) |
316 | << "Cannot emit target CGFT_AssemblyFile" ; |
317 | #elif TVM_LLVM_VERSION <= 90 |
318 | ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, |
319 | llvm::TargetMachine::CGFT_AssemblyFile) == 0) |
320 | << "Cannot emit target CGFT_AssemblyFile" ; |
321 | #else |
322 | ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CGFT_AssemblyFile) == 0) |
323 | << "Cannot emit target CGFT_AssemblyFile" ; |
324 | #endif |
325 | passAsm.run(*mAsm); |
326 | std::string assembly(dataAsm.begin(), dataAsm.end()); |
327 | |
328 | const auto* f = tvm::runtime::Registry::Get("tvm_callback_rocm_link" ); |
329 | ICHECK(f != nullptr) << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm" ; |
330 | |
331 | TVMByteArray arr; |
332 | arr.data = &obj[0]; |
333 | arr.size = obj.length(); |
334 | |
335 | std::string hsaco = (*f)(arr); |
336 | std::string ll(data_ll.begin(), data_ll.end()); |
337 | return ROCMModuleCreate(hsaco, "hsaco" , ExtractFuncInfo(mod), ll, assembly); |
338 | } |
339 | |
340 | TVM_REGISTER_GLOBAL("target.build.rocm" ).set_body_typed(BuildAMDGPU); |
341 | |
342 | TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm" ) |
343 | .set_body([](const TVMArgs& targs, TVMRetValue* rv) { |
344 | *rv = static_cast<void*>(new CodeGenAMDGPU()); |
345 | }); |
346 | |
347 | } // namespace codegen |
348 | } // namespace tvm |
349 | |
350 | #endif // TVM_LLVM_VERSION |
351 | |