1#include <string>
2#include <algorithm>
3#include <iostream>
4#include "triton/ir/basic_block.h"
5#include "triton/ir/builder.h"
6#include "triton/ir/constant.h"
7#include "triton/ir/instructions.h"
8#include "triton/ir/type.h"
9
10namespace triton{
11namespace ir{
12
13builder::builder(context &ctx):
14 ctx_(ctx), block_(nullptr) {}
15
16//===----------------------------------------------------------------------===//
17// utilities
18//===----------------------------------------------------------------------===//
19void builder::set_insert_point(basic_block::iterator it){
20 block_ = (*it)->get_parent();
21 insert_point_ = it;
22}
23
24void builder::set_insert_point(instruction* i){
25 block_ = i->get_parent();
26 auto it = std::find(block_->begin(), block_->end(), i);
27 set_insert_point(it);
28}
29
30
31void builder::set_insert_point_after(instruction* i){
32 block_ = i->get_parent();
33 auto it = std::find(block_->begin(), block_->end(), i);
34 set_insert_point(++it);
35}
36
37
38void builder::set_insert_point(basic_block *block){
39 block_ = block;
40 insert_point_ = block->end();
41}
42
43
44//===----------------------------------------------------------------------===//
45// convenience functions
46//===----------------------------------------------------------------------===//
47
48value *builder::get_int1(bool val)
49{ return constant_int::get(type::get_int1_ty(ctx_), val); }
50
51value *builder::get_int32(uint32_t val)
52{ return constant_int::get(type::get_int32_ty(ctx_), val);}
53
54value *builder::get_int64(uint64_t val)
55{ return constant_int::get(type::get_int64_ty(ctx_), val);}
56
57value *builder::get_float16(float val)
58{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
59
60value *builder::get_float32(float val)
61{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
62
63value *builder::get_range(int32_t _lo, int32_t _hi) {
64 constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
65 constant_int* hi = static_cast<constant_int*>(get_int32(_hi));
66 return insert(make_range::create(lo, hi));
67}
68
69type *builder::get_void_ty()
70{ return type::get_void_ty(ctx_); }
71
72type *builder::get_int1_ty()
73{ return type::get_int1_ty(ctx_); }
74
75type *builder::get_int8_ty()
76{ return type::get_int8_ty(ctx_); }
77
78type *builder::get_int16_ty()
79{ return type::get_int16_ty(ctx_); }
80
81type *builder::get_int32_ty()
82{ return type::get_int32_ty(ctx_); }
83
84type *builder::get_int64_ty()
85{ return type::get_int64_ty(ctx_); }
86
87type *builder::get_fp8_ty()
88{ return type::get_fp8_ty(ctx_); }
89
90type *builder::get_half_ty()
91{ return type::get_fp16_ty(ctx_); }
92
93type *builder::get_bf16_ty()
94{ return type::get_bf16_ty(ctx_); }
95
96type *builder::get_float_ty()
97{ return type::get_fp32_ty(ctx_); }
98
99type *builder::get_double_ty()
100{ return type::get_fp64_ty(ctx_); }
101
102
103//===----------------------------------------------------------------------===//
104// terminator instructions
105//===----------------------------------------------------------------------===//
106
107value* builder::create_br(basic_block *dest){
108 return insert(branch_inst::create(dest));
109}
110
111value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){
112 return insert(branch_inst::create(cond, if_dest, else_dest));
113}
114
115value *builder::create_ret_void() {
116 return insert(return_inst::create(ctx_));
117}
118
119value *builder::create_ret(value* val) {
120 return insert(return_inst::create(ctx_, val));
121}
122
123//===----------------------------------------------------------------------===//
124// dequantize instructions
125//===----------------------------------------------------------------------===//
126
127value* builder::create_dequantize(value *src, value *scale, value *shift, type *dst_ty){
128 return insert(dequantize_inst::create(src, scale, shift, dst_ty));
129}
130
131//===----------------------------------------------------------------------===//
132// cast instructions
133//===----------------------------------------------------------------------===//
134#define DEFINE_CAST_INSTR(SUFFIX, OPCODE)\
135 value *builder::create_ ## SUFFIX(value *src, type *dst_ty){\
136 return create_cast(OPCODE, src, dst_ty);\
137 }
138
139DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast)
140DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr)
141DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
142DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
143DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
144DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI)
145DEFINE_CAST_INSTR(fp_to_ui, cast_op_t::FPToUI)
146DEFINE_CAST_INSTR(fp_ext, cast_op_t::FPExt)
147DEFINE_CAST_INSTR(fp_trunc, cast_op_t::FPTrunc)
148
149value* builder::create_cast(cast_op_t op, value *v, type *dst_ty){
150 return insert(cast_inst::create(op, v, dst_ty));
151}
152
153value* builder::create_int_cast(value *src, type *dst_ty, bool is_signed){
154 return insert(cast_inst::create_integer_cast(src, dst_ty, is_signed));
155}
156
157//===----------------------------------------------------------------------===//
158// phi instructions
159//===----------------------------------------------------------------------===//
160
161phi_node* builder::create_phi(type *ty, unsigned num_reserved){
162 return insert(phi_node::create(ty, num_reserved));
163}
164
165//===----------------------------------------------------------------------===//
166// call instructions
167//===----------------------------------------------------------------------===//
168
169value *builder::create_call(function* fn, const std::vector<value*>& args){
170 return insert(call_inst::create(fn, args));
171}
172
173value* builder::create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps){
174 return insert(launch_inst::create(fn, args, grid, num_warps));
175
176}
177
178//===----------------------------------------------------------------------===//
179// binary float instructions
180//===----------------------------------------------------------------------===//
181
182#define DEFINE_BINARY_FLOAT(SUFFIX, OPCODE)\
183 value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\
184 return insert(binary_operator::create(OPCODE, lhs, rhs));\
185 }
186
187// Binary
188DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul)
189DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv)
190DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem)
191DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd)
192DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
193
194
195//===----------------------------------------------------------------------===//
196// binary int instructions
197//===----------------------------------------------------------------------===//
198
199
200value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
201 value *rhs,
202 bool has_nuw, bool has_nsw) {
203 binary_operator* result = insert(binary_operator::create(op, lhs, rhs));
204 if (has_nuw) result->set_has_no_unsigned_wrap();
205 if (has_nsw) result->set_has_no_signed_wrap();
206 return result;
207}
208
209#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
210 value* builder::create_ ## SUFFIX(value *lhs, value *rhs, bool has_nuw, bool has_nsw){\
211 return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, has_nuw, has_nsw);\
212 }\
213
214#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\
215 value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\
216 return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, false, false);\
217 }
218
219
220
221// Binary
222DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul)
223DEFINE_NOWRAP_BINARY(add, binary_op_t::Add)
224DEFINE_NOWRAP_BINARY(sub, binary_op_t::Sub)
225DEFINE_NOWRAP_BINARY(shl, binary_op_t::Shl)
226DEFINE_NOWRAP_BINARY(ashr, binary_op_t::AShr)
227DEFINE_NOWRAP_BINARY(lshr, binary_op_t::LShr)
228DEFINE_BINARY_INT(sdiv, binary_op_t::SDiv)
229DEFINE_BINARY_INT(udiv, binary_op_t::UDiv)
230DEFINE_BINARY_INT(srem, binary_op_t::SRem)
231DEFINE_BINARY_INT(urem, binary_op_t::URem)
232DEFINE_BINARY_INT(and, binary_op_t::And)
233DEFINE_BINARY_INT(or, binary_op_t::Or)
234DEFINE_BINARY_INT(xor, binary_op_t::Xor)
235
236
237//===----------------------------------------------------------------------===//
238// getelementptr instructions
239//===----------------------------------------------------------------------===//
240
241value* builder::create_gep(value *ptr, const std::vector<value*>& idx_list){
242 return insert(getelementptr_inst::create(ptr, idx_list));
243}
244
245//===----------------------------------------------------------------------===//
246// icmp instructions
247//===----------------------------------------------------------------------===//
248
249value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs){
250 return insert(icmp_inst::create(pred, lhs, rhs));
251}
252
253#define DEFINE_ICMP_INSTR(SUFFIX, OPCODE)\
254 value *builder::create_icmp ## SUFFIX(value *lhs, value *rhs){\
255 return create_icmp(OPCODE, lhs, rhs);\
256 }
257
258// Signed
259DEFINE_ICMP_INSTR(SLE, cmp_pred_t::ICMP_SLE)
260DEFINE_ICMP_INSTR(SLT, cmp_pred_t::ICMP_SLT)
261DEFINE_ICMP_INSTR(SGE, cmp_pred_t::ICMP_SGE)
262DEFINE_ICMP_INSTR(SGT, cmp_pred_t::ICMP_SGT)
263// Unsigned
264DEFINE_ICMP_INSTR(ULE, cmp_pred_t::ICMP_ULE)
265DEFINE_ICMP_INSTR(ULT, cmp_pred_t::ICMP_ULT)
266DEFINE_ICMP_INSTR(UGE, cmp_pred_t::ICMP_UGE)
267DEFINE_ICMP_INSTR(UGT, cmp_pred_t::ICMP_UGT)
268// General
269DEFINE_ICMP_INSTR(EQ, cmp_pred_t::ICMP_EQ)
270DEFINE_ICMP_INSTR(NE, cmp_pred_t::ICMP_NE)
271
272
273//===----------------------------------------------------------------------===//
274// fcmp instructions
275//===----------------------------------------------------------------------===//
276
277value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs){
278 return insert(fcmp_inst::create(pred, lhs, rhs));
279}
280
281#define DEFINE_FCMP_INSTR(SUFFIX, OPCODE)\
282 value *builder::create_fcmp ## SUFFIX(value *lhs, value *rhs){\
283 return create_fcmp(OPCODE, lhs, rhs);\
284 }
285
286// Ordered
287DEFINE_FCMP_INSTR(OLE, cmp_pred_t::FCMP_OLE)
288DEFINE_FCMP_INSTR(OLT, cmp_pred_t::FCMP_OLT)
289DEFINE_FCMP_INSTR(OGE, cmp_pred_t::FCMP_OGE)
290DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT)
291DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ)
292DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
293
294DEFINE_FCMP_INSTR(ULE, cmp_pred_t::FCMP_ULE)
295DEFINE_FCMP_INSTR(ULT, cmp_pred_t::FCMP_ULT)
296DEFINE_FCMP_INSTR(UGE, cmp_pred_t::FCMP_UGE)
297DEFINE_FCMP_INSTR(UGT, cmp_pred_t::FCMP_UGT)
298DEFINE_FCMP_INSTR(UEQ, cmp_pred_t::FCMP_UEQ)
299DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
300
301
302//===----------------------------------------------------------------------===//
303// load/store instructions
304//===----------------------------------------------------------------------===//
305
306value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
307 return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
308}
309
310value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){
311 return insert(unmasked_store_inst::create(ptr, val, eviction));
312}
313
314value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
315 return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
316}
317
318value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){
319 return insert(masked_store_inst::create(ptr, val, mask, eviction));
320}
321
322//===----------------------------------------------------------------------===//
323// struct instructions
324//===----------------------------------------------------------------------===//
325
326
327// Struct instructions
328value *builder::create_insert_value(value* val, value *elt, size_t idx){
329 return insert(insert_value_inst::create(val, elt, idx));
330}
331
332value *builder::create_extract_value(value* val, size_t idx) {
333 return insert(extract_value_inst::create(val, idx));
334}
335//===----------------------------------------------------------------------===//
336// block instructions
337//===----------------------------------------------------------------------===//
338
339value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) {
340 return insert(reshape_inst::create(arg, shapes));
341}
342
343value *builder::create_cat(value *lhs, value *rhs) {
344 return insert(cat_inst::create(lhs, rhs));
345}
346
347value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) {
348 return insert(splat_inst::create(arg, shapes));
349}
350
351value *builder::create_broadcast(value *arg, const type::block_shapes_t &shapes) {
352 return insert(broadcast_inst::create(arg, shapes));
353}
354
355value *builder::create_downcast(value *arg) {
356 return insert(downcast_inst::create(arg));
357}
358
359//
360
361value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
362 return insert(atomic_rmw_inst::create(op, ptr, val, msk));
363}
364
365#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\
366 value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\
367 return create_atomic_rmw(OPCODE, ptr, val, mask);\
368 }
369
370DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max)
371DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax)
372DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min)
373DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin)
374DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd)
375DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add)
376DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And)
377DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or)
378DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor)
379DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg)
380
381// Utilities
382value *builder::create_clock() {
383 return insert(clock_inst::create(ctx_));
384}
385
386value *builder::create_globaltimer() {
387 return insert(globaltimer_inst::create(ctx_));
388}
389
390//===----------------------------------------------------------------------===//
391// externs
392//===----------------------------------------------------------------------===//
393
394value *builder::create_extern_elementwise(const std::string &lib_name,
395 const std::string &lib_path,
396 const std::string &symbol_name,
397 const std::vector<value *> &args,
398 type *ret_ty) {
399 return insert(extern_elementwise_inst::create(ctx_, args, ret_ty, lib_name,
400 lib_path, symbol_name));
401}
402
403//===----------------------------------------------------------------------===//
404// built-in instructions
405//===----------------------------------------------------------------------===//
406
407value *builder::create_get_program_id(unsigned axis) {
408 return insert(get_program_id_inst::create(ctx_, axis));
409}
410
411value *builder::create_get_num_programs(unsigned axis) {
412 return insert(get_num_programs_inst::create(ctx_, axis));
413}
414
415value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
416 return insert(atomic_cas_inst::create(ptr, cmp, val));
417}
418
419
420value *builder::create_exp(value *arg){
421 return insert(exp_inst::create(arg));
422}
423
424value *builder::create_cos(value *arg){
425 return insert(cos_inst::create(arg));
426}
427
428value *builder::create_sin(value *arg){
429 return insert(sin_inst::create(arg));
430}
431
432value *builder::create_log(value *arg){
433 return insert(log_inst::create(arg));
434}
435
436value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) {
437 return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32));
438}
439
440value *builder::create_trans(value *A, const std::vector<int>& perm) {
441 return insert(trans_inst::create(A, perm));
442}
443
444value *builder::create_sqrt(value *A) {
445 return insert(sqrt_inst::create(A));
446}
447
448value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis) {
449 return insert(reduce_inst::create(A, op, axis));
450}
451
452value *builder::create_select(value *pred, value *if_value, value *else_value){
453 return insert(select_inst::create(pred, if_value, else_value));
454}
455
456//===----------------------------------------------------------------------===//
457// intrinsic instructions
458//===----------------------------------------------------------------------===//
459
460value *builder::create_umulhi(value *lhs, value *rhs) {
461 return insert(umulhi_inst::create(lhs, rhs));
462}
463
464value *builder::create_copy_to_shared(value *arg) {
465 return insert(copy_to_shared_inst::create(arg));
466}
467
468
469value *builder::create_copy_from_shared(value *arg) {
470 return insert(copy_from_shared_inst::create(arg));
471}
472
473value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) {
474 return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction));
475}
476
477value *builder::create_barrier(const std::string &name) {
478 return insert(barrier_inst::create(ctx_));
479}
480
481value *builder::create_async_wait(int N) {
482 return insert(async_wait_inst::create(ctx_, N));
483}
484
485value *builder::create_prefetch_s(value *arg, int inc) {
486 return insert(prefetch_s_inst::create(ctx_, arg, inc));
487}
488
489
490}
491}
492