1#include "taichi/ir/analysis.h"
2#include "taichi/ir/ir.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/transforms.h"
5#include "taichi/ir/visitors.h"
6
7namespace taichi::lang {
8
9class BinaryOpSimp : public BasicStmtVisitor {
10 public:
11 using BasicStmtVisitor::visit;
12 bool fast_math;
13 DelayedIRModifier modifier;
14 bool operand_swapped;
15
16 explicit BinaryOpSimp(bool fast_math_)
17 : fast_math(fast_math_), operand_swapped(false) {
18 }
19
20 bool try_rearranging_const_rhs(BinaryOpStmt *stmt) {
21 // Returns true if the statement is modified.
22 auto binary_lhs = stmt->lhs->cast<BinaryOpStmt>();
23 auto const_rhs = stmt->rhs->cast<ConstStmt>();
24 if (!binary_lhs || !const_rhs) {
25 return false;
26 }
27 auto const_lhs_rhs = binary_lhs->rhs->cast<ConstStmt>();
28 if (!const_lhs_rhs || binary_lhs->lhs->is<ConstStmt>()) {
29 return false;
30 }
31 auto op1 = binary_lhs->op_type;
32 auto op2 = stmt->op_type;
33 // Disables (a / b) * c -> a / (b / c), (a * b) / c -> a * (b / c)
34 // when the data type is integral.
35 if (is_integral(stmt->ret_type) &&
36 ((op1 == BinaryOpType::div && op2 == BinaryOpType::mul) ||
37 (op1 == BinaryOpType::mul && op2 == BinaryOpType::div))) {
38 return false;
39 }
40 BinaryOpType new_op2;
41 // original:
42 // stmt = (a op1 b) op2 c
43 // rearrange to:
44 // stmt = a op1 (b op2 c)
45 if (can_rearrange_associative(op1, op2, new_op2)) {
46 auto bin_op = Stmt::make<BinaryOpStmt>(new_op2, const_lhs_rhs, const_rhs);
47 bin_op->ret_type = stmt->ret_type;
48 auto new_stmt =
49 Stmt::make<BinaryOpStmt>(op1, binary_lhs->lhs, bin_op.get());
50 new_stmt->ret_type = stmt->ret_type;
51
52 modifier.insert_before(stmt, std::move(bin_op));
53 // Replace stmt now to avoid being "simplified" again
54 stmt->replace_usages_with(new_stmt.get());
55 modifier.insert_before(stmt, std::move(new_stmt));
56 modifier.erase(stmt);
57 return true;
58 }
59 // original:
60 // stmt = (a >> b) << b
61 // rearrange to:
62 // stmt = a & (-(1 << b))
63 if ((op1 == BinaryOpType::bit_shr || op1 == BinaryOpType::bit_sar) &&
64 op2 == BinaryOpType::bit_shl &&
65 irpass::analysis::same_value(const_lhs_rhs, const_rhs)) {
66 int64 mask = -((int64)1 << (uint64)const_rhs->val.val_as_int64());
67 auto mask_stmt =
68 Stmt::make<ConstStmt>(TypedConstant(stmt->ret_type, mask));
69 auto new_stmt = Stmt::make<BinaryOpStmt>(
70 BinaryOpType::bit_and, binary_lhs->lhs, mask_stmt.get());
71 new_stmt->ret_type = stmt->ret_type;
72
73 modifier.insert_before(stmt, std::move(mask_stmt));
74 // Replace stmt now to avoid being "simplified" again
75 stmt->replace_usages_with(new_stmt.get());
76 modifier.insert_before(stmt, std::move(new_stmt));
77 modifier.erase(stmt);
78 return true;
79 }
80 return false;
81 }
82
83 void visit(BinaryOpStmt *stmt) override {
84 // Swap lhs and rhs if lhs is a const and op is commutative.
85 auto const_lhs = stmt->lhs->cast<ConstStmt>();
86 if (const_lhs && is_commutative(stmt->op_type) &&
87 !stmt->rhs->is<ConstStmt>()) {
88 stmt->lhs = stmt->rhs;
89 stmt->rhs = const_lhs;
90 operand_swapped = true;
91 }
92 // Disable other optimizations if fast_math=True and the data type is not
93 // integral.
94 if (!fast_math && !is_integral(stmt->ret_type)) {
95 return;
96 }
97
98 if (try_rearranging_const_rhs(stmt)) {
99 return;
100 }
101
102 // Miscellaneous optimizations.
103 // original:
104 // stmt = a - (a & b)
105 // rearrange to:
106 // stmt = a & ~b
107 auto *binary_rhs = stmt->rhs->cast<BinaryOpStmt>();
108 if (binary_rhs && stmt->op_type == BinaryOpType::sub &&
109 binary_rhs->op_type == BinaryOpType::bit_and &&
110 irpass::analysis::same_value(stmt->lhs, binary_rhs->lhs)) {
111 auto mask_stmt =
112 Stmt::make<UnaryOpStmt>(UnaryOpType::bit_not, binary_rhs->rhs);
113 mask_stmt->ret_type = binary_rhs->rhs->ret_type;
114 auto new_stmt = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and, stmt->lhs,
115 mask_stmt.get());
116 new_stmt->ret_type = stmt->ret_type;
117
118 modifier.insert_before(stmt, std::move(mask_stmt));
119 // Replace stmt now to avoid being "simplified" again
120 stmt->replace_usages_with(new_stmt.get());
121 modifier.insert_before(stmt, std::move(new_stmt));
122 modifier.erase(stmt);
123 return;
124 }
125 }
126
127 static bool can_rearrange_associative(BinaryOpType op1,
128 BinaryOpType op2,
129 BinaryOpType &new_op2) {
130 if ((op1 == BinaryOpType::add || op1 == BinaryOpType::sub) &&
131 (op2 == BinaryOpType::add || op2 == BinaryOpType::sub)) {
132 if (op1 == BinaryOpType::add)
133 new_op2 = op2;
134 else
135 new_op2 =
136 (op2 == BinaryOpType::add ? BinaryOpType::sub : BinaryOpType::add);
137 return true;
138 }
139 if ((op1 == BinaryOpType::mul || op1 == BinaryOpType::div) &&
140 (op2 == BinaryOpType::mul || op2 == BinaryOpType::div)) {
141 if (op1 == BinaryOpType::mul)
142 new_op2 = op2;
143 else
144 new_op2 =
145 (op2 == BinaryOpType::mul ? BinaryOpType::div : BinaryOpType::mul);
146 return true;
147 }
148 // for bit operations it holds when two ops are the same
149 if ((op1 == BinaryOpType::bit_and || op1 == BinaryOpType::bit_or ||
150 op1 == BinaryOpType::bit_xor) &&
151 op1 == op2) {
152 new_op2 = op2;
153 return true;
154 }
155 if ((op1 == BinaryOpType::bit_shl || op1 == BinaryOpType::bit_shr ||
156 op1 == BinaryOpType::bit_sar) &&
157 op1 == op2) {
158 // (a << b) << c -> a << (b + c)
159 // (a >> b) >> c -> a >> (b + c)
160 new_op2 = BinaryOpType::add;
161 return true;
162 }
163 return false;
164 }
165
166 static bool is_commutative(BinaryOpType op) {
167 return op == BinaryOpType::add || op == BinaryOpType::mul ||
168 op == BinaryOpType::bit_and || op == BinaryOpType::bit_or ||
169 op == BinaryOpType::bit_xor;
170 }
171
172 static bool run(IRNode *node, bool fast_math) {
173 BinaryOpSimp simplifier(fast_math);
174 bool modified = false;
175 while (true) {
176 node->accept(&simplifier);
177 if (simplifier.modifier.modify_ir()) {
178 modified = true;
179 } else
180 break;
181 }
182 return modified || simplifier.operand_swapped;
183 }
184};
185
186namespace irpass {
187
188bool binary_op_simplify(IRNode *root, const CompileConfig &config) {
189 TI_AUTO_PROF;
190 return BinaryOpSimp::run(root, config.fast_math);
191}
192
193} // namespace irpass
194
195} // namespace taichi::lang
196