1#include <algorithm>
2#include <iostream>
3#include "triton/ir/context.h"
4#include "triton/ir/basic_block.h"
5#include "triton/ir/instructions.h"
6#include "triton/ir/constant.h"
7#include "triton/ir/type.h"
8#include "triton/ir/function.h"
9
10namespace triton{
11namespace ir{
12
13//===----------------------------------------------------------------------===//
14// instruction classes
15//===----------------------------------------------------------------------===//
16
17instruction::instruction(type *ty, value_id_t ity, unsigned num_ops,
18 const std::string &name, instruction *next)
19 : user(ty, num_ops, name), id_(ity) {
20 if(next){
21 basic_block *block = next->get_parent();
22 assert(block && "Next instruction is not in a basic block!");
23 auto it = std::find(block->begin(), block->end(), next);
24 block->get_inst_list().insert(it, next);
25 }
26}
27
28void instruction::erase_from_parent() {
29 parent_->erase(this);
30 for(ir::value* op: ops())
31 op->erase_use(this);
32}
33
34bool instruction::has_tile_result_or_op() {
35 bool result = get_type()->is_block_ty();
36 for(unsigned i = 0; i < get_num_operands(); i++)
37 result |= get_operand(i)->get_type()->is_block_ty();
38 return result;
39}
40
41//===----------------------------------------------------------------------===//
42// phi_node classes
43//===----------------------------------------------------------------------===//
44
45phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next)
46 : instruction(ty, INST_PHI, 0, name, next) {
47 blocks_.reserve(num_reserved);
48}
49
50value* phi_node::get_value_for_block(basic_block * block) {
51 auto it = std::find(blocks_.begin(), blocks_.end(), block);
52 size_t n = std::distance(blocks_.begin(), it);
53 return get_incoming_value(n);
54}
55
56// Set incoming value
57void phi_node::set_incoming_value(unsigned i, value *v){
58 assert(v && "PHI node got a null value!");
59 assert(get_type() == v->get_type() &&
60 "All operands to PHI node must be the same type as the PHI node!");
61 set_operand(i, v);
62}
63
64// Set incoming block
65void phi_node::set_incoming_block(unsigned i, basic_block *block){
66 assert(block && "PHI node got a null basic block!");
67 blocks_[i] = block;
68}
69
70// Add incoming
71void phi_node::add_incoming(value *v, basic_block *block){
72 assert(v && "PHI node got a null value!!");
73 resize_ops(get_num_operands() + 1);
74 blocks_.resize(get_num_operands() + 1);
75 set_incoming_value(get_num_operands() - 1, v);
76 set_incoming_block(get_num_operands() - 1, block);
77}
78
79// Factory methods
80phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &name, instruction *next){
81 return new phi_node(ty, num_reserved, name, next);
82}
83
84//===----------------------------------------------------------------------===//
85// call_inst classes
86//===----------------------------------------------------------------------===//
87
88std::string call_inst::repr_impl() const { return "call " + fn_->get_name(); }
89
90call_inst::call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next)
91 : instruction(fn->get_fn_type()->get_return_ty(), INST_CALL, values.size(), name, next), fn_(fn){
92 for(size_t i = 0; i < values.size(); i++)
93 set_operand(i, values.at(i));
94}
95
96call_inst* call_inst::create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name, instruction *next) {
97 return new call_inst(fn, values, name, next);
98}
99
100
101// launch
102
103launch_inst::launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps, const std::string& name, instruction* next)
104 : instruction(fn->get_fn_type()->get_return_ty(), INST_LAUNCH, 1 + values.size() + grid.size() + 1, name, next){
105 int k = 0;
106 if(grid.size() != 3)
107 throw std::runtime_error("grid must have 3 elements");
108 set_operand(k++, fn);
109 val_begin = k;
110 for(ir::value* v: values)
111 set_operand(k++, v);
112 val_end = k;
113 grid_begin = k;
114 for(ir::value* g: grid)
115 set_operand(k++, g);
116 grid_end = k;
117 set_operand(k++, num_warps);
118}
119
120
121ir::function* launch_inst::get_fn() {
122 return (ir::function*)get_operand(0);
123}
124
125std::vector<ir::value*> launch_inst::get_values() {
126 std::vector<ir::value*> ret;
127 for(int i = val_begin; i < val_end; i++)
128 ret.push_back(get_operand(i));
129 return ret;
130}
131
132std::vector<ir::value*> launch_inst::get_grid() {
133 std::vector<ir::value*> ret;
134 for(int i = grid_begin; i < grid_end; i++)
135 ret.push_back(get_operand(i));
136 return ret;
137}
138
139ir::value* launch_inst::get_num_warps() {
140 return get_operand(grid_end);
141}
142
143
144launch_inst* launch_inst::create(ir::function *fn, const std::vector<ir::value *> &values, const std::vector<ir::value *> &grid, ir::value *num_warps, const std::string &name, instruction *next) {
145 return new launch_inst(fn, values, grid, num_warps, name, next);
146}
147
148
149//===----------------------------------------------------------------------===//
150// binary_operator classes
151//===----------------------------------------------------------------------===//
152
153std::string binary_operator::repr_impl() const {
154 switch(op_) {
155 case Add : return "add";
156 case FAdd : return "fadd";
157 case Sub : return "sub";
158 case FSub : return "fsub";
159 case Mul : return "mul";
160 case FMul : return "fmul";
161 case UDiv : return "udiv";
162 case SDiv : return "sdiv";
163 case FDiv : return "fdiv";
164 case URem : return "urem";
165 case SRem : return "srem";
166 case FRem : return "frem";
167 case Shl : return "shl";
168 case LShr : return "lshr";
169 case AShr : return "ashr";
170 case And : return "and";
171 case Or : return "or";
172 case Xor : return "xor";
173 default: throw std::runtime_error("unknown binary operator");
174 }
175}
176
177bool binary_operator::is_int_div() const {
178 return op_ == binary_op_t::UDiv || op_ == binary_op_t::SDiv;
179}
180
181bool binary_operator::is_int_rem() const {
182 return op_ == binary_op_t::URem || op_ == binary_op_t::SRem;
183}
184
185bool binary_operator::is_shl() const {
186 return op_ == binary_op_t::Shl;
187}
188
189bool binary_operator::is_shr() const {
190 return op_ == binary_op_t::LShr || op_ == binary_op_t::AShr;
191}
192
193bool binary_operator::is_int_mult() const {
194 return op_ == binary_op_t::Mul;
195}
196
197bool binary_operator::is_int_add_sub() const {
198 return op_ == binary_op_t::Add || op_ == binary_op_t::Sub;
199}
200
201
202binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
203 : instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){
204 set_operand(0, lhs);
205 set_operand(1, rhs);
206}
207
208binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs, const std::string &name, instruction *next){
209 assert(lhs->get_type() == rhs->get_type() &&
210 "Cannot create binary operator with two operands of differing type!");
211 return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next);
212}
213
214//binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
215// assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
216// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
217// return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
218//}
219
220//binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
221// assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
222// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()->get_scalar_ty());
223// return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
224//}
225
226//binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
227// assert(arg->get_type()->is_integer_ty());
228// constant *mask = constant::get_all_ones_value(arg->get_type());
229// return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
230//}
231
232//===----------------------------------------------------------------------===//
233// cmp_inst classes
234//===----------------------------------------------------------------------===//
235
236
237
238// cmp_inst
239std::string cmp_inst::repr_impl() const {
240 switch (pred_) {
241 case FCMP_FALSE : return "false";
242 case FCMP_OEQ : return "fcmp_oeq";
243 case FCMP_OGT : return "fcmp_ogt";
244 case FCMP_OGE : return "fcmp_oge";
245 case FCMP_OLT : return "fcmp_olt";
246 case FCMP_OLE : return "fcmp_ole";
247 case FCMP_ONE : return "fcmp_one";
248 case FCMP_ORD : return "fcmp_ord";
249 case FCMP_UNO : return "fcmp_uno";
250 case FCMP_UEQ : return "fcmp_ueq";
251 case FCMP_UGT : return "fcmp_ugt";
252 case FCMP_UGE : return "fcmp_uge";
253 case FCMP_ULT : return "fcmp_ult";
254 case FCMP_ULE : return "fcmp_ule";
255 case FCMP_UNE : return "fcmp_une";
256 case FCMP_TRUE : return "true";
257 case ICMP_EQ : return "icmp_eq";
258 case ICMP_NE : return "icmp_ne";
259 case ICMP_UGT : return "icmp_ugt";
260 case ICMP_UGE : return "icmp_uge";
261 case ICMP_ULT : return "icmp_ult";
262 case ICMP_ULE : return "icmp_ule";
263 case ICMP_SGT : return "icmp_sgt";
264 case ICMP_SGE : return "icmp_sge";
265 case ICMP_SLT : return "icmp_slt";
266 case ICMP_SLE : return "icmp_sle";
267 default: throw std::runtime_error("unreachable");
268 }
269}
270
271cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
272 : instruction(ty, id, 2, name, next), pred_(pred) {
273 set_operand(0, lhs);
274 set_operand(1, rhs);
275}
276
277type* cmp_inst::make_cmp_result_type(type *ty){
278 type* int1_ty = type::get_int1_ty(ty->get_context());
279 if (block_type* tile_ty = dynamic_cast<block_type*>(ty))
280 return block_type::get_same_shapes(int1_ty, tile_ty);
281 return int1_ty;
282}
283
284
285bool cmp_inst::is_fp_predicate(cmp_pred_t pred) {
286 return pred >= FIRST_FCMP_PREDICATE && pred <= LAST_FCMP_PREDICATE;
287}
288
289bool cmp_inst::is_int_predicate(cmp_pred_t pred) {
290 return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE;
291}
292
293
294// icmp_inst
295icmp_inst::icmp_inst(type *ty, cmp_pred_t pred,
296 value *lhs, value *rhs, const std::string &name, instruction *next)
297 : cmp_inst(ty, INST_ICMP, pred, lhs, rhs, name, next){ }
298
299icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
300 assert(is_int_predicate(pred));
301 assert(lhs->get_type() == rhs->get_type());
302 type *res_ty = make_cmp_result_type(lhs->get_type());
303 return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
304}
305
306// fcmp_inst
307fcmp_inst::fcmp_inst(type *ty, cmp_pred_t pred,
308 value *lhs, value *rhs, const std::string &name, instruction *next)
309 : cmp_inst(ty, INST_FCMP, pred, lhs, rhs, name, next){ }
310
311fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
312 assert(is_fp_predicate(pred));
313 type *res_ty = make_cmp_result_type(lhs->get_type());
314 return new fcmp_inst(res_ty, pred, lhs, rhs, name, next);
315}
316
317//===----------------------------------------------------------------------===//
318// unary_inst classes
319//===----------------------------------------------------------------------===//
320
321unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next)
322 : instruction(ty, id, 1, name, next) {
323 set_operand(0, v);
324}
325
326//===----------------------------------------------------------------------===//
327// dequantize_inst classes
328//===----------------------------------------------------------------------===//
329
330dequantize_inst::dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next)
331 : instruction(ty, INST_DEQUANTIZE, 3, name, next) {
332 set_operand(0, v);
333 set_operand(1, scale);
334 set_operand(2, shift);
335}
336
337dequantize_inst *dequantize_inst::create(value *arg, value *scale, value *shift, type *ty, const std::string &name, instruction *next){
338 return new dequantize_inst(ty, arg, scale, shift, name, next);
339}
340
341//===----------------------------------------------------------------------===//
342// cast_inst classes
343//===----------------------------------------------------------------------===//
344
345std::string cast_inst::repr_impl() const {
346 switch (op_){
347 case cast_op_t::Trunc: return "trunc";
348 case cast_op_t::ZExt: return "zext";
349 case cast_op_t::SExt: return "sext";
350 case cast_op_t::FPTrunc: return "fp_trunc";
351 case cast_op_t::FPExt: return "fp_ext";
352 case cast_op_t::UIToFP: return "ui_to_fp";
353 case cast_op_t::SIToFP: return "si_to_fp";
354 case cast_op_t::FPToUI: return "fp_to_ui";
355 case cast_op_t::FPToSI: return "fp_to_si";
356 case cast_op_t::PtrToInt: return "ptr_to_int";
357 case cast_op_t::IntToPtr: return "int_to_ptr";
358 case cast_op_t::BitCast: return "bitcast";
359 case cast_op_t::AddrSpaceCast: return "addr_space_cast";
360 default: throw std::runtime_error("unreachable");
361 }
362}
363// TODO
364bool cast_inst::is_valid(cast_op_t op, value *arg, type *ty) {
365 assert(arg->get_type()->is_block_ty() == ty->is_block_ty());
366 return true;
367}
368
369cast_inst *cast_inst::create(cast_op_t op, value *arg, type *ty, const std::string &name, instruction *next){
370 assert(is_valid(op, arg, ty) && "Invalid cast!");
371 // Construct and return the appropriate CastInst subclass
372 switch (op) {
373 case cast_op_t::Trunc: return new trunc_inst (ty, arg, name, next);
374 case cast_op_t::ZExt: return new z_ext_inst (ty, arg, name, next);
375 case cast_op_t::SExt: return new s_ext_inst (ty, arg, name, next);
376 case cast_op_t::FPTrunc: return new fp_trunc_inst (ty, arg, name, next);
377 case cast_op_t::FPExt: return new fp_ext_inst (ty, arg, name, next);
378 case cast_op_t::UIToFP: return new ui_to_fp_inst (ty, arg, name, next);
379 case cast_op_t::SIToFP: return new si_to_fp_inst (ty, arg, name, next);
380 case cast_op_t::FPToUI: return new fp_to_ui_inst (ty, arg, name, next);
381 case cast_op_t::FPToSI: return new fp_to_si_inst (ty, arg, name, next);
382 case cast_op_t::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next);
383 case cast_op_t::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next);
384 case cast_op_t::BitCast: return new bit_cast_inst (ty, arg, name, next);
385 case cast_op_t::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next);
386 default: throw std::runtime_error("unreachable");
387 }
388}
389
390cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){
391 type *arg_ty = arg->get_type();
392 assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
393 unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
394 unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
395 cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast :
396 (arg_bits > dst_bits ? cast_op_t::Trunc :
397 (is_signed ? cast_op_t::SExt : cast_op_t::ZExt)));
398 return create(op, arg, ty, name, next);
399}
400
401//===----------------------------------------------------------------------===//
402// terminator_inst classes
403//===----------------------------------------------------------------------===//
404
405
406// return_inst
407return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
408 : terminator_inst(ret_val?ret_val->get_type():type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
409 if(ret_val)
410 set_operand(0, ret_val);
411}
412
413return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next){
414 return new return_inst(ctx, ret_val, next);
415}
416
417
418// branch_inst
419branch_inst* branch_inst::create(basic_block *dst, instruction *next) {
420 assert(dst && "Branch destination may not be null!");
421 return new uncond_branch_inst(dst, next);
422}
423
424branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *else_dst, instruction *next) {
425 assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
426 return new cond_branch_inst(if_dst, else_dst, cond, next);
427}
428
429// uncond_branch_inst
430uncond_branch_inst::uncond_branch_inst(basic_block *dst, instruction *next)
431 : branch_inst(type::get_void_ty(dst->get_context()), INST_UNCOND_BRANCH, 1, "", next){
432 set_operand(0, dst);
433}
434
435// cond_branch_inst
436cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
437 : branch_inst(type::get_void_ty(if_dst->get_context()), INST_COND_BRANCH, 3, "", next){
438 assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
439 set_operand(0, if_dst);
440 set_operand(1, else_dst);
441 set_operand(2, cond);
442}
443
444
445//===----------------------------------------------------------------------===//
446// getelementptr_inst classes
447//===----------------------------------------------------------------------===//
448
449getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next)
450 : instruction(get_return_type(pointee_ty, ptr, idx), INST_GETELEMENTPTR, 1 + idx.size(), name, next),
451 source_elt_ty(pointee_ty),
452 res_elt_ty(get_indexed_type(pointee_ty, idx)){
453 // sanity check
454 type *expected_ty = get_type()->get_scalar_ty();
455 expected_ty = ((pointer_type*)expected_ty)->get_element_ty();
456 assert(res_elt_ty == expected_ty);
457 // set operands
458 set_operand(0, ptr);
459 for(size_t i = 0; i < idx.size(); i++)
460 set_operand(1 + i, idx[i]);
461}
462
463type *getelementptr_inst::get_return_type(type *elt_ty, value *x, const std::vector<value *> &idx_list) {
464 // result pointer type
465 type *ty = x->get_type();
466 unsigned addr_space = ty->get_scalar_ty()->get_pointer_address_space();
467 type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), addr_space);
468 // Tile GEP
469 if(ty->is_block_ty())
470 return block_type::get_same_shapes(ptr_ty, ty);
471 for(value *idx : idx_list)
472 if (idx->get_type()->is_block_ty())
473 return block_type::get_same_shapes(ptr_ty, ty);
474 // Scalar GEP
475 return ptr_ty;
476}
477
478type *getelementptr_inst::get_indexed_type_impl(type *ty, const std::vector<value *> &idx_list) {
479 if(idx_list.empty())
480 return ty;
481 if(!ty->is_sized())
482 return nullptr;
483 unsigned cur_idx = 1;
484 for(; cur_idx != idx_list.size(); cur_idx++){
485 composite_type *cty = dynamic_cast<composite_type*>(ty);
486 if(!cty || cty->is_pointer_ty())
487 break;
488 value *idx = idx_list[cur_idx];
489 if(!cty->index_valid(idx))
490 break;
491 ty = cty->get_type_at_index(idx);
492 }
493 return (cur_idx == idx_list.size())? ty : nullptr;
494}
495
496type *getelementptr_inst::get_indexed_type(type *ty, const std::vector<value *> &idx_list) {
497 type *result = get_indexed_type_impl(ty, idx_list);
498 assert(result && "invalid GEP type!");
499 return result;
500}
501
502getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next) {
503 type *pointee_ty = ((pointer_type*)(ptr->get_type()->get_scalar_ty()))->get_element_ty();
504 return new getelementptr_inst(pointee_ty, ptr, idx, name, next);
505}
506
507
508//===----------------------------------------------------------------------===//
509// load_inst/store_inst classes
510//===----------------------------------------------------------------------===//
511
512// io_inst
513io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
514 : instruction(ty, id, num_ops, name, next), eviction_(eviction)
515{ }
516
517// load_inst
518load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
519 : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile)
520{ }
521
522// load
523type *load_inst::get_pointee_type(type *ty) {
524 type *scalar_ty = ty->get_scalar_ty();
525 type *pointee_ty = scalar_ty->get_pointer_element_ty();
526 if(ty->is_block_ty())
527 return block_type::get_same_shapes(pointee_ty, ty);
528 return pointee_ty;
529}
530
531// unmasked_load
532unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
533 : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
534 set_operand(0, ptr);
535}
536
537unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
538 return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
539}
540
541// masked load
542masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
543 bool is_volatile,
544 const std::string &name, instruction *next)
545 : load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
546 set_operand(0, ptr);
547 set_operand(1, mask);
548 set_operand(2, false_value);
549}
550
551masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
552 load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
553 bool is_volatile,
554 const std::string &name, instruction *next) {
555 return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
556}
557
558// masked load async
559masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
560 load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
561 const std::string &name, instruction *next)
562 : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
563 set_operand(0, ptr);
564 set_operand(1, mask);
565 set_operand(2, false_value);
566}
567
568masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
569 load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
570 const std::string &name, instruction *next) {
571 return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
572}
573
574// store
575
576store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
577 : io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next)
578{ }
579
580// unmasked_store
581unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction,
582 const std::string &name, instruction *next)
583 : store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) {
584 set_operand(0, ptr);
585 set_operand(1, val);
586}
587
588unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction,
589 const std::string &name, instruction *next) {
590 return new unmasked_store_inst(ptr, val, eviction, name, next);
591}
592
593// masked store
594masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
595 const std::string &name, instruction *next)
596 : store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) {
597 set_operand(0, ptr);
598 set_operand(1, val);
599 set_operand(2, mask);
600}
601
602masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
603 const std::string &name, instruction *next) {
604 return new masked_store_inst(ptr, val, mask, eviction, name, next);
605}
606
607//===----------------------------------------------------------------------===//
608// struct classes
609//===----------------------------------------------------------------------===//
610
611// insert value
612
613insert_value_inst::insert_value_inst(value *val, value *elt, size_t idx, const std::string& name, instruction *next)
614 : instruction(val->get_type(), INST_INSERT_VALUE, 2, name, next), idx_(idx) {
615 set_operand(0, val);
616 set_operand(1, elt);
617}
618
619insert_value_inst* insert_value_inst::create(value *val, value *elt, size_t idx, const std::string& name, instruction *next){
620 return new insert_value_inst(val, elt, idx, name, next);
621}
622
623
624// extract value
625
626extract_value_inst::extract_value_inst(value *val, size_t idx, const std::string& name, instruction *next)
627 : instruction(val->get_type()->get_struct_type(idx), INST_EXTRACT_VALUE, 1, name, next), idx_(idx) {
628 set_operand(0, val);
629}
630
631extract_value_inst* extract_value_inst::create(value *val, size_t idx, const std::string& name, instruction *next){
632 return new extract_value_inst(val, idx, name, next);
633}
634
635
636//===----------------------------------------------------------------------===//
637// retile_inst classes
638//===----------------------------------------------------------------------===//
639
640// cat
641
642cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next)
643 : instruction(block_type::get(x->get_type()->get_scalar_ty(),
644 {x->get_type()->get_block_shapes()[0] +
645 y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) {
646 set_operand(0, x);
647 set_operand(1, y);
648}
649
650instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
651 return new cat_inst(lhs, rhs, name, next);
652}
653
654// retile
655
656retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
657 const std::string &name, instruction *next)
658 : unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
659
660
661
662// reshape
663
664instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
665 const std::string &name, instruction *next) {
666 return new reshape_inst(arg, INST_RESHAPE, shapes, name, next);
667}
668
669
670// splat
671
672instruction* splat_inst::create(value *arg, const type::block_shapes_t &shapes,
673 const std::string &name, instruction *next) {
674 return new splat_inst(arg, INST_SPLAT, shapes, name, next);
675}
676
677// broadcast
678
679instruction* broadcast_inst::create(value *arg, const type::block_shapes_t &shapes,
680 const std::string &name, instruction *next) {
681 return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next);
682}
683
684// downcast
685
686instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) {
687 return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
688}
689
690
691
692
693//===----------------------------------------------------------------------===//
694// matmul_inst classes
695//===----------------------------------------------------------------------===//
696
697dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
698 const std::string &name, instruction *next)
699 : builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){
700 set_operand(0, A);
701 set_operand(1, B);
702 set_operand(2, C);
703 allow_tf32_ = allow_tf32;
704}
705
706instruction *dot_inst::create(value *A, value *B, value *C,
707 bool AT, bool BT, bool allow_tf32,
708 const std::string &name, instruction *next) {
709 TransT OPA = AT ? Trans : NoTrans;
710 TransT OPB = BT ? Trans : NoTrans;
711 return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
712}
713
714instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
715 const std::string &name, instruction *next) {
716 return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
717}
718
719instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
720 const std::string &name, instruction *next) {
721 return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
722}
723
724instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
725 const std::string &name, instruction *next) {
726 return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
727}
728
729instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
730 const std::string &name, instruction *next) {
731 return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
732}
733
734//===----------------------------------------------------------------------===//
735// trans instructions
736//===----------------------------------------------------------------------===//
737
738ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<int> perm) {
739 // get argument shapes
740 ir::block_type::block_shapes_t arg_shapes = ty->get_block_shapes();
741 // permutate argument shapes
742 perm = init_perm(ty, perm);
743 ir::block_type::block_shapes_t res_shapes = arg_shapes;
744 for(size_t i = 0; i < perm.size(); i++)
745 res_shapes[i] = arg_shapes[perm[i]];
746 // construct type
747 return block_type::get(ty->get_scalar_ty(), res_shapes);
748}
749
750std::vector<int> trans_inst::init_perm(ir::type* ty, const std::vector<int>& perm) {
751 if(!perm.empty())
752 return perm;
753 auto size = ty->get_block_shapes().size();
754 std::vector<int> result;
755 result.push_back(size - 1);
756 for(size_t i = 0; i < size - 1; i++)
757 result.push_back(i);
758 return result;
759}
760
761trans_inst::trans_inst(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next)
762 : builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) {
763 // sanity check
764 perm_ = init_perm(arg->get_type(), perm);
765 //auto size = arg->get_type()->get_tile_shapes().size();
766 //assert(perm_.size() == size);
767 set_operand(0, arg);
768}
769
770instruction* trans_inst::create(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next) {
771 return new trans_inst(arg, perm, name, next);
772}
773
774const std::vector<int> trans_inst::get_perm() const {
775 return perm_;
776}
777
778//===----------------------------------------------------------------------===//
779// sqrt instructions
780//===----------------------------------------------------------------------===//
781
782sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next)
783 : builtin_inst(arg->get_type(), INST_SQRT, 1, name, next){
784 set_operand(0, arg);
785}
786
787instruction* sqrt_inst::create(value *arg, const std::string &name, instruction *next) {
788 return new sqrt_inst(arg, name, next);
789}
790
791//===----------------------------------------------------------------------===//
792// reduce instructions
793//===----------------------------------------------------------------------===//
794
795std::string reduce_inst::to_str(op_t op) {
796 switch (op) {
797 case ADD: return "+";
798 case SUB: return "-";
799 case MAX: return "imax";
800 case MIN: return "imin";
801 case FADD: return "+";
802 case FSUB: return "-";
803 case FMAX: return "fmax";
804 case FMIN: return "fmin";
805 default: break;
806 }
807 assert(false);
808 return "";
809}
810
811type* reduce_inst::get_res_type(value *arg, unsigned axis) {
812 ir::block_type::block_shapes_t shapes = arg->get_type()->get_block_shapes();
813 shapes.erase(shapes.begin() + axis);
814 type *scalar_ty = arg->get_type()->get_scalar_ty();
815 if(shapes.empty())
816// shapes.push_back(1);
817 return scalar_ty;
818 return block_type::get(scalar_ty, shapes);
819}
820
821reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next)
822 : builtin_inst(get_res_type(arg, axis), INST_REDUCE, 1, name, next),
823 op_(op),
824 axis_(axis){
825 set_operand(0, arg);
826}
827
828instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) {
829 return new reduce_inst(arg, op, axis, name, next);
830}
831
832
833//===----------------------------------------------------------------------===//
834// select instructions
835//===----------------------------------------------------------------------===//
836
837select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next)
838 : builtin_inst(if_value->get_type(), INST_SELECT, 3, name, next){
839 set_operand(0, pred);
840 set_operand(1, if_value);
841 set_operand(2, else_value);
842}
843
844instruction* select_inst::create(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) {
845 return new select_inst(pred, if_value, else_value, name, next);
846}
847//===----------------------------------------------------------------------===//
848// builtin instructions
849//===----------------------------------------------------------------------===//
850
851
852// get_program_id
853get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
854 : builtin_inst(ty, INST_GET_PROGRAM_ID, 0, name, next), axis_(axis){
855
856}
857
858instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
859 return new get_program_id_inst(type::get_int32_ty(ctx), axis, name, next);
860}
861
862// get_num_program
863get_num_programs_inst::get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
864 : builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){
865
866}
867
868instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
869 return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
870}
871
872// atomic_rmw
873
874atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next)
875 : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) {
876 set_operand(0, ptr);
877 set_operand(1, val);
878 set_operand(2, msk);
879}
880
881instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
882 return new atomic_rmw_inst(op, ptr, val, msk, name, next);
883}
884
885
886// atomic cas
887
888atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
889 : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
890 set_operand(0, ptr);
891 set_operand(1, cmp);
892 set_operand(2, val);
893}
894
895instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) {
896 return new atomic_cas_inst(ptr, cmp, val, name, next);
897}
898
899
900// umulhi
901
902umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next)
903 : builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) {
904 set_operand(0, lhs);
905 set_operand(1, rhs);
906}
907
908instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
909 return new umulhi_inst(lhs, rhs, name, next);
910}
911
912
913// exp
914
915exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
916 : builtin_inst(val->get_type(), INST_EXP, 1, name, next) {
917 set_operand(0, val);
918}
919
920instruction* exp_inst::create(value *val, const std::string& name, instruction *next) {
921 return new exp_inst(val, name, next);
922}
923
924// cos
925cos_inst::cos_inst(value *val, const std::string &name, instruction *next)
926 : builtin_inst(val->get_type(), INST_COS, 1, name, next) {
927 set_operand(0, val);
928}
929
930instruction* cos_inst::create(value *val, const std::string& name, instruction *next) {
931 return new cos_inst(val, name, next);
932}
933
934// sin
935sin_inst::sin_inst(value *val, const std::string &name, instruction *next)
936 : builtin_inst(val->get_type(), INST_SIN, 1, name, next) {
937 set_operand(0, val);
938}
939
940instruction* sin_inst::create(value *val, const std::string& name, instruction *next) {
941 return new sin_inst(val, name, next);
942}
943
944
945// log
946
947log_inst::log_inst(value *val, const std::string &name, instruction *next)
948 : builtin_inst(val->get_type(), INST_LOG, 1, name, next) {
949 set_operand(0, val);
950}
951
952instruction* log_inst::create(value *val, const std::string& name, instruction *next) {
953 return new log_inst(val, name, next);
954}
955
956
957//===----------------------------------------------------------------------===//
958// intrinsic instructions
959//===----------------------------------------------------------------------===//
960
961// cvt_scanline
962cvt_layout_inst* cvt_layout_inst::create(value *arg, const std::string &name, instruction *next) {
963 return new cvt_layout_inst(arg->get_type(), INST_CVT_LAYOUT, arg, name, next);
964}
965
966// copy to shared
967copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
968 instruction *next) {
969 return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next);
970}
971
972// copy from shared
973copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::string &name,
974 instruction *next) {
975 return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next);
976}
977
978// barrier
979barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next)
980 : instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
981
982barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
983 return new barrier_inst(ctx, name, next);
984}
985
986async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
987 : instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
988
989async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
990 return new async_wait_inst(ctx, N, name, next);
991}
992
993// prefetch_s
994prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, const std::string &name, instruction *next) {
995 return new prefetch_s_inst(ctx, arg, inc, name, next);
996}
997
998// global timer
999globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next)
1000 : instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { }
1001
1002globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) {
1003 return new globaltimer_inst(ctx, name, next);
1004}
1005
1006// extern elementwise
1007extern_elementwise_inst::extern_elementwise_inst(
1008 context &ctx, const std::vector<value *> &args, type *ret_ty,
1009 const std::string &lib_name, const std::string &lib_path,
1010 const std::string &symbol_name, const std::string &name, instruction *next)
1011 : instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), name, next),
1012 lib_name_(lib_name),
1013 lib_path_(lib_path),
1014 symbol_name_(symbol_name) {
1015 for (size_t i = 0; i < args.size(); i++) {
1016 set_operand(i, args[i]);
1017 }
1018}
1019
1020extern_elementwise_inst *extern_elementwise_inst::create(
1021 context &ctx, const std::vector<value *> &args, type *ret_ty,
1022 const std::string &lib_name, const std::string &lib_path,
1023 const std::string &symbol_name, const std::string &name,
1024 instruction *next) {
1025 return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path,
1026 symbol_name, name, next);
1027}
1028
1029// clock
1030clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
1031 : instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }
1032
1033clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) {
1034 return new clock_inst(ctx, name, next);
1035}
1036
1037
1038// make_range
1039make_range::make_range(type *ty, constant_int *first, constant_int *last)
1040 : instruction(ty, INST_MAKE_RANGE, 0), first_(first), last_(last){ }
1041
1042make_range *make_range::create(constant_int *first, constant_int *last) {
1043 assert(first->get_type()->is_integer_ty());
1044 assert(first->get_type() == last->get_type());
1045// assert(((constant_int*)first)->get_value() == 0);
1046 type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
1047 return new make_range(ty, first, last);
1048}
1049
1050const constant_int* make_range::get_first() const {
1051 return first_;
1052}
1053
1054const constant_int* make_range::get_last() const {
1055 return last_;
1056}
1057
1058}
1059}
1060