1#include <compute_at.h>
2#include <instrumentation.h>
3#include <ir_all_nodes.h>
4#include <ir_iostream.h>
5#include <ir_utils.h>
6#include <lower_utils.h>
7#include <root_domain_map.h>
8#include <transform_iter.h>
9
10#include <c10/util/irange.h>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17// Simple selector that only propagates across tensor views in the provided
18// unordered_set. Will also propagate to all consumers of those tensors, and the
19// siblings of those tensors.
20class ComputeAtSelector : public MaxInfoSpanningTree::Selector {
21 std::unordered_set<TensorView*> selected_;
22
23 public:
24 virtual bool allowC2P(TensorView* from, TensorView* to) override {
25 return selected_.count(to) > 0;
26 }
27
28 virtual bool allowP2C(TensorView* from, TensorView* to) override {
29 // If the producer is in the selected set, then the consumer must also be
30 // replayed to obtain a compatible loop structure so that this producer
31 // can be consumed in this loop.
32 return selected_.count(from) > 0 || selected_.count(to) > 0;
33 }
34
35 virtual bool allowSibling(TensorView* from, TensorView* to) override {
36 return true;
37 }
38
39 ComputeAtSelector(std::unordered_set<TensorView*> selected)
40 : selected_(std::move(selected)) {}
41 const std::unordered_set<TensorView*>& selected() const {
42 return selected_;
43 }
44};
45
46namespace {
47
48// Wrapper around set_intersection
49template <typename T>
50std::set<T> set_intersection(const std::set<T>& set1, const std::set<T>& set2) {
51 std::set<T> intersection;
52 std::set_intersection(
53 set1.begin(),
54 set1.end(),
55 set2.begin(),
56 set2.end(),
57 std::inserter(intersection, intersection.begin()));
58 return intersection;
59}
60
61std::deque<std::deque<TensorView*>> tvChains(
62 std::deque<std::deque<Val*>> val_chains) {
63 std::deque<std::deque<TensorView*>> tv_chains(val_chains.size());
64 for (const auto i : c10::irange(val_chains.size())) {
65 auto tv_iterable = ir_utils::filterByType<TensorView>(val_chains[i]);
66 tv_chains[i] =
67 std::deque<TensorView*>(tv_iterable.begin(), tv_iterable.end());
68 }
69 return tv_chains;
70}
71
72std::unordered_set<TensorView*> getAllTVsBetween(
73 TensorView* producer,
74 TensorView* consumer) {
75 TORCH_CHECK(
76 DependencyCheck::isDependencyOf(producer, consumer),
77 "Compute At expects ",
78 producer->name(),
79 " is a dependency of ",
80 consumer->name(),
81 ", however it is not.");
82 auto between_vals =
83 DependencyCheck::getAllValsBetween({producer}, {consumer});
84 auto between_tvs = ir_utils::filterByType<TensorView>(between_vals);
85 std::unordered_set<TensorView*> result(
86 between_tvs.begin(), between_tvs.end());
87 result.erase(consumer);
88 return result;
89}
90
91TensorView* getCommonConsumer(TensorView* producer, TensorView* consumer) {
92 FUSER_PERF_SCOPE("ComputeAt::setCommonConsumer");
93 auto producer_use_chains_ =
94 tvChains(DependencyCheck::getAllUseChains(producer));
95
96 // Convert the first chain to a set.
97 std::set<TensorView*> common_consumers(
98 producer_use_chains_.front().begin(), producer_use_chains_.front().end());
99
100 // Run through all use chains of producer, and intersect them to find common
101 // TVs
102 for (auto tv_chain : producer_use_chains_) {
103 common_consumers = set_intersection(
104 common_consumers,
105 std::set<TensorView*>(tv_chain.begin(), tv_chain.end()));
106 }
107
108 auto all_chains =
109 tvChains(DependencyCheck::getAllDependencyChains(producer, consumer));
110
111 // Right now we only support compute at if at some point in the graph consumer
112 // is dependent on producer.
113 TORCH_CHECK(
114 !all_chains.empty(),
115 "Compute At expects ",
116 producer->name(),
117 " is a dependency of ",
118 consumer->name(),
119 ", however it is not.");
120
121 // Remove all TVs from producer to consumer as common consumer must be at or
122 // after consumer
123 for (const auto& tv_chain : all_chains) {
124 for (auto tv : tv_chain) {
125 if (tv != consumer)
126 common_consumers.erase(tv);
127 }
128 }
129
130 // If there is a common consumer, grab the first one at or after consumer
131 TensorView* common_consumer = nullptr;
132 if (!common_consumers.empty()) {
133 for (auto tv : producer_use_chains_.front()) {
134 if (common_consumers.find(tv) != common_consumers.end()) {
135 common_consumer = tv;
136 break;
137 }
138 }
139 TORCH_INTERNAL_ASSERT(
140 common_consumer != nullptr,
141 "Hit a logical inconsistency in the computeAt pass.");
142 }
143 return common_consumer;
144}
145
146void pullInSiblings(std::unordered_set<TensorView*>& s) {
147 for (auto tv : s) {
148 for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) {
149 if (sibling_tv == tv) {
150 continue;
151 }
152 s.emplace(sibling_tv);
153 }
154 }
155}
156
157// I am just trying to get the same set of tensors being transformed matching
158// the previous behavior of ComputeAt. The algorithm to compute this set is
159// horrible, but I don't care because I will eventually completely remove
160// ComputeAt, and this algorihtm is not worse than the pervious ComputeAt. :)
161std::unordered_set<TensorView*> getPropagationSubgraph(
162 TensorView* producer,
163 TensorView* consumer) {
164 TORCH_CHECK(
165 DependencyCheck::isDependencyOf(producer, consumer),
166 "Compute At expects ",
167 producer->name(),
168 " is a dependency of ",
169 consumer->name(),
170 ", however it is not.");
171 TensorView* common_consumer = getCommonConsumer(producer, consumer);
172 if (common_consumer != nullptr) {
173 auto result = getAllTVsBetween(producer, common_consumer);
174 pullInSiblings(result);
175 return result;
176 }
177 auto result_vals = DependencyCheck::getAllDependentVals({producer});
178 result_vals.emplace(producer);
179 auto result_tvs = ir_utils::filterByType<TensorView>(result_vals);
180 std::unordered_set<TensorView*> result;
181 std::copy_if(
182 result_tvs.begin(),
183 result_tvs.end(),
184 std::inserter(result, result.begin()),
185 [](TensorView* tv) { return !tv->uses().empty(); });
186 pullInSiblings(result);
187 return result;
188}
189
190} // namespace
191
192void ComputeAt::runAt(
193 TensorView* producer,
194 TensorView* consumer,
195 int64_t consumer_position,
196 ComputeAtMode mode) {
197 FUSER_PERF_SCOPE("ComputeAt::runAt");
198
199 // Make sure the correct fusion is setup between this and consumer.
200 TORCH_CHECK(
201 producer->fusion() == consumer->fusion(),
202 producer,
203 " and ",
204 consumer,
205 " are not in the same fusion.");
206
207 if (mode == ComputeAtMode::MostInlined) {
208 consumer_position = -1;
209 }
210
211 FusionGuard fg(producer->fusion());
212
213 auto selected = getPropagationSubgraph(producer, consumer);
214 ComputeAtSelector selector(selected);
215
216 MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);
217
218 if (mode == ComputeAtMode::MostInlined) {
219 MostInlinedTransformPropagator propagator;
220 path.traverse(&propagator);
221 inlineMost(selected);
222 } else {
223 TransformPropagator propagator(consumer, consumer_position);
224 path.traverse(&propagator);
225 inlineSelectedAt(
226 selected,
227 consumer,
228 consumer_position,
229 mode == ComputeAtMode::BestEffort);
230 }
231}
232
233void ComputeAt::runWith(
234 TensorView* producer,
235 TensorView* consumer,
236 int64_t producer_position,
237 ComputeAtMode mode) {
238 FUSER_PERF_SCOPE("ComputeAt::runWith");
239
240 // Make sure the correct fusion is setup between this and consumer.
241 TORCH_CHECK(
242 producer->fusion() == consumer->fusion(),
243 producer,
244 " and ",
245 consumer,
246 " are not in the same fusion.");
247
248 if (mode == ComputeAtMode::MostInlined) {
249 producer_position = -1;
250 }
251
252 FusionGuard fg(producer->fusion());
253
254 auto selected = getPropagationSubgraph(producer, consumer);
255 ComputeAtSelector selector(selected);
256
257 MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);
258
259 if (mode == ComputeAtMode::MostInlined) {
260 MostInlinedTransformPropagator propagator;
261 path.traverse(&propagator);
262 inlineMost(selected);
263 } else {
264 TransformPropagator propagator(producer, producer_position);
265 path.traverse(&propagator);
266 inlineSelectedAt(
267 selected,
268 producer,
269 producer_position,
270 mode == ComputeAtMode::BestEffort);
271 }
272}
273
274} // namespace cuda
275} // namespace fuser
276} // namespace jit
277} // namespace torch
278