1 | #ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_ |
---|---|
2 | #define _TRITON_CODEGEN_ANALYSIS_AXES_H_ |
3 | |
4 | #include "triton/tools/graph.h" |
5 | #include <map> |
6 | #include <vector> |
7 | |
8 | namespace triton{ |
9 | |
10 | namespace ir{ |
11 | class value; |
12 | class module; |
13 | class instruction; |
14 | } |
15 | |
16 | namespace codegen{ |
17 | namespace analysis{ |
18 | |
19 | class axes { |
20 | typedef std::pair<ir::value*, unsigned> node_t; |
21 | |
22 | private: |
23 | // update graph |
24 | void update_graph_store(ir::instruction *i); |
25 | void update_graph_reduce(ir::instruction *i); |
26 | void update_graph_reshape(ir::instruction *i); |
27 | void update_graph_trans(ir::instruction *i); |
28 | void update_graph_dequantize(ir::instruction *i); |
29 | void update_graph_broadcast(ir::instruction *i); |
30 | void update_graph_dot(ir::instruction *i); |
31 | void update_graph_elementwise(ir::instruction *i, |
32 | bool is_masked_load_async=false); |
33 | void update_graph_no_edge(ir::instruction *i); |
34 | void update_graph(ir::instruction *i); |
35 | |
36 | public: |
37 | axes(); |
38 | void run(ir::module &mod); |
39 | // accessors |
40 | int get(ir::value *value, unsigned dim); |
41 | std::vector<int> get(ir::value *value); |
42 | |
43 | private: |
44 | tools::graph<node_t> graph_; |
45 | std::map<node_t, size_t> axes_; |
46 | }; |
47 | |
48 | } |
49 | } |
50 | |
51 | } |
52 | |
53 | #endif |
54 |