1#include "taichi/ir/analysis.h"
2#include "taichi/ir/ir.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/transforms.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/system/profiler.h"
7
8#include <deque>
9#include <set>
10
11namespace taichi::lang {
12
13class DemoteAtomics : public BasicStmtVisitor {
14 private:
15 std::unordered_map<const SNode *, GlobalPtrStmt *> loop_unique_ptr_;
16 std::unordered_map<int, ExternalPtrStmt *> loop_unique_arr_ptr_;
17
18 public:
19 using BasicStmtVisitor::visit;
20
21 OffloadedStmt *current_offloaded;
22 DelayedIRModifier modifier;
23
24 DemoteAtomics() {
25 current_offloaded = nullptr;
26 }
27
28 void visit(AtomicOpStmt *stmt) override {
29 bool demote = false;
30 bool is_local = false;
31 if (current_offloaded) {
32 if (arch_is_cpu(current_offloaded->device) &&
33 current_offloaded->num_cpu_threads == 1) {
34 demote = true;
35 }
36 if (stmt->dest->is<ThreadLocalPtrStmt>()) {
37 demote = true;
38 }
39 if (current_offloaded->task_type == OffloadedTaskType::serial) {
40 demote = true;
41 }
42 if (!demote &&
43 (current_offloaded->task_type == OffloadedTaskType::range_for ||
44 current_offloaded->task_type == OffloadedTaskType::mesh_for ||
45 current_offloaded->task_type == OffloadedTaskType::struct_for)) {
46 if (stmt->dest->is<GlobalPtrStmt>()) {
47 demote = true;
48 auto dest = stmt->dest->as<GlobalPtrStmt>();
49 auto snode = dest->snode;
50 if (loop_unique_ptr_[snode] == nullptr ||
51 loop_unique_ptr_[snode]->indices.empty()) {
52 // not uniquely accessed
53 demote = false;
54 }
55 if (current_offloaded->mem_access_opt.has_flag(
56 snode, SNodeAccessFlag::block_local) ||
57 current_offloaded->mem_access_opt.has_flag(
58 snode, SNodeAccessFlag::mesh_local)) {
59 // BLS does not support write access yet so we keep atomic_adds.
60 demote = false;
61 }
62 // demote from-end atomics
63 if (current_offloaded->task_type == OffloadedTaskType::mesh_for) {
64 if (dest->indices.size() == 1 &&
65 dest->indices[0]->is<MeshIndexConversionStmt>()) {
66 auto idx = dest->indices[0]->as<MeshIndexConversionStmt>()->idx;
67 while (idx->is<MeshIndexConversionStmt>()) { // special case: l2g
68 // + g2r
69 idx = idx->as<MeshIndexConversionStmt>()->idx;
70 }
71 if (idx->is<LoopIndexStmt>() &&
72 idx->as<LoopIndexStmt>()->is_mesh_index() &&
73 loop_unique_ptr_[stmt->dest->as<GlobalPtrStmt>()->snode] !=
74 nullptr) {
75 demote = true;
76 }
77 }
78 }
79 } else if (stmt->dest->is<ExternalPtrStmt>()) {
80 ExternalPtrStmt *dest_ptr = stmt->dest->as<ExternalPtrStmt>();
81 demote = true;
82 if (dest_ptr->indices.empty()) {
83 demote = false;
84 }
85 ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as<ArgLoadStmt>();
86 int arg_id = arg_load_stmt->arg_id;
87 if (loop_unique_arr_ptr_[arg_id] == nullptr) {
88 // Not loop unique
89 demote = false;
90 }
91 // TODO: Is BLS / Mem Access Opt a thing for any_arr?
92 }
93 }
94 }
95 if (stmt->dest->is<AllocaStmt>() ||
96 (stmt->dest->is<MatrixPtrStmt>() &&
97 stmt->dest->cast<MatrixPtrStmt>()->origin->is<AllocaStmt>())) {
98 demote = true;
99 is_local = true;
100 }
101
102 if (auto dest_pointer_type = stmt->dest->ret_type->cast<PointerType>()) {
103 if (dest_pointer_type->get_pointee_type()->is<QuantFloatType>()) {
104 TI_WARN(
105 "AtomicOp on QuantFloatType is not supported. "
106 "Demoting to non-atomic RMW.\n{}",
107 stmt->tb);
108 demote = true;
109 }
110 }
111
112 if (demote) {
113 // replace atomics with load, add, store
114 auto bin_type = atomic_to_binary_op_type(stmt->op_type);
115 auto ptr = stmt->dest;
116 auto val = stmt->val;
117
118 auto new_stmts = VecStatement();
119 Stmt *load;
120 if (is_local) {
121 load = new_stmts.push_back<LocalLoadStmt>(ptr);
122 auto bin = new_stmts.push_back<BinaryOpStmt>(bin_type, load, val);
123 new_stmts.push_back<LocalStoreStmt>(ptr, bin);
124 } else {
125 load = new_stmts.push_back<GlobalLoadStmt>(ptr);
126 auto bin = new_stmts.push_back<BinaryOpStmt>(bin_type, load, val);
127 new_stmts.push_back<GlobalStoreStmt>(ptr, bin);
128 }
129 // For a taichi program like `c = ti.atomic_add(a, b)`, the IR looks
130 // like the following
131 //
132 // $c = # lhs memory
133 // $d = atomic add($a, $b)
134 // $e : store [$c <- $d]
135 //
136 // If this gets demoted, the IR is translated into:
137 //
138 // $c = # lhs memory
139 // $d' = load $a <-- added by demote_atomic
140 // $e' = add $d' $b
141 // $f : store [$a <- $e'] <-- added by demote_atomic
142 // $g : store [$c <- ???] <-- store the old value into lhs $c
143 //
144 // Naively relying on Block::replace_with() would incorrectly fill $f
145 // into ???, because $f is a store stmt that doesn't have a return
146 // value. The correct thing is to replace |stmt| $d with the loaded
147 // old value $d'.
148 // See also: https://github.com/taichi-dev/taichi/issues/332
149 stmt->replace_usages_with(load);
150 modifier.replace_with(stmt, std::move(new_stmts),
151 /*replace_usages=*/false);
152 }
153 }
154
155 void visit(OffloadedStmt *stmt) override {
156 current_offloaded = stmt;
157 if (stmt->task_type == OffloadedTaskType::range_for ||
158 stmt->task_type == OffloadedTaskType::mesh_for ||
159 stmt->task_type == OffloadedTaskType::struct_for) {
160 auto uniquely_accessed_pointers =
161 irpass::analysis::gather_uniquely_accessed_pointers(stmt);
162 loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first);
163 loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second);
164 }
165 // We don't need to visit TLS/BLS prologues/epilogues.
166 if (stmt->body) {
167 stmt->body->accept(this);
168 }
169 current_offloaded = nullptr;
170 }
171
172 static bool run(IRNode *node) {
173 DemoteAtomics demoter;
174 bool modified = false;
175 while (true) {
176 node->accept(&demoter);
177 if (demoter.modifier.modify_ir()) {
178 modified = true;
179 } else {
180 break;
181 }
182 }
183 return modified;
184 }
185};
186
187namespace irpass {
188
189bool demote_atomics(IRNode *root, const CompileConfig &config) {
190 TI_AUTO_PROF;
191 bool modified = DemoteAtomics::run(root);
192 type_check(root, config);
193 return modified;
194}
195
196} // namespace irpass
197
198} // namespace taichi::lang
199