1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_nvptx.cc |
22 | * \brief NVPTX code generator. |
23 | */ |
24 | #ifdef TVM_LLVM_VERSION |
25 | |
26 | #include <llvm/ADT/SmallString.h> |
27 | #include <llvm/IR/Attributes.h> |
28 | #include <llvm/IR/Function.h> |
29 | #include <llvm/IR/GlobalValue.h> |
30 | #include <llvm/IR/InlineAsm.h> |
31 | #include <llvm/IR/Instructions.h> |
32 | #include <llvm/IR/Intrinsics.h> |
33 | #if TVM_LLVM_VERSION >= 100 |
34 | #include <llvm/IR/IntrinsicsNVPTX.h> |
35 | #endif |
36 | #include <llvm/IR/LegacyPassManager.h> |
37 | #include <llvm/IR/Metadata.h> |
38 | #include <llvm/IR/Module.h> |
39 | #include <llvm/IR/Type.h> |
40 | #include <llvm/IRReader/IRReader.h> |
41 | #if TVM_LLVM_VERSION >= 100 |
42 | #include <llvm/Support/Alignment.h> |
43 | #endif |
44 | #include <llvm/Support/CodeGen.h> |
45 | #include <llvm/Support/SourceMgr.h> |
46 | #include <llvm/Support/raw_ostream.h> |
47 | #include <llvm/Target/TargetMachine.h> |
48 | #include <llvm/Transforms/IPO/PassManagerBuilder.h> |
49 | #include <tvm/runtime/device_api.h> |
50 | |
51 | #include <memory> |
52 | #include <string> |
53 | #include <utility> |
54 | #include <vector> |
55 | |
56 | #include "../../runtime/cuda/cuda_module.h" |
57 | #include "../build_common.h" |
58 | #include "codegen_llvm.h" |
59 | #include "llvm_instance.h" |
60 | |
61 | namespace tvm { |
62 | namespace codegen { |
63 | |
64 | // NVPTX code generator. |
65 | class CodeGenNVPTX : public CodeGenLLVM { |
66 | public: |
67 | void AddFunction(const PrimFunc& f) final { |
68 | // add function as void return value |
69 | CodeGenLLVM::AddFunctionInternal(f, true); |
70 | // annotate as kernel function |
71 | llvm::LLVMContext* ctx = llvm_target_->GetContext(); |
72 | module_->getOrInsertNamedMetadata("nvvm.annotations" ) |
73 | ->addOperand(llvm::MDNode::get( |
74 | *ctx, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx, "kernel" ), |
75 | llvm::ValueAsMetadata::get(ConstInt32(1))})); |
76 | } |
77 | |
78 | void VisitStmt_(const AllocateNode* op) final { |
79 | ICHECK(!is_zero(op->condition)); |
80 | llvm::Value* buf = nullptr; |
81 | StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; |
82 | // maximum necessary alignment in the NV devices |
83 | if (info.alignment > 16) { |
84 | info.alignment = 16; |
85 | } |
86 | |
87 | auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); |
88 | if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn" ) { |
89 | // Shared memory: address space == 3 |
90 | buf = |
91 | AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); |
92 | } else { |
93 | size_t constant_size = op->ConstantAllocationSize(); |
94 | ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU" ; |
95 | |
96 | if (constant_size % 4 == 0 && info.alignment == 0) { |
97 | info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); |
98 | } |
99 | if (storage_scope.rank == runtime::StorageRank::kLocal) { |
100 | // const int local_address_space = 5; |
101 | // TODO(tqchen): for higher version of LLVM, local address space can be set. |
102 | llvm::AllocaInst* alloca = WithFunctionEntry([&]() { |
103 | return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); |
104 | }); |
105 | #if TVM_LLVM_VERSION >= 110 |
106 | auto alignment = static_cast<unsigned>(alloca->getAlign().value()); |
107 | #else |
108 | unsigned alignment = alloca->getAlignment(); |
109 | #endif |
110 | if (alignment < static_cast<unsigned>(info.alignment)) { |
111 | #if TVM_LLVM_VERSION >= 100 |
112 | alloca->setAlignment(llvm::Align(info.alignment)); |
113 | #else |
114 | alloca->setAlignment(info.alignment); |
115 | #endif |
116 | } |
117 | buf = alloca; |
118 | } else { |
119 | ICHECK(storage_scope.rank == runtime::StorageRank::kShared) |
120 | << "Can only allocate shared or local memory inside kernel" ; |
121 | buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, |
122 | llvm::GlobalValue::PrivateLinkage); |
123 | } |
124 | } |
125 | |
126 | buf = builder_->CreatePointerCast( |
127 | buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); |
128 | ICHECK(!var_map_.count(op->buffer_var.get())); |
129 | var_map_[op->buffer_var.get()] = buf; |
130 | this->VisitStmt(op->body); |
131 | } |
132 | |
133 | // Return the thread index via intrinsics. |
134 | llvm::Value* GetThreadIndex(const IterVar& iv) final { |
135 | runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); |
136 | llvm::Intrinsic::ID intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; |
137 | if (ts.rank == 1) { |
138 | switch (ts.dim_index) { |
139 | case 0: |
140 | intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; |
141 | break; |
142 | case 1: |
143 | intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; |
144 | break; |
145 | case 2: |
146 | intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; |
147 | break; |
148 | default: |
149 | LOG(FATAL) << "unknown thread idx" ; |
150 | } |
151 | } else { |
152 | ICHECK_EQ(ts.rank, 0); |
153 | switch (ts.dim_index) { |
154 | case 0: |
155 | intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; |
156 | break; |
157 | case 1: |
158 | intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; |
159 | break; |
160 | case 2: |
161 | intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; |
162 | break; |
163 | default: |
164 | LOG(FATAL) << "unknown thread idx" ; |
165 | } |
166 | } |
167 | llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); |
168 | return builder_->CreateCall(f, {}); |
169 | } |
170 | |
171 | llvm::Value* CreateStorageSync(const CallNode* op) final { |
172 | const std::string& sync = op->args[0].as<StringImmNode>()->value; |
173 | if (sync == "warp" ) { |
174 | // TODO(tqchen) warp sync in CUDA9 |
175 | return nullptr; |
176 | } else if (sync == "shared" || sync == "shared.dyn" ) { |
177 | llvm::Function* f = |
178 | llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::nvvm_barrier0); |
179 | return builder_->CreateCall(f, {}); |
180 | } else { |
181 | LOG(FATAL) << "Do not support sync " << sync; |
182 | } |
183 | } |
184 | |
185 | #if TVM_LLVM_VERSION < 160 |
186 | // This function only works with the legacy pass manager. |
187 | void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final { |
188 | // Additional optimization hook to tweak the builder. |
189 | } |
190 | #endif |
191 | |
192 | void Optimize() final { |
193 | for (auto& f : *module_) { |
194 | auto fname = static_cast<std::string>(f.getName()); |
195 | if (fname.substr(0, 4) != "__nv" ) continue; |
196 | // This is to strip off unused __nv_* functions from the final module |
197 | // The one that is actually used will be inlined at call site |
198 | // Adapted from Halide's runtime linker |
199 | if (!f.isDeclaration() && !f.hasFnAttribute(llvm::Attribute::NoInline)) { |
200 | f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage); |
201 | } |
202 | } |
203 | CodeGenLLVM::Optimize(); |
204 | } |
205 | |
206 | llvm::Value* CreateIntrinsic(const CallNode* op) override; |
207 | |
208 | protected: |
209 | void InitTarget() final { |
210 | // Maximum vector lane = float4 |
211 | native_vector_bits_ = 4 * 32; |
212 | CodeGenLLVM::InitTarget(); |
213 | } |
214 | }; |
215 | |
216 | // Check if this is a warp shuffle intrinsic call and match its |
217 | // corresponding nvvm intrinsic. Return true if the match is successful. |
218 | static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) { |
219 | // Only 32 bit data type is supported. |
220 | if (op->dtype.is_vector() || op->dtype.bits() != 32) { |
221 | return false; |
222 | } |
223 | |
224 | // Intrinsic lookup table. |
225 | // It is difficult to emit _sync verion that works on Pascal. |
226 | // We ignore the mask and only emit the non-sync version for nvptx. |
227 | llvm::Intrinsic::ID ids[] = { |
228 | llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32, |
229 | llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32, |
230 | llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32}; |
231 | |
232 | int offset = 0; |
233 | if (op->op.same_as(builtin::tvm_warp_shuffle())) { |
234 | offset = 0; |
235 | } else if (op->op.same_as(builtin::tvm_warp_shuffle_up())) { |
236 | offset = 2; |
237 | } else if (op->op.same_as(builtin::tvm_warp_shuffle_down())) { |
238 | offset = 4; |
239 | } else { |
240 | return false; |
241 | } |
242 | |
243 | *id = ids[offset + op->dtype.is_float()]; |
244 | return true; |
245 | } |
246 | |
247 | llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { |
248 | llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; |
249 | if (GetWarpShuffleIntrinsic(op, &id)) { |
250 | std::vector<llvm::Value*> arg_value; |
251 | std::vector<llvm::Type*> arg_type; |
252 | // Ignore the first mask operand and remove the last |
253 | // redundant warp_size.. |
254 | size_t n_args = op->args.size() - 1; |
255 | for (size_t i = 1; i < n_args; ++i) { |
256 | arg_value.push_back(MakeValue(op->args[i])); |
257 | arg_type.push_back(arg_value.back()->getType()); |
258 | } |
259 | llvm::Type* return_type = arg_type[0]; |
260 | llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type); |
261 | return builder_->CreateCall(func, arg_value); |
262 | } else if (op->op.same_as(builtin::tvm_warp_activemask())) { |
263 | // Only nvptx target may keep this intrinsic at this point. |
264 | // PTX assembly: asm "activemask.b32 r1;" |
265 | auto fty = llvm::FunctionType::get(t_int32_, false); |
266 | auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0" , "=r" , true); |
267 | return builder_->CreateCall(val); |
268 | } else if (op->op.same_as(builtin::atomic_add())) { |
269 | ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now" ; |
270 | llvm::Value* v0 = MakeValue(op->args[0]); |
271 | llvm::Value* v1 = MakeValue(op->args[1]); |
272 | if (op->args[1]->dtype.is_float()) { |
273 | #if TVM_LLVM_VERSION >= 90 |
274 | #if TVM_LLVM_VERSION >= 130 |
275 | return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(), |
276 | llvm::AtomicOrdering::Monotonic); |
277 | #else |
278 | return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, |
279 | llvm::AtomicOrdering::Monotonic); |
280 | #endif |
281 | #else |
282 | LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer" ; |
283 | #endif |
284 | } |
285 | #if TVM_LLVM_VERSION >= 130 |
286 | return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1, llvm::MaybeAlign(), |
287 | llvm::AtomicOrdering::Monotonic); |
288 | #else |
289 | return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1, |
290 | llvm::AtomicOrdering::Monotonic); |
291 | #endif |
292 | } |
293 | return CodeGenLLVM::CreateIntrinsic(op); |
294 | } |
295 | |
296 | int GetCUDAComputeVersion(const Target& target) { |
297 | Optional<String> mcpu = target->GetAttr<String>("mcpu" ); |
298 | ICHECK(mcpu.defined()) << "InternalError: \"-mcpu\" is undefined in the NVPTX target" ; |
299 | std::string sm_version = mcpu.value(); |
300 | return std::stoi(sm_version.substr(3)); |
301 | } |
302 | |
303 | runtime::Module BuildNVPTX(IRModule mod, Target target) { |
304 | LLVMInstance llvm_instance; |
305 | With<LLVMTarget> llvm_target(llvm_instance, target); |
306 | |
307 | int compute_ver = GetCUDAComputeVersion(target); |
308 | auto cg = std::make_unique<CodeGenNVPTX>(); |
309 | |
310 | cg->Init("TVMPTXModule" , llvm_target.get(), false, false, false); |
311 | |
312 | cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { |
313 | ICHECK(kv.second->template IsInstance<PrimFuncNode>()) |
314 | << "Can only lower IR Module with PrimFuncs" ; |
315 | return Downcast<PrimFunc>(kv.second); |
316 | }); |
317 | |
318 | llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); |
319 | const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path" ); |
320 | if (flibdevice_path != nullptr) { |
321 | std::string path = (*flibdevice_path)(compute_ver); |
322 | if (path.length() != 0) { |
323 | std::unique_ptr<llvm::Module> mlib = llvm_instance.LoadIR(path); |
324 | mlib->setTargetTriple(llvm_target->GetTargetTriple()); |
325 | mlib->setDataLayout(tm->createDataLayout()); |
326 | cg->AddLinkModule(std::move(mlib)); |
327 | } |
328 | } |
329 | std::unique_ptr<llvm::Module> module = cg->Finish(); |
330 | llvm::SmallString<8> data_ptx, data_ll; |
331 | llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll); |
332 | dest_ptx.SetUnbuffered(); |
333 | dest_ll.SetUnbuffered(); |
334 | // print ll |
335 | module->print(dest_ll, nullptr); |
336 | std::string ll(data_ll.begin(), data_ll.end()); |
337 | // emit ptx |
338 | llvm::legacy::PassManager pass; |
339 | #if TVM_LLVM_VERSION <= 60 |
340 | ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0) |
341 | << "Cannot emit target CGFT_ObjectFile" ; |
342 | #elif TVM_LLVM_VERSION <= 90 |
343 | ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == |
344 | 0) |
345 | << "Cannot emit target CGFT_ObjectFile" ; |
346 | #else |
347 | ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) |
348 | << "Cannot emit target CGFT_ObjectFile" ; |
349 | #endif |
350 | pass.run(*module); |
351 | std::string ptx(data_ptx.begin(), data_ptx.end()); |
352 | return CUDAModuleCreate(ptx, "ptx" , ExtractFuncInfo(mod), ll); |
353 | } |
354 | |
355 | TVM_REGISTER_GLOBAL("target.build.nvptx" ).set_body_typed(BuildNVPTX); |
356 | |
357 | TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx" ) |
358 | .set_body([](const TVMArgs& targs, TVMRetValue* rv) { |
359 | *rv = static_cast<void*>(new CodeGenNVPTX()); |
360 | }); |
361 | |
362 | } // namespace codegen |
363 | } // namespace tvm |
364 | |
365 | #endif // TVM_LLVM_VERSION |
366 | |