1#include <transform_replay.h>
2
3#include <arith.h>
4#include <disjoint_set.h>
5#include <fusion.h>
6#include <instrumentation.h>
7#include <ir_all_nodes.h>
8#include <ir_builder.h>
9#include <ir_iostream.h>
10#include <ir_utils.h>
11#include <maxinfo_propagator.h>
12#include <root_domain_map.h>
13#include <transform_iter.h>
14
15#include <deque>
16
17namespace torch {
18namespace jit {
19namespace fuser {
20namespace cuda {
21
22using id_map = std::unordered_map<IterDomain*, IterDomain*>;
23
24namespace {
25
26class ReplaySelf : public ReplayTransformations {
27 private:
28 // Took a good bit of this from ReplayTransformations::handle(Split...)
29 void handle(Split* s) override {
30 // Grab input to the split operation
31 auto id_in = s->in();
32
33 // Grab our mapping of that ID to the one we're replaying
34 auto it = id_map_.find(id_in);
35
36 // Make sure it exists in the map
37 TORCH_INTERNAL_ASSERT(
38 it != id_map_.end(),
39 "Transform traversal failed, dependencies not met.");
40 // Grab the ID we're going to replay on
41 auto mapped = it->second;
42
43 // This ID should be a leaf ID (meaning it has no uses we generated)
44 TORCH_INTERNAL_ASSERT(
45 leaf_ids_.find(mapped) != leaf_ids_.end(),
46 "Transform traversal failed, modified a node but it was not a leaf node.");
47
48 // outer loop size
49 Val* remainder = ceilDiv(
50 Split::extent(mapped->extent(), s->startOffset(), s->stopOffset()),
51 s->factor());
52
53 // Manually replay the split, following the output of the operations.
54 // This is so rfactor ops are replayed correctly.
55 IterDomain* ido =
56 IterDomainBuilder(s->outer())
57 .start(s->container()->zeroVal())
58 .extent(s->innerSplit() ? remainder->as<Int>() : s->factor())
59 .build();
60
61 // inner IterDomain
62 IterDomain* idi =
63 IterDomainBuilder(s->inner())
64 .start(s->container()->zeroVal())
65 .extent(s->innerSplit() ? s->factor() : remainder->as<Int>())
66 .build();
67
68 // Generate the split node
69 IrBuilder::create<Split>(
70 s->container(),
71 ido,
72 idi,
73 mapped,
74 s->factor(),
75 s->innerSplit(),
76 s->startOffset(),
77 s->stopOffset());
78
79 // Remove mapped id from leaf IDs
80 leaf_ids_.erase(mapped);
81
82 // Add outputs to leaf IDs
83 leaf_ids_[ido] = counter++;
84 leaf_ids_[idi] = counter++;
85
86 // Update our ID map to include these outputs
87 id_map_[s->outer()] = ido;
88 id_map_[s->inner()] = idi;
89 }
90
91 void handle(Merge* m) override {
92 auto id_outer = m->outer();
93 auto id_inner = m->inner();
94
95 auto it_outer = id_map_.find(id_outer);
96 auto it_inner = id_map_.find(id_inner);
97
98 TORCH_INTERNAL_ASSERT(
99 it_outer != id_map_.end() && it_inner != id_map_.end(),
100 "Transform traversal failed, dependencies not met.");
101
102 auto id_outer_mapped = it_outer->second;
103 auto id_inner_mapped = it_inner->second;
104
105 TORCH_INTERNAL_ASSERT(
106 leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() &&
107 leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(),
108 "Transform traversal failed, modified ",
109 id_outer_mapped,
110 " and ",
111 id_inner_mapped,
112 " however one or both are not leaf nodes.");
113
114 Val* merged_id_size =
115 mul(id_outer_mapped->extent(), id_inner_mapped->extent());
116
117 IterDomain* merged_id = IterDomainBuilder(m->out())
118 .start(m->container()->zeroVal())
119 .extent(merged_id_size->as<Int>())
120 .build();
121
122 IrBuilder::create<Merge>(
123 m->container(), merged_id, id_outer_mapped, id_inner_mapped);
124
125 // Remove inputs from the leaf IDs
126 leaf_ids_.erase(id_outer_mapped);
127 leaf_ids_.erase(id_inner_mapped);
128
129 // Add the output to the leaf IDs
130 leaf_ids_[merged_id] = counter++;
131
132 id_map_[m->out()] = merged_id;
133 }
134
135 public:
136 ReplaySelf(const std::vector<IterDomain*>& _target_domain, id_map _id_map)
137 : ReplayTransformations(_target_domain, std::move(_id_map), false) {}
138};
139
140} // namespace
141
142// Self replay.
143TensorDomain* TransformReplay::fullSelfReplay(
144 const TensorDomain* new_self_root,
145 const TensorDomain* self) {
146 FUSER_PERF_SCOPE("TransformReplay::fullSelfReplay");
147
148 TORCH_INTERNAL_ASSERT(
149 new_self_root->getRootDomain().size() == self->getRootDomain().size(),
150 "Invalid number of IterDomains provided.");
151
152 // Map for replay, should be pretty simple.
153 id_map axis_map;
154 {
155 size_t i = 0;
156 for (auto id : self->getRootDomain()) {
157 TORCH_INTERNAL_ASSERT(
158 new_self_root->getRootDomain()[i]->isReduction() ==
159 id->isReduction() &&
160 new_self_root->getRootDomain()[i]->isRFactorProduct() ==
161 id->isRFactorProduct() &&
162 new_self_root->getRootDomain()[i]->isBroadcast() ==
163 id->isBroadcast(),
164 "Axes ",
165 id,
166 " and ",
167 new_self_root->getRootDomain()[i],
168 " do not match for self replay.");
169 axis_map[id] = new_self_root->getRootDomain()[i];
170 i++;
171 }
172 }
173
174 // Replay producer dimensions.
175 ReplaySelf replay(self->domain(), axis_map);
176 std::vector<IterDomain*> new_domain(self->nDims(), nullptr);
177
178 {
179 size_t i = 0;
180 for (auto id : self->domain()) {
181 auto it = replay.getReplay().find(id);
182 TORCH_INTERNAL_ASSERT(
183 it != replay.getReplay().end(),
184 "Error during replay, didn't replay an axis.");
185 new_domain[i++] = it->second;
186 }
187
188 if (self->hasRFactor()) {
189 std::vector<IterDomain*> new_rfactor_domain(
190 self->getMaybeRFactorDomain().size(), nullptr);
191 size_t i = 0;
192 for (auto id : self->getMaybeRFactorDomain()) {
193 auto it = replay.getReplay().find(id);
194 TORCH_INTERNAL_ASSERT(
195 it != replay.getReplay().end(),
196 "Error during replay, didn't replay an axis.");
197 new_rfactor_domain[i++] = it->second;
198 }
199 return IrBuilder::create<TensorDomain>(
200 self->container(),
201 new_self_root->getRootDomain(),
202 new_rfactor_domain,
203 new_domain,
204 self->contiguity());
205 }
206 }
207
208 return IrBuilder::create<TensorDomain>(
209 self->container(),
210 new_self_root->getRootDomain(),
211 new_domain,
212 new_self_root->contiguity());
213}
214
215// Producer could have rfactor axes which consumer may want replayed. We can
216// "replay" them as long as it doesn't modify the root rfactor axes. What we
217// really want to do is validate if we replayed these axes to the ones they
218// mapped to in the consumer the operations would all be the same. then we want
219// to start the replay of the producer from the rfactor root axes, not the root.
220std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
221 const TensorView* producer,
222 const TensorView* consumer,
223 int consumer_compute_at_axis,
224 const RootDomainMap& root_map,
225 bool replay_swizzle) {
226 FUSER_PERF_SCOPE("TransformReplay::replayPasC");
227
228 // If this is a reduction operation, we may call transform_replay on the
229 // tensor view. When this happens, just return thet target view.
230 if (producer == consumer)
231 return {producer->domain(), producer->nDims()};
232
233 if (consumer_compute_at_axis < 0)
234 consumer_compute_at_axis += (int)consumer->nDims() + 1;
235 TORCH_INTERNAL_ASSERT(
236 consumer_compute_at_axis >= 0 &&
237 (unsigned int)consumer_compute_at_axis <= consumer->nDims(),
238 "Invalid axis in transform replayPasC.");
239
240 // consumer ids we need to match in producer
241 std::vector<IterDomain*> consumer_CA_ids(
242 consumer->domain()->domain().begin(),
243 consumer->domain()->domain().begin() + consumer_compute_at_axis);
244
245 // Instead of replaying from the root, lets try to play forward the history of
246 // producer if they match ops on consumer. Enforce if we modify an rfactor
247 // axis that those ops must match.
248 auto forward_replay = BestEffortReplay::replayPasC(
249 producer, consumer, consumer_compute_at_axis, root_map);
250
251 // Make a new map based on all the leaves resulting from best effort replay
252 id_map forwarded_replay_map;
253 auto forward_dangling_leaves = forward_replay.getUnorderedLeafIDs();
254 for (auto entry : forward_replay.getReplay()) {
255 if (forward_dangling_leaves.find(entry.second) !=
256 forward_dangling_leaves.end()) {
257 forwarded_replay_map[entry.first] = entry.second;
258 forward_dangling_leaves.erase(entry.second);
259 }
260 }
261
262 // Replay producer dimensions.
263 ReplayTransformations replay_PasC(
264 consumer_CA_ids, forwarded_replay_map, false, replay_swizzle);
265
266 auto leaf_ids(replay_PasC.getUnorderedLeafIDs());
267
268 // Remove all ids that map to the compute at axis, we're going to replay the
269 // rest, track all dims needed to match consumer CA dims
270 std::vector<IterDomain*> needed_dims;
271 for (auto c_id : consumer_CA_ids) {
272 auto it = replay_PasC.getReplay().find(c_id);
273 if (it == replay_PasC.getReplay().end()) {
274 TORCH_INTERNAL_ASSERT(
275 c_id->isBroadcast() || c_id->isGather() || c_id->isVectorComponent(),
276 "Could not find axis, ",
277 c_id,
278 ", requested in replay.");
279 continue;
280 }
281 TORCH_INTERNAL_ASSERT(
282 leaf_ids.find(it->second) != leaf_ids.end(),
283 "Replayed id to match consumer id ",
284 c_id,
285 " should be a leaf in replay map.");
286 leaf_ids.erase(it->second);
287 needed_dims.push_back(it->second);
288 }
289
290 // leaf_ids now contains all producer ID products that are not used to satisfy
291 // the computeAt Turn into a map so we can play forward these IDs in producer
292 // (if possible):
293 id_map producer_self_replay_map;
294 for (auto entry : leaf_ids) {
295 producer_self_replay_map[entry.first] = entry.first;
296 }
297
298 for (auto entry : forward_dangling_leaves) {
299 producer_self_replay_map[entry.first] = entry.first;
300 }
301
302 // Check which root domains were used to produce the leaf_ids. We may have
303 // picked up extra roots in consumer because of broadcast forwarding.
304 std::vector<Val*> unordered_non_root_leaf_vals;
305 for (auto leaf_id : replay_PasC.getUnorderedLeafIDs()) {
306 if (leaf_id.first->definition() == nullptr) {
307 continue;
308 } else {
309 unordered_non_root_leaf_vals.emplace_back(leaf_id.first);
310 }
311 }
312
313 auto producer_root = producer->getMaybeRFactorDomain();
314
315 // Figure out all id's that have been processed to generate the
316 // unordered_non_root_leaf_vals. This needs to be done because we want to
317 // match on producer's rfactor domain, not root domain.
318 std::unordered_set<IterDomain*> all_processed_ids;
319 {
320 auto all_processed_vals_vec = DependencyCheck::getAllValsBetween(
321 {producer_root.begin(), producer_root.end()},
322 unordered_non_root_leaf_vals);
323 auto all_processed_ids_vec =
324 ir_utils::filterByType<IterDomain>(all_processed_vals_vec);
325 all_processed_ids.insert(
326 all_processed_ids_vec.begin(), all_processed_ids_vec.end());
327 }
328
329 // Any root domain that was not used to generate computeIDs we can also put in
330 // the map to forward their transformations.
331 for (auto producer_root_id : producer_root) {
332 if (all_processed_ids.find(producer_root_id) == all_processed_ids.end() &&
333 std::find(needed_dims.begin(), needed_dims.end(), producer_root_id) ==
334 needed_dims.end()) {
335 producer_self_replay_map[producer_root_id] = producer_root_id;
336 }
337 }
338
339 // Play forward transformations all producer IDs we can
340 auto producer_replayed_leaves = BestEffortReplay(
341 producer->domain()->domain(),
342 producer->domain()->domain(),
343 producer_self_replay_map);
344
345 /*
346 * Accumulate axes in to the new domain in the following order, making sure to
347 * avoid any duplicates:
348 *
349 * (1) replay_PasC.getReplay holds mappings from axes in consumer compute at
350 * axes -> corresponding generated axes in producer
351 *
352 * (2) Any axes that were not added, that can be mapped directly from an ID in
353 * consumer->domain(). These are axes that were "fully replayed" relative to
354 * the consumer, even though it wasn't in the computeAt range.
355 *
356 * producer_replayed_leaves now contain ids that we tried to forward
357 * back to what they were in producer. If they couldn't be forwarded they're
358 * left in their "most forwarded" form which may be just a remainder of the
359 * transformation required to generate the computeAt axes.
360 *
361 * (3) Axes in producer->domain() that are in producer_replayed_leaves
362 *
363 * (4) Axes not in producer->domain() that are in producer_replayed_leaves
364 *
365 */
366
367 std::vector<IterDomain*> new_IDs;
368 std::unordered_set<IterDomain*> used_IDs;
369 // Add axes in (1)
370 for (auto c_id : consumer_CA_ids) {
371 auto it = replay_PasC.getReplay().find(c_id);
372 if (it == replay_PasC.getReplay().end()) {
373 TORCH_INTERNAL_ASSERT(
374 c_id->isBroadcast() || c_id->isGather() || c_id->isVectorComponent(),
375 "Could not find axis, ",
376 c_id,
377 ", requested in replay.");
378 continue;
379 }
380 new_IDs.push_back(it->second);
381 used_IDs.emplace(it->second);
382 }
383
384 unsigned int producer_compute_at_axis = new_IDs.size();
385
386 // Add axes in (2)
387 for (auto c_id : consumer->domain()->domain()) {
388 auto it = replay_PasC.getReplay().find(c_id);
389 if (it != replay_PasC.getReplay().end()) {
390 auto id = it->second;
391 // If the leaf id from ReplayTransformations is used to move
392 // forward in BestEffortReplay, it is not a final ID.
393 if (producer_replayed_leaves.getUnorderedLeafIDs().find(id) ==
394 producer_replayed_leaves.getUnorderedLeafIDs().end()) {
395 continue;
396 }
397 if (used_IDs.find(id) == used_IDs.end()) {
398 new_IDs.push_back(id);
399 used_IDs.emplace(id);
400 }
401 }
402 }
403
404 // Add axes in (3)
405 for (auto id : producer->domain()->domain()) {
406 if (producer_replayed_leaves.getUnorderedLeafIDs().find(id) !=
407 producer_replayed_leaves.getUnorderedLeafIDs().end()) {
408 if (used_IDs.find(id) == used_IDs.end()) {
409 new_IDs.push_back(id);
410 used_IDs.emplace(id);
411 }
412 }
413 }
414
415 // Add axes in (4)
416 for (auto id : producer_replayed_leaves.getLeafIDs()) {
417 if (used_IDs.find(id) == used_IDs.end()) {
418 new_IDs.push_back(id);
419 }
420 }
421 TensorDomain* replayed = IrBuilder::create<TensorDomain>(
422 producer->container(),
423 producer->getRootDomain(),
424 producer->getRFactorDomain(),
425 new_IDs,
426 producer->domain()->contiguity());
427
428 return {replayed, producer_compute_at_axis};
429}
430
431std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
432 const TensorView* consumer,
433 const TensorView* producer,
434 int producer_compute_at_axis,
435 const RootDomainMap& root_map) {
436 FUSER_PERF_SCOPE("TransformReplay::replayCasP");
437
438 // If this is a reduction operation, we may call transform_replay on the same
439 // tensor view. When this happens, just return thet target view.
440 if (consumer == producer)
441 return {consumer->domain(), consumer->nDims()};
442
443 if (producer_compute_at_axis < 0)
444 producer_compute_at_axis += (int)producer->nDims() + 1;
445
446 TORCH_INTERNAL_ASSERT(
447 producer_compute_at_axis >= 0 &&
448 (unsigned int)producer_compute_at_axis <= producer->nDims(),
449 "Invalid axis in transform replayCasP.");
450
451 // producer ids we need to match in consumer
452 std::vector<IterDomain*> producer_CA_ids(
453 producer->domain()->domain().begin(),
454 producer->domain()->domain().begin() + producer_compute_at_axis);
455 producer_CA_ids = TensorDomain::noReductions(producer_CA_ids);
456
457 // Instead of replaying from the root, lets try to forward the history of
458 // consumer if they match ops on producer. Enforce if we modify an rfactor
459 // axis that those ops match.
460 BestEffortReplay forward_replay = BestEffortReplay::replayCasP(
461 consumer, producer, producer_compute_at_axis, root_map);
462
463 // Track dangling leaves which can be produced in
464 // BestEffortReplay::replayCasP these don't have any equivalent in producer
465 // so they're not in the map. We will simply map them to themselves so we
466 // don't lose them.
467 id_map forwarded_replay_map;
468 auto forward_dangling_leaves = forward_replay.getUnorderedLeafIDs();
469 for (auto entry : forward_replay.getReplay()) {
470 if (forward_dangling_leaves.find(entry.second) !=
471 forward_dangling_leaves.end()) {
472 forwarded_replay_map[entry.first] = entry.second;
473 forward_dangling_leaves.erase(entry.second);
474 }
475 }
476
477 // Replay producer dimensions.
478 ReplayTransformations replay_CasP(
479 producer_CA_ids, forwarded_replay_map, false);
480
481 auto leaf_ids(replay_CasP.getUnorderedLeafIDs());
482
483 // Remove all ids that map to the compute at axis, we're going to replay the
484 // rest, track all dims that are needed to match producer CA dims
485 std::vector<IterDomain*> needed_dims;
486 for (auto p_id : producer_CA_ids) {
487 auto it = replay_CasP.getReplay().find(p_id);
488 TORCH_INTERNAL_ASSERT(
489 it != replay_CasP.getReplay().end(),
490 "Could not find axis, ",
491 p_id,
492 ", requested in replay.");
493 TORCH_INTERNAL_ASSERT(
494 leaf_ids.find(it->second) != leaf_ids.end(),
495 "Replayed id to match producer id ",
496 p_id,
497 " should be a leaf in replay map.");
498 leaf_ids.erase(it->second);
499 needed_dims.push_back(it->second);
500 }
501
502 // leaf_ids now contains all consumer ID products that are not used to satisfy
503 // the computeAt. Turn into a map so we can play forward these IDs in
504 // consumer (if possible):
505 id_map consumer_self_replay_map;
506 for (auto entry : leaf_ids) {
507 consumer_self_replay_map[entry.first] = entry.first;
508 }
509
510 for (auto entry : forward_dangling_leaves) {
511 consumer_self_replay_map[entry.first] = entry.first;
512 }
513
514 // Check which root domains were used to produce the leaf_ids. We may have
515 // picked up extra roots in consumer because of broadcast forwarding.
516 std::vector<Val*> unordered_non_root_leaf_vals;
517 for (auto leaf_id : replay_CasP.getUnorderedLeafIDs()) {
518 if (leaf_id.first->definition() == nullptr) {
519 continue;
520 } else {
521 unordered_non_root_leaf_vals.emplace_back(leaf_id.first);
522 }
523 }
524
525 auto processed_roots = IterVisitor::getInputsTo(unordered_non_root_leaf_vals);
526
527 std::vector<IterDomain*> consumer_root = consumer->getRootDomain();
528
529 // Any root domain that was not used to generate computeIDs we can also put in
530 // the map to forward their transformations.
531 for (auto consumer_root_id : consumer_root) {
532 if (std::find(
533 processed_roots.begin(), processed_roots.end(), consumer_root_id) ==
534 processed_roots.end() &&
535 // Don't re-add roots that may have directly mapped in the replay
536 std::find(needed_dims.begin(), needed_dims.end(), consumer_root_id) ==
537 needed_dims.end()) {
538 consumer_self_replay_map[consumer_root_id] = consumer_root_id;
539 }
540 }
541
542 // Play forward transformations all consumer IDs we can
543 auto consumer_replayed_leaves = BestEffortReplay(
544 consumer->domain()->domain(),
545 consumer->domain()->domain(),
546 consumer_self_replay_map);
547
548 /*
549 * Accumulate axes in to the new domain in the following order, making sure to
550 * avoid any duplicates:
551 *
552 * (1) replay_PasC.getReplay holds mappings from axes in consumer compute at
553 * axes -> corresponding generated axes in producer
554 *
555 * (2) Any axes that were not added, that can be mapped directly from an ID in
556 * producer->domain(). These are axes that were "fully replayed" relative to
557 * the producer, even though it wasn't in the computeAt range.
558 *
559 * producer_replayed_leaves now contain ids that we tried to forward
560 * back to what they were in producer. If they couldn't be forwarded they're
561 * left in their "most forwarded" form which may be just a remainder of the
562 * transformation required to generate the computeAt axes.
563 *
564 * (3) Axes in producer->domain() that are in producer_replayed_leaves
565 *
566 * (4) Axes not in producer->domain() that are in producer_replayed_leaves
567 *
568 * TODO: Should (2) and (3) be swapped?
569 */
570
571 std::vector<IterDomain*> new_IDs;
572 std::unordered_set<IterDomain*> used_IDs;
573 // Add axes in (1)
574 for (auto p_id : producer_CA_ids) {
575 auto it = replay_CasP.getReplay().find(p_id);
576 TORCH_INTERNAL_ASSERT(
577 it != replay_CasP.getReplay().end(),
578 "Could not find axis, ",
579 p_id,
580 ", requested in replay.");
581 new_IDs.push_back(it->second);
582 used_IDs.emplace(it->second);
583 }
584
585 // Add axes in (2)
586 for (auto p_id : producer->domain()->domain()) {
587 auto it = replay_CasP.getReplay().find(p_id);
588 if (it != replay_CasP.getReplay().end()) {
589 auto id = it->second;
590 // If the leaf id from ReplayTransformations is used to move
591 // forward in BestEffortReplay, it is not a final ID.
592 if (consumer_replayed_leaves.getUnorderedLeafIDs().find(id) ==
593 consumer_replayed_leaves.getUnorderedLeafIDs().end()) {
594 continue;
595 }
596 if (used_IDs.find(id) == used_IDs.end()) {
597 new_IDs.push_back(id);
598 used_IDs.emplace(id);
599 }
600 }
601 }
602
603 // Add axes in (3)
604 for (auto id : consumer->domain()->domain()) {
605 if (consumer_replayed_leaves.getUnorderedLeafIDs().find(id) !=
606 consumer_replayed_leaves.getUnorderedLeafIDs().end()) {
607 if (used_IDs.find(id) == used_IDs.end()) {
608 new_IDs.push_back(id);
609 used_IDs.emplace(id);
610 }
611 }
612 }
613
614 // Add axes in (4)
615 for (auto id : consumer_replayed_leaves.getLeafIDs())
616 if (used_IDs.find(id) == used_IDs.end())
617 new_IDs.push_back(id);
618
619 TensorDomain* replayed = IrBuilder::create<TensorDomain>(
620 consumer->container(),
621 consumer->getRootDomain(),
622 consumer->getRFactorDomain(),
623 new_IDs,
624 consumer->domain()->contiguity());
625
626 return {replayed, producer_CA_ids.size()};
627}
628
629// replay Producer as Consumer
630std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
631 const TensorView* producer,
632 const TensorView* consumer,
633 int compute_at_axis,
634 bool replay_swizzle) {
635 // Use the pairwise root map as a default mapper
636 PairwiseRootDomainMap root_map(producer, consumer);
637 return replayPasC(
638 producer, consumer, compute_at_axis, root_map, replay_swizzle);
639}
640
641std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
642 const TensorView* consumer,
643 const TensorView* producer,
644 int compute_at_axis) {
645 // Use the pairwise root map as a default mapper
646 PairwiseRootDomainMap root_map(producer, consumer);
647 return replayCasP(consumer, producer, compute_at_axis, root_map);
648}
649
650// In a PasC replay, we want the producer to exactly match the consumer:
651// all the beginning axes in the producer should be mapped to the consumer in
652// the same order. Reductions in the producer needs to be in the back of the
653// producer.
654int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
655 const TensorView* producer,
656 const TensorView* consumer,
657 int consumer_pos) {
658 FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayPasC");
659
660 const auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
661 id_map c2p_root_map = pairwise_map.mapConsumerToProducer(
662 consumer->domain(), producer->domain());
663
664 // IterDomains in `consumer` root also in `producer` root
665 const auto consumer_domain = consumer->domain()->domain();
666
667 std::unordered_set<Val*> mapped_consumer_roots;
668 for (auto entry : c2p_root_map) {
669 mapped_consumer_roots.emplace(entry.first);
670 }
671
672 auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween(
673 mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});
674
675 std::unordered_set<Val*> unskippable_consumer_ids(
676 unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end());
677
678 // IterDomains in `producer` root also in `consumer` root
679 const auto producer_domain = producer->domain()->domain();
680
681 auto it_consumer = consumer_domain.begin();
682 auto it_producer = producer_domain.begin();
683
684 auto disjoint_sets =
685 BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
686 .getDisjointSets();
687
688 int mismatched_consumer_pos = 0;
689 int mismatched_producer_pos = 0;
690 while (it_consumer != consumer_domain.end()) {
691 if (consumer_pos == mismatched_consumer_pos) {
692 return mismatched_producer_pos;
693 }
694
695 auto consumer_id = *it_consumer;
696 if (unskippable_consumer_ids.count(consumer_id) == 0) {
697 ++it_consumer;
698 ++mismatched_consumer_pos;
699 continue;
700 }
701
702 if (it_producer == producer_domain.end()) {
703 return -1;
704 }
705
706 auto producer_id = *it_producer;
707 if (disjoint_sets.permissiveAreMapped(producer_id, consumer_id)) {
708 ++mismatched_consumer_pos;
709 ++mismatched_producer_pos;
710 ++it_consumer;
711 ++it_producer;
712 } else {
713 return -1;
714 }
715 }
716 if (consumer_pos == mismatched_consumer_pos) {
717 return mismatched_producer_pos;
718 }
719 return -1;
720}
721
722// We want to ignore reductions in the producer in a CasP replay.
723int TransformReplay::getMatchedLeafPosWithoutReplayCasP(
724 const TensorView* consumer,
725 const TensorView* producer,
726 int producer_pos) {
727 FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayCasP");
728
729 const auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
730 id_map p2c_root_map = pairwise_map.mapProducerToConsumer(
731 producer->domain(), consumer->domain());
732
733 // IterDomains in `producer` root that are not reduction
734 const auto producer_domain = producer->domain()->domain();
735 auto unskippable_producer_ids_vec =
736 TensorDomain::noReductions(producer_domain);
737 std::unordered_set<IterDomain*> unskippable_producer_ids(
738 unskippable_producer_ids_vec.begin(), unskippable_producer_ids_vec.end());
739
740 // IterDomains in `consumer` root also in `producer` root
741 const auto consumer_domain = consumer->domain()->domain();
742
743 std::unordered_set<Val*> mapped_consumer_roots;
744 for (auto entry : p2c_root_map) {
745 mapped_consumer_roots.emplace(entry.second);
746 }
747
748 auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween(
749 mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});
750
751 std::unordered_set<Val*> unskippable_consumer_ids(
752 unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end());
753
754 auto it_producer = producer_domain.begin();
755 auto it_consumer = consumer_domain.begin();
756
757 auto disjoint_sets =
758 BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
759 .getDisjointSets();
760
761 int mismatched_producer_pos = 0;
762 int mismatched_consumer_pos = 0;
763 while (it_producer != producer_domain.end()) {
764 if (producer_pos == mismatched_producer_pos) {
765 return mismatched_consumer_pos;
766 }
767
768 auto producer_id = *it_producer;
769 if (unskippable_producer_ids.count(producer_id) == 0) {
770 ++it_producer;
771 ++mismatched_producer_pos;
772 continue;
773 }
774
775 if (it_consumer == consumer_domain.end()) {
776 return -1;
777 }
778
779 auto consumer_id = *it_consumer;
780 if (unskippable_consumer_ids.count(consumer_id) == 0) {
781 ++it_consumer;
782 ++mismatched_consumer_pos;
783 continue;
784 }
785
786 if (disjoint_sets.permissiveAreMapped(producer_id, consumer_id)) {
787 ++mismatched_producer_pos;
788 ++mismatched_consumer_pos;
789 ++it_producer;
790 ++it_consumer;
791 } else {
792 return -1;
793 }
794 }
795 if (producer_pos == mismatched_producer_pos) {
796 return mismatched_consumer_pos;
797 }
798 return -1;
799}
800
801bool TransformReplay::fullSelfMatching(
802 const TensorView* replay,
803 const TensorView* target) {
804 auto replay_root = replay->getRootDomain();
805 auto replay_dom = replay->domain()->domain();
806 auto target_root = target->getRootDomain();
807 auto target_dom = target->domain()->domain();
808 std::unordered_map<IterDomain*, IterDomain*> target2replay_map;
809 if (replay_root.size() != target_root.size()) {
810 return false;
811 }
812 target2replay_map.reserve(replay_root.size());
813 std::transform(
814 target_root.begin(),
815 target_root.end(),
816 replay_root.begin(),
817 std::inserter(target2replay_map, target2replay_map.begin()),
818 [](auto a, auto b) { return std::make_pair(a, b); });
819 BestEffortReplay replay_(replay_dom, target_dom, target2replay_map);
820 auto r = replay_.getReplay();
821 for (int64_t i = 0; i < (int64_t)replay_dom.size(); i++) {
822 auto target_id = target_dom[i];
823 auto replay_it = r.find(target_id);
824 if (replay_it == r.end() || replay_it->second != replay_dom[i]) {
825 return false;
826 }
827 }
828 return true;
829}
830
831namespace {
832
833// Make sure if tv is set to new_td it doesn't violate set compute at and max
834// produce at positions.
835bool validateDomain(TensorView* tv, TensorDomain* new_td) {
836 auto first_mismatch =
837 BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td);
838 return first_mismatch >= (int)tv->getMaxProducerPosition() &&
839 first_mismatch >= (int)tv->getComputeAtPosition();
840}
841
842} // namespace
843
844void TransformPropagator::propagateC2P(TensorView* from, TensorView* to) {
845 int pos = replayed_pos_.at(from);
846 // Note: [Using multiple TransformPropagators]
847 // There are cases that we use multiple TransformPropagators along different
848 // spanning trees with different references in the same fusion. Some of these
849 // spanning trees could overlap. In cases when there are overlapping nodes,
850 // TransformPropagator needs to respect the replay of others, because the
851 // current TransformPropagator might not contain the most amount of
852 // information on how to do the correct transformation. The logic below tells
853 // TransformPropagator to skip the replay when not necessary.
854 int new_pos =
855 TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
856 bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
857 if (debug) {
858 std::cout << "TransformPropagator::propagateC2P" << std::endl;
859 std::cout << " from: " << from << " @ " << pos << std::endl;
860 std::cout << " to: " << to << std::endl;
861 }
862 if (new_pos < 0) {
863 auto replay = TransformReplay::replayPasC(to, from, pos);
864 TORCH_INTERNAL_ASSERT(
865 validateDomain(to, replay.first),
866 "Tried to set the domain of ",
867 to,
868 " to ",
869 replay.first,
870 " but that would invalidate previously compute at position or max producer position.");
871 to->setDomain(replay.first);
872 new_pos = replay.second;
873 if (debug) {
874 std::cout << " replayed: " << to << " @ " << new_pos << std::endl;
875 }
876 } else if (debug) {
877 std::cout << " replay skipped. result position: " << new_pos << std::endl;
878 }
879 replayed_pos_[to] = new_pos;
880}
881
882void TransformPropagator::propagateP2C(TensorView* from, TensorView* to) {
883 int pos = replayed_pos_.at(from);
884 // See note [Using multiple TransformPropagators]
885 int new_pos =
886 TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
887 bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
888 if (debug) {
889 std::cout << "TransformPropagator::propagateP2C" << std::endl;
890 std::cout << " from: " << from << " @ " << pos << std::endl;
891 std::cout << " to: " << to << std::endl;
892 }
893 if (new_pos < 0) {
894 auto replay = TransformReplay::replayCasP(to, from, pos);
895 TORCH_INTERNAL_ASSERT(
896 validateDomain(to, replay.first),
897 "Tried to set the domain of ",
898 to,
899 " to ",
900 replay.first,
901 " but that would invalidate previously compute at position or max producer position.");
902 to->setDomain(replay.first);
903 new_pos = replay.second;
904 if (debug) {
905 std::cout << " replayed: " << to << " @ " << new_pos << std::endl;
906 }
907 } else if (debug) {
908 std::cout << " replay skipped. result position: " << new_pos << std::endl;
909 }
910 replayed_pos_[to] = new_pos;
911}
912
913void TransformPropagator::propagateSibling(TensorView* from, TensorView* to) {
914 int pos = replayed_pos_.at(from);
915 // See note [Using multiple TransformPropagators]
916 bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
917 if (debug) {
918 std::cout << "TransformPropagator::propagateSibling" << std::endl;
919 std::cout << " from: " << from << " @ " << pos << std::endl;
920 std::cout << " to: " << to << std::endl;
921 }
922 if (!TransformReplay::fullSelfMatching(to, from)) {
923 auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
924 TORCH_INTERNAL_ASSERT(
925 validateDomain(to, replay),
926 "Tried to set the domain of ",
927 to,
928 " to ",
929 replay,
930 " but that would invalidate previously compute at position or max producer position.");
931 to->setDomain(replay);
932 if (debug) {
933 std::cout << " replayed: " << to << " @ " << pos << std::endl;
934 }
935 } else if (debug) {
936 std::cout << " replay skipped. result position: " << pos << std::endl;
937 }
938 replayed_pos_[to] = pos;
939}
940
941TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) {
942 if (pos < 0) {
943 pos += int64_t(from->nDims()) + 1;
944 }
945 TORCH_CHECK(
946 pos >= 0 && pos <= (int64_t)from->nDims(),
947 "TransformPropagator called on an pos outside valid range.");
948 replayed_pos_[from] = pos;
949}
950
951void MostInlinedTransformPropagator::propagateC2P(
952 TensorView* from,
953 TensorView* to) {
954 int pos = from->nDims();
955 // See note [Using multiple TransformPropagators]
956 int new_pos =
957 TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
958 bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
959 if (debug) {
960 std::cout << "MostInlinedTransformPropagator::propagateC2P" << std::endl;
961 std::cout << " from: " << from << std::endl;
962 std::cout << " to: " << to << std::endl;
963 }
964 if (new_pos < 0) {
965 auto replay = TransformReplay::replayPasC(to, from, pos);
966 TORCH_INTERNAL_ASSERT(
967 validateDomain(to, replay.first),
968 "Tried to set the domain of ",
969 to,
970 " to ",
971 replay.first,
972 " but that would invalidate previously compute at position or max producer position.");
973 to->setDomain(replay.first);
974 if (debug) {
975 std::cout << " replayed: " << to << std::endl;
976 }
977 } else if (debug) {
978 std::cout << " replay skipped" << std::endl;
979 }
980}
981
982void MostInlinedTransformPropagator::propagateP2C(
983 TensorView* from,
984 TensorView* to) {
985 int pos = from->nDims();
986 // See note [Using multiple TransformPropagators]
987 int new_pos =
988 TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
989 bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
990 if (debug) {
991 std::cout << "MostInlinedTransformPropagator::propagateP2C" << std::endl;
992 std::cout << " from: " << from << std::endl;
993 std::cout << " to: " << to << std::endl;
994 }
995 if (new_pos < 0) {
996 auto replay = TransformReplay::replayCasP(to, from, pos);
997 TORCH_INTERNAL_ASSERT(
998 validateDomain(to, replay.first),
999 "Tried to set the domain of ",
1000 to,
1001 " to ",
1002 replay.first,
1003 " but that would invalidate previously compute at position or max producer position.");
1004 to->setDomain(replay.first);
1005 if (debug) {
1006 std::cout << " replayed: " << to << std::endl;
1007 }
1008 } else if (debug) {
1009 std::cout << " replay skipped" << std::endl;
1010 }
1011}
1012
1013void MostInlinedTransformPropagator::propagateSibling(
1014 TensorView* from,
1015 TensorView* to) {
1016 // See note [Using multiple TransformPropagators]
1017 bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator);
1018 if (debug) {
1019 std::cout << "MostInlinedTransformPropagator::propagateSibling"
1020 << std::endl;
1021 std::cout << " from: " << from << std::endl;
1022 std::cout << " to: " << to << std::endl;
1023 }
1024 if (!TransformReplay::fullSelfMatching(to, from)) {
1025 auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
1026 TORCH_INTERNAL_ASSERT(
1027 validateDomain(to, replay),
1028 "Tried to set the domain of ",
1029 to,
1030 " to ",
1031 replay,
1032 " but that would invalidate previously compute at position or max producer position.");
1033 to->setDomain(replay);
1034 if (debug) {
1035 std::cout << " replayed: " << to << std::endl;
1036 }
1037 } else if (debug) {
1038 std::cout << " replay skipped" << std::endl;
1039 }
1040}
1041
1042} // namespace cuda
1043} // namespace fuser
1044} // namespace jit
1045} // namespace torch
1046