1#include <c10/util/irange.h>
2#include <fusion.h>
3#include <ir_all_nodes.h>
4#include <ir_builder.h>
5#include <mutator.h>
6
7#include <vector>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14void OptOutMutator::mutate(Statement* s) {
15 Statement::mutatorDispatch(this, s);
16}
17
18void OptOutMutator::mutate(Expr* e) {
19 Expr::mutatorDispatch(this, e);
20}
21
22void OptOutMutator::mutate(Val* v) {
23 Val::mutatorDispatch(this, v);
24}
25
26void OptOutMutator::registerMutation(Val* val, Val* mutation) {
27 bool val_is_ns = val->vtype() == ValType::NamedScalar;
28 bool mutation_is_ns = mutation->vtype() == ValType::NamedScalar;
29 bool val_is_scalar = val->vtype() == ValType::Scalar;
30 bool mutation_is_scalar = mutation->vtype() == ValType::Scalar;
31 TORCH_INTERNAL_ASSERT(
32 mutation->dtype() == val->dtype() &&
33 (mutation->vtype() == val->vtype() ||
34 ((val_is_ns && mutation_is_scalar) ||
35 (mutation_is_ns && val_is_scalar))),
36 "Mutations are not allowed to change types, tried to go from: (",
37 val->vtype(),
38 ", ",
39 val->dtype(),
40 ") to: (",
41 mutation->vtype(),
42 ", ",
43 mutation->dtype(),
44 ")");
45 mutations[val] = mutation;
46}
47
48void OptOutMutator::mutate(Bool* b) {}
49
50void OptOutMutator::mutate(Double* d) {}
51
52void OptOutMutator::mutate(Int* i) {}
53
54void OptOutMutator::mutate(ComplexDouble* c) {}
55
56void OptOutMutator::mutate(NamedScalar* ns) {}
57
58void OptOutMutator::mutate(IterDomain* id) {
59 Val* start = maybeMutated(id->start());
60 Val* extent = maybeMutated(id->extent());
61 Val* expanded_extent = nullptr;
62 if (id->hasExpandedExtent()) {
63 expanded_extent = maybeMutated(id->expandedExtent());
64 }
65 Val* stop_offset = maybeMutated(id->stopOffset());
66 if (start->sameAs(id->start()) && extent->sameAs(id->extent()) &&
67 (!id->hasExpandedExtent() ||
68 expanded_extent->sameAs(id->expandedExtent())) &&
69 stop_offset->sameAs(id->stopOffset())) {
70 return;
71 }
72 registerMutation(
73 id,
74 IterDomainBuilder(id)
75 .start(start)
76 .extent(extent)
77 .stop_offset(stop_offset)
78 .expanded_extent(expanded_extent)
79 .build());
80}
81
82void OptOutMutator::mutate(TensorDomain* td) {
83 bool mutated = false;
84
85 auto updateIdVec = [&](const std::vector<IterDomain*>& ids) {
86 std::vector<IterDomain*> updated_ids;
87 for (auto id : ids) {
88 auto updated_id = maybeMutated(id)->as<IterDomain>();
89 updated_ids.push_back(updated_id);
90 if (!updated_id->sameAs(id)) {
91 mutated = true;
92 }
93 }
94 return updated_ids;
95 };
96
97 std::vector<IterDomain*> root_dom = updateIdVec(td->getRootDomain());
98 std::vector<IterDomain*> rfactor_dom = td->hasRFactor()
99 ? updateIdVec(td->getMaybeRFactorDomain())
100 : std::vector<IterDomain*>();
101 std::vector<IterDomain*> domain = updateIdVec(td->domain());
102
103 if (!mutated) {
104 return;
105 }
106
107 Val* mutated_val = IrBuilder::create<TensorDomain>(
108 td->container(), root_dom, rfactor_dom, domain, td->contiguity());
109 registerMutation(td, mutated_val);
110}
111
112void OptOutMutator::mutate(TensorView* tv) {
113 TensorDomain* td = maybeMutated(tv->domain())->as<TensorDomain>();
114 if (!tv->domain()->sameAs(td)) {
115 tv->setDomain(td);
116 }
117 // Don't register tv mutations as we just want to update the TD
118}
119
120void OptOutMutator::mutate(kir::Predicate*) {
121 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
122}
123
124void OptOutMutator::mutate(kir::TensorIndex*) {
125 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
126}
127
128void OptOutMutator::mutate(FullOp* fop) {
129 Val* out = maybeMutated(fop->output(0));
130 Val* fill_value = maybeMutated(fop->getFillValue());
131
132 if (out->sameAs(fop->output(0))) {
133 return;
134 }
135 auto container = fop->container();
136 container->removeExpr(fop);
137 IrBuilder::create<FullOp>(container, out, fill_value, fop->dtype());
138}
139
140void OptOutMutator::mutate(ARangeOp* aop) {
141 Val* out = maybeMutated(aop->output(0));
142
143 if (out->sameAs(aop->output(0))) {
144 return;
145 }
146 auto container = aop->container();
147 container->removeExpr(aop);
148 IrBuilder::create<ARangeOp>(
149 container,
150 out,
151 aop->start(),
152 aop->end(),
153 aop->step(),
154 aop->dtype(),
155 aop->getLinearLogicalIndex());
156}
157
158void OptOutMutator::mutate(EyeOp* eop) {
159 Val* out = maybeMutated(eop->output(0));
160
161 if (out->sameAs(eop->output(0))) {
162 return;
163 }
164 auto container = eop->container();
165 container->removeExpr(eop);
166 IrBuilder::create<EyeOp>(
167 container, out, eop->dtype(), eop->getIndex1(), eop->getIndex2());
168}
169
170void OptOutMutator::mutate(UnaryOp* uop) {
171 Val* out = maybeMutated(uop->out());
172 Val* in = maybeMutated(uop->in());
173
174 if (out->sameAs(uop->out()) && in->sameAs(uop->in())) {
175 return;
176 }
177 auto container = uop->container();
178 auto uop_type = uop->getUnaryOpType();
179 container->removeExpr(uop);
180 IrBuilder::create<UnaryOp>(container, uop_type, out, in);
181}
182
183void OptOutMutator::mutate(BinaryOp* bop) {
184 Val* out = maybeMutated(bop->out());
185 Val* lhs = maybeMutated(bop->lhs());
186 Val* rhs = maybeMutated(bop->rhs());
187
188 if (out == bop->out() && lhs == bop->lhs() && rhs == bop->rhs()) {
189 return;
190 }
191
192 auto container = bop->container();
193 auto bop_type = bop->getBinaryOpType();
194 container->removeExpr(bop);
195 IrBuilder::create<BinaryOp>(container, bop_type, out, lhs, rhs);
196}
197
198void OptOutMutator::mutate(TernaryOp* top) {
199 Val* out = maybeMutated(top->out());
200 Val* in1 = maybeMutated(top->in1());
201 Val* in2 = maybeMutated(top->in2());
202 Val* in3 = maybeMutated(top->in3());
203
204 if (out == top->out() && in1 == top->in1() && in2 == top->in2() &&
205 in3 == top->in3()) {
206 return;
207 }
208
209 auto container = top->container();
210 auto top_type = top->getTernaryOpType();
211 container->removeExpr(top);
212 IrBuilder::create<TernaryOp>(container, top_type, out, in1, in2, in3);
213}
214
215void OptOutMutator::mutate(RNGOp* rop) {
216 Val* out = maybeMutated(rop->output(0));
217 auto& parameters = rop->getParameters();
218 std::vector<Val*> mutated_parameters;
219 for (auto v : parameters) {
220 mutated_parameters.emplace_back(maybeMutated(v));
221 }
222
223 if (out == rop->output(0) && mutated_parameters == parameters) {
224 return;
225 }
226
227 auto container = rop->container();
228 auto rop_type = rop->getRNGOpType();
229 container->removeExpr(rop);
230 IrBuilder::create<RNGOp>(
231 container,
232 rop_type,
233 out,
234 rop->dtype(),
235 mutated_parameters,
236 rop->getRNGOffset(),
237 rop->getPhiloxIndex());
238}
239
240void OptOutMutator::mutate(ReductionOp* rop) {
241 Val* out = maybeMutated(rop->out());
242 Val* in = maybeMutated(rop->in());
243 Val* init = rop->init();
244 if (out->sameAs(rop->out()) && in->sameAs(rop->in()) &&
245 init->sameAs(rop->init())) {
246 return;
247 }
248
249 auto container = rop->container();
250 auto rop_type = rop->getReductionOpType();
251 container->removeExpr(rop);
252 IrBuilder::create<ReductionOp>(
253 container, rop_type, init, out, in, rop->isAllreduce());
254}
255
256void OptOutMutator::mutate(GroupedReductionOp* rop) {
257 bool is_same = true;
258
259 std::vector<Val*> outputs;
260 for (auto out : rop->outputs()) {
261 auto maybe_mutated = maybeMutated(out);
262 is_same = is_same && maybe_mutated->sameAs(out);
263 outputs.push_back(maybe_mutated);
264 }
265
266 std::vector<Val*> inputs;
267 for (auto in : rop->inputs()) {
268 auto maybe_mutated = maybeMutated(in);
269 is_same = is_same && maybe_mutated->sameAs(in);
270 inputs.push_back(maybe_mutated);
271 }
272
273 std::vector<Val*> init_vals;
274 for (auto init : rop->initVals()) {
275 auto maybe_mutated = maybeMutated(init);
276 is_same = is_same && maybe_mutated->sameAs(init);
277 init_vals.push_back(maybe_mutated);
278 }
279
280 if (is_same) {
281 return;
282 }
283
284 auto container = rop->container();
285 const auto& rop_types = rop->getReductionOpTypes();
286 container->removeExpr(rop);
287 IrBuilder::create<GroupedReductionOp>(
288 container, rop_types, init_vals, outputs, inputs, rop->isAllreduce());
289}
290
291namespace {
292inline bool compareOptional(Val* a, Val* b) {
293 if (!a || !b) {
294 return (!a && !b);
295 }
296 return a->sameAs(b);
297}
298
299} // namespace
300
301void OptOutMutator::mutate(WelfordOp* wop) {
302 Val* out_avg = maybeMutated(wop->outAvg());
303 Val* out_var = maybeMutated(wop->outVar());
304 Val* out_N = maybeMutated(wop->outN());
305
306 Val* in_avg = maybeMutated(wop->inAvg());
307 Val* in_var = wop->inVar() ? maybeMutated(wop->inVar()) : nullptr;
308 Val* in_N = maybeMutated(wop->inN());
309
310 Val* init_avg = wop->initAvg() ? maybeMutated(wop->initAvg()) : nullptr;
311 Val* init_var = wop->initVar() ? maybeMutated(wop->initVar()) : nullptr;
312 Val* init_N = maybeMutated(wop->initN());
313
314 const bool out_compare = out_avg->sameAs(wop->outAvg()) &&
315 out_var->sameAs(wop->outVar()) && out_N->sameAs(wop->outN());
316 const bool in_compare = in_avg->sameAs(wop->inAvg()) &&
317 compareOptional(in_var, wop->inVar()) && in_N->sameAs(wop->inN());
318 const bool init_compare = compareOptional(init_avg, wop->initAvg()) &&
319 compareOptional(init_var, wop->initVar()) && init_N->sameAs(wop->initN());
320
321 if (out_compare && init_compare && in_compare) {
322 return;
323 }
324
325 auto container = wop->container();
326 container->removeExpr(wop);
327 IrBuilder::create<WelfordOp>(
328 container,
329 out_avg,
330 out_var,
331 out_N,
332 in_avg,
333 in_var,
334 in_N,
335 init_avg,
336 init_var,
337 init_N,
338 wop->isAllreduce());
339}
340
341void OptOutMutator::mutate(GroupedWelfordOp* wop) {
342 bool is_same = true;
343
344 std::vector<WelfordTriplet> output_vals;
345 for (const auto& out : wop->outputVals()) {
346 auto maybe_mutated =
347 out.transform([&](Val* val) { return maybeMutated(val); });
348 is_same = is_same && maybe_mutated.sameAs(out);
349 output_vals.push_back(maybe_mutated);
350 }
351
352 std::vector<WelfordTriplet> input_vals;
353 for (const auto& inp : wop->inputVals()) {
354 auto maybe_mutated =
355 inp.transform([&](Val* val) { return maybeMutated(val); });
356 is_same = is_same && maybe_mutated.sameAs(inp);
357 input_vals.push_back(maybe_mutated);
358 }
359
360 std::vector<WelfordTriplet> init_vals;
361 for (const auto& init : wop->initVals()) {
362 auto maybe_mutated =
363 init.transform([&](Val* val) { return maybeMutated(val); });
364 is_same = is_same && maybe_mutated.sameAs(init);
365 init_vals.push_back(maybe_mutated);
366 }
367
368 if (is_same) {
369 return;
370 }
371
372 auto container = wop->container();
373 container->removeExpr(wop);
374 IrBuilder::create<GroupedWelfordOp>(
375 container, output_vals, input_vals, init_vals, wop->isAllreduce());
376}
377
378void OptOutMutator::mutate(MmaOp* mma) {
379 Val* out = maybeMutated(mma->out());
380 Val* in_a = maybeMutated(mma->inA());
381 Val* in_b = maybeMutated(mma->inB());
382 Val* init = mma->init();
383
384 if (out->sameAs(mma->out()) && in_a->sameAs(mma->inA()) &&
385 in_b->sameAs(mma->inB())) {
386 return;
387 }
388
389 auto container = mma->container();
390 auto options = mma->options();
391 container->removeExpr(mma);
392 C10_UNUSED auto new_mma =
393 IrBuilder::create<MmaOp>(container, out, in_a, in_b, init, options);
394}
395
396void OptOutMutator::mutate(LoadStoreOp* ldst) {
397 Val* out = maybeMutated(ldst->out());
398 Val* in = maybeMutated(ldst->in());
399 auto op_type = ldst->opType();
400
401 if (out->sameAs(ldst->out()) && in->sameAs(ldst->in())) {
402 return;
403 }
404
405 auto container = ldst->container();
406 container->removeExpr(ldst);
407 IrBuilder::create<LoadStoreOp>(container, op_type, out, in);
408}
409
410void OptOutMutator::mutate(BroadcastOp* bop) {
411 Val* out = maybeMutated(bop->out());
412 Val* in = maybeMutated(bop->in());
413
414 if (out->sameAs(bop->out()) && in->sameAs(bop->in())) {
415 return;
416 }
417
418 auto container = bop->container();
419 auto flags = bop->getBroadcastDimFlags();
420 container->removeExpr(bop);
421 IrBuilder::create<BroadcastOp>(container, out, in, flags);
422}
423
424void OptOutMutator::mutate(TransposeOp* top) {
425 TensorView* out = maybeMutated(top->out())->as<TensorView>();
426 TensorView* in = maybeMutated(top->in())->as<TensorView>();
427
428 if (out->sameAs(top->out()) && in->sameAs(top->in())) {
429 return;
430 }
431
432 auto container = top->container();
433 auto new2old = top->new2old();
434 container->removeExpr(top);
435 IrBuilder::create<TransposeOp>(container, out, in, new2old);
436}
437
438void OptOutMutator::mutate(ExpandOp* eop) {
439 bool is_same = true;
440
441 TensorView* out = maybeMutated(eop->out())->as<TensorView>();
442 is_same = is_same && out->sameAs(eop->out());
443 TensorView* in = maybeMutated(eop->in())->as<TensorView>();
444 is_same = is_same && in->sameAs(eop->in());
445
446 std::vector<Val*> expanded_extents;
447 expanded_extents.reserve(eop->expanded_extents().size());
448 for (auto expanded_extent : eop->expanded_extents()) {
449 expanded_extents.push_back(maybeMutated(expanded_extent));
450 if (!expanded_extents.back()->sameAs(expanded_extent)) {
451 is_same = false;
452 }
453 }
454
455 if (is_same) {
456 return;
457 }
458
459 auto container = eop->container();
460 container->removeExpr(eop);
461 IrBuilder::create<ExpandOp>(container, out, in, expanded_extents);
462}
463
464void OptOutMutator::mutate(ShiftOp* sop) {
465 Val* out = maybeMutated(sop->out())->asVal();
466 Val* in = maybeMutated(sop->in())->asVal();
467
468 if (out->sameAs(sop->out()) && in->sameAs(sop->in())) {
469 return;
470 }
471
472 auto offsets = sop->offsets();
473 auto pad_width = sop->padWidth();
474 auto container = sop->container();
475 container->removeExpr(sop);
476 IrBuilder::create<ShiftOp>(container, out, in, offsets, pad_width);
477}
478
479void OptOutMutator::mutate(GatherOp* op) {
480 Val* out = maybeMutated(op->out())->asVal();
481 Val* in = maybeMutated(op->in())->asVal();
482
483 if (out->sameAs(op->out()) && in->sameAs(op->in())) {
484 return;
485 }
486
487 auto window_shape = op->windowShape();
488 auto pad_width = op->padWidth();
489 auto container = op->container();
490 container->removeExpr(op);
491 IrBuilder::create<GatherOp>(container, out, in, window_shape, pad_width);
492}
493
494void OptOutMutator::mutate(ViewAsScalar* vop) {
495 TensorView* out = maybeMutated(vop->out())->as<TensorView>();
496 TensorView* in = maybeMutated(vop->in())->as<TensorView>();
497
498 if (out->sameAs(vop->out()) && in->sameAs(vop->in())) {
499 return;
500 }
501
502 auto container = vop->container();
503 container->removeExpr(vop);
504 IrBuilder::create<ViewAsScalar>(
505 container, out, in, vop->vector_id(), vop->index());
506}
507
508void OptOutMutator::mutate(ViewOp* vop) {
509 TensorView* out = maybeMutated(vop->out())->as<TensorView>();
510 TensorView* in = maybeMutated(vop->in())->as<TensorView>();
511
512 if (out->sameAs(vop->out()) && in->sameAs(vop->in())) {
513 return;
514 }
515
516 auto container = vop->container();
517 container->removeExpr(vop);
518 IrBuilder::create<ViewOp>(container, out, in);
519}
520
521void OptOutMutator::mutate(Split* s) {
522 IterDomain* ot = maybeMutated(s->outer())->as<IterDomain>();
523 IterDomain* inr = maybeMutated(s->inner())->as<IterDomain>();
524 IterDomain* in = maybeMutated(s->in())->as<IterDomain>();
525 Val* fact = maybeMutated(s->factor())->as<Val>();
526 Val* start_offset = maybeMutated(s->startOffset());
527 Val* stop_offset = maybeMutated(s->stopOffset());
528
529 if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) &&
530 in->sameAs(s->in()) && areEqualScalars(fact, s->factor()) &&
531 start_offset->sameAs(s->startOffset()) &&
532 stop_offset->sameAs(s->stopOffset())) {
533 return;
534 }
535
536 auto container = s->container();
537 auto inner_split = s->innerSplit();
538 container->removeExpr(s);
539 C10_UNUSED auto new_node = IrBuilder::create<Split>(
540 container, ot, inr, in, fact, inner_split, start_offset, stop_offset);
541}
542
543void OptOutMutator::mutate(Merge* m) {
544 IterDomain* ot = maybeMutated(m->out())->as<IterDomain>();
545 IterDomain* otr = maybeMutated(m->outer())->as<IterDomain>();
546 IterDomain* in = maybeMutated(m->inner())->as<IterDomain>();
547
548 if (ot->sameAs(m->out()) && otr->sameAs(m->outer()) &&
549 in->sameAs(m->inner())) {
550 return;
551 }
552
553 auto container = m->container();
554 container->removeExpr(m);
555 C10_UNUSED auto new_node = IrBuilder::create<Merge>(container, ot, otr, in);
556}
557
558void OptOutMutator::mutate(Swizzle2D* m) {
559 IterDomain* outx = maybeMutated(m->outX())->as<IterDomain>();
560 IterDomain* outy = maybeMutated(m->outY())->as<IterDomain>();
561
562 IterDomain* inx = maybeMutated(m->inX())->as<IterDomain>();
563 IterDomain* iny = maybeMutated(m->inY())->as<IterDomain>();
564
565 auto swizzle_type = m->swizzleType();
566
567 if (outx->sameAs(m->outX()) && outy->sameAs(m->outY()) &&
568 inx->sameAs(m->inX()) && iny->sameAs(m->inY())) {
569 return;
570 }
571 auto container = m->container();
572 container->removeExpr(m);
573 FusionGuard::getCurFusion()->removeExpr(m);
574 C10_UNUSED auto new_node = IrBuilder::create<Swizzle2D>(
575 container, outx, outy, inx, iny, swizzle_type);
576}
577
578void OptOutMutator::mutate(kir::Allocate*) {
579 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
580}
581void OptOutMutator::mutate(kir::BlockSync*) {
582 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
583}
584void OptOutMutator::mutate(kir::GridSync*) {
585 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
586}
587void OptOutMutator::mutate(kir::CpAsyncWait*) {
588 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
589}
590void OptOutMutator::mutate(kir::CpAsyncCommit*) {
591 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
592}
593void OptOutMutator::mutate(kir::InitMagicZero*) {
594 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
595}
596void OptOutMutator::mutate(kir::UpdateMagicZero*) {
597 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
598}
599void OptOutMutator::mutate(kir::ForLoop*) {
600 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
601}
602void OptOutMutator::mutate(kir::IfThenElse*) {
603 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
604}
605void OptOutMutator::mutate(kir::GridReduction*) {
606 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
607}
608void OptOutMutator::mutate(kir::GroupedGridReduction*) {
609 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
610}
611void OptOutMutator::mutate(kir::GridBroadcast*) {
612 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
613}
614void OptOutMutator::mutate(kir::GridWelford*) {
615 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
616}
617void OptOutMutator::mutate(kir::GroupedGridWelford*) {
618 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
619}
620void OptOutMutator::mutate(kir::AllocateFusedReduction*) {
621 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
622}
623void OptOutMutator::mutate(kir::Swizzle2DInt*) {
624 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
625}
626void OptOutMutator::mutate(kir::PairSelect*) {
627 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
628}
629void OptOutMutator::mutate(kir::IntPair*) {
630 TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
631}
632
633void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) {
634 container->removeExpr(expr);
635}
636} // namespace cuda
637} // namespace fuser
638} // namespace jit
639} // namespace torch
640