1 | #include <contiguity.h> |
2 | #include <index_compute.h> |
3 | #include <ir_utils.h> |
4 | #include <lower2device.h> |
5 | #include <lower_index_compute.h> |
6 | #include <lower_magic_zero.h> |
7 | #include <lower_utils.h> |
8 | #include <lower_validation.h> |
9 | #include <transform_iter.h> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | IndexFromIdGraph::IndexFromIdGraph( |
17 | IndexCompute index_, |
18 | IndexCompute concrete_index_, |
19 | std::unordered_map<IterDomain*, Val*> initial_concrete_index_map_, |
20 | std::vector<IterDomain*> loop_domains_) |
21 | : index(index_), |
22 | concrete_index(concrete_index_), |
23 | initial_concrete_index_map(initial_concrete_index_map_), |
24 | resolved_loop_domains(loop_domains_) {} |
25 | |
26 | namespace { |
27 | |
28 | // Maps all producer domains to consumer with broadcast |
29 | // forwarding. Used to find the allocation position. |
30 | // TODO: should this be an ir_util ? Didn't seem to be |
31 | // used too much though. |
32 | std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer( |
33 | const TensorView* producer_tv, |
34 | const TensorView* consumer_tv) { |
35 | // This map has forwarded broadcast axes, it should only be used to compute |
36 | // the allocation position of the producer, and to figure out which producer |
37 | // indices are mapped to consumer trivial reductions. |
38 | std::unordered_map<IterDomain*, IterDomain*> p2c_alloc_map; |
39 | |
40 | // We want to replay producer as consumer instead of the other way around |
41 | // since consumer may have some broadcasted axes producer doesn't have |
42 | // merged into loops producer may use. If we did consumer as producer we |
43 | // wouldn't have this information in the mapping. |
44 | auto replay_PasC = BestEffortReplay::replayPasC( |
45 | producer_tv, |
46 | consumer_tv, |
47 | -1, |
48 | PairwiseRootDomainMap(producer_tv, consumer_tv)); |
49 | |
50 | // Grab consumer domain entries and reverse replay map. TODO: Maybe |
51 | // TransformReplay::replayPasC could return this map |
52 | for (auto id : consumer_tv->domain()->domain()) { |
53 | const auto& c2p_map = replay_PasC.getReplay(); |
54 | auto c2p_it = c2p_map.find(id); |
55 | if (c2p_it != c2p_map.end()) { |
56 | auto c_id = c2p_it->first; |
57 | auto p_id = c2p_it->second; |
58 | p2c_alloc_map[p_id] = c_id; |
59 | } |
60 | } |
61 | |
62 | return p2c_alloc_map; |
63 | } |
64 | |
65 | std::unordered_map<IterDomain*, IterDomain*> invertOneToOneMap( |
66 | const std::unordered_map<IterDomain*, IterDomain*>& map) { |
67 | std::unordered_map<IterDomain*, IterDomain*> inverted; |
68 | for (const auto& kv : map) { |
69 | bool inserted = inverted.emplace(kv.second, kv.first).second; |
70 | TORCH_INTERNAL_ASSERT( |
71 | inserted, |
72 | "Multiple mappings to the same value detected: " , |
73 | kv.second->toString()); |
74 | } |
75 | return inverted; |
76 | } |
77 | |
78 | //! A struct to keep track of necessary parameters used in |
79 | //! configuring index compute pass. |
80 | //! These parameters are needed to propagate the indexing from the leaf nodes of |
81 | //! the TVs and loop nests to the TVs rfactor domain during |
82 | //! index_compute.cpp::IndexCompute passes. |
83 | //! TODO: |
84 | //! Would expect this list to become shorter over time, |
85 | //! as more info can be determined holistically. |
86 | struct IndexingParameters { |
87 | //! Initial binding of index math to concrete iterdomain ids, |
88 | //! from the loop nest analysis. |
89 | std::unordered_map<IterDomain*, Val*> initial_concrete_id_index; |
90 | |
91 | //! (Used in non-global indexing) the concrete iterdomains that |
92 | //! we want to skip or merge into contiguous indexing paths. |
93 | std::unordered_set<IterDomain*> zero_domains; |
94 | |
95 | //! (Used in non-global indexing) the preferred path we would |
96 | //! be propagating contiguously merged indices backward. |
97 | std::unordered_set<IterDomain*> preferred_concrete_ids; |
98 | |
99 | //! The inferred halo padded extents of the concrete iterdomains. |
100 | std::unordered_map<IterDomain*, Val*> concrete_id_to_halo_extent; |
101 | }; |
102 | |
103 | // Initial loop index map for global producer or consumer case. |
104 | IndexingParameters getLinearIndexParameters( |
105 | const LoopIndexing& loop_indexing, |
106 | bool index_producer = false) { |
107 | IndexingParameters index_parameters; |
108 | |
109 | auto& loops = loop_indexing.loops(); |
110 | auto& loop_domain = loop_indexing.loopDomains(); |
111 | auto& loop_index_map = index_parameters.initial_concrete_id_index; |
112 | |
113 | for (auto loop_idx : c10::irange(loops.size())) { |
114 | auto loop = loops[loop_idx]; |
115 | auto index_domain = GpuLower::current()->caMap()->getConcreteMappedID( |
116 | loop_domain[loop_idx], IdMappingMode::EXACT); |
117 | if (loop->isTrivial()) { |
118 | // This is useful information in the case of |
119 | // MisalignedVectorize and double buffer epilog, etc. |
120 | loop_index_map[index_domain] = loop->start(); |
121 | } else { |
122 | // Default use pre-allocated integers for index |
123 | loop_index_map[index_domain] = loop->index(); |
124 | } |
125 | } |
126 | |
127 | // Derive the halo extents from the loop indexing result. |
128 | index_parameters.concrete_id_to_halo_extent = |
129 | GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap( |
130 | loop_indexing); |
131 | |
132 | protectNonPredicateIndexWithMagicZero( |
133 | loops, |
134 | loop_indexing.loopDomains(), |
135 | index_parameters.initial_concrete_id_index); |
136 | |
137 | // Setup double buffer increment for producer case: |
138 | // TODO: could unify these double buffer index calculation |
139 | // in follow ups. |
140 | if (index_producer) { |
141 | auto double_buffer_loop = |
142 | GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( |
143 | loop_indexing.consumerTv(), loops, true); |
144 | |
145 | for (auto loop_idx : c10::irange(loops.size())) { |
146 | auto loop = loops[loop_idx]; |
147 | if (loop == double_buffer_loop) { |
148 | TORCH_INTERNAL_ASSERT( |
149 | !loop->isTrivial(), "The double buffer loop must be materialized" ); |
150 | |
151 | auto loop_id = loop_indexing.loopDomains()[loop_idx]; |
152 | |
153 | auto concrete_loop_id = |
154 | GpuLower::current()->caMap()->getConcreteMappedID( |
155 | loop_id, IdMappingMode::EXACT); |
156 | |
157 | auto stage_depth = |
158 | GpuLower::current()->doubleBufferInfo().getStageDepthFor( |
159 | loop->iter_domain()); |
160 | index_parameters.initial_concrete_id_index[concrete_loop_id] = |
161 | SimplifyingIrBuilder::addExpr( |
162 | index_parameters.initial_concrete_id_index[concrete_loop_id], |
163 | SimplifyingIrBuilder::create<Int>(stage_depth - 1)); |
164 | } |
165 | } |
166 | } |
167 | |
168 | return index_parameters; |
169 | } |
170 | |
171 | // Initial index parameters for shared and local case |
172 | IndexingParameters getNonGlobalInitialIndexParameters( |
173 | const LoopIndexing& loop_indexing, |
174 | const TensorView* consumer_tv, |
175 | bool index_producer = false, |
176 | const TensorView* producer_tv = nullptr, |
177 | std::unordered_map<IterDomain*, IterDomain*> p2c_map = {}) { |
178 | IndexingParameters index_parameters; |
179 | const auto& loops = loop_indexing.loops(); |
180 | const auto& loop_domains = loop_indexing.loopDomains(); |
181 | |
182 | // TODO: |
183 | // The non-global path should become shorter as we |
184 | // pull more info into id graph. |
185 | std::unordered_map<IterDomain*, IterDomain*> alloc_id_map; |
186 | |
187 | if (index_producer) { |
188 | alloc_id_map = mapAllProducerDomainsToConsumer(producer_tv, consumer_tv); |
189 | } |
190 | |
191 | auto alloc_tv = index_producer ? producer_tv : consumer_tv; |
192 | auto alloc_info = lower_utils::getAllocInformation( |
193 | alloc_tv, loops, alloc_id_map, index_producer); |
194 | |
195 | std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map; |
196 | std::unordered_set<kir::ForLoop*> zero_loops; |
197 | |
198 | kir::ForLoop* double_buffer_loop = nullptr; |
199 | |
200 | if (index_producer) { |
201 | double_buffer_loop = |
202 | GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( |
203 | consumer_tv, loops, true); |
204 | } |
205 | |
206 | std::tie(loop_to_ind_map, zero_loops) = indexMapFromTV( |
207 | alloc_tv, |
208 | loops, |
209 | alloc_info.init_for_loop, |
210 | !index_producer, |
211 | double_buffer_loop); |
212 | |
213 | ensureStaticIndexing(alloc_tv, alloc_info.init_for_loop, loops, alloc_id_map); |
214 | |
215 | TORCH_INTERNAL_ASSERT( |
216 | loops.size() <= loop_domains.size(), |
217 | "Loop domain didn't replay all loops" ); |
218 | |
219 | for (auto loop_idx : c10::irange(loops.size())) { |
220 | auto loop = loops[loop_idx]; |
221 | auto loop_domain = loop_domains[loop_idx]; |
222 | |
223 | auto concrete_loop_domain = |
224 | GpuLower::current()->caMap()->getConcreteMappedID( |
225 | loop_domain, IdMappingMode::EXACT); |
226 | |
227 | index_parameters.initial_concrete_id_index[concrete_loop_domain] = |
228 | loop_to_ind_map.at(loop); |
229 | |
230 | if (zero_loops.count(loop)) { |
231 | index_parameters.zero_domains.insert(concrete_loop_domain); |
232 | } |
233 | } |
234 | |
235 | // Derive preferred path from loop indexing result. |
236 | const TensorView* target_tv = index_producer ? producer_tv : consumer_tv; |
237 | index_parameters.preferred_concrete_ids = buildLoopIndexingPreferredPath( |
238 | target_tv, loop_indexing, index_producer, p2c_map); |
239 | |
240 | // Derive the halo extents from the loop indexing result. |
241 | index_parameters.concrete_id_to_halo_extent = |
242 | GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap( |
243 | loop_indexing); |
244 | |
245 | return index_parameters; |
246 | } |
247 | |
248 | //! Initial index parameters for predicate, adjusts loop to indexing |
249 | //! may according to the information annotated on the loop nest. |
250 | //! |
251 | //! TODO: |
252 | //! This function is mostly copy pasted from previous implementation |
253 | //! at this step, further clean up is possible since: |
254 | //! 1. Much of the loop-to-ind adjustment will be issued from idgraph |
255 | //! 2. Much of the initial index logic could be shared across all |
256 | //! the 3 variants. |
257 | IndexingParameters getPredicateInitialIndexParameters( |
258 | const LoopIndexing& loop_indexing, |
259 | TensorView* consumer_tv, |
260 | kir::ForLoop* unswitch_or_vec_loop, |
261 | IterDomain* double_buffer_axis, |
262 | bool is_start_predicate) { |
263 | IndexingParameters index_parameters; |
264 | const auto& loops = loop_indexing.loops(); |
265 | const auto& loop_domains = loop_indexing.loopDomains(); |
266 | |
267 | // This shouldn't be needed. |
268 | TORCH_INTERNAL_ASSERT( |
269 | loops.size() <= loop_domains.size(), |
270 | "Loop domain didn't replay all loops" ); |
271 | |
272 | std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map; |
273 | |
274 | // Fill initial index with each forloop's index. |
275 | std::transform( |
276 | loops.begin(), |
277 | loops.end(), |
278 | std::inserter(loop_to_ind_map, loop_to_ind_map.begin()), |
279 | [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); }); |
280 | |
281 | // Generate unswitch loop to index map. |
282 | if (unswitch_or_vec_loop != nullptr) { |
283 | // Vectorized predicates are different from unswitch. Unswitch predicates |
284 | // all loops within the unswitch (the outer most unswitch) are generated |
285 | // with loop->extent-1 as the index. With vectorized predicates, only the |
286 | // vectorized loop should be like this. |
287 | bool vectorized_pred = |
288 | unswitch_or_vec_loop->iter_domain()->getParallelType() == |
289 | ParallelType::Vectorize; |
290 | |
291 | bool within_unswitch = false; |
292 | |
293 | for (const auto loop_i : c10::irange(loops.size())) { |
294 | auto loop = loops[loop_i]; |
295 | auto loop_id = loop->iter_domain(); |
296 | auto loop_pt = loop_id->getParallelType(); |
297 | auto ref_id = loop_domains.at(loop_i); |
298 | |
299 | if (loop == unswitch_or_vec_loop) { |
300 | within_unswitch = true; |
301 | } |
302 | |
303 | if (within_unswitch) { |
304 | // Rely on the reference to check broadcasting. The for loop could be |
305 | // broadcasted on a constant value from an unroll split. Since reference |
306 | // may convert this to an iter domain, that for loop could be valid to |
307 | // generate predication from. |
308 | |
309 | // Note that loop->stop() is not used below. Instead, |
310 | // loop->iter_domain()->extent() is used, which is uniform |
311 | // across the mapped domains irrespective of halo. Predicates are |
312 | // compared with each to pick the most restrictive ones. The |
313 | // comparison is done by only using the offset, which is the |
314 | // term added to the index. So, the index term must be the |
315 | // same among all predicates, otherwise the comparison would |
316 | // be invalid. The effect by halo is added to the offset |
317 | // term. See getUnswitchStopOffset. |
318 | |
319 | if (ref_id->isBroadcast()) { |
320 | // Ignore indexing into broadcasted dimensions. |
321 | continue; |
322 | } else if (loop_id->isThread()) { |
323 | // When parallelized, if the loop stop is the same as the |
324 | // extent of the associated IterDomain, i.e., no extra |
325 | // iterations for halo, predicating with the threading index |
326 | // is sufficient for both the start and stop |
327 | // predicates. That isn't the case if the loop has halo, and |
328 | // in the case either the minimum and maximum values of the |
329 | // iteration domain needs to be used. |
330 | // |
331 | // Note: Better performance was obtained if using |
332 | // threadIdx in unswitch predicates was avoided. More |
333 | // specifically, in the Hdiff stencil example, instead of |
334 | // predicating with threadIdx.x for both the start and stop |
335 | // predicates, using zero and (blockDim.x - 1) for the start |
336 | // and stop predicates, respectively, resulted in less |
337 | // register pressure. The alternative codegen can be done by |
338 | // adding this to the first if condition: |
339 | // loop_id->isBlockDim(). This would not be a concern if the |
340 | // else part could be omitted, so canOmitElseClause should |
341 | // be used as well. |
342 | if (loop->stop() == loop_id->extent()) { |
343 | loop_to_ind_map[loop] = loop->start(); |
344 | } else if (is_start_predicate) { |
345 | loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); |
346 | } else { |
347 | // Note that the parallel dimension is used rather than |
348 | // loop-stop(). See the above comment. |
349 | loop_to_ind_map[loop] = |
350 | GpuLower::current()->parallelDimensionMap().get(loop_pt); |
351 | } |
352 | } else if (is_start_predicate) { |
353 | loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); |
354 | } else { |
355 | // Similar to the above, loop_id()->extent() is |
356 | // used here instead of loop->stop(). See the above comment. |
357 | loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( |
358 | loop_id->extent(), GpuLower::current()->kernel()->oneVal()); |
359 | } |
360 | } |
361 | |
362 | // If a vectorized predicate, bail after the vectorized loop was found. |
363 | // Don't continue unswitching loops. |
364 | if (vectorized_pred && within_unswitch) { |
365 | break; |
366 | } |
367 | } |
368 | } |
369 | |
370 | // Modify trivial loops to use the loop start value. |
371 | // FIXME: eventually should be all lifted in idgraph. |
372 | for (const auto loop : loops) { |
373 | auto& idx = loop_to_ind_map.at(loop); |
374 | // If the loop is trivial, the loop index can only be the loop |
375 | // start value. |
376 | if (idx == loop->index() && loop->isTrivial()) { |
377 | idx = loop->start(); |
378 | } |
379 | } |
380 | |
381 | // Increment double buffer loop index |
382 | if (double_buffer_axis != nullptr) { |
383 | auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( |
384 | double_buffer_axis, loops, true); |
385 | if (db_loop != nullptr) { |
386 | auto loop_to_ind_map_it = loop_to_ind_map.find(db_loop); |
387 | TORCH_INTERNAL_ASSERT(loop_to_ind_map_it != loop_to_ind_map.end()); |
388 | auto cur_index = loop_to_ind_map_it->second; |
389 | // if cur_index is not the same as the index of db_loop, it must |
390 | // be true that that index has been modified to support |
391 | // unswitch. In that case, it is not necessary to move ahead the |
392 | // index for double buffering. |
393 | auto stage_depth = |
394 | GpuLower::current()->doubleBufferInfo().getStageDepthFor( |
395 | db_loop->iter_domain()); |
396 | if (cur_index == db_loop->index()) { |
397 | loop_to_ind_map[db_loop] = SimplifyingIrBuilder::addExpr( |
398 | cur_index, SimplifyingIrBuilder::create<Int>(stage_depth - 1)); |
399 | } |
400 | } |
401 | } |
402 | |
403 | // Convert loop-to-ind map to concrete-to-ind map |
404 | for (int loop_idx : c10::irange(loops.size())) { |
405 | auto loop = loops.at(loop_idx); |
406 | auto concrete_loop_domain = |
407 | GpuLower::current()->caMap()->getConcreteMappedID( |
408 | loop_domains.at(loop_idx), IdMappingMode::EXACT); |
409 | index_parameters.initial_concrete_id_index[concrete_loop_domain] = |
410 | loop_to_ind_map.at(loop); |
411 | } |
412 | |
413 | // Note that, unlike non-predicate indexing, magic-zero insertion is |
414 | // not done at this point but is done individually for each indexed |
415 | // domain. See Index::getReferenceRootPredicates. |
416 | |
417 | // Derive the halo extents from the loop indexing result. |
418 | index_parameters.concrete_id_to_halo_extent = |
419 | GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap( |
420 | loop_indexing); |
421 | |
422 | return index_parameters; |
423 | } |
424 | |
425 | } // namespace |
426 | |
427 | class LoopIndexingAnalysis { |
428 | public: |
429 | static LoopIndexing fromLoopAndConsumer( |
430 | const std::vector<kir::ForLoop*>& loops, |
431 | const TensorView* consumer_tv) { |
432 | LoopIndexingAnalysis analysis(loops, consumer_tv); |
433 | return analysis.getLoopIndexing(); |
434 | } |
435 | |
436 | private: |
437 | explicit LoopIndexingAnalysis( |
438 | const std::vector<kir::ForLoop*>& loops, |
439 | const TensorView* consumer_tv); |
440 | |
441 | //! Populate derived information into a LoopIndexing |
442 | //! data structure. |
443 | LoopIndexing getLoopIndexing() { |
444 | LoopIndexing indexing; |
445 | indexing.loops_ = loops_; |
446 | indexing.consumer_tv_ = consumer_tv_; |
447 | indexing.loop_root_ = loop_root_domains_; |
448 | indexing.loop_domains_ = loop_domains_.vector(); |
449 | indexing.index_exprs_ = replayed_exprs_; |
450 | indexing.out_of_line_exprs_ = out_of_line_exprs_; |
451 | return indexing; |
452 | } |
453 | |
454 | //! Validates that the current loop structure is well formed, in the sense |
455 | //! that ca_map would not map any two loops in the loop nest together. |
456 | void validateLoopStructure(const std::vector<kir::ForLoop*>& loops); |
457 | |
458 | //! Start at the loop iter domains, and traverse back into history on the |
459 | //! concrete IDs in the exact map calling "visitExpr" expressions through the |
460 | //! history. |
461 | void traverseFromDomainVals(); |
462 | |
463 | //! Concretize the given iterdomain and record the visit (in deterministic |
464 | //! order) in terms of the exact mapped concrete id. Marks the mapping of the |
465 | //! id to the concrete id in "concrete_to_original_id_" and returns the |
466 | //! concrete id. |
467 | IterDomain* concretizeAndVisitId(IterDomain* id); |
468 | |
469 | //! If an equivalent expression has already been processed this function |
470 | //! simply returns. Otherwise puts the exact concrete IDs of inputs in |
471 | //! consumed_concrete_, and concrete IDs of outputs in produced_concrete_. |
472 | //! Then adds the expression to replayed_exprs_. |
473 | void visitExpr(Expr* expr); |
474 | |
475 | //! Iterates through provided vals, calls concretizeAndVisitId on them, and |
476 | //! returns if any of the returned vals are in existing_ids. This is used to |
477 | //! check if inputs or outputs of ID expressions have already been |
478 | //! produced/consumed in the traversal. Indexing only needs to consume/produce |
479 | //! one IterDomain per exact disjoint set. |
480 | bool visitIdsAndCheckDuplication( |
481 | const std::vector<Val*>& vals, |
482 | const std::unordered_set<IterDomain*>& existing_ids); |
483 | |
484 | //! Fills loop_domains_ with the corresponding replayed_concrete_id mapping to |
485 | //! the provided loops. Must be done after the exact iterdomain "replay" |
486 | //! (traverseFromDomainVals). loop_domains_ are the original_id not the |
487 | //! concrete_id (translated with concrete_to_original_id). These iter domains |
488 | //! are used to grab the history that will be replayed in IndexCompute. We're |
489 | //! looking for "new" root domains and subsequent transformations, filling in |
490 | //! any missing "outputs" (or inputs for backward traversal). Then fills |
491 | //! loop_domains_ with all of these iter domains. |
492 | void constructLoopDomains(); |
493 | |
494 | //! Fills out_of_line_exprs_ by traversing the selected list of |
495 | //! expressions in reverse topological order and collect iterdomains |
496 | //! on the indexing paths that only involves leaf id's on the right |
497 | //! of consumer's ca axis. |
498 | void collectOutOfLineExprs(); |
499 | |
500 | private: |
501 | //! Original loop nest input to derive info from. |
502 | const std::vector<kir::ForLoop*>& loops_; |
503 | |
504 | //! Original consumer tv to derive view info from. |
505 | const TensorView* consumer_tv_ = nullptr; |
506 | |
507 | // Exact concrete domains that has been used |
508 | // in the traversal connection. |
509 | std::unordered_set<IterDomain*> produced_concrete_; |
510 | std::unordered_set<IterDomain*> consumed_concrete_; |
511 | |
512 | //! Iterdomains that the corresponding loops are generated from. |
513 | std::vector<IterDomain*> initial_loop_domain_ids_; |
514 | |
515 | //! All Id's in consumer's transform history |
516 | std::vector<Val*> all_consumer_id_vals_; |
517 | |
518 | //! Concrete iterdomains visited in the domain traversal, |
519 | //! in the order they are visited in traverseFromDomainVals. |
520 | VectorOfUniqueEntries<IterDomain*> replayed_concrete_ids_; |
521 | |
522 | //! Keeping track of the original visited id's before they |
523 | //! were concretized. |
524 | std::unordered_map<IterDomain*, IterDomain*> concrete_to_original_id_; |
525 | |
526 | //! Map from concrete id to its single consumer on the selected |
527 | //! iterdomain expression list. |
528 | std::unordered_map<IterDomain*, Expr*> concrete_id_to_consumer_; |
529 | |
530 | //! Source domains that all the Iterdomain transforms |
531 | //! in the loop nest originated from. |
532 | std::vector<IterDomain*> loop_root_domains_; |
533 | |
534 | //! Leaf domains representing the original loop structure |
535 | VectorOfUniqueEntries<IterDomain*> loop_domains_; |
536 | |
537 | //! Selected list of exprs that will produce and consume each |
538 | //! of the exact concrete ids from the loop nest exactly once. |
539 | std::vector<Expr*> replayed_exprs_; |
540 | |
541 | //! Set of expressions from the selected list that can be |
542 | //! resolved from axes on the right of ca axes. |
543 | std::vector<Expr*> out_of_line_exprs_; |
544 | }; |
545 | |
546 | LoopIndexingAnalysis::LoopIndexingAnalysis( |
547 | const std::vector<kir::ForLoop*>& loops, |
548 | const TensorView* consumer_tv) |
549 | : loops_(loops), consumer_tv_(consumer_tv) { |
550 | // Validate consistency in given loop nest |
551 | validateLoopStructure(loops); |
552 | |
553 | // Populate initial loop iter domains. |
554 | std::transform( |
555 | loops.begin(), |
556 | loops.end(), |
557 | std::back_inserter(initial_loop_domain_ids_), |
558 | [](kir::ForLoop* fl) { return fl->iter_domain(); }); |
559 | |
560 | // Collect consumer id's for view rfactor traversal. |
561 | all_consumer_id_vals_ = DependencyCheck::getAllValsBetween( |
562 | {consumer_tv->getRootDomain().begin(), |
563 | consumer_tv->getRootDomain().end()}, |
564 | {consumer_tv->domain()->domain().begin(), |
565 | consumer_tv->domain()->domain().end()}); |
566 | |
567 | // Resolve definition of each exact concrete id's involved in the whole loop |
568 | // nest transform history |
569 | traverseFromDomainVals(); |
570 | |
571 | // Construct concrete to consumer map. The replayed exprs are guaranteed to |
572 | // consume each concrete id once so this map is well defined. |
573 | for (auto expr : replayed_exprs_) { |
574 | for (auto input_id : ir_utils::filterByType<IterDomain>(expr->inputs())) { |
575 | auto concrete_input_id = |
576 | GpuLower::current()->caMap()->getConcreteMappedID( |
577 | input_id, IdMappingMode::EXACT); |
578 | concrete_id_to_consumer_[concrete_input_id] = expr; |
579 | } |
580 | } |
581 | |
582 | // Reconstruct the iterdomain view of the original loopnest after resolving |
583 | // the exact definition of each index. |
584 | constructLoopDomains(); |
585 | |
586 | //! Collect the set of indexing expressions that can be |
587 | //! resolved out of line. |
588 | collectOutOfLineExprs(); |
589 | } |
590 | |
591 | void LoopIndexingAnalysis::validateLoopStructure( |
592 | const std::vector<kir::ForLoop*>& loops) { |
593 | // Throw an error when two loops are mapped with each other, which |
594 | // violates an assumption that unique mappings between concrete |
595 | // IterDomains and the IterDomains of the loop structure must be |
596 | // established. It should be a reasonable assumption, but fusions |
597 | // like below won't work: |
598 | // tv0 = [I0] |
599 | // tv1 = broadcast(tv0, {true, false}); |
600 | // tv2 = broadcast(tv0, {false, true}); |
601 | // tv3 = tv1 + tv2 |
602 | // Notice that the two axes of each of tv1, tv2 and tv3 are mapped |
603 | // with each other. We believe it is unlikely this limitation |
604 | // becomes a real concern in practice. |
605 | // Map concrete id to the original loop iter domain. |
606 | std::unordered_map<IterDomain*, IterDomain*> concrete_to_loop; |
607 | for (auto it_i = loops.begin(); it_i != loops.end(); ++it_i) { |
608 | // Largely duplicating original logic |
609 | auto loop_id = (*it_i)->iter_domain(); |
610 | auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( |
611 | loop_id, IdMappingMode::EXACT); |
612 | |
613 | TORCH_INTERNAL_ASSERT( |
614 | !concrete_to_loop.count(concrete_loop_id), |
615 | "Unsupported loop structure. Two loops are mapped together." , |
616 | loop_id->toString(), |
617 | " and " , |
618 | concrete_to_loop.at(concrete_loop_id)->toString()); |
619 | |
620 | concrete_to_loop[concrete_loop_id] = loop_id; |
621 | } |
622 | } |
623 | |
624 | void LoopIndexingAnalysis::traverseFromDomainVals() { |
625 | // Order is really important here, start with outer most for loops in a |
626 | // depth first manner. The outer most loops are topologically closer to the |
627 | // outputs, so their broadcast dimensions are "more" resolved than those |
628 | // towards the inner most loops. |
629 | std::deque<IterDomain*> to_visit( |
630 | initial_loop_domain_ids_.begin(), initial_loop_domain_ids_.end()); |
631 | std::unordered_set<Expr*> visited_exprs; |
632 | std::unordered_set<IterDomain*> visited_ids; |
633 | |
634 | while (!to_visit.empty()) { |
635 | auto out_id = to_visit.front(); |
636 | to_visit.pop_front(); |
637 | |
638 | if (!visited_ids.emplace(out_id).second) { |
639 | continue; |
640 | } |
641 | auto expr = out_id->definition(); |
642 | |
643 | if (auto rfactor_id = |
644 | getRfactorIDToTraverse(out_id, all_consumer_id_vals_)) { |
645 | to_visit.emplace_front(rfactor_id); |
646 | } |
647 | |
648 | // ID's will be copied for the reference as we replay transformations. If |
649 | // there was no transformations on an iteration domain, a copy of the |
650 | // iteration domain for the reference is made here. |
651 | if (expr == nullptr) { |
652 | if (std::find( |
653 | initial_loop_domain_ids_.begin(), |
654 | initial_loop_domain_ids_.end(), |
655 | out_id) != initial_loop_domain_ids_.end()) { |
656 | concretizeAndVisitId(out_id); |
657 | } |
658 | continue; |
659 | } |
660 | |
661 | if (!visited_exprs.emplace(expr).second) { |
662 | continue; |
663 | } |
664 | |
665 | visitExpr(expr); |
666 | |
667 | auto inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs()); |
668 | // Make sure to put at the begining of the deque to maintain correct |
669 | // ordering. |
670 | to_visit.insert(to_visit.begin(), inp_ids.begin(), inp_ids.end()); |
671 | } |
672 | } |
673 | |
674 | IterDomain* LoopIndexingAnalysis::concretizeAndVisitId(IterDomain* id) { |
675 | auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( |
676 | id, IdMappingMode::EXACT); |
677 | if (replayed_concrete_ids_.pushBack(concrete_id)) { |
678 | concrete_to_original_id_[concrete_id] = id; |
679 | } |
680 | return concrete_id; |
681 | } |
682 | |
683 | namespace { |
684 | // Alias used for std::transform |
685 | IterDomain* exactConcreteId(IterDomain* id) { |
686 | return GpuLower::current()->caMap()->getConcreteMappedID( |
687 | id, IdMappingMode::EXACT); |
688 | } |
689 | } // namespace |
690 | |
691 | void LoopIndexingAnalysis::visitExpr(Expr* expr) { |
692 | if (auto swizzle2d = dynamic_cast<Swizzle2D*>(expr)) { |
693 | // Swizzle outputs are already forwarded through |
694 | // by exact CA map, so currently they are just |
695 | // ignored in the replay pass except |
696 | // that we want to note this node visited. |
697 | concretizeAndVisitId(swizzle2d->outX()); |
698 | concretizeAndVisitId(swizzle2d->outY()); |
699 | return; |
700 | } |
701 | |
702 | // Current implementation just tries to |
703 | // follow the exact behavior of reference replay |
704 | // except that no expr was actually "replayed". |
705 | |
706 | // Record all inputs, and stop if current expr |
707 | // duplicates id consumption or production. |
708 | if (visitIdsAndCheckDuplication(expr->inputs(), consumed_concrete_)) { |
709 | return; |
710 | } |
711 | if (visitIdsAndCheckDuplication(expr->outputs(), produced_concrete_)) { |
712 | return; |
713 | } |
714 | |
715 | // Record the expr if no duplication on input or output found |
716 | replayed_exprs_.push_back(expr); |
717 | |
718 | // Record the consumed and produced concrete ids by the newly |
719 | // recorded expression. |
720 | auto consumed_ids = ir_utils::filterByType<IterDomain>(expr->inputs()); |
721 | std::transform( |
722 | consumed_ids.begin(), |
723 | consumed_ids.end(), |
724 | std::inserter(consumed_concrete_, consumed_concrete_.end()), |
725 | exactConcreteId); |
726 | |
727 | auto produced_ids = ir_utils::filterByType<IterDomain>(expr->outputs()); |
728 | std::transform( |
729 | produced_ids.begin(), |
730 | produced_ids.end(), |
731 | std::inserter(produced_concrete_, produced_concrete_.end()), |
732 | exactConcreteId); |
733 | } |
734 | |
735 | bool LoopIndexingAnalysis::visitIdsAndCheckDuplication( |
736 | const std::vector<Val*>& vals, |
737 | const std::unordered_set<IterDomain*>& existing_ids) { |
738 | bool duplication = false; |
739 | for (auto id : ir_utils::filterByType<IterDomain>(vals)) { |
740 | duplication = duplication || existing_ids.count(concretizeAndVisitId(id)); |
741 | } |
742 | return duplication; |
743 | } |
744 | |
745 | void LoopIndexingAnalysis::constructLoopDomains() { |
746 | for (auto loop_id : initial_loop_domain_ids_) { |
747 | // Find the replayed_concrete_id mapping to the loop id. |
748 | auto ref_id_it = std::find_if( |
749 | replayed_concrete_ids_.vector().begin(), |
750 | replayed_concrete_ids_.vector().end(), |
751 | [&](IterDomain* concrete_id) { |
752 | return |
753 | // Make sure the replayed_concrete_id is a leaf ID |
754 | !concrete_id_to_consumer_.count(concrete_id) && |
755 | // Use permissive map so the selected ID indeed represents the |
756 | // loop. |
757 | // Note: see PR https://github.com/csarofeen/pytorch/pull/1960 |
758 | // and issue https://github.com/csarofeen/pytorch/issues/1873 |
759 | // This mapping look up is part of a staged indexing scheme. |
760 | // When we find a replayed exact id that exactly map to the loop |
761 | // id, this means that we can resolve indexing involved in this |
762 | // loop "locally", i.e. only with and with only the iterdomains |
763 | // on the |
764 | // |
765 | // given consumer tv. |
766 | // When we cannot find an exact mapping, the permissive mapping |
767 | // would |
768 | // help defering the indexing resolution for this loop nest |
769 | // level to other iterdomain expressions from tv's that are |
770 | // further concretized and usually they are further down the |
771 | // consumer chain of the given consumer tv. |
772 | // |
773 | // Intuitively exact mapping of two iterdomains should imply |
774 | // permissive mapping |
775 | // of them as well and if that was the case, only looking up |
776 | // permissive mapping would be enough to address both of the |
777 | // cases above. |
778 | // FIXME: But currently exact mapping does not imply permissive |
779 | // mapping (See issue: |
780 | // https://github.com/csarofeen/pytorch/issues/1963) |
781 | // Which means we should check both exact and permissive mapping |
782 | // here. |
783 | (GpuLower::current()->caMap()->areMapped( |
784 | concrete_id, loop_id, IdMappingMode::EXACT) || |
785 | GpuLower::current()->caMap()->areMapped( |
786 | concrete_id, loop_id, IdMappingMode::PERMISSIVE)); |
787 | }); |
788 | |
789 | TORCH_INTERNAL_ASSERT( |
790 | ref_id_it != replayed_concrete_ids_.vector().end(), |
791 | "Could not find required iter domain in reference replay: " , |
792 | loop_id->toString()); |
793 | |
794 | auto ref_id = *ref_id_it; |
795 | loop_domains_.pushBack(concrete_to_original_id_.at(ref_id)); |
796 | } |
797 | |
798 | // Construct the root domain as the inputs of the replayed domain |
799 | auto loops_replayed_domain_vals = |
800 | ir_utils::filterByType<Val>(loop_domains_.vector()); |
801 | auto root_domain_vals = IterVisitor::getInputsTo( |
802 | {loops_replayed_domain_vals.begin(), loops_replayed_domain_vals.end()}); |
803 | |
804 | // Fill loop roots: |
805 | auto root_domain_ids = ir_utils::filterByType<IterDomain>(root_domain_vals); |
806 | loop_root_domains_ = |
807 | std::vector<IterDomain*>(root_domain_ids.begin(), root_domain_ids.end()); |
808 | |
809 | // The domain may have dangling iteration domains, i.e. the inner output of |
810 | // a split but not the outer. Find which replayed vals are dependant on the |
811 | // root domains. |
812 | auto all_replayed_vals = |
813 | ir_utils::filterByType<Val>(replayed_concrete_ids_.vector()); |
814 | auto all_ids_from_root = DependencyCheck::getAllValsBetween( |
815 | {root_domain_vals.begin(), root_domain_vals.end()}, |
816 | {all_replayed_vals.begin(), all_replayed_vals.end()}); |
817 | |
818 | // Fill all dangling outputs as otherwise backwards visitor in index compute |
819 | // will complain for not having all outputs of the traversal. |
820 | for (auto id : ir_utils::filterByType<IterDomain>(all_ids_from_root)) { |
821 | if (id->uses().empty()) { |
822 | loop_domains_.pushBack(GpuLower::current()->caMap()->getConcreteMappedID( |
823 | id, IdMappingMode::EXACT)); |
824 | } |
825 | } |
826 | } |
827 | |
828 | IndexFromIdGraph getTensorIndexFromIdGraph( |
829 | const std::vector<kir::ForLoop*>& loops, |
830 | const TensorView* consumer_tv, |
831 | const TensorView* producer_tv, |
832 | bool is_global, |
833 | std::unordered_map<IterDomain*, IterDomain*> c2p_map) { |
834 | bool index_producer = producer_tv != nullptr; |
835 | auto target_tv = index_producer ? producer_tv : consumer_tv; |
836 | |
837 | auto loop_indexing = |
838 | LoopIndexingAnalysis::fromLoopAndConsumer(loops, consumer_tv); |
839 | |
840 | IndexingParameters index_parameters; |
841 | |
842 | std::unordered_map<IterDomain*, IterDomain*> p2c_map; |
843 | |
844 | // The p2c map is only needed when indexing producer |
845 | // as producer has replayed ids. |
846 | if (index_producer) { |
847 | p2c_map = invertOneToOneMap(c2p_map); |
848 | } |
849 | |
850 | if (is_global) { |
851 | index_parameters = getLinearIndexParameters(loop_indexing, index_producer); |
852 | } else { |
853 | index_parameters = getNonGlobalInitialIndexParameters( |
854 | loop_indexing, consumer_tv, index_producer, producer_tv, p2c_map); |
855 | } |
856 | |
857 | IndexCompute indexing( |
858 | index_parameters.initial_concrete_id_index, |
859 | index_parameters.zero_domains, |
860 | index_parameters.preferred_concrete_ids, |
861 | index_parameters.concrete_id_to_halo_extent); |
862 | |
863 | // Run first backward traversal to generate |
864 | // loop nest based indexing math. |
865 | indexing.run(loop_indexing); |
866 | |
867 | // Populate indexing through exact map from initial indexing |
868 | auto consumer_root = index_producer ? consumer_tv->getRootDomain() |
869 | : consumer_tv->getMaybeRFactorDomain(); |
870 | |
871 | // First collect all iterdomains in consumer transform history. |
872 | auto all_consumer_vals = DependencyCheck::getAllValsBetween( |
873 | {consumer_root.begin(), consumer_root.end()}, |
874 | {consumer_tv->domain()->domain().begin(), |
875 | consumer_tv->domain()->domain().end()}); |
876 | |
877 | // Indexable domains are the concrete id's we visited when |
878 | // traversing the "reference" indexing pass. |
879 | std::unordered_map<IterDomain*, IterDomain*> initial_indexable_map; |
880 | |
881 | // Map the concrete id indexing back to the producer or consumer tv |
882 | std::unordered_map<IterDomain*, IterDomain*> index_update_map; |
883 | |
884 | for (IterDomain* consumer_id : |
885 | ir_utils::filterByType<IterDomain>(all_consumer_vals)) { |
886 | // Track the non-concrete id we were trying to bind index |
887 | // to, whether from producer or consumer. |
888 | auto target_id = consumer_id; |
889 | |
890 | // use mapped producer id when indexing producer |
891 | if (index_producer) { |
892 | auto target_id_it = c2p_map.find(consumer_id); |
893 | if (target_id_it == c2p_map.end()) { |
894 | // consumer id not found in c2p map |
895 | // skip binding for this id. |
896 | continue; |
897 | } |
898 | target_id = target_id_it->second; |
899 | } |
900 | |
901 | // Exact id will have to be pulled from consumer side as the |
902 | // producer side are replayed ids. |
903 | auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( |
904 | consumer_id, IdMappingMode::EXACT); |
905 | |
906 | index_update_map[exact_concrete_id] = target_id; |
907 | |
908 | // Keep track of concrete id's that were used for indexing. |
909 | if (indexing.indexMap().count(exact_concrete_id)) { |
910 | initial_indexable_map[exact_concrete_id] = exact_concrete_id; |
911 | } |
912 | } |
913 | |
914 | // No contig indexing was done in reference indexing |
915 | ContigIDs contig_finder( |
916 | target_tv->domain()->domain(), |
917 | target_tv->getMaybeRFactorDomain(), |
918 | target_tv->domain()->contiguity(), |
919 | {}, |
920 | indexing.indexMap(), |
921 | GpuLower::current()->divisbleSplitSet(), |
922 | GpuLower::current()->caMap(), |
923 | GpuLower::current()->haloInfo(), |
924 | GpuLower::current()->concretizedBroadcastDomains(), |
925 | p2c_map); |
926 | |
927 | auto target_indexing = indexing.updateIndexCompute( |
928 | target_tv->domain(), index_update_map, contig_finder); |
929 | |
930 | // Fill validation info. |
931 | // TODO: cleanup seems possible. |
932 | if (index_producer) { |
933 | fillProducerVectorizedContigRootDomains( |
934 | producer_tv, consumer_tv, c2p_map, contig_finder); |
935 | } else { |
936 | fillConsumerVectorizedContigRootDomains(consumer_tv, contig_finder); |
937 | } |
938 | |
939 | return IndexFromIdGraph( |
940 | target_indexing, |
941 | indexing, |
942 | index_parameters.initial_concrete_id_index, |
943 | loop_indexing.loopDomains()); |
944 | } |
945 | |
946 | IndexFromIdGraph getPredicateIndexingFromIdGraph( |
947 | const std::vector<kir::ForLoop*>& loops, |
948 | TensorView* consumer_tv, |
949 | kir::ForLoop* unswitch_or_vec_loop, |
950 | IterDomain* double_buffer_axis, |
951 | bool is_start_predicate) { |
952 | // Run replay pass on the loop nest to generate the deterministic |
953 | // traversal info from loop structure. |
954 | auto loop_indexing = |
955 | LoopIndexingAnalysis::fromLoopAndConsumer(loops, consumer_tv); |
956 | |
957 | // Bind initial index variables to the loop nodes and adjust |
958 | // according to loop and unswitch info. |
959 | auto index_parameters = getPredicateInitialIndexParameters( |
960 | loop_indexing, |
961 | consumer_tv, |
962 | unswitch_or_vec_loop, |
963 | double_buffer_axis, |
964 | is_start_predicate); |
965 | |
966 | // Run first backward traversal to generate |
967 | // loop nest based indexing math. |
968 | IndexCompute indexing( |
969 | index_parameters.initial_concrete_id_index, |
970 | index_parameters.zero_domains, |
971 | index_parameters.preferred_concrete_ids, |
972 | index_parameters.concrete_id_to_halo_extent); |
973 | |
974 | indexing.run(loop_indexing); |
975 | |
976 | // Map the concrete id indexing back to consumer tv |
977 | std::unordered_map<IterDomain*, IterDomain*> index_update_map; |
978 | |
979 | // First collect all iterdomains in consumer transform history. |
980 | auto all_consumer_vals = DependencyCheck::getAllValsBetween( |
981 | {consumer_tv->getMaybeRFactorDomain().begin(), |
982 | consumer_tv->getMaybeRFactorDomain().end()}, |
983 | {consumer_tv->domain()->domain().begin(), |
984 | consumer_tv->domain()->domain().end()}); |
985 | |
986 | for (IterDomain* consumer_id : |
987 | ir_utils::filterByType<IterDomain>(all_consumer_vals)) { |
988 | // Track the non-concrete id we were trying to bind index |
989 | // to, whether from producer or consumer. |
990 | auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( |
991 | consumer_id, IdMappingMode::EXACT); |
992 | index_update_map[exact_concrete_id] = consumer_id; |
993 | } |
994 | |
995 | // No contiguity info is used in the predicate indexing pass, the predicate |
996 | // generation logic that uses the index math generated here will take |
997 | // contiguity into account. Send an empty ContigID class so nothing is marked |
998 | // as contiguous. |
999 | auto contig_finder = ContigIDs::getNonContigIDs(); |
1000 | |
1001 | // Run second backward traversal to map back to the consumer_tv |
1002 | auto target_indexing = indexing.updateIndexCompute( |
1003 | consumer_tv->domain(), index_update_map, contig_finder); |
1004 | |
1005 | return IndexFromIdGraph( |
1006 | target_indexing, |
1007 | indexing, |
1008 | index_parameters.initial_concrete_id_index, |
1009 | loop_indexing.loopDomains()); |
1010 | } |
1011 | |
1012 | namespace { |
1013 | |
1014 | class LoopIndexingTraversal { |
1015 | enum class TraversalOrder { ForwardTopological, BackwardTopological }; |
1016 | |
1017 | public: |
1018 | static std::vector<Expr*> forwardTopologicalOrder( |
1019 | const std::vector<Expr*>& exprs) { |
1020 | LoopIndexingTraversal traversal(exprs, TraversalOrder::ForwardTopological); |
1021 | return traversal.getExprList(); |
1022 | } |
1023 | |
1024 | static std::vector<Expr*> backwardTopologicalOrder( |
1025 | const std::vector<Expr*>& exprs) { |
1026 | LoopIndexingTraversal traversal(exprs, TraversalOrder::BackwardTopological); |
1027 | return traversal.getExprList(); |
1028 | } |
1029 | |
1030 | private: |
1031 | explicit LoopIndexingTraversal( |
1032 | const std::vector<Expr*>& exprs, |
1033 | TraversalOrder traversal_order); |
1034 | |
1035 | // Returns the vals following the expression in either |
1036 | // forward or backward order. |
1037 | const std::vector<Val*>& nextValsInTraversalOrder(Expr* expr); |
1038 | |
1039 | // Returns the vals that the expression follows in either |
1040 | // forward or backward order. |
1041 | const std::vector<Val*>& prevValsInTraversalOrder(Expr* expr); |
1042 | |
1043 | // Returns the sorted list according to the given traversal order. |
1044 | std::vector<Expr*> getExprList(); |
1045 | |
1046 | private: |
1047 | // Reference to original un-sorted expression list. |
1048 | const std::vector<Expr*>& exprs_; |
1049 | |
1050 | // The traversal order in this pass. |
1051 | const TraversalOrder traversal_order_ = TraversalOrder::ForwardTopological; |
1052 | |
1053 | // Internal record of concrete id's and it's corresponding |
1054 | // iterdomain expression that defines the exact index. |
1055 | std::unordered_map<IterDomain*, Expr*> concrete_id_to_dependency_; |
1056 | }; |
1057 | |
1058 | LoopIndexingTraversal::LoopIndexingTraversal( |
1059 | const std::vector<Expr*>& exprs, |
1060 | TraversalOrder traversal_order) |
1061 | : exprs_(exprs), traversal_order_(traversal_order) { |
1062 | // Populate concrete id dependencies: |
1063 | for (auto expr : exprs_) { |
1064 | auto next_ids = |
1065 | ir_utils::filterByType<IterDomain>(nextValsInTraversalOrder(expr)); |
1066 | for (auto id : next_ids) { |
1067 | auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( |
1068 | id, IdMappingMode::EXACT); |
1069 | TORCH_INTERNAL_ASSERT( |
1070 | concrete_id_to_dependency_.insert(std::make_pair(concrete_id, expr)) |
1071 | .second, |
1072 | "Repeated dependency, invalid iterdomain traversal." ); |
1073 | } |
1074 | } |
1075 | } |
1076 | |
1077 | const std::vector<Val*>& LoopIndexingTraversal::nextValsInTraversalOrder( |
1078 | Expr* expr) { |
1079 | switch (traversal_order_) { |
1080 | case TraversalOrder::ForwardTopological: |
1081 | return expr->outputs(); |
1082 | break; |
1083 | case TraversalOrder::BackwardTopological: |
1084 | return expr->inputs(); |
1085 | break; |
1086 | |
1087 | default: |
1088 | TORCH_INTERNAL_ASSERT(false, "unimplemented traversal order" ); |
1089 | } |
1090 | return expr->inputs(); |
1091 | } |
1092 | |
1093 | const std::vector<Val*>& LoopIndexingTraversal::prevValsInTraversalOrder( |
1094 | Expr* expr) { |
1095 | switch (traversal_order_) { |
1096 | case TraversalOrder::ForwardTopological: |
1097 | return expr->inputs(); |
1098 | break; |
1099 | case TraversalOrder::BackwardTopological: |
1100 | return expr->outputs(); |
1101 | break; |
1102 | |
1103 | default: |
1104 | TORCH_INTERNAL_ASSERT(false, "unimplemented traversal order" ); |
1105 | } |
1106 | return expr->inputs(); |
1107 | } |
1108 | |
1109 | std::vector<Expr*> LoopIndexingTraversal::getExprList() { |
1110 | std::deque<Expr*> to_visit(exprs_.begin(), exprs_.end()); |
1111 | |
1112 | // pre-allocate result space. |
1113 | std::vector<Expr*> result; |
1114 | result.reserve(exprs_.size()); |
1115 | |
1116 | // Keeps track of visited and inserted expressions. |
1117 | // An expr is visited if it has been placed in result list. |
1118 | // An expr is inserted if the traversal has put the expr on |
1119 | // the top of the stack once. Repeated insertion of the same |
1120 | // expression would never be observed if the underlying |
1121 | // dependency of the expressions is cycle free. |
1122 | std::unordered_set<Expr*> visited, inserted; |
1123 | |
1124 | while (!to_visit.empty()) { |
1125 | auto top = to_visit.front(); |
1126 | if (visited.count(top)) { |
1127 | to_visit.pop_front(); |
1128 | continue; |
1129 | } |
1130 | |
1131 | bool ready = true; |
1132 | |
1133 | for (auto prev_id : |
1134 | ir_utils::filterByType<IterDomain>(prevValsInTraversalOrder(top))) { |
1135 | auto prev_expr_it = concrete_id_to_dependency_.find( |
1136 | GpuLower::current()->caMap()->getConcreteMappedID( |
1137 | prev_id, IdMappingMode::EXACT)); |
1138 | if (prev_expr_it != concrete_id_to_dependency_.end()) { |
1139 | auto prev_expr = prev_expr_it->second; |
1140 | if (!visited.count(prev_expr)) { |
1141 | ready = false; |
1142 | to_visit.push_front(prev_expr); |
1143 | TORCH_INTERNAL_ASSERT( |
1144 | inserted.insert(prev_expr).second, |
1145 | "Circular dependency in loop index expressions." ); |
1146 | break; |
1147 | } |
1148 | } |
1149 | } |
1150 | |
1151 | if (ready) { |
1152 | visited.insert(top); |
1153 | result.emplace_back(top); |
1154 | to_visit.pop_front(); |
1155 | } |
1156 | } |
1157 | |
1158 | return result; |
1159 | } |
1160 | |
1161 | } // namespace |
1162 | |
1163 | void LoopIndexingAnalysis::collectOutOfLineExprs() { |
1164 | // Keep track of all the id's that can be resolved without |
1165 | // iterdomains on the left of ca axes. |
1166 | std::unordered_set<IterDomain*> out_of_line_ids; |
1167 | |
1168 | // Start the set with all the leaf ids. |
1169 | std::transform( |
1170 | consumer_tv_->domain()->domain().begin() + |
1171 | consumer_tv_->getComputeAtPosition(), |
1172 | consumer_tv_->domain()->domain().end(), |
1173 | std::inserter(out_of_line_ids, out_of_line_ids.end()), |
1174 | exactConcreteId); |
1175 | |
1176 | // Get the original selected list of index expressions |
1177 | // in reverse topological order. |
1178 | auto backward_expr_list = |
1179 | LoopIndexingTraversal::backwardTopologicalOrder(replayed_exprs_); |
1180 | |
1181 | for (auto expr : backward_expr_list) { |
1182 | auto id_outputs = ir_utils::filterByType<IterDomain>(expr->outputs()); |
1183 | if ( |
1184 | // Check that all of the outputs are out of line |
1185 | std::all_of( |
1186 | id_outputs.begin(), |
1187 | id_outputs.end(), |
1188 | [&out_of_line_ids](IterDomain* id) { |
1189 | return out_of_line_ids.count( |
1190 | GpuLower::current()->caMap()->getConcreteMappedID( |
1191 | id, IdMappingMode::EXACT)); |
1192 | })) { |
1193 | // Record out of line expression |
1194 | out_of_line_exprs_.push_back(expr); |
1195 | |
1196 | // Add all of the expression inputs as out of line id's. |
1197 | auto id_inputs = ir_utils::filterByType<IterDomain>(expr->inputs()); |
1198 | std::transform( |
1199 | id_inputs.begin(), |
1200 | id_inputs.end(), |
1201 | std::inserter(out_of_line_ids, out_of_line_ids.end()), |
1202 | exactConcreteId); |
1203 | } |
1204 | } |
1205 | } |
1206 | |
1207 | std::vector<Expr*> LoopIndexing::getForwardExprList() const { |
1208 | return LoopIndexingTraversal::forwardTopologicalOrder(index_exprs_); |
1209 | } |
1210 | |
1211 | std::vector<Expr*> LoopIndexing::getBackwardExprList() const { |
1212 | return LoopIndexingTraversal::backwardTopologicalOrder(index_exprs_); |
1213 | } |
1214 | |
1215 | std::unordered_set<IterDomain*> LoopIndexing::getAllExactConcreteIdSet() const { |
1216 | std::unordered_set<IterDomain*> all_id_set; |
1217 | for (auto expr : index_exprs_) { |
1218 | auto out_ids = ir_utils::filterByType<IterDomain>(expr->outputs()); |
1219 | std::transform( |
1220 | out_ids.begin(), |
1221 | out_ids.end(), |
1222 | std::inserter(all_id_set, all_id_set.end()), |
1223 | exactConcreteId); |
1224 | |
1225 | auto in_ids = ir_utils::filterByType<IterDomain>(expr->inputs()); |
1226 | std::transform( |
1227 | in_ids.begin(), |
1228 | in_ids.end(), |
1229 | std::inserter(all_id_set, all_id_set.end()), |
1230 | exactConcreteId); |
1231 | } |
1232 | return all_id_set; |
1233 | } |
1234 | |
1235 | namespace { |
1236 | |
1237 | //! Returns true if id is mapped together with any id in |
1238 | //! the vector ids by permissive compute at map. |
1239 | bool isPermissivelyMappedWithAny(IterDomain* id, const std::vector<Val*>& ids) { |
1240 | return std::any_of(ids.begin(), ids.end(), [&](Val* val) { |
1241 | return val->isA<IterDomain>() && |
1242 | GpuLower::current()->caMap()->areMapped( |
1243 | id, val->as<IterDomain>(), IdMappingMode::PERMISSIVE); |
1244 | }); |
1245 | } |
1246 | |
1247 | class LoopIndexingPreferredPathCompute : public IterVisitor { |
1248 | public: |
1249 | static std::unordered_set<IterDomain*> compute( |
1250 | const TensorView* original_tv, |
1251 | const LoopIndexing& loop_indexing, |
1252 | bool use_replay_map, |
1253 | const std::unordered_map<IterDomain*, IterDomain*>& p2c_map) { |
1254 | LoopIndexingPreferredPathCompute compute; |
1255 | |
1256 | auto all_concrete_ids = loop_indexing.getAllExactConcreteIdSet(); |
1257 | |
1258 | // Annotate all ids |
1259 | auto all_original_ids = DependencyCheck::getAllValsBetween( |
1260 | {original_tv->getMaybeRFactorDomain().begin(), |
1261 | original_tv->getMaybeRFactorDomain().end()}, |
1262 | {original_tv->domain()->domain().begin(), |
1263 | original_tv->domain()->domain().end()}); |
1264 | |
1265 | for (auto original_id : |
1266 | ir_utils::filterByType<IterDomain>(all_original_ids)) { |
1267 | auto mapped_id = original_id; |
1268 | if (use_replay_map) { |
1269 | auto c_id_it = p2c_map.find(original_id); |
1270 | if (c_id_it == p2c_map.end()) { |
1271 | continue; |
1272 | } |
1273 | mapped_id = c_id_it->second; |
1274 | } |
1275 | auto concrete_original_id = |
1276 | GpuLower::current()->caMap()->getConcreteMappedID( |
1277 | mapped_id, IdMappingMode::EXACT); |
1278 | if (all_concrete_ids.count(concrete_original_id)) { |
1279 | if (original_id->isBroadcast() || original_id->isReduction() || |
1280 | original_id->isStride()) { |
1281 | continue; |
1282 | } |
1283 | compute.preferred_path_.insert(concrete_original_id); |
1284 | } |
1285 | } |
1286 | |
1287 | for (auto expr : loop_indexing.getForwardExprList()) { |
1288 | compute.handle(expr); |
1289 | } |
1290 | |
1291 | return compute.preferred_path_; |
1292 | } |
1293 | |
1294 | private: |
1295 | void handle(Expr* e) override { |
1296 | // If an input ID is marked, propagate the marking to outputs of the |
1297 | // expression |
1298 | auto all_iter_inputs = ir_utils::filterByType<IterDomain>(e->inputs()); |
1299 | if (std::any_of( |
1300 | all_iter_inputs.begin(), |
1301 | all_iter_inputs.end(), |
1302 | [&](IterDomain* inp_id) { |
1303 | return this->preferred_path_.find( |
1304 | GpuLower::current()->caMap()->getConcreteMappedID( |
1305 | inp_id, IdMappingMode::EXACT)) != |
1306 | this->preferred_path_.end(); |
1307 | })) { |
1308 | auto all_iter_outputs = ir_utils::filterByType<IterDomain>(e->outputs()); |
1309 | |
1310 | std::transform( |
1311 | all_iter_outputs.begin(), |
1312 | all_iter_outputs.end(), |
1313 | std::inserter(preferred_path_, preferred_path_.end()), |
1314 | exactConcreteId); |
1315 | } |
1316 | } |
1317 | |
1318 | std::unordered_set<IterDomain*> preferred_path_; |
1319 | }; |
1320 | |
1321 | } // namespace |
1322 | |
1323 | // External interface for preferred path propagation. |
1324 | std::unordered_set<IterDomain*> buildLoopIndexingPreferredPath( |
1325 | const TensorView* original_tv, |
1326 | const LoopIndexing& loop_indexing, |
1327 | bool use_replay_map, |
1328 | std::unordered_map<IterDomain*, IterDomain*> p2c_map) { |
1329 | return LoopIndexingPreferredPathCompute::compute( |
1330 | original_tv, loop_indexing, use_replay_map, p2c_map); |
1331 | } |
1332 | |
1333 | // Get an rfactor IterDomain that is mapped with an IterDomain. If |
1334 | // multiple such IDs exist, select one whose input IDs are mapped with |
1335 | // the consumer IDs. This is to ensure the path from the leaf |
1336 | // IterDomains to the root matches with the consumer tensor. |
1337 | IterDomain* getRfactorIDToTraverse( |
1338 | IterDomain* id, |
1339 | const std::vector<Val*>& consumer_all_ids) { |
1340 | const auto& rfactor_ids = |
1341 | GpuLower::current()->caMap()->getViewRfactorDomainsOfIdGroup( |
1342 | id, IdMappingMode::PERMISSIVE); |
1343 | |
1344 | if (rfactor_ids.empty()) { |
1345 | return nullptr; |
1346 | } |
1347 | |
1348 | for (auto rfactor_id : rfactor_ids) { |
1349 | auto def = rfactor_id->definition(); |
1350 | if (def == nullptr) { |
1351 | continue; |
1352 | } |
1353 | |
1354 | auto rfactor_id_inputs = ir_utils::filterByType<IterDomain>(def->inputs()); |
1355 | if (std::all_of( |
1356 | rfactor_id_inputs.begin(), |
1357 | rfactor_id_inputs.end(), |
1358 | [&](IterDomain* rfactor_id_input) { |
1359 | return isPermissivelyMappedWithAny( |
1360 | rfactor_id_input, consumer_all_ids); |
1361 | })) { |
1362 | return rfactor_id; |
1363 | } |
1364 | } |
1365 | |
1366 | // No mapped ID found, which means the consumer is a post-view |
1367 | // tensor. In that case, it shouldn't matter which view path to |
1368 | // traverse, so just return the first one. |
1369 | return rfactor_ids.at(0); |
1370 | } |
1371 | |
1372 | } // namespace cuda |
1373 | } // namespace fuser |
1374 | } // namespace jit |
1375 | } // namespace torch |
1376 | |