1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_amdgpu.cc
22 * \brief AMDGPU code generator.
23 */
24#ifdef TVM_LLVM_VERSION
25
26#include <llvm/ADT/SmallString.h>
27#include <llvm/IR/Attributes.h>
28#include <llvm/IR/CallingConv.h>
29#include <llvm/IR/Function.h>
30#include <llvm/IR/GlobalValue.h>
31#include <llvm/IR/Instructions.h>
32#include <llvm/IR/Intrinsics.h>
33#if TVM_LLVM_VERSION >= 100
34#include <llvm/IR/IntrinsicsAMDGPU.h>
35#endif
36#include <llvm/IR/LegacyPassManager.h>
37#include <llvm/IRReader/IRReader.h>
38#if TVM_LLVM_VERSION >= 100
39#include <llvm/Support/Alignment.h>
40#endif
41#include <llvm/Support/CodeGen.h>
42#include <llvm/Support/SourceMgr.h>
43#include <llvm/Support/raw_ostream.h>
44#include <llvm/Target/TargetMachine.h>
45#include <llvm/Transforms/IPO/PassManagerBuilder.h>
46#include <llvm/Transforms/Utils/Cloning.h>
47#include <tvm/runtime/c_runtime_api.h>
48#include <tvm/runtime/device_api.h>
49#include <tvm/runtime/registry.h>
50
51#include "../../runtime/rocm/rocm_module.h"
52#include "../build_common.h"
53#include "codegen_llvm.h"
54#include "llvm_instance.h"
55
56namespace tvm {
57namespace codegen {
58
59namespace {
60
61// calls the device api to get the max threads per block
62static inline int DetectROCMmaxThreadsPerBlock() {
63 Device tvm_dev;
64 tvm_dev.device_type = kDLROCM;
65 tvm_dev.device_id = 0;
66 tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_dev, true);
67 if (api != nullptr) {
68 TVMRetValue val;
69 api->GetAttr(tvm_dev, tvm::runtime::kExist, &val);
70 if (val.operator int() == 1) {
71 tvm::runtime::DeviceAPI::Get(tvm_dev)->GetAttr(tvm_dev, tvm::runtime::kMaxThreadsPerBlock,
72 &val);
73 return val.operator int();
74 }
75 }
76 LOG(WARNING) << "Cannot get maximum number of threads for AMD codegen";
77 return 256; // see the discussion at PR #4342 for the choice of default
78}
79
80} // namespace
81
82// AMDGPU code generator.
83class CodeGenAMDGPU : public CodeGenLLVM {
84 public:
85 CodeGenAMDGPU() = default;
86 virtual ~CodeGenAMDGPU() = default;
87
88 void AddFunction(const PrimFunc& f) final {
89 // add function as void return value
90 CodeGenLLVM::AddFunctionInternal(f, true);
91 function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
92 std::ostringstream attr;
93 attr << "1," << DetectROCMmaxThreadsPerBlock();
94 function_->addFnAttr("amdgpu-flat-work-group-size", attr.str());
95 }
96
97 void VisitStmt_(const AllocateNode* op) final {
98 ICHECK(!is_zero(op->condition));
99 llvm::Value* buf = nullptr;
100 StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
101 auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
102
103 if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") {
104 LOG(WARNING) << "Dynamic shared memory support for rocm is experimental.";
105 buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16),
106 llvm::GlobalValue::ExternalLinkage);
107 } else {
108 size_t constant_size = op->ConstantAllocationSize();
109 ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";
110
111 if (constant_size % 4 == 0 && info.alignment == 0) {
112 info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
113 }
114 // maximum necessary alignment in the AMD devices
115 if (info.alignment > 16) {
116 info.alignment = 16;
117 }
118 if (storage_scope.rank == runtime::StorageRank::kLocal) {
119 // const int local_address_space = 5;
120 // TODO(tqchen): for higher version of LLVM, local address space can be set.
121 llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
122 return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
123 });
124#if TVM_LLVM_VERSION >= 110
125 auto alignment = static_cast<unsigned>(alloca->getAlign().value());
126#else
127 unsigned alignment = alloca->getAlignment();
128#endif
129 if (alignment < static_cast<unsigned>(info.alignment)) {
130#if TVM_LLVM_VERSION >= 100
131 alloca->setAlignment(llvm::Align(info.alignment));
132#else
133 alloca->setAlignment(info.alignment);
134#endif
135 }
136 buf = alloca;
137 } else {
138 ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
139 << "Can only allocate shared or local memory inside kernel";
140 // Shared memory: address space == 3
141 buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment,
142 llvm::GlobalValue::PrivateLinkage);
143 }
144 }
145
146 buf = builder_->CreatePointerCast(
147 buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
148 ICHECK(!var_map_.count(op->buffer_var.get()));
149 var_map_[op->buffer_var.get()] = buf;
150 this->VisitStmt(op->body);
151 }
152
153 // Return the thread index via intrinsics.
154 llvm::Value* GetThreadIndex(const IterVar& iv) final {
155 runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
156 llvm::Intrinsic::ID intrin_id = llvm::Intrinsic::amdgcn_workitem_id_x;
157 if (ts.rank == 1) {
158 switch (ts.dim_index) {
159 case 0:
160 intrin_id = llvm::Intrinsic::amdgcn_workitem_id_x;
161 break;
162 case 1:
163 intrin_id = llvm::Intrinsic::amdgcn_workitem_id_y;
164 break;
165 case 2:
166 intrin_id = llvm::Intrinsic::amdgcn_workitem_id_z;
167 break;
168 default:
169 LOG(FATAL) << "unknown workitem idx";
170 }
171 } else {
172 ICHECK_EQ(ts.rank, 0);
173 switch (ts.dim_index) {
174 case 0:
175 intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_x;
176 break;
177 case 1:
178 intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_y;
179 break;
180 case 2:
181 intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_z;
182 break;
183 default:
184 LOG(FATAL) << "unknown workgroup idx";
185 }
186 }
187 llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
188 return builder_->CreateCall(f, {});
189 }
190
191 llvm::Value* CreateStorageSync(const CallNode* op) final {
192 const std::string& sync = op->args[0].as<StringImmNode>()->value;
193 if (sync == "warp") {
194 return nullptr;
195 } else if (sync == "shared") {
196 llvm::Function* f =
197 llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::amdgcn_s_barrier);
198 return builder_->CreateCall(f, {});
199 } else {
200 LOG(FATAL) << "Do not support sync " << sync;
201 }
202 }
203
204#if TVM_LLVM_VERSION < 160
205 // This function only works with the legacy pass manager.
206 void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
207 // Additional optimization hook to tweak the builder.
208 }
209#endif
210
211 unsigned GetGlobalAddressSpace() const final { return 1; }
212
213 llvm::Value* CreateIntrinsic(const CallNode* op) final {
214 if (op->op.same_as(builtin::atomic_add())) {
215 ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now";
216 llvm::Value* v0 = MakeValue(op->args[0]);
217 llvm::Value* v1 = MakeValue(op->args[1]);
218 if (op->args[1]->dtype.is_float()) {
219#if TVM_LLVM_VERSION >= 90
220#if TVM_LLVM_VERSION >= 130
221 return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(),
222 llvm::AtomicOrdering::Monotonic);
223#else
224 return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1,
225 llvm::AtomicOrdering::Monotonic);
226#endif
227#else
228 LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer";
229#endif
230 }
231#if TVM_LLVM_VERSION >= 130
232 return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1, llvm::MaybeAlign(),
233 llvm::AtomicOrdering::Monotonic);
234#else
235 return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1,
236 llvm::AtomicOrdering::Monotonic);
237#endif
238 }
239 return CodeGenLLVM::CreateIntrinsic(op);
240 }
241
242 protected:
243 void InitTarget() final {
244 // Maximum vector lane = float4
245 native_vector_bits_ = 4 * 32;
246 CodeGenLLVM::InitTarget();
247 }
248};
249
250runtime::Module BuildAMDGPU(IRModule mod, Target target) {
251 LLVMInstance llvm_instance;
252
253 With<LLVMTarget> llvm_target(llvm_instance, target);
254#if TVM_LLVM_VERSION < 90
255 LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
256 // Lower versions will crash when loading the bitcode, see
257 // issue #4087 for a discussion
258#endif
259 auto cg = std::make_unique<CodeGenAMDGPU>();
260
261 cg->Init("TVMAMDGPUModule", llvm_target.get(), false, false, false);
262
263 cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
264 ICHECK(kv.second->template IsInstance<PrimFuncNode>())
265 << "Can only lower IR Module with PrimFuncs";
266 return Downcast<PrimFunc>(kv.second);
267 });
268
269 llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
270 const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
271 Array<runtime::String> bitcode_files = (*find_rocm_bitcodes)();
272
273 for (auto& bitcode_path : bitcode_files) {
274 std::unique_ptr<llvm::Module> mlib = llvm_instance.LoadIR(bitcode_path);
275 mlib->setTargetTriple(llvm_target->GetTargetTriple());
276 mlib->setDataLayout(tm->createDataLayout());
277
278 for (llvm::Function& f : mlib->functions()) {
279 f.addFnAttr(llvm::Attribute::AlwaysInline);
280 }
281 cg->AddLinkModule(std::move(mlib));
282 }
283
284 std::unique_ptr<llvm::Module> module = cg->Finish();
285 llvm::SmallString<8> dataObj, data_ll, dataAsm;
286 llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm);
287 destObj.SetUnbuffered();
288 dest_ll.SetUnbuffered();
289 destAsm.SetUnbuffered();
290 module->print(dest_ll, nullptr);
291#if TVM_LLVM_VERSION <= 60
292 std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(module.get());
293 std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(module.get());
294#else
295 std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(*module.get());
296 std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(*module.get());
297#endif
298 llvm::legacy::PassManager pass;
299
300#if TVM_LLVM_VERSION <= 60
301 ICHECK(tm->addPassesToEmitFile(pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0)
302 << "Cannot emit target CGFT_ObjectFile";
303#elif TVM_LLVM_VERSION <= 90
304 ICHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0)
305 << "Cannot emit target CGFT_ObjectFile";
306#else
307 ICHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0)
308 << "Cannot emit target CGFT_ObjectFile";
309#endif
310 pass.run(*mObj);
311 std::string obj(dataObj.begin(), dataObj.end());
312
313 llvm::legacy::PassManager passAsm;
314#if TVM_LLVM_VERSION <= 60
315 ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
316 << "Cannot emit target CGFT_AssemblyFile";
317#elif TVM_LLVM_VERSION <= 90
318 ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr,
319 llvm::TargetMachine::CGFT_AssemblyFile) == 0)
320 << "Cannot emit target CGFT_AssemblyFile";
321#else
322 ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CGFT_AssemblyFile) == 0)
323 << "Cannot emit target CGFT_AssemblyFile";
324#endif
325 passAsm.run(*mAsm);
326 std::string assembly(dataAsm.begin(), dataAsm.end());
327
328 const auto* f = tvm::runtime::Registry::Get("tvm_callback_rocm_link");
329 ICHECK(f != nullptr) << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm";
330
331 TVMByteArray arr;
332 arr.data = &obj[0];
333 arr.size = obj.length();
334
335 std::string hsaco = (*f)(arr);
336 std::string ll(data_ll.begin(), data_ll.end());
337 return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly);
338}
339
340TVM_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU);
341
342TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm")
343 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
344 *rv = static_cast<void*>(new CodeGenAMDGPU());
345 });
346
347} // namespace codegen
348} // namespace tvm
349
350#endif // TVM_LLVM_VERSION
351