1 | #pragma once |
2 | |
3 | #include <fusion.h> |
4 | #include <index_compute.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | namespace fuser { |
9 | namespace cuda { |
10 | |
11 | // Struct to hold useful information from an index pass on iterdomain graph. |
12 | // Used to return the IndexCompute structure back to the indexing calls in |
13 | // index_compute.cpp. Other structurs are required to resolve the actual |
14 | // indexing math there. |
15 | struct IndexFromIdGraph { |
16 | IndexCompute index; |
17 | IndexCompute concrete_index; |
18 | std::unordered_map<IterDomain*, Val*> initial_concrete_index_map; |
19 | std::vector<IterDomain*> resolved_loop_domains; |
20 | |
21 | explicit IndexFromIdGraph( |
22 | IndexCompute index, |
23 | IndexCompute concrete_index, |
24 | std::unordered_map<IterDomain*, Val*> initial_concrete_index_map, |
25 | std::vector<IterDomain*> loop_domains); |
26 | }; |
27 | |
28 | //! Indexing interface, returns IndexFromIdGraph which the IndexCompute object |
29 | //! can be queried from directly for the produced indexing. If producer_tv != |
30 | //! nullptr producer will be indexed, if producer_tv == nullptr consumer will be |
31 | //! indexed. If is_global global indexing will be done, else shared memory or |
32 | //! local indexing will be performed. |
33 | IndexFromIdGraph getTensorIndexFromIdGraph( |
34 | const std::vector<kir::ForLoop*>& loops, |
35 | const TensorView* consumer_tv, |
36 | const TensorView* producer_tv = nullptr, |
37 | bool is_global = true, |
38 | std::unordered_map<IterDomain*, IterDomain*> c2p_map = {}); |
39 | |
40 | //! Indexing interface for calculating predicate index returns IndexFromIdGraph |
41 | //! which the IndexCompute object can be queried from directly for the produced |
42 | //! indexing If is_start_predicate, will produce indexing math for the start |
43 | //! predicates. |
44 | IndexFromIdGraph getPredicateIndexingFromIdGraph( |
45 | const std::vector<kir::ForLoop*>& loops, |
46 | TensorView* consumer_tv, |
47 | kir::ForLoop* unswitch_or_vec_loop, |
48 | IterDomain* double_buffer_axis, |
49 | bool is_start_predicate); |
50 | |
51 | //! getTensorIndexFromIdGraph is the function that index_compute will call very |
52 | //! straightforwardly. However, for implementing the new indexing logic that |
53 | //! starts to abstract some of the indexing away from index_compute we need to |
54 | //! move quite a bit of the intertwined indexing logic away from the |
55 | //! index_compute file and the index_reference_replay file. This is because we |
56 | //! want to separate out what has to be done on the fly, from what analysis we |
57 | //! can do early on with the iter domain graph and associated properties. |
58 | //! |
59 | //! getTensorIndexFromIdGraph places this analysis internally in |
60 | //! LoopIndexingAnalysis. LoopIndexingAnalysis though has to communicate to: |
61 | //! 1) index_compute.cpp::IndexCompute to tell IndexCompute which expressions |
62 | //! it needs to traverse to compute the indexing math. |
63 | //! 2) lower_shift.cpp::HaloInfo::buildConcreteHaloExtentMap to build the halo |
64 | //! extent map used in indexing. |
65 | //! |
66 | //! LoopIndexing is nothing but a mechanism for this communication. |
67 | //! |
68 | //! Holds information needed to produce indexing math. In the current version of |
69 | //! indexing pass, the iter domains combined with the loop nests are the source |
70 | //! of truth in terms of resolving the actual integer indexing math from the |
71 | //! sequence of iterdomain transforms. |
72 | //! |
73 | //! This information is crtiical in resolving indexing associated with complex |
74 | //! broadcast patterns. Check FusionComplexBCast* test cases as well as |
75 | //! FusionAdvancedIndexing* for examples where resolving indices from IterDomain |
76 | //! transformations can be challenging. |
77 | //! |
78 | //! The source of this challenge is due to inling patterns where the IterDomains |
79 | //! responsible for control flow are not local to a particular TensorView. |
80 | //! Broadcast, operations like view/reshape, and gather/shift can make indexing |
81 | //! local buffers complex because of the complex effects inlining into other |
82 | //! TensorViews produce. |
83 | //! |
84 | //! TODO: |
85 | //! The first iteration tries to match the semantics of reference |
86 | //! replay without any new logic. In a follow up iteration will |
87 | //! need to revisit a few further pathological patterns. |
88 | //! |
89 | //! Note: |
90 | //! The current implementation of loop indexing pass works on |
91 | //! equivalent classes defined by ComputeAt exact map. The |
92 | //! list of expressions stored in this class form a "reference", graph of |
93 | //! iterdomain expressions when all of their inputs and outputs are replaced |
94 | //! with their exact concrete mapped id's. |
95 | //! |
96 | //! Here an invariant in a graph of iterdomain expressions is that |
97 | //! each iterdomain is produced exactly once and is either a leaf domain |
98 | //! or has been consumed exactly once by another expression. This makes sure |
99 | //! that a well defined indexing can be generated for each of the concrete ids |
100 | //! whenever we either forward or backward traverse the graph. |
101 | class LoopIndexing { |
102 | public: |
103 | //! Returns the original loop nest. |
104 | const auto& loops() const { |
105 | return loops_; |
106 | } |
107 | |
108 | //! Returns the vector of Iterdomains |
109 | //! that match the original loop pattern. |
110 | const auto& loopDomains() const { |
111 | return loop_domains_; |
112 | } |
113 | |
114 | //! Returns the consumer tv that the view info |
115 | //! was derived from. |
116 | auto consumerTv() const { |
117 | return consumer_tv_; |
118 | } |
119 | |
120 | //! Returns the set of Iterdomain transforms that |
121 | //! define the correct indexing path, in forward |
122 | //! topological order. |
123 | std::vector<Expr*> getForwardExprList() const; |
124 | |
125 | //! Returns the set of Iterdomain transforms that |
126 | //! define the correct indexing path, in backward |
127 | //! topological order. |
128 | std::vector<Expr*> getBackwardExprList() const; |
129 | |
130 | //! Returns the set of out of line expressions in |
131 | //! reverse topological order. |
132 | const std::vector<Expr*>& getBackwardOutOfLineExprList() const { |
133 | return out_of_line_exprs_; |
134 | } |
135 | |
136 | //! Returns all exact concrete id's that were produced |
137 | //! or consumed in the selected indexing expressions |
138 | std::unordered_set<IterDomain*> getAllExactConcreteIdSet() const; |
139 | |
140 | private: |
141 | friend class LoopIndexingAnalysis; |
142 | |
143 | //! The loop nest that this loop indexing is derived from. |
144 | std::vector<kir::ForLoop*> loops_; |
145 | |
146 | //! Consumer tv, where the view related info was derived from. |
147 | const TensorView* consumer_tv_; |
148 | |
149 | //! The source iterdomains that all the Iterdomain transforms |
150 | //! in this loop nest originated from. |
151 | std::vector<IterDomain*> loop_root_; |
152 | |
153 | //! The leaf iterdomains that the original loop nests correspond |
154 | //! to. May be longer than loops_ with the dangling iterdomains |
155 | //! appended towards the end. |
156 | std::vector<IterDomain*> loop_domains_; |
157 | |
158 | //! The selected sequence of expressions that should represent |
159 | //! the correct indexing math from the given loop nest. |
160 | std::vector<Expr*> index_exprs_; |
161 | |
162 | //! The subset of sequence of expressions that can be resolved |
163 | //! with only the iterdomains on the right of consumer tv's ca |
164 | //! axis. |
165 | //! Expressions are ordered in reverse topological order. |
166 | std::vector<Expr*> out_of_line_exprs_; |
167 | }; |
168 | |
169 | // When indexing there are sometimes an option to propagate an index down |
170 | // multiple paths. This will return the IterDomains in the history of the |
171 | // reference domain and mark which paths should be taken (if there's a |
172 | // preference) to reach the roots provided in preferred_roots. |
173 | std::unordered_set<IterDomain*> buildLoopIndexingPreferredPath( |
174 | const TensorView* original_tv, |
175 | const LoopIndexing& loop_indexing, |
176 | bool use_replay_map = false, |
177 | std::unordered_map<IterDomain*, IterDomain*> p2c_map = {}); |
178 | |
179 | // Get an rfactor IterDomain that is mapped with an IterDomain. If |
180 | // multiple such IDs exist, select one whose input IDs are mapped with |
181 | // the consumer IDs. This is to ensure the path from the leaf |
182 | // IterDomains to the root matches with the consumer tensor. |
183 | IterDomain* getRfactorIDToTraverse( |
184 | IterDomain* id, |
185 | const std::vector<Val*>& consumer_all_ids); |
186 | |
187 | } // namespace cuda |
188 | } // namespace fuser |
189 | } // namespace jit |
190 | } // namespace torch |
191 | |