1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file thread_storage_sync.cc
22 */
23#include <tvm/runtime/registry.h>
24#include <tvm/tir/analysis.h>
25#include <tvm/tir/builtin.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/stmt_functor.h>
28#include <tvm/tir/transform.h>
29
30#include <unordered_map>
31#include <unordered_set>
32
33#include "../../runtime/thread_storage_scope.h"
34#include "ir_utils.h"
35#include "storage_access.h"
36
37namespace tvm {
38namespace tir {
39
40class ThreadSyncPlanner : public StorageAccessVisitor {
41 public:
42 explicit ThreadSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {}
43
44 // The syncs inserted before each statement
45 std::unordered_set<const Object*> syncs_inserted_;
46
47 protected:
48 bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
49 return in_device_env() && scope == sync_scope_;
50 }
51 // Plan the sync
52 std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
53 // Unsynced reads and writes
54 std::vector<AccessEntry> reads;
55 std::vector<AccessEntry> writes;
56 // if it is a loop, rotate two times to consider effect of loop.
57 // simulation based approach to find dependenceies
58 for (size_t i = 0; i < seq.size(); ++i) {
59 const StmtEntry& s = seq[i];
60 // check if sync before statement is needed.
61 bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
62 // Apply the syncs added already.
63 if (sync_before_stmt) {
64 reads.clear();
65 writes.clear();
66 }
67 for (const AccessEntry& acc : s.access) {
68 if (acc.type == kRead) {
69 if (FindConflict(writes, acc, false)) {
70 sync_before_stmt = true;
71 break;
72 }
73 } else if (acc.type == kWrite) {
74 if (FindConflict(reads, acc, false)) {
75 sync_before_stmt = true;
76 break;
77 }
78 } else if (acc.type == kSync) {
79 reads.clear();
80 writes.clear();
81 }
82 }
83 // If sync is inserted. remove the irrelevant things.
84 if (sync_before_stmt) {
85 reads.clear();
86 writes.clear();
87 }
88 // Add the read/write of current statement
89 for (const AccessEntry& acc : s.access) {
90 if (acc.type == kRead) {
91 reads.push_back(acc);
92 } else if (acc.type == kWrite) {
93 writes.push_back(acc);
94 } else if (acc.type == kSync) {
95 reads.clear();
96 writes.clear();
97 }
98 }
99 if (sync_before_stmt) {
100 ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
101 syncs_inserted_.insert(s.stmt);
102 }
103 }
104 if (loop != nullptr) {
105 for (size_t i = 0; i < seq.size(); ++i) {
106 const StmtEntry& s = seq[i];
107 if (syncs_inserted_.count(s.stmt) != 0) break;
108 if (reads.empty() && writes.empty()) break;
109 bool sync_before_stmt = false;
110 for (const AccessEntry& acc : s.access) {
111 if (acc.type == kRead) {
112 if (FindConflict(writes, acc, true)) {
113 sync_before_stmt = true;
114 break;
115 }
116 } else if (acc.type == kWrite) {
117 if (FindConflict(reads, acc, true)) {
118 sync_before_stmt = true;
119 break;
120 }
121 } else if (acc.type == kSync) {
122 reads.clear();
123 writes.clear();
124 }
125 }
126 if (sync_before_stmt) {
127 ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
128 syncs_inserted_.insert(s.stmt);
129 break;
130 }
131 }
132 }
133 // return the exposed entries, remove unecessary ones.
134 int sync_count = 0;
135 // head are before first sync, tail are after last sync
136 std::vector<AccessEntry> head, tail;
137 AccessEntry esync;
138 esync.threads = this->env_threads();
139 esync.type = kSync;
140 esync.scope = sync_scope_;
141
142 for (const StmtEntry& s : seq) {
143 if (syncs_inserted_.count(s.stmt)) {
144 if (sync_count != 0) {
145 tail.clear();
146 } else {
147 head.push_back(esync);
148 }
149 ++sync_count;
150 }
151 for (const AccessEntry& acc : s.access) {
152 if (acc.type == kSync) {
153 if (sync_count != 0) {
154 tail.clear();
155 } else {
156 head.push_back(esync);
157 }
158 ++sync_count;
159 } else {
160 if (sync_count != 0) {
161 tail.push_back(acc);
162 } else {
163 head.push_back(acc);
164 }
165 }
166 }
167 }
168 head.insert(head.end(), tail.begin(), tail.end());
169 if (loop != nullptr) {
170 // clear double buffer flag after a loop is finished.
171 for (AccessEntry& e : head) {
172 e.double_buffer_write = false;
173 }
174 }
175 return head;
176 }
177
178 private:
179 // find conflicting entry in vec.
180 bool FindConflict(const std::vector<AccessEntry>& prev, const AccessEntry& curr,
181 bool loop_carry) {
182 for (const AccessEntry& x : prev) {
183 if (FindConflict(x, curr, loop_carry)) {
184 return true;
185 }
186 }
187 return false;
188 }
189
190 bool FindConflict(const AccessEntry& prev, const AccessEntry& curr, bool loop_carry) {
191 // Access to different buffers does not conflict.
192 if (!prev.buffer.same_as(curr.buffer)) {
193 return false;
194 }
195
196 // Assumes no race between threads
197 // Same index value means no conflicts
198 // TODO(tqchen) more standard set based testing.
199 bool has_same_index = true;
200 // Even if access has the same index, those indices need to
201 // depend on the innermost thread id to avoid race condition
202 bool depends_on_thread_index = true;
203 const VarNode* thread_index_var = nullptr;
204 if (!curr.threads.empty()) {
205 thread_index_var = curr.threads.back()->var.get();
206 }
207
208 for (size_t i = 0; i < prev.touched.size(); i++) {
209 const auto& prev_intset = prev.touched[i];
210 const auto& curr_intset = curr.touched[i];
211
212 if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) {
213 PrimExpr prev_index = prev_intset.PointValue();
214 PrimExpr curr_index = curr_intset.PointValue();
215 has_same_index = ExprDeepEqual()(prev_index, curr_index);
216 if (thread_index_var != nullptr) {
217 auto f_uses_thread_index = [=](const tvm::tir::VarNode* parameter) {
218 return parameter == thread_index_var;
219 };
220 depends_on_thread_index = depends_on_thread_index &&
221 UsesVar(curr_index, f_uses_thread_index) &&
222 UsesVar(prev_index, f_uses_thread_index);
223 }
224 } else {
225 has_same_index = false;
226 }
227
228 if (!(has_same_index && depends_on_thread_index)) {
229 break;
230 }
231 }
232 if (has_same_index && depends_on_thread_index) {
233 return false;
234 }
235
236 // If this is a read into a double buffer that was previously
237 // swapped out, then it doesn't conflict.
238 if (prev.double_buffer_write && curr.type == kRead && !loop_carry) {
239 return false;
240 }
241
242 // If nothing else allows sharing the same buffer, then they are
243 // in conflict.
244 return true;
245 }
246
247 private:
248 // synchronization scope
249 StorageScope sync_scope_;
250};
251
252// There are cases where necessary syncthreads is not inserted by ThreadSyncInserter.
253// For example, syncthreads is needed after async_wait_queue in the second loop below,
254// but since ThreadSyncInserter is not aware of the asynchronous semantics, it cannot tell
255// that the syncthreads is needed there.
256//
257// // Pipeline prologue
258// for i in range(125):
259// async_commit_queue(0):
260// async_scope:
261// shared[(i + 3) % 4] = ...
262// ...
263//
264// // Pipeline Epilogue
265// for i in range(3):
266// async_wait_queue(0, 2 - i):
267// local[...] = shared[(i + 125) % 4]
268
269// This class adds syncthreads after all async_wait_queue. That includes syncthreads that
270// can be inserted by ThreadSyncInserter as well, but ThreadSyncInserter will not insert
271// duplicate syncthreads if it finds an existing one at the synchronization point.
272class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator {
273 public:
274 explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) : sync_scope_(sync_scope) {}
275
276 Stmt VisitStmt_(const AttrStmtNode* op) final {
277 if (op->attr_key == attr::async_wait_queue_scope) {
278 auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
279 {StringImm(sync_scope_.to_string())}));
280 auto inner = op->body.as<AttrStmtNode>();
281 ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count);
282 auto zero = make_zero(DataType::Int(32));
283 auto new_body = SeqStmt({sync, inner->body});
284 return AttrStmt(zero, tir::attr::async_wait_queue_scope, op->value,
285 AttrStmt(zero, tir::attr::async_wait_inflight_count, inner->value, new_body));
286 }
287 return StmtExprMutator::VisitStmt_(op);
288 }
289
290 private:
291 StorageScope sync_scope_;
292};
293
294class ThreadSyncInserter : public StmtExprMutator {
295 public:
296 ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set<const Object*>& syncs)
297 : sync_scope_(sync_scope), syncs_(syncs) {}
298
299 Stmt VisitStmt(const Stmt& stmt) final {
300 if (syncs_.size() == 0) return stmt;
301 if (syncs_.count(stmt.get())) {
302 Stmt barrier;
303 if (sync_scope_.rank == StorageRank::kGlobal) {
304 barrier = MakeGlobalBarrier();
305 } else {
306 barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
307 {StringImm(sync_scope_.to_string())}));
308 }
309 // Mutate after query, to avoid stmt change.
310 auto ret = StmtExprMutator::VisitStmt(stmt);
311 ret = SeqStmt({barrier, ret});
312 return ret;
313 } else {
314 return StmtExprMutator::VisitStmt(stmt);
315 }
316 }
317 PrimExpr VisitExpr_(const LoadNode* op) final {
318 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
319 }
320
321 Stmt VisitStmt_(const StoreNode* op) final {
322 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
323 }
324 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
325 if (sync_scope_.rank == StorageRank::kGlobal &&
326 GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
327 ++rw_stats_[op->buffer->data].read_count;
328 }
329 return StmtExprMutator::VisitExpr_(op);
330 }
331 Stmt VisitStmt_(const BufferStoreNode* op) final {
332 if (sync_scope_.rank == StorageRank::kGlobal &&
333 GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
334 ++rw_stats_[op->buffer->data].write_count;
335 }
336 return StmtExprMutator::VisitStmt_(op);
337 }
338 Stmt VisitStmt_(const AttrStmtNode* op) final {
339 if (op->attr_key == attr::thread_extent) {
340 bool temp = true;
341 std::swap(temp, in_thread_env_);
342 thread_extents_.push_back(op);
343 Stmt ret = StmtExprMutator::VisitStmt_(op);
344 thread_extents_.pop_back();
345 std::swap(temp, in_thread_env_);
346 // first thread scope.
347 if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
348 ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
349 num_blocks_ = PrimExpr();
350 is_lead_ = PrimExpr();
351 }
352 return ret;
353 } else {
354 return StmtExprMutator::VisitStmt_(op);
355 }
356 }
357
358 PrimExpr VisitExpr_(const CallNode* op) final {
359 if (op->op.same_as(builtin::tvm_access_ptr())) {
360 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
361 op = expr.as<CallNode>();
362 ICHECK_EQ(op->args.size(), 5U);
363 Var buffer_var(GetRef<Var>(op->args[1].as<VarNode>()));
364 const IntImmNode* flag = op->args[4].as<IntImmNode>();
365 if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
366 GetScope(buffer_var).rank == StorageRank::kGlobal) {
367 ++rw_stats_[buffer_var].read_count;
368 }
369 if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
370 GetScope(buffer_var).rank == StorageRank::kGlobal) {
371 ++rw_stats_[buffer_var].write_count;
372 }
373 return expr;
374 } else {
375 return StmtExprMutator::VisitExpr_(op);
376 }
377 }
378
379 private:
380 // RW statistics about data
381 struct Entry {
382 int read_count{0};
383 int write_count{0};
384 };
385
386 // Get current storage scope.
387 StorageScope GetScope(Var buffer_var) const {
388 return StorageScope::Create(GetPtrStorageScope(buffer_var));
389 }
390
391 // private functions.
392 Stmt InitGlobalBarrier(const AttrStmtNode* op) {
393 ICHECK(op != nullptr);
394 Array<PrimExpr> pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)};
395 Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
396 Stmt body = op->body;
397 for (const auto& kv : rw_stats_) {
398 const auto& e = kv.second;
399 if (e.read_count != 0 && e.write_count != 0) {
400 body = AttrStmt(kv.first, attr::volatile_scope, 1, body);
401 }
402 }
403 rw_stats_.clear();
404 Stmt kinit = Evaluate(Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}));
405 body = SeqStmt({kinit, body});
406 body = AttrStmt(op->node, op->attr_key, op->value, body);
407 return SeqStmt({prep, body});
408 }
409 Stmt MakeGlobalBarrier() {
410 ICHECK(sync_scope_.rank == StorageRank::kGlobal);
411 if (!num_blocks_.defined()) {
412 ICHECK(!is_lead_.defined());
413 num_work_dim_ = thread_extents_.size();
414 for (const AttrStmtNode* attr : thread_extents_) {
415 IterVar iv = Downcast<IterVar>(attr->node);
416 runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag);
417 if (s.rank == 0) {
418 num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value);
419 } else if (s.rank == 1) {
420 PrimExpr cond = iv->var == make_zero(iv->var.dtype());
421 is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
422 }
423 }
424 } else {
425 ICHECK_EQ(num_work_dim_, thread_extents_.size());
426 }
427 return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
428 {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}));
429 }
430 // data structure.
431 StorageScope sync_scope_;
432 const std::unordered_set<const Object*>& syncs_;
433 // The read write statistics of storage
434 std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> rw_stats_;
435 // The statistics for global barrier
436 bool in_thread_env_{false};
437 // memorized results
438 std::vector<const AttrStmtNode*> thread_extents_;
439 size_t num_work_dim_{0};
440 PrimExpr num_blocks_;
441 PrimExpr is_lead_;
442};
443
444Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
445 StorageScope sync_scope = StorageScope::Create(storage_scope);
446 if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") {
447 stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
448 }
449 ThreadSyncPlanner planner(sync_scope);
450 planner(stmt);
451 return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
452}
453
454namespace transform {
455
456Pass ThreadSync(String storage_scope) {
457 auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
458 auto* n = f.CopyOnWrite();
459 n->body = ThreadSync(std::move(n->body), storage_scope);
460 return f;
461 };
462 return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {});
463}
464
465TVM_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync);
466
467} // namespace transform
468} // namespace tir
469} // namespace tvm
470