1// The bit-level loop vectorizer
2
3#include "taichi/program/program.h"
4#include "taichi/ir/ir.h"
5#include "taichi/ir/type_factory.h"
6#include "taichi/ir/statements.h"
7#include "taichi/ir/transforms.h"
8#include "taichi/ir/visitors.h"
9#include "taichi/ir/analysis.h"
10
11namespace taichi::lang {
12
13class BitLoopVectorize : public IRVisitor {
14 public:
15 bool is_bit_vectorized;
16 bool in_struct_for_loop;
17 StructForStmt *loop_stmt;
18 PrimitiveType *quant_array_physical_type;
19 std::unordered_map<Stmt *, std::vector<Stmt *>> transformed_atomics;
20
21 BitLoopVectorize() {
22 allow_undefined_visitor = true;
23 invoke_default_visitor = true;
24 is_bit_vectorized = false;
25 in_struct_for_loop = false;
26 loop_stmt = nullptr;
27 quant_array_physical_type = nullptr;
28 }
29
30 void visit(Block *stmt_list) override {
31 std::vector<Stmt *> statements;
32 for (auto &stmt : stmt_list->statements) {
33 statements.push_back(stmt.get());
34 }
35 for (auto stmt : statements) {
36 stmt->accept(this);
37 }
38 }
39
40 void visit(GlobalLoadStmt *stmt) override {
41 auto ptr_type = stmt->src->ret_type->as<PointerType>();
42 if (in_struct_for_loop && is_bit_vectorized) {
43 if (ptr_type->get_pointee_type()->cast<QuantIntType>()) {
44 // rewrite the previous GlobalPtrStmt's return type from *qit to
45 // *phy_type
46 auto ptr = stmt->src->cast<GlobalPtrStmt>();
47 auto ptr_physical_type = TypeFactory::get_instance().get_pointer_type(
48 quant_array_physical_type, false);
49 DataType new_ret_type(ptr_physical_type);
50 ptr->ret_type = new_ret_type;
51 ptr->is_bit_vectorized = true;
52 // check if j has offset
53 if (ptr->indices.size() == 2) {
54 auto diff = irpass::analysis::value_diff_loop_index(ptr->indices[1],
55 loop_stmt, 1);
56 // TODO: temporarily we only support [j - 1] and [j + 1]
57 // the general case should be easy to implement
58 if (diff.linear_related() && diff.certain() &&
59 (diff.low == 1 || diff.low == -1)) {
60 // construct ptr to x[i, j]
61 auto indices = ptr->indices;
62 indices[1] = loop_stmt->body->statements[1].get();
63 auto base_ptr =
64 std::make_unique<GlobalPtrStmt>(ptr->snode, indices);
65 base_ptr->ret_type = new_ret_type;
66 base_ptr->is_bit_vectorized = true;
67 // load x[i, j](base)
68 DataType load_data_type(quant_array_physical_type);
69 auto load_base = std::make_unique<GlobalLoadStmt>(base_ptr.get());
70 load_base->ret_type = load_data_type;
71 // load x[i, j + 1](offsetted)
72 // since we are doing vectorization, the actual data should be x[i,
73 // j + vectorization_width]
74 auto vectorization_width = data_type_bits(load_data_type);
75 auto offset_constant =
76 std::make_unique<ConstStmt>(TypedConstant(vectorization_width));
77 auto offset_index_opcode =
78 diff.low == -1 ? BinaryOpType::sub : BinaryOpType::add;
79 auto offset_index = std::make_unique<BinaryOpStmt>(
80 offset_index_opcode, indices[1], offset_constant.get());
81 indices[1] = offset_index.get();
82 auto offset_ptr =
83 std::make_unique<GlobalPtrStmt>(ptr->snode, indices);
84 offset_ptr->ret_type = new_ret_type;
85 offset_ptr->is_bit_vectorized = true;
86 auto load_offsetted =
87 std::make_unique<GlobalLoadStmt>(offset_ptr.get());
88 load_offsetted->ret_type = load_data_type;
89 // create bit shift and bit and operations
90 auto base_shift_offset =
91 std::make_unique<ConstStmt>(TypedConstant(load_data_type, 1));
92 auto base_shift_opcode =
93 diff.low == -1 ? BinaryOpType::bit_shl : BinaryOpType::bit_sar;
94 auto base_shift_op = std::make_unique<BinaryOpStmt>(
95 base_shift_opcode, load_base.get(), base_shift_offset.get());
96
97 auto offsetted_shift_offset = std::make_unique<ConstStmt>(
98 TypedConstant(load_data_type, vectorization_width - 1));
99 auto offsetted_shift_opcode =
100 diff.low == -1 ? BinaryOpType::bit_sar : BinaryOpType::bit_shl;
101 auto offsetted_shift_op = std::make_unique<BinaryOpStmt>(
102 offsetted_shift_opcode, load_offsetted.get(),
103 offsetted_shift_offset.get());
104
105 auto or_op = std::make_unique<BinaryOpStmt>(
106 BinaryOpType::bit_or, base_shift_op.get(),
107 offsetted_shift_op.get());
108 // modify IR
109 auto offsetted_shift_op_p = offsetted_shift_op.get();
110 stmt->insert_before_me(std::move(base_ptr));
111 stmt->insert_before_me(std::move(load_base));
112 stmt->insert_before_me(std::move(offset_constant));
113 stmt->insert_before_me(std::move(offset_index));
114 stmt->insert_before_me(std::move(offset_ptr));
115 stmt->insert_before_me(std::move(load_offsetted));
116 stmt->insert_before_me(std::move(base_shift_offset));
117 stmt->insert_before_me(std::move(base_shift_op));
118 stmt->insert_before_me(std::move(offsetted_shift_offset));
119 stmt->insert_before_me(std::move(offsetted_shift_op));
120 stmt->replace_usages_with(or_op.get());
121 offsetted_shift_op_p->insert_after_me(std::move(or_op));
122 }
123 }
124 }
125 }
126 }
127
128 void visit(GlobalStoreStmt *stmt) override {
129 auto ptr_type = stmt->dest->ret_type->as<PointerType>();
130 if (in_struct_for_loop && is_bit_vectorized) {
131 if (ptr_type->get_pointee_type()->cast<QuantIntType>()) {
132 // rewrite the previous GlobalPtrStmt's return type from *qit to
133 // *phy_type
134 auto ptr = stmt->dest->cast<GlobalPtrStmt>();
135 auto ptr_physical_type = TypeFactory::get_instance().get_pointer_type(
136 quant_array_physical_type, false);
137 DataType new_ret_type(ptr_physical_type);
138 ptr->ret_type = new_ret_type;
139 ptr->is_bit_vectorized = true;
140 }
141 }
142 }
143
144 void visit(StructForStmt *stmt) override {
145 if (stmt->snode->type != SNodeType::quant_array) {
146 return;
147 }
148 bool old_is_bit_vectorized = is_bit_vectorized;
149 is_bit_vectorized = stmt->is_bit_vectorized;
150 in_struct_for_loop = true;
151 loop_stmt = stmt;
152 quant_array_physical_type = stmt->snode->physical_type;
153 stmt->body->accept(this);
154 is_bit_vectorized = old_is_bit_vectorized;
155 in_struct_for_loop = false;
156 loop_stmt = nullptr;
157 quant_array_physical_type = nullptr;
158 }
159
160 void visit(BinaryOpStmt *stmt) override {
161 // vectorize cmp_eq and bit_and between
162 // vectorized data(local adder/array elems) and constant
163 if (in_struct_for_loop && is_bit_vectorized) {
164 if (stmt->op_type == BinaryOpType::bit_and) {
165 // if the rhs is a bit vectorized stmt and lhs is a const 1
166 // (usually generated by boolean expr), we simply replace
167 // the stmt with its rhs
168 int lhs_val = get_constant_value(stmt->lhs);
169 if (lhs_val == 1) {
170 if (auto rhs = stmt->rhs->cast<BinaryOpStmt>();
171 rhs && rhs->is_bit_vectorized) {
172 stmt->replace_usages_with(stmt->rhs);
173 }
174 }
175 } else if (stmt->op_type == BinaryOpType::cmp_eq) {
176 if (auto lhs = stmt->lhs->cast<GlobalLoadStmt>()) {
177 // case 0: lhs is a vectorized global load from the quant array
178 if (auto ptr = lhs->src->cast<GlobalPtrStmt>();
179 ptr && ptr->is_bit_vectorized) {
180 int32 rhs_val = get_constant_value(stmt->rhs);
181 // TODO: we limit 1 for now, 0 should be easy to implement by a
182 // bit_not on original bit pattern
183 TI_ASSERT(rhs_val == 1);
184 // cmp_eq with 1 yields the bit pattern itself
185
186 // to pass CFG analysis and mark the stmt vectorized
187 // create a dummy lhs + 0 here
188 auto zero = std::make_unique<ConstStmt>(TypedConstant(0));
189 auto add = std::make_unique<BinaryOpStmt>(BinaryOpType::add,
190 stmt->lhs, zero.get());
191 add->is_bit_vectorized = true;
192 // modify IR
193 auto zero_p = zero.get();
194 stmt->insert_before_me(std::move(zero));
195 stmt->replace_usages_with(add.get());
196 zero_p->insert_after_me(std::move(add));
197 }
198 } else if (auto lhs = stmt->lhs->cast<LocalLoadStmt>()) {
199 // case 1: lhs is a local load from a local adder structure
200 auto it = transformed_atomics.find(lhs->src);
201 if (it != transformed_atomics.end()) {
202 int32 rhs_val = get_constant_value(stmt->rhs);
203 // TODO: we limit 2 and 3 for now, the other case should be
204 // implement in a similar fashion
205 TI_ASSERT(rhs_val == 2 || rhs_val == 3);
206 // 010 and 011 respectively
207 auto &buffer_vec = it->second;
208 Stmt *a = buffer_vec[0], *b = buffer_vec[1], *c = buffer_vec[2];
209 // load all three buffers
210 auto load_a = std::make_unique<LocalLoadStmt>(a);
211 auto load_b = std::make_unique<LocalLoadStmt>(b);
212 auto load_c = std::make_unique<LocalLoadStmt>(c);
213 // compute not_a first
214 auto not_a = std::make_unique<UnaryOpStmt>(UnaryOpType::bit_not,
215 load_a.get());
216 // b should always be itself so do nothing
217 // compute not_c
218 auto not_c = std::make_unique<UnaryOpStmt>(UnaryOpType::bit_not,
219 load_c.get());
220 // bit_and all three patterns
221 auto and_a_b = std::make_unique<BinaryOpStmt>(
222 BinaryOpType::bit_and, not_a.get(), load_b.get());
223 auto and_b_c = std::make_unique<BinaryOpStmt>(
224 BinaryOpType::bit_and, and_a_b.get(),
225 rhs_val == 2 ? (Stmt *)(not_c.get()) : (Stmt *)(load_c.get()));
226 // mark the last stmt as vectorized
227 and_b_c->is_bit_vectorized = true;
228 // modify IR
229 auto and_a_b_p = and_a_b.get();
230 stmt->insert_before_me(std::move(load_a));
231 stmt->insert_before_me(std::move(load_b));
232 stmt->insert_before_me(std::move(load_c));
233 stmt->insert_before_me(std::move(not_a));
234 stmt->insert_before_me(std::move(not_c));
235 stmt->insert_before_me(std::move(and_a_b));
236 stmt->replace_usages_with(and_b_c.get());
237 and_a_b_p->insert_after_me(std::move(and_b_c));
238 }
239 }
240 }
241 }
242 }
243
244 void visit(AtomicOpStmt *stmt) override {
245 DataType dt(quant_array_physical_type);
246 if (in_struct_for_loop && is_bit_vectorized &&
247 stmt->op_type == AtomicOpType::add) {
248 auto it = transformed_atomics.find(stmt->dest);
249 // process a transformed atomic stmt
250 if (it != transformed_atomics.end()) {
251 auto &buffer_vec = it->second;
252 transform_atomic_add(buffer_vec, stmt, dt);
253 } else {
254 // alloc three buffers a, b, c
255 auto alloc_a = std::make_unique<AllocaStmt>(dt);
256 auto alloc_b = std::make_unique<AllocaStmt>(dt);
257 auto alloc_c = std::make_unique<AllocaStmt>(dt);
258 std::vector<Stmt *> buffer_vec{alloc_a.get(), alloc_b.get(),
259 alloc_c.get()};
260 transformed_atomics[stmt->dest] = buffer_vec;
261 // modify IR
262 stmt->insert_before_me(std::move(alloc_a));
263 stmt->insert_before_me(std::move(alloc_b));
264 stmt->insert_before_me(std::move(alloc_c));
265 transform_atomic_add(buffer_vec, stmt, dt);
266 }
267 }
268 }
269
270 static void run(IRNode *node) {
271 BitLoopVectorize inst;
272 node->accept(&inst);
273 }
274
275 private:
276 void transform_atomic_add(const std::vector<Stmt *> &buffer_vec,
277 AtomicOpStmt *stmt,
278 DataType &dt) {
279 // To transform an atomic add on a vectorized subarray of a quant array,
280 // we use a local adder with three buffers(*a*,*b*,*c*) of the same physical
281 // type of the original quant array. Each bit in *a* represents the highest
282 // bit of the result, while *b* for the second bit and *c* for the lowest
283 // bit To add *d* to the subarray, we do bit_xor and bit_and to compute the
284 // sum and the carry
285 Stmt *a = buffer_vec[0], *b = buffer_vec[1], *c = buffer_vec[2];
286 auto load_c = std::make_unique<LocalLoadStmt>(c);
287 auto carry_c = std::make_unique<BinaryOpStmt>(BinaryOpType::bit_and,
288 load_c.get(), stmt->val);
289 auto sum_c =
290 std::make_unique<AtomicOpStmt>(AtomicOpType::bit_xor, c, stmt->val);
291 auto load_b = std::make_unique<LocalLoadStmt>(b);
292 auto carry_b = std::make_unique<BinaryOpStmt>(BinaryOpType::bit_and,
293 load_b.get(), carry_c.get());
294 auto sum_b =
295 std::make_unique<AtomicOpStmt>(AtomicOpType::bit_xor, b, carry_c.get());
296 // for a, we do not need to compute its carry
297 auto sum_a =
298 std::make_unique<AtomicOpStmt>(AtomicOpType::bit_xor, a, carry_b.get());
299 // modify IR
300 stmt->insert_before_me(std::move(load_c));
301 stmt->insert_before_me(std::move(carry_c));
302 stmt->insert_before_me(std::move(sum_c));
303 stmt->insert_before_me(std::move(load_b));
304 stmt->insert_before_me(std::move(carry_b));
305 stmt->insert_before_me(std::move(sum_b));
306 stmt->insert_before_me(std::move(sum_a));
307 // there is no need to replace the stmt here as we
308 // will replace it manually later
309 }
310
311 int32 get_constant_value(Stmt *stmt) {
312 int32 val = -1;
313 // the stmt could be a cast stmt
314 if (auto cast_stmt = stmt->cast<UnaryOpStmt>();
315 cast_stmt && cast_stmt->is_cast() &&
316 cast_stmt->op_type == UnaryOpType::cast_value) {
317 stmt = cast_stmt->operand;
318 }
319 if (auto constant_stmt = stmt->cast<ConstStmt>();
320 constant_stmt &&
321 constant_stmt->val.dt->is_primitive(PrimitiveTypeID::i32)) {
322 val = constant_stmt->val.val_i32;
323 }
324 return val;
325 }
326};
327
328namespace irpass {
329
330void bit_loop_vectorize(IRNode *root) {
331 TI_AUTO_PROF;
332 BitLoopVectorize::run(root);
333 die(root);
334}
335
336} // namespace irpass
337
338} // namespace taichi::lang
339