1#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
2#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
3
4#include "triton/tools/graph.h"
5#include <map>
6#include <vector>
7
8namespace triton{
9
10namespace ir{
11 class value;
12 class module;
13 class instruction;
14}
15
16namespace codegen{
17namespace analysis{
18
19class axes {
20 typedef std::pair<ir::value*, unsigned> node_t;
21
22private:
23 // update graph
24 void update_graph_store(ir::instruction *i);
25 void update_graph_reduce(ir::instruction *i);
26 void update_graph_reshape(ir::instruction *i);
27 void update_graph_trans(ir::instruction *i);
28 void update_graph_dequantize(ir::instruction *i);
29 void update_graph_broadcast(ir::instruction *i);
30 void update_graph_dot(ir::instruction *i);
31 void update_graph_elementwise(ir::instruction *i,
32 bool is_masked_load_async=false);
33 void update_graph_no_edge(ir::instruction *i);
34 void update_graph(ir::instruction *i);
35
36public:
37 axes();
38 void run(ir::module &mod);
39 // accessors
40 int get(ir::value *value, unsigned dim);
41 std::vector<int> get(ir::value *value);
42
43private:
44 tools::graph<node_t> graph_;
45 std::map<node_t, size_t> axes_;
46};
47
48}
49}
50
51}
52
53#endif
54