1 | #include <expr_evaluator.h> |
2 | #include <ir_iostream.h> |
3 | #include <ir_utils.h> |
4 | #include <lower2device.h> |
5 | #include <lower_utils.h> |
6 | #include <non_divisible_split.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | void NonDivisibleSplitInfo::build(Fusion* fusion) { |
14 | const auto vals = fusion->usedMathVals(); |
15 | auto tvs = ir_utils::filterByType<TensorView>(vals); |
16 | |
17 | // Find all non-divisible splits |
18 | for (auto tv : tvs) { |
19 | if (tv->isFusionInput()) { |
20 | continue; |
21 | } |
22 | const std::vector<Val*> domain_vals( |
23 | tv->domain()->domain().begin(), tv->domain()->domain().end()); |
24 | current_tv_ = tv; |
25 | clearReachability(); |
26 | traverseTo(fusion, domain_vals); |
27 | current_tv_ = nullptr; |
28 | } |
29 | |
30 | if (GpuLower::current() != nullptr) { |
31 | removeRedundancy(); |
32 | } |
33 | } |
34 | |
35 | void NonDivisibleSplitInfo::handle(Split* split) { |
36 | if (split->in()->isBroadcast()) { |
37 | return; |
38 | } |
39 | |
40 | // Indicates if this split is going to be either predicated or |
41 | // validated at run time |
42 | bool is_protected = false; |
43 | |
44 | if (isReachableFromInnerDomains(split->in())) { |
45 | // check if this split may be non-divisible |
46 | auto maybe_non_divisible_extent = getMaybeNonDivisibleExtent(split); |
47 | if (maybe_non_divisible_extent) { |
48 | // If the outputs are vectorized, predication isn't |
49 | // sufficient, it must be divisible. |
50 | TORCH_INTERNAL_ASSERT( |
51 | split->outer()->getParallelType() != ParallelType::Vectorize); |
52 | if (split->inner()->getParallelType() == ParallelType::Vectorize) { |
53 | splits_to_validate_.insert(split); |
54 | } else { |
55 | // Not proven to be a divisible split |
56 | auto gpu_lower = GpuLower::current(); |
57 | TORCH_INTERNAL_ASSERT(gpu_lower != nullptr); |
58 | |
59 | // If we know this split must be divisible, it's either validated as |
60 | // above, exact matches to a case matching the above, or exact matches |
61 | // to a transformation from view which must be divisible. |
62 | if (gpu_lower->divisbleSplitSet().find(split) == |
63 | gpu_lower->divisbleSplitSet().end()) { |
64 | splits_to_predicate_[current_tv_].push_back(split); |
65 | } |
66 | } |
67 | |
68 | is_protected = true; |
69 | } |
70 | } |
71 | |
72 | propagateReachability(split, is_protected); |
73 | } |
74 | |
75 | bool NonDivisibleSplitInfo::isReachableFromInnerDomains(IterDomain* id) const { |
76 | return inner_domains_.find(id) != inner_domains_.end(); |
77 | } |
78 | |
79 | void NonDivisibleSplitInfo::clearReachability() { |
80 | inner_domains_.clear(); |
81 | } |
82 | |
83 | void NonDivisibleSplitInfo::propagateReachability( |
84 | Split* split, |
85 | bool is_protected) { |
86 | // Propagate down the reachability information. Descendants of the |
87 | // inner domain must be tracked. |
88 | inner_domains_.insert(split->inner()); |
89 | |
90 | // If this split itself is reachable, propagate the reachability to |
91 | // the outer output as well. However, if this split is protected, |
92 | // i.e., either predicated or validated, any potential effect by |
93 | // descendants of the outer domain is taken care by the predicate or |
94 | // run-time check of this split, so checking outer descendants isn't |
95 | // required. |
96 | if (isReachableFromInnerDomains(split->in()) && !is_protected) { |
97 | inner_domains_.insert(split->outer()); |
98 | } |
99 | } |
100 | |
101 | Val* NonDivisibleSplitInfo::getMaybeNonDivisibleExtent(Split* split) const { |
102 | ExpressionEvaluator ee(split->fusion()); |
103 | auto in_extent = ee.evaluate(split->in()->extent()); |
104 | auto factor = ee.evaluate(split->factor()); |
105 | |
106 | if (in_extent.has_value() && factor.has_value() && |
107 | in_extent.value() % factor.value() == 0) { |
108 | return nullptr; |
109 | } |
110 | |
111 | // even if the extent size is unknown, if the factor is known to |
112 | // be 1, it's always divisible |
113 | if (factor.has_value() && factor.value() == 1) { |
114 | return nullptr; |
115 | } |
116 | |
117 | auto ceildiv_dom = split->innerSplit() ? split->outer() : split->inner(); |
118 | return ceildiv_dom->extent(); |
119 | } |
120 | |
121 | void NonDivisibleSplitInfo::handle(Merge* merge) { |
122 | propagateReachability(merge); |
123 | } |
124 | |
125 | void NonDivisibleSplitInfo::propagateReachability(Merge* merge) { |
126 | // Inner input index never exceeds its extent as it's computed as an |
127 | // remainder. Outer may do. |
128 | if (isReachableFromInnerDomains(merge->outer())) { |
129 | inner_domains_.insert(merge->out()); |
130 | } |
131 | } |
132 | |
133 | void NonDivisibleSplitInfo::removeRedundancy() { |
134 | auto gpu_lower = GpuLower::current(); |
135 | TORCH_INTERNAL_ASSERT(gpu_lower != nullptr); |
136 | |
137 | std::unordered_set<IterDomain*> split_to_validate_outer; |
138 | for (auto it = splits_to_validate_.begin(); |
139 | it != splits_to_validate_.end();) { |
140 | auto outer_concrete = gpu_lower->caMap()->getConcreteMappedID( |
141 | (*it)->outer(), IdMappingMode::EXACT); |
142 | auto new_domain = split_to_validate_outer.insert(outer_concrete).second; |
143 | if (!new_domain) { |
144 | it = splits_to_validate_.erase(it); |
145 | } else { |
146 | ++it; |
147 | } |
148 | } |
149 | |
150 | // If validated by runtime checks, no need to predicate |
151 | for (auto& kv : splits_to_predicate_) { |
152 | auto& splits = kv.second; |
153 | for (auto it = splits.begin(); it != splits.end();) { |
154 | // If the outer domain is mapped with the outer domain of any |
155 | // validated domain, it is safe to omit the predicate for the |
156 | // split. |
157 | Split* split_to_predicate = *it; |
158 | if (std::any_of( |
159 | splits_to_validate_.begin(), |
160 | splits_to_validate_.end(), |
161 | [&](Split* split_to_validate) { |
162 | return gpu_lower->caMap()->areMapped( |
163 | split_to_validate->outer(), |
164 | split_to_predicate->outer(), |
165 | IdMappingMode::EXACT); |
166 | })) { |
167 | it = splits.erase(it); |
168 | } else { |
169 | ++it; |
170 | } |
171 | } |
172 | } |
173 | } |
174 | |
175 | } // namespace cuda |
176 | } // namespace fuser |
177 | } // namespace jit |
178 | } // namespace torch |
179 | |