1 | #include <cmath> |
2 | #include <deque> |
3 | #include <set> |
4 | #include <thread> |
5 | |
6 | #include "taichi/ir/ir.h" |
7 | #include "taichi/ir/snode.h" |
8 | #include "taichi/ir/statements.h" |
9 | #include "taichi/ir/transforms.h" |
10 | #include "taichi/ir/visitors.h" |
11 | #include "taichi/transforms/constant_fold.h" |
12 | #include "taichi/program/program.h" |
13 | |
14 | namespace taichi::lang { |
15 | |
16 | class ConstantFold : public BasicStmtVisitor { |
17 | public: |
18 | using BasicStmtVisitor::visit; |
19 | DelayedIRModifier modifier; |
20 | Program *program; |
21 | CompileConfig compile_config; |
22 | |
23 | explicit ConstantFold(Program *program, const CompileConfig &compile_config) |
24 | : program(program), compile_config(compile_config) { |
25 | this->compile_config.advanced_optimization = false; |
26 | this->compile_config.constant_folding = false; |
27 | this->compile_config.external_optimization_level = 0; |
28 | } |
29 | |
30 | Kernel *get_jit_evaluator_kernel(JITEvaluatorId const &id) { |
31 | auto &cache = program->jit_evaluator_cache; |
32 | // Discussion: |
33 | // https://github.com/taichi-dev/taichi/pull/954#discussion_r423442606 |
34 | std::lock_guard<std::mutex> _(program->jit_evaluator_cache_mut); |
35 | auto it = cache.find(id); |
36 | if (it != cache.end()) // cached? |
37 | return it->second.get(); |
38 | |
39 | auto kernel_name = fmt::format("jit_evaluator_{}" , cache.size()); |
40 | auto func = [&id](Kernel *kernel) { |
41 | auto lhstmt = |
42 | Stmt::make<ArgLoadStmt>(/*arg_id=*/0, id.lhs, /*is_ptr=*/false); |
43 | auto rhstmt = |
44 | Stmt::make<ArgLoadStmt>(/*arg_id=*/1, id.rhs, /*is_ptr=*/false); |
45 | pStmt oper; |
46 | if (id.is_binary) { |
47 | oper = Stmt::make<BinaryOpStmt>(id.binary_op(), lhstmt.get(), |
48 | rhstmt.get()); |
49 | oper->set_tb(id.tb); |
50 | } else { |
51 | oper = Stmt::make<UnaryOpStmt>(id.unary_op(), lhstmt.get()); |
52 | if (unary_op_is_cast(id.unary_op())) { |
53 | oper->cast<UnaryOpStmt>()->cast_type = id.rhs; |
54 | } |
55 | } |
56 | auto &ast_builder = kernel->context->builder(); |
57 | auto ret = Stmt::make<ReturnStmt>(oper.get()); |
58 | ast_builder.insert(std::move(lhstmt)); |
59 | if (id.is_binary) { |
60 | ast_builder.insert(std::move(rhstmt)); |
61 | } |
62 | ast_builder.insert(std::move(oper)); |
63 | ast_builder.insert(std::move(ret)); |
64 | }; |
65 | |
66 | auto ker = std::make_unique<Kernel>(*program, func, kernel_name); |
67 | ker->insert_ret(id.ret); |
68 | ker->insert_scalar_param(id.lhs); |
69 | if (id.is_binary) |
70 | ker->insert_scalar_param(id.rhs); |
71 | ker->is_evaluator = true; |
72 | ker->finalize_rets(); |
73 | |
74 | auto *ker_ptr = ker.get(); |
75 | TI_TRACE("Saving JIT evaluator cache entry id={}" , |
76 | std::hash<JITEvaluatorId>{}(id)); |
77 | cache[id] = std::move(ker); |
78 | |
79 | return ker_ptr; |
80 | } |
81 | |
82 | static bool is_good_type(DataType dt) { |
83 | // ConstStmt of `bad` types like `i8` is not supported by LLVM. |
84 | // Discussion: |
85 | // https://github.com/taichi-dev/taichi/pull/839#issuecomment-625902727 |
86 | if (dt->is_primitive(PrimitiveTypeID::i32) || |
87 | dt->is_primitive(PrimitiveTypeID::i64) || |
88 | dt->is_primitive(PrimitiveTypeID::u32) || |
89 | dt->is_primitive(PrimitiveTypeID::u64) || |
90 | dt->is_primitive(PrimitiveTypeID::f32) || |
91 | dt->is_primitive(PrimitiveTypeID::f64)) |
92 | return true; |
93 | else |
94 | return false; |
95 | } |
96 | |
97 | bool jit_evaluate_binary_op(TypedConstant &ret, |
98 | BinaryOpStmt *stmt, |
99 | const TypedConstant &lhs, |
100 | const TypedConstant &rhs) { |
101 | if (!is_good_type(ret.dt)) |
102 | return false; |
103 | JITEvaluatorId id{std::this_thread::get_id(), |
104 | (int)stmt->op_type, |
105 | ret.dt, |
106 | lhs.dt, |
107 | rhs.dt, |
108 | compile_config.debug ? stmt->tb : "" , |
109 | true}; |
110 | auto *ker = get_jit_evaluator_kernel(id); |
111 | auto launch_ctx = ker->make_launch_context(); |
112 | launch_ctx.set_arg_raw(0, lhs.val_u64); |
113 | launch_ctx.set_arg_raw(1, rhs.val_u64); |
114 | { |
115 | std::lock_guard<std::mutex> _(program->jit_evaluator_cache_mut); |
116 | (*ker)(compile_config, launch_ctx); |
117 | ret.val_i64 = program->fetch_result<int64_t>(0); |
118 | } |
119 | return true; |
120 | } |
121 | |
122 | bool jit_evaluate_unary_op(TypedConstant &ret, |
123 | UnaryOpStmt *stmt, |
124 | const TypedConstant &operand) { |
125 | if (!is_good_type(ret.dt)) |
126 | return false; |
127 | JITEvaluatorId id{std::this_thread::get_id(), |
128 | (int)stmt->op_type, |
129 | ret.dt, |
130 | operand.dt, |
131 | stmt->cast_type, |
132 | "" , |
133 | false}; |
134 | auto *ker = get_jit_evaluator_kernel(id); |
135 | auto launch_ctx = ker->make_launch_context(); |
136 | launch_ctx.set_arg_raw(0, operand.val_u64); |
137 | { |
138 | std::lock_guard<std::mutex> _(program->jit_evaluator_cache_mut); |
139 | (*ker)(compile_config, launch_ctx); |
140 | ret.val_i64 = program->fetch_result<int64_t>(0); |
141 | } |
142 | return true; |
143 | } |
144 | |
145 | void visit(BinaryOpStmt *stmt) override { |
146 | auto lhs = stmt->lhs->cast<ConstStmt>(); |
147 | auto rhs = stmt->rhs->cast<ConstStmt>(); |
148 | if (!lhs || !rhs) |
149 | return; |
150 | auto dst_type = stmt->ret_type; |
151 | TypedConstant new_constant(dst_type); |
152 | |
153 | if (stmt->op_type == BinaryOpType::pow) { |
154 | if (is_integral(rhs->ret_type)) { |
155 | auto rhs_val = rhs->val.val_int(); |
156 | if (rhs_val < 0 && is_integral(stmt->ret_type)) { |
157 | TI_ERROR("Negative exponent in pow(int, int) is not allowed." ); |
158 | } |
159 | } |
160 | } |
161 | |
162 | if (jit_evaluate_binary_op(new_constant, stmt, lhs->val, rhs->val)) { |
163 | auto evaluated = Stmt::make<ConstStmt>(TypedConstant(new_constant)); |
164 | stmt->replace_usages_with(evaluated.get()); |
165 | modifier.insert_before(stmt, std::move(evaluated)); |
166 | modifier.erase(stmt); |
167 | } |
168 | } |
169 | |
170 | void visit(UnaryOpStmt *stmt) override { |
171 | if (stmt->is_cast() && stmt->cast_type == stmt->operand->ret_type) { |
172 | stmt->replace_usages_with(stmt->operand); |
173 | modifier.erase(stmt); |
174 | return; |
175 | } |
176 | auto operand = stmt->operand->cast<ConstStmt>(); |
177 | if (!operand) |
178 | return; |
179 | if (stmt->is_cast()) { |
180 | bool cast_available = true; |
181 | TypedConstant new_constant(stmt->ret_type); |
182 | auto operand = stmt->operand->cast<ConstStmt>(); |
183 | if (stmt->op_type == UnaryOpType::cast_bits) { |
184 | new_constant.value_bits = operand->val.value_bits; |
185 | } else { |
186 | if (stmt->cast_type == PrimitiveType::f32) { |
187 | new_constant.val_f32 = float32(operand->val.val_cast_to_float64()); |
188 | } else if (stmt->cast_type == PrimitiveType::f64) { |
189 | new_constant.val_f64 = operand->val.val_cast_to_float64(); |
190 | } else { |
191 | cast_available = false; |
192 | } |
193 | } |
194 | if (cast_available) { |
195 | auto evaluated = Stmt::make<ConstStmt>(TypedConstant(new_constant)); |
196 | stmt->replace_usages_with(evaluated.get()); |
197 | modifier.insert_before(stmt, std::move(evaluated)); |
198 | modifier.erase(stmt); |
199 | return; |
200 | } |
201 | } |
202 | auto dst_type = stmt->ret_type; |
203 | TypedConstant new_constant(dst_type); |
204 | if (jit_evaluate_unary_op(new_constant, stmt, operand->val)) { |
205 | auto evaluated = Stmt::make<ConstStmt>(TypedConstant(new_constant)); |
206 | stmt->replace_usages_with(evaluated.get()); |
207 | modifier.insert_before(stmt, std::move(evaluated)); |
208 | modifier.erase(stmt); |
209 | } |
210 | } |
211 | |
212 | static bool run(IRNode *node, |
213 | Program *program, |
214 | const CompileConfig &compile_config) { |
215 | ConstantFold folder(program, compile_config); |
216 | bool modified = false; |
217 | |
218 | while (true) { |
219 | node->accept(&folder); |
220 | if (folder.modifier.modify_ir()) { |
221 | modified = true; |
222 | } else { |
223 | break; |
224 | } |
225 | } |
226 | |
227 | return modified; |
228 | } |
229 | }; |
230 | |
231 | const PassID ConstantFoldPass::id = "ConstantFoldPass" ; |
232 | |
233 | namespace irpass { |
234 | |
235 | bool constant_fold(IRNode *root, |
236 | const CompileConfig &compile_config, |
237 | const ConstantFoldPass::Args &args) { |
238 | TI_AUTO_PROF; |
239 | if (!compile_config.advanced_optimization) |
240 | return false; |
241 | return ConstantFold::run(root, args.program, compile_config); |
242 | } |
243 | |
244 | } // namespace irpass |
245 | |
246 | } // namespace taichi::lang |
247 | |