1#include <compute_at_map.h>
2#include <fusion.h>
3#include <instrumentation.h>
4#include <ir_all_nodes.h>
5#include <ir_iostream.h>
6#include <ir_utils.h>
7#include <lower2device.h>
8#include <lower_expr_sort.h>
9#include <lower_utils.h>
10
11#include <deque>
12#include <list>
13#include <sstream>
14#include <unordered_map>
15#include <unordered_set>
16#include <vector>
17
18namespace torch {
19namespace jit {
20namespace fuser {
21namespace cuda {
22
23namespace {
24
25// TODO: Review const model, and objects
26// ExprSegmentationSorter
27// Responsible for going through DAG and proposing things we could try to
28// merge together, calls "supportedMerge" on these proposed groups to see
29// if they should be merged together, then merges them if so.
30// ExprGroup
31// A group of exprs that are grouped together based on their loop nest
32// structures.
33// ExprGroupConnections
34// Holds vals and what they connect. In other words it's a val that is an
35// output of a ExprSegmentationSorter "from" and an input of
36// ExprSegmentationSorter "to". There's nothing preventing from a val being
37// between groups twice.
38// TODO: make sure there's nothing wrong with grouping of nodes that
39// have the same value input twice. i.e. (B = A*A)
40
41// Selecting segments to propose is based on the theorem 4.2 in the paper which
42// makes sure when segment the segmented graph will be a DAG (assumes Fusion is
43// already a DAG). The segmentation code relies on assumptions of DAG-ness
44// during segmentation, meaning proposed merging of groups must maintain the DAG
45// property of the graph.
46//
47// Julien Herrmann, Yusuf Özkaya, Bora Uçar, Kamer Kaya, Umit Catalyurek.
48// Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs.
49// SIAM Journal on Scientific Computing, Society for Industrial and Applied
50// Mathematics, 2019, 41 (4), pp.A2117-A2145. ff10.1137/18M1176865ff.
51// ffhal02306566f
52
53class ExprGroup;
54struct ExprGroupConnections;
55class ExprSegmentationSorter;
56
57// Debug printing disabled due to clang tidy, see below for definitions
58// std::ostream& operator<<(std::ostream& os, const ExprGroup* group);
59
60// Wrapper for values, these are edges between expr groups. Multiple edges can
61// exist between expr groups, and the same Val can show up more than once in
62// multiple edges.
63struct ExprGroupConnections {
64 ExprGroupConnections(
65 ExprGroup* group_from,
66 ExprGroup* group_to,
67 Val* producer_val,
68 Val* consumer_val)
69 : from(group_from),
70 to(group_to),
71 producer_val_(producer_val),
72 consumer_val_(consumer_val) {}
73 // Producer group from which the edge starts
74 ExprGroup* from;
75
76 // Consumer group from which the edge ends
77 ExprGroup* to;
78
79 // The value from the producer group connecting the groups
80 // This value helps us resolve the compute at position of expr groups
81
82 Val* producer_val_;
83
84 // The value that the producer val gets used to create on this edge
85 // This value helps us resolve the produce at position of expr groups
86 Val* consumer_val_;
87};
88
89struct ExprSortPayload : public PolymorphicBase {
90 // Need to track compute at domains as well as produce at domains. Produce at
91 // domains will be matched to producers compute at domains. Track the active
92 // domains that will be matched from inner most dim to outer most.
93 std::vector<IterDomain*> ca_domains_;
94 std::vector<IterDomain*> pa_domains_;
95
96 // Maximum path distance from an input expr group required for
97 // Theorem 4.2
98 int level = -1;
99
100 // Traversal marker, marks if this group has been visited by current pass
101 bool visited = false;
102
103 // Marks if this group is already selected to merge with another group, marks
104 // which group to merge with
105 ExprGroup* merge_with = nullptr;
106
107 // Marks if this group is already selected to merge with another group
108 bool merged = false;
109};
110
111// Groups together expressions which create a expr group
112class ExprGroup {
113 public:
114 ExprGroup() : payload_(std::make_unique<ExprSortPayload>()) {}
115
116 ExprGroup(Expr* expr) : payload_(std::make_unique<ExprSortPayload>()) {
117 exprs_.push_back(expr);
118 }
119
120 ExprGroup(const ExprGroup& other)
121 : payload_(new ExprSortPayload(*(other.payload_))) {}
122
123 ExprGroup& operator=(const ExprGroup& other) {
124 *payload_ = *other.payload_;
125 exprs_ = other.exprs_;
126 return *this;
127 }
128
129 // Clears the traversal information in the payload
130 void clearTraversalInfo();
131
132 // Returns all neighbors, producers and consumers
133 std::vector<ExprGroup*> getNeighbors();
134
135 // Return neighbors of this proven to be safe nodes to merge with in regards
136 // to maining an acyclic graph. This looks at, neighbors if merged, neighbors
137 // level, and merged neighbors of neighbors level. If fallback_mode_enabled
138 // will return the inverse set of ExprGroups that are proven to be safe
139 // merges.
140 std::vector<ExprGroup*> getMergeCandidates(
141 bool fallback_mode_enabled = false);
142
143 std::unique_ptr<ExprSortPayload>& payload() {
144 return payload_;
145 }
146
147 const auto& producerEdges() const {
148 return producer_edges_;
149 }
150
151 void addProducerEdge(ExprGroupConnections* edge) {
152 addEdge(producer_edges_, edge);
153 }
154
155 void removeProducerEdge(ExprGroupConnections* edge) {
156 removeEdge(producer_edges_, edge);
157 }
158
159 void clearProducerEdges() {
160 producer_edges_.clear();
161 }
162
163 const auto& consumerEdges() const {
164 return consumer_edges_;
165 }
166
167 void addConsumerEdge(ExprGroupConnections* edge) {
168 addEdge(consumer_edges_, edge);
169 }
170
171 void removeConsumerEdge(ExprGroupConnections* edge) {
172 removeEdge(consumer_edges_, edge);
173 }
174
175 void clearConsumerEdges() {
176 consumer_edges_.clear();
177 }
178
179 auto& exprs() {
180 return exprs_;
181 }
182
183 const auto& exprs() const {
184 return exprs_;
185 }
186
187 private:
188 static void addEdge(
189 std::vector<ExprGroupConnections*>& edges,
190 ExprGroupConnections* edge_to_add) {
191 edges.push_back(edge_to_add);
192 }
193
194 static void removeEdge(
195 std::vector<ExprGroupConnections*>& edges,
196 ExprGroupConnections* edge_to_remove) {
197 auto it = std::find(edges.begin(), edges.end(), edge_to_remove);
198 TORCH_INTERNAL_ASSERT(it != edges.end(), "Could not find edge to remove.");
199 edges.erase(it);
200 }
201
202 private:
203 // "Ancestor nodes", towards inputs of segmentedDAG
204 std::vector<ExprGroupConnections*> producer_edges_;
205
206 // "Descendent nodes", towards outputs of segmentedDAG
207 std::vector<ExprGroupConnections*> consumer_edges_;
208
209 // Exprs that make up the group
210 std::vector<Expr*> exprs_;
211
212 // Stateful traversal information
213 std::unique_ptr<ExprSortPayload> payload_;
214};
215
216// This class sorts expressions guarantees two things, 1) Tensors are produced
217// before they're consumed 2) If the production of two tensors are supposed to
218// share a for loop, they're in an order where they can. (1) is pretty standard
219// of ordering a DAG. (2) is where things get a bit complicated and why we do
220// this sorting through segmentation. Consider a section of a DAG: T4 = T3 + T2.
221// Where T2 and T3 are not inputs to the fusion, all tensors are 3D, and we want
222// the production of T3 to share the inner most loop of T4 and we want the
223// production of T2 to share the middle loop with T4. i.e. we're looking for
224// For(i:I){
225// For(j: J){
226// For(k: K){
227// T2[i, j, k] = ...
228// }
229// For(k: K){
230// T3[i, j, k] = ...
231// T4[i, j, k] = T2[i, j, k] + T3[i, j, k]
232// }
233// }
234// }
235// The only valid ordering of expressions is producing T2, then T3, then T4. If
236// we swapped T3 and T2, then T3 and T4 couldn't share their inner most loop,
237// because T2 has its own inner most loop. If we swapped either tensor with T4,
238// then we'd try to be using T2 or T3 without producing them (back to gaurantee
239// 1).
240class ExprSegmentationSorter {
241 public:
242 ExprSegmentationSorter(Fusion* fusion) : fusion_(fusion) {}
243
244 void sort();
245
246 std::string toString(int verbosity = 0) const;
247
248 //! Returns a flattened list of sorted exprs
249 std::vector<Expr*> getExprs() const;
250
251 private:
252 // Allocate an empty expr group and return it
253 ExprGroup* makeEmptyGroup();
254
255 // Allocate an expr group with the provided expr and return it. Also requires
256 // information on if this expression is a terminating expression (none of its
257 // outputs are used in other expressions being sorted).
258 ExprGroup* makeEmptyGroup(Expr*, bool terminating_expr);
259
260 // Returns if sg1 and sg2 should be merged together, is called if they can
261 // based on the current status of the DAG.
262 bool supportedMerge(ExprGroup* sg1, ExprGroup* sg2);
263
264 // Returns true if the graph will remain an acyclic graph after merging sg1
265 // and sg2
266 bool testStillDag(ExprGroup* sg1, ExprGroup* sg2);
267
268 // Merges two ExprGroups and returns the new ExprGroup
269 ExprGroup* makeMergedNode(ExprGroup* sg1, ExprGroup* sg2);
270
271 // This is called once no more groups can be merged together. This will lower
272 // the compute at position of a segment group if the last dimension of the
273 // segment group doesn't map to any of the dimensions of its neighbors.
274 bool interIterUpdate();
275
276 // Reset the ExprSortPayload of the groups so we can traverse and identify
277 // merge candidates.
278 void resetTraversal();
279
280 // Reset the set levels of each group. This is what's used to identify which
281 // nodes can be merged together.
282 void resetLevels();
283
284 // Go through groups that are marked as to merge and merge them.
285 void mergeNodes();
286
287 // Initialize concrete_id_dependencies
288 void initializeForLoopDependencies();
289
290 // Checks if the for loop associated with the concrete ID is ready to be
291 // resolved in sorting.
292 bool loopReady(IterDomain* concrete_id);
293
294 // Disconnect the edges connecting group to the rest of the graph, and return
295 // all the edges that were disconnected
296 std::unordered_set<ExprGroupConnections*> disconnectGroup(ExprGroup* group);
297
298 private:
299 // Track how many groups we have from iteration to iteration so we can track
300 // when we've stopped merging nodes.
301 size_t n_groups_ = 0;
302
303 // Lifetime of the graph view of the fusion and segmentation. Use list to not
304 // invalidate any entries on insertion/deletion.
305 std::list<std::unique_ptr<ExprGroupConnections>> edges_;
306 std::list<std::unique_ptr<ExprGroup>> groups_;
307
308 std::deque<ExprGroup*> to_visit_;
309
310 std::vector<std::pair<ExprGroup*, ExprGroup*>> to_merge_;
311
312 Fusion* fusion_;
313
314 // We use a theorem out of a paper mentioned in other comments. This theorem
315 // is good at identifying multiple expr groups to merge during a single
316 // iteration without producing a cyclic graph from an acyclic graph. This
317 // theorem is not guaranteed to find all possible nodes that can be merged
318 // together. We need to be able to group all disjoint groups of exprs or
319 // we fail to generate code. Therefore, if we can't find anything to make
320 // forward progress based on the theorem we fallback to manually looking if we
321 // can segmenet all combinations we haven't previously looked at.
322 bool fallback_mode_enabled_ = false;
323
324 // We need to track ID resolution, see AdvancedLowering6. For loops need
325 // to be resolved from inner most to outer most. We therefore track
326 // loop dependencies where inner most loops need to be fully resolved before
327 // we can resolve the next outer loop. We track this by looking at all tensor
328 // views, and each iteration domain. An iter domain in the outer most position
329 // has dependencies on all inner dimensions. This tracking is done on concrete
330 // id's in the loop map, this is because IDs may exist in some TVs but not
331 // others, however, we need a "global" view to track these dependencies.
332 std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
333 concrete_id_dependencies;
334};
335
336// // Debug printing, disabled due to clang-tidy see above for declarations.
337// std::ostream& operator<<(std::ostream& os, ExprGroup* group) {
338// os << "Group Start{\n ca, pa ("
339// << group->payload()->ca_domains_.size() << ", "
340// << group->payload()->pa_domains_.size() << ")";
341// os << " ca_ids {";
342// for (size_t i = 0; i < group->payload()->ca_domains_.size(); i++) {
343// os << group->payload()->ca_domains_[i];
344// if (i + 1 != group->payload()->ca_domains_.size())
345// os << ", ";
346// }
347// os << "} pa_ids {";
348// for (size_t i = 0; i < group->payload()->pa_domains_.size(); i++) {
349// os << group->payload()->pa_domains_[i];
350// if (i + 1 != group->payload()->pa_domains_.size())
351// os << ", ";
352// }
353// os << "}";
354// os << "\nExprs {\n";
355// for(auto expr : group->exprs()){
356// os << expr;
357// }
358// os << "}Group End\n";
359// return os;
360// }
361
362std::vector<ExprGroup*> ExprGroup::getNeighbors() {
363 std::vector<ExprGroup*> neighbors;
364 for (auto inp : producer_edges_) {
365 neighbors.push_back(inp->from);
366 }
367 for (auto out : consumerEdges()) {
368 neighbors.push_back(out->to);
369 }
370 return neighbors;
371}
372
373std::vector<ExprGroup*> ExprGroup::getMergeCandidates(
374 bool fallback_mode_enabled) {
375 std::vector<ExprGroup*> neighbors = getNeighbors();
376
377 // Don't look for candidates if already merged
378 if (payload()->merged) {
379 return {};
380 }
381
382 // Can this node be merged with another? Check if neighbors are merged, if
383 // so and merged neighbor is within 1 level or node merged with neighbor is
384 // within 1 level, can't merge this node with anything else.
385 bool can_merge_this = true;
386 bool neighbor_merged = false;
387 for (auto neighbor : neighbors) {
388 if (!neighbor->payload()->merged) {
389 continue;
390 }
391 neighbor_merged = true;
392 if (std::abs(neighbor->payload()->level - payload()->level) <= 1) {
393 can_merge_this = false;
394 }
395 if (std::abs(
396 neighbor->payload()->merge_with->payload()->level -
397 payload()->level) <= 1) {
398 can_merge_this = false;
399 }
400 }
401
402 // If something prevents us from merging this node, and we're not in fallback
403 // mode, return empty set.
404 if (!can_merge_this && !fallback_mode_enabled) {
405 return {};
406 }
407
408 // If fallback mode already detected a merge somewhere, we shouldn't still be
409 // traversing.
410 if (fallback_mode_enabled) {
411 TORCH_INTERNAL_ASSERT(
412 !neighbor_merged,
413 "Shouldn't still be traversing in fallback mode if a merge was found.");
414 }
415
416 std::vector<bool> can_merge(neighbors.size(), true);
417
418 // Find neighbors with a level that is only 1 different than this group's
419 // level
420 for (const auto i : c10::irange(neighbors.size())) {
421 if (std::abs(neighbors[i]->payload()->level - payload()->level) > 1) {
422 can_merge[i] = false;
423 }
424 }
425
426 // Check neighbor of neighbors we're considering, if any of them are merged
427 // with another node, make sure the resulting edge wouldn't have a level
428 // difference of 1
429 for (const auto i : c10::irange(neighbors.size())) {
430 if (!can_merge[i]) {
431 continue;
432 }
433
434 for (auto neighbor_neighbor : neighbors[i]->getNeighbors()) {
435 // Don't check self
436 if (neighbor_neighbor == neighbors[i]) {
437 continue;
438 }
439 if (neighbor_neighbor->payload()->merged) {
440 // check neighbor_neighbor level
441 if (std::abs(neighbor_neighbor->payload()->level - payload()->level) <=
442 1) {
443 can_merge[i] = false;
444 }
445 if (std::abs(
446 neighbor_neighbor->payload()->level -
447 neighbors[i]->payload()->level) <= 1) {
448 can_merge[i] = false;
449 }
450
451 // check neighbor_neighber->merged->level
452 if (std::abs(
453 neighbor_neighbor->payload()->merge_with->payload()->level -
454 payload()->level) <= 1) {
455 can_merge[i] = false;
456 }
457 if (std::abs(
458 neighbor_neighbor->payload()->merge_with->payload()->level -
459 neighbors[i]->payload()->level) <= 1) {
460 can_merge[i] = false;
461 }
462 }
463 }
464 }
465
466 std::vector<ExprGroup*> merge_candidates;
467 for (const auto i : c10::irange(neighbors.size())) {
468 if ((can_merge[i] && !fallback_mode_enabled) ||
469 (!can_merge[i] && fallback_mode_enabled)) {
470 merge_candidates.push_back(neighbors[i]);
471 }
472 }
473 return merge_candidates;
474}
475
476void ExprGroup::clearTraversalInfo() {
477 payload()->level = -1;
478 payload()->visited = false;
479 payload()->merge_with = nullptr;
480 payload()->merged = false;
481}
482
483void ExprSegmentationSorter::resetTraversal() {
484 for (auto& group : groups_) {
485 // Start traversal at input groups
486 if (group->producerEdges().empty()) {
487 to_visit_.push_back(group.get());
488 }
489 group->clearTraversalInfo();
490 }
491}
492
493// Level is maximum distance from inputs. It's the metric used to select what
494// nodes can be merged while maintaining a DAG
495void ExprSegmentationSorter::resetLevels() {
496 std::vector<ExprGroup*> next_to_visit;
497
498 while (!to_visit_.empty()) {
499 auto visit = to_visit_.front();
500 to_visit_.pop_front();
501
502 // All inputs processed?
503 bool ready = true;
504 if (!visit->producerEdges().empty()) {
505 ready = std::all_of(
506 visit->producerEdges().begin(),
507 visit->producerEdges().end(),
508 [&](ExprGroupConnections* dep) {
509 return dep->from->payload()->visited;
510 });
511 }
512
513 if (!ready) {
514 // In case traversal doesn't complete because there's an error in the
515 // DAG topology.
516 next_to_visit.push_back(visit);
517 continue;
518 }
519
520 visit->payload()->visited = true;
521
522 to_visit_.insert(
523 to_visit_.end(), next_to_visit.begin(), next_to_visit.end());
524 next_to_visit.clear();
525
526 for (auto out : visit->consumerEdges()) {
527 to_visit_.push_back(out->to);
528 }
529
530 visit->payload()->level = 0;
531 for (auto inp : visit->producerEdges()) {
532 visit->payload()->level =
533 std::max(visit->payload()->level, inp->from->payload()->level + 1);
534 }
535 }
536 TORCH_INTERNAL_ASSERT(next_to_visit.empty(), "Error in graph, is not a DAG.");
537}
538
539ExprGroup* ExprSegmentationSorter::makeEmptyGroup() {
540 groups_.push_back(std::make_unique<ExprGroup>());
541 return groups_.back().get();
542}
543
544ExprGroup* ExprSegmentationSorter::makeEmptyGroup(
545 Expr* expr,
546 bool terminating_expr) {
547 auto group = makeEmptyGroup();
548 group->exprs().push_back(expr);
549 if (ir_utils::isTvOp(expr)) {
550 auto out_tv = expr->outputs()[0]->as<TensorView>();
551 // Grab all id's that are shared with other tensors.
552 // If not connected to consumers, doesn't mater what compute at is set to
553 if (!terminating_expr) {
554 for (const auto tv_i : c10::irange(out_tv->getComputeAtPosition())) {
555 group->payload()->ca_domains_.push_back(out_tv->axis(tv_i));
556 }
557 }
558 for (const auto tv_i : c10::irange(out_tv->getMaxProducerPosition())) {
559 group->payload()->pa_domains_.push_back(out_tv->axis(tv_i));
560 }
561 }
562 return group;
563}
564
565// Debug function that prints the current state of the sorter.
566//
567// Uncomment if needed.
568// std::string ExprSegmentationSorter::toString(int verbosity) const {
569// std::stringstream ss;
570// ss << "{\n";
571// for (auto& group : groups_) {
572// ss << " " << group.get() << "\n";
573
574// if (verbosity > 1) {
575// if (group->producerEdges().size() > 0) {
576// ss << "Produced by groups with edges: { \n";
577// for (auto producer_edge : group->producerEdges()) {
578// ss << producer_edge->producer_val_ << " -> "
579// << producer_edge->consumer_val_ << "\n";
580// }
581// ss << " }"
582// << "\n";
583// }
584// }
585
586// if (verbosity > 1) {
587// if (group->consumerEdges().size() > 0) {
588// ss << "Consumed by groups with edges: { \n";
589// for (auto consumer_edge : group->consumerEdges()) {
590// ss << consumer_edge->producer_val_ << " -> "
591// << consumer_edge->consumer_val_ << "\n";
592// }
593// ss << " }"
594// << "\n";
595// }
596// }
597// }
598// ss << "}\n";
599// return ss.str();
600// }
601
602namespace {
603
604// Concat's edges of sg1 and sg2, but removes any edges from/to sg1/sg2
605std::vector<ExprGroupConnections*> getMergedEdges(
606 const ExprGroup* sg1,
607 const std::vector<ExprGroupConnections*>& edges1,
608 const ExprGroup* sg2,
609 const std::vector<ExprGroupConnections*>& edges2) {
610 TORCH_INTERNAL_ASSERT(
611 sg1 != nullptr && sg2 != nullptr,
612 "This function doesn't handle trivial.");
613
614 auto merged_edges = edges1;
615 merged_edges.insert(merged_edges.end(), edges2.begin(), edges2.end());
616
617 // Remove intra edges
618 merged_edges.erase(
619 std::remove_if(
620 merged_edges.begin(),
621 merged_edges.end(),
622 [&sg1, &sg2](ExprGroupConnections* se) {
623 return (se->to == sg1 && se->from == sg2) ||
624 (se->to == sg2 && se->from == sg1);
625 }),
626 merged_edges.end());
627
628 return merged_edges;
629}
630
631// Concat's producer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
632std::vector<ExprGroupConnections*> getMergedProducerEdges(
633 const ExprGroup* sg1,
634 const ExprGroup* sg2) {
635 return getMergedEdges(sg1, sg1->producerEdges(), sg2, sg2->producerEdges());
636}
637
638// Concat's consumer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
639std::vector<ExprGroupConnections*> getMergedConsumerEdges(
640 const ExprGroup* sg1,
641 const ExprGroup* sg2) {
642 return getMergedEdges(sg1, sg1->consumerEdges(), sg2, sg2->consumerEdges());
643}
644
645// Assuming sg1 and sg2 are connected, figure out which is the consumer
646ExprGroup* getProducer(ExprGroup* sg1, ExprGroup* sg2) {
647 for (auto producer_edge : sg1->producerEdges()) {
648 if (producer_edge->from == sg2) {
649 return sg2;
650 }
651 }
652
653 for (auto consumer_edge : sg1->consumerEdges()) {
654 if (consumer_edge->to == sg2) {
655 return sg1;
656 }
657 }
658
659 return nullptr;
660}
661
662std::vector<IterDomain*> getLocalDomainOrdering(
663 const std::vector<Expr*>& exprs,
664 const std::unordered_set<IterDomain*> filter,
665 const std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>&
666 concrete_id_dependencies) {
667 if (exprs.empty()) {
668 return std::vector<IterDomain*>();
669 }
670
671 const auto& ca_map = GpuLower::current()->caMap();
672
673 std::unordered_set<IterDomain*> domains;
674
675 for (auto expr : exprs) {
676 if (!ir_utils::isTvOp(expr)) {
677 continue;
678 }
679
680 auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
681 for (auto tv_input : tv_inputs) {
682 std::vector<IterDomain*> domain;
683
684 std::transform(
685 tv_input->domain()->domain().begin(),
686 tv_input->domain()->domain().begin() +
687 std::max(
688 tv_input->getComputeAtPosition(),
689 tv_input->getMaxProducerPosition()),
690 std::back_inserter(domain),
691 [&ca_map](IterDomain* id) {
692 return ca_map->getConcreteMappedID(id, IdMappingMode::LOOP);
693 });
694
695 domain.erase(
696 std::remove_if(
697 domain.begin(),
698 domain.end(),
699 [&filter, &ca_map](IterDomain* id) {
700 return filter.find(ca_map->getConcreteMappedID(
701 id, IdMappingMode::LOOP)) == filter.end();
702 }),
703 domain.end());
704
705 domains.insert(domain.begin(), domain.end());
706 }
707 }
708
709 std::vector<IterDomain*> merged_domain(domains.begin(), domains.end());
710 std::sort(
711 merged_domain.begin(),
712 merged_domain.end(),
713 ir_utils::IterDomainDependencySorter(
714 concrete_id_dependencies, GpuLower::current()->caMap()));
715 return merged_domain;
716}
717} // namespace
718
719// Disconect group from neighbors, and return edges that were disconnected
720std::unordered_set<ExprGroupConnections*> ExprSegmentationSorter::
721 disconnectGroup(ExprGroup* group) {
722 std::unordered_set<ExprGroupConnections*> removed_edges(
723 group->producerEdges().begin(), group->producerEdges().end());
724
725 for (auto edge : group->producerEdges()) {
726 edge->from->removeConsumerEdge(edge);
727 }
728
729 for (auto edge : group->consumerEdges()) {
730 edge->to->removeProducerEdge(edge);
731 }
732
733 group->clearProducerEdges();
734 group->clearConsumerEdges();
735
736 return removed_edges;
737}
738
739// TODO: This function may be sub optimial. If we find that an iteration domain
740// matches later in the other domain, we will hold all other iteration domains
741// until that one matches. There may be cases where duplicating that iteration
742// domain, and moving on could be more efficient.
743ExprGroup* ExprSegmentationSorter::makeMergedNode(
744 ExprGroup* sg1,
745 ExprGroup* sg2) {
746 // Keep Expr's sorted in topological order.
747 const auto producer = getProducer(sg1, sg2);
748 const auto consumer = sg1 == producer ? sg2 : sg1;
749
750 // Make the new joined node
751 auto joined_groups = makeEmptyGroup();
752
753 TORCH_INTERNAL_ASSERT(
754 producer != nullptr,
755 "Tried to merge expr's together that aren't neighbors.");
756
757 joined_groups->exprs() = producer->exprs();
758 joined_groups->exprs().insert(
759 joined_groups->exprs().end(),
760 consumer->exprs().begin(),
761 consumer->exprs().end());
762
763 auto producer_edges = getMergedProducerEdges(sg1, sg2);
764 // Connect joined group to resulting neighbors
765 for (auto& edge : producer_edges) {
766 auto from = edge->from;
767 auto producer_val = edge->producer_val_;
768 auto consumer_val = edge->consumer_val_;
769
770 edges_.push_back(std::make_unique<ExprGroupConnections>(
771 from, joined_groups, producer_val, consumer_val));
772
773 joined_groups->addProducerEdge(edges_.back().get());
774 from->addConsumerEdge(edges_.back().get());
775 }
776
777 auto consumer_edges = getMergedConsumerEdges(sg1, sg2);
778
779 for (auto& edge : consumer_edges) {
780 auto to = edge->to;
781 auto producer_val = edge->producer_val_;
782 auto consumer_val = edge->consumer_val_;
783
784 edges_.push_back(std::make_unique<ExprGroupConnections>(
785 joined_groups, to, producer_val, consumer_val));
786 joined_groups->addConsumerEdge(edges_.back().get());
787 edge->to->addProducerEdge(edges_.back().get());
788 }
789
790 // Merge the compute at domain of all edges going out from the newly joined
791 // group. The val's we're looking for are from our consumer edges, but we want
792 // to grab the producer val as that's the one we generate.
793 std::unordered_set<IterDomain*> ca_ids;
794 for (auto consumer_group_edge : joined_groups->consumerEdges()) {
795 auto producer_of_consumer_edge = consumer_group_edge->producer_val_;
796 if (producer_of_consumer_edge->isA<TensorView>()) {
797 auto tv = producer_of_consumer_edge->as<TensorView>();
798 for (const auto tv_i : c10::irange(tv->getComputeAtPosition())) {
799 ca_ids.emplace(GpuLower::current()->caMap()->getConcreteMappedID(
800 tv->axis(tv_i), IdMappingMode::LOOP));
801 }
802 }
803 }
804
805 // Merge the produce at domain of all edges coming into the newly joined
806 // group. The val's we're looking for are from our producer edges, but we want
807 // to grab the consumer val as that's the one we generate.
808 std::unordered_set<IterDomain*> pa_ids;
809 for (auto producer_group_edge : joined_groups->producerEdges()) {
810 auto consumer_of_producer_edge = producer_group_edge->consumer_val_;
811 if (consumer_of_producer_edge->isA<TensorView>()) {
812 auto tv = consumer_of_producer_edge->as<TensorView>();
813 for (const auto tv_i : c10::irange(tv->getMaxProducerPosition())) {
814 pa_ids.emplace(GpuLower::current()->caMap()->getConcreteMappedID(
815 tv->axis(tv_i), IdMappingMode::LOOP));
816 }
817 }
818 }
819
820 auto all_ca_pa_ids = ca_ids;
821 all_ca_pa_ids.insert(pa_ids.begin(), pa_ids.end());
822
823 auto ordered_ids = getLocalDomainOrdering(
824 joined_groups->exprs(), all_ca_pa_ids, concrete_id_dependencies);
825
826 for (auto id : ordered_ids) {
827 if (ca_ids.count(id)) {
828 joined_groups->payload()->ca_domains_.emplace_back(id);
829 }
830 if (pa_ids.count(id)) {
831 joined_groups->payload()->pa_domains_.emplace_back(id);
832 }
833 }
834
835 return joined_groups;
836}
837
838bool canReducePA(ExprGroup* group) {
839 if (group->payload()->pa_domains_.empty()) {
840 return false;
841 }
842
843 IterDomain* group_pa_last_id = group->payload()->pa_domains_.back();
844
845 // Look through producer edges to see if we can reduce our produce at domain
846 for (auto producer_edge : group->producerEdges()) {
847 auto producer_val = producer_edge->producer_val_;
848 auto consumer_val = producer_edge->consumer_val_;
849
850 // If producer isn't a tensor view it can't be mapped into a producer dim of
851 // this group
852 if (!(consumer_val->isA<TensorView>() && producer_val->isA<TensorView>())) {
853 continue;
854 }
855
856 // If the compute at domains of the producer group is empty, it can't map to
857 // the produce at domains of this group
858 auto producer_group = producer_edge->from;
859 if (producer_group->payload()->ca_domains_.empty()) {
860 continue;
861 }
862
863 auto producer_tv = producer_val->as<TensorView>();
864 auto consumer_tv = consumer_val->as<TensorView>();
865
866 // If this consumer_tv doesn't map to the last producer domain of this group
867 // it can't decide if it can be reduced
868 bool has_matching_pa = false;
869 for (const auto i : c10::irange(consumer_tv->getMaxProducerPosition())) {
870 if (GpuLower::current()->caMap()->areMapped(
871 consumer_tv->axis(i), group_pa_last_id, IdMappingMode::LOOP)) {
872 has_matching_pa = true;
873 break;
874 }
875 }
876
877 if (!has_matching_pa) {
878 continue;
879 }
880
881 // If any compute at positions of producers directly map to the last produce
882 // at position it can't be lowered.
883 for (int producer_pos_i =
884 static_cast<int>(producer_tv->getComputeAtPosition());
885 producer_pos_i > 0;
886 producer_pos_i--) {
887 if (GpuLower::current()->caMap()->areMapped(
888 producer_tv->axis(producer_pos_i - 1),
889 group_pa_last_id,
890 IdMappingMode::LOOP)) {
891 return false;
892 }
893 }
894 }
895
896 return true;
897}
898
899// Update in between attempts to segment. This is called once no more groups
900// can be merged together. Typically we will want to remove compute at groups
901// that have finished being grouped together. However if no groups have been
902// merged after we've done this, we may need to stop as we could have multiple
903// disjoint groups that won't be merged.
904bool ExprSegmentationSorter::interIterUpdate() {
905 // Go through groups and lower either pa or ca domain return if anything was
906 // lowered
907 bool lowered_a_domain = false;
908 for (auto& group : groups_) {
909 if (canReducePA(group.get())) {
910 group->payload()->pa_domains_.pop_back();
911 lowered_a_domain = true;
912 }
913 }
914
915 // If we couldn't lower compute at domain any further, and we haven't merged
916 // any new groups after fallback_mode_enabled_ has been turned on, make sure
917 // we've finished successfully
918 if (!lowered_a_domain && n_groups_ == groups_.size()) {
919 // Make sure none of the groups are still connected, as that would mean we
920 // should have been able to merge them.
921 bool successfully_finished = std::all_of(
922 groups_.begin(), groups_.end(), [](std::unique_ptr<ExprGroup>& sg) {
923 return sg->producerEdges().empty() && sg->consumerEdges().empty();
924 });
925 if (successfully_finished) {
926 return false;
927 }
928 // If we didn't finish and we tried the fallback, throw.
929 TORCH_INTERNAL_ASSERT(
930 !fallback_mode_enabled_,
931 "Couldn't successfully sort out the fusion expressions. ",
932 "There are remaining connections of the hierarchical segmentation which should have been ",
933 "flattened to a single ordered group, or disjoint ordered groups.");
934 // We didn't finish, but we haven't tried the fallback, try again with that.
935 fallback_mode_enabled_ = true;
936 }
937
938 n_groups_ = groups_.size();
939 // Not done, continue.
940 return true;
941}
942
943void ExprSegmentationSorter::mergeNodes() {
944 std::unordered_set<ExprGroup*> clean_up_groups;
945 std::unordered_set<ExprGroupConnections*> clean_up_edges;
946
947 while (!to_merge_.empty()) {
948 ExprGroup *group1 = nullptr, *group2 = nullptr;
949 std::tie(group1, group2) = to_merge_.back();
950 to_merge_.pop_back();
951 TORCH_INTERNAL_ASSERT(
952 group2 == group1->payload()->merge_with,
953 "Expression Sorter: inconsistent to_merge packing");
954 clean_up_groups.emplace(group1);
955 clean_up_groups.emplace(group2);
956 makeMergedNode(group1, group2);
957 }
958
959 for (auto group : clean_up_groups) {
960 auto disconnected_edges = disconnectGroup(group);
961 clean_up_edges.insert(disconnected_edges.begin(), disconnected_edges.end());
962 }
963
964 edges_.remove_if([&](std::unique_ptr<ExprGroupConnections>& edge) {
965 return clean_up_edges.find(edge.get()) != clean_up_edges.end();
966 });
967
968 groups_.remove_if([&](std::unique_ptr<ExprGroup>& group) {
969 return clean_up_groups.find(group.get()) != clean_up_groups.end();
970 });
971}
972
973// Initialize concrete_id_dependencies and concrete_id_to_all_ids
974void ExprSegmentationSorter::initializeForLoopDependencies() {
975 TORCH_INTERNAL_ASSERT(
976 concrete_id_dependencies.empty(),
977 "For loop dependencies have already been initialized.");
978
979 for (auto tv : ir_utils::allTvs(fusion_)) {
980 std::unordered_set<IterDomain*> dependencies;
981 for (size_t tv_id_i =
982 std::max(tv->getMaxProducerPosition(), tv->getComputeAtPosition());
983 tv_id_i > 0;
984 tv_id_i--) {
985 auto tv_id = tv->axis((int)(tv_id_i - 1));
986 auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
987 tv_id, IdMappingMode::LOOP);
988
989 if (concrete_id_dependencies.find(concrete_id) ==
990 concrete_id_dependencies.end()) {
991 concrete_id_dependencies[concrete_id] = dependencies;
992 } else {
993 concrete_id_dependencies[concrete_id].insert(
994 dependencies.begin(), dependencies.end());
995 }
996
997 // Loops after tv_id are dependent on tv_id
998 dependencies.emplace(GpuLower::current()->caMap()->getConcreteMappedID(
999 tv_id, IdMappingMode::LOOP));
1000 }
1001 }
1002
1003 // Fill out dependencies as IDs will have local dependency information, but
1004 // it's still not guaranteed to be global.
1005
1006 // If loop structure is something like:
1007 // T0 [I0]
1008 // T1 [I0, I1]
1009 // T2 [I1, I2]
1010 //
1011 // I1 will be marked as a dependency of I0
1012 // I2 will be marked as a dependency of I1
1013 //
1014 // However, I2 will not be marked as a dep of I0, so we need to fill out the
1015 // dependency analysis. This is done by iterating through IterDomains filling
1016 // out all the dependencies of dependencies recursively.
1017
1018 std::deque<IterDomain*> to_visit;
1019 std::unordered_set<IterDomain*> visited;
1020
1021 std::transform(
1022 concrete_id_dependencies.begin(),
1023 concrete_id_dependencies.end(),
1024 std::back_inserter(to_visit),
1025 [](const auto& concrete_dep_entry) { return concrete_dep_entry.first; });
1026
1027 size_t inf_loop_counter = to_visit.size();
1028 bool failed = false;
1029
1030 while (!to_visit.empty()) {
1031 auto id = to_visit.front();
1032 to_visit.pop_front();
1033
1034 if (inf_loop_counter-- == 0) {
1035 failed = true;
1036 break;
1037 }
1038
1039 auto& dependencies = concrete_id_dependencies.at(id);
1040 bool ready = dependencies.empty() ||
1041 std::all_of(dependencies.begin(),
1042 dependencies.end(),
1043 [&visited](IterDomain* id) { return visited.count(id); });
1044
1045 if (!ready) {
1046 to_visit.push_back(id);
1047 continue;
1048 }
1049
1050 inf_loop_counter = to_visit.size();
1051
1052 for (auto dependency : dependencies) {
1053 auto dep_of_dep = concrete_id_dependencies.at(dependency);
1054 dependencies.insert(dep_of_dep.begin(), dep_of_dep.end());
1055 }
1056 visited.emplace(id);
1057 }
1058 if (failed) {
1059 std::cerr
1060 << "ERROR: Iteration domain sorting has failed, infinite loop detected."
1061 << std::endl;
1062 std::cerr << "Failed to sort out: " << std::endl;
1063 for (auto entry : to_visit) {
1064 std::cerr << entry->toString();
1065 if (entry != to_visit.back()) {
1066 std::cerr << ", ";
1067 }
1068 }
1069
1070 std::cerr << "Dependencies: " << std::endl;
1071 for (const auto& dep_entry : concrete_id_dependencies) {
1072 std::cerr << " Deps of " << dep_entry.first->toString() << std::endl
1073 << " ";
1074
1075 for (auto dep : dep_entry.second) {
1076 std::cerr << dep->toString() << ", ";
1077 }
1078 std::cerr << std::endl;
1079 }
1080
1081 TORCH_INTERNAL_ASSERT(false);
1082 }
1083}
1084
1085// Checks if the for loop associated with the concrete ID is ready to be
1086// resolved in sorting. This could be done more efficiently with some
1087// additional tracking, however we recreate ca_domain_ when we merge groups,
1088// so it's hard to track what is no longer needed.
1089bool ExprSegmentationSorter::loopReady(IterDomain* concrete_id) {
1090 const auto& dependencies = concrete_id_dependencies[concrete_id];
1091 for (auto& group : groups_) {
1092 // Only need to check compute at domain here, because if there's an entry in
1093 // produce at, that has no matching entry in compute at, then that ID can be
1094 // removed as in canReducePA
1095 for (auto ca_domain : group->payload()->ca_domains_) {
1096 if (dependencies.count(ca_domain)) {
1097 return false;
1098 }
1099 }
1100 }
1101 return true;
1102}
1103
1104// Two expression groups can be merged together if there's a value produced by
1105// producer group, consumed by consumer group, where the compute at position
1106// maps to the inner most compute at domain of the producer group and maps to
1107// the inner most produce at domain of the consumer. If this value doesn't exist
1108// we can't be certain these domains share the "next" inner most loop.
1109//
1110// We're looking for this because we're starting at the inner most loops of all
1111// expressions, and looking for neighboring expressions that share inner loops.
1112// Once we've found all the inner most loops that expressions share, we merge
1113// them together, then look at the next inner most loop of the group and figure
1114// out which other groups share this next inner most loop.
1115bool ExprSegmentationSorter::supportedMerge(ExprGroup* sg1, ExprGroup* sg2) {
1116 auto producer_group = getProducer(sg1, sg2);
1117 auto consumer_group = sg1 == producer_group ? sg2 : sg1;
1118
1119 if (producer_group->payload()->ca_domains_.size() <
1120 producer_group->payload()->pa_domains_.size()) {
1121 return false;
1122 }
1123
1124 if (consumer_group->payload()->pa_domains_.size() <
1125 consumer_group->payload()->ca_domains_.size()) {
1126 return false;
1127 }
1128
1129 const auto& producer_ca_domain = producer_group->payload()->ca_domains_;
1130 const auto& consumer_pa_domain = consumer_group->payload()->pa_domains_;
1131
1132 if (producer_ca_domain.empty() && consumer_pa_domain.empty()) {
1133 return true;
1134 }
1135
1136 if (producer_ca_domain.empty() || consumer_pa_domain.empty()) {
1137 return false;
1138 }
1139
1140 // If inner loop dependencies have not been resolved, cannot merge.
1141 if (!loopReady(producer_ca_domain.back()) ||
1142 !loopReady(consumer_pa_domain.back())) {
1143 return false;
1144 }
1145
1146 for (auto edge : producer_group->consumerEdges()) {
1147 if (edge->to != consumer_group) {
1148 continue;
1149 }
1150 auto producer_val = edge->producer_val_;
1151 auto consumer_val = edge->consumer_val_;
1152
1153 if (!producer_val->isA<TensorView>()) {
1154 continue;
1155 }
1156
1157 TORCH_INTERNAL_ASSERT(
1158 consumer_val->isA<TensorView>(),
1159 "Mismatched tensorview to non-tensorview in expression sorting. ",
1160 producer_val,
1161 " is consumed by ",
1162 consumer_val);
1163
1164 auto producer_tv = producer_val->as<TensorView>();
1165
1166 auto compute_at_pos = producer_tv->getComputeAtPosition();
1167 auto compute_at_dim = compute_at_pos > 0
1168 ? producer_tv->axis((int)producer_tv->getComputeAtPosition() - 1)
1169 : nullptr;
1170
1171 if (compute_at_dim == nullptr) {
1172 continue;
1173 }
1174
1175 if (!GpuLower::current()->caMap()->areMapped(
1176 compute_at_dim, producer_ca_domain.back(), IdMappingMode::LOOP)) {
1177 continue;
1178 }
1179
1180 if (GpuLower::current()->caMap()->areMapped(
1181 compute_at_dim, consumer_pa_domain.back(), IdMappingMode::LOOP)) {
1182 return true;
1183 }
1184 }
1185 return false;
1186}
1187
1188bool ExprSegmentationSorter::testStillDag(ExprGroup* sg1, ExprGroup* sg2) {
1189 std::deque<ExprGroup*> to_visit;
1190 std::unordered_set<ExprGroup*> visited;
1191 // Add consumers of sg1 if not sg2
1192 for (auto sg1_consumer_edge : sg1->consumerEdges()) {
1193 if (sg1_consumer_edge->to != sg2) {
1194 to_visit.emplace_back(sg1_consumer_edge->to);
1195 }
1196 }
1197
1198 // Add consumers of sg2 if not sg1
1199 for (auto sg2_consumer_edge : sg2->consumerEdges()) {
1200 if (sg2_consumer_edge->to != sg1) {
1201 to_visit.emplace_back(sg2_consumer_edge->to);
1202 }
1203 }
1204
1205 while (to_visit.size() > 0) {
1206 auto group = to_visit.front();
1207 // Arrived back at one of the original groups, merging these two groups
1208 // would generate a cycle
1209 if (group == sg1 || group == sg2) {
1210 return false;
1211 }
1212 to_visit.pop_front();
1213 if (visited.find(group) != visited.end()) {
1214 continue;
1215 }
1216 visited.emplace(group);
1217 for (auto consumer_edge : group->consumerEdges()) {
1218 to_visit.emplace_back(consumer_edge->to);
1219 }
1220 }
1221
1222 // No cycles found, we're good.
1223 return true;
1224}
1225
1226void ExprSegmentationSorter::sort() {
1227 // Need this for initialization of the DAG that is processed
1228 std::unordered_map<Expr*, ExprGroup*> expr2group;
1229
1230 auto all_exprs = fusion_->exprs();
1231
1232 // Figure out all the values used as inputs to the expressions we're sorting
1233 // (to find terminating expressions). There could be branches of expressions
1234 // not used to produce outputs, so can't simply check val->uses() to figure
1235 // out if it's actually used in the expressions we're sorting.
1236 std::unordered_set<Val*> used_vals;
1237 for (auto expr : all_exprs) {
1238 used_vals.insert(expr->inputs().begin(), expr->inputs().end());
1239 }
1240
1241 // Initialize DAG, convert each expr to a segment group
1242 for (auto expr : all_exprs) {
1243 bool is_terminating_expr = std::none_of(
1244 expr->outputs().begin(),
1245 expr->outputs().end(),
1246 [&used_vals](Val* output) { return used_vals.count(output) > 0; });
1247 auto group = makeEmptyGroup(expr, is_terminating_expr);
1248 expr2group.insert(std::make_pair(expr, group));
1249 }
1250
1251 // Create edges between the Exprs. Mark inputs and outputs of the fusion.
1252 for (auto expr : fusion_->exprs()) {
1253 auto expr_group = expr2group.at(expr);
1254 auto out = expr->outputs()[0];
1255 for (auto inp : expr->inputs()) {
1256 if (inp->isFusionInput()) {
1257 continue;
1258 }
1259
1260 // Could be something like a constant scalar, definition is nullptr, but
1261 // isn't an "input" to the fusion. At least not one provided by an
1262 // external source.
1263 if (inp->definition() == nullptr) {
1264 continue;
1265 }
1266
1267 auto inp_def_group = expr2group.at(inp->definition());
1268 edges_.push_back(std::make_unique<ExprGroupConnections>(
1269 inp_def_group, expr_group, inp, out));
1270 expr_group->addProducerEdge(edges_.back().get());
1271 inp_def_group->addConsumerEdge(edges_.back().get());
1272 }
1273 }
1274
1275 // Initialize loop dependency maps
1276 initializeForLoopDependencies();
1277
1278 bool inter_iter_update = true;
1279 while (inter_iter_update) {
1280 // If we didn't do any update, stop traversal, we're done.
1281 if (!fallback_mode_enabled_) {
1282 // Merge expressions in sorted order
1283 bool merged_nodes = true;
1284 while (merged_nodes) {
1285 // Reset stateful traversal details in ExprGroups
1286 resetTraversal();
1287 resetLevels();
1288
1289 for (auto& group : groups_) {
1290 if (group->payload()->merged) {
1291 continue;
1292 }
1293 auto candidates = group->getMergeCandidates(fallback_mode_enabled_);
1294 if (candidates.empty()) {
1295 continue;
1296 }
1297
1298 auto candidate_it = candidates.begin();
1299 while (candidate_it != candidates.end() &&
1300 !supportedMerge(group.get(), *candidate_it)) {
1301 candidate_it++;
1302 }
1303 if (candidate_it == candidates.end()) {
1304 continue;
1305 }
1306
1307 to_merge_.emplace_back(std::make_pair(group.get(), *candidate_it));
1308
1309 group->payload()->merged = true;
1310 group->payload()->merge_with = *candidate_it;
1311
1312 (*candidate_it)->payload()->merged = true;
1313 (*candidate_it)->payload()->merge_with = group.get();
1314 }
1315
1316 if (to_merge_.empty()) {
1317 merged_nodes = false;
1318 }
1319
1320 mergeNodes();
1321
1322 // Move compute at axes left
1323 inter_iter_update = interIterUpdate();
1324 }
1325 } else {
1326 // fallback_mode_enabled = true
1327 // Reset stateful traversal details in ExprGroups as we'll exclude merge
1328 // options that were already ruled out and therefore need traversal and
1329 // levels reset.
1330 resetTraversal();
1331 resetLevels();
1332
1333 for (auto& group : groups_) {
1334 if (group->payload()->merged) {
1335 continue;
1336 }
1337 // Get merge candidates that weren't proven safe to merge with default
1338 // algorithm.
1339 auto candidates = group->getMergeCandidates(fallback_mode_enabled_);
1340 if (candidates.empty()) {
1341 continue;
1342 }
1343
1344 auto candidate_it = candidates.begin();
1345
1346 while (candidate_it != candidates.end()) {
1347 while (candidate_it != candidates.end() &&
1348 !supportedMerge(group.get(), *candidate_it)) {
1349 candidate_it++;
1350 }
1351
1352 if (candidate_it == candidates.end()) {
1353 break;
1354 }
1355
1356 if (testStillDag(group.get(), *candidate_it)) {
1357 // Mark in same style as default algorithm for convenience even
1358 // though we will only merge once with the fallback
1359 to_merge_.emplace_back(std::make_pair(group.get(), *candidate_it));
1360
1361 group->payload()->merged = true;
1362 group->payload()->merge_with = *candidate_it;
1363
1364 (*candidate_it)->payload()->merged = true;
1365 (*candidate_it)->payload()->merge_with = group.get();
1366 break;
1367 }
1368
1369 candidate_it++;
1370 }
1371
1372 if (to_merge_.size() > 0) {
1373 break;
1374 }
1375 }
1376
1377 // If we can merge something, merge it, disable fallback, and bail
1378 if (to_merge_.size() > 0) {
1379 mergeNodes();
1380 }
1381
1382 // Move compute at axes left
1383 // If fallback didn't work, interIterUpdate will catch that we failed.
1384 inter_iter_update = interIterUpdate();
1385 fallback_mode_enabled_ = false;
1386 }
1387 }
1388}
1389
1390std::vector<Expr*> ExprSegmentationSorter::getExprs() const {
1391 std::vector<Expr*> exprs;
1392 for (auto& group : groups_) {
1393 exprs.insert(exprs.end(), group->exprs().begin(), group->exprs().end());
1394 }
1395 return exprs;
1396}
1397
1398} // namespace
1399
1400std::vector<Expr*> reorderExprsForComputeAt() {
1401 auto fusion = FusionGuard::getCurFusion();
1402 if (fusion->exprs().empty()) {
1403 return {};
1404 }
1405 TORCH_INTERNAL_ASSERT(fusion != nullptr);
1406 ExprSegmentationSorter sorter(fusion);
1407 sorter.sort();
1408 auto sorted_exprs = sorter.getExprs();
1409 TORCH_INTERNAL_ASSERT(
1410 sorted_exprs.size() > 0,
1411 "Error during expression sorting, no expressions produced.");
1412 return sorted_exprs;
1413}
1414
1415} // namespace cuda
1416} // namespace fuser
1417} // namespace jit
1418} // namespace torch
1419