1 | #include "taichi/ir/analysis.h" |
2 | #include "taichi/ir/ir.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/transforms.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/program/program.h" |
7 | #include "taichi/util/bit.h" |
8 | |
9 | namespace taichi::lang { |
10 | |
11 | // Algebraic Simplification and Strength Reduction |
12 | class AlgSimp : public BasicStmtVisitor { |
13 | private: |
14 | void cast_to_result_type(Stmt *&a, Stmt *stmt) { |
15 | if (stmt->ret_type != a->ret_type) { |
16 | auto cast = Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_value, a); |
17 | cast->cast_type = stmt->ret_type; |
18 | cast->ret_type = stmt->ret_type; |
19 | a = cast.get(); |
20 | modifier.insert_before(stmt, std::move(cast)); |
21 | } |
22 | } |
23 | |
24 | void replace_with_zero(Stmt *stmt) { |
25 | auto zero = Stmt::make<ConstStmt>(TypedConstant(stmt->ret_type)); |
26 | stmt->replace_usages_with(zero.get()); |
27 | modifier.insert_before(stmt, std::move(zero)); |
28 | modifier.erase(stmt); |
29 | } |
30 | |
31 | void replace_with_one(Stmt *stmt) { |
32 | auto one = Stmt::make<ConstStmt>(TypedConstant(1)); |
33 | auto one_raw = one.get(); |
34 | modifier.insert_before(stmt, std::move(one)); |
35 | cast_to_result_type(one_raw, stmt); |
36 | stmt->replace_usages_with(one_raw); |
37 | modifier.erase(stmt); |
38 | } |
39 | |
40 | public: |
41 | static constexpr int max_weaken_exponent = 32; |
42 | using BasicStmtVisitor::visit; |
43 | bool fast_math; |
44 | DelayedIRModifier modifier; |
45 | |
46 | explicit AlgSimp(bool fast_math_) : fast_math(fast_math_) { |
47 | } |
48 | |
49 | [[nodiscard]] bool is_redundant_cast(const DataType &first_cast, |
50 | const DataType &second_cast) const { |
51 | // Tests if second_cast(first_cast(a)) is guaranteed to be equivalent to |
52 | // second_cast(a). |
53 | if (!first_cast->is<PrimitiveType>() || !second_cast->is<PrimitiveType>()) { |
54 | // TODO(type): handle this case |
55 | return false; |
56 | } |
57 | if (is_real(second_cast)) { |
58 | // float(...(a)) |
59 | return is_real(first_cast) && |
60 | data_type_bits(second_cast) <= data_type_bits(first_cast); |
61 | } |
62 | if (is_integral(first_cast)) { |
63 | // int(int(a)) |
64 | return data_type_bits(second_cast) <= data_type_bits(first_cast); |
65 | } |
66 | // int(float(a)) |
67 | if (data_type_bits(second_cast) <= data_type_bits(first_cast) * 2) { |
68 | // f64 can hold any i32 values. |
69 | return true; |
70 | } else { |
71 | // Assume a floating point type can hold any integer values when |
72 | // fast_math=True. |
73 | return fast_math; |
74 | } |
75 | } |
76 | |
77 | void visit(UnaryOpStmt *stmt) override { |
78 | if (stmt->is_cast()) { |
79 | if (stmt->cast_type == stmt->operand->ret_type) { |
80 | stmt->replace_usages_with(stmt->operand); |
81 | modifier.erase(stmt); |
82 | } else if (stmt->operand->is<UnaryOpStmt>() && |
83 | stmt->operand->as<UnaryOpStmt>()->is_cast()) { |
84 | auto prev_cast = stmt->operand->as<UnaryOpStmt>(); |
85 | if (stmt->op_type == UnaryOpType::cast_bits && |
86 | prev_cast->op_type == UnaryOpType::cast_bits) { |
87 | stmt->operand = prev_cast->operand; |
88 | modifier.mark_as_modified(); |
89 | } else if (stmt->op_type == UnaryOpType::cast_value && |
90 | prev_cast->op_type == UnaryOpType::cast_value && |
91 | is_redundant_cast(prev_cast->cast_type, stmt->cast_type)) { |
92 | stmt->operand = prev_cast->operand; |
93 | modifier.mark_as_modified(); |
94 | } |
95 | } |
96 | } |
97 | } |
98 | |
99 | bool optimize_multiplication(BinaryOpStmt *stmt) { |
100 | // return true iff the IR is modified |
101 | auto lhs = stmt->lhs->cast<ConstStmt>(); |
102 | auto rhs = stmt->rhs->cast<ConstStmt>(); |
103 | TI_ASSERT(stmt->op_type == BinaryOpType::mul); |
104 | if (alg_is_one(lhs) || alg_is_one(rhs)) { |
105 | // 1 * a -> a, a * 1 -> a |
106 | stmt->replace_usages_with(alg_is_one(lhs) ? stmt->rhs : stmt->lhs); |
107 | modifier.erase(stmt); |
108 | return true; |
109 | } |
110 | if ((fast_math || is_integral(stmt->ret_type)) && |
111 | (alg_is_zero(lhs) || alg_is_zero(rhs))) { |
112 | // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 |
113 | replace_with_zero(stmt); |
114 | return true; |
115 | } |
116 | if (is_integral(stmt->ret_type) && (alg_is_pot(lhs) || alg_is_pot(rhs))) { |
117 | // a * pot -> a << log2(pot) |
118 | if (alg_is_pot(lhs)) { |
119 | std::swap(stmt->lhs, stmt->rhs); |
120 | std::swap(lhs, rhs); |
121 | } |
122 | int log2rhs = bit::log2int((uint64)rhs->val.val_as_int64()); |
123 | auto new_rhs = |
124 | Stmt::make<ConstStmt>(TypedConstant(stmt->lhs->ret_type, log2rhs)); |
125 | auto result = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_shl, stmt->lhs, |
126 | new_rhs.get()); |
127 | result->ret_type = stmt->ret_type; |
128 | result->set_tb(stmt->tb); |
129 | stmt->replace_usages_with(result.get()); |
130 | modifier.insert_before(stmt, std::move(new_rhs)); |
131 | modifier.insert_before(stmt, std::move(result)); |
132 | modifier.erase(stmt); |
133 | return true; |
134 | } |
135 | if (alg_is_two(lhs) || alg_is_two(rhs)) { |
136 | // 2 * a -> a + a, a * 2 -> a + a |
137 | auto a = stmt->lhs; |
138 | if (alg_is_two(lhs)) |
139 | a = stmt->rhs; |
140 | cast_to_result_type(a, stmt); |
141 | auto sum = Stmt::make<BinaryOpStmt>(BinaryOpType::add, a, a); |
142 | sum->ret_type = a->ret_type; |
143 | sum->set_tb(stmt->tb); |
144 | stmt->replace_usages_with(sum.get()); |
145 | modifier.insert_before(stmt, std::move(sum)); |
146 | modifier.erase(stmt); |
147 | return true; |
148 | } |
149 | return false; |
150 | } |
151 | |
152 | bool optimize_division(BinaryOpStmt *stmt) { |
153 | // return true iff the IR is modified |
154 | auto rhs = stmt->rhs->cast<ConstStmt>(); |
155 | TI_ASSERT(stmt->op_type == BinaryOpType::div || |
156 | stmt->op_type == BinaryOpType::floordiv); |
157 | if (alg_is_one(rhs) && !(is_real(stmt->lhs->ret_type) && |
158 | stmt->op_type == BinaryOpType::floordiv)) { |
159 | // a / 1 -> a |
160 | stmt->replace_usages_with(stmt->lhs); |
161 | modifier.erase(stmt); |
162 | return true; |
163 | } |
164 | if ((fast_math || is_integral(stmt->ret_type)) && |
165 | irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { |
166 | // fast_math or integral operands: a / a -> 1 |
167 | replace_with_one(stmt); |
168 | return true; |
169 | } |
170 | if (fast_math && rhs && is_real(rhs->ret_type) && |
171 | stmt->op_type != BinaryOpType::floordiv) { |
172 | if (alg_is_zero(rhs)) { |
173 | TI_WARN("Potential division by 0\n{}" , stmt->tb); |
174 | } else { |
175 | // a / const -> a * (1 / const) |
176 | auto reciprocal = |
177 | Stmt::make_typed<ConstStmt>(TypedConstant(rhs->ret_type)); |
178 | if (rhs->ret_type->is_primitive(PrimitiveTypeID::f64)) { |
179 | reciprocal->val.val_float64() = (float64)1.0 / rhs->val.val_float64(); |
180 | } else if (rhs->ret_type->is_primitive(PrimitiveTypeID::f32)) { |
181 | reciprocal->val.val_float32() = (float32)1.0 / rhs->val.val_float32(); |
182 | } else { |
183 | TI_NOT_IMPLEMENTED |
184 | } |
185 | auto product = Stmt::make<BinaryOpStmt>(BinaryOpType::mul, stmt->lhs, |
186 | reciprocal.get()); |
187 | product->ret_type = stmt->ret_type; |
188 | stmt->replace_usages_with(product.get()); |
189 | modifier.insert_before(stmt, std::move(reciprocal)); |
190 | modifier.insert_before(stmt, std::move(product)); |
191 | modifier.erase(stmt); |
192 | return true; |
193 | } |
194 | } |
195 | if (is_integral(stmt->lhs->ret_type) && is_unsigned(stmt->lhs->ret_type) && |
196 | alg_is_pot(rhs)) { |
197 | // (unsigned)a / pot -> a >> log2(pot) |
198 | int log2rhs = bit::log2int((uint64)rhs->val.val_as_int64()); |
199 | auto new_rhs = |
200 | Stmt::make<ConstStmt>(TypedConstant(stmt->lhs->ret_type, log2rhs)); |
201 | auto result = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_sar, stmt->lhs, |
202 | new_rhs.get()); |
203 | result->ret_type = stmt->ret_type; |
204 | stmt->replace_usages_with(result.get()); |
205 | modifier.insert_before(stmt, std::move(new_rhs)); |
206 | modifier.insert_before(stmt, std::move(result)); |
207 | modifier.erase(stmt); |
208 | return true; |
209 | } |
210 | return false; |
211 | } |
212 | |
213 | void visit(BinaryOpStmt *stmt) override { |
214 | if (stmt->lhs->ret_type->is<TensorType>() || |
215 | stmt->rhs->ret_type->is<TensorType>()) { |
216 | // TODO: support tensor type |
217 | return; |
218 | } |
219 | auto lhs = stmt->lhs->cast<ConstStmt>(); |
220 | auto rhs = stmt->rhs->cast<ConstStmt>(); |
221 | if (stmt->op_type == BinaryOpType::mul) { |
222 | optimize_multiplication(stmt); |
223 | } else if (stmt->op_type == BinaryOpType::div || |
224 | stmt->op_type == BinaryOpType::floordiv) { |
225 | optimize_division(stmt); |
226 | } else if (stmt->op_type == BinaryOpType::add || |
227 | stmt->op_type == BinaryOpType::sub || |
228 | stmt->op_type == BinaryOpType::bit_or || |
229 | stmt->op_type == BinaryOpType::bit_xor) { |
230 | if (alg_is_zero(rhs)) { |
231 | // a +-|^ 0 -> a |
232 | stmt->replace_usages_with(stmt->lhs); |
233 | modifier.erase(stmt); |
234 | } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs)) { |
235 | // 0 +|^ a -> a |
236 | stmt->replace_usages_with(stmt->rhs); |
237 | modifier.erase(stmt); |
238 | } else if (stmt->op_type == BinaryOpType::bit_or && |
239 | irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { |
240 | // a | a -> a |
241 | stmt->replace_usages_with(stmt->lhs); |
242 | modifier.erase(stmt); |
243 | } else if ((stmt->op_type == BinaryOpType::sub || |
244 | stmt->op_type == BinaryOpType::bit_xor) && |
245 | (fast_math || is_integral(stmt->ret_type)) && |
246 | irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { |
247 | // fast_math or integral operands: a -^ a -> 0 |
248 | replace_with_zero(stmt); |
249 | } |
250 | } else if (rhs && stmt->op_type == BinaryOpType::pow) { |
251 | float64 exponent = rhs->val.val_cast_to_float64(); |
252 | if (exponent == 1) { |
253 | // a ** 1 -> a |
254 | stmt->replace_usages_with(stmt->lhs); |
255 | modifier.erase(stmt); |
256 | } else if (exponent == 0) { |
257 | // a ** 0 -> 1 |
258 | replace_with_one(stmt); |
259 | } else if (exponent == 0.5) { |
260 | // a ** 0.5 -> sqrt(a) |
261 | auto a = stmt->lhs; |
262 | cast_to_result_type(a, stmt); |
263 | auto result = Stmt::make<UnaryOpStmt>(UnaryOpType::sqrt, a); |
264 | result->ret_type = a->ret_type; |
265 | stmt->replace_usages_with(result.get()); |
266 | modifier.insert_before(stmt, std::move(result)); |
267 | modifier.erase(stmt); |
268 | } else if (exponent == std::round(exponent) && exponent > 0 && |
269 | exponent <= max_weaken_exponent) { |
270 | // a ** n -> Exponentiation by squaring |
271 | auto a = stmt->lhs; |
272 | cast_to_result_type(a, stmt); |
273 | const int exp = exponent; |
274 | Stmt *result = nullptr; |
275 | auto a_power_of_2 = a; |
276 | int current_exponent = 1; |
277 | while (true) { |
278 | if (exp & current_exponent) { |
279 | if (!result) |
280 | result = a_power_of_2; |
281 | else { |
282 | auto new_result = Stmt::make<BinaryOpStmt>(BinaryOpType::mul, |
283 | result, a_power_of_2); |
284 | new_result->ret_type = a->ret_type; |
285 | result = new_result.get(); |
286 | modifier.insert_before(stmt, std::move(new_result)); |
287 | } |
288 | } |
289 | current_exponent <<= 1; |
290 | if (current_exponent > exp) |
291 | break; |
292 | auto new_a_power = Stmt::make<BinaryOpStmt>( |
293 | BinaryOpType::mul, a_power_of_2, a_power_of_2); |
294 | new_a_power->ret_type = a->ret_type; |
295 | a_power_of_2 = new_a_power.get(); |
296 | modifier.insert_before(stmt, std::move(new_a_power)); |
297 | } |
298 | stmt->replace_usages_with(result); |
299 | modifier.erase(stmt); |
300 | } else if (exponent == std::round(exponent) && exponent < 0 && |
301 | exponent >= -max_weaken_exponent) { |
302 | // a ** -n -> 1 / a ** n |
303 | if (is_integral(stmt->lhs->ret_type)) { |
304 | TI_ERROR("Negative exponent in pow(int, int) is not allowed." ); |
305 | } |
306 | auto one = Stmt::make<ConstStmt>(TypedConstant(1)); |
307 | auto one_raw = one.get(); |
308 | modifier.insert_before(stmt, std::move(one)); |
309 | cast_to_result_type(one_raw, stmt); |
310 | auto new_exponent = Stmt::make<UnaryOpStmt>(UnaryOpType::neg, rhs); |
311 | auto a_to_n = Stmt::make<BinaryOpStmt>(BinaryOpType::pow, stmt->lhs, |
312 | new_exponent.get()); |
313 | a_to_n->ret_type = stmt->ret_type; |
314 | auto result = |
315 | Stmt::make<BinaryOpStmt>(BinaryOpType::div, one_raw, a_to_n.get()); |
316 | stmt->replace_usages_with(result.get()); |
317 | modifier.insert_before(stmt, std::move(new_exponent)); |
318 | modifier.insert_before(stmt, std::move(a_to_n)); |
319 | modifier.insert_before(stmt, std::move(result)); |
320 | modifier.erase(stmt); |
321 | } |
322 | } else if (stmt->op_type == BinaryOpType::bit_and) { |
323 | if (alg_is_minus_one(rhs)) { |
324 | // a & -1 -> a |
325 | stmt->replace_usages_with(stmt->lhs); |
326 | modifier.erase(stmt); |
327 | } else if (alg_is_minus_one(lhs)) { |
328 | // -1 & a -> a |
329 | stmt->replace_usages_with(stmt->rhs); |
330 | modifier.erase(stmt); |
331 | } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) { |
332 | // 0 & a -> 0, a & 0 -> 0 |
333 | replace_with_zero(stmt); |
334 | } else if (irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { |
335 | // a & a -> a |
336 | stmt->replace_usages_with(stmt->lhs); |
337 | modifier.erase(stmt); |
338 | } |
339 | } else if (stmt->op_type == BinaryOpType::bit_sar || |
340 | stmt->op_type == BinaryOpType::bit_shl || |
341 | stmt->op_type == BinaryOpType::bit_shr) { |
342 | if (alg_is_zero(rhs) || alg_is_zero(lhs)) { |
343 | // a >> 0 -> a |
344 | // a << 0 -> a |
345 | // 0 << a -> 0 |
346 | // 0 >> a -> 0 |
347 | TI_ASSERT(stmt->lhs->ret_type == stmt->ret_type); |
348 | stmt->replace_usages_with(stmt->lhs); |
349 | modifier.erase(stmt); |
350 | } |
351 | } else if (is_comparison(stmt->op_type)) { |
352 | if ((fast_math || is_integral(stmt->lhs->ret_type)) && |
353 | irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { |
354 | // fast_math or integral operands: a == a -> 1, a != a -> 0 |
355 | if (stmt->op_type == BinaryOpType::cmp_eq || |
356 | stmt->op_type == BinaryOpType::cmp_ge || |
357 | stmt->op_type == BinaryOpType::cmp_le) { |
358 | replace_with_one(stmt); |
359 | } else if (stmt->op_type == BinaryOpType::cmp_ne || |
360 | stmt->op_type == BinaryOpType::cmp_gt || |
361 | stmt->op_type == BinaryOpType::cmp_lt) { |
362 | replace_with_zero(stmt); |
363 | } else { |
364 | TI_NOT_IMPLEMENTED |
365 | } |
366 | } |
367 | } |
368 | } |
369 | |
370 | void visit(AssertStmt *stmt) override { |
371 | auto cond = stmt->cond->cast<ConstStmt>(); |
372 | if (!cond) |
373 | return; |
374 | if (!alg_is_zero(cond)) { |
375 | // this statement has no effect |
376 | modifier.erase(stmt); |
377 | } |
378 | } |
379 | |
380 | void visit(WhileControlStmt *stmt) override { |
381 | auto cond = stmt->cond->cast<ConstStmt>(); |
382 | if (!cond) |
383 | return; |
384 | if (!alg_is_zero(cond)) { |
385 | // this statement has no effect |
386 | modifier.erase(stmt); |
387 | } |
388 | } |
389 | |
390 | static bool alg_is_zero(ConstStmt *stmt) { |
391 | if (!stmt) |
392 | return false; |
393 | return stmt->val.equal_value(0); |
394 | } |
395 | |
396 | static bool alg_is_one(ConstStmt *stmt) { |
397 | if (!stmt) |
398 | return false; |
399 | return stmt->val.equal_value(1); |
400 | } |
401 | |
402 | static bool alg_is_two(ConstStmt *stmt) { |
403 | if (!stmt) |
404 | return false; |
405 | return stmt->val.equal_value(2); |
406 | } |
407 | |
408 | static bool alg_is_minus_one(ConstStmt *stmt) { |
409 | if (!stmt) |
410 | return false; |
411 | return stmt->val.equal_value(-1); |
412 | } |
413 | |
414 | static bool alg_is_pot(ConstStmt *stmt) { |
415 | if (!stmt) |
416 | return false; |
417 | if (!is_integral(stmt->val.dt)) |
418 | return false; |
419 | if (is_signed(stmt->val.dt)) { |
420 | return bit::is_power_of_two(stmt->val.val_int()); |
421 | } else { |
422 | return bit::is_power_of_two(stmt->val.val_uint()); |
423 | } |
424 | } |
425 | |
426 | static bool run(IRNode *node, bool fast_math) { |
427 | AlgSimp simplifier(fast_math); |
428 | bool modified = false; |
429 | while (true) { |
430 | node->accept(&simplifier); |
431 | if (simplifier.modifier.modify_ir()) |
432 | modified = true; |
433 | else |
434 | break; |
435 | } |
436 | return modified; |
437 | } |
438 | }; |
439 | |
440 | namespace irpass { |
441 | |
442 | bool alg_simp(IRNode *root, const CompileConfig &config) { |
443 | TI_AUTO_PROF; |
444 | return AlgSimp::run(root, config.fast_math); |
445 | } |
446 | |
447 | } // namespace irpass |
448 | |
449 | } // namespace taichi::lang |
450 | |