1 | #pragma once |
2 | |
3 | #include <dispatch.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | namespace fuser { |
8 | namespace cuda { |
9 | |
10 | class Expr; |
11 | |
12 | namespace kir { |
13 | class Predicate; |
14 | class TensorIndex; |
15 | class ForLoop; |
16 | class IfThenElse; |
17 | class Scope; |
18 | |
19 | // Base visitor class that visits all nodes in provided vector<Expr*>. |
20 | // |
21 | // Includes visiting through scopes like IfThenElse and ForLoop, and tracks |
22 | // them in scopes_ and for_loops_. |
23 | // |
24 | // Makes a copy of exprs at exprs_ which could be used to modify and return. |
25 | // |
26 | // When traversing through ITE/FLs it will use a copy |
27 | // of the provided expressions to make it safe to insert/delete nodes. |
28 | // |
29 | // Provides a simple base class to inherit from for typical lowering passes on |
30 | // Expr list |
31 | class TORCH_CUDA_CU_API IrVisitor : public OptOutDispatch { |
32 | public: |
33 | std::vector<Expr*> handle(const std::vector<Expr*>& expr); |
34 | |
35 | protected: |
36 | using OptOutDispatch::handle; |
37 | |
38 | virtual void handle(ForLoop*) override; |
39 | virtual void handle(IfThenElse*) override; |
40 | |
41 | protected: |
42 | std::vector<ForLoop*> for_loops_; |
43 | std::vector<Scope*> scope_; |
44 | std::vector<Expr*> scope_exprs_; |
45 | std::vector<Expr*> exprs_; |
46 | }; |
47 | |
48 | // Const version of IrVisitor |
49 | class TORCH_CUDA_CU_API ConstIrVisitor : public OptOutConstDispatch { |
50 | public: |
51 | std::vector<const Expr*> handle(const std::vector<const Expr*>& expr); |
52 | |
53 | protected: |
54 | using OptOutConstDispatch::handle; |
55 | |
56 | virtual void handle(const ForLoop*) override; |
57 | virtual void handle(const IfThenElse*) override; |
58 | |
59 | protected: |
60 | std::vector<const ForLoop*> for_loops_; |
61 | std::vector<const Scope*> scope_; |
62 | std::vector<const Expr*> scope_exprs_; |
63 | std::vector<const Expr*> exprs_; |
64 | }; |
65 | |
66 | // Base Expr Mutator class that visits all nodes with IrVisitor, and then |
67 | // inserts new expressions, replaces expressions based on insertion/replace |
68 | // maps provided or removes existing expressions. These replacement |
69 | // maps are expected to accumulate during an initial traversal, then |
70 | // runs an insertion based on them after the overloaded traversal. |
71 | // |
72 | // Order of mutations may be important, mutations are ordered according to the |
73 | // following rules: |
74 | // Before/After insertions are ordered as registered when reverse_order == |
75 | // false, |
76 | // |
77 | // Before/After insertions are in reverse order as registered when |
78 | // reverse_order == true, |
79 | // |
80 | // Before/After insertions are done before Expr replacements, so reference for |
81 | // insertions must be on pre-replaced Exprs |
82 | // |
83 | // Removal of expressions is done after replacements. |
84 | // |
85 | // To place in a scope that is empty, simply provide a nullptr reference |
86 | // Since insertions are done in order, it's possible to insert an expression in |
87 | // an empty scope, and then use that inserted scope as a reference for |
88 | // subsequent mutations. |
89 | class ExprMutator : public IrVisitor { |
90 | protected: |
91 | std::vector<Expr*> traverseAndInsert( |
92 | const std::vector<Expr*>& expr, |
93 | bool reverse_order = false); |
94 | |
95 | std::vector<Expr*> mutate(bool reverse_order = false); |
96 | |
97 | using IrVisitor::handle; |
98 | // Registration function which *don't* need to be called "in place" during |
99 | // visiting. |
100 | void registerInsertBefore(Expr* reference, Expr* new_expr, Scope* scope); |
101 | void registerInsertAfter(Expr* reference, Expr* new_expr, Scope* scope); |
102 | void registerReplace(Expr* reference, Expr* new_expr, Scope* scope); |
103 | void registerRemove(Expr* expr_to_remove, Scope* scope); |
104 | |
105 | // Registration function which need to be called "in place" during visiting. |
106 | // I.E. |
107 | // if you want to insert before/after or replace an Expr, you must register |
108 | // when in handle(Expr*) of that expr. |
109 | void registerInsertBefore(Expr* reference, Expr* new_expr); |
110 | void registerInsertAfter(Expr* reference, Expr* new_expr); |
111 | void registerReplace(Expr* reference, Expr* new_expr); |
112 | void registerRemove(Expr* expr_to_remove); |
113 | |
114 | private: |
115 | enum class MutationMode { BEFORE, AFTER, REPLACE, REMOVE }; |
116 | |
117 | void registerMutation( |
118 | Expr* ref, |
119 | Expr* new_expr, |
120 | Scope* scope, |
121 | MutationMode mode); |
122 | |
123 | struct MutationInformation { |
124 | Expr* reference = nullptr; |
125 | Expr* new_expr = nullptr; |
126 | Scope* scope = nullptr; |
127 | MutationMode mode = MutationMode::BEFORE; |
128 | }; |
129 | |
130 | // Track insertions as they're registered |
131 | std::vector<MutationInformation> insertions_; |
132 | |
133 | // Track replacements as they're registered |
134 | std::vector<MutationInformation> replacements_; |
135 | |
136 | // Track removal as they're registered |
137 | std::vector<MutationInformation> removal_; |
138 | }; |
139 | |
140 | } // namespace kir |
141 | } // namespace cuda |
142 | } // namespace fuser |
143 | } // namespace jit |
144 | } // namespace torch |
145 | |