1 | #ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H |
---|---|
2 | #define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H |
3 | |
4 | #include <map> |
5 | #include <vector> |
6 | |
7 | namespace triton { |
8 | |
9 | namespace ir { |
10 | class value; |
11 | class module; |
12 | class phi_node; |
13 | class splat_inst; |
14 | class cast_inst; |
15 | class cmp_inst; |
16 | class reshape_inst; |
17 | class dequantize_inst; |
18 | class broadcast_inst; |
19 | class binary_operator; |
20 | class getelementptr_inst; |
21 | } |
22 | |
23 | namespace codegen{ |
24 | namespace analysis{ |
25 | |
26 | class align { |
27 | private: |
28 | struct cst_info { |
29 | unsigned num_cst; |
30 | unsigned value; |
31 | }; |
32 | // helpers |
33 | std::vector<unsigned> get_shapes(ir::value *v); |
34 | // populate is_constant |
35 | std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x); |
36 | std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x); |
37 | std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x); |
38 | std::vector<cst_info> populate_is_constant_dequantize(ir::dequantize_inst* x); |
39 | std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x); |
40 | std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x); |
41 | std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x); |
42 | std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x); |
43 | std::vector<cst_info> populate_is_constant_default(ir::value* v); |
44 | std::vector<cst_info> populate_is_constant(ir::value *v); |
45 | // populate max_contiguous |
46 | std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x); |
47 | std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x); |
48 | std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x); |
49 | std::vector<unsigned> populate_max_contiguous_dequantize(ir::dequantize_inst* x); |
50 | std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x); |
51 | std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x); |
52 | std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x); |
53 | std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x); |
54 | std::vector<unsigned> populate_max_contiguous_default(ir::value* v); |
55 | std::vector<unsigned> populate_max_contiguous(ir::value *v); |
56 | // populate starting_multiple |
57 | std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x); |
58 | std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x); |
59 | std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x); |
60 | std::vector<unsigned> populate_starting_multiple_dequantize(ir::dequantize_inst* x); |
61 | std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x); |
62 | std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x); |
63 | std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x); |
64 | std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x); |
65 | std::vector<unsigned> populate_starting_multiple_default(ir::value* v); |
66 | std::vector<unsigned> populate_starting_multiple(ir::value *v); |
67 | // populate all maps |
68 | void populate(ir::value *v); |
69 | |
70 | public: |
71 | void run(ir::module &mod); |
72 | unsigned get(ir::value* v, unsigned ax) const; |
73 | std::vector<unsigned> contiguous(ir::value* v) const; |
74 | std::vector<cst_info> get_cst_info(ir::value* v) const; |
75 | |
76 | private: |
77 | std::map<ir::value*, std::vector<cst_info>> is_constant_; |
78 | std::map<ir::value*, std::vector<unsigned>> max_contiguous_; |
79 | std::map<ir::value*, std::vector<unsigned>> starting_multiple_; |
80 | }; |
81 | |
82 | |
83 | } |
84 | } |
85 | } |
86 | |
87 | #endif |
88 |