1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_nvptx.cc
22 * \brief NVPTX code generator.
23 */
24#ifdef TVM_LLVM_VERSION
25
26#include <llvm/ADT/SmallString.h>
27#include <llvm/IR/Attributes.h>
28#include <llvm/IR/Function.h>
29#include <llvm/IR/GlobalValue.h>
30#include <llvm/IR/InlineAsm.h>
31#include <llvm/IR/Instructions.h>
32#include <llvm/IR/Intrinsics.h>
33#if TVM_LLVM_VERSION >= 100
34#include <llvm/IR/IntrinsicsNVPTX.h>
35#endif
36#include <llvm/IR/LegacyPassManager.h>
37#include <llvm/IR/Metadata.h>
38#include <llvm/IR/Module.h>
39#include <llvm/IR/Type.h>
40#include <llvm/IRReader/IRReader.h>
41#if TVM_LLVM_VERSION >= 100
42#include <llvm/Support/Alignment.h>
43#endif
44#include <llvm/Support/CodeGen.h>
45#include <llvm/Support/SourceMgr.h>
46#include <llvm/Support/raw_ostream.h>
47#include <llvm/Target/TargetMachine.h>
48#include <llvm/Transforms/IPO/PassManagerBuilder.h>
49#include <tvm/runtime/device_api.h>
50
51#include <memory>
52#include <string>
53#include <utility>
54#include <vector>
55
56#include "../../runtime/cuda/cuda_module.h"
57#include "../build_common.h"
58#include "codegen_llvm.h"
59#include "llvm_instance.h"
60
61namespace tvm {
62namespace codegen {
63
64// NVPTX code generator.
65class CodeGenNVPTX : public CodeGenLLVM {
66 public:
67 void AddFunction(const PrimFunc& f) final {
68 // add function as void return value
69 CodeGenLLVM::AddFunctionInternal(f, true);
70 // annotate as kernel function
71 llvm::LLVMContext* ctx = llvm_target_->GetContext();
72 module_->getOrInsertNamedMetadata("nvvm.annotations")
73 ->addOperand(llvm::MDNode::get(
74 *ctx, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx, "kernel"),
75 llvm::ValueAsMetadata::get(ConstInt32(1))}));
76 }
77
78 void VisitStmt_(const AllocateNode* op) final {
79 ICHECK(!is_zero(op->condition));
80 llvm::Value* buf = nullptr;
81 StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
82 // maximum necessary alignment in the NV devices
83 if (info.alignment > 16) {
84 info.alignment = 16;
85 }
86
87 auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
88 if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") {
89 // Shared memory: address space == 3
90 buf =
91 AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage);
92 } else {
93 size_t constant_size = op->ConstantAllocationSize();
94 ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";
95
96 if (constant_size % 4 == 0 && info.alignment == 0) {
97 info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
98 }
99 if (storage_scope.rank == runtime::StorageRank::kLocal) {
100 // const int local_address_space = 5;
101 // TODO(tqchen): for higher version of LLVM, local address space can be set.
102 llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
103 return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
104 });
105#if TVM_LLVM_VERSION >= 110
106 auto alignment = static_cast<unsigned>(alloca->getAlign().value());
107#else
108 unsigned alignment = alloca->getAlignment();
109#endif
110 if (alignment < static_cast<unsigned>(info.alignment)) {
111#if TVM_LLVM_VERSION >= 100
112 alloca->setAlignment(llvm::Align(info.alignment));
113#else
114 alloca->setAlignment(info.alignment);
115#endif
116 }
117 buf = alloca;
118 } else {
119 ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
120 << "Can only allocate shared or local memory inside kernel";
121 buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment,
122 llvm::GlobalValue::PrivateLinkage);
123 }
124 }
125
126 buf = builder_->CreatePointerCast(
127 buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
128 ICHECK(!var_map_.count(op->buffer_var.get()));
129 var_map_[op->buffer_var.get()] = buf;
130 this->VisitStmt(op->body);
131 }
132
133 // Return the thread index via intrinsics.
134 llvm::Value* GetThreadIndex(const IterVar& iv) final {
135 runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
136 llvm::Intrinsic::ID intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
137 if (ts.rank == 1) {
138 switch (ts.dim_index) {
139 case 0:
140 intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
141 break;
142 case 1:
143 intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y;
144 break;
145 case 2:
146 intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z;
147 break;
148 default:
149 LOG(FATAL) << "unknown thread idx";
150 }
151 } else {
152 ICHECK_EQ(ts.rank, 0);
153 switch (ts.dim_index) {
154 case 0:
155 intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x;
156 break;
157 case 1:
158 intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y;
159 break;
160 case 2:
161 intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z;
162 break;
163 default:
164 LOG(FATAL) << "unknown thread idx";
165 }
166 }
167 llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
168 return builder_->CreateCall(f, {});
169 }
170
171 llvm::Value* CreateStorageSync(const CallNode* op) final {
172 const std::string& sync = op->args[0].as<StringImmNode>()->value;
173 if (sync == "warp") {
174 // TODO(tqchen) warp sync in CUDA9
175 return nullptr;
176 } else if (sync == "shared" || sync == "shared.dyn") {
177 llvm::Function* f =
178 llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::nvvm_barrier0);
179 return builder_->CreateCall(f, {});
180 } else {
181 LOG(FATAL) << "Do not support sync " << sync;
182 }
183 }
184
185#if TVM_LLVM_VERSION < 160
186 // This function only works with the legacy pass manager.
187 void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
188 // Additional optimization hook to tweak the builder.
189 }
190#endif
191
192 void Optimize() final {
193 for (auto& f : *module_) {
194 auto fname = static_cast<std::string>(f.getName());
195 if (fname.substr(0, 4) != "__nv") continue;
196 // This is to strip off unused __nv_* functions from the final module
197 // The one that is actually used will be inlined at call site
198 // Adapted from Halide's runtime linker
199 if (!f.isDeclaration() && !f.hasFnAttribute(llvm::Attribute::NoInline)) {
200 f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
201 }
202 }
203 CodeGenLLVM::Optimize();
204 }
205
206 llvm::Value* CreateIntrinsic(const CallNode* op) override;
207
208 protected:
209 void InitTarget() final {
210 // Maximum vector lane = float4
211 native_vector_bits_ = 4 * 32;
212 CodeGenLLVM::InitTarget();
213 }
214};
215
216// Check if this is a warp shuffle intrinsic call and match its
217// corresponding nvvm intrinsic. Return true if the match is successful.
218static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
219 // Only 32 bit data type is supported.
220 if (op->dtype.is_vector() || op->dtype.bits() != 32) {
221 return false;
222 }
223
224 // Intrinsic lookup table.
225 // It is difficult to emit _sync verion that works on Pascal.
226 // We ignore the mask and only emit the non-sync version for nvptx.
227 llvm::Intrinsic::ID ids[] = {
228 llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
229 llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
230 llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};
231
232 int offset = 0;
233 if (op->op.same_as(builtin::tvm_warp_shuffle())) {
234 offset = 0;
235 } else if (op->op.same_as(builtin::tvm_warp_shuffle_up())) {
236 offset = 2;
237 } else if (op->op.same_as(builtin::tvm_warp_shuffle_down())) {
238 offset = 4;
239 } else {
240 return false;
241 }
242
243 *id = ids[offset + op->dtype.is_float()];
244 return true;
245}
246
247llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) {
248 llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;
249 if (GetWarpShuffleIntrinsic(op, &id)) {
250 std::vector<llvm::Value*> arg_value;
251 std::vector<llvm::Type*> arg_type;
252 // Ignore the first mask operand and remove the last
253 // redundant warp_size..
254 size_t n_args = op->args.size() - 1;
255 for (size_t i = 1; i < n_args; ++i) {
256 arg_value.push_back(MakeValue(op->args[i]));
257 arg_type.push_back(arg_value.back()->getType());
258 }
259 llvm::Type* return_type = arg_type[0];
260 llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
261 return builder_->CreateCall(func, arg_value);
262 } else if (op->op.same_as(builtin::tvm_warp_activemask())) {
263 // Only nvptx target may keep this intrinsic at this point.
264 // PTX assembly: asm "activemask.b32 r1;"
265 auto fty = llvm::FunctionType::get(t_int32_, false);
266 auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
267 return builder_->CreateCall(val);
268 } else if (op->op.same_as(builtin::atomic_add())) {
269 ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now";
270 llvm::Value* v0 = MakeValue(op->args[0]);
271 llvm::Value* v1 = MakeValue(op->args[1]);
272 if (op->args[1]->dtype.is_float()) {
273#if TVM_LLVM_VERSION >= 90
274#if TVM_LLVM_VERSION >= 130
275 return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(),
276 llvm::AtomicOrdering::Monotonic);
277#else
278 return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1,
279 llvm::AtomicOrdering::Monotonic);
280#endif
281#else
282 LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer";
283#endif
284 }
285#if TVM_LLVM_VERSION >= 130
286 return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1, llvm::MaybeAlign(),
287 llvm::AtomicOrdering::Monotonic);
288#else
289 return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1,
290 llvm::AtomicOrdering::Monotonic);
291#endif
292 }
293 return CodeGenLLVM::CreateIntrinsic(op);
294}
295
296int GetCUDAComputeVersion(const Target& target) {
297 Optional<String> mcpu = target->GetAttr<String>("mcpu");
298 ICHECK(mcpu.defined()) << "InternalError: \"-mcpu\" is undefined in the NVPTX target";
299 std::string sm_version = mcpu.value();
300 return std::stoi(sm_version.substr(3));
301}
302
303runtime::Module BuildNVPTX(IRModule mod, Target target) {
304 LLVMInstance llvm_instance;
305 With<LLVMTarget> llvm_target(llvm_instance, target);
306
307 int compute_ver = GetCUDAComputeVersion(target);
308 auto cg = std::make_unique<CodeGenNVPTX>();
309
310 cg->Init("TVMPTXModule", llvm_target.get(), false, false, false);
311
312 cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
313 ICHECK(kv.second->template IsInstance<PrimFuncNode>())
314 << "Can only lower IR Module with PrimFuncs";
315 return Downcast<PrimFunc>(kv.second);
316 });
317
318 llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
319 const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
320 if (flibdevice_path != nullptr) {
321 std::string path = (*flibdevice_path)(compute_ver);
322 if (path.length() != 0) {
323 std::unique_ptr<llvm::Module> mlib = llvm_instance.LoadIR(path);
324 mlib->setTargetTriple(llvm_target->GetTargetTriple());
325 mlib->setDataLayout(tm->createDataLayout());
326 cg->AddLinkModule(std::move(mlib));
327 }
328 }
329 std::unique_ptr<llvm::Module> module = cg->Finish();
330 llvm::SmallString<8> data_ptx, data_ll;
331 llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll);
332 dest_ptx.SetUnbuffered();
333 dest_ll.SetUnbuffered();
334 // print ll
335 module->print(dest_ll, nullptr);
336 std::string ll(data_ll.begin(), data_ll.end());
337 // emit ptx
338 llvm::legacy::PassManager pass;
339#if TVM_LLVM_VERSION <= 60
340 ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
341 << "Cannot emit target CGFT_ObjectFile";
342#elif TVM_LLVM_VERSION <= 90
343 ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) ==
344 0)
345 << "Cannot emit target CGFT_ObjectFile";
346#else
347 ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0)
348 << "Cannot emit target CGFT_ObjectFile";
349#endif
350 pass.run(*module);
351 std::string ptx(data_ptx.begin(), data_ptx.end());
352 return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll);
353}
354
355TVM_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX);
356
357TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx")
358 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
359 *rv = static_cast<void*>(new CodeGenNVPTX());
360 });
361
362} // namespace codegen
363} // namespace tvm
364
365#endif // TVM_LLVM_VERSION
366