1 | #include <dispatch.h> |
2 | #include <instrumentation.h> |
3 | #include <ir_utils.h> |
4 | #include <kernel_ir.h> |
5 | #include <kernel_ir_dispatch.h> |
6 | #include <lower2device.h> |
7 | #include <lower_insert_syncs.h> |
8 | #include <lower_utils.h> |
9 | |
10 | #include <unordered_set> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | namespace { |
18 | |
19 | //! Scan through Kernel IR for-loops to insert Sync nodes to avoid |
20 | //! Write-After-Read (WAR) race condition. |
21 | //! |
22 | //! Example: |
23 | //! for () { |
24 | //! smem_buf[threadIdx.x] = x; |
25 | //! __syncthreads(); |
26 | //! buf[threadId.x] = smem_buf[threadIdx.x + 1]; |
27 | //! } |
28 | //! |
29 | //! In this case, additional syncthreads is needed at the end of the |
30 | //! loop body to avoid a hazard with smem_buf. |
31 | |
32 | //! Keeping track the allocations of SMEM TVs |
33 | class SmemAllocMap { |
34 | public: |
35 | //! Insert a new node if it's a SMEM allocation |
36 | void insert(kir::Allocate* alloc) { |
37 | if (auto tv = dynamic_cast<TensorView*>(alloc->buffer())) { |
38 | if (tv->getMemoryType() == MemoryType::Shared) { |
39 | // Note that a TensorView can have two allocations due to |
40 | // unswitch. |
41 | auto p = map_.insert({tv, alloc}); |
42 | // If there's an existing entry, reset it with the new |
43 | // alloc. Currently, the existing alloc is actually the same |
44 | // as the new one as each expression is just inserted to both |
45 | // then and else parts of the unswitched loop, but this should |
46 | // be changed. |
47 | if (!p.second) { |
48 | p.first->second = alloc; |
49 | } |
50 | } |
51 | } |
52 | } |
53 | |
54 | //! Run through aliases to get the buffer that is actually allocated for a |
55 | //! given TV |
56 | TensorView* getRealBuffer(TensorView* tv) const { |
57 | auto it = map_.find(tv); |
58 | TORCH_INTERNAL_ASSERT( |
59 | it != map_.end(), "Allocation not found for " , tv->toString()); |
60 | const kir::Allocate* alloc = it->second; |
61 | while (alloc->alias()) { |
62 | alloc = alloc->alias(); |
63 | } |
64 | auto buf = alloc->buffer(); |
65 | TORCH_INTERNAL_ASSERT(buf->isA<TensorView>()); |
66 | return buf->as<TensorView>(); |
67 | } |
68 | |
69 | private: |
70 | std::unordered_map<TensorView*, kir::Allocate*> map_; |
71 | }; |
72 | |
73 | struct WarMemoryInfo { |
74 | // True if there's a sync after the last read within the alloc loop. |
75 | bool sync_after_read = false; |
76 | |
77 | // True if there's a sync before the first write. There can be multiple writes |
78 | // from memory aliasing. |
79 | bool sync_before_write = false; |
80 | |
81 | // Has there been a read of this memory location |
82 | bool read_hit = false; |
83 | |
84 | // Has there been *the* write to this memory location, assumes single write |
85 | // instruction (needs to be before conditionals added to code) |
86 | bool write_hit = false; |
87 | |
88 | // For loop this TV is compute_at'ed in. |
89 | kir::ForLoop* ca_loop = nullptr; |
90 | }; |
91 | |
92 | // To prevent shared memory from being over written before it is read, a |
93 | // synchronization point has to be inserted either between the allocation of an |
94 | // SMEM buffer and where we write into it, or after the buffer's last read |
95 | // before exiting the allocation's scope. |
96 | // |
97 | // e.g. |
98 | // for i: |
99 | // "alloc A" in shared memory - This is really marked by the compute_at point |
100 | // sync_loc_0 |
101 | // for j: |
102 | // sync_loc_1 |
103 | // for k: |
104 | // sync_loc_2 |
105 | // A = ... |
106 | // for k: |
107 | // ... = ... A |
108 | // for j: |
109 | // for k: |
110 | // ... = ... A |
111 | // sync_loc_3 |
112 | // sync_loc_4 |
113 | // sync_loc_5 |
114 | // |
115 | // All sync locations here provide valid protection that memory in A is finished |
116 | // being read before it is over written in the next iteration |
117 | // |
118 | // Insertion of sync threads will be done from the inner most position to the |
119 | // outer most. If a sync protecting the buffer is not already placed, the |
120 | // location prefered for the sync threads is the last possible position. One |
121 | // future optimization could be to not sync on the last iteration of the loop |
122 | // the sync is placed in. |
123 | class WarSyncInserter : private kir::ExprMutator { |
124 | public: |
125 | static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) { |
126 | WarSyncInserter inserter(exprs); |
127 | return inserter.exprs_; |
128 | } |
129 | |
130 | private: |
131 | //! Insert Sync nodes at the end of a given for-loop when a WAR |
132 | //! hazard may happen. |
133 | WarSyncInserter(const std::vector<Expr*>& exprs) { |
134 | auto& lower_alloc_info_map = GpuLower::current()->localAllocationInfoMap(); |
135 | for (const auto& entry : lower_alloc_info_map) { |
136 | alloc_map_.insert(entry.first); |
137 | } |
138 | kir::ExprMutator::traverseAndInsert(exprs); |
139 | } |
140 | |
141 | void handle(kir::IfThenElse* ite) final { |
142 | TORCH_INTERNAL_ASSERT( |
143 | ite->elseBody().empty(), |
144 | "Pass does not support conditional flow," , |
145 | " needs to be done before conditional execution is lowered." ); |
146 | kir::ExprMutator::handle(ite); |
147 | } |
148 | |
149 | void handle(kir::BlockSync* sync) final { |
150 | // Register the sync for the active for loop |
151 | sync_hit_.back() = true; |
152 | // Run through the active allocations, if a read was hit, register there was |
153 | // a sync after the read. If there's subsequent reads on this buffer the |
154 | // sync_after_read will be cleared. |
155 | for (auto& entry : smem_allocations_) { |
156 | auto& alloc_stack = entry.second; |
157 | if (alloc_stack.back().read_hit) { |
158 | alloc_stack.back().sync_after_read = true; |
159 | } |
160 | } |
161 | } |
162 | |
163 | void handle(kir::GridSync* sync) final { |
164 | // Register the sync for the active for loop |
165 | sync_hit_.back() = true; |
166 | // Run through the active allocations, if a read was hit, register there was |
167 | // a sync after the read. If there's subsequent reads on this buffer the |
168 | // sync_after_read will be cleared. |
169 | for (auto& entry : smem_allocations_) { |
170 | auto& alloc_stack = entry.second; |
171 | if (alloc_stack.back().read_hit) { |
172 | alloc_stack.back().sync_after_read = true; |
173 | } |
174 | } |
175 | } |
176 | |
177 | // Checks if fl or loops within it have hit a sync |
178 | bool syncWithin(kir::ForLoop* fl) { |
179 | // If outer most scope check the first sync_hit_ position |
180 | if (fl == nullptr) { |
181 | return sync_hit_[0]; |
182 | } |
183 | |
184 | // Find the for loop we want to look within |
185 | auto fl_it = std::find(for_loops_.begin(), for_loops_.end(), fl); |
186 | |
187 | // Convert it to an index, but add one for the outer most scope |
188 | auto fl_i = std::distance(for_loops_.begin(), fl_it) + 1; |
189 | |
190 | // Start at that index and see if there's syncs within that for loop |
191 | for (auto i : c10::irange(fl_i, sync_hit_.size())) { |
192 | if (sync_hit_[i]) { |
193 | return true; |
194 | } |
195 | } |
196 | return false; |
197 | } |
198 | |
199 | void handle(Expr* expr) final { |
200 | // If not a tensor view expression continue with dispatch |
201 | if (!ir_utils::isTvOp(expr)) { |
202 | kir::ExprMutator::handle(expr); |
203 | return; |
204 | } |
205 | |
206 | // Mark write has been hit for all output tvs |
207 | auto out_tvs = ir_utils::filterByType<TensorView>(expr->outputs()); |
208 | for (auto out_tv : out_tvs) { |
209 | if (out_tv->getMemoryType() != MemoryType::Shared || |
210 | GpuLower::current()->syncMap().needsRawSync(out_tv).none()) { |
211 | continue; |
212 | } |
213 | |
214 | auto& entry = getMemInfo(out_tv); |
215 | |
216 | // If this is the first write and there's a sync in one of the loops after |
217 | // the compute at loop, then this buffer is protected. |
218 | if (syncWithin(entry.ca_loop) && !entry.write_hit) { |
219 | entry.sync_before_write = true; |
220 | } |
221 | entry.write_hit = true; |
222 | } |
223 | |
224 | // Mark read was hit, if sync_after_read was set, clear it. |
225 | auto inp_tvs = ir_utils::filterByType<TensorView>(expr->inputs()); |
226 | for (auto inp_tv : inp_tvs) { |
227 | if (inp_tv->getMemoryType() != MemoryType::Shared || |
228 | GpuLower::current()->syncMap().needsRawSync(inp_tv).none()) { |
229 | continue; |
230 | } |
231 | |
232 | auto& entry = getMemInfo(inp_tv); |
233 | entry.read_hit = true; |
234 | // Clear the sync_after_read if it was set because there was another write |
235 | entry.sync_after_read = false; |
236 | } |
237 | } |
238 | |
239 | void handle(kir::ForLoop* for_loop) final { |
240 | // Push loop scope information |
241 | auto prev_within_iter_loop_ = within_iter_loop_; |
242 | sync_hit_.push_back(false); |
243 | |
244 | // If there is no real iterating loop WAR syncs aren't necessary |
245 | within_iter_loop_ = within_iter_loop_ || !for_loop->isTrivial(); |
246 | |
247 | // Process the expressions in the for loop |
248 | kir::ExprMutator::handle(for_loop); |
249 | |
250 | // Sync analysis and cleanup: |
251 | // |
252 | // Pop for loop stack inside WarMemoryInfo structs if they match this one. |
253 | // Erase empty entries so we don't continue to search over them |
254 | // |
255 | // Insert sync at end of this for loop if any of the entries require |
256 | std::vector<TensorView*> to_erase; |
257 | bool insert_sync = false; |
258 | for (auto& entry : smem_allocations_) { |
259 | auto& alloc_stack = entry.second; |
260 | if (alloc_stack.size() && alloc_stack.back().ca_loop == for_loop) { |
261 | if (!alloc_stack.back().sync_after_read && |
262 | !alloc_stack.back().sync_before_write) { |
263 | insert_sync = within_iter_loop_; |
264 | } |
265 | |
266 | alloc_stack.pop_back(); |
267 | if (alloc_stack.empty()) { |
268 | to_erase.push_back(entry.first); |
269 | } |
270 | } |
271 | } |
272 | |
273 | for (auto tv : to_erase) { |
274 | smem_allocations_.erase(tv); |
275 | } |
276 | |
277 | // WAR Sync is necessary in this loop, register its insertion. |
278 | if (insert_sync) { |
279 | auto sync_expr = IrBuilder::create<kir::BlockSync>(true); |
280 | kir::ExprMutator::registerInsertAfter( |
281 | for_loop->body().exprs().back(), sync_expr, &for_loop->body()); |
282 | handle(sync_expr); |
283 | } |
284 | |
285 | // Pop for loop scope information |
286 | sync_hit_.pop_back(); |
287 | within_iter_loop_ = prev_within_iter_loop_; |
288 | } |
289 | |
290 | // Create a new WarMemoryInfo entry if required and return a reference to it, |
291 | // else return the WarMemoryInfo associated with tv |
292 | WarMemoryInfo& getMemInfo(TensorView* tv) { |
293 | auto maybe_aliased_tv = alloc_map_.getRealBuffer(tv); |
294 | auto alloc_it = smem_allocations_.find(maybe_aliased_tv); |
295 | auto ca_loop = |
296 | lower_utils::getAllocInformation(tv, for_loops_).init_for_loop; |
297 | if (alloc_it == smem_allocations_.end()) { |
298 | WarMemoryInfo mem_info; |
299 | mem_info.ca_loop = ca_loop; |
300 | auto entry_it = |
301 | smem_allocations_ |
302 | .insert(std::make_pair( |
303 | maybe_aliased_tv, std::vector<WarMemoryInfo>({mem_info}))) |
304 | .first; |
305 | return entry_it->second.back(); |
306 | } else if ( |
307 | maybe_aliased_tv != tv && alloc_it->second.back().ca_loop != ca_loop) { |
308 | WarMemoryInfo mem_info; |
309 | mem_info.ca_loop = ca_loop; |
310 | auto& alloc_stack = alloc_it->second; |
311 | alloc_stack.push_back(mem_info); |
312 | return alloc_stack.back(); |
313 | } |
314 | return alloc_it->second.back(); |
315 | } |
316 | |
317 | //! Allocation map of SMEM buffers. Needed because of SMEM buffer aliasing, |
318 | //! need to track the root of the alias to properly insert WAR hazard syncs |
319 | SmemAllocMap alloc_map_; |
320 | |
321 | //! Is there a loop nest that has a non-trivial iteration (extent != 1) and |
322 | //! not bound to a block/thread. This indicates if a WAR sync is necessary, |
323 | //! otherwise the Expr is not in an iterating for loop. |
324 | bool within_iter_loop_ = false; |
325 | |
326 | // Track which loops have hit a sync. Used to see if there's a sync before |
327 | // write. |
328 | std::vector<bool> sync_hit_ = {false}; |
329 | |
330 | // Keep track of the active allocations we need to protect. Key is the |
331 | // "getRealBuffer", not the raw tv. There can be multiple WarMemoryInfo's |
332 | // because of aliasing. If the "getRealBuffer" tv has a compute at outside the |
333 | // alias tv, each aliased tv in a unique ca_loop has to be tracked separately |
334 | // for WAR insertion. |
335 | std::unordered_map<TensorView*, std::vector<WarMemoryInfo>> smem_allocations_; |
336 | }; |
337 | |
338 | class ValidatePlacementAfterWrites : private kir::IrVisitor { |
339 | public: |
340 | //! Validate no expr in writes found under loop |
341 | static void validate( |
342 | kir::ForLoop* loop, |
343 | const std::unordered_set<Expr*>& writes) { |
344 | ValidatePlacementAfterWrites validator(writes); |
345 | validator.handle(loop); |
346 | } |
347 | |
348 | private: |
349 | using kir::IrVisitor::handle; |
350 | |
351 | ValidatePlacementAfterWrites(const std::unordered_set<Expr*>& writes) |
352 | : writes_(writes) {} |
353 | |
354 | void handle(Expr* expr) final { |
355 | if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) { |
356 | kir::IrVisitor::handle(expr); |
357 | } else { |
358 | TORCH_INTERNAL_ASSERT( |
359 | writes_.find(expr) == writes_.end(), |
360 | "Block sync must be placed after " , |
361 | expr->toString()); |
362 | } |
363 | } |
364 | |
365 | private: |
366 | const std::unordered_set<Expr*>& writes_; |
367 | }; |
368 | |
369 | namespace { |
370 | |
371 | Val* getGridSyncBufferSize(const ParallelTypeBitmap& ptb) { |
372 | // See the comment above for getGridCommWorkBufferSize. |
373 | TORCH_INTERNAL_ASSERT( |
374 | ptb.hasBID(), |
375 | "Detected needing a grid sync but no grid bits set in bitmap." ); |
376 | Val* buffer_size = GpuLower::current()->kernel()->oneVal(); |
377 | for (auto pt : kParallelTypeBIDs) { |
378 | // Synchronized within pt, so all blocks of this PT use the same |
379 | // sync buffer location, and thus no need to expand the sync |
380 | // buffer size. |
381 | if (ptb.get(pt)) { |
382 | continue; |
383 | } |
384 | auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); |
385 | if (pt_dim == nullptr || pt_dim->isOneInt()) { |
386 | continue; |
387 | } |
388 | buffer_size = IrBuilder::mulExpr(buffer_size, pt_dim); |
389 | } |
390 | return buffer_size; |
391 | } |
392 | |
393 | } // namespace |
394 | |
395 | class ReadAfterWriteSyncs : public kir::ExprMutator { |
396 | private: |
397 | using kir::ExprMutator::handle; |
398 | |
399 | //! Traverse up the loop stack from loops_it and if a halo loop is |
400 | //! found, place a given sync expr before the outer-most halo loop. |
401 | // TODO: What needs to be done here for gmem comm? |
402 | bool insertBeforeHaloLoop( |
403 | std::vector<kir::ForLoop*>::iterator loops_it, |
404 | Expr* sync_expr, |
405 | Expr* maybe_alloc, |
406 | const std::unordered_set<Expr*>& writes) { |
407 | std::vector<kir::ForLoop*>::iterator halo_loop_it; |
408 | bool halo_loop_found = false; |
409 | |
410 | while (true) { |
411 | if ((*loops_it)->iter_domain()->isThreadDim() && |
412 | (*loops_it)->iter_domain()->extent() != (*loops_it)->stop()) { |
413 | halo_loop_found = true; |
414 | halo_loop_it = loops_it; |
415 | } |
416 | |
417 | if (loops_it == for_loops_.begin()) { |
418 | break; |
419 | } |
420 | --loops_it; |
421 | } |
422 | |
423 | // No halo loop found. Do not place the sync expr here. Return |
424 | // false to indicate nothing is done. |
425 | if (!halo_loop_found) { |
426 | return false; |
427 | } |
428 | |
429 | auto halo_loop = *halo_loop_it; |
430 | |
431 | // Make sure there's no write to the smem buffer inside the halo |
432 | // loop. syncthreads is moved before the halo loop, so having |
433 | // writes inside the loop invalidates the consistency. |
434 | ValidatePlacementAfterWrites::validate(halo_loop, writes); |
435 | |
436 | if (halo_loop_it == for_loops_.begin()) { |
437 | // place in global scope |
438 | auto place_before_it = std::find(exprs_.begin(), exprs_.end(), halo_loop); |
439 | TORCH_INTERNAL_ASSERT(place_before_it != exprs_.end()); |
440 | exprs_.insert(place_before_it, sync_expr); |
441 | } else { |
442 | auto place_in = *(halo_loop_it - 1); |
443 | kir::ExprMutator::registerInsertBefore( |
444 | halo_loop, sync_expr, &place_in->body()); |
445 | if (maybe_alloc != nullptr) { |
446 | kir::ExprMutator::registerInsertBefore( |
447 | halo_loop, maybe_alloc, &place_in->body()); |
448 | } |
449 | } |
450 | |
451 | return true; |
452 | } |
453 | |
454 | void handle(Expr* expr) final { |
455 | if (!ir_utils::isTvOp(expr) || expr->isA<kir::Allocate>()) { |
456 | kir::ExprMutator::handle(expr); |
457 | return; |
458 | } |
459 | |
460 | // An identical but separate flow of timing for cpasync_wait. |
461 | // The insertion and tracking mechanism is the same as RAW |
462 | // sync insertion since cp.async only writes smem. |
463 | // Currently the only interaction which is realized by the |
464 | // ordering in this function is that in the case when we need both a |
465 | // cpasync wait and a block sync before the same expr, we want |
466 | // to place the wait before the block sync, since currently there shouldn't |
467 | // be any normal case where we explicitly want the wait after a block sync. |
468 | if (cpasync_wait_before_.size() > 0 && |
469 | cpasync_wait_before_.front() == expr) { |
470 | cpasync_wait_before_.pop_front(); |
471 | auto last_writes = last_cpasync_writes_.front(); |
472 | last_cpasync_writes_.pop_front(); |
473 | |
474 | auto sync_expr = IrBuilder::create<kir::CpAsyncWait>(); |
475 | insertSyncExpr(last_writes, expr, sync_expr, nullptr); |
476 | } |
477 | |
478 | if (sync_before_.size() > 0 && sync_before_.front().first == expr) { |
479 | auto sync_bitmap = sync_before_.front().second; |
480 | sync_before_.pop_front(); |
481 | auto last_writes = last_writes_.front(); |
482 | last_writes_.pop_front(); |
483 | // Found that a sync is needed |
484 | |
485 | // TODO: Explicitly test the 3 cases below |
486 | Expr* sync_expr = nullptr; |
487 | kir::Allocate* maybe_alloc = nullptr; |
488 | if (sync_bitmap.hasBID()) { |
489 | maybe_alloc = lower_utils::allocGlobalBufferForGridComm( |
490 | getGridSyncBufferSize(sync_bitmap), DataType::Int, true); |
491 | sync_expr = IrBuilder::create<kir::GridSync>( |
492 | sync_bitmap, maybe_alloc->buffer()); |
493 | } else { |
494 | sync_expr = IrBuilder::create<kir::BlockSync>(false); // is not war sync |
495 | } |
496 | |
497 | insertSyncExpr(last_writes, expr, sync_expr, maybe_alloc); |
498 | } |
499 | } |
500 | |
501 | // Find where a sync needs to be inserted and insert the given sync. |
502 | // This is very similar to how allocations are placed, simply place sync |
503 | // before the expression at the common alloc point of producers (really |
504 | // last_writes because we may have other exprs we're syncing besides the |
505 | // producers of this one) |
506 | void insertSyncExpr( |
507 | const std::unordered_set<Expr*>& last_writes, |
508 | Expr* insert_before_expr, |
509 | Expr* sync_expr, |
510 | Expr* maybe_alloc) { |
511 | // The expressions in last_writes are those we're protecting the read |
512 | // from. To figure out which loop we need a syncthread in, take the inner |
513 | // most compute at for loop of all the outputs of the last writes. |
514 | std::unordered_set<kir::ForLoop*> sync_within; |
515 | |
516 | for (auto last_write : last_writes) { |
517 | auto write_out_tv = ir_utils::getTvOutput(last_write); |
518 | TORCH_INTERNAL_ASSERT( |
519 | write_out_tv != nullptr, |
520 | "Error in RAW sync insertion, expecting a TV expr, but didn't find one." ); |
521 | if (write_out_tv->getComputeAtPosition() == 0) { |
522 | continue; |
523 | } |
524 | |
525 | auto local_id = |
526 | write_out_tv->axis((int)write_out_tv->getComputeAtPosition() - 1); |
527 | |
528 | auto loops_it = std::find_if( |
529 | for_loops_.begin(), for_loops_.end(), [&local_id](const auto& loop) { |
530 | return GpuLower::current()->caMap()->areMapped( |
531 | loop->iter_domain(), local_id, IdMappingMode::PERMISSIVE); |
532 | }); |
533 | |
534 | TORCH_INTERNAL_ASSERT( |
535 | loops_it != for_loops_.end(), |
536 | "Could not find loop associated with the alloc position of " , |
537 | write_out_tv->toString()); |
538 | |
539 | sync_within.emplace(*loops_it); |
540 | } |
541 | |
542 | // The for loop the sync needs to be in |
543 | kir::ForLoop* sync_within_fl = nullptr; |
544 | for (auto fl : for_loops_) { |
545 | if (sync_within.count(fl)) { |
546 | sync_within_fl = fl; |
547 | } |
548 | } |
549 | |
550 | if (sync_within_fl == nullptr) { |
551 | // Sync should be placed at global scope, after its outer most loop if |
552 | // it has one. |
553 | Expr* place_before = |
554 | for_loops_.size() > 0 ? for_loops_[0] : insert_before_expr; |
555 | // Find location in exprs_ |
556 | auto place_before_it = |
557 | std::find(exprs_.begin(), exprs_.end(), place_before); |
558 | TORCH_INTERNAL_ASSERT( |
559 | place_before_it != exprs_.end(), |
560 | "Could not figure out where to place synchronization. " , |
561 | "Tried to place after, " , |
562 | place_before->toString(), |
563 | ", but could not find this expression at the global scope." ); |
564 | if (maybe_alloc != nullptr) { |
565 | registerInsertBefore(place_before, maybe_alloc, nullptr); |
566 | } |
567 | registerInsertBefore(*(place_before_it), sync_expr, nullptr); |
568 | } else { |
569 | auto sync_within_loop_it = |
570 | std::find(for_loops_.begin(), for_loops_.end(), sync_within_fl); |
571 | |
572 | // block sync must be placed before halo-extended loops |
573 | if (insertBeforeHaloLoop( |
574 | sync_within_loop_it, sync_expr, maybe_alloc, last_writes)) { |
575 | return; |
576 | } |
577 | |
578 | auto place_in = *sync_within_loop_it; |
579 | Expr* place_before = nullptr; |
580 | |
581 | if (sync_within_loop_it + 1 == for_loops_.end()) { |
582 | // Inline, place before expr |
583 | place_before = insert_before_expr; |
584 | } else { |
585 | place_before = *(sync_within_loop_it + 1); |
586 | } |
587 | |
588 | registerInsertBefore(place_before, sync_expr, &place_in->body()); |
589 | if (maybe_alloc != nullptr) { |
590 | registerInsertBefore(place_before, maybe_alloc, &place_in->body()); |
591 | } |
592 | } |
593 | } |
594 | |
595 | void handle(kir::IfThenElse*) final { |
596 | TORCH_INTERNAL_ASSERT( |
597 | false, |
598 | "Pass does not support conditional statements, " , |
599 | "this pass should be run before any conditionals are placed in code." ); |
600 | } |
601 | |
602 | // Return a set of expressions that modify shared-memory |
603 | // tensors. Expressions are excluded when syncthreads are already |
604 | // placed. |
605 | std::unordered_set<Expr*> isModifiedSharedMemory( |
606 | const std::unordered_map<Val*, Expr*>& smem, |
607 | const std::vector<Val*>& tvs, |
608 | bool check_sync_map = true) const { |
609 | std::unordered_set<Expr*> last_writes; |
610 | for (auto tv : ir_utils::filterByType<TensorView>(tvs)) { |
611 | if (check_sync_map && |
612 | GpuLower::current()->syncMap().needsRawSync(tv).none()) { |
613 | continue; |
614 | } |
615 | if (tv->getMemoryType() != MemoryType::Shared) { |
616 | continue; |
617 | } |
618 | auto it = smem.find(tv); |
619 | if (it != smem.end()) { |
620 | last_writes.insert(it->second); |
621 | } |
622 | } |
623 | return last_writes; |
624 | } |
625 | |
626 | std::unordered_set<Expr*> isModifiedGlobalMemory( |
627 | const std::unordered_map<Val*, Expr*>& gmem, |
628 | const std::vector<Val*>& tvs) const { |
629 | std::unordered_set<Expr*> last_writes; |
630 | for (auto tv : ir_utils::filterByType<TensorView>(tvs)) { |
631 | if (GpuLower::current()->syncMap().needsRawSync(tv).none()) { |
632 | continue; |
633 | } |
634 | auto it = gmem.find(tv); |
635 | if (it != gmem.end()) { |
636 | last_writes.insert(it->second); |
637 | } |
638 | } |
639 | return last_writes; |
640 | } |
641 | |
642 | ReadAfterWriteSyncs(const std::vector<Expr*>& _exprs) { |
643 | // Fusion shared_memory values |
644 | // Tracks if shared memory is modified |
645 | std::unordered_map<Val*, Expr*> smem; |
646 | // Tracks if shared memory is asynchronously modified |
647 | std::unordered_map<Val*, Expr*> smem_async; |
648 | std::unordered_map<Val*, Expr*> gmem; |
649 | |
650 | // Flatten all the expressions |
651 | auto flattened_exprs = ir_utils::flattenScopedExprs(_exprs); |
652 | |
653 | Expr* prev_tv_expr = nullptr; |
654 | for (auto expr : flattened_exprs) { |
655 | if (!ir_utils::isTvOp(expr) || expr->isA<kir::Allocate>()) { |
656 | continue; |
657 | } |
658 | |
659 | auto last_gmem_writes = isModifiedGlobalMemory(gmem, expr->inputs()); |
660 | if (!last_gmem_writes.empty()) { |
661 | TORCH_INTERNAL_ASSERT( |
662 | prev_tv_expr != nullptr, |
663 | "Can't require sync on inputs, however, detected it's needed." ); |
664 | ParallelTypeBitmap bitmap; |
665 | for (auto entry : gmem) { |
666 | TORCH_INTERNAL_ASSERT(entry.first->isA<TensorView>()); |
667 | auto sync_bits = GpuLower::current()->syncMap().needsRawSync( |
668 | entry.first->as<TensorView>()); |
669 | bitmap |= sync_bits; |
670 | } |
671 | |
672 | sync_before_.emplace_back(std::make_pair(expr, bitmap)); |
673 | last_writes_.push_back(last_gmem_writes); |
674 | gmem.clear(); |
675 | } |
676 | |
677 | auto last_smem_writes = isModifiedSharedMemory(smem, expr->inputs()); |
678 | auto last_async_smem_writes = |
679 | isModifiedSharedMemory(smem_async, expr->inputs(), false); |
680 | |
681 | // Keep track of async smem writes before the current |
682 | // expr, following largely the same logic as block sync. |
683 | if (!last_async_smem_writes.empty()) { |
684 | cpasync_wait_before_.push_back(expr); |
685 | std::unordered_set<Expr*> async_smem_writes; |
686 | for (auto it : smem_async) { |
687 | async_smem_writes.insert(it.second); |
688 | } |
689 | last_cpasync_writes_.push_back(async_smem_writes); |
690 | smem_async.clear(); |
691 | } |
692 | |
693 | if (!last_smem_writes.empty()) { |
694 | TORCH_INTERNAL_ASSERT( |
695 | prev_tv_expr != nullptr, |
696 | "Can't require sync on inputs, however, detected it's needed." ); |
697 | ParallelTypeBitmap bitmap; |
698 | bitmap.set(ParallelType::TIDx); |
699 | bitmap.set(ParallelType::TIDy); |
700 | bitmap.set(ParallelType::TIDz); |
701 | sync_before_.emplace_back(std::make_pair(expr, bitmap)); |
702 | |
703 | // Before clearing `smem`, put all the currently pending smem writes |
704 | // in last_writes_. This will make sure all the smem writes will |
705 | // be taken into consideration when deciding which loopnest level |
706 | // to insert the block sync. see FusionRAWSyncInsertionPlace4. |
707 | std::unordered_set<Expr*> smem_writes; |
708 | for (auto it : smem) { |
709 | // No need to keep track of shared mem writes that does not |
710 | // require a RAW block sync. |
711 | if (GpuLower::current() |
712 | ->syncMap() |
713 | .needsRawSync(it.first->as<TensorView>()) |
714 | .hasTID()) { |
715 | smem_writes.insert(it.second); |
716 | } |
717 | } |
718 | last_writes_.push_back(smem_writes); |
719 | smem.clear(); |
720 | } |
721 | |
722 | for (auto tv : ir_utils::filterByType<TensorView>(expr->outputs())) { |
723 | // Double buffered tensors do not need RAW sync to be inserted |
724 | // here, except for the initial load part, which is taken care |
725 | // separately by DoubleBufferInserter. |
726 | if (tv->getMemoryType() == MemoryType::Shared && |
727 | !(tv->isDoubleBuffered() || tv->isCircularBuffered())) { |
728 | smem[tv] = expr; |
729 | |
730 | // only keep track of async writes in smem_async |
731 | if (ir_utils::isCpAsyncOp(expr)) { |
732 | smem_async[tv] = expr; |
733 | } |
734 | } |
735 | if (tv->getMemoryType() == MemoryType::Global) { |
736 | gmem[tv] = expr; |
737 | } |
738 | } |
739 | |
740 | prev_tv_expr = expr; |
741 | } |
742 | |
743 | kir::ExprMutator::traverseAndInsert(_exprs); |
744 | |
745 | TORCH_INTERNAL_ASSERT( |
746 | sync_before_.empty(), "Didn't place all required syncs." ); |
747 | } |
748 | |
749 | private: |
750 | //! Keep track of expressions that must be followed by syncthreads |
751 | std::deque<std::pair<Expr*, ParallelTypeBitmap>> sync_before_; |
752 | |
753 | //! Keep track of write expressions that must be placed before |
754 | //! syncthreads. |
755 | //! |
756 | //! syncthreads is placed before for each expression of |
757 | //! sync_before_. However, if it's inside a loop with halo, it must |
758 | //! be placed before that. last_writes_ keeps track of expressions |
759 | //! modifying the smem buffer each syncthreads is used for so that |
760 | //! it is not placed before those write expressions. |
761 | std::deque<std::unordered_set<Expr*>> last_writes_; |
762 | |
763 | //! Keep track of expressions that must be wait for cp.async to finish. |
764 | std::deque<Expr*> cpasync_wait_before_; |
765 | |
766 | //! Keep track of write expressions that must be placed before |
767 | //! cp.async wait. |
768 | std::deque<std::unordered_set<Expr*>> last_cpasync_writes_; |
769 | |
770 | public: |
771 | static std::vector<Expr*> insert(const std::vector<Expr*>& loop_nests) { |
772 | ReadAfterWriteSyncs inserter(loop_nests); |
773 | return inserter.exprs_; |
774 | } |
775 | }; |
776 | |
777 | } // namespace |
778 | |
779 | std::vector<Expr*> insertRawThreadSynchronization( |
780 | const std::vector<Expr*>& exprs) { |
781 | FUSER_PERF_SCOPE("GpuLower::Lower::insertRawThreadSynchronization" ); |
782 | return ReadAfterWriteSyncs::insert(exprs); |
783 | } |
784 | |
785 | std::vector<Expr*> insertWarThreadSynchronization( |
786 | const std::vector<Expr*>& exprs) { |
787 | FUSER_PERF_SCOPE("GpuLower::Lower::insertWarThreadSynchronization" ); |
788 | return WarSyncInserter::insert(exprs); |
789 | } |
790 | } // namespace cuda |
791 | } // namespace fuser |
792 | } // namespace jit |
793 | } // namespace torch |
794 | |