1 | #include <inlining.h> |
2 | #include <ir_utils.h> |
3 | #include <root_domain_map.h> |
4 | #include <transform_iter.h> |
5 | |
6 | #include <utility> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | MaxPosCalculator::MaxPosCalculator( |
14 | const std::unordered_set<IterDomain*>& uninlinable_ids) |
15 | : uninlinable_ids_(uninlinable_ids) { |
16 | buildUnmappableDims(); |
17 | } |
18 | |
19 | void MaxPosCalculator::buildUnmappableDims() { |
20 | ComputeAtRootDomainMap root_map; |
21 | root_map.build(); |
22 | auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); |
23 | for (auto tv : all_tvs) { |
24 | auto consumers = ir_utils::consumerTvsOf(tv); |
25 | for (auto consumer : consumers) { |
26 | // Grab dimensions in producer and consumer that are mappable to eachother |
27 | // based on the computeAtRootDomainMap. This will tell us which dimensions |
28 | // can be inlined based on avoiding trying to inline non-trivial |
29 | // reduction structures. |
30 | auto mappable_roots = |
31 | root_map.getMappableDims(tv->domain(), consumer->domain()); |
32 | for (auto tv_root_id : tv->getMaybeRFactorDomain()) { |
33 | if (mappable_roots.find(tv_root_id) == mappable_roots.end() && |
34 | !tv_root_id->isTrivialReduction()) { |
35 | unmappable_dims_.emplace(tv_root_id); |
36 | } |
37 | } |
38 | } |
39 | } |
40 | } |
41 | |
42 | bool MaxPosCalculator::isAllowedID( |
43 | IterDomain* id, |
44 | TensorView* tv, |
45 | bool best_effort, |
46 | bool allow_reduction, |
47 | bool allow_vectorize, |
48 | bool allow_unmappable) const { |
49 | bool allowed = true; |
50 | |
51 | if (!allow_reduction) { |
52 | allowed = allowed && !id->isReduction(); |
53 | } |
54 | |
55 | if (uninlinable_ids_.count(id)) { |
56 | return false; |
57 | } |
58 | |
59 | if (!allow_vectorize) { |
60 | // Avoid inlining if marked as Vectorize or Group. In the case of |
61 | // BestEffort and MostInlined modes, avoid Unroll as well. |
62 | bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) || |
63 | id->getParallelType() == ParallelType::Group || |
64 | (best_effort && id->getParallelType() == ParallelType::Unroll); |
65 | allowed = allowed && !is_vectorize; |
66 | } |
67 | |
68 | if (!allow_unmappable) { |
69 | auto root_dom = tv->getMaybeRFactorDomain(); |
70 | std::unordered_set<Val*> root_dom_set(root_dom.begin(), root_dom.end()); |
71 | auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); |
72 | bool is_unmappable = false; |
73 | for (auto val : all_vals) { |
74 | auto id = val->as<IterDomain>(); |
75 | if (root_dom_set.count(val) > 0 && unmappable_dims_.count(id) > 0) { |
76 | is_unmappable = true; |
77 | break; |
78 | } |
79 | } |
80 | allowed = allowed && !is_unmappable; |
81 | } |
82 | |
83 | return allowed; |
84 | } |
85 | |
86 | size_t MaxPosCalculator::getMaxPosSelf( |
87 | TensorView* tv, |
88 | bool best_effort, |
89 | bool allow_reduction, |
90 | bool allow_vectorize, |
91 | bool allow_unmappable) const { |
92 | auto dom = tv->domain()->domain(); |
93 | auto iter = std::find_if(dom.begin(), dom.end(), [=](IterDomain* id) { |
94 | return !isAllowedID( |
95 | id, |
96 | tv, |
97 | best_effort, |
98 | allow_reduction, |
99 | allow_vectorize, |
100 | allow_unmappable); |
101 | }); |
102 | return std::distance(dom.begin(), iter); |
103 | } |
104 | |
105 | // Return the max position in producer that can be inlined to consumer |
106 | // Cannot inline: |
107 | // Vectorized dimensions in consumer |
108 | // Unrolled dimensions in consumer |
109 | size_t MaxPosCalculator::getMaxProducerPosFromConsumer( |
110 | TensorView* producer, |
111 | TensorView* consumer, |
112 | bool best_effort) const { |
113 | auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); |
114 | auto replay_CasP = |
115 | BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); |
116 | auto p2c_replay_map = replay_CasP.getReplay(); |
117 | |
118 | for (size_t producer_pos = 0; producer_pos < producer->nDims(); |
119 | producer_pos++) { |
120 | // If the producer position is mismatching with the consumer, then we can |
121 | // not inline into this position, otherwise the max producer position of |
122 | // the consumer will become invalid and expression sort will fail. |
123 | if (TransformReplay::getMatchedLeafPosWithoutReplayCasP( |
124 | consumer, producer, producer_pos + 1) < 0) { |
125 | return producer_pos; |
126 | } |
127 | auto map_it = p2c_replay_map.find(producer->axis(producer_pos)); |
128 | if (map_it != p2c_replay_map.end()) { |
129 | auto c_id = map_it->second; |
130 | if (!isAllowedID(c_id, consumer, best_effort, true, false, true)) { |
131 | return producer_pos; |
132 | } |
133 | } |
134 | } |
135 | return producer->nDims(); |
136 | } |
137 | |
138 | size_t MaxPosCalculator::getMaxPosAll( |
139 | TensorView* tv, |
140 | bool best_effort, |
141 | bool check_siblings) { |
142 | auto max_pos = getMaxPosSelf(tv, best_effort, false, false, false); |
143 | for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { |
144 | max_pos = std::min<size_t>( |
145 | max_pos, getMaxProducerPosFromConsumer(tv, consumer_tv, best_effort)); |
146 | } |
147 | if (check_siblings) { |
148 | for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { |
149 | max_pos = std::min<size_t>( |
150 | max_pos, getMaxPosAll(sibling_tv, best_effort, false)); |
151 | } |
152 | } |
153 | return max_pos; |
154 | } |
155 | |
156 | void inlineMost(const std::unordered_set<IterDomain*>& uninlinable_ids) { |
157 | inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids); |
158 | } |
159 | |
160 | void inlineMost( |
161 | const std::vector<TensorView*>& tvs, |
162 | const std::unordered_set<IterDomain*>& uninlinable_ids) { |
163 | if (tvs.empty()) { |
164 | return; |
165 | } |
166 | MaxPosCalculator calc(uninlinable_ids); |
167 | for (auto tv : tvs) { |
168 | tv->inlineAt(-1, true, &calc); |
169 | } |
170 | } |
171 | |
172 | void inlineMost( |
173 | const std::unordered_set<TensorView*>& tvs, |
174 | const std::unordered_set<IterDomain*>& uninlinable_ids) { |
175 | if (tvs.empty()) { |
176 | return; |
177 | } |
178 | MaxPosCalculator calc(uninlinable_ids); |
179 | for (auto tv : tvs) { |
180 | tv->inlineAt(-1, true, &calc); |
181 | } |
182 | } |
183 | |
184 | namespace { |
185 | |
186 | // Find the positions of `selected` tensors that is mapped to the given position |
187 | // in the reference tensor. |
188 | class FindMappedPositions : public MaxInfoSpanningTree::Propagator { |
189 | std::unordered_map<TensorView*, size_t>& output_; |
190 | |
191 | public: |
192 | FindMappedPositions( |
193 | std::unordered_map<TensorView*, size_t>& output, |
194 | TensorView* reference, |
195 | int64_t reference_pos); |
196 | |
197 | ~FindMappedPositions() = default; |
198 | |
199 | virtual void propagateC2P(TensorView* from, TensorView* to) override; |
200 | virtual void propagateP2C(TensorView* from, TensorView* to) override; |
201 | virtual void propagateSibling(TensorView* from, TensorView* to) override; |
202 | }; |
203 | |
204 | FindMappedPositions::FindMappedPositions( |
205 | std::unordered_map<TensorView*, size_t>& output, |
206 | TensorView* reference, |
207 | int64_t reference_pos) |
208 | : output_(output) { |
209 | if (reference_pos < 0) { |
210 | reference_pos += int64_t(reference->nDims()) + 1; |
211 | } |
212 | TORCH_CHECK( |
213 | reference_pos >= 0 && reference_pos <= int64_t(reference->nDims()), |
214 | "Invalid axis received " , |
215 | reference_pos, |
216 | " but should be > -" , |
217 | reference->nDims(), |
218 | " and <= " , |
219 | reference->nDims(), |
220 | "." ); |
221 | output_[reference] = reference_pos; |
222 | } |
223 | |
224 | void FindMappedPositions::propagateC2P(TensorView* from, TensorView* to) { |
225 | int from_pos = output_.at(from); |
226 | auto to_pos = |
227 | TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); |
228 | // If there is no matching position found, we compute the highest matched |
229 | // position as the closest approximation |
230 | while (to_pos < 0) { |
231 | from_pos--; |
232 | to_pos = |
233 | TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); |
234 | } |
235 | output_[to] = to_pos; |
236 | } |
237 | |
238 | void FindMappedPositions::propagateP2C(TensorView* from, TensorView* to) { |
239 | int from_pos = output_.at(from); |
240 | auto to_pos = |
241 | TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); |
242 | // If there is no matching position found, we compute the highest matched |
243 | // position as the closest approximation |
244 | while (to_pos < 0) { |
245 | from_pos--; |
246 | to_pos = |
247 | TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); |
248 | } |
249 | output_[to] = to_pos; |
250 | } |
251 | |
252 | void FindMappedPositions::propagateSibling(TensorView* from, TensorView* to) { |
253 | auto from_pos = output_.at(from); |
254 | TORCH_CHECK( |
255 | TransformReplay::fullSelfMatching(to, from), |
256 | "Transformations in siblings " , |
257 | from, |
258 | " and " , |
259 | to, |
260 | " does not match with each other." ); |
261 | output_[to] = from_pos; |
262 | } |
263 | |
264 | std::unordered_map<TensorView*, size_t> getPositionsMappedTo( |
265 | TensorView* reference_tv, |
266 | int64_t reference_pos) { |
267 | std::unordered_map<TensorView*, size_t> mapped_positions; |
268 | MaxRootDomainInfoSpanningTree tree(reference_tv, reference_pos); |
269 | FindMappedPositions propagator(mapped_positions, reference_tv, reference_pos); |
270 | tree.traverse(&propagator); |
271 | return mapped_positions; |
272 | } |
273 | |
274 | } // namespace |
275 | |
276 | void inlineAllAt( |
277 | TensorView* reference_tv, |
278 | int64_t reference_pos, |
279 | bool best_effort, |
280 | const std::unordered_set<IterDomain*>& uninlinable_ids) { |
281 | auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); |
282 | MaxPosCalculator calc(uninlinable_ids); |
283 | for (auto pair : mapped_positions) { |
284 | pair.first->inlineAt(pair.second, best_effort, &calc); |
285 | } |
286 | } |
287 | |
288 | void inlineSelectedAt( |
289 | const std::unordered_set<TensorView*>& selected, |
290 | TensorView* reference_tv, |
291 | int64_t reference_pos, |
292 | bool best_effort, |
293 | const std::unordered_set<IterDomain*>& uninlinable_ids) { |
294 | auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); |
295 | MaxPosCalculator calc(uninlinable_ids); |
296 | for (auto pair : mapped_positions) { |
297 | if (selected.count(pair.first) > 0) { |
298 | pair.first->inlineAt(pair.second, best_effort, &calc); |
299 | } |
300 | } |
301 | } |
302 | |
303 | } // namespace cuda |
304 | } // namespace fuser |
305 | } // namespace jit |
306 | } // namespace torch |
307 | |