1#include <transform_iter.h>
2
3#include <c10/util/irange.h>
4#include <ir_utils.h>
5
6namespace torch {
7namespace jit {
8namespace fuser {
9namespace cuda {
10
11// Transform dispatch
12void ReplayTransformations::handle(Expr* e) {
13 switch (e->getExprType().value()) {
14 case (ExprType::Split):
15 case (ExprType::Merge):
16 case (ExprType::Swizzle2D):
17 break;
18 default:
19 TORCH_INTERNAL_ASSERT(
20 false, "Invalid expr type found in transform traversal.");
21 }
22 IterVisitor::handle(e);
23}
24
25// We're going to replay this split operation on the corresponding ID
26void ReplayTransformations::handle(Split* s) {
27 // Grab our input to the split node
28 auto id_in = s->in();
29
30 // Make sure we have a corresponding entry in our map pointing to the ID we're
31 // going to replay the split on
32 auto it = id_map_.find(id_in);
33 if (it == id_map_.end()) {
34 if (error_on_failure_) {
35 TORCH_INTERNAL_ASSERT(
36 false, "Transform traversal failed, dependencies not met.");
37 } else {
38 return;
39 }
40 }
41
42 auto mapped = (*it).second;
43 // Make sure this ID is a leaf ID (meaning it has no uses we generated)
44 TORCH_INTERNAL_ASSERT(
45 leaf_ids_.find(mapped) != leaf_ids_.end(),
46 "Transform traversal failed, modified a node but it was not a leaf node.");
47
48 // Replay the split onto mapped
49 auto outs = IterDomain::split(
50 mapped, s->factor(), s->innerSplit(), s->startOffset(), s->stopOffset());
51 // Remove mapped from the leaf IDs
52 leaf_ids_.erase(mapped);
53
54 // Add outputs to leaf IDs
55 leaf_ids_[outs.first] = counter++;
56 leaf_ids_[outs.second] = counter++;
57
58 // Update our ID map to include these outputs
59 id_map_[s->outer()] = outs.first;
60 id_map_[s->inner()] = outs.second;
61}
62
63// We're going to replay this merge operation on the corresponding IDs
64void ReplayTransformations::handle(Merge* m) {
65 // Grab the inputs to the merge node
66 auto id_outer = m->outer();
67 auto id_inner = m->inner();
68
69 // Make sure we have a corresponding entry in our map pointing to the IDs
70 // we're going to replay the merge on
71 auto it_outer = id_map_.find(id_outer);
72 auto it_inner = id_map_.find(id_inner);
73
74 const bool outer_found = it_outer != id_map_.end();
75 const bool outer_bcast = id_outer->isBroadcast();
76 const bool inner_found = it_inner != id_map_.end();
77 const bool inner_bcast = id_inner->isBroadcast();
78
79 // If either are not found
80 if (!outer_found || !inner_found) {
81 // If both aren't found, it's a failure
82 // If outer is found && inner is bcast it is not a failure
83 // If inner is found && outer is bcast it is not a failure
84 if (!(outer_found || inner_found) || (outer_found && !inner_bcast) ||
85 (inner_found && !outer_bcast)) {
86 if (error_on_failure_) {
87 TORCH_INTERNAL_ASSERT(
88 false, "Transform traversal failed, dependencies not met.");
89 } else {
90 return;
91 }
92 }
93 }
94
95 // If we merge a broadcast dim with a non-broadcast dim, just remap the output
96 // to the non-broadcast dim.
97 if (inner_found && !outer_found && outer_bcast) {
98 id_map_[m->out()] = it_inner->second;
99 return;
100 }
101 if (outer_found && !inner_found && inner_bcast) {
102 id_map_[m->out()] = it_outer->second;
103 return;
104 }
105
106 // Grab the IDs we're going to replay this merge on
107 const auto id_outer_mapped = it_outer->second;
108 const auto id_inner_mapped = it_inner->second;
109
110 // Make sure these IDs are leaf IDs (meaning they have no uses we generated)
111 TORCH_INTERNAL_ASSERT(
112 leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() &&
113 leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(),
114 "Transform traversal failed, tried to replay with ",
115 id_outer_mapped,
116 " and ",
117 id_inner_mapped,
118 " however one or both are not leaf nodes.");
119
120 // Replay the merge operation
121 auto out = IterDomain::merge(id_outer_mapped, id_inner_mapped);
122
123 // Remove inputs from the leaf IDs
124 leaf_ids_.erase(id_outer_mapped);
125 leaf_ids_.erase(id_inner_mapped);
126
127 // Add the output to the leaf IDs
128 leaf_ids_[out] = counter++;
129
130 // Update our ID map with the replayed output
131 id_map_[m->out()] = out;
132}
133
134void ReplayTransformations::handle(Swizzle2D* swizzle_2d) {
135 // Grab our input to the split node
136 auto id_in_x = swizzle_2d->inX();
137 auto id_in_y = swizzle_2d->inY();
138
139 // Make sure we have a corresponding entry in our map pointing to the ID we're
140 // going to replay the swizzle on
141 auto it_x = id_map_.find(id_in_x);
142 auto it_y = id_map_.find(id_in_y);
143
144 if (it_x == id_map_.end() || it_y == id_map_.end()) {
145 if (error_on_failure_) {
146 TORCH_INTERNAL_ASSERT(
147 false, "Transform traversal failed, dependencies not met.");
148 } else {
149 return;
150 }
151 }
152
153 auto mapped_x = (*it_x).second;
154 auto mapped_y = (*it_y).second;
155
156 // Make sure this ID is a leaf ID (meaning it has no uses we generated)
157 TORCH_INTERNAL_ASSERT(
158 leaf_ids_.find(mapped_x) != leaf_ids_.end() &&
159 leaf_ids_.find(mapped_y) != leaf_ids_.end(),
160 "Transform traversal failed, modified a node but it was not a leaf node.");
161
162 auto outs = std::make_pair(mapped_x, mapped_y);
163
164 if (replay_swizzle_) {
165 // Replay the swizzle onto mapped
166 outs = IterDomain::swizzle(swizzle_2d->swizzleType(), mapped_x, mapped_y);
167
168 // Remove mapped from the leaf IDs
169 leaf_ids_.erase(mapped_x);
170 leaf_ids_.erase(mapped_y);
171 }
172
173 // Add outputs to leaf IDs
174 leaf_ids_[outs.first] = counter++;
175 leaf_ids_[outs.second] = counter++;
176
177 // Update our ID map to include these outputs
178 id_map_[swizzle_2d->outX()] = outs.first;
179 id_map_[swizzle_2d->outY()] = outs.second;
180}
181
182ReplayTransformations::ReplayTransformations(
183 const std::vector<IterDomain*>& _target_domain,
184 std::unordered_map<IterDomain*, IterDomain*> _id_map,
185 bool _error_on_failure,
186 bool replay_swizzle)
187 : target_domain_(_target_domain),
188 id_map_(std::move(_id_map)),
189 error_on_failure_(_error_on_failure),
190 replay_swizzle_(replay_swizzle) {
191 // Make sure id_map has all the inputs needed to replay target_domain
192 auto inps = IterVisitor::getInputsTo(
193 std::vector<Val*>(target_domain_.begin(), target_domain_.end()));
194
195 if (error_on_failure_)
196 std::for_each(inps.begin(), inps.end(), [this](Val* val) {
197 TORCH_INTERNAL_ASSERT(
198 val->getValType().value() == ValType::IterDomain,
199 "Expected IterDomain only for Replay Transformations, but found ",
200 val);
201 IterDomain* id = val->as<IterDomain>();
202 TORCH_INTERNAL_ASSERT(
203 id_map_.find(id) != id_map_.end(),
204 "Could not find required input: ",
205 id,
206 " in provided id_map.");
207 });
208
209 // Set all the leaf nodes for tracking, all ids start as a leaf and will be
210 // updated based on the transformations
211 for (auto entry : id_map_)
212 leaf_ids_[entry.second] = counter++;
213}
214
215// Replays outputs that were generated from ids.first on ids.second
216void ReplayTransformations::runReplay() {
217 TORCH_INTERNAL_ASSERT(
218 !ran_replay,
219 "Cannot run replay twice without creating a new Replay Class.");
220 ran_replay = true;
221 if (target_domain_.empty() || id_map_.empty())
222 return;
223
224 // Switch outDomain to a vector to start the traversal
225 std::vector<Val*> traversal_vals(
226 target_domain_.begin(), target_domain_.end());
227 traverseTo(traversal_vals[0]->fusion(), traversal_vals);
228
229 if (error_on_failure_)
230 TORCH_INTERNAL_ASSERT(
231 leaf_ids_.size() >= target_domain_.size(),
232 "Transform traversal failed, did not find enough output IterDomains.");
233
234 // Validate replay
235 for (auto out : target_domain_) {
236 auto it_replayed = id_map_.find(out);
237 if (it_replayed == id_map_.end()) {
238 if (error_on_failure_) {
239 TORCH_INTERNAL_ASSERT(
240 false,
241 "Transform traversal failed, could not find expected output.");
242 }
243 continue;
244 }
245
246 auto id_replayed = (*it_replayed).second;
247 auto it_leaf = leaf_ids_.find(id_replayed);
248 TORCH_INTERNAL_ASSERT(
249 it_leaf != leaf_ids_.end(),
250 "Transform Traversal failed, expected a replayed dim for ",
251 out,
252 " but one was not created.");
253 }
254
255 // Populate leaf_vec_ in a deterministic manner. This is deterministic
256 // because size_t in leaf_ids is filled based on operation order.
257 std::set<std::pair<IterDomain*, size_t>, id_int_lt> ordered_set;
258 for (auto entry : leaf_ids_)
259 ordered_set.emplace(entry);
260
261 leaf_vec_.clear();
262 leaf_vec_.resize(ordered_set.size());
263 std::transform(
264 ordered_set.begin(),
265 ordered_set.end(),
266 leaf_vec_.begin(),
267 [](std::pair<IterDomain*, size_t> entry) { return entry.first; });
268}
269
270BestEffortReplay::BestEffortReplay(
271 const std::vector<IterDomain*>& replay_domain,
272 const std::vector<IterDomain*>& target_domain,
273 std::unordered_map<IterDomain*, IterDomain*> target2replay_map,
274 std::unordered_map<IterDomain*, IterDomain*> replay_forward_id_map,
275 std::unordered_map<IterDomain*, IterDomain*> target_forward_id_map,
276 bool skip_swizzle)
277 : target2replay_id_map_(std::move(target2replay_map)),
278 replay_forward_id_map_(std::move(replay_forward_id_map)),
279 target_forward_id_map_(std::move(target_forward_id_map)),
280 skip_swizzle_(skip_swizzle) {
281 for (auto entry : target2replay_id_map_) {
282 leaf_ids_[entry.second] = counter++;
283 }
284
285 // Grab expr history of iter domains in target_domain
286 std::vector<Expr*> target_exprs = StmtSort::getExprs(
287 FusionGuard::getCurFusion(),
288 std::vector<Val*>(target_domain.begin(), target_domain.end()));
289
290 // If we check how an IterDomain was generated, it should only use an
291 // IterDomain in an expression once. We pull a map from the input
292 // IterDomains to the expression consuming them to generate the
293 // replay_domain domain. This will be used to propagate the target_domain to
294 // replay_domain map.
295
296 // Map replay domain's IterDomains to the Exprs they're used in
297 std::vector<Expr*> replay_exprs = StmtSort::getExprs(
298 FusionGuard::getCurFusion(),
299 std::vector<Val*>(replay_domain.begin(), replay_domain.end()));
300
301 // Track which id's in replay have to be replayed to guarantee rfactor
302 // transformations. The iteration domains in the rfactor axes don't have
303 // to be used in a matching expression in target, so we want to exclude those.
304 // Only the iteration domains [root_domains, rfactor) domains have to be used
305 // in matching transformation to guarantee rfactor domain is consistent.
306 // However, if any rfactor id was used to produce the rfactor domain, we need
307 // transformations on them to match the target exactly.
308 std::unordered_set<IterDomain*> replay_rfactor_ids;
309
310 // Track which expressions iteration domains are used, they should only be
311 // used in one expression.
312 std::unordered_map<IterDomain*, Expr*> replay_id2expr_map;
313 for (auto replay_expr : replay_exprs) {
314 for (auto id : ir_utils::filterByType<IterDomain>(replay_expr->inputs())) {
315 TORCH_INTERNAL_ASSERT(
316 replay_id2expr_map.find(id) == replay_id2expr_map.end(),
317 "Error trying to map rfactor root domain during replay.",
318 " An IterDomain was found to be used in more than one expression.");
319
320 replay_id2expr_map[id] = replay_expr;
321 }
322
323 // Only want to forward rfactor in map
324 auto out_ids = ir_utils::filterByType<IterDomain>(replay_expr->outputs());
325 if (std::any_of(out_ids.begin(), out_ids.end(), [](IterDomain* id) {
326 return id->isRFactorProduct();
327 })) {
328 auto inp_ids = ir_utils::filterByType<IterDomain>(replay_expr->inputs());
329 replay_rfactor_ids.insert(inp_ids.begin(), inp_ids.end());
330 }
331 }
332
333 std::unordered_map<IterDomain*, Expr*> target_id2expr_map;
334 for (auto target_expr : target_exprs) {
335 for (auto id : ir_utils::filterByType<IterDomain>(target_expr->inputs())) {
336 TORCH_INTERNAL_ASSERT(
337 target_id2expr_map.insert({id, target_expr}).second,
338 "BestEffortReplay : Unexpected multi-use of id",
339 id);
340 }
341 }
342
343 if (skip_swizzle_) {
344 // Progress through all swizzle ops if we are skipping
345 // swizzles on the mapping.
346 skipSwizzles(target_id2expr_map, replay_id2expr_map);
347 }
348
349 std::string err_str(
350 "Error during replay, a transformation was called that conflicts with an rfactor call.");
351
352 bool any_target_expr_contains_broadcast_id = false;
353
354 // Iterate through target IterDomains' history and compare with what we
355 // recorded from replay_domain
356 for (auto target_expr : target_exprs) {
357 auto target_inps_filtered =
358 ir_utils::filterByType<IterDomain>(target_expr->inputs());
359
360 // If any input argument in target expression is in the forward map then
361 // forward the mapped IterDomains in replay and continue to the next
362 // expression as target_expr cannot match a replay_expr
363 if (std::any_of(
364 target_inps_filtered.begin(),
365 target_inps_filtered.end(),
366 [&](IterDomain* target_inp) {
367 return this->inTargetForwardMap(target_inp);
368 })) {
369 for (auto target_inp : target_inps_filtered) {
370 if (inTargetForwardMap(target_inp)) {
371 auto target2replay_it = target2replay_id_map_.find(target_inp);
372 if (target2replay_it != target2replay_id_map_.end()) {
373 // Replace target_inp entry in target2replay_id_map_ with forwarded
374 // id
375 target2replay_id_map_[getTargetForwardedId(target_inp)] =
376 target2replay_it->second;
377 target2replay_id_map_.erase(target_inp);
378 }
379 }
380 }
381 // Continue to next target_expr
382 continue;
383 }
384
385 std::vector<IterDomain*> target_id_inps(
386 target_inps_filtered.begin(), target_inps_filtered.end());
387
388 bool target_expr_contains_broadcast_id = std::any_of(
389 target_inps_filtered.begin(),
390 target_inps_filtered.end(),
391 [](IterDomain* id) { return id->isBroadcast(); });
392 any_target_expr_contains_broadcast_id =
393 any_target_expr_contains_broadcast_id ||
394 target_expr_contains_broadcast_id;
395
396 std::vector<IterDomain*> replay_inps =
397 std::vector<IterDomain*>(target_id_inps.size(), nullptr);
398
399 bool missing_replay_input = false;
400
401 // Map target_expr inputs to replay domain directly
402 for (const auto t_i : c10::irange(target_id_inps.size())) {
403 // There might not be a mapping, that could be okay (depends on rfactor
404 // checking).
405 auto it = target2replay_id_map_.find(target_id_inps[t_i]);
406 if (it != target2replay_id_map_.end()) {
407 replay_inps[t_i] = getReplayForwardedId(it->second);
408 } else {
409 missing_replay_input = true;
410 }
411 }
412
413 // Check if any of the associated replay id's are part of an rfactor domain
414 bool replay_has_rfactor_inp = std::any_of(
415 replay_inps.begin(),
416 replay_inps.end(),
417 [&replay_rfactor_ids](IterDomain* id) {
418 return id == nullptr ? false
419 : id->isRFactorProduct() &&
420 (replay_rfactor_ids.find(id) != replay_rfactor_ids.end());
421 });
422
423 // If some replay id inputs are part of rfactor, make sure all target
424 // expression inputs map to a replay input
425 if (replay_has_rfactor_inp) {
426 bool no_missing_exprs = std::none_of(
427 replay_inps.begin(),
428 replay_inps.end(),
429 [&replay_id2expr_map](IterDomain* id) {
430 if (id == nullptr) {
431 return true;
432 } else {
433 return replay_id2expr_map.find(id) == replay_id2expr_map.end();
434 }
435 });
436 // View operation creates a TensorView with rfactor. After view, broadcast
437 // operation adds iterDomains for any size-1 dimensions. Therefore, the
438 // target domain (broadcast) may contain broadcast ids that are not
439 // present in the replay domain (view). In this case, we skip any target
440 // expressions that contain broadcast ids.
441 TORCH_INTERNAL_ASSERT(
442 no_missing_exprs || any_target_expr_contains_broadcast_id, err_str);
443 }
444
445 // If any inputs are missing, continue as this expr doesn't match.
446 if (missing_replay_input) {
447 TORCH_INTERNAL_ASSERT(
448 !replay_has_rfactor_inp || any_target_expr_contains_broadcast_id,
449 err_str);
450 continue;
451 }
452
453 // Find which replay_expr maps to the target_expr
454 Expr* replay_expr = nullptr;
455 // Check if all inputs have the same expression
456 bool mismatched_replay_exprs = false;
457 for (auto replay_inp : replay_inps) {
458 auto it = replay_id2expr_map.find(replay_inp);
459 if (it != replay_id2expr_map.end()) {
460 if (replay_expr == nullptr) {
461 replay_expr = it->second;
462 } else {
463 mismatched_replay_exprs =
464 mismatched_replay_exprs || replay_expr != it->second;
465 }
466 } else {
467 // If no expr is mapped then set mismatched epxrs to go to continue to
468 // the next target expr
469 mismatched_replay_exprs = true;
470 }
471 }
472
473 // If expressions of mapped inputs don't match, then continue to next target
474 // expr
475 if (mismatched_replay_exprs || replay_expr == nullptr) {
476 TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
477 continue;
478 }
479
480 bool mismatched_inputs = replay_inps.size() != replay_expr->inputs().size();
481 for (size_t i = 0; i < replay_inps.size() && !mismatched_inputs; i++) {
482 mismatched_inputs =
483 mismatched_inputs || replay_expr->inputs()[i] != replay_inps[i];
484 }
485
486 // If there isn't an rfactor id in the replay's inputs and there's a
487 // mismatched input, continue
488 if (mismatched_inputs) {
489 TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
490 continue;
491 }
492
493 // If there isn't an rfactor id in the replay's inputs and there's a
494 // mismatch in replay_expr's and target_expr's outputs, continue
495 if (target_expr->outputs().size() != replay_expr->outputs().size()) {
496 TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
497 continue;
498 }
499
500 // If there isn't an rfactor id in the replay's inputs and there's a
501 // mismatch in replay_expr's and target_expr's expression type, continue
502 if (replay_expr->getExprType().value() !=
503 target_expr->getExprType().value()) {
504 TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
505 continue;
506 }
507
508 // If there isn't an rfactor id in the replay's inputs and there's a
509 // mismatch in replay_expr's and target_expr's split factor (if a split
510 // expr), continue
511 if (replay_expr->getExprType().value() == ExprType::Split) {
512 auto r_split = replay_expr->as<Split>();
513 auto t_split = target_expr->as<Split>();
514 if (!r_split->factor()->sameAs(t_split->factor()) ||
515 r_split->innerSplit() != t_split->innerSplit() ||
516 !r_split->startOffset()->sameAs(t_split->startOffset()) ||
517 !r_split->stopOffset()->sameAs(t_split->stopOffset())) {
518 TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
519 continue;
520 }
521 }
522
523 // Need to match swizzle type and parameters if
524 // not skipping swizzles in this mapping pass.
525 if (!skip_swizzle_ && replay_expr->etype() == ExprType::Swizzle2D) {
526 auto r_swizzle_2d = replay_expr->as<Swizzle2D>();
527 auto t_swizzle_2d = target_expr->as<Swizzle2D>();
528 if (!(r_swizzle_2d->swizzleType() == t_swizzle_2d->swizzleType())) {
529 TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
530 continue;
531 }
532 }
533
534 // Take replay expr inputs out of map:
535 for (const auto t_i : c10::irange(target_id_inps.size())) {
536 auto t_inp = target_id_inps[t_i];
537 auto r_orig_inp = target2replay_id_map_.at(t_inp);
538 auto r_maybe_forwarded_inp = replay_inps[t_i];
539
540 // Remove original target2replay_it->second if it's in leaf_ids
541 if (leaf_ids_.find(r_orig_inp) != leaf_ids_.end()) {
542 leaf_ids_.erase(r_orig_inp);
543 }
544
545 // Check if we used a forwarded id, if so add forwarded id's to tracking.
546 if (r_orig_inp != r_maybe_forwarded_inp) {
547 forwarded_ids_.emplace_back(r_orig_inp);
548 }
549 }
550
551 // Add outputs to map.
552 for (const auto i : c10::irange(target_expr->outputs().size())) {
553 auto t_out = target_expr->output(i);
554 auto r_out = replay_expr->output(i);
555 if (t_out->getValType() == ValType::IterDomain &&
556 r_out->getValType() == ValType::IterDomain) {
557 target2replay_id_map_[t_out->as<IterDomain>()] =
558 r_out->as<IterDomain>();
559 leaf_ids_[r_out->as<IterDomain>()] = counter++;
560 }
561 }
562
563 if (skip_swizzle_) {
564 // Progress through all swizzle ops if we are skipping
565 // swizzles on the mapping.
566 skipSwizzles(target_id2expr_map, replay_id2expr_map);
567 }
568 }
569}
570
571// Find the first position i where td1[i] is not the same as td2[i].
572// "Same" means the DAG to generate td1[i] and td2[i] are the
573// equivelent.
574int BestEffortReplay::findFirstMismatchedID(
575 const TensorDomain* td1,
576 const TensorDomain* td2) {
577 std::unordered_map<IterDomain*, IterDomain*> id_map;
578 auto rd1 = td1->getRootDomain();
579 auto rd2 = td2->getRootDomain();
580 std::unordered_set<IterDomain*> rd2_set(
581 td2->getRootDomain().begin(), td2->getRootDomain().end());
582
583 // Find matching root IterDomains, we could make this O(nlog(n)) if we could
584 // sort IterDomains.
585 for (auto rd1i : rd1) {
586 for (auto rd2i : rd2) {
587 if (rd1i->sameAs(rd2i) && rd2_set.find(rd2i) != rd2_set.end()) {
588 id_map[rd1i] = rd2i;
589 rd2_set.erase(rd2i);
590 break;
591 }
592 }
593 }
594
595 BestEffortReplay ber(td2->domain(), td1->domain(), id_map);
596 for (const auto i :
597 c10::irange(std::max(td1->domain().size(), td2->domain().size()))) {
598 if (ber.getReplay().find(td1->axis(i)) == ber.getReplay().end()) {
599 return i;
600 }
601 // Order is important.
602 auto td2_axis = ber.getReplay().at(td1->axis(i));
603 if (td2->axis(i) != td2_axis) {
604 return i;
605 }
606 }
607 return std::min(td1->nDims(), td2->nDims());
608}
609
610namespace {
611
612// Maps that track information relevant to best effort replay about broadcast
613// axes in consumer that are not in producer
614//
615// For example if we have consumer: T0[i0, b1, b2, i3] and producer:
616// T1[i0, i3]
617//
618// If consumer transformations are:
619// -> T[i0, b1o, b1i, b2o, b2i, i3]
620// -> T[i0*b1i, b1o, b2o, b2i, i3]
621// -> T[i0*b1i*b2o, b1o, b2i, i3]
622// -> T[i0*b1i*b2o*i3, b1o, b2i]
623//
624// forwarding_map would forward i0->i0*b1i and i0*b1i->i0*b1i*b2o
625// compliment_map would have the entry i0->b1i and i0*b1i->b2o
626//
627// The first is to fast forward transformations in consumer involving broadcast
628// axes not in producer. The compliment map is to use later to compute what leaf
629// nodes we may have after the forwarding process is finished. Leaf nodes are
630// only important for replayCasP, so look there to see how this is done. Forward
631// map is used for replayCasP and replayPasC.
632struct ConsumerForwardingInfo {
633 public:
634 // Map IterDomain* axes that can safely be forwarded to their output.
635 std::unordered_map<IterDomain*, IterDomain*> forwarding_map;
636
637 // Given a forward id map id_input -> id_forwarded
638 // Track the other inputs in the expr that id_input is an input to. These will
639 // be used to adjust the replay's leaf tracking. Don't need to track one to
640 // many as currently transformations on IterDomains can only have maximum 2
641 // inputs, but maybe in the future we'll have more.
642 std::unordered_map<IterDomain*, std::vector<IterDomain*>> compliment_map;
643
644 ConsumerForwardingInfo(
645 const TensorView* producer,
646 const TensorView* consumer) {
647 // Collect which root axes are in consumer that are not in producer because
648 // of broadcasting
649 std::unordered_set<IterDomain*> consumer_bcast_roots_not_in_producer;
650
651 const auto c2p_root_map =
652 PairwiseRootDomainMap(producer, consumer)
653 .mapConsumerToProducer(consumer->domain(), producer->domain());
654
655 for (auto consumer_root_id : consumer->getRootDomain()) {
656 if (consumer_root_id->isBroadcast()) {
657 if (c2p_root_map.find(consumer_root_id) == c2p_root_map.end()) {
658 consumer_bcast_roots_not_in_producer.emplace(consumer_root_id);
659 }
660 }
661 }
662
663 // We have root axes in consumer that don't exist in producer, now forward
664 // those to include all id's in consumer comprised of only axes not in
665 // producer.
666 auto consumer_bcast_ids_not_in_producer =
667 consumer_bcast_roots_not_in_producer;
668
669 std::vector<Expr*> consumer_history = StmtSort::getExprs(
670 FusionGuard::getCurFusion(),
671 std::vector<Val*>(
672 consumer->domain()->domain().begin(),
673 consumer->domain()->domain().end()));
674
675 auto isIdOnlyInConsumer =
676 [&consumer_bcast_ids_not_in_producer](IterDomain* input_id) {
677 return consumer_bcast_ids_not_in_producer.find(input_id) !=
678 consumer_bcast_ids_not_in_producer.end();
679 };
680
681 for (auto expr : consumer_history) {
682 auto input_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
683 // If expr inputs are all in consumer_bcast_ids_not_in_producer, than so
684 // are all outputs
685 if (std::all_of(input_ids.begin(), input_ids.end(), isIdOnlyInConsumer)) {
686 // add all outputs to not being in producer
687 for (auto output_ids :
688 ir_utils::filterByType<IterDomain>(expr->outputs())) {
689 consumer_bcast_ids_not_in_producer.emplace(output_ids);
690 }
691 } else if (
692 expr->isA<Merge>() &&
693 std::any_of(input_ids.begin(), input_ids.end(), isIdOnlyInConsumer)) {
694 auto merge_expr = expr->as<Merge>();
695 // If
696 // - one of the inputs is made of id's in consumer that don't map to
697 // producer (bcast axes),
698 // - && the other input maps to an id in both consumer and producer
699 // - && this is a merge
700 // for the sake of BestEffortReplay we can forward the input mapping
701 // to both consumer and producer to the output of the expression
702 std::vector<IterDomain*> forwarded_ids;
703 std::vector<IterDomain*> compliment_ids;
704
705 for (auto input_id : input_ids) {
706 if (!isIdOnlyInConsumer(input_id)) {
707 forwarded_ids.emplace_back(input_id);
708 forwarding_map.emplace(std::make_pair(input_id, merge_expr->out()));
709 } else {
710 compliment_ids.push_back(input_id);
711 }
712 }
713
714 // Set up compliment map
715 for (auto forwarded_id : forwarded_ids) {
716 compliment_map.emplace(std::make_pair(forwarded_id, compliment_ids));
717 }
718 }
719 }
720 }
721};
722
723// Maps that track information relevant to best effort replay about
724// trivial-reduction axes in producer
725//
726// For example if we have producer: T0[i0, r1, r2, i3] and consumer:
727// T1[i0, i3]
728//
729// If producer transformations are:
730// -> T[i0, r1, r2, i3]
731// -> T[i0*r1, r2, i3]
732// -> T[i0*r1*r2, i3]
733//
734// forwarding_map would forward i0->i0*r1 and i0*r1->i0*r1*r2
735// compliment_map would have the i0->r1 and i0*r1->r2
736//
737// These two maps are used similarly as ConsumerForwardingInfo. See
738// its comments as well.
739struct ProducerForwardingInfo {
740 public:
741 // Map IterDomain* axes that can safely be forwarded to their output.
742 std::unordered_map<IterDomain*, IterDomain*> forwarding_map;
743
744 // Given a forward id map id_input -> id_forwarded
745 // Track the other inputs in the expr that id_input is an input to. These will
746 // be used to adjust the replay's leaf tracking. Don't need to track one to
747 // many as currently transformations on IterDomains can only have maximum 2
748 // inputs, but maybe in the future we'll have more.
749 std::unordered_map<IterDomain*, std::vector<IterDomain*>> compliment_map;
750
751 ProducerForwardingInfo(const TensorView* producer) {
752 std::vector<Expr*> producer_history = StmtSort::getExprs(
753 FusionGuard::getCurFusion(),
754 std::vector<Val*>(
755 producer->domain()->domain().begin(),
756 producer->domain()->domain().end()));
757
758 for (auto merge : ir_utils::filterByType<Merge>(producer_history)) {
759 auto inner = merge->inner();
760 auto outer = merge->outer();
761 if ((inner->isTrivialReduction() && !outer->isReduction()) ||
762 (outer->isTrivialReduction() && !inner->isReduction())) {
763 auto compliment_id = inner->isTrivialReduction() ? inner : outer;
764 auto forwarded_id = inner->isTrivialReduction() ? outer : inner;
765 forwarding_map.emplace(std::make_pair(forwarded_id, merge->out()));
766 compliment_map.emplace(std::make_pair(
767 forwarded_id, std::vector<IterDomain*>{compliment_id}));
768 }
769 }
770 }
771};
772
773// Trace chain of swizzles until reaching
774// an IterDomain that's either a leaf or
775// not a producer of any swizzle.
776IterDomain* getSwizzleFinalOutput(
777 IterDomain* id,
778 const std::unordered_map<IterDomain*, Expr*>& id2expr) {
779 bool is_swizzle_input = true;
780
781 // Note: currently not supporting swizzling consumer of another
782 // swizzle id, so this should terminate in 1 iter, but eventually
783 // will try to support stacked swizzles so keeping this pass
784 // generic.
785 while (is_swizzle_input) {
786 auto expr_it = id2expr.find(id);
787
788 // This means id is a leaf that doesn't
789 // have any consumers. Stop iteration in this case.
790 if (expr_it == id2expr.end()) {
791 is_swizzle_input = false;
792 break;
793 }
794
795 if (expr_it->second->etype() == ExprType::Swizzle2D) {
796 // In the case of 2D swizzle ops, just forward
797 // inX to outX and inY to outY.
798 auto expr = expr_it->second->as<Swizzle2D>();
799 if (id == expr->inX()) {
800 id = expr->outX();
801 } else {
802 TORCH_INTERNAL_ASSERT(
803 id == expr->inY(),
804 "unknown input to swizzle op",
805 id->toString(),
806 expr->toString());
807 id = expr->outY();
808 }
809 } else {
810 // Probably unreachable but if the expression
811 // is unknown type assume it is not a swizzle op.
812 is_swizzle_input = false;
813 }
814 }
815
816 return id;
817}
818
819bool isSwizzleInput(
820 IterDomain* input_id,
821 const std::unordered_map<IterDomain*, Expr*>& id2expr) {
822 auto user_expr_it = id2expr.find(input_id);
823
824 if (user_expr_it == id2expr.end()) {
825 return false;
826 }
827
828 return user_expr_it->second->etype() == ExprType::Swizzle2D;
829}
830
831} // namespace
832
833void BestEffortReplay::addComplimentLeafIDs(
834 const std::unordered_map<IterDomain*, IterDomain*>& forwarding_map,
835 const std::unordered_map<IterDomain*, std::vector<IterDomain*>>&
836 compliment_map) {
837 // ID's could go through more than one forward iteration in the map before it
838 // terminates. Grab every id between the forwarded id, and what it was
839 // forwarded to
840 std::function<void(IterDomain*, std::vector<IterDomain*>&)>
841 collectForwardedIds =
842 [&forwarding_map, &collectForwardedIds](
843 IterDomain* forward_id,
844 std::vector<IterDomain*>& forwarded_ids) -> void {
845 if (forwarding_map.find(forward_id) != forwarding_map.end()) {
846 forwarded_ids.emplace_back(forward_id);
847 collectForwardedIds(forwarding_map.at(forward_id), forwarded_ids);
848 }
849 };
850
851 std::vector<IterDomain*> expanded_forwarded_ids;
852 for (auto forwarded_id : forwarded_ids_) {
853 collectForwardedIds(forwarded_id, expanded_forwarded_ids);
854 }
855
856 // Grab all compliments of forwarded ids.
857 std::vector<IterDomain*> compliments;
858 for (auto forwarded_id : expanded_forwarded_ids) {
859 auto compliment_map_it = compliment_map.find(forwarded_id);
860 TORCH_INTERNAL_ASSERT(
861 compliment_map_it != compliment_map.end(),
862 "Issue tracking forwarded broadcast merges in best effort replay.");
863 compliments.insert(
864 compliments.end(),
865 compliment_map_it->second.begin(),
866 compliment_map_it->second.end());
867 }
868
869 // Grab all exprs used to make the forwarded compliments
870 auto compliment_exprs = StmtSort::getExprs(
871 FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()});
872
873 // Figure out if there are any leaves in compliment_exprs that aren't
874 // the forwarded id
875 std::unordered_map<IterDomain*, size_t> leaf_ids;
876
877 for (auto expr : compliment_exprs) {
878 for (auto inp : ir_utils::filterByType<IterDomain>(expr->inputs())) {
879 leaf_ids.erase(inp);
880 }
881 for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) {
882 // If we used the comliment for forwarded don't add to leaf nodes.
883 if (std::find(compliments.begin(), compliments.end(), out) ==
884 compliments.end()) {
885 leaf_ids.emplace(std::make_pair(out, counter++));
886 }
887 }
888 }
889
890 leaf_ids_.insert(leaf_ids.begin(), leaf_ids.end());
891}
892
893BestEffortReplay BestEffortReplay::replayCasP(
894 const TensorView* consumer,
895 const TensorView* producer,
896 int producer_compute_at_axis,
897 const RootDomainMap& root_map) {
898 if (producer_compute_at_axis < 0)
899 producer_compute_at_axis += (int)producer->nDims() + 1;
900
901 TORCH_INTERNAL_ASSERT(
902 producer_compute_at_axis >= 0 &&
903 (unsigned int)producer_compute_at_axis <= producer->nDims(),
904 "Invalid axis provided to BestEffortReplay::replayCasP.");
905
906 // producer ids we need to match in consumer
907 std::vector<IterDomain*> producer_CA_ids(
908 producer->domain()->domain().begin(),
909 producer->domain()->domain().begin() + producer_compute_at_axis);
910 producer_CA_ids = TensorDomain::noReductions(producer_CA_ids);
911
912 // If producer has an rfactor root, that's what will match the consumer
913 std::vector<IterDomain*> producer_root = producer->getMaybeRFactorDomain();
914
915 // Figure out all inputs required to generate the compute_at dimensions. We
916 // need all deps because inputs on producer may be in getRootDomain, but we
917 // may need in rFactorDomain
918 auto all_CA_id_deps = DependencyCheck::getAllValsBetween(
919 {producer_root.begin(), producer_root.end()},
920 {producer_CA_ids.begin(), producer_CA_ids.end()});
921
922 // Figure out minimal set of root IDs needed to produce producer_CA_ids:
923 std::unordered_set<IterDomain*> producer_CA_root_ids;
924 for (IterDomain* id : producer_root) {
925 if (std::find(all_CA_id_deps.begin(), all_CA_id_deps.end(), id) !=
926 all_CA_id_deps.end()) {
927 producer_CA_root_ids.emplace(id);
928 }
929 }
930
931 const auto p2c_root_map = root_map.mapProducerToConsumer(
932 producer->domain(), consumer->domain(), producer_CA_root_ids);
933
934 // See FusionAdvancedComputeAt7 for an example of the forwarding logic
935 ConsumerForwardingInfo consumer_forwarding_info(producer, consumer);
936
937 ProducerForwardingInfo producer_forwarding_info(producer);
938
939 auto consumer_replay = BestEffortReplay(
940 consumer->domain()->domain(),
941 producer_CA_ids,
942 p2c_root_map,
943 consumer_forwarding_info.forwarding_map,
944 producer_forwarding_info.forwarding_map);
945
946 consumer_replay.addComplimentLeafIDs(
947 consumer_forwarding_info.forwarding_map,
948 consumer_forwarding_info.compliment_map);
949
950 return consumer_replay;
951}
952
953// Runs a best effort replay that ignores broadcast axes that appear in
954// consumer that are not mapped to producer in root_map.
955BestEffortReplay BestEffortReplay::replayPasC(
956 const TensorView* producer,
957 const TensorView* consumer,
958 int consumer_compute_at_axis,
959 const RootDomainMap& root_map) {
960 if (consumer_compute_at_axis < 0)
961 consumer_compute_at_axis += (int)consumer->nDims() + 1;
962 TORCH_INTERNAL_ASSERT(
963 consumer_compute_at_axis >= 0 &&
964 (unsigned int)consumer_compute_at_axis <= consumer->nDims(),
965 "Invalid axis provided to BestEffortReplay::replayPasC.");
966
967 // consumer ids we need to match in producer
968 std::vector<IterDomain*> consumer_CA_ids(
969 consumer->domain()->domain().begin(),
970 consumer->domain()->domain().begin() + consumer_compute_at_axis);
971
972 // Figure out all inputs required to generate the compute_at dimensions
973 auto consumer_CA_root_vals = IterVisitor::getInputsTo(
974 std::vector<Val*>(consumer_CA_ids.begin(), consumer_CA_ids.end()));
975
976 std::unordered_set<IterDomain*> consumer_CA_root_ids;
977 for (auto val : consumer_CA_root_vals) {
978 if (val->getValType().value() == ValType::IterDomain) {
979 consumer_CA_root_ids.emplace(val->as<IterDomain>());
980 }
981 }
982
983 const auto c2p_root_map = root_map.mapConsumerToProducer(
984 consumer->domain(), producer->domain(), consumer_CA_root_ids);
985
986 ConsumerForwardingInfo consumer_forwarding_info(producer, consumer);
987
988 ProducerForwardingInfo producer_forwarding_info(producer);
989
990 // Instead of replaying from the root, lets try to play forward the history
991 // of producer if they match ops on consumer. Enforce if we modify an
992 // rfactor axis that those ops must match.
993 auto producer_replay = BestEffortReplay(
994 producer->domain()->domain(),
995 consumer_CA_ids,
996 c2p_root_map,
997 producer_forwarding_info.forwarding_map,
998 consumer_forwarding_info.forwarding_map);
999
1000 producer_replay.addComplimentLeafIDs(
1001 producer_forwarding_info.forwarding_map,
1002 producer_forwarding_info.compliment_map);
1003
1004 return producer_replay;
1005}
1006
1007void BestEffortReplay::skipSwizzles(
1008 const std::unordered_map<IterDomain*, Expr*>& target_id2expr,
1009 const std::unordered_map<IterDomain*, Expr*>& replay_id2expr) {
1010 // Update target2replay map
1011 bool updated = true;
1012
1013 while (updated) {
1014 updated = false;
1015 for (auto it : target2replay_id_map_) {
1016 if (isSwizzleInput(it.first, target_id2expr) ||
1017 isSwizzleInput(it.second, replay_id2expr)) {
1018 updated = true;
1019 auto new_target = getSwizzleFinalOutput(it.first, target_id2expr);
1020 auto new_replay = getSwizzleFinalOutput(it.second, replay_id2expr);
1021
1022 // new_target and new_replay will now be the final output
1023 // skipping all swizzles in between. We'd need to
1024 // update the mapping and leaf ids to the final outputs.
1025 target2replay_id_map_.erase(it.first);
1026 TORCH_INTERNAL_ASSERT(
1027 target2replay_id_map_.insert(std::make_pair(new_target, new_replay))
1028 .second,
1029 "Unexpected replay leaf");
1030 // Progress the leaf ids if the replay is updated
1031 if (it.second != new_replay &&
1032 leaf_ids_.find(it.second) != leaf_ids_.end()) {
1033 leaf_ids_.erase(it.second);
1034 leaf_ids_[new_replay] = counter++;
1035 }
1036 break;
1037 }
1038 }
1039 }
1040}
1041
1042DisjointSets<IterDomain*> BestEffortReplay::getDisjointSets() {
1043 DisjointSets<IterDomain*> result;
1044 const std::unordered_map<IterDomain*, IterDomain*>* maps[3] = {
1045 &target2replay_id_map_, &replay_forward_id_map_, &target_forward_id_map_};
1046 for (auto map : maps) {
1047 for (auto entry : *map) {
1048 result.mapEntries(entry.first, entry.second);
1049 }
1050 }
1051 return result;
1052}
1053
1054} // namespace cuda
1055} // namespace fuser
1056} // namespace jit
1057} // namespace torch
1058