1 | #pragma once |
2 | |
3 | #include <ir_all_nodes.h> |
4 | #include <ir_base_nodes.h> |
5 | #include <parallel_type_bitmap.h> |
6 | #include <type.h> |
7 | #include <utils.h> |
8 | |
9 | #include <c10/macros/Export.h> |
10 | #include <c10/util/Optional.h> |
11 | |
12 | #include <cstdint> |
13 | #include <string> |
14 | #include <unordered_map> |
15 | #include <vector> |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | namespace fuser { |
20 | namespace cuda { |
21 | |
22 | class IrBuilderPasskey; |
23 | |
24 | // Abstract nodes |
25 | class Val; |
26 | class Expr; |
27 | |
28 | // Values |
29 | class Bool; |
30 | class Double; |
31 | class Int; |
32 | class NamedScalar; |
33 | |
34 | class IterDomain; |
35 | class TensorDomain; |
36 | class TensorView; |
37 | |
38 | // Expressions |
39 | class UnaryOp; |
40 | class BinaryOp; |
41 | class TernaryOp; |
42 | class RNGOp; |
43 | class ReductionOp; |
44 | class WelfordOp; |
45 | class BroadcastOp; |
46 | |
47 | namespace kir { |
48 | class Kernel; |
49 | |
50 | // Values |
51 | class Predicate; |
52 | class TensorIndex; |
53 | |
54 | // Expressions |
55 | class Allocate; |
56 | class BlockSync; |
57 | class GridSync; |
58 | class CpAsyncWait; |
59 | class CpAsyncCommit; |
60 | class InitMagicZero; |
61 | class UpdateMagicZero; |
62 | class ForLoop; |
63 | class IfThenElse; |
64 | class GridReduction; |
65 | class GroupedGridReduction; |
66 | class GridBroadcast; |
67 | class GridWelford; |
68 | class GroupedGridWelford; |
69 | class AllocateFusedReduction; |
70 | |
71 | // Expr container |
72 | class Scope; |
73 | |
74 | class TORCH_CUDA_CU_API Predicate final : public Val { |
75 | public: |
76 | explicit Predicate( |
77 | IrBuilderPasskey passkey, |
78 | PredicateType ptype, |
79 | const Expr* expr = nullptr, |
80 | Bool* thread_pred = nullptr); |
81 | |
82 | explicit Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop); |
83 | |
84 | explicit Predicate(IrBuilderPasskey passkey, Bool* value); |
85 | |
86 | PredicateType predicate_type() const { |
87 | return ptype_; |
88 | } |
89 | |
90 | const Expr* expr() const { |
91 | TORCH_INTERNAL_ASSERT( |
92 | ptype_ != PredicateType::Unswitch && |
93 | ptype_ != PredicateType::Vectorize && ptype_ != PredicateType::Manual); |
94 | return expr_; |
95 | } |
96 | |
97 | Bool* thread_pred() const { |
98 | TORCH_INTERNAL_ASSERT( |
99 | ptype_ == PredicateType::Inline || |
100 | ptype_ == PredicateType::Misaligned || ptype_ == PredicateType::Shift || |
101 | ptype_ == PredicateType::Padding || |
102 | ptype_ == PredicateType::ReductionWrite); |
103 | return thread_pred_; |
104 | } |
105 | |
106 | ForLoop* unrolled_loop() const { |
107 | TORCH_INTERNAL_ASSERT(ptype_ == PredicateType::Unswitch); |
108 | return unrolled_loop_; |
109 | } |
110 | |
111 | bool hasValue() const { |
112 | return value_ != nullptr; |
113 | } |
114 | |
115 | Bool* value() const { |
116 | TORCH_INTERNAL_ASSERT( |
117 | value_ != nullptr, |
118 | "The conditional expression for this Predicate is invalid." ); |
119 | return value_; |
120 | } |
121 | |
122 | void setValue(Bool* value) { |
123 | TORCH_INTERNAL_ASSERT(value != nullptr, "The Bool expression is invalid." ); |
124 | value_ = value; |
125 | } |
126 | |
127 | bool isConst() const final { |
128 | return hasValue() && value_->isConst(); |
129 | } |
130 | |
131 | private: |
132 | PredicateType ptype_ = PredicateType::Manual; |
133 | |
134 | // For PredicateCompute::getInlinePredicate, |
135 | // ShiftPredicateInserter::getShiftPredicate and getPaddingPredicate |
136 | const Expr* expr_ = nullptr; |
137 | |
138 | // For PredicateCompute::getInlinePredicate |
139 | Bool* thread_pred_ = nullptr; |
140 | |
141 | // For ParallelType::Unswitch - UnswitchPredicate::get |
142 | ForLoop* unrolled_loop_ = nullptr; |
143 | |
144 | // The Bool conditional value |
145 | // The value is nullptr until lower_predicate pass |
146 | Bool* value_ = nullptr; |
147 | }; |
148 | |
149 | class TORCH_CUDA_CU_API TensorIndex final : public Val { |
150 | public: |
151 | TensorIndex( |
152 | IrBuilderPasskey, |
153 | const TensorView* view, |
154 | std::vector<Val*> indices); |
155 | |
156 | std::vector<Val*>::size_type nDims() const { |
157 | return indices_.size(); |
158 | } |
159 | |
160 | Val* index(int i) const; |
161 | |
162 | const std::vector<Val*>& indices() const { |
163 | return indices_; |
164 | } |
165 | |
166 | TensorView* view() const { |
167 | TORCH_INTERNAL_ASSERT(view_ != nullptr); |
168 | return const_cast<TensorView*>(view_); // NOLINT |
169 | } |
170 | |
171 | private: |
172 | const TensorView* view_ = nullptr; |
173 | std::vector<Val*> indices_; |
174 | }; |
175 | |
176 | //! Allocate is a lower level Node that describes a buffer of memory that |
177 | //! is required as an intermediate within a kernel. The extent is the expression |
178 | //! of the size of the buffer that is generated from the TensorView that |
179 | //! describes the output of an operation. |
180 | class TORCH_CUDA_CU_API Allocate final : public Expr { |
181 | public: |
182 | //! Allocation of a multi-dimensional buffer |
183 | //! |
184 | //! param shape Size of each dimension |
185 | explicit Allocate( |
186 | IrBuilderPasskey passkey, |
187 | Val* buffer, |
188 | MemoryType memory_type, |
189 | std::vector<Val*> shape = {}, |
190 | bool zero_init = false); |
191 | |
192 | //! Allocation of a non-dimensional buffer |
193 | //! |
194 | //! param size Size of allocation |
195 | explicit Allocate( |
196 | IrBuilderPasskey passkey, |
197 | Val* buffer, |
198 | MemoryType memory_type, |
199 | Val* size, |
200 | bool zero_init = false); |
201 | |
202 | Expr* shallowCopy() const override; |
203 | |
204 | Val* buffer() const { |
205 | return buffer_; |
206 | } |
207 | |
208 | MemoryType memoryType() const { |
209 | return memory_type_; |
210 | } |
211 | |
212 | Val* size() const { |
213 | return size_; |
214 | } |
215 | |
216 | const std::vector<Val*>& shape() const { |
217 | return shape_; |
218 | } |
219 | |
220 | bool zeroInit() const { |
221 | return zero_init_; |
222 | } |
223 | |
224 | const Allocate* alias() const { |
225 | return alias_; |
226 | } |
227 | |
228 | void setAlias(const Allocate* alias) { |
229 | TORCH_INTERNAL_ASSERT(alias != this); |
230 | TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type_); |
231 | alias_ = alias; |
232 | } |
233 | |
234 | private: |
235 | Val* buffer_ = nullptr; |
236 | MemoryType memory_type_ = MemoryType::Local; |
237 | //! Size of each dimension |
238 | std::vector<Val*> shape_; |
239 | bool zero_init_ = false; |
240 | //! Total size |
241 | Val* size_ = nullptr; |
242 | |
243 | // This alias tracks the next Allocate node in a linked chain of aliases |
244 | // If the alias is nullptr, then the Allocate node uses memory in the kernel |
245 | const Allocate* alias_ = nullptr; |
246 | }; |
247 | |
248 | // Sync represents __syncthreads barrier for block level coordination. |
249 | // |
250 | // TODO(kir): change name to SyncThreads as we could have other barriers. |
251 | // |
252 | class TORCH_CUDA_CU_API BlockSync final : public Expr { |
253 | public: |
254 | explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); |
255 | |
256 | Expr* shallowCopy() const override; |
257 | |
258 | bool isWarHazardSync() const { |
259 | return war_sync_; |
260 | } |
261 | |
262 | private: |
263 | // TODO: war_sync_ is only used for testing/validation purposes. |
264 | bool war_sync_ = false; |
265 | }; |
266 | |
267 | // CpAsyncWait represents wait intrinsics for cp.async |
268 | class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { |
269 | public: |
270 | explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0); |
271 | |
272 | Expr* shallowCopy() const override; |
273 | |
274 | //! Returns the remaining number of stages that are not synchronized |
275 | //! after this op. |
276 | unsigned int keepStages() const { |
277 | return keep_stages_; |
278 | } |
279 | |
280 | private: |
281 | //! Number of stage to leave un-sync'ed by this op. |
282 | unsigned int keep_stages_ = 0; |
283 | }; |
284 | |
285 | // CpAsyncCommit represents commit intrinsics for cp.async |
286 | // A commit intrinsic communicates delimiter of transaction groups |
287 | // to the async load hardware. Example usage see [Cicular buffer]. |
288 | class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { |
289 | public: |
290 | explicit CpAsyncCommit(IrBuilderPasskey passkey); |
291 | |
292 | Expr* shallowCopy() const override; |
293 | }; |
294 | |
295 | // Synchronize all blocks in device, implies cooperative group launch is |
296 | // required. |
297 | class TORCH_CUDA_CU_API GridSync final : public Expr { |
298 | public: |
299 | explicit GridSync( |
300 | IrBuilderPasskey passkey, |
301 | ParallelTypeBitmap sync_dims, |
302 | Val* sync_buffer); |
303 | |
304 | Expr* shallowCopy() const override; |
305 | |
306 | ParallelTypeBitmap syncDims() const { |
307 | return sync_dims_; |
308 | } |
309 | |
310 | Val* syncBuffer() const { |
311 | return sync_buffer_; |
312 | } |
313 | |
314 | private: |
315 | ParallelTypeBitmap sync_dims_; |
316 | Val* sync_buffer_ = nullptr; |
317 | }; |
318 | |
319 | // Simply prints "DEFINE_MAGIC_ZERO" in the code in accordance with magic_zero |
320 | // in helpers.cu |
321 | class TORCH_CUDA_CU_API InitMagicZero final : public Expr { |
322 | public: |
323 | explicit InitMagicZero(IrBuilderPasskey passkey); |
324 | |
325 | Expr* shallowCopy() const override; |
326 | }; |
327 | |
328 | // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero |
329 | // in helpers.cu |
330 | class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { |
331 | public: |
332 | explicit UpdateMagicZero(IrBuilderPasskey passkey); |
333 | |
334 | Expr* shallowCopy() const override; |
335 | }; |
336 | |
337 | // TODO(kir): promote to IR node |
338 | class TORCH_CUDA_CU_API Scope { |
339 | public: |
340 | explicit Scope(Expr* owner) : owner_(owner) {} |
341 | |
342 | const std::vector<Expr*>& exprs() const { |
343 | return exprs_; |
344 | } |
345 | |
346 | bool empty() const { |
347 | return exprs_.empty(); |
348 | } |
349 | |
350 | auto size() const { |
351 | return exprs_.size(); |
352 | } |
353 | |
354 | auto& operator[](size_t i) { |
355 | return exprs_[i]; |
356 | } |
357 | |
358 | auto& operator[](size_t i) const { |
359 | return exprs_[i]; |
360 | } |
361 | |
362 | // Insert expr before expression at pos |
363 | void insert(size_t pos, Expr* expr); |
364 | |
365 | // Insert expr before ref |
366 | void insert_before(Expr* ref, Expr* expr); |
367 | |
368 | // Insert expr after ref |
369 | void insert_after(Expr* ref, Expr* expr); |
370 | |
371 | void push_back(Expr* e) { |
372 | exprs_.push_back(e); |
373 | } |
374 | |
375 | // Erase expr at pos |
376 | void erase(size_t pos); |
377 | |
378 | // Erase expr ref |
379 | void erase(Expr* ref); |
380 | |
381 | bool contains(Expr* expr) const; |
382 | |
383 | void clear(); |
384 | |
385 | Expr* owner() const { |
386 | return owner_; |
387 | } |
388 | |
389 | private: |
390 | // Insert expr before pos |
391 | void insert(std::vector<Expr*>::const_iterator pos, Expr* expr); |
392 | |
393 | // Erase expr at pos |
394 | void erase(std::vector<Expr*>::const_iterator pos); |
395 | |
396 | private: |
397 | std::vector<Expr*> exprs_; |
398 | |
399 | //! Owner exprssion of this scope, e.g., IfThenElse |
400 | Expr* owner_ = nullptr; |
401 | }; |
402 | |
403 | //! ForLoop provides scoping around an int iterator from 0 to range. Exprs |
404 | //! placed in its body are considered inside the scope of the for loop. In the |
405 | //! future the implementation should look quite different so that we can do |
406 | //! proper dependency annalysis like in Fusion. |
407 | //! |
408 | //! TODO(kir): this is not a real expression |
409 | //! |
410 | //! ForLoop may represent a part of an iteration domain representend |
411 | //! by iter_domain_. In that case, the loop extent field, extent_, may |
412 | //! be smaller than the extent of iter_domain_. |
413 | class TORCH_CUDA_CU_API ForLoop final : public Expr { |
414 | public: |
415 | //! By default, start and stop are the same as those of iter_domain. |
416 | //! Step is one by default. |
417 | //! |
418 | //! TODO: cleaner way to set options? |
419 | ForLoop( |
420 | IrBuilderPasskey passkey, |
421 | IterDomain* iter_domain, |
422 | Val* index, |
423 | Val* start, |
424 | Val* stop, |
425 | Val* step, |
426 | bool vectorize, |
427 | Val* vectorize_shift, |
428 | bool unroll_required, |
429 | DoubleBufferLoopStage double_buffer_loop_stage); |
430 | |
431 | ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain); |
432 | |
433 | ForLoop(IrBuilderPasskey passkey, const ForLoop* other); |
434 | |
435 | Expr* shallowCopy() const override; |
436 | |
437 | Val* index() const { |
438 | return index_; |
439 | } |
440 | |
441 | Val* start() const; |
442 | |
443 | Val* stop() const; |
444 | |
445 | Val* step() const; |
446 | |
447 | Val* vectorize_shift() const { |
448 | return vectorize_shift_; |
449 | } |
450 | |
451 | IterDomain* iter_domain() const { |
452 | return iter_domain_; |
453 | } |
454 | |
455 | // TODO: Return pointer instead of reference to be more consistent |
456 | Scope& body() { |
457 | return body_; |
458 | } |
459 | |
460 | const Scope& body() const { |
461 | return body_; |
462 | } |
463 | |
464 | bool vectorize() const { |
465 | return vectorize_; |
466 | } |
467 | |
468 | //! True if unrolled (i.e., "#pragma unroll" is attached) |
469 | bool isUnrolled() const; |
470 | |
471 | //! True if unrolling is required |
472 | bool isUnrollRequired() const { |
473 | return unroll_required_; |
474 | } |
475 | |
476 | //! Set unrolling required |
477 | void requireUnroll() { |
478 | unroll_required_ = true; |
479 | } |
480 | |
481 | //! True if no actual for-loop is materialized |
482 | bool isTrivial() const; |
483 | |
484 | //! Returns the stage of a double buffered iterdomain |
485 | //! that this for loop materializes. |
486 | auto doubleBufferLoopStage() const { |
487 | return double_buffer_loop_stage_; |
488 | } |
489 | |
490 | private: |
491 | //! Returns if a loop could be unrolled. |
492 | bool isUnrollable() const; |
493 | |
494 | private: |
495 | IterDomain* const iter_domain_ = nullptr; |
496 | |
497 | Val* index_ = nullptr; |
498 | Val* start_ = nullptr; |
499 | Val* stop_ = nullptr; |
500 | Val* step_ = nullptr; |
501 | |
502 | // vectorize is true when the for-loop contains a vectorize set |
503 | // the flag is used to omit the for-loop from the kernel |
504 | bool vectorize_ = false; |
505 | // [pre | vectorize | post] <= inner-most, merged root domain |
506 | // shift_ is applied to vectorize and post sections. |
507 | Val* vectorize_shift_ = nullptr; |
508 | |
509 | //! True if unroll is required for avoiding stack allocation |
510 | bool unroll_required_ = false; |
511 | |
512 | Scope body_; |
513 | |
514 | //! Tracks if this for loop is implementing a stage of |
515 | //! a double buffered iterdomain. |
516 | DoubleBufferLoopStage double_buffer_loop_stage_ = |
517 | DoubleBufferLoopStage::NotApplicable; |
518 | }; |
519 | |
520 | //! IfThenElse provides scoping for an boolean operator. Exprs placed in its |
521 | //! body are considered inside the scope of the if statement. In the future the |
522 | //! implementation should look quite different so that we can do proper |
523 | //! dependency annalysis like in Fusion. |
524 | //! |
525 | //! TODO(kir): this is not a real expression |
526 | //! |
527 | class TORCH_CUDA_CU_API IfThenElse final : public Expr { |
528 | public: |
529 | explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); |
530 | |
531 | Expr* shallowCopy() const override; |
532 | |
533 | Scope& thenBody() { |
534 | return then_body_; |
535 | } |
536 | const Scope& thenBody() const { |
537 | return then_body_; |
538 | } |
539 | |
540 | Scope& elseBody() { |
541 | return else_body_; |
542 | } |
543 | |
544 | const Scope& elseBody() const { |
545 | return else_body_; |
546 | } |
547 | |
548 | bool hasElse() const { |
549 | return !else_body_.empty(); |
550 | } |
551 | |
552 | private: |
553 | Scope then_body_; |
554 | Scope else_body_; |
555 | }; |
556 | |
557 | //! Grid reduction operation |
558 | //! |
559 | //! This node is used only after lowering a fusion to explicitly mark a grid |
560 | //! reduction and the buffer allocation needed to do it. |
561 | //! |
562 | //! This node provides FusionExecutor the information it needs to allocate the |
563 | //! reduction and sync buffers. |
564 | class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { |
565 | public: |
566 | GridReduction( |
567 | IrBuilderPasskey passkey, |
568 | BinaryOpType reduction_op_type, |
569 | Val* init, |
570 | Val* out, |
571 | Val* in, |
572 | Allocate* reduction_buffer, |
573 | Allocate* sync_buffer, |
574 | Val* entrance_index, |
575 | Val* entrances, |
576 | bool is_allreduce = false); |
577 | |
578 | Expr* shallowCopy() const override; |
579 | |
580 | Allocate* reduction_buffer() const { |
581 | return reduction_buffer_; |
582 | } |
583 | |
584 | Allocate* sync_buffer() const { |
585 | return sync_buffer_; |
586 | } |
587 | |
588 | // Which instance of entering this grid reduction is this iteration? |
589 | Val* entrance_index() const { |
590 | return entrance_index_; |
591 | } |
592 | |
593 | // How many times will this grid reduction be entered |
594 | Val* entrances() const { |
595 | return entrances_; |
596 | } |
597 | |
598 | const ParallelTypeBitmap& threadPredicate() const { |
599 | return thread_predicate_; |
600 | } |
601 | |
602 | GridReduction* withThreadPredicate( |
603 | const ParallelTypeBitmap& thread_predicate) { |
604 | auto result = shallowCopy()->as<GridReduction>(); |
605 | result->thread_predicate_ = thread_predicate; |
606 | return result; |
607 | } |
608 | |
609 | private: |
610 | Allocate* reduction_buffer_ = nullptr; |
611 | Allocate* sync_buffer_ = nullptr; |
612 | // gridReduce has template flags for thread predicates. In order to |
613 | // use them, the thread predicate is held here separately from |
614 | // Expr::predicate_. |
615 | ParallelTypeBitmap thread_predicate_; |
616 | Val* entrance_index_ = nullptr; |
617 | Val* entrances_ = nullptr; |
618 | }; |
619 | |
620 | class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { |
621 | public: |
622 | GroupedGridReduction( |
623 | IrBuilderPasskey passkey, |
624 | std::vector<BinaryOpType> reduction_op_type, |
625 | std::vector<Val*> init, |
626 | std::vector<Val*> out, |
627 | std::vector<Val*> in, |
628 | std::vector<Allocate*> reduction_buffers, |
629 | Allocate* sync_buffer, |
630 | Val* entrance_index, |
631 | Val* entrances, |
632 | Val* buffer_stride, |
633 | bool is_allreduce = false); |
634 | |
635 | Expr* shallowCopy() const override; |
636 | |
637 | const std::vector<Allocate*>& reduction_buffers() const { |
638 | return reduction_buffers_; |
639 | } |
640 | |
641 | Allocate* reduction_buffer(size_t i) const { |
642 | return reduction_buffers_.at(i); |
643 | } |
644 | |
645 | Allocate* sync_buffer() const { |
646 | return sync_buffer_; |
647 | } |
648 | |
649 | // Which instance of entering this grid reduction is this iteration? |
650 | Val* entrance_index() const { |
651 | return entrance_index_; |
652 | } |
653 | |
654 | // How many times will this grid reduction be entered |
655 | Val* entrances() const { |
656 | return entrances_; |
657 | } |
658 | |
659 | Val* buffer_stride() const { |
660 | return buffer_stride_; |
661 | } |
662 | |
663 | const ParallelTypeBitmap& threadPredicate() const { |
664 | return thread_predicate_; |
665 | } |
666 | |
667 | GroupedGridReduction* withThreadPredicate( |
668 | const ParallelTypeBitmap& thread_predicate) { |
669 | auto result = shallowCopy()->as<GroupedGridReduction>(); |
670 | result->thread_predicate_ = thread_predicate; |
671 | return result; |
672 | } |
673 | |
674 | private: |
675 | std::vector<Allocate*> reduction_buffers_; |
676 | Allocate* sync_buffer_ = nullptr; |
677 | // gridReduce has template flags for thread predicates. In order to |
678 | // use them, the thread predicate is held here separately from |
679 | // Expr::predicate_. |
680 | ParallelTypeBitmap thread_predicate_; |
681 | Val* entrance_index_ = nullptr; |
682 | Val* entrances_ = nullptr; |
683 | // Stride of reduction buffers |
684 | Val* buffer_stride_ = nullptr; |
685 | }; |
686 | |
687 | //! Grid broadcast operation |
688 | //! |
689 | //! This node is used only after lowering a fusion to explicitly mark a grid |
690 | //! broadcast and the buffer allocation needed to do it. |
691 | //! |
692 | //! This node provides FusionExecutor the information it needs to allocate the |
693 | //! broadcast and sync buffers. |
694 | class TORCH_CUDA_CU_API GridBroadcast final : public Expr { |
695 | public: |
696 | GridBroadcast( |
697 | IrBuilderPasskey passkey, |
698 | BroadcastOp* broadcast_op, |
699 | Allocate* broadcast_buffer, |
700 | Allocate* sync_buffer); |
701 | |
702 | Expr* shallowCopy() const override; |
703 | |
704 | BroadcastOp* broadcast_op() const { |
705 | return broadcast_op_; |
706 | } |
707 | |
708 | Allocate* broadcast_buffer() const { |
709 | return broadcast_buffer_; |
710 | } |
711 | |
712 | Allocate* sync_buffer() const { |
713 | return sync_buffer_; |
714 | } |
715 | |
716 | private: |
717 | BroadcastOp* broadcast_op_ = nullptr; |
718 | Allocate* broadcast_buffer_ = nullptr; |
719 | Allocate* sync_buffer_ = nullptr; |
720 | }; |
721 | |
722 | //! Grid welford operation |
723 | //! |
724 | //! This node is used only after lowering a fusion to explicitly mark a grid |
725 | //! reduction and the buffer allocation needed to do it. |
726 | //! |
727 | //! This node provides FusionExecutor the information it needs to allocate the |
728 | //! reduction and sync buffers. |
729 | //! |
730 | //! TODO: Make this a subclass of WelfordOp |
731 | class TORCH_CUDA_CU_API GridWelford final : public Expr { |
732 | public: |
733 | GridWelford( |
734 | IrBuilderPasskey passkey, |
735 | WelfordOp* welford_op, |
736 | Allocate* var_buffer, |
737 | Allocate* avg_buffer, |
738 | Allocate* n_buffer, |
739 | Allocate* sync_buffer, |
740 | Val* entrance_index, |
741 | Val* entrances); |
742 | |
743 | Expr* shallowCopy() const override; |
744 | |
745 | WelfordOp* welford_op() const { |
746 | return welford_op_; |
747 | } |
748 | |
749 | Allocate* var_buffer() const { |
750 | return var_buffer_; |
751 | } |
752 | |
753 | Allocate* avg_buffer() const { |
754 | return avg_buffer_; |
755 | } |
756 | |
757 | Allocate* N_buffer() const { |
758 | return n_buffer_; |
759 | } |
760 | |
761 | Allocate* sync_buffer() const { |
762 | return sync_buffer_; |
763 | } |
764 | |
765 | // Which instance of entering this grid reduction is this iteration? |
766 | Val* entrance_index() const { |
767 | return entrance_index_; |
768 | } |
769 | |
770 | // How many times will this grid reduction be entered |
771 | Val* entrances() const { |
772 | return entrances_; |
773 | } |
774 | |
775 | const ParallelTypeBitmap& threadPredicate() const { |
776 | return thread_predicate_; |
777 | } |
778 | |
779 | GridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { |
780 | auto result = shallowCopy()->as<GridWelford>(); |
781 | result->thread_predicate_ = thread_predicate; |
782 | return result; |
783 | } |
784 | |
785 | private: |
786 | WelfordOp* welford_op_ = nullptr; |
787 | Allocate* var_buffer_ = nullptr; |
788 | Allocate* avg_buffer_ = nullptr; |
789 | Allocate* n_buffer_ = nullptr; |
790 | Allocate* sync_buffer_ = nullptr; |
791 | Val* entrance_index_ = nullptr; |
792 | Val* entrances_ = nullptr; |
793 | // gridReduce has template flags for thread predicates. In order to |
794 | // use them, the thread predicate is held here separately from |
795 | // Expr::predicate_. |
796 | ParallelTypeBitmap thread_predicate_; |
797 | }; |
798 | |
799 | class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { |
800 | public: |
801 | // input, output and init vals are vectors of triplets |
802 | GroupedGridWelford( |
803 | IrBuilderPasskey passkey, |
804 | std::vector<WelfordTriplet> output_vals, |
805 | std::vector<WelfordTriplet> input_vals, |
806 | std::vector<WelfordTriplet> init_vals, |
807 | std::array<std::vector<Allocate*>, 3> reduction_buffers, |
808 | Allocate* sync_buffer, |
809 | Val* entrance_index, |
810 | Val* entrances, |
811 | Val* buffer_stride, |
812 | bool is_allreduce = false); |
813 | |
814 | Expr* shallowCopy() const override; |
815 | |
816 | const std::array<std::vector<Allocate*>, 3>& reduction_buffers() const { |
817 | return reduction_buffers_; |
818 | } |
819 | |
820 | Allocate* sync_buffer() const { |
821 | return sync_buffer_; |
822 | } |
823 | |
824 | // Which instance of entering this grid reduction is this iteration? |
825 | Val* entrance_index() const { |
826 | return entrance_index_; |
827 | } |
828 | |
829 | // How many times will this grid reduction be entered |
830 | Val* entrances() const { |
831 | return entrances_; |
832 | } |
833 | |
834 | Val* buffer_stride() const { |
835 | return buffer_stride_; |
836 | } |
837 | |
838 | const ParallelTypeBitmap& threadPredicate() const { |
839 | return thread_predicate_; |
840 | } |
841 | |
842 | GroupedGridWelford* withThreadPredicate( |
843 | const ParallelTypeBitmap& thread_predicate) { |
844 | auto result = shallowCopy()->as<GroupedGridWelford>(); |
845 | result->thread_predicate_ = thread_predicate; |
846 | return result; |
847 | } |
848 | |
849 | private: |
850 | std::array<std::vector<Allocate*>, 3> reduction_buffers_; |
851 | Allocate* sync_buffer_ = nullptr; |
852 | // gridReduce has template flags for thread predicates. In order to |
853 | // use them, the thread predicate is held here separately from |
854 | // Expr::predicate_. |
855 | ParallelTypeBitmap thread_predicate_; |
856 | Val* entrance_index_ = nullptr; |
857 | Val* entrances_ = nullptr; |
858 | // Stride of reduction buffers |
859 | Val* buffer_stride_ = nullptr; |
860 | }; |
861 | |
862 | // Allocate an instance of the fused reduction class. |
863 | class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { |
864 | public: |
865 | explicit AllocateFusedReduction( |
866 | IrBuilderPasskey passkey, |
867 | GridReduction* grid_reduction); |
868 | |
869 | explicit AllocateFusedReduction( |
870 | IrBuilderPasskey passkey, |
871 | GridWelford* grid_welford); |
872 | |
873 | explicit AllocateFusedReduction( |
874 | IrBuilderPasskey passkey, |
875 | GroupedGridReduction* grouped_grid_reduction); |
876 | |
877 | explicit AllocateFusedReduction( |
878 | IrBuilderPasskey passkey, |
879 | GroupedGridWelford* grouped_grid_welford); |
880 | |
881 | Expr* shallowCopy() const override; |
882 | |
883 | Expr* gridExpr() const { |
884 | return grid_expr_; |
885 | } |
886 | |
887 | TensorIndex* out() const; |
888 | |
889 | const ParallelTypeBitmap& threadPredicate() const; |
890 | |
891 | private: |
892 | //! GridReduction, GridWelford, GroupedGridReduction or GroupedGridWelford |
893 | Expr* grid_expr_ = nullptr; |
894 | }; |
895 | |
896 | //! An IR node consisting of a pair of integers |
897 | //! to facilitate definition of 2D swizzle operators. |
898 | //! All swizzle 2D ops takes two inputs and outputs |
899 | //! an integer pair. |
900 | //! TODO: |
901 | //! currently this IR node is only allowed as input |
902 | //! to the new PairSelect node. In follow ups would |
903 | //! possibly build out to support out of line |
904 | //! definition of the pair alone. |
905 | class TORCH_CUDA_CU_API IntPair : public Val { |
906 | public: |
907 | IntPair(IrBuilderPasskey passkey); |
908 | }; |
909 | |
910 | //! An IR node marking selection of first or second |
911 | //! value from a pair of integers, e.g.: |
912 | //! Pair(X,Y) -> X or Y. |
913 | //! This IR node is used to facilitate generation |
914 | //! of inline 2D swizzle math. |
915 | class TORCH_CUDA_CU_API PairSelect : public Expr { |
916 | public: |
917 | //! Indicates which value from the input |
918 | //! integer pair to output. |
919 | enum class Selection { X = 0, Y }; |
920 | |
921 | PairSelect(IrBuilderPasskey, Val* out, IntPair* in, Selection selection); |
922 | |
923 | Expr* shallowCopy() const override; |
924 | |
925 | Val* out() const { |
926 | return out_; |
927 | } |
928 | |
929 | IntPair* in() const { |
930 | return in_; |
931 | } |
932 | |
933 | auto selection() const { |
934 | return selection_; |
935 | } |
936 | |
937 | private: |
938 | Val* const out_ = nullptr; |
939 | IntPair* const in_ = nullptr; |
940 | Selection selection_; |
941 | }; |
942 | |
943 | //! An integer IR node that will be generated |
944 | //! using custom integer swizzle functions |
945 | //! from the cuda runtime functions. |
946 | //! Most supported swizzle functions require |
947 | //! the sizes of each dimension defined so |
948 | //! all operators will take the extents as inputs. |
949 | class TORCH_CUDA_CU_API Swizzle2DInt : public Expr { |
950 | public: |
951 | Swizzle2DInt( |
952 | IrBuilderPasskey, |
953 | IntPair* out, |
954 | Val* in_x, |
955 | Val* in_y, |
956 | Val* extent_x, |
957 | Val* extent_y, |
958 | Swizzle2DType swizzle_type); |
959 | |
960 | Expr* shallowCopy() const override; |
961 | |
962 | IntPair* out() const { |
963 | return out_; |
964 | } |
965 | |
966 | Val* inX() const { |
967 | return in_x_; |
968 | } |
969 | |
970 | Val* inY() const { |
971 | return in_y_; |
972 | } |
973 | |
974 | Val* extentX() const { |
975 | return extent_x_; |
976 | } |
977 | |
978 | Val* extentY() const { |
979 | return extent_y_; |
980 | } |
981 | |
982 | const auto& swizzleType() const { |
983 | return swizzle_type_; |
984 | } |
985 | |
986 | private: |
987 | IntPair* const out_ = nullptr; |
988 | |
989 | Val* const in_x_ = nullptr; |
990 | Val* const in_y_ = nullptr; |
991 | Val* const extent_x_ = nullptr; |
992 | Val* const extent_y_ = nullptr; |
993 | Swizzle2DType swizzle_type_; |
994 | }; |
995 | |
996 | } // namespace kir |
997 | } // namespace cuda |
998 | } // namespace fuser |
999 | } // namespace jit |
1000 | } // namespace torch |
1001 | |