1 | #include <lower_unroll.h> |
2 | |
3 | #include <arith.h> |
4 | #include <index_compute.h> |
5 | #include <instrumentation.h> |
6 | #include <ir_iostream.h> |
7 | #include <ir_utils.h> |
8 | #include <kernel_expr_evaluator.h> |
9 | #include <lower2device.h> |
10 | #include <lower_misaligned_vectorization.h> |
11 | #include <lower_utils.h> |
12 | #include <predicate_compute.h> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | namespace fuser { |
17 | namespace cuda { |
18 | |
19 | namespace { |
20 | |
21 | // Provide a new for loop matching the one provided |
22 | kir::ForLoop* cloneLoopNest(const kir::ForLoop* for_loop) { |
23 | const auto new_loop = IrBuilder::create<kir::ForLoop>(for_loop); |
24 | for (auto expr : for_loop->body().exprs()) { |
25 | if (auto nested_for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
26 | expr = cloneLoopNest(nested_for_loop); |
27 | } |
28 | new_loop->body().push_back(expr); |
29 | } |
30 | return new_loop; |
31 | } |
32 | |
33 | // Returns true if expr is an expression that initializes a reduction |
34 | // buffer. |
35 | bool isReductionInitExpr(const Expr* expr) { |
36 | // False if its output isn't a TensorView |
37 | if (!ir_utils::isTvOp(expr)) { |
38 | return false; |
39 | } |
40 | // False if it doesn't have any reduction axis |
41 | const auto out_tv = expr->outputs()[0]->as<TensorView>(); |
42 | if (!out_tv->domain()->hasReduction()) { |
43 | return false; |
44 | } |
45 | // False if it has have TensorView inputs as initialization should |
46 | // never use TensorViews |
47 | const auto tv_filter_inp_view = |
48 | ir_utils::filterByType<TensorView>(expr->inputs()); |
49 | if (tv_filter_inp_view.begin() != tv_filter_inp_view.end()) { |
50 | return false; |
51 | } |
52 | return true; |
53 | } |
54 | |
55 | } // namespace |
56 | |
57 | void UnrollPass::registerReplace( |
58 | Expr* reference, |
59 | Expr* new_expr, |
60 | kir::Scope* scope) { |
61 | kir::ExprMutator::registerReplace(reference, new_expr, scope); |
62 | GpuLower::current()->propagateExprInfo(reference, new_expr); |
63 | } |
64 | |
65 | void UnrollPass::handle(Expr* expr) { |
66 | if (ir_utils::isTvOp(expr)) { |
67 | // If tv op, predicate it |
68 | const auto out_tv = ir_utils::getTvOutput(expr); |
69 | const bool should_predicate = !for_loops_.empty() || |
70 | out_tv->getMemoryType() == MemoryType::Global || |
71 | out_tv->getMemoryType() == MemoryType::Shared; |
72 | if (!should_predicate) { |
73 | return; |
74 | } |
75 | |
76 | const auto thread_pred = isReductionInitExpr(expr) |
77 | ? GpuLower::current()->kernel()->trueVal() |
78 | : GpuLower::current()->threadPredMap().getPredicate(out_tv); |
79 | |
80 | // When this expr is in an unswitched block, only attach the |
81 | // thread predicate to the expr as thread predicates are not |
82 | // grouped to the unswitch predicate. |
83 | kir::Predicate* thread_pred_expr = nullptr; |
84 | if (unswitched_loop_) { |
85 | thread_pred_expr = IrBuilder::create<kir::Predicate>(thread_pred); |
86 | } |
87 | |
88 | non_trivial_pred_found_ = true; |
89 | |
90 | Expr* expr_with_predicate = expr; |
91 | |
92 | // When a predicate needs to account for ShiftOp, it is currently |
93 | // taken care by its own function. |
94 | if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) { |
95 | expr_with_predicate = ShiftPredicateInserter::insert( |
96 | expr, for_loops_, thread_pred, unswitched_loop_); |
97 | if (expr_with_predicate != expr) { |
98 | registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); |
99 | } |
100 | return; |
101 | } |
102 | |
103 | // Reduction may need a separate predicate for writes. |
104 | if (!isReductionInitExpr(expr) && out_tv->domain()->hasReduction()) { |
105 | const auto write_pred = unswitched_loop_ |
106 | ? thread_pred_expr |
107 | : IrBuilder::create<kir::Predicate>( |
108 | PredicateType::ReductionWrite, expr, thread_pred); |
109 | expr_with_predicate = expr_with_predicate->withWritePredicate(write_pred); |
110 | } |
111 | |
112 | // For expr calling a device func with block sync, don't create |
113 | // if-then-else but pass the predicate to the device func |
114 | if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { |
115 | const auto pred = unswitched_loop_ |
116 | ? thread_pred_expr |
117 | : IrBuilder::create<kir::Predicate>( |
118 | PredicateType::Inline, expr, thread_pred); |
119 | expr_with_predicate = expr_with_predicate->withPredicate(pred); |
120 | registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); |
121 | return; |
122 | } |
123 | |
124 | // Vectorized expressions should never use inline predicates |
125 | kir::Predicate* pred = nullptr; |
126 | if (!unswitched_loop_ && |
127 | std::any_of( |
128 | for_loops_.begin(), for_loops_.end(), [](const kir::ForLoop* fl) { |
129 | return fl->iter_domain()->getParallelType() == |
130 | ParallelType::Vectorize; |
131 | })) { |
132 | pred = IrBuilder::create<kir::Predicate>(PredicateType::Vectorize); |
133 | } |
134 | |
135 | if (pred == nullptr) { |
136 | pred = unswitched_loop_ ? thread_pred_expr |
137 | : IrBuilder::create<kir::Predicate>( |
138 | PredicateType::Inline, expr, thread_pred); |
139 | } |
140 | |
141 | if (lower_utils::supportInlinePredicate(expr)) { |
142 | expr_with_predicate = expr_with_predicate->withPredicate(pred); |
143 | registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); |
144 | return; |
145 | } |
146 | |
147 | // If we need a predicate, put expr inside an if then else |
148 | kir::IfThenElse* inline_ite = IrBuilder::create<kir::IfThenElse>(pred); |
149 | if (for_loops_.empty()) { |
150 | // Special handling for top level output expressions that still |
151 | // need predicates. One motivating example is a reduction op that |
152 | // reduces to a scalar (issue #491) |
153 | kir::ExprMutator::registerReplace(expr, inline_ite, nullptr); |
154 | } else { |
155 | kir::ExprMutator::registerReplace( |
156 | expr, inline_ite, &for_loops_.back()->body()); |
157 | } |
158 | if (expr != expr_with_predicate) { |
159 | GpuLower::current()->propagateExprInfo(expr, expr_with_predicate); |
160 | } |
161 | inline_ite->thenBody().push_back(expr_with_predicate); |
162 | } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
163 | handle(for_loop); |
164 | } |
165 | } |
166 | |
167 | // We should factor our actual predicate generation from unrolling but insering |
168 | // IR nodes "unroll_pred" or "inline_pred", then generate those later. |
169 | void UnrollPass::handle(kir::ForLoop* fl) { |
170 | // Setup for loop scoping |
171 | const bool is_unroll = |
172 | fl->iter_domain()->getParallelType() == ParallelType::Unroll || |
173 | fl->iter_domain()->getParallelType() == ParallelType::Unswitch; |
174 | |
175 | // If we're not looking for an unroll loop, or didn't find one, process as |
176 | // normal. |
177 | if (!is_unroll || !look_for_unroll_) { |
178 | for_loops_.push_back(fl); |
179 | |
180 | // Make copy of exprs because we replace them inplace in fl |
181 | const auto exprs_copy = fl->body().exprs(); |
182 | |
183 | // Skip Misaligned Vectorization For-Loops here |
184 | if (!containsAnyDirectChildMisalignedVectorize(fl)) { |
185 | for (auto expr : exprs_copy) { |
186 | handle(expr); |
187 | } |
188 | } |
189 | |
190 | for_loops_.pop_back(); |
191 | return; |
192 | } |
193 | |
194 | auto unroll_pred = IrBuilder::create<kir::Predicate>(fl); |
195 | |
196 | kir::IfThenElse* unroll_ite = IrBuilder::create<kir::IfThenElse>(unroll_pred); |
197 | |
198 | // Get the loop nest for the unrolled path |
199 | kir::ForLoop* unrolled_loop_nest = cloneLoopNest(fl); |
200 | |
201 | // Thread predicates are not removed from the expressions. Visit |
202 | // each expression to attach kir::Predicate. |
203 | unswitched_loop_ = true; |
204 | look_for_unroll_ = false; |
205 | handle(unrolled_loop_nest); |
206 | unswitched_loop_ = false; |
207 | look_for_unroll_ = true; |
208 | |
209 | unroll_ite->thenBody().push_back(unrolled_loop_nest); |
210 | |
211 | // Loop nest for inlined path |
212 | kir::ForLoop* inlined_loop = cloneLoopNest(fl); |
213 | |
214 | // Add inline predicates for inlined loop nest |
215 | look_for_unroll_ = false; |
216 | non_trivial_pred_found_ = false; |
217 | handle(inlined_loop); |
218 | look_for_unroll_ = true; |
219 | if (!non_trivial_pred_found_) { |
220 | kir::ExprMutator::registerReplace( |
221 | fl, |
222 | inlined_loop, |
223 | for_loops_.empty() ? nullptr : &for_loops_.back()->body()); |
224 | } else { |
225 | if (!canOmitElseClause(fl)) { |
226 | unroll_ite->elseBody().push_back(inlined_loop); |
227 | } |
228 | kir::ExprMutator::registerReplace( |
229 | fl, |
230 | unroll_ite, |
231 | for_loops_.empty() ? nullptr : &for_loops_.back()->body()); |
232 | } |
233 | } |
234 | |
235 | bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { |
236 | kir::ExpressionEvaluator eval; |
237 | std::vector<kir::ForLoop*> loops({fl}); |
238 | |
239 | const auto& pred_map = GpuLower::current()->threadPredMap(); |
240 | |
241 | while (loops.size() > 0) { |
242 | auto loop = loops.back(); |
243 | loops.pop_back(); |
244 | |
245 | // If there's any expression that requires barrier |
246 | // synchronization, the else part can't be omitted |
247 | for (auto expr : loop->body().exprs()) { |
248 | if (lower_utils::hasBlockSync(expr, pred_map)) { |
249 | return false; |
250 | } |
251 | } |
252 | // If the number of visits of the loop body per thread is one, the |
253 | // unswitch predicate is sufficient. |
254 | // When the loop stop is the same as the extent of its IterDomain, |
255 | // the per-thread visit count is guaranteed to be one at most (see |
256 | // CudaKernelGenerator::handle(kir::ForLoop*) as well. Also, when a |
257 | // loop is vectorized (not misaligned), the count must be one at |
258 | // most. Even if not parallelized nor vectoirzed, it is also |
259 | // sufficient if the loop stop is in fact one. |
260 | bool visit_once = false; |
261 | auto id = loop->iter_domain(); |
262 | if ((id->isThread() && (loop->stop() == id->extent())) || |
263 | id->getParallelType() == ParallelType::Vectorize) { |
264 | visit_once = true; |
265 | } |
266 | if (!visit_once) { |
267 | const auto result = eval.evaluate(loop->stop()); |
268 | if (result.has_value() && result.value() == 1) { |
269 | visit_once = true; |
270 | } |
271 | } |
272 | |
273 | // The visit count is not guaranteed to be one, so the else part |
274 | // must be created. |
275 | if (!visit_once) { |
276 | return false; |
277 | } |
278 | |
279 | // The unswitch predicate is sufficient for this loop. Proceed to |
280 | // nested loops. |
281 | for (auto nested_loop : |
282 | ir_utils::filterByType<kir::ForLoop>(loop->body().exprs())) { |
283 | loops.push_back(nested_loop); |
284 | } |
285 | } |
286 | |
287 | return true; |
288 | } |
289 | |
290 | UnrollPass::UnrollPass(const std::vector<Expr*>& exprs) { |
291 | kir::ExprMutator::traverseAndInsert(exprs); |
292 | } |
293 | |
294 | std::vector<Expr*> UnrollPass::runPass( |
295 | Fusion* fusion, |
296 | const std::vector<Expr*>& exprs) { |
297 | FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::runPass" ); |
298 | |
299 | UnrollPass unroll_pass(exprs); |
300 | return unroll_pass.exprs_; |
301 | } |
302 | |
303 | } // namespace cuda |
304 | } // namespace fuser |
305 | } // namespace jit |
306 | } // namespace torch |
307 | |