1#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
2#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
3
4#include <map>
5#include <set>
6#include <vector>
7#include <memory>
8#include "triton/tools/graph.h"
9#include "triton/codegen/target.h"
10
11namespace triton{
12
13namespace ir{
14 class value;
15 class type;
16 class module;
17 class instruction;
18 class phi_node;
19}
20
21namespace codegen{
22namespace analysis{
23
24class axes;
25class align;
26class layout_visitor;
27class data_layout;
28class mma_layout;
29class scanline_layout;
30class shared_layout;
31
32
33class layout_visitor {
34public:
35 virtual void visit_layout(data_layout *);
36 virtual void visit_layout_mma(mma_layout*) = 0;
37 virtual void visit_layout_scanline(scanline_layout*) = 0;
38 virtual void visit_layout_shared(shared_layout*) = 0;
39};
40
41class data_layout {
42protected:
43 enum id_t {
44 MMA,
45 SCANLINE,
46 SHARED
47 };
48
49 typedef std::vector<int> axes_t;
50 typedef std::vector<unsigned> shape_t;
51 typedef std::vector<int> order_t;
52 typedef std::vector<ir::value*> values_t;
53
54private:
55 template<typename T>
56 T* downcast(id_t id) {
57 if(id_ == id)
58 return static_cast<T*>(this);
59 return nullptr;
60 }
61
62public:
63 data_layout(id_t id,
64 const std::vector<int>& axes,
65 const std::vector<unsigned> &shape,
66 const std::vector<ir::value *> &values,
67 analysis::align* align);
68 // visitor
69 virtual void accept(layout_visitor* vst) = 0;
70 // downcast
71 mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
72 scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
73 shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
74 // accessors
75 size_t get_rank() { return shape_.size(); }
76 const shape_t& get_shape() const { return shape_; }
77 const order_t& get_order() const { return order_; }
78 const values_t& get_values() const { return values_;}
79 int get_axis(size_t k) const { return axes_.at(k); }
80 std::vector<int> get_axes() const { return axes_; }
81 const int get_order(size_t k) const { return order_.at(k); }
82 // find the position of given axis
83 int find_axis(int to_find) const;
84
85
86private:
87 id_t id_;
88 axes_t axes_;
89 values_t values_;
90
91protected:
92 order_t order_;
93 shape_t shape_;
94};
95
96class distributed_layout: public data_layout{
97public:
98 distributed_layout(id_t id,
99 const std::vector<int>& axes,
100 const std::vector<unsigned>& shape,
101 const std::vector<ir::value*>& values,
102 analysis::align* align);
103
104 int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
105 int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
106 virtual int contig_per_thread(size_t k) = 0;
107
108protected:
109 std::vector<int> shape_per_cta_;
110};
111
112class mma_layout: public distributed_layout {
113public:
114 enum TensorCoreType : uint8_t {
115 // floating-point tensor core instr
116 FP32_FP16_FP16_FP32 = 0, // default
117 FP32_BF16_BF16_FP32,
118 FP32_TF32_TF32_FP32,
119 // integer tensor core instr
120 INT32_INT1_INT1_INT32, // Not implemented
121 INT32_INT4_INT4_INT32, // Not implemented
122 INT32_INT8_INT8_INT32, // Not implemented
123 //
124 NOT_APPLICABLE,
125 };
126
127 // Used on nvidia GPUs with sm >= 80
128 inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
129 {FP32_FP16_FP16_FP32, {16, 8, 16}},
130 {FP32_BF16_BF16_FP32, {16, 8, 16}},
131 {FP32_TF32_TF32_FP32, {16, 8, 8}},
132
133 {INT32_INT1_INT1_INT32, {16, 8, 256}},
134 {INT32_INT4_INT4_INT32, {16, 8, 64}},
135 {INT32_INT8_INT8_INT32, {16, 8, 32}},
136 };
137
138 // shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
139 inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
140 {FP32_FP16_FP16_FP32, {8, 8, 8}},
141 {FP32_BF16_BF16_FP32, {8, 8, 8}},
142 {FP32_TF32_TF32_FP32, {8, 8, 4}},
143
144 {INT32_INT1_INT1_INT32, {8, 8, 64}},
145 {INT32_INT4_INT4_INT32, {8, 8, 32}},
146 {INT32_INT8_INT8_INT32, {8, 8, 16}},
147 };
148
149 inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
150 {FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
151 {FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
152 {FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
153
154 {INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
155 {INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
156 {INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
157 };
158
159 // vector length per ldmatrix (16*8/elelment_size_in_bits)
160 inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
161 {FP32_FP16_FP16_FP32, 8},
162 {FP32_BF16_BF16_FP32, 8},
163 {FP32_TF32_TF32_FP32, 4},
164
165 {INT32_INT1_INT1_INT32, 128},
166 {INT32_INT4_INT4_INT32, 32},
167 {INT32_INT8_INT8_INT32, 16},
168 };
169
170public:
171 mma_layout(size_t num_warps,
172 const std::vector<int>& axes,
173 const std::vector<unsigned>& shapes,
174 const std::vector<ir::value *> &values,
175 analysis::align* align, target *tgt,
176 shared_layout* layout_a,
177 shared_layout* layout_b,
178 ir::value *dot);
179 void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
180 // accessor
181 int fpw(size_t k) { return fpw_.at(k); }
182 int wpt(size_t k) { return wpt_.at(k); }
183 int spw(size_t k) { return spw_.at(k); }
184 int rep(size_t k) { return rep_.at(k); }
185 int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
186
187 // helpers for generator.cc
188 std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
189 std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
190 std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
191 int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
192 int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
193
194 // setter
195 void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
196
197private:
198 // fragment per warp
199 std::vector<int> fpw_;
200 // shape per warp
201 std::vector<int> spw_;
202 // warp per tile
203 std::vector<int> wpt_;
204 // shape per tile
205 std::vector<int> spt_;
206 // repetitions
207 std::vector<int> rep_;
208 // contiguous per thread
209 std::vector<int> contig_per_thread_;
210
211 TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
212};
213
214class scanline_layout: public distributed_layout {
215public:
216 scanline_layout(size_t num_warps,
217 const std::vector<int>& axes,
218 const std::vector<unsigned>& shape,
219 const std::vector<ir::value *> &values,
220 analysis::align* align,
221 target* tgt);
222 void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
223 // accessor
224 int mts(size_t k) { return mts_.at(k); }
225 int nts(size_t k) { return nts_.at(k); }
226 int contig_per_thread(size_t k) { return nts_.at(k); }
227
228 int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);}
229private:
230 // micro tile size. The size of a tile held by a thread block.
231 std::vector<int> mts_;
232 // nano tile size. The size of a tile held by a thread.
233 std::vector<int> nts_;
234};
235
236struct double_buffer_info_t {
237 ir::value* first;
238 ir::value* latch;
239 ir::phi_node* phi;
240};
241
242struct N_buffer_info_t {
243 std::vector<ir::value*> firsts; // not necessarily ordered as input order
244 ir::value* latch;
245 ir::phi_node* phi;
246 std::map<ir::value*, int> firsts_idx;
247};
248
249// abstract for dot and corresponding smem values
250class shared_layout: public data_layout {
251private:
252 static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
253 static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
254 static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
255
256public:
257 shared_layout(data_layout *arg,
258 const std::vector<int>& axes,
259 const std::vector<unsigned>& shapes,
260 const std::vector<ir::value *> &values_,
261 ir::type *ty,
262 analysis::align* align, target *tgt,
263 bool is_tmp = false);
264 void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
265 // accessors
266 size_t get_size() { return size_; }
267 ir::type* get_type() { return ty_; }
268 double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
269 N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
270 int get_num_stages() const;
271 size_t get_per_stage_size() const { return size_ / get_num_stages(); }
272 size_t get_per_stage_elements() const;
273 size_t get_num_per_phase() { return num_per_phase_; }
274 ir::value* hmma_dot_a() { return hmma_dot_a_; }
275 ir::value* hmma_dot_b() { return hmma_dot_b_; }
276 void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
277 int get_mma_vec() { return mma_vec_;}
278 int get_mma_strided() { return mma_strided_; }
279 bool allow_swizzle() const { return allow_swizzle_; }
280 data_layout* get_arg_layout() { return arg_layout_; }
281 bool is_tmp() const { return is_tmp_; }
282
283private:
284 size_t size_;
285 ir::type *ty_;
286 std::shared_ptr<double_buffer_info_t> double_buffer_;
287 std::shared_ptr<N_buffer_info_t> N_buffer_;
288 size_t num_per_phase_;
289 ir::value* hmma_dot_a_;
290 ir::value* hmma_dot_b_;
291 data_layout* arg_layout_;
292 int mma_vec_;
293 int mma_strided_;
294 bool allow_swizzle_ = true;
295 target *tgt_;
296 bool is_tmp_;
297};
298
299
300
301class layouts {
302 typedef ir::value* node_t;
303 typedef std::map <node_t, std::set<node_t>> graph_t;
304
305private:
306 // graph creation
307 void connect(ir::value *x, ir::value *y);
308 void make_graph(ir::instruction *i);
309
310 void init_hmma_tile(data_layout& layouts);
311 void init_scanline_tile(data_layout &layouts);
312
313 void create(size_t id, const std::vector<ir::value*>& values);
314
315 void create_tmp_layout(size_t id, data_layout* arg,
316 const std::vector<int>& axes,
317 const std::vector<unsigned>& shape,
318 ir::instruction* i,
319 bool is_index = false);
320
321 public:
322 // constructor
323 layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
324
325 // accessors
326 unsigned layout_of(ir::value *value) const { return groups_.at(value); }
327 bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
328 bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
329 const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
330 size_t num_layouts() const { return values_.size();}
331 data_layout* get(size_t id) { return layouts_.at(id); }
332 data_layout* get(ir::value *v) { return get(layout_of(v));}
333 std::map<size_t, data_layout*> &get_all() { return layouts_; }
334 bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
335 int tmp(ir::value* i) { return tmp_.at(i);}
336 int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
337 int tmp_index(ir::value* i) { return tmp_index_.at(i);}
338 void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
339
340 // layout checkers
341 bool is_scanline(ir::instruction* i);
342
343 bool is_coalesced_scanline(ir::instruction* i);
344
345 bool is_mma(ir::instruction* i);
346
347 bool is_a100_mma(ir::instruction* i);
348
349 // execution
350 void run(ir::module &mod);
351
352private:
353 analysis::axes* axes_;
354 analysis::align* align_;
355 size_t num_warps_;
356 target* tgt_;
357 tools::graph<ir::value*> graph_;
358 std::map<ir::value*, size_t> groups_;
359 std::map<size_t, std::vector<ir::value*>> values_;
360 std::map<size_t, data_layout*> layouts_;
361 std::map<ir::value*, size_t> tmp_;
362 std::map<ir::value*, size_t> tmp_index_;
363};
364
365}
366}
367
368}
369
370#endif
371