1#include <maxinfo_propagator.h>
2#include <root_domain_map.h>
3
4namespace torch {
5namespace jit {
6namespace fuser {
7namespace cuda {
8
9bool MaxInfoSpanningTree::Information::operator>(const Information& r) const {
10 return r < *this;
11}
12
13bool MaxInfoSpanningTree::Information::operator==(const Information& r) const {
14 return !(r < *this) && !(*this < r);
15}
16
17// Prim's algorithm
18MaxInfoSpanningTree::MaxInfoSpanningTree(
19 TensorView* reference,
20 std::shared_ptr<Information> reference_info,
21 Selector* selector)
22 : reference_(reference),
23 reference_info_(reference_info),
24 selector_(selector) {}
25
26void MaxInfoSpanningTree::compute_spanning_tree() {
27 // A set that allows us to quickly tell if a tensor has been replayed. If yes,
28 // then we will not bother computing if a new path to this tensor is worth
29 // taking (because the answer is always not worth)
30 std::unordered_set<TensorView*> replayed;
31
32 // A sorted list of possible next steps. The list is sorted in the order of
33 // ascending amount of preserved information about the reference tensor. The
34 // back of the list preserves the most amount of information about the
35 // reference tensor, and should always be the next step to take. We use
36 // std::list instead of std::priority_queue because C++'s
37 // std::priority_queue does not support increase-key, and might not be
38 // deterministic either.
39 std::list<NextHopWithInfo> candidates(1);
40 candidates.back().next_hop.from = nullptr;
41 candidates.back().next_hop.to = reference_;
42 candidates.back().info_to = reference_info_;
43
44 // Insert the given next hop the correct position in `candidates`. If there
45 // is an existing next hop that preserves more information, then we will just
46 // discard `info`.
47 auto insertNextHop = [&](const NextHopWithInfo& info) {
48 if (!*(info.info_from)) {
49 // When there is no more information about the starting tensor,
50 // we are not interested in continuing the path-finding.
51 return;
52 }
53 // Find if there is already a path to the dest tensor
54 auto existing = std::find_if(
55 candidates.begin(), candidates.end(), [&](const NextHopWithInfo& i) {
56 return i.next_hop.to == info.next_hop.to;
57 });
58 // Only insert if there is no existing path to the dest tensor, or the new
59 // path preserves more information about the starting tensor.
60 if (existing == candidates.end() || *existing < info) {
61 if (existing != candidates.end()) {
62 candidates.erase(existing);
63 }
64 auto pos = std::upper_bound(candidates.begin(), candidates.end(), info);
65 candidates.insert(pos, info);
66 }
67 };
68
69 auto allowC2P = [this](TensorView* from, TensorView* to) {
70 if (selector_ == nullptr) {
71 return true;
72 }
73 return selector_->allowC2P(from, to);
74 };
75
76 auto allowP2C = [this](TensorView* from, TensorView* to) {
77 if (selector_ == nullptr) {
78 return true;
79 }
80 return selector_->allowP2C(from, to);
81 };
82
83 auto allowSibling = [this](TensorView* from, TensorView* to) {
84 if (selector_ == nullptr) {
85 return true;
86 }
87 return selector_->allowSibling(from, to);
88 };
89
90 while (!candidates.empty()) {
91 const auto next_hop_info = candidates.back();
92 const auto& next_hop = next_hop_info.next_hop;
93 candidates.pop_back();
94
95 if (next_hop.from != nullptr) {
96 // nullptr used to start from reference
97 path_.push_back(next_hop);
98 }
99 replayed.emplace(next_hop.to);
100
101 for (auto sibling_tv : ir_utils::siblingTvsOf(next_hop.to)) {
102 if (replayed.count(sibling_tv) ||
103 !allowSibling(next_hop.to, sibling_tv)) {
104 continue;
105 }
106 insertNextHop(NextHopWithInfo(
107 NextHop(NextHopType::SIBLING, next_hop.to, sibling_tv),
108 next_hop_info.info_to,
109 computeInfoSibling(next_hop.to, sibling_tv, next_hop_info.info_to)));
110 }
111
112 for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) {
113 if (replayed.count(consumer_tv) || !allowP2C(next_hop.to, consumer_tv)) {
114 continue;
115 }
116 insertNextHop(NextHopWithInfo(
117 NextHop(NextHopType::C_AS_P, next_hop.to, consumer_tv),
118 next_hop_info.info_to,
119 computeInfoCasP(next_hop.to, consumer_tv, next_hop_info.info_to)));
120 }
121
122 for (auto producer_tv : ir_utils::producerTvsOf(next_hop.to)) {
123 if (replayed.count(producer_tv) || !allowC2P(next_hop.to, producer_tv)) {
124 continue;
125 }
126 insertNextHop(NextHopWithInfo(
127 NextHop(NextHopType::P_AS_C, next_hop.to, producer_tv),
128 next_hop_info.info_to,
129 computeInfoPasC(next_hop.to, producer_tv, next_hop_info.info_to)));
130 }
131 }
132}
133
134void MaxInfoSpanningTree::traverse(Propagator* propagator) {
135 if (path_.empty()) {
136 compute_spanning_tree();
137 }
138 propagator->setUp();
139 for (const auto& next_hop : path_) {
140 switch (next_hop.type) {
141 case NextHopType::SIBLING:
142 propagator->propagateSibling(next_hop.from, next_hop.to);
143 break;
144 case NextHopType::C_AS_P:
145 propagator->propagateP2C(next_hop.from, next_hop.to);
146 break;
147 case NextHopType::P_AS_C:
148 propagator->propagateC2P(next_hop.from, next_hop.to);
149 break;
150 }
151 }
152 propagator->tearDown();
153}
154
155MaxRootDomainInfoSpanningTree::RootDomainInfo::operator bool() const {
156 return !info.empty();
157}
158
159bool MaxRootDomainInfoSpanningTree::RootDomainInfo::operator<(
160 const Information& r) const {
161 auto rr = dynamic_cast<const RootDomainInfo&>(r);
162 if (info.size() != rr.info.size()) {
163 return info.size() < rr.info.size();
164 }
165 size_t l_complete =
166 std::count_if(info.begin(), info.end(), [](const RootIDInfo& i) {
167 return i.is_complete;
168 });
169 size_t r_complete =
170 std::count_if(rr.info.begin(), rr.info.end(), [](const RootIDInfo& i) {
171 return i.is_complete;
172 });
173 return l_complete < r_complete;
174}
175
176namespace {
177
178// Given `root_ids`, a list of IDs in the root domain of `tv`, find their
179// corresponding IDs in the rfactor domain of `tv`.
180std::unordered_set<IterDomain*> mapRootToRFactor(
181 TensorView* tv,
182 const std::unordered_set<IterDomain*>& root_ids) {
183 std::unordered_set<IterDomain*> mapped_rfactor_ids;
184 const auto& rfactor_dom = tv->getMaybeRFactorDomain();
185 for (auto id : rfactor_dom) {
186 if (root_ids.count(id) > 0) {
187 mapped_rfactor_ids.emplace(id);
188 continue;
189 }
190 for (auto root_id : root_ids) {
191 if (id == root_id || DependencyCheck::isDependencyOf(root_id, id)) {
192 mapped_rfactor_ids.emplace(id);
193 break;
194 }
195 }
196 }
197 return mapped_rfactor_ids;
198}
199
200// Given `rfactor_ids`, a list of IDs in the rfactor domain of `tv`, find their
201// corresponding IDs in the root domain of `tv`.
202std::unordered_set<IterDomain*> mapRFactorToRoot(
203 TensorView* tv,
204 const std::unordered_set<IterDomain*>& rfactor_ids) {
205 std::unordered_set<IterDomain*> mapped_root_ids;
206 for (auto id : tv->getRootDomain()) {
207 if (rfactor_ids.count(id) > 0) {
208 mapped_root_ids.emplace(id);
209 continue;
210 }
211 for (auto rfactor_id : rfactor_ids) {
212 if (DependencyCheck::isDependencyOf(id, rfactor_id)) {
213 mapped_root_ids.emplace(id);
214 break;
215 }
216 }
217 }
218 return mapped_root_ids;
219}
220
221} // namespace
222
223// Given the preserved reference root ID info of a producer, compute
224// the corresponding info in consumer. The given info may be represented by
225// producer's root domain, or rfactor domain, depending on how we reached the
226// producer during path-finding. If the given info is already represented with
227// producer's rfactor domain, then we directly map it to the consumer's root
228// domain. If the given info is represented with producer's root domain, we need
229// to first map it to the rfactor domain of the producer, then we can map it to
230// the consumer's root domain. The computed info will be represented by root
231// domain as root domain contains the raw information.
232std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree::
233 computeInfoCasP(
234 TensorView* from,
235 TensorView* to,
236 std::shared_ptr<Information> from_info) const {
237 RootDomainInfo result;
238
239 TensorView* producer = from;
240 TensorView* consumer = to;
241 const auto& producer_root_id_info =
242 std::dynamic_pointer_cast<RootDomainInfo>(from_info)->info;
243
244 auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
245 auto p2c_map = pairwise_map.mapProducerToConsumer(
246 producer->domain(), consumer->domain());
247
248 for (auto& info : producer_root_id_info) {
249 RootIDInfo consumer_info;
250 consumer_info.is_complete = info.is_complete;
251 consumer_info.is_rfactor = false;
252
253 // mapped root ids in producer -> mapped rfactor ids in producer
254 std::unordered_set<IterDomain*> producer_mapped_rfactor_ids;
255 if (producer->hasRFactor() && !info.is_rfactor) {
256 producer_mapped_rfactor_ids = mapRootToRFactor(producer, info.mapped_ids);
257 } else {
258 producer_mapped_rfactor_ids = info.mapped_ids;
259 }
260
261 // mapped rfactor ids in producer -> mapped root ids in consumer
262 for (auto producer_id : producer_mapped_rfactor_ids) {
263 auto it = p2c_map.find(producer_id);
264 if (it != p2c_map.end()) {
265 consumer_info.mapped_ids.insert(it->second);
266 } else {
267 consumer_info.is_complete = false;
268 }
269 }
270
271 // If at least one root id in the consumer contains information
272 // of this starting root id, then keep this record
273 if (!consumer_info.mapped_ids.empty()) {
274 result.info.push_back(consumer_info);
275 }
276 }
277 return std::make_shared<RootDomainInfo>(std::move(result));
278}
279
280// Given the preserved reference root ID info of a consumer, compute
281// the corresponding info in producer. The given info may be represented by
282// consumer's root domain, or rfactor domain, depending on how we reached the
283// consumer during path-finding. If the given info is already represented with
284// consumer's root domain, then we directly map it to the producer's rfactor
285// domain. If the given info is represented with consumer's rfactor domain, we
286// need to first map it to the root domain of the consumer, then we can map it
287// to the producer's rfactor domain. The computed info will be represented by
288// rfactor domain as rfactor domain contains the raw information.
289std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree::
290 computeInfoPasC(
291 TensorView* from,
292 TensorView* to,
293 std::shared_ptr<Information> from_info) const {
294 RootDomainInfo result;
295
296 TensorView* producer = to;
297 TensorView* consumer = from;
298 const auto& consumer_root_id_info =
299 std::dynamic_pointer_cast<RootDomainInfo>(from_info)->info;
300
301 auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
302 auto c2p_map = pairwise_map.mapConsumerToProducer(
303 consumer->domain(), producer->domain());
304
305 for (auto& info : consumer_root_id_info) {
306 RootIDInfo producer_info;
307 producer_info.is_complete = info.is_complete;
308 producer_info.is_rfactor = true;
309
310 // mapped rfactor ids in consumer -> mapped root ids in consumer
311 std::unordered_set<IterDomain*> consumer_mapped_root_ids;
312 if (info.is_rfactor && consumer->hasRFactor()) {
313 consumer_mapped_root_ids = mapRFactorToRoot(consumer, info.mapped_ids);
314 } else {
315 consumer_mapped_root_ids = info.mapped_ids;
316 }
317
318 // mapped root ids in consumer -> mapped rfactor ids in producer
319 for (auto consumer_id : consumer_mapped_root_ids) {
320 auto it = c2p_map.find(consumer_id);
321 if (it != c2p_map.end()) {
322 producer_info.mapped_ids.insert(it->second);
323 } else {
324 producer_info.is_complete = false;
325 }
326 }
327
328 // We will stop at the rfactor ids in producer, and will not further map
329 // them into root ids in producer. This means, we only keep the unprocessed
330 // raw information of a tensor. This behavior is important to make sure that
331 // info is as accurate as possible throughout the path-finding.
332 //
333 // For example, in a C->P->C' path, we want to do
334 // C(root) -> P(rfactor) -> C'(root)
335 // instead of
336 // C(root) -> P(rfactor) -> P(root) -> P(rfactor) -> C'(root)
337 //
338 // and the above two paths do lead to different results:
339 //
340 // For example if you have a producer tensor
341 // root domain: [I1, I2]
342 // rfactor domain: [I3, I5]
343 // where I3, I4 = split(I1), I5 = merge(I4, I2)
344 // Then the P(rfactor) -> P(root) -> P(rfactor) could lead to
345 // P(rfactor: {I5}) -> P(root: {I1, I2}) -> P(rfactor: {I3, I5})
346 // which is not correct
347
348 // If at least one root id in the producer contains information
349 // of this starting root id, then keep this record
350 if (!producer_info.mapped_ids.empty()) {
351 result.info.push_back(producer_info);
352 }
353 }
354 return std::make_shared<RootDomainInfo>(std::move(result));
355}
356
357std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo>
358MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo(TensorView* tv) {
359 RootDomainInfo result;
360 const auto& root_domain = tv->getRootDomain();
361 result.info.reserve(root_domain.size());
362 for (auto id : root_domain) {
363 result.info.emplace_back(RootIDInfo{{id}, true, false});
364 }
365 return std::make_shared<RootDomainInfo>(std::move(result));
366}
367
368std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo>
369MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo(
370 TensorView* tv,
371 int64_t leaf_pos) {
372 if (leaf_pos < 0) {
373 leaf_pos += int64_t(tv->nDims()) + 1;
374 }
375 TORCH_CHECK(
376 leaf_pos >= 0 && leaf_pos <= int64_t(tv->nDims()),
377 "MaxRootDomainInfoSpanningTree called on an leaf_pos outside valid range.");
378 RootDomainInfo result;
379 const auto& root_domain = tv->getMaybeRFactorDomain();
380 const auto& leaf_domain = tv->domain()->domain();
381 std::unordered_set<IterDomain*> selected_leaves(
382 leaf_domain.begin(), leaf_domain.begin() + leaf_pos);
383 for (auto id : root_domain) {
384 if (selected_leaves.count(id) > 0) {
385 result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()});
386 continue;
387 }
388 for (auto selected_leaf_id : selected_leaves) {
389 if (DependencyCheck::isDependencyOf(id, selected_leaf_id)) {
390 result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()});
391 break;
392 }
393 }
394 }
395 return std::make_shared<RootDomainInfo>(std::move(result));
396}
397
398// Given the preserved reference root ID info of a tensor, compute
399// the corresponding info in its sibling. Since info has nothing to do with
400// replay state, so sibling info is always identical by definition.
401std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree::
402 computeInfoSibling(
403 TensorView* from,
404 TensorView* to,
405 std::shared_ptr<Information> from_info) const {
406 return from_info;
407}
408
409void SpanningTreePrinter::propagateC2P(TensorView* from, TensorView* to) {
410 stream_ << "propagateC2P" << std::endl;
411 stream_ << " from: " << from->toString() << std::endl;
412 stream_ << " to: " << to->toString() << std::endl;
413}
414
415void SpanningTreePrinter::propagateP2C(TensorView* from, TensorView* to) {
416 stream_ << "propagateP2C" << std::endl;
417 stream_ << " from: " << from->toString() << std::endl;
418 stream_ << " to: " << to->toString() << std::endl;
419}
420
421void SpanningTreePrinter::propagateSibling(TensorView* from, TensorView* to) {
422 stream_ << "propagateSibling" << std::endl;
423 stream_ << " from: " << from->toString() << std::endl;
424 stream_ << " to: " << to->toString() << std::endl;
425}
426
427bool SetSelector::allowC2P(TensorView* from, TensorView* to) {
428 return selected_.count(to) > 0;
429}
430
431bool SetSelector::allowP2C(TensorView* from, TensorView* to) {
432 return selected_.count(to) > 0;
433}
434
435bool SetSelector::allowSibling(TensorView* from, TensorView* to) {
436 return true;
437}
438
439} // namespace cuda
440} // namespace fuser
441} // namespace jit
442} // namespace torch
443