1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <compute_at_map.h>
6#include <disjoint_set.h>
7#include <ir_all_nodes.h>
8#include <lower_shift.h>
9#include <lower_trivial_broadcast.h>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16// Goes through the transformations associated with a series of ids and root
17// ids. Checks the ordering of the iteration domains through these operations to
18// pick out which operations are consistently ordered. For example:
19// [i0, i1, i2]
20// ->split(0, 4)->merge(1)->merge(1)->merge(0)
21// are consistently ordered from largest to smallest extents, but
22// ->split(0, 4)->merge(1)->merge(0, 2)->merge(0) is not consistently ordered
23// with the roots.
24//
25// This property is important to understand the contiguity of dimensions through
26// complex transformations.
27class OrderedIdInformation : public OptInDispatch {
28 public:
29 OrderedIdInformation() = delete;
30
31 OrderedIdInformation(
32 const std::vector<IterDomain*>& ids,
33 const std::vector<IterDomain*>& root_domain,
34 std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info);
35
36 const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
37 idToRootIds() const {
38 return id_to_root_ids_;
39 }
40
41 bool isConsistentlyOrdered(IterDomain* id) const {
42 return consistently_ordered_ids_.find(id) !=
43 consistently_ordered_ids_.end();
44 }
45
46 bool exclusivelyConsumesRoots(IterDomain* id) const {
47 return exclusively_consumes_roots_.find(id) !=
48 exclusively_consumes_roots_.end();
49 }
50
51 private:
52 // Returns if the id in active_ids should be in exclusively_consumes_roots_
53 bool checkExclusivelyConsumesRoots(IterDomain* id);
54
55 void handle(Split*) override;
56
57 void handle(Merge* merge) override;
58
59 void handle(Swizzle2D* swizzle) override;
60
61 // Track which root ids were used to generate each iter domain
62 std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
63 id_to_root_ids_;
64
65 // Track all IterDomains that have correct ordered transforms for contiguity.
66 // i.e. if we have:
67 //
68 // root = [i0, i1, i2]
69 // i3 = merge(i0, i2)
70 // would not be consistently ordered transformed
71 //
72 // root = [i0, i1, i2]
73 // i4, i5 = spit(merge(merge(i0, i1), i2), 4)
74 // would be consistently ordered transforms
75 //
76 // root = [i0, i1, i2, i3]
77 // i4 = merge(i1, i2) would also be consistently ordered transformed
78 std::unordered_set<IterDomain*> consistently_ordered_ids_;
79
80 // Active series of IterDomains that are updated while we're processing the
81 // domain. Helps us identify which ids are consistently_ordered_ids_. Used
82 // for intermediate storage, not to return.
83 std::vector<IterDomain*> active_ids_;
84
85 // IterDomains in this set exclusively consume all the uses of their roots.
86 // For example:
87 // [i0, i1] split(0, f)->merge(1)
88 // [ceilDiv(i0, f), f*i1]
89 // neither iter domains exclusively consume the roots. With another:
90 // merge(0) -> [ceilDiv(i0, f)*f*i1]
91 // The resulting iter domain does exclusively consume the roots.
92 //
93 // Also:
94 // [i0, i1, i2, i3] merge(1)->merge(1)
95 // ->[i0, i1*i2*i3]
96 // both resulting iter domains do exclusively consume their roots
97 std::unordered_set<IterDomain*> exclusively_consumes_roots_;
98
99 // Broadcast domains that are concretized cannot be considered contiguously
100 // indexable.
101 // TODO: This constraint is more conservative than necessary as it's only if
102 // the domain is concretized within the local indexing, not in the entire
103 // fusion.
104 std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info_;
105};
106
107// Based on provided divisible split set, goes through expressions and marks all
108// IterDomains that are dependent on a non-divisible split.
109class NonDivisibleSplitDependencies : public OptInDispatch {
110 public:
111 NonDivisibleSplitDependencies() = delete;
112
113 NonDivisibleSplitDependencies(
114 const std::vector<IterDomain*>& ids,
115 const std::vector<IterDomain*>& root_domain,
116 const std::unordered_set<Split*>& divisible_splits);
117
118 bool dependsOnNonDivisibleSplit(IterDomain* id) const {
119 return depends_on_non_divisible_split.find(id) !=
120 depends_on_non_divisible_split.end();
121 }
122
123 private:
124 std::unordered_set<IterDomain*> depends_on_non_divisible_split;
125};
126
127// A merge is contiguous if:
128// Inputs of outer are to the left in the root domain of the inputs of RHS.
129// All inputs are contiguous in the root domain:
130// - All marked as contiguous
131// - Only gaps between inputs are broadcast or reductoin dims
132// There are no split transformations performed on outer or inner
133// All transformations on outer or inner are contiguous merges
134// If this criteria holds, then we can index the input root domains of this
135// merge with the indexing provided to the output of the merge in the backward
136// index pass
137
138class ContigIDs : public OptInDispatch {
139 public:
140 //! Check through the history of ids whose inputs map to root_domain with
141 //! contiguity root_contiguity. Return unordered_set of all merges that are
142 //! contiguous. Ignore root order is primarily used for predicate generation.
143 //! In this case we can linearize indexing of any ID that only consists of
144 //! merge operations.
145 //!
146 //! Mapping information from CA Index concrete to reference domains
147 //! is used to find if merged output domains can be indexed. If there's
148 //! no mapping to a reference domain, there's no corresponding
149 //! index, so it isn't marked as conting merge.
150 //!
151 //! p2c_id_map can be used when replayed producer domains are
152 //! analyzed, in which case producer-to-consumer maps should be
153 //! passed.
154 //!
155 //! If ignore_indexability and ignore_halo_constraint are true,
156 //! ignore the constraint on indexing and halo, respectively. It is
157 //! the caller that is responsible for its correctness.
158 //! Not really sure why but clang-tidy only complains about
159 //! std::unordered_map if passed as a const reference.
160 ContigIDs(
161 const std::vector<IterDomain*>& ids,
162 const std::vector<IterDomain*>& root_domain,
163 const std::vector<bool>& root_contiguity,
164 const std::unordered_set<IterDomain*>& final_ids,
165 const std::unordered_map<IterDomain*, Val*>& index_map,
166 const std::unordered_set<Split*>& divisible_splits,
167 std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {},
168 bool ignore_indexability = false,
169 bool ignore_consistent_ordering = false);
170
171 //! \param ids IterDomains on the leaves of the domain we're looking for
172 //! contiguous indexing into.
173 //! \param root_domain the root domain of the domain we're looking for
174 //! contiguous indexing into.
175 //! \param root_contiguity the contiguity of the root_domain.
176 //! \param concrete_to_ref concrete ids of the exact map that the reference
177 //! index is using for indexing.
178 //! \param divisible_splits a set of all splits in the fusion that are
179 //! divisible.
180 //! \param ca_map compute at map of the fusion.
181 //! \param halo_info halo information of the fusion.
182 //! \param concrete_info concretized broadcast information of the fusion.
183 //! \param p2c_id_map map from producer to consumer ids used for indexing
184 //! producer tensors.
185 //! \param ignore_consistent_ordering true for actual indexing into tensors
186 //! but false for predicate analysis. Ordering of merges don't matter for
187 //! predicate generation as they don't map to a physical address.
188 //! \param ignore_indexability can only be true if providing a real
189 //! concrete_to_ref map. As what it's checking is if the index is actually
190 //! indexable based on the reference.
191 ContigIDs(
192 const std::vector<IterDomain*>& ids,
193 const std::vector<IterDomain*>& root_domain,
194 const std::vector<bool>& root_contiguity,
195 const std::unordered_set<IterDomain*>& final_ids,
196 const std::unordered_map<IterDomain*, Val*>& index_map,
197 const std::unordered_set<Split*>& divisible_splits,
198 std::shared_ptr<const ComputeAtMap> ca_map,
199 std::shared_ptr<const HaloInfo> halo_info,
200 std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info,
201 std::unordered_map<IterDomain*, IterDomain*> p2c_id_map = {},
202 bool ignore_indexability = false,
203 bool ignore_consistent_ordering = false);
204
205 //! Return an empty ContigIDs with no contiguous ID
206 static ContigIDs getNonContigIDs();
207
208 const std::unordered_set<IterDomain*>& contigIDs() const {
209 return contig_ids_;
210 }
211
212 const std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>&
213 withinContigIDs() const {
214 return within_contig_ids_;
215 }
216
217 const std::unordered_map<IterDomain*, IterDomain*>& rootToIndexedID() const {
218 return root_to_indexed_id_;
219 }
220
221 VectorOfUniqueEntries<IterDomain*> indexedRootIDs(IterDomain* id) const {
222 auto root_ids_it = consistent_transform_info_->idToRootIds().find(id);
223 if (root_ids_it == consistent_transform_info_->idToRootIds().end()) {
224 return {};
225 }
226 return root_ids_it->second;
227 }
228
229 private:
230 using OptInDispatch::handle;
231
232 bool inRoot(const std::vector<IterDomain*>& ids) {
233 return std::all_of(ids.begin(), ids.end(), [this](IterDomain* id) {
234 return is_contig_root_.find(id) != is_contig_root_.end();
235 });
236 }
237
238 bool isContig(IterDomain* id) {
239 return contig_ids_.find(id) != contig_ids_.end();
240 }
241
242 // Split outputs are not contiguous, don't need to do anything.
243 void handle(Split*) override {}
244
245 void handle(Merge* merge) override;
246
247 // TODO:
248 // Currently not propagating any contiguity information
249 // as contiguity is generally not preserved after swizzles.
250 // But in follow ups we could gradually add back a few special
251 // cases, depending on specific swizzle type and axes.
252 void handle(Swizzle2D* swizzle) override {}
253
254 IterDomain* getCAIndexConcreteId(IterDomain* id) const;
255
256 //! True if an ID is indexable.
257 //! E.g., a merged domain with broadcast may not be indexable when
258 //! its corresponding reference tensor has non-broadcast domains.
259 bool isIndexable(IterDomain* id) const;
260
261 //! Return an ID mapped with id_map_ or itself
262 IterDomain* getMappedId(IterDomain* id) const;
263
264 private:
265 void build(const std::vector<IterDomain*>& ids);
266
267 //! Root domains to analyze contiguity
268 const std::vector<IterDomain*>& root_domain_;
269 //! Contiguity of root_domain_
270 const std::vector<bool>& root_contiguity_;
271 //! Domains where indexing/predicates cannot be done with their
272 //! consumers domains
273 const std::unordered_set<IterDomain*>& final_ids_;
274 //! Mapping of concrete domains to indices. Just used to check if
275 //! there's an index for an IterDomain.
276 const std::unordered_map<IterDomain*, Val*> index_map_;
277 // Divisible split information as we can still consider iter domains
278 // contiguous through divisible splits.
279 const std::unordered_set<Split*>& divisible_splits_;
280
281 std::shared_ptr<const ComputeAtMap> ca_map_;
282 std::shared_ptr<const HaloInfo> halo_info_;
283 std::shared_ptr<const ConcretizedBroadcastDomains> concrete_info_;
284
285 //! Producer-to-consumer index map in the case of analyzing replayed
286 //! producer tensors
287 const std::unordered_map<IterDomain*, IterDomain*> p2c_id_map_;
288
289 const bool ignore_indexability_ = false;
290 const bool ignore_consistent_ordering_ = false;
291
292 //! Mapping of root domain to bool indicating contiguity
293 std::unordered_map<IterDomain*, bool> is_contig_root_;
294 // Mark if ids are result of contigous merges
295 std::unordered_set<IterDomain*> contig_ids_;
296 // Given contiguous domain, return all iter domains within its history.
297 std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
298 within_contig_ids_;
299 //! Mapping of root domain to the actual indexed domain, which can
300 //! be itself or a contig merged domain if found.
301 std::unordered_map<IterDomain*, IterDomain*> root_to_indexed_id_;
302
303 std::unique_ptr<const OrderedIdInformation> consistent_transform_info_;
304
305 NonDivisibleSplitDependencies non_divisible_id_info_;
306};
307
308} // namespace cuda
309} // namespace fuser
310} // namespace jit
311} // namespace torch
312