1 | #include <numeric> |
2 | #include <sstream> |
3 | #include <iomanip> |
4 | #include <stdexcept> |
5 | #include "triton/codegen/selection/generator.h" |
6 | #include "triton/codegen/target.h" |
7 | #include "triton/codegen/analysis/axes.h" |
8 | #include "triton/codegen/analysis/allocation.h" |
9 | #include "triton/codegen/analysis/align.h" |
10 | #include "triton/codegen/analysis/swizzle.h" |
11 | #include "triton/codegen/transform/coalesce.h" |
12 | #include "triton/ir/context.h" |
13 | #include "triton/ir/module.h" |
14 | #include "triton/ir/function.h" |
15 | #include "triton/ir/type.h" |
16 | #include "triton/ir/utils.h" |
17 | #include "llvm/IR/Module.h" |
18 | #include "llvm/IR/IRBuilder.h" |
19 | #include "llvm/IR/IntrinsicsNVPTX.h" |
20 | #include "llvm/IR/BasicBlock.h" |
21 | #include "llvm/IR/Attributes.h" |
22 | #include "llvm/IR/InlineAsm.h" |
23 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
24 | |
25 | namespace triton{ |
26 | namespace codegen{ |
27 | |
28 | using namespace llvm; |
29 | |
30 | Value* adder::operator()(Value *x, Value *y, const std::string& name) { |
31 | // (x + cst) + y -> (x + y) + cst |
32 | if(auto* bin = dyn_cast<BinaryOperator>(x)) |
33 | if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) |
34 | if(dyn_cast<Constant>(bin->getOperand(1))){ |
35 | return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y), |
36 | bin->getOperand(1)); |
37 | } |
38 | // (x + (y + cst)) -> (x + y) + cst |
39 | if(auto* bin = dyn_cast<BinaryOperator>(y)) |
40 | if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) |
41 | if(dyn_cast<Constant>(bin->getOperand(1))){ |
42 | return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)), |
43 | bin->getOperand(1)); |
44 | } |
45 | |
46 | // default |
47 | return (*builder_)->CreateAdd(x, y, name); |
48 | } |
49 | |
50 | Value* multiplier::operator()(Value *x, Value *y, const std::string &name) { |
51 | // (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2) |
52 | if(auto* bin = dyn_cast<BinaryOperator>(x)) |
53 | if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) |
54 | if(dyn_cast<Constant>(bin->getOperand(1))) |
55 | if(dyn_cast<Constant>(y)){ |
56 | return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y), |
57 | (*builder_)->CreateMul(bin->getOperand(1), y)); |
58 | } |
59 | // default |
60 | return (*builder_)->CreateMul(x, y, name); |
61 | } |
62 | |
63 | Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ |
64 | // (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2) |
65 | if(auto* gep = dyn_cast<GetElementPtrInst>(ptr)) |
66 | if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin())) |
67 | if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){ |
68 | return (*builder_)->CreateGEP(gep->getPointerOperand()->getType()->getScalarType()->getPointerElementType(), |
69 | gep->getPointerOperand(), (*builder_)->CreateAdd(cst1, cst2)); |
70 | } |
71 | // ptr + (off + cst) -> (ptr + off) + cst |
72 | if(auto* bin = dyn_cast<BinaryOperator>(off)) |
73 | if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) |
74 | if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){ |
75 | Value *gep = (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(), |
76 | ptr, bin->getOperand(0)); |
77 | return (*builder_)->CreateGEP(gep->getType()->getScalarType()->getPointerElementType(), |
78 | gep, bin->getOperand(1)); |
79 | } |
80 | // default |
81 | return (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(), |
82 | ptr, off, name); |
83 | } |
84 | |
85 | //Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) { |
86 | // return (*builder_)->CreateGEP(ty, ptr, vals, name); |
87 | //} |
88 | |
89 | // types |
90 | #define void_ty builder_->getVoidTy() |
91 | #define f16_ty builder_->getHalfTy() |
92 | #define bf16_ty builder_->getInt16Ty() |
93 | #define f32_ty builder_->getFloatTy() |
94 | #define i1_ty builder_->getInt1Ty() |
95 | #define i8_ty builder_->getInt8Ty() |
96 | #define i16_ty builder_->getInt16Ty() |
97 | #define i32_ty builder_->getInt32Ty() |
98 | #define i64_ty builder_->getInt64Ty() |
99 | #define vec_ty(type, num_el) VectorType::get(type, num_el, false) |
100 | #define ptr_ty(...) PointerType::get(__VA_ARGS__) |
101 | // constants |
102 | #define i16(...) builder_->getInt16(__VA_ARGS__) |
103 | #define i32(...) builder_->getInt32(__VA_ARGS__) |
104 | // ops |
105 | #define and_(...) builder_->CreateAnd(__VA_ARGS__) |
106 | #define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) |
107 | #define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__) |
108 | #define bin_op(...) builder_->CreateBinOp(__VA_ARGS__) |
109 | #define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__) |
110 | #define br(...) builder_->CreateBr(__VA_ARGS__) |
111 | #define call(...) builder_->CreateCall(__VA_ARGS__) |
112 | #define cast(...) builder_->CreateCast(__VA_ARGS__) |
113 | #define cond_br(...) builder_->CreateCondBr(__VA_ARGS__) |
114 | #define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__) |
115 | #define (...) builder_->CreateExtractElement(__VA_ARGS__) |
116 | #define (...) builder_->CreateExtractValue(__VA_ARGS__) |
117 | #define fadd(...) builder_->CreateFAdd(__VA_ARGS__) |
118 | #define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) |
119 | #define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__) |
120 | #define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__) |
121 | #define fmul(...) builder_->CreateFMul(__VA_ARGS__) |
122 | #define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) |
123 | #define fsub(...) builder_->CreateFSub(__VA_ARGS__) |
124 | #define icmp(...) builder_->CreateICmp(__VA_ARGS__) |
125 | #define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) |
126 | #define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) |
127 | #define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) |
128 | #define icmp_uge(...) builder_->CreateICmpUGE(__VA_ARGS__) |
129 | #define icmp_ule(...) builder_->CreateICmpULE(__VA_ARGS__) |
130 | #define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) |
131 | #define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) |
132 | #define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) |
133 | #define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr) |
134 | #define lshr(...) builder_->CreateLShr(__VA_ARGS__) |
135 | #define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) |
136 | #define min_num(...) builder_->CreateMinNum(__VA_ARGS__) |
137 | #define neg(...) builder_->CreateNeg(__VA_ARGS__) |
138 | #define phi(...) builder_->CreatePHI(__VA_ARGS__) |
139 | #define ret(...) builder_->CreateRet(__VA_ARGS__) |
140 | #define select(...) builder_->CreateSelect(__VA_ARGS__) |
141 | #define store(...) builder_->CreateStore(__VA_ARGS__) |
142 | #define sub(...) builder_->CreateSub(__VA_ARGS__) |
143 | #define shl(...) builder_->CreateShl(__VA_ARGS__) |
144 | #define udiv(...) builder_->CreateUDiv(__VA_ARGS__) |
145 | #define urem(...) builder_->CreateURem(__VA_ARGS__) |
146 | #define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) |
147 | #define xor_(...) builder_->CreateXor(__VA_ARGS__) |
148 | |
149 | /** |
150 | * \brief Convert Triton-IR Type to LLVM-IR Type |
151 | */ |
152 | Type *generator::cvt(ir::type *ty) { |
153 | // struct |
154 | if(ty->is_struct_ty()){ |
155 | std::vector<Type*> tys; |
156 | for(size_t i = 0; i < ty->get_struct_numel(); i++) |
157 | tys.push_back(cvt(ty->get_struct_type(i))); |
158 | return StructType::get(builder_->getContext(), tys, true); |
159 | } |
160 | |
161 | // function |
162 | if(auto* tt = dynamic_cast<ir::function_type*>(ty)){ |
163 | Type *ret_ty = cvt(tt->get_return_ty()); |
164 | std::vector<Type*> arg_tys(tt->get_num_params()); |
165 | for(size_t i = 0; i < arg_tys.size(); i++) |
166 | arg_tys[i] = cvt(tt->get_param_ty(i)); |
167 | return FunctionType::get(ret_ty, arg_tys, false); |
168 | } |
169 | // pointer |
170 | if(ty->is_pointer_ty()){ |
171 | Type *elt_ty = cvt(ty->get_pointer_element_ty()); |
172 | unsigned addr_space = ty->get_pointer_address_space(); |
173 | return ptr_ty(elt_ty, addr_space); |
174 | } |
175 | // integer |
176 | if(ty->is_integer_ty()){ |
177 | unsigned bitwidth = ty->get_integer_bitwidth(); |
178 | return IntegerType::get(*ctx_, bitwidth); |
179 | } |
180 | // primitive types |
181 | switch(ty->get_type_id()){ |
182 | case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); |
183 | case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); |
184 | case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); |
185 | case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type |
186 | case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); |
187 | case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); |
188 | case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); |
189 | case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_); |
190 | case ir::type::TokenTyID: return Type::getTokenTy(*ctx_); |
191 | default: break; |
192 | } |
193 | // unknown type |
194 | throw std::runtime_error("unknown conversion from ir::type to Type" ); |
195 | } |
196 | |
197 | /** |
198 | * \brief Convert Triton-IR Attribute to LLVM-IR Attribute |
199 | */ |
200 | llvm::Attribute generator::cvt(ir::attribute attr) { |
201 | switch(attr.get_kind()){ |
202 | case ir::noalias: return llvm::Attribute::get(*ctx_, llvm::Attribute::NoAlias); |
203 | case ir::readonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::ReadOnly); |
204 | case ir::writeonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::WriteOnly); |
205 | case ir::aligned: return llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, attr.get_value()); |
206 | case ir::retune: return llvm::Attribute::get(*ctx_, llvm::Attribute::None); |
207 | default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute" ); |
208 | } |
209 | } |
210 | |
211 | /** |
212 | * \brief Constructor of LLVM code generator |
213 | */ |
214 | generator::generator(analysis::axes *a_axes, |
215 | analysis::layouts *layouts, |
216 | analysis::align *alignment, |
217 | analysis::allocation *alloc, |
218 | analysis::swizzle *swizzle, |
219 | target *tgt, |
220 | unsigned num_warps) |
221 | : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), |
222 | tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) { |
223 | |
224 | } |
225 | |
226 | /** |
227 | * \brief Code Generation for `value` |
228 | */ |
229 | void generator::visit_value(ir::value* v) { |
230 | if(!seen_.insert(v).second) |
231 | return; |
232 | if(v->get_type()->is_block_ty()){ |
233 | if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){ |
234 | analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer(); |
235 | analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer(); |
236 | |
237 | // offset |
238 | Value *offset = nullptr; |
239 | // base pointer |
240 | Value *ptr = shared_ptr_[layout]; |
241 | |
242 | if (n_buffer) { |
243 | // ptr = base (shared_ptr_[layout]) + smem_idx * size |
244 | // read_smem_idx |
245 | if (v == n_buffer->phi) { |
246 | ptr = shared_ptr_[layout]; |
247 | } |
248 | // write_smem_idx |
249 | if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) { |
250 | int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v); |
251 | int elements = write_smem_idx * layout->get_per_stage_elements(); |
252 | ptr = gep(shared_pre_ptr_[layout], i32(elements)); |
253 | } else if (v == n_buffer->latch) { |
254 | Value* write_smem_idx = write_smem_idx_[layout]; |
255 | Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements())); |
256 | ptr = gep(shared_pre_ptr_[layout], elements); |
257 | } |
258 | } else if (double_buffer) { |
259 | if(v == double_buffer->phi) |
260 | offset = shared_off_[layout]; |
261 | if(v == double_buffer->latch) |
262 | ptr = shared_next_ptr_[layout]; |
263 | else if(v == double_buffer->first) |
264 | ptr = shared_pre_ptr_[layout]; |
265 | } // else do nothing |
266 | // what visit_dot & vist_cts & ... see |
267 | shmems_[v] = ptr; |
268 | // now only latches have offset (PHINode), only used by finalize_share_layout() |
269 | shoffs_[v] = offset; |
270 | } |
271 | } |
272 | // visit operands |
273 | BasicBlock *current = builder_->GetInsertBlock(); |
274 | auto *inst = dynamic_cast<ir::instruction*>(v); |
275 | if(inst) |
276 | for(ir::value *op: inst->ops()){ |
277 | if(dynamic_cast<ir::constant*>(op) || !dynamic_cast<ir::phi_node*>(v)) |
278 | visit_value(op); |
279 | } |
280 | init_idx(v); |
281 | // change insert point for phi node |
282 | builder_->SetInsertPoint(current); |
283 | auto *phi = dynamic_cast<ir::phi_node*>(v); |
284 | if(phi && !current->empty() && current->getFirstNonPHI()) |
285 | builder_->SetInsertPoint(&*current->getFirstNonPHI()); |
286 | // visit user |
287 | if(auto *usr = dynamic_cast<ir::user*>(v)){ |
288 | if(!dynamic_cast<ir::function*>(usr)) |
289 | usr->accept(this); |
290 | } |
291 | // revert insert point |
292 | if(phi && !current->empty() && current->getFirstNonPHI()) |
293 | builder_->SetInsertPoint(current); |
294 | } |
295 | |
296 | /** |
297 | * \brief Code Generation for `phi` |
298 | */ |
299 | void generator::visit_phi_node(ir::phi_node* x) { |
300 | Type *ty = cvt(x->get_type()->get_scalar_ty()); |
301 | for(indices_t idx: idxs_.at(x)) |
302 | vals_[x][idx] = phi(ty, x->get_num_operands()); |
303 | } |
304 | |
305 | /** |
306 | * \brief Code Generation for `call` |
307 | */ |
308 | void generator::visit_call_inst(ir::call_inst* call) { |
309 | throw std::runtime_error("call not supported! Triton should be inlining everything." ); |
310 | } |
311 | |
312 | void generator::visit_launch_inst(ir::launch_inst *launch) { |
313 | ir::function* fn = (ir::function*)launch->get_operand(0); |
314 | // forward-declare cudaGetParameterBufferV2 |
315 | std::vector<Type*> get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0), |
316 | ArrayType::get(builder_->getInt32Ty(), 3), |
317 | ArrayType::get(builder_->getInt32Ty(), 3), |
318 | builder_->getInt32Ty()}; |
319 | FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false); |
320 | Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2" , mod_); |
321 | AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]); |
322 | AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]); |
323 | ConstantInt* _0 = builder_->getInt32(0); |
324 | ConstantInt* _1 = builder_->getInt32(1); |
325 | ConstantInt* _2 = builder_->getInt32(2); |
326 | // create basic block |
327 | BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done" , builder_->GetInsertBlock()->getParent()); |
328 | BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch" , launch_done_bb->getParent(), launch_done_bb); |
329 | Value *tid = tgt_->get_local_id(mod_, *builder_, 0); |
330 | Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0)); |
331 | builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb); |
332 | builder_->SetInsertPoint(launch_bb); |
333 | |
334 | // |
335 | builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0})); |
336 | builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1})); |
337 | builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2})); |
338 | Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]); |
339 | builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0})); |
340 | builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1})); |
341 | builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2})); |
342 | Function* called_fn = fns_[fn]; |
343 | Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]); |
344 | Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)}); |
345 | // forwrd-declare cudaLaunchDeviceV2 |
346 | std::vector<Type*> launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()}; |
347 | FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false); |
348 | Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2" , mod_); |
349 | // TODO: add branch |
350 | Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()), |
351 | builder_->getInt64(0)); |
352 | BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2" , launch_done_bb->getParent(), launch_done_bb); |
353 | builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb); |
354 | builder_->SetInsertPoint(launch2_bb); |
355 | |
356 | unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace(); |
357 | unsigned off = 0; |
358 | unsigned last_size = 0; |
359 | for(ir::value* arg: launch->get_values()){ |
360 | Value* curr_arg = vals_[arg][{}]; |
361 | Type* curr_arg_ty = curr_arg->getType(); |
362 | // handle struct alignment |
363 | off += last_size; |
364 | unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8; |
365 | off = (off + size - 1) / size * size; |
366 | // get pointer to current arg |
367 | Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off)); |
368 | curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space)); |
369 | // store arg |
370 | builder_->CreateStore(curr_arg, curr_arg_ptr); |
371 | last_size = size; |
372 | } |
373 | builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)}); |
374 | builder_->CreateBr(launch_done_bb); |
375 | // done |
376 | builder_->SetInsertPoint(launch_done_bb); |
377 | |
378 | } |
379 | |
380 | /** |
381 | * \brief Code Generation for `binary_operator` |
382 | */ |
383 | void generator::visit_binary_operator(ir::binary_operator*x) { |
384 | using ll = llvm::Instruction::BinaryOps; |
385 | using tt = ir::binary_op_t; |
386 | auto cvt = [](ir::binary_op_t op){ |
387 | switch(op) { |
388 | case tt::Add: return ll::Add; |
389 | case tt::FAdd: return ll::FAdd; |
390 | case tt::Sub: return ll::Sub; |
391 | case tt::FSub: return ll::FSub; |
392 | case tt::Mul: return ll::Mul; |
393 | case tt::FMul: return ll::FMul; |
394 | case tt::UDiv: return ll::UDiv; |
395 | case tt::SDiv: return ll::SDiv; |
396 | case tt::FDiv: return ll::FDiv; |
397 | case tt::URem: return ll::URem; |
398 | case tt::SRem: return ll::SRem; |
399 | case tt::FRem: return ll::FRem; |
400 | case tt::Shl: return ll::Shl; |
401 | case tt::LShr: return ll::LShr; |
402 | case tt::AShr: return ll::AShr; |
403 | case tt::And: return ll::And; |
404 | case tt::Or: return ll::Or; |
405 | case tt::Xor: return ll::Xor; |
406 | default: throw std::runtime_error("unreachable switch" ); |
407 | } |
408 | }; |
409 | // x->print(std::cout); |
410 | for(indices_t idx: idxs_.at(x)){ |
411 | Value *lhs = vals_[x->get_operand(0)][idx]; |
412 | Value *rhs = vals_[x->get_operand(1)][idx]; |
413 | // manually select bf16 bin op |
414 | if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) { |
415 | assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty()); |
416 | if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b |
417 | InlineAsm *bf16_add_asm = |
418 | InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), |
419 | "{ .reg .b16 c; \n\t" |
420 | " mov.b16 c, 0x3f80U; \n\t" // 1.0 |
421 | " fma.rn.bf16 $0, $1, c, $2; } \n\t" , |
422 | "=h,h,h" , false); |
423 | vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs}); |
424 | } else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a |
425 | InlineAsm *bf16_sub_asm = |
426 | InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), |
427 | " { .reg .b16 c; \n\t" |
428 | " mov.b16 c, 0xbf80U; \n\t" // -1.0 |
429 | " fma.rn.bf16 $0, $2, c, $1;} \n\t" , |
430 | "=h,h,h" , false); |
431 | vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs}); |
432 | } else if (x->get_op() == tt::FMul) { // a * b = a*b + 0 |
433 | InlineAsm *bf16_mul_asm = |
434 | InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), |
435 | " { .reg .b16 c; \n\t" |
436 | " mov.b16 c, 0x8000U; \n\t" // 0.0 |
437 | " fma.rn.bf16 $0, $1, $2, c;} \n\t" , |
438 | "=h,h,h" , false); |
439 | vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs}); |
440 | } else |
441 | throw std::runtime_error("invalid bin op for bf16" ); |
442 | } else { // not bf16 |
443 | auto op = cvt(x->get_op()); |
444 | if(op == ll::Add) |
445 | vals_[x][idx] = add(lhs, rhs); |
446 | else if(op == ll::Mul) |
447 | vals_[x][idx] = mul(lhs, rhs); |
448 | else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() && |
449 | x->get_type()->get_scalar_ty()->is_fp32_ty()){ |
450 | InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false), |
451 | " div.full.f32 $0, $1, $2;" , "=r,r,r" , false); |
452 | vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs}); |
453 | |
454 | } |
455 | else |
456 | vals_[x][idx] = bin_op(op, lhs, rhs); |
457 | } |
458 | } |
459 | } |
460 | |
461 | /** |
462 | * \brief Code Generation for `getelementptr` |
463 | */ |
464 | void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) { |
465 | for(indices_t idx: idxs_.at(x)){ |
466 | Value *ptr = vals_[x->get_pointer_operand()][idx]; |
467 | std::vector<Value*> vals; |
468 | for(auto it= x->idx_begin(); it != x->idx_end(); it++) |
469 | vals.push_back(vals_[*it][idx]); |
470 | assert(vals.size() == 1); |
471 | vals_[x][idx] = gep(ptr, vals[0]); |
472 | } |
473 | } |
474 | |
475 | /** |
476 | * \brief Code Generation for `icmp` |
477 | */ |
478 | void generator::visit_icmp_inst(ir::icmp_inst* x) { |
479 | auto cvt = [](ir::cmp_pred_t pred) { |
480 | using ll = llvm::CmpInst::Predicate; |
481 | using tt = ir::cmp_pred_t; |
482 | switch(pred){ |
483 | case tt::FIRST_ICMP_PREDICATE: return ll::FIRST_ICMP_PREDICATE; |
484 | case tt::ICMP_EQ: return ll::ICMP_EQ; |
485 | case tt::ICMP_NE: return ll::ICMP_NE; |
486 | case tt::ICMP_UGT: return ll::ICMP_UGT; |
487 | case tt::ICMP_UGE: return ll::ICMP_UGE; |
488 | case tt::ICMP_ULT: return ll::ICMP_ULT; |
489 | case tt::ICMP_ULE: return ll::ICMP_ULE; |
490 | case tt::ICMP_SGT: return ll::ICMP_SGT; |
491 | case tt::ICMP_SGE: return ll::ICMP_SGE; |
492 | case tt::ICMP_SLT: return ll::ICMP_SLT; |
493 | case tt::ICMP_SLE: return ll::ICMP_SLE; |
494 | case tt::LAST_ICMP_PREDICATE: return ll::LAST_ICMP_PREDICATE; |
495 | default: throw std::runtime_error("unreachable switch" ); |
496 | } |
497 | }; |
498 | |
499 | for(indices_t idx: idxs_.at(x)){ |
500 | Value *lhs = vals_[x->get_operand(0)][idx]; |
501 | Value *rhs = vals_[x->get_operand(1)][idx]; |
502 | vals_[x][idx] = icmp(cvt(x->get_pred()), lhs, rhs); |
503 | } |
504 | } |
505 | |
506 | /** |
507 | * \brief Code Generation for `fcmp` |
508 | */ |
509 | void generator::visit_fcmp_inst(ir::fcmp_inst* x) { |
510 | auto cvt = [](ir::cmp_pred_t pred) { |
511 | using ll = llvm::CmpInst::Predicate; |
512 | using tt = ir::cmp_pred_t; |
513 | switch(pred){ |
514 | case tt::FIRST_FCMP_PREDICATE: return ll::FIRST_FCMP_PREDICATE; |
515 | case tt::FCMP_FALSE: return ll::FCMP_FALSE; |
516 | case tt::FCMP_OEQ: return ll::FCMP_OEQ; |
517 | case tt::FCMP_OGT: return ll::FCMP_OGT; |
518 | case tt::FCMP_OGE: return ll::FCMP_OGE; |
519 | case tt::FCMP_OLT: return ll::FCMP_OLT; |
520 | case tt::FCMP_OLE: return ll::FCMP_OLE; |
521 | case tt::FCMP_ONE: return ll::FCMP_ONE; |
522 | case tt::FCMP_ORD: return ll::FCMP_ORD; |
523 | case tt::FCMP_UNO: return ll::FCMP_UNO; |
524 | case tt::FCMP_UEQ: return ll::FCMP_UEQ; |
525 | case tt::FCMP_UGT: return ll::FCMP_UGT; |
526 | case tt::FCMP_UGE: return ll::FCMP_UGE; |
527 | case tt::FCMP_ULT: return ll::FCMP_ULT; |
528 | case tt::FCMP_ULE: return ll::FCMP_ULE; |
529 | case tt::FCMP_UNE: return ll::FCMP_UNE; |
530 | case tt::FCMP_TRUE: return ll::FCMP_TRUE; |
531 | case tt::LAST_FCMP_PREDICATE: return ll::LAST_FCMP_PREDICATE; |
532 | default: throw std::runtime_error("unreachable switch" ); |
533 | } |
534 | }; |
535 | for(indices_t idx: idxs_.at(x)){ |
536 | Value *lhs = vals_[x->get_operand(0)][idx]; |
537 | Value *rhs = vals_[x->get_operand(1)][idx]; |
538 | vals_[x][idx] = fcmp(cvt(x->get_pred()), lhs, rhs); |
539 | } |
540 | } |
541 | |
542 | |
543 | std::tuple<Value*, Value*, Value*, Value*> generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){ |
544 | in0 = cast(llvm::Instruction::FPTrunc, in0, f16_ty); |
545 | in1 = cast(llvm::Instruction::FPTrunc, in1, f16_ty); |
546 | in2 = cast(llvm::Instruction::FPTrunc, in2, f16_ty); |
547 | in3 = cast(llvm::Instruction::FPTrunc, in3, f16_ty); |
548 | Value *ret0, *ret1, *ret2, *ret3; |
549 | std::tie(ret0, ret1, ret2, ret3) = fp16x4_to_fp8x4(in0, in1, in2, in3); |
550 | return std::make_tuple(ret0, ret1, ret2, ret3); |
551 | } |
552 | |
553 | std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){ |
554 | Value *ret0, *ret1, *ret2, *ret3; |
555 | std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3); |
556 | ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty); |
557 | ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty); |
558 | ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty); |
559 | ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty); |
560 | return std::make_tuple(ret0, ret1, ret2, ret3); |
561 | } |
562 | |
563 | |
564 | std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){ |
565 | Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); |
566 | InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), |
567 | "{" |
568 | ".reg .b32 a<2>, b<2>; \n\t" |
569 | "prmt.b32 a0, 0, $2, 0x5040; \n\t" // If input is 0xdcba set a0 to 0xb0a0 |
570 | "prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0 |
571 | "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign) |
572 | "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign) |
573 | "shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 position) |
574 | "shr.b32 b1, b1, 1; \n\t" // b1 >>= 1 (shift into fp16 position) |
575 | "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 & a0) (restore sign) |
576 | "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 & a1) (restore sign) |
577 | "}" , "=r,=r,r" , false); |
578 | Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); |
579 | packed_in = insert_elt(packed_in, in0, (uint64_t)0); |
580 | packed_in = insert_elt(packed_in, in1, (uint64_t)1); |
581 | packed_in = insert_elt(packed_in, in2, (uint64_t)2); |
582 | packed_in = insert_elt(packed_in, in3, (uint64_t)3); |
583 | Value *in = bit_cast(packed_in, i32_ty); |
584 | Value *ret = call(ptx, {in}); |
585 | Value *packed_ret0 = extract_val(ret, {0}); |
586 | Value *packed_ret1 = extract_val(ret, {1}); |
587 | Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); |
588 | Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); |
589 | Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); |
590 | Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); |
591 | return std::make_tuple(ret0, ret1, ret2, ret3); |
592 | } |
593 | |
594 | std::tuple<Value*, Value*, Value*, Value*> generator::fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) { |
595 | /* fp16 bit representation is seeeeemmmmmmmmmm (s=sign, e=exponent, m=mantissa) |
596 | * fp8 bit representation is seeeemmm |
597 | * The 4 fp8 exponent bits are the low order 4 exponent bits in fp16. |
598 | * The 3 fp8 mantissa bits are the high order 3 mantissa bits in fp16. |
599 | * Note that the low order exponent bits and high order mantissa bits in fp16 are contiguous. |
600 | * We want to round to nearest fp8 value. To do that add 1 to 4th mantissa bit in fp16 (that's |
601 | * one more than the number of mantissa bits in fp8). |
602 | * fp8 = (fp16 & 0x8000) | (((f16 << 1) + 0x0080) & 0x7fff) |
603 | * |
604 | * We compute two fp16s in one uint32. The addition could cause bit flips from one fp16 to the |
605 | * other. To avoid this we zero out the most significant exponent bit. If that bit is set then |
606 | * the value isn't representable in float8 anyway so we assume it's never set (and give garbage |
607 | * output if it is). If we were willing to assume the most significant exponent was never set |
608 | * we could save the first two lop3.b32 instructions below. |
609 | */ |
610 | InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false), |
611 | "{" |
612 | ".reg .b32 a<2>, b<2>; \n\t" |
613 | "shl.b32 a0, $1, 1; \n\t" // a0 = input0 << 1 |
614 | "shl.b32 a1, $2, 1; \n\t" // a1 = input1 << 1 |
615 | "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // a0 = (a0 & 0x7fff7fff) |
616 | "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // a1 = (a1 & 0x7fff7fff) |
617 | "add.u32 a0, a0, 0x00800080; \n\t" // a0 += 0x00800080 |
618 | "add.u32 a1, a1, 0x00800080; \n\t" // a1 += 0x00800080 |
619 | "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n\t" // b0 = (input0 & 0x80008000) | a0 |
620 | "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n\t" // b1 = (input1 & 0x80008000) | a1 |
621 | "prmt.b32 $0, b0, b1, 0x7531; \n\t" // If b0 = 0xabcd and b1=0x0123 sets output to 0xac02 |
622 | "}" , "=r,r,r" , false); |
623 | Value *packed_in0 = UndefValue::get(vec_ty(f16_ty, 2)); |
624 | Value *packed_in1 = UndefValue::get(vec_ty(f16_ty, 2)); |
625 | packed_in0 = insert_elt(packed_in0, in0, (int)0); |
626 | packed_in0 = insert_elt(packed_in0, in1, (int)1); |
627 | packed_in1 = insert_elt(packed_in1, in2, (int)0); |
628 | packed_in1 = insert_elt(packed_in1, in3, (int)1); |
629 | Value *in_arg0 = bit_cast(packed_in0, i32_ty); |
630 | Value *in_arg1 = bit_cast(packed_in1, i32_ty); |
631 | Value *ret = call(ptx, {in_arg0, in_arg1}); |
632 | Value *ret0 = extract_elt(ret, (int)0); |
633 | Value *ret1 = extract_elt(ret, (int)1); |
634 | Value *ret2 = extract_elt(ret, (int)2); |
635 | Value *ret3 = extract_elt(ret, (int)3); |
636 | return std::make_tuple(ret0, ret1, ret2, ret3); |
637 | } |
638 | |
639 | std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3) { |
640 | // current exp offset: 15 |
641 | // Add 112 (127-15) to compensate the difference in exponent bias |
642 | // bf16 = (nosign >> (8-4) + 112 << 7) | sign; |
643 | // bf16 = (nosign >> 4 + 0x3800) | sign; |
644 | Type *ret_ty = StructType::get(*ctx_, {vec_ty(bf16_ty, 2), vec_ty(bf16_ty, 2)}); |
645 | InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), |
646 | "{" |
647 | ".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n\t" |
648 | "prmt.b32 a0, 0, $2, 0x5040; \n\t" // 0xdcba => 0xb0a0 |
649 | "prmt.b32 a1, 0, $2, 0x7060; \n\t" // 0xdcba => 0xd0c0 |
650 | "and.b32 sign0, a0, 0x80008000; \n\t" |
651 | "and.b32 sign1, a1, 0x80008000; \n\t" |
652 | "and.b32 nosign0, a0, 0x7fff7fff; \n\t" |
653 | "and.b32 nosign1, a1, 0x7fff7fff; \n\t" |
654 | "shr.b32 nosign0, nosign0, 4; \n\t" |
655 | "shr.b32 nosign1, nosign1, 4; \n\t" |
656 | "add.u32 nosign0, nosign0, 0x38003800; \n\t" |
657 | "add.u32 nosign1, nosign1, 0x38003800; \n\t" |
658 | "or.b32 $0, sign0, nosign0; \n\t" |
659 | "or.b32 $1, sign1, nosign1; \n\t" |
660 | "}" , "=r,=r,r" , false); |
661 | Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); |
662 | packed_in = insert_elt(packed_in, in0, (uint64_t)0); |
663 | packed_in = insert_elt(packed_in, in1, (uint64_t)1); |
664 | packed_in = insert_elt(packed_in, in2, (uint64_t)2); |
665 | packed_in = insert_elt(packed_in, in3, (uint64_t)3); |
666 | Value *in = bit_cast(packed_in, i32_ty); |
667 | Value *ret = call(ptx, {in}); |
668 | Value *packed_ret0 = extract_val(ret, {0}); |
669 | Value *packed_ret1 = extract_val(ret, {1}); |
670 | Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); |
671 | Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); |
672 | Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); |
673 | Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); |
674 | return std::make_tuple(ret0, ret1, ret2, ret3); |
675 | } |
676 | |
677 | std::tuple<Value*, Value*, Value*, Value*> generator::bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) { |
678 | /* Assuming fp8 exponent offset is 16. bf16 exponent offset is 127. |
679 | Max value in fp8: 0b01111111 (0x7f), |
680 | bf16: 3ff0 |
681 | Min value in fp8: 0b00000000 (0x00) |
682 | bf16: 0x3c00 |
683 | // @note: +0x8 is for "rounding to nearest zero" |
684 | fp8 = (nosign(bf16) - (112 << 7) + 0x8) << 4; |
685 | return fp8 | sign; // also permute bytes |
686 | */ |
687 | InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false), |
688 | "{\n\t" |
689 | ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n\t" |
690 | ".reg .u32 fp8_min, fp8_max, rn_, zero; \n\t" |
691 | "mov.u32 fp8_min, 0x38003800; \n\t" |
692 | "mov.u32 fp8_max, 0x3ff03ff0; \n\t" |
693 | "mov.u32 rn_, 0x80008; \n\t" |
694 | "mov.u32 zero, 0; \n\t" |
695 | "and.b32 sign0, $1, 0x80008000; \n\t" |
696 | "and.b32 sign1, $2, 0x80008000; \n\t" |
697 | "prmt.b32 sign, sign0, sign1, 0x7531; \n\t" |
698 | "and.b32 nosign0, $1, 0x7fff7fff; \n\t" |
699 | "and.b32 nosign1, $2, 0x7fff7fff; \n\t" |
700 | |
701 | ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n\t" // nosign = clamp(nosign, min, max) |
702 | "and.b32 nosign_0_0, nosign0, 0xffff0000; \n\t" |
703 | "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n\t" |
704 | "min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n\t" |
705 | "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n\t" |
706 | "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n\t" |
707 | "min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n\t" |
708 | "or.b32 nosign0, nosign_0_0, nosign_0_1; \n\t" |
709 | "and.b32 nosign_1_0, nosign1, 0xffff0000; \n\t" |
710 | "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n\t" |
711 | "min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n\t" |
712 | "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n\t" |
713 | "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n\t" |
714 | "min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n\t" |
715 | "or.b32 nosign1, nosign_1_0, nosign_1_1; \n\t" |
716 | |
717 | "add.u32 nosign0, nosign0, rn_; \n\t" // round to nearest zero |
718 | "add.u32 nosign1, nosign1, rn_; \n\t" |
719 | "sub.u32 nosign0, nosign0, 0x38003800; \n\t" // compensate offset |
720 | "sub.u32 nosign1, nosign1, 0x38003800; \n\t" |
721 | "shr.u32 nosign0, nosign0, 4; \n\t" |
722 | "shr.u32 nosign1, nosign1, 4; \n\t" |
723 | "prmt.b32 nosign, nosign0, nosign1, 0x6420; \n\t" |
724 | "or.b32 $0, nosign, sign; \n\t" |
725 | "" |
726 | "}" , "=r,r,r" , false); |
727 | Value *packed_in0 = UndefValue::get(vec_ty(bf16_ty, 2)); |
728 | Value *packed_in1 = UndefValue::get(vec_ty(bf16_ty, 2)); |
729 | packed_in0 = insert_elt(packed_in0, in0, (int)0); |
730 | packed_in0 = insert_elt(packed_in0, in1, (int)1); |
731 | packed_in1 = insert_elt(packed_in1, in2, (int)0); |
732 | packed_in1 = insert_elt(packed_in1, in3, (int)1); |
733 | Value *in_arg0 = bit_cast(packed_in0, i32_ty); |
734 | Value *in_arg1 = bit_cast(packed_in1, i32_ty); |
735 | Value *ret = call(ptx, {in_arg0, in_arg1}); |
736 | Value *ret0 = extract_elt(ret, (int)0); |
737 | Value *ret1 = extract_elt(ret, (int)1); |
738 | Value *ret2 = extract_elt(ret, (int)2); |
739 | Value *ret3 = extract_elt(ret, (int)3); |
740 | return std::make_tuple(ret0, ret1, ret2, ret3); |
741 | } |
742 | |
743 | Value* generator::bf16_to_fp32(Value *in0){ |
744 | if (tgt_->as_nvidia()->sm() >= 80) { |
745 | InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false), |
746 | "cvt.rn.f32.bf16 $0, $1;" , "=r,h" , false); |
747 | return call(ptx, {in0}); |
748 | } else { |
749 | Value *ret = UndefValue::get(vec_ty(i16_ty, 2)); |
750 | ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1); |
751 | ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0); |
752 | return bit_cast(ret, f32_ty); |
753 | } |
754 | } |
755 | |
756 | Value* generator::fp32_to_bf16(Value *in0){ |
757 | if(tgt_->as_nvidia()->sm() >= 80){ |
758 | InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false), |
759 | "cvt.rn.bf16.f32 $0, $1;" , "=h,r" , false); |
760 | return call(ptx, {in0}); |
761 | } |
762 | return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1); |
763 | } |
764 | |
765 | /** |
766 | * \brief Code Generation for `cast` |
767 | */ |
768 | void generator::visit_cast_inst(ir::cast_inst* x) { |
769 | ir::value *op = x->get_operand(0); |
770 | ir::type* ret_sca_ty = x->get_type()->get_scalar_ty(); |
771 | ir::type* op_sca_ty = op->get_type()->get_scalar_ty(); |
772 | auto x_idxs = idxs_.at(x); |
773 | auto op_idxs = idxs_.at(op); |
774 | |
775 | // <> FP8 |
776 | if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){ |
777 | // ensure that conversions can be vectorized |
778 | int ld = layouts_->get(x)->get_order(0); |
779 | int contiguous = layouts_->get(x)->to_scanline()->nts(ld); |
780 | if(contiguous % 4 != 0) |
781 | throw std::runtime_error("unsupported fp32 -> fp8 conversion" ); |
782 | |
783 | // run the conversion |
784 | auto cvt = [&](Value* a, Value* b, Value* c, Value* d){ |
785 | if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty()) |
786 | return fp32x4_to_fp8x4(a, b, c, d); |
787 | if(op_sca_ty->is_fp16_ty() && ret_sca_ty->is_fp8_ty()) |
788 | return fp16x4_to_fp8x4(a, b, c, d); |
789 | if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty()) |
790 | return fp8x4_to_fp16x4(a, b, c, d); |
791 | if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty()) |
792 | return fp8x4_to_fp32x4(a, b, c, d); |
793 | // fp8 <> bf16 |
794 | if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_bf16_ty()) |
795 | return fp8x4_to_bf16x4(a, b, c, d); |
796 | if (op_sca_ty->is_bf16_ty() && ret_sca_ty->is_fp8_ty()) |
797 | return bf16x4_to_fp8x4(a, b, c, d); |
798 | throw std::runtime_error("unsupported conversion" ); |
799 | }; |
800 | for(size_t i = 0; i < x_idxs.size(); i+=4){ |
801 | std::tie(vals_[x][x_idxs[i+0]], |
802 | vals_[x][x_idxs[i+1]], |
803 | vals_[x][x_idxs[i+2]], |
804 | vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]], |
805 | vals_[op][op_idxs[i+1]], |
806 | vals_[op][op_idxs[i+2]], |
807 | vals_[op][op_idxs[i+3]]); |
808 | } |
809 | return; |
810 | } |
811 | |
812 | // <> BF16 |
813 | if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){ |
814 | // FP32 -> BF16 |
815 | if(op_sca_ty->is_fp32_ty()){ |
816 | for (indices_t idx: idxs_.at(x)) { |
817 | Value *arg = vals_[x->get_operand(0)][idx]; |
818 | vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty); |
819 | } |
820 | return; |
821 | } |
822 | // BF16 -> FP32 |
823 | if(ret_sca_ty->is_fp32_ty()){ |
824 | for(size_t i = 0; i < x_idxs.size(); i++) |
825 | vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); |
826 | return; |
827 | } |
828 | } |
829 | |
830 | |
831 | Type *ty = cvt(x->get_type()->get_scalar_ty()); |
832 | auto cvt = [](ir::cast_op_t op){ |
833 | using ll = llvm::Instruction::CastOps; |
834 | using tt = ir::cast_op_t; |
835 | switch(op){ |
836 | case tt::Trunc: return ll::Trunc; |
837 | case tt::ZExt: return ll::ZExt; |
838 | case tt::SExt: return ll::SExt; |
839 | case tt::FPTrunc: return ll::FPTrunc; |
840 | case tt::FPExt: return ll::FPExt; |
841 | case tt::UIToFP: return ll::UIToFP; |
842 | case tt::SIToFP: return ll::SIToFP; |
843 | case tt::FPToUI: return ll::FPToUI; |
844 | case tt::FPToSI: return ll::FPToSI; |
845 | case tt::PtrToInt: return ll::PtrToInt; |
846 | case tt::IntToPtr: return ll::IntToPtr; |
847 | case tt::BitCast: return ll::BitCast; |
848 | case tt::AddrSpaceCast: return ll::AddrSpaceCast; |
849 | default: throw std::runtime_error("unreachable switch" ); |
850 | } |
851 | }; |
852 | for(indices_t idx: idxs_.at(x)){ |
853 | Value *arg = vals_[x->get_operand(0)][idx]; |
854 | vals_[x][idx] = cast(cvt(x->get_op()), arg, ty); |
855 | } |
856 | } |
857 | |
858 | std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int16_to_float16x8( |
859 | Value *in0, Value *scale_x512, Value *shift |
860 | ){ |
861 | /* unpacking 8 int2s packed into an int16 to 8 float16s |
862 | * the algorithm is similar to |
863 | * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/ |
864 | fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1492-L1563 |
865 | */ |
866 | Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); |
867 | InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false), |
868 | "{" |
869 | ".reg .b32 a<2>, b<4>; \n\t" // input is 0xab,cd,ef,gh,ab,cd,ef,gh, each a, b etc occupies two bits. |
870 | "and.b32 a0, 0x30300303, $4; \n\t" // set a0 to 0x0b,00,0f,00,00,0d,00,0h |
871 | "and.b32 a1, 0xc0c00c0c, $4; \n\t" // set a1 to 0xa0,00,e0,00,00,c0,00,g0 |
872 | "prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x00,00,00,0d,00,00,00,0h |
873 | "prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00,00,00,c0,00,00,00,g0 |
874 | "prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x00,00,0b,00,00,00,0f,00 |
875 | "prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00,00,a0,00,00,00,e0,00 |
876 | "mov.b32 a0, 0x78007800; \n\t" // a0 = 32768 |
877 | "mov.b32 a1, 0x70007000; \n\t" // a1 = 8192 |
878 | "mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768. |
879 | "mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 8192. |
880 | "mov.b32 a0, 0x68006800; \n\t" // a0 = 2048 |
881 | "mov.b32 a1, 0x60006000; \n\t" // a1 = 512 |
882 | "mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 2048. |
883 | "mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 512. |
884 | "fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift. |
885 | "fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift. |
886 | "fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out2 = b2 * scale + shift. |
887 | "fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out3 = b3 * scale + shift. |
888 | "}" , "=r,=r,=r,=r,r,r,r" , false); |
889 | |
890 | Value *packed_in = UndefValue::get(vec_ty(i16_ty, 2)); |
891 | packed_in = insert_elt(packed_in, in0, (int)0); |
892 | packed_in = insert_elt(packed_in, in0, (int)1); |
893 | Value *in = bit_cast(packed_in, i32_ty); |
894 | |
895 | Value *ret = call(ptx, {in, scale_x512, shift}); |
896 | Value *packed_ret0 = extract_val(ret, {0}); |
897 | Value *packed_ret1 = extract_val(ret, {1}); |
898 | Value *packed_ret2 = extract_val(ret, {2}); |
899 | Value *packed_ret3 = extract_val(ret, {3}); |
900 | Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h |
901 | Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g |
902 | Value *ret2 = extract_elt(packed_ret2, (uint64_t)0); // f |
903 | Value *ret3 = extract_elt(packed_ret3, (uint64_t)0); // e |
904 | Value *ret4 = extract_elt(packed_ret0, (uint64_t)1); // d |
905 | Value *ret5 = extract_elt(packed_ret1, (uint64_t)1); // c |
906 | Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b |
907 | Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a |
908 | return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7); |
909 | } |
910 | |
911 | std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int32_to_float16x8( |
912 | Value *in0, Value *scale_x512, Value *shift |
913 | ){ |
914 | /* unpacking 8 int4s packed into an int32 to 8 float16s |
915 | * the algorithm is similar to |
916 | * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/ |
917 | fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1566-L1619 |
918 | */ |
919 | Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); |
920 | InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false), |
921 | "{" |
922 | ".reg .b32 a<2>, b<4>; \n\t" |
923 | "and.b32 a0, 0x0f0f0f0f, $4; \n\t" // If input is 0xabcdefgh set a to 0x0b0d0f0h |
924 | "and.b32 a1, 0xf0f0f0f0, $4; \n\t" // If input is 0xabcdefgh set a to 0xa0c0e0g0 |
925 | "prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x000f000h |
926 | "prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00e000g0 |
927 | "prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x000b000d |
928 | "prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00a000c0 |
929 | "mov.b32 a0, 0x78007800; \n\t" |
930 | "mov.b32 a1, 0x68006800; \n\t" |
931 | "mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768. |
932 | "mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 2048. |
933 | "mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 32768. |
934 | "mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 2048. |
935 | "fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift. |
936 | "fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift. |
937 | "fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out0 = b0 * scale + shift. |
938 | "fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out1 = b1 * scale + shift. |
939 | "}" , "=r,=r,=r,=r,r,r,r" , false); |
940 | |
941 | Value *ret = call(ptx, {in0, scale_x512, shift}); |
942 | Value *packed_ret0 = extract_val(ret, {0}); |
943 | Value *packed_ret1 = extract_val(ret, {1}); |
944 | Value *packed_ret2 = extract_val(ret, {2}); |
945 | Value *packed_ret3 = extract_val(ret, {3}); |
946 | Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h |
947 | Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g |
948 | Value *ret2 = extract_elt(packed_ret0, (uint64_t)1); // f |
949 | Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // e |
950 | Value *ret4 = extract_elt(packed_ret2, (uint64_t)0); // d |
951 | Value *ret5 = extract_elt(packed_ret3, (uint64_t)0); // c |
952 | Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b |
953 | Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a |
954 | return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7); |
955 | } |
956 | |
957 | std::tuple<Value*, Value*, Value*, Value*> generator::int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift){ |
958 | /* unpacking 4 int8s packed into an int32 to 4 fp16s |
959 | * the algorithm is similar to |
960 | * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/ |
961 | fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1622-L1646 |
962 | */ |
963 | Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); |
964 | InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false), |
965 | "{" |
966 | ".reg .b32 a, b<2>; \n\t" |
967 | "prmt.b32 b0, 0, $2, 0x0504; \n\t" // If input is 0xabcdefgh set b0 to 0x00ef00gh |
968 | "prmt.b32 b1, 0, $2, 0x0706; \n\t" // If input is 0xabcdefgh set b1 to 0x00ab00cd |
969 | "mov.b32 a, 0x78007800; \n\t" |
970 | "mul.f16x2 b0, b0, a; \n\t" // b0 = b0 * 32768. |
971 | "mul.f16x2 b1, b1, a; \n\t" // b1 = b1 * 32768. |
972 | "fma.rn.f16x2 $0, b0, $3, $4; \n\t" // out0 = b0 * scale + shift. |
973 | "fma.rn.f16x2 $1, b1, $3, $4; \n\t" // out1 = b1 * scale + shift. |
974 | "}" , "=r,=r,r,r,r" , false); |
975 | |
976 | Value *ret = call(ptx, {in0, scale_x512, shift}); |
977 | Value *packed_ret0 = extract_val(ret, {0}); |
978 | Value *packed_ret1 = extract_val(ret, {1}); |
979 | Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // gh |
980 | Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); // ef |
981 | Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); // cd |
982 | Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // ab |
983 | return std::make_tuple(ret0, ret1, ret2, ret3); |
984 | } |
985 | |
986 | std::tuple<Value*, Value*> generator::prepare_scale_shift(Value *scale, Value *shift){ |
987 | Value *scale_x512 = fmul(scale, bit_cast(i16(0x6000), f16_ty)); |
988 | Value *p_scale_x512 = UndefValue::get(vec_ty(f16_ty, 2)); |
989 | p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)0); |
990 | p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)1); |
991 | p_scale_x512 = bit_cast(p_scale_x512, i32_ty); |
992 | |
993 | Value *p_shift = UndefValue::get(vec_ty(f16_ty, 2)); |
994 | p_shift = insert_elt(p_shift, shift, (int)0); |
995 | p_shift = insert_elt(p_shift, shift, (int)1); |
996 | p_shift = bit_cast(p_shift, i32_ty); |
997 | |
998 | return std::make_tuple(p_scale_x512, p_shift); |
999 | } |
1000 | |
1001 | /** |
1002 | * \brief Code Generation for `dequantize` |
1003 | */ |
1004 | void generator::visit_dequantize_inst(ir::dequantize_inst* x) { |
1005 | ir::value *op = x->get_operand(0); |
1006 | |
1007 | auto src_ty_size_in_bits = op->get_type()->get_scalar_ty()->get_primitive_size_in_bits(); |
1008 | |
1009 | auto ret_last_dim = (x->get_type()->get_block_shapes()).back(); |
1010 | auto op_last_dim = (op->get_type()->get_block_shapes()).back(); |
1011 | |
1012 | auto x_idxs = idxs_.at(x); |
1013 | auto op_idxs = idxs_.at(op); |
1014 | |
1015 | ir::value *scale = x->get_operand(1); |
1016 | ir::value *shift = x->get_operand(2); |
1017 | |
1018 | Value *p_scale_x512, *p_shift; |
1019 | std::tie(p_scale_x512, p_shift) = prepare_scale_shift(vals_[scale][{}], vals_[shift][{}]); |
1020 | |
1021 | int ld = layouts_->get(x)->get_order(0); |
1022 | int contiguous = layouts_->get(x)->to_scanline()->nts(ld); |
1023 | |
1024 | int op_ld = layouts_->get(op)->get_order(0); |
1025 | int op_contiguous = layouts_->get(op)->to_scanline()->nts(op_ld); |
1026 | |
1027 | std::string err_msg; |
1028 | err_msg = "unsupported dequantization, cannot vectorize properly. x_idxs.size(): " |
1029 | + std::to_string(x_idxs.size()) + "; op_idxs.size(): " |
1030 | + std::to_string(op_idxs.size()) + "; contiguous: " |
1031 | + std::to_string(contiguous) + "; op_contiguous: " |
1032 | + std::to_string(op_contiguous) + ". if the condition " |
1033 | "is not met, please try adjusting block_size, num_warps or " |
1034 | "using tl.multiple_of to hint the input/output ptr address." ; |
1035 | |
1036 | if (ret_last_dim == 8 * op_last_dim) { |
1037 | if((x_idxs.size() != 8 * op_idxs.size()) || (contiguous != 8 * op_contiguous)) { |
1038 | throw std::runtime_error(err_msg); |
1039 | } |
1040 | |
1041 | auto cvt = [&]( |
1042 | Value* a, Value* scale, Value* shift |
1043 | ){ |
1044 | if (src_ty_size_in_bits == 16){ // int2 quantization, int16 to 8 fp16s |
1045 | return int16_to_float16x8(a, scale, shift); |
1046 | } else if (src_ty_size_in_bits == 32) { // int4 quantization, int32 to 8 fp16s |
1047 | return int32_to_float16x8(a, scale, shift); |
1048 | } else { |
1049 | throw std::runtime_error("unsupported conversion" ); |
1050 | } |
1051 | }; |
1052 | |
1053 | for(size_t j = 0; j < op_idxs.size(); j++){ |
1054 | size_t i = j * 8; |
1055 | std::tie(vals_[x][x_idxs[i+0]], |
1056 | vals_[x][x_idxs[i+1]], |
1057 | vals_[x][x_idxs[i+2]], |
1058 | vals_[x][x_idxs[i+3]], |
1059 | vals_[x][x_idxs[i+4]], |
1060 | vals_[x][x_idxs[i+5]], |
1061 | vals_[x][x_idxs[i+6]], |
1062 | vals_[x][x_idxs[i+7]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift); |
1063 | } |
1064 | } else if (ret_last_dim == 4 * op_last_dim && src_ty_size_in_bits == 32) { // int8 quantization, int32 to 4 fp16s |
1065 | if((x_idxs.size() != 4 * op_idxs.size()) || (contiguous != 4 * op_contiguous)) { |
1066 | throw std::runtime_error(err_msg); |
1067 | } |
1068 | |
1069 | auto cvt = [&](Value* a, Value* scale, Value* shift){ |
1070 | return int32_to_float16x4(a, scale, shift); |
1071 | }; |
1072 | |
1073 | for(size_t j = 0; j < op_idxs.size(); j++){ |
1074 | size_t i = j * 4; |
1075 | std::tie(vals_[x][x_idxs[i+0]], |
1076 | vals_[x][x_idxs[i+1]], |
1077 | vals_[x][x_idxs[i+2]], |
1078 | vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift); |
1079 | } |
1080 | } else { |
1081 | throw std::runtime_error("unsupported dequantization" ); |
1082 | } |
1083 | return; |
1084 | } |
1085 | |
1086 | /** |
1087 | * \brief Code Generation for `return` |
1088 | */ |
1089 | void generator::visit_return_inst(ir::return_inst* rr) { |
1090 | ir::value *ret_val = rr->get_return_value(); |
1091 | ret(ret_val ? vals_[ret_val][{}] : nullptr); |
1092 | } |
1093 | |
1094 | /** |
1095 | * \brief Code Generation for `cond_branch` |
1096 | */ |
1097 | void generator::visit_cond_branch_inst(ir::cond_branch_inst* br) { |
1098 | BasicBlock *true_dest = bbs_.at(br->get_true_dest()); |
1099 | BasicBlock *false_dest = bbs_.at(br->get_false_dest()); |
1100 | Value *cond = vals_[br->get_cond()][{}]; |
1101 | cond_br(cond, true_dest, false_dest); |
1102 | } |
1103 | |
1104 | /** |
1105 | * \brief Code Generation for `uncond_branch` |
1106 | */ |
1107 | void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { |
1108 | BasicBlock *dest = bbs_.at(br->get_dest()); |
1109 | br(dest); |
1110 | } |
1111 | |
1112 | /** |
1113 | * \brief Code Generation for a (synchronous) `load` |
1114 | */ |
1115 | void generator::visit_load_inst(ir::load_inst* x){ |
1116 | BasicBlock *current = builder_->GetInsertBlock(); |
1117 | Module *module = current->getModule(); |
1118 | Value *tid = tgt_->get_local_id(module, *builder_, 0); |
1119 | Value *lane = urem(tid, i32(32)); |
1120 | ir::value *op = x->get_pointer_operand(); |
1121 | ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x); |
1122 | Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); |
1123 | // compute vector width |
1124 | size_t vec = 1; |
1125 | bool is_mma_first_row = false; |
1126 | if(op->get_type()->is_block_ty()){ |
1127 | auto ord = ords_.at(op); |
1128 | size_t aln = alignment_->get(op, ord[0]); |
1129 | if(mx){ |
1130 | size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst; |
1131 | max_eq = std::max<size_t>(max_eq, 1); |
1132 | aln = std::min(aln, max_eq); |
1133 | } |
1134 | analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(x)); |
1135 | assert(layout); |
1136 | |
1137 | vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln); |
1138 | // TODO: generalize |
1139 | is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && |
1140 | (a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1)); |
1141 | if(is_mma_first_row) |
1142 | vec = std::min<size_t>(2, aln); |
1143 | } |
1144 | // code generation |
1145 | auto idxs = idxs_.at(x); |
1146 | for(size_t i = 0; i < idxs.size(); i += vec){ |
1147 | indices_t idx = idxs[i]; |
1148 | // pointer value |
1149 | Value *ptr = vals_[op][idx]; |
1150 | // masked load |
1151 | size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; |
1152 | // input ptr info |
1153 | GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr); |
1154 | size_t in_off; |
1155 | if(in_gep){ |
1156 | ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin()); |
1157 | in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; |
1158 | ptr = cst ? in_gep->getPointerOperand() : in_gep; |
1159 | } |
1160 | else{ |
1161 | in_off = 0; |
1162 | } |
1163 | Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); |
1164 | // if(!op->get_type()->is_block_ty()){ |
1165 | // pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0))); |
1166 | // } |
1167 | Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; |
1168 | size_t nbits = dtsize*8; |
1169 | // pack sub-words (< 32/64bits) into words |
1170 | // each load has width min(nbits*vec, 32/64) |
1171 | // and there are (nbits * vec)/width of them |
1172 | int max_word_width = std::max<int>(32, nbits); |
1173 | int tot_width = nbits*vec; |
1174 | int width = std::min(tot_width, max_word_width); |
1175 | int n_words = std::max(1, tot_width / width); |
1176 | bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; |
1177 | has_l2_evict_policy = false; |
1178 | // has_evict_policy = false; // currently disable until supported in `store` |
1179 | // ----- |
1180 | // create inline asm string |
1181 | // ----- |
1182 | std::ostringstream asm_oss; |
1183 | asm_oss << "@$" << n_words; // predicate |
1184 | asm_oss << " ld" ; |
1185 | if(x->get_is_volatile()) |
1186 | asm_oss << ".volatile" ; |
1187 | asm_oss << ".global" ; |
1188 | if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca" ; |
1189 | if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg" ; |
1190 | if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first" ; |
1191 | if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last" ; |
1192 | if (has_l2_evict_policy) asm_oss << ".L2::cache_hint" ; |
1193 | if(n_words > 1) |
1194 | asm_oss << ".v" << n_words; // vector width |
1195 | asm_oss << ".b" << width; // word size |
1196 | asm_oss << " {" ; |
1197 | for(int i = 0; i < n_words; i++){ // return values |
1198 | if(i > 0) asm_oss << "," ; |
1199 | asm_oss << "$" << i; |
1200 | } |
1201 | asm_oss << "}" ; |
1202 | asm_oss << ", [ $" << n_words + 1; // load |
1203 | asm_oss << " + " << in_off << "]" ; // constant offset |
1204 | if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2; |
1205 | asm_oss << ";" ; |
1206 | bool has_other = other && (other != UndefValue::get(other->getType())); |
1207 | std::vector<Value *> others; |
1208 | // handle `other` values for indices where the mask |
1209 | // is false |
1210 | if(has_other) |
1211 | for(size_t ii = 0; ii < n_words; ii++){ |
1212 | size_t size = width / nbits; |
1213 | Value *v = UndefValue::get(vec_ty(ty, size)); |
1214 | for(size_t s = 0; s < size; s++){ |
1215 | ir::value *false_val = mx->get_false_value_operand(); |
1216 | v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s); |
1217 | } |
1218 | v = bit_cast(v, IntegerType::get(*ctx_, width)); |
1219 | // PTX doesn't support mov.u8, so we need to use mov.u16 |
1220 | auto mov_width = width < 16 ? 16 : width; |
1221 | asm_oss << "\n " ; |
1222 | asm_oss << "@!$" << n_words << " mov.u" << mov_width; |
1223 | asm_oss << " $" << ii << ", " ; |
1224 | std::ios_base::fmtflags flags(asm_oss.flags()); |
1225 | if(ConstantInt* cst = dyn_cast<ConstantInt>(v)) |
1226 | asm_oss << "0x" << std::hex << cst->getSExtValue(); |
1227 | else{ |
1228 | asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii; |
1229 | others.push_back(v); |
1230 | } |
1231 | asm_oss.flags(flags); |
1232 | asm_oss << ";" ; |
1233 | } |
1234 | // ---- |
1235 | // create inline ASM signature |
1236 | // --- |
1237 | std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width)); |
1238 | Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0]; |
1239 | // ret_ty->print(llvm::outs()); |
1240 | std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()}; |
1241 | for(Value *v: others) |
1242 | arg_tys.push_back(v->getType()); |
1243 | if (has_l2_evict_policy) |
1244 | arg_tys.push_back(i64_ty); |
1245 | FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false); |
1246 | // --- |
1247 | // create inline ASM constraints |
1248 | // --- |
1249 | std::string asm_cstrt; |
1250 | for(int ii = 0; ii < n_words; ii++){ |
1251 | if(ii > 0) asm_cstrt += "," ; |
1252 | asm_cstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c" ); |
1253 | } |
1254 | asm_cstrt += ",b,l" ; |
1255 | for(size_t ii = 0; ii < others.size(); ii++){ |
1256 | asm_cstrt += "," ; |
1257 | asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c" ); |
1258 | } |
1259 | if (has_l2_evict_policy) |
1260 | asm_cstrt += ",l" ; |
1261 | // --- |
1262 | // finally call inline ASM |
1263 | // --- |
1264 | InlineAsm *inlineAsm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); |
1265 | std::vector<Value*> args = {pred, ptr}; |
1266 | for(Value *v: others) |
1267 | args.push_back(v); |
1268 | if (has_l2_evict_policy) |
1269 | args.push_back(policies_.at(x->get_eviction_policy())); |
1270 | |
1271 | |
1272 | Value *_ret = call(inlineAsm, args); |
1273 | // if(!op->get_type()->is_block_ty()){ |
1274 | // Value* cond = icmp_eq(tid, i32(0)); |
1275 | // Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3)); |
1276 | // Instruction* bar = add_barrier(); |
1277 | // Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false); |
1278 | // builder_->SetInsertPoint(term); |
1279 | // store(_ret, shptr); |
1280 | // builder_->SetInsertPoint(bar->getParent()); |
1281 | // _ret = load(shptr); |
1282 | // add_barrier(); |
1283 | // } |
1284 | |
1285 | // --- |
1286 | // extract and store return values |
1287 | // --- |
1288 | std::vector<Value *> rets; |
1289 | for(unsigned int ii = 0; ii < n_words; ii++){ |
1290 | Value *curr; |
1291 | if(ret_ty->isStructTy()) |
1292 | curr = extract_val(_ret, {ii}); |
1293 | else |
1294 | curr = _ret; |
1295 | rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8)))); |
1296 | } |
1297 | int tmp = (width / (dtsize * 8)); |
1298 | for(size_t ii = 0; ii < vec; ii++) |
1299 | vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp); |
1300 | } |
1301 | } |
1302 | |
1303 | void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { |
1304 | visit_load_inst(x); |
1305 | } |
1306 | void generator::visit_masked_load_inst(ir::masked_load_inst* x) { |
1307 | visit_load_inst(x); |
1308 | } |
1309 | |
1310 | /** |
1311 | * \brief Code Generation for a (synchronous) `store` |
1312 | */ |
1313 | |
1314 | void generator::visit_store_inst(ir::store_inst * x){ |
1315 | ir::masked_store_inst *mx = dynamic_cast<ir::masked_store_inst*>(x); |
1316 | // operands |
1317 | ir::value *ptr_op = x->get_pointer_operand(); |
1318 | ir::value *val_op = x->get_value_operand(); |
1319 | ir::value *msk_op = nullptr; |
1320 | if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x)) |
1321 | msk_op = msk_st->get_mask_operand(); |
1322 | // vector size |
1323 | size_t vec = 1; |
1324 | if(val_op->get_type()->is_block_ty()){ |
1325 | auto ord = ords_.at(x->get_pointer_operand()); |
1326 | size_t aln = alignment_->get(ptr_op, ord[0]); |
1327 | size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous; |
1328 | if(mx){ |
1329 | size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst; |
1330 | max_eq = std::max<size_t>(max_eq, 1); |
1331 | aln = std::min(aln, max_eq); |
1332 | } |
1333 | analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(ptr_op)); |
1334 | assert(layout); |
1335 | // vec = std::min(nts, aln); |
1336 | vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln); |
1337 | // TODO: generalize |
1338 | bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && |
1339 | (a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1)); |
1340 | if(is_mma_first_row) |
1341 | vec = std::min<size_t>(2, aln); |
1342 | } |
1343 | bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; |
1344 | has_l2_evict_policy = false; |
1345 | auto idxs = idxs_.at(val_op); |
1346 | Type *ty = cvt(val_op->get_type()->get_scalar_ty()); |
1347 | if(ty->isIntegerTy(1)) |
1348 | ty = builder_->getInt8Ty(); |
1349 | for(size_t i = 0; i < idxs.size(); i += vec){ |
1350 | indices_t idx = idxs[i]; |
1351 | // pointers |
1352 | Value *ptr = vals_[ptr_op][idx]; |
1353 | size_t dtsize = std::max<int>(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8); |
1354 | GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr); |
1355 | size_t in_off; |
1356 | if(in_gep){ |
1357 | ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin()); |
1358 | in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; |
1359 | ptr = cst ? in_gep->getPointerOperand() : in_gep; |
1360 | } |
1361 | else{ |
1362 | in_off = 0; |
1363 | } |
1364 | // mask |
1365 | Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue(); |
1366 | size_t nbits = dtsize*8; |
1367 | // pack sub-words (< 32/64bits) into words |
1368 | // each load has width min(nbits*vec, 32/64) |
1369 | // and there are (nbits * vec)/width of them |
1370 | int max_word_width = std::max<int>(32, nbits); |
1371 | int tot_width = nbits*vec; |
1372 | int width = std::min(tot_width, max_word_width); |
1373 | int n_words = std::max(1, tot_width / width); |
1374 | // ----- |
1375 | // create inline asm string |
1376 | // ----- |
1377 | std::ostringstream asm_oss; |
1378 | asm_oss << "@$0" ; // predicate |
1379 | asm_oss << " st.global" ; |
1380 | if (has_l2_evict_policy) asm_oss << ".L2::cache_hint" ; |
1381 | if(n_words > 1) |
1382 | asm_oss << ".v" << n_words; // vector width |
1383 | asm_oss << ".b" << width; // word size |
1384 | asm_oss << " [ $1 + " << in_off << "]" ; |
1385 | asm_oss << " , {" ; |
1386 | for(int i = 0; i < n_words; i++){ // return values |
1387 | if(i > 0) asm_oss << "," ; |
1388 | asm_oss << "$" << 2 + i; |
1389 | } |
1390 | asm_oss << "}" ; |
1391 | if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2; |
1392 | asm_oss << ";" ; |
1393 | // ---- |
1394 | // create inline ASM signature |
1395 | // --- |
1396 | Type* val_arg_ty = IntegerType::get(*ctx_, width); |
1397 | std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()}; |
1398 | for(int ii = 0; ii < n_words; ii++) |
1399 | arg_tys.push_back(val_arg_ty); |
1400 | if (has_l2_evict_policy) |
1401 | arg_tys.push_back(i64_ty); |
1402 | FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false); |
1403 | // --- |
1404 | // create inline ASM constraints |
1405 | // --- |
1406 | std::string asm_cstrt = "b,l" ; |
1407 | for(int ii = 0; ii < n_words; ii++){ |
1408 | asm_cstrt += "," ; |
1409 | asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c" ); |
1410 | } |
1411 | if (has_l2_evict_policy) |
1412 | asm_cstrt += ",l" ; |
1413 | // --- |
1414 | // finally call inline ASM |
1415 | // --- |
1416 | InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); |
1417 | std::vector<Value*> args = {pred, ptr}; |
1418 | for(unsigned int ii = 0; ii < n_words; ii++){ |
1419 | size_t n_subw = width / nbits; |
1420 | Value* curr = UndefValue::get(vec_ty(ty, n_subw)); |
1421 | for(unsigned int jj = 0; jj < n_subw; jj++){ |
1422 | Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]]; |
1423 | if(new_elt->getType()->isIntegerTy(1)) |
1424 | new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty()); |
1425 | new_elt = bit_cast(new_elt, ty); |
1426 | curr = builder_->CreateInsertElement(curr, new_elt, jj); |
1427 | } |
1428 | args.push_back(bit_cast(curr, val_arg_ty)); |
1429 | } |
1430 | if (has_l2_evict_policy) |
1431 | args.push_back(policies_.at(x->get_eviction_policy())); |
1432 | call(_asm, args); |
1433 | } |
1434 | } |
1435 | void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) { |
1436 | visit_store_inst(x); |
1437 | } |
1438 | void generator::visit_masked_store_inst(ir::masked_store_inst* x) { |
1439 | visit_store_inst(x); |
1440 | } |
1441 | |
1442 | // -- |
1443 | |
1444 | void generator::(ir::extract_value_inst *x) { |
1445 | auto idxs = idxs_.at(x); |
1446 | ir::value* agg = x->get_operand(0); |
1447 | unsigned insert_idx = x->get_idx(); |
1448 | for(size_t i = 0; i < idxs.size(); i++){ |
1449 | auto idx = idxs[i]; |
1450 | vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx}); |
1451 | } |
1452 | } |
1453 | |
1454 | |
1455 | void generator::visit_insert_value_inst(ir::insert_value_inst *x){ |
1456 | auto idxs = idxs_.at(x); |
1457 | ir::value* agg = x->get_operand(0); |
1458 | ir::value* val = x->get_operand(1); |
1459 | unsigned insert_idx = x->get_idx(); |
1460 | for(size_t i = 0; i < idxs.size(); i++){ |
1461 | auto idx = idxs[i]; |
1462 | vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx}); |
1463 | } |
1464 | } |
1465 | |
1466 | // -- |
1467 | /** |
1468 | * \brief Code Generation for `cat` |
1469 | */ |
1470 | void generator::visit_cat_inst(ir::cat_inst* x) { |
1471 | auto idxs = idxs_.at(x); |
1472 | ir::value* lhs = x->get_operand(0); |
1473 | ir::value* rhs = x->get_operand(1); |
1474 | int i = 0; |
1475 | for(size_t j = 0; j < idxs_.at(lhs).size(); j ++){ |
1476 | vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]]; |
1477 | } |
1478 | for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ |
1479 | vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; |
1480 | } |
1481 | } |
1482 | |
1483 | |
1484 | |
1485 | /** |
1486 | * \brief Code Generation for `reshape` |
1487 | */ |
1488 | void generator::visit_reshape_inst(ir::reshape_inst* x) { |
1489 | auto idxs = idxs_.at(x); |
1490 | for(size_t i = 0; i < idxs_.at(x).size(); i ++){ |
1491 | ir::value* op = x->get_operand(0); |
1492 | vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]]; |
1493 | }; |
1494 | } |
1495 | |
1496 | /** |
1497 | * \brief Code Generation for `splat` |
1498 | */ |
1499 | void generator::visit_splat_inst(ir::splat_inst* x) { |
1500 | for(auto idx: idxs_.at(x)) |
1501 | vals_[x][idx] = vals_[x->get_operand(0)][{}]; |
1502 | } |
1503 | |
1504 | /** |
1505 | * \brief Code Generation for `broadcast` |
1506 | */ |
1507 | void generator::visit_broadcast_inst(ir::broadcast_inst* x) { |
1508 | ir::value* op = x->get_operand(0); |
1509 | const auto& shape = op->get_type()->get_block_shapes(); |
1510 | for(auto out_idx: idxs_.at(x)){ |
1511 | indices_t in_idx = out_idx; |
1512 | for(size_t k = 0; k < in_idx.size(); k++) |
1513 | in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k]; |
1514 | vals_[x][out_idx] = vals_[op][in_idx]; |
1515 | } |
1516 | // for(size_t i = 0; i < idxs_.at(x).size(); i++) |
1517 | // vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]]; |
1518 | } |
1519 | |
1520 | /** |
1521 | * \brief Code Generation for `downcast` |
1522 | */ |
1523 | void generator::visit_downcast_inst(ir::downcast_inst* x) { |
1524 | vals_[x][{}] = vals_[x->get_operand(0)][{i32(0)}]; |
1525 | } |
1526 | |
1527 | /** |
1528 | * \brief Code Generation for `get_program_id` |
1529 | */ |
1530 | void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) { |
1531 | Module *module = builder_->GetInsertBlock()->getModule(); |
1532 | Value *ret = tgt_->get_block_id(module, *builder_, pid->get_axis()); |
1533 | vals_[pid][{}] = ret; |
1534 | } |
1535 | |
1536 | /** |
1537 | * \brief Code Generation for `get_num_programs` |
1538 | */ |
1539 | void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) { |
1540 | Module *module = builder_->GetInsertBlock()->getModule(); |
1541 | Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis()); |
1542 | vals_[np][{}] = ret; |
1543 | } |
1544 | |
1545 | /** |
1546 | * \brief Code Generation for `exp` |
1547 | */ |
1548 | void generator::visit_exp_inst(ir::exp_inst* x){ |
1549 | Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634); |
1550 | std::vector<llvm::Type*> tys = {f32_ty}; |
1551 | FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); |
1552 | InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;" , "=f,0" , false); |
1553 | for(auto idx: idxs_.at(x)){ |
1554 | Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e); |
1555 | // Value *ex2arg = vals_[x->get_operand(0)][idx]; |
1556 | vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg}); |
1557 | } |
1558 | } |
1559 | |
1560 | /** |
1561 | * \brief Code Generation for `cos` |
1562 | */ |
1563 | void generator::visit_cos_inst(ir::cos_inst* x){ |
1564 | std::vector<llvm::Type*> tys = {f32_ty}; |
1565 | FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); |
1566 | InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;" , "=f,0" , false); |
1567 | for(auto idx: idxs_.at(x)){ |
1568 | vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]}); |
1569 | } |
1570 | } |
1571 | |
1572 | /** |
1573 | * \brief Code Generation for `umulhi` |
1574 | */ |
1575 | void generator::visit_umulhi_inst(ir::umulhi_inst* x){ |
1576 | std::vector<llvm::Type*> tys = {i32_ty, i32_ty}; |
1577 | FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false); |
1578 | InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;" , "=r,r,r" , false); |
1579 | for(auto idx: idxs_.at(x)){ |
1580 | Value* lhs = vals_[x->get_operand(0)][idx]; |
1581 | Value* rhs = vals_[x->get_operand(1)][idx]; |
1582 | vals_[x][idx] = call(umulhi, std::vector<llvm::Value*>{lhs, rhs}); |
1583 | } |
1584 | } |
1585 | |
1586 | /** |
1587 | * \brief Code Generation for `sin` |
1588 | */ |
1589 | void generator::visit_sin_inst(ir::sin_inst* x){ |
1590 | std::vector<llvm::Type*> tys = {f32_ty}; |
1591 | FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); |
1592 | InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;" , "=f,0" , false); |
1593 | for(auto idx: idxs_.at(x)){ |
1594 | vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]}); |
1595 | } |
1596 | } |
1597 | |
1598 | /** |
1599 | * \brief Code Generation for `log` |
1600 | */ |
1601 | void generator::visit_log_inst(ir::log_inst* x){ |
1602 | Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453); |
1603 | std::vector<llvm::Type*> tys = {f32_ty}; |
1604 | FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); |
1605 | InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;" , "=f,f" , false); |
1606 | for(auto idx: idxs_.at(x)){ |
1607 | Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]}); |
1608 | vals_[x][idx] = fmul(lg2arg, rcplog2e); |
1609 | } |
1610 | } |
1611 | |
1612 | /** |
1613 | * \brief Code Generation for `atomic_cas` |
1614 | */ |
1615 | void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { |
1616 | BasicBlock *current = builder_->GetInsertBlock(); |
1617 | Module *module = current->getModule(); |
1618 | Value *tid = tgt_->get_local_id(module, *builder_, 0); |
1619 | Value *pred = icmp_eq(tid, i32(0)); |
1620 | // BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); |
1621 | // BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); |
1622 | add_barrier(); |
1623 | tgt_->add_memfence(module, *builder_); |
1624 | Value *atom_ptr; |
1625 | atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "" ); |
1626 | atom_ptr = bit_cast(atom_ptr, ptr_ty(cvt(cas->get_type()->get_scalar_ty()), 3)); |
1627 | // cond_br(pred, tid_0_bb, tid_0_done_bb); |
1628 | // builder_->SetInsertPoint(tid_0_bb); |
1629 | Value *cas_ptr = vals_[cas->get_operand(0)][{}]; |
1630 | Value *cas_cmp = vals_[cas->get_operand(1)][{}]; |
1631 | Value *cas_val = vals_[cas->get_operand(2)][{}]; |
1632 | std::string asm_str = "@$1 atom.global.cas.b32 $0, [$2], $3, $4;" ; |
1633 | FunctionType *fn_ty = FunctionType::get(i32_ty, {pred->getType(), cas_ptr->getType(), cas_cmp->getType(), cas_val->getType()}, false); |
1634 | InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "=r,b,l,r,r" , true); |
1635 | add_barrier(); |
1636 | Value *old = call(iasm, {pred, cas_ptr, cas_cmp, cas_val}); |
1637 | add_barrier(); |
1638 | |
1639 | std::string asm2_str = "@$0 st.shared.b32 [$1], $2;" ; |
1640 | FunctionType *fn2_ty = FunctionType::get(void_ty, {pred->getType(), atom_ptr->getType(), old->getType()}, false); |
1641 | InlineAsm *iasm2 = InlineAsm::get(fn2_ty, asm2_str, "b,r,r" , true); |
1642 | add_barrier(); |
1643 | call(iasm2, {pred, atom_ptr, old}); |
1644 | tgt_->add_memfence(module, *builder_); |
1645 | add_barrier(); |
1646 | vals_[cas][{}] = load(atom_ptr); |
1647 | add_barrier(); |
1648 | } |
1649 | |
1650 | /** |
1651 | * \brief Code Generation for `atomic_rmw` |
1652 | */ |
1653 | void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { |
1654 | ir::value* ptr = atom->get_operand(0); |
1655 | ir::value* val = atom->get_operand(1); |
1656 | ir::value* msk = atom->get_operand(2); |
1657 | |
1658 | // vector size |
1659 | int vec = 1; |
1660 | Value *mask = builder_->getInt1(true); |
1661 | if(atom->get_type()->is_block_ty()){ |
1662 | auto shape = atom->get_type()->get_block_shapes(); |
1663 | int ld = ords_.at(ptr)[0]; |
1664 | unsigned alignment = alignment_->get(ptr, ld); |
1665 | vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment); |
1666 | vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1); |
1667 | // mask out inactive threads |
1668 | analysis::data_layout* layout = layouts_->get(val); |
1669 | auto curr_axes = a_axes_->get(val); |
1670 | auto layt_axes = layout->get_axes(); |
1671 | for(unsigned k = 0; k < layt_axes.size(); k++){ |
1672 | unsigned ax = layt_axes.at(k); |
1673 | distributed_axis dax = axes_.at(ax); |
1674 | // axis is part of the original layout: thread id should be 0 |
1675 | // but not the current layout |
1676 | if(std::find(curr_axes.begin(), curr_axes.end(), ax) == curr_axes.end()) |
1677 | mask = and_(mask, icmp_eq(dax.thread_id, i32(0))); |
1678 | } |
1679 | // last axis may spillover |
1680 | Value *thread_id = tgt_->get_local_id(mod_, *builder_, 0); |
1681 | int per_thread = 1; |
1682 | for(int ax: layt_axes) { per_thread *= axes_.at(ax).contiguous; } |
1683 | int numel = 1; |
1684 | for(int s: layout->get_shape()) { numel *= s; } |
1685 | mask = and_(mask, icmp_ult(mul(thread_id, i32(per_thread)), i32(numel))); |
1686 | } |
1687 | |
1688 | |
1689 | for(int i = 0; i < idxs_.at(val).size(); i += vec){ |
1690 | auto idx = idxs_[val][i]; |
1691 | Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec)); |
1692 | for(int ii = 0; ii < vec; ii++) |
1693 | rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii); |
1694 | Value *rmw_ptr = vals_[ptr][idx]; |
1695 | Value *rmw_msk = vals_[msk][idx]; |
1696 | rmw_msk = and_(rmw_msk, mask); |
1697 | if(vec == 1) |
1698 | rmw_val = extract_elt(rmw_val, i32(0)); |
1699 | Type* ty = rmw_val->getType(); |
1700 | size_t nbits = ty->getScalarSizeInBits(); |
1701 | // extract pointer offset |
1702 | std::string offset = "" ; |
1703 | if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr)) |
1704 | if(gep->getNumIndices() == 1) |
1705 | if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){ |
1706 | offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8); |
1707 | rmw_ptr = gep->getPointerOperand(); |
1708 | } |
1709 | rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1)); |
1710 | // asm argument type |
1711 | std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; |
1712 | // asm function type |
1713 | FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); |
1714 | // asm string |
1715 | std::string s_nbits = std::to_string(nbits); |
1716 | std::string name; |
1717 | std::string s_ty; |
1718 | using tt = ir::atomic_rmw_op_t; |
1719 | switch(atom->get_op()){ |
1720 | case tt::Or: name = "or" ; s_ty = "b" ; break; |
1721 | case tt::And: name = "and" ; s_ty = "b" ; break; |
1722 | case tt::Xor: name = "xor" , s_ty = "b" ; break; |
1723 | case tt::Add: name = "add" , s_ty = "s" ; break; |
1724 | case tt::Min: name = "min" , s_ty = "s" ; break; |
1725 | case tt::Max: name = "max" , s_ty = "s" ; break; |
1726 | case tt::UMin: name = "min" , s_ty = "u" ; break; |
1727 | case tt::UMax: name = "max" , s_ty = "u" ; break; |
1728 | case tt::FAdd: name = "add" , s_ty = "f" ; break; |
1729 | case tt::Xchg: name = "exch" , s_ty = "b" ; break; |
1730 | } |
1731 | std::string s_vec = vec == 2 ? "x2" : "" ; |
1732 | std::string mod = nbits == 16 ? ".noftz" : "" ; |
1733 | |
1734 | std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;" ; |
1735 | std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h" ); |
1736 | std::string constraint = "=" + ty_id + ",b,l," + ty_id; |
1737 | // create inline asm |
1738 | InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); |
1739 | // call asm |
1740 | if(atom->get_type()->is_block_ty()) |
1741 | vals_[atom][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val})); |
1742 | else{ |
1743 | Module *mod = builder_->GetInsertBlock()->getModule(); |
1744 | tgt_->add_memfence(mod, *builder_); |
1745 | add_barrier(); |
1746 | Value *tid = tgt_->get_local_id(mod, *builder_, 0); |
1747 | rmw_msk = builder_->CreateAnd(rmw_msk, icmp_eq(tid, i32(0))); |
1748 | Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val})); |
1749 | Value *atom_ptr; |
1750 | atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(atom)))), "" ); |
1751 | atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); |
1752 | store(old, atom_ptr); |
1753 | add_barrier(); |
1754 | vals_[atom][idx] = load(atom_ptr); |
1755 | add_barrier(); |
1756 | } |
1757 | } |
1758 | } |
1759 | |
1760 | /** |
1761 | * \brief Code Generation for `mma.884` (V100) |
1762 | */ |
1763 | //TODO: clean-up |
1764 | void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { |
1765 | // shapes |
1766 | auto shape_c = C->get_type()->get_block_shapes(); |
1767 | auto shape_a = A->get_type()->get_block_shapes(); |
1768 | auto shape_b = B->get_type()->get_block_shapes(); |
1769 | // order |
1770 | auto ord_a = layouts_->get(A)->get_order(); |
1771 | auto ord_b = layouts_->get(B)->get_order(); |
1772 | bool is_a_trans = C->is_trans_a(); |
1773 | // is_a_trans = false; |
1774 | if(C->is_trans_a()){ |
1775 | std::swap(ord_a[0], ord_a[1]); |
1776 | std::swap(shape_a[0], shape_a[1]); |
1777 | std::swap(offset_a_m_, offset_a_k_); |
1778 | } |
1779 | // std::cout << "visiting" << std::endl; |
1780 | // if(C->is_trans_b()){ |
1781 | // std::swap(ord_b[0], ord_b[1]); |
1782 | // std::swap(shape_b[0], shape_b[1]); |
1783 | // } |
1784 | // layouts |
1785 | analysis::mma_layout* layout_c = layouts_->get(C)->to_mma(); |
1786 | analysis::shared_layout* layout_a = layouts_->get(A)->to_shared(); |
1787 | analysis::shared_layout* layout_b = layouts_->get(B)->to_shared(); |
1788 | // vectorization |
1789 | int vec_a = swizzle_->get_vec(layout_a); |
1790 | int vec_b = swizzle_->get_vec(layout_b); |
1791 | // strides |
1792 | bool is_a_row = ord_a[0] != 0; |
1793 | bool is_b_row = ord_b[0] != 0; |
1794 | int stride_am = is_a_row ? shape_a[1] : 1; |
1795 | int stride_ak = is_a_row ? 1 : shape_a[0]; |
1796 | int stride_a0 = is_a_row ? stride_ak : stride_am; |
1797 | int stride_a1 = is_a_row ? stride_am : stride_ak; |
1798 | int stride_bn = is_b_row ? 1 : shape_b[0]; |
1799 | int stride_bk = is_b_row ? shape_b[1] : 1; |
1800 | int stride_b0 = is_b_row ? stride_bn : stride_bk; |
1801 | int stride_b1 = is_b_row ? stride_bk : stride_bn; |
1802 | int stride_rep_m = layout_c->wpt(0) * layout_c->fpw(0) * 8; |
1803 | int stride_rep_n = layout_c->wpt(1) * layout_c->fpw(1) * 8; |
1804 | int stride_rep_k = 1; |
1805 | // swizzling |
1806 | int per_phase_a = swizzle_->get_per_phase(layout_a); |
1807 | int max_phase_a = swizzle_->get_max_phase(layout_a); |
1808 | int step_a0 = is_a_row ? stride_rep_k : stride_rep_m; |
1809 | int num_ptr_a = std::max(2 * per_phase_a * max_phase_a / step_a0, 1); |
1810 | int per_phase_b = swizzle_->get_per_phase(layout_b); |
1811 | int max_phase_b = swizzle_->get_max_phase(layout_b); |
1812 | int step_b0 = is_b_row ? stride_rep_n : stride_rep_k; |
1813 | int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1); |
1814 | |
1815 | |
1816 | // max_phase_a = 4; |
1817 | // vec_a = 8; |
1818 | // std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl; |
1819 | // std::cout << vec_a << " " << vec_b << std::endl; |
1820 | |
1821 | /* --------------------------------- */ |
1822 | /* --- pre-compute pointer lanes --- */ |
1823 | /* --------------------------------- */ |
1824 | BasicBlock* curr_bb = builder_->GetInsertBlock(); |
1825 | BasicBlock* entry = &curr_bb->getParent()->getEntryBlock(); |
1826 | if(entry != curr_bb) |
1827 | builder_->SetInsertPoint(entry->getTerminator()); |
1828 | Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c]; |
1829 | Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c]; |
1830 | Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a)); |
1831 | std::vector<Value*> off_a(num_ptr_a); |
1832 | for(int i = 0; i < num_ptr_a; i++){ |
1833 | Value* off_a0i = add(off_a0, i32(i*(is_a_row?4:stride_rep_m))); |
1834 | off_a0i = exact_udiv(off_a0i, i32(vec_a)); |
1835 | off_a0i = xor_(off_a0i, phase_a); |
1836 | off_a0i = mul(off_a0i, i32(vec_a)); |
1837 | off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1))); |
1838 | } |
1839 | Value* off_b0 = is_b_row ? offset_b_n_[layout_c] : offset_b_k_[layout_c]; |
1840 | Value* off_b1 = is_b_row ? offset_b_k_[layout_c] : offset_b_n_[layout_c]; |
1841 | Value* phase_b = urem(udiv(off_b1, i32(per_phase_b)), i32(max_phase_b)); |
1842 | std::vector<Value*> off_b(num_ptr_b); |
1843 | for(int i = 0; i < num_ptr_b; i++){ |
1844 | Value* off_b0i = add(off_b0, i32(i*(is_b_row?stride_rep_n:4))); |
1845 | off_b0i = udiv(off_b0i, i32(vec_b)); |
1846 | off_b0i = xor_(off_b0i, phase_b); |
1847 | off_b0i = mul(off_b0i, i32(vec_b)); |
1848 | off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1))); |
1849 | } |
1850 | builder_->SetInsertPoint(curr_bb); |
1851 | |
1852 | /* --------------------------------- */ |
1853 | /* --- MMA intrinsic --- */ |
1854 | /* --------------------------------- */ |
1855 | Type *f16x2_ty = vec_ty(f16_ty, 2); |
1856 | Type *ret_ty = StructType::get(*ctx_, {f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty}); |
1857 | std::vector<Type*> arg_ty = {f16x2_ty, f16x2_ty, f16x2_ty, f16x2_ty, |
1858 | f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty}; |
1859 | InlineAsm *mma = InlineAsm::get(FunctionType::get(ret_ty, arg_ty, false), |
1860 | " mma.sync.aligned.m8n8k4." |
1861 | + std::string(is_a_row ? "row" : "col" ) |
1862 | + "." |
1863 | + std::string(is_b_row ? "row" : "col" ) |
1864 | + ".f32.f16.f16.f32 " |
1865 | "{$0, $1, $2, $3, $4, $5, $6, $7}, " |
1866 | "{$8, $9}, " |
1867 | "{$10, $11}, " |
1868 | "{$0, $1, $2, $3, $4, $5, $6, $7};" , "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7" , false); |
1869 | |
1870 | |
1871 | std::vector<Value*> ptr_a(num_ptr_a); |
1872 | std::vector<Value*> ptr_b(num_ptr_b); |
1873 | std::map<std::pair<int, int>, std::pair<Value*, Value*>> has, hbs; |
1874 | for(int i = 0; i < num_ptr_a; i++) |
1875 | ptr_a[i] = gep(shmems_[A], off_a[i]); |
1876 | for(int i = 0; i < num_ptr_b; i++) |
1877 | ptr_b[i] = gep(shmems_[B], off_b[i]); |
1878 | |
1879 | |
1880 | // initialize accumulators |
1881 | std::vector<Value*> acc; |
1882 | for(indices_t idx: idxs_.at(C)) |
1883 | acc.push_back(vals_[D][idx]); |
1884 | |
1885 | unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->shape_per_cta(0); |
1886 | unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->shape_per_cta(1); |
1887 | |
1888 | // create mma & unpack result |
1889 | auto call_mma = [&](unsigned m, unsigned n, unsigned K) { |
1890 | auto ha = has[{m, K}]; |
1891 | auto hb = hbs[{n, K}]; |
1892 | // arguments |
1893 | std::vector<size_t> idx = { |
1894 | (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, |
1895 | (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, |
1896 | (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, |
1897 | (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m |
1898 | }; |
1899 | std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second}; |
1900 | for(unsigned i = 0; i < 8; i++) |
1901 | args.push_back(acc[idx[i]]); |
1902 | // execute mma |
1903 | Value *nc = call(mma, args); |
1904 | // unpack |
1905 | for(unsigned i = 0; i < 8; i++) |
1906 | acc[idx[i]] = extract_val(nc, {i}); |
1907 | }; |
1908 | |
1909 | ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A); |
1910 | ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B); |
1911 | |
1912 | // Cache lds value. If values are prefetched, create phi node |
1913 | // @param inc: incoming block (0 = header, 1 = loop) |
1914 | auto register_lds = |
1915 | [&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) { |
1916 | if (K == 0 && is_prefetch) { |
1917 | ir::basic_block* inc_block = phiA->get_incoming_block(inc); |
1918 | lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block)); |
1919 | lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block)); |
1920 | } else |
1921 | vals[{m, K}] = {val0, val1}; |
1922 | }; |
1923 | |
1924 | auto load_a = [&](int m, int K, int inc, bool is_prefetch) { |
1925 | int offidx = (is_a_row ? K/4 : m) % num_ptr_a; |
1926 | Value* ptra; |
1927 | if(K==0 && is_prefetch){ |
1928 | if(inc == 0) |
1929 | ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); |
1930 | else |
1931 | ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); |
1932 | } |
1933 | else |
1934 | ptra = ptr_a[offidx]; |
1935 | int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); |
1936 | int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K; |
1937 | Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak)); |
1938 | Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3))); |
1939 | // record lds that needs to be moved |
1940 | if (K == 0 && inc == 1 && is_prefetch) |
1941 | prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha); |
1942 | Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty); |
1943 | Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty); |
1944 | register_lds(has, m, K, inc, ha00, ha01, is_prefetch); |
1945 | if(vec_a > 4){ |
1946 | Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty); |
1947 | Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty); |
1948 | if(is_a_row) |
1949 | register_lds(has, m, K+4, inc, ha10, ha11, is_prefetch); |
1950 | else |
1951 | register_lds(has, m+1, K, inc, ha10, ha11, is_prefetch); |
1952 | } |
1953 | }; |
1954 | |
1955 | auto load_b = [&](int n, int K, int inc, bool is_prefetch) { |
1956 | int offidx = (is_b_row? n : K/4) % num_ptr_b; |
1957 | Value* ptrb; |
1958 | if(K==0 && is_prefetch){ |
1959 | if(inc == 0) |
1960 | ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); |
1961 | else |
1962 | ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); |
1963 | } else |
1964 | ptrb = ptr_b[offidx]; |
1965 | |
1966 | int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; |
1967 | int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b); |
1968 | Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk)); |
1969 | Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3))); |
1970 | // record lds that needs to be moved |
1971 | if (K == 0 && inc == 1 && is_prefetch) |
1972 | prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb); |
1973 | Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty); |
1974 | Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty); |
1975 | register_lds(hbs, n, K, inc, hb00, hb01, is_prefetch); |
1976 | if(vec_b > 4){ |
1977 | Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty); |
1978 | Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty); |
1979 | if(is_b_row) |
1980 | register_lds(hbs, n+1, K, inc, hb10, hb11, is_prefetch); |
1981 | else |
1982 | register_lds(hbs, n, K+4, inc, hb10, hb11, is_prefetch); |
1983 | } |
1984 | |
1985 | }; |
1986 | |
1987 | // update accumulators |
1988 | if (C->is_prefetched()) { |
1989 | // create phis |
1990 | builder_->SetInsertPoint(curr_bb->getFirstNonPHI()); |
1991 | for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) { |
1992 | has[{m, 0}].first = phi(f16x2_ty, 2); |
1993 | has[{m, 0}].second = phi(f16x2_ty, 2); |
1994 | if (!is_a_row && vec_a>4) { |
1995 | has[{m+1, 0}].first = phi(f16x2_ty, 2); |
1996 | has[{m+1, 0}].second = phi(f16x2_ty, 2); |
1997 | } |
1998 | } |
1999 | for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) { |
2000 | hbs[{n, 0}].first = phi(f16x2_ty, 2); |
2001 | hbs[{n, 0}].second = phi(f16x2_ty, 2); |
2002 | if (is_b_row && vec_b>4) { |
2003 | hbs[{n+1, 0}].first = phi(f16x2_ty, 2); |
2004 | hbs[{n+1, 0}].second = phi(f16x2_ty, 2); |
2005 | } |
2006 | } |
2007 | |
2008 | // insert prefetched lds at the end of loop header |
2009 | builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); |
2010 | for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) |
2011 | load_a(m, 0, 0, true); |
2012 | for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) |
2013 | load_b(n, 0, 0, true); |
2014 | |
2015 | // update accumulators |
2016 | builder_->SetInsertPoint(curr_bb); |
2017 | for (unsigned K = 0; K < NK; K += 4) { |
2018 | int NEXTK = (K + 4) % NK; |
2019 | // prefetch A |
2020 | for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2) |
2021 | load_a(m, NEXTK, 1, true); |
2022 | // prefetch B |
2023 | for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1) |
2024 | load_b(n, NEXTK, 1, true); |
2025 | // tensor core ops |
2026 | for(unsigned m = 0; m < num_m/2; m++) |
2027 | for(unsigned n = 0; n < num_n/2; n++){ |
2028 | call_mma(m, n, K); |
2029 | } |
2030 | } |
2031 | } else { // not prefetched |
2032 | for(unsigned K = 0; K < NK; K += 4) |
2033 | for(unsigned m = 0; m < num_m/2; m++) |
2034 | for(unsigned n = 0; n < num_n/2; n++) { |
2035 | if(has.find({m, K}) == has.end()) |
2036 | load_a(m, K, /*inc*/0, /*is_prefetch*/false); |
2037 | if(hbs.find({n, K}) == hbs.end()) |
2038 | load_b(n, K, /*inc*/0, /*is_prefetch*/false); |
2039 | call_mma(m, n, K); |
2040 | } |
2041 | } |
2042 | |
2043 | // write back accumulators |
2044 | for(size_t i = 0; i < idxs_.at(C).size(); i++) |
2045 | vals_[C][idxs_[C][i]] = acc[i]; |
2046 | } |
2047 | |
2048 | namespace { |
2049 | class mma16816_smem_loader { |
2050 | public: |
2051 | mma16816_smem_loader(int wpt, std::vector<int> order, int k_order, |
2052 | std::vector<unsigned> tile_shape, |
2053 | std::vector<int> instr_shape, std::vector<int> mat_shape, |
2054 | int per_phase, int max_phase, int dtsize, Builder *builder, |
2055 | adder add, multiplier mul, geper gep) |
2056 | : wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape), |
2057 | instr_shape_(instr_shape), mat_shape_(mat_shape), |
2058 | per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder), |
2059 | add(add), mul(mul), gep(gep) { |
2060 | // compute compile-time constant variables & types |
2061 | c_mat_shape_ = mat_shape[order[0]]; |
2062 | s_mat_shape_ = mat_shape[order[1]]; |
2063 | |
2064 | c_stride_ = tile_shape[order[1]]; |
2065 | s_stride_ = tile_shape[order[0]]; |
2066 | |
2067 | // rule: k must be the fast-changing axis |
2068 | need_trans_ = k_order_ != order_[0]; |
2069 | can_use_ldmatrix_ = dtsize == 2 || (!need_trans_); |
2070 | |
2071 | // we need more pointers at the fast-changing axis, |
2072 | if (can_use_ldmatrix_) |
2073 | num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]]; |
2074 | else // warning: this only works for tf32 & need transpose |
2075 | num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]]; |
2076 | num_ptr_ = std::max<int>(num_ptr_, 2); |
2077 | |
2078 | // special rule for i8/u8, 4 ptrs for each matrix |
2079 | if (!can_use_ldmatrix_ && dtsize_ == 1) |
2080 | num_ptr_ *= 4; |
2081 | |
2082 | // load_v4 stride (in num of mats) |
2083 | int load_stride_in_mat[2]; |
2084 | load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2 |
2085 | load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]); |
2086 | p_load_stride_in_mat_ = load_stride_in_mat[order[0]]; |
2087 | // stride in mat, used by load_v4 |
2088 | s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]); |
2089 | } |
2090 | |
2091 | std::vector<Value*> compute_offs(Value *warp_off, Value *lane) { |
2092 | // TODO: this needs to be moved to constructor (and extracted to arr_order) |
2093 | mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_; |
2094 | warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1]; |
2095 | // start matrix logic offset (rename it as base_mat_off?) |
2096 | Value *mat_off[2] = {nullptr, nullptr}; |
2097 | |
2098 | if (can_use_ldmatrix_) { |
2099 | // c: lane idx inside a group (a group is a collection of 8 contiguous threads) |
2100 | // s: group idx (0,1,2,3) inside a warp |
2101 | Value *c = urem(lane, i32(8)); |
2102 | Value *s = udiv(lane, i32(8)); |
2103 | // We can decompose s => s_0, s_1... |
2104 | Value *s0 = urem(s, i32(2)); |
2105 | Value *s1 = udiv(s, i32(2)); |
2106 | |
2107 | // We use different orders for a & b for better performance. |
2108 | Value *k_mat_arr = (k_order_ == 1) ? s1 : s0; |
2109 | Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1; |
2110 | mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)), |
2111 | mul(nk_mat_arr, i32(mat_arr_stride_))); |
2112 | mat_off[k_order_] = k_mat_arr; |
2113 | // physical offset (before swizzling) |
2114 | Value *c_mat_off = mat_off[order_[0]]; |
2115 | Value *s_mat_off = mat_off[order_[1]]; |
2116 | // offset inside a matrix |
2117 | Value *s_off_in_mat = c; |
2118 | |
2119 | std::vector<Value*> offs(num_ptr_); |
2120 | Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); |
2121 | // pre-compute strided offset |
2122 | Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); |
2123 | for (int i=0; i < num_ptr_; ++i) { |
2124 | Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_)); |
2125 | c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle |
2126 | offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_))); |
2127 | } |
2128 | return offs; |
2129 | } else if (dtsize_ == 4 && need_trans_) { |
2130 | // load tf32 matrices with lds32 |
2131 | Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]] |
2132 | Value *s_off_in_mat = urem(lane, i32(4)); // |
2133 | |
2134 | Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); |
2135 | std::vector<Value*> offs(num_ptr_); |
2136 | for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time |
2137 | int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; |
2138 | int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; |
2139 | if (k_mat_arr_int > 0) // we don't need pointers for k |
2140 | continue; |
2141 | Value *k_mat_arr = i32(k_mat_arr_int); |
2142 | Value *nk_mat_arr = i32(nk_mat_arr_int); |
2143 | // physical offset (before swizzling) |
2144 | Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), |
2145 | mul(nk_mat_arr, i32(mat_arr_stride_))); |
2146 | Value *s_mat_off = k_mat_arr; // always 0? |
2147 | Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); |
2148 | // FIXME: (k_order_ == 1?) is really dirty hack |
2149 | for (int i = 0; i < num_ptr_/2; ++i) { |
2150 | Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2))); |
2151 | c_mat_off_i = xor_(c_mat_off_i, phase); |
2152 | Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); |
2153 | // TODO: move this out of the loop |
2154 | c_off = urem(c_off, i32(tile_shape_[order_[0]])); |
2155 | s_off = urem(s_off, i32(tile_shape_[order_[1]])); |
2156 | offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_))); |
2157 | } |
2158 | } |
2159 | return offs; |
2160 | // throw std::runtime_error("not implemented"); |
2161 | } else if (dtsize_ == 1 && need_trans_) { |
2162 | // load i8/u8 matrices with lds8 |
2163 | Value *c_off_in_mat = udiv(lane, i32(4)); // |
2164 | Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols |
2165 | |
2166 | // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); |
2167 | std::vector<Value*> offs(num_ptr_); |
2168 | for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time |
2169 | int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; |
2170 | int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; |
2171 | if (k_mat_arr_int > 0) // we don't need pointers for k |
2172 | continue; |
2173 | Value *k_mat_arr = i32(k_mat_arr_int); |
2174 | Value *nk_mat_arr = i32(nk_mat_arr_int); |
2175 | // physical offset (before swizzling) |
2176 | Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), |
2177 | mul(nk_mat_arr, i32(mat_arr_stride_))); |
2178 | Value *s_mat_off = k_mat_arr; // always 0? |
2179 | |
2180 | for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) { |
2181 | for (int elem_off = 0; elem_off < 4; ++elem_off) { |
2182 | int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off; |
2183 | |
2184 | Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2))); |
2185 | Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off)); |
2186 | |
2187 | // disable swizzling ... |
2188 | // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); |
2189 | // c_mat_off_i = xor_(c_mat_off_i, phase); |
2190 | |
2191 | Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); |
2192 | Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_))); |
2193 | // To prevent out-of-bound access when the tile is too small |
2194 | c_off = urem(c_off, i32(tile_shape_[order_[0]])); |
2195 | s_off = urem(s_off, i32(tile_shape_[order_[1]])); |
2196 | offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_))); |
2197 | } |
2198 | } |
2199 | } |
2200 | return offs; |
2201 | } else |
2202 | throw std::runtime_error("invalid smem load config" ); |
2203 | } |
2204 | |
2205 | std::tuple<Value*, Value*, Value*, Value*> |
2206 | load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn, |
2207 | Value *pre_ptr, Value *next_ptr, std::vector<Value*> &off, std::vector<Value*> &ptrs, |
2208 | FunctionType *ldmatrix_ty, Type *smem_ptr_ty, |
2209 | std::map<ir::value*, std::vector<Value*>> &prefetch_latch_to_bb_) { |
2210 | assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned" ); |
2211 | int mat_idx[2] = {mat0, mat1}; |
2212 | int k = mat_idx[k_order_]; |
2213 | |
2214 | int ptr_idx = -1; |
2215 | if (can_use_ldmatrix_) |
2216 | ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]); |
2217 | else if (dtsize_ == 4 && need_trans_) // tf32 & trans |
2218 | ptr_idx = mat_idx[order_[0]]; |
2219 | else // i8 & trans |
2220 | ptr_idx = mat_idx[order_[0]] * 4; |
2221 | |
2222 | auto get_ptr = [&](int idx) -> Value* { |
2223 | Value *ptr = nullptr; |
2224 | if (k == 0 && is_prefetch) { |
2225 | if (inc == 0) |
2226 | ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty); |
2227 | else |
2228 | ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty); |
2229 | } else |
2230 | ptr = ptrs.at(idx); |
2231 | return ptr; |
2232 | }; |
2233 | Value *ptr = get_ptr(ptr_idx); |
2234 | |
2235 | Value *res_v4 = nullptr; |
2236 | if (can_use_ldmatrix_) { |
2237 | std::string trans = need_trans_ ? ".trans" : "" ; |
2238 | // the offset (in byte) on the strided axis is a constant |
2239 | int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_; |
2240 | InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty, |
2241 | "ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 " |
2242 | "{$0, $1, $2, $3}, " |
2243 | "[$4 + " + std::to_string(s_offset) + "];" , |
2244 | "=r,=r,=r,=r,r" , true); |
2245 | assert(ptr); |
2246 | res_v4 = call(ldmatrix_ty, ld_fn, {ptr}); |
2247 | if (k == 0 && inc == 1 && is_prefetch) |
2248 | prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4); |
2249 | return {extract_val(res_v4, std::vector<unsigned>{0}), |
2250 | extract_val(res_v4, std::vector<unsigned>{1}), |
2251 | extract_val(res_v4, std::vector<unsigned>{2}), |
2252 | extract_val(res_v4, std::vector<unsigned>{3})}; |
2253 | } else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices |
2254 | Value *ptr2 = get_ptr(ptr_idx+1); |
2255 | assert(s_mat_stride_ == 1); |
2256 | int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; |
2257 | int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; |
2258 | Value *elem0, *elem1, *elem2, *elem3; |
2259 | if (k_order_ == 1) { |
2260 | elem0 = load(gep(ptr, i32(s_offset_elem))); |
2261 | elem1 = load(gep(ptr2, i32(s_offset_elem))); |
2262 | elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); |
2263 | elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); |
2264 | } else { // for b (k first) |
2265 | elem0 = load(gep(ptr, i32(s_offset_elem))); |
2266 | elem2 = load(gep(ptr2, i32(s_offset_elem))); |
2267 | elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); |
2268 | elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); |
2269 | } |
2270 | if (k == 0 && inc == 1 && is_prefetch) { |
2271 | prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0); |
2272 | prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1); |
2273 | prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2); |
2274 | prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3); |
2275 | } |
2276 | return {elem0, elem1, elem2, elem3}; |
2277 | } else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices |
2278 | Value *ptr00 = get_ptr(ptr_idx); |
2279 | Value *ptr01 = get_ptr(ptr_idx+1); |
2280 | Value *ptr02 = get_ptr(ptr_idx+2); |
2281 | Value *ptr03 = get_ptr(ptr_idx+3); |
2282 | |
2283 | Value *ptr10 = get_ptr(ptr_idx+4); |
2284 | Value *ptr11 = get_ptr(ptr_idx+5); |
2285 | Value *ptr12 = get_ptr(ptr_idx+6); |
2286 | Value *ptr13 = get_ptr(ptr_idx+7); |
2287 | |
2288 | assert(s_mat_stride_ == 1); |
2289 | int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; |
2290 | int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; |
2291 | |
2292 | Value *i8v4_elems[4]; |
2293 | Value *i32_elems[4]; |
2294 | for (int i=0; i<4; ++i) |
2295 | i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4)); |
2296 | |
2297 | Value *elem00, *elem01, *elem02, *elem03; |
2298 | Value *elem10, *elem11, *elem12, *elem13; |
2299 | Value *elem20, *elem21, *elem22, *elem23; |
2300 | Value *elem30, *elem31, *elem32, *elem33; |
2301 | Value *i8_elems[4*4]; |
2302 | if (k_order_ == 1) { // |
2303 | i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); |
2304 | i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); |
2305 | i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); |
2306 | i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); |
2307 | |
2308 | assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); |
2309 | |
2310 | i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); |
2311 | i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); |
2312 | i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); |
2313 | i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); |
2314 | |
2315 | i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); |
2316 | i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); |
2317 | i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); |
2318 | i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); |
2319 | |
2320 | i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); |
2321 | i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); |
2322 | i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); |
2323 | i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); |
2324 | |
2325 | for (int m=0; m<4; ++m) { |
2326 | for (int e=0; e<4; ++e) |
2327 | i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); |
2328 | i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); |
2329 | } |
2330 | } else { // for b (k first) |
2331 | i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); |
2332 | i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); |
2333 | i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); |
2334 | i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); |
2335 | |
2336 | assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); |
2337 | |
2338 | i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); |
2339 | i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); |
2340 | i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); |
2341 | i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); |
2342 | |
2343 | i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); |
2344 | i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); |
2345 | i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); |
2346 | i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); |
2347 | |
2348 | i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); |
2349 | i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); |
2350 | i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); |
2351 | i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); |
2352 | |
2353 | for (int m=0; m<4; ++m) { |
2354 | for (int e=0; e<4; ++e) |
2355 | i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); |
2356 | i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); |
2357 | } |
2358 | } |
2359 | if (k == 0 && inc == 1 && is_prefetch) { |
2360 | for (int m = 0; m < 4; ++m) |
2361 | for (int e = 0; e < 4; ++e) |
2362 | prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]); |
2363 | } |
2364 | return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]}; |
2365 | } else |
2366 | throw std::runtime_error("invalid smem load" ); |
2367 | } |
2368 | |
2369 | int get_num_ptr() const { return num_ptr_; } |
2370 | |
2371 | private: |
2372 | int wpt_; |
2373 | std::vector<int> order_; |
2374 | int k_order_; |
2375 | std::vector<unsigned> tile_shape_; |
2376 | std::vector<int> instr_shape_; |
2377 | std::vector<int> mat_shape_; |
2378 | int per_phase_, max_phase_; |
2379 | int dtsize_; |
2380 | |
2381 | // generated |
2382 | int c_mat_shape_, s_mat_shape_; |
2383 | int c_stride_, s_stride_; |
2384 | // p_: on the pointer axis |
2385 | int p_load_stride_in_mat_; |
2386 | int s_mat_stride_; |
2387 | // stride when moving to next not-k mat |
2388 | int warp_off_stride_; |
2389 | int mat_arr_stride_; // matrix arrangement (inside a load) stride |
2390 | bool need_trans_, can_use_ldmatrix_; |
2391 | int num_ptr_; |
2392 | |
2393 | Builder *builder_; |
2394 | adder add; |
2395 | multiplier mul; |
2396 | geper gep; |
2397 | }; |
2398 | } |
2399 | |
2400 | /** |
2401 | * \brief Code Generation for `mma.16816` (A100) |
2402 | */ |
2403 | //TODO: clean-up |
2404 | void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { |
2405 | const std::vector<unsigned>& shapes = C->get_type()->get_block_shapes(); |
2406 | std::map<std::vector<Value*>, std::vector<Value*>> fcs; |
2407 | for(indices_t idx: idxs_.at(C)){ |
2408 | std::vector<Value*> key(idx.size() - 2); |
2409 | std::copy(idx.begin() + 2, idx.end(), key.begin()); |
2410 | fcs[key].push_back(vals_[D][idx]); |
2411 | }; |
2412 | auto shape_a = A->get_type()->get_block_shapes(); |
2413 | auto shape_b = B->get_type()->get_block_shapes(); |
2414 | auto ord_a = layouts_->get(A)->get_order(); |
2415 | if(C->is_trans_a()){ |
2416 | std::swap(ord_a[0], ord_a[1]); |
2417 | std::swap(shape_a[0], shape_a[1]); |
2418 | } |
2419 | auto ord_b = layouts_->get(B)->get_order(); |
2420 | if(C->is_trans_b()){ |
2421 | std::swap(ord_b[0], ord_b[1]); |
2422 | std::swap(shape_b[0], shape_b[1]); |
2423 | } |
2424 | NK = shape_a[1]; |
2425 | analysis::mma_layout* layout = layouts_->get(C)->to_mma(); |
2426 | |
2427 | std::vector<int> mma_instr_shape = layout->get_mma_instr_shape(); |
2428 | const int mma_instr_m = mma_instr_shape[0]; |
2429 | const int mma_instr_n = mma_instr_shape[1]; |
2430 | const int mma_instr_k = mma_instr_shape[2]; |
2431 | |
2432 | std::vector<int> mat_shape = layout->get_mma_mat_shape(); |
2433 | const int mat_shape_m = mat_shape[0]; |
2434 | const int mat_shape_n = mat_shape[1]; |
2435 | const int mat_shape_k = mat_shape[2]; |
2436 | |
2437 | |
2438 | const int num_rep_m = shapes[0] / layout->shape_per_cta(0); |
2439 | const int num_rep_n = shapes[1] / layout->shape_per_cta(1); |
2440 | const int num_rep_k = std::max<int>(NK/mma_instr_k, 1); |
2441 | |
2442 | // floating point types |
2443 | Type *fp32_ty = f32_ty; |
2444 | Type *fp16x2_ty = vec_ty(f16_ty, 2); |
2445 | Type *bf16x2_ty = vec_ty(bf16_ty, 2); |
2446 | Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); |
2447 | Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty}); |
2448 | Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); |
2449 | // integer types |
2450 | Type *i8x4_ty = vec_ty(i8_ty, 4); |
2451 | Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty}); |
2452 | Type *i32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty}); |
2453 | |
2454 | |
2455 | FunctionType *ldmatrix_ty = nullptr; |
2456 | FunctionType *mma_ty = nullptr; |
2457 | Type *phi_ty = nullptr; |
2458 | Type *smem_ptr_ty = nullptr; |
2459 | |
2460 | ir::type *A_ir_ty = A->get_type()->get_scalar_ty(); |
2461 | ir::type *B_ir_ty = B->get_type()->get_scalar_ty(); |
2462 | if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) { |
2463 | mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); |
2464 | smem_ptr_ty = ptr_ty(f16_ty, 3); |
2465 | ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); |
2466 | phi_ty = fp16x2_ty; |
2467 | } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) { |
2468 | mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); |
2469 | smem_ptr_ty = ptr_ty(bf16_ty, 3); |
2470 | ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); |
2471 | phi_ty = bf16x2_ty; |
2472 | } else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) { |
2473 | mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); |
2474 | smem_ptr_ty = ptr_ty(fp32_ty, 3); |
2475 | ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); |
2476 | phi_ty = fp32_ty; |
2477 | } else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) { |
2478 | // FIXME: We should use i8 here (but nvptx will generate extra casts when using i8) |
2479 | mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); |
2480 | smem_ptr_ty = ptr_ty(i8_ty, 3); |
2481 | ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); |
2482 | phi_ty = i32_ty; |
2483 | // mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); |
2484 | // smem_ptr_ty = ptr_ty(i8_ty, 3); |
2485 | // ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); |
2486 | // phi_ty = i8x4_ty; |
2487 | } else |
2488 | throw std::runtime_error("mma16816 data type not supported" ); |
2489 | |
2490 | // left-hand-side values |
2491 | std::map<std::pair<unsigned, unsigned>, Value*> ha; |
2492 | std::map<std::pair<unsigned, unsigned>, Value*> hb; |
2493 | |
2494 | BasicBlock* CurrBB = builder_->GetInsertBlock(); |
2495 | BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); |
2496 | |
2497 | // if true, this will move pointer declarations to the entry basic block |
2498 | // not prefetched cases tend to be more limited in resource usage |
2499 | // so we don't pre-compute ptrs to save registers |
2500 | bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB); |
2501 | if(licm_ptrs) |
2502 | builder_->SetInsertPoint(FirstBB->getTerminator()); |
2503 | |
2504 | Value* thread = tgt_->get_local_id(mod_, *builder_, 0); |
2505 | Value *lane = urem(thread, i32(32)); |
2506 | Value *warp = udiv(thread, i32(32)); |
2507 | Value *warp_mn = udiv(warp, i32(layout->wpt(0))); |
2508 | Value *warp_m = urem(warp, i32(layout->wpt(0))); |
2509 | Value *warp_n = urem(warp_mn, i32(layout->wpt(1))); |
2510 | std::vector<Value *>& fc = fcs.begin()->second; |
2511 | |
2512 | size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; |
2513 | size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; |
2514 | |
2515 | ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A); |
2516 | ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B); |
2517 | auto register_lds2 = |
2518 | [&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { |
2519 | if (k < 2 && is_prefetch) { |
2520 | ir::basic_block* inc_block = phiA->get_incoming_block(inc); |
2521 | lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); |
2522 | } else |
2523 | vals[{mn, k}] = val; |
2524 | }; |
2525 | |
2526 | // | -> k (row-major), since we have ldmatrix.trans, we only need to change stride |
2527 | // v (s0_0(0), s1_0(2), | *num_rep_k |
2528 | // m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2) |
2529 | // ----------- |
2530 | // *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0)) |
2531 | std::function<void(int,int,int,bool)> load_a; |
2532 | analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared(); |
2533 | bool is_a_shared = layout_a != nullptr; |
2534 | if(is_a_shared) { |
2535 | const int per_phase_a = swizzle_->get_per_phase(layout_a); |
2536 | const int max_phase_a = swizzle_->get_max_phase(layout_a); |
2537 | mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, |
2538 | {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, |
2539 | per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); |
2540 | std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane); |
2541 | int num_ptr_a = a_loader.get_num_ptr(); |
2542 | // pointers |
2543 | std::vector<Value*> ptrs_a(num_ptr_a); |
2544 | if(licm_ptrs) |
2545 | builder_->SetInsertPoint(CurrBB); |
2546 | for(int i = 0; i < num_ptr_a; i++) |
2547 | ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty); |
2548 | if(licm_ptrs) |
2549 | builder_->SetInsertPoint(FirstBB->getTerminator()); |
2550 | // loading function |
2551 | load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable { |
2552 | auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], |
2553 | shared_next_ptr_[layout_a], off_a, ptrs_a, |
2554 | ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); |
2555 | register_lds2(ha, m, k, inc, ha0, is_prefetch); |
2556 | register_lds2(ha, m+1, k, inc, ha1, is_prefetch); |
2557 | register_lds2(ha, m, k+1, inc, ha2, is_prefetch); |
2558 | register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch); |
2559 | }; |
2560 | } |
2561 | else { |
2562 | load_a = [&](int m, int k, int inc, bool is_prefetch) { |
2563 | distributed_axis ax_n = axes_.at(a_axes_->get(A, 1)); |
2564 | int ldm = ax_n.values.size(); |
2565 | if(ldm != num_rep_k*4) |
2566 | throw std::runtime_error("Internal compiler error when trying to fuse matmuls!" ); |
2567 | // std::cout << m << " " << k << std::endl; |
2568 | // std::cout << idxs_[A].size() << std::endl; |
2569 | // std::cout << (m+1)*ldm + k*2 + 3 << std::endl; |
2570 | // int ldm = num_rep_k*4; |
2571 | Value* ha0 = UndefValue::get(phi_ty); // e.g., fp16x2 |
2572 | Value* ha1 = UndefValue::get(phi_ty); |
2573 | Value* ha2 = UndefValue::get(phi_ty); |
2574 | Value* ha3 = UndefValue::get(phi_ty); |
2575 | ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0)); |
2576 | ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1)); |
2577 | ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0)); |
2578 | ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1)); |
2579 | ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0)); |
2580 | ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1)); |
2581 | ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0)); |
2582 | ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1)); |
2583 | ha[{m, k}] = ha0; |
2584 | ha[{m+1, k}] = ha1; |
2585 | ha[{m, k+1}] = ha2; |
2586 | ha[{m+1, k+1}] = ha3; |
2587 | }; |
2588 | } |
2589 | |
2590 | |
2591 | // | -> n (col-major) |
2592 | // v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n |
2593 | // k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1)) |
2594 | // ----------- |
2595 | // *num_rep_k (stride in num of matrices(mat_stride_bk): 2) |
2596 | analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared(); |
2597 | const int per_phase_b = swizzle_->get_per_phase(layout_b); |
2598 | const int max_phase_b = swizzle_->get_max_phase(layout_b); |
2599 | std::vector<int> mma_instr_b{mma_instr_k, mma_instr_n}; |
2600 | std::vector<int> mat_shape_b{mat_shape_k, mat_shape_n}; |
2601 | int k_order_b = 0; |
2602 | // if(C->is_trans_b()){ |
2603 | // std::swap(mma_instr_b[0], mma_instr_b[1]); |
2604 | // std::swap(mat_shape_b[0], mat_shape_b[1]); |
2605 | // k_order_b = k_order_b ^ 1; |
2606 | // std::swap(ord_b[0], ord_b[1]); |
2607 | // std::swap(shape_b[0], shape_b[1]); |
2608 | // } |
2609 | |
2610 | mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b, |
2611 | mma_instr_b, mat_shape_b, |
2612 | per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep); |
2613 | std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane); |
2614 | |
2615 | if(licm_ptrs) |
2616 | builder_->SetInsertPoint(CurrBB); |
2617 | // pointers |
2618 | int num_ptr_b = b_loader.get_num_ptr(); |
2619 | std::vector<Value*> ptrs_b(num_ptr_b); |
2620 | for(int i = 0; i < num_ptr_b; i++) |
2621 | ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty); |
2622 | |
2623 | |
2624 | // loading function |
2625 | std::function<void(int,int,int,bool)> load_b; |
2626 | load_b = [&](int n, int k, int inc, bool is_prefetch) { |
2627 | auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], |
2628 | shared_next_ptr_[layout_b], off_b, ptrs_b, |
2629 | ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); |
2630 | register_lds2(hb, n, k, inc, hb0, is_prefetch); |
2631 | register_lds2(hb, n+1, k, inc, hb2, is_prefetch); |
2632 | register_lds2(hb, n, k+1, inc, hb1, is_prefetch); |
2633 | register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch); |
2634 | }; |
2635 | |
2636 | |
2637 | |
2638 | // create mma & unpack result, m, n, k are offsets in mat |
2639 | auto call_mma = [&](unsigned m, unsigned n, unsigned k) { |
2640 | InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() + |
2641 | " {$0, $1, $2, $3}," |
2642 | " {$4, $5, $6, $7}," |
2643 | " {$8, $9}," |
2644 | " {$10, $11, $12, $13};" , |
2645 | "=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3" , true); |
2646 | unsigned cols_per_thread = num_rep_n * 2; |
2647 | std::vector<size_t> idx = { |
2648 | (m + 0)*cols_per_thread + (n*2 + 0), |
2649 | (m + 0)*cols_per_thread + (n*2 + 1), |
2650 | (m + 1)*cols_per_thread + (n*2 + 0), |
2651 | (m + 1)*cols_per_thread + (n*2 + 1) |
2652 | }; |
2653 | Value *nc = call(mma_ty, mma_fn, |
2654 | {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], |
2655 | hb[{n, k}], hb[{n, k+1}], |
2656 | fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); |
2657 | fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0}); |
2658 | fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1}); |
2659 | fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2}); |
2660 | fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3}); |
2661 | }; |
2662 | if (C->is_prefetched()) { |
2663 | // create phis |
2664 | builder_->SetInsertPoint(CurrBB->getFirstNonPHI()); |
2665 | for(unsigned m = 0; m < num_rep_m; m++){ |
2666 | ha[{2*m, 0}] = phi(phi_ty, 2); |
2667 | ha[{2*m+1, 0}] = phi(phi_ty, 2); |
2668 | ha[{2*m, 1}] = phi(phi_ty, 2); |
2669 | ha[{2*m+1, 1}] = phi(phi_ty, 2); |
2670 | } |
2671 | for(unsigned n = 0; n < num_rep_n; n+=2){ |
2672 | hb[{n, 0}] = phi(phi_ty, 2); |
2673 | hb[{n+1, 0}] = phi(phi_ty, 2); |
2674 | hb[{n, 1}] = phi(phi_ty, 2); |
2675 | hb[{n+1, 1}] = phi(phi_ty, 2); |
2676 | } |
2677 | // insert prefetched lds at the end of loop header |
2678 | builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); |
2679 | for(unsigned m = 0; m < num_rep_m; m++) |
2680 | load_a(2*m, 0, 0, true); |
2681 | for(unsigned n = 0; n < num_rep_n; n+=2) |
2682 | load_b(n, 0, 0, true); |
2683 | // update accumulators |
2684 | builder_->SetInsertPoint(CurrBB); |
2685 | for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2 |
2686 | int next_k = (k + 1) % num_rep_k; |
2687 | // prefetch A |
2688 | for(unsigned m = 0; m < num_rep_m; m++) |
2689 | load_a(2*m, 2*next_k, 1, true); |
2690 | // prefetch B |
2691 | for(unsigned n = 0; n < num_rep_n; n+=2) |
2692 | load_b(n, 2*next_k, 1, true); |
2693 | // tensor core ops |
2694 | for(unsigned m = 0; m < num_rep_m; m++) |
2695 | for(unsigned n = 0; n < num_rep_n; n++){ |
2696 | call_mma(2*m, n, 2*k); |
2697 | } |
2698 | } |
2699 | } |
2700 | else{ |
2701 | for (unsigned k = 0; k < num_rep_k; k++) { |
2702 | for (unsigned m = 0; m < num_rep_m; m++) |
2703 | load_a(2*m, 2*k, 0, /*is_prefetch*/false); |
2704 | for (unsigned n = 0; n < num_rep_n; n+=2) |
2705 | load_b(n, 2*k, 0, /*is_prefetch*/false); |
2706 | for (unsigned m = 0; m < num_rep_m; m++) |
2707 | for (unsigned n = 0; n < num_rep_n; n++) |
2708 | call_mma(2*m, n, 2*k); |
2709 | } |
2710 | } |
2711 | // write back |
2712 | unsigned i = 0; |
2713 | for(indices_t idx: idxs_.at(C)){ |
2714 | std::vector<Value*> key(idx.size() - 2); |
2715 | std::copy(idx.begin() + 2, idx.end(), key.begin()); |
2716 | if(i >= fcs.at(key).size()) |
2717 | i = 0; |
2718 | vals_[C][idx] = fcs.at(key)[i++]; |
2719 | }; |
2720 | |
2721 | } |
2722 | |
2723 | /** |
2724 | * \brief Code Generation for FMA-based `dot` (FP32, FP64, Default) |
2725 | */ |
2726 | void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) { |
2727 | auto shape_c = C->get_type()->get_block_shapes(); |
2728 | auto shape_a = A->get_type()->get_block_shapes(); |
2729 | auto shape_b = B->get_type()->get_block_shapes(); |
2730 | auto ord_a = layouts_->get(A)->get_order(); |
2731 | auto ord_b = layouts_->get(B)->get_order(); |
2732 | analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline(); |
2733 | analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0)); |
2734 | analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); |
2735 | bool is_a_row = ord_a[0] == 1; |
2736 | bool is_b_row = ord_b[0] == 1; |
2737 | std::string a_trans = is_a_row ? "" : ".trans" ; |
2738 | std::string b_trans = is_b_row ? ".trans" : "" ; |
2739 | int stride_a_m = is_a_row ? shape_a[1] : 1; |
2740 | int stride_a_k = is_a_row ? 1 : shape_a[0]; |
2741 | int stride_b_n = is_b_row ? 1 : shape_b[0]; |
2742 | int stride_b_k = is_b_row ? shape_b[1] : 1; |
2743 | int stride_a0 = is_a_row ? stride_a_k : stride_a_m; |
2744 | int stride_a1 = is_a_row ? stride_a_m : stride_a_k; |
2745 | int stride_b0 = is_b_row ? stride_b_n : stride_b_k; |
2746 | int stride_b1 = is_b_row ? stride_b_k : stride_b_n; |
2747 | int lda = is_a_row ? stride_a_m : stride_a_k; |
2748 | int ldb = is_b_row ? stride_b_k : stride_b_n; |
2749 | int per_phase_a = swizzle_->get_per_phase(layout_a); |
2750 | int max_phase_a = swizzle_->get_max_phase(layout_a); |
2751 | int per_phase_b = swizzle_->get_per_phase(layout_b); |
2752 | int max_phase_b = swizzle_->get_max_phase(layout_b); |
2753 | int num_ptr_a = 8; |
2754 | int num_ptr_b = 8; |
2755 | int vec_a = 2; |
2756 | int vec_b = 4; |
2757 | distributed_axis ax_m = axes_.at(a_axes_->get(C, 0)); |
2758 | distributed_axis ax_n = axes_.at(a_axes_->get(C, 1)); |
2759 | // Value* thread = tgt_->get_local_id(mod_, *builder_, 0); |
2760 | |
2761 | Value* off_a0 = is_a_row ? i32(0) : mul(ax_m.thread_id, i32(ax_m.contiguous)); |
2762 | Value* off_a1 = is_a_row ? mul(ax_m.thread_id, i32(ax_m.contiguous)): i32(0); |
2763 | std::vector<Value*> off_a(num_ptr_a); |
2764 | for(int i = 0; i < num_ptr_a; i++){ |
2765 | // Value* off_a0i = add(off_a0, i32(is_a_row ? vec_a : layout_c->mts(0)*vec_a)); |
2766 | // off_a0i = exact_udiv(off_a0i, i32(vec_a)); |
2767 | // off_a0i = xor_(off_a0i, phase_a); |
2768 | // off_a0i = mul(off_a0i, i32(vec_a)); |
2769 | off_a[i] = add(mul(off_a0, i32(stride_a0)), mul(off_a1, i32(stride_a1))); |
2770 | } |
2771 | Value* off_b0 = is_b_row ? mul(ax_n.thread_id, i32(ax_n.contiguous)): i32(0); |
2772 | Value* off_b1 = is_b_row ? i32(0) : mul(ax_n.thread_id, i32(ax_n.contiguous)); |
2773 | std::vector<Value*> off_b(num_ptr_b); |
2774 | for(int i = 0; i < num_ptr_b; i++){ |
2775 | // Value* off_b0i = add(off_b0, i32(is_b_row ? layout_c->mts(1)*vec_b : vec_b)); |
2776 | // off_b0i = exact_udiv(off_b0i, i32(vec_b)); |
2777 | // off_b0i = xor_(off_b0i, phase_b); |
2778 | // off_b0i = mul(off_b0i, i32(vec_b)); |
2779 | off_b[i] = add(mul(off_b0, i32(stride_b0)), mul(off_b1, i32(stride_b1))); |
2780 | } |
2781 | std::vector<Value*> ptrs_a(num_ptr_a); |
2782 | for(int i = 0; i < num_ptr_a; i++) |
2783 | ptrs_a[i] = gep(shmems_[A], off_a[i]); |
2784 | std::vector<Value*> ptrs_b(num_ptr_b); |
2785 | for(int i = 0; i < num_ptr_b; i++) |
2786 | ptrs_b[i] = gep(shmems_[B], off_b[i]); |
2787 | |
2788 | std::map<indices_t, Value*> ret = vals_[D]; |
2789 | std::map<std::pair<int, int>, Value*> has, hbs; |
2790 | auto ord = layout_c->get_order(); |
2791 | for(unsigned k = 0; k < NK; k++){ |
2792 | int z = 0; |
2793 | for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1])) |
2794 | for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0])) |
2795 | for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++) |
2796 | for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){ |
2797 | unsigned m = (ord[0] == 1) ? i : j; |
2798 | unsigned n = (ord[0] == 1) ? j : i; |
2799 | unsigned mm = (ord[0] == 1) ? ii : jj; |
2800 | unsigned nn = (ord[0] == 1) ? jj : ii; |
2801 | if(has.find({m + mm, k}) == has.end()){ |
2802 | Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k)); |
2803 | Value* va = load(pa); |
2804 | has[{m + mm, k}] = va; |
2805 | } |
2806 | if(hbs.find({n + nn, k}) == hbs.end()){ |
2807 | Value* pb = gep(ptrs_b[0], i32((n + nn)*stride_b_n + k*stride_b_k)); |
2808 | Value* vb = load(pb); |
2809 | hbs[{n + nn, k}] = vb; |
2810 | } |
2811 | ret[idxs_[C].at(z)] = call(f_mul_add, {has[{m+mm,k}], hbs[{n+nn, k}], ret[idxs_[C].at(z)]}); |
2812 | z++; |
2813 | } |
2814 | } |
2815 | |
2816 | for(indices_t idx: idxs_.at(C)){ |
2817 | vals_[C][idx] = ret[idx]; |
2818 | } |
2819 | } |
2820 | |
2821 | /** |
2822 | * \brief Code Generation for `dot` |
2823 | * Dispatches to appropriate specialized function |
2824 | */ |
2825 | void generator::visit_dot_inst(ir::dot_inst* dot) { |
2826 | Function *fn = builder_->GetInsertBlock()->getParent(); |
2827 | Module *module = fn->getParent(); |
2828 | ir::value *A = dot->get_operand(0); |
2829 | ir::value *B = dot->get_operand(1); |
2830 | ir::value *D = dot->get_operand(2); |
2831 | Type *c_ty = cvt(D->get_type()->get_scalar_ty()); |
2832 | Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty}); |
2833 | auto A_shapes = A->get_type()->get_block_shapes(); |
2834 | size_t red_axis = 1; |
2835 | unsigned NK = A_shapes[red_axis]; |
2836 | bool is_outer = NK == 1; |
2837 | bool is_mma = layouts_->get(dot)->to_mma(); |
2838 | if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80) |
2839 | return visit_mma884(dot, A, B, D, NK); |
2840 | if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) |
2841 | return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? |
2842 | if (dot->get_type()->get_scalar_ty()->is_fp32_ty() && |
2843 | A->get_type()->get_scalar_ty()->is_fp32_ty()) |
2844 | return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); |
2845 | throw std::runtime_error("dot has invalid operand type" ); |
2846 | } |
2847 | |
2848 | void generator::visit_trans_inst(ir::trans_inst* trans) { |
2849 | throw std::runtime_error("not supported" ); |
2850 | } |
2851 | |
2852 | /** |
2853 | * \brief Code Generation for `sqrt` |
2854 | */ |
2855 | void generator::visit_sqrt_inst(ir::sqrt_inst* x) { |
2856 | for(indices_t idx: idxs_.at(x)){ |
2857 | Value *val = vals_[x->get_operand(0)][idx]; |
2858 | Value *ret = intrinsic(Intrinsic::sqrt, {val->getType()}, {val}); |
2859 | vals_[x][idx] = ret; |
2860 | } |
2861 | } |
2862 | |
2863 | Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx){ |
2864 | // strides |
2865 | std::vector<Value*> strides(shapes.size(), builder_->getInt32(0)); |
2866 | strides[order[0]] = builder_->getInt32(1); |
2867 | for(size_t i = 1; i < idx.size(); i++) |
2868 | strides[order[i]] = builder_->CreateMul(strides[order[i-1]], builder_->getInt32(shapes[order[i-1]])); |
2869 | // result |
2870 | Value *result = builder_->getInt32(0); |
2871 | for(size_t i = 0; i < idx.size(); i++) |
2872 | result = builder_->CreateAdd(result, builder_->CreateMul(idx[i], strides[i])); |
2873 | return result; |
2874 | } |
2875 | |
2876 | inline Value* generator::shfl_sync(Value* acc, int32_t i){ |
2877 | Type* ty = acc->getType(); |
2878 | std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;" ; |
2879 | InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r" , false); |
2880 | if(ty->getPrimitiveSizeInBits() <= 32) |
2881 | return call(shfl, {acc, i32(i)}); |
2882 | acc = bit_cast(acc, vec_ty(f32_ty, 2)); |
2883 | Value* acc0 = builder_->CreateExtractElement(acc, i32(0)); |
2884 | Value* acc1 = builder_->CreateExtractElement(acc, i32(1)); |
2885 | Value* ret = UndefValue::get(vec_ty(f32_ty, 2)); |
2886 | ret = insert_elt(ret, shfl_sync(acc0, i), i32(0)); |
2887 | ret = insert_elt(ret, shfl_sync(acc1, i), i32(1)); |
2888 | return bit_cast(ret, ty); |
2889 | } |
2890 | |
2891 | /** |
2892 | * \brief Code Generation for `reduce` (ND case) |
2893 | */ |
2894 | void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){ |
2895 | ir::value *arg = x->get_operand(0); |
2896 | const auto with_index = x->with_index(); |
2897 | unsigned axis = x->get_axis(); |
2898 | analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg)); |
2899 | const auto &shapes = layout->get_shape(); |
2900 | |
2901 | Type* sca_ty = cvt(arg->get_type()->get_scalar_ty()); |
2902 | size_t n_bits = sca_ty->getPrimitiveSizeInBits(); |
2903 | std::string n_bits_str = std::to_string(n_bits); |
2904 | std::string cst = (n_bits == 64) ? "l" : "r" ; |
2905 | |
2906 | FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false); |
2907 | InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;" , "b," + cst + "," + cst, true); |
2908 | FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false); |
2909 | InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];" , "=" + cst + ",b," + cst, true); |
2910 | |
2911 | Type *index_ty = IntegerType::get(*ctx_, 32); |
2912 | FunctionType *st_shared_index_ty = |
2913 | FunctionType::get(void_ty, {i1_ty, ptr_ty(index_ty, 3), index_ty}, false); |
2914 | InlineAsm *st_shared_index = InlineAsm::get( |
2915 | st_shared_index_ty, "@$0 st.shared.b32 [$1], $2;" , "b,r,r" , true); |
2916 | FunctionType *ld_shared_index_ty = |
2917 | FunctionType::get(index_ty, {i1_ty, ptr_ty(index_ty, 3)}, false); |
2918 | InlineAsm *ld_shared_index = InlineAsm::get( |
2919 | ld_shared_index_ty, "@$1 ld.shared.b32 $0, [$2];" , "=r,b,r" , true); |
2920 | |
2921 | Value* thread = tgt_->get_local_id(mod_, *builder_, 0); |
2922 | Value* warp = udiv(thread, i32(32)); |
2923 | Value* lane = urem(thread, i32(32)); |
2924 | |
2925 | unsigned shuffle_width = 0; |
2926 | unsigned warps_per_inner = 0; |
2927 | auto arg_vals = vals_.at(arg); |
2928 | std::vector<indices_t> arg_idxs = idxs_.at(arg); |
2929 | size_t n_elts = arg_idxs.size(); |
2930 | unsigned col_per_thread = 0; |
2931 | Value* warp_j = nullptr; |
2932 | if (analysis::scanline_layout *scanline = layout->to_scanline()) { |
2933 | std::vector<int> order = layout->get_order(); |
2934 | unsigned mts = scanline->mts(order[0]); |
2935 | shuffle_width = std::min<int>(mts, 32); |
2936 | warps_per_inner = std::max<int>(mts / 32, 1); |
2937 | col_per_thread = shapes[order[0]] / mts; |
2938 | warp_j = urem(warp, i32(warps_per_inner)); |
2939 | } else if (layout->to_mma()) { |
2940 | shuffle_width = 4; |
2941 | warps_per_inner = layout->to_mma()->wpt(1); |
2942 | col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size(); |
2943 | warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; |
2944 | } |
2945 | assert(warp_j != nullptr); |
2946 | |
2947 | // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]); |
2948 | // |
2949 | Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)), |
2950 | cvt(x->get_type()->get_scalar_ty())); |
2951 | Value *index_base = |
2952 | with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)), |
2953 | IntegerType::get(*ctx_, 32)) |
2954 | : nullptr; |
2955 | |
2956 | // preds |
2957 | Value* is_lane0 = icmp_eq(lane, i32(0)); |
2958 | Value* is_warp0 = icmp_eq(warp, i32(0)); |
2959 | Value* is_thread0 = icmp_eq(thread, i32(0)); |
2960 | Value* lane_j = urem(lane, i32(shuffle_width)); |
2961 | if(warps_per_inner > 1) |
2962 | add_barrier(); |
2963 | // compute partial sum for each warp, and store to shared memory |
2964 | for(size_t i = 0; i < n_elts/col_per_thread; i++){ |
2965 | std::pair<Value*, Value*> acc; |
2966 | // reduce within thread |
2967 | for(size_t j = 0; j < col_per_thread; j++){ |
2968 | auto arg_idx = arg_idxs[i*col_per_thread + j]; |
2969 | bool is_first = j == 0; |
2970 | do_acc( |
2971 | acc, [&]() -> Value * { return arg_vals[arg_idx]; }, |
2972 | [&]() -> Value * { return arg_idx[axis]; }, is_first); |
2973 | } |
2974 | |
2975 | // reduce within warp |
2976 | for(int k = shuffle_width/2 ; k > 0; k >>= 1) { |
2977 | do_acc( |
2978 | acc, [&]() -> Value * { return shfl_sync(acc.first, k); }, |
2979 | [&]() -> Value * { return shfl_sync(acc.second, k); }, false); |
2980 | } |
2981 | // store partial result to shared memory |
2982 | auto x_idxs = idxs_[x][i]; |
2983 | Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; |
2984 | // single warp on the reduce dimension -- no need to use shmem |
2985 | if(warps_per_inner==1){ |
2986 | vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first; |
2987 | } |
2988 | else{ |
2989 | Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); |
2990 | call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first}); |
2991 | if (with_index) { |
2992 | call(st_shared_index, |
2993 | {icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second}); |
2994 | } |
2995 | } |
2996 | } |
2997 | if(warps_per_inner==1) |
2998 | return; |
2999 | add_barrier(); |
3000 | // at this point, partial accumulator synchronized in shared memory |
3001 | // Just need to reduce `warp_per_inner` numbers in shared memory |
3002 | for(size_t i = 0; i < n_elts/col_per_thread; i++){ |
3003 | auto x_idxs = idxs_[x][i]; |
3004 | Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; |
3005 | Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner))); |
3006 | std::pair<Value*, Value*> acc; |
3007 | acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); |
3008 | acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true), |
3009 | gep(index_base, ld_off)}) |
3010 | : nullptr; |
3011 | for (int k = warps_per_inner / 2; k > 0; k >>= 1) { |
3012 | do_acc( |
3013 | acc, [&]() -> Value * { return shfl_sync(acc.first, k); }, |
3014 | [&]() -> Value * { return shfl_sync(acc.second, k); }, false); |
3015 | } |
3016 | vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first; |
3017 | } |
3018 | // add_barrier(); |
3019 | } |
3020 | |
3021 | |
3022 | void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) { |
3023 | ir::value *arg = x->get_operand(0); |
3024 | unsigned axis = x->get_axis(); |
3025 | auto with_index = x->with_index(); |
3026 | |
3027 | // reduce within thread |
3028 | // index-><current reduced value, current min/max index (optional)> |
3029 | std::map<indices_t, std::pair<Value*, Value*>> accs; |
3030 | for(indices_t idx: idxs_.at(arg)){ |
3031 | indices_t pidx = idx; |
3032 | pidx[axis] = i32(0); |
3033 | bool is_first = accs.find(pidx) == accs.end(); |
3034 | do_acc( |
3035 | accs[pidx], [&]() -> Value * { return vals_[arg][idx]; }, |
3036 | [&]() -> Value * { return idx[axis]; }, is_first); |
3037 | }; |
3038 | |
3039 | // reduce within blocks |
3040 | auto *data_layout = layouts_->get(layouts_->tmp(x)); |
3041 | auto *data_ptr = |
3042 | cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty())); |
3043 | auto *index_ptr = |
3044 | with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)), |
3045 | IntegerType::get(*ctx_, 32)) |
3046 | : data_ptr; |
3047 | |
3048 | auto shape = data_layout->get_shape(); |
3049 | auto order = data_layout->get_order(); |
3050 | Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id; |
3051 | for(auto& x: accs) { |
3052 | // current element being computed |
3053 | std::pair<Value *, Value *> acc = x.second; |
3054 | indices_t write_idx = x.first; |
3055 | write_idx[axis] = lane; |
3056 | // shared memory write pointer |
3057 | Value *write_off = shared_off(shape, order, write_idx); |
3058 | Value *write_ptr = gep(data_ptr, write_off); |
3059 | Value *index_write_ptr = gep(index_ptr, write_off); |
3060 | // initialize shared memory |
3061 | add_barrier(); |
3062 | store(acc.first, write_ptr); |
3063 | if (with_index) { |
3064 | store(acc.second, index_write_ptr); |
3065 | } |
3066 | // build result |
3067 | indices_t idx(write_idx.size(), i32(0)); |
3068 | for(size_t i = shape[axis]/2; i > 0; i >>= 1){ |
3069 | idx[axis] = i32(i); |
3070 | // read pointer |
3071 | Value *read_msk = icmp_ult(lane, i32(i)); |
3072 | Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0)); |
3073 | Value *read_ptr = gep(write_ptr, read_off); |
3074 | Value *index_read_ptr = gep(index_write_ptr, read_off); |
3075 | add_barrier(); |
3076 | // update accumulator |
3077 | do_acc( |
3078 | acc, [&]() -> Value * { return load(read_ptr); }, |
3079 | [&]() -> Value * { return load(index_read_ptr); }, false); |
3080 | add_barrier(); |
3081 | store(acc.first, write_ptr); |
3082 | if (with_index) { |
3083 | store(acc.second, index_write_ptr); |
3084 | } |
3085 | } |
3086 | } |
3087 | add_barrier(); |
3088 | |
3089 | // write back |
3090 | for(indices_t idx: idxs_.at(x)){ |
3091 | indices_t read_idx = idx; |
3092 | read_idx.insert(read_idx.begin() + axis, i32(0)); |
3093 | Value *read_off = shared_off(shape, order, read_idx); |
3094 | Value *read_ptr = |
3095 | with_index ? gep(index_ptr, read_off) : gep(data_ptr, read_off); |
3096 | vals_[x][idx] = load(read_ptr); |
3097 | }; |
3098 | } |
3099 | |
3100 | /** |
3101 | * \brief Code Generation for `reduce` (generic case) |
3102 | */ |
3103 | void generator::visit_reduce_inst(ir::reduce_inst* x) { |
3104 | Type *ty = cvt(x->get_type()->get_scalar_ty()); |
3105 | // accumulation function |
3106 | ir::reduce_inst::op_t op = x->get_op(); |
3107 | auto do_acc_op = [&](Value *x, Value *y) -> Value* { |
3108 | switch(op){ |
3109 | case ir::reduce_inst::ADD: return add(x, y); |
3110 | case ir::reduce_inst::SUB: return sub(x, y); |
3111 | case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y); |
3112 | case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y); |
3113 | case ir::reduce_inst::ARGMAX: return icmp_sge(x, y); |
3114 | case ir::reduce_inst::ARGMIN: return icmp_sle(x, y); |
3115 | case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y); |
3116 | case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y); |
3117 | case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); |
3118 | case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); |
3119 | case ir::reduce_inst::FADD: return fadd(x, y); |
3120 | case ir::reduce_inst::FSUB: return fsub(x, y); |
3121 | case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y); |
3122 | case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y); |
3123 | case ir::reduce_inst::FMAX: return max_num(x, y); |
3124 | case ir::reduce_inst::FMIN: return min_num(x, y); |
3125 | case ir::reduce_inst::XOR: return xor_(x, y); |
3126 | |
3127 | default: throw std::runtime_error("unreachable" ); |
3128 | } |
3129 | }; |
3130 | |
3131 | auto do_acc = [&](std::pair<Value *, Value *> &acc, |
3132 | std::function<Value *()> load_value_fn, |
3133 | std::function<Value *()> load_index_fn, |
3134 | bool is_first) -> void { |
3135 | auto *val = load_value_fn(); |
3136 | if (x->with_index()) { |
3137 | auto *index = load_index_fn(); |
3138 | if (is_first) { |
3139 | acc.first = val; |
3140 | acc.second = index; |
3141 | } else { |
3142 | Value *ret = do_acc_op(acc.first, val); |
3143 | acc.first = select(ret, acc.first, val); |
3144 | acc.second = select(ret, acc.second, index); |
3145 | } |
3146 | } else { |
3147 | acc.first = is_first ? val : do_acc_op(acc.first, val); |
3148 | } |
3149 | }; |
3150 | |
3151 | // neutral element |
3152 | Value *neutral; |
3153 | switch(op) { |
3154 | case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break; |
3155 | case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; |
3156 | case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break; |
3157 | case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break; |
3158 | case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break; |
3159 | case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break; |
3160 | case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break; |
3161 | case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break; |
3162 | case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; |
3163 | case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; |
3164 | case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break; |
3165 | case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; |
3166 | case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break; |
3167 | case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break; |
3168 | case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; |
3169 | case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break; |
3170 | case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break; |
3171 | default: throw std::runtime_error("unreachable" ); |
3172 | } |
3173 | ir::value *arg = x->get_operand(0); |
3174 | bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x); |
3175 | bool is_a100_mma = layouts_->is_a100_mma(x); |
3176 | if (is_coalesced_scanline || is_a100_mma) |
3177 | visit_reducend_inst_fast(x, do_acc, neutral); |
3178 | else |
3179 | visit_reducend_inst(x, do_acc, neutral); |
3180 | } |
3181 | |
3182 | /** |
3183 | * \brief Code Generation for `select` |
3184 | */ |
3185 | void generator::visit_select_inst(ir::select_inst* x) { |
3186 | for(indices_t idx: idxs_.at(x)){ |
3187 | vals_[x][idx] = select(vals_[x->get_operand(0)][idx], |
3188 | vals_[x->get_operand(1)][idx], |
3189 | vals_[x->get_operand(2)][idx]); |
3190 | } |
3191 | } |
3192 | |
3193 | |
3194 | |
3195 | void generator::visit_layout_convert(ir::value *out, ir::value *in){ |
3196 | ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes(); |
3197 | // pointer to temporary shared memory |
3198 | Type *ty = cvt(out->get_type()->get_scalar_ty()); |
3199 | |
3200 | // Orders |
3201 | analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in)); |
3202 | analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out)); |
3203 | Value *base; |
3204 | int off = alloc_->offset(layouts_->get(layouts_->tmp(out))); |
3205 | // std::cout << off << std::endl; |
3206 | base = gep(shmem_, i32(off)); |
3207 | base = bit_cast(base, ptr_ty(ty, 3)); |
3208 | std::vector<int> n_reps; |
3209 | for(int i = 0; i < shape.size(); i++){ |
3210 | int in_per_cta = in_layout->shape_per_cta(i); |
3211 | int out_per_cta = out_layout->shape_per_cta(i); |
3212 | int max_per_cta = std::max(in_per_cta, out_per_cta); |
3213 | n_reps.push_back(shape[i]/max_per_cta); |
3214 | } |
3215 | std::vector<std::vector<Value*>> in_ax; |
3216 | std::vector<std::vector<Value*>> out_ax; |
3217 | for(int d = 0; d < shape.size(); d++){ |
3218 | in_ax.push_back(axes_.at(a_axes_->get(in, d)).values); |
3219 | out_ax.push_back(axes_.at(a_axes_->get(out, d)).values); |
3220 | } |
3221 | auto in_ord = |
3222 | in_layout->to_mma() ? out_layout->get_order() : in_layout->get_order(); |
3223 | auto out_ord = |
3224 | out_layout->to_mma() ? in_layout->get_order() : out_layout->get_order(); |
3225 | // out_ord[0] == 0 or in_order[0] == 0 means the first dimension is |
3226 | // non-contiguous. in_vec can be greater than 0 only if both out_ord[0] and |
3227 | // and in_ord[0] are contiguous. |
3228 | int in_vec = out_ord[0] == 0 ? 1 |
3229 | : in_ord[0] == 0 ? 1 |
3230 | : in_layout->contig_per_thread(in_ord[0]); |
3231 | int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]); |
3232 | int pad = std::max(in_vec, out_vec); |
3233 | Value *in_ld = i32(shape[in_ord[0]] + pad); |
3234 | Value *out_ld = i32(shape[out_ord[0]] + pad); |
3235 | for(int i = 0; i < n_reps[0]; i++) |
3236 | for(int j = 0; j < n_reps[1]; j++){ |
3237 | int max_ii, max_jj; |
3238 | add_barrier(); |
3239 | max_ii = in_ax[0].size()/n_reps[0]; |
3240 | max_jj = in_ax[1].size()/n_reps[1]; |
3241 | for(int ii = 0; ii < max_ii; ii++) |
3242 | for(int jj = 0; jj < max_jj; jj+=in_vec){ |
3243 | // shared mem pointer |
3244 | indices_t offs = {in_ax[0][ii], in_ax[1][jj]}; |
3245 | Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); |
3246 | Value *ptr = gep(base, off); |
3247 | // stash value to shared mem |
3248 | Value* vals = UndefValue::get(vec_ty(ty, in_vec)); |
3249 | for(int jjj = 0; jjj < in_vec; jjj++){ |
3250 | indices_t idxs = {in_ax[0][i*max_ii + ii], |
3251 | in_ax[1][j*max_jj + jj + jjj]}; |
3252 | Value* val = bit_cast(vals_[in][idxs], ty); |
3253 | vals = insert_elt(vals, val, jjj); |
3254 | } |
3255 | ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace())); |
3256 | store(vals, ptr); |
3257 | } |
3258 | add_barrier(); |
3259 | max_ii = out_ax[0].size()/n_reps[0]; |
3260 | max_jj = out_ax[1].size()/n_reps[1]; |
3261 | for(int ii = 0; ii < max_ii; ii++) |
3262 | for(int jj = 0; jj < max_jj; jj+=out_vec){ |
3263 | // shared mem pointer |
3264 | indices_t offs = {out_ax[0][ii], out_ax[1][jj]}; |
3265 | Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); |
3266 | Value *ptr = gep(base, off); |
3267 | ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace())); |
3268 | // load value from shared rem |
3269 | Value* vals = load(ptr); |
3270 | for(int jjj = 0; jjj < out_vec; jjj++){ |
3271 | indices_t idxs = {out_ax[0][i*max_ii + ii], |
3272 | out_ax[1][j*max_jj + jj + jjj]}; |
3273 | vals_[out][idxs] = extract_elt(vals, jjj); |
3274 | } |
3275 | } |
3276 | |
3277 | } |
3278 | } |
3279 | |
3280 | void generator::visit_cvt_layout_inst(ir::cvt_layout_inst *rc) { |
3281 | visit_layout_convert(rc, rc->get_operand(0)); |
3282 | } |
3283 | |
3284 | void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ |
3285 | unsigned in_vec = 1; |
3286 | ir::value *arg = x->get_pointer_operand(); |
3287 | analysis::shared_layout* out_layout = layouts_->get(x)->to_shared(); |
3288 | analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); |
3289 | auto out_order = out_layout->get_order(); |
3290 | auto in_order = in_layout->get_order(); |
3291 | // tiles |
3292 | if(out_order == in_order) |
3293 | in_vec = in_layout->nts(in_order[0]); |
3294 | int out_vec = swizzle_->get_vec(out_layout); |
3295 | int min_vec = std::min<int>(out_vec, in_vec); |
3296 | int s = std::max<int>(out_vec / in_vec, 1); |
3297 | // |
3298 | int per_phase = swizzle_->get_per_phase(out_layout); |
3299 | int max_phase = swizzle_->get_max_phase(out_layout); |
3300 | // |
3301 | int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); |
3302 | int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1); |
3303 | int n_shared_0 = std::max<int>(in_vec / out_vec, 1); |
3304 | auto shapes = x->get_type()->get_block_shapes(); |
3305 | BasicBlock* CurrBB = builder_->GetInsertBlock(); |
3306 | BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); |
3307 | std::map<std::pair<int, int>, Value*> tmp; |
3308 | std::vector<std::pair<Value*, int>> shared; |
3309 | for(int i = 0; i < idxs_.at(arg).size(); i++){ |
3310 | unsigned id = i / min_vec; |
3311 | // input ptr info |
3312 | int id_0 = id % (in_ld/min_vec); |
3313 | int id_1 = id / (in_ld/min_vec); |
3314 | int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); |
3315 | int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); |
3316 | int off = (off_1*shapes[in_order[0]] + off_0); |
3317 | std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0}; |
3318 | if(tmp.find(key) == tmp.end()){ |
3319 | if(CurrBB != FirstBB) |
3320 | builder_->SetInsertPoint(FirstBB->getTerminator()); |
3321 | indices_t idx = idxs_.at(arg).at(key.first*in_ld); |
3322 | Value* phase = udiv(idx[in_order[1]], i32(per_phase)); |
3323 | phase = urem(phase, i32(max_phase)); |
3324 | Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); |
3325 | Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec)); |
3326 | off_0 = udiv(off_0, i32(min_vec)); |
3327 | off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s))); |
3328 | off_0 = mul(off_0 , i32(min_vec)); |
3329 | Value* off = add(off_0, off_1); |
3330 | if(CurrBB != FirstBB) |
3331 | builder_->SetInsertPoint(CurrBB); |
3332 | tmp[key] = gep(shmems_[x], {off}); |
3333 | } |
3334 | shared.push_back({tmp[key], off}); |
3335 | } |
3336 | size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; |
3337 | for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){ |
3338 | auto idx = idxs_[arg][i]; |
3339 | // input ptr info |
3340 | Value *ptr = vals_[arg][idx]; |
3341 | size_t in_off = 0; |
3342 | GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]); |
3343 | if(in_gep){ |
3344 | ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin()); |
3345 | in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; |
3346 | ptr= cst ? in_gep->getPointerOperand() : in_gep; |
3347 | } |
3348 | // output ptr info |
3349 | Value* out_base = shared[i].first; |
3350 | int out_off = shared[i].second*dtsize; |
3351 | // asm |
3352 | std::string mod = (in_vec*dtsize == 16) ? ".cg" : ".ca" ; |
3353 | // Value* false_value = vals_[x->get_false_value_operand()][idx]; |
3354 | // bool is_zero_false_value = false; |
3355 | // if(Constant* cst = dyn_cast<Constant>(false_value)) |
3356 | // is_zero_false_value = cst->isZeroValue(); |
3357 | Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0)); |
3358 | std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;" ; |
3359 | FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), ptr->getType(), builder_->getInt32Ty()}, false); |
3360 | InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r" , true); |
3361 | call(iasm, {out_base, ptr, src_size}); |
3362 | } |
3363 | |
3364 | std::string asm_str = "cp.async.commit_group;" ; |
3365 | InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "" , true); |
3366 | call(iasm); |
3367 | } |
3368 | |
3369 | void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { |
3370 | unsigned in_vec = 1; |
3371 | ir::value *arg = cts->get_operand(0); |
3372 | analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared(); |
3373 | analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg)); |
3374 | auto out_order = out_layout->get_order(); |
3375 | auto in_order = in_layout->get_order(); |
3376 | // tiles |
3377 | if(out_order == in_order) |
3378 | in_vec = in_layout->contig_per_thread(in_order[0]); |
3379 | int out_vec = swizzle_->get_vec(out_layout); |
3380 | int min_vec = std::min<int>(out_vec, in_vec); |
3381 | int s = std::max<int>(out_vec / in_vec, 1); |
3382 | // |
3383 | int per_phase = swizzle_->get_per_phase(out_layout); |
3384 | int max_phase = swizzle_->get_max_phase(out_layout); |
3385 | // |
3386 | int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]); |
3387 | int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]); |
3388 | if(in_layout->to_mma()){ |
3389 | mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]); |
3390 | mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]); |
3391 | per_phase = 1; |
3392 | max_phase = 8; |
3393 | } |
3394 | |
3395 | int in_ld = in_layout->get_shape()[in_order[0]] / mts_0; |
3396 | int n_shared_0 = std::max<int>(in_vec / out_vec, 1); |
3397 | int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1); |
3398 | if(in_layout->to_mma()){ |
3399 | n_shared_0 = 8; |
3400 | n_shared_1 = 1; |
3401 | } |
3402 | |
3403 | BasicBlock* CurrBB = builder_->GetInsertBlock(); |
3404 | BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); |
3405 | auto shapes = cts->get_type()->get_block_shapes(); |
3406 | |
3407 | |
3408 | // store to shared |
3409 | Value *current = nullptr; |
3410 | std::map<std::pair<int, int>, Value*> ptrs; |
3411 | for(int i = 0; i < idxs_.at(arg).size(); i++){ |
3412 | auto idx = idxs_[arg][i]; |
3413 | Value *in_value = vals_[arg][idx]; |
3414 | if(i % min_vec == 0) |
3415 | current = UndefValue::get(vec_ty(in_value->getType(), min_vec)); |
3416 | current = insert_elt(current, in_value, i % min_vec); |
3417 | if(i % min_vec == min_vec - 1){ |
3418 | unsigned id = i / min_vec; |
3419 | // input ptr info |
3420 | int id_0 = id % (in_ld/min_vec); |
3421 | int id_1 = id / (in_ld/min_vec); |
3422 | // std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl; |
3423 | std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0}; |
3424 | if(ptrs.find(key) == ptrs.end()){ |
3425 | if(FirstBB->getTerminator()) |
3426 | builder_->SetInsertPoint(FirstBB->getTerminator()); |
3427 | else |
3428 | builder_->SetInsertPoint(FirstBB); |
3429 | indices_t idx = idxs_.at(arg).at(key.first*in_ld); |
3430 | Value* phase = udiv(idx[in_order[1]], i32(per_phase)); |
3431 | phase = urem(phase, i32(max_phase)); |
3432 | Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); |
3433 | Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec)); |
3434 | off_0 = udiv(off_0, i32(min_vec)); |
3435 | off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s))); |
3436 | off_0 = mul(off_0 , i32(min_vec)); |
3437 | Value* off = add(off_0, off_1); |
3438 | builder_->SetInsertPoint(CurrBB); |
3439 | ptrs[key] = gep(shmems_.at(cts), {off}); |
3440 | } |
3441 | int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; |
3442 | int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; |
3443 | if(in_layout->to_mma()){ |
3444 | off_0 = id_0/n_shared_0*n_shared_0*8; |
3445 | off_1 = id_1/n_shared_1*n_shared_1*8; |
3446 | } |
3447 | int off = (off_1*shapes[in_order[0]] + off_0); |
3448 | Value* ptr = gep(ptrs[key], {i32(off)}); |
3449 | ptr = bit_cast(ptr, current->getType()->getPointerTo(3)); |
3450 | // asm |
3451 | store(current, ptr); |
3452 | } |
3453 | }; |
3454 | } |
3455 | |
3456 | void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst*) { |
3457 | throw std::runtime_error("TODO" ); |
3458 | } |
3459 | |
3460 | Instruction* generator::add_barrier() { |
3461 | Module *module = builder_->GetInsertBlock()->getModule(); |
3462 | return tgt_->add_barrier(module, *builder_); |
3463 | } |
3464 | |
3465 | void generator::visit_barrier_inst(ir::barrier_inst*) { |
3466 | add_barrier(); |
3467 | } |
3468 | |
3469 | void generator::visit_clock_inst(ir::clock_inst* clock){ |
3470 | InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;" , "=l" , true); |
3471 | vals_[clock][{}] = call(iasm); |
3472 | } |
3473 | |
3474 | void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){ |
3475 | InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;" , "=l" , true); |
3476 | vals_[timer][{}] = call(iasm); |
3477 | } |
3478 | |
3479 | |
3480 | |
3481 | void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) { |
3482 | ir::value *v = i->get_operand(0); |
3483 | int inc = i->get_inc(); |
3484 | if (inc == 0) { |
3485 | // If dot has not been visitied, do nothing. |
3486 | } else { |
3487 | // If dot has been visitied, insert prefetched lds |
3488 | assert(inc == 1); |
3489 | assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() && |
3490 | "dot hasn't be visited" ); |
3491 | // sink lds & extract element |
3492 | // move lds & all uses to current location |
3493 | std::stack<Value*> work_stack; |
3494 | for (Value *value : prefetch_latch_to_bb_[v]) |
3495 | work_stack.push(value); |
3496 | std::vector<Instruction*> dead_instrs; |
3497 | while (!work_stack.empty()) { |
3498 | Value *m = work_stack.top(); |
3499 | work_stack.pop(); |
3500 | |
3501 | for (auto u : m->users()) |
3502 | work_stack.push(u); |
3503 | |
3504 | assert(isa<Instruction>(m)); |
3505 | auto m_instr = static_cast<Instruction*>(m); |
3506 | |
3507 | m_instr->removeFromParent(); |
3508 | m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end())); |
3509 | assert(m_instr->getParent() == &*builder_->GetInsertBlock()); |
3510 | builder_->SetInsertPoint(m_instr->getParent()); |
3511 | } |
3512 | } |
3513 | } |
3514 | |
3515 | void generator::visit_async_wait_inst(ir::async_wait_inst* i) { |
3516 | std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";" ; |
3517 | InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "" , true); |
3518 | call(iasm); |
3519 | } |
3520 | |
3521 | /** |
3522 | * \brief Code Generation for `extern_elementwise` |
3523 | */ |
3524 | void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) { |
3525 | std::vector<Type *> operand_types; |
3526 | for (size_t j = 0; j < i->get_num_operands(); j++) { |
3527 | operand_types.push_back( |
3528 | cvt(i->get_operand(j)->get_type()->get_scalar_ty())); |
3529 | } |
3530 | Type *ret_type = cvt(i->get_type()->get_scalar_ty()); |
3531 | FunctionType *FT = |
3532 | FunctionType::get(ret_type, std::move(operand_types), false); |
3533 | Function *F = llvm::cast<llvm::Function>( |
3534 | mod_->getOrInsertFunction(i->get_symbol_name(), FT).getCallee()); |
3535 | for (auto idx : idxs_.at(i)) { |
3536 | std::vector<llvm::Value *> args; |
3537 | for (size_t j = 0; j < i->get_num_operands(); j++) { |
3538 | args.emplace_back(vals_[i->get_operand(j)][idx]); |
3539 | } |
3540 | vals_[i][idx] = call(F, std::move(args)); |
3541 | } |
3542 | add_extern_lib(i->get_lib_name(), i->get_lib_path()); |
3543 | } |
3544 | |
3545 | //void generator::visit_make_range_dyn(ir::make_range_dyn* x) { |
3546 | // for(indices_t idx: idxs_.at(x)){ |
3547 | // assert(idx.size() == 1); |
3548 | // if(idx[0] == i32(0)) |
3549 | // vals_[x][idx] = idx[0]; |
3550 | // else{ |
3551 | // BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]); |
3552 | // assert(bin_add); |
3553 | // vals_[x][idx] = bin_add->getOperand(0); |
3554 | // } |
3555 | // } |
3556 | //} |
3557 | |
3558 | //void generator::visit_make_range_sta(ir::make_range_sta* x) { |
3559 | // for(indices_t idx: idxs_.at(x)){ |
3560 | // assert(idx.size() == 1); |
3561 | // if(idx[0] == i32(0)){ |
3562 | // vals_[x][idx] = idx[0]; |
3563 | // } |
3564 | // else{ |
3565 | // BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]); |
3566 | // assert(bin_add); |
3567 | // Value *cst = bin_add->getOperand(1); |
3568 | // assert(isa<Constant>(cst)); |
3569 | // vals_[x][idx] = cst; |
3570 | // } |
3571 | // }; |
3572 | //} |
3573 | |
3574 | void generator::visit_make_range(ir::make_range* x) { |
3575 | for(indices_t idx: idxs_.at(x)){ |
3576 | Value* start = ConstantInt::get(idx[0]->getType(), x->get_first()->get_value()); |
3577 | vals_[x][idx] = add(start, idx[0]); |
3578 | } |
3579 | } |
3580 | |
3581 | void generator::visit_undef_value(ir::undef_value *x) { |
3582 | ir::type* sca_ty = x->get_type()->get_scalar_ty(); |
3583 | Type* ty = cvt(sca_ty); |
3584 | for(indices_t idx: idxs_.at(x)) |
3585 | vals_[x][idx] = llvm::UndefValue::get(ty); |
3586 | } |
3587 | |
3588 | void generator::visit_constant_int(ir::constant_int *x){ |
3589 | Type *ty = cvt(x->get_type()->get_scalar_ty()); |
3590 | for(indices_t idx: idxs_.at(x)) |
3591 | vals_[x][idx] = ConstantInt::get(ty, x->get_value()); |
3592 | } |
3593 | |
3594 | void generator::visit_constant_fp(ir::constant_fp *x){ |
3595 | Type *ty = cvt(x->get_type()->get_scalar_ty()); |
3596 | for(indices_t idx: idxs_.at(x)) { |
3597 | // manually select bf16 constant |
3598 | if (x->get_type()->get_scalar_ty()->is_bf16_ty()) { |
3599 | // highest 16 bits of fp32 |
3600 | float fp32_value = x->get_value(); |
3601 | uint16_t bf16_raw = (*reinterpret_cast<uint32_t*>(&fp32_value) |
3602 | & 0xffff0000) >> 16; |
3603 | std::stringstream const_str; |
3604 | const_str << "0x" << std::hex << bf16_raw << "U" ; // unsigned |
3605 | InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false), |
3606 | " mov.b16 $0, " + const_str.str() + ";" , |
3607 | "=h" , false); |
3608 | vals_[x][idx] = builder_->CreateCall(bf16_const, {}); |
3609 | } else |
3610 | vals_[x][idx] = ConstantFP::get(ty, x->get_value()); |
3611 | } |
3612 | } |
3613 | |
3614 | void generator::visit_alloc_const(ir::alloc_const *alloc) { |
3615 | unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value(); |
3616 | Type *element_ty = cvt(alloc->get_type()->get_pointer_element_ty()); |
3617 | Type *array_ty = llvm::ArrayType::get(element_ty, size); |
3618 | Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage, |
3619 | nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4); |
3620 | vals_[alloc][{}] = bit_cast(array, element_ty->getPointerTo(4)); |
3621 | } |
3622 | |
3623 | |
3624 | void generator::forward_declare(ir::function* fn){ |
3625 | FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type()); |
3626 | if(!tgt_->is_gpu()){ |
3627 | Type *fn_ret_ty = fn_ty->getReturnType(); |
3628 | std::vector<Type*> fn_args_ty; |
3629 | for(unsigned i = 0; i < fn_ty->getNumParams(); i++) |
3630 | fn_args_ty.push_back(fn_ty->getParamType(i)); |
3631 | fn_args_ty.push_back(i32_ty); |
3632 | fn_args_ty.push_back(i32_ty); |
3633 | fn_args_ty.push_back(i32_ty); |
3634 | fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false); |
3635 | } |
3636 | Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_); |
3637 | fns_[fn] = ret; |
3638 | } |
3639 | |
3640 | Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout, |
3641 | Type *ty) { |
3642 | unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); |
3643 | Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space)); |
3644 | return base; |
3645 | } |
3646 | |
3647 | void generator::visit_function(ir::function* fn) { |
3648 | idxs_.clear(); |
3649 | vals_.clear(); |
3650 | seen_.clear(); |
3651 | LLVMContext &ctx = builder_->getContext(); |
3652 | |
3653 | Function* ret = fns_[fn]; |
3654 | |
3655 | |
3656 | // set attributes |
3657 | for(auto attr_pair: fn->attrs()){ |
3658 | unsigned id = attr_pair.first; |
3659 | for(ir::attribute attr: attr_pair.second) |
3660 | if(attr.is_llvm_attr()){ |
3661 | llvm::Attribute llattr = cvt(attr); |
3662 | if(llattr.getKindAsEnum() != llvm::Attribute::None) |
3663 | ret->addAttribute(id, cvt(attr)); |
3664 | } |
3665 | } |
3666 | // set metadata |
3667 | if(tgt_->is_gpu()){ |
3668 | tgt_->set_kernel(*builder_, ctx, mod_, ret); |
3669 | Metadata *md_args[] = { |
3670 | ValueAsMetadata::get(ret), |
3671 | MDString::get(ctx, "maxntidx" ), |
3672 | ValueAsMetadata::get(i32(num_warps_*32)) |
3673 | }; |
3674 | mod_->getOrInsertNamedMetadata("nvvm.annotations" )->addOperand(MDNode::get(ctx, md_args)); |
3675 | } |
3676 | // set arguments |
3677 | for(unsigned i = 0; i < fn->args().size(); i++) |
3678 | vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i); |
3679 | // create blocks |
3680 | auto blocks = ir::cfg::reverse_post_order(fn); |
3681 | for(ir::basic_block *block: blocks) { |
3682 | BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret); |
3683 | bbs_[block] = dst_block; |
3684 | } |
3685 | builder_->SetInsertPoint(bbs_[fn->blocks()[0]]); |
3686 | // create policies |
3687 | if(tgt_->as_nvidia()->sm() >= 80) |
3688 | for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){ |
3689 | std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last" ; |
3690 | std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;" ; |
3691 | InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l" , false); |
3692 | policies_[evict] = call(iasm); |
3693 | } |
3694 | // initialize layouts |
3695 | for(auto x: layouts_->get_all()){ |
3696 | visit_layout(x.second); |
3697 | } |
3698 | // generate LLVM-IR code |
3699 | for(ir::basic_block *block: blocks) |
3700 | visit_basic_block(block); |
3701 | // finalize |
3702 | finalize_function(fn); |
3703 | } |
3704 | |
3705 | |
3706 | |
3707 | void generator::visit_layout_mma(analysis::mma_layout* layout) { |
3708 | ir::value *a = nullptr; |
3709 | ir::value *b = nullptr; |
3710 | for(ir::value* v: layout->get_values()) |
3711 | if(ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(v)){ |
3712 | a = dot->get_operand(0); |
3713 | b = dot->get_operand(1); |
3714 | } |
3715 | analysis::data_layout* layout_a = layouts_->get(a); |
3716 | analysis::data_layout* layout_b = layouts_->get(b); |
3717 | |
3718 | const auto& shape = layout->get_shape(); |
3719 | Value *_1 = i32(1); |
3720 | Value *_2 = i32(2); |
3721 | Value *_3 = i32(3); |
3722 | Value *_4 = i32(4); |
3723 | Value *_8 = i32(8); |
3724 | Value *_16 = i32(16); |
3725 | Value *_32 = i32(32); |
3726 | int cc = tgt_->as_nvidia()->sm(); |
3727 | std::vector<Value*> idx_m; |
3728 | std::vector<Value*> idx_n; |
3729 | std::vector<Value*> idx_z; |
3730 | // |
3731 | Value* thread = tgt_->get_local_id(mod_, *builder_, 0); |
3732 | Value *lane = urem(thread, _32); |
3733 | Value *warp = udiv(thread, _32); |
3734 | /* lane offset */ |
3735 | if(cc < 80){ |
3736 | auto ord_a = layout_a->get_order(); |
3737 | auto ord_b = layout_b->get_order(); |
3738 | bool is_a_row = ord_a[0] != 0; |
3739 | bool is_b_row = ord_b[0] != 0; |
3740 | /* warp offset */ |
3741 | Value *warp_0 = urem(warp, i32(layout->wpt(0))); |
3742 | Value *warp_12 = udiv(warp, i32(layout->wpt(0))); |
3743 | Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); |
3744 | Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); |
3745 | Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); |
3746 | // Quad offset |
3747 | Value *off_quad_m = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(0))); |
3748 | Value *off_quad_n = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(1))); |
3749 | // Pair offset |
3750 | Value *off_pair_m = udiv(urem(lane, _16), _4); |
3751 | off_pair_m = urem(off_pair_m, i32(layout->fpw(0))); |
3752 | off_pair_m = mul(off_pair_m, i32(4)); |
3753 | Value *off_pair_n = udiv(urem(lane, _16), _4); |
3754 | off_pair_n = udiv(off_pair_n, i32(layout->fpw(0))); |
3755 | off_pair_n = urem(off_pair_n, i32(layout->fpw(1))); |
3756 | off_pair_n = mul(off_pair_n, i32(4)); |
3757 | // scale |
3758 | off_pair_m = mul(off_pair_m, i32(layout->rep(0)/2)); |
3759 | off_quad_m = mul(off_quad_m, i32(layout->rep(0)/2)); |
3760 | off_pair_n = mul(off_pair_n, i32(layout->rep(1)/2)); |
3761 | off_quad_n = mul(off_quad_n, i32(layout->rep(1)/2)); |
3762 | // Quad pair offset |
3763 | Value *off_lane_m = add(off_pair_m, off_quad_m); |
3764 | Value *off_lane_n = add(off_pair_n, off_quad_n); |
3765 | // a offset |
3766 | offset_a_m_[layout] = add(off_warp_m, off_lane_m); |
3767 | offset_a_k_[layout] = and_(lane, _3); |
3768 | // b offsets |
3769 | offset_b_n_[layout] = add(off_warp_n, off_lane_n); |
3770 | offset_b_k_[layout] = and_(lane, _3); |
3771 | // i indices |
3772 | Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]); |
3773 | for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)) |
3774 | for(unsigned mm = 0; mm < layout->rep(0); mm++) |
3775 | idx_m.push_back(add(offset_c_m, i32(m + mm*2))); |
3776 | // j indices |
3777 | Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n)); |
3778 | for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)) |
3779 | for(unsigned nn = 0; nn < layout->rep(1); nn++){ |
3780 | idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1)))); |
3781 | idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1))); |
3782 | } |
3783 | if(is_a_row){ |
3784 | offset_a_m_[layout] = add(offset_a_m_[layout], urem(thread, i32(4))); |
3785 | offset_a_k_[layout] = i32(0); |
3786 | } |
3787 | if(!is_b_row){ |
3788 | offset_b_n_[layout] = add(offset_b_n_[layout], urem(thread, i32(4))); |
3789 | offset_b_k_[layout] = i32(0); |
3790 | } |
3791 | /* axes */ |
3792 | axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0}; |
3793 | axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1}; |
3794 | } |
3795 | else{ |
3796 | /* warp offset */ |
3797 | Value *warp_0 = urem(warp, i32(layout->wpt(0))); |
3798 | Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1))); |
3799 | Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); |
3800 | Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); |
3801 | Value *off_lane_m = urem(lane, _16); |
3802 | Value *off_lane_n = urem(lane, _8); |
3803 | /* offsets */ |
3804 | // a offset |
3805 | offset_a_m_[layout] = add(off_warp_m, off_lane_m); |
3806 | offset_a_k_[layout] = i32(0); |
3807 | // b offsets |
3808 | offset_b_n_[layout] = add(off_warp_n, off_lane_n); |
3809 | offset_b_k_[layout] = i32(0); |
3810 | // c offset |
3811 | Value *off_c_m = add(udiv(lane, _4), off_warp_m); |
3812 | Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n); |
3813 | for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)){ |
3814 | idx_m.push_back(add(off_c_m, i32(m))); |
3815 | idx_m.push_back(add(off_c_m, i32(m + 8))); |
3816 | } |
3817 | for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)){ |
3818 | idx_n.push_back(add(off_c_n, i32(n))); |
3819 | idx_n.push_back(add(off_c_n, i32(n + 1))); |
3820 | } |
3821 | /* axes */ |
3822 | axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0}; |
3823 | axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1}; |
3824 | } |
3825 | } |
3826 | |
3827 | void generator::visit_layout_scanline(analysis::scanline_layout* layout) { |
3828 | Value* thread_id = tgt_->get_local_id(mod_, *builder_, 0); |
3829 | auto order = layout->get_order(); |
3830 | const auto& shape = layout->get_shape(); |
3831 | // Delinearize |
3832 | size_t dim = shape.size(); |
3833 | std::vector<Value*> thread_ids(dim); |
3834 | for(unsigned k = 0; k < dim - 1; k++){ |
3835 | Constant *dim_k = i32(layout->mts(order[k])); |
3836 | Value *rem = urem(thread_id, dim_k); |
3837 | thread_id = udiv(thread_id, dim_k); |
3838 | thread_ids[order[k]] = rem; |
3839 | } |
3840 | Constant *dim_k = i32(layout->mts(order[dim - 1])); |
3841 | thread_ids[order[dim - 1]] = urem(thread_id, dim_k); |
3842 | |
3843 | // Create axes |
3844 | for(unsigned k = 0; k < dim; k++) { |
3845 | int nts = layout->nts(k); |
3846 | int mts = layout->mts(k); |
3847 | std::string str_k = std::to_string(k); |
3848 | Value *contiguous_k = i32(nts); |
3849 | Value *scaled_thread_ids = mul(thread_ids[k], contiguous_k); |
3850 | unsigned per_cta = layout->shape_per_cta(k); |
3851 | unsigned per_thread = nts * shape[k] / per_cta; |
3852 | std::vector<Value*> idx_list(per_thread); |
3853 | for(unsigned n = 0 ; n < per_thread; n++){ |
3854 | unsigned offset = n / nts * per_cta + n % nts; |
3855 | idx_list[n] = add(scaled_thread_ids, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); |
3856 | } |
3857 | axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_ids[k]}; |
3858 | } |
3859 | } |
3860 | |
3861 | void generator::visit_layout_shared(analysis::shared_layout* layout) { |
3862 | Type* ty = cvt(layout->get_type()); |
3863 | PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace()); |
3864 | if (layout->get_N_buffer()) { |
3865 | // create pointers |
3866 | shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout))); |
3867 | shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty); |
3868 | |
3869 | BasicBlock *current = builder_->GetInsertBlock(); |
3870 | |
3871 | auto info = *layout->get_N_buffer(); |
3872 | ir::phi_node *phi = info.phi; |
3873 | BasicBlock *parent = bbs_.at(phi->get_parent()); |
3874 | if(parent->empty()) |
3875 | builder_->SetInsertPoint(parent); |
3876 | else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) { |
3877 | builder_->SetInsertPoint(&*parent->getFirstNonPHI()); |
3878 | } else |
3879 | builder_->SetInsertPoint(parent); |
3880 | |
3881 | // create smem_idx |
3882 | read_smem_idx_[layout] = phi(i32_ty, 2); |
3883 | write_smem_idx_[layout] = phi(i32_ty, 2); |
3884 | |
3885 | // create pointers |
3886 | // ptr of the current iteration |
3887 | shared_ptr_[layout] = phi(ptr_ty, 2); |
3888 | // ptr of the next iteration |
3889 | shared_next_ptr_[layout] = phi(ptr_ty, 2); |
3890 | |
3891 | builder_->SetInsertPoint(current); |
3892 | } else if(layout->get_double_buffer()) { |
3893 | BasicBlock *current = builder_->GetInsertBlock(); |
3894 | auto info = *layout->get_double_buffer(); |
3895 | ir::phi_node *phi = info.phi; |
3896 | BasicBlock *parent = bbs_.at(phi->get_parent()); |
3897 | if(parent->empty()) |
3898 | builder_->SetInsertPoint(parent); |
3899 | else |
3900 | builder_->SetInsertPoint(&*parent->getFirstNonPHI()); |
3901 | // create pointers |
3902 | shared_ptr_[layout] = phi(ptr_ty, 2); |
3903 | shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout))); |
3904 | shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], shared_ptr_[layout]->getType()); |
3905 | shared_off_[layout] = phi(i32_ty, 2); |
3906 | shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr" ); |
3907 | builder_->SetInsertPoint(current); |
3908 | } else{ |
3909 | size_t offset = alloc_->offset(layout); |
3910 | shared_ptr_[layout] = gep(shmem_, i32(offset)); |
3911 | shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty); |
3912 | } |
3913 | } |
3914 | |
3915 | void generator::visit_basic_block(ir::basic_block * block) { |
3916 | |
3917 | BasicBlock *parent = bbs_[block]; |
3918 | builder_->SetInsertPoint(parent); |
3919 | for(ir::instruction *i: block->get_inst_list()){ |
3920 | visit_value(i); |
3921 | // std::cout << "done" << std::endl; |
3922 | } |
3923 | // Update ir bb -> llvm bb mapping |
3924 | bbs_[block] = builder_->GetInsertBlock(); |
3925 | } |
3926 | |
3927 | void generator::visit_argument(ir::argument* arg) { |
3928 | |
3929 | } |
3930 | |
3931 | void generator::init_idx(ir::value *v) { |
3932 | idxs_[v].clear(); |
3933 | if(!v->get_type()->is_block_ty()){ |
3934 | idxs_[v].push_back({}); |
3935 | return; |
3936 | } |
3937 | if(layouts_->get(v)->to_shared()) |
3938 | return; |
3939 | const auto &shapes = v->get_type()->get_block_shapes(); |
3940 | size_t rank = shapes.size(); |
3941 | std::vector<distributed_axis> axes(rank); |
3942 | std::vector<int> ord(rank); |
3943 | // compute axes |
3944 | // std::cout << "axes" << std::endl; |
3945 | for(size_t d = 0; d < shapes.size(); d++){ |
3946 | // std::cout << d << " " << shapes[d] << std::endl; |
3947 | // std::cout << a_axes_->get(v, d) << std::endl; |
3948 | if(shapes[d] > 1){ |
3949 | unsigned x = a_axes_->get(v, d); |
3950 | axes[d] = axes_.at(x); |
3951 | } |
3952 | else{ |
3953 | axes[d].contiguous = 1; |
3954 | axes[d].values = {i32(0)}; |
3955 | } |
3956 | } |
3957 | // std::cout << "axes ok" << std::endl; |
3958 | // compute order |
3959 | analysis::data_layout* layout = layouts_->get(v); |
3960 | std::iota(ord.begin(), ord.end(), 0); |
3961 | auto cmp = [&](int x, int y) { |
3962 | unsigned axx = a_axes_->get(v, x); |
3963 | unsigned axy = a_axes_->get(v, y); |
3964 | size_t posx = layout->find_axis(axx); |
3965 | size_t posy = layout->find_axis(axy); |
3966 | if(posx < rank && posy < rank) |
3967 | return layout->get_order(posx) < layout->get_order(posy); |
3968 | return false; |
3969 | }; |
3970 | std::sort(ord.begin(), ord.end(), cmp); |
3971 | ords_[v] = ord; |
3972 | // indices |
3973 | if(axes.size() == 1) |
3974 | for(Value* x0: axes[ord[0]].values){ |
3975 | idxs_[v].push_back({x0}); |
3976 | } |
3977 | if(axes.size() == 2) |
3978 | for(Value* x1: axes[ord[1]].values) |
3979 | for(Value* x0: axes[ord[0]].values){ |
3980 | indices_t idx(2); |
3981 | idx[ord[0]] = x0; |
3982 | idx[ord[1]] = x1; |
3983 | idxs_[v].push_back(idx); |
3984 | } |
3985 | if(axes.size() == 3) |
3986 | for(Value* x2: axes[ord[2]].values) |
3987 | for(Value* x1: axes[ord[1]].values) |
3988 | for(Value* x0: axes[ord[0]].values){ |
3989 | indices_t idx(3); |
3990 | idx[ord[0]] = x0; |
3991 | idx[ord[1]] = x1; |
3992 | idx[ord[2]] = x2; |
3993 | idxs_[v].push_back(idx); |
3994 | } |
3995 | } |
3996 | |
3997 | void generator::finalize_shared_layout(analysis::shared_layout *shared) { |
3998 | if (auto n_buffer = shared->get_N_buffer()) { |
3999 | // if (*_smem_idx == #stages-1) { |
4000 | // *_smem_idx = 0; |
4001 | // } else *_smem_idx++; |
4002 | auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) { |
4003 | // insert point |
4004 | Value *idx = smem_idx[shared]; |
4005 | builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator()); |
4006 | Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1)); |
4007 | PHINode *_ret = phi(i32_ty, 2); |
4008 | Instruction *then_term = nullptr; |
4009 | Instruction *else_term = nullptr; |
4010 | Instruction *dummy = builder_->CreateRet(nullptr); |
4011 | llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr); |
4012 | dummy->removeFromParent(); |
4013 | builder_->SetInsertPoint(then_term); |
4014 | Value *zero_smem_idx = i32(0); |
4015 | builder_->SetInsertPoint(else_term); |
4016 | Value *inc_smem_idx = add(idx, i32(1)); |
4017 | builder_->SetInsertPoint(_ret->getParent()); |
4018 | _ret->addIncoming(zero_smem_idx, then_term->getParent()); |
4019 | _ret->addIncoming(inc_smem_idx, else_term->getParent()); |
4020 | // update ir::bb -> llvm::bb mapping |
4021 | bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock(); |
4022 | // idx = init_stage; |
4023 | // loop: ... |
4024 | if (auto idx_phi = llvm::dyn_cast<PHINode>(smem_idx[shared])) { |
4025 | idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0))); |
4026 | idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1))); |
4027 | } else |
4028 | throw std::runtime_error("Should be PHINode" ); |
4029 | }; |
4030 | |
4031 | // read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2 |
4032 | finalize_smem_idx(read_smem_idx_, 2); |
4033 | finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1); |
4034 | |
4035 | // finalize pointers |
4036 | ir::phi_node *pn = n_buffer->phi; |
4037 | BasicBlock * = bbs_.at(pn->get_incoming_block(0)); |
4038 | BasicBlock *loop = bbs_.at(pn->get_incoming_block(1)); |
4039 | // %curr_ptr = phi %shared_pre_ptr, %next_ptr |
4040 | // %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size)) |
4041 | if (auto curr_ptr = dyn_cast<PHINode>(shared_ptr_[shared])) { |
4042 | curr_ptr->addIncoming(shared_pre_ptr_[shared], header); |
4043 | curr_ptr->addIncoming(shared_next_ptr_[shared], loop); |
4044 | } else |
4045 | throw std::runtime_error("Should be PHINode" ); |
4046 | |
4047 | BasicBlock *current = builder_->GetInsertBlock(); |
4048 | builder_->SetInsertPoint(header->getTerminator()); |
4049 | Value * = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements())); |
4050 | builder_->SetInsertPoint(current->getTerminator()); |
4051 | |
4052 | assert(isa<PHINode>(shared_next_ptr_[shared])); |
4053 | static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header); |
4054 | |
4055 | Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements())); |
4056 | Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset); |
4057 | static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr, loop); |
4058 | } else if(shared->get_double_buffer()) { |
4059 | auto info = *shared->get_double_buffer(); |
4060 | ir::phi_node *phi = info.phi; |
4061 | PHINode *ptr = (PHINode*)shmems_[phi]; |
4062 | PHINode *offset = (PHINode*)shoffs_[phi]; |
4063 | for(unsigned n = 0; n < phi->get_num_incoming(); n++){ |
4064 | ir::basic_block* inc_block = phi->get_incoming_block(n); |
4065 | ir::value* inc_val = phi->get_incoming_value(n); |
4066 | BasicBlock *llvm_inc_block = bbs_.at(inc_block); |
4067 | if(inc_val == info.latch){ |
4068 | builder_->SetInsertPoint(llvm_inc_block->getTerminator()); |
4069 | Value *next_offset = neg(offset); |
4070 | offset->addIncoming(next_offset, llvm_inc_block); |
4071 | } |
4072 | else { |
4073 | unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8; |
4074 | offset->addIncoming(i32(shared->get_size() / (2*num_bytes)), llvm_inc_block); |
4075 | } |
4076 | ptr->addIncoming(shmems_[inc_val], llvm_inc_block); |
4077 | } |
4078 | } |
4079 | } |
4080 | |
4081 | void generator::finalize_function(ir::function *fn) { |
4082 | // finalize double-buffering |
4083 | for(const auto& x: layouts_->get_all()) |
4084 | if(auto *shared = dynamic_cast<analysis::shared_layout*>(x.second)) |
4085 | finalize_shared_layout(shared); |
4086 | // finalize phi |
4087 | for(ir::basic_block *block: fn->blocks()) |
4088 | for(ir::instruction *inst: block->get_inst_list()) |
4089 | if(auto *phi = dynamic_cast<ir::phi_node*>(inst)) |
4090 | finalize_phi_node(phi); |
4091 | for(auto& x: lazy_phi_incs_) |
4092 | std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]); |
4093 | } |
4094 | |
4095 | void generator::finalize_phi_node(ir::phi_node *x) { |
4096 | if(shmems_.find(x) != shmems_.end()) |
4097 | return; |
4098 | for(unsigned n = 0; n < x->get_num_incoming(); n++){ |
4099 | ir::basic_block *_block = x->get_incoming_block(n); |
4100 | BasicBlock *block = bbs_.at(_block); |
4101 | for(indices_t idx: idxs_.at(x)){ |
4102 | PHINode *phi = (PHINode*)vals_[x][idx]; |
4103 | Value *inc = vals_[x->get_incoming_value(n)][idx]; |
4104 | // x->print(std::cout); |
4105 | phi->addIncoming(inc, block); |
4106 | } |
4107 | } |
4108 | } |
4109 | |
4110 | void generator::packed_type(ir::value* i){ |
4111 | Type* dtype = cvt(i->get_type()->get_tile_element_ty()); |
4112 | auto* layout = dynamic_cast<analysis::scanline_layout*>(layouts_->get(i)); |
4113 | assert(layout); |
4114 | } |
4115 | |
4116 | void generator::visit(ir::module &src, llvm::Module &dst) { |
4117 | mod_ = &dst; |
4118 | ctx_ = &dst.getContext(); |
4119 | builder_ = new Builder(*ctx_); |
4120 | // allocate shared memory |
4121 | if(tgt_->is_gpu()) |
4122 | if(unsigned alloc_size = alloc_->allocated_size()){ |
4123 | Type *int_8_ty = Type::getInt8Ty(*ctx_); |
4124 | Type *int_32_ty = Type::getInt32Ty(*ctx_); |
4125 | ArrayType *array_ty = ArrayType::get(int_32_ty, 0); |
4126 | Type *ptr_ty = ptr_ty(int_8_ty, 3); |
4127 | GlobalVariable *sh_mem_array = |
4128 | new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage, |
4129 | nullptr, "__shared_ptr" , nullptr, GlobalVariable::NotThreadLocal, 3); |
4130 | shmem_ = bit_cast(sh_mem_array, ptr_ty); |
4131 | } |
4132 | // instantiate device functions |
4133 | // for(ir::function *fn: src.get_function_list()) |
4134 | // for(ir::basic_block *bb: fn->blocks()) |
4135 | // for(ir::instruction *i: bb->get_inst_list()) |
4136 | // if(auto *call = dynamic_cast<ir::call_inst*>(i)){ |
4137 | // std::cout << "call??" << std::endl; |
4138 | // } |
4139 | // visit functions |
4140 | for(ir::function *fn: src.get_function_list()) |
4141 | forward_declare(fn); |
4142 | for(ir::function *fn: src.get_function_list()) |
4143 | visit_function(fn); |
4144 | } |
4145 | |
4146 | void generator::add_extern_lib(const std::string &lib_name, |
4147 | const std::string &lib_path) { |
4148 | if (extern_lib_map_.count(lib_name) == 0) { |
4149 | extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path); |
4150 | } else if (extern_lib_map_.at(lib_name)->path() != lib_path) { |
4151 | throw std::runtime_error("A library has multiple paths (1) " + lib_path + |
4152 | " (2) " + extern_lib_map_.at(lib_name)->path()); |
4153 | } |
4154 | } |
4155 | |
4156 | } // namespace codegen |
4157 | } // namespace triton |
4158 | |