1#include <lower_unroll.h>
2
3#include <arith.h>
4#include <index_compute.h>
5#include <instrumentation.h>
6#include <ir_iostream.h>
7#include <ir_utils.h>
8#include <kernel_expr_evaluator.h>
9#include <lower2device.h>
10#include <lower_misaligned_vectorization.h>
11#include <lower_utils.h>
12#include <predicate_compute.h>
13
14namespace torch {
15namespace jit {
16namespace fuser {
17namespace cuda {
18
19namespace {
20
21// Provide a new for loop matching the one provided
22kir::ForLoop* cloneLoopNest(const kir::ForLoop* for_loop) {
23 const auto new_loop = IrBuilder::create<kir::ForLoop>(for_loop);
24 for (auto expr : for_loop->body().exprs()) {
25 if (auto nested_for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
26 expr = cloneLoopNest(nested_for_loop);
27 }
28 new_loop->body().push_back(expr);
29 }
30 return new_loop;
31}
32
33// Returns true if expr is an expression that initializes a reduction
34// buffer.
35bool isReductionInitExpr(const Expr* expr) {
36 // False if its output isn't a TensorView
37 if (!ir_utils::isTvOp(expr)) {
38 return false;
39 }
40 // False if it doesn't have any reduction axis
41 const auto out_tv = expr->outputs()[0]->as<TensorView>();
42 if (!out_tv->domain()->hasReduction()) {
43 return false;
44 }
45 // False if it has have TensorView inputs as initialization should
46 // never use TensorViews
47 const auto tv_filter_inp_view =
48 ir_utils::filterByType<TensorView>(expr->inputs());
49 if (tv_filter_inp_view.begin() != tv_filter_inp_view.end()) {
50 return false;
51 }
52 return true;
53}
54
55} // namespace
56
57void UnrollPass::registerReplace(
58 Expr* reference,
59 Expr* new_expr,
60 kir::Scope* scope) {
61 kir::ExprMutator::registerReplace(reference, new_expr, scope);
62 GpuLower::current()->propagateExprInfo(reference, new_expr);
63}
64
65void UnrollPass::handle(Expr* expr) {
66 if (ir_utils::isTvOp(expr)) {
67 // If tv op, predicate it
68 const auto out_tv = ir_utils::getTvOutput(expr);
69 const bool should_predicate = !for_loops_.empty() ||
70 out_tv->getMemoryType() == MemoryType::Global ||
71 out_tv->getMemoryType() == MemoryType::Shared;
72 if (!should_predicate) {
73 return;
74 }
75
76 const auto thread_pred = isReductionInitExpr(expr)
77 ? GpuLower::current()->kernel()->trueVal()
78 : GpuLower::current()->threadPredMap().getPredicate(out_tv);
79
80 // When this expr is in an unswitched block, only attach the
81 // thread predicate to the expr as thread predicates are not
82 // grouped to the unswitch predicate.
83 kir::Predicate* thread_pred_expr = nullptr;
84 if (unswitched_loop_) {
85 thread_pred_expr = IrBuilder::create<kir::Predicate>(thread_pred);
86 }
87
88 non_trivial_pred_found_ = true;
89
90 Expr* expr_with_predicate = expr;
91
92 // When a predicate needs to account for ShiftOp, it is currently
93 // taken care by its own function.
94 if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) {
95 expr_with_predicate = ShiftPredicateInserter::insert(
96 expr, for_loops_, thread_pred, unswitched_loop_);
97 if (expr_with_predicate != expr) {
98 registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
99 }
100 return;
101 }
102
103 // Reduction may need a separate predicate for writes.
104 if (!isReductionInitExpr(expr) && out_tv->domain()->hasReduction()) {
105 const auto write_pred = unswitched_loop_
106 ? thread_pred_expr
107 : IrBuilder::create<kir::Predicate>(
108 PredicateType::ReductionWrite, expr, thread_pred);
109 expr_with_predicate = expr_with_predicate->withWritePredicate(write_pred);
110 }
111
112 // For expr calling a device func with block sync, don't create
113 // if-then-else but pass the predicate to the device func
114 if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) {
115 const auto pred = unswitched_loop_
116 ? thread_pred_expr
117 : IrBuilder::create<kir::Predicate>(
118 PredicateType::Inline, expr, thread_pred);
119 expr_with_predicate = expr_with_predicate->withPredicate(pred);
120 registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
121 return;
122 }
123
124 // Vectorized expressions should never use inline predicates
125 kir::Predicate* pred = nullptr;
126 if (!unswitched_loop_ &&
127 std::any_of(
128 for_loops_.begin(), for_loops_.end(), [](const kir::ForLoop* fl) {
129 return fl->iter_domain()->getParallelType() ==
130 ParallelType::Vectorize;
131 })) {
132 pred = IrBuilder::create<kir::Predicate>(PredicateType::Vectorize);
133 }
134
135 if (pred == nullptr) {
136 pred = unswitched_loop_ ? thread_pred_expr
137 : IrBuilder::create<kir::Predicate>(
138 PredicateType::Inline, expr, thread_pred);
139 }
140
141 if (lower_utils::supportInlinePredicate(expr)) {
142 expr_with_predicate = expr_with_predicate->withPredicate(pred);
143 registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
144 return;
145 }
146
147 // If we need a predicate, put expr inside an if then else
148 kir::IfThenElse* inline_ite = IrBuilder::create<kir::IfThenElse>(pred);
149 if (for_loops_.empty()) {
150 // Special handling for top level output expressions that still
151 // need predicates. One motivating example is a reduction op that
152 // reduces to a scalar (issue #491)
153 kir::ExprMutator::registerReplace(expr, inline_ite, nullptr);
154 } else {
155 kir::ExprMutator::registerReplace(
156 expr, inline_ite, &for_loops_.back()->body());
157 }
158 if (expr != expr_with_predicate) {
159 GpuLower::current()->propagateExprInfo(expr, expr_with_predicate);
160 }
161 inline_ite->thenBody().push_back(expr_with_predicate);
162 } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
163 handle(for_loop);
164 }
165}
166
167// We should factor our actual predicate generation from unrolling but insering
168// IR nodes "unroll_pred" or "inline_pred", then generate those later.
169void UnrollPass::handle(kir::ForLoop* fl) {
170 // Setup for loop scoping
171 const bool is_unroll =
172 fl->iter_domain()->getParallelType() == ParallelType::Unroll ||
173 fl->iter_domain()->getParallelType() == ParallelType::Unswitch;
174
175 // If we're not looking for an unroll loop, or didn't find one, process as
176 // normal.
177 if (!is_unroll || !look_for_unroll_) {
178 for_loops_.push_back(fl);
179
180 // Make copy of exprs because we replace them inplace in fl
181 const auto exprs_copy = fl->body().exprs();
182
183 // Skip Misaligned Vectorization For-Loops here
184 if (!containsAnyDirectChildMisalignedVectorize(fl)) {
185 for (auto expr : exprs_copy) {
186 handle(expr);
187 }
188 }
189
190 for_loops_.pop_back();
191 return;
192 }
193
194 auto unroll_pred = IrBuilder::create<kir::Predicate>(fl);
195
196 kir::IfThenElse* unroll_ite = IrBuilder::create<kir::IfThenElse>(unroll_pred);
197
198 // Get the loop nest for the unrolled path
199 kir::ForLoop* unrolled_loop_nest = cloneLoopNest(fl);
200
201 // Thread predicates are not removed from the expressions. Visit
202 // each expression to attach kir::Predicate.
203 unswitched_loop_ = true;
204 look_for_unroll_ = false;
205 handle(unrolled_loop_nest);
206 unswitched_loop_ = false;
207 look_for_unroll_ = true;
208
209 unroll_ite->thenBody().push_back(unrolled_loop_nest);
210
211 // Loop nest for inlined path
212 kir::ForLoop* inlined_loop = cloneLoopNest(fl);
213
214 // Add inline predicates for inlined loop nest
215 look_for_unroll_ = false;
216 non_trivial_pred_found_ = false;
217 handle(inlined_loop);
218 look_for_unroll_ = true;
219 if (!non_trivial_pred_found_) {
220 kir::ExprMutator::registerReplace(
221 fl,
222 inlined_loop,
223 for_loops_.empty() ? nullptr : &for_loops_.back()->body());
224 } else {
225 if (!canOmitElseClause(fl)) {
226 unroll_ite->elseBody().push_back(inlined_loop);
227 }
228 kir::ExprMutator::registerReplace(
229 fl,
230 unroll_ite,
231 for_loops_.empty() ? nullptr : &for_loops_.back()->body());
232 }
233}
234
235bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) {
236 kir::ExpressionEvaluator eval;
237 std::vector<kir::ForLoop*> loops({fl});
238
239 const auto& pred_map = GpuLower::current()->threadPredMap();
240
241 while (loops.size() > 0) {
242 auto loop = loops.back();
243 loops.pop_back();
244
245 // If there's any expression that requires barrier
246 // synchronization, the else part can't be omitted
247 for (auto expr : loop->body().exprs()) {
248 if (lower_utils::hasBlockSync(expr, pred_map)) {
249 return false;
250 }
251 }
252 // If the number of visits of the loop body per thread is one, the
253 // unswitch predicate is sufficient.
254 // When the loop stop is the same as the extent of its IterDomain,
255 // the per-thread visit count is guaranteed to be one at most (see
256 // CudaKernelGenerator::handle(kir::ForLoop*) as well. Also, when a
257 // loop is vectorized (not misaligned), the count must be one at
258 // most. Even if not parallelized nor vectoirzed, it is also
259 // sufficient if the loop stop is in fact one.
260 bool visit_once = false;
261 auto id = loop->iter_domain();
262 if ((id->isThread() && (loop->stop() == id->extent())) ||
263 id->getParallelType() == ParallelType::Vectorize) {
264 visit_once = true;
265 }
266 if (!visit_once) {
267 const auto result = eval.evaluate(loop->stop());
268 if (result.has_value() && result.value() == 1) {
269 visit_once = true;
270 }
271 }
272
273 // The visit count is not guaranteed to be one, so the else part
274 // must be created.
275 if (!visit_once) {
276 return false;
277 }
278
279 // The unswitch predicate is sufficient for this loop. Proceed to
280 // nested loops.
281 for (auto nested_loop :
282 ir_utils::filterByType<kir::ForLoop>(loop->body().exprs())) {
283 loops.push_back(nested_loop);
284 }
285 }
286
287 return true;
288}
289
290UnrollPass::UnrollPass(const std::vector<Expr*>& exprs) {
291 kir::ExprMutator::traverseAndInsert(exprs);
292}
293
294std::vector<Expr*> UnrollPass::runPass(
295 Fusion* fusion,
296 const std::vector<Expr*>& exprs) {
297 FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::runPass");
298
299 UnrollPass unroll_pass(exprs);
300 return unroll_pass.exprs_;
301}
302
303} // namespace cuda
304} // namespace fuser
305} // namespace jit
306} // namespace torch
307