1#include <index_compute.h>
2
3#include <c10/util/Exception.h>
4#include <c10/util/irange.h>
5#include <arith.h>
6#include <contiguity.h>
7#include <expr_evaluator.h>
8#include <instrumentation.h>
9#include <ir_all_nodes.h>
10#include <ir_iostream.h>
11#include <ir_utils.h>
12#include <kernel_expr_evaluator.h>
13#include <lower2device.h>
14#include <lower_double_buffer.h>
15#include <lower_index_compute.h>
16#include <lower_magic_zero.h>
17#include <lower_shift.h>
18#include <lower_unroll.h>
19#include <lower_utils.h>
20#include <lower_validation.h>
21#include <root_domain_map.h>
22#include <transform_iter.h>
23#include <transform_replay.h>
24
25namespace torch {
26namespace jit {
27namespace fuser {
28namespace cuda {
29
30namespace {
31
32//! Offset of an index of a producer axis with respect to its
33//! corresponding consumer index
34int getProducerHaloOffset(
35 const TensorView* producer_tv,
36 size_t producer_axis,
37 const TensorView* consumer_tv) {
38 auto p2c =
39 PairwiseRootDomainMap(producer_tv, consumer_tv)
40 .mapProducerToConsumer(producer_tv->domain(), consumer_tv->domain());
41
42 auto producer_id = producer_tv->getMaybeRFactorDomain()[producer_axis];
43
44 auto it = p2c.find(producer_id);
45 // p2c should always have a mapping for producer_id. The only case
46 // where no mapping exists for a producer axis is when it is a
47 // reduction axis. Since this function is only used for indexing
48 // producer tensors, where reduction axes are skipped, producer_id
49 // should never be a reduction axis.
50 TORCH_INTERNAL_ASSERT(it != p2c.end());
51 IterDomain* consumer_id = it->second;
52
53 const auto& halo_map = GpuLower::current()->haloInfo();
54 const auto p_pad = halo_map->getRootAxisInfo(producer_id).width(0);
55 const auto c_pad = halo_map->getRootAxisInfo(consumer_id).width(0);
56
57 auto offset = p_pad - c_pad;
58
59 // If the consumer is a result of shifting the producer, adjust the
60 // producer index per the offsets argument of the shift op.
61 if (auto shift_op = dynamic_cast<const ShiftOp*>(consumer_tv->definition())) {
62 offset -= shift_op->offset(producer_axis);
63 }
64
65 return offset;
66}
67
68//! Offset producer index when necessary
69Val* getProducerIndexWithHalo(
70 const TensorView* producer_tv,
71 size_t producer_axis,
72 Val* producer_index,
73 const TensorView* consumer_tv) {
74 const auto offset =
75 getProducerHaloOffset(producer_tv, producer_axis, consumer_tv);
76
77 if (offset == 0) {
78 return producer_index;
79 }
80
81 producer_index = SimplifyingIrBuilder::addExpr(producer_index, offset);
82
83 return producer_index;
84}
85
86//! Create a producer offset based off a consumer index
87//!
88//! \param consumer_root_axis Position of corresponding consumer axis
89//! \param consumer_tv Consumer TensorView
90//! \param index_map Mappings from consumer or reference to indices
91//! \param use_reference_map True when index_map maps reference domains
92//! \param concrete_to_ref_map Mappings from concrete to reference domains
93Val* getProducerOffsetWithGather(
94 size_t consumer_root_axis,
95 const TensorView* consumer_tv,
96 const std::unordered_map<IterDomain*, Val*>& index_map,
97 bool use_reference_map = false,
98 const std::unordered_map<IterDomain*, IterDomain*>& concrete_to_ref_map =
99 {}) {
100 const auto gpu_lower = GpuLower::current();
101
102 const auto gather_expr = dynamic_cast<GatherOp*>(consumer_tv->definition());
103
104 if (gather_expr == nullptr) {
105 return gpu_lower->kernel()->zeroVal();
106 }
107
108 // If the window extent is one, no specific offsetting
109 // is necessary
110 if (consumer_root_axis >= gather_expr->windowShape().size() ||
111 gather_expr->windowShape()[consumer_root_axis] == 1) {
112 return gpu_lower->kernel()->zeroVal();
113 }
114
115 // Basically, the goal is to build an expression of producer_index +
116 // window_index, so we first need to locate the index expression
117 // that corresponds to the window axis of this producer axis.
118
119 const auto window_axis = gather_expr->gatherAxis(consumer_root_axis);
120 auto window_id = consumer_tv->getRootDomain().at(window_axis);
121
122 // When index_map maps a reference tensor, find the corresponding
123 // reference ID of window_id.
124 if (use_reference_map) {
125 auto concrete_window_id = gpu_lower->caMap()->getConcreteMappedID(
126 window_id, IdMappingMode::EXACT);
127 auto concrete_2_ref_it = concrete_to_ref_map.find(concrete_window_id);
128 TORCH_INTERNAL_ASSERT(concrete_2_ref_it != concrete_to_ref_map.end());
129 window_id = concrete_2_ref_it->second;
130 }
131
132 auto window_idx = index_map.at(window_id);
133
134 // Positive padding at offset zero means the indexing shifted to the
135 // negative direction.
136 auto pad_width = gather_expr->padWidth()[consumer_root_axis][0];
137
138 // producer offset: window_index - padding
139 auto producer_offset = SimplifyingIrBuilder::subExpr(
140 window_idx, SimplifyingIrBuilder::create<Int>(pad_width));
141 return producer_offset;
142}
143
144//! Create a producer offset based off a consumer index
145//!
146//! \param consumer_root_axis Position of corresponding consumer axis
147//! \param consumer_tv Consumer TensorView
148//! \param index_map Mappings from consumer or reference to indices
149//! \param use_reference_map True when index_map maps reference domains
150//! \param concrete_to_ref_map Mappings from concrete to reference domains
151Val* getConcreteProducerOffsetWithGather(
152 size_t consumer_root_axis,
153 const TensorView* consumer_tv,
154 const std::unordered_map<IterDomain*, Val*>& index_map,
155 bool use_concrete_map = false) {
156 const auto gpu_lower = GpuLower::current();
157
158 const auto gather_expr = dynamic_cast<GatherOp*>(consumer_tv->definition());
159
160 if (gather_expr == nullptr) {
161 return gpu_lower->kernel()->zeroVal();
162 }
163
164 // If the window extent is one, no specific offsetting
165 // is necessary
166 if (consumer_root_axis >= gather_expr->windowShape().size() ||
167 gather_expr->windowShape()[consumer_root_axis] == 1) {
168 return gpu_lower->kernel()->zeroVal();
169 }
170
171 // Basically, the goal is to build an expression of producer_index +
172 // window_index, so we first need to locate the index expression
173 // that corresponds to the window axis of this producer axis.
174
175 const auto window_axis = gather_expr->gatherAxis(consumer_root_axis);
176 auto window_id = consumer_tv->getRootDomain().at(window_axis);
177
178 Val* window_idx = nullptr;
179
180 if (use_concrete_map) {
181 window_idx = index_map.at(GpuLower::current()->caMap()->getConcreteMappedID(
182 window_id, IdMappingMode::EXACT));
183 } else {
184 window_idx = index_map.at(window_id);
185 }
186
187 // Positive padding at offset zero means the indexing shifted to the
188 // negative direction.
189 auto pad_width = gather_expr->padWidth()[consumer_root_axis][0];
190
191 // producer offset: window_index - padding
192 auto producer_offset = SimplifyingIrBuilder::subExpr(
193 window_idx, SimplifyingIrBuilder::create<Int>(pad_width));
194 return producer_offset;
195}
196
197//! Offset a producer index of a gather expression
198//!
199//! Given an index of a producer root axis, build a new index
200//! expression that accesses a window position that the current loop
201//! structure refers to. Use getGatherProducerOffset to create an
202//! offset Val.
203Val* getProducerIndexWithGather(
204 Val* producer_index,
205 size_t producer_root_axis,
206 const TensorView* producer_tv,
207 const TensorView* consumer_tv,
208 const std::unordered_map<IterDomain*, Val*>& concrete_index_map) {
209 auto gather_op = dynamic_cast<const GatherOp*>(consumer_tv->definition());
210
211 // Just return the producer index as is if this is not a gather
212 if (gather_op == nullptr) {
213 return producer_index;
214 }
215
216 // Consumer axis that corresponds to the producer axis
217 int consumer_axis = -1;
218 for (const auto i : c10::irange(producer_root_axis + 1)) {
219 if (producer_tv->getMaybeRFactorDomain()[i]->isReduction() ||
220 producer_tv->getMaybeRFactorDomain()[i]->isStride()) {
221 continue;
222 }
223 ++consumer_axis;
224 }
225
226 TORCH_INTERNAL_ASSERT(
227 consumer_axis >= 0 &&
228 consumer_axis < (int)gather_op->windowShape().size(),
229 "Invalid consumer axis",
230 consumer_axis,
231 ", producer_axis: ",
232 producer_root_axis);
233
234 auto offset = getConcreteProducerOffsetWithGather(
235 consumer_axis, consumer_tv, concrete_index_map, true);
236 return SimplifyingIrBuilder::addExpr(producer_index, offset);
237}
238
239// Adjusts a global consumer index when its root domain is partially
240// split. Note that non-global consumer indices don't need any
241// adjustment.
242Val* getGlobalConsumerOffsetWithPartialSplit(IterDomain* root_id) {
243 auto offset = GpuLower::current()->partialSplitMap().getStartOffset(root_id);
244 if (offset == nullptr) {
245 return GpuLower::current()->kernel()->zeroVal();
246 } else {
247 return offset;
248 }
249}
250
251// Adjusts a global producer index when its root domain and
252// corresponding consumer root domain have non-matching split
253// offsets. Specifically, since producer_index is calcualted based on
254// the consumer, if the consumer has a non-zero offset,
255// it needs to be added to the index. Also, when the producer itself
256// also has a non-zero split offset, that needs to be subtracted from
257// the index.
258Val* getProducerIndexWithPartialSplit(
259 Val* producer_index,
260 IterDomain* producer_root_id,
261 const TensorView* producer_tv,
262 const TensorView* consumer_tv) {
263 const auto gpu_lower = GpuLower::current();
264
265 auto p2c =
266 PairwiseRootDomainMap(producer_tv, consumer_tv)
267 .mapProducerToConsumer(producer_tv->domain(), consumer_tv->domain());
268
269 auto it = p2c.find(producer_root_id);
270 if (it == p2c.end()) {
271 return producer_index;
272 }
273
274 auto consumer_root_id = it->second;
275
276 auto consumer_offset =
277 gpu_lower->partialSplitMap().getStartOffset(consumer_root_id);
278 consumer_offset = consumer_offset == nullptr ? gpu_lower->kernel()->zeroVal()
279 : consumer_offset;
280
281 auto producer_offset =
282 gpu_lower->partialSplitMap().getStartOffset(producer_root_id);
283 producer_offset = producer_offset == nullptr ? gpu_lower->kernel()->zeroVal()
284 : producer_offset;
285
286 // If the producer is on global memory, it's always allocated
287 // without trimming the out-of-bounds region, so the consumer offset
288 // should be added to the index.
289 if (producer_tv->getMemoryType() == MemoryType::Global) {
290 if (consumer_offset->isZeroInt()) {
291 return producer_index;
292 } else {
293 return SimplifyingIrBuilder::addExpr(producer_index, consumer_offset);
294 }
295 }
296
297 // Non-global case. Difference of the split offsets must be
298 // accounted.
299
300 auto diff = SimplifyingIrBuilder::subExpr(consumer_offset, producer_offset);
301 kir::ExpressionEvaluator ee;
302 auto diff_eval = ee.evaluate(diff);
303 // We currently only allow constant offsetting
304 TORCH_INTERNAL_ASSERT(diff_eval.has_value(), "Invalid partial split");
305
306 if (diff_eval.value() == 0) {
307 return producer_index;
308 }
309
310 return SimplifyingIrBuilder::addExpr(
311 producer_index,
312 SimplifyingIrBuilder::create<Int>(diff_eval->as<int64_t>()));
313}
314
315} // namespace
316
317void IndexCompute::handle(Split* split) {
318 auto in_id = maybeGetExactMapConcreteID(split->in()->as<IterDomain>());
319 auto outer_id = maybeGetExactMapConcreteID(split->outer()->as<IterDomain>());
320 auto inner_id = maybeGetExactMapConcreteID(split->inner()->as<IterDomain>());
321
322 auto outer_it = index_map_.find(outer_id);
323 auto inner_it = index_map_.find(inner_id);
324 if (outer_it == index_map_.end() || inner_it == index_map_.end()) {
325 return;
326 }
327
328 const auto outer_ind = outer_it->second;
329 const auto inner_ind = inner_it->second;
330
331 const bool outer_zero = isZero(outer_id);
332 const bool inner_zero = isZero(inner_id);
333
334 // We want to mark as zero merged in if we're working with shared or local
335 // memory, and the dimension we're working with is not part of the allocation,
336 // as we have special propagation rules for that scenario.
337
338 // Maybe clear in_id as it could have been mapped over from another
339 // IndexCompute. Uncertain if this is needed but seems to be safe.
340 bool zero_merged_in = hasZeroMerged(in_id) || hasZeroMerged(inner_id) ||
341 hasZeroMerged(outer_id);
342
343 // If both are zero, the split input is also zero
344 if (inner_zero && outer_zero) {
345 zero_domains_.emplace(in_id);
346 }
347
348 if (zero_merged_in) {
349 zero_merged_in_.emplace(in_id);
350 }
351
352 if (isZero(in_id)) {
353 index_map_[in_id] = GpuLower::current()->kernel()->zeroVal();
354 extent_map_[in_id] = GpuLower::current()->kernel()->zeroVal();
355 } else if (zero_merged_in && outer_zero) {
356 index_map_[in_id] = inner_ind;
357 extent_map_[in_id] = getExtent(inner_id);
358 } else if (zero_merged_in && inner_zero) {
359 index_map_[in_id] = outer_ind;
360 extent_map_[in_id] = getExtent(outer_id);
361 } else {
362 index_map_[in_id] = SimplifyingIrBuilder::addExpr(
363 SimplifyingIrBuilder::mulExpr(outer_ind, getExtent(inner_id)),
364 inner_ind);
365 // The extent should be updated only when its allocation is
366 // partial, i.e., zero_merged_in is true. See PR #1270.
367 if (zero_merged_in) {
368 extent_map_[in_id] = SimplifyingIrBuilder::mulExpr(
369 getExtent(outer_id), getExtent(inner_id));
370 }
371 }
372}
373
374void IndexCompute::handle(Merge* merge) {
375 auto out_id = maybeGetExactMapConcreteID(merge->out());
376 auto outer_id = maybeGetExactMapConcreteID(merge->outer());
377 auto inner_id = maybeGetExactMapConcreteID(merge->inner());
378
379 auto out_it = index_map_.find(out_id);
380 if (out_it == index_map_.end()) {
381 return;
382 }
383 auto out_ind = out_it->second;
384
385 auto zero = GpuLower::current()->kernel()->zeroVal();
386
387 if (isZero(out_id)) {
388 index_map_[outer_id] = zero;
389 index_map_[inner_id] = zero;
390 // TODO: Why do we set extent_map_ to zero? This has to be protected by zero
391 // merged in, but seems logical to me the extent would still be one.
392 extent_map_[outer_id] = zero;
393 extent_map_[inner_id] = zero;
394 zero_domains_.emplace(outer_id);
395 zero_domains_.emplace(inner_id);
396 return;
397 }
398
399 if (!hasZeroMerged(out_id) && contig_ids_.find(out_id) != contig_ids_.end()) {
400 // Contiguous indexing path
401 auto input_ids = ir_utils::iterDomainInputsOfOrderedAs(
402 {merge->out()}, td_->getMaybeRFactorDomain());
403
404 // Shouldn't hit this, but don't want to segfault if somehow we do.
405 TORCH_INTERNAL_ASSERT(!input_ids.empty());
406
407 // Try to find the last non broadcast entry to put the index in if it's a
408 // contiguous merge. This isn't strictly necessary but there's implicit
409 // assumptions in the indexing logic that assume broadcasted root domains
410 // can be ignored. This logic is just to try and match that logic.
411 // Initialize everything to zero.
412 for (auto root_id : input_ids) {
413 index_map_[root_id] = zero;
414 }
415
416 // If all are broadcast we can just send the index to the last entry.
417 if (std::all_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) {
418 // I don't think reductions can be in here, but strictly matching the
419 // logic in the indexing functions like
420 // getNonGlobalConsumerStridedIndices
421 return id->isBroadcast() || id->isReduction() || id->isStride();
422 })) {
423 index_map_[*(input_ids.end() - 1)] = out_ind;
424 } else {
425 for (auto id_it = input_ids.rbegin(); id_it != input_ids.rend();
426 id_it++) {
427 auto id = *id_it;
428 if (id->isBroadcast() || id->isReduction() || id->isStride()) {
429 continue;
430 } else {
431 index_map_[id] = out_ind;
432 break;
433 }
434 }
435 }
436
437 return;
438 }
439
440 Val* inner_extent = getExtent(inner_id);
441
442 // When the reference has halo extent for inner_id, that extent needs to
443 // be used to un-merge
444 if (halo_extent_map_.find(inner_id) != halo_extent_map_.end()) {
445 inner_extent = halo_extent_map_[inner_id];
446 }
447
448 const auto outer_extent = getExtent(outer_id);
449
450 if (inner_id->isBroadcast() && inner_extent->isOneInt()) {
451 // Propagate away from broadcast dims
452 index_map_[outer_id] = out_ind;
453 index_map_[inner_id] = zero;
454
455 extent_map_[outer_id] = getExtent(out_id);
456 if (hasZeroMerged(out_id)) {
457 zero_merged_in_.insert(outer_id);
458 }
459 } else if (outer_id->isBroadcast() && outer_extent->isOneInt()) {
460 // Propagate away from broadcast dims
461 index_map_[outer_id] = zero;
462 index_map_[inner_id] = out_ind;
463
464 extent_map_[inner_id] = getExtent(out_id);
465 if (hasZeroMerged(out_id)) {
466 zero_merged_in_.insert(inner_id);
467 }
468 } else if (hasZeroMerged(out_id)) {
469 // Don't propagate to inner id if it's comprised of only broadcast root
470 // domains, unless outer is also all broadcast domains. Index shouldn't be
471 // anything but zero if both inner and outer are all broadcast domains, but
472 // didn't add a hard check for this. See FusionAdvancedIndexing5_CUDA
473 if (!inner_id->isBroadcast() && !outer_id->isBroadcast()) {
474 // If neither dimension is a broadcast (should be true for reference
475 // indexing) pick the preferred path or the inner path.
476 if (preferred_paths_.find(outer_id) != preferred_paths_.end() &&
477 preferred_paths_.find(inner_id) == preferred_paths_.end()) {
478 // Marked that we should prop through outer, not inner.
479 index_map_[outer_id] = out_ind;
480 extent_map_[outer_id] = getExtent(out_id);
481 index_map_[inner_id] = zero;
482 extent_map_[inner_id] = zero;
483 zero_domains_.emplace(inner_id);
484 } else {
485 // Prop through inner
486 index_map_[inner_id] = out_ind;
487 extent_map_[inner_id] = getExtent(out_id);
488 index_map_[outer_id] = zero;
489 extent_map_[outer_id] = zero;
490 zero_domains_.emplace(outer_id);
491 }
492 } else if (inner_id->isBroadcast() && !outer_id->isBroadcast()) {
493 // Inner is broadcast and outer isn't, prop through outer
494 index_map_[outer_id] = out_ind;
495 extent_map_[outer_id] = getExtent(out_id);
496 index_map_[inner_id] = zero;
497 extent_map_[inner_id] = zero;
498 zero_domains_.emplace(inner_id);
499 } else {
500 // Default to propagating through inner
501 index_map_[inner_id] = out_ind;
502 extent_map_[inner_id] = getExtent(out_id);
503 index_map_[outer_id] = zero;
504 extent_map_[outer_id] = zero;
505 zero_domains_.emplace(outer_id);
506 }
507 zero_merged_in_.emplace(inner_id);
508 zero_merged_in_.emplace(outer_id);
509 } else {
510 index_map_[outer_id] = SimplifyingIrBuilder::divExpr(out_ind, inner_extent);
511 index_map_[inner_id] = SimplifyingIrBuilder::modExpr(out_ind, inner_extent);
512 }
513}
514
515void IndexCompute::handle(Swizzle2D* swizzle_2d) {
516 auto out_x_id = maybeGetExactMapConcreteID(swizzle_2d->outX());
517 auto out_y_id = maybeGetExactMapConcreteID(swizzle_2d->outY());
518 auto in_x_id = maybeGetExactMapConcreteID(swizzle_2d->inX());
519 auto in_y_id = maybeGetExactMapConcreteID(swizzle_2d->inY());
520
521 auto out_x_it = index_map_.find(out_x_id);
522 auto out_y_it = index_map_.find(out_y_id);
523
524 if (out_x_it == index_map_.end() || out_y_it == index_map_.end()) {
525 return;
526 }
527
528 const auto out_x_ind = out_x_it->second;
529 const auto out_y_ind = out_y_it->second;
530
531 if (swizzle_mode_ == SwizzleMode::NoSwizzle ||
532 swizzle_mode_ != swizzle_2d->swizzleMode()) {
533 // Handle inactive swizzles by just passing through index
534 // and extend information.
535
536 TORCH_INTERNAL_ASSERT(
537 index_map_.count(in_x_id) == index_map_.count(in_y_id),
538 "input index should be either both defined or both undefined");
539 if (index_map_.count(in_x_id)) {
540 // Only propagate original index through if
541 // the input index hasn't been computed.
542 // TODO:
543 // This part should be cleaner once we remove the
544 // second index traversal pass.
545 return;
546 }
547 index_map_[in_x_id] = out_x_ind;
548 index_map_[in_y_id] = out_y_ind;
549 extent_map_[in_y_id] = getExtent(out_y_id);
550 extent_map_[in_x_id] = getExtent(out_x_id);
551 } else {
552 // Generate integer swizzle math if the
553 // swizzle is activated. See also
554 // [Note on swizzle mode].
555
556 auto out_pair = IrBuilder::swizzle2DIntExpr(
557 out_x_ind,
558 out_y_ind,
559 getExtent(out_x_id),
560 getExtent(out_y_id),
561 swizzle_2d->swizzleType());
562
563 index_map_[in_x_id] =
564 IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X);
565 index_map_[in_y_id] =
566 IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y);
567 }
568}
569
570void IndexCompute::handle(Expr* e) {
571 switch (e->getExprType().value()) {
572 case (ExprType::Split):
573 case (ExprType::Merge):
574 case (ExprType::Swizzle2D):
575 break;
576 default:
577 TORCH_INTERNAL_ASSERT(
578 false, "Invalid expr type found in transform traversal.");
579 }
580 BackwardVisitor::handle(e);
581}
582
583IndexCompute::IndexCompute(
584 const TensorDomain* _td,
585 std::unordered_map<IterDomain*, Val*> initial_index_map,
586 std::unordered_map<IterDomain*, Val*> extent_map,
587 std::unordered_set<IterDomain*> zero_domains,
588 std::unordered_set<IterDomain*> zero_merged_in,
589 std::unordered_set<IterDomain*> preferred_paths,
590 std::unordered_map<IterDomain*, Val*> halo_extent_map)
591 : IndexCompute(
592 _td,
593 std::move(initial_index_map),
594 std::move(extent_map),
595 std::move(zero_domains),
596 std::move(zero_merged_in),
597 ContigIDs::getNonContigIDs(),
598 std::move(preferred_paths),
599 std::move(halo_extent_map)) {}
600
601IndexCompute::IndexCompute(
602 const TensorDomain* _td,
603 std::unordered_map<IterDomain*, Val*> initial_index_map,
604 std::unordered_map<IterDomain*, Val*> extent_map,
605 std::unordered_set<IterDomain*> zero_domains,
606 std::unordered_set<IterDomain*> zero_merged_in,
607 const ContigIDs& contig_finder,
608 std::unordered_set<IterDomain*> preferred_paths,
609 std::unordered_map<IterDomain*, Val*> halo_extent_map)
610 : td_(_td),
611 index_map_(std::move(initial_index_map)),
612 extent_map_(std::move(extent_map)),
613 zero_domains_(std::move(zero_domains)),
614 zero_merged_in_(std::move(zero_merged_in)),
615 preferred_paths_(std::move(preferred_paths)),
616 halo_extent_map_(std::move(halo_extent_map)) {
617 FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");
618
619 // Make sure we recompute any indices we can that map to a contiguous access
620 // in physical memory.
621 contig_ids_ = contig_finder.contigIDs();
622 root_to_indexed_id_ = contig_finder.rootToIndexedID();
623 const auto& within_contig = contig_finder.withinContigIDs();
624 for (auto contig_id : contig_ids_) {
625 if (index_map_.find(contig_id) != index_map_.end()) {
626 TORCH_INTERNAL_ASSERT(
627 within_contig.find(contig_id) != within_contig.end());
628 for (auto id : within_contig.at(contig_id)) {
629 index_map_.erase(id);
630 }
631 }
632 }
633}
634
635IndexCompute::IndexCompute(
636 std::unordered_map<IterDomain*, Val*> initial_index_map,
637 std::unordered_set<IterDomain*> zero_domains,
638 std::unordered_set<IterDomain*> preferred_paths,
639 std::unordered_map<IterDomain*, Val*> halo_extent_map)
640 : index_map_(std::move(initial_index_map)),
641 zero_domains_(std::move(zero_domains)),
642 preferred_paths_(std::move(preferred_paths)),
643 halo_extent_map_(std::move(halo_extent_map)) {
644 FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");
645 concrete_id_pass_ = true;
646 swizzle_mode_ = SwizzleMode::Loop;
647}
648
649void IndexCompute::run(const LoopIndexing& loop_indexing) {
650 TORCH_INTERNAL_ASSERT(
651 concrete_id_pass_, "concrete pass only for this option");
652 // Apply loop swizzles if there are any that outputs to
653 // the loop domains.
654 // Currently only support loop swizzles that directly output
655 // to concrete loop domains and these are validated in
656 // validate swizzle pass.
657 // TODO:
658 // will gradually enable replaying and mapping of loop
659 // swizzles in the IR infrastructure and once that's piped
660 // through this part of logic will be removed.
661 std::unordered_set<Expr*> visited;
662 for (auto loop_id : loop_indexing.loopDomains()) {
663 auto loop_id_def = loop_id->definition();
664 if (loop_id_def != nullptr && loop_id_def->isA<Swizzle2D>()) {
665 if (visited.insert(loop_id_def).second) {
666 handle(loop_id_def);
667 }
668 }
669 }
670
671 // Resolve the index vals that could be resolved with only
672 // the loops that consumer_tv doesn't share with any of its
673 // consumers, i.e. the not-inlined loops that define consumer_tv
674 // values.
675 collectIndexIntoPermissiveMap(loop_indexing);
676
677 // Run through the loop indexing expressions and generate
678 // the indexing integer math for the concrete ids.
679 for (auto expr : loop_indexing.getBackwardExprList()) {
680 // Resolve missing values from permissive map.
681 updateIndexMapFromPermissiveMap(expr);
682
683 handle(expr);
684 }
685}
686
687void IndexCompute::collectIndexIntoPermissiveMap(
688 const LoopIndexing& loop_indexing) {
689 // Visit the expressions that only produces un-inlined iterdomains,
690 // in reverse topological order.
691 for (auto expr : loop_indexing.getBackwardOutOfLineExprList()) {
692 // Compute indexing vals for the expression inputs.
693 //
694 // This stage should run before any indexing computation so it could be
695 // made sure that all index values computed at this stage are
696 // the ones that can be resolved only with the not-inlined
697 // iterdomains.
698 //
699 auto id_outputs = ir_utils::filterByType<IterDomain>(expr->outputs());
700 if (std::all_of(
701 id_outputs.begin(), id_outputs.end(), [this](IterDomain* id) {
702 return index_map_.count(
703 GpuLower::current()->caMap()->getConcreteMappedID(
704 id, IdMappingMode::EXACT));
705 })) {
706 // Visit this expression:
707 // LoopIndexingAnalysis::traverseFromDomainVals made sure that each
708 // concrete index is bound exactly once so computing these expressions
709 // early should still be consistent.
710 handle(expr);
711
712 auto id_inputs = ir_utils::filterByType<IterDomain>(expr->inputs());
713 for (auto id : id_inputs) {
714 // Collect backward pass results from this expression if they are
715 // made available in by this expression.
716 auto idx_it =
717 index_map_.find(GpuLower::current()->caMap()->getConcreteMappedID(
718 id, IdMappingMode::EXACT));
719
720 if (idx_it != index_map_.end()) {
721 permissive_index_map_
722 [GpuLower::current()->caMap()->getConcreteMappedID(
723 id, IdMappingMode::PERMISSIVE)] = idx_it->second;
724 }
725 }
726 }
727 }
728}
729
730void IndexCompute::updateIndexMapFromPermissiveMap(const Expr* id_expr) {
731 auto id_outputs = ir_utils::filterByType<IterDomain>(id_expr->outputs());
732 for (auto id : id_outputs) {
733 auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
734 id, IdMappingMode::EXACT);
735 // Only try to copy index val from permissive map when
736 // the index is missing.
737 if (!index_map_.count(concrete_id)) {
738 auto permissive_id = GpuLower::current()->caMap()->getConcreteMappedID(
739 id, IdMappingMode::PERMISSIVE);
740 // Write the permissive index val into index_map_ if the
741 // missing value is found here.
742 auto permissive_it = permissive_index_map_.find(permissive_id);
743 if (permissive_it != permissive_index_map_.end()) {
744 index_map_[concrete_id] = permissive_it->second;
745 }
746 }
747 }
748}
749
750void IndexCompute::run() {
751 const std::vector<Val*> domain_vals(
752 td_->domain().begin(), td_->domain().end());
753
754 traverseTo(td_->fusion(), domain_vals, false);
755}
756
757IterDomain* IndexCompute::maybeGetExactMapConcreteID(IterDomain* id) {
758 if (concrete_id_pass_) {
759 return GpuLower::current()->caMap()->getConcreteMappedID(
760 id, IdMappingMode::EXACT);
761 }
762 return id;
763}
764
765Val* IndexCompute::getExtent(IterDomain* id) const {
766 // Pick from extent_map_ if available. Previously parallel
767 // dimensions were ued (e.g., blockDim.x), however, it would result
768 // in out-of-bounds errors when the extent of IterDomain is smaller
769 // than the threading dimension.
770 if (extent_map_.find(id) != extent_map_.end()) {
771 return extent_map_.at(id);
772 } else {
773 return id->extent();
774 }
775}
776
777bool IndexCompute::hasZeroMerged(IterDomain* id) const {
778 return zero_merged_in_.find(id) != zero_merged_in_.end() || isZero(id);
779}
780
781bool IndexCompute::isZero(IterDomain* id) const {
782 return zero_domains_.find(id) != zero_domains_.end();
783}
784
785IndexCompute IndexCompute::updateIndexCompute(
786 const TensorDomain* new_td,
787 const std::unordered_map<IterDomain*, IterDomain*>& id_map,
788 const ContigIDs& contig_finder) const {
789 FUSER_PERF_SCOPE("GpuLower::Lower::updateIndexCompute");
790
791 std::unordered_map<IterDomain*, Val*> updated_index_map;
792 std::unordered_map<IterDomain*, Val*> updated_extent_map;
793 std::unordered_set<IterDomain*> updated_zero_domains;
794 std::unordered_set<IterDomain*> updated_zero_merged_in;
795 std::unordered_map<IterDomain*, Val*> updated_halo_extent_map;
796
797 for (auto id_entry : id_map) {
798 IterDomain* prev_id = id_entry.first;
799 IterDomain* new_id = id_entry.second;
800
801 if (index_map_.find(prev_id) != index_map_.end()) {
802 updated_index_map[new_id] = index_map_.at(prev_id);
803 }
804
805 updated_extent_map[new_id] = getExtent(prev_id);
806
807 if (zero_domains_.find(prev_id) != zero_domains_.end()) {
808 updated_zero_domains.emplace(new_id);
809 }
810
811 if (zero_merged_in_.find(prev_id) != zero_merged_in_.end()) {
812 updated_zero_merged_in.emplace(new_id);
813 }
814
815 auto halo_extent_it = halo_extent_map_.find(prev_id);
816 if (halo_extent_it != halo_extent_map_.end()) {
817 updated_halo_extent_map[new_id] = halo_extent_it->second;
818 }
819 }
820
821 IndexCompute updated_index_compute(
822 new_td,
823 updated_index_map,
824 updated_extent_map,
825 updated_zero_domains,
826 updated_zero_merged_in,
827 contig_finder,
828 {},
829 updated_halo_extent_map);
830
831 updated_index_compute.run();
832
833 return updated_index_compute;
834}
835
836namespace {
837// Map indices down to the leaf domains for applying swizzle
838class UpdateLeafIndices : public IterVisitor {
839 public:
840 UpdateLeafIndices(
841 const TensorDomain* td,
842 std::unordered_map<IterDomain*, Val*> initial_index_map,
843 std::unordered_map<IterDomain*, Val*> extent_map)
844 : td_(td),
845 index_map_(std::move(initial_index_map)),
846 extent_map_(std::move(extent_map)) {
847 const std::vector<Val*> domain_vals(
848 td_->domain().begin(), td_->domain().end());
849
850 traverseTo(td_->fusion(), domain_vals, false);
851 }
852
853 const std::unordered_map<IterDomain*, Val*>& indexMap() const {
854 return index_map_;
855 }
856
857 const std::unordered_map<IterDomain*, Val*>& extentMap() const {
858 return extent_map_;
859 }
860
861 private:
862 using IterVisitor::handle;
863
864 void handle(Split* split) override {
865 auto in_id = split->in();
866 auto outer_id = split->outer();
867 auto inner_id = split->inner();
868
869 // Nothing need to be done when mappings for the output axes
870 // already exist.
871 if (index_map_.find(outer_id) != index_map_.end()) {
872 TORCH_INTERNAL_ASSERT(
873 index_map_.find(inner_id) != index_map_.end(),
874 "Outer exists but inner not found");
875 return;
876 }
877
878 if (!index_map_.count(in_id)) {
879 // Reduction axes on producer side could be visited on forward
880 // propagation pass and current implementation does not yet
881 // support reduciton on swizzled iterdomains, so un-indexed
882 // reduction iterdomains are just ignored for now.
883 TORCH_INTERNAL_ASSERT(
884 in_id->isReduction(), "Undefined index for ", in_id->toString());
885 return;
886 }
887
888 auto factor = split->factor();
889 index_map_[inner_id] =
890 SimplifyingIrBuilder::modExpr(index_map_[in_id], factor);
891 extent_map_[inner_id] = factor;
892 index_map_[outer_id] =
893 SimplifyingIrBuilder::divExpr(index_map_[in_id], factor);
894 extent_map_[outer_id] =
895 SimplifyingIrBuilder::ceilDivExpr(getExtent(in_id), factor);
896 }
897
898 void handle(Merge* merge) override {
899 auto out_id = merge->out();
900 auto outer_id = merge->outer();
901 auto inner_id = merge->inner();
902
903 if (!index_map_.count(outer_id) || !index_map_.count(inner_id)) {
904 // Reduction axes on producer side could be visited on forward
905 // propagation pass and current implementation does not yet
906 // support reduciton on swizzled iterdomains, so un-indexed
907 // reduction iterdomains are just ignored for now.
908 TORCH_INTERNAL_ASSERT(
909 outer_id->isReduction() && inner_id->isReduction(),
910 "Undefined index for ",
911 outer_id->toString(),
912 " and ",
913 inner_id->toString());
914 return;
915 }
916
917 // Nothing need to be done when mappings for the output axes
918 // already exist.
919 if (index_map_.find(out_id) != index_map_.end()) {
920 return;
921 }
922
923 TORCH_INTERNAL_ASSERT(
924 index_map_.find(outer_id) != index_map_.end(), "Outer ID not found");
925 TORCH_INTERNAL_ASSERT(
926 index_map_.find(inner_id) != index_map_.end(), "Inner ID not found");
927
928 index_map_[out_id] = SimplifyingIrBuilder::addExpr(
929 index_map_[inner_id],
930 SimplifyingIrBuilder::mulExpr(
931 index_map_[outer_id], getExtent(inner_id)));
932
933 extent_map_[out_id] =
934 SimplifyingIrBuilder::mulExpr(getExtent(outer_id), getExtent(inner_id));
935 }
936
937 void handle(Swizzle2D* swizzle_2d) override {
938 auto in_x = swizzle_2d->inX();
939 auto in_y = swizzle_2d->inY();
940 auto out_x = swizzle_2d->outX();
941 auto out_y = swizzle_2d->outY();
942
943 // Forward propagation pass still just forward
944 // through the indices and the actual swizzle
945 // will be applied on the backward pass in
946 // IndexSwizzle class implementation.
947 index_map_[out_x] = index_map_.at(in_x);
948 extent_map_[out_x] = getExtent(in_x);
949 index_map_[out_y] = index_map_.at(in_y);
950 extent_map_[out_y] = getExtent(in_y);
951 }
952
953 // return extent_map_[id] if exists, else return id->extent()
954 Val* getExtent(IterDomain* id) {
955 if (extent_map_.find(id) != extent_map_.end()) {
956 return extent_map_.at(id);
957 } else {
958 return id->extent();
959 }
960 }
961
962 private:
963 const TensorDomain* td_;
964 std::unordered_map<IterDomain*, Val*> index_map_;
965 std::unordered_map<IterDomain*, Val*> extent_map_;
966};
967
968// Returns halo-extended extent if id has halo. Otherwise, just
969// returns id->extent.
970Val* getHaloExtentOfRootAxis(IterDomain* id, Val* normal_extent = nullptr) {
971 if (normal_extent == nullptr) {
972 normal_extent = id->extent();
973 }
974
975 const auto& halo = GpuLower::current()->haloInfo()->getRootAxisInfo(id);
976 if (halo.hasHalo()) {
977 auto halo_extent = SimplifyingIrBuilder::addExpr(
978 normal_extent, SimplifyingIrBuilder::create<Int>(halo.width()));
979 return halo_extent;
980 } else {
981 return normal_extent;
982 }
983}
984
985} // namespace
986
987IndexSwizzle::IndexSwizzle(
988 const TensorView* tv,
989 std::unordered_map<IterDomain*, Val*> initial_index_map,
990 std::unordered_map<IterDomain*, Val*> extent_map,
991 std::unordered_set<IterDomain*> zero_domains,
992 std::unordered_set<IterDomain*> zero_merged_in)
993 : IndexCompute(
994 tv->domain(),
995 std::move(initial_index_map),
996 std::move(extent_map),
997 std::move(zero_domains),
998 std::move(zero_merged_in)),
999 tv_(tv),
1000 swizzle_type_(tv->swizzleType()),
1001 ids_to_swizzle_(tv->axesToSwizzle()) {}
1002
1003IndexSwizzle::IndexSwizzle(
1004 const TensorView* tv,
1005 const TensorDomain* domain,
1006 std::unordered_map<IterDomain*, Val*> initial_index_map,
1007 std::unordered_map<IterDomain*, Val*> extent_map,
1008 std::unordered_set<IterDomain*> zero_domains,
1009 std::unordered_set<IterDomain*> zero_merged_in)
1010 : IndexCompute(
1011 domain,
1012 std::move(initial_index_map),
1013 std::move(extent_map),
1014 std::move(zero_domains),
1015 std::move(zero_merged_in)),
1016 tv_(tv),
1017 swizzle_type_(tv->swizzleType()),
1018 ids_to_swizzle_(tv->axesToSwizzle()) {}
1019
1020void IndexSwizzle::run() {
1021 TORCH_INTERNAL_ASSERT(
1022 swizzle_type_ == SwizzleType::NoSwizzle ||
1023 swizzle_type_ == SwizzleType::Transpose,
1024 "Invalid swizzle type");
1025 if (swizzle_type_ == SwizzleType::Transpose) {
1026 // Shifts the second axis by the first axis as ((idx_1 + idx_2) %
1027 // ext). Alternatively, ((idx_1 - idx_2) & (ext - 1)) would also
1028 // work if ext is a power of two. Practically, ext should be 32 if
1029 // the data type of the tensor is float, so the latter approach
1030 // should also be fine.
1031 TORCH_INTERNAL_ASSERT(tv_->getMemoryType() == MemoryType::Shared);
1032 TORCH_INTERNAL_ASSERT(tv_->axesToSwizzle().size() == 2);
1033
1034 UpdateLeafIndices update_leaves(td_, indexMap(), extentMap());
1035 index_map_ = update_leaves.indexMap();
1036 extent_map_ = update_leaves.extentMap();
1037
1038 IterDomain* id_to_swizzle_i = ids_to_swizzle_.at(0);
1039 IterDomain* id_to_swizzle_j = ids_to_swizzle_.at(1);
1040
1041 if (indexMap().find(id_to_swizzle_i) != indexMap().end() &&
1042 indexMap().find(id_to_swizzle_j) != indexMap().end()) {
1043 auto idx_to_swizzle_i = indexMap().at(id_to_swizzle_i);
1044 auto idx_to_swizzle_j = indexMap().at(id_to_swizzle_j);
1045
1046 auto swizzled_idx = SimplifyingIrBuilder::modExpr(
1047 SimplifyingIrBuilder::addExpr(idx_to_swizzle_i, idx_to_swizzle_j),
1048 id_to_swizzle_j->extent());
1049 index_map_[id_to_swizzle_j] = swizzled_idx;
1050 swizzled_ids_.insert(id_to_swizzle_j);
1051 IndexCompute::run();
1052 }
1053 } else if (tv_->hasSwizzleOp()) {
1054 // Propagate backward for the annotated swizzle path.
1055 // TODO:
1056 // eventually will unify the two swizzling implementation
1057 // code path in a follow up. Currently just focusing on
1058 // getting the necessary implementation of the swizzle
1059 // operator ready.
1060 //
1061 // At this intermediate state, the legacy swizzle implementation
1062 // takes precedence, i.e. whenever swizzle_type_ is not NoSwizzle,
1063 // the new swizzle op pass is disabled.
1064 UpdateLeafIndices update_leaves(td_, indexMap(), extentMap());
1065 index_map_ = update_leaves.indexMap();
1066 extent_map_ = update_leaves.extentMap();
1067 IndexCompute::swizzle_mode_ = SwizzleMode::Data;
1068 IndexCompute::run();
1069 }
1070}
1071
1072void IndexSwizzle::handle(Expr* e) {
1073 auto out_ids = ir_utils::filterByType<IterDomain>(e->outputs());
1074 bool needs_update =
1075 std::any_of(
1076 out_ids.begin(),
1077 out_ids.end(),
1078 [this](IterDomain* id) {
1079 return swizzled_ids_.find(id) != swizzled_ids_.end();
1080 }) ||
1081 (e->isA<Swizzle2D>() &&
1082 e->as<Swizzle2D>()->swizzleType() != Swizzle2DType::NoSwizzle &&
1083 e->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Data);
1084 if (!needs_update) {
1085 return;
1086 }
1087
1088 IndexCompute::handle(e);
1089 for (auto input : ir_utils::filterByType<IterDomain>(e->inputs())) {
1090 swizzled_ids_.insert(input);
1091 }
1092}
1093
1094void IndexSwizzle::handle(Swizzle2D* swizzle_2d) {
1095 auto out_x_id = swizzle_2d->outX();
1096 auto out_y_id = swizzle_2d->outY();
1097
1098 auto out_x_it = index_map_.find(out_x_id);
1099 auto out_y_it = index_map_.find(out_y_id);
1100
1101 // TODO: unify the legacy path in all usage
1102 TORCH_INTERNAL_ASSERT(
1103 swizzle_type_ == SwizzleType::NoSwizzle,
1104 "Cannot mix usage of two swizzle implementations");
1105
1106 TORCH_INTERNAL_ASSERT(
1107 out_x_it != index_map_.end() && out_y_it != index_map_.end(),
1108 "Swizzle output indices were not propagated through");
1109
1110 IndexCompute::handle(swizzle_2d);
1111}
1112
1113// Used for local and shared index mapping. Returns a map from loops
1114// to loop indices as well as a set of loops that do not contribute to
1115// indexing.
1116std::pair<
1117 std::unordered_map<kir::ForLoop*, Val*>,
1118 std::unordered_set<kir::ForLoop*>>
1119indexMapFromTV(
1120 const TensorView* tv,
1121 const std::vector<kir::ForLoop*>& loops,
1122 kir::ForLoop* alloc_loop,
1123 bool as_consumer,
1124 kir::ForLoop* double_buffer_loop) {
1125 bool within_alloc = false;
1126 if (alloc_loop == nullptr) {
1127 within_alloc = true;
1128 }
1129
1130 const bool is_global = tv->getMemoryType() == MemoryType::Global;
1131 const bool is_shared = tv->getMemoryType() == MemoryType::Shared;
1132 const bool is_local = tv->getMemoryType() == MemoryType::Local;
1133
1134 std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
1135
1136 // Check if the current op has an implicit loop implemented
1137 // within an mma instruction.
1138 bool within_mma_loops =
1139 std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) {
1140 return fl->iter_domain()->isMma();
1141 });
1142
1143 // When indexed as a producer, the parallel types of the the
1144 // producer domains may not be the same as those of the loops, but
1145 // that's still valid parallelization. However, in that case, using
1146 // the parallel types of the loops to decide replacement of indices
1147 // with zero isn't valid. That's only valid when there's a matching
1148 // IterDomain in the producer tensor that has the same parallel
1149 // type.
1150 auto find_matching_parallel_domain = [tv](IterDomain* id) -> bool {
1151 const auto gpu_lower = GpuLower::current();
1152 auto it = std::find_if(
1153 tv->domain()->domain().begin(),
1154 tv->domain()->domain().end(),
1155 [&](IterDomain* tv_id) {
1156 // Matching is done using the index and loop maps. See
1157 // validateParallelize as well.
1158 return gpu_lower->caMap()->areMapped(
1159 id, tv_id, IdMappingMode::EXACT) ||
1160 (GpuLower::current()->caMap()->areMapped(
1161 id, tv_id, IdMappingMode::PERMISSIVE) &&
1162 ir_utils::derivedFromRootCAAxes(tv, tv_id));
1163 });
1164 if (it == tv->domain()->domain().end()) {
1165 return false;
1166 }
1167
1168 auto corresponding_domain = *it;
1169 return corresponding_domain->getParallelType() == id->getParallelType();
1170 };
1171
1172 // Track domains that do not contibute to the resulting
1173 // index. Previously, index->isZeroInt() was used to detect such
1174 // domains, but that's not a reliable method as we may set an
1175 // initial index to zero for unswitch.
1176 std::unordered_set<kir::ForLoop*> zero_loops;
1177
1178 for (auto loop : loops) {
1179 Val* idx = nullptr;
1180 const auto same_parallel_type = as_consumer ||
1181 find_matching_parallel_domain(loop->iter_domain()) ||
1182 // Note && TODO:
1183 // mma swizzled lane_id does not map naturally from producer
1184 // to consumer but they should still be detected as same
1185 // parallel type. In a follow up may want to extent
1186 // find_matching_parallel_domain to cover this case.
1187 (within_mma_loops &&
1188 loop->iter_domain()->getParallelType() == ParallelType::TIDx);
1189 // See also LoopNestGenerator::pushAlloc.
1190 // NOLINTNEXTLINE(bugprone-branch-clone)
1191 if (!within_alloc) {
1192 if ((loop->iter_domain()->isThreadDim() && is_shared) ||
1193 (loop->iter_domain()->isThread() && is_global)) {
1194 idx = loop->index();
1195 } else {
1196 idx = GpuLower::current()->kernel()->zeroVal();
1197 zero_loops.insert(loop);
1198 }
1199 } else if (
1200 // For shared-memory tensors, when a domain is parallelized by
1201 // BID, the index can be replaced with zero as long as the
1202 // tensor has a matching domain that has the same parallel
1203 // type. Matching can be omitted when indexed as a consumer
1204 // since it is always the case. When indexed as a producer, to
1205 // replace it with zero, the same parallel type of BID must be
1206 // used by the producer tensor. Thus, since this is a shared
1207 // memory tensor, when a producer domain is parallelized by
1208 // BID, there must be a matching consumer domain with the same
1209 // parallel type, which must be the IterDomain of the
1210 // loop.
1211 (loop->iter_domain()->isBlockDim() && is_shared &&
1212 same_parallel_type) ||
1213 // Similarly for local memory tensors, zero replacement can be
1214 // only done when there's a matching domain with the same
1215 // parallel type
1216 (loop->iter_domain()->isThread() && is_local && same_parallel_type) ||
1217 // MMA operands are currently indexed in units of "fragments",
1218 // so each mma tensor domain would be zero-ed and the tensor index
1219 // calculated here would be the fragment index.
1220 // TODO: This is a quick WAR to enable iterating over a register array
1221 // of MMA fragments, so we could generate unrolled mma loops.
1222 // Eventually we still want IdGraph to be able to analyze the
1223 // in-register layout of mma fragments for more unified indexing math
1224 // as well as more flexibility in swizzling loops.
1225 (loop->iter_domain()->isMma() && !as_consumer)) {
1226 idx = GpuLower::current()->kernel()->zeroVal();
1227 zero_loops.insert(loop);
1228 } else {
1229 idx = loop->index();
1230 }
1231
1232 // If the loop is trivial, the loop index can only be the loop
1233 // start value.
1234 if (idx == loop->index() && loop->isTrivial()) {
1235 idx = loop->start();
1236 }
1237
1238 if (loop == double_buffer_loop) {
1239 auto stage_depth =
1240 GpuLower::current()->doubleBufferInfo().getStageDepthFor(
1241 loop->iter_domain());
1242 idx = SimplifyingIrBuilder::addExpr(
1243 idx, SimplifyingIrBuilder::create<Int>(stage_depth - 1));
1244 }
1245
1246 loop_to_ind_map[loop] = idx;
1247
1248 if (!within_alloc && loop == alloc_loop) {
1249 within_alloc = true;
1250 }
1251 }
1252 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
1253 return {loop_to_ind_map, zero_loops};
1254}
1255
1256//! Set "pragma unroll" required for loops that indexing of Local
1257//! tensors depends on.
1258//!
1259//! \param tv Indexed tensor
1260//! \param alloc_loop Allocation loop of tv
1261//! \param loops The current loop structure
1262//! \param id_map Producer-to-consumer map in case of indexing as producer
1263void ensureStaticIndexing(
1264 const TensorView* tv,
1265 kir::ForLoop* alloc_loop,
1266 const std::vector<kir::ForLoop*>& loops,
1267 const std::unordered_map<IterDomain*, IterDomain*>& id_map) {
1268 if (tv->getMemoryType() != MemoryType::Local) {
1269 return;
1270 }
1271
1272 bool within_alloc = false;
1273 if (alloc_loop == nullptr) {
1274 within_alloc = true;
1275 }
1276
1277 for (auto loop : loops) {
1278 if (!within_alloc) {
1279 if (loop == alloc_loop) {
1280 within_alloc = true;
1281 }
1282 continue;
1283 }
1284 IterDomain* loop_id = loop->iter_domain();
1285 if (loop->vectorize() || loop_id->isThread()) {
1286 continue;
1287 }
1288 // Look for a domain that is mapped with the loop. If mapped in
1289 // the loop map, the loop index should be used for indexing of the
1290 // tensor, except for broadcast and reduction domains.
1291 auto it = std::find_if(
1292 tv->domain()->domain().begin(),
1293 tv->domain()->domain().end(),
1294 [loop_id, &id_map](IterDomain* id) {
1295 if (id->isBroadcast() || id->isReduction() || id->isStride()) {
1296 return false;
1297 }
1298 auto id_replacement = id_map.find(id);
1299 if (id_replacement != id_map.end()) {
1300 id = id_replacement->second;
1301 }
1302 return GpuLower::current()->caMap()->areMapped(
1303 loop_id, id, IdMappingMode::PERMISSIVE);
1304 });
1305 if (it != tv->domain()->domain().end()) {
1306 loop->requireUnroll();
1307 }
1308 }
1309}
1310
1311namespace {
1312
1313//! Returns an iterdomain that corresponds to the
1314//! indexing sub-expression to hoist or a nullopt
1315//! if the index should not be hoisted.
1316c10::optional<IterDomain*> getMaybeIndexedIdToHoist(
1317 IterDomain* root_id,
1318 const TensorView* tv,
1319 const IndexCompute& indexing,
1320 Val* index) {
1321 if (isOptionDisabled(DisableOption::IndexHoist) ||
1322 index->definition() == nullptr) {
1323 return c10::nullopt;
1324 }
1325
1326 // The old swizzle interface, which should be deprecated, is not
1327 // supported.
1328 if (tv->swizzleType() != SwizzleType::NoSwizzle) {
1329 return c10::nullopt;
1330 }
1331
1332 // New swizzle interface not yet supported
1333 if (tv->hasSwizzleOp()) {
1334 return c10::nullopt;
1335 }
1336
1337 // Find the true indexed domain, which can be a merged contiguous domain.
1338 auto contig_id_it = indexing.rootToContigID().find(root_id);
1339 TORCH_INTERNAL_ASSERT(
1340 contig_id_it != indexing.rootToContigID().end(),
1341 "Consumer indexed ID not found: ",
1342 root_id->toString());
1343 auto indexed_id = contig_id_it->second;
1344 // Make sure this contig ID is indeed indexed
1345 TORCH_INTERNAL_ASSERT(
1346 indexing.indexMap().find(contig_id_it->second) !=
1347 indexing.indexMap().end(),
1348 "Invalid contig index: ",
1349 contig_id_it->second->toString());
1350
1351 return indexed_id;
1352}
1353
1354// Version of hoisting without using reference tensor,
1355// should eventually deprecate the other one once reference
1356// tensor is completely deprecated.
1357Val* hoistConsumerIndex(
1358 IterDomain* consumer_root_id,
1359 const TensorView* consumer_tv,
1360 const IndexCompute& consumer_indexing,
1361 std::vector<IterDomain*> loop_domains,
1362 const std::unordered_map<IterDomain*, Val*> initial_loop_index_map,
1363 const std::vector<kir::ForLoop*>& loops,
1364 Val* index) {
1365 auto maybe_hoisted_consumer_id = getMaybeIndexedIdToHoist(
1366 consumer_root_id, consumer_tv, consumer_indexing, index);
1367
1368 if (!maybe_hoisted_consumer_id.has_value()) {
1369 return index;
1370 }
1371
1372 // Insert the index into the common index map. A previously inserted
1373 // val can be returned.
1374 auto common_index = GpuLower::current()
1375 ->commonIndexMap()
1376 .insert(
1377 maybe_hoisted_consumer_id.value(),
1378 consumer_tv->domain(),
1379 loop_domains,
1380 initial_loop_index_map,
1381 loops,
1382 index)
1383 .first;
1384
1385 return common_index;
1386}
1387
1388std::unordered_map<IterDomain*, IterDomain*> invertOneToOneMap(
1389 const std::unordered_map<IterDomain*, IterDomain*>& map) {
1390 std::unordered_map<IterDomain*, IterDomain*> inverted;
1391 for (const auto& kv : map) {
1392 bool inserted = inverted.emplace(kv.second, kv.first).second;
1393 TORCH_INTERNAL_ASSERT(
1394 inserted,
1395 "Multiple mappings to the same value detected: ",
1396 kv.second->toString());
1397 }
1398 return inverted;
1399}
1400
1401Val* hoistProducerIndex(
1402 IterDomain* producer_root_id,
1403 const TensorView* producer_tv,
1404 const IndexCompute& producer_indexing,
1405 const TensorView* consumer_tv,
1406 const std::unordered_map<IterDomain*, IterDomain*>& p2c_map,
1407 std::vector<IterDomain*> loop_domains,
1408 const std::unordered_map<IterDomain*, Val*> initial_loop_index_map,
1409 const std::vector<kir::ForLoop*>& loops,
1410 Val* index) {
1411 auto maybe_indexed_producer_id = getMaybeIndexedIdToHoist(
1412 producer_root_id, producer_tv, producer_indexing, index);
1413
1414 if (!maybe_indexed_producer_id.has_value()) {
1415 return index;
1416 }
1417
1418 // Use the corresponding consumer domain to find matching
1419 // for-loops. Note that there's no CA mapping with the producer
1420 // domains as the producer TensorDomain is a temporary replay
1421 // domain.
1422 auto indexed_consumer_id_it = p2c_map.find(maybe_indexed_producer_id.value());
1423
1424 // There can be no corresponding consumer ID. For example, consider:
1425 // consumer: [b1, i2, i3]
1426 // producer: [i2, i3].
1427 // Suppose the consumer is transformed as:
1428 // consumer: [(b1*i2)*i3]
1429 // Then the producer would be transformed when indexed:
1430 // producer: [i2*i3]
1431 // Assuming i2 and i3 are contiguous, the producer indexing is done
1432 // with the mreged i2*i3 domain, but there's no domain in the
1433 // cosumer that maps with the producer indexed domain.
1434 // It seems non-trivial to support patterns like this. Skip for now.
1435 if (indexed_consumer_id_it == p2c_map.end()) {
1436 return index;
1437 }
1438
1439 IterDomain* indexed_consumer_id = indexed_consumer_id_it->second;
1440
1441 auto common_index = GpuLower::current()
1442 ->commonIndexMap()
1443 .insert(
1444 indexed_consumer_id,
1445 consumer_tv->domain(),
1446 loop_domains,
1447 initial_loop_index_map,
1448 loops,
1449 index)
1450 .first;
1451
1452 return common_index;
1453}
1454
1455} // namespace
1456
1457std::vector<Val*> Index::getGlobalProducerStridedIndices(
1458 TensorView* producer_tv,
1459 const TensorView* consumer_tv,
1460 const std::vector<kir::ForLoop*>& loops) {
1461 FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex");
1462 const auto gpu_lower = GpuLower::current();
1463
1464 // Replay producer to look like consumer so we can index on producer since
1465 // our loop nests look like consumer
1466 auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
1467 auto producerAsC =
1468 TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map)
1469 .first;
1470
1471 // Make the producer_tv look like consumer while performing indexing math
1472 ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC);
1473
1474 // Map sent to best effort replay needs to match the exact incantation for
1475 // compute_at_mode.cpp with MappingMode::Index
1476 auto c2p_root_map =
1477 PairwiseRootDomainMap(producer_tv, consumer_tv, true)
1478 .mapConsumerToProducer(consumer_tv->domain(), producer_tv->domain());
1479
1480 // This replay has to be consistent with compute at index map.
1481 BestEffortReplay replay_producer_as_consumer(
1482 producer_tv->domain()->domain(),
1483 consumer_tv->domain()->domain(),
1484 c2p_root_map);
1485
1486 const auto& c2p_map = replay_producer_as_consumer.getReplay();
1487 const auto p2c_map = invertOneToOneMap(c2p_map);
1488
1489 // Forward vectorized IDs to index into producer correctly
1490 // We want p_id to be vectorized like consumer just for the indexing, then we
1491 // need to switch it back later. Store previous state here when changing. We
1492 // need to do this as replaying producer as consumer can use replay best
1493 // effort which means some domains may be producer's original domains.
1494 std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
1495 for (auto entry : c2p_map) {
1496 auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID(
1497 entry.first, IdMappingMode::EXACT);
1498 auto p_id = entry.second;
1499 if (ref_id->getParallelType() == ParallelType::Vectorize) {
1500 p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
1501 p_id->parallelize(ParallelType::Vectorize);
1502 } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) {
1503 p_id->parallelize(ParallelType::MisalignedVectorize);
1504 }
1505 }
1506
1507 auto producer_indexing_from_idgraph =
1508 getTensorIndexFromIdGraph(loops, consumer_tv, producer_tv, true, c2p_map);
1509
1510 auto producer_indexing = producer_indexing_from_idgraph.index;
1511
1512 // Revert p_ids
1513 for (auto entry : p_id_backup) {
1514 entry.first->parallelize(entry.second);
1515 }
1516
1517 // Indices should now be mapped onto IterDomains in producer, so just grab
1518 // and use them.
1519 auto root_dom = producer_tv->getMaybeRFactorDomain();
1520
1521 // TODO: Abstract stride logic to reuse with consumer indexing
1522 std::vector<Val*> strides(root_dom.size(), nullptr);
1523 {
1524 int stride_i = 0;
1525 for (const auto i : c10::irange(root_dom.size())) {
1526 if (root_dom[i]->isReduction()) {
1527 strides[i] = GpuLower::current()->kernel()->oneVal();
1528 continue;
1529 }
1530 std::stringstream ss;
1531 ss << "T" << producer_tv->name() << ".stride[" << stride_i++ << "]";
1532 strides[i] =
1533 SimplifyingIrBuilder::create<NamedScalar>(ss.str(), DataType::Int);
1534 }
1535 }
1536
1537 TORCH_INTERNAL_ASSERT(
1538 root_dom.size() == producer_tv->domain()->contiguity().size());
1539 Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal();
1540 for (const auto i : c10::irange(root_dom.size())) {
1541 auto dim = root_dom.size() - i - 1;
1542 if (root_dom[dim]->isReduction()) {
1543 continue;
1544 }
1545
1546 Val* root_ind = nullptr;
1547 if (producer_indexing.indexMap().find(root_dom[dim]) !=
1548 producer_indexing.indexMap().end()) {
1549 root_ind = producer_indexing.indexMap().at(root_dom[dim]);
1550 } else if (root_dom[dim]->isBroadcast()) {
1551 root_ind = GpuLower::current()->kernel()->zeroVal();
1552 }
1553
1554 TORCH_INTERNAL_ASSERT(
1555 root_ind != nullptr,
1556 "Couldn't find root mapping for ",
1557 producer_tv->toString(),
1558 " dim: ",
1559 dim,
1560 " id: ",
1561 root_dom[dim]->toString());
1562
1563 if (producer_tv->domain()->contiguity()[dim]) {
1564 // If contig, used the stored stride which may be the previous
1565 // dimensions stride * previous dimensions size
1566 strides[dim] = cur_contig_stride;
1567 // Prepare for the next dimension which may also be contiguous, multiply
1568 // by extent of this dimension
1569 auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
1570 cur_contig_stride =
1571 SimplifyingIrBuilder::mulExpr(cur_contig_stride, root_dim_extent);
1572 } else {
1573 // If non contiguous dimension, keep local stride information, set cur
1574 // stride to local stride * local raw extent
1575 auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
1576 cur_contig_stride =
1577 SimplifyingIrBuilder::mulExpr(strides[dim], root_dim_extent);
1578 }
1579 }
1580
1581 auto vectorize_shift =
1582 loops.empty() ? nullptr : loops.back()->vectorize_shift();
1583
1584 // Global striding
1585 std::vector<Val*> strided_inds(
1586 root_dom.size(), GpuLower::current()->kernel()->zeroVal());
1587 for (const auto i : c10::irange(root_dom.size())) {
1588 // If the domain is derived from a trivial reduction, no indexing
1589 // to create.
1590 if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() ||
1591 gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) {
1592 continue;
1593 }
1594
1595 TORCH_INTERNAL_ASSERT(
1596 producer_indexing.indexMap().find(root_dom[i]) !=
1597 producer_indexing.indexMap().end(),
1598 "Couldn't find root mapping for TV",
1599 producer_tv->name(),
1600 " dim: ",
1601 i,
1602 " id: ",
1603 root_dom[i]->toString());
1604
1605 auto root_ind = producer_indexing.indexMap().at(root_dom[i]);
1606
1607 // index hoist must be done before the adjustments for halo
1608 root_ind = hoistProducerIndex(
1609 root_dom[i],
1610 producer_tv,
1611 producer_indexing,
1612 consumer_tv,
1613 p2c_map,
1614 producer_indexing_from_idgraph.resolved_loop_domains,
1615 producer_indexing_from_idgraph.initial_concrete_index_map,
1616 loops,
1617 root_ind);
1618
1619 root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv);
1620
1621 root_ind = getProducerIndexWithGather(
1622 root_ind,
1623 i,
1624 producer_tv,
1625 consumer_tv,
1626 producer_indexing_from_idgraph.concrete_index.indexMap());
1627
1628 root_ind = getProducerIndexWithPartialSplit(
1629 root_ind, root_dom[i], producer_tv, consumer_tv);
1630
1631 if (root_ind->isZeroInt()) {
1632 continue;
1633 } else {
1634 auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]);
1635 if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
1636 strided_inds[i] =
1637 SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift);
1638 } else {
1639 strided_inds[i] = strided_ind;
1640 }
1641 }
1642 }
1643
1644 return strided_inds;
1645}
1646
1647namespace {
1648
1649// Maps all producer domains to consumer with broadcast
1650// forwarding. Used to find the allocation position.
1651std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer(
1652 TensorView* producer_tv,
1653 const TensorView* consumer_tv) {
1654 // This map has forwarded broadcast axes, it should only be used to compute
1655 // the allocation position of the producer, and to figure out which producer
1656 // indices are mapped to consumer trivial reductions.
1657 std::unordered_map<IterDomain*, IterDomain*> p2c_alloc_map;
1658
1659 // We want to replay producer as consumer instead of the other way around
1660 // since consumer may have some broadcasted axes producer doesn't have
1661 // merged into loops producer may use. If we did consumer as producer we
1662 // wouldn't have this information in the mapping.
1663 auto replay_PasC = BestEffortReplay::replayPasC(
1664 producer_tv,
1665 consumer_tv,
1666 -1,
1667 PairwiseRootDomainMap(producer_tv, consumer_tv));
1668
1669 // Grab consumer domain entries and reverse replay map. TODO: Maybe
1670 // TransformReplay::replayPasC could return this map
1671 for (auto id : consumer_tv->domain()->domain()) {
1672 const auto& c2p_map = replay_PasC.getReplay();
1673 auto c2p_it = c2p_map.find(id);
1674 if (c2p_it != c2p_map.end()) {
1675 auto c_id = c2p_it->first;
1676 auto p_id = c2p_it->second;
1677 p2c_alloc_map[p_id] = c_id;
1678 }
1679 }
1680
1681 return p2c_alloc_map;
1682}
1683
1684} // namespace
1685
1686// Producer index for either shared or local memory
1687std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
1688 TensorView* producer_tv,
1689 const TensorView* consumer_tv,
1690 const std::vector<kir::ForLoop*>& loops) {
1691 const auto gpu_lower = GpuLower::current();
1692
1693 // Replay producer to look like consumer so we can index on producer since our
1694 // loop nests look like consumer
1695 auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
1696 auto producer_replayed_as_consumer =
1697 TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map)
1698 .first;
1699
1700 ir_utils::TVDomainGuard domain_guard(
1701 producer_tv, producer_replayed_as_consumer);
1702 const auto p2c_alloc_map =
1703 mapAllProducerDomainsToConsumer(producer_tv, consumer_tv);
1704
1705 // Map everything we can from reference to producer using compute at index
1706 // map. All producer id's don't exist in the compute at map. The rfactor axes
1707 // all may be, but since I haven't proven that to be the case, going to do a
1708 // more conservative approach, which is to use the consumer as a proxy between
1709 // producer to reference.
1710 std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_producer;
1711 std::unordered_map<IterDomain*, IterDomain*> c2p_index_map;
1712 std::unordered_map<IterDomain*, IterDomain*> p2c_index_map;
1713
1714 // Map sent to best effort replay needs to match the exact incantation for
1715 // compute_at_mode.cpp with MappingMode::Index
1716 auto c2p_root_map =
1717 PairwiseRootDomainMap(producer_tv, consumer_tv, true)
1718 .mapConsumerToProducer(consumer_tv->domain(), producer_tv->domain());
1719
1720 // This replay has to be consistent with compute at index map.
1721 BestEffortReplay replay_producer_as_consumer(
1722 producer_tv->domain()->domain(),
1723 consumer_tv->domain()->domain(),
1724 c2p_root_map);
1725
1726 c2p_index_map = replay_producer_as_consumer.getReplay();
1727 p2c_index_map = invertOneToOneMap(c2p_index_map);
1728
1729 // Forward vectorized IDs to index into producer correctly
1730 // We want p_id to be vectorized like consumer just for the indexing, then we
1731 // need to switch it back later. Store previous state here when changing. We
1732 // need to do this as replaying producer as consumer can use replay best
1733 // effort which means some domains may be the originals.
1734 std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
1735 for (auto entry : c2p_index_map) {
1736 auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID(
1737 entry.first, IdMappingMode::EXACT);
1738 auto p_id = entry.second;
1739 if (ref_id->getParallelType() == ParallelType::Vectorize) {
1740 p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
1741 p_id->parallelize(ParallelType::Vectorize);
1742 } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) {
1743 p_id->parallelize(ParallelType::MisalignedVectorize);
1744 }
1745 }
1746
1747 auto producer_indexing_from_idgraph = getTensorIndexFromIdGraph(
1748 loops, consumer_tv, producer_tv, false, c2p_index_map);
1749
1750 auto producer_indexing = producer_indexing_from_idgraph.index;
1751
1752 // Revert p_ids
1753 for (auto entry : p_id_backup) {
1754 entry.first->parallelize(entry.second);
1755 }
1756
1757 IndexSwizzle index_swizzle(
1758 producer_tv,
1759 producer_indexing.indexMap(),
1760 producer_indexing.extentMap(),
1761 producer_indexing.zeroDomains(),
1762 producer_indexing.zeroMergedIn());
1763
1764 index_swizzle.run();
1765
1766 auto producer_swizzled_index = index_swizzle;
1767
1768 if (producer_tv->hasSwizzleOp()) {
1769 // Special handling needed on the new swizzle
1770 // op pass:
1771 // each swizzle op is local to the tensor,
1772 // so ReplayPasC will not include the swizzle
1773 // ops on the producer iterdomain. So would
1774 // need to traverse forward the producer domain
1775 // before the replay to get the swizzle ops.
1776 IndexSwizzle producer_swizzle2d(
1777 producer_tv,
1778 domain_guard.prevDomain(),
1779 producer_indexing.indexMap(),
1780 producer_indexing.extentMap(),
1781 producer_indexing.zeroDomains(),
1782 producer_indexing.zeroMergedIn());
1783 producer_swizzle2d.run();
1784 producer_swizzled_index = producer_swizzle2d;
1785 }
1786
1787 // TODO: merge the two swizzle compute logic once the new one is ready.
1788 // will need to replace cyclic shift swizzle with xor since swizzle2d
1789 // doesn't have cyclic shift.
1790 const auto& index_map = producer_swizzled_index.indexMap();
1791
1792 const auto& extent_map = producer_indexing.extentMap();
1793 const auto& zero_domain_map = producer_indexing.zeroDomains();
1794 // Indices should now be mapped onto IterDomains in producer, so just grab
1795 // and use them.
1796 auto root_dom = producer_tv->getMaybeRFactorDomain();
1797
1798 // Figure out which root axes we don't need to index
1799 std::unordered_set<IterDomain*> skip_indexing;
1800
1801 for (auto root_id : root_dom) {
1802 // Already taken care of because we can detect no indexing required
1803 if (root_id->isBroadcast() || root_id->isReduction() ||
1804 gpu_lower->trivialReductionInfo().isDerived(root_id) ||
1805 root_id->isStride()) {
1806 skip_indexing.insert(root_id);
1807 continue;
1808 }
1809
1810 // Already an entry for this root domain, continue
1811 if (index_map.find(root_id) != index_map.end()) {
1812 continue;
1813 }
1814
1815 // Maps to consumers trivial reduction, don't index
1816 if (p2c_alloc_map.find(root_id) != p2c_alloc_map.end() &&
1817 gpu_lower->trivialReductionInfo().isDerived(
1818 p2c_alloc_map.at(root_id))) {
1819 skip_indexing.emplace(root_id);
1820 }
1821 }
1822
1823 std::vector<Val*> strided_inds(
1824 root_dom.size(), GpuLower::current()->kernel()->zeroVal());
1825 for (const auto i : c10::irange(root_dom.size())) {
1826 if (skip_indexing.count(root_dom[i])) {
1827 continue;
1828 }
1829
1830 TORCH_INTERNAL_ASSERT(
1831 index_map.find(root_dom[i]) != index_map.end(),
1832 "Couldn't find root mapping for ",
1833 producer_tv->toString(),
1834 " dim: ",
1835 i,
1836 " id: ",
1837 root_dom[i]->toString());
1838
1839 auto root_ind_i = index_map.at(root_dom[i]);
1840
1841 // index hoist must be done before the adjustments for halo
1842 root_ind_i = hoistProducerIndex(
1843 root_dom[i],
1844 producer_tv,
1845 producer_indexing,
1846 consumer_tv,
1847 p2c_index_map,
1848 producer_indexing_from_idgraph.resolved_loop_domains,
1849 producer_indexing_from_idgraph.initial_concrete_index_map,
1850 loops,
1851 root_ind_i);
1852
1853 root_ind_i =
1854 getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv);
1855
1856 root_ind_i = getProducerIndexWithGather(
1857 root_ind_i,
1858 i,
1859 producer_tv,
1860 consumer_tv,
1861 producer_indexing_from_idgraph.concrete_index.indexMap());
1862
1863 root_ind_i = getProducerIndexWithPartialSplit(
1864 root_ind_i, root_dom[i], producer_tv, consumer_tv);
1865
1866 if (root_ind_i->isZeroInt()) {
1867 continue;
1868 }
1869
1870 // Compute striding for this index.
1871 Val* stride = nullptr;
1872 for (const auto j : c10::irange(i + 1, root_dom.size())) {
1873 if (skip_indexing.count(root_dom[j])) {
1874 continue;
1875 }
1876
1877 TORCH_INTERNAL_ASSERT(
1878 index_map.find(root_dom[j]) != index_map.end(),
1879 "Couldn't find root mapping for ",
1880 producer_tv->name(),
1881 " dim: ",
1882 j,
1883 " id: ",
1884 root_dom[j]->toString());
1885
1886 auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end()
1887 ? root_dom[j]->extent()
1888 : extent_map.at(root_dom[j]);
1889
1890 root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j);
1891
1892 if (zero_domain_map.count(root_dom[j]) == 0) {
1893 if (stride == nullptr) {
1894 stride = root_ext_j;
1895 } else {
1896 stride = SimplifyingIrBuilder::mulExpr(stride, root_ext_j);
1897 }
1898 }
1899 }
1900
1901 if (stride != nullptr) {
1902 strided_inds[i] = SimplifyingIrBuilder::mulExpr(root_ind_i, stride);
1903 } else {
1904 strided_inds[i] = root_ind_i;
1905 }
1906 }
1907
1908 if (producer_tv->isDoubleBuffered() || producer_tv->isCircularBuffered()) {
1909 auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
1910 producer_tv, loops, true);
1911 if (db_loop != nullptr) {
1912 auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor(
1913 db_loop->iter_domain());
1914 auto loop_index =
1915 db_loop->isTrivial() ? db_loop->start() : db_loop->index();
1916 auto db_switch_index = SimplifyingIrBuilder::modExpr(
1917 loop_index, SimplifyingIrBuilder::create<Int>(stage_depth));
1918 auto original_alloc_size =
1919 gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv);
1920 auto db_strided_index =
1921 SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size);
1922 strided_inds.push_back(db_strided_index);
1923 }
1924 }
1925
1926 return strided_inds;
1927}
1928
1929std::vector<Val*> Index::getLinearLogicalIndex(
1930 TensorView* consumer_tv,
1931 const std::vector<kir::ForLoop*>& loops) {
1932 auto guard = ir_utils::overrideContiguityGuard(consumer_tv, true);
1933 return getGlobalConsumerStridedIndices(consumer_tv, loops);
1934}
1935
1936std::vector<Val*> Index::getPerDimLogicalIndex(
1937 TensorView* consumer_tv,
1938 const std::vector<kir::ForLoop*>& loops) {
1939 auto guard = ir_utils::overrideContiguityGuard(consumer_tv, false);
1940 IndexFromIdGraph index_from_id_graph =
1941 getTensorIndexFromIdGraph(loops, consumer_tv);
1942 return getRootIndices(consumer_tv, loops, index_from_id_graph);
1943}
1944
1945std::vector<Val*> Index::getStrides(const TensorView* tv) {
1946 // Indices should now be mapped onto IterDomains in consumer, so just grab
1947 // and use them.
1948 auto root_dom = tv->getMaybeRFactorDomain();
1949
1950 std::vector<Val*> strides(
1951 root_dom.size(), GpuLower::current()->kernel()->oneVal());
1952 {
1953 int stride_i = 0;
1954 for (const auto i : c10::irange(root_dom.size())) {
1955 if (root_dom[i]->isReduction() || root_dom[i]->isStride()) {
1956 strides[i] = GpuLower::current()->kernel()->oneVal();
1957 continue;
1958 }
1959 std::stringstream ss;
1960 ss << "T" << tv->name() << ".stride[" << stride_i++ << "]";
1961 strides[i] =
1962 SimplifyingIrBuilder::create<NamedScalar>(ss.str(), DataType::Int);
1963 }
1964 }
1965
1966 TORCH_INTERNAL_ASSERT(root_dom.size() == tv->domain()->contiguity().size());
1967 Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal();
1968 for (const auto i : c10::irange(root_dom.size())) {
1969 auto dim = root_dom.size() - i - 1;
1970 if (root_dom[dim]->isReduction() || root_dom[dim]->isStride()) {
1971 continue;
1972 }
1973
1974 if (tv->domain()->contiguity()[dim]) {
1975 // If contig, used the stored stride which may be the previous
1976 // dimensions stride * previous dimensions size
1977 strides[dim] = cur_contig_stride;
1978 // Prepare for the next dimension which may also be contiguous, multiply
1979 // by extent of this dimension
1980 auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
1981 cur_contig_stride =
1982 SimplifyingIrBuilder::mulExpr(cur_contig_stride, root_dim_extent);
1983 } else {
1984 // If non contiguous dimension, keep local stride information, set cur
1985 // stride to local stride * local raw extent
1986 cur_contig_stride = SimplifyingIrBuilder::mulExpr(
1987 strides[dim], getHaloExtentOfRootAxis(root_dom[dim]));
1988 }
1989 }
1990 return strides;
1991}
1992
1993std::vector<Val*> Index::getRootIndices(
1994 const TensorView* tv,
1995 const std::vector<kir::ForLoop*>& loops,
1996 const IndexFromIdGraph& index_from_id_graph) {
1997 auto gpu_lower = GpuLower::current();
1998 auto root_dom = tv->getMaybeRFactorDomain();
1999 auto indexing = index_from_id_graph.index;
2000
2001 std::vector<Val*> root_inds(
2002 root_dom.size(), GpuLower::current()->kernel()->zeroVal());
2003 for (const auto i : c10::irange(root_dom.size())) {
2004 // See a comment in indexing to root domains in getGlobalProducerIndex.
2005 if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() ||
2006 gpu_lower->trivialReductionInfo().isDerived(root_dom[i]) ||
2007 root_dom[i]->isStride()) {
2008 continue;
2009 }
2010
2011 TORCH_INTERNAL_ASSERT(
2012 indexing.indexMap().find(root_dom[i]) != indexing.indexMap().end(),
2013 "Couldn't find root mapping for ",
2014 tv->toString(),
2015 " dim: ",
2016 i,
2017 " id: ",
2018 root_dom[i]->toString());
2019
2020 auto root_ind = indexing.indexMap().at(root_dom[i]);
2021
2022 // index hoist must be done before the adjustments for halo
2023 root_ind = hoistConsumerIndex(
2024 root_dom[i],
2025 tv,
2026 indexing,
2027 index_from_id_graph.resolved_loop_domains,
2028 index_from_id_graph.initial_concrete_index_map,
2029 loops,
2030 root_ind);
2031
2032 root_ind = SimplifyingIrBuilder::addExpr(
2033 root_ind, getGlobalConsumerOffsetWithPartialSplit(root_dom[i]));
2034 root_inds[i] = root_ind;
2035 }
2036 return root_inds;
2037}
2038
2039std::vector<Val*> Index::getGlobalConsumerStridedIndices(
2040 const TensorView* consumer_tv,
2041 const std::vector<kir::ForLoop*>& loops) {
2042 FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex");
2043
2044 auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv);
2045 auto consumer_indexing = index_from_id_graph.index;
2046 auto strides = getStrides(consumer_tv);
2047 auto root_inds = getRootIndices(consumer_tv, loops, index_from_id_graph);
2048
2049 // Global striding
2050 auto vectorize_shift =
2051 loops.empty() ? nullptr : loops.back()->vectorize_shift();
2052 std::vector<Val*> strided_inds(
2053 root_inds.size(), GpuLower::current()->kernel()->zeroVal());
2054 for (const auto i : c10::irange(root_inds.size())) {
2055 if (root_inds[i]->isZeroInt()) {
2056 continue;
2057 } else {
2058 auto strided_ind =
2059 SimplifyingIrBuilder::mulExpr(root_inds[i], strides[i]);
2060 if (i == strides.size() - 1 && vectorize_shift != nullptr) {
2061 strided_inds[i] =
2062 SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift);
2063 } else {
2064 strided_inds[i] = strided_ind;
2065 }
2066 }
2067 }
2068
2069 TORCH_INTERNAL_ASSERT(
2070 strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size());
2071
2072 return strided_inds;
2073}
2074
2075// Consumer index for either shared or local memory
2076std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
2077 const TensorView* consumer_tv,
2078 const std::vector<kir::ForLoop*>& loops) {
2079 const auto gpu_lower = GpuLower::current();
2080
2081 auto consumer_indexing_from_idgraph = getTensorIndexFromIdGraph(
2082 loops,
2083 consumer_tv,
2084 // Producer tv
2085 nullptr,
2086 // Index global
2087 false);
2088
2089 auto consumer_indexing = consumer_indexing_from_idgraph.index;
2090
2091 IndexSwizzle index_swizzle(
2092 consumer_tv,
2093 consumer_indexing.indexMap(),
2094 consumer_indexing.extentMap(),
2095 consumer_indexing.zeroDomains(),
2096 consumer_indexing.zeroMergedIn());
2097
2098 index_swizzle.run();
2099
2100 const auto& index_map = index_swizzle.indexMap();
2101 const auto& extent_map = consumer_indexing.extentMap();
2102 const auto& zero_domain_map = consumer_indexing.zeroDomains();
2103
2104 // Indices should now be mapped onto IterDomains in consumer, so just grab
2105 // and use them.
2106 auto root_dom = consumer_tv->getMaybeRFactorDomain();
2107 std::vector<Val*> strided_inds(
2108 root_dom.size(), GpuLower::current()->kernel()->zeroVal());
2109 for (const auto i : c10::irange(root_dom.size())) {
2110 if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() ||
2111 gpu_lower->trivialReductionInfo().isDerived(root_dom[i]) ||
2112 root_dom[i]->isStride()) {
2113 continue;
2114 }
2115
2116 TORCH_INTERNAL_ASSERT(
2117 index_map.find(root_dom[i]) != index_map.end(),
2118 "Couldn't find root mapping for ",
2119 consumer_tv->toString(),
2120 " dim: ",
2121 i,
2122 " id: ",
2123 root_dom[i]->toString());
2124
2125 auto root_ind_i = index_map.at(root_dom[i]);
2126 if (root_ind_i->isZeroInt()) {
2127 continue;
2128 }
2129
2130 // index hoist must be done before the adjustments for halo
2131 root_ind_i = hoistConsumerIndex(
2132 root_dom[i],
2133 consumer_tv,
2134 consumer_indexing,
2135 consumer_indexing_from_idgraph.resolved_loop_domains,
2136 consumer_indexing_from_idgraph.initial_concrete_index_map,
2137 loops,
2138 root_ind_i);
2139
2140 // Compute striding for this index.
2141 Val* stride = nullptr;
2142 for (const auto j : c10::irange(i + 1, root_dom.size())) {
2143 if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction() ||
2144 gpu_lower->trivialReductionInfo().isDerived(root_dom[j]) ||
2145 root_dom[j]->isStride()) {
2146 continue;
2147 }
2148
2149 TORCH_INTERNAL_ASSERT(
2150 index_map.find(root_dom[j]) != index_map.end(),
2151 "Couldn't find root mapping for ",
2152 consumer_tv->toString(),
2153 " dim: ",
2154 j,
2155 " id: ",
2156 root_dom[j]->toString());
2157
2158 auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end()
2159 ? root_dom[j]->extent()
2160 : extent_map.at(root_dom[j]);
2161
2162 root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j);
2163
2164 if (zero_domain_map.count(root_dom[j]) == 0) {
2165 if (stride == nullptr) {
2166 stride = root_ext_j;
2167 } else {
2168 stride = SimplifyingIrBuilder::mulExpr(stride, root_ext_j);
2169 }
2170 }
2171 }
2172
2173 if (stride != nullptr) {
2174 strided_inds[i] = SimplifyingIrBuilder::mulExpr(root_ind_i, stride);
2175 } else {
2176 strided_inds[i] = root_ind_i;
2177 }
2178 }
2179
2180 // This check was originally done in getConsumerStridedIndices, but
2181 // the number of strided index values depends on the loop where the
2182 // consumer tensor is located. If it's double buffered and not in
2183 // the prologue loop, strided_inds ends up having one more
2184 // index, so it's just much simpler to check here before adding the
2185 // additional index for double buffering.
2186 TORCH_INTERNAL_ASSERT(
2187 strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size());
2188
2189 if (consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) {
2190 auto db_loop =
2191 gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops);
2192 auto stage_depth =
2193 gpu_lower->doubleBufferInfo().getStageDepthFor(db_loop->iter_domain());
2194 bool is_circular_buffer_loop = stage_depth > 2;
2195 bool is_prolog =
2196 db_loop->doubleBufferLoopStage() == DoubleBufferLoopStage::Prolog;
2197
2198 Val* db_switch_index = nullptr;
2199
2200 // In double buffered we don't materialize the prolog loop as there will
2201 // be only one iteration. In circular buffer case we materialize the
2202 // prolog loop as well covering the first N-1 iterations, N being the
2203 // stage depth.
2204 if (!is_prolog || is_circular_buffer_loop) {
2205 if (is_prolog && is_circular_buffer_loop) {
2206 // The buffer switching logic is the same as original index
2207 // in the case of circular buffer prolog.
2208 db_switch_index = db_loop->index();
2209 } else {
2210 // Switching index generated for main loop or epilog component.
2211 db_switch_index = SimplifyingIrBuilder::modExpr(
2212 SimplifyingIrBuilder::addExpr(
2213 db_loop->index(),
2214 SimplifyingIrBuilder::create<Int>(stage_depth - 1)),
2215 SimplifyingIrBuilder::create<Int>(stage_depth));
2216 }
2217
2218 // Use the generated switching buffer index to access the buffer space.
2219 auto original_alloc_size =
2220 gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv);
2221 auto db_strided_index =
2222 SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size);
2223 strided_inds.push_back(db_strided_index);
2224 }
2225 }
2226
2227 return strided_inds;
2228}
2229
2230std::vector<Val*> Index::getProducerStridedIndices(
2231 TensorView* producer,
2232 const TensorView* consumer,
2233 const std::vector<kir::ForLoop*>& loops) {
2234 FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices");
2235 if (producer->domain()->noReductions().size() == 0) {
2236 return std::vector<Val*>(
2237 producer->getMaybeRFactorDomain().size(),
2238 GpuLower::current()->kernel()->zeroVal());
2239 }
2240
2241 std::vector<Val*> strided_indices;
2242 if (producer->getMemoryType() == MemoryType::Global) {
2243 strided_indices =
2244 getGlobalProducerStridedIndices(producer, consumer, loops);
2245 } else {
2246 strided_indices =
2247 getNonGlobalProducerStridedIndices(producer, consumer, loops);
2248 }
2249
2250 TORCH_INTERNAL_ASSERT(
2251 strided_indices.size() ==
2252 producer->getMaybeRFactorDomain().size() +
2253 (producer->isDoubleBuffered() || producer->isCircularBuffered() ? 1
2254 : 0));
2255
2256 return strided_indices;
2257}
2258
2259// Producer is the inputs of an expression
2260kir::TensorIndex* Index::getProducerIndex(
2261 TensorView* producer,
2262 const TensorView* consumer,
2263 const std::vector<kir::ForLoop*>& loops) {
2264 auto strided_indices = getProducerStridedIndices(producer, consumer, loops);
2265 return SimplifyingIrBuilder::create<kir::TensorIndex>(
2266 producer, strided_indices);
2267}
2268
2269std::vector<Val*> Index::getConsumerStridedIndices(
2270 const TensorView* consumer,
2271 const std::vector<kir::ForLoop*>& loops) {
2272 FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices");
2273 if (consumer->domain()->noReductions().size() == 0) {
2274 return std::vector<Val*>(
2275 consumer->getMaybeRFactorDomain().size(),
2276 GpuLower::current()->kernel()->zeroVal());
2277 }
2278
2279 std::vector<Val*> strided_indices;
2280 if (consumer->getMemoryType() == MemoryType::Global) {
2281 strided_indices = getGlobalConsumerStridedIndices(consumer, loops);
2282 } else {
2283 strided_indices = getNonGlobalConsumerStridedIndices(consumer, loops);
2284 }
2285
2286 return strided_indices;
2287}
2288
2289// Consumer is the output of an expression
2290kir::TensorIndex* Index::getConsumerIndex(
2291 const TensorView* consumer,
2292 const std::vector<kir::ForLoop*>& loops) {
2293 auto strided_indices = getConsumerStridedIndices(consumer, loops);
2294 return SimplifyingIrBuilder::create<kir::TensorIndex>(
2295 consumer, strided_indices);
2296}
2297
2298namespace {
2299
2300struct PredicateDomainInfo {
2301 public:
2302 // Iteration domain to predicate
2303 IterDomain* id = nullptr;
2304 // The set of iteration domains that make up the id. If this is for
2305 // a non-divisible split, the set only contains the id itself. This
2306 // set is used to remove redundant predicates when gathering
2307 // unswitch predicates.
2308 std::unordered_set<IterDomain*> covered_ids;
2309 // True if this predicate is for a non-divisible split
2310 bool is_non_divisible_split = false;
2311};
2312
2313// Find iteration domains in the history of a consumer to predicate comprised
2314// only of merge operations. Only return iteration domains that are subsequently
2315// fed into a split, or are in the provided domain. In other words, we don't
2316// want to return every IterDomain that's contiguous, just the one closest to
2317// the leaves. Predicates are not associated with physical memory so we can
2318// treat all of them as contiguous merges.
2319//
2320// TODO: This seems to have a large overlap with ContigIDs. Consider
2321// refactoring.
2322std::vector<PredicateDomainInfo> getPredicateContigIds(
2323 TensorView* consumer_tv,
2324 const std::unordered_map<IterDomain*, Val*>& consumer_index_map) {
2325 const auto gpu_lower = GpuLower::current();
2326
2327 const auto& consumer_root_domain = consumer_tv->getRootDomain();
2328
2329 if (consumer_root_domain.empty()) {
2330 return std::vector<PredicateDomainInfo>();
2331 }
2332
2333 std::unordered_map<IterDomain*, Val*> concrete_index_map;
2334 for (auto entry : consumer_index_map) {
2335 auto c_id = gpu_lower->caMap()->getConcreteMappedID(
2336 entry.first, IdMappingMode::EXACT);
2337 concrete_index_map[c_id] = entry.second;
2338 }
2339
2340 std::vector<bool> predicate_contiguity(consumer_root_domain.size(), true);
2341 std::unordered_set<IterDomain*> final_ids;
2342 for (auto root_i : c10::irange(predicate_contiguity.size())) {
2343 auto root_id = consumer_root_domain[root_i];
2344 if (root_id->maybePartial()) {
2345 final_ids.insert(root_id);
2346 continue;
2347 }
2348 // Shifted or gathered axes need to be predicated at the root domain
2349 auto shift_expr = dynamic_cast<ShiftOp*>(consumer_tv->definition());
2350 auto gather_expr = dynamic_cast<GatherOp*>(consumer_tv->definition());
2351 if ((shift_expr && shift_expr->offset(root_i) != 0) ||
2352 (gather_expr && root_i < gather_expr->windowShape().size() &&
2353 gather_expr->windowShape().at(root_i) != 1)) {
2354 final_ids.insert(root_id);
2355 }
2356 }
2357
2358 ContigIDs contig_finder(
2359 consumer_tv->domain()->domain(),
2360 consumer_root_domain,
2361 predicate_contiguity,
2362 final_ids,
2363 concrete_index_map,
2364 GpuLower::current()->divisbleSplitSet(),
2365 GpuLower::current()->caMap(),
2366 GpuLower::current()->haloInfo(),
2367 GpuLower::current()->concretizedBroadcastDomains(),
2368 {},
2369 false,
2370 true);
2371
2372 std::vector<PredicateDomainInfo> contig_id_infos;
2373 std::unordered_set<IterDomain*> covered_roots;
2374
2375 // Create entries and return them
2376 for (auto root_id : consumer_root_domain) {
2377 if (covered_roots.count(root_id) > 0) {
2378 continue;
2379 }
2380
2381 auto contig_id_it = contig_finder.rootToIndexedID().find(root_id);
2382
2383 TORCH_INTERNAL_ASSERT(
2384 contig_id_it != contig_finder.rootToIndexedID().end(),
2385 "Error in predicate contiguity analysis, missing index for root ",
2386 root_id->toString());
2387
2388 auto contig_id = contig_id_it->second;
2389
2390 // Pick inputs from the starting domains, i.e.,
2391 // reference_predicated_root_domain.
2392 auto contig_root_ids = contig_finder.indexedRootIDs(contig_id);
2393 covered_roots.insert(contig_root_ids.begin(), contig_root_ids.end());
2394 PredicateDomainInfo contig_id_info;
2395 contig_id_info.id = contig_id;
2396 contig_id_info.covered_ids = std::unordered_set<IterDomain*>(
2397 contig_root_ids.begin(), contig_root_ids.end());
2398 contig_id_infos.push_back(contig_id_info);
2399 }
2400 return contig_id_infos;
2401}
2402
2403std::vector<PredicateDomainInfo> getNonDivisibleConsumerDomainsToPredicate(
2404 TensorView* consumer_tv) {
2405 const auto& non_divisible_split_info =
2406 GpuLower::current()->nonDivisibleSplitInfo();
2407
2408 std::vector<PredicateDomainInfo> pred_info_vec;
2409
2410 auto it = non_divisible_split_info.splitsToPredicate().find(consumer_tv);
2411 if (it == non_divisible_split_info.splitsToPredicate().end()) {
2412 return {};
2413 }
2414
2415 const auto& splits_to_predicate = it->second;
2416
2417 for (auto split : splits_to_predicate) {
2418 PredicateDomainInfo info{split->in(), {split->in()}, true};
2419 pred_info_vec.emplace_back(info);
2420 }
2421
2422 return pred_info_vec;
2423}
2424
2425bool needsPadding(TensorView* tv) {
2426 auto shift_expr = dynamic_cast<ShiftOp*>(tv->definition());
2427 auto gather_expr = dynamic_cast<GatherOp*>(tv->definition());
2428
2429 return (shift_expr != nullptr && shift_expr->hasPadding()) ||
2430 (gather_expr != nullptr && gather_expr->hasPadding());
2431}
2432
2433// Get an additional offset of a stop index when building a predicate
2434// for unswitch. Initial stop indices generated at
2435// getPredicateIndexingFromIdGraph do not take halo into account, and the
2436// adjustment for halo is done as an additional offset to the final index value
2437// so that unswitch predicates can be compared with each other by just looking
2438// at the additional offsets.
2439//
2440// consumer_root_id: the domain for which a stop predicate is being built.
2441int getUnswitchStopOffset(
2442 IterDomain* consumer_root_id,
2443 TensorView* consumer_tv) {
2444 const auto gpu_lower = GpuLower::current();
2445
2446 AxisHaloInfo halo_info =
2447 gpu_lower->haloInfo()->getRootAxisInfo(consumer_root_id);
2448
2449 // If the consumer root domain to predicate does not have halo, no
2450 // adjustment is required.
2451 if (!halo_info.hasHalo()) {
2452 return 0;
2453 }
2454
2455 // Find if this contig_id is used in the unswitched domains
2456 auto unswitch_it = std::find_if(
2457 consumer_tv->domain()->domain().begin(),
2458 consumer_tv->domain()->domain().end(),
2459 [](IterDomain* id) {
2460 return id->getParallelType() == ParallelType::Unswitch ||
2461 id->getParallelType() == ParallelType::Unroll ||
2462 id->getParallelType() == ParallelType::Vectorize;
2463 });
2464
2465 // If any of the unswitched leaf domains inherits the halo from the
2466 // root domain, the halo width needs to be added to the stop offset
2467 if (std::any_of(
2468 unswitch_it,
2469 consumer_tv->domain()->domain().end(),
2470 [&gpu_lower, &consumer_root_id](auto leaf_id) {
2471 return gpu_lower->haloInfo()->isHaloInherited(
2472 consumer_root_id, leaf_id);
2473 })) {
2474 return halo_info.width();
2475 } else {
2476 return 0;
2477 }
2478}
2479
2480std::pair<Val*, Val*> getStartAndStopOffsetsForShift(
2481 TensorView* consumer_tv,
2482 IterDomain* consumer_id,
2483 bool padding_predicate) {
2484 TORCH_INTERNAL_ASSERT(consumer_id != nullptr);
2485
2486 auto shift_expr = dynamic_cast<ShiftOp*>(consumer_tv->definition());
2487
2488 // Adjustment is not necessary if not shift.
2489 // Even so, padding predicate does not need any adjustment.
2490 if (shift_expr == nullptr || padding_predicate) {
2491 return {
2492 GpuLower::current()->kernel()->zeroVal(),
2493 GpuLower::current()->kernel()->zeroVal()};
2494 }
2495
2496 const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id);
2497
2498 // The first or last N elements, where N is the padding width,
2499 // correspond to the padding predicate.
2500
2501 const auto shift_offset = shift_expr->offset(root_axis_pos);
2502 const auto pad_width = shift_expr->padWidth().at(root_axis_pos);
2503
2504 int start_offset = 0;
2505 int stop_offset = 0;
2506
2507 if (shift_offset > 0) {
2508 start_offset = -pad_width;
2509 } else if (shift_offset < 0) {
2510 stop_offset = pad_width;
2511 }
2512
2513 return {
2514 SimplifyingIrBuilder::create<Int>(start_offset),
2515 SimplifyingIrBuilder::create<Int>(stop_offset)};
2516}
2517
2518std::pair<Val*, Val*> getStartAndStopOffsetsForGather(
2519 TensorView* consumer_tv,
2520 IterDomain* consumer_id,
2521 const std::unordered_map<IterDomain*, Val*>& ref_start_index_map,
2522 const std::unordered_map<IterDomain*, Val*>& ref_stop_index_map,
2523 bool padding_predicate) {
2524 TORCH_INTERNAL_ASSERT(consumer_id != nullptr);
2525
2526 // Adjustment is not necessary if not gather. Even so, padding
2527 // predicate does not need any adjustment.
2528 if (!consumer_tv->definition()->isA<GatherOp>() || padding_predicate) {
2529 return {
2530 GpuLower::current()->kernel()->zeroVal(),
2531 GpuLower::current()->kernel()->zeroVal()};
2532 }
2533
2534 const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id);
2535
2536 auto producer_start_offset = getProducerOffsetWithGather(
2537 root_axis_pos, consumer_tv, ref_start_index_map);
2538
2539 auto producer_stop_offset = getProducerOffsetWithGather(
2540 root_axis_pos, consumer_tv, ref_stop_index_map);
2541
2542 auto consumer_start_offset = GpuLower::current()->kernel()->zeroVal();
2543 auto consumer_stop_offset = GpuLower::current()->kernel()->zeroVal();
2544
2545 if (producer_start_offset->isZeroInt() && producer_stop_offset->isZeroInt()) {
2546 return {consumer_start_offset, consumer_stop_offset};
2547 }
2548
2549 Val* start_offset = nullptr;
2550 Val* stop_offset = nullptr;
2551
2552 // In the normal case, take the minimum of the start and the
2553 // maximum of the stop offsets. If there's no padding, the producer
2554 // offset must be always larger than the consumer
2555 // offset. So, the consumer and produce offsets can be always used
2556 // for the start and stop offsets, respectively.
2557 const auto pad_left =
2558 consumer_tv->definition()->as<GatherOp>()->padWidth()[root_axis_pos][0];
2559 const auto pad_right =
2560 consumer_tv->definition()->as<GatherOp>()->padWidth()[root_axis_pos][1];
2561 const auto window_size =
2562 consumer_tv->definition()->as<GatherOp>()->windowShape()[root_axis_pos];
2563
2564 // consumer index: index
2565 // producer index: index + window_index - pad_left
2566 //
2567 // consumer extent: ext
2568 // producer extent: ext + window_size - 1 - pad_left - pad_right
2569 //
2570 // consumer stop pred: index < ext
2571 // producer stop pred: index + window_index - pad_left < ext + window_size - 1
2572 // - pad_left - pad_right
2573 // -> index + window_index - pad_left - (window_size - 1 -
2574 // pad_left - pad_right) < ext
2575 // -> index + window_index - (window_size - 1 - pad_right) <
2576 // ext
2577 //
2578 // consumer start pred: index >= 0
2579 // producer start pred: index + window_index - pad_left >= 0
2580
2581 const auto producer_ext_adj = window_size - 1 - pad_left - pad_right;
2582 producer_stop_offset = SimplifyingIrBuilder::subExpr(
2583 producer_stop_offset,
2584 SimplifyingIrBuilder::create<Int>(producer_ext_adj));
2585
2586 // As commented above, when pad_left is zero, the consumer predicate
2587 // is always more restrictive than the producer predicate.
2588 if (pad_left == 0) {
2589 start_offset = consumer_start_offset;
2590 } else {
2591 start_offset = SimplifyingIrBuilder::minExpr(
2592 consumer_start_offset, producer_start_offset);
2593 }
2594
2595 // As commented above, when pad_right is zero, the consumer
2596 // predicate is always more restrictive than the producer
2597 // predicate.
2598 if (pad_right == 0) {
2599 stop_offset = consumer_stop_offset;
2600 } else {
2601 stop_offset = SimplifyingIrBuilder::maxExpr(
2602 consumer_stop_offset, producer_stop_offset);
2603 }
2604
2605 TORCH_INTERNAL_ASSERT(start_offset != nullptr);
2606 TORCH_INTERNAL_ASSERT(stop_offset != nullptr);
2607
2608 return {start_offset, stop_offset};
2609}
2610
2611// Get the start and stop limit offsets that define the valid range to
2612// compute. In the simplest case, they are just 0 and
2613// IterDomain::extent. However, IterDomain may have non-zero start and
2614// stop that's different from extent. Also, when IterDomain has halo,
2615// the actual offsets of the logical start and stop positions are
2616// shifted.
2617std::pair<Val*, Val*> getStartAndStopLimitOffsets(
2618 IterDomain* consumer_id,
2619 bool padding_predicate,
2620 bool non_divisible_pred) {
2621 const auto gpu_lower = GpuLower::current();
2622
2623 TORCH_INTERNAL_ASSERT(consumer_id != nullptr);
2624
2625 Val* start_limit = consumer_id->start();
2626 Val* stop_limit = SimplifyingIrBuilder::negExpr(consumer_id->stopOffset());
2627
2628 if (!non_divisible_pred) {
2629 AxisHaloInfo halo_info =
2630 gpu_lower->haloInfo()->getRootAxisInfo(consumer_id);
2631
2632 // Below, "left" and "right" halo mean halo at offset zero and
2633 // axis extent, respectively.
2634 //
2635 // The consumer axis looks like this:
2636 //
2637 // [0, left halo)[start_limit, stop_limit)[0, right halo)
2638 //
2639 if (!padding_predicate) {
2640 start_limit =
2641 SimplifyingIrBuilder::addExpr(start_limit, halo_info.width(0));
2642 stop_limit =
2643 SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width(0));
2644 } else {
2645 // In case of the padding predicate, the whole range, including both left
2646 // and right halo regions, is computed.
2647 stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width());
2648 }
2649 } else {
2650 // For non-divisible predicates, the index must be predicated such
2651 // that it is less than the extent of the predicated ID +
2652 // halo. Note that getRootAxisInfo doesn't work since consumer_id
2653 // isn't a root domain.
2654 if (gpu_lower->haloInfo()->hasHaloWidth(consumer_id)) {
2655 auto halo = gpu_lower->haloInfo()->getHaloWidth(consumer_id);
2656 stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo);
2657 }
2658 }
2659
2660 return {start_limit, stop_limit};
2661}
2662
2663// Get the offsets for the start and stop predicates. The offsets
2664// are to be added to the index.
2665std::pair<Val*, Val*> getStartAndStopOffsets(
2666 IterDomain* consumer_id,
2667 TensorView* consumer_tv,
2668 const std::unordered_map<IterDomain*, Val*>& consumer_start_index_map,
2669 const std::unordered_map<IterDomain*, Val*>& consumer_stop_index_map,
2670 bool padding_predicate,
2671 bool unswitch,
2672 bool non_divisible_pred) {
2673 // By default, the offsets for the start and stop predicates are
2674 // just zero. All halo-related adjustments are done at root domains,
2675 // so consumer_id is not a root domain, no adjustment is required.
2676 if (consumer_id->definition() != nullptr && !non_divisible_pred) {
2677 return {
2678 GpuLower::current()->kernel()->zeroVal(),
2679 GpuLower::current()->kernel()->zeroVal()};
2680 }
2681
2682 auto consumer_def = consumer_tv->definition();
2683
2684 Val* start_offset = GpuLower::current()->kernel()->zeroVal();
2685 Val* stop_offset = GpuLower::current()->kernel()->zeroVal();
2686
2687 // These adjustments are not required when predicating non-divisible splits
2688 if (!non_divisible_pred) {
2689 if (consumer_def->isA<ShiftOp>()) {
2690 std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForShift(
2691 consumer_tv, consumer_id, padding_predicate);
2692 } else if (consumer_def->isA<GatherOp>()) {
2693 std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForGather(
2694 consumer_tv,
2695 consumer_id,
2696 consumer_start_index_map,
2697 consumer_stop_index_map,
2698 padding_predicate);
2699 }
2700
2701 // Adjustment for partial split
2702 auto partial_split_offset =
2703 getGlobalConsumerOffsetWithPartialSplit(consumer_id);
2704 start_offset =
2705 SimplifyingIrBuilder::addExpr(start_offset, partial_split_offset);
2706 stop_offset =
2707 SimplifyingIrBuilder::addExpr(stop_offset, partial_split_offset);
2708
2709 // If generating a predicate for unswitch, adjust the stop offset to
2710 // accommodate the addition of halo to the loop stop. See the
2711 // comment in getPredicateIndexingFromIdGraph as well.
2712 if (unswitch) {
2713 TORCH_INTERNAL_ASSERT(
2714 !padding_predicate, "Unswitch should not use the padding predicate");
2715 auto stop_unswitch_offset =
2716 getUnswitchStopOffset(consumer_id, consumer_tv);
2717 stop_offset =
2718 SimplifyingIrBuilder::addExpr(stop_offset, stop_unswitch_offset);
2719 }
2720 }
2721
2722 // Get the boundaries of two ends
2723 auto limits = getStartAndStopLimitOffsets(
2724 consumer_id, padding_predicate, non_divisible_pred);
2725
2726 // At this point, we have everything to create both start and stop
2727 // predicates as:
2728 //
2729 // index + start_offset >= start_limit
2730 // index + stop_offset < extent + stop_limit
2731 //
2732 // In order to enable consolidating unswitch predicates, organize
2733 // the predicates as:
2734 //
2735 // index + (start_offset - start_limit) >= 0
2736 // index + (stop_offset - stop_limit) < extent
2737
2738 start_offset = SimplifyingIrBuilder::subExpr(start_offset, limits.first);
2739 stop_offset = SimplifyingIrBuilder::subExpr(stop_offset, limits.second);
2740
2741 return {start_offset, stop_offset};
2742}
2743
2744// A partial value of a start offset is returned if determined to be
2745// safe. Nullptr is returned if it can be omitted completely.
2746Val* simplifyStartOffset(Val* start_offset) {
2747 // Start predicate can be omitted when start_offset >= 0.
2748 auto offset_val = start_offset->as<Int>()->value();
2749 if (offset_val.has_value() && offset_val.value() >= 0) {
2750 return nullptr;
2751 }
2752
2753 // start_offset may look like min(0, window_index - pad). Then, can
2754 // remove min and leave the rhs only.
2755 auto def = dynamic_cast<BinaryOp*>(start_offset->definition());
2756 if (def != nullptr && def->getBinaryOpType() == BinaryOpType::Min &&
2757 def->lhs()->isZeroInt()) {
2758 return def->rhs();
2759 }
2760
2761 return start_offset;
2762}
2763
2764bool canOmitStopPredicate(
2765 Val* stop_index,
2766 Val* stop_offset,
2767 IterDomain* contig_id) {
2768 bool index_simple = stop_index->definition() == nullptr;
2769 // The definition may be just adding the magic zero, which can be
2770 // effectively considered "simple"
2771 if (!index_simple && isProtectedWithMagicZero(stop_index)) {
2772 // Make sure the lhs of stop_index is simple.
2773 auto lhs = stop_index->definition()->as<BinaryOp>()->lhs();
2774 if (lhs->definition() == nullptr) {
2775 index_simple = true;
2776 }
2777 }
2778
2779 if (!index_simple) {
2780 return false;
2781 }
2782
2783 const auto gpu_lower = GpuLower::current();
2784
2785 // Stop predicate: stop_index + stop_offset < extent, where
2786 // stop_index ranges from 0 to (extent + halo), so this can be
2787 // omitted if extent + halo + stop_offset < extent, i.e., halo +
2788 // stop_offset <= 0.
2789
2790 auto stop_offset_val = stop_offset->as<Int>()->value();
2791
2792 // If they are not compile-time constant, can't prove the
2793 // condition.
2794 if (!stop_offset_val.has_value()) {
2795 return false;
2796 }
2797
2798 // Note that when a root domain is halo extended, it is the domain
2799 // to be predicated, not its merged contig id even if it exists. So,
2800 // if contig_id does not have root axis info, contig_id is
2801 // guaranteed to have no halo.
2802 auto halo_ext = gpu_lower->haloInfo()->hasRootAxisInfo(contig_id)
2803 ? gpu_lower->haloInfo()->getRootAxisInfo(contig_id).width()
2804 : 0;
2805
2806 if (halo_ext + stop_offset_val.value() > 0) {
2807 return false;
2808 }
2809
2810 // When the domain is parallelized, the parallel dimension must be
2811 // exact. Otherwise, there would be extra threads/blocks that need
2812 // to be predicated out.
2813 if (isParallelTypeThread(contig_id->getParallelType())) {
2814 if (!gpu_lower->parallelDimensionMap().isExact(
2815 contig_id->getParallelType())) {
2816 return false;
2817 }
2818 // If the domain has halo, the loop is expanded by the halo
2819 // extent, so we can't prove the loop extent is the same as the
2820 // parallel dimension.
2821 if (halo_ext != 0) {
2822 return false;
2823 }
2824 }
2825
2826 return true;
2827}
2828
2829std::pair<Val*, Val*> hoistPredicates(
2830 Val* start_index,
2831 Val* stop_index,
2832 const std::vector<kir::ForLoop*>& loops,
2833 std::vector<IterDomain*> loop_domains,
2834 const std::unordered_map<IterDomain*, Val*>& start_initial_loop_index_map,
2835 const std::unordered_map<IterDomain*, Val*>& stop_initial_loop_index_map,
2836 kir::ForLoop* unswitch_or_vec_loop,
2837 IterDomain* predicated_consumer_id,
2838 TensorView* predicated_consumer_tv) {
2839 const std::pair<Val*, Val*> same_indices{start_index, stop_index};
2840
2841 if (isOptionDisabled(DisableOption::IndexHoist)) {
2842 return same_indices;
2843 }
2844
2845 const auto start_is_same_as_stop = stop_index == start_index;
2846
2847 Val* hoisted_stop_index = nullptr;
2848
2849 if (stop_index->definition() == nullptr) {
2850 // If the index doens't have an expression, nothing to hoist
2851 hoisted_stop_index = stop_index;
2852 } else {
2853 bool inserted = false;
2854 std::tie(hoisted_stop_index, inserted) =
2855 GpuLower::current()->commonIndexMap().insert(
2856 predicated_consumer_id,
2857 predicated_consumer_tv->domain(),
2858 loop_domains,
2859 stop_initial_loop_index_map,
2860 loops,
2861 stop_index);
2862 }
2863
2864 Val* hoisted_start_index = nullptr;
2865 if (start_is_same_as_stop) {
2866 hoisted_start_index = hoisted_stop_index;
2867 } else if (start_index->definition() == nullptr) {
2868 hoisted_start_index = start_index;
2869 } else {
2870 bool inserted = false;
2871 std::tie(hoisted_start_index, inserted) =
2872 GpuLower::current()->commonIndexMap().insert(
2873 predicated_consumer_id,
2874 predicated_consumer_tv->domain(),
2875 loop_domains,
2876 start_initial_loop_index_map,
2877 loops,
2878 start_index);
2879 }
2880
2881 return {hoisted_start_index, hoisted_stop_index};
2882}
2883
2884// Updates a loop index map with a loop index protected by magic zero
2885std::unordered_map<IterDomain*, Val*> updateInitialLoopIndexMap(
2886 const std::unordered_map<IterDomain*, Val*>& initial_loop_index_map,
2887 const IndexMagicZeroInfo& magic_zero_info) {
2888 if (magic_zero_info.original_loop_index != nullptr) {
2889 TORCH_INTERNAL_ASSERT(magic_zero_info.protected_loop_index != nullptr);
2890 auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
2891 magic_zero_info.loop_id, IdMappingMode::EXACT);
2892 auto updated_map = initial_loop_index_map;
2893 updated_map[concrete_loop_id] = magic_zero_info.protected_loop_index;
2894 return updated_map;
2895 } else {
2896 return initial_loop_index_map;
2897 }
2898}
2899
2900} // namespace
2901
2902// Returns predicates and the concrete (by loop map) root domains they cover
2903std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
2904 TensorView* consumer_tv,
2905 const std::vector<kir::ForLoop*>& loops,
2906 kir::ForLoop* unswitch_or_vec_loop,
2907 bool shift_padding) {
2908 FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates");
2909
2910 const auto gpu_lower = GpuLower::current();
2911
2912 const bool is_unswitch = unswitch_or_vec_loop != nullptr;
2913
2914 // Nothing needs to be done when padding is not required.
2915 if (shift_padding && !needsPadding(consumer_tv)) {
2916 return {RootPredicateInfo::getFalseInfo()};
2917 }
2918
2919 auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv);
2920
2921 // Generate start and stop indexing from idgraph.
2922 //
2923 // Both start and stop positions may need to be predicated. Indexing
2924 // differs when generating predicates for unswitch.
2925 // NOTE: If we could find-and-replace KIR nodes, we could just
2926 // generate one index map, clone it and replace the loop-to-index
2927 // mappings of unswitched loops for the start predicate.
2928
2929 auto stop_indexing_from_idgraph = getPredicateIndexingFromIdGraph(
2930 loops, consumer_tv, unswitch_or_vec_loop, db_axis, false);
2931 const auto consumer_stop_indexing = stop_indexing_from_idgraph.index;
2932 const auto& consumer_stop_index_map = consumer_stop_indexing.indexMap();
2933
2934 // If not unswitch, share the same indexing map as the stop index
2935 // map
2936 const auto start_indexing_from_idgraph = is_unswitch
2937 ? getPredicateIndexingFromIdGraph(
2938 loops, consumer_tv, unswitch_or_vec_loop, db_axis, true)
2939 : stop_indexing_from_idgraph;
2940 const auto consumer_start_indexing = start_indexing_from_idgraph.index;
2941 const auto& consumer_start_index_map = consumer_start_indexing.indexMap();
2942
2943 // Get the contiguous ids we need to generate predicates for
2944 auto contig_id_infos =
2945 getPredicateContigIds(consumer_tv, consumer_stop_index_map);
2946
2947 auto non_divisible_splits =
2948 getNonDivisibleConsumerDomainsToPredicate(consumer_tv);
2949 contig_id_infos.insert(
2950 contig_id_infos.end(),
2951 non_divisible_splits.begin(),
2952 non_divisible_splits.end());
2953
2954 std::vector<RootPredicateInfo> pred_info_vec;
2955
2956 for (auto contig_id_entry : contig_id_infos) {
2957 auto contig_id = contig_id_entry.id;
2958 // No predicates needed for braodcasted indices.
2959 if (contig_id->isBroadcast() ||
2960 gpu_lower->trivialReductionInfo().isDerived(contig_id)) {
2961 continue;
2962 }
2963
2964 auto root_ids = contig_id_entry.covered_ids;
2965
2966 const auto consumer_stop_indexing_it =
2967 consumer_stop_index_map.find(contig_id);
2968
2969 // First condition below happens with Misaligned predicates, where
2970 // inner-most vectorized loops are not included in the loops
2971 // parameter. Predicates involving vectorized loops are separately
2972 // generated in lower_misaligned_vectorization.
2973 //
2974 // Second condition is simply to avoid predication on broadcasting axes as
2975 // it's not required.
2976 if (consumer_stop_indexing_it == consumer_stop_index_map.end() ||
2977 consumer_stop_indexing_it->second->isZeroInt()) {
2978 continue;
2979 }
2980
2981 RootPredicateInfo info;
2982
2983 // Compute offsets for start and stop predicate. For non-shift,
2984 // non-gather ops, there's only stop predicate as indices never be
2985 // negative. However, for shift and gather, the index may need to
2986 // be predicated so that it is >= zero.
2987 //
2988 // Furthermore, in case of gather, both producer and consumer
2989 // positions may need to be predicated, so there can be multiple
2990 // offset values.
2991 //
2992 // The final predicates will look like:
2993 // (index + start_offset) >= 0 && (index + stop_offset) < extent.
2994
2995 std::tie(info.start_offset_, info.stop_offset_) = getStartAndStopOffsets(
2996 contig_id,
2997 consumer_tv,
2998 consumer_start_index_map,
2999 consumer_stop_index_map,
3000 shift_padding,
3001 unswitch_or_vec_loop != nullptr,
3002 contig_id_entry.is_non_divisible_split);
3003
3004 auto stop_index = consumer_stop_indexing_it->second;
3005 auto start_index = consumer_start_index_map.at(contig_id);
3006
3007 IndexMagicZeroInfo start_magic_zero_info;
3008 IndexMagicZeroInfo stop_magic_zero_info;
3009
3010 // When the start and stop indices are not the same, apply the
3011 // magic-zero protection separately for both of them.
3012 if (stop_index != start_index) {
3013 start_magic_zero_info = protectPredicateIndexWithMagicZero(
3014 start_index, start_indexing_from_idgraph, loops);
3015 stop_magic_zero_info = protectPredicateIndexWithMagicZero(
3016 stop_index, stop_indexing_from_idgraph, loops);
3017 } else {
3018 stop_magic_zero_info = protectPredicateIndexWithMagicZero(
3019 stop_index, stop_indexing_from_idgraph, loops);
3020 start_magic_zero_info = stop_magic_zero_info;
3021 }
3022
3023 start_index = start_magic_zero_info.index;
3024 stop_index = stop_magic_zero_info.index;
3025
3026 // Update the loop-index map with the magic-zero protection info
3027 // before passing it to the hoisting function
3028 std::tie(start_index, stop_index) = hoistPredicates(
3029 start_index,
3030 stop_index,
3031 loops,
3032 stop_indexing_from_idgraph.resolved_loop_domains,
3033 updateInitialLoopIndexMap(
3034 start_indexing_from_idgraph.initial_concrete_index_map,
3035 start_magic_zero_info),
3036 updateInitialLoopIndexMap(
3037 stop_indexing_from_idgraph.initial_concrete_index_map,
3038 stop_magic_zero_info),
3039 unswitch_or_vec_loop,
3040 contig_id,
3041 consumer_tv);
3042
3043 // Build predicates for start positions as:
3044 // start_index + start_offset >= 0
3045 auto start_offset = simplifyStartOffset(info.start_offset_);
3046 if (start_offset == nullptr) {
3047 info.start_predicate_ = GpuLower::current()->kernel()->trueVal();
3048 } else {
3049 auto offsetted_start_index =
3050 SimplifyingIrBuilder::addExpr(start_index, start_offset);
3051 auto start_pred =
3052 SimplifyingIrBuilder::geExpr(
3053 offsetted_start_index, GpuLower::current()->kernel()->zeroVal())
3054 ->as<Bool>();
3055 info.start_predicate_ = start_pred;
3056 }
3057
3058 // Build predicates for stop positions as:
3059 // stop_index + stop_offset < IterDomain::extent
3060 auto stop_offset = info.stop_offset_;
3061 if (canOmitStopPredicate(stop_index, stop_offset, contig_id)) {
3062 info.stop_predicate_ = GpuLower::current()->kernel()->trueVal();
3063 } else {
3064 auto offsetted_stop_index =
3065 SimplifyingIrBuilder::addExpr(stop_index, stop_offset);
3066 auto stop_pred = SimplifyingIrBuilder::ltExpr(
3067 offsetted_stop_index, contig_id->extent())
3068 ->as<Bool>();
3069 info.stop_predicate_ = stop_pred;
3070 }
3071
3072 for (auto consumer_id : contig_id_entry.covered_ids) {
3073 info.root_ids_.insert(consumer_id);
3074 }
3075 pred_info_vec.emplace_back(info);
3076 }
3077
3078 return pred_info_vec;
3079}
3080
3081RootPredicateInfo RootPredicateInfo::getFalseInfo() {
3082 RootPredicateInfo info;
3083 info.start_predicate_ = GpuLower::current()->kernel()->falseVal();
3084 info.stop_predicate_ = GpuLower::current()->kernel()->falseVal();
3085
3086 return info;
3087}
3088
3089} // namespace cuda
3090} // namespace fuser
3091} // namespace jit
3092} // namespace torch
3093