1#pragma once
2
3#include <atomic>
4#include <optional>
5#include <unordered_map>
6#include <unordered_set>
7
8#include "taichi/ir/control_flow_graph.h"
9#include "taichi/ir/ir.h"
10#include "taichi/ir/pass.h"
11#include "taichi/transforms/check_out_of_bound.h"
12#include "taichi/transforms/constant_fold.h"
13#include "taichi/transforms/inlining.h"
14#include "taichi/transforms/lower_access.h"
15#include "taichi/transforms/make_block_local.h"
16#include "taichi/transforms/make_mesh_block_local.h"
17#include "taichi/transforms/demote_mesh_statements.h"
18#include "taichi/transforms/simplify.h"
19#include "taichi/common/trait.h"
20
21namespace taichi::lang {
22
23class ScratchPads;
24
25class Function;
26
27// IR passes
28namespace irpass {
29
30void re_id(IRNode *root);
31void flag_access(IRNode *root);
32void eliminate_immutable_local_vars(IRNode *root);
33void scalarize(IRNode *root);
34void lower_matrix_ptr(IRNode *root);
35bool die(IRNode *root);
36bool simplify(IRNode *root, const CompileConfig &config);
37bool cfg_optimization(
38 IRNode *root,
39 bool after_lower_access,
40 bool autodiff_enabled,
41 bool real_matrix_enabled,
42 const std::optional<ControlFlowGraph::LiveVarAnalysisConfig>
43 &lva_config_opt = std::nullopt);
44bool alg_simp(IRNode *root, const CompileConfig &config);
45bool demote_operations(IRNode *root, const CompileConfig &config);
46bool binary_op_simplify(IRNode *root, const CompileConfig &config);
47bool whole_kernel_cse(IRNode *root);
48bool extract_constant(IRNode *root, const CompileConfig &config);
49bool unreachable_code_elimination(IRNode *root);
50bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config);
51bool cache_loop_invariant_global_vars(IRNode *root,
52 const CompileConfig &config);
53void full_simplify(IRNode *root,
54 const CompileConfig &config,
55 const FullSimplifyPass::Args &args);
56void print(IRNode *root, std::string *output = nullptr);
57void frontend_type_check(IRNode *root);
58void lower_ast(IRNode *root);
59void type_check(IRNode *root, const CompileConfig &config);
60bool inlining(IRNode *root,
61 const CompileConfig &config,
62 const InliningPass::Args &args);
63void bit_loop_vectorize(IRNode *root);
64void slp_vectorize(IRNode *root);
65void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt);
66bool check_out_of_bound(IRNode *root,
67 const CompileConfig &config,
68 const CheckOutOfBoundPass::Args &args);
69void make_thread_local(IRNode *root, const CompileConfig &config);
70std::unique_ptr<ScratchPads> initialize_scratch_pad(OffloadedStmt *root);
71void make_block_local(IRNode *root,
72 const CompileConfig &config,
73 const MakeBlockLocalPass::Args &args);
74void make_mesh_thread_local(IRNode *root,
75 const CompileConfig &config,
76 const MakeBlockLocalPass::Args &args);
77void make_mesh_block_local(IRNode *root,
78 const CompileConfig &config,
79 const MakeMeshBlockLocal::Args &args);
80void demote_mesh_statements(IRNode *root,
81 const CompileConfig &config,
82 const DemoteMeshStatements::Args &args);
83bool remove_loop_unique(IRNode *root);
84bool remove_range_assumption(IRNode *root);
85bool lower_access(IRNode *root,
86 const CompileConfig &config,
87 const LowerAccessPass::Args &args);
88void auto_diff(IRNode *root,
89 const CompileConfig &config,
90 AutodiffMode autodiffMode,
91 bool use_stack = false);
92/**
93 * Check whether the kernel obeys the autodiff limitation e.g., gloabl data
94 * access rule
95 */
96void differentiation_validation_check(IRNode *root,
97 const CompileConfig &config,
98 const std::string &kernel_name);
99/**
100 * Determine all adaptive AD-stacks' size. This pass is idempotent, i.e.,
101 * there are no side effects if called more than once or called when not needed.
102 * @return Whether the IR is modified, i.e., whether there exists adaptive
103 * AD-stacks before this pass.
104 */
105bool determine_ad_stack_size(IRNode *root, const CompileConfig &config);
106bool constant_fold(IRNode *root,
107 const CompileConfig &config,
108 const ConstantFoldPass::Args &args);
109void offload(IRNode *root, const CompileConfig &config);
110bool transform_statements(
111 IRNode *root,
112 std::function<bool(Stmt *)> filter,
113 std::function<void(Stmt *, DelayedIRModifier *)> transformer);
114/**
115 * @param root The IR root to be traversed.
116 * @param filter A function which tells if a statement need to be replaced.
117 * @param generator If a statement |s| need to be replaced, generate a new
118 * statement |s1| with the argument |s|, insert |s1| to where |s| is defined,
119 * remove |s|'s definition, and replace all usages of |s| with |s1|.
120 * @return Whether the IR is modified.
121 */
122bool replace_and_insert_statements(
123 IRNode *root,
124 std::function<bool(Stmt *)> filter,
125 std::function<std::unique_ptr<Stmt>(Stmt *)> generator);
126/**
127 * @param finder If a statement |s| need to be replaced, find the existing
128 * statement |s1| with the argument |s|, remove |s|'s definition, and replace
129 * all usages of |s| with |s1|.
130 */
131bool replace_statements(IRNode *root,
132 std::function<bool(Stmt *)> filter,
133 std::function<Stmt *(Stmt *)> finder);
134void demote_dense_struct_fors(IRNode *root);
135void demote_no_access_mesh_fors(IRNode *root);
136bool demote_atomics(IRNode *root, const CompileConfig &config);
137void reverse_segments(IRNode *root); // for autograd
138void detect_read_only(IRNode *root);
139void optimize_bit_struct_stores(IRNode *root,
140 const CompileConfig &config,
141 AnalysisManager *amgr);
142
143ENUM_FLAGS(ExternalPtrAccess){NONE = 0, READ = 1, WRITE = 2};
144
145/**
146 * Checks the access to external pointers in an offload.
147 *
148 * @param val1
149 * The offloaded statement to check
150 *
151 * @return
152 * The analyzed result.
153 */
154std::unordered_map<int, ExternalPtrAccess> detect_external_ptr_access_in_task(
155 OffloadedStmt *offload);
156
157// compile_to_offloads does the basic compilation to create all the offloaded
158// tasks of a Taichi kernel. It's worth pointing out that this doesn't demote
159// dense struct fors. This is a necessary workaround to prevent the async
160// engine from fusing incompatible offloaded tasks. TODO(Lin): check this
161// comment
162void compile_to_offloads(IRNode *ir,
163 const CompileConfig &config,
164 Kernel *kernel,
165 bool verbose,
166 AutodiffMode autodiff_mode,
167 bool ad_use_stack,
168 bool start_from_ast);
169
170void offload_to_executable(IRNode *ir,
171 const CompileConfig &config,
172 Kernel *kernel,
173 bool verbose,
174 bool determine_ad_stack_size,
175 bool lower_global_access,
176 bool make_thread_local,
177 bool make_block_local);
178// compile_to_executable fully covers compile_to_offloads, and also does
179// additional optimizations so that |ir| can be directly fed into codegen.
180void compile_to_executable(IRNode *ir,
181 const CompileConfig &config,
182 Kernel *kernel,
183 AutodiffMode autodiff_mode,
184 bool ad_use_stack,
185 bool verbose,
186 bool lower_global_access = true,
187 bool make_thread_local = false,
188 bool make_block_local = false,
189 bool start_from_ast = true);
190// Compile a function with some basic optimizations, so that the number of
191// statements is reduced before inlining.
192void compile_function(IRNode *ir,
193 const CompileConfig &config,
194 Function *func,
195 AutodiffMode autodiff_mode,
196 bool verbose,
197 bool start_from_ast);
198
199void ast_to_ir(const CompileConfig &config,
200 Kernel &kernel,
201 bool to_executable = true);
202
203void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config);
204} // namespace irpass
205
206} // namespace taichi::lang
207