1 | #pragma once |
2 | |
3 | #include <fusion.h> |
4 | #include <ir_base_nodes.h> |
5 | #include <kernel_cache.h> |
6 | #include <scheduler/all_schedulers.h> |
7 | #include <scheduler/registry.h> |
8 | #include <utils.h> |
9 | |
10 | #include <deque> |
11 | #include <list> |
12 | #include <unordered_set> |
13 | #include <vector> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | namespace fuser { |
18 | namespace cuda { |
19 | |
20 | class SegmentedGroup; |
21 | class SegmentCandidateFinder; |
22 | |
23 | // A directed edge on DAG, |
24 | // Wrapper for values, edges between segmented groups which are made up |
25 | // of Exprs. Multiple edges can exist between segmented groups. |
26 | struct SegmentedEdge { |
27 | SegmentedEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val) |
28 | : from(from), to(to), val(val) {} |
29 | |
30 | SegmentedGroup* from; |
31 | SegmentedGroup* to; |
32 | Val* val; |
33 | |
34 | void print() const; |
35 | }; |
36 | |
37 | std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge); |
38 | |
39 | //! Groups together expressions which create a segmented group |
40 | //! Can be used to produce fusions |
41 | class TORCH_CUDA_CU_API SegmentedGroup { |
42 | public: |
43 | SegmentedGroup(SegmentedFusion* segmented_fusion) |
44 | : segmented_fusion_(segmented_fusion) {} |
45 | |
46 | SegmentedGroup(Expr* expr, SegmentedFusion* segmented_fusion) |
47 | : segmented_fusion_(segmented_fusion) { |
48 | exprs_.push_back(expr); |
49 | } |
50 | |
51 | //! Checks if this group takes original fusion's input |
52 | bool isInputGroup() { |
53 | return !input_vals.empty(); |
54 | }; |
55 | |
56 | //! Checks if this group is used any where in the segmented fusion |
57 | bool isConnected() const { |
58 | return !producer_edges.empty() || !consumer_edges.empty() || |
59 | !output_vals.empty(); |
60 | } |
61 | |
62 | //! returns the id assigned by segment pass |
63 | int groupId() const { |
64 | return group_id_; |
65 | } |
66 | |
67 | //! Returns inputs that this group shares with the original fusion |
68 | const auto& inputs() const { |
69 | return input_vals; |
70 | } |
71 | |
72 | //! Returns outputs that this group shares with the original fusion |
73 | const auto& outputs() const { |
74 | return output_vals; |
75 | } |
76 | |
77 | //! Returns the schedule heuristic associated with this group |
78 | ScheduleHeuristic heuristic() const { |
79 | return heuristic_; |
80 | } |
81 | |
82 | //! Returns the exprs that make up this group |
83 | const auto& exprs() const { |
84 | return exprs_; |
85 | } |
86 | |
87 | //! Debug print function |
88 | void print() const; |
89 | |
90 | //! Returns the segmented fusion that this group is in |
91 | SegmentedFusion* segmentedFusion() const { |
92 | return segmented_fusion_; |
93 | } |
94 | |
95 | //! Utility to re-collect the operators included in this |
96 | //! segmented group after updating the group boundary. |
97 | void resetExprList(); |
98 | |
99 | //! Try to get a scheduler entry for this group with |
100 | //! the given runtime info. |
101 | //! Returns a new scheduler with the same heuristics |
102 | //! for this group if possible. |
103 | //! Note that the schedule params can be different. |
104 | //! Returns a nullopt if this group cannot be scheduled |
105 | //! with the same heuristics. |
106 | c10::optional<std::unique_ptr<SchedulerEntry>> getMaybeSchedulerEntry( |
107 | SchedulerRuntimeInfo& runtime_info); |
108 | |
109 | public: |
110 | //! "Ancestor nodes", towards inputs of segmentedDAG |
111 | std::vector<SegmentedEdge*> producer_edges; |
112 | |
113 | //! "Descendent nodes", towards outputs of segmentedDAG |
114 | std::vector<SegmentedEdge*> consumer_edges; |
115 | |
116 | //! Composite Fusion inputs in this group |
117 | std::vector<Val*> input_vals; |
118 | |
119 | //! Composite Fusion outputs in this group |
120 | std::vector<Val*> output_vals; |
121 | |
122 | private: |
123 | friend class SegmentCandidateFinder; |
124 | friend class SegmentedFusion; |
125 | friend class FusionKernelRuntime; |
126 | friend class TranslateApplicableWelford; |
127 | |
128 | //! unique identifier of group in the segmented fusion |
129 | int group_id_ = -1; |
130 | |
131 | //! The scheduler to use for compiling this group |
132 | ScheduleHeuristic heuristic_ = ScheduleHeuristic::None; |
133 | |
134 | //! Exprs that make up the group |
135 | std::vector<Expr*> exprs_; |
136 | |
137 | //! Maximum path distance from an input segmented group required for |
138 | //! Theorem 4.2 |
139 | int level_ = -1; |
140 | |
141 | //! traversal marker, has this node already been processed |
142 | bool visited_ = false; |
143 | |
144 | //! Did we select another group to merge with |
145 | SegmentedGroup* merge_with_ = nullptr; |
146 | |
147 | //! if we selected another group to merge, which edge is to be contracted |
148 | SegmentedEdge* merge_through_ = nullptr; |
149 | |
150 | //! Has this node been merged? |
151 | bool merged_ = false; |
152 | |
153 | private: |
154 | //! Utility to convert edge vector to value vector |
155 | std::vector<Val*> edgesToVals(const std::vector<SegmentedEdge*>& se_v); |
156 | |
157 | //! Reset method to call at begining of each |
158 | //! merge node iteration |
159 | void clearTraversalInfo(); |
160 | |
161 | //! To be called at the very end of segment fusion |
162 | //! no more segment merging should be done beyond |
163 | void finalize(); |
164 | |
165 | //! Return all segmented groups connected with *this |
166 | std::vector<SegmentedGroup*> getNeighbors(); |
167 | |
168 | //! Utility struct to represent a group connection |
169 | //! both the group to connect with and the edge |
170 | //! to connect through |
171 | struct NeighborGroup { |
172 | NeighborGroup(SegmentedGroup* g, SegmentedEdge* e) : group(g), edge(e) {} |
173 | SegmentedGroup* group; |
174 | SegmentedEdge* edge; |
175 | }; |
176 | |
177 | //! TODO: May want to sort this based on size of connections between this and |
178 | //! neighbors as well as if the connection is an output of the fusion (has to |
179 | //! be saved to gmem anyways) |
180 | std::vector<NeighborGroup> getNeighborGroups(); |
181 | |
182 | //! Look at all neighbors of this and return who this could merge with based |
183 | //! on level values of this, neighbors, and merged neighbors of neighbors |
184 | std::vector<NeighborGroup> getMergeCandidates(); |
185 | |
186 | //! Assign schedule heuristic to this group |
187 | void setHeuristic(ScheduleHeuristic sh) { |
188 | heuristic_ = sh; |
189 | } |
190 | |
191 | //! Assign Id for this group |
192 | void setID(int id) { |
193 | TORCH_INTERNAL_ASSERT(group_id_ == -1); |
194 | group_id_ = id; |
195 | } |
196 | |
197 | //! SegmentedFusion this group belongs to |
198 | SegmentedFusion* segmented_fusion_; |
199 | }; |
200 | |
201 | std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group); |
202 | |
203 | //! Auxiliary class for storing heuristics. The managed data is either |
204 | //! a single scheduler entry for complete fusion, |
205 | //! or a vector of schedulers, one for each segment, for segmented fusion. |
206 | class TORCH_CUDA_CU_API FusionHeuristics { |
207 | using SchedulerEntryOwningPtr = std::unique_ptr<SchedulerEntry>; |
208 | |
209 | public: |
210 | //! Constructor for segmented fusion case. Created with empty list and |
211 | //! uses emplaceBack for inserting heuristics in order |
212 | explicit FusionHeuristics() = default; |
213 | |
214 | //! Constructor for complete fusion case, generates the scheduler entry |
215 | //! for the fusion owning the given expression |
216 | explicit FusionHeuristics( |
217 | ScheduleHeuristic schedule_heuristic, |
218 | SchedulerRuntimeInfo& runtime_info, |
219 | HeuristicSummary* data_cache = nullptr) { |
220 | heuristics_.emplace_back(SchedulerEntry::makeEntry( |
221 | schedule_heuristic, runtime_info.fusion(), runtime_info, data_cache)); |
222 | is_segmented_ = false; |
223 | } |
224 | |
225 | FusionHeuristics(const FusionHeuristics&) = delete; |
226 | FusionHeuristics& operator=(const FusionHeuristics&) = delete; |
227 | |
228 | //! Place a scheduler entry on the list. Applies to segmented fusion only. |
229 | void emplaceBack(SchedulerEntryOwningPtr&& pt) { |
230 | TORCH_INTERNAL_ASSERT(is_segmented_); |
231 | heuristics_.emplace_back(std::move(pt)); |
232 | } |
233 | |
234 | //! Returns list of schedulers for a segmneted fusion. |
235 | const std::vector<SchedulerEntryOwningPtr>& heuristicsList() const { |
236 | return heuristics_; |
237 | } |
238 | |
239 | //! Returns the single scheduler for a complete fusion. |
240 | SchedulerEntry* singleKernelHeuristics() { |
241 | TORCH_INTERNAL_ASSERT(!is_segmented_); |
242 | return heuristics_.begin()->get(); |
243 | } |
244 | |
245 | private: |
246 | std::vector<SchedulerEntryOwningPtr> heuristics_; |
247 | bool is_segmented_ = true; |
248 | }; |
249 | |
250 | //! Exported Interface for representing segmented fusion graph |
251 | //! this class owns the segmented groups |
252 | class TORCH_CUDA_CU_API SegmentedFusion { |
253 | public: |
254 | explicit SegmentedFusion(std::unique_ptr<Fusion> fusion); |
255 | |
256 | //! Factory function for the un-segmented case, directly |
257 | //! constructs a "SegmentedFusion", with the given Fusion |
258 | //! as the only group. |
259 | static std::unique_ptr<SegmentedFusion> fromCompleteFusion( |
260 | std::unique_ptr<Fusion> fusion, |
261 | ScheduleHeuristic heuristic); |
262 | |
263 | //! Is the fusion segmented? |
264 | bool isSegmented() const { |
265 | return !groups_.empty(); |
266 | } |
267 | |
268 | std::vector<SegmentedGroup*>& groups() { |
269 | return groups_; |
270 | } |
271 | |
272 | std::vector<SegmentedEdge*>& edges() { |
273 | return edges_; |
274 | } |
275 | |
276 | const std::vector<SegmentedGroup*>& cgroups() const { |
277 | return groups_; |
278 | } |
279 | |
280 | const std::vector<SegmentedEdge*>& cedges() const { |
281 | return edges_; |
282 | } |
283 | |
284 | //! Returns the original un-segmented fusion |
285 | Fusion* completeFusion() const { |
286 | return complete_fusion_.get(); |
287 | } |
288 | |
289 | const auto& inputs() const { |
290 | return complete_fusion_->inputs(); |
291 | } |
292 | |
293 | const auto& outputs() const { |
294 | return complete_fusion_->outputs(); |
295 | } |
296 | |
297 | Val* findAlias(Val* val) const { |
298 | auto alias_it = complete_fusion_->ioAlias().find(val); |
299 | if (alias_it != complete_fusion_->ioAlias().end()) { |
300 | return alias_it->second; |
301 | } |
302 | return nullptr; |
303 | } |
304 | |
305 | //! Make a clone of the group and convert to fusion |
306 | std::unique_ptr<Fusion> makeFusion(SegmentedGroup* sg); |
307 | |
308 | //! Make heuristics for all groups in this segmented fusion |
309 | std::unique_ptr<FusionHeuristics> makeInitialHeuristics( |
310 | const KernelArgumentHolder& inputs); |
311 | |
312 | //! Inline Debug print for segmented fusion |
313 | std::string toString(int verbosity) const; |
314 | |
315 | //! Debug drawing for graphviz |
316 | void draw(); |
317 | |
318 | //! Debug print for segmented fusions |
319 | void print() const; |
320 | |
321 | //! API for adding groups |
322 | SegmentedGroup* newGroup(); |
323 | |
324 | //! API shortcut for adding a singleton group |
325 | SegmentedGroup* newGroup(Expr* expr); |
326 | |
327 | //! API for adding edges |
328 | SegmentedEdge* newEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val); |
329 | |
330 | HeuristicSummary* getCachedHeuristicDataFor(SegmentedGroup* group); |
331 | |
332 | private: |
333 | //! Unique name for segmented fusion |
334 | int segmented_fusion_name_; |
335 | |
336 | //! States representing segmentation |
337 | std::vector<SegmentedEdge*> edges_; |
338 | std::vector<SegmentedGroup*> groups_; |
339 | |
340 | //! Owning object to explicitly manage groups and edges |
341 | class Impl { |
342 | public: |
343 | explicit Impl(SegmentedFusion* sf) : owning_fusion_(sf) {} |
344 | |
345 | SegmentedGroup* makeGroup(); |
346 | SegmentedGroup* makeGroup(Expr*); |
347 | SegmentedEdge* makeEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val); |
348 | void cleanUnused(); |
349 | |
350 | private: |
351 | using GroupPtr = std::unique_ptr<SegmentedGroup>; |
352 | using EdgePtr = std::unique_ptr<SegmentedEdge>; |
353 | std::vector<GroupPtr> groups_; |
354 | std::vector<EdgePtr> edges_; |
355 | SegmentedFusion* owning_fusion_; |
356 | }; |
357 | Impl impl_; |
358 | |
359 | //! A Copy of original full fusion |
360 | std::unique_ptr<Fusion> complete_fusion_; |
361 | |
362 | //! A set of intermediate tensors that need to be cast to fp16 |
363 | std::unordered_set<TensorView*> force_fp16_tv_set_; |
364 | |
365 | DataType force_half_precision_type_; |
366 | |
367 | //! Static traversal information to be used for fast heuristics lookup |
368 | std::unordered_map<SegmentedGroup*, std::unique_ptr<HeuristicSummary>> |
369 | heuristic_summary_cache_; |
370 | |
371 | // TODO: this class needs cleanup |
372 | protected: |
373 | friend class SegmentCandidateFinder; |
374 | //! Make a heuristics entry for a group and parameters |
375 | std::unique_ptr<SchedulerEntry> makeInitialSchedulerEntry( |
376 | SegmentedGroup* sg, |
377 | SchedulerRuntimeInfo& runtime_info); |
378 | |
379 | //! Cleanup function to be call at the end of fusion |
380 | //! segment pass |
381 | void finalize(); |
382 | |
383 | //! Collect all the intermediate tensors between segmented |
384 | //! groups that will cast to fp16 |
385 | void annotateFP16IntermediateTensors(); |
386 | |
387 | //! Keep heuristic checking intermediate data |
388 | void setCachedHeuristicDataFor( |
389 | SegmentedGroup* group, |
390 | std::unique_ptr<HeuristicSummary> data); |
391 | |
392 | //! Utility to give unique name for each segmented fusion |
393 | static size_t segmentedFusionName() { |
394 | static size_t counter = 0; |
395 | return counter++; |
396 | } |
397 | }; |
398 | |
399 | //! This is a base class for segmenter analysis |
400 | //! provides the minimal implementation on header so that |
401 | //! a unique_ptr can use this base class |
402 | //! actual implementations of analyses are in the .cpp files |
403 | //! TODO: In the next refactor PR, should put segment candidate |
404 | //! finder in .cpp file completely since API doesn't require these |
405 | //! details |
406 | class SegmenterAnalysis : public PolymorphicBase {}; |
407 | class GroupDependencyAnalysis; |
408 | |
409 | // Manual node merging passes |
410 | class CombineReductions; |
411 | |
412 | //! Options to configure/debug candidate finder |
413 | struct TORCH_CUDA_CU_API SegmentCandidateFinderOptions { |
414 | bool run_translate_welford = true; |
415 | bool run_combine_reductions = true; |
416 | bool run_herrmann_merge = true; |
417 | bool run_final_merge = true; |
418 | }; |
419 | |
420 | //! SegmentCandidateFinder |
421 | //! Responsible for going through DAG and proposing things we could try to |
422 | //! fuse together, calls "canGenerateCode" on these proposed segments to see |
423 | //! if they are valid and we can generate code for them. |
424 | //! FusionSegment |
425 | //! A group of exprs that are segmented together |
426 | //! FusionSegmentConnections |
427 | //! Holds vals and what they connect. In other words it's a val that is an |
428 | //! output of a FusionSegment "from" and an input of FusionSegment "to". |
429 | //! There's nothing preventing from a val being between segments twice. |
430 | //! TODO: make sure there's nothing wrong with segmentation on nodes that |
431 | //! have the same value input twice. i.e. (B = A*A) |
432 | //! Selecting segments to propose is based on the theorem 4.2 in the paper which |
433 | //! makes sure when segment the segmented graph will be a DAG (assumes Fusion is |
434 | //! already a DAG). The segmentation code relies on assumptions of DAG-ness |
435 | //! during segmentation, meaning proposed merging of groups must maintain the |
436 | //! DAG property of the graph. |
437 | //! |
438 | //! Julien Herrmann, Yusuf Özkaya, Bora Uçar, Kamer Kaya, Umit Catalyurek. |
439 | //! Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs. |
440 | //! SIAM Journal on Scientific Computing, Society for Industrial and Applied |
441 | //! Mathematics, 2019, 41 (4), pp.A2117-A2145. ff10.1137/18M1176865ff. |
442 | //! ffhal02306566f |
443 | class TORCH_CUDA_CU_API SegmentCandidateFinder { |
444 | public: |
445 | // Perform segmentation on a copy of the given fusion |
446 | static std::unique_ptr<SegmentedFusion> segment( |
447 | const Fusion* fusion, |
448 | const KernelArgumentHolder& inputs, |
449 | SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) { |
450 | auto fusion_copy = std::make_unique<Fusion>(*fusion); |
451 | if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { |
452 | std::cout << "Segment the fusion (Original Fusion Un-modified): " |
453 | << std::endl; |
454 | fusion_copy->printMath(); |
455 | } |
456 | SegmentCandidateFinder scf(std::move(fusion_copy), inputs, options); |
457 | return std::move(scf.segmented_fusion_); |
458 | } |
459 | |
460 | // Perform segmentation on and take ownership of the given fusion |
461 | static std::unique_ptr<SegmentedFusion> segment( |
462 | std::unique_ptr<Fusion> fusion, |
463 | const KernelArgumentHolder& inputs, |
464 | SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) { |
465 | SegmentCandidateFinder scf(std::move(fusion), inputs, options); |
466 | if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { |
467 | std::cout << "Segment the fusion (Original Fusion Un-modified): " |
468 | << std::endl; |
469 | scf.completeFusion()->printMath(); |
470 | } |
471 | return std::move(scf.segmented_fusion_); |
472 | } |
473 | |
474 | static bool TranslateWelfordInFusion( |
475 | Fusion* fusion, |
476 | const KernelArgumentHolder& runtime_inputs); |
477 | |
478 | private: |
479 | // Perform segmentation on and take ownership of the given fusion |
480 | SegmentCandidateFinder( |
481 | std::unique_ptr<Fusion> fusion, |
482 | const KernelArgumentHolder& inputs, |
483 | SegmentCandidateFinderOptions options); |
484 | |
485 | void resetTraversal(); |
486 | |
487 | void resetLevels(); |
488 | |
489 | SegmentedGroup* mergeNodes(); |
490 | |
491 | bool codeGenSupportedMerge(SegmentedGroup* group1, SegmentedGroup* group2); |
492 | |
493 | void findSegments(); |
494 | |
495 | std::unordered_set<SegmentedEdge*> disconnectGroup(SegmentedGroup* group); |
496 | |
497 | std::vector<SegmentedGroup*>& groups() { |
498 | TORCH_INTERNAL_ASSERT( |
499 | segmented_fusion_ != nullptr, "Segment finder not owinging any fusion" ); |
500 | return segmented_fusion_->groups(); |
501 | } |
502 | |
503 | std::vector<SegmentedEdge*>& edges() { |
504 | TORCH_INTERNAL_ASSERT( |
505 | segmented_fusion_ != nullptr, "Segment finder not owinging any fusion" ); |
506 | return segmented_fusion_->edges(); |
507 | } |
508 | |
509 | Fusion* completeFusion() { |
510 | TORCH_INTERNAL_ASSERT( |
511 | segmented_fusion_ != nullptr, "Segment finder not owinging any fusion" ); |
512 | return segmented_fusion_->completeFusion(); |
513 | } |
514 | |
515 | SchedulerRuntimeInfo& runtimeInfo() { |
516 | return runtime_info_; |
517 | } |
518 | |
519 | ExpressionEvaluator& expressionEvaluator() { |
520 | return runtime_info_.expressionEvaluator(); |
521 | } |
522 | |
523 | //! Additional merging iteration, clean up the rest of |
524 | //! the merging opportunities |
525 | //! Herrmann et al. is a fast and safe algorithm for finding merge candidates |
526 | //! but can become too conservative in our use cases because we place |
527 | //! additional qualifiers on valid merges other than having to generate DAGs, |
528 | //! i.e. canSchedule. So we need a bruteforce final merging iteration as a |
529 | //! clean up pass. Cost isn't expected to be high since the graph at this |
530 | //! stage is already quite merged. Example cf. test_gpu.cpp: |
531 | //! FusionDAGMerging_CUDA |
532 | //! |
533 | //! This merging algorithm is based on Theorem 4.1 of Herrmann et al., |
534 | //! to check if a producer-consumer pair can be merged into one group, |
535 | //! it's enough to check if any other consumer of the producer also |
536 | //! produces the consumer. |
537 | void finalMerge(); |
538 | |
539 | //! Duplicate and add all exprs producing the used |
540 | //! scalar values in group |
541 | void resolveScalarsInGroup(SegmentedGroup* group); |
542 | |
543 | //! Duplicate and add all exprs from "inputs" in the group, to complete |
544 | //! inputs. These expressions are simply unary ops of inputs that we want to |
545 | //! recompute for each segment, instead of computing and producing a segmented |
546 | //! val. For example if we have: |
547 | //! tv1 = tv0 * 2; |
548 | //! tv3 = tv1 + tv2; |
549 | //! tv4 = tv1 + tv4 |
550 | //! If we segmented on tv1, we would be producing an output for tv1 for 2 |
551 | //! groups that have tv3 or tv4, instead we could easily recompute tv1 from |
552 | //! tv0. |
553 | void resolveInputsInGroup(SegmentedGroup* group); |
554 | |
555 | //! Remove all scalar edges in group |
556 | //! (TODO: need structure better so we don't have to do this) |
557 | void removeScalarEdges(); |
558 | |
559 | //! Utility function to merge a vector of groups in one step, |
560 | //! need to check for DAG condition before using this method |
561 | SegmentedGroup* mergeAllGivenGroups( |
562 | const std::vector<SegmentedGroup*>& groups); |
563 | |
564 | //! Utility to remove a group and corresponding edges |
565 | //! TODO: remove inline versions of this as much as possible |
566 | void eraseGroups(std::unordered_set<SegmentedGroup*>& groups_to_erase); |
567 | |
568 | void finalize(); |
569 | |
570 | //! Return the resulting heuristic corresponding to the merged |
571 | //! group built by merging the two groups connected by edge |
572 | ScheduleHeuristic deriveHeuristic(SegmentedGroup* edge); |
573 | |
574 | GroupDependencyAnalysis* getGroupDependency(); |
575 | |
576 | protected: |
577 | //! These are the merge node heuristic passes, should |
578 | //! eventually should have a dedicated interface |
579 | //! instead of keeping adding friends |
580 | friend class CombineReductions; |
581 | |
582 | //! options to configure and debug the segment process |
583 | SegmentCandidateFinderOptions options_; |
584 | |
585 | std::deque<SegmentedGroup*> to_visit_; |
586 | std::vector<SegmentedGroup*> next_to_visit_; |
587 | |
588 | std::unordered_set<SegmentedGroup*> clean_up_groups_; |
589 | std::unordered_set<SegmentedEdge*> clean_up_edges_; |
590 | |
591 | std::vector<SegmentedGroup*> to_merge_; |
592 | |
593 | std::unique_ptr<SegmentedFusion> segmented_fusion_; |
594 | |
595 | std::unique_ptr<SegmenterAnalysis> group_dependency_; |
596 | |
597 | SchedulerRuntimeInfo runtime_info_; |
598 | |
599 | //! Note: |
600 | //! Segmenter should eventually rely only on runtime_info_ for |
601 | //! safe caching. runtime_inputs_ is only used in translateWelford |
602 | //! to initialize expression evaluators on copies of the original |
603 | //! fusion, which doesn't use any un-cached info and is safe. |
604 | //! |
605 | //! Directly using runtime_inputs_ in other cases is in general |
606 | //! risky. |
607 | //! |
608 | //! To get rid of runtime_inputs_ we need mechanisms |
609 | //! to copy expression evaluator values from fusion |
610 | //! to a copy, or even better to a copy of a |
611 | //! sub-graph of original fusion. |
612 | //! TODO: |
613 | //! implement the expression evaluator transfer and |
614 | //! remove runtime_inputs_ in a follow up. |
615 | const KernelArgumentHolder& runtime_inputs_; |
616 | }; |
617 | |
618 | // TODO: Make as member functions on classes instead of global scope |
619 | TORCH_CUDA_CU_API std::string toString(const SegmentedGroup* group); |
620 | TORCH_CUDA_CU_API std::string toString(const SegmentedEdge* edge); |
621 | TORCH_CUDA_CU_API std::string toString(const SegmentedFusion* segmented_fusion); |
622 | TORCH_CUDA_CU_API std::string toString( |
623 | const SegmentCandidateFinderOptions& segment_options); |
624 | |
625 | } // namespace cuda |
626 | } // namespace fuser |
627 | } // namespace jit |
628 | } // namespace torch |
629 | |