1 | #include "taichi/ir/analysis.h" |
2 | #include "taichi/ir/ir.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/transforms.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/system/profiler.h" |
7 | |
8 | #include <deque> |
9 | #include <set> |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | class DemoteAtomics : public BasicStmtVisitor { |
14 | private: |
15 | std::unordered_map<const SNode *, GlobalPtrStmt *> loop_unique_ptr_; |
16 | std::unordered_map<int, ExternalPtrStmt *> loop_unique_arr_ptr_; |
17 | |
18 | public: |
19 | using BasicStmtVisitor::visit; |
20 | |
21 | OffloadedStmt *current_offloaded; |
22 | DelayedIRModifier modifier; |
23 | |
24 | DemoteAtomics() { |
25 | current_offloaded = nullptr; |
26 | } |
27 | |
28 | void visit(AtomicOpStmt *stmt) override { |
29 | bool demote = false; |
30 | bool is_local = false; |
31 | if (current_offloaded) { |
32 | if (arch_is_cpu(current_offloaded->device) && |
33 | current_offloaded->num_cpu_threads == 1) { |
34 | demote = true; |
35 | } |
36 | if (stmt->dest->is<ThreadLocalPtrStmt>()) { |
37 | demote = true; |
38 | } |
39 | if (current_offloaded->task_type == OffloadedTaskType::serial) { |
40 | demote = true; |
41 | } |
42 | if (!demote && |
43 | (current_offloaded->task_type == OffloadedTaskType::range_for || |
44 | current_offloaded->task_type == OffloadedTaskType::mesh_for || |
45 | current_offloaded->task_type == OffloadedTaskType::struct_for)) { |
46 | if (stmt->dest->is<GlobalPtrStmt>()) { |
47 | demote = true; |
48 | auto dest = stmt->dest->as<GlobalPtrStmt>(); |
49 | auto snode = dest->snode; |
50 | if (loop_unique_ptr_[snode] == nullptr || |
51 | loop_unique_ptr_[snode]->indices.empty()) { |
52 | // not uniquely accessed |
53 | demote = false; |
54 | } |
55 | if (current_offloaded->mem_access_opt.has_flag( |
56 | snode, SNodeAccessFlag::block_local) || |
57 | current_offloaded->mem_access_opt.has_flag( |
58 | snode, SNodeAccessFlag::mesh_local)) { |
59 | // BLS does not support write access yet so we keep atomic_adds. |
60 | demote = false; |
61 | } |
62 | // demote from-end atomics |
63 | if (current_offloaded->task_type == OffloadedTaskType::mesh_for) { |
64 | if (dest->indices.size() == 1 && |
65 | dest->indices[0]->is<MeshIndexConversionStmt>()) { |
66 | auto idx = dest->indices[0]->as<MeshIndexConversionStmt>()->idx; |
67 | while (idx->is<MeshIndexConversionStmt>()) { // special case: l2g |
68 | // + g2r |
69 | idx = idx->as<MeshIndexConversionStmt>()->idx; |
70 | } |
71 | if (idx->is<LoopIndexStmt>() && |
72 | idx->as<LoopIndexStmt>()->is_mesh_index() && |
73 | loop_unique_ptr_[stmt->dest->as<GlobalPtrStmt>()->snode] != |
74 | nullptr) { |
75 | demote = true; |
76 | } |
77 | } |
78 | } |
79 | } else if (stmt->dest->is<ExternalPtrStmt>()) { |
80 | ExternalPtrStmt *dest_ptr = stmt->dest->as<ExternalPtrStmt>(); |
81 | demote = true; |
82 | if (dest_ptr->indices.empty()) { |
83 | demote = false; |
84 | } |
85 | ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as<ArgLoadStmt>(); |
86 | int arg_id = arg_load_stmt->arg_id; |
87 | if (loop_unique_arr_ptr_[arg_id] == nullptr) { |
88 | // Not loop unique |
89 | demote = false; |
90 | } |
91 | // TODO: Is BLS / Mem Access Opt a thing for any_arr? |
92 | } |
93 | } |
94 | } |
95 | if (stmt->dest->is<AllocaStmt>() || |
96 | (stmt->dest->is<MatrixPtrStmt>() && |
97 | stmt->dest->cast<MatrixPtrStmt>()->origin->is<AllocaStmt>())) { |
98 | demote = true; |
99 | is_local = true; |
100 | } |
101 | |
102 | if (auto dest_pointer_type = stmt->dest->ret_type->cast<PointerType>()) { |
103 | if (dest_pointer_type->get_pointee_type()->is<QuantFloatType>()) { |
104 | TI_WARN( |
105 | "AtomicOp on QuantFloatType is not supported. " |
106 | "Demoting to non-atomic RMW.\n{}" , |
107 | stmt->tb); |
108 | demote = true; |
109 | } |
110 | } |
111 | |
112 | if (demote) { |
113 | // replace atomics with load, add, store |
114 | auto bin_type = atomic_to_binary_op_type(stmt->op_type); |
115 | auto ptr = stmt->dest; |
116 | auto val = stmt->val; |
117 | |
118 | auto new_stmts = VecStatement(); |
119 | Stmt *load; |
120 | if (is_local) { |
121 | load = new_stmts.push_back<LocalLoadStmt>(ptr); |
122 | auto bin = new_stmts.push_back<BinaryOpStmt>(bin_type, load, val); |
123 | new_stmts.push_back<LocalStoreStmt>(ptr, bin); |
124 | } else { |
125 | load = new_stmts.push_back<GlobalLoadStmt>(ptr); |
126 | auto bin = new_stmts.push_back<BinaryOpStmt>(bin_type, load, val); |
127 | new_stmts.push_back<GlobalStoreStmt>(ptr, bin); |
128 | } |
129 | // For a taichi program like `c = ti.atomic_add(a, b)`, the IR looks |
130 | // like the following |
131 | // |
132 | // $c = # lhs memory |
133 | // $d = atomic add($a, $b) |
134 | // $e : store [$c <- $d] |
135 | // |
136 | // If this gets demoted, the IR is translated into: |
137 | // |
138 | // $c = # lhs memory |
139 | // $d' = load $a <-- added by demote_atomic |
140 | // $e' = add $d' $b |
141 | // $f : store [$a <- $e'] <-- added by demote_atomic |
142 | // $g : store [$c <- ???] <-- store the old value into lhs $c |
143 | // |
144 | // Naively relying on Block::replace_with() would incorrectly fill $f |
145 | // into ???, because $f is a store stmt that doesn't have a return |
146 | // value. The correct thing is to replace |stmt| $d with the loaded |
147 | // old value $d'. |
148 | // See also: https://github.com/taichi-dev/taichi/issues/332 |
149 | stmt->replace_usages_with(load); |
150 | modifier.replace_with(stmt, std::move(new_stmts), |
151 | /*replace_usages=*/false); |
152 | } |
153 | } |
154 | |
155 | void visit(OffloadedStmt *stmt) override { |
156 | current_offloaded = stmt; |
157 | if (stmt->task_type == OffloadedTaskType::range_for || |
158 | stmt->task_type == OffloadedTaskType::mesh_for || |
159 | stmt->task_type == OffloadedTaskType::struct_for) { |
160 | auto uniquely_accessed_pointers = |
161 | irpass::analysis::gather_uniquely_accessed_pointers(stmt); |
162 | loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first); |
163 | loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second); |
164 | } |
165 | // We don't need to visit TLS/BLS prologues/epilogues. |
166 | if (stmt->body) { |
167 | stmt->body->accept(this); |
168 | } |
169 | current_offloaded = nullptr; |
170 | } |
171 | |
172 | static bool run(IRNode *node) { |
173 | DemoteAtomics demoter; |
174 | bool modified = false; |
175 | while (true) { |
176 | node->accept(&demoter); |
177 | if (demoter.modifier.modify_ir()) { |
178 | modified = true; |
179 | } else { |
180 | break; |
181 | } |
182 | } |
183 | return modified; |
184 | } |
185 | }; |
186 | |
187 | namespace irpass { |
188 | |
189 | bool demote_atomics(IRNode *root, const CompileConfig &config) { |
190 | TI_AUTO_PROF; |
191 | bool modified = DemoteAtomics::run(root); |
192 | type_check(root, config); |
193 | return modified; |
194 | } |
195 | |
196 | } // namespace irpass |
197 | |
198 | } // namespace taichi::lang |
199 | |