1 | #pragma once |
2 | |
3 | #include <atomic> |
4 | #include <optional> |
5 | #include <unordered_map> |
6 | #include <unordered_set> |
7 | |
8 | #include "taichi/ir/control_flow_graph.h" |
9 | #include "taichi/ir/ir.h" |
10 | #include "taichi/ir/pass.h" |
11 | #include "taichi/transforms/check_out_of_bound.h" |
12 | #include "taichi/transforms/constant_fold.h" |
13 | #include "taichi/transforms/inlining.h" |
14 | #include "taichi/transforms/lower_access.h" |
15 | #include "taichi/transforms/make_block_local.h" |
16 | #include "taichi/transforms/make_mesh_block_local.h" |
17 | #include "taichi/transforms/demote_mesh_statements.h" |
18 | #include "taichi/transforms/simplify.h" |
19 | #include "taichi/common/trait.h" |
20 | |
21 | namespace taichi::lang { |
22 | |
23 | class ScratchPads; |
24 | |
25 | class Function; |
26 | |
27 | // IR passes |
28 | namespace irpass { |
29 | |
30 | void re_id(IRNode *root); |
31 | void flag_access(IRNode *root); |
32 | void eliminate_immutable_local_vars(IRNode *root); |
33 | void scalarize(IRNode *root); |
34 | void lower_matrix_ptr(IRNode *root); |
35 | bool die(IRNode *root); |
36 | bool simplify(IRNode *root, const CompileConfig &config); |
37 | bool cfg_optimization( |
38 | IRNode *root, |
39 | bool after_lower_access, |
40 | bool autodiff_enabled, |
41 | bool real_matrix_enabled, |
42 | const std::optional<ControlFlowGraph::LiveVarAnalysisConfig> |
43 | &lva_config_opt = std::nullopt); |
44 | bool alg_simp(IRNode *root, const CompileConfig &config); |
45 | bool demote_operations(IRNode *root, const CompileConfig &config); |
46 | bool binary_op_simplify(IRNode *root, const CompileConfig &config); |
47 | bool whole_kernel_cse(IRNode *root); |
48 | bool (IRNode *root, const CompileConfig &config); |
49 | bool unreachable_code_elimination(IRNode *root); |
50 | bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config); |
51 | bool cache_loop_invariant_global_vars(IRNode *root, |
52 | const CompileConfig &config); |
53 | void full_simplify(IRNode *root, |
54 | const CompileConfig &config, |
55 | const FullSimplifyPass::Args &args); |
56 | void print(IRNode *root, std::string *output = nullptr); |
57 | void frontend_type_check(IRNode *root); |
58 | void lower_ast(IRNode *root); |
59 | void type_check(IRNode *root, const CompileConfig &config); |
60 | bool inlining(IRNode *root, |
61 | const CompileConfig &config, |
62 | const InliningPass::Args &args); |
63 | void bit_loop_vectorize(IRNode *root); |
64 | void slp_vectorize(IRNode *root); |
65 | void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt); |
66 | bool check_out_of_bound(IRNode *root, |
67 | const CompileConfig &config, |
68 | const CheckOutOfBoundPass::Args &args); |
69 | void make_thread_local(IRNode *root, const CompileConfig &config); |
70 | std::unique_ptr<ScratchPads> initialize_scratch_pad(OffloadedStmt *root); |
71 | void make_block_local(IRNode *root, |
72 | const CompileConfig &config, |
73 | const MakeBlockLocalPass::Args &args); |
74 | void make_mesh_thread_local(IRNode *root, |
75 | const CompileConfig &config, |
76 | const MakeBlockLocalPass::Args &args); |
77 | void make_mesh_block_local(IRNode *root, |
78 | const CompileConfig &config, |
79 | const MakeMeshBlockLocal::Args &args); |
80 | void demote_mesh_statements(IRNode *root, |
81 | const CompileConfig &config, |
82 | const DemoteMeshStatements::Args &args); |
83 | bool remove_loop_unique(IRNode *root); |
84 | bool remove_range_assumption(IRNode *root); |
85 | bool lower_access(IRNode *root, |
86 | const CompileConfig &config, |
87 | const LowerAccessPass::Args &args); |
88 | void auto_diff(IRNode *root, |
89 | const CompileConfig &config, |
90 | AutodiffMode autodiffMode, |
91 | bool use_stack = false); |
92 | /** |
93 | * Check whether the kernel obeys the autodiff limitation e.g., gloabl data |
94 | * access rule |
95 | */ |
96 | void differentiation_validation_check(IRNode *root, |
97 | const CompileConfig &config, |
98 | const std::string &kernel_name); |
99 | /** |
100 | * Determine all adaptive AD-stacks' size. This pass is idempotent, i.e., |
101 | * there are no side effects if called more than once or called when not needed. |
102 | * @return Whether the IR is modified, i.e., whether there exists adaptive |
103 | * AD-stacks before this pass. |
104 | */ |
105 | bool determine_ad_stack_size(IRNode *root, const CompileConfig &config); |
106 | bool constant_fold(IRNode *root, |
107 | const CompileConfig &config, |
108 | const ConstantFoldPass::Args &args); |
109 | void offload(IRNode *root, const CompileConfig &config); |
110 | bool transform_statements( |
111 | IRNode *root, |
112 | std::function<bool(Stmt *)> filter, |
113 | std::function<void(Stmt *, DelayedIRModifier *)> transformer); |
114 | /** |
115 | * @param root The IR root to be traversed. |
116 | * @param filter A function which tells if a statement need to be replaced. |
117 | * @param generator If a statement |s| need to be replaced, generate a new |
118 | * statement |s1| with the argument |s|, insert |s1| to where |s| is defined, |
119 | * remove |s|'s definition, and replace all usages of |s| with |s1|. |
120 | * @return Whether the IR is modified. |
121 | */ |
122 | bool replace_and_insert_statements( |
123 | IRNode *root, |
124 | std::function<bool(Stmt *)> filter, |
125 | std::function<std::unique_ptr<Stmt>(Stmt *)> generator); |
126 | /** |
127 | * @param finder If a statement |s| need to be replaced, find the existing |
128 | * statement |s1| with the argument |s|, remove |s|'s definition, and replace |
129 | * all usages of |s| with |s1|. |
130 | */ |
131 | bool replace_statements(IRNode *root, |
132 | std::function<bool(Stmt *)> filter, |
133 | std::function<Stmt *(Stmt *)> finder); |
134 | void demote_dense_struct_fors(IRNode *root); |
135 | void demote_no_access_mesh_fors(IRNode *root); |
136 | bool demote_atomics(IRNode *root, const CompileConfig &config); |
137 | void reverse_segments(IRNode *root); // for autograd |
138 | void detect_read_only(IRNode *root); |
139 | void optimize_bit_struct_stores(IRNode *root, |
140 | const CompileConfig &config, |
141 | AnalysisManager *amgr); |
142 | |
143 | ENUM_FLAGS(ExternalPtrAccess){NONE = 0, READ = 1, WRITE = 2}; |
144 | |
145 | /** |
146 | * Checks the access to external pointers in an offload. |
147 | * |
148 | * @param val1 |
149 | * The offloaded statement to check |
150 | * |
151 | * @return |
152 | * The analyzed result. |
153 | */ |
154 | std::unordered_map<int, ExternalPtrAccess> detect_external_ptr_access_in_task( |
155 | OffloadedStmt *offload); |
156 | |
157 | // compile_to_offloads does the basic compilation to create all the offloaded |
158 | // tasks of a Taichi kernel. It's worth pointing out that this doesn't demote |
159 | // dense struct fors. This is a necessary workaround to prevent the async |
160 | // engine from fusing incompatible offloaded tasks. TODO(Lin): check this |
161 | // comment |
162 | void compile_to_offloads(IRNode *ir, |
163 | const CompileConfig &config, |
164 | Kernel *kernel, |
165 | bool verbose, |
166 | AutodiffMode autodiff_mode, |
167 | bool ad_use_stack, |
168 | bool start_from_ast); |
169 | |
170 | void offload_to_executable(IRNode *ir, |
171 | const CompileConfig &config, |
172 | Kernel *kernel, |
173 | bool verbose, |
174 | bool determine_ad_stack_size, |
175 | bool lower_global_access, |
176 | bool make_thread_local, |
177 | bool make_block_local); |
178 | // compile_to_executable fully covers compile_to_offloads, and also does |
179 | // additional optimizations so that |ir| can be directly fed into codegen. |
180 | void compile_to_executable(IRNode *ir, |
181 | const CompileConfig &config, |
182 | Kernel *kernel, |
183 | AutodiffMode autodiff_mode, |
184 | bool ad_use_stack, |
185 | bool verbose, |
186 | bool lower_global_access = true, |
187 | bool make_thread_local = false, |
188 | bool make_block_local = false, |
189 | bool start_from_ast = true); |
190 | // Compile a function with some basic optimizations, so that the number of |
191 | // statements is reduced before inlining. |
192 | void compile_function(IRNode *ir, |
193 | const CompileConfig &config, |
194 | Function *func, |
195 | AutodiffMode autodiff_mode, |
196 | bool verbose, |
197 | bool start_from_ast); |
198 | |
199 | void ast_to_ir(const CompileConfig &config, |
200 | Kernel &kernel, |
201 | bool to_executable = true); |
202 | |
203 | void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config); |
204 | } // namespace irpass |
205 | |
206 | } // namespace taichi::lang |
207 | |