1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <compute_at_map.h> |
6 | #include <disjoint_set.h> |
7 | #include <ir_all_nodes.h> |
8 | #include <lower_shift.h> |
9 | #include <lower_trivial_broadcast.h> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | // Goes through the transformations associated with a series of ids and root |
17 | // ids. Checks the ordering of the iteration domains through these operations to |
18 | // pick out which operations are consistently ordered. For example: |
19 | // [i0, i1, i2] |
20 | // ->split(0, 4)->merge(1)->merge(1)->merge(0) |
21 | // are consistently ordered from largest to smallest extents, but |
22 | // ->split(0, 4)->merge(1)->merge(0, 2)->merge(0) is not consistently ordered |
23 | // with the roots. |
24 | // |
25 | // This property is important to understand the contiguity of dimensions through |
26 | // complex transformations. |
27 | class OrderedIdInformation : public OptInDispatch { |
28 | public: |
29 | OrderedIdInformation() = delete; |
30 | |
31 | OrderedIdInformation( |
32 | const std::vector<IterDomain*>& ids, |
33 | const std::vector<IterDomain*>& root_domain, |
34 | std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info); |
35 | |
36 | const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>& |
37 | idToRootIds() const { |
38 | return id_to_root_ids_; |
39 | } |
40 | |
41 | bool isConsistentlyOrdered(IterDomain* id) const { |
42 | return consistently_ordered_ids_.find(id) != |
43 | consistently_ordered_ids_.end(); |
44 | } |
45 | |
46 | bool exclusivelyConsumesRoots(IterDomain* id) const { |
47 | return exclusively_consumes_roots_.find(id) != |
48 | exclusively_consumes_roots_.end(); |
49 | } |
50 | |
51 | private: |
52 | // Returns if the id in active_ids should be in exclusively_consumes_roots_ |
53 | bool checkExclusivelyConsumesRoots(IterDomain* id); |
54 | |
55 | void handle(Split*) override; |
56 | |
57 | void handle(Merge* merge) override; |
58 | |
59 | void handle(Swizzle2D* swizzle) override; |
60 | |
61 | // Track which root ids were used to generate each iter domain |
62 | std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>> |
63 | id_to_root_ids_; |
64 | |
65 | // Track all IterDomains that have correct ordered transforms for contiguity. |
66 | // i.e. if we have: |
67 | // |
68 | // root = [i0, i1, i2] |
69 | // i3 = merge(i0, i2) |
70 | // would not be consistently ordered transformed |
71 | // |
72 | // root = [i0, i1, i2] |
73 | // i4, i5 = spit(merge(merge(i0, i1), i2), 4) |
74 | // would be consistently ordered transforms |
75 | // |
76 | // root = [i0, i1, i2, i3] |
77 | // i4 = merge(i1, i2) would also be consistently ordered transformed |
78 | std::unordered_set<IterDomain*> consistently_ordered_ids_; |
79 | |
80 | // Active series of IterDomains that are updated while we're processing the |
81 | // domain. Helps us identify which ids are consistently_ordered_ids_. Used |
82 | // for intermediate storage, not to return. |
83 | std::vector<IterDomain*> active_ids_; |
84 | |
85 | // IterDomains in this set exclusively consume all the uses of their roots. |
86 | // For example: |
87 | // [i0, i1] split(0, f)->merge(1) |
88 | // [ceilDiv(i0, f), f*i1] |
89 | // neither iter domains exclusively consume the roots. With another: |
90 | // merge(0) -> [ceilDiv(i0, f)*f*i1] |
91 | // The resulting iter domain does exclusively consume the roots. |
92 | // |
93 | // Also: |
94 | // [i0, i1, i2, i3] merge(1)->merge(1) |
95 | // ->[i0, i1*i2*i3] |
96 | // both resulting iter domains do exclusively consume their roots |
97 | std::unordered_set<IterDomain*> exclusively_consumes_roots_; |
98 | |
99 | // Broadcast domains that are concretized cannot be considered contiguously |
100 | // indexable. |
101 | // TODO: This constraint is more conservative than necessary as it's only if |
102 | // the domain is concretized within the local indexing, not in the entire |
103 | // fusion. |
104 | std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info_; |
105 | }; |
106 | |
107 | // Based on provided divisible split set, goes through expressions and marks all |
108 | // IterDomains that are dependent on a non-divisible split. |
109 | class NonDivisibleSplitDependencies : public OptInDispatch { |
110 | public: |
111 | NonDivisibleSplitDependencies() = delete; |
112 | |
113 | NonDivisibleSplitDependencies( |
114 | const std::vector<IterDomain*>& ids, |
115 | const std::vector<IterDomain*>& root_domain, |
116 | const std::unordered_set<Split*>& divisible_splits); |
117 | |
118 | bool dependsOnNonDivisibleSplit(IterDomain* id) const { |
119 | return depends_on_non_divisible_split.find(id) != |
120 | depends_on_non_divisible_split.end(); |
121 | } |
122 | |
123 | private: |
124 | std::unordered_set<IterDomain*> depends_on_non_divisible_split; |
125 | }; |
126 | |
127 | // A merge is contiguous if: |
128 | // Inputs of outer are to the left in the root domain of the inputs of RHS. |
129 | // All inputs are contiguous in the root domain: |
130 | // - All marked as contiguous |
131 | // - Only gaps between inputs are broadcast or reductoin dims |
132 | // There are no split transformations performed on outer or inner |
133 | // All transformations on outer or inner are contiguous merges |
134 | // If this criteria holds, then we can index the input root domains of this |
135 | // merge with the indexing provided to the output of the merge in the backward |
136 | // index pass |
137 | |
138 | class ContigIDs : public OptInDispatch { |
139 | public: |
140 | //! Check through the history of ids whose inputs map to root_domain with |
141 | //! contiguity root_contiguity. Return unordered_set of all merges that are |
142 | //! contiguous. Ignore root order is primarily used for predicate generation. |
143 | //! In this case we can linearize indexing of any ID that only consists of |
144 | //! merge operations. |
145 | //! |
146 | //! Mapping information from CA Index concrete to reference domains |
147 | //! is used to find if merged output domains can be indexed. If there's |
148 | //! no mapping to a reference domain, there's no corresponding |
149 | //! index, so it isn't marked as conting merge. |
150 | //! |
151 | //! p2c_id_map can be used when replayed producer domains are |
152 | //! analyzed, in which case producer-to-consumer maps should be |
153 | //! passed. |
154 | //! |
155 | //! If ignore_indexability and ignore_halo_constraint are true, |
156 | //! ignore the constraint on indexing and halo, respectively. It is |
157 | //! the caller that is responsible for its correctness. |
158 | //! Not really sure why but clang-tidy only complains about |
159 | //! std::unordered_map if passed as a const reference. |
160 | ContigIDs( |
161 | const std::vector<IterDomain*>& ids, |
162 | const std::vector<IterDomain*>& root_domain, |
163 | const std::vector<bool>& root_contiguity, |
164 | const std::unordered_set<IterDomain*>& final_ids, |
165 | const std::unordered_map<IterDomain*, Val*>& index_map, |
166 | const std::unordered_set<Split*>& divisible_splits, |
167 | std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {}, |
168 | bool ignore_indexability = false, |
169 | bool ignore_consistent_ordering = false); |
170 | |
171 | //! \param ids IterDomains on the leaves of the domain we're looking for |
172 | //! contiguous indexing into. |
173 | //! \param root_domain the root domain of the domain we're looking for |
174 | //! contiguous indexing into. |
175 | //! \param root_contiguity the contiguity of the root_domain. |
176 | //! \param concrete_to_ref concrete ids of the exact map that the reference |
177 | //! index is using for indexing. |
178 | //! \param divisible_splits a set of all splits in the fusion that are |
179 | //! divisible. |
180 | //! \param ca_map compute at map of the fusion. |
181 | //! \param halo_info halo information of the fusion. |
182 | //! \param concrete_info concretized broadcast information of the fusion. |
183 | //! \param p2c_id_map map from producer to consumer ids used for indexing |
184 | //! producer tensors. |
185 | //! \param ignore_consistent_ordering true for actual indexing into tensors |
186 | //! but false for predicate analysis. Ordering of merges don't matter for |
187 | //! predicate generation as they don't map to a physical address. |
188 | //! \param ignore_indexability can only be true if providing a real |
189 | //! concrete_to_ref map. As what it's checking is if the index is actually |
190 | //! indexable based on the reference. |
191 | ContigIDs( |
192 | const std::vector<IterDomain*>& ids, |
193 | const std::vector<IterDomain*>& root_domain, |
194 | const std::vector<bool>& root_contiguity, |
195 | const std::unordered_set<IterDomain*>& final_ids, |
196 | const std::unordered_map<IterDomain*, Val*>& index_map, |
197 | const std::unordered_set<Split*>& divisible_splits, |
198 | std::shared_ptr<const ComputeAtMap> ca_map, |
199 | std::shared_ptr<const HaloInfo> halo_info, |
200 | std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info, |
201 | std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {}, |
202 | bool ignore_indexability = false, |
203 | bool ignore_consistent_ordering = false); |
204 | |
205 | //! Return an empty ContigIDs with no contiguous ID |
206 | static ContigIDs getNonContigIDs(); |
207 | |
208 | const std::unordered_set<IterDomain*>& contigIDs() const { |
209 | return contig_ids_; |
210 | } |
211 | |
212 | const std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>& |
213 | withinContigIDs() const { |
214 | return within_contig_ids_; |
215 | } |
216 | |
217 | const std::unordered_map<IterDomain*, IterDomain*>& rootToIndexedID() const { |
218 | return root_to_indexed_id_; |
219 | } |
220 | |
221 | VectorOfUniqueEntries<IterDomain*> indexedRootIDs(IterDomain* id) const { |
222 | auto root_ids_it = consistent_transform_info_->idToRootIds().find(id); |
223 | if (root_ids_it == consistent_transform_info_->idToRootIds().end()) { |
224 | return {}; |
225 | } |
226 | return root_ids_it->second; |
227 | } |
228 | |
229 | private: |
230 | using OptInDispatch::handle; |
231 | |
232 | bool inRoot(const std::vector<IterDomain*>& ids) { |
233 | return std::all_of(ids.begin(), ids.end(), [this](IterDomain* id) { |
234 | return is_contig_root_.find(id) != is_contig_root_.end(); |
235 | }); |
236 | } |
237 | |
238 | bool isContig(IterDomain* id) { |
239 | return contig_ids_.find(id) != contig_ids_.end(); |
240 | } |
241 | |
242 | // Split outputs are not contiguous, don't need to do anything. |
243 | void handle(Split*) override {} |
244 | |
245 | void handle(Merge* merge) override; |
246 | |
247 | // TODO: |
248 | // Currently not propagating any contiguity information |
249 | // as contiguity is generally not preserved after swizzles. |
250 | // But in follow ups we could gradually add back a few special |
251 | // cases, depending on specific swizzle type and axes. |
252 | void handle(Swizzle2D* swizzle) override {} |
253 | |
254 | IterDomain* getCAIndexConcreteId(IterDomain* id) const; |
255 | |
256 | //! True if an ID is indexable. |
257 | //! E.g., a merged domain with broadcast may not be indexable when |
258 | //! its corresponding reference tensor has non-broadcast domains. |
259 | bool isIndexable(IterDomain* id) const; |
260 | |
261 | //! Return an ID mapped with id_map_ or itself |
262 | IterDomain* getMappedId(IterDomain* id) const; |
263 | |
264 | private: |
265 | void build(const std::vector<IterDomain*>& ids); |
266 | |
267 | //! Root domains to analyze contiguity |
268 | const std::vector<IterDomain*>& root_domain_; |
269 | //! Contiguity of root_domain_ |
270 | const std::vector<bool>& root_contiguity_; |
271 | //! Domains where indexing/predicates cannot be done with their |
272 | //! consumers domains |
273 | const std::unordered_set<IterDomain*>& final_ids_; |
274 | //! Mapping of concrete domains to indices. Just used to check if |
275 | //! there's an index for an IterDomain. |
276 | const std::unordered_map<IterDomain*, Val*> index_map_; |
277 | // Divisible split information as we can still consider iter domains |
278 | // contiguous through divisible splits. |
279 | const std::unordered_set<Split*>& divisible_splits_; |
280 | |
281 | std::shared_ptr<const ComputeAtMap> ca_map_; |
282 | std::shared_ptr<const HaloInfo> halo_info_; |
283 | std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info_; |
284 | |
285 | //! Producer-to-consumer index map in the case of analyzing replayed |
286 | //! producer tensors |
287 | const std::unordered_map<IterDomain*, IterDomain*> p2c_id_map_; |
288 | |
289 | const bool ignore_indexability_ = false; |
290 | const bool ignore_consistent_ordering_ = false; |
291 | |
292 | //! Mapping of root domain to bool indicating contiguity |
293 | std::unordered_map<IterDomain*, bool> is_contig_root_; |
294 | // Mark if ids are result of contigous merges |
295 | std::unordered_set<IterDomain*> contig_ids_; |
296 | // Given contiguous domain, return all iter domains within its history. |
297 | std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> |
298 | within_contig_ids_; |
299 | //! Mapping of root domain to the actual indexed domain, which can |
300 | //! be itself or a contig merged domain if found. |
301 | std::unordered_map<IterDomain*, IterDomain*> root_to_indexed_id_; |
302 | |
303 | std::unique_ptr<const OrderedIdInformation> consistent_transform_info_; |
304 | |
305 | NonDivisibleSplitDependencies non_divisible_id_info_; |
306 | }; |
307 | |
308 | } // namespace cuda |
309 | } // namespace fuser |
310 | } // namespace jit |
311 | } // namespace torch |
312 | |