1 | #include <maxinfo_propagator.h> |
2 | #include <root_domain_map.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | namespace fuser { |
7 | namespace cuda { |
8 | |
9 | bool MaxInfoSpanningTree::Information::operator>(const Information& r) const { |
10 | return r < *this; |
11 | } |
12 | |
13 | bool MaxInfoSpanningTree::Information::operator==(const Information& r) const { |
14 | return !(r < *this) && !(*this < r); |
15 | } |
16 | |
17 | // Prim's algorithm |
18 | MaxInfoSpanningTree::MaxInfoSpanningTree( |
19 | TensorView* reference, |
20 | std::shared_ptr<Information> reference_info, |
21 | Selector* selector) |
22 | : reference_(reference), |
23 | reference_info_(reference_info), |
24 | selector_(selector) {} |
25 | |
26 | void MaxInfoSpanningTree::compute_spanning_tree() { |
27 | // A set that allows us to quickly tell if a tensor has been replayed. If yes, |
28 | // then we will not bother computing if a new path to this tensor is worth |
29 | // taking (because the answer is always not worth) |
30 | std::unordered_set<TensorView*> replayed; |
31 | |
32 | // A sorted list of possible next steps. The list is sorted in the order of |
33 | // ascending amount of preserved information about the reference tensor. The |
34 | // back of the list preserves the most amount of information about the |
35 | // reference tensor, and should always be the next step to take. We use |
36 | // std::list instead of std::priority_queue because C++'s |
37 | // std::priority_queue does not support increase-key, and might not be |
38 | // deterministic either. |
39 | std::list<NextHopWithInfo> candidates(1); |
40 | candidates.back().next_hop.from = nullptr; |
41 | candidates.back().next_hop.to = reference_; |
42 | candidates.back().info_to = reference_info_; |
43 | |
44 | // Insert the given next hop the correct position in `candidates`. If there |
45 | // is an existing next hop that preserves more information, then we will just |
46 | // discard `info`. |
47 | auto insertNextHop = [&](const NextHopWithInfo& info) { |
48 | if (!*(info.info_from)) { |
49 | // When there is no more information about the starting tensor, |
50 | // we are not interested in continuing the path-finding. |
51 | return; |
52 | } |
53 | // Find if there is already a path to the dest tensor |
54 | auto existing = std::find_if( |
55 | candidates.begin(), candidates.end(), [&](const NextHopWithInfo& i) { |
56 | return i.next_hop.to == info.next_hop.to; |
57 | }); |
58 | // Only insert if there is no existing path to the dest tensor, or the new |
59 | // path preserves more information about the starting tensor. |
60 | if (existing == candidates.end() || *existing < info) { |
61 | if (existing != candidates.end()) { |
62 | candidates.erase(existing); |
63 | } |
64 | auto pos = std::upper_bound(candidates.begin(), candidates.end(), info); |
65 | candidates.insert(pos, info); |
66 | } |
67 | }; |
68 | |
69 | auto allowC2P = [this](TensorView* from, TensorView* to) { |
70 | if (selector_ == nullptr) { |
71 | return true; |
72 | } |
73 | return selector_->allowC2P(from, to); |
74 | }; |
75 | |
76 | auto allowP2C = [this](TensorView* from, TensorView* to) { |
77 | if (selector_ == nullptr) { |
78 | return true; |
79 | } |
80 | return selector_->allowP2C(from, to); |
81 | }; |
82 | |
83 | auto allowSibling = [this](TensorView* from, TensorView* to) { |
84 | if (selector_ == nullptr) { |
85 | return true; |
86 | } |
87 | return selector_->allowSibling(from, to); |
88 | }; |
89 | |
90 | while (!candidates.empty()) { |
91 | const auto next_hop_info = candidates.back(); |
92 | const auto& next_hop = next_hop_info.next_hop; |
93 | candidates.pop_back(); |
94 | |
95 | if (next_hop.from != nullptr) { |
96 | // nullptr used to start from reference |
97 | path_.push_back(next_hop); |
98 | } |
99 | replayed.emplace(next_hop.to); |
100 | |
101 | for (auto sibling_tv : ir_utils::siblingTvsOf(next_hop.to)) { |
102 | if (replayed.count(sibling_tv) || |
103 | !allowSibling(next_hop.to, sibling_tv)) { |
104 | continue; |
105 | } |
106 | insertNextHop(NextHopWithInfo( |
107 | NextHop(NextHopType::SIBLING, next_hop.to, sibling_tv), |
108 | next_hop_info.info_to, |
109 | computeInfoSibling(next_hop.to, sibling_tv, next_hop_info.info_to))); |
110 | } |
111 | |
112 | for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) { |
113 | if (replayed.count(consumer_tv) || !allowP2C(next_hop.to, consumer_tv)) { |
114 | continue; |
115 | } |
116 | insertNextHop(NextHopWithInfo( |
117 | NextHop(NextHopType::C_AS_P, next_hop.to, consumer_tv), |
118 | next_hop_info.info_to, |
119 | computeInfoCasP(next_hop.to, consumer_tv, next_hop_info.info_to))); |
120 | } |
121 | |
122 | for (auto producer_tv : ir_utils::producerTvsOf(next_hop.to)) { |
123 | if (replayed.count(producer_tv) || !allowC2P(next_hop.to, producer_tv)) { |
124 | continue; |
125 | } |
126 | insertNextHop(NextHopWithInfo( |
127 | NextHop(NextHopType::P_AS_C, next_hop.to, producer_tv), |
128 | next_hop_info.info_to, |
129 | computeInfoPasC(next_hop.to, producer_tv, next_hop_info.info_to))); |
130 | } |
131 | } |
132 | } |
133 | |
134 | void MaxInfoSpanningTree::traverse(Propagator* propagator) { |
135 | if (path_.empty()) { |
136 | compute_spanning_tree(); |
137 | } |
138 | propagator->setUp(); |
139 | for (const auto& next_hop : path_) { |
140 | switch (next_hop.type) { |
141 | case NextHopType::SIBLING: |
142 | propagator->propagateSibling(next_hop.from, next_hop.to); |
143 | break; |
144 | case NextHopType::C_AS_P: |
145 | propagator->propagateP2C(next_hop.from, next_hop.to); |
146 | break; |
147 | case NextHopType::P_AS_C: |
148 | propagator->propagateC2P(next_hop.from, next_hop.to); |
149 | break; |
150 | } |
151 | } |
152 | propagator->tearDown(); |
153 | } |
154 | |
155 | MaxRootDomainInfoSpanningTree::RootDomainInfo::operator bool() const { |
156 | return !info.empty(); |
157 | } |
158 | |
159 | bool MaxRootDomainInfoSpanningTree::RootDomainInfo::operator<( |
160 | const Information& r) const { |
161 | auto rr = dynamic_cast<const RootDomainInfo&>(r); |
162 | if (info.size() != rr.info.size()) { |
163 | return info.size() < rr.info.size(); |
164 | } |
165 | size_t l_complete = |
166 | std::count_if(info.begin(), info.end(), [](const RootIDInfo& i) { |
167 | return i.is_complete; |
168 | }); |
169 | size_t r_complete = |
170 | std::count_if(rr.info.begin(), rr.info.end(), [](const RootIDInfo& i) { |
171 | return i.is_complete; |
172 | }); |
173 | return l_complete < r_complete; |
174 | } |
175 | |
176 | namespace { |
177 | |
178 | // Given `root_ids`, a list of IDs in the root domain of `tv`, find their |
179 | // corresponding IDs in the rfactor domain of `tv`. |
180 | std::unordered_set<IterDomain*> mapRootToRFactor( |
181 | TensorView* tv, |
182 | const std::unordered_set<IterDomain*>& root_ids) { |
183 | std::unordered_set<IterDomain*> mapped_rfactor_ids; |
184 | const auto& rfactor_dom = tv->getMaybeRFactorDomain(); |
185 | for (auto id : rfactor_dom) { |
186 | if (root_ids.count(id) > 0) { |
187 | mapped_rfactor_ids.emplace(id); |
188 | continue; |
189 | } |
190 | for (auto root_id : root_ids) { |
191 | if (id == root_id || DependencyCheck::isDependencyOf(root_id, id)) { |
192 | mapped_rfactor_ids.emplace(id); |
193 | break; |
194 | } |
195 | } |
196 | } |
197 | return mapped_rfactor_ids; |
198 | } |
199 | |
200 | // Given `rfactor_ids`, a list of IDs in the rfactor domain of `tv`, find their |
201 | // corresponding IDs in the root domain of `tv`. |
202 | std::unordered_set<IterDomain*> mapRFactorToRoot( |
203 | TensorView* tv, |
204 | const std::unordered_set<IterDomain*>& rfactor_ids) { |
205 | std::unordered_set<IterDomain*> mapped_root_ids; |
206 | for (auto id : tv->getRootDomain()) { |
207 | if (rfactor_ids.count(id) > 0) { |
208 | mapped_root_ids.emplace(id); |
209 | continue; |
210 | } |
211 | for (auto rfactor_id : rfactor_ids) { |
212 | if (DependencyCheck::isDependencyOf(id, rfactor_id)) { |
213 | mapped_root_ids.emplace(id); |
214 | break; |
215 | } |
216 | } |
217 | } |
218 | return mapped_root_ids; |
219 | } |
220 | |
221 | } // namespace |
222 | |
223 | // Given the preserved reference root ID info of a producer, compute |
224 | // the corresponding info in consumer. The given info may be represented by |
225 | // producer's root domain, or rfactor domain, depending on how we reached the |
226 | // producer during path-finding. If the given info is already represented with |
227 | // producer's rfactor domain, then we directly map it to the consumer's root |
228 | // domain. If the given info is represented with producer's root domain, we need |
229 | // to first map it to the rfactor domain of the producer, then we can map it to |
230 | // the consumer's root domain. The computed info will be represented by root |
231 | // domain as root domain contains the raw information. |
232 | std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree:: |
233 | computeInfoCasP( |
234 | TensorView* from, |
235 | TensorView* to, |
236 | std::shared_ptr<Information> from_info) const { |
237 | RootDomainInfo result; |
238 | |
239 | TensorView* producer = from; |
240 | TensorView* consumer = to; |
241 | const auto& producer_root_id_info = |
242 | std::dynamic_pointer_cast<RootDomainInfo>(from_info)->info; |
243 | |
244 | auto pairwise_map = PairwiseRootDomainMap(producer, consumer); |
245 | auto p2c_map = pairwise_map.mapProducerToConsumer( |
246 | producer->domain(), consumer->domain()); |
247 | |
248 | for (auto& info : producer_root_id_info) { |
249 | RootIDInfo consumer_info; |
250 | consumer_info.is_complete = info.is_complete; |
251 | consumer_info.is_rfactor = false; |
252 | |
253 | // mapped root ids in producer -> mapped rfactor ids in producer |
254 | std::unordered_set<IterDomain*> producer_mapped_rfactor_ids; |
255 | if (producer->hasRFactor() && !info.is_rfactor) { |
256 | producer_mapped_rfactor_ids = mapRootToRFactor(producer, info.mapped_ids); |
257 | } else { |
258 | producer_mapped_rfactor_ids = info.mapped_ids; |
259 | } |
260 | |
261 | // mapped rfactor ids in producer -> mapped root ids in consumer |
262 | for (auto producer_id : producer_mapped_rfactor_ids) { |
263 | auto it = p2c_map.find(producer_id); |
264 | if (it != p2c_map.end()) { |
265 | consumer_info.mapped_ids.insert(it->second); |
266 | } else { |
267 | consumer_info.is_complete = false; |
268 | } |
269 | } |
270 | |
271 | // If at least one root id in the consumer contains information |
272 | // of this starting root id, then keep this record |
273 | if (!consumer_info.mapped_ids.empty()) { |
274 | result.info.push_back(consumer_info); |
275 | } |
276 | } |
277 | return std::make_shared<RootDomainInfo>(std::move(result)); |
278 | } |
279 | |
280 | // Given the preserved reference root ID info of a consumer, compute |
281 | // the corresponding info in producer. The given info may be represented by |
282 | // consumer's root domain, or rfactor domain, depending on how we reached the |
283 | // consumer during path-finding. If the given info is already represented with |
284 | // consumer's root domain, then we directly map it to the producer's rfactor |
285 | // domain. If the given info is represented with consumer's rfactor domain, we |
286 | // need to first map it to the root domain of the consumer, then we can map it |
287 | // to the producer's rfactor domain. The computed info will be represented by |
288 | // rfactor domain as rfactor domain contains the raw information. |
289 | std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree:: |
290 | computeInfoPasC( |
291 | TensorView* from, |
292 | TensorView* to, |
293 | std::shared_ptr<Information> from_info) const { |
294 | RootDomainInfo result; |
295 | |
296 | TensorView* producer = to; |
297 | TensorView* consumer = from; |
298 | const auto& consumer_root_id_info = |
299 | std::dynamic_pointer_cast<RootDomainInfo>(from_info)->info; |
300 | |
301 | auto pairwise_map = PairwiseRootDomainMap(producer, consumer); |
302 | auto c2p_map = pairwise_map.mapConsumerToProducer( |
303 | consumer->domain(), producer->domain()); |
304 | |
305 | for (auto& info : consumer_root_id_info) { |
306 | RootIDInfo producer_info; |
307 | producer_info.is_complete = info.is_complete; |
308 | producer_info.is_rfactor = true; |
309 | |
310 | // mapped rfactor ids in consumer -> mapped root ids in consumer |
311 | std::unordered_set<IterDomain*> consumer_mapped_root_ids; |
312 | if (info.is_rfactor && consumer->hasRFactor()) { |
313 | consumer_mapped_root_ids = mapRFactorToRoot(consumer, info.mapped_ids); |
314 | } else { |
315 | consumer_mapped_root_ids = info.mapped_ids; |
316 | } |
317 | |
318 | // mapped root ids in consumer -> mapped rfactor ids in producer |
319 | for (auto consumer_id : consumer_mapped_root_ids) { |
320 | auto it = c2p_map.find(consumer_id); |
321 | if (it != c2p_map.end()) { |
322 | producer_info.mapped_ids.insert(it->second); |
323 | } else { |
324 | producer_info.is_complete = false; |
325 | } |
326 | } |
327 | |
328 | // We will stop at the rfactor ids in producer, and will not further map |
329 | // them into root ids in producer. This means, we only keep the unprocessed |
330 | // raw information of a tensor. This behavior is important to make sure that |
331 | // info is as accurate as possible throughout the path-finding. |
332 | // |
333 | // For example, in a C->P->C' path, we want to do |
334 | // C(root) -> P(rfactor) -> C'(root) |
335 | // instead of |
336 | // C(root) -> P(rfactor) -> P(root) -> P(rfactor) -> C'(root) |
337 | // |
338 | // and the above two paths do lead to different results: |
339 | // |
340 | // For example if you have a producer tensor |
341 | // root domain: [I1, I2] |
342 | // rfactor domain: [I3, I5] |
343 | // where I3, I4 = split(I1), I5 = merge(I4, I2) |
344 | // Then the P(rfactor) -> P(root) -> P(rfactor) could lead to |
345 | // P(rfactor: {I5}) -> P(root: {I1, I2}) -> P(rfactor: {I3, I5}) |
346 | // which is not correct |
347 | |
348 | // If at least one root id in the producer contains information |
349 | // of this starting root id, then keep this record |
350 | if (!producer_info.mapped_ids.empty()) { |
351 | result.info.push_back(producer_info); |
352 | } |
353 | } |
354 | return std::make_shared<RootDomainInfo>(std::move(result)); |
355 | } |
356 | |
357 | std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo> |
358 | MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo(TensorView* tv) { |
359 | RootDomainInfo result; |
360 | const auto& root_domain = tv->getRootDomain(); |
361 | result.info.reserve(root_domain.size()); |
362 | for (auto id : root_domain) { |
363 | result.info.emplace_back(RootIDInfo{{id}, true, false}); |
364 | } |
365 | return std::make_shared<RootDomainInfo>(std::move(result)); |
366 | } |
367 | |
368 | std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo> |
369 | MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo( |
370 | TensorView* tv, |
371 | int64_t leaf_pos) { |
372 | if (leaf_pos < 0) { |
373 | leaf_pos += int64_t(tv->nDims()) + 1; |
374 | } |
375 | TORCH_CHECK( |
376 | leaf_pos >= 0 && leaf_pos <= int64_t(tv->nDims()), |
377 | "MaxRootDomainInfoSpanningTree called on an leaf_pos outside valid range." ); |
378 | RootDomainInfo result; |
379 | const auto& root_domain = tv->getMaybeRFactorDomain(); |
380 | const auto& leaf_domain = tv->domain()->domain(); |
381 | std::unordered_set<IterDomain*> selected_leaves( |
382 | leaf_domain.begin(), leaf_domain.begin() + leaf_pos); |
383 | for (auto id : root_domain) { |
384 | if (selected_leaves.count(id) > 0) { |
385 | result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()}); |
386 | continue; |
387 | } |
388 | for (auto selected_leaf_id : selected_leaves) { |
389 | if (DependencyCheck::isDependencyOf(id, selected_leaf_id)) { |
390 | result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()}); |
391 | break; |
392 | } |
393 | } |
394 | } |
395 | return std::make_shared<RootDomainInfo>(std::move(result)); |
396 | } |
397 | |
398 | // Given the preserved reference root ID info of a tensor, compute |
399 | // the corresponding info in its sibling. Since info has nothing to do with |
400 | // replay state, so sibling info is always identical by definition. |
401 | std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree:: |
402 | computeInfoSibling( |
403 | TensorView* from, |
404 | TensorView* to, |
405 | std::shared_ptr<Information> from_info) const { |
406 | return from_info; |
407 | } |
408 | |
409 | void SpanningTreePrinter::propagateC2P(TensorView* from, TensorView* to) { |
410 | stream_ << "propagateC2P" << std::endl; |
411 | stream_ << " from: " << from->toString() << std::endl; |
412 | stream_ << " to: " << to->toString() << std::endl; |
413 | } |
414 | |
415 | void SpanningTreePrinter::propagateP2C(TensorView* from, TensorView* to) { |
416 | stream_ << "propagateP2C" << std::endl; |
417 | stream_ << " from: " << from->toString() << std::endl; |
418 | stream_ << " to: " << to->toString() << std::endl; |
419 | } |
420 | |
421 | void SpanningTreePrinter::propagateSibling(TensorView* from, TensorView* to) { |
422 | stream_ << "propagateSibling" << std::endl; |
423 | stream_ << " from: " << from->toString() << std::endl; |
424 | stream_ << " to: " << to->toString() << std::endl; |
425 | } |
426 | |
427 | bool SetSelector::allowC2P(TensorView* from, TensorView* to) { |
428 | return selected_.count(to) > 0; |
429 | } |
430 | |
431 | bool SetSelector::allowP2C(TensorView* from, TensorView* to) { |
432 | return selected_.count(to) > 0; |
433 | } |
434 | |
435 | bool SetSelector::allowSibling(TensorView* from, TensorView* to) { |
436 | return true; |
437 | } |
438 | |
439 | } // namespace cuda |
440 | } // namespace fuser |
441 | } // namespace jit |
442 | } // namespace torch |
443 | |