1#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
2#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
3
4#include <map>
5#include <vector>
6
7namespace triton {
8
9namespace ir {
10 class value;
11 class module;
12 class phi_node;
13 class splat_inst;
14 class cast_inst;
15 class cmp_inst;
16 class reshape_inst;
17 class dequantize_inst;
18 class broadcast_inst;
19 class binary_operator;
20 class getelementptr_inst;
21}
22
23namespace codegen{
24namespace analysis{
25
26class align {
27private:
28 struct cst_info {
29 unsigned num_cst;
30 unsigned value;
31 };
32 // helpers
33 std::vector<unsigned> get_shapes(ir::value *v);
34 // populate is_constant
35 std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
36 std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
37 std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
38 std::vector<cst_info> populate_is_constant_dequantize(ir::dequantize_inst* x);
39 std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
40 std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
41 std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x);
42 std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
43 std::vector<cst_info> populate_is_constant_default(ir::value* v);
44 std::vector<cst_info> populate_is_constant(ir::value *v);
45 // populate max_contiguous
46 std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
47 std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
48 std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
49 std::vector<unsigned> populate_max_contiguous_dequantize(ir::dequantize_inst* x);
50 std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
51 std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
52 std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
53 std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
54 std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
55 std::vector<unsigned> populate_max_contiguous(ir::value *v);
56 // populate starting_multiple
57 std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
58 std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
59 std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
60 std::vector<unsigned> populate_starting_multiple_dequantize(ir::dequantize_inst* x);
61 std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
62 std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
63 std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
64 std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
65 std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
66 std::vector<unsigned> populate_starting_multiple(ir::value *v);
67 // populate all maps
68 void populate(ir::value *v);
69
70public:
71 void run(ir::module &mod);
72 unsigned get(ir::value* v, unsigned ax) const;
73 std::vector<unsigned> contiguous(ir::value* v) const;
74 std::vector<cst_info> get_cst_info(ir::value* v) const;
75
76private:
77 std::map<ir::value*, std::vector<cst_info>> is_constant_;
78 std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
79 std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
80};
81
82
83}
84}
85}
86
87#endif
88