1 | #include <arith.h> |
2 | #include <fusion.h> |
3 | #include <fusion_segmenter.h> |
4 | #include <instrumentation.h> |
5 | #include <ir_all_nodes.h> |
6 | #include <ir_cloner.h> |
7 | #include <ir_graphviz.h> |
8 | #include <ir_iostream.h> |
9 | #include <ir_utils.h> |
10 | #include <scheduler/debug_utils.h> |
11 | |
12 | #include <sstream> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | namespace fuser { |
17 | namespace cuda { |
18 | |
19 | namespace { |
20 | |
21 | using GroupSet = VectorOfUniqueEntries<SegmentedGroup*>; |
22 | |
23 | } // namespace |
24 | |
25 | std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::getNeighborGroups() { |
26 | std::vector<NeighborGroup> neighbors; |
27 | for (auto inp : producer_edges) { |
28 | if (inp->val->isFusionOutput()) { |
29 | // Don't fuse across output nodes, would need to find another path. |
30 | continue; |
31 | } |
32 | neighbors.emplace_back(inp->from, inp); |
33 | } |
34 | for (auto out : consumer_edges) { |
35 | if (out->val->isFusionOutput()) { |
36 | // Don't fuse across output nodes, would need to find another path. |
37 | continue; |
38 | } |
39 | neighbors.emplace_back(out->to, out); |
40 | } |
41 | return neighbors; |
42 | } |
43 | |
44 | std::vector<SegmentedGroup*> SegmentedGroup::getNeighbors() { |
45 | std::vector<SegmentedGroup*> neighbors; |
46 | auto neighbors_pair = getNeighborGroups(); |
47 | |
48 | std::transform( |
49 | neighbors_pair.begin(), |
50 | neighbors_pair.end(), |
51 | std::back_inserter(neighbors), |
52 | [](auto& neighbor_group) { return neighbor_group.group; }); |
53 | return neighbors; |
54 | } |
55 | |
56 | std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup:: |
57 | getMergeCandidates() { |
58 | // Don't look for candidates if already merged |
59 | if (merged_) { |
60 | return {}; |
61 | } |
62 | |
63 | std::vector<NeighborGroup> neighbors = getNeighborGroups(); |
64 | |
65 | // Can this node be merged with another? Check if neighbors are merged, if |
66 | // so and merged neighbor is within 1 level or node merged with neighbor is |
67 | // within 1 level, can't merge this node with anything else. |
68 | bool can_merge_this = true; |
69 | for (auto& neighbor : neighbors) { |
70 | if (!neighbor.group->merged_) { |
71 | continue; |
72 | } |
73 | if (std::abs(neighbor.group->level_ - level_) <= 1) { |
74 | can_merge_this = false; |
75 | } |
76 | if (std::abs(neighbor.group->merge_with_->level_ - level_) <= 1) { |
77 | can_merge_this = false; |
78 | } |
79 | } |
80 | if (!can_merge_this) { |
81 | return {}; |
82 | } |
83 | |
84 | std::vector<bool> can_merge(neighbors.size(), true); |
85 | |
86 | // Find neighbors with a level that is only 1 differant than this groups level |
87 | for (const auto i : c10::irange(neighbors.size())) { |
88 | if (std::abs(neighbors[i].group->level_ - level_) > 1) { |
89 | can_merge[i] = false; |
90 | } |
91 | } |
92 | |
93 | // Check neighbor of neighbors we're considering, if any of them are merged |
94 | // with another node, make sure the resulting edge wouldn't have a level |
95 | // difference of 1 |
96 | for (const auto i : c10::irange(neighbors.size())) { |
97 | if (!can_merge[i]) { |
98 | continue; |
99 | } |
100 | |
101 | for (auto neighbor_neighbor : neighbors[i].group->getNeighbors()) { |
102 | // Don't check self |
103 | if (neighbor_neighbor == neighbors[i].group) { |
104 | continue; |
105 | } |
106 | if (neighbor_neighbor->merged_) { |
107 | // check neighbor_neighbor level |
108 | if (std::abs(neighbor_neighbor->level_ - level_) <= 1) { |
109 | can_merge[i] = false; |
110 | } |
111 | if (std::abs(neighbor_neighbor->level_ - neighbors[i].group->level_) <= |
112 | 1) { |
113 | can_merge[i] = false; |
114 | } |
115 | |
116 | // check neighbor_neighber->merged_->level_ |
117 | if (std::abs(neighbor_neighbor->merge_with_->level_ - level_) <= 1) { |
118 | can_merge[i] = false; |
119 | } |
120 | if (std::abs( |
121 | neighbor_neighbor->merge_with_->level_ - |
122 | neighbors[i].group->level_) <= 1) { |
123 | can_merge[i] = false; |
124 | } |
125 | } |
126 | } |
127 | } |
128 | |
129 | std::vector<NeighborGroup> merge_candidates; |
130 | for (const auto i : c10::irange(neighbors.size())) { |
131 | if (can_merge[i]) { |
132 | merge_candidates.push_back(neighbors[i]); |
133 | } |
134 | } |
135 | return merge_candidates; |
136 | } |
137 | |
138 | void SegmentedGroup::clearTraversalInfo() { |
139 | level_ = -1; |
140 | visited_ = false; |
141 | merge_with_ = nullptr; |
142 | merge_through_ = nullptr; |
143 | merged_ = false; |
144 | } |
145 | |
146 | std::vector<Val*> SegmentedGroup::edgesToVals( |
147 | const std::vector<SegmentedEdge*>& se_v) { |
148 | std::vector<Val*> ret_v; |
149 | ret_v.reserve(se_v.size()); |
150 | |
151 | std::transform( |
152 | se_v.cbegin(), |
153 | se_v.cend(), |
154 | std::back_inserter(ret_v), |
155 | [](SegmentedEdge* se) { return se->val; }); |
156 | return ret_v; |
157 | } |
158 | |
159 | template <typename PREDICATE> |
160 | void insertUniquePredicated( |
161 | std::vector<Val*>& v, |
162 | const std::vector<SegmentedEdge*>& e, |
163 | PREDICATE pred) { |
164 | VectorOfUniqueEntries<Val*> to_add; |
165 | for (auto edge : e) { |
166 | to_add.pushBack(edge->val); |
167 | } |
168 | |
169 | std::copy_if( |
170 | to_add.vector().begin(), |
171 | to_add.vector().end(), |
172 | std::back_inserter(v), |
173 | [pred](Val* val) { return pred(val); }); |
174 | } |
175 | |
176 | void SegmentedGroup::finalize() { |
177 | // Move all the edges to group input/output |
178 | // Inputs |
179 | insertUniquePredicated( |
180 | input_vals, producer_edges, [](Val* v) { return !v->isFusionInput(); }); |
181 | |
182 | std::unordered_set<Val*> input_set(input_vals.begin(), input_vals.end()); |
183 | |
184 | for (auto expr : exprs_) { |
185 | for (auto i : expr->inputs()) { |
186 | if (i->isAnInt() && i->definition() == nullptr && !i->isConstScalar() && |
187 | !i->isFusionInput() && !input_set.count(i)) { |
188 | input_set.insert(i); |
189 | input_vals.push_back(i); |
190 | } |
191 | } |
192 | } |
193 | |
194 | // Outputs |
195 | insertUniquePredicated( |
196 | output_vals, consumer_edges, [](Val* v) { return !v->isFusionOutput(); }); |
197 | |
198 | // alias aware segmentation. we add inputs that are aliased by output |
199 | // generated in this SegmentedGroup |
200 | for (auto output : output_vals) { |
201 | if (auto aliased_input = segmented_fusion_->findAlias(output)) { |
202 | // aliasing currently only supported as output to input |
203 | TORCH_INTERNAL_ASSERT( |
204 | aliased_input->isFusionInput(), |
205 | "aliased input is not found in the complete fusion" ); |
206 | if (!input_set.count(aliased_input)) { |
207 | input_set.insert(aliased_input); |
208 | input_vals.push_back(aliased_input); |
209 | } |
210 | } |
211 | } |
212 | } |
213 | |
214 | std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group) { |
215 | os << "g{" ; |
216 | auto expr_to_print = group->exprs(); |
217 | std::sort( |
218 | expr_to_print.begin(), |
219 | expr_to_print.end(), |
220 | [](auto expr_a, auto expr_b) -> bool { |
221 | return expr_a->name() < expr_b->name(); |
222 | }); |
223 | for (const auto i : c10::irange(expr_to_print.size())) { |
224 | os << expr_to_print[i]->name(); |
225 | if (i + 1 != expr_to_print.size()) |
226 | os << ", " ; |
227 | } |
228 | os << "}\n" ; |
229 | return os; |
230 | } |
231 | |
232 | void SegmentedGroup::print() const { |
233 | std::cout << this << "\n" ; |
234 | } |
235 | |
236 | std::string toString(const SegmentedGroup* group) { |
237 | std::stringstream ss; |
238 | ss << group; |
239 | return ss.str(); |
240 | } |
241 | |
242 | std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge) { |
243 | os << "e{ " << edge->from << " -> " << edge->to << "(" ; |
244 | IrPrinter irp(os); |
245 | irp.handle(edge->val); |
246 | os << ") }\n" ; |
247 | return os; |
248 | } |
249 | |
250 | void SegmentedEdge::print() const { |
251 | std::cout << this << "\n" ; |
252 | } |
253 | |
254 | std::string toString(const SegmentedEdge* edge) { |
255 | std::stringstream ss; |
256 | ss << edge; |
257 | return ss.str(); |
258 | } |
259 | |
260 | std::unique_ptr<SegmentedFusion> SegmentedFusion::fromCompleteFusion( |
261 | std::unique_ptr<Fusion> fusion_ptr, |
262 | ScheduleHeuristic heuristic) { |
263 | auto fusion = fusion_ptr.get(); |
264 | |
265 | auto segmented_fusion_ptr = |
266 | std::make_unique<SegmentedFusion>(std::move(fusion_ptr)); |
267 | |
268 | // Make a group for the single fusion |
269 | auto single_group = segmented_fusion_ptr->newGroup(); |
270 | |
271 | // Add input and output vals |
272 | single_group->input_vals = fusion->inputs(); |
273 | single_group->output_vals = fusion->outputs(); |
274 | |
275 | // Get ordered expression list |
276 | single_group->resetExprList(); |
277 | |
278 | // Assign heuristics and id for the complete fusion |
279 | // to share the runtime path of segmented fusion. |
280 | single_group->setHeuristic(heuristic); |
281 | single_group->setID(0); |
282 | |
283 | return segmented_fusion_ptr; |
284 | } |
285 | |
286 | SegmentedFusion::SegmentedFusion(std::unique_ptr<Fusion> fusion) |
287 | : impl_(this), complete_fusion_(std::move(fusion)) { |
288 | segmented_fusion_name_ = segmentedFusionName(); |
289 | annotateFP16IntermediateTensors(); |
290 | } |
291 | |
292 | SegmentedGroup* SegmentedFusion::Impl::makeGroup() { |
293 | groups_.emplace_back(std::make_unique<SegmentedGroup>(owning_fusion_)); |
294 | return groups_.back().get(); |
295 | } |
296 | |
297 | SegmentedGroup* SegmentedFusion::Impl::makeGroup(Expr* expr) { |
298 | groups_.emplace_back(std::make_unique<SegmentedGroup>(expr, owning_fusion_)); |
299 | return groups_.back().get(); |
300 | } |
301 | |
302 | SegmentedEdge* SegmentedFusion::Impl::makeEdge( |
303 | SegmentedGroup* from, |
304 | SegmentedGroup* to, |
305 | Val* val) { |
306 | edges_.emplace_back(std::make_unique<SegmentedEdge>(from, to, val)); |
307 | return edges_.back().get(); |
308 | } |
309 | |
310 | void SegmentedFusion::Impl::cleanUnused() { |
311 | std::unordered_set<SegmentedGroup*> g_used( |
312 | owning_fusion_->groups().begin(), owning_fusion_->groups().end()); |
313 | std::unordered_set<SegmentedEdge*> e_used( |
314 | owning_fusion_->edges().begin(), owning_fusion_->edges().end()); |
315 | |
316 | groups_.erase( |
317 | std::remove_if( |
318 | groups_.begin(), |
319 | groups_.end(), |
320 | [&g_used](auto& g) { return g_used.count(g.get()) == 0; }), |
321 | groups_.end()); |
322 | |
323 | edges_.erase( |
324 | std::remove_if( |
325 | edges_.begin(), |
326 | edges_.end(), |
327 | [&e_used](auto& e) { return e_used.count(e.get()) == 0; }), |
328 | edges_.end()); |
329 | } |
330 | |
331 | SegmentedGroup* SegmentedFusion::newGroup() { |
332 | SegmentedGroup* g = impl_.makeGroup(); |
333 | groups_.push_back(g); |
334 | return g; |
335 | } |
336 | |
337 | SegmentedGroup* SegmentedFusion::newGroup(Expr* expr) { |
338 | SegmentedGroup* g = impl_.makeGroup(expr); |
339 | groups_.push_back(g); |
340 | return g; |
341 | } |
342 | |
343 | SegmentedEdge* SegmentedFusion::newEdge( |
344 | SegmentedGroup* from, |
345 | SegmentedGroup* to, |
346 | Val* val) { |
347 | SegmentedEdge* e = impl_.makeEdge(from, to, val); |
348 | edges_.push_back(e); |
349 | return e; |
350 | } |
351 | |
352 | void SegmentedFusion::draw() { |
353 | size_t group_index = 0; |
354 | std::unordered_map<const Expr*, size_t> expr_color_map; |
355 | |
356 | for (auto group : groups()) { |
357 | for (auto expr : group->exprs()) { |
358 | if (ir_utils::isTvOp(expr)) { |
359 | expr_color_map[expr] = group_index; |
360 | } |
361 | } |
362 | group_index++; |
363 | } |
364 | |
365 | std::stringstream sstream; |
366 | sstream << "segmented_fusion" << segmented_fusion_name_ << ".dot" ; |
367 | auto filename = sstream.str(); |
368 | |
369 | IrGraphGenerator::print( |
370 | completeFusion(), |
371 | filename.c_str(), |
372 | IrGraphGenerator::DetailLevel::ComputeOnly, |
373 | &expr_color_map); |
374 | } |
375 | |
376 | namespace { |
377 | |
378 | std::vector<Val*> uniqueValConcat( |
379 | const std::vector<std::vector<Val*>>& val_vecs) { |
380 | std::vector<Val*> unique_vals; |
381 | std::unordered_set<Val*> added; |
382 | for (const auto& vec : val_vecs) { |
383 | for (auto val : vec) { |
384 | if (added.find(val) == added.end()) { |
385 | unique_vals.push_back(val); |
386 | added.emplace(val); |
387 | } |
388 | } |
389 | } |
390 | return unique_vals; |
391 | } |
392 | |
393 | // Concat's producer edges of sg1 and sg2, but removes any edges from/to sg1/sg2 |
394 | std::vector<SegmentedEdge*> getMergedProducerEdges( |
395 | const SegmentedGroup* sg1, |
396 | const SegmentedGroup* sg2) { |
397 | TORCH_INTERNAL_ASSERT( |
398 | sg1 != nullptr && sg2 != nullptr, |
399 | "This function doesn't handle trivial." ); |
400 | |
401 | auto producer_edges = sg1->producer_edges; |
402 | |
403 | producer_edges.insert( |
404 | producer_edges.end(), |
405 | sg2->producer_edges.begin(), |
406 | sg2->producer_edges.end()); |
407 | |
408 | // Register producers into sg2 |
409 | std::unordered_set<Val*> sg2_vals; |
410 | for (auto se : sg2->producer_edges) { |
411 | sg2_vals.emplace(se->val); |
412 | } |
413 | |
414 | producer_edges.erase( |
415 | std::remove_if( |
416 | producer_edges.begin(), |
417 | producer_edges.end(), |
418 | [&sg1, &sg2, &sg2_vals](SegmentedEdge* se) { |
419 | // remove edges in between the groups and common uses |
420 | return (se->to == sg1 && se->from == sg2) || |
421 | (se->to == sg2 && se->from == sg1) || |
422 | (se->to == sg1 && sg2_vals.count(se->val)); |
423 | }), |
424 | producer_edges.end()); |
425 | |
426 | // Remove Duplicate Edges |
427 | |
428 | return producer_edges; |
429 | } |
430 | |
431 | // Concat's consumer edges of sg1 and sg2, but removes any edges from/to sg1/sg2 |
432 | std::vector<SegmentedEdge*> getMergedConsumerEdges( |
433 | const SegmentedGroup* sg1, |
434 | const SegmentedGroup* sg2) { |
435 | TORCH_INTERNAL_ASSERT( |
436 | sg1 != nullptr && sg2 != nullptr, |
437 | "This function doesn't handle trivial." ); |
438 | |
439 | auto consumer_edges = sg1->consumer_edges; |
440 | consumer_edges.insert( |
441 | consumer_edges.end(), |
442 | sg2->consumer_edges.begin(), |
443 | sg2->consumer_edges.end()); |
444 | |
445 | consumer_edges.erase( |
446 | std::remove_if( |
447 | consumer_edges.begin(), |
448 | consumer_edges.end(), |
449 | [&sg1, &sg2](SegmentedEdge* se) { |
450 | return (se->to == sg1 && se->from == sg2) || |
451 | (se->to == sg2 && se->from == sg1); |
452 | }), |
453 | consumer_edges.end()); |
454 | |
455 | return consumer_edges; |
456 | } |
457 | |
458 | // Returns a determinstic, unique set of inputs of the segment group, sg1, or |
459 | // the combined group sg1 + sg2 |
460 | std::vector<Val*> getAllInputs( |
461 | const SegmentedGroup* sg1, |
462 | const SegmentedGroup* sg2 = nullptr) { |
463 | std::vector<SegmentedEdge*> merged_producer_edges; |
464 | |
465 | if (sg1 != nullptr && sg2 != nullptr) { |
466 | merged_producer_edges = getMergedProducerEdges(sg1, sg2); |
467 | } else if (sg1 != nullptr) { |
468 | merged_producer_edges = sg1->producer_edges; |
469 | } else if (sg2 != nullptr) { |
470 | merged_producer_edges = sg2->producer_edges; |
471 | } |
472 | |
473 | std::vector<Val*> producer_edge_vals; |
474 | |
475 | std::transform( |
476 | merged_producer_edges.begin(), |
477 | merged_producer_edges.end(), |
478 | std::back_inserter(producer_edge_vals), |
479 | [](SegmentedEdge* se) { return se->val; }); |
480 | |
481 | return uniqueValConcat( |
482 | {sg1 == nullptr ? std::vector<Val*>() : sg1->input_vals, |
483 | sg2 == nullptr ? std::vector<Val*>() : sg2->input_vals, |
484 | producer_edge_vals}); |
485 | } |
486 | |
487 | // Returns a determinstic, unique set of outputs of the segment group, sg1, or |
488 | // the combined group sg1 + sg2 |
489 | std::vector<Val*> getAllOutputs( |
490 | const SegmentedGroup* sg1, |
491 | const SegmentedGroup* sg2 = nullptr) { |
492 | std::vector<SegmentedEdge*> merged_consumer_edges; |
493 | |
494 | if (sg1 != nullptr && sg2 != nullptr) { |
495 | merged_consumer_edges = getMergedConsumerEdges(sg1, sg2); |
496 | } else if (sg1 != nullptr) { |
497 | merged_consumer_edges = sg1->consumer_edges; |
498 | } else if (sg2 != nullptr) { |
499 | merged_consumer_edges = sg2->consumer_edges; |
500 | } |
501 | |
502 | std::vector<Val*> consumer_edge_vals; |
503 | |
504 | std::transform( |
505 | merged_consumer_edges.begin(), |
506 | merged_consumer_edges.end(), |
507 | std::back_inserter(consumer_edge_vals), |
508 | [](SegmentedEdge* se) { return se->val; }); |
509 | |
510 | auto output_vals = uniqueValConcat( |
511 | {sg1 == nullptr ? std::vector<Val*>() : sg1->output_vals, |
512 | sg2 == nullptr ? std::vector<Val*>() : sg2->output_vals, |
513 | consumer_edge_vals}); |
514 | |
515 | return output_vals; |
516 | } |
517 | |
518 | // Set version of getting merged input or output if segmented_groups were |
519 | // merged |
520 | // outputs respects order in segmented_groups for deterministic |
521 | // merge trace |
522 | // will get input if get_inputs otherwise will get ouputs |
523 | // TODO: merge with the binary counter parts |
524 | std::vector<Val*> allInputsIfTrueElseOutputs( |
525 | const std::vector<SegmentedGroup*>& segmented_groups, |
526 | bool get_inputs = true) { |
527 | // Helper to distinguish if we are getting inputs or outputs |
528 | using EdgeVec = std::vector<SegmentedEdge*>; |
529 | using ValVec = std::vector<Val*>; |
530 | |
531 | // Get producer edges to get inputs, consumer edges to get outputs |
532 | auto edges_to_process_from_or_to_group = |
533 | [get_inputs](SegmentedGroup* group) -> EdgeVec& { |
534 | return get_inputs ? group->producer_edges : group->consumer_edges; |
535 | }; |
536 | |
537 | // Get the group that is connected to current group |
538 | auto global_vals_from_or_to_group = |
539 | [get_inputs](SegmentedGroup* group) -> ValVec& { |
540 | return get_inputs ? group->input_vals : group->output_vals; |
541 | }; |
542 | |
543 | // Get the group that is connected to current group by given edge |
544 | auto opposite_end_of_edge = [get_inputs](SegmentedEdge* edge) { |
545 | return get_inputs ? edge->from : edge->to; |
546 | }; |
547 | |
548 | // Keep track of value and order to ensure deterministic result |
549 | std::vector<Val*> merged_vals; |
550 | std::unordered_set<Val*> merged_vals_set; |
551 | |
552 | // Put groups in a set for quick look up |
553 | std::unordered_set<SegmentedGroup*> segmented_groups_set( |
554 | segmented_groups.begin(), segmented_groups.end()); |
555 | |
556 | // Collect vals associated with edges |
557 | for (auto group : segmented_groups) { |
558 | for (auto edge : edges_to_process_from_or_to_group(group)) { |
559 | if ( |
560 | // Need to de-duplicate values so we don't get multiple of any input |
561 | !merged_vals_set.count(edge->val) && |
562 | // One side of this edge will be `group`, if the other end is |
563 | // also in segmented_groups, then this is an internal edge |
564 | // that we don't want. |
565 | !segmented_groups_set.count(opposite_end_of_edge(edge))) { |
566 | merged_vals.push_back(edge->val); |
567 | merged_vals_set.insert(edge->val); |
568 | } |
569 | } |
570 | } |
571 | |
572 | // Collect original fusion's inputs/outputs and append at the end |
573 | for (auto group : segmented_groups) { |
574 | for (auto global_val : global_vals_from_or_to_group(group)) { |
575 | // de-duplicate |
576 | if (!merged_vals_set.count(global_val)) { |
577 | merged_vals.push_back(global_val); |
578 | merged_vals_set.insert(global_val); |
579 | } |
580 | } |
581 | } |
582 | |
583 | return merged_vals; |
584 | } |
585 | |
586 | // A sorting utility used for debug printing only |
587 | // sorts the given vector of expressions in topological |
588 | // order, with equal cases respecting the original order |
589 | // in the vector. |
590 | std::vector<Expr*> groupExprPrintSorting(const std::vector<Expr*>& exprs) { |
591 | std::vector<Expr*> exprs_to_print(exprs.begin(), exprs.end()); |
592 | std::unordered_set<Expr*> exprs_to_print_set(exprs.begin(), exprs.end()); |
593 | std::unordered_set<Expr*> exprs_visited; |
594 | std::vector<Expr*> sorted_list; |
595 | while (!std::all_of( |
596 | exprs_to_print.begin(), |
597 | exprs_to_print.end(), |
598 | [&exprs_visited](auto expr) { return exprs_visited.count(expr); })) { |
599 | bool expr_added_to_sorted_list = false; |
600 | for (auto expr : exprs_to_print) { |
601 | if (!exprs_visited.count(expr)) { |
602 | bool add_this_expr = true; |
603 | // Check if any of the inputs of current |
604 | // expression within the group |
605 | // hasn't been visited |
606 | for (auto input : expr->inputs()) { |
607 | if (input->definition() && |
608 | exprs_to_print_set.count(input->definition()) && |
609 | !exprs_visited.count(input->definition())) { |
610 | add_this_expr = false; |
611 | break; |
612 | } |
613 | } |
614 | |
615 | // Append the current group to sorted list |
616 | // and mark visited |
617 | if (add_this_expr) { |
618 | expr_added_to_sorted_list = true; |
619 | exprs_visited.insert(expr); |
620 | sorted_list.push_back(expr); |
621 | break; |
622 | } |
623 | } |
624 | } |
625 | TORCH_INTERNAL_ASSERT( |
626 | expr_added_to_sorted_list, |
627 | "group debug print failed, exprs within given vector not a DAG" ); |
628 | } |
629 | return sorted_list; |
630 | } |
631 | |
632 | // Utility function to list all expressions in a group |
633 | void detailGroupPrint(std::ostream& os, const SegmentedGroup* group) { |
634 | IrPrinter irp(os); |
635 | |
636 | auto sort_val_by_name = [](std::vector<Val*> vals_to_sort) { |
637 | std::sort(vals_to_sort.begin(), vals_to_sort.end(), [](Val* a, Val* b) { |
638 | return a->name() < b->name(); |
639 | }); |
640 | return vals_to_sort; |
641 | }; |
642 | |
643 | os << "g{" |
644 | << "(" << toString(group->heuristic()) << ")\n" ; |
645 | os << "inputs: \n" ; |
646 | for (auto input : sort_val_by_name(getAllInputs(group))) { |
647 | os << input << " " << input->getDataType().value() << "\n" ; |
648 | } |
649 | os << "outputs: \n" ; |
650 | for (auto output : sort_val_by_name(getAllOutputs(group))) { |
651 | os << output << " " << output->getDataType().value() << "\n" ; |
652 | } |
653 | |
654 | os << "\n\n" ; |
655 | |
656 | auto expr_to_print = groupExprPrintSorting(group->exprs()); |
657 | |
658 | for (const auto i : c10::irange(expr_to_print.size())) { |
659 | irp.handle(expr_to_print[i]); |
660 | } |
661 | os << "}\n\n" ; |
662 | } |
663 | |
664 | //! Insert casts for an intermediate tensorview, i.e. ones |
665 | //! that are in segmentedEdges. The insertion is done on |
666 | //! the complete fusion, which should be owned by a segmented |
667 | //! fusion so that only one segmented fusion will be affected. |
668 | //! The replacement pattern is: |
669 | //! TV0 |
670 | //! replaced as: |
671 | //! fp16_tv = cast(TV0) |
672 | //! fp32_tv = cast(fp16_tv) |
673 | //! |
674 | //! All segmented groups that take TV0 as input will then |
675 | //! take fp16_tv or bf16_tv instead and the cast to fp32 will be |
676 | //! automatically included in each of the groups. |
677 | TensorView* castIntermediateValueInCompleteFusion( |
678 | Fusion* fusion, |
679 | TensorView* original_tv, |
680 | std::unordered_set<Expr*> edge_from_group_uses, |
681 | DataType dtype) { |
682 | FusionGuard fg(fusion); |
683 | |
684 | // A utility lambda that creates consumer tensordomain of |
685 | // the given tv and create a new tensorview around the |
686 | // new tensordomain with the given data type. |
687 | auto make_consumer_tv = [&](TensorView* from, DataType data_type) { |
688 | // Keep broadcast axes and remove reduction axes |
689 | size_t i = 0; |
690 | auto no_reduction_root_domain = |
691 | TensorDomain::noReductions(original_tv->getMaybeRFactorDomain()); |
692 | std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size()); |
693 | for (const auto& dom : no_reduction_root_domain) { |
694 | new_root_domain[i++] = dom->cloneWithoutRFactor(); |
695 | } |
696 | |
697 | // Create the actual domain and tv. |
698 | return IrBuilder::create<TensorView>( |
699 | IrBuilder::create<TensorDomain>( |
700 | new_root_domain, std::vector<bool>(new_root_domain.size(), true)), |
701 | data_type); |
702 | }; |
703 | |
704 | // create the tv's to cast |
705 | auto half_precision_tv = make_consumer_tv(original_tv, dtype); |
706 | |
707 | auto fp32_tv = make_consumer_tv(original_tv, DataType::Float); |
708 | |
709 | // replace uses of original tv with fp32_tv in the complete |
710 | // fusion |
711 | for (auto expr : fusion->unordered_uses(original_tv)) { |
712 | // Don't modify internal uses of buffers, only cast for outputs. |
713 | if (edge_from_group_uses.find(expr) == edge_from_group_uses.end()) { |
714 | ir_utils::replaceValInExpr(expr, original_tv, fp32_tv); |
715 | } |
716 | } |
717 | |
718 | // Insert the cast ops. |
719 | IrBuilder::create<UnaryOp>(UnaryOpType::Cast, half_precision_tv, original_tv); |
720 | IrBuilder::create<UnaryOp>(UnaryOpType::Cast, fp32_tv, half_precision_tv); |
721 | |
722 | // Return the new tv to replace original tv with |
723 | // on the segmented edges. |
724 | return half_precision_tv; |
725 | } |
726 | } // namespace |
727 | |
728 | void SegmentedFusion::finalize() { |
729 | impl_.cleanUnused(); |
730 | // Insert casts for the tensorviews that are on |
731 | // segmented edges and also on the force_to_fp16 list |
732 | // |
733 | // Note: |
734 | // The cast is inserted after the segmenter canSchedule check, which |
735 | // shouldn't cause problem short-term. The reason we put the cast here |
736 | // is we don't want to keep making copies of the original fusion |
737 | // during segmentation. Could consider making the cast insertion |
738 | // reversible if we do have to test canSchedule with the casts inserted |
739 | // during segmentation process in the future. |
740 | |
741 | // Keep track of groups that need to update expr list, |
742 | // including both the producer and consumer of the selected tv's that |
743 | // we cast to fp16. |
744 | std::unordered_set<SegmentedGroup*> affected_group_set; |
745 | // A map to keep track of the tv's that have been inserted cast |
746 | // and its fp16 version. |
747 | std::unordered_map<TensorView*, TensorView*> fp32_to_half_cast_map; |
748 | |
749 | // Go through all edges of the segmented fusion. |
750 | for (auto edge : edges()) { |
751 | TORCH_INTERNAL_ASSERT(edge->val->isA<TensorView>()); |
752 | auto edge_tv = edge->val->as<TensorView>(); |
753 | |
754 | // Uses of the edge value within the from group should not be replaced. This |
755 | // will cause the group to have an intermediate tensor |
756 | // tv -> float2half -> output |
757 | // \ -> half2float -> other uses in group |
758 | // The conversion back and forth from half precision can hurt numerics. |
759 | // Collect expressions that use the edge value of concern within the from |
760 | // group to avoid replacing with the cast tensor. |
761 | std::unordered_set<Expr*> uses_in_from_group; |
762 | |
763 | // All expressions in the from group of the edge |
764 | std::unordered_set<Expr*> from_group_exprs( |
765 | edge->from->exprs().begin(), edge->from->exprs().end()); |
766 | |
767 | // All uses of the edge val |
768 | for (auto edge_val_use_expr : edge_tv->uses()) { |
769 | if (from_group_exprs.count(edge_val_use_expr)) { |
770 | // Find uses in the to group of the val |
771 | uses_in_from_group.emplace(edge_val_use_expr); |
772 | } |
773 | } |
774 | |
775 | // Only look at ones that need to cast to fp16 or bf16 |
776 | if ((force_fp16_tv_set_.count(edge_tv) > 0)) { |
777 | auto cast_tv_it = fp32_to_half_cast_map.find(edge->val->as<TensorView>()); |
778 | TensorView* cast_tv = nullptr; |
779 | // Insert cast ops for this tv if we haven't done so. |
780 | if (cast_tv_it == fp32_to_half_cast_map.end()) { |
781 | cast_tv = castIntermediateValueInCompleteFusion( |
782 | complete_fusion_.get(), |
783 | edge_tv, |
784 | uses_in_from_group, |
785 | force_half_precision_type_); |
786 | fp32_to_half_cast_map[edge->val->as<TensorView>()] = cast_tv; |
787 | } else { |
788 | cast_tv = cast_tv_it->second; |
789 | } |
790 | |
791 | // Update the edge to use the fp16 version |
792 | edge->val = cast_tv; |
793 | |
794 | // Mark the groups for update later |
795 | affected_group_set.insert(edge->from); |
796 | affected_group_set.insert(edge->to); |
797 | |
798 | // The expr pointers on the group's expr list might have been freed |
799 | // by now after `ir_utils::replaceValInExpr`. |
800 | // Need a valid expression list to continue. Update from and to group. |
801 | edge->from->resetExprList(); |
802 | edge->to->resetExprList(); |
803 | } |
804 | } |
805 | } |
806 | |
807 | //! An utility class to compute and maintain the "producers of" |
808 | //! relationship in a segmented graph. Space heavy and should |
809 | //! avoid use on very large graphs. |
810 | //! |
811 | //! Currently trying to move as far as possible with only a |
812 | //! producer map, without transposing it to make a consumer map. |
813 | //! Making it NonCopyable because we should never need to |
814 | //! copy an instance of this class. |
815 | //! TODO: Space efficiency of this class will be important, |
816 | //! because we need it in the pre-merging of segmentedGroups, |
817 | //! currently O(n^2). O(nlogn) would be a reasonable |
818 | //! goal to achieve. |
819 | class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis { |
820 | using GroupSetOwningPtr = std::unique_ptr<GroupSet>; |
821 | using DependencyMap = std::unordered_map<SegmentedGroup*, GroupSetOwningPtr>; |
822 | |
823 | public: |
824 | //! Populate producers of all groups in segmented fusion |
825 | explicit GroupDependencyAnalysis(const SegmentedFusion* segmented_fusion) |
826 | : segmented_fusion_(segmented_fusion) { |
827 | computeAllProducers(); |
828 | } |
829 | |
830 | //! Checks if group is consumer of any group in groups_to_check |
831 | //! TODO: refactor this similar to isConsumerOf |
832 | bool isConsumerOfAny( |
833 | SegmentedGroup* group, |
834 | const std::vector<SegmentedGroup*>& groups_to_check) { |
835 | auto& producers_of_group = getAllKnownProducersSet(group); |
836 | for (const auto& potential_producer : groups_to_check) { |
837 | if (producers_of_group->has(potential_producer)) { |
838 | return true; |
839 | } |
840 | } |
841 | return false; |
842 | } |
843 | |
844 | bool isConsumerOf(SegmentedGroup* a, SegmentedGroup* b) { |
845 | auto it = known_producers_of_.find(a); |
846 | if (it == known_producers_of_.end()) { |
847 | return false; |
848 | } |
849 | return it->second->has(b); |
850 | } |
851 | |
852 | bool isProducerOf(SegmentedGroup* a, SegmentedGroup* b) { |
853 | return isConsumerOf(b, a); |
854 | } |
855 | |
856 | //! Finds the common producers of given set of groups |
857 | GroupSet getCommonProducersOf(std::vector<SegmentedGroup*> groups); |
858 | |
859 | //! Update the map when the given two groups have been merged to create `ab` |
860 | //! this method is for book keeping and query only, doesn't implicitly check |
861 | //! for DAG |
862 | void mergeGroups(SegmentedGroup* a, SegmentedGroup* b, SegmentedGroup* ab); |
863 | |
864 | //! Update the map when the given two groups have been merged to create |
865 | //! `merged` this method is for book keeping and query only, doesn't |
866 | //! implicitly check |
867 | //! for DAG |
868 | void mergeGroups(const GroupSet& groups, SegmentedGroup* merged); |
869 | |
870 | //! Populate all values that is on a path from producer to consumer |
871 | //! efficiency can be important here. (TODO) |
872 | GroupSet valuesBetween(SegmentedGroup* producer, SegmentedGroup* consumer) { |
873 | if (producer == consumer) { |
874 | return {}; |
875 | } |
876 | |
877 | GroupSet values_between; |
878 | auto& all_producers_of_consumer = known_producers_of_.at(consumer); |
879 | TORCH_INTERNAL_ASSERT( |
880 | all_producers_of_consumer->has(producer), |
881 | "Fusion segment: Trying to compute path between two nodes that are not producer-consumer pairs" ); |
882 | |
883 | for (auto producer_of_consumer : *all_producers_of_consumer) { |
884 | if (known_producers_of_.at(producer_of_consumer)->has(producer)) { |
885 | values_between.pushBack(producer_of_consumer); |
886 | } |
887 | } |
888 | |
889 | return values_between; |
890 | } |
891 | |
892 | //! Checks if the segmented fusion this class tracks is still a DAG |
893 | //! used for generating assertions after transforms |
894 | bool isproducerMapDAG() const { |
895 | for (auto& it : known_producers_of_) { |
896 | if (it.second->has(it.first)) { |
897 | return false; |
898 | } |
899 | } |
900 | return true; |
901 | } |
902 | |
903 | private: |
904 | //! Collect initial producer info using |
905 | //! a work list algorithm through forward traversal |
906 | //! a backward DFS would do the same |
907 | void computeAllProducers(); |
908 | |
909 | //! Add all consumers of `producer` to `to_visit` |
910 | void addConsumersToWorkList(SegmentedGroup* producer, GroupSet& to_visit) { |
911 | for (auto e : producer->consumer_edges) { |
912 | // A consumer wouldn't have been worked before any of its producer |
913 | to_visit.pushBack(e->to); |
914 | } |
915 | } |
916 | |
917 | //! Propagate all known producers of `from` into `into`, used to keep track |
918 | //! of: |
919 | //! 1. `from` is a producer of `into` |
920 | //! 2. `from` has been merged with other group to create `into` |
921 | void mergeAllKnownProducersIntoFrom( |
922 | SegmentedGroup* into, |
923 | SegmentedGroup* from) { |
924 | auto& producer_set_to_merge = *getAllKnownProducersSet(from); |
925 | for (auto group : producer_set_to_merge) { |
926 | getAllKnownProducersSet(into)->pushBack(group); |
927 | } |
928 | } |
929 | |
930 | //! Utility to access known producers of a group so far |
931 | GroupSetOwningPtr& (SegmentedGroup* group) { |
932 | auto& producer_set_ptr = known_producers_of_[group]; |
933 | if (!producer_set_ptr) { |
934 | producer_set_ptr = std::make_unique<GroupSet>(); |
935 | } |
936 | return producer_set_ptr; |
937 | } |
938 | |
939 | // utility to compute the set intersection of group sets a,b |
940 | GroupSet groupSetIntersection(const GroupSet& a, const GroupSet& b) { |
941 | bool a_is_smaller = a.size() < b.size(); |
942 | const auto& smaller_group_set = a_is_smaller ? a : b; |
943 | const auto& bigger_group_set = a_is_smaller ? b : a; |
944 | |
945 | GroupSet intersection; |
946 | for (auto group : smaller_group_set) { |
947 | if (bigger_group_set.has(group)) { |
948 | intersection.pushBack(group); |
949 | } |
950 | } |
951 | return intersection; |
952 | } |
953 | |
954 | private: |
955 | const SegmentedFusion* segmented_fusion_; |
956 | DependencyMap known_producers_of_; |
957 | }; |
958 | |
959 | //! Finds the common producers of given set of groups |
960 | GroupSet GroupDependencyAnalysis::getCommonProducersOf( |
961 | std::vector<SegmentedGroup*> groups) { |
962 | if (groups.empty()) { |
963 | return {}; |
964 | } |
965 | |
966 | // Optimization: start with the smallest producer set |
967 | std::sort( |
968 | groups.begin(), |
969 | groups.end(), |
970 | [this](SegmentedGroup* a, SegmentedGroup* b) { |
971 | return known_producers_of_.at(a)->size() < |
972 | known_producers_of_.at(b)->size(); |
973 | }); |
974 | |
975 | // Get intersection of producers |
976 | GroupSet common_producers = *(known_producers_of_.at(groups[0])); |
977 | for (const auto i : c10::irange(1, groups.size())) { |
978 | common_producers = groupSetIntersection( |
979 | common_producers, *(known_producers_of_.at(groups[i]))); |
980 | } |
981 | |
982 | return common_producers; |
983 | } |
984 | |
985 | //! Update the map when the given two groups have been merged to create `ab` |
986 | //! this method is for book keeping and query only, doesn't implicitly check |
987 | //! for DAG |
988 | void GroupDependencyAnalysis::mergeGroups( |
989 | SegmentedGroup* a, |
990 | SegmentedGroup* b, |
991 | SegmentedGroup* ab) { |
992 | // Access/Create the producer set of ab |
993 | auto& ab_set = getAllKnownProducersSet(ab); |
994 | |
995 | // propagate a's and b's known producers into ab |
996 | mergeAllKnownProducersIntoFrom(ab, a); |
997 | mergeAllKnownProducersIntoFrom(ab, b); |
998 | |
999 | // a, b are now merged, so no longer exist |
1000 | ab_set->erase(a); |
1001 | ab_set->erase(b); |
1002 | |
1003 | // a, b no longer exist, remove their producer sets |
1004 | known_producers_of_.erase(a); |
1005 | known_producers_of_.erase(b); |
1006 | |
1007 | // update producer maps of other groups |
1008 | for (auto& it : known_producers_of_) { |
1009 | // for all groups that are produced by either a or b |
1010 | if (it.second->has(a) || it.second->has(b)) { |
1011 | // insert ab as the new producer |
1012 | it.second->pushBack(ab); |
1013 | // all producers of both a and b are now producers of `it` |
1014 | mergeAllKnownProducersIntoFrom(it.first, ab); |
1015 | } |
1016 | // a, b no longer exist, remove them from `it` |
1017 | it.second->erase(a); |
1018 | it.second->erase(b); |
1019 | } |
1020 | } |
1021 | |
1022 | //! Update the map when the given two groups have been merged to create |
1023 | //! `merged` this method is for book keeping and query only, doesn't |
1024 | //! implicitly check |
1025 | //! for DAG |
1026 | void GroupDependencyAnalysis::mergeGroups( |
1027 | const GroupSet& groups, |
1028 | SegmentedGroup* merged) { |
1029 | // Access/Create the producer set of merged |
1030 | auto& merged_set = getAllKnownProducersSet(merged); |
1031 | |
1032 | // Populate all producers of groups and |
1033 | // write into producer map of merged |
1034 | std::for_each( |
1035 | groups.begin(), groups.end(), [this, merged](SegmentedGroup* group) { |
1036 | mergeAllKnownProducersIntoFrom(merged, group); |
1037 | }); |
1038 | |
1039 | // Erase all groups that was merged from producer map |
1040 | std::for_each( |
1041 | groups.begin(), groups.end(), [this, &merged_set](SegmentedGroup* group) { |
1042 | // erase inter dependencies |
1043 | merged_set->erase(group); |
1044 | // erase producer map tracking merged entires |
1045 | known_producers_of_.erase(group); |
1046 | }); |
1047 | |
1048 | // Update producer relationships with other groups in producer map |
1049 | for (auto& it : known_producers_of_) { |
1050 | auto producer_intersection = groupSetIntersection(*(it.second), groups); |
1051 | // if current node has any producer that was merged |
1052 | if (producer_intersection.size() > 0) { |
1053 | for (auto merged_producer : producer_intersection) { |
1054 | // delete all disappearing producers |
1055 | it.second->erase(merged_producer); |
1056 | } |
1057 | // insert the new group as producer |
1058 | it.second->pushBack(merged); |
1059 | } |
1060 | } |
1061 | } |
1062 | |
1063 | //! Collect initial producer info using |
1064 | //! a work list algorithm through forward traversal |
1065 | //! a backward DFS would do the same |
1066 | void GroupDependencyAnalysis::computeAllProducers() { |
1067 | GroupSet visited; |
1068 | GroupSet to_visit; |
1069 | |
1070 | // Collect source nodes, with no producers we are guaranteed |
1071 | // a source node on a DAG |
1072 | for (auto group : segmented_fusion_->cgroups()) { |
1073 | if (group->producer_edges.empty()) { |
1074 | visited.pushBack(group); |
1075 | } |
1076 | } |
1077 | |
1078 | // visited now only contain source nodes |
1079 | // they can go backward to nowhere |
1080 | for (auto group : visited) { |
1081 | addConsumersToWorkList(group, to_visit); |
1082 | } |
1083 | |
1084 | while (!to_visit.empty()) { |
1085 | SegmentedGroup* to_update = nullptr; |
1086 | for (auto visiting_group : to_visit) { |
1087 | if (std::all_of( |
1088 | visiting_group->producer_edges.begin(), |
1089 | visiting_group->producer_edges.end(), |
1090 | [&visited](SegmentedEdge* e) { return visited.has(e->from); })) { |
1091 | // filter multi-edges |
1092 | GroupSet producers_of_visiting_group; |
1093 | for (auto edge : visiting_group->producer_edges) { |
1094 | producers_of_visiting_group.pushBack(edge->from); |
1095 | } |
1096 | |
1097 | // populate all possible paths |
1098 | // from producer backward, including |
1099 | // the producer |
1100 | for (auto producer : producers_of_visiting_group) { |
1101 | getAllKnownProducersSet(visiting_group)->pushBack(producer); |
1102 | mergeAllKnownProducersIntoFrom(visiting_group, producer); |
1103 | } |
1104 | to_update = visiting_group; |
1105 | break; |
1106 | } |
1107 | } |
1108 | if (to_update) { |
1109 | addConsumersToWorkList(to_update, to_visit); |
1110 | to_visit.erase(to_update); |
1111 | visited.pushBack(to_update); |
1112 | } else { |
1113 | TORCH_INTERNAL_ASSERT(false, "unreachable, original graph not a DAG" ); |
1114 | } |
1115 | } |
1116 | } |
1117 | |
1118 | std::ostream& operator<<( |
1119 | std::ostream& os, |
1120 | const SegmentedFusion* segmented_fusion) { |
1121 | // Topologically sort groups |
1122 | GroupDependencyAnalysis dependency(segmented_fusion); |
1123 | std::vector<SegmentedGroup*> groups_to_print( |
1124 | segmented_fusion->cgroups().begin(), segmented_fusion->cgroups().end()); |
1125 | std::vector<SegmentedGroup*> sorted_groups_to_print; |
1126 | |
1127 | // Sort groups topologically from producer to consumer before printing |
1128 | while (!groups_to_print.empty()) { |
1129 | auto group_it_to_append = groups_to_print.begin(); |
1130 | for (auto group_it_to_compare = groups_to_print.begin(); |
1131 | group_it_to_compare != groups_to_print.end(); |
1132 | group_it_to_compare++) { |
1133 | if (dependency.isProducerOf(*group_it_to_compare, *group_it_to_append)) { |
1134 | group_it_to_append = group_it_to_compare; |
1135 | } |
1136 | } |
1137 | sorted_groups_to_print.push_back(*group_it_to_append); |
1138 | groups_to_print.erase(group_it_to_append); |
1139 | } |
1140 | |
1141 | // Do a reverse look up to check the order of sorted groups |
1142 | std::unordered_map<SegmentedGroup*, size_t> group_order; |
1143 | for (const auto i : c10::irange(sorted_groups_to_print.size())) { |
1144 | group_order[sorted_groups_to_print[i]] = i; |
1145 | } |
1146 | |
1147 | // Sort edges to print |
1148 | std::vector<SegmentedEdge*> sorted_edges_to_print( |
1149 | segmented_fusion->cedges().begin(), segmented_fusion->cedges().end()); |
1150 | std::sort( |
1151 | sorted_edges_to_print.begin(), |
1152 | sorted_edges_to_print.end(), |
1153 | [&group_order](SegmentedEdge* edge_a, SegmentedEdge* edge_b) { |
1154 | return group_order.at(edge_a->from) < group_order.at(edge_b->from); |
1155 | }); |
1156 | |
1157 | os << "Segmented_Fusion Dump: -- fusion segments:\n" ; |
1158 | os << "Segmented_Fusion{ \n" ; |
1159 | os << "groups: \n" ; |
1160 | for (const auto g : sorted_groups_to_print) { |
1161 | os << g << "\n" ; |
1162 | } |
1163 | os << "edges: \n" ; |
1164 | for (const auto e : sorted_edges_to_print) { |
1165 | os << e << "\n" ; |
1166 | } |
1167 | os << "\ngroup details:\n" ; |
1168 | for (const auto g : sorted_groups_to_print) { |
1169 | detailGroupPrint(os, g); |
1170 | } |
1171 | os << "} //Segmented_Fusion\n" ; |
1172 | return os; |
1173 | } |
1174 | |
1175 | void SegmentedFusion::print() const { |
1176 | std::cout << "Segmented_Fusion Dump: -- Re-written complete fusion:{\n" ; |
1177 | completeFusion()->printMath(); |
1178 | std::cout << "} // {Re-written complete fusion}\n" ; |
1179 | std::cout << this << "\n" ; |
1180 | } |
1181 | |
1182 | std::string toString(SegmentedFusion* segmented_fusion) { |
1183 | std::stringstream ss; |
1184 | ss << segmented_fusion; |
1185 | return ss.str(); |
1186 | } |
1187 | |
1188 | std::unique_ptr<Fusion> SegmentedFusion::makeFusion(SegmentedGroup* sg) { |
1189 | std::unique_ptr<Fusion> fusion_segment = std::make_unique<Fusion>(); |
1190 | |
1191 | auto complete_to_segment_map = |
1192 | Fusion::copy(completeFusion(), fusion_segment.get()); |
1193 | |
1194 | std::vector<Val*> input_list( |
1195 | fusion_segment->inputs().begin(), fusion_segment->inputs().end()); |
1196 | for (auto inp : input_list) { |
1197 | fusion_segment->removeInput(inp); |
1198 | } |
1199 | |
1200 | std::vector<Val*> output_list( |
1201 | fusion_segment->outputs().begin(), fusion_segment->outputs().end()); |
1202 | for (auto out : output_list) { |
1203 | fusion_segment->removeOutput(out); |
1204 | } |
1205 | |
1206 | std::vector<TensorView*> view_tvs; |
1207 | for (auto inp : getAllInputs(sg)) { |
1208 | auto clone_tv = complete_to_segment_map.clone(inp); |
1209 | fusion_segment->addInput(clone_tv); |
1210 | if (inp->isDefinitionType(ExprType::ViewOp)) { |
1211 | TORCH_INTERNAL_ASSERT(clone_tv != nullptr && clone_tv->isA<TensorView>()); |
1212 | view_tvs.push_back(clone_tv->as<TensorView>()); |
1213 | } |
1214 | } |
1215 | |
1216 | for (auto out : getAllOutputs(sg)) { |
1217 | fusion_segment->addOutput(complete_to_segment_map.clone(out)); |
1218 | } |
1219 | |
1220 | for (auto tv : view_tvs) { |
1221 | tv->convertRfactorToRootDomain(); |
1222 | } |
1223 | |
1224 | return fusion_segment; |
1225 | } |
1226 | |
1227 | void SegmentCandidateFinder::resetTraversal() { |
1228 | for (auto group : groups()) { |
1229 | // Start traversal at input groups |
1230 | if (group->producer_edges.empty()) { |
1231 | to_visit_.push_back(group); |
1232 | } |
1233 | group->visited_ = false; |
1234 | group->level_ = 0; |
1235 | } |
1236 | } |
1237 | |
1238 | void SegmentCandidateFinder::resetLevels() { |
1239 | while (!to_visit_.empty()) { |
1240 | auto visit = to_visit_.front(); |
1241 | to_visit_.pop_front(); |
1242 | |
1243 | // All inputs processed? |
1244 | bool ready = true; |
1245 | if (!visit->producer_edges.empty()) { |
1246 | ready = std::all_of( |
1247 | visit->producer_edges.begin(), |
1248 | visit->producer_edges.end(), |
1249 | [&](SegmentedEdge* dep) { return dep->from->visited_; }); |
1250 | } |
1251 | |
1252 | if (!ready) { |
1253 | // In case traversal doesn't complete because there's an error in the |
1254 | // DAG topology. |
1255 | next_to_visit_.push_back(visit); |
1256 | continue; |
1257 | } |
1258 | |
1259 | visit->visited_ = true; |
1260 | |
1261 | to_visit_.insert( |
1262 | to_visit_.end(), next_to_visit_.begin(), next_to_visit_.end()); |
1263 | next_to_visit_.clear(); |
1264 | |
1265 | for (auto out : visit->consumer_edges) { |
1266 | to_visit_.push_back(out->to); |
1267 | } |
1268 | |
1269 | visit->level_ = 0; |
1270 | for (auto inp : visit->producer_edges) { |
1271 | visit->level_ = std::max(visit->level_, inp->from->level_ + 1); |
1272 | } |
1273 | } |
1274 | TORCH_INTERNAL_ASSERT( |
1275 | next_to_visit_.empty(), "Error in graph, is not a DAG." ); |
1276 | } |
1277 | |
1278 | // Disconect group from neighbors, and return edges that were disconnected |
1279 | std::unordered_set<SegmentedEdge*> SegmentCandidateFinder::disconnectGroup( |
1280 | SegmentedGroup* group) { |
1281 | std::unordered_set<SegmentedEdge*> removed_edges( |
1282 | group->producer_edges.begin(), group->producer_edges.end()); |
1283 | |
1284 | for (auto edge : group->producer_edges) { |
1285 | auto from = edge->from; |
1286 | auto& from_edges = from->consumer_edges; |
1287 | auto from_edge_it = std::find(from_edges.begin(), from_edges.end(), edge); |
1288 | TORCH_INTERNAL_ASSERT( |
1289 | from_edge_it != from_edges.end(), "Could not find edge to remove." ); |
1290 | from_edges.erase(from_edge_it); |
1291 | } |
1292 | |
1293 | for (auto edge : group->consumer_edges) { |
1294 | removed_edges.insert(edge); |
1295 | auto to = edge->to; |
1296 | auto& to_edges = to->producer_edges; |
1297 | auto to_edge_it = std::find(to_edges.begin(), to_edges.end(), edge); |
1298 | TORCH_INTERNAL_ASSERT( |
1299 | to_edge_it != to_edges.end(), "Could not find edge to remove." ); |
1300 | to_edges.erase(to_edge_it); |
1301 | } |
1302 | |
1303 | group->producer_edges.clear(); |
1304 | group->consumer_edges.clear(); |
1305 | |
1306 | return removed_edges; |
1307 | } |
1308 | |
1309 | void SegmentCandidateFinder::eraseGroups( |
1310 | std::unordered_set<SegmentedGroup*>& groups_to_erase) { |
1311 | std::unordered_set<SegmentedEdge*> edges_to_erase; |
1312 | for (auto group : groups_to_erase) { |
1313 | auto disconnected_edges = disconnectGroup(group); |
1314 | edges_to_erase.insert(disconnected_edges.begin(), disconnected_edges.end()); |
1315 | } |
1316 | |
1317 | edges().erase( |
1318 | std::remove_if( |
1319 | edges().begin(), |
1320 | edges().end(), |
1321 | [&edges_to_erase](SegmentedEdge* edge) { |
1322 | if (edges_to_erase.find(edge) != edges_to_erase.end()) { |
1323 | return true; |
1324 | }; |
1325 | return false; |
1326 | }), |
1327 | edges().end()); |
1328 | |
1329 | groups().erase( |
1330 | std::remove_if( |
1331 | groups().begin(), |
1332 | groups().end(), |
1333 | [&groups_to_erase](SegmentedGroup* group) { |
1334 | if (groups_to_erase.find(group) != groups_to_erase.end()) { |
1335 | return true; |
1336 | }; |
1337 | return false; |
1338 | }), |
1339 | groups().end()); |
1340 | } |
1341 | |
1342 | SegmentedGroup* SegmentCandidateFinder::mergeNodes() { |
1343 | SegmentedGroup* last_merged = nullptr; |
1344 | auto it = to_merge_.begin(); |
1345 | TORCH_INTERNAL_ASSERT(to_merge_.size() % 2 == 0); |
1346 | while (it != to_merge_.end()) { |
1347 | auto group1 = *it++; |
1348 | auto group2 = *it++; |
1349 | |
1350 | clean_up_groups_.emplace(group1); |
1351 | clean_up_groups_.emplace(group2); |
1352 | |
1353 | // Make the new joined node |
1354 | auto joined_group = segmented_fusion_->newGroup(); |
1355 | |
1356 | joined_group->input_vals = |
1357 | uniqueValConcat({group1->input_vals, group2->input_vals}); |
1358 | |
1359 | joined_group->output_vals = |
1360 | uniqueValConcat({group1->output_vals, group2->output_vals}); |
1361 | |
1362 | joined_group->exprs_ = group1->exprs_; |
1363 | joined_group->exprs_.insert( |
1364 | joined_group->exprs_.end(), |
1365 | group2->exprs_.begin(), |
1366 | group2->exprs_.end()); |
1367 | |
1368 | auto producer_edges = getMergedProducerEdges(group1, group2); |
1369 | // Connect joined group to resulting neighbors |
1370 | for (auto edge : producer_edges) { |
1371 | auto from = edge->from; |
1372 | auto val = edge->val; |
1373 | |
1374 | auto new_edge = segmented_fusion_->newEdge(from, joined_group, val); |
1375 | joined_group->producer_edges.push_back(new_edge); |
1376 | from->consumer_edges.push_back(new_edge); |
1377 | } |
1378 | |
1379 | auto consumer_edges = getMergedConsumerEdges(group1, group2); |
1380 | |
1381 | for (auto edge : consumer_edges) { |
1382 | auto to = edge->to; |
1383 | auto val = edge->val; |
1384 | |
1385 | auto new_edge = segmented_fusion_->newEdge(joined_group, to, val); |
1386 | joined_group->consumer_edges.push_back(new_edge); |
1387 | edge->to->producer_edges.push_back(new_edge); |
1388 | } |
1389 | |
1390 | joined_group->setHeuristic(deriveHeuristic(joined_group)); |
1391 | // Need to maintain the group dependency data if it has been intialized |
1392 | // by previous merging |
1393 | if (group_dependency_) { |
1394 | group_dependency_->as<GroupDependencyAnalysis>()->mergeGroups( |
1395 | group1, group2, joined_group); |
1396 | } |
1397 | last_merged = joined_group; |
1398 | } |
1399 | |
1400 | to_merge_.clear(); |
1401 | for (auto group : clean_up_groups_) { |
1402 | auto disconnected_edges = disconnectGroup(group); |
1403 | clean_up_edges_.insert( |
1404 | disconnected_edges.begin(), disconnected_edges.end()); |
1405 | } |
1406 | |
1407 | edges().erase( |
1408 | std::remove_if( |
1409 | edges().begin(), |
1410 | edges().end(), |
1411 | [this](SegmentedEdge* edge) { |
1412 | if (this->clean_up_edges_.find(edge) != |
1413 | this->clean_up_edges_.end()) { |
1414 | return true; |
1415 | }; |
1416 | return false; |
1417 | }), |
1418 | edges().end()); |
1419 | |
1420 | groups().erase( |
1421 | std::remove_if( |
1422 | groups().begin(), |
1423 | groups().end(), |
1424 | [this](SegmentedGroup* group) { |
1425 | if (this->clean_up_groups_.find(group) != |
1426 | this->clean_up_groups_.end()) { |
1427 | return true; |
1428 | }; |
1429 | return false; |
1430 | }), |
1431 | groups().end()); |
1432 | |
1433 | clean_up_edges_.clear(); |
1434 | clean_up_groups_.clear(); |
1435 | |
1436 | return last_merged; |
1437 | } |
1438 | |
1439 | // Logic largely parallels mergeNodes, but they are used |
1440 | // in different phases of segmentation. Should consider |
1441 | // a clean up and share the implementations. |
1442 | SegmentedGroup* SegmentCandidateFinder::mergeAllGivenGroups( |
1443 | const std::vector<SegmentedGroup*>& groups_to_merge) { |
1444 | TORCH_INTERNAL_ASSERT( |
1445 | !groups_to_merge.empty(), |
1446 | "fusion segment :(mergeAllGivenGroups) tried to merge no groups" ) |
1447 | |
1448 | // Make a set to detect internal edges |
1449 | std::unordered_set<SegmentedGroup*> group_set( |
1450 | groups_to_merge.begin(), groups_to_merge.end()); |
1451 | |
1452 | // Sets to de-duplicate multiple uses of |
1453 | // input/edge values and re-computations of exprs |
1454 | std::unordered_set<Val*> used_edge_vals_set; |
1455 | std::unordered_set<Val*> used_input_vals_set; |
1456 | std::unordered_set<Expr*> exprs_set; |
1457 | |
1458 | // Create new group |
1459 | auto joined_group = segmented_fusion_->newGroup(); |
1460 | |
1461 | // Populate edges, exprs, global vals |
1462 | // from each of the groups |
1463 | for (auto group : groups_to_merge) { |
1464 | // Populate complete fusion inputs to the group |
1465 | for (auto input_val : group->input_vals) { |
1466 | if (!used_input_vals_set.count(input_val)) { |
1467 | used_input_vals_set.insert(input_val); |
1468 | joined_group->input_vals.push_back(input_val); |
1469 | } |
1470 | } |
1471 | |
1472 | // Populate complete fusion outputs from the group |
1473 | for (auto output_val : group->output_vals) { |
1474 | joined_group->output_vals.push_back(output_val); |
1475 | } |
1476 | |
1477 | // Populate producer edges to the group |
1478 | for (auto edge : group->producer_edges) { |
1479 | if ( |
1480 | // Check this is not internal edge |
1481 | !group_set.count(edge->from) && |
1482 | // Check this val has been added or not |
1483 | !used_edge_vals_set.count(edge->val)) { |
1484 | used_edge_vals_set.insert(edge->val); |
1485 | auto new_producer_edge = |
1486 | segmented_fusion_->newEdge(edge->from, joined_group, edge->val); |
1487 | joined_group->producer_edges.push_back(new_producer_edge); |
1488 | edge->from->consumer_edges.push_back(new_producer_edge); |
1489 | } |
1490 | } |
1491 | |
1492 | // Populate consumer edges from the group |
1493 | for (auto edge : group->consumer_edges) { |
1494 | if ( |
1495 | // Check this is not internal edge |
1496 | !group_set.count(edge->to)) { |
1497 | auto new_consumer_edge = |
1498 | segmented_fusion_->newEdge(joined_group, edge->to, edge->val); |
1499 | joined_group->consumer_edges.push_back(new_consumer_edge); |
1500 | edge->to->producer_edges.push_back(new_consumer_edge); |
1501 | } |
1502 | } |
1503 | |
1504 | // Populate exprs |
1505 | for (auto expr : group->exprs_) { |
1506 | if (!exprs_set.count(expr)) { |
1507 | joined_group->exprs_.push_back(expr); |
1508 | exprs_set.insert(expr); |
1509 | } |
1510 | } |
1511 | } |
1512 | |
1513 | // Clean up original groups from segmented fusion |
1514 | for (auto group : groups_to_merge) { |
1515 | auto disconnected_edges = disconnectGroup(group); |
1516 | clean_up_edges_.insert( |
1517 | disconnected_edges.begin(), disconnected_edges.end()); |
1518 | } |
1519 | |
1520 | edges().erase( |
1521 | std::remove_if( |
1522 | edges().begin(), |
1523 | edges().end(), |
1524 | [this](SegmentedEdge* edge) { return clean_up_edges_.count(edge); }), |
1525 | edges().end()); |
1526 | |
1527 | groups().erase( |
1528 | std::remove_if( |
1529 | groups().begin(), |
1530 | groups().end(), |
1531 | [&group_set](SegmentedGroup* group) -> bool { |
1532 | return group_set.count(group); |
1533 | }), |
1534 | groups().end()); |
1535 | |
1536 | clean_up_edges_.clear(); |
1537 | |
1538 | joined_group->setHeuristic(deriveHeuristic(joined_group)); |
1539 | return joined_group; |
1540 | } |
1541 | namespace { |
1542 | |
1543 | // Guard to temporarily change the inputs and outputs of a fusion. On |
1544 | // destruction will return fusion to original state. |
1545 | // Not used temporarily but will be useful when adding more mergin heuristics |
1546 | class FusionSegmentGuard : public NonCopyable { |
1547 | public: |
1548 | FusionSegmentGuard() = delete; |
1549 | |
1550 | FusionSegmentGuard( |
1551 | Fusion* fusion, |
1552 | std::vector<Val*> inputs, |
1553 | std::vector<Val*> outputs) |
1554 | : fusion_(fusion), |
1555 | old_inputs_(fusion->inputs()), |
1556 | old_outputs_(fusion->outputs()), |
1557 | new_inputs_(std::move(inputs)), |
1558 | new_outputs_(std::move(outputs)) { |
1559 | FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard" ); |
1560 | TORCH_INTERNAL_ASSERT(fusion_ != nullptr); |
1561 | for (auto old_inp : old_inputs_) { |
1562 | fusion_->removeInput(old_inp); |
1563 | } |
1564 | |
1565 | for (auto old_out : old_outputs_) { |
1566 | fusion_->removeOutput(old_out); |
1567 | } |
1568 | |
1569 | for (auto new_inp : new_inputs_) { |
1570 | fusion_->addInput(new_inp); |
1571 | } |
1572 | |
1573 | for (auto new_out : new_outputs_) { |
1574 | fusion_->addOutput(new_out); |
1575 | } |
1576 | } |
1577 | |
1578 | ~FusionSegmentGuard() { |
1579 | FUSER_PERF_SCOPE("~Segmenter::FusionSegmentGuard" ); |
1580 | |
1581 | if (fusion_ == nullptr) { |
1582 | return; |
1583 | } |
1584 | for (auto new_inp : new_inputs_) { |
1585 | fusion_->removeInput(new_inp); |
1586 | } |
1587 | |
1588 | for (auto new_out : new_outputs_) { |
1589 | fusion_->removeOutput(new_out); |
1590 | } |
1591 | |
1592 | for (auto old_inp : old_inputs_) { |
1593 | fusion_->addInput(old_inp); |
1594 | } |
1595 | |
1596 | for (auto old_out : old_outputs_) { |
1597 | fusion_->addOutput(old_out); |
1598 | } |
1599 | } |
1600 | |
1601 | private: |
1602 | Fusion* const fusion_ = nullptr; |
1603 | const std::vector<Val*> old_inputs_; |
1604 | const std::vector<Val*> old_outputs_; |
1605 | const std::vector<Val*> new_inputs_; |
1606 | const std::vector<Val*> new_outputs_; |
1607 | }; |
1608 | |
1609 | c10::optional<ScheduleHeuristic> tryMerge( |
1610 | Fusion* fusion, |
1611 | SchedulerRuntimeInfo& runtime_info, |
1612 | SegmentedGroup* a, |
1613 | SegmentedGroup* b = nullptr) { |
1614 | FusionSegmentGuard fsg(fusion, getAllInputs(a, b), getAllOutputs(a, b)); |
1615 | |
1616 | scheduler_debug_utils::canScheduleMessage( |
1617 | "\n**Segmenter** Considering fusion:\n" , fusion); |
1618 | return SchedulerEntry::proposeHeuristics(fusion, runtime_info); |
1619 | } |
1620 | |
1621 | c10::optional<ScheduleHeuristic> tryMerge( |
1622 | Fusion* fusion, |
1623 | SchedulerRuntimeInfo& runtime_info, |
1624 | const std::vector<SegmentedGroup*>& segmented_groups) { |
1625 | FusionSegmentGuard fsg( |
1626 | fusion, |
1627 | allInputsIfTrueElseOutputs(segmented_groups, true), |
1628 | allInputsIfTrueElseOutputs(segmented_groups, false)); |
1629 | scheduler_debug_utils::canScheduleMessage( |
1630 | "\n**Segmenter** Considering fusion:\n" , fusion); |
1631 | return SchedulerEntry::proposeHeuristics(fusion, runtime_info); |
1632 | } |
1633 | |
1634 | // This function is for cleanup and |
1635 | // easier debugging. It shouldn't affect functionality |
1636 | // since segmented fusions are compiled with fusion |
1637 | // guard on the edges instead of actually looking |
1638 | // at the exprs. |
1639 | void deDuplicateScalarExprs(std::vector<Expr*>& exprs) { |
1640 | // Exprs in SegmentedGroup are not ordered |
1641 | // so it is ok to insert them from unordered |
1642 | // set |
1643 | std::unordered_set<Expr*> scalar_expr_set; |
1644 | |
1645 | std::copy_if( |
1646 | exprs.begin(), |
1647 | exprs.end(), |
1648 | std::inserter(scalar_expr_set, scalar_expr_set.end()), |
1649 | [](Expr* expr) { return ir_utils::isScalarOp(expr); }); |
1650 | |
1651 | if (!scalar_expr_set.empty()) { |
1652 | exprs.erase( |
1653 | std::remove_if( |
1654 | exprs.begin(), |
1655 | exprs.end(), |
1656 | [&scalar_expr_set](Expr* expr) { |
1657 | return scalar_expr_set.count(expr); |
1658 | }), |
1659 | exprs.end()); |
1660 | exprs.insert(exprs.end(), scalar_expr_set.begin(), scalar_expr_set.end()); |
1661 | } |
1662 | } |
1663 | |
1664 | } // namespace |
1665 | |
1666 | c10::optional<std::unique_ptr<SchedulerEntry>> SegmentedGroup:: |
1667 | getMaybeSchedulerEntry(SchedulerRuntimeInfo& runtime_info) { |
1668 | FUSER_PERF_SCOPE("SegmentedGroup::getMaybeSchedulerEntry" ); |
1669 | auto fusion = segmented_fusion_->completeFusion(); |
1670 | auto data_cache = segmented_fusion_->getCachedHeuristicDataFor(this); |
1671 | FusionSegmentGuard fsg(fusion, getAllInputs(this), getAllOutputs(this)); |
1672 | if (!SchedulerEntry::canSchedule( |
1673 | heuristic(), fusion, runtime_info, data_cache)) { |
1674 | return c10::nullopt; |
1675 | } |
1676 | return SchedulerEntry::makeEntry( |
1677 | heuristic(), fusion, runtime_info, data_cache); |
1678 | } |
1679 | |
1680 | void SegmentedGroup::resetExprList() { |
1681 | auto input_group_vec = getAllInputs(this); |
1682 | std::unordered_set<Val*> input_group_set( |
1683 | input_group_vec.begin(), input_group_vec.end()); |
1684 | auto expr_set = |
1685 | DependencyCheck::getAllExprsBetween(input_group_set, getAllOutputs(this)); |
1686 | exprs_ = std::vector<Expr*>(expr_set.begin(), expr_set.end()); |
1687 | } |
1688 | |
1689 | // Custom merge node passes: |
1690 | // These passes are added at the beginning or the end of |
1691 | // the node merging process to direct the heuristics of |
1692 | // node merging process |
1693 | // |
1694 | // Should consider generalization and make a proper interface |
1695 | // if we have more merge node heuristics like this |
1696 | |
1697 | //! Translate Welford |
1698 | //! |
1699 | //! This pass can be inserted at any stages of segmentation, |
1700 | //! and it tries to replace welford ops with persistent |
1701 | //! mean and var ops. |
1702 | //! |
1703 | //! The checking of feasibility of persistent kernels |
1704 | //! is through normalization schedulers. The general idea |
1705 | //! is to first try to translate on a copy, and see if |
1706 | //! normalization scheduler is willing to produce a |
1707 | //! persistent kernel. |
1708 | //! |
1709 | //! For complete fusion this pass checks if all the |
1710 | //! welford ops can be translated simultaneously to |
1711 | //! produce a persistent normalization kernel and |
1712 | //! will perform translation if checks pass. |
1713 | //! |
1714 | //! For segmented fusion, same check is performed within |
1715 | //! each segmented group to collect applicable welford ops, |
1716 | //! and actual translations are performed on the complete |
1717 | //! fusion after all the checks are done. |
1718 | class TranslateApplicableWelford { |
1719 | public: |
1720 | //! Try translation on each segmented group of |
1721 | //! given segmented fusion |
1722 | //! returns true if any welford has been translated |
1723 | static bool run( |
1724 | SegmentedFusion* segmented_fusion, |
1725 | const KernelArgumentHolder& runtime_inputs) { |
1726 | TranslateApplicableWelford translate_welford( |
1727 | segmented_fusion, runtime_inputs); |
1728 | return translate_welford.translated_any_welford_; |
1729 | } |
1730 | |
1731 | //! Try translation on complete fusion, |
1732 | //! returns true if any welford has been translated |
1733 | static bool run(Fusion* fusion, const KernelArgumentHolder& runtime_inputs) { |
1734 | TranslateApplicableWelford translate_welford(fusion, runtime_inputs); |
1735 | return translate_welford.translated_any_welford_; |
1736 | } |
1737 | |
1738 | private: |
1739 | explicit TranslateApplicableWelford( |
1740 | SegmentedFusion* segmented_fusion, |
1741 | const KernelArgumentHolder& runtime_inputs); |
1742 | |
1743 | explicit TranslateApplicableWelford( |
1744 | Fusion* fusion, |
1745 | const KernelArgumentHolder& runtime_inputs); |
1746 | |
1747 | //! Given vector of welford ops from the same fusion, |
1748 | //! checks if translating all of them result in a |
1749 | //! persistent normalization kernel by try-runs on |
1750 | //! a test copy of the original fusion. |
1751 | //! |
1752 | //! Supported use cases are either un-segmented fusion, |
1753 | //! or all the given welfords are within the same |
1754 | //! segmented group. In the latter case, the segmented |
1755 | //! group containing all the welford ops needs to be |
1756 | //! provided. |
1757 | bool wouldTranslateToPersistent( |
1758 | const std::vector<WelfordOp*>& orignal_welfords, |
1759 | SegmentedGroup* group = nullptr); |
1760 | |
1761 | //! Translate the given welford op into separate |
1762 | //! average and standard deviation calculation. |
1763 | void translateSingleWelford(WelfordOp* welford); |
1764 | |
1765 | //! Utility to test if a translated fusion |
1766 | //! gives a persistent kernel. Uses normalization |
1767 | //! scheduler to do the test. |
1768 | bool isValidPersistentFusion( |
1769 | Fusion* translated_fusion, |
1770 | SchedulerRuntimeInfo& runtime_info); |
1771 | |
1772 | private: |
1773 | //! Indicates any translation happened. |
1774 | bool translated_any_welford_ = false; |
1775 | |
1776 | //! a reference to global fusion runtime inputs |
1777 | const KernelArgumentHolder& runtime_inputs_; |
1778 | |
1779 | //! For translation within group only, |
1780 | //! group boundary at test copy |
1781 | //! (see wouldTranslateToPersistent implementation ) |
1782 | std::vector<Val*> test_group_inputs_; |
1783 | std::vector<Val*> test_group_outputs_; |
1784 | }; |
1785 | |
1786 | TranslateApplicableWelford::TranslateApplicableWelford( |
1787 | Fusion* fusion, |
1788 | const KernelArgumentHolder& runtime_inputs) |
1789 | : runtime_inputs_(runtime_inputs) { |
1790 | auto exprs = fusion->exprs(); |
1791 | std::vector<WelfordOp*> orignal_welfords( |
1792 | ir_utils::filterByType<WelfordOp>(exprs).begin(), |
1793 | ir_utils::filterByType<WelfordOp>(exprs).end()); |
1794 | |
1795 | if (wouldTranslateToPersistent(orignal_welfords)) { |
1796 | for (auto welford : orignal_welfords) { |
1797 | translateSingleWelford(welford); |
1798 | } |
1799 | translated_any_welford_ = true; |
1800 | } |
1801 | } |
1802 | |
1803 | TranslateApplicableWelford::TranslateApplicableWelford( |
1804 | SegmentedFusion* segmented_fusion, |
1805 | const KernelArgumentHolder& runtime_inputs) |
1806 | : runtime_inputs_(runtime_inputs) { |
1807 | std::vector<SegmentedGroup*> translated_groups; |
1808 | std::vector<WelfordOp*> welford_to_translate; |
1809 | // Find welfords that can be translated in each group |
1810 | for (auto group : segmented_fusion->groups()) { |
1811 | std::vector<WelfordOp*> welford_in_group( |
1812 | ir_utils::filterByType<WelfordOp>(group->exprs()).begin(), |
1813 | ir_utils::filterByType<WelfordOp>(group->exprs()).end()); |
1814 | |
1815 | if (wouldTranslateToPersistent(welford_in_group, group)) { |
1816 | translated_groups.push_back(group); |
1817 | welford_to_translate.insert( |
1818 | welford_to_translate.end(), |
1819 | welford_in_group.begin(), |
1820 | welford_in_group.end()); |
1821 | } |
1822 | } |
1823 | |
1824 | // Actually translate the welford ops |
1825 | // and record all the vals that have been |
1826 | // replaced by the translation. |
1827 | for (auto welford : welford_to_translate) { |
1828 | translateSingleWelford(welford); |
1829 | } |
1830 | |
1831 | for (auto translated_group : translated_groups) { |
1832 | // Update heuristics and expr list of translated groups |
1833 | translated_group->heuristic_ = ScheduleHeuristic::Persistent; |
1834 | translated_group->resetExprList(); |
1835 | } |
1836 | } |
1837 | |
1838 | bool TranslateApplicableWelford::isValidPersistentFusion( |
1839 | Fusion* translated_fusion, |
1840 | SchedulerRuntimeInfo& runtime_info) { |
1841 | if (!SchedulerEntry::canSchedule( |
1842 | ScheduleHeuristic::Persistent, translated_fusion, runtime_info)) { |
1843 | return false; |
1844 | } |
1845 | |
1846 | auto scheduler = SchedulerEntry::makeEntry( |
1847 | ScheduleHeuristic::Persistent, translated_fusion, runtime_info); |
1848 | |
1849 | return scheduler->reductionParams().persistent_kernel; |
1850 | } |
1851 | |
1852 | bool TranslateApplicableWelford::wouldTranslateToPersistent( |
1853 | const std::vector<WelfordOp*>& orignal_welfords, |
1854 | SegmentedGroup* group) { |
1855 | if (orignal_welfords.empty()) { |
1856 | return false; |
1857 | } |
1858 | |
1859 | // Make sure all welford ops come from the same complete fusion |
1860 | auto fusion = orignal_welfords[0]->fusion(); |
1861 | TORCH_INTERNAL_ASSERT( |
1862 | std::all_of( |
1863 | orignal_welfords.begin(), |
1864 | orignal_welfords.end(), |
1865 | [fusion](WelfordOp* welford) { return welford->fusion() == fusion; }), |
1866 | "Welfords in given vector not in the same fusion" ); |
1867 | |
1868 | // Make initial `in-progress copy` |
1869 | auto test_copy = std::make_unique<Fusion>(); |
1870 | auto original_to_test_map = Fusion::copy(fusion, test_copy.get()); |
1871 | |
1872 | std::vector<WelfordOp*> copied_welfords; |
1873 | std::transform( |
1874 | orignal_welfords.begin(), |
1875 | orignal_welfords.end(), |
1876 | std::back_inserter(copied_welfords), |
1877 | [&original_to_test_map](auto welford) { |
1878 | return original_to_test_map.clone(welford); |
1879 | }); |
1880 | // Copied welfords will be invalidated on translation, but Vals will be |
1881 | // reused, keep a reference to them. |
1882 | std::vector<Val*> welford_avgs; |
1883 | std::vector<Val*> welford_vars; |
1884 | for (auto welford : copied_welfords) { |
1885 | welford_avgs.push_back(welford->outAvg()); |
1886 | welford_vars.push_back(welford->outVar()); |
1887 | } |
1888 | |
1889 | // Translate the welford ops |
1890 | for (auto welford_to_translate : copied_welfords) { |
1891 | translateSingleWelford(welford_to_translate); |
1892 | } |
1893 | |
1894 | SchedulerRuntimeInfo runtime_info(test_copy.get(), runtime_inputs_, true); |
1895 | // If we are looking at a segment of fusion, |
1896 | // we maintain the segmented group boundary, |
1897 | // one set for in_progress copy and one set |
1898 | // for `test copy` |
1899 | if (group != nullptr) { |
1900 | auto original_inputs = getAllInputs(group); |
1901 | auto original_outputs = getAllOutputs(group); |
1902 | test_group_inputs_.clear(); |
1903 | test_group_outputs_.clear(); |
1904 | std::transform( |
1905 | original_inputs.begin(), |
1906 | original_inputs.end(), |
1907 | std::back_inserter(test_group_inputs_), |
1908 | [&original_to_test_map](Val* in) { |
1909 | return original_to_test_map.clone(in); |
1910 | }); |
1911 | std::transform( |
1912 | original_outputs.begin(), |
1913 | original_outputs.end(), |
1914 | std::back_inserter(test_group_outputs_), |
1915 | [&original_to_test_map](Val* out) { |
1916 | return original_to_test_map.clone(out); |
1917 | }); |
1918 | |
1919 | // If only average is used from welford, we should still translate, but we |
1920 | // might not detect persistence if variance isn't actually used/marked as an |
1921 | // output in the test. |
1922 | for (auto outs_i : c10::irange(welford_avgs.size())) { |
1923 | auto avg = welford_avgs[outs_i]; |
1924 | auto var = welford_vars[outs_i]; |
1925 | if (avg->uses().empty()) { |
1926 | test_group_outputs_.push_back(avg); |
1927 | } |
1928 | |
1929 | if (var->uses().empty()) { |
1930 | test_group_outputs_.push_back(var); |
1931 | } |
1932 | } |
1933 | |
1934 | // Temporarily localize test copy around |
1935 | // the group boundary |
1936 | FusionSegmentGuard fsg( |
1937 | test_copy.get(), test_group_inputs_, test_group_outputs_); |
1938 | |
1939 | // Test if the translated copy is persistent |
1940 | return isValidPersistentFusion(test_copy.get(), runtime_info); |
1941 | } |
1942 | // In the case where we work on un-segmented |
1943 | // fusion, no group boundary logic, just |
1944 | // translate and test. |
1945 | return isValidPersistentFusion(test_copy.get(), runtime_info); |
1946 | } |
1947 | |
1948 | void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) { |
1949 | auto fusion = welford->fusion(); |
1950 | FusionGuard fg(fusion); |
1951 | // Only support translation of welford ops that |
1952 | // doesn't take inputs that are already statistics, |
1953 | // i.e. an r-factor product. |
1954 | // This translation works on un-scheduled fusions so |
1955 | // shouldn't expect to see this. |
1956 | TORCH_INTERNAL_ASSERT(welford->inN()->isOneInt()); |
1957 | |
1958 | // Grab the inputs and outputs of the welford |
1959 | auto in_val = welford->in()->as<TensorView>(); |
1960 | auto out_avg = welford->outAvg()->as<TensorView>(); |
1961 | auto out_var = welford->outVar()->as<TensorView>(); |
1962 | auto out_N = welford->outN()->as<TensorView>(); |
1963 | |
1964 | fusion->removeExpr(welford); |
1965 | // Not safe to use welford anymore |
1966 | welford = nullptr; |
1967 | |
1968 | // Create normalization based welford graph |
1969 | // largely taken from batchnorm cpp benchmark |
1970 | const auto& in_root = |
1971 | TensorDomain::noReductions(in_val->getMaybeRFactorDomain()); |
1972 | const auto& out_root = out_avg->getRootDomain(); |
1973 | std::vector<int> red_axes; |
1974 | |
1975 | TORCH_INTERNAL_ASSERT( |
1976 | in_root.size() == out_root.size(), |
1977 | "Invalid root domains of Welford input and output." , |
1978 | " Input: " , |
1979 | ir_utils::toString(in_root), |
1980 | ". Output: " , |
1981 | ir_utils::toString(out_root)); |
1982 | |
1983 | // Create scalar version of the feature element |
1984 | // counting. |
1985 | Val* num_features = IrBuilder::create<Double>(1); |
1986 | std::vector<bool> broadcast_mask(in_root.size(), false); |
1987 | for (const auto i : c10::irange(in_root.size())) { |
1988 | if (out_root.at(i)->isReduction()) { |
1989 | red_axes.push_back(i); |
1990 | broadcast_mask[i] = true; |
1991 | num_features = mul(num_features, out_root.at(i)->extent()); |
1992 | } |
1993 | } |
1994 | |
1995 | // Build a normalization expression group that is |
1996 | // equivalent to a welford operation. |
1997 | auto x_sum = sum(in_val, red_axes); |
1998 | IrBuilder::create<BinaryOp>(BinaryOpType::Div, out_avg, x_sum, num_features); |
1999 | // welford.avg may be broadcast. Reuse it if found. |
2000 | TensorView* x_avg_bcast = nullptr; |
2001 | for (auto& use_expr : out_avg->uses()) { |
2002 | if (auto bcast = dynamic_cast<BroadcastOp*>(use_expr)) { |
2003 | if (bcast->getBroadcastDimFlags() == broadcast_mask) { |
2004 | // Same broadcast found. |
2005 | x_avg_bcast = bcast->out()->as<TensorView>(); |
2006 | break; |
2007 | } |
2008 | } |
2009 | } |
2010 | |
2011 | // x_mean_sub may already exist. Reuse it if found. |
2012 | TensorView* x_mean_sub = nullptr; |
2013 | if (x_avg_bcast != nullptr) { |
2014 | for (auto& use_expr : x_avg_bcast->uses()) { |
2015 | if (auto bop = dynamic_cast<BinaryOp*>(use_expr)) { |
2016 | if (bop->getBinaryOpType() == BinaryOpType::Sub) { |
2017 | if (bop->lhs() == in_val && bop->rhs() == x_avg_bcast) { |
2018 | x_mean_sub = bop->out()->as<TensorView>(); |
2019 | } |
2020 | } |
2021 | } |
2022 | } |
2023 | } |
2024 | |
2025 | if (x_avg_bcast == nullptr) { |
2026 | x_avg_bcast = broadcast(out_avg, broadcast_mask); |
2027 | } |
2028 | |
2029 | if (x_mean_sub == nullptr) { |
2030 | x_mean_sub = sub(in_val, x_avg_bcast); |
2031 | } |
2032 | |
2033 | auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); |
2034 | IrBuilder::create<ReductionOp>( |
2035 | BinaryOpType::Add, |
2036 | IrBuilder::create<Double>(0.0), |
2037 | out_var, |
2038 | x_mean_sub_pow); |
2039 | IrBuilder::create<UnaryOp>(UnaryOpType::Set, out_N, num_features); |
2040 | |
2041 | // out_avg, out_N are now outputs of a pointwise ops and we |
2042 | // need to clear out its reduction domains. |
2043 | out_avg->clearReductionIterDomains(); |
2044 | out_N->clearReductionIterDomains(); |
2045 | } |
2046 | |
2047 | bool SegmentCandidateFinder::TranslateWelfordInFusion( |
2048 | Fusion* fusion, |
2049 | const KernelArgumentHolder& runtime_inputs) { |
2050 | return TranslateApplicableWelford::run(fusion, runtime_inputs); |
2051 | } |
2052 | |
2053 | //! CombineReductions: |
2054 | //! This pass works before the main merge node process |
2055 | //! It identifies reduction operations that can be combined |
2056 | //! together to form a normalization kernel. |
2057 | //! Two reductions are considered the same type if they have |
2058 | //! the same root domain length, and the reduction axis are the same. |
2059 | //! This pass tries to merge nodes with the same reduction type based |
2060 | //! on the graph structure. |
2061 | class CombineReductions { |
2062 | using GroupVec = std::vector<SegmentedGroup*>; |
2063 | class ReductionSignature; |
2064 | |
2065 | public: |
2066 | static void run(SegmentCandidateFinder* segment_candidate_finder) { |
2067 | CombineReductions combine_reductions(segment_candidate_finder); |
2068 | } |
2069 | static bool shouldRun(SegmentCandidateFinder* segment_candidate_finder); |
2070 | |
2071 | private: |
2072 | CombineReductions(SegmentCandidateFinder* segment_candidate_finder) |
2073 | : segment_candidate_finder_(segment_candidate_finder) { |
2074 | // Run pass over the segments |
2075 | |
2076 | // Collect segmented groups with reductions in them, |
2077 | // Assuming running before any merge happened, so |
2078 | // should see exactly one non-trivial reduction in each group |
2079 | for (auto group : segment_candidate_finder_->groups()) { |
2080 | if (auto rop_signature = |
2081 | ReductionSignature::makeReductionSignature(group)) { |
2082 | // Ignore pure squeeze operations in this analysis |
2083 | if (!rop_signature->hasNonTrivialReduction()) { |
2084 | continue; |
2085 | } |
2086 | |
2087 | groups_with_reductions_.push_back(group); |
2088 | // Check if this reduction signature is one that we have seen before |
2089 | auto signature_match_it = std::find_if( |
2090 | known_reduction_signatures_.begin(), |
2091 | known_reduction_signatures_.end(), |
2092 | [&rop_signature](auto& know_signature) { |
2093 | return know_signature->sameAs(rop_signature.get()); |
2094 | }); |
2095 | // Unmatched: Create a new signature entry if not known |
2096 | if (signature_match_it == known_reduction_signatures_.end()) { |
2097 | group_reduction_signature_map_[group] = rop_signature.get(); |
2098 | known_reduction_signatures_.emplace_back(std::move(rop_signature)); |
2099 | } else { |
2100 | // Matched known signature: Mark that this groups belongs to know |
2101 | // signature |
2102 | group_reduction_signature_map_[group] = signature_match_it->get(); |
2103 | } |
2104 | } |
2105 | } |
2106 | |
2107 | // Keep trying to merge groups with compatible reductions and compatible |
2108 | // paths |
2109 | // until no more merge opportunity can be identified |
2110 | bool merged_groups = true; |
2111 | while (merged_groups) { |
2112 | merged_groups = false; |
2113 | |
2114 | // Merge one pair of reduction groups at a time, and need |
2115 | // the pass to update dependency info along the way to avoid cycles |
2116 | for (const auto first_group_index : |
2117 | c10::irange(groups_with_reductions_.size())) { |
2118 | if (merged_groups) { |
2119 | // Need to break and re-enter this loop because |
2120 | // groups_with_reductions_ will be updated |
2121 | break; |
2122 | } |
2123 | |
2124 | // Select one of the group to merge and get its reduction signature |
2125 | auto first_group = groups_with_reductions_[first_group_index]; |
2126 | auto first_group_signature = |
2127 | group_reduction_signature_map_.at(first_group); |
2128 | |
2129 | for (const auto second_group_index : c10::irange( |
2130 | first_group_index + 1, groups_with_reductions_.size())) { |
2131 | if (merged_groups) { |
2132 | // Need to break and re-enter this loop because |
2133 | // groups_with_reductions_ will be updated |
2134 | break; |
2135 | } |
2136 | auto second_group = groups_with_reductions_[second_group_index]; |
2137 | auto second_group_signature = |
2138 | group_reduction_signature_map_.at(second_group); |
2139 | |
2140 | // Cannot merge if their signatures are not the same |
2141 | if (!first_group_signature->sameAs(second_group_signature)) { |
2142 | continue; |
2143 | } |
2144 | |
2145 | // first try a vertical merge |
2146 | merged_groups = |
2147 | verticalReductionMerge(first_group, second_group) != nullptr; |
2148 | if (!merged_groups) { |
2149 | // vertical merge didn't happen, try a horizontal merge |
2150 | merged_groups = |
2151 | horizontalReductionMerge(first_group, second_group) != nullptr; |
2152 | } |
2153 | } |
2154 | } |
2155 | } |
2156 | } |
2157 | |
2158 | //! Merge a vertical pair of producers and consumers, |
2159 | //! the resulting group will include all nodes that are |
2160 | //! also consumers of producer and producers of consumer, |
2161 | //! i.e. values between the given producer-consumer pair. |
2162 | //! Can be proven that: |
2163 | //! 1. Including all of these nodes will be cycle-free |
2164 | //! 2. These nodes are the minimal set of nodes to include if |
2165 | //! for producer-consumer pair to be in the same group cycle-free |
2166 | //! |
2167 | //! Returns nullptr if such merge cannot be achieved. |
2168 | //! Reasons for not merging will include: |
2169 | //! 1. Given groups do not form producer-consumer pair |
2170 | //! 2. Merge will create cycle on the graph |
2171 | //! 3. The merged joined group cannot be scheduled |
2172 | SegmentedGroup* verticalReductionMerge( |
2173 | SegmentedGroup* first_group, |
2174 | SegmentedGroup* second_group) { |
2175 | // This is part of ReductionCombine pass, and we should only call this |
2176 | // function on a pair of reduction/normalization groups |
2177 | TORCH_INTERNAL_ASSERT( |
2178 | group_reduction_signature_map_.at(first_group) |
2179 | ->sameAs(group_reduction_signature_map_.at(second_group))); |
2180 | TORCH_INTERNAL_ASSERT(first_group != second_group); |
2181 | // Get the group dependency data from segment finder |
2182 | auto dependency_analysis = segment_candidate_finder_->getGroupDependency(); |
2183 | |
2184 | // Check producer-consumer relationship |
2185 | SegmentedGroup* producer = nullptr; |
2186 | SegmentedGroup* consumer = nullptr; |
2187 | if (dependency_analysis->isConsumerOf(first_group, second_group)) { |
2188 | producer = second_group; |
2189 | consumer = first_group; |
2190 | } else if (dependency_analysis->isProducerOf(first_group, second_group)) { |
2191 | producer = first_group; |
2192 | consumer = second_group; |
2193 | } else { |
2194 | // Given groups aren't producer-consumer pair, won't merge |
2195 | return nullptr; |
2196 | } |
2197 | |
2198 | // Collect all groups that we need to merge along with the producer and |
2199 | // consumer |
2200 | auto all_groups_to_merge = |
2201 | getValidMinVerticalMergedGroupSet(producer, consumer); |
2202 | |
2203 | if (all_groups_to_merge.empty()) { |
2204 | // The vertical paths from producer to consumer have in-compatible |
2205 | // reductions |
2206 | // so this vertical merge cannot be done. |
2207 | return nullptr; |
2208 | } |
2209 | |
2210 | // TODO: this step would not be deterministic, because valuesBetween isn't |
2211 | // could fix this by a topological order |
2212 | std::vector<SegmentedGroup*> all_groups_to_merge_vec( |
2213 | all_groups_to_merge.begin(), all_groups_to_merge.end()); |
2214 | |
2215 | // Final sanity check: the merged group can actually be scheduled |
2216 | Fusion* fusion = |
2217 | segment_candidate_finder_->segmented_fusion_->completeFusion(); |
2218 | if (!tryMerge( |
2219 | fusion, |
2220 | segment_candidate_finder_->runtimeInfo(), |
2221 | all_groups_to_merge_vec)) { |
2222 | return nullptr; |
2223 | } |
2224 | |
2225 | // Merge this group |
2226 | auto joined_group = |
2227 | segment_candidate_finder_->mergeAllGivenGroups(all_groups_to_merge_vec); |
2228 | |
2229 | // Update dependency analysis |
2230 | dependency_analysis->mergeGroups(all_groups_to_merge, joined_group); |
2231 | |
2232 | // Update the reduction groups that are merged |
2233 | groups_with_reductions_.push_back(joined_group); |
2234 | group_reduction_signature_map_[joined_group] = |
2235 | group_reduction_signature_map_.at(first_group); |
2236 | groups_with_reductions_.erase( |
2237 | std::remove_if( |
2238 | groups_with_reductions_.begin(), |
2239 | groups_with_reductions_.end(), |
2240 | [&all_groups_to_merge](SegmentedGroup* group) { |
2241 | return all_groups_to_merge.has(group); |
2242 | }), |
2243 | groups_with_reductions_.end()); |
2244 | |
2245 | return joined_group; |
2246 | } |
2247 | |
2248 | //! Horizontal reduction merging: |
2249 | //! merge two horizontal groups with reduction expressions to make a joined |
2250 | //! normalization group. A pair of horizontal groups are ones that are not |
2251 | //! a producer-consumer pair, and share either a common producer or a common |
2252 | //! consumer. |
2253 | //! |
2254 | //! TODO: This implementation looks at common producers only, since common |
2255 | //! consumers are not computed easily with current dependency analysis. |
2256 | SegmentedGroup* horizontalReductionMerge( |
2257 | SegmentedGroup* first_group, |
2258 | SegmentedGroup* second_group) { |
2259 | // This is part of ReductionCombine pass, and we should only call this |
2260 | // function on a pair of |
2261 | // reduction/normalization groups |
2262 | TORCH_INTERNAL_ASSERT( |
2263 | group_reduction_signature_map_.at(first_group) |
2264 | ->sameAs(group_reduction_signature_map_.at(second_group))); |
2265 | TORCH_INTERNAL_ASSERT(first_group != second_group); |
2266 | |
2267 | auto dependency_analysis = segment_candidate_finder_->getGroupDependency(); |
2268 | |
2269 | // Check that the two groups are not producer-consumer's |
2270 | if (dependency_analysis->isConsumerOf(first_group, second_group) || |
2271 | dependency_analysis->isProducerOf(first_group, second_group)) { |
2272 | // This merge pass will not handle producer-consumer pairs |
2273 | return nullptr; |
2274 | } |
2275 | |
2276 | // Get common producers of the two group |
2277 | auto common_producers_set = |
2278 | dependency_analysis->getCommonProducersOf({first_group, second_group}); |
2279 | if (common_producers_set.empty()) { |
2280 | // The given pair doesn't have a common producer. |
2281 | // Either they have a common consumer, which we don't handle for now, |
2282 | // or maybe the two given groups are not connected. |
2283 | return nullptr; |
2284 | } |
2285 | |
2286 | // We are looking for a very specific patterns here. The cases that this |
2287 | // pattern will not capture are ones that reductions of different |
2288 | // signatures are so interleaved that we cannot find a clear cut as |
2289 | // explained below, without graph rewriting. Some graph re-writing on the |
2290 | // segmented groups level could provide extra merging opportunities for |
2291 | // free, which could be part of next step. |
2292 | // |
2293 | // The specific pattern we look for contains a common producer P with |
2294 | // immediate consumers C1, C2 such that all paths from C1 to first_group and |
2295 | // all paths from C2 to second_group won't hit a reduction with a different |
2296 | // signature. |
2297 | |
2298 | // Topologically sort the common producers and start with the topologically |
2299 | // minimal, |
2300 | // i.e. one that are closest to the two groups. This will cut the search |
2301 | // space. |
2302 | std::vector<SegmentedGroup*> common_producers; |
2303 | for (auto producer : common_producers_set) { |
2304 | if (!std::any_of( |
2305 | common_producers_set.begin(), |
2306 | common_producers_set.end(), |
2307 | [dependency_analysis, producer](SegmentedGroup* group) { |
2308 | return dependency_analysis->isProducerOf(producer, group); |
2309 | })) { |
2310 | common_producers.push_back(producer); |
2311 | } |
2312 | } |
2313 | |
2314 | // Visit the common producers found, starting from topologically minimum, |
2315 | // i.e. the ones closer to the groups |
2316 | for (auto common_producer : common_producers) { |
2317 | // Visit this common producer |
2318 | // Use a double loop in case the schedulers like some patterns |
2319 | // better than the other |
2320 | for (auto first_consumer_edge : common_producer->consumer_edges) { |
2321 | auto producer_of_first_group = first_consumer_edge->to; |
2322 | auto to_merge_with_first_group = getValidMinVerticalMergedGroupSet( |
2323 | producer_of_first_group, first_group); |
2324 | if (to_merge_with_first_group.empty()) { |
2325 | // There's no valid merge path from this consumer of common producer, |
2326 | // either due to a conflicting reduction signature, or simply there's |
2327 | // no path to first group |
2328 | continue; |
2329 | } |
2330 | TORCH_INTERNAL_ASSERT(!dependency_analysis->isProducerOf( |
2331 | producer_of_first_group, second_group)); |
2332 | for (auto second_consumer_edge : common_producer->consumer_edges) { |
2333 | auto producer_of_second_group = second_consumer_edge->to; |
2334 | auto to_merge_with_second_group = getValidMinVerticalMergedGroupSet( |
2335 | producer_of_second_group, second_group); |
2336 | if (to_merge_with_second_group.empty()) { |
2337 | // There's no valid merge path from this consumer of common |
2338 | // producer, |
2339 | // either due to a conflicting reduction signature, or simply |
2340 | // there's no path to second group |
2341 | continue; |
2342 | } |
2343 | TORCH_INTERNAL_ASSERT(!dependency_analysis->isProducerOf( |
2344 | producer_of_second_group, first_group)); |
2345 | // At this point we should have a pair of valid candidates,final check |
2346 | // is to see if the combined group |
2347 | // can be scheduled by schedulers |
2348 | // merge the two paths and de-duplicate, |
2349 | // re-using container here with to_merge_with_second_group |
2350 | auto& groups_to_merge_set = to_merge_with_second_group; |
2351 | groups_to_merge_set.insert( |
2352 | to_merge_with_first_group.begin(), |
2353 | to_merge_with_first_group.end()); |
2354 | std::vector<SegmentedGroup*> groups_to_merge_vec( |
2355 | groups_to_merge_set.begin(), groups_to_merge_set.end()); |
2356 | Fusion* fusion = |
2357 | segment_candidate_finder_->segmented_fusion_->completeFusion(); |
2358 | if (tryMerge( |
2359 | fusion, |
2360 | segment_candidate_finder_->runtimeInfo(), |
2361 | groups_to_merge_vec)) { |
2362 | // Found a valid horizontal merge, want to proceed with merging here |
2363 | auto joined_group = segment_candidate_finder_->mergeAllGivenGroups( |
2364 | groups_to_merge_vec); |
2365 | dependency_analysis->mergeGroups(groups_to_merge_set, joined_group); |
2366 | |
2367 | groups_with_reductions_.push_back(joined_group); |
2368 | group_reduction_signature_map_[joined_group] = |
2369 | group_reduction_signature_map_.at(first_group); |
2370 | groups_with_reductions_.erase( |
2371 | std::remove_if( |
2372 | groups_with_reductions_.begin(), |
2373 | groups_with_reductions_.end(), |
2374 | [&groups_to_merge_set](SegmentedGroup* group) { |
2375 | return groups_to_merge_set.has(group); |
2376 | }), |
2377 | groups_with_reductions_.end()); |
2378 | |
2379 | return joined_group; |
2380 | } |
2381 | } |
2382 | } |
2383 | } |
2384 | |
2385 | // Searched all possibilities and there is no valid horizontal merge pattern |
2386 | // found. |
2387 | return nullptr; |
2388 | } |
2389 | |
2390 | //! This is a utility method that is used in both vertical merging and |
2391 | //! horizontal merging. |
2392 | //! It is used to identify the smallest set of groups to merge vertically |
2393 | //! involving the |
2394 | //! two given nodes. |
2395 | //! Given a pair of nodes this utility distinguishes 3 cases: |
2396 | //! 1. if maybe_producer is the same as maybe_consumer, then returns |
2397 | //! {maybe_producer} |
2398 | //! 2. if maybe_producer is actually a producer of consumer, returns a set |
2399 | //! containing |
2400 | //! the smallest merged group that would contain producer and consumer and |
2401 | //! would not introduce a cycle. Returns empty set if such group has |
2402 | //! a conflicting reduction signature. |
2403 | //! 3. returns empty set if neither conditions above apply. |
2404 | GroupSet getValidMinVerticalMergedGroupSet( |
2405 | SegmentedGroup* maybe_producer, |
2406 | SegmentedGroup* maybe_consumer) { |
2407 | auto dependency_analysis = segment_candidate_finder_->getGroupDependency(); |
2408 | if (maybe_consumer == maybe_producer) { |
2409 | // maybe producer is the same as maybe_consumer |
2410 | return {maybe_consumer}; |
2411 | } else if (dependency_analysis->isConsumerOf( |
2412 | maybe_consumer, maybe_producer)) { |
2413 | auto groups_to_check = |
2414 | dependency_analysis->valuesBetween(maybe_producer, maybe_consumer); |
2415 | groups_to_check.pushBack(maybe_producer); |
2416 | groups_to_check.pushBack(maybe_consumer); |
2417 | |
2418 | // Check that either no group has a reduction or all groups have the same |
2419 | // reduction signature |
2420 | ReductionSignature* reduction_signature = nullptr; |
2421 | |
2422 | // Iterate through the minimal group set to see if any conflicts |
2423 | for (auto group : groups_to_check) { |
2424 | // Check that this group does not involve a output edge contraction |
2425 | // This pass is intended to be a pre-merging pass. Since contracting an |
2426 | // output edge does not generate much saving of global memory access |
2427 | // we want to postpone merging these edges till the very final pass |
2428 | for (auto producer_edge_of_group : group->producer_edges) { |
2429 | if (groups_to_check.has(producer_edge_of_group->from) && |
2430 | producer_edge_of_group->val->isFusionOutput()) { |
2431 | return {}; |
2432 | } |
2433 | } |
2434 | for (auto consumer_edge_of_group : group->consumer_edges) { |
2435 | if (groups_to_check.has(consumer_edge_of_group->to) && |
2436 | consumer_edge_of_group->val->isFusionOutput()) { |
2437 | return {}; |
2438 | } |
2439 | } |
2440 | |
2441 | // Check that this group does not have a conflicting reduction signature |
2442 | if (group_reduction_signature_map_.count(group)) { |
2443 | if (reduction_signature != nullptr) { |
2444 | if (!group_reduction_signature_map_.at(group)->sameAs( |
2445 | reduction_signature)) { |
2446 | // Found a conflict in reduction signature, cannot do a vertical |
2447 | // merge |
2448 | return {}; |
2449 | } |
2450 | } else { |
2451 | reduction_signature = group_reduction_signature_map_.at(group); |
2452 | } |
2453 | } |
2454 | } |
2455 | return groups_to_check; |
2456 | } |
2457 | // maybe producer is not a producer of maybe consumer |
2458 | return {}; |
2459 | } |
2460 | |
2461 | private: |
2462 | SegmentCandidateFinder* segment_candidate_finder_; |
2463 | |
2464 | // Wrapper class for reduction type |
2465 | // Assuming there wouldn't be too many of them |
2466 | // so won't need to create a hash |
2467 | // TODO: |
2468 | // Want to reconsider this for transpose operations, |
2469 | // need refactoring to handle reduction fusions across a transpose operation |
2470 | class ReductionSignature { |
2471 | public: |
2472 | bool sameAs(const ReductionSignature* reduction_signature) { |
2473 | if (reduction_signature == this) { |
2474 | return true; |
2475 | } |
2476 | |
2477 | if (root_domain_size_ != reduction_signature->root_domain_size_ || |
2478 | has_nontrivial_reduction_ != |
2479 | reduction_signature->has_nontrivial_reduction_ || |
2480 | reduction_axes_.size() != |
2481 | reduction_signature->reduction_axes_.size()) { |
2482 | return false; |
2483 | } |
2484 | |
2485 | for (const auto i : c10::irange(reduction_axes_.size())) { |
2486 | if (reduction_axes_[i] != reduction_signature->reduction_axes_[i]) { |
2487 | return false; |
2488 | } |
2489 | } |
2490 | |
2491 | return true; |
2492 | } |
2493 | |
2494 | bool sameAs(const ReductionSignature& reduction_signature) { |
2495 | return sameAs(&reduction_signature); |
2496 | } |
2497 | |
2498 | bool hasNonTrivialReduction() const { |
2499 | return has_nontrivial_reduction_; |
2500 | } |
2501 | |
2502 | static std::unique_ptr<ReductionSignature> makeReductionSignature( |
2503 | SegmentedGroup* group) { |
2504 | std::unique_ptr<ReductionSignature> signature = nullptr; |
2505 | |
2506 | for (auto expr : group->exprs()) { |
2507 | std::unique_ptr<ReductionSignature> new_signature = nullptr; |
2508 | |
2509 | if (auto rop = dynamic_cast<ReductionOp*>(expr)) { |
2510 | new_signature = std::make_unique<ReductionSignature>(rop); |
2511 | } |
2512 | if (auto wop = dynamic_cast<WelfordOp*>(expr)) { |
2513 | new_signature = std::make_unique<ReductionSignature>(wop); |
2514 | } |
2515 | |
2516 | if (new_signature != nullptr) { |
2517 | TORCH_INTERNAL_ASSERT( |
2518 | signature == nullptr || !signature->has_nontrivial_reduction_ || |
2519 | !new_signature->has_nontrivial_reduction_ || |
2520 | signature->sameAs(new_signature.get()), |
2521 | "Conflicting signature found in this group" ); |
2522 | signature = std::move(new_signature); |
2523 | } |
2524 | } |
2525 | return signature; |
2526 | } |
2527 | |
2528 | template <typename REDUCTION = ReductionOp> |
2529 | ReductionSignature(REDUCTION* rop) { |
2530 | auto out_tv = rop->out()->template as<TensorView>(); |
2531 | has_nontrivial_reduction_ = out_tv->hasReduction(); |
2532 | TORCH_INTERNAL_ASSERT(out_tv != nullptr); |
2533 | auto& root_domain = out_tv->getRootDomain(); |
2534 | root_domain_size_ = root_domain.size(); |
2535 | |
2536 | // Trivial reduction i.e. squeeze is tricky here: |
2537 | // this pass doesn't want to touch any pure squeeze, i.e.: |
2538 | // T0 [R(1), I(i0), I(i1)] |
2539 | // meanwhile, for two reductions having |
2540 | // squeezes, we do require they have squeeze at the |
2541 | // same position so that they can be easily root domain mapped |
2542 | // So T0 and T1 are the same signature, |
2543 | // T0 [R(1), R(i0), I(i1)] |
2544 | // T1 [R(1), R(i0), I(i1)] |
2545 | // but T2 and T3 below are not |
2546 | // T0 [R(1), R(1), R(i0), I(i1)] |
2547 | // T1 [R(1), R(i0), I(i1)] |
2548 | for (const auto i : c10::irange(root_domain_size_)) { |
2549 | if (root_domain[i]->isReduction()) { |
2550 | reduction_axes_.push_back(i); |
2551 | } |
2552 | if (!root_domain[i]->isTrivialReduction()) { |
2553 | has_nontrivial_reduction_ = true; |
2554 | } |
2555 | } |
2556 | } |
2557 | |
2558 | private: |
2559 | size_t root_domain_size_ = 0; |
2560 | std::vector<int> reduction_axes_; |
2561 | bool has_nontrivial_reduction_ = false; |
2562 | }; |
2563 | |
2564 | //! Keeps track of groups with reduction expressions, |
2565 | //! using a vector here to maintain a deterministic ordering |
2566 | GroupVec groups_with_reductions_; |
2567 | |
2568 | //! Maps groups to their corresponding signature type |
2569 | std::unordered_map<SegmentedGroup*, ReductionSignature*> |
2570 | group_reduction_signature_map_; |
2571 | |
2572 | //! Maintains all reduction signatures seen in the segmented fusion |
2573 | std::vector<std::unique_ptr<ReductionSignature>> known_reduction_signatures_; |
2574 | }; |
2575 | |
2576 | //! This is to be checked |
2577 | bool CombineReductions::shouldRun( |
2578 | SegmentCandidateFinder* segment_candidate_finder) { |
2579 | std::vector<std::unique_ptr<ReductionSignature>> known_reductions; |
2580 | // Iterate over group segments we have before segment candidate finder |
2581 | // tries to merge any groups |
2582 | for (auto group : segment_candidate_finder->groups()) { |
2583 | if (auto reduction_signature = |
2584 | ReductionSignature::makeReductionSignature(group)) { |
2585 | if (reduction_signature->hasNonTrivialReduction() && |
2586 | std::any_of( |
2587 | known_reductions.begin(), |
2588 | known_reductions.end(), |
2589 | [&reduction_signature](auto& know_signature) { |
2590 | return know_signature->sameAs(reduction_signature.get()); |
2591 | })) { |
2592 | // Found two reductions with the same signature, run pass |
2593 | return true; |
2594 | } |
2595 | known_reductions.emplace_back(std::move(reduction_signature)); |
2596 | } |
2597 | } |
2598 | return false; |
2599 | } |
2600 | |
2601 | namespace { |
2602 | |
2603 | //! Returns true if group1 and group2 are an immediate producer-consumer pair. |
2604 | bool areDirectlyConnected(SegmentedGroup* group1, SegmentedGroup* group2) { |
2605 | // Check if group1 is a immediate consumer of group2 |
2606 | if (std::any_of( |
2607 | group1->producer_edges.begin(), |
2608 | group1->producer_edges.end(), |
2609 | [group2](SegmentedEdge* edge) { return edge->from == group2; })) { |
2610 | return true; |
2611 | } |
2612 | |
2613 | // Check if group1 is a immediate producer of group2 |
2614 | if (std::any_of( |
2615 | group1->consumer_edges.begin(), |
2616 | group1->consumer_edges.end(), |
2617 | [group2](SegmentedEdge* edge) { return edge->to == group2; })) { |
2618 | return true; |
2619 | } |
2620 | |
2621 | return false; |
2622 | } |
2623 | |
2624 | } // namespace |
2625 | |
2626 | bool SegmentCandidateFinder::codeGenSupportedMerge( |
2627 | SegmentedGroup* group1, |
2628 | SegmentedGroup* group2) { |
2629 | TORCH_INTERNAL_ASSERT( |
2630 | areDirectlyConnected(group1, group2), |
2631 | "only support testing immediate producer-consumer groups" ); |
2632 | Fusion* fusion = segmented_fusion_->completeFusion(); |
2633 | auto h = tryMerge(fusion, runtime_info_, group1, group2); |
2634 | return h.has_value(); |
2635 | } |
2636 | |
2637 | // TODO: consider caching the heuristics value so tryMerge doesn't have to be |
2638 | // called twice |
2639 | ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic( |
2640 | SegmentedGroup* group) { |
2641 | Fusion* fusion = segmented_fusion_->completeFusion(); |
2642 | auto h = tryMerge(fusion, runtime_info_, group); |
2643 | TORCH_INTERNAL_ASSERT(h.has_value()); |
2644 | return h.value(); |
2645 | } |
2646 | |
2647 | SegmentCandidateFinder::SegmentCandidateFinder( |
2648 | std::unique_ptr<Fusion> fusion, |
2649 | const KernelArgumentHolder& inputs, |
2650 | SegmentCandidateFinderOptions options) |
2651 | : options_(options), |
2652 | runtime_info_(fusion.get(), inputs, true), |
2653 | runtime_inputs_(inputs) { |
2654 | segmented_fusion_ = std::make_unique<SegmentedFusion>(std::move(fusion)); |
2655 | findSegments(); |
2656 | } |
2657 | |
2658 | void SegmentCandidateFinder::findSegments() { |
2659 | FUSER_PERF_SCOPE("Finding valid fusion segment solutions" ); |
2660 | // TODO: Make traversal items local to this function. |
2661 | |
2662 | // Need this for initialization of the DAG that is process |
2663 | std::unordered_map<Expr*, SegmentedGroup*> expr2group; |
2664 | |
2665 | // Keep track of complete fusion input use |
2666 | std::unordered_map<Val*, SegmentedGroup*> input2group; |
2667 | |
2668 | // Initialize DAG, convert each expr to a segment group |
2669 | auto exprs = completeFusion()->exprs(); |
2670 | for (auto expr : exprs) { |
2671 | if (!ir_utils::isScalarOp(expr)) { |
2672 | auto new_group = segmented_fusion_->newGroup(expr); |
2673 | expr2group.insert(std::make_pair(expr, new_group)); |
2674 | } |
2675 | } |
2676 | |
2677 | // Find all expresions that are simply unary ops from inputs. Don't segment |
2678 | // these as they're easy targets for recomputation. Only go until the first |
2679 | // expression that has multiple uses. We could continue, but the logic of |
2680 | // hacking the fusion "inputs" logic gets a bit more complicated. |
2681 | |
2682 | // Expressions to exclude from segmentation because they're just derived from |
2683 | // unary ops on inputs to the complete fusion |
2684 | VectorOfUniqueEntries<Expr*> excluded_inp_unary_exprs; |
2685 | |
2686 | // "Terminating" outputs from the excluded input unary exprs, these will be |
2687 | // treated as complete fusion inputs. |
2688 | VectorOfUniqueEntries<Val*> forwarded_inputs; |
2689 | { |
2690 | std::deque<Expr*> to_visit; |
2691 | for (auto inp : completeFusion()->inputs()) { |
2692 | if (std::all_of(inp->uses().begin(), inp->uses().end(), [](Expr* expr) { |
2693 | return expr->getExprType().value() == ExprType::UnaryOp; |
2694 | })) { |
2695 | to_visit.insert(to_visit.end(), inp->uses().begin(), inp->uses().end()); |
2696 | } |
2697 | } |
2698 | |
2699 | while (!to_visit.empty()) { |
2700 | auto expr = to_visit.front(); |
2701 | to_visit.pop_front(); |
2702 | if (expr->getExprType().value() != ExprType::UnaryOp || |
2703 | expr->output(0)->isFusionOutput()) { |
2704 | continue; |
2705 | } |
2706 | |
2707 | if (expr->output(0)->uses().size() > 1) { |
2708 | excluded_inp_unary_exprs.pushBack(expr); |
2709 | forwarded_inputs.pushBack(expr->output(0)); |
2710 | continue; |
2711 | } |
2712 | |
2713 | to_visit.emplace_back(expr->output(0)->uses()[0]); |
2714 | } |
2715 | } |
2716 | |
2717 | auto excluded_fusion_inputs = IterVisitor::getInputsTo( |
2718 | {forwarded_inputs.begin(), forwarded_inputs.end()}); |
2719 | |
2720 | // List of vals to treat as complete fusion inputs for segmentation |
2721 | auto forwarded_fusion_inputs = completeFusion()->inputs(); |
2722 | |
2723 | forwarded_fusion_inputs.erase( |
2724 | std::remove_if( |
2725 | forwarded_fusion_inputs.begin(), |
2726 | forwarded_fusion_inputs.end(), |
2727 | [&excluded_fusion_inputs](Val* inp) { |
2728 | return std::find( |
2729 | excluded_fusion_inputs.begin(), |
2730 | excluded_fusion_inputs.end(), |
2731 | inp) != excluded_fusion_inputs.end(); |
2732 | }), |
2733 | forwarded_fusion_inputs.end()); |
2734 | |
2735 | forwarded_fusion_inputs.insert( |
2736 | forwarded_fusion_inputs.end(), |
2737 | forwarded_inputs.begin(), |
2738 | forwarded_inputs.end()); |
2739 | |
2740 | auto isFusionInput = [&forwarded_fusion_inputs](Val* val) -> bool { |
2741 | return std::find( |
2742 | forwarded_fusion_inputs.begin(), |
2743 | forwarded_fusion_inputs.end(), |
2744 | val) != forwarded_fusion_inputs.end(); |
2745 | }; |
2746 | |
2747 | // Insert auxiliary groups to use group dependency on inputs as well |
2748 | // TODO: these groups should never merged into any other groups, but are |
2749 | // just there to support the dependency analysis. Later re-factor should |
2750 | // avoid introducing them explicitly on the segmented fusion. |
2751 | for (auto input : forwarded_fusion_inputs) { |
2752 | // These groups are used to represent input as a common |
2753 | // producer in horizontal merges, and should never be |
2754 | // seen as a candidate for vertical merge |
2755 | auto new_group = segmented_fusion_->newGroup(); |
2756 | input2group.insert({input, new_group}); |
2757 | } |
2758 | |
2759 | // Create edges between the Exprs. Mark inputs and outputs of the fusion. |
2760 | for (auto expr : exprs) { |
2761 | // No group created for scalar ops |
2762 | if (ir_utils::isScalarOp(expr)) { |
2763 | continue; |
2764 | } |
2765 | |
2766 | if (excluded_inp_unary_exprs.has(expr)) { |
2767 | continue; |
2768 | } |
2769 | |
2770 | auto expr_group = expr2group.at(expr); |
2771 | for (auto inp : expr->inputs()) { |
2772 | if (isFusionInput(inp)) { |
2773 | expr_group->input_vals.push_back(inp); |
2774 | auto aux_group = input2group.at(inp); |
2775 | auto new_edge = segmented_fusion_->newEdge(aux_group, expr_group, inp); |
2776 | expr_group->producer_edges.push_back(new_edge); |
2777 | aux_group->consumer_edges.push_back(new_edge); |
2778 | continue; |
2779 | } |
2780 | |
2781 | // Could be something like a constant scalar, definition is nullptr, but |
2782 | // isn't an "input" to the fusion. At least not one provided by an |
2783 | // external source. |
2784 | if (inp->definition() == nullptr) { |
2785 | continue; |
2786 | } |
2787 | |
2788 | // No group created for scalar ops since they may need to be duplicated |
2789 | // to avoid scalar edges. They are handled in resolveScalarsInGroup |
2790 | if (inp->isScalar()) { |
2791 | continue; |
2792 | } |
2793 | |
2794 | auto def_group = expr2group.at(inp->definition()); |
2795 | auto new_edge = segmented_fusion_->newEdge(def_group, expr_group, inp); |
2796 | expr_group->producer_edges.push_back(new_edge); |
2797 | def_group->consumer_edges.push_back(new_edge); |
2798 | } |
2799 | for (auto out : expr->outputs()) { |
2800 | if (out->isFusionOutput()) { |
2801 | expr_group->output_vals.push_back(out); |
2802 | } |
2803 | } |
2804 | } |
2805 | |
2806 | auto reduction_ops = ir_utils::getReductionOps( |
2807 | segmented_fusion_->completeFusion(), true /* ignore_trivial */); |
2808 | auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops); |
2809 | |
2810 | if (options_.run_translate_welford && |
2811 | (welford_ops.begin() != welford_ops.end())) { |
2812 | TranslateApplicableWelford::run(segmented_fusion_.get(), runtime_inputs_); |
2813 | } |
2814 | |
2815 | for (auto group : groups()) { |
2816 | if (!group->outputs().empty()) { |
2817 | // Set heuristics in case single reduction kernels were left out |
2818 | group->setHeuristic(deriveHeuristic(group)); |
2819 | } |
2820 | } |
2821 | |
2822 | // Remove all scalar edges since they do not represent actual |
2823 | // dependency among segmented groups. |
2824 | removeScalarEdges(); |
2825 | |
2826 | // Run pre-merge heuristics |
2827 | if (options_.run_combine_reductions && CombineReductions::shouldRun(this)) { |
2828 | CombineReductions::run(this); |
2829 | } |
2830 | |
2831 | // All merges will be vertical beyond this point for now, so |
2832 | // we can remove the input auxiliary groups. Should make the vertical |
2833 | // merges avoid auxiliary group once we start general horizontal merges |
2834 | std::unordered_set<SegmentedGroup*> input_groups; |
2835 | for (auto input : forwarded_fusion_inputs) { |
2836 | input_groups.insert(input2group.at(input)); |
2837 | } |
2838 | eraseGroups(input_groups); |
2839 | |
2840 | if (options_.run_herrmann_merge) { |
2841 | bool merged_nodes = true; |
2842 | // Initial merge iteration |
2843 | while (merged_nodes) { |
2844 | // Reset stateful traversal details in SegmentedGroups |
2845 | resetTraversal(); |
2846 | |
2847 | resetLevels(); |
2848 | |
2849 | for (auto& group : groups()) { |
2850 | if (group->merged_) { |
2851 | continue; |
2852 | } |
2853 | auto candidates = group->getMergeCandidates(); |
2854 | if (candidates.empty()) { |
2855 | continue; |
2856 | } |
2857 | |
2858 | auto candidate_it = candidates.begin(); |
2859 | while (candidate_it != candidates.end() && |
2860 | !codeGenSupportedMerge(group, candidate_it->group)) { |
2861 | candidate_it++; |
2862 | } |
2863 | if (candidate_it == candidates.end()) { |
2864 | continue; |
2865 | } |
2866 | |
2867 | to_merge_.emplace_back(group); |
2868 | to_merge_.emplace_back(candidate_it->group); |
2869 | |
2870 | group->merged_ = true; |
2871 | group->merge_with_ = candidate_it->group; |
2872 | group->merge_through_ = candidate_it->edge; |
2873 | |
2874 | candidate_it->group->merged_ = true; |
2875 | candidate_it->group->merge_with_ = group; |
2876 | candidate_it->group->merge_through_ = candidate_it->edge; |
2877 | } |
2878 | |
2879 | if (to_merge_.empty()) { |
2880 | merged_nodes = false; |
2881 | } |
2882 | |
2883 | mergeNodes(); |
2884 | } |
2885 | } |
2886 | |
2887 | if (options_.run_final_merge) { |
2888 | // TODO: consider interleaving herrmman merge and bruteforce merge, as |
2889 | // bruteforce merge can introduce opportunities for more herrmann merge |
2890 | finalMerge(); |
2891 | } |
2892 | |
2893 | finalize(); |
2894 | |
2895 | if (isDebugDumpEnabled(DebugDumpOption::FusionSegmentsDrawing)) { |
2896 | segmented_fusion_->draw(); |
2897 | } |
2898 | } |
2899 | |
2900 | void SegmentCandidateFinder::finalMerge() { |
2901 | auto producer_check = getGroupDependency(); |
2902 | |
2903 | bool merged_nodes = true; |
2904 | while (merged_nodes) { |
2905 | // Iterate all groups and check if a group |
2906 | // can merge with one of its consumers |
2907 | for (auto producer_group : groups()) { |
2908 | // Populate consumers and their corresponding consumer edges |
2909 | std::unordered_map<SegmentedGroup*, SegmentedEdge*> consumer_edge_map; |
2910 | std::vector<SegmentedGroup*> all_consumers_of_producer_group; |
2911 | for (auto consumer : producer_group->consumer_edges) { |
2912 | // Since this is the last fusion pass, we can enable fusion through |
2913 | // outputs. Priority of this was decreased because if the only |
2914 | // connection between groups is an output node, best case scenario we |
2915 | // can save a single pass in memory. Where if it wasn't an output it |
2916 | // would be two passes. |
2917 | consumer_edge_map.insert({consumer->to, consumer}); |
2918 | } |
2919 | // Populate all consumers from the map to avoid duplicate |
2920 | std::transform( |
2921 | consumer_edge_map.begin(), |
2922 | consumer_edge_map.end(), |
2923 | std::back_inserter(all_consumers_of_producer_group), |
2924 | [](auto& it) { return it.first; }); |
2925 | |
2926 | for (auto consumer : all_consumers_of_producer_group) { |
2927 | if (!producer_check->isConsumerOfAny( |
2928 | consumer, all_consumers_of_producer_group) && |
2929 | codeGenSupportedMerge(producer_group, consumer)) { |
2930 | to_merge_.emplace_back(producer_group); |
2931 | to_merge_.emplace_back(consumer); |
2932 | producer_group->merged_ = true; |
2933 | producer_group->merge_with_ = consumer; |
2934 | producer_group->merge_through_ = consumer_edge_map.at(consumer); |
2935 | consumer->merged_ = true; |
2936 | consumer->merge_with_ = producer_group; |
2937 | consumer->merge_through_ = producer_group->merge_through_; |
2938 | break; |
2939 | } |
2940 | } |
2941 | |
2942 | // Only want to merge one pair at a time so break if found any |
2943 | if (!to_merge_.empty()) { |
2944 | break; |
2945 | } |
2946 | } |
2947 | |
2948 | if (to_merge_.empty()) { |
2949 | merged_nodes = false; |
2950 | } else { |
2951 | TORCH_INTERNAL_ASSERT( |
2952 | to_merge_.size() == 2, "merging more than 2 nodes in final iter" ); |
2953 | mergeNodes(); |
2954 | } |
2955 | } |
2956 | } |
2957 | |
2958 | void SegmentCandidateFinder::resolveScalarsInGroup(SegmentedGroup* group) { |
2959 | std::vector<Val*> to_visit; |
2960 | std::unordered_set<Val*> visited; |
2961 | |
2962 | // Collect all scalar uses in the group |
2963 | for (auto expr : group->exprs()) { |
2964 | for (auto input : expr->inputs()) { |
2965 | if (input->isScalar()) { |
2966 | to_visit.push_back(input); |
2967 | } |
2968 | } |
2969 | } |
2970 | |
2971 | // Keep track of composite fusion inputs used in this group |
2972 | std::unordered_set<Val*> input_set( |
2973 | group->input_vals.begin(), group->input_vals.end()); |
2974 | |
2975 | // Record and append all missing scalar exprs at the end. |
2976 | std::vector<Expr*> exprs_to_add; |
2977 | |
2978 | // Do a stack based traversal of the scalar ops to avoid |
2979 | // combinatorial duplication of exprs. |
2980 | while (!to_visit.empty()) { |
2981 | auto stack_top_val = to_visit.back(); |
2982 | if (visited.count(stack_top_val)) { |
2983 | to_visit.pop_back(); |
2984 | } else if (stack_top_val->definition() == nullptr) { |
2985 | // A scalar without def can be a scalar, a tensor dim, |
2986 | // or a composite fusion input |
2987 | // The first two cases are handled in finalize(), |
2988 | // the last case needs to add new input_val to this group. |
2989 | visited.insert(stack_top_val); |
2990 | // If this is a composite fusion scalar input, make sure this group has it |
2991 | if (stack_top_val->isFusionInput() && !input_set.count(stack_top_val)) { |
2992 | group->input_vals.push_back(stack_top_val); |
2993 | input_set.insert(stack_top_val); |
2994 | } |
2995 | to_visit.pop_back(); |
2996 | } else { |
2997 | // A scalar with an actual definition |
2998 | auto definition_expr = stack_top_val->definition(); |
2999 | bool all_inputs_visited = true; |
3000 | // If any of the inputs are not visited, visit them first |
3001 | for (auto input : definition_expr->inputs()) { |
3002 | if (!visited.count(input)) { |
3003 | all_inputs_visited = false; |
3004 | to_visit.push_back(input); |
3005 | } |
3006 | } |
3007 | // This node is ready to be visited |
3008 | if (all_inputs_visited) { |
3009 | // Collect the defining expr to insert into group |
3010 | exprs_to_add.push_back(definition_expr); |
3011 | visited.insert(stack_top_val); |
3012 | to_visit.pop_back(); |
3013 | } |
3014 | } |
3015 | } |
3016 | |
3017 | // Add all the defining expr to the group |
3018 | for (auto expr : exprs_to_add) { |
3019 | group->exprs_.push_back(expr); |
3020 | } |
3021 | } |
3022 | |
3023 | void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) { |
3024 | std::vector<Val*> to_visit; |
3025 | std::unordered_set<Val*> visited; |
3026 | |
3027 | // Collect all inputs to group that are not inputs of fusion |
3028 | for (auto input : group->inputs()) { |
3029 | if (!input->isFusionInput()) { |
3030 | to_visit.push_back(input); |
3031 | } |
3032 | } |
3033 | |
3034 | // Reset group inputs to real inputs |
3035 | group->input_vals = IterVisitor::getInputsTo(group->inputs()); |
3036 | |
3037 | // Grab all expressions needed to produce to_visit |
3038 | auto input_exprs = StmtSort::getExprs(completeFusion(), to_visit); |
3039 | |
3040 | // Insert those expressions at the beginning of the group |
3041 | group->exprs_.insert( |
3042 | group->exprs_.begin(), input_exprs.begin(), input_exprs.end()); |
3043 | } |
3044 | |
3045 | void SegmentCandidateFinder::removeScalarEdges() { |
3046 | // Remove all scalar edges between groups |
3047 | // They may have been created by welford |
3048 | // translation. |
3049 | // we will not need them after scalar |
3050 | // resolution |
3051 | auto remove_scalar_edges_from_vec = [](std::vector<SegmentedEdge*>& edges) { |
3052 | edges.erase( |
3053 | std::remove_if( |
3054 | edges.begin(), |
3055 | edges.end(), |
3056 | [](SegmentedEdge* segmented_edge) { |
3057 | return segmented_edge->val->isScalar(); |
3058 | }), |
3059 | edges.end()); |
3060 | }; |
3061 | |
3062 | remove_scalar_edges_from_vec(edges()); |
3063 | for (auto group : groups()) { |
3064 | remove_scalar_edges_from_vec(group->producer_edges); |
3065 | remove_scalar_edges_from_vec(group->consumer_edges); |
3066 | } |
3067 | } |
3068 | |
3069 | void SegmentCandidateFinder::finalize() { |
3070 | // Remove unconnected groups |
3071 | groups().erase( |
3072 | std::remove_if( |
3073 | groups().begin(), |
3074 | groups().end(), |
3075 | [](SegmentedGroup* sg) { return !sg->isConnected(); }), |
3076 | groups().end()); |
3077 | |
3078 | // Add group labeling |
3079 | int i = 0; |
3080 | for (auto it = groups().begin(); it != groups().end(); it++, i++) { |
3081 | deDuplicateScalarExprs((*it)->exprs_); |
3082 | (*it)->setID(i); |
3083 | } |
3084 | |
3085 | // TODO: too many things are currently abstracted under the term |
3086 | // finalize. Need to re-structure in a follow up. |
3087 | |
3088 | // Finalize connections between segmented groups |
3089 | segmented_fusion_->finalize(); |
3090 | |
3091 | // Resolve all the scalar expressions needed in each group |
3092 | for (auto group : segmented_fusion_->groups()) { |
3093 | resolveScalarsInGroup(group); |
3094 | } |
3095 | |
3096 | // Resolve all the scalar expressions needed in each group |
3097 | for (auto group : segmented_fusion_->groups()) { |
3098 | resolveInputsInGroup(group); |
3099 | } |
3100 | |
3101 | // Finalize each group, fill in the missing inputs, i.e. tensor dims. |
3102 | for (auto g : groups()) { |
3103 | g->setHeuristic(deriveHeuristic(g)); |
3104 | g->finalize(); |
3105 | } |
3106 | } |
3107 | |
3108 | GroupDependencyAnalysis* SegmentCandidateFinder::getGroupDependency() { |
3109 | if (!group_dependency_) { |
3110 | group_dependency_ = |
3111 | std::make_unique<GroupDependencyAnalysis>(segmented_fusion_.get()); |
3112 | } |
3113 | return group_dependency_->as<GroupDependencyAnalysis>(); |
3114 | } |
3115 | |
3116 | FusionKernelRuntime::SchedulerEntryPtr SegmentedFusion:: |
3117 | makeInitialSchedulerEntry( |
3118 | SegmentedGroup* sg, |
3119 | SchedulerRuntimeInfo& runtime_info) { |
3120 | auto local_fusion = completeFusion(); |
3121 | FusionSegmentGuard fsg(local_fusion, getAllInputs(sg), getAllOutputs(sg)); |
3122 | // This will be the first time each group is scheduled. So we'd want to |
3123 | // construct the cache data here. |
3124 | auto data_cache_ptr = std::make_unique<HeuristicSummary>( |
3125 | local_fusion, sg->heuristic(), runtime_info); |
3126 | auto data_cache = data_cache_ptr.get(); |
3127 | setCachedHeuristicDataFor(sg, std::move(data_cache_ptr)); |
3128 | return SchedulerEntry::makeEntry( |
3129 | sg->heuristic(), local_fusion, runtime_info, data_cache); |
3130 | } |
3131 | |
3132 | std::unique_ptr<FusionHeuristics> SegmentedFusion::makeInitialHeuristics( |
3133 | const KernelArgumentHolder& inputs) { |
3134 | auto ret = std::make_unique<FusionHeuristics>(); |
3135 | SchedulerRuntimeInfo runtime_info(completeFusion(), inputs, true); |
3136 | for (auto g : groups()) { |
3137 | ret->emplaceBack(makeInitialSchedulerEntry(g, runtime_info)); |
3138 | } |
3139 | return ret; |
3140 | } |
3141 | |
3142 | HeuristicSummary* SegmentedFusion::getCachedHeuristicDataFor( |
3143 | SegmentedGroup* group) { |
3144 | auto data_it = heuristic_summary_cache_.find(group); |
3145 | if (data_it == heuristic_summary_cache_.end()) { |
3146 | return nullptr; |
3147 | } |
3148 | return data_it->second.get(); |
3149 | } |
3150 | |
3151 | void SegmentedFusion::setCachedHeuristicDataFor( |
3152 | SegmentedGroup* group, |
3153 | std::unique_ptr<HeuristicSummary> data) { |
3154 | TORCH_INTERNAL_ASSERT(!heuristic_summary_cache_.count(group)); |
3155 | heuristic_summary_cache_[group] = std::move(data); |
3156 | } |
3157 | |
3158 | namespace { |
3159 | |
3160 | //! A thin traversal class that collects all the tensorviews |
3161 | //! that could cast to fp16 or bf16 if they were segmented edges. |
3162 | //! The selected values are currently defined as all the |
3163 | //! tensorviews that |
3164 | //! 1. are not complete fusion input/output, |
3165 | //! 2. have a use chain that ends with a fp16 |
3166 | //! complete fusion output |
3167 | //! 3. are fp32 datatype |
3168 | class ForceHalfAnnotation : public IterVisitor { |
3169 | public: |
3170 | static std::unordered_set<TensorView*> getFP16AnnotatedSet(Fusion* fusion) { |
3171 | ForceHalfAnnotation annotation; |
3172 | std::vector<Val*> fp16_outputs; |
3173 | auto& cast_to_type = annotation.cast_to_type_; |
3174 | auto other_half_type = |
3175 | cast_to_type == DataType::Half ? DataType::BFloat16 : DataType::Half; |
3176 | std::copy_if( |
3177 | fusion->outputs().begin(), |
3178 | fusion->outputs().end(), |
3179 | std::back_inserter(fp16_outputs), |
3180 | [&cast_to_type, &other_half_type](auto* val) { |
3181 | auto dtype = val->getDataType().value(); |
3182 | if (cast_to_type) { |
3183 | TORCH_INTERNAL_ASSERT( |
3184 | other_half_type != dtype, |
3185 | "Mix of BFloat16 and Float16 in the same graph is not supported." ); |
3186 | } |
3187 | return val->template isA<TensorView>() && |
3188 | val->getDataType().has_value() && |
3189 | (val->getDataType().value() == DataType::Half || |
3190 | val->getDataType().value() == DataType::BFloat16); |
3191 | }); |
3192 | |
3193 | annotation.traverseTo(fusion, fp16_outputs); |
3194 | return annotation.force_fp16_tv_set_; |
3195 | } |
3196 | |
3197 | private: |
3198 | using IterVisitor::handle; |
3199 | |
3200 | void handle(TensorView* tv) override { |
3201 | auto dtype = tv->getDataType(); |
3202 | if (dtype.has_value() && dtype.value() == DataType::Float && |
3203 | !tv->isFusionOutput() && !tv->isFusionInput()) { |
3204 | force_fp16_tv_set_.insert(tv); |
3205 | } |
3206 | } |
3207 | |
3208 | std::unordered_set<TensorView*> force_fp16_tv_set_; |
3209 | c10::optional<DataType> cast_to_type_ = c10::nullopt; |
3210 | }; |
3211 | |
3212 | } // namespace |
3213 | |
3214 | void SegmentedFusion::annotateFP16IntermediateTensors() { |
3215 | force_fp16_tv_set_ = |
3216 | ForceHalfAnnotation::getFP16AnnotatedSet(complete_fusion_.get()); |
3217 | for (auto out_tv : |
3218 | ir_utils::filterByType<TensorView>(complete_fusion_->outputs())) { |
3219 | if (out_tv) { |
3220 | auto dtype = out_tv->getDataType().value(); |
3221 | if (dtype == DataType::Half || dtype == DataType::BFloat16) { |
3222 | force_half_precision_type_ = dtype; |
3223 | } |
3224 | } |
3225 | } |
3226 | } |
3227 | |
3228 | std::string toString(const SegmentCandidateFinderOptions& segment_options) { |
3229 | std::stringstream ss; |
3230 | ss << "segmentation phases {\n" ; |
3231 | if (segment_options.run_combine_reductions) { |
3232 | ss << "combine reductions\n" ; |
3233 | } |
3234 | if (segment_options.run_herrmann_merge) { |
3235 | ss << "herrmann merging\n" ; |
3236 | } |
3237 | if (segment_options.run_final_merge) { |
3238 | ss << "final merging\n" ; |
3239 | } |
3240 | ss << "\n}\n" ; |
3241 | return ss.str(); |
3242 | } |
3243 | |
3244 | } // namespace cuda |
3245 | } // namespace fuser |
3246 | } // namespace jit |
3247 | } // namespace torch |
3248 | |