1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <dispatch.h> |
6 | #include <ir_all_nodes.h> |
7 | #include <kernel_ir.h> |
8 | |
9 | #include <unordered_set> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | //! Detect almost all IterDomains that are derived from trivial |
17 | //! reductons. |
18 | class TORCH_CUDA_CU_API TrivialReductionInfo { |
19 | public: |
20 | void build(Fusion* fusion); |
21 | |
22 | bool isDerived(IterDomain* id) const; |
23 | |
24 | private: |
25 | //! IterDomains that are derived only from trivial |
26 | //! reductons. Included domains are not limited to reduction axes as |
27 | //! rfactor can make reductions to normal axes. |
28 | //! |
29 | //! Note that the set should cover almost all cases but there can be |
30 | //! undetected trivial domains. For example, split by one creates a |
31 | //! trivial reduction domain, which is detected. However, if it is |
32 | //! further split, both of the two resulting axes are also trivial, |
33 | //! however, only the inner axis is recognized as trivial. While this |
34 | //! is a limitation, it would have very little practical |
35 | //! implication. |
36 | std::unordered_set<IterDomain*> domains_; |
37 | //! Subset of domains_, whose input root axes are all derived from |
38 | //! trivial reductions. These domains do not need to manifest as |
39 | //! for-loops. |
40 | std::unordered_set<IterDomain*> domains_derived_from_root_; |
41 | }; |
42 | |
43 | } // namespace cuda |
44 | } // namespace fuser |
45 | } // namespace jit |
46 | } // namespace torch |
47 | |