1 | #include <compute_at_map.h> |
2 | |
3 | #include <disjoint_set.h> |
4 | #include <ir_utils.h> |
5 | #include <lower2device.h> |
6 | #include <root_domain_map.h> |
7 | #include <transform_iter.h> |
8 | |
9 | #include <tuple> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | namespace { |
16 | |
17 | // Is the provided IterDomain an Leaf of provided TensorView and within its |
18 | // computeAtPosition |
19 | bool idIsAComputeAtLeafDomain(IterDomain* id, TensorView* tv) { |
20 | auto begin = tv->domain()->domain().begin(); |
21 | auto end = tv->domain()->domain().begin() + tv->getComputeAtPosition(); |
22 | return std::find(begin, end, id) != end; |
23 | } |
24 | |
25 | // Is the provided IterDomain an Leaf of provided TensorView |
26 | bool idIsALeafDomain(IterDomain* id, TensorView* tv) { |
27 | auto begin = tv->domain()->domain().begin(); |
28 | auto end = tv->domain()->domain().end(); |
29 | return std::find(begin, end, id) != end; |
30 | } |
31 | |
32 | } // namespace |
33 | |
34 | IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) { |
35 | build(fusion); |
36 | |
37 | if (!allow_self_mapping) { |
38 | TORCH_INTERNAL_ASSERT( |
39 | !hasSelfMapping(), |
40 | "Unsupported domain mapping detected in " , |
41 | std::get<0>(*self_mapping_info_)->toString(), |
42 | ". " , |
43 | std::get<3>(*self_mapping_info_), |
44 | " domains, " , |
45 | std::get<1>(*self_mapping_info_)->toString(), |
46 | " and " , |
47 | std::get<2>(*self_mapping_info_)->toString(), |
48 | ", are mapped with each other." ); |
49 | } |
50 | } |
51 | |
52 | //! Map corresponding inputs and outputs of swizzle op together |
53 | //! on the given disjoint set, if the given id is an output |
54 | //! of a swizzle operator. |
55 | //! |
56 | //! The current usage of swizzle operator is local to each tensor |
57 | //! itself, so they should not affect exact or permissive mapping |
58 | //! between iterdomains on different tensor domains. |
59 | //! TODO: |
60 | //! Exact mapping based index hoisting of swizzled iterdomains |
61 | //! is disabled currently and will be re-enabled in the next |
62 | //! few build out steps. |
63 | void mapMaybeSwizzleOp( |
64 | DisjointSets<IterDomain*>& disjoint_sets, |
65 | IterDomain* id) { |
66 | if (auto swizzle_2d = dynamic_cast<Swizzle2D*>(id->definition())) { |
67 | // Map each input to its corresponding output on the given |
68 | // disjoint set. |
69 | disjoint_sets.mapEntries(swizzle_2d->inX(), swizzle_2d->outX()); |
70 | disjoint_sets.mapEntries(swizzle_2d->inY(), swizzle_2d->outY()); |
71 | } |
72 | } |
73 | |
74 | bool IterDomainGraph::exprsMap( |
75 | Expr* first, |
76 | Expr* second, |
77 | bool forward, |
78 | const DisjointSets<IterDomain*>& id_map) { |
79 | if (first == nullptr || second == nullptr) { |
80 | return false; |
81 | } |
82 | |
83 | if (first->etype() != second->etype()) { |
84 | return false; |
85 | } |
86 | |
87 | TORCH_INTERNAL_ASSERT( |
88 | first->etype() == ExprType::Merge || first->etype() == ExprType::Split, |
89 | "Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n" , |
90 | first->toString()); |
91 | |
92 | auto first_ids = ir_utils::filterByType<IterDomain>( |
93 | forward ? first->inputs() : first->outputs()) |
94 | .vector(); |
95 | |
96 | auto second_ids = ir_utils::filterByType<IterDomain>( |
97 | forward ? second->inputs() : second->outputs()) |
98 | .vector(); |
99 | |
100 | TORCH_INTERNAL_ASSERT( |
101 | first_ids.size() == second_ids.size(), |
102 | "Expected number of " , |
103 | (forward ? "inputs" : "outputs" ), |
104 | " to match for\n" , |
105 | first->toString(), |
106 | second->toString()); |
107 | |
108 | { |
109 | std::vector<std::pair<IterDomain*, IterDomain*>> zipped_ids; |
110 | |
111 | std::transform( |
112 | first_ids.begin(), |
113 | first_ids.end(), |
114 | second_ids.begin(), |
115 | std::back_inserter(zipped_ids), |
116 | [](IterDomain* first, IterDomain* second) { |
117 | return std::make_pair(first, second); |
118 | }); |
119 | |
120 | if (std::any_of( |
121 | zipped_ids.begin(), |
122 | zipped_ids.end(), |
123 | [&](std::pair<IterDomain*, IterDomain*> id_pair) { |
124 | return !id_map.strictAreMapped(id_pair.first, id_pair.second); |
125 | })) { |
126 | return false; |
127 | } |
128 | } |
129 | |
130 | if (first->isA<Merge>() && !forward) { |
131 | // Can't back prop through merge without making sure one dimension actually |
132 | // is identical extents. |
133 | auto merge0 = first->as<Merge>(); |
134 | auto merge1 = second->as<Merge>(); |
135 | |
136 | auto extent_0o = merge0->outer()->extent(); |
137 | auto extent_0i = merge0->inner()->extent(); |
138 | auto extent_1o = merge1->outer()->extent(); |
139 | auto extent_1i = merge1->inner()->extent(); |
140 | |
141 | auto extent_0_match = extent_0o->sameAs(extent_1o) || |
142 | (extent_0o->isConstInt() && extent_1o->isConstInt() && |
143 | extent_0o->evaluateInt() == extent_1o->evaluateInt()); |
144 | |
145 | auto extent_1_match = extent_0i->sameAs(extent_1i) || |
146 | (extent_0i->isConstInt() && extent_1i->isConstInt() && |
147 | extent_0i->evaluateInt() == extent_1i->evaluateInt()); |
148 | |
149 | if (!(extent_0_match || extent_1_match)) { |
150 | return false; |
151 | } |
152 | } |
153 | |
154 | if (first->isA<Split>()) { |
155 | auto first_split = first->as<Split>(); |
156 | auto second_split = second->as<Split>(); |
157 | if (!first_split->factor()->sameAs(second_split->factor()) || |
158 | first_split->innerSplit() != second_split->innerSplit() || |
159 | !first_split->startOffset()->sameAs(second_split->startOffset()) || |
160 | !first_split->stopOffset()->sameAs(second_split->stopOffset())) { |
161 | return false; |
162 | } |
163 | } |
164 | |
165 | return true; |
166 | } |
167 | |
168 | void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { |
169 | if (first == nullptr || second == nullptr) { |
170 | return; |
171 | } |
172 | |
173 | if (!exprsMap(first, second, forward, exact_nodes_)) { |
174 | return; |
175 | } |
176 | |
177 | auto first_ids = ir_utils::filterByType<IterDomain>( |
178 | forward ? first->outputs() : first->inputs()) |
179 | .vector(); |
180 | auto second_ids = ir_utils::filterByType<IterDomain>( |
181 | forward ? second->outputs() : second->inputs()) |
182 | .vector(); |
183 | TORCH_INTERNAL_ASSERT( |
184 | first_ids.size() == second_ids.size(), |
185 | "This should be unreachable, if transformation expressions match, their number of inputs and outputs should as well.\n However found:\n" , |
186 | first->toString(), |
187 | "\nand\n" , |
188 | second->toString()); |
189 | for (auto out_i : c10::irange(first_ids.size())) { |
190 | exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]); |
191 | permissive_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]); |
192 | } |
193 | } |
194 | |
195 | namespace { |
196 | |
197 | // Returns a pair of mapped IDs |
198 | c10::optional<std::pair<IterDomain*, IterDomain*>> detectMappablePair( |
199 | const std::vector<IterDomain*>& ids, |
200 | const IterDomainGraph& id_graph) { |
201 | for (auto id1 : ids) { |
202 | for (auto id2 : ids) { |
203 | if (id1 == id2) { |
204 | continue; |
205 | } |
206 | if (id_graph.permissiveNodes().disjointSetMap().at(id1)->has(id2)) { |
207 | return std::make_pair(id1, id2); |
208 | } |
209 | } |
210 | } |
211 | |
212 | return {}; |
213 | } |
214 | |
215 | // It is assumed that for any tensor represented by a list of domains, |
216 | // those domains should never be mapped with each other. It may be |
217 | // possible to lift this assumption, but it's unclear if it could |
218 | // matter in practice. |
219 | c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>> |
220 | findFirstSelfMapping(Fusion* fusion, const IterDomainGraph& id_graph) { |
221 | for (auto tv : ir_utils::allTvs(fusion)) { |
222 | // For each tensor, make sure root, rfactor and leaf domains |
223 | // should not include domains that are mapped with another domain |
224 | // in the same set of domains. This may be overly conservative, |
225 | // and it maybe enough to check the root domains. |
226 | |
227 | // Root domains |
228 | auto self_mappped_root_pair = |
229 | detectMappablePair(tv->getRootDomain(), id_graph); |
230 | if (self_mappped_root_pair.has_value()) { |
231 | return std::make_tuple( |
232 | tv, |
233 | self_mappped_root_pair->first, |
234 | self_mappped_root_pair->second, |
235 | "Root" ); |
236 | } |
237 | |
238 | // Rfactor domains |
239 | if (tv->hasRFactor()) { |
240 | auto self_mappped_rf_pair = |
241 | detectMappablePair(tv->getRFactorDomain(), id_graph); |
242 | if (self_mappped_rf_pair.has_value()) { |
243 | return std::make_tuple( |
244 | tv, |
245 | self_mappped_rf_pair->first, |
246 | self_mappped_rf_pair->second, |
247 | "RFactor" ); |
248 | } |
249 | } |
250 | |
251 | // Leaf domains |
252 | auto self_mappped_leaf_pair = |
253 | detectMappablePair(tv->domain()->domain(), id_graph); |
254 | if (self_mappped_leaf_pair.has_value()) { |
255 | return std::make_tuple( |
256 | tv, |
257 | self_mappped_leaf_pair->first, |
258 | self_mappped_leaf_pair->second, |
259 | "Leaf" ); |
260 | } |
261 | } |
262 | return c10::nullopt; |
263 | } |
264 | |
265 | } // namespace |
266 | |
267 | void IterDomainGraph::build(Fusion* fusion) { |
268 | FusionGuard fg(fusion); |
269 | |
270 | // Initialize a node for every iteration domain |
271 | for (auto tv : ir_utils::allTvs(fusion)) { |
272 | const auto& root_domain = tv->getRootDomain(); |
273 | const auto& domain = tv->domain()->domain(); |
274 | |
275 | // Grab all values in the history of the tensor view's domain |
276 | auto all_vals = DependencyCheck::getAllValsBetween( |
277 | {root_domain.begin(), root_domain.end()}, |
278 | {domain.begin(), domain.end()}); |
279 | |
280 | // Filter so we only have iteration domains (ignore Ints used in split) |
281 | auto all_ids = ir_utils::filterByType<IterDomain>(all_vals); |
282 | |
283 | // Check is this domain is a consumer of a view-like operation |
284 | bool view_like_domain = tv->domain()->hasViewLikeRFactor(); |
285 | |
286 | for (auto id : all_ids) { |
287 | // Check if this id is a view like rfactor id |
288 | bool is_view_rfactor_id = false; |
289 | if (view_like_domain && id->isRFactorProduct()) { |
290 | // If the tensor domain is a view like domain, and the iteration domain |
291 | // is marked as an rfactor product and is in the rfactor domain, it's a |
292 | // view like rfactor iteration domain |
293 | const auto& rfactor_domain = tv->domain()->getMaybeRFactorDomain(); |
294 | if (std::find(rfactor_domain.begin(), rfactor_domain.end(), id) != |
295 | rfactor_domain.end()) { |
296 | is_view_rfactor_id = true; |
297 | } |
298 | } |
299 | bool is_leaf_id = |
300 | std::find(domain.begin(), domain.end(), id) != domain.end(); |
301 | initializeId(id, is_view_rfactor_id, is_leaf_id); |
302 | } |
303 | } |
304 | |
305 | // All ID's are initialized, start connecting them on the permissive, exact, |
306 | // and loop dimensions. |
307 | |
308 | for (auto expr : fusion->exprs()) { |
309 | if (!ir_utils::isTvOp(expr)) { |
310 | continue; |
311 | } |
312 | |
313 | auto tv_outputs = ir_utils::filterByType<TensorView>(expr->outputs()); |
314 | TensorView* first_output_tv = nullptr; |
315 | |
316 | for (auto c_tv : tv_outputs) { |
317 | if (first_output_tv == nullptr) { |
318 | first_output_tv = c_tv; |
319 | } else { |
320 | // Map multi outputs of an expression to each other. c is current |
321 | // output, and f as first output. Keep consistent with the later section |
322 | // of producer and consumers. Which here producer is now "first output", |
323 | // and consumer is still consumer. One exception is how the |
324 | // domains left of CA positions are handled in the Parallel |
325 | // map. Those domains are not mapped in producer and consumer |
326 | // mappings as they do not share loops, but are mapped in the |
327 | // case of mapping multiple outputs since they do share the |
328 | // same loops. |
329 | |
330 | TORCH_INTERNAL_ASSERT( |
331 | c_tv->getRootDomain().size() == |
332 | first_output_tv->getRootDomain().size(), |
333 | "Multiple outputs with mismatched dimensions is not supported. " , |
334 | "Only supported case is welford op where all outputs tvs have identical domains." ); |
335 | // p->f, c->c |
336 | std::unordered_map<IterDomain*, IterDomain*> c2f_root_map; |
337 | for (const auto i : |
338 | c10::irange(first_output_tv->getRootDomain().size())) { |
339 | c2f_root_map.insert(std::make_pair( |
340 | c_tv->getRootDomain()[i], first_output_tv->getRootDomain()[i])); |
341 | } |
342 | |
343 | // Multi output mapping, outputs are required to have the same domain |
344 | // and same transformations, so they can be mapped in permissive/exact, |
345 | // and when within compute at position of domain()->domain() in the |
346 | // parallel map. |
347 | auto replay_FasC = BestEffortReplay( |
348 | first_output_tv->domain()->domain(), |
349 | c_tv->domain()->domain(), |
350 | c2f_root_map); |
351 | |
352 | auto c2f_map = replay_FasC.getReplay(); |
353 | |
354 | // Map the entire replay map between the multiple |
355 | // consumers even for the Parallel map as they share the same |
356 | // loop. |
357 | for (auto entry : c2f_map) { |
358 | auto c_id = entry.first; |
359 | auto f_id = entry.second; |
360 | // Map the id's together |
361 | permissive_nodes_.mapEntries(f_id, c_id); |
362 | exact_nodes_.mapEntries(f_id, c_id); |
363 | if (idIsALeafDomain(f_id, first_output_tv)) { |
364 | loop_nodes_.mapEntries(f_id, c_id); |
365 | } |
366 | sibling_sets_.mapEntries(f_id, c_id); |
367 | } |
368 | } |
369 | |
370 | auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs()); |
371 | |
372 | for (auto p_tv : tv_inputs) { |
373 | // If outside computeAt axis, we don't want to directly map |
374 | // consumer/producer as their thread mappings could change as long as |
375 | // it's across shared/global memory. |
376 | auto pairwise_map = PairwiseRootDomainMap(p_tv, c_tv); |
377 | const auto& permissive_c2p_root_map = |
378 | pairwise_map.mapConsumerToProducer(c_tv->domain(), p_tv->domain()); |
379 | |
380 | // Look for matching ID transformations in producer and consumer, replay |
381 | // producer as consumer. We want to replay producer as consumer instead |
382 | // of the other way around since consumer may have some broadcasted axes |
383 | // producer doesn't have merged into loops producer may use. If we did |
384 | // consumer as producer we wouldn't have this information in the |
385 | // mapping. If we're using this map for indexing, we do not want to |
386 | // propagate broadcast mismatches. If we're using it to identify loop |
387 | // nests, we do want to propagate mismatches. |
388 | auto permissive_replay_PasC = |
389 | BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map); |
390 | |
391 | const auto& permissive_c2p_map = permissive_replay_PasC.getReplay(); |
392 | const auto permissive_disjoint_sets = |
393 | permissive_replay_PasC.getDisjointSets(); |
394 | |
395 | // For exact mapings do not map any broadcast dimensions to |
396 | // non-broadcast dimensions. Prevent any broadcasted axes being mapped |
397 | // to non-broadcasted axes. |
398 | auto exact_c2p_root_map = |
399 | PairwiseRootDomainMap(p_tv, c_tv, true) |
400 | .mapConsumerToProducer(c_tv->domain(), p_tv->domain()); |
401 | |
402 | // Same as permissive above but for exact |
403 | auto exact_replay_PasC = BestEffortReplay( |
404 | p_tv->domain()->domain(), |
405 | c_tv->domain()->domain(), |
406 | exact_c2p_root_map); |
407 | |
408 | const auto& exact_c2p_map = exact_replay_PasC.getReplay(); |
409 | |
410 | for (auto entry : exact_c2p_map) { |
411 | auto c_id = entry.first; |
412 | auto p_id = entry.second; |
413 | exact_nodes_.mapEntries(c_id, p_id); |
414 | consumers_.at(p_id).pushBack(c_id); |
415 | producers_.at(c_id).pushBack(p_id); |
416 | |
417 | // Add the swizzle inputs to the same |
418 | // disjoint set as well if either c_id |
419 | // or p_id is swizzle output. |
420 | mapMaybeSwizzleOp(exact_nodes_, p_id); |
421 | mapMaybeSwizzleOp(exact_nodes_, c_id); |
422 | } |
423 | |
424 | for (auto entry : permissive_c2p_map) { |
425 | auto c_id = entry.first; |
426 | auto p_id = entry.second; |
427 | if (idIsAComputeAtLeafDomain(p_id, p_tv)) { |
428 | loop_nodes_.mapEntries(c_id, p_id); |
429 | } else { |
430 | // When there are trivial reductions merged with other dims, `p_id` |
431 | // might not be a compute at leaf domain of `p_tv`, but it actually |
432 | // has an equivalent compute at leaf domain. For that case, we map |
433 | // the equivalent compute at leaf domain. |
434 | for (unsigned int i = 0; i < p_tv->getComputeAtPosition(); i++) { |
435 | auto id = p_tv->axis(i); |
436 | if (permissive_disjoint_sets.permissiveAreMapped(p_id, id)) { |
437 | loop_nodes_.mapEntries(c_id, id); |
438 | } |
439 | } |
440 | } |
441 | permissive_nodes_.mapEntries(c_id, p_id); |
442 | consumers_.at(p_id).pushBack(c_id); |
443 | producers_.at(c_id).pushBack(p_id); |
444 | |
445 | // Add the swizzle inputs to the same |
446 | // disjoint set as well if either c_id |
447 | // or p_id is swizzle output. |
448 | mapMaybeSwizzleOp(permissive_nodes_, p_id); |
449 | mapMaybeSwizzleOp(permissive_nodes_, c_id); |
450 | } |
451 | |
452 | // Make sure we always get root mapping for the permissive map. |
453 | // Because of forwarding we could otherwise miss some root mappings. |
454 | for (auto entry : permissive_c2p_root_map) { |
455 | auto c_id = entry.first; |
456 | auto p_id = entry.second; |
457 | // Map the id's together |
458 | permissive_nodes_.mapEntries(c_id, p_id); |
459 | consumers_.at(p_id).pushBack(c_id); |
460 | producers_.at(c_id).pushBack(p_id); |
461 | } |
462 | } |
463 | } |
464 | } |
465 | |
466 | // Explicitly map through rfactor transformations, if we have an op like: |
467 | // |
468 | // T1[x, y*z] = view(T0[x*y, z]) |
469 | // T3[x, y*z] = view(T2[x*y, z]) |
470 | // T4 = T0 + T2 |
471 | // |
472 | // We want to map T1 and T3's rfactor transformations together by playing the |
473 | // transformations forward since their root domains map. If instead we have: |
474 | // |
475 | // T1[x, y*z] = view(T0[x*y, z]) |
476 | // T3[x, y*z] = view(T2[x*y, z]) |
477 | // T4 = T1 + T3 |
478 | // |
479 | // Then we wouldn't have a mapping of T1 and T3's root domain, we'd have a |
480 | // mapping of their rfactor domain, so we would want to map T1 and T3's |
481 | // rfactor transformations starting at their rfactor domains. |
482 | // |
483 | // Therefore we'll explicitly map rfactor transformation iteration domains |
484 | // forward and backwards. Something similar could happen with rfactor of root |
485 | // domains, though it seems mapping rfactor reduction domains aren't that |
486 | // important. Mapping view transformations is more important since view is |
487 | // part of the compute definition so having the map through the |
488 | // transformations makes it easy to check if different view operations are |
489 | // consistent with eachother. |
490 | |
491 | auto all_tvs = ir_utils::allTvs(fusion); |
492 | std::vector<TensorView*> all_consumer_tvs; |
493 | std::copy_if( |
494 | all_tvs.begin(), |
495 | all_tvs.end(), |
496 | std::back_inserter(all_consumer_tvs), |
497 | [](TensorView* tv) { return !tv->isFusionInput() && tv->hasRFactor(); }); |
498 | |
499 | // IterDomains could have multiple uses defined in the fusion if multiple |
500 | // transformations were redefined (more than one transform propagation pass |
501 | // was run and retransformed sections of the graph). We're going to make a new |
502 | // uses map so we can easily process the actual uses of IterDomains. We |
503 | // actually only need rfactor uses for this section of mapping, so we'll limit |
504 | // this map to only rfactor transformations. |
505 | std::unordered_map<IterDomain*, Expr*> rfactor_id_uses; |
506 | |
507 | // Order of traversal is important for processing all the rfactor ids as the |
508 | // first pass will go forward through expressions and the second pass will |
509 | // traverse backwards through them. ID's will be unique in this vector, |
510 | // enforced when building it since it's built with rfactor_id_uses. |
511 | std::vector<IterDomain*> rfactor_id_order; |
512 | |
513 | // Grab all the rfactor ids. |
514 | for (auto consumer_tv : all_consumer_tvs) { |
515 | auto exprs = StmtSort::getExprs( |
516 | fusion, |
517 | {consumer_tv->getMaybeRFactorDomain().begin(), |
518 | consumer_tv->getMaybeRFactorDomain().end()}); |
519 | for (auto expr : exprs) { |
520 | auto rfactor_inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs()); |
521 | TORCH_INTERNAL_ASSERT( |
522 | expr->isA<Split>() || expr->isA<Merge>(), |
523 | "Wasn't expecting the expression type of:\n" , |
524 | expr->toString(), |
525 | "\nto be an expression defined in an rfactor transformation." ); |
526 | for (auto rfactor_inp_id : rfactor_inp_ids) { |
527 | TORCH_INTERNAL_ASSERT( |
528 | rfactor_id_uses.find(rfactor_inp_id) == rfactor_id_uses.end(), |
529 | "Was expecting iter domains to only have one active transformation but found id " , |
530 | rfactor_inp_id->toString(), |
531 | " used in\n" , |
532 | rfactor_id_uses.at(rfactor_inp_id), |
533 | "\nand\n" , |
534 | expr->toString()); |
535 | rfactor_id_uses.emplace(std::make_pair(rfactor_inp_id, expr)); |
536 | rfactor_id_order.push_back(rfactor_inp_id); |
537 | } |
538 | } |
539 | for (auto rfactor_id : consumer_tv->getMaybeRFactorDomain()) { |
540 | if (rfactor_id->isRFactorProduct()) { |
541 | rfactor_id_uses.emplace(std::make_pair(rfactor_id, nullptr)); |
542 | rfactor_id_order.push_back(rfactor_id); |
543 | } |
544 | } |
545 | } |
546 | |
547 | // if prop_forward we're going forward through transformations and |
548 | // expressions, meaning if inputs of expressions map then we map their |
549 | // outputs, otherwise we're traversing backwards, meaning if outputs of |
550 | // expressions map then we map their inputs. |
551 | for (auto prop_forward : {true, false}) { |
552 | std::unordered_set<Expr*> visited_exprs; |
553 | |
554 | for (auto rfactor_id_i : c10::irange(rfactor_id_order.size())) { |
555 | auto first_rfactor_id = prop_forward |
556 | ? rfactor_id_order[rfactor_id_i] |
557 | : rfactor_id_order[rfactor_id_order.size() - 1 - rfactor_id_i]; |
558 | |
559 | // At should be safe since we made rfactor_id_order and rfactor_id_uses at |
560 | // the same time so they should have the same exact entries. |
561 | auto first_expr = prop_forward ? rfactor_id_uses.at(first_rfactor_id) |
562 | : first_rfactor_id->definition(); |
563 | |
564 | if (first_expr == nullptr) { |
565 | continue; |
566 | } |
567 | |
568 | if (visited_exprs.find(first_expr) != visited_exprs.end()) { |
569 | continue; |
570 | } |
571 | visited_exprs.emplace(first_expr); |
572 | |
573 | // Only need to be concerned here with mapping across rfactor iter |
574 | // domains, so isolate out those. |
575 | auto all_exact_map_ids = exact_nodes_.getDisjointSetOf(first_rfactor_id); |
576 | std::vector<IterDomain*> exact_map_rf_ids; |
577 | std::copy_if( |
578 | all_exact_map_ids.vector().begin(), |
579 | all_exact_map_ids.vector().end(), |
580 | std::back_inserter(exact_map_rf_ids), |
581 | [](IterDomain* id) { return id->isRFactorProduct(); }); |
582 | |
583 | for (auto exact_map_rf_id : exact_map_rf_ids) { |
584 | if (exact_map_rf_id == first_rfactor_id) { |
585 | continue; |
586 | } |
587 | // If there's an input with an rfactor domain we could have an exact |
588 | // mapped rfactor id that's on the input meaning it wouldn't have an |
589 | // entry in rfactor_id_uses |
590 | auto other_use = |
591 | rfactor_id_uses.find(exact_map_rf_id) == rfactor_id_uses.end() |
592 | ? nullptr |
593 | : rfactor_id_uses.at(exact_map_rf_id); |
594 | auto other_expr = |
595 | prop_forward ? other_use : exact_map_rf_id->definition(); |
596 | |
597 | if (other_expr == nullptr) { |
598 | continue; |
599 | } |
600 | |
601 | if (visited_exprs.find(other_expr) != visited_exprs.end()) { |
602 | continue; |
603 | } |
604 | |
605 | mapThroughExpr(first_expr, other_expr, prop_forward); |
606 | } |
607 | } |
608 | } |
609 | self_mapping_info_ = findFirstSelfMapping(fusion, *this); |
610 | } |
611 | |
612 | void IterDomainGraph::initializeId( |
613 | IterDomain* id, |
614 | bool is_view_rfactor_id, |
615 | bool is_leaf_id) { |
616 | permissive_nodes_.initializeSet(id); |
617 | exact_nodes_.initializeSet(id); |
618 | if (is_leaf_id) { |
619 | loop_nodes_.initializeSet(id); |
620 | } |
621 | consumers_[id] = {}; |
622 | producers_[id] = {}; |
623 | sibling_sets_.initializeSet(id); |
624 | |
625 | all_ids_.pushBack(id); |
626 | |
627 | if (is_view_rfactor_id) { |
628 | view_rfactor_ids_.emplace(id); |
629 | } |
630 | } |
631 | |
632 | ComputeAtMap::ComputeAtMap(Fusion* fusion) |
633 | : id_graph_(fusion), fusion_(fusion) { |
634 | build(fusion); |
635 | } |
636 | |
637 | void ComputeAtMap::build(Fusion* fusion) { |
638 | trivial_reduction_info_.build(fusion); |
639 | buildConcreteIds(); |
640 | } |
641 | |
642 | void ComputeAtMap::validateAndPropagatePType() { |
643 | for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) { |
644 | ParallelType common_ptype = ParallelType::Serial; |
645 | for (auto id : loop_disjoint_set->vector()) { |
646 | auto id_ptype = id->getParallelType(); |
647 | TORCH_INTERNAL_ASSERT( |
648 | id_ptype == common_ptype || id_ptype == ParallelType::Serial || |
649 | common_ptype == ParallelType::Serial, |
650 | "Issue validating parallel type disjoint ptype is, " , |
651 | common_ptype, |
652 | " but found in the set the id: " , |
653 | id->toString()); |
654 | common_ptype = |
655 | common_ptype == ParallelType::Serial ? id_ptype : common_ptype; |
656 | } |
657 | |
658 | for (auto id : loop_disjoint_set->vector()) { |
659 | id->parallelize(common_ptype); |
660 | } |
661 | } |
662 | } |
663 | |
664 | void ComputeAtMap::allocateIndexVariables() { |
665 | // Run through all disjoint sets registered in loop map, |
666 | // all lowered kir::ForLoop will correspond to one of the disjoint sets |
667 | // and we only need one index variable for each set. |
668 | for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) { |
669 | ParallelType ptype; |
670 | // first allocate thread and grid parallel indices: |
671 | // The validation pass will check that the parallel bindings within the |
672 | // loop nodes are consistent so all the loops within this disjoint set |
673 | // will be realized implicitly using parallel index variables. |
674 | if (std::any_of( |
675 | loop_disjoint_set->vector().begin(), |
676 | loop_disjoint_set->vector().end(), |
677 | [&ptype](IterDomain* id) { |
678 | if (id->isThread() && |
679 | // Halo extended parallel loops currently are handled |
680 | // differently and an index variable would still |
681 | // be allocated in this case. |
682 | (GpuLower::current()->haloInfo()->getExtent(id) == nullptr)) { |
683 | ptype = id->getParallelType(); |
684 | return true; |
685 | } |
686 | return false; |
687 | })) { |
688 | loop_index_variable_map_[loop_disjoint_set.get()] = |
689 | NamedScalar::getParallelIndex(ptype); |
690 | continue; |
691 | } |
692 | |
693 | // All loops in this set are non-parallel, non-concretized broadcast |
694 | // iterdomains, their "index variable" should be zero. |
695 | if (std::all_of( |
696 | loop_disjoint_set->vector().begin(), |
697 | loop_disjoint_set->vector().end(), |
698 | [](IterDomain* id) { return id->isBroadcast(); })) { |
699 | loop_index_variable_map_[loop_disjoint_set.get()] = fusion_->zeroVal(); |
700 | continue; |
701 | } |
702 | |
703 | // Allocate variable for the iterdomains: |
704 | auto concrete_loop_id_it = concrete_id_cache_.find(loop_disjoint_set); |
705 | TORCH_INTERNAL_ASSERT( |
706 | concrete_loop_id_it != concrete_id_cache_.end(), |
707 | "Concrete id not computed" ); |
708 | |
709 | auto concrete_loop_id = concrete_loop_id_it->second; |
710 | |
711 | // Need to allocate double buffered loop differently. |
712 | if (GpuLower::current()->doubleBufferInfo().isDoubleBufferedIterDomain( |
713 | concrete_loop_id)) { |
714 | // Allocate index variable for each stage of the double buffered loop. |
715 | double_buffered_loop_index_variable_map_[loop_disjoint_set.get()] = |
716 | std::make_unique<DoubleBufferIndices>(DoubleBufferIndices( |
717 | {{DoubleBufferLoopStage::Prolog, |
718 | IrBuilder::create<Int>(c10::nullopt)}, |
719 | {DoubleBufferLoopStage::Main, |
720 | IrBuilder::create<Int>(c10::nullopt)}, |
721 | {DoubleBufferLoopStage::Epilog, |
722 | IrBuilder::create<Int>(c10::nullopt)}})); |
723 | } else { |
724 | // Everything now should be serial concrete loops, |
725 | // we just allocate a loop index integer for each set of loops. |
726 | loop_index_variable_map_[loop_disjoint_set.get()] = |
727 | IrBuilder::create<Int>(c10::nullopt); |
728 | } |
729 | } |
730 | } |
731 | |
732 | Val* ComputeAtMap::getIndexVariable( |
733 | IterDomain* id, |
734 | DoubleBufferLoopStage double_buffer_loop_stage) const { |
735 | TORCH_INTERNAL_ASSERT( |
736 | id_graph_.loopNodes().mappingExists(id), |
737 | "Index Variable: no index variable allocated as " , |
738 | id->toString(), |
739 | " is not registered in loop map" ); |
740 | const auto* loop_set = &(id_graph_.loopNodes().getDisjointSetOf(id)); |
741 | |
742 | // Check if this loop was modified by double buffer pass. |
743 | bool is_double_buffer_iterdomain = |
744 | GpuLower::current()->doubleBufferInfo().isDoubleBufferedIterDomain(id); |
745 | |
746 | if (is_double_buffer_iterdomain) { |
747 | // Use dedicated double buffer index variable if the loop is double buffer |
748 | // loop |
749 | if (double_buffer_loop_stage == DoubleBufferLoopStage::NotApplicable) { |
750 | // The double buffered loop stages are created after the loop nest |
751 | // lowering phase so this function will be querried before the double |
752 | // buffer pass. At that point, no forloop has any double buffer |
753 | // stage defined, and we just default to using the main stage index. |
754 | double_buffer_loop_stage = DoubleBufferLoopStage::Main; |
755 | } |
756 | return double_buffered_loop_index_variable_map_.at(loop_set)->at( |
757 | double_buffer_loop_stage); |
758 | } else { |
759 | return loop_index_variable_map_.at(loop_set); |
760 | } |
761 | } |
762 | |
763 | bool ComputeAtMap::areMapped( |
764 | IterDomain* id0, |
765 | IterDomain* id1, |
766 | IdMappingMode mode) const { |
767 | return disjointSetOf(id0, mode)->has(id1); |
768 | } |
769 | |
770 | namespace { |
771 | |
772 | // Validate a LOOP concrete ID has the complete ID set required for |
773 | // indexing. See issue #1655 and FusionIncompleteConcreteID for an |
774 | // example fusion that fails with this validation. Fixing this issue |
775 | // would require creating a reference IterDomain with all the |
776 | // necessary root ID for for loop extent generation, for indexing, and for |
777 | // predication. |
778 | // |
779 | // root_ids_of_all_ids and root_ids_of_concrete_id consist of EXACT |
780 | // concrete IDs. |
781 | void validateCompletenessOfLoopConcreteID( |
782 | IterDomain* concrete_id, |
783 | const ComputeAtMap& ca_map, |
784 | const TrivialReductionInfo& trivial_reduction_info, |
785 | // All root id's of all IDs in the disjoint id set |
786 | const std::unordered_set<IterDomain*>& root_ids_of_all_ids, |
787 | // Map from a root id to the concrete id's it's represented in |
788 | const std::unordered_set<IterDomain*>& root_ids_of_concrete_id, |
789 | const std::unordered_map<IterDomain*, std::vector<IterDomain*>>& |
790 | root_id_to_maybe_concrete_ids, |
791 | // Disjoint set just for printing |
792 | const std::vector<IterDomain*>& id_set, |
793 | // All the candidate concrete IDs found for this disjoint id set |
794 | const std::vector<IterDomain*>& maybe_concrete_ids) { |
795 | std::vector<IterDomain*> root_ids_not_found_with_concrete_id; |
796 | |
797 | for (auto root_id : root_ids_of_all_ids) { |
798 | if (root_ids_of_concrete_id.find(root_id) != |
799 | root_ids_of_concrete_id.end()) { |
800 | continue; |
801 | } |
802 | |
803 | // None of the root IDs of the conrete ID is exactly mapped with |
804 | // root_id. |
805 | |
806 | // It is still a valid concrete ID if it has a non-broadcast |
807 | // root ID that is mapped with root_id. |
808 | if ((root_id->isBroadcast() || trivial_reduction_info.isDerived(root_id)) && |
809 | std::any_of( |
810 | root_ids_of_concrete_id.begin(), |
811 | root_ids_of_concrete_id.end(), |
812 | [&](auto root_id_of_concrete_id) { |
813 | return !root_id_of_concrete_id->isBroadcast() && |
814 | !trivial_reduction_info.isDerived(root_id_of_concrete_id) && |
815 | ca_map.areMapped( |
816 | root_id, |
817 | root_id_of_concrete_id, |
818 | IdMappingMode::PERMISSIVE); |
819 | })) { |
820 | continue; |
821 | } |
822 | |
823 | // If all of the corresponding maybe-concrete IDs are exactly |
824 | // mapped with the concrete ID, this missing root_id is not a |
825 | // problem. This can happen with reduction rfactor, e.g., |
826 | // FusionAdvancedLowering1. |
827 | if (std::all_of( |
828 | root_id_to_maybe_concrete_ids.at(root_id).begin(), |
829 | root_id_to_maybe_concrete_ids.at(root_id).end(), |
830 | [&](auto maybe_concrete_id) { |
831 | return ca_map.areMapped( |
832 | concrete_id, maybe_concrete_id, IdMappingMode::EXACT); |
833 | })) { |
834 | continue; |
835 | } |
836 | |
837 | root_ids_not_found_with_concrete_id.push_back(root_id); |
838 | } |
839 | |
840 | if (root_ids_not_found_with_concrete_id.empty()) { |
841 | return; |
842 | } |
843 | |
844 | // Error detected as some root IDs are not accounted for by the |
845 | // concrete ID. |
846 | std::stringstream error_msg; |
847 | error_msg << "IDs: " << ir_utils::toString(id_set); |
848 | error_msg << ", concrete ID: " << concrete_id->toString(); |
849 | error_msg << ", maybe concrete IDs: " |
850 | << ir_utils::toString(maybe_concrete_ids); |
851 | error_msg << ", all root IDs:" ; |
852 | for (auto root_id : root_ids_of_all_ids) { |
853 | error_msg << " " << root_id->toString(); |
854 | } |
855 | error_msg << ", root IDs not found with concrete ID: " ; |
856 | for (auto id : root_ids_not_found_with_concrete_id) { |
857 | error_msg << " " << id->toString(); |
858 | } |
859 | TORCH_INTERNAL_ASSERT( |
860 | false, "Concrete ID failed to cover all root IDs. " , error_msg.str()); |
861 | } |
862 | |
863 | } // namespace |
864 | |
865 | IterDomain* ComputeAtMap::computeConcreteId( |
866 | IterDomain* id, |
867 | IdMappingMode mode) { |
868 | const auto& disjoint_set_shared_ptr = disjointSetOf(id, mode); |
869 | |
870 | TORCH_INTERNAL_ASSERT( |
871 | disjoint_set_shared_ptr->vector().size(), |
872 | "Empty disjoint set found for " , |
873 | id->toString()); |
874 | |
875 | if (disjoint_set_shared_ptr->vector().size() == 1) { |
876 | // If only one entry in the disjoint set, by definition the existing ID has |
877 | // to be the concrete ID. |
878 | return disjoint_set_shared_ptr->vector().front(); |
879 | } |
880 | |
881 | // Grab a set of candidate concrete_ids, we track towards the consumers in the |
882 | // ID group as one of those is guaranteed to be a valid concrete id. |
883 | VectorOfUniqueEntries<IterDomain*> maybe_concrete_ids; |
884 | for (auto id : disjoint_set_shared_ptr->vector()) { |
885 | bool id_output = true; |
886 | for (auto consumer_id : id_graph_.consumers().at(id).vector()) { |
887 | if (disjoint_set_shared_ptr->has(consumer_id)) { |
888 | id_output = false; |
889 | break; |
890 | } |
891 | } |
892 | if (id_output) { |
893 | maybe_concrete_ids.pushBack(id); |
894 | } |
895 | } |
896 | |
897 | // Shouldn't ever happen, it would mean there's an error somewhere in the |
898 | // graph. |
899 | TORCH_INTERNAL_ASSERT( |
900 | maybe_concrete_ids.vector().size(), |
901 | "No potential concrete_id's found for " , |
902 | id->toString()); |
903 | |
904 | if (maybe_concrete_ids.vector().size() == 1) { |
905 | return maybe_concrete_ids.vector().front(); |
906 | } |
907 | |
908 | // The concrete_id should have the most roots it can trace back to that are |
909 | // iter domains, (non-broadcast/non-reduction). We don't trace back through |
910 | // view operations, so the one with the most iter root domains is the concrete |
911 | // ID. |
912 | IterDomain* concrete_id = nullptr; |
913 | int max_iter_root_count = 0; |
914 | int max_bcast_root_count = 0; |
915 | |
916 | // For the LOOP map, the concrete ID must account for all root IDs |
917 | // of all of the IDs in each disjoit set. At least those ID's that are |
918 | // non-broadcast/non-reduction. As broadcast is only important here if it's |
919 | // concretized in the set. Track information so we can later make sure the |
920 | // concrete id has accounted for all iter domains meaning it has a correct |
921 | // loop size. |
922 | std::unordered_set<IterDomain*> root_ids_of_all_ids; |
923 | std::unordered_set<IterDomain*> root_ids_of_concrete_id; |
924 | std::unordered_map<IterDomain*, std::vector<IterDomain*>> |
925 | root_id_to_maybe_concrete_ids; |
926 | |
927 | // Populate the above information, look for the concrete id, validate the loop |
928 | // concrete ID. |
929 | for (auto maybe_concrete_id : maybe_concrete_ids.vector()) { |
930 | std::unordered_set<IterDomain*> root_ids; |
931 | std::deque<IterDomain*> to_visit; |
932 | |
933 | to_visit.push_back(maybe_concrete_id); |
934 | while (to_visit.size()) { |
935 | auto current_id = to_visit.front(); |
936 | to_visit.pop_front(); |
937 | if (isViewRfactor(current_id)) { |
938 | root_ids.emplace(current_id); |
939 | continue; |
940 | } |
941 | |
942 | // push back producer IterDomains or add root if they don't exist |
943 | auto producer_vals = ir_utils::producerValsOf(current_id); |
944 | auto producer_ids = ir_utils::filterByType<IterDomain>(producer_vals); |
945 | |
946 | if (producer_ids.empty()) { |
947 | root_ids.emplace(current_id); |
948 | } else { |
949 | to_visit.insert( |
950 | to_visit.end(), producer_ids.begin(), producer_ids.end()); |
951 | } |
952 | } |
953 | |
954 | if (mode == IdMappingMode::LOOP) { |
955 | std::transform( |
956 | root_ids.begin(), |
957 | root_ids.end(), |
958 | std::inserter(root_ids_of_all_ids, root_ids_of_all_ids.end()), |
959 | [&](const auto root_id) { |
960 | auto exact_concrete_id = |
961 | getConcreteMappedID(root_id, IdMappingMode::EXACT); |
962 | root_id_to_maybe_concrete_ids[exact_concrete_id].push_back( |
963 | maybe_concrete_id); |
964 | return exact_concrete_id; |
965 | }); |
966 | } |
967 | |
968 | int bcast_root_count = std::count_if( |
969 | root_ids.begin(), root_ids.end(), [&](IterDomain* root_id) { |
970 | return root_id->isBroadcast() |
971 | // TODO: This shouldn't have a negative impact, but (emperically) |
972 | // might not be necessary |
973 | || trivial_reduction_info_.isDerived(root_id); |
974 | }); |
975 | int iter_root_count = (int)root_ids.size() - bcast_root_count; |
976 | if (iter_root_count > max_iter_root_count || |
977 | (iter_root_count == max_iter_root_count && |
978 | bcast_root_count > max_bcast_root_count)) { |
979 | max_iter_root_count = iter_root_count; |
980 | max_bcast_root_count = bcast_root_count; |
981 | concrete_id = maybe_concrete_id; |
982 | |
983 | // If we update the concrete_id, then update the root_ids_of_concrete_id |
984 | // to reflect this id |
985 | if (mode == IdMappingMode::LOOP) { |
986 | root_ids_of_concrete_id.clear(); |
987 | std::transform( |
988 | root_ids.begin(), |
989 | root_ids.end(), |
990 | std::inserter( |
991 | root_ids_of_concrete_id, root_ids_of_concrete_id.end()), |
992 | [&](const auto root_id) { |
993 | return getConcreteMappedID(root_id, IdMappingMode::EXACT); |
994 | }); |
995 | } |
996 | } |
997 | } // end maybe_concrete_id |
998 | |
999 | TORCH_INTERNAL_ASSERT( |
1000 | concrete_id != nullptr, |
1001 | "Something went wrong, could not find a concrete id." ); |
1002 | |
1003 | if (mode == IdMappingMode::LOOP) { |
1004 | // Validate the concrete id has influence from all the roots of all the |
1005 | // consumers that will map to this concete id in the loop map. This means |
1006 | // all the consumers in all expressions of the loop nest generated based on |
1007 | // this concrete ID will have their roots mapping to this concrete ID |
1008 | // represented in the extent of this concrete id. |
1009 | validateCompletenessOfLoopConcreteID( |
1010 | concrete_id, |
1011 | *this, |
1012 | trivial_reduction_info_, |
1013 | root_ids_of_all_ids, |
1014 | root_ids_of_concrete_id, |
1015 | root_id_to_maybe_concrete_ids, |
1016 | disjoint_set_shared_ptr->vector(), |
1017 | maybe_concrete_ids.vector()); |
1018 | } |
1019 | |
1020 | return concrete_id; |
1021 | } |
1022 | |
1023 | void ComputeAtMap::buildConcreteIds() { |
1024 | for (const auto& disjoint_set_shared_ptr : |
1025 | id_graph_.permissiveNodes().disjointSets()) { |
1026 | TORCH_INTERNAL_ASSERT( |
1027 | disjoint_set_shared_ptr->vector().size(), |
1028 | "Cannot compute concrete id of empty set." ); |
1029 | auto first_id = disjoint_set_shared_ptr->vector().front(); |
1030 | auto concrete_id = computeConcreteId(first_id, IdMappingMode::PERMISSIVE); |
1031 | concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id; |
1032 | } |
1033 | |
1034 | for (const auto& disjoint_set_shared_ptr : |
1035 | id_graph_.exactNodes().disjointSets()) { |
1036 | TORCH_INTERNAL_ASSERT( |
1037 | disjoint_set_shared_ptr->vector().size(), |
1038 | "Cannot compute concrete id of empty set." ); |
1039 | auto first_id = disjoint_set_shared_ptr->vector().front(); |
1040 | auto concrete_id = computeConcreteId(first_id, IdMappingMode::EXACT); |
1041 | concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id; |
1042 | } |
1043 | |
1044 | for (const auto& disjoint_set_shared_ptr : |
1045 | id_graph_.loopNodes().disjointSets()) { |
1046 | TORCH_INTERNAL_ASSERT( |
1047 | disjoint_set_shared_ptr->vector().size(), |
1048 | "Cannot compute concrete id of empty set." ); |
1049 | auto first_id = disjoint_set_shared_ptr->vector().front(); |
1050 | auto concrete_id = computeConcreteId(first_id, IdMappingMode::LOOP); |
1051 | concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id; |
1052 | } |
1053 | } |
1054 | |
1055 | IterDomain* ComputeAtMap::getConcreteMappedID( |
1056 | IterDomain* id, |
1057 | IdMappingMode mode) const { |
1058 | auto disjoint_set_shared_ptr = disjointSetOf(id, mode); |
1059 | |
1060 | TORCH_INTERNAL_ASSERT( |
1061 | disjoint_set_shared_ptr->vector().size() > 0, |
1062 | "Empty disjoint set found for " , |
1063 | id->toString()); |
1064 | |
1065 | auto cache_it = concrete_id_cache_.find(disjoint_set_shared_ptr); |
1066 | |
1067 | TORCH_INTERNAL_ASSERT( |
1068 | cache_it != concrete_id_cache_.end(), |
1069 | "Could not find concrete id for: " , |
1070 | id->toString(), |
1071 | " with mode " , |
1072 | mode); |
1073 | |
1074 | return cache_it->second; |
1075 | } |
1076 | |
1077 | namespace { |
1078 | |
1079 | std::string idGraphNodesToString( |
1080 | const ComputeAtMap& ca_map, |
1081 | IdMappingMode mode) { |
1082 | std::stringstream ss; |
1083 | const auto& disjoint_sets = ca_map.getIdSets(mode); |
1084 | for (const auto& s_ptr : disjoint_sets.disjointSets()) { |
1085 | const auto& set = *s_ptr; |
1086 | IterDomain* concrete_id = nullptr; |
1087 | if (!set.empty()) { |
1088 | auto id = set.front(); |
1089 | concrete_id = ca_map.getConcreteMappedID(id, mode); |
1090 | } |
1091 | ss << " {" ; |
1092 | for (auto entry : set.vector()) { |
1093 | ss << abstractToString(entry); |
1094 | if (entry == concrete_id) { |
1095 | ss << "*" ; |
1096 | } |
1097 | if (entry != set.back()) { |
1098 | ss << "; " ; |
1099 | } |
1100 | } |
1101 | ss << " }\n" ; |
1102 | } |
1103 | return ss.str(); |
1104 | } |
1105 | |
1106 | } // namespace |
1107 | |
1108 | std::string ComputeAtMap::toString() const { |
1109 | std::stringstream ss; |
1110 | ss << "Compute at map { \n" ; |
1111 | ss << "Permissive map:\n" |
1112 | << idGraphNodesToString(*this, IdMappingMode::PERMISSIVE); |
1113 | ss << "Exact map:\n" << idGraphNodesToString(*this, IdMappingMode::EXACT); |
1114 | ss << "Loop map:\n" << idGraphNodesToString(*this, IdMappingMode::LOOP); |
1115 | ss << "Consumer maps:\n" ; |
1116 | for (auto entry : id_graph_.consumers()) { |
1117 | ss << " " << entry.first->toString() << " :: " << entry.second.toString() |
1118 | << "\n" ; |
1119 | } |
1120 | |
1121 | ss << "Producer maps:\n" ; |
1122 | for (auto entry : id_graph_.producers()) { |
1123 | ss << " " << entry.first->toString() << " :: " << entry.second.toString() |
1124 | << "\n" ; |
1125 | } |
1126 | |
1127 | ss << "Sibling map:\n" << id_graph_.siblings().toString() << "\n" ; |
1128 | |
1129 | ss << "} compute at map" << std::endl; |
1130 | return ss.str(); |
1131 | } |
1132 | |
1133 | bool ComputeAtMap::isViewRfactor(IterDomain* ref_id) const { |
1134 | return id_graph_.viewRfactorIds().find(ref_id) != |
1135 | id_graph_.viewRfactorIds().end(); |
1136 | } |
1137 | |
1138 | std::vector<IterDomain*> ComputeAtMap::getViewRfactorDomainsOfIdGroup( |
1139 | IterDomain* ref_id, |
1140 | IdMappingMode mode) const { |
1141 | auto disjoint_set = disjointSetOf(ref_id, mode); |
1142 | std::vector<IterDomain*> rfactor_ids; |
1143 | for (auto disjoint_id : disjoint_set->vector()) { |
1144 | if (id_graph_.viewRfactorIds().find(disjoint_id) != |
1145 | id_graph_.viewRfactorIds().end()) { |
1146 | rfactor_ids.push_back(disjoint_id); |
1147 | } |
1148 | } |
1149 | return rfactor_ids; |
1150 | } |
1151 | |
1152 | const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>& ComputeAtMap:: |
1153 | disjointSetOf(IterDomain* id, IdMappingMode mode) const { |
1154 | TORCH_INTERNAL_ASSERT( |
1155 | idExistsInMap(id), |
1156 | id->toString(), |
1157 | " has not been processed in this Compute At Map, yet the disjoint set for it was requested." ); |
1158 | return getIdSets(mode).disjointSetMap().at(id); |
1159 | } |
1160 | |
1161 | const DisjointSets<IterDomain*>& ComputeAtMap::getIdSets( |
1162 | IdMappingMode mode) const { |
1163 | switch (mode) { |
1164 | case IdMappingMode::PERMISSIVE: |
1165 | return id_graph_.permissiveNodes(); |
1166 | case IdMappingMode::EXACT: |
1167 | return id_graph_.exactNodes(); |
1168 | case IdMappingMode::LOOP: |
1169 | return id_graph_.loopNodes(); |
1170 | } |
1171 | TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided." ); |
1172 | } |
1173 | |
1174 | bool ComputeAtMap::idExistsInMap(IterDomain* id) const { |
1175 | return getIdSets(IdMappingMode::EXACT).disjointSetMap().find(id) != |
1176 | getIdSets(IdMappingMode::EXACT).disjointSetMap().end(); |
1177 | } |
1178 | |
1179 | } // namespace cuda |
1180 | } // namespace fuser |
1181 | } // namespace jit |
1182 | } // namespace torch |
1183 | |