1#include <compute_at_map.h>
2
3#include <disjoint_set.h>
4#include <ir_utils.h>
5#include <lower2device.h>
6#include <root_domain_map.h>
7#include <transform_iter.h>
8
9#include <tuple>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15namespace {
16
17// Is the provided IterDomain an Leaf of provided TensorView and within its
18// computeAtPosition
19bool idIsAComputeAtLeafDomain(IterDomain* id, TensorView* tv) {
20 auto begin = tv->domain()->domain().begin();
21 auto end = tv->domain()->domain().begin() + tv->getComputeAtPosition();
22 return std::find(begin, end, id) != end;
23}
24
25// Is the provided IterDomain an Leaf of provided TensorView
26bool idIsALeafDomain(IterDomain* id, TensorView* tv) {
27 auto begin = tv->domain()->domain().begin();
28 auto end = tv->domain()->domain().end();
29 return std::find(begin, end, id) != end;
30}
31
32} // namespace
33
34IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) {
35 build(fusion);
36
37 if (!allow_self_mapping) {
38 TORCH_INTERNAL_ASSERT(
39 !hasSelfMapping(),
40 "Unsupported domain mapping detected in ",
41 std::get<0>(*self_mapping_info_)->toString(),
42 ". ",
43 std::get<3>(*self_mapping_info_),
44 " domains, ",
45 std::get<1>(*self_mapping_info_)->toString(),
46 " and ",
47 std::get<2>(*self_mapping_info_)->toString(),
48 ", are mapped with each other.");
49 }
50}
51
52//! Map corresponding inputs and outputs of swizzle op together
53//! on the given disjoint set, if the given id is an output
54//! of a swizzle operator.
55//!
56//! The current usage of swizzle operator is local to each tensor
57//! itself, so they should not affect exact or permissive mapping
58//! between iterdomains on different tensor domains.
59//! TODO:
60//! Exact mapping based index hoisting of swizzled iterdomains
61//! is disabled currently and will be re-enabled in the next
62//! few build out steps.
63void mapMaybeSwizzleOp(
64 DisjointSets<IterDomain*>& disjoint_sets,
65 IterDomain* id) {
66 if (auto swizzle_2d = dynamic_cast<Swizzle2D*>(id->definition())) {
67 // Map each input to its corresponding output on the given
68 // disjoint set.
69 disjoint_sets.mapEntries(swizzle_2d->inX(), swizzle_2d->outX());
70 disjoint_sets.mapEntries(swizzle_2d->inY(), swizzle_2d->outY());
71 }
72}
73
74bool IterDomainGraph::exprsMap(
75 Expr* first,
76 Expr* second,
77 bool forward,
78 const DisjointSets<IterDomain*>& id_map) {
79 if (first == nullptr || second == nullptr) {
80 return false;
81 }
82
83 if (first->etype() != second->etype()) {
84 return false;
85 }
86
87 TORCH_INTERNAL_ASSERT(
88 first->etype() == ExprType::Merge || first->etype() == ExprType::Split,
89 "Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n",
90 first->toString());
91
92 auto first_ids = ir_utils::filterByType<IterDomain>(
93 forward ? first->inputs() : first->outputs())
94 .vector();
95
96 auto second_ids = ir_utils::filterByType<IterDomain>(
97 forward ? second->inputs() : second->outputs())
98 .vector();
99
100 TORCH_INTERNAL_ASSERT(
101 first_ids.size() == second_ids.size(),
102 "Expected number of ",
103 (forward ? "inputs" : "outputs"),
104 " to match for\n",
105 first->toString(),
106 second->toString());
107
108 {
109 std::vector<std::pair<IterDomain*, IterDomain*>> zipped_ids;
110
111 std::transform(
112 first_ids.begin(),
113 first_ids.end(),
114 second_ids.begin(),
115 std::back_inserter(zipped_ids),
116 [](IterDomain* first, IterDomain* second) {
117 return std::make_pair(first, second);
118 });
119
120 if (std::any_of(
121 zipped_ids.begin(),
122 zipped_ids.end(),
123 [&](std::pair<IterDomain*, IterDomain*> id_pair) {
124 return !id_map.strictAreMapped(id_pair.first, id_pair.second);
125 })) {
126 return false;
127 }
128 }
129
130 if (first->isA<Merge>() && !forward) {
131 // Can't back prop through merge without making sure one dimension actually
132 // is identical extents.
133 auto merge0 = first->as<Merge>();
134 auto merge1 = second->as<Merge>();
135
136 auto extent_0o = merge0->outer()->extent();
137 auto extent_0i = merge0->inner()->extent();
138 auto extent_1o = merge1->outer()->extent();
139 auto extent_1i = merge1->inner()->extent();
140
141 auto extent_0_match = extent_0o->sameAs(extent_1o) ||
142 (extent_0o->isConstInt() && extent_1o->isConstInt() &&
143 extent_0o->evaluateInt() == extent_1o->evaluateInt());
144
145 auto extent_1_match = extent_0i->sameAs(extent_1i) ||
146 (extent_0i->isConstInt() && extent_1i->isConstInt() &&
147 extent_0i->evaluateInt() == extent_1i->evaluateInt());
148
149 if (!(extent_0_match || extent_1_match)) {
150 return false;
151 }
152 }
153
154 if (first->isA<Split>()) {
155 auto first_split = first->as<Split>();
156 auto second_split = second->as<Split>();
157 if (!first_split->factor()->sameAs(second_split->factor()) ||
158 first_split->innerSplit() != second_split->innerSplit() ||
159 !first_split->startOffset()->sameAs(second_split->startOffset()) ||
160 !first_split->stopOffset()->sameAs(second_split->stopOffset())) {
161 return false;
162 }
163 }
164
165 return true;
166}
167
168void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
169 if (first == nullptr || second == nullptr) {
170 return;
171 }
172
173 if (!exprsMap(first, second, forward, exact_nodes_)) {
174 return;
175 }
176
177 auto first_ids = ir_utils::filterByType<IterDomain>(
178 forward ? first->outputs() : first->inputs())
179 .vector();
180 auto second_ids = ir_utils::filterByType<IterDomain>(
181 forward ? second->outputs() : second->inputs())
182 .vector();
183 TORCH_INTERNAL_ASSERT(
184 first_ids.size() == second_ids.size(),
185 "This should be unreachable, if transformation expressions match, their number of inputs and outputs should as well.\n However found:\n",
186 first->toString(),
187 "\nand\n",
188 second->toString());
189 for (auto out_i : c10::irange(first_ids.size())) {
190 exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
191 permissive_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
192 }
193}
194
195namespace {
196
197// Returns a pair of mapped IDs
198c10::optional<std::pair<IterDomain*, IterDomain*>> detectMappablePair(
199 const std::vector<IterDomain*>& ids,
200 const IterDomainGraph& id_graph) {
201 for (auto id1 : ids) {
202 for (auto id2 : ids) {
203 if (id1 == id2) {
204 continue;
205 }
206 if (id_graph.permissiveNodes().disjointSetMap().at(id1)->has(id2)) {
207 return std::make_pair(id1, id2);
208 }
209 }
210 }
211
212 return {};
213}
214
215// It is assumed that for any tensor represented by a list of domains,
216// those domains should never be mapped with each other. It may be
217// possible to lift this assumption, but it's unclear if it could
218// matter in practice.
219c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
220findFirstSelfMapping(Fusion* fusion, const IterDomainGraph& id_graph) {
221 for (auto tv : ir_utils::allTvs(fusion)) {
222 // For each tensor, make sure root, rfactor and leaf domains
223 // should not include domains that are mapped with another domain
224 // in the same set of domains. This may be overly conservative,
225 // and it maybe enough to check the root domains.
226
227 // Root domains
228 auto self_mappped_root_pair =
229 detectMappablePair(tv->getRootDomain(), id_graph);
230 if (self_mappped_root_pair.has_value()) {
231 return std::make_tuple(
232 tv,
233 self_mappped_root_pair->first,
234 self_mappped_root_pair->second,
235 "Root");
236 }
237
238 // Rfactor domains
239 if (tv->hasRFactor()) {
240 auto self_mappped_rf_pair =
241 detectMappablePair(tv->getRFactorDomain(), id_graph);
242 if (self_mappped_rf_pair.has_value()) {
243 return std::make_tuple(
244 tv,
245 self_mappped_rf_pair->first,
246 self_mappped_rf_pair->second,
247 "RFactor");
248 }
249 }
250
251 // Leaf domains
252 auto self_mappped_leaf_pair =
253 detectMappablePair(tv->domain()->domain(), id_graph);
254 if (self_mappped_leaf_pair.has_value()) {
255 return std::make_tuple(
256 tv,
257 self_mappped_leaf_pair->first,
258 self_mappped_leaf_pair->second,
259 "Leaf");
260 }
261 }
262 return c10::nullopt;
263}
264
265} // namespace
266
267void IterDomainGraph::build(Fusion* fusion) {
268 FusionGuard fg(fusion);
269
270 // Initialize a node for every iteration domain
271 for (auto tv : ir_utils::allTvs(fusion)) {
272 const auto& root_domain = tv->getRootDomain();
273 const auto& domain = tv->domain()->domain();
274
275 // Grab all values in the history of the tensor view's domain
276 auto all_vals = DependencyCheck::getAllValsBetween(
277 {root_domain.begin(), root_domain.end()},
278 {domain.begin(), domain.end()});
279
280 // Filter so we only have iteration domains (ignore Ints used in split)
281 auto all_ids = ir_utils::filterByType<IterDomain>(all_vals);
282
283 // Check is this domain is a consumer of a view-like operation
284 bool view_like_domain = tv->domain()->hasViewLikeRFactor();
285
286 for (auto id : all_ids) {
287 // Check if this id is a view like rfactor id
288 bool is_view_rfactor_id = false;
289 if (view_like_domain && id->isRFactorProduct()) {
290 // If the tensor domain is a view like domain, and the iteration domain
291 // is marked as an rfactor product and is in the rfactor domain, it's a
292 // view like rfactor iteration domain
293 const auto& rfactor_domain = tv->domain()->getMaybeRFactorDomain();
294 if (std::find(rfactor_domain.begin(), rfactor_domain.end(), id) !=
295 rfactor_domain.end()) {
296 is_view_rfactor_id = true;
297 }
298 }
299 bool is_leaf_id =
300 std::find(domain.begin(), domain.end(), id) != domain.end();
301 initializeId(id, is_view_rfactor_id, is_leaf_id);
302 }
303 }
304
305 // All ID's are initialized, start connecting them on the permissive, exact,
306 // and loop dimensions.
307
308 for (auto expr : fusion->exprs()) {
309 if (!ir_utils::isTvOp(expr)) {
310 continue;
311 }
312
313 auto tv_outputs = ir_utils::filterByType<TensorView>(expr->outputs());
314 TensorView* first_output_tv = nullptr;
315
316 for (auto c_tv : tv_outputs) {
317 if (first_output_tv == nullptr) {
318 first_output_tv = c_tv;
319 } else {
320 // Map multi outputs of an expression to each other. c is current
321 // output, and f as first output. Keep consistent with the later section
322 // of producer and consumers. Which here producer is now "first output",
323 // and consumer is still consumer. One exception is how the
324 // domains left of CA positions are handled in the Parallel
325 // map. Those domains are not mapped in producer and consumer
326 // mappings as they do not share loops, but are mapped in the
327 // case of mapping multiple outputs since they do share the
328 // same loops.
329
330 TORCH_INTERNAL_ASSERT(
331 c_tv->getRootDomain().size() ==
332 first_output_tv->getRootDomain().size(),
333 "Multiple outputs with mismatched dimensions is not supported. ",
334 "Only supported case is welford op where all outputs tvs have identical domains.");
335 // p->f, c->c
336 std::unordered_map<IterDomain*, IterDomain*> c2f_root_map;
337 for (const auto i :
338 c10::irange(first_output_tv->getRootDomain().size())) {
339 c2f_root_map.insert(std::make_pair(
340 c_tv->getRootDomain()[i], first_output_tv->getRootDomain()[i]));
341 }
342
343 // Multi output mapping, outputs are required to have the same domain
344 // and same transformations, so they can be mapped in permissive/exact,
345 // and when within compute at position of domain()->domain() in the
346 // parallel map.
347 auto replay_FasC = BestEffortReplay(
348 first_output_tv->domain()->domain(),
349 c_tv->domain()->domain(),
350 c2f_root_map);
351
352 auto c2f_map = replay_FasC.getReplay();
353
354 // Map the entire replay map between the multiple
355 // consumers even for the Parallel map as they share the same
356 // loop.
357 for (auto entry : c2f_map) {
358 auto c_id = entry.first;
359 auto f_id = entry.second;
360 // Map the id's together
361 permissive_nodes_.mapEntries(f_id, c_id);
362 exact_nodes_.mapEntries(f_id, c_id);
363 if (idIsALeafDomain(f_id, first_output_tv)) {
364 loop_nodes_.mapEntries(f_id, c_id);
365 }
366 sibling_sets_.mapEntries(f_id, c_id);
367 }
368 }
369
370 auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
371
372 for (auto p_tv : tv_inputs) {
373 // If outside computeAt axis, we don't want to directly map
374 // consumer/producer as their thread mappings could change as long as
375 // it's across shared/global memory.
376 auto pairwise_map = PairwiseRootDomainMap(p_tv, c_tv);
377 const auto& permissive_c2p_root_map =
378 pairwise_map.mapConsumerToProducer(c_tv->domain(), p_tv->domain());
379
380 // Look for matching ID transformations in producer and consumer, replay
381 // producer as consumer. We want to replay producer as consumer instead
382 // of the other way around since consumer may have some broadcasted axes
383 // producer doesn't have merged into loops producer may use. If we did
384 // consumer as producer we wouldn't have this information in the
385 // mapping. If we're using this map for indexing, we do not want to
386 // propagate broadcast mismatches. If we're using it to identify loop
387 // nests, we do want to propagate mismatches.
388 auto permissive_replay_PasC =
389 BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map);
390
391 const auto& permissive_c2p_map = permissive_replay_PasC.getReplay();
392 const auto permissive_disjoint_sets =
393 permissive_replay_PasC.getDisjointSets();
394
395 // For exact mapings do not map any broadcast dimensions to
396 // non-broadcast dimensions. Prevent any broadcasted axes being mapped
397 // to non-broadcasted axes.
398 auto exact_c2p_root_map =
399 PairwiseRootDomainMap(p_tv, c_tv, true)
400 .mapConsumerToProducer(c_tv->domain(), p_tv->domain());
401
402 // Same as permissive above but for exact
403 auto exact_replay_PasC = BestEffortReplay(
404 p_tv->domain()->domain(),
405 c_tv->domain()->domain(),
406 exact_c2p_root_map);
407
408 const auto& exact_c2p_map = exact_replay_PasC.getReplay();
409
410 for (auto entry : exact_c2p_map) {
411 auto c_id = entry.first;
412 auto p_id = entry.second;
413 exact_nodes_.mapEntries(c_id, p_id);
414 consumers_.at(p_id).pushBack(c_id);
415 producers_.at(c_id).pushBack(p_id);
416
417 // Add the swizzle inputs to the same
418 // disjoint set as well if either c_id
419 // or p_id is swizzle output.
420 mapMaybeSwizzleOp(exact_nodes_, p_id);
421 mapMaybeSwizzleOp(exact_nodes_, c_id);
422 }
423
424 for (auto entry : permissive_c2p_map) {
425 auto c_id = entry.first;
426 auto p_id = entry.second;
427 if (idIsAComputeAtLeafDomain(p_id, p_tv)) {
428 loop_nodes_.mapEntries(c_id, p_id);
429 } else {
430 // When there are trivial reductions merged with other dims, `p_id`
431 // might not be a compute at leaf domain of `p_tv`, but it actually
432 // has an equivalent compute at leaf domain. For that case, we map
433 // the equivalent compute at leaf domain.
434 for (unsigned int i = 0; i < p_tv->getComputeAtPosition(); i++) {
435 auto id = p_tv->axis(i);
436 if (permissive_disjoint_sets.permissiveAreMapped(p_id, id)) {
437 loop_nodes_.mapEntries(c_id, id);
438 }
439 }
440 }
441 permissive_nodes_.mapEntries(c_id, p_id);
442 consumers_.at(p_id).pushBack(c_id);
443 producers_.at(c_id).pushBack(p_id);
444
445 // Add the swizzle inputs to the same
446 // disjoint set as well if either c_id
447 // or p_id is swizzle output.
448 mapMaybeSwizzleOp(permissive_nodes_, p_id);
449 mapMaybeSwizzleOp(permissive_nodes_, c_id);
450 }
451
452 // Make sure we always get root mapping for the permissive map.
453 // Because of forwarding we could otherwise miss some root mappings.
454 for (auto entry : permissive_c2p_root_map) {
455 auto c_id = entry.first;
456 auto p_id = entry.second;
457 // Map the id's together
458 permissive_nodes_.mapEntries(c_id, p_id);
459 consumers_.at(p_id).pushBack(c_id);
460 producers_.at(c_id).pushBack(p_id);
461 }
462 }
463 }
464 }
465
466 // Explicitly map through rfactor transformations, if we have an op like:
467 //
468 // T1[x, y*z] = view(T0[x*y, z])
469 // T3[x, y*z] = view(T2[x*y, z])
470 // T4 = T0 + T2
471 //
472 // We want to map T1 and T3's rfactor transformations together by playing the
473 // transformations forward since their root domains map. If instead we have:
474 //
475 // T1[x, y*z] = view(T0[x*y, z])
476 // T3[x, y*z] = view(T2[x*y, z])
477 // T4 = T1 + T3
478 //
479 // Then we wouldn't have a mapping of T1 and T3's root domain, we'd have a
480 // mapping of their rfactor domain, so we would want to map T1 and T3's
481 // rfactor transformations starting at their rfactor domains.
482 //
483 // Therefore we'll explicitly map rfactor transformation iteration domains
484 // forward and backwards. Something similar could happen with rfactor of root
485 // domains, though it seems mapping rfactor reduction domains aren't that
486 // important. Mapping view transformations is more important since view is
487 // part of the compute definition so having the map through the
488 // transformations makes it easy to check if different view operations are
489 // consistent with eachother.
490
491 auto all_tvs = ir_utils::allTvs(fusion);
492 std::vector<TensorView*> all_consumer_tvs;
493 std::copy_if(
494 all_tvs.begin(),
495 all_tvs.end(),
496 std::back_inserter(all_consumer_tvs),
497 [](TensorView* tv) { return !tv->isFusionInput() && tv->hasRFactor(); });
498
499 // IterDomains could have multiple uses defined in the fusion if multiple
500 // transformations were redefined (more than one transform propagation pass
501 // was run and retransformed sections of the graph). We're going to make a new
502 // uses map so we can easily process the actual uses of IterDomains. We
503 // actually only need rfactor uses for this section of mapping, so we'll limit
504 // this map to only rfactor transformations.
505 std::unordered_map<IterDomain*, Expr*> rfactor_id_uses;
506
507 // Order of traversal is important for processing all the rfactor ids as the
508 // first pass will go forward through expressions and the second pass will
509 // traverse backwards through them. ID's will be unique in this vector,
510 // enforced when building it since it's built with rfactor_id_uses.
511 std::vector<IterDomain*> rfactor_id_order;
512
513 // Grab all the rfactor ids.
514 for (auto consumer_tv : all_consumer_tvs) {
515 auto exprs = StmtSort::getExprs(
516 fusion,
517 {consumer_tv->getMaybeRFactorDomain().begin(),
518 consumer_tv->getMaybeRFactorDomain().end()});
519 for (auto expr : exprs) {
520 auto rfactor_inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
521 TORCH_INTERNAL_ASSERT(
522 expr->isA<Split>() || expr->isA<Merge>(),
523 "Wasn't expecting the expression type of:\n",
524 expr->toString(),
525 "\nto be an expression defined in an rfactor transformation.");
526 for (auto rfactor_inp_id : rfactor_inp_ids) {
527 TORCH_INTERNAL_ASSERT(
528 rfactor_id_uses.find(rfactor_inp_id) == rfactor_id_uses.end(),
529 "Was expecting iter domains to only have one active transformation but found id ",
530 rfactor_inp_id->toString(),
531 " used in\n",
532 rfactor_id_uses.at(rfactor_inp_id),
533 "\nand\n",
534 expr->toString());
535 rfactor_id_uses.emplace(std::make_pair(rfactor_inp_id, expr));
536 rfactor_id_order.push_back(rfactor_inp_id);
537 }
538 }
539 for (auto rfactor_id : consumer_tv->getMaybeRFactorDomain()) {
540 if (rfactor_id->isRFactorProduct()) {
541 rfactor_id_uses.emplace(std::make_pair(rfactor_id, nullptr));
542 rfactor_id_order.push_back(rfactor_id);
543 }
544 }
545 }
546
547 // if prop_forward we're going forward through transformations and
548 // expressions, meaning if inputs of expressions map then we map their
549 // outputs, otherwise we're traversing backwards, meaning if outputs of
550 // expressions map then we map their inputs.
551 for (auto prop_forward : {true, false}) {
552 std::unordered_set<Expr*> visited_exprs;
553
554 for (auto rfactor_id_i : c10::irange(rfactor_id_order.size())) {
555 auto first_rfactor_id = prop_forward
556 ? rfactor_id_order[rfactor_id_i]
557 : rfactor_id_order[rfactor_id_order.size() - 1 - rfactor_id_i];
558
559 // At should be safe since we made rfactor_id_order and rfactor_id_uses at
560 // the same time so they should have the same exact entries.
561 auto first_expr = prop_forward ? rfactor_id_uses.at(first_rfactor_id)
562 : first_rfactor_id->definition();
563
564 if (first_expr == nullptr) {
565 continue;
566 }
567
568 if (visited_exprs.find(first_expr) != visited_exprs.end()) {
569 continue;
570 }
571 visited_exprs.emplace(first_expr);
572
573 // Only need to be concerned here with mapping across rfactor iter
574 // domains, so isolate out those.
575 auto all_exact_map_ids = exact_nodes_.getDisjointSetOf(first_rfactor_id);
576 std::vector<IterDomain*> exact_map_rf_ids;
577 std::copy_if(
578 all_exact_map_ids.vector().begin(),
579 all_exact_map_ids.vector().end(),
580 std::back_inserter(exact_map_rf_ids),
581 [](IterDomain* id) { return id->isRFactorProduct(); });
582
583 for (auto exact_map_rf_id : exact_map_rf_ids) {
584 if (exact_map_rf_id == first_rfactor_id) {
585 continue;
586 }
587 // If there's an input with an rfactor domain we could have an exact
588 // mapped rfactor id that's on the input meaning it wouldn't have an
589 // entry in rfactor_id_uses
590 auto other_use =
591 rfactor_id_uses.find(exact_map_rf_id) == rfactor_id_uses.end()
592 ? nullptr
593 : rfactor_id_uses.at(exact_map_rf_id);
594 auto other_expr =
595 prop_forward ? other_use : exact_map_rf_id->definition();
596
597 if (other_expr == nullptr) {
598 continue;
599 }
600
601 if (visited_exprs.find(other_expr) != visited_exprs.end()) {
602 continue;
603 }
604
605 mapThroughExpr(first_expr, other_expr, prop_forward);
606 }
607 }
608 }
609 self_mapping_info_ = findFirstSelfMapping(fusion, *this);
610}
611
612void IterDomainGraph::initializeId(
613 IterDomain* id,
614 bool is_view_rfactor_id,
615 bool is_leaf_id) {
616 permissive_nodes_.initializeSet(id);
617 exact_nodes_.initializeSet(id);
618 if (is_leaf_id) {
619 loop_nodes_.initializeSet(id);
620 }
621 consumers_[id] = {};
622 producers_[id] = {};
623 sibling_sets_.initializeSet(id);
624
625 all_ids_.pushBack(id);
626
627 if (is_view_rfactor_id) {
628 view_rfactor_ids_.emplace(id);
629 }
630}
631
632ComputeAtMap::ComputeAtMap(Fusion* fusion)
633 : id_graph_(fusion), fusion_(fusion) {
634 build(fusion);
635}
636
637void ComputeAtMap::build(Fusion* fusion) {
638 trivial_reduction_info_.build(fusion);
639 buildConcreteIds();
640}
641
642void ComputeAtMap::validateAndPropagatePType() {
643 for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) {
644 ParallelType common_ptype = ParallelType::Serial;
645 for (auto id : loop_disjoint_set->vector()) {
646 auto id_ptype = id->getParallelType();
647 TORCH_INTERNAL_ASSERT(
648 id_ptype == common_ptype || id_ptype == ParallelType::Serial ||
649 common_ptype == ParallelType::Serial,
650 "Issue validating parallel type disjoint ptype is, ",
651 common_ptype,
652 " but found in the set the id: ",
653 id->toString());
654 common_ptype =
655 common_ptype == ParallelType::Serial ? id_ptype : common_ptype;
656 }
657
658 for (auto id : loop_disjoint_set->vector()) {
659 id->parallelize(common_ptype);
660 }
661 }
662}
663
664void ComputeAtMap::allocateIndexVariables() {
665 // Run through all disjoint sets registered in loop map,
666 // all lowered kir::ForLoop will correspond to one of the disjoint sets
667 // and we only need one index variable for each set.
668 for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) {
669 ParallelType ptype;
670 // first allocate thread and grid parallel indices:
671 // The validation pass will check that the parallel bindings within the
672 // loop nodes are consistent so all the loops within this disjoint set
673 // will be realized implicitly using parallel index variables.
674 if (std::any_of(
675 loop_disjoint_set->vector().begin(),
676 loop_disjoint_set->vector().end(),
677 [&ptype](IterDomain* id) {
678 if (id->isThread() &&
679 // Halo extended parallel loops currently are handled
680 // differently and an index variable would still
681 // be allocated in this case.
682 (GpuLower::current()->haloInfo()->getExtent(id) == nullptr)) {
683 ptype = id->getParallelType();
684 return true;
685 }
686 return false;
687 })) {
688 loop_index_variable_map_[loop_disjoint_set.get()] =
689 NamedScalar::getParallelIndex(ptype);
690 continue;
691 }
692
693 // All loops in this set are non-parallel, non-concretized broadcast
694 // iterdomains, their "index variable" should be zero.
695 if (std::all_of(
696 loop_disjoint_set->vector().begin(),
697 loop_disjoint_set->vector().end(),
698 [](IterDomain* id) { return id->isBroadcast(); })) {
699 loop_index_variable_map_[loop_disjoint_set.get()] = fusion_->zeroVal();
700 continue;
701 }
702
703 // Allocate variable for the iterdomains:
704 auto concrete_loop_id_it = concrete_id_cache_.find(loop_disjoint_set);
705 TORCH_INTERNAL_ASSERT(
706 concrete_loop_id_it != concrete_id_cache_.end(),
707 "Concrete id not computed");
708
709 auto concrete_loop_id = concrete_loop_id_it->second;
710
711 // Need to allocate double buffered loop differently.
712 if (GpuLower::current()->doubleBufferInfo().isDoubleBufferedIterDomain(
713 concrete_loop_id)) {
714 // Allocate index variable for each stage of the double buffered loop.
715 double_buffered_loop_index_variable_map_[loop_disjoint_set.get()] =
716 std::make_unique<DoubleBufferIndices>(DoubleBufferIndices(
717 {{DoubleBufferLoopStage::Prolog,
718 IrBuilder::create<Int>(c10::nullopt)},
719 {DoubleBufferLoopStage::Main,
720 IrBuilder::create<Int>(c10::nullopt)},
721 {DoubleBufferLoopStage::Epilog,
722 IrBuilder::create<Int>(c10::nullopt)}}));
723 } else {
724 // Everything now should be serial concrete loops,
725 // we just allocate a loop index integer for each set of loops.
726 loop_index_variable_map_[loop_disjoint_set.get()] =
727 IrBuilder::create<Int>(c10::nullopt);
728 }
729 }
730}
731
732Val* ComputeAtMap::getIndexVariable(
733 IterDomain* id,
734 DoubleBufferLoopStage double_buffer_loop_stage) const {
735 TORCH_INTERNAL_ASSERT(
736 id_graph_.loopNodes().mappingExists(id),
737 "Index Variable: no index variable allocated as ",
738 id->toString(),
739 " is not registered in loop map");
740 const auto* loop_set = &(id_graph_.loopNodes().getDisjointSetOf(id));
741
742 // Check if this loop was modified by double buffer pass.
743 bool is_double_buffer_iterdomain =
744 GpuLower::current()->doubleBufferInfo().isDoubleBufferedIterDomain(id);
745
746 if (is_double_buffer_iterdomain) {
747 // Use dedicated double buffer index variable if the loop is double buffer
748 // loop
749 if (double_buffer_loop_stage == DoubleBufferLoopStage::NotApplicable) {
750 // The double buffered loop stages are created after the loop nest
751 // lowering phase so this function will be querried before the double
752 // buffer pass. At that point, no forloop has any double buffer
753 // stage defined, and we just default to using the main stage index.
754 double_buffer_loop_stage = DoubleBufferLoopStage::Main;
755 }
756 return double_buffered_loop_index_variable_map_.at(loop_set)->at(
757 double_buffer_loop_stage);
758 } else {
759 return loop_index_variable_map_.at(loop_set);
760 }
761}
762
763bool ComputeAtMap::areMapped(
764 IterDomain* id0,
765 IterDomain* id1,
766 IdMappingMode mode) const {
767 return disjointSetOf(id0, mode)->has(id1);
768}
769
770namespace {
771
772// Validate a LOOP concrete ID has the complete ID set required for
773// indexing. See issue #1655 and FusionIncompleteConcreteID for an
774// example fusion that fails with this validation. Fixing this issue
775// would require creating a reference IterDomain with all the
776// necessary root ID for for loop extent generation, for indexing, and for
777// predication.
778//
779// root_ids_of_all_ids and root_ids_of_concrete_id consist of EXACT
780// concrete IDs.
781void validateCompletenessOfLoopConcreteID(
782 IterDomain* concrete_id,
783 const ComputeAtMap& ca_map,
784 const TrivialReductionInfo& trivial_reduction_info,
785 // All root id's of all IDs in the disjoint id set
786 const std::unordered_set<IterDomain*>& root_ids_of_all_ids,
787 // Map from a root id to the concrete id's it's represented in
788 const std::unordered_set<IterDomain*>& root_ids_of_concrete_id,
789 const std::unordered_map<IterDomain*, std::vector<IterDomain*>>&
790 root_id_to_maybe_concrete_ids,
791 // Disjoint set just for printing
792 const std::vector<IterDomain*>& id_set,
793 // All the candidate concrete IDs found for this disjoint id set
794 const std::vector<IterDomain*>& maybe_concrete_ids) {
795 std::vector<IterDomain*> root_ids_not_found_with_concrete_id;
796
797 for (auto root_id : root_ids_of_all_ids) {
798 if (root_ids_of_concrete_id.find(root_id) !=
799 root_ids_of_concrete_id.end()) {
800 continue;
801 }
802
803 // None of the root IDs of the conrete ID is exactly mapped with
804 // root_id.
805
806 // It is still a valid concrete ID if it has a non-broadcast
807 // root ID that is mapped with root_id.
808 if ((root_id->isBroadcast() || trivial_reduction_info.isDerived(root_id)) &&
809 std::any_of(
810 root_ids_of_concrete_id.begin(),
811 root_ids_of_concrete_id.end(),
812 [&](auto root_id_of_concrete_id) {
813 return !root_id_of_concrete_id->isBroadcast() &&
814 !trivial_reduction_info.isDerived(root_id_of_concrete_id) &&
815 ca_map.areMapped(
816 root_id,
817 root_id_of_concrete_id,
818 IdMappingMode::PERMISSIVE);
819 })) {
820 continue;
821 }
822
823 // If all of the corresponding maybe-concrete IDs are exactly
824 // mapped with the concrete ID, this missing root_id is not a
825 // problem. This can happen with reduction rfactor, e.g.,
826 // FusionAdvancedLowering1.
827 if (std::all_of(
828 root_id_to_maybe_concrete_ids.at(root_id).begin(),
829 root_id_to_maybe_concrete_ids.at(root_id).end(),
830 [&](auto maybe_concrete_id) {
831 return ca_map.areMapped(
832 concrete_id, maybe_concrete_id, IdMappingMode::EXACT);
833 })) {
834 continue;
835 }
836
837 root_ids_not_found_with_concrete_id.push_back(root_id);
838 }
839
840 if (root_ids_not_found_with_concrete_id.empty()) {
841 return;
842 }
843
844 // Error detected as some root IDs are not accounted for by the
845 // concrete ID.
846 std::stringstream error_msg;
847 error_msg << "IDs: " << ir_utils::toString(id_set);
848 error_msg << ", concrete ID: " << concrete_id->toString();
849 error_msg << ", maybe concrete IDs: "
850 << ir_utils::toString(maybe_concrete_ids);
851 error_msg << ", all root IDs:";
852 for (auto root_id : root_ids_of_all_ids) {
853 error_msg << " " << root_id->toString();
854 }
855 error_msg << ", root IDs not found with concrete ID: ";
856 for (auto id : root_ids_not_found_with_concrete_id) {
857 error_msg << " " << id->toString();
858 }
859 TORCH_INTERNAL_ASSERT(
860 false, "Concrete ID failed to cover all root IDs. ", error_msg.str());
861}
862
863} // namespace
864
865IterDomain* ComputeAtMap::computeConcreteId(
866 IterDomain* id,
867 IdMappingMode mode) {
868 const auto& disjoint_set_shared_ptr = disjointSetOf(id, mode);
869
870 TORCH_INTERNAL_ASSERT(
871 disjoint_set_shared_ptr->vector().size(),
872 "Empty disjoint set found for ",
873 id->toString());
874
875 if (disjoint_set_shared_ptr->vector().size() == 1) {
876 // If only one entry in the disjoint set, by definition the existing ID has
877 // to be the concrete ID.
878 return disjoint_set_shared_ptr->vector().front();
879 }
880
881 // Grab a set of candidate concrete_ids, we track towards the consumers in the
882 // ID group as one of those is guaranteed to be a valid concrete id.
883 VectorOfUniqueEntries<IterDomain*> maybe_concrete_ids;
884 for (auto id : disjoint_set_shared_ptr->vector()) {
885 bool id_output = true;
886 for (auto consumer_id : id_graph_.consumers().at(id).vector()) {
887 if (disjoint_set_shared_ptr->has(consumer_id)) {
888 id_output = false;
889 break;
890 }
891 }
892 if (id_output) {
893 maybe_concrete_ids.pushBack(id);
894 }
895 }
896
897 // Shouldn't ever happen, it would mean there's an error somewhere in the
898 // graph.
899 TORCH_INTERNAL_ASSERT(
900 maybe_concrete_ids.vector().size(),
901 "No potential concrete_id's found for ",
902 id->toString());
903
904 if (maybe_concrete_ids.vector().size() == 1) {
905 return maybe_concrete_ids.vector().front();
906 }
907
908 // The concrete_id should have the most roots it can trace back to that are
909 // iter domains, (non-broadcast/non-reduction). We don't trace back through
910 // view operations, so the one with the most iter root domains is the concrete
911 // ID.
912 IterDomain* concrete_id = nullptr;
913 int max_iter_root_count = 0;
914 int max_bcast_root_count = 0;
915
916 // For the LOOP map, the concrete ID must account for all root IDs
917 // of all of the IDs in each disjoit set. At least those ID's that are
918 // non-broadcast/non-reduction. As broadcast is only important here if it's
919 // concretized in the set. Track information so we can later make sure the
920 // concrete id has accounted for all iter domains meaning it has a correct
921 // loop size.
922 std::unordered_set<IterDomain*> root_ids_of_all_ids;
923 std::unordered_set<IterDomain*> root_ids_of_concrete_id;
924 std::unordered_map<IterDomain*, std::vector<IterDomain*>>
925 root_id_to_maybe_concrete_ids;
926
927 // Populate the above information, look for the concrete id, validate the loop
928 // concrete ID.
929 for (auto maybe_concrete_id : maybe_concrete_ids.vector()) {
930 std::unordered_set<IterDomain*> root_ids;
931 std::deque<IterDomain*> to_visit;
932
933 to_visit.push_back(maybe_concrete_id);
934 while (to_visit.size()) {
935 auto current_id = to_visit.front();
936 to_visit.pop_front();
937 if (isViewRfactor(current_id)) {
938 root_ids.emplace(current_id);
939 continue;
940 }
941
942 // push back producer IterDomains or add root if they don't exist
943 auto producer_vals = ir_utils::producerValsOf(current_id);
944 auto producer_ids = ir_utils::filterByType<IterDomain>(producer_vals);
945
946 if (producer_ids.empty()) {
947 root_ids.emplace(current_id);
948 } else {
949 to_visit.insert(
950 to_visit.end(), producer_ids.begin(), producer_ids.end());
951 }
952 }
953
954 if (mode == IdMappingMode::LOOP) {
955 std::transform(
956 root_ids.begin(),
957 root_ids.end(),
958 std::inserter(root_ids_of_all_ids, root_ids_of_all_ids.end()),
959 [&](const auto root_id) {
960 auto exact_concrete_id =
961 getConcreteMappedID(root_id, IdMappingMode::EXACT);
962 root_id_to_maybe_concrete_ids[exact_concrete_id].push_back(
963 maybe_concrete_id);
964 return exact_concrete_id;
965 });
966 }
967
968 int bcast_root_count = std::count_if(
969 root_ids.begin(), root_ids.end(), [&](IterDomain* root_id) {
970 return root_id->isBroadcast()
971 // TODO: This shouldn't have a negative impact, but (emperically)
972 // might not be necessary
973 || trivial_reduction_info_.isDerived(root_id);
974 });
975 int iter_root_count = (int)root_ids.size() - bcast_root_count;
976 if (iter_root_count > max_iter_root_count ||
977 (iter_root_count == max_iter_root_count &&
978 bcast_root_count > max_bcast_root_count)) {
979 max_iter_root_count = iter_root_count;
980 max_bcast_root_count = bcast_root_count;
981 concrete_id = maybe_concrete_id;
982
983 // If we update the concrete_id, then update the root_ids_of_concrete_id
984 // to reflect this id
985 if (mode == IdMappingMode::LOOP) {
986 root_ids_of_concrete_id.clear();
987 std::transform(
988 root_ids.begin(),
989 root_ids.end(),
990 std::inserter(
991 root_ids_of_concrete_id, root_ids_of_concrete_id.end()),
992 [&](const auto root_id) {
993 return getConcreteMappedID(root_id, IdMappingMode::EXACT);
994 });
995 }
996 }
997 } // end maybe_concrete_id
998
999 TORCH_INTERNAL_ASSERT(
1000 concrete_id != nullptr,
1001 "Something went wrong, could not find a concrete id.");
1002
1003 if (mode == IdMappingMode::LOOP) {
1004 // Validate the concrete id has influence from all the roots of all the
1005 // consumers that will map to this concete id in the loop map. This means
1006 // all the consumers in all expressions of the loop nest generated based on
1007 // this concrete ID will have their roots mapping to this concrete ID
1008 // represented in the extent of this concrete id.
1009 validateCompletenessOfLoopConcreteID(
1010 concrete_id,
1011 *this,
1012 trivial_reduction_info_,
1013 root_ids_of_all_ids,
1014 root_ids_of_concrete_id,
1015 root_id_to_maybe_concrete_ids,
1016 disjoint_set_shared_ptr->vector(),
1017 maybe_concrete_ids.vector());
1018 }
1019
1020 return concrete_id;
1021}
1022
1023void ComputeAtMap::buildConcreteIds() {
1024 for (const auto& disjoint_set_shared_ptr :
1025 id_graph_.permissiveNodes().disjointSets()) {
1026 TORCH_INTERNAL_ASSERT(
1027 disjoint_set_shared_ptr->vector().size(),
1028 "Cannot compute concrete id of empty set.");
1029 auto first_id = disjoint_set_shared_ptr->vector().front();
1030 auto concrete_id = computeConcreteId(first_id, IdMappingMode::PERMISSIVE);
1031 concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
1032 }
1033
1034 for (const auto& disjoint_set_shared_ptr :
1035 id_graph_.exactNodes().disjointSets()) {
1036 TORCH_INTERNAL_ASSERT(
1037 disjoint_set_shared_ptr->vector().size(),
1038 "Cannot compute concrete id of empty set.");
1039 auto first_id = disjoint_set_shared_ptr->vector().front();
1040 auto concrete_id = computeConcreteId(first_id, IdMappingMode::EXACT);
1041 concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
1042 }
1043
1044 for (const auto& disjoint_set_shared_ptr :
1045 id_graph_.loopNodes().disjointSets()) {
1046 TORCH_INTERNAL_ASSERT(
1047 disjoint_set_shared_ptr->vector().size(),
1048 "Cannot compute concrete id of empty set.");
1049 auto first_id = disjoint_set_shared_ptr->vector().front();
1050 auto concrete_id = computeConcreteId(first_id, IdMappingMode::LOOP);
1051 concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
1052 }
1053}
1054
1055IterDomain* ComputeAtMap::getConcreteMappedID(
1056 IterDomain* id,
1057 IdMappingMode mode) const {
1058 auto disjoint_set_shared_ptr = disjointSetOf(id, mode);
1059
1060 TORCH_INTERNAL_ASSERT(
1061 disjoint_set_shared_ptr->vector().size() > 0,
1062 "Empty disjoint set found for ",
1063 id->toString());
1064
1065 auto cache_it = concrete_id_cache_.find(disjoint_set_shared_ptr);
1066
1067 TORCH_INTERNAL_ASSERT(
1068 cache_it != concrete_id_cache_.end(),
1069 "Could not find concrete id for: ",
1070 id->toString(),
1071 " with mode ",
1072 mode);
1073
1074 return cache_it->second;
1075}
1076
1077namespace {
1078
1079std::string idGraphNodesToString(
1080 const ComputeAtMap& ca_map,
1081 IdMappingMode mode) {
1082 std::stringstream ss;
1083 const auto& disjoint_sets = ca_map.getIdSets(mode);
1084 for (const auto& s_ptr : disjoint_sets.disjointSets()) {
1085 const auto& set = *s_ptr;
1086 IterDomain* concrete_id = nullptr;
1087 if (!set.empty()) {
1088 auto id = set.front();
1089 concrete_id = ca_map.getConcreteMappedID(id, mode);
1090 }
1091 ss << " {";
1092 for (auto entry : set.vector()) {
1093 ss << abstractToString(entry);
1094 if (entry == concrete_id) {
1095 ss << "*";
1096 }
1097 if (entry != set.back()) {
1098 ss << "; ";
1099 }
1100 }
1101 ss << " }\n";
1102 }
1103 return ss.str();
1104}
1105
1106} // namespace
1107
1108std::string ComputeAtMap::toString() const {
1109 std::stringstream ss;
1110 ss << "Compute at map { \n";
1111 ss << "Permissive map:\n"
1112 << idGraphNodesToString(*this, IdMappingMode::PERMISSIVE);
1113 ss << "Exact map:\n" << idGraphNodesToString(*this, IdMappingMode::EXACT);
1114 ss << "Loop map:\n" << idGraphNodesToString(*this, IdMappingMode::LOOP);
1115 ss << "Consumer maps:\n";
1116 for (auto entry : id_graph_.consumers()) {
1117 ss << " " << entry.first->toString() << " :: " << entry.second.toString()
1118 << "\n";
1119 }
1120
1121 ss << "Producer maps:\n";
1122 for (auto entry : id_graph_.producers()) {
1123 ss << " " << entry.first->toString() << " :: " << entry.second.toString()
1124 << "\n";
1125 }
1126
1127 ss << "Sibling map:\n" << id_graph_.siblings().toString() << "\n";
1128
1129 ss << "} compute at map" << std::endl;
1130 return ss.str();
1131}
1132
1133bool ComputeAtMap::isViewRfactor(IterDomain* ref_id) const {
1134 return id_graph_.viewRfactorIds().find(ref_id) !=
1135 id_graph_.viewRfactorIds().end();
1136}
1137
1138std::vector<IterDomain*> ComputeAtMap::getViewRfactorDomainsOfIdGroup(
1139 IterDomain* ref_id,
1140 IdMappingMode mode) const {
1141 auto disjoint_set = disjointSetOf(ref_id, mode);
1142 std::vector<IterDomain*> rfactor_ids;
1143 for (auto disjoint_id : disjoint_set->vector()) {
1144 if (id_graph_.viewRfactorIds().find(disjoint_id) !=
1145 id_graph_.viewRfactorIds().end()) {
1146 rfactor_ids.push_back(disjoint_id);
1147 }
1148 }
1149 return rfactor_ids;
1150}
1151
1152const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>& ComputeAtMap::
1153 disjointSetOf(IterDomain* id, IdMappingMode mode) const {
1154 TORCH_INTERNAL_ASSERT(
1155 idExistsInMap(id),
1156 id->toString(),
1157 " has not been processed in this Compute At Map, yet the disjoint set for it was requested.");
1158 return getIdSets(mode).disjointSetMap().at(id);
1159}
1160
1161const DisjointSets<IterDomain*>& ComputeAtMap::getIdSets(
1162 IdMappingMode mode) const {
1163 switch (mode) {
1164 case IdMappingMode::PERMISSIVE:
1165 return id_graph_.permissiveNodes();
1166 case IdMappingMode::EXACT:
1167 return id_graph_.exactNodes();
1168 case IdMappingMode::LOOP:
1169 return id_graph_.loopNodes();
1170 }
1171 TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided.");
1172}
1173
1174bool ComputeAtMap::idExistsInMap(IterDomain* id) const {
1175 return getIdSets(IdMappingMode::EXACT).disjointSetMap().find(id) !=
1176 getIdSets(IdMappingMode::EXACT).disjointSetMap().end();
1177}
1178
1179} // namespace cuda
1180} // namespace fuser
1181} // namespace jit
1182} // namespace torch
1183