1#pragma once
2
3#include <fusion.h>
4#include <ir_base_nodes.h>
5#include <kernel_cache.h>
6#include <scheduler/all_schedulers.h>
7#include <scheduler/registry.h>
8#include <utils.h>
9
10#include <deque>
11#include <list>
12#include <unordered_set>
13#include <vector>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20class SegmentedGroup;
21class SegmentCandidateFinder;
22
23// A directed edge on DAG,
24// Wrapper for values, edges between segmented groups which are made up
25// of Exprs. Multiple edges can exist between segmented groups.
26struct SegmentedEdge {
27 SegmentedEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val)
28 : from(from), to(to), val(val) {}
29
30 SegmentedGroup* from;
31 SegmentedGroup* to;
32 Val* val;
33
34 void print() const;
35};
36
37std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge);
38
39//! Groups together expressions which create a segmented group
40//! Can be used to produce fusions
41class TORCH_CUDA_CU_API SegmentedGroup {
42 public:
43 SegmentedGroup(SegmentedFusion* segmented_fusion)
44 : segmented_fusion_(segmented_fusion) {}
45
46 SegmentedGroup(Expr* expr, SegmentedFusion* segmented_fusion)
47 : segmented_fusion_(segmented_fusion) {
48 exprs_.push_back(expr);
49 }
50
51 //! Checks if this group takes original fusion's input
52 bool isInputGroup() {
53 return !input_vals.empty();
54 };
55
56 //! Checks if this group is used any where in the segmented fusion
57 bool isConnected() const {
58 return !producer_edges.empty() || !consumer_edges.empty() ||
59 !output_vals.empty();
60 }
61
62 //! returns the id assigned by segment pass
63 int groupId() const {
64 return group_id_;
65 }
66
67 //! Returns inputs that this group shares with the original fusion
68 const auto& inputs() const {
69 return input_vals;
70 }
71
72 //! Returns outputs that this group shares with the original fusion
73 const auto& outputs() const {
74 return output_vals;
75 }
76
77 //! Returns the schedule heuristic associated with this group
78 ScheduleHeuristic heuristic() const {
79 return heuristic_;
80 }
81
82 //! Returns the exprs that make up this group
83 const auto& exprs() const {
84 return exprs_;
85 }
86
87 //! Debug print function
88 void print() const;
89
90 //! Returns the segmented fusion that this group is in
91 SegmentedFusion* segmentedFusion() const {
92 return segmented_fusion_;
93 }
94
95 //! Utility to re-collect the operators included in this
96 //! segmented group after updating the group boundary.
97 void resetExprList();
98
99 //! Try to get a scheduler entry for this group with
100 //! the given runtime info.
101 //! Returns a new scheduler with the same heuristics
102 //! for this group if possible.
103 //! Note that the schedule params can be different.
104 //! Returns a nullopt if this group cannot be scheduled
105 //! with the same heuristics.
106 c10::optional<std::unique_ptr<SchedulerEntry>> getMaybeSchedulerEntry(
107 SchedulerRuntimeInfo& runtime_info);
108
109 public:
110 //! "Ancestor nodes", towards inputs of segmentedDAG
111 std::vector<SegmentedEdge*> producer_edges;
112
113 //! "Descendent nodes", towards outputs of segmentedDAG
114 std::vector<SegmentedEdge*> consumer_edges;
115
116 //! Composite Fusion inputs in this group
117 std::vector<Val*> input_vals;
118
119 //! Composite Fusion outputs in this group
120 std::vector<Val*> output_vals;
121
122 private:
123 friend class SegmentCandidateFinder;
124 friend class SegmentedFusion;
125 friend class FusionKernelRuntime;
126 friend class TranslateApplicableWelford;
127
128 //! unique identifier of group in the segmented fusion
129 int group_id_ = -1;
130
131 //! The scheduler to use for compiling this group
132 ScheduleHeuristic heuristic_ = ScheduleHeuristic::None;
133
134 //! Exprs that make up the group
135 std::vector<Expr*> exprs_;
136
137 //! Maximum path distance from an input segmented group required for
138 //! Theorem 4.2
139 int level_ = -1;
140
141 //! traversal marker, has this node already been processed
142 bool visited_ = false;
143
144 //! Did we select another group to merge with
145 SegmentedGroup* merge_with_ = nullptr;
146
147 //! if we selected another group to merge, which edge is to be contracted
148 SegmentedEdge* merge_through_ = nullptr;
149
150 //! Has this node been merged?
151 bool merged_ = false;
152
153 private:
154 //! Utility to convert edge vector to value vector
155 std::vector<Val*> edgesToVals(const std::vector<SegmentedEdge*>& se_v);
156
157 //! Reset method to call at begining of each
158 //! merge node iteration
159 void clearTraversalInfo();
160
161 //! To be called at the very end of segment fusion
162 //! no more segment merging should be done beyond
163 void finalize();
164
165 //! Return all segmented groups connected with *this
166 std::vector<SegmentedGroup*> getNeighbors();
167
168 //! Utility struct to represent a group connection
169 //! both the group to connect with and the edge
170 //! to connect through
171 struct NeighborGroup {
172 NeighborGroup(SegmentedGroup* g, SegmentedEdge* e) : group(g), edge(e) {}
173 SegmentedGroup* group;
174 SegmentedEdge* edge;
175 };
176
177 //! TODO: May want to sort this based on size of connections between this and
178 //! neighbors as well as if the connection is an output of the fusion (has to
179 //! be saved to gmem anyways)
180 std::vector<NeighborGroup> getNeighborGroups();
181
182 //! Look at all neighbors of this and return who this could merge with based
183 //! on level values of this, neighbors, and merged neighbors of neighbors
184 std::vector<NeighborGroup> getMergeCandidates();
185
186 //! Assign schedule heuristic to this group
187 void setHeuristic(ScheduleHeuristic sh) {
188 heuristic_ = sh;
189 }
190
191 //! Assign Id for this group
192 void setID(int id) {
193 TORCH_INTERNAL_ASSERT(group_id_ == -1);
194 group_id_ = id;
195 }
196
197 //! SegmentedFusion this group belongs to
198 SegmentedFusion* segmented_fusion_;
199};
200
201std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group);
202
203//! Auxiliary class for storing heuristics. The managed data is either
204//! a single scheduler entry for complete fusion,
205//! or a vector of schedulers, one for each segment, for segmented fusion.
206class TORCH_CUDA_CU_API FusionHeuristics {
207 using SchedulerEntryOwningPtr = std::unique_ptr<SchedulerEntry>;
208
209 public:
210 //! Constructor for segmented fusion case. Created with empty list and
211 //! uses emplaceBack for inserting heuristics in order
212 explicit FusionHeuristics() = default;
213
214 //! Constructor for complete fusion case, generates the scheduler entry
215 //! for the fusion owning the given expression
216 explicit FusionHeuristics(
217 ScheduleHeuristic schedule_heuristic,
218 SchedulerRuntimeInfo& runtime_info,
219 HeuristicSummary* data_cache = nullptr) {
220 heuristics_.emplace_back(SchedulerEntry::makeEntry(
221 schedule_heuristic, runtime_info.fusion(), runtime_info, data_cache));
222 is_segmented_ = false;
223 }
224
225 FusionHeuristics(const FusionHeuristics&) = delete;
226 FusionHeuristics& operator=(const FusionHeuristics&) = delete;
227
228 //! Place a scheduler entry on the list. Applies to segmented fusion only.
229 void emplaceBack(SchedulerEntryOwningPtr&& pt) {
230 TORCH_INTERNAL_ASSERT(is_segmented_);
231 heuristics_.emplace_back(std::move(pt));
232 }
233
234 //! Returns list of schedulers for a segmneted fusion.
235 const std::vector<SchedulerEntryOwningPtr>& heuristicsList() const {
236 return heuristics_;
237 }
238
239 //! Returns the single scheduler for a complete fusion.
240 SchedulerEntry* singleKernelHeuristics() {
241 TORCH_INTERNAL_ASSERT(!is_segmented_);
242 return heuristics_.begin()->get();
243 }
244
245 private:
246 std::vector<SchedulerEntryOwningPtr> heuristics_;
247 bool is_segmented_ = true;
248};
249
250//! Exported Interface for representing segmented fusion graph
251//! this class owns the segmented groups
252class TORCH_CUDA_CU_API SegmentedFusion {
253 public:
254 explicit SegmentedFusion(std::unique_ptr<Fusion> fusion);
255
256 //! Factory function for the un-segmented case, directly
257 //! constructs a "SegmentedFusion", with the given Fusion
258 //! as the only group.
259 static std::unique_ptr<SegmentedFusion> fromCompleteFusion(
260 std::unique_ptr<Fusion> fusion,
261 ScheduleHeuristic heuristic);
262
263 //! Is the fusion segmented?
264 bool isSegmented() const {
265 return !groups_.empty();
266 }
267
268 std::vector<SegmentedGroup*>& groups() {
269 return groups_;
270 }
271
272 std::vector<SegmentedEdge*>& edges() {
273 return edges_;
274 }
275
276 const std::vector<SegmentedGroup*>& cgroups() const {
277 return groups_;
278 }
279
280 const std::vector<SegmentedEdge*>& cedges() const {
281 return edges_;
282 }
283
284 //! Returns the original un-segmented fusion
285 Fusion* completeFusion() const {
286 return complete_fusion_.get();
287 }
288
289 const auto& inputs() const {
290 return complete_fusion_->inputs();
291 }
292
293 const auto& outputs() const {
294 return complete_fusion_->outputs();
295 }
296
297 Val* findAlias(Val* val) const {
298 auto alias_it = complete_fusion_->ioAlias().find(val);
299 if (alias_it != complete_fusion_->ioAlias().end()) {
300 return alias_it->second;
301 }
302 return nullptr;
303 }
304
305 //! Make a clone of the group and convert to fusion
306 std::unique_ptr<Fusion> makeFusion(SegmentedGroup* sg);
307
308 //! Make heuristics for all groups in this segmented fusion
309 std::unique_ptr<FusionHeuristics> makeInitialHeuristics(
310 const KernelArgumentHolder& inputs);
311
312 //! Inline Debug print for segmented fusion
313 std::string toString(int verbosity) const;
314
315 //! Debug drawing for graphviz
316 void draw();
317
318 //! Debug print for segmented fusions
319 void print() const;
320
321 //! API for adding groups
322 SegmentedGroup* newGroup();
323
324 //! API shortcut for adding a singleton group
325 SegmentedGroup* newGroup(Expr* expr);
326
327 //! API for adding edges
328 SegmentedEdge* newEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val);
329
330 HeuristicSummary* getCachedHeuristicDataFor(SegmentedGroup* group);
331
332 private:
333 //! Unique name for segmented fusion
334 int segmented_fusion_name_;
335
336 //! States representing segmentation
337 std::vector<SegmentedEdge*> edges_;
338 std::vector<SegmentedGroup*> groups_;
339
340 //! Owning object to explicitly manage groups and edges
341 class Impl {
342 public:
343 explicit Impl(SegmentedFusion* sf) : owning_fusion_(sf) {}
344
345 SegmentedGroup* makeGroup();
346 SegmentedGroup* makeGroup(Expr*);
347 SegmentedEdge* makeEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val);
348 void cleanUnused();
349
350 private:
351 using GroupPtr = std::unique_ptr<SegmentedGroup>;
352 using EdgePtr = std::unique_ptr<SegmentedEdge>;
353 std::vector<GroupPtr> groups_;
354 std::vector<EdgePtr> edges_;
355 SegmentedFusion* owning_fusion_;
356 };
357 Impl impl_;
358
359 //! A Copy of original full fusion
360 std::unique_ptr<Fusion> complete_fusion_;
361
362 //! A set of intermediate tensors that need to be cast to fp16
363 std::unordered_set<TensorView*> force_fp16_tv_set_;
364
365 DataType force_half_precision_type_;
366
367 //! Static traversal information to be used for fast heuristics lookup
368 std::unordered_map<SegmentedGroup*, std::unique_ptr<HeuristicSummary>>
369 heuristic_summary_cache_;
370
371 // TODO: this class needs cleanup
372 protected:
373 friend class SegmentCandidateFinder;
374 //! Make a heuristics entry for a group and parameters
375 std::unique_ptr<SchedulerEntry> makeInitialSchedulerEntry(
376 SegmentedGroup* sg,
377 SchedulerRuntimeInfo& runtime_info);
378
379 //! Cleanup function to be call at the end of fusion
380 //! segment pass
381 void finalize();
382
383 //! Collect all the intermediate tensors between segmented
384 //! groups that will cast to fp16
385 void annotateFP16IntermediateTensors();
386
387 //! Keep heuristic checking intermediate data
388 void setCachedHeuristicDataFor(
389 SegmentedGroup* group,
390 std::unique_ptr<HeuristicSummary> data);
391
392 //! Utility to give unique name for each segmented fusion
393 static size_t segmentedFusionName() {
394 static size_t counter = 0;
395 return counter++;
396 }
397};
398
399//! This is a base class for segmenter analysis
400//! provides the minimal implementation on header so that
401//! a unique_ptr can use this base class
402//! actual implementations of analyses are in the .cpp files
403//! TODO: In the next refactor PR, should put segment candidate
404//! finder in .cpp file completely since API doesn't require these
405//! details
406class SegmenterAnalysis : public PolymorphicBase {};
407class GroupDependencyAnalysis;
408
409// Manual node merging passes
410class CombineReductions;
411
412//! Options to configure/debug candidate finder
413struct TORCH_CUDA_CU_API SegmentCandidateFinderOptions {
414 bool run_translate_welford = true;
415 bool run_combine_reductions = true;
416 bool run_herrmann_merge = true;
417 bool run_final_merge = true;
418};
419
420//! SegmentCandidateFinder
421//! Responsible for going through DAG and proposing things we could try to
422//! fuse together, calls "canGenerateCode" on these proposed segments to see
423//! if they are valid and we can generate code for them.
424//! FusionSegment
425//! A group of exprs that are segmented together
426//! FusionSegmentConnections
427//! Holds vals and what they connect. In other words it's a val that is an
428//! output of a FusionSegment "from" and an input of FusionSegment "to".
429//! There's nothing preventing from a val being between segments twice.
430//! TODO: make sure there's nothing wrong with segmentation on nodes that
431//! have the same value input twice. i.e. (B = A*A)
432//! Selecting segments to propose is based on the theorem 4.2 in the paper which
433//! makes sure when segment the segmented graph will be a DAG (assumes Fusion is
434//! already a DAG). The segmentation code relies on assumptions of DAG-ness
435//! during segmentation, meaning proposed merging of groups must maintain the
436//! DAG property of the graph.
437//!
438//! Julien Herrmann, Yusuf Özkaya, Bora Uçar, Kamer Kaya, Umit Catalyurek.
439//! Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs.
440//! SIAM Journal on Scientific Computing, Society for Industrial and Applied
441//! Mathematics, 2019, 41 (4), pp.A2117-A2145. ff10.1137/18M1176865ff.
442//! ffhal02306566f
443class TORCH_CUDA_CU_API SegmentCandidateFinder {
444 public:
445 // Perform segmentation on a copy of the given fusion
446 static std::unique_ptr<SegmentedFusion> segment(
447 const Fusion* fusion,
448 const KernelArgumentHolder& inputs,
449 SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) {
450 auto fusion_copy = std::make_unique<Fusion>(*fusion);
451 if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) {
452 std::cout << "Segment the fusion (Original Fusion Un-modified): "
453 << std::endl;
454 fusion_copy->printMath();
455 }
456 SegmentCandidateFinder scf(std::move(fusion_copy), inputs, options);
457 return std::move(scf.segmented_fusion_);
458 }
459
460 // Perform segmentation on and take ownership of the given fusion
461 static std::unique_ptr<SegmentedFusion> segment(
462 std::unique_ptr<Fusion> fusion,
463 const KernelArgumentHolder& inputs,
464 SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) {
465 SegmentCandidateFinder scf(std::move(fusion), inputs, options);
466 if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) {
467 std::cout << "Segment the fusion (Original Fusion Un-modified): "
468 << std::endl;
469 scf.completeFusion()->printMath();
470 }
471 return std::move(scf.segmented_fusion_);
472 }
473
474 static bool TranslateWelfordInFusion(
475 Fusion* fusion,
476 const KernelArgumentHolder& runtime_inputs);
477
478 private:
479 // Perform segmentation on and take ownership of the given fusion
480 SegmentCandidateFinder(
481 std::unique_ptr<Fusion> fusion,
482 const KernelArgumentHolder& inputs,
483 SegmentCandidateFinderOptions options);
484
485 void resetTraversal();
486
487 void resetLevels();
488
489 SegmentedGroup* mergeNodes();
490
491 bool codeGenSupportedMerge(SegmentedGroup* group1, SegmentedGroup* group2);
492
493 void findSegments();
494
495 std::unordered_set<SegmentedEdge*> disconnectGroup(SegmentedGroup* group);
496
497 std::vector<SegmentedGroup*>& groups() {
498 TORCH_INTERNAL_ASSERT(
499 segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
500 return segmented_fusion_->groups();
501 }
502
503 std::vector<SegmentedEdge*>& edges() {
504 TORCH_INTERNAL_ASSERT(
505 segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
506 return segmented_fusion_->edges();
507 }
508
509 Fusion* completeFusion() {
510 TORCH_INTERNAL_ASSERT(
511 segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
512 return segmented_fusion_->completeFusion();
513 }
514
515 SchedulerRuntimeInfo& runtimeInfo() {
516 return runtime_info_;
517 }
518
519 ExpressionEvaluator& expressionEvaluator() {
520 return runtime_info_.expressionEvaluator();
521 }
522
523 //! Additional merging iteration, clean up the rest of
524 //! the merging opportunities
525 //! Herrmann et al. is a fast and safe algorithm for finding merge candidates
526 //! but can become too conservative in our use cases because we place
527 //! additional qualifiers on valid merges other than having to generate DAGs,
528 //! i.e. canSchedule. So we need a bruteforce final merging iteration as a
529 //! clean up pass. Cost isn't expected to be high since the graph at this
530 //! stage is already quite merged. Example cf. test_gpu.cpp:
531 //! FusionDAGMerging_CUDA
532 //!
533 //! This merging algorithm is based on Theorem 4.1 of Herrmann et al.,
534 //! to check if a producer-consumer pair can be merged into one group,
535 //! it's enough to check if any other consumer of the producer also
536 //! produces the consumer.
537 void finalMerge();
538
539 //! Duplicate and add all exprs producing the used
540 //! scalar values in group
541 void resolveScalarsInGroup(SegmentedGroup* group);
542
543 //! Duplicate and add all exprs from "inputs" in the group, to complete
544 //! inputs. These expressions are simply unary ops of inputs that we want to
545 //! recompute for each segment, instead of computing and producing a segmented
546 //! val. For example if we have:
547 //! tv1 = tv0 * 2;
548 //! tv3 = tv1 + tv2;
549 //! tv4 = tv1 + tv4
550 //! If we segmented on tv1, we would be producing an output for tv1 for 2
551 //! groups that have tv3 or tv4, instead we could easily recompute tv1 from
552 //! tv0.
553 void resolveInputsInGroup(SegmentedGroup* group);
554
555 //! Remove all scalar edges in group
556 //! (TODO: need structure better so we don't have to do this)
557 void removeScalarEdges();
558
559 //! Utility function to merge a vector of groups in one step,
560 //! need to check for DAG condition before using this method
561 SegmentedGroup* mergeAllGivenGroups(
562 const std::vector<SegmentedGroup*>& groups);
563
564 //! Utility to remove a group and corresponding edges
565 //! TODO: remove inline versions of this as much as possible
566 void eraseGroups(std::unordered_set<SegmentedGroup*>& groups_to_erase);
567
568 void finalize();
569
570 //! Return the resulting heuristic corresponding to the merged
571 //! group built by merging the two groups connected by edge
572 ScheduleHeuristic deriveHeuristic(SegmentedGroup* edge);
573
574 GroupDependencyAnalysis* getGroupDependency();
575
576 protected:
577 //! These are the merge node heuristic passes, should
578 //! eventually should have a dedicated interface
579 //! instead of keeping adding friends
580 friend class CombineReductions;
581
582 //! options to configure and debug the segment process
583 SegmentCandidateFinderOptions options_;
584
585 std::deque<SegmentedGroup*> to_visit_;
586 std::vector<SegmentedGroup*> next_to_visit_;
587
588 std::unordered_set<SegmentedGroup*> clean_up_groups_;
589 std::unordered_set<SegmentedEdge*> clean_up_edges_;
590
591 std::vector<SegmentedGroup*> to_merge_;
592
593 std::unique_ptr<SegmentedFusion> segmented_fusion_;
594
595 std::unique_ptr<SegmenterAnalysis> group_dependency_;
596
597 SchedulerRuntimeInfo runtime_info_;
598
599 //! Note:
600 //! Segmenter should eventually rely only on runtime_info_ for
601 //! safe caching. runtime_inputs_ is only used in translateWelford
602 //! to initialize expression evaluators on copies of the original
603 //! fusion, which doesn't use any un-cached info and is safe.
604 //!
605 //! Directly using runtime_inputs_ in other cases is in general
606 //! risky.
607 //!
608 //! To get rid of runtime_inputs_ we need mechanisms
609 //! to copy expression evaluator values from fusion
610 //! to a copy, or even better to a copy of a
611 //! sub-graph of original fusion.
612 //! TODO:
613 //! implement the expression evaluator transfer and
614 //! remove runtime_inputs_ in a follow up.
615 const KernelArgumentHolder& runtime_inputs_;
616};
617
618// TODO: Make as member functions on classes instead of global scope
619TORCH_CUDA_CU_API std::string toString(const SegmentedGroup* group);
620TORCH_CUDA_CU_API std::string toString(const SegmentedEdge* edge);
621TORCH_CUDA_CU_API std::string toString(const SegmentedFusion* segmented_fusion);
622TORCH_CUDA_CU_API std::string toString(
623 const SegmentCandidateFinderOptions& segment_options);
624
625} // namespace cuda
626} // namespace fuser
627} // namespace jit
628} // namespace torch
629