1#pragma once
2
3#ifndef _TRITON_IR_BUILDER_H_
4#define _TRITON_IR_BUILDER_H_
5
6#include <vector>
7#include <string>
8#include "instructions.h"
9#include "basic_block.h"
10#include "type.h"
11
12namespace triton{
13namespace ir{
14
15class basic_block;
16class value;
17class type;
18class constant_int;
19class instruction;
20class context;
21class phi_node;
22
23/* Builder */
24class builder{
25public:
26 typedef basic_block::iterator iterator;
27
28public:
29 // Constructor
30 builder(context &ctx);
31 // Getters
32 // const context& get_context() const { return ctx_; }
33 context& get_context() { return ctx_; }
34
35 // Setters
36 void set_insert_point(iterator instr);
37 void set_insert_point(instruction* i);
38 void set_insert_point_after(instruction* i);
39 void set_insert_point(basic_block* block);
40 basic_block* get_insert_block() { return block_; }
41 iterator get_insert_point() { return insert_point_;}
42 // Constants
43 value *get_int1(bool val);
44 value *get_int32(uint32_t val);
45 value *get_int64(uint64_t val);
46 value *get_float16(float val);
47 value *get_float32(float val);
48 value *get_range(int32_t lo, int32_t hi);
49 // Types
50 type *get_void_ty();
51 type *get_int1_ty();
52 type *get_int8_ty();
53 type *get_int16_ty();
54 type *get_int32_ty();
55 type *get_int64_ty();
56 type *get_fp8_ty();
57 type *get_half_ty();
58 type *get_bf16_ty();
59 type *get_float_ty();
60 type *get_double_ty();
61 // Insert
62 template<typename InstTy>
63 InstTy* insert(InstTy *inst){
64 assert(block_);
65 block_->get_inst_list().insert(insert_point_, inst);
66 inst->set_parent(block_);
67// for(ir::value* op: inst->ops())
68// op->add_use(inst);
69 return inst;
70 }
71 // terminator instructions
72 value* create_br(basic_block *dest);
73 value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
74 value* create_ret_void();
75 value* create_ret(value *ret);
76 // Dequantize instructions
77 value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty);
78 // Cast instructions
79 value* create_bitcast(value *src, type *dest_ty);
80 value *create_cast(cast_op_t op, value *v, type *dst_ty);
81 value* create_int_to_ptr(value *src, type *dst_ty);
82 value* create_ptr_to_int(value *src, type *dst_ty);
83 value* create_si_to_fp(value *src, type *dst_ty);
84 value* create_ui_to_fp(value *src, type *dst_ty);
85 value* create_fp_to_si(value *src, type *dst_ty);
86 value* create_fp_to_ui(value *src, type *dst_ty);
87 value* create_fp_ext(value *src, type *dst_ty);
88 value* create_fp_trunc(value *src, type *dst_ty);
89 value* create_int_cast(value *src, type *dst_ty, bool is_signed);
90 value *create_downcast(value *arg);
91 // Call instruction
92 value* create_call(function* fn, const std::vector<value*>& args);
93 value* create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps);
94 // Phi instruction
95 phi_node* create_phi(type *ty, unsigned num_reserved);
96 // Binary instructions
97 value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
98 value *create_fmul(value *lhs, value *rhs);
99 value *create_fdiv(value *lhs, value *rhs);
100 value *create_frem(value *lhs, value *rhs);
101 value *create_fadd(value *lhs, value *rhs);
102 value *create_fsub(value *lhs, value *rhs);
103 value *create_sdiv(value *lhs, value *rhs);
104 value *create_udiv(value *lhs, value *rhs);
105 value *create_srem(value *lhs, value *rhs);
106 value *create_urem(value *lhs, value *rhs);
107 value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
108 value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
109 value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
110 value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
111 value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
112 value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
113 // GEP
114 value *create_gep(value *ptr, const std::vector<value*>& idx_list);
115 // Comparison (int)
116 value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
117 value *create_icmpSLE(value *lhs, value *rhs);
118 value *create_icmpSLT(value *lhs, value *rhs);
119 value *create_icmpSGE(value *lhs, value *rhs);
120 value *create_icmpSGT(value *lhs, value *rhs);
121 value *create_icmpULE(value *lhs, value *rhs);
122 value *create_icmpULT(value *lhs, value *rhs);
123 value *create_icmpUGE(value *lhs, value *rhs);
124 value *create_icmpUGT(value *lhs, value *rhs);
125 value *create_icmpEQ(value *lhs, value *rhs);
126 value *create_icmpNE(value *lhs, value *rhs);
127 // Comparison (float)
128 value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
129 value *create_fcmpOLT(value *lhs, value *rhs);
130 value *create_fcmpOGT(value *lhs, value *rhs);
131 value *create_fcmpOLE(value *lhs, value *rhs);
132 value *create_fcmpOGE(value *lhs, value *rhs);
133 value *create_fcmpOEQ(value *lhs, value *rhs);
134 value *create_fcmpONE(value *lhs, value *rhs);
135 value *create_fcmpULT(value *lhs, value *rhs);
136 value *create_fcmpUGT(value *lhs, value *rhs);
137 value *create_fcmpULE(value *lhs, value *rhs);
138 value *create_fcmpUGE(value *lhs, value *rhs);
139 value *create_fcmpUEQ(value *lhs, value *rhs);
140 value *create_fcmpUNE(value *lhs, value *rhs);
141 // Logical
142 value *create_and(value *lhs, value *rhs);
143 value *create_xor(value *lhs, value *rhs);
144 value *create_or(value *lhs, value *rhs);
145 // Input/Output
146 value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
147 value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
148 value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
149 value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
150 // Struct instructions
151 value *create_insert_value(value* val, value *elt, size_t idx);
152 value *create_extract_value(value* val, size_t idx);
153 // Block instruction
154 value *create_splat(value *arg, const type::block_shapes_t &shapes);
155 value *create_reshape(value *arg, const type::block_shapes_t &shapes);
156 value *create_cat(value *lhs, value *rhs);
157 value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
158 // Atomic instruction
159 value *create_atomic_cas(value *ptr, value *cmp, value *val);
160 value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
161 value *create_atomic_max(value *ptr, value *val, value *msk);
162 value *create_atomic_umax(value *ptr, value *val, value *msk);
163 value *create_atomic_min(value *ptr, value *val, value *msk);
164 value *create_atomic_umin(value *ptr, value *val, value *msk);
165 value *create_atomic_fadd(value *ptr, value *val, value *msk);
166 value *create_atomic_add(value *ptr, value *val, value *msk);
167 value *create_atomic_and(value *ptr, value *val, value *msk);
168 value *create_atomic_or(value *ptr, value *val, value *msk);
169 value *create_atomic_xor(value *ptr, value *val, value *msk);
170 value *create_atomic_xchg(value *ptr, value *val, value *msk);
171 // Utilities
172 value *create_clock();
173 value *create_globaltimer();
174 // Extern instruction
175 value *create_extern_elementwise(const std::string &lib_name,
176 const std::string &lib_path,
177 const std::string &symbol_name,
178 const std::vector<value *> &args,
179 type *ret_ty);
180 // Built-in instruction
181 value *create_get_program_id(unsigned axis);
182 value *create_get_num_programs(unsigned axis);
183 value *create_exp(value* arg);
184 value *create_cos(value* arg);
185 value *create_sin(value* arg);
186 value *create_log(value* arg);
187 value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
188 value *create_trans(value *A, const std::vector<int> &perm = {});
189 value *create_sqrt(value *A);
190 value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
191 value *create_select(value *pred, value *if_value, value *else_value);
192 // Intrinsics
193 // These have no place in the IR, and hopefully they can be removed at some point
194 value *create_umulhi(value* lhs, value* rhs);
195 value *create_copy_to_shared(value *arg);
196 value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
197 value *create_copy_from_shared(value *arg);
198 value *create_barrier(const std::string &name = "");
199 value *create_async_wait(int N);
200 value *create_prefetch_s(value *arg, int inc);
201
202private:
203 context &ctx_;
204 basic_block *block_;
205 iterator insert_point_;
206};
207
208
209}
210}
211
212#endif
213