1 | #include <iter_visitor.h> |
2 | |
3 | #include <fusion.h> |
4 | #include <ir_all_nodes.h> |
5 | #include <ir_iostream.h> |
6 | #include <ir_utils.h> |
7 | #include <type.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | /* ITER VISITOR */ |
15 | |
16 | namespace { |
17 | |
18 | // Remove any stmt in stmts that is in visited |
19 | void remove_visited( |
20 | std::vector<Statement*>& stmts, |
21 | const std::unordered_set<Statement*>& visited) { |
22 | std::deque<std::vector<Statement*>::iterator> to_erase; |
23 | for (auto it = stmts.begin(); it != stmts.end(); it++) { |
24 | if (visited.find(*it) != visited.end()) { |
25 | to_erase.push_back(it); |
26 | } |
27 | } |
28 | |
29 | while (!to_erase.empty()) { |
30 | stmts.erase(to_erase.back()); |
31 | to_erase.pop_back(); |
32 | } |
33 | } |
34 | |
35 | class MemberStatements : public OptOutDispatch { |
36 | public: |
37 | // Return all members of the stmt if it's a Val. For expressions it returns |
38 | // nothing. |
39 | static std::vector<Statement*> next(Statement* stmt) { |
40 | MemberStatements find_next(stmt); |
41 | return find_next.next_stmts_; |
42 | } |
43 | |
44 | private: |
45 | MemberStatements() = default; |
46 | |
47 | MemberStatements(Statement* stmt) { |
48 | handle(stmt); |
49 | } |
50 | |
51 | using OptOutDispatch::handle; |
52 | |
53 | void handle(Val* val) final { |
54 | FusionGuard::getCurFusion()->assertInContainer( |
55 | val, |
56 | "IterVisitor.cpp::MemberStatements::handle(Val*) Cannot traverse val, " ); |
57 | OptOutDispatch::handle(val); |
58 | } |
59 | |
60 | void handle(IterDomain* stmt) final { |
61 | next_stmts_.push_back(stmt->start()); |
62 | next_stmts_.push_back(stmt->extent()); |
63 | next_stmts_.push_back(stmt->stopOffset()); |
64 | } |
65 | |
66 | void handle(TensorDomain* stmt) final { |
67 | next_stmts_.insert( |
68 | next_stmts_.end(), stmt->domain().begin(), stmt->domain().end()); |
69 | } |
70 | |
71 | void handle(TensorView* tv) final { |
72 | next_stmts_.push_back(tv->domain()); |
73 | } |
74 | |
75 | std::vector<Statement*> next_stmts_; |
76 | }; |
77 | |
78 | } // namespace |
79 | |
80 | std::vector<Statement*> IterVisitor::next(Statement* stmt) { |
81 | if (stmt->isVal()) { |
82 | return next(stmt->as<Val>()); |
83 | } else { |
84 | return next(stmt->as<Expr>()); |
85 | } |
86 | } |
87 | |
88 | std::vector<Statement*> IterVisitor::next(Val* v) { |
89 | FusionGuard::getCurFusion()->assertInContainer(v, "Cannot traverse val, " ); |
90 | if (v->definition() != nullptr) { |
91 | return {v->definition()}; |
92 | } |
93 | return {}; |
94 | } |
95 | |
96 | std::vector<Statement*> IterVisitor::next(Expr* expr) { |
97 | FusionGuard::getCurFusion()->assertInContainer( |
98 | expr, "Cannot traverse expr, " ); |
99 | std::vector<Statement*> next_stmts{ |
100 | expr->inputs().begin(), expr->inputs().end()}; |
101 | return next_stmts; |
102 | } |
103 | |
104 | // This handle functions is called on every Statement* in topological order, |
105 | // starting from outputs to inputs. |
106 | void IterVisitor::handle(Statement* s) { |
107 | OptOutDispatch::handle(s); |
108 | } |
109 | |
110 | // This handle functions is called on every Expr* in topological order, |
111 | // starting from outputs to inputs. |
112 | void IterVisitor::handle(Expr* e) { |
113 | OptOutDispatch::handle(e); |
114 | } |
115 | |
116 | // This handle functions is called on every Val* in topological order, |
117 | // starting from outputs to inputs. |
118 | void IterVisitor::handle(Val* v) { |
119 | OptOutDispatch::handle(v); |
120 | } |
121 | |
122 | // Implementation details: |
123 | // We start with an entry in stmt_stack that is the outputs we want to |
124 | // process. We cannot process these outputs untill all Stmts in their history |
125 | // have been processed, as those Stmts contain all dependencies to produce |
126 | // these values. What we will do is traverse towards inputs until we hit a |
127 | // leaf node. Once we hit a leaf node that node will be visited, then we will |
128 | // take them off the stack. Once a stack entry is empty, know everything |
129 | // needed to be visited to visit stmt_stack.back().back(). We then visit that |
130 | // node, make it as visisted and remove it from the stack. |
131 | // |
132 | // To prevent traversing all paths through a DAG (unless we want to) we have a |
133 | // function to remove visited nodes from being re-added to the stack |
134 | // (remove_visited). |
135 | void IterVisitor::traverseBetween( |
136 | Fusion* fusion, |
137 | const std::unordered_set<Val*>& from, |
138 | const std::vector<Val*>& to, |
139 | bool traverse_all_paths, |
140 | bool traverse_into_members) { |
141 | FusionGuard fg(fusion); |
142 | |
143 | std::unordered_set<Statement*> visited; |
144 | |
145 | stmt_stack.clear(); |
146 | stmt_stack.emplace_back(to.rbegin(), to.rend()); |
147 | |
148 | bool all_inputs_visited = false; |
149 | |
150 | while (!stmt_stack.empty()) { |
151 | auto& current_inputs = stmt_stack.back(); |
152 | |
153 | // If current_inputs is empty, pop a level of the stmt_stack, mark the level |
154 | // we pop to as having all inputs processed, the layer we processed were all |
155 | // added inputs required for that Stmt. |
156 | if (current_inputs.empty()) { |
157 | stmt_stack.pop_back(); |
158 | all_inputs_visited = true; |
159 | continue; |
160 | } |
161 | |
162 | // Get the very last entry in the stack to process |
163 | const auto& stmt = current_inputs.back(); |
164 | |
165 | // If we just poped a stmt_stack level, we can finally visit it! |
166 | if (all_inputs_visited) { |
167 | // stmt may have be already visited. |
168 | if (traverse_all_paths || visited.find(stmt) == visited.end()) { |
169 | // Mark visited |
170 | visited.insert(stmt); |
171 | |
172 | // Actually visit stmt |
173 | handle(stmt); |
174 | } |
175 | |
176 | // Remove last value just visited |
177 | current_inputs.pop_back(); |
178 | |
179 | // Mark that we need to visit a new Stmt's. |
180 | all_inputs_visited = false; |
181 | } else { |
182 | // We're not ready to process this node, so add all its inputs to be |
183 | // checked Visit input nodes. |
184 | std::vector<Statement*> next_stmts; |
185 | |
186 | if ((stmt->isVal() && from.find(stmt->asVal()) == from.end()) || |
187 | stmt->isExpr()) { |
188 | next_stmts = next(stmt); |
189 | } |
190 | |
191 | if (traverse_into_members) { |
192 | auto members = MemberStatements::next(stmt); |
193 | next_stmts.insert(next_stmts.end(), members.begin(), members.end()); |
194 | } |
195 | |
196 | // We may want to retraverse nodes, in that case revisit everything! |
197 | if (!traverse_all_paths) { |
198 | // If we don't want to retraverse, remove nodes we already visisted. |
199 | remove_visited(next_stmts, visited); |
200 | } |
201 | if (next_stmts.empty()) { |
202 | // If there's nothing to visit because it was all already visited, mark |
203 | // to process |
204 | all_inputs_visited = true; |
205 | } else { |
206 | // Add all these new stmts to visit to the stack. |
207 | stmt_stack.emplace_back(next_stmts.rbegin(), next_stmts.rend()); |
208 | // We have new things to visit, |
209 | all_inputs_visited = false; |
210 | } |
211 | } |
212 | } |
213 | } |
214 | |
215 | void IterVisitor::traverseTo( |
216 | Fusion* fusion, |
217 | const std::vector<Val*>& to, |
218 | bool traverse_all_paths, |
219 | bool traverse_into_members) { |
220 | traverseBetween(fusion, {}, to, traverse_all_paths, traverse_into_members); |
221 | } |
222 | |
223 | void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) { |
224 | FusionGuard fg(fusion); |
225 | |
226 | auto term_val_outs = fusion->getTerminatingOutputs(); |
227 | if (!term_val_outs.empty()) { |
228 | traverseTo(fusion, term_val_outs, traverse_all_paths); |
229 | } |
230 | } |
231 | |
232 | void IterVisitor::traverse(Fusion* fusion) { |
233 | traverseHelper(fusion, false); |
234 | } |
235 | |
236 | void IterVisitor::traverseAllPaths(Fusion* fusion) { |
237 | traverseHelper(fusion, true); |
238 | } |
239 | |
240 | namespace { |
241 | |
242 | // TODO: Also have InputsOf should pick one and remove the other. |
243 | class Inputs : public IterVisitor { |
244 | private: |
245 | //! Optional list of input vals. While traversing to inputs if a value in the |
246 | //! all_inputs list is found, that value will be added to the inputs_ and |
247 | //! traversal will not go into its definition. Otherwise traversal follows |
248 | //! definition paths until hitting a definition that is a nullptr (i.e. a |
249 | //! terminating input). |
250 | const std::vector<Val*>& all_inputs_; |
251 | std::vector<Val*> inputs_; |
252 | |
253 | Inputs(const std::vector<Val*>& all_inputs) : all_inputs_(all_inputs) {} |
254 | |
255 | std::vector<Statement*> next(Val* v) override { |
256 | if (std::find(inputs_.begin(), inputs_.end(), v) != inputs_.end()) { |
257 | return {}; |
258 | } |
259 | return IterVisitor::next(v); |
260 | } |
261 | |
262 | void handle(Val* val) override { |
263 | // If there's no definition to val, or val is created inside the fusion, or |
264 | // val is within the provided inputs |
265 | if (val->definition() == nullptr || val->definition()->inputs().empty() || |
266 | std::find(all_inputs_.begin(), all_inputs_.end(), val) != |
267 | all_inputs_.end()) { |
268 | // if not already placed in the inputs |
269 | if (std::find(inputs_.begin(), inputs_.end(), val) == inputs_.end()) { |
270 | inputs_.push_back(val); |
271 | } |
272 | } |
273 | } |
274 | |
275 | public: |
276 | static std::vector<Val*> getInputs( |
277 | const std::vector<Val*>& of, |
278 | const std::vector<Val*>& all_inputs) { |
279 | if (of.empty()) { |
280 | return {}; |
281 | } |
282 | Inputs inps(all_inputs); |
283 | inps.traverseTo(of[0]->fusion(), of); |
284 | return inps.inputs_; |
285 | } |
286 | }; |
287 | |
288 | } // namespace |
289 | |
290 | std::vector<Val*> IterVisitor::getInputsTo( |
291 | const std::vector<Val*>& vals, |
292 | const std::vector<Val*>& inputs) { |
293 | return Inputs::getInputs(vals, inputs); |
294 | } |
295 | |
296 | namespace { |
297 | |
298 | class AllVals : public IterVisitor { |
299 | private: |
300 | std::unordered_set<Val*> vals; |
301 | |
302 | void handle(Val* val) final { |
303 | vals.emplace(val); |
304 | } |
305 | |
306 | public: |
307 | // Return all values in history of all values in from |
308 | static std::unordered_set<Val*> get( |
309 | Fusion* fusion, |
310 | const std::vector<Val*>& from) { |
311 | AllVals av; |
312 | av.traverseTo(fusion, from, false); |
313 | return av.vals; |
314 | } |
315 | }; |
316 | |
317 | } // namespace |
318 | |
319 | /* BACKWARDS VISITOR */ |
320 | |
321 | std::vector<Statement*> BackwardVisitor::next(Statement* stmt) { |
322 | if (stmt->isVal()) { |
323 | return next(stmt->as<Val>()); |
324 | } else if (stmt->isExpr()) { |
325 | return next(stmt->as<Expr>()); |
326 | } else { |
327 | TORCH_INTERNAL_ASSERT( |
328 | false, "BackwardVisitor could not detect type in next_dispatch." ); |
329 | } |
330 | } |
331 | |
332 | std::vector<Statement*> BackwardVisitor::next(Expr* expr) { |
333 | return std::vector<Statement*>( |
334 | expr->outputs().begin(), expr->outputs().end()); |
335 | } |
336 | |
337 | std::vector<Statement*> BackwardVisitor::next(Val* val) { |
338 | // Going to sort based on relative topological position |
339 | std::map<size_t, Statement*> exprs; |
340 | |
341 | for (auto expr : FusionGuard::getCurFusion()->unordered_uses(val)) { |
342 | // Make sure it's an expr we can traverse |
343 | if (traversal_exprs_.find(expr) != traversal_exprs_.end()) { |
344 | exprs[traversal_exprs_[expr]] = expr; |
345 | } |
346 | } |
347 | |
348 | std::vector<Statement*> next_stmts(exprs.size()); |
349 | std::transform( |
350 | exprs.begin(), |
351 | exprs.end(), |
352 | next_stmts.begin(), |
353 | [](std::pair<size_t, Statement*> pair) { return pair.second; }); |
354 | |
355 | return next_stmts; |
356 | } |
357 | |
358 | void BackwardVisitor::handle(Statement* stmt) { |
359 | OptOutDispatch::handle(stmt); |
360 | } |
361 | |
362 | void BackwardVisitor::handle(Expr* expr) { |
363 | OptOutDispatch::handle(expr); |
364 | } |
365 | |
366 | void BackwardVisitor::handle(Val* val) { |
367 | OptOutDispatch::handle(val); |
368 | } |
369 | |
370 | void BackwardVisitor::traverseTo( |
371 | Fusion* fusion, |
372 | const std::vector<Val*>& from, |
373 | bool traverseAllPaths) { |
374 | FusionGuard fg(fusion); |
375 | |
376 | // Reset members |
377 | stmt_stack_.clear(); |
378 | traversal_exprs_.clear(); |
379 | |
380 | if (from.empty()) { |
381 | return; |
382 | } |
383 | |
384 | auto vals = AllVals::get(fusion, from); |
385 | auto exprs = StmtSort::getExprs(fusion, from); |
386 | |
387 | { |
388 | size_t pos = 0; |
389 | for (auto expr : exprs) |
390 | traversal_exprs_[expr] = pos++; |
391 | } |
392 | |
393 | // All stmts we've called handle on |
394 | std::unordered_set<Statement*> visited_stmts_; |
395 | |
396 | if (must_cover_all_expr_outputs_) { |
397 | for (auto traversal_pair : traversal_exprs_) { |
398 | for (auto out : traversal_pair.first->outputs()) { |
399 | TORCH_INTERNAL_ASSERT( |
400 | vals.find(out) != vals.end(), |
401 | "Invalid backward traversal found. Some output paths were not provided:" , |
402 | out); |
403 | } |
404 | } |
405 | } |
406 | |
407 | auto inputs = InputsOf::getInputsTo(from); |
408 | stmt_stack_.emplace_back(inputs.begin(), inputs.end()); |
409 | |
410 | // The rest is basically copy-pasted from IterVitor: |
411 | while (!stmt_stack_.empty()) { |
412 | auto next_stmts = next(stmt_stack_.back().back()); |
413 | |
414 | // Remove statements we already visited if we're not traversing all paths |
415 | if (!traverseAllPaths) { |
416 | remove_visited(next_stmts, visited_stmts_); |
417 | } |
418 | |
419 | // Traverse down until we get to a leaf |
420 | while (!next_stmts.empty()) { |
421 | stmt_stack_.emplace_back(next_stmts.rbegin(), next_stmts.rend()); |
422 | next_stmts = next(stmt_stack_.back().back()); |
423 | // Remove statements we already visited if we're not traversing all paths |
424 | if (!traverseAllPaths) { |
425 | remove_visited(next_stmts, visited_stmts_); |
426 | } |
427 | } |
428 | |
429 | // Traverse back up |
430 | // Mark visited |
431 | visited_stmts_.emplace(stmt_stack_.back().back()); |
432 | // Handle |
433 | handle(stmt_stack_.back().back()); |
434 | // Remove |
435 | stmt_stack_.back().pop_back(); |
436 | |
437 | while (!stmt_stack_.empty() && stmt_stack_.back().empty()) { |
438 | stmt_stack_.pop_back(); |
439 | if (!stmt_stack_.empty()) { |
440 | // Mark visited |
441 | visited_stmts_.emplace(stmt_stack_.back().back()); |
442 | // Handle |
443 | handle(stmt_stack_.back().back()); |
444 | // Remove |
445 | stmt_stack_.back().pop_back(); |
446 | } |
447 | } |
448 | } |
449 | } |
450 | |
451 | /* DEPENDENCY CHECKING */ |
452 | |
453 | namespace { |
454 | |
455 | // Looks for and returns all values in between dependencies and vals, including |
456 | // them. |
457 | struct Dependencies : public IterVisitor { |
458 | private: |
459 | //! A given set of dependency Vals |
460 | const std::unordered_set<Val*> dependencies_; |
461 | //! Vals that are found between dependencies_ and of. Topologically |
462 | //! ordered. |
463 | std::vector<Val*> vals_; |
464 | //! Exprs that are found between dependencies_ and of. Topologically |
465 | //! ordered. |
466 | std::vector<Expr*> exprs_; |
467 | //! A set version of vals_ |
468 | std::unordered_set<Val*> dependent_vals_; |
469 | //! A set version of exprs_ |
470 | std::unordered_set<Expr*> dependent_exprs_; |
471 | |
472 | private: |
473 | std::vector<Statement*> next(Val* v) override { |
474 | if (dependencies_.find(v) != dependencies_.end()) { |
475 | return std::vector<Statement*>(); |
476 | } |
477 | return IterVisitor::next(v); |
478 | } |
479 | |
480 | void handle(Val* val) override { |
481 | // val is included if: |
482 | // 1. it is one of the dependencies, or |
483 | // 2. its defining expression is included in the dependent expr set |
484 | if (dependencies_.find(val) != dependencies_.end()) { |
485 | TORCH_INTERNAL_ASSERT( |
486 | dependent_vals_.find(val) == dependent_vals_.end(), |
487 | "Trying to add already added val: " , |
488 | val); |
489 | vals_.push_back(val); |
490 | dependent_vals_.insert(val); |
491 | } else { |
492 | auto def = val->definition(); |
493 | if (def != nullptr && |
494 | dependent_exprs_.find(def) != dependent_exprs_.end()) { |
495 | TORCH_INTERNAL_ASSERT( |
496 | dependent_vals_.find(val) == dependent_vals_.end(), |
497 | "Trying to add already added val: " , |
498 | val); |
499 | vals_.push_back(val); |
500 | dependent_vals_.insert(val); |
501 | } |
502 | } |
503 | } |
504 | |
505 | void handle(Expr* expr) override { |
506 | // Track which expr is depedent on the dependencies_ exprs. |
507 | if (std::any_of( |
508 | expr->inputs().begin(), expr->inputs().end(), [&](Val* input_val) { |
509 | return dependent_vals_.find(input_val) != dependent_vals_.end(); |
510 | })) { |
511 | if (!dependent_exprs_.count(expr)) { |
512 | exprs_.push_back(expr); |
513 | dependent_exprs_.insert(expr); |
514 | } |
515 | } |
516 | } |
517 | |
518 | Dependencies( |
519 | std::unordered_set<Val*> _dependencies, |
520 | const std::vector<Val*>& of) |
521 | : dependencies_(std::move(_dependencies)) { |
522 | traverseTo(of[0]->fusion(), of, false); |
523 | }; |
524 | |
525 | public: |
526 | static std::vector<Val*> getAllVals( |
527 | const std::unordered_set<Val*>& dependencies, |
528 | const std::vector<Val*>& of) { |
529 | if (of.empty()) { |
530 | return {}; |
531 | } |
532 | |
533 | Dependencies deps(dependencies, of); |
534 | return deps.vals_; |
535 | } |
536 | |
537 | static std::vector<Expr*> getAllExprs( |
538 | const std::unordered_set<Val*>& dependencies, |
539 | const std::vector<Val*>& of) { |
540 | if (of.empty()) { |
541 | return {}; |
542 | } |
543 | |
544 | Dependencies deps(dependencies, of); |
545 | return deps.exprs_; |
546 | } |
547 | }; |
548 | |
549 | // Looks for and returns all output values with dependencies on `of`. |
550 | struct FindOutputs : public IterVisitor { |
551 | const std::unordered_set<Val*>& of_; |
552 | std::unordered_set<Val*> outs_; |
553 | |
554 | void handle(Val* val) override { |
555 | if (of_.find(val) != of_.end()) { |
556 | Statement* out_stmt = stmt_stack.front().back(); |
557 | TORCH_INTERNAL_ASSERT(out_stmt->isVal()); |
558 | auto out_val = out_stmt->as<Val>(); |
559 | if (of_.find(out_val) == of_.end()) { |
560 | outs_.emplace(out_val); |
561 | } |
562 | } |
563 | } |
564 | |
565 | // TODO: Simply traverse through uses from of. Would be a lot faster than |
566 | // tracing all paths like this. |
567 | FindOutputs(const std::unordered_set<Val*>& _of) : of_(_of) { |
568 | auto fusion = (*of_.begin())->fusion(); |
569 | traverseTo(fusion, fusion->outputs(), true); |
570 | }; |
571 | |
572 | static std::unordered_set<Val*> getAllOutputsOf( |
573 | const std::unordered_set<Val*>& of) { |
574 | if (of.empty()) { |
575 | return std::unordered_set<Val*>(); |
576 | } |
577 | |
578 | FindOutputs finder(of); |
579 | return finder.outs_; |
580 | } |
581 | }; |
582 | |
583 | // Looks for and returns all values that depends on `of`. |
584 | class DependentVals : public IterVisitor { |
585 | private: |
586 | // Which nodes to find dependencies of |
587 | const std::unordered_set<Val*>& of_; |
588 | |
589 | // Dependencies we have so far |
590 | std::unordered_set<Val*> outs_; |
591 | |
592 | // Boundary where we want to stop searching beyond |
593 | // TODO: Based on the todo below, shouldn't we stop just at the definition of? |
594 | // If we really wanted to make this traverse left, wouldn't we first check |
595 | // which outputs are outputs dependent on of? |
596 | std::unordered_set<Val*> boundary_; |
597 | |
598 | std::vector<Statement*> next(Val* v) override { |
599 | if (boundary_.find(v) != boundary_.end()) |
600 | return std::vector<Statement*>(); |
601 | return IterVisitor::next(v); |
602 | } |
603 | |
604 | void handle(Val* val) override { |
605 | if (val->isFusionInput() || val->definition() == nullptr || |
606 | of_.count(val) || outs_.count(val)) { |
607 | return; |
608 | } |
609 | |
610 | for (auto v : val->definition()->inputs()) { |
611 | if (of_.count(v) || outs_.count(v)) { |
612 | outs_.emplace(val); |
613 | return; |
614 | } |
615 | } |
616 | } |
617 | |
618 | // optimization to limit search path |
619 | // TODO: Is this valid? Couldn't something like: |
620 | // out0 = of + val0 |
621 | // out1 = out0 + val1 |
622 | // out2 = TernaryOp(out1, val0, of) |
623 | // Hide the dep of out1 on of? |
624 | void createBoundary() { |
625 | for (auto v_of : of_) { |
626 | for (auto v_expr : v_of->uses()) { |
627 | for (auto v_in : v_expr->inputs()) { |
628 | boundary_.emplace(v_in); |
629 | } |
630 | } |
631 | } |
632 | } |
633 | |
634 | DependentVals(const std::unordered_set<Val*>& _of) : of_(_of) { |
635 | createBoundary(); |
636 | auto fusion = (*of_.begin())->fusion(); |
637 | traverseTo(fusion, fusion->outputs(), false); |
638 | }; |
639 | |
640 | public: |
641 | static std::unordered_set<Val*> getAllDependentVals( |
642 | const std::unordered_set<Val*>& of) { |
643 | if (of.empty()) { |
644 | return std::unordered_set<Val*>(); |
645 | } |
646 | DependentVals dependencies(of); |
647 | return dependencies.outs_; |
648 | } |
649 | }; |
650 | |
651 | class DependencyChains : public IterVisitor { |
652 | public: |
653 | std::deque<std::deque<Val*>> dep_chains; |
654 | bool is_dependency = false; |
655 | std::unordered_set<Val*> dependencies_; |
656 | |
657 | void handle(Val* val) override { |
658 | if (dependencies_.find(val) != dependencies_.end()) { |
659 | is_dependency = true; |
660 | std::deque<Val*> deps; |
661 | for (auto stack : stmt_stack) { |
662 | if (stack.back()->isVal()) { |
663 | deps.push_back(stack.back()->as<Val>()); |
664 | } |
665 | } |
666 | // Order as dependency -> of |
667 | dep_chains.emplace_back(deps.rbegin(), deps.rend()); |
668 | } |
669 | } |
670 | |
671 | DependencyChains(Val* _dependency, Val* _of, bool all_chains_ = false) |
672 | : dependencies_({_dependency}) { |
673 | traverseTo(_of->fusion(), {_of}, all_chains_); |
674 | } |
675 | |
676 | DependencyChains(Val* _dependency, bool all_chains_ = false) |
677 | : dependencies_({_dependency}) { |
678 | if (all_chains_) { |
679 | traverseAllPaths(_dependency->fusion()); |
680 | } else { |
681 | traverse(_dependency->fusion()); |
682 | } |
683 | } |
684 | |
685 | DependencyChains( |
686 | std::unordered_set<Val*> _dependencies, |
687 | bool all_chains_ = false) |
688 | : dependencies_(std::move(_dependencies)) { |
689 | if (dependencies_.empty()) { |
690 | return; |
691 | } |
692 | |
693 | if (all_chains_) { |
694 | traverseAllPaths((*dependencies_.begin())->fusion()); |
695 | } else { |
696 | traverse((*dependencies_.begin())->fusion()); |
697 | } |
698 | } |
699 | |
700 | static std::deque<Val*> getDependencyChain(Val* dependency, Val* of) { |
701 | DependencyChains dp(dependency, of, false); |
702 | if (dp.dep_chains.empty()) { |
703 | return std::deque<Val*>(); |
704 | } |
705 | return dp.dep_chains[0]; |
706 | } |
707 | |
708 | // I don't think this is actually hooked up, but leaving for now. |
709 | static std::deque<std::deque<Val*>> getDependencyChains( |
710 | Val* dependency, |
711 | Val* of) { |
712 | DependencyChains dp(dependency, of, true); |
713 | if (dp.dep_chains.empty()) { |
714 | return std::deque<std::deque<Val*>>(); |
715 | } |
716 | return dp.dep_chains; |
717 | } |
718 | |
719 | static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency) { |
720 | DependencyChains dp(dependency, true); |
721 | if (dp.dep_chains.empty()) { |
722 | return std::deque<std::deque<Val*>>(); |
723 | } |
724 | return dp.dep_chains; |
725 | } |
726 | |
727 | static std::deque<std::deque<Val*>> getAllUseChains( |
728 | const std::unordered_set<Val*>& dependencies) { |
729 | DependencyChains dp(dependencies, true); |
730 | if (dp.dep_chains.empty()) { |
731 | return std::deque<std::deque<Val*>>(); |
732 | } |
733 | return dp.dep_chains; |
734 | } |
735 | }; |
736 | |
737 | } // namespace |
738 | |
739 | bool DependencyCheck::isDependencyOf(Val* dependency, Val* of) { |
740 | return !DependencyChains::getDependencyChain(dependency, of).empty(); |
741 | } |
742 | |
743 | std::deque<Val*> DependencyCheck::getSingleDependencyChain( |
744 | Val* dependency, |
745 | Val* of) { |
746 | return DependencyChains::getDependencyChain(dependency, of); |
747 | } |
748 | |
749 | std::deque<std::deque<Val*>> DependencyCheck::getAllDependencyChains( |
750 | Val* dependency, |
751 | Val* of) { |
752 | return DependencyChains::getDependencyChains(dependency, of); |
753 | } |
754 | |
755 | std::deque<std::deque<Val*>> DependencyCheck::getAllUseChains(Val* producer) { |
756 | return DependencyChains::getAllUseChains(producer); |
757 | } |
758 | |
759 | std::vector<Val*> DependencyCheck::getAllValsBetween( |
760 | const std::unordered_set<Val*>& dependencies, |
761 | const std::vector<Val*>& of) { |
762 | return Dependencies::getAllVals(dependencies, of); |
763 | } |
764 | |
765 | std::vector<Expr*> DependencyCheck::getAllExprsBetween( |
766 | const std::unordered_set<Val*>& dependencies, |
767 | const std::vector<Val*>& of) { |
768 | return Dependencies::getAllExprs(dependencies, of); |
769 | } |
770 | |
771 | std::unordered_set<Val*> DependencyCheck::getAllOutputsOf( |
772 | const std::unordered_set<Val*>& of) { |
773 | if (of.empty()) { |
774 | return std::unordered_set<Val*>(); |
775 | } |
776 | FusionGuard fg((*of.begin())->fusion()); |
777 | return FindOutputs::getAllOutputsOf(of); |
778 | } |
779 | |
780 | std::unordered_set<Val*> DependencyCheck::getAllDependentVals( |
781 | const std::unordered_set<Val*>& of) { |
782 | if (of.empty()) { |
783 | return std::unordered_set<Val*>(); |
784 | } |
785 | FusionGuard fg((*of.begin())->fusion()); |
786 | return DependentVals::getAllDependentVals(of); |
787 | } |
788 | |
789 | void StmtSort::handle(Statement* stmt) { |
790 | stmts.push_back(stmt); |
791 | } |
792 | |
793 | std::vector<Expr*> StmtSort::getExprs(Fusion* fusion, bool traverse_members) { |
794 | auto terminating_outputs = fusion->getTerminatingOutputs(); |
795 | return StmtSort::getExprs(fusion, terminating_outputs, traverse_members); |
796 | } |
797 | |
798 | std::vector<Expr*> StmtSort::getExprs( |
799 | Fusion* fusion, |
800 | const std::vector<Val*>& to, |
801 | bool traverse_members) { |
802 | auto stmts = StmtSort::getStmts(fusion, to, traverse_members); |
803 | auto filter = ir_utils::filterByType<Expr>(stmts.begin(), stmts.end()); |
804 | std::vector<Expr*> exprs(filter.begin(), filter.end()); |
805 | return exprs; |
806 | } |
807 | |
808 | std::vector<Expr*> StmtSort::getExprsBetween( |
809 | Fusion* fusion, |
810 | const std::vector<Val*>& from, |
811 | const std::vector<Val*>& to, |
812 | bool traverse_members) { |
813 | auto stmts = StmtSort::getStmtsBetween(fusion, from, to, traverse_members); |
814 | auto filter = ir_utils::filterByType<Expr>(stmts.begin(), stmts.end()); |
815 | std::vector<Expr*> exprs(filter.begin(), filter.end()); |
816 | return exprs; |
817 | } |
818 | |
819 | std::vector<Statement*> StmtSort::getStmts( |
820 | Fusion* fusion, |
821 | bool traverse_members) { |
822 | auto terminating_outputs = fusion->getTerminatingOutputs(); |
823 | return StmtSort::getStmts(fusion, terminating_outputs, traverse_members); |
824 | } |
825 | |
826 | std::vector<Statement*> StmtSort::getStmts( |
827 | Fusion* fusion, |
828 | const std::vector<Val*>& to, |
829 | bool traverse_members) { |
830 | StmtSort es; |
831 | es.traverseTo(fusion, to, false, traverse_members); |
832 | return es.stmts; |
833 | } |
834 | |
835 | std::vector<Statement*> StmtSort::getStmtsBetween( |
836 | Fusion* fusion, |
837 | const std::vector<Val*>& from, |
838 | const std::vector<Val*>& to, |
839 | bool traverse_members) { |
840 | StmtSort es; |
841 | es.traverseBetween( |
842 | fusion, {from.begin(), from.end()}, to, false, traverse_members); |
843 | return es.stmts; |
844 | } |
845 | |
846 | void InputsOf::handle(Val* v) { |
847 | if (v->definition() == nullptr || v->definition()->inputs().empty()) { |
848 | if (grabbed_inputs.emplace(v).second) { |
849 | ordered_inputs.push_back(v); |
850 | } |
851 | } |
852 | } |
853 | |
854 | std::vector<Val*> InputsOf::output(Fusion* fusion, Val* output_) { |
855 | return outputs(fusion, {output_}); |
856 | } |
857 | |
858 | std::vector<Val*> InputsOf::outputs( |
859 | Fusion* fusion, |
860 | const std::vector<Val*>& outputs_) { |
861 | InputsOf io; |
862 | io.traverseTo(fusion, outputs_, false); |
863 | return io.ordered_inputs; |
864 | } |
865 | |
866 | } // namespace cuda |
867 | } // namespace fuser |
868 | } // namespace jit |
869 | } // namespace torch |
870 | |