1 | #include <compute_at.h> |
2 | #include <instrumentation.h> |
3 | #include <ir_all_nodes.h> |
4 | #include <ir_iostream.h> |
5 | #include <ir_utils.h> |
6 | #include <lower_utils.h> |
7 | #include <root_domain_map.h> |
8 | #include <transform_iter.h> |
9 | |
10 | #include <c10/util/irange.h> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | // Simple selector that only propagates across tensor views in the provided |
18 | // unordered_set. Will also propagate to all consumers of those tensors, and the |
19 | // siblings of those tensors. |
20 | class ComputeAtSelector : public MaxInfoSpanningTree::Selector { |
21 | std::unordered_set<TensorView*> selected_; |
22 | |
23 | public: |
24 | virtual bool allowC2P(TensorView* from, TensorView* to) override { |
25 | return selected_.count(to) > 0; |
26 | } |
27 | |
28 | virtual bool allowP2C(TensorView* from, TensorView* to) override { |
29 | // If the producer is in the selected set, then the consumer must also be |
30 | // replayed to obtain a compatible loop structure so that this producer |
31 | // can be consumed in this loop. |
32 | return selected_.count(from) > 0 || selected_.count(to) > 0; |
33 | } |
34 | |
35 | virtual bool allowSibling(TensorView* from, TensorView* to) override { |
36 | return true; |
37 | } |
38 | |
39 | ComputeAtSelector(std::unordered_set<TensorView*> selected) |
40 | : selected_(std::move(selected)) {} |
41 | const std::unordered_set<TensorView*>& selected() const { |
42 | return selected_; |
43 | } |
44 | }; |
45 | |
46 | namespace { |
47 | |
48 | // Wrapper around set_intersection |
49 | template <typename T> |
50 | std::set<T> set_intersection(const std::set<T>& set1, const std::set<T>& set2) { |
51 | std::set<T> intersection; |
52 | std::set_intersection( |
53 | set1.begin(), |
54 | set1.end(), |
55 | set2.begin(), |
56 | set2.end(), |
57 | std::inserter(intersection, intersection.begin())); |
58 | return intersection; |
59 | } |
60 | |
61 | std::deque<std::deque<TensorView*>> tvChains( |
62 | std::deque<std::deque<Val*>> val_chains) { |
63 | std::deque<std::deque<TensorView*>> tv_chains(val_chains.size()); |
64 | for (const auto i : c10::irange(val_chains.size())) { |
65 | auto tv_iterable = ir_utils::filterByType<TensorView>(val_chains[i]); |
66 | tv_chains[i] = |
67 | std::deque<TensorView*>(tv_iterable.begin(), tv_iterable.end()); |
68 | } |
69 | return tv_chains; |
70 | } |
71 | |
72 | std::unordered_set<TensorView*> getAllTVsBetween( |
73 | TensorView* producer, |
74 | TensorView* consumer) { |
75 | TORCH_CHECK( |
76 | DependencyCheck::isDependencyOf(producer, consumer), |
77 | "Compute At expects " , |
78 | producer->name(), |
79 | " is a dependency of " , |
80 | consumer->name(), |
81 | ", however it is not." ); |
82 | auto between_vals = |
83 | DependencyCheck::getAllValsBetween({producer}, {consumer}); |
84 | auto between_tvs = ir_utils::filterByType<TensorView>(between_vals); |
85 | std::unordered_set<TensorView*> result( |
86 | between_tvs.begin(), between_tvs.end()); |
87 | result.erase(consumer); |
88 | return result; |
89 | } |
90 | |
91 | TensorView* getCommonConsumer(TensorView* producer, TensorView* consumer) { |
92 | FUSER_PERF_SCOPE("ComputeAt::setCommonConsumer" ); |
93 | auto producer_use_chains_ = |
94 | tvChains(DependencyCheck::getAllUseChains(producer)); |
95 | |
96 | // Convert the first chain to a set. |
97 | std::set<TensorView*> common_consumers( |
98 | producer_use_chains_.front().begin(), producer_use_chains_.front().end()); |
99 | |
100 | // Run through all use chains of producer, and intersect them to find common |
101 | // TVs |
102 | for (auto tv_chain : producer_use_chains_) { |
103 | common_consumers = set_intersection( |
104 | common_consumers, |
105 | std::set<TensorView*>(tv_chain.begin(), tv_chain.end())); |
106 | } |
107 | |
108 | auto all_chains = |
109 | tvChains(DependencyCheck::getAllDependencyChains(producer, consumer)); |
110 | |
111 | // Right now we only support compute at if at some point in the graph consumer |
112 | // is dependent on producer. |
113 | TORCH_CHECK( |
114 | !all_chains.empty(), |
115 | "Compute At expects " , |
116 | producer->name(), |
117 | " is a dependency of " , |
118 | consumer->name(), |
119 | ", however it is not." ); |
120 | |
121 | // Remove all TVs from producer to consumer as common consumer must be at or |
122 | // after consumer |
123 | for (const auto& tv_chain : all_chains) { |
124 | for (auto tv : tv_chain) { |
125 | if (tv != consumer) |
126 | common_consumers.erase(tv); |
127 | } |
128 | } |
129 | |
130 | // If there is a common consumer, grab the first one at or after consumer |
131 | TensorView* common_consumer = nullptr; |
132 | if (!common_consumers.empty()) { |
133 | for (auto tv : producer_use_chains_.front()) { |
134 | if (common_consumers.find(tv) != common_consumers.end()) { |
135 | common_consumer = tv; |
136 | break; |
137 | } |
138 | } |
139 | TORCH_INTERNAL_ASSERT( |
140 | common_consumer != nullptr, |
141 | "Hit a logical inconsistency in the computeAt pass." ); |
142 | } |
143 | return common_consumer; |
144 | } |
145 | |
146 | void pullInSiblings(std::unordered_set<TensorView*>& s) { |
147 | for (auto tv : s) { |
148 | for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { |
149 | if (sibling_tv == tv) { |
150 | continue; |
151 | } |
152 | s.emplace(sibling_tv); |
153 | } |
154 | } |
155 | } |
156 | |
157 | // I am just trying to get the same set of tensors being transformed matching |
158 | // the previous behavior of ComputeAt. The algorithm to compute this set is |
159 | // horrible, but I don't care because I will eventually completely remove |
160 | // ComputeAt, and this algorihtm is not worse than the pervious ComputeAt. :) |
161 | std::unordered_set<TensorView*> getPropagationSubgraph( |
162 | TensorView* producer, |
163 | TensorView* consumer) { |
164 | TORCH_CHECK( |
165 | DependencyCheck::isDependencyOf(producer, consumer), |
166 | "Compute At expects " , |
167 | producer->name(), |
168 | " is a dependency of " , |
169 | consumer->name(), |
170 | ", however it is not." ); |
171 | TensorView* common_consumer = getCommonConsumer(producer, consumer); |
172 | if (common_consumer != nullptr) { |
173 | auto result = getAllTVsBetween(producer, common_consumer); |
174 | pullInSiblings(result); |
175 | return result; |
176 | } |
177 | auto result_vals = DependencyCheck::getAllDependentVals({producer}); |
178 | result_vals.emplace(producer); |
179 | auto result_tvs = ir_utils::filterByType<TensorView>(result_vals); |
180 | std::unordered_set<TensorView*> result; |
181 | std::copy_if( |
182 | result_tvs.begin(), |
183 | result_tvs.end(), |
184 | std::inserter(result, result.begin()), |
185 | [](TensorView* tv) { return !tv->uses().empty(); }); |
186 | pullInSiblings(result); |
187 | return result; |
188 | } |
189 | |
190 | } // namespace |
191 | |
192 | void ComputeAt::runAt( |
193 | TensorView* producer, |
194 | TensorView* consumer, |
195 | int64_t consumer_position, |
196 | ComputeAtMode mode) { |
197 | FUSER_PERF_SCOPE("ComputeAt::runAt" ); |
198 | |
199 | // Make sure the correct fusion is setup between this and consumer. |
200 | TORCH_CHECK( |
201 | producer->fusion() == consumer->fusion(), |
202 | producer, |
203 | " and " , |
204 | consumer, |
205 | " are not in the same fusion." ); |
206 | |
207 | if (mode == ComputeAtMode::MostInlined) { |
208 | consumer_position = -1; |
209 | } |
210 | |
211 | FusionGuard fg(producer->fusion()); |
212 | |
213 | auto selected = getPropagationSubgraph(producer, consumer); |
214 | ComputeAtSelector selector(selected); |
215 | |
216 | MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); |
217 | |
218 | if (mode == ComputeAtMode::MostInlined) { |
219 | MostInlinedTransformPropagator propagator; |
220 | path.traverse(&propagator); |
221 | inlineMost(selected); |
222 | } else { |
223 | TransformPropagator propagator(consumer, consumer_position); |
224 | path.traverse(&propagator); |
225 | inlineSelectedAt( |
226 | selected, |
227 | consumer, |
228 | consumer_position, |
229 | mode == ComputeAtMode::BestEffort); |
230 | } |
231 | } |
232 | |
233 | void ComputeAt::runWith( |
234 | TensorView* producer, |
235 | TensorView* consumer, |
236 | int64_t producer_position, |
237 | ComputeAtMode mode) { |
238 | FUSER_PERF_SCOPE("ComputeAt::runWith" ); |
239 | |
240 | // Make sure the correct fusion is setup between this and consumer. |
241 | TORCH_CHECK( |
242 | producer->fusion() == consumer->fusion(), |
243 | producer, |
244 | " and " , |
245 | consumer, |
246 | " are not in the same fusion." ); |
247 | |
248 | if (mode == ComputeAtMode::MostInlined) { |
249 | producer_position = -1; |
250 | } |
251 | |
252 | FusionGuard fg(producer->fusion()); |
253 | |
254 | auto selected = getPropagationSubgraph(producer, consumer); |
255 | ComputeAtSelector selector(selected); |
256 | |
257 | MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector); |
258 | |
259 | if (mode == ComputeAtMode::MostInlined) { |
260 | MostInlinedTransformPropagator propagator; |
261 | path.traverse(&propagator); |
262 | inlineMost(selected); |
263 | } else { |
264 | TransformPropagator propagator(producer, producer_position); |
265 | path.traverse(&propagator); |
266 | inlineSelectedAt( |
267 | selected, |
268 | producer, |
269 | producer_position, |
270 | mode == ComputeAtMode::BestEffort); |
271 | } |
272 | } |
273 | |
274 | } // namespace cuda |
275 | } // namespace fuser |
276 | } // namespace jit |
277 | } // namespace torch |
278 | |