1#pragma once
2
3#include <iter_visitor.h>
4#include <root_domain_map.h>
5
6#include <unordered_map>
7#include <unordered_set>
8#include <vector>
9
10/*
11 * Index compute takes in a list of indices typically generated from the
12 * surrounding for loop nest. The number of indicies are intended to match the
13 * number of dimensions of the incomming TensorView which may have less or more
14 * dimensions than its root due to split/merge operations.
15 * Split/merge operations are then replayed backwards produce resulting
16 * indices (based on input indices) that match the root dimension.
17 *
18 * For example with GLOBAL tensor:
19 * TV[I, K]
20 * TV[Io, Ii{4}, K] = TV.split(I, factor=4)
21 * ALLOC: NONE
22 * INDEX: indexCompute {i, j, k} -> {i * 4 + j, k}
23 * FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k}
24 * PREDICATE: {i * 4 + j, k} -> i * 4 + j < I
25 *
26 *
27 * For example with SHARED tensor:
28 *
29 * global_TV[I, K]
30 * global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4)
31 * smem_TV.compute_at(global_TV, 1)
32 * global_TV.parallelize(1, threadIDx.x)
33 *
34 * ALLOC: alloc(smem_TV, 4 x K)
35 * INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k}
36 * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k}
37 * PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if
38 * global
39 *
40 *
41 * For example with LOCAL tensor:
42 * global_TV[I, K, L]
43 * global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4)
44 * reg_TV.compute_at(global_TV, 2)
45 * global_TV.parallelize(1, threadIDx.x)
46 * global_TV{i, j, k, l} -> { i * 4 + j, k, l }
47 * global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l}
48 *
49 * ALLOC: alloc(reg_TV, K x L)
50 * INDEX: {k, l} -> {k, l}
51 * FLATTENED_INDEX: {k, l} -> {k * L + l}
52 * PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global
53 *
54 * These indices can then be flattened later based on strides.
55 */
56
57namespace torch {
58namespace jit {
59namespace fuser {
60namespace cuda {
61
62class ContigIDs;
63class LoopIndexing;
64struct IndexFromIdGraph;
65
66class IndexCompute : public BackwardVisitor {
67 protected:
68 using BackwardVisitor::handle;
69
70 void handle(Split*) override;
71 void handle(Merge*) override;
72 void handle(Expr*) override;
73 void handle(Swizzle2D*) override;
74
75 // return extent_map_[id] if exists, else return id->extent()
76 Val* getExtent(IterDomain* id) const;
77
78 //! True if a domain is not used to index
79 bool isZero(IterDomain* id) const;
80 //! True if any dependent of a domain is not used to index
81 bool hasZeroMerged(IterDomain* id) const;
82
83 //! Returns the concrete ID from the compute at EXACT mode map if
84 //! concrete_id_pass == true, otherwise returns id passed in.
85 //! Helps unify the expr handling logic in reference domain and concrete id
86 //! based traversal.
87 IterDomain* maybeGetExactMapConcreteID(IterDomain* id);
88
89 //! (Concrete indexing pass only)
90 //! Collect permissive index binding from the given expression.
91 //! See also permissive_map_ and LoopIndexing::getBackwardOutOfLineExprList.
92 void collectIndexIntoPermissiveMap(const LoopIndexing& loop_indexing);
93
94 //! (Concrete indexing pass only)
95 //! Iterate through id_expr's input and pull index vals from permissive
96 //! map, when both of the following are true:
97 //! 1. the output id is missing in index_map_.
98 //! 2. the output id is found in permissive map.
99 void updateIndexMapFromPermissiveMap(const Expr* id_expr);
100
101 // Tensor domain we're mapping back to root
102 const TensorDomain* td_; // NOLINT
103
104 // Map we update as we propagate backward, containing all IDs in the
105 // propagation. Initial indices are mapped with this map at tv->domain()
106 // and are back propagated to tv->getRootDomain(). This index_map_ keeps the
107 // indices at intermediate IterDomain's in that back propagation.
108 std::unordered_map<IterDomain*, Val*> index_map_; // NOLINT
109
110 // Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its
111 // producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to
112 // the extent I0*I1. Also contains updated extents if we merge in a 0 index.
113 // See zero_merged_in_.
114 std::unordered_map<IterDomain*, Val*> extent_map_; // NOLINT
115
116 // Keeps track of domains that do not contribute to indexing
117 std::unordered_set<IterDomain*> zero_domains_; // NOLINT
118
119 // This set keeps track of IterDomain's that have had a zero index merged into
120 // them. This happens if we do something like tv->axis(0)->split(4) then
121 // tv->computeAt(1, ...) if this tensor is in smem or lmem the backward
122 // indexing would be (0, i) then when we do the backward computation that zero
123 // and i would attempt to be merged together. We handle indices like these
124 // specially.
125 std::unordered_set<IterDomain*> zero_merged_in_;
126
127 // IDs that are a result of contiguous merges
128 std::unordered_set<IterDomain*> contig_ids_;
129
130 // Map from root to indexed domains
131 std::unordered_map<IterDomain*, IterDomain*> root_to_indexed_id_;
132
133 // Mentions if we should propagate an index down a particular IterDomain path
134 // if there's an option
135 std::unordered_set<IterDomain*> preferred_paths_;
136
137 // Map from IterDomains to halo-extended extents
138 std::unordered_map<IterDomain*, Val*> halo_extent_map_;
139
140 // Temporary flag which tells IndexCompute to use concrete id's from the exact
141 // map rather than the actual IDs used in the ID expressions.
142 bool concrete_id_pass_ = false;
143
144 // Mode of swizzle that are activated in this index compute
145 // instance. Will treat swizzles of different mode as no-op.
146 // Currently data mode swizzles are handled same as before in IndexSwizzle
147 // pass, while loop mode swizzles are handled early on in concrete indexing
148 // pass. See also [Note on swizzle mode]
149 SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle;
150
151 // (Concrete id pass only)
152 // Contains the indexing math that could be resolved with only the
153 // iterdomains on the right of the consumer_tv's ca axis, i.e. the
154 // ones that corresponding to the loops that consumer_tv would not
155 // share with any of its consumers.
156 // These indexing vals should be kept separate from index_map_ and
157 // should only be used when the indexing traversal follows the
158 // order defined in LoopIndexingAnalysis::traverseFromDomainVals.
159 std::unordered_map<IterDomain*, Val*> permissive_index_map_;
160
161 public:
162 const std::unordered_map<IterDomain*, Val*>& indexMap() const {
163 return index_map_;
164 }
165
166 const std::unordered_map<IterDomain*, Val*>& extentMap() const {
167 return extent_map_;
168 }
169
170 const std::unordered_set<IterDomain*>& zeroDomains() const {
171 return zero_domains_;
172 }
173
174 const std::unordered_set<IterDomain*>& zeroMergedIn() const {
175 return zero_merged_in_;
176 }
177
178 const std::unordered_map<IterDomain*, IterDomain*>& rootToContigID() const {
179 return root_to_indexed_id_;
180 }
181
182 // Propagate back from _td using initial_index_map
183 IndexCompute(
184 const TensorDomain* _td,
185 std::unordered_map<IterDomain*, Val*> initial_index_map,
186 std::unordered_map<IterDomain*, Val*> _extent_map,
187 std::unordered_set<IterDomain*> zero_domains,
188 std::unordered_set<IterDomain*> _zero_merged_in,
189 std::unordered_set<IterDomain*> preferred_paths = {},
190 std::unordered_map<IterDomain*, Val*> halo_extent_map = {});
191
192 IndexCompute(
193 const TensorDomain* _td,
194 std::unordered_map<IterDomain*, Val*> initial_index_map,
195 std::unordered_map<IterDomain*, Val*> _extent_map,
196 std::unordered_set<IterDomain*> zero_domains,
197 std::unordered_set<IterDomain*> _zero_merged_in,
198 const ContigIDs& contig_finder,
199 std::unordered_set<IterDomain*> preferred_paths = {},
200 std::unordered_map<IterDomain*, Val*> halo_extent_map = {});
201
202 // Entry point used for using concrete id based traversal. This traversal is
203 // assumed to start at leaf IDs provided by initial_index_map.
204 IndexCompute(
205 std::unordered_map<IterDomain*, Val*> initial_index_map,
206 std::unordered_set<IterDomain*> zero_domains,
207 std::unordered_set<IterDomain*> preferred_paths,
208 std::unordered_map<IterDomain*, Val*> concrete_halo_extent_map);
209
210 // Updates index_map, extent_map, and zero_merged_in based on id_map and
211 // returns a new IndexCompute ready to be used.
212 IndexCompute updateIndexCompute(
213 const TensorDomain* new_td,
214 const std::unordered_map<IterDomain*, IterDomain*>& id_map,
215 const ContigIDs& contig_finder) const;
216
217 // Interface to run index traversal through loop indexing analysis result to
218 // be used with the entry point for concrete id based traversal.
219 void run(const LoopIndexing& loop_indexing);
220
221 virtual void run();
222};
223
224//! Apply swizzle and update root indices accordingly
225class IndexSwizzle : public IndexCompute {
226 public:
227 IndexSwizzle(
228 const TensorView* tv,
229 std::unordered_map<IterDomain*, Val*> initial_index_map,
230 std::unordered_map<IterDomain*, Val*> extent_map,
231 std::unordered_set<IterDomain*> zero_domains,
232 std::unordered_set<IterDomain*> zero_merged_in);
233
234 IndexSwizzle(
235 const TensorView* tv,
236 const TensorDomain* domain,
237 std::unordered_map<IterDomain*, Val*> initial_index_map,
238 std::unordered_map<IterDomain*, Val*> extent_map,
239 std::unordered_set<IterDomain*> zero_domains,
240 std::unordered_set<IterDomain*> zero_merged_in);
241
242 void run() override;
243
244 protected:
245 using IndexCompute::handle;
246
247 void handle(Expr* e) override;
248
249 void handle(Swizzle2D* swizzle_2d) override;
250
251 private:
252 const TensorView* tv_ = nullptr;
253 SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
254 std::vector<IterDomain*> ids_to_swizzle_;
255 std::unordered_set<IterDomain*> swizzled_ids_;
256};
257
258//! Predicate information of a root or contiguous merged domain
259class RootPredicateInfo {
260 friend class Index;
261
262 public:
263 const auto& startPredicate() const {
264 return start_predicate_;
265 }
266
267 auto& startPredicate() {
268 return start_predicate_;
269 }
270
271 const auto& startOffset() const {
272 return start_offset_;
273 }
274
275 const auto& stopPredicate() const {
276 return stop_predicate_;
277 }
278
279 const auto& stopOffset() const {
280 return stop_offset_;
281 }
282
283 const auto& rootIds() const {
284 return root_ids_;
285 }
286
287 //! Return a false RootPredicateInfo, i.e., both start and stop
288 //! predicates are false.
289 static RootPredicateInfo getFalseInfo();
290
291 private:
292 // prdicate for lower end
293 Bool* start_predicate_ = nullptr;
294 // prdicate for upper end
295 Bool* stop_predicate_ = nullptr;
296 // Offset of the start predicate
297 Val* start_offset_ = nullptr;
298 // Offset of the stop predicate
299 Val* stop_offset_ = nullptr;
300 // Track which roots have been handled by the generated predicates
301 std::unordered_set<IterDomain*> root_ids_;
302};
303
304// Simple interface for IndexCompute
305// If getComputeAtAxis and more generally TensorView const model is fixed, we
306// can make the below tensorviews const.
307class Index {
308 private:
309 // Producer indexing if it's in shared or local memory
310 static std::vector<Val*> getNonGlobalProducerStridedIndices(
311 TensorView* producer,
312 const TensorView* consumer,
313 const std::vector<kir::ForLoop*>& loops);
314
315 // Consumer indexing if it's in shared or local memory
316 static std::vector<Val*> getNonGlobalConsumerStridedIndices(
317 const TensorView* consumer,
318 const std::vector<kir::ForLoop*>& loops);
319
320 // Producer if it's in global memory
321 static std::vector<Val*> getGlobalProducerStridedIndices(
322 TensorView* producer,
323 const TensorView* consumer,
324 const std::vector<kir::ForLoop*>& loops);
325
326 // Consumer indexing if it's in global memory
327 static std::vector<Val*> getGlobalConsumerStridedIndices(
328 const TensorView* consumer,
329 const std::vector<kir::ForLoop*>& loops);
330
331 // get the strides of a tensor used for the index lowering
332 static std::vector<Val*> getStrides(const TensorView* tv);
333
334 // get the root indices of a tensor used for the index lowering
335 static std::vector<Val*> getRootIndices(
336 const TensorView* tv,
337 const std::vector<kir::ForLoop*>& loops,
338 const IndexFromIdGraph& index_from_id_graph);
339
340 public:
341 // Indexing functions
342 // Consumer = Producer
343 // i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
344 // Producer indexing dispatch
345 static kir::TensorIndex* getProducerIndex(
346 TensorView* producer,
347 const TensorView* consumer,
348 const std::vector<kir::ForLoop*>& loops);
349
350 // Consumer index dispatch
351 static kir::TensorIndex* getConsumerIndex(
352 const TensorView* consumer,
353 const std::vector<kir::ForLoop*>& loops);
354
355 //! Returns a vector of strided indices mapped onto the (rfactor)
356 //! root domain of a producer tensor. The size of the returned
357 //! vector is guaranteed to be equal to the number of axes of the
358 //! indexing root domain.
359 static std::vector<Val*> getProducerStridedIndices(
360 TensorView* producer,
361 const TensorView* consumer,
362 const std::vector<kir::ForLoop*>& loops);
363
364 //! Returns a vector of strided indices mapped onto the (rfactor)
365 //! root domain of a consumer tensor. The size of the returned
366 //! vector is guaranteed to be equal to the number of axes of the
367 //! indexing root domain.
368 static std::vector<Val*> getConsumerStridedIndices(
369 const TensorView* consumer,
370 const std::vector<kir::ForLoop*>& loops);
371
372 //! Returns the logical index linearized from a multi-dimension address into a
373 //! linear memory address a consumer tensor. The returned index is intended to
374 //! be used for the computation of some tensor factories, such as: arange and
375 //! rand (for Philox pseudo random sequences)
376 static std::vector<Val*> getLinearLogicalIndex(
377 TensorView* consumer_tv,
378 const std::vector<kir::ForLoop*>& loops);
379
380 //! Returns a vector of logical indices mapped onto the (rfactor)
381 //! root domain of a consumer tensor. The returned index is intended
382 //! to be used for the computation of some tensor factories, such as:
383 //! eye
384 static std::vector<Val*> getPerDimLogicalIndex(
385 TensorView* consumer_tv,
386 const std::vector<kir::ForLoop*>& loops);
387
388 //! Take a consumer tensorview and loop nest and generates predicates
389 //! associated with the concrete roots of the loop nest. Returns a list of
390 //! predicates, and a list of concrete roots they're associated with. It
391 //! is assumed that no predicate is required if index[i] is an index
392 //! directly from a for loop. This will not catch all cases if we actually
393 //! have static size information for example:
394 //!
395 //! TV[I].split(4)
396 //! would produce the code:
397 //! for(i : I/4)
398 //! for(j : 4)
399 //! if( i * 4 + j < TV.size(0))
400 //! TV[i * 4 + j]...
401 //!
402 //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't
403 //! need the predicate. This will be caught by canOmitPredicate in the
404 //! predicate lowering
405 //!
406 //! unswitch_or_vec_loop is the for loop to start the unswitch like
407 //! predicate, this is not a bool value as if we have an unswitch loop
408 //! with a vectorized loop inside, we only want to base the "unswitch"
409 //! like predicate on the vectorized loop.
410 static std::vector<RootPredicateInfo> getReferenceRootPredicates(
411 TensorView* consumer_tv,
412 const std::vector<kir::ForLoop*>& loops,
413 kir::ForLoop* unswitch_or_vec_loop,
414 bool padding_predicate);
415};
416
417// Used for local and shared index mapping. Returns a map from loops
418// to loop indices as well as a set of loops that do not contribute to
419// indexing.
420// TODO: could be cleaned up further.
421std::pair<
422 std::unordered_map<kir::ForLoop*, Val*>,
423 std::unordered_set<kir::ForLoop*>>
424indexMapFromTV(
425 const TensorView* tv,
426 const std::vector<kir::ForLoop*>& loops,
427 kir::ForLoop* alloc_loop,
428 bool as_consumer,
429 kir::ForLoop* double_buffer_loop = nullptr);
430
431//! Set "pragma unroll" required for loops that indexing of Local
432//! tensors depends on.
433//!
434//! \param tv Indexed tensor
435//! \param alloc_loop Allocation loop of tv
436//! \param loops The current loop structure
437//! \param id_map Producer-to-consumer map in case of indexing as producer
438void ensureStaticIndexing(
439 const TensorView* tv,
440 kir::ForLoop* alloc_loop,
441 const std::vector<kir::ForLoop*>& loops,
442 const std::unordered_map<IterDomain*, IterDomain*>& id_map = {});
443
444} // namespace cuda
445} // namespace fuser
446} // namespace jit
447} // namespace torch
448