1#include <arith.h>
2#include <index_compute.h>
3#include <instrumentation.h>
4#include <ir_iostream.h>
5#include <ir_utils.h>
6#include <kernel_expr_evaluator.h>
7#include <kernel_ir.h>
8#include <lower2device.h>
9#include <lower_index_compute.h>
10#include <lower_shift.h>
11#include <lower_utils.h>
12
13#include <functional>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20Expr* ShiftPredicateInserter::insert(
21 Expr* expr,
22 const std::vector<kir::ForLoop*>& loops,
23 Bool* thread_pred,
24 bool within_unswitch) {
25 const auto gpu_lower = GpuLower::current();
26
27 TensorView* out_tv = ir_utils::getTvOutput(expr);
28 TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output");
29
30 const bool needs_shift_predicate =
31 gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition());
32 if (!needs_shift_predicate) {
33 return expr;
34 }
35
36 // The conditional branches to create:
37 //
38 // if (shift_pred) {
39 // consumer = producer;
40 // } else {
41 // if (padding_pred) {
42 // consumer = 0;
43 // }
44 // }
45
46 kir::Predicate* thread_pred_expr = nullptr;
47 if (within_unswitch) {
48 thread_pred_expr = IrBuilder::create<kir::Predicate>(thread_pred);
49 }
50
51 kir::Predicate* shift_pred = within_unswitch
52 ? thread_pred_expr
53 : IrBuilder::create<kir::Predicate>(
54 PredicateType::Shift, expr, thread_pred);
55
56 // If the expr involves a thread-block barrier, set the predicate of
57 // the expr with shift_pred. Since the expr is not shift, the
58 // padding is safe to omit.
59 if (lower_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) {
60 return expr->withPredicate(shift_pred);
61 }
62
63 auto shift_ite = IrBuilder::create<kir::IfThenElse>(shift_pred);
64
65 auto& scope = loops.back()->body();
66
67 // Insert the if statement
68 scope.insert_before(expr, shift_ite);
69
70 // Remove the expr from the list
71 scope.erase(expr);
72
73 // Place the expr inside the if statement
74 shift_ite->thenBody().push_back(expr);
75
76 // No padding condition is required if this is within unswitch.
77 if (within_unswitch) {
78 return expr;
79 }
80
81 // Padding by zero
82 kir::Predicate* padding_pred = IrBuilder::create<kir::Predicate>(
83 PredicateType::Padding, expr, thread_pred);
84 auto bounds_ite = IrBuilder::create<kir::IfThenElse>(padding_pred);
85 const int pad_value = 0;
86 auto pad_expr = IrBuilder::create<UnaryOp>(
87 UnaryOpType::Set, out_tv, IrBuilder::create<Int>(pad_value));
88 bounds_ite->thenBody().push_back(pad_expr);
89 // Insert the else block
90 shift_ite->elseBody().push_back(bounds_ite);
91
92 return expr;
93}
94
95int AxisHaloInfo::width() const {
96 return width(0) + width(1);
97}
98
99int AxisHaloInfo::width(int pos) const {
100 TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2);
101 return widths_[pos];
102}
103
104void AxisHaloInfo::setWidth(int pos, int width) {
105 TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2);
106 widths_[pos] = width;
107}
108
109void AxisHaloInfo::merge(int pos, int other) {
110 auto new_width = std::max(width(pos), other);
111 setWidth(pos, new_width);
112}
113
114void AxisHaloInfo::merge(const AxisHaloInfo& other) {
115 for (const auto i : c10::irange(widths_.size())) {
116 merge(i, other.width(i));
117 }
118}
119
120bool AxisHaloInfo::hasHalo() const {
121 return std::any_of(
122 widths_.begin(), widths_.end(), [](auto w) { return w != 0; });
123}
124
125std::string AxisHaloInfo::toString() const {
126 std::stringstream ss;
127 ss << "<" << width(0) << ", " << width(1) << ">";
128 return ss.str();
129}
130
131bool HaloInfo::hasRootAxisInfo(IterDomain* id) const {
132 return root_axis_map_.find(id) != root_axis_map_.end();
133}
134
135const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const {
136 // TODO: Enable this check, was failing in many tests
137 // TORCH_INTERNAL_ASSERT(
138 // id->definition() == nullptr || id->isRFactorProduct(),
139 // "Invalid IterDomain: ",
140 // id);
141 auto it = root_axis_map_.find(id);
142 TORCH_INTERNAL_ASSERT(
143 it != root_axis_map_.end(),
144 "Halo root axis info not found for ",
145 id->toString());
146 return it->second;
147}
148
149void HaloInfo::setRootAxisInfo(
150 IterDomain* id,
151 const AxisHaloInfo& root_axis_info) {
152 root_axis_map_[id] = root_axis_info;
153
154 initializeFromRootAxisInfo(id);
155 return;
156}
157
158HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr<const ComputeAtMap> ca_map)
159 // Make a copy of the permissive map for extent comparators
160 : permissive_map_(ca_map->idGraph().permissiveNodes()) {
161 const auto vals = fusion->usedMathVals();
162 auto tvs = ir_utils::filterByType<TensorView>(vals);
163
164 // Initialize all root axis info
165 for (auto tv : tvs) {
166 for (auto root_axis : tv->getRootDomain()) {
167 setRootAxisInfo(root_axis, AxisHaloInfo());
168 }
169 // Just adds a placeholder to make it not fail. Reduction and
170 // rfactor support is not yet in place.
171 if (tv->hasRFactor()) {
172 for (auto rf_root_axis : tv->getRFactorDomain()) {
173 setRootAxisInfo(rf_root_axis, AxisHaloInfo());
174 }
175 }
176 }
177
178 // Propagate backward halo information of root axes from fusion
179 // outputs to inputs
180 auto exprs = fusion->exprs();
181 for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
182 auto expr = *it;
183 if (!expr->outputs()[0]->isA<TensorView>()) {
184 continue;
185 }
186
187 propagateRootAxisInfo(expr);
188 }
189
190 // Propagates halo information from root axes down to leaf axes
191 for (auto tv : tvs) {
192 build(tv->domain());
193 }
194
195 if (isDebugDumpEnabled(DebugDumpOption::Halo)) {
196 std::cout << toString() << std::endl;
197 }
198
199 // Note that validation requires consumer halo info
200 for (auto tv : tvs) {
201 validate(tv, ca_map);
202 }
203}
204
205void HaloInfo::propagateRootAxisInfo(Expr* expr) {
206 for (auto output : expr->outputs()) {
207 auto out_tv = dynamic_cast<TensorView*>(output);
208 if (out_tv == nullptr) {
209 continue;
210 }
211 for (auto input : expr->inputs()) {
212 auto in_tv = dynamic_cast<TensorView*>(input);
213 if (in_tv == nullptr) {
214 continue;
215 }
216 propagateRootAxisInfo(in_tv, out_tv, expr);
217 }
218 }
219}
220
221void HaloInfo::propagateRootAxisInfo(
222 TensorView* producer,
223 TensorView* consumer,
224 Expr* expr) {
225 // Do not add halo to input tensors
226 if (producer->isFusionInput()) {
227 return;
228 }
229
230 auto c2p = PairwiseRootDomainMap(producer, consumer)
231 .mapConsumerToProducer(consumer->domain(), producer->domain());
232
233 const auto& c_root = consumer->getRootDomain();
234
235 for (const auto i : c10::irange(c_root.size())) {
236 auto c_id = c_root[i];
237 auto it = c2p.find(c_id);
238 if (it == c2p.end()) {
239 // nothing to propagate
240 continue;
241 }
242
243 // propagate root-axis halo info from c_id to p_id
244
245 auto p_id = it->second;
246
247 AxisHaloInfo p_info;
248 if (hasRootAxisInfo(p_id)) {
249 p_info = getRootAxisInfo(p_id);
250 }
251 const auto c_info = getRootAxisInfo(c_id);
252
253 // If the root axes are broadcast, no halo should be associated
254 // with them.
255 if (c_id->isBroadcast()) {
256 TORCH_INTERNAL_ASSERT(!c_info.hasHalo());
257 p_info.merge(c_info);
258 setRootAxisInfo(p_id, p_info);
259 continue;
260 } else if (p_id->isRFactorProduct()) {
261 TORCH_INTERNAL_ASSERT(
262 !c_info.hasHalo(),
263 "Propagating halo info to a rfactor producer domain not yet supported.");
264 continue;
265 }
266
267 // If the defining expression is shift, adjust the producer halo
268 // width based on the shift offset. If the shift offset is
269 // positive, create halo at offset zero of the producer axis so
270 // that the consumer can safely access the producer. If the offset
271 // is negative, halo is created at the other end of the axis.
272 // If the expr is not shift, just merge the consumer halo info
273 // to the producer halo info so that the producer halo can be the
274 // maximum of all its consumers.
275 if (auto shift_op = dynamic_cast<ShiftOp*>(expr)) {
276 const auto offset = shift_op->offset(i);
277 if (offset == 0) {
278 p_info.merge(c_info);
279 } else {
280 int pos = (offset > 0) ? 0 : 1;
281 p_info.merge(pos, c_info.width(pos) + std::abs(offset));
282 }
283 } else if (auto gather_op = dynamic_cast<GatherOp*>(expr)) {
284 const auto window_dim = gather_op->windowShape()[i];
285 if (window_dim == 1) {
286 p_info.merge(c_info);
287 continue;
288 }
289 const auto pad_dim0 = gather_op->padWidth()[i][0];
290 p_info.merge(0, c_info.width(0) + pad_dim0);
291 // The right-side halo is propagated as:
292 // consumer_right_halo + (window_dim - 1 - left_padding)
293 p_info.merge(1, c_info.width(1) + window_dim - 1 - pad_dim0);
294 } else {
295 p_info.merge(c_info);
296 }
297 setRootAxisInfo(p_id, p_info);
298 }
299}
300
301void HaloInfo::insertToInheritanceMap(
302 TensorDomain* td,
303 IterDomain* parent,
304 IterDomain* child) {
305 // Check each root domain to see if its set includes the parent. If
306 // so, adds the child to the same set.
307 bool inserted = false;
308 for (auto root_axis : td->getRootDomain()) {
309 auto it = inheritance_map_.find(root_axis);
310 if (it == inheritance_map_.end()) {
311 continue;
312 }
313 auto& id_set = it->second;
314 if (id_set.find(parent) != id_set.end()) {
315 id_set.insert(child);
316 inserted = true;
317 }
318 }
319 // No matching set found. This should not happen.
320 TORCH_INTERNAL_ASSERT(inserted);
321}
322
323void HaloInfo::initializeFromRootAxisInfo(IterDomain* id) {
324 TORCH_INTERNAL_ASSERT(hasRootAxisInfo(id));
325
326 const auto& halo_info = getRootAxisInfo(id);
327 auto halo_width = halo_info.width();
328
329 if (!halo_info.hasHalo()) {
330 setHaloWidth(id, 0);
331 return;
332 }
333
334 auto expanded_extent =
335 IrBuilder::addExpr(id->extent(), IrBuilder::create<Int>(halo_width));
336 extent_map_[id] = expanded_extent;
337 halo_width_map_[id] = halo_width;
338
339 inheritance_map_[id] = {id};
340}
341
342void HaloInfo::setHaloWidth(IterDomain* id, int halo_width) {
343 halo_width_map_[id] = halo_width;
344}
345
346// Propagate extent information from root axes to descendants
347void HaloInfo::build(TensorDomain* td) {
348 auto exprs = DependencyCheck::getAllExprsBetween(
349 {td->getMaybeRFactorDomain().begin(), td->getMaybeRFactorDomain().end()},
350 {td->domain().begin(), td->domain().end()});
351
352 // Track IDs that are generated by merging halo-extended IDs
353 std::unordered_set<IterDomain*> merged_shifted_ids;
354
355 // Propagate halo information by traversing IterDomain
356 // expressions. We populate extent_map_ and
357 // halo_width_map_.
358 // - extent_map_ maps to Expr* representing the
359 // extent of each axis including its halo. If no mapping exists for
360 // a particular axis in extent_map_, it means the axis does not have
361 // halo.
362 // - halo_width_map_ just maps to the integer size of the halo,
363 // which is used for extent comparison (e.g., extentLessEqual).
364 //
365 // - When expr is split: if the halo width of the input axis is
366 // zero, both the split outputs get zero halo in halo_width_map_. No
367 // mapping is added for extent_map_. Otherwise, the halo is
368 // propagated only to the inner output, so the inner output gets the
369 // same halo width and its mapping is created in extent_map_.
370 //
371 // One major assumption here is that splitting an axis that is
372 // an output of merging halo-extended axes is not allowed. This is
373 // because it is unclear how to split the halo part of the merged
374 // axis. This is unlikely to be a real limitation in practice.
375 //
376 // - When expr is merge: if either of the inputs has halo, a mapping
377 // for the output is created in extent_map_. No mapping is created
378 // for halo_width_map_ (see the comment on HaloInfo::halo_width_map_
379 // in lower_shift.h). If both of them don't have halo, just adds a
380 // new mapping of the output to zero in halo_width_map_. Also adds
381 // it to a set (merged_shifted_ids) to track which axes are merge
382 // outputs of halo-extended axes.
383
384 for (auto expr : exprs) {
385 if (auto split = dynamic_cast<Split*>(expr)) {
386 // Merge-then-split of halo-extended IDs is not allowed
387 TORCH_INTERNAL_ASSERT(
388 merged_shifted_ids.find(split->in()) == merged_shifted_ids.end(),
389 "Splitting IterDomain that is a merged domain of halo-extended domains is not allowed");
390
391 auto in_id = split->in();
392
393 // If no halo info is found, nothing needs to be done. This ID
394 // must be an ancestor of a domain set by setRootAxisInfo.
395 if (!hasHaloWidth(in_id)) {
396 continue;
397 }
398
399 const auto halo_width = getHaloWidth(in_id);
400
401 if (halo_width == 0) {
402 setHaloWidth(split->outer(), 0);
403 setHaloWidth(split->inner(), 0);
404 continue;
405 }
406
407 // propagate to inner domain
408 auto out_id = split->inner();
409
410 auto expanded_extent =
411 SimplifyingIrBuilder::addExpr(out_id->extent(), halo_width);
412 extent_map_.insert({out_id, expanded_extent});
413
414 setHaloWidth(split->outer(), 0);
415 setHaloWidth(split->inner(), halo_width);
416
417 insertToInheritanceMap(td, in_id, split->inner());
418 } else if (auto merge = dynamic_cast<Merge*>(expr)) {
419 // If either of the two inputs has halo extension, propagate it
420 // to the merged output ID
421 auto inner_extent = getExtent(merge->inner());
422 auto outer_extent = getExtent(merge->outer());
423 if (inner_extent != nullptr || outer_extent != nullptr) {
424 if (inner_extent == nullptr) {
425 inner_extent = merge->inner()->extent();
426 } else {
427 insertToInheritanceMap(td, merge->inner(), merge->out());
428 }
429 if (outer_extent == nullptr) {
430 outer_extent = merge->outer()->extent();
431 } else {
432 insertToInheritanceMap(td, merge->outer(), merge->out());
433 }
434 auto expanded_extent =
435 SimplifyingIrBuilder::mulExpr(outer_extent, inner_extent);
436 extent_map_.insert({merge->out(), expanded_extent});
437 // Splitting the output of this merge is not allowed, so
438 // remember it
439 merged_shifted_ids.insert(merge->out());
440 // Note that halo_width_map_ is not updated
441 } else {
442 setHaloWidth(merge->out(), 0);
443 }
444 } else if (auto swizzle = dynamic_cast<Swizzle2D*>(expr)) {
445 // Assume no halo on swizzled domain for now.
446 TORCH_INTERNAL_ASSERT(
447 getExtent(swizzle->inX()) == nullptr,
448 "Halo is not supported with swizzle. Halo-extended ID: ",
449 swizzle->inX()->toString(),
450 " used in ",
451 swizzle->toString());
452 TORCH_INTERNAL_ASSERT(
453 getExtent(swizzle->inY()) == nullptr,
454 "Halo is not supported with swizzle. Halo-extended ID: ",
455 swizzle->inY()->toString(),
456 " used in ",
457 swizzle->toString());
458 for (auto id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
459 setHaloWidth(id, 0);
460 }
461 } else {
462 TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr);
463 }
464 }
465}
466
467//! Restriction 1: When allocation is outside of a shifted
468//! axis, the shifted axis must be guaranteed to have a smaller extent
469//! than the concrete axis. For now, shifted axes always mean expanded
470//! allocations when the axis is located inside the allocation
471//! point. This restriction is validated at the allocation lowering
472//! pass.
473//!
474//! Restriction 2: If an expanded axis is parallelized, its memory
475//! must be accessible by all other threads. More specifically:
476//! - TIDx: It must be on shared memory. May want to consider
477//! utilizing the shuffle instructions as well.
478//! - BIDx: Not supported. If on global memory, Cooperative Launch
479//! may be used to support it, however, it's unclear in what
480//! situations block-level parallelization should be used.
481//!
482//! Other types of parallelization should be supported except for
483//! vectorization. Vectorization should be eventually supported but
484//! needs further work.
485void HaloInfo::validate(
486 TensorView* tv,
487 std::shared_ptr<const ComputeAtMap> ca_map) const {
488 const auto mem_type = tv->getMemoryType();
489
490 for (auto axis : tv->domain()->domain()) {
491 auto concrete_id = ca_map->getConcreteMappedID(axis, IdMappingMode::LOOP);
492
493 // The extent is assumed to be the same
494 TORCH_INTERNAL_ASSERT(
495 extentEqual(axis, concrete_id),
496 "Axis does not have the same exact size with its concrete ID due to halo extension.",
497 " Tensor: T",
498 tv->name(),
499 ", Axis: ",
500 axis,
501 ", concrete ID: ",
502 concrete_id);
503
504 auto halo_extent = getExtent(axis);
505
506 // If no halo extent is associated with this axis, it means the
507 // axis is not extended.
508 if (halo_extent == nullptr) {
509 continue;
510 }
511
512 // Enforce restrictions on parallelization and memory type
513 const auto ptype = concrete_id->getParallelType();
514
515 if (ptype == ParallelType::Serial) {
516 continue;
517 }
518
519 // Only threading parallelism is considered for now
520 TORCH_CHECK(
521 isParallelTypeThread(ptype), "Unsupported parallel type: ", ptype);
522
523 bool shared_mem_needed = false;
524 for (auto use : tv->uses()) {
525 if (!ir_utils::isTvOp(use)) {
526 continue;
527 }
528 if (use->isA<ShiftOp>() || use->isA<GatherOp>()) {
529 shared_mem_needed = true;
530 break;
531 }
532 auto consumer = use->outputs()[0]->as<TensorView>();
533 // Find the corresponding axis in the consumer
534 auto it = std::find_if(
535 consumer->domain()->domain().begin(),
536 consumer->domain()->domain().end(),
537 [&](IterDomain* consumer_axis) {
538 return ca_map->areMapped(
539 axis, consumer_axis, IdMappingMode::PERMISSIVE);
540 });
541 if (it == consumer->domain()->domain().end()) {
542 continue;
543 }
544 if (!extentEqual(axis, *it)) {
545 shared_mem_needed = true;
546 break;
547 }
548 }
549
550 if (!shared_mem_needed) {
551 continue;
552 }
553
554 if (isParallelTypeThreadDim(ptype)) {
555 // If all the consumers have the same extent and none of the
556 // expressions is shift, any memory should be fine. Otherwise, it
557 // must be accessible by all threads involved in the
558 // parallelization.
559 TORCH_CHECK(
560 mem_type == MemoryType::Shared,
561 "TV",
562 tv->name(),
563 " must be allocated on shared memory as its halo-extended axis is parallelized by ",
564 ptype);
565
566 } else if (isParallelTypeBlockDim(ptype)) {
567 TORCH_CHECK(
568 false,
569 "Block-based parallelization of a halo-extended axis is not supported: ",
570 axis);
571 }
572 }
573 return;
574}
575
576Val* HaloInfo::getExtent(IterDomain* id) const {
577 auto it = extent_map_.find(id);
578 if (it != extent_map_.end()) {
579 return it->second;
580 } else {
581 return nullptr;
582 }
583}
584
585int HaloInfo::getHaloWidth(IterDomain* id) const {
586 auto it = halo_width_map_.find(id);
587 TORCH_INTERNAL_ASSERT(it != halo_width_map_.end());
588 return it->second;
589}
590
591bool HaloInfo::hasHaloWidth(IterDomain* id) const {
592 return halo_width_map_.find(id) != halo_width_map_.end();
593}
594
595const std::unordered_set<IterDomain*>& HaloInfo::getChildDomains(
596 IterDomain* root_id) const {
597 auto it = inheritance_map_.find(root_id);
598 TORCH_INTERNAL_ASSERT(
599 it != inheritance_map_.end(),
600 "Domain not found in the inheritance map: ",
601 root_id);
602 return it->second;
603}
604
605bool HaloInfo::isHaloInherited(IterDomain* root_id, IterDomain* id) const {
606 return getChildDomains(root_id).count(id) > 0;
607}
608
609std::unordered_set<IterDomain*> HaloInfo::getRootDomains(IterDomain* id) const {
610 std::unordered_set<IterDomain*> id_set;
611
612 for (const auto& kv : inheritance_map_) {
613 if (kv.second.count(id) > 0) {
614 id_set.insert(kv.first);
615 }
616 }
617
618 return id_set;
619}
620
621namespace {
622
623//! Prove if the comparison operator, cmp, is true with the extents of
624//! id1 and id2, including their halo. The comparison is done
625//! conservatively, meaning false negative is possible.
626//!
627//! It is assumed that id1 and id2 are mapped with the CA Loop map, so
628//! what is checked here is only about halo
629//! sizes using HaloInfo::halo_width_map_. Since it does not have
630//! mappings for merged axes, each axis of merge inputs are
631//! individually compared, and only when both of the input axes
632//! return true, the merge output axis returns true.
633template <typename Cmp>
634bool extentCompare(
635 const HaloInfo& halo_map,
636 IterDomain* id1,
637 IterDomain* id2,
638 Cmp cmp,
639 const DisjointSets<IterDomain*>& permissive_map) {
640 TORCH_INTERNAL_ASSERT(
641 permissive_map.strictAreMapped(id1, id2), "Invalid axes to compare");
642
643 // It's invalid to compare two axes and when only either of them has
644 // halo.
645
646 if (halo_map.hasHaloWidth(id1)) {
647 TORCH_INTERNAL_ASSERT(
648 halo_map.hasHaloWidth(id2), "Invalid comparison: ", id1, " and ", id2);
649 // Both axes have halo. We assume the axes themselves have equal
650 // extents, excluding halo, as they are mapped with the CA
651 // map. So, we just need to compare the halo width of each axis.
652 return cmp(halo_map.getHaloWidth(id1), halo_map.getHaloWidth(id2));
653 } else {
654 TORCH_INTERNAL_ASSERT(!halo_map.hasHaloWidth(id2));
655 // Both don't have halo. The only case this can happen must be
656 // both axes are the output of a merge expression, so each merge
657 // input is recursively compared, and returns true only when both
658 // inputs return.
659 if (auto merge1 = dynamic_cast<Merge*>(id1->definition())) {
660 auto merge2 = dynamic_cast<Merge*>(id2->definition());
661 TORCH_INTERNAL_ASSERT(
662 merge2 != nullptr, "Invalid comparison: ", id1, " and ", id2);
663 auto inner_le = extentCompare(
664 halo_map, merge1->inner(), merge2->inner(), cmp, permissive_map);
665 auto outer_le = extentCompare(
666 halo_map, merge1->outer(), merge2->outer(), cmp, permissive_map);
667 return inner_le && outer_le;
668 } else {
669 // This is not considered. Should never reach here.
670 TORCH_INTERNAL_ASSERT(false, "Invalid comparison: ", id1, " and ", id2);
671 }
672 }
673}
674
675} // namespace
676
677bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const {
678 return extentCompare(*this, id1, id2, std::less_equal<>(), permissive_map_);
679}
680
681bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const {
682 return extentCompare(*this, id1, id2, std::equal_to<>(), permissive_map_);
683}
684
685std::string HaloInfo::toString() const {
686 std::stringstream ss;
687
688 ss << "HaloInfo:\n";
689
690 if (root_axis_map_.empty()) {
691 return ss.str();
692 }
693
694 Fusion* fusion = root_axis_map_.begin()->first->fusion();
695
696 auto used_vals = DependencyCheck::getAllValsBetween(
697 {fusion->inputs().begin(), fusion->inputs().end()}, fusion->outputs());
698
699 for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
700 const auto& root = tv->getRootDomain();
701 ss << "TV" << tv->name() << " root domain: ";
702 for (auto axis : root) {
703 ss << axis << " -> " << getRootAxisInfo(axis).toString() << ", ";
704 }
705 ss << "\n";
706 }
707
708 return ss.str();
709}
710
711bool HaloInfo::needsShiftPredicate(Expr* expr) const {
712 // In lowering shift and gather turn into a unary op. We really need the shift
713 // expr. Do a round about trick to grab it:
714 auto tv_out = ir_utils::getTvOutput(expr);
715 auto consumer_td = tv_out->domain();
716 auto shift_expr = dynamic_cast<ShiftOp*>(tv_out->definition());
717 auto gather_expr = dynamic_cast<GatherOp*>(tv_out->definition());
718 for (const auto i : c10::irange(consumer_td->getRootDomain().size())) {
719 auto consumer_id = consumer_td->getRootDomain()[i];
720 const auto consumer_halo_info = getRootAxisInfo(consumer_id);
721 if (consumer_halo_info.hasHalo() ||
722 (shift_expr != nullptr && shift_expr->offset(i) != 0 &&
723 !consumer_id->isBroadcast()) ||
724 (gather_expr != nullptr && gather_expr->windowShape()[i] != 1 &&
725 !consumer_id->isBroadcast())) {
726 return true;
727 }
728 }
729 return false;
730}
731
732std::unordered_map<IterDomain*, Val*> HaloInfo::buildConcreteHaloExtentMap(
733 const LoopIndexing& loop_indexing) const {
734 // Use a local workspace to avoid re-defining halo info.
735 HaloInfo local_halo_info = *GpuLower::current()->haloInfo();
736
737 auto global_halo_info = GpuLower::current()->haloInfo();
738
739 // Setup root:
740 for (auto consumer_root_id : loop_indexing.consumerTv()->getRootDomain()) {
741 auto consumer_index_concrete_id =
742 GpuLower::current()->caMap()->getConcreteMappedID(
743 consumer_root_id, IdMappingMode::EXACT);
744 local_halo_info.setRootAxisInfo(
745 consumer_index_concrete_id,
746 global_halo_info->getRootAxisInfo(consumer_root_id));
747 }
748
749 // Track IDs that are generated by merging halo-extended IDs
750 std::unordered_set<IterDomain*> merged_shifted_ids;
751
752 for (auto expr : loop_indexing.getForwardExprList()) {
753 if (auto split = dynamic_cast<Split*>(expr)) {
754 // Merge-then-split of halo-extended IDs is not allowed
755 TORCH_INTERNAL_ASSERT(
756 merged_shifted_ids.find(split->in()) == merged_shifted_ids.end(),
757 "Splitting IterDomain that is a merged domain of halo-extended domains is not allowed");
758
759 auto in_id = GpuLower::current()->caMap()->getConcreteMappedID(
760 split->in(), IdMappingMode::EXACT);
761
762 // If no halo info is found, nothing needs to be done. This ID
763 // must be an ancestor of a domain set by setRootAxisInfo.
764 if (!local_halo_info.hasHaloWidth(in_id)) {
765 continue;
766 }
767
768 const auto halo_width = local_halo_info.getHaloWidth(in_id);
769
770 if (halo_width == 0) {
771 local_halo_info.setHaloWidth(
772 GpuLower::current()->caMap()->getConcreteMappedID(
773 split->outer(), IdMappingMode::EXACT),
774 0);
775 local_halo_info.setHaloWidth(
776 GpuLower::current()->caMap()->getConcreteMappedID(
777 split->inner(), IdMappingMode::EXACT),
778 0);
779 continue;
780 }
781
782 // propagate to inner domain
783 auto out_id = GpuLower::current()->caMap()->getConcreteMappedID(
784 split->inner(), IdMappingMode::EXACT);
785
786 auto expanded_extent =
787 SimplifyingIrBuilder::addExpr(out_id->extent(), halo_width);
788 local_halo_info.extent_map_.insert({out_id, expanded_extent});
789
790 local_halo_info.setHaloWidth(
791 GpuLower::current()->caMap()->getConcreteMappedID(
792 split->outer(), IdMappingMode::EXACT),
793 0);
794 local_halo_info.setHaloWidth(
795 GpuLower::current()->caMap()->getConcreteMappedID(
796 split->inner(), IdMappingMode::EXACT),
797 halo_width);
798
799 // TODO: add support for inheritance map
800 } else if (auto merge = dynamic_cast<Merge*>(expr)) {
801 // If either of the two inputs has halo extension, propagate it
802 // to the merged output ID
803 auto inner_extent = local_halo_info.getExtent(
804 GpuLower::current()->caMap()->getConcreteMappedID(
805 merge->inner(), IdMappingMode::EXACT));
806 auto outer_extent = local_halo_info.getExtent(
807 GpuLower::current()->caMap()->getConcreteMappedID(
808 merge->outer(), IdMappingMode::EXACT));
809 if (inner_extent != nullptr || outer_extent != nullptr) {
810 if (inner_extent == nullptr) {
811 inner_extent = merge->inner()->extent();
812 }
813 if (outer_extent == nullptr) {
814 outer_extent = merge->outer()->extent();
815 }
816 auto expanded_extent =
817 SimplifyingIrBuilder::mulExpr(outer_extent, inner_extent);
818 local_halo_info.extent_map_.insert(
819 {GpuLower::current()->caMap()->getConcreteMappedID(
820 merge->out(), IdMappingMode::EXACT),
821 expanded_extent});
822 // Splitting the output of this merge is not allowed, so
823 // remember it
824 merged_shifted_ids.insert(
825 GpuLower::current()->caMap()->getConcreteMappedID(
826 merge->out(), IdMappingMode::EXACT));
827 // Note that halo_width_map_ is not updated
828 } else {
829 local_halo_info.setHaloWidth(
830 GpuLower::current()->caMap()->getConcreteMappedID(
831 merge->out(), IdMappingMode::EXACT),
832 0);
833 }
834 } else if (auto swizzle_2d = dynamic_cast<Swizzle2D*>(expr)) {
835 // Swizzle with halo not yet supported, just set the width
836 // to zero at the moment.
837 TORCH_INTERNAL_ASSERT(
838 local_halo_info.getHaloWidth(
839 GpuLower::current()->caMap()->getConcreteMappedID(
840 swizzle_2d->inX(), IdMappingMode::EXACT)) == 0 &&
841 local_halo_info.getHaloWidth(
842 GpuLower::current()->caMap()->getConcreteMappedID(
843 swizzle_2d->inY(), IdMappingMode::EXACT)) == 0,
844 "Swizzle on ID with halo not yet supported.");
845 TORCH_INTERNAL_ASSERT("Swizzle on ID with halo not yet supported.");
846 local_halo_info.setHaloWidth(
847 GpuLower::current()->caMap()->getConcreteMappedID(
848 swizzle_2d->outX(), IdMappingMode::EXACT),
849 0);
850 local_halo_info.setHaloWidth(
851 GpuLower::current()->caMap()->getConcreteMappedID(
852 swizzle_2d->outY(), IdMappingMode::EXACT),
853 0);
854 } else {
855 TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr);
856 }
857 }
858
859 return local_halo_info.extent_map_;
860}
861
862} // namespace cuda
863} // namespace fuser
864} // namespace jit
865} // namespace torch
866