1 | #ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_ |
2 | #define _TRITON_CODEGEN_ANALYSIS_GRID_H_ |
3 | |
4 | #include <map> |
5 | #include <set> |
6 | #include <vector> |
7 | #include <memory> |
8 | #include "triton/tools/graph.h" |
9 | #include "triton/codegen/target.h" |
10 | |
11 | namespace triton{ |
12 | |
13 | namespace ir{ |
14 | class value; |
15 | class type; |
16 | class module; |
17 | class instruction; |
18 | class phi_node; |
19 | } |
20 | |
21 | namespace codegen{ |
22 | namespace analysis{ |
23 | |
24 | class axes; |
25 | class align; |
26 | class layout_visitor; |
27 | class data_layout; |
28 | class mma_layout; |
29 | class scanline_layout; |
30 | class shared_layout; |
31 | |
32 | |
33 | class layout_visitor { |
34 | public: |
35 | virtual void visit_layout(data_layout *); |
36 | virtual void visit_layout_mma(mma_layout*) = 0; |
37 | virtual void visit_layout_scanline(scanline_layout*) = 0; |
38 | virtual void visit_layout_shared(shared_layout*) = 0; |
39 | }; |
40 | |
41 | class data_layout { |
42 | protected: |
43 | enum id_t { |
44 | MMA, |
45 | SCANLINE, |
46 | SHARED |
47 | }; |
48 | |
49 | typedef std::vector<int> axes_t; |
50 | typedef std::vector<unsigned> shape_t; |
51 | typedef std::vector<int> order_t; |
52 | typedef std::vector<ir::value*> values_t; |
53 | |
54 | private: |
55 | template<typename T> |
56 | T* downcast(id_t id) { |
57 | if(id_ == id) |
58 | return static_cast<T*>(this); |
59 | return nullptr; |
60 | } |
61 | |
62 | public: |
63 | data_layout(id_t id, |
64 | const std::vector<int>& axes, |
65 | const std::vector<unsigned> &shape, |
66 | const std::vector<ir::value *> &values, |
67 | analysis::align* align); |
68 | // visitor |
69 | virtual void accept(layout_visitor* vst) = 0; |
70 | // downcast |
71 | mma_layout* to_mma() { return downcast<mma_layout>(MMA); } |
72 | scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); } |
73 | shared_layout* to_shared() { return downcast<shared_layout>(SHARED); } |
74 | // accessors |
75 | size_t get_rank() { return shape_.size(); } |
76 | const shape_t& get_shape() const { return shape_; } |
77 | const order_t& get_order() const { return order_; } |
78 | const values_t& get_values() const { return values_;} |
79 | int get_axis(size_t k) const { return axes_.at(k); } |
80 | std::vector<int> get_axes() const { return axes_; } |
81 | const int get_order(size_t k) const { return order_.at(k); } |
82 | // find the position of given axis |
83 | int find_axis(int to_find) const; |
84 | |
85 | |
86 | private: |
87 | id_t id_; |
88 | axes_t axes_; |
89 | values_t values_; |
90 | |
91 | protected: |
92 | order_t order_; |
93 | shape_t shape_; |
94 | }; |
95 | |
96 | class distributed_layout: public data_layout{ |
97 | public: |
98 | distributed_layout(id_t id, |
99 | const std::vector<int>& axes, |
100 | const std::vector<unsigned>& shape, |
101 | const std::vector<ir::value*>& values, |
102 | analysis::align* align); |
103 | |
104 | int shape_per_cta(size_t k) { return shape_per_cta_.at(k); } |
105 | int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; } |
106 | virtual int contig_per_thread(size_t k) = 0; |
107 | |
108 | protected: |
109 | std::vector<int> shape_per_cta_; |
110 | }; |
111 | |
112 | class mma_layout: public distributed_layout { |
113 | public: |
114 | enum TensorCoreType : uint8_t { |
115 | // floating-point tensor core instr |
116 | FP32_FP16_FP16_FP32 = 0, // default |
117 | FP32_BF16_BF16_FP32, |
118 | FP32_TF32_TF32_FP32, |
119 | // integer tensor core instr |
120 | INT32_INT1_INT1_INT32, // Not implemented |
121 | INT32_INT4_INT4_INT32, // Not implemented |
122 | INT32_INT8_INT8_INT32, // Not implemented |
123 | // |
124 | NOT_APPLICABLE, |
125 | }; |
126 | |
127 | // Used on nvidia GPUs with sm >= 80 |
128 | inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = { |
129 | {FP32_FP16_FP16_FP32, {16, 8, 16}}, |
130 | {FP32_BF16_BF16_FP32, {16, 8, 16}}, |
131 | {FP32_TF32_TF32_FP32, {16, 8, 8}}, |
132 | |
133 | {INT32_INT1_INT1_INT32, {16, 8, 256}}, |
134 | {INT32_INT4_INT4_INT32, {16, 8, 64}}, |
135 | {INT32_INT8_INT8_INT32, {16, 8, 32}}, |
136 | }; |
137 | |
138 | // shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices) |
139 | inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = { |
140 | {FP32_FP16_FP16_FP32, {8, 8, 8}}, |
141 | {FP32_BF16_BF16_FP32, {8, 8, 8}}, |
142 | {FP32_TF32_TF32_FP32, {8, 8, 4}}, |
143 | |
144 | {INT32_INT1_INT1_INT32, {8, 8, 64}}, |
145 | {INT32_INT4_INT4_INT32, {8, 8, 32}}, |
146 | {INT32_INT8_INT8_INT32, {8, 8, 16}}, |
147 | }; |
148 | |
149 | inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = { |
150 | {FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" }, |
151 | {FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" }, |
152 | {FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32" }, |
153 | |
154 | {INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc" }, |
155 | {INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32" }, |
156 | {INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32" }, |
157 | }; |
158 | |
159 | // vector length per ldmatrix (16*8/elelment_size_in_bits) |
160 | inline static const std::map<TensorCoreType, int> mma_instr_vec_ = { |
161 | {FP32_FP16_FP16_FP32, 8}, |
162 | {FP32_BF16_BF16_FP32, 8}, |
163 | {FP32_TF32_TF32_FP32, 4}, |
164 | |
165 | {INT32_INT1_INT1_INT32, 128}, |
166 | {INT32_INT4_INT4_INT32, 32}, |
167 | {INT32_INT8_INT8_INT32, 16}, |
168 | }; |
169 | |
170 | public: |
171 | mma_layout(size_t num_warps, |
172 | const std::vector<int>& axes, |
173 | const std::vector<unsigned>& shapes, |
174 | const std::vector<ir::value *> &values, |
175 | analysis::align* align, target *tgt, |
176 | shared_layout* layout_a, |
177 | shared_layout* layout_b, |
178 | ir::value *dot); |
179 | void accept(layout_visitor* vst) { vst->visit_layout_mma(this); } |
180 | // accessor |
181 | int fpw(size_t k) { return fpw_.at(k); } |
182 | int wpt(size_t k) { return wpt_.at(k); } |
183 | int spw(size_t k) { return spw_.at(k); } |
184 | int rep(size_t k) { return rep_.at(k); } |
185 | int contig_per_thread(size_t k) { return contig_per_thread_.at(k); } |
186 | |
187 | // helpers for generator.cc |
188 | std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); } |
189 | std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); } |
190 | std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); } |
191 | int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); } |
192 | int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); } |
193 | |
194 | // setter |
195 | void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; } |
196 | |
197 | private: |
198 | // fragment per warp |
199 | std::vector<int> fpw_; |
200 | // shape per warp |
201 | std::vector<int> spw_; |
202 | // warp per tile |
203 | std::vector<int> wpt_; |
204 | // shape per tile |
205 | std::vector<int> spt_; |
206 | // repetitions |
207 | std::vector<int> rep_; |
208 | // contiguous per thread |
209 | std::vector<int> contig_per_thread_; |
210 | |
211 | TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32; |
212 | }; |
213 | |
214 | class scanline_layout: public distributed_layout { |
215 | public: |
216 | scanline_layout(size_t num_warps, |
217 | const std::vector<int>& axes, |
218 | const std::vector<unsigned>& shape, |
219 | const std::vector<ir::value *> &values, |
220 | analysis::align* align, |
221 | target* tgt); |
222 | void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); } |
223 | // accessor |
224 | int mts(size_t k) { return mts_.at(k); } |
225 | int nts(size_t k) { return nts_.at(k); } |
226 | int contig_per_thread(size_t k) { return nts_.at(k); } |
227 | |
228 | int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);} |
229 | private: |
230 | // micro tile size. The size of a tile held by a thread block. |
231 | std::vector<int> mts_; |
232 | // nano tile size. The size of a tile held by a thread. |
233 | std::vector<int> nts_; |
234 | }; |
235 | |
236 | struct double_buffer_info_t { |
237 | ir::value* first; |
238 | ir::value* latch; |
239 | ir::phi_node* phi; |
240 | }; |
241 | |
242 | struct N_buffer_info_t { |
243 | std::vector<ir::value*> firsts; // not necessarily ordered as input order |
244 | ir::value* latch; |
245 | ir::phi_node* phi; |
246 | std::map<ir::value*, int> firsts_idx; |
247 | }; |
248 | |
249 | // abstract for dot and corresponding smem values |
250 | class shared_layout: public data_layout { |
251 | private: |
252 | static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator); |
253 | static void (ir::value *v, std::shared_ptr<double_buffer_info_t>& res); |
254 | static void (ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages); |
255 | |
256 | public: |
257 | shared_layout(data_layout *arg, |
258 | const std::vector<int>& axes, |
259 | const std::vector<unsigned>& shapes, |
260 | const std::vector<ir::value *> &values_, |
261 | ir::type *ty, |
262 | analysis::align* align, target *tgt, |
263 | bool is_tmp = false); |
264 | void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } |
265 | // accessors |
266 | size_t get_size() { return size_; } |
267 | ir::type* get_type() { return ty_; } |
268 | double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); } |
269 | N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); } |
270 | int get_num_stages() const; |
271 | size_t get_per_stage_size() const { return size_ / get_num_stages(); } |
272 | size_t get_per_stage_elements() const; |
273 | size_t get_num_per_phase() { return num_per_phase_; } |
274 | ir::value* hmma_dot_a() { return hmma_dot_a_; } |
275 | ir::value* hmma_dot_b() { return hmma_dot_b_; } |
276 | void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; } |
277 | int get_mma_vec() { return mma_vec_;} |
278 | int get_mma_strided() { return mma_strided_; } |
279 | bool allow_swizzle() const { return allow_swizzle_; } |
280 | data_layout* get_arg_layout() { return arg_layout_; } |
281 | bool is_tmp() const { return is_tmp_; } |
282 | |
283 | private: |
284 | size_t size_; |
285 | ir::type *ty_; |
286 | std::shared_ptr<double_buffer_info_t> double_buffer_; |
287 | std::shared_ptr<N_buffer_info_t> N_buffer_; |
288 | size_t num_per_phase_; |
289 | ir::value* hmma_dot_a_; |
290 | ir::value* hmma_dot_b_; |
291 | data_layout* arg_layout_; |
292 | int mma_vec_; |
293 | int mma_strided_; |
294 | bool allow_swizzle_ = true; |
295 | target *tgt_; |
296 | bool is_tmp_; |
297 | }; |
298 | |
299 | |
300 | |
301 | class layouts { |
302 | typedef ir::value* node_t; |
303 | typedef std::map <node_t, std::set<node_t>> graph_t; |
304 | |
305 | private: |
306 | // graph creation |
307 | void connect(ir::value *x, ir::value *y); |
308 | void make_graph(ir::instruction *i); |
309 | |
310 | void init_hmma_tile(data_layout& layouts); |
311 | void init_scanline_tile(data_layout &layouts); |
312 | |
313 | void create(size_t id, const std::vector<ir::value*>& values); |
314 | |
315 | void create_tmp_layout(size_t id, data_layout* arg, |
316 | const std::vector<int>& axes, |
317 | const std::vector<unsigned>& shape, |
318 | ir::instruction* i, |
319 | bool is_index = false); |
320 | |
321 | public: |
322 | // constructor |
323 | layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt); |
324 | |
325 | // accessors |
326 | unsigned layout_of(ir::value *value) const { return groups_.at(value); } |
327 | bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); } |
328 | bool has(size_t id) { return layouts_.find(id) != layouts_.end(); } |
329 | const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); } |
330 | size_t num_layouts() const { return values_.size();} |
331 | data_layout* get(size_t id) { return layouts_.at(id); } |
332 | data_layout* get(ir::value *v) { return get(layout_of(v));} |
333 | std::map<size_t, data_layout*> &get_all() { return layouts_; } |
334 | bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); } |
335 | int tmp(ir::value* i) { return tmp_.at(i);} |
336 | int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); } |
337 | int tmp_index(ir::value* i) { return tmp_index_.at(i);} |
338 | void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; } |
339 | |
340 | // layout checkers |
341 | bool is_scanline(ir::instruction* i); |
342 | |
343 | bool is_coalesced_scanline(ir::instruction* i); |
344 | |
345 | bool is_mma(ir::instruction* i); |
346 | |
347 | bool is_a100_mma(ir::instruction* i); |
348 | |
349 | // execution |
350 | void run(ir::module &mod); |
351 | |
352 | private: |
353 | analysis::axes* axes_; |
354 | analysis::align* align_; |
355 | size_t num_warps_; |
356 | target* tgt_; |
357 | tools::graph<ir::value*> graph_; |
358 | std::map<ir::value*, size_t> groups_; |
359 | std::map<size_t, std::vector<ir::value*>> values_; |
360 | std::map<size_t, data_layout*> layouts_; |
361 | std::map<ir::value*, size_t> tmp_; |
362 | std::map<ir::value*, size_t> tmp_index_; |
363 | }; |
364 | |
365 | } |
366 | } |
367 | |
368 | } |
369 | |
370 | #endif |
371 | |