1 | #include <variant> |
2 | |
3 | #include "taichi/ir/ir.h" |
4 | #include "taichi/ir/statements.h" |
5 | #include "taichi/ir/transforms.h" |
6 | #include "taichi/ir/visitors.h" |
7 | #include "taichi/program/program.h" |
8 | #include "taichi/system/profiler.h" |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | class Scalarize : public BasicStmtVisitor { |
13 | public: |
14 | ImmediateIRModifier immediate_modifier_; |
15 | DelayedIRModifier delayed_modifier_; |
16 | |
17 | explicit Scalarize(IRNode *node) : immediate_modifier_(node) { |
18 | node->accept(this); |
19 | |
20 | delayed_modifier_.modify_ir(); |
21 | } |
22 | |
23 | /* |
24 | "val" of StoreStmt should have already been replaced by a MatrixInitStmt in |
25 | former scalarization. |
26 | |
27 | Before: |
28 | StoreStmt(TensorType<4 x i32>* dest, TensorType<4 x i32> val) |
29 | |
30 | After: |
31 | addr0 = MatrixPtrStmt(TensorType<4 x i32>* dest, 0) |
32 | addr1 = MatrixPtrStmt(TensorType<4 x i32>* dest, 1) |
33 | addr2 = MatrixPtrStmt(TensorType<4 x i32>* dest, 2) |
34 | addr2 = MatrixPtrStmt(TensorType<4 x i32>* dest, 3) |
35 | |
36 | StoreStmt(i32* addr0, i32 val->cast<MatrixInitStmt>()->val[0]) |
37 | StoreStmt(i32* addr1, i32 val->cast<MatrixInitStmt>()->val[1]) |
38 | StoreStmt(i32* addr2, i32 val->cast<MatrixInitStmt>()->val[2]) |
39 | StoreStmt(i32* addr3, i32 val->cast<MatrixInitStmt>()->val[3]) |
40 | */ |
41 | template <typename T> |
42 | void scalarize_store_stmt(T *stmt) { |
43 | auto dest_dtype = stmt->dest->ret_type.ptr_removed(); |
44 | auto val_dtype = stmt->val->ret_type; |
45 | if (dest_dtype->template is<TensorType>() && |
46 | val_dtype->template is<TensorType>()) { |
47 | // Needs scalarize |
48 | auto dest_tensor_type = dest_dtype->template as<TensorType>(); |
49 | auto val_tensor_type = val_dtype->template as<TensorType>(); |
50 | |
51 | TI_ASSERT(dest_tensor_type->get_shape() == val_tensor_type->get_shape()); |
52 | |
53 | TI_ASSERT(stmt->val->template is<MatrixInitStmt>()); |
54 | auto matrix_init_stmt = stmt->val->template as<MatrixInitStmt>(); |
55 | |
56 | int num_elements = val_tensor_type->get_num_elements(); |
57 | auto primitive_type = dest_tensor_type->get_element_type(); |
58 | for (int i = 0; i < num_elements; i++) { |
59 | auto const_stmt = std::make_unique<ConstStmt>( |
60 | TypedConstant(get_data_type<int32>(), i)); |
61 | |
62 | auto matrix_ptr_stmt = |
63 | std::make_unique<MatrixPtrStmt>(stmt->dest, const_stmt.get()); |
64 | matrix_ptr_stmt->ret_type = primitive_type; |
65 | matrix_ptr_stmt->ret_type.set_is_pointer(true); |
66 | |
67 | auto scalarized_stmt = std::make_unique<T>(matrix_ptr_stmt.get(), |
68 | matrix_init_stmt->values[i]); |
69 | |
70 | delayed_modifier_.insert_before(stmt, std::move(const_stmt)); |
71 | delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); |
72 | delayed_modifier_.insert_before(stmt, std::move(scalarized_stmt)); |
73 | } |
74 | |
75 | delayed_modifier_.erase(stmt); |
76 | } |
77 | } |
78 | |
79 | /* |
80 | Before: |
81 | TensorType<4 x i32> val = LoadStmt(TensorType<4 x i32>* src) |
82 | |
83 | After: |
84 | i32* addr0 = MatrixPtrStmt(TensorType<4 x i32>* src, 0) |
85 | i32* addr1 = MatrixPtrStmt(TensorType<4 x i32>* src, 1) |
86 | i32* addr2 = MatrixPtrStmt(TensorType<4 x i32>* src, 2) |
87 | i32* addr3 = MatrixPtrStmt(TensorType<4 x i32>* src, 3) |
88 | |
89 | i32 val0 = LoadStmt(addr0) |
90 | i32 val1 = LoadStmt(addr1) |
91 | i32 val2 = LoadStmt(addr2) |
92 | i32 val3 = LoadStmt(addr3) |
93 | |
94 | tmp = MatrixInitStmt(val0, val1, val2, val3) |
95 | stmt->replace_all_usages_with(tmp) |
96 | */ |
97 | template <typename T> |
98 | void scalarize_load_stmt(T *stmt) { |
99 | auto src_dtype = stmt->src->ret_type.ptr_removed(); |
100 | if (src_dtype->template is<TensorType>()) { |
101 | // Needs scalarize |
102 | auto src_tensor_type = src_dtype->template as<TensorType>(); |
103 | |
104 | std::vector<Stmt *> matrix_init_values; |
105 | int num_elements = src_tensor_type->get_num_elements(); |
106 | |
107 | auto primitive_type = src_tensor_type->get_element_type(); |
108 | for (size_t i = 0; i < num_elements; i++) { |
109 | auto const_stmt = std::make_unique<ConstStmt>( |
110 | TypedConstant(get_data_type<int32>(), i)); |
111 | |
112 | auto matrix_ptr_stmt = |
113 | std::make_unique<MatrixPtrStmt>(stmt->src, const_stmt.get()); |
114 | matrix_ptr_stmt->ret_type = primitive_type; |
115 | matrix_ptr_stmt->ret_type.set_is_pointer(true); |
116 | |
117 | auto scalarized_stmt = std::make_unique<T>(matrix_ptr_stmt.get()); |
118 | scalarized_stmt->ret_type = primitive_type; |
119 | |
120 | matrix_init_values.push_back(scalarized_stmt.get()); |
121 | |
122 | delayed_modifier_.insert_before(stmt, std::move(const_stmt)); |
123 | delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); |
124 | delayed_modifier_.insert_before(stmt, std::move(scalarized_stmt)); |
125 | } |
126 | |
127 | auto matrix_init_stmt = |
128 | std::make_unique<MatrixInitStmt>(matrix_init_values); |
129 | matrix_init_stmt->ret_type = src_dtype; |
130 | |
131 | immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); |
132 | delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); |
133 | |
134 | delayed_modifier_.erase(stmt); |
135 | } |
136 | } |
137 | |
138 | /* |
139 | |
140 | Before: |
141 | TensorType<4 x i32> val = UnaryStmt(TensorType<4 x i32> operand) |
142 | |
143 | * Note that "operand" should have already been scalarized to |
144 | MatrixInitStmt |
145 | |
146 | After: |
147 | i32 calc_val0 = UnaryStmt(operand->cast<MatrixInitStmt>()->val[0]) |
148 | i32 calc_val1 = UnaryStmt(operand->cast<MatrixInitStmt>()->val[1]) |
149 | i32 calc_val2 = UnaryStmt(operand->cast<MatrixInitStmt>()->val[2]) |
150 | i32 calc_val3 = UnaryStmt(operand->cast<MatrixInitStmt>()->val[3]) |
151 | |
152 | tmp = MatrixInitStmt(calc_val0, calc_val1, |
153 | calc_val2, calc_val3) |
154 | |
155 | stmt->replace_all_usages_with(tmp) |
156 | */ |
157 | void visit(UnaryOpStmt *stmt) override { |
158 | auto operand_dtype = stmt->operand->ret_type; |
159 | if (operand_dtype->is<TensorType>()) { |
160 | // Needs scalarize |
161 | auto operand_tensor_type = operand_dtype->as<TensorType>(); |
162 | |
163 | TI_ASSERT(stmt->operand->is<MatrixInitStmt>()); |
164 | auto operand_matrix_init_stmt = stmt->operand->cast<MatrixInitStmt>(); |
165 | |
166 | TI_ASSERT(operand_matrix_init_stmt->values.size() == |
167 | operand_tensor_type->get_num_elements()); |
168 | |
169 | std::vector<Stmt *> matrix_init_values; |
170 | int num_elements = operand_tensor_type->get_num_elements(); |
171 | auto primitive_type = stmt->ret_type.get_element_type(); |
172 | for (size_t i = 0; i < num_elements; i++) { |
173 | auto unary_stmt = std::make_unique<UnaryOpStmt>( |
174 | stmt->op_type, operand_matrix_init_stmt->values[i]); |
175 | if (stmt->is_cast()) { |
176 | unary_stmt->cast_type = stmt->cast_type.get_element_type(); |
177 | } |
178 | unary_stmt->ret_type = primitive_type; |
179 | matrix_init_values.push_back(unary_stmt.get()); |
180 | |
181 | delayed_modifier_.insert_before(stmt, std::move(unary_stmt)); |
182 | } |
183 | |
184 | auto matrix_init_stmt = |
185 | std::make_unique<MatrixInitStmt>(matrix_init_values); |
186 | matrix_init_stmt->ret_type = operand_dtype; |
187 | |
188 | immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); |
189 | delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); |
190 | |
191 | delayed_modifier_.erase(stmt); |
192 | } |
193 | } |
194 | |
195 | /* |
196 | Before: |
197 | TensorType<4 x i32> val = BinaryStmt(TensorType<4 x i32> lhs, |
198 | TensorType<4 x i32> rhs) |
199 | |
200 | * Note that "lhs" and "rhs" should have already been scalarized to |
201 | MatrixInitStmt |
202 | |
203 | After: |
204 | i32 calc_val0 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[0], |
205 | rhs->cast<MatrixInitStmt>()->val[0]) |
206 | i32 calc_val1 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[1], |
207 | rhs->cast<MatrixInitStmt>()->val[1]) |
208 | i32 calc_val2 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[2], |
209 | rhs->cast<MatrixInitStmt>()->val[2]) |
210 | i32 calc_val3 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[3], |
211 | rhs->cast<MatrixInitStmt>()->val[3]) |
212 | |
213 | tmp = MatrixInitStmt(calc_val0, calc_val1, |
214 | calc_val2, calc_val3) |
215 | |
216 | stmt->replace_all_usages_with(tmp) |
217 | */ |
218 | void visit(BinaryOpStmt *stmt) override { |
219 | auto lhs_dtype = stmt->lhs->ret_type; |
220 | auto rhs_dtype = stmt->rhs->ret_type; |
221 | if (lhs_dtype->is<TensorType>() || rhs_dtype->is<TensorType>()) { |
222 | // Make sure broadcasting has been correctly applied by |
223 | // BinaryOpExpression::type_check(). |
224 | TI_ASSERT(lhs_dtype->is<TensorType>() && rhs_dtype->is<TensorType>()); |
225 | // However, since the type conversions are delayed until |
226 | // irpass::type_check(), we only check for the shape here. |
227 | TI_ASSERT(lhs_dtype->cast<TensorType>()->get_shape() == |
228 | rhs_dtype->cast<TensorType>()->get_shape()); |
229 | // Scalarization for LoadStmt should have already replaced both operands |
230 | // to MatrixInitStmt. |
231 | TI_ASSERT(stmt->lhs->is<MatrixInitStmt>()); |
232 | TI_ASSERT(stmt->rhs->is<MatrixInitStmt>()); |
233 | |
234 | auto lhs_matrix_init_stmt = stmt->lhs->cast<MatrixInitStmt>(); |
235 | std::vector<Stmt *> lhs_vals = lhs_matrix_init_stmt->values; |
236 | |
237 | auto rhs_matrix_init_stmt = stmt->rhs->cast<MatrixInitStmt>(); |
238 | std::vector<Stmt *> rhs_vals = rhs_matrix_init_stmt->values; |
239 | |
240 | TI_ASSERT(rhs_vals.size() == lhs_vals.size()); |
241 | |
242 | size_t num_elements = lhs_vals.size(); |
243 | auto primitive_type = stmt->ret_type.get_element_type(); |
244 | std::vector<Stmt *> matrix_init_values; |
245 | for (size_t i = 0; i < num_elements; i++) { |
246 | auto binary_stmt = std::make_unique<BinaryOpStmt>( |
247 | stmt->op_type, lhs_vals[i], rhs_vals[i]); |
248 | matrix_init_values.push_back(binary_stmt.get()); |
249 | binary_stmt->ret_type = primitive_type; |
250 | |
251 | delayed_modifier_.insert_before(stmt, std::move(binary_stmt)); |
252 | } |
253 | |
254 | auto matrix_init_stmt = |
255 | std::make_unique<MatrixInitStmt>(matrix_init_values); |
256 | matrix_init_stmt->ret_type = stmt->ret_type; |
257 | |
258 | immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); |
259 | delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); |
260 | |
261 | delayed_modifier_.erase(stmt); |
262 | } |
263 | } |
264 | |
265 | void visit(PrintStmt *stmt) override { |
266 | auto &contents = stmt->contents; |
267 | std::vector<std::variant<Stmt *, std::string>> new_contents; |
268 | for (size_t i = 0; i < contents.size(); i++) { |
269 | auto content = contents[i]; |
270 | if (auto string_ptr = std::get_if<std::string>(&content)) { |
271 | new_contents.push_back(*string_ptr); |
272 | } else { |
273 | Stmt *print_stmt = std::get<Stmt *>(content); |
274 | if (print_stmt->is<MatrixInitStmt>()) { |
275 | auto matrix_init_stmt = print_stmt->cast<MatrixInitStmt>(); |
276 | auto tensor_shape = |
277 | print_stmt->ret_type->as<TensorType>()->get_shape(); |
278 | |
279 | bool is_matrix = tensor_shape.size() == 2; |
280 | int m = tensor_shape[0]; |
281 | |
282 | new_contents.push_back("[" ); |
283 | if (is_matrix) { |
284 | int n = tensor_shape[1]; |
285 | for (size_t i = 0; i < m; i++) { |
286 | new_contents.push_back("[" ); |
287 | for (size_t j = 0; j < n; j++) { |
288 | size_t index = i * n + j; |
289 | new_contents.push_back(matrix_init_stmt->values[index]); |
290 | if (j != n - 1) |
291 | new_contents.push_back(", " ); |
292 | } |
293 | new_contents.push_back("]" ); |
294 | |
295 | if (i != m - 1) |
296 | new_contents.push_back(", " ); |
297 | } |
298 | } else { |
299 | for (size_t i = 0; i < m; i++) { |
300 | new_contents.push_back(matrix_init_stmt->values[i]); |
301 | if (i != m - 1) |
302 | new_contents.push_back(", " ); |
303 | } |
304 | } |
305 | new_contents.push_back("]" ); |
306 | } else { |
307 | new_contents.push_back(print_stmt); |
308 | } |
309 | } |
310 | } |
311 | |
312 | // Merge string contents |
313 | std::vector<std::variant<Stmt *, std::string>> merged_contents; |
314 | std::string merged_string = "" ; |
315 | for (const auto &content : new_contents) { |
316 | if (auto string_content = std::get_if<std::string>(&content)) { |
317 | merged_string += *string_content; |
318 | } else { |
319 | if (!merged_string.empty()) { |
320 | merged_contents.push_back(merged_string); |
321 | merged_string = "" ; |
322 | } |
323 | merged_contents.push_back(content); |
324 | } |
325 | } |
326 | if (!merged_string.empty()) |
327 | merged_contents.push_back(merged_string); |
328 | |
329 | delayed_modifier_.insert_before(stmt, |
330 | Stmt::make<PrintStmt>(merged_contents)); |
331 | delayed_modifier_.erase(stmt); |
332 | } |
333 | |
334 | /* |
335 | Before: |
336 | TensorType<4 x i32> val = AtomicStmt(TensorType<4 x i32>* dest, |
337 | TensorType<4 x i32> val) |
338 | |
339 | After: |
340 | i32* dest_ptr_0 = MatrixPtrStmt(dest, 0) |
341 | i32* dest_ptr_1 = MatrixPtrStmt(dest, 1) |
342 | i32* dest_ptr_2 = MatrixPtrStmt(dest, 2) |
343 | i32* dest_ptr_3 = MatrixPtrStmt(dest, 3) |
344 | |
345 | i32 dest_val0 = AtomicStmt(dest_ptr_0, |
346 | val->cast<MatrixInitStmt>()->val[0]) |
347 | i32 dest_val1 = AtomicStmt(dest_ptr_1, |
348 | val->cast<MatrixInitStmt>()->val[1]) |
349 | i32 dest_val2 = AtomicStmt(dest_ptr_2, |
350 | val->cast<MatrixInitStmt>()->val[2]) |
351 | i32 dest_val3 = AtomicStmt(dest_ptr_3, |
352 | val->cast<MatrixInitStmt>()->val[3]) |
353 | |
354 | tmp = MatrixInitStmt(dest_val0, dest_val1, |
355 | dest_val2, dest_val3) |
356 | |
357 | stmt->replace_all_usages_with(tmp) |
358 | */ |
359 | void visit(AtomicOpStmt *stmt) override { |
360 | auto dest_dtype = stmt->dest->ret_type.ptr_removed(); |
361 | auto val_dtype = stmt->val->ret_type; |
362 | if (dest_dtype->is<TensorType>() || val_dtype->is<TensorType>()) { |
363 | // Make sure broadcasting has been correctly applied by |
364 | // AtomicOpExpression::type_check(). |
365 | TI_ASSERT(dest_dtype->is<TensorType>() && val_dtype->is<TensorType>()); |
366 | // However, since the type conversions are delayed until |
367 | // irpass::type_check(), we only check for the shape here. |
368 | TI_ASSERT(dest_dtype->cast<TensorType>()->get_shape() == |
369 | val_dtype->cast<TensorType>()->get_shape()); |
370 | // Scalarization for LoadStmt should have already replaced val operand |
371 | // to MatrixInitStmt. |
372 | TI_ASSERT(stmt->val->is<MatrixInitStmt>()); |
373 | |
374 | auto val_matrix_init_stmt = stmt->val->cast<MatrixInitStmt>(); |
375 | std::vector<Stmt *> val_values = val_matrix_init_stmt->values; |
376 | |
377 | size_t num_elements = val_values.size(); |
378 | auto primitive_type = stmt->ret_type.get_element_type(); |
379 | |
380 | // Scalarize dest & val |
381 | std::vector<Stmt *> matrix_init_values; |
382 | for (size_t i = 0; i < num_elements; i++) { |
383 | // scalarize to dest_i |
384 | auto const_stmt = std::make_unique<ConstStmt>( |
385 | TypedConstant(get_data_type<int32>(), i)); |
386 | auto matrix_ptr_stmt = |
387 | std::make_unique<MatrixPtrStmt>(stmt->dest, const_stmt.get()); |
388 | |
389 | // scalarize to val_i |
390 | auto val_stmt = val_values[i]; |
391 | |
392 | // assemble to scalarized atomic_op |
393 | auto atomic_stmt = std::make_unique<AtomicOpStmt>( |
394 | stmt->op_type, matrix_ptr_stmt.get(), val_stmt); |
395 | atomic_stmt->ret_type = primitive_type; |
396 | |
397 | matrix_init_values.push_back(atomic_stmt.get()); |
398 | |
399 | delayed_modifier_.insert_before(stmt, std::move(const_stmt)); |
400 | delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); |
401 | delayed_modifier_.insert_before(stmt, std::move(atomic_stmt)); |
402 | } |
403 | |
404 | auto matrix_init_stmt = |
405 | std::make_unique<MatrixInitStmt>(matrix_init_values); |
406 | matrix_init_stmt->ret_type = stmt->ret_type; |
407 | |
408 | immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); |
409 | delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); |
410 | |
411 | delayed_modifier_.erase(stmt); |
412 | } |
413 | } |
414 | |
415 | /* |
416 | Before: |
417 | TensorType<4 x i32> val = TernaryStmt(TensorType<4 x i32> cond, |
418 | TensorType<4 x i32> lhs, |
419 | TensorType<4 x i32> rhs) |
420 | |
421 | After: |
422 | i32 val0 = TernaryStmt(cond->cast<MatrixInitStmt>()->val[0], |
423 | lhs->cast<MatrixInitStmt>()->val[0], |
424 | rhs->cast<MatrixInitStmt>()->val[0]) |
425 | |
426 | i32 val1 = TernaryStmt(cond->cast<MatrixInitStmt>()->val[1], |
427 | lhs->cast<MatrixInitStmt>()->val[1], |
428 | rhs->cast<MatrixInitStmt>()->val[1]) |
429 | |
430 | i32 val2 = TernaryStmt(cond->cast<MatrixInitStmt>()->val[2], |
431 | lhs->cast<MatrixInitStmt>()->val[2], |
432 | rhs->cast<MatrixInitStmt>()->val[2]) |
433 | |
434 | i32 val3 = TernaryStmt(cond->cast<MatrixInitStmt>()->val[3], |
435 | lhs->cast<MatrixInitStmt>()->val[3], |
436 | rhs->cast<MatrixInitStmt>()->val[3]) |
437 | |
438 | tmp = MatrixInitStmt(val0, val1, val2, val3) |
439 | |
440 | stmt->replace_all_usages_with(tmp) |
441 | */ |
442 | void visit(TernaryOpStmt *stmt) override { |
443 | auto cond_dtype = stmt->op1->ret_type; |
444 | auto op2_dtype = stmt->op2->ret_type; |
445 | auto op3_dtype = stmt->op3->ret_type; |
446 | if (cond_dtype->is<TensorType>() || op2_dtype->is<TensorType>() || |
447 | op3_dtype->is<TensorType>()) { |
448 | // Make sure broadcasting has been correctly applied by |
449 | // TernaryOpExpression::type_check(). |
450 | TI_ASSERT(cond_dtype->is<TensorType>() && op2_dtype->is<TensorType>() && |
451 | op3_dtype->is<TensorType>()); |
452 | // However, since the type conversions are delayed until |
453 | // irpass::type_check(), we only check for the shape here. |
454 | TI_ASSERT(cond_dtype.get_shape() == op2_dtype.get_shape()); |
455 | TI_ASSERT(op2_dtype.get_shape() == op3_dtype.get_shape()); |
456 | // Scalarization for LoadStmt should have already replaced all operands |
457 | // to MatrixInitStmt. |
458 | TI_ASSERT(stmt->op1->is<MatrixInitStmt>()); |
459 | TI_ASSERT(stmt->op2->is<MatrixInitStmt>()); |
460 | TI_ASSERT(stmt->op3->is<MatrixInitStmt>()); |
461 | |
462 | auto cond_matrix_init_stmt = stmt->op1->cast<MatrixInitStmt>(); |
463 | std::vector<Stmt *> cond_vals = cond_matrix_init_stmt->values; |
464 | |
465 | auto op2_matrix_init_stmt = stmt->op2->cast<MatrixInitStmt>(); |
466 | std::vector<Stmt *> op2_vals = op2_matrix_init_stmt->values; |
467 | |
468 | auto op3_matrix_init_stmt = stmt->op3->cast<MatrixInitStmt>(); |
469 | std::vector<Stmt *> op3_vals = op3_matrix_init_stmt->values; |
470 | |
471 | TI_ASSERT(cond_vals.size() == op2_vals.size()); |
472 | TI_ASSERT(op2_vals.size() == op3_vals.size()); |
473 | |
474 | size_t num_elements = cond_vals.size(); |
475 | auto primitive_type = stmt->ret_type.get_element_type(); |
476 | std::vector<Stmt *> matrix_init_values; |
477 | for (size_t i = 0; i < num_elements; i++) { |
478 | auto ternary_stmt = std::make_unique<TernaryOpStmt>( |
479 | stmt->op_type, cond_vals[i], op2_vals[i], op3_vals[i]); |
480 | matrix_init_values.push_back(ternary_stmt.get()); |
481 | ternary_stmt->ret_type = primitive_type; |
482 | |
483 | delayed_modifier_.insert_before(stmt, std::move(ternary_stmt)); |
484 | } |
485 | |
486 | auto matrix_init_stmt = |
487 | std::make_unique<MatrixInitStmt>(matrix_init_values); |
488 | matrix_init_stmt->ret_type = stmt->ret_type; |
489 | |
490 | immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); |
491 | delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); |
492 | |
493 | delayed_modifier_.erase(stmt); |
494 | } |
495 | } |
496 | |
497 | void visit(GlobalStoreStmt *stmt) override { |
498 | scalarize_store_stmt<GlobalStoreStmt>(stmt); |
499 | } |
500 | |
501 | void visit(LocalStoreStmt *stmt) override { |
502 | scalarize_store_stmt<LocalStoreStmt>(stmt); |
503 | } |
504 | |
505 | void visit(GlobalLoadStmt *stmt) override { |
506 | scalarize_load_stmt<GlobalLoadStmt>(stmt); |
507 | } |
508 | |
509 | void visit(LocalLoadStmt *stmt) override { |
510 | scalarize_load_stmt<LocalLoadStmt>(stmt); |
511 | } |
512 | |
513 | void visit(ArgLoadStmt *stmt) override { |
514 | auto ret_type = stmt->ret_type.ptr_removed().get_element_type(); |
515 | auto arg_load = std::make_unique<ArgLoadStmt>(stmt->arg_id, ret_type, |
516 | stmt->is_ptr, stmt->is_grad); |
517 | |
518 | immediate_modifier_.replace_usages_with(stmt, arg_load.get()); |
519 | |
520 | delayed_modifier_.insert_before(stmt, std::move(arg_load)); |
521 | delayed_modifier_.erase(stmt); |
522 | } |
523 | |
524 | private: |
525 | using BasicStmtVisitor::visit; |
526 | }; |
527 | |
528 | // The GatherScalarizableLocalPointers gathers all local TensorType allocas |
529 | // only indexed with constants, which can then be scalarized in the |
530 | // ScalarizeLocalPointers pass. |
531 | class GatherScalarizableLocalPointers : public BasicStmtVisitor { |
532 | private: |
533 | using BasicStmtVisitor::visit; |
534 | |
535 | std::unordered_map<Stmt *, bool> is_alloca_scalarizable_; |
536 | |
537 | public: |
538 | void visit(AllocaStmt *stmt) override { |
539 | if (stmt->ret_type.ptr_removed()->is<TensorType>()) { |
540 | TI_ASSERT(is_alloca_scalarizable_.count(stmt) == 0); |
541 | is_alloca_scalarizable_[stmt] = !stmt->is_shared; |
542 | } |
543 | } |
544 | |
545 | void visit(MatrixPtrStmt *stmt) override { |
546 | if (stmt->origin->is<AllocaStmt>()) { |
547 | TI_ASSERT(is_alloca_scalarizable_.count(stmt->origin) == 1); |
548 | if (!stmt->offset->is<ConstStmt>()) { |
549 | is_alloca_scalarizable_[stmt->origin] = false; |
550 | } |
551 | } |
552 | } |
553 | |
554 | static std::unordered_set<Stmt *> run(IRNode *node) { |
555 | GatherScalarizableLocalPointers pass; |
556 | node->accept(&pass); |
557 | std::unordered_set<Stmt *> result; |
558 | for (auto &[k, v] : pass.is_alloca_scalarizable_) { |
559 | if (v) { |
560 | result.insert(k); |
561 | } |
562 | } |
563 | return result; |
564 | } |
565 | }; |
566 | |
567 | class ScalarizeLocalPointers : public BasicStmtVisitor { |
568 | public: |
569 | ImmediateIRModifier immediate_modifier_; |
570 | DelayedIRModifier delayed_modifier_; |
571 | |
572 | std::unordered_set<Stmt *> scalarizable_allocas_; |
573 | // { original_alloca_stmt : [scalarized_alloca_stmt0, ...] } |
574 | std::unordered_map<Stmt *, std::vector<Stmt *>> scalarized_local_tensor_map_; |
575 | |
576 | explicit ScalarizeLocalPointers( |
577 | IRNode *node, |
578 | const std::unordered_set<Stmt *> &scalarizable_allocas) |
579 | : immediate_modifier_(node), scalarizable_allocas_(scalarizable_allocas) { |
580 | node->accept(this); |
581 | |
582 | delayed_modifier_.modify_ir(); |
583 | } |
584 | |
585 | /* |
586 | Accessing scalar values are always more efficient than accessing elements |
587 | from a vector - the former generates less instructions, leading to better |
588 | performance in both compilation and runtime. |
589 | |
590 | Although we can do nothing about "global" tensors like tensors from |
591 | ArgLoadStmt or GlobalPtrStmt, we can still optimize "local" tensors like |
592 | tensors from AllocaStmt. In this pass, we ask AllocaStmt to allocate |
593 | multiple scalarized PrimitiveTyped variables in replacement of the original |
594 | TensorType. |
595 | |
596 | An additional container "scalarized_local_tensor_map_" is used to keep track |
597 | of the scalarized AllocaStmt, for later use in LoadStmt and StoreStmt. |
598 | |
599 | Before: |
600 | TensorType<4 x i32>* addr = AllocaStmt(TensorType<4 x i32>) |
601 | |
602 | After: |
603 | i32 addr0 = AllocaStmt(i32) |
604 | i32 addr1 = AllocaStmt(i32) |
605 | i32 addr2 = AllocaStmt(i32) |
606 | i32 addr3 = AllocaStmt(i32) |
607 | |
608 | scalarized_local_tensor_map_[addr] = {addr0, addr1, addr2, addr3} |
609 | */ |
610 | void visit(AllocaStmt *stmt) override { |
611 | if (scalarizable_allocas_.count(stmt) == 1) { |
612 | auto tensor_type = stmt->ret_type.ptr_removed()->cast<TensorType>(); |
613 | TI_ASSERT(tensor_type != nullptr); |
614 | auto primitive_type = tensor_type->get_element_type(); |
615 | |
616 | TI_ASSERT(scalarized_local_tensor_map_.count(stmt) == 0); |
617 | scalarized_local_tensor_map_[stmt] = {}; |
618 | for (size_t i = 0; i < tensor_type->get_num_elements(); i++) { |
619 | auto scalarized_alloca_stmt = |
620 | std::make_unique<AllocaStmt>(primitive_type); |
621 | scalarized_alloca_stmt->ret_type = primitive_type; |
622 | |
623 | scalarized_local_tensor_map_[stmt].push_back( |
624 | scalarized_alloca_stmt.get()); |
625 | delayed_modifier_.insert_before(stmt, |
626 | std::move(scalarized_alloca_stmt)); |
627 | } |
628 | |
629 | delayed_modifier_.erase(stmt); |
630 | } |
631 | } |
632 | |
633 | /* |
634 | Before: |
635 | MatrixPtrStmt(TensorType<4 x i32>* alloca_stmt, int offset) |
636 | |
637 | After: |
638 | scalarized_alloca_stmt = scalarized_local_tensor_map_[alloca_stmt][offset] |
639 | stmt->replace_all_usages_with(scalarized_alloca_stmt) |
640 | */ |
641 | void visit(MatrixPtrStmt *stmt) override { |
642 | if (stmt->origin->is<AllocaStmt>() && |
643 | scalarizable_allocas_.count(stmt->origin) == 1) { |
644 | auto alloca_stmt = stmt->origin->cast<AllocaStmt>(); |
645 | auto tensor_type = |
646 | alloca_stmt->ret_type.ptr_removed()->cast<TensorType>(); |
647 | TI_ASSERT(tensor_type != nullptr); |
648 | int num_elements = tensor_type->get_num_elements(); |
649 | TI_ASSERT(scalarized_local_tensor_map_.count(alloca_stmt)); |
650 | |
651 | const auto &scalarized_alloca_stmts = |
652 | scalarized_local_tensor_map_[alloca_stmt]; |
653 | TI_ASSERT(scalarized_alloca_stmts.size() == num_elements); |
654 | |
655 | TI_ASSERT(stmt->offset->is<ConstStmt>()); |
656 | int offset = stmt->offset->cast<ConstStmt>()->val.val_int32(); |
657 | |
658 | TI_ASSERT(offset < scalarized_alloca_stmts.size()); |
659 | auto new_stmt = scalarized_alloca_stmts[offset]; |
660 | |
661 | immediate_modifier_.replace_usages_with(stmt, new_stmt); |
662 | delayed_modifier_.erase(stmt); |
663 | } |
664 | } |
665 | |
666 | private: |
667 | using BasicStmtVisitor::visit; |
668 | }; |
669 | |
670 | // The ExtractLocalPointers pass aims at removing redundant ConstStmts and |
671 | // MatrixPtrStmts generated for any (AllocaStmt, integer) pair by extracting |
672 | // a unique copy for any future usage. |
673 | // |
674 | // Example for redundant stmts: |
675 | // <i32> $0 = const 0 |
676 | // <i32> $1 = const 1 |
677 | // ... |
678 | // <[Tensor (3, 3) f32]> $47738 = alloca |
679 | // <i32> $47739 = const 0 [REDUNDANT] |
680 | // <*f32> $47740 = shift ptr [$47738 + $47739] |
681 | // $47741 : local store [$47740 <- $47713] |
682 | // <i32> $47742 = const 1 [REDUNDANT] |
683 | // <*f32> $47743 = shift ptr [$47738 + $47742] |
684 | // $47744 : local store [$47743 <- $47716] |
685 | // ... |
686 | // <i32> $47812 = const 1 [REDUNDANT] |
687 | // <*f32> $47813 = shift ptr [$47738 + $47812] [REDUNDANT] |
688 | // <f32> $47814 = local load [$47813] |
689 | class : public BasicStmtVisitor { |
690 | public: |
691 | ImmediateIRModifier ; |
692 | DelayedIRModifier ; |
693 | |
694 | std::unordered_map<std::pair<Stmt *, int>, |
695 | Stmt *, |
696 | hashing::Hasher<std::pair<Stmt *, int>>> |
697 | ; // mapping an (AllocaStmt, integer) pair to the first |
698 | // MatrixPtrStmt representing it |
699 | std::unordered_map<int, Stmt *> |
700 | ; // mapping an integer to the first ConstStmt representing |
701 | // it |
702 | Block *; |
703 | |
704 | explicit (IRNode *root) : immediate_modifier_(root) { |
705 | TI_ASSERT(root->is<Block>()); |
706 | top_level_ = root->as<Block>(); |
707 | root->accept(this); |
708 | delayed_modifier_.modify_ir(); |
709 | } |
710 | |
711 | void (MatrixPtrStmt *stmt) override { |
712 | if (stmt->origin->is<AllocaStmt>()) { |
713 | auto alloca_stmt = stmt->origin->cast<AllocaStmt>(); |
714 | auto tensor_type = |
715 | alloca_stmt->ret_type.ptr_removed()->cast<TensorType>(); |
716 | TI_ASSERT(tensor_type != nullptr); |
717 | if (stmt->offset->is<ConstStmt>()) { |
718 | int offset = stmt->offset->cast<ConstStmt>()->val.val_int32(); |
719 | if (first_const_.count(offset) == 0) { |
720 | first_const_[offset] = stmt->offset; |
721 | delayed_modifier_.extract_to_block_front(stmt->offset, top_level_); |
722 | } |
723 | auto key = std::make_pair(alloca_stmt, offset); |
724 | if (first_matrix_ptr_.count(key) == 0) { |
725 | auto = std::make_unique<MatrixPtrStmt>( |
726 | alloca_stmt, first_const_[offset]); |
727 | first_matrix_ptr_[key] = extracted.get(); |
728 | delayed_modifier_.insert_after(alloca_stmt, std::move(extracted)); |
729 | } |
730 | auto new_stmt = first_matrix_ptr_[key]; |
731 | immediate_modifier_.replace_usages_with(stmt, new_stmt); |
732 | delayed_modifier_.erase(stmt); |
733 | } |
734 | } |
735 | } |
736 | |
737 | private: |
738 | using BasicStmtVisitor::visit; |
739 | }; |
740 | |
741 | namespace irpass { |
742 | |
743 | void scalarize(IRNode *root) { |
744 | TI_AUTO_PROF; |
745 | Scalarize scalarize_pass(root); |
746 | auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root); |
747 | ScalarizeLocalPointers scalarize_pointers_pass(root, scalarizable_allocas); |
748 | ExtractLocalPointers (root); |
749 | } |
750 | |
751 | } // namespace irpass |
752 | |
753 | } // namespace taichi::lang |
754 | |