1 | #pragma once |
2 | |
3 | #ifndef _TRITON_SELECTION_GENERATOR_H_ |
4 | #define _TRITON_SELECTION_GENERATOR_H_ |
5 | |
6 | #include "triton/ir/visitor.h" |
7 | #include "triton/ir/instructions.h" |
8 | #include "triton/codegen/analysis/layout.h" |
9 | #include "triton/codegen/extern_lib.h" |
10 | #include <functional> |
11 | |
12 | // forward |
13 | namespace llvm{ |
14 | class Type; |
15 | class Value; |
16 | class PHINode; |
17 | class BasicBlock; |
18 | class Attribute; |
19 | class Instruction; |
20 | class Constant; |
21 | class LLVMContext; |
22 | class Module; |
23 | class ConstantFolder; |
24 | class IRBuilderDefaultInserter; |
25 | template <typename T, typename Inserter> |
26 | class IRBuilder; |
27 | class ArrayType; |
28 | class Function; |
29 | class StructType; |
30 | } |
31 | |
32 | namespace triton{ |
33 | |
34 | namespace ir{ |
35 | class attribute; |
36 | class load_inst; |
37 | class store_inst; |
38 | } |
39 | |
40 | namespace codegen{ |
41 | |
42 | // forward |
43 | namespace analysis{ |
44 | class liveness; |
45 | class tiles; |
46 | class align; |
47 | class allocation; |
48 | class cts; |
49 | class axes; |
50 | class layouts; |
51 | class swizzle; |
52 | } |
53 | // typedef |
54 | typedef llvm::IRBuilder<llvm::ConstantFolder, |
55 | llvm::IRBuilderDefaultInserter> Builder; |
56 | typedef llvm::LLVMContext LLVMContext; |
57 | typedef llvm::Type Type; |
58 | typedef llvm::Value Value; |
59 | typedef llvm::Attribute Attribute; |
60 | typedef llvm::BasicBlock BasicBlock; |
61 | typedef llvm::Module Module; |
62 | typedef llvm::Instruction Instruction; |
63 | typedef llvm::Constant Constant; |
64 | typedef llvm::ArrayType ArrayType; |
65 | typedef llvm::Function Function; |
66 | typedef std::vector<Value*> indices_t; |
67 | class target; |
68 | |
69 | } |
70 | } |
71 | |
72 | namespace triton{ |
73 | namespace codegen{ |
74 | |
75 | struct distributed_axis { |
76 | int contiguous; |
77 | std::vector<Value*> values; |
78 | Value* thread_id; |
79 | }; |
80 | |
81 | class adder{ |
82 | public: |
83 | adder(Builder** builder): builder_(builder) { } |
84 | Value* operator()(Value* x, Value* y, const std::string& name = "" ); |
85 | |
86 | private: |
87 | Builder** builder_; |
88 | }; |
89 | |
90 | class multiplier{ |
91 | public: |
92 | multiplier(Builder** builder): builder_(builder) { } |
93 | Value* operator()(Value* x, Value* y, const std::string& name = "" ); |
94 | private: |
95 | Builder** builder_; |
96 | }; |
97 | |
98 | class geper{ |
99 | public: |
100 | geper(Builder** builder): builder_(builder) { } |
101 | Value* operator()(Value *ptr, Value* off, const std::string& name = "" ); |
102 | Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "" ); |
103 | |
104 | private: |
105 | Builder** builder_; |
106 | }; |
107 | |
108 | class generator: public ir::visitor, public analysis::layout_visitor { |
109 | private: |
110 | void init_idx(ir::value *x); |
111 | Instruction* add_barrier(); |
112 | Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx); |
113 | void finalize_shared_layout(analysis::shared_layout*); |
114 | void finalize_function(ir::function*); |
115 | void finalize_phi_node(ir::phi_node*); |
116 | |
117 | private: |
118 | Type *cvt(ir::type *ty); |
119 | llvm::Attribute cvt(ir::attribute attr); |
120 | void packed_type(ir::value* i); |
121 | void forward_declare(ir::function* fn); |
122 | Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty); |
123 | |
124 | private: |
125 | typedef std::function<void( |
126 | std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn, |
127 | std::function<Value *()> load_index_fn, bool is_first)> |
128 | acc_fn_t; |
129 | |
130 | public: |
131 | generator(analysis::axes *a_axes, |
132 | analysis::layouts *layouts, |
133 | analysis::align *alignment, |
134 | analysis::allocation *alloc, |
135 | analysis::swizzle *swizzle, |
136 | target *tgt, |
137 | unsigned num_warps); |
138 | |
139 | void visit_value(ir::value* v); |
140 | void visit_call_inst(ir::call_inst*); |
141 | void visit_launch_inst(ir::launch_inst *); |
142 | void visit_phi_node(ir::phi_node*); |
143 | void visit_binary_operator(ir::binary_operator*); |
144 | void visit_getelementptr_inst(ir::getelementptr_inst*); |
145 | void visit_icmp_inst(ir::icmp_inst*); |
146 | void visit_fcmp_inst(ir::fcmp_inst*); |
147 | std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3); |
148 | std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); |
149 | std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3); |
150 | std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); |
151 | std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3); |
152 | std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); |
153 | Value* bf16_to_fp32(Value *in0); |
154 | Value* fp32_to_bf16(Value *in0); |
155 | std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int16_to_float16x8( |
156 | Value *in0, Value *scale_x512, Value *shift |
157 | ); |
158 | std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int32_to_float16x8( |
159 | Value *in0, Value *scale_x512, Value *shift |
160 | ); |
161 | std::tuple<Value*, Value*, Value*, Value*> int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift); |
162 | std::tuple<Value*, Value*> prepare_scale_shift(Value *scale, Value *shift); |
163 | void visit_dequantize_inst(ir::dequantize_inst*); |
164 | void visit_cast_inst(ir::cast_inst*); |
165 | void visit_return_inst(ir::return_inst*); |
166 | void visit_cond_branch_inst(ir::cond_branch_inst*); |
167 | void visit_uncond_branch_inst(ir::uncond_branch_inst*); |
168 | void visit_load_inst(ir::load_inst*); |
169 | void visit_unmasked_load_inst(ir::unmasked_load_inst*); |
170 | void visit_masked_load_inst(ir::masked_load_inst*); |
171 | void visit_store_inst(ir::store_inst*); |
172 | void visit_unmasked_store_inst(ir::unmasked_store_inst*); |
173 | void visit_masked_store_inst(ir::masked_store_inst*); |
174 | void visit_cat_inst(ir::cat_inst*); |
175 | void (ir::extract_value_inst *); |
176 | void visit_insert_value_inst(ir::insert_value_inst *); |
177 | void visit_reshape_inst(ir::reshape_inst*); |
178 | void visit_splat_inst(ir::splat_inst*); |
179 | void visit_broadcast_inst(ir::broadcast_inst*); |
180 | void visit_downcast_inst(ir::downcast_inst*); |
181 | void visit_exp_inst(ir::exp_inst*); |
182 | void visit_cos_inst(ir::cos_inst*); |
183 | void visit_umulhi_inst(ir::umulhi_inst* x); |
184 | void visit_sin_inst(ir::sin_inst*); |
185 | void visit_log_inst(ir::log_inst*); |
186 | void visit_get_program_id_inst(ir::get_program_id_inst*); |
187 | void visit_get_num_programs_inst(ir::get_num_programs_inst*); |
188 | void visit_atomic_cas_inst(ir::atomic_cas_inst*); |
189 | void visit_atomic_rmw_inst(ir::atomic_rmw_inst*); |
190 | void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); |
191 | void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); |
192 | void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add); |
193 | void visit_dot_inst(ir::dot_inst*); |
194 | void visit_trans_inst(ir::trans_inst*); |
195 | void visit_sqrt_inst(ir::sqrt_inst*); |
196 | Value* shfl_sync(Value* acc, int32_t i); |
197 | void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral); |
198 | void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral); |
199 | void visit_reduce_inst(ir::reduce_inst*); |
200 | void visit_select_inst(ir::select_inst*); |
201 | void visit_layout_convert(ir::value *out, ir::value *in); |
202 | void visit_cvt_layout_inst(ir::cvt_layout_inst*); |
203 | void visit_masked_load_async_inst(ir::masked_load_async_inst*); |
204 | void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); |
205 | void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); |
206 | void visit_barrier_inst(ir::barrier_inst*); |
207 | void visit_prefetch_s_inst(ir::prefetch_s_inst*); |
208 | void visit_async_wait_inst(ir::async_wait_inst*); |
209 | // void visit_make_range_dyn(ir::make_range_dyn*); |
210 | void visit_make_range(ir::make_range*); |
211 | void visit_clock_inst(ir::clock_inst*); |
212 | void visit_globaltimer_inst(ir::globaltimer_inst*); |
213 | void visit_extern_elementwise_inst(ir::extern_elementwise_inst*); |
214 | // void visit_make_range_sta(ir::make_range_sta*); |
215 | void visit_undef_value(ir::undef_value*); |
216 | void visit_constant_int(ir::constant_int*); |
217 | void visit_constant_fp(ir::constant_fp*); |
218 | void visit_alloc_const(ir::alloc_const*); |
219 | void visit_function(ir::function*); |
220 | void visit_basic_block(ir::basic_block*); |
221 | void visit_argument(ir::argument*); |
222 | void visit(ir::module &, llvm::Module &); |
223 | |
224 | // layouts |
225 | void visit_layout_mma(analysis::mma_layout*); |
226 | void visit_layout_scanline(analysis::scanline_layout*); |
227 | void visit_layout_shared(analysis::shared_layout*); |
228 | |
229 | // Add a new external library based on given name and path if it doesn't exist |
230 | void add_extern_lib(const std::string &lib_name, const std::string &lib_path); |
231 | |
232 | // Get all external libraries |
233 | const ExternLibMap &get_extern_lib_map() { |
234 | return extern_lib_map_; |
235 | } |
236 | |
237 | private: |
238 | LLVMContext *ctx_; |
239 | Builder* builder_; |
240 | Module *mod_; |
241 | |
242 | std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_; |
243 | |
244 | analysis::axes *a_axes_; |
245 | analysis::swizzle *swizzle_; |
246 | std::map<unsigned, distributed_axis> axes_; |
247 | target *tgt_; |
248 | analysis::layouts *layouts_; |
249 | analysis::align *alignment_; |
250 | analysis::allocation *alloc_; |
251 | Value *shmem_; |
252 | std::set<ir::value*> seen_; |
253 | |
254 | unsigned num_warps_; |
255 | |
256 | std::map<analysis::data_layout*, Value*> offset_a_m_; |
257 | std::map<analysis::data_layout*, Value*> offset_a_k_; |
258 | std::map<analysis::data_layout*, Value*> offset_b_k_; |
259 | std::map<analysis::data_layout*, Value*> offset_b_n_; |
260 | |
261 | /// layout -> base ptr |
262 | std::map<analysis::data_layout*, Value*> shared_ptr_; |
263 | std::map<analysis::data_layout*, Value*> shared_pre_ptr_; |
264 | std::map<analysis::data_layout*, Value*> shared_next_ptr_; |
265 | /// offset for double-buffered layout |
266 | std::map<analysis::data_layout*, Value*> shared_off_; |
267 | |
268 | /// Base shmem pointer of ir value |
269 | std::map<ir::value*, Value*> shmems_; |
270 | std::map<ir::value*, Value*> shoffs_; |
271 | std::map<ir::value*, std::vector<indices_t>> idxs_; |
272 | std::map<ir::value*, std::map<indices_t, Value*>> vals_; |
273 | /// idx for multi-stage pipeline |
274 | std::map<analysis::data_layout*, Value*> read_smem_idx_; |
275 | std::map<analysis::data_layout*, Value*> write_smem_idx_; |
276 | |
277 | /// triton bb -> llvm bb |
278 | std::map<ir::value*, BasicBlock *> bbs_; |
279 | std::map<ir::value*, std::vector<int>> ords_; |
280 | std::map<ir::value*, Function*> fns_; |
281 | |
282 | // helper for creating llvm values |
283 | adder add; |
284 | multiplier mul; |
285 | geper gep; |
286 | |
287 | /// PHI nodes |
288 | std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_; |
289 | |
290 | /// Record prefetch instrs that needs to be moved |
291 | std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_; |
292 | |
293 | // Eviction policies |
294 | std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_; |
295 | }; |
296 | |
297 | } |
298 | } |
299 | |
300 | #endif |
301 | |