1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/system/profiler.h"
6
7namespace taichi::lang {
8
9// Flag accesses to be either weak (non-activating) or strong (activating)
10class FlagAccess : public IRVisitor {
11 public:
12 explicit FlagAccess(IRNode *node) {
13 allow_undefined_visitor = true;
14 invoke_default_visitor = false;
15 node->accept(this);
16 }
17
18 void visit(Block *stmt_list) override { // block itself has no id
19 for (auto &stmt : stmt_list->statements) {
20 stmt->accept(this);
21 }
22 }
23
24 void visit(IfStmt *if_stmt) override {
25 if (if_stmt->true_statements)
26 if_stmt->true_statements->accept(this);
27 if (if_stmt->false_statements) {
28 if_stmt->false_statements->accept(this);
29 }
30 }
31
32 void visit(WhileStmt *stmt) override {
33 stmt->body->accept(this);
34 }
35
36 void visit(RangeForStmt *for_stmt) override {
37 for_stmt->body->accept(this);
38 }
39
40 void visit(StructForStmt *for_stmt) override {
41 for_stmt->body->accept(this);
42 }
43
44 void visit(MeshForStmt *for_stmt) override {
45 for_stmt->body->accept(this);
46 }
47
48 void visit(OffloadedStmt *stmt) override {
49 stmt->all_blocks_accept(this);
50 }
51
52 // Assuming pointers will be visited before global load/st
53 void visit(GlobalPtrStmt *stmt) override {
54 stmt->activate = false;
55 }
56
57 void visit(GlobalStoreStmt *stmt) override {
58 if (stmt->dest->is<GlobalPtrStmt>()) {
59 stmt->dest->as<GlobalPtrStmt>()->activate = true;
60 }
61 if (stmt->dest->is<MatrixPtrStmt>()) {
62 if (stmt->dest->as<MatrixPtrStmt>()->is_unlowered_global_ptr()) {
63 stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>()->activate =
64 true;
65 }
66 }
67 }
68
69 void visit(AtomicOpStmt *stmt) override {
70 if (stmt->dest->is<GlobalPtrStmt>()) {
71 stmt->dest->as<GlobalPtrStmt>()->activate = true;
72 }
73 }
74};
75
76// For struct fors, weaken accesses on variables currently being looped over
77// E.g.
78// for i in x:
79// x[i] = 0
80// Here although we are writing to x[i], but i will only loop over active
81// elements of x. So we don't need one more activation. Note the indices of x
82// accesses must be loop indices for this optimization to be correct.
83
84class WeakenAccess : public BasicStmtVisitor {
85 public:
86 using BasicStmtVisitor::visit;
87
88 explicit WeakenAccess(IRNode *node) {
89 allow_undefined_visitor = true;
90 invoke_default_visitor = false;
91 current_struct_for_ = nullptr;
92 current_offload_ = nullptr;
93 node->accept(this);
94 }
95
96 void visit(Block *stmt_list) override { // block itself has no id
97 for (auto &stmt : stmt_list->statements) {
98 stmt->accept(this);
99 }
100 }
101
102 void visit(StructForStmt *stmt) override {
103 current_struct_for_ = stmt;
104 stmt->body->accept(this);
105 current_struct_for_ = nullptr;
106 }
107
108 void visit(OffloadedStmt *stmt) override {
109 current_offload_ = stmt;
110 if (stmt->body)
111 stmt->body->accept(this);
112 current_offload_ = nullptr;
113 }
114
115 static SNode *least_sparse_ancestor(SNode *a) {
116 while (a->type == SNodeType::place || a->type == SNodeType::dense ||
117 a->type == SNodeType::bit_struct ||
118 a->type == SNodeType::quant_array) {
119 a = a->parent;
120 }
121 return a;
122 }
123
124 static bool share_sparsity(SNode *a, SNode *b) {
125 return least_sparse_ancestor(a) == least_sparse_ancestor(b);
126 }
127
128 void visit(GlobalPtrStmt *stmt) override {
129 if (stmt->activate) {
130 bool is_struct_for =
131 (current_offload_ && current_offload_->task_type ==
132 OffloadedStmt::TaskType::struct_for) ||
133 current_struct_for_;
134 if (is_struct_for) {
135 bool same_as_loop_snode = true;
136 SNode *loop_snode = nullptr;
137 if (current_struct_for_) {
138 loop_snode = current_struct_for_->snode;
139 } else {
140 loop_snode = current_offload_->snode;
141 }
142 TI_ASSERT(loop_snode);
143 if (!share_sparsity(stmt->snode, loop_snode)) {
144 same_as_loop_snode = false;
145 }
146 if (stmt->indices.size() == loop_snode->num_active_indices)
147 for (int i = 0; i < loop_snode->num_active_indices; i++) {
148 auto ind = stmt->indices[i];
149 // TODO: vectorized cases?
150 if (auto loop_var = ind->cast<LoopIndexStmt>()) {
151 if (loop_var->index != i) {
152 same_as_loop_snode = false;
153 }
154 } else {
155 same_as_loop_snode = false;
156 }
157 }
158 if (same_as_loop_snode)
159 stmt->activate = false;
160 }
161 }
162 }
163
164 private:
165 OffloadedStmt *current_offload_;
166 StructForStmt *current_struct_for_;
167};
168
169namespace irpass {
170
171void flag_access(IRNode *root) {
172 TI_AUTO_PROF;
173 FlagAccess flag_access(root);
174 WeakenAccess weaken_access(root);
175}
176
177} // namespace irpass
178
179} // namespace taichi::lang
180