1#include <kernel_ir.h>
2#include <kernel_ir_dispatch.h>
3
4namespace torch {
5namespace jit {
6namespace fuser {
7namespace cuda {
8namespace kir {
9std::vector<Expr*> IrVisitor::handle(const std::vector<Expr*>& exprs) {
10 exprs_ = std::vector<Expr*>(exprs);
11 for (auto expr : exprs) {
12 handle(expr);
13 }
14 return exprs_;
15}
16
17void IrVisitor::handle(ForLoop* fl) {
18 for_loops_.push_back(fl);
19 scope_.push_back(&fl->body());
20 scope_exprs_.push_back(fl);
21 auto body_exprs = std::vector<Expr*>(fl->body().exprs());
22 for (auto expr : body_exprs) {
23 handle(expr);
24 }
25 scope_exprs_.pop_back();
26 scope_.pop_back();
27 for_loops_.pop_back();
28}
29
30void IrVisitor::handle(IfThenElse* ite) {
31 scope_exprs_.push_back(ite);
32 scope_.push_back(&ite->thenBody());
33 auto then_exprs = std::vector<Expr*>(ite->thenBody().exprs());
34 for (auto expr : then_exprs) {
35 handle(expr);
36 }
37 scope_.pop_back();
38
39 scope_.push_back(&ite->elseBody());
40 auto else_exprs = std::vector<Expr*>(ite->elseBody().exprs());
41 for (auto expr : else_exprs) {
42 handle(expr);
43 }
44 scope_.pop_back();
45 scope_exprs_.pop_back();
46}
47
48std::vector<const Expr*> ConstIrVisitor::handle(
49 const std::vector<const Expr*>& exprs) {
50 exprs_ = exprs;
51 for (auto expr : exprs) {
52 handle(expr);
53 }
54 return exprs_;
55}
56
57void ConstIrVisitor::handle(const ForLoop* fl) {
58 for_loops_.push_back(fl);
59 scope_.push_back(&fl->body());
60 scope_exprs_.push_back(fl);
61 auto body_exprs = fl->body().exprs();
62 for (auto expr : body_exprs) {
63 handle(expr);
64 }
65 scope_exprs_.pop_back();
66 scope_.pop_back();
67 for_loops_.pop_back();
68}
69
70void ConstIrVisitor::handle(const IfThenElse* ite) {
71 scope_exprs_.push_back(ite);
72 scope_.push_back(&ite->thenBody());
73 auto then_exprs = ite->thenBody().exprs();
74 for (auto expr : then_exprs) {
75 handle(expr);
76 }
77 scope_.pop_back();
78
79 scope_.push_back(&ite->elseBody());
80 auto else_exprs = ite->elseBody().exprs();
81 for (auto expr : else_exprs) {
82 handle(expr);
83 }
84 scope_.pop_back();
85 scope_exprs_.pop_back();
86}
87
88std::vector<Expr*> ExprMutator::mutate(bool reverse_order) {
89 if (insertions_.empty() && replacements_.empty() && removal_.empty()) {
90 return exprs_;
91 }
92
93 auto run_insertion = [&](MutationInformation info) {
94 if (info.scope == nullptr) {
95 // If reference is nullptr and there are no expressions, simply insert the
96 // expr
97 if (exprs_.empty() && info.reference == nullptr) {
98 exprs_.push_back(info.new_expr);
99 return;
100 }
101 auto pos_it = std::find(exprs_.begin(), exprs_.end(), info.reference);
102 TORCH_INTERNAL_ASSERT(
103 pos_it != exprs_.end(),
104 "Issue finding reference expression for insertion.");
105 if (info.mode == MutationMode::BEFORE) {
106 exprs_.insert(pos_it, info.new_expr);
107 } else {
108 exprs_.insert(pos_it + 1, info.new_expr);
109 }
110 } else {
111 // If reference is nullptr and there are no expressions, simply insert the
112 // expr
113 if (info.scope->exprs().empty() && info.reference == nullptr) {
114 info.scope->push_back(info.new_expr);
115 return;
116 }
117 if (info.mode == MutationMode::BEFORE) {
118 info.scope->insert_before(info.reference, info.new_expr);
119 } else {
120 info.scope->insert_after(info.reference, info.new_expr);
121 }
122 }
123 };
124
125 if (reverse_order) {
126 for (auto it = insertions_.rbegin(); it != insertions_.rend(); ++it) {
127 run_insertion(*it);
128 }
129 } else {
130 for (auto insertion_info : insertions_) {
131 run_insertion(insertion_info);
132 }
133 }
134
135 for (auto replacement_info : replacements_) {
136 if (replacement_info.scope == nullptr) {
137 auto pos_it =
138 std::find(exprs_.begin(), exprs_.end(), replacement_info.reference);
139 TORCH_INTERNAL_ASSERT(
140 pos_it != exprs_.end(),
141 "Issue finding reference expression for replacement.");
142 exprs_.insert(pos_it, replacement_info.new_expr);
143 // iterator can be invalidated from insertion
144 pos_it =
145 std::find(exprs_.begin(), exprs_.end(), replacement_info.reference);
146 exprs_.erase(pos_it);
147 } else {
148 replacement_info.scope->insert_before(
149 replacement_info.reference, replacement_info.new_expr);
150 replacement_info.scope->erase(replacement_info.reference);
151 }
152 }
153
154 for (auto removal_info : removal_) {
155 if (removal_info.scope == nullptr) {
156 auto pos_it =
157 std::find(exprs_.begin(), exprs_.end(), removal_info.reference);
158 TORCH_INTERNAL_ASSERT(
159 pos_it != exprs_.end(), "Issue finding expression to remove.");
160 exprs_.erase(pos_it);
161 } else {
162 TORCH_INTERNAL_ASSERT(
163 removal_info.scope->contains(removal_info.reference),
164 "Expression to remove is not found in the given scope: ",
165 removal_info.reference->toString());
166 removal_info.scope->erase(removal_info.reference);
167 }
168 }
169
170 insertions_.clear();
171 replacements_.clear();
172
173 return exprs_;
174}
175
176std::vector<Expr*> ExprMutator::traverseAndInsert(
177 const std::vector<Expr*>& exprs,
178 bool reverse_order) {
179 IrVisitor::handle(exprs);
180 return mutate(reverse_order);
181}
182
183void ExprMutator::registerMutation(
184 Expr* reference,
185 Expr* new_expr,
186 Scope* scope,
187 MutationMode mode) {
188 MutationInformation mutation;
189 mutation.reference = reference;
190 mutation.new_expr = new_expr;
191 mutation.scope = scope;
192 mutation.mode = mode;
193 if (mode == MutationMode::BEFORE || mode == MutationMode::AFTER) {
194 insertions_.push_back(mutation);
195 } else if (mode == MutationMode::REPLACE) {
196 replacements_.push_back(mutation);
197 } else if (mode == MutationMode::REMOVE) {
198 removal_.push_back(mutation);
199 } else {
200 TORCH_INTERNAL_ASSERT(false, "Invalid mutation type");
201 }
202}
203
204void ExprMutator::registerInsertBefore(
205 Expr* reference,
206 Expr* new_expr,
207 Scope* scope) {
208 registerMutation(reference, new_expr, scope, MutationMode::BEFORE);
209}
210
211void ExprMutator::registerInsertAfter(
212 Expr* reference,
213 Expr* new_expr,
214 Scope* scope) {
215 registerMutation(reference, new_expr, scope, MutationMode::AFTER);
216}
217
218void ExprMutator::registerReplace(
219 Expr* reference,
220 Expr* new_expr,
221 Scope* scope) {
222 registerMutation(reference, new_expr, scope, MutationMode::REPLACE);
223}
224
225void ExprMutator::registerRemove(Expr* expr_to_remove, Scope* scope) {
226 registerMutation(expr_to_remove, nullptr, scope, MutationMode::REMOVE);
227}
228
229void ExprMutator::registerInsertBefore(Expr* reference, Expr* new_expr) {
230 Scope* scope = scope_.empty() ? nullptr : scope_.back();
231 registerInsertBefore(reference, new_expr, scope);
232}
233
234void ExprMutator::registerInsertAfter(Expr* reference, Expr* new_expr) {
235 Scope* scope = scope_.empty() ? nullptr : scope_.back();
236 registerInsertAfter(reference, new_expr, scope);
237}
238
239void ExprMutator::registerReplace(Expr* reference, Expr* new_expr) {
240 Scope* scope = scope_.empty() ? nullptr : scope_.back();
241 registerReplace(reference, new_expr, scope);
242}
243
244void ExprMutator::registerRemove(Expr* expr_to_remove) {
245 Scope* scope = scope_.empty() ? nullptr : scope_.back();
246 registerRemove(expr_to_remove, scope);
247}
248
249} // namespace kir
250} // namespace cuda
251} // namespace fuser
252} // namespace jit
253} // namespace torch
254