1 | #include <ir_iostream.h> |
2 | #include <ir_utils.h> |
3 | #include <iter_visitor.h> |
4 | #include <root_domain_map.h> |
5 | |
6 | #include <sstream> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | std::unordered_map<IterDomain*, IterDomain*> RootDomainMap:: |
14 | mapProducerToConsumer( |
15 | const TensorDomain* producer, |
16 | const TensorDomain* consumer, |
17 | const std::unordered_set<IterDomain*>& root_dims_to_map) const { |
18 | return map(producer, consumer, root_dims_to_map, true); |
19 | } |
20 | |
21 | std::unordered_map<IterDomain*, IterDomain*> RootDomainMap:: |
22 | mapProducerToConsumer( |
23 | const TensorDomain* producer, |
24 | const TensorDomain* consumer) const { |
25 | std::unordered_set<IterDomain*> root_dims_to_map( |
26 | producer->getMaybeRFactorDomain().begin(), |
27 | producer->getMaybeRFactorDomain().end()); |
28 | return mapProducerToConsumer(producer, consumer, root_dims_to_map); |
29 | } |
30 | |
31 | std::unordered_map<IterDomain*, IterDomain*> RootDomainMap:: |
32 | mapConsumerToProducer( |
33 | const TensorDomain* consumer, |
34 | const TensorDomain* producer, |
35 | const std::unordered_set<IterDomain*>& root_dims_to_map) const { |
36 | return map(producer, consumer, root_dims_to_map, false); |
37 | } |
38 | |
39 | std::unordered_map<IterDomain*, IterDomain*> RootDomainMap:: |
40 | mapConsumerToProducer( |
41 | const TensorDomain* consumer, |
42 | const TensorDomain* producer) const { |
43 | std::unordered_set<IterDomain*> root_dims_to_map( |
44 | consumer->getRootDomain().begin(), consumer->getRootDomain().end()); |
45 | return mapConsumerToProducer(consumer, producer, root_dims_to_map); |
46 | } |
47 | |
48 | PairwiseRootDomainMap::PairwiseRootDomainMap( |
49 | const TensorView* producer, |
50 | const TensorView* consumer, |
51 | bool is_exact) |
52 | : producer_tv_(producer), consumer_tv_(consumer), is_exact_(is_exact) { |
53 | TORCH_INTERNAL_ASSERT(producer != nullptr); |
54 | TORCH_INTERNAL_ASSERT(consumer != nullptr); |
55 | TORCH_INTERNAL_ASSERT(producer->fusion() == consumer->fusion()); |
56 | // Make sure they are really a producer and its consumer |
57 | TORCH_INTERNAL_ASSERT( |
58 | producer->isConsumerOf(consumer), |
59 | "Not a producer-consumer pair: " , |
60 | producer, |
61 | ", " , |
62 | consumer); |
63 | } |
64 | |
65 | std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map( |
66 | const TensorDomain* producer, |
67 | const TensorDomain* consumer, |
68 | const std::unordered_set<IterDomain*>& root_dims_to_map, |
69 | bool producer_to_consumer) const { |
70 | // Sanity check that the given producer and consumer domains are |
71 | // really the TensorDomains of the producer and consumer TensorViews |
72 | // given to the constructor. |
73 | TORCH_INTERNAL_ASSERT(producer_tv_->domain() == producer); |
74 | TORCH_INTERNAL_ASSERT(consumer_tv_->domain() == consumer); |
75 | |
76 | if (consumer_tv_->definition()->isA<TransposeOp>()) { |
77 | return mapTranspose( |
78 | producer, consumer, root_dims_to_map, producer_to_consumer); |
79 | } |
80 | |
81 | std::vector<bool> broadcast_flags; |
82 | if (BroadcastOp* bop = |
83 | dynamic_cast<BroadcastOp*>(consumer_tv_->definition())) { |
84 | broadcast_flags = bop->getBroadcastDimFlags(); |
85 | } |
86 | |
87 | std::unordered_map<IterDomain*, IterDomain*> dom_map; |
88 | const auto producer_root = |
89 | TensorDomain::noReductions(producer->getMaybeRFactorDomain()); |
90 | const auto& consumer_root = consumer->getRootDomain(); |
91 | size_t itc = 0, itp = 0; |
92 | while (itc < consumer_root.size() && itp < producer_root.size()) { |
93 | IterDomain* producer_id = producer_root[itp]; |
94 | IterDomain* consumer_id = consumer_root[itc]; |
95 | |
96 | // When the consumer ID is a new broadcast domain, there is no |
97 | // mapping for it. |
98 | if (!broadcast_flags.empty() && broadcast_flags.at(itc)) { |
99 | TORCH_INTERNAL_ASSERT(consumer_id->isBroadcast()); |
100 | itc++; |
101 | continue; |
102 | } |
103 | |
104 | // In exact mapping, do not map broadcast domains with |
105 | // non-broadcast domains |
106 | if (is_exact_ && producer_id->isBroadcast() != consumer_id->isBroadcast()) { |
107 | itc++; |
108 | itp++; |
109 | continue; |
110 | } |
111 | |
112 | IterDomain* map_key_id = producer_id; |
113 | IterDomain* map_value_id = consumer_id; |
114 | if (!producer_to_consumer) { |
115 | std::swap(map_key_id, map_value_id); |
116 | } |
117 | |
118 | if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { |
119 | dom_map.insert(std::make_pair(map_key_id, map_value_id)); |
120 | } |
121 | itc++; |
122 | itp++; |
123 | } |
124 | return dom_map; |
125 | } |
126 | |
127 | std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap:: |
128 | mapTranspose( |
129 | const TensorDomain* producer, |
130 | const TensorDomain* consumer, |
131 | const std::unordered_set<IterDomain*>& root_dims_to_map, |
132 | bool producer_to_consumer) const { |
133 | const auto producer_root = |
134 | TensorDomain::noReductions(producer->getMaybeRFactorDomain()); |
135 | const auto& consumer_root = consumer->getRootDomain(); |
136 | |
137 | std::unordered_map<IterDomain*, IterDomain*> dom_map; |
138 | |
139 | TransposeOp* top = dynamic_cast<TransposeOp*>(consumer_tv_->definition()); |
140 | TORCH_INTERNAL_ASSERT(top != nullptr); |
141 | |
142 | const auto& new2old = top->new2old(); |
143 | for (const auto i : c10::irange(consumer_root.size())) { |
144 | IterDomain* map_key_id = producer_root[new2old[i]]; |
145 | IterDomain* map_value_id = consumer_root[i]; |
146 | |
147 | // In exact mapping, do not map broadcast domains with |
148 | // non-broadcast domains |
149 | if (is_exact_ && map_key_id->isBroadcast() != map_value_id->isBroadcast()) { |
150 | continue; |
151 | } |
152 | |
153 | if (!producer_to_consumer) { |
154 | std::swap(map_key_id, map_value_id); |
155 | } |
156 | |
157 | if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { |
158 | dom_map.insert(std::make_pair(map_key_id, map_value_id)); |
159 | } |
160 | } |
161 | return dom_map; |
162 | } |
163 | |
164 | std::string PairwiseRootDomainMap::toString() const { |
165 | std::stringstream ss; |
166 | ss << "{producer: " << producer() << ", consumer: " << consumer(); |
167 | auto p2c = mapProducerToConsumer(producer()->domain(), consumer()->domain()); |
168 | for (auto pair : p2c) { |
169 | ss << ", " << pair.first->toString() << " -> " << pair.second->toString(); |
170 | } |
171 | ss << "}" ; |
172 | return ss.str(); |
173 | } |
174 | |
175 | namespace { |
176 | |
177 | template <typename T> |
178 | auto ensureMapping( |
179 | T& m, |
180 | const typename T::key_type& key, |
181 | const typename T::mapped_type& init_value) { |
182 | auto it = m.find(key); |
183 | if (it == m.end()) { |
184 | it = m.insert({key, init_value}).first; |
185 | } |
186 | return it; |
187 | } |
188 | |
189 | TensorView* lookUpTv(const TensorDomain* td) { |
190 | Fusion* fusion = FusionGuard::getCurFusion(); |
191 | for (auto tv : ir_utils::filterByType<TensorView>(fusion->vals())) { |
192 | if (tv->domain() == td) { |
193 | return tv; |
194 | } |
195 | } |
196 | return nullptr; |
197 | } |
198 | |
199 | } // namespace |
200 | |
201 | std::string DomainKey::toString() const { |
202 | std::stringstream ss; |
203 | if (id()) { |
204 | ss << id(); |
205 | } else { |
206 | ss << "null" ; |
207 | } |
208 | if (concreteId()) { |
209 | ss << " (concrete: " << concreteId() << ")" ; |
210 | } |
211 | ss << " in " ; |
212 | if (td()) { |
213 | auto tv = lookUpTv(td()); |
214 | TORCH_INTERNAL_ASSERT(tv != nullptr, "No TV found for " , td()->toString()); |
215 | ss << "T" << tv->name() << "[ " << td()->getRootDomain() << " ]" ; |
216 | if (td()->hasRFactor()) { |
217 | ss << " (Rfactor: [ " << td()->getMaybeRFactorDomain() << " ])" ; |
218 | } |
219 | } else { |
220 | ss << "null" ; |
221 | } |
222 | return ss.str(); |
223 | } |
224 | |
225 | UnmappableReductionDomains::UnmappableReductionDomains() { |
226 | Fusion* fusion = FusionGuard::getCurFusion(); |
227 | traverse(fusion); |
228 | } |
229 | |
230 | namespace { |
231 | |
232 | //! Find all domains that a given domain is dependent on |
233 | class FindInputDomains : BackwardVisitor { |
234 | private: |
235 | FindInputDomains(TensorView* tv, const IterDomain* id) |
236 | : BackwardVisitor(false), tv_(tv) { |
237 | input_keys_.insert(DomainKey(tv_->domain(), id)); |
238 | } |
239 | |
240 | DomainKeySet find() { |
241 | traverseTo(tv_->fusion(), {tv_}); |
242 | return input_keys_; |
243 | } |
244 | |
245 | void handle(Expr* expr) override { |
246 | for (auto output : expr->outputs()) { |
247 | if (!output->isA<TensorView>()) { |
248 | continue; |
249 | } |
250 | for (auto input : expr->inputs()) { |
251 | if (!input->isA<TensorView>()) { |
252 | continue; |
253 | } |
254 | propagate(input->as<TensorView>(), output->as<TensorView>()); |
255 | } |
256 | } |
257 | } |
258 | |
259 | void propagate(TensorView* in_tv, TensorView* out_tv) { |
260 | auto c2p = PairwiseRootDomainMap(in_tv, out_tv) |
261 | .mapConsumerToProducer(out_tv->domain(), in_tv->domain()); |
262 | for (auto root_dom : out_tv->getRootDomain()) { |
263 | DomainKey out_key({out_tv->domain(), root_dom}); |
264 | if (input_keys_.find(out_key) == input_keys_.end()) { |
265 | continue; |
266 | } |
267 | auto input_id_it = c2p.find(root_dom); |
268 | if (input_id_it == c2p.end()) { |
269 | continue; |
270 | } |
271 | DomainKey input_key(in_tv->domain(), input_id_it->second); |
272 | input_keys_.insert(input_key); |
273 | } |
274 | } |
275 | |
276 | private: |
277 | TensorView* tv_ = nullptr; |
278 | DomainKeySet input_keys_; |
279 | |
280 | public: |
281 | static DomainKeySet find(TensorView* tv, const IterDomain* id) { |
282 | return FindInputDomains(tv, id).find(); |
283 | } |
284 | }; |
285 | |
286 | } // namespace |
287 | |
288 | void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) { |
289 | std::vector<DomainKey> reduction_keys; |
290 | for (const auto id : out_tv->getRootDomain()) { |
291 | if (id->isReduction()) { |
292 | DomainKey key(out_tv->domain(), id); |
293 | reduction_keys.push_back(key); |
294 | reduction_domains_.insert({key, {}}); |
295 | } |
296 | } |
297 | auto use_chains = DependencyCheck::getAllUseChains(out_tv); |
298 | for (const auto& chain : use_chains) { |
299 | for (const auto& tv : ir_utils::filterByType<TensorView>(chain)) { |
300 | // Do not include the tensor itself in its consumers |
301 | if (tv == out_tv) { |
302 | continue; |
303 | } |
304 | const auto& root_domain = tv->getRootDomain(); |
305 | for (const auto& id : root_domain) { |
306 | DomainKey consumer_key(tv->domain(), id); |
307 | for (const auto& reduction_key : reduction_keys) { |
308 | reduction_domains_.at(reduction_key).insert(consumer_key); |
309 | } |
310 | } |
311 | } |
312 | } |
313 | for (const auto& reduction_key : reduction_keys) { |
314 | reduction_domain_inputs_.insert( |
315 | {reduction_key, FindInputDomains::find(out_tv, reduction_key.id())}); |
316 | } |
317 | } |
318 | |
319 | void UnmappableReductionDomains::handle(ReductionOp* op) { |
320 | // Builds a map from reduction domains to consumer domains. |
321 | TensorView* out_tv = op->out()->as<TensorView>(); |
322 | handleReductionOutput(out_tv); |
323 | } |
324 | |
325 | void UnmappableReductionDomains::handle(GroupedReductionOp* op) { |
326 | // Builds a map from reduction domains to consumer domains. |
327 | for (auto out : op->outputs()) { |
328 | handleReductionOutput(out->as<TensorView>()); |
329 | } |
330 | } |
331 | |
332 | void UnmappableReductionDomains::handle(MmaOp* mma) { |
333 | // Builds a map from reduction domains to consumer domains. |
334 | TensorView* out_tv = mma->out()->as<TensorView>(); |
335 | handleReductionOutput(out_tv); |
336 | } |
337 | |
338 | void UnmappableReductionDomains::handle(WelfordOp* op) { |
339 | // Builds a map from reduction domains to consumer domains. |
340 | handleReductionOutput(op->outAvg()->as<TensorView>()); |
341 | handleReductionOutput(op->outVar()->as<TensorView>()); |
342 | handleReductionOutput(op->outN()->as<TensorView>()); |
343 | } |
344 | |
345 | bool UnmappableReductionDomains::isReductionOutputMapped( |
346 | const DomainKeySet& consumer_domains, |
347 | const ComputeAtRootDomainMap& root_map) const { |
348 | // Check each reduction domain if any of the consumer domains |
349 | // conflicts with it |
350 | for (const auto& kv : reduction_domains_) { |
351 | const DomainKey& reduction_domain = kv.first; |
352 | // Domains that must not be mapped with the reduction domain |
353 | const DomainKeySet& incompatible_domains = kv.second; |
354 | // Input domains to the reduction domain |
355 | const auto& input_keys = reduction_domain_inputs_.at(reduction_domain); |
356 | // Check if any of the consumer domains is an input to the |
357 | // reduction |
358 | auto it = std::find_if( |
359 | consumer_domains.begin(), |
360 | consumer_domains.end(), |
361 | [&](const auto& consumer_domain) { |
362 | return std::find( |
363 | input_keys.begin(), input_keys.end(), consumer_domain) != |
364 | input_keys.end(); |
365 | }); |
366 | // None of the consumer domains is used for the reduction |
367 | // domain. They should be safe with respect to this reduction |
368 | // domain |
369 | if (it == consumer_domains.end()) { |
370 | continue; |
371 | } |
372 | |
373 | // A consumer domain that is an input to the reduction domain |
374 | const DomainKey& input_to_reduction = *it; |
375 | |
376 | // Check if mapping input_to_reduction with the other domains in |
377 | // consumer_domains. If there's a domain that is a consumer of the |
378 | // reduction, they must not be mapped together |
379 | for (const auto& consumer_domain : consumer_domains) { |
380 | if (consumer_domain == input_to_reduction) { |
381 | continue; |
382 | } |
383 | if (std::any_of( |
384 | incompatible_domains.begin(), |
385 | incompatible_domains.end(), |
386 | [&](const DomainKey& incompatible_domain) { |
387 | return root_map.canMap( |
388 | consumer_domain.td(), |
389 | consumer_domain.id(), |
390 | incompatible_domain.td(), |
391 | incompatible_domain.id()); |
392 | })) { |
393 | return true; |
394 | } |
395 | } |
396 | } |
397 | return false; |
398 | } |
399 | |
400 | std::string UnmappableReductionDomains::toString() const { |
401 | std::stringstream ss; |
402 | ss << "Reduction-to-consumer map\n" ; |
403 | for (const auto& kv : reduction_domains_) { |
404 | ss << "\tReduction: " << kv.first.toString() << "\n" ; |
405 | for (const auto& mapped_val : kv.second) { |
406 | ss << "\t\tConsumer domain: " << mapped_val.toString() << "\n" ; |
407 | } |
408 | } |
409 | |
410 | ss << "Reduction-to-producer map\n" ; |
411 | for (const auto& kv : reduction_domain_inputs_) { |
412 | ss << "\tReduction: " << kv.first.toString() << "\n" ; |
413 | for (const auto& mapped_val : kv.second) { |
414 | ss << "\t\tProducer domain: " << mapped_val.toString() << "\n" ; |
415 | } |
416 | } |
417 | |
418 | return ss.str(); |
419 | } |
420 | |
421 | void ComputeAtRootDomainMap::build(bool map_through_reduction) { |
422 | // Make sure we start from scratch. Throw away previous results. |
423 | eq_set_.clear(); |
424 | bcast_map_.clear(); |
425 | new_broadcast_domains_.clear(); |
426 | ComputeAtRootDomainMapBuilder builder(*this, map_through_reduction); |
427 | } |
428 | |
429 | bool ComputeAtRootDomainMap::canMap( |
430 | const TensorDomain* td_a, |
431 | const IterDomain* id_a, |
432 | const TensorDomain* td_b, |
433 | const IterDomain* id_b) const { |
434 | TORCH_INTERNAL_ASSERT( |
435 | id_a->definition() == nullptr || id_a->isRFactorProduct(), |
436 | "Non-root domain is not supported: " , |
437 | id_a); |
438 | TORCH_INTERNAL_ASSERT( |
439 | id_b->definition() == nullptr || id_b->isRFactorProduct(), |
440 | "Non-root domain is not supported: " , |
441 | id_b); |
442 | |
443 | // Forward to overloaded functions |
444 | if (!id_a->isBroadcast() && !id_b->isBroadcast()) { |
445 | return canMap(DomainKey(td_a, id_a), DomainKey(td_b, id_b)); |
446 | } else if (!id_a->isBroadcast()) { |
447 | return canMap(DomainKey(td_a, id_a), td_b, id_b); |
448 | } else if (!id_b->isBroadcast()) { |
449 | return canMap(DomainKey(td_b, id_b), td_a, id_a); |
450 | } |
451 | |
452 | // At this point, both are broadcast. Every pair of concrete IDs of |
453 | // both id_a and id_b needs to be looked at. Whether they are |
454 | // mappable depends on whether the concrete IDs are broadcast or |
455 | // not. Note that a broadcast axis is used a concrete ID when it is |
456 | // part of an output tensor domain, i.e., when it never gets |
457 | // concretized with any non-broadcast axis. |
458 | |
459 | // If there exists a pair of non-broadcast concrete IDs is not |
460 | // mappable, id_a and id_b can't be mapped together. Otherwise, they |
461 | // can be mapped when there is any mappable pair is found. |
462 | bool mappable_pair_found = false; |
463 | for (const auto& key_a : getConcretizedKeys(td_a, id_a)) { |
464 | for (const auto& key_b : getConcretizedKeys(td_b, id_b)) { |
465 | const bool mappable = canMap(key_a, key_b); |
466 | mappable_pair_found = mappable_pair_found || mappable; |
467 | // If both concrete IDs are not broadcast, they must be |
468 | // mappable. Also, if either of the concrete IDs is a reduction, |
469 | // that means a trivial reduction (i.e., broadcast immediately |
470 | // followed by reduction), which does not prevent any mapping. |
471 | if (!key_a.concreteId()->isBroadcast() && |
472 | !key_b.concreteId()->isBroadcast() && |
473 | !key_a.concreteId()->isReduction() && |
474 | !key_b.concreteId()->isReduction() && !mappable) { |
475 | return false; |
476 | } |
477 | } |
478 | } |
479 | |
480 | return mappable_pair_found; |
481 | } |
482 | |
483 | bool ComputeAtRootDomainMap::canMap( |
484 | const DomainKey& key_a, |
485 | const TensorDomain* td_b, |
486 | const IterDomain* id_b) const { |
487 | TORCH_INTERNAL_ASSERT( |
488 | id_b->definition() == nullptr || id_b->isRFactorProduct(), |
489 | "Non-root domain is not supported: " , |
490 | id_b); |
491 | |
492 | if (!id_b->isBroadcast()) { |
493 | return canMap(key_a, DomainKey(td_b, id_b)); |
494 | } |
495 | |
496 | // If id_b is broadcast, look at all the concrete IDs that id_b may |
497 | // be concretized to. Whether it is mappable with key_a depends on |
498 | // whether key_a's concrete ID is also broadcast. |
499 | // 1) key_a's concrete ID is also broadcast: They are mappable when |
500 | // there is any mappable concrete ID exists in the concrete ID set |
501 | // of id_b. |
502 | // 2) key_a's concrete ID is not broadcast: Since key_a is indeed |
503 | // concrete, it must be mappable with any of concrete ID of id_b, |
504 | // except when a id_b concrete is broadcast. |
505 | const bool key_a_bcast = |
506 | key_a.concreteId() && key_a.concreteId()->isBroadcast(); |
507 | const bool key_a_reduction = |
508 | (key_a.concreteId() && key_a.concreteId()->isReduction()) || |
509 | key_a.id()->isReduction(); |
510 | bool mappable_pair_found = false; |
511 | for (const auto& key_b : getConcretizedKeys(td_b, id_b)) { |
512 | const bool mappable = canMap(key_a, key_b); |
513 | mappable_pair_found = mappable_pair_found || mappable; |
514 | // If both concrete IDs are not broadcast, they must be mappable. |
515 | // However, if key_b's concrete ID is a reduction, the concrete ID |
516 | // is a result of a trivial reduction, so it should not prevent |
517 | // any other mapping. Similarly, if key_a is a reduction, it just |
518 | // needs to find any concrete ID of key_b that can be mapped. |
519 | if (!key_a_bcast && !key_b.concreteId()->isBroadcast() && |
520 | !key_b.concreteId()->isReduction() && !key_a_reduction && !mappable) { |
521 | return false; |
522 | } |
523 | } |
524 | |
525 | return mappable_pair_found; |
526 | } |
527 | |
528 | bool ComputeAtRootDomainMap::canMap( |
529 | const DomainKey& key_a, |
530 | const DomainKey& key_b) const { |
531 | return key_a == key_b || eq_set_.permissiveAreMapped(key_a, key_b); |
532 | } |
533 | |
534 | void ComputeAtRootDomainMap::setAlias( |
535 | const TensorDomain* td, |
536 | const TensorDomain* td_alias) { |
537 | auto tmp_bcast_map = bcast_map_; |
538 | for (const auto& kv : bcast_map_) { |
539 | const auto& bcast_map_key = kv.first; |
540 | const auto& bcast_concrete_id_set = kv.second; |
541 | if (bcast_map_key.td() == td) { |
542 | DomainKey alias_key(td_alias, bcast_map_key.id()); |
543 | tmp_bcast_map.insert({alias_key, bcast_concrete_id_set}); |
544 | } |
545 | } |
546 | bcast_map_ = tmp_bcast_map; |
547 | |
548 | auto all_elements = eq_set_.getAllElements(); |
549 | for (const auto& key : all_elements.vector()) { |
550 | if (key.td() == td) { |
551 | DomainKey alias_key(td_alias, key.id(), key.concreteId()); |
552 | eq_set_.mapEntries(key, alias_key); |
553 | } |
554 | } |
555 | |
556 | auto tmp_new_broadcast_domains = new_broadcast_domains_; |
557 | for (const auto& key : new_broadcast_domains_) { |
558 | if (key.td() == td) { |
559 | DomainKey alias_key(td_alias, key.id()); |
560 | tmp_new_broadcast_domains.insert(alias_key); |
561 | } |
562 | } |
563 | new_broadcast_domains_ = tmp_new_broadcast_domains; |
564 | } |
565 | |
566 | std::vector<DomainKey> ComputeAtRootDomainMap::getConcretizedKeys( |
567 | const TensorDomain* td, |
568 | const IterDomain* id) const { |
569 | DomainKey key(td, id); |
570 | auto it = bcast_map_.find(key); |
571 | TORCH_INTERNAL_ASSERT(it != bcast_map_.end(), "Not found: " , key.toString()); |
572 | std::vector<DomainKey> domains; |
573 | std::transform( |
574 | it->second.begin(), |
575 | it->second.end(), |
576 | std::back_inserter(domains), |
577 | [&](const IterDomain* concrete_id) { |
578 | return DomainKey(td, id, concrete_id); |
579 | }); |
580 | return domains; |
581 | } |
582 | |
583 | std::unordered_set<const IterDomain*>& ComputeAtRootDomainMap:: |
584 | getConcretizedDomains(const TensorDomain* td, const IterDomain* id) { |
585 | DomainKey key(td, id); |
586 | auto it = bcast_map_.find(key); |
587 | TORCH_INTERNAL_ASSERT(it != bcast_map_.end(), "Not found: " , key.toString()); |
588 | return it->second; |
589 | } |
590 | |
591 | std::unordered_map<IterDomain*, IterDomain*> ComputeAtRootDomainMap:: |
592 | mapBestEffort( |
593 | const TensorDomain* from_td, |
594 | const std::vector<IterDomain*>& from_root, |
595 | const TensorDomain* to_td, |
596 | const std::vector<IterDomain*>& to_root) const { |
597 | std::unordered_map<IterDomain*, IterDomain*> id_map; |
598 | for (auto& from_id : from_root) { |
599 | for (const auto& to_id : to_root) { |
600 | if (canMap(from_td, from_id, to_td, to_id)) { |
601 | TORCH_INTERNAL_ASSERT( |
602 | id_map.insert({from_id, to_id}).second, |
603 | "Multiple matching ID detected for " , |
604 | from_id); |
605 | } |
606 | } |
607 | } |
608 | return id_map; |
609 | } |
610 | |
611 | std::unordered_map<IterDomain*, IterDomain*> ComputeAtRootDomainMap::map( |
612 | const TensorDomain* producer, |
613 | const TensorDomain* consumer, |
614 | const std::unordered_set<IterDomain*>& root_dims_to_map, |
615 | bool producer_to_consumer) const { |
616 | const auto& producer_root = |
617 | TensorDomain::noReductions(producer->getMaybeRFactorDomain()); |
618 | const auto& consumer_root = consumer->getRootDomain(); |
619 | const TensorDomain* from_td = producer_to_consumer ? producer : consumer; |
620 | const TensorDomain* to_td = producer_to_consumer ? consumer : producer; |
621 | const auto& from_ids = producer_to_consumer ? producer_root : consumer_root; |
622 | const auto& to_ids = producer_to_consumer ? consumer_root : producer_root; |
623 | std::unordered_map<IterDomain*, IterDomain*> id_map = |
624 | mapBestEffort(from_td, from_ids, to_td, to_ids); |
625 | for (auto& from_id : from_ids) { |
626 | if (root_dims_to_map.find(from_id) == root_dims_to_map.end()) { |
627 | // Remove mapping if exists |
628 | id_map.erase(from_id); |
629 | continue; |
630 | } |
631 | if (id_map.find(from_id) != id_map.end()) { |
632 | continue; |
633 | } |
634 | // Matching ID not found. It's an error unless the following three cases: |
635 | // 1. from_id is a new broadcast of a consumer domain; or |
636 | // 2. from_id is a window axis of a consumer domain; or |
637 | // 3. from_id is a ViewAsScalar domain |
638 | // Note that reduction domains are removed from the producer root domain. |
639 | if (!producer_to_consumer && |
640 | (new_broadcast_domains_.find(DomainKey(from_td, from_id)) != |
641 | new_broadcast_domains_.end() || |
642 | from_id->getIterType() == IterType::VectorComponent || |
643 | (window_axes_.count(from_id) > 0))) { |
644 | continue; |
645 | } |
646 | TORCH_INTERNAL_ASSERT( |
647 | false, |
648 | "Mapping IterDomain " , |
649 | from_id, |
650 | " of " , |
651 | from_td, |
652 | " not possible as it would require recomputing the source tensor." , |
653 | " Producer root: " , |
654 | producer_root, |
655 | ". Consumer root: " , |
656 | consumer_root, |
657 | ". Mapping: " , |
658 | this->toString()); |
659 | } |
660 | return id_map; |
661 | } |
662 | |
663 | std::unordered_set<IterDomain*> ComputeAtRootDomainMap::getMappableDims( |
664 | const TensorDomain* producer, |
665 | const TensorDomain* consumer) const { |
666 | //! This funciton previously used mapBestEffort but it can fail when |
667 | //! a domain is mapped to multitple domains, which can happen with |
668 | //! views. Since we only need to find mappable domains, just |
669 | //! grab any domain that is mapped in a pairwise way. |
670 | |
671 | const auto& producer_root = producer->getMaybeRFactorDomain(); |
672 | const auto& consumer_root = consumer->getRootDomain(); |
673 | |
674 | std::unordered_set<IterDomain*> mappable_ids; |
675 | |
676 | for (const auto& p_id : producer_root) { |
677 | for (const auto& c_id : consumer_root) { |
678 | if (canMap(producer, p_id, consumer, c_id)) { |
679 | mappable_ids.emplace(p_id); |
680 | mappable_ids.emplace(c_id); |
681 | } |
682 | } |
683 | } |
684 | |
685 | return mappable_ids; |
686 | } |
687 | |
688 | std::string ComputeAtRootDomainMap::toString() const { |
689 | return eq_set_.toString(); |
690 | } |
691 | |
692 | ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder( |
693 | ComputeAtRootDomainMap& root_map, |
694 | bool map_through_reduction) |
695 | : BackwardVisitor(false), |
696 | root_map_(root_map), |
697 | map_through_reduction_(map_through_reduction) { |
698 | Fusion* fusion = FusionGuard::getCurFusion(); |
699 | TORCH_INTERNAL_ASSERT(fusion != nullptr); |
700 | traverseTo(fusion, fusion->outputs(), false); |
701 | if (!pending_map_.empty()) { |
702 | std::stringstream ss; |
703 | ss << "pending map:\n" ; |
704 | for (auto& kv : pending_map_) { |
705 | ss << "\t" << kv.first.toString() << "\n" ; |
706 | for (auto& dk : kv.second) { |
707 | ss << "\t\t" << dk.toString() << "\n" ; |
708 | } |
709 | } |
710 | std::cerr << ss.str(); |
711 | } |
712 | TORCH_INTERNAL_ASSERT(pending_map_.empty()); |
713 | } |
714 | |
715 | // Set concrete domains for broadcast domains that never get joined |
716 | // with a concrete domain. Just set its own domain as a concrete |
717 | // domain, which is not concrete but is sufficient for this analysis. |
718 | void ComputeAtRootDomainMapBuilder::initializeBcastMap( |
719 | const TensorView* tv, |
720 | const IterDomain* id) { |
721 | TORCH_INTERNAL_ASSERT(id->isBroadcast(), "Not a broadcast axis" ); |
722 | auto key = DomainKey(tv->domain(), id); |
723 | auto it = root_map_.bcast_map_.find(key); |
724 | if (it != root_map_.bcast_map_.end()) { |
725 | // already initialized. |
726 | return; |
727 | } |
728 | |
729 | // This initialization should be only used for: 1) fusion output |
730 | // tensors, 2) outputs of multi-consumer expressions that are not |
731 | // fusion outputs, and 3) view outputs as broadcasts can be merged |
732 | // with non-broadcast domains, resulting in non-broadcast rfactor |
733 | // domains. |
734 | TORCH_INTERNAL_ASSERT( |
735 | tv->isFusionOutput() || tv->definition()->outputs().size() > 1 || |
736 | tv->isDefinitionType(ExprType::ViewOp), |
737 | "Invalid tensor to initialize bcast map: t" , |
738 | tv->name()); |
739 | root_map_.bcast_map_.insert({key, {id}}); |
740 | } |
741 | |
742 | void ComputeAtRootDomainMapBuilder::addToPendingList( |
743 | const DomainKey& producer, |
744 | const DomainKey& consumer) { |
745 | auto it = ensureMapping(pending_map_, producer, {}); |
746 | auto& consumer_set = it->second; |
747 | consumer_set.insert(consumer); |
748 | } |
749 | |
750 | void ComputeAtRootDomainMapBuilder::setMapped( |
751 | const DomainKey& producer, |
752 | const DomainKey& consumer) { |
753 | root_map_.eq_set_.mapEntries(producer, consumer); |
754 | } |
755 | |
756 | void ComputeAtRootDomainMapBuilder::setInvalid( |
757 | const DomainKey& key1, |
758 | const DomainKey& key2) { |
759 | invalid_mappings_.emplace_back(key1, key2); |
760 | } |
761 | |
762 | bool ComputeAtRootDomainMapBuilder::isInvalid( |
763 | const DomainKeySet& domains) const { |
764 | // First, collect all invalid mappings for each of the keys in domains |
765 | DomainKeyMap<DomainKeySet> invalid_key_map; |
766 | for (const auto& key : domains) { |
767 | DomainKeySet invalid_keys; |
768 | for (const auto& invalid_pair : invalid_mappings_) { |
769 | if (root_map_.canMap(key, invalid_pair.first)) { |
770 | invalid_keys.insert(invalid_pair.second); |
771 | } else if (root_map_.canMap(key, invalid_pair.second)) { |
772 | invalid_keys.insert(invalid_pair.first); |
773 | } |
774 | } |
775 | invalid_key_map.emplace(key, invalid_keys); |
776 | } |
777 | |
778 | // Next, check if any pair is invalid to map. |
779 | const auto num_keys = domains.size(); |
780 | const std::vector<DomainKey> domains_vec({domains.begin(), domains.end()}); |
781 | for (const auto i : c10::irange(num_keys)) { |
782 | const auto& key_i = domains_vec[i]; |
783 | // If no invalid keys found for key_i, it can be skipped. |
784 | const auto invalid_key_map_it = invalid_key_map.find(key_i); |
785 | if (invalid_key_map_it == invalid_key_map.end()) { |
786 | continue; |
787 | } |
788 | |
789 | // Set of keys that are invalid to be mapped with key_i. |
790 | const DomainKeySet& invalid_keys_for_i = invalid_key_map_it->second; |
791 | |
792 | // If any other key in domains is identified mappable with any of |
793 | // the keys in this set, the mapping with key_i is invalid. |
794 | for (const auto j : c10::irange(i + 1, num_keys)) { |
795 | const auto& key_j = domains_vec[j]; |
796 | if (std::any_of( |
797 | invalid_keys_for_i.begin(), |
798 | invalid_keys_for_i.end(), |
799 | [&](const auto& invalid_key_for_i) { |
800 | return root_map_.canMap(key_j, invalid_key_for_i); |
801 | })) { |
802 | return true; |
803 | } |
804 | } |
805 | } |
806 | return false; |
807 | } |
808 | |
809 | void ComputeAtRootDomainMapBuilder::setMaybeMapped( |
810 | const TensorDomain* producer_td, |
811 | const IterDomain* producer_id, |
812 | const TensorDomain* consumer_td, |
813 | const IterDomain* consumer_id) { |
814 | const DomainKey producer_key(producer_td, producer_id); |
815 | const DomainKey consumer_key(consumer_td, consumer_id); |
816 | |
817 | if (producer_id->isBroadcast()) { |
818 | ensureMapping(root_map_.bcast_map_, producer_key, {}); |
819 | } |
820 | |
821 | if (consumer_id->isBroadcast()) { |
822 | TORCH_INTERNAL_ASSERT(producer_id->isBroadcast()); |
823 | // Get bcast_map_ entry for consumer_id |
824 | const auto consumer_bcast_domains = |
825 | root_map_.getConcretizedKeys(consumer_td, consumer_id); |
826 | auto& producer_domains = |
827 | root_map_.getConcretizedDomains(producer_td, producer_id); |
828 | |
829 | // If consumer id is broadcasted, make sure to propagate its concrete_id(s) |
830 | // to producer |
831 | for (const auto& consumer_bcast_key : consumer_bcast_domains) { |
832 | const auto concrete_id = consumer_bcast_key.concreteId(); |
833 | const DomainKey producer_bcast_key(producer_td, producer_id, concrete_id); |
834 | producer_domains.insert(concrete_id); |
835 | addToPendingList(producer_bcast_key, consumer_bcast_key); |
836 | } |
837 | } else { |
838 | auto producer_concrete_key = producer_key; |
839 | if (producer_id->isBroadcast()) { |
840 | const auto concrete_id = consumer_id; |
841 | auto& producer_domains = |
842 | root_map_.getConcretizedDomains(producer_td, producer_id); |
843 | producer_concrete_key = DomainKey(producer_td, producer_id, concrete_id); |
844 | producer_domains.insert(concrete_id); |
845 | } |
846 | addToPendingList(producer_concrete_key, consumer_key); |
847 | } |
848 | } |
849 | |
850 | void ComputeAtRootDomainMapBuilder::handle(Expr* e) { |
851 | // Avoid visiting expressions multiple times |
852 | if (visited_.find(e) != visited_.end()) { |
853 | return; |
854 | } |
855 | BackwardVisitor::handle(e); |
856 | visited_.insert(e); |
857 | } |
858 | |
859 | void ComputeAtRootDomainMapBuilder::mapPointwiseOrReductionOp(Expr* e) { |
860 | if (e->output(0)->getValType() != ValType::TensorView) { |
861 | return; |
862 | } |
863 | |
864 | // Broadcast is handled separately, so e should never be BroadcastOp. |
865 | TORCH_INTERNAL_ASSERT(e->getExprType() != ExprType::BroadcastOp); |
866 | |
867 | TORCH_INTERNAL_ASSERT(e->outputs().size() >= 1); |
868 | const TensorView* out_tv = e->output(0)->as<TensorView>(); |
869 | const TensorDomain* out_td = out_tv->domain(); |
870 | const auto& out_root = out_td->getRootDomain(); |
871 | |
872 | // Record equalities from output to all the inputs |
873 | // ignores non-concretizable broadcasts |
874 | for (auto* in_tv : ir_utils::filterByType<TensorView>(e->inputs())) { |
875 | const TensorDomain* in_td = in_tv->domain(); |
876 | std::vector<IterDomain*> in_root = |
877 | TensorDomain::noReductions(in_tv->getMaybeRFactorDomain()); |
878 | TORCH_INTERNAL_ASSERT( |
879 | in_root.size() == out_root.size(), |
880 | "\nExpression: " , |
881 | e, |
882 | "\nInput root domain: " , |
883 | in_root, |
884 | "\nOutput root domain: " , |
885 | out_root); |
886 | for (const auto it : c10::irange(in_root.size())) { |
887 | if (e->outputs().size() > 1) { |
888 | TORCH_INTERNAL_ASSERT( |
889 | e->isA<WelfordOp>() || e->isA<GroupedReductionOp>() || |
890 | e->isA<GroupedWelfordOp>(), |
891 | "Unknown multi-output Expr type " , |
892 | e->getExprType().value(), |
893 | " is found" ); |
894 | for (auto out : e->outputs()) { |
895 | auto out_tv = out->as<TensorView>(); |
896 | auto out_td = out_tv->domain(); |
897 | auto out_root = out_td->getRootDomain(); |
898 | setMaybeMapped(in_td, in_root[it], out_td, out_root[it]); |
899 | } |
900 | } else { |
901 | setMaybeMapped(in_td, in_root[it], out_td, out_root[it]); |
902 | } |
903 | } |
904 | } |
905 | } |
906 | |
907 | void ComputeAtRootDomainMapBuilder::handle(BroadcastOp* op) { |
908 | const TensorDomain* in_td = op->in()->as<TensorView>()->domain(); |
909 | const TensorDomain* out_td = op->out()->as<TensorView>()->domain(); |
910 | const auto in_root = |
911 | TensorDomain::noReductions(in_td->getMaybeRFactorDomain()); |
912 | const auto& out_root = out_td->getRootDomain(); |
913 | const auto& bcast_dim_flags = op->getBroadcastDimFlags(); |
914 | TORCH_INTERNAL_ASSERT( |
915 | out_root.size() == bcast_dim_flags.size(), |
916 | "dim flags: " , |
917 | bcast_dim_flags, |
918 | ", out root: " , |
919 | out_root); |
920 | auto in_it = in_root.begin(); |
921 | auto out_it = out_root.begin(); |
922 | while (in_it != in_root.end() && out_it != out_root.end()) { |
923 | if (bcast_dim_flags.at(std::distance(out_root.begin(), out_it))) { |
924 | // new broadcast dim. No matching dimension in the input |
925 | // tensor. |
926 | root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it)); |
927 | ++out_it; |
928 | continue; |
929 | } |
930 | setMaybeMapped(in_td, *in_it, out_td, *out_it); |
931 | ++in_it; |
932 | ++out_it; |
933 | } |
934 | // At this point, the input domain should have been scanned |
935 | // entirely. |
936 | TORCH_INTERNAL_ASSERT( |
937 | in_it == in_root.end(), |
938 | "Unmatched domain detected: " , |
939 | *in_it, |
940 | " of " , |
941 | in_td); |
942 | // On the other hand, the output may still have some domains left, |
943 | // and they must be new broadcast domains. |
944 | for (; out_it != out_root.end(); ++out_it) { |
945 | TORCH_INTERNAL_ASSERT( |
946 | bcast_dim_flags.at(std::distance(out_root.begin(), out_it)), |
947 | "Unmatched domain detected: " , |
948 | *out_it, |
949 | " of " , |
950 | out_td); |
951 | root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it)); |
952 | } |
953 | } |
954 | |
955 | void ComputeAtRootDomainMapBuilder::handle(ViewAsScalar* op) { |
956 | const TensorView* out_tv = op->output(0)->as<TensorView>(); |
957 | const TensorDomain* out_td = out_tv->domain(); |
958 | const auto& out_root = out_td->getRootDomain(); |
959 | |
960 | const TensorView* in_tv = op->input(0)->as<TensorView>(); |
961 | const TensorDomain* in_td = in_tv->domain(); |
962 | |
963 | std::vector<IterDomain*> in_root = |
964 | TensorDomain::noReductions(in_tv->getMaybeRFactorDomain()); |
965 | TORCH_INTERNAL_ASSERT( |
966 | in_root.size() + 1 == out_root.size(), |
967 | "\nExpression: " , |
968 | op, |
969 | "\nInput root domain: " , |
970 | in_root, |
971 | "\nOutput root domain: " , |
972 | out_root); |
973 | auto in_it = in_root.begin(); |
974 | auto out_it = out_root.begin(); |
975 | while (in_it != in_root.end() && out_it != out_root.end()) { |
976 | setMaybeMapped(in_td, *in_it, out_td, *out_it); |
977 | ++in_it; |
978 | ++out_it; |
979 | } |
980 | TORCH_INTERNAL_ASSERT( |
981 | (*out_it)->isVectorComponent(), |
982 | "The last dim of ViewDtypeOp's output must be a ViewAsScalar" ); |
983 | } |
984 | |
985 | void ComputeAtRootDomainMapBuilder::handle(TransposeOp* op) { |
986 | const TensorDomain* in_td = op->in()->as<TensorView>()->domain(); |
987 | std::vector<IterDomain*> in_root = |
988 | TensorDomain::noReductions(in_td->getMaybeRFactorDomain()); |
989 | |
990 | const TensorDomain* out_td = op->out()->as<TensorView>()->domain(); |
991 | const auto& out_root = out_td->getRootDomain(); |
992 | |
993 | TORCH_INTERNAL_ASSERT(in_root.size() == out_root.size()); |
994 | |
995 | const auto& new2old = op->new2old(); |
996 | |
997 | for (const auto it : c10::irange(out_root.size())) { |
998 | setMaybeMapped(in_td, in_root[new2old[it]], out_td, out_root[it]); |
999 | } |
1000 | } |
1001 | |
1002 | void ComputeAtRootDomainMapBuilder::handle(GatherOp* op) { |
1003 | const TensorDomain* in_td = op->in()->as<TensorView>()->domain(); |
1004 | const TensorDomain* out_td = op->out()->as<TensorView>()->domain(); |
1005 | const auto in_root = |
1006 | TensorDomain::noReductions(in_td->getMaybeRFactorDomain()); |
1007 | const auto& out_root = out_td->getRootDomain(); |
1008 | |
1009 | // Only maps the input root axes. Do not map the new window axes. |
1010 | for (const auto it : c10::irange(in_root.size())) { |
1011 | setMaybeMapped(in_td, in_root[it], out_td, out_root[it]); |
1012 | } |
1013 | |
1014 | // Keep track of window axes so that they can be skipped when |
1015 | // mapping root domains |
1016 | for (const auto it : c10::irange(in_root.size(), out_root.size())) { |
1017 | root_map_.window_axes_.insert(out_root[it]); |
1018 | } |
1019 | } |
1020 | |
1021 | void ComputeAtRootDomainMapBuilder::mapAllPendingMappings( |
1022 | const DomainKey& key) { |
1023 | auto it = pending_map_.find(key); |
1024 | if (it == pending_map_.end()) { |
1025 | return; |
1026 | } |
1027 | const auto& pending_set = it->second; |
1028 | // All entries in key_set must be equivalent with each other. |
1029 | TORCH_INTERNAL_ASSERT(pending_set.size() > 0); |
1030 | bool consistent = safeToMap(pending_set); |
1031 | for (const auto pending_key : pending_set) { |
1032 | if (consistent) { |
1033 | setMapped(key, pending_key); |
1034 | } else { |
1035 | setInvalid(key, pending_key); |
1036 | } |
1037 | } |
1038 | // This entry should never be used again, so remove it. |
1039 | pending_map_.erase(it); |
1040 | } |
1041 | |
1042 | void ComputeAtRootDomainMapBuilder::mapAllPendingMappings( |
1043 | const TensorDomain* td, |
1044 | IterDomain* id) { |
1045 | if (id->isBroadcast()) { |
1046 | for (const auto& key : root_map_.getConcretizedKeys(td, id)) { |
1047 | mapAllPendingMappings(key); |
1048 | } |
1049 | } else { |
1050 | mapAllPendingMappings(DomainKey(td, id)); |
1051 | } |
1052 | } |
1053 | |
1054 | void ComputeAtRootDomainMapBuilder::handle(RNGOp* rop) { |
1055 | handle(rop->output(0)->as<TensorView>()); |
1056 | } |
1057 | |
1058 | void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) { |
1059 | const TensorDomain* td = tv->domain(); |
1060 | const auto rfactor = TensorDomain::noReductions(td->getMaybeRFactorDomain()); |
1061 | for (auto id : rfactor) { |
1062 | if (id->isBroadcast()) { |
1063 | initializeBcastMap(tv, id); |
1064 | } |
1065 | mapAllPendingMappings(td, id); |
1066 | } |
1067 | |
1068 | // When tv has an rfactor domain, propagate the domain mappings from |
1069 | // each of the rfactor axes to the dependent root axes. |
1070 | if (td->hasViewLikeRFactor()) { |
1071 | std::unordered_set<Val*> root_set( |
1072 | {td->getRootDomain().begin(), td->getRootDomain().end()}); |
1073 | for (auto rf_id : rfactor) { |
1074 | if (!rf_id->isRFactorProduct()) { |
1075 | continue; |
1076 | } |
1077 | auto dep = DependencyCheck::getAllValsBetween(root_set, {rf_id}); |
1078 | for (auto id : ir_utils::filterByType<IterDomain>(dep)) { |
1079 | if (root_set.find(id) == root_set.end() || rf_id == id) { |
1080 | continue; |
1081 | } |
1082 | setMaybeMapped(td, id, td, rf_id); |
1083 | } |
1084 | } |
1085 | // Once mappings for rfactor axes are propagated to root axes, |
1086 | // aggregates them at each root axis |
1087 | for (auto id : tv->getRootDomain()) { |
1088 | if (id->isBroadcast()) { |
1089 | // There can be broadcast domains that appear at root domains but |
1090 | // are removed at rfactor domains as they are merged into |
1091 | // non-reduction domains. Initialize the map for those broadcast |
1092 | // domains. |
1093 | initializeBcastMap(tv, id); |
1094 | } |
1095 | mapAllPendingMappings(td, id); |
1096 | } |
1097 | } |
1098 | } |
1099 | |
1100 | // Checks whether all consumers of a producer can be joined without |
1101 | // introducing unsupported mappings, i.e., requiring recomputations. |
1102 | bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) { |
1103 | if (domains.size() <= 1) { |
1104 | return true; |
1105 | } |
1106 | |
1107 | // Can't map if reduction output domains would be mapped |
1108 | if (incompatible_domains_.isReductionOutputMapped(domains, root_map_) && |
1109 | !map_through_reduction_) { |
1110 | return false; |
1111 | } |
1112 | // Make sure mapping these domains won't cause any invalid mapping |
1113 | if (isInvalid(domains)) { |
1114 | return false; |
1115 | } |
1116 | return true; |
1117 | } |
1118 | |
1119 | namespace { |
1120 | class ExactRootDomainMapBuilder : private IterVisitor { |
1121 | public: |
1122 | ExactRootDomainMapBuilder( |
1123 | Fusion* fusion, |
1124 | DisjointSets<const IterDomain*>& eq_sets) |
1125 | : eq_sets_(eq_sets) { |
1126 | traverseTo(fusion, fusion->outputs()); |
1127 | } |
1128 | |
1129 | private: |
1130 | using IterVisitor::handle; |
1131 | |
1132 | void handle(Expr* expr) final { |
1133 | for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) { |
1134 | for (auto consumer : |
1135 | ir_utils::filterByType<TensorView>(expr->outputs())) { |
1136 | PairwiseRootDomainMap pwise_map(producer, consumer, true); |
1137 | const auto mappings = pwise_map.mapProducerToConsumer( |
1138 | producer->domain(), consumer->domain()); |
1139 | for (const auto& mapping : mappings) { |
1140 | eq_sets_.mapEntries(mapping.first, mapping.second); |
1141 | } |
1142 | } |
1143 | } |
1144 | } |
1145 | |
1146 | private: |
1147 | DisjointSets<const IterDomain*>& eq_sets_; |
1148 | }; |
1149 | |
1150 | } // namespace |
1151 | |
1152 | ExactRootDomainMap::ExactRootDomainMap(Fusion* fusion) { |
1153 | ExactRootDomainMapBuilder builder(fusion, eq_sets_); |
1154 | } |
1155 | |
1156 | bool ExactRootDomainMap::areMapped( |
1157 | const IterDomain* id_a, |
1158 | const IterDomain* id_b) const { |
1159 | // With expand going into a view operation there can be an instance where an |
1160 | // iteration root domain in the consumer resolves the broadcast from the |
1161 | // producer, then immediately rfactors it. In this case the consumer root is |
1162 | // not mapped exactly to any other domain, so it might no have an entry in |
1163 | // eq_sets_. eq_sets_.strictAreMapped would throw in this case so just return |
1164 | // false if a mapping doesn't exist. |
1165 | if (!eq_sets_.mappingExists(id_a) || !eq_sets_.mappingExists(id_b)) { |
1166 | return false; |
1167 | } |
1168 | return eq_sets_.strictAreMapped(id_a, id_b); |
1169 | } |
1170 | |
1171 | std::unordered_map<IterDomain*, IterDomain*> ExactRootDomainMap::map( |
1172 | const TensorDomain* producer, |
1173 | const TensorDomain* consumer, |
1174 | const std::unordered_set<IterDomain*>& root_dims_to_map, |
1175 | bool producer_to_consumer) const { |
1176 | const auto& producer_root = |
1177 | TensorDomain::noReductions(producer->getMaybeRFactorDomain()); |
1178 | const auto& consumer_root = consumer->getRootDomain(); |
1179 | const auto& from_ids = producer_to_consumer ? producer_root : consumer_root; |
1180 | const auto& to_ids = producer_to_consumer ? consumer_root : producer_root; |
1181 | |
1182 | std::unordered_map<IterDomain*, IterDomain*> id_map; |
1183 | |
1184 | for (auto& from_id : from_ids) { |
1185 | if (root_dims_to_map.find(from_id) == root_dims_to_map.end()) { |
1186 | continue; |
1187 | } |
1188 | for (const auto& to_id : to_ids) { |
1189 | if (areMapped(from_id, to_id)) { |
1190 | TORCH_INTERNAL_ASSERT( |
1191 | id_map.insert({from_id, to_id}).second, |
1192 | "Multiple matching ID detected for " , |
1193 | from_id); |
1194 | } |
1195 | } |
1196 | } |
1197 | |
1198 | return id_map; |
1199 | } |
1200 | |
1201 | std::string ExactRootDomainMap::toString() const { |
1202 | return eq_sets_.toString(); |
1203 | } |
1204 | |
1205 | } // namespace cuda |
1206 | } // namespace fuser |
1207 | } // namespace jit |
1208 | } // namespace torch |
1209 | |