1 | #include "taichi/ir/analysis.h" |
2 | #include "taichi/ir/ir.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/transforms.h" |
5 | #include "taichi/ir/visitors.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | class BinaryOpSimp : public BasicStmtVisitor { |
10 | public: |
11 | using BasicStmtVisitor::visit; |
12 | bool fast_math; |
13 | DelayedIRModifier modifier; |
14 | bool operand_swapped; |
15 | |
16 | explicit BinaryOpSimp(bool fast_math_) |
17 | : fast_math(fast_math_), operand_swapped(false) { |
18 | } |
19 | |
20 | bool try_rearranging_const_rhs(BinaryOpStmt *stmt) { |
21 | // Returns true if the statement is modified. |
22 | auto binary_lhs = stmt->lhs->cast<BinaryOpStmt>(); |
23 | auto const_rhs = stmt->rhs->cast<ConstStmt>(); |
24 | if (!binary_lhs || !const_rhs) { |
25 | return false; |
26 | } |
27 | auto const_lhs_rhs = binary_lhs->rhs->cast<ConstStmt>(); |
28 | if (!const_lhs_rhs || binary_lhs->lhs->is<ConstStmt>()) { |
29 | return false; |
30 | } |
31 | auto op1 = binary_lhs->op_type; |
32 | auto op2 = stmt->op_type; |
33 | // Disables (a / b) * c -> a / (b / c), (a * b) / c -> a * (b / c) |
34 | // when the data type is integral. |
35 | if (is_integral(stmt->ret_type) && |
36 | ((op1 == BinaryOpType::div && op2 == BinaryOpType::mul) || |
37 | (op1 == BinaryOpType::mul && op2 == BinaryOpType::div))) { |
38 | return false; |
39 | } |
40 | BinaryOpType new_op2; |
41 | // original: |
42 | // stmt = (a op1 b) op2 c |
43 | // rearrange to: |
44 | // stmt = a op1 (b op2 c) |
45 | if (can_rearrange_associative(op1, op2, new_op2)) { |
46 | auto bin_op = Stmt::make<BinaryOpStmt>(new_op2, const_lhs_rhs, const_rhs); |
47 | bin_op->ret_type = stmt->ret_type; |
48 | auto new_stmt = |
49 | Stmt::make<BinaryOpStmt>(op1, binary_lhs->lhs, bin_op.get()); |
50 | new_stmt->ret_type = stmt->ret_type; |
51 | |
52 | modifier.insert_before(stmt, std::move(bin_op)); |
53 | // Replace stmt now to avoid being "simplified" again |
54 | stmt->replace_usages_with(new_stmt.get()); |
55 | modifier.insert_before(stmt, std::move(new_stmt)); |
56 | modifier.erase(stmt); |
57 | return true; |
58 | } |
59 | // original: |
60 | // stmt = (a >> b) << b |
61 | // rearrange to: |
62 | // stmt = a & (-(1 << b)) |
63 | if ((op1 == BinaryOpType::bit_shr || op1 == BinaryOpType::bit_sar) && |
64 | op2 == BinaryOpType::bit_shl && |
65 | irpass::analysis::same_value(const_lhs_rhs, const_rhs)) { |
66 | int64 mask = -((int64)1 << (uint64)const_rhs->val.val_as_int64()); |
67 | auto mask_stmt = |
68 | Stmt::make<ConstStmt>(TypedConstant(stmt->ret_type, mask)); |
69 | auto new_stmt = Stmt::make<BinaryOpStmt>( |
70 | BinaryOpType::bit_and, binary_lhs->lhs, mask_stmt.get()); |
71 | new_stmt->ret_type = stmt->ret_type; |
72 | |
73 | modifier.insert_before(stmt, std::move(mask_stmt)); |
74 | // Replace stmt now to avoid being "simplified" again |
75 | stmt->replace_usages_with(new_stmt.get()); |
76 | modifier.insert_before(stmt, std::move(new_stmt)); |
77 | modifier.erase(stmt); |
78 | return true; |
79 | } |
80 | return false; |
81 | } |
82 | |
83 | void visit(BinaryOpStmt *stmt) override { |
84 | // Swap lhs and rhs if lhs is a const and op is commutative. |
85 | auto const_lhs = stmt->lhs->cast<ConstStmt>(); |
86 | if (const_lhs && is_commutative(stmt->op_type) && |
87 | !stmt->rhs->is<ConstStmt>()) { |
88 | stmt->lhs = stmt->rhs; |
89 | stmt->rhs = const_lhs; |
90 | operand_swapped = true; |
91 | } |
92 | // Disable other optimizations if fast_math=True and the data type is not |
93 | // integral. |
94 | if (!fast_math && !is_integral(stmt->ret_type)) { |
95 | return; |
96 | } |
97 | |
98 | if (try_rearranging_const_rhs(stmt)) { |
99 | return; |
100 | } |
101 | |
102 | // Miscellaneous optimizations. |
103 | // original: |
104 | // stmt = a - (a & b) |
105 | // rearrange to: |
106 | // stmt = a & ~b |
107 | auto *binary_rhs = stmt->rhs->cast<BinaryOpStmt>(); |
108 | if (binary_rhs && stmt->op_type == BinaryOpType::sub && |
109 | binary_rhs->op_type == BinaryOpType::bit_and && |
110 | irpass::analysis::same_value(stmt->lhs, binary_rhs->lhs)) { |
111 | auto mask_stmt = |
112 | Stmt::make<UnaryOpStmt>(UnaryOpType::bit_not, binary_rhs->rhs); |
113 | mask_stmt->ret_type = binary_rhs->rhs->ret_type; |
114 | auto new_stmt = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and, stmt->lhs, |
115 | mask_stmt.get()); |
116 | new_stmt->ret_type = stmt->ret_type; |
117 | |
118 | modifier.insert_before(stmt, std::move(mask_stmt)); |
119 | // Replace stmt now to avoid being "simplified" again |
120 | stmt->replace_usages_with(new_stmt.get()); |
121 | modifier.insert_before(stmt, std::move(new_stmt)); |
122 | modifier.erase(stmt); |
123 | return; |
124 | } |
125 | } |
126 | |
127 | static bool can_rearrange_associative(BinaryOpType op1, |
128 | BinaryOpType op2, |
129 | BinaryOpType &new_op2) { |
130 | if ((op1 == BinaryOpType::add || op1 == BinaryOpType::sub) && |
131 | (op2 == BinaryOpType::add || op2 == BinaryOpType::sub)) { |
132 | if (op1 == BinaryOpType::add) |
133 | new_op2 = op2; |
134 | else |
135 | new_op2 = |
136 | (op2 == BinaryOpType::add ? BinaryOpType::sub : BinaryOpType::add); |
137 | return true; |
138 | } |
139 | if ((op1 == BinaryOpType::mul || op1 == BinaryOpType::div) && |
140 | (op2 == BinaryOpType::mul || op2 == BinaryOpType::div)) { |
141 | if (op1 == BinaryOpType::mul) |
142 | new_op2 = op2; |
143 | else |
144 | new_op2 = |
145 | (op2 == BinaryOpType::mul ? BinaryOpType::div : BinaryOpType::mul); |
146 | return true; |
147 | } |
148 | // for bit operations it holds when two ops are the same |
149 | if ((op1 == BinaryOpType::bit_and || op1 == BinaryOpType::bit_or || |
150 | op1 == BinaryOpType::bit_xor) && |
151 | op1 == op2) { |
152 | new_op2 = op2; |
153 | return true; |
154 | } |
155 | if ((op1 == BinaryOpType::bit_shl || op1 == BinaryOpType::bit_shr || |
156 | op1 == BinaryOpType::bit_sar) && |
157 | op1 == op2) { |
158 | // (a << b) << c -> a << (b + c) |
159 | // (a >> b) >> c -> a >> (b + c) |
160 | new_op2 = BinaryOpType::add; |
161 | return true; |
162 | } |
163 | return false; |
164 | } |
165 | |
166 | static bool is_commutative(BinaryOpType op) { |
167 | return op == BinaryOpType::add || op == BinaryOpType::mul || |
168 | op == BinaryOpType::bit_and || op == BinaryOpType::bit_or || |
169 | op == BinaryOpType::bit_xor; |
170 | } |
171 | |
172 | static bool run(IRNode *node, bool fast_math) { |
173 | BinaryOpSimp simplifier(fast_math); |
174 | bool modified = false; |
175 | while (true) { |
176 | node->accept(&simplifier); |
177 | if (simplifier.modifier.modify_ir()) { |
178 | modified = true; |
179 | } else |
180 | break; |
181 | } |
182 | return modified || simplifier.operand_swapped; |
183 | } |
184 | }; |
185 | |
186 | namespace irpass { |
187 | |
188 | bool binary_op_simplify(IRNode *root, const CompileConfig &config) { |
189 | TI_AUTO_PROF; |
190 | return BinaryOpSimp::run(root, config.fast_math); |
191 | } |
192 | |
193 | } // namespace irpass |
194 | |
195 | } // namespace taichi::lang |
196 | |