1 | #include <compute_at_map.h> |
2 | #include <fusion.h> |
3 | #include <instrumentation.h> |
4 | #include <ir_all_nodes.h> |
5 | #include <ir_iostream.h> |
6 | #include <ir_utils.h> |
7 | #include <lower2device.h> |
8 | #include <lower_expr_sort.h> |
9 | #include <lower_utils.h> |
10 | |
11 | #include <deque> |
12 | #include <list> |
13 | #include <sstream> |
14 | #include <unordered_map> |
15 | #include <unordered_set> |
16 | #include <vector> |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | namespace fuser { |
21 | namespace cuda { |
22 | |
23 | namespace { |
24 | |
25 | // TODO: Review const model, and objects |
26 | // ExprSegmentationSorter |
27 | // Responsible for going through DAG and proposing things we could try to |
28 | // merge together, calls "supportedMerge" on these proposed groups to see |
29 | // if they should be merged together, then merges them if so. |
30 | // ExprGroup |
31 | // A group of exprs that are grouped together based on their loop nest |
32 | // structures. |
33 | // ExprGroupConnections |
34 | // Holds vals and what they connect. In other words it's a val that is an |
35 | // output of a ExprSegmentationSorter "from" and an input of |
36 | // ExprSegmentationSorter "to". There's nothing preventing from a val being |
37 | // between groups twice. |
38 | // TODO: make sure there's nothing wrong with grouping of nodes that |
39 | // have the same value input twice. i.e. (B = A*A) |
40 | |
41 | // Selecting segments to propose is based on the theorem 4.2 in the paper which |
42 | // makes sure when segment the segmented graph will be a DAG (assumes Fusion is |
43 | // already a DAG). The segmentation code relies on assumptions of DAG-ness |
44 | // during segmentation, meaning proposed merging of groups must maintain the DAG |
45 | // property of the graph. |
46 | // |
47 | // Julien Herrmann, Yusuf Özkaya, Bora Uçar, Kamer Kaya, Umit Catalyurek. |
48 | // Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs. |
49 | // SIAM Journal on Scientific Computing, Society for Industrial and Applied |
50 | // Mathematics, 2019, 41 (4), pp.A2117-A2145. ff10.1137/18M1176865ff. |
51 | // ffhal02306566f |
52 | |
53 | class ExprGroup; |
54 | struct ExprGroupConnections; |
55 | class ExprSegmentationSorter; |
56 | |
57 | // Debug printing disabled due to clang tidy, see below for definitions |
58 | // std::ostream& operator<<(std::ostream& os, const ExprGroup* group); |
59 | |
60 | // Wrapper for values, these are edges between expr groups. Multiple edges can |
61 | // exist between expr groups, and the same Val can show up more than once in |
62 | // multiple edges. |
63 | struct ExprGroupConnections { |
64 | ExprGroupConnections( |
65 | ExprGroup* group_from, |
66 | ExprGroup* group_to, |
67 | Val* producer_val, |
68 | Val* consumer_val) |
69 | : from(group_from), |
70 | to(group_to), |
71 | producer_val_(producer_val), |
72 | consumer_val_(consumer_val) {} |
73 | // Producer group from which the edge starts |
74 | ExprGroup* from; |
75 | |
76 | // Consumer group from which the edge ends |
77 | ExprGroup* to; |
78 | |
79 | // The value from the producer group connecting the groups |
80 | // This value helps us resolve the compute at position of expr groups |
81 | |
82 | Val* producer_val_; |
83 | |
84 | // The value that the producer val gets used to create on this edge |
85 | // This value helps us resolve the produce at position of expr groups |
86 | Val* consumer_val_; |
87 | }; |
88 | |
89 | struct ExprSortPayload : public PolymorphicBase { |
90 | // Need to track compute at domains as well as produce at domains. Produce at |
91 | // domains will be matched to producers compute at domains. Track the active |
92 | // domains that will be matched from inner most dim to outer most. |
93 | std::vector<IterDomain*> ca_domains_; |
94 | std::vector<IterDomain*> pa_domains_; |
95 | |
96 | // Maximum path distance from an input expr group required for |
97 | // Theorem 4.2 |
98 | int level = -1; |
99 | |
100 | // Traversal marker, marks if this group has been visited by current pass |
101 | bool visited = false; |
102 | |
103 | // Marks if this group is already selected to merge with another group, marks |
104 | // which group to merge with |
105 | ExprGroup* merge_with = nullptr; |
106 | |
107 | // Marks if this group is already selected to merge with another group |
108 | bool merged = false; |
109 | }; |
110 | |
111 | // Groups together expressions which create a expr group |
112 | class ExprGroup { |
113 | public: |
114 | ExprGroup() : payload_(std::make_unique<ExprSortPayload>()) {} |
115 | |
116 | ExprGroup(Expr* expr) : payload_(std::make_unique<ExprSortPayload>()) { |
117 | exprs_.push_back(expr); |
118 | } |
119 | |
120 | ExprGroup(const ExprGroup& other) |
121 | : payload_(new ExprSortPayload(*(other.payload_))) {} |
122 | |
123 | ExprGroup& operator=(const ExprGroup& other) { |
124 | *payload_ = *other.payload_; |
125 | exprs_ = other.exprs_; |
126 | return *this; |
127 | } |
128 | |
129 | // Clears the traversal information in the payload |
130 | void clearTraversalInfo(); |
131 | |
132 | // Returns all neighbors, producers and consumers |
133 | std::vector<ExprGroup*> getNeighbors(); |
134 | |
135 | // Return neighbors of this proven to be safe nodes to merge with in regards |
136 | // to maining an acyclic graph. This looks at, neighbors if merged, neighbors |
137 | // level, and merged neighbors of neighbors level. If fallback_mode_enabled |
138 | // will return the inverse set of ExprGroups that are proven to be safe |
139 | // merges. |
140 | std::vector<ExprGroup*> getMergeCandidates( |
141 | bool fallback_mode_enabled = false); |
142 | |
143 | std::unique_ptr<ExprSortPayload>& payload() { |
144 | return payload_; |
145 | } |
146 | |
147 | const auto& producerEdges() const { |
148 | return producer_edges_; |
149 | } |
150 | |
151 | void addProducerEdge(ExprGroupConnections* edge) { |
152 | addEdge(producer_edges_, edge); |
153 | } |
154 | |
155 | void removeProducerEdge(ExprGroupConnections* edge) { |
156 | removeEdge(producer_edges_, edge); |
157 | } |
158 | |
159 | void clearProducerEdges() { |
160 | producer_edges_.clear(); |
161 | } |
162 | |
163 | const auto& consumerEdges() const { |
164 | return consumer_edges_; |
165 | } |
166 | |
167 | void addConsumerEdge(ExprGroupConnections* edge) { |
168 | addEdge(consumer_edges_, edge); |
169 | } |
170 | |
171 | void removeConsumerEdge(ExprGroupConnections* edge) { |
172 | removeEdge(consumer_edges_, edge); |
173 | } |
174 | |
175 | void clearConsumerEdges() { |
176 | consumer_edges_.clear(); |
177 | } |
178 | |
179 | auto& exprs() { |
180 | return exprs_; |
181 | } |
182 | |
183 | const auto& exprs() const { |
184 | return exprs_; |
185 | } |
186 | |
187 | private: |
188 | static void addEdge( |
189 | std::vector<ExprGroupConnections*>& edges, |
190 | ExprGroupConnections* edge_to_add) { |
191 | edges.push_back(edge_to_add); |
192 | } |
193 | |
194 | static void removeEdge( |
195 | std::vector<ExprGroupConnections*>& edges, |
196 | ExprGroupConnections* edge_to_remove) { |
197 | auto it = std::find(edges.begin(), edges.end(), edge_to_remove); |
198 | TORCH_INTERNAL_ASSERT(it != edges.end(), "Could not find edge to remove." ); |
199 | edges.erase(it); |
200 | } |
201 | |
202 | private: |
203 | // "Ancestor nodes", towards inputs of segmentedDAG |
204 | std::vector<ExprGroupConnections*> producer_edges_; |
205 | |
206 | // "Descendent nodes", towards outputs of segmentedDAG |
207 | std::vector<ExprGroupConnections*> consumer_edges_; |
208 | |
209 | // Exprs that make up the group |
210 | std::vector<Expr*> exprs_; |
211 | |
212 | // Stateful traversal information |
213 | std::unique_ptr<ExprSortPayload> payload_; |
214 | }; |
215 | |
216 | // This class sorts expressions guarantees two things, 1) Tensors are produced |
217 | // before they're consumed 2) If the production of two tensors are supposed to |
218 | // share a for loop, they're in an order where they can. (1) is pretty standard |
219 | // of ordering a DAG. (2) is where things get a bit complicated and why we do |
220 | // this sorting through segmentation. Consider a section of a DAG: T4 = T3 + T2. |
221 | // Where T2 and T3 are not inputs to the fusion, all tensors are 3D, and we want |
222 | // the production of T3 to share the inner most loop of T4 and we want the |
223 | // production of T2 to share the middle loop with T4. i.e. we're looking for |
224 | // For(i:I){ |
225 | // For(j: J){ |
226 | // For(k: K){ |
227 | // T2[i, j, k] = ... |
228 | // } |
229 | // For(k: K){ |
230 | // T3[i, j, k] = ... |
231 | // T4[i, j, k] = T2[i, j, k] + T3[i, j, k] |
232 | // } |
233 | // } |
234 | // } |
235 | // The only valid ordering of expressions is producing T2, then T3, then T4. If |
236 | // we swapped T3 and T2, then T3 and T4 couldn't share their inner most loop, |
237 | // because T2 has its own inner most loop. If we swapped either tensor with T4, |
238 | // then we'd try to be using T2 or T3 without producing them (back to gaurantee |
239 | // 1). |
240 | class ExprSegmentationSorter { |
241 | public: |
242 | ExprSegmentationSorter(Fusion* fusion) : fusion_(fusion) {} |
243 | |
244 | void sort(); |
245 | |
246 | std::string toString(int verbosity = 0) const; |
247 | |
248 | //! Returns a flattened list of sorted exprs |
249 | std::vector<Expr*> getExprs() const; |
250 | |
251 | private: |
252 | // Allocate an empty expr group and return it |
253 | ExprGroup* makeEmptyGroup(); |
254 | |
255 | // Allocate an expr group with the provided expr and return it. Also requires |
256 | // information on if this expression is a terminating expression (none of its |
257 | // outputs are used in other expressions being sorted). |
258 | ExprGroup* makeEmptyGroup(Expr*, bool terminating_expr); |
259 | |
260 | // Returns if sg1 and sg2 should be merged together, is called if they can |
261 | // based on the current status of the DAG. |
262 | bool supportedMerge(ExprGroup* sg1, ExprGroup* sg2); |
263 | |
264 | // Returns true if the graph will remain an acyclic graph after merging sg1 |
265 | // and sg2 |
266 | bool testStillDag(ExprGroup* sg1, ExprGroup* sg2); |
267 | |
268 | // Merges two ExprGroups and returns the new ExprGroup |
269 | ExprGroup* makeMergedNode(ExprGroup* sg1, ExprGroup* sg2); |
270 | |
271 | // This is called once no more groups can be merged together. This will lower |
272 | // the compute at position of a segment group if the last dimension of the |
273 | // segment group doesn't map to any of the dimensions of its neighbors. |
274 | bool interIterUpdate(); |
275 | |
276 | // Reset the ExprSortPayload of the groups so we can traverse and identify |
277 | // merge candidates. |
278 | void resetTraversal(); |
279 | |
280 | // Reset the set levels of each group. This is what's used to identify which |
281 | // nodes can be merged together. |
282 | void resetLevels(); |
283 | |
284 | // Go through groups that are marked as to merge and merge them. |
285 | void mergeNodes(); |
286 | |
287 | // Initialize concrete_id_dependencies |
288 | void initializeForLoopDependencies(); |
289 | |
290 | // Checks if the for loop associated with the concrete ID is ready to be |
291 | // resolved in sorting. |
292 | bool loopReady(IterDomain* concrete_id); |
293 | |
294 | // Disconnect the edges connecting group to the rest of the graph, and return |
295 | // all the edges that were disconnected |
296 | std::unordered_set<ExprGroupConnections*> disconnectGroup(ExprGroup* group); |
297 | |
298 | private: |
299 | // Track how many groups we have from iteration to iteration so we can track |
300 | // when we've stopped merging nodes. |
301 | size_t n_groups_ = 0; |
302 | |
303 | // Lifetime of the graph view of the fusion and segmentation. Use list to not |
304 | // invalidate any entries on insertion/deletion. |
305 | std::list<std::unique_ptr<ExprGroupConnections>> edges_; |
306 | std::list<std::unique_ptr<ExprGroup>> groups_; |
307 | |
308 | std::deque<ExprGroup*> to_visit_; |
309 | |
310 | std::vector<std::pair<ExprGroup*, ExprGroup*>> to_merge_; |
311 | |
312 | Fusion* fusion_; |
313 | |
314 | // We use a theorem out of a paper mentioned in other comments. This theorem |
315 | // is good at identifying multiple expr groups to merge during a single |
316 | // iteration without producing a cyclic graph from an acyclic graph. This |
317 | // theorem is not guaranteed to find all possible nodes that can be merged |
318 | // together. We need to be able to group all disjoint groups of exprs or |
319 | // we fail to generate code. Therefore, if we can't find anything to make |
320 | // forward progress based on the theorem we fallback to manually looking if we |
321 | // can segmenet all combinations we haven't previously looked at. |
322 | bool fallback_mode_enabled_ = false; |
323 | |
324 | // We need to track ID resolution, see AdvancedLowering6. For loops need |
325 | // to be resolved from inner most to outer most. We therefore track |
326 | // loop dependencies where inner most loops need to be fully resolved before |
327 | // we can resolve the next outer loop. We track this by looking at all tensor |
328 | // views, and each iteration domain. An iter domain in the outer most position |
329 | // has dependencies on all inner dimensions. This tracking is done on concrete |
330 | // id's in the loop map, this is because IDs may exist in some TVs but not |
331 | // others, however, we need a "global" view to track these dependencies. |
332 | std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> |
333 | concrete_id_dependencies; |
334 | }; |
335 | |
336 | // // Debug printing, disabled due to clang-tidy see above for declarations. |
337 | // std::ostream& operator<<(std::ostream& os, ExprGroup* group) { |
338 | // os << "Group Start{\n ca, pa (" |
339 | // << group->payload()->ca_domains_.size() << ", " |
340 | // << group->payload()->pa_domains_.size() << ")"; |
341 | // os << " ca_ids {"; |
342 | // for (size_t i = 0; i < group->payload()->ca_domains_.size(); i++) { |
343 | // os << group->payload()->ca_domains_[i]; |
344 | // if (i + 1 != group->payload()->ca_domains_.size()) |
345 | // os << ", "; |
346 | // } |
347 | // os << "} pa_ids {"; |
348 | // for (size_t i = 0; i < group->payload()->pa_domains_.size(); i++) { |
349 | // os << group->payload()->pa_domains_[i]; |
350 | // if (i + 1 != group->payload()->pa_domains_.size()) |
351 | // os << ", "; |
352 | // } |
353 | // os << "}"; |
354 | // os << "\nExprs {\n"; |
355 | // for(auto expr : group->exprs()){ |
356 | // os << expr; |
357 | // } |
358 | // os << "}Group End\n"; |
359 | // return os; |
360 | // } |
361 | |
362 | std::vector<ExprGroup*> ExprGroup::getNeighbors() { |
363 | std::vector<ExprGroup*> neighbors; |
364 | for (auto inp : producer_edges_) { |
365 | neighbors.push_back(inp->from); |
366 | } |
367 | for (auto out : consumerEdges()) { |
368 | neighbors.push_back(out->to); |
369 | } |
370 | return neighbors; |
371 | } |
372 | |
373 | std::vector<ExprGroup*> ExprGroup::getMergeCandidates( |
374 | bool fallback_mode_enabled) { |
375 | std::vector<ExprGroup*> neighbors = getNeighbors(); |
376 | |
377 | // Don't look for candidates if already merged |
378 | if (payload()->merged) { |
379 | return {}; |
380 | } |
381 | |
382 | // Can this node be merged with another? Check if neighbors are merged, if |
383 | // so and merged neighbor is within 1 level or node merged with neighbor is |
384 | // within 1 level, can't merge this node with anything else. |
385 | bool can_merge_this = true; |
386 | bool neighbor_merged = false; |
387 | for (auto neighbor : neighbors) { |
388 | if (!neighbor->payload()->merged) { |
389 | continue; |
390 | } |
391 | neighbor_merged = true; |
392 | if (std::abs(neighbor->payload()->level - payload()->level) <= 1) { |
393 | can_merge_this = false; |
394 | } |
395 | if (std::abs( |
396 | neighbor->payload()->merge_with->payload()->level - |
397 | payload()->level) <= 1) { |
398 | can_merge_this = false; |
399 | } |
400 | } |
401 | |
402 | // If something prevents us from merging this node, and we're not in fallback |
403 | // mode, return empty set. |
404 | if (!can_merge_this && !fallback_mode_enabled) { |
405 | return {}; |
406 | } |
407 | |
408 | // If fallback mode already detected a merge somewhere, we shouldn't still be |
409 | // traversing. |
410 | if (fallback_mode_enabled) { |
411 | TORCH_INTERNAL_ASSERT( |
412 | !neighbor_merged, |
413 | "Shouldn't still be traversing in fallback mode if a merge was found." ); |
414 | } |
415 | |
416 | std::vector<bool> can_merge(neighbors.size(), true); |
417 | |
418 | // Find neighbors with a level that is only 1 different than this group's |
419 | // level |
420 | for (const auto i : c10::irange(neighbors.size())) { |
421 | if (std::abs(neighbors[i]->payload()->level - payload()->level) > 1) { |
422 | can_merge[i] = false; |
423 | } |
424 | } |
425 | |
426 | // Check neighbor of neighbors we're considering, if any of them are merged |
427 | // with another node, make sure the resulting edge wouldn't have a level |
428 | // difference of 1 |
429 | for (const auto i : c10::irange(neighbors.size())) { |
430 | if (!can_merge[i]) { |
431 | continue; |
432 | } |
433 | |
434 | for (auto neighbor_neighbor : neighbors[i]->getNeighbors()) { |
435 | // Don't check self |
436 | if (neighbor_neighbor == neighbors[i]) { |
437 | continue; |
438 | } |
439 | if (neighbor_neighbor->payload()->merged) { |
440 | // check neighbor_neighbor level |
441 | if (std::abs(neighbor_neighbor->payload()->level - payload()->level) <= |
442 | 1) { |
443 | can_merge[i] = false; |
444 | } |
445 | if (std::abs( |
446 | neighbor_neighbor->payload()->level - |
447 | neighbors[i]->payload()->level) <= 1) { |
448 | can_merge[i] = false; |
449 | } |
450 | |
451 | // check neighbor_neighber->merged->level |
452 | if (std::abs( |
453 | neighbor_neighbor->payload()->merge_with->payload()->level - |
454 | payload()->level) <= 1) { |
455 | can_merge[i] = false; |
456 | } |
457 | if (std::abs( |
458 | neighbor_neighbor->payload()->merge_with->payload()->level - |
459 | neighbors[i]->payload()->level) <= 1) { |
460 | can_merge[i] = false; |
461 | } |
462 | } |
463 | } |
464 | } |
465 | |
466 | std::vector<ExprGroup*> merge_candidates; |
467 | for (const auto i : c10::irange(neighbors.size())) { |
468 | if ((can_merge[i] && !fallback_mode_enabled) || |
469 | (!can_merge[i] && fallback_mode_enabled)) { |
470 | merge_candidates.push_back(neighbors[i]); |
471 | } |
472 | } |
473 | return merge_candidates; |
474 | } |
475 | |
476 | void ExprGroup::clearTraversalInfo() { |
477 | payload()->level = -1; |
478 | payload()->visited = false; |
479 | payload()->merge_with = nullptr; |
480 | payload()->merged = false; |
481 | } |
482 | |
483 | void ExprSegmentationSorter::resetTraversal() { |
484 | for (auto& group : groups_) { |
485 | // Start traversal at input groups |
486 | if (group->producerEdges().empty()) { |
487 | to_visit_.push_back(group.get()); |
488 | } |
489 | group->clearTraversalInfo(); |
490 | } |
491 | } |
492 | |
493 | // Level is maximum distance from inputs. It's the metric used to select what |
494 | // nodes can be merged while maintaining a DAG |
495 | void ExprSegmentationSorter::resetLevels() { |
496 | std::vector<ExprGroup*> next_to_visit; |
497 | |
498 | while (!to_visit_.empty()) { |
499 | auto visit = to_visit_.front(); |
500 | to_visit_.pop_front(); |
501 | |
502 | // All inputs processed? |
503 | bool ready = true; |
504 | if (!visit->producerEdges().empty()) { |
505 | ready = std::all_of( |
506 | visit->producerEdges().begin(), |
507 | visit->producerEdges().end(), |
508 | [&](ExprGroupConnections* dep) { |
509 | return dep->from->payload()->visited; |
510 | }); |
511 | } |
512 | |
513 | if (!ready) { |
514 | // In case traversal doesn't complete because there's an error in the |
515 | // DAG topology. |
516 | next_to_visit.push_back(visit); |
517 | continue; |
518 | } |
519 | |
520 | visit->payload()->visited = true; |
521 | |
522 | to_visit_.insert( |
523 | to_visit_.end(), next_to_visit.begin(), next_to_visit.end()); |
524 | next_to_visit.clear(); |
525 | |
526 | for (auto out : visit->consumerEdges()) { |
527 | to_visit_.push_back(out->to); |
528 | } |
529 | |
530 | visit->payload()->level = 0; |
531 | for (auto inp : visit->producerEdges()) { |
532 | visit->payload()->level = |
533 | std::max(visit->payload()->level, inp->from->payload()->level + 1); |
534 | } |
535 | } |
536 | TORCH_INTERNAL_ASSERT(next_to_visit.empty(), "Error in graph, is not a DAG." ); |
537 | } |
538 | |
539 | ExprGroup* ExprSegmentationSorter::makeEmptyGroup() { |
540 | groups_.push_back(std::make_unique<ExprGroup>()); |
541 | return groups_.back().get(); |
542 | } |
543 | |
544 | ExprGroup* ExprSegmentationSorter::makeEmptyGroup( |
545 | Expr* expr, |
546 | bool terminating_expr) { |
547 | auto group = makeEmptyGroup(); |
548 | group->exprs().push_back(expr); |
549 | if (ir_utils::isTvOp(expr)) { |
550 | auto out_tv = expr->outputs()[0]->as<TensorView>(); |
551 | // Grab all id's that are shared with other tensors. |
552 | // If not connected to consumers, doesn't mater what compute at is set to |
553 | if (!terminating_expr) { |
554 | for (const auto tv_i : c10::irange(out_tv->getComputeAtPosition())) { |
555 | group->payload()->ca_domains_.push_back(out_tv->axis(tv_i)); |
556 | } |
557 | } |
558 | for (const auto tv_i : c10::irange(out_tv->getMaxProducerPosition())) { |
559 | group->payload()->pa_domains_.push_back(out_tv->axis(tv_i)); |
560 | } |
561 | } |
562 | return group; |
563 | } |
564 | |
565 | // Debug function that prints the current state of the sorter. |
566 | // |
567 | // Uncomment if needed. |
568 | // std::string ExprSegmentationSorter::toString(int verbosity) const { |
569 | // std::stringstream ss; |
570 | // ss << "{\n"; |
571 | // for (auto& group : groups_) { |
572 | // ss << " " << group.get() << "\n"; |
573 | |
574 | // if (verbosity > 1) { |
575 | // if (group->producerEdges().size() > 0) { |
576 | // ss << "Produced by groups with edges: { \n"; |
577 | // for (auto producer_edge : group->producerEdges()) { |
578 | // ss << producer_edge->producer_val_ << " -> " |
579 | // << producer_edge->consumer_val_ << "\n"; |
580 | // } |
581 | // ss << " }" |
582 | // << "\n"; |
583 | // } |
584 | // } |
585 | |
586 | // if (verbosity > 1) { |
587 | // if (group->consumerEdges().size() > 0) { |
588 | // ss << "Consumed by groups with edges: { \n"; |
589 | // for (auto consumer_edge : group->consumerEdges()) { |
590 | // ss << consumer_edge->producer_val_ << " -> " |
591 | // << consumer_edge->consumer_val_ << "\n"; |
592 | // } |
593 | // ss << " }" |
594 | // << "\n"; |
595 | // } |
596 | // } |
597 | // } |
598 | // ss << "}\n"; |
599 | // return ss.str(); |
600 | // } |
601 | |
602 | namespace { |
603 | |
604 | // Concat's edges of sg1 and sg2, but removes any edges from/to sg1/sg2 |
605 | std::vector<ExprGroupConnections*> getMergedEdges( |
606 | const ExprGroup* sg1, |
607 | const std::vector<ExprGroupConnections*>& edges1, |
608 | const ExprGroup* sg2, |
609 | const std::vector<ExprGroupConnections*>& edges2) { |
610 | TORCH_INTERNAL_ASSERT( |
611 | sg1 != nullptr && sg2 != nullptr, |
612 | "This function doesn't handle trivial." ); |
613 | |
614 | auto merged_edges = edges1; |
615 | merged_edges.insert(merged_edges.end(), edges2.begin(), edges2.end()); |
616 | |
617 | // Remove intra edges |
618 | merged_edges.erase( |
619 | std::remove_if( |
620 | merged_edges.begin(), |
621 | merged_edges.end(), |
622 | [&sg1, &sg2](ExprGroupConnections* se) { |
623 | return (se->to == sg1 && se->from == sg2) || |
624 | (se->to == sg2 && se->from == sg1); |
625 | }), |
626 | merged_edges.end()); |
627 | |
628 | return merged_edges; |
629 | } |
630 | |
631 | // Concat's producer edges of sg1 and sg2, but removes any edges from/to sg1/sg2 |
632 | std::vector<ExprGroupConnections*> getMergedProducerEdges( |
633 | const ExprGroup* sg1, |
634 | const ExprGroup* sg2) { |
635 | return getMergedEdges(sg1, sg1->producerEdges(), sg2, sg2->producerEdges()); |
636 | } |
637 | |
638 | // Concat's consumer edges of sg1 and sg2, but removes any edges from/to sg1/sg2 |
639 | std::vector<ExprGroupConnections*> getMergedConsumerEdges( |
640 | const ExprGroup* sg1, |
641 | const ExprGroup* sg2) { |
642 | return getMergedEdges(sg1, sg1->consumerEdges(), sg2, sg2->consumerEdges()); |
643 | } |
644 | |
645 | // Assuming sg1 and sg2 are connected, figure out which is the consumer |
646 | ExprGroup* getProducer(ExprGroup* sg1, ExprGroup* sg2) { |
647 | for (auto producer_edge : sg1->producerEdges()) { |
648 | if (producer_edge->from == sg2) { |
649 | return sg2; |
650 | } |
651 | } |
652 | |
653 | for (auto consumer_edge : sg1->consumerEdges()) { |
654 | if (consumer_edge->to == sg2) { |
655 | return sg1; |
656 | } |
657 | } |
658 | |
659 | return nullptr; |
660 | } |
661 | |
662 | std::vector<IterDomain*> getLocalDomainOrdering( |
663 | const std::vector<Expr*>& exprs, |
664 | const std::unordered_set<IterDomain*> filter, |
665 | const std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>& |
666 | concrete_id_dependencies) { |
667 | if (exprs.empty()) { |
668 | return std::vector<IterDomain*>(); |
669 | } |
670 | |
671 | const auto& ca_map = GpuLower::current()->caMap(); |
672 | |
673 | std::unordered_set<IterDomain*> domains; |
674 | |
675 | for (auto expr : exprs) { |
676 | if (!ir_utils::isTvOp(expr)) { |
677 | continue; |
678 | } |
679 | |
680 | auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs()); |
681 | for (auto tv_input : tv_inputs) { |
682 | std::vector<IterDomain*> domain; |
683 | |
684 | std::transform( |
685 | tv_input->domain()->domain().begin(), |
686 | tv_input->domain()->domain().begin() + |
687 | std::max( |
688 | tv_input->getComputeAtPosition(), |
689 | tv_input->getMaxProducerPosition()), |
690 | std::back_inserter(domain), |
691 | [&ca_map](IterDomain* id) { |
692 | return ca_map->getConcreteMappedID(id, IdMappingMode::LOOP); |
693 | }); |
694 | |
695 | domain.erase( |
696 | std::remove_if( |
697 | domain.begin(), |
698 | domain.end(), |
699 | [&filter, &ca_map](IterDomain* id) { |
700 | return filter.find(ca_map->getConcreteMappedID( |
701 | id, IdMappingMode::LOOP)) == filter.end(); |
702 | }), |
703 | domain.end()); |
704 | |
705 | domains.insert(domain.begin(), domain.end()); |
706 | } |
707 | } |
708 | |
709 | std::vector<IterDomain*> merged_domain(domains.begin(), domains.end()); |
710 | std::sort( |
711 | merged_domain.begin(), |
712 | merged_domain.end(), |
713 | ir_utils::IterDomainDependencySorter( |
714 | concrete_id_dependencies, GpuLower::current()->caMap())); |
715 | return merged_domain; |
716 | } |
717 | } // namespace |
718 | |
719 | // Disconect group from neighbors, and return edges that were disconnected |
720 | std::unordered_set<ExprGroupConnections*> ExprSegmentationSorter:: |
721 | disconnectGroup(ExprGroup* group) { |
722 | std::unordered_set<ExprGroupConnections*> removed_edges( |
723 | group->producerEdges().begin(), group->producerEdges().end()); |
724 | |
725 | for (auto edge : group->producerEdges()) { |
726 | edge->from->removeConsumerEdge(edge); |
727 | } |
728 | |
729 | for (auto edge : group->consumerEdges()) { |
730 | edge->to->removeProducerEdge(edge); |
731 | } |
732 | |
733 | group->clearProducerEdges(); |
734 | group->clearConsumerEdges(); |
735 | |
736 | return removed_edges; |
737 | } |
738 | |
739 | // TODO: This function may be sub optimial. If we find that an iteration domain |
740 | // matches later in the other domain, we will hold all other iteration domains |
741 | // until that one matches. There may be cases where duplicating that iteration |
742 | // domain, and moving on could be more efficient. |
743 | ExprGroup* ExprSegmentationSorter::makeMergedNode( |
744 | ExprGroup* sg1, |
745 | ExprGroup* sg2) { |
746 | // Keep Expr's sorted in topological order. |
747 | const auto producer = getProducer(sg1, sg2); |
748 | const auto consumer = sg1 == producer ? sg2 : sg1; |
749 | |
750 | // Make the new joined node |
751 | auto joined_groups = makeEmptyGroup(); |
752 | |
753 | TORCH_INTERNAL_ASSERT( |
754 | producer != nullptr, |
755 | "Tried to merge expr's together that aren't neighbors." ); |
756 | |
757 | joined_groups->exprs() = producer->exprs(); |
758 | joined_groups->exprs().insert( |
759 | joined_groups->exprs().end(), |
760 | consumer->exprs().begin(), |
761 | consumer->exprs().end()); |
762 | |
763 | auto producer_edges = getMergedProducerEdges(sg1, sg2); |
764 | // Connect joined group to resulting neighbors |
765 | for (auto& edge : producer_edges) { |
766 | auto from = edge->from; |
767 | auto producer_val = edge->producer_val_; |
768 | auto consumer_val = edge->consumer_val_; |
769 | |
770 | edges_.push_back(std::make_unique<ExprGroupConnections>( |
771 | from, joined_groups, producer_val, consumer_val)); |
772 | |
773 | joined_groups->addProducerEdge(edges_.back().get()); |
774 | from->addConsumerEdge(edges_.back().get()); |
775 | } |
776 | |
777 | auto consumer_edges = getMergedConsumerEdges(sg1, sg2); |
778 | |
779 | for (auto& edge : consumer_edges) { |
780 | auto to = edge->to; |
781 | auto producer_val = edge->producer_val_; |
782 | auto consumer_val = edge->consumer_val_; |
783 | |
784 | edges_.push_back(std::make_unique<ExprGroupConnections>( |
785 | joined_groups, to, producer_val, consumer_val)); |
786 | joined_groups->addConsumerEdge(edges_.back().get()); |
787 | edge->to->addProducerEdge(edges_.back().get()); |
788 | } |
789 | |
790 | // Merge the compute at domain of all edges going out from the newly joined |
791 | // group. The val's we're looking for are from our consumer edges, but we want |
792 | // to grab the producer val as that's the one we generate. |
793 | std::unordered_set<IterDomain*> ca_ids; |
794 | for (auto consumer_group_edge : joined_groups->consumerEdges()) { |
795 | auto producer_of_consumer_edge = consumer_group_edge->producer_val_; |
796 | if (producer_of_consumer_edge->isA<TensorView>()) { |
797 | auto tv = producer_of_consumer_edge->as<TensorView>(); |
798 | for (const auto tv_i : c10::irange(tv->getComputeAtPosition())) { |
799 | ca_ids.emplace(GpuLower::current()->caMap()->getConcreteMappedID( |
800 | tv->axis(tv_i), IdMappingMode::LOOP)); |
801 | } |
802 | } |
803 | } |
804 | |
805 | // Merge the produce at domain of all edges coming into the newly joined |
806 | // group. The val's we're looking for are from our producer edges, but we want |
807 | // to grab the consumer val as that's the one we generate. |
808 | std::unordered_set<IterDomain*> pa_ids; |
809 | for (auto producer_group_edge : joined_groups->producerEdges()) { |
810 | auto consumer_of_producer_edge = producer_group_edge->consumer_val_; |
811 | if (consumer_of_producer_edge->isA<TensorView>()) { |
812 | auto tv = consumer_of_producer_edge->as<TensorView>(); |
813 | for (const auto tv_i : c10::irange(tv->getMaxProducerPosition())) { |
814 | pa_ids.emplace(GpuLower::current()->caMap()->getConcreteMappedID( |
815 | tv->axis(tv_i), IdMappingMode::LOOP)); |
816 | } |
817 | } |
818 | } |
819 | |
820 | auto all_ca_pa_ids = ca_ids; |
821 | all_ca_pa_ids.insert(pa_ids.begin(), pa_ids.end()); |
822 | |
823 | auto ordered_ids = getLocalDomainOrdering( |
824 | joined_groups->exprs(), all_ca_pa_ids, concrete_id_dependencies); |
825 | |
826 | for (auto id : ordered_ids) { |
827 | if (ca_ids.count(id)) { |
828 | joined_groups->payload()->ca_domains_.emplace_back(id); |
829 | } |
830 | if (pa_ids.count(id)) { |
831 | joined_groups->payload()->pa_domains_.emplace_back(id); |
832 | } |
833 | } |
834 | |
835 | return joined_groups; |
836 | } |
837 | |
838 | bool canReducePA(ExprGroup* group) { |
839 | if (group->payload()->pa_domains_.empty()) { |
840 | return false; |
841 | } |
842 | |
843 | IterDomain* group_pa_last_id = group->payload()->pa_domains_.back(); |
844 | |
845 | // Look through producer edges to see if we can reduce our produce at domain |
846 | for (auto producer_edge : group->producerEdges()) { |
847 | auto producer_val = producer_edge->producer_val_; |
848 | auto consumer_val = producer_edge->consumer_val_; |
849 | |
850 | // If producer isn't a tensor view it can't be mapped into a producer dim of |
851 | // this group |
852 | if (!(consumer_val->isA<TensorView>() && producer_val->isA<TensorView>())) { |
853 | continue; |
854 | } |
855 | |
856 | // If the compute at domains of the producer group is empty, it can't map to |
857 | // the produce at domains of this group |
858 | auto producer_group = producer_edge->from; |
859 | if (producer_group->payload()->ca_domains_.empty()) { |
860 | continue; |
861 | } |
862 | |
863 | auto producer_tv = producer_val->as<TensorView>(); |
864 | auto consumer_tv = consumer_val->as<TensorView>(); |
865 | |
866 | // If this consumer_tv doesn't map to the last producer domain of this group |
867 | // it can't decide if it can be reduced |
868 | bool has_matching_pa = false; |
869 | for (const auto i : c10::irange(consumer_tv->getMaxProducerPosition())) { |
870 | if (GpuLower::current()->caMap()->areMapped( |
871 | consumer_tv->axis(i), group_pa_last_id, IdMappingMode::LOOP)) { |
872 | has_matching_pa = true; |
873 | break; |
874 | } |
875 | } |
876 | |
877 | if (!has_matching_pa) { |
878 | continue; |
879 | } |
880 | |
881 | // If any compute at positions of producers directly map to the last produce |
882 | // at position it can't be lowered. |
883 | for (int producer_pos_i = |
884 | static_cast<int>(producer_tv->getComputeAtPosition()); |
885 | producer_pos_i > 0; |
886 | producer_pos_i--) { |
887 | if (GpuLower::current()->caMap()->areMapped( |
888 | producer_tv->axis(producer_pos_i - 1), |
889 | group_pa_last_id, |
890 | IdMappingMode::LOOP)) { |
891 | return false; |
892 | } |
893 | } |
894 | } |
895 | |
896 | return true; |
897 | } |
898 | |
899 | // Update in between attempts to segment. This is called once no more groups |
900 | // can be merged together. Typically we will want to remove compute at groups |
901 | // that have finished being grouped together. However if no groups have been |
902 | // merged after we've done this, we may need to stop as we could have multiple |
903 | // disjoint groups that won't be merged. |
904 | bool ExprSegmentationSorter::interIterUpdate() { |
905 | // Go through groups and lower either pa or ca domain return if anything was |
906 | // lowered |
907 | bool lowered_a_domain = false; |
908 | for (auto& group : groups_) { |
909 | if (canReducePA(group.get())) { |
910 | group->payload()->pa_domains_.pop_back(); |
911 | lowered_a_domain = true; |
912 | } |
913 | } |
914 | |
915 | // If we couldn't lower compute at domain any further, and we haven't merged |
916 | // any new groups after fallback_mode_enabled_ has been turned on, make sure |
917 | // we've finished successfully |
918 | if (!lowered_a_domain && n_groups_ == groups_.size()) { |
919 | // Make sure none of the groups are still connected, as that would mean we |
920 | // should have been able to merge them. |
921 | bool successfully_finished = std::all_of( |
922 | groups_.begin(), groups_.end(), [](std::unique_ptr<ExprGroup>& sg) { |
923 | return sg->producerEdges().empty() && sg->consumerEdges().empty(); |
924 | }); |
925 | if (successfully_finished) { |
926 | return false; |
927 | } |
928 | // If we didn't finish and we tried the fallback, throw. |
929 | TORCH_INTERNAL_ASSERT( |
930 | !fallback_mode_enabled_, |
931 | "Couldn't successfully sort out the fusion expressions. " , |
932 | "There are remaining connections of the hierarchical segmentation which should have been " , |
933 | "flattened to a single ordered group, or disjoint ordered groups." ); |
934 | // We didn't finish, but we haven't tried the fallback, try again with that. |
935 | fallback_mode_enabled_ = true; |
936 | } |
937 | |
938 | n_groups_ = groups_.size(); |
939 | // Not done, continue. |
940 | return true; |
941 | } |
942 | |
943 | void ExprSegmentationSorter::mergeNodes() { |
944 | std::unordered_set<ExprGroup*> clean_up_groups; |
945 | std::unordered_set<ExprGroupConnections*> clean_up_edges; |
946 | |
947 | while (!to_merge_.empty()) { |
948 | ExprGroup *group1 = nullptr, *group2 = nullptr; |
949 | std::tie(group1, group2) = to_merge_.back(); |
950 | to_merge_.pop_back(); |
951 | TORCH_INTERNAL_ASSERT( |
952 | group2 == group1->payload()->merge_with, |
953 | "Expression Sorter: inconsistent to_merge packing" ); |
954 | clean_up_groups.emplace(group1); |
955 | clean_up_groups.emplace(group2); |
956 | makeMergedNode(group1, group2); |
957 | } |
958 | |
959 | for (auto group : clean_up_groups) { |
960 | auto disconnected_edges = disconnectGroup(group); |
961 | clean_up_edges.insert(disconnected_edges.begin(), disconnected_edges.end()); |
962 | } |
963 | |
964 | edges_.remove_if([&](std::unique_ptr<ExprGroupConnections>& edge) { |
965 | return clean_up_edges.find(edge.get()) != clean_up_edges.end(); |
966 | }); |
967 | |
968 | groups_.remove_if([&](std::unique_ptr<ExprGroup>& group) { |
969 | return clean_up_groups.find(group.get()) != clean_up_groups.end(); |
970 | }); |
971 | } |
972 | |
973 | // Initialize concrete_id_dependencies and concrete_id_to_all_ids |
974 | void ExprSegmentationSorter::initializeForLoopDependencies() { |
975 | TORCH_INTERNAL_ASSERT( |
976 | concrete_id_dependencies.empty(), |
977 | "For loop dependencies have already been initialized." ); |
978 | |
979 | for (auto tv : ir_utils::allTvs(fusion_)) { |
980 | std::unordered_set<IterDomain*> dependencies; |
981 | for (size_t tv_id_i = |
982 | std::max(tv->getMaxProducerPosition(), tv->getComputeAtPosition()); |
983 | tv_id_i > 0; |
984 | tv_id_i--) { |
985 | auto tv_id = tv->axis((int)(tv_id_i - 1)); |
986 | auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( |
987 | tv_id, IdMappingMode::LOOP); |
988 | |
989 | if (concrete_id_dependencies.find(concrete_id) == |
990 | concrete_id_dependencies.end()) { |
991 | concrete_id_dependencies[concrete_id] = dependencies; |
992 | } else { |
993 | concrete_id_dependencies[concrete_id].insert( |
994 | dependencies.begin(), dependencies.end()); |
995 | } |
996 | |
997 | // Loops after tv_id are dependent on tv_id |
998 | dependencies.emplace(GpuLower::current()->caMap()->getConcreteMappedID( |
999 | tv_id, IdMappingMode::LOOP)); |
1000 | } |
1001 | } |
1002 | |
1003 | // Fill out dependencies as IDs will have local dependency information, but |
1004 | // it's still not guaranteed to be global. |
1005 | |
1006 | // If loop structure is something like: |
1007 | // T0 [I0] |
1008 | // T1 [I0, I1] |
1009 | // T2 [I1, I2] |
1010 | // |
1011 | // I1 will be marked as a dependency of I0 |
1012 | // I2 will be marked as a dependency of I1 |
1013 | // |
1014 | // However, I2 will not be marked as a dep of I0, so we need to fill out the |
1015 | // dependency analysis. This is done by iterating through IterDomains filling |
1016 | // out all the dependencies of dependencies recursively. |
1017 | |
1018 | std::deque<IterDomain*> to_visit; |
1019 | std::unordered_set<IterDomain*> visited; |
1020 | |
1021 | std::transform( |
1022 | concrete_id_dependencies.begin(), |
1023 | concrete_id_dependencies.end(), |
1024 | std::back_inserter(to_visit), |
1025 | [](const auto& concrete_dep_entry) { return concrete_dep_entry.first; }); |
1026 | |
1027 | size_t inf_loop_counter = to_visit.size(); |
1028 | bool failed = false; |
1029 | |
1030 | while (!to_visit.empty()) { |
1031 | auto id = to_visit.front(); |
1032 | to_visit.pop_front(); |
1033 | |
1034 | if (inf_loop_counter-- == 0) { |
1035 | failed = true; |
1036 | break; |
1037 | } |
1038 | |
1039 | auto& dependencies = concrete_id_dependencies.at(id); |
1040 | bool ready = dependencies.empty() || |
1041 | std::all_of(dependencies.begin(), |
1042 | dependencies.end(), |
1043 | [&visited](IterDomain* id) { return visited.count(id); }); |
1044 | |
1045 | if (!ready) { |
1046 | to_visit.push_back(id); |
1047 | continue; |
1048 | } |
1049 | |
1050 | inf_loop_counter = to_visit.size(); |
1051 | |
1052 | for (auto dependency : dependencies) { |
1053 | auto dep_of_dep = concrete_id_dependencies.at(dependency); |
1054 | dependencies.insert(dep_of_dep.begin(), dep_of_dep.end()); |
1055 | } |
1056 | visited.emplace(id); |
1057 | } |
1058 | if (failed) { |
1059 | std::cerr |
1060 | << "ERROR: Iteration domain sorting has failed, infinite loop detected." |
1061 | << std::endl; |
1062 | std::cerr << "Failed to sort out: " << std::endl; |
1063 | for (auto entry : to_visit) { |
1064 | std::cerr << entry->toString(); |
1065 | if (entry != to_visit.back()) { |
1066 | std::cerr << ", " ; |
1067 | } |
1068 | } |
1069 | |
1070 | std::cerr << "Dependencies: " << std::endl; |
1071 | for (const auto& dep_entry : concrete_id_dependencies) { |
1072 | std::cerr << " Deps of " << dep_entry.first->toString() << std::endl |
1073 | << " " ; |
1074 | |
1075 | for (auto dep : dep_entry.second) { |
1076 | std::cerr << dep->toString() << ", " ; |
1077 | } |
1078 | std::cerr << std::endl; |
1079 | } |
1080 | |
1081 | TORCH_INTERNAL_ASSERT(false); |
1082 | } |
1083 | } |
1084 | |
1085 | // Checks if the for loop associated with the concrete ID is ready to be |
1086 | // resolved in sorting. This could be done more efficiently with some |
1087 | // additional tracking, however we recreate ca_domain_ when we merge groups, |
1088 | // so it's hard to track what is no longer needed. |
1089 | bool ExprSegmentationSorter::loopReady(IterDomain* concrete_id) { |
1090 | const auto& dependencies = concrete_id_dependencies[concrete_id]; |
1091 | for (auto& group : groups_) { |
1092 | // Only need to check compute at domain here, because if there's an entry in |
1093 | // produce at, that has no matching entry in compute at, then that ID can be |
1094 | // removed as in canReducePA |
1095 | for (auto ca_domain : group->payload()->ca_domains_) { |
1096 | if (dependencies.count(ca_domain)) { |
1097 | return false; |
1098 | } |
1099 | } |
1100 | } |
1101 | return true; |
1102 | } |
1103 | |
1104 | // Two expression groups can be merged together if there's a value produced by |
1105 | // producer group, consumed by consumer group, where the compute at position |
1106 | // maps to the inner most compute at domain of the producer group and maps to |
1107 | // the inner most produce at domain of the consumer. If this value doesn't exist |
1108 | // we can't be certain these domains share the "next" inner most loop. |
1109 | // |
1110 | // We're looking for this because we're starting at the inner most loops of all |
1111 | // expressions, and looking for neighboring expressions that share inner loops. |
1112 | // Once we've found all the inner most loops that expressions share, we merge |
1113 | // them together, then look at the next inner most loop of the group and figure |
1114 | // out which other groups share this next inner most loop. |
1115 | bool ExprSegmentationSorter::supportedMerge(ExprGroup* sg1, ExprGroup* sg2) { |
1116 | auto producer_group = getProducer(sg1, sg2); |
1117 | auto consumer_group = sg1 == producer_group ? sg2 : sg1; |
1118 | |
1119 | if (producer_group->payload()->ca_domains_.size() < |
1120 | producer_group->payload()->pa_domains_.size()) { |
1121 | return false; |
1122 | } |
1123 | |
1124 | if (consumer_group->payload()->pa_domains_.size() < |
1125 | consumer_group->payload()->ca_domains_.size()) { |
1126 | return false; |
1127 | } |
1128 | |
1129 | const auto& producer_ca_domain = producer_group->payload()->ca_domains_; |
1130 | const auto& consumer_pa_domain = consumer_group->payload()->pa_domains_; |
1131 | |
1132 | if (producer_ca_domain.empty() && consumer_pa_domain.empty()) { |
1133 | return true; |
1134 | } |
1135 | |
1136 | if (producer_ca_domain.empty() || consumer_pa_domain.empty()) { |
1137 | return false; |
1138 | } |
1139 | |
1140 | // If inner loop dependencies have not been resolved, cannot merge. |
1141 | if (!loopReady(producer_ca_domain.back()) || |
1142 | !loopReady(consumer_pa_domain.back())) { |
1143 | return false; |
1144 | } |
1145 | |
1146 | for (auto edge : producer_group->consumerEdges()) { |
1147 | if (edge->to != consumer_group) { |
1148 | continue; |
1149 | } |
1150 | auto producer_val = edge->producer_val_; |
1151 | auto consumer_val = edge->consumer_val_; |
1152 | |
1153 | if (!producer_val->isA<TensorView>()) { |
1154 | continue; |
1155 | } |
1156 | |
1157 | TORCH_INTERNAL_ASSERT( |
1158 | consumer_val->isA<TensorView>(), |
1159 | "Mismatched tensorview to non-tensorview in expression sorting. " , |
1160 | producer_val, |
1161 | " is consumed by " , |
1162 | consumer_val); |
1163 | |
1164 | auto producer_tv = producer_val->as<TensorView>(); |
1165 | |
1166 | auto compute_at_pos = producer_tv->getComputeAtPosition(); |
1167 | auto compute_at_dim = compute_at_pos > 0 |
1168 | ? producer_tv->axis((int)producer_tv->getComputeAtPosition() - 1) |
1169 | : nullptr; |
1170 | |
1171 | if (compute_at_dim == nullptr) { |
1172 | continue; |
1173 | } |
1174 | |
1175 | if (!GpuLower::current()->caMap()->areMapped( |
1176 | compute_at_dim, producer_ca_domain.back(), IdMappingMode::LOOP)) { |
1177 | continue; |
1178 | } |
1179 | |
1180 | if (GpuLower::current()->caMap()->areMapped( |
1181 | compute_at_dim, consumer_pa_domain.back(), IdMappingMode::LOOP)) { |
1182 | return true; |
1183 | } |
1184 | } |
1185 | return false; |
1186 | } |
1187 | |
1188 | bool ExprSegmentationSorter::testStillDag(ExprGroup* sg1, ExprGroup* sg2) { |
1189 | std::deque<ExprGroup*> to_visit; |
1190 | std::unordered_set<ExprGroup*> visited; |
1191 | // Add consumers of sg1 if not sg2 |
1192 | for (auto sg1_consumer_edge : sg1->consumerEdges()) { |
1193 | if (sg1_consumer_edge->to != sg2) { |
1194 | to_visit.emplace_back(sg1_consumer_edge->to); |
1195 | } |
1196 | } |
1197 | |
1198 | // Add consumers of sg2 if not sg1 |
1199 | for (auto sg2_consumer_edge : sg2->consumerEdges()) { |
1200 | if (sg2_consumer_edge->to != sg1) { |
1201 | to_visit.emplace_back(sg2_consumer_edge->to); |
1202 | } |
1203 | } |
1204 | |
1205 | while (to_visit.size() > 0) { |
1206 | auto group = to_visit.front(); |
1207 | // Arrived back at one of the original groups, merging these two groups |
1208 | // would generate a cycle |
1209 | if (group == sg1 || group == sg2) { |
1210 | return false; |
1211 | } |
1212 | to_visit.pop_front(); |
1213 | if (visited.find(group) != visited.end()) { |
1214 | continue; |
1215 | } |
1216 | visited.emplace(group); |
1217 | for (auto consumer_edge : group->consumerEdges()) { |
1218 | to_visit.emplace_back(consumer_edge->to); |
1219 | } |
1220 | } |
1221 | |
1222 | // No cycles found, we're good. |
1223 | return true; |
1224 | } |
1225 | |
1226 | void ExprSegmentationSorter::sort() { |
1227 | // Need this for initialization of the DAG that is processed |
1228 | std::unordered_map<Expr*, ExprGroup*> expr2group; |
1229 | |
1230 | auto all_exprs = fusion_->exprs(); |
1231 | |
1232 | // Figure out all the values used as inputs to the expressions we're sorting |
1233 | // (to find terminating expressions). There could be branches of expressions |
1234 | // not used to produce outputs, so can't simply check val->uses() to figure |
1235 | // out if it's actually used in the expressions we're sorting. |
1236 | std::unordered_set<Val*> used_vals; |
1237 | for (auto expr : all_exprs) { |
1238 | used_vals.insert(expr->inputs().begin(), expr->inputs().end()); |
1239 | } |
1240 | |
1241 | // Initialize DAG, convert each expr to a segment group |
1242 | for (auto expr : all_exprs) { |
1243 | bool is_terminating_expr = std::none_of( |
1244 | expr->outputs().begin(), |
1245 | expr->outputs().end(), |
1246 | [&used_vals](Val* output) { return used_vals.count(output) > 0; }); |
1247 | auto group = makeEmptyGroup(expr, is_terminating_expr); |
1248 | expr2group.insert(std::make_pair(expr, group)); |
1249 | } |
1250 | |
1251 | // Create edges between the Exprs. Mark inputs and outputs of the fusion. |
1252 | for (auto expr : fusion_->exprs()) { |
1253 | auto expr_group = expr2group.at(expr); |
1254 | auto out = expr->outputs()[0]; |
1255 | for (auto inp : expr->inputs()) { |
1256 | if (inp->isFusionInput()) { |
1257 | continue; |
1258 | } |
1259 | |
1260 | // Could be something like a constant scalar, definition is nullptr, but |
1261 | // isn't an "input" to the fusion. At least not one provided by an |
1262 | // external source. |
1263 | if (inp->definition() == nullptr) { |
1264 | continue; |
1265 | } |
1266 | |
1267 | auto inp_def_group = expr2group.at(inp->definition()); |
1268 | edges_.push_back(std::make_unique<ExprGroupConnections>( |
1269 | inp_def_group, expr_group, inp, out)); |
1270 | expr_group->addProducerEdge(edges_.back().get()); |
1271 | inp_def_group->addConsumerEdge(edges_.back().get()); |
1272 | } |
1273 | } |
1274 | |
1275 | // Initialize loop dependency maps |
1276 | initializeForLoopDependencies(); |
1277 | |
1278 | bool inter_iter_update = true; |
1279 | while (inter_iter_update) { |
1280 | // If we didn't do any update, stop traversal, we're done. |
1281 | if (!fallback_mode_enabled_) { |
1282 | // Merge expressions in sorted order |
1283 | bool merged_nodes = true; |
1284 | while (merged_nodes) { |
1285 | // Reset stateful traversal details in ExprGroups |
1286 | resetTraversal(); |
1287 | resetLevels(); |
1288 | |
1289 | for (auto& group : groups_) { |
1290 | if (group->payload()->merged) { |
1291 | continue; |
1292 | } |
1293 | auto candidates = group->getMergeCandidates(fallback_mode_enabled_); |
1294 | if (candidates.empty()) { |
1295 | continue; |
1296 | } |
1297 | |
1298 | auto candidate_it = candidates.begin(); |
1299 | while (candidate_it != candidates.end() && |
1300 | !supportedMerge(group.get(), *candidate_it)) { |
1301 | candidate_it++; |
1302 | } |
1303 | if (candidate_it == candidates.end()) { |
1304 | continue; |
1305 | } |
1306 | |
1307 | to_merge_.emplace_back(std::make_pair(group.get(), *candidate_it)); |
1308 | |
1309 | group->payload()->merged = true; |
1310 | group->payload()->merge_with = *candidate_it; |
1311 | |
1312 | (*candidate_it)->payload()->merged = true; |
1313 | (*candidate_it)->payload()->merge_with = group.get(); |
1314 | } |
1315 | |
1316 | if (to_merge_.empty()) { |
1317 | merged_nodes = false; |
1318 | } |
1319 | |
1320 | mergeNodes(); |
1321 | |
1322 | // Move compute at axes left |
1323 | inter_iter_update = interIterUpdate(); |
1324 | } |
1325 | } else { |
1326 | // fallback_mode_enabled = true |
1327 | // Reset stateful traversal details in ExprGroups as we'll exclude merge |
1328 | // options that were already ruled out and therefore need traversal and |
1329 | // levels reset. |
1330 | resetTraversal(); |
1331 | resetLevels(); |
1332 | |
1333 | for (auto& group : groups_) { |
1334 | if (group->payload()->merged) { |
1335 | continue; |
1336 | } |
1337 | // Get merge candidates that weren't proven safe to merge with default |
1338 | // algorithm. |
1339 | auto candidates = group->getMergeCandidates(fallback_mode_enabled_); |
1340 | if (candidates.empty()) { |
1341 | continue; |
1342 | } |
1343 | |
1344 | auto candidate_it = candidates.begin(); |
1345 | |
1346 | while (candidate_it != candidates.end()) { |
1347 | while (candidate_it != candidates.end() && |
1348 | !supportedMerge(group.get(), *candidate_it)) { |
1349 | candidate_it++; |
1350 | } |
1351 | |
1352 | if (candidate_it == candidates.end()) { |
1353 | break; |
1354 | } |
1355 | |
1356 | if (testStillDag(group.get(), *candidate_it)) { |
1357 | // Mark in same style as default algorithm for convenience even |
1358 | // though we will only merge once with the fallback |
1359 | to_merge_.emplace_back(std::make_pair(group.get(), *candidate_it)); |
1360 | |
1361 | group->payload()->merged = true; |
1362 | group->payload()->merge_with = *candidate_it; |
1363 | |
1364 | (*candidate_it)->payload()->merged = true; |
1365 | (*candidate_it)->payload()->merge_with = group.get(); |
1366 | break; |
1367 | } |
1368 | |
1369 | candidate_it++; |
1370 | } |
1371 | |
1372 | if (to_merge_.size() > 0) { |
1373 | break; |
1374 | } |
1375 | } |
1376 | |
1377 | // If we can merge something, merge it, disable fallback, and bail |
1378 | if (to_merge_.size() > 0) { |
1379 | mergeNodes(); |
1380 | } |
1381 | |
1382 | // Move compute at axes left |
1383 | // If fallback didn't work, interIterUpdate will catch that we failed. |
1384 | inter_iter_update = interIterUpdate(); |
1385 | fallback_mode_enabled_ = false; |
1386 | } |
1387 | } |
1388 | } |
1389 | |
1390 | std::vector<Expr*> ExprSegmentationSorter::getExprs() const { |
1391 | std::vector<Expr*> exprs; |
1392 | for (auto& group : groups_) { |
1393 | exprs.insert(exprs.end(), group->exprs().begin(), group->exprs().end()); |
1394 | } |
1395 | return exprs; |
1396 | } |
1397 | |
1398 | } // namespace |
1399 | |
1400 | std::vector<Expr*> reorderExprsForComputeAt() { |
1401 | auto fusion = FusionGuard::getCurFusion(); |
1402 | if (fusion->exprs().empty()) { |
1403 | return {}; |
1404 | } |
1405 | TORCH_INTERNAL_ASSERT(fusion != nullptr); |
1406 | ExprSegmentationSorter sorter(fusion); |
1407 | sorter.sort(); |
1408 | auto sorted_exprs = sorter.getExprs(); |
1409 | TORCH_INTERNAL_ASSERT( |
1410 | sorted_exprs.size() > 0, |
1411 | "Error during expression sorting, no expressions produced." ); |
1412 | return sorted_exprs; |
1413 | } |
1414 | |
1415 | } // namespace cuda |
1416 | } // namespace fuser |
1417 | } // namespace jit |
1418 | } // namespace torch |
1419 | |