1 | #include <predicate_compute.h> |
2 | |
3 | #include <arith.h> |
4 | #include <expr_evaluator.h> |
5 | #include <fusion.h> |
6 | #include <index_compute.h> |
7 | #include <instrumentation.h> |
8 | #include <ir_utils.h> |
9 | #include <lower2device.h> |
10 | #include <transform_iter.h> |
11 | |
12 | #include <c10/util/irange.h> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | namespace fuser { |
17 | namespace cuda { |
18 | |
19 | namespace { |
20 | |
21 | bool isTensorIndexOp(Expr* expr) { |
22 | const auto& outputs = expr->outputs(); |
23 | return outputs.size() >= 1 && outputs[0]->isA<kir::TensorIndex>(); |
24 | } |
25 | |
26 | bool isOutputLocal(const Expr* expr) { |
27 | return std::all_of( |
28 | expr->outputs().begin(), expr->outputs().end(), [](const Val* output) { |
29 | return !output->isA<TensorView>() || |
30 | output->as<TensorView>()->getMemoryType() == MemoryType::Local; |
31 | }); |
32 | } |
33 | |
34 | } // namespace |
35 | |
36 | bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) { |
37 | auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( |
38 | id, IdMappingMode::EXACT); |
39 | if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) { |
40 | ids_.push_back(concrete_id); |
41 | return true; |
42 | } else { |
43 | return false; |
44 | } |
45 | } |
46 | |
47 | Bool* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { |
48 | Bool* pred = nullptr; |
49 | |
50 | auto index = SimplifyingIrBuilder::create<NamedScalar>( |
51 | stringifyThread(pt_), DataType::Int); |
52 | |
53 | for (const auto& pred_id : ids()) { |
54 | // Just sanity check that pred_id is concrete |
55 | TORCH_INTERNAL_ASSERT( |
56 | pred_id == |
57 | GpuLower::current()->caMap()->getConcreteMappedID( |
58 | pred_id, IdMappingMode::EXACT)); |
59 | auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent()); |
60 | pred = SimplifyingIrBuilder::andExpr(pred, new_pred)->as<Bool>(); |
61 | } |
62 | |
63 | return pred; |
64 | } |
65 | |
66 | namespace { |
67 | |
68 | std::unordered_set<Val*> getNonUnswitchedRootDomains( |
69 | const std::vector<kir::ForLoop*>& loops, |
70 | size_t unswitched_loop_index) { |
71 | std::vector<Val*> non_unswited_leaf_domains; |
72 | std::transform( |
73 | loops.begin(), |
74 | loops.begin() + unswitched_loop_index, |
75 | std::back_inserter(non_unswited_leaf_domains), |
76 | [&](kir::ForLoop* loop) { return loop->iter_domain(); }); |
77 | |
78 | auto non_unswitched_inputs = |
79 | IterVisitor::getInputsTo(non_unswited_leaf_domains); |
80 | |
81 | auto non_unswitched_root_doms = |
82 | ir_utils::filterByType<IterDomain>(non_unswitched_inputs); |
83 | |
84 | std::unordered_set<Val*> non_unswitched_concrete_root_domains; |
85 | |
86 | std::transform( |
87 | non_unswitched_root_doms.begin(), |
88 | non_unswitched_root_doms.end(), |
89 | std::inserter( |
90 | non_unswitched_concrete_root_domains, |
91 | non_unswitched_concrete_root_domains.end()), |
92 | [&](auto root_dom) { |
93 | return GpuLower::current()->caMap()->getConcreteMappedID( |
94 | root_dom, IdMappingMode::EXACT); |
95 | }); |
96 | |
97 | return non_unswitched_concrete_root_domains; |
98 | } |
99 | |
100 | bool isFullyUnswitched( |
101 | IterDomain* loop_id, |
102 | const std::unordered_set<Val*>& non_unswitched_root_domains) { |
103 | auto root_vals = IterVisitor::getInputsTo({loop_id}); |
104 | |
105 | auto root_domains = ir_utils::filterByType<IterDomain>(root_vals); |
106 | |
107 | return std::none_of( |
108 | root_domains.begin(), root_domains.end(), [&](auto root_dom) { |
109 | auto concrete_root_dom = |
110 | GpuLower::current()->caMap()->getConcreteMappedID( |
111 | root_dom, IdMappingMode::EXACT); |
112 | return non_unswitched_root_domains.count(concrete_root_dom) > 0; |
113 | }); |
114 | } |
115 | |
116 | } // namespace |
117 | |
118 | std::unordered_map< |
119 | ParallelType, |
120 | ParallelizedDomainPredicate::PredicateInfo, |
121 | TypeHash> |
122 | ParallelizedDomainPredicate::getPredicateMap( |
123 | const Expr* expr, |
124 | const std::vector<kir::ForLoop*>& loops, |
125 | kir::ForLoop* unswitched_loop) { |
126 | const auto gpu_lower = GpuLower::current(); |
127 | auto output_tvs = ir_utils::getTvs(expr->outputs()); |
128 | |
129 | if (output_tvs.empty()) { |
130 | return {}; |
131 | } |
132 | |
133 | // Initialize a map with empty predicate info |
134 | std::unordered_map<ParallelType, PredicateInfo, TypeHash> map; |
135 | for (auto pt : kParallelTypeThreads) { |
136 | map.insert({pt, PredicateInfo(pt)}); |
137 | } |
138 | |
139 | // For each loop, check if it's parallelized by an non-exact |
140 | // threading dimension. If yes and it's used in the given expr, the |
141 | // domain needs to be protected by a predicate on the thread/block |
142 | // index. |
143 | |
144 | bool within_unswitch = false; |
145 | std::unordered_set<Val*> non_unswitched_root_domains; |
146 | |
147 | for (const auto i : c10::irange(loops.size())) { |
148 | auto loop = loops[i]; |
149 | |
150 | // Parallel dimensions need not be predicated if fully unswitched. |
151 | if (loop == unswitched_loop) { |
152 | within_unswitch = true; |
153 | non_unswitched_root_domains = getNonUnswitchedRootDomains(loops, i); |
154 | } |
155 | |
156 | auto loop_id = loop->iter_domain(); |
157 | auto loop_ptype = loop_id->getParallelType(); |
158 | |
159 | // Not necessary to add a predicate if the paralle type is exact |
160 | if (!isParallelTypeThread(loop_ptype) || |
161 | gpu_lower->parallelDimensionMap().isExact(loop_ptype)) { |
162 | continue; |
163 | } |
164 | |
165 | // Parallel dimensions need not be predicated if fully unswitched. |
166 | if (within_unswitch && |
167 | isFullyUnswitched(loop_id, non_unswitched_root_domains)) { |
168 | continue; |
169 | } |
170 | |
171 | for (auto tv : output_tvs) { |
172 | // Check if the loop domain is used by the output tensor |
173 | auto it = std::find_if( |
174 | tv->domain()->domain().begin(), |
175 | tv->domain()->domain().end(), |
176 | [&](auto tv_id) { |
177 | return gpu_lower->caMap()->areMapped( |
178 | loop_id, tv_id, IdMappingMode::EXACT); |
179 | }); |
180 | if (it == tv->domain()->domain().end()) { |
181 | continue; |
182 | } |
183 | |
184 | IterDomain* tv_id = *it; |
185 | |
186 | // If the corresponding domain is a broadcast, it's not really used. |
187 | if (tv_id->isBroadcast()) { |
188 | continue; |
189 | } |
190 | |
191 | // If it's a root domain, it should be covered by the root |
192 | // predicates, so no extra predicate is required. |
193 | if (std::find( |
194 | tv->domain()->getRootDomain().begin(), |
195 | tv->domain()->getRootDomain().end(), |
196 | tv_id) != tv->domain()->getRootDomain().end()) { |
197 | continue; |
198 | } |
199 | |
200 | // tv_id needs to be predicated. Adds it to the PredicateInfo map. |
201 | auto& info = map.at(loop_ptype); |
202 | info.addDomain(tv_id); |
203 | } |
204 | } |
205 | |
206 | return map; |
207 | } |
208 | |
209 | Bool* ParallelizedDomainPredicate::getPredicate( |
210 | const Expr* expr, |
211 | const std::vector<kir::ForLoop*>& loops) { |
212 | auto pred_map = getPredicateMap(expr, loops); |
213 | |
214 | Val* pred = GpuLower::current()->kernel()->trueVal(); |
215 | |
216 | for (auto pt : kParallelTypeThreads) { |
217 | auto pred_info_it = pred_map.find(pt); |
218 | if (pred_info_it != pred_map.end()) { |
219 | const auto& pred_info = pred_info_it->second; |
220 | auto tid_pred = pred_info.getPredicate(); |
221 | pred = SimplifyingIrBuilder::andExpr(pred, tid_pred); |
222 | } |
223 | } |
224 | |
225 | TORCH_INTERNAL_ASSERT(pred != nullptr); |
226 | return pred->as<Bool>(); |
227 | } |
228 | |
229 | UnswitchPredicateKey::UnswitchPredicateKey() |
230 | : predicated_concrete_id_(nullptr) { |
231 | for (auto pt : kParallelTypeThreads) { |
232 | parallel_concrete_ids_.insert({pt, nullptr}); |
233 | } |
234 | } |
235 | |
236 | // For a predicated concrete domain, id, find which thread parallel |
237 | // types are used. For each used parallel type, find the concrete |
238 | // domain that the paralllel type is associated with. The parallelized |
239 | // concrete domains are used to uniquely collect all necessary |
240 | // unswitch predicates. |
241 | UnswitchPredicateKey::UnswitchPredicateKey( |
242 | IterDomain* predicated_consumer_id, |
243 | TensorView* consumer_tv, |
244 | IterDomain* predicated_concrete_id) |
245 | : predicated_concrete_id_(predicated_concrete_id) { |
246 | // Initialize the parallelized domain map |
247 | for (auto pt : kParallelTypeThreads) { |
248 | parallel_concrete_ids_.insert({pt, nullptr}); |
249 | } |
250 | |
251 | std::vector<Val*> all_parallelized_consumer_leaf_ids; |
252 | std::copy_if( |
253 | consumer_tv->domain()->domain().begin(), |
254 | consumer_tv->domain()->domain().end(), |
255 | std::back_inserter(all_parallelized_consumer_leaf_ids), |
256 | [](IterDomain* x) { return isParallelTypeThread(x->getParallelType()); }); |
257 | |
258 | // If the consumer domais are not parallelized at all, no need to |
259 | // differentiate keys based on how the predicated id is parallelized |
260 | if (all_parallelized_consumer_leaf_ids.empty()) { |
261 | return; |
262 | } |
263 | |
264 | // All domains that are parallelized descendants of predicated_consumer_id |
265 | auto all_parallelized_consumer_ids = DependencyCheck::getAllValsBetween( |
266 | {predicated_consumer_id}, all_parallelized_consumer_leaf_ids); |
267 | // Just pick leaf domains |
268 | std::vector<IterDomain*> parallelized_consumer_leaf_ids; |
269 | std::copy_if( |
270 | consumer_tv->domain()->domain().begin(), |
271 | consumer_tv->domain()->domain().end(), |
272 | std::back_inserter(parallelized_consumer_leaf_ids), |
273 | [&](IterDomain* x) { |
274 | return std::find( |
275 | all_parallelized_consumer_ids.begin(), |
276 | all_parallelized_consumer_ids.end(), |
277 | x) != all_parallelized_consumer_ids.end(); |
278 | }); |
279 | |
280 | if (parallelized_consumer_leaf_ids.empty()) { |
281 | // None of the parallelized leaf domains are derived from |
282 | // predicated_consumer_id |
283 | return; |
284 | } |
285 | |
286 | // Find the corresponding concrete id for each parallel type |
287 | for (auto consumer_leaf : parallelized_consumer_leaf_ids) { |
288 | auto pt = consumer_leaf->getParallelType(); |
289 | auto concrete_leaf = GpuLower::current()->caMap()->getConcreteMappedID( |
290 | consumer_leaf, IdMappingMode::EXACT); |
291 | parallel_concrete_ids_.at(pt) = concrete_leaf; |
292 | } |
293 | } |
294 | |
295 | std::string UnswitchPredicateKey::toString() const { |
296 | std::stringstream ss; |
297 | ss << "Predicated domain: " ; |
298 | if (predicatedId() != nullptr) { |
299 | ss << predicatedId(); |
300 | } else { |
301 | ss << "null" ; |
302 | } |
303 | for (auto pt : kParallelTypeThreads) { |
304 | auto pid = parallelId(pt); |
305 | ss << ", " << pt << ": " ; |
306 | if (pid) { |
307 | ss << pid; |
308 | } else { |
309 | ss << "null" ; |
310 | } |
311 | } |
312 | return ss.str(); |
313 | } |
314 | |
315 | std::size_t UnswitchPredicateKeyHash::operator()( |
316 | const UnswitchPredicateKey& key) const { |
317 | auto h = std::hash<const IterDomain*>{}(key.predicatedId()); |
318 | for (auto pt : kParallelTypeThreads) { |
319 | h = h ^ std::hash<const IterDomain*>{}(key.parallelId(pt)); |
320 | } |
321 | return h; |
322 | }; |
323 | |
324 | Bool* PredicateCompute::getInlinePredicate( |
325 | const Expr* expr, |
326 | const std::vector<kir::ForLoop*>& loops, |
327 | Bool* thread_pred, |
328 | PredicateType pred_type) { |
329 | FUSER_PERF_SCOPE("GpuLower::Lower::getInlinePredicate" ); |
330 | |
331 | const auto gpu_lower = GpuLower::current(); |
332 | |
333 | // If outputs are registers, no need to predicate for threads |
334 | if (isOutputLocal(expr)) { |
335 | thread_pred = gpu_lower->kernel()->trueVal(); |
336 | } |
337 | |
338 | if (loops.empty()) { |
339 | TORCH_INTERNAL_ASSERT(thread_pred != nullptr); |
340 | return thread_pred; |
341 | } |
342 | |
343 | auto out_tv = ir_utils::getTvOutput(expr); |
344 | TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output" ); |
345 | |
346 | // Predicates for non-exact parallel dimensions must be used even |
347 | // when PredicateElimination::canOmitPredicate is true. |
348 | auto parallel_dom_pred = |
349 | ParallelizedDomainPredicate::getPredicate(expr, loops); |
350 | TORCH_INTERNAL_ASSERT(parallel_dom_pred != nullptr); |
351 | |
352 | if (gpu_lower->predicateElimination().canOmitPredicate(expr)) { |
353 | return SimplifyingIrBuilder::andExpr(thread_pred, parallel_dom_pred) |
354 | ->as<Bool>(); |
355 | } |
356 | |
357 | auto pred_info_vec = Index::getReferenceRootPredicates( |
358 | out_tv, loops, nullptr, pred_type == PredicateType::Padding); |
359 | |
360 | std::vector<Bool*> preds; |
361 | |
362 | // When pred_type is ReductionWrite, filter out predicates for |
363 | // reduction axes. For blockReduce, this is necessary when reduction |
364 | // axes start at non-zero offsets and parallelized with TID since |
365 | // blockReduce returns a valid output only at offset-zero |
366 | // threads. Similarly, for gridReduce, the last block to store the |
367 | // output may be predicated out with the read predicate, so the |
368 | // write predicate needs to ignore the reduction axes. |
369 | bool non_zero_start_found = false; |
370 | for (const auto& pred_info : pred_info_vec) { |
371 | if (pred_type == PredicateType::ReductionWrite) { |
372 | const auto& consumer_ids = pred_info.rootIds(); |
373 | bool pred_for_reduction_axis = false; |
374 | for (auto consumer_id : consumer_ids) { |
375 | if (consumer_id->isReduction()) { |
376 | if (!consumer_id->start()->isZeroInt()) { |
377 | non_zero_start_found = true; |
378 | } |
379 | pred_for_reduction_axis = true; |
380 | break; |
381 | } |
382 | } |
383 | // Don't add the predicate if it corresponds to a reduction axis |
384 | if (pred_for_reduction_axis) { |
385 | continue; |
386 | } |
387 | } |
388 | preds.push_back(pred_info.startPredicate()); |
389 | preds.push_back(pred_info.stopPredicate()); |
390 | } |
391 | |
392 | // When generating a predicate for blockReduce writes and not for |
393 | // gridReduce, if all reduction axes start with zero, we can just |
394 | // use the same predicate for reads. nullptr is returned then. |
395 | if (pred_type == PredicateType::ReductionWrite && !non_zero_start_found && |
396 | !out_tv->domain()->hasGridReduction()) { |
397 | return nullptr; |
398 | } |
399 | |
400 | preds.push_back(parallel_dom_pred); |
401 | |
402 | if (thread_pred != nullptr) { |
403 | preds.push_back(thread_pred); |
404 | } |
405 | |
406 | if (preds.empty()) { |
407 | return GpuLower::current()->kernel()->trueVal(); |
408 | } |
409 | |
410 | Val* cond = preds[0]; |
411 | for (const auto i : c10::irange(1, preds.size())) { |
412 | cond = SimplifyingIrBuilder::andExpr(cond, preds[i]); |
413 | } |
414 | |
415 | return cond->as<Bool>(); |
416 | } |
417 | |
418 | Bool* UnswitchPredicate::get( |
419 | const std::vector<kir::ForLoop*>& outer_loops, |
420 | kir::ForLoop* unrolled_loop) { |
421 | FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::get" ); |
422 | |
423 | UnswitchPredicate up(outer_loops, unrolled_loop); |
424 | |
425 | Val* unswitch_pred = GpuLower::current()->kernel()->trueVal(); |
426 | for (auto pred : up.predicates_) { |
427 | unswitch_pred = SimplifyingIrBuilder::andExpr(unswitch_pred, pred); |
428 | } |
429 | |
430 | return unswitch_pred->as<Bool>(); |
431 | } |
432 | |
433 | void UnswitchPredicate::predicateOn(Expr* tv_expr) { |
434 | FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::predicateOn" ); |
435 | |
436 | if (for_loops_.empty()) { |
437 | return; |
438 | } |
439 | |
440 | const auto gpu_lower = GpuLower::current(); |
441 | |
442 | // FIXME: |
443 | // Needed to keep the predicate of cp.async initialization to get the |
444 | // inverted predicate, |
445 | // see [Predicate Inversion for CpAsync]. In a follow up both this part and |
446 | // the [Predicate Inversion for CpAsync] should be cleaned up together. |
447 | if (gpu_lower->predicateElimination().canOmitPredicate(tv_expr) && |
448 | !ir_utils::isCpAsyncInit(tv_expr)) { |
449 | addParallelizedDomainPredicates(tv_expr); |
450 | return; |
451 | } |
452 | |
453 | auto out_tv = ir_utils::getTvOutput(tv_expr); |
454 | TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output" ); |
455 | |
456 | auto ref_pred_info = Index::getReferenceRootPredicates( |
457 | out_tv, for_loops_, unrolled_loop_, false); |
458 | |
459 | // If RootPredicateInfo has a static predicate that is more |
460 | // restrictive than the current one, replace the current with the |
461 | // new one. If it has a dynamic predicate, add it to the dynamic |
462 | // predicate list. Since the final static predicate can't be |
463 | // determined until all expressions are analyzed, predicates are |
464 | // temporarily placed in the predicated_keys map and the final |
465 | // predicates are generated in the finalize function. |
466 | |
467 | for (const auto& pred_info : ref_pred_info) { |
468 | TORCH_INTERNAL_ASSERT(pred_info.startPredicate() != nullptr); |
469 | TORCH_INTERNAL_ASSERT(pred_info.stopPredicate() != nullptr); |
470 | |
471 | const auto& root_ids = pred_info.rootIds(); |
472 | |
473 | bool add_pred = false; |
474 | |
475 | // Used to find a matching existing MergedPredicates |
476 | UnswitchPredicateKey first_key; |
477 | bool first_key_set = false; |
478 | |
479 | for (auto root_id : root_ids) { |
480 | auto concrete_root_id = gpu_lower->caMap()->getConcreteMappedID( |
481 | root_id, IdMappingMode::EXACT); |
482 | |
483 | if (root_id->isBroadcast()) { |
484 | continue; |
485 | } |
486 | |
487 | UnswitchPredicateKey key(root_id, out_tv, concrete_root_id); |
488 | auto inserted = predicated_keys_.insert(key).second; |
489 | add_pred = add_pred || inserted; |
490 | |
491 | if (!first_key_set) { |
492 | first_key = key; |
493 | first_key_set = true; |
494 | } |
495 | } |
496 | |
497 | if (!first_key_set) { |
498 | // No predicate generated |
499 | continue; |
500 | } |
501 | |
502 | // The start and stop offsets may need to be merged to avoid |
503 | // redundant predicates. When these offsets are zero, nothing is |
504 | // done. When non-zero, find the corresponding MergedPredicates |
505 | // and merge both the start and stop offsets. Note that the |
506 | // offsets are non-zero, the predicates must be generated at a |
507 | // root domain, so root_ids.size() must be one. That unique root |
508 | // domain is used as a key to find the corresponding |
509 | // MergedPredicate. |
510 | |
511 | // Initialize with an invalid iterator to signal no corresponding |
512 | // MergedPredicates is found yet. |
513 | auto merged_pred_it = pending_predicates_.end(); |
514 | |
515 | if (add_pred) { |
516 | // This is a new predicate for the root domain. Initialize a new |
517 | // MergedPredicates and add it to the pending list. |
518 | UnswitchPredicate::MergedPredicates merged_pred; |
519 | |
520 | // To look up this MergedPredicates for other predicates |
521 | // generated for the same predicate key |
522 | if (root_ids.size() == 1) { |
523 | merged_pred.predicate_key = first_key; |
524 | } |
525 | |
526 | pending_predicates_.push_back(merged_pred); |
527 | |
528 | merged_pred_it = |
529 | pending_predicates_.begin() + pending_predicates_.size() - 1; |
530 | } else if (root_ids.size() == 1) { |
531 | // If not new, try to find a corresponding MergedPredicates. |
532 | merged_pred_it = std::find_if( |
533 | pending_predicates_.begin(), |
534 | pending_predicates_.end(), |
535 | [&first_key](const auto& merged_predicates) { |
536 | return merged_predicates.predicate_key == first_key; |
537 | }); |
538 | // Note: It is possible that no matching merged predicate info |
539 | // is found. Since add_pred is false here, the root domain is |
540 | // already predicated. It must mean that the root domain |
541 | // is included in a contiguous merged domain, which means there |
542 | // must be no halo-extended domain involved. |
543 | } |
544 | |
545 | // If a corresponding MergedPredicates is found, merge both the |
546 | // start and stop offsets. |
547 | if (merged_pred_it != pending_predicates_.end()) { |
548 | mergeUnswitchPredicateOffsets( |
549 | pred_info.startPredicate(), |
550 | pred_info.startOffset(), |
551 | merged_pred_it->start, |
552 | true); |
553 | |
554 | mergeUnswitchPredicateOffsets( |
555 | pred_info.stopPredicate(), |
556 | pred_info.stopOffset(), |
557 | merged_pred_it->stop, |
558 | false); |
559 | } |
560 | } |
561 | |
562 | addParallelizedDomainPredicates(tv_expr); |
563 | } |
564 | |
565 | void UnswitchPredicate::addParallelizedDomainPredicates(Expr* tv_expr) { |
566 | auto pred_map = ParallelizedDomainPredicate::getPredicateMap( |
567 | tv_expr, for_loops_, unrolled_loop_); |
568 | for (auto pt : kParallelTypeThreads) { |
569 | auto pred_info_it = pred_map.find(pt); |
570 | if (pred_info_it == pred_map.end()) { |
571 | continue; |
572 | } |
573 | const auto& new_info = pred_info_it->second; |
574 | auto& predicated = |
575 | parallelized_dom_predicates_ |
576 | .insert({pt, ParallelizedDomainPredicate::PredicateInfo{pt}}) |
577 | .first->second; |
578 | for (auto id : new_info.ids()) { |
579 | if (predicated.addDomain(id)) { |
580 | predicates_.push_back(new_info.getPredicate()); |
581 | } |
582 | } |
583 | } |
584 | } |
585 | |
586 | void UnswitchPredicate::openLoop(kir::ForLoop* fl) { |
587 | FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::openLoop" ); |
588 | |
589 | for_loops_.push_back(fl); |
590 | |
591 | for (auto expr : fl->body().exprs()) { |
592 | if (ir_utils::isTvOp(expr) || isTensorIndexOp(expr)) { |
593 | predicateOn(expr); |
594 | } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) { |
595 | openIte(ite); |
596 | } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
597 | openLoop(for_loop); |
598 | } |
599 | } |
600 | |
601 | for_loops_.pop_back(); |
602 | } |
603 | |
604 | void UnswitchPredicate::openIte(kir::IfThenElse* ite) { |
605 | FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::openIte" ); |
606 | |
607 | // only expand the ite thenBody |
608 | for (auto expr : ite->thenBody().exprs()) { |
609 | if (ir_utils::isTvOp(expr) || isTensorIndexOp(expr)) { |
610 | predicateOn(expr); |
611 | } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) { |
612 | openIte(ite); |
613 | } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
614 | openLoop(for_loop); |
615 | } |
616 | } |
617 | } |
618 | |
619 | void UnswitchPredicate::finalize() { |
620 | for (const auto& merged_pred : pending_predicates_) { |
621 | const auto& start_info = merged_pred.start; |
622 | if (start_info.static_pred) { |
623 | predicates_.push_back(start_info.static_pred); |
624 | } |
625 | for (auto dynamic_pred : start_info.dynamic_preds) { |
626 | predicates_.push_back(dynamic_pred); |
627 | } |
628 | const auto& stop_info = merged_pred.stop; |
629 | if (stop_info.static_pred) { |
630 | predicates_.push_back(stop_info.static_pred); |
631 | } |
632 | for (auto dynamic_pred : stop_info.dynamic_preds) { |
633 | predicates_.push_back(dynamic_pred); |
634 | } |
635 | } |
636 | } |
637 | |
638 | void UnswitchPredicate::mergeUnswitchPredicateOffsets( |
639 | Bool* predicate, |
640 | Val* offset, |
641 | MergedPredicates::Info& merged_predicate_info, |
642 | bool is_start) { |
643 | auto is_more_restrictive = [&is_start](int64_t new_val, int64_t current_val) { |
644 | if (is_start) { |
645 | return new_val < current_val; |
646 | } else { |
647 | return new_val > current_val; |
648 | } |
649 | }; |
650 | |
651 | auto offset_int = dynamic_cast<Int*>(offset); |
652 | // If it's a static predicate, replace the current one if it's |
653 | // more restrictive. If it's dynamic, just adds it to the dynamic |
654 | // predicate list. |
655 | if (offset_int && offset_int->isConst()) { |
656 | auto offset_const = offset_int->value().value(); |
657 | auto& static_pred = merged_predicate_info.static_pred; |
658 | auto& static_offset = merged_predicate_info.static_offset; |
659 | if (static_pred == nullptr || |
660 | is_more_restrictive(offset_const, static_offset)) { |
661 | static_pred = predicate; |
662 | static_offset = offset_const; |
663 | } |
664 | } else { |
665 | merged_predicate_info.dynamic_preds.push_back(predicate); |
666 | } |
667 | } |
668 | |
669 | UnswitchPredicate::UnswitchPredicate( |
670 | std::vector<kir::ForLoop*> outer_loops, |
671 | kir::ForLoop* unrolled_loop) |
672 | : for_loops_(std::move(outer_loops)), unrolled_loop_(unrolled_loop) { |
673 | openLoop(unrolled_loop); |
674 | finalize(); |
675 | } |
676 | |
677 | } // namespace cuda |
678 | } // namespace fuser |
679 | } // namespace jit |
680 | } // namespace torch |
681 | |