1 | #pragma once |
2 | |
3 | #include <disjoint_set.h> |
4 | #include <ir_all_nodes.h> |
5 | #include <kernel_ir.h> |
6 | #include <lower_trivial_reductions.h> |
7 | |
8 | #include <deque> |
9 | #include <unordered_map> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | // There's three modes of these iter domain mappings all uniquely important in |
17 | // the lowering process. |
18 | // |
19 | // For EXACT/PERMISSIVE mode consider: |
20 | // |
21 | // consumer[i0, b1] = producer[i0] |
22 | // consumer->merge(0) (consumer will now be [i0 * b1]) |
23 | // When producer is replayed as consumer (the direction we use for mapping) |
24 | // with BestEffortReplay forward_bcast_mismatch = True the producer to |
25 | // consumer map will have both a mapping of consumer(i0) to producer(i0) as |
26 | // well as consumer(i0*b1) to producer(i0). This latter mapping is important |
27 | // for loop nest mappings as the consumer will generate a loop based on i0*b1 |
28 | // and the producer may be computeAt inside this loop nest. However, for |
29 | // indexing we do not want these two maps as producer may be indexed as i0*i1 |
30 | // depending on the loop nest structure and how it was built. Therefore we |
31 | // really need to carry (at least) two sets of maps around for lowering. |
32 | // |
33 | // LOOP mode is important if we have something like: |
34 | // consumer[i0o, threadIdx.x{i0i}] = producer[i0o, threadIdx.y{i0i}](computeAt |
35 | // = 1) which can easily happen when using shared memory. We want to make sure |
36 | // that the iteration domain used for loop construction (concreteId) has the |
37 | // proper parallelization strategy. In parallel mode we do typical iteration |
38 | // domain mapping, however we remove from it any iteration domains outside the |
39 | // computeAt of producer when mapping. This guarentees we won't map |
40 | // IterDomains that could have different parallelization strategies. We also |
41 | // propagate the parallel strategy in parallel mode so all mapped IDs that |
42 | // must have the same parallel type, do. |
43 | // |
44 | // IdMappingMode::LOOP |
45 | // Only maps leaf axes to left of compute at |
46 | // Forward broadcast axes in replay |
47 | // IdMappingMode::PERMISSIVE |
48 | // Forward broadcast axes in replay |
49 | // Map all iteration domains |
50 | // Always contain root mappings (otherwise they could have been forwarded in |
51 | // broadcast) |
52 | // IdMappingMode::EXACT |
53 | // Don't map any broadcast axes to non-broadcast axes |
54 | // Do not forward through any broadcast IDs |
55 | class TORCH_CUDA_CU_API IterDomainGraph { |
56 | public: |
57 | IterDomainGraph(Fusion* fusion, bool allow_self_mapping = false); |
58 | |
59 | const DisjointSets<IterDomain*>& permissiveNodes() const { |
60 | return permissive_nodes_; |
61 | } |
62 | const DisjointSets<IterDomain*>& exactNodes() const { |
63 | return exact_nodes_; |
64 | } |
65 | const DisjointSets<IterDomain*>& loopNodes() const { |
66 | return loop_nodes_; |
67 | } |
68 | |
69 | // Consumers and producers is not symmetric like the other sets |
70 | const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>& |
71 | consumers() const { |
72 | return consumers_; |
73 | } |
74 | const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>& |
75 | producers() const { |
76 | return producers_; |
77 | } |
78 | |
79 | const DisjointSets<IterDomain*>& siblings() const { |
80 | return sibling_sets_; |
81 | } |
82 | |
83 | const VectorOfUniqueEntries<IterDomain*>& allIds() const { |
84 | return all_ids_; |
85 | } |
86 | |
87 | const std::unordered_set<IterDomain*>& viewRfactorIds() const { |
88 | return view_rfactor_ids_; |
89 | } |
90 | |
91 | // Returns if first and second are expressions through which the provided |
92 | // id_map have matching inputs (if forward), or outputs (if not forward). |
93 | // Returning true means the expressions are "the same", in terms they modify |
94 | // matching original extents, by the same amount. |
95 | static bool exprsMap( |
96 | Expr* first, |
97 | Expr* second, |
98 | bool forward, |
99 | const DisjointSets<IterDomain*>& id_map); |
100 | |
101 | bool hasSelfMapping() const { |
102 | return self_mapping_info_.has_value(); |
103 | } |
104 | |
105 | private: |
106 | void build(Fusion* fusion); |
107 | |
108 | void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id); |
109 | |
110 | // Checks if exprsMap then if forward will map outputs else inputs in exact |
111 | // and permissive map. |
112 | void mapThroughExpr(Expr* first, Expr* second, bool forward); |
113 | |
114 | DisjointSets<IterDomain*> permissive_nodes_; |
115 | DisjointSets<IterDomain*> exact_nodes_; |
116 | DisjointSets<IterDomain*> loop_nodes_; |
117 | |
118 | // Consumers and producers is not symmetric like the other sets |
119 | std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>> |
120 | consumers_; |
121 | std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>> |
122 | producers_; |
123 | |
124 | DisjointSets<IterDomain*> sibling_sets_; |
125 | |
126 | VectorOfUniqueEntries<IterDomain*> all_ids_; |
127 | |
128 | std::unordered_set<IterDomain*> view_rfactor_ids_; |
129 | |
130 | c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>> |
131 | self_mapping_info_ = c10::nullopt; |
132 | }; |
133 | |
134 | class TrivialReductionInfo; |
135 | |
136 | using DoubleBufferIndices = std::unordered_map<DoubleBufferLoopStage, Int*>; |
137 | |
138 | class TORCH_CUDA_CU_API ComputeAtMap { |
139 | public: |
140 | ComputeAtMap() = delete; |
141 | ComputeAtMap(const ComputeAtMap&) = delete; |
142 | ComputeAtMap& operator=(const ComputeAtMap&) = delete; |
143 | ComputeAtMap(ComputeAtMap&&) = default; |
144 | ComputeAtMap& operator=(ComputeAtMap&&) = default; |
145 | ComputeAtMap(Fusion* fusion); |
146 | |
147 | //! Run through disjoint sets in the LOOP map, make sure there's only one |
148 | //! non-serial parallel type in each disjoint set, set the parallel type of |
149 | //! all IterDomains in the disjoint set to that PType. |
150 | void validateAndPropagatePType(); |
151 | |
152 | //! Run through disjoint sets in the LOOP map and allocate the index |
153 | //! variable for the associated for loop that will be generated |
154 | //! for each disjoint sets in the loop map. This pre-allocation makes |
155 | //! 2 key assumptions about computeAt map that would very likely be |
156 | //! long term invariant: |
157 | //! 1. All kir::forloop created in the lowering pass should belong |
158 | //! to one of the disjoint sets in loop map. |
159 | //! 2. The lowering pass will *never* create a loop nest with 2 |
160 | //! different nesting levels mapped together, i.e. the case below |
161 | //! never occurs: |
162 | //! for i in IterDomain1 |
163 | //! for j in IterDomain2 |
164 | //! ... |
165 | //! With loop_map.areMapped(IterDomain1, IterDomain2) == true. |
166 | //! Under this condition, we can pre-allocate all required index |
167 | //! variable integers before creating any kir::forloop, and this |
168 | //! would help optimizing the generated integer math for indexing. |
169 | void allocateIndexVariables(); |
170 | |
171 | //! Returns if id0 and id1 are mapped to eachother with provided IdMappingMode |
172 | bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const; |
173 | |
174 | //! Returns an iter domain that is the maximum expanded size of all iter |
175 | //! domains the one provided maps to. Useful for opening loops to the correct |
176 | //! iteration size. Not guarenteed to return the same ID every call, but is |
177 | //! guarenteed to return iter domains in the same disjoint set. |
178 | IterDomain* getConcreteMappedID(IterDomain* id, IdMappingMode mode) const; |
179 | |
180 | // Prints mapping information, forwards to an internal IterDomainGraph |
181 | std::string toString() const; |
182 | |
183 | // Returns if the provided ID is a view like rfactor id |
184 | bool isViewRfactor(IterDomain* ref_id) const; |
185 | |
186 | // Returns all rfactor domains in rfactor_concrete_count_reset_domains_ that |
187 | // are in the disjoint set of the provided IterDomain. This will be every view |
188 | // like rfactor ID the provided ID "depends" on in the map. |
189 | std::vector<IterDomain*> getViewRfactorDomainsOfIdGroup( |
190 | IterDomain* ref_id, |
191 | IdMappingMode mode) const; |
192 | |
193 | const IterDomainGraph& idGraph() const { |
194 | return id_graph_; |
195 | } |
196 | |
197 | //! Get the ID sets for a provided IdMappingMode |
198 | const DisjointSets<IterDomain*>& getIdSets(IdMappingMode mode) const; |
199 | |
200 | // Returns if the ID actually has a disjoint set meaning it has been processed |
201 | // in the creation of the compute at map. |
202 | bool idExistsInMap(IterDomain* id) const; |
203 | |
204 | //! Returns the pre-allocated index variable integer used in |
205 | //! the kir::ForLoop corresponding to the given IterDomain. |
206 | //! this interface is only valid if the ID has a loop mapping, |
207 | //! ca_map will throw exceptions if given iterdomain doesn't |
208 | //! have a loop map entry. |
209 | Val* getIndexVariable( |
210 | IterDomain* id, |
211 | DoubleBufferLoopStage double_buffer_loop_stage = |
212 | DoubleBufferLoopStage::NotApplicable) const; |
213 | |
214 | private: |
215 | // Build id_graph_ |
216 | void build(Fusion* fusion); |
217 | |
218 | // Build concrete_id_cache_ |
219 | // Build a single entry in concrete_cache_id_ |
220 | IterDomain* computeConcreteId(IterDomain* id, IdMappingMode mode); |
221 | void buildConcreteIds(); |
222 | |
223 | // Produce the disjoint set containing provided id with mapping mode. |
224 | const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>& disjointSetOf( |
225 | IterDomain* id, |
226 | IdMappingMode mode) const; |
227 | |
228 | // Should be built once and never modified again. |
229 | IterDomainGraph id_graph_; |
230 | TrivialReductionInfo trivial_reduction_info_; |
231 | |
232 | // Prevent needing to recompute concrete_id's in compute at map. |
233 | // VectorOfUniqueEntries is unique across mapping modes, so don't need to use |
234 | // mapping mode directly in this cache. const |
235 | // VectorOfUniqueEntries<IterDomain*>& is what's returned by |
236 | // ComputeAtMap::disjointSetOf which can be used directly. |
237 | std::unordered_map< |
238 | std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>, |
239 | IterDomain*> |
240 | concrete_id_cache_; |
241 | |
242 | //! Allocated Loop index variable through the CA map. |
243 | //! only valid for disjoint sets on the loop ca map. |
244 | std::unordered_map<const VectorOfUniqueEntries<IterDomain*>*, Val*> |
245 | loop_index_variable_map_; |
246 | |
247 | //! Allocated loop indices for double buffer loop. |
248 | //! only valid for disjoint sets on the loop ca map |
249 | //! that have double buffer-ed iterdomains. |
250 | using DoubleBufferIndicesPtr = std::unique_ptr<DoubleBufferIndices>; |
251 | std::unordered_map< |
252 | const VectorOfUniqueEntries<IterDomain*>*, |
253 | DoubleBufferIndicesPtr> |
254 | double_buffered_loop_index_variable_map_; |
255 | |
256 | // Shortcut to access the fusion this computeAt map was |
257 | // built from. |
258 | Fusion* fusion_; |
259 | }; |
260 | |
261 | } // namespace cuda |
262 | } // namespace fuser |
263 | } // namespace jit |
264 | } // namespace torch |
265 | |