1#include "taichi/ir/ir.h"
2#include "taichi/ir/transforms.h"
3#include "taichi/ir/analysis.h"
4#include "taichi/ir/pass.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/program/compile_config.h"
7#include "taichi/program/extension.h"
8#include "taichi/program/function.h"
9#include "taichi/program/kernel.h"
10
11namespace taichi::lang {
12
13namespace irpass {
14namespace {
15
16std::function<void(const std::string &)>
17make_pass_printer(bool verbose, const std::string &kernel_name, IRNode *ir) {
18 if (!verbose) {
19 return [](const std::string &) {};
20 }
21 return [ir, kernel_name](const std::string &pass) {
22 TI_INFO("[{}] {}:", kernel_name, pass);
23 std::cout << std::flush;
24 irpass::re_id(ir);
25 irpass::print(ir);
26 std::cout << std::flush;
27 };
28}
29
30} // namespace
31
32void compile_to_offloads(IRNode *ir,
33 const CompileConfig &config,
34 Kernel *kernel,
35 bool verbose,
36 AutodiffMode autodiff_mode,
37 bool ad_use_stack,
38 bool start_from_ast) {
39 TI_AUTO_PROF;
40
41 auto print = make_pass_printer(verbose, kernel->get_name(), ir);
42 print("Initial IR");
43
44 if (!verbose && config.print_preprocessed_ir && start_from_ast) {
45 TI_INFO("[{}] {}:", kernel->get_name(), "Preprocessed IR");
46 std::cout << std::flush;
47 irpass::re_id(ir);
48 irpass::print(ir);
49 std::cout << std::flush;
50 }
51
52 if (autodiff_mode == AutodiffMode::kReverse) {
53 irpass::reverse_segments(ir);
54 print("Segment reversed (for autodiff)");
55 }
56
57 if (start_from_ast) {
58 irpass::frontend_type_check(ir);
59 irpass::lower_ast(ir);
60 print("Lowered");
61 }
62
63 irpass::compile_taichi_functions(ir, config);
64
65 irpass::eliminate_immutable_local_vars(ir);
66 print("Immutable local vars eliminated");
67
68 if (config.real_matrix_scalarize) {
69 irpass::scalarize(ir);
70
71 // Remove redundant MatrixInitStmt inserted during scalarization
72 irpass::die(ir);
73 print("Scalarized");
74 }
75
76 irpass::lower_matrix_ptr(ir);
77 print("Matrix ptr lowered");
78
79 irpass::type_check(ir, config);
80 print("Typechecked");
81 irpass::analysis::verify(ir);
82
83 if (kernel->is_evaluator) {
84 TI_ASSERT(autodiff_mode == AutodiffMode::kNone);
85
86 irpass::demote_operations(ir, config);
87 print("Operations demoted");
88
89 irpass::offload(ir, config);
90 print("Offloaded");
91 irpass::analysis::verify(ir);
92 return;
93 }
94
95 // TODO: strictly enforce bit vectorization for x86 cpu and CUDA now
96 // create a separate CompileConfig flag for the new pass
97 if (arch_is_cpu(config.arch) || config.arch == Arch::cuda ||
98 config.arch == Arch::amdgpu) {
99 irpass::bit_loop_vectorize(ir);
100 irpass::type_check(ir, config);
101 print("Bit Loop Vectorized");
102 irpass::analysis::verify(ir);
103 }
104
105 irpass::full_simplify(
106 ir, config,
107 {false, /*autodiff_enabled*/ autodiff_mode != AutodiffMode::kNone,
108 kernel->program});
109 print("Simplified I");
110 irpass::analysis::verify(ir);
111
112 if (is_extension_supported(config.arch, Extension::mesh)) {
113 irpass::analysis::gather_meshfor_relation_types(ir);
114 }
115
116 if (config.debug && autodiff_mode == AutodiffMode::kCheckAutodiffValid) {
117 // Check whether the kernel obeys the autodiff limitation e.g., gloabl data
118 // access rule
119 // This check should be performed in the forward kernel i.e., autodiff_mode
120 // == AutodiffMode::kCheckAutodiffValid
121 irpass::demote_atomics(ir, config);
122 irpass::differentiation_validation_check(ir, config, kernel->get_name());
123 irpass::analysis::verify(ir);
124 }
125
126 if (autodiff_mode == AutodiffMode::kReverse ||
127 autodiff_mode == AutodiffMode::kForward) {
128 // Remove local atomics here so that we don't have to handle their gradients
129 irpass::demote_atomics(ir, config);
130
131 irpass::full_simplify(ir, config,
132 {false, /*autodiff_enabled*/ true, kernel->program});
133 irpass::auto_diff(ir, config, autodiff_mode, ad_use_stack);
134 // TODO: Be carefull with the full_simplify when do high-order autodiff
135 irpass::full_simplify(ir, config,
136 {false, /*autodiff_enabled*/ false, kernel->program});
137 print("Gradient");
138 irpass::analysis::verify(ir);
139 }
140
141 if (config.check_out_of_bound) {
142 irpass::check_out_of_bound(ir, config, {kernel->get_name()});
143 print("Bound checked");
144 irpass::analysis::verify(ir);
145 }
146
147 irpass::flag_access(ir);
148 print("Access flagged I");
149 irpass::analysis::verify(ir);
150
151 irpass::full_simplify(ir, config,
152 {false, /*autodiff_enabled*/ false, kernel->program});
153 print("Simplified II");
154 irpass::analysis::verify(ir);
155
156 irpass::offload(ir, config);
157 print("Offloaded");
158 irpass::analysis::verify(ir);
159
160 // TODO: This pass may be redundant as cfg_optimization() is already called
161 // in full_simplify().
162 if (config.opt_level > 0 && config.cfg_optimization) {
163 irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false,
164 !config.real_matrix_scalarize);
165 print("Optimized by CFG");
166 irpass::analysis::verify(ir);
167 }
168
169 irpass::flag_access(ir);
170 print("Access flagged II");
171
172 irpass::full_simplify(ir, config,
173 {false, /*autodiff_enabled*/ false, kernel->program});
174 print("Simplified III");
175 irpass::analysis::verify(ir);
176}
177
178void offload_to_executable(IRNode *ir,
179 const CompileConfig &config,
180 Kernel *kernel,
181 bool verbose,
182 bool determine_ad_stack_size,
183 bool lower_global_access,
184 bool make_thread_local,
185 bool make_block_local) {
186 TI_AUTO_PROF;
187
188 auto print = make_pass_printer(verbose, kernel->get_name(), ir);
189
190 // TODO: This is just a proof that we can demote struct-fors after offloading.
191 // Eventually we might want the order to be TLS/BLS -> demote struct-for.
192 // For now, putting this after TLS will disable TLS, because it can only
193 // handle range-fors at this point.
194
195 auto amgr = std::make_unique<AnalysisManager>();
196
197 print("Start offload_to_executable");
198 irpass::analysis::verify(ir);
199
200 if (config.detect_read_only) {
201 irpass::detect_read_only(ir);
202 print("Detect read-only accesses");
203 }
204
205 irpass::demote_atomics(ir, config);
206 print("Atomics demoted I");
207 irpass::analysis::verify(ir);
208 if (config.cache_loop_invariant_global_vars) {
209 irpass::cache_loop_invariant_global_vars(ir, config);
210 print("Cache loop-invariant global vars");
211 }
212
213 if (config.demote_dense_struct_fors) {
214 irpass::demote_dense_struct_fors(ir);
215 irpass::type_check(ir, config);
216 print("Dense struct-for demoted");
217 irpass::analysis::verify(ir);
218 }
219
220 if (is_extension_supported(config.arch, Extension::mesh) &&
221 config.demote_no_access_mesh_fors) {
222 irpass::demote_no_access_mesh_fors(ir);
223 irpass::type_check(ir, config);
224 print("No-access mesh-for demoted");
225 irpass::analysis::verify(ir);
226 }
227
228 if (make_thread_local) {
229 irpass::make_thread_local(ir, config);
230 print("Make thread local");
231 }
232
233 if (is_extension_supported(config.arch, Extension::mesh)) {
234 irpass::make_mesh_thread_local(ir, config, {kernel->get_name()});
235 print("Make mesh thread local");
236 if (config.make_mesh_block_local && config.arch == Arch::cuda) {
237 irpass::make_mesh_block_local(ir, config, {kernel->get_name()});
238 print("Make mesh block local");
239 irpass::full_simplify(
240 ir, config, {false, /*autodiff_enabled*/ false, kernel->program});
241 print("Simplified X");
242 }
243 }
244
245 if (make_block_local) {
246 irpass::make_block_local(ir, config, {kernel->get_name()});
247 print("Make block local");
248 }
249
250 if (is_extension_supported(config.arch, Extension::mesh)) {
251 irpass::demote_mesh_statements(ir, config, {kernel->get_name()});
252 print("Demote mesh statements");
253 }
254
255 irpass::demote_atomics(ir, config);
256 print("Atomics demoted II");
257 irpass::analysis::verify(ir);
258
259 if (is_extension_supported(config.arch, Extension::quant) &&
260 config.quant_opt_atomic_demotion) {
261 irpass::analysis::gather_uniquely_accessed_bit_structs(ir, amgr.get());
262 }
263
264 irpass::remove_range_assumption(ir);
265 print("Remove range assumption");
266
267 irpass::remove_loop_unique(ir);
268 print("Remove loop_unique");
269 irpass::analysis::verify(ir);
270
271 if (lower_global_access) {
272 irpass::full_simplify(ir, config,
273 {false, /*autodiff_enabled*/ false, kernel->program});
274 print("Simplified before lower access");
275 irpass::lower_access(ir, config, {kernel->no_activate, true});
276 print("Access lowered");
277 irpass::analysis::verify(ir);
278
279 irpass::die(ir);
280 print("DIE");
281 irpass::analysis::verify(ir);
282
283 irpass::flag_access(ir);
284 print("Access flagged III");
285 irpass::analysis::verify(ir);
286 }
287
288 irpass::demote_operations(ir, config);
289 print("Operations demoted");
290
291 irpass::full_simplify(
292 ir, config,
293 {lower_global_access, /*autodiff_enabled*/ false, kernel->program});
294 print("Simplified IV");
295
296 if (determine_ad_stack_size) {
297 irpass::determine_ad_stack_size(ir, config);
298 print("Autodiff stack size determined");
299 }
300
301 if (is_extension_supported(config.arch, Extension::quant)) {
302 irpass::optimize_bit_struct_stores(ir, config, amgr.get());
303 print("Bit struct stores optimized");
304 }
305
306 // Final field registration correctness & type checking
307 irpass::type_check(ir, config);
308 irpass::analysis::verify(ir);
309}
310
311void compile_to_executable(IRNode *ir,
312 const CompileConfig &config,
313 Kernel *kernel,
314 AutodiffMode autodiff_mode,
315 bool ad_use_stack,
316 bool verbose,
317 bool lower_global_access,
318 bool make_thread_local,
319 bool make_block_local,
320 bool start_from_ast) {
321 TI_AUTO_PROF;
322
323 compile_to_offloads(ir, config, kernel, verbose, autodiff_mode, ad_use_stack,
324 start_from_ast);
325
326 offload_to_executable(
327 ir, config, kernel, verbose,
328 /*determine_ad_stack_size=*/autodiff_mode == AutodiffMode::kReverse &&
329 ad_use_stack,
330 lower_global_access, make_thread_local, make_block_local);
331}
332
333void compile_function(IRNode *ir,
334 const CompileConfig &config,
335 Function *func,
336 AutodiffMode autodiff_mode,
337 bool verbose,
338 bool start_from_ast) {
339 TI_AUTO_PROF;
340
341 auto print = make_pass_printer(verbose, func->get_name(), ir);
342 print("Initial IR");
343
344 if (autodiff_mode == AutodiffMode::kReverse) {
345 irpass::reverse_segments(ir);
346 print("Segment reversed (for autodiff)");
347 }
348
349 if (start_from_ast) {
350 irpass::frontend_type_check(ir);
351 irpass::lower_ast(ir);
352 print("Lowered");
353 }
354
355 if (config.real_matrix_scalarize) {
356 irpass::scalarize(ir);
357
358 // Remove redundant MatrixInitStmt inserted during scalarization
359 irpass::die(ir);
360 print("Scalarized");
361 }
362
363 irpass::lower_access(ir, config, {{}, true});
364 print("Access lowered");
365 irpass::analysis::verify(ir);
366
367 irpass::die(ir);
368 print("DIE");
369 irpass::analysis::verify(ir);
370
371 irpass::flag_access(ir);
372 print("Access flagged III");
373 irpass::analysis::verify(ir);
374
375 irpass::type_check(ir, config);
376 print("Typechecked");
377
378 irpass::demote_operations(ir, config);
379 print("Operations demoted");
380
381 irpass::full_simplify(
382 ir, config, {false, autodiff_mode != AutodiffMode::kNone, func->program});
383 print("Simplified");
384 irpass::analysis::verify(ir);
385}
386
387} // namespace irpass
388
389} // namespace taichi::lang
390