1#include <variant>
2
3#include "taichi/ir/ir.h"
4#include "taichi/ir/statements.h"
5#include "taichi/ir/transforms.h"
6#include "taichi/ir/visitors.h"
7#include "taichi/program/program.h"
8#include "taichi/system/profiler.h"
9
10namespace taichi::lang {
11
12class Scalarize : public BasicStmtVisitor {
13 public:
14 ImmediateIRModifier immediate_modifier_;
15 DelayedIRModifier delayed_modifier_;
16
17 explicit Scalarize(IRNode *node) : immediate_modifier_(node) {
18 node->accept(this);
19
20 delayed_modifier_.modify_ir();
21 }
22
23 /*
24 "val" of StoreStmt should have already been replaced by a MatrixInitStmt in
25 former scalarization.
26
27 Before:
28 StoreStmt(TensorType<4 x i32>* dest, TensorType<4 x i32> val)
29
30 After:
31 addr0 = MatrixPtrStmt(TensorType<4 x i32>* dest, 0)
32 addr1 = MatrixPtrStmt(TensorType<4 x i32>* dest, 1)
33 addr2 = MatrixPtrStmt(TensorType<4 x i32>* dest, 2)
34 addr2 = MatrixPtrStmt(TensorType<4 x i32>* dest, 3)
35
36 StoreStmt(i32* addr0, i32 val->cast<MatrixInitStmt>()->val[0])
37 StoreStmt(i32* addr1, i32 val->cast<MatrixInitStmt>()->val[1])
38 StoreStmt(i32* addr2, i32 val->cast<MatrixInitStmt>()->val[2])
39 StoreStmt(i32* addr3, i32 val->cast<MatrixInitStmt>()->val[3])
40 */
41 template <typename T>
42 void scalarize_store_stmt(T *stmt) {
43 auto dest_dtype = stmt->dest->ret_type.ptr_removed();
44 auto val_dtype = stmt->val->ret_type;
45 if (dest_dtype->template is<TensorType>() &&
46 val_dtype->template is<TensorType>()) {
47 // Needs scalarize
48 auto dest_tensor_type = dest_dtype->template as<TensorType>();
49 auto val_tensor_type = val_dtype->template as<TensorType>();
50
51 TI_ASSERT(dest_tensor_type->get_shape() == val_tensor_type->get_shape());
52
53 TI_ASSERT(stmt->val->template is<MatrixInitStmt>());
54 auto matrix_init_stmt = stmt->val->template as<MatrixInitStmt>();
55
56 int num_elements = val_tensor_type->get_num_elements();
57 auto primitive_type = dest_tensor_type->get_element_type();
58 for (int i = 0; i < num_elements; i++) {
59 auto const_stmt = std::make_unique<ConstStmt>(
60 TypedConstant(get_data_type<int32>(), i));
61
62 auto matrix_ptr_stmt =
63 std::make_unique<MatrixPtrStmt>(stmt->dest, const_stmt.get());
64 matrix_ptr_stmt->ret_type = primitive_type;
65 matrix_ptr_stmt->ret_type.set_is_pointer(true);
66
67 auto scalarized_stmt = std::make_unique<T>(matrix_ptr_stmt.get(),
68 matrix_init_stmt->values[i]);
69
70 delayed_modifier_.insert_before(stmt, std::move(const_stmt));
71 delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt));
72 delayed_modifier_.insert_before(stmt, std::move(scalarized_stmt));
73 }
74
75 delayed_modifier_.erase(stmt);
76 }
77 }
78
79 /*
80 Before:
81 TensorType<4 x i32> val = LoadStmt(TensorType<4 x i32>* src)
82
83 After:
84 i32* addr0 = MatrixPtrStmt(TensorType<4 x i32>* src, 0)
85 i32* addr1 = MatrixPtrStmt(TensorType<4 x i32>* src, 1)
86 i32* addr2 = MatrixPtrStmt(TensorType<4 x i32>* src, 2)
87 i32* addr3 = MatrixPtrStmt(TensorType<4 x i32>* src, 3)
88
89 i32 val0 = LoadStmt(addr0)
90 i32 val1 = LoadStmt(addr1)
91 i32 val2 = LoadStmt(addr2)
92 i32 val3 = LoadStmt(addr3)
93
94 tmp = MatrixInitStmt(val0, val1, val2, val3)
95 stmt->replace_all_usages_with(tmp)
96 */
97 template <typename T>
98 void scalarize_load_stmt(T *stmt) {
99 auto src_dtype = stmt->src->ret_type.ptr_removed();
100 if (src_dtype->template is<TensorType>()) {
101 // Needs scalarize
102 auto src_tensor_type = src_dtype->template as<TensorType>();
103
104 std::vector<Stmt *> matrix_init_values;
105 int num_elements = src_tensor_type->get_num_elements();
106
107 auto primitive_type = src_tensor_type->get_element_type();
108 for (size_t i = 0; i < num_elements; i++) {
109 auto const_stmt = std::make_unique<ConstStmt>(
110 TypedConstant(get_data_type<int32>(), i));
111
112 auto matrix_ptr_stmt =
113 std::make_unique<MatrixPtrStmt>(stmt->src, const_stmt.get());
114 matrix_ptr_stmt->ret_type = primitive_type;
115 matrix_ptr_stmt->ret_type.set_is_pointer(true);
116
117 auto scalarized_stmt = std::make_unique<T>(matrix_ptr_stmt.get());
118 scalarized_stmt->ret_type = primitive_type;
119
120 matrix_init_values.push_back(scalarized_stmt.get());
121
122 delayed_modifier_.insert_before(stmt, std::move(const_stmt));
123 delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt));
124 delayed_modifier_.insert_before(stmt, std::move(scalarized_stmt));
125 }
126
127 auto matrix_init_stmt =
128 std::make_unique<MatrixInitStmt>(matrix_init_values);
129 matrix_init_stmt->ret_type = src_dtype;
130
131 immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
132 delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));
133
134 delayed_modifier_.erase(stmt);
135 }
136 }
137
138 /*
139
140 Before:
141 TensorType<4 x i32> val = UnaryStmt(TensorType<4 x i32> operand)
142
143 * Note that "operand" should have already been scalarized to
144 MatrixInitStmt
145
146 After:
147 i32 calc_val0 = UnaryStmt(operand->cast<MatrixInitStmt>()->val[0])
148 i32 calc_val1 = UnaryStmt(operand->cast<MatrixInitStmt>()->val[1])
149 i32 calc_val2 = UnaryStmt(operand->cast<MatrixInitStmt>()->val[2])
150 i32 calc_val3 = UnaryStmt(operand->cast<MatrixInitStmt>()->val[3])
151
152 tmp = MatrixInitStmt(calc_val0, calc_val1,
153 calc_val2, calc_val3)
154
155 stmt->replace_all_usages_with(tmp)
156 */
157 void visit(UnaryOpStmt *stmt) override {
158 auto operand_dtype = stmt->operand->ret_type;
159 if (operand_dtype->is<TensorType>()) {
160 // Needs scalarize
161 auto operand_tensor_type = operand_dtype->as<TensorType>();
162
163 TI_ASSERT(stmt->operand->is<MatrixInitStmt>());
164 auto operand_matrix_init_stmt = stmt->operand->cast<MatrixInitStmt>();
165
166 TI_ASSERT(operand_matrix_init_stmt->values.size() ==
167 operand_tensor_type->get_num_elements());
168
169 std::vector<Stmt *> matrix_init_values;
170 int num_elements = operand_tensor_type->get_num_elements();
171 auto primitive_type = stmt->ret_type.get_element_type();
172 for (size_t i = 0; i < num_elements; i++) {
173 auto unary_stmt = std::make_unique<UnaryOpStmt>(
174 stmt->op_type, operand_matrix_init_stmt->values[i]);
175 if (stmt->is_cast()) {
176 unary_stmt->cast_type = stmt->cast_type.get_element_type();
177 }
178 unary_stmt->ret_type = primitive_type;
179 matrix_init_values.push_back(unary_stmt.get());
180
181 delayed_modifier_.insert_before(stmt, std::move(unary_stmt));
182 }
183
184 auto matrix_init_stmt =
185 std::make_unique<MatrixInitStmt>(matrix_init_values);
186 matrix_init_stmt->ret_type = operand_dtype;
187
188 immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
189 delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));
190
191 delayed_modifier_.erase(stmt);
192 }
193 }
194
195 /*
196 Before:
197 TensorType<4 x i32> val = BinaryStmt(TensorType<4 x i32> lhs,
198 TensorType<4 x i32> rhs)
199
200 * Note that "lhs" and "rhs" should have already been scalarized to
201 MatrixInitStmt
202
203 After:
204 i32 calc_val0 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[0],
205 rhs->cast<MatrixInitStmt>()->val[0])
206 i32 calc_val1 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[1],
207 rhs->cast<MatrixInitStmt>()->val[1])
208 i32 calc_val2 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[2],
209 rhs->cast<MatrixInitStmt>()->val[2])
210 i32 calc_val3 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[3],
211 rhs->cast<MatrixInitStmt>()->val[3])
212
213 tmp = MatrixInitStmt(calc_val0, calc_val1,
214 calc_val2, calc_val3)
215
216 stmt->replace_all_usages_with(tmp)
217 */
218 void visit(BinaryOpStmt *stmt) override {
219 auto lhs_dtype = stmt->lhs->ret_type;
220 auto rhs_dtype = stmt->rhs->ret_type;
221 if (lhs_dtype->is<TensorType>() || rhs_dtype->is<TensorType>()) {
222 // Make sure broadcasting has been correctly applied by
223 // BinaryOpExpression::type_check().
224 TI_ASSERT(lhs_dtype->is<TensorType>() && rhs_dtype->is<TensorType>());
225 // However, since the type conversions are delayed until
226 // irpass::type_check(), we only check for the shape here.
227 TI_ASSERT(lhs_dtype->cast<TensorType>()->get_shape() ==
228 rhs_dtype->cast<TensorType>()->get_shape());
229 // Scalarization for LoadStmt should have already replaced both operands
230 // to MatrixInitStmt.
231 TI_ASSERT(stmt->lhs->is<MatrixInitStmt>());
232 TI_ASSERT(stmt->rhs->is<MatrixInitStmt>());
233
234 auto lhs_matrix_init_stmt = stmt->lhs->cast<MatrixInitStmt>();
235 std::vector<Stmt *> lhs_vals = lhs_matrix_init_stmt->values;
236
237 auto rhs_matrix_init_stmt = stmt->rhs->cast<MatrixInitStmt>();
238 std::vector<Stmt *> rhs_vals = rhs_matrix_init_stmt->values;
239
240 TI_ASSERT(rhs_vals.size() == lhs_vals.size());
241
242 size_t num_elements = lhs_vals.size();
243 auto primitive_type = stmt->ret_type.get_element_type();
244 std::vector<Stmt *> matrix_init_values;
245 for (size_t i = 0; i < num_elements; i++) {
246 auto binary_stmt = std::make_unique<BinaryOpStmt>(
247 stmt->op_type, lhs_vals[i], rhs_vals[i]);
248 matrix_init_values.push_back(binary_stmt.get());
249 binary_stmt->ret_type = primitive_type;
250
251 delayed_modifier_.insert_before(stmt, std::move(binary_stmt));
252 }
253
254 auto matrix_init_stmt =
255 std::make_unique<MatrixInitStmt>(matrix_init_values);
256 matrix_init_stmt->ret_type = stmt->ret_type;
257
258 immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
259 delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));
260
261 delayed_modifier_.erase(stmt);
262 }
263 }
264
265 void visit(PrintStmt *stmt) override {
266 auto &contents = stmt->contents;
267 std::vector<std::variant<Stmt *, std::string>> new_contents;
268 for (size_t i = 0; i < contents.size(); i++) {
269 auto content = contents[i];
270 if (auto string_ptr = std::get_if<std::string>(&content)) {
271 new_contents.push_back(*string_ptr);
272 } else {
273 Stmt *print_stmt = std::get<Stmt *>(content);
274 if (print_stmt->is<MatrixInitStmt>()) {
275 auto matrix_init_stmt = print_stmt->cast<MatrixInitStmt>();
276 auto tensor_shape =
277 print_stmt->ret_type->as<TensorType>()->get_shape();
278
279 bool is_matrix = tensor_shape.size() == 2;
280 int m = tensor_shape[0];
281
282 new_contents.push_back("[");
283 if (is_matrix) {
284 int n = tensor_shape[1];
285 for (size_t i = 0; i < m; i++) {
286 new_contents.push_back("[");
287 for (size_t j = 0; j < n; j++) {
288 size_t index = i * n + j;
289 new_contents.push_back(matrix_init_stmt->values[index]);
290 if (j != n - 1)
291 new_contents.push_back(", ");
292 }
293 new_contents.push_back("]");
294
295 if (i != m - 1)
296 new_contents.push_back(", ");
297 }
298 } else {
299 for (size_t i = 0; i < m; i++) {
300 new_contents.push_back(matrix_init_stmt->values[i]);
301 if (i != m - 1)
302 new_contents.push_back(", ");
303 }
304 }
305 new_contents.push_back("]");
306 } else {
307 new_contents.push_back(print_stmt);
308 }
309 }
310 }
311
312 // Merge string contents
313 std::vector<std::variant<Stmt *, std::string>> merged_contents;
314 std::string merged_string = "";
315 for (const auto &content : new_contents) {
316 if (auto string_content = std::get_if<std::string>(&content)) {
317 merged_string += *string_content;
318 } else {
319 if (!merged_string.empty()) {
320 merged_contents.push_back(merged_string);
321 merged_string = "";
322 }
323 merged_contents.push_back(content);
324 }
325 }
326 if (!merged_string.empty())
327 merged_contents.push_back(merged_string);
328
329 delayed_modifier_.insert_before(stmt,
330 Stmt::make<PrintStmt>(merged_contents));
331 delayed_modifier_.erase(stmt);
332 }
333
334 /*
335 Before:
336 TensorType<4 x i32> val = AtomicStmt(TensorType<4 x i32>* dest,
337 TensorType<4 x i32> val)
338
339 After:
340 i32* dest_ptr_0 = MatrixPtrStmt(dest, 0)
341 i32* dest_ptr_1 = MatrixPtrStmt(dest, 1)
342 i32* dest_ptr_2 = MatrixPtrStmt(dest, 2)
343 i32* dest_ptr_3 = MatrixPtrStmt(dest, 3)
344
345 i32 dest_val0 = AtomicStmt(dest_ptr_0,
346 val->cast<MatrixInitStmt>()->val[0])
347 i32 dest_val1 = AtomicStmt(dest_ptr_1,
348 val->cast<MatrixInitStmt>()->val[1])
349 i32 dest_val2 = AtomicStmt(dest_ptr_2,
350 val->cast<MatrixInitStmt>()->val[2])
351 i32 dest_val3 = AtomicStmt(dest_ptr_3,
352 val->cast<MatrixInitStmt>()->val[3])
353
354 tmp = MatrixInitStmt(dest_val0, dest_val1,
355 dest_val2, dest_val3)
356
357 stmt->replace_all_usages_with(tmp)
358 */
359 void visit(AtomicOpStmt *stmt) override {
360 auto dest_dtype = stmt->dest->ret_type.ptr_removed();
361 auto val_dtype = stmt->val->ret_type;
362 if (dest_dtype->is<TensorType>() || val_dtype->is<TensorType>()) {
363 // Make sure broadcasting has been correctly applied by
364 // AtomicOpExpression::type_check().
365 TI_ASSERT(dest_dtype->is<TensorType>() && val_dtype->is<TensorType>());
366 // However, since the type conversions are delayed until
367 // irpass::type_check(), we only check for the shape here.
368 TI_ASSERT(dest_dtype->cast<TensorType>()->get_shape() ==
369 val_dtype->cast<TensorType>()->get_shape());
370 // Scalarization for LoadStmt should have already replaced val operand
371 // to MatrixInitStmt.
372 TI_ASSERT(stmt->val->is<MatrixInitStmt>());
373
374 auto val_matrix_init_stmt = stmt->val->cast<MatrixInitStmt>();
375 std::vector<Stmt *> val_values = val_matrix_init_stmt->values;
376
377 size_t num_elements = val_values.size();
378 auto primitive_type = stmt->ret_type.get_element_type();
379
380 // Scalarize dest & val
381 std::vector<Stmt *> matrix_init_values;
382 for (size_t i = 0; i < num_elements; i++) {
383 // scalarize to dest_i
384 auto const_stmt = std::make_unique<ConstStmt>(
385 TypedConstant(get_data_type<int32>(), i));
386 auto matrix_ptr_stmt =
387 std::make_unique<MatrixPtrStmt>(stmt->dest, const_stmt.get());
388
389 // scalarize to val_i
390 auto val_stmt = val_values[i];
391
392 // assemble to scalarized atomic_op
393 auto atomic_stmt = std::make_unique<AtomicOpStmt>(
394 stmt->op_type, matrix_ptr_stmt.get(), val_stmt);
395 atomic_stmt->ret_type = primitive_type;
396
397 matrix_init_values.push_back(atomic_stmt.get());
398
399 delayed_modifier_.insert_before(stmt, std::move(const_stmt));
400 delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt));
401 delayed_modifier_.insert_before(stmt, std::move(atomic_stmt));
402 }
403
404 auto matrix_init_stmt =
405 std::make_unique<MatrixInitStmt>(matrix_init_values);
406 matrix_init_stmt->ret_type = stmt->ret_type;
407
408 immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
409 delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));
410
411 delayed_modifier_.erase(stmt);
412 }
413 }
414
415 /*
416 Before:
417 TensorType<4 x i32> val = TernaryStmt(TensorType<4 x i32> cond,
418 TensorType<4 x i32> lhs,
419 TensorType<4 x i32> rhs)
420
421 After:
422 i32 val0 = TernaryStmt(cond->cast<MatrixInitStmt>()->val[0],
423 lhs->cast<MatrixInitStmt>()->val[0],
424 rhs->cast<MatrixInitStmt>()->val[0])
425
426 i32 val1 = TernaryStmt(cond->cast<MatrixInitStmt>()->val[1],
427 lhs->cast<MatrixInitStmt>()->val[1],
428 rhs->cast<MatrixInitStmt>()->val[1])
429
430 i32 val2 = TernaryStmt(cond->cast<MatrixInitStmt>()->val[2],
431 lhs->cast<MatrixInitStmt>()->val[2],
432 rhs->cast<MatrixInitStmt>()->val[2])
433
434 i32 val3 = TernaryStmt(cond->cast<MatrixInitStmt>()->val[3],
435 lhs->cast<MatrixInitStmt>()->val[3],
436 rhs->cast<MatrixInitStmt>()->val[3])
437
438 tmp = MatrixInitStmt(val0, val1, val2, val3)
439
440 stmt->replace_all_usages_with(tmp)
441 */
442 void visit(TernaryOpStmt *stmt) override {
443 auto cond_dtype = stmt->op1->ret_type;
444 auto op2_dtype = stmt->op2->ret_type;
445 auto op3_dtype = stmt->op3->ret_type;
446 if (cond_dtype->is<TensorType>() || op2_dtype->is<TensorType>() ||
447 op3_dtype->is<TensorType>()) {
448 // Make sure broadcasting has been correctly applied by
449 // TernaryOpExpression::type_check().
450 TI_ASSERT(cond_dtype->is<TensorType>() && op2_dtype->is<TensorType>() &&
451 op3_dtype->is<TensorType>());
452 // However, since the type conversions are delayed until
453 // irpass::type_check(), we only check for the shape here.
454 TI_ASSERT(cond_dtype.get_shape() == op2_dtype.get_shape());
455 TI_ASSERT(op2_dtype.get_shape() == op3_dtype.get_shape());
456 // Scalarization for LoadStmt should have already replaced all operands
457 // to MatrixInitStmt.
458 TI_ASSERT(stmt->op1->is<MatrixInitStmt>());
459 TI_ASSERT(stmt->op2->is<MatrixInitStmt>());
460 TI_ASSERT(stmt->op3->is<MatrixInitStmt>());
461
462 auto cond_matrix_init_stmt = stmt->op1->cast<MatrixInitStmt>();
463 std::vector<Stmt *> cond_vals = cond_matrix_init_stmt->values;
464
465 auto op2_matrix_init_stmt = stmt->op2->cast<MatrixInitStmt>();
466 std::vector<Stmt *> op2_vals = op2_matrix_init_stmt->values;
467
468 auto op3_matrix_init_stmt = stmt->op3->cast<MatrixInitStmt>();
469 std::vector<Stmt *> op3_vals = op3_matrix_init_stmt->values;
470
471 TI_ASSERT(cond_vals.size() == op2_vals.size());
472 TI_ASSERT(op2_vals.size() == op3_vals.size());
473
474 size_t num_elements = cond_vals.size();
475 auto primitive_type = stmt->ret_type.get_element_type();
476 std::vector<Stmt *> matrix_init_values;
477 for (size_t i = 0; i < num_elements; i++) {
478 auto ternary_stmt = std::make_unique<TernaryOpStmt>(
479 stmt->op_type, cond_vals[i], op2_vals[i], op3_vals[i]);
480 matrix_init_values.push_back(ternary_stmt.get());
481 ternary_stmt->ret_type = primitive_type;
482
483 delayed_modifier_.insert_before(stmt, std::move(ternary_stmt));
484 }
485
486 auto matrix_init_stmt =
487 std::make_unique<MatrixInitStmt>(matrix_init_values);
488 matrix_init_stmt->ret_type = stmt->ret_type;
489
490 immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
491 delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));
492
493 delayed_modifier_.erase(stmt);
494 }
495 }
496
497 void visit(GlobalStoreStmt *stmt) override {
498 scalarize_store_stmt<GlobalStoreStmt>(stmt);
499 }
500
501 void visit(LocalStoreStmt *stmt) override {
502 scalarize_store_stmt<LocalStoreStmt>(stmt);
503 }
504
505 void visit(GlobalLoadStmt *stmt) override {
506 scalarize_load_stmt<GlobalLoadStmt>(stmt);
507 }
508
509 void visit(LocalLoadStmt *stmt) override {
510 scalarize_load_stmt<LocalLoadStmt>(stmt);
511 }
512
513 void visit(ArgLoadStmt *stmt) override {
514 auto ret_type = stmt->ret_type.ptr_removed().get_element_type();
515 auto arg_load = std::make_unique<ArgLoadStmt>(stmt->arg_id, ret_type,
516 stmt->is_ptr, stmt->is_grad);
517
518 immediate_modifier_.replace_usages_with(stmt, arg_load.get());
519
520 delayed_modifier_.insert_before(stmt, std::move(arg_load));
521 delayed_modifier_.erase(stmt);
522 }
523
524 private:
525 using BasicStmtVisitor::visit;
526};
527
528// The GatherScalarizableLocalPointers gathers all local TensorType allocas
529// only indexed with constants, which can then be scalarized in the
530// ScalarizeLocalPointers pass.
531class GatherScalarizableLocalPointers : public BasicStmtVisitor {
532 private:
533 using BasicStmtVisitor::visit;
534
535 std::unordered_map<Stmt *, bool> is_alloca_scalarizable_;
536
537 public:
538 void visit(AllocaStmt *stmt) override {
539 if (stmt->ret_type.ptr_removed()->is<TensorType>()) {
540 TI_ASSERT(is_alloca_scalarizable_.count(stmt) == 0);
541 is_alloca_scalarizable_[stmt] = !stmt->is_shared;
542 }
543 }
544
545 void visit(MatrixPtrStmt *stmt) override {
546 if (stmt->origin->is<AllocaStmt>()) {
547 TI_ASSERT(is_alloca_scalarizable_.count(stmt->origin) == 1);
548 if (!stmt->offset->is<ConstStmt>()) {
549 is_alloca_scalarizable_[stmt->origin] = false;
550 }
551 }
552 }
553
554 static std::unordered_set<Stmt *> run(IRNode *node) {
555 GatherScalarizableLocalPointers pass;
556 node->accept(&pass);
557 std::unordered_set<Stmt *> result;
558 for (auto &[k, v] : pass.is_alloca_scalarizable_) {
559 if (v) {
560 result.insert(k);
561 }
562 }
563 return result;
564 }
565};
566
567class ScalarizeLocalPointers : public BasicStmtVisitor {
568 public:
569 ImmediateIRModifier immediate_modifier_;
570 DelayedIRModifier delayed_modifier_;
571
572 std::unordered_set<Stmt *> scalarizable_allocas_;
573 // { original_alloca_stmt : [scalarized_alloca_stmt0, ...] }
574 std::unordered_map<Stmt *, std::vector<Stmt *>> scalarized_local_tensor_map_;
575
576 explicit ScalarizeLocalPointers(
577 IRNode *node,
578 const std::unordered_set<Stmt *> &scalarizable_allocas)
579 : immediate_modifier_(node), scalarizable_allocas_(scalarizable_allocas) {
580 node->accept(this);
581
582 delayed_modifier_.modify_ir();
583 }
584
585 /*
586 Accessing scalar values are always more efficient than accessing elements
587 from a vector - the former generates less instructions, leading to better
588 performance in both compilation and runtime.
589
590 Although we can do nothing about "global" tensors like tensors from
591 ArgLoadStmt or GlobalPtrStmt, we can still optimize "local" tensors like
592 tensors from AllocaStmt. In this pass, we ask AllocaStmt to allocate
593 multiple scalarized PrimitiveTyped variables in replacement of the original
594 TensorType.
595
596 An additional container "scalarized_local_tensor_map_" is used to keep track
597 of the scalarized AllocaStmt, for later use in LoadStmt and StoreStmt.
598
599 Before:
600 TensorType<4 x i32>* addr = AllocaStmt(TensorType<4 x i32>)
601
602 After:
603 i32 addr0 = AllocaStmt(i32)
604 i32 addr1 = AllocaStmt(i32)
605 i32 addr2 = AllocaStmt(i32)
606 i32 addr3 = AllocaStmt(i32)
607
608 scalarized_local_tensor_map_[addr] = {addr0, addr1, addr2, addr3}
609 */
610 void visit(AllocaStmt *stmt) override {
611 if (scalarizable_allocas_.count(stmt) == 1) {
612 auto tensor_type = stmt->ret_type.ptr_removed()->cast<TensorType>();
613 TI_ASSERT(tensor_type != nullptr);
614 auto primitive_type = tensor_type->get_element_type();
615
616 TI_ASSERT(scalarized_local_tensor_map_.count(stmt) == 0);
617 scalarized_local_tensor_map_[stmt] = {};
618 for (size_t i = 0; i < tensor_type->get_num_elements(); i++) {
619 auto scalarized_alloca_stmt =
620 std::make_unique<AllocaStmt>(primitive_type);
621 scalarized_alloca_stmt->ret_type = primitive_type;
622
623 scalarized_local_tensor_map_[stmt].push_back(
624 scalarized_alloca_stmt.get());
625 delayed_modifier_.insert_before(stmt,
626 std::move(scalarized_alloca_stmt));
627 }
628
629 delayed_modifier_.erase(stmt);
630 }
631 }
632
633 /*
634 Before:
635 MatrixPtrStmt(TensorType<4 x i32>* alloca_stmt, int offset)
636
637 After:
638 scalarized_alloca_stmt = scalarized_local_tensor_map_[alloca_stmt][offset]
639 stmt->replace_all_usages_with(scalarized_alloca_stmt)
640 */
641 void visit(MatrixPtrStmt *stmt) override {
642 if (stmt->origin->is<AllocaStmt>() &&
643 scalarizable_allocas_.count(stmt->origin) == 1) {
644 auto alloca_stmt = stmt->origin->cast<AllocaStmt>();
645 auto tensor_type =
646 alloca_stmt->ret_type.ptr_removed()->cast<TensorType>();
647 TI_ASSERT(tensor_type != nullptr);
648 int num_elements = tensor_type->get_num_elements();
649 TI_ASSERT(scalarized_local_tensor_map_.count(alloca_stmt));
650
651 const auto &scalarized_alloca_stmts =
652 scalarized_local_tensor_map_[alloca_stmt];
653 TI_ASSERT(scalarized_alloca_stmts.size() == num_elements);
654
655 TI_ASSERT(stmt->offset->is<ConstStmt>());
656 int offset = stmt->offset->cast<ConstStmt>()->val.val_int32();
657
658 TI_ASSERT(offset < scalarized_alloca_stmts.size());
659 auto new_stmt = scalarized_alloca_stmts[offset];
660
661 immediate_modifier_.replace_usages_with(stmt, new_stmt);
662 delayed_modifier_.erase(stmt);
663 }
664 }
665
666 private:
667 using BasicStmtVisitor::visit;
668};
669
670// The ExtractLocalPointers pass aims at removing redundant ConstStmts and
671// MatrixPtrStmts generated for any (AllocaStmt, integer) pair by extracting
672// a unique copy for any future usage.
673//
674// Example for redundant stmts:
675// <i32> $0 = const 0
676// <i32> $1 = const 1
677// ...
678// <[Tensor (3, 3) f32]> $47738 = alloca
679// <i32> $47739 = const 0 [REDUNDANT]
680// <*f32> $47740 = shift ptr [$47738 + $47739]
681// $47741 : local store [$47740 <- $47713]
682// <i32> $47742 = const 1 [REDUNDANT]
683// <*f32> $47743 = shift ptr [$47738 + $47742]
684// $47744 : local store [$47743 <- $47716]
685// ...
686// <i32> $47812 = const 1 [REDUNDANT]
687// <*f32> $47813 = shift ptr [$47738 + $47812] [REDUNDANT]
688// <f32> $47814 = local load [$47813]
689class ExtractLocalPointers : public BasicStmtVisitor {
690 public:
691 ImmediateIRModifier immediate_modifier_;
692 DelayedIRModifier delayed_modifier_;
693
694 std::unordered_map<std::pair<Stmt *, int>,
695 Stmt *,
696 hashing::Hasher<std::pair<Stmt *, int>>>
697 first_matrix_ptr_; // mapping an (AllocaStmt, integer) pair to the first
698 // MatrixPtrStmt representing it
699 std::unordered_map<int, Stmt *>
700 first_const_; // mapping an integer to the first ConstStmt representing
701 // it
702 Block *top_level_;
703
704 explicit ExtractLocalPointers(IRNode *root) : immediate_modifier_(root) {
705 TI_ASSERT(root->is<Block>());
706 top_level_ = root->as<Block>();
707 root->accept(this);
708 delayed_modifier_.modify_ir();
709 }
710
711 void visit(MatrixPtrStmt *stmt) override {
712 if (stmt->origin->is<AllocaStmt>()) {
713 auto alloca_stmt = stmt->origin->cast<AllocaStmt>();
714 auto tensor_type =
715 alloca_stmt->ret_type.ptr_removed()->cast<TensorType>();
716 TI_ASSERT(tensor_type != nullptr);
717 if (stmt->offset->is<ConstStmt>()) {
718 int offset = stmt->offset->cast<ConstStmt>()->val.val_int32();
719 if (first_const_.count(offset) == 0) {
720 first_const_[offset] = stmt->offset;
721 delayed_modifier_.extract_to_block_front(stmt->offset, top_level_);
722 }
723 auto key = std::make_pair(alloca_stmt, offset);
724 if (first_matrix_ptr_.count(key) == 0) {
725 auto extracted = std::make_unique<MatrixPtrStmt>(
726 alloca_stmt, first_const_[offset]);
727 first_matrix_ptr_[key] = extracted.get();
728 delayed_modifier_.insert_after(alloca_stmt, std::move(extracted));
729 }
730 auto new_stmt = first_matrix_ptr_[key];
731 immediate_modifier_.replace_usages_with(stmt, new_stmt);
732 delayed_modifier_.erase(stmt);
733 }
734 }
735 }
736
737 private:
738 using BasicStmtVisitor::visit;
739};
740
741namespace irpass {
742
743void scalarize(IRNode *root) {
744 TI_AUTO_PROF;
745 Scalarize scalarize_pass(root);
746 auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root);
747 ScalarizeLocalPointers scalarize_pointers_pass(root, scalarizable_allocas);
748 ExtractLocalPointers extract_pointers_pass(root);
749}
750
751} // namespace irpass
752
753} // namespace taichi::lang
754