1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/system/profiler.h"
6#include <numeric>
7#include <functional>
8
9namespace taichi::lang {
10
11class LowerMatrixPtr : public BasicStmtVisitor {
12 private:
13 using BasicStmtVisitor::visit;
14 DelayedIRModifier modifier_;
15
16 public:
17 void visit(MatrixPtrStmt *stmt) override {
18 if (stmt->origin->is<MatrixOfGlobalPtrStmt>()) {
19 auto origin = stmt->origin->as<MatrixOfGlobalPtrStmt>();
20 if (stmt->offset->is<ConstStmt>()) {
21 auto offset = stmt->offset->as<ConstStmt>();
22 auto lowered = std::make_unique<GlobalPtrStmt>(
23 origin->snodes[offset->val.val_int()], origin->indices);
24 stmt->replace_usages_with(lowered.get());
25 modifier_.insert_before(stmt, std::move(lowered));
26 modifier_.erase(stmt);
27 } else {
28 TI_ASSERT_INFO(
29 origin->dynamic_indexable,
30 "Element of the MatrixField is not dynamic indexable.\n{}",
31 stmt->tb);
32 auto stride = std::make_unique<ConstStmt>(
33 TypedConstant(origin->dynamic_index_stride));
34 auto offset = std::make_unique<BinaryOpStmt>(
35 BinaryOpType::mul, stmt->offset, stride.get());
36 auto ptr_base =
37 std::make_unique<GlobalPtrStmt>(origin->snodes[0], origin->indices);
38 auto lowered =
39 std::make_unique<MatrixPtrStmt>(ptr_base.get(), offset.get());
40 stmt->replace_usages_with(lowered.get());
41 modifier_.insert_before(stmt, std::move(stride));
42 modifier_.insert_before(stmt, std::move(offset));
43 modifier_.insert_before(stmt, std::move(ptr_base));
44 modifier_.insert_before(stmt, std::move(lowered));
45 modifier_.erase(stmt);
46 }
47 return;
48 }
49 if (stmt->origin->is<ExternalPtrStmt>()) {
50 auto origin = stmt->origin->as<ExternalPtrStmt>();
51 TI_ASSERT(stmt->origin->ret_type.ptr_removed()->is<TensorType>());
52
53 std::vector<Stmt *> indices = origin->indices;
54 indices.push_back(stmt->offset);
55
56 // MatrixPtrStmt has flattened indices, linearization of which is done
57 // during IndexExpression::flatten() Here we need to modify the
58 // element_dim and element_shape a little bit.
59 int element_dim = -1; // AOS Vector
60 std::vector<int> element_shape = {
61 std::accumulate(begin(origin->element_shape),
62 end(origin->element_shape), 1, std::multiplies<>())};
63
64 auto fused = std::make_unique<ExternalPtrStmt>(
65 origin->base_ptr, indices, element_shape, element_dim);
66 fused->ret_type = stmt->ret_type;
67
68 stmt->replace_usages_with(fused.get());
69 modifier_.insert_before(stmt, std::move(fused));
70 modifier_.erase(stmt);
71 return;
72 }
73 if (stmt->origin->is<MatrixOfMatrixPtrStmt>()) {
74 auto origin = stmt->origin->as<MatrixOfMatrixPtrStmt>();
75 TI_ASSERT(stmt->offset->is<ConstStmt>());
76 auto offset = stmt->offset->as<ConstStmt>();
77 stmt->replace_usages_with(origin->stmts[offset->val.val_int()]);
78 modifier_.erase(stmt);
79 return;
80 }
81 }
82
83 static void run(IRNode *node) {
84 LowerMatrixPtr pass;
85 node->accept(&pass);
86 pass.modifier_.modify_ir();
87 }
88};
89
90class RemoveMatrixOfPtr : public BasicStmtVisitor {
91 private:
92 using BasicStmtVisitor::visit;
93 DelayedIRModifier modifier_;
94
95 public:
96 void visit(MatrixOfGlobalPtrStmt *stmt) override {
97 modifier_.erase(stmt);
98 }
99
100 void visit(MatrixOfMatrixPtrStmt *stmt) override {
101 modifier_.erase(stmt);
102 }
103
104 static void run(IRNode *node) {
105 RemoveMatrixOfPtr pass;
106 node->accept(&pass);
107 pass.modifier_.modify_ir();
108 }
109};
110
111namespace irpass {
112
113void lower_matrix_ptr(IRNode *root) {
114 TI_AUTO_PROF;
115 LowerMatrixPtr::run(root);
116 RemoveMatrixOfPtr::run(root);
117}
118
119} // namespace irpass
120
121} // namespace taichi::lang
122