1#include <dispatch.h>
2#include <instrumentation.h>
3#include <ir_utils.h>
4#include <kernel_ir.h>
5#include <kernel_ir_dispatch.h>
6#include <lower2device.h>
7#include <lower_insert_syncs.h>
8#include <lower_utils.h>
9
10#include <unordered_set>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17namespace {
18
19//! Scan through Kernel IR for-loops to insert Sync nodes to avoid
20//! Write-After-Read (WAR) race condition.
21//!
22//! Example:
23//! for () {
24//! smem_buf[threadIdx.x] = x;
25//! __syncthreads();
26//! buf[threadId.x] = smem_buf[threadIdx.x + 1];
27//! }
28//!
29//! In this case, additional syncthreads is needed at the end of the
30//! loop body to avoid a hazard with smem_buf.
31
32//! Keeping track the allocations of SMEM TVs
33class SmemAllocMap {
34 public:
35 //! Insert a new node if it's a SMEM allocation
36 void insert(kir::Allocate* alloc) {
37 if (auto tv = dynamic_cast<TensorView*>(alloc->buffer())) {
38 if (tv->getMemoryType() == MemoryType::Shared) {
39 // Note that a TensorView can have two allocations due to
40 // unswitch.
41 auto p = map_.insert({tv, alloc});
42 // If there's an existing entry, reset it with the new
43 // alloc. Currently, the existing alloc is actually the same
44 // as the new one as each expression is just inserted to both
45 // then and else parts of the unswitched loop, but this should
46 // be changed.
47 if (!p.second) {
48 p.first->second = alloc;
49 }
50 }
51 }
52 }
53
54 //! Run through aliases to get the buffer that is actually allocated for a
55 //! given TV
56 TensorView* getRealBuffer(TensorView* tv) const {
57 auto it = map_.find(tv);
58 TORCH_INTERNAL_ASSERT(
59 it != map_.end(), "Allocation not found for ", tv->toString());
60 const kir::Allocate* alloc = it->second;
61 while (alloc->alias()) {
62 alloc = alloc->alias();
63 }
64 auto buf = alloc->buffer();
65 TORCH_INTERNAL_ASSERT(buf->isA<TensorView>());
66 return buf->as<TensorView>();
67 }
68
69 private:
70 std::unordered_map<TensorView*, kir::Allocate*> map_;
71};
72
73struct WarMemoryInfo {
74 // True if there's a sync after the last read within the alloc loop.
75 bool sync_after_read = false;
76
77 // True if there's a sync before the first write. There can be multiple writes
78 // from memory aliasing.
79 bool sync_before_write = false;
80
81 // Has there been a read of this memory location
82 bool read_hit = false;
83
84 // Has there been *the* write to this memory location, assumes single write
85 // instruction (needs to be before conditionals added to code)
86 bool write_hit = false;
87
88 // For loop this TV is compute_at'ed in.
89 kir::ForLoop* ca_loop = nullptr;
90};
91
92// To prevent shared memory from being over written before it is read, a
93// synchronization point has to be inserted either between the allocation of an
94// SMEM buffer and where we write into it, or after the buffer's last read
95// before exiting the allocation's scope.
96//
97// e.g.
98// for i:
99// "alloc A" in shared memory - This is really marked by the compute_at point
100// sync_loc_0
101// for j:
102// sync_loc_1
103// for k:
104// sync_loc_2
105// A = ...
106// for k:
107// ... = ... A
108// for j:
109// for k:
110// ... = ... A
111// sync_loc_3
112// sync_loc_4
113// sync_loc_5
114//
115// All sync locations here provide valid protection that memory in A is finished
116// being read before it is over written in the next iteration
117//
118// Insertion of sync threads will be done from the inner most position to the
119// outer most. If a sync protecting the buffer is not already placed, the
120// location prefered for the sync threads is the last possible position. One
121// future optimization could be to not sync on the last iteration of the loop
122// the sync is placed in.
123class WarSyncInserter : private kir::ExprMutator {
124 public:
125 static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) {
126 WarSyncInserter inserter(exprs);
127 return inserter.exprs_;
128 }
129
130 private:
131 //! Insert Sync nodes at the end of a given for-loop when a WAR
132 //! hazard may happen.
133 WarSyncInserter(const std::vector<Expr*>& exprs) {
134 auto& lower_alloc_info_map = GpuLower::current()->localAllocationInfoMap();
135 for (const auto& entry : lower_alloc_info_map) {
136 alloc_map_.insert(entry.first);
137 }
138 kir::ExprMutator::traverseAndInsert(exprs);
139 }
140
141 void handle(kir::IfThenElse* ite) final {
142 TORCH_INTERNAL_ASSERT(
143 ite->elseBody().empty(),
144 "Pass does not support conditional flow,",
145 " needs to be done before conditional execution is lowered.");
146 kir::ExprMutator::handle(ite);
147 }
148
149 void handle(kir::BlockSync* sync) final {
150 // Register the sync for the active for loop
151 sync_hit_.back() = true;
152 // Run through the active allocations, if a read was hit, register there was
153 // a sync after the read. If there's subsequent reads on this buffer the
154 // sync_after_read will be cleared.
155 for (auto& entry : smem_allocations_) {
156 auto& alloc_stack = entry.second;
157 if (alloc_stack.back().read_hit) {
158 alloc_stack.back().sync_after_read = true;
159 }
160 }
161 }
162
163 void handle(kir::GridSync* sync) final {
164 // Register the sync for the active for loop
165 sync_hit_.back() = true;
166 // Run through the active allocations, if a read was hit, register there was
167 // a sync after the read. If there's subsequent reads on this buffer the
168 // sync_after_read will be cleared.
169 for (auto& entry : smem_allocations_) {
170 auto& alloc_stack = entry.second;
171 if (alloc_stack.back().read_hit) {
172 alloc_stack.back().sync_after_read = true;
173 }
174 }
175 }
176
177 // Checks if fl or loops within it have hit a sync
178 bool syncWithin(kir::ForLoop* fl) {
179 // If outer most scope check the first sync_hit_ position
180 if (fl == nullptr) {
181 return sync_hit_[0];
182 }
183
184 // Find the for loop we want to look within
185 auto fl_it = std::find(for_loops_.begin(), for_loops_.end(), fl);
186
187 // Convert it to an index, but add one for the outer most scope
188 auto fl_i = std::distance(for_loops_.begin(), fl_it) + 1;
189
190 // Start at that index and see if there's syncs within that for loop
191 for (auto i : c10::irange(fl_i, sync_hit_.size())) {
192 if (sync_hit_[i]) {
193 return true;
194 }
195 }
196 return false;
197 }
198
199 void handle(Expr* expr) final {
200 // If not a tensor view expression continue with dispatch
201 if (!ir_utils::isTvOp(expr)) {
202 kir::ExprMutator::handle(expr);
203 return;
204 }
205
206 // Mark write has been hit for all output tvs
207 auto out_tvs = ir_utils::filterByType<TensorView>(expr->outputs());
208 for (auto out_tv : out_tvs) {
209 if (out_tv->getMemoryType() != MemoryType::Shared ||
210 GpuLower::current()->syncMap().needsRawSync(out_tv).none()) {
211 continue;
212 }
213
214 auto& entry = getMemInfo(out_tv);
215
216 // If this is the first write and there's a sync in one of the loops after
217 // the compute at loop, then this buffer is protected.
218 if (syncWithin(entry.ca_loop) && !entry.write_hit) {
219 entry.sync_before_write = true;
220 }
221 entry.write_hit = true;
222 }
223
224 // Mark read was hit, if sync_after_read was set, clear it.
225 auto inp_tvs = ir_utils::filterByType<TensorView>(expr->inputs());
226 for (auto inp_tv : inp_tvs) {
227 if (inp_tv->getMemoryType() != MemoryType::Shared ||
228 GpuLower::current()->syncMap().needsRawSync(inp_tv).none()) {
229 continue;
230 }
231
232 auto& entry = getMemInfo(inp_tv);
233 entry.read_hit = true;
234 // Clear the sync_after_read if it was set because there was another write
235 entry.sync_after_read = false;
236 }
237 }
238
239 void handle(kir::ForLoop* for_loop) final {
240 // Push loop scope information
241 auto prev_within_iter_loop_ = within_iter_loop_;
242 sync_hit_.push_back(false);
243
244 // If there is no real iterating loop WAR syncs aren't necessary
245 within_iter_loop_ = within_iter_loop_ || !for_loop->isTrivial();
246
247 // Process the expressions in the for loop
248 kir::ExprMutator::handle(for_loop);
249
250 // Sync analysis and cleanup:
251 //
252 // Pop for loop stack inside WarMemoryInfo structs if they match this one.
253 // Erase empty entries so we don't continue to search over them
254 //
255 // Insert sync at end of this for loop if any of the entries require
256 std::vector<TensorView*> to_erase;
257 bool insert_sync = false;
258 for (auto& entry : smem_allocations_) {
259 auto& alloc_stack = entry.second;
260 if (alloc_stack.size() && alloc_stack.back().ca_loop == for_loop) {
261 if (!alloc_stack.back().sync_after_read &&
262 !alloc_stack.back().sync_before_write) {
263 insert_sync = within_iter_loop_;
264 }
265
266 alloc_stack.pop_back();
267 if (alloc_stack.empty()) {
268 to_erase.push_back(entry.first);
269 }
270 }
271 }
272
273 for (auto tv : to_erase) {
274 smem_allocations_.erase(tv);
275 }
276
277 // WAR Sync is necessary in this loop, register its insertion.
278 if (insert_sync) {
279 auto sync_expr = IrBuilder::create<kir::BlockSync>(true);
280 kir::ExprMutator::registerInsertAfter(
281 for_loop->body().exprs().back(), sync_expr, &for_loop->body());
282 handle(sync_expr);
283 }
284
285 // Pop for loop scope information
286 sync_hit_.pop_back();
287 within_iter_loop_ = prev_within_iter_loop_;
288 }
289
290 // Create a new WarMemoryInfo entry if required and return a reference to it,
291 // else return the WarMemoryInfo associated with tv
292 WarMemoryInfo& getMemInfo(TensorView* tv) {
293 auto maybe_aliased_tv = alloc_map_.getRealBuffer(tv);
294 auto alloc_it = smem_allocations_.find(maybe_aliased_tv);
295 auto ca_loop =
296 lower_utils::getAllocInformation(tv, for_loops_).init_for_loop;
297 if (alloc_it == smem_allocations_.end()) {
298 WarMemoryInfo mem_info;
299 mem_info.ca_loop = ca_loop;
300 auto entry_it =
301 smem_allocations_
302 .insert(std::make_pair(
303 maybe_aliased_tv, std::vector<WarMemoryInfo>({mem_info})))
304 .first;
305 return entry_it->second.back();
306 } else if (
307 maybe_aliased_tv != tv && alloc_it->second.back().ca_loop != ca_loop) {
308 WarMemoryInfo mem_info;
309 mem_info.ca_loop = ca_loop;
310 auto& alloc_stack = alloc_it->second;
311 alloc_stack.push_back(mem_info);
312 return alloc_stack.back();
313 }
314 return alloc_it->second.back();
315 }
316
317 //! Allocation map of SMEM buffers. Needed because of SMEM buffer aliasing,
318 //! need to track the root of the alias to properly insert WAR hazard syncs
319 SmemAllocMap alloc_map_;
320
321 //! Is there a loop nest that has a non-trivial iteration (extent != 1) and
322 //! not bound to a block/thread. This indicates if a WAR sync is necessary,
323 //! otherwise the Expr is not in an iterating for loop.
324 bool within_iter_loop_ = false;
325
326 // Track which loops have hit a sync. Used to see if there's a sync before
327 // write.
328 std::vector<bool> sync_hit_ = {false};
329
330 // Keep track of the active allocations we need to protect. Key is the
331 // "getRealBuffer", not the raw tv. There can be multiple WarMemoryInfo's
332 // because of aliasing. If the "getRealBuffer" tv has a compute at outside the
333 // alias tv, each aliased tv in a unique ca_loop has to be tracked separately
334 // for WAR insertion.
335 std::unordered_map<TensorView*, std::vector<WarMemoryInfo>> smem_allocations_;
336};
337
338class ValidatePlacementAfterWrites : private kir::IrVisitor {
339 public:
340 //! Validate no expr in writes found under loop
341 static void validate(
342 kir::ForLoop* loop,
343 const std::unordered_set<Expr*>& writes) {
344 ValidatePlacementAfterWrites validator(writes);
345 validator.handle(loop);
346 }
347
348 private:
349 using kir::IrVisitor::handle;
350
351 ValidatePlacementAfterWrites(const std::unordered_set<Expr*>& writes)
352 : writes_(writes) {}
353
354 void handle(Expr* expr) final {
355 if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
356 kir::IrVisitor::handle(expr);
357 } else {
358 TORCH_INTERNAL_ASSERT(
359 writes_.find(expr) == writes_.end(),
360 "Block sync must be placed after ",
361 expr->toString());
362 }
363 }
364
365 private:
366 const std::unordered_set<Expr*>& writes_;
367};
368
369namespace {
370
371Val* getGridSyncBufferSize(const ParallelTypeBitmap& ptb) {
372 // See the comment above for getGridCommWorkBufferSize.
373 TORCH_INTERNAL_ASSERT(
374 ptb.hasBID(),
375 "Detected needing a grid sync but no grid bits set in bitmap.");
376 Val* buffer_size = GpuLower::current()->kernel()->oneVal();
377 for (auto pt : kParallelTypeBIDs) {
378 // Synchronized within pt, so all blocks of this PT use the same
379 // sync buffer location, and thus no need to expand the sync
380 // buffer size.
381 if (ptb.get(pt)) {
382 continue;
383 }
384 auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt);
385 if (pt_dim == nullptr || pt_dim->isOneInt()) {
386 continue;
387 }
388 buffer_size = IrBuilder::mulExpr(buffer_size, pt_dim);
389 }
390 return buffer_size;
391}
392
393} // namespace
394
395class ReadAfterWriteSyncs : public kir::ExprMutator {
396 private:
397 using kir::ExprMutator::handle;
398
399 //! Traverse up the loop stack from loops_it and if a halo loop is
400 //! found, place a given sync expr before the outer-most halo loop.
401 // TODO: What needs to be done here for gmem comm?
402 bool insertBeforeHaloLoop(
403 std::vector<kir::ForLoop*>::iterator loops_it,
404 Expr* sync_expr,
405 Expr* maybe_alloc,
406 const std::unordered_set<Expr*>& writes) {
407 std::vector<kir::ForLoop*>::iterator halo_loop_it;
408 bool halo_loop_found = false;
409
410 while (true) {
411 if ((*loops_it)->iter_domain()->isThreadDim() &&
412 (*loops_it)->iter_domain()->extent() != (*loops_it)->stop()) {
413 halo_loop_found = true;
414 halo_loop_it = loops_it;
415 }
416
417 if (loops_it == for_loops_.begin()) {
418 break;
419 }
420 --loops_it;
421 }
422
423 // No halo loop found. Do not place the sync expr here. Return
424 // false to indicate nothing is done.
425 if (!halo_loop_found) {
426 return false;
427 }
428
429 auto halo_loop = *halo_loop_it;
430
431 // Make sure there's no write to the smem buffer inside the halo
432 // loop. syncthreads is moved before the halo loop, so having
433 // writes inside the loop invalidates the consistency.
434 ValidatePlacementAfterWrites::validate(halo_loop, writes);
435
436 if (halo_loop_it == for_loops_.begin()) {
437 // place in global scope
438 auto place_before_it = std::find(exprs_.begin(), exprs_.end(), halo_loop);
439 TORCH_INTERNAL_ASSERT(place_before_it != exprs_.end());
440 exprs_.insert(place_before_it, sync_expr);
441 } else {
442 auto place_in = *(halo_loop_it - 1);
443 kir::ExprMutator::registerInsertBefore(
444 halo_loop, sync_expr, &place_in->body());
445 if (maybe_alloc != nullptr) {
446 kir::ExprMutator::registerInsertBefore(
447 halo_loop, maybe_alloc, &place_in->body());
448 }
449 }
450
451 return true;
452 }
453
454 void handle(Expr* expr) final {
455 if (!ir_utils::isTvOp(expr) || expr->isA<kir::Allocate>()) {
456 kir::ExprMutator::handle(expr);
457 return;
458 }
459
460 // An identical but separate flow of timing for cpasync_wait.
461 // The insertion and tracking mechanism is the same as RAW
462 // sync insertion since cp.async only writes smem.
463 // Currently the only interaction which is realized by the
464 // ordering in this function is that in the case when we need both a
465 // cpasync wait and a block sync before the same expr, we want
466 // to place the wait before the block sync, since currently there shouldn't
467 // be any normal case where we explicitly want the wait after a block sync.
468 if (cpasync_wait_before_.size() > 0 &&
469 cpasync_wait_before_.front() == expr) {
470 cpasync_wait_before_.pop_front();
471 auto last_writes = last_cpasync_writes_.front();
472 last_cpasync_writes_.pop_front();
473
474 auto sync_expr = IrBuilder::create<kir::CpAsyncWait>();
475 insertSyncExpr(last_writes, expr, sync_expr, nullptr);
476 }
477
478 if (sync_before_.size() > 0 && sync_before_.front().first == expr) {
479 auto sync_bitmap = sync_before_.front().second;
480 sync_before_.pop_front();
481 auto last_writes = last_writes_.front();
482 last_writes_.pop_front();
483 // Found that a sync is needed
484
485 // TODO: Explicitly test the 3 cases below
486 Expr* sync_expr = nullptr;
487 kir::Allocate* maybe_alloc = nullptr;
488 if (sync_bitmap.hasBID()) {
489 maybe_alloc = lower_utils::allocGlobalBufferForGridComm(
490 getGridSyncBufferSize(sync_bitmap), DataType::Int, true);
491 sync_expr = IrBuilder::create<kir::GridSync>(
492 sync_bitmap, maybe_alloc->buffer());
493 } else {
494 sync_expr = IrBuilder::create<kir::BlockSync>(false); // is not war sync
495 }
496
497 insertSyncExpr(last_writes, expr, sync_expr, maybe_alloc);
498 }
499 }
500
501 // Find where a sync needs to be inserted and insert the given sync.
502 // This is very similar to how allocations are placed, simply place sync
503 // before the expression at the common alloc point of producers (really
504 // last_writes because we may have other exprs we're syncing besides the
505 // producers of this one)
506 void insertSyncExpr(
507 const std::unordered_set<Expr*>& last_writes,
508 Expr* insert_before_expr,
509 Expr* sync_expr,
510 Expr* maybe_alloc) {
511 // The expressions in last_writes are those we're protecting the read
512 // from. To figure out which loop we need a syncthread in, take the inner
513 // most compute at for loop of all the outputs of the last writes.
514 std::unordered_set<kir::ForLoop*> sync_within;
515
516 for (auto last_write : last_writes) {
517 auto write_out_tv = ir_utils::getTvOutput(last_write);
518 TORCH_INTERNAL_ASSERT(
519 write_out_tv != nullptr,
520 "Error in RAW sync insertion, expecting a TV expr, but didn't find one.");
521 if (write_out_tv->getComputeAtPosition() == 0) {
522 continue;
523 }
524
525 auto local_id =
526 write_out_tv->axis((int)write_out_tv->getComputeAtPosition() - 1);
527
528 auto loops_it = std::find_if(
529 for_loops_.begin(), for_loops_.end(), [&local_id](const auto& loop) {
530 return GpuLower::current()->caMap()->areMapped(
531 loop->iter_domain(), local_id, IdMappingMode::PERMISSIVE);
532 });
533
534 TORCH_INTERNAL_ASSERT(
535 loops_it != for_loops_.end(),
536 "Could not find loop associated with the alloc position of ",
537 write_out_tv->toString());
538
539 sync_within.emplace(*loops_it);
540 }
541
542 // The for loop the sync needs to be in
543 kir::ForLoop* sync_within_fl = nullptr;
544 for (auto fl : for_loops_) {
545 if (sync_within.count(fl)) {
546 sync_within_fl = fl;
547 }
548 }
549
550 if (sync_within_fl == nullptr) {
551 // Sync should be placed at global scope, after its outer most loop if
552 // it has one.
553 Expr* place_before =
554 for_loops_.size() > 0 ? for_loops_[0] : insert_before_expr;
555 // Find location in exprs_
556 auto place_before_it =
557 std::find(exprs_.begin(), exprs_.end(), place_before);
558 TORCH_INTERNAL_ASSERT(
559 place_before_it != exprs_.end(),
560 "Could not figure out where to place synchronization. ",
561 "Tried to place after, ",
562 place_before->toString(),
563 ", but could not find this expression at the global scope.");
564 if (maybe_alloc != nullptr) {
565 registerInsertBefore(place_before, maybe_alloc, nullptr);
566 }
567 registerInsertBefore(*(place_before_it), sync_expr, nullptr);
568 } else {
569 auto sync_within_loop_it =
570 std::find(for_loops_.begin(), for_loops_.end(), sync_within_fl);
571
572 // block sync must be placed before halo-extended loops
573 if (insertBeforeHaloLoop(
574 sync_within_loop_it, sync_expr, maybe_alloc, last_writes)) {
575 return;
576 }
577
578 auto place_in = *sync_within_loop_it;
579 Expr* place_before = nullptr;
580
581 if (sync_within_loop_it + 1 == for_loops_.end()) {
582 // Inline, place before expr
583 place_before = insert_before_expr;
584 } else {
585 place_before = *(sync_within_loop_it + 1);
586 }
587
588 registerInsertBefore(place_before, sync_expr, &place_in->body());
589 if (maybe_alloc != nullptr) {
590 registerInsertBefore(place_before, maybe_alloc, &place_in->body());
591 }
592 }
593 }
594
595 void handle(kir::IfThenElse*) final {
596 TORCH_INTERNAL_ASSERT(
597 false,
598 "Pass does not support conditional statements, ",
599 "this pass should be run before any conditionals are placed in code.");
600 }
601
602 // Return a set of expressions that modify shared-memory
603 // tensors. Expressions are excluded when syncthreads are already
604 // placed.
605 std::unordered_set<Expr*> isModifiedSharedMemory(
606 const std::unordered_map<Val*, Expr*>& smem,
607 const std::vector<Val*>& tvs,
608 bool check_sync_map = true) const {
609 std::unordered_set<Expr*> last_writes;
610 for (auto tv : ir_utils::filterByType<TensorView>(tvs)) {
611 if (check_sync_map &&
612 GpuLower::current()->syncMap().needsRawSync(tv).none()) {
613 continue;
614 }
615 if (tv->getMemoryType() != MemoryType::Shared) {
616 continue;
617 }
618 auto it = smem.find(tv);
619 if (it != smem.end()) {
620 last_writes.insert(it->second);
621 }
622 }
623 return last_writes;
624 }
625
626 std::unordered_set<Expr*> isModifiedGlobalMemory(
627 const std::unordered_map<Val*, Expr*>& gmem,
628 const std::vector<Val*>& tvs) const {
629 std::unordered_set<Expr*> last_writes;
630 for (auto tv : ir_utils::filterByType<TensorView>(tvs)) {
631 if (GpuLower::current()->syncMap().needsRawSync(tv).none()) {
632 continue;
633 }
634 auto it = gmem.find(tv);
635 if (it != gmem.end()) {
636 last_writes.insert(it->second);
637 }
638 }
639 return last_writes;
640 }
641
642 ReadAfterWriteSyncs(const std::vector<Expr*>& _exprs) {
643 // Fusion shared_memory values
644 // Tracks if shared memory is modified
645 std::unordered_map<Val*, Expr*> smem;
646 // Tracks if shared memory is asynchronously modified
647 std::unordered_map<Val*, Expr*> smem_async;
648 std::unordered_map<Val*, Expr*> gmem;
649
650 // Flatten all the expressions
651 auto flattened_exprs = ir_utils::flattenScopedExprs(_exprs);
652
653 Expr* prev_tv_expr = nullptr;
654 for (auto expr : flattened_exprs) {
655 if (!ir_utils::isTvOp(expr) || expr->isA<kir::Allocate>()) {
656 continue;
657 }
658
659 auto last_gmem_writes = isModifiedGlobalMemory(gmem, expr->inputs());
660 if (!last_gmem_writes.empty()) {
661 TORCH_INTERNAL_ASSERT(
662 prev_tv_expr != nullptr,
663 "Can't require sync on inputs, however, detected it's needed.");
664 ParallelTypeBitmap bitmap;
665 for (auto entry : gmem) {
666 TORCH_INTERNAL_ASSERT(entry.first->isA<TensorView>());
667 auto sync_bits = GpuLower::current()->syncMap().needsRawSync(
668 entry.first->as<TensorView>());
669 bitmap |= sync_bits;
670 }
671
672 sync_before_.emplace_back(std::make_pair(expr, bitmap));
673 last_writes_.push_back(last_gmem_writes);
674 gmem.clear();
675 }
676
677 auto last_smem_writes = isModifiedSharedMemory(smem, expr->inputs());
678 auto last_async_smem_writes =
679 isModifiedSharedMemory(smem_async, expr->inputs(), false);
680
681 // Keep track of async smem writes before the current
682 // expr, following largely the same logic as block sync.
683 if (!last_async_smem_writes.empty()) {
684 cpasync_wait_before_.push_back(expr);
685 std::unordered_set<Expr*> async_smem_writes;
686 for (auto it : smem_async) {
687 async_smem_writes.insert(it.second);
688 }
689 last_cpasync_writes_.push_back(async_smem_writes);
690 smem_async.clear();
691 }
692
693 if (!last_smem_writes.empty()) {
694 TORCH_INTERNAL_ASSERT(
695 prev_tv_expr != nullptr,
696 "Can't require sync on inputs, however, detected it's needed.");
697 ParallelTypeBitmap bitmap;
698 bitmap.set(ParallelType::TIDx);
699 bitmap.set(ParallelType::TIDy);
700 bitmap.set(ParallelType::TIDz);
701 sync_before_.emplace_back(std::make_pair(expr, bitmap));
702
703 // Before clearing `smem`, put all the currently pending smem writes
704 // in last_writes_. This will make sure all the smem writes will
705 // be taken into consideration when deciding which loopnest level
706 // to insert the block sync. see FusionRAWSyncInsertionPlace4.
707 std::unordered_set<Expr*> smem_writes;
708 for (auto it : smem) {
709 // No need to keep track of shared mem writes that does not
710 // require a RAW block sync.
711 if (GpuLower::current()
712 ->syncMap()
713 .needsRawSync(it.first->as<TensorView>())
714 .hasTID()) {
715 smem_writes.insert(it.second);
716 }
717 }
718 last_writes_.push_back(smem_writes);
719 smem.clear();
720 }
721
722 for (auto tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
723 // Double buffered tensors do not need RAW sync to be inserted
724 // here, except for the initial load part, which is taken care
725 // separately by DoubleBufferInserter.
726 if (tv->getMemoryType() == MemoryType::Shared &&
727 !(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
728 smem[tv] = expr;
729
730 // only keep track of async writes in smem_async
731 if (ir_utils::isCpAsyncOp(expr)) {
732 smem_async[tv] = expr;
733 }
734 }
735 if (tv->getMemoryType() == MemoryType::Global) {
736 gmem[tv] = expr;
737 }
738 }
739
740 prev_tv_expr = expr;
741 }
742
743 kir::ExprMutator::traverseAndInsert(_exprs);
744
745 TORCH_INTERNAL_ASSERT(
746 sync_before_.empty(), "Didn't place all required syncs.");
747 }
748
749 private:
750 //! Keep track of expressions that must be followed by syncthreads
751 std::deque<std::pair<Expr*, ParallelTypeBitmap>> sync_before_;
752
753 //! Keep track of write expressions that must be placed before
754 //! syncthreads.
755 //!
756 //! syncthreads is placed before for each expression of
757 //! sync_before_. However, if it's inside a loop with halo, it must
758 //! be placed before that. last_writes_ keeps track of expressions
759 //! modifying the smem buffer each syncthreads is used for so that
760 //! it is not placed before those write expressions.
761 std::deque<std::unordered_set<Expr*>> last_writes_;
762
763 //! Keep track of expressions that must be wait for cp.async to finish.
764 std::deque<Expr*> cpasync_wait_before_;
765
766 //! Keep track of write expressions that must be placed before
767 //! cp.async wait.
768 std::deque<std::unordered_set<Expr*>> last_cpasync_writes_;
769
770 public:
771 static std::vector<Expr*> insert(const std::vector<Expr*>& loop_nests) {
772 ReadAfterWriteSyncs inserter(loop_nests);
773 return inserter.exprs_;
774 }
775};
776
777} // namespace
778
779std::vector<Expr*> insertRawThreadSynchronization(
780 const std::vector<Expr*>& exprs) {
781 FUSER_PERF_SCOPE("GpuLower::Lower::insertRawThreadSynchronization");
782 return ReadAfterWriteSyncs::insert(exprs);
783}
784
785std::vector<Expr*> insertWarThreadSynchronization(
786 const std::vector<Expr*>& exprs) {
787 FUSER_PERF_SCOPE("GpuLower::Lower::insertWarThreadSynchronization");
788 return WarSyncInserter::insert(exprs);
789}
790} // namespace cuda
791} // namespace fuser
792} // namespace jit
793} // namespace torch
794