1 | #include <arith.h> |
2 | #include <disjoint_set.h> |
3 | #include <ir_cloner.h> |
4 | #include <ir_interface_nodes.h> |
5 | #include <ir_iostream.h> |
6 | #include <ir_utils.h> |
7 | #include <kernel.h> |
8 | #include <kernel_ir.h> |
9 | #include <lower2device.h> |
10 | #include <root_domain_map.h> |
11 | #include <transform_iter.h> |
12 | #include <transform_rfactor.h> |
13 | #include <transform_view.h> |
14 | |
15 | #include <c10/util/irange.h> |
16 | |
17 | #include <sstream> |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | namespace fuser { |
22 | namespace cuda { |
23 | |
24 | namespace { |
25 | |
26 | class ScalarCheck : OptInConstDispatch { |
27 | public: |
28 | static bool sameAs(const Val* v1, const Val* v2) { |
29 | if (v1 == v2) |
30 | return true; |
31 | |
32 | if (v1->getValType() != v2->getValType()) |
33 | return false; |
34 | |
35 | if (v1->getDataType() != v2->getDataType()) |
36 | return false; |
37 | |
38 | ScalarCheck sc(v1, v2); |
39 | return sc.same_; |
40 | } |
41 | |
42 | private: |
43 | void handle(const Bool* b) final { |
44 | same_ = v1_->as<Bool>()->sameAs(v2_->as<Bool>()); |
45 | } |
46 | |
47 | void handle(const Double* d) final { |
48 | same_ = v1_->as<Double>()->sameAs(v2_->as<Double>()); |
49 | } |
50 | |
51 | void handle(const Int* i) final { |
52 | same_ = v1_->as<Int>()->sameAs(v2_->as<Int>()); |
53 | } |
54 | |
55 | void handle(const NamedScalar* ns) final { |
56 | same_ = v1_->as<NamedScalar>()->sameAs(v2_->as<NamedScalar>()); |
57 | } |
58 | |
59 | ScalarCheck(const Val* _v1, const Val* _v2) : v1_(_v1), v2_(_v2) { |
60 | OptInConstDispatch::handle(v1_); |
61 | } |
62 | |
63 | private: |
64 | const Val* v1_ = nullptr; |
65 | const Val* v2_ = nullptr; |
66 | bool same_ = false; |
67 | }; |
68 | |
69 | } // namespace |
70 | |
71 | bool areEqualScalars(Val* v1, Val* v2) { |
72 | return ScalarCheck::sameAs(v1, v2); |
73 | } |
74 | |
75 | Bool::Bool(IrBuilderPasskey passkey) |
76 | : Val(passkey, ValType::Scalar, DataType::Bool), |
77 | maybe_value_{c10::nullopt} {} |
78 | |
79 | Bool::Bool(IrBuilderPasskey passkey, bool value) |
80 | : Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {} |
81 | |
82 | Bool::Bool(IrBuilderPasskey passkey, c10::optional<bool> value) |
83 | : Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {} |
84 | |
85 | Bool::Bool(const Bool* src, IrCloner* ir_cloner) |
86 | : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} |
87 | |
88 | bool Bool::sameAs(const Statement* other) const { |
89 | if (this == other) { |
90 | return true; |
91 | } |
92 | if (!other->isA<Bool>()) { |
93 | return false; |
94 | } |
95 | const auto other_bool = other->as<Bool>(); |
96 | if (isConst() && other_bool->isConst()) { |
97 | return *value() == *(other_bool->value()); |
98 | } |
99 | return false; |
100 | } |
101 | |
102 | Double::Double(IrBuilderPasskey passkey) |
103 | : Val(passkey, ValType::Scalar, DataType::Double), |
104 | maybe_value_{c10::nullopt} {} |
105 | |
106 | Double::Double(IrBuilderPasskey passkey, ScalarType value) |
107 | : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {} |
108 | |
109 | Double::Double(IrBuilderPasskey passkey, c10::optional<ScalarType> value) |
110 | : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {} |
111 | |
112 | Double::Double(const Double* src, IrCloner* ir_cloner) |
113 | : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} |
114 | |
115 | bool Double::sameAs(const Statement* other) const { |
116 | if (this == other) { |
117 | return true; |
118 | } |
119 | if (!other->isA<Double>()) { |
120 | return false; |
121 | } |
122 | const auto other_double = other->as<Double>(); |
123 | if (isConst() && other_double->isConst()) |
124 | return *value() == *(other_double->value()); |
125 | return false; |
126 | } |
127 | |
128 | Int::Int(IrBuilderPasskey passkey) |
129 | : Val(passkey, ValType::Scalar, DataType::Int), |
130 | maybe_value_{c10::nullopt} {} |
131 | |
132 | Int::Int(IrBuilderPasskey passkey, ScalarType value) |
133 | : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {} |
134 | |
135 | Int::Int(IrBuilderPasskey passkey, c10::optional<ScalarType> value) |
136 | : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {} |
137 | |
138 | Int::Int(const Int* src, IrCloner* ir_cloner) |
139 | : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} |
140 | |
141 | bool Int::sameAs(const Statement* other) const { |
142 | if (this == other) { |
143 | return true; |
144 | } |
145 | if (!other->isA<Int>()) { |
146 | return false; |
147 | } |
148 | const auto other_int = other->as<Int>(); |
149 | if (isConst() && other_int->isConst()) { |
150 | return *value() == *(other_int->value()); |
151 | } |
152 | return false; |
153 | } |
154 | |
155 | ComplexDouble::ComplexDouble(IrBuilderPasskey passkey) |
156 | : Val(passkey, ValType::Scalar, DataType::ComplexDouble), |
157 | maybe_value_{c10::nullopt} {} |
158 | |
159 | ComplexDouble::ComplexDouble(IrBuilderPasskey passkey, ScalarType value) |
160 | : Val(passkey, ValType::Scalar, DataType::ComplexDouble), |
161 | maybe_value_{value} {} |
162 | |
163 | ComplexDouble::ComplexDouble( |
164 | IrBuilderPasskey passkey, |
165 | c10::optional<ScalarType> value) |
166 | : Val(passkey, ValType::Scalar, DataType::ComplexDouble), |
167 | maybe_value_{value} {} |
168 | |
169 | ComplexDouble::ComplexDouble(const ComplexDouble* src, IrCloner* ir_cloner) |
170 | : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} |
171 | |
172 | bool ComplexDouble::sameAs(const Statement* other) const { |
173 | if (this == other) { |
174 | return true; |
175 | } |
176 | if (!other->isA<ComplexDouble>()) { |
177 | return false; |
178 | } |
179 | const auto other_complex = other->as<ComplexDouble>(); |
180 | if (isConst() && other_complex->isConst()) |
181 | return *value() == *(other_complex->value()); |
182 | return false; |
183 | } |
184 | |
185 | FullOp::FullOp( |
186 | IrBuilderPasskey passkey, |
187 | Val* out, |
188 | Val* fill_value, |
189 | DataType dtype) |
190 | : Expr(passkey, ExprType::FullOp), dtype_(dtype), fill_value_(fill_value) { |
191 | if (out->isA<TensorView>()) { |
192 | addInput(out->as<TensorView>()->getRootDomain()[0]->extent()); |
193 | } |
194 | addInput(fill_value); |
195 | addOutput(out); |
196 | } |
197 | |
198 | FullOp::FullOp(const FullOp* src, IrCloner* ir_cloner) |
199 | : Expr(src, ir_cloner), |
200 | dtype_(src->dtype()), |
201 | fill_value_(ir_cloner->clone(src->fill_value_)) {} |
202 | |
203 | Expr* FullOp::shallowCopy() const { |
204 | auto result = IrBuilder::create<FullOp>(output(0), fill_value_, dtype_); |
205 | result->copyPredicatesFrom(this); |
206 | return result; |
207 | } |
208 | |
209 | bool FullOp::sameAs(const Statement* other) const { |
210 | if (this == other) { |
211 | return true; |
212 | } |
213 | if (!other->isA<FullOp>()) { |
214 | return false; |
215 | } |
216 | const auto other_op = other->as<FullOp>(); |
217 | if (dtype_ != other_op->dtype_) { |
218 | return false; |
219 | } |
220 | return Expr::sameAs(other); |
221 | } |
222 | |
223 | ARangeOp::ARangeOp( |
224 | IrBuilderPasskey passkey, |
225 | Val* out, |
226 | Val* start, |
227 | Val* end, |
228 | Val* step, |
229 | DataType dtype, |
230 | Val* linear_index) |
231 | : Expr(passkey, ExprType::ARangeOp), |
232 | dtype_(dtype), |
233 | start_(start), |
234 | end_(end), |
235 | step_(step), |
236 | linear_index_(linear_index) { |
237 | addInput(start); |
238 | addInput(end); |
239 | addInput(step); |
240 | addOutput(out); |
241 | } |
242 | |
243 | ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner) |
244 | : Expr(src, ir_cloner), |
245 | dtype_(src->dtype()), |
246 | start_(ir_cloner->clone(src->start_)), |
247 | end_(ir_cloner->clone(src->end_)), |
248 | step_(ir_cloner->clone(src->step_)), |
249 | linear_index_(ir_cloner->clone(src->linear_index_)) {} |
250 | |
251 | Expr* ARangeOp::shallowCopy() const { |
252 | auto result = IrBuilder::create<ARangeOp>( |
253 | output(0), start_, end_, step_, dtype_, linear_index_); |
254 | result->copyPredicatesFrom(this); |
255 | return result; |
256 | } |
257 | |
258 | bool ARangeOp::sameAs(const Statement* other) const { |
259 | if (this == other) { |
260 | return true; |
261 | } |
262 | if (!other->isA<ARangeOp>()) { |
263 | return false; |
264 | } |
265 | const auto other_op = other->as<ARangeOp>(); |
266 | if (dtype_ != other_op->dtype_) { |
267 | return false; |
268 | } |
269 | if (!start_->sameAs(other_op->start_)) { |
270 | return false; |
271 | } |
272 | if (!end_->sameAs(other_op->end_)) { |
273 | return false; |
274 | } |
275 | if (!step_->sameAs(other_op->step_)) { |
276 | return false; |
277 | } |
278 | if ((linear_index_ == nullptr) != (other_op->linear_index_ == nullptr)) { |
279 | return false; |
280 | } |
281 | if ((linear_index_ != nullptr) && |
282 | !linear_index_->sameAs(other_op->linear_index_)) { |
283 | return false; |
284 | } |
285 | return Expr::sameAs(other); |
286 | } |
287 | |
288 | EyeOp::EyeOp( |
289 | IrBuilderPasskey passkey, |
290 | Val* out, |
291 | DataType dtype, |
292 | Val* index1, |
293 | Val* index2) |
294 | : Expr(passkey, ExprType::EyeOp), |
295 | dtype_(dtype), |
296 | index1_(index1), |
297 | index2_(index2) { |
298 | if (out->isA<TensorView>()) { |
299 | addInput(out->as<TensorView>()->getRootDomain()[0]->extent()); |
300 | if (out->as<TensorView>()->getRootDomain()[1] != |
301 | out->as<TensorView>()->getRootDomain()[0]) { |
302 | addInput(out->as<TensorView>()->getRootDomain()[1]->extent()); |
303 | } |
304 | } |
305 | addOutput(out); |
306 | } |
307 | |
308 | EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner) |
309 | : Expr(src, ir_cloner), |
310 | dtype_(src->dtype_), |
311 | index1_(ir_cloner->clone(src->index1_)), |
312 | index2_(ir_cloner->clone(src->index2_)) {} |
313 | |
314 | Expr* EyeOp::shallowCopy() const { |
315 | auto result = IrBuilder::create<EyeOp>(output(0), dtype_, index1_, index2_); |
316 | result->copyPredicatesFrom(this); |
317 | return result; |
318 | } |
319 | |
320 | bool EyeOp::sameAs(const Statement* other) const { |
321 | if (this == other) { |
322 | return true; |
323 | } |
324 | if (!other->isA<EyeOp>()) { |
325 | return false; |
326 | } |
327 | const auto other_op = other->as<EyeOp>(); |
328 | if (dtype_ != other_op->dtype_) { |
329 | return false; |
330 | } |
331 | if ((index1_ == nullptr) != (other_op->index1_ == nullptr)) { |
332 | return false; |
333 | } |
334 | if ((index2_ == nullptr) != (other_op->index2_ == nullptr)) { |
335 | return false; |
336 | } |
337 | if ((index1_ != nullptr) && !index1_->sameAs(other_op->index1_)) { |
338 | return false; |
339 | } |
340 | if ((index2_ != nullptr) && !index2_->sameAs(other_op->index2_)) { |
341 | return false; |
342 | } |
343 | return Expr::sameAs(other); |
344 | } |
345 | |
346 | UnaryOp::UnaryOp( |
347 | IrBuilderPasskey passkey, |
348 | UnaryOpType type, |
349 | Val* out, |
350 | Val* in, |
351 | int rng_offset) |
352 | : Expr(passkey, ExprType::UnaryOp), |
353 | unary_op_type_{type}, |
354 | out_{out}, |
355 | in_{in} { |
356 | addOutput(out); |
357 | addInput(in); |
358 | } |
359 | |
360 | UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) |
361 | : Expr(src, ir_cloner), |
362 | unary_op_type_(src->unary_op_type_), |
363 | out_(ir_cloner->clone(src->out_)), |
364 | in_(ir_cloner->clone(src->in_)) {} |
365 | |
366 | Expr* UnaryOp::shallowCopy() const { |
367 | auto result = IrBuilder::create<UnaryOp>(unary_op_type_, out_, in_); |
368 | result->copyPredicatesFrom(this); |
369 | return result; |
370 | } |
371 | |
372 | bool UnaryOp::sameAs(const Statement* other) const { |
373 | if (this == other) { |
374 | return true; |
375 | } |
376 | if (!other->isA<UnaryOp>()) { |
377 | return false; |
378 | } |
379 | const auto other_op = other->as<UnaryOp>(); |
380 | if (getUnaryOpType() != other_op->getUnaryOpType()) { |
381 | return false; |
382 | } |
383 | return Expr::sameAs(other); |
384 | } |
385 | |
386 | BinaryOp::BinaryOp( |
387 | IrBuilderPasskey passkey, |
388 | BinaryOpType type, |
389 | Val* out, |
390 | Val* lhs, |
391 | Val* rhs) |
392 | : Expr(passkey, ExprType::BinaryOp), |
393 | binary_op_type_{type}, |
394 | out_{out}, |
395 | lhs_{lhs}, |
396 | rhs_{rhs} { |
397 | addOutput(out); |
398 | addInput(lhs); |
399 | addInput(rhs); |
400 | } |
401 | |
402 | BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) |
403 | : Expr(src, ir_cloner), |
404 | binary_op_type_(src->binary_op_type_), |
405 | out_(ir_cloner->clone(src->out_)), |
406 | lhs_(ir_cloner->clone(src->lhs_)), |
407 | rhs_(ir_cloner->clone(src->rhs_)) {} |
408 | |
409 | Expr* BinaryOp::shallowCopy() const { |
410 | auto result = IrBuilder::create<BinaryOp>(binary_op_type_, out_, lhs_, rhs_); |
411 | result->copyPredicatesFrom(this); |
412 | return result; |
413 | } |
414 | |
415 | bool BinaryOp::sameAs(const Statement* other) const { |
416 | if (this == other) { |
417 | return true; |
418 | } |
419 | if (!other->isA<BinaryOp>()) { |
420 | return false; |
421 | } |
422 | const auto other_op = other->as<BinaryOp>(); |
423 | if (getBinaryOpType() != other_op->getBinaryOpType()) { |
424 | return false; |
425 | } |
426 | return Expr::sameAs(other); |
427 | } |
428 | |
429 | TernaryOp::TernaryOp( |
430 | IrBuilderPasskey passkey, |
431 | TernaryOpType type, |
432 | Val* out, |
433 | Val* in1, |
434 | Val* in2, |
435 | Val* in3) |
436 | : Expr(passkey, ExprType::TernaryOp), |
437 | ternary_op_type_{type}, |
438 | out_{out}, |
439 | in1_{in1}, |
440 | in2_{in2}, |
441 | in3_{in3} { |
442 | addOutput(out); |
443 | addInput(in1); |
444 | addInput(in2); |
445 | addInput(in3); |
446 | } |
447 | |
448 | TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) |
449 | : Expr(src, ir_cloner), |
450 | ternary_op_type_(src->ternary_op_type_), |
451 | out_(ir_cloner->clone(src->out_)), |
452 | in1_(ir_cloner->clone(src->in1_)), |
453 | in2_(ir_cloner->clone(src->in2_)), |
454 | in3_(ir_cloner->clone(src->in3_)) {} |
455 | |
456 | Expr* TernaryOp::shallowCopy() const { |
457 | auto result = |
458 | IrBuilder::create<TernaryOp>(ternary_op_type_, out_, in1_, in2_, in3_); |
459 | result->copyPredicatesFrom(this); |
460 | return result; |
461 | } |
462 | |
463 | bool TernaryOp::sameAs(const Statement* other) const { |
464 | if (this == other) { |
465 | return true; |
466 | } |
467 | if (!other->isA<TernaryOp>()) { |
468 | return false; |
469 | } |
470 | const auto other_op = other->as<TernaryOp>(); |
471 | if (getTernaryOpType() != other_op->getTernaryOpType()) { |
472 | return false; |
473 | } |
474 | return Expr::sameAs(other); |
475 | } |
476 | |
477 | RNGOp::RNGOp( |
478 | IrBuilderPasskey passkey, |
479 | RNGOpType type, |
480 | Val* out, |
481 | DataType dtype, |
482 | std::vector<Val*> parameters, |
483 | int rng_offset, |
484 | Val* philox_index) |
485 | : Expr(passkey, ExprType::RNGOp), |
486 | rng_op_type_(type), |
487 | dtype_(dtype), |
488 | parameters_(std::move(parameters)), |
489 | rng_offset_(rng_offset), |
490 | philox_index_(philox_index) { |
491 | if (out->isA<TensorView>()) { |
492 | for (auto id : out->as<TensorView>()->getRootDomain()) { |
493 | shape_.emplace_back(id->extent()); |
494 | } |
495 | } |
496 | for (auto v : shape_) { |
497 | addInput(v); |
498 | } |
499 | for (auto v : parameters_) { |
500 | addInput(v); |
501 | } |
502 | addOutput(out); |
503 | } |
504 | |
505 | RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) |
506 | : Expr(src, ir_cloner), |
507 | rng_op_type_(src->rng_op_type_), |
508 | dtype_(src->dtype()), |
509 | parameters_(ir_cloner->clone(src->parameters_)), |
510 | rng_offset_(src->rng_offset_), |
511 | philox_index_(ir_cloner->clone(src->philox_index_)) {} |
512 | |
513 | Expr* RNGOp::shallowCopy() const { |
514 | auto result = IrBuilder::create<RNGOp>( |
515 | rng_op_type_, output(0), dtype_, parameters_, rng_offset_, philox_index_); |
516 | result->copyPredicatesFrom(this); |
517 | return result; |
518 | } |
519 | |
520 | bool RNGOp::sameAs(const Statement* other) const { |
521 | if (this == other) { |
522 | return true; |
523 | } |
524 | if (!other->isA<RNGOp>()) { |
525 | return false; |
526 | } |
527 | const auto other_op = other->as<RNGOp>(); |
528 | if (getRNGOpType() != other_op->getRNGOpType()) { |
529 | return false; |
530 | } |
531 | if (dtype_ != other_op->dtype_) { |
532 | return false; |
533 | } |
534 | if (parameters_.size() != other_op->parameters_.size()) { |
535 | return false; |
536 | } |
537 | for (auto i : c10::irange(parameters_.size())) { |
538 | if (!parameters_[i]->sameAs(other_op->parameters_[i])) { |
539 | return false; |
540 | } |
541 | } |
542 | if (getRNGOffset() != other_op->getRNGOffset()) { |
543 | return false; |
544 | } |
545 | if ((philox_index_ == nullptr) != (other_op->philox_index_ == nullptr)) { |
546 | return false; |
547 | } |
548 | if ((philox_index_ != nullptr) && |
549 | !philox_index_->sameAs(other_op->philox_index_)) { |
550 | return false; |
551 | } |
552 | return Expr::sameAs(other); |
553 | } |
554 | |
555 | BroadcastOp::BroadcastOp( |
556 | IrBuilderPasskey passkey, |
557 | Val* out, |
558 | Val* in, |
559 | std::vector<bool> is_broadcast_dims) |
560 | : Expr(passkey, ExprType::BroadcastOp), |
561 | out_(out), |
562 | in_(in), |
563 | is_broadcast_dims_(std::move(is_broadcast_dims)) { |
564 | // clang-tidy complains about out_ that it may be null. |
565 | TORCH_INTERNAL_ASSERT(out_ != nullptr); |
566 | TORCH_INTERNAL_ASSERT(in_ != nullptr); |
567 | |
568 | auto out_type = out->getValType().value(); |
569 | auto in_type = in->getValType().value(); |
570 | |
571 | TORCH_INTERNAL_ASSERT( |
572 | (out_type == ValType::TensorView && in_type == ValType::TensorView) || |
573 | (out_type == ValType::TensorIndex && in_type == ValType::TensorIndex), |
574 | "Cannot braodcast a non-tensor object." ); |
575 | |
576 | addOutput(out); |
577 | addInput(in); |
578 | |
579 | if (!out->isA<TensorView>() || !in->isA<TensorView>()) { |
580 | return; |
581 | } |
582 | |
583 | passkey.ir_container_->registerExpr(exprPasskey(), this); |
584 | |
585 | // This is a generic check that root dims of a consumer and producer match. |
586 | // Maybe we shouldn't relegate it to this constructor. |
587 | const auto c_tv = out_->as<TensorView>(); |
588 | const auto p_tv = in_->as<TensorView>(); |
589 | |
590 | const auto& c_root = c_tv->getRootDomain(); |
591 | const auto& p_root = p_tv->getMaybeRFactorDomain(); |
592 | |
593 | const auto root_p2c = |
594 | PairwiseRootDomainMap(p_tv, c_tv) |
595 | .mapProducerToConsumer(p_tv->domain(), c_tv->domain()); |
596 | |
597 | for (auto id : p_root) { |
598 | if (root_p2c.find(id) == root_p2c.end()) { |
599 | TORCH_INTERNAL_ASSERT( |
600 | id->isReduction() || id->isStride(), |
601 | "Invalid broadcast op: " , |
602 | id, |
603 | ". Non-reduction input dim doesn't match to output." ); |
604 | } |
605 | } |
606 | |
607 | std::unordered_set<IterDomain*> c_mapped; |
608 | for (auto pair_entry : root_p2c) { |
609 | c_mapped.insert(pair_entry.second); |
610 | } |
611 | |
612 | for (const auto i : c10::irange(c_root.size())) { |
613 | const auto c_id = c_root[i]; |
614 | if (c_mapped.find(c_id) != c_mapped.end()) { |
615 | continue; |
616 | } |
617 | TORCH_INTERNAL_ASSERT( |
618 | c_id->isBroadcast() && is_broadcast_dims_[i], |
619 | "Invalid broadcast op: " , |
620 | c_id, |
621 | ". Non-broadcasted output dim isn't matched from input." ); |
622 | } |
623 | } |
624 | |
625 | BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner) |
626 | : Expr(src, ir_cloner), |
627 | out_(ir_cloner->clone(src->out_)), |
628 | in_(ir_cloner->clone(src->in_)), |
629 | is_broadcast_dims_(src->is_broadcast_dims_) {} |
630 | |
631 | Expr* BroadcastOp::shallowCopy() const { |
632 | auto result = IrBuilder::create<BroadcastOp>(out_, in_, is_broadcast_dims_); |
633 | result->copyPredicatesFrom(this); |
634 | return result; |
635 | } |
636 | |
637 | bool BroadcastOp::sameAs(const Statement* other) const { |
638 | if (this == other) { |
639 | return true; |
640 | } |
641 | if (!other->isA<BroadcastOp>()) { |
642 | return false; |
643 | } |
644 | const auto other_op = other->as<BroadcastOp>(); |
645 | if (getBroadcastDimFlags() != other_op->getBroadcastDimFlags()) { |
646 | return false; |
647 | } |
648 | return Expr::sameAs(other); |
649 | } |
650 | |
651 | ReductionOp::ReductionOp( |
652 | IrBuilderPasskey passkey, |
653 | BinaryOpType reduction_op_type, |
654 | Val* init, |
655 | Val* out, |
656 | Val* in, |
657 | bool is_allreduce, |
658 | ExprType expr_type) |
659 | : Expr(passkey, expr_type), |
660 | reduction_op_type_(reduction_op_type), |
661 | init_(init), |
662 | out_(out), |
663 | in_(in), |
664 | is_allreduce_(is_allreduce) { |
665 | TORCH_CHECK( |
666 | out->getValType().value() == ValType::TensorView || |
667 | out->getValType().value() == ValType::TensorIndex); |
668 | |
669 | TORCH_INTERNAL_ASSERT( |
670 | (in->getValType() == ValType::TensorView && |
671 | out->getValType() == ValType::TensorView) || |
672 | (in->getValType() == ValType::TensorIndex && |
673 | out->getValType() == ValType::TensorIndex), |
674 | "Reduction operation was created that does not have tensor inputs and outputs." ); |
675 | |
676 | if (in->isA<TensorView>()) { |
677 | TORCH_INTERNAL_ASSERT( |
678 | TensorDomain::noReductions( |
679 | in->as<TensorView>()->getMaybeRFactorDomain()) |
680 | .size() == out->as<TensorView>()->getRootDomain().size(), |
681 | "Reduction operation created with mismatched domains." ); |
682 | } |
683 | TORCH_INTERNAL_ASSERT( |
684 | init->isConstScalar(), |
685 | "Tried to create a reduction operation whith an initial value that isn't a constant." ); |
686 | |
687 | addOutput(out); |
688 | addInput(in); |
689 | } |
690 | |
691 | ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) |
692 | : Expr(src, ir_cloner), |
693 | reduction_op_type_(src->reduction_op_type_), |
694 | init_(ir_cloner->clone(src->init_)), |
695 | out_(ir_cloner->clone(src->out_)), |
696 | in_(ir_cloner->clone(src->in_)), |
697 | is_allreduce_(src->is_allreduce_) {} |
698 | |
699 | Expr* ReductionOp::shallowCopy() const { |
700 | auto result = IrBuilder::create<ReductionOp>( |
701 | reduction_op_type_, init_, out_, in_, is_allreduce_, etype()); |
702 | result->copyPredicatesFrom(this); |
703 | return result; |
704 | } |
705 | |
706 | bool ReductionOp::sameAs(const Statement* other) const { |
707 | if (this == other) { |
708 | return true; |
709 | } |
710 | if (!other->isA<ReductionOp>()) { |
711 | return false; |
712 | } |
713 | const auto other_op = other->as<ReductionOp>(); |
714 | // Note that init is not part of input vals, so it must be checked separately. |
715 | return ( |
716 | Expr::sameAs(other) && |
717 | getReductionOpType() == other_op->getReductionOpType() && |
718 | init()->sameAs(other_op->init())); |
719 | } |
720 | |
721 | GroupedReductionOp::GroupedReductionOp( |
722 | IrBuilderPasskey passkey, |
723 | std::vector<BinaryOpType> reduction_op_types, |
724 | std::vector<Val*> init_vals, |
725 | std::vector<Val*> outputs, |
726 | std::vector<Val*> inputs, |
727 | bool is_fused, |
728 | ExprType expr_type) |
729 | : Expr(passkey, expr_type), |
730 | reduction_op_types_(std::move(reduction_op_types)), |
731 | init_vals_(std::move(init_vals)), |
732 | is_allreduce_(is_fused) { |
733 | for (auto out : outputs) { |
734 | addOutput(out); |
735 | } |
736 | |
737 | for (auto in : inputs) { |
738 | addInput(in); |
739 | } |
740 | } |
741 | |
742 | GroupedReductionOp::GroupedReductionOp( |
743 | const GroupedReductionOp* src, |
744 | IrCloner* ir_cloner) |
745 | : Expr(src, ir_cloner), |
746 | reduction_op_types_(src->reduction_op_types_), |
747 | init_vals_(ir_cloner->clone(src->init_vals_)), |
748 | is_allreduce_(src->is_allreduce_) {} |
749 | |
750 | Expr* GroupedReductionOp::shallowCopy() const { |
751 | auto result = IrBuilder::create<GroupedReductionOp>( |
752 | reduction_op_types_, |
753 | init_vals_, |
754 | outputs(), |
755 | inputs(), |
756 | is_allreduce_, |
757 | etype()); |
758 | result->copyPredicatesFrom(this); |
759 | return result; |
760 | } |
761 | |
762 | int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { |
763 | auto it = std::find(outputs().begin(), outputs().end(), output_val); |
764 | if (it != outputs().end()) { |
765 | return std::distance(outputs().begin(), it); |
766 | } |
767 | |
768 | TORCH_INTERNAL_ASSERT( |
769 | false, "Not an output, " , output_val->toString(), ", of " , toString()); |
770 | } |
771 | |
772 | bool GroupedReductionOp::sameAs(const Statement* other) const { |
773 | if (this == other) { |
774 | return true; |
775 | } |
776 | |
777 | auto grouped_rop = dynamic_cast<const GroupedReductionOp*>(other); |
778 | if (grouped_rop == nullptr) { |
779 | return false; |
780 | } |
781 | |
782 | if (!Expr::sameAs(other) || |
783 | getReductionOpTypes() != grouped_rop->getReductionOpTypes()) { |
784 | return false; |
785 | } |
786 | |
787 | for (const auto i : c10::irange(numExprs())) { |
788 | if (!initVal(i)->sameAs(grouped_rop->initVal(i))) { |
789 | return false; |
790 | } |
791 | } |
792 | |
793 | return true; |
794 | } |
795 | |
796 | WelfordOp::WelfordOp( |
797 | IrBuilderPasskey passkey, |
798 | const WelfordTriplet& output, |
799 | const WelfordTriplet& input, |
800 | const WelfordTriplet& init, |
801 | bool is_fused) |
802 | : Expr(passkey, ExprType::WelfordOp), |
803 | output_(output), |
804 | input_(input), |
805 | init_(init), |
806 | is_allreduce_(is_fused) { |
807 | // Previously, nullptr was accepted and implicitly replaced by |
808 | // default values. Looks like we always pass some non-null values, |
809 | // so removed the implicit default behavior for code simplicity. |
810 | TORCH_INTERNAL_ASSERT(output.avg() != nullptr); |
811 | TORCH_INTERNAL_ASSERT(output.var() != nullptr); |
812 | TORCH_INTERNAL_ASSERT(output.N() != nullptr); |
813 | TORCH_INTERNAL_ASSERT(init.avg() != nullptr); |
814 | TORCH_INTERNAL_ASSERT(init.var() != nullptr); |
815 | TORCH_INTERNAL_ASSERT(init.N() != nullptr); |
816 | TORCH_INTERNAL_ASSERT(input.avg() != nullptr); |
817 | TORCH_INTERNAL_ASSERT(input.var() != nullptr); |
818 | TORCH_INTERNAL_ASSERT(input.N() != nullptr); |
819 | |
820 | // Check output type |
821 | TORCH_INTERNAL_ASSERT( |
822 | output.avg()->getValType().value() == ValType::TensorView || |
823 | output.avg()->getValType().value() == ValType::TensorIndex); |
824 | TORCH_INTERNAL_ASSERT( |
825 | output.var()->getValType().value() == ValType::TensorView || |
826 | output.var()->getValType().value() == ValType::TensorIndex); |
827 | TORCH_INTERNAL_ASSERT( |
828 | output.N()->getValType().value() == ValType::TensorView || |
829 | output.N()->getValType().value() == ValType::TensorIndex); |
830 | TORCH_INTERNAL_ASSERT(isIntegralType(output.N()->dtype())); |
831 | |
832 | // check initial value |
833 | TORCH_INTERNAL_ASSERT(init.N()->getValType().value() == ValType::Scalar); |
834 | TORCH_INTERNAL_ASSERT(isIntegralType(init.N()->dtype())); |
835 | if (!init.N()->isZeroInt()) { |
836 | // when initial count is zero, no initial variance or average is needed |
837 | // initial value with a count of 1 is un-common enough that I'll push |
838 | // the responsibility of creating all-zero var tensors to the user |
839 | TORCH_INTERNAL_ASSERT( |
840 | init_.avg()->getValType().value() == ValType::TensorView || |
841 | init_.avg()->getValType().value() == ValType::TensorIndex); |
842 | TORCH_INTERNAL_ASSERT( |
843 | init_.var()->getValType().value() == ValType::TensorView || |
844 | init_.var()->getValType().value() == ValType::TensorIndex, |
845 | "Invalid initial var: " , |
846 | init_.var()->toString()); |
847 | } |
848 | |
849 | // check input |
850 | TORCH_INTERNAL_ASSERT( |
851 | input_.avg()->getValType().value() == ValType::TensorView || |
852 | input_.avg()->getValType().value() == ValType::TensorIndex, |
853 | input_.avg()->getValType().value()); |
854 | TORCH_INTERNAL_ASSERT( |
855 | input_.N()->getValType().value() == ValType::Scalar || |
856 | input_.N()->getValType().value() == ValType::TensorView || |
857 | input_.N()->getValType().value() == ValType::TensorIndex); |
858 | TORCH_INTERNAL_ASSERT(isIntegralType(input_.N()->dtype())); |
859 | if (!input_.N()->isOneInt()) { |
860 | // when input is only one value, only the value is required through avg |
861 | // input the var part is implicitly 0 and codegen will handle that. |
862 | TORCH_INTERNAL_ASSERT( |
863 | input_.var()->getValType().value() == ValType::TensorView || |
864 | input_.var()->getValType().value() == ValType::TensorIndex); |
865 | } else { |
866 | TORCH_INTERNAL_ASSERT( |
867 | input_.var() == nullptr || input_.var()->isZeroInt(), |
868 | "Invalid var input, which must be either nullptr or scalar zero when the N input is one." ); |
869 | } |
870 | |
871 | addOutput(output_.avg()); |
872 | addOutput(output_.var()); |
873 | addOutput(output_.N()); |
874 | |
875 | addInput(input_.avg()); |
876 | addInput(input_.var()); |
877 | addInput(input_.N()); |
878 | } |
879 | |
880 | c10::optional<WelfordTriplet::ValName> WelfordTriplet::getNameOf( |
881 | Val* val) const { |
882 | auto it = std::find(begin(), end(), val); |
883 | if (it != end()) { |
884 | return indexToValName(std::distance(begin(), it)); |
885 | } |
886 | |
887 | return c10::optional<WelfordTriplet::ValName>(); |
888 | } |
889 | |
890 | bool WelfordTriplet::sameAs(const WelfordTriplet& other) const { |
891 | return this == &other || |
892 | (avg()->sameAs(other.avg()) && var()->sameAs(other.var()) && |
893 | N()->sameAs(other.N())); |
894 | } |
895 | |
896 | WelfordTriplet WelfordTriplet::clone(IrCloner* ir_cloner) const { |
897 | return transform([&](const Val* val) { return ir_cloner->clone<Val>(val); }); |
898 | } |
899 | |
900 | std::vector<WelfordTriplet> WelfordTriplet::clone( |
901 | const std::vector<WelfordTriplet>& src, |
902 | IrCloner* ir_cloner) { |
903 | std::vector<WelfordTriplet> cloned; |
904 | for (const auto& triplet : src) { |
905 | cloned.emplace_back(triplet.clone(ir_cloner)); |
906 | } |
907 | return cloned; |
908 | } |
909 | |
910 | WelfordOp::WelfordOp( |
911 | IrBuilderPasskey passkey, |
912 | Val* out_avg, |
913 | Val* out_var, |
914 | Val* out_N, |
915 | Val* in_avg, |
916 | Val* in_var, |
917 | Val* in_N, |
918 | Val* init_avg, |
919 | Val* init_var, |
920 | Val* init_N, |
921 | bool is_fused) |
922 | : WelfordOp( |
923 | passkey, |
924 | WelfordTriplet(out_avg, out_var, out_N), |
925 | WelfordTriplet(in_avg, in_var, in_N), |
926 | WelfordTriplet(init_avg, init_var, init_N), |
927 | is_fused) {} |
928 | |
929 | WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) |
930 | : Expr(src, ir_cloner), |
931 | output_(src->output_.clone(ir_cloner)), |
932 | input_(src->input_.clone(ir_cloner)), |
933 | init_(src->init_.clone(ir_cloner)), |
934 | is_allreduce_(src->is_allreduce_) {} |
935 | |
936 | Expr* WelfordOp::shallowCopy() const { |
937 | auto result = |
938 | IrBuilder::create<WelfordOp>(output_, input_, init_, is_allreduce_); |
939 | result->copyPredicatesFrom(this); |
940 | return result; |
941 | } |
942 | |
943 | Val* WelfordOp::getInitValOfOutput(Val* output_val) const { |
944 | auto val_name = output().getNameOf(output_val); |
945 | |
946 | TORCH_INTERNAL_ASSERT( |
947 | val_name.has_value(), |
948 | "Not an output val " , |
949 | output_val->toString(), |
950 | " of " , |
951 | toString()); |
952 | |
953 | return init().get(*val_name); |
954 | } |
955 | |
956 | bool WelfordOp::sameAs(const Statement* other) const { |
957 | if (this == other) { |
958 | return true; |
959 | } |
960 | if (auto other_wop = dynamic_cast<const WelfordOp*>(other)) { |
961 | return input_.sameAs(other_wop->input_) && init_.sameAs(other_wop->init_); |
962 | } |
963 | return false; |
964 | } |
965 | |
966 | std::vector<Val*> WelfordOp::getInitVals() const { |
967 | std::vector<Val*> init_vals({init_.avg(), init_.var(), init_.N()}); |
968 | return init_vals; |
969 | } |
970 | |
971 | GroupedWelfordOp::GroupedWelfordOp( |
972 | IrBuilderPasskey passkey, |
973 | std::vector<WelfordTriplet> output_vals, |
974 | std::vector<WelfordTriplet> input_vals, |
975 | std::vector<WelfordTriplet> init_vals, |
976 | bool is_allreduce, |
977 | ExprType expr_type) |
978 | : Expr(passkey, expr_type), |
979 | output_vals_(std::move(output_vals)), |
980 | input_vals_(std::move(input_vals)), |
981 | init_vals_(std::move(init_vals)), |
982 | is_allreduce_(is_allreduce) { |
983 | const auto num_grouped_ops = output_vals_.size(); |
984 | |
985 | TORCH_INTERNAL_ASSERT( |
986 | input_vals_.size() == num_grouped_ops, |
987 | "Invalid number of input arguments. Expected: " , |
988 | num_grouped_ops, |
989 | ", Given: " , |
990 | input_vals_.size()); |
991 | TORCH_INTERNAL_ASSERT( |
992 | init_vals_.size() == num_grouped_ops, |
993 | "Invalid number of N arguments. Expected: " , |
994 | num_grouped_ops, |
995 | ", Given: " , |
996 | init_vals_.size()); |
997 | |
998 | for (const auto i : c10::irange(num_grouped_ops)) { |
999 | // Check output type |
1000 | TORCH_INTERNAL_ASSERT( |
1001 | output_vals_[i].avg()->getValType().value() == ValType::TensorView || |
1002 | output_vals_[i].avg()->getValType().value() == ValType::TensorIndex); |
1003 | TORCH_INTERNAL_ASSERT( |
1004 | output_vals_[i].var()->getValType().value() == ValType::TensorView || |
1005 | output_vals_[i].var()->getValType().value() == ValType::TensorIndex); |
1006 | TORCH_INTERNAL_ASSERT( |
1007 | output_vals_[i].N()->getValType().value() == ValType::TensorView || |
1008 | output_vals_[i].N()->getValType().value() == ValType::TensorIndex); |
1009 | TORCH_INTERNAL_ASSERT(isIntegralType(output_vals_[i].N()->dtype())); |
1010 | |
1011 | // check initial value |
1012 | auto init_avg = init_vals_[i].avg(); |
1013 | auto init_var = init_vals_[i].var(); |
1014 | auto init_N = init_vals_[i].N(); |
1015 | TORCH_INTERNAL_ASSERT( |
1016 | init_avg != nullptr && init_var != nullptr && init_N != nullptr, |
1017 | "nullptr init vals are not allowed" ); |
1018 | TORCH_INTERNAL_ASSERT(init_N->getValType().value() == ValType::Scalar); |
1019 | TORCH_INTERNAL_ASSERT(isIntegralType(init_N->dtype())); |
1020 | TORCH_INTERNAL_ASSERT( |
1021 | init_avg->getValType().value() == ValType::TensorView || |
1022 | init_avg->getValType().value() == ValType::TensorIndex || |
1023 | (init_N->isZeroInt() && |
1024 | init_avg->getValType().value() == ValType::Scalar), |
1025 | "Initial avg must be a tensor or, can be a scalar if initial N is zero." , |
1026 | " Initial avg: " , |
1027 | init_avg->toString(), |
1028 | ". Initial N: " , |
1029 | init_N->toString()); |
1030 | TORCH_INTERNAL_ASSERT( |
1031 | init_var->getValType().value() == ValType::TensorView || |
1032 | init_var->getValType().value() == ValType::TensorIndex || |
1033 | (init_N->isZeroInt() && |
1034 | init_var->getValType().value() == ValType::Scalar), |
1035 | "Initial var must be a tensor or, can be a scalar if initial N is zero: " , |
1036 | init_var->toString()); |
1037 | |
1038 | // check input |
1039 | auto in_avg = input_vals_[i].avg(); |
1040 | auto in_var = input_vals_[i].var(); |
1041 | auto in_N = input_vals_[i].N(); |
1042 | TORCH_INTERNAL_ASSERT( |
1043 | in_avg != nullptr && in_var != nullptr && in_N != nullptr, |
1044 | "nullptr input vals are not allowed" ); |
1045 | TORCH_INTERNAL_ASSERT( |
1046 | in_N->getValType().value() == ValType::Scalar || |
1047 | in_N->getValType().value() == ValType::TensorView || |
1048 | in_N->getValType().value() == ValType::TensorIndex); |
1049 | TORCH_INTERNAL_ASSERT(isIntegralType(in_N->dtype())); |
1050 | TORCH_INTERNAL_ASSERT( |
1051 | in_avg->getValType().value() == ValType::TensorView || |
1052 | in_avg->getValType().value() == ValType::TensorIndex, |
1053 | "Invalid input avg argument type: " , |
1054 | in_avg->getValType().value()); |
1055 | |
1056 | if (in_N->isOneInt()) { |
1057 | // when input is only one value, only the value is required through avg |
1058 | // input the var part must be implicitly 0 |
1059 | TORCH_INTERNAL_ASSERT( |
1060 | in_var->isZeroInt(), |
1061 | "Invalid var input, which must be scalar zero when the N input is one: " , |
1062 | in_var->toString()); |
1063 | } else { |
1064 | TORCH_INTERNAL_ASSERT( |
1065 | in_var->getValType().value() == ValType::TensorView || |
1066 | in_var->getValType().value() == ValType::TensorIndex, |
1067 | in_var->getValType().value(), |
1068 | ", " , |
1069 | in_N->toString()); |
1070 | } |
1071 | } |
1072 | |
1073 | for (const auto i : c10::irange(num_grouped_ops)) { |
1074 | addOutput(output_vals_[i].avg()); |
1075 | addOutput(output_vals_[i].var()); |
1076 | addOutput(output_vals_[i].N()); |
1077 | addInput(input_vals_[i].avg()); |
1078 | addInput(input_vals_[i].var()); |
1079 | addInput(input_vals_[i].N()); |
1080 | } |
1081 | } |
1082 | |
1083 | GroupedWelfordOp::GroupedWelfordOp( |
1084 | const GroupedWelfordOp* src, |
1085 | IrCloner* ir_cloner) |
1086 | : Expr(src, ir_cloner), |
1087 | output_vals_(WelfordTriplet::clone(src->output_vals_, ir_cloner)), |
1088 | input_vals_(WelfordTriplet::clone(src->input_vals_, ir_cloner)), |
1089 | init_vals_(WelfordTriplet::clone(src->init_vals_, ir_cloner)), |
1090 | is_allreduce_(src->is_allreduce_) {} |
1091 | |
1092 | Expr* GroupedWelfordOp::shallowCopy() const { |
1093 | auto result = IrBuilder::create<GroupedWelfordOp>( |
1094 | output_vals_, input_vals_, init_vals_, is_allreduce_, etype()); |
1095 | result->copyPredicatesFrom(this); |
1096 | return result; |
1097 | } |
1098 | |
1099 | bool GroupedWelfordOp::sameAs(const Statement* other) const { |
1100 | if (this == other) { |
1101 | return true; |
1102 | } |
1103 | |
1104 | auto grouped_op = dynamic_cast<const GroupedWelfordOp*>(other); |
1105 | if (grouped_op == nullptr) { |
1106 | return false; |
1107 | } |
1108 | |
1109 | if (!Expr::sameAs(other)) { |
1110 | return false; |
1111 | } |
1112 | |
1113 | for (const auto i : c10::irange(numExprs())) { |
1114 | if (!initAvg(i)->sameAs(grouped_op->initAvg(i)) || |
1115 | !initVar(i)->sameAs(grouped_op->initVar(i)) || |
1116 | !initN(i)->sameAs(grouped_op->initN(i))) { |
1117 | return false; |
1118 | } |
1119 | } |
1120 | |
1121 | return true; |
1122 | } |
1123 | |
1124 | int GroupedWelfordOp::getExprIndexOfOutput(Val* output_val) const { |
1125 | for (const auto expr_idx : c10::irange(numExprs())) { |
1126 | if (outputVals().at(expr_idx).getNameOf(output_val).has_value()) { |
1127 | return expr_idx; |
1128 | } |
1129 | } |
1130 | |
1131 | TORCH_INTERNAL_ASSERT( |
1132 | false, "Not an output, " , output_val->toString(), ", of " , toString()); |
1133 | } |
1134 | |
1135 | Val* GroupedWelfordOp::getInitValOfOutput(Val* output_val) const { |
1136 | auto expr_index = getExprIndexOfOutput(output_val); |
1137 | |
1138 | auto val_name = outputVals().at(expr_index).getNameOf(output_val).value(); |
1139 | |
1140 | return initVals().at(expr_index).get(val_name); |
1141 | } |
1142 | |
1143 | MmaOp::MmaOp( |
1144 | IrBuilderPasskey passkey, |
1145 | Val* out, |
1146 | Val* in_a, |
1147 | Val* in_b, |
1148 | Val* init) |
1149 | : Expr(passkey, ExprType::MmaOp), |
1150 | out_(out), |
1151 | in_a_(in_a), |
1152 | in_b_(in_b), |
1153 | init_(init) { |
1154 | // Check output type |
1155 | TORCH_INTERNAL_ASSERT( |
1156 | out->getValType().value() == ValType::TensorView || |
1157 | out->getValType().value() == ValType::TensorIndex); |
1158 | |
1159 | TORCH_INTERNAL_ASSERT( |
1160 | in_a->getValType().value() == ValType::TensorView || |
1161 | in_a->getValType().value() == ValType::TensorIndex, |
1162 | in_a->getValType().value()); |
1163 | |
1164 | TORCH_INTERNAL_ASSERT( |
1165 | in_b->getValType().value() == ValType::TensorView || |
1166 | in_b->getValType().value() == ValType::TensorIndex, |
1167 | in_b->getValType().value()); |
1168 | |
1169 | addOutput(out); |
1170 | addInput(in_a); |
1171 | addInput(in_b); |
1172 | } |
1173 | |
1174 | MmaOp::MmaOp( |
1175 | IrBuilderPasskey passkey, |
1176 | Val* out, |
1177 | Val* in_a, |
1178 | Val* in_b, |
1179 | Val* init, |
1180 | OptionsInMma options) |
1181 | : MmaOp(passkey, out, in_a, in_b, init) { |
1182 | options_ = options; |
1183 | } |
1184 | |
1185 | MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) |
1186 | : Expr(src, ir_cloner), |
1187 | out_(ir_cloner->clone(src->out_)), |
1188 | in_a_(ir_cloner->clone(src->in_a_)), |
1189 | in_b_(ir_cloner->clone(src->in_b_)), |
1190 | init_(ir_cloner->clone(src->init_)), |
1191 | options_(src->options_) {} |
1192 | |
1193 | Expr* MmaOp::shallowCopy() const { |
1194 | auto result = IrBuilder::create<MmaOp>(out_, in_a_, in_b_, init_); |
1195 | result->options_ = options_; |
1196 | result->copyPredicatesFrom(this); |
1197 | return result; |
1198 | } |
1199 | |
1200 | bool MmaOp::sameAs(const Statement* other) const { |
1201 | if (this == other) { |
1202 | return true; |
1203 | } |
1204 | if (auto other_mma = dynamic_cast<const MmaOp*>(other)) { |
1205 | return out_->sameAs(other_mma->out_) && in_a_->sameAs(other_mma->in_a_) && |
1206 | in_b_->sameAs(other_mma->in_b_) && init_->sameAs(other_mma->init_) && |
1207 | options_ == other_mma->options_; |
1208 | } |
1209 | return false; |
1210 | } |
1211 | |
1212 | TransposeOp::TransposeOp( |
1213 | IrBuilderPasskey passkey, |
1214 | TensorView* out, |
1215 | TensorView* in, |
1216 | std::vector<int64_t> new2old) |
1217 | : Expr(passkey, ExprType::TransposeOp), |
1218 | out_(out), |
1219 | in_(in), |
1220 | new2old_(std::move(new2old)) { |
1221 | // Sanity check of the input parameters. Maybe not necessary as they |
1222 | // should be checked at function transpose. |
1223 | |
1224 | TORCH_INTERNAL_ASSERT( |
1225 | TensorDomain::noReductions(in->getMaybeRFactorDomain()).size() == |
1226 | out->getMaybeRFactorDomain().size()); |
1227 | |
1228 | TORCH_INTERNAL_ASSERT(new2old_.size() == out->getMaybeRFactorDomain().size()); |
1229 | |
1230 | // Make sure the entries of new2old are unique and range from 0 to |
1231 | // N-1, where N == new2old.size(). |
1232 | std::set<int64_t> old_positions(new2old_.begin(), new2old_.end()); |
1233 | TORCH_INTERNAL_ASSERT(old_positions.size() == new2old_.size()); |
1234 | // old_positions is sorted, so the first entry must be 0. |
1235 | TORCH_INTERNAL_ASSERT( |
1236 | *(old_positions.begin()) == 0, |
1237 | "Invalid new2old vector detected: " , |
1238 | new2old_); |
1239 | // The last entry must be N-1, since old_positions is sorted, starts |
1240 | // with 0, and its length is N. |
1241 | TORCH_INTERNAL_ASSERT( |
1242 | *(old_positions.rbegin()) == (int)(new2old_.size() - 1), |
1243 | "Invalid new2old vector detected: " , |
1244 | new2old_); |
1245 | |
1246 | addOutput(out); |
1247 | addInput(in); |
1248 | } |
1249 | |
1250 | TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) |
1251 | : Expr(src, ir_cloner), |
1252 | out_(ir_cloner->clone(src->out_)), |
1253 | in_(ir_cloner->clone(src->in_)), |
1254 | new2old_(src->new2old_) {} |
1255 | |
1256 | Expr* TransposeOp::shallowCopy() const { |
1257 | auto result = IrBuilder::create<TransposeOp>(out_, in_, new2old_); |
1258 | result->copyPredicatesFrom(this); |
1259 | return result; |
1260 | } |
1261 | |
1262 | std::vector<int64_t> TransposeOp::old2new() const { |
1263 | std::vector<int64_t> old2new(new2old_.size()); |
1264 | for (auto new_axis : c10::irange(new2old_.size())) { |
1265 | auto old_axis = new2old_.at(new_axis); |
1266 | old2new[old_axis] = new_axis; |
1267 | } |
1268 | return old2new; |
1269 | } |
1270 | |
1271 | ExpandOp::ExpandOp( |
1272 | IrBuilderPasskey passkey, |
1273 | TensorView* out, |
1274 | TensorView* in, |
1275 | std::vector<Val*> _expanded_extents) |
1276 | : Expr(passkey, ExprType::ExpandOp), |
1277 | out_(out), |
1278 | in_(in), |
1279 | expanded_extents_(std::move(_expanded_extents)) { |
1280 | addOutput(out); |
1281 | addInput(in); |
1282 | for (auto expanded_extent : expanded_extents_) { |
1283 | TORCH_INTERNAL_ASSERT(expanded_extent != nullptr); |
1284 | TORCH_INTERNAL_ASSERT( |
1285 | expanded_extent->dtype() == DataType::Int, |
1286 | "Expanded extents must be of Int type." ); |
1287 | addInput(expanded_extent); |
1288 | } |
1289 | } |
1290 | |
1291 | ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner) |
1292 | : Expr(src, ir_cloner), |
1293 | out_(ir_cloner->clone(src->out_)), |
1294 | in_(ir_cloner->clone(src->in_)) { |
1295 | expanded_extents_.reserve(src->expanded_extents_.size()); |
1296 | for (const auto expanded_extent : src->expanded_extents_) { |
1297 | expanded_extents_.push_back(ir_cloner->clone(expanded_extent)); |
1298 | } |
1299 | } |
1300 | |
1301 | Expr* ExpandOp::shallowCopy() const { |
1302 | auto result = IrBuilder::create<ExpandOp>(out_, in_, expanded_extents_); |
1303 | result->copyPredicatesFrom(this); |
1304 | return result; |
1305 | } |
1306 | |
1307 | ShiftOp::ShiftOp( |
1308 | IrBuilderPasskey passkey, |
1309 | Val* out, |
1310 | Val* in, |
1311 | std::vector<int> offsets, |
1312 | std::vector<int> pad_width) |
1313 | : Expr(passkey, ExprType::ShiftOp), |
1314 | out_(out), |
1315 | in_(in), |
1316 | offsets_(std::move(offsets)), |
1317 | pad_width_(std::move(pad_width)) { |
1318 | // clang-tidy complains about out_ that it may be null. |
1319 | TORCH_INTERNAL_ASSERT(out_ != nullptr); |
1320 | TORCH_INTERNAL_ASSERT(in_ != nullptr); |
1321 | |
1322 | auto out_type = out->getValType().value(); |
1323 | auto in_type = in->getValType().value(); |
1324 | |
1325 | TORCH_INTERNAL_ASSERT( |
1326 | out_type == ValType::TensorView && in_type == ValType::TensorView, |
1327 | "Cannot shift a non-tensor object." ); |
1328 | |
1329 | TORCH_INTERNAL_ASSERT( |
1330 | offsets_.size() == |
1331 | TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain()) |
1332 | .size(), |
1333 | "Invalid offset vector: " , |
1334 | offsets_); |
1335 | |
1336 | TORCH_INTERNAL_ASSERT( |
1337 | pad_width_.size() == |
1338 | TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain()) |
1339 | .size(), |
1340 | "Invalid padding width vector: " , |
1341 | pad_width_); |
1342 | |
1343 | addOutput(out); |
1344 | addInput(in); |
1345 | } |
1346 | |
1347 | ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) |
1348 | : Expr(src, ir_cloner), |
1349 | out_(ir_cloner->clone(src->out_)), |
1350 | in_(ir_cloner->clone(src->in_)), |
1351 | offsets_(src->offsets_), |
1352 | pad_width_(src->pad_width_) {} |
1353 | |
1354 | Expr* ShiftOp::shallowCopy() const { |
1355 | auto result = IrBuilder::create<ShiftOp>(out_, in_, offsets_, pad_width_); |
1356 | result->copyPredicatesFrom(this); |
1357 | return result; |
1358 | } |
1359 | |
1360 | bool ShiftOp::sameAs(const Statement* other) const { |
1361 | if (this == other) { |
1362 | return true; |
1363 | } |
1364 | if (!other->isA<ShiftOp>()) { |
1365 | return false; |
1366 | } |
1367 | const auto other_op = other->as<ShiftOp>(); |
1368 | if (offsets() != other_op->offsets()) { |
1369 | return false; |
1370 | } |
1371 | return Expr::sameAs(other); |
1372 | } |
1373 | |
1374 | GatherOp::GatherOp( |
1375 | IrBuilderPasskey passkey, |
1376 | Val* out, |
1377 | Val* in, |
1378 | std::vector<int> window_shape, |
1379 | std::vector<std::vector<int>> pad_width) |
1380 | : Expr(passkey, ExprType::GatherOp), |
1381 | out_(out), |
1382 | in_(in), |
1383 | window_shape_(std::move(window_shape)), |
1384 | pad_width_(std::move(pad_width)) { |
1385 | // clang-tidy complains about out_ that it may be null. |
1386 | TORCH_INTERNAL_ASSERT(out_ != nullptr); |
1387 | TORCH_INTERNAL_ASSERT(in_ != nullptr); |
1388 | |
1389 | auto out_type = out->getValType().value(); |
1390 | auto in_type = in->getValType().value(); |
1391 | |
1392 | TORCH_INTERNAL_ASSERT( |
1393 | out_type == ValType::TensorView && in_type == ValType::TensorView, |
1394 | "Cannot shift a non-tensor object." ); |
1395 | |
1396 | const auto ndims = |
1397 | TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain()).size(); |
1398 | |
1399 | TORCH_INTERNAL_ASSERT( |
1400 | window_shape_.size() == ndims, |
1401 | "Invalid window_shape vector: " , |
1402 | window_shape_); |
1403 | TORCH_INTERNAL_ASSERT( |
1404 | pad_width_.size() == ndims, "Invalid pad_width vector: " , pad_width_); |
1405 | |
1406 | for (const auto& pad : pad_width_) { |
1407 | TORCH_INTERNAL_ASSERT( |
1408 | pad.size() == 2, "Padding size for each axis must have two Int vals." ); |
1409 | } |
1410 | |
1411 | addOutput(out); |
1412 | addInput(in); |
1413 | } |
1414 | |
1415 | GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) |
1416 | : Expr(src, ir_cloner), |
1417 | out_(ir_cloner->clone(src->out_)), |
1418 | in_(ir_cloner->clone(src->in_)), |
1419 | window_shape_(src->window_shape_), |
1420 | pad_width_(src->pad_width_) {} |
1421 | |
1422 | Expr* GatherOp::shallowCopy() const { |
1423 | auto result = |
1424 | IrBuilder::create<GatherOp>(out_, in_, window_shape_, pad_width_); |
1425 | result->copyPredicatesFrom(this); |
1426 | return result; |
1427 | } |
1428 | |
1429 | bool GatherOp::sameAs(const Statement* other) const { |
1430 | if (this == other) { |
1431 | return true; |
1432 | } |
1433 | if (!other->isA<GatherOp>()) { |
1434 | return false; |
1435 | } |
1436 | const auto other_op = other->as<GatherOp>(); |
1437 | if (windowShape() != other_op->windowShape() || |
1438 | padWidth() != other_op->padWidth()) { |
1439 | return false; |
1440 | } |
1441 | return Expr::sameAs(other); |
1442 | } |
1443 | |
1444 | int GatherOp::gatherAxis(int axis) const { |
1445 | if (axis < 0) { |
1446 | axis += out()->as<TensorView>()->nDims(); |
1447 | } |
1448 | TORCH_INTERNAL_ASSERT( |
1449 | axis >= 0 && axis < (int)windowShape().size(), "Invalid axis: " , axis); |
1450 | return int(windowShape().size()) + axis; |
1451 | } |
1452 | |
1453 | ViewAsScalar::ViewAsScalar( |
1454 | IrBuilderPasskey passkey, |
1455 | Val* out, |
1456 | Val* in, |
1457 | IterDomain* vector_id, |
1458 | Val* index) |
1459 | : Expr(passkey, ExprType::ViewAsScalar), |
1460 | out_(out), |
1461 | in_(in), |
1462 | vector_id_(vector_id), |
1463 | index_(index) { |
1464 | addOutput(out); |
1465 | addInput(in); |
1466 | } |
1467 | |
1468 | ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner) |
1469 | : Expr(src, ir_cloner), |
1470 | out_(ir_cloner->clone(src->out_)), |
1471 | in_(ir_cloner->clone(src->in_)), |
1472 | vector_id_(ir_cloner->clone(src->vector_id_)), |
1473 | index_(ir_cloner->clone(src->index_)) {} |
1474 | |
1475 | Expr* ViewAsScalar::shallowCopy() const { |
1476 | auto result = IrBuilder::create<ViewAsScalar>(out_, in_, vector_id_, index_); |
1477 | result->copyPredicatesFrom(this); |
1478 | return result; |
1479 | } |
1480 | |
1481 | ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) |
1482 | : Expr(passkey, ExprType::ViewOp), out_(out), in_(in) { |
1483 | addOutput(out); |
1484 | addInput(in); |
1485 | } |
1486 | |
1487 | ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) |
1488 | : Expr(src, ir_cloner), |
1489 | out_(ir_cloner->clone(src->out_)), |
1490 | in_(ir_cloner->clone(src->in_)) {} |
1491 | |
1492 | Expr* ViewOp::shallowCopy() const { |
1493 | auto result = IrBuilder::create<ViewOp>(out_, in_); |
1494 | result->copyPredicatesFrom(this); |
1495 | return result; |
1496 | } |
1497 | |
1498 | LoadStoreOp::LoadStoreOp( |
1499 | IrBuilderPasskey passkey, |
1500 | LoadStoreOpType op_type, |
1501 | Val* out, |
1502 | Val* in) |
1503 | : Expr(passkey, ExprType::LoadStoreOp), |
1504 | load_store_type_(op_type), |
1505 | out_(out), |
1506 | in_(in) { |
1507 | addOutput(out); |
1508 | addInput(in); |
1509 | } |
1510 | |
1511 | LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) |
1512 | : Expr(src, ir_cloner), |
1513 | load_store_type_(src->load_store_type_), |
1514 | out_(ir_cloner->clone(src->out_)), |
1515 | in_(ir_cloner->clone(src->in_)) {} |
1516 | |
1517 | Expr* LoadStoreOp::shallowCopy() const { |
1518 | auto result = IrBuilder::create<LoadStoreOp>(load_store_type_, out_, in_); |
1519 | result->copyPredicatesFrom(this); |
1520 | return result; |
1521 | } |
1522 | |
1523 | IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) |
1524 | : start_(_start), extent_(_extent) { |
1525 | TORCH_INTERNAL_ASSERT( |
1526 | start_ != nullptr && extent_ != nullptr, |
1527 | "Start and extent are required to build an iter domain." ); |
1528 | } |
1529 | |
1530 | IterDomainBuilder::IterDomainBuilder(const IterDomain* id) |
1531 | : start_(id->start()), |
1532 | extent_(id->extent()), |
1533 | expanded_extent_( |
1534 | id->hasExpandedExtent() ? id->expandedExtent() : nullptr), |
1535 | stop_offset_(id->stopOffset()), |
1536 | parallel_type_(id->getParallelType()), |
1537 | iter_type_(id->getIterType()), |
1538 | is_rfactor_domain_(id->isRFactorProduct()), |
1539 | is_padded_dimension_(id->hasPaddingToMultipleOfWarp()), |
1540 | padded_to_size_(id->getMaybeSizeAfterPadding()), |
1541 | is_mma_swizzled_(id->isMmaSwizzled()) {} |
1542 | |
1543 | IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() { |
1544 | parallel_type_ = ParallelType::Serial; |
1545 | is_rfactor_domain_ = false; |
1546 | is_padded_dimension_ = false; |
1547 | padded_to_size_ = c10::nullopt; |
1548 | is_mma_swizzled_ = false; |
1549 | return *this; |
1550 | } |
1551 | |
1552 | IterDomainBuilder& IterDomainBuilder::resetRfactor() { |
1553 | return is_rfactor_domain(false); |
1554 | } |
1555 | |
1556 | IterDomainBuilder& IterDomainBuilder::start(Val* _start) { |
1557 | start_ = _start; |
1558 | return *this; |
1559 | } |
1560 | |
1561 | IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) { |
1562 | extent_ = _extent; |
1563 | return *this; |
1564 | } |
1565 | |
1566 | IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) { |
1567 | expanded_extent_ = _expanded_extent; |
1568 | return *this; |
1569 | } |
1570 | |
1571 | IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) { |
1572 | stop_offset_ = _stop_offset; |
1573 | return *this; |
1574 | } |
1575 | |
1576 | IterDomainBuilder& IterDomainBuilder::parallel_type( |
1577 | ParallelType _parallel_type) { |
1578 | parallel_type_ = _parallel_type; |
1579 | return *this; |
1580 | } |
1581 | |
1582 | IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) { |
1583 | iter_type_ = _iter_type; |
1584 | return *this; |
1585 | } |
1586 | |
1587 | IterDomainBuilder& IterDomainBuilder::is_rfactor_domain( |
1588 | bool _is_rfactor_domain) { |
1589 | is_rfactor_domain_ = _is_rfactor_domain; |
1590 | return *this; |
1591 | } |
1592 | |
1593 | IterDomainBuilder& IterDomainBuilder::is_padded_dimension( |
1594 | bool _is_padded_dimension) { |
1595 | is_padded_dimension_ = _is_padded_dimension; |
1596 | return *this; |
1597 | } |
1598 | |
1599 | IterDomainBuilder& IterDomainBuilder::padded_to_size( |
1600 | c10::optional<int64_t> _padded_to_size) { |
1601 | padded_to_size_ = _padded_to_size; |
1602 | return *this; |
1603 | } |
1604 | |
1605 | IterDomainBuilder& IterDomainBuilder::is_mma_swizzled(bool _is_mma_swizzled) { |
1606 | is_mma_swizzled_ = _is_mma_swizzled; |
1607 | return *this; |
1608 | } |
1609 | |
1610 | IterDomain* IterDomainBuilder::build() const { |
1611 | TORCH_INTERNAL_ASSERT( |
1612 | start_ != nullptr && extent_ != nullptr, |
1613 | "Start and extent are required to build an iter domain." ); |
1614 | return IrBuilder::create<IterDomain>(start_->container(), *this); |
1615 | } |
1616 | |
1617 | IterDomain::IterDomain( |
1618 | IrBuilderPasskey passkey, |
1619 | Val* start, |
1620 | Val* extent, |
1621 | Val* expanded_extent, |
1622 | Val* stop_offset, |
1623 | ParallelType parallel_type, |
1624 | IterType iter_type, |
1625 | bool is_rfactor_domain, |
1626 | bool is_padded_dimension, |
1627 | c10::optional<int64_t> padded_to_size, |
1628 | bool is_mma_swizzled) |
1629 | : Val(passkey, ValType::IterDomain, DataType::Int), |
1630 | start_(start), |
1631 | extent_(extent), |
1632 | expanded_extent_(expanded_extent), |
1633 | stop_offset_( |
1634 | stop_offset == nullptr ? passkey.ir_container_->zeroVal() |
1635 | : stop_offset), |
1636 | parallel_type_(parallel_type), |
1637 | iter_type_(iter_type), |
1638 | is_rfactor_domain_(is_rfactor_domain), |
1639 | is_padded_dimension_(is_padded_dimension), |
1640 | padded_to_size_(padded_to_size), |
1641 | is_mma_swizzled_(is_mma_swizzled) { |
1642 | TORCH_CHECK( |
1643 | !(isRFactorProduct() && isBroadcast()), |
1644 | "IterDomain cannot be both a broadcast and rfactor domain." ); |
1645 | |
1646 | TORCH_INTERNAL_ASSERT( |
1647 | extent->isAnInt(), |
1648 | "Cannot create an iter domain over an extent that is not an int but received " , |
1649 | extent, |
1650 | " ." ); |
1651 | |
1652 | TORCH_INTERNAL_ASSERT( |
1653 | start->isAnInt(), |
1654 | "Cannot create an iter domain with a start that is not an int but received " , |
1655 | start, |
1656 | " ." ); |
1657 | } |
1658 | |
1659 | IterDomain::IterDomain(IrBuilderPasskey passkey, const IterDomainBuilder& args) |
1660 | |
1661 | : IterDomain( |
1662 | passkey, |
1663 | args.start_, |
1664 | args.extent_, |
1665 | args.expanded_extent_, |
1666 | args.stop_offset_, |
1667 | args.parallel_type_, |
1668 | args.iter_type_, |
1669 | args.is_rfactor_domain_, |
1670 | args.is_padded_dimension_, |
1671 | args.padded_to_size_, |
1672 | args.is_mma_swizzled_) {} |
1673 | |
1674 | IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) |
1675 | : Val(src, ir_cloner), |
1676 | start_(ir_cloner->clone(src->start_)), |
1677 | extent_(ir_cloner->clone(src->extent_)), |
1678 | expanded_extent_( |
1679 | src->hasExpandedExtent() ? ir_cloner->clone(src->expandedExtent()) |
1680 | : nullptr), |
1681 | stop_offset_(ir_cloner->clone(src->stop_offset_)), |
1682 | parallel_type_(src->parallel_type_), |
1683 | iter_type_(src->iter_type_), |
1684 | is_rfactor_domain_(src->is_rfactor_domain_), |
1685 | is_padded_dimension_(src->is_padded_dimension_), |
1686 | padded_to_size_(src->padded_to_size_), |
1687 | is_mma_swizzled_(src->is_mma_swizzled_) {} |
1688 | |
1689 | bool IterDomain::sameAs(const Statement* other) const { |
1690 | if (other == this) { |
1691 | return true; |
1692 | } |
1693 | |
1694 | if (!other->isA<IterDomain>()) { |
1695 | return false; |
1696 | } |
1697 | |
1698 | const IterDomain* other_id = other->as<IterDomain>(); |
1699 | |
1700 | bool is_same = isReduction() == other_id->isReduction() && |
1701 | getParallelType() == other_id->getParallelType() && |
1702 | isVectorComponent() == other_id->isVectorComponent(); |
1703 | is_same = is_same && ScalarCheck::sameAs(extent(), other_id->extent()); |
1704 | is_same = is_same && ScalarCheck::sameAs(start(), other_id->start()); |
1705 | is_same = |
1706 | is_same && ScalarCheck::sameAs(stopOffset(), other_id->stopOffset()); |
1707 | is_same = is_same && (hasExpandedExtent() == other_id->hasExpandedExtent()); |
1708 | if (is_same && hasExpandedExtent()) { |
1709 | is_same = ScalarCheck::sameAs(expandedExtent(), other_id->expandedExtent()); |
1710 | } |
1711 | |
1712 | return is_same; |
1713 | } |
1714 | |
1715 | // Returns a new IterDomain matching properties of this except for |
1716 | // is_rfactor_domain_ |
1717 | IterDomain* IterDomain::cloneWithoutRFactor() const { |
1718 | auto cloned = IterDomainBuilder(this).resetRfactor().build(); |
1719 | |
1720 | return cloned; |
1721 | } |
1722 | |
1723 | bool IterDomain::isTrivialReduction() const { |
1724 | if (!isReduction()) { |
1725 | return false; |
1726 | } |
1727 | |
1728 | if (extent()->isOneInt()) { |
1729 | return true; |
1730 | } |
1731 | |
1732 | // If this domain is an output of an expression, i.e., not a root |
1733 | // domain, check if all root domains are trivial reductions. This is |
1734 | // almost the same as the analysis done in TrivialReductionInfo, but |
1735 | // is limited within a single tensor, whereas TrivialReductionInfo |
1736 | // does more expensive analysis potentially traversing through |
1737 | // rfactor domains |
1738 | if (definition()) { |
1739 | // Note: There's no const version of IterVisitor. |
1740 | auto id_inputs = InputsOf::output(fusion(), const_cast<IterDomain*>(this)); |
1741 | if (std::all_of( |
1742 | ir_utils::filterByType<IterDomain>(id_inputs).begin(), |
1743 | ir_utils::filterByType<IterDomain>(id_inputs).end(), |
1744 | [](IterDomain* root_id) { |
1745 | return root_id->isReduction() && root_id->extent()->isOneInt(); |
1746 | })) { |
1747 | return true; |
1748 | } |
1749 | } |
1750 | |
1751 | return false; |
1752 | } |
1753 | |
1754 | std::vector<IterDomain*> IterDomain::clone( |
1755 | const std::vector<IterDomain*>& domains) { |
1756 | std::vector<IterDomain*> cloned_domains; |
1757 | std::transform( |
1758 | domains.begin(), |
1759 | domains.end(), |
1760 | std::back_inserter(cloned_domains), |
1761 | [](auto id) { return id->cloneWithoutRFactor(); }); |
1762 | return cloned_domains; |
1763 | } |
1764 | |
1765 | // Merging does not propagate the start and stop values of the input |
1766 | // domains to the merged output domain. The actual range of the |
1767 | // domains is enforced by predicates. Note that since only root |
1768 | // domains have valid start and stop, it's not possible to contiguous |
1769 | // predication. |
1770 | IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { |
1771 | TORCH_CHECK( |
1772 | !outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(), |
1773 | "Merging IterDomains with ending values that are 0 is not supported at this time." ); |
1774 | TORCH_CHECK( |
1775 | outer->isReduction() == inner->isReduction() || |
1776 | (!outer->isReduction() && inner->isTrivialReduction()) || |
1777 | (outer->isTrivialReduction() && !inner->isReduction()), |
1778 | "Merging IterDomains requires that their iteration types match. " , |
1779 | "Outer: " , |
1780 | outer->toString(), |
1781 | ", Inner: " , |
1782 | inner->toString()); |
1783 | TORCH_CHECK( |
1784 | (outer->isGather() && inner->isGather()) || |
1785 | (!outer->isGather() && !inner->isGather()), |
1786 | "Merging gather and non-gather domains is not supported." ); |
1787 | |
1788 | TORCH_CHECK( |
1789 | !outer->isStride() && !inner->isStride(), |
1790 | "No support for merging stride domains" ); |
1791 | |
1792 | Val* merged_id_size = mul(outer->extent(), inner->extent()); |
1793 | |
1794 | IterType itype = outer->getIterType(); |
1795 | |
1796 | if (outer->isBroadcast() && inner->isBroadcast()) { |
1797 | itype = IterType::Broadcast; |
1798 | } |
1799 | |
1800 | if ((outer->isBroadcast() || inner->isBroadcast()) && |
1801 | (outer->getIterType() == IterType::Iteration || |
1802 | inner->getIterType() == IterType::Iteration)) { |
1803 | itype = IterType::Iteration; |
1804 | } |
1805 | |
1806 | // Merging trivial reduction with iter domain, that's fine, just make it an |
1807 | // iter domain. |
1808 | if ((outer->isTrivialReduction() || inner->isTrivialReduction()) && |
1809 | (outer->getIterType() == IterType::Iteration || |
1810 | inner->getIterType() == IterType::Iteration)) { |
1811 | itype = IterType::Iteration; |
1812 | } |
1813 | |
1814 | // Merging trivial reduction with broadcasting, that's fine, just make it a |
1815 | // broadcasting. |
1816 | if ((outer->isTrivialReduction() || inner->isTrivialReduction()) && |
1817 | (outer->isBroadcast() || inner->isBroadcast())) { |
1818 | itype = IterType::Broadcast; |
1819 | } |
1820 | |
1821 | Val* expanded_extent = nullptr; |
1822 | if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) { |
1823 | if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) { |
1824 | expanded_extent = mul(outer->expandedExtent(), inner->expandedExtent()); |
1825 | } else if (outer->hasExpandedExtent() && !inner->hasExpandedExtent()) { |
1826 | if (inner->isBroadcast()) { |
1827 | expanded_extent = outer->expandedExtent(); |
1828 | } else { |
1829 | expanded_extent = mul(outer->expandedExtent(), inner->extent()); |
1830 | } |
1831 | } else if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) { |
1832 | if (outer->isBroadcast()) { |
1833 | expanded_extent = inner->expandedExtent(); |
1834 | } else { |
1835 | expanded_extent = mul(outer->extent(), inner->expandedExtent()); |
1836 | } |
1837 | } |
1838 | } |
1839 | |
1840 | IterDomain* merged_id = |
1841 | IterDomainBuilder( |
1842 | outer->container()->zeroVal(), merged_id_size->as<Int>()) |
1843 | .parallel_type(outer->getParallelType()) |
1844 | .expanded_extent(expanded_extent) |
1845 | .iter_type(itype) |
1846 | .build(); |
1847 | |
1848 | IrBuilder::create<Merge>(outer->container(), merged_id, outer, inner); |
1849 | |
1850 | return merged_id; |
1851 | } |
1852 | |
1853 | // Both outer and inner domains do not inherit start and stop |
1854 | // values as they can't be split. The access range is enforced by |
1855 | // predicates. |
1856 | std::pair<IterDomain*, IterDomain*> IterDomain::split( |
1857 | IterDomain* in, |
1858 | Val* factor, |
1859 | bool inner_split, |
1860 | Val* start_offset, |
1861 | Val* stop_offset) { |
1862 | TORCH_CHECK( |
1863 | !in->extent()->isZeroInt(), |
1864 | "Splitting IterDomains with ending values that are 0 is not supported at this time." ); |
1865 | |
1866 | TORCH_CHECK(factor->isAnInt(), "Cannot split by non-integer value " , factor); |
1867 | |
1868 | if (factor->getValType() == ValType::Scalar) { |
1869 | TORCH_CHECK( |
1870 | factor->isConstScalar() || |
1871 | (FusionGuard::getCurFusion() == factor->fusion() && |
1872 | factor->isFusionInput()), |
1873 | factor, |
1874 | " is not a constant nor an input. It must be one or the other to be used in a split." , |
1875 | " If you want a symbolic split based on a thread dimension please use IterDomain::split(IterDomain*, ParallelType);" ); |
1876 | } else if (factor->getValType() == ValType::NamedScalar) { |
1877 | TORCH_CHECK( |
1878 | factor->as<NamedScalar>()->getParallelDim() != c10::nullopt, |
1879 | "Splitting a dimension by a named scalar is only supported on block or grid dimensions but received " , |
1880 | factor); |
1881 | } |
1882 | |
1883 | // outer loop size |
1884 | Val* remainder = |
1885 | ceilDiv(Split::extent(in->extent(), start_offset, stop_offset), factor); |
1886 | Val* expanded_remainder = nullptr; |
1887 | if (in->hasExpandedExtent()) { |
1888 | expanded_remainder = ceilDiv( |
1889 | Split::extent(in->expandedExtent(), start_offset, stop_offset), factor); |
1890 | } |
1891 | |
1892 | if ((start_offset != nullptr && !start_offset->isZeroInt()) || |
1893 | (stop_offset != nullptr && !stop_offset->isZeroInt())) { |
1894 | TORCH_INTERNAL_ASSERT( |
1895 | in->definition() == nullptr, |
1896 | "Partial split is only allowed with root domains" ); |
1897 | } |
1898 | // outer loop IterDomain |
1899 | IterDomain* ido = |
1900 | IterDomainBuilder( |
1901 | in->container()->zeroVal(), |
1902 | inner_split ? remainder->as<Int>() : factor) |
1903 | .expanded_extent( |
1904 | in->hasExpandedExtent() && inner_split ? expanded_remainder |
1905 | : nullptr) |
1906 | .parallel_type(in->getParallelType()) |
1907 | .iter_type(in->getIterType()) |
1908 | .build(); |
1909 | |
1910 | // inner loop IterDomain |
1911 | IterDomain* idi = |
1912 | IterDomainBuilder( |
1913 | in->container()->zeroVal(), |
1914 | inner_split ? factor : remainder->as<Int>()) |
1915 | .expanded_extent( |
1916 | in->hasExpandedExtent() && !inner_split ? expanded_remainder |
1917 | : nullptr) |
1918 | .parallel_type(in->getParallelType()) |
1919 | .iter_type(in->getIterType()) |
1920 | .build(); |
1921 | |
1922 | IrBuilder::create<Split>( |
1923 | in->container(), |
1924 | ido, |
1925 | idi, |
1926 | in, |
1927 | factor, |
1928 | inner_split, |
1929 | start_offset, |
1930 | stop_offset); |
1931 | return {ido, idi}; |
1932 | } |
1933 | |
1934 | std::pair<IterDomain*, IterDomain*> IterDomain::split( |
1935 | IterDomain* in, |
1936 | Val* factor, |
1937 | bool inner_split, |
1938 | bool trim_out_of_bounds) { |
1939 | auto start_offset = trim_out_of_bounds ? in->start() : nullptr; |
1940 | auto stop_offset = trim_out_of_bounds ? in->stopOffset() : nullptr; |
1941 | return IterDomain::split(in, factor, inner_split, start_offset, stop_offset); |
1942 | } |
1943 | |
1944 | std::pair<IterDomain*, IterDomain*> IterDomain::stridedSplit(int factor) { |
1945 | // Use partial split so that only valid values are retained |
1946 | auto split_out = IterDomain::split( |
1947 | this, IrBuilder::create<Int>(container(), factor), true, true); |
1948 | |
1949 | split_out.second->iter_type_ = IterType::Stride; |
1950 | split_out.first->is_rfactor_domain_ = true; |
1951 | split_out.second->is_rfactor_domain_ = true; |
1952 | return split_out; |
1953 | } |
1954 | |
1955 | std::pair<IterDomain*, IterDomain*> IterDomain::swizzle( |
1956 | Swizzle2DType swizzle_type, |
1957 | IterDomain* in_x, |
1958 | IterDomain* in_y, |
1959 | SwizzleMode swizzle_mode) { |
1960 | TORCH_CHECK( |
1961 | !in_x->extent()->isZeroInt() && !in_y->extent()->isZeroInt(), |
1962 | "Invalid swizzling of a empty dimension." ); |
1963 | |
1964 | // TODO: reduction check on swizzle: |
1965 | TORCH_CHECK( |
1966 | !in_x->isReduction() && !in_y->isReduction(), |
1967 | "swizzled reduction not yet supported" ); |
1968 | |
1969 | for (auto input : InputsOf::outputs(in_x->fusion(), {in_x, in_y})) { |
1970 | TORCH_CHECK( |
1971 | !input->as<IterDomain>()->isBroadcast(), |
1972 | "swizzling broadcast axes not yet supported" ); |
1973 | } |
1974 | |
1975 | // TODO: gather and shift check on swizzle |
1976 | TORCH_INTERNAL_ASSERT( |
1977 | !in_x->isGather() && !in_y->isGather(), |
1978 | "Swizzled gather not yet supported" ); |
1979 | |
1980 | IterDomain* out_x = IterDomainBuilder(in_x).build(); |
1981 | |
1982 | IterDomain* out_y = IterDomainBuilder(in_y).build(); |
1983 | |
1984 | IrBuilder::create<Swizzle2D>( |
1985 | in_x->container(), out_x, out_y, in_x, in_y, swizzle_type, swizzle_mode); |
1986 | |
1987 | return std::make_pair(out_x, out_y); |
1988 | } |
1989 | |
1990 | // TODO: We should change parallelize interface to be on tensorview or at least |
1991 | // vectorize should be done on tensorview. This would let us check that we don't |
1992 | // vectorize to the left of the computeAt domain, and could allow us to do some |
1993 | // simple validation of vectorize as it's inputs are right most and contiguous. |
1994 | void IterDomain::parallelize(ParallelType t) { |
1995 | if (parallel_type_ == t) { |
1996 | // No op, don't do any more checks, it was already set to this value. |
1997 | return; |
1998 | } |
1999 | |
2000 | if (t == ParallelType::Unroll || isParallelTypeVectorize(t) || |
2001 | t == ParallelType::Group) { |
2002 | TORCH_CHECK( |
2003 | start()->isZeroInt() && extent()->isConstScalar(), |
2004 | "Vectorization, unrolling, unswitching and grouping are only supported with start = 0 and extent as a const int, but got " , |
2005 | "a start of " , |
2006 | start(), |
2007 | " and extent " , |
2008 | extent(), |
2009 | " ." ); |
2010 | } |
2011 | |
2012 | if (t == ParallelType::Group) { |
2013 | TORCH_CHECK( |
2014 | getIterType() == IterType::Iteration, |
2015 | "Grouping IterDomain of non Iteration type is not allowed. " , |
2016 | getIterType()); |
2017 | } |
2018 | |
2019 | if (isMmaSwizzled()) { |
2020 | // Mma swizzled axes represent data representation within a warp |
2021 | // so only allow updates that keep the parallelization within |
2022 | // a warp. |
2023 | // Note && TODO: this check is actually used to allow indexing path |
2024 | // to make copies of the iterdomains. We might eventually just want |
2025 | // to lock these parallel types and not allowing any changes once |
2026 | // they are swizzled. |
2027 | TORCH_CHECK( |
2028 | t == ParallelType::Vectorize || t == ParallelType::TIDx || |
2029 | t == ParallelType::Serial, |
2030 | "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids" ); |
2031 | } |
2032 | |
2033 | parallel_type_ = t; |
2034 | } |
2035 | |
2036 | bool IterDomain::maybePartial() const { |
2037 | return !start()->isZeroInt() || !stopOffset()->isZeroInt(); |
2038 | } |
2039 | |
2040 | Val* IterDomain::stopOffset() const { |
2041 | return stop_offset_; |
2042 | } |
2043 | |
2044 | Val* IterDomain::stop() const { |
2045 | if (stopOffset()->isZeroInt()) { |
2046 | return extent(); |
2047 | } |
2048 | |
2049 | return sub(extent(), stopOffset()); |
2050 | } |
2051 | |
2052 | TensorDomain::TensorDomain( |
2053 | IrBuilderPasskey passkey, |
2054 | std::vector<IterDomain*> root_domain, |
2055 | std::vector<bool> contiguity) |
2056 | : Val(passkey, ValType::TensorDomain, DataType::Null), |
2057 | root_domain_(std::move(root_domain)), |
2058 | contiguity_( |
2059 | contiguity.empty() ? std::vector<bool>(root_domain_.size(), false) |
2060 | : std::move(contiguity)) { |
2061 | TORCH_CHECK( |
2062 | contiguity_.size() == getMaybeRFactorDomain().size(), |
2063 | "Invalid contiguity information provided, incorrect size. Received vector of size " , |
2064 | contiguity_.size(), |
2065 | " but needed one of size " , |
2066 | root_domain_.size()); |
2067 | |
2068 | // Just due to clang-tidy, correct value set in resetDomains |
2069 | has_nontrivial_reduction_ = false; |
2070 | domain_ = root_domain_; |
2071 | resetDomains(); |
2072 | } |
2073 | |
2074 | TensorDomain::TensorDomain( |
2075 | IrBuilderPasskey passkey, |
2076 | std::vector<IterDomain*> root_domain, |
2077 | std::vector<IterDomain*> domain, |
2078 | std::vector<bool> contiguity) |
2079 | : Val(passkey, ValType::TensorDomain, DataType::Null), |
2080 | root_domain_(std::move(root_domain)), |
2081 | domain_(std::move(domain)), |
2082 | contiguity_( |
2083 | contiguity.empty() ? std::vector<bool>(root_domain_.size(), false) |
2084 | : std::move(contiguity)) { |
2085 | TORCH_CHECK( |
2086 | contiguity_.size() == getMaybeRFactorDomain().size(), |
2087 | "Invalid contiguity information provided, incorrect size. Received vector of size " , |
2088 | contiguity_.size(), |
2089 | " but needed one of size " , |
2090 | root_domain_.size()); |
2091 | |
2092 | std::vector<Val*> domain_vals(domain_.begin(), domain_.end()); |
2093 | auto inps = IterVisitor::getInputsTo(domain_vals); |
2094 | |
2095 | // Validate that the root domain consists of all inputs to domain |
2096 | // Uncertain if this will hold for RFactor |
2097 | |
2098 | std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end()); |
2099 | std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) { |
2100 | TORCH_INTERNAL_ASSERT( |
2101 | root_vals.find(inp) != root_vals.end(), |
2102 | "Invalid tensor domain, " , |
2103 | inp, |
2104 | " is an input of domain, but it is not found in the root domain." ); |
2105 | }); |
2106 | |
2107 | // Just due to clang-tidy, correct value set in resetDomains |
2108 | has_nontrivial_reduction_ = false; |
2109 | resetDomains(); |
2110 | } |
2111 | |
2112 | TensorDomain::TensorDomain( |
2113 | IrBuilderPasskey passkey, |
2114 | std::vector<IterDomain*> root_domain, |
2115 | std::vector<IterDomain*> rfactor_domain, |
2116 | std::vector<IterDomain*> domain, |
2117 | std::vector<bool> contiguity) |
2118 | : Val(passkey, ValType::TensorDomain, DataType::Null), |
2119 | root_domain_(std::move(root_domain)), |
2120 | domain_(std::move(domain)), |
2121 | rfactor_domain_(std::move(rfactor_domain)), |
2122 | contiguity_( |
2123 | contiguity.empty() ? std::vector<bool>(rfactor_domain_.size(), false) |
2124 | : std::move(contiguity)) { |
2125 | TORCH_CHECK( |
2126 | contiguity_.size() == getMaybeRFactorDomain().size(), |
2127 | "Invalid contiguity information provided, incorrect size. Received vector of size " , |
2128 | contiguity_.size(), |
2129 | " but needed one of size " , |
2130 | getMaybeRFactorDomain().size()); |
2131 | |
2132 | auto inps = IterVisitor::getInputsTo( |
2133 | std::vector<Val*>(domain_.begin(), domain_.end())); |
2134 | |
2135 | // Validate that the root domain consists of all inputs to domain |
2136 | // Uncertain if this will hold for RFactor |
2137 | |
2138 | std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end()); |
2139 | std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) { |
2140 | TORCH_INTERNAL_ASSERT( |
2141 | root_vals.find(inp) != root_vals.end(), |
2142 | "Invalid tensor domain, " , |
2143 | inp, |
2144 | " is an input of domain, but it is not found in the root domain." ); |
2145 | }); |
2146 | |
2147 | inps = IterVisitor::getInputsTo( |
2148 | std::vector<Val*>(rfactor_domain_.begin(), rfactor_domain_.end())); |
2149 | std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) { |
2150 | TORCH_INTERNAL_ASSERT( |
2151 | root_vals.find(inp) != root_vals.end(), |
2152 | "Invalid tensor domain, " , |
2153 | inp, |
2154 | " is an input of the rfactor domain, but it is not found in the root domain." ); |
2155 | }); |
2156 | |
2157 | // Just due to clang-tidy, correct value set in resetDomains |
2158 | has_nontrivial_reduction_ = false; |
2159 | resetDomains(); |
2160 | } |
2161 | |
2162 | TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner) |
2163 | : Val(src, ir_cloner), |
2164 | root_domain_(ir_cloner->clone(src->root_domain_)), |
2165 | domain_(ir_cloner->clone(src->domain_)), |
2166 | no_bcast_domain_(ir_cloner->clone(src->no_bcast_domain_)), |
2167 | no_reduction_domain_(ir_cloner->clone(src->no_reduction_domain_)), |
2168 | rfactor_domain_(ir_cloner->clone(src->rfactor_domain_)), |
2169 | contiguity_(src->contiguity()), |
2170 | has_nontrivial_reduction_(src->has_nontrivial_reduction_) {} |
2171 | |
2172 | bool TensorDomain::hasBlockBroadcast() const { |
2173 | return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { |
2174 | return id->isBroadcast() && id->isThreadDim(); |
2175 | }); |
2176 | } |
2177 | |
2178 | bool TensorDomain::hasGridBroadcast() const { |
2179 | return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { |
2180 | return id->isBroadcast() && id->isBlockDim(); |
2181 | }); |
2182 | } |
2183 | |
2184 | bool TensorDomain::operator==(const TensorDomain& other) const { |
2185 | // Checks equality of each class field. Should not be necessary to |
2186 | // check no_bcast_domain_ and no_reduction_domain_ as they are just |
2187 | // derived from domain_. |
2188 | return root_domain_ == other.root_domain_ && domain_ == other.domain_ && |
2189 | rfactor_domain_ == other.rfactor_domain_ && |
2190 | contiguity_ == other.contiguity_; |
2191 | } |
2192 | |
2193 | bool TensorDomain::sameAs(const Statement* const other) const { |
2194 | if (this == other) { |
2195 | return true; |
2196 | } |
2197 | |
2198 | if (!other->isA<TensorDomain>()) { |
2199 | return false; |
2200 | } |
2201 | |
2202 | const TensorDomain* other_td = other->as<TensorDomain>(); |
2203 | |
2204 | if (nDims() != other_td->nDims()) { |
2205 | return false; |
2206 | } |
2207 | if (getRootDomain().size() != other_td->getRootDomain().size()) { |
2208 | return false; |
2209 | } |
2210 | if (getRFactorDomain().size() != other_td->getRFactorDomain().size()) { |
2211 | return false; |
2212 | } |
2213 | |
2214 | for (const auto i : c10::irange(nDims())) { |
2215 | if (!(axis(i)->sameAs(other_td->axis(i)))) { |
2216 | return false; |
2217 | } |
2218 | } |
2219 | |
2220 | for (const auto i : c10::irange(getRootDomain().size())) { |
2221 | if (!(getRootDomain()[i]->sameAs(other_td->getRootDomain()[i]))) { |
2222 | return false; |
2223 | } |
2224 | } |
2225 | |
2226 | for (const auto i : c10::irange(getRFactorDomain().size())) { |
2227 | if (!(getRFactorDomain()[i]->sameAs(other_td->getRFactorDomain()[i]))) { |
2228 | return false; |
2229 | } |
2230 | } |
2231 | |
2232 | return true; |
2233 | } |
2234 | |
2235 | bool TensorDomain::sameAs( |
2236 | const std::vector<IterDomain*>& lhs, |
2237 | const std::vector<IterDomain*>& rhs) { |
2238 | if (lhs.size() != rhs.size()) |
2239 | return false; |
2240 | size_t i = 0; |
2241 | for (auto td_lhs : lhs) { |
2242 | if (!td_lhs->sameAs(rhs[i++])) |
2243 | return false; |
2244 | } |
2245 | return true; |
2246 | } |
2247 | |
2248 | void TensorDomain::setContiguity(const std::vector<bool>& contig) { |
2249 | TORCH_INTERNAL_ASSERT( |
2250 | getMaybeRFactorDomain().size() == contig.size(), |
2251 | "Invalid contiguity vector: " , |
2252 | contig); |
2253 | |
2254 | contiguity_ = contig; |
2255 | } |
2256 | |
2257 | bool TensorDomain::hasReduction() const { |
2258 | return has_nontrivial_reduction_; |
2259 | } |
2260 | |
2261 | bool TensorDomain::hasBlockReduction() const { |
2262 | return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { |
2263 | return id->isReduction() && id->isThreadDim(); |
2264 | }); |
2265 | } |
2266 | |
2267 | bool TensorDomain::hasGridReduction() const { |
2268 | return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { |
2269 | return id->isReduction() && id->isBlockDim(); |
2270 | }); |
2271 | } |
2272 | |
2273 | bool TensorDomain::hasBroadcast() const { |
2274 | return no_bcast_domain_.size() != domain_.size(); |
2275 | } |
2276 | |
2277 | bool TensorDomain::hasRFactor() const { |
2278 | return !rfactor_domain_.empty(); |
2279 | } |
2280 | |
2281 | bool TensorDomain::hasViewLikeRFactor() const { |
2282 | if (!hasRFactor()) { |
2283 | // Can't have view like rfactor if there is no rfactor domain |
2284 | return false; |
2285 | } |
2286 | |
2287 | // If there's an rfactor domain and no rfactor product is a reduction, this is |
2288 | // a view like rfactor |
2289 | return std::none_of( |
2290 | getMaybeRFactorDomain().begin(), |
2291 | getMaybeRFactorDomain().end(), |
2292 | [](IterDomain* id) { |
2293 | return id->isReduction() && id->isRFactorProduct(); |
2294 | }); |
2295 | } |
2296 | |
2297 | bool TensorDomain::hasVectorize() const { |
2298 | return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { |
2299 | return id->getParallelType() == ParallelType::Vectorize || |
2300 | id->getParallelType() == ParallelType::MisalignedVectorize; |
2301 | }); |
2302 | } |
2303 | |
2304 | c10::optional<unsigned int> TensorDomain::getReductionAxis() const { |
2305 | auto it = std::find_if(domain_.begin(), domain_.end(), [](const auto& id) { |
2306 | return id->isReduction(); |
2307 | }); |
2308 | if (it == domain_.end()) { |
2309 | return c10::optional<unsigned int>(); |
2310 | } else { |
2311 | return c10::optional<unsigned int>(std::distance(domain_.begin(), it)); |
2312 | } |
2313 | } |
2314 | |
2315 | // i here is int, as we want to accept negative value and ::size_type can be a |
2316 | // uint. |
2317 | IterDomain* TensorDomain::axis(int i) const { |
2318 | TORCH_INTERNAL_ASSERT( |
2319 | nDims() > 0, "Tried to access an axis in a 0-dim domain" ); |
2320 | if (i < 0) |
2321 | i += nDims(); |
2322 | TORCH_CHECK( |
2323 | i >= 0 && (unsigned int)i < nDims(), |
2324 | "Tried to access axis " , |
2325 | i, |
2326 | " in domain " , |
2327 | this); |
2328 | return domain_[i]; |
2329 | } |
2330 | |
2331 | size_t TensorDomain::posOf(IterDomain* id) const { |
2332 | TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to find an axis in a 0-dim domain" ); |
2333 | size_t i = 0; |
2334 | while (i < domain_.size()) { |
2335 | if (domain_[i] == id) |
2336 | return i; |
2337 | i++; |
2338 | } |
2339 | TORCH_CHECK(false, "Provided id is not part of this domain." ); |
2340 | } |
2341 | |
2342 | size_t TensorDomain::rootPosOf(IterDomain* id) const { |
2343 | TORCH_INTERNAL_ASSERT( |
2344 | root_domain_.size() > 0, "Tried to find an axis in a 0-dim root domain" ); |
2345 | auto it = std::find(root_domain_.begin(), root_domain_.end(), id); |
2346 | TORCH_INTERNAL_ASSERT( |
2347 | it != root_domain_.end(), "Provided id is not part of root domain." ); |
2348 | return std::distance(root_domain_.begin(), it); |
2349 | } |
2350 | |
2351 | void TensorDomain::split( |
2352 | int axis_, |
2353 | Val* factor, |
2354 | bool inner_split, |
2355 | bool trim_out_of_bounds) { |
2356 | TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim domain" ); |
2357 | if (axis_ < 0) |
2358 | axis_ += nDims(); |
2359 | |
2360 | TORCH_INTERNAL_ASSERT( |
2361 | axis_ >= 0 && (unsigned int)axis_ < nDims(), |
2362 | "Tried to split on axis outside TensorDomain's range." ); |
2363 | |
2364 | IterDomain* id = axis(axis_); |
2365 | |
2366 | // partial split is only allowed with root domains |
2367 | if (trim_out_of_bounds) { |
2368 | TORCH_INTERNAL_ASSERT( |
2369 | std::find(getRootDomain().begin(), getRootDomain().end(), id) != |
2370 | getRootDomain().end(), |
2371 | "Partial split is only allowed with root domains" ); |
2372 | } |
2373 | |
2374 | TORCH_INTERNAL_ASSERT( |
2375 | !id->isMmaSwizzled(), |
2376 | "Further transformation on warp mapped id's not allowed." ); |
2377 | |
2378 | auto split_ids = |
2379 | IterDomain::split(id, factor, inner_split, trim_out_of_bounds); |
2380 | domain_.erase(domain_.begin() + axis_); |
2381 | domain_.insert(domain_.begin() + axis_, split_ids.second); |
2382 | domain_.insert(domain_.begin() + axis_, split_ids.first); |
2383 | resetDomains(); |
2384 | } |
2385 | |
2386 | // Merge "axis_o" and "axis_i" into 1 dimension |
2387 | void TensorDomain::merge(int axis_o, int axis_i) { |
2388 | TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain" ); |
2389 | if (axis_o < 0) |
2390 | axis_o += nDims(); |
2391 | |
2392 | if (axis_i < 0) |
2393 | axis_i += nDims(); |
2394 | |
2395 | TORCH_CHECK( |
2396 | axis_o >= 0 && (unsigned int)axis_o < nDims() && axis_i >= 0 && |
2397 | (unsigned int)axis_i < nDims(), |
2398 | "Invalid merge detected, either one or both axes are outside of TensorView's range." ); |
2399 | |
2400 | TORCH_CHECK( |
2401 | axis_o != axis_i, |
2402 | "Invalid merge detected, axes provided are the same axis." ); |
2403 | |
2404 | if (axis_o > axis_i) { |
2405 | auto tmp = axis_i; |
2406 | axis_i = axis_o; |
2407 | axis_o = tmp; |
2408 | } |
2409 | |
2410 | IterDomain* first = axis(axis_o); |
2411 | IterDomain* second = axis(axis_i); |
2412 | |
2413 | TORCH_INTERNAL_ASSERT( |
2414 | !first->isMmaSwizzled() && !second->isMmaSwizzled(), |
2415 | "Further transformation on warp mapped id's not allowed." ); |
2416 | |
2417 | IterDomain* merged_id = IterDomain::merge(first, second); |
2418 | |
2419 | domain_.erase(domain_.begin() + axis_i); |
2420 | domain_.erase(domain_.begin() + axis_o); |
2421 | domain_.insert(domain_.begin() + axis_o, merged_id); |
2422 | resetDomains(); |
2423 | } |
2424 | |
2425 | // Reorder axes according to map[old_pos] = new_pos |
2426 | void TensorDomain::reorder(const std::unordered_map<int, int>& old2new_) { |
2427 | TORCH_INTERNAL_ASSERT( |
2428 | !(nDims() == 0 && old2new_.size() > 0), |
2429 | "Tried to reorder a 0-dim domain" ); |
2430 | domain_ = orderedAs(domain_, old2new_); |
2431 | resetDomains(); |
2432 | } |
2433 | |
2434 | std::vector<IterDomain*> TensorDomain::orderedAs( |
2435 | const std::vector<IterDomain*>& dom, |
2436 | const std::unordered_map<int, int>& old2new_) { |
2437 | TORCH_INTERNAL_ASSERT( |
2438 | !(dom.size() == 0 && old2new_.size() > 0), |
2439 | "Tried to reorder a 0-dim domain" ); |
2440 | |
2441 | // Eventhough these checks are already in TensorView, we want to redo them as |
2442 | // we can enter this function from other places, not through TensorView |
2443 | |
2444 | auto new2old = ir_utils::normalizeOld2New(old2new_, dom.size()); |
2445 | |
2446 | std::vector<IterDomain*> reordered_domain; |
2447 | std::transform( |
2448 | new2old.begin(), |
2449 | new2old.end(), |
2450 | std::back_inserter(reordered_domain), |
2451 | [dom](int i) -> IterDomain* { return dom[i]; }); |
2452 | |
2453 | return reordered_domain; |
2454 | } |
2455 | |
2456 | void TensorDomain::swizzle( |
2457 | Swizzle2DType swizzle_type, |
2458 | int x, |
2459 | int y, |
2460 | SwizzleMode swizzle_mode) { |
2461 | TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain" ); |
2462 | |
2463 | TORCH_CHECK( |
2464 | x >= 0 && (unsigned int)x < nDims(), |
2465 | "Invalid swizzle detected, either one or both axes are outside of TensorView's range." ); |
2466 | |
2467 | TORCH_CHECK( |
2468 | y >= 0 && (unsigned int)y < nDims(), |
2469 | "Invalid swizzle detected, either one or both axes are outside of TensorView's range." ); |
2470 | |
2471 | IterDomain* axis_x = axis(x); |
2472 | IterDomain* axis_y = axis(y); |
2473 | |
2474 | IterDomain* axis_out_x = nullptr; |
2475 | IterDomain* axis_out_y = nullptr; |
2476 | |
2477 | std::tie(axis_out_x, axis_out_y) = |
2478 | IterDomain::swizzle(swizzle_type, axis_x, axis_y, swizzle_mode); |
2479 | |
2480 | domain_.erase(domain_.begin() + x); |
2481 | domain_.insert(domain_.begin() + x, axis_out_x); |
2482 | |
2483 | domain_.erase(domain_.begin() + y); |
2484 | domain_.insert(domain_.begin() + y, axis_out_y); |
2485 | |
2486 | resetDomains(); |
2487 | } |
2488 | |
2489 | std::vector<IterDomain*> TensorDomain::noReductions( |
2490 | const std::vector<IterDomain*>& td) { |
2491 | size_t size_out = 0; |
2492 | for (auto id : td) { |
2493 | if (!id->isReduction() && !id->isStride()) { |
2494 | size_out++; |
2495 | } |
2496 | } |
2497 | std::vector<IterDomain*> noReductionDomain(size_out); |
2498 | |
2499 | int it = 0; |
2500 | for (auto id : td) { |
2501 | if (!id->isReduction() && !id->isStride()) { |
2502 | noReductionDomain[it++] = id; |
2503 | } |
2504 | } |
2505 | |
2506 | return noReductionDomain; |
2507 | } |
2508 | |
2509 | std::vector<IterDomain*> TensorDomain::noBroadcasts( |
2510 | const std::vector<IterDomain*>& td) { |
2511 | size_t size_out = 0; |
2512 | for (auto id : td) |
2513 | if (!id->isBroadcast()) |
2514 | size_out++; |
2515 | std::vector<IterDomain*> noBroadcastDomain(size_out); |
2516 | |
2517 | int it = 0; |
2518 | for (auto id : td) |
2519 | if (!id->isBroadcast()) |
2520 | noBroadcastDomain[it++] = id; |
2521 | |
2522 | return noBroadcastDomain; |
2523 | } |
2524 | |
2525 | bool TensorDomain::hasBroadcast(const std::vector<IterDomain*>& td) { |
2526 | for (auto id : td) |
2527 | if (id->isBroadcast()) |
2528 | return true; |
2529 | return false; |
2530 | } |
2531 | |
2532 | bool TensorDomain::hasReduction(const std::vector<IterDomain*>& td) { |
2533 | for (auto id : td) |
2534 | if (id->isReduction()) |
2535 | return true; |
2536 | return false; |
2537 | } |
2538 | |
2539 | bool TensorDomain::hasNontrivialReduction(const std::vector<IterDomain*>& td) { |
2540 | for (auto id : td) { |
2541 | if (id->isReduction() && !id->isTrivialReduction()) { |
2542 | return true; |
2543 | } |
2544 | } |
2545 | return false; |
2546 | } |
2547 | |
2548 | TensorDomain* TensorDomain::view(const AnalyzeViewResult& view_analysis) { |
2549 | TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to view transform a 0-dim domain" ); |
2550 | return transformView(this, view_analysis); |
2551 | } |
2552 | |
2553 | TensorDomain* TensorDomain::flatten(int64_t start_dim, int64_t end_dim) { |
2554 | auto inp_domain = noReductions(getMaybeRFactorDomain()); |
2555 | |
2556 | if (start_dim < 0) { |
2557 | start_dim += inp_domain.size(); |
2558 | } |
2559 | if (end_dim < 0) { |
2560 | end_dim += inp_domain.size(); |
2561 | } |
2562 | TORCH_CHECK( |
2563 | start_dim >= 0 && start_dim < int64_t(inp_domain.size()), |
2564 | "Invalid start_dim " , |
2565 | start_dim); |
2566 | TORCH_CHECK( |
2567 | end_dim >= 0 && end_dim < int64_t(inp_domain.size()), |
2568 | "Invalid end_dim " , |
2569 | end_dim); |
2570 | TORCH_CHECK(start_dim <= end_dim, "start_dim must be <= end_dim" ); |
2571 | |
2572 | std::vector<IterDomain*> new_root_domain; |
2573 | new_root_domain.reserve(inp_domain.size()); |
2574 | for (auto i : c10::irange(inp_domain.size())) { |
2575 | bool is_rfactor_dim = i >= size_t(start_dim) && i <= size_t(end_dim); |
2576 | auto inp_id = inp_domain[i]; |
2577 | auto out_id = IterDomainBuilder(inp_id) |
2578 | .is_rfactor_domain(is_rfactor_dim) |
2579 | .extent( |
2580 | (is_rfactor_dim && inp_id->hasExpandedExtent()) |
2581 | ? inp_id->expandedExtent() |
2582 | : inp_id->extent()) |
2583 | .iter_type( |
2584 | (is_rfactor_dim && inp_id->isBroadcast()) |
2585 | ? IterType::Iteration |
2586 | : inp_id->getIterType()) |
2587 | .build(); |
2588 | new_root_domain.push_back(out_id); |
2589 | } |
2590 | |
2591 | std::vector<IterDomain*> rfactor_domain; |
2592 | rfactor_domain.reserve(new_root_domain.size() - (end_dim - start_dim)); |
2593 | for (auto i : c10::irange(start_dim)) { |
2594 | rfactor_domain.push_back(new_root_domain[i]); |
2595 | } |
2596 | |
2597 | IterDomain* merged_id = new_root_domain[start_dim]; |
2598 | for (auto i : c10::irange(start_dim + 1, end_dim + 1)) { |
2599 | IterDomain* new_merged_id = |
2600 | IterDomainBuilder( |
2601 | merged_id->container()->zeroVal(), |
2602 | mul(merged_id->extent(), new_root_domain[i]->extent())) |
2603 | .is_rfactor_domain(true) |
2604 | .build(); |
2605 | IrBuilder::create<Merge>(new_merged_id, merged_id, new_root_domain[i]); |
2606 | merged_id = new_merged_id; |
2607 | } |
2608 | rfactor_domain.push_back(merged_id); |
2609 | |
2610 | for (auto i : c10::irange(end_dim + 1, inp_domain.size())) { |
2611 | rfactor_domain.push_back(new_root_domain[i]); |
2612 | } |
2613 | |
2614 | return IrBuilder::create<TensorDomain>( |
2615 | new_root_domain, |
2616 | rfactor_domain, |
2617 | rfactor_domain, |
2618 | std::vector<bool>(rfactor_domain.size(), true)); |
2619 | } |
2620 | |
2621 | // TODO: Rfactor a Welford |
2622 | |
2623 | // pair is in order where second is the consumer of first |
2624 | std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor( |
2625 | const std::vector<int>& axes_) { |
2626 | return TransformRFactor::runReplay(this, axes_); |
2627 | } |
2628 | |
2629 | Split::Split( |
2630 | IrBuilderPasskey passkey, |
2631 | IterDomain* outer, |
2632 | IterDomain* inner, |
2633 | IterDomain* in, |
2634 | Val* factor, |
2635 | bool inner_split, |
2636 | Val* start_offset, |
2637 | Val* stop_offset) |
2638 | : Expr(passkey, ExprType::Split), |
2639 | outer_{outer}, |
2640 | inner_{inner}, |
2641 | in_{in}, |
2642 | factor_{factor}, |
2643 | inner_split_{inner_split}, |
2644 | start_offset_{ |
2645 | start_offset != nullptr ? start_offset |
2646 | : passkey.ir_container_->zeroVal()}, |
2647 | stop_offset_{ |
2648 | stop_offset != nullptr ? stop_offset |
2649 | : passkey.ir_container_->zeroVal()} { |
2650 | TORCH_INTERNAL_ASSERT( |
2651 | factor_->isAnInt(), |
2652 | "Attempted to create a Split node with a non-integer factor." ); |
2653 | addOutput(outer); |
2654 | addOutput(inner); |
2655 | addInput(in); |
2656 | // TODO add factor as an input, need to check Split::Split during validation |
2657 | // and need to check BestEffortReplay::findFirstMismatchedID addInput(factor); |
2658 | } |
2659 | |
2660 | Split::Split(const Split* src, IrCloner* ir_cloner) |
2661 | : Expr(src, ir_cloner), |
2662 | outer_(ir_cloner->clone(src->outer_)), |
2663 | inner_(ir_cloner->clone(src->inner_)), |
2664 | in_(ir_cloner->clone(src->in_)), |
2665 | factor_(ir_cloner->clone(src->factor_)), |
2666 | inner_split_(src->inner_split_), |
2667 | start_offset_(ir_cloner->clone(src->start_offset_)), |
2668 | stop_offset_(ir_cloner->clone(src->stop_offset_)) {} |
2669 | |
2670 | Expr* Split::shallowCopy() const { |
2671 | auto result = IrBuilder::create<Split>( |
2672 | outer_, inner_, in_, factor_, inner_split_, start_offset_, stop_offset_); |
2673 | result->copyPredicatesFrom(this); |
2674 | return result; |
2675 | } |
2676 | |
2677 | Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) { |
2678 | TORCH_INTERNAL_ASSERT(in_extent != nullptr); |
2679 | |
2680 | if (start_offset != nullptr && !start_offset->isZeroInt()) { |
2681 | in_extent = sub(in_extent, start_offset); |
2682 | } |
2683 | |
2684 | if (stop_offset != nullptr && !stop_offset->isZeroInt()) { |
2685 | in_extent = sub(in_extent, stop_offset); |
2686 | } |
2687 | |
2688 | return in_extent; |
2689 | } |
2690 | |
2691 | bool Split::sameAs(const Statement* other) const { |
2692 | if (this == other) { |
2693 | return true; |
2694 | } |
2695 | if (!other->isA<Split>()) { |
2696 | return false; |
2697 | } |
2698 | return Expr::sameAs(other) && |
2699 | factor()->sameAs(other->as<Split>()->factor()) && |
2700 | innerSplit() == other->as<Split>()->innerSplit() && |
2701 | startOffset()->sameAs(other->as<Split>()->startOffset()) && |
2702 | stopOffset()->sameAs(other->as<Split>()->stopOffset()); |
2703 | } |
2704 | |
2705 | Merge::Merge( |
2706 | IrBuilderPasskey passkey, |
2707 | IterDomain* out, |
2708 | IterDomain* outer, |
2709 | IterDomain* inner) |
2710 | : Expr(passkey, ExprType::Merge), out_{out}, outer_{outer}, inner_{inner} { |
2711 | addOutput(out); |
2712 | addInput(outer); |
2713 | addInput(inner); |
2714 | } |
2715 | |
2716 | Merge::Merge(const Merge* src, IrCloner* ir_cloner) |
2717 | : Expr(src, ir_cloner), |
2718 | out_(ir_cloner->clone(src->out_)), |
2719 | outer_(ir_cloner->clone(src->outer_)), |
2720 | inner_(ir_cloner->clone(src->inner_)) {} |
2721 | |
2722 | Expr* Merge::shallowCopy() const { |
2723 | auto result = IrBuilder::create<Merge>(out_, outer_, inner_); |
2724 | result->copyPredicatesFrom(this); |
2725 | return result; |
2726 | } |
2727 | |
2728 | bool Merge::sameAs(const Statement* other) const { |
2729 | if (this == other) { |
2730 | return true; |
2731 | } |
2732 | if (!other->isA<Merge>()) { |
2733 | return false; |
2734 | } |
2735 | return Expr::sameAs(other); |
2736 | } |
2737 | |
2738 | Swizzle2D::Swizzle2D( |
2739 | IrBuilderPasskey passkey, |
2740 | IterDomain* out_x, |
2741 | IterDomain* out_y, |
2742 | IterDomain* in_x, |
2743 | IterDomain* in_y, |
2744 | Swizzle2DType swizzle_type, |
2745 | SwizzleMode swizzle_mode) |
2746 | : Expr(passkey, ExprType::Swizzle2D), |
2747 | out_x_{out_x}, |
2748 | out_y_{out_y}, |
2749 | in_x_{in_x}, |
2750 | in_y_{in_y}, |
2751 | swizzle_type_(swizzle_type), |
2752 | swizzle_mode_(swizzle_mode) { |
2753 | addOutput(out_x); |
2754 | addOutput(out_y); |
2755 | addInput(in_x); |
2756 | addInput(in_y); |
2757 | } |
2758 | |
2759 | Expr* Swizzle2D::shallowCopy() const { |
2760 | auto result = IrBuilder::create<Swizzle2D>( |
2761 | out_x_, out_y_, in_x_, in_y_, swizzle_type_, swizzle_mode_); |
2762 | result->copyPredicatesFrom(this); |
2763 | return result; |
2764 | } |
2765 | |
2766 | bool Swizzle2D::sameAs(const Statement* other) const { |
2767 | if (this == other) { |
2768 | return true; |
2769 | } |
2770 | if (!other->isA<Swizzle2D>()) { |
2771 | return false; |
2772 | } |
2773 | if (!(swizzle_type_ == other->as<Swizzle2D>()->swizzle_type_)) { |
2774 | return false; |
2775 | } |
2776 | return Expr::sameAs(other); |
2777 | } |
2778 | |
2779 | Swizzle2D::Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner) |
2780 | : Expr(src, ir_cloner), |
2781 | out_x_(ir_cloner->clone(src->out_x_)), |
2782 | out_y_(ir_cloner->clone(src->out_y_)), |
2783 | in_x_(ir_cloner->clone(src->in_x_)), |
2784 | in_y_(ir_cloner->clone(src->in_y_)), |
2785 | swizzle_type_(src->swizzle_type_), |
2786 | swizzle_mode_(src->swizzle_mode_) {} |
2787 | |
2788 | NamedScalar::NamedScalar( |
2789 | IrBuilderPasskey passkey, |
2790 | std::string name, |
2791 | DataType dtype) |
2792 | : Val(passkey, ValType::NamedScalar, dtype), name_(std::move(name)) {} |
2793 | |
2794 | NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner) |
2795 | : Val(src, ir_cloner), name_(src->name_) {} |
2796 | |
2797 | bool NamedScalar::sameAs(const Statement* other) const { |
2798 | if (this == other) { |
2799 | return true; |
2800 | } |
2801 | if (!other->isA<NamedScalar>()) { |
2802 | return false; |
2803 | } |
2804 | return other->as<NamedScalar>()->name().compare(name()) == 0; |
2805 | } |
2806 | |
2807 | NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) { |
2808 | TORCH_INTERNAL_ASSERT( |
2809 | isParallelTypeThread(p_type), |
2810 | "Cannot get parallel dim of non thread type, received: " , |
2811 | p_type); |
2812 | TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr); |
2813 | std::string parallel_dim = stringifyThreadSize(p_type); |
2814 | return IrBuilder::create<NamedScalar>(parallel_dim, DataType::Int); |
2815 | } |
2816 | |
2817 | NamedScalar* NamedScalar::getParallelIndex(ParallelType p_type) { |
2818 | TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr); |
2819 | std::string parallel_ind = stringifyThread(p_type); |
2820 | return IrBuilder::create<NamedScalar>(parallel_ind, DataType::Int); |
2821 | } |
2822 | |
2823 | c10::optional<ParallelType> NamedScalar::getParallelDim() const { |
2824 | if (stringifyThreadSize(ParallelType::TIDx).compare(name()) == 0) { |
2825 | return c10::optional<ParallelType>(ParallelType::TIDx); |
2826 | } else if (stringifyThreadSize(ParallelType::TIDy).compare(name()) == 0) { |
2827 | return c10::optional<ParallelType>(ParallelType::TIDy); |
2828 | } else if (stringifyThreadSize(ParallelType::TIDz).compare(name()) == 0) { |
2829 | return c10::optional<ParallelType>(ParallelType::TIDz); |
2830 | } else if (stringifyThreadSize(ParallelType::BIDx).compare(name()) == 0) { |
2831 | return c10::optional<ParallelType>(ParallelType::BIDx); |
2832 | } else if (stringifyThreadSize(ParallelType::BIDy).compare(name()) == 0) { |
2833 | return c10::optional<ParallelType>(ParallelType::BIDy); |
2834 | } else if (stringifyThreadSize(ParallelType::BIDz).compare(name()) == 0) { |
2835 | return c10::optional<ParallelType>(ParallelType::BIDz); |
2836 | } |
2837 | return c10::nullopt; |
2838 | } |
2839 | |
2840 | c10::optional<ParallelType> NamedScalar::getParallelIndex() const { |
2841 | if (stringifyThread(ParallelType::TIDx).compare(name()) == 0) { |
2842 | return c10::optional<ParallelType>(ParallelType::TIDx); |
2843 | } else if (stringifyThread(ParallelType::TIDy).compare(name()) == 0) { |
2844 | return c10::optional<ParallelType>(ParallelType::TIDy); |
2845 | } else if (stringifyThread(ParallelType::TIDz).compare(name()) == 0) { |
2846 | return c10::optional<ParallelType>(ParallelType::TIDz); |
2847 | } else if (stringifyThread(ParallelType::BIDx).compare(name()) == 0) { |
2848 | return c10::optional<ParallelType>(ParallelType::BIDx); |
2849 | } else if (stringifyThread(ParallelType::BIDy).compare(name()) == 0) { |
2850 | return c10::optional<ParallelType>(ParallelType::BIDy); |
2851 | } else if (stringifyThread(ParallelType::BIDz).compare(name()) == 0) { |
2852 | return c10::optional<ParallelType>(ParallelType::BIDz); |
2853 | } |
2854 | return c10::nullopt; |
2855 | } |
2856 | |
2857 | } // namespace cuda |
2858 | } // namespace fuser |
2859 | } // namespace jit |
2860 | } // namespace torch |
2861 | |