1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file thread_storage_sync.cc |
22 | */ |
23 | #include <tvm/runtime/registry.h> |
24 | #include <tvm/tir/analysis.h> |
25 | #include <tvm/tir/builtin.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include <unordered_map> |
31 | #include <unordered_set> |
32 | |
33 | #include "../../runtime/thread_storage_scope.h" |
34 | #include "ir_utils.h" |
35 | #include "storage_access.h" |
36 | |
37 | namespace tvm { |
38 | namespace tir { |
39 | |
40 | class ThreadSyncPlanner : public StorageAccessVisitor { |
41 | public: |
42 | explicit ThreadSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {} |
43 | |
44 | // The syncs inserted before each statement |
45 | std::unordered_set<const Object*> syncs_inserted_; |
46 | |
47 | protected: |
48 | bool Enabled(const VarNode* buf, const StorageScope& scope) const final { |
49 | return in_device_env() && scope == sync_scope_; |
50 | } |
51 | // Plan the sync |
52 | std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final { |
53 | // Unsynced reads and writes |
54 | std::vector<AccessEntry> reads; |
55 | std::vector<AccessEntry> writes; |
56 | // if it is a loop, rotate two times to consider effect of loop. |
57 | // simulation based approach to find dependenceies |
58 | for (size_t i = 0; i < seq.size(); ++i) { |
59 | const StmtEntry& s = seq[i]; |
60 | // check if sync before statement is needed. |
61 | bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); |
62 | // Apply the syncs added already. |
63 | if (sync_before_stmt) { |
64 | reads.clear(); |
65 | writes.clear(); |
66 | } |
67 | for (const AccessEntry& acc : s.access) { |
68 | if (acc.type == kRead) { |
69 | if (FindConflict(writes, acc, false)) { |
70 | sync_before_stmt = true; |
71 | break; |
72 | } |
73 | } else if (acc.type == kWrite) { |
74 | if (FindConflict(reads, acc, false)) { |
75 | sync_before_stmt = true; |
76 | break; |
77 | } |
78 | } else if (acc.type == kSync) { |
79 | reads.clear(); |
80 | writes.clear(); |
81 | } |
82 | } |
83 | // If sync is inserted. remove the irrelevant things. |
84 | if (sync_before_stmt) { |
85 | reads.clear(); |
86 | writes.clear(); |
87 | } |
88 | // Add the read/write of current statement |
89 | for (const AccessEntry& acc : s.access) { |
90 | if (acc.type == kRead) { |
91 | reads.push_back(acc); |
92 | } else if (acc.type == kWrite) { |
93 | writes.push_back(acc); |
94 | } else if (acc.type == kSync) { |
95 | reads.clear(); |
96 | writes.clear(); |
97 | } |
98 | } |
99 | if (sync_before_stmt) { |
100 | ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition" ; |
101 | syncs_inserted_.insert(s.stmt); |
102 | } |
103 | } |
104 | if (loop != nullptr) { |
105 | for (size_t i = 0; i < seq.size(); ++i) { |
106 | const StmtEntry& s = seq[i]; |
107 | if (syncs_inserted_.count(s.stmt) != 0) break; |
108 | if (reads.empty() && writes.empty()) break; |
109 | bool sync_before_stmt = false; |
110 | for (const AccessEntry& acc : s.access) { |
111 | if (acc.type == kRead) { |
112 | if (FindConflict(writes, acc, true)) { |
113 | sync_before_stmt = true; |
114 | break; |
115 | } |
116 | } else if (acc.type == kWrite) { |
117 | if (FindConflict(reads, acc, true)) { |
118 | sync_before_stmt = true; |
119 | break; |
120 | } |
121 | } else if (acc.type == kSync) { |
122 | reads.clear(); |
123 | writes.clear(); |
124 | } |
125 | } |
126 | if (sync_before_stmt) { |
127 | ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition" ; |
128 | syncs_inserted_.insert(s.stmt); |
129 | break; |
130 | } |
131 | } |
132 | } |
133 | // return the exposed entries, remove unecessary ones. |
134 | int sync_count = 0; |
135 | // head are before first sync, tail are after last sync |
136 | std::vector<AccessEntry> head, tail; |
137 | AccessEntry esync; |
138 | esync.threads = this->env_threads(); |
139 | esync.type = kSync; |
140 | esync.scope = sync_scope_; |
141 | |
142 | for (const StmtEntry& s : seq) { |
143 | if (syncs_inserted_.count(s.stmt)) { |
144 | if (sync_count != 0) { |
145 | tail.clear(); |
146 | } else { |
147 | head.push_back(esync); |
148 | } |
149 | ++sync_count; |
150 | } |
151 | for (const AccessEntry& acc : s.access) { |
152 | if (acc.type == kSync) { |
153 | if (sync_count != 0) { |
154 | tail.clear(); |
155 | } else { |
156 | head.push_back(esync); |
157 | } |
158 | ++sync_count; |
159 | } else { |
160 | if (sync_count != 0) { |
161 | tail.push_back(acc); |
162 | } else { |
163 | head.push_back(acc); |
164 | } |
165 | } |
166 | } |
167 | } |
168 | head.insert(head.end(), tail.begin(), tail.end()); |
169 | if (loop != nullptr) { |
170 | // clear double buffer flag after a loop is finished. |
171 | for (AccessEntry& e : head) { |
172 | e.double_buffer_write = false; |
173 | } |
174 | } |
175 | return head; |
176 | } |
177 | |
178 | private: |
179 | // find conflicting entry in vec. |
180 | bool FindConflict(const std::vector<AccessEntry>& prev, const AccessEntry& curr, |
181 | bool loop_carry) { |
182 | for (const AccessEntry& x : prev) { |
183 | if (FindConflict(x, curr, loop_carry)) { |
184 | return true; |
185 | } |
186 | } |
187 | return false; |
188 | } |
189 | |
190 | bool FindConflict(const AccessEntry& prev, const AccessEntry& curr, bool loop_carry) { |
191 | // Access to different buffers does not conflict. |
192 | if (!prev.buffer.same_as(curr.buffer)) { |
193 | return false; |
194 | } |
195 | |
196 | // Assumes no race between threads |
197 | // Same index value means no conflicts |
198 | // TODO(tqchen) more standard set based testing. |
199 | bool has_same_index = true; |
200 | // Even if access has the same index, those indices need to |
201 | // depend on the innermost thread id to avoid race condition |
202 | bool depends_on_thread_index = true; |
203 | const VarNode* thread_index_var = nullptr; |
204 | if (!curr.threads.empty()) { |
205 | thread_index_var = curr.threads.back()->var.get(); |
206 | } |
207 | |
208 | for (size_t i = 0; i < prev.touched.size(); i++) { |
209 | const auto& prev_intset = prev.touched[i]; |
210 | const auto& curr_intset = curr.touched[i]; |
211 | |
212 | if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) { |
213 | PrimExpr prev_index = prev_intset.PointValue(); |
214 | PrimExpr curr_index = curr_intset.PointValue(); |
215 | has_same_index = ExprDeepEqual()(prev_index, curr_index); |
216 | if (thread_index_var != nullptr) { |
217 | auto f_uses_thread_index = [=](const tvm::tir::VarNode* parameter) { |
218 | return parameter == thread_index_var; |
219 | }; |
220 | depends_on_thread_index = depends_on_thread_index && |
221 | UsesVar(curr_index, f_uses_thread_index) && |
222 | UsesVar(prev_index, f_uses_thread_index); |
223 | } |
224 | } else { |
225 | has_same_index = false; |
226 | } |
227 | |
228 | if (!(has_same_index && depends_on_thread_index)) { |
229 | break; |
230 | } |
231 | } |
232 | if (has_same_index && depends_on_thread_index) { |
233 | return false; |
234 | } |
235 | |
236 | // If this is a read into a double buffer that was previously |
237 | // swapped out, then it doesn't conflict. |
238 | if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { |
239 | return false; |
240 | } |
241 | |
242 | // If nothing else allows sharing the same buffer, then they are |
243 | // in conflict. |
244 | return true; |
245 | } |
246 | |
247 | private: |
248 | // synchronization scope |
249 | StorageScope sync_scope_; |
250 | }; |
251 | |
252 | // There are cases where necessary syncthreads is not inserted by ThreadSyncInserter. |
253 | // For example, syncthreads is needed after async_wait_queue in the second loop below, |
254 | // but since ThreadSyncInserter is not aware of the asynchronous semantics, it cannot tell |
255 | // that the syncthreads is needed there. |
256 | // |
257 | // // Pipeline prologue |
258 | // for i in range(125): |
259 | // async_commit_queue(0): |
260 | // async_scope: |
261 | // shared[(i + 3) % 4] = ... |
262 | // ... |
263 | // |
264 | // // Pipeline Epilogue |
265 | // for i in range(3): |
266 | // async_wait_queue(0, 2 - i): |
267 | // local[...] = shared[(i + 125) % 4] |
268 | |
269 | // This class adds syncthreads after all async_wait_queue. That includes syncthreads that |
270 | // can be inserted by ThreadSyncInserter as well, but ThreadSyncInserter will not insert |
271 | // duplicate syncthreads if it finds an existing one at the synchronization point. |
272 | class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator { |
273 | public: |
274 | explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) : sync_scope_(sync_scope) {} |
275 | |
276 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
277 | if (op->attr_key == attr::async_wait_queue_scope) { |
278 | auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), |
279 | {StringImm(sync_scope_.to_string())})); |
280 | auto inner = op->body.as<AttrStmtNode>(); |
281 | ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count); |
282 | auto zero = make_zero(DataType::Int(32)); |
283 | auto new_body = SeqStmt({sync, inner->body}); |
284 | return AttrStmt(zero, tir::attr::async_wait_queue_scope, op->value, |
285 | AttrStmt(zero, tir::attr::async_wait_inflight_count, inner->value, new_body)); |
286 | } |
287 | return StmtExprMutator::VisitStmt_(op); |
288 | } |
289 | |
290 | private: |
291 | StorageScope sync_scope_; |
292 | }; |
293 | |
294 | class ThreadSyncInserter : public StmtExprMutator { |
295 | public: |
296 | ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set<const Object*>& syncs) |
297 | : sync_scope_(sync_scope), syncs_(syncs) {} |
298 | |
299 | Stmt VisitStmt(const Stmt& stmt) final { |
300 | if (syncs_.size() == 0) return stmt; |
301 | if (syncs_.count(stmt.get())) { |
302 | Stmt barrier; |
303 | if (sync_scope_.rank == StorageRank::kGlobal) { |
304 | barrier = MakeGlobalBarrier(); |
305 | } else { |
306 | barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), |
307 | {StringImm(sync_scope_.to_string())})); |
308 | } |
309 | // Mutate after query, to avoid stmt change. |
310 | auto ret = StmtExprMutator::VisitStmt(stmt); |
311 | ret = SeqStmt({barrier, ret}); |
312 | return ret; |
313 | } else { |
314 | return StmtExprMutator::VisitStmt(stmt); |
315 | } |
316 | } |
317 | PrimExpr VisitExpr_(const LoadNode* op) final { |
318 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
319 | } |
320 | |
321 | Stmt VisitStmt_(const StoreNode* op) final { |
322 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
323 | } |
324 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
325 | if (sync_scope_.rank == StorageRank::kGlobal && |
326 | GetScope(op->buffer->data).rank == StorageRank::kGlobal) { |
327 | ++rw_stats_[op->buffer->data].read_count; |
328 | } |
329 | return StmtExprMutator::VisitExpr_(op); |
330 | } |
331 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
332 | if (sync_scope_.rank == StorageRank::kGlobal && |
333 | GetScope(op->buffer->data).rank == StorageRank::kGlobal) { |
334 | ++rw_stats_[op->buffer->data].write_count; |
335 | } |
336 | return StmtExprMutator::VisitStmt_(op); |
337 | } |
338 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
339 | if (op->attr_key == attr::thread_extent) { |
340 | bool temp = true; |
341 | std::swap(temp, in_thread_env_); |
342 | thread_extents_.push_back(op); |
343 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
344 | thread_extents_.pop_back(); |
345 | std::swap(temp, in_thread_env_); |
346 | // first thread scope. |
347 | if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) { |
348 | ret = InitGlobalBarrier(ret.as<AttrStmtNode>()); |
349 | num_blocks_ = PrimExpr(); |
350 | is_lead_ = PrimExpr(); |
351 | } |
352 | return ret; |
353 | } else { |
354 | return StmtExprMutator::VisitStmt_(op); |
355 | } |
356 | } |
357 | |
358 | PrimExpr VisitExpr_(const CallNode* op) final { |
359 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
360 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
361 | op = expr.as<CallNode>(); |
362 | ICHECK_EQ(op->args.size(), 5U); |
363 | Var buffer_var(GetRef<Var>(op->args[1].as<VarNode>())); |
364 | const IntImmNode* flag = op->args[4].as<IntImmNode>(); |
365 | if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && |
366 | GetScope(buffer_var).rank == StorageRank::kGlobal) { |
367 | ++rw_stats_[buffer_var].read_count; |
368 | } |
369 | if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && |
370 | GetScope(buffer_var).rank == StorageRank::kGlobal) { |
371 | ++rw_stats_[buffer_var].write_count; |
372 | } |
373 | return expr; |
374 | } else { |
375 | return StmtExprMutator::VisitExpr_(op); |
376 | } |
377 | } |
378 | |
379 | private: |
380 | // RW statistics about data |
381 | struct Entry { |
382 | int read_count{0}; |
383 | int write_count{0}; |
384 | }; |
385 | |
386 | // Get current storage scope. |
387 | StorageScope GetScope(Var buffer_var) const { |
388 | return StorageScope::Create(GetPtrStorageScope(buffer_var)); |
389 | } |
390 | |
391 | // private functions. |
392 | Stmt InitGlobalBarrier(const AttrStmtNode* op) { |
393 | ICHECK(op != nullptr); |
394 | Array<PrimExpr> pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; |
395 | Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); |
396 | Stmt body = op->body; |
397 | for (const auto& kv : rw_stats_) { |
398 | const auto& e = kv.second; |
399 | if (e.read_count != 0 && e.write_count != 0) { |
400 | body = AttrStmt(kv.first, attr::volatile_scope, 1, body); |
401 | } |
402 | } |
403 | rw_stats_.clear(); |
404 | Stmt kinit = Evaluate(Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {})); |
405 | body = SeqStmt({kinit, body}); |
406 | body = AttrStmt(op->node, op->attr_key, op->value, body); |
407 | return SeqStmt({prep, body}); |
408 | } |
409 | Stmt MakeGlobalBarrier() { |
410 | ICHECK(sync_scope_.rank == StorageRank::kGlobal); |
411 | if (!num_blocks_.defined()) { |
412 | ICHECK(!is_lead_.defined()); |
413 | num_work_dim_ = thread_extents_.size(); |
414 | for (const AttrStmtNode* attr : thread_extents_) { |
415 | IterVar iv = Downcast<IterVar>(attr->node); |
416 | runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag); |
417 | if (s.rank == 0) { |
418 | num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); |
419 | } else if (s.rank == 1) { |
420 | PrimExpr cond = iv->var == make_zero(iv->var.dtype()); |
421 | is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; |
422 | } |
423 | } |
424 | } else { |
425 | ICHECK_EQ(num_work_dim_, thread_extents_.size()); |
426 | } |
427 | return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), |
428 | {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_})); |
429 | } |
430 | // data structure. |
431 | StorageScope sync_scope_; |
432 | const std::unordered_set<const Object*>& syncs_; |
433 | // The read write statistics of storage |
434 | std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> rw_stats_; |
435 | // The statistics for global barrier |
436 | bool in_thread_env_{false}; |
437 | // memorized results |
438 | std::vector<const AttrStmtNode*> thread_extents_; |
439 | size_t num_work_dim_{0}; |
440 | PrimExpr num_blocks_; |
441 | PrimExpr is_lead_; |
442 | }; |
443 | |
444 | Stmt ThreadSync(Stmt stmt, std::string storage_scope) { |
445 | StorageScope sync_scope = StorageScope::Create(storage_scope); |
446 | if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "" ) { |
447 | stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt); |
448 | } |
449 | ThreadSyncPlanner planner(sync_scope); |
450 | planner(stmt); |
451 | return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); |
452 | } |
453 | |
454 | namespace transform { |
455 | |
456 | Pass ThreadSync(String storage_scope) { |
457 | auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) { |
458 | auto* n = f.CopyOnWrite(); |
459 | n->body = ThreadSync(std::move(n->body), storage_scope); |
460 | return f; |
461 | }; |
462 | return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync" , {}); |
463 | } |
464 | |
465 | TVM_REGISTER_GLOBAL("tir.transform.ThreadSync" ).set_body_typed(ThreadSync); |
466 | |
467 | } // namespace transform |
468 | } // namespace tir |
469 | } // namespace tvm |
470 | |