1#include <arith.h>
2#include <fusion.h>
3#include <fusion_segmenter.h>
4#include <instrumentation.h>
5#include <ir_all_nodes.h>
6#include <ir_cloner.h>
7#include <ir_graphviz.h>
8#include <ir_iostream.h>
9#include <ir_utils.h>
10#include <scheduler/debug_utils.h>
11
12#include <sstream>
13
14namespace torch {
15namespace jit {
16namespace fuser {
17namespace cuda {
18
19namespace {
20
21using GroupSet = VectorOfUniqueEntries<SegmentedGroup*>;
22
23} // namespace
24
25std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::getNeighborGroups() {
26 std::vector<NeighborGroup> neighbors;
27 for (auto inp : producer_edges) {
28 if (inp->val->isFusionOutput()) {
29 // Don't fuse across output nodes, would need to find another path.
30 continue;
31 }
32 neighbors.emplace_back(inp->from, inp);
33 }
34 for (auto out : consumer_edges) {
35 if (out->val->isFusionOutput()) {
36 // Don't fuse across output nodes, would need to find another path.
37 continue;
38 }
39 neighbors.emplace_back(out->to, out);
40 }
41 return neighbors;
42}
43
44std::vector<SegmentedGroup*> SegmentedGroup::getNeighbors() {
45 std::vector<SegmentedGroup*> neighbors;
46 auto neighbors_pair = getNeighborGroups();
47
48 std::transform(
49 neighbors_pair.begin(),
50 neighbors_pair.end(),
51 std::back_inserter(neighbors),
52 [](auto& neighbor_group) { return neighbor_group.group; });
53 return neighbors;
54}
55
56std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::
57 getMergeCandidates() {
58 // Don't look for candidates if already merged
59 if (merged_) {
60 return {};
61 }
62
63 std::vector<NeighborGroup> neighbors = getNeighborGroups();
64
65 // Can this node be merged with another? Check if neighbors are merged, if
66 // so and merged neighbor is within 1 level or node merged with neighbor is
67 // within 1 level, can't merge this node with anything else.
68 bool can_merge_this = true;
69 for (auto& neighbor : neighbors) {
70 if (!neighbor.group->merged_) {
71 continue;
72 }
73 if (std::abs(neighbor.group->level_ - level_) <= 1) {
74 can_merge_this = false;
75 }
76 if (std::abs(neighbor.group->merge_with_->level_ - level_) <= 1) {
77 can_merge_this = false;
78 }
79 }
80 if (!can_merge_this) {
81 return {};
82 }
83
84 std::vector<bool> can_merge(neighbors.size(), true);
85
86 // Find neighbors with a level that is only 1 differant than this groups level
87 for (const auto i : c10::irange(neighbors.size())) {
88 if (std::abs(neighbors[i].group->level_ - level_) > 1) {
89 can_merge[i] = false;
90 }
91 }
92
93 // Check neighbor of neighbors we're considering, if any of them are merged
94 // with another node, make sure the resulting edge wouldn't have a level
95 // difference of 1
96 for (const auto i : c10::irange(neighbors.size())) {
97 if (!can_merge[i]) {
98 continue;
99 }
100
101 for (auto neighbor_neighbor : neighbors[i].group->getNeighbors()) {
102 // Don't check self
103 if (neighbor_neighbor == neighbors[i].group) {
104 continue;
105 }
106 if (neighbor_neighbor->merged_) {
107 // check neighbor_neighbor level
108 if (std::abs(neighbor_neighbor->level_ - level_) <= 1) {
109 can_merge[i] = false;
110 }
111 if (std::abs(neighbor_neighbor->level_ - neighbors[i].group->level_) <=
112 1) {
113 can_merge[i] = false;
114 }
115
116 // check neighbor_neighber->merged_->level_
117 if (std::abs(neighbor_neighbor->merge_with_->level_ - level_) <= 1) {
118 can_merge[i] = false;
119 }
120 if (std::abs(
121 neighbor_neighbor->merge_with_->level_ -
122 neighbors[i].group->level_) <= 1) {
123 can_merge[i] = false;
124 }
125 }
126 }
127 }
128
129 std::vector<NeighborGroup> merge_candidates;
130 for (const auto i : c10::irange(neighbors.size())) {
131 if (can_merge[i]) {
132 merge_candidates.push_back(neighbors[i]);
133 }
134 }
135 return merge_candidates;
136}
137
138void SegmentedGroup::clearTraversalInfo() {
139 level_ = -1;
140 visited_ = false;
141 merge_with_ = nullptr;
142 merge_through_ = nullptr;
143 merged_ = false;
144}
145
146std::vector<Val*> SegmentedGroup::edgesToVals(
147 const std::vector<SegmentedEdge*>& se_v) {
148 std::vector<Val*> ret_v;
149 ret_v.reserve(se_v.size());
150
151 std::transform(
152 se_v.cbegin(),
153 se_v.cend(),
154 std::back_inserter(ret_v),
155 [](SegmentedEdge* se) { return se->val; });
156 return ret_v;
157}
158
159template <typename PREDICATE>
160void insertUniquePredicated(
161 std::vector<Val*>& v,
162 const std::vector<SegmentedEdge*>& e,
163 PREDICATE pred) {
164 VectorOfUniqueEntries<Val*> to_add;
165 for (auto edge : e) {
166 to_add.pushBack(edge->val);
167 }
168
169 std::copy_if(
170 to_add.vector().begin(),
171 to_add.vector().end(),
172 std::back_inserter(v),
173 [pred](Val* val) { return pred(val); });
174}
175
176void SegmentedGroup::finalize() {
177 // Move all the edges to group input/output
178 // Inputs
179 insertUniquePredicated(
180 input_vals, producer_edges, [](Val* v) { return !v->isFusionInput(); });
181
182 std::unordered_set<Val*> input_set(input_vals.begin(), input_vals.end());
183
184 for (auto expr : exprs_) {
185 for (auto i : expr->inputs()) {
186 if (i->isAnInt() && i->definition() == nullptr && !i->isConstScalar() &&
187 !i->isFusionInput() && !input_set.count(i)) {
188 input_set.insert(i);
189 input_vals.push_back(i);
190 }
191 }
192 }
193
194 // Outputs
195 insertUniquePredicated(
196 output_vals, consumer_edges, [](Val* v) { return !v->isFusionOutput(); });
197
198 // alias aware segmentation. we add inputs that are aliased by output
199 // generated in this SegmentedGroup
200 for (auto output : output_vals) {
201 if (auto aliased_input = segmented_fusion_->findAlias(output)) {
202 // aliasing currently only supported as output to input
203 TORCH_INTERNAL_ASSERT(
204 aliased_input->isFusionInput(),
205 "aliased input is not found in the complete fusion");
206 if (!input_set.count(aliased_input)) {
207 input_set.insert(aliased_input);
208 input_vals.push_back(aliased_input);
209 }
210 }
211 }
212}
213
214std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group) {
215 os << "g{";
216 auto expr_to_print = group->exprs();
217 std::sort(
218 expr_to_print.begin(),
219 expr_to_print.end(),
220 [](auto expr_a, auto expr_b) -> bool {
221 return expr_a->name() < expr_b->name();
222 });
223 for (const auto i : c10::irange(expr_to_print.size())) {
224 os << expr_to_print[i]->name();
225 if (i + 1 != expr_to_print.size())
226 os << ", ";
227 }
228 os << "}\n";
229 return os;
230}
231
232void SegmentedGroup::print() const {
233 std::cout << this << "\n";
234}
235
236std::string toString(const SegmentedGroup* group) {
237 std::stringstream ss;
238 ss << group;
239 return ss.str();
240}
241
242std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge) {
243 os << "e{ " << edge->from << " -> " << edge->to << "(";
244 IrPrinter irp(os);
245 irp.handle(edge->val);
246 os << ") }\n";
247 return os;
248}
249
250void SegmentedEdge::print() const {
251 std::cout << this << "\n";
252}
253
254std::string toString(const SegmentedEdge* edge) {
255 std::stringstream ss;
256 ss << edge;
257 return ss.str();
258}
259
260std::unique_ptr<SegmentedFusion> SegmentedFusion::fromCompleteFusion(
261 std::unique_ptr<Fusion> fusion_ptr,
262 ScheduleHeuristic heuristic) {
263 auto fusion = fusion_ptr.get();
264
265 auto segmented_fusion_ptr =
266 std::make_unique<SegmentedFusion>(std::move(fusion_ptr));
267
268 // Make a group for the single fusion
269 auto single_group = segmented_fusion_ptr->newGroup();
270
271 // Add input and output vals
272 single_group->input_vals = fusion->inputs();
273 single_group->output_vals = fusion->outputs();
274
275 // Get ordered expression list
276 single_group->resetExprList();
277
278 // Assign heuristics and id for the complete fusion
279 // to share the runtime path of segmented fusion.
280 single_group->setHeuristic(heuristic);
281 single_group->setID(0);
282
283 return segmented_fusion_ptr;
284}
285
286SegmentedFusion::SegmentedFusion(std::unique_ptr<Fusion> fusion)
287 : impl_(this), complete_fusion_(std::move(fusion)) {
288 segmented_fusion_name_ = segmentedFusionName();
289 annotateFP16IntermediateTensors();
290}
291
292SegmentedGroup* SegmentedFusion::Impl::makeGroup() {
293 groups_.emplace_back(std::make_unique<SegmentedGroup>(owning_fusion_));
294 return groups_.back().get();
295}
296
297SegmentedGroup* SegmentedFusion::Impl::makeGroup(Expr* expr) {
298 groups_.emplace_back(std::make_unique<SegmentedGroup>(expr, owning_fusion_));
299 return groups_.back().get();
300}
301
302SegmentedEdge* SegmentedFusion::Impl::makeEdge(
303 SegmentedGroup* from,
304 SegmentedGroup* to,
305 Val* val) {
306 edges_.emplace_back(std::make_unique<SegmentedEdge>(from, to, val));
307 return edges_.back().get();
308}
309
310void SegmentedFusion::Impl::cleanUnused() {
311 std::unordered_set<SegmentedGroup*> g_used(
312 owning_fusion_->groups().begin(), owning_fusion_->groups().end());
313 std::unordered_set<SegmentedEdge*> e_used(
314 owning_fusion_->edges().begin(), owning_fusion_->edges().end());
315
316 groups_.erase(
317 std::remove_if(
318 groups_.begin(),
319 groups_.end(),
320 [&g_used](auto& g) { return g_used.count(g.get()) == 0; }),
321 groups_.end());
322
323 edges_.erase(
324 std::remove_if(
325 edges_.begin(),
326 edges_.end(),
327 [&e_used](auto& e) { return e_used.count(e.get()) == 0; }),
328 edges_.end());
329}
330
331SegmentedGroup* SegmentedFusion::newGroup() {
332 SegmentedGroup* g = impl_.makeGroup();
333 groups_.push_back(g);
334 return g;
335}
336
337SegmentedGroup* SegmentedFusion::newGroup(Expr* expr) {
338 SegmentedGroup* g = impl_.makeGroup(expr);
339 groups_.push_back(g);
340 return g;
341}
342
343SegmentedEdge* SegmentedFusion::newEdge(
344 SegmentedGroup* from,
345 SegmentedGroup* to,
346 Val* val) {
347 SegmentedEdge* e = impl_.makeEdge(from, to, val);
348 edges_.push_back(e);
349 return e;
350}
351
352void SegmentedFusion::draw() {
353 size_t group_index = 0;
354 std::unordered_map<const Expr*, size_t> expr_color_map;
355
356 for (auto group : groups()) {
357 for (auto expr : group->exprs()) {
358 if (ir_utils::isTvOp(expr)) {
359 expr_color_map[expr] = group_index;
360 }
361 }
362 group_index++;
363 }
364
365 std::stringstream sstream;
366 sstream << "segmented_fusion" << segmented_fusion_name_ << ".dot";
367 auto filename = sstream.str();
368
369 IrGraphGenerator::print(
370 completeFusion(),
371 filename.c_str(),
372 IrGraphGenerator::DetailLevel::ComputeOnly,
373 &expr_color_map);
374}
375
376namespace {
377
378std::vector<Val*> uniqueValConcat(
379 const std::vector<std::vector<Val*>>& val_vecs) {
380 std::vector<Val*> unique_vals;
381 std::unordered_set<Val*> added;
382 for (const auto& vec : val_vecs) {
383 for (auto val : vec) {
384 if (added.find(val) == added.end()) {
385 unique_vals.push_back(val);
386 added.emplace(val);
387 }
388 }
389 }
390 return unique_vals;
391}
392
393// Concat's producer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
394std::vector<SegmentedEdge*> getMergedProducerEdges(
395 const SegmentedGroup* sg1,
396 const SegmentedGroup* sg2) {
397 TORCH_INTERNAL_ASSERT(
398 sg1 != nullptr && sg2 != nullptr,
399 "This function doesn't handle trivial.");
400
401 auto producer_edges = sg1->producer_edges;
402
403 producer_edges.insert(
404 producer_edges.end(),
405 sg2->producer_edges.begin(),
406 sg2->producer_edges.end());
407
408 // Register producers into sg2
409 std::unordered_set<Val*> sg2_vals;
410 for (auto se : sg2->producer_edges) {
411 sg2_vals.emplace(se->val);
412 }
413
414 producer_edges.erase(
415 std::remove_if(
416 producer_edges.begin(),
417 producer_edges.end(),
418 [&sg1, &sg2, &sg2_vals](SegmentedEdge* se) {
419 // remove edges in between the groups and common uses
420 return (se->to == sg1 && se->from == sg2) ||
421 (se->to == sg2 && se->from == sg1) ||
422 (se->to == sg1 && sg2_vals.count(se->val));
423 }),
424 producer_edges.end());
425
426 // Remove Duplicate Edges
427
428 return producer_edges;
429}
430
431// Concat's consumer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
432std::vector<SegmentedEdge*> getMergedConsumerEdges(
433 const SegmentedGroup* sg1,
434 const SegmentedGroup* sg2) {
435 TORCH_INTERNAL_ASSERT(
436 sg1 != nullptr && sg2 != nullptr,
437 "This function doesn't handle trivial.");
438
439 auto consumer_edges = sg1->consumer_edges;
440 consumer_edges.insert(
441 consumer_edges.end(),
442 sg2->consumer_edges.begin(),
443 sg2->consumer_edges.end());
444
445 consumer_edges.erase(
446 std::remove_if(
447 consumer_edges.begin(),
448 consumer_edges.end(),
449 [&sg1, &sg2](SegmentedEdge* se) {
450 return (se->to == sg1 && se->from == sg2) ||
451 (se->to == sg2 && se->from == sg1);
452 }),
453 consumer_edges.end());
454
455 return consumer_edges;
456}
457
458// Returns a determinstic, unique set of inputs of the segment group, sg1, or
459// the combined group sg1 + sg2
460std::vector<Val*> getAllInputs(
461 const SegmentedGroup* sg1,
462 const SegmentedGroup* sg2 = nullptr) {
463 std::vector<SegmentedEdge*> merged_producer_edges;
464
465 if (sg1 != nullptr && sg2 != nullptr) {
466 merged_producer_edges = getMergedProducerEdges(sg1, sg2);
467 } else if (sg1 != nullptr) {
468 merged_producer_edges = sg1->producer_edges;
469 } else if (sg2 != nullptr) {
470 merged_producer_edges = sg2->producer_edges;
471 }
472
473 std::vector<Val*> producer_edge_vals;
474
475 std::transform(
476 merged_producer_edges.begin(),
477 merged_producer_edges.end(),
478 std::back_inserter(producer_edge_vals),
479 [](SegmentedEdge* se) { return se->val; });
480
481 return uniqueValConcat(
482 {sg1 == nullptr ? std::vector<Val*>() : sg1->input_vals,
483 sg2 == nullptr ? std::vector<Val*>() : sg2->input_vals,
484 producer_edge_vals});
485}
486
487// Returns a determinstic, unique set of outputs of the segment group, sg1, or
488// the combined group sg1 + sg2
489std::vector<Val*> getAllOutputs(
490 const SegmentedGroup* sg1,
491 const SegmentedGroup* sg2 = nullptr) {
492 std::vector<SegmentedEdge*> merged_consumer_edges;
493
494 if (sg1 != nullptr && sg2 != nullptr) {
495 merged_consumer_edges = getMergedConsumerEdges(sg1, sg2);
496 } else if (sg1 != nullptr) {
497 merged_consumer_edges = sg1->consumer_edges;
498 } else if (sg2 != nullptr) {
499 merged_consumer_edges = sg2->consumer_edges;
500 }
501
502 std::vector<Val*> consumer_edge_vals;
503
504 std::transform(
505 merged_consumer_edges.begin(),
506 merged_consumer_edges.end(),
507 std::back_inserter(consumer_edge_vals),
508 [](SegmentedEdge* se) { return se->val; });
509
510 auto output_vals = uniqueValConcat(
511 {sg1 == nullptr ? std::vector<Val*>() : sg1->output_vals,
512 sg2 == nullptr ? std::vector<Val*>() : sg2->output_vals,
513 consumer_edge_vals});
514
515 return output_vals;
516}
517
518// Set version of getting merged input or output if segmented_groups were
519// merged
520// outputs respects order in segmented_groups for deterministic
521// merge trace
522// will get input if get_inputs otherwise will get ouputs
523// TODO: merge with the binary counter parts
524std::vector<Val*> allInputsIfTrueElseOutputs(
525 const std::vector<SegmentedGroup*>& segmented_groups,
526 bool get_inputs = true) {
527 // Helper to distinguish if we are getting inputs or outputs
528 using EdgeVec = std::vector<SegmentedEdge*>;
529 using ValVec = std::vector<Val*>;
530
531 // Get producer edges to get inputs, consumer edges to get outputs
532 auto edges_to_process_from_or_to_group =
533 [get_inputs](SegmentedGroup* group) -> EdgeVec& {
534 return get_inputs ? group->producer_edges : group->consumer_edges;
535 };
536
537 // Get the group that is connected to current group
538 auto global_vals_from_or_to_group =
539 [get_inputs](SegmentedGroup* group) -> ValVec& {
540 return get_inputs ? group->input_vals : group->output_vals;
541 };
542
543 // Get the group that is connected to current group by given edge
544 auto opposite_end_of_edge = [get_inputs](SegmentedEdge* edge) {
545 return get_inputs ? edge->from : edge->to;
546 };
547
548 // Keep track of value and order to ensure deterministic result
549 std::vector<Val*> merged_vals;
550 std::unordered_set<Val*> merged_vals_set;
551
552 // Put groups in a set for quick look up
553 std::unordered_set<SegmentedGroup*> segmented_groups_set(
554 segmented_groups.begin(), segmented_groups.end());
555
556 // Collect vals associated with edges
557 for (auto group : segmented_groups) {
558 for (auto edge : edges_to_process_from_or_to_group(group)) {
559 if (
560 // Need to de-duplicate values so we don't get multiple of any input
561 !merged_vals_set.count(edge->val) &&
562 // One side of this edge will be `group`, if the other end is
563 // also in segmented_groups, then this is an internal edge
564 // that we don't want.
565 !segmented_groups_set.count(opposite_end_of_edge(edge))) {
566 merged_vals.push_back(edge->val);
567 merged_vals_set.insert(edge->val);
568 }
569 }
570 }
571
572 // Collect original fusion's inputs/outputs and append at the end
573 for (auto group : segmented_groups) {
574 for (auto global_val : global_vals_from_or_to_group(group)) {
575 // de-duplicate
576 if (!merged_vals_set.count(global_val)) {
577 merged_vals.push_back(global_val);
578 merged_vals_set.insert(global_val);
579 }
580 }
581 }
582
583 return merged_vals;
584}
585
586// A sorting utility used for debug printing only
587// sorts the given vector of expressions in topological
588// order, with equal cases respecting the original order
589// in the vector.
590std::vector<Expr*> groupExprPrintSorting(const std::vector<Expr*>& exprs) {
591 std::vector<Expr*> exprs_to_print(exprs.begin(), exprs.end());
592 std::unordered_set<Expr*> exprs_to_print_set(exprs.begin(), exprs.end());
593 std::unordered_set<Expr*> exprs_visited;
594 std::vector<Expr*> sorted_list;
595 while (!std::all_of(
596 exprs_to_print.begin(),
597 exprs_to_print.end(),
598 [&exprs_visited](auto expr) { return exprs_visited.count(expr); })) {
599 bool expr_added_to_sorted_list = false;
600 for (auto expr : exprs_to_print) {
601 if (!exprs_visited.count(expr)) {
602 bool add_this_expr = true;
603 // Check if any of the inputs of current
604 // expression within the group
605 // hasn't been visited
606 for (auto input : expr->inputs()) {
607 if (input->definition() &&
608 exprs_to_print_set.count(input->definition()) &&
609 !exprs_visited.count(input->definition())) {
610 add_this_expr = false;
611 break;
612 }
613 }
614
615 // Append the current group to sorted list
616 // and mark visited
617 if (add_this_expr) {
618 expr_added_to_sorted_list = true;
619 exprs_visited.insert(expr);
620 sorted_list.push_back(expr);
621 break;
622 }
623 }
624 }
625 TORCH_INTERNAL_ASSERT(
626 expr_added_to_sorted_list,
627 "group debug print failed, exprs within given vector not a DAG");
628 }
629 return sorted_list;
630}
631
632// Utility function to list all expressions in a group
633void detailGroupPrint(std::ostream& os, const SegmentedGroup* group) {
634 IrPrinter irp(os);
635
636 auto sort_val_by_name = [](std::vector<Val*> vals_to_sort) {
637 std::sort(vals_to_sort.begin(), vals_to_sort.end(), [](Val* a, Val* b) {
638 return a->name() < b->name();
639 });
640 return vals_to_sort;
641 };
642
643 os << "g{"
644 << "(" << toString(group->heuristic()) << ")\n";
645 os << "inputs: \n";
646 for (auto input : sort_val_by_name(getAllInputs(group))) {
647 os << input << " " << input->getDataType().value() << "\n";
648 }
649 os << "outputs: \n";
650 for (auto output : sort_val_by_name(getAllOutputs(group))) {
651 os << output << " " << output->getDataType().value() << "\n";
652 }
653
654 os << "\n\n";
655
656 auto expr_to_print = groupExprPrintSorting(group->exprs());
657
658 for (const auto i : c10::irange(expr_to_print.size())) {
659 irp.handle(expr_to_print[i]);
660 }
661 os << "}\n\n";
662}
663
664//! Insert casts for an intermediate tensorview, i.e. ones
665//! that are in segmentedEdges. The insertion is done on
666//! the complete fusion, which should be owned by a segmented
667//! fusion so that only one segmented fusion will be affected.
668//! The replacement pattern is:
669//! TV0
670//! replaced as:
671//! fp16_tv = cast(TV0)
672//! fp32_tv = cast(fp16_tv)
673//!
674//! All segmented groups that take TV0 as input will then
675//! take fp16_tv or bf16_tv instead and the cast to fp32 will be
676//! automatically included in each of the groups.
677TensorView* castIntermediateValueInCompleteFusion(
678 Fusion* fusion,
679 TensorView* original_tv,
680 std::unordered_set<Expr*> edge_from_group_uses,
681 DataType dtype) {
682 FusionGuard fg(fusion);
683
684 // A utility lambda that creates consumer tensordomain of
685 // the given tv and create a new tensorview around the
686 // new tensordomain with the given data type.
687 auto make_consumer_tv = [&](TensorView* from, DataType data_type) {
688 // Keep broadcast axes and remove reduction axes
689 size_t i = 0;
690 auto no_reduction_root_domain =
691 TensorDomain::noReductions(original_tv->getMaybeRFactorDomain());
692 std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size());
693 for (const auto& dom : no_reduction_root_domain) {
694 new_root_domain[i++] = dom->cloneWithoutRFactor();
695 }
696
697 // Create the actual domain and tv.
698 return IrBuilder::create<TensorView>(
699 IrBuilder::create<TensorDomain>(
700 new_root_domain, std::vector<bool>(new_root_domain.size(), true)),
701 data_type);
702 };
703
704 // create the tv's to cast
705 auto half_precision_tv = make_consumer_tv(original_tv, dtype);
706
707 auto fp32_tv = make_consumer_tv(original_tv, DataType::Float);
708
709 // replace uses of original tv with fp32_tv in the complete
710 // fusion
711 for (auto expr : fusion->unordered_uses(original_tv)) {
712 // Don't modify internal uses of buffers, only cast for outputs.
713 if (edge_from_group_uses.find(expr) == edge_from_group_uses.end()) {
714 ir_utils::replaceValInExpr(expr, original_tv, fp32_tv);
715 }
716 }
717
718 // Insert the cast ops.
719 IrBuilder::create<UnaryOp>(UnaryOpType::Cast, half_precision_tv, original_tv);
720 IrBuilder::create<UnaryOp>(UnaryOpType::Cast, fp32_tv, half_precision_tv);
721
722 // Return the new tv to replace original tv with
723 // on the segmented edges.
724 return half_precision_tv;
725}
726} // namespace
727
728void SegmentedFusion::finalize() {
729 impl_.cleanUnused();
730 // Insert casts for the tensorviews that are on
731 // segmented edges and also on the force_to_fp16 list
732 //
733 // Note:
734 // The cast is inserted after the segmenter canSchedule check, which
735 // shouldn't cause problem short-term. The reason we put the cast here
736 // is we don't want to keep making copies of the original fusion
737 // during segmentation. Could consider making the cast insertion
738 // reversible if we do have to test canSchedule with the casts inserted
739 // during segmentation process in the future.
740
741 // Keep track of groups that need to update expr list,
742 // including both the producer and consumer of the selected tv's that
743 // we cast to fp16.
744 std::unordered_set<SegmentedGroup*> affected_group_set;
745 // A map to keep track of the tv's that have been inserted cast
746 // and its fp16 version.
747 std::unordered_map<TensorView*, TensorView*> fp32_to_half_cast_map;
748
749 // Go through all edges of the segmented fusion.
750 for (auto edge : edges()) {
751 TORCH_INTERNAL_ASSERT(edge->val->isA<TensorView>());
752 auto edge_tv = edge->val->as<TensorView>();
753
754 // Uses of the edge value within the from group should not be replaced. This
755 // will cause the group to have an intermediate tensor
756 // tv -> float2half -> output
757 // \ -> half2float -> other uses in group
758 // The conversion back and forth from half precision can hurt numerics.
759 // Collect expressions that use the edge value of concern within the from
760 // group to avoid replacing with the cast tensor.
761 std::unordered_set<Expr*> uses_in_from_group;
762
763 // All expressions in the from group of the edge
764 std::unordered_set<Expr*> from_group_exprs(
765 edge->from->exprs().begin(), edge->from->exprs().end());
766
767 // All uses of the edge val
768 for (auto edge_val_use_expr : edge_tv->uses()) {
769 if (from_group_exprs.count(edge_val_use_expr)) {
770 // Find uses in the to group of the val
771 uses_in_from_group.emplace(edge_val_use_expr);
772 }
773 }
774
775 // Only look at ones that need to cast to fp16 or bf16
776 if ((force_fp16_tv_set_.count(edge_tv) > 0)) {
777 auto cast_tv_it = fp32_to_half_cast_map.find(edge->val->as<TensorView>());
778 TensorView* cast_tv = nullptr;
779 // Insert cast ops for this tv if we haven't done so.
780 if (cast_tv_it == fp32_to_half_cast_map.end()) {
781 cast_tv = castIntermediateValueInCompleteFusion(
782 complete_fusion_.get(),
783 edge_tv,
784 uses_in_from_group,
785 force_half_precision_type_);
786 fp32_to_half_cast_map[edge->val->as<TensorView>()] = cast_tv;
787 } else {
788 cast_tv = cast_tv_it->second;
789 }
790
791 // Update the edge to use the fp16 version
792 edge->val = cast_tv;
793
794 // Mark the groups for update later
795 affected_group_set.insert(edge->from);
796 affected_group_set.insert(edge->to);
797
798 // The expr pointers on the group's expr list might have been freed
799 // by now after `ir_utils::replaceValInExpr`.
800 // Need a valid expression list to continue. Update from and to group.
801 edge->from->resetExprList();
802 edge->to->resetExprList();
803 }
804 }
805}
806
807//! An utility class to compute and maintain the "producers of"
808//! relationship in a segmented graph. Space heavy and should
809//! avoid use on very large graphs.
810//!
811//! Currently trying to move as far as possible with only a
812//! producer map, without transposing it to make a consumer map.
813//! Making it NonCopyable because we should never need to
814//! copy an instance of this class.
815//! TODO: Space efficiency of this class will be important,
816//! because we need it in the pre-merging of segmentedGroups,
817//! currently O(n^2). O(nlogn) would be a reasonable
818//! goal to achieve.
819class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
820 using GroupSetOwningPtr = std::unique_ptr<GroupSet>;
821 using DependencyMap = std::unordered_map<SegmentedGroup*, GroupSetOwningPtr>;
822
823 public:
824 //! Populate producers of all groups in segmented fusion
825 explicit GroupDependencyAnalysis(const SegmentedFusion* segmented_fusion)
826 : segmented_fusion_(segmented_fusion) {
827 computeAllProducers();
828 }
829
830 //! Checks if group is consumer of any group in groups_to_check
831 //! TODO: refactor this similar to isConsumerOf
832 bool isConsumerOfAny(
833 SegmentedGroup* group,
834 const std::vector<SegmentedGroup*>& groups_to_check) {
835 auto& producers_of_group = getAllKnownProducersSet(group);
836 for (const auto& potential_producer : groups_to_check) {
837 if (producers_of_group->has(potential_producer)) {
838 return true;
839 }
840 }
841 return false;
842 }
843
844 bool isConsumerOf(SegmentedGroup* a, SegmentedGroup* b) {
845 auto it = known_producers_of_.find(a);
846 if (it == known_producers_of_.end()) {
847 return false;
848 }
849 return it->second->has(b);
850 }
851
852 bool isProducerOf(SegmentedGroup* a, SegmentedGroup* b) {
853 return isConsumerOf(b, a);
854 }
855
856 //! Finds the common producers of given set of groups
857 GroupSet getCommonProducersOf(std::vector<SegmentedGroup*> groups);
858
859 //! Update the map when the given two groups have been merged to create `ab`
860 //! this method is for book keeping and query only, doesn't implicitly check
861 //! for DAG
862 void mergeGroups(SegmentedGroup* a, SegmentedGroup* b, SegmentedGroup* ab);
863
864 //! Update the map when the given two groups have been merged to create
865 //! `merged` this method is for book keeping and query only, doesn't
866 //! implicitly check
867 //! for DAG
868 void mergeGroups(const GroupSet& groups, SegmentedGroup* merged);
869
870 //! Populate all values that is on a path from producer to consumer
871 //! efficiency can be important here. (TODO)
872 GroupSet valuesBetween(SegmentedGroup* producer, SegmentedGroup* consumer) {
873 if (producer == consumer) {
874 return {};
875 }
876
877 GroupSet values_between;
878 auto& all_producers_of_consumer = known_producers_of_.at(consumer);
879 TORCH_INTERNAL_ASSERT(
880 all_producers_of_consumer->has(producer),
881 "Fusion segment: Trying to compute path between two nodes that are not producer-consumer pairs");
882
883 for (auto producer_of_consumer : *all_producers_of_consumer) {
884 if (known_producers_of_.at(producer_of_consumer)->has(producer)) {
885 values_between.pushBack(producer_of_consumer);
886 }
887 }
888
889 return values_between;
890 }
891
892 //! Checks if the segmented fusion this class tracks is still a DAG
893 //! used for generating assertions after transforms
894 bool isproducerMapDAG() const {
895 for (auto& it : known_producers_of_) {
896 if (it.second->has(it.first)) {
897 return false;
898 }
899 }
900 return true;
901 }
902
903 private:
904 //! Collect initial producer info using
905 //! a work list algorithm through forward traversal
906 //! a backward DFS would do the same
907 void computeAllProducers();
908
909 //! Add all consumers of `producer` to `to_visit`
910 void addConsumersToWorkList(SegmentedGroup* producer, GroupSet& to_visit) {
911 for (auto e : producer->consumer_edges) {
912 // A consumer wouldn't have been worked before any of its producer
913 to_visit.pushBack(e->to);
914 }
915 }
916
917 //! Propagate all known producers of `from` into `into`, used to keep track
918 //! of:
919 //! 1. `from` is a producer of `into`
920 //! 2. `from` has been merged with other group to create `into`
921 void mergeAllKnownProducersIntoFrom(
922 SegmentedGroup* into,
923 SegmentedGroup* from) {
924 auto& producer_set_to_merge = *getAllKnownProducersSet(from);
925 for (auto group : producer_set_to_merge) {
926 getAllKnownProducersSet(into)->pushBack(group);
927 }
928 }
929
930 //! Utility to access known producers of a group so far
931 GroupSetOwningPtr& getAllKnownProducersSet(SegmentedGroup* group) {
932 auto& producer_set_ptr = known_producers_of_[group];
933 if (!producer_set_ptr) {
934 producer_set_ptr = std::make_unique<GroupSet>();
935 }
936 return producer_set_ptr;
937 }
938
939 // utility to compute the set intersection of group sets a,b
940 GroupSet groupSetIntersection(const GroupSet& a, const GroupSet& b) {
941 bool a_is_smaller = a.size() < b.size();
942 const auto& smaller_group_set = a_is_smaller ? a : b;
943 const auto& bigger_group_set = a_is_smaller ? b : a;
944
945 GroupSet intersection;
946 for (auto group : smaller_group_set) {
947 if (bigger_group_set.has(group)) {
948 intersection.pushBack(group);
949 }
950 }
951 return intersection;
952 }
953
954 private:
955 const SegmentedFusion* segmented_fusion_;
956 DependencyMap known_producers_of_;
957};
958
959//! Finds the common producers of given set of groups
960GroupSet GroupDependencyAnalysis::getCommonProducersOf(
961 std::vector<SegmentedGroup*> groups) {
962 if (groups.empty()) {
963 return {};
964 }
965
966 // Optimization: start with the smallest producer set
967 std::sort(
968 groups.begin(),
969 groups.end(),
970 [this](SegmentedGroup* a, SegmentedGroup* b) {
971 return known_producers_of_.at(a)->size() <
972 known_producers_of_.at(b)->size();
973 });
974
975 // Get intersection of producers
976 GroupSet common_producers = *(known_producers_of_.at(groups[0]));
977 for (const auto i : c10::irange(1, groups.size())) {
978 common_producers = groupSetIntersection(
979 common_producers, *(known_producers_of_.at(groups[i])));
980 }
981
982 return common_producers;
983}
984
985//! Update the map when the given two groups have been merged to create `ab`
986//! this method is for book keeping and query only, doesn't implicitly check
987//! for DAG
988void GroupDependencyAnalysis::mergeGroups(
989 SegmentedGroup* a,
990 SegmentedGroup* b,
991 SegmentedGroup* ab) {
992 // Access/Create the producer set of ab
993 auto& ab_set = getAllKnownProducersSet(ab);
994
995 // propagate a's and b's known producers into ab
996 mergeAllKnownProducersIntoFrom(ab, a);
997 mergeAllKnownProducersIntoFrom(ab, b);
998
999 // a, b are now merged, so no longer exist
1000 ab_set->erase(a);
1001 ab_set->erase(b);
1002
1003 // a, b no longer exist, remove their producer sets
1004 known_producers_of_.erase(a);
1005 known_producers_of_.erase(b);
1006
1007 // update producer maps of other groups
1008 for (auto& it : known_producers_of_) {
1009 // for all groups that are produced by either a or b
1010 if (it.second->has(a) || it.second->has(b)) {
1011 // insert ab as the new producer
1012 it.second->pushBack(ab);
1013 // all producers of both a and b are now producers of `it`
1014 mergeAllKnownProducersIntoFrom(it.first, ab);
1015 }
1016 // a, b no longer exist, remove them from `it`
1017 it.second->erase(a);
1018 it.second->erase(b);
1019 }
1020}
1021
1022//! Update the map when the given two groups have been merged to create
1023//! `merged` this method is for book keeping and query only, doesn't
1024//! implicitly check
1025//! for DAG
1026void GroupDependencyAnalysis::mergeGroups(
1027 const GroupSet& groups,
1028 SegmentedGroup* merged) {
1029 // Access/Create the producer set of merged
1030 auto& merged_set = getAllKnownProducersSet(merged);
1031
1032 // Populate all producers of groups and
1033 // write into producer map of merged
1034 std::for_each(
1035 groups.begin(), groups.end(), [this, merged](SegmentedGroup* group) {
1036 mergeAllKnownProducersIntoFrom(merged, group);
1037 });
1038
1039 // Erase all groups that was merged from producer map
1040 std::for_each(
1041 groups.begin(), groups.end(), [this, &merged_set](SegmentedGroup* group) {
1042 // erase inter dependencies
1043 merged_set->erase(group);
1044 // erase producer map tracking merged entires
1045 known_producers_of_.erase(group);
1046 });
1047
1048 // Update producer relationships with other groups in producer map
1049 for (auto& it : known_producers_of_) {
1050 auto producer_intersection = groupSetIntersection(*(it.second), groups);
1051 // if current node has any producer that was merged
1052 if (producer_intersection.size() > 0) {
1053 for (auto merged_producer : producer_intersection) {
1054 // delete all disappearing producers
1055 it.second->erase(merged_producer);
1056 }
1057 // insert the new group as producer
1058 it.second->pushBack(merged);
1059 }
1060 }
1061}
1062
1063//! Collect initial producer info using
1064//! a work list algorithm through forward traversal
1065//! a backward DFS would do the same
1066void GroupDependencyAnalysis::computeAllProducers() {
1067 GroupSet visited;
1068 GroupSet to_visit;
1069
1070 // Collect source nodes, with no producers we are guaranteed
1071 // a source node on a DAG
1072 for (auto group : segmented_fusion_->cgroups()) {
1073 if (group->producer_edges.empty()) {
1074 visited.pushBack(group);
1075 }
1076 }
1077
1078 // visited now only contain source nodes
1079 // they can go backward to nowhere
1080 for (auto group : visited) {
1081 addConsumersToWorkList(group, to_visit);
1082 }
1083
1084 while (!to_visit.empty()) {
1085 SegmentedGroup* to_update = nullptr;
1086 for (auto visiting_group : to_visit) {
1087 if (std::all_of(
1088 visiting_group->producer_edges.begin(),
1089 visiting_group->producer_edges.end(),
1090 [&visited](SegmentedEdge* e) { return visited.has(e->from); })) {
1091 // filter multi-edges
1092 GroupSet producers_of_visiting_group;
1093 for (auto edge : visiting_group->producer_edges) {
1094 producers_of_visiting_group.pushBack(edge->from);
1095 }
1096
1097 // populate all possible paths
1098 // from producer backward, including
1099 // the producer
1100 for (auto producer : producers_of_visiting_group) {
1101 getAllKnownProducersSet(visiting_group)->pushBack(producer);
1102 mergeAllKnownProducersIntoFrom(visiting_group, producer);
1103 }
1104 to_update = visiting_group;
1105 break;
1106 }
1107 }
1108 if (to_update) {
1109 addConsumersToWorkList(to_update, to_visit);
1110 to_visit.erase(to_update);
1111 visited.pushBack(to_update);
1112 } else {
1113 TORCH_INTERNAL_ASSERT(false, "unreachable, original graph not a DAG");
1114 }
1115 }
1116}
1117
1118std::ostream& operator<<(
1119 std::ostream& os,
1120 const SegmentedFusion* segmented_fusion) {
1121 // Topologically sort groups
1122 GroupDependencyAnalysis dependency(segmented_fusion);
1123 std::vector<SegmentedGroup*> groups_to_print(
1124 segmented_fusion->cgroups().begin(), segmented_fusion->cgroups().end());
1125 std::vector<SegmentedGroup*> sorted_groups_to_print;
1126
1127 // Sort groups topologically from producer to consumer before printing
1128 while (!groups_to_print.empty()) {
1129 auto group_it_to_append = groups_to_print.begin();
1130 for (auto group_it_to_compare = groups_to_print.begin();
1131 group_it_to_compare != groups_to_print.end();
1132 group_it_to_compare++) {
1133 if (dependency.isProducerOf(*group_it_to_compare, *group_it_to_append)) {
1134 group_it_to_append = group_it_to_compare;
1135 }
1136 }
1137 sorted_groups_to_print.push_back(*group_it_to_append);
1138 groups_to_print.erase(group_it_to_append);
1139 }
1140
1141 // Do a reverse look up to check the order of sorted groups
1142 std::unordered_map<SegmentedGroup*, size_t> group_order;
1143 for (const auto i : c10::irange(sorted_groups_to_print.size())) {
1144 group_order[sorted_groups_to_print[i]] = i;
1145 }
1146
1147 // Sort edges to print
1148 std::vector<SegmentedEdge*> sorted_edges_to_print(
1149 segmented_fusion->cedges().begin(), segmented_fusion->cedges().end());
1150 std::sort(
1151 sorted_edges_to_print.begin(),
1152 sorted_edges_to_print.end(),
1153 [&group_order](SegmentedEdge* edge_a, SegmentedEdge* edge_b) {
1154 return group_order.at(edge_a->from) < group_order.at(edge_b->from);
1155 });
1156
1157 os << "Segmented_Fusion Dump: -- fusion segments:\n";
1158 os << "Segmented_Fusion{ \n";
1159 os << "groups: \n";
1160 for (const auto g : sorted_groups_to_print) {
1161 os << g << "\n";
1162 }
1163 os << "edges: \n";
1164 for (const auto e : sorted_edges_to_print) {
1165 os << e << "\n";
1166 }
1167 os << "\ngroup details:\n";
1168 for (const auto g : sorted_groups_to_print) {
1169 detailGroupPrint(os, g);
1170 }
1171 os << "} //Segmented_Fusion\n";
1172 return os;
1173}
1174
1175void SegmentedFusion::print() const {
1176 std::cout << "Segmented_Fusion Dump: -- Re-written complete fusion:{\n";
1177 completeFusion()->printMath();
1178 std::cout << "} // {Re-written complete fusion}\n";
1179 std::cout << this << "\n";
1180}
1181
1182std::string toString(SegmentedFusion* segmented_fusion) {
1183 std::stringstream ss;
1184 ss << segmented_fusion;
1185 return ss.str();
1186}
1187
1188std::unique_ptr<Fusion> SegmentedFusion::makeFusion(SegmentedGroup* sg) {
1189 std::unique_ptr<Fusion> fusion_segment = std::make_unique<Fusion>();
1190
1191 auto complete_to_segment_map =
1192 Fusion::copy(completeFusion(), fusion_segment.get());
1193
1194 std::vector<Val*> input_list(
1195 fusion_segment->inputs().begin(), fusion_segment->inputs().end());
1196 for (auto inp : input_list) {
1197 fusion_segment->removeInput(inp);
1198 }
1199
1200 std::vector<Val*> output_list(
1201 fusion_segment->outputs().begin(), fusion_segment->outputs().end());
1202 for (auto out : output_list) {
1203 fusion_segment->removeOutput(out);
1204 }
1205
1206 std::vector<TensorView*> view_tvs;
1207 for (auto inp : getAllInputs(sg)) {
1208 auto clone_tv = complete_to_segment_map.clone(inp);
1209 fusion_segment->addInput(clone_tv);
1210 if (inp->isDefinitionType(ExprType::ViewOp)) {
1211 TORCH_INTERNAL_ASSERT(clone_tv != nullptr && clone_tv->isA<TensorView>());
1212 view_tvs.push_back(clone_tv->as<TensorView>());
1213 }
1214 }
1215
1216 for (auto out : getAllOutputs(sg)) {
1217 fusion_segment->addOutput(complete_to_segment_map.clone(out));
1218 }
1219
1220 for (auto tv : view_tvs) {
1221 tv->convertRfactorToRootDomain();
1222 }
1223
1224 return fusion_segment;
1225}
1226
1227void SegmentCandidateFinder::resetTraversal() {
1228 for (auto group : groups()) {
1229 // Start traversal at input groups
1230 if (group->producer_edges.empty()) {
1231 to_visit_.push_back(group);
1232 }
1233 group->visited_ = false;
1234 group->level_ = 0;
1235 }
1236}
1237
1238void SegmentCandidateFinder::resetLevels() {
1239 while (!to_visit_.empty()) {
1240 auto visit = to_visit_.front();
1241 to_visit_.pop_front();
1242
1243 // All inputs processed?
1244 bool ready = true;
1245 if (!visit->producer_edges.empty()) {
1246 ready = std::all_of(
1247 visit->producer_edges.begin(),
1248 visit->producer_edges.end(),
1249 [&](SegmentedEdge* dep) { return dep->from->visited_; });
1250 }
1251
1252 if (!ready) {
1253 // In case traversal doesn't complete because there's an error in the
1254 // DAG topology.
1255 next_to_visit_.push_back(visit);
1256 continue;
1257 }
1258
1259 visit->visited_ = true;
1260
1261 to_visit_.insert(
1262 to_visit_.end(), next_to_visit_.begin(), next_to_visit_.end());
1263 next_to_visit_.clear();
1264
1265 for (auto out : visit->consumer_edges) {
1266 to_visit_.push_back(out->to);
1267 }
1268
1269 visit->level_ = 0;
1270 for (auto inp : visit->producer_edges) {
1271 visit->level_ = std::max(visit->level_, inp->from->level_ + 1);
1272 }
1273 }
1274 TORCH_INTERNAL_ASSERT(
1275 next_to_visit_.empty(), "Error in graph, is not a DAG.");
1276}
1277
1278// Disconect group from neighbors, and return edges that were disconnected
1279std::unordered_set<SegmentedEdge*> SegmentCandidateFinder::disconnectGroup(
1280 SegmentedGroup* group) {
1281 std::unordered_set<SegmentedEdge*> removed_edges(
1282 group->producer_edges.begin(), group->producer_edges.end());
1283
1284 for (auto edge : group->producer_edges) {
1285 auto from = edge->from;
1286 auto& from_edges = from->consumer_edges;
1287 auto from_edge_it = std::find(from_edges.begin(), from_edges.end(), edge);
1288 TORCH_INTERNAL_ASSERT(
1289 from_edge_it != from_edges.end(), "Could not find edge to remove.");
1290 from_edges.erase(from_edge_it);
1291 }
1292
1293 for (auto edge : group->consumer_edges) {
1294 removed_edges.insert(edge);
1295 auto to = edge->to;
1296 auto& to_edges = to->producer_edges;
1297 auto to_edge_it = std::find(to_edges.begin(), to_edges.end(), edge);
1298 TORCH_INTERNAL_ASSERT(
1299 to_edge_it != to_edges.end(), "Could not find edge to remove.");
1300 to_edges.erase(to_edge_it);
1301 }
1302
1303 group->producer_edges.clear();
1304 group->consumer_edges.clear();
1305
1306 return removed_edges;
1307}
1308
1309void SegmentCandidateFinder::eraseGroups(
1310 std::unordered_set<SegmentedGroup*>& groups_to_erase) {
1311 std::unordered_set<SegmentedEdge*> edges_to_erase;
1312 for (auto group : groups_to_erase) {
1313 auto disconnected_edges = disconnectGroup(group);
1314 edges_to_erase.insert(disconnected_edges.begin(), disconnected_edges.end());
1315 }
1316
1317 edges().erase(
1318 std::remove_if(
1319 edges().begin(),
1320 edges().end(),
1321 [&edges_to_erase](SegmentedEdge* edge) {
1322 if (edges_to_erase.find(edge) != edges_to_erase.end()) {
1323 return true;
1324 };
1325 return false;
1326 }),
1327 edges().end());
1328
1329 groups().erase(
1330 std::remove_if(
1331 groups().begin(),
1332 groups().end(),
1333 [&groups_to_erase](SegmentedGroup* group) {
1334 if (groups_to_erase.find(group) != groups_to_erase.end()) {
1335 return true;
1336 };
1337 return false;
1338 }),
1339 groups().end());
1340}
1341
1342SegmentedGroup* SegmentCandidateFinder::mergeNodes() {
1343 SegmentedGroup* last_merged = nullptr;
1344 auto it = to_merge_.begin();
1345 TORCH_INTERNAL_ASSERT(to_merge_.size() % 2 == 0);
1346 while (it != to_merge_.end()) {
1347 auto group1 = *it++;
1348 auto group2 = *it++;
1349
1350 clean_up_groups_.emplace(group1);
1351 clean_up_groups_.emplace(group2);
1352
1353 // Make the new joined node
1354 auto joined_group = segmented_fusion_->newGroup();
1355
1356 joined_group->input_vals =
1357 uniqueValConcat({group1->input_vals, group2->input_vals});
1358
1359 joined_group->output_vals =
1360 uniqueValConcat({group1->output_vals, group2->output_vals});
1361
1362 joined_group->exprs_ = group1->exprs_;
1363 joined_group->exprs_.insert(
1364 joined_group->exprs_.end(),
1365 group2->exprs_.begin(),
1366 group2->exprs_.end());
1367
1368 auto producer_edges = getMergedProducerEdges(group1, group2);
1369 // Connect joined group to resulting neighbors
1370 for (auto edge : producer_edges) {
1371 auto from = edge->from;
1372 auto val = edge->val;
1373
1374 auto new_edge = segmented_fusion_->newEdge(from, joined_group, val);
1375 joined_group->producer_edges.push_back(new_edge);
1376 from->consumer_edges.push_back(new_edge);
1377 }
1378
1379 auto consumer_edges = getMergedConsumerEdges(group1, group2);
1380
1381 for (auto edge : consumer_edges) {
1382 auto to = edge->to;
1383 auto val = edge->val;
1384
1385 auto new_edge = segmented_fusion_->newEdge(joined_group, to, val);
1386 joined_group->consumer_edges.push_back(new_edge);
1387 edge->to->producer_edges.push_back(new_edge);
1388 }
1389
1390 joined_group->setHeuristic(deriveHeuristic(joined_group));
1391 // Need to maintain the group dependency data if it has been intialized
1392 // by previous merging
1393 if (group_dependency_) {
1394 group_dependency_->as<GroupDependencyAnalysis>()->mergeGroups(
1395 group1, group2, joined_group);
1396 }
1397 last_merged = joined_group;
1398 }
1399
1400 to_merge_.clear();
1401 for (auto group : clean_up_groups_) {
1402 auto disconnected_edges = disconnectGroup(group);
1403 clean_up_edges_.insert(
1404 disconnected_edges.begin(), disconnected_edges.end());
1405 }
1406
1407 edges().erase(
1408 std::remove_if(
1409 edges().begin(),
1410 edges().end(),
1411 [this](SegmentedEdge* edge) {
1412 if (this->clean_up_edges_.find(edge) !=
1413 this->clean_up_edges_.end()) {
1414 return true;
1415 };
1416 return false;
1417 }),
1418 edges().end());
1419
1420 groups().erase(
1421 std::remove_if(
1422 groups().begin(),
1423 groups().end(),
1424 [this](SegmentedGroup* group) {
1425 if (this->clean_up_groups_.find(group) !=
1426 this->clean_up_groups_.end()) {
1427 return true;
1428 };
1429 return false;
1430 }),
1431 groups().end());
1432
1433 clean_up_edges_.clear();
1434 clean_up_groups_.clear();
1435
1436 return last_merged;
1437}
1438
1439// Logic largely parallels mergeNodes, but they are used
1440// in different phases of segmentation. Should consider
1441// a clean up and share the implementations.
1442SegmentedGroup* SegmentCandidateFinder::mergeAllGivenGroups(
1443 const std::vector<SegmentedGroup*>& groups_to_merge) {
1444 TORCH_INTERNAL_ASSERT(
1445 !groups_to_merge.empty(),
1446 "fusion segment :(mergeAllGivenGroups) tried to merge no groups")
1447
1448 // Make a set to detect internal edges
1449 std::unordered_set<SegmentedGroup*> group_set(
1450 groups_to_merge.begin(), groups_to_merge.end());
1451
1452 // Sets to de-duplicate multiple uses of
1453 // input/edge values and re-computations of exprs
1454 std::unordered_set<Val*> used_edge_vals_set;
1455 std::unordered_set<Val*> used_input_vals_set;
1456 std::unordered_set<Expr*> exprs_set;
1457
1458 // Create new group
1459 auto joined_group = segmented_fusion_->newGroup();
1460
1461 // Populate edges, exprs, global vals
1462 // from each of the groups
1463 for (auto group : groups_to_merge) {
1464 // Populate complete fusion inputs to the group
1465 for (auto input_val : group->input_vals) {
1466 if (!used_input_vals_set.count(input_val)) {
1467 used_input_vals_set.insert(input_val);
1468 joined_group->input_vals.push_back(input_val);
1469 }
1470 }
1471
1472 // Populate complete fusion outputs from the group
1473 for (auto output_val : group->output_vals) {
1474 joined_group->output_vals.push_back(output_val);
1475 }
1476
1477 // Populate producer edges to the group
1478 for (auto edge : group->producer_edges) {
1479 if (
1480 // Check this is not internal edge
1481 !group_set.count(edge->from) &&
1482 // Check this val has been added or not
1483 !used_edge_vals_set.count(edge->val)) {
1484 used_edge_vals_set.insert(edge->val);
1485 auto new_producer_edge =
1486 segmented_fusion_->newEdge(edge->from, joined_group, edge->val);
1487 joined_group->producer_edges.push_back(new_producer_edge);
1488 edge->from->consumer_edges.push_back(new_producer_edge);
1489 }
1490 }
1491
1492 // Populate consumer edges from the group
1493 for (auto edge : group->consumer_edges) {
1494 if (
1495 // Check this is not internal edge
1496 !group_set.count(edge->to)) {
1497 auto new_consumer_edge =
1498 segmented_fusion_->newEdge(joined_group, edge->to, edge->val);
1499 joined_group->consumer_edges.push_back(new_consumer_edge);
1500 edge->to->producer_edges.push_back(new_consumer_edge);
1501 }
1502 }
1503
1504 // Populate exprs
1505 for (auto expr : group->exprs_) {
1506 if (!exprs_set.count(expr)) {
1507 joined_group->exprs_.push_back(expr);
1508 exprs_set.insert(expr);
1509 }
1510 }
1511 }
1512
1513 // Clean up original groups from segmented fusion
1514 for (auto group : groups_to_merge) {
1515 auto disconnected_edges = disconnectGroup(group);
1516 clean_up_edges_.insert(
1517 disconnected_edges.begin(), disconnected_edges.end());
1518 }
1519
1520 edges().erase(
1521 std::remove_if(
1522 edges().begin(),
1523 edges().end(),
1524 [this](SegmentedEdge* edge) { return clean_up_edges_.count(edge); }),
1525 edges().end());
1526
1527 groups().erase(
1528 std::remove_if(
1529 groups().begin(),
1530 groups().end(),
1531 [&group_set](SegmentedGroup* group) -> bool {
1532 return group_set.count(group);
1533 }),
1534 groups().end());
1535
1536 clean_up_edges_.clear();
1537
1538 joined_group->setHeuristic(deriveHeuristic(joined_group));
1539 return joined_group;
1540}
1541namespace {
1542
1543// Guard to temporarily change the inputs and outputs of a fusion. On
1544// destruction will return fusion to original state.
1545// Not used temporarily but will be useful when adding more mergin heuristics
1546class FusionSegmentGuard : public NonCopyable {
1547 public:
1548 FusionSegmentGuard() = delete;
1549
1550 FusionSegmentGuard(
1551 Fusion* fusion,
1552 std::vector<Val*> inputs,
1553 std::vector<Val*> outputs)
1554 : fusion_(fusion),
1555 old_inputs_(fusion->inputs()),
1556 old_outputs_(fusion->outputs()),
1557 new_inputs_(std::move(inputs)),
1558 new_outputs_(std::move(outputs)) {
1559 FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard");
1560 TORCH_INTERNAL_ASSERT(fusion_ != nullptr);
1561 for (auto old_inp : old_inputs_) {
1562 fusion_->removeInput(old_inp);
1563 }
1564
1565 for (auto old_out : old_outputs_) {
1566 fusion_->removeOutput(old_out);
1567 }
1568
1569 for (auto new_inp : new_inputs_) {
1570 fusion_->addInput(new_inp);
1571 }
1572
1573 for (auto new_out : new_outputs_) {
1574 fusion_->addOutput(new_out);
1575 }
1576 }
1577
1578 ~FusionSegmentGuard() {
1579 FUSER_PERF_SCOPE("~Segmenter::FusionSegmentGuard");
1580
1581 if (fusion_ == nullptr) {
1582 return;
1583 }
1584 for (auto new_inp : new_inputs_) {
1585 fusion_->removeInput(new_inp);
1586 }
1587
1588 for (auto new_out : new_outputs_) {
1589 fusion_->removeOutput(new_out);
1590 }
1591
1592 for (auto old_inp : old_inputs_) {
1593 fusion_->addInput(old_inp);
1594 }
1595
1596 for (auto old_out : old_outputs_) {
1597 fusion_->addOutput(old_out);
1598 }
1599 }
1600
1601 private:
1602 Fusion* const fusion_ = nullptr;
1603 const std::vector<Val*> old_inputs_;
1604 const std::vector<Val*> old_outputs_;
1605 const std::vector<Val*> new_inputs_;
1606 const std::vector<Val*> new_outputs_;
1607};
1608
1609c10::optional<ScheduleHeuristic> tryMerge(
1610 Fusion* fusion,
1611 SchedulerRuntimeInfo& runtime_info,
1612 SegmentedGroup* a,
1613 SegmentedGroup* b = nullptr) {
1614 FusionSegmentGuard fsg(fusion, getAllInputs(a, b), getAllOutputs(a, b));
1615
1616 scheduler_debug_utils::canScheduleMessage(
1617 "\n**Segmenter** Considering fusion:\n", fusion);
1618 return SchedulerEntry::proposeHeuristics(fusion, runtime_info);
1619}
1620
1621c10::optional<ScheduleHeuristic> tryMerge(
1622 Fusion* fusion,
1623 SchedulerRuntimeInfo& runtime_info,
1624 const std::vector<SegmentedGroup*>& segmented_groups) {
1625 FusionSegmentGuard fsg(
1626 fusion,
1627 allInputsIfTrueElseOutputs(segmented_groups, true),
1628 allInputsIfTrueElseOutputs(segmented_groups, false));
1629 scheduler_debug_utils::canScheduleMessage(
1630 "\n**Segmenter** Considering fusion:\n", fusion);
1631 return SchedulerEntry::proposeHeuristics(fusion, runtime_info);
1632}
1633
1634// This function is for cleanup and
1635// easier debugging. It shouldn't affect functionality
1636// since segmented fusions are compiled with fusion
1637// guard on the edges instead of actually looking
1638// at the exprs.
1639void deDuplicateScalarExprs(std::vector<Expr*>& exprs) {
1640 // Exprs in SegmentedGroup are not ordered
1641 // so it is ok to insert them from unordered
1642 // set
1643 std::unordered_set<Expr*> scalar_expr_set;
1644
1645 std::copy_if(
1646 exprs.begin(),
1647 exprs.end(),
1648 std::inserter(scalar_expr_set, scalar_expr_set.end()),
1649 [](Expr* expr) { return ir_utils::isScalarOp(expr); });
1650
1651 if (!scalar_expr_set.empty()) {
1652 exprs.erase(
1653 std::remove_if(
1654 exprs.begin(),
1655 exprs.end(),
1656 [&scalar_expr_set](Expr* expr) {
1657 return scalar_expr_set.count(expr);
1658 }),
1659 exprs.end());
1660 exprs.insert(exprs.end(), scalar_expr_set.begin(), scalar_expr_set.end());
1661 }
1662}
1663
1664} // namespace
1665
1666c10::optional<std::unique_ptr<SchedulerEntry>> SegmentedGroup::
1667 getMaybeSchedulerEntry(SchedulerRuntimeInfo& runtime_info) {
1668 FUSER_PERF_SCOPE("SegmentedGroup::getMaybeSchedulerEntry");
1669 auto fusion = segmented_fusion_->completeFusion();
1670 auto data_cache = segmented_fusion_->getCachedHeuristicDataFor(this);
1671 FusionSegmentGuard fsg(fusion, getAllInputs(this), getAllOutputs(this));
1672 if (!SchedulerEntry::canSchedule(
1673 heuristic(), fusion, runtime_info, data_cache)) {
1674 return c10::nullopt;
1675 }
1676 return SchedulerEntry::makeEntry(
1677 heuristic(), fusion, runtime_info, data_cache);
1678}
1679
1680void SegmentedGroup::resetExprList() {
1681 auto input_group_vec = getAllInputs(this);
1682 std::unordered_set<Val*> input_group_set(
1683 input_group_vec.begin(), input_group_vec.end());
1684 auto expr_set =
1685 DependencyCheck::getAllExprsBetween(input_group_set, getAllOutputs(this));
1686 exprs_ = std::vector<Expr*>(expr_set.begin(), expr_set.end());
1687}
1688
1689// Custom merge node passes:
1690// These passes are added at the beginning or the end of
1691// the node merging process to direct the heuristics of
1692// node merging process
1693//
1694// Should consider generalization and make a proper interface
1695// if we have more merge node heuristics like this
1696
1697//! Translate Welford
1698//!
1699//! This pass can be inserted at any stages of segmentation,
1700//! and it tries to replace welford ops with persistent
1701//! mean and var ops.
1702//!
1703//! The checking of feasibility of persistent kernels
1704//! is through normalization schedulers. The general idea
1705//! is to first try to translate on a copy, and see if
1706//! normalization scheduler is willing to produce a
1707//! persistent kernel.
1708//!
1709//! For complete fusion this pass checks if all the
1710//! welford ops can be translated simultaneously to
1711//! produce a persistent normalization kernel and
1712//! will perform translation if checks pass.
1713//!
1714//! For segmented fusion, same check is performed within
1715//! each segmented group to collect applicable welford ops,
1716//! and actual translations are performed on the complete
1717//! fusion after all the checks are done.
1718class TranslateApplicableWelford {
1719 public:
1720 //! Try translation on each segmented group of
1721 //! given segmented fusion
1722 //! returns true if any welford has been translated
1723 static bool run(
1724 SegmentedFusion* segmented_fusion,
1725 const KernelArgumentHolder& runtime_inputs) {
1726 TranslateApplicableWelford translate_welford(
1727 segmented_fusion, runtime_inputs);
1728 return translate_welford.translated_any_welford_;
1729 }
1730
1731 //! Try translation on complete fusion,
1732 //! returns true if any welford has been translated
1733 static bool run(Fusion* fusion, const KernelArgumentHolder& runtime_inputs) {
1734 TranslateApplicableWelford translate_welford(fusion, runtime_inputs);
1735 return translate_welford.translated_any_welford_;
1736 }
1737
1738 private:
1739 explicit TranslateApplicableWelford(
1740 SegmentedFusion* segmented_fusion,
1741 const KernelArgumentHolder& runtime_inputs);
1742
1743 explicit TranslateApplicableWelford(
1744 Fusion* fusion,
1745 const KernelArgumentHolder& runtime_inputs);
1746
1747 //! Given vector of welford ops from the same fusion,
1748 //! checks if translating all of them result in a
1749 //! persistent normalization kernel by try-runs on
1750 //! a test copy of the original fusion.
1751 //!
1752 //! Supported use cases are either un-segmented fusion,
1753 //! or all the given welfords are within the same
1754 //! segmented group. In the latter case, the segmented
1755 //! group containing all the welford ops needs to be
1756 //! provided.
1757 bool wouldTranslateToPersistent(
1758 const std::vector<WelfordOp*>& orignal_welfords,
1759 SegmentedGroup* group = nullptr);
1760
1761 //! Translate the given welford op into separate
1762 //! average and standard deviation calculation.
1763 void translateSingleWelford(WelfordOp* welford);
1764
1765 //! Utility to test if a translated fusion
1766 //! gives a persistent kernel. Uses normalization
1767 //! scheduler to do the test.
1768 bool isValidPersistentFusion(
1769 Fusion* translated_fusion,
1770 SchedulerRuntimeInfo& runtime_info);
1771
1772 private:
1773 //! Indicates any translation happened.
1774 bool translated_any_welford_ = false;
1775
1776 //! a reference to global fusion runtime inputs
1777 const KernelArgumentHolder& runtime_inputs_;
1778
1779 //! For translation within group only,
1780 //! group boundary at test copy
1781 //! (see wouldTranslateToPersistent implementation )
1782 std::vector<Val*> test_group_inputs_;
1783 std::vector<Val*> test_group_outputs_;
1784};
1785
1786TranslateApplicableWelford::TranslateApplicableWelford(
1787 Fusion* fusion,
1788 const KernelArgumentHolder& runtime_inputs)
1789 : runtime_inputs_(runtime_inputs) {
1790 auto exprs = fusion->exprs();
1791 std::vector<WelfordOp*> orignal_welfords(
1792 ir_utils::filterByType<WelfordOp>(exprs).begin(),
1793 ir_utils::filterByType<WelfordOp>(exprs).end());
1794
1795 if (wouldTranslateToPersistent(orignal_welfords)) {
1796 for (auto welford : orignal_welfords) {
1797 translateSingleWelford(welford);
1798 }
1799 translated_any_welford_ = true;
1800 }
1801}
1802
1803TranslateApplicableWelford::TranslateApplicableWelford(
1804 SegmentedFusion* segmented_fusion,
1805 const KernelArgumentHolder& runtime_inputs)
1806 : runtime_inputs_(runtime_inputs) {
1807 std::vector<SegmentedGroup*> translated_groups;
1808 std::vector<WelfordOp*> welford_to_translate;
1809 // Find welfords that can be translated in each group
1810 for (auto group : segmented_fusion->groups()) {
1811 std::vector<WelfordOp*> welford_in_group(
1812 ir_utils::filterByType<WelfordOp>(group->exprs()).begin(),
1813 ir_utils::filterByType<WelfordOp>(group->exprs()).end());
1814
1815 if (wouldTranslateToPersistent(welford_in_group, group)) {
1816 translated_groups.push_back(group);
1817 welford_to_translate.insert(
1818 welford_to_translate.end(),
1819 welford_in_group.begin(),
1820 welford_in_group.end());
1821 }
1822 }
1823
1824 // Actually translate the welford ops
1825 // and record all the vals that have been
1826 // replaced by the translation.
1827 for (auto welford : welford_to_translate) {
1828 translateSingleWelford(welford);
1829 }
1830
1831 for (auto translated_group : translated_groups) {
1832 // Update heuristics and expr list of translated groups
1833 translated_group->heuristic_ = ScheduleHeuristic::Persistent;
1834 translated_group->resetExprList();
1835 }
1836}
1837
1838bool TranslateApplicableWelford::isValidPersistentFusion(
1839 Fusion* translated_fusion,
1840 SchedulerRuntimeInfo& runtime_info) {
1841 if (!SchedulerEntry::canSchedule(
1842 ScheduleHeuristic::Persistent, translated_fusion, runtime_info)) {
1843 return false;
1844 }
1845
1846 auto scheduler = SchedulerEntry::makeEntry(
1847 ScheduleHeuristic::Persistent, translated_fusion, runtime_info);
1848
1849 return scheduler->reductionParams().persistent_kernel;
1850}
1851
1852bool TranslateApplicableWelford::wouldTranslateToPersistent(
1853 const std::vector<WelfordOp*>& orignal_welfords,
1854 SegmentedGroup* group) {
1855 if (orignal_welfords.empty()) {
1856 return false;
1857 }
1858
1859 // Make sure all welford ops come from the same complete fusion
1860 auto fusion = orignal_welfords[0]->fusion();
1861 TORCH_INTERNAL_ASSERT(
1862 std::all_of(
1863 orignal_welfords.begin(),
1864 orignal_welfords.end(),
1865 [fusion](WelfordOp* welford) { return welford->fusion() == fusion; }),
1866 "Welfords in given vector not in the same fusion");
1867
1868 // Make initial `in-progress copy`
1869 auto test_copy = std::make_unique<Fusion>();
1870 auto original_to_test_map = Fusion::copy(fusion, test_copy.get());
1871
1872 std::vector<WelfordOp*> copied_welfords;
1873 std::transform(
1874 orignal_welfords.begin(),
1875 orignal_welfords.end(),
1876 std::back_inserter(copied_welfords),
1877 [&original_to_test_map](auto welford) {
1878 return original_to_test_map.clone(welford);
1879 });
1880 // Copied welfords will be invalidated on translation, but Vals will be
1881 // reused, keep a reference to them.
1882 std::vector<Val*> welford_avgs;
1883 std::vector<Val*> welford_vars;
1884 for (auto welford : copied_welfords) {
1885 welford_avgs.push_back(welford->outAvg());
1886 welford_vars.push_back(welford->outVar());
1887 }
1888
1889 // Translate the welford ops
1890 for (auto welford_to_translate : copied_welfords) {
1891 translateSingleWelford(welford_to_translate);
1892 }
1893
1894 SchedulerRuntimeInfo runtime_info(test_copy.get(), runtime_inputs_, true);
1895 // If we are looking at a segment of fusion,
1896 // we maintain the segmented group boundary,
1897 // one set for in_progress copy and one set
1898 // for `test copy`
1899 if (group != nullptr) {
1900 auto original_inputs = getAllInputs(group);
1901 auto original_outputs = getAllOutputs(group);
1902 test_group_inputs_.clear();
1903 test_group_outputs_.clear();
1904 std::transform(
1905 original_inputs.begin(),
1906 original_inputs.end(),
1907 std::back_inserter(test_group_inputs_),
1908 [&original_to_test_map](Val* in) {
1909 return original_to_test_map.clone(in);
1910 });
1911 std::transform(
1912 original_outputs.begin(),
1913 original_outputs.end(),
1914 std::back_inserter(test_group_outputs_),
1915 [&original_to_test_map](Val* out) {
1916 return original_to_test_map.clone(out);
1917 });
1918
1919 // If only average is used from welford, we should still translate, but we
1920 // might not detect persistence if variance isn't actually used/marked as an
1921 // output in the test.
1922 for (auto outs_i : c10::irange(welford_avgs.size())) {
1923 auto avg = welford_avgs[outs_i];
1924 auto var = welford_vars[outs_i];
1925 if (avg->uses().empty()) {
1926 test_group_outputs_.push_back(avg);
1927 }
1928
1929 if (var->uses().empty()) {
1930 test_group_outputs_.push_back(var);
1931 }
1932 }
1933
1934 // Temporarily localize test copy around
1935 // the group boundary
1936 FusionSegmentGuard fsg(
1937 test_copy.get(), test_group_inputs_, test_group_outputs_);
1938
1939 // Test if the translated copy is persistent
1940 return isValidPersistentFusion(test_copy.get(), runtime_info);
1941 }
1942 // In the case where we work on un-segmented
1943 // fusion, no group boundary logic, just
1944 // translate and test.
1945 return isValidPersistentFusion(test_copy.get(), runtime_info);
1946}
1947
1948void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) {
1949 auto fusion = welford->fusion();
1950 FusionGuard fg(fusion);
1951 // Only support translation of welford ops that
1952 // doesn't take inputs that are already statistics,
1953 // i.e. an r-factor product.
1954 // This translation works on un-scheduled fusions so
1955 // shouldn't expect to see this.
1956 TORCH_INTERNAL_ASSERT(welford->inN()->isOneInt());
1957
1958 // Grab the inputs and outputs of the welford
1959 auto in_val = welford->in()->as<TensorView>();
1960 auto out_avg = welford->outAvg()->as<TensorView>();
1961 auto out_var = welford->outVar()->as<TensorView>();
1962 auto out_N = welford->outN()->as<TensorView>();
1963
1964 fusion->removeExpr(welford);
1965 // Not safe to use welford anymore
1966 welford = nullptr;
1967
1968 // Create normalization based welford graph
1969 // largely taken from batchnorm cpp benchmark
1970 const auto& in_root =
1971 TensorDomain::noReductions(in_val->getMaybeRFactorDomain());
1972 const auto& out_root = out_avg->getRootDomain();
1973 std::vector<int> red_axes;
1974
1975 TORCH_INTERNAL_ASSERT(
1976 in_root.size() == out_root.size(),
1977 "Invalid root domains of Welford input and output.",
1978 " Input: ",
1979 ir_utils::toString(in_root),
1980 ". Output: ",
1981 ir_utils::toString(out_root));
1982
1983 // Create scalar version of the feature element
1984 // counting.
1985 Val* num_features = IrBuilder::create<Double>(1);
1986 std::vector<bool> broadcast_mask(in_root.size(), false);
1987 for (const auto i : c10::irange(in_root.size())) {
1988 if (out_root.at(i)->isReduction()) {
1989 red_axes.push_back(i);
1990 broadcast_mask[i] = true;
1991 num_features = mul(num_features, out_root.at(i)->extent());
1992 }
1993 }
1994
1995 // Build a normalization expression group that is
1996 // equivalent to a welford operation.
1997 auto x_sum = sum(in_val, red_axes);
1998 IrBuilder::create<BinaryOp>(BinaryOpType::Div, out_avg, x_sum, num_features);
1999 // welford.avg may be broadcast. Reuse it if found.
2000 TensorView* x_avg_bcast = nullptr;
2001 for (auto& use_expr : out_avg->uses()) {
2002 if (auto bcast = dynamic_cast<BroadcastOp*>(use_expr)) {
2003 if (bcast->getBroadcastDimFlags() == broadcast_mask) {
2004 // Same broadcast found.
2005 x_avg_bcast = bcast->out()->as<TensorView>();
2006 break;
2007 }
2008 }
2009 }
2010
2011 // x_mean_sub may already exist. Reuse it if found.
2012 TensorView* x_mean_sub = nullptr;
2013 if (x_avg_bcast != nullptr) {
2014 for (auto& use_expr : x_avg_bcast->uses()) {
2015 if (auto bop = dynamic_cast<BinaryOp*>(use_expr)) {
2016 if (bop->getBinaryOpType() == BinaryOpType::Sub) {
2017 if (bop->lhs() == in_val && bop->rhs() == x_avg_bcast) {
2018 x_mean_sub = bop->out()->as<TensorView>();
2019 }
2020 }
2021 }
2022 }
2023 }
2024
2025 if (x_avg_bcast == nullptr) {
2026 x_avg_bcast = broadcast(out_avg, broadcast_mask);
2027 }
2028
2029 if (x_mean_sub == nullptr) {
2030 x_mean_sub = sub(in_val, x_avg_bcast);
2031 }
2032
2033 auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub);
2034 IrBuilder::create<ReductionOp>(
2035 BinaryOpType::Add,
2036 IrBuilder::create<Double>(0.0),
2037 out_var,
2038 x_mean_sub_pow);
2039 IrBuilder::create<UnaryOp>(UnaryOpType::Set, out_N, num_features);
2040
2041 // out_avg, out_N are now outputs of a pointwise ops and we
2042 // need to clear out its reduction domains.
2043 out_avg->clearReductionIterDomains();
2044 out_N->clearReductionIterDomains();
2045}
2046
2047bool SegmentCandidateFinder::TranslateWelfordInFusion(
2048 Fusion* fusion,
2049 const KernelArgumentHolder& runtime_inputs) {
2050 return TranslateApplicableWelford::run(fusion, runtime_inputs);
2051}
2052
2053//! CombineReductions:
2054//! This pass works before the main merge node process
2055//! It identifies reduction operations that can be combined
2056//! together to form a normalization kernel.
2057//! Two reductions are considered the same type if they have
2058//! the same root domain length, and the reduction axis are the same.
2059//! This pass tries to merge nodes with the same reduction type based
2060//! on the graph structure.
2061class CombineReductions {
2062 using GroupVec = std::vector<SegmentedGroup*>;
2063 class ReductionSignature;
2064
2065 public:
2066 static void run(SegmentCandidateFinder* segment_candidate_finder) {
2067 CombineReductions combine_reductions(segment_candidate_finder);
2068 }
2069 static bool shouldRun(SegmentCandidateFinder* segment_candidate_finder);
2070
2071 private:
2072 CombineReductions(SegmentCandidateFinder* segment_candidate_finder)
2073 : segment_candidate_finder_(segment_candidate_finder) {
2074 // Run pass over the segments
2075
2076 // Collect segmented groups with reductions in them,
2077 // Assuming running before any merge happened, so
2078 // should see exactly one non-trivial reduction in each group
2079 for (auto group : segment_candidate_finder_->groups()) {
2080 if (auto rop_signature =
2081 ReductionSignature::makeReductionSignature(group)) {
2082 // Ignore pure squeeze operations in this analysis
2083 if (!rop_signature->hasNonTrivialReduction()) {
2084 continue;
2085 }
2086
2087 groups_with_reductions_.push_back(group);
2088 // Check if this reduction signature is one that we have seen before
2089 auto signature_match_it = std::find_if(
2090 known_reduction_signatures_.begin(),
2091 known_reduction_signatures_.end(),
2092 [&rop_signature](auto& know_signature) {
2093 return know_signature->sameAs(rop_signature.get());
2094 });
2095 // Unmatched: Create a new signature entry if not known
2096 if (signature_match_it == known_reduction_signatures_.end()) {
2097 group_reduction_signature_map_[group] = rop_signature.get();
2098 known_reduction_signatures_.emplace_back(std::move(rop_signature));
2099 } else {
2100 // Matched known signature: Mark that this groups belongs to know
2101 // signature
2102 group_reduction_signature_map_[group] = signature_match_it->get();
2103 }
2104 }
2105 }
2106
2107 // Keep trying to merge groups with compatible reductions and compatible
2108 // paths
2109 // until no more merge opportunity can be identified
2110 bool merged_groups = true;
2111 while (merged_groups) {
2112 merged_groups = false;
2113
2114 // Merge one pair of reduction groups at a time, and need
2115 // the pass to update dependency info along the way to avoid cycles
2116 for (const auto first_group_index :
2117 c10::irange(groups_with_reductions_.size())) {
2118 if (merged_groups) {
2119 // Need to break and re-enter this loop because
2120 // groups_with_reductions_ will be updated
2121 break;
2122 }
2123
2124 // Select one of the group to merge and get its reduction signature
2125 auto first_group = groups_with_reductions_[first_group_index];
2126 auto first_group_signature =
2127 group_reduction_signature_map_.at(first_group);
2128
2129 for (const auto second_group_index : c10::irange(
2130 first_group_index + 1, groups_with_reductions_.size())) {
2131 if (merged_groups) {
2132 // Need to break and re-enter this loop because
2133 // groups_with_reductions_ will be updated
2134 break;
2135 }
2136 auto second_group = groups_with_reductions_[second_group_index];
2137 auto second_group_signature =
2138 group_reduction_signature_map_.at(second_group);
2139
2140 // Cannot merge if their signatures are not the same
2141 if (!first_group_signature->sameAs(second_group_signature)) {
2142 continue;
2143 }
2144
2145 // first try a vertical merge
2146 merged_groups =
2147 verticalReductionMerge(first_group, second_group) != nullptr;
2148 if (!merged_groups) {
2149 // vertical merge didn't happen, try a horizontal merge
2150 merged_groups =
2151 horizontalReductionMerge(first_group, second_group) != nullptr;
2152 }
2153 }
2154 }
2155 }
2156 }
2157
2158 //! Merge a vertical pair of producers and consumers,
2159 //! the resulting group will include all nodes that are
2160 //! also consumers of producer and producers of consumer,
2161 //! i.e. values between the given producer-consumer pair.
2162 //! Can be proven that:
2163 //! 1. Including all of these nodes will be cycle-free
2164 //! 2. These nodes are the minimal set of nodes to include if
2165 //! for producer-consumer pair to be in the same group cycle-free
2166 //!
2167 //! Returns nullptr if such merge cannot be achieved.
2168 //! Reasons for not merging will include:
2169 //! 1. Given groups do not form producer-consumer pair
2170 //! 2. Merge will create cycle on the graph
2171 //! 3. The merged joined group cannot be scheduled
2172 SegmentedGroup* verticalReductionMerge(
2173 SegmentedGroup* first_group,
2174 SegmentedGroup* second_group) {
2175 // This is part of ReductionCombine pass, and we should only call this
2176 // function on a pair of reduction/normalization groups
2177 TORCH_INTERNAL_ASSERT(
2178 group_reduction_signature_map_.at(first_group)
2179 ->sameAs(group_reduction_signature_map_.at(second_group)));
2180 TORCH_INTERNAL_ASSERT(first_group != second_group);
2181 // Get the group dependency data from segment finder
2182 auto dependency_analysis = segment_candidate_finder_->getGroupDependency();
2183
2184 // Check producer-consumer relationship
2185 SegmentedGroup* producer = nullptr;
2186 SegmentedGroup* consumer = nullptr;
2187 if (dependency_analysis->isConsumerOf(first_group, second_group)) {
2188 producer = second_group;
2189 consumer = first_group;
2190 } else if (dependency_analysis->isProducerOf(first_group, second_group)) {
2191 producer = first_group;
2192 consumer = second_group;
2193 } else {
2194 // Given groups aren't producer-consumer pair, won't merge
2195 return nullptr;
2196 }
2197
2198 // Collect all groups that we need to merge along with the producer and
2199 // consumer
2200 auto all_groups_to_merge =
2201 getValidMinVerticalMergedGroupSet(producer, consumer);
2202
2203 if (all_groups_to_merge.empty()) {
2204 // The vertical paths from producer to consumer have in-compatible
2205 // reductions
2206 // so this vertical merge cannot be done.
2207 return nullptr;
2208 }
2209
2210 // TODO: this step would not be deterministic, because valuesBetween isn't
2211 // could fix this by a topological order
2212 std::vector<SegmentedGroup*> all_groups_to_merge_vec(
2213 all_groups_to_merge.begin(), all_groups_to_merge.end());
2214
2215 // Final sanity check: the merged group can actually be scheduled
2216 Fusion* fusion =
2217 segment_candidate_finder_->segmented_fusion_->completeFusion();
2218 if (!tryMerge(
2219 fusion,
2220 segment_candidate_finder_->runtimeInfo(),
2221 all_groups_to_merge_vec)) {
2222 return nullptr;
2223 }
2224
2225 // Merge this group
2226 auto joined_group =
2227 segment_candidate_finder_->mergeAllGivenGroups(all_groups_to_merge_vec);
2228
2229 // Update dependency analysis
2230 dependency_analysis->mergeGroups(all_groups_to_merge, joined_group);
2231
2232 // Update the reduction groups that are merged
2233 groups_with_reductions_.push_back(joined_group);
2234 group_reduction_signature_map_[joined_group] =
2235 group_reduction_signature_map_.at(first_group);
2236 groups_with_reductions_.erase(
2237 std::remove_if(
2238 groups_with_reductions_.begin(),
2239 groups_with_reductions_.end(),
2240 [&all_groups_to_merge](SegmentedGroup* group) {
2241 return all_groups_to_merge.has(group);
2242 }),
2243 groups_with_reductions_.end());
2244
2245 return joined_group;
2246 }
2247
2248 //! Horizontal reduction merging:
2249 //! merge two horizontal groups with reduction expressions to make a joined
2250 //! normalization group. A pair of horizontal groups are ones that are not
2251 //! a producer-consumer pair, and share either a common producer or a common
2252 //! consumer.
2253 //!
2254 //! TODO: This implementation looks at common producers only, since common
2255 //! consumers are not computed easily with current dependency analysis.
2256 SegmentedGroup* horizontalReductionMerge(
2257 SegmentedGroup* first_group,
2258 SegmentedGroup* second_group) {
2259 // This is part of ReductionCombine pass, and we should only call this
2260 // function on a pair of
2261 // reduction/normalization groups
2262 TORCH_INTERNAL_ASSERT(
2263 group_reduction_signature_map_.at(first_group)
2264 ->sameAs(group_reduction_signature_map_.at(second_group)));
2265 TORCH_INTERNAL_ASSERT(first_group != second_group);
2266
2267 auto dependency_analysis = segment_candidate_finder_->getGroupDependency();
2268
2269 // Check that the two groups are not producer-consumer's
2270 if (dependency_analysis->isConsumerOf(first_group, second_group) ||
2271 dependency_analysis->isProducerOf(first_group, second_group)) {
2272 // This merge pass will not handle producer-consumer pairs
2273 return nullptr;
2274 }
2275
2276 // Get common producers of the two group
2277 auto common_producers_set =
2278 dependency_analysis->getCommonProducersOf({first_group, second_group});
2279 if (common_producers_set.empty()) {
2280 // The given pair doesn't have a common producer.
2281 // Either they have a common consumer, which we don't handle for now,
2282 // or maybe the two given groups are not connected.
2283 return nullptr;
2284 }
2285
2286 // We are looking for a very specific patterns here. The cases that this
2287 // pattern will not capture are ones that reductions of different
2288 // signatures are so interleaved that we cannot find a clear cut as
2289 // explained below, without graph rewriting. Some graph re-writing on the
2290 // segmented groups level could provide extra merging opportunities for
2291 // free, which could be part of next step.
2292 //
2293 // The specific pattern we look for contains a common producer P with
2294 // immediate consumers C1, C2 such that all paths from C1 to first_group and
2295 // all paths from C2 to second_group won't hit a reduction with a different
2296 // signature.
2297
2298 // Topologically sort the common producers and start with the topologically
2299 // minimal,
2300 // i.e. one that are closest to the two groups. This will cut the search
2301 // space.
2302 std::vector<SegmentedGroup*> common_producers;
2303 for (auto producer : common_producers_set) {
2304 if (!std::any_of(
2305 common_producers_set.begin(),
2306 common_producers_set.end(),
2307 [dependency_analysis, producer](SegmentedGroup* group) {
2308 return dependency_analysis->isProducerOf(producer, group);
2309 })) {
2310 common_producers.push_back(producer);
2311 }
2312 }
2313
2314 // Visit the common producers found, starting from topologically minimum,
2315 // i.e. the ones closer to the groups
2316 for (auto common_producer : common_producers) {
2317 // Visit this common producer
2318 // Use a double loop in case the schedulers like some patterns
2319 // better than the other
2320 for (auto first_consumer_edge : common_producer->consumer_edges) {
2321 auto producer_of_first_group = first_consumer_edge->to;
2322 auto to_merge_with_first_group = getValidMinVerticalMergedGroupSet(
2323 producer_of_first_group, first_group);
2324 if (to_merge_with_first_group.empty()) {
2325 // There's no valid merge path from this consumer of common producer,
2326 // either due to a conflicting reduction signature, or simply there's
2327 // no path to first group
2328 continue;
2329 }
2330 TORCH_INTERNAL_ASSERT(!dependency_analysis->isProducerOf(
2331 producer_of_first_group, second_group));
2332 for (auto second_consumer_edge : common_producer->consumer_edges) {
2333 auto producer_of_second_group = second_consumer_edge->to;
2334 auto to_merge_with_second_group = getValidMinVerticalMergedGroupSet(
2335 producer_of_second_group, second_group);
2336 if (to_merge_with_second_group.empty()) {
2337 // There's no valid merge path from this consumer of common
2338 // producer,
2339 // either due to a conflicting reduction signature, or simply
2340 // there's no path to second group
2341 continue;
2342 }
2343 TORCH_INTERNAL_ASSERT(!dependency_analysis->isProducerOf(
2344 producer_of_second_group, first_group));
2345 // At this point we should have a pair of valid candidates,final check
2346 // is to see if the combined group
2347 // can be scheduled by schedulers
2348 // merge the two paths and de-duplicate,
2349 // re-using container here with to_merge_with_second_group
2350 auto& groups_to_merge_set = to_merge_with_second_group;
2351 groups_to_merge_set.insert(
2352 to_merge_with_first_group.begin(),
2353 to_merge_with_first_group.end());
2354 std::vector<SegmentedGroup*> groups_to_merge_vec(
2355 groups_to_merge_set.begin(), groups_to_merge_set.end());
2356 Fusion* fusion =
2357 segment_candidate_finder_->segmented_fusion_->completeFusion();
2358 if (tryMerge(
2359 fusion,
2360 segment_candidate_finder_->runtimeInfo(),
2361 groups_to_merge_vec)) {
2362 // Found a valid horizontal merge, want to proceed with merging here
2363 auto joined_group = segment_candidate_finder_->mergeAllGivenGroups(
2364 groups_to_merge_vec);
2365 dependency_analysis->mergeGroups(groups_to_merge_set, joined_group);
2366
2367 groups_with_reductions_.push_back(joined_group);
2368 group_reduction_signature_map_[joined_group] =
2369 group_reduction_signature_map_.at(first_group);
2370 groups_with_reductions_.erase(
2371 std::remove_if(
2372 groups_with_reductions_.begin(),
2373 groups_with_reductions_.end(),
2374 [&groups_to_merge_set](SegmentedGroup* group) {
2375 return groups_to_merge_set.has(group);
2376 }),
2377 groups_with_reductions_.end());
2378
2379 return joined_group;
2380 }
2381 }
2382 }
2383 }
2384
2385 // Searched all possibilities and there is no valid horizontal merge pattern
2386 // found.
2387 return nullptr;
2388 }
2389
2390 //! This is a utility method that is used in both vertical merging and
2391 //! horizontal merging.
2392 //! It is used to identify the smallest set of groups to merge vertically
2393 //! involving the
2394 //! two given nodes.
2395 //! Given a pair of nodes this utility distinguishes 3 cases:
2396 //! 1. if maybe_producer is the same as maybe_consumer, then returns
2397 //! {maybe_producer}
2398 //! 2. if maybe_producer is actually a producer of consumer, returns a set
2399 //! containing
2400 //! the smallest merged group that would contain producer and consumer and
2401 //! would not introduce a cycle. Returns empty set if such group has
2402 //! a conflicting reduction signature.
2403 //! 3. returns empty set if neither conditions above apply.
2404 GroupSet getValidMinVerticalMergedGroupSet(
2405 SegmentedGroup* maybe_producer,
2406 SegmentedGroup* maybe_consumer) {
2407 auto dependency_analysis = segment_candidate_finder_->getGroupDependency();
2408 if (maybe_consumer == maybe_producer) {
2409 // maybe producer is the same as maybe_consumer
2410 return {maybe_consumer};
2411 } else if (dependency_analysis->isConsumerOf(
2412 maybe_consumer, maybe_producer)) {
2413 auto groups_to_check =
2414 dependency_analysis->valuesBetween(maybe_producer, maybe_consumer);
2415 groups_to_check.pushBack(maybe_producer);
2416 groups_to_check.pushBack(maybe_consumer);
2417
2418 // Check that either no group has a reduction or all groups have the same
2419 // reduction signature
2420 ReductionSignature* reduction_signature = nullptr;
2421
2422 // Iterate through the minimal group set to see if any conflicts
2423 for (auto group : groups_to_check) {
2424 // Check that this group does not involve a output edge contraction
2425 // This pass is intended to be a pre-merging pass. Since contracting an
2426 // output edge does not generate much saving of global memory access
2427 // we want to postpone merging these edges till the very final pass
2428 for (auto producer_edge_of_group : group->producer_edges) {
2429 if (groups_to_check.has(producer_edge_of_group->from) &&
2430 producer_edge_of_group->val->isFusionOutput()) {
2431 return {};
2432 }
2433 }
2434 for (auto consumer_edge_of_group : group->consumer_edges) {
2435 if (groups_to_check.has(consumer_edge_of_group->to) &&
2436 consumer_edge_of_group->val->isFusionOutput()) {
2437 return {};
2438 }
2439 }
2440
2441 // Check that this group does not have a conflicting reduction signature
2442 if (group_reduction_signature_map_.count(group)) {
2443 if (reduction_signature != nullptr) {
2444 if (!group_reduction_signature_map_.at(group)->sameAs(
2445 reduction_signature)) {
2446 // Found a conflict in reduction signature, cannot do a vertical
2447 // merge
2448 return {};
2449 }
2450 } else {
2451 reduction_signature = group_reduction_signature_map_.at(group);
2452 }
2453 }
2454 }
2455 return groups_to_check;
2456 }
2457 // maybe producer is not a producer of maybe consumer
2458 return {};
2459 }
2460
2461 private:
2462 SegmentCandidateFinder* segment_candidate_finder_;
2463
2464 // Wrapper class for reduction type
2465 // Assuming there wouldn't be too many of them
2466 // so won't need to create a hash
2467 // TODO:
2468 // Want to reconsider this for transpose operations,
2469 // need refactoring to handle reduction fusions across a transpose operation
2470 class ReductionSignature {
2471 public:
2472 bool sameAs(const ReductionSignature* reduction_signature) {
2473 if (reduction_signature == this) {
2474 return true;
2475 }
2476
2477 if (root_domain_size_ != reduction_signature->root_domain_size_ ||
2478 has_nontrivial_reduction_ !=
2479 reduction_signature->has_nontrivial_reduction_ ||
2480 reduction_axes_.size() !=
2481 reduction_signature->reduction_axes_.size()) {
2482 return false;
2483 }
2484
2485 for (const auto i : c10::irange(reduction_axes_.size())) {
2486 if (reduction_axes_[i] != reduction_signature->reduction_axes_[i]) {
2487 return false;
2488 }
2489 }
2490
2491 return true;
2492 }
2493
2494 bool sameAs(const ReductionSignature& reduction_signature) {
2495 return sameAs(&reduction_signature);
2496 }
2497
2498 bool hasNonTrivialReduction() const {
2499 return has_nontrivial_reduction_;
2500 }
2501
2502 static std::unique_ptr<ReductionSignature> makeReductionSignature(
2503 SegmentedGroup* group) {
2504 std::unique_ptr<ReductionSignature> signature = nullptr;
2505
2506 for (auto expr : group->exprs()) {
2507 std::unique_ptr<ReductionSignature> new_signature = nullptr;
2508
2509 if (auto rop = dynamic_cast<ReductionOp*>(expr)) {
2510 new_signature = std::make_unique<ReductionSignature>(rop);
2511 }
2512 if (auto wop = dynamic_cast<WelfordOp*>(expr)) {
2513 new_signature = std::make_unique<ReductionSignature>(wop);
2514 }
2515
2516 if (new_signature != nullptr) {
2517 TORCH_INTERNAL_ASSERT(
2518 signature == nullptr || !signature->has_nontrivial_reduction_ ||
2519 !new_signature->has_nontrivial_reduction_ ||
2520 signature->sameAs(new_signature.get()),
2521 "Conflicting signature found in this group");
2522 signature = std::move(new_signature);
2523 }
2524 }
2525 return signature;
2526 }
2527
2528 template <typename REDUCTION = ReductionOp>
2529 ReductionSignature(REDUCTION* rop) {
2530 auto out_tv = rop->out()->template as<TensorView>();
2531 has_nontrivial_reduction_ = out_tv->hasReduction();
2532 TORCH_INTERNAL_ASSERT(out_tv != nullptr);
2533 auto& root_domain = out_tv->getRootDomain();
2534 root_domain_size_ = root_domain.size();
2535
2536 // Trivial reduction i.e. squeeze is tricky here:
2537 // this pass doesn't want to touch any pure squeeze, i.e.:
2538 // T0 [R(1), I(i0), I(i1)]
2539 // meanwhile, for two reductions having
2540 // squeezes, we do require they have squeeze at the
2541 // same position so that they can be easily root domain mapped
2542 // So T0 and T1 are the same signature,
2543 // T0 [R(1), R(i0), I(i1)]
2544 // T1 [R(1), R(i0), I(i1)]
2545 // but T2 and T3 below are not
2546 // T0 [R(1), R(1), R(i0), I(i1)]
2547 // T1 [R(1), R(i0), I(i1)]
2548 for (const auto i : c10::irange(root_domain_size_)) {
2549 if (root_domain[i]->isReduction()) {
2550 reduction_axes_.push_back(i);
2551 }
2552 if (!root_domain[i]->isTrivialReduction()) {
2553 has_nontrivial_reduction_ = true;
2554 }
2555 }
2556 }
2557
2558 private:
2559 size_t root_domain_size_ = 0;
2560 std::vector<int> reduction_axes_;
2561 bool has_nontrivial_reduction_ = false;
2562 };
2563
2564 //! Keeps track of groups with reduction expressions,
2565 //! using a vector here to maintain a deterministic ordering
2566 GroupVec groups_with_reductions_;
2567
2568 //! Maps groups to their corresponding signature type
2569 std::unordered_map<SegmentedGroup*, ReductionSignature*>
2570 group_reduction_signature_map_;
2571
2572 //! Maintains all reduction signatures seen in the segmented fusion
2573 std::vector<std::unique_ptr<ReductionSignature>> known_reduction_signatures_;
2574};
2575
2576//! This is to be checked
2577bool CombineReductions::shouldRun(
2578 SegmentCandidateFinder* segment_candidate_finder) {
2579 std::vector<std::unique_ptr<ReductionSignature>> known_reductions;
2580 // Iterate over group segments we have before segment candidate finder
2581 // tries to merge any groups
2582 for (auto group : segment_candidate_finder->groups()) {
2583 if (auto reduction_signature =
2584 ReductionSignature::makeReductionSignature(group)) {
2585 if (reduction_signature->hasNonTrivialReduction() &&
2586 std::any_of(
2587 known_reductions.begin(),
2588 known_reductions.end(),
2589 [&reduction_signature](auto& know_signature) {
2590 return know_signature->sameAs(reduction_signature.get());
2591 })) {
2592 // Found two reductions with the same signature, run pass
2593 return true;
2594 }
2595 known_reductions.emplace_back(std::move(reduction_signature));
2596 }
2597 }
2598 return false;
2599}
2600
2601namespace {
2602
2603//! Returns true if group1 and group2 are an immediate producer-consumer pair.
2604bool areDirectlyConnected(SegmentedGroup* group1, SegmentedGroup* group2) {
2605 // Check if group1 is a immediate consumer of group2
2606 if (std::any_of(
2607 group1->producer_edges.begin(),
2608 group1->producer_edges.end(),
2609 [group2](SegmentedEdge* edge) { return edge->from == group2; })) {
2610 return true;
2611 }
2612
2613 // Check if group1 is a immediate producer of group2
2614 if (std::any_of(
2615 group1->consumer_edges.begin(),
2616 group1->consumer_edges.end(),
2617 [group2](SegmentedEdge* edge) { return edge->to == group2; })) {
2618 return true;
2619 }
2620
2621 return false;
2622}
2623
2624} // namespace
2625
2626bool SegmentCandidateFinder::codeGenSupportedMerge(
2627 SegmentedGroup* group1,
2628 SegmentedGroup* group2) {
2629 TORCH_INTERNAL_ASSERT(
2630 areDirectlyConnected(group1, group2),
2631 "only support testing immediate producer-consumer groups");
2632 Fusion* fusion = segmented_fusion_->completeFusion();
2633 auto h = tryMerge(fusion, runtime_info_, group1, group2);
2634 return h.has_value();
2635}
2636
2637// TODO: consider caching the heuristics value so tryMerge doesn't have to be
2638// called twice
2639ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic(
2640 SegmentedGroup* group) {
2641 Fusion* fusion = segmented_fusion_->completeFusion();
2642 auto h = tryMerge(fusion, runtime_info_, group);
2643 TORCH_INTERNAL_ASSERT(h.has_value());
2644 return h.value();
2645}
2646
2647SegmentCandidateFinder::SegmentCandidateFinder(
2648 std::unique_ptr<Fusion> fusion,
2649 const KernelArgumentHolder& inputs,
2650 SegmentCandidateFinderOptions options)
2651 : options_(options),
2652 runtime_info_(fusion.get(), inputs, true),
2653 runtime_inputs_(inputs) {
2654 segmented_fusion_ = std::make_unique<SegmentedFusion>(std::move(fusion));
2655 findSegments();
2656}
2657
2658void SegmentCandidateFinder::findSegments() {
2659 FUSER_PERF_SCOPE("Finding valid fusion segment solutions");
2660 // TODO: Make traversal items local to this function.
2661
2662 // Need this for initialization of the DAG that is process
2663 std::unordered_map<Expr*, SegmentedGroup*> expr2group;
2664
2665 // Keep track of complete fusion input use
2666 std::unordered_map<Val*, SegmentedGroup*> input2group;
2667
2668 // Initialize DAG, convert each expr to a segment group
2669 auto exprs = completeFusion()->exprs();
2670 for (auto expr : exprs) {
2671 if (!ir_utils::isScalarOp(expr)) {
2672 auto new_group = segmented_fusion_->newGroup(expr);
2673 expr2group.insert(std::make_pair(expr, new_group));
2674 }
2675 }
2676
2677 // Find all expresions that are simply unary ops from inputs. Don't segment
2678 // these as they're easy targets for recomputation. Only go until the first
2679 // expression that has multiple uses. We could continue, but the logic of
2680 // hacking the fusion "inputs" logic gets a bit more complicated.
2681
2682 // Expressions to exclude from segmentation because they're just derived from
2683 // unary ops on inputs to the complete fusion
2684 VectorOfUniqueEntries<Expr*> excluded_inp_unary_exprs;
2685
2686 // "Terminating" outputs from the excluded input unary exprs, these will be
2687 // treated as complete fusion inputs.
2688 VectorOfUniqueEntries<Val*> forwarded_inputs;
2689 {
2690 std::deque<Expr*> to_visit;
2691 for (auto inp : completeFusion()->inputs()) {
2692 if (std::all_of(inp->uses().begin(), inp->uses().end(), [](Expr* expr) {
2693 return expr->getExprType().value() == ExprType::UnaryOp;
2694 })) {
2695 to_visit.insert(to_visit.end(), inp->uses().begin(), inp->uses().end());
2696 }
2697 }
2698
2699 while (!to_visit.empty()) {
2700 auto expr = to_visit.front();
2701 to_visit.pop_front();
2702 if (expr->getExprType().value() != ExprType::UnaryOp ||
2703 expr->output(0)->isFusionOutput()) {
2704 continue;
2705 }
2706
2707 if (expr->output(0)->uses().size() > 1) {
2708 excluded_inp_unary_exprs.pushBack(expr);
2709 forwarded_inputs.pushBack(expr->output(0));
2710 continue;
2711 }
2712
2713 to_visit.emplace_back(expr->output(0)->uses()[0]);
2714 }
2715 }
2716
2717 auto excluded_fusion_inputs = IterVisitor::getInputsTo(
2718 {forwarded_inputs.begin(), forwarded_inputs.end()});
2719
2720 // List of vals to treat as complete fusion inputs for segmentation
2721 auto forwarded_fusion_inputs = completeFusion()->inputs();
2722
2723 forwarded_fusion_inputs.erase(
2724 std::remove_if(
2725 forwarded_fusion_inputs.begin(),
2726 forwarded_fusion_inputs.end(),
2727 [&excluded_fusion_inputs](Val* inp) {
2728 return std::find(
2729 excluded_fusion_inputs.begin(),
2730 excluded_fusion_inputs.end(),
2731 inp) != excluded_fusion_inputs.end();
2732 }),
2733 forwarded_fusion_inputs.end());
2734
2735 forwarded_fusion_inputs.insert(
2736 forwarded_fusion_inputs.end(),
2737 forwarded_inputs.begin(),
2738 forwarded_inputs.end());
2739
2740 auto isFusionInput = [&forwarded_fusion_inputs](Val* val) -> bool {
2741 return std::find(
2742 forwarded_fusion_inputs.begin(),
2743 forwarded_fusion_inputs.end(),
2744 val) != forwarded_fusion_inputs.end();
2745 };
2746
2747 // Insert auxiliary groups to use group dependency on inputs as well
2748 // TODO: these groups should never merged into any other groups, but are
2749 // just there to support the dependency analysis. Later re-factor should
2750 // avoid introducing them explicitly on the segmented fusion.
2751 for (auto input : forwarded_fusion_inputs) {
2752 // These groups are used to represent input as a common
2753 // producer in horizontal merges, and should never be
2754 // seen as a candidate for vertical merge
2755 auto new_group = segmented_fusion_->newGroup();
2756 input2group.insert({input, new_group});
2757 }
2758
2759 // Create edges between the Exprs. Mark inputs and outputs of the fusion.
2760 for (auto expr : exprs) {
2761 // No group created for scalar ops
2762 if (ir_utils::isScalarOp(expr)) {
2763 continue;
2764 }
2765
2766 if (excluded_inp_unary_exprs.has(expr)) {
2767 continue;
2768 }
2769
2770 auto expr_group = expr2group.at(expr);
2771 for (auto inp : expr->inputs()) {
2772 if (isFusionInput(inp)) {
2773 expr_group->input_vals.push_back(inp);
2774 auto aux_group = input2group.at(inp);
2775 auto new_edge = segmented_fusion_->newEdge(aux_group, expr_group, inp);
2776 expr_group->producer_edges.push_back(new_edge);
2777 aux_group->consumer_edges.push_back(new_edge);
2778 continue;
2779 }
2780
2781 // Could be something like a constant scalar, definition is nullptr, but
2782 // isn't an "input" to the fusion. At least not one provided by an
2783 // external source.
2784 if (inp->definition() == nullptr) {
2785 continue;
2786 }
2787
2788 // No group created for scalar ops since they may need to be duplicated
2789 // to avoid scalar edges. They are handled in resolveScalarsInGroup
2790 if (inp->isScalar()) {
2791 continue;
2792 }
2793
2794 auto def_group = expr2group.at(inp->definition());
2795 auto new_edge = segmented_fusion_->newEdge(def_group, expr_group, inp);
2796 expr_group->producer_edges.push_back(new_edge);
2797 def_group->consumer_edges.push_back(new_edge);
2798 }
2799 for (auto out : expr->outputs()) {
2800 if (out->isFusionOutput()) {
2801 expr_group->output_vals.push_back(out);
2802 }
2803 }
2804 }
2805
2806 auto reduction_ops = ir_utils::getReductionOps(
2807 segmented_fusion_->completeFusion(), true /* ignore_trivial */);
2808 auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops);
2809
2810 if (options_.run_translate_welford &&
2811 (welford_ops.begin() != welford_ops.end())) {
2812 TranslateApplicableWelford::run(segmented_fusion_.get(), runtime_inputs_);
2813 }
2814
2815 for (auto group : groups()) {
2816 if (!group->outputs().empty()) {
2817 // Set heuristics in case single reduction kernels were left out
2818 group->setHeuristic(deriveHeuristic(group));
2819 }
2820 }
2821
2822 // Remove all scalar edges since they do not represent actual
2823 // dependency among segmented groups.
2824 removeScalarEdges();
2825
2826 // Run pre-merge heuristics
2827 if (options_.run_combine_reductions && CombineReductions::shouldRun(this)) {
2828 CombineReductions::run(this);
2829 }
2830
2831 // All merges will be vertical beyond this point for now, so
2832 // we can remove the input auxiliary groups. Should make the vertical
2833 // merges avoid auxiliary group once we start general horizontal merges
2834 std::unordered_set<SegmentedGroup*> input_groups;
2835 for (auto input : forwarded_fusion_inputs) {
2836 input_groups.insert(input2group.at(input));
2837 }
2838 eraseGroups(input_groups);
2839
2840 if (options_.run_herrmann_merge) {
2841 bool merged_nodes = true;
2842 // Initial merge iteration
2843 while (merged_nodes) {
2844 // Reset stateful traversal details in SegmentedGroups
2845 resetTraversal();
2846
2847 resetLevels();
2848
2849 for (auto& group : groups()) {
2850 if (group->merged_) {
2851 continue;
2852 }
2853 auto candidates = group->getMergeCandidates();
2854 if (candidates.empty()) {
2855 continue;
2856 }
2857
2858 auto candidate_it = candidates.begin();
2859 while (candidate_it != candidates.end() &&
2860 !codeGenSupportedMerge(group, candidate_it->group)) {
2861 candidate_it++;
2862 }
2863 if (candidate_it == candidates.end()) {
2864 continue;
2865 }
2866
2867 to_merge_.emplace_back(group);
2868 to_merge_.emplace_back(candidate_it->group);
2869
2870 group->merged_ = true;
2871 group->merge_with_ = candidate_it->group;
2872 group->merge_through_ = candidate_it->edge;
2873
2874 candidate_it->group->merged_ = true;
2875 candidate_it->group->merge_with_ = group;
2876 candidate_it->group->merge_through_ = candidate_it->edge;
2877 }
2878
2879 if (to_merge_.empty()) {
2880 merged_nodes = false;
2881 }
2882
2883 mergeNodes();
2884 }
2885 }
2886
2887 if (options_.run_final_merge) {
2888 // TODO: consider interleaving herrmman merge and bruteforce merge, as
2889 // bruteforce merge can introduce opportunities for more herrmann merge
2890 finalMerge();
2891 }
2892
2893 finalize();
2894
2895 if (isDebugDumpEnabled(DebugDumpOption::FusionSegmentsDrawing)) {
2896 segmented_fusion_->draw();
2897 }
2898}
2899
2900void SegmentCandidateFinder::finalMerge() {
2901 auto producer_check = getGroupDependency();
2902
2903 bool merged_nodes = true;
2904 while (merged_nodes) {
2905 // Iterate all groups and check if a group
2906 // can merge with one of its consumers
2907 for (auto producer_group : groups()) {
2908 // Populate consumers and their corresponding consumer edges
2909 std::unordered_map<SegmentedGroup*, SegmentedEdge*> consumer_edge_map;
2910 std::vector<SegmentedGroup*> all_consumers_of_producer_group;
2911 for (auto consumer : producer_group->consumer_edges) {
2912 // Since this is the last fusion pass, we can enable fusion through
2913 // outputs. Priority of this was decreased because if the only
2914 // connection between groups is an output node, best case scenario we
2915 // can save a single pass in memory. Where if it wasn't an output it
2916 // would be two passes.
2917 consumer_edge_map.insert({consumer->to, consumer});
2918 }
2919 // Populate all consumers from the map to avoid duplicate
2920 std::transform(
2921 consumer_edge_map.begin(),
2922 consumer_edge_map.end(),
2923 std::back_inserter(all_consumers_of_producer_group),
2924 [](auto& it) { return it.first; });
2925
2926 for (auto consumer : all_consumers_of_producer_group) {
2927 if (!producer_check->isConsumerOfAny(
2928 consumer, all_consumers_of_producer_group) &&
2929 codeGenSupportedMerge(producer_group, consumer)) {
2930 to_merge_.emplace_back(producer_group);
2931 to_merge_.emplace_back(consumer);
2932 producer_group->merged_ = true;
2933 producer_group->merge_with_ = consumer;
2934 producer_group->merge_through_ = consumer_edge_map.at(consumer);
2935 consumer->merged_ = true;
2936 consumer->merge_with_ = producer_group;
2937 consumer->merge_through_ = producer_group->merge_through_;
2938 break;
2939 }
2940 }
2941
2942 // Only want to merge one pair at a time so break if found any
2943 if (!to_merge_.empty()) {
2944 break;
2945 }
2946 }
2947
2948 if (to_merge_.empty()) {
2949 merged_nodes = false;
2950 } else {
2951 TORCH_INTERNAL_ASSERT(
2952 to_merge_.size() == 2, "merging more than 2 nodes in final iter");
2953 mergeNodes();
2954 }
2955 }
2956}
2957
2958void SegmentCandidateFinder::resolveScalarsInGroup(SegmentedGroup* group) {
2959 std::vector<Val*> to_visit;
2960 std::unordered_set<Val*> visited;
2961
2962 // Collect all scalar uses in the group
2963 for (auto expr : group->exprs()) {
2964 for (auto input : expr->inputs()) {
2965 if (input->isScalar()) {
2966 to_visit.push_back(input);
2967 }
2968 }
2969 }
2970
2971 // Keep track of composite fusion inputs used in this group
2972 std::unordered_set<Val*> input_set(
2973 group->input_vals.begin(), group->input_vals.end());
2974
2975 // Record and append all missing scalar exprs at the end.
2976 std::vector<Expr*> exprs_to_add;
2977
2978 // Do a stack based traversal of the scalar ops to avoid
2979 // combinatorial duplication of exprs.
2980 while (!to_visit.empty()) {
2981 auto stack_top_val = to_visit.back();
2982 if (visited.count(stack_top_val)) {
2983 to_visit.pop_back();
2984 } else if (stack_top_val->definition() == nullptr) {
2985 // A scalar without def can be a scalar, a tensor dim,
2986 // or a composite fusion input
2987 // The first two cases are handled in finalize(),
2988 // the last case needs to add new input_val to this group.
2989 visited.insert(stack_top_val);
2990 // If this is a composite fusion scalar input, make sure this group has it
2991 if (stack_top_val->isFusionInput() && !input_set.count(stack_top_val)) {
2992 group->input_vals.push_back(stack_top_val);
2993 input_set.insert(stack_top_val);
2994 }
2995 to_visit.pop_back();
2996 } else {
2997 // A scalar with an actual definition
2998 auto definition_expr = stack_top_val->definition();
2999 bool all_inputs_visited = true;
3000 // If any of the inputs are not visited, visit them first
3001 for (auto input : definition_expr->inputs()) {
3002 if (!visited.count(input)) {
3003 all_inputs_visited = false;
3004 to_visit.push_back(input);
3005 }
3006 }
3007 // This node is ready to be visited
3008 if (all_inputs_visited) {
3009 // Collect the defining expr to insert into group
3010 exprs_to_add.push_back(definition_expr);
3011 visited.insert(stack_top_val);
3012 to_visit.pop_back();
3013 }
3014 }
3015 }
3016
3017 // Add all the defining expr to the group
3018 for (auto expr : exprs_to_add) {
3019 group->exprs_.push_back(expr);
3020 }
3021}
3022
3023void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) {
3024 std::vector<Val*> to_visit;
3025 std::unordered_set<Val*> visited;
3026
3027 // Collect all inputs to group that are not inputs of fusion
3028 for (auto input : group->inputs()) {
3029 if (!input->isFusionInput()) {
3030 to_visit.push_back(input);
3031 }
3032 }
3033
3034 // Reset group inputs to real inputs
3035 group->input_vals = IterVisitor::getInputsTo(group->inputs());
3036
3037 // Grab all expressions needed to produce to_visit
3038 auto input_exprs = StmtSort::getExprs(completeFusion(), to_visit);
3039
3040 // Insert those expressions at the beginning of the group
3041 group->exprs_.insert(
3042 group->exprs_.begin(), input_exprs.begin(), input_exprs.end());
3043}
3044
3045void SegmentCandidateFinder::removeScalarEdges() {
3046 // Remove all scalar edges between groups
3047 // They may have been created by welford
3048 // translation.
3049 // we will not need them after scalar
3050 // resolution
3051 auto remove_scalar_edges_from_vec = [](std::vector<SegmentedEdge*>& edges) {
3052 edges.erase(
3053 std::remove_if(
3054 edges.begin(),
3055 edges.end(),
3056 [](SegmentedEdge* segmented_edge) {
3057 return segmented_edge->val->isScalar();
3058 }),
3059 edges.end());
3060 };
3061
3062 remove_scalar_edges_from_vec(edges());
3063 for (auto group : groups()) {
3064 remove_scalar_edges_from_vec(group->producer_edges);
3065 remove_scalar_edges_from_vec(group->consumer_edges);
3066 }
3067}
3068
3069void SegmentCandidateFinder::finalize() {
3070 // Remove unconnected groups
3071 groups().erase(
3072 std::remove_if(
3073 groups().begin(),
3074 groups().end(),
3075 [](SegmentedGroup* sg) { return !sg->isConnected(); }),
3076 groups().end());
3077
3078 // Add group labeling
3079 int i = 0;
3080 for (auto it = groups().begin(); it != groups().end(); it++, i++) {
3081 deDuplicateScalarExprs((*it)->exprs_);
3082 (*it)->setID(i);
3083 }
3084
3085 // TODO: too many things are currently abstracted under the term
3086 // finalize. Need to re-structure in a follow up.
3087
3088 // Finalize connections between segmented groups
3089 segmented_fusion_->finalize();
3090
3091 // Resolve all the scalar expressions needed in each group
3092 for (auto group : segmented_fusion_->groups()) {
3093 resolveScalarsInGroup(group);
3094 }
3095
3096 // Resolve all the scalar expressions needed in each group
3097 for (auto group : segmented_fusion_->groups()) {
3098 resolveInputsInGroup(group);
3099 }
3100
3101 // Finalize each group, fill in the missing inputs, i.e. tensor dims.
3102 for (auto g : groups()) {
3103 g->setHeuristic(deriveHeuristic(g));
3104 g->finalize();
3105 }
3106}
3107
3108GroupDependencyAnalysis* SegmentCandidateFinder::getGroupDependency() {
3109 if (!group_dependency_) {
3110 group_dependency_ =
3111 std::make_unique<GroupDependencyAnalysis>(segmented_fusion_.get());
3112 }
3113 return group_dependency_->as<GroupDependencyAnalysis>();
3114}
3115
3116FusionKernelRuntime::SchedulerEntryPtr SegmentedFusion::
3117 makeInitialSchedulerEntry(
3118 SegmentedGroup* sg,
3119 SchedulerRuntimeInfo& runtime_info) {
3120 auto local_fusion = completeFusion();
3121 FusionSegmentGuard fsg(local_fusion, getAllInputs(sg), getAllOutputs(sg));
3122 // This will be the first time each group is scheduled. So we'd want to
3123 // construct the cache data here.
3124 auto data_cache_ptr = std::make_unique<HeuristicSummary>(
3125 local_fusion, sg->heuristic(), runtime_info);
3126 auto data_cache = data_cache_ptr.get();
3127 setCachedHeuristicDataFor(sg, std::move(data_cache_ptr));
3128 return SchedulerEntry::makeEntry(
3129 sg->heuristic(), local_fusion, runtime_info, data_cache);
3130}
3131
3132std::unique_ptr<FusionHeuristics> SegmentedFusion::makeInitialHeuristics(
3133 const KernelArgumentHolder& inputs) {
3134 auto ret = std::make_unique<FusionHeuristics>();
3135 SchedulerRuntimeInfo runtime_info(completeFusion(), inputs, true);
3136 for (auto g : groups()) {
3137 ret->emplaceBack(makeInitialSchedulerEntry(g, runtime_info));
3138 }
3139 return ret;
3140}
3141
3142HeuristicSummary* SegmentedFusion::getCachedHeuristicDataFor(
3143 SegmentedGroup* group) {
3144 auto data_it = heuristic_summary_cache_.find(group);
3145 if (data_it == heuristic_summary_cache_.end()) {
3146 return nullptr;
3147 }
3148 return data_it->second.get();
3149}
3150
3151void SegmentedFusion::setCachedHeuristicDataFor(
3152 SegmentedGroup* group,
3153 std::unique_ptr<HeuristicSummary> data) {
3154 TORCH_INTERNAL_ASSERT(!heuristic_summary_cache_.count(group));
3155 heuristic_summary_cache_[group] = std::move(data);
3156}
3157
3158namespace {
3159
3160//! A thin traversal class that collects all the tensorviews
3161//! that could cast to fp16 or bf16 if they were segmented edges.
3162//! The selected values are currently defined as all the
3163//! tensorviews that
3164//! 1. are not complete fusion input/output,
3165//! 2. have a use chain that ends with a fp16
3166//! complete fusion output
3167//! 3. are fp32 datatype
3168class ForceHalfAnnotation : public IterVisitor {
3169 public:
3170 static std::unordered_set<TensorView*> getFP16AnnotatedSet(Fusion* fusion) {
3171 ForceHalfAnnotation annotation;
3172 std::vector<Val*> fp16_outputs;
3173 auto& cast_to_type = annotation.cast_to_type_;
3174 auto other_half_type =
3175 cast_to_type == DataType::Half ? DataType::BFloat16 : DataType::Half;
3176 std::copy_if(
3177 fusion->outputs().begin(),
3178 fusion->outputs().end(),
3179 std::back_inserter(fp16_outputs),
3180 [&cast_to_type, &other_half_type](auto* val) {
3181 auto dtype = val->getDataType().value();
3182 if (cast_to_type) {
3183 TORCH_INTERNAL_ASSERT(
3184 other_half_type != dtype,
3185 "Mix of BFloat16 and Float16 in the same graph is not supported.");
3186 }
3187 return val->template isA<TensorView>() &&
3188 val->getDataType().has_value() &&
3189 (val->getDataType().value() == DataType::Half ||
3190 val->getDataType().value() == DataType::BFloat16);
3191 });
3192
3193 annotation.traverseTo(fusion, fp16_outputs);
3194 return annotation.force_fp16_tv_set_;
3195 }
3196
3197 private:
3198 using IterVisitor::handle;
3199
3200 void handle(TensorView* tv) override {
3201 auto dtype = tv->getDataType();
3202 if (dtype.has_value() && dtype.value() == DataType::Float &&
3203 !tv->isFusionOutput() && !tv->isFusionInput()) {
3204 force_fp16_tv_set_.insert(tv);
3205 }
3206 }
3207
3208 std::unordered_set<TensorView*> force_fp16_tv_set_;
3209 c10::optional<DataType> cast_to_type_ = c10::nullopt;
3210};
3211
3212} // namespace
3213
3214void SegmentedFusion::annotateFP16IntermediateTensors() {
3215 force_fp16_tv_set_ =
3216 ForceHalfAnnotation::getFP16AnnotatedSet(complete_fusion_.get());
3217 for (auto out_tv :
3218 ir_utils::filterByType<TensorView>(complete_fusion_->outputs())) {
3219 if (out_tv) {
3220 auto dtype = out_tv->getDataType().value();
3221 if (dtype == DataType::Half || dtype == DataType::BFloat16) {
3222 force_half_precision_type_ = dtype;
3223 }
3224 }
3225 }
3226}
3227
3228std::string toString(const SegmentCandidateFinderOptions& segment_options) {
3229 std::stringstream ss;
3230 ss << "segmentation phases {\n";
3231 if (segment_options.run_combine_reductions) {
3232 ss << "combine reductions\n";
3233 }
3234 if (segment_options.run_herrmann_merge) {
3235 ss << "herrmann merging\n";
3236 }
3237 if (segment_options.run_final_merge) {
3238 ss << "final merging\n";
3239 }
3240 ss << "\n}\n";
3241 return ss.str();
3242}
3243
3244} // namespace cuda
3245} // namespace fuser
3246} // namespace jit
3247} // namespace torch
3248