1#include <lower_predicate_elimination.h>
2
3#include <arith.h>
4#include <expr_evaluator.h>
5#include <instrumentation.h>
6#include <ir_iostream.h>
7#include <ir_utils.h>
8#include <lower2device.h>
9#include <lower_shift.h>
10#include <lower_utils.h>
11#include <predicate_compute.h>
12#include <transform_iter.h>
13#include <transform_replay.h>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20namespace {
21
22// Warp primitives are currently limited to un-predicated usage,
23// predicating these ops will require extra steps to ensure that
24// the whole warp will get the same value.
25void assertOnWarpOps(const Expr* expr) {
26 TORCH_INTERNAL_ASSERT(
27 !ir_utils::isLdMatrixOp(expr),
28 "Predicate elimination: cannot eliminate pred for ldmatrix, use exact parallel dims",
29 expr->toString());
30 TORCH_INTERNAL_ASSERT(
31 !expr->isA<MmaOp>(),
32 "Mma op: cannot eliminate predicate for mma op, tiling not valid. ",
33 expr->toString());
34}
35
36} // namespace
37
38namespace {
39
40// Utility to check if the scheduled domain of the given
41// TensorView represent an exact shared mem access, meaning
42// that all the thread parallel dimensions on the leaf nodes
43// are exact so that the shared mem read/write would not
44// run out of bound because of thread over-subscription.
45bool isExactParallelSharedMemAccess(TensorView* tv) {
46 auto& parallel_dimension_map = GpuLower::current()->parallelDimensionMap();
47 for (auto id : tv->domain()->domain()) {
48 if (id->isThreadDim()) {
49 auto ptype = id->getParallelType();
50 // Need to predicate to avoid out of bound access
51 // because of over-subscribed block size.
52 if (!parallel_dimension_map.isExact(ptype)) {
53 return false;
54 }
55 }
56 }
57 return true;
58}
59
60class PredicateAnalyzer : public OptOutDispatch {
61 public:
62 //! Checks if a predicate is needed to avoid out-of-bound accesses.
63 //!
64 //! Due to the way we allocate local-memory tensors, there should
65 //! never be out-of-bound accesses with consumer tensors when allocated on
66 //! local memory. However, accessing producer tensors still may
67 //! result in out-of-bound as they are replayed as consumers.
68 static bool needsPredicate(TensorView* producer, TensorView* consumer) {
69 // Both tensors must be on local or shared memory. Global tensors must be
70 // predicated as allocation is done based on root domains. Smem
71 // and local tensors are allocated based on leaf domains.
72 // However, smem tensors are parallelized, which is highly likely, the size
73 // of the parallelized axis is the actual size of the axis, not
74 // the number of threads. This is currently actively checked to avoid
75 // out of bound shared mem access by out of bound threads.
76 if (producer->getMemoryType() == MemoryType::Global ||
77 consumer->getMemoryType() == MemoryType::Global) {
78 return true;
79 }
80
81 auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
82 auto c2p =
83 BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
84 .getReplay();
85
86 PredicateAnalyzer analyzer(c2p);
87
88 for (auto id : consumer->domain()->domain()) {
89 if (analyzer.needsPredicate(id)) {
90 return true;
91 }
92 }
93
94 return false;
95 }
96
97 private:
98 PredicateAnalyzer(const std::unordered_map<IterDomain*, IterDomain*>& c2p_map)
99 : c2p_map_(c2p_map) {}
100
101 // Returns true if no out-of-bound accesses could occur with a
102 // producer
103 bool needsPredicate(IterDomain* consumer_id) {
104 needs_predicate_ = false;
105 handle(consumer_id);
106 return needs_predicate_;
107 }
108
109 void handle(IterDomain* consumer_id) override {
110 // The traversal should have ended if needs_predicate_ was true
111 TORCH_INTERNAL_ASSERT(!needs_predicate_);
112
113 // If consumer_id is not going to be materialized as a loop (e.g.,
114 // broadcast), no need to predicate
115 if (consumer_id->isBroadcast() ||
116 GpuLower::current()->trivialReductionInfo().isDerived(consumer_id)) {
117 return;
118 }
119
120 // If the producer has a matching domain, it should not cause
121 // out-of-bound accesses
122 if (c2p_map_.find(consumer_id) != c2p_map_.end()) {
123 return;
124 }
125
126 // If no definition exists, stop traversing
127 if (consumer_id->definition() == nullptr) {
128 return;
129 }
130
131 OptOutDispatch::handle(consumer_id->definition());
132 }
133
134 // If it splits the input axis evenly, proceeds to check the input
135 // axis. Otherwise, we can't skip predication as it might cause
136 // out-bound accesses with the producer tensor
137 void handle(Split* split) override {
138 auto factor = split->factor()->getInt();
139 if (!factor.has_value()) {
140 needs_predicate_ = true;
141 return;
142 }
143
144 ExpressionEvaluator ee(split->fusion());
145 const auto in_extent = ee.evaluate(split->in()->extent());
146
147 if (!in_extent.has_value() || ((in_extent.value() % factor.value()) != 0)) {
148 needs_predicate_ = true;
149 return;
150 }
151
152 handle(split->in());
153 }
154
155 void handle(Merge* merge) override {
156 handle(merge->inner());
157 if (needs_predicate_) {
158 return;
159 }
160 handle(merge->outer());
161 }
162
163 private:
164 //! BestEffort map from consumer IDs to producer IDs
165 const std::unordered_map<IterDomain*, IterDomain*>& c2p_map_;
166 bool needs_predicate_ = false;
167};
168
169class PredicateChcker : public IterVisitor {
170 public:
171 static bool needsPredicate(
172 Expr* expr,
173 const std::unordered_set<const Expr*>& non_predicated_exprs) {
174 if (!ir_utils::isTvOp(expr)) {
175 return false;
176 }
177
178 PredicateChcker checker(non_predicated_exprs);
179 checker.handle(expr);
180 return checker.needs_predicate_;
181 }
182
183 private:
184 PredicateChcker(const std::unordered_set<const Expr*>& non_predicated_exprs)
185 : non_predicated_exprs_(non_predicated_exprs) {}
186
187 using IterVisitor::handle;
188
189 void handle(Expr* expr) final {
190 needs_predicate_ = predicateIntDiv(expr) ||
191 predicateMisalignedVectorize(expr) || predicateShift(expr) ||
192 predicateSharedMemAccess(expr) || predicateProducerConsumerPair(expr) ||
193 predicateNonDivisibleRootDomains(expr) ||
194 predicateNonDivisibleSplit(expr) || predicateExpandReduce(expr);
195
196 // A cp.async op would need a predicate for either the global
197 // input or its shared mem output, or both.
198 // Due to the WAR discussed in [Predicate Inversion for CpAsync],
199 // we currently cannot support use cases where both the gmem read
200 // and the smem write need to be predicated.
201 // Adding a check here would make the exclusion of such case as precise as
202 // possible and avoid duplication of predicateSharedMemAccess
203 // logic. But this part along with [Predicate Inversion for CpAsync]
204 // should be cleaned up all together when we extend predicate/masking
205 // logic to cover this usage.
206 TORCH_INTERNAL_ASSERT(
207 !(ir_utils::isCpAsyncOp(expr) && predicateSharedMemAccess(expr)),
208 "predicate removal: unsupported use case of cp.async");
209
210 if (needs_predicate_) {
211 return;
212 }
213
214 // Check ExprType-specific conditions
215 IterVisitor::handle(expr);
216 }
217
218 // All "predicateXYZ" functions return true if an expr needs to be
219 // predicated.
220
221 // Always predicate integer division and related ops as we don't
222 // know what values are in the out-of-bound region and they may
223 // cause exceptions
224 bool predicateIntDiv(Expr* expr) const {
225 auto dt = expr->outputs()[0]->getDataType().value();
226 return (
227 (dt == DataType::Int || dt == DataType::Int32) &&
228 expr->isA<BinaryOp>() &&
229 (expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Div ||
230 expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Mod ||
231 expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Remainder ||
232 expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::CeilDiv));
233 }
234
235 // If we're reducing an expanded domain, we need to be careful to predicate it
236 // or we could end up reducing a broadcasted value too many times.
237 bool predicateExpandReduce(Expr* expr) const {
238 if (!ir_utils::isReductionOp(expr)) {
239 return false;
240 }
241 auto tv_inputs = ir_utils::getTvs(expr->inputs());
242 TORCH_INTERNAL_ASSERT(
243 tv_inputs.size() > 0,
244 "Should never have a reduction op without a tensor view input.");
245 bool found_expand = false;
246 for (auto tv_input : tv_inputs) {
247 found_expand |= std::any_of(
248 tv_input->getMaybeRFactorDomain().begin(),
249 tv_input->getMaybeRFactorDomain().end(),
250 [](IterDomain* id) { return id->hasExpandedExtent(); });
251 }
252
253 if (!found_expand) {
254 return false;
255 }
256
257 auto tv_outputs = ir_utils::getTvs(expr->outputs());
258 if (expr->isA<WelfordOp>() && tv_inputs.size() != tv_outputs.size()) {
259 tv_outputs = std::vector<TensorView*>(tv_inputs.size(), tv_outputs[0]);
260 }
261
262 TORCH_INTERNAL_ASSERT(
263 tv_outputs.size() == tv_inputs.size(),
264 "Was expecting matching number of inputs and outputs for expression: ",
265 expr->toString());
266
267 for (auto i : c10::irange(tv_inputs.size())) {
268 const auto root_p2c =
269 PairwiseRootDomainMap(tv_inputs[i], tv_outputs[i])
270 .mapProducerToConsumer(
271 tv_inputs[i]->domain(), tv_outputs[i]->domain());
272 for (auto entry : root_p2c) {
273 auto p_id = entry.first;
274 auto c_id = entry.second;
275 if (p_id->hasExpandedExtent() && c_id->isReduction()) {
276 return true;
277 }
278 }
279 }
280 return false;
281 }
282
283 // Skip if MisalignedVectorize is involved for now. This could be
284 // relaxed.
285 bool predicateMisalignedVectorize(Expr* expr) const {
286 std::vector<const std::vector<Val*>*> inputs_and_outputs = {
287 &(expr->inputs()), &(expr->outputs())};
288 for (const auto& inputs_or_outputs : inputs_and_outputs) {
289 for (auto tv : ir_utils::filterByType<TensorView>(*inputs_or_outputs)) {
290 if (std::any_of(
291 tv->domain()->domain().begin(),
292 tv->domain()->domain().end(),
293 [](IterDomain* axis) {
294 return axis->getParallelType() ==
295 ParallelType::MisalignedVectorize;
296 })) {
297 return true;
298 }
299 }
300 }
301 return false;
302 }
303
304 // Shift is not supported yet.
305 bool predicateShift(Expr* expr) const {
306 auto halo_info = GpuLower::current()->haloInfo();
307 auto input_tvs = ir_utils::filterByType<TensorView>(expr->inputs());
308 return halo_info->needsShiftPredicate(expr) ||
309 std::any_of(input_tvs.begin(), input_tvs.end(), [&](auto input_tv) {
310 return input_tv->definition() != nullptr &&
311 halo_info->needsShiftPredicate(input_tv->definition());
312 });
313 }
314
315 // Predicates the expression if any producer-consumer pair of the
316 // expression needs to be predicated
317 bool predicateProducerConsumerPair(Expr* expr) const {
318 for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
319 for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
320 if (PredicateAnalyzer::needsPredicate(input, output)) {
321 return true;
322 }
323 }
324 }
325 return false;
326 }
327
328 bool predicateSharedMemAccess(Expr* expr) const {
329 // This is initial step to gradually remove predicates around
330 // sharedmem access in suitable situations.
331 // Using an additional variable to track the predicate-on reasons
332 // when the predicate around shared mem cannot be removed.
333 for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
334 for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
335 if (producer->getMemoryType() == MemoryType::Shared ||
336 consumer->getMemoryType() == MemoryType::Shared) {
337 if (needSharedMemPredicate(producer, consumer)) {
338 return true;
339 }
340 }
341 }
342 }
343
344 return false;
345 }
346
347 // Check for conditions where the predicate cannot be removed
348 // when either producer or consumer is in shared memory.
349 bool needSharedMemPredicate(TensorView* producer, TensorView* consumer)
350 const {
351 // Indexing is based on consumer leaf ids so check the consumer.
352
353 // If consumer schedule contains in-exact thread parallel
354 // dimensions, need to predicate against out of bound
355 // shared memory access by out of bound threads.
356 if (!isExactParallelSharedMemAccess(consumer)) {
357 return true;
358 }
359
360 // TODO: This is directed WAR on FusionPersistentNormLocalShared.
361 // This use case along with other previous issues motivate a
362 // joint optimization of predicate removal and buffer reuse.
363 // In this particular case:
364 // __shared__ T0 [10], T1[10]
365 // for i in ...
366 // if(pred)
367 // T1[i] = T0[i] + ... // exp0
368 // T2 = 0; // init for exp1
369 // if(pred)
370 // T2 = T1 ... // exp1
371 // If we remove pred around expr1, as the way the pred removal
372 // pass is set up, the init for expr will be pushed up to
373 // initialize T1 instead.
374 // However if we initialize T1, the code will look like:
375 // for i in ...
376 // T1[i] = 0;
377 // for i in ...
378 // if(pred)
379 // T1[i] = T0[i] + ...
380 // Note that we'd be able to reuse buffer of T0 for T1 but
381 // if we initialze T1 we cannot do that and thus the
382 // kernel would not fit in smaller devices.
383 if (producer->getMemoryType() == MemoryType::Shared) {
384 if (auto producer_def = producer->definition()) {
385 if (std::any_of(
386 producer_def->inputs().begin(),
387 producer_def->inputs().end(),
388 [](Val* val) {
389 if (auto tv = ir_utils::getTv(val)) {
390 return tv->getMemoryType() == MemoryType::Shared;
391 }
392 return false;
393 })) {
394 // Disable shared memory producers that is a consumer
395 // of another shared memory tensor. The initialization would
396 // break potential opportunity to re-use shared mem buffer.
397 return true;
398 }
399 }
400 }
401
402 for (auto id : consumer->domain()->domain()) {
403 // TODO: (Enable in a follow up)
404 // smem predicate removal with init would break unroll and unswitch,
405 // eg. as in issue 1133, so disabling this removal pattern for now.
406 if (id->getParallelType() == ParallelType::Unroll ||
407 id->getParallelType() == ParallelType::Unswitch) {
408 return true;
409 }
410
411 // TODO: (Enable in a follow up)
412 // This cannot yet be removed since smem initialization needs to be
413 // handled specially, e.g. as in smem_reduce test. Will be able to
414 // lift this one once the generic pred removal pass with fusion
415 // traversal is ready.
416 auto consumer_def = consumer->definition();
417 if (ir_utils::isReductionOp(consumer_def)) {
418 if (producer->getMemoryType() == MemoryType::Shared) {
419 return true;
420 }
421 }
422 }
423
424 return false;
425 }
426
427 // Utility to find the leaf iterdomains of the given
428 // tensor view that will be treated as "zero loops"
429 // in the indexing pass.
430 // For details on zero loops, see indexMapFromTV in
431 // lower index pass.
432 std::vector<Val*> getZeroLeafIds(const TensorView* tv) const {
433 TORCH_INTERNAL_ASSERT(
434 tv->getMemoryType() == MemoryType::Local ||
435 tv->getMemoryType() == MemoryType::Shared,
436 "Local or shared memory tensor is assumed: ",
437 tv->toString());
438 bool is_shared_mem = tv->getMemoryType() == MemoryType::Shared;
439 std::vector<Val*> zero_leaf_ids;
440 for (const auto i : c10::irange(tv->nDims())) {
441 auto leaf_id = tv->axis(i);
442 if (is_shared_mem && leaf_id->isThreadDim()) {
443 // Thread parallel axes on shared mem are never
444 // zero loops as each thread owns its share
445 // of the shared mem space.
446 continue;
447 }
448 if (
449 // Non-thread parallel dimension on the left
450 // of CA axes are zero loops.
451 i < tv->getComputeAtPosition() ||
452 // Parallel axes on local mem is zero loop.
453 // Grid axes on shared mem is zero loop.
454 leaf_id->isThread() ||
455 // Mma axes, similar to vectorization, are
456 // implicit in hardware intrinsics, and thus
457 // will be treated as a zero loop.
458 leaf_id->isMma()) {
459 zero_leaf_ids.push_back(leaf_id);
460 }
461 }
462
463 return zero_leaf_ids;
464 }
465
466 // An index can exceed the logical extent of the indexed domain if
467 // it's split. It can cause a reduction op to reduce the same value
468 // multiple times. Even a pointwise op can be a problem if the
469 // consumer is an alias of the producer. This check excludes such
470 // expressions from predicate elimination.
471 //
472 // This is not an issue if the index includes a zero domain (as defined in
473 // index_compute.cpp), the extent is calculated by multiplying the
474 // split output domains, so it never cross the domain boundary.
475 // So, if a root domain is split and none of its descendants is a
476 // zero domain, the expr needs to be predicated. See
477 // FusionPredicateElimination6 for a concrete example.
478 //
479 // It would be also possible to avoid register aliasing instead of
480 // giving up predicate elimination. Since this condition should be
481 // rather uncommon, either would be fine as long as correctness is
482 // provided.
483 bool predicateNonDivisibleRootDomains(Expr* expr) const {
484 for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
485 const auto all_exprs = DependencyCheck::getAllExprsBetween(
486 {output->getMaybeRFactorDomain().begin(),
487 output->getMaybeRFactorDomain().end()},
488 {output->domain()->domain().begin(),
489 output->domain()->domain().end()});
490 std::unordered_set<Val*> split_root;
491 std::copy_if(
492 output->getMaybeRFactorDomain().begin(),
493 output->getMaybeRFactorDomain().end(),
494 std::inserter(split_root, split_root.end()),
495 [&](auto rf_root) {
496 if (rf_root->isBroadcast() ||
497 GpuLower::current()->trivialReductionInfo().isDerived(
498 rf_root)) {
499 return false;
500 }
501 for (Expr* use : rf_root->uses()) {
502 if (std::find(all_exprs.begin(), all_exprs.end(), use) ==
503 all_exprs.end()) {
504 continue;
505 }
506 return use->isA<Split>();
507 }
508 return false;
509 });
510 // If no root domain is split, no need to predicate
511 if (split_root.empty()) {
512 continue;
513 }
514 const auto zero_leaf_ids = getZeroLeafIds(output);
515 if (zero_leaf_ids.empty()) {
516 return true;
517 }
518 const auto vals =
519 DependencyCheck::getAllValsBetween(split_root, zero_leaf_ids);
520 if (std::any_of(
521 split_root.begin(),
522 split_root.end(),
523 [&vals](auto split_root_id) {
524 return std::find(vals.begin(), vals.end(), split_root_id) ==
525 vals.end();
526 })) {
527 return true;
528 }
529 }
530 return false;
531 }
532
533 // Always predicate if non-divisible split is found. It may be
534 // possible to make it less conservative.
535 // See FusionPredicateElimination7 for a concrete example.
536 bool predicateNonDivisibleSplit(Expr* expr) const {
537 const auto& non_divisible_split_info =
538 GpuLower::current()->nonDivisibleSplitInfo();
539 for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
540 if (non_divisible_split_info.splitsToPredicate().find(output) !=
541 non_divisible_split_info.splitsToPredicate().end()) {
542 return true;
543 }
544 }
545 return false;
546 }
547
548 // If this is a reduction, and if we omit the predicate for the
549 // input, the input may have a garbabe value, which must not be used
550 // for this reduction. However, it is still legal to omit its
551 // predicate when: 1) the predicate of the input is not omitted and
552 // 2) the input can be initialized to the init value of this
553 // reduction. When the input is the output of another reduciton, the
554 // input is initialized to the init value of the reduction, so the
555 // two reductions must use the same init value.
556 // See FusionPredicateElimination3 and FusionPredicateElimination4
557 // for concrete examples.
558 void handle(ReductionOp* rop) final {
559 auto input = rop->inputs()[0]->as<TensorView>();
560 auto input_def = input->definition();
561 // When input_def is null, input must be an input to the fusion,
562 // so that must be allocated on global memory. Since we don't omit
563 // predication for expressions involving global memory, this
564 // should never occur.
565 TORCH_INTERNAL_ASSERT(
566 input_def != nullptr, "Inconsistent input found: ", input->toString());
567
568 // The input needs to be initialized to the init value to omit
569 // the predicate, so if the input has its own init value, i.e.,
570 // produced by another reduction, they must use the same init
571 // value.
572 Val* input_init = ir_utils::getReductionInitValOf(input);
573 if (input_init != nullptr && !rop->init()->sameAs(input_init)) {
574 needs_predicate_ = true;
575 return;
576 }
577
578 // If input is not predicated, out-of-bound value may be
579 // overwritten by a garbage value. However, it doesn't matter if
580 // the input is also produced by another reduction. If the preceding
581 // reduction omits the predicate, it means its input must be
582 // initialized to its init value, so no predicate should be
583 // needed in both of the two reduction ops if they use the same
584 // init value, which is guaranteed by the above check, and the
585 // same reduction op.
586 if (auto input_def_rop = dynamic_cast<ReductionOp*>(input_def)) {
587 if (rop->getReductionOpType() != input_def_rop->getReductionOpType() &&
588 non_predicated_exprs_.find(input_def) !=
589 non_predicated_exprs_.end()) {
590 needs_predicate_ = true;
591 return;
592 }
593 } else if (
594 non_predicated_exprs_.find(input_def) != non_predicated_exprs_.end()) {
595 needs_predicate_ = true;
596 return;
597 }
598 }
599
600 // Welford. See FusionPredicateElimination5.
601 void handle(WelfordOp* wop) final {
602 for (const auto i : c10::irange(3)) {
603 auto init = wop->getInitVals()[i];
604
605 // Welford input can be a scalar. Predicate is required unless
606 // the scalar value is equal to the init value.
607 auto input = wop->inputs().at(i);
608 if (input->isScalar()) {
609 if (!input->sameAs(init)) {
610 needs_predicate_ = true;
611 return;
612 }
613 continue;
614 }
615
616 auto input_tv = dynamic_cast<TensorView*>(input);
617 TORCH_INTERNAL_ASSERT(input_tv != nullptr);
618
619 auto input_def = input->definition();
620
621 // When input_def is null, input must be an input to the fusion,
622 // so that must be allocated on global memory. Since we don't omit
623 // predication for expressions involving global memory, this
624 // should never occur.
625 TORCH_INTERNAL_ASSERT(
626 input_def != nullptr,
627 "Inconsistent input found: ",
628 input->toString());
629
630 // The input needs to be initialized to the init value to omit
631 // the predicate, so if the input has its own init value, i.e.,
632 // produced by another reduction, they must use the same init
633 // value.
634 Val* input_init = ir_utils::getReductionInitValOf(input_tv);
635 if (input_init != nullptr && !init->sameAs(input_init)) {
636 needs_predicate_ = true;
637 return;
638 }
639
640 // If input is not predicated, out-of-bound value may be
641 // overwritten by a garbage value. However, it doesn't matter if
642 // the input is also produced by another welford.
643 if (!input_def->isA<WelfordOp>() && !input_def->isA<GroupedWelfordOp>() &&
644 non_predicated_exprs_.find(input_def) !=
645 non_predicated_exprs_.end()) {
646 needs_predicate_ = true;
647 return;
648 }
649 }
650 }
651
652 void handle(GroupedReductionOp* grouped_rop) final {
653 for (const auto i : c10::irange(grouped_rop->numExprs())) {
654 auto input = grouped_rop->input(i)->as<TensorView>();
655 auto input_def = input->definition();
656 // When input_def is null, input must be an input to the fusion,
657 // so that must be allocated on global memory. Since we don't omit
658 // predication for expressions involving global memory, this
659 // should never occur.
660 TORCH_INTERNAL_ASSERT(
661 input_def != nullptr,
662 "Inconsistent input found: ",
663 input->toString());
664
665 // The input needs to be initialized to the init value to omit
666 // the predicate, so if the input has its own init value, i.e.,
667 // produced by another reduction, they must use the same init
668 // value.
669 Val* input_init = ir_utils::getReductionInitValOf(input);
670 if (input_init != nullptr &&
671 !grouped_rop->initVal(i)->sameAs(input_init)) {
672 needs_predicate_ = true;
673 return;
674 }
675
676 // If input is not predicated, out-of-bound value may be
677 // overwritten by a garbage value. However, it doesn't matter if
678 // the input is also produced by another reduction. If the preceding
679 // reduction omits the predicate, it means its input must be
680 // initialized to its init value, so no predicate should be
681 // needed in both of the two reduction ops if they use the same
682 // init value, which is guaranteed by the above check, and the
683 // same reduction op.
684 if (auto input_def_rop = dynamic_cast<ReductionOp*>(input_def)) {
685 if (grouped_rop->getReductionOpType(i) !=
686 input_def_rop->getReductionOpType() &&
687 non_predicated_exprs_.find(input_def) !=
688 non_predicated_exprs_.end()) {
689 needs_predicate_ = true;
690 return;
691 }
692 } else if (
693 auto input_def_grouped_rop =
694 dynamic_cast<GroupedReductionOp*>(input_def)) {
695 auto input_index_as_output =
696 input_def_grouped_rop->getExprIndexOfOutput(input);
697 if (grouped_rop->getReductionOpType(i) !=
698 input_def_grouped_rop->getReductionOpType(
699 input_index_as_output) &&
700 non_predicated_exprs_.find(input_def) !=
701 non_predicated_exprs_.end()) {
702 needs_predicate_ = true;
703 return;
704 }
705 } else if (
706 non_predicated_exprs_.find(input_def) !=
707 non_predicated_exprs_.end()) {
708 needs_predicate_ = true;
709 return;
710 }
711 }
712 }
713
714 void handle(GroupedWelfordOp* grouped_wop) final {
715 for (const auto expr_idx : c10::irange(grouped_wop->numExprs())) {
716 for (const auto val_idx : c10::irange(3)) {
717 auto init = grouped_wop->initVals().at(expr_idx).get(val_idx);
718
719 // Welford input can be a scalar. Predicate is required unless
720 // the scalar value is equal to the init value.
721 auto input = grouped_wop->inputVals().at(expr_idx).get(val_idx);
722 if (input->isScalar()) {
723 if (!input->sameAs(init)) {
724 needs_predicate_ = true;
725 return;
726 }
727 continue;
728 }
729
730 auto input_tv = dynamic_cast<TensorView*>(input);
731 TORCH_INTERNAL_ASSERT(input_tv != nullptr);
732
733 auto input_def = input->definition();
734
735 // When input_def is null, input must be an input to the fusion,
736 // so that must be allocated on global memory. Since we don't omit
737 // predication for expressions involving global memory, this
738 // should never occur.
739 TORCH_INTERNAL_ASSERT(
740 input_def != nullptr,
741 "Inconsistent input found: ",
742 input->toString());
743
744 // The input needs to be initialized to the init value to omit
745 // the predicate, so if the input has its own init value, i.e.,
746 // produced by another reduction, they must use the same init
747 // value.
748 Val* input_init = ir_utils::getReductionInitValOf(input_tv);
749 if (input_init != nullptr && !init->sameAs(input_init)) {
750 needs_predicate_ = true;
751 return;
752 }
753
754 // If input is not predicated, out-of-bound value may be
755 // overwritten by a garbage value. However, it doesn't matter if
756 // the input is also produced by another reduction op as it
757 // must be initialized and its initialized value is already
758 // found to be equal to the initil value of this op.
759 if (!input_def->isA<WelfordOp>() &&
760 !input_def->isA<GroupedWelfordOp>() &&
761 non_predicated_exprs_.find(input_def) !=
762 non_predicated_exprs_.end()) {
763 needs_predicate_ = true;
764 return;
765 }
766 }
767 }
768 }
769
770 // Similar to the above reduction constraint but for MMA
771 void handle(MmaOp* mma) final {
772 for (auto input : ir_utils::filterByType<TensorView>(mma->inputs())) {
773 auto input_def = input->definition();
774 TORCH_INTERNAL_ASSERT(
775 input_def != nullptr,
776 "Inconsistent input found: ",
777 input->toString());
778
779 Val* input_init = ir_utils::getReductionInitValOf(input);
780 if (input_init != nullptr && !mma->init()->sameAs(input_init)) {
781 needs_predicate_ = true;
782 return;
783 }
784
785 if (non_predicated_exprs_.find(input_def) !=
786 non_predicated_exprs_.end()) {
787 // If producer of mma is non_predicated and initialized
788 // with the same value. The mma should not need a
789 // predicate. In fact this is the only way we can
790 // use mma at the moment since we could not predicate
791 // mma ops without guaranteeing warp uniform results.
792 auto input_init =
793 GpuLower::current()->predicateElimination().getInitValue(input);
794
795 // TODO:
796 // clean up this to support more generic prolog fusion.
797 // Will need additional analysis passes on initialization
798 // propagation and further predicate placement on top.
799 // More TODO:
800 // Even when producer is initialized, it is still generally
801 // not safe to remove predicate around reduction ops if the
802 // producer is not predicated.
803 // On the other side, we do have patterns like ldmatrix->mma where
804 // both producer and consumer cannot be safely predicated without
805 // guaranteeing warp uniform results.
806 // This is currently a WAR and relies on validation pass to exclude
807 // complex prolog patterns in mma based matmul kernels. Will
808 // definitely need to revisit and build out predicate and
809 // initialization analysis pass to better handle this case.
810 if (input_init != nullptr && !input_init->sameAs(mma->init())) {
811 // This is a WAR at the moment. We would need to propagate
812 // initialization information from PredicateElimination
813 // pass to most accurately detect if the input is
814 // initialized correctly.
815 // This could also be fixed when we have the traversal
816 // based predicate elimination and initialization pass
817 // ready. Would be easy to clean up this part at that point.
818 needs_predicate_ = true;
819 return;
820 }
821 }
822 }
823 }
824
825 private:
826 const std::unordered_set<const Expr*>& non_predicated_exprs_;
827 bool needs_predicate_ = false;
828};
829
830} // namespace
831
832bool PredicateElimination::needsPredicate(Expr* expr) const {
833 return PredicateChcker::needsPredicate(expr, non_predicated_exprs_);
834}
835
836void PredicateElimination::handle(Expr* expr) {
837 if (!ir_utils::isTvOp(expr)) {
838 return;
839 }
840
841 if (needsPredicate(expr)) {
842 assertOnWarpOps(expr);
843 return;
844 }
845
846 non_predicated_exprs_.insert(expr);
847
848 // Ensure all inputs have some values set at the out-of-bound
849 // regions
850 for (const auto i : c10::irange(expr->inputs().size())) {
851 auto input = dynamic_cast<TensorView*>(expr->inputs()[i]);
852 if (input == nullptr) {
853 continue;
854 }
855 auto input_def = input->definition();
856 // When input_def is null, input must be an input to the fusion,
857 // so that must be allocated on global memory. Since we don't omit
858 // predication for expressions involving global memory, this
859 // should never occur.
860 TORCH_INTERNAL_ASSERT(
861 input_def != nullptr, "Inconsistent input found: ", input->toString());
862
863 // If input is an output of reduction, it should be fully
864 // initialied as it's allocated on local memory.
865 if (ir_utils::isReductionOp(input_def)) {
866 continue;
867 }
868
869 if (expr->isA<ReductionOp>()) {
870 setReductionInitValue(input, expr->as<ReductionOp>()->init());
871 continue;
872 } else if (expr->isA<GroupedReductionOp>()) {
873 setReductionInitValue(input, expr->as<GroupedReductionOp>()->initVal(i));
874 continue;
875 } else if (auto wop = dynamic_cast<WelfordOp*>(expr)) {
876 Val* init = wop->getInitVals().at(i);
877 setReductionInitValue(input, init);
878 continue;
879 } else if (expr->isA<MmaOp>()) {
880 setReductionInitValue(input, expr->as<MmaOp>()->init());
881 continue;
882 } else if (
883 non_predicated_exprs_.find(input_def) != non_predicated_exprs_.end()) {
884 // If an input does not need a predicate either, then it should
885 // have some value, so no need to set a default value
886 continue;
887 } else {
888 // Make sure input is initialized
889 setDefaultInitValue(input);
890 }
891 }
892}
893
894bool PredicateElimination::setDefaultInitValue(TensorView* tv) {
895 auto it = init_value_map_.find(tv);
896 // If there's already a mapping for tv, it should be mapped to a
897 // zero val or a reduction init. Either case, no need to modify
898 // the existing mapping.
899 if (it == init_value_map_.end()) {
900 init_value_map_.insert({tv, nullptr});
901 }
902 return true;
903}
904
905bool PredicateElimination::setReductionInitValue(
906 TensorView* tv,
907 Val* reduction_init) {
908 TORCH_INTERNAL_ASSERT(tv != nullptr);
909
910 auto it = init_value_map_.find(tv);
911 if (it == init_value_map_.end()) {
912 init_value_map_.insert({tv, reduction_init});
913 return true;
914 }
915
916 auto existing_val = it->second;
917 if (existing_val == nullptr) {
918 // If the existing mapping returns nullptr, it means that a
919 // default init was set before. Overwrite with the reduction
920 // init val.
921 init_value_map_[tv] = reduction_init;
922 return true;
923 } else if (existing_val->sameAs(reduction_init)) {
924 return true;
925 } else {
926 TORCH_INTERNAL_ASSERT(
927 false,
928 "Inconsistent setting of initialization value for t",
929 tv->name(),
930 ". Prev: ",
931 existing_val->toString(),
932 ", New: ",
933 reduction_init->toString());
934 return false;
935 }
936}
937
938bool PredicateElimination::canOmitPredicate(const Expr* expr) const {
939 // Predicate elimination can be disabled with
940 // PYTORCH_NVFUSER_DISABLE=predicate_elimination
941 if (isOptionDisabled(DisableOption::PredicateElimination)) {
942 assertOnWarpOps(expr);
943 return false;
944 }
945
946 TORCH_INTERNAL_ASSERT(expr != nullptr);
947 const auto out_tv = ir_utils::getTvOutput(expr);
948 TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression");
949
950 if (ir_utils::isTensorScalarFillOp(expr)) {
951 if (out_tv->getMemoryType() == MemoryType::Local) {
952 // Filling a local tensor with scalar shouldn't
953 // need any predicate currently.
954 return true;
955 } else if (out_tv->getMemoryType() == MemoryType::Shared) {
956 // A shared memory initialization should be same except
957 // that we'd need a predicate to guard against out of
958 // bound access by out of inexact threads.
959 return isExactParallelSharedMemAccess(out_tv);
960 }
961 }
962
963 if (non_predicated_exprs_.find(expr) != non_predicated_exprs_.end()) {
964 return true;
965 }
966
967 assertOnWarpOps(expr);
968 return false;
969}
970
971void PredicateElimination::propagateRemovalInfo(
972 const Expr* from,
973 const Expr* to) {
974 if (non_predicated_exprs_.count(from)) {
975 non_predicated_exprs_.insert(to);
976 }
977}
978
979Val* PredicateElimination::getInitValue(TensorView* tv) const {
980 auto it = init_value_map_.find(tv);
981 if (it == init_value_map_.end()) {
982 return nullptr;
983 }
984 auto init_val = it->second;
985 if (init_val == nullptr) {
986 // No reduction restriction. Just use zero
987 return GpuLower::current()->kernel()->zeroVal();
988 } else {
989 return init_val;
990 }
991}
992
993void PredicateElimination::build(Fusion* fusion) {
994 traverseTo(fusion, fusion->outputs());
995}
996
997std::string PredicateElimination::toString() const {
998 std::stringstream ss;
999 ss << "Tensors that do not need predication:";
1000 for (auto expr : non_predicated_exprs_) {
1001 for (auto out : expr->outputs()) {
1002 TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
1003 ss << " T" << out->name();
1004 }
1005 }
1006 ss << "\n";
1007 ss << "Init values:";
1008 for (auto kv : init_value_map_) {
1009 ss << " T" << kv.first->name() << "->";
1010 if (kv.second == nullptr) {
1011 ss << "<default(0)>";
1012 } else {
1013 ss << kv.second;
1014 }
1015 }
1016 ss << "\n";
1017 return ss.str();
1018}
1019
1020} // namespace cuda
1021} // namespace fuser
1022} // namespace jit
1023} // namespace torch
1024