1#include <contiguity.h>
2#include <index_compute.h>
3#include <ir_utils.h>
4#include <lower2device.h>
5#include <lower_index_compute.h>
6#include <lower_magic_zero.h>
7#include <lower_utils.h>
8#include <lower_validation.h>
9#include <transform_iter.h>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16IndexFromIdGraph::IndexFromIdGraph(
17 IndexCompute index_,
18 IndexCompute concrete_index_,
19 std::unordered_map<IterDomain*, Val*> initial_concrete_index_map_,
20 std::vector<IterDomain*> loop_domains_)
21 : index(index_),
22 concrete_index(concrete_index_),
23 initial_concrete_index_map(initial_concrete_index_map_),
24 resolved_loop_domains(loop_domains_) {}
25
26namespace {
27
28// Maps all producer domains to consumer with broadcast
29// forwarding. Used to find the allocation position.
30// TODO: should this be an ir_util ? Didn't seem to be
31// used too much though.
32std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer(
33 const TensorView* producer_tv,
34 const TensorView* consumer_tv) {
35 // This map has forwarded broadcast axes, it should only be used to compute
36 // the allocation position of the producer, and to figure out which producer
37 // indices are mapped to consumer trivial reductions.
38 std::unordered_map<IterDomain*, IterDomain*> p2c_alloc_map;
39
40 // We want to replay producer as consumer instead of the other way around
41 // since consumer may have some broadcasted axes producer doesn't have
42 // merged into loops producer may use. If we did consumer as producer we
43 // wouldn't have this information in the mapping.
44 auto replay_PasC = BestEffortReplay::replayPasC(
45 producer_tv,
46 consumer_tv,
47 -1,
48 PairwiseRootDomainMap(producer_tv, consumer_tv));
49
50 // Grab consumer domain entries and reverse replay map. TODO: Maybe
51 // TransformReplay::replayPasC could return this map
52 for (auto id : consumer_tv->domain()->domain()) {
53 const auto& c2p_map = replay_PasC.getReplay();
54 auto c2p_it = c2p_map.find(id);
55 if (c2p_it != c2p_map.end()) {
56 auto c_id = c2p_it->first;
57 auto p_id = c2p_it->second;
58 p2c_alloc_map[p_id] = c_id;
59 }
60 }
61
62 return p2c_alloc_map;
63}
64
65std::unordered_map<IterDomain*, IterDomain*> invertOneToOneMap(
66 const std::unordered_map<IterDomain*, IterDomain*>& map) {
67 std::unordered_map<IterDomain*, IterDomain*> inverted;
68 for (const auto& kv : map) {
69 bool inserted = inverted.emplace(kv.second, kv.first).second;
70 TORCH_INTERNAL_ASSERT(
71 inserted,
72 "Multiple mappings to the same value detected: ",
73 kv.second->toString());
74 }
75 return inverted;
76}
77
78//! A struct to keep track of necessary parameters used in
79//! configuring index compute pass.
80//! These parameters are needed to propagate the indexing from the leaf nodes of
81//! the TVs and loop nests to the TVs rfactor domain during
82//! index_compute.cpp::IndexCompute passes.
83//! TODO:
84//! Would expect this list to become shorter over time,
85//! as more info can be determined holistically.
86struct IndexingParameters {
87 //! Initial binding of index math to concrete iterdomain ids,
88 //! from the loop nest analysis.
89 std::unordered_map<IterDomain*, Val*> initial_concrete_id_index;
90
91 //! (Used in non-global indexing) the concrete iterdomains that
92 //! we want to skip or merge into contiguous indexing paths.
93 std::unordered_set<IterDomain*> zero_domains;
94
95 //! (Used in non-global indexing) the preferred path we would
96 //! be propagating contiguously merged indices backward.
97 std::unordered_set<IterDomain*> preferred_concrete_ids;
98
99 //! The inferred halo padded extents of the concrete iterdomains.
100 std::unordered_map<IterDomain*, Val*> concrete_id_to_halo_extent;
101};
102
103// Initial loop index map for global producer or consumer case.
104IndexingParameters getLinearIndexParameters(
105 const LoopIndexing& loop_indexing,
106 bool index_producer = false) {
107 IndexingParameters index_parameters;
108
109 auto& loops = loop_indexing.loops();
110 auto& loop_domain = loop_indexing.loopDomains();
111 auto& loop_index_map = index_parameters.initial_concrete_id_index;
112
113 for (auto loop_idx : c10::irange(loops.size())) {
114 auto loop = loops[loop_idx];
115 auto index_domain = GpuLower::current()->caMap()->getConcreteMappedID(
116 loop_domain[loop_idx], IdMappingMode::EXACT);
117 if (loop->isTrivial()) {
118 // This is useful information in the case of
119 // MisalignedVectorize and double buffer epilog, etc.
120 loop_index_map[index_domain] = loop->start();
121 } else {
122 // Default use pre-allocated integers for index
123 loop_index_map[index_domain] = loop->index();
124 }
125 }
126
127 // Derive the halo extents from the loop indexing result.
128 index_parameters.concrete_id_to_halo_extent =
129 GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap(
130 loop_indexing);
131
132 protectNonPredicateIndexWithMagicZero(
133 loops,
134 loop_indexing.loopDomains(),
135 index_parameters.initial_concrete_id_index);
136
137 // Setup double buffer increment for producer case:
138 // TODO: could unify these double buffer index calculation
139 // in follow ups.
140 if (index_producer) {
141 auto double_buffer_loop =
142 GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop(
143 loop_indexing.consumerTv(), loops, true);
144
145 for (auto loop_idx : c10::irange(loops.size())) {
146 auto loop = loops[loop_idx];
147 if (loop == double_buffer_loop) {
148 TORCH_INTERNAL_ASSERT(
149 !loop->isTrivial(), "The double buffer loop must be materialized");
150
151 auto loop_id = loop_indexing.loopDomains()[loop_idx];
152
153 auto concrete_loop_id =
154 GpuLower::current()->caMap()->getConcreteMappedID(
155 loop_id, IdMappingMode::EXACT);
156
157 auto stage_depth =
158 GpuLower::current()->doubleBufferInfo().getStageDepthFor(
159 loop->iter_domain());
160 index_parameters.initial_concrete_id_index[concrete_loop_id] =
161 SimplifyingIrBuilder::addExpr(
162 index_parameters.initial_concrete_id_index[concrete_loop_id],
163 SimplifyingIrBuilder::create<Int>(stage_depth - 1));
164 }
165 }
166 }
167
168 return index_parameters;
169}
170
171// Initial index parameters for shared and local case
172IndexingParameters getNonGlobalInitialIndexParameters(
173 const LoopIndexing& loop_indexing,
174 const TensorView* consumer_tv,
175 bool index_producer = false,
176 const TensorView* producer_tv = nullptr,
177 std::unordered_map<IterDomain*, IterDomain*> p2c_map = {}) {
178 IndexingParameters index_parameters;
179 const auto& loops = loop_indexing.loops();
180 const auto& loop_domains = loop_indexing.loopDomains();
181
182 // TODO:
183 // The non-global path should become shorter as we
184 // pull more info into id graph.
185 std::unordered_map<IterDomain*, IterDomain*> alloc_id_map;
186
187 if (index_producer) {
188 alloc_id_map = mapAllProducerDomainsToConsumer(producer_tv, consumer_tv);
189 }
190
191 auto alloc_tv = index_producer ? producer_tv : consumer_tv;
192 auto alloc_info = lower_utils::getAllocInformation(
193 alloc_tv, loops, alloc_id_map, index_producer);
194
195 std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
196 std::unordered_set<kir::ForLoop*> zero_loops;
197
198 kir::ForLoop* double_buffer_loop = nullptr;
199
200 if (index_producer) {
201 double_buffer_loop =
202 GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop(
203 consumer_tv, loops, true);
204 }
205
206 std::tie(loop_to_ind_map, zero_loops) = indexMapFromTV(
207 alloc_tv,
208 loops,
209 alloc_info.init_for_loop,
210 !index_producer,
211 double_buffer_loop);
212
213 ensureStaticIndexing(alloc_tv, alloc_info.init_for_loop, loops, alloc_id_map);
214
215 TORCH_INTERNAL_ASSERT(
216 loops.size() <= loop_domains.size(),
217 "Loop domain didn't replay all loops");
218
219 for (auto loop_idx : c10::irange(loops.size())) {
220 auto loop = loops[loop_idx];
221 auto loop_domain = loop_domains[loop_idx];
222
223 auto concrete_loop_domain =
224 GpuLower::current()->caMap()->getConcreteMappedID(
225 loop_domain, IdMappingMode::EXACT);
226
227 index_parameters.initial_concrete_id_index[concrete_loop_domain] =
228 loop_to_ind_map.at(loop);
229
230 if (zero_loops.count(loop)) {
231 index_parameters.zero_domains.insert(concrete_loop_domain);
232 }
233 }
234
235 // Derive preferred path from loop indexing result.
236 const TensorView* target_tv = index_producer ? producer_tv : consumer_tv;
237 index_parameters.preferred_concrete_ids = buildLoopIndexingPreferredPath(
238 target_tv, loop_indexing, index_producer, p2c_map);
239
240 // Derive the halo extents from the loop indexing result.
241 index_parameters.concrete_id_to_halo_extent =
242 GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap(
243 loop_indexing);
244
245 return index_parameters;
246}
247
248//! Initial index parameters for predicate, adjusts loop to indexing
249//! may according to the information annotated on the loop nest.
250//!
251//! TODO:
252//! This function is mostly copy pasted from previous implementation
253//! at this step, further clean up is possible since:
254//! 1. Much of the loop-to-ind adjustment will be issued from idgraph
255//! 2. Much of the initial index logic could be shared across all
256//! the 3 variants.
257IndexingParameters getPredicateInitialIndexParameters(
258 const LoopIndexing& loop_indexing,
259 TensorView* consumer_tv,
260 kir::ForLoop* unswitch_or_vec_loop,
261 IterDomain* double_buffer_axis,
262 bool is_start_predicate) {
263 IndexingParameters index_parameters;
264 const auto& loops = loop_indexing.loops();
265 const auto& loop_domains = loop_indexing.loopDomains();
266
267 // This shouldn't be needed.
268 TORCH_INTERNAL_ASSERT(
269 loops.size() <= loop_domains.size(),
270 "Loop domain didn't replay all loops");
271
272 std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
273
274 // Fill initial index with each forloop's index.
275 std::transform(
276 loops.begin(),
277 loops.end(),
278 std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
279 [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
280
281 // Generate unswitch loop to index map.
282 if (unswitch_or_vec_loop != nullptr) {
283 // Vectorized predicates are different from unswitch. Unswitch predicates
284 // all loops within the unswitch (the outer most unswitch) are generated
285 // with loop->extent-1 as the index. With vectorized predicates, only the
286 // vectorized loop should be like this.
287 bool vectorized_pred =
288 unswitch_or_vec_loop->iter_domain()->getParallelType() ==
289 ParallelType::Vectorize;
290
291 bool within_unswitch = false;
292
293 for (const auto loop_i : c10::irange(loops.size())) {
294 auto loop = loops[loop_i];
295 auto loop_id = loop->iter_domain();
296 auto loop_pt = loop_id->getParallelType();
297 auto ref_id = loop_domains.at(loop_i);
298
299 if (loop == unswitch_or_vec_loop) {
300 within_unswitch = true;
301 }
302
303 if (within_unswitch) {
304 // Rely on the reference to check broadcasting. The for loop could be
305 // broadcasted on a constant value from an unroll split. Since reference
306 // may convert this to an iter domain, that for loop could be valid to
307 // generate predication from.
308
309 // Note that loop->stop() is not used below. Instead,
310 // loop->iter_domain()->extent() is used, which is uniform
311 // across the mapped domains irrespective of halo. Predicates are
312 // compared with each to pick the most restrictive ones. The
313 // comparison is done by only using the offset, which is the
314 // term added to the index. So, the index term must be the
315 // same among all predicates, otherwise the comparison would
316 // be invalid. The effect by halo is added to the offset
317 // term. See getUnswitchStopOffset.
318
319 if (ref_id->isBroadcast()) {
320 // Ignore indexing into broadcasted dimensions.
321 continue;
322 } else if (loop_id->isThread()) {
323 // When parallelized, if the loop stop is the same as the
324 // extent of the associated IterDomain, i.e., no extra
325 // iterations for halo, predicating with the threading index
326 // is sufficient for both the start and stop
327 // predicates. That isn't the case if the loop has halo, and
328 // in the case either the minimum and maximum values of the
329 // iteration domain needs to be used.
330 //
331 // Note: Better performance was obtained if using
332 // threadIdx in unswitch predicates was avoided. More
333 // specifically, in the Hdiff stencil example, instead of
334 // predicating with threadIdx.x for both the start and stop
335 // predicates, using zero and (blockDim.x - 1) for the start
336 // and stop predicates, respectively, resulted in less
337 // register pressure. The alternative codegen can be done by
338 // adding this to the first if condition:
339 // loop_id->isBlockDim(). This would not be a concern if the
340 // else part could be omitted, so canOmitElseClause should
341 // be used as well.
342 if (loop->stop() == loop_id->extent()) {
343 loop_to_ind_map[loop] = loop->start();
344 } else if (is_start_predicate) {
345 loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal();
346 } else {
347 // Note that the parallel dimension is used rather than
348 // loop-stop(). See the above comment.
349 loop_to_ind_map[loop] =
350 GpuLower::current()->parallelDimensionMap().get(loop_pt);
351 }
352 } else if (is_start_predicate) {
353 loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal();
354 } else {
355 // Similar to the above, loop_id()->extent() is
356 // used here instead of loop->stop(). See the above comment.
357 loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr(
358 loop_id->extent(), GpuLower::current()->kernel()->oneVal());
359 }
360 }
361
362 // If a vectorized predicate, bail after the vectorized loop was found.
363 // Don't continue unswitching loops.
364 if (vectorized_pred && within_unswitch) {
365 break;
366 }
367 }
368 }
369
370 // Modify trivial loops to use the loop start value.
371 // FIXME: eventually should be all lifted in idgraph.
372 for (const auto loop : loops) {
373 auto& idx = loop_to_ind_map.at(loop);
374 // If the loop is trivial, the loop index can only be the loop
375 // start value.
376 if (idx == loop->index() && loop->isTrivial()) {
377 idx = loop->start();
378 }
379 }
380
381 // Increment double buffer loop index
382 if (double_buffer_axis != nullptr) {
383 auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop(
384 double_buffer_axis, loops, true);
385 if (db_loop != nullptr) {
386 auto loop_to_ind_map_it = loop_to_ind_map.find(db_loop);
387 TORCH_INTERNAL_ASSERT(loop_to_ind_map_it != loop_to_ind_map.end());
388 auto cur_index = loop_to_ind_map_it->second;
389 // if cur_index is not the same as the index of db_loop, it must
390 // be true that that index has been modified to support
391 // unswitch. In that case, it is not necessary to move ahead the
392 // index for double buffering.
393 auto stage_depth =
394 GpuLower::current()->doubleBufferInfo().getStageDepthFor(
395 db_loop->iter_domain());
396 if (cur_index == db_loop->index()) {
397 loop_to_ind_map[db_loop] = SimplifyingIrBuilder::addExpr(
398 cur_index, SimplifyingIrBuilder::create<Int>(stage_depth - 1));
399 }
400 }
401 }
402
403 // Convert loop-to-ind map to concrete-to-ind map
404 for (int loop_idx : c10::irange(loops.size())) {
405 auto loop = loops.at(loop_idx);
406 auto concrete_loop_domain =
407 GpuLower::current()->caMap()->getConcreteMappedID(
408 loop_domains.at(loop_idx), IdMappingMode::EXACT);
409 index_parameters.initial_concrete_id_index[concrete_loop_domain] =
410 loop_to_ind_map.at(loop);
411 }
412
413 // Note that, unlike non-predicate indexing, magic-zero insertion is
414 // not done at this point but is done individually for each indexed
415 // domain. See Index::getReferenceRootPredicates.
416
417 // Derive the halo extents from the loop indexing result.
418 index_parameters.concrete_id_to_halo_extent =
419 GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap(
420 loop_indexing);
421
422 return index_parameters;
423}
424
425} // namespace
426
427class LoopIndexingAnalysis {
428 public:
429 static LoopIndexing fromLoopAndConsumer(
430 const std::vector<kir::ForLoop*>& loops,
431 const TensorView* consumer_tv) {
432 LoopIndexingAnalysis analysis(loops, consumer_tv);
433 return analysis.getLoopIndexing();
434 }
435
436 private:
437 explicit LoopIndexingAnalysis(
438 const std::vector<kir::ForLoop*>& loops,
439 const TensorView* consumer_tv);
440
441 //! Populate derived information into a LoopIndexing
442 //! data structure.
443 LoopIndexing getLoopIndexing() {
444 LoopIndexing indexing;
445 indexing.loops_ = loops_;
446 indexing.consumer_tv_ = consumer_tv_;
447 indexing.loop_root_ = loop_root_domains_;
448 indexing.loop_domains_ = loop_domains_.vector();
449 indexing.index_exprs_ = replayed_exprs_;
450 indexing.out_of_line_exprs_ = out_of_line_exprs_;
451 return indexing;
452 }
453
454 //! Validates that the current loop structure is well formed, in the sense
455 //! that ca_map would not map any two loops in the loop nest together.
456 void validateLoopStructure(const std::vector<kir::ForLoop*>& loops);
457
458 //! Start at the loop iter domains, and traverse back into history on the
459 //! concrete IDs in the exact map calling "visitExpr" expressions through the
460 //! history.
461 void traverseFromDomainVals();
462
463 //! Concretize the given iterdomain and record the visit (in deterministic
464 //! order) in terms of the exact mapped concrete id. Marks the mapping of the
465 //! id to the concrete id in "concrete_to_original_id_" and returns the
466 //! concrete id.
467 IterDomain* concretizeAndVisitId(IterDomain* id);
468
469 //! If an equivalent expression has already been processed this function
470 //! simply returns. Otherwise puts the exact concrete IDs of inputs in
471 //! consumed_concrete_, and concrete IDs of outputs in produced_concrete_.
472 //! Then adds the expression to replayed_exprs_.
473 void visitExpr(Expr* expr);
474
475 //! Iterates through provided vals, calls concretizeAndVisitId on them, and
476 //! returns if any of the returned vals are in existing_ids. This is used to
477 //! check if inputs or outputs of ID expressions have already been
478 //! produced/consumed in the traversal. Indexing only needs to consume/produce
479 //! one IterDomain per exact disjoint set.
480 bool visitIdsAndCheckDuplication(
481 const std::vector<Val*>& vals,
482 const std::unordered_set<IterDomain*>& existing_ids);
483
484 //! Fills loop_domains_ with the corresponding replayed_concrete_id mapping to
485 //! the provided loops. Must be done after the exact iterdomain "replay"
486 //! (traverseFromDomainVals). loop_domains_ are the original_id not the
487 //! concrete_id (translated with concrete_to_original_id). These iter domains
488 //! are used to grab the history that will be replayed in IndexCompute. We're
489 //! looking for "new" root domains and subsequent transformations, filling in
490 //! any missing "outputs" (or inputs for backward traversal). Then fills
491 //! loop_domains_ with all of these iter domains.
492 void constructLoopDomains();
493
494 //! Fills out_of_line_exprs_ by traversing the selected list of
495 //! expressions in reverse topological order and collect iterdomains
496 //! on the indexing paths that only involves leaf id's on the right
497 //! of consumer's ca axis.
498 void collectOutOfLineExprs();
499
500 private:
501 //! Original loop nest input to derive info from.
502 const std::vector<kir::ForLoop*>& loops_;
503
504 //! Original consumer tv to derive view info from.
505 const TensorView* consumer_tv_ = nullptr;
506
507 // Exact concrete domains that has been used
508 // in the traversal connection.
509 std::unordered_set<IterDomain*> produced_concrete_;
510 std::unordered_set<IterDomain*> consumed_concrete_;
511
512 //! Iterdomains that the corresponding loops are generated from.
513 std::vector<IterDomain*> initial_loop_domain_ids_;
514
515 //! All Id's in consumer's transform history
516 std::vector<Val*> all_consumer_id_vals_;
517
518 //! Concrete iterdomains visited in the domain traversal,
519 //! in the order they are visited in traverseFromDomainVals.
520 VectorOfUniqueEntries<IterDomain*> replayed_concrete_ids_;
521
522 //! Keeping track of the original visited id's before they
523 //! were concretized.
524 std::unordered_map<IterDomain*, IterDomain*> concrete_to_original_id_;
525
526 //! Map from concrete id to its single consumer on the selected
527 //! iterdomain expression list.
528 std::unordered_map<IterDomain*, Expr*> concrete_id_to_consumer_;
529
530 //! Source domains that all the Iterdomain transforms
531 //! in the loop nest originated from.
532 std::vector<IterDomain*> loop_root_domains_;
533
534 //! Leaf domains representing the original loop structure
535 VectorOfUniqueEntries<IterDomain*> loop_domains_;
536
537 //! Selected list of exprs that will produce and consume each
538 //! of the exact concrete ids from the loop nest exactly once.
539 std::vector<Expr*> replayed_exprs_;
540
541 //! Set of expressions from the selected list that can be
542 //! resolved from axes on the right of ca axes.
543 std::vector<Expr*> out_of_line_exprs_;
544};
545
546LoopIndexingAnalysis::LoopIndexingAnalysis(
547 const std::vector<kir::ForLoop*>& loops,
548 const TensorView* consumer_tv)
549 : loops_(loops), consumer_tv_(consumer_tv) {
550 // Validate consistency in given loop nest
551 validateLoopStructure(loops);
552
553 // Populate initial loop iter domains.
554 std::transform(
555 loops.begin(),
556 loops.end(),
557 std::back_inserter(initial_loop_domain_ids_),
558 [](kir::ForLoop* fl) { return fl->iter_domain(); });
559
560 // Collect consumer id's for view rfactor traversal.
561 all_consumer_id_vals_ = DependencyCheck::getAllValsBetween(
562 {consumer_tv->getRootDomain().begin(),
563 consumer_tv->getRootDomain().end()},
564 {consumer_tv->domain()->domain().begin(),
565 consumer_tv->domain()->domain().end()});
566
567 // Resolve definition of each exact concrete id's involved in the whole loop
568 // nest transform history
569 traverseFromDomainVals();
570
571 // Construct concrete to consumer map. The replayed exprs are guaranteed to
572 // consume each concrete id once so this map is well defined.
573 for (auto expr : replayed_exprs_) {
574 for (auto input_id : ir_utils::filterByType<IterDomain>(expr->inputs())) {
575 auto concrete_input_id =
576 GpuLower::current()->caMap()->getConcreteMappedID(
577 input_id, IdMappingMode::EXACT);
578 concrete_id_to_consumer_[concrete_input_id] = expr;
579 }
580 }
581
582 // Reconstruct the iterdomain view of the original loopnest after resolving
583 // the exact definition of each index.
584 constructLoopDomains();
585
586 //! Collect the set of indexing expressions that can be
587 //! resolved out of line.
588 collectOutOfLineExprs();
589}
590
591void LoopIndexingAnalysis::validateLoopStructure(
592 const std::vector<kir::ForLoop*>& loops) {
593 // Throw an error when two loops are mapped with each other, which
594 // violates an assumption that unique mappings between concrete
595 // IterDomains and the IterDomains of the loop structure must be
596 // established. It should be a reasonable assumption, but fusions
597 // like below won't work:
598 // tv0 = [I0]
599 // tv1 = broadcast(tv0, {true, false});
600 // tv2 = broadcast(tv0, {false, true});
601 // tv3 = tv1 + tv2
602 // Notice that the two axes of each of tv1, tv2 and tv3 are mapped
603 // with each other. We believe it is unlikely this limitation
604 // becomes a real concern in practice.
605 // Map concrete id to the original loop iter domain.
606 std::unordered_map<IterDomain*, IterDomain*> concrete_to_loop;
607 for (auto it_i = loops.begin(); it_i != loops.end(); ++it_i) {
608 // Largely duplicating original logic
609 auto loop_id = (*it_i)->iter_domain();
610 auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
611 loop_id, IdMappingMode::EXACT);
612
613 TORCH_INTERNAL_ASSERT(
614 !concrete_to_loop.count(concrete_loop_id),
615 "Unsupported loop structure. Two loops are mapped together.",
616 loop_id->toString(),
617 " and ",
618 concrete_to_loop.at(concrete_loop_id)->toString());
619
620 concrete_to_loop[concrete_loop_id] = loop_id;
621 }
622}
623
624void LoopIndexingAnalysis::traverseFromDomainVals() {
625 // Order is really important here, start with outer most for loops in a
626 // depth first manner. The outer most loops are topologically closer to the
627 // outputs, so their broadcast dimensions are "more" resolved than those
628 // towards the inner most loops.
629 std::deque<IterDomain*> to_visit(
630 initial_loop_domain_ids_.begin(), initial_loop_domain_ids_.end());
631 std::unordered_set<Expr*> visited_exprs;
632 std::unordered_set<IterDomain*> visited_ids;
633
634 while (!to_visit.empty()) {
635 auto out_id = to_visit.front();
636 to_visit.pop_front();
637
638 if (!visited_ids.emplace(out_id).second) {
639 continue;
640 }
641 auto expr = out_id->definition();
642
643 if (auto rfactor_id =
644 getRfactorIDToTraverse(out_id, all_consumer_id_vals_)) {
645 to_visit.emplace_front(rfactor_id);
646 }
647
648 // ID's will be copied for the reference as we replay transformations. If
649 // there was no transformations on an iteration domain, a copy of the
650 // iteration domain for the reference is made here.
651 if (expr == nullptr) {
652 if (std::find(
653 initial_loop_domain_ids_.begin(),
654 initial_loop_domain_ids_.end(),
655 out_id) != initial_loop_domain_ids_.end()) {
656 concretizeAndVisitId(out_id);
657 }
658 continue;
659 }
660
661 if (!visited_exprs.emplace(expr).second) {
662 continue;
663 }
664
665 visitExpr(expr);
666
667 auto inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
668 // Make sure to put at the begining of the deque to maintain correct
669 // ordering.
670 to_visit.insert(to_visit.begin(), inp_ids.begin(), inp_ids.end());
671 }
672}
673
674IterDomain* LoopIndexingAnalysis::concretizeAndVisitId(IterDomain* id) {
675 auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
676 id, IdMappingMode::EXACT);
677 if (replayed_concrete_ids_.pushBack(concrete_id)) {
678 concrete_to_original_id_[concrete_id] = id;
679 }
680 return concrete_id;
681}
682
683namespace {
684// Alias used for std::transform
685IterDomain* exactConcreteId(IterDomain* id) {
686 return GpuLower::current()->caMap()->getConcreteMappedID(
687 id, IdMappingMode::EXACT);
688}
689} // namespace
690
691void LoopIndexingAnalysis::visitExpr(Expr* expr) {
692 if (auto swizzle2d = dynamic_cast<Swizzle2D*>(expr)) {
693 // Swizzle outputs are already forwarded through
694 // by exact CA map, so currently they are just
695 // ignored in the replay pass except
696 // that we want to note this node visited.
697 concretizeAndVisitId(swizzle2d->outX());
698 concretizeAndVisitId(swizzle2d->outY());
699 return;
700 }
701
702 // Current implementation just tries to
703 // follow the exact behavior of reference replay
704 // except that no expr was actually "replayed".
705
706 // Record all inputs, and stop if current expr
707 // duplicates id consumption or production.
708 if (visitIdsAndCheckDuplication(expr->inputs(), consumed_concrete_)) {
709 return;
710 }
711 if (visitIdsAndCheckDuplication(expr->outputs(), produced_concrete_)) {
712 return;
713 }
714
715 // Record the expr if no duplication on input or output found
716 replayed_exprs_.push_back(expr);
717
718 // Record the consumed and produced concrete ids by the newly
719 // recorded expression.
720 auto consumed_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
721 std::transform(
722 consumed_ids.begin(),
723 consumed_ids.end(),
724 std::inserter(consumed_concrete_, consumed_concrete_.end()),
725 exactConcreteId);
726
727 auto produced_ids = ir_utils::filterByType<IterDomain>(expr->outputs());
728 std::transform(
729 produced_ids.begin(),
730 produced_ids.end(),
731 std::inserter(produced_concrete_, produced_concrete_.end()),
732 exactConcreteId);
733}
734
735bool LoopIndexingAnalysis::visitIdsAndCheckDuplication(
736 const std::vector<Val*>& vals,
737 const std::unordered_set<IterDomain*>& existing_ids) {
738 bool duplication = false;
739 for (auto id : ir_utils::filterByType<IterDomain>(vals)) {
740 duplication = duplication || existing_ids.count(concretizeAndVisitId(id));
741 }
742 return duplication;
743}
744
745void LoopIndexingAnalysis::constructLoopDomains() {
746 for (auto loop_id : initial_loop_domain_ids_) {
747 // Find the replayed_concrete_id mapping to the loop id.
748 auto ref_id_it = std::find_if(
749 replayed_concrete_ids_.vector().begin(),
750 replayed_concrete_ids_.vector().end(),
751 [&](IterDomain* concrete_id) {
752 return
753 // Make sure the replayed_concrete_id is a leaf ID
754 !concrete_id_to_consumer_.count(concrete_id) &&
755 // Use permissive map so the selected ID indeed represents the
756 // loop.
757 // Note: see PR https://github.com/csarofeen/pytorch/pull/1960
758 // and issue https://github.com/csarofeen/pytorch/issues/1873
759 // This mapping look up is part of a staged indexing scheme.
760 // When we find a replayed exact id that exactly map to the loop
761 // id, this means that we can resolve indexing involved in this
762 // loop "locally", i.e. only with and with only the iterdomains
763 // on the
764 //
765 // given consumer tv.
766 // When we cannot find an exact mapping, the permissive mapping
767 // would
768 // help defering the indexing resolution for this loop nest
769 // level to other iterdomain expressions from tv's that are
770 // further concretized and usually they are further down the
771 // consumer chain of the given consumer tv.
772 //
773 // Intuitively exact mapping of two iterdomains should imply
774 // permissive mapping
775 // of them as well and if that was the case, only looking up
776 // permissive mapping would be enough to address both of the
777 // cases above.
778 // FIXME: But currently exact mapping does not imply permissive
779 // mapping (See issue:
780 // https://github.com/csarofeen/pytorch/issues/1963)
781 // Which means we should check both exact and permissive mapping
782 // here.
783 (GpuLower::current()->caMap()->areMapped(
784 concrete_id, loop_id, IdMappingMode::EXACT) ||
785 GpuLower::current()->caMap()->areMapped(
786 concrete_id, loop_id, IdMappingMode::PERMISSIVE));
787 });
788
789 TORCH_INTERNAL_ASSERT(
790 ref_id_it != replayed_concrete_ids_.vector().end(),
791 "Could not find required iter domain in reference replay: ",
792 loop_id->toString());
793
794 auto ref_id = *ref_id_it;
795 loop_domains_.pushBack(concrete_to_original_id_.at(ref_id));
796 }
797
798 // Construct the root domain as the inputs of the replayed domain
799 auto loops_replayed_domain_vals =
800 ir_utils::filterByType<Val>(loop_domains_.vector());
801 auto root_domain_vals = IterVisitor::getInputsTo(
802 {loops_replayed_domain_vals.begin(), loops_replayed_domain_vals.end()});
803
804 // Fill loop roots:
805 auto root_domain_ids = ir_utils::filterByType<IterDomain>(root_domain_vals);
806 loop_root_domains_ =
807 std::vector<IterDomain*>(root_domain_ids.begin(), root_domain_ids.end());
808
809 // The domain may have dangling iteration domains, i.e. the inner output of
810 // a split but not the outer. Find which replayed vals are dependant on the
811 // root domains.
812 auto all_replayed_vals =
813 ir_utils::filterByType<Val>(replayed_concrete_ids_.vector());
814 auto all_ids_from_root = DependencyCheck::getAllValsBetween(
815 {root_domain_vals.begin(), root_domain_vals.end()},
816 {all_replayed_vals.begin(), all_replayed_vals.end()});
817
818 // Fill all dangling outputs as otherwise backwards visitor in index compute
819 // will complain for not having all outputs of the traversal.
820 for (auto id : ir_utils::filterByType<IterDomain>(all_ids_from_root)) {
821 if (id->uses().empty()) {
822 loop_domains_.pushBack(GpuLower::current()->caMap()->getConcreteMappedID(
823 id, IdMappingMode::EXACT));
824 }
825 }
826}
827
828IndexFromIdGraph getTensorIndexFromIdGraph(
829 const std::vector<kir::ForLoop*>& loops,
830 const TensorView* consumer_tv,
831 const TensorView* producer_tv,
832 bool is_global,
833 std::unordered_map<IterDomain*, IterDomain*> c2p_map) {
834 bool index_producer = producer_tv != nullptr;
835 auto target_tv = index_producer ? producer_tv : consumer_tv;
836
837 auto loop_indexing =
838 LoopIndexingAnalysis::fromLoopAndConsumer(loops, consumer_tv);
839
840 IndexingParameters index_parameters;
841
842 std::unordered_map<IterDomain*, IterDomain*> p2c_map;
843
844 // The p2c map is only needed when indexing producer
845 // as producer has replayed ids.
846 if (index_producer) {
847 p2c_map = invertOneToOneMap(c2p_map);
848 }
849
850 if (is_global) {
851 index_parameters = getLinearIndexParameters(loop_indexing, index_producer);
852 } else {
853 index_parameters = getNonGlobalInitialIndexParameters(
854 loop_indexing, consumer_tv, index_producer, producer_tv, p2c_map);
855 }
856
857 IndexCompute indexing(
858 index_parameters.initial_concrete_id_index,
859 index_parameters.zero_domains,
860 index_parameters.preferred_concrete_ids,
861 index_parameters.concrete_id_to_halo_extent);
862
863 // Run first backward traversal to generate
864 // loop nest based indexing math.
865 indexing.run(loop_indexing);
866
867 // Populate indexing through exact map from initial indexing
868 auto consumer_root = index_producer ? consumer_tv->getRootDomain()
869 : consumer_tv->getMaybeRFactorDomain();
870
871 // First collect all iterdomains in consumer transform history.
872 auto all_consumer_vals = DependencyCheck::getAllValsBetween(
873 {consumer_root.begin(), consumer_root.end()},
874 {consumer_tv->domain()->domain().begin(),
875 consumer_tv->domain()->domain().end()});
876
877 // Indexable domains are the concrete id's we visited when
878 // traversing the "reference" indexing pass.
879 std::unordered_map<IterDomain*, IterDomain*> initial_indexable_map;
880
881 // Map the concrete id indexing back to the producer or consumer tv
882 std::unordered_map<IterDomain*, IterDomain*> index_update_map;
883
884 for (IterDomain* consumer_id :
885 ir_utils::filterByType<IterDomain>(all_consumer_vals)) {
886 // Track the non-concrete id we were trying to bind index
887 // to, whether from producer or consumer.
888 auto target_id = consumer_id;
889
890 // use mapped producer id when indexing producer
891 if (index_producer) {
892 auto target_id_it = c2p_map.find(consumer_id);
893 if (target_id_it == c2p_map.end()) {
894 // consumer id not found in c2p map
895 // skip binding for this id.
896 continue;
897 }
898 target_id = target_id_it->second;
899 }
900
901 // Exact id will have to be pulled from consumer side as the
902 // producer side are replayed ids.
903 auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
904 consumer_id, IdMappingMode::EXACT);
905
906 index_update_map[exact_concrete_id] = target_id;
907
908 // Keep track of concrete id's that were used for indexing.
909 if (indexing.indexMap().count(exact_concrete_id)) {
910 initial_indexable_map[exact_concrete_id] = exact_concrete_id;
911 }
912 }
913
914 // No contig indexing was done in reference indexing
915 ContigIDs contig_finder(
916 target_tv->domain()->domain(),
917 target_tv->getMaybeRFactorDomain(),
918 target_tv->domain()->contiguity(),
919 {},
920 indexing.indexMap(),
921 GpuLower::current()->divisbleSplitSet(),
922 GpuLower::current()->caMap(),
923 GpuLower::current()->haloInfo(),
924 GpuLower::current()->concretizedBroadcastDomains(),
925 p2c_map);
926
927 auto target_indexing = indexing.updateIndexCompute(
928 target_tv->domain(), index_update_map, contig_finder);
929
930 // Fill validation info.
931 // TODO: cleanup seems possible.
932 if (index_producer) {
933 fillProducerVectorizedContigRootDomains(
934 producer_tv, consumer_tv, c2p_map, contig_finder);
935 } else {
936 fillConsumerVectorizedContigRootDomains(consumer_tv, contig_finder);
937 }
938
939 return IndexFromIdGraph(
940 target_indexing,
941 indexing,
942 index_parameters.initial_concrete_id_index,
943 loop_indexing.loopDomains());
944}
945
946IndexFromIdGraph getPredicateIndexingFromIdGraph(
947 const std::vector<kir::ForLoop*>& loops,
948 TensorView* consumer_tv,
949 kir::ForLoop* unswitch_or_vec_loop,
950 IterDomain* double_buffer_axis,
951 bool is_start_predicate) {
952 // Run replay pass on the loop nest to generate the deterministic
953 // traversal info from loop structure.
954 auto loop_indexing =
955 LoopIndexingAnalysis::fromLoopAndConsumer(loops, consumer_tv);
956
957 // Bind initial index variables to the loop nodes and adjust
958 // according to loop and unswitch info.
959 auto index_parameters = getPredicateInitialIndexParameters(
960 loop_indexing,
961 consumer_tv,
962 unswitch_or_vec_loop,
963 double_buffer_axis,
964 is_start_predicate);
965
966 // Run first backward traversal to generate
967 // loop nest based indexing math.
968 IndexCompute indexing(
969 index_parameters.initial_concrete_id_index,
970 index_parameters.zero_domains,
971 index_parameters.preferred_concrete_ids,
972 index_parameters.concrete_id_to_halo_extent);
973
974 indexing.run(loop_indexing);
975
976 // Map the concrete id indexing back to consumer tv
977 std::unordered_map<IterDomain*, IterDomain*> index_update_map;
978
979 // First collect all iterdomains in consumer transform history.
980 auto all_consumer_vals = DependencyCheck::getAllValsBetween(
981 {consumer_tv->getMaybeRFactorDomain().begin(),
982 consumer_tv->getMaybeRFactorDomain().end()},
983 {consumer_tv->domain()->domain().begin(),
984 consumer_tv->domain()->domain().end()});
985
986 for (IterDomain* consumer_id :
987 ir_utils::filterByType<IterDomain>(all_consumer_vals)) {
988 // Track the non-concrete id we were trying to bind index
989 // to, whether from producer or consumer.
990 auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
991 consumer_id, IdMappingMode::EXACT);
992 index_update_map[exact_concrete_id] = consumer_id;
993 }
994
995 // No contiguity info is used in the predicate indexing pass, the predicate
996 // generation logic that uses the index math generated here will take
997 // contiguity into account. Send an empty ContigID class so nothing is marked
998 // as contiguous.
999 auto contig_finder = ContigIDs::getNonContigIDs();
1000
1001 // Run second backward traversal to map back to the consumer_tv
1002 auto target_indexing = indexing.updateIndexCompute(
1003 consumer_tv->domain(), index_update_map, contig_finder);
1004
1005 return IndexFromIdGraph(
1006 target_indexing,
1007 indexing,
1008 index_parameters.initial_concrete_id_index,
1009 loop_indexing.loopDomains());
1010}
1011
1012namespace {
1013
1014class LoopIndexingTraversal {
1015 enum class TraversalOrder { ForwardTopological, BackwardTopological };
1016
1017 public:
1018 static std::vector<Expr*> forwardTopologicalOrder(
1019 const std::vector<Expr*>& exprs) {
1020 LoopIndexingTraversal traversal(exprs, TraversalOrder::ForwardTopological);
1021 return traversal.getExprList();
1022 }
1023
1024 static std::vector<Expr*> backwardTopologicalOrder(
1025 const std::vector<Expr*>& exprs) {
1026 LoopIndexingTraversal traversal(exprs, TraversalOrder::BackwardTopological);
1027 return traversal.getExprList();
1028 }
1029
1030 private:
1031 explicit LoopIndexingTraversal(
1032 const std::vector<Expr*>& exprs,
1033 TraversalOrder traversal_order);
1034
1035 // Returns the vals following the expression in either
1036 // forward or backward order.
1037 const std::vector<Val*>& nextValsInTraversalOrder(Expr* expr);
1038
1039 // Returns the vals that the expression follows in either
1040 // forward or backward order.
1041 const std::vector<Val*>& prevValsInTraversalOrder(Expr* expr);
1042
1043 // Returns the sorted list according to the given traversal order.
1044 std::vector<Expr*> getExprList();
1045
1046 private:
1047 // Reference to original un-sorted expression list.
1048 const std::vector<Expr*>& exprs_;
1049
1050 // The traversal order in this pass.
1051 const TraversalOrder traversal_order_ = TraversalOrder::ForwardTopological;
1052
1053 // Internal record of concrete id's and it's corresponding
1054 // iterdomain expression that defines the exact index.
1055 std::unordered_map<IterDomain*, Expr*> concrete_id_to_dependency_;
1056};
1057
1058LoopIndexingTraversal::LoopIndexingTraversal(
1059 const std::vector<Expr*>& exprs,
1060 TraversalOrder traversal_order)
1061 : exprs_(exprs), traversal_order_(traversal_order) {
1062 // Populate concrete id dependencies:
1063 for (auto expr : exprs_) {
1064 auto next_ids =
1065 ir_utils::filterByType<IterDomain>(nextValsInTraversalOrder(expr));
1066 for (auto id : next_ids) {
1067 auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
1068 id, IdMappingMode::EXACT);
1069 TORCH_INTERNAL_ASSERT(
1070 concrete_id_to_dependency_.insert(std::make_pair(concrete_id, expr))
1071 .second,
1072 "Repeated dependency, invalid iterdomain traversal.");
1073 }
1074 }
1075}
1076
1077const std::vector<Val*>& LoopIndexingTraversal::nextValsInTraversalOrder(
1078 Expr* expr) {
1079 switch (traversal_order_) {
1080 case TraversalOrder::ForwardTopological:
1081 return expr->outputs();
1082 break;
1083 case TraversalOrder::BackwardTopological:
1084 return expr->inputs();
1085 break;
1086
1087 default:
1088 TORCH_INTERNAL_ASSERT(false, "unimplemented traversal order");
1089 }
1090 return expr->inputs();
1091}
1092
1093const std::vector<Val*>& LoopIndexingTraversal::prevValsInTraversalOrder(
1094 Expr* expr) {
1095 switch (traversal_order_) {
1096 case TraversalOrder::ForwardTopological:
1097 return expr->inputs();
1098 break;
1099 case TraversalOrder::BackwardTopological:
1100 return expr->outputs();
1101 break;
1102
1103 default:
1104 TORCH_INTERNAL_ASSERT(false, "unimplemented traversal order");
1105 }
1106 return expr->inputs();
1107}
1108
1109std::vector<Expr*> LoopIndexingTraversal::getExprList() {
1110 std::deque<Expr*> to_visit(exprs_.begin(), exprs_.end());
1111
1112 // pre-allocate result space.
1113 std::vector<Expr*> result;
1114 result.reserve(exprs_.size());
1115
1116 // Keeps track of visited and inserted expressions.
1117 // An expr is visited if it has been placed in result list.
1118 // An expr is inserted if the traversal has put the expr on
1119 // the top of the stack once. Repeated insertion of the same
1120 // expression would never be observed if the underlying
1121 // dependency of the expressions is cycle free.
1122 std::unordered_set<Expr*> visited, inserted;
1123
1124 while (!to_visit.empty()) {
1125 auto top = to_visit.front();
1126 if (visited.count(top)) {
1127 to_visit.pop_front();
1128 continue;
1129 }
1130
1131 bool ready = true;
1132
1133 for (auto prev_id :
1134 ir_utils::filterByType<IterDomain>(prevValsInTraversalOrder(top))) {
1135 auto prev_expr_it = concrete_id_to_dependency_.find(
1136 GpuLower::current()->caMap()->getConcreteMappedID(
1137 prev_id, IdMappingMode::EXACT));
1138 if (prev_expr_it != concrete_id_to_dependency_.end()) {
1139 auto prev_expr = prev_expr_it->second;
1140 if (!visited.count(prev_expr)) {
1141 ready = false;
1142 to_visit.push_front(prev_expr);
1143 TORCH_INTERNAL_ASSERT(
1144 inserted.insert(prev_expr).second,
1145 "Circular dependency in loop index expressions.");
1146 break;
1147 }
1148 }
1149 }
1150
1151 if (ready) {
1152 visited.insert(top);
1153 result.emplace_back(top);
1154 to_visit.pop_front();
1155 }
1156 }
1157
1158 return result;
1159}
1160
1161} // namespace
1162
1163void LoopIndexingAnalysis::collectOutOfLineExprs() {
1164 // Keep track of all the id's that can be resolved without
1165 // iterdomains on the left of ca axes.
1166 std::unordered_set<IterDomain*> out_of_line_ids;
1167
1168 // Start the set with all the leaf ids.
1169 std::transform(
1170 consumer_tv_->domain()->domain().begin() +
1171 consumer_tv_->getComputeAtPosition(),
1172 consumer_tv_->domain()->domain().end(),
1173 std::inserter(out_of_line_ids, out_of_line_ids.end()),
1174 exactConcreteId);
1175
1176 // Get the original selected list of index expressions
1177 // in reverse topological order.
1178 auto backward_expr_list =
1179 LoopIndexingTraversal::backwardTopologicalOrder(replayed_exprs_);
1180
1181 for (auto expr : backward_expr_list) {
1182 auto id_outputs = ir_utils::filterByType<IterDomain>(expr->outputs());
1183 if (
1184 // Check that all of the outputs are out of line
1185 std::all_of(
1186 id_outputs.begin(),
1187 id_outputs.end(),
1188 [&out_of_line_ids](IterDomain* id) {
1189 return out_of_line_ids.count(
1190 GpuLower::current()->caMap()->getConcreteMappedID(
1191 id, IdMappingMode::EXACT));
1192 })) {
1193 // Record out of line expression
1194 out_of_line_exprs_.push_back(expr);
1195
1196 // Add all of the expression inputs as out of line id's.
1197 auto id_inputs = ir_utils::filterByType<IterDomain>(expr->inputs());
1198 std::transform(
1199 id_inputs.begin(),
1200 id_inputs.end(),
1201 std::inserter(out_of_line_ids, out_of_line_ids.end()),
1202 exactConcreteId);
1203 }
1204 }
1205}
1206
1207std::vector<Expr*> LoopIndexing::getForwardExprList() const {
1208 return LoopIndexingTraversal::forwardTopologicalOrder(index_exprs_);
1209}
1210
1211std::vector<Expr*> LoopIndexing::getBackwardExprList() const {
1212 return LoopIndexingTraversal::backwardTopologicalOrder(index_exprs_);
1213}
1214
1215std::unordered_set<IterDomain*> LoopIndexing::getAllExactConcreteIdSet() const {
1216 std::unordered_set<IterDomain*> all_id_set;
1217 for (auto expr : index_exprs_) {
1218 auto out_ids = ir_utils::filterByType<IterDomain>(expr->outputs());
1219 std::transform(
1220 out_ids.begin(),
1221 out_ids.end(),
1222 std::inserter(all_id_set, all_id_set.end()),
1223 exactConcreteId);
1224
1225 auto in_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
1226 std::transform(
1227 in_ids.begin(),
1228 in_ids.end(),
1229 std::inserter(all_id_set, all_id_set.end()),
1230 exactConcreteId);
1231 }
1232 return all_id_set;
1233}
1234
1235namespace {
1236
1237//! Returns true if id is mapped together with any id in
1238//! the vector ids by permissive compute at map.
1239bool isPermissivelyMappedWithAny(IterDomain* id, const std::vector<Val*>& ids) {
1240 return std::any_of(ids.begin(), ids.end(), [&](Val* val) {
1241 return val->isA<IterDomain>() &&
1242 GpuLower::current()->caMap()->areMapped(
1243 id, val->as<IterDomain>(), IdMappingMode::PERMISSIVE);
1244 });
1245}
1246
1247class LoopIndexingPreferredPathCompute : public IterVisitor {
1248 public:
1249 static std::unordered_set<IterDomain*> compute(
1250 const TensorView* original_tv,
1251 const LoopIndexing& loop_indexing,
1252 bool use_replay_map,
1253 const std::unordered_map<IterDomain*, IterDomain*>& p2c_map) {
1254 LoopIndexingPreferredPathCompute compute;
1255
1256 auto all_concrete_ids = loop_indexing.getAllExactConcreteIdSet();
1257
1258 // Annotate all ids
1259 auto all_original_ids = DependencyCheck::getAllValsBetween(
1260 {original_tv->getMaybeRFactorDomain().begin(),
1261 original_tv->getMaybeRFactorDomain().end()},
1262 {original_tv->domain()->domain().begin(),
1263 original_tv->domain()->domain().end()});
1264
1265 for (auto original_id :
1266 ir_utils::filterByType<IterDomain>(all_original_ids)) {
1267 auto mapped_id = original_id;
1268 if (use_replay_map) {
1269 auto c_id_it = p2c_map.find(original_id);
1270 if (c_id_it == p2c_map.end()) {
1271 continue;
1272 }
1273 mapped_id = c_id_it->second;
1274 }
1275 auto concrete_original_id =
1276 GpuLower::current()->caMap()->getConcreteMappedID(
1277 mapped_id, IdMappingMode::EXACT);
1278 if (all_concrete_ids.count(concrete_original_id)) {
1279 if (original_id->isBroadcast() || original_id->isReduction() ||
1280 original_id->isStride()) {
1281 continue;
1282 }
1283 compute.preferred_path_.insert(concrete_original_id);
1284 }
1285 }
1286
1287 for (auto expr : loop_indexing.getForwardExprList()) {
1288 compute.handle(expr);
1289 }
1290
1291 return compute.preferred_path_;
1292 }
1293
1294 private:
1295 void handle(Expr* e) override {
1296 // If an input ID is marked, propagate the marking to outputs of the
1297 // expression
1298 auto all_iter_inputs = ir_utils::filterByType<IterDomain>(e->inputs());
1299 if (std::any_of(
1300 all_iter_inputs.begin(),
1301 all_iter_inputs.end(),
1302 [&](IterDomain* inp_id) {
1303 return this->preferred_path_.find(
1304 GpuLower::current()->caMap()->getConcreteMappedID(
1305 inp_id, IdMappingMode::EXACT)) !=
1306 this->preferred_path_.end();
1307 })) {
1308 auto all_iter_outputs = ir_utils::filterByType<IterDomain>(e->outputs());
1309
1310 std::transform(
1311 all_iter_outputs.begin(),
1312 all_iter_outputs.end(),
1313 std::inserter(preferred_path_, preferred_path_.end()),
1314 exactConcreteId);
1315 }
1316 }
1317
1318 std::unordered_set<IterDomain*> preferred_path_;
1319};
1320
1321} // namespace
1322
1323// External interface for preferred path propagation.
1324std::unordered_set<IterDomain*> buildLoopIndexingPreferredPath(
1325 const TensorView* original_tv,
1326 const LoopIndexing& loop_indexing,
1327 bool use_replay_map,
1328 std::unordered_map<IterDomain*, IterDomain*> p2c_map) {
1329 return LoopIndexingPreferredPathCompute::compute(
1330 original_tv, loop_indexing, use_replay_map, p2c_map);
1331}
1332
1333// Get an rfactor IterDomain that is mapped with an IterDomain. If
1334// multiple such IDs exist, select one whose input IDs are mapped with
1335// the consumer IDs. This is to ensure the path from the leaf
1336// IterDomains to the root matches with the consumer tensor.
1337IterDomain* getRfactorIDToTraverse(
1338 IterDomain* id,
1339 const std::vector<Val*>& consumer_all_ids) {
1340 const auto& rfactor_ids =
1341 GpuLower::current()->caMap()->getViewRfactorDomainsOfIdGroup(
1342 id, IdMappingMode::PERMISSIVE);
1343
1344 if (rfactor_ids.empty()) {
1345 return nullptr;
1346 }
1347
1348 for (auto rfactor_id : rfactor_ids) {
1349 auto def = rfactor_id->definition();
1350 if (def == nullptr) {
1351 continue;
1352 }
1353
1354 auto rfactor_id_inputs = ir_utils::filterByType<IterDomain>(def->inputs());
1355 if (std::all_of(
1356 rfactor_id_inputs.begin(),
1357 rfactor_id_inputs.end(),
1358 [&](IterDomain* rfactor_id_input) {
1359 return isPermissivelyMappedWithAny(
1360 rfactor_id_input, consumer_all_ids);
1361 })) {
1362 return rfactor_id;
1363 }
1364 }
1365
1366 // No mapped ID found, which means the consumer is a post-view
1367 // tensor. In that case, it shouldn't matter which view path to
1368 // traverse, so just return the first one.
1369 return rfactor_ids.at(0);
1370}
1371
1372} // namespace cuda
1373} // namespace fuser
1374} // namespace jit
1375} // namespace torch
1376