1#pragma once
2
3#ifndef _TRITON_SELECTION_GENERATOR_H_
4#define _TRITON_SELECTION_GENERATOR_H_
5
6#include "triton/ir/visitor.h"
7#include "triton/ir/instructions.h"
8#include "triton/codegen/analysis/layout.h"
9#include "triton/codegen/extern_lib.h"
10#include <functional>
11
12// forward
13namespace llvm{
14 class Type;
15 class Value;
16 class PHINode;
17 class BasicBlock;
18 class Attribute;
19 class Instruction;
20 class Constant;
21 class LLVMContext;
22 class Module;
23 class ConstantFolder;
24 class IRBuilderDefaultInserter;
25 template <typename T, typename Inserter>
26 class IRBuilder;
27 class ArrayType;
28 class Function;
29 class StructType;
30}
31
32namespace triton{
33
34namespace ir{
35class attribute;
36class load_inst;
37class store_inst;
38}
39
40namespace codegen{
41
42// forward
43namespace analysis{
44class liveness;
45class tiles;
46class align;
47class allocation;
48class cts;
49class axes;
50class layouts;
51class swizzle;
52}
53// typedef
54typedef llvm::IRBuilder<llvm::ConstantFolder,
55 llvm::IRBuilderDefaultInserter> Builder;
56typedef llvm::LLVMContext LLVMContext;
57typedef llvm::Type Type;
58typedef llvm::Value Value;
59typedef llvm::Attribute Attribute;
60typedef llvm::BasicBlock BasicBlock;
61typedef llvm::Module Module;
62typedef llvm::Instruction Instruction;
63typedef llvm::Constant Constant;
64typedef llvm::ArrayType ArrayType;
65typedef llvm::Function Function;
66typedef std::vector<Value*> indices_t;
67class target;
68
69}
70}
71
72namespace triton{
73namespace codegen{
74
75struct distributed_axis {
76 int contiguous;
77 std::vector<Value*> values;
78 Value* thread_id;
79};
80
81class adder{
82public:
83 adder(Builder** builder): builder_(builder) { }
84 Value* operator()(Value* x, Value* y, const std::string& name = "");
85
86private:
87 Builder** builder_;
88};
89
90class multiplier{
91public:
92 multiplier(Builder** builder): builder_(builder) { }
93 Value* operator()(Value* x, Value* y, const std::string& name = "");
94private:
95 Builder** builder_;
96};
97
98class geper{
99public:
100 geper(Builder** builder): builder_(builder) { }
101 Value* operator()(Value *ptr, Value* off, const std::string& name = "");
102 Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
103
104private:
105 Builder** builder_;
106};
107
108class generator: public ir::visitor, public analysis::layout_visitor {
109private:
110 void init_idx(ir::value *x);
111 Instruction* add_barrier();
112 Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
113 void finalize_shared_layout(analysis::shared_layout*);
114 void finalize_function(ir::function*);
115 void finalize_phi_node(ir::phi_node*);
116
117private:
118 Type *cvt(ir::type *ty);
119 llvm::Attribute cvt(ir::attribute attr);
120 void packed_type(ir::value* i);
121 void forward_declare(ir::function* fn);
122 Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
123
124 private:
125 typedef std::function<void(
126 std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
127 std::function<Value *()> load_index_fn, bool is_first)>
128 acc_fn_t;
129
130 public:
131 generator(analysis::axes *a_axes,
132 analysis::layouts *layouts,
133 analysis::align *alignment,
134 analysis::allocation *alloc,
135 analysis::swizzle *swizzle,
136 target *tgt,
137 unsigned num_warps);
138
139 void visit_value(ir::value* v);
140 void visit_call_inst(ir::call_inst*);
141 void visit_launch_inst(ir::launch_inst *);
142 void visit_phi_node(ir::phi_node*);
143 void visit_binary_operator(ir::binary_operator*);
144 void visit_getelementptr_inst(ir::getelementptr_inst*);
145 void visit_icmp_inst(ir::icmp_inst*);
146 void visit_fcmp_inst(ir::fcmp_inst*);
147 std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
148 std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
149 std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
150 std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
151 std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3);
152 std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
153 Value* bf16_to_fp32(Value *in0);
154 Value* fp32_to_bf16(Value *in0);
155 std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int16_to_float16x8(
156 Value *in0, Value *scale_x512, Value *shift
157 );
158 std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int32_to_float16x8(
159 Value *in0, Value *scale_x512, Value *shift
160 );
161 std::tuple<Value*, Value*, Value*, Value*> int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift);
162 std::tuple<Value*, Value*> prepare_scale_shift(Value *scale, Value *shift);
163 void visit_dequantize_inst(ir::dequantize_inst*);
164 void visit_cast_inst(ir::cast_inst*);
165 void visit_return_inst(ir::return_inst*);
166 void visit_cond_branch_inst(ir::cond_branch_inst*);
167 void visit_uncond_branch_inst(ir::uncond_branch_inst*);
168 void visit_load_inst(ir::load_inst*);
169 void visit_unmasked_load_inst(ir::unmasked_load_inst*);
170 void visit_masked_load_inst(ir::masked_load_inst*);
171 void visit_store_inst(ir::store_inst*);
172 void visit_unmasked_store_inst(ir::unmasked_store_inst*);
173 void visit_masked_store_inst(ir::masked_store_inst*);
174 void visit_cat_inst(ir::cat_inst*);
175 void visit_extract_value_inst(ir::extract_value_inst *);
176 void visit_insert_value_inst(ir::insert_value_inst *);
177 void visit_reshape_inst(ir::reshape_inst*);
178 void visit_splat_inst(ir::splat_inst*);
179 void visit_broadcast_inst(ir::broadcast_inst*);
180 void visit_downcast_inst(ir::downcast_inst*);
181 void visit_exp_inst(ir::exp_inst*);
182 void visit_cos_inst(ir::cos_inst*);
183 void visit_umulhi_inst(ir::umulhi_inst* x);
184 void visit_sin_inst(ir::sin_inst*);
185 void visit_log_inst(ir::log_inst*);
186 void visit_get_program_id_inst(ir::get_program_id_inst*);
187 void visit_get_num_programs_inst(ir::get_num_programs_inst*);
188 void visit_atomic_cas_inst(ir::atomic_cas_inst*);
189 void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
190 void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
191 void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
192 void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
193 void visit_dot_inst(ir::dot_inst*);
194 void visit_trans_inst(ir::trans_inst*);
195 void visit_sqrt_inst(ir::sqrt_inst*);
196 Value* shfl_sync(Value* acc, int32_t i);
197 void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
198 void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
199 void visit_reduce_inst(ir::reduce_inst*);
200 void visit_select_inst(ir::select_inst*);
201 void visit_layout_convert(ir::value *out, ir::value *in);
202 void visit_cvt_layout_inst(ir::cvt_layout_inst*);
203 void visit_masked_load_async_inst(ir::masked_load_async_inst*);
204 void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
205 void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
206 void visit_barrier_inst(ir::barrier_inst*);
207 void visit_prefetch_s_inst(ir::prefetch_s_inst*);
208 void visit_async_wait_inst(ir::async_wait_inst*);
209// void visit_make_range_dyn(ir::make_range_dyn*);
210 void visit_make_range(ir::make_range*);
211 void visit_clock_inst(ir::clock_inst*);
212 void visit_globaltimer_inst(ir::globaltimer_inst*);
213 void visit_extern_elementwise_inst(ir::extern_elementwise_inst*);
214// void visit_make_range_sta(ir::make_range_sta*);
215 void visit_undef_value(ir::undef_value*);
216 void visit_constant_int(ir::constant_int*);
217 void visit_constant_fp(ir::constant_fp*);
218 void visit_alloc_const(ir::alloc_const*);
219 void visit_function(ir::function*);
220 void visit_basic_block(ir::basic_block*);
221 void visit_argument(ir::argument*);
222 void visit(ir::module &, llvm::Module &);
223
224 // layouts
225 void visit_layout_mma(analysis::mma_layout*);
226 void visit_layout_scanline(analysis::scanline_layout*);
227 void visit_layout_shared(analysis::shared_layout*);
228
229 // Add a new external library based on given name and path if it doesn't exist
230 void add_extern_lib(const std::string &lib_name, const std::string &lib_path);
231
232 // Get all external libraries
233 const ExternLibMap &get_extern_lib_map() {
234 return extern_lib_map_;
235 }
236
237 private:
238 LLVMContext *ctx_;
239 Builder* builder_;
240 Module *mod_;
241
242 std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_;
243
244 analysis::axes *a_axes_;
245 analysis::swizzle *swizzle_;
246 std::map<unsigned, distributed_axis> axes_;
247 target *tgt_;
248 analysis::layouts *layouts_;
249 analysis::align *alignment_;
250 analysis::allocation *alloc_;
251 Value *shmem_;
252 std::set<ir::value*> seen_;
253
254 unsigned num_warps_;
255
256 std::map<analysis::data_layout*, Value*> offset_a_m_;
257 std::map<analysis::data_layout*, Value*> offset_a_k_;
258 std::map<analysis::data_layout*, Value*> offset_b_k_;
259 std::map<analysis::data_layout*, Value*> offset_b_n_;
260
261 /// layout -> base ptr
262 std::map<analysis::data_layout*, Value*> shared_ptr_;
263 std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
264 std::map<analysis::data_layout*, Value*> shared_next_ptr_;
265 /// offset for double-buffered layout
266 std::map<analysis::data_layout*, Value*> shared_off_;
267
268 /// Base shmem pointer of ir value
269 std::map<ir::value*, Value*> shmems_;
270 std::map<ir::value*, Value*> shoffs_;
271 std::map<ir::value*, std::vector<indices_t>> idxs_;
272 std::map<ir::value*, std::map<indices_t, Value*>> vals_;
273 /// idx for multi-stage pipeline
274 std::map<analysis::data_layout*, Value*> read_smem_idx_;
275 std::map<analysis::data_layout*, Value*> write_smem_idx_;
276
277 /// triton bb -> llvm bb
278 std::map<ir::value*, BasicBlock *> bbs_;
279 std::map<ir::value*, std::vector<int>> ords_;
280 std::map<ir::value*, Function*> fns_;
281
282 // helper for creating llvm values
283 adder add;
284 multiplier mul;
285 geper gep;
286
287 /// PHI nodes
288 std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
289
290 /// Record prefetch instrs that needs to be moved
291 std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
292
293 // Eviction policies
294 std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_;
295};
296
297}
298}
299
300#endif
301