1#include <ir_utils.h>
2#include <iter_visitor.h>
3#include <root_domain_map.h>
4
5#include <lower_trivial_broadcast.h>
6
7namespace torch {
8namespace jit {
9namespace fuser {
10namespace cuda {
11
12ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) {
13 exact_map_ = std::make_unique<ExactRootDomainMap>(fusion);
14
15 // Initialize the origin map with input broadcast domains
16 auto inputs = fusion->inputsAndCreated();
17 for (const auto fusion_input_tv :
18 ir_utils::filterByType<TensorView>(inputs)) {
19 for (auto root_id : fusion_input_tv->getRootDomain()) {
20 if (root_id->isBroadcast()) {
21 broadcast_origin_map_.emplace(
22 root_id, std::unordered_set<IterDomain*>({root_id}));
23 }
24 }
25 }
26 traverse(fusion);
27}
28
29bool ConcretizedBroadcastDomains::isConcretized(IterDomain* id) const {
30 auto it = broadcast_to_concrete_map_.find(id);
31 return it != broadcast_to_concrete_map_.end();
32}
33
34bool ConcretizedBroadcastDomains::isUniquelyConcretized(IterDomain* id) const {
35 auto it = broadcast_to_concrete_map_.find(id);
36 return it != broadcast_to_concrete_map_.end() && it->second.size() == 1;
37}
38
39bool ConcretizedBroadcastDomains::maybeNonUniquelyConcretized(
40 IterDomain* id) const {
41 auto it = broadcast_to_concrete_map_.find(id);
42 return it != broadcast_to_concrete_map_.end() && it->second.size() > 1;
43}
44
45void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) {
46 // Create a new entry for each of new broadcast domains
47 auto out = bop->out()->as<TensorView>();
48 for (const auto i : c10::irange(out->getRootDomain().size())) {
49 if (bop->getBroadcastDimFlags().at(i)) {
50 auto new_bcast_id = out->getRootDomain().at(i);
51 broadcast_origin_map_.emplace(
52 new_bcast_id, std::unordered_set<IterDomain*>({new_bcast_id}));
53 }
54 }
55}
56
57void ConcretizedBroadcastDomains::handle(Expr* expr) {
58 IterVisitor::handle(expr);
59
60 // Propagate broadcast origin info from producers to consumers
61 for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
62 std::unordered_set<IterDomain*> producer_broadcasts;
63 // This assumes there's no merged broadcast axes between root and rfactor
64 // domains which is not possible at the moment. If this assumption is ever
65 // invalidated we would need to manaually propagate root IDs to rfactor IDs.
66 for (auto producer_id : producer->getMaybeRFactorDomain()) {
67 if (producer_id->isBroadcast()) {
68 producer_broadcasts.insert(producer_id);
69 }
70 }
71 if (producer_broadcasts.empty()) {
72 continue;
73 }
74
75 for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
76 auto p2c_map =
77 PairwiseRootDomainMap(producer, consumer)
78 .mapProducerToConsumer(
79 producer->domain(), consumer->domain(), producer_broadcasts);
80 for (const auto& kv : p2c_map) {
81 auto p_id = kv.first;
82 auto c_id = kv.second;
83 // If the consumer ID is a reduction (i.e., a trivial
84 // reduction), do not consider it's concretized.
85 const bool is_concretized =
86 !c_id->isBroadcast() && !c_id->isReduction();
87 auto it = broadcast_origin_map_.find(p_id);
88 TORCH_INTERNAL_ASSERT(
89 it != broadcast_origin_map_.end(),
90 "Broadcast origin info not found for producer broadcast domain: ",
91 p_id->toString(),
92 " of ",
93 producer->toString());
94 const auto& producer_origins = it->second;
95 if (is_concretized) {
96 // Keep track of all the origin domains as concretized
97 for (auto origin : producer_origins) {
98 markAsConcretized(origin, c_id);
99 }
100 } else {
101 // Not concretized yet. Propagate forward the origin info.
102 auto& consumer_origins = broadcast_origin_map_[c_id];
103 for (auto origin : producer_origins) {
104 consumer_origins.insert(origin);
105 }
106 consumer_origins.insert(c_id);
107 }
108 }
109 }
110 }
111}
112
113void ConcretizedBroadcastDomains::markAsConcretized(
114 IterDomain* broadcast_root_domain,
115 IterDomain* concrete_root_domain) {
116 std::deque<IterDomain*> child_domains({broadcast_root_domain});
117 while (!child_domains.empty()) {
118 auto child = child_domains.front();
119 child_domains.pop_front();
120 auto& concrete_ids = broadcast_to_concrete_map_[child];
121 auto inserted =
122 insertRootDomainToConcreteDomainSet(concrete_root_domain, concrete_ids);
123 if (!inserted) {
124 continue;
125 }
126 const auto& child_uses = child->uses();
127 for (auto child_use : child_uses) {
128 for (auto out_id :
129 ir_utils::filterByType<IterDomain>(child_use->outputs())) {
130 child_domains.push_back(out_id);
131 }
132 }
133 }
134}
135
136bool ConcretizedBroadcastDomains::insertRootDomainToConcreteDomainSet(
137 IterDomain* new_root_id,
138 std::unordered_set<IterDomain*>& id_set) {
139 auto has_exactly_mapped_id =
140 std::any_of(id_set.begin(), id_set.end(), [&](IterDomain* existing_id) {
141 return exact_map_->areMapped(new_root_id, existing_id);
142 });
143 if (has_exactly_mapped_id) {
144 return false;
145 } else {
146 id_set.emplace(new_root_id);
147 return true;
148 }
149}
150
151} // namespace cuda
152} // namespace fuser
153} // namespace jit
154} // namespace torch
155