1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/system/profiler.h" |
6 | #include <numeric> |
7 | #include <functional> |
8 | |
9 | namespace taichi::lang { |
10 | |
11 | class LowerMatrixPtr : public BasicStmtVisitor { |
12 | private: |
13 | using BasicStmtVisitor::visit; |
14 | DelayedIRModifier modifier_; |
15 | |
16 | public: |
17 | void visit(MatrixPtrStmt *stmt) override { |
18 | if (stmt->origin->is<MatrixOfGlobalPtrStmt>()) { |
19 | auto origin = stmt->origin->as<MatrixOfGlobalPtrStmt>(); |
20 | if (stmt->offset->is<ConstStmt>()) { |
21 | auto offset = stmt->offset->as<ConstStmt>(); |
22 | auto lowered = std::make_unique<GlobalPtrStmt>( |
23 | origin->snodes[offset->val.val_int()], origin->indices); |
24 | stmt->replace_usages_with(lowered.get()); |
25 | modifier_.insert_before(stmt, std::move(lowered)); |
26 | modifier_.erase(stmt); |
27 | } else { |
28 | TI_ASSERT_INFO( |
29 | origin->dynamic_indexable, |
30 | "Element of the MatrixField is not dynamic indexable.\n{}" , |
31 | stmt->tb); |
32 | auto stride = std::make_unique<ConstStmt>( |
33 | TypedConstant(origin->dynamic_index_stride)); |
34 | auto offset = std::make_unique<BinaryOpStmt>( |
35 | BinaryOpType::mul, stmt->offset, stride.get()); |
36 | auto ptr_base = |
37 | std::make_unique<GlobalPtrStmt>(origin->snodes[0], origin->indices); |
38 | auto lowered = |
39 | std::make_unique<MatrixPtrStmt>(ptr_base.get(), offset.get()); |
40 | stmt->replace_usages_with(lowered.get()); |
41 | modifier_.insert_before(stmt, std::move(stride)); |
42 | modifier_.insert_before(stmt, std::move(offset)); |
43 | modifier_.insert_before(stmt, std::move(ptr_base)); |
44 | modifier_.insert_before(stmt, std::move(lowered)); |
45 | modifier_.erase(stmt); |
46 | } |
47 | return; |
48 | } |
49 | if (stmt->origin->is<ExternalPtrStmt>()) { |
50 | auto origin = stmt->origin->as<ExternalPtrStmt>(); |
51 | TI_ASSERT(stmt->origin->ret_type.ptr_removed()->is<TensorType>()); |
52 | |
53 | std::vector<Stmt *> indices = origin->indices; |
54 | indices.push_back(stmt->offset); |
55 | |
56 | // MatrixPtrStmt has flattened indices, linearization of which is done |
57 | // during IndexExpression::flatten() Here we need to modify the |
58 | // element_dim and element_shape a little bit. |
59 | int element_dim = -1; // AOS Vector |
60 | std::vector<int> element_shape = { |
61 | std::accumulate(begin(origin->element_shape), |
62 | end(origin->element_shape), 1, std::multiplies<>())}; |
63 | |
64 | auto fused = std::make_unique<ExternalPtrStmt>( |
65 | origin->base_ptr, indices, element_shape, element_dim); |
66 | fused->ret_type = stmt->ret_type; |
67 | |
68 | stmt->replace_usages_with(fused.get()); |
69 | modifier_.insert_before(stmt, std::move(fused)); |
70 | modifier_.erase(stmt); |
71 | return; |
72 | } |
73 | if (stmt->origin->is<MatrixOfMatrixPtrStmt>()) { |
74 | auto origin = stmt->origin->as<MatrixOfMatrixPtrStmt>(); |
75 | TI_ASSERT(stmt->offset->is<ConstStmt>()); |
76 | auto offset = stmt->offset->as<ConstStmt>(); |
77 | stmt->replace_usages_with(origin->stmts[offset->val.val_int()]); |
78 | modifier_.erase(stmt); |
79 | return; |
80 | } |
81 | } |
82 | |
83 | static void run(IRNode *node) { |
84 | LowerMatrixPtr pass; |
85 | node->accept(&pass); |
86 | pass.modifier_.modify_ir(); |
87 | } |
88 | }; |
89 | |
90 | class RemoveMatrixOfPtr : public BasicStmtVisitor { |
91 | private: |
92 | using BasicStmtVisitor::visit; |
93 | DelayedIRModifier modifier_; |
94 | |
95 | public: |
96 | void visit(MatrixOfGlobalPtrStmt *stmt) override { |
97 | modifier_.erase(stmt); |
98 | } |
99 | |
100 | void visit(MatrixOfMatrixPtrStmt *stmt) override { |
101 | modifier_.erase(stmt); |
102 | } |
103 | |
104 | static void run(IRNode *node) { |
105 | RemoveMatrixOfPtr pass; |
106 | node->accept(&pass); |
107 | pass.modifier_.modify_ir(); |
108 | } |
109 | }; |
110 | |
111 | namespace irpass { |
112 | |
113 | void lower_matrix_ptr(IRNode *root) { |
114 | TI_AUTO_PROF; |
115 | LowerMatrixPtr::run(root); |
116 | RemoveMatrixOfPtr::run(root); |
117 | } |
118 | |
119 | } // namespace irpass |
120 | |
121 | } // namespace taichi::lang |
122 | |