1#include <iter_visitor.h>
2
3#include <fusion.h>
4#include <ir_all_nodes.h>
5#include <ir_iostream.h>
6#include <ir_utils.h>
7#include <type.h>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14/* ITER VISITOR */
15
16namespace {
17
18// Remove any stmt in stmts that is in visited
19void remove_visited(
20 std::vector<Statement*>& stmts,
21 const std::unordered_set<Statement*>& visited) {
22 std::deque<std::vector<Statement*>::iterator> to_erase;
23 for (auto it = stmts.begin(); it != stmts.end(); it++) {
24 if (visited.find(*it) != visited.end()) {
25 to_erase.push_back(it);
26 }
27 }
28
29 while (!to_erase.empty()) {
30 stmts.erase(to_erase.back());
31 to_erase.pop_back();
32 }
33}
34
35class MemberStatements : public OptOutDispatch {
36 public:
37 // Return all members of the stmt if it's a Val. For expressions it returns
38 // nothing.
39 static std::vector<Statement*> next(Statement* stmt) {
40 MemberStatements find_next(stmt);
41 return find_next.next_stmts_;
42 }
43
44 private:
45 MemberStatements() = default;
46
47 MemberStatements(Statement* stmt) {
48 handle(stmt);
49 }
50
51 using OptOutDispatch::handle;
52
53 void handle(Val* val) final {
54 FusionGuard::getCurFusion()->assertInContainer(
55 val,
56 "IterVisitor.cpp::MemberStatements::handle(Val*) Cannot traverse val, ");
57 OptOutDispatch::handle(val);
58 }
59
60 void handle(IterDomain* stmt) final {
61 next_stmts_.push_back(stmt->start());
62 next_stmts_.push_back(stmt->extent());
63 next_stmts_.push_back(stmt->stopOffset());
64 }
65
66 void handle(TensorDomain* stmt) final {
67 next_stmts_.insert(
68 next_stmts_.end(), stmt->domain().begin(), stmt->domain().end());
69 }
70
71 void handle(TensorView* tv) final {
72 next_stmts_.push_back(tv->domain());
73 }
74
75 std::vector<Statement*> next_stmts_;
76};
77
78} // namespace
79
80std::vector<Statement*> IterVisitor::next(Statement* stmt) {
81 if (stmt->isVal()) {
82 return next(stmt->as<Val>());
83 } else {
84 return next(stmt->as<Expr>());
85 }
86}
87
88std::vector<Statement*> IterVisitor::next(Val* v) {
89 FusionGuard::getCurFusion()->assertInContainer(v, "Cannot traverse val, ");
90 if (v->definition() != nullptr) {
91 return {v->definition()};
92 }
93 return {};
94}
95
96std::vector<Statement*> IterVisitor::next(Expr* expr) {
97 FusionGuard::getCurFusion()->assertInContainer(
98 expr, "Cannot traverse expr, ");
99 std::vector<Statement*> next_stmts{
100 expr->inputs().begin(), expr->inputs().end()};
101 return next_stmts;
102}
103
104// This handle functions is called on every Statement* in topological order,
105// starting from outputs to inputs.
106void IterVisitor::handle(Statement* s) {
107 OptOutDispatch::handle(s);
108}
109
110// This handle functions is called on every Expr* in topological order,
111// starting from outputs to inputs.
112void IterVisitor::handle(Expr* e) {
113 OptOutDispatch::handle(e);
114}
115
116// This handle functions is called on every Val* in topological order,
117// starting from outputs to inputs.
118void IterVisitor::handle(Val* v) {
119 OptOutDispatch::handle(v);
120}
121
122// Implementation details:
123// We start with an entry in stmt_stack that is the outputs we want to
124// process. We cannot process these outputs untill all Stmts in their history
125// have been processed, as those Stmts contain all dependencies to produce
126// these values. What we will do is traverse towards inputs until we hit a
127// leaf node. Once we hit a leaf node that node will be visited, then we will
128// take them off the stack. Once a stack entry is empty, know everything
129// needed to be visited to visit stmt_stack.back().back(). We then visit that
130// node, make it as visisted and remove it from the stack.
131//
132// To prevent traversing all paths through a DAG (unless we want to) we have a
133// function to remove visited nodes from being re-added to the stack
134// (remove_visited).
135void IterVisitor::traverseBetween(
136 Fusion* fusion,
137 const std::unordered_set<Val*>& from,
138 const std::vector<Val*>& to,
139 bool traverse_all_paths,
140 bool traverse_into_members) {
141 FusionGuard fg(fusion);
142
143 std::unordered_set<Statement*> visited;
144
145 stmt_stack.clear();
146 stmt_stack.emplace_back(to.rbegin(), to.rend());
147
148 bool all_inputs_visited = false;
149
150 while (!stmt_stack.empty()) {
151 auto& current_inputs = stmt_stack.back();
152
153 // If current_inputs is empty, pop a level of the stmt_stack, mark the level
154 // we pop to as having all inputs processed, the layer we processed were all
155 // added inputs required for that Stmt.
156 if (current_inputs.empty()) {
157 stmt_stack.pop_back();
158 all_inputs_visited = true;
159 continue;
160 }
161
162 // Get the very last entry in the stack to process
163 const auto& stmt = current_inputs.back();
164
165 // If we just poped a stmt_stack level, we can finally visit it!
166 if (all_inputs_visited) {
167 // stmt may have be already visited.
168 if (traverse_all_paths || visited.find(stmt) == visited.end()) {
169 // Mark visited
170 visited.insert(stmt);
171
172 // Actually visit stmt
173 handle(stmt);
174 }
175
176 // Remove last value just visited
177 current_inputs.pop_back();
178
179 // Mark that we need to visit a new Stmt's.
180 all_inputs_visited = false;
181 } else {
182 // We're not ready to process this node, so add all its inputs to be
183 // checked Visit input nodes.
184 std::vector<Statement*> next_stmts;
185
186 if ((stmt->isVal() && from.find(stmt->asVal()) == from.end()) ||
187 stmt->isExpr()) {
188 next_stmts = next(stmt);
189 }
190
191 if (traverse_into_members) {
192 auto members = MemberStatements::next(stmt);
193 next_stmts.insert(next_stmts.end(), members.begin(), members.end());
194 }
195
196 // We may want to retraverse nodes, in that case revisit everything!
197 if (!traverse_all_paths) {
198 // If we don't want to retraverse, remove nodes we already visisted.
199 remove_visited(next_stmts, visited);
200 }
201 if (next_stmts.empty()) {
202 // If there's nothing to visit because it was all already visited, mark
203 // to process
204 all_inputs_visited = true;
205 } else {
206 // Add all these new stmts to visit to the stack.
207 stmt_stack.emplace_back(next_stmts.rbegin(), next_stmts.rend());
208 // We have new things to visit,
209 all_inputs_visited = false;
210 }
211 }
212 }
213}
214
215void IterVisitor::traverseTo(
216 Fusion* fusion,
217 const std::vector<Val*>& to,
218 bool traverse_all_paths,
219 bool traverse_into_members) {
220 traverseBetween(fusion, {}, to, traverse_all_paths, traverse_into_members);
221}
222
223void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) {
224 FusionGuard fg(fusion);
225
226 auto term_val_outs = fusion->getTerminatingOutputs();
227 if (!term_val_outs.empty()) {
228 traverseTo(fusion, term_val_outs, traverse_all_paths);
229 }
230}
231
232void IterVisitor::traverse(Fusion* fusion) {
233 traverseHelper(fusion, false);
234}
235
236void IterVisitor::traverseAllPaths(Fusion* fusion) {
237 traverseHelper(fusion, true);
238}
239
240namespace {
241
242// TODO: Also have InputsOf should pick one and remove the other.
243class Inputs : public IterVisitor {
244 private:
245 //! Optional list of input vals. While traversing to inputs if a value in the
246 //! all_inputs list is found, that value will be added to the inputs_ and
247 //! traversal will not go into its definition. Otherwise traversal follows
248 //! definition paths until hitting a definition that is a nullptr (i.e. a
249 //! terminating input).
250 const std::vector<Val*>& all_inputs_;
251 std::vector<Val*> inputs_;
252
253 Inputs(const std::vector<Val*>& all_inputs) : all_inputs_(all_inputs) {}
254
255 std::vector<Statement*> next(Val* v) override {
256 if (std::find(inputs_.begin(), inputs_.end(), v) != inputs_.end()) {
257 return {};
258 }
259 return IterVisitor::next(v);
260 }
261
262 void handle(Val* val) override {
263 // If there's no definition to val, or val is created inside the fusion, or
264 // val is within the provided inputs
265 if (val->definition() == nullptr || val->definition()->inputs().empty() ||
266 std::find(all_inputs_.begin(), all_inputs_.end(), val) !=
267 all_inputs_.end()) {
268 // if not already placed in the inputs
269 if (std::find(inputs_.begin(), inputs_.end(), val) == inputs_.end()) {
270 inputs_.push_back(val);
271 }
272 }
273 }
274
275 public:
276 static std::vector<Val*> getInputs(
277 const std::vector<Val*>& of,
278 const std::vector<Val*>& all_inputs) {
279 if (of.empty()) {
280 return {};
281 }
282 Inputs inps(all_inputs);
283 inps.traverseTo(of[0]->fusion(), of);
284 return inps.inputs_;
285 }
286};
287
288} // namespace
289
290std::vector<Val*> IterVisitor::getInputsTo(
291 const std::vector<Val*>& vals,
292 const std::vector<Val*>& inputs) {
293 return Inputs::getInputs(vals, inputs);
294}
295
296namespace {
297
298class AllVals : public IterVisitor {
299 private:
300 std::unordered_set<Val*> vals;
301
302 void handle(Val* val) final {
303 vals.emplace(val);
304 }
305
306 public:
307 // Return all values in history of all values in from
308 static std::unordered_set<Val*> get(
309 Fusion* fusion,
310 const std::vector<Val*>& from) {
311 AllVals av;
312 av.traverseTo(fusion, from, false);
313 return av.vals;
314 }
315};
316
317} // namespace
318
319/* BACKWARDS VISITOR */
320
321std::vector<Statement*> BackwardVisitor::next(Statement* stmt) {
322 if (stmt->isVal()) {
323 return next(stmt->as<Val>());
324 } else if (stmt->isExpr()) {
325 return next(stmt->as<Expr>());
326 } else {
327 TORCH_INTERNAL_ASSERT(
328 false, "BackwardVisitor could not detect type in next_dispatch.");
329 }
330}
331
332std::vector<Statement*> BackwardVisitor::next(Expr* expr) {
333 return std::vector<Statement*>(
334 expr->outputs().begin(), expr->outputs().end());
335}
336
337std::vector<Statement*> BackwardVisitor::next(Val* val) {
338 // Going to sort based on relative topological position
339 std::map<size_t, Statement*> exprs;
340
341 for (auto expr : FusionGuard::getCurFusion()->unordered_uses(val)) {
342 // Make sure it's an expr we can traverse
343 if (traversal_exprs_.find(expr) != traversal_exprs_.end()) {
344 exprs[traversal_exprs_[expr]] = expr;
345 }
346 }
347
348 std::vector<Statement*> next_stmts(exprs.size());
349 std::transform(
350 exprs.begin(),
351 exprs.end(),
352 next_stmts.begin(),
353 [](std::pair<size_t, Statement*> pair) { return pair.second; });
354
355 return next_stmts;
356}
357
358void BackwardVisitor::handle(Statement* stmt) {
359 OptOutDispatch::handle(stmt);
360}
361
362void BackwardVisitor::handle(Expr* expr) {
363 OptOutDispatch::handle(expr);
364}
365
366void BackwardVisitor::handle(Val* val) {
367 OptOutDispatch::handle(val);
368}
369
370void BackwardVisitor::traverseTo(
371 Fusion* fusion,
372 const std::vector<Val*>& from,
373 bool traverseAllPaths) {
374 FusionGuard fg(fusion);
375
376 // Reset members
377 stmt_stack_.clear();
378 traversal_exprs_.clear();
379
380 if (from.empty()) {
381 return;
382 }
383
384 auto vals = AllVals::get(fusion, from);
385 auto exprs = StmtSort::getExprs(fusion, from);
386
387 {
388 size_t pos = 0;
389 for (auto expr : exprs)
390 traversal_exprs_[expr] = pos++;
391 }
392
393 // All stmts we've called handle on
394 std::unordered_set<Statement*> visited_stmts_;
395
396 if (must_cover_all_expr_outputs_) {
397 for (auto traversal_pair : traversal_exprs_) {
398 for (auto out : traversal_pair.first->outputs()) {
399 TORCH_INTERNAL_ASSERT(
400 vals.find(out) != vals.end(),
401 "Invalid backward traversal found. Some output paths were not provided:",
402 out);
403 }
404 }
405 }
406
407 auto inputs = InputsOf::getInputsTo(from);
408 stmt_stack_.emplace_back(inputs.begin(), inputs.end());
409
410 // The rest is basically copy-pasted from IterVitor:
411 while (!stmt_stack_.empty()) {
412 auto next_stmts = next(stmt_stack_.back().back());
413
414 // Remove statements we already visited if we're not traversing all paths
415 if (!traverseAllPaths) {
416 remove_visited(next_stmts, visited_stmts_);
417 }
418
419 // Traverse down until we get to a leaf
420 while (!next_stmts.empty()) {
421 stmt_stack_.emplace_back(next_stmts.rbegin(), next_stmts.rend());
422 next_stmts = next(stmt_stack_.back().back());
423 // Remove statements we already visited if we're not traversing all paths
424 if (!traverseAllPaths) {
425 remove_visited(next_stmts, visited_stmts_);
426 }
427 }
428
429 // Traverse back up
430 // Mark visited
431 visited_stmts_.emplace(stmt_stack_.back().back());
432 // Handle
433 handle(stmt_stack_.back().back());
434 // Remove
435 stmt_stack_.back().pop_back();
436
437 while (!stmt_stack_.empty() && stmt_stack_.back().empty()) {
438 stmt_stack_.pop_back();
439 if (!stmt_stack_.empty()) {
440 // Mark visited
441 visited_stmts_.emplace(stmt_stack_.back().back());
442 // Handle
443 handle(stmt_stack_.back().back());
444 // Remove
445 stmt_stack_.back().pop_back();
446 }
447 }
448 }
449}
450
451/* DEPENDENCY CHECKING */
452
453namespace {
454
455// Looks for and returns all values in between dependencies and vals, including
456// them.
457struct Dependencies : public IterVisitor {
458 private:
459 //! A given set of dependency Vals
460 const std::unordered_set<Val*> dependencies_;
461 //! Vals that are found between dependencies_ and of. Topologically
462 //! ordered.
463 std::vector<Val*> vals_;
464 //! Exprs that are found between dependencies_ and of. Topologically
465 //! ordered.
466 std::vector<Expr*> exprs_;
467 //! A set version of vals_
468 std::unordered_set<Val*> dependent_vals_;
469 //! A set version of exprs_
470 std::unordered_set<Expr*> dependent_exprs_;
471
472 private:
473 std::vector<Statement*> next(Val* v) override {
474 if (dependencies_.find(v) != dependencies_.end()) {
475 return std::vector<Statement*>();
476 }
477 return IterVisitor::next(v);
478 }
479
480 void handle(Val* val) override {
481 // val is included if:
482 // 1. it is one of the dependencies, or
483 // 2. its defining expression is included in the dependent expr set
484 if (dependencies_.find(val) != dependencies_.end()) {
485 TORCH_INTERNAL_ASSERT(
486 dependent_vals_.find(val) == dependent_vals_.end(),
487 "Trying to add already added val: ",
488 val);
489 vals_.push_back(val);
490 dependent_vals_.insert(val);
491 } else {
492 auto def = val->definition();
493 if (def != nullptr &&
494 dependent_exprs_.find(def) != dependent_exprs_.end()) {
495 TORCH_INTERNAL_ASSERT(
496 dependent_vals_.find(val) == dependent_vals_.end(),
497 "Trying to add already added val: ",
498 val);
499 vals_.push_back(val);
500 dependent_vals_.insert(val);
501 }
502 }
503 }
504
505 void handle(Expr* expr) override {
506 // Track which expr is depedent on the dependencies_ exprs.
507 if (std::any_of(
508 expr->inputs().begin(), expr->inputs().end(), [&](Val* input_val) {
509 return dependent_vals_.find(input_val) != dependent_vals_.end();
510 })) {
511 if (!dependent_exprs_.count(expr)) {
512 exprs_.push_back(expr);
513 dependent_exprs_.insert(expr);
514 }
515 }
516 }
517
518 Dependencies(
519 std::unordered_set<Val*> _dependencies,
520 const std::vector<Val*>& of)
521 : dependencies_(std::move(_dependencies)) {
522 traverseTo(of[0]->fusion(), of, false);
523 };
524
525 public:
526 static std::vector<Val*> getAllVals(
527 const std::unordered_set<Val*>& dependencies,
528 const std::vector<Val*>& of) {
529 if (of.empty()) {
530 return {};
531 }
532
533 Dependencies deps(dependencies, of);
534 return deps.vals_;
535 }
536
537 static std::vector<Expr*> getAllExprs(
538 const std::unordered_set<Val*>& dependencies,
539 const std::vector<Val*>& of) {
540 if (of.empty()) {
541 return {};
542 }
543
544 Dependencies deps(dependencies, of);
545 return deps.exprs_;
546 }
547};
548
549// Looks for and returns all output values with dependencies on `of`.
550struct FindOutputs : public IterVisitor {
551 const std::unordered_set<Val*>& of_;
552 std::unordered_set<Val*> outs_;
553
554 void handle(Val* val) override {
555 if (of_.find(val) != of_.end()) {
556 Statement* out_stmt = stmt_stack.front().back();
557 TORCH_INTERNAL_ASSERT(out_stmt->isVal());
558 auto out_val = out_stmt->as<Val>();
559 if (of_.find(out_val) == of_.end()) {
560 outs_.emplace(out_val);
561 }
562 }
563 }
564
565 // TODO: Simply traverse through uses from of. Would be a lot faster than
566 // tracing all paths like this.
567 FindOutputs(const std::unordered_set<Val*>& _of) : of_(_of) {
568 auto fusion = (*of_.begin())->fusion();
569 traverseTo(fusion, fusion->outputs(), true);
570 };
571
572 static std::unordered_set<Val*> getAllOutputsOf(
573 const std::unordered_set<Val*>& of) {
574 if (of.empty()) {
575 return std::unordered_set<Val*>();
576 }
577
578 FindOutputs finder(of);
579 return finder.outs_;
580 }
581};
582
583// Looks for and returns all values that depends on `of`.
584class DependentVals : public IterVisitor {
585 private:
586 // Which nodes to find dependencies of
587 const std::unordered_set<Val*>& of_;
588
589 // Dependencies we have so far
590 std::unordered_set<Val*> outs_;
591
592 // Boundary where we want to stop searching beyond
593 // TODO: Based on the todo below, shouldn't we stop just at the definition of?
594 // If we really wanted to make this traverse left, wouldn't we first check
595 // which outputs are outputs dependent on of?
596 std::unordered_set<Val*> boundary_;
597
598 std::vector<Statement*> next(Val* v) override {
599 if (boundary_.find(v) != boundary_.end())
600 return std::vector<Statement*>();
601 return IterVisitor::next(v);
602 }
603
604 void handle(Val* val) override {
605 if (val->isFusionInput() || val->definition() == nullptr ||
606 of_.count(val) || outs_.count(val)) {
607 return;
608 }
609
610 for (auto v : val->definition()->inputs()) {
611 if (of_.count(v) || outs_.count(v)) {
612 outs_.emplace(val);
613 return;
614 }
615 }
616 }
617
618 // optimization to limit search path
619 // TODO: Is this valid? Couldn't something like:
620 // out0 = of + val0
621 // out1 = out0 + val1
622 // out2 = TernaryOp(out1, val0, of)
623 // Hide the dep of out1 on of?
624 void createBoundary() {
625 for (auto v_of : of_) {
626 for (auto v_expr : v_of->uses()) {
627 for (auto v_in : v_expr->inputs()) {
628 boundary_.emplace(v_in);
629 }
630 }
631 }
632 }
633
634 DependentVals(const std::unordered_set<Val*>& _of) : of_(_of) {
635 createBoundary();
636 auto fusion = (*of_.begin())->fusion();
637 traverseTo(fusion, fusion->outputs(), false);
638 };
639
640 public:
641 static std::unordered_set<Val*> getAllDependentVals(
642 const std::unordered_set<Val*>& of) {
643 if (of.empty()) {
644 return std::unordered_set<Val*>();
645 }
646 DependentVals dependencies(of);
647 return dependencies.outs_;
648 }
649};
650
651class DependencyChains : public IterVisitor {
652 public:
653 std::deque<std::deque<Val*>> dep_chains;
654 bool is_dependency = false;
655 std::unordered_set<Val*> dependencies_;
656
657 void handle(Val* val) override {
658 if (dependencies_.find(val) != dependencies_.end()) {
659 is_dependency = true;
660 std::deque<Val*> deps;
661 for (auto stack : stmt_stack) {
662 if (stack.back()->isVal()) {
663 deps.push_back(stack.back()->as<Val>());
664 }
665 }
666 // Order as dependency -> of
667 dep_chains.emplace_back(deps.rbegin(), deps.rend());
668 }
669 }
670
671 DependencyChains(Val* _dependency, Val* _of, bool all_chains_ = false)
672 : dependencies_({_dependency}) {
673 traverseTo(_of->fusion(), {_of}, all_chains_);
674 }
675
676 DependencyChains(Val* _dependency, bool all_chains_ = false)
677 : dependencies_({_dependency}) {
678 if (all_chains_) {
679 traverseAllPaths(_dependency->fusion());
680 } else {
681 traverse(_dependency->fusion());
682 }
683 }
684
685 DependencyChains(
686 std::unordered_set<Val*> _dependencies,
687 bool all_chains_ = false)
688 : dependencies_(std::move(_dependencies)) {
689 if (dependencies_.empty()) {
690 return;
691 }
692
693 if (all_chains_) {
694 traverseAllPaths((*dependencies_.begin())->fusion());
695 } else {
696 traverse((*dependencies_.begin())->fusion());
697 }
698 }
699
700 static std::deque<Val*> getDependencyChain(Val* dependency, Val* of) {
701 DependencyChains dp(dependency, of, false);
702 if (dp.dep_chains.empty()) {
703 return std::deque<Val*>();
704 }
705 return dp.dep_chains[0];
706 }
707
708 // I don't think this is actually hooked up, but leaving for now.
709 static std::deque<std::deque<Val*>> getDependencyChains(
710 Val* dependency,
711 Val* of) {
712 DependencyChains dp(dependency, of, true);
713 if (dp.dep_chains.empty()) {
714 return std::deque<std::deque<Val*>>();
715 }
716 return dp.dep_chains;
717 }
718
719 static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency) {
720 DependencyChains dp(dependency, true);
721 if (dp.dep_chains.empty()) {
722 return std::deque<std::deque<Val*>>();
723 }
724 return dp.dep_chains;
725 }
726
727 static std::deque<std::deque<Val*>> getAllUseChains(
728 const std::unordered_set<Val*>& dependencies) {
729 DependencyChains dp(dependencies, true);
730 if (dp.dep_chains.empty()) {
731 return std::deque<std::deque<Val*>>();
732 }
733 return dp.dep_chains;
734 }
735};
736
737} // namespace
738
739bool DependencyCheck::isDependencyOf(Val* dependency, Val* of) {
740 return !DependencyChains::getDependencyChain(dependency, of).empty();
741}
742
743std::deque<Val*> DependencyCheck::getSingleDependencyChain(
744 Val* dependency,
745 Val* of) {
746 return DependencyChains::getDependencyChain(dependency, of);
747}
748
749std::deque<std::deque<Val*>> DependencyCheck::getAllDependencyChains(
750 Val* dependency,
751 Val* of) {
752 return DependencyChains::getDependencyChains(dependency, of);
753}
754
755std::deque<std::deque<Val*>> DependencyCheck::getAllUseChains(Val* producer) {
756 return DependencyChains::getAllUseChains(producer);
757}
758
759std::vector<Val*> DependencyCheck::getAllValsBetween(
760 const std::unordered_set<Val*>& dependencies,
761 const std::vector<Val*>& of) {
762 return Dependencies::getAllVals(dependencies, of);
763}
764
765std::vector<Expr*> DependencyCheck::getAllExprsBetween(
766 const std::unordered_set<Val*>& dependencies,
767 const std::vector<Val*>& of) {
768 return Dependencies::getAllExprs(dependencies, of);
769}
770
771std::unordered_set<Val*> DependencyCheck::getAllOutputsOf(
772 const std::unordered_set<Val*>& of) {
773 if (of.empty()) {
774 return std::unordered_set<Val*>();
775 }
776 FusionGuard fg((*of.begin())->fusion());
777 return FindOutputs::getAllOutputsOf(of);
778}
779
780std::unordered_set<Val*> DependencyCheck::getAllDependentVals(
781 const std::unordered_set<Val*>& of) {
782 if (of.empty()) {
783 return std::unordered_set<Val*>();
784 }
785 FusionGuard fg((*of.begin())->fusion());
786 return DependentVals::getAllDependentVals(of);
787}
788
789void StmtSort::handle(Statement* stmt) {
790 stmts.push_back(stmt);
791}
792
793std::vector<Expr*> StmtSort::getExprs(Fusion* fusion, bool traverse_members) {
794 auto terminating_outputs = fusion->getTerminatingOutputs();
795 return StmtSort::getExprs(fusion, terminating_outputs, traverse_members);
796}
797
798std::vector<Expr*> StmtSort::getExprs(
799 Fusion* fusion,
800 const std::vector<Val*>& to,
801 bool traverse_members) {
802 auto stmts = StmtSort::getStmts(fusion, to, traverse_members);
803 auto filter = ir_utils::filterByType<Expr>(stmts.begin(), stmts.end());
804 std::vector<Expr*> exprs(filter.begin(), filter.end());
805 return exprs;
806}
807
808std::vector<Expr*> StmtSort::getExprsBetween(
809 Fusion* fusion,
810 const std::vector<Val*>& from,
811 const std::vector<Val*>& to,
812 bool traverse_members) {
813 auto stmts = StmtSort::getStmtsBetween(fusion, from, to, traverse_members);
814 auto filter = ir_utils::filterByType<Expr>(stmts.begin(), stmts.end());
815 std::vector<Expr*> exprs(filter.begin(), filter.end());
816 return exprs;
817}
818
819std::vector<Statement*> StmtSort::getStmts(
820 Fusion* fusion,
821 bool traverse_members) {
822 auto terminating_outputs = fusion->getTerminatingOutputs();
823 return StmtSort::getStmts(fusion, terminating_outputs, traverse_members);
824}
825
826std::vector<Statement*> StmtSort::getStmts(
827 Fusion* fusion,
828 const std::vector<Val*>& to,
829 bool traverse_members) {
830 StmtSort es;
831 es.traverseTo(fusion, to, false, traverse_members);
832 return es.stmts;
833}
834
835std::vector<Statement*> StmtSort::getStmtsBetween(
836 Fusion* fusion,
837 const std::vector<Val*>& from,
838 const std::vector<Val*>& to,
839 bool traverse_members) {
840 StmtSort es;
841 es.traverseBetween(
842 fusion, {from.begin(), from.end()}, to, false, traverse_members);
843 return es.stmts;
844}
845
846void InputsOf::handle(Val* v) {
847 if (v->definition() == nullptr || v->definition()->inputs().empty()) {
848 if (grabbed_inputs.emplace(v).second) {
849 ordered_inputs.push_back(v);
850 }
851 }
852}
853
854std::vector<Val*> InputsOf::output(Fusion* fusion, Val* output_) {
855 return outputs(fusion, {output_});
856}
857
858std::vector<Val*> InputsOf::outputs(
859 Fusion* fusion,
860 const std::vector<Val*>& outputs_) {
861 InputsOf io;
862 io.traverseTo(fusion, outputs_, false);
863 return io.ordered_inputs;
864}
865
866} // namespace cuda
867} // namespace fuser
868} // namespace jit
869} // namespace torch
870