1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file coproc_sync.cc
22 */
23#include <tvm/runtime/registry.h>
24#include <tvm/tir/builtin.h>
25#include <tvm/tir/expr.h>
26#include <tvm/tir/stmt_functor.h>
27#include <tvm/tir/transform.h>
28
29#include <unordered_map>
30#include <unordered_set>
31
32#include "ir_utils.h"
33#include "storage_access.h"
34
35namespace tvm {
36namespace tir {
37
38// Visitor to find touched set by co-processor scope.
39class CoProcTouchedBuffer : public StmtExprVisitor {
40 public:
41 void VisitExpr_(const LoadNode* op) final {
42 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
43 }
44 void VisitStmt_(const StoreNode* op) final {
45 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
46 }
47 void VisitExpr_(const BufferLoadNode* op) final {
48 if (in_scope_) {
49 touched_[op->buffer->data.get()].coproc = true;
50 } else {
51 touched_[op->buffer->data.get()].normal = true;
52 }
53 StmtExprVisitor::VisitExpr_(op);
54 }
55 void VisitStmt_(const BufferStoreNode* op) final {
56 if (in_scope_) {
57 touched_[op->buffer->data.get()].coproc = true;
58 } else {
59 touched_[op->buffer->data.get()].normal = true;
60 }
61 StmtExprVisitor::VisitStmt_(op);
62 }
63 void VisitExpr_(const CallNode* op) final {
64 if (op->op.same_as(builtin::tvm_access_ptr())) {
65 const VarNode* buffer = op->args[1].as<VarNode>();
66 if (in_scope_) {
67 touched_[buffer].coproc = true;
68 } else {
69 touched_[buffer].normal = true;
70 }
71 }
72 StmtExprVisitor::VisitExpr_(op);
73 }
74 void VisitStmt_(const AttrStmtNode* op) final {
75 if (op->attr_key == attr::coproc_scope && !in_scope_) {
76 in_scope_ = true;
77 IterVar iv = Downcast<IterVar>(op->node);
78 coproc_.insert(iv);
79 StmtExprVisitor::VisitStmt_(op);
80 in_scope_ = false;
81 } else {
82 StmtExprVisitor::VisitStmt_(op);
83 }
84 }
85
86 // Touch Entry
87 struct TouchEntry {
88 bool normal{false};
89 bool coproc{false};
90 };
91 std::unordered_map<const VarNode*, TouchEntry> touched_;
92 std::unordered_set<IterVar> coproc_;
93
94 private:
95 bool in_scope_{false};
96};
97
98// Synchronization planning with co-processor.
99class CoProcSyncPlanner : public StorageAccessVisitor {
100 public:
101 explicit CoProcSyncPlanner(const std::unordered_set<const VarNode*>& touched,
102 const std::string& coproc_name)
103 : touched_(touched), coproc_name_(coproc_name) {}
104
105 void Plan(const Stmt& stmt) {
106 this->VisitStmt(stmt);
107 PlanSync(scope_.back(), nullptr, true);
108 if (sync_.size() == 0) {
109 sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync");
110 }
111 }
112
113 // Write synchronization to be inserted before or after stmt.
114 std::unordered_map<const Object*, std::vector<Stmt>> sync_;
115
116 protected:
117 bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
118 return touched_.count(buf);
119 }
120
121 // Plan the sync
122 std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
123 return PlanSync(seq, loop, false);
124 }
125
126 private:
127 // Plan write synchronization if write is not coherent
128 std::vector<AccessEntry> PlanSync(std::vector<StmtEntry> seq, const ForNode* loop,
129 bool force_sync_at_end) {
130 // detect write barriers
131 // access by the co-processor.
132 std::vector<AccessEntry> co_access;
133 bool contain_sync = false;
134
135 auto find_conflict = [&](const AccessEntry& acc) {
136 for (const AccessEntry& x : co_access) {
137 if (x.buffer.same_as(acc.buffer) &&
138 ((acc.type == kRead && x.type == kWrite) || acc.type == kWrite)) {
139 return true;
140 }
141 }
142 return false;
143 };
144 for (size_t i = 0; i < seq.size(); ++i) {
145 const StmtEntry& s = seq[i];
146 bool sync_write = false;
147 for (const AccessEntry& acc : s.access) {
148 if (acc.threads.size() == 0 && find_conflict(acc)) {
149 sync_write = true;
150 break;
151 }
152 if (acc.type == kSync) {
153 co_access.clear();
154 contain_sync = true;
155 }
156 }
157 if (sync_write) {
158 ICHECK_NE(i, 0U);
159 sync_[seq[i - 1].stmt] = GetSync(co_access);
160 co_access.clear();
161 contain_sync = true;
162 }
163 for (const AccessEntry& acc : s.access) {
164 if (acc.threads.size() != 0) {
165 co_access.push_back(acc);
166 }
167 }
168 }
169 bool sync_at_end = force_sync_at_end;
170 if (loop != nullptr && !sync_at_end) {
171 // loop carray dependency
172 for (size_t i = 0; i < seq.size(); ++i) {
173 const StmtEntry& s = seq[i];
174 for (const AccessEntry& acc : s.access) {
175 if (acc.threads.size() == 0 && find_conflict(acc)) {
176 sync_at_end = true;
177 break;
178 }
179 }
180 if (sync_.count(s.stmt) || sync_at_end) break;
181 }
182 }
183 if (sync_at_end && co_access.size() != 0) {
184 ICHECK_NE(seq.size(), 0);
185 contain_sync = true;
186 sync_[seq.back().stmt] = GetSync(co_access);
187 co_access.clear();
188 }
189 if (contain_sync) {
190 AccessEntry e;
191 e.type = kSync;
192 co_access.insert(co_access.begin(), e);
193 }
194 return co_access;
195 }
196 // Add write Synchronization
197 std::vector<Stmt> GetSync(const std::vector<AccessEntry>& co_access) {
198 // Does not consider memory coherence, need runtime.
199 ICHECK_NE(co_access.size(), 0U);
200 ICHECK_EQ(co_access[0].threads.size(), 1U);
201 return GetSync(coproc_name_ + ".coproc_sync");
202 }
203
204 std::vector<Stmt> GetSync(std::string sync_name) {
205 return {Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}))};
206 }
207
208 const std::unordered_set<const VarNode*>& touched_;
209 std::string coproc_name_;
210};
211
212// Detect memory barriers when coproc read/write memory
213class CoProcBarrierDetector : public StorageAccessVisitor {
214 public:
215 explicit CoProcBarrierDetector(const std::unordered_set<const VarNode*>& touched,
216 const std::string& coproc_name)
217 : touched_(touched) {
218 read_barrier_name_ = "tir." + coproc_name + ".coproc_read_barrier";
219 write_barrier_name_ = "tir." + coproc_name + ".coproc_write_barrier";
220 }
221
222 void PlanReadBarrier(const Stmt& stmt) {
223 read_barrier_ = true;
224 this->VisitStmt(stmt);
225 PlanReadBarrier(scope_.back(), nullptr);
226 }
227 void PlanWriteBarrier(const Stmt& stmt) {
228 read_barrier_ = false;
229 this->VisitStmt(stmt);
230 PlanWriteBarrier(scope_.back(), nullptr);
231 }
232
233 std::unordered_map<const Object*, std::vector<Stmt>> barrier_before_;
234 std::unordered_map<const Object*, std::vector<Stmt>> barrier_after_;
235
236 protected:
237 bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
238 return touched_.count(buf);
239 }
240
241 // Plan the sync
242 std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
243 if (read_barrier_) {
244 return PlanReadBarrier(seq, loop);
245 } else {
246 return PlanWriteBarrier(seq, loop);
247 }
248 }
249
250 private:
251 // Plan write barrier at Read after write point.
252 std::vector<AccessEntry> PlanWriteBarrier(std::vector<StmtEntry> seq, const ForNode* loop) {
253 std::vector<AccessEntry> read_seq;
254 std::unordered_map<const VarNode*, std::vector<AccessEntry>> write_set;
255
256 auto fupdate = [&](size_t i, const AccessEntry& acc) {
257 auto it = write_set.find(acc.buffer.get());
258 if (it != write_set.end()) {
259 ICHECK_NE(i, 0U);
260 barrier_after_[seq[i - 1].stmt].push_back(MakeBarrier(write_barrier_name_, it->second));
261 write_set.erase(it);
262 }
263 };
264 for (size_t i = 0; i < seq.size(); ++i) {
265 const StmtEntry& s = seq[i];
266 for (const AccessEntry& acc : s.access) {
267 if (acc.threads.size() == 0 && acc.type == kRead) {
268 fupdate(i, acc);
269 read_seq.push_back(acc);
270 }
271 }
272 for (const AccessEntry& acc : s.access) {
273 if (acc.threads.size() != 0 && acc.type == kWrite) {
274 write_set[acc.buffer.get()].push_back(acc);
275 }
276 }
277 }
278 // loop carry
279 if (loop != nullptr) {
280 for (const AccessEntry& acc : read_seq) {
281 fupdate(seq.size(), acc);
282 }
283 }
284 for (const auto& kv : write_set) {
285 read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end());
286 }
287 return read_seq;
288 }
289
290 std::vector<AccessEntry> PlanReadBarrier(std::vector<StmtEntry> seq, const ForNode* loop) {
291 std::vector<AccessEntry> write_seq;
292 std::unordered_map<const VarNode*, std::vector<AccessEntry>> read_set;
293
294 auto fupdate = [&](size_t i, const AccessEntry& acc) {
295 auto it = read_set.find(acc.buffer.get());
296 if (it != read_set.end()) {
297 ICHECK_NE(i, seq.size());
298 barrier_before_[seq[i].stmt].push_back(MakeBarrier(read_barrier_name_, it->second));
299 read_set.erase(it);
300 }
301 };
302
303 for (size_t i = seq.size(); i != 0; --i) {
304 const StmtEntry& s = seq[i - 1];
305 for (const AccessEntry& acc : s.access) {
306 if (acc.threads.size() == 0 && acc.type == kWrite) {
307 fupdate(i, acc);
308 write_seq.push_back(acc);
309 }
310 }
311 for (const AccessEntry& acc : s.access) {
312 if (acc.threads.size() != 0 && acc.type == kRead) {
313 read_set[acc.buffer.get()].push_back(acc);
314 }
315 }
316 }
317 // loop carry
318 if (loop != nullptr) {
319 for (const AccessEntry& acc : write_seq) {
320 fupdate(0, acc);
321 }
322 }
323 for (const auto& kv : read_set) {
324 write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end());
325 }
326 return write_seq;
327 }
328
329 Stmt MakeBarrier(const std::string& func, const std::vector<AccessEntry>& wvec) {
330 // insert write point
331 Array<arith::IntSet> wset;
332 for (const AccessEntry& acc : wvec) {
333 ICHECK(acc.dtype == wvec[0].dtype);
334 ICHECK_EQ(acc.touched.size(), 1) << "CoProcBarrierDetector expects flat memory";
335 wset.push_back(acc.touched[0]);
336 }
337 Range none;
338 Range r = arith::Union(wset).CoverRange(none);
339 ICHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer;
340 PrimExpr min = r->min;
341 PrimExpr extent = r->extent;
342 return Evaluate(Call(DataType::Int(32), Op::Get(func),
343 {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}));
344 }
345 // Write barrier name
346 bool read_barrier_{false};
347 std::string read_barrier_name_;
348 std::string write_barrier_name_;
349 const std::unordered_set<const VarNode*>& touched_;
350};
351
352class CoProcInstDepDetector : public StmtVisitor {
353 public:
354 explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name)
355 : coproc_axis_(coproc_axis) {
356 sync_push_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_push");
357 sync_pop_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_pop");
358 }
359
360 void Plan(const Stmt& stmt) {
361 this->VisitStmt(stmt);
362 if (last_state_.node != nullptr) {
363 MatchFixEnterPop(first_state_);
364 MatchFixExitPush(last_state_);
365 }
366 }
367
368 void VisitStmt_(const AttrStmtNode* op) final {
369 if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) {
370 const IntImmNode* ctx_id = op->value.as<IntImmNode>();
371 ICHECK(ctx_id != nullptr);
372 curr_state_.clear();
373 curr_state_.node = op->body.get();
374 curr_state_.enter_ctx.insert(ctx_id->value);
375 curr_state_.exit_ctx.insert(ctx_id->value);
376 UpdateState();
377 } else {
378 StmtVisitor::VisitStmt_(op);
379 }
380 }
381
382 void VisitStmt_(const ForNode* op) final {
383 SyncState temp_first, temp_last;
384 std::swap(first_state_, temp_first);
385 std::swap(last_state_, temp_last);
386 this->VisitStmt(op->body);
387 curr_state_.clear();
388 if (last_state_.node != nullptr) {
389 curr_state_.node = op;
390 ICHECK(first_state_.node != nullptr);
391 // loop carry dependency
392 InjectSync(last_state_, first_state_, &(curr_state_.exit_push), &(curr_state_.enter_pop));
393 curr_state_.enter_ctx = first_state_.enter_ctx;
394 curr_state_.exit_ctx = last_state_.exit_ctx;
395 }
396 std::swap(first_state_, temp_first);
397 std::swap(last_state_, temp_last);
398 if (curr_state_.node != nullptr) {
399 UpdateState();
400 }
401 }
402
403 void VisitStmt_(const IfThenElseNode* op) final {
404 SyncState temp_first, temp_last, curr_state;
405 std::swap(first_state_, temp_first);
406 std::swap(last_state_, temp_last);
407 {
408 // then stmt
409 this->VisitStmt(op->then_case);
410 if (last_state_.node != nullptr) {
411 curr_state.node = op;
412 MatchFixEnterPop(first_state_);
413 MatchFixExitPush(last_state_);
414 curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end());
415 curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end());
416 }
417 first_state_.clear();
418 last_state_.clear();
419 }
420 if (op->else_case) {
421 this->VisitStmt(op->else_case.value());
422 if (last_state_.node != nullptr) {
423 curr_state.node = op;
424 MatchFixEnterPop(first_state_);
425 MatchFixExitPush(last_state_);
426 curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end());
427 curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end());
428 }
429 }
430 // update in the trace.
431 std::swap(first_state_, temp_first);
432 std::swap(last_state_, temp_last);
433 std::swap(curr_state_, curr_state);
434 if (curr_state_.node != nullptr) {
435 UpdateState();
436 }
437 }
438
439 void VisitStmt_(const WhileNode* op) final {
440 // TODO(masahi): Do we need a special handling for While nodes?
441 LOG(FATAL) << "WhileNode not supported in CoProcSync.";
442 }
443
444 // insert before is stored in reverse order
445 // the first element is closest to the node.
446 std::unordered_map<const Object*, std::vector<Stmt>> insert_before_;
447 std::unordered_map<const Object*, std::vector<Stmt>> insert_after_;
448
449 private:
450 // state in the sync entry
451 struct SyncState {
452 // The statement of the state.
453 const Object* node{nullptr};
454 // Set of all possible contexts in the entering moment.
455 std::unordered_set<int> enter_ctx;
456 // Set of all possible contexts in the exit moment.
457 std::unordered_set<int> exit_ctx;
458 // existing pop performed at enter
459 std::vector<std::pair<int, int>> enter_pop;
460 // existing push performed at exit
461 std::vector<std::pair<int, int>> exit_push;
462 // clear the state
463 void clear() {
464 node = nullptr;
465 enter_ctx.clear();
466 exit_ctx.clear();
467 enter_pop.clear();
468 exit_push.clear();
469 }
470 };
471 // inject proper sync into the pair
472 // record the push/pop sequence that could be possibly un-matched.
473 // return the push/pop message at enter/exit of the Block
474 // after considering the existing unmatcheded events and added events
475 void InjectSync(const SyncState& prev, const SyncState& next,
476 std::vector<std::pair<int, int>>* prev_exit_push,
477 std::vector<std::pair<int, int>>* next_enter_pop) {
478 prev_exit_push->clear();
479 next_enter_pop->clear();
480 // quick path
481 if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && prev.exit_ctx.size() == 1 &&
482 next.enter_ctx.size() == 1) {
483 int from = *prev.exit_ctx.begin();
484 int to = *next.enter_ctx.begin();
485 if (from != to) {
486 insert_after_[prev.node].emplace_back(MakePush(from, to));
487 insert_before_[next.node].emplace_back(MakePop(from, to));
488 prev_exit_push->emplace_back(std::make_pair(from, to));
489 next_enter_pop->emplace_back(std::make_pair(from, to));
490 }
491 return;
492 }
493 // complicate path.
494 std::vector<std::pair<int, int>> vpush = prev.exit_push;
495 std::vector<std::pair<int, int>> vpop = next.enter_pop;
496 std::vector<std::pair<int, int>> pending;
497 for (int from : prev.exit_ctx) {
498 for (int to : next.enter_ctx) {
499 if (from != to) {
500 pending.emplace_back(std::make_pair(from, to));
501 }
502 }
503 }
504 // policy 1
505 std::vector<Stmt> prev_after, next_before;
506 for (const std::pair<int, int>& p : pending) {
507 if (std::find(prev.exit_push.begin(), prev.exit_push.end(), p) == prev.exit_push.end()) {
508 vpush.push_back(p);
509 prev_after.emplace_back(MakePush(p.first, p.second));
510 }
511 if (std::find(next.enter_pop.begin(), next.enter_pop.end(), p) == next.enter_pop.end()) {
512 vpop.push_back(p);
513 next_before.emplace_back(MakePop(p.first, p.second));
514 }
515 }
516 // fix pending
517 for (const std::pair<int, int>& p : vpush) {
518 if (std::find(vpop.begin(), vpop.end(), p) == vpop.end()) {
519 prev_after.emplace_back(MakePop(p.first, p.second));
520 } else {
521 prev_exit_push->push_back(p);
522 }
523 }
524 for (const std::pair<int, int>& p : vpop) {
525 if (std::find(vpush.begin(), vpush.end(), p) == vpush.end()) {
526 next_before.emplace_back(MakePush(p.first, p.second));
527 } else {
528 next_enter_pop->push_back(p);
529 }
530 }
531 if (prev_after.size() != 0) {
532 auto& v1 = insert_after_[prev.node];
533 v1.insert(v1.end(), prev_after.begin(), prev_after.end());
534 }
535 if (next_before.size() != 0) {
536 auto& v2 = insert_before_[next.node];
537 v2.insert(v2.end(), next_before.begin(), next_before.end());
538 }
539 }
540
541 void MatchFixEnterPop(const SyncState& state) {
542 if (state.enter_pop.size() == 0) return;
543 auto& vec = insert_before_[state.node];
544 for (const std::pair<int, int>& p : state.enter_pop) {
545 vec.push_back(MakePush(p.first, p.second));
546 }
547 }
548
549 void MatchFixExitPush(const SyncState& state) {
550 if (state.exit_push.size() == 0) return;
551 auto& vec = insert_after_[state.node];
552 for (const std::pair<int, int>& p : state.exit_push) {
553 vec.push_back(MakePop(p.first, p.second));
554 }
555 }
556
557 void UpdateState() {
558 if (last_state_.node != nullptr) {
559 std::vector<std::pair<int, int>> t1, t2;
560 InjectSync(last_state_, curr_state_, &t1, &t2);
561 std::swap(last_state_, curr_state_);
562 } else {
563 ICHECK(first_state_.node == nullptr);
564 first_state_ = curr_state_;
565 last_state_ = curr_state_;
566 }
567 }
568
569 Stmt MakePush(int from, int to) {
570 return Evaluate(Call(DataType::Int(32), sync_push_op_,
571 {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}));
572 }
573 Stmt MakePop(int from, int to) {
574 return Evaluate(Call(DataType::Int(32), sync_pop_op_,
575 {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}));
576 }
577 // sync states.
578 SyncState first_state_, last_state_, curr_state_;
579 // Variables
580 IterVar coproc_axis_;
581 Op sync_push_op_, sync_pop_op_;
582};
583
584class CoProcSyncInserter : public StmtMutator {
585 public:
586 Stmt Insert(Stmt stmt) {
587 CoProcTouchedBuffer visitor;
588 visitor(stmt);
589 if (visitor.coproc_.size() == 0) return stmt;
590 std::unordered_set<const VarNode*> touched;
591
592 for (const auto& kv : visitor.touched_) {
593 if (kv.second.normal && kv.second.coproc) {
594 touched.insert(kv.first);
595 }
596 }
597 ICHECK_EQ(visitor.coproc_.size(), 1U);
598 std::string coproc_name = (*visitor.coproc_.begin())->var->name_hint;
599 // plan sync.
600 CoProcSyncPlanner sync_planner(touched, coproc_name);
601 sync_planner.Plan(stmt);
602 for (const auto& kv : sync_planner.sync_) {
603 auto& vec = insert_after_[kv.first];
604 vec.insert(vec.end(), kv.second.begin(), kv.second.end());
605 }
606 // Detect barrier
607 CoProcBarrierDetector barrier_detector(touched, coproc_name);
608 barrier_detector.PlanReadBarrier(stmt);
609 barrier_detector.PlanWriteBarrier(stmt);
610 for (const auto& kv : barrier_detector.barrier_before_) {
611 auto& vec = insert_before_[kv.first];
612 vec.insert(vec.end(), kv.second.begin(), kv.second.end());
613 }
614 for (const auto& kv : barrier_detector.barrier_after_) {
615 auto& vec = insert_after_[kv.first];
616 vec.insert(vec.end(), kv.second.begin(), kv.second.end());
617 }
618 // Detect barrier
619 CoProcInstDepDetector sync_detector(*visitor.coproc_.begin(), coproc_name);
620 sync_detector.Plan(stmt);
621 for (const auto& kv : sync_detector.insert_before_) {
622 auto& vec = insert_before_[kv.first];
623 vec.insert(vec.end(), kv.second.begin(), kv.second.end());
624 }
625 for (const auto& kv : sync_detector.insert_after_) {
626 auto& vec = insert_after_[kv.first];
627 vec.insert(vec.end(), kv.second.begin(), kv.second.end());
628 }
629 return operator()(std::move(stmt));
630 }
631
632 Stmt VisitStmt(const Stmt& stmt) final {
633 auto it_before = insert_before_.find(stmt.get());
634 auto it_after = insert_after_.find(stmt.get());
635 Stmt new_stmt = StmtMutator::VisitStmt(stmt);
636
637 return SeqStmt::Flatten(
638 it_before != insert_before_.end() ? it_before->second : std::vector<Stmt>(), new_stmt,
639 it_after != insert_after_.end() ? it_after->second : std::vector<Stmt>());
640 }
641
642 private:
643 // insert before is stored in reverse order
644 // the first element is closest to the node.
645 std::unordered_map<const Object*, std::vector<Stmt>> insert_before_;
646 std::unordered_map<const Object*, std::vector<Stmt>> insert_after_;
647};
648
649Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); }
650
651namespace transform {
652
653Pass CoProcSync() {
654 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
655 auto* n = f.CopyOnWrite();
656 n->body = CoProcSyncInserter().Insert(std::move(n->body));
657 return f;
658 };
659 return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {});
660}
661
662TVM_REGISTER_GLOBAL("tir.transform.CoProcSync").set_body_typed(CoProcSync);
663
664} // namespace transform
665
666} // namespace tir
667} // namespace tvm
668