1#include <ir_iostream.h>
2#include <ir_utils.h>
3#include <iter_visitor.h>
4#include <root_domain_map.h>
5
6#include <sstream>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
14 mapProducerToConsumer(
15 const TensorDomain* producer,
16 const TensorDomain* consumer,
17 const std::unordered_set<IterDomain*>& root_dims_to_map) const {
18 return map(producer, consumer, root_dims_to_map, true);
19}
20
21std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
22 mapProducerToConsumer(
23 const TensorDomain* producer,
24 const TensorDomain* consumer) const {
25 std::unordered_set<IterDomain*> root_dims_to_map(
26 producer->getMaybeRFactorDomain().begin(),
27 producer->getMaybeRFactorDomain().end());
28 return mapProducerToConsumer(producer, consumer, root_dims_to_map);
29}
30
31std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
32 mapConsumerToProducer(
33 const TensorDomain* consumer,
34 const TensorDomain* producer,
35 const std::unordered_set<IterDomain*>& root_dims_to_map) const {
36 return map(producer, consumer, root_dims_to_map, false);
37}
38
39std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
40 mapConsumerToProducer(
41 const TensorDomain* consumer,
42 const TensorDomain* producer) const {
43 std::unordered_set<IterDomain*> root_dims_to_map(
44 consumer->getRootDomain().begin(), consumer->getRootDomain().end());
45 return mapConsumerToProducer(consumer, producer, root_dims_to_map);
46}
47
48PairwiseRootDomainMap::PairwiseRootDomainMap(
49 const TensorView* producer,
50 const TensorView* consumer,
51 bool is_exact)
52 : producer_tv_(producer), consumer_tv_(consumer), is_exact_(is_exact) {
53 TORCH_INTERNAL_ASSERT(producer != nullptr);
54 TORCH_INTERNAL_ASSERT(consumer != nullptr);
55 TORCH_INTERNAL_ASSERT(producer->fusion() == consumer->fusion());
56 // Make sure they are really a producer and its consumer
57 TORCH_INTERNAL_ASSERT(
58 producer->isConsumerOf(consumer),
59 "Not a producer-consumer pair: ",
60 producer,
61 ", ",
62 consumer);
63}
64
65std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
66 const TensorDomain* producer,
67 const TensorDomain* consumer,
68 const std::unordered_set<IterDomain*>& root_dims_to_map,
69 bool producer_to_consumer) const {
70 // Sanity check that the given producer and consumer domains are
71 // really the TensorDomains of the producer and consumer TensorViews
72 // given to the constructor.
73 TORCH_INTERNAL_ASSERT(producer_tv_->domain() == producer);
74 TORCH_INTERNAL_ASSERT(consumer_tv_->domain() == consumer);
75
76 if (consumer_tv_->definition()->isA<TransposeOp>()) {
77 return mapTranspose(
78 producer, consumer, root_dims_to_map, producer_to_consumer);
79 }
80
81 std::vector<bool> broadcast_flags;
82 if (BroadcastOp* bop =
83 dynamic_cast<BroadcastOp*>(consumer_tv_->definition())) {
84 broadcast_flags = bop->getBroadcastDimFlags();
85 }
86
87 std::unordered_map<IterDomain*, IterDomain*> dom_map;
88 const auto producer_root =
89 TensorDomain::noReductions(producer->getMaybeRFactorDomain());
90 const auto& consumer_root = consumer->getRootDomain();
91 size_t itc = 0, itp = 0;
92 while (itc < consumer_root.size() && itp < producer_root.size()) {
93 IterDomain* producer_id = producer_root[itp];
94 IterDomain* consumer_id = consumer_root[itc];
95
96 // When the consumer ID is a new broadcast domain, there is no
97 // mapping for it.
98 if (!broadcast_flags.empty() && broadcast_flags.at(itc)) {
99 TORCH_INTERNAL_ASSERT(consumer_id->isBroadcast());
100 itc++;
101 continue;
102 }
103
104 // In exact mapping, do not map broadcast domains with
105 // non-broadcast domains
106 if (is_exact_ && producer_id->isBroadcast() != consumer_id->isBroadcast()) {
107 itc++;
108 itp++;
109 continue;
110 }
111
112 IterDomain* map_key_id = producer_id;
113 IterDomain* map_value_id = consumer_id;
114 if (!producer_to_consumer) {
115 std::swap(map_key_id, map_value_id);
116 }
117
118 if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
119 dom_map.insert(std::make_pair(map_key_id, map_value_id));
120 }
121 itc++;
122 itp++;
123 }
124 return dom_map;
125}
126
127std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::
128 mapTranspose(
129 const TensorDomain* producer,
130 const TensorDomain* consumer,
131 const std::unordered_set<IterDomain*>& root_dims_to_map,
132 bool producer_to_consumer) const {
133 const auto producer_root =
134 TensorDomain::noReductions(producer->getMaybeRFactorDomain());
135 const auto& consumer_root = consumer->getRootDomain();
136
137 std::unordered_map<IterDomain*, IterDomain*> dom_map;
138
139 TransposeOp* top = dynamic_cast<TransposeOp*>(consumer_tv_->definition());
140 TORCH_INTERNAL_ASSERT(top != nullptr);
141
142 const auto& new2old = top->new2old();
143 for (const auto i : c10::irange(consumer_root.size())) {
144 IterDomain* map_key_id = producer_root[new2old[i]];
145 IterDomain* map_value_id = consumer_root[i];
146
147 // In exact mapping, do not map broadcast domains with
148 // non-broadcast domains
149 if (is_exact_ && map_key_id->isBroadcast() != map_value_id->isBroadcast()) {
150 continue;
151 }
152
153 if (!producer_to_consumer) {
154 std::swap(map_key_id, map_value_id);
155 }
156
157 if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
158 dom_map.insert(std::make_pair(map_key_id, map_value_id));
159 }
160 }
161 return dom_map;
162}
163
164std::string PairwiseRootDomainMap::toString() const {
165 std::stringstream ss;
166 ss << "{producer: " << producer() << ", consumer: " << consumer();
167 auto p2c = mapProducerToConsumer(producer()->domain(), consumer()->domain());
168 for (auto pair : p2c) {
169 ss << ", " << pair.first->toString() << " -> " << pair.second->toString();
170 }
171 ss << "}";
172 return ss.str();
173}
174
175namespace {
176
177template <typename T>
178auto ensureMapping(
179 T& m,
180 const typename T::key_type& key,
181 const typename T::mapped_type& init_value) {
182 auto it = m.find(key);
183 if (it == m.end()) {
184 it = m.insert({key, init_value}).first;
185 }
186 return it;
187}
188
189TensorView* lookUpTv(const TensorDomain* td) {
190 Fusion* fusion = FusionGuard::getCurFusion();
191 for (auto tv : ir_utils::filterByType<TensorView>(fusion->vals())) {
192 if (tv->domain() == td) {
193 return tv;
194 }
195 }
196 return nullptr;
197}
198
199} // namespace
200
201std::string DomainKey::toString() const {
202 std::stringstream ss;
203 if (id()) {
204 ss << id();
205 } else {
206 ss << "null";
207 }
208 if (concreteId()) {
209 ss << " (concrete: " << concreteId() << ")";
210 }
211 ss << " in ";
212 if (td()) {
213 auto tv = lookUpTv(td());
214 TORCH_INTERNAL_ASSERT(tv != nullptr, "No TV found for ", td()->toString());
215 ss << "T" << tv->name() << "[ " << td()->getRootDomain() << " ]";
216 if (td()->hasRFactor()) {
217 ss << " (Rfactor: [ " << td()->getMaybeRFactorDomain() << " ])";
218 }
219 } else {
220 ss << "null";
221 }
222 return ss.str();
223}
224
225UnmappableReductionDomains::UnmappableReductionDomains() {
226 Fusion* fusion = FusionGuard::getCurFusion();
227 traverse(fusion);
228}
229
230namespace {
231
232//! Find all domains that a given domain is dependent on
233class FindInputDomains : BackwardVisitor {
234 private:
235 FindInputDomains(TensorView* tv, const IterDomain* id)
236 : BackwardVisitor(false), tv_(tv) {
237 input_keys_.insert(DomainKey(tv_->domain(), id));
238 }
239
240 DomainKeySet find() {
241 traverseTo(tv_->fusion(), {tv_});
242 return input_keys_;
243 }
244
245 void handle(Expr* expr) override {
246 for (auto output : expr->outputs()) {
247 if (!output->isA<TensorView>()) {
248 continue;
249 }
250 for (auto input : expr->inputs()) {
251 if (!input->isA<TensorView>()) {
252 continue;
253 }
254 propagate(input->as<TensorView>(), output->as<TensorView>());
255 }
256 }
257 }
258
259 void propagate(TensorView* in_tv, TensorView* out_tv) {
260 auto c2p = PairwiseRootDomainMap(in_tv, out_tv)
261 .mapConsumerToProducer(out_tv->domain(), in_tv->domain());
262 for (auto root_dom : out_tv->getRootDomain()) {
263 DomainKey out_key({out_tv->domain(), root_dom});
264 if (input_keys_.find(out_key) == input_keys_.end()) {
265 continue;
266 }
267 auto input_id_it = c2p.find(root_dom);
268 if (input_id_it == c2p.end()) {
269 continue;
270 }
271 DomainKey input_key(in_tv->domain(), input_id_it->second);
272 input_keys_.insert(input_key);
273 }
274 }
275
276 private:
277 TensorView* tv_ = nullptr;
278 DomainKeySet input_keys_;
279
280 public:
281 static DomainKeySet find(TensorView* tv, const IterDomain* id) {
282 return FindInputDomains(tv, id).find();
283 }
284};
285
286} // namespace
287
288void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) {
289 std::vector<DomainKey> reduction_keys;
290 for (const auto id : out_tv->getRootDomain()) {
291 if (id->isReduction()) {
292 DomainKey key(out_tv->domain(), id);
293 reduction_keys.push_back(key);
294 reduction_domains_.insert({key, {}});
295 }
296 }
297 auto use_chains = DependencyCheck::getAllUseChains(out_tv);
298 for (const auto& chain : use_chains) {
299 for (const auto& tv : ir_utils::filterByType<TensorView>(chain)) {
300 // Do not include the tensor itself in its consumers
301 if (tv == out_tv) {
302 continue;
303 }
304 const auto& root_domain = tv->getRootDomain();
305 for (const auto& id : root_domain) {
306 DomainKey consumer_key(tv->domain(), id);
307 for (const auto& reduction_key : reduction_keys) {
308 reduction_domains_.at(reduction_key).insert(consumer_key);
309 }
310 }
311 }
312 }
313 for (const auto& reduction_key : reduction_keys) {
314 reduction_domain_inputs_.insert(
315 {reduction_key, FindInputDomains::find(out_tv, reduction_key.id())});
316 }
317}
318
319void UnmappableReductionDomains::handle(ReductionOp* op) {
320 // Builds a map from reduction domains to consumer domains.
321 TensorView* out_tv = op->out()->as<TensorView>();
322 handleReductionOutput(out_tv);
323}
324
325void UnmappableReductionDomains::handle(GroupedReductionOp* op) {
326 // Builds a map from reduction domains to consumer domains.
327 for (auto out : op->outputs()) {
328 handleReductionOutput(out->as<TensorView>());
329 }
330}
331
332void UnmappableReductionDomains::handle(MmaOp* mma) {
333 // Builds a map from reduction domains to consumer domains.
334 TensorView* out_tv = mma->out()->as<TensorView>();
335 handleReductionOutput(out_tv);
336}
337
338void UnmappableReductionDomains::handle(WelfordOp* op) {
339 // Builds a map from reduction domains to consumer domains.
340 handleReductionOutput(op->outAvg()->as<TensorView>());
341 handleReductionOutput(op->outVar()->as<TensorView>());
342 handleReductionOutput(op->outN()->as<TensorView>());
343}
344
345bool UnmappableReductionDomains::isReductionOutputMapped(
346 const DomainKeySet& consumer_domains,
347 const ComputeAtRootDomainMap& root_map) const {
348 // Check each reduction domain if any of the consumer domains
349 // conflicts with it
350 for (const auto& kv : reduction_domains_) {
351 const DomainKey& reduction_domain = kv.first;
352 // Domains that must not be mapped with the reduction domain
353 const DomainKeySet& incompatible_domains = kv.second;
354 // Input domains to the reduction domain
355 const auto& input_keys = reduction_domain_inputs_.at(reduction_domain);
356 // Check if any of the consumer domains is an input to the
357 // reduction
358 auto it = std::find_if(
359 consumer_domains.begin(),
360 consumer_domains.end(),
361 [&](const auto& consumer_domain) {
362 return std::find(
363 input_keys.begin(), input_keys.end(), consumer_domain) !=
364 input_keys.end();
365 });
366 // None of the consumer domains is used for the reduction
367 // domain. They should be safe with respect to this reduction
368 // domain
369 if (it == consumer_domains.end()) {
370 continue;
371 }
372
373 // A consumer domain that is an input to the reduction domain
374 const DomainKey& input_to_reduction = *it;
375
376 // Check if mapping input_to_reduction with the other domains in
377 // consumer_domains. If there's a domain that is a consumer of the
378 // reduction, they must not be mapped together
379 for (const auto& consumer_domain : consumer_domains) {
380 if (consumer_domain == input_to_reduction) {
381 continue;
382 }
383 if (std::any_of(
384 incompatible_domains.begin(),
385 incompatible_domains.end(),
386 [&](const DomainKey& incompatible_domain) {
387 return root_map.canMap(
388 consumer_domain.td(),
389 consumer_domain.id(),
390 incompatible_domain.td(),
391 incompatible_domain.id());
392 })) {
393 return true;
394 }
395 }
396 }
397 return false;
398}
399
400std::string UnmappableReductionDomains::toString() const {
401 std::stringstream ss;
402 ss << "Reduction-to-consumer map\n";
403 for (const auto& kv : reduction_domains_) {
404 ss << "\tReduction: " << kv.first.toString() << "\n";
405 for (const auto& mapped_val : kv.second) {
406 ss << "\t\tConsumer domain: " << mapped_val.toString() << "\n";
407 }
408 }
409
410 ss << "Reduction-to-producer map\n";
411 for (const auto& kv : reduction_domain_inputs_) {
412 ss << "\tReduction: " << kv.first.toString() << "\n";
413 for (const auto& mapped_val : kv.second) {
414 ss << "\t\tProducer domain: " << mapped_val.toString() << "\n";
415 }
416 }
417
418 return ss.str();
419}
420
421void ComputeAtRootDomainMap::build(bool map_through_reduction) {
422 // Make sure we start from scratch. Throw away previous results.
423 eq_set_.clear();
424 bcast_map_.clear();
425 new_broadcast_domains_.clear();
426 ComputeAtRootDomainMapBuilder builder(*this, map_through_reduction);
427}
428
429bool ComputeAtRootDomainMap::canMap(
430 const TensorDomain* td_a,
431 const IterDomain* id_a,
432 const TensorDomain* td_b,
433 const IterDomain* id_b) const {
434 TORCH_INTERNAL_ASSERT(
435 id_a->definition() == nullptr || id_a->isRFactorProduct(),
436 "Non-root domain is not supported: ",
437 id_a);
438 TORCH_INTERNAL_ASSERT(
439 id_b->definition() == nullptr || id_b->isRFactorProduct(),
440 "Non-root domain is not supported: ",
441 id_b);
442
443 // Forward to overloaded functions
444 if (!id_a->isBroadcast() && !id_b->isBroadcast()) {
445 return canMap(DomainKey(td_a, id_a), DomainKey(td_b, id_b));
446 } else if (!id_a->isBroadcast()) {
447 return canMap(DomainKey(td_a, id_a), td_b, id_b);
448 } else if (!id_b->isBroadcast()) {
449 return canMap(DomainKey(td_b, id_b), td_a, id_a);
450 }
451
452 // At this point, both are broadcast. Every pair of concrete IDs of
453 // both id_a and id_b needs to be looked at. Whether they are
454 // mappable depends on whether the concrete IDs are broadcast or
455 // not. Note that a broadcast axis is used a concrete ID when it is
456 // part of an output tensor domain, i.e., when it never gets
457 // concretized with any non-broadcast axis.
458
459 // If there exists a pair of non-broadcast concrete IDs is not
460 // mappable, id_a and id_b can't be mapped together. Otherwise, they
461 // can be mapped when there is any mappable pair is found.
462 bool mappable_pair_found = false;
463 for (const auto& key_a : getConcretizedKeys(td_a, id_a)) {
464 for (const auto& key_b : getConcretizedKeys(td_b, id_b)) {
465 const bool mappable = canMap(key_a, key_b);
466 mappable_pair_found = mappable_pair_found || mappable;
467 // If both concrete IDs are not broadcast, they must be
468 // mappable. Also, if either of the concrete IDs is a reduction,
469 // that means a trivial reduction (i.e., broadcast immediately
470 // followed by reduction), which does not prevent any mapping.
471 if (!key_a.concreteId()->isBroadcast() &&
472 !key_b.concreteId()->isBroadcast() &&
473 !key_a.concreteId()->isReduction() &&
474 !key_b.concreteId()->isReduction() && !mappable) {
475 return false;
476 }
477 }
478 }
479
480 return mappable_pair_found;
481}
482
483bool ComputeAtRootDomainMap::canMap(
484 const DomainKey& key_a,
485 const TensorDomain* td_b,
486 const IterDomain* id_b) const {
487 TORCH_INTERNAL_ASSERT(
488 id_b->definition() == nullptr || id_b->isRFactorProduct(),
489 "Non-root domain is not supported: ",
490 id_b);
491
492 if (!id_b->isBroadcast()) {
493 return canMap(key_a, DomainKey(td_b, id_b));
494 }
495
496 // If id_b is broadcast, look at all the concrete IDs that id_b may
497 // be concretized to. Whether it is mappable with key_a depends on
498 // whether key_a's concrete ID is also broadcast.
499 // 1) key_a's concrete ID is also broadcast: They are mappable when
500 // there is any mappable concrete ID exists in the concrete ID set
501 // of id_b.
502 // 2) key_a's concrete ID is not broadcast: Since key_a is indeed
503 // concrete, it must be mappable with any of concrete ID of id_b,
504 // except when a id_b concrete is broadcast.
505 const bool key_a_bcast =
506 key_a.concreteId() && key_a.concreteId()->isBroadcast();
507 const bool key_a_reduction =
508 (key_a.concreteId() && key_a.concreteId()->isReduction()) ||
509 key_a.id()->isReduction();
510 bool mappable_pair_found = false;
511 for (const auto& key_b : getConcretizedKeys(td_b, id_b)) {
512 const bool mappable = canMap(key_a, key_b);
513 mappable_pair_found = mappable_pair_found || mappable;
514 // If both concrete IDs are not broadcast, they must be mappable.
515 // However, if key_b's concrete ID is a reduction, the concrete ID
516 // is a result of a trivial reduction, so it should not prevent
517 // any other mapping. Similarly, if key_a is a reduction, it just
518 // needs to find any concrete ID of key_b that can be mapped.
519 if (!key_a_bcast && !key_b.concreteId()->isBroadcast() &&
520 !key_b.concreteId()->isReduction() && !key_a_reduction && !mappable) {
521 return false;
522 }
523 }
524
525 return mappable_pair_found;
526}
527
528bool ComputeAtRootDomainMap::canMap(
529 const DomainKey& key_a,
530 const DomainKey& key_b) const {
531 return key_a == key_b || eq_set_.permissiveAreMapped(key_a, key_b);
532}
533
534void ComputeAtRootDomainMap::setAlias(
535 const TensorDomain* td,
536 const TensorDomain* td_alias) {
537 auto tmp_bcast_map = bcast_map_;
538 for (const auto& kv : bcast_map_) {
539 const auto& bcast_map_key = kv.first;
540 const auto& bcast_concrete_id_set = kv.second;
541 if (bcast_map_key.td() == td) {
542 DomainKey alias_key(td_alias, bcast_map_key.id());
543 tmp_bcast_map.insert({alias_key, bcast_concrete_id_set});
544 }
545 }
546 bcast_map_ = tmp_bcast_map;
547
548 auto all_elements = eq_set_.getAllElements();
549 for (const auto& key : all_elements.vector()) {
550 if (key.td() == td) {
551 DomainKey alias_key(td_alias, key.id(), key.concreteId());
552 eq_set_.mapEntries(key, alias_key);
553 }
554 }
555
556 auto tmp_new_broadcast_domains = new_broadcast_domains_;
557 for (const auto& key : new_broadcast_domains_) {
558 if (key.td() == td) {
559 DomainKey alias_key(td_alias, key.id());
560 tmp_new_broadcast_domains.insert(alias_key);
561 }
562 }
563 new_broadcast_domains_ = tmp_new_broadcast_domains;
564}
565
566std::vector<DomainKey> ComputeAtRootDomainMap::getConcretizedKeys(
567 const TensorDomain* td,
568 const IterDomain* id) const {
569 DomainKey key(td, id);
570 auto it = bcast_map_.find(key);
571 TORCH_INTERNAL_ASSERT(it != bcast_map_.end(), "Not found: ", key.toString());
572 std::vector<DomainKey> domains;
573 std::transform(
574 it->second.begin(),
575 it->second.end(),
576 std::back_inserter(domains),
577 [&](const IterDomain* concrete_id) {
578 return DomainKey(td, id, concrete_id);
579 });
580 return domains;
581}
582
583std::unordered_set<const IterDomain*>& ComputeAtRootDomainMap::
584 getConcretizedDomains(const TensorDomain* td, const IterDomain* id) {
585 DomainKey key(td, id);
586 auto it = bcast_map_.find(key);
587 TORCH_INTERNAL_ASSERT(it != bcast_map_.end(), "Not found: ", key.toString());
588 return it->second;
589}
590
591std::unordered_map<IterDomain*, IterDomain*> ComputeAtRootDomainMap::
592 mapBestEffort(
593 const TensorDomain* from_td,
594 const std::vector<IterDomain*>& from_root,
595 const TensorDomain* to_td,
596 const std::vector<IterDomain*>& to_root) const {
597 std::unordered_map<IterDomain*, IterDomain*> id_map;
598 for (auto& from_id : from_root) {
599 for (const auto& to_id : to_root) {
600 if (canMap(from_td, from_id, to_td, to_id)) {
601 TORCH_INTERNAL_ASSERT(
602 id_map.insert({from_id, to_id}).second,
603 "Multiple matching ID detected for ",
604 from_id);
605 }
606 }
607 }
608 return id_map;
609}
610
611std::unordered_map<IterDomain*, IterDomain*> ComputeAtRootDomainMap::map(
612 const TensorDomain* producer,
613 const TensorDomain* consumer,
614 const std::unordered_set<IterDomain*>& root_dims_to_map,
615 bool producer_to_consumer) const {
616 const auto& producer_root =
617 TensorDomain::noReductions(producer->getMaybeRFactorDomain());
618 const auto& consumer_root = consumer->getRootDomain();
619 const TensorDomain* from_td = producer_to_consumer ? producer : consumer;
620 const TensorDomain* to_td = producer_to_consumer ? consumer : producer;
621 const auto& from_ids = producer_to_consumer ? producer_root : consumer_root;
622 const auto& to_ids = producer_to_consumer ? consumer_root : producer_root;
623 std::unordered_map<IterDomain*, IterDomain*> id_map =
624 mapBestEffort(from_td, from_ids, to_td, to_ids);
625 for (auto& from_id : from_ids) {
626 if (root_dims_to_map.find(from_id) == root_dims_to_map.end()) {
627 // Remove mapping if exists
628 id_map.erase(from_id);
629 continue;
630 }
631 if (id_map.find(from_id) != id_map.end()) {
632 continue;
633 }
634 // Matching ID not found. It's an error unless the following three cases:
635 // 1. from_id is a new broadcast of a consumer domain; or
636 // 2. from_id is a window axis of a consumer domain; or
637 // 3. from_id is a ViewAsScalar domain
638 // Note that reduction domains are removed from the producer root domain.
639 if (!producer_to_consumer &&
640 (new_broadcast_domains_.find(DomainKey(from_td, from_id)) !=
641 new_broadcast_domains_.end() ||
642 from_id->getIterType() == IterType::VectorComponent ||
643 (window_axes_.count(from_id) > 0))) {
644 continue;
645 }
646 TORCH_INTERNAL_ASSERT(
647 false,
648 "Mapping IterDomain ",
649 from_id,
650 " of ",
651 from_td,
652 " not possible as it would require recomputing the source tensor.",
653 " Producer root: ",
654 producer_root,
655 ". Consumer root: ",
656 consumer_root,
657 ". Mapping: ",
658 this->toString());
659 }
660 return id_map;
661}
662
663std::unordered_set<IterDomain*> ComputeAtRootDomainMap::getMappableDims(
664 const TensorDomain* producer,
665 const TensorDomain* consumer) const {
666 //! This funciton previously used mapBestEffort but it can fail when
667 //! a domain is mapped to multitple domains, which can happen with
668 //! views. Since we only need to find mappable domains, just
669 //! grab any domain that is mapped in a pairwise way.
670
671 const auto& producer_root = producer->getMaybeRFactorDomain();
672 const auto& consumer_root = consumer->getRootDomain();
673
674 std::unordered_set<IterDomain*> mappable_ids;
675
676 for (const auto& p_id : producer_root) {
677 for (const auto& c_id : consumer_root) {
678 if (canMap(producer, p_id, consumer, c_id)) {
679 mappable_ids.emplace(p_id);
680 mappable_ids.emplace(c_id);
681 }
682 }
683 }
684
685 return mappable_ids;
686}
687
688std::string ComputeAtRootDomainMap::toString() const {
689 return eq_set_.toString();
690}
691
692ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder(
693 ComputeAtRootDomainMap& root_map,
694 bool map_through_reduction)
695 : BackwardVisitor(false),
696 root_map_(root_map),
697 map_through_reduction_(map_through_reduction) {
698 Fusion* fusion = FusionGuard::getCurFusion();
699 TORCH_INTERNAL_ASSERT(fusion != nullptr);
700 traverseTo(fusion, fusion->outputs(), false);
701 if (!pending_map_.empty()) {
702 std::stringstream ss;
703 ss << "pending map:\n";
704 for (auto& kv : pending_map_) {
705 ss << "\t" << kv.first.toString() << "\n";
706 for (auto& dk : kv.second) {
707 ss << "\t\t" << dk.toString() << "\n";
708 }
709 }
710 std::cerr << ss.str();
711 }
712 TORCH_INTERNAL_ASSERT(pending_map_.empty());
713}
714
715// Set concrete domains for broadcast domains that never get joined
716// with a concrete domain. Just set its own domain as a concrete
717// domain, which is not concrete but is sufficient for this analysis.
718void ComputeAtRootDomainMapBuilder::initializeBcastMap(
719 const TensorView* tv,
720 const IterDomain* id) {
721 TORCH_INTERNAL_ASSERT(id->isBroadcast(), "Not a broadcast axis");
722 auto key = DomainKey(tv->domain(), id);
723 auto it = root_map_.bcast_map_.find(key);
724 if (it != root_map_.bcast_map_.end()) {
725 // already initialized.
726 return;
727 }
728
729 // This initialization should be only used for: 1) fusion output
730 // tensors, 2) outputs of multi-consumer expressions that are not
731 // fusion outputs, and 3) view outputs as broadcasts can be merged
732 // with non-broadcast domains, resulting in non-broadcast rfactor
733 // domains.
734 TORCH_INTERNAL_ASSERT(
735 tv->isFusionOutput() || tv->definition()->outputs().size() > 1 ||
736 tv->isDefinitionType(ExprType::ViewOp),
737 "Invalid tensor to initialize bcast map: t",
738 tv->name());
739 root_map_.bcast_map_.insert({key, {id}});
740}
741
742void ComputeAtRootDomainMapBuilder::addToPendingList(
743 const DomainKey& producer,
744 const DomainKey& consumer) {
745 auto it = ensureMapping(pending_map_, producer, {});
746 auto& consumer_set = it->second;
747 consumer_set.insert(consumer);
748}
749
750void ComputeAtRootDomainMapBuilder::setMapped(
751 const DomainKey& producer,
752 const DomainKey& consumer) {
753 root_map_.eq_set_.mapEntries(producer, consumer);
754}
755
756void ComputeAtRootDomainMapBuilder::setInvalid(
757 const DomainKey& key1,
758 const DomainKey& key2) {
759 invalid_mappings_.emplace_back(key1, key2);
760}
761
762bool ComputeAtRootDomainMapBuilder::isInvalid(
763 const DomainKeySet& domains) const {
764 // First, collect all invalid mappings for each of the keys in domains
765 DomainKeyMap<DomainKeySet> invalid_key_map;
766 for (const auto& key : domains) {
767 DomainKeySet invalid_keys;
768 for (const auto& invalid_pair : invalid_mappings_) {
769 if (root_map_.canMap(key, invalid_pair.first)) {
770 invalid_keys.insert(invalid_pair.second);
771 } else if (root_map_.canMap(key, invalid_pair.second)) {
772 invalid_keys.insert(invalid_pair.first);
773 }
774 }
775 invalid_key_map.emplace(key, invalid_keys);
776 }
777
778 // Next, check if any pair is invalid to map.
779 const auto num_keys = domains.size();
780 const std::vector<DomainKey> domains_vec({domains.begin(), domains.end()});
781 for (const auto i : c10::irange(num_keys)) {
782 const auto& key_i = domains_vec[i];
783 // If no invalid keys found for key_i, it can be skipped.
784 const auto invalid_key_map_it = invalid_key_map.find(key_i);
785 if (invalid_key_map_it == invalid_key_map.end()) {
786 continue;
787 }
788
789 // Set of keys that are invalid to be mapped with key_i.
790 const DomainKeySet& invalid_keys_for_i = invalid_key_map_it->second;
791
792 // If any other key in domains is identified mappable with any of
793 // the keys in this set, the mapping with key_i is invalid.
794 for (const auto j : c10::irange(i + 1, num_keys)) {
795 const auto& key_j = domains_vec[j];
796 if (std::any_of(
797 invalid_keys_for_i.begin(),
798 invalid_keys_for_i.end(),
799 [&](const auto& invalid_key_for_i) {
800 return root_map_.canMap(key_j, invalid_key_for_i);
801 })) {
802 return true;
803 }
804 }
805 }
806 return false;
807}
808
809void ComputeAtRootDomainMapBuilder::setMaybeMapped(
810 const TensorDomain* producer_td,
811 const IterDomain* producer_id,
812 const TensorDomain* consumer_td,
813 const IterDomain* consumer_id) {
814 const DomainKey producer_key(producer_td, producer_id);
815 const DomainKey consumer_key(consumer_td, consumer_id);
816
817 if (producer_id->isBroadcast()) {
818 ensureMapping(root_map_.bcast_map_, producer_key, {});
819 }
820
821 if (consumer_id->isBroadcast()) {
822 TORCH_INTERNAL_ASSERT(producer_id->isBroadcast());
823 // Get bcast_map_ entry for consumer_id
824 const auto consumer_bcast_domains =
825 root_map_.getConcretizedKeys(consumer_td, consumer_id);
826 auto& producer_domains =
827 root_map_.getConcretizedDomains(producer_td, producer_id);
828
829 // If consumer id is broadcasted, make sure to propagate its concrete_id(s)
830 // to producer
831 for (const auto& consumer_bcast_key : consumer_bcast_domains) {
832 const auto concrete_id = consumer_bcast_key.concreteId();
833 const DomainKey producer_bcast_key(producer_td, producer_id, concrete_id);
834 producer_domains.insert(concrete_id);
835 addToPendingList(producer_bcast_key, consumer_bcast_key);
836 }
837 } else {
838 auto producer_concrete_key = producer_key;
839 if (producer_id->isBroadcast()) {
840 const auto concrete_id = consumer_id;
841 auto& producer_domains =
842 root_map_.getConcretizedDomains(producer_td, producer_id);
843 producer_concrete_key = DomainKey(producer_td, producer_id, concrete_id);
844 producer_domains.insert(concrete_id);
845 }
846 addToPendingList(producer_concrete_key, consumer_key);
847 }
848}
849
850void ComputeAtRootDomainMapBuilder::handle(Expr* e) {
851 // Avoid visiting expressions multiple times
852 if (visited_.find(e) != visited_.end()) {
853 return;
854 }
855 BackwardVisitor::handle(e);
856 visited_.insert(e);
857}
858
859void ComputeAtRootDomainMapBuilder::mapPointwiseOrReductionOp(Expr* e) {
860 if (e->output(0)->getValType() != ValType::TensorView) {
861 return;
862 }
863
864 // Broadcast is handled separately, so e should never be BroadcastOp.
865 TORCH_INTERNAL_ASSERT(e->getExprType() != ExprType::BroadcastOp);
866
867 TORCH_INTERNAL_ASSERT(e->outputs().size() >= 1);
868 const TensorView* out_tv = e->output(0)->as<TensorView>();
869 const TensorDomain* out_td = out_tv->domain();
870 const auto& out_root = out_td->getRootDomain();
871
872 // Record equalities from output to all the inputs
873 // ignores non-concretizable broadcasts
874 for (auto* in_tv : ir_utils::filterByType<TensorView>(e->inputs())) {
875 const TensorDomain* in_td = in_tv->domain();
876 std::vector<IterDomain*> in_root =
877 TensorDomain::noReductions(in_tv->getMaybeRFactorDomain());
878 TORCH_INTERNAL_ASSERT(
879 in_root.size() == out_root.size(),
880 "\nExpression: ",
881 e,
882 "\nInput root domain: ",
883 in_root,
884 "\nOutput root domain: ",
885 out_root);
886 for (const auto it : c10::irange(in_root.size())) {
887 if (e->outputs().size() > 1) {
888 TORCH_INTERNAL_ASSERT(
889 e->isA<WelfordOp>() || e->isA<GroupedReductionOp>() ||
890 e->isA<GroupedWelfordOp>(),
891 "Unknown multi-output Expr type ",
892 e->getExprType().value(),
893 " is found");
894 for (auto out : e->outputs()) {
895 auto out_tv = out->as<TensorView>();
896 auto out_td = out_tv->domain();
897 auto out_root = out_td->getRootDomain();
898 setMaybeMapped(in_td, in_root[it], out_td, out_root[it]);
899 }
900 } else {
901 setMaybeMapped(in_td, in_root[it], out_td, out_root[it]);
902 }
903 }
904 }
905}
906
907void ComputeAtRootDomainMapBuilder::handle(BroadcastOp* op) {
908 const TensorDomain* in_td = op->in()->as<TensorView>()->domain();
909 const TensorDomain* out_td = op->out()->as<TensorView>()->domain();
910 const auto in_root =
911 TensorDomain::noReductions(in_td->getMaybeRFactorDomain());
912 const auto& out_root = out_td->getRootDomain();
913 const auto& bcast_dim_flags = op->getBroadcastDimFlags();
914 TORCH_INTERNAL_ASSERT(
915 out_root.size() == bcast_dim_flags.size(),
916 "dim flags: ",
917 bcast_dim_flags,
918 ", out root: ",
919 out_root);
920 auto in_it = in_root.begin();
921 auto out_it = out_root.begin();
922 while (in_it != in_root.end() && out_it != out_root.end()) {
923 if (bcast_dim_flags.at(std::distance(out_root.begin(), out_it))) {
924 // new broadcast dim. No matching dimension in the input
925 // tensor.
926 root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it));
927 ++out_it;
928 continue;
929 }
930 setMaybeMapped(in_td, *in_it, out_td, *out_it);
931 ++in_it;
932 ++out_it;
933 }
934 // At this point, the input domain should have been scanned
935 // entirely.
936 TORCH_INTERNAL_ASSERT(
937 in_it == in_root.end(),
938 "Unmatched domain detected: ",
939 *in_it,
940 " of ",
941 in_td);
942 // On the other hand, the output may still have some domains left,
943 // and they must be new broadcast domains.
944 for (; out_it != out_root.end(); ++out_it) {
945 TORCH_INTERNAL_ASSERT(
946 bcast_dim_flags.at(std::distance(out_root.begin(), out_it)),
947 "Unmatched domain detected: ",
948 *out_it,
949 " of ",
950 out_td);
951 root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it));
952 }
953}
954
955void ComputeAtRootDomainMapBuilder::handle(ViewAsScalar* op) {
956 const TensorView* out_tv = op->output(0)->as<TensorView>();
957 const TensorDomain* out_td = out_tv->domain();
958 const auto& out_root = out_td->getRootDomain();
959
960 const TensorView* in_tv = op->input(0)->as<TensorView>();
961 const TensorDomain* in_td = in_tv->domain();
962
963 std::vector<IterDomain*> in_root =
964 TensorDomain::noReductions(in_tv->getMaybeRFactorDomain());
965 TORCH_INTERNAL_ASSERT(
966 in_root.size() + 1 == out_root.size(),
967 "\nExpression: ",
968 op,
969 "\nInput root domain: ",
970 in_root,
971 "\nOutput root domain: ",
972 out_root);
973 auto in_it = in_root.begin();
974 auto out_it = out_root.begin();
975 while (in_it != in_root.end() && out_it != out_root.end()) {
976 setMaybeMapped(in_td, *in_it, out_td, *out_it);
977 ++in_it;
978 ++out_it;
979 }
980 TORCH_INTERNAL_ASSERT(
981 (*out_it)->isVectorComponent(),
982 "The last dim of ViewDtypeOp's output must be a ViewAsScalar");
983}
984
985void ComputeAtRootDomainMapBuilder::handle(TransposeOp* op) {
986 const TensorDomain* in_td = op->in()->as<TensorView>()->domain();
987 std::vector<IterDomain*> in_root =
988 TensorDomain::noReductions(in_td->getMaybeRFactorDomain());
989
990 const TensorDomain* out_td = op->out()->as<TensorView>()->domain();
991 const auto& out_root = out_td->getRootDomain();
992
993 TORCH_INTERNAL_ASSERT(in_root.size() == out_root.size());
994
995 const auto& new2old = op->new2old();
996
997 for (const auto it : c10::irange(out_root.size())) {
998 setMaybeMapped(in_td, in_root[new2old[it]], out_td, out_root[it]);
999 }
1000}
1001
1002void ComputeAtRootDomainMapBuilder::handle(GatherOp* op) {
1003 const TensorDomain* in_td = op->in()->as<TensorView>()->domain();
1004 const TensorDomain* out_td = op->out()->as<TensorView>()->domain();
1005 const auto in_root =
1006 TensorDomain::noReductions(in_td->getMaybeRFactorDomain());
1007 const auto& out_root = out_td->getRootDomain();
1008
1009 // Only maps the input root axes. Do not map the new window axes.
1010 for (const auto it : c10::irange(in_root.size())) {
1011 setMaybeMapped(in_td, in_root[it], out_td, out_root[it]);
1012 }
1013
1014 // Keep track of window axes so that they can be skipped when
1015 // mapping root domains
1016 for (const auto it : c10::irange(in_root.size(), out_root.size())) {
1017 root_map_.window_axes_.insert(out_root[it]);
1018 }
1019}
1020
1021void ComputeAtRootDomainMapBuilder::mapAllPendingMappings(
1022 const DomainKey& key) {
1023 auto it = pending_map_.find(key);
1024 if (it == pending_map_.end()) {
1025 return;
1026 }
1027 const auto& pending_set = it->second;
1028 // All entries in key_set must be equivalent with each other.
1029 TORCH_INTERNAL_ASSERT(pending_set.size() > 0);
1030 bool consistent = safeToMap(pending_set);
1031 for (const auto pending_key : pending_set) {
1032 if (consistent) {
1033 setMapped(key, pending_key);
1034 } else {
1035 setInvalid(key, pending_key);
1036 }
1037 }
1038 // This entry should never be used again, so remove it.
1039 pending_map_.erase(it);
1040}
1041
1042void ComputeAtRootDomainMapBuilder::mapAllPendingMappings(
1043 const TensorDomain* td,
1044 IterDomain* id) {
1045 if (id->isBroadcast()) {
1046 for (const auto& key : root_map_.getConcretizedKeys(td, id)) {
1047 mapAllPendingMappings(key);
1048 }
1049 } else {
1050 mapAllPendingMappings(DomainKey(td, id));
1051 }
1052}
1053
1054void ComputeAtRootDomainMapBuilder::handle(RNGOp* rop) {
1055 handle(rop->output(0)->as<TensorView>());
1056}
1057
1058void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) {
1059 const TensorDomain* td = tv->domain();
1060 const auto rfactor = TensorDomain::noReductions(td->getMaybeRFactorDomain());
1061 for (auto id : rfactor) {
1062 if (id->isBroadcast()) {
1063 initializeBcastMap(tv, id);
1064 }
1065 mapAllPendingMappings(td, id);
1066 }
1067
1068 // When tv has an rfactor domain, propagate the domain mappings from
1069 // each of the rfactor axes to the dependent root axes.
1070 if (td->hasViewLikeRFactor()) {
1071 std::unordered_set<Val*> root_set(
1072 {td->getRootDomain().begin(), td->getRootDomain().end()});
1073 for (auto rf_id : rfactor) {
1074 if (!rf_id->isRFactorProduct()) {
1075 continue;
1076 }
1077 auto dep = DependencyCheck::getAllValsBetween(root_set, {rf_id});
1078 for (auto id : ir_utils::filterByType<IterDomain>(dep)) {
1079 if (root_set.find(id) == root_set.end() || rf_id == id) {
1080 continue;
1081 }
1082 setMaybeMapped(td, id, td, rf_id);
1083 }
1084 }
1085 // Once mappings for rfactor axes are propagated to root axes,
1086 // aggregates them at each root axis
1087 for (auto id : tv->getRootDomain()) {
1088 if (id->isBroadcast()) {
1089 // There can be broadcast domains that appear at root domains but
1090 // are removed at rfactor domains as they are merged into
1091 // non-reduction domains. Initialize the map for those broadcast
1092 // domains.
1093 initializeBcastMap(tv, id);
1094 }
1095 mapAllPendingMappings(td, id);
1096 }
1097 }
1098}
1099
1100// Checks whether all consumers of a producer can be joined without
1101// introducing unsupported mappings, i.e., requiring recomputations.
1102bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) {
1103 if (domains.size() <= 1) {
1104 return true;
1105 }
1106
1107 // Can't map if reduction output domains would be mapped
1108 if (incompatible_domains_.isReductionOutputMapped(domains, root_map_) &&
1109 !map_through_reduction_) {
1110 return false;
1111 }
1112 // Make sure mapping these domains won't cause any invalid mapping
1113 if (isInvalid(domains)) {
1114 return false;
1115 }
1116 return true;
1117}
1118
1119namespace {
1120class ExactRootDomainMapBuilder : private IterVisitor {
1121 public:
1122 ExactRootDomainMapBuilder(
1123 Fusion* fusion,
1124 DisjointSets<const IterDomain*>& eq_sets)
1125 : eq_sets_(eq_sets) {
1126 traverseTo(fusion, fusion->outputs());
1127 }
1128
1129 private:
1130 using IterVisitor::handle;
1131
1132 void handle(Expr* expr) final {
1133 for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
1134 for (auto consumer :
1135 ir_utils::filterByType<TensorView>(expr->outputs())) {
1136 PairwiseRootDomainMap pwise_map(producer, consumer, true);
1137 const auto mappings = pwise_map.mapProducerToConsumer(
1138 producer->domain(), consumer->domain());
1139 for (const auto& mapping : mappings) {
1140 eq_sets_.mapEntries(mapping.first, mapping.second);
1141 }
1142 }
1143 }
1144 }
1145
1146 private:
1147 DisjointSets<const IterDomain*>& eq_sets_;
1148};
1149
1150} // namespace
1151
1152ExactRootDomainMap::ExactRootDomainMap(Fusion* fusion) {
1153 ExactRootDomainMapBuilder builder(fusion, eq_sets_);
1154}
1155
1156bool ExactRootDomainMap::areMapped(
1157 const IterDomain* id_a,
1158 const IterDomain* id_b) const {
1159 // With expand going into a view operation there can be an instance where an
1160 // iteration root domain in the consumer resolves the broadcast from the
1161 // producer, then immediately rfactors it. In this case the consumer root is
1162 // not mapped exactly to any other domain, so it might no have an entry in
1163 // eq_sets_. eq_sets_.strictAreMapped would throw in this case so just return
1164 // false if a mapping doesn't exist.
1165 if (!eq_sets_.mappingExists(id_a) || !eq_sets_.mappingExists(id_b)) {
1166 return false;
1167 }
1168 return eq_sets_.strictAreMapped(id_a, id_b);
1169}
1170
1171std::unordered_map<IterDomain*, IterDomain*> ExactRootDomainMap::map(
1172 const TensorDomain* producer,
1173 const TensorDomain* consumer,
1174 const std::unordered_set<IterDomain*>& root_dims_to_map,
1175 bool producer_to_consumer) const {
1176 const auto& producer_root =
1177 TensorDomain::noReductions(producer->getMaybeRFactorDomain());
1178 const auto& consumer_root = consumer->getRootDomain();
1179 const auto& from_ids = producer_to_consumer ? producer_root : consumer_root;
1180 const auto& to_ids = producer_to_consumer ? consumer_root : producer_root;
1181
1182 std::unordered_map<IterDomain*, IterDomain*> id_map;
1183
1184 for (auto& from_id : from_ids) {
1185 if (root_dims_to_map.find(from_id) == root_dims_to_map.end()) {
1186 continue;
1187 }
1188 for (const auto& to_id : to_ids) {
1189 if (areMapped(from_id, to_id)) {
1190 TORCH_INTERNAL_ASSERT(
1191 id_map.insert({from_id, to_id}).second,
1192 "Multiple matching ID detected for ",
1193 from_id);
1194 }
1195 }
1196 }
1197
1198 return id_map;
1199}
1200
1201std::string ExactRootDomainMap::toString() const {
1202 return eq_sets_.toString();
1203}
1204
1205} // namespace cuda
1206} // namespace fuser
1207} // namespace jit
1208} // namespace torch
1209