1 | #pragma once |
2 | |
3 | #include <ir_all_nodes.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | namespace fuser { |
8 | namespace cuda { |
9 | |
10 | //! Keep track of certain patterns of reductions. |
11 | //! |
12 | //! - Allreduce IterDomain: reduced and broadcast domain. |
13 | class FusedReductionInfo { |
14 | public: |
15 | void markAsAllreduce(IterDomain* id); |
16 | |
17 | bool isAllreduce(IterDomain* id) const; |
18 | |
19 | private: |
20 | // Reduction IterDomains that are also broadcast |
21 | std::unordered_set<IterDomain*> allreduce_ids_; |
22 | }; |
23 | |
24 | //! Detect reductions and broadcasts that are eligible for the fused |
25 | //! reduction kernel. When found, the predicate flags of the broadcast |
26 | //! is unset, which effectively makes the broadcast just a unary set |
27 | //! op. |
28 | //! TODO: Consider moving the warp-based fused reduction here. |
29 | void fuseReductionsAndBroadcasts(Fusion*); |
30 | |
31 | } // namespace cuda |
32 | } // namespace fuser |
33 | } // namespace jit |
34 | } // namespace torch |
35 | |