1 | #include <string> |
2 | #include <algorithm> |
3 | #include <iostream> |
4 | #include "triton/ir/basic_block.h" |
5 | #include "triton/ir/builder.h" |
6 | #include "triton/ir/constant.h" |
7 | #include "triton/ir/instructions.h" |
8 | #include "triton/ir/type.h" |
9 | |
10 | namespace triton{ |
11 | namespace ir{ |
12 | |
13 | builder::builder(context &ctx): |
14 | ctx_(ctx), block_(nullptr) {} |
15 | |
16 | //===----------------------------------------------------------------------===// |
17 | // utilities |
18 | //===----------------------------------------------------------------------===// |
19 | void builder::set_insert_point(basic_block::iterator it){ |
20 | block_ = (*it)->get_parent(); |
21 | insert_point_ = it; |
22 | } |
23 | |
24 | void builder::set_insert_point(instruction* i){ |
25 | block_ = i->get_parent(); |
26 | auto it = std::find(block_->begin(), block_->end(), i); |
27 | set_insert_point(it); |
28 | } |
29 | |
30 | |
31 | void builder::set_insert_point_after(instruction* i){ |
32 | block_ = i->get_parent(); |
33 | auto it = std::find(block_->begin(), block_->end(), i); |
34 | set_insert_point(++it); |
35 | } |
36 | |
37 | |
38 | void builder::set_insert_point(basic_block *block){ |
39 | block_ = block; |
40 | insert_point_ = block->end(); |
41 | } |
42 | |
43 | |
44 | //===----------------------------------------------------------------------===// |
45 | // convenience functions |
46 | //===----------------------------------------------------------------------===// |
47 | |
48 | value *builder::get_int1(bool val) |
49 | { return constant_int::get(type::get_int1_ty(ctx_), val); } |
50 | |
51 | value *builder::get_int32(uint32_t val) |
52 | { return constant_int::get(type::get_int32_ty(ctx_), val);} |
53 | |
54 | value *builder::get_int64(uint64_t val) |
55 | { return constant_int::get(type::get_int64_ty(ctx_), val);} |
56 | |
57 | value *builder::get_float16(float val) |
58 | { return constant_fp::get(type::get_fp16_ty(ctx_), val); } |
59 | |
60 | value *builder::get_float32(float val) |
61 | { return constant_fp::get(type::get_fp32_ty(ctx_), val); } |
62 | |
63 | value *builder::get_range(int32_t _lo, int32_t _hi) { |
64 | constant_int* lo = static_cast<constant_int*>(get_int32(_lo)); |
65 | constant_int* hi = static_cast<constant_int*>(get_int32(_hi)); |
66 | return insert(make_range::create(lo, hi)); |
67 | } |
68 | |
69 | type *builder::get_void_ty() |
70 | { return type::get_void_ty(ctx_); } |
71 | |
72 | type *builder::get_int1_ty() |
73 | { return type::get_int1_ty(ctx_); } |
74 | |
75 | type *builder::get_int8_ty() |
76 | { return type::get_int8_ty(ctx_); } |
77 | |
78 | type *builder::get_int16_ty() |
79 | { return type::get_int16_ty(ctx_); } |
80 | |
81 | type *builder::get_int32_ty() |
82 | { return type::get_int32_ty(ctx_); } |
83 | |
84 | type *builder::get_int64_ty() |
85 | { return type::get_int64_ty(ctx_); } |
86 | |
87 | type *builder::get_fp8_ty() |
88 | { return type::get_fp8_ty(ctx_); } |
89 | |
90 | type *builder::get_half_ty() |
91 | { return type::get_fp16_ty(ctx_); } |
92 | |
93 | type *builder::get_bf16_ty() |
94 | { return type::get_bf16_ty(ctx_); } |
95 | |
96 | type *builder::get_float_ty() |
97 | { return type::get_fp32_ty(ctx_); } |
98 | |
99 | type *builder::get_double_ty() |
100 | { return type::get_fp64_ty(ctx_); } |
101 | |
102 | |
103 | //===----------------------------------------------------------------------===// |
104 | // terminator instructions |
105 | //===----------------------------------------------------------------------===// |
106 | |
107 | value* builder::create_br(basic_block *dest){ |
108 | return insert(branch_inst::create(dest)); |
109 | } |
110 | |
111 | value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){ |
112 | return insert(branch_inst::create(cond, if_dest, else_dest)); |
113 | } |
114 | |
115 | value *builder::create_ret_void() { |
116 | return insert(return_inst::create(ctx_)); |
117 | } |
118 | |
119 | value *builder::create_ret(value* val) { |
120 | return insert(return_inst::create(ctx_, val)); |
121 | } |
122 | |
123 | //===----------------------------------------------------------------------===// |
124 | // dequantize instructions |
125 | //===----------------------------------------------------------------------===// |
126 | |
127 | value* builder::create_dequantize(value *src, value *scale, value *shift, type *dst_ty){ |
128 | return insert(dequantize_inst::create(src, scale, shift, dst_ty)); |
129 | } |
130 | |
131 | //===----------------------------------------------------------------------===// |
132 | // cast instructions |
133 | //===----------------------------------------------------------------------===// |
134 | #define DEFINE_CAST_INSTR(SUFFIX, OPCODE)\ |
135 | value *builder::create_ ## SUFFIX(value *src, type *dst_ty){\ |
136 | return create_cast(OPCODE, src, dst_ty);\ |
137 | } |
138 | |
139 | DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast) |
140 | DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr) |
141 | DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) |
142 | DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) |
143 | DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) |
144 | DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI) |
145 | DEFINE_CAST_INSTR(fp_to_ui, cast_op_t::FPToUI) |
146 | DEFINE_CAST_INSTR(fp_ext, cast_op_t::FPExt) |
147 | DEFINE_CAST_INSTR(fp_trunc, cast_op_t::FPTrunc) |
148 | |
149 | value* builder::create_cast(cast_op_t op, value *v, type *dst_ty){ |
150 | return insert(cast_inst::create(op, v, dst_ty)); |
151 | } |
152 | |
153 | value* builder::create_int_cast(value *src, type *dst_ty, bool is_signed){ |
154 | return insert(cast_inst::create_integer_cast(src, dst_ty, is_signed)); |
155 | } |
156 | |
157 | //===----------------------------------------------------------------------===// |
158 | // phi instructions |
159 | //===----------------------------------------------------------------------===// |
160 | |
161 | phi_node* builder::create_phi(type *ty, unsigned num_reserved){ |
162 | return insert(phi_node::create(ty, num_reserved)); |
163 | } |
164 | |
165 | //===----------------------------------------------------------------------===// |
166 | // call instructions |
167 | //===----------------------------------------------------------------------===// |
168 | |
169 | value *builder::create_call(function* fn, const std::vector<value*>& args){ |
170 | return insert(call_inst::create(fn, args)); |
171 | } |
172 | |
173 | value* builder::create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps){ |
174 | return insert(launch_inst::create(fn, args, grid, num_warps)); |
175 | |
176 | } |
177 | |
178 | //===----------------------------------------------------------------------===// |
179 | // binary float instructions |
180 | //===----------------------------------------------------------------------===// |
181 | |
182 | #define DEFINE_BINARY_FLOAT(SUFFIX, OPCODE)\ |
183 | value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\ |
184 | return insert(binary_operator::create(OPCODE, lhs, rhs));\ |
185 | } |
186 | |
187 | // Binary |
188 | DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul) |
189 | DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv) |
190 | DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem) |
191 | DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd) |
192 | DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub) |
193 | |
194 | |
195 | //===----------------------------------------------------------------------===// |
196 | // binary int instructions |
197 | //===----------------------------------------------------------------------===// |
198 | |
199 | |
200 | value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs, |
201 | value *rhs, |
202 | bool has_nuw, bool has_nsw) { |
203 | binary_operator* result = insert(binary_operator::create(op, lhs, rhs)); |
204 | if (has_nuw) result->set_has_no_unsigned_wrap(); |
205 | if (has_nsw) result->set_has_no_signed_wrap(); |
206 | return result; |
207 | } |
208 | |
209 | #define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\ |
210 | value* builder::create_ ## SUFFIX(value *lhs, value *rhs, bool has_nuw, bool has_nsw){\ |
211 | return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, has_nuw, has_nsw);\ |
212 | }\ |
213 | |
214 | #define DEFINE_BINARY_INT(SUFFIX, OPCODE)\ |
215 | value *builder::create_ ## SUFFIX(value *lhs, value *rhs){\ |
216 | return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, false, false);\ |
217 | } |
218 | |
219 | |
220 | |
221 | // Binary |
222 | DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul) |
223 | DEFINE_NOWRAP_BINARY(add, binary_op_t::Add) |
224 | DEFINE_NOWRAP_BINARY(sub, binary_op_t::Sub) |
225 | DEFINE_NOWRAP_BINARY(shl, binary_op_t::Shl) |
226 | DEFINE_NOWRAP_BINARY(ashr, binary_op_t::AShr) |
227 | DEFINE_NOWRAP_BINARY(lshr, binary_op_t::LShr) |
228 | DEFINE_BINARY_INT(sdiv, binary_op_t::SDiv) |
229 | DEFINE_BINARY_INT(udiv, binary_op_t::UDiv) |
230 | DEFINE_BINARY_INT(srem, binary_op_t::SRem) |
231 | DEFINE_BINARY_INT(urem, binary_op_t::URem) |
232 | DEFINE_BINARY_INT(and, binary_op_t::And) |
233 | DEFINE_BINARY_INT(or, binary_op_t::Or) |
234 | DEFINE_BINARY_INT(xor, binary_op_t::Xor) |
235 | |
236 | |
237 | //===----------------------------------------------------------------------===// |
238 | // getelementptr instructions |
239 | //===----------------------------------------------------------------------===// |
240 | |
241 | value* builder::create_gep(value *ptr, const std::vector<value*>& idx_list){ |
242 | return insert(getelementptr_inst::create(ptr, idx_list)); |
243 | } |
244 | |
245 | //===----------------------------------------------------------------------===// |
246 | // icmp instructions |
247 | //===----------------------------------------------------------------------===// |
248 | |
249 | value *builder::create_icmp(cmp_pred_t pred, value *lhs, value *rhs){ |
250 | return insert(icmp_inst::create(pred, lhs, rhs)); |
251 | } |
252 | |
253 | #define DEFINE_ICMP_INSTR(SUFFIX, OPCODE)\ |
254 | value *builder::create_icmp ## SUFFIX(value *lhs, value *rhs){\ |
255 | return create_icmp(OPCODE, lhs, rhs);\ |
256 | } |
257 | |
258 | // Signed |
259 | DEFINE_ICMP_INSTR(SLE, cmp_pred_t::ICMP_SLE) |
260 | DEFINE_ICMP_INSTR(SLT, cmp_pred_t::ICMP_SLT) |
261 | DEFINE_ICMP_INSTR(SGE, cmp_pred_t::ICMP_SGE) |
262 | DEFINE_ICMP_INSTR(SGT, cmp_pred_t::ICMP_SGT) |
263 | // Unsigned |
264 | DEFINE_ICMP_INSTR(ULE, cmp_pred_t::ICMP_ULE) |
265 | DEFINE_ICMP_INSTR(ULT, cmp_pred_t::ICMP_ULT) |
266 | DEFINE_ICMP_INSTR(UGE, cmp_pred_t::ICMP_UGE) |
267 | DEFINE_ICMP_INSTR(UGT, cmp_pred_t::ICMP_UGT) |
268 | // General |
269 | DEFINE_ICMP_INSTR(EQ, cmp_pred_t::ICMP_EQ) |
270 | DEFINE_ICMP_INSTR(NE, cmp_pred_t::ICMP_NE) |
271 | |
272 | |
273 | //===----------------------------------------------------------------------===// |
274 | // fcmp instructions |
275 | //===----------------------------------------------------------------------===// |
276 | |
277 | value *builder::create_fcmp(cmp_pred_t pred, value *lhs, value *rhs){ |
278 | return insert(fcmp_inst::create(pred, lhs, rhs)); |
279 | } |
280 | |
281 | #define DEFINE_FCMP_INSTR(SUFFIX, OPCODE)\ |
282 | value *builder::create_fcmp ## SUFFIX(value *lhs, value *rhs){\ |
283 | return create_fcmp(OPCODE, lhs, rhs);\ |
284 | } |
285 | |
286 | // Ordered |
287 | DEFINE_FCMP_INSTR(OLE, cmp_pred_t::FCMP_OLE) |
288 | DEFINE_FCMP_INSTR(OLT, cmp_pred_t::FCMP_OLT) |
289 | DEFINE_FCMP_INSTR(OGE, cmp_pred_t::FCMP_OGE) |
290 | DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT) |
291 | DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ) |
292 | DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE) |
293 | |
294 | DEFINE_FCMP_INSTR(ULE, cmp_pred_t::FCMP_ULE) |
295 | DEFINE_FCMP_INSTR(ULT, cmp_pred_t::FCMP_ULT) |
296 | DEFINE_FCMP_INSTR(UGE, cmp_pred_t::FCMP_UGE) |
297 | DEFINE_FCMP_INSTR(UGT, cmp_pred_t::FCMP_UGT) |
298 | DEFINE_FCMP_INSTR(UEQ, cmp_pred_t::FCMP_UEQ) |
299 | DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE) |
300 | |
301 | |
302 | //===----------------------------------------------------------------------===// |
303 | // load/store instructions |
304 | //===----------------------------------------------------------------------===// |
305 | |
306 | value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ |
307 | return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile)); |
308 | } |
309 | |
310 | value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){ |
311 | return insert(unmasked_store_inst::create(ptr, val, eviction)); |
312 | } |
313 | |
314 | value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ |
315 | return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile)); |
316 | } |
317 | |
318 | value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){ |
319 | return insert(masked_store_inst::create(ptr, val, mask, eviction)); |
320 | } |
321 | |
322 | //===----------------------------------------------------------------------===// |
323 | // struct instructions |
324 | //===----------------------------------------------------------------------===// |
325 | |
326 | |
327 | // Struct instructions |
328 | value *builder::create_insert_value(value* val, value *elt, size_t idx){ |
329 | return insert(insert_value_inst::create(val, elt, idx)); |
330 | } |
331 | |
332 | value *builder::(value* val, size_t idx) { |
333 | return insert(extract_value_inst::create(val, idx)); |
334 | } |
335 | //===----------------------------------------------------------------------===// |
336 | // block instructions |
337 | //===----------------------------------------------------------------------===// |
338 | |
339 | value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) { |
340 | return insert(reshape_inst::create(arg, shapes)); |
341 | } |
342 | |
343 | value *builder::create_cat(value *lhs, value *rhs) { |
344 | return insert(cat_inst::create(lhs, rhs)); |
345 | } |
346 | |
347 | value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) { |
348 | return insert(splat_inst::create(arg, shapes)); |
349 | } |
350 | |
351 | value *builder::create_broadcast(value *arg, const type::block_shapes_t &shapes) { |
352 | return insert(broadcast_inst::create(arg, shapes)); |
353 | } |
354 | |
355 | value *builder::create_downcast(value *arg) { |
356 | return insert(downcast_inst::create(arg)); |
357 | } |
358 | |
359 | // |
360 | |
361 | value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ |
362 | return insert(atomic_rmw_inst::create(op, ptr, val, msk)); |
363 | } |
364 | |
365 | #define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\ |
366 | value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\ |
367 | return create_atomic_rmw(OPCODE, ptr, val, mask);\ |
368 | } |
369 | |
370 | DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max) |
371 | DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax) |
372 | DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min) |
373 | DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin) |
374 | DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd) |
375 | DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add) |
376 | DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And) |
377 | DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or) |
378 | DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor) |
379 | DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg) |
380 | |
381 | // Utilities |
382 | value *builder::create_clock() { |
383 | return insert(clock_inst::create(ctx_)); |
384 | } |
385 | |
386 | value *builder::create_globaltimer() { |
387 | return insert(globaltimer_inst::create(ctx_)); |
388 | } |
389 | |
390 | //===----------------------------------------------------------------------===// |
391 | // externs |
392 | //===----------------------------------------------------------------------===// |
393 | |
394 | value *builder::create_extern_elementwise(const std::string &lib_name, |
395 | const std::string &lib_path, |
396 | const std::string &symbol_name, |
397 | const std::vector<value *> &args, |
398 | type *ret_ty) { |
399 | return insert(extern_elementwise_inst::create(ctx_, args, ret_ty, lib_name, |
400 | lib_path, symbol_name)); |
401 | } |
402 | |
403 | //===----------------------------------------------------------------------===// |
404 | // built-in instructions |
405 | //===----------------------------------------------------------------------===// |
406 | |
407 | value *builder::create_get_program_id(unsigned axis) { |
408 | return insert(get_program_id_inst::create(ctx_, axis)); |
409 | } |
410 | |
411 | value *builder::create_get_num_programs(unsigned axis) { |
412 | return insert(get_num_programs_inst::create(ctx_, axis)); |
413 | } |
414 | |
415 | value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){ |
416 | return insert(atomic_cas_inst::create(ptr, cmp, val)); |
417 | } |
418 | |
419 | |
420 | value *builder::create_exp(value *arg){ |
421 | return insert(exp_inst::create(arg)); |
422 | } |
423 | |
424 | value *builder::create_cos(value *arg){ |
425 | return insert(cos_inst::create(arg)); |
426 | } |
427 | |
428 | value *builder::create_sin(value *arg){ |
429 | return insert(sin_inst::create(arg)); |
430 | } |
431 | |
432 | value *builder::create_log(value *arg){ |
433 | return insert(log_inst::create(arg)); |
434 | } |
435 | |
436 | value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) { |
437 | return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32)); |
438 | } |
439 | |
440 | value *builder::create_trans(value *A, const std::vector<int>& perm) { |
441 | return insert(trans_inst::create(A, perm)); |
442 | } |
443 | |
444 | value *builder::create_sqrt(value *A) { |
445 | return insert(sqrt_inst::create(A)); |
446 | } |
447 | |
448 | value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis) { |
449 | return insert(reduce_inst::create(A, op, axis)); |
450 | } |
451 | |
452 | value *builder::create_select(value *pred, value *if_value, value *else_value){ |
453 | return insert(select_inst::create(pred, if_value, else_value)); |
454 | } |
455 | |
456 | //===----------------------------------------------------------------------===// |
457 | // intrinsic instructions |
458 | //===----------------------------------------------------------------------===// |
459 | |
460 | value *builder::create_umulhi(value *lhs, value *rhs) { |
461 | return insert(umulhi_inst::create(lhs, rhs)); |
462 | } |
463 | |
464 | value *builder::create_copy_to_shared(value *arg) { |
465 | return insert(copy_to_shared_inst::create(arg)); |
466 | } |
467 | |
468 | |
469 | value *builder::create_copy_from_shared(value *arg) { |
470 | return insert(copy_from_shared_inst::create(arg)); |
471 | } |
472 | |
473 | value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) { |
474 | return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction)); |
475 | } |
476 | |
477 | value *builder::create_barrier(const std::string &name) { |
478 | return insert(barrier_inst::create(ctx_)); |
479 | } |
480 | |
481 | value *builder::create_async_wait(int N) { |
482 | return insert(async_wait_inst::create(ctx_, N)); |
483 | } |
484 | |
485 | value *builder::create_prefetch_s(value *arg, int inc) { |
486 | return insert(prefetch_s_inst::create(ctx_, arg, inc)); |
487 | } |
488 | |
489 | |
490 | } |
491 | } |
492 | |