1#include <arith.h>
2#include <index_compute.h>
3#include <ir_iostream.h>
4#include <ir_utils.h>
5#include <lower2device.h>
6#include <lower_utils.h>
7#include <predicate_compute.h>
8
9#include <lower_index.h>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16Val* IndexLowering::lowerSrcIndex(Val* src, Val* dst) const {
17 if (auto tv = dynamic_cast<TensorView*>(src)) {
18 TORCH_INTERNAL_ASSERT(dst->isA<TensorView>());
19 return Index::getProducerIndex(tv, dst->as<TensorView>(), for_loops_);
20 } else {
21 return src;
22 }
23}
24
25Val* IndexLowering::lowerDstIndex(Val* dst) const {
26 if (auto tv = dynamic_cast<TensorView*>(dst)) {
27 return Index::getConsumerIndex(tv, for_loops_);
28 } else {
29 return dst;
30 }
31}
32
33void IndexLowering::pushBack(Expr* expr) {
34 if (active_scope_ == nullptr) {
35 lowered_exprs_.push_back(expr);
36 } else {
37 active_scope_->push_back(expr);
38 }
39}
40
41Expr* IndexLowering::back() const {
42 if (active_scope_ == nullptr) {
43 TORCH_INTERNAL_ASSERT(
44 !lowered_exprs_.empty(), "IndexLowering::back: empty scope.");
45 return lowered_exprs_.back();
46 }
47 TORCH_INTERNAL_ASSERT(
48 !active_scope_->empty(), "IndexLowering::back: empty scope.");
49 return active_scope_->exprs().back();
50}
51
52void IndexLowering::insertAtTopLevel(Expr* expr) {
53 TORCH_INTERNAL_ASSERT(!lowered_exprs_.empty());
54 lowered_exprs_.insert(lowered_exprs_.end() - 1, expr);
55}
56
57void IndexLowering::handle(const kir::IfThenElse* ite) {
58 const auto prev_scope = active_scope_;
59
60 auto new_ite = IrBuilder::create<kir::IfThenElse>(ite->predicate());
61 pushBack(new_ite);
62
63 active_scope_ = &new_ite->thenBody();
64
65 for (auto expr : ite->thenBody().exprs()) {
66 OptOutConstDispatch::handle(expr);
67 }
68
69 active_scope_ = &new_ite->elseBody();
70
71 for (auto expr : ite->elseBody().exprs()) {
72 OptOutConstDispatch::handle(expr);
73 }
74
75 active_scope_ = prev_scope;
76}
77
78void IndexLowering::handle(const kir::ForLoop* for_loop) {
79 const auto prev_scope = active_scope_;
80
81 auto new_for_loop = IrBuilder::create<kir::ForLoop>(for_loop);
82 pushBack(new_for_loop);
83
84 active_scope_ = &new_for_loop->body();
85 for_loops_.push_back(new_for_loop);
86
87 for (auto expr : for_loop->body().exprs()) {
88 OptOutConstDispatch::handle(expr);
89 }
90
91 for_loops_.pop_back();
92 active_scope_ = prev_scope;
93}
94
95void IndexLowering::handle(const RNGOp* rop) {
96 // Write random tensor indices into the consumer
97 // tensor index if the output is a tensor.
98 auto out_tv = dynamic_cast<TensorView*>(rop->output(0));
99 TORCH_INTERNAL_ASSERT(out_tv != nullptr, "rand scalar not yet supported");
100
101 // TensorIndex for philox subsequence and component.
102 auto philox_index = SimplifyingIrBuilder::create<kir::TensorIndex>(
103 out_tv, Index::getLinearLogicalIndex(out_tv, for_loops_));
104
105 // TensorIndex for writing rand_like output.
106 const auto out = lowerDstIndex(out_tv);
107
108 auto lowered = IrBuilder::create<RNGOp>(
109 rop->getRNGOpType(),
110 out,
111 rop->dtype(),
112 rop->getParameters(),
113 rop->getRNGOffset(),
114 philox_index);
115
116 pushBack(lowered);
117 GpuLower::current()->propagateExprInfo(rop, back());
118}
119
120void IndexLowering::handle(const FullOp* fop) {
121 auto out_tv = dynamic_cast<TensorView*>(fop->output(0));
122 TORCH_INTERNAL_ASSERT(out_tv != nullptr);
123
124 // TensorIndex for writing output.
125 const auto out = lowerDstIndex(out_tv);
126 auto lowered =
127 IrBuilder::create<FullOp>(out, fop->getFillValue(), fop->dtype());
128
129 pushBack(lowered);
130 GpuLower::current()->propagateExprInfo(fop, back());
131}
132
133void IndexLowering::handle(const ARangeOp* aop) {
134 // Write linear tensor indices into the consumer
135 // tensor index if the output is a tensor.
136 auto out_tv = dynamic_cast<TensorView*>(aop->output(0));
137 TORCH_INTERNAL_ASSERT(out_tv != nullptr);
138
139 // linear index for computing arange output
140 auto linear_index = SimplifyingIrBuilder::create<kir::TensorIndex>(
141 out_tv, Index::getLinearLogicalIndex(out_tv, for_loops_));
142
143 // TensorIndex for writing arange output.
144 const auto out = lowerDstIndex(out_tv);
145 auto lowered = IrBuilder::create<ARangeOp>(
146 out, aop->start(), aop->end(), aop->step(), aop->dtype(), linear_index);
147
148 pushBack(lowered);
149 GpuLower::current()->propagateExprInfo(aop, back());
150}
151
152void IndexLowering::handle(const EyeOp* eop) {
153 auto out_tv = dynamic_cast<TensorView*>(eop->output(0));
154 TORCH_INTERNAL_ASSERT(out_tv != nullptr);
155
156 // linear index for computing eye output
157 auto indices = Index::getPerDimLogicalIndex(out_tv, for_loops_);
158 TORCH_INTERNAL_ASSERT(indices.size() == 2);
159 auto index1 = indices[0];
160 auto index2 = indices[1];
161
162 // TensorIndex for writing eye output.
163 const auto out = lowerDstIndex(out_tv);
164 auto lowered = IrBuilder::create<EyeOp>(out, eop->dtype(), index1, index2);
165
166 pushBack(lowered);
167 GpuLower::current()->propagateExprInfo(eop, back());
168}
169
170void IndexLowering::handle(const UnaryOp* uop) {
171 const auto in = lowerSrcIndex(uop->in(), uop->out());
172 const auto out = lowerDstIndex(uop->out());
173 pushBack(IrBuilder::create<UnaryOp>(uop->getUnaryOpType(), out, in));
174 GpuLower::current()->propagateExprInfo(uop, back());
175}
176
177void IndexLowering::handle(const BinaryOp* bop) {
178 const auto lhs = lowerSrcIndex(bop->lhs(), bop->out());
179 const auto rhs = lowerSrcIndex(bop->rhs(), bop->out());
180 const auto out = lowerDstIndex(bop->out());
181 pushBack(IrBuilder::create<BinaryOp>(bop->getBinaryOpType(), out, lhs, rhs));
182 GpuLower::current()->propagateExprInfo(bop, back());
183}
184
185void IndexLowering::handle(const TernaryOp* top) {
186 const auto in1 = lowerSrcIndex(top->in1(), top->out());
187 const auto in2 = lowerSrcIndex(top->in2(), top->out());
188 const auto in3 = lowerSrcIndex(top->in3(), top->out());
189 const auto out = lowerDstIndex(top->out());
190 pushBack(IrBuilder::create<TernaryOp>(
191 top->getTernaryOpType(), out, in1, in2, in3));
192 GpuLower::current()->propagateExprInfo(top, back());
193}
194
195void IndexLowering::handle(const ViewAsScalar* uop) {
196 const auto in = lowerSrcIndex(uop->in(), uop->out());
197 const auto out = lowerDstIndex(uop->out());
198 for (auto loop : for_loops_) {
199 if (GpuLower::current()->caMap()->areMapped(
200 loop->iter_domain(),
201 uop->vector_id()->as<IterDomain>(),
202 IdMappingMode::LOOP)) {
203 Val* index = loop->index();
204 pushBack(
205 IrBuilder::create<ViewAsScalar>(out, in, uop->vector_id(), index));
206 GpuLower::current()->propagateExprInfo(uop, back());
207 return;
208 }
209 }
210 TORCH_INTERNAL_ASSERT(false, "Can not find index for vector dim");
211}
212
213namespace {
214
215struct GridCommWorkBufferSizeInfo {
216 // Size of overall buffer. Can be expanded for privatization
217 Val* size_of_privatized_buffer = nullptr;
218 // Size of single buffer.
219 Val* buffer_stride = nullptr;
220};
221
222// Get the size of the temporary work buffer for grid communication, this can be
223// grid reduction, broadcast, or grid welford.
224// The buffer is expanded for privatization when not persistent or grouped.
225GridCommWorkBufferSizeInfo getGridCommWorkBufferSize(
226 const TensorDomain* td,
227 const std::vector<kir::ForLoop*>& for_loops,
228 bool is_persistent) {
229 // The buffer size is the number of thread blocks multiplied by the
230 // number of threads not used for reduction domains.
231 // Note: Previously it was calculated based on the shape of the
232 // tensor, but it makes more sense to compute the size based on the
233 // shape of the thread block and grid since this buffer is used for
234 // communications among them. Both methods should result in the same
235 // size if the parallel dimensions are exact, but otherwise, just
236 // computing the buffer size based on the tensor shape isn't
237 // sufficient since there could be extra threads/blocks.
238 Val* size_of_single_buffer = GpuLower::current()->kernel()->oneVal();
239 for (auto pt : kParallelTypeThreads) {
240 auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt);
241 if (pt_dim == nullptr || pt_dim->isOneInt()) {
242 continue;
243 }
244 if (isParallelTypeThreadDim(pt) &&
245 std::any_of(td->domain().begin(), td->domain().end(), [&](auto out_id) {
246 return out_id->getParallelType() == pt &&
247 (out_id->isReduction() || out_id->isBroadcast());
248 })) {
249 continue;
250 }
251 size_of_single_buffer =
252 SimplifyingIrBuilder::mulExpr(size_of_single_buffer, pt_dim);
253 }
254
255 // Expand the buffer for privatization. The buffer is expanded so
256 // that each non-reduction IterDomain uses a different part of the
257 // buffer. For persistent mode, this expansion is only done for
258 // grouped IterDomains.
259
260 Val* size_of_privatized_buffer = size_of_single_buffer;
261
262 // In persistent mode, if non-grouped no-reduction domain is used,
263 // double the buffer size to save a final grid sync
264 bool is_doubled = false;
265
266 for (auto fl : for_loops) {
267 // Buffer size of parallelized domains are already taken care
268 if (fl->isTrivial() || fl->iter_domain()->isReduction() ||
269 fl->iter_domain()->isThread()) {
270 continue;
271 }
272 // If persistent, i.e., allreduce, only IterDomains with
273 // ParallelType::Group are privatized
274 if (!is_persistent ||
275 fl->iter_domain()->getParallelType() == ParallelType::Group) {
276 size_of_privatized_buffer = SimplifyingIrBuilder::mulExpr(
277 size_of_privatized_buffer, fl->iter_domain()->extent());
278 } else if (is_persistent) {
279 is_doubled = true;
280 }
281 }
282
283 if (is_doubled) {
284 size_of_privatized_buffer = SimplifyingIrBuilder::mulExpr(
285 size_of_privatized_buffer, IrBuilder::create<Int>(2));
286 }
287
288 GridCommWorkBufferSizeInfo info;
289 info.size_of_privatized_buffer = size_of_privatized_buffer;
290 info.buffer_stride = size_of_single_buffer;
291 if (is_doubled) {
292 info.buffer_stride = SimplifyingIrBuilder::mulExpr(
293 info.buffer_stride, IrBuilder::create<Int>(2));
294 }
295
296 return info;
297}
298
299Val* getGridSyncBufferSize(
300 const TensorDomain* td,
301 const std::vector<kir::ForLoop*>& for_loops,
302 bool is_persistent) {
303 // See the comment above for getGridCommWorkBufferSize.
304 Val* buffer_size = GpuLower::current()->kernel()->oneVal();
305 for (auto pt : kParallelTypeBIDs) {
306 auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt);
307 if (pt_dim == nullptr || pt_dim->isOneInt()) {
308 continue;
309 }
310 if (std::any_of(td->domain().begin(), td->domain().end(), [&](auto out_id) {
311 return out_id->getParallelType() == pt &&
312 (out_id->isReduction() || out_id->isBroadcast());
313 })) {
314 continue;
315 }
316 buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim);
317 }
318
319 // If not persistent, all iteration domains require a separate
320 // semaphore for re-entrant grid reductions
321 if (!is_persistent) {
322 for (auto fl : for_loops) {
323 if (fl->isTrivial()) {
324 continue;
325 }
326 if (fl->iter_domain()->isThread()) {
327 // already accounted for.
328 continue;
329 }
330
331 buffer_size = SimplifyingIrBuilder::mulExpr(
332 buffer_size, fl->iter_domain()->extent());
333 }
334 }
335
336 return buffer_size;
337}
338
339Val* getEntranceCountGridReduce(std::vector<kir::ForLoop*>& for_loops) {
340 Val* grid_reduction_entrances = GpuLower::current()->kernel()->oneVal();
341
342 for (const auto loop : for_loops) {
343 if (loop->isTrivial()) {
344 continue;
345 }
346 if (loop->iter_domain()->isThread()) {
347 // already accounted for.
348 continue;
349 }
350 // TODO: Does this work for shift/gather?
351 grid_reduction_entrances = SimplifyingIrBuilder::mulExpr(
352 grid_reduction_entrances, loop->iter_domain()->extent());
353 }
354 return grid_reduction_entrances;
355}
356
357// Linear indexing of for loops for multiple entrances into grid reduce
358// TODO: What happens if there's a broadcast that's resolved (not present in the
359// grid reduce) but the global buffer isn't expanded?
360Val* getEntranceLinIndGridReduce(std::vector<kir::ForLoop*>& for_loops) {
361 Val* linear_index = GpuLower::current()->kernel()->zeroVal();
362
363 for (const auto loop : for_loops) {
364 if (loop->isTrivial()) {
365 continue;
366 }
367 if (loop->iter_domain()->isThread()) {
368 // already accounted for.
369 continue;
370 }
371 // TODO: Does this work for shift/gather?
372 linear_index = SimplifyingIrBuilder::addExpr(
373 SimplifyingIrBuilder::mulExpr(
374 linear_index, loop->iter_domain()->extent()),
375 loop->index());
376 }
377 return linear_index;
378}
379
380} // namespace
381
382void IndexLowering::handle(const ReductionOp* rop) {
383 TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(rop));
384
385 const auto out_tv = rop->out()->as<TensorView>();
386 const auto out_domain = out_tv->domain();
387
388 const bool has_block_reduce = out_domain->hasBlockReduction();
389 const bool has_grid_reduce = out_domain->hasGridReduction();
390
391 const auto out = lowerDstIndex(rop->out());
392 const auto in = lowerSrcIndex(rop->in(), rop->out());
393
394 if (has_grid_reduce) {
395 handleGridReduction(rop, out, in);
396 } else if (has_block_reduce) {
397 handleBlockReduction(rop, out, in);
398 } else {
399 pushBack(
400 IrBuilder::create<BinaryOp>(rop->getReductionOpType(), out, out, in));
401 GpuLower::current()->propagateExprInfo(rop, back());
402 }
403}
404
405void IndexLowering::handleBlockReduction(
406 const ReductionOp* rop,
407 Val* out,
408 Val* in) {
409 TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(rop));
410
411 ReductionOp* indexed_rop = IrBuilder::create<ReductionOp>(
412 rop->getReductionOpType(), rop->init(), out, in, rop->isAllreduce());
413 if (rop->predicate()) {
414 indexed_rop =
415 indexed_rop->withPredicate(rop->predicate())->as<ReductionOp>();
416 }
417 if (rop->writePredicate()) {
418 indexed_rop = indexed_rop->withWritePredicate(rop->writePredicate())
419 ->as<ReductionOp>();
420 }
421
422 pushBack(indexed_rop);
423 GpuLower::current()->propagateExprInfo(rop, back());
424}
425
426void IndexLowering::handleGridReduction(
427 const ReductionOp* rop,
428 Val* out,
429 Val* in) {
430 const auto out_tv = out->as<kir::TensorIndex>()->view();
431 const auto out_domain = out_tv->domain();
432
433 TORCH_INTERNAL_ASSERT(out_domain->hasGridReduction());
434
435 // If we do a grid reduction we can't have a reduction axis that is not bound
436 // to a grid or block dim.
437 TORCH_INTERNAL_ASSERT(
438 std::none_of(
439 out_domain->domain().begin(),
440 out_domain->domain().end(),
441 [](IterDomain* id) {
442 return !id->isThread() && id->isReduction() &&
443 !id->extent()->isOneInt();
444 }),
445 "Found a reduction stage that has both a non-parallelized ",
446 "reduction and a grid reduction. This is not supported, ",
447 "please use rfactor to do the serialized reduction first, ",
448 "then the grid reduction. ",
449 rop->toString());
450
451 // Use a unique buffer for work and sync flag when called within a
452 // loop unless it's persistent. Grid all reduce means persistence is
453 // required. However, not being a grid all reduce does not mean
454 // non-persistence. Currently, if a cooperative grid reduction is
455 // required anywhere in the kernel, all grid reducitons are done in
456 // a persistent manner, so all grid reductions should be consulted.
457 // TODO: fix this
458 const bool is_persistent = rop->isAllreduce();
459 const auto buffer_size_info =
460 getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);
461
462 auto work_buffer = allocateUniqueBuffer(
463 buffer_size_info.size_of_privatized_buffer,
464 out_tv->dtype(),
465 false,
466 out_tv,
467 work_buffer_map_);
468
469 auto sync_buffer_size =
470 getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
471 auto sync_buffer = allocateUniqueBuffer(
472 sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);
473
474 const auto entrance_ind = !is_persistent
475 ? getEntranceLinIndGridReduce(for_loops_)
476 : GpuLower::current()->kernel()->zeroVal();
477 const auto n_entrances = !is_persistent
478 ? getEntranceCountGridReduce(for_loops_)
479 : GpuLower::current()->kernel()->oneVal();
480
481 // The thread predicate for GridReduction needs to be set
482 // separately from the main predicate. Do not combine them like
483 // other expressions.
484 const auto& thread_pred =
485 GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv);
486
487 auto grid_reduction = IrBuilder::create<kir::GridReduction>(
488 rop->getReductionOpType(),
489 rop->init(),
490 out,
491 in,
492 work_buffer,
493 sync_buffer,
494 entrance_ind,
495 n_entrances,
496 rop->isAllreduce());
497
498 grid_reduction = grid_reduction->withThreadPredicate(thread_pred);
499
500 if (rop->predicate()) {
501 grid_reduction = grid_reduction->withPredicate(rop->predicate())
502 ->as<kir::GridReduction>();
503 }
504 if (rop->writePredicate()) {
505 grid_reduction = grid_reduction->withWritePredicate(rop->writePredicate())
506 ->as<kir::GridReduction>();
507 }
508
509 pushBack(grid_reduction);
510 GpuLower::current()->propagateExprInfo(rop, back());
511
512 if (rop->isAllreduce()) {
513 allocateUniqueFusedReduction(grid_reduction, out_tv);
514 }
515}
516
517void IndexLowering::handle(const GroupedReductionOp* grouped_rop) {
518 TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(grouped_rop));
519
520 const auto out_tv = ir_utils::getTvOutput(grouped_rop);
521 const auto out_domain = out_tv->domain();
522
523 const bool has_block_reduce = out_domain->hasBlockReduction();
524 const bool has_grid_reduce = out_domain->hasGridReduction();
525
526 std::vector<Val*> indexed_outputs(grouped_rop->numExprs());
527 std::vector<Val*> indexed_inputs(grouped_rop->numExprs());
528
529 for (const auto i : c10::irange(grouped_rop->numExprs())) {
530 indexed_outputs.at(i) = lowerDstIndex(grouped_rop->output(i));
531 indexed_inputs.at(i) =
532 lowerSrcIndex(grouped_rop->input(i), grouped_rop->output(i));
533 }
534
535 if (has_grid_reduce) {
536 handleGridReduction(grouped_rop, indexed_outputs, indexed_inputs);
537 } else if (has_block_reduce) {
538 handleBlockReduction(grouped_rop, indexed_outputs, indexed_inputs);
539 } else {
540 for (const auto i : c10::irange(grouped_rop->numExprs())) {
541 pushBack(IrBuilder::create<BinaryOp>(
542 grouped_rop->getReductionOpType(i),
543 indexed_outputs.at(i),
544 indexed_outputs.at(i),
545 indexed_inputs.at(i)));
546 }
547 }
548}
549
550void IndexLowering::handleBlockReduction(
551 const GroupedReductionOp* grouped_rop,
552 const std::vector<Val*>& outputs,
553 const std::vector<Val*>& inputs) {
554 TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(grouped_rop));
555
556 GroupedReductionOp* indexed_rop = IrBuilder::create<GroupedReductionOp>(
557 grouped_rop->getReductionOpTypes(),
558 grouped_rop->initVals(),
559 outputs,
560 inputs,
561 grouped_rop->isAllreduce());
562 if (grouped_rop->predicate()) {
563 indexed_rop = indexed_rop->withPredicate(grouped_rop->predicate())
564 ->as<GroupedReductionOp>();
565 }
566 if (grouped_rop->writePredicate()) {
567 indexed_rop = indexed_rop->withWritePredicate(grouped_rop->writePredicate())
568 ->as<GroupedReductionOp>();
569 }
570
571 pushBack(indexed_rop);
572 GpuLower::current()->propagateExprInfo(grouped_rop, back());
573}
574
575void IndexLowering::handleGridReduction(
576 const GroupedReductionOp* grouped_rop,
577 const std::vector<Val*>& outputs,
578 const std::vector<Val*>& inputs) {
579 const auto out_tv = ir_utils::getTvOutput(grouped_rop);
580 const auto out_domain = out_tv->domain();
581
582 TORCH_INTERNAL_ASSERT(out_domain->hasGridReduction());
583
584 // If we do a grid reduction we can't have a reduction axis that is not bound
585 // to a grid or block dim.
586 TORCH_INTERNAL_ASSERT(
587 std::none_of(
588 out_domain->domain().begin(),
589 out_domain->domain().end(),
590 [](IterDomain* id) {
591 return !id->isThread() && id->isReduction() &&
592 !id->extent()->isOneInt();
593 }),
594 "Found a reduction stage that has both a non-parallelized ",
595 "reduction and a grid reduction. This is not supported, ",
596 "please use rfactor to do the serialized reduction first, ",
597 "then the grid reduction.");
598
599 const bool is_persistent = grouped_rop->isAllreduce();
600 auto work_buf_size_info =
601 getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);
602
603 std::vector<kir::Allocate*> work_buffers;
604 std::transform(
605 outputs.begin(),
606 outputs.end(),
607 std::back_inserter(work_buffers),
608 [&](Val* output) {
609 return allocateUniqueBuffer(
610 work_buf_size_info.size_of_privatized_buffer,
611 output->dtype(),
612 false,
613 output->as<kir::TensorIndex>()->view(),
614 work_buffer_map_);
615 });
616
617 auto sync_buffer_size =
618 getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
619 auto sync_buffer = allocateUniqueBuffer(
620 sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);
621
622 const auto entrance_ind = !is_persistent
623 ? getEntranceLinIndGridReduce(for_loops_)
624 : GpuLower::current()->kernel()->zeroVal();
625 const auto n_entrances = !is_persistent
626 ? getEntranceCountGridReduce(for_loops_)
627 : GpuLower::current()->kernel()->oneVal();
628
629 // The thread predicate for GridReduction needs to be set
630 // separately from the main predicate. Do not combine them like
631 // other expressions.
632 const auto& thread_pred =
633 GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv);
634
635 auto grid_reduction = IrBuilder::create<kir::GroupedGridReduction>(
636 grouped_rop->getReductionOpTypes(),
637 grouped_rop->initVals(),
638 outputs,
639 inputs,
640 work_buffers,
641 sync_buffer,
642 entrance_ind,
643 n_entrances,
644 work_buf_size_info.buffer_stride,
645 grouped_rop->isAllreduce());
646
647 grid_reduction = grid_reduction->withThreadPredicate(thread_pred);
648
649 if (grouped_rop->predicate()) {
650 grid_reduction = grid_reduction->withPredicate(grouped_rop->predicate())
651 ->as<kir::GroupedGridReduction>();
652 }
653 if (grouped_rop->writePredicate()) {
654 grid_reduction =
655 grid_reduction->withWritePredicate(grouped_rop->writePredicate())
656 ->as<kir::GroupedGridReduction>();
657 }
658
659 pushBack(grid_reduction);
660 GpuLower::current()->propagateExprInfo(grouped_rop, back());
661
662 if (grouped_rop->isAllreduce()) {
663 allocateUniqueFusedReduction(grid_reduction, out_tv);
664 }
665}
666
667void IndexLowering::handle(const WelfordOp* wop) {
668 TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(wop));
669
670 const auto out_tv = wop->outAvg()->as<TensorView>();
671 const auto out_domain = out_tv->domain();
672
673 const bool has_block_reduce = out_domain->hasBlockReduction();
674 const bool has_grid_reduce = out_domain->hasGridReduction();
675
676 if (has_grid_reduce) {
677 TORCH_INTERNAL_ASSERT(
678 std::none_of(
679 out_domain->domain().begin(),
680 out_domain->domain().end(),
681 [](IterDomain* id) {
682 return !id->isThread() && id->isReduction();
683 }),
684 "Found a reduction stage that has both a non-parallelized ",
685 "reduction and a grid reduction. This is not supported, ",
686 "please use rfactor to do the serialized reduction first, ",
687 "then the grid reduction.");
688 }
689
690 // lower IO tensors
691 const auto in_var =
692 wop->inVar() ? lowerSrcIndex(wop->inVar(), wop->outAvg()) : nullptr;
693 const auto in_avg = lowerSrcIndex(wop->inAvg(), wop->outAvg());
694 auto in_N = wop->inN();
695
696 // in Rfactor-ed case, the input N is actually a TV
697 if (!in_N->isScalar()) {
698 in_N = lowerSrcIndex(in_N, wop->outN());
699 }
700
701 auto out_avg = lowerDstIndex(wop->outAvg());
702 auto out_var = lowerDstIndex(wop->outVar());
703 auto out_N = lowerDstIndex(wop->outN());
704
705 WelfordOp* indexed_wop = IrBuilder::create<WelfordOp>(
706 out_avg,
707 out_var,
708 out_N,
709 in_avg,
710 in_var,
711 in_N,
712 wop->initAvg(),
713 wop->initVar(),
714 wop->initN(),
715 wop->isAllreduce());
716
717 if (wop->predicate()) {
718 indexed_wop = indexed_wop->withPredicate(wop->predicate())->as<WelfordOp>();
719 }
720 if (wop->writePredicate()) {
721 indexed_wop =
722 indexed_wop->withWritePredicate(wop->writePredicate())->as<WelfordOp>();
723 }
724
725 // Serial welford
726 if (!has_block_reduce && !has_grid_reduce) {
727 pushBack(indexed_wop);
728 GpuLower::current()->propagateExprInfo(wop, back());
729 return;
730 }
731
732 // Block-only welford
733 if (!has_grid_reduce) {
734 pushBack(indexed_wop);
735 GpuLower::current()->propagateExprInfo(wop, back());
736 return;
737 }
738
739 handleGridWelford(indexed_wop);
740}
741
742void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
743 const auto out_tv = indexed_wop->out()->as<kir::TensorIndex>()->view();
744 const auto out_domain = out_tv->domain();
745
746 // TODO: See the comment on the same variable in handleGridReduction
747 const bool is_persistent = indexed_wop->isAllreduce();
748 const auto buffer_size_info =
749 getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);
750
751 const auto work_buffer_size = buffer_size_info.size_of_privatized_buffer;
752 auto out_avg_buffer = allocateUniqueBuffer(
753 work_buffer_size,
754 indexed_wop->outAvg()->dtype(),
755 false,
756 indexed_wop->outAvg()->as<kir::TensorIndex>()->view(),
757 work_buffer_map_);
758 auto out_var_buffer = allocateUniqueBuffer(
759 work_buffer_size,
760 indexed_wop->outVar()->dtype(),
761 false,
762 indexed_wop->outVar()->as<kir::TensorIndex>()->view(),
763 work_buffer_map_);
764 auto out_N_buffer = allocateUniqueBuffer(
765 work_buffer_size,
766 indexed_wop->outN()->dtype(),
767 false,
768 indexed_wop->outN()->as<kir::TensorIndex>()->view(),
769 work_buffer_map_);
770
771 auto sync_buffer_size =
772 getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
773 auto sync_buffer = allocateUniqueBuffer(
774 sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);
775
776 const auto entrance_ind = !is_persistent
777 ? getEntranceLinIndGridReduce(for_loops_)
778 : GpuLower::current()->kernel()->zeroVal();
779 const auto n_entrances = !is_persistent
780 ? getEntranceCountGridReduce(for_loops_)
781 : GpuLower::current()->kernel()->oneVal();
782
783 // The thread predicate for GridReduction needs to be set
784 // separately from the main predicate. Do not combine them like
785 // other expressions.
786 const auto& thread_pred =
787 GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv);
788
789 auto grid_welford = IrBuilder::create<kir::GridWelford>(
790 indexed_wop,
791 out_var_buffer,
792 out_avg_buffer,
793 out_N_buffer,
794 sync_buffer,
795 entrance_ind,
796 n_entrances);
797
798 grid_welford = grid_welford->withThreadPredicate(thread_pred);
799
800 const bool block_reduce_separated =
801 out_domain->hasBlockReduction() && !indexed_wop->isAllreduce();
802
803 if (indexed_wop->predicate()) {
804 if (block_reduce_separated) {
805 grid_welford = grid_welford
806 ->withPredicate(IrBuilder::create<kir::Predicate>(
807 GpuLower::current()->kernel()->trueVal()))
808 ->as<kir::GridWelford>();
809 } else {
810 grid_welford = grid_welford->withPredicate(indexed_wop->predicate())
811 ->as<kir::GridWelford>();
812 }
813 }
814
815 if (indexed_wop->writePredicate()) {
816 grid_welford =
817 grid_welford->withWritePredicate(indexed_wop->writePredicate())
818 ->as<kir::GridWelford>();
819 }
820
821 if (block_reduce_separated) {
822 pushBack(indexed_wop);
823 GpuLower::current()->propagateExprInfo(indexed_wop, back());
824 }
825
826 pushBack(grid_welford);
827 GpuLower::current()->propagateExprInfo(indexed_wop, back());
828
829 if (indexed_wop->isAllreduce()) {
830 // When using the fused reduction, allocate the reduction object at
831 // the outer-most scope
832 allocateUniqueFusedReduction(grid_welford, out_tv);
833 }
834}
835
836void IndexLowering::handle(const GroupedWelfordOp* grouped_wop) {
837 TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(grouped_wop));
838
839 const auto out_tv = ir_utils::getTvOutput(grouped_wop);
840 const auto out_domain = out_tv->domain();
841
842 const bool has_grid_reduce = out_domain->hasGridReduction();
843
844 std::vector<WelfordTriplet> indexed_outputs(grouped_wop->numExprs());
845 std::vector<WelfordTriplet> indexed_inputs(grouped_wop->numExprs());
846
847 for (const auto i : c10::irange(grouped_wop->numExprs())) {
848 const auto& output = grouped_wop->outputVals().at(i);
849 const auto& input = grouped_wop->inputVals().at(i);
850 WelfordTriplet indexed_output;
851 WelfordTriplet indexed_input;
852 for (const auto j : c10::irange(3)) {
853 indexed_output.get(j) = lowerDstIndex(output.get(j));
854 indexed_input.get(j) = lowerSrcIndex(input.get(j), output.get(j));
855 }
856 indexed_outputs[i] = indexed_output;
857 indexed_inputs[i] = indexed_input;
858 }
859
860 if (has_grid_reduce) {
861 handleGroupedGridWelford(
862 grouped_wop, indexed_outputs, indexed_inputs, grouped_wop->initVals());
863 } else {
864 TORCH_INTERNAL_ASSERT(
865 false,
866 "Only grid welford is supported. Validation should have caught non-grid welford grouping.");
867 }
868}
869
870std::vector<kir::Allocate*> IndexLowering::allocateWelfordWorkBuffer(
871 const std::vector<WelfordTriplet>& triplets,
872 WelfordTriplet::ValName name,
873 Val* buffer_size) {
874 std::vector<kir::Allocate*> work_buffers;
875
876 std::transform(
877 triplets.begin(),
878 triplets.end(),
879 std::back_inserter(work_buffers),
880 [&](const WelfordTriplet& output) {
881 return allocateUniqueBuffer(
882 buffer_size,
883 output.get(name)->dtype(),
884 false,
885 output.get(name)->as<TensorView>(),
886 work_buffer_map_);
887 });
888
889 return work_buffers;
890}
891
892void IndexLowering::handleGroupedGridWelford(
893 const GroupedWelfordOp* op,
894 const std::vector<WelfordTriplet>& output_vals,
895 const std::vector<WelfordTriplet>& input_vals,
896 const std::vector<WelfordTriplet>& init_vals) {
897 const auto out_tv = ir_utils::getTvOutput(op);
898 const auto out_domain = out_tv->domain();
899
900 TORCH_INTERNAL_ASSERT(out_domain->hasGridReduction());
901
902 // If we do a grid reduction we can't have a reduction axis that is not bound
903 // to a grid or block dim.
904 TORCH_INTERNAL_ASSERT(
905 std::none_of(
906 out_domain->domain().begin(),
907 out_domain->domain().end(),
908 [](IterDomain* id) {
909 return !id->isThread() && id->isReduction() &&
910 !id->extent()->isOneInt();
911 }),
912 "Found a reduction stage that has both a non-parallelized ",
913 "reduction and a grid reduction. This is not supported, ",
914 "please use rfactor to do the serialized reduction first, ",
915 "then the grid reduction.");
916
917 const bool is_persistent = op->isAllreduce();
918 auto work_buf_size_info =
919 getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);
920
921 const auto work_buffers_avg = allocateWelfordWorkBuffer(
922 op->outputVals(),
923 WelfordTriplet::ValName::Avg,
924 work_buf_size_info.size_of_privatized_buffer);
925 const auto work_buffers_var = allocateWelfordWorkBuffer(
926 op->outputVals(),
927 WelfordTriplet::ValName::Var,
928 work_buf_size_info.size_of_privatized_buffer);
929 const auto work_buffers_N = allocateWelfordWorkBuffer(
930 op->outputVals(),
931 WelfordTriplet::ValName::N,
932 work_buf_size_info.size_of_privatized_buffer);
933
934 auto sync_buffer_size =
935 getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
936 auto sync_buffer = allocateUniqueBuffer(
937 sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);
938
939 const auto entrance_ind = !is_persistent
940 ? getEntranceLinIndGridReduce(for_loops_)
941 : GpuLower::current()->kernel()->zeroVal();
942 const auto n_entrances = !is_persistent
943 ? getEntranceCountGridReduce(for_loops_)
944 : GpuLower::current()->kernel()->oneVal();
945
946 // The thread predicate needs to be set separately from the main
947 // predicate. Do not combine them like other expressions.
948 const auto& thread_pred =
949 GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv);
950
951 auto indexed_op = IrBuilder::create<kir::GroupedGridWelford>(
952 output_vals,
953 input_vals,
954 init_vals,
955 std::array<std::vector<kir::Allocate*>, 3>{
956 work_buffers_avg, work_buffers_var, work_buffers_N},
957 sync_buffer,
958 entrance_ind,
959 n_entrances,
960 work_buf_size_info.buffer_stride,
961 op->isAllreduce());
962
963 indexed_op = indexed_op->withThreadPredicate(thread_pred);
964
965 if (op->predicate()) {
966 indexed_op = indexed_op->withPredicate(op->predicate())
967 ->as<kir::GroupedGridWelford>();
968 }
969 if (op->writePredicate()) {
970 indexed_op = indexed_op->withWritePredicate(op->writePredicate())
971 ->as<kir::GroupedGridWelford>();
972 }
973
974 pushBack(indexed_op);
975 GpuLower::current()->propagateExprInfo(op, back());
976
977 if (op->isAllreduce()) {
978 allocateUniqueFusedReduction(indexed_op, out_tv);
979 }
980}
981
982void IndexLowering::handle(const LoadStoreOp* ldst) {
983 const auto in = lowerSrcIndex(ldst->in(), ldst->out());
984 const auto out = lowerDstIndex(ldst->out());
985 auto new_ldst = IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in)
986 ->withPredicate(ldst->predicate());
987 pushBack(new_ldst);
988 GpuLower::current()->propagateExprInfo(ldst, back());
989}
990
991void IndexLowering::handle(const MmaOp* mma) {
992 const auto a = lowerSrcIndex(mma->inA(), mma->out());
993 const auto b = lowerSrcIndex(mma->inB(), mma->out());
994 const auto out = lowerDstIndex(mma->out());
995 auto mma_indexed =
996 IrBuilder::create<MmaOp>(out, a, b, mma->init(), mma->options());
997 pushBack(mma_indexed);
998 GpuLower::current()->propagateExprInfo(mma, back());
999}
1000
1001void IndexLowering::handle(const BroadcastOp* bop) {
1002 TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(bop));
1003
1004 const auto out_tv = bop->out()->as<TensorView>();
1005
1006 const auto out = lowerDstIndex(bop->out());
1007 const auto in = lowerSrcIndex(bop->in(), bop->out());
1008 auto indexed_expr =
1009 IrBuilder::create<BroadcastOp>(out, in, bop->getBroadcastDimFlags());
1010
1011 const ParallelTypeBitmap parallel_bitmap =
1012 GpuLower::current()->threadPredMap().getParallelBroadcastDomains(out_tv);
1013
1014 const bool block_x = parallel_bitmap.get(ParallelType::BIDx);
1015 const bool block_y = parallel_bitmap.get(ParallelType::BIDy);
1016 const bool block_z = parallel_bitmap.get(ParallelType::BIDz);
1017
1018 if (bop->predicate()) {
1019 indexed_expr =
1020 indexed_expr->withPredicate(bop->predicate())->as<BroadcastOp>();
1021 }
1022
1023 const bool grid_broadcast_needed = block_x || block_y || block_z;
1024 if (!grid_broadcast_needed) {
1025 pushBack(indexed_expr);
1026 GpuLower::current()->propagateExprInfo(bop, back());
1027 return;
1028 }
1029
1030 // Grid broadcast
1031 const auto out_domain = out_tv->domain();
1032 const auto work_buffer_size =
1033 getGridCommWorkBufferSize(out_domain, for_loops_, true)
1034 .size_of_privatized_buffer;
1035
1036 auto work_buffer = allocateUniqueBuffer(
1037 work_buffer_size, out->dtype(), false, out_tv, work_buffer_map_);
1038
1039 auto sync_buffer_size = getGridSyncBufferSize(out_domain, for_loops_, true);
1040 auto sync_buffer = allocateUniqueBuffer(
1041 sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);
1042
1043 auto grid_broadcast = IrBuilder::create<kir::GridBroadcast>(
1044 indexed_expr, work_buffer, sync_buffer);
1045
1046 if (bop->predicate()) {
1047 grid_broadcast = grid_broadcast->withPredicate(bop->predicate())
1048 ->as<kir::GridBroadcast>();
1049 }
1050
1051 pushBack(grid_broadcast);
1052 GpuLower::current()->propagateExprInfo(bop, back());
1053}
1054
1055void IndexLowering::handle(const kir::Allocate* allocate) {
1056 // TODO(kir): remove the need for const_cast
1057 pushBack(const_cast<kir::Allocate*>(allocate)); // NOLINT
1058}
1059
1060void IndexLowering::handle(const kir::BlockSync* sync) {
1061 // TODO(kir): remove the need for const_cast
1062 pushBack(const_cast<kir::BlockSync*>(sync)); // NOLINT
1063}
1064
1065void IndexLowering::handle(const kir::GridSync* sync) {
1066 // TODO(kir): remove the need for const_cast
1067 pushBack(const_cast<kir::GridSync*>(sync)); // NOLINT
1068}
1069
1070void IndexLowering::handle(const kir::CpAsyncWait* wait) {
1071 // TODO(kir): remove the need for const_cast
1072 pushBack(const_cast<kir::CpAsyncWait*>(wait)); // NOLINT
1073}
1074
1075void IndexLowering::handle(const kir::CpAsyncCommit* commit) {
1076 // TODO(kir): remove the need for const_cast
1077 pushBack(const_cast<kir::CpAsyncCommit*>(commit)); // NOLINT
1078}
1079
1080void IndexLowering::generate(const std::vector<Expr*>& exprs) {
1081 for (auto expr : exprs) {
1082 OptOutConstDispatch::handle(expr);
1083 }
1084}
1085
1086kir::Allocate* IndexLowering::allocateUniqueBuffer(
1087 Val* buffer_size,
1088 DataType dtype,
1089 bool zero_init,
1090 TensorView* out_tv,
1091 std::unordered_map<TensorView*, kir::Allocate*>& alloc_map) {
1092 // Return an existing allocation if exists
1093 auto it = alloc_map.find(out_tv);
1094 if (it != alloc_map.end()) {
1095 return it->second;
1096 }
1097
1098 // No existing allocation found. Create a new one
1099 auto new_buffer =
1100 lower_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init);
1101
1102 // Keep track of the allocation
1103 alloc_map.emplace(out_tv, new_buffer);
1104
1105 // A buffer may be used in both the unswitched paths, so it must be
1106 // placed outside of the current scope. Simplying placing it at the
1107 // top-level scope should work.
1108 insertAtTopLevel(new_buffer);
1109
1110 return new_buffer;
1111}
1112
1113void IndexLowering::allocateUniqueFusedReduction(
1114 Expr* expr,
1115 TensorView* out_tv) {
1116 auto it = fused_reduction_map_.find(out_tv);
1117 if (it != fused_reduction_map_.end()) {
1118 return;
1119 }
1120
1121 kir::AllocateFusedReduction* fused_reduction_alloc_reduction = nullptr;
1122 switch (expr->getExprType().value()) {
1123 case ExprType::GridReduction:
1124 fused_reduction_alloc_reduction =
1125 IrBuilder::create<kir::AllocateFusedReduction>(
1126 expr->as<kir::GridReduction>());
1127 break;
1128 case ExprType::GridWelford:
1129 fused_reduction_alloc_reduction =
1130 IrBuilder::create<kir::AllocateFusedReduction>(
1131 expr->as<kir::GridWelford>());
1132 break;
1133 case ExprType::GroupedGridReduction:
1134 fused_reduction_alloc_reduction =
1135 IrBuilder::create<kir::AllocateFusedReduction>(
1136 expr->as<kir::GroupedGridReduction>());
1137 break;
1138 case ExprType::GroupedGridWelford:
1139 fused_reduction_alloc_reduction =
1140 IrBuilder::create<kir::AllocateFusedReduction>(
1141 expr->as<kir::GroupedGridWelford>());
1142 break;
1143 default:
1144 TORCH_INTERNAL_ASSERT(false, "Invalid expr: ", expr->toString());
1145 }
1146
1147 fused_reduction_map_.emplace(out_tv, fused_reduction_alloc_reduction);
1148
1149 // When using the fused reduction, allocate the reduction object at
1150 // the outer-most scope
1151 insertAtTopLevel(fused_reduction_alloc_reduction);
1152}
1153
1154} // namespace cuda
1155} // namespace fuser
1156} // namespace jit
1157} // namespace torch
1158