1#include <numeric>
2#include <sstream>
3#include <iomanip>
4#include <stdexcept>
5#include "triton/codegen/selection/generator.h"
6#include "triton/codegen/target.h"
7#include "triton/codegen/analysis/axes.h"
8#include "triton/codegen/analysis/allocation.h"
9#include "triton/codegen/analysis/align.h"
10#include "triton/codegen/analysis/swizzle.h"
11#include "triton/codegen/transform/coalesce.h"
12#include "triton/ir/context.h"
13#include "triton/ir/module.h"
14#include "triton/ir/function.h"
15#include "triton/ir/type.h"
16#include "triton/ir/utils.h"
17#include "llvm/IR/Module.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/IR/IntrinsicsNVPTX.h"
20#include "llvm/IR/BasicBlock.h"
21#include "llvm/IR/Attributes.h"
22#include "llvm/IR/InlineAsm.h"
23#include "llvm/Transforms/Utils/BasicBlockUtils.h"
24
25namespace triton{
26namespace codegen{
27
28using namespace llvm;
29
30Value* adder::operator()(Value *x, Value *y, const std::string& name) {
31 // (x + cst) + y -> (x + y) + cst
32 if(auto* bin = dyn_cast<BinaryOperator>(x))
33 if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
34 if(dyn_cast<Constant>(bin->getOperand(1))){
35 return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y),
36 bin->getOperand(1));
37 }
38 // (x + (y + cst)) -> (x + y) + cst
39 if(auto* bin = dyn_cast<BinaryOperator>(y))
40 if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
41 if(dyn_cast<Constant>(bin->getOperand(1))){
42 return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)),
43 bin->getOperand(1));
44 }
45
46 // default
47 return (*builder_)->CreateAdd(x, y, name);
48}
49
50Value* multiplier::operator()(Value *x, Value *y, const std::string &name) {
51 // (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2)
52 if(auto* bin = dyn_cast<BinaryOperator>(x))
53 if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
54 if(dyn_cast<Constant>(bin->getOperand(1)))
55 if(dyn_cast<Constant>(y)){
56 return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y),
57 (*builder_)->CreateMul(bin->getOperand(1), y));
58 }
59 // default
60 return (*builder_)->CreateMul(x, y, name);
61}
62
63Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
64 // (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2)
65 if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
66 if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
67 if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
68 return (*builder_)->CreateGEP(gep->getPointerOperand()->getType()->getScalarType()->getPointerElementType(),
69 gep->getPointerOperand(), (*builder_)->CreateAdd(cst1, cst2));
70 }
71 // ptr + (off + cst) -> (ptr + off) + cst
72 if(auto* bin = dyn_cast<BinaryOperator>(off))
73 if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
74 if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
75 Value *gep = (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(),
76 ptr, bin->getOperand(0));
77 return (*builder_)->CreateGEP(gep->getType()->getScalarType()->getPointerElementType(),
78 gep, bin->getOperand(1));
79 }
80 // default
81 return (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(),
82 ptr, off, name);
83}
84
85//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
86// return (*builder_)->CreateGEP(ty, ptr, vals, name);
87//}
88
89// types
90#define void_ty builder_->getVoidTy()
91#define f16_ty builder_->getHalfTy()
92#define bf16_ty builder_->getInt16Ty()
93#define f32_ty builder_->getFloatTy()
94#define i1_ty builder_->getInt1Ty()
95#define i8_ty builder_->getInt8Ty()
96#define i16_ty builder_->getInt16Ty()
97#define i32_ty builder_->getInt32Ty()
98#define i64_ty builder_->getInt64Ty()
99#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
100#define ptr_ty(...) PointerType::get(__VA_ARGS__)
101// constants
102#define i16(...) builder_->getInt16(__VA_ARGS__)
103#define i32(...) builder_->getInt32(__VA_ARGS__)
104// ops
105#define and_(...) builder_->CreateAnd(__VA_ARGS__)
106#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
107#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
108#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__)
109#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__)
110#define br(...) builder_->CreateBr(__VA_ARGS__)
111#define call(...) builder_->CreateCall(__VA_ARGS__)
112#define cast(...) builder_->CreateCast(__VA_ARGS__)
113#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__)
114#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__)
115#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__)
116#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
117#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
118#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
119#define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__)
120#define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__)
121#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
122#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
123#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
124#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
125#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
126#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
127#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
128#define icmp_uge(...) builder_->CreateICmpUGE(__VA_ARGS__)
129#define icmp_ule(...) builder_->CreateICmpULE(__VA_ARGS__)
130#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
131#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
132#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
133#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr)
134#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
135#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
136#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
137#define neg(...) builder_->CreateNeg(__VA_ARGS__)
138#define phi(...) builder_->CreatePHI(__VA_ARGS__)
139#define ret(...) builder_->CreateRet(__VA_ARGS__)
140#define select(...) builder_->CreateSelect(__VA_ARGS__)
141#define store(...) builder_->CreateStore(__VA_ARGS__)
142#define sub(...) builder_->CreateSub(__VA_ARGS__)
143#define shl(...) builder_->CreateShl(__VA_ARGS__)
144#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
145#define urem(...) builder_->CreateURem(__VA_ARGS__)
146#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
147#define xor_(...) builder_->CreateXor(__VA_ARGS__)
148
149/**
150 * \brief Convert Triton-IR Type to LLVM-IR Type
151 */
152Type *generator::cvt(ir::type *ty) {
153 // struct
154 if(ty->is_struct_ty()){
155 std::vector<Type*> tys;
156 for(size_t i = 0; i < ty->get_struct_numel(); i++)
157 tys.push_back(cvt(ty->get_struct_type(i)));
158 return StructType::get(builder_->getContext(), tys, true);
159 }
160
161 // function
162 if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
163 Type *ret_ty = cvt(tt->get_return_ty());
164 std::vector<Type*> arg_tys(tt->get_num_params());
165 for(size_t i = 0; i < arg_tys.size(); i++)
166 arg_tys[i] = cvt(tt->get_param_ty(i));
167 return FunctionType::get(ret_ty, arg_tys, false);
168 }
169 // pointer
170 if(ty->is_pointer_ty()){
171 Type *elt_ty = cvt(ty->get_pointer_element_ty());
172 unsigned addr_space = ty->get_pointer_address_space();
173 return ptr_ty(elt_ty, addr_space);
174 }
175 // integer
176 if(ty->is_integer_ty()){
177 unsigned bitwidth = ty->get_integer_bitwidth();
178 return IntegerType::get(*ctx_, bitwidth);
179 }
180 // primitive types
181 switch(ty->get_type_id()){
182 case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
183 case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
184 case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
185 case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type
186 case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
187 case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
188 case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
189 case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_);
190 case ir::type::TokenTyID: return Type::getTokenTy(*ctx_);
191 default: break;
192 }
193 // unknown type
194 throw std::runtime_error("unknown conversion from ir::type to Type");
195}
196
197/**
198 * \brief Convert Triton-IR Attribute to LLVM-IR Attribute
199 */
200llvm::Attribute generator::cvt(ir::attribute attr) {
201 switch(attr.get_kind()){
202 case ir::noalias: return llvm::Attribute::get(*ctx_, llvm::Attribute::NoAlias);
203 case ir::readonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::ReadOnly);
204 case ir::writeonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::WriteOnly);
205 case ir::aligned: return llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, attr.get_value());
206 case ir::retune: return llvm::Attribute::get(*ctx_, llvm::Attribute::None);
207 default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute");
208 }
209}
210
211/**
212 * \brief Constructor of LLVM code generator
213 */
214generator::generator(analysis::axes *a_axes,
215 analysis::layouts *layouts,
216 analysis::align *alignment,
217 analysis::allocation *alloc,
218 analysis::swizzle *swizzle,
219 target *tgt,
220 unsigned num_warps)
221 : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
222 tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
223
224}
225
226/**
227 * \brief Code Generation for `value`
228 */
229void generator::visit_value(ir::value* v) {
230 if(!seen_.insert(v).second)
231 return;
232 if(v->get_type()->is_block_ty()){
233 if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
234 analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer();
235 analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer();
236
237 // offset
238 Value *offset = nullptr;
239 // base pointer
240 Value *ptr = shared_ptr_[layout];
241
242 if (n_buffer) {
243 // ptr = base (shared_ptr_[layout]) + smem_idx * size
244 // read_smem_idx
245 if (v == n_buffer->phi) {
246 ptr = shared_ptr_[layout];
247 }
248 // write_smem_idx
249 if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) {
250 int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v);
251 int elements = write_smem_idx * layout->get_per_stage_elements();
252 ptr = gep(shared_pre_ptr_[layout], i32(elements));
253 } else if (v == n_buffer->latch) {
254 Value* write_smem_idx = write_smem_idx_[layout];
255 Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements()));
256 ptr = gep(shared_pre_ptr_[layout], elements);
257 }
258 } else if (double_buffer) {
259 if(v == double_buffer->phi)
260 offset = shared_off_[layout];
261 if(v == double_buffer->latch)
262 ptr = shared_next_ptr_[layout];
263 else if(v == double_buffer->first)
264 ptr = shared_pre_ptr_[layout];
265 } // else do nothing
266 // what visit_dot & vist_cts & ... see
267 shmems_[v] = ptr;
268 // now only latches have offset (PHINode), only used by finalize_share_layout()
269 shoffs_[v] = offset;
270 }
271 }
272 // visit operands
273 BasicBlock *current = builder_->GetInsertBlock();
274 auto *inst = dynamic_cast<ir::instruction*>(v);
275 if(inst)
276 for(ir::value *op: inst->ops()){
277 if(dynamic_cast<ir::constant*>(op) || !dynamic_cast<ir::phi_node*>(v))
278 visit_value(op);
279 }
280 init_idx(v);
281 // change insert point for phi node
282 builder_->SetInsertPoint(current);
283 auto *phi = dynamic_cast<ir::phi_node*>(v);
284 if(phi && !current->empty() && current->getFirstNonPHI())
285 builder_->SetInsertPoint(&*current->getFirstNonPHI());
286 // visit user
287 if(auto *usr = dynamic_cast<ir::user*>(v)){
288 if(!dynamic_cast<ir::function*>(usr))
289 usr->accept(this);
290 }
291 // revert insert point
292 if(phi && !current->empty() && current->getFirstNonPHI())
293 builder_->SetInsertPoint(current);
294}
295
296/**
297 * \brief Code Generation for `phi`
298 */
299void generator::visit_phi_node(ir::phi_node* x) {
300 Type *ty = cvt(x->get_type()->get_scalar_ty());
301 for(indices_t idx: idxs_.at(x))
302 vals_[x][idx] = phi(ty, x->get_num_operands());
303}
304
305/**
306 * \brief Code Generation for `call`
307 */
308void generator::visit_call_inst(ir::call_inst* call) {
309 throw std::runtime_error("call not supported! Triton should be inlining everything.");
310}
311
312void generator::visit_launch_inst(ir::launch_inst *launch) {
313 ir::function* fn = (ir::function*)launch->get_operand(0);
314 // forward-declare cudaGetParameterBufferV2
315 std::vector<Type*> get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0),
316 ArrayType::get(builder_->getInt32Ty(), 3),
317 ArrayType::get(builder_->getInt32Ty(), 3),
318 builder_->getInt32Ty()};
319 FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false);
320 Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2", mod_);
321 AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]);
322 AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]);
323 ConstantInt* _0 = builder_->getInt32(0);
324 ConstantInt* _1 = builder_->getInt32(1);
325 ConstantInt* _2 = builder_->getInt32(2);
326 // create basic block
327 BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done", builder_->GetInsertBlock()->getParent());
328 BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch", launch_done_bb->getParent(), launch_done_bb);
329 Value *tid = tgt_->get_local_id(mod_, *builder_, 0);
330 Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0));
331 builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb);
332 builder_->SetInsertPoint(launch_bb);
333
334 //
335 builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0}));
336 builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1}));
337 builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2}));
338 Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]);
339 builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0}));
340 builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1}));
341 builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2}));
342 Function* called_fn = fns_[fn];
343 Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]);
344 Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)});
345 // forwrd-declare cudaLaunchDeviceV2
346 std::vector<Type*> launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()};
347 FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false);
348 Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2", mod_);
349 // TODO: add branch
350 Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()),
351 builder_->getInt64(0));
352 BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2", launch_done_bb->getParent(), launch_done_bb);
353 builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb);
354 builder_->SetInsertPoint(launch2_bb);
355
356 unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace();
357 unsigned off = 0;
358 unsigned last_size = 0;
359 for(ir::value* arg: launch->get_values()){
360 Value* curr_arg = vals_[arg][{}];
361 Type* curr_arg_ty = curr_arg->getType();
362 // handle struct alignment
363 off += last_size;
364 unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8;
365 off = (off + size - 1) / size * size;
366 // get pointer to current arg
367 Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off));
368 curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space));
369 // store arg
370 builder_->CreateStore(curr_arg, curr_arg_ptr);
371 last_size = size;
372 }
373 builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)});
374 builder_->CreateBr(launch_done_bb);
375 // done
376 builder_->SetInsertPoint(launch_done_bb);
377
378}
379
380/**
381 * \brief Code Generation for `binary_operator`
382 */
383void generator::visit_binary_operator(ir::binary_operator*x) {
384 using ll = llvm::Instruction::BinaryOps;
385 using tt = ir::binary_op_t;
386 auto cvt = [](ir::binary_op_t op){
387 switch(op) {
388 case tt::Add: return ll::Add;
389 case tt::FAdd: return ll::FAdd;
390 case tt::Sub: return ll::Sub;
391 case tt::FSub: return ll::FSub;
392 case tt::Mul: return ll::Mul;
393 case tt::FMul: return ll::FMul;
394 case tt::UDiv: return ll::UDiv;
395 case tt::SDiv: return ll::SDiv;
396 case tt::FDiv: return ll::FDiv;
397 case tt::URem: return ll::URem;
398 case tt::SRem: return ll::SRem;
399 case tt::FRem: return ll::FRem;
400 case tt::Shl: return ll::Shl;
401 case tt::LShr: return ll::LShr;
402 case tt::AShr: return ll::AShr;
403 case tt::And: return ll::And;
404 case tt::Or: return ll::Or;
405 case tt::Xor: return ll::Xor;
406 default: throw std::runtime_error("unreachable switch");
407 }
408 };
409// x->print(std::cout);
410 for(indices_t idx: idxs_.at(x)){
411 Value *lhs = vals_[x->get_operand(0)][idx];
412 Value *rhs = vals_[x->get_operand(1)][idx];
413 // manually select bf16 bin op
414 if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
415 assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
416 if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b
417 InlineAsm *bf16_add_asm =
418 InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
419 "{ .reg .b16 c; \n\t"
420 " mov.b16 c, 0x3f80U; \n\t" // 1.0
421 " fma.rn.bf16 $0, $1, c, $2; } \n\t",
422 "=h,h,h", false);
423 vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
424 } else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
425 InlineAsm *bf16_sub_asm =
426 InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
427 " { .reg .b16 c; \n\t"
428 " mov.b16 c, 0xbf80U; \n\t" // -1.0
429 " fma.rn.bf16 $0, $2, c, $1;} \n\t",
430 "=h,h,h", false);
431 vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
432 } else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
433 InlineAsm *bf16_mul_asm =
434 InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
435 " { .reg .b16 c; \n\t"
436 " mov.b16 c, 0x8000U; \n\t" // 0.0
437 " fma.rn.bf16 $0, $1, $2, c;} \n\t",
438 "=h,h,h", false);
439 vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
440 } else
441 throw std::runtime_error("invalid bin op for bf16");
442 } else { // not bf16
443 auto op = cvt(x->get_op());
444 if(op == ll::Add)
445 vals_[x][idx] = add(lhs, rhs);
446 else if(op == ll::Mul)
447 vals_[x][idx] = mul(lhs, rhs);
448 else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
449 x->get_type()->get_scalar_ty()->is_fp32_ty()){
450 InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
451 " div.full.f32 $0, $1, $2;", "=r,r,r", false);
452 vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
453
454 }
455 else
456 vals_[x][idx] = bin_op(op, lhs, rhs);
457 }
458 }
459}
460
461/**
462 * \brief Code Generation for `getelementptr`
463 */
464void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) {
465 for(indices_t idx: idxs_.at(x)){
466 Value *ptr = vals_[x->get_pointer_operand()][idx];
467 std::vector<Value*> vals;
468 for(auto it= x->idx_begin(); it != x->idx_end(); it++)
469 vals.push_back(vals_[*it][idx]);
470 assert(vals.size() == 1);
471 vals_[x][idx] = gep(ptr, vals[0]);
472 }
473}
474
475/**
476 * \brief Code Generation for `icmp`
477 */
478void generator::visit_icmp_inst(ir::icmp_inst* x) {
479 auto cvt = [](ir::cmp_pred_t pred) {
480 using ll = llvm::CmpInst::Predicate;
481 using tt = ir::cmp_pred_t;
482 switch(pred){
483 case tt::FIRST_ICMP_PREDICATE: return ll::FIRST_ICMP_PREDICATE;
484 case tt::ICMP_EQ: return ll::ICMP_EQ;
485 case tt::ICMP_NE: return ll::ICMP_NE;
486 case tt::ICMP_UGT: return ll::ICMP_UGT;
487 case tt::ICMP_UGE: return ll::ICMP_UGE;
488 case tt::ICMP_ULT: return ll::ICMP_ULT;
489 case tt::ICMP_ULE: return ll::ICMP_ULE;
490 case tt::ICMP_SGT: return ll::ICMP_SGT;
491 case tt::ICMP_SGE: return ll::ICMP_SGE;
492 case tt::ICMP_SLT: return ll::ICMP_SLT;
493 case tt::ICMP_SLE: return ll::ICMP_SLE;
494 case tt::LAST_ICMP_PREDICATE: return ll::LAST_ICMP_PREDICATE;
495 default: throw std::runtime_error("unreachable switch");
496 }
497 };
498
499 for(indices_t idx: idxs_.at(x)){
500 Value *lhs = vals_[x->get_operand(0)][idx];
501 Value *rhs = vals_[x->get_operand(1)][idx];
502 vals_[x][idx] = icmp(cvt(x->get_pred()), lhs, rhs);
503 }
504}
505
506/**
507 * \brief Code Generation for `fcmp`
508 */
509void generator::visit_fcmp_inst(ir::fcmp_inst* x) {
510 auto cvt = [](ir::cmp_pred_t pred) {
511 using ll = llvm::CmpInst::Predicate;
512 using tt = ir::cmp_pred_t;
513 switch(pred){
514 case tt::FIRST_FCMP_PREDICATE: return ll::FIRST_FCMP_PREDICATE;
515 case tt::FCMP_FALSE: return ll::FCMP_FALSE;
516 case tt::FCMP_OEQ: return ll::FCMP_OEQ;
517 case tt::FCMP_OGT: return ll::FCMP_OGT;
518 case tt::FCMP_OGE: return ll::FCMP_OGE;
519 case tt::FCMP_OLT: return ll::FCMP_OLT;
520 case tt::FCMP_OLE: return ll::FCMP_OLE;
521 case tt::FCMP_ONE: return ll::FCMP_ONE;
522 case tt::FCMP_ORD: return ll::FCMP_ORD;
523 case tt::FCMP_UNO: return ll::FCMP_UNO;
524 case tt::FCMP_UEQ: return ll::FCMP_UEQ;
525 case tt::FCMP_UGT: return ll::FCMP_UGT;
526 case tt::FCMP_UGE: return ll::FCMP_UGE;
527 case tt::FCMP_ULT: return ll::FCMP_ULT;
528 case tt::FCMP_ULE: return ll::FCMP_ULE;
529 case tt::FCMP_UNE: return ll::FCMP_UNE;
530 case tt::FCMP_TRUE: return ll::FCMP_TRUE;
531 case tt::LAST_FCMP_PREDICATE: return ll::LAST_FCMP_PREDICATE;
532 default: throw std::runtime_error("unreachable switch");
533 }
534 };
535 for(indices_t idx: idxs_.at(x)){
536 Value *lhs = vals_[x->get_operand(0)][idx];
537 Value *rhs = vals_[x->get_operand(1)][idx];
538 vals_[x][idx] = fcmp(cvt(x->get_pred()), lhs, rhs);
539 }
540}
541
542
543std::tuple<Value*, Value*, Value*, Value*> generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){
544 in0 = cast(llvm::Instruction::FPTrunc, in0, f16_ty);
545 in1 = cast(llvm::Instruction::FPTrunc, in1, f16_ty);
546 in2 = cast(llvm::Instruction::FPTrunc, in2, f16_ty);
547 in3 = cast(llvm::Instruction::FPTrunc, in3, f16_ty);
548 Value *ret0, *ret1, *ret2, *ret3;
549 std::tie(ret0, ret1, ret2, ret3) = fp16x4_to_fp8x4(in0, in1, in2, in3);
550 return std::make_tuple(ret0, ret1, ret2, ret3);
551}
552
553std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){
554 Value *ret0, *ret1, *ret2, *ret3;
555 std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3);
556 ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty);
557 ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty);
558 ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty);
559 ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty);
560 return std::make_tuple(ret0, ret1, ret2, ret3);
561}
562
563
564std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){
565 Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
566 InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
567 "{"
568 ".reg .b32 a<2>, b<2>; \n\t"
569 "prmt.b32 a0, 0, $2, 0x5040; \n\t" // If input is 0xdcba set a0 to 0xb0a0
570 "prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0
571 "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign)
572 "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign)
573 "shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 position)
574 "shr.b32 b1, b1, 1; \n\t" // b1 >>= 1 (shift into fp16 position)
575 "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 & a0) (restore sign)
576 "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 & a1) (restore sign)
577 "}", "=r,=r,r", false);
578 Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
579 packed_in = insert_elt(packed_in, in0, (uint64_t)0);
580 packed_in = insert_elt(packed_in, in1, (uint64_t)1);
581 packed_in = insert_elt(packed_in, in2, (uint64_t)2);
582 packed_in = insert_elt(packed_in, in3, (uint64_t)3);
583 Value *in = bit_cast(packed_in, i32_ty);
584 Value *ret = call(ptx, {in});
585 Value *packed_ret0 = extract_val(ret, {0});
586 Value *packed_ret1 = extract_val(ret, {1});
587 Value *ret0 = extract_elt(packed_ret0, (uint64_t)0);
588 Value *ret1 = extract_elt(packed_ret0, (uint64_t)1);
589 Value *ret2 = extract_elt(packed_ret1, (uint64_t)0);
590 Value *ret3 = extract_elt(packed_ret1, (uint64_t)1);
591 return std::make_tuple(ret0, ret1, ret2, ret3);
592}
593
594std::tuple<Value*, Value*, Value*, Value*> generator::fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) {
595 /* fp16 bit representation is seeeeemmmmmmmmmm (s=sign, e=exponent, m=mantissa)
596 * fp8 bit representation is seeeemmm
597 * The 4 fp8 exponent bits are the low order 4 exponent bits in fp16.
598 * The 3 fp8 mantissa bits are the high order 3 mantissa bits in fp16.
599 * Note that the low order exponent bits and high order mantissa bits in fp16 are contiguous.
600 * We want to round to nearest fp8 value. To do that add 1 to 4th mantissa bit in fp16 (that's
601 * one more than the number of mantissa bits in fp8).
602 * fp8 = (fp16 & 0x8000) | (((f16 << 1) + 0x0080) & 0x7fff)
603 *
604 * We compute two fp16s in one uint32. The addition could cause bit flips from one fp16 to the
605 * other. To avoid this we zero out the most significant exponent bit. If that bit is set then
606 * the value isn't representable in float8 anyway so we assume it's never set (and give garbage
607 * output if it is). If we were willing to assume the most significant exponent was never set
608 * we could save the first two lop3.b32 instructions below.
609 */
610 InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false),
611 "{"
612 ".reg .b32 a<2>, b<2>; \n\t"
613 "shl.b32 a0, $1, 1; \n\t" // a0 = input0 << 1
614 "shl.b32 a1, $2, 1; \n\t" // a1 = input1 << 1
615 "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // a0 = (a0 & 0x7fff7fff)
616 "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // a1 = (a1 & 0x7fff7fff)
617 "add.u32 a0, a0, 0x00800080; \n\t" // a0 += 0x00800080
618 "add.u32 a1, a1, 0x00800080; \n\t" // a1 += 0x00800080
619 "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n\t" // b0 = (input0 & 0x80008000) | a0
620 "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n\t" // b1 = (input1 & 0x80008000) | a1
621 "prmt.b32 $0, b0, b1, 0x7531; \n\t" // If b0 = 0xabcd and b1=0x0123 sets output to 0xac02
622 "}", "=r,r,r", false);
623 Value *packed_in0 = UndefValue::get(vec_ty(f16_ty, 2));
624 Value *packed_in1 = UndefValue::get(vec_ty(f16_ty, 2));
625 packed_in0 = insert_elt(packed_in0, in0, (int)0);
626 packed_in0 = insert_elt(packed_in0, in1, (int)1);
627 packed_in1 = insert_elt(packed_in1, in2, (int)0);
628 packed_in1 = insert_elt(packed_in1, in3, (int)1);
629 Value *in_arg0 = bit_cast(packed_in0, i32_ty);
630 Value *in_arg1 = bit_cast(packed_in1, i32_ty);
631 Value *ret = call(ptx, {in_arg0, in_arg1});
632 Value *ret0 = extract_elt(ret, (int)0);
633 Value *ret1 = extract_elt(ret, (int)1);
634 Value *ret2 = extract_elt(ret, (int)2);
635 Value *ret3 = extract_elt(ret, (int)3);
636 return std::make_tuple(ret0, ret1, ret2, ret3);
637}
638
639std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3) {
640 // current exp offset: 15
641 // Add 112 (127-15) to compensate the difference in exponent bias
642 // bf16 = (nosign >> (8-4) + 112 << 7) | sign;
643 // bf16 = (nosign >> 4 + 0x3800) | sign;
644 Type *ret_ty = StructType::get(*ctx_, {vec_ty(bf16_ty, 2), vec_ty(bf16_ty, 2)});
645 InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
646 "{"
647 ".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n\t"
648 "prmt.b32 a0, 0, $2, 0x5040; \n\t" // 0xdcba => 0xb0a0
649 "prmt.b32 a1, 0, $2, 0x7060; \n\t" // 0xdcba => 0xd0c0
650 "and.b32 sign0, a0, 0x80008000; \n\t"
651 "and.b32 sign1, a1, 0x80008000; \n\t"
652 "and.b32 nosign0, a0, 0x7fff7fff; \n\t"
653 "and.b32 nosign1, a1, 0x7fff7fff; \n\t"
654 "shr.b32 nosign0, nosign0, 4; \n\t"
655 "shr.b32 nosign1, nosign1, 4; \n\t"
656 "add.u32 nosign0, nosign0, 0x38003800; \n\t"
657 "add.u32 nosign1, nosign1, 0x38003800; \n\t"
658 "or.b32 $0, sign0, nosign0; \n\t"
659 "or.b32 $1, sign1, nosign1; \n\t"
660 "}", "=r,=r,r", false);
661 Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
662 packed_in = insert_elt(packed_in, in0, (uint64_t)0);
663 packed_in = insert_elt(packed_in, in1, (uint64_t)1);
664 packed_in = insert_elt(packed_in, in2, (uint64_t)2);
665 packed_in = insert_elt(packed_in, in3, (uint64_t)3);
666 Value *in = bit_cast(packed_in, i32_ty);
667 Value *ret = call(ptx, {in});
668 Value *packed_ret0 = extract_val(ret, {0});
669 Value *packed_ret1 = extract_val(ret, {1});
670 Value *ret0 = extract_elt(packed_ret0, (uint64_t)0);
671 Value *ret1 = extract_elt(packed_ret0, (uint64_t)1);
672 Value *ret2 = extract_elt(packed_ret1, (uint64_t)0);
673 Value *ret3 = extract_elt(packed_ret1, (uint64_t)1);
674 return std::make_tuple(ret0, ret1, ret2, ret3);
675}
676
677std::tuple<Value*, Value*, Value*, Value*> generator::bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) {
678 /* Assuming fp8 exponent offset is 16. bf16 exponent offset is 127.
679 Max value in fp8: 0b01111111 (0x7f),
680 bf16: 3ff0
681 Min value in fp8: 0b00000000 (0x00)
682 bf16: 0x3c00
683 // @note: +0x8 is for "rounding to nearest zero"
684 fp8 = (nosign(bf16) - (112 << 7) + 0x8) << 4;
685 return fp8 | sign; // also permute bytes
686 */
687 InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false),
688 "{\n\t"
689 ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n\t"
690 ".reg .u32 fp8_min, fp8_max, rn_, zero; \n\t"
691 "mov.u32 fp8_min, 0x38003800; \n\t"
692 "mov.u32 fp8_max, 0x3ff03ff0; \n\t"
693 "mov.u32 rn_, 0x80008; \n\t"
694 "mov.u32 zero, 0; \n\t"
695 "and.b32 sign0, $1, 0x80008000; \n\t"
696 "and.b32 sign1, $2, 0x80008000; \n\t"
697 "prmt.b32 sign, sign0, sign1, 0x7531; \n\t"
698 "and.b32 nosign0, $1, 0x7fff7fff; \n\t"
699 "and.b32 nosign1, $2, 0x7fff7fff; \n\t"
700
701 ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n\t" // nosign = clamp(nosign, min, max)
702 "and.b32 nosign_0_0, nosign0, 0xffff0000; \n\t"
703 "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n\t"
704 "min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n\t"
705 "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n\t"
706 "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n\t"
707 "min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n\t"
708 "or.b32 nosign0, nosign_0_0, nosign_0_1; \n\t"
709 "and.b32 nosign_1_0, nosign1, 0xffff0000; \n\t"
710 "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n\t"
711 "min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n\t"
712 "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n\t"
713 "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n\t"
714 "min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n\t"
715 "or.b32 nosign1, nosign_1_0, nosign_1_1; \n\t"
716
717 "add.u32 nosign0, nosign0, rn_; \n\t" // round to nearest zero
718 "add.u32 nosign1, nosign1, rn_; \n\t"
719 "sub.u32 nosign0, nosign0, 0x38003800; \n\t" // compensate offset
720 "sub.u32 nosign1, nosign1, 0x38003800; \n\t"
721 "shr.u32 nosign0, nosign0, 4; \n\t"
722 "shr.u32 nosign1, nosign1, 4; \n\t"
723 "prmt.b32 nosign, nosign0, nosign1, 0x6420; \n\t"
724 "or.b32 $0, nosign, sign; \n\t"
725 ""
726 "}", "=r,r,r", false);
727 Value *packed_in0 = UndefValue::get(vec_ty(bf16_ty, 2));
728 Value *packed_in1 = UndefValue::get(vec_ty(bf16_ty, 2));
729 packed_in0 = insert_elt(packed_in0, in0, (int)0);
730 packed_in0 = insert_elt(packed_in0, in1, (int)1);
731 packed_in1 = insert_elt(packed_in1, in2, (int)0);
732 packed_in1 = insert_elt(packed_in1, in3, (int)1);
733 Value *in_arg0 = bit_cast(packed_in0, i32_ty);
734 Value *in_arg1 = bit_cast(packed_in1, i32_ty);
735 Value *ret = call(ptx, {in_arg0, in_arg1});
736 Value *ret0 = extract_elt(ret, (int)0);
737 Value *ret1 = extract_elt(ret, (int)1);
738 Value *ret2 = extract_elt(ret, (int)2);
739 Value *ret3 = extract_elt(ret, (int)3);
740 return std::make_tuple(ret0, ret1, ret2, ret3);
741}
742
743Value* generator::bf16_to_fp32(Value *in0){
744 if (tgt_->as_nvidia()->sm() >= 80) {
745 InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
746 "cvt.rn.f32.bf16 $0, $1;", "=r,h", false);
747 return call(ptx, {in0});
748 } else {
749 Value *ret = UndefValue::get(vec_ty(i16_ty, 2));
750 ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1);
751 ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0);
752 return bit_cast(ret, f32_ty);
753 }
754}
755
756Value* generator::fp32_to_bf16(Value *in0){
757 if(tgt_->as_nvidia()->sm() >= 80){
758 InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
759 "cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
760 return call(ptx, {in0});
761 }
762 return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1);
763}
764
765/**
766 * \brief Code Generation for `cast`
767 */
768void generator::visit_cast_inst(ir::cast_inst* x) {
769 ir::value *op = x->get_operand(0);
770 ir::type* ret_sca_ty = x->get_type()->get_scalar_ty();
771 ir::type* op_sca_ty = op->get_type()->get_scalar_ty();
772 auto x_idxs = idxs_.at(x);
773 auto op_idxs = idxs_.at(op);
774
775 // <> FP8
776 if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){
777 // ensure that conversions can be vectorized
778 int ld = layouts_->get(x)->get_order(0);
779 int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
780 if(contiguous % 4 != 0)
781 throw std::runtime_error("unsupported fp32 -> fp8 conversion");
782
783 // run the conversion
784 auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
785 if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty())
786 return fp32x4_to_fp8x4(a, b, c, d);
787 if(op_sca_ty->is_fp16_ty() && ret_sca_ty->is_fp8_ty())
788 return fp16x4_to_fp8x4(a, b, c, d);
789 if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty())
790 return fp8x4_to_fp16x4(a, b, c, d);
791 if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty())
792 return fp8x4_to_fp32x4(a, b, c, d);
793 // fp8 <> bf16
794 if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_bf16_ty())
795 return fp8x4_to_bf16x4(a, b, c, d);
796 if (op_sca_ty->is_bf16_ty() && ret_sca_ty->is_fp8_ty())
797 return bf16x4_to_fp8x4(a, b, c, d);
798 throw std::runtime_error("unsupported conversion");
799 };
800 for(size_t i = 0; i < x_idxs.size(); i+=4){
801 std::tie(vals_[x][x_idxs[i+0]],
802 vals_[x][x_idxs[i+1]],
803 vals_[x][x_idxs[i+2]],
804 vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]],
805 vals_[op][op_idxs[i+1]],
806 vals_[op][op_idxs[i+2]],
807 vals_[op][op_idxs[i+3]]);
808 }
809 return;
810 }
811
812 // <> BF16
813 if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
814 // FP32 -> BF16
815 if(op_sca_ty->is_fp32_ty()){
816 for (indices_t idx: idxs_.at(x)) {
817 Value *arg = vals_[x->get_operand(0)][idx];
818 vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
819 }
820 return;
821 }
822 // BF16 -> FP32
823 if(ret_sca_ty->is_fp32_ty()){
824 for(size_t i = 0; i < x_idxs.size(); i++)
825 vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
826 return;
827 }
828 }
829
830
831 Type *ty = cvt(x->get_type()->get_scalar_ty());
832 auto cvt = [](ir::cast_op_t op){
833 using ll = llvm::Instruction::CastOps;
834 using tt = ir::cast_op_t;
835 switch(op){
836 case tt::Trunc: return ll::Trunc;
837 case tt::ZExt: return ll::ZExt;
838 case tt::SExt: return ll::SExt;
839 case tt::FPTrunc: return ll::FPTrunc;
840 case tt::FPExt: return ll::FPExt;
841 case tt::UIToFP: return ll::UIToFP;
842 case tt::SIToFP: return ll::SIToFP;
843 case tt::FPToUI: return ll::FPToUI;
844 case tt::FPToSI: return ll::FPToSI;
845 case tt::PtrToInt: return ll::PtrToInt;
846 case tt::IntToPtr: return ll::IntToPtr;
847 case tt::BitCast: return ll::BitCast;
848 case tt::AddrSpaceCast: return ll::AddrSpaceCast;
849 default: throw std::runtime_error("unreachable switch");
850 }
851 };
852 for(indices_t idx: idxs_.at(x)){
853 Value *arg = vals_[x->get_operand(0)][idx];
854 vals_[x][idx] = cast(cvt(x->get_op()), arg, ty);
855 }
856}
857
858std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int16_to_float16x8(
859 Value *in0, Value *scale_x512, Value *shift
860){
861 /* unpacking 8 int2s packed into an int16 to 8 float16s
862 * the algorithm is similar to
863 * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
864 fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1492-L1563
865 */
866 Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
867 InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
868 "{"
869 ".reg .b32 a<2>, b<4>; \n\t" // input is 0xab,cd,ef,gh,ab,cd,ef,gh, each a, b etc occupies two bits.
870 "and.b32 a0, 0x30300303, $4; \n\t" // set a0 to 0x0b,00,0f,00,00,0d,00,0h
871 "and.b32 a1, 0xc0c00c0c, $4; \n\t" // set a1 to 0xa0,00,e0,00,00,c0,00,g0
872 "prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x00,00,00,0d,00,00,00,0h
873 "prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00,00,00,c0,00,00,00,g0
874 "prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x00,00,0b,00,00,00,0f,00
875 "prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00,00,a0,00,00,00,e0,00
876 "mov.b32 a0, 0x78007800; \n\t" // a0 = 32768
877 "mov.b32 a1, 0x70007000; \n\t" // a1 = 8192
878 "mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768.
879 "mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 8192.
880 "mov.b32 a0, 0x68006800; \n\t" // a0 = 2048
881 "mov.b32 a1, 0x60006000; \n\t" // a1 = 512
882 "mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 2048.
883 "mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 512.
884 "fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift.
885 "fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift.
886 "fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out2 = b2 * scale + shift.
887 "fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out3 = b3 * scale + shift.
888 "}", "=r,=r,=r,=r,r,r,r", false);
889
890 Value *packed_in = UndefValue::get(vec_ty(i16_ty, 2));
891 packed_in = insert_elt(packed_in, in0, (int)0);
892 packed_in = insert_elt(packed_in, in0, (int)1);
893 Value *in = bit_cast(packed_in, i32_ty);
894
895 Value *ret = call(ptx, {in, scale_x512, shift});
896 Value *packed_ret0 = extract_val(ret, {0});
897 Value *packed_ret1 = extract_val(ret, {1});
898 Value *packed_ret2 = extract_val(ret, {2});
899 Value *packed_ret3 = extract_val(ret, {3});
900 Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h
901 Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g
902 Value *ret2 = extract_elt(packed_ret2, (uint64_t)0); // f
903 Value *ret3 = extract_elt(packed_ret3, (uint64_t)0); // e
904 Value *ret4 = extract_elt(packed_ret0, (uint64_t)1); // d
905 Value *ret5 = extract_elt(packed_ret1, (uint64_t)1); // c
906 Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b
907 Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a
908 return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7);
909}
910
911std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int32_to_float16x8(
912 Value *in0, Value *scale_x512, Value *shift
913){
914 /* unpacking 8 int4s packed into an int32 to 8 float16s
915 * the algorithm is similar to
916 * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
917 fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1566-L1619
918 */
919 Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
920 InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
921 "{"
922 ".reg .b32 a<2>, b<4>; \n\t"
923 "and.b32 a0, 0x0f0f0f0f, $4; \n\t" // If input is 0xabcdefgh set a to 0x0b0d0f0h
924 "and.b32 a1, 0xf0f0f0f0, $4; \n\t" // If input is 0xabcdefgh set a to 0xa0c0e0g0
925 "prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x000f000h
926 "prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00e000g0
927 "prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x000b000d
928 "prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00a000c0
929 "mov.b32 a0, 0x78007800; \n\t"
930 "mov.b32 a1, 0x68006800; \n\t"
931 "mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768.
932 "mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 2048.
933 "mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 32768.
934 "mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 2048.
935 "fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift.
936 "fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift.
937 "fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out0 = b0 * scale + shift.
938 "fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out1 = b1 * scale + shift.
939 "}", "=r,=r,=r,=r,r,r,r", false);
940
941 Value *ret = call(ptx, {in0, scale_x512, shift});
942 Value *packed_ret0 = extract_val(ret, {0});
943 Value *packed_ret1 = extract_val(ret, {1});
944 Value *packed_ret2 = extract_val(ret, {2});
945 Value *packed_ret3 = extract_val(ret, {3});
946 Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h
947 Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g
948 Value *ret2 = extract_elt(packed_ret0, (uint64_t)1); // f
949 Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // e
950 Value *ret4 = extract_elt(packed_ret2, (uint64_t)0); // d
951 Value *ret5 = extract_elt(packed_ret3, (uint64_t)0); // c
952 Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b
953 Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a
954 return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7);
955}
956
957std::tuple<Value*, Value*, Value*, Value*> generator::int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift){
958 /* unpacking 4 int8s packed into an int32 to 4 fp16s
959 * the algorithm is similar to
960 * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
961 fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1622-L1646
962 */
963 Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
964 InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
965 "{"
966 ".reg .b32 a, b<2>; \n\t"
967 "prmt.b32 b0, 0, $2, 0x0504; \n\t" // If input is 0xabcdefgh set b0 to 0x00ef00gh
968 "prmt.b32 b1, 0, $2, 0x0706; \n\t" // If input is 0xabcdefgh set b1 to 0x00ab00cd
969 "mov.b32 a, 0x78007800; \n\t"
970 "mul.f16x2 b0, b0, a; \n\t" // b0 = b0 * 32768.
971 "mul.f16x2 b1, b1, a; \n\t" // b1 = b1 * 32768.
972 "fma.rn.f16x2 $0, b0, $3, $4; \n\t" // out0 = b0 * scale + shift.
973 "fma.rn.f16x2 $1, b1, $3, $4; \n\t" // out1 = b1 * scale + shift.
974 "}", "=r,=r,r,r,r", false);
975
976 Value *ret = call(ptx, {in0, scale_x512, shift});
977 Value *packed_ret0 = extract_val(ret, {0});
978 Value *packed_ret1 = extract_val(ret, {1});
979 Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // gh
980 Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); // ef
981 Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); // cd
982 Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // ab
983 return std::make_tuple(ret0, ret1, ret2, ret3);
984}
985
986std::tuple<Value*, Value*> generator::prepare_scale_shift(Value *scale, Value *shift){
987 Value *scale_x512 = fmul(scale, bit_cast(i16(0x6000), f16_ty));
988 Value *p_scale_x512 = UndefValue::get(vec_ty(f16_ty, 2));
989 p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)0);
990 p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)1);
991 p_scale_x512 = bit_cast(p_scale_x512, i32_ty);
992
993 Value *p_shift = UndefValue::get(vec_ty(f16_ty, 2));
994 p_shift = insert_elt(p_shift, shift, (int)0);
995 p_shift = insert_elt(p_shift, shift, (int)1);
996 p_shift = bit_cast(p_shift, i32_ty);
997
998 return std::make_tuple(p_scale_x512, p_shift);
999}
1000
1001/**
1002 * \brief Code Generation for `dequantize`
1003 */
1004void generator::visit_dequantize_inst(ir::dequantize_inst* x) {
1005 ir::value *op = x->get_operand(0);
1006
1007 auto src_ty_size_in_bits = op->get_type()->get_scalar_ty()->get_primitive_size_in_bits();
1008
1009 auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
1010 auto op_last_dim = (op->get_type()->get_block_shapes()).back();
1011
1012 auto x_idxs = idxs_.at(x);
1013 auto op_idxs = idxs_.at(op);
1014
1015 ir::value *scale = x->get_operand(1);
1016 ir::value *shift = x->get_operand(2);
1017
1018 Value *p_scale_x512, *p_shift;
1019 std::tie(p_scale_x512, p_shift) = prepare_scale_shift(vals_[scale][{}], vals_[shift][{}]);
1020
1021 int ld = layouts_->get(x)->get_order(0);
1022 int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
1023
1024 int op_ld = layouts_->get(op)->get_order(0);
1025 int op_contiguous = layouts_->get(op)->to_scanline()->nts(op_ld);
1026
1027 std::string err_msg;
1028 err_msg = "unsupported dequantization, cannot vectorize properly. x_idxs.size(): "
1029 + std::to_string(x_idxs.size()) + "; op_idxs.size(): "
1030 + std::to_string(op_idxs.size()) + "; contiguous: "
1031 + std::to_string(contiguous) + "; op_contiguous: "
1032 + std::to_string(op_contiguous) + ". if the condition "
1033 "is not met, please try adjusting block_size, num_warps or "
1034 "using tl.multiple_of to hint the input/output ptr address.";
1035
1036 if (ret_last_dim == 8 * op_last_dim) {
1037 if((x_idxs.size() != 8 * op_idxs.size()) || (contiguous != 8 * op_contiguous)) {
1038 throw std::runtime_error(err_msg);
1039 }
1040
1041 auto cvt = [&](
1042 Value* a, Value* scale, Value* shift
1043 ){
1044 if (src_ty_size_in_bits == 16){ // int2 quantization, int16 to 8 fp16s
1045 return int16_to_float16x8(a, scale, shift);
1046 } else if (src_ty_size_in_bits == 32) { // int4 quantization, int32 to 8 fp16s
1047 return int32_to_float16x8(a, scale, shift);
1048 } else {
1049 throw std::runtime_error("unsupported conversion");
1050 }
1051 };
1052
1053 for(size_t j = 0; j < op_idxs.size(); j++){
1054 size_t i = j * 8;
1055 std::tie(vals_[x][x_idxs[i+0]],
1056 vals_[x][x_idxs[i+1]],
1057 vals_[x][x_idxs[i+2]],
1058 vals_[x][x_idxs[i+3]],
1059 vals_[x][x_idxs[i+4]],
1060 vals_[x][x_idxs[i+5]],
1061 vals_[x][x_idxs[i+6]],
1062 vals_[x][x_idxs[i+7]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift);
1063 }
1064 } else if (ret_last_dim == 4 * op_last_dim && src_ty_size_in_bits == 32) { // int8 quantization, int32 to 4 fp16s
1065 if((x_idxs.size() != 4 * op_idxs.size()) || (contiguous != 4 * op_contiguous)) {
1066 throw std::runtime_error(err_msg);
1067 }
1068
1069 auto cvt = [&](Value* a, Value* scale, Value* shift){
1070 return int32_to_float16x4(a, scale, shift);
1071 };
1072
1073 for(size_t j = 0; j < op_idxs.size(); j++){
1074 size_t i = j * 4;
1075 std::tie(vals_[x][x_idxs[i+0]],
1076 vals_[x][x_idxs[i+1]],
1077 vals_[x][x_idxs[i+2]],
1078 vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift);
1079 }
1080 } else {
1081 throw std::runtime_error("unsupported dequantization");
1082 }
1083 return;
1084}
1085
1086/**
1087 * \brief Code Generation for `return`
1088 */
1089void generator::visit_return_inst(ir::return_inst* rr) {
1090 ir::value *ret_val = rr->get_return_value();
1091 ret(ret_val ? vals_[ret_val][{}] : nullptr);
1092}
1093
1094/**
1095 * \brief Code Generation for `cond_branch`
1096 */
1097void generator::visit_cond_branch_inst(ir::cond_branch_inst* br) {
1098 BasicBlock *true_dest = bbs_.at(br->get_true_dest());
1099 BasicBlock *false_dest = bbs_.at(br->get_false_dest());
1100 Value *cond = vals_[br->get_cond()][{}];
1101 cond_br(cond, true_dest, false_dest);
1102}
1103
1104/**
1105 * \brief Code Generation for `uncond_branch`
1106 */
1107void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) {
1108 BasicBlock *dest = bbs_.at(br->get_dest());
1109 br(dest);
1110}
1111
1112/**
1113 * \brief Code Generation for a (synchronous) `load`
1114 */
1115void generator::visit_load_inst(ir::load_inst* x){
1116 BasicBlock *current = builder_->GetInsertBlock();
1117 Module *module = current->getModule();
1118 Value *tid = tgt_->get_local_id(module, *builder_, 0);
1119 Value *lane = urem(tid, i32(32));
1120 ir::value *op = x->get_pointer_operand();
1121 ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
1122 Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
1123 // compute vector width
1124 size_t vec = 1;
1125 bool is_mma_first_row = false;
1126 if(op->get_type()->is_block_ty()){
1127 auto ord = ords_.at(op);
1128 size_t aln = alignment_->get(op, ord[0]);
1129 if(mx){
1130 size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst;
1131 max_eq = std::max<size_t>(max_eq, 1);
1132 aln = std::min(aln, max_eq);
1133 }
1134 analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(x));
1135 assert(layout);
1136
1137 vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
1138 // TODO: generalize
1139 is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
1140 (a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1));
1141 if(is_mma_first_row)
1142 vec = std::min<size_t>(2, aln);
1143 }
1144 // code generation
1145 auto idxs = idxs_.at(x);
1146 for(size_t i = 0; i < idxs.size(); i += vec){
1147 indices_t idx = idxs[i];
1148 // pointer value
1149 Value *ptr = vals_[op][idx];
1150 // masked load
1151 size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
1152 // input ptr info
1153 GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
1154 size_t in_off;
1155 if(in_gep){
1156 ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
1157 in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
1158 ptr = cst ? in_gep->getPointerOperand() : in_gep;
1159 }
1160 else{
1161 in_off = 0;
1162 }
1163 Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue();
1164 // if(!op->get_type()->is_block_ty()){
1165 // pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0)));
1166 // }
1167 Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr;
1168 size_t nbits = dtsize*8;
1169 // pack sub-words (< 32/64bits) into words
1170 // each load has width min(nbits*vec, 32/64)
1171 // and there are (nbits * vec)/width of them
1172 int max_word_width = std::max<int>(32, nbits);
1173 int tot_width = nbits*vec;
1174 int width = std::min(tot_width, max_word_width);
1175 int n_words = std::max(1, tot_width / width);
1176 bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
1177 has_l2_evict_policy = false;
1178 // has_evict_policy = false; // currently disable until supported in `store`
1179 // -----
1180 // create inline asm string
1181 // -----
1182 std::ostringstream asm_oss;
1183 asm_oss << "@$" << n_words; // predicate
1184 asm_oss << " ld";
1185 if(x->get_is_volatile())
1186 asm_oss << ".volatile";
1187 asm_oss << ".global";
1188 if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
1189 if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
1190 if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
1191 if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
1192 if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
1193 if(n_words > 1)
1194 asm_oss << ".v" << n_words; // vector width
1195 asm_oss << ".b" << width; // word size
1196 asm_oss << " {";
1197 for(int i = 0; i < n_words; i++){ // return values
1198 if(i > 0) asm_oss << ",";
1199 asm_oss << "$" << i;
1200 }
1201 asm_oss << "}";
1202 asm_oss << ", [ $" << n_words + 1; // load
1203 asm_oss << " + " << in_off << "]"; // constant offset
1204 if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
1205 asm_oss << ";";
1206 bool has_other = other && (other != UndefValue::get(other->getType()));
1207 std::vector<Value *> others;
1208 // handle `other` values for indices where the mask
1209 // is false
1210 if(has_other)
1211 for(size_t ii = 0; ii < n_words; ii++){
1212 size_t size = width / nbits;
1213 Value *v = UndefValue::get(vec_ty(ty, size));
1214 for(size_t s = 0; s < size; s++){
1215 ir::value *false_val = mx->get_false_value_operand();
1216 v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s);
1217 }
1218 v = bit_cast(v, IntegerType::get(*ctx_, width));
1219 // PTX doesn't support mov.u8, so we need to use mov.u16
1220 auto mov_width = width < 16 ? 16 : width;
1221 asm_oss << "\n ";
1222 asm_oss << "@!$" << n_words << " mov.u" << mov_width;
1223 asm_oss << " $" << ii << ", ";
1224 std::ios_base::fmtflags flags(asm_oss.flags());
1225 if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
1226 asm_oss << "0x" << std::hex << cst->getSExtValue();
1227 else{
1228 asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii;
1229 others.push_back(v);
1230 }
1231 asm_oss.flags(flags);
1232 asm_oss << ";";
1233 }
1234 // ----
1235 // create inline ASM signature
1236 // ---
1237 std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
1238 Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
1239 // ret_ty->print(llvm::outs());
1240 std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
1241 for(Value *v: others)
1242 arg_tys.push_back(v->getType());
1243 if (has_l2_evict_policy)
1244 arg_tys.push_back(i64_ty);
1245 FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
1246 // ---
1247 // create inline ASM constraints
1248 // ---
1249 std::string asm_cstrt;
1250 for(int ii = 0; ii < n_words; ii++){
1251 if(ii > 0) asm_cstrt += ",";
1252 asm_cstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
1253 }
1254 asm_cstrt += ",b,l";
1255 for(size_t ii = 0; ii < others.size(); ii++){
1256 asm_cstrt += ",";
1257 asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
1258 }
1259 if (has_l2_evict_policy)
1260 asm_cstrt += ",l";
1261 // ---
1262 // finally call inline ASM
1263 // ---
1264 InlineAsm *inlineAsm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
1265 std::vector<Value*> args = {pred, ptr};
1266 for(Value *v: others)
1267 args.push_back(v);
1268 if (has_l2_evict_policy)
1269 args.push_back(policies_.at(x->get_eviction_policy()));
1270
1271
1272 Value *_ret = call(inlineAsm, args);
1273 // if(!op->get_type()->is_block_ty()){
1274 // Value* cond = icmp_eq(tid, i32(0));
1275 // Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3));
1276 // Instruction* bar = add_barrier();
1277 // Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false);
1278 // builder_->SetInsertPoint(term);
1279 // store(_ret, shptr);
1280 // builder_->SetInsertPoint(bar->getParent());
1281 // _ret = load(shptr);
1282 // add_barrier();
1283 // }
1284
1285 // ---
1286 // extract and store return values
1287 // ---
1288 std::vector<Value *> rets;
1289 for(unsigned int ii = 0; ii < n_words; ii++){
1290 Value *curr;
1291 if(ret_ty->isStructTy())
1292 curr = extract_val(_ret, {ii});
1293 else
1294 curr = _ret;
1295 rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8))));
1296 }
1297 int tmp = (width / (dtsize * 8));
1298 for(size_t ii = 0; ii < vec; ii++)
1299 vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
1300 }
1301}
1302
1303void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
1304 visit_load_inst(x);
1305}
1306void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
1307 visit_load_inst(x);
1308}
1309
1310/**
1311 * \brief Code Generation for a (synchronous) `store`
1312 */
1313
1314void generator::visit_store_inst(ir::store_inst * x){
1315 ir::masked_store_inst *mx = dynamic_cast<ir::masked_store_inst*>(x);
1316 // operands
1317 ir::value *ptr_op = x->get_pointer_operand();
1318 ir::value *val_op = x->get_value_operand();
1319 ir::value *msk_op = nullptr;
1320 if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x))
1321 msk_op = msk_st->get_mask_operand();
1322 // vector size
1323 size_t vec = 1;
1324 if(val_op->get_type()->is_block_ty()){
1325 auto ord = ords_.at(x->get_pointer_operand());
1326 size_t aln = alignment_->get(ptr_op, ord[0]);
1327 size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous;
1328 if(mx){
1329 size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst;
1330 max_eq = std::max<size_t>(max_eq, 1);
1331 aln = std::min(aln, max_eq);
1332 }
1333 analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(ptr_op));
1334 assert(layout);
1335 // vec = std::min(nts, aln);
1336 vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
1337 // TODO: generalize
1338 bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
1339 (a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1));
1340 if(is_mma_first_row)
1341 vec = std::min<size_t>(2, aln);
1342 }
1343 bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
1344 has_l2_evict_policy = false;
1345 auto idxs = idxs_.at(val_op);
1346 Type *ty = cvt(val_op->get_type()->get_scalar_ty());
1347 if(ty->isIntegerTy(1))
1348 ty = builder_->getInt8Ty();
1349 for(size_t i = 0; i < idxs.size(); i += vec){
1350 indices_t idx = idxs[i];
1351 // pointers
1352 Value *ptr = vals_[ptr_op][idx];
1353 size_t dtsize = std::max<int>(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8);
1354 GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
1355 size_t in_off;
1356 if(in_gep){
1357 ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
1358 in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
1359 ptr = cst ? in_gep->getPointerOperand() : in_gep;
1360 }
1361 else{
1362 in_off = 0;
1363 }
1364 // mask
1365 Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue();
1366 size_t nbits = dtsize*8;
1367 // pack sub-words (< 32/64bits) into words
1368 // each load has width min(nbits*vec, 32/64)
1369 // and there are (nbits * vec)/width of them
1370 int max_word_width = std::max<int>(32, nbits);
1371 int tot_width = nbits*vec;
1372 int width = std::min(tot_width, max_word_width);
1373 int n_words = std::max(1, tot_width / width);
1374 // -----
1375 // create inline asm string
1376 // -----
1377 std::ostringstream asm_oss;
1378 asm_oss << "@$0"; // predicate
1379 asm_oss << " st.global";
1380 if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
1381 if(n_words > 1)
1382 asm_oss << ".v" << n_words; // vector width
1383 asm_oss << ".b" << width; // word size
1384 asm_oss << " [ $1 + " << in_off << "]";
1385 asm_oss << " , {";
1386 for(int i = 0; i < n_words; i++){ // return values
1387 if(i > 0) asm_oss << ",";
1388 asm_oss << "$" << 2 + i;
1389 }
1390 asm_oss << "}";
1391 if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
1392 asm_oss << ";";
1393 // ----
1394 // create inline ASM signature
1395 // ---
1396 Type* val_arg_ty = IntegerType::get(*ctx_, width);
1397 std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
1398 for(int ii = 0; ii < n_words; ii++)
1399 arg_tys.push_back(val_arg_ty);
1400 if (has_l2_evict_policy)
1401 arg_tys.push_back(i64_ty);
1402 FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false);
1403 // ---
1404 // create inline ASM constraints
1405 // ---
1406 std::string asm_cstrt = "b,l";
1407 for(int ii = 0; ii < n_words; ii++){
1408 asm_cstrt += ",";
1409 asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
1410 }
1411 if (has_l2_evict_policy)
1412 asm_cstrt += ",l";
1413 // ---
1414 // finally call inline ASM
1415 // ---
1416 InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
1417 std::vector<Value*> args = {pred, ptr};
1418 for(unsigned int ii = 0; ii < n_words; ii++){
1419 size_t n_subw = width / nbits;
1420 Value* curr = UndefValue::get(vec_ty(ty, n_subw));
1421 for(unsigned int jj = 0; jj < n_subw; jj++){
1422 Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]];
1423 if(new_elt->getType()->isIntegerTy(1))
1424 new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty());
1425 new_elt = bit_cast(new_elt, ty);
1426 curr = builder_->CreateInsertElement(curr, new_elt, jj);
1427 }
1428 args.push_back(bit_cast(curr, val_arg_ty));
1429 }
1430 if (has_l2_evict_policy)
1431 args.push_back(policies_.at(x->get_eviction_policy()));
1432 call(_asm, args);
1433 }
1434}
1435void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
1436 visit_store_inst(x);
1437}
1438void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
1439 visit_store_inst(x);
1440}
1441
1442// --
1443
1444void generator::visit_extract_value_inst(ir::extract_value_inst *x) {
1445 auto idxs = idxs_.at(x);
1446 ir::value* agg = x->get_operand(0);
1447 unsigned insert_idx = x->get_idx();
1448 for(size_t i = 0; i < idxs.size(); i++){
1449 auto idx = idxs[i];
1450 vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx});
1451 }
1452}
1453
1454
1455void generator::visit_insert_value_inst(ir::insert_value_inst *x){
1456 auto idxs = idxs_.at(x);
1457 ir::value* agg = x->get_operand(0);
1458 ir::value* val = x->get_operand(1);
1459 unsigned insert_idx = x->get_idx();
1460 for(size_t i = 0; i < idxs.size(); i++){
1461 auto idx = idxs[i];
1462 vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx});
1463 }
1464}
1465
1466// --
1467/**
1468 * \brief Code Generation for `cat`
1469 */
1470void generator::visit_cat_inst(ir::cat_inst* x) {
1471 auto idxs = idxs_.at(x);
1472 ir::value* lhs = x->get_operand(0);
1473 ir::value* rhs = x->get_operand(1);
1474 int i = 0;
1475 for(size_t j = 0; j < idxs_.at(lhs).size(); j ++){
1476 vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]];
1477 }
1478 for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
1479 vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
1480 }
1481}
1482
1483
1484
1485/**
1486 * \brief Code Generation for `reshape`
1487 */
1488void generator::visit_reshape_inst(ir::reshape_inst* x) {
1489 auto idxs = idxs_.at(x);
1490 for(size_t i = 0; i < idxs_.at(x).size(); i ++){
1491 ir::value* op = x->get_operand(0);
1492 vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]];
1493 };
1494}
1495
1496/**
1497 * \brief Code Generation for `splat`
1498 */
1499void generator::visit_splat_inst(ir::splat_inst* x) {
1500 for(auto idx: idxs_.at(x))
1501 vals_[x][idx] = vals_[x->get_operand(0)][{}];
1502}
1503
1504/**
1505 * \brief Code Generation for `broadcast`
1506 */
1507void generator::visit_broadcast_inst(ir::broadcast_inst* x) {
1508 ir::value* op = x->get_operand(0);
1509 const auto& shape = op->get_type()->get_block_shapes();
1510 for(auto out_idx: idxs_.at(x)){
1511 indices_t in_idx = out_idx;
1512 for(size_t k = 0; k < in_idx.size(); k++)
1513 in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k];
1514 vals_[x][out_idx] = vals_[op][in_idx];
1515 }
1516// for(size_t i = 0; i < idxs_.at(x).size(); i++)
1517// vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]];
1518}
1519
1520/**
1521 * \brief Code Generation for `downcast`
1522 */
1523void generator::visit_downcast_inst(ir::downcast_inst* x) {
1524 vals_[x][{}] = vals_[x->get_operand(0)][{i32(0)}];
1525}
1526
1527/**
1528 * \brief Code Generation for `get_program_id`
1529 */
1530void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) {
1531 Module *module = builder_->GetInsertBlock()->getModule();
1532 Value *ret = tgt_->get_block_id(module, *builder_, pid->get_axis());
1533 vals_[pid][{}] = ret;
1534}
1535
1536/**
1537 * \brief Code Generation for `get_num_programs`
1538 */
1539void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) {
1540 Module *module = builder_->GetInsertBlock()->getModule();
1541 Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis());
1542 vals_[np][{}] = ret;
1543}
1544
1545/**
1546 * \brief Code Generation for `exp`
1547 */
1548void generator::visit_exp_inst(ir::exp_inst* x){
1549 Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
1550 std::vector<llvm::Type*> tys = {f32_ty};
1551 FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
1552 InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
1553 for(auto idx: idxs_.at(x)){
1554 Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
1555 // Value *ex2arg = vals_[x->get_operand(0)][idx];
1556 vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
1557 }
1558}
1559
1560/**
1561 * \brief Code Generation for `cos`
1562 */
1563void generator::visit_cos_inst(ir::cos_inst* x){
1564 std::vector<llvm::Type*> tys = {f32_ty};
1565 FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
1566 InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
1567 for(auto idx: idxs_.at(x)){
1568 vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
1569 }
1570}
1571
1572/**
1573 * \brief Code Generation for `umulhi`
1574 */
1575void generator::visit_umulhi_inst(ir::umulhi_inst* x){
1576 std::vector<llvm::Type*> tys = {i32_ty, i32_ty};
1577 FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false);
1578 InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;", "=r,r,r", false);
1579 for(auto idx: idxs_.at(x)){
1580 Value* lhs = vals_[x->get_operand(0)][idx];
1581 Value* rhs = vals_[x->get_operand(1)][idx];
1582 vals_[x][idx] = call(umulhi, std::vector<llvm::Value*>{lhs, rhs});
1583 }
1584 }
1585
1586/**
1587 * \brief Code Generation for `sin`
1588 */
1589void generator::visit_sin_inst(ir::sin_inst* x){
1590 std::vector<llvm::Type*> tys = {f32_ty};
1591 FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
1592 InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
1593 for(auto idx: idxs_.at(x)){
1594 vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
1595 }
1596 }
1597
1598/**
1599 * \brief Code Generation for `log`
1600 */
1601void generator::visit_log_inst(ir::log_inst* x){
1602 Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453);
1603 std::vector<llvm::Type*> tys = {f32_ty};
1604 FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
1605 InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
1606 for(auto idx: idxs_.at(x)){
1607 Value *lg2arg = call(lg2, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
1608 vals_[x][idx] = fmul(lg2arg, rcplog2e);
1609 }
1610}
1611
1612/**
1613 * \brief Code Generation for `atomic_cas`
1614 */
1615void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
1616 BasicBlock *current = builder_->GetInsertBlock();
1617 Module *module = current->getModule();
1618 Value *tid = tgt_->get_local_id(module, *builder_, 0);
1619 Value *pred = icmp_eq(tid, i32(0));
1620// BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
1621// BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
1622 add_barrier();
1623 tgt_->add_memfence(module, *builder_);
1624 Value *atom_ptr;
1625 atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
1626 atom_ptr = bit_cast(atom_ptr, ptr_ty(cvt(cas->get_type()->get_scalar_ty()), 3));
1627// cond_br(pred, tid_0_bb, tid_0_done_bb);
1628// builder_->SetInsertPoint(tid_0_bb);
1629 Value *cas_ptr = vals_[cas->get_operand(0)][{}];
1630 Value *cas_cmp = vals_[cas->get_operand(1)][{}];
1631 Value *cas_val = vals_[cas->get_operand(2)][{}];
1632 std::string asm_str = "@$1 atom.global.cas.b32 $0, [$2], $3, $4;";
1633 FunctionType *fn_ty = FunctionType::get(i32_ty, {pred->getType(), cas_ptr->getType(), cas_cmp->getType(), cas_val->getType()}, false);
1634 InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "=r,b,l,r,r", true);
1635 add_barrier();
1636 Value *old = call(iasm, {pred, cas_ptr, cas_cmp, cas_val});
1637 add_barrier();
1638
1639 std::string asm2_str = "@$0 st.shared.b32 [$1], $2;";
1640 FunctionType *fn2_ty = FunctionType::get(void_ty, {pred->getType(), atom_ptr->getType(), old->getType()}, false);
1641 InlineAsm *iasm2 = InlineAsm::get(fn2_ty, asm2_str, "b,r,r", true);
1642 add_barrier();
1643 call(iasm2, {pred, atom_ptr, old});
1644 tgt_->add_memfence(module, *builder_);
1645 add_barrier();
1646 vals_[cas][{}] = load(atom_ptr);
1647 add_barrier();
1648}
1649
1650/**
1651 * \brief Code Generation for `atomic_rmw`
1652 */
1653void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
1654 ir::value* ptr = atom->get_operand(0);
1655 ir::value* val = atom->get_operand(1);
1656 ir::value* msk = atom->get_operand(2);
1657
1658 // vector size
1659 int vec = 1;
1660 Value *mask = builder_->getInt1(true);
1661 if(atom->get_type()->is_block_ty()){
1662 auto shape = atom->get_type()->get_block_shapes();
1663 int ld = ords_.at(ptr)[0];
1664 unsigned alignment = alignment_->get(ptr, ld);
1665 vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
1666 vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1);
1667 // mask out inactive threads
1668 analysis::data_layout* layout = layouts_->get(val);
1669 auto curr_axes = a_axes_->get(val);
1670 auto layt_axes = layout->get_axes();
1671 for(unsigned k = 0; k < layt_axes.size(); k++){
1672 unsigned ax = layt_axes.at(k);
1673 distributed_axis dax = axes_.at(ax);
1674 // axis is part of the original layout: thread id should be 0
1675 // but not the current layout
1676 if(std::find(curr_axes.begin(), curr_axes.end(), ax) == curr_axes.end())
1677 mask = and_(mask, icmp_eq(dax.thread_id, i32(0)));
1678 }
1679 // last axis may spillover
1680 Value *thread_id = tgt_->get_local_id(mod_, *builder_, 0);
1681 int per_thread = 1;
1682 for(int ax: layt_axes) { per_thread *= axes_.at(ax).contiguous; }
1683 int numel = 1;
1684 for(int s: layout->get_shape()) { numel *= s; }
1685 mask = and_(mask, icmp_ult(mul(thread_id, i32(per_thread)), i32(numel)));
1686 }
1687
1688
1689 for(int i = 0; i < idxs_.at(val).size(); i += vec){
1690 auto idx = idxs_[val][i];
1691 Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
1692 for(int ii = 0; ii < vec; ii++)
1693 rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
1694 Value *rmw_ptr = vals_[ptr][idx];
1695 Value *rmw_msk = vals_[msk][idx];
1696 rmw_msk = and_(rmw_msk, mask);
1697 if(vec == 1)
1698 rmw_val = extract_elt(rmw_val, i32(0));
1699 Type* ty = rmw_val->getType();
1700 size_t nbits = ty->getScalarSizeInBits();
1701 // extract pointer offset
1702 std::string offset = "";
1703 if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr))
1704 if(gep->getNumIndices() == 1)
1705 if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
1706 offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8);
1707 rmw_ptr = gep->getPointerOperand();
1708 }
1709 rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1));
1710 // asm argument type
1711 std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
1712 // asm function type
1713 FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
1714 // asm string
1715 std::string s_nbits = std::to_string(nbits);
1716 std::string name;
1717 std::string s_ty;
1718 using tt = ir::atomic_rmw_op_t;
1719 switch(atom->get_op()){
1720 case tt::Or: name = "or"; s_ty = "b"; break;
1721 case tt::And: name = "and"; s_ty = "b"; break;
1722 case tt::Xor: name = "xor", s_ty = "b"; break;
1723 case tt::Add: name = "add" , s_ty = "s"; break;
1724 case tt::Min: name = "min", s_ty = "s"; break;
1725 case tt::Max: name = "max", s_ty = "s"; break;
1726 case tt::UMin: name = "min", s_ty = "u"; break;
1727 case tt::UMax: name = "max", s_ty = "u"; break;
1728 case tt::FAdd: name = "add", s_ty = "f"; break;
1729 case tt::Xchg: name = "exch", s_ty = "b"; break;
1730 }
1731 std::string s_vec = vec == 2 ? "x2" : "";
1732 std::string mod = nbits == 16 ? ".noftz" : "";
1733
1734 std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
1735 std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h");
1736 std::string constraint = "=" + ty_id + ",b,l," + ty_id;
1737 // create inline asm
1738 InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
1739 // call asm
1740 if(atom->get_type()->is_block_ty())
1741 vals_[atom][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
1742 else{
1743 Module *mod = builder_->GetInsertBlock()->getModule();
1744 tgt_->add_memfence(mod, *builder_);
1745 add_barrier();
1746 Value *tid = tgt_->get_local_id(mod, *builder_, 0);
1747 rmw_msk = builder_->CreateAnd(rmw_msk, icmp_eq(tid, i32(0)));
1748 Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
1749 Value *atom_ptr;
1750 atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(atom)))), "");
1751 atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
1752 store(old, atom_ptr);
1753 add_barrier();
1754 vals_[atom][idx] = load(atom_ptr);
1755 add_barrier();
1756 }
1757 }
1758}
1759
1760/**
1761 * \brief Code Generation for `mma.884` (V100)
1762 */
1763//TODO: clean-up
1764void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
1765 // shapes
1766 auto shape_c = C->get_type()->get_block_shapes();
1767 auto shape_a = A->get_type()->get_block_shapes();
1768 auto shape_b = B->get_type()->get_block_shapes();
1769 // order
1770 auto ord_a = layouts_->get(A)->get_order();
1771 auto ord_b = layouts_->get(B)->get_order();
1772 bool is_a_trans = C->is_trans_a();
1773 // is_a_trans = false;
1774 if(C->is_trans_a()){
1775 std::swap(ord_a[0], ord_a[1]);
1776 std::swap(shape_a[0], shape_a[1]);
1777 std::swap(offset_a_m_, offset_a_k_);
1778 }
1779 // std::cout << "visiting" << std::endl;
1780 // if(C->is_trans_b()){
1781 // std::swap(ord_b[0], ord_b[1]);
1782 // std::swap(shape_b[0], shape_b[1]);
1783 // }
1784 // layouts
1785 analysis::mma_layout* layout_c = layouts_->get(C)->to_mma();
1786 analysis::shared_layout* layout_a = layouts_->get(A)->to_shared();
1787 analysis::shared_layout* layout_b = layouts_->get(B)->to_shared();
1788 // vectorization
1789 int vec_a = swizzle_->get_vec(layout_a);
1790 int vec_b = swizzle_->get_vec(layout_b);
1791 // strides
1792 bool is_a_row = ord_a[0] != 0;
1793 bool is_b_row = ord_b[0] != 0;
1794 int stride_am = is_a_row ? shape_a[1] : 1;
1795 int stride_ak = is_a_row ? 1 : shape_a[0];
1796 int stride_a0 = is_a_row ? stride_ak : stride_am;
1797 int stride_a1 = is_a_row ? stride_am : stride_ak;
1798 int stride_bn = is_b_row ? 1 : shape_b[0];
1799 int stride_bk = is_b_row ? shape_b[1] : 1;
1800 int stride_b0 = is_b_row ? stride_bn : stride_bk;
1801 int stride_b1 = is_b_row ? stride_bk : stride_bn;
1802 int stride_rep_m = layout_c->wpt(0) * layout_c->fpw(0) * 8;
1803 int stride_rep_n = layout_c->wpt(1) * layout_c->fpw(1) * 8;
1804 int stride_rep_k = 1;
1805 // swizzling
1806 int per_phase_a = swizzle_->get_per_phase(layout_a);
1807 int max_phase_a = swizzle_->get_max_phase(layout_a);
1808 int step_a0 = is_a_row ? stride_rep_k : stride_rep_m;
1809 int num_ptr_a = std::max(2 * per_phase_a * max_phase_a / step_a0, 1);
1810 int per_phase_b = swizzle_->get_per_phase(layout_b);
1811 int max_phase_b = swizzle_->get_max_phase(layout_b);
1812 int step_b0 = is_b_row ? stride_rep_n : stride_rep_k;
1813 int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1);
1814
1815
1816 // max_phase_a = 4;
1817 // vec_a = 8;
1818 // std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl;
1819 // std::cout << vec_a << " " << vec_b << std::endl;
1820
1821 /* --------------------------------- */
1822 /* --- pre-compute pointer lanes --- */
1823 /* --------------------------------- */
1824 BasicBlock* curr_bb = builder_->GetInsertBlock();
1825 BasicBlock* entry = &curr_bb->getParent()->getEntryBlock();
1826 if(entry != curr_bb)
1827 builder_->SetInsertPoint(entry->getTerminator());
1828 Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c];
1829 Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c];
1830 Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a));
1831 std::vector<Value*> off_a(num_ptr_a);
1832 for(int i = 0; i < num_ptr_a; i++){
1833 Value* off_a0i = add(off_a0, i32(i*(is_a_row?4:stride_rep_m)));
1834 off_a0i = exact_udiv(off_a0i, i32(vec_a));
1835 off_a0i = xor_(off_a0i, phase_a);
1836 off_a0i = mul(off_a0i, i32(vec_a));
1837 off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1)));
1838 }
1839 Value* off_b0 = is_b_row ? offset_b_n_[layout_c] : offset_b_k_[layout_c];
1840 Value* off_b1 = is_b_row ? offset_b_k_[layout_c] : offset_b_n_[layout_c];
1841 Value* phase_b = urem(udiv(off_b1, i32(per_phase_b)), i32(max_phase_b));
1842 std::vector<Value*> off_b(num_ptr_b);
1843 for(int i = 0; i < num_ptr_b; i++){
1844 Value* off_b0i = add(off_b0, i32(i*(is_b_row?stride_rep_n:4)));
1845 off_b0i = udiv(off_b0i, i32(vec_b));
1846 off_b0i = xor_(off_b0i, phase_b);
1847 off_b0i = mul(off_b0i, i32(vec_b));
1848 off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1)));
1849 }
1850 builder_->SetInsertPoint(curr_bb);
1851
1852 /* --------------------------------- */
1853 /* --- MMA intrinsic --- */
1854 /* --------------------------------- */
1855 Type *f16x2_ty = vec_ty(f16_ty, 2);
1856 Type *ret_ty = StructType::get(*ctx_, {f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty});
1857 std::vector<Type*> arg_ty = {f16x2_ty, f16x2_ty, f16x2_ty, f16x2_ty,
1858 f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty};
1859 InlineAsm *mma = InlineAsm::get(FunctionType::get(ret_ty, arg_ty, false),
1860 " mma.sync.aligned.m8n8k4."
1861 + std::string(is_a_row ? "row" : "col")
1862 + "."
1863 + std::string(is_b_row ? "row" : "col")
1864 + ".f32.f16.f16.f32 "
1865 "{$0, $1, $2, $3, $4, $5, $6, $7}, "
1866 "{$8, $9}, "
1867 "{$10, $11}, "
1868 "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
1869
1870
1871 std::vector<Value*> ptr_a(num_ptr_a);
1872 std::vector<Value*> ptr_b(num_ptr_b);
1873 std::map<std::pair<int, int>, std::pair<Value*, Value*>> has, hbs;
1874 for(int i = 0; i < num_ptr_a; i++)
1875 ptr_a[i] = gep(shmems_[A], off_a[i]);
1876 for(int i = 0; i < num_ptr_b; i++)
1877 ptr_b[i] = gep(shmems_[B], off_b[i]);
1878
1879
1880 // initialize accumulators
1881 std::vector<Value*> acc;
1882 for(indices_t idx: idxs_.at(C))
1883 acc.push_back(vals_[D][idx]);
1884
1885 unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->shape_per_cta(0);
1886 unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->shape_per_cta(1);
1887
1888 // create mma & unpack result
1889 auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
1890 auto ha = has[{m, K}];
1891 auto hb = hbs[{n, K}];
1892 // arguments
1893 std::vector<size_t> idx = {
1894 (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
1895 (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
1896 (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
1897 (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
1898 };
1899 std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
1900 for(unsigned i = 0; i < 8; i++)
1901 args.push_back(acc[idx[i]]);
1902 // execute mma
1903 Value *nc = call(mma, args);
1904 // unpack
1905 for(unsigned i = 0; i < 8; i++)
1906 acc[idx[i]] = extract_val(nc, {i});
1907 };
1908
1909 ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
1910 ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
1911
1912 // Cache lds value. If values are prefetched, create phi node
1913 // @param inc: incoming block (0 = header, 1 = loop)
1914 auto register_lds =
1915 [&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
1916 if (K == 0 && is_prefetch) {
1917 ir::basic_block* inc_block = phiA->get_incoming_block(inc);
1918 lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
1919 lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
1920 } else
1921 vals[{m, K}] = {val0, val1};
1922 };
1923
1924 auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
1925 int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
1926 Value* ptra;
1927 if(K==0 && is_prefetch){
1928 if(inc == 0)
1929 ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
1930 else
1931 ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
1932 }
1933 else
1934 ptra = ptr_a[offidx];
1935 int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
1936 int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
1937 Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
1938 Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
1939 // record lds that needs to be moved
1940 if (K == 0 && inc == 1 && is_prefetch)
1941 prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
1942 Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
1943 Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
1944 register_lds(has, m, K, inc, ha00, ha01, is_prefetch);
1945 if(vec_a > 4){
1946 Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
1947 Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
1948 if(is_a_row)
1949 register_lds(has, m, K+4, inc, ha10, ha11, is_prefetch);
1950 else
1951 register_lds(has, m+1, K, inc, ha10, ha11, is_prefetch);
1952 }
1953 };
1954
1955 auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
1956 int offidx = (is_b_row? n : K/4) % num_ptr_b;
1957 Value* ptrb;
1958 if(K==0 && is_prefetch){
1959 if(inc == 0)
1960 ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
1961 else
1962 ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
1963 } else
1964 ptrb = ptr_b[offidx];
1965
1966 int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
1967 int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
1968 Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
1969 Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
1970 // record lds that needs to be moved
1971 if (K == 0 && inc == 1 && is_prefetch)
1972 prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
1973 Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
1974 Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
1975 register_lds(hbs, n, K, inc, hb00, hb01, is_prefetch);
1976 if(vec_b > 4){
1977 Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
1978 Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
1979 if(is_b_row)
1980 register_lds(hbs, n+1, K, inc, hb10, hb11, is_prefetch);
1981 else
1982 register_lds(hbs, n, K+4, inc, hb10, hb11, is_prefetch);
1983 }
1984
1985 };
1986
1987 // update accumulators
1988 if (C->is_prefetched()) {
1989 // create phis
1990 builder_->SetInsertPoint(curr_bb->getFirstNonPHI());
1991 for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) {
1992 has[{m, 0}].first = phi(f16x2_ty, 2);
1993 has[{m, 0}].second = phi(f16x2_ty, 2);
1994 if (!is_a_row && vec_a>4) {
1995 has[{m+1, 0}].first = phi(f16x2_ty, 2);
1996 has[{m+1, 0}].second = phi(f16x2_ty, 2);
1997 }
1998 }
1999 for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) {
2000 hbs[{n, 0}].first = phi(f16x2_ty, 2);
2001 hbs[{n, 0}].second = phi(f16x2_ty, 2);
2002 if (is_b_row && vec_b>4) {
2003 hbs[{n+1, 0}].first = phi(f16x2_ty, 2);
2004 hbs[{n+1, 0}].second = phi(f16x2_ty, 2);
2005 }
2006 }
2007
2008 // insert prefetched lds at the end of loop header
2009 builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
2010 for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2)
2011 load_a(m, 0, 0, true);
2012 for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1)
2013 load_b(n, 0, 0, true);
2014
2015 // update accumulators
2016 builder_->SetInsertPoint(curr_bb);
2017 for (unsigned K = 0; K < NK; K += 4) {
2018 int NEXTK = (K + 4) % NK;
2019 // prefetch A
2020 for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2)
2021 load_a(m, NEXTK, 1, true);
2022 // prefetch B
2023 for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1)
2024 load_b(n, NEXTK, 1, true);
2025 // tensor core ops
2026 for(unsigned m = 0; m < num_m/2; m++)
2027 for(unsigned n = 0; n < num_n/2; n++){
2028 call_mma(m, n, K);
2029 }
2030 }
2031 } else { // not prefetched
2032 for(unsigned K = 0; K < NK; K += 4)
2033 for(unsigned m = 0; m < num_m/2; m++)
2034 for(unsigned n = 0; n < num_n/2; n++) {
2035 if(has.find({m, K}) == has.end())
2036 load_a(m, K, /*inc*/0, /*is_prefetch*/false);
2037 if(hbs.find({n, K}) == hbs.end())
2038 load_b(n, K, /*inc*/0, /*is_prefetch*/false);
2039 call_mma(m, n, K);
2040 }
2041 }
2042
2043 // write back accumulators
2044 for(size_t i = 0; i < idxs_.at(C).size(); i++)
2045 vals_[C][idxs_[C][i]] = acc[i];
2046}
2047
2048namespace {
2049class mma16816_smem_loader {
2050public:
2051 mma16816_smem_loader(int wpt, std::vector<int> order, int k_order,
2052 std::vector<unsigned> tile_shape,
2053 std::vector<int> instr_shape, std::vector<int> mat_shape,
2054 int per_phase, int max_phase, int dtsize, Builder *builder,
2055 adder add, multiplier mul, geper gep)
2056 : wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape),
2057 instr_shape_(instr_shape), mat_shape_(mat_shape),
2058 per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder),
2059 add(add), mul(mul), gep(gep) {
2060 // compute compile-time constant variables & types
2061 c_mat_shape_ = mat_shape[order[0]];
2062 s_mat_shape_ = mat_shape[order[1]];
2063
2064 c_stride_ = tile_shape[order[1]];
2065 s_stride_ = tile_shape[order[0]];
2066
2067 // rule: k must be the fast-changing axis
2068 need_trans_ = k_order_ != order_[0];
2069 can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
2070
2071 // we need more pointers at the fast-changing axis,
2072 if (can_use_ldmatrix_)
2073 num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
2074 else // warning: this only works for tf32 & need transpose
2075 num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
2076 num_ptr_ = std::max<int>(num_ptr_, 2);
2077
2078 // special rule for i8/u8, 4 ptrs for each matrix
2079 if (!can_use_ldmatrix_ && dtsize_ == 1)
2080 num_ptr_ *= 4;
2081
2082 // load_v4 stride (in num of mats)
2083 int load_stride_in_mat[2];
2084 load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2
2085 load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]);
2086 p_load_stride_in_mat_ = load_stride_in_mat[order[0]];
2087 // stride in mat, used by load_v4
2088 s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]);
2089 }
2090
2091 std::vector<Value*> compute_offs(Value *warp_off, Value *lane) {
2092 // TODO: this needs to be moved to constructor (and extracted to arr_order)
2093 mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_;
2094 warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1];
2095 // start matrix logic offset (rename it as base_mat_off?)
2096 Value *mat_off[2] = {nullptr, nullptr};
2097
2098 if (can_use_ldmatrix_) {
2099 // c: lane idx inside a group (a group is a collection of 8 contiguous threads)
2100 // s: group idx (0,1,2,3) inside a warp
2101 Value *c = urem(lane, i32(8));
2102 Value *s = udiv(lane, i32(8));
2103 // We can decompose s => s_0, s_1...
2104 Value *s0 = urem(s, i32(2));
2105 Value *s1 = udiv(s, i32(2));
2106
2107 // We use different orders for a & b for better performance.
2108 Value *k_mat_arr = (k_order_ == 1) ? s1 : s0;
2109 Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1;
2110 mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)),
2111 mul(nk_mat_arr, i32(mat_arr_stride_)));
2112 mat_off[k_order_] = k_mat_arr;
2113 // physical offset (before swizzling)
2114 Value *c_mat_off = mat_off[order_[0]];
2115 Value *s_mat_off = mat_off[order_[1]];
2116 // offset inside a matrix
2117 Value *s_off_in_mat = c;
2118
2119 std::vector<Value*> offs(num_ptr_);
2120 Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
2121 // pre-compute strided offset
2122 Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
2123 for (int i=0; i < num_ptr_; ++i) {
2124 Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_));
2125 c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle
2126 offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_)));
2127 }
2128 return offs;
2129 } else if (dtsize_ == 4 && need_trans_) {
2130 // load tf32 matrices with lds32
2131 Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]]
2132 Value *s_off_in_mat = urem(lane, i32(4)); //
2133
2134 Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
2135 std::vector<Value*> offs(num_ptr_);
2136 for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
2137 int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
2138 int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
2139 if (k_mat_arr_int > 0) // we don't need pointers for k
2140 continue;
2141 Value *k_mat_arr = i32(k_mat_arr_int);
2142 Value *nk_mat_arr = i32(nk_mat_arr_int);
2143 // physical offset (before swizzling)
2144 Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
2145 mul(nk_mat_arr, i32(mat_arr_stride_)));
2146 Value *s_mat_off = k_mat_arr; // always 0?
2147 Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
2148 // FIXME: (k_order_ == 1?) is really dirty hack
2149 for (int i = 0; i < num_ptr_/2; ++i) {
2150 Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
2151 c_mat_off_i = xor_(c_mat_off_i, phase);
2152 Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
2153 // TODO: move this out of the loop
2154 c_off = urem(c_off, i32(tile_shape_[order_[0]]));
2155 s_off = urem(s_off, i32(tile_shape_[order_[1]]));
2156 offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_)));
2157 }
2158 }
2159 return offs;
2160 // throw std::runtime_error("not implemented");
2161 } else if (dtsize_ == 1 && need_trans_) {
2162 // load i8/u8 matrices with lds8
2163 Value *c_off_in_mat = udiv(lane, i32(4)); //
2164 Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols
2165
2166 // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
2167 std::vector<Value*> offs(num_ptr_);
2168 for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
2169 int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
2170 int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
2171 if (k_mat_arr_int > 0) // we don't need pointers for k
2172 continue;
2173 Value *k_mat_arr = i32(k_mat_arr_int);
2174 Value *nk_mat_arr = i32(nk_mat_arr_int);
2175 // physical offset (before swizzling)
2176 Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
2177 mul(nk_mat_arr, i32(mat_arr_stride_)));
2178 Value *s_mat_off = k_mat_arr; // always 0?
2179
2180 for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) {
2181 for (int elem_off = 0; elem_off < 4; ++elem_off) {
2182 int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off;
2183
2184 Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
2185 Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off));
2186
2187 // disable swizzling ...
2188 // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
2189 // c_mat_off_i = xor_(c_mat_off_i, phase);
2190
2191 Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
2192 Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_)));
2193 // To prevent out-of-bound access when the tile is too small
2194 c_off = urem(c_off, i32(tile_shape_[order_[0]]));
2195 s_off = urem(s_off, i32(tile_shape_[order_[1]]));
2196 offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_)));
2197 }
2198 }
2199 }
2200 return offs;
2201 } else
2202 throw std::runtime_error("invalid smem load config");
2203 }
2204
2205 std::tuple<Value*, Value*, Value*, Value*>
2206 load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn,
2207 Value *pre_ptr, Value *next_ptr, std::vector<Value*> &off, std::vector<Value*> &ptrs,
2208 FunctionType *ldmatrix_ty, Type *smem_ptr_ty,
2209 std::map<ir::value*, std::vector<Value*>> &prefetch_latch_to_bb_) {
2210 assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
2211 int mat_idx[2] = {mat0, mat1};
2212 int k = mat_idx[k_order_];
2213
2214 int ptr_idx = -1;
2215 if (can_use_ldmatrix_)
2216 ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
2217 else if (dtsize_ == 4 && need_trans_) // tf32 & trans
2218 ptr_idx = mat_idx[order_[0]];
2219 else // i8 & trans
2220 ptr_idx = mat_idx[order_[0]] * 4;
2221
2222 auto get_ptr = [&](int idx) -> Value* {
2223 Value *ptr = nullptr;
2224 if (k == 0 && is_prefetch) {
2225 if (inc == 0)
2226 ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty);
2227 else
2228 ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty);
2229 } else
2230 ptr = ptrs.at(idx);
2231 return ptr;
2232 };
2233 Value *ptr = get_ptr(ptr_idx);
2234
2235 Value *res_v4 = nullptr;
2236 if (can_use_ldmatrix_) {
2237 std::string trans = need_trans_ ? ".trans" : "";
2238 // the offset (in byte) on the strided axis is a constant
2239 int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_;
2240 InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty,
2241 "ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 "
2242 "{$0, $1, $2, $3}, "
2243 "[$4 + " + std::to_string(s_offset) + "];",
2244 "=r,=r,=r,=r,r", true);
2245 assert(ptr);
2246 res_v4 = call(ldmatrix_ty, ld_fn, {ptr});
2247 if (k == 0 && inc == 1 && is_prefetch)
2248 prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4);
2249 return {extract_val(res_v4, std::vector<unsigned>{0}),
2250 extract_val(res_v4, std::vector<unsigned>{1}),
2251 extract_val(res_v4, std::vector<unsigned>{2}),
2252 extract_val(res_v4, std::vector<unsigned>{3})};
2253 } else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices
2254 Value *ptr2 = get_ptr(ptr_idx+1);
2255 assert(s_mat_stride_ == 1);
2256 int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
2257 int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
2258 Value *elem0, *elem1, *elem2, *elem3;
2259 if (k_order_ == 1) {
2260 elem0 = load(gep(ptr, i32(s_offset_elem)));
2261 elem1 = load(gep(ptr2, i32(s_offset_elem)));
2262 elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
2263 elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
2264 } else { // for b (k first)
2265 elem0 = load(gep(ptr, i32(s_offset_elem)));
2266 elem2 = load(gep(ptr2, i32(s_offset_elem)));
2267 elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
2268 elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
2269 }
2270 if (k == 0 && inc == 1 && is_prefetch) {
2271 prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0);
2272 prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1);
2273 prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2);
2274 prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
2275 }
2276 return {elem0, elem1, elem2, elem3};
2277 } else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices
2278 Value *ptr00 = get_ptr(ptr_idx);
2279 Value *ptr01 = get_ptr(ptr_idx+1);
2280 Value *ptr02 = get_ptr(ptr_idx+2);
2281 Value *ptr03 = get_ptr(ptr_idx+3);
2282
2283 Value *ptr10 = get_ptr(ptr_idx+4);
2284 Value *ptr11 = get_ptr(ptr_idx+5);
2285 Value *ptr12 = get_ptr(ptr_idx+6);
2286 Value *ptr13 = get_ptr(ptr_idx+7);
2287
2288 assert(s_mat_stride_ == 1);
2289 int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
2290 int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
2291
2292 Value *i8v4_elems[4];
2293 Value *i32_elems[4];
2294 for (int i=0; i<4; ++i)
2295 i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4));
2296
2297 Value *elem00, *elem01, *elem02, *elem03;
2298 Value *elem10, *elem11, *elem12, *elem13;
2299 Value *elem20, *elem21, *elem22, *elem23;
2300 Value *elem30, *elem31, *elem32, *elem33;
2301 Value *i8_elems[4*4];
2302 if (k_order_ == 1) { //
2303 i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
2304 i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
2305 i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
2306 i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
2307
2308 assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
2309
2310 i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
2311 i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
2312 i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
2313 i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
2314
2315 i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
2316 i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
2317 i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
2318 i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
2319
2320 i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
2321 i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
2322 i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
2323 i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
2324
2325 for (int m=0; m<4; ++m) {
2326 for (int e=0; e<4; ++e)
2327 i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
2328 i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
2329 }
2330 } else { // for b (k first)
2331 i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
2332 i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
2333 i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
2334 i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
2335
2336 assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
2337
2338 i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
2339 i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
2340 i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
2341 i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
2342
2343 i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
2344 i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
2345 i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
2346 i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
2347
2348 i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
2349 i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
2350 i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
2351 i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
2352
2353 for (int m=0; m<4; ++m) {
2354 for (int e=0; e<4; ++e)
2355 i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
2356 i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
2357 }
2358 }
2359 if (k == 0 && inc == 1 && is_prefetch) {
2360 for (int m = 0; m < 4; ++m)
2361 for (int e = 0; e < 4; ++e)
2362 prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]);
2363 }
2364 return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]};
2365 } else
2366 throw std::runtime_error("invalid smem load");
2367 }
2368
2369 int get_num_ptr() const { return num_ptr_; }
2370
2371private:
2372 int wpt_;
2373 std::vector<int> order_;
2374 int k_order_;
2375 std::vector<unsigned> tile_shape_;
2376 std::vector<int> instr_shape_;
2377 std::vector<int> mat_shape_;
2378 int per_phase_, max_phase_;
2379 int dtsize_;
2380
2381 // generated
2382 int c_mat_shape_, s_mat_shape_;
2383 int c_stride_, s_stride_;
2384 // p_: on the pointer axis
2385 int p_load_stride_in_mat_;
2386 int s_mat_stride_;
2387 // stride when moving to next not-k mat
2388 int warp_off_stride_;
2389 int mat_arr_stride_; // matrix arrangement (inside a load) stride
2390 bool need_trans_, can_use_ldmatrix_;
2391 int num_ptr_;
2392
2393 Builder *builder_;
2394 adder add;
2395 multiplier mul;
2396 geper gep;
2397};
2398}
2399
2400/**
2401 * \brief Code Generation for `mma.16816` (A100)
2402 */
2403//TODO: clean-up
2404void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
2405 const std::vector<unsigned>& shapes = C->get_type()->get_block_shapes();
2406 std::map<std::vector<Value*>, std::vector<Value*>> fcs;
2407 for(indices_t idx: idxs_.at(C)){
2408 std::vector<Value*> key(idx.size() - 2);
2409 std::copy(idx.begin() + 2, idx.end(), key.begin());
2410 fcs[key].push_back(vals_[D][idx]);
2411 };
2412 auto shape_a = A->get_type()->get_block_shapes();
2413 auto shape_b = B->get_type()->get_block_shapes();
2414 auto ord_a = layouts_->get(A)->get_order();
2415 if(C->is_trans_a()){
2416 std::swap(ord_a[0], ord_a[1]);
2417 std::swap(shape_a[0], shape_a[1]);
2418 }
2419 auto ord_b = layouts_->get(B)->get_order();
2420 if(C->is_trans_b()){
2421 std::swap(ord_b[0], ord_b[1]);
2422 std::swap(shape_b[0], shape_b[1]);
2423 }
2424 NK = shape_a[1];
2425 analysis::mma_layout* layout = layouts_->get(C)->to_mma();
2426
2427 std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
2428 const int mma_instr_m = mma_instr_shape[0];
2429 const int mma_instr_n = mma_instr_shape[1];
2430 const int mma_instr_k = mma_instr_shape[2];
2431
2432 std::vector<int> mat_shape = layout->get_mma_mat_shape();
2433 const int mat_shape_m = mat_shape[0];
2434 const int mat_shape_n = mat_shape[1];
2435 const int mat_shape_k = mat_shape[2];
2436
2437
2438 const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
2439 const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
2440 const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
2441
2442 // floating point types
2443 Type *fp32_ty = f32_ty;
2444 Type *fp16x2_ty = vec_ty(f16_ty, 2);
2445 Type *bf16x2_ty = vec_ty(bf16_ty, 2);
2446 Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
2447 Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
2448 Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
2449 // integer types
2450 Type *i8x4_ty = vec_ty(i8_ty, 4);
2451 Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty});
2452 Type *i32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty});
2453
2454
2455 FunctionType *ldmatrix_ty = nullptr;
2456 FunctionType *mma_ty = nullptr;
2457 Type *phi_ty = nullptr;
2458 Type *smem_ptr_ty = nullptr;
2459
2460 ir::type *A_ir_ty = A->get_type()->get_scalar_ty();
2461 ir::type *B_ir_ty = B->get_type()->get_scalar_ty();
2462 if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) {
2463 mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
2464 smem_ptr_ty = ptr_ty(f16_ty, 3);
2465 ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
2466 phi_ty = fp16x2_ty;
2467 } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) {
2468 mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
2469 smem_ptr_ty = ptr_ty(bf16_ty, 3);
2470 ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
2471 phi_ty = bf16x2_ty;
2472 } else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) {
2473 mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
2474 smem_ptr_ty = ptr_ty(fp32_ty, 3);
2475 ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
2476 phi_ty = fp32_ty;
2477 } else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) {
2478 // FIXME: We should use i8 here (but nvptx will generate extra casts when using i8)
2479 mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
2480 smem_ptr_ty = ptr_ty(i8_ty, 3);
2481 ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
2482 phi_ty = i32_ty;
2483 // mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
2484 // smem_ptr_ty = ptr_ty(i8_ty, 3);
2485 // ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
2486 // phi_ty = i8x4_ty;
2487 } else
2488 throw std::runtime_error("mma16816 data type not supported");
2489
2490 // left-hand-side values
2491 std::map<std::pair<unsigned, unsigned>, Value*> ha;
2492 std::map<std::pair<unsigned, unsigned>, Value*> hb;
2493
2494 BasicBlock* CurrBB = builder_->GetInsertBlock();
2495 BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
2496
2497 // if true, this will move pointer declarations to the entry basic block
2498 // not prefetched cases tend to be more limited in resource usage
2499 // so we don't pre-compute ptrs to save registers
2500 bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB);
2501 if(licm_ptrs)
2502 builder_->SetInsertPoint(FirstBB->getTerminator());
2503
2504 Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
2505 Value *lane = urem(thread, i32(32));
2506 Value *warp = udiv(thread, i32(32));
2507 Value *warp_mn = udiv(warp, i32(layout->wpt(0)));
2508 Value *warp_m = urem(warp, i32(layout->wpt(0)));
2509 Value *warp_n = urem(warp_mn, i32(layout->wpt(1)));
2510 std::vector<Value *>& fc = fcs.begin()->second;
2511
2512 size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
2513 size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
2514
2515 ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
2516 ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
2517 auto register_lds2 =
2518 [&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
2519 if (k < 2 && is_prefetch) {
2520 ir::basic_block* inc_block = phiA->get_incoming_block(inc);
2521 lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
2522 } else
2523 vals[{mn, k}] = val;
2524 };
2525
2526 // | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
2527 // v (s0_0(0), s1_0(2), | *num_rep_k
2528 // m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
2529 // -----------
2530 // *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
2531 std::function<void(int,int,int,bool)> load_a;
2532 analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared();
2533 bool is_a_shared = layout_a != nullptr;
2534 if(is_a_shared) {
2535 const int per_phase_a = swizzle_->get_per_phase(layout_a);
2536 const int max_phase_a = swizzle_->get_max_phase(layout_a);
2537 mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
2538 {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
2539 per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
2540 std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
2541 int num_ptr_a = a_loader.get_num_ptr();
2542 // pointers
2543 std::vector<Value*> ptrs_a(num_ptr_a);
2544 if(licm_ptrs)
2545 builder_->SetInsertPoint(CurrBB);
2546 for(int i = 0; i < num_ptr_a; i++)
2547 ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
2548 if(licm_ptrs)
2549 builder_->SetInsertPoint(FirstBB->getTerminator());
2550 // loading function
2551 load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable {
2552 auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
2553 shared_next_ptr_[layout_a], off_a, ptrs_a,
2554 ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
2555 register_lds2(ha, m, k, inc, ha0, is_prefetch);
2556 register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
2557 register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
2558 register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
2559 };
2560 }
2561 else {
2562 load_a = [&](int m, int k, int inc, bool is_prefetch) {
2563 distributed_axis ax_n = axes_.at(a_axes_->get(A, 1));
2564 int ldm = ax_n.values.size();
2565 if(ldm != num_rep_k*4)
2566 throw std::runtime_error("Internal compiler error when trying to fuse matmuls!");
2567 // std::cout << m << " " << k << std::endl;
2568 // std::cout << idxs_[A].size() << std::endl;
2569 // std::cout << (m+1)*ldm + k*2 + 3 << std::endl;
2570 // int ldm = num_rep_k*4;
2571 Value* ha0 = UndefValue::get(phi_ty); // e.g., fp16x2
2572 Value* ha1 = UndefValue::get(phi_ty);
2573 Value* ha2 = UndefValue::get(phi_ty);
2574 Value* ha3 = UndefValue::get(phi_ty);
2575 ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0));
2576 ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1));
2577 ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0));
2578 ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1));
2579 ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0));
2580 ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1));
2581 ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0));
2582 ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1));
2583 ha[{m, k}] = ha0;
2584 ha[{m+1, k}] = ha1;
2585 ha[{m, k+1}] = ha2;
2586 ha[{m+1, k+1}] = ha3;
2587 };
2588 }
2589
2590
2591 // | -> n (col-major)
2592 // v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
2593 // k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
2594 // -----------
2595 // *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
2596 analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared();
2597 const int per_phase_b = swizzle_->get_per_phase(layout_b);
2598 const int max_phase_b = swizzle_->get_max_phase(layout_b);
2599 std::vector<int> mma_instr_b{mma_instr_k, mma_instr_n};
2600 std::vector<int> mat_shape_b{mat_shape_k, mat_shape_n};
2601 int k_order_b = 0;
2602 // if(C->is_trans_b()){
2603 // std::swap(mma_instr_b[0], mma_instr_b[1]);
2604 // std::swap(mat_shape_b[0], mat_shape_b[1]);
2605 // k_order_b = k_order_b ^ 1;
2606 // std::swap(ord_b[0], ord_b[1]);
2607 // std::swap(shape_b[0], shape_b[1]);
2608 // }
2609
2610 mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b,
2611 mma_instr_b, mat_shape_b,
2612 per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
2613 std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
2614
2615 if(licm_ptrs)
2616 builder_->SetInsertPoint(CurrBB);
2617 // pointers
2618 int num_ptr_b = b_loader.get_num_ptr();
2619 std::vector<Value*> ptrs_b(num_ptr_b);
2620 for(int i = 0; i < num_ptr_b; i++)
2621 ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
2622
2623
2624 // loading function
2625 std::function<void(int,int,int,bool)> load_b;
2626 load_b = [&](int n, int k, int inc, bool is_prefetch) {
2627 auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
2628 shared_next_ptr_[layout_b], off_b, ptrs_b,
2629 ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
2630 register_lds2(hb, n, k, inc, hb0, is_prefetch);
2631 register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
2632 register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
2633 register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
2634 };
2635
2636
2637
2638 // create mma & unpack result, m, n, k are offsets in mat
2639 auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
2640 InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
2641 " {$0, $1, $2, $3},"
2642 " {$4, $5, $6, $7},"
2643 " {$8, $9},"
2644 " {$10, $11, $12, $13};",
2645 "=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true);
2646 unsigned cols_per_thread = num_rep_n * 2;
2647 std::vector<size_t> idx = {
2648 (m + 0)*cols_per_thread + (n*2 + 0),
2649 (m + 0)*cols_per_thread + (n*2 + 1),
2650 (m + 1)*cols_per_thread + (n*2 + 0),
2651 (m + 1)*cols_per_thread + (n*2 + 1)
2652 };
2653 Value *nc = call(mma_ty, mma_fn,
2654 {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
2655 hb[{n, k}], hb[{n, k+1}],
2656 fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
2657 fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
2658 fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
2659 fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
2660 fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
2661 };
2662 if (C->is_prefetched()) {
2663 // create phis
2664 builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
2665 for(unsigned m = 0; m < num_rep_m; m++){
2666 ha[{2*m, 0}] = phi(phi_ty, 2);
2667 ha[{2*m+1, 0}] = phi(phi_ty, 2);
2668 ha[{2*m, 1}] = phi(phi_ty, 2);
2669 ha[{2*m+1, 1}] = phi(phi_ty, 2);
2670 }
2671 for(unsigned n = 0; n < num_rep_n; n+=2){
2672 hb[{n, 0}] = phi(phi_ty, 2);
2673 hb[{n+1, 0}] = phi(phi_ty, 2);
2674 hb[{n, 1}] = phi(phi_ty, 2);
2675 hb[{n+1, 1}] = phi(phi_ty, 2);
2676 }
2677 // insert prefetched lds at the end of loop header
2678 builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
2679 for(unsigned m = 0; m < num_rep_m; m++)
2680 load_a(2*m, 0, 0, true);
2681 for(unsigned n = 0; n < num_rep_n; n+=2)
2682 load_b(n, 0, 0, true);
2683 // update accumulators
2684 builder_->SetInsertPoint(CurrBB);
2685 for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2
2686 int next_k = (k + 1) % num_rep_k;
2687 // prefetch A
2688 for(unsigned m = 0; m < num_rep_m; m++)
2689 load_a(2*m, 2*next_k, 1, true);
2690 // prefetch B
2691 for(unsigned n = 0; n < num_rep_n; n+=2)
2692 load_b(n, 2*next_k, 1, true);
2693 // tensor core ops
2694 for(unsigned m = 0; m < num_rep_m; m++)
2695 for(unsigned n = 0; n < num_rep_n; n++){
2696 call_mma(2*m, n, 2*k);
2697 }
2698 }
2699 }
2700 else{
2701 for (unsigned k = 0; k < num_rep_k; k++) {
2702 for (unsigned m = 0; m < num_rep_m; m++)
2703 load_a(2*m, 2*k, 0, /*is_prefetch*/false);
2704 for (unsigned n = 0; n < num_rep_n; n+=2)
2705 load_b(n, 2*k, 0, /*is_prefetch*/false);
2706 for (unsigned m = 0; m < num_rep_m; m++)
2707 for (unsigned n = 0; n < num_rep_n; n++)
2708 call_mma(2*m, n, 2*k);
2709 }
2710 }
2711 // write back
2712 unsigned i = 0;
2713 for(indices_t idx: idxs_.at(C)){
2714 std::vector<Value*> key(idx.size() - 2);
2715 std::copy(idx.begin() + 2, idx.end(), key.begin());
2716 if(i >= fcs.at(key).size())
2717 i = 0;
2718 vals_[C][idx] = fcs.at(key)[i++];
2719 };
2720
2721}
2722
2723/**
2724 * \brief Code Generation for FMA-based `dot` (FP32, FP64, Default)
2725 */
2726void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) {
2727 auto shape_c = C->get_type()->get_block_shapes();
2728 auto shape_a = A->get_type()->get_block_shapes();
2729 auto shape_b = B->get_type()->get_block_shapes();
2730 auto ord_a = layouts_->get(A)->get_order();
2731 auto ord_b = layouts_->get(B)->get_order();
2732 analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline();
2733 analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
2734 analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
2735 bool is_a_row = ord_a[0] == 1;
2736 bool is_b_row = ord_b[0] == 1;
2737 std::string a_trans = is_a_row ? "" : ".trans";
2738 std::string b_trans = is_b_row ? ".trans" : "";
2739 int stride_a_m = is_a_row ? shape_a[1] : 1;
2740 int stride_a_k = is_a_row ? 1 : shape_a[0];
2741 int stride_b_n = is_b_row ? 1 : shape_b[0];
2742 int stride_b_k = is_b_row ? shape_b[1] : 1;
2743 int stride_a0 = is_a_row ? stride_a_k : stride_a_m;
2744 int stride_a1 = is_a_row ? stride_a_m : stride_a_k;
2745 int stride_b0 = is_b_row ? stride_b_n : stride_b_k;
2746 int stride_b1 = is_b_row ? stride_b_k : stride_b_n;
2747 int lda = is_a_row ? stride_a_m : stride_a_k;
2748 int ldb = is_b_row ? stride_b_k : stride_b_n;
2749 int per_phase_a = swizzle_->get_per_phase(layout_a);
2750 int max_phase_a = swizzle_->get_max_phase(layout_a);
2751 int per_phase_b = swizzle_->get_per_phase(layout_b);
2752 int max_phase_b = swizzle_->get_max_phase(layout_b);
2753 int num_ptr_a = 8;
2754 int num_ptr_b = 8;
2755 int vec_a = 2;
2756 int vec_b = 4;
2757 distributed_axis ax_m = axes_.at(a_axes_->get(C, 0));
2758 distributed_axis ax_n = axes_.at(a_axes_->get(C, 1));
2759// Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
2760
2761 Value* off_a0 = is_a_row ? i32(0) : mul(ax_m.thread_id, i32(ax_m.contiguous));
2762 Value* off_a1 = is_a_row ? mul(ax_m.thread_id, i32(ax_m.contiguous)): i32(0);
2763 std::vector<Value*> off_a(num_ptr_a);
2764 for(int i = 0; i < num_ptr_a; i++){
2765// Value* off_a0i = add(off_a0, i32(is_a_row ? vec_a : layout_c->mts(0)*vec_a));
2766// off_a0i = exact_udiv(off_a0i, i32(vec_a));
2767// off_a0i = xor_(off_a0i, phase_a);
2768// off_a0i = mul(off_a0i, i32(vec_a));
2769 off_a[i] = add(mul(off_a0, i32(stride_a0)), mul(off_a1, i32(stride_a1)));
2770 }
2771 Value* off_b0 = is_b_row ? mul(ax_n.thread_id, i32(ax_n.contiguous)): i32(0);
2772 Value* off_b1 = is_b_row ? i32(0) : mul(ax_n.thread_id, i32(ax_n.contiguous));
2773 std::vector<Value*> off_b(num_ptr_b);
2774 for(int i = 0; i < num_ptr_b; i++){
2775// Value* off_b0i = add(off_b0, i32(is_b_row ? layout_c->mts(1)*vec_b : vec_b));
2776// off_b0i = exact_udiv(off_b0i, i32(vec_b));
2777// off_b0i = xor_(off_b0i, phase_b);
2778// off_b0i = mul(off_b0i, i32(vec_b));
2779 off_b[i] = add(mul(off_b0, i32(stride_b0)), mul(off_b1, i32(stride_b1)));
2780 }
2781 std::vector<Value*> ptrs_a(num_ptr_a);
2782 for(int i = 0; i < num_ptr_a; i++)
2783 ptrs_a[i] = gep(shmems_[A], off_a[i]);
2784 std::vector<Value*> ptrs_b(num_ptr_b);
2785 for(int i = 0; i < num_ptr_b; i++)
2786 ptrs_b[i] = gep(shmems_[B], off_b[i]);
2787
2788 std::map<indices_t, Value*> ret = vals_[D];
2789 std::map<std::pair<int, int>, Value*> has, hbs;
2790 auto ord = layout_c->get_order();
2791 for(unsigned k = 0; k < NK; k++){
2792 int z = 0;
2793 for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1]))
2794 for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0]))
2795 for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++)
2796 for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){
2797 unsigned m = (ord[0] == 1) ? i : j;
2798 unsigned n = (ord[0] == 1) ? j : i;
2799 unsigned mm = (ord[0] == 1) ? ii : jj;
2800 unsigned nn = (ord[0] == 1) ? jj : ii;
2801 if(has.find({m + mm, k}) == has.end()){
2802 Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k));
2803 Value* va = load(pa);
2804 has[{m + mm, k}] = va;
2805 }
2806 if(hbs.find({n + nn, k}) == hbs.end()){
2807 Value* pb = gep(ptrs_b[0], i32((n + nn)*stride_b_n + k*stride_b_k));
2808 Value* vb = load(pb);
2809 hbs[{n + nn, k}] = vb;
2810 }
2811 ret[idxs_[C].at(z)] = call(f_mul_add, {has[{m+mm,k}], hbs[{n+nn, k}], ret[idxs_[C].at(z)]});
2812 z++;
2813 }
2814 }
2815
2816 for(indices_t idx: idxs_.at(C)){
2817 vals_[C][idx] = ret[idx];
2818 }
2819}
2820
2821/**
2822 * \brief Code Generation for `dot`
2823 * Dispatches to appropriate specialized function
2824 */
2825void generator::visit_dot_inst(ir::dot_inst* dot) {
2826 Function *fn = builder_->GetInsertBlock()->getParent();
2827 Module *module = fn->getParent();
2828 ir::value *A = dot->get_operand(0);
2829 ir::value *B = dot->get_operand(1);
2830 ir::value *D = dot->get_operand(2);
2831 Type *c_ty = cvt(D->get_type()->get_scalar_ty());
2832 Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
2833 auto A_shapes = A->get_type()->get_block_shapes();
2834 size_t red_axis = 1;
2835 unsigned NK = A_shapes[red_axis];
2836 bool is_outer = NK == 1;
2837 bool is_mma = layouts_->get(dot)->to_mma();
2838 if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
2839 return visit_mma884(dot, A, B, D, NK);
2840 if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
2841 return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
2842 if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
2843 A->get_type()->get_scalar_ty()->is_fp32_ty())
2844 return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
2845 throw std::runtime_error("dot has invalid operand type");
2846}
2847
2848void generator::visit_trans_inst(ir::trans_inst* trans) {
2849 throw std::runtime_error("not supported");
2850}
2851
2852/**
2853 * \brief Code Generation for `sqrt`
2854 */
2855void generator::visit_sqrt_inst(ir::sqrt_inst* x) {
2856 for(indices_t idx: idxs_.at(x)){
2857 Value *val = vals_[x->get_operand(0)][idx];
2858 Value *ret = intrinsic(Intrinsic::sqrt, {val->getType()}, {val});
2859 vals_[x][idx] = ret;
2860 }
2861}
2862
2863Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx){
2864 // strides
2865 std::vector<Value*> strides(shapes.size(), builder_->getInt32(0));
2866 strides[order[0]] = builder_->getInt32(1);
2867 for(size_t i = 1; i < idx.size(); i++)
2868 strides[order[i]] = builder_->CreateMul(strides[order[i-1]], builder_->getInt32(shapes[order[i-1]]));
2869 // result
2870 Value *result = builder_->getInt32(0);
2871 for(size_t i = 0; i < idx.size(); i++)
2872 result = builder_->CreateAdd(result, builder_->CreateMul(idx[i], strides[i]));
2873 return result;
2874}
2875
2876inline Value* generator::shfl_sync(Value* acc, int32_t i){
2877 Type* ty = acc->getType();
2878 std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
2879 InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
2880 if(ty->getPrimitiveSizeInBits() <= 32)
2881 return call(shfl, {acc, i32(i)});
2882 acc = bit_cast(acc, vec_ty(f32_ty, 2));
2883 Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
2884 Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
2885 Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
2886 ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
2887 ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
2888 return bit_cast(ret, ty);
2889}
2890
2891/**
2892 * \brief Code Generation for `reduce` (ND case)
2893 */
2894void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){
2895 ir::value *arg = x->get_operand(0);
2896 const auto with_index = x->with_index();
2897 unsigned axis = x->get_axis();
2898 analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
2899 const auto &shapes = layout->get_shape();
2900
2901 Type* sca_ty = cvt(arg->get_type()->get_scalar_ty());
2902 size_t n_bits = sca_ty->getPrimitiveSizeInBits();
2903 std::string n_bits_str = std::to_string(n_bits);
2904 std::string cst = (n_bits == 64) ? "l" : "r";
2905
2906 FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false);
2907 InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;", "b," + cst + "," + cst, true);
2908 FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false);
2909 InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true);
2910
2911 Type *index_ty = IntegerType::get(*ctx_, 32);
2912 FunctionType *st_shared_index_ty =
2913 FunctionType::get(void_ty, {i1_ty, ptr_ty(index_ty, 3), index_ty}, false);
2914 InlineAsm *st_shared_index = InlineAsm::get(
2915 st_shared_index_ty, "@$0 st.shared.b32 [$1], $2;", "b,r,r", true);
2916 FunctionType *ld_shared_index_ty =
2917 FunctionType::get(index_ty, {i1_ty, ptr_ty(index_ty, 3)}, false);
2918 InlineAsm *ld_shared_index = InlineAsm::get(
2919 ld_shared_index_ty, "@$1 ld.shared.b32 $0, [$2];", "=r,b,r", true);
2920
2921 Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
2922 Value* warp = udiv(thread, i32(32));
2923 Value* lane = urem(thread, i32(32));
2924
2925 unsigned shuffle_width = 0;
2926 unsigned warps_per_inner = 0;
2927 auto arg_vals = vals_.at(arg);
2928 std::vector<indices_t> arg_idxs = idxs_.at(arg);
2929 size_t n_elts = arg_idxs.size();
2930 unsigned col_per_thread = 0;
2931 Value* warp_j = nullptr;
2932 if (analysis::scanline_layout *scanline = layout->to_scanline()) {
2933 std::vector<int> order = layout->get_order();
2934 unsigned mts = scanline->mts(order[0]);
2935 shuffle_width = std::min<int>(mts, 32);
2936 warps_per_inner = std::max<int>(mts / 32, 1);
2937 col_per_thread = shapes[order[0]] / mts;
2938 warp_j = urem(warp, i32(warps_per_inner));
2939 } else if (layout->to_mma()) {
2940 shuffle_width = 4;
2941 warps_per_inner = layout->to_mma()->wpt(1);
2942 col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size();
2943 warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
2944 }
2945 assert(warp_j != nullptr);
2946
2947 // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]);
2948 //
2949 Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)),
2950 cvt(x->get_type()->get_scalar_ty()));
2951 Value *index_base =
2952 with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
2953 IntegerType::get(*ctx_, 32))
2954 : nullptr;
2955
2956 // preds
2957 Value* is_lane0 = icmp_eq(lane, i32(0));
2958 Value* is_warp0 = icmp_eq(warp, i32(0));
2959 Value* is_thread0 = icmp_eq(thread, i32(0));
2960 Value* lane_j = urem(lane, i32(shuffle_width));
2961 if(warps_per_inner > 1)
2962 add_barrier();
2963 // compute partial sum for each warp, and store to shared memory
2964 for(size_t i = 0; i < n_elts/col_per_thread; i++){
2965 std::pair<Value*, Value*> acc;
2966 // reduce within thread
2967 for(size_t j = 0; j < col_per_thread; j++){
2968 auto arg_idx = arg_idxs[i*col_per_thread + j];
2969 bool is_first = j == 0;
2970 do_acc(
2971 acc, [&]() -> Value * { return arg_vals[arg_idx]; },
2972 [&]() -> Value * { return arg_idx[axis]; }, is_first);
2973 }
2974
2975 // reduce within warp
2976 for(int k = shuffle_width/2 ; k > 0; k >>= 1) {
2977 do_acc(
2978 acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
2979 [&]() -> Value * { return shfl_sync(acc.second, k); }, false);
2980 }
2981 // store partial result to shared memory
2982 auto x_idxs = idxs_[x][i];
2983 Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
2984 // single warp on the reduce dimension -- no need to use shmem
2985 if(warps_per_inner==1){
2986 vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
2987 }
2988 else{
2989 Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
2990 call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
2991 if (with_index) {
2992 call(st_shared_index,
2993 {icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
2994 }
2995 }
2996 }
2997 if(warps_per_inner==1)
2998 return;
2999 add_barrier();
3000 // at this point, partial accumulator synchronized in shared memory
3001 // Just need to reduce `warp_per_inner` numbers in shared memory
3002 for(size_t i = 0; i < n_elts/col_per_thread; i++){
3003 auto x_idxs = idxs_[x][i];
3004 Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
3005 Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner)));
3006 std::pair<Value*, Value*> acc;
3007 acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
3008 acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true),
3009 gep(index_base, ld_off)})
3010 : nullptr;
3011 for (int k = warps_per_inner / 2; k > 0; k >>= 1) {
3012 do_acc(
3013 acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
3014 [&]() -> Value * { return shfl_sync(acc.second, k); }, false);
3015 }
3016 vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
3017 }
3018 // add_barrier();
3019}
3020
3021
3022void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) {
3023 ir::value *arg = x->get_operand(0);
3024 unsigned axis = x->get_axis();
3025 auto with_index = x->with_index();
3026
3027 // reduce within thread
3028 // index-><current reduced value, current min/max index (optional)>
3029 std::map<indices_t, std::pair<Value*, Value*>> accs;
3030 for(indices_t idx: idxs_.at(arg)){
3031 indices_t pidx = idx;
3032 pidx[axis] = i32(0);
3033 bool is_first = accs.find(pidx) == accs.end();
3034 do_acc(
3035 accs[pidx], [&]() -> Value * { return vals_[arg][idx]; },
3036 [&]() -> Value * { return idx[axis]; }, is_first);
3037 };
3038
3039 // reduce within blocks
3040 auto *data_layout = layouts_->get(layouts_->tmp(x));
3041 auto *data_ptr =
3042 cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty()));
3043 auto *index_ptr =
3044 with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
3045 IntegerType::get(*ctx_, 32))
3046 : data_ptr;
3047
3048 auto shape = data_layout->get_shape();
3049 auto order = data_layout->get_order();
3050 Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id;
3051 for(auto& x: accs) {
3052 // current element being computed
3053 std::pair<Value *, Value *> acc = x.second;
3054 indices_t write_idx = x.first;
3055 write_idx[axis] = lane;
3056 // shared memory write pointer
3057 Value *write_off = shared_off(shape, order, write_idx);
3058 Value *write_ptr = gep(data_ptr, write_off);
3059 Value *index_write_ptr = gep(index_ptr, write_off);
3060 // initialize shared memory
3061 add_barrier();
3062 store(acc.first, write_ptr);
3063 if (with_index) {
3064 store(acc.second, index_write_ptr);
3065 }
3066 // build result
3067 indices_t idx(write_idx.size(), i32(0));
3068 for(size_t i = shape[axis]/2; i > 0; i >>= 1){
3069 idx[axis] = i32(i);
3070 // read pointer
3071 Value *read_msk = icmp_ult(lane, i32(i));
3072 Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0));
3073 Value *read_ptr = gep(write_ptr, read_off);
3074 Value *index_read_ptr = gep(index_write_ptr, read_off);
3075 add_barrier();
3076 // update accumulator
3077 do_acc(
3078 acc, [&]() -> Value * { return load(read_ptr); },
3079 [&]() -> Value * { return load(index_read_ptr); }, false);
3080 add_barrier();
3081 store(acc.first, write_ptr);
3082 if (with_index) {
3083 store(acc.second, index_write_ptr);
3084 }
3085 }
3086 }
3087 add_barrier();
3088
3089 // write back
3090 for(indices_t idx: idxs_.at(x)){
3091 indices_t read_idx = idx;
3092 read_idx.insert(read_idx.begin() + axis, i32(0));
3093 Value *read_off = shared_off(shape, order, read_idx);
3094 Value *read_ptr =
3095 with_index ? gep(index_ptr, read_off) : gep(data_ptr, read_off);
3096 vals_[x][idx] = load(read_ptr);
3097 };
3098}
3099
3100/**
3101 * \brief Code Generation for `reduce` (generic case)
3102 */
3103void generator::visit_reduce_inst(ir::reduce_inst* x) {
3104 Type *ty = cvt(x->get_type()->get_scalar_ty());
3105 // accumulation function
3106 ir::reduce_inst::op_t op = x->get_op();
3107 auto do_acc_op = [&](Value *x, Value *y) -> Value* {
3108 switch(op){
3109 case ir::reduce_inst::ADD: return add(x, y);
3110 case ir::reduce_inst::SUB: return sub(x, y);
3111 case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y);
3112 case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y);
3113 case ir::reduce_inst::ARGMAX: return icmp_sge(x, y);
3114 case ir::reduce_inst::ARGMIN: return icmp_sle(x, y);
3115 case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y);
3116 case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y);
3117 case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
3118 case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
3119 case ir::reduce_inst::FADD: return fadd(x, y);
3120 case ir::reduce_inst::FSUB: return fsub(x, y);
3121 case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y);
3122 case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y);
3123 case ir::reduce_inst::FMAX: return max_num(x, y);
3124 case ir::reduce_inst::FMIN: return min_num(x, y);
3125 case ir::reduce_inst::XOR: return xor_(x, y);
3126
3127 default: throw std::runtime_error("unreachable");
3128 }
3129 };
3130
3131 auto do_acc = [&](std::pair<Value *, Value *> &acc,
3132 std::function<Value *()> load_value_fn,
3133 std::function<Value *()> load_index_fn,
3134 bool is_first) -> void {
3135 auto *val = load_value_fn();
3136 if (x->with_index()) {
3137 auto *index = load_index_fn();
3138 if (is_first) {
3139 acc.first = val;
3140 acc.second = index;
3141 } else {
3142 Value *ret = do_acc_op(acc.first, val);
3143 acc.first = select(ret, acc.first, val);
3144 acc.second = select(ret, acc.second, index);
3145 }
3146 } else {
3147 acc.first = is_first ? val : do_acc_op(acc.first, val);
3148 }
3149 };
3150
3151 // neutral element
3152 Value *neutral;
3153 switch(op) {
3154 case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
3155 case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
3156 case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
3157 case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
3158 case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
3159 case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
3160 case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break;
3161 case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break;
3162 case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
3163 case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
3164 case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
3165 case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
3166 case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
3167 case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break;
3168 case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
3169 case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break;
3170 case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break;
3171 default: throw std::runtime_error("unreachable");
3172 }
3173 ir::value *arg = x->get_operand(0);
3174 bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
3175 bool is_a100_mma = layouts_->is_a100_mma(x);
3176 if (is_coalesced_scanline || is_a100_mma)
3177 visit_reducend_inst_fast(x, do_acc, neutral);
3178 else
3179 visit_reducend_inst(x, do_acc, neutral);
3180}
3181
3182/**
3183 * \brief Code Generation for `select`
3184 */
3185void generator::visit_select_inst(ir::select_inst* x) {
3186 for(indices_t idx: idxs_.at(x)){
3187 vals_[x][idx] = select(vals_[x->get_operand(0)][idx],
3188 vals_[x->get_operand(1)][idx],
3189 vals_[x->get_operand(2)][idx]);
3190 }
3191}
3192
3193
3194
3195void generator::visit_layout_convert(ir::value *out, ir::value *in){
3196 ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
3197 // pointer to temporary shared memory
3198 Type *ty = cvt(out->get_type()->get_scalar_ty());
3199
3200 // Orders
3201 analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
3202 analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
3203 Value *base;
3204 int off = alloc_->offset(layouts_->get(layouts_->tmp(out)));
3205 // std::cout << off << std::endl;
3206 base = gep(shmem_, i32(off));
3207 base = bit_cast(base, ptr_ty(ty, 3));
3208 std::vector<int> n_reps;
3209 for(int i = 0; i < shape.size(); i++){
3210 int in_per_cta = in_layout->shape_per_cta(i);
3211 int out_per_cta = out_layout->shape_per_cta(i);
3212 int max_per_cta = std::max(in_per_cta, out_per_cta);
3213 n_reps.push_back(shape[i]/max_per_cta);
3214 }
3215 std::vector<std::vector<Value*>> in_ax;
3216 std::vector<std::vector<Value*>> out_ax;
3217 for(int d = 0; d < shape.size(); d++){
3218 in_ax.push_back(axes_.at(a_axes_->get(in, d)).values);
3219 out_ax.push_back(axes_.at(a_axes_->get(out, d)).values);
3220 }
3221 auto in_ord =
3222 in_layout->to_mma() ? out_layout->get_order() : in_layout->get_order();
3223 auto out_ord =
3224 out_layout->to_mma() ? in_layout->get_order() : out_layout->get_order();
3225 // out_ord[0] == 0 or in_order[0] == 0 means the first dimension is
3226 // non-contiguous. in_vec can be greater than 0 only if both out_ord[0] and
3227 // and in_ord[0] are contiguous.
3228 int in_vec = out_ord[0] == 0 ? 1
3229 : in_ord[0] == 0 ? 1
3230 : in_layout->contig_per_thread(in_ord[0]);
3231 int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]);
3232 int pad = std::max(in_vec, out_vec);
3233 Value *in_ld = i32(shape[in_ord[0]] + pad);
3234 Value *out_ld = i32(shape[out_ord[0]] + pad);
3235 for(int i = 0; i < n_reps[0]; i++)
3236 for(int j = 0; j < n_reps[1]; j++){
3237 int max_ii, max_jj;
3238 add_barrier();
3239 max_ii = in_ax[0].size()/n_reps[0];
3240 max_jj = in_ax[1].size()/n_reps[1];
3241 for(int ii = 0; ii < max_ii; ii++)
3242 for(int jj = 0; jj < max_jj; jj+=in_vec){
3243 // shared mem pointer
3244 indices_t offs = {in_ax[0][ii], in_ax[1][jj]};
3245 Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
3246 Value *ptr = gep(base, off);
3247 // stash value to shared mem
3248 Value* vals = UndefValue::get(vec_ty(ty, in_vec));
3249 for(int jjj = 0; jjj < in_vec; jjj++){
3250 indices_t idxs = {in_ax[0][i*max_ii + ii],
3251 in_ax[1][j*max_jj + jj + jjj]};
3252 Value* val = bit_cast(vals_[in][idxs], ty);
3253 vals = insert_elt(vals, val, jjj);
3254 }
3255 ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace()));
3256 store(vals, ptr);
3257 }
3258 add_barrier();
3259 max_ii = out_ax[0].size()/n_reps[0];
3260 max_jj = out_ax[1].size()/n_reps[1];
3261 for(int ii = 0; ii < max_ii; ii++)
3262 for(int jj = 0; jj < max_jj; jj+=out_vec){
3263 // shared mem pointer
3264 indices_t offs = {out_ax[0][ii], out_ax[1][jj]};
3265 Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
3266 Value *ptr = gep(base, off);
3267 ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace()));
3268 // load value from shared rem
3269 Value* vals = load(ptr);
3270 for(int jjj = 0; jjj < out_vec; jjj++){
3271 indices_t idxs = {out_ax[0][i*max_ii + ii],
3272 out_ax[1][j*max_jj + jj + jjj]};
3273 vals_[out][idxs] = extract_elt(vals, jjj);
3274 }
3275 }
3276
3277 }
3278}
3279
3280void generator::visit_cvt_layout_inst(ir::cvt_layout_inst *rc) {
3281 visit_layout_convert(rc, rc->get_operand(0));
3282}
3283
3284void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
3285 unsigned in_vec = 1;
3286 ir::value *arg = x->get_pointer_operand();
3287 analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
3288 analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
3289 auto out_order = out_layout->get_order();
3290 auto in_order = in_layout->get_order();
3291 // tiles
3292 if(out_order == in_order)
3293 in_vec = in_layout->nts(in_order[0]);
3294 int out_vec = swizzle_->get_vec(out_layout);
3295 int min_vec = std::min<int>(out_vec, in_vec);
3296 int s = std::max<int>(out_vec / in_vec, 1);
3297 //
3298 int per_phase = swizzle_->get_per_phase(out_layout);
3299 int max_phase = swizzle_->get_max_phase(out_layout);
3300 //
3301 int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
3302 int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
3303 int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
3304 auto shapes = x->get_type()->get_block_shapes();
3305 BasicBlock* CurrBB = builder_->GetInsertBlock();
3306 BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
3307 std::map<std::pair<int, int>, Value*> tmp;
3308 std::vector<std::pair<Value*, int>> shared;
3309 for(int i = 0; i < idxs_.at(arg).size(); i++){
3310 unsigned id = i / min_vec;
3311 // input ptr info
3312 int id_0 = id % (in_ld/min_vec);
3313 int id_1 = id / (in_ld/min_vec);
3314 int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
3315 int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
3316 int off = (off_1*shapes[in_order[0]] + off_0);
3317 std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
3318 if(tmp.find(key) == tmp.end()){
3319 if(CurrBB != FirstBB)
3320 builder_->SetInsertPoint(FirstBB->getTerminator());
3321 indices_t idx = idxs_.at(arg).at(key.first*in_ld);
3322 Value* phase = udiv(idx[in_order[1]], i32(per_phase));
3323 phase = urem(phase, i32(max_phase));
3324 Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
3325 Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
3326 off_0 = udiv(off_0, i32(min_vec));
3327 off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
3328 off_0 = mul(off_0 , i32(min_vec));
3329 Value* off = add(off_0, off_1);
3330 if(CurrBB != FirstBB)
3331 builder_->SetInsertPoint(CurrBB);
3332 tmp[key] = gep(shmems_[x], {off});
3333 }
3334 shared.push_back({tmp[key], off});
3335 }
3336 size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
3337 for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
3338 auto idx = idxs_[arg][i];
3339 // input ptr info
3340 Value *ptr = vals_[arg][idx];
3341 size_t in_off = 0;
3342 GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
3343 if(in_gep){
3344 ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
3345 in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
3346 ptr= cst ? in_gep->getPointerOperand() : in_gep;
3347 }
3348 // output ptr info
3349 Value* out_base = shared[i].first;
3350 int out_off = shared[i].second*dtsize;
3351 // asm
3352 std::string mod = (in_vec*dtsize == 16) ? ".cg" : ".ca";
3353// Value* false_value = vals_[x->get_false_value_operand()][idx];
3354// bool is_zero_false_value = false;
3355// if(Constant* cst = dyn_cast<Constant>(false_value))
3356// is_zero_false_value = cst->isZeroValue();
3357 Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0));
3358 std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;";
3359 FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), ptr->getType(), builder_->getInt32Ty()}, false);
3360 InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true);
3361 call(iasm, {out_base, ptr, src_size});
3362 }
3363
3364 std::string asm_str = "cp.async.commit_group;";
3365 InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
3366 call(iasm);
3367}
3368
3369void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
3370 unsigned in_vec = 1;
3371 ir::value *arg = cts->get_operand(0);
3372 analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
3373 analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
3374 auto out_order = out_layout->get_order();
3375 auto in_order = in_layout->get_order();
3376 // tiles
3377 if(out_order == in_order)
3378 in_vec = in_layout->contig_per_thread(in_order[0]);
3379 int out_vec = swizzle_->get_vec(out_layout);
3380 int min_vec = std::min<int>(out_vec, in_vec);
3381 int s = std::max<int>(out_vec / in_vec, 1);
3382 //
3383 int per_phase = swizzle_->get_per_phase(out_layout);
3384 int max_phase = swizzle_->get_max_phase(out_layout);
3385 //
3386 int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]);
3387 int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]);
3388 if(in_layout->to_mma()){
3389 mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]);
3390 mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]);
3391 per_phase = 1;
3392 max_phase = 8;
3393 }
3394
3395 int in_ld = in_layout->get_shape()[in_order[0]] / mts_0;
3396 int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
3397 int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
3398 if(in_layout->to_mma()){
3399 n_shared_0 = 8;
3400 n_shared_1 = 1;
3401 }
3402
3403 BasicBlock* CurrBB = builder_->GetInsertBlock();
3404 BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
3405 auto shapes = cts->get_type()->get_block_shapes();
3406
3407
3408 // store to shared
3409 Value *current = nullptr;
3410 std::map<std::pair<int, int>, Value*> ptrs;
3411 for(int i = 0; i < idxs_.at(arg).size(); i++){
3412 auto idx = idxs_[arg][i];
3413 Value *in_value = vals_[arg][idx];
3414 if(i % min_vec == 0)
3415 current = UndefValue::get(vec_ty(in_value->getType(), min_vec));
3416 current = insert_elt(current, in_value, i % min_vec);
3417 if(i % min_vec == min_vec - 1){
3418 unsigned id = i / min_vec;
3419 // input ptr info
3420 int id_0 = id % (in_ld/min_vec);
3421 int id_1 = id / (in_ld/min_vec);
3422 // std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl;
3423 std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
3424 if(ptrs.find(key) == ptrs.end()){
3425 if(FirstBB->getTerminator())
3426 builder_->SetInsertPoint(FirstBB->getTerminator());
3427 else
3428 builder_->SetInsertPoint(FirstBB);
3429 indices_t idx = idxs_.at(arg).at(key.first*in_ld);
3430 Value* phase = udiv(idx[in_order[1]], i32(per_phase));
3431 phase = urem(phase, i32(max_phase));
3432 Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
3433 Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
3434 off_0 = udiv(off_0, i32(min_vec));
3435 off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
3436 off_0 = mul(off_0 , i32(min_vec));
3437 Value* off = add(off_0, off_1);
3438 builder_->SetInsertPoint(CurrBB);
3439 ptrs[key] = gep(shmems_.at(cts), {off});
3440 }
3441 int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
3442 int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
3443 if(in_layout->to_mma()){
3444 off_0 = id_0/n_shared_0*n_shared_0*8;
3445 off_1 = id_1/n_shared_1*n_shared_1*8;
3446 }
3447 int off = (off_1*shapes[in_order[0]] + off_0);
3448 Value* ptr = gep(ptrs[key], {i32(off)});
3449 ptr = bit_cast(ptr, current->getType()->getPointerTo(3));
3450 // asm
3451 store(current, ptr);
3452 }
3453 };
3454}
3455
3456void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst*) {
3457 throw std::runtime_error("TODO");
3458}
3459
3460Instruction* generator::add_barrier() {
3461 Module *module = builder_->GetInsertBlock()->getModule();
3462 return tgt_->add_barrier(module, *builder_);
3463}
3464
3465void generator::visit_barrier_inst(ir::barrier_inst*) {
3466 add_barrier();
3467}
3468
3469void generator::visit_clock_inst(ir::clock_inst* clock){
3470 InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true);
3471 vals_[clock][{}] = call(iasm);
3472}
3473
3474void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){
3475 InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true);
3476 vals_[timer][{}] = call(iasm);
3477}
3478
3479
3480
3481void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
3482 ir::value *v = i->get_operand(0);
3483 int inc = i->get_inc();
3484 if (inc == 0) {
3485 // If dot has not been visitied, do nothing.
3486 } else {
3487 // If dot has been visitied, insert prefetched lds
3488 assert(inc == 1);
3489 assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() &&
3490 "dot hasn't be visited");
3491 // sink lds & extract element
3492 // move lds & all uses to current location
3493 std::stack<Value*> work_stack;
3494 for (Value *value : prefetch_latch_to_bb_[v])
3495 work_stack.push(value);
3496 std::vector<Instruction*> dead_instrs;
3497 while (!work_stack.empty()) {
3498 Value *m = work_stack.top();
3499 work_stack.pop();
3500
3501 for (auto u : m->users())
3502 work_stack.push(u);
3503
3504 assert(isa<Instruction>(m));
3505 auto m_instr = static_cast<Instruction*>(m);
3506
3507 m_instr->removeFromParent();
3508 m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end()));
3509 assert(m_instr->getParent() == &*builder_->GetInsertBlock());
3510 builder_->SetInsertPoint(m_instr->getParent());
3511 }
3512 }
3513}
3514
3515void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
3516 std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
3517 InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
3518 call(iasm);
3519}
3520
3521/**
3522 * \brief Code Generation for `extern_elementwise`
3523 */
3524void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) {
3525 std::vector<Type *> operand_types;
3526 for (size_t j = 0; j < i->get_num_operands(); j++) {
3527 operand_types.push_back(
3528 cvt(i->get_operand(j)->get_type()->get_scalar_ty()));
3529 }
3530 Type *ret_type = cvt(i->get_type()->get_scalar_ty());
3531 FunctionType *FT =
3532 FunctionType::get(ret_type, std::move(operand_types), false);
3533 Function *F = llvm::cast<llvm::Function>(
3534 mod_->getOrInsertFunction(i->get_symbol_name(), FT).getCallee());
3535 for (auto idx : idxs_.at(i)) {
3536 std::vector<llvm::Value *> args;
3537 for (size_t j = 0; j < i->get_num_operands(); j++) {
3538 args.emplace_back(vals_[i->get_operand(j)][idx]);
3539 }
3540 vals_[i][idx] = call(F, std::move(args));
3541 }
3542 add_extern_lib(i->get_lib_name(), i->get_lib_path());
3543}
3544
3545//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
3546// for(indices_t idx: idxs_.at(x)){
3547// assert(idx.size() == 1);
3548// if(idx[0] == i32(0))
3549// vals_[x][idx] = idx[0];
3550// else{
3551// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
3552// assert(bin_add);
3553// vals_[x][idx] = bin_add->getOperand(0);
3554// }
3555// }
3556//}
3557
3558//void generator::visit_make_range_sta(ir::make_range_sta* x) {
3559// for(indices_t idx: idxs_.at(x)){
3560// assert(idx.size() == 1);
3561// if(idx[0] == i32(0)){
3562// vals_[x][idx] = idx[0];
3563// }
3564// else{
3565// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
3566// assert(bin_add);
3567// Value *cst = bin_add->getOperand(1);
3568// assert(isa<Constant>(cst));
3569// vals_[x][idx] = cst;
3570// }
3571// };
3572//}
3573
3574void generator::visit_make_range(ir::make_range* x) {
3575 for(indices_t idx: idxs_.at(x)){
3576 Value* start = ConstantInt::get(idx[0]->getType(), x->get_first()->get_value());
3577 vals_[x][idx] = add(start, idx[0]);
3578 }
3579}
3580
3581void generator::visit_undef_value(ir::undef_value *x) {
3582 ir::type* sca_ty = x->get_type()->get_scalar_ty();
3583 Type* ty = cvt(sca_ty);
3584 for(indices_t idx: idxs_.at(x))
3585 vals_[x][idx] = llvm::UndefValue::get(ty);
3586}
3587
3588void generator::visit_constant_int(ir::constant_int *x){
3589 Type *ty = cvt(x->get_type()->get_scalar_ty());
3590 for(indices_t idx: idxs_.at(x))
3591 vals_[x][idx] = ConstantInt::get(ty, x->get_value());
3592}
3593
3594void generator::visit_constant_fp(ir::constant_fp *x){
3595 Type *ty = cvt(x->get_type()->get_scalar_ty());
3596 for(indices_t idx: idxs_.at(x)) {
3597 // manually select bf16 constant
3598 if (x->get_type()->get_scalar_ty()->is_bf16_ty()) {
3599 // highest 16 bits of fp32
3600 float fp32_value = x->get_value();
3601 uint16_t bf16_raw = (*reinterpret_cast<uint32_t*>(&fp32_value)
3602 & 0xffff0000) >> 16;
3603 std::stringstream const_str;
3604 const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned
3605 InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false),
3606 " mov.b16 $0, " + const_str.str() + ";",
3607 "=h", false);
3608 vals_[x][idx] = builder_->CreateCall(bf16_const, {});
3609 } else
3610 vals_[x][idx] = ConstantFP::get(ty, x->get_value());
3611 }
3612}
3613
3614void generator::visit_alloc_const(ir::alloc_const *alloc) {
3615 unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value();
3616 Type *element_ty = cvt(alloc->get_type()->get_pointer_element_ty());
3617 Type *array_ty = llvm::ArrayType::get(element_ty, size);
3618 Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
3619 nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
3620 vals_[alloc][{}] = bit_cast(array, element_ty->getPointerTo(4));
3621}
3622
3623
3624void generator::forward_declare(ir::function* fn){
3625 FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type());
3626 if(!tgt_->is_gpu()){
3627 Type *fn_ret_ty = fn_ty->getReturnType();
3628 std::vector<Type*> fn_args_ty;
3629 for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
3630 fn_args_ty.push_back(fn_ty->getParamType(i));
3631 fn_args_ty.push_back(i32_ty);
3632 fn_args_ty.push_back(i32_ty);
3633 fn_args_ty.push_back(i32_ty);
3634 fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false);
3635 }
3636 Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
3637 fns_[fn] = ret;
3638}
3639
3640Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout,
3641 Type *ty) {
3642 unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
3643 Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space));
3644 return base;
3645}
3646
3647void generator::visit_function(ir::function* fn) {
3648 idxs_.clear();
3649 vals_.clear();
3650 seen_.clear();
3651 LLVMContext &ctx = builder_->getContext();
3652
3653 Function* ret = fns_[fn];
3654
3655
3656 // set attributes
3657 for(auto attr_pair: fn->attrs()){
3658 unsigned id = attr_pair.first;
3659 for(ir::attribute attr: attr_pair.second)
3660 if(attr.is_llvm_attr()){
3661 llvm::Attribute llattr = cvt(attr);
3662 if(llattr.getKindAsEnum() != llvm::Attribute::None)
3663 ret->addAttribute(id, cvt(attr));
3664 }
3665 }
3666 // set metadata
3667 if(tgt_->is_gpu()){
3668 tgt_->set_kernel(*builder_, ctx, mod_, ret);
3669 Metadata *md_args[] = {
3670 ValueAsMetadata::get(ret),
3671 MDString::get(ctx, "maxntidx"),
3672 ValueAsMetadata::get(i32(num_warps_*32))
3673 };
3674 mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
3675 }
3676 // set arguments
3677 for(unsigned i = 0; i < fn->args().size(); i++)
3678 vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
3679 // create blocks
3680 auto blocks = ir::cfg::reverse_post_order(fn);
3681 for(ir::basic_block *block: blocks) {
3682 BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
3683 bbs_[block] = dst_block;
3684 }
3685 builder_->SetInsertPoint(bbs_[fn->blocks()[0]]);
3686 // create policies
3687 if(tgt_->as_nvidia()->sm() >= 80)
3688 for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
3689 std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
3690 std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;";
3691 InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
3692 policies_[evict] = call(iasm);
3693 }
3694 // initialize layouts
3695 for(auto x: layouts_->get_all()){
3696 visit_layout(x.second);
3697 }
3698 // generate LLVM-IR code
3699 for(ir::basic_block *block: blocks)
3700 visit_basic_block(block);
3701 // finalize
3702 finalize_function(fn);
3703}
3704
3705
3706
3707void generator::visit_layout_mma(analysis::mma_layout* layout) {
3708 ir::value *a = nullptr;
3709 ir::value *b = nullptr;
3710 for(ir::value* v: layout->get_values())
3711 if(ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(v)){
3712 a = dot->get_operand(0);
3713 b = dot->get_operand(1);
3714 }
3715 analysis::data_layout* layout_a = layouts_->get(a);
3716 analysis::data_layout* layout_b = layouts_->get(b);
3717
3718 const auto& shape = layout->get_shape();
3719 Value *_1 = i32(1);
3720 Value *_2 = i32(2);
3721 Value *_3 = i32(3);
3722 Value *_4 = i32(4);
3723 Value *_8 = i32(8);
3724 Value *_16 = i32(16);
3725 Value *_32 = i32(32);
3726 int cc = tgt_->as_nvidia()->sm();
3727 std::vector<Value*> idx_m;
3728 std::vector<Value*> idx_n;
3729 std::vector<Value*> idx_z;
3730 //
3731 Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
3732 Value *lane = urem(thread, _32);
3733 Value *warp = udiv(thread, _32);
3734 /* lane offset */
3735 if(cc < 80){
3736 auto ord_a = layout_a->get_order();
3737 auto ord_b = layout_b->get_order();
3738 bool is_a_row = ord_a[0] != 0;
3739 bool is_b_row = ord_b[0] != 0;
3740 /* warp offset */
3741 Value *warp_0 = urem(warp, i32(layout->wpt(0)));
3742 Value *warp_12 = udiv(warp, i32(layout->wpt(0)));
3743 Value *warp_1 = urem(warp_12, i32(layout->wpt(1)));
3744 Value *off_warp_m = mul(warp_0, i32(layout->spw(0)));
3745 Value *off_warp_n = mul(warp_1, i32(layout->spw(1)));
3746 // Quad offset
3747 Value *off_quad_m = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(0)));
3748 Value *off_quad_n = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(1)));
3749 // Pair offset
3750 Value *off_pair_m = udiv(urem(lane, _16), _4);
3751 off_pair_m = urem(off_pair_m, i32(layout->fpw(0)));
3752 off_pair_m = mul(off_pair_m, i32(4));
3753 Value *off_pair_n = udiv(urem(lane, _16), _4);
3754 off_pair_n = udiv(off_pair_n, i32(layout->fpw(0)));
3755 off_pair_n = urem(off_pair_n, i32(layout->fpw(1)));
3756 off_pair_n = mul(off_pair_n, i32(4));
3757 // scale
3758 off_pair_m = mul(off_pair_m, i32(layout->rep(0)/2));
3759 off_quad_m = mul(off_quad_m, i32(layout->rep(0)/2));
3760 off_pair_n = mul(off_pair_n, i32(layout->rep(1)/2));
3761 off_quad_n = mul(off_quad_n, i32(layout->rep(1)/2));
3762 // Quad pair offset
3763 Value *off_lane_m = add(off_pair_m, off_quad_m);
3764 Value *off_lane_n = add(off_pair_n, off_quad_n);
3765 // a offset
3766 offset_a_m_[layout] = add(off_warp_m, off_lane_m);
3767 offset_a_k_[layout] = and_(lane, _3);
3768 // b offsets
3769 offset_b_n_[layout] = add(off_warp_n, off_lane_n);
3770 offset_b_k_[layout] = and_(lane, _3);
3771 // i indices
3772 Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]);
3773 for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0))
3774 for(unsigned mm = 0; mm < layout->rep(0); mm++)
3775 idx_m.push_back(add(offset_c_m, i32(m + mm*2)));
3776 // j indices
3777 Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n));
3778 for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1))
3779 for(unsigned nn = 0; nn < layout->rep(1); nn++){
3780 idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1))));
3781 idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1)));
3782 }
3783 if(is_a_row){
3784 offset_a_m_[layout] = add(offset_a_m_[layout], urem(thread, i32(4)));
3785 offset_a_k_[layout] = i32(0);
3786 }
3787 if(!is_b_row){
3788 offset_b_n_[layout] = add(offset_b_n_[layout], urem(thread, i32(4)));
3789 offset_b_k_[layout] = i32(0);
3790 }
3791 /* axes */
3792 axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0};
3793 axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1};
3794 }
3795 else{
3796 /* warp offset */
3797 Value *warp_0 = urem(warp, i32(layout->wpt(0)));
3798 Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1)));
3799 Value *off_warp_m = mul(warp_0, i32(layout->spw(0)));
3800 Value *off_warp_n = mul(warp_1, i32(layout->spw(1)));
3801 Value *off_lane_m = urem(lane, _16);
3802 Value *off_lane_n = urem(lane, _8);
3803 /* offsets */
3804 // a offset
3805 offset_a_m_[layout] = add(off_warp_m, off_lane_m);
3806 offset_a_k_[layout] = i32(0);
3807 // b offsets
3808 offset_b_n_[layout] = add(off_warp_n, off_lane_n);
3809 offset_b_k_[layout] = i32(0);
3810 // c offset
3811 Value *off_c_m = add(udiv(lane, _4), off_warp_m);
3812 Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n);
3813 for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)){
3814 idx_m.push_back(add(off_c_m, i32(m)));
3815 idx_m.push_back(add(off_c_m, i32(m + 8)));
3816 }
3817 for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)){
3818 idx_n.push_back(add(off_c_n, i32(n)));
3819 idx_n.push_back(add(off_c_n, i32(n + 1)));
3820 }
3821 /* axes */
3822 axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0};
3823 axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1};
3824 }
3825}
3826
3827void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
3828 Value* thread_id = tgt_->get_local_id(mod_, *builder_, 0);
3829 auto order = layout->get_order();
3830 const auto& shape = layout->get_shape();
3831 // Delinearize
3832 size_t dim = shape.size();
3833 std::vector<Value*> thread_ids(dim);
3834 for(unsigned k = 0; k < dim - 1; k++){
3835 Constant *dim_k = i32(layout->mts(order[k]));
3836 Value *rem = urem(thread_id, dim_k);
3837 thread_id = udiv(thread_id, dim_k);
3838 thread_ids[order[k]] = rem;
3839 }
3840 Constant *dim_k = i32(layout->mts(order[dim - 1]));
3841 thread_ids[order[dim - 1]] = urem(thread_id, dim_k);
3842
3843 // Create axes
3844 for(unsigned k = 0; k < dim; k++) {
3845 int nts = layout->nts(k);
3846 int mts = layout->mts(k);
3847 std::string str_k = std::to_string(k);
3848 Value *contiguous_k = i32(nts);
3849 Value *scaled_thread_ids = mul(thread_ids[k], contiguous_k);
3850 unsigned per_cta = layout->shape_per_cta(k);
3851 unsigned per_thread = nts * shape[k] / per_cta;
3852 std::vector<Value*> idx_list(per_thread);
3853 for(unsigned n = 0 ; n < per_thread; n++){
3854 unsigned offset = n / nts * per_cta + n % nts;
3855 idx_list[n] = add(scaled_thread_ids, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
3856 }
3857 axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_ids[k]};
3858 }
3859}
3860
3861void generator::visit_layout_shared(analysis::shared_layout* layout) {
3862 Type* ty = cvt(layout->get_type());
3863 PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace());
3864 if (layout->get_N_buffer()) {
3865 // create pointers
3866 shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout)));
3867 shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty);
3868
3869 BasicBlock *current = builder_->GetInsertBlock();
3870
3871 auto info = *layout->get_N_buffer();
3872 ir::phi_node *phi = info.phi;
3873 BasicBlock *parent = bbs_.at(phi->get_parent());
3874 if(parent->empty())
3875 builder_->SetInsertPoint(parent);
3876 else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) {
3877 builder_->SetInsertPoint(&*parent->getFirstNonPHI());
3878 } else
3879 builder_->SetInsertPoint(parent);
3880
3881 // create smem_idx
3882 read_smem_idx_[layout] = phi(i32_ty, 2);
3883 write_smem_idx_[layout] = phi(i32_ty, 2);
3884
3885 // create pointers
3886 // ptr of the current iteration
3887 shared_ptr_[layout] = phi(ptr_ty, 2);
3888 // ptr of the next iteration
3889 shared_next_ptr_[layout] = phi(ptr_ty, 2);
3890
3891 builder_->SetInsertPoint(current);
3892 } else if(layout->get_double_buffer()) {
3893 BasicBlock *current = builder_->GetInsertBlock();
3894 auto info = *layout->get_double_buffer();
3895 ir::phi_node *phi = info.phi;
3896 BasicBlock *parent = bbs_.at(phi->get_parent());
3897 if(parent->empty())
3898 builder_->SetInsertPoint(parent);
3899 else
3900 builder_->SetInsertPoint(&*parent->getFirstNonPHI());
3901 // create pointers
3902 shared_ptr_[layout] = phi(ptr_ty, 2);
3903 shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout)));
3904 shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], shared_ptr_[layout]->getType());
3905 shared_off_[layout] = phi(i32_ty, 2);
3906 shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr");
3907 builder_->SetInsertPoint(current);
3908 } else{
3909 size_t offset = alloc_->offset(layout);
3910 shared_ptr_[layout] = gep(shmem_, i32(offset));
3911 shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty);
3912 }
3913}
3914
3915void generator::visit_basic_block(ir::basic_block * block) {
3916
3917 BasicBlock *parent = bbs_[block];
3918 builder_->SetInsertPoint(parent);
3919 for(ir::instruction *i: block->get_inst_list()){
3920 visit_value(i);
3921 // std::cout << "done" << std::endl;
3922 }
3923 // Update ir bb -> llvm bb mapping
3924 bbs_[block] = builder_->GetInsertBlock();
3925}
3926
3927void generator::visit_argument(ir::argument* arg) {
3928
3929}
3930
3931void generator::init_idx(ir::value *v) {
3932 idxs_[v].clear();
3933 if(!v->get_type()->is_block_ty()){
3934 idxs_[v].push_back({});
3935 return;
3936 }
3937 if(layouts_->get(v)->to_shared())
3938 return;
3939 const auto &shapes = v->get_type()->get_block_shapes();
3940 size_t rank = shapes.size();
3941 std::vector<distributed_axis> axes(rank);
3942 std::vector<int> ord(rank);
3943 // compute axes
3944 // std::cout << "axes" << std::endl;
3945 for(size_t d = 0; d < shapes.size(); d++){
3946 // std::cout << d << " " << shapes[d] << std::endl;
3947 // std::cout << a_axes_->get(v, d) << std::endl;
3948 if(shapes[d] > 1){
3949 unsigned x = a_axes_->get(v, d);
3950 axes[d] = axes_.at(x);
3951 }
3952 else{
3953 axes[d].contiguous = 1;
3954 axes[d].values = {i32(0)};
3955 }
3956 }
3957 // std::cout << "axes ok" << std::endl;
3958 // compute order
3959 analysis::data_layout* layout = layouts_->get(v);
3960 std::iota(ord.begin(), ord.end(), 0);
3961 auto cmp = [&](int x, int y) {
3962 unsigned axx = a_axes_->get(v, x);
3963 unsigned axy = a_axes_->get(v, y);
3964 size_t posx = layout->find_axis(axx);
3965 size_t posy = layout->find_axis(axy);
3966 if(posx < rank && posy < rank)
3967 return layout->get_order(posx) < layout->get_order(posy);
3968 return false;
3969 };
3970 std::sort(ord.begin(), ord.end(), cmp);
3971 ords_[v] = ord;
3972 // indices
3973 if(axes.size() == 1)
3974 for(Value* x0: axes[ord[0]].values){
3975 idxs_[v].push_back({x0});
3976 }
3977 if(axes.size() == 2)
3978 for(Value* x1: axes[ord[1]].values)
3979 for(Value* x0: axes[ord[0]].values){
3980 indices_t idx(2);
3981 idx[ord[0]] = x0;
3982 idx[ord[1]] = x1;
3983 idxs_[v].push_back(idx);
3984 }
3985 if(axes.size() == 3)
3986 for(Value* x2: axes[ord[2]].values)
3987 for(Value* x1: axes[ord[1]].values)
3988 for(Value* x0: axes[ord[0]].values){
3989 indices_t idx(3);
3990 idx[ord[0]] = x0;
3991 idx[ord[1]] = x1;
3992 idx[ord[2]] = x2;
3993 idxs_[v].push_back(idx);
3994 }
3995}
3996
3997void generator::finalize_shared_layout(analysis::shared_layout *shared) {
3998 if (auto n_buffer = shared->get_N_buffer()) {
3999 // if (*_smem_idx == #stages-1) {
4000 // *_smem_idx = 0;
4001 // } else *_smem_idx++;
4002 auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) {
4003 // insert point
4004 Value *idx = smem_idx[shared];
4005 builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator());
4006 Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1));
4007 PHINode *_ret = phi(i32_ty, 2);
4008 Instruction *then_term = nullptr;
4009 Instruction *else_term = nullptr;
4010 Instruction *dummy = builder_->CreateRet(nullptr);
4011 llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr);
4012 dummy->removeFromParent();
4013 builder_->SetInsertPoint(then_term);
4014 Value *zero_smem_idx = i32(0);
4015 builder_->SetInsertPoint(else_term);
4016 Value *inc_smem_idx = add(idx, i32(1));
4017 builder_->SetInsertPoint(_ret->getParent());
4018 _ret->addIncoming(zero_smem_idx, then_term->getParent());
4019 _ret->addIncoming(inc_smem_idx, else_term->getParent());
4020 // update ir::bb -> llvm::bb mapping
4021 bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock();
4022 // idx = init_stage;
4023 // loop: ...
4024 if (auto idx_phi = llvm::dyn_cast<PHINode>(smem_idx[shared])) {
4025 idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0)));
4026 idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1)));
4027 } else
4028 throw std::runtime_error("Should be PHINode");
4029 };
4030
4031 // read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2
4032 finalize_smem_idx(read_smem_idx_, 2);
4033 finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1);
4034
4035 // finalize pointers
4036 ir::phi_node *pn = n_buffer->phi;
4037 BasicBlock *header = bbs_.at(pn->get_incoming_block(0));
4038 BasicBlock *loop = bbs_.at(pn->get_incoming_block(1));
4039 // %curr_ptr = phi %shared_pre_ptr, %next_ptr
4040 // %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size))
4041 if (auto curr_ptr = dyn_cast<PHINode>(shared_ptr_[shared])) {
4042 curr_ptr->addIncoming(shared_pre_ptr_[shared], header);
4043 curr_ptr->addIncoming(shared_next_ptr_[shared], loop);
4044 } else
4045 throw std::runtime_error("Should be PHINode");
4046
4047 BasicBlock *current = builder_->GetInsertBlock();
4048 builder_->SetInsertPoint(header->getTerminator());
4049 Value *next_ptr_header = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements()));
4050 builder_->SetInsertPoint(current->getTerminator());
4051
4052 assert(isa<PHINode>(shared_next_ptr_[shared]));
4053 static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header);
4054
4055 Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements()));
4056 Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset);
4057 static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr, loop);
4058 } else if(shared->get_double_buffer()) {
4059 auto info = *shared->get_double_buffer();
4060 ir::phi_node *phi = info.phi;
4061 PHINode *ptr = (PHINode*)shmems_[phi];
4062 PHINode *offset = (PHINode*)shoffs_[phi];
4063 for(unsigned n = 0; n < phi->get_num_incoming(); n++){
4064 ir::basic_block* inc_block = phi->get_incoming_block(n);
4065 ir::value* inc_val = phi->get_incoming_value(n);
4066 BasicBlock *llvm_inc_block = bbs_.at(inc_block);
4067 if(inc_val == info.latch){
4068 builder_->SetInsertPoint(llvm_inc_block->getTerminator());
4069 Value *next_offset = neg(offset);
4070 offset->addIncoming(next_offset, llvm_inc_block);
4071 }
4072 else {
4073 unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8;
4074 offset->addIncoming(i32(shared->get_size() / (2*num_bytes)), llvm_inc_block);
4075 }
4076 ptr->addIncoming(shmems_[inc_val], llvm_inc_block);
4077 }
4078 }
4079}
4080
4081void generator::finalize_function(ir::function *fn) {
4082 // finalize double-buffering
4083 for(const auto& x: layouts_->get_all())
4084 if(auto *shared = dynamic_cast<analysis::shared_layout*>(x.second))
4085 finalize_shared_layout(shared);
4086 // finalize phi
4087 for(ir::basic_block *block: fn->blocks())
4088 for(ir::instruction *inst: block->get_inst_list())
4089 if(auto *phi = dynamic_cast<ir::phi_node*>(inst))
4090 finalize_phi_node(phi);
4091 for(auto& x: lazy_phi_incs_)
4092 std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]);
4093}
4094
4095void generator::finalize_phi_node(ir::phi_node *x) {
4096 if(shmems_.find(x) != shmems_.end())
4097 return;
4098 for(unsigned n = 0; n < x->get_num_incoming(); n++){
4099 ir::basic_block *_block = x->get_incoming_block(n);
4100 BasicBlock *block = bbs_.at(_block);
4101 for(indices_t idx: idxs_.at(x)){
4102 PHINode *phi = (PHINode*)vals_[x][idx];
4103 Value *inc = vals_[x->get_incoming_value(n)][idx];
4104 // x->print(std::cout);
4105 phi->addIncoming(inc, block);
4106 }
4107 }
4108}
4109
4110void generator::packed_type(ir::value* i){
4111 Type* dtype = cvt(i->get_type()->get_tile_element_ty());
4112 auto* layout = dynamic_cast<analysis::scanline_layout*>(layouts_->get(i));
4113 assert(layout);
4114}
4115
4116void generator::visit(ir::module &src, llvm::Module &dst) {
4117 mod_ = &dst;
4118 ctx_ = &dst.getContext();
4119 builder_ = new Builder(*ctx_);
4120 // allocate shared memory
4121 if(tgt_->is_gpu())
4122 if(unsigned alloc_size = alloc_->allocated_size()){
4123 Type *int_8_ty = Type::getInt8Ty(*ctx_);
4124 Type *int_32_ty = Type::getInt32Ty(*ctx_);
4125 ArrayType *array_ty = ArrayType::get(int_32_ty, 0);
4126 Type *ptr_ty = ptr_ty(int_8_ty, 3);
4127 GlobalVariable *sh_mem_array =
4128 new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
4129 nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
4130 shmem_ = bit_cast(sh_mem_array, ptr_ty);
4131 }
4132 // instantiate device functions
4133// for(ir::function *fn: src.get_function_list())
4134// for(ir::basic_block *bb: fn->blocks())
4135// for(ir::instruction *i: bb->get_inst_list())
4136// if(auto *call = dynamic_cast<ir::call_inst*>(i)){
4137// std::cout << "call??" << std::endl;
4138// }
4139 // visit functions
4140 for(ir::function *fn: src.get_function_list())
4141 forward_declare(fn);
4142 for(ir::function *fn: src.get_function_list())
4143 visit_function(fn);
4144}
4145
4146void generator::add_extern_lib(const std::string &lib_name,
4147 const std::string &lib_path) {
4148 if (extern_lib_map_.count(lib_name) == 0) {
4149 extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path);
4150 } else if (extern_lib_map_.at(lib_name)->path() != lib_path) {
4151 throw std::runtime_error("A library has multiple paths (1) " + lib_path +
4152 " (2) " + extern_lib_map_.at(lib_name)->path());
4153 }
4154}
4155
4156} // namespace codegen
4157} // namespace triton
4158