1#include <arith.h>
2#include <disjoint_set.h>
3#include <ir_cloner.h>
4#include <ir_interface_nodes.h>
5#include <ir_iostream.h>
6#include <ir_utils.h>
7#include <kernel.h>
8#include <kernel_ir.h>
9#include <lower2device.h>
10#include <root_domain_map.h>
11#include <transform_iter.h>
12#include <transform_rfactor.h>
13#include <transform_view.h>
14
15#include <c10/util/irange.h>
16
17#include <sstream>
18
19namespace torch {
20namespace jit {
21namespace fuser {
22namespace cuda {
23
24namespace {
25
26class ScalarCheck : OptInConstDispatch {
27 public:
28 static bool sameAs(const Val* v1, const Val* v2) {
29 if (v1 == v2)
30 return true;
31
32 if (v1->getValType() != v2->getValType())
33 return false;
34
35 if (v1->getDataType() != v2->getDataType())
36 return false;
37
38 ScalarCheck sc(v1, v2);
39 return sc.same_;
40 }
41
42 private:
43 void handle(const Bool* b) final {
44 same_ = v1_->as<Bool>()->sameAs(v2_->as<Bool>());
45 }
46
47 void handle(const Double* d) final {
48 same_ = v1_->as<Double>()->sameAs(v2_->as<Double>());
49 }
50
51 void handle(const Int* i) final {
52 same_ = v1_->as<Int>()->sameAs(v2_->as<Int>());
53 }
54
55 void handle(const NamedScalar* ns) final {
56 same_ = v1_->as<NamedScalar>()->sameAs(v2_->as<NamedScalar>());
57 }
58
59 ScalarCheck(const Val* _v1, const Val* _v2) : v1_(_v1), v2_(_v2) {
60 OptInConstDispatch::handle(v1_);
61 }
62
63 private:
64 const Val* v1_ = nullptr;
65 const Val* v2_ = nullptr;
66 bool same_ = false;
67};
68
69} // namespace
70
71bool areEqualScalars(Val* v1, Val* v2) {
72 return ScalarCheck::sameAs(v1, v2);
73}
74
75Bool::Bool(IrBuilderPasskey passkey)
76 : Val(passkey, ValType::Scalar, DataType::Bool),
77 maybe_value_{c10::nullopt} {}
78
79Bool::Bool(IrBuilderPasskey passkey, bool value)
80 : Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {}
81
82Bool::Bool(IrBuilderPasskey passkey, c10::optional<bool> value)
83 : Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {}
84
85Bool::Bool(const Bool* src, IrCloner* ir_cloner)
86 : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
87
88bool Bool::sameAs(const Statement* other) const {
89 if (this == other) {
90 return true;
91 }
92 if (!other->isA<Bool>()) {
93 return false;
94 }
95 const auto other_bool = other->as<Bool>();
96 if (isConst() && other_bool->isConst()) {
97 return *value() == *(other_bool->value());
98 }
99 return false;
100}
101
102Double::Double(IrBuilderPasskey passkey)
103 : Val(passkey, ValType::Scalar, DataType::Double),
104 maybe_value_{c10::nullopt} {}
105
106Double::Double(IrBuilderPasskey passkey, ScalarType value)
107 : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {}
108
109Double::Double(IrBuilderPasskey passkey, c10::optional<ScalarType> value)
110 : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {}
111
112Double::Double(const Double* src, IrCloner* ir_cloner)
113 : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
114
115bool Double::sameAs(const Statement* other) const {
116 if (this == other) {
117 return true;
118 }
119 if (!other->isA<Double>()) {
120 return false;
121 }
122 const auto other_double = other->as<Double>();
123 if (isConst() && other_double->isConst())
124 return *value() == *(other_double->value());
125 return false;
126}
127
128Int::Int(IrBuilderPasskey passkey)
129 : Val(passkey, ValType::Scalar, DataType::Int),
130 maybe_value_{c10::nullopt} {}
131
132Int::Int(IrBuilderPasskey passkey, ScalarType value)
133 : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {}
134
135Int::Int(IrBuilderPasskey passkey, c10::optional<ScalarType> value)
136 : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {}
137
138Int::Int(const Int* src, IrCloner* ir_cloner)
139 : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
140
141bool Int::sameAs(const Statement* other) const {
142 if (this == other) {
143 return true;
144 }
145 if (!other->isA<Int>()) {
146 return false;
147 }
148 const auto other_int = other->as<Int>();
149 if (isConst() && other_int->isConst()) {
150 return *value() == *(other_int->value());
151 }
152 return false;
153}
154
155ComplexDouble::ComplexDouble(IrBuilderPasskey passkey)
156 : Val(passkey, ValType::Scalar, DataType::ComplexDouble),
157 maybe_value_{c10::nullopt} {}
158
159ComplexDouble::ComplexDouble(IrBuilderPasskey passkey, ScalarType value)
160 : Val(passkey, ValType::Scalar, DataType::ComplexDouble),
161 maybe_value_{value} {}
162
163ComplexDouble::ComplexDouble(
164 IrBuilderPasskey passkey,
165 c10::optional<ScalarType> value)
166 : Val(passkey, ValType::Scalar, DataType::ComplexDouble),
167 maybe_value_{value} {}
168
169ComplexDouble::ComplexDouble(const ComplexDouble* src, IrCloner* ir_cloner)
170 : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
171
172bool ComplexDouble::sameAs(const Statement* other) const {
173 if (this == other) {
174 return true;
175 }
176 if (!other->isA<ComplexDouble>()) {
177 return false;
178 }
179 const auto other_complex = other->as<ComplexDouble>();
180 if (isConst() && other_complex->isConst())
181 return *value() == *(other_complex->value());
182 return false;
183}
184
185FullOp::FullOp(
186 IrBuilderPasskey passkey,
187 Val* out,
188 Val* fill_value,
189 DataType dtype)
190 : Expr(passkey, ExprType::FullOp), dtype_(dtype), fill_value_(fill_value) {
191 if (out->isA<TensorView>()) {
192 addInput(out->as<TensorView>()->getRootDomain()[0]->extent());
193 }
194 addInput(fill_value);
195 addOutput(out);
196}
197
198FullOp::FullOp(const FullOp* src, IrCloner* ir_cloner)
199 : Expr(src, ir_cloner),
200 dtype_(src->dtype()),
201 fill_value_(ir_cloner->clone(src->fill_value_)) {}
202
203Expr* FullOp::shallowCopy() const {
204 auto result = IrBuilder::create<FullOp>(output(0), fill_value_, dtype_);
205 result->copyPredicatesFrom(this);
206 return result;
207}
208
209bool FullOp::sameAs(const Statement* other) const {
210 if (this == other) {
211 return true;
212 }
213 if (!other->isA<FullOp>()) {
214 return false;
215 }
216 const auto other_op = other->as<FullOp>();
217 if (dtype_ != other_op->dtype_) {
218 return false;
219 }
220 return Expr::sameAs(other);
221}
222
223ARangeOp::ARangeOp(
224 IrBuilderPasskey passkey,
225 Val* out,
226 Val* start,
227 Val* end,
228 Val* step,
229 DataType dtype,
230 Val* linear_index)
231 : Expr(passkey, ExprType::ARangeOp),
232 dtype_(dtype),
233 start_(start),
234 end_(end),
235 step_(step),
236 linear_index_(linear_index) {
237 addInput(start);
238 addInput(end);
239 addInput(step);
240 addOutput(out);
241}
242
243ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner)
244 : Expr(src, ir_cloner),
245 dtype_(src->dtype()),
246 start_(ir_cloner->clone(src->start_)),
247 end_(ir_cloner->clone(src->end_)),
248 step_(ir_cloner->clone(src->step_)),
249 linear_index_(ir_cloner->clone(src->linear_index_)) {}
250
251Expr* ARangeOp::shallowCopy() const {
252 auto result = IrBuilder::create<ARangeOp>(
253 output(0), start_, end_, step_, dtype_, linear_index_);
254 result->copyPredicatesFrom(this);
255 return result;
256}
257
258bool ARangeOp::sameAs(const Statement* other) const {
259 if (this == other) {
260 return true;
261 }
262 if (!other->isA<ARangeOp>()) {
263 return false;
264 }
265 const auto other_op = other->as<ARangeOp>();
266 if (dtype_ != other_op->dtype_) {
267 return false;
268 }
269 if (!start_->sameAs(other_op->start_)) {
270 return false;
271 }
272 if (!end_->sameAs(other_op->end_)) {
273 return false;
274 }
275 if (!step_->sameAs(other_op->step_)) {
276 return false;
277 }
278 if ((linear_index_ == nullptr) != (other_op->linear_index_ == nullptr)) {
279 return false;
280 }
281 if ((linear_index_ != nullptr) &&
282 !linear_index_->sameAs(other_op->linear_index_)) {
283 return false;
284 }
285 return Expr::sameAs(other);
286}
287
288EyeOp::EyeOp(
289 IrBuilderPasskey passkey,
290 Val* out,
291 DataType dtype,
292 Val* index1,
293 Val* index2)
294 : Expr(passkey, ExprType::EyeOp),
295 dtype_(dtype),
296 index1_(index1),
297 index2_(index2) {
298 if (out->isA<TensorView>()) {
299 addInput(out->as<TensorView>()->getRootDomain()[0]->extent());
300 if (out->as<TensorView>()->getRootDomain()[1] !=
301 out->as<TensorView>()->getRootDomain()[0]) {
302 addInput(out->as<TensorView>()->getRootDomain()[1]->extent());
303 }
304 }
305 addOutput(out);
306}
307
308EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner)
309 : Expr(src, ir_cloner),
310 dtype_(src->dtype_),
311 index1_(ir_cloner->clone(src->index1_)),
312 index2_(ir_cloner->clone(src->index2_)) {}
313
314Expr* EyeOp::shallowCopy() const {
315 auto result = IrBuilder::create<EyeOp>(output(0), dtype_, index1_, index2_);
316 result->copyPredicatesFrom(this);
317 return result;
318}
319
320bool EyeOp::sameAs(const Statement* other) const {
321 if (this == other) {
322 return true;
323 }
324 if (!other->isA<EyeOp>()) {
325 return false;
326 }
327 const auto other_op = other->as<EyeOp>();
328 if (dtype_ != other_op->dtype_) {
329 return false;
330 }
331 if ((index1_ == nullptr) != (other_op->index1_ == nullptr)) {
332 return false;
333 }
334 if ((index2_ == nullptr) != (other_op->index2_ == nullptr)) {
335 return false;
336 }
337 if ((index1_ != nullptr) && !index1_->sameAs(other_op->index1_)) {
338 return false;
339 }
340 if ((index2_ != nullptr) && !index2_->sameAs(other_op->index2_)) {
341 return false;
342 }
343 return Expr::sameAs(other);
344}
345
346UnaryOp::UnaryOp(
347 IrBuilderPasskey passkey,
348 UnaryOpType type,
349 Val* out,
350 Val* in,
351 int rng_offset)
352 : Expr(passkey, ExprType::UnaryOp),
353 unary_op_type_{type},
354 out_{out},
355 in_{in} {
356 addOutput(out);
357 addInput(in);
358}
359
360UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner)
361 : Expr(src, ir_cloner),
362 unary_op_type_(src->unary_op_type_),
363 out_(ir_cloner->clone(src->out_)),
364 in_(ir_cloner->clone(src->in_)) {}
365
366Expr* UnaryOp::shallowCopy() const {
367 auto result = IrBuilder::create<UnaryOp>(unary_op_type_, out_, in_);
368 result->copyPredicatesFrom(this);
369 return result;
370}
371
372bool UnaryOp::sameAs(const Statement* other) const {
373 if (this == other) {
374 return true;
375 }
376 if (!other->isA<UnaryOp>()) {
377 return false;
378 }
379 const auto other_op = other->as<UnaryOp>();
380 if (getUnaryOpType() != other_op->getUnaryOpType()) {
381 return false;
382 }
383 return Expr::sameAs(other);
384}
385
386BinaryOp::BinaryOp(
387 IrBuilderPasskey passkey,
388 BinaryOpType type,
389 Val* out,
390 Val* lhs,
391 Val* rhs)
392 : Expr(passkey, ExprType::BinaryOp),
393 binary_op_type_{type},
394 out_{out},
395 lhs_{lhs},
396 rhs_{rhs} {
397 addOutput(out);
398 addInput(lhs);
399 addInput(rhs);
400}
401
402BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner)
403 : Expr(src, ir_cloner),
404 binary_op_type_(src->binary_op_type_),
405 out_(ir_cloner->clone(src->out_)),
406 lhs_(ir_cloner->clone(src->lhs_)),
407 rhs_(ir_cloner->clone(src->rhs_)) {}
408
409Expr* BinaryOp::shallowCopy() const {
410 auto result = IrBuilder::create<BinaryOp>(binary_op_type_, out_, lhs_, rhs_);
411 result->copyPredicatesFrom(this);
412 return result;
413}
414
415bool BinaryOp::sameAs(const Statement* other) const {
416 if (this == other) {
417 return true;
418 }
419 if (!other->isA<BinaryOp>()) {
420 return false;
421 }
422 const auto other_op = other->as<BinaryOp>();
423 if (getBinaryOpType() != other_op->getBinaryOpType()) {
424 return false;
425 }
426 return Expr::sameAs(other);
427}
428
429TernaryOp::TernaryOp(
430 IrBuilderPasskey passkey,
431 TernaryOpType type,
432 Val* out,
433 Val* in1,
434 Val* in2,
435 Val* in3)
436 : Expr(passkey, ExprType::TernaryOp),
437 ternary_op_type_{type},
438 out_{out},
439 in1_{in1},
440 in2_{in2},
441 in3_{in3} {
442 addOutput(out);
443 addInput(in1);
444 addInput(in2);
445 addInput(in3);
446}
447
448TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner)
449 : Expr(src, ir_cloner),
450 ternary_op_type_(src->ternary_op_type_),
451 out_(ir_cloner->clone(src->out_)),
452 in1_(ir_cloner->clone(src->in1_)),
453 in2_(ir_cloner->clone(src->in2_)),
454 in3_(ir_cloner->clone(src->in3_)) {}
455
456Expr* TernaryOp::shallowCopy() const {
457 auto result =
458 IrBuilder::create<TernaryOp>(ternary_op_type_, out_, in1_, in2_, in3_);
459 result->copyPredicatesFrom(this);
460 return result;
461}
462
463bool TernaryOp::sameAs(const Statement* other) const {
464 if (this == other) {
465 return true;
466 }
467 if (!other->isA<TernaryOp>()) {
468 return false;
469 }
470 const auto other_op = other->as<TernaryOp>();
471 if (getTernaryOpType() != other_op->getTernaryOpType()) {
472 return false;
473 }
474 return Expr::sameAs(other);
475}
476
477RNGOp::RNGOp(
478 IrBuilderPasskey passkey,
479 RNGOpType type,
480 Val* out,
481 DataType dtype,
482 std::vector<Val*> parameters,
483 int rng_offset,
484 Val* philox_index)
485 : Expr(passkey, ExprType::RNGOp),
486 rng_op_type_(type),
487 dtype_(dtype),
488 parameters_(std::move(parameters)),
489 rng_offset_(rng_offset),
490 philox_index_(philox_index) {
491 if (out->isA<TensorView>()) {
492 for (auto id : out->as<TensorView>()->getRootDomain()) {
493 shape_.emplace_back(id->extent());
494 }
495 }
496 for (auto v : shape_) {
497 addInput(v);
498 }
499 for (auto v : parameters_) {
500 addInput(v);
501 }
502 addOutput(out);
503}
504
505RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner)
506 : Expr(src, ir_cloner),
507 rng_op_type_(src->rng_op_type_),
508 dtype_(src->dtype()),
509 parameters_(ir_cloner->clone(src->parameters_)),
510 rng_offset_(src->rng_offset_),
511 philox_index_(ir_cloner->clone(src->philox_index_)) {}
512
513Expr* RNGOp::shallowCopy() const {
514 auto result = IrBuilder::create<RNGOp>(
515 rng_op_type_, output(0), dtype_, parameters_, rng_offset_, philox_index_);
516 result->copyPredicatesFrom(this);
517 return result;
518}
519
520bool RNGOp::sameAs(const Statement* other) const {
521 if (this == other) {
522 return true;
523 }
524 if (!other->isA<RNGOp>()) {
525 return false;
526 }
527 const auto other_op = other->as<RNGOp>();
528 if (getRNGOpType() != other_op->getRNGOpType()) {
529 return false;
530 }
531 if (dtype_ != other_op->dtype_) {
532 return false;
533 }
534 if (parameters_.size() != other_op->parameters_.size()) {
535 return false;
536 }
537 for (auto i : c10::irange(parameters_.size())) {
538 if (!parameters_[i]->sameAs(other_op->parameters_[i])) {
539 return false;
540 }
541 }
542 if (getRNGOffset() != other_op->getRNGOffset()) {
543 return false;
544 }
545 if ((philox_index_ == nullptr) != (other_op->philox_index_ == nullptr)) {
546 return false;
547 }
548 if ((philox_index_ != nullptr) &&
549 !philox_index_->sameAs(other_op->philox_index_)) {
550 return false;
551 }
552 return Expr::sameAs(other);
553}
554
555BroadcastOp::BroadcastOp(
556 IrBuilderPasskey passkey,
557 Val* out,
558 Val* in,
559 std::vector<bool> is_broadcast_dims)
560 : Expr(passkey, ExprType::BroadcastOp),
561 out_(out),
562 in_(in),
563 is_broadcast_dims_(std::move(is_broadcast_dims)) {
564 // clang-tidy complains about out_ that it may be null.
565 TORCH_INTERNAL_ASSERT(out_ != nullptr);
566 TORCH_INTERNAL_ASSERT(in_ != nullptr);
567
568 auto out_type = out->getValType().value();
569 auto in_type = in->getValType().value();
570
571 TORCH_INTERNAL_ASSERT(
572 (out_type == ValType::TensorView && in_type == ValType::TensorView) ||
573 (out_type == ValType::TensorIndex && in_type == ValType::TensorIndex),
574 "Cannot braodcast a non-tensor object.");
575
576 addOutput(out);
577 addInput(in);
578
579 if (!out->isA<TensorView>() || !in->isA<TensorView>()) {
580 return;
581 }
582
583 passkey.ir_container_->registerExpr(exprPasskey(), this);
584
585 // This is a generic check that root dims of a consumer and producer match.
586 // Maybe we shouldn't relegate it to this constructor.
587 const auto c_tv = out_->as<TensorView>();
588 const auto p_tv = in_->as<TensorView>();
589
590 const auto& c_root = c_tv->getRootDomain();
591 const auto& p_root = p_tv->getMaybeRFactorDomain();
592
593 const auto root_p2c =
594 PairwiseRootDomainMap(p_tv, c_tv)
595 .mapProducerToConsumer(p_tv->domain(), c_tv->domain());
596
597 for (auto id : p_root) {
598 if (root_p2c.find(id) == root_p2c.end()) {
599 TORCH_INTERNAL_ASSERT(
600 id->isReduction() || id->isStride(),
601 "Invalid broadcast op: ",
602 id,
603 ". Non-reduction input dim doesn't match to output.");
604 }
605 }
606
607 std::unordered_set<IterDomain*> c_mapped;
608 for (auto pair_entry : root_p2c) {
609 c_mapped.insert(pair_entry.second);
610 }
611
612 for (const auto i : c10::irange(c_root.size())) {
613 const auto c_id = c_root[i];
614 if (c_mapped.find(c_id) != c_mapped.end()) {
615 continue;
616 }
617 TORCH_INTERNAL_ASSERT(
618 c_id->isBroadcast() && is_broadcast_dims_[i],
619 "Invalid broadcast op: ",
620 c_id,
621 ". Non-broadcasted output dim isn't matched from input.");
622 }
623}
624
625BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner)
626 : Expr(src, ir_cloner),
627 out_(ir_cloner->clone(src->out_)),
628 in_(ir_cloner->clone(src->in_)),
629 is_broadcast_dims_(src->is_broadcast_dims_) {}
630
631Expr* BroadcastOp::shallowCopy() const {
632 auto result = IrBuilder::create<BroadcastOp>(out_, in_, is_broadcast_dims_);
633 result->copyPredicatesFrom(this);
634 return result;
635}
636
637bool BroadcastOp::sameAs(const Statement* other) const {
638 if (this == other) {
639 return true;
640 }
641 if (!other->isA<BroadcastOp>()) {
642 return false;
643 }
644 const auto other_op = other->as<BroadcastOp>();
645 if (getBroadcastDimFlags() != other_op->getBroadcastDimFlags()) {
646 return false;
647 }
648 return Expr::sameAs(other);
649}
650
651ReductionOp::ReductionOp(
652 IrBuilderPasskey passkey,
653 BinaryOpType reduction_op_type,
654 Val* init,
655 Val* out,
656 Val* in,
657 bool is_allreduce,
658 ExprType expr_type)
659 : Expr(passkey, expr_type),
660 reduction_op_type_(reduction_op_type),
661 init_(init),
662 out_(out),
663 in_(in),
664 is_allreduce_(is_allreduce) {
665 TORCH_CHECK(
666 out->getValType().value() == ValType::TensorView ||
667 out->getValType().value() == ValType::TensorIndex);
668
669 TORCH_INTERNAL_ASSERT(
670 (in->getValType() == ValType::TensorView &&
671 out->getValType() == ValType::TensorView) ||
672 (in->getValType() == ValType::TensorIndex &&
673 out->getValType() == ValType::TensorIndex),
674 "Reduction operation was created that does not have tensor inputs and outputs.");
675
676 if (in->isA<TensorView>()) {
677 TORCH_INTERNAL_ASSERT(
678 TensorDomain::noReductions(
679 in->as<TensorView>()->getMaybeRFactorDomain())
680 .size() == out->as<TensorView>()->getRootDomain().size(),
681 "Reduction operation created with mismatched domains.");
682 }
683 TORCH_INTERNAL_ASSERT(
684 init->isConstScalar(),
685 "Tried to create a reduction operation whith an initial value that isn't a constant.");
686
687 addOutput(out);
688 addInput(in);
689}
690
691ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner)
692 : Expr(src, ir_cloner),
693 reduction_op_type_(src->reduction_op_type_),
694 init_(ir_cloner->clone(src->init_)),
695 out_(ir_cloner->clone(src->out_)),
696 in_(ir_cloner->clone(src->in_)),
697 is_allreduce_(src->is_allreduce_) {}
698
699Expr* ReductionOp::shallowCopy() const {
700 auto result = IrBuilder::create<ReductionOp>(
701 reduction_op_type_, init_, out_, in_, is_allreduce_, etype());
702 result->copyPredicatesFrom(this);
703 return result;
704}
705
706bool ReductionOp::sameAs(const Statement* other) const {
707 if (this == other) {
708 return true;
709 }
710 if (!other->isA<ReductionOp>()) {
711 return false;
712 }
713 const auto other_op = other->as<ReductionOp>();
714 // Note that init is not part of input vals, so it must be checked separately.
715 return (
716 Expr::sameAs(other) &&
717 getReductionOpType() == other_op->getReductionOpType() &&
718 init()->sameAs(other_op->init()));
719}
720
721GroupedReductionOp::GroupedReductionOp(
722 IrBuilderPasskey passkey,
723 std::vector<BinaryOpType> reduction_op_types,
724 std::vector<Val*> init_vals,
725 std::vector<Val*> outputs,
726 std::vector<Val*> inputs,
727 bool is_fused,
728 ExprType expr_type)
729 : Expr(passkey, expr_type),
730 reduction_op_types_(std::move(reduction_op_types)),
731 init_vals_(std::move(init_vals)),
732 is_allreduce_(is_fused) {
733 for (auto out : outputs) {
734 addOutput(out);
735 }
736
737 for (auto in : inputs) {
738 addInput(in);
739 }
740}
741
742GroupedReductionOp::GroupedReductionOp(
743 const GroupedReductionOp* src,
744 IrCloner* ir_cloner)
745 : Expr(src, ir_cloner),
746 reduction_op_types_(src->reduction_op_types_),
747 init_vals_(ir_cloner->clone(src->init_vals_)),
748 is_allreduce_(src->is_allreduce_) {}
749
750Expr* GroupedReductionOp::shallowCopy() const {
751 auto result = IrBuilder::create<GroupedReductionOp>(
752 reduction_op_types_,
753 init_vals_,
754 outputs(),
755 inputs(),
756 is_allreduce_,
757 etype());
758 result->copyPredicatesFrom(this);
759 return result;
760}
761
762int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const {
763 auto it = std::find(outputs().begin(), outputs().end(), output_val);
764 if (it != outputs().end()) {
765 return std::distance(outputs().begin(), it);
766 }
767
768 TORCH_INTERNAL_ASSERT(
769 false, "Not an output, ", output_val->toString(), ", of ", toString());
770}
771
772bool GroupedReductionOp::sameAs(const Statement* other) const {
773 if (this == other) {
774 return true;
775 }
776
777 auto grouped_rop = dynamic_cast<const GroupedReductionOp*>(other);
778 if (grouped_rop == nullptr) {
779 return false;
780 }
781
782 if (!Expr::sameAs(other) ||
783 getReductionOpTypes() != grouped_rop->getReductionOpTypes()) {
784 return false;
785 }
786
787 for (const auto i : c10::irange(numExprs())) {
788 if (!initVal(i)->sameAs(grouped_rop->initVal(i))) {
789 return false;
790 }
791 }
792
793 return true;
794}
795
796WelfordOp::WelfordOp(
797 IrBuilderPasskey passkey,
798 const WelfordTriplet& output,
799 const WelfordTriplet& input,
800 const WelfordTriplet& init,
801 bool is_fused)
802 : Expr(passkey, ExprType::WelfordOp),
803 output_(output),
804 input_(input),
805 init_(init),
806 is_allreduce_(is_fused) {
807 // Previously, nullptr was accepted and implicitly replaced by
808 // default values. Looks like we always pass some non-null values,
809 // so removed the implicit default behavior for code simplicity.
810 TORCH_INTERNAL_ASSERT(output.avg() != nullptr);
811 TORCH_INTERNAL_ASSERT(output.var() != nullptr);
812 TORCH_INTERNAL_ASSERT(output.N() != nullptr);
813 TORCH_INTERNAL_ASSERT(init.avg() != nullptr);
814 TORCH_INTERNAL_ASSERT(init.var() != nullptr);
815 TORCH_INTERNAL_ASSERT(init.N() != nullptr);
816 TORCH_INTERNAL_ASSERT(input.avg() != nullptr);
817 TORCH_INTERNAL_ASSERT(input.var() != nullptr);
818 TORCH_INTERNAL_ASSERT(input.N() != nullptr);
819
820 // Check output type
821 TORCH_INTERNAL_ASSERT(
822 output.avg()->getValType().value() == ValType::TensorView ||
823 output.avg()->getValType().value() == ValType::TensorIndex);
824 TORCH_INTERNAL_ASSERT(
825 output.var()->getValType().value() == ValType::TensorView ||
826 output.var()->getValType().value() == ValType::TensorIndex);
827 TORCH_INTERNAL_ASSERT(
828 output.N()->getValType().value() == ValType::TensorView ||
829 output.N()->getValType().value() == ValType::TensorIndex);
830 TORCH_INTERNAL_ASSERT(isIntegralType(output.N()->dtype()));
831
832 // check initial value
833 TORCH_INTERNAL_ASSERT(init.N()->getValType().value() == ValType::Scalar);
834 TORCH_INTERNAL_ASSERT(isIntegralType(init.N()->dtype()));
835 if (!init.N()->isZeroInt()) {
836 // when initial count is zero, no initial variance or average is needed
837 // initial value with a count of 1 is un-common enough that I'll push
838 // the responsibility of creating all-zero var tensors to the user
839 TORCH_INTERNAL_ASSERT(
840 init_.avg()->getValType().value() == ValType::TensorView ||
841 init_.avg()->getValType().value() == ValType::TensorIndex);
842 TORCH_INTERNAL_ASSERT(
843 init_.var()->getValType().value() == ValType::TensorView ||
844 init_.var()->getValType().value() == ValType::TensorIndex,
845 "Invalid initial var: ",
846 init_.var()->toString());
847 }
848
849 // check input
850 TORCH_INTERNAL_ASSERT(
851 input_.avg()->getValType().value() == ValType::TensorView ||
852 input_.avg()->getValType().value() == ValType::TensorIndex,
853 input_.avg()->getValType().value());
854 TORCH_INTERNAL_ASSERT(
855 input_.N()->getValType().value() == ValType::Scalar ||
856 input_.N()->getValType().value() == ValType::TensorView ||
857 input_.N()->getValType().value() == ValType::TensorIndex);
858 TORCH_INTERNAL_ASSERT(isIntegralType(input_.N()->dtype()));
859 if (!input_.N()->isOneInt()) {
860 // when input is only one value, only the value is required through avg
861 // input the var part is implicitly 0 and codegen will handle that.
862 TORCH_INTERNAL_ASSERT(
863 input_.var()->getValType().value() == ValType::TensorView ||
864 input_.var()->getValType().value() == ValType::TensorIndex);
865 } else {
866 TORCH_INTERNAL_ASSERT(
867 input_.var() == nullptr || input_.var()->isZeroInt(),
868 "Invalid var input, which must be either nullptr or scalar zero when the N input is one.");
869 }
870
871 addOutput(output_.avg());
872 addOutput(output_.var());
873 addOutput(output_.N());
874
875 addInput(input_.avg());
876 addInput(input_.var());
877 addInput(input_.N());
878}
879
880c10::optional<WelfordTriplet::ValName> WelfordTriplet::getNameOf(
881 Val* val) const {
882 auto it = std::find(begin(), end(), val);
883 if (it != end()) {
884 return indexToValName(std::distance(begin(), it));
885 }
886
887 return c10::optional<WelfordTriplet::ValName>();
888}
889
890bool WelfordTriplet::sameAs(const WelfordTriplet& other) const {
891 return this == &other ||
892 (avg()->sameAs(other.avg()) && var()->sameAs(other.var()) &&
893 N()->sameAs(other.N()));
894}
895
896WelfordTriplet WelfordTriplet::clone(IrCloner* ir_cloner) const {
897 return transform([&](const Val* val) { return ir_cloner->clone<Val>(val); });
898}
899
900std::vector<WelfordTriplet> WelfordTriplet::clone(
901 const std::vector<WelfordTriplet>& src,
902 IrCloner* ir_cloner) {
903 std::vector<WelfordTriplet> cloned;
904 for (const auto& triplet : src) {
905 cloned.emplace_back(triplet.clone(ir_cloner));
906 }
907 return cloned;
908}
909
910WelfordOp::WelfordOp(
911 IrBuilderPasskey passkey,
912 Val* out_avg,
913 Val* out_var,
914 Val* out_N,
915 Val* in_avg,
916 Val* in_var,
917 Val* in_N,
918 Val* init_avg,
919 Val* init_var,
920 Val* init_N,
921 bool is_fused)
922 : WelfordOp(
923 passkey,
924 WelfordTriplet(out_avg, out_var, out_N),
925 WelfordTriplet(in_avg, in_var, in_N),
926 WelfordTriplet(init_avg, init_var, init_N),
927 is_fused) {}
928
929WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner)
930 : Expr(src, ir_cloner),
931 output_(src->output_.clone(ir_cloner)),
932 input_(src->input_.clone(ir_cloner)),
933 init_(src->init_.clone(ir_cloner)),
934 is_allreduce_(src->is_allreduce_) {}
935
936Expr* WelfordOp::shallowCopy() const {
937 auto result =
938 IrBuilder::create<WelfordOp>(output_, input_, init_, is_allreduce_);
939 result->copyPredicatesFrom(this);
940 return result;
941}
942
943Val* WelfordOp::getInitValOfOutput(Val* output_val) const {
944 auto val_name = output().getNameOf(output_val);
945
946 TORCH_INTERNAL_ASSERT(
947 val_name.has_value(),
948 "Not an output val ",
949 output_val->toString(),
950 " of ",
951 toString());
952
953 return init().get(*val_name);
954}
955
956bool WelfordOp::sameAs(const Statement* other) const {
957 if (this == other) {
958 return true;
959 }
960 if (auto other_wop = dynamic_cast<const WelfordOp*>(other)) {
961 return input_.sameAs(other_wop->input_) && init_.sameAs(other_wop->init_);
962 }
963 return false;
964}
965
966std::vector<Val*> WelfordOp::getInitVals() const {
967 std::vector<Val*> init_vals({init_.avg(), init_.var(), init_.N()});
968 return init_vals;
969}
970
971GroupedWelfordOp::GroupedWelfordOp(
972 IrBuilderPasskey passkey,
973 std::vector<WelfordTriplet> output_vals,
974 std::vector<WelfordTriplet> input_vals,
975 std::vector<WelfordTriplet> init_vals,
976 bool is_allreduce,
977 ExprType expr_type)
978 : Expr(passkey, expr_type),
979 output_vals_(std::move(output_vals)),
980 input_vals_(std::move(input_vals)),
981 init_vals_(std::move(init_vals)),
982 is_allreduce_(is_allreduce) {
983 const auto num_grouped_ops = output_vals_.size();
984
985 TORCH_INTERNAL_ASSERT(
986 input_vals_.size() == num_grouped_ops,
987 "Invalid number of input arguments. Expected: ",
988 num_grouped_ops,
989 ", Given: ",
990 input_vals_.size());
991 TORCH_INTERNAL_ASSERT(
992 init_vals_.size() == num_grouped_ops,
993 "Invalid number of N arguments. Expected: ",
994 num_grouped_ops,
995 ", Given: ",
996 init_vals_.size());
997
998 for (const auto i : c10::irange(num_grouped_ops)) {
999 // Check output type
1000 TORCH_INTERNAL_ASSERT(
1001 output_vals_[i].avg()->getValType().value() == ValType::TensorView ||
1002 output_vals_[i].avg()->getValType().value() == ValType::TensorIndex);
1003 TORCH_INTERNAL_ASSERT(
1004 output_vals_[i].var()->getValType().value() == ValType::TensorView ||
1005 output_vals_[i].var()->getValType().value() == ValType::TensorIndex);
1006 TORCH_INTERNAL_ASSERT(
1007 output_vals_[i].N()->getValType().value() == ValType::TensorView ||
1008 output_vals_[i].N()->getValType().value() == ValType::TensorIndex);
1009 TORCH_INTERNAL_ASSERT(isIntegralType(output_vals_[i].N()->dtype()));
1010
1011 // check initial value
1012 auto init_avg = init_vals_[i].avg();
1013 auto init_var = init_vals_[i].var();
1014 auto init_N = init_vals_[i].N();
1015 TORCH_INTERNAL_ASSERT(
1016 init_avg != nullptr && init_var != nullptr && init_N != nullptr,
1017 "nullptr init vals are not allowed");
1018 TORCH_INTERNAL_ASSERT(init_N->getValType().value() == ValType::Scalar);
1019 TORCH_INTERNAL_ASSERT(isIntegralType(init_N->dtype()));
1020 TORCH_INTERNAL_ASSERT(
1021 init_avg->getValType().value() == ValType::TensorView ||
1022 init_avg->getValType().value() == ValType::TensorIndex ||
1023 (init_N->isZeroInt() &&
1024 init_avg->getValType().value() == ValType::Scalar),
1025 "Initial avg must be a tensor or, can be a scalar if initial N is zero.",
1026 " Initial avg: ",
1027 init_avg->toString(),
1028 ". Initial N: ",
1029 init_N->toString());
1030 TORCH_INTERNAL_ASSERT(
1031 init_var->getValType().value() == ValType::TensorView ||
1032 init_var->getValType().value() == ValType::TensorIndex ||
1033 (init_N->isZeroInt() &&
1034 init_var->getValType().value() == ValType::Scalar),
1035 "Initial var must be a tensor or, can be a scalar if initial N is zero: ",
1036 init_var->toString());
1037
1038 // check input
1039 auto in_avg = input_vals_[i].avg();
1040 auto in_var = input_vals_[i].var();
1041 auto in_N = input_vals_[i].N();
1042 TORCH_INTERNAL_ASSERT(
1043 in_avg != nullptr && in_var != nullptr && in_N != nullptr,
1044 "nullptr input vals are not allowed");
1045 TORCH_INTERNAL_ASSERT(
1046 in_N->getValType().value() == ValType::Scalar ||
1047 in_N->getValType().value() == ValType::TensorView ||
1048 in_N->getValType().value() == ValType::TensorIndex);
1049 TORCH_INTERNAL_ASSERT(isIntegralType(in_N->dtype()));
1050 TORCH_INTERNAL_ASSERT(
1051 in_avg->getValType().value() == ValType::TensorView ||
1052 in_avg->getValType().value() == ValType::TensorIndex,
1053 "Invalid input avg argument type: ",
1054 in_avg->getValType().value());
1055
1056 if (in_N->isOneInt()) {
1057 // when input is only one value, only the value is required through avg
1058 // input the var part must be implicitly 0
1059 TORCH_INTERNAL_ASSERT(
1060 in_var->isZeroInt(),
1061 "Invalid var input, which must be scalar zero when the N input is one: ",
1062 in_var->toString());
1063 } else {
1064 TORCH_INTERNAL_ASSERT(
1065 in_var->getValType().value() == ValType::TensorView ||
1066 in_var->getValType().value() == ValType::TensorIndex,
1067 in_var->getValType().value(),
1068 ", ",
1069 in_N->toString());
1070 }
1071 }
1072
1073 for (const auto i : c10::irange(num_grouped_ops)) {
1074 addOutput(output_vals_[i].avg());
1075 addOutput(output_vals_[i].var());
1076 addOutput(output_vals_[i].N());
1077 addInput(input_vals_[i].avg());
1078 addInput(input_vals_[i].var());
1079 addInput(input_vals_[i].N());
1080 }
1081}
1082
1083GroupedWelfordOp::GroupedWelfordOp(
1084 const GroupedWelfordOp* src,
1085 IrCloner* ir_cloner)
1086 : Expr(src, ir_cloner),
1087 output_vals_(WelfordTriplet::clone(src->output_vals_, ir_cloner)),
1088 input_vals_(WelfordTriplet::clone(src->input_vals_, ir_cloner)),
1089 init_vals_(WelfordTriplet::clone(src->init_vals_, ir_cloner)),
1090 is_allreduce_(src->is_allreduce_) {}
1091
1092Expr* GroupedWelfordOp::shallowCopy() const {
1093 auto result = IrBuilder::create<GroupedWelfordOp>(
1094 output_vals_, input_vals_, init_vals_, is_allreduce_, etype());
1095 result->copyPredicatesFrom(this);
1096 return result;
1097}
1098
1099bool GroupedWelfordOp::sameAs(const Statement* other) const {
1100 if (this == other) {
1101 return true;
1102 }
1103
1104 auto grouped_op = dynamic_cast<const GroupedWelfordOp*>(other);
1105 if (grouped_op == nullptr) {
1106 return false;
1107 }
1108
1109 if (!Expr::sameAs(other)) {
1110 return false;
1111 }
1112
1113 for (const auto i : c10::irange(numExprs())) {
1114 if (!initAvg(i)->sameAs(grouped_op->initAvg(i)) ||
1115 !initVar(i)->sameAs(grouped_op->initVar(i)) ||
1116 !initN(i)->sameAs(grouped_op->initN(i))) {
1117 return false;
1118 }
1119 }
1120
1121 return true;
1122}
1123
1124int GroupedWelfordOp::getExprIndexOfOutput(Val* output_val) const {
1125 for (const auto expr_idx : c10::irange(numExprs())) {
1126 if (outputVals().at(expr_idx).getNameOf(output_val).has_value()) {
1127 return expr_idx;
1128 }
1129 }
1130
1131 TORCH_INTERNAL_ASSERT(
1132 false, "Not an output, ", output_val->toString(), ", of ", toString());
1133}
1134
1135Val* GroupedWelfordOp::getInitValOfOutput(Val* output_val) const {
1136 auto expr_index = getExprIndexOfOutput(output_val);
1137
1138 auto val_name = outputVals().at(expr_index).getNameOf(output_val).value();
1139
1140 return initVals().at(expr_index).get(val_name);
1141}
1142
1143MmaOp::MmaOp(
1144 IrBuilderPasskey passkey,
1145 Val* out,
1146 Val* in_a,
1147 Val* in_b,
1148 Val* init)
1149 : Expr(passkey, ExprType::MmaOp),
1150 out_(out),
1151 in_a_(in_a),
1152 in_b_(in_b),
1153 init_(init) {
1154 // Check output type
1155 TORCH_INTERNAL_ASSERT(
1156 out->getValType().value() == ValType::TensorView ||
1157 out->getValType().value() == ValType::TensorIndex);
1158
1159 TORCH_INTERNAL_ASSERT(
1160 in_a->getValType().value() == ValType::TensorView ||
1161 in_a->getValType().value() == ValType::TensorIndex,
1162 in_a->getValType().value());
1163
1164 TORCH_INTERNAL_ASSERT(
1165 in_b->getValType().value() == ValType::TensorView ||
1166 in_b->getValType().value() == ValType::TensorIndex,
1167 in_b->getValType().value());
1168
1169 addOutput(out);
1170 addInput(in_a);
1171 addInput(in_b);
1172}
1173
1174MmaOp::MmaOp(
1175 IrBuilderPasskey passkey,
1176 Val* out,
1177 Val* in_a,
1178 Val* in_b,
1179 Val* init,
1180 OptionsInMma options)
1181 : MmaOp(passkey, out, in_a, in_b, init) {
1182 options_ = options;
1183}
1184
1185MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner)
1186 : Expr(src, ir_cloner),
1187 out_(ir_cloner->clone(src->out_)),
1188 in_a_(ir_cloner->clone(src->in_a_)),
1189 in_b_(ir_cloner->clone(src->in_b_)),
1190 init_(ir_cloner->clone(src->init_)),
1191 options_(src->options_) {}
1192
1193Expr* MmaOp::shallowCopy() const {
1194 auto result = IrBuilder::create<MmaOp>(out_, in_a_, in_b_, init_);
1195 result->options_ = options_;
1196 result->copyPredicatesFrom(this);
1197 return result;
1198}
1199
1200bool MmaOp::sameAs(const Statement* other) const {
1201 if (this == other) {
1202 return true;
1203 }
1204 if (auto other_mma = dynamic_cast<const MmaOp*>(other)) {
1205 return out_->sameAs(other_mma->out_) && in_a_->sameAs(other_mma->in_a_) &&
1206 in_b_->sameAs(other_mma->in_b_) && init_->sameAs(other_mma->init_) &&
1207 options_ == other_mma->options_;
1208 }
1209 return false;
1210}
1211
1212TransposeOp::TransposeOp(
1213 IrBuilderPasskey passkey,
1214 TensorView* out,
1215 TensorView* in,
1216 std::vector<int64_t> new2old)
1217 : Expr(passkey, ExprType::TransposeOp),
1218 out_(out),
1219 in_(in),
1220 new2old_(std::move(new2old)) {
1221 // Sanity check of the input parameters. Maybe not necessary as they
1222 // should be checked at function transpose.
1223
1224 TORCH_INTERNAL_ASSERT(
1225 TensorDomain::noReductions(in->getMaybeRFactorDomain()).size() ==
1226 out->getMaybeRFactorDomain().size());
1227
1228 TORCH_INTERNAL_ASSERT(new2old_.size() == out->getMaybeRFactorDomain().size());
1229
1230 // Make sure the entries of new2old are unique and range from 0 to
1231 // N-1, where N == new2old.size().
1232 std::set<int64_t> old_positions(new2old_.begin(), new2old_.end());
1233 TORCH_INTERNAL_ASSERT(old_positions.size() == new2old_.size());
1234 // old_positions is sorted, so the first entry must be 0.
1235 TORCH_INTERNAL_ASSERT(
1236 *(old_positions.begin()) == 0,
1237 "Invalid new2old vector detected: ",
1238 new2old_);
1239 // The last entry must be N-1, since old_positions is sorted, starts
1240 // with 0, and its length is N.
1241 TORCH_INTERNAL_ASSERT(
1242 *(old_positions.rbegin()) == (int)(new2old_.size() - 1),
1243 "Invalid new2old vector detected: ",
1244 new2old_);
1245
1246 addOutput(out);
1247 addInput(in);
1248}
1249
1250TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner)
1251 : Expr(src, ir_cloner),
1252 out_(ir_cloner->clone(src->out_)),
1253 in_(ir_cloner->clone(src->in_)),
1254 new2old_(src->new2old_) {}
1255
1256Expr* TransposeOp::shallowCopy() const {
1257 auto result = IrBuilder::create<TransposeOp>(out_, in_, new2old_);
1258 result->copyPredicatesFrom(this);
1259 return result;
1260}
1261
1262std::vector<int64_t> TransposeOp::old2new() const {
1263 std::vector<int64_t> old2new(new2old_.size());
1264 for (auto new_axis : c10::irange(new2old_.size())) {
1265 auto old_axis = new2old_.at(new_axis);
1266 old2new[old_axis] = new_axis;
1267 }
1268 return old2new;
1269}
1270
1271ExpandOp::ExpandOp(
1272 IrBuilderPasskey passkey,
1273 TensorView* out,
1274 TensorView* in,
1275 std::vector<Val*> _expanded_extents)
1276 : Expr(passkey, ExprType::ExpandOp),
1277 out_(out),
1278 in_(in),
1279 expanded_extents_(std::move(_expanded_extents)) {
1280 addOutput(out);
1281 addInput(in);
1282 for (auto expanded_extent : expanded_extents_) {
1283 TORCH_INTERNAL_ASSERT(expanded_extent != nullptr);
1284 TORCH_INTERNAL_ASSERT(
1285 expanded_extent->dtype() == DataType::Int,
1286 "Expanded extents must be of Int type.");
1287 addInput(expanded_extent);
1288 }
1289}
1290
1291ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner)
1292 : Expr(src, ir_cloner),
1293 out_(ir_cloner->clone(src->out_)),
1294 in_(ir_cloner->clone(src->in_)) {
1295 expanded_extents_.reserve(src->expanded_extents_.size());
1296 for (const auto expanded_extent : src->expanded_extents_) {
1297 expanded_extents_.push_back(ir_cloner->clone(expanded_extent));
1298 }
1299}
1300
1301Expr* ExpandOp::shallowCopy() const {
1302 auto result = IrBuilder::create<ExpandOp>(out_, in_, expanded_extents_);
1303 result->copyPredicatesFrom(this);
1304 return result;
1305}
1306
1307ShiftOp::ShiftOp(
1308 IrBuilderPasskey passkey,
1309 Val* out,
1310 Val* in,
1311 std::vector<int> offsets,
1312 std::vector<int> pad_width)
1313 : Expr(passkey, ExprType::ShiftOp),
1314 out_(out),
1315 in_(in),
1316 offsets_(std::move(offsets)),
1317 pad_width_(std::move(pad_width)) {
1318 // clang-tidy complains about out_ that it may be null.
1319 TORCH_INTERNAL_ASSERT(out_ != nullptr);
1320 TORCH_INTERNAL_ASSERT(in_ != nullptr);
1321
1322 auto out_type = out->getValType().value();
1323 auto in_type = in->getValType().value();
1324
1325 TORCH_INTERNAL_ASSERT(
1326 out_type == ValType::TensorView && in_type == ValType::TensorView,
1327 "Cannot shift a non-tensor object.");
1328
1329 TORCH_INTERNAL_ASSERT(
1330 offsets_.size() ==
1331 TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain())
1332 .size(),
1333 "Invalid offset vector: ",
1334 offsets_);
1335
1336 TORCH_INTERNAL_ASSERT(
1337 pad_width_.size() ==
1338 TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain())
1339 .size(),
1340 "Invalid padding width vector: ",
1341 pad_width_);
1342
1343 addOutput(out);
1344 addInput(in);
1345}
1346
1347ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner)
1348 : Expr(src, ir_cloner),
1349 out_(ir_cloner->clone(src->out_)),
1350 in_(ir_cloner->clone(src->in_)),
1351 offsets_(src->offsets_),
1352 pad_width_(src->pad_width_) {}
1353
1354Expr* ShiftOp::shallowCopy() const {
1355 auto result = IrBuilder::create<ShiftOp>(out_, in_, offsets_, pad_width_);
1356 result->copyPredicatesFrom(this);
1357 return result;
1358}
1359
1360bool ShiftOp::sameAs(const Statement* other) const {
1361 if (this == other) {
1362 return true;
1363 }
1364 if (!other->isA<ShiftOp>()) {
1365 return false;
1366 }
1367 const auto other_op = other->as<ShiftOp>();
1368 if (offsets() != other_op->offsets()) {
1369 return false;
1370 }
1371 return Expr::sameAs(other);
1372}
1373
1374GatherOp::GatherOp(
1375 IrBuilderPasskey passkey,
1376 Val* out,
1377 Val* in,
1378 std::vector<int> window_shape,
1379 std::vector<std::vector<int>> pad_width)
1380 : Expr(passkey, ExprType::GatherOp),
1381 out_(out),
1382 in_(in),
1383 window_shape_(std::move(window_shape)),
1384 pad_width_(std::move(pad_width)) {
1385 // clang-tidy complains about out_ that it may be null.
1386 TORCH_INTERNAL_ASSERT(out_ != nullptr);
1387 TORCH_INTERNAL_ASSERT(in_ != nullptr);
1388
1389 auto out_type = out->getValType().value();
1390 auto in_type = in->getValType().value();
1391
1392 TORCH_INTERNAL_ASSERT(
1393 out_type == ValType::TensorView && in_type == ValType::TensorView,
1394 "Cannot shift a non-tensor object.");
1395
1396 const auto ndims =
1397 TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain()).size();
1398
1399 TORCH_INTERNAL_ASSERT(
1400 window_shape_.size() == ndims,
1401 "Invalid window_shape vector: ",
1402 window_shape_);
1403 TORCH_INTERNAL_ASSERT(
1404 pad_width_.size() == ndims, "Invalid pad_width vector: ", pad_width_);
1405
1406 for (const auto& pad : pad_width_) {
1407 TORCH_INTERNAL_ASSERT(
1408 pad.size() == 2, "Padding size for each axis must have two Int vals.");
1409 }
1410
1411 addOutput(out);
1412 addInput(in);
1413}
1414
1415GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner)
1416 : Expr(src, ir_cloner),
1417 out_(ir_cloner->clone(src->out_)),
1418 in_(ir_cloner->clone(src->in_)),
1419 window_shape_(src->window_shape_),
1420 pad_width_(src->pad_width_) {}
1421
1422Expr* GatherOp::shallowCopy() const {
1423 auto result =
1424 IrBuilder::create<GatherOp>(out_, in_, window_shape_, pad_width_);
1425 result->copyPredicatesFrom(this);
1426 return result;
1427}
1428
1429bool GatherOp::sameAs(const Statement* other) const {
1430 if (this == other) {
1431 return true;
1432 }
1433 if (!other->isA<GatherOp>()) {
1434 return false;
1435 }
1436 const auto other_op = other->as<GatherOp>();
1437 if (windowShape() != other_op->windowShape() ||
1438 padWidth() != other_op->padWidth()) {
1439 return false;
1440 }
1441 return Expr::sameAs(other);
1442}
1443
1444int GatherOp::gatherAxis(int axis) const {
1445 if (axis < 0) {
1446 axis += out()->as<TensorView>()->nDims();
1447 }
1448 TORCH_INTERNAL_ASSERT(
1449 axis >= 0 && axis < (int)windowShape().size(), "Invalid axis: ", axis);
1450 return int(windowShape().size()) + axis;
1451}
1452
1453ViewAsScalar::ViewAsScalar(
1454 IrBuilderPasskey passkey,
1455 Val* out,
1456 Val* in,
1457 IterDomain* vector_id,
1458 Val* index)
1459 : Expr(passkey, ExprType::ViewAsScalar),
1460 out_(out),
1461 in_(in),
1462 vector_id_(vector_id),
1463 index_(index) {
1464 addOutput(out);
1465 addInput(in);
1466}
1467
1468ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner)
1469 : Expr(src, ir_cloner),
1470 out_(ir_cloner->clone(src->out_)),
1471 in_(ir_cloner->clone(src->in_)),
1472 vector_id_(ir_cloner->clone(src->vector_id_)),
1473 index_(ir_cloner->clone(src->index_)) {}
1474
1475Expr* ViewAsScalar::shallowCopy() const {
1476 auto result = IrBuilder::create<ViewAsScalar>(out_, in_, vector_id_, index_);
1477 result->copyPredicatesFrom(this);
1478 return result;
1479}
1480
1481ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in)
1482 : Expr(passkey, ExprType::ViewOp), out_(out), in_(in) {
1483 addOutput(out);
1484 addInput(in);
1485}
1486
1487ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner)
1488 : Expr(src, ir_cloner),
1489 out_(ir_cloner->clone(src->out_)),
1490 in_(ir_cloner->clone(src->in_)) {}
1491
1492Expr* ViewOp::shallowCopy() const {
1493 auto result = IrBuilder::create<ViewOp>(out_, in_);
1494 result->copyPredicatesFrom(this);
1495 return result;
1496}
1497
1498LoadStoreOp::LoadStoreOp(
1499 IrBuilderPasskey passkey,
1500 LoadStoreOpType op_type,
1501 Val* out,
1502 Val* in)
1503 : Expr(passkey, ExprType::LoadStoreOp),
1504 load_store_type_(op_type),
1505 out_(out),
1506 in_(in) {
1507 addOutput(out);
1508 addInput(in);
1509}
1510
1511LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner)
1512 : Expr(src, ir_cloner),
1513 load_store_type_(src->load_store_type_),
1514 out_(ir_cloner->clone(src->out_)),
1515 in_(ir_cloner->clone(src->in_)) {}
1516
1517Expr* LoadStoreOp::shallowCopy() const {
1518 auto result = IrBuilder::create<LoadStoreOp>(load_store_type_, out_, in_);
1519 result->copyPredicatesFrom(this);
1520 return result;
1521}
1522
1523IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent)
1524 : start_(_start), extent_(_extent) {
1525 TORCH_INTERNAL_ASSERT(
1526 start_ != nullptr && extent_ != nullptr,
1527 "Start and extent are required to build an iter domain.");
1528}
1529
1530IterDomainBuilder::IterDomainBuilder(const IterDomain* id)
1531 : start_(id->start()),
1532 extent_(id->extent()),
1533 expanded_extent_(
1534 id->hasExpandedExtent() ? id->expandedExtent() : nullptr),
1535 stop_offset_(id->stopOffset()),
1536 parallel_type_(id->getParallelType()),
1537 iter_type_(id->getIterType()),
1538 is_rfactor_domain_(id->isRFactorProduct()),
1539 is_padded_dimension_(id->hasPaddingToMultipleOfWarp()),
1540 padded_to_size_(id->getMaybeSizeAfterPadding()),
1541 is_mma_swizzled_(id->isMmaSwizzled()) {}
1542
1543IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() {
1544 parallel_type_ = ParallelType::Serial;
1545 is_rfactor_domain_ = false;
1546 is_padded_dimension_ = false;
1547 padded_to_size_ = c10::nullopt;
1548 is_mma_swizzled_ = false;
1549 return *this;
1550}
1551
1552IterDomainBuilder& IterDomainBuilder::resetRfactor() {
1553 return is_rfactor_domain(false);
1554}
1555
1556IterDomainBuilder& IterDomainBuilder::start(Val* _start) {
1557 start_ = _start;
1558 return *this;
1559}
1560
1561IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) {
1562 extent_ = _extent;
1563 return *this;
1564}
1565
1566IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) {
1567 expanded_extent_ = _expanded_extent;
1568 return *this;
1569}
1570
1571IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) {
1572 stop_offset_ = _stop_offset;
1573 return *this;
1574}
1575
1576IterDomainBuilder& IterDomainBuilder::parallel_type(
1577 ParallelType _parallel_type) {
1578 parallel_type_ = _parallel_type;
1579 return *this;
1580}
1581
1582IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) {
1583 iter_type_ = _iter_type;
1584 return *this;
1585}
1586
1587IterDomainBuilder& IterDomainBuilder::is_rfactor_domain(
1588 bool _is_rfactor_domain) {
1589 is_rfactor_domain_ = _is_rfactor_domain;
1590 return *this;
1591}
1592
1593IterDomainBuilder& IterDomainBuilder::is_padded_dimension(
1594 bool _is_padded_dimension) {
1595 is_padded_dimension_ = _is_padded_dimension;
1596 return *this;
1597}
1598
1599IterDomainBuilder& IterDomainBuilder::padded_to_size(
1600 c10::optional<int64_t> _padded_to_size) {
1601 padded_to_size_ = _padded_to_size;
1602 return *this;
1603}
1604
1605IterDomainBuilder& IterDomainBuilder::is_mma_swizzled(bool _is_mma_swizzled) {
1606 is_mma_swizzled_ = _is_mma_swizzled;
1607 return *this;
1608}
1609
1610IterDomain* IterDomainBuilder::build() const {
1611 TORCH_INTERNAL_ASSERT(
1612 start_ != nullptr && extent_ != nullptr,
1613 "Start and extent are required to build an iter domain.");
1614 return IrBuilder::create<IterDomain>(start_->container(), *this);
1615}
1616
1617IterDomain::IterDomain(
1618 IrBuilderPasskey passkey,
1619 Val* start,
1620 Val* extent,
1621 Val* expanded_extent,
1622 Val* stop_offset,
1623 ParallelType parallel_type,
1624 IterType iter_type,
1625 bool is_rfactor_domain,
1626 bool is_padded_dimension,
1627 c10::optional<int64_t> padded_to_size,
1628 bool is_mma_swizzled)
1629 : Val(passkey, ValType::IterDomain, DataType::Int),
1630 start_(start),
1631 extent_(extent),
1632 expanded_extent_(expanded_extent),
1633 stop_offset_(
1634 stop_offset == nullptr ? passkey.ir_container_->zeroVal()
1635 : stop_offset),
1636 parallel_type_(parallel_type),
1637 iter_type_(iter_type),
1638 is_rfactor_domain_(is_rfactor_domain),
1639 is_padded_dimension_(is_padded_dimension),
1640 padded_to_size_(padded_to_size),
1641 is_mma_swizzled_(is_mma_swizzled) {
1642 TORCH_CHECK(
1643 !(isRFactorProduct() && isBroadcast()),
1644 "IterDomain cannot be both a broadcast and rfactor domain.");
1645
1646 TORCH_INTERNAL_ASSERT(
1647 extent->isAnInt(),
1648 "Cannot create an iter domain over an extent that is not an int but received ",
1649 extent,
1650 " .");
1651
1652 TORCH_INTERNAL_ASSERT(
1653 start->isAnInt(),
1654 "Cannot create an iter domain with a start that is not an int but received ",
1655 start,
1656 " .");
1657}
1658
1659IterDomain::IterDomain(IrBuilderPasskey passkey, const IterDomainBuilder& args)
1660
1661 : IterDomain(
1662 passkey,
1663 args.start_,
1664 args.extent_,
1665 args.expanded_extent_,
1666 args.stop_offset_,
1667 args.parallel_type_,
1668 args.iter_type_,
1669 args.is_rfactor_domain_,
1670 args.is_padded_dimension_,
1671 args.padded_to_size_,
1672 args.is_mma_swizzled_) {}
1673
1674IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner)
1675 : Val(src, ir_cloner),
1676 start_(ir_cloner->clone(src->start_)),
1677 extent_(ir_cloner->clone(src->extent_)),
1678 expanded_extent_(
1679 src->hasExpandedExtent() ? ir_cloner->clone(src->expandedExtent())
1680 : nullptr),
1681 stop_offset_(ir_cloner->clone(src->stop_offset_)),
1682 parallel_type_(src->parallel_type_),
1683 iter_type_(src->iter_type_),
1684 is_rfactor_domain_(src->is_rfactor_domain_),
1685 is_padded_dimension_(src->is_padded_dimension_),
1686 padded_to_size_(src->padded_to_size_),
1687 is_mma_swizzled_(src->is_mma_swizzled_) {}
1688
1689bool IterDomain::sameAs(const Statement* other) const {
1690 if (other == this) {
1691 return true;
1692 }
1693
1694 if (!other->isA<IterDomain>()) {
1695 return false;
1696 }
1697
1698 const IterDomain* other_id = other->as<IterDomain>();
1699
1700 bool is_same = isReduction() == other_id->isReduction() &&
1701 getParallelType() == other_id->getParallelType() &&
1702 isVectorComponent() == other_id->isVectorComponent();
1703 is_same = is_same && ScalarCheck::sameAs(extent(), other_id->extent());
1704 is_same = is_same && ScalarCheck::sameAs(start(), other_id->start());
1705 is_same =
1706 is_same && ScalarCheck::sameAs(stopOffset(), other_id->stopOffset());
1707 is_same = is_same && (hasExpandedExtent() == other_id->hasExpandedExtent());
1708 if (is_same && hasExpandedExtent()) {
1709 is_same = ScalarCheck::sameAs(expandedExtent(), other_id->expandedExtent());
1710 }
1711
1712 return is_same;
1713}
1714
1715// Returns a new IterDomain matching properties of this except for
1716// is_rfactor_domain_
1717IterDomain* IterDomain::cloneWithoutRFactor() const {
1718 auto cloned = IterDomainBuilder(this).resetRfactor().build();
1719
1720 return cloned;
1721}
1722
1723bool IterDomain::isTrivialReduction() const {
1724 if (!isReduction()) {
1725 return false;
1726 }
1727
1728 if (extent()->isOneInt()) {
1729 return true;
1730 }
1731
1732 // If this domain is an output of an expression, i.e., not a root
1733 // domain, check if all root domains are trivial reductions. This is
1734 // almost the same as the analysis done in TrivialReductionInfo, but
1735 // is limited within a single tensor, whereas TrivialReductionInfo
1736 // does more expensive analysis potentially traversing through
1737 // rfactor domains
1738 if (definition()) {
1739 // Note: There's no const version of IterVisitor.
1740 auto id_inputs = InputsOf::output(fusion(), const_cast<IterDomain*>(this));
1741 if (std::all_of(
1742 ir_utils::filterByType<IterDomain>(id_inputs).begin(),
1743 ir_utils::filterByType<IterDomain>(id_inputs).end(),
1744 [](IterDomain* root_id) {
1745 return root_id->isReduction() && root_id->extent()->isOneInt();
1746 })) {
1747 return true;
1748 }
1749 }
1750
1751 return false;
1752}
1753
1754std::vector<IterDomain*> IterDomain::clone(
1755 const std::vector<IterDomain*>& domains) {
1756 std::vector<IterDomain*> cloned_domains;
1757 std::transform(
1758 domains.begin(),
1759 domains.end(),
1760 std::back_inserter(cloned_domains),
1761 [](auto id) { return id->cloneWithoutRFactor(); });
1762 return cloned_domains;
1763}
1764
1765// Merging does not propagate the start and stop values of the input
1766// domains to the merged output domain. The actual range of the
1767// domains is enforced by predicates. Note that since only root
1768// domains have valid start and stop, it's not possible to contiguous
1769// predication.
1770IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
1771 TORCH_CHECK(
1772 !outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
1773 "Merging IterDomains with ending values that are 0 is not supported at this time.");
1774 TORCH_CHECK(
1775 outer->isReduction() == inner->isReduction() ||
1776 (!outer->isReduction() && inner->isTrivialReduction()) ||
1777 (outer->isTrivialReduction() && !inner->isReduction()),
1778 "Merging IterDomains requires that their iteration types match. ",
1779 "Outer: ",
1780 outer->toString(),
1781 ", Inner: ",
1782 inner->toString());
1783 TORCH_CHECK(
1784 (outer->isGather() && inner->isGather()) ||
1785 (!outer->isGather() && !inner->isGather()),
1786 "Merging gather and non-gather domains is not supported.");
1787
1788 TORCH_CHECK(
1789 !outer->isStride() && !inner->isStride(),
1790 "No support for merging stride domains");
1791
1792 Val* merged_id_size = mul(outer->extent(), inner->extent());
1793
1794 IterType itype = outer->getIterType();
1795
1796 if (outer->isBroadcast() && inner->isBroadcast()) {
1797 itype = IterType::Broadcast;
1798 }
1799
1800 if ((outer->isBroadcast() || inner->isBroadcast()) &&
1801 (outer->getIterType() == IterType::Iteration ||
1802 inner->getIterType() == IterType::Iteration)) {
1803 itype = IterType::Iteration;
1804 }
1805
1806 // Merging trivial reduction with iter domain, that's fine, just make it an
1807 // iter domain.
1808 if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
1809 (outer->getIterType() == IterType::Iteration ||
1810 inner->getIterType() == IterType::Iteration)) {
1811 itype = IterType::Iteration;
1812 }
1813
1814 // Merging trivial reduction with broadcasting, that's fine, just make it a
1815 // broadcasting.
1816 if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
1817 (outer->isBroadcast() || inner->isBroadcast())) {
1818 itype = IterType::Broadcast;
1819 }
1820
1821 Val* expanded_extent = nullptr;
1822 if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) {
1823 if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) {
1824 expanded_extent = mul(outer->expandedExtent(), inner->expandedExtent());
1825 } else if (outer->hasExpandedExtent() && !inner->hasExpandedExtent()) {
1826 if (inner->isBroadcast()) {
1827 expanded_extent = outer->expandedExtent();
1828 } else {
1829 expanded_extent = mul(outer->expandedExtent(), inner->extent());
1830 }
1831 } else if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) {
1832 if (outer->isBroadcast()) {
1833 expanded_extent = inner->expandedExtent();
1834 } else {
1835 expanded_extent = mul(outer->extent(), inner->expandedExtent());
1836 }
1837 }
1838 }
1839
1840 IterDomain* merged_id =
1841 IterDomainBuilder(
1842 outer->container()->zeroVal(), merged_id_size->as<Int>())
1843 .parallel_type(outer->getParallelType())
1844 .expanded_extent(expanded_extent)
1845 .iter_type(itype)
1846 .build();
1847
1848 IrBuilder::create<Merge>(outer->container(), merged_id, outer, inner);
1849
1850 return merged_id;
1851}
1852
1853// Both outer and inner domains do not inherit start and stop
1854// values as they can't be split. The access range is enforced by
1855// predicates.
1856std::pair<IterDomain*, IterDomain*> IterDomain::split(
1857 IterDomain* in,
1858 Val* factor,
1859 bool inner_split,
1860 Val* start_offset,
1861 Val* stop_offset) {
1862 TORCH_CHECK(
1863 !in->extent()->isZeroInt(),
1864 "Splitting IterDomains with ending values that are 0 is not supported at this time.");
1865
1866 TORCH_CHECK(factor->isAnInt(), "Cannot split by non-integer value ", factor);
1867
1868 if (factor->getValType() == ValType::Scalar) {
1869 TORCH_CHECK(
1870 factor->isConstScalar() ||
1871 (FusionGuard::getCurFusion() == factor->fusion() &&
1872 factor->isFusionInput()),
1873 factor,
1874 " is not a constant nor an input. It must be one or the other to be used in a split.",
1875 " If you want a symbolic split based on a thread dimension please use IterDomain::split(IterDomain*, ParallelType);");
1876 } else if (factor->getValType() == ValType::NamedScalar) {
1877 TORCH_CHECK(
1878 factor->as<NamedScalar>()->getParallelDim() != c10::nullopt,
1879 "Splitting a dimension by a named scalar is only supported on block or grid dimensions but received ",
1880 factor);
1881 }
1882
1883 // outer loop size
1884 Val* remainder =
1885 ceilDiv(Split::extent(in->extent(), start_offset, stop_offset), factor);
1886 Val* expanded_remainder = nullptr;
1887 if (in->hasExpandedExtent()) {
1888 expanded_remainder = ceilDiv(
1889 Split::extent(in->expandedExtent(), start_offset, stop_offset), factor);
1890 }
1891
1892 if ((start_offset != nullptr && !start_offset->isZeroInt()) ||
1893 (stop_offset != nullptr && !stop_offset->isZeroInt())) {
1894 TORCH_INTERNAL_ASSERT(
1895 in->definition() == nullptr,
1896 "Partial split is only allowed with root domains");
1897 }
1898 // outer loop IterDomain
1899 IterDomain* ido =
1900 IterDomainBuilder(
1901 in->container()->zeroVal(),
1902 inner_split ? remainder->as<Int>() : factor)
1903 .expanded_extent(
1904 in->hasExpandedExtent() && inner_split ? expanded_remainder
1905 : nullptr)
1906 .parallel_type(in->getParallelType())
1907 .iter_type(in->getIterType())
1908 .build();
1909
1910 // inner loop IterDomain
1911 IterDomain* idi =
1912 IterDomainBuilder(
1913 in->container()->zeroVal(),
1914 inner_split ? factor : remainder->as<Int>())
1915 .expanded_extent(
1916 in->hasExpandedExtent() && !inner_split ? expanded_remainder
1917 : nullptr)
1918 .parallel_type(in->getParallelType())
1919 .iter_type(in->getIterType())
1920 .build();
1921
1922 IrBuilder::create<Split>(
1923 in->container(),
1924 ido,
1925 idi,
1926 in,
1927 factor,
1928 inner_split,
1929 start_offset,
1930 stop_offset);
1931 return {ido, idi};
1932}
1933
1934std::pair<IterDomain*, IterDomain*> IterDomain::split(
1935 IterDomain* in,
1936 Val* factor,
1937 bool inner_split,
1938 bool trim_out_of_bounds) {
1939 auto start_offset = trim_out_of_bounds ? in->start() : nullptr;
1940 auto stop_offset = trim_out_of_bounds ? in->stopOffset() : nullptr;
1941 return IterDomain::split(in, factor, inner_split, start_offset, stop_offset);
1942}
1943
1944std::pair<IterDomain*, IterDomain*> IterDomain::stridedSplit(int factor) {
1945 // Use partial split so that only valid values are retained
1946 auto split_out = IterDomain::split(
1947 this, IrBuilder::create<Int>(container(), factor), true, true);
1948
1949 split_out.second->iter_type_ = IterType::Stride;
1950 split_out.first->is_rfactor_domain_ = true;
1951 split_out.second->is_rfactor_domain_ = true;
1952 return split_out;
1953}
1954
1955std::pair<IterDomain*, IterDomain*> IterDomain::swizzle(
1956 Swizzle2DType swizzle_type,
1957 IterDomain* in_x,
1958 IterDomain* in_y,
1959 SwizzleMode swizzle_mode) {
1960 TORCH_CHECK(
1961 !in_x->extent()->isZeroInt() && !in_y->extent()->isZeroInt(),
1962 "Invalid swizzling of a empty dimension.");
1963
1964 // TODO: reduction check on swizzle:
1965 TORCH_CHECK(
1966 !in_x->isReduction() && !in_y->isReduction(),
1967 "swizzled reduction not yet supported");
1968
1969 for (auto input : InputsOf::outputs(in_x->fusion(), {in_x, in_y})) {
1970 TORCH_CHECK(
1971 !input->as<IterDomain>()->isBroadcast(),
1972 "swizzling broadcast axes not yet supported");
1973 }
1974
1975 // TODO: gather and shift check on swizzle
1976 TORCH_INTERNAL_ASSERT(
1977 !in_x->isGather() && !in_y->isGather(),
1978 "Swizzled gather not yet supported");
1979
1980 IterDomain* out_x = IterDomainBuilder(in_x).build();
1981
1982 IterDomain* out_y = IterDomainBuilder(in_y).build();
1983
1984 IrBuilder::create<Swizzle2D>(
1985 in_x->container(), out_x, out_y, in_x, in_y, swizzle_type, swizzle_mode);
1986
1987 return std::make_pair(out_x, out_y);
1988}
1989
1990// TODO: We should change parallelize interface to be on tensorview or at least
1991// vectorize should be done on tensorview. This would let us check that we don't
1992// vectorize to the left of the computeAt domain, and could allow us to do some
1993// simple validation of vectorize as it's inputs are right most and contiguous.
1994void IterDomain::parallelize(ParallelType t) {
1995 if (parallel_type_ == t) {
1996 // No op, don't do any more checks, it was already set to this value.
1997 return;
1998 }
1999
2000 if (t == ParallelType::Unroll || isParallelTypeVectorize(t) ||
2001 t == ParallelType::Group) {
2002 TORCH_CHECK(
2003 start()->isZeroInt() && extent()->isConstScalar(),
2004 "Vectorization, unrolling, unswitching and grouping are only supported with start = 0 and extent as a const int, but got ",
2005 "a start of ",
2006 start(),
2007 " and extent ",
2008 extent(),
2009 " .");
2010 }
2011
2012 if (t == ParallelType::Group) {
2013 TORCH_CHECK(
2014 getIterType() == IterType::Iteration,
2015 "Grouping IterDomain of non Iteration type is not allowed. ",
2016 getIterType());
2017 }
2018
2019 if (isMmaSwizzled()) {
2020 // Mma swizzled axes represent data representation within a warp
2021 // so only allow updates that keep the parallelization within
2022 // a warp.
2023 // Note && TODO: this check is actually used to allow indexing path
2024 // to make copies of the iterdomains. We might eventually just want
2025 // to lock these parallel types and not allowing any changes once
2026 // they are swizzled.
2027 TORCH_CHECK(
2028 t == ParallelType::Vectorize || t == ParallelType::TIDx ||
2029 t == ParallelType::Serial,
2030 "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids");
2031 }
2032
2033 parallel_type_ = t;
2034}
2035
2036bool IterDomain::maybePartial() const {
2037 return !start()->isZeroInt() || !stopOffset()->isZeroInt();
2038}
2039
2040Val* IterDomain::stopOffset() const {
2041 return stop_offset_;
2042}
2043
2044Val* IterDomain::stop() const {
2045 if (stopOffset()->isZeroInt()) {
2046 return extent();
2047 }
2048
2049 return sub(extent(), stopOffset());
2050}
2051
2052TensorDomain::TensorDomain(
2053 IrBuilderPasskey passkey,
2054 std::vector<IterDomain*> root_domain,
2055 std::vector<bool> contiguity)
2056 : Val(passkey, ValType::TensorDomain, DataType::Null),
2057 root_domain_(std::move(root_domain)),
2058 contiguity_(
2059 contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
2060 : std::move(contiguity)) {
2061 TORCH_CHECK(
2062 contiguity_.size() == getMaybeRFactorDomain().size(),
2063 "Invalid contiguity information provided, incorrect size. Received vector of size ",
2064 contiguity_.size(),
2065 " but needed one of size ",
2066 root_domain_.size());
2067
2068 // Just due to clang-tidy, correct value set in resetDomains
2069 has_nontrivial_reduction_ = false;
2070 domain_ = root_domain_;
2071 resetDomains();
2072}
2073
2074TensorDomain::TensorDomain(
2075 IrBuilderPasskey passkey,
2076 std::vector<IterDomain*> root_domain,
2077 std::vector<IterDomain*> domain,
2078 std::vector<bool> contiguity)
2079 : Val(passkey, ValType::TensorDomain, DataType::Null),
2080 root_domain_(std::move(root_domain)),
2081 domain_(std::move(domain)),
2082 contiguity_(
2083 contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
2084 : std::move(contiguity)) {
2085 TORCH_CHECK(
2086 contiguity_.size() == getMaybeRFactorDomain().size(),
2087 "Invalid contiguity information provided, incorrect size. Received vector of size ",
2088 contiguity_.size(),
2089 " but needed one of size ",
2090 root_domain_.size());
2091
2092 std::vector<Val*> domain_vals(domain_.begin(), domain_.end());
2093 auto inps = IterVisitor::getInputsTo(domain_vals);
2094
2095 // Validate that the root domain consists of all inputs to domain
2096 // Uncertain if this will hold for RFactor
2097
2098 std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end());
2099 std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) {
2100 TORCH_INTERNAL_ASSERT(
2101 root_vals.find(inp) != root_vals.end(),
2102 "Invalid tensor domain, ",
2103 inp,
2104 " is an input of domain, but it is not found in the root domain.");
2105 });
2106
2107 // Just due to clang-tidy, correct value set in resetDomains
2108 has_nontrivial_reduction_ = false;
2109 resetDomains();
2110}
2111
2112TensorDomain::TensorDomain(
2113 IrBuilderPasskey passkey,
2114 std::vector<IterDomain*> root_domain,
2115 std::vector<IterDomain*> rfactor_domain,
2116 std::vector<IterDomain*> domain,
2117 std::vector<bool> contiguity)
2118 : Val(passkey, ValType::TensorDomain, DataType::Null),
2119 root_domain_(std::move(root_domain)),
2120 domain_(std::move(domain)),
2121 rfactor_domain_(std::move(rfactor_domain)),
2122 contiguity_(
2123 contiguity.empty() ? std::vector<bool>(rfactor_domain_.size(), false)
2124 : std::move(contiguity)) {
2125 TORCH_CHECK(
2126 contiguity_.size() == getMaybeRFactorDomain().size(),
2127 "Invalid contiguity information provided, incorrect size. Received vector of size ",
2128 contiguity_.size(),
2129 " but needed one of size ",
2130 getMaybeRFactorDomain().size());
2131
2132 auto inps = IterVisitor::getInputsTo(
2133 std::vector<Val*>(domain_.begin(), domain_.end()));
2134
2135 // Validate that the root domain consists of all inputs to domain
2136 // Uncertain if this will hold for RFactor
2137
2138 std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end());
2139 std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) {
2140 TORCH_INTERNAL_ASSERT(
2141 root_vals.find(inp) != root_vals.end(),
2142 "Invalid tensor domain, ",
2143 inp,
2144 " is an input of domain, but it is not found in the root domain.");
2145 });
2146
2147 inps = IterVisitor::getInputsTo(
2148 std::vector<Val*>(rfactor_domain_.begin(), rfactor_domain_.end()));
2149 std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) {
2150 TORCH_INTERNAL_ASSERT(
2151 root_vals.find(inp) != root_vals.end(),
2152 "Invalid tensor domain, ",
2153 inp,
2154 " is an input of the rfactor domain, but it is not found in the root domain.");
2155 });
2156
2157 // Just due to clang-tidy, correct value set in resetDomains
2158 has_nontrivial_reduction_ = false;
2159 resetDomains();
2160}
2161
2162TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner)
2163 : Val(src, ir_cloner),
2164 root_domain_(ir_cloner->clone(src->root_domain_)),
2165 domain_(ir_cloner->clone(src->domain_)),
2166 no_bcast_domain_(ir_cloner->clone(src->no_bcast_domain_)),
2167 no_reduction_domain_(ir_cloner->clone(src->no_reduction_domain_)),
2168 rfactor_domain_(ir_cloner->clone(src->rfactor_domain_)),
2169 contiguity_(src->contiguity()),
2170 has_nontrivial_reduction_(src->has_nontrivial_reduction_) {}
2171
2172bool TensorDomain::hasBlockBroadcast() const {
2173 return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
2174 return id->isBroadcast() && id->isThreadDim();
2175 });
2176}
2177
2178bool TensorDomain::hasGridBroadcast() const {
2179 return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
2180 return id->isBroadcast() && id->isBlockDim();
2181 });
2182}
2183
2184bool TensorDomain::operator==(const TensorDomain& other) const {
2185 // Checks equality of each class field. Should not be necessary to
2186 // check no_bcast_domain_ and no_reduction_domain_ as they are just
2187 // derived from domain_.
2188 return root_domain_ == other.root_domain_ && domain_ == other.domain_ &&
2189 rfactor_domain_ == other.rfactor_domain_ &&
2190 contiguity_ == other.contiguity_;
2191}
2192
2193bool TensorDomain::sameAs(const Statement* const other) const {
2194 if (this == other) {
2195 return true;
2196 }
2197
2198 if (!other->isA<TensorDomain>()) {
2199 return false;
2200 }
2201
2202 const TensorDomain* other_td = other->as<TensorDomain>();
2203
2204 if (nDims() != other_td->nDims()) {
2205 return false;
2206 }
2207 if (getRootDomain().size() != other_td->getRootDomain().size()) {
2208 return false;
2209 }
2210 if (getRFactorDomain().size() != other_td->getRFactorDomain().size()) {
2211 return false;
2212 }
2213
2214 for (const auto i : c10::irange(nDims())) {
2215 if (!(axis(i)->sameAs(other_td->axis(i)))) {
2216 return false;
2217 }
2218 }
2219
2220 for (const auto i : c10::irange(getRootDomain().size())) {
2221 if (!(getRootDomain()[i]->sameAs(other_td->getRootDomain()[i]))) {
2222 return false;
2223 }
2224 }
2225
2226 for (const auto i : c10::irange(getRFactorDomain().size())) {
2227 if (!(getRFactorDomain()[i]->sameAs(other_td->getRFactorDomain()[i]))) {
2228 return false;
2229 }
2230 }
2231
2232 return true;
2233}
2234
2235bool TensorDomain::sameAs(
2236 const std::vector<IterDomain*>& lhs,
2237 const std::vector<IterDomain*>& rhs) {
2238 if (lhs.size() != rhs.size())
2239 return false;
2240 size_t i = 0;
2241 for (auto td_lhs : lhs) {
2242 if (!td_lhs->sameAs(rhs[i++]))
2243 return false;
2244 }
2245 return true;
2246}
2247
2248void TensorDomain::setContiguity(const std::vector<bool>& contig) {
2249 TORCH_INTERNAL_ASSERT(
2250 getMaybeRFactorDomain().size() == contig.size(),
2251 "Invalid contiguity vector: ",
2252 contig);
2253
2254 contiguity_ = contig;
2255}
2256
2257bool TensorDomain::hasReduction() const {
2258 return has_nontrivial_reduction_;
2259}
2260
2261bool TensorDomain::hasBlockReduction() const {
2262 return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
2263 return id->isReduction() && id->isThreadDim();
2264 });
2265}
2266
2267bool TensorDomain::hasGridReduction() const {
2268 return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
2269 return id->isReduction() && id->isBlockDim();
2270 });
2271}
2272
2273bool TensorDomain::hasBroadcast() const {
2274 return no_bcast_domain_.size() != domain_.size();
2275}
2276
2277bool TensorDomain::hasRFactor() const {
2278 return !rfactor_domain_.empty();
2279}
2280
2281bool TensorDomain::hasViewLikeRFactor() const {
2282 if (!hasRFactor()) {
2283 // Can't have view like rfactor if there is no rfactor domain
2284 return false;
2285 }
2286
2287 // If there's an rfactor domain and no rfactor product is a reduction, this is
2288 // a view like rfactor
2289 return std::none_of(
2290 getMaybeRFactorDomain().begin(),
2291 getMaybeRFactorDomain().end(),
2292 [](IterDomain* id) {
2293 return id->isReduction() && id->isRFactorProduct();
2294 });
2295}
2296
2297bool TensorDomain::hasVectorize() const {
2298 return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
2299 return id->getParallelType() == ParallelType::Vectorize ||
2300 id->getParallelType() == ParallelType::MisalignedVectorize;
2301 });
2302}
2303
2304c10::optional<unsigned int> TensorDomain::getReductionAxis() const {
2305 auto it = std::find_if(domain_.begin(), domain_.end(), [](const auto& id) {
2306 return id->isReduction();
2307 });
2308 if (it == domain_.end()) {
2309 return c10::optional<unsigned int>();
2310 } else {
2311 return c10::optional<unsigned int>(std::distance(domain_.begin(), it));
2312 }
2313}
2314
2315// i here is int, as we want to accept negative value and ::size_type can be a
2316// uint.
2317IterDomain* TensorDomain::axis(int i) const {
2318 TORCH_INTERNAL_ASSERT(
2319 nDims() > 0, "Tried to access an axis in a 0-dim domain");
2320 if (i < 0)
2321 i += nDims();
2322 TORCH_CHECK(
2323 i >= 0 && (unsigned int)i < nDims(),
2324 "Tried to access axis ",
2325 i,
2326 " in domain ",
2327 this);
2328 return domain_[i];
2329}
2330
2331size_t TensorDomain::posOf(IterDomain* id) const {
2332 TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to find an axis in a 0-dim domain");
2333 size_t i = 0;
2334 while (i < domain_.size()) {
2335 if (domain_[i] == id)
2336 return i;
2337 i++;
2338 }
2339 TORCH_CHECK(false, "Provided id is not part of this domain.");
2340}
2341
2342size_t TensorDomain::rootPosOf(IterDomain* id) const {
2343 TORCH_INTERNAL_ASSERT(
2344 root_domain_.size() > 0, "Tried to find an axis in a 0-dim root domain");
2345 auto it = std::find(root_domain_.begin(), root_domain_.end(), id);
2346 TORCH_INTERNAL_ASSERT(
2347 it != root_domain_.end(), "Provided id is not part of root domain.");
2348 return std::distance(root_domain_.begin(), it);
2349}
2350
2351void TensorDomain::split(
2352 int axis_,
2353 Val* factor,
2354 bool inner_split,
2355 bool trim_out_of_bounds) {
2356 TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim domain");
2357 if (axis_ < 0)
2358 axis_ += nDims();
2359
2360 TORCH_INTERNAL_ASSERT(
2361 axis_ >= 0 && (unsigned int)axis_ < nDims(),
2362 "Tried to split on axis outside TensorDomain's range.");
2363
2364 IterDomain* id = axis(axis_);
2365
2366 // partial split is only allowed with root domains
2367 if (trim_out_of_bounds) {
2368 TORCH_INTERNAL_ASSERT(
2369 std::find(getRootDomain().begin(), getRootDomain().end(), id) !=
2370 getRootDomain().end(),
2371 "Partial split is only allowed with root domains");
2372 }
2373
2374 TORCH_INTERNAL_ASSERT(
2375 !id->isMmaSwizzled(),
2376 "Further transformation on warp mapped id's not allowed.");
2377
2378 auto split_ids =
2379 IterDomain::split(id, factor, inner_split, trim_out_of_bounds);
2380 domain_.erase(domain_.begin() + axis_);
2381 domain_.insert(domain_.begin() + axis_, split_ids.second);
2382 domain_.insert(domain_.begin() + axis_, split_ids.first);
2383 resetDomains();
2384}
2385
2386// Merge "axis_o" and "axis_i" into 1 dimension
2387void TensorDomain::merge(int axis_o, int axis_i) {
2388 TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain");
2389 if (axis_o < 0)
2390 axis_o += nDims();
2391
2392 if (axis_i < 0)
2393 axis_i += nDims();
2394
2395 TORCH_CHECK(
2396 axis_o >= 0 && (unsigned int)axis_o < nDims() && axis_i >= 0 &&
2397 (unsigned int)axis_i < nDims(),
2398 "Invalid merge detected, either one or both axes are outside of TensorView's range.");
2399
2400 TORCH_CHECK(
2401 axis_o != axis_i,
2402 "Invalid merge detected, axes provided are the same axis.");
2403
2404 if (axis_o > axis_i) {
2405 auto tmp = axis_i;
2406 axis_i = axis_o;
2407 axis_o = tmp;
2408 }
2409
2410 IterDomain* first = axis(axis_o);
2411 IterDomain* second = axis(axis_i);
2412
2413 TORCH_INTERNAL_ASSERT(
2414 !first->isMmaSwizzled() && !second->isMmaSwizzled(),
2415 "Further transformation on warp mapped id's not allowed.");
2416
2417 IterDomain* merged_id = IterDomain::merge(first, second);
2418
2419 domain_.erase(domain_.begin() + axis_i);
2420 domain_.erase(domain_.begin() + axis_o);
2421 domain_.insert(domain_.begin() + axis_o, merged_id);
2422 resetDomains();
2423}
2424
2425// Reorder axes according to map[old_pos] = new_pos
2426void TensorDomain::reorder(const std::unordered_map<int, int>& old2new_) {
2427 TORCH_INTERNAL_ASSERT(
2428 !(nDims() == 0 && old2new_.size() > 0),
2429 "Tried to reorder a 0-dim domain");
2430 domain_ = orderedAs(domain_, old2new_);
2431 resetDomains();
2432}
2433
2434std::vector<IterDomain*> TensorDomain::orderedAs(
2435 const std::vector<IterDomain*>& dom,
2436 const std::unordered_map<int, int>& old2new_) {
2437 TORCH_INTERNAL_ASSERT(
2438 !(dom.size() == 0 && old2new_.size() > 0),
2439 "Tried to reorder a 0-dim domain");
2440
2441 // Eventhough these checks are already in TensorView, we want to redo them as
2442 // we can enter this function from other places, not through TensorView
2443
2444 auto new2old = ir_utils::normalizeOld2New(old2new_, dom.size());
2445
2446 std::vector<IterDomain*> reordered_domain;
2447 std::transform(
2448 new2old.begin(),
2449 new2old.end(),
2450 std::back_inserter(reordered_domain),
2451 [dom](int i) -> IterDomain* { return dom[i]; });
2452
2453 return reordered_domain;
2454}
2455
2456void TensorDomain::swizzle(
2457 Swizzle2DType swizzle_type,
2458 int x,
2459 int y,
2460 SwizzleMode swizzle_mode) {
2461 TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain");
2462
2463 TORCH_CHECK(
2464 x >= 0 && (unsigned int)x < nDims(),
2465 "Invalid swizzle detected, either one or both axes are outside of TensorView's range.");
2466
2467 TORCH_CHECK(
2468 y >= 0 && (unsigned int)y < nDims(),
2469 "Invalid swizzle detected, either one or both axes are outside of TensorView's range.");
2470
2471 IterDomain* axis_x = axis(x);
2472 IterDomain* axis_y = axis(y);
2473
2474 IterDomain* axis_out_x = nullptr;
2475 IterDomain* axis_out_y = nullptr;
2476
2477 std::tie(axis_out_x, axis_out_y) =
2478 IterDomain::swizzle(swizzle_type, axis_x, axis_y, swizzle_mode);
2479
2480 domain_.erase(domain_.begin() + x);
2481 domain_.insert(domain_.begin() + x, axis_out_x);
2482
2483 domain_.erase(domain_.begin() + y);
2484 domain_.insert(domain_.begin() + y, axis_out_y);
2485
2486 resetDomains();
2487}
2488
2489std::vector<IterDomain*> TensorDomain::noReductions(
2490 const std::vector<IterDomain*>& td) {
2491 size_t size_out = 0;
2492 for (auto id : td) {
2493 if (!id->isReduction() && !id->isStride()) {
2494 size_out++;
2495 }
2496 }
2497 std::vector<IterDomain*> noReductionDomain(size_out);
2498
2499 int it = 0;
2500 for (auto id : td) {
2501 if (!id->isReduction() && !id->isStride()) {
2502 noReductionDomain[it++] = id;
2503 }
2504 }
2505
2506 return noReductionDomain;
2507}
2508
2509std::vector<IterDomain*> TensorDomain::noBroadcasts(
2510 const std::vector<IterDomain*>& td) {
2511 size_t size_out = 0;
2512 for (auto id : td)
2513 if (!id->isBroadcast())
2514 size_out++;
2515 std::vector<IterDomain*> noBroadcastDomain(size_out);
2516
2517 int it = 0;
2518 for (auto id : td)
2519 if (!id->isBroadcast())
2520 noBroadcastDomain[it++] = id;
2521
2522 return noBroadcastDomain;
2523}
2524
2525bool TensorDomain::hasBroadcast(const std::vector<IterDomain*>& td) {
2526 for (auto id : td)
2527 if (id->isBroadcast())
2528 return true;
2529 return false;
2530}
2531
2532bool TensorDomain::hasReduction(const std::vector<IterDomain*>& td) {
2533 for (auto id : td)
2534 if (id->isReduction())
2535 return true;
2536 return false;
2537}
2538
2539bool TensorDomain::hasNontrivialReduction(const std::vector<IterDomain*>& td) {
2540 for (auto id : td) {
2541 if (id->isReduction() && !id->isTrivialReduction()) {
2542 return true;
2543 }
2544 }
2545 return false;
2546}
2547
2548TensorDomain* TensorDomain::view(const AnalyzeViewResult& view_analysis) {
2549 TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to view transform a 0-dim domain");
2550 return transformView(this, view_analysis);
2551}
2552
2553TensorDomain* TensorDomain::flatten(int64_t start_dim, int64_t end_dim) {
2554 auto inp_domain = noReductions(getMaybeRFactorDomain());
2555
2556 if (start_dim < 0) {
2557 start_dim += inp_domain.size();
2558 }
2559 if (end_dim < 0) {
2560 end_dim += inp_domain.size();
2561 }
2562 TORCH_CHECK(
2563 start_dim >= 0 && start_dim < int64_t(inp_domain.size()),
2564 "Invalid start_dim ",
2565 start_dim);
2566 TORCH_CHECK(
2567 end_dim >= 0 && end_dim < int64_t(inp_domain.size()),
2568 "Invalid end_dim ",
2569 end_dim);
2570 TORCH_CHECK(start_dim <= end_dim, "start_dim must be <= end_dim");
2571
2572 std::vector<IterDomain*> new_root_domain;
2573 new_root_domain.reserve(inp_domain.size());
2574 for (auto i : c10::irange(inp_domain.size())) {
2575 bool is_rfactor_dim = i >= size_t(start_dim) && i <= size_t(end_dim);
2576 auto inp_id = inp_domain[i];
2577 auto out_id = IterDomainBuilder(inp_id)
2578 .is_rfactor_domain(is_rfactor_dim)
2579 .extent(
2580 (is_rfactor_dim && inp_id->hasExpandedExtent())
2581 ? inp_id->expandedExtent()
2582 : inp_id->extent())
2583 .iter_type(
2584 (is_rfactor_dim && inp_id->isBroadcast())
2585 ? IterType::Iteration
2586 : inp_id->getIterType())
2587 .build();
2588 new_root_domain.push_back(out_id);
2589 }
2590
2591 std::vector<IterDomain*> rfactor_domain;
2592 rfactor_domain.reserve(new_root_domain.size() - (end_dim - start_dim));
2593 for (auto i : c10::irange(start_dim)) {
2594 rfactor_domain.push_back(new_root_domain[i]);
2595 }
2596
2597 IterDomain* merged_id = new_root_domain[start_dim];
2598 for (auto i : c10::irange(start_dim + 1, end_dim + 1)) {
2599 IterDomain* new_merged_id =
2600 IterDomainBuilder(
2601 merged_id->container()->zeroVal(),
2602 mul(merged_id->extent(), new_root_domain[i]->extent()))
2603 .is_rfactor_domain(true)
2604 .build();
2605 IrBuilder::create<Merge>(new_merged_id, merged_id, new_root_domain[i]);
2606 merged_id = new_merged_id;
2607 }
2608 rfactor_domain.push_back(merged_id);
2609
2610 for (auto i : c10::irange(end_dim + 1, inp_domain.size())) {
2611 rfactor_domain.push_back(new_root_domain[i]);
2612 }
2613
2614 return IrBuilder::create<TensorDomain>(
2615 new_root_domain,
2616 rfactor_domain,
2617 rfactor_domain,
2618 std::vector<bool>(rfactor_domain.size(), true));
2619}
2620
2621// TODO: Rfactor a Welford
2622
2623// pair is in order where second is the consumer of first
2624std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
2625 const std::vector<int>& axes_) {
2626 return TransformRFactor::runReplay(this, axes_);
2627}
2628
2629Split::Split(
2630 IrBuilderPasskey passkey,
2631 IterDomain* outer,
2632 IterDomain* inner,
2633 IterDomain* in,
2634 Val* factor,
2635 bool inner_split,
2636 Val* start_offset,
2637 Val* stop_offset)
2638 : Expr(passkey, ExprType::Split),
2639 outer_{outer},
2640 inner_{inner},
2641 in_{in},
2642 factor_{factor},
2643 inner_split_{inner_split},
2644 start_offset_{
2645 start_offset != nullptr ? start_offset
2646 : passkey.ir_container_->zeroVal()},
2647 stop_offset_{
2648 stop_offset != nullptr ? stop_offset
2649 : passkey.ir_container_->zeroVal()} {
2650 TORCH_INTERNAL_ASSERT(
2651 factor_->isAnInt(),
2652 "Attempted to create a Split node with a non-integer factor.");
2653 addOutput(outer);
2654 addOutput(inner);
2655 addInput(in);
2656 // TODO add factor as an input, need to check Split::Split during validation
2657 // and need to check BestEffortReplay::findFirstMismatchedID addInput(factor);
2658}
2659
2660Split::Split(const Split* src, IrCloner* ir_cloner)
2661 : Expr(src, ir_cloner),
2662 outer_(ir_cloner->clone(src->outer_)),
2663 inner_(ir_cloner->clone(src->inner_)),
2664 in_(ir_cloner->clone(src->in_)),
2665 factor_(ir_cloner->clone(src->factor_)),
2666 inner_split_(src->inner_split_),
2667 start_offset_(ir_cloner->clone(src->start_offset_)),
2668 stop_offset_(ir_cloner->clone(src->stop_offset_)) {}
2669
2670Expr* Split::shallowCopy() const {
2671 auto result = IrBuilder::create<Split>(
2672 outer_, inner_, in_, factor_, inner_split_, start_offset_, stop_offset_);
2673 result->copyPredicatesFrom(this);
2674 return result;
2675}
2676
2677Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) {
2678 TORCH_INTERNAL_ASSERT(in_extent != nullptr);
2679
2680 if (start_offset != nullptr && !start_offset->isZeroInt()) {
2681 in_extent = sub(in_extent, start_offset);
2682 }
2683
2684 if (stop_offset != nullptr && !stop_offset->isZeroInt()) {
2685 in_extent = sub(in_extent, stop_offset);
2686 }
2687
2688 return in_extent;
2689}
2690
2691bool Split::sameAs(const Statement* other) const {
2692 if (this == other) {
2693 return true;
2694 }
2695 if (!other->isA<Split>()) {
2696 return false;
2697 }
2698 return Expr::sameAs(other) &&
2699 factor()->sameAs(other->as<Split>()->factor()) &&
2700 innerSplit() == other->as<Split>()->innerSplit() &&
2701 startOffset()->sameAs(other->as<Split>()->startOffset()) &&
2702 stopOffset()->sameAs(other->as<Split>()->stopOffset());
2703}
2704
2705Merge::Merge(
2706 IrBuilderPasskey passkey,
2707 IterDomain* out,
2708 IterDomain* outer,
2709 IterDomain* inner)
2710 : Expr(passkey, ExprType::Merge), out_{out}, outer_{outer}, inner_{inner} {
2711 addOutput(out);
2712 addInput(outer);
2713 addInput(inner);
2714}
2715
2716Merge::Merge(const Merge* src, IrCloner* ir_cloner)
2717 : Expr(src, ir_cloner),
2718 out_(ir_cloner->clone(src->out_)),
2719 outer_(ir_cloner->clone(src->outer_)),
2720 inner_(ir_cloner->clone(src->inner_)) {}
2721
2722Expr* Merge::shallowCopy() const {
2723 auto result = IrBuilder::create<Merge>(out_, outer_, inner_);
2724 result->copyPredicatesFrom(this);
2725 return result;
2726}
2727
2728bool Merge::sameAs(const Statement* other) const {
2729 if (this == other) {
2730 return true;
2731 }
2732 if (!other->isA<Merge>()) {
2733 return false;
2734 }
2735 return Expr::sameAs(other);
2736}
2737
2738Swizzle2D::Swizzle2D(
2739 IrBuilderPasskey passkey,
2740 IterDomain* out_x,
2741 IterDomain* out_y,
2742 IterDomain* in_x,
2743 IterDomain* in_y,
2744 Swizzle2DType swizzle_type,
2745 SwizzleMode swizzle_mode)
2746 : Expr(passkey, ExprType::Swizzle2D),
2747 out_x_{out_x},
2748 out_y_{out_y},
2749 in_x_{in_x},
2750 in_y_{in_y},
2751 swizzle_type_(swizzle_type),
2752 swizzle_mode_(swizzle_mode) {
2753 addOutput(out_x);
2754 addOutput(out_y);
2755 addInput(in_x);
2756 addInput(in_y);
2757}
2758
2759Expr* Swizzle2D::shallowCopy() const {
2760 auto result = IrBuilder::create<Swizzle2D>(
2761 out_x_, out_y_, in_x_, in_y_, swizzle_type_, swizzle_mode_);
2762 result->copyPredicatesFrom(this);
2763 return result;
2764}
2765
2766bool Swizzle2D::sameAs(const Statement* other) const {
2767 if (this == other) {
2768 return true;
2769 }
2770 if (!other->isA<Swizzle2D>()) {
2771 return false;
2772 }
2773 if (!(swizzle_type_ == other->as<Swizzle2D>()->swizzle_type_)) {
2774 return false;
2775 }
2776 return Expr::sameAs(other);
2777}
2778
2779Swizzle2D::Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner)
2780 : Expr(src, ir_cloner),
2781 out_x_(ir_cloner->clone(src->out_x_)),
2782 out_y_(ir_cloner->clone(src->out_y_)),
2783 in_x_(ir_cloner->clone(src->in_x_)),
2784 in_y_(ir_cloner->clone(src->in_y_)),
2785 swizzle_type_(src->swizzle_type_),
2786 swizzle_mode_(src->swizzle_mode_) {}
2787
2788NamedScalar::NamedScalar(
2789 IrBuilderPasskey passkey,
2790 std::string name,
2791 DataType dtype)
2792 : Val(passkey, ValType::NamedScalar, dtype), name_(std::move(name)) {}
2793
2794NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner)
2795 : Val(src, ir_cloner), name_(src->name_) {}
2796
2797bool NamedScalar::sameAs(const Statement* other) const {
2798 if (this == other) {
2799 return true;
2800 }
2801 if (!other->isA<NamedScalar>()) {
2802 return false;
2803 }
2804 return other->as<NamedScalar>()->name().compare(name()) == 0;
2805}
2806
2807NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) {
2808 TORCH_INTERNAL_ASSERT(
2809 isParallelTypeThread(p_type),
2810 "Cannot get parallel dim of non thread type, received: ",
2811 p_type);
2812 TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr);
2813 std::string parallel_dim = stringifyThreadSize(p_type);
2814 return IrBuilder::create<NamedScalar>(parallel_dim, DataType::Int);
2815}
2816
2817NamedScalar* NamedScalar::getParallelIndex(ParallelType p_type) {
2818 TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr);
2819 std::string parallel_ind = stringifyThread(p_type);
2820 return IrBuilder::create<NamedScalar>(parallel_ind, DataType::Int);
2821}
2822
2823c10::optional<ParallelType> NamedScalar::getParallelDim() const {
2824 if (stringifyThreadSize(ParallelType::TIDx).compare(name()) == 0) {
2825 return c10::optional<ParallelType>(ParallelType::TIDx);
2826 } else if (stringifyThreadSize(ParallelType::TIDy).compare(name()) == 0) {
2827 return c10::optional<ParallelType>(ParallelType::TIDy);
2828 } else if (stringifyThreadSize(ParallelType::TIDz).compare(name()) == 0) {
2829 return c10::optional<ParallelType>(ParallelType::TIDz);
2830 } else if (stringifyThreadSize(ParallelType::BIDx).compare(name()) == 0) {
2831 return c10::optional<ParallelType>(ParallelType::BIDx);
2832 } else if (stringifyThreadSize(ParallelType::BIDy).compare(name()) == 0) {
2833 return c10::optional<ParallelType>(ParallelType::BIDy);
2834 } else if (stringifyThreadSize(ParallelType::BIDz).compare(name()) == 0) {
2835 return c10::optional<ParallelType>(ParallelType::BIDz);
2836 }
2837 return c10::nullopt;
2838}
2839
2840c10::optional<ParallelType> NamedScalar::getParallelIndex() const {
2841 if (stringifyThread(ParallelType::TIDx).compare(name()) == 0) {
2842 return c10::optional<ParallelType>(ParallelType::TIDx);
2843 } else if (stringifyThread(ParallelType::TIDy).compare(name()) == 0) {
2844 return c10::optional<ParallelType>(ParallelType::TIDy);
2845 } else if (stringifyThread(ParallelType::TIDz).compare(name()) == 0) {
2846 return c10::optional<ParallelType>(ParallelType::TIDz);
2847 } else if (stringifyThread(ParallelType::BIDx).compare(name()) == 0) {
2848 return c10::optional<ParallelType>(ParallelType::BIDx);
2849 } else if (stringifyThread(ParallelType::BIDy).compare(name()) == 0) {
2850 return c10::optional<ParallelType>(ParallelType::BIDy);
2851 } else if (stringifyThread(ParallelType::BIDz).compare(name()) == 0) {
2852 return c10::optional<ParallelType>(ParallelType::BIDz);
2853 }
2854 return c10::nullopt;
2855}
2856
2857} // namespace cuda
2858} // namespace fuser
2859} // namespace jit
2860} // namespace torch
2861