1 | #include <ir_utils.h> |
2 | #include <kernel_ir.h> |
3 | #include <lower2device.h> |
4 | |
5 | #include <lower_double_buffer.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace fuser { |
10 | namespace cuda { |
11 | |
12 | unsigned int getDoubleBufferAxisPosition(const TensorView* tv) { |
13 | // Double-buffering prefetches the next subregion of the tensor by |
14 | // doubling the allocation. The subregion is defined by the axes |
15 | // at the CA position till the inner-most position. There must be |
16 | // at least one axis that is outside (left) of the CA position, |
17 | // which defines the loop where prefetching is applied. Therefore, |
18 | // the CA position must be larger than 0. |
19 | |
20 | TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0); |
21 | |
22 | // Unroll must not exist outside of double-buffer axis |
23 | auto first_unroll_it = std::find_if( |
24 | tv->domain()->domain().begin(), |
25 | tv->domain()->domain().end(), |
26 | [](const auto axis) { |
27 | return axis->getParallelType() == ParallelType::Unroll; |
28 | }); |
29 | |
30 | const int first_unroll_pos = |
31 | std::distance(tv->domain()->domain().begin(), first_unroll_it); |
32 | |
33 | const int unroll_or_ca_pos = |
34 | std::min((int)tv->getComputeAtPosition(), first_unroll_pos); |
35 | |
36 | TORCH_INTERNAL_ASSERT( |
37 | unroll_or_ca_pos > 0, |
38 | "Invalid tensor to double-buffer. Valid double buffer axis not found due to Unroll. " , |
39 | tv->toString()); |
40 | |
41 | int valid_pos = -1; |
42 | // Skip parallelized or broadcast axes |
43 | for (int i = unroll_or_ca_pos - 1; i >= 0; --i) { |
44 | auto pt = tv->axis(i)->getParallelType(); |
45 | if (!isParallelTypeThread(pt) && !tv->axis(i)->isBroadcast()) { |
46 | valid_pos = i; |
47 | break; |
48 | } |
49 | } |
50 | |
51 | TORCH_INTERNAL_ASSERT( |
52 | valid_pos >= 0, |
53 | "Invalid tensor to double-buffer. Valid double buffer axis not found. " , |
54 | tv->toString()); |
55 | |
56 | return valid_pos; |
57 | } |
58 | |
59 | IterDomain* getDoubleBufferAxis(const TensorView* tv) { |
60 | return tv->axis((int)getDoubleBufferAxisPosition(tv)); |
61 | } |
62 | |
63 | void validateDoubleBufferedTensor(const TensorView* tv) { |
64 | auto double_buffer_pos = getDoubleBufferAxisPosition(tv); |
65 | |
66 | // Like vectorization, only UnaryOp::Set with another TensorView is |
67 | // considered. |
68 | auto def = tv->definition(); |
69 | TORCH_INTERNAL_ASSERT( |
70 | (def->isA<UnaryOp>() && |
71 | def->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Set) || |
72 | // Load store op should generally support double buffering. |
73 | def->isA<LoadStoreOp>(), |
74 | "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set is supported: " , |
75 | def->toString()); |
76 | |
77 | TORCH_INTERNAL_ASSERT( |
78 | def->input(0)->isA<TensorView>(), |
79 | "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set with TensorView is supported: " , |
80 | def->toString()); |
81 | |
82 | // Require the producer tensor to have been computed entirely for |
83 | // the double-buffering loop. Otherwise, the producer itself would |
84 | // also need to be double-bufferred. |
85 | auto producer = def->input(0)->as<TensorView>(); |
86 | TORCH_INTERNAL_ASSERT( |
87 | producer->getComputeAtPosition() <= double_buffer_pos, |
88 | "Invalid tensor to double-buffer. The computeAt position of the producer tensor must be moved left: " , |
89 | producer->toString()); |
90 | |
91 | // Not strictly necessary, but only gmem -> smem or local and smem -> local |
92 | // are allowed. |
93 | const auto p_mem_type = producer->getMemoryType(); |
94 | const auto c_mem_type = tv->getMemoryType(); |
95 | TORCH_INTERNAL_ASSERT( |
96 | (p_mem_type == MemoryType::Global && |
97 | (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) || |
98 | (p_mem_type == MemoryType::Shared && c_mem_type == MemoryType::Local), |
99 | "Invalid tensor to double-buffer: " , |
100 | tv->toString(), |
101 | ". Producer memory type: " , |
102 | p_mem_type, |
103 | ". Consumer memory type: " , |
104 | c_mem_type); |
105 | |
106 | return; |
107 | } |
108 | |
109 | namespace { |
110 | |
111 | // Initial inspection of a fusion to find and validate double buffered tensors |
112 | class DoubleBufferFusionInspector : private IterVisitor { |
113 | public: |
114 | DoubleBufferFusionInspector(Fusion* fusion, DoubleBufferInfo& db_info) |
115 | : db_info_(db_info) { |
116 | traverse(fusion); |
117 | } |
118 | |
119 | private: |
120 | using IterVisitor::handle; |
121 | |
122 | void handle(TensorView* tv) final { |
123 | if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) { |
124 | return; |
125 | } |
126 | |
127 | TORCH_INTERNAL_ASSERT( |
128 | tv->definition(), "Fusion input shouldn't be double buffered." , tv); |
129 | |
130 | validateDoubleBufferedTensor(tv); |
131 | |
132 | auto db_axis = getDoubleBufferAxis(tv); |
133 | |
134 | db_info_.setDoubleBufferAxis(tv, db_axis); |
135 | } |
136 | |
137 | private: |
138 | DoubleBufferInfo& db_info_; |
139 | }; |
140 | |
141 | // The epilogue loop is only created when the producer of a double |
142 | // buffer tensor is on smem, in which case it would otherwise require |
143 | // an additional predicate to guard buffer overruns. When it's on |
144 | // gmem, that isn't the case, so it does not need to create an |
145 | // epilogue loop. |
146 | bool requireEpilogue(const std::vector<Expr*>& exprs) { |
147 | return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) { |
148 | return expr->input(0)->as<TensorView>()->getMemoryType() == |
149 | MemoryType::Shared; |
150 | }); |
151 | } |
152 | |
153 | // Replicates double buffer loops for Prologue, Main, and |
154 | // Epilogue. Prologue only copies the load expressions of double |
155 | // buffered tensors, whereas Epilogue does any expression other than |
156 | // the loads. Main copies everything. |
157 | class DoubleBufferLoopCloner : public kir::IrVisitor { |
158 | public: |
159 | static kir::ForLoop* clone( |
160 | kir::ForLoop* double_buffer_loop, |
161 | const std::vector<Expr*>& double_buffer_load_exprs, |
162 | DoubleBufferLoopStage loop_type) { |
163 | DoubleBufferLoopCloner cloner( |
164 | double_buffer_loop, double_buffer_load_exprs, loop_type); |
165 | cloner.clone(); |
166 | return cloner.cloned_top_level_loop_; |
167 | } |
168 | |
169 | private: |
170 | DoubleBufferLoopCloner( |
171 | kir::ForLoop* double_buffer_loop, |
172 | const std::vector<Expr*>& double_buffer_load_exprs, |
173 | DoubleBufferLoopStage loop_type) |
174 | : double_buffer_loop_(double_buffer_loop), |
175 | double_buffer_load_exprs_(double_buffer_load_exprs), |
176 | loop_type_(loop_type) {} |
177 | |
178 | using kir::IrVisitor::handle; |
179 | |
180 | void clone() { |
181 | const auto gpu_lower = GpuLower::current(); |
182 | |
183 | // Cloning the double buffer loop as follows: |
184 | // |
185 | // Prologue: 0 to 1 |
186 | // Main: 0 to (extent-1) |
187 | // Epilogue: (extent-1) to extent |
188 | |
189 | auto index = GpuLower::current()->caMap()->getIndexVariable( |
190 | double_buffer_loop_->iter_domain(), loop_type_); |
191 | auto start = double_buffer_loop_->start(); |
192 | auto stop = double_buffer_loop_->stop(); |
193 | auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor( |
194 | double_buffer_loop_->iter_domain()); |
195 | |
196 | if (loop_type_ == DoubleBufferLoopStage::Prolog) { |
197 | TORCH_INTERNAL_ASSERT(start->isZeroInt()); |
198 | stop = SimplifyingIrBuilder::create<Int>(stage_depth - 1); |
199 | } else if ( |
200 | loop_type_ == DoubleBufferLoopStage::Main && |
201 | requireEpilogue(double_buffer_load_exprs_)) { |
202 | stop = IrBuilder::subExpr( |
203 | double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal()); |
204 | } else if (loop_type_ == DoubleBufferLoopStage::Epilog) { |
205 | TORCH_INTERNAL_ASSERT(requireEpilogue(double_buffer_load_exprs_)); |
206 | start = IrBuilder::subExpr( |
207 | double_buffer_loop_->stop(), |
208 | SimplifyingIrBuilder::create<Int>(stage_depth - 1)); |
209 | } |
210 | |
211 | cloned_top_level_loop_ = IrBuilder::create<kir::ForLoop>( |
212 | double_buffer_loop_->iter_domain(), |
213 | index, |
214 | start, |
215 | stop, |
216 | gpu_lower->kernel()->oneVal(), |
217 | false, |
218 | nullptr, |
219 | double_buffer_loop_->isUnrollRequired(), |
220 | loop_type_); |
221 | |
222 | handle(double_buffer_loop_); |
223 | |
224 | if (stage_depth > 2) { |
225 | cloned_top_level_loop_->body().push_back( |
226 | IrBuilder::create<kir::CpAsyncCommit>()); |
227 | } |
228 | } |
229 | |
230 | void handle(kir::ForLoop* fl) final { |
231 | kir::ForLoop* cloned_loop = fl == double_buffer_loop_ |
232 | ? cloned_top_level_loop_ |
233 | : IrBuilder::create<kir::ForLoop>(fl); |
234 | |
235 | cloned_scopes_.push_back(&cloned_loop->body()); |
236 | |
237 | kir::IrVisitor::handle(fl); |
238 | |
239 | cloned_scopes_.pop_back(); |
240 | |
241 | // Add the cloned loop into the parent loop body only when the |
242 | // cloned loop contains expressions. |
243 | if (!cloned_loop->body().empty() && !cloned_scopes_.empty()) { |
244 | cloned_scopes_.back()->push_back(cloned_loop); |
245 | } |
246 | } |
247 | |
248 | void handle(kir::IfThenElse* ite) final { |
249 | TORCH_INTERNAL_ASSERT(false, "No IfThenElse should exist yet" ); |
250 | } |
251 | |
252 | void handle(Expr* expr) final { |
253 | if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) { |
254 | kir::IrVisitor::handle(expr); |
255 | return; |
256 | } |
257 | |
258 | TORCH_INTERNAL_ASSERT(!cloned_scopes_.empty()); |
259 | |
260 | if (loop_type_ == DoubleBufferLoopStage::Main) { |
261 | cloned_scopes_.back()->push_back(expr); |
262 | return; |
263 | } |
264 | |
265 | // In Prologue and Epilogue, either load expressions or anything |
266 | // else are copied. Note that there can be multiple exprs defining |
267 | // double buffered TVs (e.g., buffer initialization). |
268 | |
269 | auto out_tv = ir_utils::getTvOutput(expr); |
270 | const auto is_double_buffer_load_expr = std::any_of( |
271 | double_buffer_load_exprs_.begin(), |
272 | double_buffer_load_exprs_.end(), |
273 | [out_tv](const auto load_expr) { |
274 | auto double_buffer_tv = ir_utils::getTvOutput(load_expr); |
275 | TORCH_INTERNAL_ASSERT(double_buffer_tv != nullptr); |
276 | return out_tv == double_buffer_tv; |
277 | }); |
278 | if ((loop_type_ == DoubleBufferLoopStage::Prolog && |
279 | is_double_buffer_load_expr) || |
280 | (loop_type_ == DoubleBufferLoopStage::Epilog && |
281 | !is_double_buffer_load_expr)) { |
282 | cloned_scopes_.back()->push_back(expr); |
283 | } |
284 | } |
285 | |
286 | private: |
287 | kir::ForLoop* double_buffer_loop_ = nullptr; |
288 | const std::vector<Expr*>& double_buffer_load_exprs_; |
289 | const DoubleBufferLoopStage loop_type_; |
290 | |
291 | kir::ForLoop* cloned_top_level_loop_ = nullptr; |
292 | std::deque<kir::Scope*> cloned_scopes_; |
293 | }; |
294 | |
295 | using InsertionInfo = std::unordered_map<kir::ForLoop*, std::vector<Expr*>>; |
296 | |
297 | // Traverse lowered loop-nests and find all double buffer loops and |
298 | // associated load expressions. |
299 | class DoubleBufferLoopNestInspector : private kir::IrVisitor { |
300 | public: |
301 | static InsertionInfo run(const std::vector<Expr*>& exprs) { |
302 | DoubleBufferLoopNestInspector inspector(exprs); |
303 | return inspector.insertion_info_; |
304 | } |
305 | |
306 | private: |
307 | DoubleBufferLoopNestInspector(const std::vector<Expr*>& exprs) { |
308 | handle(exprs); |
309 | } |
310 | |
311 | using kir::IrVisitor::handle; |
312 | |
313 | // Collect double buffer related information on a expr |
314 | // that is a memory load, i.e. a LoadStore or a Set. |
315 | void handlePossibleLoadExpr(Expr* expr) { |
316 | const auto gpu_lower = GpuLower::current(); |
317 | |
318 | auto out_tv = ir_utils::getTvOutput(expr); |
319 | |
320 | if (out_tv == nullptr) { |
321 | return; |
322 | } |
323 | |
324 | // Ignore init loop |
325 | if (!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered()) || |
326 | !expr->input(0)->isA<TensorView>()) { |
327 | return; |
328 | } |
329 | |
330 | auto double_buffer_loop = |
331 | gpu_lower->doubleBufferInfo().getDoubleBufferLoop(out_tv, for_loops_); |
332 | |
333 | TORCH_INTERNAL_ASSERT( |
334 | double_buffer_loop != nullptr, |
335 | "No double buffer loop found for a double buffered tensor: " , |
336 | out_tv->toString()); |
337 | |
338 | validateDoubleBufferLoop(double_buffer_loop); |
339 | |
340 | insertion_info_[double_buffer_loop].push_back(expr); |
341 | } |
342 | |
343 | void handle(UnaryOp* uop) final { |
344 | handlePossibleLoadExpr(uop); |
345 | } |
346 | |
347 | void handle(LoadStoreOp* ldst) final { |
348 | handlePossibleLoadExpr(ldst); |
349 | } |
350 | |
351 | static void validateDoubleBufferLoop(kir::ForLoop* loop) { |
352 | TORCH_INTERNAL_ASSERT( |
353 | loop->start()->isZeroInt(), "Unsupported loop: " , loop->toString()); |
354 | TORCH_INTERNAL_ASSERT( |
355 | loop->step()->isOneInt(), "Unsupported loop: " , loop->toString()); |
356 | TORCH_INTERNAL_ASSERT( |
357 | !loop->vectorize(), |
358 | "Vectorized loop should not be the allocation loop for double-buffered tensor: " , |
359 | loop->toString()); |
360 | TORCH_INTERNAL_ASSERT( |
361 | !loop->vectorize_shift(), |
362 | "Vectorize shift loop should not be the allocation loop for double-buffered tensor: " , |
363 | loop->toString()); |
364 | } |
365 | |
366 | InsertionInfo insertion_info_; |
367 | }; |
368 | |
369 | // Apply double buffering transformations |
370 | class DoubleBufferInserter : private kir::ExprMutator { |
371 | public: |
372 | // When there exist multiple double buffer loops, apply |
373 | // transformations to inner-most loops first. A single ExprMutator |
374 | // pass can only process one loop. |
375 | static std::vector<Expr*> run( |
376 | const std::vector<Expr*>& exprs, |
377 | InsertionInfo insertion_info) { |
378 | auto inserted_exprs = exprs; |
379 | while (!insertion_info.empty()) { |
380 | DoubleBufferInserter inserter(inserted_exprs, insertion_info); |
381 | inserted_exprs = inserter.exprs_; |
382 | } |
383 | return inserted_exprs; |
384 | } |
385 | |
386 | private: |
387 | DoubleBufferInserter( |
388 | const std::vector<Expr*>& exprs, |
389 | InsertionInfo& insertion_info) |
390 | : insertion_info_(insertion_info) { |
391 | auto num_double_buffer_loops = insertion_info.size(); |
392 | traverseAndInsert(exprs); |
393 | TORCH_INTERNAL_ASSERT(processed_loop_ != nullptr); |
394 | TORCH_INTERNAL_ASSERT(insertion_info.size() == num_double_buffer_loops - 1); |
395 | } |
396 | |
397 | using kir::ExprMutator::handle; |
398 | |
399 | void handle(kir::ForLoop* loop) final { |
400 | kir::ExprMutator::handle(loop); |
401 | |
402 | // If another loop is already taken care of, no more loop should |
403 | // be done in the same pass |
404 | if (processed_loop_ != nullptr) { |
405 | return; |
406 | } |
407 | |
408 | auto it = insertion_info_.find(loop); |
409 | if (it == insertion_info_.end()) { |
410 | return; |
411 | } |
412 | |
413 | insert(loop, it->second); |
414 | processed_loop_ = loop; |
415 | insertion_info_.erase(loop); |
416 | } |
417 | |
418 | void insert( |
419 | kir::ForLoop* double_buffer_loop, |
420 | const std::vector<Expr*>& loads) { |
421 | auto prologue_loop = DoubleBufferLoopCloner::clone( |
422 | double_buffer_loop, loads, DoubleBufferLoopStage::Prolog); |
423 | registerInsertBefore(double_buffer_loop, prologue_loop); |
424 | |
425 | auto write_to_smem = |
426 | std::any_of(loads.begin(), loads.end(), [](const Expr* expr) { |
427 | return expr->output(0)->as<TensorView>()->getMemoryType() == |
428 | MemoryType::Shared; |
429 | }); |
430 | |
431 | // RAW sync is not inserted for double buffered tensors. The only |
432 | // exception is the prologue load. |
433 | bool insert_cpasync_wait = false; |
434 | if (write_to_smem) { |
435 | // Here the initial sync before entering double buffer loop is |
436 | // inserted. |
437 | |
438 | // If any of the double buffered tensor in this double buffer |
439 | // loop is async copy. We want to wait for the gmem loads to |
440 | // finish before synchronizing the block. |
441 | if (std::any_of(loads.begin(), loads.end(), ir_utils::isCpAsyncOp)) { |
442 | auto stage_depth = |
443 | GpuLower::current()->doubleBufferInfo().getStageDepthFor( |
444 | double_buffer_loop->iter_domain()); |
445 | auto cp_async_wait = |
446 | IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2); |
447 | registerInsertBefore(double_buffer_loop, cp_async_wait); |
448 | insert_cpasync_wait = true; |
449 | } |
450 | |
451 | // Insert the initial block sync before entering main loop. |
452 | if (std::any_of(loads.begin(), loads.end(), [](Expr* expr) { |
453 | return GpuLower::current() |
454 | ->syncMap() |
455 | .needsRawSync(ir_utils::getTvOutput(expr)) |
456 | .hasTID(); |
457 | })) { |
458 | // If any of the double buffered loads require sync, as indicated |
459 | // by sync info map, insert the sync before entering the double buffer |
460 | // loop. |
461 | // TODO: |
462 | // Currently not supporting double buffer in gmem, but short to mid |
463 | // term not yet a priority to go for this case. |
464 | auto sync = IrBuilder::create<kir::BlockSync>(false); |
465 | registerInsertBefore(double_buffer_loop, sync); |
466 | } |
467 | } |
468 | |
469 | auto main_loop = DoubleBufferLoopCloner::clone( |
470 | double_buffer_loop, loads, DoubleBufferLoopStage::Main); |
471 | |
472 | registerReplace(double_buffer_loop, main_loop); |
473 | |
474 | // Insert the wait instruction in this pass instead |
475 | // of relying on WAR sync pass to do it. |
476 | // The WAR sync pass today would insert the wait function |
477 | // exactly where we need it but the purpose of this wait |
478 | // insertion isn't exactly WAR protection. |
479 | // |
480 | // TODO: [Double Buffer Sync] |
481 | // We might eventually want to move the block sync inserted |
482 | // by WAR pass here as well since this sync insertion is kind |
483 | // of both WAR and RAW (or neither RAW nor WAR, depends |
484 | // on how we look at it). |
485 | // Eg. in the case when a intermediate |
486 | // tensor is double buffered. |
487 | // |
488 | // __block_sync(); // This is the initial sync |
489 | // For i in ... // Double buffer loop |
490 | // A[i%2] = ...; |
491 | // ... = A[1-i%2]; |
492 | // __block_sync(); // sync within loop |
493 | // ... |
494 | // The "sync within loop" can be placed anywhere in the |
495 | // double buffer loop while in the case of RAW and WAR |
496 | // there'd be extra insertion point restrictions. |
497 | // We are currently not actively exploring opportunities |
498 | // with this property of "double buffer sync" so this |
499 | // is more conceptual at the moment, aka low priority. |
500 | if (insert_cpasync_wait) { |
501 | insertCpAsyncWaitInMainLoop(main_loop); |
502 | } |
503 | |
504 | if (requireEpilogue(loads)) { |
505 | auto epilogue_loop = DoubleBufferLoopCloner::clone( |
506 | double_buffer_loop, loads, DoubleBufferLoopStage::Epilog); |
507 | registerInsertAfter(double_buffer_loop, epilogue_loop); |
508 | } |
509 | } |
510 | |
511 | // Simple conservative rule for inserting async copy wait |
512 | // primitive in the double buffer loop: |
513 | void insertCpAsyncWaitInMainLoop(kir::ForLoop* main_loop) { |
514 | TORCH_INTERNAL_ASSERT( |
515 | !main_loop->body().empty(), |
516 | "Double buffer sync insertion: empty main loop." ); |
517 | // Note: This pass explicitly assumes that WAR sync has been |
518 | // inserted so would need to be updated if we re-order the |
519 | // passes. Cleanups suggested in [Double Buffer Sync] |
520 | // would resolve this dependency on pass ordering. |
521 | auto end_of_loop_expr = main_loop->body().exprs().back(); |
522 | auto stage_depth = GpuLower::current()->doubleBufferInfo().getStageDepthFor( |
523 | main_loop->iter_domain()); |
524 | auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2); |
525 | |
526 | // Check if a sync has been inserted by WAR sync pass. |
527 | auto block_sync_it = std::find_if( |
528 | main_loop->body().exprs().rbegin(), |
529 | main_loop->body().exprs().rend(), |
530 | [](const Expr* expr) { return expr->isA<kir::BlockSync>(); }); |
531 | if (block_sync_it == main_loop->body().exprs().rend()) { |
532 | // If there's no sync, i.e. no tensor needs cross |
533 | // thread communication. We still need a wait but |
534 | // it can just be anywhere in the loop. Chose to |
535 | // place at the end arbitrarily. |
536 | main_loop->body().insert_after(end_of_loop_expr, cp_async_wait); |
537 | } else { |
538 | // If a sync has been inserted, wait needs to be placed |
539 | // before the sync. |
540 | main_loop->body().insert_before(*block_sync_it, cp_async_wait); |
541 | } |
542 | } |
543 | |
544 | private: |
545 | InsertionInfo& insertion_info_; |
546 | kir::ForLoop* processed_loop_ = nullptr; |
547 | }; |
548 | |
549 | } // namespace |
550 | |
551 | void DoubleBufferInfo::build(Fusion* fusion) { |
552 | DoubleBufferFusionInspector inspector(fusion, *this); |
553 | |
554 | // Build double buffered loop id's |
555 | for (auto& info : map_) { |
556 | auto double_buffer_axis = info.second.double_buffer_axis; |
557 | // Keeps track of which loop disjoint set has been |
558 | // double buffered. In index allocation, one index |
559 | // variable would need to be allocated in each |
560 | // double buffer stage. |
561 | concrete_double_buffered_loop_id_.insert( |
562 | GpuLower::current()->caMap()->getConcreteMappedID( |
563 | double_buffer_axis, IdMappingMode::LOOP)); |
564 | } |
565 | } |
566 | |
567 | bool DoubleBufferInfo::isDoubleBufferedIterDomain(IterDomain* id) { |
568 | auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( |
569 | id, IdMappingMode::LOOP); |
570 | return concrete_double_buffered_loop_id_.count(concrete_loop_id); |
571 | } |
572 | |
573 | DoubleBufferInfo::TvInfo& DoubleBufferInfo::getTvInfo(const TensorView* tv) { |
574 | TORCH_INTERNAL_ASSERT( |
575 | tv->isDoubleBuffered() || tv->isCircularBuffered(), |
576 | "Not a double-buffered tensor: " , |
577 | tv->toString()); |
578 | return map_[tv]; |
579 | } |
580 | |
581 | void DoubleBufferInfo::setDoubleBufferAxis( |
582 | const TensorView* tv, |
583 | IterDomain* axis) { |
584 | getTvInfo(tv).double_buffer_axis = axis; |
585 | |
586 | // Also validate the stage consistency with CA map. |
587 | unsigned int stage_depth = 0; |
588 | if (tv->isCircularBuffered()) { |
589 | stage_depth = tv->circularBufferDepth(); |
590 | } else { |
591 | // Double buffer is essentially |
592 | // circular buffer with depth 2. |
593 | stage_depth = 2; |
594 | } |
595 | |
596 | // Set and validate the new stage depth. |
597 | setStageDepth(axis, stage_depth); |
598 | } |
599 | |
600 | void DoubleBufferInfo::setStageDepth(IterDomain* id, unsigned int stage_depth) { |
601 | auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( |
602 | id, IdMappingMode::LOOP); |
603 | |
604 | auto maybe_exisiting_depth_it = stage_depth_.find(concrete_loop_id); |
605 | if (maybe_exisiting_depth_it == stage_depth_.end()) { |
606 | stage_depth_[concrete_loop_id] = stage_depth; |
607 | } else { |
608 | TORCH_INTERNAL_ASSERT( |
609 | stage_depth == maybe_exisiting_depth_it->second, |
610 | "Unsupported multiple depth pipelining, was set to " , |
611 | maybe_exisiting_depth_it->second, |
612 | " by " , |
613 | maybe_exisiting_depth_it->first->toString(), |
614 | " and then set to " , |
615 | stage_depth, |
616 | " by " , |
617 | concrete_loop_id->toString()); |
618 | } |
619 | } |
620 | |
621 | IterDomain* DoubleBufferInfo::getDoubleBufferAxis(const TensorView* tv) { |
622 | if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) { |
623 | return nullptr; |
624 | } |
625 | |
626 | return getTvInfo(tv).double_buffer_axis; |
627 | } |
628 | |
629 | unsigned int DoubleBufferInfo::getStageDepthFor( |
630 | IterDomain* double_buffer_axis) { |
631 | auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( |
632 | double_buffer_axis, IdMappingMode::LOOP); |
633 | |
634 | auto maybe_depth_it = stage_depth_.find(concrete_id); |
635 | |
636 | TORCH_INTERNAL_ASSERT( |
637 | maybe_depth_it != stage_depth_.end(), "Stage depth not found" ); |
638 | |
639 | return maybe_depth_it->second; |
640 | } |
641 | |
642 | kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop( |
643 | IterDomain* axis, |
644 | const std::vector<kir::ForLoop*>& loops, |
645 | bool ignore_prologue) { |
646 | auto loop_it = std::find_if(loops.begin(), loops.end(), [&](const auto loop) { |
647 | return GpuLower::current()->caMap()->areMapped( |
648 | loop->iter_domain(), axis, IdMappingMode::EXACT) && |
649 | (!ignore_prologue || |
650 | loop->doubleBufferLoopStage() != DoubleBufferLoopStage::Prolog); |
651 | }); |
652 | |
653 | if (loop_it != loops.end()) { |
654 | return *loop_it; |
655 | } else { |
656 | return nullptr; |
657 | } |
658 | } |
659 | |
660 | kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop( |
661 | const TensorView* tv, |
662 | const std::vector<kir::ForLoop*>& loops, |
663 | bool ignore_prologue) { |
664 | auto axis = getDoubleBufferAxis(tv); |
665 | |
666 | if (axis == nullptr) { |
667 | return nullptr; |
668 | } |
669 | |
670 | return getDoubleBufferLoop(axis, loops, ignore_prologue); |
671 | } |
672 | |
673 | void DoubleBufferInfo::setOriginalAllocSize( |
674 | const TensorView* tv, |
675 | Val* original_alloc_size) { |
676 | getTvInfo(tv).original_alloc_size = original_alloc_size; |
677 | } |
678 | |
679 | Val* DoubleBufferInfo::getOriginalAllocSize(const TensorView* tv) { |
680 | if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) { |
681 | return nullptr; |
682 | } |
683 | |
684 | return getTvInfo(tv).original_alloc_size; |
685 | } |
686 | |
687 | std::vector<Expr*> DoubleBufferPass::run(const std::vector<Expr*>& exprs) { |
688 | auto insertion_info = DoubleBufferLoopNestInspector::run(exprs); |
689 | return DoubleBufferInserter::run(exprs, insertion_info); |
690 | } |
691 | |
692 | } // namespace cuda |
693 | } // namespace fuser |
694 | } // namespace jit |
695 | } // namespace torch |
696 | |