1#include <instrumentation.h>
2#include <ir_iostream.h>
3#include <kernel_expr_evaluator.h>
4#include <kernel_ir.h>
5#include <kernel_ir_dispatch.h>
6#include <lower2device.h>
7#include <lower_allocation.h>
8
9#include <unordered_set>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16namespace {
17
18class AllocationInserter : public kir::ExprMutator {
19 private:
20 using kir::ExprMutator::handle;
21
22 // Expanded version of BasicAllocInfo in lower_utils.h helps to track
23 // additional information
24 struct AllocationInformation {
25 // The for loop that the initialization of this allocation must be
26 // placed in, nullptr if not within a loop
27 kir::ForLoop* init_for_loop = nullptr;
28
29 // The expression that the initialization of this allocation must
30 // be placed before
31 Expr* init_place_before = nullptr;
32
33 // Keep track of the actual allocation loop. This can be different
34 // from init_for_loop only with unswitched shared memory allocations,
35 // which are moved outer loops to avoid duplicated allocations
36 // (see issue #1133).
37 kir::ForLoop* alloc_for_loop = nullptr;
38
39 // The expression that this allocation must be placed
40 // before. Similar to alloc_for_loop, this is different from
41 // init_place_before only with unswitched shared memory allocations.
42 Expr* alloc_place_before = nullptr;
43
44 // The allocation position relative to buffer
45 size_t alloc_pos = 0;
46
47 // The buffer this allocation is for
48 TensorView* buffer = nullptr;
49
50 // Info to transfer to GPU lower
51 bool has_halo = false;
52
53 // Local Iterdomains that this allocation covers
54 std::unique_ptr<std::vector<IterDomain*>> allocation_domains;
55 };
56
57 // Find allocation point
58 // Fills info.buffer, info.alloc_pos, info.init_for_loop,
59 // info.init_place_before, info.alloc_for_loop, info.alloc_place_before
60 void fillAllocationInformation(AllocationInformation& info, Expr* expr) {
61 auto loop_alloc_info =
62 lower_utils::getAllocInformation(info.buffer, for_loops_);
63
64 info.init_for_loop = loop_alloc_info.init_for_loop;
65 info.alloc_for_loop = loop_alloc_info.alloc_for_loop;
66 info.alloc_pos = loop_alloc_info.alloc_pos;
67
68 auto next_fl = [](kir::ForLoop* fl, const std::vector<kir::ForLoop*> fls) {
69 for (auto i : c10::irange(fls.size())) {
70 if (fl == fls[i]) {
71 if (i + 1 < fls.size()) {
72 return fls[i + 1];
73 }
74 }
75 }
76 TORCH_INTERNAL_ASSERT(false, "Could not find desired loop.");
77 };
78
79 if (info.init_for_loop == nullptr) {
80 info.init_place_before = for_loops_.size() > 0 ? for_loops_[0] : expr;
81 } else {
82 if (info.init_for_loop == for_loops_.back()) {
83 // Inline allocation, place before expr
84 info.init_place_before = expr;
85 } else {
86 // Place allocation after the last computeAt axis
87 // TODO: may be more efficient to place before the first non-computeAt
88 // axis
89 info.init_place_before = next_fl(info.init_for_loop, for_loops_);
90 }
91 }
92
93 // Set the allocation loop and the place_before expression in the
94 // same way as the initialization loop and place_before expression
95 if (info.alloc_for_loop == info.init_for_loop) {
96 info.alloc_for_loop = info.init_for_loop;
97 info.alloc_place_before = info.init_place_before;
98 } else {
99 if (info.alloc_for_loop == nullptr) {
100 info.alloc_place_before = for_loops_.size() > 0 ? for_loops_[0] : expr;
101 } else {
102 // Since there must be an inner unswitched domain,
103 // alloc_for_loop should never be the inner-most loop.
104 TORCH_INTERNAL_ASSERT(info.alloc_for_loop != for_loops_.back());
105 info.alloc_place_before = next_fl(info.alloc_for_loop, for_loops_);
106 }
107 }
108 }
109
110 // Create initialization expression if init_val is non-null.
111 Expr* createInitExpr(AllocationInformation& info, Val* init_val) {
112 if (init_val == nullptr) {
113 return nullptr;
114 }
115
116 std::vector<IterDomain*> init_dims;
117 for (const auto axis_i :
118 c10::irange(info.alloc_pos, info.buffer->nDims())) {
119 if (info.buffer->axis(axis_i)->isReduction() ||
120 info.buffer->axis(axis_i)->isBroadcast()) {
121 continue;
122 }
123 auto concrete_id = gpu_lower->caMap()->getConcreteMappedID(
124 info.buffer->axis(axis_i), IdMappingMode::LOOP);
125 init_dims.push_back(concrete_id);
126 }
127 Expr* init_expr =
128 IrBuilder::create<UnaryOp>(UnaryOpType::Set, info.buffer, init_val);
129 for (auto init_loop_it = init_dims.rbegin();
130 init_loop_it != init_dims.rend();
131 ++init_loop_it) {
132 auto id = *init_loop_it;
133 kir::ForLoop* new_loop = nullptr;
134 auto extent_with_halo = gpu_lower->haloInfo()->getExtent(id);
135 if (extent_with_halo) {
136 new_loop = IrBuilder::create<kir::ForLoop>(
137 id,
138 IrBuilder::create<Int>(c10::nullopt),
139 nullptr,
140 extent_with_halo,
141 nullptr,
142 false,
143 nullptr,
144 false,
145 DoubleBufferLoopStage::NotApplicable);
146 } else {
147 new_loop = IrBuilder::create<kir::ForLoop>(id);
148 }
149 new_loop->body().push_back(init_expr);
150 init_expr = new_loop;
151 }
152 return init_expr;
153 }
154
155 std::vector<Val*> getGlobalAllocationSizes(AllocationInformation& info) {
156 const auto& domain = info.buffer->domain();
157 const auto& maybe_rfactor_domain = domain->hasRFactor()
158 ? domain->getRFactorDomain()
159 : domain->getRootDomain();
160
161 std::vector<Val*> alloc_dims;
162
163 for (const auto id : maybe_rfactor_domain) {
164 if (id->isReduction() || id->isStride() || id->isBroadcast()) {
165 continue;
166 }
167 auto extent = id->extent();
168 // Use halo-extended extent if found
169 auto halo_extent = gpu_lower->haloInfo()->getRootAxisInfo(id);
170 if (halo_extent.hasHalo()) {
171 extent = IrBuilder::addExpr(
172 extent, IrBuilder::create<Int>(halo_extent.width()));
173 }
174 alloc_dims.push_back(extent);
175 }
176
177 return alloc_dims;
178 }
179
180 // Get allocation extents of root axes with halo
181 //
182 // Allocation can be done with leaf IDs with halo as well, but
183 // allocation size could be larger than necessary.
184 //
185 // For example, suppose the shift offset of an axis is 1. When it is
186 // split by N, the halo size of the inner output is N+1. When the
187 // allocation only has the inner split output, the allocation size
188 // would be N+1. Suppose that ID is further split by M, the output
189 // extents would be N/M and M+1. The allocation size based on the
190 // leaves would be N/M*(M+1) or N+N/M, which is larger than N+1.
191 //
192 // This function tries to propagate back halo informatin to root
193 // axes to avoid inflating allocations. It fails when merged domains
194 // are split and only one of the split outputs is used for
195 // allocations since in such a case we can't un-merge and properly
196 // determine the extents of the merge inputs. Currently, that
197 // results in an exception, but it may be more reasonable to simply
198 // fall back to the leaf-based allocation.
199 //
200 // See the FusionShiftDoubleSplit test for an example case.
201 std::vector<Val*> getNonGlobalAllocExprWithHalo(
202 TensorView* tv,
203 const std::vector<IterDomain*>& alloc_domains) {
204 std::vector<Val*> start_vals;
205 std::transform(
206 alloc_domains.begin(),
207 alloc_domains.end(),
208 std::back_inserter(start_vals),
209 [](IterDomain* dom) { return dom->as<Val>(); });
210
211 // Get all exprs involved in generating the allocation IDs
212 auto exprs = StmtSort::getExprs(tv->fusion(), start_vals);
213
214 // Get the halo extent if found
215 auto getExtent = [this](IterDomain* id) {
216 auto extent = gpu_lower->haloInfo()->getExtent(id);
217 if (extent == nullptr) {
218 extent = id->extent();
219 }
220 return extent;
221 };
222
223 std::unordered_map<IterDomain*, Val*> known_extents;
224
225 // IterDomains that are allocated fully. For example, if an ID is
226 // split and only one of them is used for allocation, that's not
227 // considered full. Only full domains can be unmerged, which is
228 // needed to propagate back the halo information to root domains.
229 std::unordered_set<IterDomain*> full_domains;
230
231 for (auto alloc_domain : alloc_domains) {
232 known_extents.insert({alloc_domain, getExtent(alloc_domain)});
233 full_domains.insert(alloc_domain);
234 }
235
236 for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
237 auto expr = *it;
238 if (auto merge = dynamic_cast<Merge*>(expr)) {
239 auto out_it = known_extents.find(merge->out());
240 // If nothing is know about the out id, no propagation can be
241 // done. Note that's not necessarily an error.
242 if (out_it == known_extents.end()) {
243 continue;
244 }
245 // Similarly, if the extent of the out id is not full extent,
246 // we can't un-merge it.
247 if (full_domains.find(merge->out()) == full_domains.end()) {
248 continue;
249 }
250 // Since the extent of the out id is full, the extent of each
251 // of the input axes is also full
252 known_extents.insert({merge->inner(), getExtent(merge->inner())});
253 full_domains.insert(merge->inner());
254 known_extents.insert({merge->outer(), getExtent(merge->outer())});
255 full_domains.insert(merge->outer());
256 known_extents.erase(out_it);
257 } else if (auto split = dynamic_cast<Split*>(expr)) {
258 auto inner = split->inner();
259 const auto inner_it = known_extents.find(inner);
260 auto outer = split->outer();
261 const auto outer_it = known_extents.find(outer);
262 if (inner_it != known_extents.end() &&
263 outer_it != known_extents.end()) {
264 if (full_domains.find(inner) != full_domains.end() &&
265 full_domains.find(outer) != full_domains.end()) {
266 known_extents.insert({split->in(), getExtent(split->in())});
267 full_domains.insert(split->in());
268 } else {
269 known_extents.insert(
270 {split->in(),
271 IrBuilder::mulExpr(outer_it->second, inner_it->second)});
272 }
273 known_extents.erase(inner_it);
274 known_extents.erase(outer_it);
275 } else if (inner_it != known_extents.end()) {
276 known_extents.insert({split->in(), inner_it->second});
277 known_extents.erase(inner_it);
278 } else if (outer_it != known_extents.end()) {
279 known_extents.insert({split->in(), outer_it->second});
280 known_extents.erase(outer_it);
281 }
282 } else {
283 TORCH_INTERNAL_ASSERT(false, "Unexpected expr: ", expr);
284 }
285 }
286
287 std::vector<Val*> alloc_dims;
288
289 for (auto root_axis : tv->getRootDomain()) {
290 auto it = known_extents.find(root_axis);
291 if (it == known_extents.end()) {
292 continue;
293 }
294 alloc_dims.push_back(it->second);
295 known_extents.erase(it);
296 }
297
298 // known_extents should have only mappings for root axes, so
299 // if anything remains in the map, it's an error
300 if (!known_extents.empty()) {
301 std::stringstream ss;
302 for (auto kv : known_extents) {
303 ss << kv.first << " ";
304 }
305 TORCH_INTERNAL_ASSERT(
306 false, "Non-root axes found for TV", tv->name(), ": ", ss.str());
307 }
308
309 return alloc_dims;
310 }
311
312 std::vector<Val*> getNonGlobalAllocExpr(AllocationInformation& info) {
313 const auto memory_type = info.buffer->getMemoryType();
314 TORCH_INTERNAL_ASSERT(
315 memory_type != MemoryType::Global,
316 "Invalid memory type: ",
317 memory_type);
318
319 std::vector<Val*> alloc_dims;
320
321 bool has_halo = false;
322 std::vector<IterDomain*> alloc_domains;
323
324 info.allocation_domains = std::make_unique<std::vector<IterDomain*>>();
325
326 for (const auto axis_i : c10::irange(info.buffer->nDims())) {
327 const auto local_id = info.buffer->axis(axis_i);
328
329 // Don't use reduction/stride/broadcast axis in the allocation
330 // computation
331 if (local_id->isReduction() || local_id->isStride() ||
332 local_id->isBroadcast()) {
333 continue;
334 }
335
336 auto concrete_id = gpu_lower->caMap()->getConcreteMappedID(
337 info.buffer->axis(axis_i), IdMappingMode::LOOP);
338 const bool is_block_dim =
339 isParallelTypeBlockDim(concrete_id->getParallelType());
340 const bool is_thread_dim =
341 isParallelTypeThreadDim(concrete_id->getParallelType());
342 const bool is_thread =
343 isParallelTypeThread(concrete_id->getParallelType());
344
345 if (axis_i < info.alloc_pos) {
346 // Even when the axis is outside the allocation position, if the
347 // tensor is shared with respect to the axis, the buffer size
348 // needs to be expanded for the axis. Sharing occurs in two
349 // cases: 1) the tensor is on shared memory with the axis
350 // parallelized by TIDs, and 2) the tensor is on global memory
351 // with the axis parallelized by TIDs or BIDs.
352 if (!((memory_type == MemoryType::Shared && is_thread_dim) ||
353 (memory_type == MemoryType::Global && is_thread))) {
354 continue;
355 }
356 alloc_domains.push_back(info.buffer->axis(axis_i));
357 } else {
358 if (
359 // If shared memory, don't use any IDs bound to a grid dimension
360 (memory_type == MemoryType::Shared && is_block_dim) ||
361 // If local memory, don't use any IDs bound to a grid or block
362 // dimension
363 (memory_type == MemoryType::Local && is_thread)) {
364 continue;
365 }
366 alloc_domains.push_back(info.buffer->axis(axis_i));
367 }
368
369 auto extent = concrete_id->extent();
370
371 if (gpu_lower->haloInfo()->getExtent(info.buffer->axis(axis_i)) !=
372 nullptr) {
373 has_halo = true;
374 }
375
376 alloc_dims.push_back(extent);
377 info.allocation_domains->push_back(local_id);
378 }
379
380 // When an axis with halo extension is detected, propagate back
381 // the halo extents from leaf IDs to root IDs
382 if (has_halo) {
383 info.has_halo = true;
384 return getNonGlobalAllocExprWithHalo(info.buffer, alloc_domains);
385 }
386
387 return alloc_dims;
388 }
389
390 kir::Allocate* createAllocExpr(AllocationInformation& info, bool is_output) {
391 if (is_output) {
392 return nullptr;
393 }
394
395 std::vector<Val*> alloc_dims;
396 const MemoryType memory_type = info.buffer->getMemoryType();
397
398 if (memory_type == MemoryType::Global) {
399 alloc_dims = getGlobalAllocationSizes(info);
400 } else {
401 alloc_dims = getNonGlobalAllocExpr(info);
402 }
403
404 if (alloc_dims.size() == 0 &&
405 info.buffer->domain()->noReductions().size() != 0) {
406 alloc_dims.push_back(info.buffer->container()->oneVal());
407 }
408
409 // Double the allocation size if double-buffered. Record the
410 // original size for indexing.
411 if (info.buffer->isDoubleBuffered() || info.buffer->isCircularBuffered()) {
412 Val* original_alloc_size = nullptr;
413 for (auto alloc_dim : alloc_dims) {
414 if (original_alloc_size == nullptr) {
415 original_alloc_size = alloc_dim;
416 } else {
417 original_alloc_size =
418 IrBuilder::mulExpr(original_alloc_size, alloc_dim);
419 }
420 }
421 GpuLower::current()->doubleBufferInfo().setOriginalAllocSize(
422 info.buffer, original_alloc_size);
423 int double_buffer_stage = 2;
424 if (info.buffer->isCircularBuffered()) {
425 double_buffer_stage = info.buffer->circularBufferDepth();
426 }
427 alloc_dims.push_back(IrBuilder::create<Int>(double_buffer_stage));
428 }
429
430 // Create the allocation node
431 return IrBuilder::create<kir::Allocate>(
432 info.buffer, info.buffer->getMemoryType(), alloc_dims);
433 }
434
435 void handle(Expr* expr) override {
436 if (!ir_utils::isTvOp(expr) || expr->isA<kir::Allocate>()) {
437 ExprMutator::handle(expr);
438 return;
439 }
440
441 // // Found where the allocation needs to be inserted
442
443 for (const auto i : c10::irange(expr->outputs().size())) {
444 auto out = expr->output(i);
445 if (!out->isA<TensorView>()) {
446 continue;
447 }
448
449 auto out_tv = out->as<TensorView>();
450 auto default_val = gpu_lower->predicateElimination().getInitValue(out_tv);
451
452 Val* init = nullptr;
453 if (expr->isA<ReductionOp>() && out_tv->hasReduction()) {
454 TORCH_INTERNAL_ASSERT(
455 default_val == nullptr,
456 "Reduction should not have a default initialization value for predicate elimination.");
457 init = expr->as<ReductionOp>()->init();
458 } else if (expr->isA<GroupedReductionOp>() && out_tv->hasReduction()) {
459 TORCH_INTERNAL_ASSERT(
460 default_val == nullptr,
461 "Reduction should not have a default initialization value for predicate elimination.");
462 init = expr->as<GroupedReductionOp>()->initVal(i);
463 } else if (expr->isA<MmaOp>()) {
464 init = expr->as<MmaOp>()->init();
465 } else if (expr->isA<WelfordOp>()) {
466 TORCH_INTERNAL_ASSERT(
467 default_val == nullptr,
468 "Welford should not have a default initialization value for predicate elimination.");
469 const auto welford = expr->as<WelfordOp>();
470 if (out->name() == welford->outVar()->name()) {
471 init = welford->initVar() == nullptr ? IrBuilder::create<Double>(0)
472 : welford->initVar();
473 } else if (out->name() == welford->outAvg()->name()) {
474 init = welford->initAvg() == nullptr ? IrBuilder::create<Double>(0)
475 : welford->initAvg();
476 } else {
477 TORCH_INTERNAL_ASSERT(
478 out->name() == welford->outN()->name(), "Unreachable");
479 init = welford->initN();
480 }
481 } else if (expr->isA<GroupedWelfordOp>()) {
482 TORCH_INTERNAL_ASSERT(
483 default_val == nullptr,
484 "Welford should not have a default initialization value for predicate elimination.");
485 init = expr->as<GroupedWelfordOp>()->getInitValOfOutput(out);
486 } else if (default_val != nullptr) {
487 init = default_val;
488 }
489
490 const bool is_output = out->isFusionOutput();
491
492 // Don't need to alloc outputs, and if we don't need to initialize we're
493 // done.
494 if (is_output && init == nullptr) {
495 continue;
496 }
497
498 AllocationInformation allocation;
499 allocation.buffer = out_tv;
500 fillAllocationInformation(allocation, expr);
501
502 auto alloc_expr = createAllocExpr(allocation, is_output);
503 auto init_expr = createInitExpr(allocation, init);
504
505 // Write information to GPULower
506 writeInfoToGPULower(allocation, alloc_expr);
507
508 // Register allocations before initializations to keep them in the right
509 // order
510 if (alloc_expr != nullptr) {
511 if (allocation.buffer->getMemoryType() == MemoryType::Shared) {
512 // Shared allocations go at the begining of scope
513 TORCH_INTERNAL_ASSERT(!exprs_.empty());
514 registerInsertBefore(exprs_[0], alloc_expr, nullptr);
515 } else {
516 TORCH_INTERNAL_ASSERT(allocation.alloc_place_before != nullptr);
517 kir::Scope* scope = allocation.alloc_for_loop == nullptr
518 ? nullptr
519 : &allocation.alloc_for_loop->body();
520 registerInsertBefore(
521 allocation.alloc_place_before, alloc_expr, scope);
522 }
523 }
524
525 if (init_expr != nullptr) {
526 TORCH_INTERNAL_ASSERT(allocation.init_place_before != nullptr);
527 kir::Scope* scope = allocation.init_for_loop == nullptr
528 ? nullptr
529 : &allocation.init_for_loop->body();
530 registerInsertBefore(allocation.init_place_before, init_expr, scope);
531 }
532 }
533 }
534
535 // Sends alloc_expr, info.has_halo, info.allocation_domains to GpuLower
536 void writeInfoToGPULower(
537 const AllocationInformation& allocation,
538 kir::Allocate* alloc_expr) {
539 auto& lower_alloc_info_map = GpuLower::current()->localAllocationInfoMap();
540 if (alloc_expr == nullptr) {
541 // Skip output allocation.
542 return;
543 }
544 TORCH_INTERNAL_ASSERT(
545 !lower_alloc_info_map.count(alloc_expr),
546 "duplicated allocation info entry");
547
548 // Create info entry for GPULower
549 auto lower_alloc_info_ptr = std::make_unique<LocalAllocationInfo>();
550 lower_alloc_info_ptr->alloc_expr = alloc_expr;
551 lower_alloc_info_ptr->has_halo = allocation.has_halo;
552 if (allocation.allocation_domains) {
553 lower_alloc_info_ptr->alloc_domains = *(allocation.allocation_domains);
554 }
555
556 // Write entry to the stored map
557 lower_alloc_info_map[alloc_expr] = std::move(lower_alloc_info_ptr);
558 }
559
560 void handle(kir::IfThenElse*) final {
561 TORCH_INTERNAL_ASSERT(
562 false,
563 "Pass does not support conditional statements, ",
564 "this pass should be run before any conditionals are placed in code.");
565 }
566
567 AllocationInserter(const std::vector<Expr*>& exprs)
568 : gpu_lower(GpuLower::current()) {
569 kir::ExprMutator::traverseAndInsert(exprs);
570 }
571
572 private:
573 GpuLower* gpu_lower;
574
575 public:
576 static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) {
577 AllocationInserter inserter(exprs);
578 return inserter.exprs_;
579 }
580};
581
582} // namespace
583
584std::vector<Expr*> insertAllocations(const std::vector<Expr*>& exprs) {
585 FUSER_PERF_SCOPE("GpuLower::Lower::insertAllocations");
586 return AllocationInserter::insert(exprs);
587}
588
589} // namespace cuda
590} // namespace fuser
591} // namespace jit
592} // namespace torch
593