1 | #include <kernel_ir.h> |
2 | #include <kernel_ir_dispatch.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | namespace fuser { |
7 | namespace cuda { |
8 | namespace kir { |
9 | std::vector<Expr*> IrVisitor::handle(const std::vector<Expr*>& exprs) { |
10 | exprs_ = std::vector<Expr*>(exprs); |
11 | for (auto expr : exprs) { |
12 | handle(expr); |
13 | } |
14 | return exprs_; |
15 | } |
16 | |
17 | void IrVisitor::handle(ForLoop* fl) { |
18 | for_loops_.push_back(fl); |
19 | scope_.push_back(&fl->body()); |
20 | scope_exprs_.push_back(fl); |
21 | auto body_exprs = std::vector<Expr*>(fl->body().exprs()); |
22 | for (auto expr : body_exprs) { |
23 | handle(expr); |
24 | } |
25 | scope_exprs_.pop_back(); |
26 | scope_.pop_back(); |
27 | for_loops_.pop_back(); |
28 | } |
29 | |
30 | void IrVisitor::handle(IfThenElse* ite) { |
31 | scope_exprs_.push_back(ite); |
32 | scope_.push_back(&ite->thenBody()); |
33 | auto then_exprs = std::vector<Expr*>(ite->thenBody().exprs()); |
34 | for (auto expr : then_exprs) { |
35 | handle(expr); |
36 | } |
37 | scope_.pop_back(); |
38 | |
39 | scope_.push_back(&ite->elseBody()); |
40 | auto else_exprs = std::vector<Expr*>(ite->elseBody().exprs()); |
41 | for (auto expr : else_exprs) { |
42 | handle(expr); |
43 | } |
44 | scope_.pop_back(); |
45 | scope_exprs_.pop_back(); |
46 | } |
47 | |
48 | std::vector<const Expr*> ConstIrVisitor::handle( |
49 | const std::vector<const Expr*>& exprs) { |
50 | exprs_ = exprs; |
51 | for (auto expr : exprs) { |
52 | handle(expr); |
53 | } |
54 | return exprs_; |
55 | } |
56 | |
57 | void ConstIrVisitor::handle(const ForLoop* fl) { |
58 | for_loops_.push_back(fl); |
59 | scope_.push_back(&fl->body()); |
60 | scope_exprs_.push_back(fl); |
61 | auto body_exprs = fl->body().exprs(); |
62 | for (auto expr : body_exprs) { |
63 | handle(expr); |
64 | } |
65 | scope_exprs_.pop_back(); |
66 | scope_.pop_back(); |
67 | for_loops_.pop_back(); |
68 | } |
69 | |
70 | void ConstIrVisitor::handle(const IfThenElse* ite) { |
71 | scope_exprs_.push_back(ite); |
72 | scope_.push_back(&ite->thenBody()); |
73 | auto then_exprs = ite->thenBody().exprs(); |
74 | for (auto expr : then_exprs) { |
75 | handle(expr); |
76 | } |
77 | scope_.pop_back(); |
78 | |
79 | scope_.push_back(&ite->elseBody()); |
80 | auto else_exprs = ite->elseBody().exprs(); |
81 | for (auto expr : else_exprs) { |
82 | handle(expr); |
83 | } |
84 | scope_.pop_back(); |
85 | scope_exprs_.pop_back(); |
86 | } |
87 | |
88 | std::vector<Expr*> ExprMutator::mutate(bool reverse_order) { |
89 | if (insertions_.empty() && replacements_.empty() && removal_.empty()) { |
90 | return exprs_; |
91 | } |
92 | |
93 | auto run_insertion = [&](MutationInformation info) { |
94 | if (info.scope == nullptr) { |
95 | // If reference is nullptr and there are no expressions, simply insert the |
96 | // expr |
97 | if (exprs_.empty() && info.reference == nullptr) { |
98 | exprs_.push_back(info.new_expr); |
99 | return; |
100 | } |
101 | auto pos_it = std::find(exprs_.begin(), exprs_.end(), info.reference); |
102 | TORCH_INTERNAL_ASSERT( |
103 | pos_it != exprs_.end(), |
104 | "Issue finding reference expression for insertion." ); |
105 | if (info.mode == MutationMode::BEFORE) { |
106 | exprs_.insert(pos_it, info.new_expr); |
107 | } else { |
108 | exprs_.insert(pos_it + 1, info.new_expr); |
109 | } |
110 | } else { |
111 | // If reference is nullptr and there are no expressions, simply insert the |
112 | // expr |
113 | if (info.scope->exprs().empty() && info.reference == nullptr) { |
114 | info.scope->push_back(info.new_expr); |
115 | return; |
116 | } |
117 | if (info.mode == MutationMode::BEFORE) { |
118 | info.scope->insert_before(info.reference, info.new_expr); |
119 | } else { |
120 | info.scope->insert_after(info.reference, info.new_expr); |
121 | } |
122 | } |
123 | }; |
124 | |
125 | if (reverse_order) { |
126 | for (auto it = insertions_.rbegin(); it != insertions_.rend(); ++it) { |
127 | run_insertion(*it); |
128 | } |
129 | } else { |
130 | for (auto insertion_info : insertions_) { |
131 | run_insertion(insertion_info); |
132 | } |
133 | } |
134 | |
135 | for (auto replacement_info : replacements_) { |
136 | if (replacement_info.scope == nullptr) { |
137 | auto pos_it = |
138 | std::find(exprs_.begin(), exprs_.end(), replacement_info.reference); |
139 | TORCH_INTERNAL_ASSERT( |
140 | pos_it != exprs_.end(), |
141 | "Issue finding reference expression for replacement." ); |
142 | exprs_.insert(pos_it, replacement_info.new_expr); |
143 | // iterator can be invalidated from insertion |
144 | pos_it = |
145 | std::find(exprs_.begin(), exprs_.end(), replacement_info.reference); |
146 | exprs_.erase(pos_it); |
147 | } else { |
148 | replacement_info.scope->insert_before( |
149 | replacement_info.reference, replacement_info.new_expr); |
150 | replacement_info.scope->erase(replacement_info.reference); |
151 | } |
152 | } |
153 | |
154 | for (auto removal_info : removal_) { |
155 | if (removal_info.scope == nullptr) { |
156 | auto pos_it = |
157 | std::find(exprs_.begin(), exprs_.end(), removal_info.reference); |
158 | TORCH_INTERNAL_ASSERT( |
159 | pos_it != exprs_.end(), "Issue finding expression to remove." ); |
160 | exprs_.erase(pos_it); |
161 | } else { |
162 | TORCH_INTERNAL_ASSERT( |
163 | removal_info.scope->contains(removal_info.reference), |
164 | "Expression to remove is not found in the given scope: " , |
165 | removal_info.reference->toString()); |
166 | removal_info.scope->erase(removal_info.reference); |
167 | } |
168 | } |
169 | |
170 | insertions_.clear(); |
171 | replacements_.clear(); |
172 | |
173 | return exprs_; |
174 | } |
175 | |
176 | std::vector<Expr*> ExprMutator::traverseAndInsert( |
177 | const std::vector<Expr*>& exprs, |
178 | bool reverse_order) { |
179 | IrVisitor::handle(exprs); |
180 | return mutate(reverse_order); |
181 | } |
182 | |
183 | void ExprMutator::registerMutation( |
184 | Expr* reference, |
185 | Expr* new_expr, |
186 | Scope* scope, |
187 | MutationMode mode) { |
188 | MutationInformation mutation; |
189 | mutation.reference = reference; |
190 | mutation.new_expr = new_expr; |
191 | mutation.scope = scope; |
192 | mutation.mode = mode; |
193 | if (mode == MutationMode::BEFORE || mode == MutationMode::AFTER) { |
194 | insertions_.push_back(mutation); |
195 | } else if (mode == MutationMode::REPLACE) { |
196 | replacements_.push_back(mutation); |
197 | } else if (mode == MutationMode::REMOVE) { |
198 | removal_.push_back(mutation); |
199 | } else { |
200 | TORCH_INTERNAL_ASSERT(false, "Invalid mutation type" ); |
201 | } |
202 | } |
203 | |
204 | void ExprMutator::registerInsertBefore( |
205 | Expr* reference, |
206 | Expr* new_expr, |
207 | Scope* scope) { |
208 | registerMutation(reference, new_expr, scope, MutationMode::BEFORE); |
209 | } |
210 | |
211 | void ExprMutator::registerInsertAfter( |
212 | Expr* reference, |
213 | Expr* new_expr, |
214 | Scope* scope) { |
215 | registerMutation(reference, new_expr, scope, MutationMode::AFTER); |
216 | } |
217 | |
218 | void ExprMutator::registerReplace( |
219 | Expr* reference, |
220 | Expr* new_expr, |
221 | Scope* scope) { |
222 | registerMutation(reference, new_expr, scope, MutationMode::REPLACE); |
223 | } |
224 | |
225 | void ExprMutator::registerRemove(Expr* expr_to_remove, Scope* scope) { |
226 | registerMutation(expr_to_remove, nullptr, scope, MutationMode::REMOVE); |
227 | } |
228 | |
229 | void ExprMutator::registerInsertBefore(Expr* reference, Expr* new_expr) { |
230 | Scope* scope = scope_.empty() ? nullptr : scope_.back(); |
231 | registerInsertBefore(reference, new_expr, scope); |
232 | } |
233 | |
234 | void ExprMutator::registerInsertAfter(Expr* reference, Expr* new_expr) { |
235 | Scope* scope = scope_.empty() ? nullptr : scope_.back(); |
236 | registerInsertAfter(reference, new_expr, scope); |
237 | } |
238 | |
239 | void ExprMutator::registerReplace(Expr* reference, Expr* new_expr) { |
240 | Scope* scope = scope_.empty() ? nullptr : scope_.back(); |
241 | registerReplace(reference, new_expr, scope); |
242 | } |
243 | |
244 | void ExprMutator::registerRemove(Expr* expr_to_remove) { |
245 | Scope* scope = scope_.empty() ? nullptr : scope_.back(); |
246 | registerRemove(expr_to_remove, scope); |
247 | } |
248 | |
249 | } // namespace kir |
250 | } // namespace cuda |
251 | } // namespace fuser |
252 | } // namespace jit |
253 | } // namespace torch |
254 | |