1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file coproc_sync.cc |
22 | */ |
23 | #include <tvm/runtime/registry.h> |
24 | #include <tvm/tir/builtin.h> |
25 | #include <tvm/tir/expr.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | #include <tvm/tir/transform.h> |
28 | |
29 | #include <unordered_map> |
30 | #include <unordered_set> |
31 | |
32 | #include "ir_utils.h" |
33 | #include "storage_access.h" |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | |
38 | // Visitor to find touched set by co-processor scope. |
39 | class CoProcTouchedBuffer : public StmtExprVisitor { |
40 | public: |
41 | void VisitExpr_(const LoadNode* op) final { |
42 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
43 | } |
44 | void VisitStmt_(const StoreNode* op) final { |
45 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
46 | } |
47 | void VisitExpr_(const BufferLoadNode* op) final { |
48 | if (in_scope_) { |
49 | touched_[op->buffer->data.get()].coproc = true; |
50 | } else { |
51 | touched_[op->buffer->data.get()].normal = true; |
52 | } |
53 | StmtExprVisitor::VisitExpr_(op); |
54 | } |
55 | void VisitStmt_(const BufferStoreNode* op) final { |
56 | if (in_scope_) { |
57 | touched_[op->buffer->data.get()].coproc = true; |
58 | } else { |
59 | touched_[op->buffer->data.get()].normal = true; |
60 | } |
61 | StmtExprVisitor::VisitStmt_(op); |
62 | } |
63 | void VisitExpr_(const CallNode* op) final { |
64 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
65 | const VarNode* buffer = op->args[1].as<VarNode>(); |
66 | if (in_scope_) { |
67 | touched_[buffer].coproc = true; |
68 | } else { |
69 | touched_[buffer].normal = true; |
70 | } |
71 | } |
72 | StmtExprVisitor::VisitExpr_(op); |
73 | } |
74 | void VisitStmt_(const AttrStmtNode* op) final { |
75 | if (op->attr_key == attr::coproc_scope && !in_scope_) { |
76 | in_scope_ = true; |
77 | IterVar iv = Downcast<IterVar>(op->node); |
78 | coproc_.insert(iv); |
79 | StmtExprVisitor::VisitStmt_(op); |
80 | in_scope_ = false; |
81 | } else { |
82 | StmtExprVisitor::VisitStmt_(op); |
83 | } |
84 | } |
85 | |
86 | // Touch Entry |
87 | struct TouchEntry { |
88 | bool normal{false}; |
89 | bool coproc{false}; |
90 | }; |
91 | std::unordered_map<const VarNode*, TouchEntry> touched_; |
92 | std::unordered_set<IterVar> coproc_; |
93 | |
94 | private: |
95 | bool in_scope_{false}; |
96 | }; |
97 | |
98 | // Synchronization planning with co-processor. |
99 | class CoProcSyncPlanner : public StorageAccessVisitor { |
100 | public: |
101 | explicit CoProcSyncPlanner(const std::unordered_set<const VarNode*>& touched, |
102 | const std::string& coproc_name) |
103 | : touched_(touched), coproc_name_(coproc_name) {} |
104 | |
105 | void Plan(const Stmt& stmt) { |
106 | this->VisitStmt(stmt); |
107 | PlanSync(scope_.back(), nullptr, true); |
108 | if (sync_.size() == 0) { |
109 | sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync" ); |
110 | } |
111 | } |
112 | |
113 | // Write synchronization to be inserted before or after stmt. |
114 | std::unordered_map<const Object*, std::vector<Stmt>> sync_; |
115 | |
116 | protected: |
117 | bool Enabled(const VarNode* buf, const StorageScope& scope) const final { |
118 | return touched_.count(buf); |
119 | } |
120 | |
121 | // Plan the sync |
122 | std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final { |
123 | return PlanSync(seq, loop, false); |
124 | } |
125 | |
126 | private: |
127 | // Plan write synchronization if write is not coherent |
128 | std::vector<AccessEntry> PlanSync(std::vector<StmtEntry> seq, const ForNode* loop, |
129 | bool force_sync_at_end) { |
130 | // detect write barriers |
131 | // access by the co-processor. |
132 | std::vector<AccessEntry> co_access; |
133 | bool contain_sync = false; |
134 | |
135 | auto find_conflict = [&](const AccessEntry& acc) { |
136 | for (const AccessEntry& x : co_access) { |
137 | if (x.buffer.same_as(acc.buffer) && |
138 | ((acc.type == kRead && x.type == kWrite) || acc.type == kWrite)) { |
139 | return true; |
140 | } |
141 | } |
142 | return false; |
143 | }; |
144 | for (size_t i = 0; i < seq.size(); ++i) { |
145 | const StmtEntry& s = seq[i]; |
146 | bool sync_write = false; |
147 | for (const AccessEntry& acc : s.access) { |
148 | if (acc.threads.size() == 0 && find_conflict(acc)) { |
149 | sync_write = true; |
150 | break; |
151 | } |
152 | if (acc.type == kSync) { |
153 | co_access.clear(); |
154 | contain_sync = true; |
155 | } |
156 | } |
157 | if (sync_write) { |
158 | ICHECK_NE(i, 0U); |
159 | sync_[seq[i - 1].stmt] = GetSync(co_access); |
160 | co_access.clear(); |
161 | contain_sync = true; |
162 | } |
163 | for (const AccessEntry& acc : s.access) { |
164 | if (acc.threads.size() != 0) { |
165 | co_access.push_back(acc); |
166 | } |
167 | } |
168 | } |
169 | bool sync_at_end = force_sync_at_end; |
170 | if (loop != nullptr && !sync_at_end) { |
171 | // loop carray dependency |
172 | for (size_t i = 0; i < seq.size(); ++i) { |
173 | const StmtEntry& s = seq[i]; |
174 | for (const AccessEntry& acc : s.access) { |
175 | if (acc.threads.size() == 0 && find_conflict(acc)) { |
176 | sync_at_end = true; |
177 | break; |
178 | } |
179 | } |
180 | if (sync_.count(s.stmt) || sync_at_end) break; |
181 | } |
182 | } |
183 | if (sync_at_end && co_access.size() != 0) { |
184 | ICHECK_NE(seq.size(), 0); |
185 | contain_sync = true; |
186 | sync_[seq.back().stmt] = GetSync(co_access); |
187 | co_access.clear(); |
188 | } |
189 | if (contain_sync) { |
190 | AccessEntry e; |
191 | e.type = kSync; |
192 | co_access.insert(co_access.begin(), e); |
193 | } |
194 | return co_access; |
195 | } |
196 | // Add write Synchronization |
197 | std::vector<Stmt> GetSync(const std::vector<AccessEntry>& co_access) { |
198 | // Does not consider memory coherence, need runtime. |
199 | ICHECK_NE(co_access.size(), 0U); |
200 | ICHECK_EQ(co_access[0].threads.size(), 1U); |
201 | return GetSync(coproc_name_ + ".coproc_sync" ); |
202 | } |
203 | |
204 | std::vector<Stmt> GetSync(std::string sync_name) { |
205 | return {Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}))}; |
206 | } |
207 | |
208 | const std::unordered_set<const VarNode*>& touched_; |
209 | std::string coproc_name_; |
210 | }; |
211 | |
212 | // Detect memory barriers when coproc read/write memory |
213 | class CoProcBarrierDetector : public StorageAccessVisitor { |
214 | public: |
215 | explicit CoProcBarrierDetector(const std::unordered_set<const VarNode*>& touched, |
216 | const std::string& coproc_name) |
217 | : touched_(touched) { |
218 | read_barrier_name_ = "tir." + coproc_name + ".coproc_read_barrier" ; |
219 | write_barrier_name_ = "tir." + coproc_name + ".coproc_write_barrier" ; |
220 | } |
221 | |
222 | void PlanReadBarrier(const Stmt& stmt) { |
223 | read_barrier_ = true; |
224 | this->VisitStmt(stmt); |
225 | PlanReadBarrier(scope_.back(), nullptr); |
226 | } |
227 | void PlanWriteBarrier(const Stmt& stmt) { |
228 | read_barrier_ = false; |
229 | this->VisitStmt(stmt); |
230 | PlanWriteBarrier(scope_.back(), nullptr); |
231 | } |
232 | |
233 | std::unordered_map<const Object*, std::vector<Stmt>> barrier_before_; |
234 | std::unordered_map<const Object*, std::vector<Stmt>> barrier_after_; |
235 | |
236 | protected: |
237 | bool Enabled(const VarNode* buf, const StorageScope& scope) const final { |
238 | return touched_.count(buf); |
239 | } |
240 | |
241 | // Plan the sync |
242 | std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final { |
243 | if (read_barrier_) { |
244 | return PlanReadBarrier(seq, loop); |
245 | } else { |
246 | return PlanWriteBarrier(seq, loop); |
247 | } |
248 | } |
249 | |
250 | private: |
251 | // Plan write barrier at Read after write point. |
252 | std::vector<AccessEntry> PlanWriteBarrier(std::vector<StmtEntry> seq, const ForNode* loop) { |
253 | std::vector<AccessEntry> read_seq; |
254 | std::unordered_map<const VarNode*, std::vector<AccessEntry>> write_set; |
255 | |
256 | auto fupdate = [&](size_t i, const AccessEntry& acc) { |
257 | auto it = write_set.find(acc.buffer.get()); |
258 | if (it != write_set.end()) { |
259 | ICHECK_NE(i, 0U); |
260 | barrier_after_[seq[i - 1].stmt].push_back(MakeBarrier(write_barrier_name_, it->second)); |
261 | write_set.erase(it); |
262 | } |
263 | }; |
264 | for (size_t i = 0; i < seq.size(); ++i) { |
265 | const StmtEntry& s = seq[i]; |
266 | for (const AccessEntry& acc : s.access) { |
267 | if (acc.threads.size() == 0 && acc.type == kRead) { |
268 | fupdate(i, acc); |
269 | read_seq.push_back(acc); |
270 | } |
271 | } |
272 | for (const AccessEntry& acc : s.access) { |
273 | if (acc.threads.size() != 0 && acc.type == kWrite) { |
274 | write_set[acc.buffer.get()].push_back(acc); |
275 | } |
276 | } |
277 | } |
278 | // loop carry |
279 | if (loop != nullptr) { |
280 | for (const AccessEntry& acc : read_seq) { |
281 | fupdate(seq.size(), acc); |
282 | } |
283 | } |
284 | for (const auto& kv : write_set) { |
285 | read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end()); |
286 | } |
287 | return read_seq; |
288 | } |
289 | |
290 | std::vector<AccessEntry> PlanReadBarrier(std::vector<StmtEntry> seq, const ForNode* loop) { |
291 | std::vector<AccessEntry> write_seq; |
292 | std::unordered_map<const VarNode*, std::vector<AccessEntry>> read_set; |
293 | |
294 | auto fupdate = [&](size_t i, const AccessEntry& acc) { |
295 | auto it = read_set.find(acc.buffer.get()); |
296 | if (it != read_set.end()) { |
297 | ICHECK_NE(i, seq.size()); |
298 | barrier_before_[seq[i].stmt].push_back(MakeBarrier(read_barrier_name_, it->second)); |
299 | read_set.erase(it); |
300 | } |
301 | }; |
302 | |
303 | for (size_t i = seq.size(); i != 0; --i) { |
304 | const StmtEntry& s = seq[i - 1]; |
305 | for (const AccessEntry& acc : s.access) { |
306 | if (acc.threads.size() == 0 && acc.type == kWrite) { |
307 | fupdate(i, acc); |
308 | write_seq.push_back(acc); |
309 | } |
310 | } |
311 | for (const AccessEntry& acc : s.access) { |
312 | if (acc.threads.size() != 0 && acc.type == kRead) { |
313 | read_set[acc.buffer.get()].push_back(acc); |
314 | } |
315 | } |
316 | } |
317 | // loop carry |
318 | if (loop != nullptr) { |
319 | for (const AccessEntry& acc : write_seq) { |
320 | fupdate(0, acc); |
321 | } |
322 | } |
323 | for (const auto& kv : read_set) { |
324 | write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end()); |
325 | } |
326 | return write_seq; |
327 | } |
328 | |
329 | Stmt MakeBarrier(const std::string& func, const std::vector<AccessEntry>& wvec) { |
330 | // insert write point |
331 | Array<arith::IntSet> wset; |
332 | for (const AccessEntry& acc : wvec) { |
333 | ICHECK(acc.dtype == wvec[0].dtype); |
334 | ICHECK_EQ(acc.touched.size(), 1) << "CoProcBarrierDetector expects flat memory" ; |
335 | wset.push_back(acc.touched[0]); |
336 | } |
337 | Range none; |
338 | Range r = arith::Union(wset).CoverRange(none); |
339 | ICHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; |
340 | PrimExpr min = r->min; |
341 | PrimExpr extent = r->extent; |
342 | return Evaluate(Call(DataType::Int(32), Op::Get(func), |
343 | {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent})); |
344 | } |
345 | // Write barrier name |
346 | bool read_barrier_{false}; |
347 | std::string read_barrier_name_; |
348 | std::string write_barrier_name_; |
349 | const std::unordered_set<const VarNode*>& touched_; |
350 | }; |
351 | |
352 | class CoProcInstDepDetector : public StmtVisitor { |
353 | public: |
354 | explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name) |
355 | : coproc_axis_(coproc_axis) { |
356 | sync_push_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_push" ); |
357 | sync_pop_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_pop" ); |
358 | } |
359 | |
360 | void Plan(const Stmt& stmt) { |
361 | this->VisitStmt(stmt); |
362 | if (last_state_.node != nullptr) { |
363 | MatchFixEnterPop(first_state_); |
364 | MatchFixExitPush(last_state_); |
365 | } |
366 | } |
367 | |
368 | void VisitStmt_(const AttrStmtNode* op) final { |
369 | if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) { |
370 | const IntImmNode* ctx_id = op->value.as<IntImmNode>(); |
371 | ICHECK(ctx_id != nullptr); |
372 | curr_state_.clear(); |
373 | curr_state_.node = op->body.get(); |
374 | curr_state_.enter_ctx.insert(ctx_id->value); |
375 | curr_state_.exit_ctx.insert(ctx_id->value); |
376 | UpdateState(); |
377 | } else { |
378 | StmtVisitor::VisitStmt_(op); |
379 | } |
380 | } |
381 | |
382 | void VisitStmt_(const ForNode* op) final { |
383 | SyncState temp_first, temp_last; |
384 | std::swap(first_state_, temp_first); |
385 | std::swap(last_state_, temp_last); |
386 | this->VisitStmt(op->body); |
387 | curr_state_.clear(); |
388 | if (last_state_.node != nullptr) { |
389 | curr_state_.node = op; |
390 | ICHECK(first_state_.node != nullptr); |
391 | // loop carry dependency |
392 | InjectSync(last_state_, first_state_, &(curr_state_.exit_push), &(curr_state_.enter_pop)); |
393 | curr_state_.enter_ctx = first_state_.enter_ctx; |
394 | curr_state_.exit_ctx = last_state_.exit_ctx; |
395 | } |
396 | std::swap(first_state_, temp_first); |
397 | std::swap(last_state_, temp_last); |
398 | if (curr_state_.node != nullptr) { |
399 | UpdateState(); |
400 | } |
401 | } |
402 | |
403 | void VisitStmt_(const IfThenElseNode* op) final { |
404 | SyncState temp_first, temp_last, curr_state; |
405 | std::swap(first_state_, temp_first); |
406 | std::swap(last_state_, temp_last); |
407 | { |
408 | // then stmt |
409 | this->VisitStmt(op->then_case); |
410 | if (last_state_.node != nullptr) { |
411 | curr_state.node = op; |
412 | MatchFixEnterPop(first_state_); |
413 | MatchFixExitPush(last_state_); |
414 | curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); |
415 | curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); |
416 | } |
417 | first_state_.clear(); |
418 | last_state_.clear(); |
419 | } |
420 | if (op->else_case) { |
421 | this->VisitStmt(op->else_case.value()); |
422 | if (last_state_.node != nullptr) { |
423 | curr_state.node = op; |
424 | MatchFixEnterPop(first_state_); |
425 | MatchFixExitPush(last_state_); |
426 | curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); |
427 | curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); |
428 | } |
429 | } |
430 | // update in the trace. |
431 | std::swap(first_state_, temp_first); |
432 | std::swap(last_state_, temp_last); |
433 | std::swap(curr_state_, curr_state); |
434 | if (curr_state_.node != nullptr) { |
435 | UpdateState(); |
436 | } |
437 | } |
438 | |
439 | void VisitStmt_(const WhileNode* op) final { |
440 | // TODO(masahi): Do we need a special handling for While nodes? |
441 | LOG(FATAL) << "WhileNode not supported in CoProcSync." ; |
442 | } |
443 | |
444 | // insert before is stored in reverse order |
445 | // the first element is closest to the node. |
446 | std::unordered_map<const Object*, std::vector<Stmt>> insert_before_; |
447 | std::unordered_map<const Object*, std::vector<Stmt>> insert_after_; |
448 | |
449 | private: |
450 | // state in the sync entry |
451 | struct SyncState { |
452 | // The statement of the state. |
453 | const Object* node{nullptr}; |
454 | // Set of all possible contexts in the entering moment. |
455 | std::unordered_set<int> enter_ctx; |
456 | // Set of all possible contexts in the exit moment. |
457 | std::unordered_set<int> exit_ctx; |
458 | // existing pop performed at enter |
459 | std::vector<std::pair<int, int>> enter_pop; |
460 | // existing push performed at exit |
461 | std::vector<std::pair<int, int>> exit_push; |
462 | // clear the state |
463 | void clear() { |
464 | node = nullptr; |
465 | enter_ctx.clear(); |
466 | exit_ctx.clear(); |
467 | enter_pop.clear(); |
468 | exit_push.clear(); |
469 | } |
470 | }; |
471 | // inject proper sync into the pair |
472 | // record the push/pop sequence that could be possibly un-matched. |
473 | // return the push/pop message at enter/exit of the Block |
474 | // after considering the existing unmatcheded events and added events |
475 | void InjectSync(const SyncState& prev, const SyncState& next, |
476 | std::vector<std::pair<int, int>>* prev_exit_push, |
477 | std::vector<std::pair<int, int>>* next_enter_pop) { |
478 | prev_exit_push->clear(); |
479 | next_enter_pop->clear(); |
480 | // quick path |
481 | if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && prev.exit_ctx.size() == 1 && |
482 | next.enter_ctx.size() == 1) { |
483 | int from = *prev.exit_ctx.begin(); |
484 | int to = *next.enter_ctx.begin(); |
485 | if (from != to) { |
486 | insert_after_[prev.node].emplace_back(MakePush(from, to)); |
487 | insert_before_[next.node].emplace_back(MakePop(from, to)); |
488 | prev_exit_push->emplace_back(std::make_pair(from, to)); |
489 | next_enter_pop->emplace_back(std::make_pair(from, to)); |
490 | } |
491 | return; |
492 | } |
493 | // complicate path. |
494 | std::vector<std::pair<int, int>> vpush = prev.exit_push; |
495 | std::vector<std::pair<int, int>> vpop = next.enter_pop; |
496 | std::vector<std::pair<int, int>> pending; |
497 | for (int from : prev.exit_ctx) { |
498 | for (int to : next.enter_ctx) { |
499 | if (from != to) { |
500 | pending.emplace_back(std::make_pair(from, to)); |
501 | } |
502 | } |
503 | } |
504 | // policy 1 |
505 | std::vector<Stmt> prev_after, next_before; |
506 | for (const std::pair<int, int>& p : pending) { |
507 | if (std::find(prev.exit_push.begin(), prev.exit_push.end(), p) == prev.exit_push.end()) { |
508 | vpush.push_back(p); |
509 | prev_after.emplace_back(MakePush(p.first, p.second)); |
510 | } |
511 | if (std::find(next.enter_pop.begin(), next.enter_pop.end(), p) == next.enter_pop.end()) { |
512 | vpop.push_back(p); |
513 | next_before.emplace_back(MakePop(p.first, p.second)); |
514 | } |
515 | } |
516 | // fix pending |
517 | for (const std::pair<int, int>& p : vpush) { |
518 | if (std::find(vpop.begin(), vpop.end(), p) == vpop.end()) { |
519 | prev_after.emplace_back(MakePop(p.first, p.second)); |
520 | } else { |
521 | prev_exit_push->push_back(p); |
522 | } |
523 | } |
524 | for (const std::pair<int, int>& p : vpop) { |
525 | if (std::find(vpush.begin(), vpush.end(), p) == vpush.end()) { |
526 | next_before.emplace_back(MakePush(p.first, p.second)); |
527 | } else { |
528 | next_enter_pop->push_back(p); |
529 | } |
530 | } |
531 | if (prev_after.size() != 0) { |
532 | auto& v1 = insert_after_[prev.node]; |
533 | v1.insert(v1.end(), prev_after.begin(), prev_after.end()); |
534 | } |
535 | if (next_before.size() != 0) { |
536 | auto& v2 = insert_before_[next.node]; |
537 | v2.insert(v2.end(), next_before.begin(), next_before.end()); |
538 | } |
539 | } |
540 | |
541 | void MatchFixEnterPop(const SyncState& state) { |
542 | if (state.enter_pop.size() == 0) return; |
543 | auto& vec = insert_before_[state.node]; |
544 | for (const std::pair<int, int>& p : state.enter_pop) { |
545 | vec.push_back(MakePush(p.first, p.second)); |
546 | } |
547 | } |
548 | |
549 | void MatchFixExitPush(const SyncState& state) { |
550 | if (state.exit_push.size() == 0) return; |
551 | auto& vec = insert_after_[state.node]; |
552 | for (const std::pair<int, int>& p : state.exit_push) { |
553 | vec.push_back(MakePop(p.first, p.second)); |
554 | } |
555 | } |
556 | |
557 | void UpdateState() { |
558 | if (last_state_.node != nullptr) { |
559 | std::vector<std::pair<int, int>> t1, t2; |
560 | InjectSync(last_state_, curr_state_, &t1, &t2); |
561 | std::swap(last_state_, curr_state_); |
562 | } else { |
563 | ICHECK(first_state_.node == nullptr); |
564 | first_state_ = curr_state_; |
565 | last_state_ = curr_state_; |
566 | } |
567 | } |
568 | |
569 | Stmt MakePush(int from, int to) { |
570 | return Evaluate(Call(DataType::Int(32), sync_push_op_, |
571 | {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)})); |
572 | } |
573 | Stmt MakePop(int from, int to) { |
574 | return Evaluate(Call(DataType::Int(32), sync_pop_op_, |
575 | {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)})); |
576 | } |
577 | // sync states. |
578 | SyncState first_state_, last_state_, curr_state_; |
579 | // Variables |
580 | IterVar coproc_axis_; |
581 | Op sync_push_op_, sync_pop_op_; |
582 | }; |
583 | |
584 | class CoProcSyncInserter : public StmtMutator { |
585 | public: |
586 | Stmt Insert(Stmt stmt) { |
587 | CoProcTouchedBuffer visitor; |
588 | visitor(stmt); |
589 | if (visitor.coproc_.size() == 0) return stmt; |
590 | std::unordered_set<const VarNode*> touched; |
591 | |
592 | for (const auto& kv : visitor.touched_) { |
593 | if (kv.second.normal && kv.second.coproc) { |
594 | touched.insert(kv.first); |
595 | } |
596 | } |
597 | ICHECK_EQ(visitor.coproc_.size(), 1U); |
598 | std::string coproc_name = (*visitor.coproc_.begin())->var->name_hint; |
599 | // plan sync. |
600 | CoProcSyncPlanner sync_planner(touched, coproc_name); |
601 | sync_planner.Plan(stmt); |
602 | for (const auto& kv : sync_planner.sync_) { |
603 | auto& vec = insert_after_[kv.first]; |
604 | vec.insert(vec.end(), kv.second.begin(), kv.second.end()); |
605 | } |
606 | // Detect barrier |
607 | CoProcBarrierDetector barrier_detector(touched, coproc_name); |
608 | barrier_detector.PlanReadBarrier(stmt); |
609 | barrier_detector.PlanWriteBarrier(stmt); |
610 | for (const auto& kv : barrier_detector.barrier_before_) { |
611 | auto& vec = insert_before_[kv.first]; |
612 | vec.insert(vec.end(), kv.second.begin(), kv.second.end()); |
613 | } |
614 | for (const auto& kv : barrier_detector.barrier_after_) { |
615 | auto& vec = insert_after_[kv.first]; |
616 | vec.insert(vec.end(), kv.second.begin(), kv.second.end()); |
617 | } |
618 | // Detect barrier |
619 | CoProcInstDepDetector sync_detector(*visitor.coproc_.begin(), coproc_name); |
620 | sync_detector.Plan(stmt); |
621 | for (const auto& kv : sync_detector.insert_before_) { |
622 | auto& vec = insert_before_[kv.first]; |
623 | vec.insert(vec.end(), kv.second.begin(), kv.second.end()); |
624 | } |
625 | for (const auto& kv : sync_detector.insert_after_) { |
626 | auto& vec = insert_after_[kv.first]; |
627 | vec.insert(vec.end(), kv.second.begin(), kv.second.end()); |
628 | } |
629 | return operator()(std::move(stmt)); |
630 | } |
631 | |
632 | Stmt VisitStmt(const Stmt& stmt) final { |
633 | auto it_before = insert_before_.find(stmt.get()); |
634 | auto it_after = insert_after_.find(stmt.get()); |
635 | Stmt new_stmt = StmtMutator::VisitStmt(stmt); |
636 | |
637 | return SeqStmt::Flatten( |
638 | it_before != insert_before_.end() ? it_before->second : std::vector<Stmt>(), new_stmt, |
639 | it_after != insert_after_.end() ? it_after->second : std::vector<Stmt>()); |
640 | } |
641 | |
642 | private: |
643 | // insert before is stored in reverse order |
644 | // the first element is closest to the node. |
645 | std::unordered_map<const Object*, std::vector<Stmt>> insert_before_; |
646 | std::unordered_map<const Object*, std::vector<Stmt>> insert_after_; |
647 | }; |
648 | |
649 | Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); } |
650 | |
651 | namespace transform { |
652 | |
653 | Pass CoProcSync() { |
654 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
655 | auto* n = f.CopyOnWrite(); |
656 | n->body = CoProcSyncInserter().Insert(std::move(n->body)); |
657 | return f; |
658 | }; |
659 | return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync" , {}); |
660 | } |
661 | |
662 | TVM_REGISTER_GLOBAL("tir.transform.CoProcSync" ).set_body_typed(CoProcSync); |
663 | |
664 | } // namespace transform |
665 | |
666 | } // namespace tir |
667 | } // namespace tvm |
668 | |