1#include "taichi/ir/analysis.h"
2#include "taichi/ir/ir.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/transforms.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/program/program.h"
7#include "taichi/util/bit.h"
8
9namespace taichi::lang {
10
11// Algebraic Simplification and Strength Reduction
12class AlgSimp : public BasicStmtVisitor {
13 private:
14 void cast_to_result_type(Stmt *&a, Stmt *stmt) {
15 if (stmt->ret_type != a->ret_type) {
16 auto cast = Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_value, a);
17 cast->cast_type = stmt->ret_type;
18 cast->ret_type = stmt->ret_type;
19 a = cast.get();
20 modifier.insert_before(stmt, std::move(cast));
21 }
22 }
23
24 void replace_with_zero(Stmt *stmt) {
25 auto zero = Stmt::make<ConstStmt>(TypedConstant(stmt->ret_type));
26 stmt->replace_usages_with(zero.get());
27 modifier.insert_before(stmt, std::move(zero));
28 modifier.erase(stmt);
29 }
30
31 void replace_with_one(Stmt *stmt) {
32 auto one = Stmt::make<ConstStmt>(TypedConstant(1));
33 auto one_raw = one.get();
34 modifier.insert_before(stmt, std::move(one));
35 cast_to_result_type(one_raw, stmt);
36 stmt->replace_usages_with(one_raw);
37 modifier.erase(stmt);
38 }
39
40 public:
41 static constexpr int max_weaken_exponent = 32;
42 using BasicStmtVisitor::visit;
43 bool fast_math;
44 DelayedIRModifier modifier;
45
46 explicit AlgSimp(bool fast_math_) : fast_math(fast_math_) {
47 }
48
49 [[nodiscard]] bool is_redundant_cast(const DataType &first_cast,
50 const DataType &second_cast) const {
51 // Tests if second_cast(first_cast(a)) is guaranteed to be equivalent to
52 // second_cast(a).
53 if (!first_cast->is<PrimitiveType>() || !second_cast->is<PrimitiveType>()) {
54 // TODO(type): handle this case
55 return false;
56 }
57 if (is_real(second_cast)) {
58 // float(...(a))
59 return is_real(first_cast) &&
60 data_type_bits(second_cast) <= data_type_bits(first_cast);
61 }
62 if (is_integral(first_cast)) {
63 // int(int(a))
64 return data_type_bits(second_cast) <= data_type_bits(first_cast);
65 }
66 // int(float(a))
67 if (data_type_bits(second_cast) <= data_type_bits(first_cast) * 2) {
68 // f64 can hold any i32 values.
69 return true;
70 } else {
71 // Assume a floating point type can hold any integer values when
72 // fast_math=True.
73 return fast_math;
74 }
75 }
76
77 void visit(UnaryOpStmt *stmt) override {
78 if (stmt->is_cast()) {
79 if (stmt->cast_type == stmt->operand->ret_type) {
80 stmt->replace_usages_with(stmt->operand);
81 modifier.erase(stmt);
82 } else if (stmt->operand->is<UnaryOpStmt>() &&
83 stmt->operand->as<UnaryOpStmt>()->is_cast()) {
84 auto prev_cast = stmt->operand->as<UnaryOpStmt>();
85 if (stmt->op_type == UnaryOpType::cast_bits &&
86 prev_cast->op_type == UnaryOpType::cast_bits) {
87 stmt->operand = prev_cast->operand;
88 modifier.mark_as_modified();
89 } else if (stmt->op_type == UnaryOpType::cast_value &&
90 prev_cast->op_type == UnaryOpType::cast_value &&
91 is_redundant_cast(prev_cast->cast_type, stmt->cast_type)) {
92 stmt->operand = prev_cast->operand;
93 modifier.mark_as_modified();
94 }
95 }
96 }
97 }
98
99 bool optimize_multiplication(BinaryOpStmt *stmt) {
100 // return true iff the IR is modified
101 auto lhs = stmt->lhs->cast<ConstStmt>();
102 auto rhs = stmt->rhs->cast<ConstStmt>();
103 TI_ASSERT(stmt->op_type == BinaryOpType::mul);
104 if (alg_is_one(lhs) || alg_is_one(rhs)) {
105 // 1 * a -> a, a * 1 -> a
106 stmt->replace_usages_with(alg_is_one(lhs) ? stmt->rhs : stmt->lhs);
107 modifier.erase(stmt);
108 return true;
109 }
110 if ((fast_math || is_integral(stmt->ret_type)) &&
111 (alg_is_zero(lhs) || alg_is_zero(rhs))) {
112 // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0
113 replace_with_zero(stmt);
114 return true;
115 }
116 if (is_integral(stmt->ret_type) && (alg_is_pot(lhs) || alg_is_pot(rhs))) {
117 // a * pot -> a << log2(pot)
118 if (alg_is_pot(lhs)) {
119 std::swap(stmt->lhs, stmt->rhs);
120 std::swap(lhs, rhs);
121 }
122 int log2rhs = bit::log2int((uint64)rhs->val.val_as_int64());
123 auto new_rhs =
124 Stmt::make<ConstStmt>(TypedConstant(stmt->lhs->ret_type, log2rhs));
125 auto result = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_shl, stmt->lhs,
126 new_rhs.get());
127 result->ret_type = stmt->ret_type;
128 result->set_tb(stmt->tb);
129 stmt->replace_usages_with(result.get());
130 modifier.insert_before(stmt, std::move(new_rhs));
131 modifier.insert_before(stmt, std::move(result));
132 modifier.erase(stmt);
133 return true;
134 }
135 if (alg_is_two(lhs) || alg_is_two(rhs)) {
136 // 2 * a -> a + a, a * 2 -> a + a
137 auto a = stmt->lhs;
138 if (alg_is_two(lhs))
139 a = stmt->rhs;
140 cast_to_result_type(a, stmt);
141 auto sum = Stmt::make<BinaryOpStmt>(BinaryOpType::add, a, a);
142 sum->ret_type = a->ret_type;
143 sum->set_tb(stmt->tb);
144 stmt->replace_usages_with(sum.get());
145 modifier.insert_before(stmt, std::move(sum));
146 modifier.erase(stmt);
147 return true;
148 }
149 return false;
150 }
151
152 bool optimize_division(BinaryOpStmt *stmt) {
153 // return true iff the IR is modified
154 auto rhs = stmt->rhs->cast<ConstStmt>();
155 TI_ASSERT(stmt->op_type == BinaryOpType::div ||
156 stmt->op_type == BinaryOpType::floordiv);
157 if (alg_is_one(rhs) && !(is_real(stmt->lhs->ret_type) &&
158 stmt->op_type == BinaryOpType::floordiv)) {
159 // a / 1 -> a
160 stmt->replace_usages_with(stmt->lhs);
161 modifier.erase(stmt);
162 return true;
163 }
164 if ((fast_math || is_integral(stmt->ret_type)) &&
165 irpass::analysis::same_value(stmt->lhs, stmt->rhs)) {
166 // fast_math or integral operands: a / a -> 1
167 replace_with_one(stmt);
168 return true;
169 }
170 if (fast_math && rhs && is_real(rhs->ret_type) &&
171 stmt->op_type != BinaryOpType::floordiv) {
172 if (alg_is_zero(rhs)) {
173 TI_WARN("Potential division by 0\n{}", stmt->tb);
174 } else {
175 // a / const -> a * (1 / const)
176 auto reciprocal =
177 Stmt::make_typed<ConstStmt>(TypedConstant(rhs->ret_type));
178 if (rhs->ret_type->is_primitive(PrimitiveTypeID::f64)) {
179 reciprocal->val.val_float64() = (float64)1.0 / rhs->val.val_float64();
180 } else if (rhs->ret_type->is_primitive(PrimitiveTypeID::f32)) {
181 reciprocal->val.val_float32() = (float32)1.0 / rhs->val.val_float32();
182 } else {
183 TI_NOT_IMPLEMENTED
184 }
185 auto product = Stmt::make<BinaryOpStmt>(BinaryOpType::mul, stmt->lhs,
186 reciprocal.get());
187 product->ret_type = stmt->ret_type;
188 stmt->replace_usages_with(product.get());
189 modifier.insert_before(stmt, std::move(reciprocal));
190 modifier.insert_before(stmt, std::move(product));
191 modifier.erase(stmt);
192 return true;
193 }
194 }
195 if (is_integral(stmt->lhs->ret_type) && is_unsigned(stmt->lhs->ret_type) &&
196 alg_is_pot(rhs)) {
197 // (unsigned)a / pot -> a >> log2(pot)
198 int log2rhs = bit::log2int((uint64)rhs->val.val_as_int64());
199 auto new_rhs =
200 Stmt::make<ConstStmt>(TypedConstant(stmt->lhs->ret_type, log2rhs));
201 auto result = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_sar, stmt->lhs,
202 new_rhs.get());
203 result->ret_type = stmt->ret_type;
204 stmt->replace_usages_with(result.get());
205 modifier.insert_before(stmt, std::move(new_rhs));
206 modifier.insert_before(stmt, std::move(result));
207 modifier.erase(stmt);
208 return true;
209 }
210 return false;
211 }
212
213 void visit(BinaryOpStmt *stmt) override {
214 if (stmt->lhs->ret_type->is<TensorType>() ||
215 stmt->rhs->ret_type->is<TensorType>()) {
216 // TODO: support tensor type
217 return;
218 }
219 auto lhs = stmt->lhs->cast<ConstStmt>();
220 auto rhs = stmt->rhs->cast<ConstStmt>();
221 if (stmt->op_type == BinaryOpType::mul) {
222 optimize_multiplication(stmt);
223 } else if (stmt->op_type == BinaryOpType::div ||
224 stmt->op_type == BinaryOpType::floordiv) {
225 optimize_division(stmt);
226 } else if (stmt->op_type == BinaryOpType::add ||
227 stmt->op_type == BinaryOpType::sub ||
228 stmt->op_type == BinaryOpType::bit_or ||
229 stmt->op_type == BinaryOpType::bit_xor) {
230 if (alg_is_zero(rhs)) {
231 // a +-|^ 0 -> a
232 stmt->replace_usages_with(stmt->lhs);
233 modifier.erase(stmt);
234 } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs)) {
235 // 0 +|^ a -> a
236 stmt->replace_usages_with(stmt->rhs);
237 modifier.erase(stmt);
238 } else if (stmt->op_type == BinaryOpType::bit_or &&
239 irpass::analysis::same_value(stmt->lhs, stmt->rhs)) {
240 // a | a -> a
241 stmt->replace_usages_with(stmt->lhs);
242 modifier.erase(stmt);
243 } else if ((stmt->op_type == BinaryOpType::sub ||
244 stmt->op_type == BinaryOpType::bit_xor) &&
245 (fast_math || is_integral(stmt->ret_type)) &&
246 irpass::analysis::same_value(stmt->lhs, stmt->rhs)) {
247 // fast_math or integral operands: a -^ a -> 0
248 replace_with_zero(stmt);
249 }
250 } else if (rhs && stmt->op_type == BinaryOpType::pow) {
251 float64 exponent = rhs->val.val_cast_to_float64();
252 if (exponent == 1) {
253 // a ** 1 -> a
254 stmt->replace_usages_with(stmt->lhs);
255 modifier.erase(stmt);
256 } else if (exponent == 0) {
257 // a ** 0 -> 1
258 replace_with_one(stmt);
259 } else if (exponent == 0.5) {
260 // a ** 0.5 -> sqrt(a)
261 auto a = stmt->lhs;
262 cast_to_result_type(a, stmt);
263 auto result = Stmt::make<UnaryOpStmt>(UnaryOpType::sqrt, a);
264 result->ret_type = a->ret_type;
265 stmt->replace_usages_with(result.get());
266 modifier.insert_before(stmt, std::move(result));
267 modifier.erase(stmt);
268 } else if (exponent == std::round(exponent) && exponent > 0 &&
269 exponent <= max_weaken_exponent) {
270 // a ** n -> Exponentiation by squaring
271 auto a = stmt->lhs;
272 cast_to_result_type(a, stmt);
273 const int exp = exponent;
274 Stmt *result = nullptr;
275 auto a_power_of_2 = a;
276 int current_exponent = 1;
277 while (true) {
278 if (exp & current_exponent) {
279 if (!result)
280 result = a_power_of_2;
281 else {
282 auto new_result = Stmt::make<BinaryOpStmt>(BinaryOpType::mul,
283 result, a_power_of_2);
284 new_result->ret_type = a->ret_type;
285 result = new_result.get();
286 modifier.insert_before(stmt, std::move(new_result));
287 }
288 }
289 current_exponent <<= 1;
290 if (current_exponent > exp)
291 break;
292 auto new_a_power = Stmt::make<BinaryOpStmt>(
293 BinaryOpType::mul, a_power_of_2, a_power_of_2);
294 new_a_power->ret_type = a->ret_type;
295 a_power_of_2 = new_a_power.get();
296 modifier.insert_before(stmt, std::move(new_a_power));
297 }
298 stmt->replace_usages_with(result);
299 modifier.erase(stmt);
300 } else if (exponent == std::round(exponent) && exponent < 0 &&
301 exponent >= -max_weaken_exponent) {
302 // a ** -n -> 1 / a ** n
303 if (is_integral(stmt->lhs->ret_type)) {
304 TI_ERROR("Negative exponent in pow(int, int) is not allowed.");
305 }
306 auto one = Stmt::make<ConstStmt>(TypedConstant(1));
307 auto one_raw = one.get();
308 modifier.insert_before(stmt, std::move(one));
309 cast_to_result_type(one_raw, stmt);
310 auto new_exponent = Stmt::make<UnaryOpStmt>(UnaryOpType::neg, rhs);
311 auto a_to_n = Stmt::make<BinaryOpStmt>(BinaryOpType::pow, stmt->lhs,
312 new_exponent.get());
313 a_to_n->ret_type = stmt->ret_type;
314 auto result =
315 Stmt::make<BinaryOpStmt>(BinaryOpType::div, one_raw, a_to_n.get());
316 stmt->replace_usages_with(result.get());
317 modifier.insert_before(stmt, std::move(new_exponent));
318 modifier.insert_before(stmt, std::move(a_to_n));
319 modifier.insert_before(stmt, std::move(result));
320 modifier.erase(stmt);
321 }
322 } else if (stmt->op_type == BinaryOpType::bit_and) {
323 if (alg_is_minus_one(rhs)) {
324 // a & -1 -> a
325 stmt->replace_usages_with(stmt->lhs);
326 modifier.erase(stmt);
327 } else if (alg_is_minus_one(lhs)) {
328 // -1 & a -> a
329 stmt->replace_usages_with(stmt->rhs);
330 modifier.erase(stmt);
331 } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) {
332 // 0 & a -> 0, a & 0 -> 0
333 replace_with_zero(stmt);
334 } else if (irpass::analysis::same_value(stmt->lhs, stmt->rhs)) {
335 // a & a -> a
336 stmt->replace_usages_with(stmt->lhs);
337 modifier.erase(stmt);
338 }
339 } else if (stmt->op_type == BinaryOpType::bit_sar ||
340 stmt->op_type == BinaryOpType::bit_shl ||
341 stmt->op_type == BinaryOpType::bit_shr) {
342 if (alg_is_zero(rhs) || alg_is_zero(lhs)) {
343 // a >> 0 -> a
344 // a << 0 -> a
345 // 0 << a -> 0
346 // 0 >> a -> 0
347 TI_ASSERT(stmt->lhs->ret_type == stmt->ret_type);
348 stmt->replace_usages_with(stmt->lhs);
349 modifier.erase(stmt);
350 }
351 } else if (is_comparison(stmt->op_type)) {
352 if ((fast_math || is_integral(stmt->lhs->ret_type)) &&
353 irpass::analysis::same_value(stmt->lhs, stmt->rhs)) {
354 // fast_math or integral operands: a == a -> 1, a != a -> 0
355 if (stmt->op_type == BinaryOpType::cmp_eq ||
356 stmt->op_type == BinaryOpType::cmp_ge ||
357 stmt->op_type == BinaryOpType::cmp_le) {
358 replace_with_one(stmt);
359 } else if (stmt->op_type == BinaryOpType::cmp_ne ||
360 stmt->op_type == BinaryOpType::cmp_gt ||
361 stmt->op_type == BinaryOpType::cmp_lt) {
362 replace_with_zero(stmt);
363 } else {
364 TI_NOT_IMPLEMENTED
365 }
366 }
367 }
368 }
369
370 void visit(AssertStmt *stmt) override {
371 auto cond = stmt->cond->cast<ConstStmt>();
372 if (!cond)
373 return;
374 if (!alg_is_zero(cond)) {
375 // this statement has no effect
376 modifier.erase(stmt);
377 }
378 }
379
380 void visit(WhileControlStmt *stmt) override {
381 auto cond = stmt->cond->cast<ConstStmt>();
382 if (!cond)
383 return;
384 if (!alg_is_zero(cond)) {
385 // this statement has no effect
386 modifier.erase(stmt);
387 }
388 }
389
390 static bool alg_is_zero(ConstStmt *stmt) {
391 if (!stmt)
392 return false;
393 return stmt->val.equal_value(0);
394 }
395
396 static bool alg_is_one(ConstStmt *stmt) {
397 if (!stmt)
398 return false;
399 return stmt->val.equal_value(1);
400 }
401
402 static bool alg_is_two(ConstStmt *stmt) {
403 if (!stmt)
404 return false;
405 return stmt->val.equal_value(2);
406 }
407
408 static bool alg_is_minus_one(ConstStmt *stmt) {
409 if (!stmt)
410 return false;
411 return stmt->val.equal_value(-1);
412 }
413
414 static bool alg_is_pot(ConstStmt *stmt) {
415 if (!stmt)
416 return false;
417 if (!is_integral(stmt->val.dt))
418 return false;
419 if (is_signed(stmt->val.dt)) {
420 return bit::is_power_of_two(stmt->val.val_int());
421 } else {
422 return bit::is_power_of_two(stmt->val.val_uint());
423 }
424 }
425
426 static bool run(IRNode *node, bool fast_math) {
427 AlgSimp simplifier(fast_math);
428 bool modified = false;
429 while (true) {
430 node->accept(&simplifier);
431 if (simplifier.modifier.modify_ir())
432 modified = true;
433 else
434 break;
435 }
436 return modified;
437 }
438};
439
440namespace irpass {
441
442bool alg_simp(IRNode *root, const CompileConfig &config) {
443 TI_AUTO_PROF;
444 return AlgSimp::run(root, config.fast_math);
445}
446
447} // namespace irpass
448
449} // namespace taichi::lang
450