1 | // The bit-level loop vectorizer |
2 | |
3 | #include "taichi/program/program.h" |
4 | #include "taichi/ir/ir.h" |
5 | #include "taichi/ir/type_factory.h" |
6 | #include "taichi/ir/statements.h" |
7 | #include "taichi/ir/transforms.h" |
8 | #include "taichi/ir/visitors.h" |
9 | #include "taichi/ir/analysis.h" |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | class BitLoopVectorize : public IRVisitor { |
14 | public: |
15 | bool is_bit_vectorized; |
16 | bool in_struct_for_loop; |
17 | StructForStmt *loop_stmt; |
18 | PrimitiveType *quant_array_physical_type; |
19 | std::unordered_map<Stmt *, std::vector<Stmt *>> transformed_atomics; |
20 | |
21 | BitLoopVectorize() { |
22 | allow_undefined_visitor = true; |
23 | invoke_default_visitor = true; |
24 | is_bit_vectorized = false; |
25 | in_struct_for_loop = false; |
26 | loop_stmt = nullptr; |
27 | quant_array_physical_type = nullptr; |
28 | } |
29 | |
30 | void visit(Block *stmt_list) override { |
31 | std::vector<Stmt *> statements; |
32 | for (auto &stmt : stmt_list->statements) { |
33 | statements.push_back(stmt.get()); |
34 | } |
35 | for (auto stmt : statements) { |
36 | stmt->accept(this); |
37 | } |
38 | } |
39 | |
40 | void visit(GlobalLoadStmt *stmt) override { |
41 | auto ptr_type = stmt->src->ret_type->as<PointerType>(); |
42 | if (in_struct_for_loop && is_bit_vectorized) { |
43 | if (ptr_type->get_pointee_type()->cast<QuantIntType>()) { |
44 | // rewrite the previous GlobalPtrStmt's return type from *qit to |
45 | // *phy_type |
46 | auto ptr = stmt->src->cast<GlobalPtrStmt>(); |
47 | auto ptr_physical_type = TypeFactory::get_instance().get_pointer_type( |
48 | quant_array_physical_type, false); |
49 | DataType new_ret_type(ptr_physical_type); |
50 | ptr->ret_type = new_ret_type; |
51 | ptr->is_bit_vectorized = true; |
52 | // check if j has offset |
53 | if (ptr->indices.size() == 2) { |
54 | auto diff = irpass::analysis::value_diff_loop_index(ptr->indices[1], |
55 | loop_stmt, 1); |
56 | // TODO: temporarily we only support [j - 1] and [j + 1] |
57 | // the general case should be easy to implement |
58 | if (diff.linear_related() && diff.certain() && |
59 | (diff.low == 1 || diff.low == -1)) { |
60 | // construct ptr to x[i, j] |
61 | auto indices = ptr->indices; |
62 | indices[1] = loop_stmt->body->statements[1].get(); |
63 | auto base_ptr = |
64 | std::make_unique<GlobalPtrStmt>(ptr->snode, indices); |
65 | base_ptr->ret_type = new_ret_type; |
66 | base_ptr->is_bit_vectorized = true; |
67 | // load x[i, j](base) |
68 | DataType load_data_type(quant_array_physical_type); |
69 | auto load_base = std::make_unique<GlobalLoadStmt>(base_ptr.get()); |
70 | load_base->ret_type = load_data_type; |
71 | // load x[i, j + 1](offsetted) |
72 | // since we are doing vectorization, the actual data should be x[i, |
73 | // j + vectorization_width] |
74 | auto vectorization_width = data_type_bits(load_data_type); |
75 | auto offset_constant = |
76 | std::make_unique<ConstStmt>(TypedConstant(vectorization_width)); |
77 | auto offset_index_opcode = |
78 | diff.low == -1 ? BinaryOpType::sub : BinaryOpType::add; |
79 | auto offset_index = std::make_unique<BinaryOpStmt>( |
80 | offset_index_opcode, indices[1], offset_constant.get()); |
81 | indices[1] = offset_index.get(); |
82 | auto offset_ptr = |
83 | std::make_unique<GlobalPtrStmt>(ptr->snode, indices); |
84 | offset_ptr->ret_type = new_ret_type; |
85 | offset_ptr->is_bit_vectorized = true; |
86 | auto load_offsetted = |
87 | std::make_unique<GlobalLoadStmt>(offset_ptr.get()); |
88 | load_offsetted->ret_type = load_data_type; |
89 | // create bit shift and bit and operations |
90 | auto base_shift_offset = |
91 | std::make_unique<ConstStmt>(TypedConstant(load_data_type, 1)); |
92 | auto base_shift_opcode = |
93 | diff.low == -1 ? BinaryOpType::bit_shl : BinaryOpType::bit_sar; |
94 | auto base_shift_op = std::make_unique<BinaryOpStmt>( |
95 | base_shift_opcode, load_base.get(), base_shift_offset.get()); |
96 | |
97 | auto offsetted_shift_offset = std::make_unique<ConstStmt>( |
98 | TypedConstant(load_data_type, vectorization_width - 1)); |
99 | auto offsetted_shift_opcode = |
100 | diff.low == -1 ? BinaryOpType::bit_sar : BinaryOpType::bit_shl; |
101 | auto offsetted_shift_op = std::make_unique<BinaryOpStmt>( |
102 | offsetted_shift_opcode, load_offsetted.get(), |
103 | offsetted_shift_offset.get()); |
104 | |
105 | auto or_op = std::make_unique<BinaryOpStmt>( |
106 | BinaryOpType::bit_or, base_shift_op.get(), |
107 | offsetted_shift_op.get()); |
108 | // modify IR |
109 | auto offsetted_shift_op_p = offsetted_shift_op.get(); |
110 | stmt->insert_before_me(std::move(base_ptr)); |
111 | stmt->insert_before_me(std::move(load_base)); |
112 | stmt->insert_before_me(std::move(offset_constant)); |
113 | stmt->insert_before_me(std::move(offset_index)); |
114 | stmt->insert_before_me(std::move(offset_ptr)); |
115 | stmt->insert_before_me(std::move(load_offsetted)); |
116 | stmt->insert_before_me(std::move(base_shift_offset)); |
117 | stmt->insert_before_me(std::move(base_shift_op)); |
118 | stmt->insert_before_me(std::move(offsetted_shift_offset)); |
119 | stmt->insert_before_me(std::move(offsetted_shift_op)); |
120 | stmt->replace_usages_with(or_op.get()); |
121 | offsetted_shift_op_p->insert_after_me(std::move(or_op)); |
122 | } |
123 | } |
124 | } |
125 | } |
126 | } |
127 | |
128 | void visit(GlobalStoreStmt *stmt) override { |
129 | auto ptr_type = stmt->dest->ret_type->as<PointerType>(); |
130 | if (in_struct_for_loop && is_bit_vectorized) { |
131 | if (ptr_type->get_pointee_type()->cast<QuantIntType>()) { |
132 | // rewrite the previous GlobalPtrStmt's return type from *qit to |
133 | // *phy_type |
134 | auto ptr = stmt->dest->cast<GlobalPtrStmt>(); |
135 | auto ptr_physical_type = TypeFactory::get_instance().get_pointer_type( |
136 | quant_array_physical_type, false); |
137 | DataType new_ret_type(ptr_physical_type); |
138 | ptr->ret_type = new_ret_type; |
139 | ptr->is_bit_vectorized = true; |
140 | } |
141 | } |
142 | } |
143 | |
144 | void visit(StructForStmt *stmt) override { |
145 | if (stmt->snode->type != SNodeType::quant_array) { |
146 | return; |
147 | } |
148 | bool old_is_bit_vectorized = is_bit_vectorized; |
149 | is_bit_vectorized = stmt->is_bit_vectorized; |
150 | in_struct_for_loop = true; |
151 | loop_stmt = stmt; |
152 | quant_array_physical_type = stmt->snode->physical_type; |
153 | stmt->body->accept(this); |
154 | is_bit_vectorized = old_is_bit_vectorized; |
155 | in_struct_for_loop = false; |
156 | loop_stmt = nullptr; |
157 | quant_array_physical_type = nullptr; |
158 | } |
159 | |
160 | void visit(BinaryOpStmt *stmt) override { |
161 | // vectorize cmp_eq and bit_and between |
162 | // vectorized data(local adder/array elems) and constant |
163 | if (in_struct_for_loop && is_bit_vectorized) { |
164 | if (stmt->op_type == BinaryOpType::bit_and) { |
165 | // if the rhs is a bit vectorized stmt and lhs is a const 1 |
166 | // (usually generated by boolean expr), we simply replace |
167 | // the stmt with its rhs |
168 | int lhs_val = get_constant_value(stmt->lhs); |
169 | if (lhs_val == 1) { |
170 | if (auto rhs = stmt->rhs->cast<BinaryOpStmt>(); |
171 | rhs && rhs->is_bit_vectorized) { |
172 | stmt->replace_usages_with(stmt->rhs); |
173 | } |
174 | } |
175 | } else if (stmt->op_type == BinaryOpType::cmp_eq) { |
176 | if (auto lhs = stmt->lhs->cast<GlobalLoadStmt>()) { |
177 | // case 0: lhs is a vectorized global load from the quant array |
178 | if (auto ptr = lhs->src->cast<GlobalPtrStmt>(); |
179 | ptr && ptr->is_bit_vectorized) { |
180 | int32 rhs_val = get_constant_value(stmt->rhs); |
181 | // TODO: we limit 1 for now, 0 should be easy to implement by a |
182 | // bit_not on original bit pattern |
183 | TI_ASSERT(rhs_val == 1); |
184 | // cmp_eq with 1 yields the bit pattern itself |
185 | |
186 | // to pass CFG analysis and mark the stmt vectorized |
187 | // create a dummy lhs + 0 here |
188 | auto zero = std::make_unique<ConstStmt>(TypedConstant(0)); |
189 | auto add = std::make_unique<BinaryOpStmt>(BinaryOpType::add, |
190 | stmt->lhs, zero.get()); |
191 | add->is_bit_vectorized = true; |
192 | // modify IR |
193 | auto zero_p = zero.get(); |
194 | stmt->insert_before_me(std::move(zero)); |
195 | stmt->replace_usages_with(add.get()); |
196 | zero_p->insert_after_me(std::move(add)); |
197 | } |
198 | } else if (auto lhs = stmt->lhs->cast<LocalLoadStmt>()) { |
199 | // case 1: lhs is a local load from a local adder structure |
200 | auto it = transformed_atomics.find(lhs->src); |
201 | if (it != transformed_atomics.end()) { |
202 | int32 rhs_val = get_constant_value(stmt->rhs); |
203 | // TODO: we limit 2 and 3 for now, the other case should be |
204 | // implement in a similar fashion |
205 | TI_ASSERT(rhs_val == 2 || rhs_val == 3); |
206 | // 010 and 011 respectively |
207 | auto &buffer_vec = it->second; |
208 | Stmt *a = buffer_vec[0], *b = buffer_vec[1], *c = buffer_vec[2]; |
209 | // load all three buffers |
210 | auto load_a = std::make_unique<LocalLoadStmt>(a); |
211 | auto load_b = std::make_unique<LocalLoadStmt>(b); |
212 | auto load_c = std::make_unique<LocalLoadStmt>(c); |
213 | // compute not_a first |
214 | auto not_a = std::make_unique<UnaryOpStmt>(UnaryOpType::bit_not, |
215 | load_a.get()); |
216 | // b should always be itself so do nothing |
217 | // compute not_c |
218 | auto not_c = std::make_unique<UnaryOpStmt>(UnaryOpType::bit_not, |
219 | load_c.get()); |
220 | // bit_and all three patterns |
221 | auto and_a_b = std::make_unique<BinaryOpStmt>( |
222 | BinaryOpType::bit_and, not_a.get(), load_b.get()); |
223 | auto and_b_c = std::make_unique<BinaryOpStmt>( |
224 | BinaryOpType::bit_and, and_a_b.get(), |
225 | rhs_val == 2 ? (Stmt *)(not_c.get()) : (Stmt *)(load_c.get())); |
226 | // mark the last stmt as vectorized |
227 | and_b_c->is_bit_vectorized = true; |
228 | // modify IR |
229 | auto and_a_b_p = and_a_b.get(); |
230 | stmt->insert_before_me(std::move(load_a)); |
231 | stmt->insert_before_me(std::move(load_b)); |
232 | stmt->insert_before_me(std::move(load_c)); |
233 | stmt->insert_before_me(std::move(not_a)); |
234 | stmt->insert_before_me(std::move(not_c)); |
235 | stmt->insert_before_me(std::move(and_a_b)); |
236 | stmt->replace_usages_with(and_b_c.get()); |
237 | and_a_b_p->insert_after_me(std::move(and_b_c)); |
238 | } |
239 | } |
240 | } |
241 | } |
242 | } |
243 | |
244 | void visit(AtomicOpStmt *stmt) override { |
245 | DataType dt(quant_array_physical_type); |
246 | if (in_struct_for_loop && is_bit_vectorized && |
247 | stmt->op_type == AtomicOpType::add) { |
248 | auto it = transformed_atomics.find(stmt->dest); |
249 | // process a transformed atomic stmt |
250 | if (it != transformed_atomics.end()) { |
251 | auto &buffer_vec = it->second; |
252 | transform_atomic_add(buffer_vec, stmt, dt); |
253 | } else { |
254 | // alloc three buffers a, b, c |
255 | auto alloc_a = std::make_unique<AllocaStmt>(dt); |
256 | auto alloc_b = std::make_unique<AllocaStmt>(dt); |
257 | auto alloc_c = std::make_unique<AllocaStmt>(dt); |
258 | std::vector<Stmt *> buffer_vec{alloc_a.get(), alloc_b.get(), |
259 | alloc_c.get()}; |
260 | transformed_atomics[stmt->dest] = buffer_vec; |
261 | // modify IR |
262 | stmt->insert_before_me(std::move(alloc_a)); |
263 | stmt->insert_before_me(std::move(alloc_b)); |
264 | stmt->insert_before_me(std::move(alloc_c)); |
265 | transform_atomic_add(buffer_vec, stmt, dt); |
266 | } |
267 | } |
268 | } |
269 | |
270 | static void run(IRNode *node) { |
271 | BitLoopVectorize inst; |
272 | node->accept(&inst); |
273 | } |
274 | |
275 | private: |
276 | void transform_atomic_add(const std::vector<Stmt *> &buffer_vec, |
277 | AtomicOpStmt *stmt, |
278 | DataType &dt) { |
279 | // To transform an atomic add on a vectorized subarray of a quant array, |
280 | // we use a local adder with three buffers(*a*,*b*,*c*) of the same physical |
281 | // type of the original quant array. Each bit in *a* represents the highest |
282 | // bit of the result, while *b* for the second bit and *c* for the lowest |
283 | // bit To add *d* to the subarray, we do bit_xor and bit_and to compute the |
284 | // sum and the carry |
285 | Stmt *a = buffer_vec[0], *b = buffer_vec[1], *c = buffer_vec[2]; |
286 | auto load_c = std::make_unique<LocalLoadStmt>(c); |
287 | auto carry_c = std::make_unique<BinaryOpStmt>(BinaryOpType::bit_and, |
288 | load_c.get(), stmt->val); |
289 | auto sum_c = |
290 | std::make_unique<AtomicOpStmt>(AtomicOpType::bit_xor, c, stmt->val); |
291 | auto load_b = std::make_unique<LocalLoadStmt>(b); |
292 | auto carry_b = std::make_unique<BinaryOpStmt>(BinaryOpType::bit_and, |
293 | load_b.get(), carry_c.get()); |
294 | auto sum_b = |
295 | std::make_unique<AtomicOpStmt>(AtomicOpType::bit_xor, b, carry_c.get()); |
296 | // for a, we do not need to compute its carry |
297 | auto sum_a = |
298 | std::make_unique<AtomicOpStmt>(AtomicOpType::bit_xor, a, carry_b.get()); |
299 | // modify IR |
300 | stmt->insert_before_me(std::move(load_c)); |
301 | stmt->insert_before_me(std::move(carry_c)); |
302 | stmt->insert_before_me(std::move(sum_c)); |
303 | stmt->insert_before_me(std::move(load_b)); |
304 | stmt->insert_before_me(std::move(carry_b)); |
305 | stmt->insert_before_me(std::move(sum_b)); |
306 | stmt->insert_before_me(std::move(sum_a)); |
307 | // there is no need to replace the stmt here as we |
308 | // will replace it manually later |
309 | } |
310 | |
311 | int32 get_constant_value(Stmt *stmt) { |
312 | int32 val = -1; |
313 | // the stmt could be a cast stmt |
314 | if (auto cast_stmt = stmt->cast<UnaryOpStmt>(); |
315 | cast_stmt && cast_stmt->is_cast() && |
316 | cast_stmt->op_type == UnaryOpType::cast_value) { |
317 | stmt = cast_stmt->operand; |
318 | } |
319 | if (auto constant_stmt = stmt->cast<ConstStmt>(); |
320 | constant_stmt && |
321 | constant_stmt->val.dt->is_primitive(PrimitiveTypeID::i32)) { |
322 | val = constant_stmt->val.val_i32; |
323 | } |
324 | return val; |
325 | } |
326 | }; |
327 | |
328 | namespace irpass { |
329 | |
330 | void bit_loop_vectorize(IRNode *root) { |
331 | TI_AUTO_PROF; |
332 | BitLoopVectorize::run(root); |
333 | die(root); |
334 | } |
335 | |
336 | } // namespace irpass |
337 | |
338 | } // namespace taichi::lang |
339 | |