1#include <inlining.h>
2#include <ir_utils.h>
3#include <root_domain_map.h>
4#include <transform_iter.h>
5
6#include <utility>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13MaxPosCalculator::MaxPosCalculator(
14 const std::unordered_set<IterDomain*>& uninlinable_ids)
15 : uninlinable_ids_(uninlinable_ids) {
16 buildUnmappableDims();
17}
18
19void MaxPosCalculator::buildUnmappableDims() {
20 ComputeAtRootDomainMap root_map;
21 root_map.build();
22 auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion());
23 for (auto tv : all_tvs) {
24 auto consumers = ir_utils::consumerTvsOf(tv);
25 for (auto consumer : consumers) {
26 // Grab dimensions in producer and consumer that are mappable to eachother
27 // based on the computeAtRootDomainMap. This will tell us which dimensions
28 // can be inlined based on avoiding trying to inline non-trivial
29 // reduction structures.
30 auto mappable_roots =
31 root_map.getMappableDims(tv->domain(), consumer->domain());
32 for (auto tv_root_id : tv->getMaybeRFactorDomain()) {
33 if (mappable_roots.find(tv_root_id) == mappable_roots.end() &&
34 !tv_root_id->isTrivialReduction()) {
35 unmappable_dims_.emplace(tv_root_id);
36 }
37 }
38 }
39 }
40}
41
42bool MaxPosCalculator::isAllowedID(
43 IterDomain* id,
44 TensorView* tv,
45 bool best_effort,
46 bool allow_reduction,
47 bool allow_vectorize,
48 bool allow_unmappable) const {
49 bool allowed = true;
50
51 if (!allow_reduction) {
52 allowed = allowed && !id->isReduction();
53 }
54
55 if (uninlinable_ids_.count(id)) {
56 return false;
57 }
58
59 if (!allow_vectorize) {
60 // Avoid inlining if marked as Vectorize or Group. In the case of
61 // BestEffort and MostInlined modes, avoid Unroll as well.
62 bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) ||
63 id->getParallelType() == ParallelType::Group ||
64 (best_effort && id->getParallelType() == ParallelType::Unroll);
65 allowed = allowed && !is_vectorize;
66 }
67
68 if (!allow_unmappable) {
69 auto root_dom = tv->getMaybeRFactorDomain();
70 std::unordered_set<Val*> root_dom_set(root_dom.begin(), root_dom.end());
71 auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id});
72 bool is_unmappable = false;
73 for (auto val : all_vals) {
74 auto id = val->as<IterDomain>();
75 if (root_dom_set.count(val) > 0 && unmappable_dims_.count(id) > 0) {
76 is_unmappable = true;
77 break;
78 }
79 }
80 allowed = allowed && !is_unmappable;
81 }
82
83 return allowed;
84}
85
86size_t MaxPosCalculator::getMaxPosSelf(
87 TensorView* tv,
88 bool best_effort,
89 bool allow_reduction,
90 bool allow_vectorize,
91 bool allow_unmappable) const {
92 auto dom = tv->domain()->domain();
93 auto iter = std::find_if(dom.begin(), dom.end(), [=](IterDomain* id) {
94 return !isAllowedID(
95 id,
96 tv,
97 best_effort,
98 allow_reduction,
99 allow_vectorize,
100 allow_unmappable);
101 });
102 return std::distance(dom.begin(), iter);
103}
104
105// Return the max position in producer that can be inlined to consumer
106// Cannot inline:
107// Vectorized dimensions in consumer
108// Unrolled dimensions in consumer
109size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
110 TensorView* producer,
111 TensorView* consumer,
112 bool best_effort) const {
113 auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
114 auto replay_CasP =
115 BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map);
116 auto p2c_replay_map = replay_CasP.getReplay();
117
118 for (size_t producer_pos = 0; producer_pos < producer->nDims();
119 producer_pos++) {
120 // If the producer position is mismatching with the consumer, then we can
121 // not inline into this position, otherwise the max producer position of
122 // the consumer will become invalid and expression sort will fail.
123 if (TransformReplay::getMatchedLeafPosWithoutReplayCasP(
124 consumer, producer, producer_pos + 1) < 0) {
125 return producer_pos;
126 }
127 auto map_it = p2c_replay_map.find(producer->axis(producer_pos));
128 if (map_it != p2c_replay_map.end()) {
129 auto c_id = map_it->second;
130 if (!isAllowedID(c_id, consumer, best_effort, true, false, true)) {
131 return producer_pos;
132 }
133 }
134 }
135 return producer->nDims();
136}
137
138size_t MaxPosCalculator::getMaxPosAll(
139 TensorView* tv,
140 bool best_effort,
141 bool check_siblings) {
142 auto max_pos = getMaxPosSelf(tv, best_effort, false, false, false);
143 for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
144 max_pos = std::min<size_t>(
145 max_pos, getMaxProducerPosFromConsumer(tv, consumer_tv, best_effort));
146 }
147 if (check_siblings) {
148 for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) {
149 max_pos = std::min<size_t>(
150 max_pos, getMaxPosAll(sibling_tv, best_effort, false));
151 }
152 }
153 return max_pos;
154}
155
156void inlineMost(const std::unordered_set<IterDomain*>& uninlinable_ids) {
157 inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids);
158}
159
160void inlineMost(
161 const std::vector<TensorView*>& tvs,
162 const std::unordered_set<IterDomain*>& uninlinable_ids) {
163 if (tvs.empty()) {
164 return;
165 }
166 MaxPosCalculator calc(uninlinable_ids);
167 for (auto tv : tvs) {
168 tv->inlineAt(-1, true, &calc);
169 }
170}
171
172void inlineMost(
173 const std::unordered_set<TensorView*>& tvs,
174 const std::unordered_set<IterDomain*>& uninlinable_ids) {
175 if (tvs.empty()) {
176 return;
177 }
178 MaxPosCalculator calc(uninlinable_ids);
179 for (auto tv : tvs) {
180 tv->inlineAt(-1, true, &calc);
181 }
182}
183
184namespace {
185
186// Find the positions of `selected` tensors that is mapped to the given position
187// in the reference tensor.
188class FindMappedPositions : public MaxInfoSpanningTree::Propagator {
189 std::unordered_map<TensorView*, size_t>& output_;
190
191 public:
192 FindMappedPositions(
193 std::unordered_map<TensorView*, size_t>& output,
194 TensorView* reference,
195 int64_t reference_pos);
196
197 ~FindMappedPositions() = default;
198
199 virtual void propagateC2P(TensorView* from, TensorView* to) override;
200 virtual void propagateP2C(TensorView* from, TensorView* to) override;
201 virtual void propagateSibling(TensorView* from, TensorView* to) override;
202};
203
204FindMappedPositions::FindMappedPositions(
205 std::unordered_map<TensorView*, size_t>& output,
206 TensorView* reference,
207 int64_t reference_pos)
208 : output_(output) {
209 if (reference_pos < 0) {
210 reference_pos += int64_t(reference->nDims()) + 1;
211 }
212 TORCH_CHECK(
213 reference_pos >= 0 && reference_pos <= int64_t(reference->nDims()),
214 "Invalid axis received ",
215 reference_pos,
216 " but should be > -",
217 reference->nDims(),
218 " and <= ",
219 reference->nDims(),
220 ".");
221 output_[reference] = reference_pos;
222}
223
224void FindMappedPositions::propagateC2P(TensorView* from, TensorView* to) {
225 int from_pos = output_.at(from);
226 auto to_pos =
227 TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
228 // If there is no matching position found, we compute the highest matched
229 // position as the closest approximation
230 while (to_pos < 0) {
231 from_pos--;
232 to_pos =
233 TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
234 }
235 output_[to] = to_pos;
236}
237
238void FindMappedPositions::propagateP2C(TensorView* from, TensorView* to) {
239 int from_pos = output_.at(from);
240 auto to_pos =
241 TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
242 // If there is no matching position found, we compute the highest matched
243 // position as the closest approximation
244 while (to_pos < 0) {
245 from_pos--;
246 to_pos =
247 TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
248 }
249 output_[to] = to_pos;
250}
251
252void FindMappedPositions::propagateSibling(TensorView* from, TensorView* to) {
253 auto from_pos = output_.at(from);
254 TORCH_CHECK(
255 TransformReplay::fullSelfMatching(to, from),
256 "Transformations in siblings ",
257 from,
258 " and ",
259 to,
260 " does not match with each other.");
261 output_[to] = from_pos;
262}
263
264std::unordered_map<TensorView*, size_t> getPositionsMappedTo(
265 TensorView* reference_tv,
266 int64_t reference_pos) {
267 std::unordered_map<TensorView*, size_t> mapped_positions;
268 MaxRootDomainInfoSpanningTree tree(reference_tv, reference_pos);
269 FindMappedPositions propagator(mapped_positions, reference_tv, reference_pos);
270 tree.traverse(&propagator);
271 return mapped_positions;
272}
273
274} // namespace
275
276void inlineAllAt(
277 TensorView* reference_tv,
278 int64_t reference_pos,
279 bool best_effort,
280 const std::unordered_set<IterDomain*>& uninlinable_ids) {
281 auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
282 MaxPosCalculator calc(uninlinable_ids);
283 for (auto pair : mapped_positions) {
284 pair.first->inlineAt(pair.second, best_effort, &calc);
285 }
286}
287
288void inlineSelectedAt(
289 const std::unordered_set<TensorView*>& selected,
290 TensorView* reference_tv,
291 int64_t reference_pos,
292 bool best_effort,
293 const std::unordered_set<IterDomain*>& uninlinable_ids) {
294 auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
295 MaxPosCalculator calc(uninlinable_ids);
296 for (auto pair : mapped_positions) {
297 if (selected.count(pair.first) > 0) {
298 pair.first->inlineAt(pair.second, best_effort, &calc);
299 }
300 }
301}
302
303} // namespace cuda
304} // namespace fuser
305} // namespace jit
306} // namespace torch
307