1 | #pragma once |
2 | |
3 | #ifndef _TRITON_IR_BUILDER_H_ |
4 | #define _TRITON_IR_BUILDER_H_ |
5 | |
6 | #include <vector> |
7 | #include <string> |
8 | #include "instructions.h" |
9 | #include "basic_block.h" |
10 | #include "type.h" |
11 | |
12 | namespace triton{ |
13 | namespace ir{ |
14 | |
15 | class basic_block; |
16 | class value; |
17 | class type; |
18 | class constant_int; |
19 | class instruction; |
20 | class context; |
21 | class phi_node; |
22 | |
23 | /* Builder */ |
24 | class builder{ |
25 | public: |
26 | typedef basic_block::iterator iterator; |
27 | |
28 | public: |
29 | // Constructor |
30 | builder(context &ctx); |
31 | // Getters |
32 | // const context& get_context() const { return ctx_; } |
33 | context& get_context() { return ctx_; } |
34 | |
35 | // Setters |
36 | void set_insert_point(iterator instr); |
37 | void set_insert_point(instruction* i); |
38 | void set_insert_point_after(instruction* i); |
39 | void set_insert_point(basic_block* block); |
40 | basic_block* get_insert_block() { return block_; } |
41 | iterator get_insert_point() { return insert_point_;} |
42 | // Constants |
43 | value *get_int1(bool val); |
44 | value *get_int32(uint32_t val); |
45 | value *get_int64(uint64_t val); |
46 | value *get_float16(float val); |
47 | value *get_float32(float val); |
48 | value *get_range(int32_t lo, int32_t hi); |
49 | // Types |
50 | type *get_void_ty(); |
51 | type *get_int1_ty(); |
52 | type *get_int8_ty(); |
53 | type *get_int16_ty(); |
54 | type *get_int32_ty(); |
55 | type *get_int64_ty(); |
56 | type *get_fp8_ty(); |
57 | type *get_half_ty(); |
58 | type *get_bf16_ty(); |
59 | type *get_float_ty(); |
60 | type *get_double_ty(); |
61 | // Insert |
62 | template<typename InstTy> |
63 | InstTy* insert(InstTy *inst){ |
64 | assert(block_); |
65 | block_->get_inst_list().insert(insert_point_, inst); |
66 | inst->set_parent(block_); |
67 | // for(ir::value* op: inst->ops()) |
68 | // op->add_use(inst); |
69 | return inst; |
70 | } |
71 | // terminator instructions |
72 | value* create_br(basic_block *dest); |
73 | value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); |
74 | value* create_ret_void(); |
75 | value* create_ret(value *ret); |
76 | // Dequantize instructions |
77 | value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty); |
78 | // Cast instructions |
79 | value* create_bitcast(value *src, type *dest_ty); |
80 | value *create_cast(cast_op_t op, value *v, type *dst_ty); |
81 | value* create_int_to_ptr(value *src, type *dst_ty); |
82 | value* create_ptr_to_int(value *src, type *dst_ty); |
83 | value* create_si_to_fp(value *src, type *dst_ty); |
84 | value* create_ui_to_fp(value *src, type *dst_ty); |
85 | value* create_fp_to_si(value *src, type *dst_ty); |
86 | value* create_fp_to_ui(value *src, type *dst_ty); |
87 | value* create_fp_ext(value *src, type *dst_ty); |
88 | value* create_fp_trunc(value *src, type *dst_ty); |
89 | value* create_int_cast(value *src, type *dst_ty, bool is_signed); |
90 | value *create_downcast(value *arg); |
91 | // Call instruction |
92 | value* create_call(function* fn, const std::vector<value*>& args); |
93 | value* create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps); |
94 | // Phi instruction |
95 | phi_node* create_phi(type *ty, unsigned num_reserved); |
96 | // Binary instructions |
97 | value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw); |
98 | value *create_fmul(value *lhs, value *rhs); |
99 | value *create_fdiv(value *lhs, value *rhs); |
100 | value *create_frem(value *lhs, value *rhs); |
101 | value *create_fadd(value *lhs, value *rhs); |
102 | value *create_fsub(value *lhs, value *rhs); |
103 | value *create_sdiv(value *lhs, value *rhs); |
104 | value *create_udiv(value *lhs, value *rhs); |
105 | value *create_srem(value *lhs, value *rhs); |
106 | value *create_urem(value *lhs, value *rhs); |
107 | value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); |
108 | value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); |
109 | value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); |
110 | value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); |
111 | value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); |
112 | value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); |
113 | // GEP |
114 | value *create_gep(value *ptr, const std::vector<value*>& idx_list); |
115 | // Comparison (int) |
116 | value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs); |
117 | value *create_icmpSLE(value *lhs, value *rhs); |
118 | value *create_icmpSLT(value *lhs, value *rhs); |
119 | value *create_icmpSGE(value *lhs, value *rhs); |
120 | value *create_icmpSGT(value *lhs, value *rhs); |
121 | value *create_icmpULE(value *lhs, value *rhs); |
122 | value *create_icmpULT(value *lhs, value *rhs); |
123 | value *create_icmpUGE(value *lhs, value *rhs); |
124 | value *create_icmpUGT(value *lhs, value *rhs); |
125 | value *create_icmpEQ(value *lhs, value *rhs); |
126 | value *create_icmpNE(value *lhs, value *rhs); |
127 | // Comparison (float) |
128 | value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs); |
129 | value *create_fcmpOLT(value *lhs, value *rhs); |
130 | value *create_fcmpOGT(value *lhs, value *rhs); |
131 | value *create_fcmpOLE(value *lhs, value *rhs); |
132 | value *create_fcmpOGE(value *lhs, value *rhs); |
133 | value *create_fcmpOEQ(value *lhs, value *rhs); |
134 | value *create_fcmpONE(value *lhs, value *rhs); |
135 | value *create_fcmpULT(value *lhs, value *rhs); |
136 | value *create_fcmpUGT(value *lhs, value *rhs); |
137 | value *create_fcmpULE(value *lhs, value *rhs); |
138 | value *create_fcmpUGE(value *lhs, value *rhs); |
139 | value *create_fcmpUEQ(value *lhs, value *rhs); |
140 | value *create_fcmpUNE(value *lhs, value *rhs); |
141 | // Logical |
142 | value *create_and(value *lhs, value *rhs); |
143 | value *create_xor(value *lhs, value *rhs); |
144 | value *create_or(value *lhs, value *rhs); |
145 | // Input/Output |
146 | value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); |
147 | value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction); |
148 | value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); |
149 | value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction); |
150 | // Struct instructions |
151 | value *create_insert_value(value* val, value *elt, size_t idx); |
152 | value *(value* val, size_t idx); |
153 | // Block instruction |
154 | value *create_splat(value *arg, const type::block_shapes_t &shapes); |
155 | value *create_reshape(value *arg, const type::block_shapes_t &shapes); |
156 | value *create_cat(value *lhs, value *rhs); |
157 | value *create_broadcast(value *arg, const type::block_shapes_t &shapes); |
158 | // Atomic instruction |
159 | value *create_atomic_cas(value *ptr, value *cmp, value *val); |
160 | value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk); |
161 | value *create_atomic_max(value *ptr, value *val, value *msk); |
162 | value *create_atomic_umax(value *ptr, value *val, value *msk); |
163 | value *create_atomic_min(value *ptr, value *val, value *msk); |
164 | value *create_atomic_umin(value *ptr, value *val, value *msk); |
165 | value *create_atomic_fadd(value *ptr, value *val, value *msk); |
166 | value *create_atomic_add(value *ptr, value *val, value *msk); |
167 | value *create_atomic_and(value *ptr, value *val, value *msk); |
168 | value *create_atomic_or(value *ptr, value *val, value *msk); |
169 | value *create_atomic_xor(value *ptr, value *val, value *msk); |
170 | value *create_atomic_xchg(value *ptr, value *val, value *msk); |
171 | // Utilities |
172 | value *create_clock(); |
173 | value *create_globaltimer(); |
174 | // Extern instruction |
175 | value *create_extern_elementwise(const std::string &lib_name, |
176 | const std::string &lib_path, |
177 | const std::string &symbol_name, |
178 | const std::vector<value *> &args, |
179 | type *ret_ty); |
180 | // Built-in instruction |
181 | value *create_get_program_id(unsigned axis); |
182 | value *create_get_num_programs(unsigned axis); |
183 | value *create_exp(value* arg); |
184 | value *create_cos(value* arg); |
185 | value *create_sin(value* arg); |
186 | value *create_log(value* arg); |
187 | value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32); |
188 | value *create_trans(value *A, const std::vector<int> &perm = {}); |
189 | value *create_sqrt(value *A); |
190 | value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); |
191 | value *create_select(value *pred, value *if_value, value *else_value); |
192 | // Intrinsics |
193 | // These have no place in the IR, and hopefully they can be removed at some point |
194 | value *create_umulhi(value* lhs, value* rhs); |
195 | value *create_copy_to_shared(value *arg); |
196 | value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY); |
197 | value *create_copy_from_shared(value *arg); |
198 | value *create_barrier(const std::string &name = "" ); |
199 | value *create_async_wait(int N); |
200 | value *create_prefetch_s(value *arg, int inc); |
201 | |
202 | private: |
203 | context &ctx_; |
204 | basic_block *block_; |
205 | iterator insert_point_; |
206 | }; |
207 | |
208 | |
209 | } |
210 | } |
211 | |
212 | #endif |
213 | |