1 | #pragma once |
2 | |
3 | #include <ir_interface_nodes.h> |
4 | #include <maxinfo_propagator.h> |
5 | #include <transform_replay.h> |
6 | |
7 | #include <memory> |
8 | #include <unordered_set> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | class MaxPosCalculator { |
16 | // Root domains in producer that's unmappable to any of its consumers |
17 | std::unordered_set<IterDomain*> unmappable_dims_; |
18 | |
19 | // User set IterDomains to not inline, used in schedulers to avoid inlining |
20 | // trivial reductions |
21 | std::unordered_set<IterDomain*> uninlinable_ids_; |
22 | |
23 | // Iterate through all TVs and collect the dimensions of each TV that don't |
24 | // map to all its consumer TVs. |
25 | void buildUnmappableDims(); |
26 | |
27 | // Utility function to return if an id of tv is a valid iter domain to inline |
28 | // within. This is used in getMaxPos{PasC,CasP}. Different variations of the |
29 | // bool values are used if checking max position of PasC, CasP, or checking |
30 | // for a max "self" position. |
31 | bool isAllowedID( |
32 | IterDomain* id, |
33 | TensorView* tv, |
34 | bool best_effort, |
35 | bool allow_reduction, |
36 | bool allow_vectorize, |
37 | bool allow_unmappable) const; |
38 | |
39 | public: |
40 | // Returns the position at which tv can be inlined within. |
41 | size_t getMaxPosSelf( |
42 | TensorView* tv, |
43 | bool best_effort, |
44 | bool allow_reduction, |
45 | bool allow_vectorize, |
46 | bool allow_unmappable) const; |
47 | |
48 | // Returns the maximum position producer can be inlined based on consumer |
49 | // given the set ComputeAtMode |
50 | size_t getMaxProducerPosFromConsumer( |
51 | TensorView* producer, |
52 | TensorView* consumer, |
53 | bool best_effort) const; |
54 | |
55 | // Checks producers, consumers, and siblings to see what the maximum position |
56 | // in tv is that can be shared across both directions. |
57 | size_t getMaxPosAll( |
58 | TensorView* tv, |
59 | bool best_effort = false, |
60 | bool check_siblings = true); |
61 | |
62 | MaxPosCalculator(const std::unordered_set<IterDomain*>& uninlinable_ids = {}); |
63 | }; |
64 | |
65 | // Inline to the right most allowed position for all tensors in the current |
66 | // fusion. |
67 | TORCH_CUDA_CU_API void inlineMost( |
68 | const std::unordered_set<IterDomain*>& uninlinable_ids = {}); |
69 | // Inline to the right most allowed position for the selected tensors in the |
70 | // current fusion. |
71 | TORCH_CUDA_CU_API void inlineMost( |
72 | const std::vector<TensorView*>& tvs, |
73 | const std::unordered_set<IterDomain*>& uninlinable_ids = {}); |
74 | // Inline to the right most allowed position for the selected tensors in the |
75 | // current fusion. |
76 | TORCH_CUDA_CU_API void inlineMost( |
77 | const std::unordered_set<TensorView*>& tvs, |
78 | const std::unordered_set<IterDomain*>& uninlinable_ids = {}); |
79 | |
80 | // Inline to the position corresponding to the reference position in the |
81 | // reference tensor for all tensors in the current fusion. |
82 | TORCH_CUDA_CU_API void inlineAllAt( |
83 | TensorView* reference_tv, |
84 | int64_t reference_pos, |
85 | bool best_effort = false, |
86 | const std::unordered_set<IterDomain*>& uninlinable_ids = {}); |
87 | |
88 | // Inline to the position corresponding to the reference position in the |
89 | // reference tensor for selected tensors in the current fusion. |
90 | TORCH_CUDA_CU_API void inlineSelectedAt( |
91 | const std::unordered_set<TensorView*>& selected, |
92 | TensorView* reference_tv, |
93 | int64_t reference_pos, |
94 | bool best_effort = false, |
95 | const std::unordered_set<IterDomain*>& uninlinable_ids = {}); |
96 | |
97 | } // namespace cuda |
98 | } // namespace fuser |
99 | } // namespace jit |
100 | } // namespace torch |
101 | |