1#include <expr_evaluator.h>
2#include <ir_iostream.h>
3#include <ir_utils.h>
4#include <lower2device.h>
5#include <lower_utils.h>
6#include <non_divisible_split.h>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13void NonDivisibleSplitInfo::build(Fusion* fusion) {
14 const auto vals = fusion->usedMathVals();
15 auto tvs = ir_utils::filterByType<TensorView>(vals);
16
17 // Find all non-divisible splits
18 for (auto tv : tvs) {
19 if (tv->isFusionInput()) {
20 continue;
21 }
22 const std::vector<Val*> domain_vals(
23 tv->domain()->domain().begin(), tv->domain()->domain().end());
24 current_tv_ = tv;
25 clearReachability();
26 traverseTo(fusion, domain_vals);
27 current_tv_ = nullptr;
28 }
29
30 if (GpuLower::current() != nullptr) {
31 removeRedundancy();
32 }
33}
34
35void NonDivisibleSplitInfo::handle(Split* split) {
36 if (split->in()->isBroadcast()) {
37 return;
38 }
39
40 // Indicates if this split is going to be either predicated or
41 // validated at run time
42 bool is_protected = false;
43
44 if (isReachableFromInnerDomains(split->in())) {
45 // check if this split may be non-divisible
46 auto maybe_non_divisible_extent = getMaybeNonDivisibleExtent(split);
47 if (maybe_non_divisible_extent) {
48 // If the outputs are vectorized, predication isn't
49 // sufficient, it must be divisible.
50 TORCH_INTERNAL_ASSERT(
51 split->outer()->getParallelType() != ParallelType::Vectorize);
52 if (split->inner()->getParallelType() == ParallelType::Vectorize) {
53 splits_to_validate_.insert(split);
54 } else {
55 // Not proven to be a divisible split
56 auto gpu_lower = GpuLower::current();
57 TORCH_INTERNAL_ASSERT(gpu_lower != nullptr);
58
59 // If we know this split must be divisible, it's either validated as
60 // above, exact matches to a case matching the above, or exact matches
61 // to a transformation from view which must be divisible.
62 if (gpu_lower->divisbleSplitSet().find(split) ==
63 gpu_lower->divisbleSplitSet().end()) {
64 splits_to_predicate_[current_tv_].push_back(split);
65 }
66 }
67
68 is_protected = true;
69 }
70 }
71
72 propagateReachability(split, is_protected);
73}
74
75bool NonDivisibleSplitInfo::isReachableFromInnerDomains(IterDomain* id) const {
76 return inner_domains_.find(id) != inner_domains_.end();
77}
78
79void NonDivisibleSplitInfo::clearReachability() {
80 inner_domains_.clear();
81}
82
83void NonDivisibleSplitInfo::propagateReachability(
84 Split* split,
85 bool is_protected) {
86 // Propagate down the reachability information. Descendants of the
87 // inner domain must be tracked.
88 inner_domains_.insert(split->inner());
89
90 // If this split itself is reachable, propagate the reachability to
91 // the outer output as well. However, if this split is protected,
92 // i.e., either predicated or validated, any potential effect by
93 // descendants of the outer domain is taken care by the predicate or
94 // run-time check of this split, so checking outer descendants isn't
95 // required.
96 if (isReachableFromInnerDomains(split->in()) && !is_protected) {
97 inner_domains_.insert(split->outer());
98 }
99}
100
101Val* NonDivisibleSplitInfo::getMaybeNonDivisibleExtent(Split* split) const {
102 ExpressionEvaluator ee(split->fusion());
103 auto in_extent = ee.evaluate(split->in()->extent());
104 auto factor = ee.evaluate(split->factor());
105
106 if (in_extent.has_value() && factor.has_value() &&
107 in_extent.value() % factor.value() == 0) {
108 return nullptr;
109 }
110
111 // even if the extent size is unknown, if the factor is known to
112 // be 1, it's always divisible
113 if (factor.has_value() && factor.value() == 1) {
114 return nullptr;
115 }
116
117 auto ceildiv_dom = split->innerSplit() ? split->outer() : split->inner();
118 return ceildiv_dom->extent();
119}
120
121void NonDivisibleSplitInfo::handle(Merge* merge) {
122 propagateReachability(merge);
123}
124
125void NonDivisibleSplitInfo::propagateReachability(Merge* merge) {
126 // Inner input index never exceeds its extent as it's computed as an
127 // remainder. Outer may do.
128 if (isReachableFromInnerDomains(merge->outer())) {
129 inner_domains_.insert(merge->out());
130 }
131}
132
133void NonDivisibleSplitInfo::removeRedundancy() {
134 auto gpu_lower = GpuLower::current();
135 TORCH_INTERNAL_ASSERT(gpu_lower != nullptr);
136
137 std::unordered_set<IterDomain*> split_to_validate_outer;
138 for (auto it = splits_to_validate_.begin();
139 it != splits_to_validate_.end();) {
140 auto outer_concrete = gpu_lower->caMap()->getConcreteMappedID(
141 (*it)->outer(), IdMappingMode::EXACT);
142 auto new_domain = split_to_validate_outer.insert(outer_concrete).second;
143 if (!new_domain) {
144 it = splits_to_validate_.erase(it);
145 } else {
146 ++it;
147 }
148 }
149
150 // If validated by runtime checks, no need to predicate
151 for (auto& kv : splits_to_predicate_) {
152 auto& splits = kv.second;
153 for (auto it = splits.begin(); it != splits.end();) {
154 // If the outer domain is mapped with the outer domain of any
155 // validated domain, it is safe to omit the predicate for the
156 // split.
157 Split* split_to_predicate = *it;
158 if (std::any_of(
159 splits_to_validate_.begin(),
160 splits_to_validate_.end(),
161 [&](Split* split_to_validate) {
162 return gpu_lower->caMap()->areMapped(
163 split_to_validate->outer(),
164 split_to_predicate->outer(),
165 IdMappingMode::EXACT);
166 })) {
167 it = splits.erase(it);
168 } else {
169 ++it;
170 }
171 }
172 }
173}
174
175} // namespace cuda
176} // namespace fuser
177} // namespace jit
178} // namespace torch
179