1 | #include <algorithm> |
2 | #include <functional> |
3 | #include <iterator> |
4 | #include <type_traits> |
5 | |
6 | #include "taichi/ir/analysis.h" |
7 | #include "taichi/ir/ir.h" |
8 | #include "taichi/ir/statements.h" |
9 | #include "taichi/ir/transforms.h" |
10 | #include "taichi/ir/visitors.h" |
11 | #include "taichi/system/profiler.h" |
12 | |
13 | namespace taichi::lang { |
14 | |
15 | namespace { |
16 | |
17 | // Find the destinations of global atomic reductions that can be demoted into |
18 | // TLS buffer. |
19 | template <typename T> |
20 | std::vector<std::pair<T *, AtomicOpType>> find_global_reduction_destinations( |
21 | OffloadedStmt *offload, |
22 | const std::function<bool(T *)> &dest_checker) { |
23 | static_assert(std::is_same_v<T, GlobalPtrStmt> || |
24 | std::is_same_v<T, GlobalTemporaryStmt>); |
25 | // Gather all atomic add/sub/max/min destinations and record corresponding op |
26 | // type on the first appearance of a destination. |
27 | // Only one op type will be allowed on one destination (add/sub is an |
28 | // exception because add/sub can be mixed together in delta calculation). |
29 | // We use std::vector instead of std::map to keep a deterministic order here. |
30 | std::vector<std::pair<T *, AtomicOpType>> atomic_destinations; |
31 | // TODO: this is again an abuse since it gathers nothing. Need to design a IR |
32 | // map/reduce system |
33 | auto atomics = irpass::analysis::gather_statements(offload, [&](Stmt *stmt) { |
34 | if (auto atomic_op = stmt->cast<AtomicOpStmt>()) { |
35 | if (atomic_op->op_type == AtomicOpType::add || |
36 | atomic_op->op_type == AtomicOpType::sub || |
37 | atomic_op->op_type == AtomicOpType::max || |
38 | atomic_op->op_type == AtomicOpType::min) { |
39 | // Local atomics do not count. |
40 | if (auto dest = atomic_op->dest->cast<T>()) { |
41 | if (std::find_if(atomic_destinations.begin(), |
42 | atomic_destinations.end(), |
43 | [&](const std::pair<T *, AtomicOpType> &elem) { |
44 | return elem.first == dest; |
45 | }) == atomic_destinations.end()) { |
46 | atomic_destinations.push_back( |
47 | {dest, atomic_op->op_type == AtomicOpType::sub |
48 | ? AtomicOpType::add |
49 | : atomic_op->op_type}); |
50 | } |
51 | } |
52 | } |
53 | } |
54 | return false; |
55 | }); |
56 | |
57 | std::vector<std::pair<T *, AtomicOpType>> valid_reduction_values; |
58 | for (auto dest : atomic_destinations) { |
59 | // check if there is any other global load/store/atomic operations |
60 | auto related_global_mem_ops = |
61 | irpass::analysis::gather_statements(offload, [&](Stmt *stmt) { |
62 | if (auto load = stmt->cast<GlobalLoadStmt>()) { |
63 | if (irpass::analysis::maybe_same_address(load->src, dest.first)) { |
64 | return true; |
65 | } |
66 | } else if (auto store = stmt->cast<GlobalStoreStmt>()) { |
67 | if (irpass::analysis::maybe_same_address(store->dest, dest.first)) { |
68 | return true; |
69 | } |
70 | } else if (auto atomic = stmt->cast<AtomicOpStmt>()) { |
71 | if (irpass::analysis::maybe_same_address(atomic->dest, |
72 | dest.first)) { |
73 | return !((atomic->op_type == AtomicOpType::sub && |
74 | dest.second == AtomicOpType::add) || |
75 | atomic->op_type == dest.second); |
76 | } |
77 | } |
78 | for (auto &op : stmt->get_operands()) { |
79 | if (op == nullptr) |
80 | continue; |
81 | // Make sure the values of related atomic operations are not used. |
82 | if (auto atomic = op->cast<AtomicOpStmt>()) { |
83 | if (irpass::analysis::maybe_same_address(atomic->dest, |
84 | dest.first)) { |
85 | return true; |
86 | } |
87 | } |
88 | } |
89 | return false; // Now we are sure the statement is not related to the |
90 | // destination |
91 | }); |
92 | if (related_global_mem_ops.empty() && dest_checker(dest.first)) { |
93 | valid_reduction_values.push_back(dest); |
94 | } |
95 | } |
96 | return valid_reduction_values; |
97 | } |
98 | |
99 | void make_thread_local_offload(OffloadedStmt *offload) { |
100 | if (offload->task_type != OffloadedTaskType::range_for && |
101 | offload->task_type != OffloadedTaskType::struct_for) |
102 | return; |
103 | |
104 | std::vector<std::pair<Stmt *, AtomicOpType>> valid_reduction_values; |
105 | { |
106 | auto valid_global_ptrs = find_global_reduction_destinations<GlobalPtrStmt>( |
107 | offload, [](GlobalPtrStmt *dest) { |
108 | // We can only optimized reductions to global ptrs with form like |
109 | // loss[None] (0-D fields) for now. |
110 | // No TLS on quant types. |
111 | return (dest->snode->type == SNodeType::place) && |
112 | dest->indices.empty() && dest->snode->dt->is<PrimitiveType>(); |
113 | }); |
114 | auto valid_global_tmps = |
115 | find_global_reduction_destinations<GlobalTemporaryStmt>( |
116 | offload, [](auto *) { return true; }); |
117 | std::copy(valid_global_ptrs.begin(), valid_global_ptrs.end(), |
118 | std::back_inserter(valid_reduction_values)); |
119 | std::copy(valid_global_tmps.begin(), valid_global_tmps.end(), |
120 | std::back_inserter(valid_reduction_values)); |
121 | } |
122 | |
123 | std::size_t tls_offset = 0; |
124 | |
125 | // TODO: sort thread local storage variables according to dtype_size to |
126 | // reduce buffer fragmentation. |
127 | for (auto dest : valid_reduction_values) { |
128 | auto data_type = dest.first->ret_type.ptr_removed(); |
129 | auto dtype_size = data_type_size(data_type); |
130 | // Step 1: |
131 | // Create thread local storage |
132 | { |
133 | if (offload->tls_prologue == nullptr) { |
134 | offload->tls_prologue = std::make_unique<Block>(); |
135 | offload->tls_prologue->parent_stmt = offload; |
136 | } |
137 | |
138 | // ensure alignment |
139 | tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size; |
140 | |
141 | auto tls_ptr = offload->tls_prologue->push_back<ThreadLocalPtrStmt>( |
142 | tls_offset, TypeFactory::get_instance().get_pointer_type(data_type)); |
143 | |
144 | auto zero = offload->tls_prologue->insert( |
145 | std::make_unique<ConstStmt>( |
146 | dest.second == AtomicOpType::max ? get_min_value(data_type) |
147 | : dest.second == AtomicOpType::min ? get_max_value(data_type) |
148 | : TypedConstant(data_type, 0)), |
149 | -1); |
150 | // Zero-fill |
151 | // TODO: do not use GlobalStore for TLS ptr. |
152 | offload->tls_prologue->push_back<GlobalStoreStmt>(tls_ptr, zero); |
153 | } |
154 | |
155 | // Step 2: |
156 | // Make loop body accumulate to TLS ptr instead of global ptr |
157 | { |
158 | auto tls_ptr = offload->body->insert( |
159 | Stmt::make<ThreadLocalPtrStmt>( |
160 | tls_offset, |
161 | TypeFactory::get_instance().get_pointer_type(data_type)), |
162 | 0); |
163 | dest.first->replace_usages_with(tls_ptr); |
164 | } |
165 | |
166 | // Step 3: |
167 | // Atomic-add thread local contribution to its global version |
168 | { |
169 | if (offload->tls_epilogue == nullptr) { |
170 | offload->tls_epilogue = std::make_unique<Block>(); |
171 | offload->tls_epilogue->parent_stmt = offload; |
172 | } |
173 | auto tls_ptr = offload->tls_epilogue->push_back<ThreadLocalPtrStmt>( |
174 | tls_offset, TypeFactory::get_instance().get_pointer_type(data_type)); |
175 | // TODO: do not use global load from TLS. |
176 | auto tls_load = offload->tls_epilogue->push_back<GlobalLoadStmt>(tls_ptr); |
177 | auto global_ptr = offload->tls_epilogue->insert( |
178 | std::unique_ptr<Stmt>( |
179 | (Stmt *)irpass::analysis::clone(dest.first).release()), |
180 | -1); |
181 | offload->tls_epilogue->insert( |
182 | AtomicOpStmt::make_for_reduction(dest.second, global_ptr, tls_load), |
183 | -1); |
184 | } |
185 | |
186 | // allocate storage for the TLS variable |
187 | tls_offset += dtype_size; |
188 | } |
189 | |
190 | offload->tls_size = std::max(std::size_t(1), tls_offset); |
191 | } |
192 | |
193 | } // namespace |
194 | |
195 | namespace irpass { |
196 | |
197 | // This pass should happen after offloading but before lower_access |
198 | void make_thread_local(IRNode *root, const CompileConfig &config) { |
199 | TI_AUTO_PROF; |
200 | if (auto root_block = root->cast<Block>()) { |
201 | for (auto &offload : root_block->statements) { |
202 | make_thread_local_offload(offload->cast<OffloadedStmt>()); |
203 | } |
204 | } else { |
205 | make_thread_local_offload(root->as<OffloadedStmt>()); |
206 | } |
207 | type_check(root, config); |
208 | } |
209 | |
210 | } // namespace irpass |
211 | |
212 | } // namespace taichi::lang |
213 | |