1#include <algorithm>
2#include <functional>
3#include <iterator>
4#include <type_traits>
5
6#include "taichi/ir/analysis.h"
7#include "taichi/ir/ir.h"
8#include "taichi/ir/statements.h"
9#include "taichi/ir/transforms.h"
10#include "taichi/ir/visitors.h"
11#include "taichi/system/profiler.h"
12
13namespace taichi::lang {
14
15namespace {
16
17// Find the destinations of global atomic reductions that can be demoted into
18// TLS buffer.
19template <typename T>
20std::vector<std::pair<T *, AtomicOpType>> find_global_reduction_destinations(
21 OffloadedStmt *offload,
22 const std::function<bool(T *)> &dest_checker) {
23 static_assert(std::is_same_v<T, GlobalPtrStmt> ||
24 std::is_same_v<T, GlobalTemporaryStmt>);
25 // Gather all atomic add/sub/max/min destinations and record corresponding op
26 // type on the first appearance of a destination.
27 // Only one op type will be allowed on one destination (add/sub is an
28 // exception because add/sub can be mixed together in delta calculation).
29 // We use std::vector instead of std::map to keep a deterministic order here.
30 std::vector<std::pair<T *, AtomicOpType>> atomic_destinations;
31 // TODO: this is again an abuse since it gathers nothing. Need to design a IR
32 // map/reduce system
33 auto atomics = irpass::analysis::gather_statements(offload, [&](Stmt *stmt) {
34 if (auto atomic_op = stmt->cast<AtomicOpStmt>()) {
35 if (atomic_op->op_type == AtomicOpType::add ||
36 atomic_op->op_type == AtomicOpType::sub ||
37 atomic_op->op_type == AtomicOpType::max ||
38 atomic_op->op_type == AtomicOpType::min) {
39 // Local atomics do not count.
40 if (auto dest = atomic_op->dest->cast<T>()) {
41 if (std::find_if(atomic_destinations.begin(),
42 atomic_destinations.end(),
43 [&](const std::pair<T *, AtomicOpType> &elem) {
44 return elem.first == dest;
45 }) == atomic_destinations.end()) {
46 atomic_destinations.push_back(
47 {dest, atomic_op->op_type == AtomicOpType::sub
48 ? AtomicOpType::add
49 : atomic_op->op_type});
50 }
51 }
52 }
53 }
54 return false;
55 });
56
57 std::vector<std::pair<T *, AtomicOpType>> valid_reduction_values;
58 for (auto dest : atomic_destinations) {
59 // check if there is any other global load/store/atomic operations
60 auto related_global_mem_ops =
61 irpass::analysis::gather_statements(offload, [&](Stmt *stmt) {
62 if (auto load = stmt->cast<GlobalLoadStmt>()) {
63 if (irpass::analysis::maybe_same_address(load->src, dest.first)) {
64 return true;
65 }
66 } else if (auto store = stmt->cast<GlobalStoreStmt>()) {
67 if (irpass::analysis::maybe_same_address(store->dest, dest.first)) {
68 return true;
69 }
70 } else if (auto atomic = stmt->cast<AtomicOpStmt>()) {
71 if (irpass::analysis::maybe_same_address(atomic->dest,
72 dest.first)) {
73 return !((atomic->op_type == AtomicOpType::sub &&
74 dest.second == AtomicOpType::add) ||
75 atomic->op_type == dest.second);
76 }
77 }
78 for (auto &op : stmt->get_operands()) {
79 if (op == nullptr)
80 continue;
81 // Make sure the values of related atomic operations are not used.
82 if (auto atomic = op->cast<AtomicOpStmt>()) {
83 if (irpass::analysis::maybe_same_address(atomic->dest,
84 dest.first)) {
85 return true;
86 }
87 }
88 }
89 return false; // Now we are sure the statement is not related to the
90 // destination
91 });
92 if (related_global_mem_ops.empty() && dest_checker(dest.first)) {
93 valid_reduction_values.push_back(dest);
94 }
95 }
96 return valid_reduction_values;
97}
98
99void make_thread_local_offload(OffloadedStmt *offload) {
100 if (offload->task_type != OffloadedTaskType::range_for &&
101 offload->task_type != OffloadedTaskType::struct_for)
102 return;
103
104 std::vector<std::pair<Stmt *, AtomicOpType>> valid_reduction_values;
105 {
106 auto valid_global_ptrs = find_global_reduction_destinations<GlobalPtrStmt>(
107 offload, [](GlobalPtrStmt *dest) {
108 // We can only optimized reductions to global ptrs with form like
109 // loss[None] (0-D fields) for now.
110 // No TLS on quant types.
111 return (dest->snode->type == SNodeType::place) &&
112 dest->indices.empty() && dest->snode->dt->is<PrimitiveType>();
113 });
114 auto valid_global_tmps =
115 find_global_reduction_destinations<GlobalTemporaryStmt>(
116 offload, [](auto *) { return true; });
117 std::copy(valid_global_ptrs.begin(), valid_global_ptrs.end(),
118 std::back_inserter(valid_reduction_values));
119 std::copy(valid_global_tmps.begin(), valid_global_tmps.end(),
120 std::back_inserter(valid_reduction_values));
121 }
122
123 std::size_t tls_offset = 0;
124
125 // TODO: sort thread local storage variables according to dtype_size to
126 // reduce buffer fragmentation.
127 for (auto dest : valid_reduction_values) {
128 auto data_type = dest.first->ret_type.ptr_removed();
129 auto dtype_size = data_type_size(data_type);
130 // Step 1:
131 // Create thread local storage
132 {
133 if (offload->tls_prologue == nullptr) {
134 offload->tls_prologue = std::make_unique<Block>();
135 offload->tls_prologue->parent_stmt = offload;
136 }
137
138 // ensure alignment
139 tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size;
140
141 auto tls_ptr = offload->tls_prologue->push_back<ThreadLocalPtrStmt>(
142 tls_offset, TypeFactory::get_instance().get_pointer_type(data_type));
143
144 auto zero = offload->tls_prologue->insert(
145 std::make_unique<ConstStmt>(
146 dest.second == AtomicOpType::max ? get_min_value(data_type)
147 : dest.second == AtomicOpType::min ? get_max_value(data_type)
148 : TypedConstant(data_type, 0)),
149 -1);
150 // Zero-fill
151 // TODO: do not use GlobalStore for TLS ptr.
152 offload->tls_prologue->push_back<GlobalStoreStmt>(tls_ptr, zero);
153 }
154
155 // Step 2:
156 // Make loop body accumulate to TLS ptr instead of global ptr
157 {
158 auto tls_ptr = offload->body->insert(
159 Stmt::make<ThreadLocalPtrStmt>(
160 tls_offset,
161 TypeFactory::get_instance().get_pointer_type(data_type)),
162 0);
163 dest.first->replace_usages_with(tls_ptr);
164 }
165
166 // Step 3:
167 // Atomic-add thread local contribution to its global version
168 {
169 if (offload->tls_epilogue == nullptr) {
170 offload->tls_epilogue = std::make_unique<Block>();
171 offload->tls_epilogue->parent_stmt = offload;
172 }
173 auto tls_ptr = offload->tls_epilogue->push_back<ThreadLocalPtrStmt>(
174 tls_offset, TypeFactory::get_instance().get_pointer_type(data_type));
175 // TODO: do not use global load from TLS.
176 auto tls_load = offload->tls_epilogue->push_back<GlobalLoadStmt>(tls_ptr);
177 auto global_ptr = offload->tls_epilogue->insert(
178 std::unique_ptr<Stmt>(
179 (Stmt *)irpass::analysis::clone(dest.first).release()),
180 -1);
181 offload->tls_epilogue->insert(
182 AtomicOpStmt::make_for_reduction(dest.second, global_ptr, tls_load),
183 -1);
184 }
185
186 // allocate storage for the TLS variable
187 tls_offset += dtype_size;
188 }
189
190 offload->tls_size = std::max(std::size_t(1), tls_offset);
191}
192
193} // namespace
194
195namespace irpass {
196
197// This pass should happen after offloading but before lower_access
198void make_thread_local(IRNode *root, const CompileConfig &config) {
199 TI_AUTO_PROF;
200 if (auto root_block = root->cast<Block>()) {
201 for (auto &offload : root_block->statements) {
202 make_thread_local_offload(offload->cast<OffloadedStmt>());
203 }
204 } else {
205 make_thread_local_offload(root->as<OffloadedStmt>());
206 }
207 type_check(root, config);
208}
209
210} // namespace irpass
211
212} // namespace taichi::lang
213