1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/transforms.h" |
3 | #include "taichi/ir/analysis.h" |
4 | #include "taichi/ir/pass.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/program/compile_config.h" |
7 | #include "taichi/program/extension.h" |
8 | #include "taichi/program/function.h" |
9 | #include "taichi/program/kernel.h" |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | namespace irpass { |
14 | namespace { |
15 | |
16 | std::function<void(const std::string &)> |
17 | make_pass_printer(bool verbose, const std::string &kernel_name, IRNode *ir) { |
18 | if (!verbose) { |
19 | return [](const std::string &) {}; |
20 | } |
21 | return [ir, kernel_name](const std::string &pass) { |
22 | TI_INFO("[{}] {}:" , kernel_name, pass); |
23 | std::cout << std::flush; |
24 | irpass::re_id(ir); |
25 | irpass::print(ir); |
26 | std::cout << std::flush; |
27 | }; |
28 | } |
29 | |
30 | } // namespace |
31 | |
32 | void compile_to_offloads(IRNode *ir, |
33 | const CompileConfig &config, |
34 | Kernel *kernel, |
35 | bool verbose, |
36 | AutodiffMode autodiff_mode, |
37 | bool ad_use_stack, |
38 | bool start_from_ast) { |
39 | TI_AUTO_PROF; |
40 | |
41 | auto print = make_pass_printer(verbose, kernel->get_name(), ir); |
42 | print("Initial IR" ); |
43 | |
44 | if (!verbose && config.print_preprocessed_ir && start_from_ast) { |
45 | TI_INFO("[{}] {}:" , kernel->get_name(), "Preprocessed IR" ); |
46 | std::cout << std::flush; |
47 | irpass::re_id(ir); |
48 | irpass::print(ir); |
49 | std::cout << std::flush; |
50 | } |
51 | |
52 | if (autodiff_mode == AutodiffMode::kReverse) { |
53 | irpass::reverse_segments(ir); |
54 | print("Segment reversed (for autodiff)" ); |
55 | } |
56 | |
57 | if (start_from_ast) { |
58 | irpass::frontend_type_check(ir); |
59 | irpass::lower_ast(ir); |
60 | print("Lowered" ); |
61 | } |
62 | |
63 | irpass::compile_taichi_functions(ir, config); |
64 | |
65 | irpass::eliminate_immutable_local_vars(ir); |
66 | print("Immutable local vars eliminated" ); |
67 | |
68 | if (config.real_matrix_scalarize) { |
69 | irpass::scalarize(ir); |
70 | |
71 | // Remove redundant MatrixInitStmt inserted during scalarization |
72 | irpass::die(ir); |
73 | print("Scalarized" ); |
74 | } |
75 | |
76 | irpass::lower_matrix_ptr(ir); |
77 | print("Matrix ptr lowered" ); |
78 | |
79 | irpass::type_check(ir, config); |
80 | print("Typechecked" ); |
81 | irpass::analysis::verify(ir); |
82 | |
83 | if (kernel->is_evaluator) { |
84 | TI_ASSERT(autodiff_mode == AutodiffMode::kNone); |
85 | |
86 | irpass::demote_operations(ir, config); |
87 | print("Operations demoted" ); |
88 | |
89 | irpass::offload(ir, config); |
90 | print("Offloaded" ); |
91 | irpass::analysis::verify(ir); |
92 | return; |
93 | } |
94 | |
95 | // TODO: strictly enforce bit vectorization for x86 cpu and CUDA now |
96 | // create a separate CompileConfig flag for the new pass |
97 | if (arch_is_cpu(config.arch) || config.arch == Arch::cuda || |
98 | config.arch == Arch::amdgpu) { |
99 | irpass::bit_loop_vectorize(ir); |
100 | irpass::type_check(ir, config); |
101 | print("Bit Loop Vectorized" ); |
102 | irpass::analysis::verify(ir); |
103 | } |
104 | |
105 | irpass::full_simplify( |
106 | ir, config, |
107 | {false, /*autodiff_enabled*/ autodiff_mode != AutodiffMode::kNone, |
108 | kernel->program}); |
109 | print("Simplified I" ); |
110 | irpass::analysis::verify(ir); |
111 | |
112 | if (is_extension_supported(config.arch, Extension::mesh)) { |
113 | irpass::analysis::gather_meshfor_relation_types(ir); |
114 | } |
115 | |
116 | if (config.debug && autodiff_mode == AutodiffMode::kCheckAutodiffValid) { |
117 | // Check whether the kernel obeys the autodiff limitation e.g., gloabl data |
118 | // access rule |
119 | // This check should be performed in the forward kernel i.e., autodiff_mode |
120 | // == AutodiffMode::kCheckAutodiffValid |
121 | irpass::demote_atomics(ir, config); |
122 | irpass::differentiation_validation_check(ir, config, kernel->get_name()); |
123 | irpass::analysis::verify(ir); |
124 | } |
125 | |
126 | if (autodiff_mode == AutodiffMode::kReverse || |
127 | autodiff_mode == AutodiffMode::kForward) { |
128 | // Remove local atomics here so that we don't have to handle their gradients |
129 | irpass::demote_atomics(ir, config); |
130 | |
131 | irpass::full_simplify(ir, config, |
132 | {false, /*autodiff_enabled*/ true, kernel->program}); |
133 | irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack); |
134 | // TODO: Be carefull with the full_simplify when do high-order autodiff |
135 | irpass::full_simplify(ir, config, |
136 | {false, /*autodiff_enabled*/ false, kernel->program}); |
137 | print("Gradient" ); |
138 | irpass::analysis::verify(ir); |
139 | } |
140 | |
141 | if (config.check_out_of_bound) { |
142 | irpass::check_out_of_bound(ir, config, {kernel->get_name()}); |
143 | print("Bound checked" ); |
144 | irpass::analysis::verify(ir); |
145 | } |
146 | |
147 | irpass::flag_access(ir); |
148 | print("Access flagged I" ); |
149 | irpass::analysis::verify(ir); |
150 | |
151 | irpass::full_simplify(ir, config, |
152 | {false, /*autodiff_enabled*/ false, kernel->program}); |
153 | print("Simplified II" ); |
154 | irpass::analysis::verify(ir); |
155 | |
156 | irpass::offload(ir, config); |
157 | print("Offloaded" ); |
158 | irpass::analysis::verify(ir); |
159 | |
160 | // TODO: This pass may be redundant as cfg_optimization() is already called |
161 | // in full_simplify(). |
162 | if (config.opt_level > 0 && config.cfg_optimization) { |
163 | irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false, |
164 | !config.real_matrix_scalarize); |
165 | print("Optimized by CFG" ); |
166 | irpass::analysis::verify(ir); |
167 | } |
168 | |
169 | irpass::flag_access(ir); |
170 | print("Access flagged II" ); |
171 | |
172 | irpass::full_simplify(ir, config, |
173 | {false, /*autodiff_enabled*/ false, kernel->program}); |
174 | print("Simplified III" ); |
175 | irpass::analysis::verify(ir); |
176 | } |
177 | |
178 | void offload_to_executable(IRNode *ir, |
179 | const CompileConfig &config, |
180 | Kernel *kernel, |
181 | bool verbose, |
182 | bool determine_ad_stack_size, |
183 | bool lower_global_access, |
184 | bool make_thread_local, |
185 | bool make_block_local) { |
186 | TI_AUTO_PROF; |
187 | |
188 | auto print = make_pass_printer(verbose, kernel->get_name(), ir); |
189 | |
190 | // TODO: This is just a proof that we can demote struct-fors after offloading. |
191 | // Eventually we might want the order to be TLS/BLS -> demote struct-for. |
192 | // For now, putting this after TLS will disable TLS, because it can only |
193 | // handle range-fors at this point. |
194 | |
195 | auto amgr = std::make_unique<AnalysisManager>(); |
196 | |
197 | print("Start offload_to_executable" ); |
198 | irpass::analysis::verify(ir); |
199 | |
200 | if (config.detect_read_only) { |
201 | irpass::detect_read_only(ir); |
202 | print("Detect read-only accesses" ); |
203 | } |
204 | |
205 | irpass::demote_atomics(ir, config); |
206 | print("Atomics demoted I" ); |
207 | irpass::analysis::verify(ir); |
208 | if (config.cache_loop_invariant_global_vars) { |
209 | irpass::cache_loop_invariant_global_vars(ir, config); |
210 | print("Cache loop-invariant global vars" ); |
211 | } |
212 | |
213 | if (config.demote_dense_struct_fors) { |
214 | irpass::demote_dense_struct_fors(ir); |
215 | irpass::type_check(ir, config); |
216 | print("Dense struct-for demoted" ); |
217 | irpass::analysis::verify(ir); |
218 | } |
219 | |
220 | if (is_extension_supported(config.arch, Extension::mesh) && |
221 | config.demote_no_access_mesh_fors) { |
222 | irpass::demote_no_access_mesh_fors(ir); |
223 | irpass::type_check(ir, config); |
224 | print("No-access mesh-for demoted" ); |
225 | irpass::analysis::verify(ir); |
226 | } |
227 | |
228 | if (make_thread_local) { |
229 | irpass::make_thread_local(ir, config); |
230 | print("Make thread local" ); |
231 | } |
232 | |
233 | if (is_extension_supported(config.arch, Extension::mesh)) { |
234 | irpass::make_mesh_thread_local(ir, config, {kernel->get_name()}); |
235 | print("Make mesh thread local" ); |
236 | if (config.make_mesh_block_local && config.arch == Arch::cuda) { |
237 | irpass::make_mesh_block_local(ir, config, {kernel->get_name()}); |
238 | print("Make mesh block local" ); |
239 | irpass::full_simplify( |
240 | ir, config, {false, /*autodiff_enabled*/ false, kernel->program}); |
241 | print("Simplified X" ); |
242 | } |
243 | } |
244 | |
245 | if (make_block_local) { |
246 | irpass::make_block_local(ir, config, {kernel->get_name()}); |
247 | print("Make block local" ); |
248 | } |
249 | |
250 | if (is_extension_supported(config.arch, Extension::mesh)) { |
251 | irpass::demote_mesh_statements(ir, config, {kernel->get_name()}); |
252 | print("Demote mesh statements" ); |
253 | } |
254 | |
255 | irpass::demote_atomics(ir, config); |
256 | print("Atomics demoted II" ); |
257 | irpass::analysis::verify(ir); |
258 | |
259 | if (is_extension_supported(config.arch, Extension::quant) && |
260 | config.quant_opt_atomic_demotion) { |
261 | irpass::analysis::gather_uniquely_accessed_bit_structs(ir, amgr.get()); |
262 | } |
263 | |
264 | irpass::remove_range_assumption(ir); |
265 | print("Remove range assumption" ); |
266 | |
267 | irpass::remove_loop_unique(ir); |
268 | print("Remove loop_unique" ); |
269 | irpass::analysis::verify(ir); |
270 | |
271 | if (lower_global_access) { |
272 | irpass::full_simplify(ir, config, |
273 | {false, /*autodiff_enabled*/ false, kernel->program}); |
274 | print("Simplified before lower access" ); |
275 | irpass::lower_access(ir, config, {kernel->no_activate, true}); |
276 | print("Access lowered" ); |
277 | irpass::analysis::verify(ir); |
278 | |
279 | irpass::die(ir); |
280 | print("DIE" ); |
281 | irpass::analysis::verify(ir); |
282 | |
283 | irpass::flag_access(ir); |
284 | print("Access flagged III" ); |
285 | irpass::analysis::verify(ir); |
286 | } |
287 | |
288 | irpass::demote_operations(ir, config); |
289 | print("Operations demoted" ); |
290 | |
291 | irpass::full_simplify( |
292 | ir, config, |
293 | {lower_global_access, /*autodiff_enabled*/ false, kernel->program}); |
294 | print("Simplified IV" ); |
295 | |
296 | if (determine_ad_stack_size) { |
297 | irpass::determine_ad_stack_size(ir, config); |
298 | print("Autodiff stack size determined" ); |
299 | } |
300 | |
301 | if (is_extension_supported(config.arch, Extension::quant)) { |
302 | irpass::optimize_bit_struct_stores(ir, config, amgr.get()); |
303 | print("Bit struct stores optimized" ); |
304 | } |
305 | |
306 | // Final field registration correctness & type checking |
307 | irpass::type_check(ir, config); |
308 | irpass::analysis::verify(ir); |
309 | } |
310 | |
311 | void compile_to_executable(IRNode *ir, |
312 | const CompileConfig &config, |
313 | Kernel *kernel, |
314 | AutodiffMode autodiff_mode, |
315 | bool ad_use_stack, |
316 | bool verbose, |
317 | bool lower_global_access, |
318 | bool make_thread_local, |
319 | bool make_block_local, |
320 | bool start_from_ast) { |
321 | TI_AUTO_PROF; |
322 | |
323 | compile_to_offloads(ir, config, kernel, verbose, autodiff_mode, ad_use_stack, |
324 | start_from_ast); |
325 | |
326 | offload_to_executable( |
327 | ir, config, kernel, verbose, |
328 | /*determine_ad_stack_size=*/autodiff_mode == AutodiffMode::kReverse && |
329 | ad_use_stack, |
330 | lower_global_access, make_thread_local, make_block_local); |
331 | } |
332 | |
333 | void compile_function(IRNode *ir, |
334 | const CompileConfig &config, |
335 | Function *func, |
336 | AutodiffMode autodiff_mode, |
337 | bool verbose, |
338 | bool start_from_ast) { |
339 | TI_AUTO_PROF; |
340 | |
341 | auto print = make_pass_printer(verbose, func->get_name(), ir); |
342 | print("Initial IR" ); |
343 | |
344 | if (autodiff_mode == AutodiffMode::kReverse) { |
345 | irpass::reverse_segments(ir); |
346 | print("Segment reversed (for autodiff)" ); |
347 | } |
348 | |
349 | if (start_from_ast) { |
350 | irpass::frontend_type_check(ir); |
351 | irpass::lower_ast(ir); |
352 | print("Lowered" ); |
353 | } |
354 | |
355 | if (config.real_matrix_scalarize) { |
356 | irpass::scalarize(ir); |
357 | |
358 | // Remove redundant MatrixInitStmt inserted during scalarization |
359 | irpass::die(ir); |
360 | print("Scalarized" ); |
361 | } |
362 | |
363 | irpass::lower_access(ir, config, {{}, true}); |
364 | print("Access lowered" ); |
365 | irpass::analysis::verify(ir); |
366 | |
367 | irpass::die(ir); |
368 | print("DIE" ); |
369 | irpass::analysis::verify(ir); |
370 | |
371 | irpass::flag_access(ir); |
372 | print("Access flagged III" ); |
373 | irpass::analysis::verify(ir); |
374 | |
375 | irpass::type_check(ir, config); |
376 | print("Typechecked" ); |
377 | |
378 | irpass::demote_operations(ir, config); |
379 | print("Operations demoted" ); |
380 | |
381 | irpass::full_simplify( |
382 | ir, config, {false, autodiff_mode != AutodiffMode::kNone, func->program}); |
383 | print("Simplified" ); |
384 | irpass::analysis::verify(ir); |
385 | } |
386 | |
387 | } // namespace irpass |
388 | |
389 | } // namespace taichi::lang |
390 | |