1 | #pragma once |
2 | |
3 | #include <optional> |
4 | #include <unordered_set> |
5 | |
6 | #include "taichi/ir/ir.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | /** |
11 | * A basic block in control-flow graph. |
12 | * A CFGNode contains a reference to a part of the CHI IR, or more precisely, |
13 | * an interval of statements in a Block. |
14 | * The edges in the graph are stored in |prev| and |next|. The control flow is |
15 | * possible to go from any node in |prev| to this node, and is possible to go |
16 | * from this node to any node in |next|. |
17 | */ |
18 | class CFGNode { |
19 | private: |
20 | // For accelerating get_store_forwarding_data() |
21 | std::unordered_set<Block *> parent_blocks_; |
22 | |
23 | public: |
24 | // This node corresponds to block->statements[i] |
25 | // for i in [begin_location, end_location). |
26 | Block *block; |
27 | int begin_location, end_location; |
28 | // Is this node in an offloaded range_for/struct_for? |
29 | bool is_parallel_executed; |
30 | |
31 | // For updating begin/end locations when modifying the block. |
32 | CFGNode *prev_node_in_same_block; |
33 | CFGNode *next_node_in_same_block; |
34 | |
35 | // Edges in the graph |
36 | std::vector<CFGNode *> prev, next; |
37 | |
38 | // Reaching definition analysis |
39 | // https://en.wikipedia.org/wiki/Reaching_definition |
40 | std::unordered_set<Stmt *> reach_gen, reach_kill, reach_in, reach_out; |
41 | |
42 | // Live variable analysis |
43 | // https://en.wikipedia.org/wiki/Live_variable_analysis |
44 | std::unordered_set<Stmt *> live_gen, live_kill, live_in, live_out; |
45 | |
46 | CFGNode(Block *block, |
47 | int begin_location, |
48 | int end_location, |
49 | bool is_parallel_executed, |
50 | CFGNode *prev_node_in_same_block); |
51 | |
52 | // An empty node |
53 | CFGNode(); |
54 | |
55 | static void add_edge(CFGNode *from, CFGNode *to); |
56 | |
57 | // Property methods. |
58 | bool empty() const; |
59 | std::size_t size() const; |
60 | |
61 | // Methods for modifying the underlying CHI IR. |
62 | void erase(int location); |
63 | void insert(std::unique_ptr<Stmt> &&new_stmt, int location); |
64 | void replace_with(int location, |
65 | std::unique_ptr<Stmt> &&new_stmt, |
66 | bool replace_usages = true) const; |
67 | |
68 | // Utility methods. |
69 | static bool contain_variable(const std::unordered_set<Stmt *> &var_set, |
70 | Stmt *var); |
71 | static bool may_contain_variable(const std::unordered_set<Stmt *> &var_set, |
72 | Stmt *var); |
73 | bool reach_kill_variable(Stmt *var) const; |
74 | Stmt *get_store_forwarding_data(Stmt *var, int position) const; |
75 | |
76 | // Analyses and optimizations inside a CFGNode. |
77 | void reaching_definition_analysis(bool after_lower_access); |
78 | bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); |
79 | void gather_loaded_snodes(std::unordered_set<SNode *> &snodes) const; |
80 | void live_variable_analysis(bool after_lower_access); |
81 | bool dead_store_elimination(bool after_lower_access); |
82 | }; |
83 | |
84 | class ControlFlowGraph { |
85 | private: |
86 | // Erase an empty node. |
87 | void erase(int node_id); |
88 | |
89 | public: |
90 | struct LiveVarAnalysisConfig { |
91 | // This is mostly useful for SFG task-level dead store elimination. SFG may |
92 | // detect certain cases where writes to one or more SNodes in a task are |
93 | // eliminable. |
94 | std::unordered_set<const SNode *> eliminable_snodes; |
95 | }; |
96 | std::vector<std::unique_ptr<CFGNode>> nodes; |
97 | const int start_node = 0; |
98 | int final_node{0}; |
99 | |
100 | template <typename... Args> |
101 | CFGNode *push_back(Args &&...args) { |
102 | nodes.emplace_back(std::make_unique<CFGNode>(std::forward<Args>(args)...)); |
103 | return nodes.back().get(); |
104 | } |
105 | |
106 | [[nodiscard]] std::size_t size() const; |
107 | [[nodiscard]] CFGNode *back() const; |
108 | |
109 | void print_graph_structure() const; |
110 | |
111 | /** |
112 | * Perform reaching definition analysis using the worklist algorithm, |
113 | * and store the results in CFGNodes. |
114 | * https://en.wikipedia.org/wiki/Reaching_definition |
115 | * |
116 | * @param after_lower_access |
117 | * When after_lower_access is true, only consider local variables (allocas). |
118 | */ |
119 | void reaching_definition_analysis(bool after_lower_access); |
120 | |
121 | /** |
122 | * Perform live variable analysis using the worklist algorithm, |
123 | * and store the results in CFGNodes. |
124 | * https://en.wikipedia.org/wiki/Live_variable_analysis |
125 | * |
126 | * @param after_lower_access |
127 | * When after_lower_access is true, only consider local variables (allocas). |
128 | * @param config_opt |
129 | * The set of SNodes which is never loaded after this task. |
130 | */ |
131 | void live_variable_analysis( |
132 | bool after_lower_access, |
133 | const std::optional<LiveVarAnalysisConfig> &config_opt); |
134 | |
135 | /** |
136 | * Simplify the graph structure to accelerate other analyses and |
137 | * optimizations. The IR is not modified. |
138 | */ |
139 | void simplify_graph(); |
140 | |
141 | // This pass cannot eliminate container statements properly for now. |
142 | bool unreachable_code_elimination(); |
143 | |
144 | /** |
145 | * Perform store-to-load forwarding and identical store elimination. |
146 | */ |
147 | bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); |
148 | |
149 | /** |
150 | * Perform dead store elimination and identical load elimination. |
151 | */ |
152 | bool dead_store_elimination( |
153 | bool after_lower_access, |
154 | const std::optional<LiveVarAnalysisConfig> &lva_config_opt); |
155 | |
156 | /** |
157 | * Gather the SNodes which is read or partially written in this offloaded |
158 | * task. |
159 | */ |
160 | std::unordered_set<SNode *> gather_loaded_snodes(); |
161 | |
162 | /** |
163 | * Determine all adaptive AD-stacks' necessary size. |
164 | * @param default_ad_stack_size The default AD-stack's size when we are |
165 | * unable to determine some AD-stack's size. |
166 | */ |
167 | void determine_ad_stack_size(int default_ad_stack_size); |
168 | }; |
169 | |
170 | } // namespace taichi::lang |
171 | |