1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <fusion.h> |
6 | #include <ir_base_nodes.h> |
7 | #include <mma_type.h> |
8 | #include <parallel_type_bitmap.h> |
9 | |
10 | //! Nodes in here should generally not be used by users. They should be behind |
11 | //! the scenes and users shouldn't have to be aware of what they do to use the |
12 | //! code generator |
13 | //! |
14 | //! \todo improve implementation bool IterDomain::sameAs(const IterDomain*) |
15 | //! \todo Add testing of sameAs functions for these nodes |
16 | //! |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | namespace fuser { |
21 | namespace cuda { |
22 | |
23 | class ViewTransform; |
24 | class Scope; |
25 | class IrCloner; |
26 | struct AnalyzeViewResult; |
27 | |
28 | //! Returns true if both v1 and v2 are scalars, are the same type of scalars, |
29 | //! and dispatches to the inherited Val type's `->sameAs` call. e.g. if both |
30 | //! vals are `Int` will dispatch to v1->as<Int>()->sameAs(v2.as<Int>()) |
31 | bool areEqualScalars(Val* v1, Val* v2); |
32 | |
33 | class TORCH_CUDA_CU_API FullOp : public Expr { |
34 | public: |
35 | FullOp(IrBuilderPasskey, Val* out, Val* fill_value, DataType dtype); |
36 | |
37 | FullOp(const FullOp* src, IrCloner* ir_cloner); |
38 | |
39 | Expr* shallowCopy() const override; |
40 | |
41 | bool sameAs(const Statement* other) const override; |
42 | |
43 | DataType dtype() const { |
44 | return dtype_; |
45 | } |
46 | |
47 | Val* getFillValue() const { |
48 | return fill_value_; |
49 | } |
50 | |
51 | private: |
52 | const DataType dtype_; |
53 | Val* fill_value_; |
54 | }; |
55 | |
56 | class TORCH_CUDA_CU_API ARangeOp : public Expr { |
57 | public: |
58 | ARangeOp( |
59 | IrBuilderPasskey, |
60 | Val* out, |
61 | Val* start, |
62 | Val* end, |
63 | Val* step, |
64 | DataType dtype, |
65 | Val* linear_index = nullptr); |
66 | |
67 | ARangeOp(const ARangeOp* src, IrCloner* ir_cloner); |
68 | |
69 | Expr* shallowCopy() const override; |
70 | |
71 | bool sameAs(const Statement* other) const override; |
72 | |
73 | DataType dtype() const { |
74 | return dtype_; |
75 | } |
76 | |
77 | Val* start() const { |
78 | return start_; |
79 | } |
80 | |
81 | Val* end() const { |
82 | return end_; |
83 | } |
84 | |
85 | Val* step() const { |
86 | return step_; |
87 | } |
88 | |
89 | Val* getLinearLogicalIndex() const { |
90 | return linear_index_; |
91 | } |
92 | |
93 | void setLinearIndex(Val* index) { |
94 | linear_index_ = index; |
95 | } |
96 | |
97 | private: |
98 | const DataType dtype_; |
99 | Val* start_; |
100 | Val* end_; |
101 | Val* step_; |
102 | Val* linear_index_ = nullptr; |
103 | }; |
104 | |
105 | // Tensor factory for generating identity matrices like |
106 | // |
107 | // [[1, 0, 0], |
108 | // [0, 1, 0], |
109 | // [0, 0, 1]] |
110 | // |
111 | // or |
112 | // |
113 | // [[1, 0, 0], |
114 | // [0, 1, 0], |
115 | // [0, 0, 1], |
116 | // [0, 0, 0]] |
117 | // |
118 | // or |
119 | // |
120 | // [[1, 0, 0, 0], |
121 | // [0, 1, 0, 0], |
122 | // [0, 0, 1, 0]] |
123 | class TORCH_CUDA_CU_API EyeOp : public Expr { |
124 | public: |
125 | EyeOp( |
126 | IrBuilderPasskey, |
127 | Val* out, |
128 | DataType dtype, |
129 | Val* index1 = nullptr, |
130 | Val* index2 = nullptr); |
131 | |
132 | EyeOp(const EyeOp* src, IrCloner* ir_cloner); |
133 | |
134 | Expr* shallowCopy() const override; |
135 | |
136 | bool sameAs(const Statement* other) const override; |
137 | |
138 | DataType dtype() const { |
139 | return dtype_; |
140 | } |
141 | |
142 | Val* getIndex1() const { |
143 | return index1_; |
144 | } |
145 | |
146 | void setIndex1(Val* index) { |
147 | index1_ = index; |
148 | } |
149 | |
150 | Val* getIndex2() const { |
151 | return index2_; |
152 | } |
153 | |
154 | void setIndex2(Val* index) { |
155 | index2_ = index; |
156 | } |
157 | |
158 | private: |
159 | const DataType dtype_; |
160 | Val* index1_ = nullptr; |
161 | Val* index2_ = nullptr; |
162 | }; |
163 | |
164 | //! A specialization for Unary operations. Unary operations take in a single |
165 | //! input and produce a single output. Examples include: |
166 | //! 1) Casting operation i.e. float(a_val) |
167 | //! 2) Negation i.e. val * -1 |
168 | //! 3) Reduction across a dimension i.e. val.sum(axis=2) |
169 | //! 4) split/merge |
170 | class TORCH_CUDA_CU_API UnaryOp : public Expr { |
171 | public: |
172 | UnaryOp( |
173 | IrBuilderPasskey, |
174 | UnaryOpType type, |
175 | Val* out, |
176 | Val* in, |
177 | int rng_offset = -1); |
178 | |
179 | UnaryOp(const UnaryOp* src, IrCloner* ir_cloner); |
180 | |
181 | Expr* shallowCopy() const override; |
182 | |
183 | Val* out() const { |
184 | return out_; |
185 | } |
186 | Val* in() const { |
187 | return in_; |
188 | } |
189 | |
190 | UnaryOpType getUnaryOpType() const { |
191 | return unary_op_type_; |
192 | } |
193 | |
194 | bool sameAs(const Statement* other) const override; |
195 | |
196 | private: |
197 | const UnaryOpType unary_op_type_; |
198 | Val* const out_ = nullptr; |
199 | Val* const in_ = nullptr; |
200 | }; |
201 | |
202 | //! A specialization for Binary operations. Binary operations take in two inputs |
203 | //! and produce a single output. Examples include: |
204 | //! 1) Add/mul/div/mod/sub (A * B) |
205 | //! 2) LT (A < B) |
206 | class TORCH_CUDA_CU_API BinaryOp : public Expr { |
207 | public: |
208 | BinaryOp(IrBuilderPasskey, BinaryOpType type, Val* out, Val* lhs, Val* rhs); |
209 | |
210 | BinaryOp(const BinaryOp* src, IrCloner* ir_cloner); |
211 | |
212 | Expr* shallowCopy() const override; |
213 | |
214 | Val* out() const { |
215 | return out_; |
216 | } |
217 | Val* lhs() const { |
218 | return lhs_; |
219 | } |
220 | Val* rhs() const { |
221 | return rhs_; |
222 | } |
223 | |
224 | BinaryOpType getBinaryOpType() const { |
225 | return binary_op_type_; |
226 | } |
227 | |
228 | bool sameAs(const Statement* other) const override; |
229 | |
230 | private: |
231 | const BinaryOpType binary_op_type_; |
232 | Val* const out_ = nullptr; |
233 | Val* const lhs_ = nullptr; |
234 | Val* const rhs_ = nullptr; |
235 | }; |
236 | |
237 | //! A specialization for random number generator (RNG) operations. RNG |
238 | //! operations take in no tensor input and produce a single output. |
239 | class TORCH_CUDA_CU_API RNGOp : public Expr { |
240 | public: |
241 | RNGOp( |
242 | IrBuilderPasskey, |
243 | RNGOpType type, |
244 | Val* out, |
245 | DataType dtype, |
246 | std::vector<Val*> parameters = {}, |
247 | int rng_offset = 0, |
248 | Val* philox_index = nullptr); |
249 | |
250 | RNGOp(const RNGOp* src, IrCloner* ir_cloner); |
251 | |
252 | Expr* shallowCopy() const override; |
253 | |
254 | RNGOpType getRNGOpType() const { |
255 | return rng_op_type_; |
256 | } |
257 | |
258 | DataType dtype() const { |
259 | return dtype_; |
260 | } |
261 | |
262 | int getRNGOffset() const { |
263 | return rng_offset_; |
264 | } |
265 | |
266 | void setRNGOffset(int val) { |
267 | rng_offset_ = val; |
268 | } |
269 | |
270 | const std::vector<Val*>& getParameters() const { |
271 | return parameters_; |
272 | } |
273 | |
274 | const std::vector<Val*>& getShape() const { |
275 | return shape_; |
276 | } |
277 | |
278 | Val* getPhiloxIndex() const { |
279 | return philox_index_; |
280 | } |
281 | |
282 | void setPhiloxIndex(Val* index) { |
283 | philox_index_ = index; |
284 | } |
285 | |
286 | bool sameAs(const Statement* other) const override; |
287 | |
288 | private: |
289 | const RNGOpType rng_op_type_; |
290 | const DataType dtype_; |
291 | std::vector<Val*> parameters_; |
292 | std::vector<Val*> shape_; |
293 | int rng_offset_ = -1; |
294 | // The index used to feed philox's subsequence and component |
295 | Val* philox_index_ = nullptr; |
296 | }; |
297 | |
298 | //! Broadcast in to match out. is_broadcast_dims are relative to out. Where |
299 | //! is_broadcast_dims.size() == out->nDims(). |
300 | class TORCH_CUDA_CU_API BroadcastOp : public Expr { |
301 | public: |
302 | //! \param out The output tensor |
303 | //! \param in The input tensor |
304 | //! \param is_broadcast_dims True when output dim is a new broadcast domain |
305 | BroadcastOp( |
306 | IrBuilderPasskey, |
307 | Val* out, |
308 | Val* in, |
309 | std::vector<bool> is_broadcast_dims); |
310 | |
311 | BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner); |
312 | |
313 | Expr* shallowCopy() const override; |
314 | |
315 | Val* out() const { |
316 | return out_; |
317 | } |
318 | Val* in() const { |
319 | return in_; |
320 | } |
321 | |
322 | bool isBroadcastDim(size_t dim) const { |
323 | return is_broadcast_dims_.at(dim); |
324 | } |
325 | |
326 | const std::vector<bool>& getBroadcastDimFlags() const { |
327 | return is_broadcast_dims_; |
328 | } |
329 | |
330 | bool sameAs(const Statement* other) const override; |
331 | |
332 | private: |
333 | Val* const out_ = nullptr; |
334 | Val* const in_ = nullptr; |
335 | |
336 | //! The same list passed to the broadcast arithmetic op. Each |
337 | //! element corresponds to an IterDomain of the output tensor and is |
338 | //! true when the IterDomain is a new broadcast domain. Note |
339 | //! that the output tensor may have other broadcast domains whose |
340 | //! flags are false because the input tensor may already have |
341 | //! broadcast domains. |
342 | const std::vector<bool> is_broadcast_dims_; |
343 | }; |
344 | |
345 | //! Reduction operation. Out is first initialized to _init. Then |
346 | //! reduction_op_type is used to update out as out = reductionOp(out, in). |
347 | //! Output's axes marked as reduction will be reduced to produce an output |
348 | //! tensor. The output tensors size will be the size of all |
349 | //! non-reduction/non-broadcast dimensions. |
350 | class TORCH_CUDA_CU_API ReductionOp : public Expr { |
351 | public: |
352 | ReductionOp( |
353 | IrBuilderPasskey, |
354 | BinaryOpType reduction_op_type, |
355 | Val* init, |
356 | Val* out, |
357 | Val* in, |
358 | bool is_allreduce = false, |
359 | ExprType expr_type = ExprType::ReductionOp); |
360 | |
361 | ReductionOp(const ReductionOp* src, IrCloner* ir_cloner); |
362 | |
363 | Expr* shallowCopy() const override; |
364 | |
365 | Val* out() const { |
366 | return out_; |
367 | } |
368 | Val* in() const { |
369 | return in_; |
370 | } |
371 | Val* init() const { |
372 | return init_; |
373 | } |
374 | |
375 | BinaryOpType getReductionOpType() const { |
376 | return reduction_op_type_; |
377 | } |
378 | |
379 | bool isAllreduce() const { |
380 | return is_allreduce_; |
381 | } |
382 | |
383 | bool sameAs(const Statement* other) const override; |
384 | |
385 | private: |
386 | const BinaryOpType reduction_op_type_; |
387 | Val* const init_ = nullptr; |
388 | Val* const out_ = nullptr; |
389 | Val* const in_ = nullptr; |
390 | //! True if broadcast is fused |
391 | bool is_allreduce_ = false; |
392 | }; |
393 | |
394 | //! Grouped reduction operation for horizontal fusions. It works like |
395 | //! batched GEMMs in the sense that multiple independent reductions are |
396 | //! performed together. The main benefit is when reducing tensors across thread |
397 | //! blocks, a single grid sync can be done for all individual |
398 | //! reductions. As grid sync is very expensive, this can be a |
399 | //! significant performance impact. |
400 | class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { |
401 | public: |
402 | GroupedReductionOp( |
403 | IrBuilderPasskey, |
404 | std::vector<BinaryOpType> reduction_op_type, |
405 | std::vector<Val*> init, |
406 | std::vector<Val*> out, |
407 | std::vector<Val*> in, |
408 | bool is_allreduce = false, |
409 | ExprType expr_type = ExprType::GroupedReductionOp); |
410 | |
411 | GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner); |
412 | |
413 | Expr* shallowCopy() const override; |
414 | |
415 | //! Number of expressions grouped horizontally. It does not reflect |
416 | //! iteration grouping. |
417 | size_t numExprs() const { |
418 | return reduction_op_types_.size(); |
419 | } |
420 | |
421 | const std::vector<Val*>& initVals() const { |
422 | return init_vals_; |
423 | } |
424 | |
425 | Val* initVal(size_t index) const { |
426 | return init_vals_.at(index); |
427 | } |
428 | |
429 | const std::vector<BinaryOpType>& getReductionOpTypes() const { |
430 | return reduction_op_types_; |
431 | } |
432 | |
433 | BinaryOpType getReductionOpType(size_t index) const { |
434 | return reduction_op_types_.at(index); |
435 | } |
436 | |
437 | bool isAllreduce() const { |
438 | return is_allreduce_; |
439 | } |
440 | |
441 | //! Return the index of the corresponding reduction expression for |
442 | //! a given output val. |
443 | int getExprIndexOfOutput(Val* output_val) const; |
444 | |
445 | bool sameAs(const Statement* other) const override; |
446 | |
447 | private: |
448 | //! Reduction ops of grouped reductions |
449 | const std::vector<BinaryOpType> reduction_op_types_; |
450 | //! Initial values of grouped reductions |
451 | const std::vector<Val*> init_vals_; |
452 | //! True if using the fused reduction kernel |
453 | bool is_allreduce_ = false; |
454 | }; |
455 | |
456 | //! Average, variance and N (count) vals for Welford |
457 | class TORCH_CUDA_CU_API WelfordTriplet { |
458 | public: |
459 | //! Names of the Welford triplet vals |
460 | enum class ValName { Avg, Var, N }; |
461 | |
462 | WelfordTriplet() = default; |
463 | |
464 | WelfordTriplet(Val* avg, Val* var, Val* N) : vals_({avg, var, N}) {} |
465 | |
466 | Val* const& avg() const { |
467 | return get(ValName::Avg); |
468 | } |
469 | |
470 | Val*& avg() { |
471 | return get(ValName::Avg); |
472 | } |
473 | |
474 | TensorView* avgTv() const { |
475 | TORCH_INTERNAL_ASSERT(avg()->isA<TensorView>()); |
476 | return avg()->as<TensorView>(); |
477 | } |
478 | |
479 | Val* const& var() const { |
480 | return get(ValName::Var); |
481 | } |
482 | |
483 | Val*& var() { |
484 | return get(ValName::Var); |
485 | } |
486 | |
487 | TensorView* varTv() const { |
488 | TORCH_INTERNAL_ASSERT(var()->isA<TensorView>()); |
489 | return var()->as<TensorView>(); |
490 | } |
491 | |
492 | Val* const& N() const { |
493 | return get(ValName::N); |
494 | } |
495 | |
496 | Val*& N() { |
497 | return get(ValName::N); |
498 | } |
499 | |
500 | TensorView* NTv() const { |
501 | TORCH_INTERNAL_ASSERT(N()->isA<TensorView>()); |
502 | return N()->as<TensorView>(); |
503 | } |
504 | |
505 | //! Get the i-th val. Ordering is defined by ValName. |
506 | Val* const& get(int i) const { |
507 | return vals_.at(i); |
508 | } |
509 | |
510 | //! Get the i-th val. Ordering is defined by ValName. |
511 | Val*& get(int i) { |
512 | return vals_.at(i); |
513 | } |
514 | |
515 | Val* const& get(ValName name) const { |
516 | return get(valNameToIndex(name)); |
517 | } |
518 | |
519 | Val*& get(ValName name) { |
520 | return get(valNameToIndex(name)); |
521 | } |
522 | |
523 | //! Get the name of a given val in this triplet. None is returned if |
524 | //! not found. |
525 | c10::optional<ValName> getNameOf(Val* val) const; |
526 | |
527 | //! Return a new triplet with outputs produced by a function applied |
528 | //! to each of this triplet |
529 | template <typename Func> |
530 | WelfordTriplet transform(Func func) const { |
531 | return WelfordTriplet(func(avg()), func(var()), func(N())); |
532 | } |
533 | |
534 | bool sameAs(const WelfordTriplet& other) const; |
535 | |
536 | WelfordTriplet clone(IrCloner* ir_cloner) const; |
537 | |
538 | //! Clone a vector of triplets |
539 | static std::vector<WelfordTriplet> clone( |
540 | const std::vector<WelfordTriplet>& src, |
541 | IrCloner* ir_cloner); |
542 | |
543 | auto begin() { |
544 | return vals_.begin(); |
545 | } |
546 | |
547 | auto begin() const { |
548 | return vals_.begin(); |
549 | } |
550 | |
551 | auto end() { |
552 | return vals_.end(); |
553 | } |
554 | |
555 | auto end() const { |
556 | return vals_.end(); |
557 | } |
558 | |
559 | private: |
560 | //! Convert a given val name to an index |
561 | static int valNameToIndex(ValName name) { |
562 | return static_cast<int>(name); |
563 | } |
564 | |
565 | //! Convert a given index to a name |
566 | static ValName indexToValName(int index) { |
567 | TORCH_INTERNAL_ASSERT(index >= 0 && index < 3, "Invalid index: " , index); |
568 | return static_cast<ValName>(index); |
569 | } |
570 | |
571 | private: |
572 | //! Holds avg, var and N in this order |
573 | std::array<Val*, 3> vals_ = {{nullptr, nullptr, nullptr}}; |
574 | }; |
575 | |
576 | //! Welford Scan operation. |
577 | class TORCH_CUDA_CU_API WelfordOp : public Expr { |
578 | public: |
579 | WelfordOp( |
580 | IrBuilderPasskey, |
581 | const WelfordTriplet& output, |
582 | const WelfordTriplet& input, |
583 | const WelfordTriplet& init, |
584 | bool is_fused = false); |
585 | |
586 | WelfordOp( |
587 | IrBuilderPasskey, |
588 | Val* out_avg, |
589 | Val* out_var, |
590 | Val* out_N, |
591 | Val* in_avg, |
592 | Val* in_var, |
593 | Val* in_N, |
594 | Val* init_avg, |
595 | Val* init_var, |
596 | Val* init_N, |
597 | bool is_fused = false); |
598 | |
599 | WelfordOp(const WelfordOp* src, IrCloner* ir_cloner); |
600 | |
601 | Expr* shallowCopy() const override; |
602 | |
603 | Val* out() const { |
604 | return output().avg(); |
605 | } |
606 | |
607 | Val* in() const { |
608 | return input().avg(); |
609 | } |
610 | |
611 | bool sameAs(const Statement* const other) const override; |
612 | |
613 | const WelfordTriplet& output() const { |
614 | return output_; |
615 | } |
616 | |
617 | Val* outAvg() const { |
618 | return output().avg(); |
619 | } |
620 | |
621 | Val* outVar() const { |
622 | return output().var(); |
623 | } |
624 | |
625 | Val* outN() const { |
626 | return output().N(); |
627 | } |
628 | |
629 | const WelfordTriplet& input() const { |
630 | return input_; |
631 | } |
632 | |
633 | Val* inAvg() const { |
634 | return input().avg(); |
635 | } |
636 | |
637 | Val* inVar() const { |
638 | return input().var(); |
639 | } |
640 | |
641 | Val* inN() const { |
642 | return input().N(); |
643 | } |
644 | |
645 | const WelfordTriplet& init() const { |
646 | return init_; |
647 | } |
648 | |
649 | Val* initAvg() const { |
650 | return init().avg(); |
651 | } |
652 | |
653 | Val* initVar() const { |
654 | return init().var(); |
655 | } |
656 | |
657 | Val* initN() const { |
658 | return init().N(); |
659 | } |
660 | |
661 | bool singleValue() const { |
662 | return inN()->isOneInt(); |
663 | } |
664 | |
665 | bool hasInit() const { |
666 | return !initN()->isZeroInt(); |
667 | } |
668 | |
669 | bool isAllreduce() const { |
670 | return is_allreduce_; |
671 | } |
672 | |
673 | std::vector<Val*> getInitVals() const; |
674 | |
675 | //! Return the init val for an output val |
676 | Val* getInitValOfOutput(Val* output_val) const; |
677 | |
678 | private: |
679 | const WelfordTriplet output_; |
680 | const WelfordTriplet input_; |
681 | const WelfordTriplet init_; |
682 | //! True if using the fused reduction kernel (not implemented yet) |
683 | bool is_allreduce_ = false; |
684 | }; |
685 | |
686 | class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { |
687 | public: |
688 | GroupedWelfordOp( |
689 | IrBuilderPasskey, |
690 | std::vector<WelfordTriplet> output_vals, |
691 | std::vector<WelfordTriplet> input_vals, |
692 | std::vector<WelfordTriplet> init_vals, |
693 | bool is_allreduce = false, |
694 | ExprType expr_type = ExprType::GroupedWelfordOp); |
695 | |
696 | GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner); |
697 | |
698 | Expr* shallowCopy() const override; |
699 | |
700 | //! Number of expressions grouped horizontally. It does not reflect |
701 | //! iteration grouping. As horizontal grouping is not supported, |
702 | //! this always returns 1. |
703 | size_t numExprs() const { |
704 | return 1; |
705 | } |
706 | |
707 | Val* out(size_t index) const { |
708 | return outAvg(index); |
709 | } |
710 | |
711 | Val* in(size_t index) const { |
712 | return inAvg(index); |
713 | } |
714 | |
715 | bool sameAs(const Statement* const other) const override; |
716 | |
717 | const std::vector<WelfordTriplet>& outputVals() const { |
718 | return output_vals_; |
719 | } |
720 | |
721 | const std::vector<WelfordTriplet>& inputVals() const { |
722 | return input_vals_; |
723 | } |
724 | |
725 | const std::vector<WelfordTriplet>& initVals() const { |
726 | return init_vals_; |
727 | } |
728 | |
729 | Val* outAvg(size_t index) const { |
730 | return outputVals().at(index).avg(); |
731 | } |
732 | |
733 | Val* outVar(size_t index) const { |
734 | return outputVals().at(index).var(); |
735 | } |
736 | |
737 | Val* outN(size_t index) const { |
738 | return outputVals().at(index).N(); |
739 | } |
740 | |
741 | Val* inAvg(size_t index) const { |
742 | return inputVals().at(index).avg(); |
743 | } |
744 | |
745 | Val* inVar(size_t index) const { |
746 | return inputVals().at(index).var(); |
747 | } |
748 | |
749 | Val* inN(size_t index) const { |
750 | return inputVals().at(index).N(); |
751 | } |
752 | |
753 | Val* initAvg(size_t index) const { |
754 | return initVals().at(index).avg(); |
755 | } |
756 | |
757 | Val* initVar(size_t index) const { |
758 | return initVals().at(index).var(); |
759 | } |
760 | |
761 | Val* initN(size_t index) const { |
762 | return initVals().at(index).N(); |
763 | } |
764 | |
765 | //! Return the index of the corresponding welford expression for |
766 | //! a given output val |
767 | int getExprIndexOfOutput(Val* output_val) const; |
768 | |
769 | //! Return the init val for an output val |
770 | Val* getInitValOfOutput(Val* output_val) const; |
771 | |
772 | bool singleValue(size_t index) const { |
773 | return inN(index)->isOneInt(); |
774 | } |
775 | |
776 | bool hasInit(size_t index) const { |
777 | return !initN(index)->isZeroInt(); |
778 | } |
779 | |
780 | bool isAllreduce() const { |
781 | return is_allreduce_; |
782 | } |
783 | |
784 | private: |
785 | const std::vector<WelfordTriplet> output_vals_; |
786 | const std::vector<WelfordTriplet> input_vals_; |
787 | const std::vector<WelfordTriplet> init_vals_; |
788 | //! True if using the fused reduction kernel |
789 | bool is_allreduce_ = false; |
790 | }; |
791 | |
792 | //! Fused Matmul operation |
793 | class TORCH_CUDA_CU_API MmaOp : public Expr { |
794 | public: |
795 | // This is a temporary data structure to for the |
796 | // scheduling specific parameters that we still need |
797 | // to store on an mma node. Eventually will only be |
798 | // the mma macro type that will stay on the IR node |
799 | // after additional cleaning ups. |
800 | struct OptionsInMma { |
801 | MmaOptions::MacroType macro = MmaOptions::MacroType::NoMMA; |
802 | MmaOptions::MmaInputLayout operand_layout = MmaOptions::MmaInputLayout::TT; |
803 | int accumulator_stride = 0; |
804 | |
805 | bool operator==(const OptionsInMma& other) const { |
806 | return macro == other.macro && operand_layout == other.operand_layout && |
807 | accumulator_stride == other.accumulator_stride; |
808 | } |
809 | }; |
810 | |
811 | MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init); |
812 | |
813 | MmaOp( |
814 | IrBuilderPasskey, |
815 | Val* out, |
816 | Val* in_a, |
817 | Val* in_b, |
818 | Val* init, |
819 | OptionsInMma options); |
820 | |
821 | MmaOp(const MmaOp* src, IrCloner* ir_cloner); |
822 | |
823 | Expr* shallowCopy() const override; |
824 | |
825 | Val* out() const { |
826 | return out_; |
827 | } |
828 | |
829 | Val* inA() const { |
830 | return in_a_; |
831 | } |
832 | |
833 | Val* inB() const { |
834 | return in_b_; |
835 | } |
836 | |
837 | Val* init() const { |
838 | return init_; |
839 | } |
840 | |
841 | const auto& options() const { |
842 | TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:" , this); |
843 | return options_.value(); |
844 | } |
845 | |
846 | bool sameAs(const Statement* const other) const override; |
847 | |
848 | auto accStride() const { |
849 | TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:" , this); |
850 | return options_->accumulator_stride; |
851 | } |
852 | |
853 | void configureOptions(MmaOptions options) { |
854 | options_ = OptionsInMma(); |
855 | TORCH_INTERNAL_ASSERT( |
856 | options.macro != MmaOptions::MacroType::NoMMA, |
857 | "Un-configured mma type from options." ); |
858 | TORCH_INTERNAL_ASSERT( |
859 | options.accumulator_stride > 0, "Un-configured accumulator stride." ); |
860 | options_->accumulator_stride = options.accumulator_stride; |
861 | options_->macro = options.macro; |
862 | options_->operand_layout = options.operand_layout; |
863 | } |
864 | |
865 | private: |
866 | Val* const out_ = nullptr; |
867 | Val* const in_a_ = nullptr; |
868 | Val* const in_b_ = nullptr; |
869 | Val* const init_ = nullptr; |
870 | c10::optional<OptionsInMma> options_ = c10::nullopt; |
871 | }; |
872 | |
873 | class TORCH_CUDA_CU_API TransposeOp : public Expr { |
874 | public: |
875 | TransposeOp( |
876 | IrBuilderPasskey, |
877 | TensorView* out, |
878 | TensorView* in, |
879 | std::vector<int64_t> new2old); |
880 | |
881 | TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); |
882 | |
883 | Expr* shallowCopy() const override; |
884 | |
885 | TensorView* out() const { |
886 | return out_; |
887 | } |
888 | |
889 | TensorView* in() const { |
890 | return in_; |
891 | } |
892 | |
893 | const std::vector<int64_t>& new2old() const { |
894 | return new2old_; |
895 | } |
896 | |
897 | std::vector<int64_t> old2new() const; |
898 | |
899 | private: |
900 | TensorView* const out_ = nullptr; |
901 | TensorView* const in_ = nullptr; |
902 | const std::vector<int64_t> new2old_; |
903 | }; |
904 | |
905 | class TORCH_CUDA_CU_API ExpandOp : public Expr { |
906 | public: |
907 | ExpandOp( |
908 | IrBuilderPasskey, |
909 | TensorView* out, |
910 | TensorView* in, |
911 | std::vector<Val*> _expanded_extents); |
912 | |
913 | ExpandOp(const ExpandOp* src, IrCloner* ir_cloner); |
914 | |
915 | Expr* shallowCopy() const override; |
916 | |
917 | TensorView* out() const { |
918 | return out_; |
919 | } |
920 | |
921 | TensorView* in() const { |
922 | return in_; |
923 | } |
924 | |
925 | const std::vector<Val*>& expanded_extents() const { |
926 | return expanded_extents_; |
927 | } |
928 | |
929 | private: |
930 | TensorView* const out_ = nullptr; |
931 | TensorView* const in_ = nullptr; |
932 | std::vector<Val*> expanded_extents_; |
933 | }; |
934 | |
935 | class TORCH_CUDA_CU_API TernaryOp : public Expr { |
936 | public: |
937 | TernaryOp( |
938 | IrBuilderPasskey, |
939 | TernaryOpType type, |
940 | Val* out, |
941 | Val* in1, |
942 | Val* in2, |
943 | Val* in3); |
944 | |
945 | TernaryOp(const TernaryOp* src, IrCloner* ir_cloner); |
946 | |
947 | Expr* shallowCopy() const override; |
948 | |
949 | Val* out() const { |
950 | return out_; |
951 | } |
952 | |
953 | Val* in1() const { |
954 | return in1_; |
955 | } |
956 | Val* in2() const { |
957 | return in2_; |
958 | } |
959 | Val* in3() const { |
960 | return in3_; |
961 | } |
962 | |
963 | TernaryOpType getTernaryOpType() const { |
964 | return ternary_op_type_; |
965 | } |
966 | |
967 | bool sameAs(const Statement* other) const override; |
968 | |
969 | private: |
970 | const TernaryOpType ternary_op_type_; |
971 | Val* const out_ = nullptr; |
972 | Val* const in1_ = nullptr; |
973 | Val* const in2_ = nullptr; |
974 | Val* const in3_ = nullptr; |
975 | }; |
976 | |
977 | //! Shift |
978 | class TORCH_CUDA_CU_API ShiftOp : public Expr { |
979 | public: |
980 | //! \param out |
981 | //! \param in |
982 | //! \param offsets |
983 | ShiftOp( |
984 | IrBuilderPasskey, |
985 | Val* out, |
986 | Val* in, |
987 | std::vector<int> offsets, |
988 | std::vector<int> pad_width); |
989 | |
990 | ShiftOp(const ShiftOp* src, IrCloner* ir_cloner); |
991 | |
992 | Expr* shallowCopy() const override; |
993 | |
994 | Val* out() const { |
995 | return out_; |
996 | } |
997 | Val* in() const { |
998 | return in_; |
999 | } |
1000 | |
1001 | int offset(size_t dim) const { |
1002 | return offsets_.at(dim); |
1003 | } |
1004 | |
1005 | const std::vector<int>& offsets() const { |
1006 | return offsets_; |
1007 | } |
1008 | |
1009 | const std::vector<int>& padWidth() const { |
1010 | return pad_width_; |
1011 | } |
1012 | |
1013 | bool hasPadding() const { |
1014 | return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto p) { |
1015 | return p > 0; |
1016 | }); |
1017 | } |
1018 | |
1019 | bool sameAs(const Statement* other) const override; |
1020 | |
1021 | private: |
1022 | Val* const out_ = nullptr; |
1023 | Val* const in_ = nullptr; |
1024 | //! Each of the root axes is shifted by the corresponding value of |
1025 | //! offsets_. The sign of each value indicates the direction of |
1026 | //! shifting. |
1027 | const std::vector<int> offsets_; |
1028 | const std::vector<int> pad_width_; |
1029 | }; |
1030 | |
1031 | //! Gather a window around each element. |
1032 | class TORCH_CUDA_CU_API GatherOp : public Expr { |
1033 | public: |
1034 | GatherOp( |
1035 | IrBuilderPasskey, |
1036 | Val* out, |
1037 | Val* in, |
1038 | std::vector<int> window_shape, |
1039 | std::vector<std::vector<int>> pad_width); |
1040 | |
1041 | GatherOp(const GatherOp* src, IrCloner* ir_cloner); |
1042 | |
1043 | Expr* shallowCopy() const override; |
1044 | |
1045 | Val* out() const { |
1046 | return out_; |
1047 | } |
1048 | Val* in() const { |
1049 | return in_; |
1050 | } |
1051 | |
1052 | const auto& windowShape() const { |
1053 | return window_shape_; |
1054 | } |
1055 | |
1056 | //! Returns the gather axis that corresponds to an input axis |
1057 | int gatherAxis(int axis) const; |
1058 | |
1059 | const auto& padWidth() const { |
1060 | return pad_width_; |
1061 | } |
1062 | |
1063 | bool hasPadding() const { |
1064 | return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto& p) { |
1065 | return p[0] > 0 || p[1] > 0; |
1066 | }); |
1067 | } |
1068 | |
1069 | bool sameAs(const Statement* other) const override; |
1070 | |
1071 | private: |
1072 | Val* const out_ = nullptr; |
1073 | Val* const in_ = nullptr; |
1074 | //! Shape of a window gathered for each element. |
1075 | std::vector<int> window_shape_; |
1076 | //! The size of zero-padding of each axis. |
1077 | std::vector<std::vector<int>> pad_width_; |
1078 | }; |
1079 | |
1080 | class TORCH_CUDA_CU_API ViewAsScalar : public Expr { |
1081 | public: |
1082 | ViewAsScalar( |
1083 | IrBuilderPasskey, |
1084 | Val* out, |
1085 | Val* in, |
1086 | IterDomain* vector_id, |
1087 | Val* index = nullptr); |
1088 | |
1089 | ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner); |
1090 | |
1091 | Expr* shallowCopy() const override; |
1092 | |
1093 | Val* out() const { |
1094 | return out_; |
1095 | } |
1096 | |
1097 | Val* in() const { |
1098 | return in_; |
1099 | } |
1100 | |
1101 | IterDomain* vector_id() const { |
1102 | return vector_id_; |
1103 | } |
1104 | |
1105 | Val* index() const { |
1106 | return index_; |
1107 | } |
1108 | |
1109 | private: |
1110 | Val* const out_ = nullptr; |
1111 | Val* const in_ = nullptr; |
1112 | |
1113 | // The IterDomain of type VectorComponent newly appended to the output |
1114 | IterDomain* vector_id_ = nullptr; |
1115 | |
1116 | // The index that vector_id_ is lowered into |
1117 | Val* index_ = nullptr; |
1118 | }; |
1119 | |
1120 | class TORCH_CUDA_CU_API ViewOp : public Expr { |
1121 | public: |
1122 | ViewOp(IrBuilderPasskey, TensorView* out, TensorView* in); |
1123 | |
1124 | ViewOp(const ViewOp* src, IrCloner* ir_cloner); |
1125 | |
1126 | Expr* shallowCopy() const override; |
1127 | |
1128 | TensorView* out() const { |
1129 | return out_; |
1130 | } |
1131 | |
1132 | TensorView* in() const { |
1133 | return in_; |
1134 | } |
1135 | |
1136 | private: |
1137 | TensorView* const out_ = nullptr; |
1138 | TensorView* const in_ = nullptr; |
1139 | }; |
1140 | |
1141 | //! This operator explicitly models data movement between |
1142 | //! state spaces on GPU. Currently the modeled state spaces include |
1143 | //! global memory, shared memory and register. |
1144 | //! |
1145 | //! The main usage of this op is to facilitate generation of hardware |
1146 | //! accelerated memory ops, i.e. ldmatrix, cp.async and more to come. |
1147 | class TORCH_CUDA_CU_API LoadStoreOp : public Expr { |
1148 | public: |
1149 | LoadStoreOp(IrBuilderPasskey, LoadStoreOpType op_type, Val* out, Val* in); |
1150 | |
1151 | LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner); |
1152 | |
1153 | Expr* shallowCopy() const override; |
1154 | |
1155 | Val* out() const { |
1156 | return out_; |
1157 | } |
1158 | |
1159 | Val* in() const { |
1160 | return in_; |
1161 | } |
1162 | |
1163 | LoadStoreOpType opType() const { |
1164 | return load_store_type_; |
1165 | } |
1166 | |
1167 | private: |
1168 | LoadStoreOpType load_store_type_ = LoadStoreOpType::LdMatrix; |
1169 | Val* const out_ = nullptr; |
1170 | Val* const in_ = nullptr; |
1171 | }; |
1172 | |
1173 | // Convenience utility to initialize IterDomain's without having to sort through |
1174 | // all the default values. Intended to be used with |
1175 | // IterDomain::IterDomain(IrBuilderPasskey IterDomainBuildArgs) |
1176 | class TORCH_CUDA_CU_API IterDomainBuilder { |
1177 | public: |
1178 | // Match legacy constructor |
1179 | IterDomainBuilder(Val* _start, Val* _extent); |
1180 | |
1181 | // Grab all the parameters from id to set the IterDomainBuilder |
1182 | IterDomainBuilder(const IterDomain* id); |
1183 | |
1184 | // Resets defaults for rfactor, is padded dim, padded to size, and is mma |
1185 | // swizzle which should only be set during scheduling. |
1186 | IterDomainBuilder& resetSchedulingParams(); |
1187 | |
1188 | // Resets is_rfactor_domain |
1189 | IterDomainBuilder& resetRfactor(); |
1190 | |
1191 | IterDomainBuilder& start(Val* _start); |
1192 | IterDomainBuilder& extent(Val* _extent); |
1193 | IterDomainBuilder& expanded_extent(Val* _expanded_extent); |
1194 | IterDomainBuilder& stop_offset(Val* _stop_offset); |
1195 | IterDomainBuilder& parallel_type(ParallelType _parallel_type); |
1196 | IterDomainBuilder& iter_type(IterType _iter_type); |
1197 | IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); |
1198 | IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); |
1199 | IterDomainBuilder& padded_to_size(c10::optional<int64_t> _padded_to_size); |
1200 | IterDomainBuilder& is_mma_swizzled(bool _is_mma_swizzled); |
1201 | |
1202 | IterDomain* build() const; |
1203 | |
1204 | // Must have start and extent at least |
1205 | IterDomainBuilder() = delete; |
1206 | |
1207 | Val* start_ = nullptr; |
1208 | Val* extent_ = nullptr; |
1209 | Val* expanded_extent_ = nullptr; |
1210 | Val* stop_offset_ = nullptr; |
1211 | ParallelType parallel_type_ = ParallelType::Serial; |
1212 | IterType iter_type_ = IterType::Iteration; |
1213 | |
1214 | // Only relevant at scheduling time or compile time. |
1215 | bool is_rfactor_domain_ = false; |
1216 | bool is_padded_dimension_ = false; |
1217 | c10::optional<int64_t> padded_to_size_ = c10::nullopt; |
1218 | bool is_mma_swizzled_ = false; |
1219 | }; |
1220 | |
1221 | // Friends for direct access to split |
1222 | class TensorDomain; |
1223 | class ReplayTransformations; |
1224 | class IndexReferenceReplay; |
1225 | //! Simply a representation of an annotated 1D iterable from start to extent. |
1226 | //! TensorDomains which represent how to iterate over a tensor is made up of |
1227 | //! IterDomains to form an ND iterable. We directly set parallization strategies |
1228 | //! on IterDomains. |
1229 | class TORCH_CUDA_CU_API IterDomain : public Val { |
1230 | public: |
1231 | IterDomain(IrBuilderPasskey, const IterDomainBuilder& args); |
1232 | |
1233 | // Legacy constructor, TODO: should start moving to use IterDomainBuildArgs |
1234 | // constructor Same as the above but can set the offset of the stop point |
1235 | IterDomain( |
1236 | IrBuilderPasskey, |
1237 | Val* start, |
1238 | Val* extent, |
1239 | Val* expanded_extent, |
1240 | Val* stop_offset, |
1241 | ParallelType parallel_type, |
1242 | IterType iter_type, |
1243 | bool is_rfactor_domain, |
1244 | bool is_padded_dimension, |
1245 | c10::optional<int64_t> padded_to_size_, |
1246 | bool is_mma_swizzled); |
1247 | |
1248 | IterDomain(const IterDomain* src, IrCloner* ir_cloner); |
1249 | |
1250 | bool sameAs(const Statement* other) const override; |
1251 | |
1252 | //! Returns a new IterDomain matching properties of this |
1253 | //! |
1254 | //! This does NOT copy the is_rfactor_domain flag. |
1255 | IterDomain* cloneWithoutRFactor() const; |
1256 | |
1257 | //! Clone a vector domains |
1258 | static std::vector<IterDomain*> clone( |
1259 | const std::vector<IterDomain*>& domains); |
1260 | |
1261 | static IterDomain* merge(IterDomain* outer, IterDomain* inner); |
1262 | |
1263 | //! start_offset and stop_offset defines partial split. Only root |
1264 | //! domains are allowed to have non-zero start and stop offsets. |
1265 | static std::pair<IterDomain*, IterDomain*> split( |
1266 | IterDomain* in, |
1267 | Val* factor, |
1268 | bool inner_split, |
1269 | Val* start_offset = nullptr, |
1270 | Val* stop_offset = nullptr); |
1271 | |
1272 | //! trim_out_of_bounds controls how the values outside start and stop |
1273 | //! positions are treated. The option is only valid with root |
1274 | //! domains as non-root domains do not have valid start and stop |
1275 | //! positions. |
1276 | //! |
1277 | //! \param trim_out_of_bounds Trims [0, start_] and [-stop_offset_, extent_] |
1278 | static std::pair<IterDomain*, IterDomain*> split( |
1279 | IterDomain* in, |
1280 | Val* factor, |
1281 | bool inner_split, |
1282 | bool trim_out_of_bounds); |
1283 | |
1284 | bool isReduction() const { |
1285 | return getIterType() == IterType::Reduction; |
1286 | } |
1287 | |
1288 | bool isRFactorProduct() const { |
1289 | return is_rfactor_domain_; |
1290 | } |
1291 | |
1292 | bool isBroadcast() const { |
1293 | return getIterType() == IterType::Broadcast; |
1294 | } |
1295 | |
1296 | bool isGather() const { |
1297 | return getIterType() == IterType::Gather; |
1298 | } |
1299 | |
1300 | bool isStride() const { |
1301 | return getIterType() == IterType::Stride; |
1302 | } |
1303 | |
1304 | bool isVectorComponent() const { |
1305 | return getIterType() == IterType::VectorComponent; |
1306 | } |
1307 | |
1308 | bool isParallelized() const { |
1309 | return getParallelType() != ParallelType::Serial; |
1310 | } |
1311 | |
1312 | //! Return if this iter domain is mapped to a grid dimension |
1313 | bool isBlockDim() const { |
1314 | return isParallelTypeBlockDim(getParallelType()); |
1315 | } |
1316 | |
1317 | //! Return if this iter domain is mapped to a block dimension |
1318 | bool isThreadDim() const { |
1319 | return isParallelTypeThreadDim(getParallelType()); |
1320 | } |
1321 | |
1322 | //! Return if this iter domain is either mapped to a block or grid dimension |
1323 | bool isThread() const { |
1324 | return (isBlockDim() || isThreadDim()); |
1325 | } |
1326 | |
1327 | void parallelize(ParallelType t); |
1328 | |
1329 | ParallelType getParallelType() const { |
1330 | return parallel_type_; |
1331 | } |
1332 | |
1333 | IterType getIterType() const { |
1334 | return iter_type_; |
1335 | } |
1336 | |
1337 | Val* start() const { |
1338 | return start_; |
1339 | } |
1340 | |
1341 | Val* stop() const; |
1342 | |
1343 | Val* stopOffset() const; |
1344 | |
1345 | Val* extent() const { |
1346 | TORCH_INTERNAL_ASSERT(extent_ != nullptr); |
1347 | return extent_; |
1348 | } |
1349 | |
1350 | bool hasExpandedExtent() const { |
1351 | return expanded_extent_ != nullptr; |
1352 | } |
1353 | |
1354 | // Returns the expanded extent of a strided broadcast entry. |
1355 | Val* expandedExtent() const { |
1356 | TORCH_INTERNAL_ASSERT( |
1357 | hasExpandedExtent(), |
1358 | "Requested expanded extent, but none found on this dimension." ); |
1359 | return expanded_extent_; |
1360 | } |
1361 | |
1362 | Val* getMaybeExpandedExtent() const { |
1363 | if (hasExpandedExtent()) { |
1364 | return expandedExtent(); |
1365 | } |
1366 | return extent(); |
1367 | } |
1368 | |
1369 | //! Dimension padding interface: |
1370 | //! 2 modes are currently supported: |
1371 | //! |
1372 | //! - mode 1: if to_size is given as a positive number, |
1373 | //! the dimension will be padded to the size so that |
1374 | //! this iterdomain will be compile-time constant |
1375 | //! size and it is the scheduler's responsibility |
1376 | //! to ensure no input larger than the padded size |
1377 | //! will be observed |
1378 | //! |
1379 | //! - mode 2: if no to_size is given, this dimension |
1380 | //! is "dynamically" padded to next smallest multiple |
1381 | //! of a warp size, i.e. 17 padded to 32, 33 padded to 64 |
1382 | //! based on the given input. |
1383 | void padToMultipleOfWarp(c10::optional<int64_t> maybe_to_size = {}) { |
1384 | // Currently only restricted to TIDx to generate warp reduce |
1385 | TORCH_CHECK( |
1386 | parallel_type_ == ParallelType::TIDx, |
1387 | "padToMultipleOfWarp : warp padding only supported on TIDx parallel dimension" ); |
1388 | is_padded_dimension_ = true; |
1389 | if (maybe_to_size.has_value()) { |
1390 | if (maybe_to_size.value() > 0) { |
1391 | padded_to_size_ = maybe_to_size.value(); |
1392 | } |
1393 | } |
1394 | } |
1395 | |
1396 | //! Indicates if this iterdomain had padding |
1397 | //! dynamical or statical |
1398 | bool hasPaddingToMultipleOfWarp() const { |
1399 | return is_padded_dimension_; |
1400 | } |
1401 | |
1402 | //! Returns a concrete value if this iterdomain |
1403 | //! has been padded to a statical size. |
1404 | c10::optional<int64_t> getMaybeSizeAfterPadding() const { |
1405 | return padded_to_size_; |
1406 | } |
1407 | |
1408 | //! True if range of iteration domain isn't across the full extent |
1409 | bool maybePartial() const; |
1410 | |
1411 | //! Check if IterDomain is a broadcast axis with compile-time |
1412 | //! known extent. This is the case with all size-1 IterDomains on |
1413 | //! a TensorView's root domain when the TensorView is created. |
1414 | bool isImplicitBroadcast() const { |
1415 | return isBroadcast() && extent()->isOneInt(); |
1416 | } |
1417 | |
1418 | //! Check if IterDomain is a reduction axis with size of 1, i.e. |
1419 | //! a "squeeze" operator, or solely derived from such axes. |
1420 | bool isTrivialReduction() const; |
1421 | |
1422 | //! Split for stride by a given factor. It effectively does an inner |
1423 | //! split by the factor and sets the inner domain as a Stride |
1424 | //! domain. |
1425 | std::pair<IterDomain*, IterDomain*> stridedSplit(int factor); |
1426 | |
1427 | // TODO: Remove |
1428 | bool isSimple() const { |
1429 | return definition() == nullptr; |
1430 | } |
1431 | |
1432 | //! Marks that this id represents a |
1433 | //! instruction loop, mma use only. |
1434 | //! |
1435 | //! An instruction loop can be considered a generalization of |
1436 | //! vectorization. It also represents a loop that's implemented |
1437 | //! by an instruction and should not be realized by codegen and |
1438 | //! cannot be inlined with. |
1439 | //! As an example, if a mma macro, call it mma_eg implements: |
1440 | //! for m in M |
1441 | //! for n in N |
1442 | //! for k in K |
1443 | //! C[m,n] += A[m,k]*B[k,n], |
1444 | //! But the generated code should simply be: |
1445 | //! mma_eg(C,A,B) |
1446 | //! without the 3 level loopnest, i.e. they're instruction loops. |
1447 | //! |
1448 | //! In the actual mma macros, the loopnests it implements is a |
1449 | //! transformed version of above to match the mma swizzle. |
1450 | //! So it's different implicit loopnest for different macros. |
1451 | //! WarpMmaSwizzler will label the instruction loops case-by-case. |
1452 | bool isMma() const { |
1453 | return parallel_type_ == ParallelType::Mma; |
1454 | } |
1455 | |
1456 | //! Applies 2D swizzle on a rectangular tile defined by |
1457 | //! a pair of iterdomains. |
1458 | static std::pair<IterDomain*, IterDomain*> swizzle( |
1459 | Swizzle2DType swizzle_type, |
1460 | IterDomain* in_x, |
1461 | IterDomain* in_y, |
1462 | SwizzleMode swizzle_mode = SwizzleMode::Data); |
1463 | |
1464 | bool isMmaSwizzled() const { |
1465 | return is_mma_swizzled_; |
1466 | } |
1467 | |
1468 | //! Used by WarpMmaSwizzler, this is an utility for WarpMmaSwizzler |
1469 | //! to lock the thread swizzled iterdomains. |
1470 | //! Only true for the iterdomains produced by WarpMmaSwizzler. |
1471 | //! Mma ops require specific swizzle patterns |
1472 | //! and this label utility is to prevent any further transform on the |
1473 | //! iterdomains involved in the swizzle so that the pattern remain correct in |
1474 | //! generated code. |
1475 | //! |
1476 | //! Note: |
1477 | //! Used only through WarpMmaSwizzler only and mma validation relies on |
1478 | //! this |
1479 | //! flag being set on the correct iterdomains. |
1480 | void toMmaSwizzled() { |
1481 | is_mma_swizzled_ = true; |
1482 | } |
1483 | |
1484 | protected: |
1485 | friend TensorDomain; |
1486 | friend ReplayTransformations; |
1487 | friend IndexReferenceReplay; |
1488 | |
1489 | private: |
1490 | //! Valid range is defined as [start:-stop_offset] |
1491 | Val* const start_ = nullptr; |
1492 | Val* const extent_ = nullptr; |
1493 | |
1494 | // Broadcast dimensions are assumed to be size 1 for the sake of code |
1495 | // generation. If a user though calls `expand` on a tensor that dimension is |
1496 | // still considered a broadcast dimension. However if we ever output that |
1497 | // dimension it should be a size dictated by the `expand` operation, and have |
1498 | // a stride of zero. Since this extent is important to track, but not |
1499 | // necessarily generate code for (still want loops on broadcast to be of size |
1500 | // 0), we simply store it separately from extent_. Having an expanded_extent_ |
1501 | // is only allowed with broadcasted dimsneions. Only in this instance does it |
1502 | // make sense to have an expanded_extent_, because it's used when users are |
1503 | // expecting return tensors to have a physical domain. If a user simply |
1504 | // "broadcasts" an operation |
1505 | Val* const expanded_extent_ = nullptr; |
1506 | |
1507 | //! Distance of stop from the end |
1508 | Val* const stop_offset_ = nullptr; |
1509 | ParallelType parallel_type_ = ParallelType::Serial; |
1510 | IterType iter_type_ = IterType::Iteration; |
1511 | bool is_rfactor_domain_ = false; |
1512 | bool is_padded_dimension_ = false; |
1513 | c10::optional<int64_t> padded_to_size_ = c10::nullopt; |
1514 | |
1515 | // TODO: Remove only used in kernel IR because IterDomains don't maintain |
1516 | // definitions of split/merge. |
1517 | bool is_simple_ = true; |
1518 | |
1519 | //! Tracks if this id represents a thread swizzled loop or |
1520 | //! models an implicit loop within instructions. Should not make |
1521 | //! any changes once an id is warp mapped. |
1522 | bool is_mma_swizzled_ = false; |
1523 | }; |
1524 | |
1525 | //! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every |
1526 | //! logical axis in its associated tensor. TensorDomain does not directly hold |
1527 | //! the Tensor it is associated with, and in theory could be associated with |
1528 | //! multiple tensors. TensorDomain's primary responsibility is to provide a |
1529 | //! mechanism to access history of transformations that were used to generate |
1530 | //! it. This is done through the normal interaction of Expr/Val in Fusion. i.e. |
1531 | //! if we want to know the previous operation generating a particular |
1532 | //! TensorDomain we can simply call: |
1533 | //! |
1534 | //! FusionGuard::getCurFusion()->definition(a_tensor_domain) |
1535 | //! |
1536 | //! which should give us an operation in the list [split, merge] or similar |
1537 | //! operations that take in a TensorDomain, applies a transformation and outputs |
1538 | //! a tensor domain. |
1539 | class TORCH_CUDA_CU_API TensorDomain : public Val { |
1540 | public: |
1541 | explicit TensorDomain( |
1542 | IrBuilderPasskey, |
1543 | std::vector<IterDomain*> root_domain, |
1544 | std::vector<bool> contiguity = std::vector<bool>()); |
1545 | |
1546 | TensorDomain( |
1547 | IrBuilderPasskey, |
1548 | std::vector<IterDomain*> root_domain, |
1549 | std::vector<IterDomain*> domain, |
1550 | std::vector<bool> contiguity = std::vector<bool>()); |
1551 | |
1552 | TensorDomain( |
1553 | IrBuilderPasskey, |
1554 | std::vector<IterDomain*> root_domain, |
1555 | std::vector<IterDomain*> rfactor_domain, |
1556 | std::vector<IterDomain*> domain, |
1557 | std::vector<bool> contiguity = std::vector<bool>()); |
1558 | |
1559 | TensorDomain(const TensorDomain* src, IrCloner* ir_cloner); |
1560 | |
1561 | bool operator==(const TensorDomain& other) const; |
1562 | bool operator!=(const TensorDomain& other) const { |
1563 | return !(*this == other); |
1564 | } |
1565 | |
1566 | std::vector<IterDomain*>::size_type nDims() const { |
1567 | return domain_.size(); |
1568 | } |
1569 | |
1570 | bool sameAs(const Statement* other) const override; |
1571 | |
1572 | static bool sameAs( |
1573 | const std::vector<IterDomain*>& lhs, |
1574 | const std::vector<IterDomain*>& rhs); |
1575 | |
1576 | const std::vector<IterDomain*>& domain() const { |
1577 | return domain_; |
1578 | } |
1579 | |
1580 | const std::vector<bool>& contiguity() const { |
1581 | return contiguity_; |
1582 | } |
1583 | |
1584 | void setContiguity(const std::vector<bool>& contig); |
1585 | |
1586 | std::string getContiguityString() const { |
1587 | std::stringstream ss; |
1588 | for (auto b : contiguity()) { |
1589 | ss << (b ? "t" : "f" ); |
1590 | } |
1591 | return ss.str(); |
1592 | } |
1593 | |
1594 | bool hasReduction() const; |
1595 | bool hasBlockReduction() const; |
1596 | bool hasGridReduction() const; |
1597 | bool hasBlockBroadcast() const; |
1598 | bool hasGridBroadcast() const; |
1599 | bool hasBroadcast() const; |
1600 | bool hasRFactor() const; |
1601 | |
1602 | // Returns if rfactor domain only consists of id's of iter type. |
1603 | bool hasViewLikeRFactor() const; |
1604 | |
1605 | bool hasVectorize() const; |
1606 | |
1607 | c10::optional<unsigned int> getReductionAxis() const; |
1608 | |
1609 | const std::vector<IterDomain*>& noReductions() const { |
1610 | return no_reduction_domain_; |
1611 | } |
1612 | |
1613 | const std::vector<IterDomain*>& noBroadcasts() const { |
1614 | return no_bcast_domain_; |
1615 | } |
1616 | |
1617 | const std::vector<IterDomain*>& getRootDomain() const { |
1618 | return root_domain_; |
1619 | }; |
1620 | |
1621 | const std::vector<IterDomain*>& getRFactorDomain() const { |
1622 | return rfactor_domain_; |
1623 | }; |
1624 | |
1625 | // If rfactor domain exists in domain() return it, otherwise return root |
1626 | // domain. |
1627 | const std::vector<IterDomain*>& getMaybeRFactorDomain() const { |
1628 | return hasRFactor() ? getRFactorDomain() : getRootDomain(); |
1629 | } |
1630 | |
1631 | void resetDomains() { |
1632 | no_reduction_domain_ = noReductions(domain_); |
1633 | no_bcast_domain_ = noBroadcasts(domain_); |
1634 | has_nontrivial_reduction_ = hasNontrivialReduction(domain_); |
1635 | } |
1636 | |
1637 | // i here is int, as we want to accept negative value and ::size_type can be a |
1638 | // uint. |
1639 | IterDomain* axis(int i) const; |
1640 | |
1641 | size_t posOf(IterDomain* id) const; |
1642 | |
1643 | //! Returns a position of a root domain |
1644 | size_t rootPosOf(IterDomain* id) const; |
1645 | |
1646 | // Split "axis" into 2 axes |
1647 | //! inner_split dictates if the factor section of the split should be inside |
1648 | //! the |
1649 | //! remainer or outside. |
1650 | //! e.g. split(0, 4, inner_split = true) will result in: |
1651 | //! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}] |
1652 | //! e.g. split(0, 4, inner_split = false) will result in: |
1653 | //! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}] |
1654 | void split( |
1655 | int axis_, |
1656 | Val* factor, |
1657 | bool inner_split, |
1658 | bool trim_out_of_bounds = false); |
1659 | |
1660 | // Merge axis_o and axis_i. axis_i is the fast changing dimension. Resulting |
1661 | // axis is by default placed at original position axis_o |
1662 | void merge(int axis_o, int axis_i); |
1663 | |
1664 | // Reorder axes according to map[old_pos] = new_pos |
1665 | void reorder(const std::unordered_map<int, int>& old2new); |
1666 | |
1667 | //! Applies 2D swizzle on a rectangular tile defined by |
1668 | //! a pair of iterdomains contained in this domain. |
1669 | void swizzle( |
1670 | Swizzle2DType swizzle_type, |
1671 | int x, |
1672 | int y, |
1673 | SwizzleMode swizzle_mode = SwizzleMode::Data); |
1674 | |
1675 | // Transform TensorView according to merge and split transformations |
1676 | TensorDomain* view(const AnalyzeViewResult& view_analysis); |
1677 | |
1678 | TensorDomain* flatten(int64_t start_dim, int64_t end_dim); |
1679 | |
1680 | static std::vector<IterDomain*> orderedAs( |
1681 | const std::vector<IterDomain*>& td, |
1682 | const std::unordered_map<int, int>& old2new); |
1683 | |
1684 | static std::vector<IterDomain*> noReductions(const std::vector<IterDomain*>&); |
1685 | static std::vector<IterDomain*> noBroadcasts(const std::vector<IterDomain*>&); |
1686 | |
1687 | static bool hasBroadcast(const std::vector<IterDomain*>&); |
1688 | static bool hasReduction(const std::vector<IterDomain*>&); |
1689 | static bool hasNontrivialReduction(const std::vector<IterDomain*>&); |
1690 | |
1691 | // pair is in order where second is the consumer of first |
1692 | std::pair<TensorDomain*, TensorDomain*> rFactor(const std::vector<int>& axes); |
1693 | |
1694 | private: |
1695 | const std::vector<IterDomain*> root_domain_; |
1696 | std::vector<IterDomain*> domain_; |
1697 | std::vector<IterDomain*> no_bcast_domain_; |
1698 | std::vector<IterDomain*> no_reduction_domain_; |
1699 | const std::vector<IterDomain*> rfactor_domain_; |
1700 | std::vector<bool> contiguity_; |
1701 | bool has_nontrivial_reduction_; |
1702 | }; |
1703 | |
1704 | //! Representation a split on an IterDomain by "factor" |
1705 | //! inner_split dictates if the factor section of the split should be inside the |
1706 | //! remainer or outside. |
1707 | class TORCH_CUDA_CU_API Split : public Expr { |
1708 | public: |
1709 | // start_offset and stop_offset are used to express partial |
1710 | // split. Only the partial domain from start_offset to stop_offset |
1711 | // is split and the outer sub-regions are ignored. Note that both |
1712 | // start_offset and stop_offset are distance from the left end and |
1713 | // right ends, respectively. |
1714 | Split( |
1715 | IrBuilderPasskey, |
1716 | IterDomain* outer, |
1717 | IterDomain* inner, |
1718 | IterDomain* in, |
1719 | Val* factor, |
1720 | bool inner_split = true, |
1721 | Val* start_offset = nullptr, |
1722 | Val* stop_offset = nullptr); |
1723 | |
1724 | Split(const Split* src, IrCloner* ir_cloner); |
1725 | |
1726 | Expr* shallowCopy() const override; |
1727 | |
1728 | IterDomain* outer() const { |
1729 | return outer_; |
1730 | } |
1731 | IterDomain* inner() const { |
1732 | return inner_; |
1733 | } |
1734 | IterDomain* in() const { |
1735 | return in_; |
1736 | } |
1737 | Val* factor() const { |
1738 | return factor_; |
1739 | } |
1740 | |
1741 | bool innerSplit() const { |
1742 | return inner_split_; |
1743 | } |
1744 | |
1745 | Val* startOffset() const { |
1746 | TORCH_INTERNAL_ASSERT(start_offset_ != nullptr); |
1747 | return start_offset_; |
1748 | } |
1749 | |
1750 | Val* stopOffset() const { |
1751 | TORCH_INTERNAL_ASSERT(stop_offset_ != nullptr); |
1752 | return stop_offset_; |
1753 | } |
1754 | |
1755 | //! Utility function to compute the split extent. |
1756 | static Val* extent(Val* in_extent, Val* start_offset, Val* stop_offset); |
1757 | |
1758 | bool sameAs(const Statement* other) const override; |
1759 | |
1760 | private: |
1761 | IterDomain* const outer_ = nullptr; |
1762 | IterDomain* const inner_ = nullptr; |
1763 | IterDomain* const in_ = nullptr; |
1764 | Val* const factor_ = nullptr; |
1765 | bool inner_split_ = true; |
1766 | //! Start position of the input domain. Non-zero means partial |
1767 | //! split. Elements until this offset are ignored. |
1768 | Val* const start_offset_ = nullptr; |
1769 | //! Offset from extent of the input domain. Non-zero means partial |
1770 | //! split. Elements after this offset are ignored. |
1771 | Val* const stop_offset_ = nullptr; |
1772 | }; |
1773 | |
1774 | //! Merge the IterDomains outer and inner into one domain, outer and inner |
1775 | //! dictate which will be traversed first (inner). Both IterDomains must be of |
1776 | //! the same iter or reduction type, as well as the same parallelization |
1777 | //! strategy if there is one |
1778 | class TORCH_CUDA_CU_API Merge : public Expr { |
1779 | public: |
1780 | Merge( |
1781 | IrBuilderPasskey, |
1782 | IterDomain* out, |
1783 | IterDomain* outer, |
1784 | IterDomain* inner); |
1785 | |
1786 | Merge(const Merge* src, IrCloner* ir_cloner); |
1787 | |
1788 | Expr* shallowCopy() const override; |
1789 | |
1790 | IterDomain* out() const { |
1791 | return out_; |
1792 | } |
1793 | IterDomain* outer() const { |
1794 | return outer_; |
1795 | } |
1796 | IterDomain* inner() const { |
1797 | return inner_; |
1798 | } |
1799 | |
1800 | bool sameAs(const Statement* other) const override; |
1801 | |
1802 | private: |
1803 | IterDomain* const out_ = nullptr; |
1804 | IterDomain* const outer_ = nullptr; |
1805 | IterDomain* const inner_ = nullptr; |
1806 | }; |
1807 | |
1808 | //! Applies 2D swizzles on a rectangular tile defined by 2 iterdomains. |
1809 | class TORCH_CUDA_CU_API Swizzle2D : public Expr { |
1810 | public: |
1811 | Swizzle2D( |
1812 | IrBuilderPasskey, |
1813 | IterDomain* out_x, |
1814 | IterDomain* out_y, |
1815 | IterDomain* in_x, |
1816 | IterDomain* in_y, |
1817 | Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle, |
1818 | SwizzleMode swizzle_mode = SwizzleMode::Data); |
1819 | |
1820 | Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner); |
1821 | |
1822 | Expr* shallowCopy() const override; |
1823 | |
1824 | IterDomain* outX() const { |
1825 | return out_x_; |
1826 | } |
1827 | |
1828 | IterDomain* outY() const { |
1829 | return out_y_; |
1830 | } |
1831 | |
1832 | IterDomain* inX() const { |
1833 | return in_x_; |
1834 | } |
1835 | |
1836 | IterDomain* inY() const { |
1837 | return in_y_; |
1838 | } |
1839 | |
1840 | auto swizzleType() const { |
1841 | return swizzle_type_; |
1842 | } |
1843 | |
1844 | auto swizzleMode() const { |
1845 | return swizzle_mode_; |
1846 | } |
1847 | |
1848 | bool sameAs(const Statement* other) const override; |
1849 | |
1850 | private: |
1851 | // Output iterdomain pair corresponding |
1852 | // to the original input iterdomain pair. |
1853 | IterDomain* const out_x_ = nullptr; |
1854 | IterDomain* const out_y_ = nullptr; |
1855 | |
1856 | // Input iterdomain pair. |
1857 | IterDomain* const in_x_ = nullptr; |
1858 | IterDomain* const in_y_ = nullptr; |
1859 | |
1860 | // The type of predefined 1-to-1 functions |
1861 | // used for swizzling math. |
1862 | Swizzle2DType swizzle_type_ = Swizzle2DType::NoSwizzle; |
1863 | |
1864 | // Swizzle mode of this swizzle instance. |
1865 | // [Note on swizzle mode] |
1866 | // On the current implementations we support two modes of |
1867 | // swizzle math, namely, data mode and loop mode. |
1868 | // `Data` mode swizzling is a swizzle that will change the |
1869 | // data layout in shared memory, likely in global memory buffers |
1870 | // as well in the future. see also IndexSwizzle in index_compute.cpp. |
1871 | // |
1872 | // Most important use cases are transpose bank conflict removal, and mma |
1873 | // swizzled shared memory layout. Example illustrated in 1D case: |
1874 | // |
1875 | // for (int i = 0; i<I; i++){ |
1876 | // # This is a `Data` mode swizzle. |
1877 | // Tshared [swizzled(i)] = Tin[i]; |
1878 | // } |
1879 | // # Now Tshared holds swizzled data, i.e. the data layout of |
1880 | // Tshared does not map to Tin with affine relationships. |
1881 | // |
1882 | // for(int i=0;i<I;i++){ |
1883 | // Tout = Tshared[swizzled(i)]; |
1884 | // } |
1885 | // |
1886 | // `Loop` mode swizzling does not affect the data layout of any buffer |
1887 | // but only permutes the iteration order of serial or parallel loop. |
1888 | // This is useful when we want to designate non-affine mapping of thread |
1889 | // to data or we want to generate non-affine loops. |
1890 | // Exampe illustrated in 1D case: |
1891 | // for (int i = 0; i<I; i++){ |
1892 | // # This is a `Loop` mode swizzle |
1893 | // Tshared [swizzled(i)] = Tin[swizzled(i)]; |
1894 | // } |
1895 | // # Now Tshared holds normal data, i.e. it still has |
1896 | // the same data layout as if the swizzle wasn't there. |
1897 | // |
1898 | // # Consumers of Tshared does not need to know about the |
1899 | // loop swizzle at previous op if not inlined. |
1900 | // for(int i=0;i<I;i++){ |
1901 | // Tout = Tshared[i]; |
1902 | // } |
1903 | // TODO: Loop swizzles eventually will be piped through in all mappings |
1904 | // and replay of the fusion IR infrastructure. |
1905 | SwizzleMode swizzle_mode_ = SwizzleMode::Data; |
1906 | }; |
1907 | |
1908 | //! Integer value which has a special name |
1909 | //! |
1910 | //! These could be: |
1911 | //! - threadIdx.x |
1912 | //! - blockIdx.y |
1913 | //! - blockDim.z |
1914 | //! - T3.stride[2] |
1915 | //! |
1916 | class TORCH_CUDA_CU_API NamedScalar : public Val { |
1917 | public: |
1918 | NamedScalar(IrBuilderPasskey passkey, std::string name, DataType dtype); |
1919 | |
1920 | NamedScalar(const NamedScalar* src, IrCloner* ir_cloner); |
1921 | |
1922 | const std::string& name() const { |
1923 | return name_; |
1924 | } |
1925 | |
1926 | bool sameAs(const Statement* other) const override; |
1927 | |
1928 | //! Return the named scalar extent of a parallel dimension (e.g. blockDim.x) |
1929 | //! WARNING: Only works with Fusion container at the moment |
1930 | static NamedScalar* getParallelDim(ParallelType p_type); |
1931 | |
1932 | //! Return the named scalar index of a parallel dimension (e.g. threadIdx.x) |
1933 | //! WARNING: Only works with Fusion container at the moment |
1934 | static NamedScalar* getParallelIndex(ParallelType p_type); |
1935 | |
1936 | //! Return the parallel type of this NamedScalar if it is an extent of a |
1937 | //! parallel dimension |
1938 | c10::optional<ParallelType> getParallelDim() const; |
1939 | |
1940 | //! Return the parallel type of this NamedScalar if it is an index of a |
1941 | //! parallel dimension |
1942 | c10::optional<ParallelType> getParallelIndex() const; |
1943 | |
1944 | private: |
1945 | std::string name_; |
1946 | }; |
1947 | |
1948 | } // namespace cuda |
1949 | } // namespace fuser |
1950 | } // namespace jit |
1951 | } // namespace torch |
1952 | |