1 | #include <ir_utils.h> |
2 | #include <iter_visitor.h> |
3 | #include <root_domain_map.h> |
4 | |
5 | #include <lower_trivial_broadcast.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace fuser { |
10 | namespace cuda { |
11 | |
12 | ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) { |
13 | exact_map_ = std::make_unique<ExactRootDomainMap>(fusion); |
14 | |
15 | // Initialize the origin map with input broadcast domains |
16 | auto inputs = fusion->inputsAndCreated(); |
17 | for (const auto fusion_input_tv : |
18 | ir_utils::filterByType<TensorView>(inputs)) { |
19 | for (auto root_id : fusion_input_tv->getRootDomain()) { |
20 | if (root_id->isBroadcast()) { |
21 | broadcast_origin_map_.emplace( |
22 | root_id, std::unordered_set<IterDomain*>({root_id})); |
23 | } |
24 | } |
25 | } |
26 | traverse(fusion); |
27 | } |
28 | |
29 | bool ConcretizedBroadcastDomains::isConcretized(IterDomain* id) const { |
30 | auto it = broadcast_to_concrete_map_.find(id); |
31 | return it != broadcast_to_concrete_map_.end(); |
32 | } |
33 | |
34 | bool ConcretizedBroadcastDomains::isUniquelyConcretized(IterDomain* id) const { |
35 | auto it = broadcast_to_concrete_map_.find(id); |
36 | return it != broadcast_to_concrete_map_.end() && it->second.size() == 1; |
37 | } |
38 | |
39 | bool ConcretizedBroadcastDomains::maybeNonUniquelyConcretized( |
40 | IterDomain* id) const { |
41 | auto it = broadcast_to_concrete_map_.find(id); |
42 | return it != broadcast_to_concrete_map_.end() && it->second.size() > 1; |
43 | } |
44 | |
45 | void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) { |
46 | // Create a new entry for each of new broadcast domains |
47 | auto out = bop->out()->as<TensorView>(); |
48 | for (const auto i : c10::irange(out->getRootDomain().size())) { |
49 | if (bop->getBroadcastDimFlags().at(i)) { |
50 | auto new_bcast_id = out->getRootDomain().at(i); |
51 | broadcast_origin_map_.emplace( |
52 | new_bcast_id, std::unordered_set<IterDomain*>({new_bcast_id})); |
53 | } |
54 | } |
55 | } |
56 | |
57 | void ConcretizedBroadcastDomains::handle(Expr* expr) { |
58 | IterVisitor::handle(expr); |
59 | |
60 | // Propagate broadcast origin info from producers to consumers |
61 | for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) { |
62 | std::unordered_set<IterDomain*> producer_broadcasts; |
63 | // This assumes there's no merged broadcast axes between root and rfactor |
64 | // domains which is not possible at the moment. If this assumption is ever |
65 | // invalidated we would need to manaually propagate root IDs to rfactor IDs. |
66 | for (auto producer_id : producer->getMaybeRFactorDomain()) { |
67 | if (producer_id->isBroadcast()) { |
68 | producer_broadcasts.insert(producer_id); |
69 | } |
70 | } |
71 | if (producer_broadcasts.empty()) { |
72 | continue; |
73 | } |
74 | |
75 | for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) { |
76 | auto p2c_map = |
77 | PairwiseRootDomainMap(producer, consumer) |
78 | .mapProducerToConsumer( |
79 | producer->domain(), consumer->domain(), producer_broadcasts); |
80 | for (const auto& kv : p2c_map) { |
81 | auto p_id = kv.first; |
82 | auto c_id = kv.second; |
83 | // If the consumer ID is a reduction (i.e., a trivial |
84 | // reduction), do not consider it's concretized. |
85 | const bool is_concretized = |
86 | !c_id->isBroadcast() && !c_id->isReduction(); |
87 | auto it = broadcast_origin_map_.find(p_id); |
88 | TORCH_INTERNAL_ASSERT( |
89 | it != broadcast_origin_map_.end(), |
90 | "Broadcast origin info not found for producer broadcast domain: " , |
91 | p_id->toString(), |
92 | " of " , |
93 | producer->toString()); |
94 | const auto& producer_origins = it->second; |
95 | if (is_concretized) { |
96 | // Keep track of all the origin domains as concretized |
97 | for (auto origin : producer_origins) { |
98 | markAsConcretized(origin, c_id); |
99 | } |
100 | } else { |
101 | // Not concretized yet. Propagate forward the origin info. |
102 | auto& consumer_origins = broadcast_origin_map_[c_id]; |
103 | for (auto origin : producer_origins) { |
104 | consumer_origins.insert(origin); |
105 | } |
106 | consumer_origins.insert(c_id); |
107 | } |
108 | } |
109 | } |
110 | } |
111 | } |
112 | |
113 | void ConcretizedBroadcastDomains::markAsConcretized( |
114 | IterDomain* broadcast_root_domain, |
115 | IterDomain* concrete_root_domain) { |
116 | std::deque<IterDomain*> child_domains({broadcast_root_domain}); |
117 | while (!child_domains.empty()) { |
118 | auto child = child_domains.front(); |
119 | child_domains.pop_front(); |
120 | auto& concrete_ids = broadcast_to_concrete_map_[child]; |
121 | auto inserted = |
122 | insertRootDomainToConcreteDomainSet(concrete_root_domain, concrete_ids); |
123 | if (!inserted) { |
124 | continue; |
125 | } |
126 | const auto& child_uses = child->uses(); |
127 | for (auto child_use : child_uses) { |
128 | for (auto out_id : |
129 | ir_utils::filterByType<IterDomain>(child_use->outputs())) { |
130 | child_domains.push_back(out_id); |
131 | } |
132 | } |
133 | } |
134 | } |
135 | |
136 | bool ConcretizedBroadcastDomains::insertRootDomainToConcreteDomainSet( |
137 | IterDomain* new_root_id, |
138 | std::unordered_set<IterDomain*>& id_set) { |
139 | auto has_exactly_mapped_id = |
140 | std::any_of(id_set.begin(), id_set.end(), [&](IterDomain* existing_id) { |
141 | return exact_map_->areMapped(new_root_id, existing_id); |
142 | }); |
143 | if (has_exactly_mapped_id) { |
144 | return false; |
145 | } else { |
146 | id_set.emplace(new_root_id); |
147 | return true; |
148 | } |
149 | } |
150 | |
151 | } // namespace cuda |
152 | } // namespace fuser |
153 | } // namespace jit |
154 | } // namespace torch |
155 | |