1 | #pragma once |
2 | |
3 | #include <iter_visitor.h> |
4 | #include <root_domain_map.h> |
5 | |
6 | #include <unordered_map> |
7 | #include <unordered_set> |
8 | #include <vector> |
9 | |
10 | /* |
11 | * Index compute takes in a list of indices typically generated from the |
12 | * surrounding for loop nest. The number of indicies are intended to match the |
13 | * number of dimensions of the incomming TensorView which may have less or more |
14 | * dimensions than its root due to split/merge operations. |
15 | * Split/merge operations are then replayed backwards produce resulting |
16 | * indices (based on input indices) that match the root dimension. |
17 | * |
18 | * For example with GLOBAL tensor: |
19 | * TV[I, K] |
20 | * TV[Io, Ii{4}, K] = TV.split(I, factor=4) |
21 | * ALLOC: NONE |
22 | * INDEX: indexCompute {i, j, k} -> {i * 4 + j, k} |
23 | * FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k} |
24 | * PREDICATE: {i * 4 + j, k} -> i * 4 + j < I |
25 | * |
26 | * |
27 | * For example with SHARED tensor: |
28 | * |
29 | * global_TV[I, K] |
30 | * global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4) |
31 | * smem_TV.compute_at(global_TV, 1) |
32 | * global_TV.parallelize(1, threadIDx.x) |
33 | * |
34 | * ALLOC: alloc(smem_TV, 4 x K) |
35 | * INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k} |
36 | * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k} |
37 | * PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if |
38 | * global |
39 | * |
40 | * |
41 | * For example with LOCAL tensor: |
42 | * global_TV[I, K, L] |
43 | * global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4) |
44 | * reg_TV.compute_at(global_TV, 2) |
45 | * global_TV.parallelize(1, threadIDx.x) |
46 | * global_TV{i, j, k, l} -> { i * 4 + j, k, l } |
47 | * global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l} |
48 | * |
49 | * ALLOC: alloc(reg_TV, K x L) |
50 | * INDEX: {k, l} -> {k, l} |
51 | * FLATTENED_INDEX: {k, l} -> {k * L + l} |
52 | * PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global |
53 | * |
54 | * These indices can then be flattened later based on strides. |
55 | */ |
56 | |
57 | namespace torch { |
58 | namespace jit { |
59 | namespace fuser { |
60 | namespace cuda { |
61 | |
62 | class ContigIDs; |
63 | class LoopIndexing; |
64 | struct IndexFromIdGraph; |
65 | |
66 | class IndexCompute : public BackwardVisitor { |
67 | protected: |
68 | using BackwardVisitor::handle; |
69 | |
70 | void handle(Split*) override; |
71 | void handle(Merge*) override; |
72 | void handle(Expr*) override; |
73 | void handle(Swizzle2D*) override; |
74 | |
75 | // return extent_map_[id] if exists, else return id->extent() |
76 | Val* getExtent(IterDomain* id) const; |
77 | |
78 | //! True if a domain is not used to index |
79 | bool isZero(IterDomain* id) const; |
80 | //! True if any dependent of a domain is not used to index |
81 | bool hasZeroMerged(IterDomain* id) const; |
82 | |
83 | //! Returns the concrete ID from the compute at EXACT mode map if |
84 | //! concrete_id_pass == true, otherwise returns id passed in. |
85 | //! Helps unify the expr handling logic in reference domain and concrete id |
86 | //! based traversal. |
87 | IterDomain* maybeGetExactMapConcreteID(IterDomain* id); |
88 | |
89 | //! (Concrete indexing pass only) |
90 | //! Collect permissive index binding from the given expression. |
91 | //! See also permissive_map_ and LoopIndexing::getBackwardOutOfLineExprList. |
92 | void collectIndexIntoPermissiveMap(const LoopIndexing& loop_indexing); |
93 | |
94 | //! (Concrete indexing pass only) |
95 | //! Iterate through id_expr's input and pull index vals from permissive |
96 | //! map, when both of the following are true: |
97 | //! 1. the output id is missing in index_map_. |
98 | //! 2. the output id is found in permissive map. |
99 | void updateIndexMapFromPermissiveMap(const Expr* id_expr); |
100 | |
101 | // Tensor domain we're mapping back to root |
102 | const TensorDomain* td_; // NOLINT |
103 | |
104 | // Map we update as we propagate backward, containing all IDs in the |
105 | // propagation. Initial indices are mapped with this map at tv->domain() |
106 | // and are back propagated to tv->getRootDomain(). This index_map_ keeps the |
107 | // indices at intermediate IterDomain's in that back propagation. |
108 | std::unordered_map<IterDomain*, Val*> index_map_; // NOLINT |
109 | |
110 | // Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its |
111 | // producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to |
112 | // the extent I0*I1. Also contains updated extents if we merge in a 0 index. |
113 | // See zero_merged_in_. |
114 | std::unordered_map<IterDomain*, Val*> extent_map_; // NOLINT |
115 | |
116 | // Keeps track of domains that do not contribute to indexing |
117 | std::unordered_set<IterDomain*> zero_domains_; // NOLINT |
118 | |
119 | // This set keeps track of IterDomain's that have had a zero index merged into |
120 | // them. This happens if we do something like tv->axis(0)->split(4) then |
121 | // tv->computeAt(1, ...) if this tensor is in smem or lmem the backward |
122 | // indexing would be (0, i) then when we do the backward computation that zero |
123 | // and i would attempt to be merged together. We handle indices like these |
124 | // specially. |
125 | std::unordered_set<IterDomain*> zero_merged_in_; |
126 | |
127 | // IDs that are a result of contiguous merges |
128 | std::unordered_set<IterDomain*> contig_ids_; |
129 | |
130 | // Map from root to indexed domains |
131 | std::unordered_map<IterDomain*, IterDomain*> root_to_indexed_id_; |
132 | |
133 | // Mentions if we should propagate an index down a particular IterDomain path |
134 | // if there's an option |
135 | std::unordered_set<IterDomain*> preferred_paths_; |
136 | |
137 | // Map from IterDomains to halo-extended extents |
138 | std::unordered_map<IterDomain*, Val*> halo_extent_map_; |
139 | |
140 | // Temporary flag which tells IndexCompute to use concrete id's from the exact |
141 | // map rather than the actual IDs used in the ID expressions. |
142 | bool concrete_id_pass_ = false; |
143 | |
144 | // Mode of swizzle that are activated in this index compute |
145 | // instance. Will treat swizzles of different mode as no-op. |
146 | // Currently data mode swizzles are handled same as before in IndexSwizzle |
147 | // pass, while loop mode swizzles are handled early on in concrete indexing |
148 | // pass. See also [Note on swizzle mode] |
149 | SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle; |
150 | |
151 | // (Concrete id pass only) |
152 | // Contains the indexing math that could be resolved with only the |
153 | // iterdomains on the right of the consumer_tv's ca axis, i.e. the |
154 | // ones that corresponding to the loops that consumer_tv would not |
155 | // share with any of its consumers. |
156 | // These indexing vals should be kept separate from index_map_ and |
157 | // should only be used when the indexing traversal follows the |
158 | // order defined in LoopIndexingAnalysis::traverseFromDomainVals. |
159 | std::unordered_map<IterDomain*, Val*> permissive_index_map_; |
160 | |
161 | public: |
162 | const std::unordered_map<IterDomain*, Val*>& indexMap() const { |
163 | return index_map_; |
164 | } |
165 | |
166 | const std::unordered_map<IterDomain*, Val*>& extentMap() const { |
167 | return extent_map_; |
168 | } |
169 | |
170 | const std::unordered_set<IterDomain*>& zeroDomains() const { |
171 | return zero_domains_; |
172 | } |
173 | |
174 | const std::unordered_set<IterDomain*>& zeroMergedIn() const { |
175 | return zero_merged_in_; |
176 | } |
177 | |
178 | const std::unordered_map<IterDomain*, IterDomain*>& rootToContigID() const { |
179 | return root_to_indexed_id_; |
180 | } |
181 | |
182 | // Propagate back from _td using initial_index_map |
183 | IndexCompute( |
184 | const TensorDomain* _td, |
185 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
186 | std::unordered_map<IterDomain*, Val*> _extent_map, |
187 | std::unordered_set<IterDomain*> zero_domains, |
188 | std::unordered_set<IterDomain*> _zero_merged_in, |
189 | std::unordered_set<IterDomain*> preferred_paths = {}, |
190 | std::unordered_map<IterDomain*, Val*> halo_extent_map = {}); |
191 | |
192 | IndexCompute( |
193 | const TensorDomain* _td, |
194 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
195 | std::unordered_map<IterDomain*, Val*> _extent_map, |
196 | std::unordered_set<IterDomain*> zero_domains, |
197 | std::unordered_set<IterDomain*> _zero_merged_in, |
198 | const ContigIDs& contig_finder, |
199 | std::unordered_set<IterDomain*> preferred_paths = {}, |
200 | std::unordered_map<IterDomain*, Val*> halo_extent_map = {}); |
201 | |
202 | // Entry point used for using concrete id based traversal. This traversal is |
203 | // assumed to start at leaf IDs provided by initial_index_map. |
204 | IndexCompute( |
205 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
206 | std::unordered_set<IterDomain*> zero_domains, |
207 | std::unordered_set<IterDomain*> preferred_paths, |
208 | std::unordered_map<IterDomain*, Val*> concrete_halo_extent_map); |
209 | |
210 | // Updates index_map, extent_map, and zero_merged_in based on id_map and |
211 | // returns a new IndexCompute ready to be used. |
212 | IndexCompute updateIndexCompute( |
213 | const TensorDomain* new_td, |
214 | const std::unordered_map<IterDomain*, IterDomain*>& id_map, |
215 | const ContigIDs& contig_finder) const; |
216 | |
217 | // Interface to run index traversal through loop indexing analysis result to |
218 | // be used with the entry point for concrete id based traversal. |
219 | void run(const LoopIndexing& loop_indexing); |
220 | |
221 | virtual void run(); |
222 | }; |
223 | |
224 | //! Apply swizzle and update root indices accordingly |
225 | class IndexSwizzle : public IndexCompute { |
226 | public: |
227 | IndexSwizzle( |
228 | const TensorView* tv, |
229 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
230 | std::unordered_map<IterDomain*, Val*> extent_map, |
231 | std::unordered_set<IterDomain*> zero_domains, |
232 | std::unordered_set<IterDomain*> zero_merged_in); |
233 | |
234 | IndexSwizzle( |
235 | const TensorView* tv, |
236 | const TensorDomain* domain, |
237 | std::unordered_map<IterDomain*, Val*> initial_index_map, |
238 | std::unordered_map<IterDomain*, Val*> extent_map, |
239 | std::unordered_set<IterDomain*> zero_domains, |
240 | std::unordered_set<IterDomain*> zero_merged_in); |
241 | |
242 | void run() override; |
243 | |
244 | protected: |
245 | using IndexCompute::handle; |
246 | |
247 | void handle(Expr* e) override; |
248 | |
249 | void handle(Swizzle2D* swizzle_2d) override; |
250 | |
251 | private: |
252 | const TensorView* tv_ = nullptr; |
253 | SwizzleType swizzle_type_ = SwizzleType::NoSwizzle; |
254 | std::vector<IterDomain*> ids_to_swizzle_; |
255 | std::unordered_set<IterDomain*> swizzled_ids_; |
256 | }; |
257 | |
258 | //! Predicate information of a root or contiguous merged domain |
259 | class RootPredicateInfo { |
260 | friend class Index; |
261 | |
262 | public: |
263 | const auto& startPredicate() const { |
264 | return start_predicate_; |
265 | } |
266 | |
267 | auto& startPredicate() { |
268 | return start_predicate_; |
269 | } |
270 | |
271 | const auto& startOffset() const { |
272 | return start_offset_; |
273 | } |
274 | |
275 | const auto& stopPredicate() const { |
276 | return stop_predicate_; |
277 | } |
278 | |
279 | const auto& stopOffset() const { |
280 | return stop_offset_; |
281 | } |
282 | |
283 | const auto& rootIds() const { |
284 | return root_ids_; |
285 | } |
286 | |
287 | //! Return a false RootPredicateInfo, i.e., both start and stop |
288 | //! predicates are false. |
289 | static RootPredicateInfo getFalseInfo(); |
290 | |
291 | private: |
292 | // prdicate for lower end |
293 | Bool* start_predicate_ = nullptr; |
294 | // prdicate for upper end |
295 | Bool* stop_predicate_ = nullptr; |
296 | // Offset of the start predicate |
297 | Val* start_offset_ = nullptr; |
298 | // Offset of the stop predicate |
299 | Val* stop_offset_ = nullptr; |
300 | // Track which roots have been handled by the generated predicates |
301 | std::unordered_set<IterDomain*> root_ids_; |
302 | }; |
303 | |
304 | // Simple interface for IndexCompute |
305 | // If getComputeAtAxis and more generally TensorView const model is fixed, we |
306 | // can make the below tensorviews const. |
307 | class Index { |
308 | private: |
309 | // Producer indexing if it's in shared or local memory |
310 | static std::vector<Val*> getNonGlobalProducerStridedIndices( |
311 | TensorView* producer, |
312 | const TensorView* consumer, |
313 | const std::vector<kir::ForLoop*>& loops); |
314 | |
315 | // Consumer indexing if it's in shared or local memory |
316 | static std::vector<Val*> getNonGlobalConsumerStridedIndices( |
317 | const TensorView* consumer, |
318 | const std::vector<kir::ForLoop*>& loops); |
319 | |
320 | // Producer if it's in global memory |
321 | static std::vector<Val*> getGlobalProducerStridedIndices( |
322 | TensorView* producer, |
323 | const TensorView* consumer, |
324 | const std::vector<kir::ForLoop*>& loops); |
325 | |
326 | // Consumer indexing if it's in global memory |
327 | static std::vector<Val*> getGlobalConsumerStridedIndices( |
328 | const TensorView* consumer, |
329 | const std::vector<kir::ForLoop*>& loops); |
330 | |
331 | // get the strides of a tensor used for the index lowering |
332 | static std::vector<Val*> getStrides(const TensorView* tv); |
333 | |
334 | // get the root indices of a tensor used for the index lowering |
335 | static std::vector<Val*> getRootIndices( |
336 | const TensorView* tv, |
337 | const std::vector<kir::ForLoop*>& loops, |
338 | const IndexFromIdGraph& index_from_id_graph); |
339 | |
340 | public: |
341 | // Indexing functions |
342 | // Consumer = Producer |
343 | // i.e. T0 = T1... -> T0 is the consumer, T1 is the producer |
344 | // Producer indexing dispatch |
345 | static kir::TensorIndex* getProducerIndex( |
346 | TensorView* producer, |
347 | const TensorView* consumer, |
348 | const std::vector<kir::ForLoop*>& loops); |
349 | |
350 | // Consumer index dispatch |
351 | static kir::TensorIndex* getConsumerIndex( |
352 | const TensorView* consumer, |
353 | const std::vector<kir::ForLoop*>& loops); |
354 | |
355 | //! Returns a vector of strided indices mapped onto the (rfactor) |
356 | //! root domain of a producer tensor. The size of the returned |
357 | //! vector is guaranteed to be equal to the number of axes of the |
358 | //! indexing root domain. |
359 | static std::vector<Val*> getProducerStridedIndices( |
360 | TensorView* producer, |
361 | const TensorView* consumer, |
362 | const std::vector<kir::ForLoop*>& loops); |
363 | |
364 | //! Returns a vector of strided indices mapped onto the (rfactor) |
365 | //! root domain of a consumer tensor. The size of the returned |
366 | //! vector is guaranteed to be equal to the number of axes of the |
367 | //! indexing root domain. |
368 | static std::vector<Val*> getConsumerStridedIndices( |
369 | const TensorView* consumer, |
370 | const std::vector<kir::ForLoop*>& loops); |
371 | |
372 | //! Returns the logical index linearized from a multi-dimension address into a |
373 | //! linear memory address a consumer tensor. The returned index is intended to |
374 | //! be used for the computation of some tensor factories, such as: arange and |
375 | //! rand (for Philox pseudo random sequences) |
376 | static std::vector<Val*> getLinearLogicalIndex( |
377 | TensorView* consumer_tv, |
378 | const std::vector<kir::ForLoop*>& loops); |
379 | |
380 | //! Returns a vector of logical indices mapped onto the (rfactor) |
381 | //! root domain of a consumer tensor. The returned index is intended |
382 | //! to be used for the computation of some tensor factories, such as: |
383 | //! eye |
384 | static std::vector<Val*> getPerDimLogicalIndex( |
385 | TensorView* consumer_tv, |
386 | const std::vector<kir::ForLoop*>& loops); |
387 | |
388 | //! Take a consumer tensorview and loop nest and generates predicates |
389 | //! associated with the concrete roots of the loop nest. Returns a list of |
390 | //! predicates, and a list of concrete roots they're associated with. It |
391 | //! is assumed that no predicate is required if index[i] is an index |
392 | //! directly from a for loop. This will not catch all cases if we actually |
393 | //! have static size information for example: |
394 | //! |
395 | //! TV[I].split(4) |
396 | //! would produce the code: |
397 | //! for(i : I/4) |
398 | //! for(j : 4) |
399 | //! if( i * 4 + j < TV.size(0)) |
400 | //! TV[i * 4 + j]... |
401 | //! |
402 | //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't |
403 | //! need the predicate. This will be caught by canOmitPredicate in the |
404 | //! predicate lowering |
405 | //! |
406 | //! unswitch_or_vec_loop is the for loop to start the unswitch like |
407 | //! predicate, this is not a bool value as if we have an unswitch loop |
408 | //! with a vectorized loop inside, we only want to base the "unswitch" |
409 | //! like predicate on the vectorized loop. |
410 | static std::vector<RootPredicateInfo> getReferenceRootPredicates( |
411 | TensorView* consumer_tv, |
412 | const std::vector<kir::ForLoop*>& loops, |
413 | kir::ForLoop* unswitch_or_vec_loop, |
414 | bool padding_predicate); |
415 | }; |
416 | |
417 | // Used for local and shared index mapping. Returns a map from loops |
418 | // to loop indices as well as a set of loops that do not contribute to |
419 | // indexing. |
420 | // TODO: could be cleaned up further. |
421 | std::pair< |
422 | std::unordered_map<kir::ForLoop*, Val*>, |
423 | std::unordered_set<kir::ForLoop*>> |
424 | indexMapFromTV( |
425 | const TensorView* tv, |
426 | const std::vector<kir::ForLoop*>& loops, |
427 | kir::ForLoop* alloc_loop, |
428 | bool as_consumer, |
429 | kir::ForLoop* double_buffer_loop = nullptr); |
430 | |
431 | //! Set "pragma unroll" required for loops that indexing of Local |
432 | //! tensors depends on. |
433 | //! |
434 | //! \param tv Indexed tensor |
435 | //! \param alloc_loop Allocation loop of tv |
436 | //! \param loops The current loop structure |
437 | //! \param id_map Producer-to-consumer map in case of indexing as producer |
438 | void ensureStaticIndexing( |
439 | const TensorView* tv, |
440 | kir::ForLoop* alloc_loop, |
441 | const std::vector<kir::ForLoop*>& loops, |
442 | const std::unordered_map<IterDomain*, IterDomain*>& id_map = {}); |
443 | |
444 | } // namespace cuda |
445 | } // namespace fuser |
446 | } // namespace jit |
447 | } // namespace torch |
448 | |