1#include <lower2device.h>
2
3#include <ATen/cuda/CUDAContext.h>
4#include <expr_evaluator.h>
5#include <fusion.h>
6#include <instrumentation.h>
7#include <ir_iostream.h>
8#include <ir_utils.h>
9#include <lower_alias_memory.h>
10#include <lower_allocation.h>
11#include <lower_divisible_split.h>
12#include <lower_double_buffer.h>
13#include <lower_expr_sort.h>
14#include <lower_fusion_simplifier.h>
15#include <lower_index.h>
16#include <lower_insert_syncs.h>
17#include <lower_instrument.h>
18#include <lower_loops.h>
19#include <lower_magic_zero.h>
20#include <lower_misaligned_vectorization.h>
21#include <lower_predicate.h>
22#include <lower_replace_size.h>
23#include <lower_shift.h>
24#include <lower_trivial_reductions.h>
25#include <lower_unroll.h>
26#include <lower_utils.h>
27#include <lower_validation.h>
28#include <lower_warp_reduce.h>
29
30#include <list>
31#include <unordered_map>
32#include <unordered_set>
33
34namespace torch {
35namespace jit {
36namespace fuser {
37namespace cuda {
38
39thread_local GpuLower* active_gpu_lower = nullptr; // NOLINT
40namespace {
41
42class KIRCleaner : public OptOutDispatch {
43 public:
44 //! Remove nop IR nodes
45 static std::vector<Expr*> cleanUp(const std::vector<Expr*>& loop_nests) {
46 KIRCleaner cleaner;
47 std::vector<Expr*> out_loop_nests;
48 for (auto loop_nest : loop_nests) {
49 cleaner.handle(loop_nest);
50 // No need to keep the loop nest if it's determined to be nop
51 if (!cleaner.is_nop_) {
52 out_loop_nests.push_back(loop_nest);
53 }
54 }
55 return out_loop_nests;
56 }
57
58 private:
59 using OptOutDispatch::handle;
60 void handle(Expr* expr) final {
61 if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
62 OptOutDispatch::handle(expr);
63 } else {
64 // Any non-scoping expr is not considered nop
65 is_nop_ = false;
66 }
67 }
68
69 void handle(kir::ForLoop* fl) final {
70 auto exprs = fl->body().exprs();
71 fl->body().clear();
72 for (auto expr : exprs) {
73 handle(expr);
74 // Add the expr to the loop body only when the expr is not nop
75 if (!is_nop_) {
76 fl->body().push_back(expr);
77 }
78 }
79 // The loop is nop when no expr exists in the body
80 is_nop_ = fl->body().empty();
81 }
82
83 void handle(kir::IfThenElse* ite) final {
84 const auto conditional = ite->predicate()->value();
85
86 // Visit the then block
87 auto then_exprs = ite->thenBody().exprs();
88 ite->thenBody().clear();
89 if (!conditional->isConst() || conditional->value().value()) {
90 for (auto expr : then_exprs) {
91 handle(expr);
92 if (!is_nop_) {
93 ite->thenBody().push_back(expr);
94 }
95 }
96 }
97
98 const bool then_nop = ite->thenBody().empty();
99
100 // Visit the else block
101 auto else_exprs = ite->elseBody().exprs();
102 ite->elseBody().clear();
103 if (!conditional->isConst() || !conditional->value().value()) {
104 for (auto expr : else_exprs) {
105 handle(expr);
106 if (!is_nop_) {
107 ite->elseBody().push_back(expr);
108 }
109 }
110 }
111
112 const bool else_nop = ite->elseBody().empty();
113
114 // If the then block is nop but the else is not, invert the
115 // conditional and move the exprs in the else block to the then
116 // block.
117 if (then_nop && !else_nop) {
118 Bool* pred = ite->predicate()->value();
119 Bool* not_pred = SimplifyingIrBuilder::notExpr(pred)->as<Bool>();
120 ite->predicate()->setValue(not_pred);
121 for (auto expr : ite->elseBody().exprs()) {
122 ite->thenBody().push_back(expr);
123 }
124 ite->elseBody().clear();
125 }
126
127 // This IfThenElse is nop if both the then and else blocks are nop
128 is_nop_ = then_nop && else_nop;
129 }
130
131 private:
132 //! True if the last visited expr is nop
133 bool is_nop_ = false;
134};
135
136} // namespace
137
138void GpuLower::collectPaddedParallelDims() {
139 ExpressionEvaluator ee(fusion_);
140 bool can_be_single_warp = true;
141
142 auto warp_size = at::cuda::warp_size();
143
144 auto used_vals = fusion_->usedMathVals();
145 for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
146 for (auto id : tv->domain()->domain()) {
147 if (tv->definition()) {
148 // TODO: Support GroupedReductionOp
149 if (auto reduction = dynamic_cast<ReductionOp*>(tv->definition())) {
150 if (ir_utils::getMaybeWarpReductionDim(
151 reduction->out(), reduction->in())
152 .has_value()) {
153 warp_pad_info_.has_warp_reduction = true;
154 }
155 }
156 }
157
158 // Check ifi TIDx is padded in this kernel
159 if (id->hasPaddingToMultipleOfWarp()) {
160 TORCH_INTERNAL_ASSERT(
161 id->getParallelType() == ParallelType::TIDx,
162 "Padded types supported only on TIDx");
163 warp_pad_info_.is_tidx_padded = true;
164 }
165
166 // Check all possible bindings of TIDx to see
167 // if TIDx will eventually be bound to a single warp.
168 if (id->getParallelType() == ParallelType::TIDx) {
169 auto eval_dim = ee.evaluate(id->extent());
170 auto size_after_padding = id->getMaybeSizeAfterPadding();
171 bool padding_to_single_warp = size_after_padding.has_value() &&
172 size_after_padding.value() == warp_size;
173
174 if ((!eval_dim.has_value() || eval_dim.value() > warp_size) &&
175 !padding_to_single_warp) {
176 // If we see any other TIDx binding that's larger than
177 // a warp or unknown, we shouldn't lower warp reduce
178 // to a single warp type.
179 can_be_single_warp = false;
180 warp_pad_info_.is_tidx_single_warp = false;
181 } else if (can_be_single_warp) {
182 if (padding_to_single_warp ||
183 (eval_dim.has_value() && eval_dim.value() == warp_size)) {
184 warp_pad_info_.is_tidx_single_warp = true;
185 }
186 }
187 }
188 }
189 }
190}
191
192void assignRNGOffset(Fusion* fusion) {
193 int counter = 0;
194 for (auto expr : fusion->exprs()) {
195 if (expr->isA<RNGOp>()) {
196 auto rop = expr->as<RNGOp>();
197 rop->setRNGOffset(counter++);
198 }
199 }
200}
201
202void GpuLower::lower(Fusion* fusion, DataType index_type) {
203 FUSER_PERF_SCOPE("GpuLower::lower");
204 TORCH_INTERNAL_ASSERT(fusion != nullptr);
205 TORCH_INTERNAL_ASSERT(
206 active_gpu_lower == nullptr, "Nested lowering passes are not supported");
207
208 struct LowerGuard {
209 LowerGuard(GpuLower* gpu_lower) {
210 active_gpu_lower = gpu_lower;
211 }
212 ~LowerGuard() {
213 active_gpu_lower = nullptr;
214 }
215 } lower_guard(this);
216 // Copy fusion into a new kernel for processing
217 kernel_ = std::make_unique<kir::Kernel>(fusion, index_type);
218 // Alias the fusion kernel caries around as a view of itself.
219 fusion_ = kernel_.get();
220
221 // Convert tensor views of DataType::Index type to either Int or Int32
222 for (auto tv : ir_utils::allTvs(fusion_)) {
223 if (tv->dtype() == DataType::Index) {
224 tv->resolveIndexDtype();
225 }
226 }
227 assignRNGOffset(fusion_);
228
229 FusionGuard fg(fusion_);
230 // prepare for lowering
231 validateIr(fusion_);
232
233 // Checks if any TIDx dim is marked as padded to a warp. Also checks if we can
234 // determine the padding is explicitly a single warp.
235 collectPaddedParallelDims();
236
237 // Replaces integers that are tensor sizes by named scalars as "T0.size[0]"
238 replaceSymbolicSizes(fusion_);
239
240 // Traverse through reductions and termine if any iteration domains are
241 // trivial reductions. Add these iteration domains to trivial_reduction_info_
242 // which simply holds a map of which axes are trivial and which are not.
243 trivial_reduction_info_.build(fusion_);
244 // Replaces trivial reduction expressions (all id's being reduced are trivial)
245 // with set unary op
246 trivialReductionReplacement(fusion_, trivial_reduction_info_);
247
248 // Build what's refered to as the compute at map. This map contains the
249 // mappings of all iteration domains across the fusion. There are three types
250 // of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more
251 // information.
252 compute_at_map_ = std::make_shared<ComputeAtMap>(fusion_);
253
254 if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) {
255 std::cout << compute_at_map_->toString() << std::endl;
256 }
257
258 compute_at_map_->validateAndPropagatePType();
259
260 // Uses compute_at_map, find all splits that are enforced to be divisible
261 divisible_splits_ = getAllDivisibleSplits(fusion_, compute_at_map_.get());
262
263 // Used in parallel dimension map
264 concretized_broadcast_domains_ =
265 std::make_shared<const ConcretizedBroadcastDomains>(fusion_);
266
267 parallelDimensionMap().build(fusion_);
268 if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) {
269 std::cout << "Parallel dimension map:" << std::endl;
270 std::cout << parallel_dimension_map_.toString() << std::endl;
271 }
272
273 // Validate mma data format and compatibility if any on the fusion.
274 validateMma(fusion_);
275
276 // Validate swizzle usage on the fusion schedule.
277 validateSwizzle(fusion_);
278
279 // Compute thread predicates. Depends on parallel_dimension_map_
280 thread_pred_map_.build(fusion_);
281
282 // Fuse cetain patterns of reductions, such as a grid reduction
283 // followed by a grid broadcast. Only depends on parallelization and
284 // thread predicate map.
285 fuseReductionsAndBroadcasts(fusion_);
286
287 // Scan the whole fusion and build mappings about halo extensions of
288 // all IterDomains
289 halo_info_ = std::make_shared<HaloInfo>(fusion_, compute_at_map_);
290
291 // Want to run this after parallel map and halo info map are
292 // created. vectorized_accesses_ and vectorized_set_info_ are filled.
293 validateAndCollectVectorizeInfo(fusion_);
294
295 // Depends on ComputeAtMap and HaloInfo.
296 validateAndConvertIterDomainGrouping(fusion_);
297
298 // Assumes all grouped reductions are convered to
299 // GroupedReductionOp, which is done by
300 // validateAndConvertIterDomainGrouping
301 validateGroupedReductions(fusion_);
302
303 // Depends on thread_pred_map_, validates parallelization collects which
304 // tensor views need WAR or RAW syncs
305 sync_map_.build(fusion_);
306 if (isDebugDumpEnabled(DebugDumpOption::SyncMap)) {
307 std::cout << sync_map_.toString() << std::endl;
308 }
309
310 partialSplitMap().build(fusion_);
311
312 validatePartialSplit(fusion_);
313
314 nonDivisibleSplitInfo().build(fusion_);
315
316 // Detects all exprssions that don't need predicates. Depends on
317 // nonDivisibleSplitInfo.
318 predicateElimination().build(fusion_);
319
320 doubleBufferInfo().build(fusion_);
321
322 compute_at_map_->allocateIndexVariables();
323 // Run our passes keeping the lowered expressions and forwarding
324 // them
325
326 // Reorder expressions for loop-nest generation respecting computeAt
327 // relationships
328 const auto exprs_sorted = reorderExprsForComputeAt();
329
330 // Generate loop-nests and place each expression at its
331 // corresponding loop
332 const auto exprs_lowered = LoopNestGenerator::loweredExprs(exprs_sorted);
333
334 // Replace trivial reductions, Transpose, Shift, Gather, and View ops with
335 // unary ops since they're not separately processed in lowering.
336 const auto exprs_unary_replaced = unarySetOpInserter(exprs_lowered);
337
338 // Insert allocations
339 const auto exprs_alloced = insertAllocations(exprs_unary_replaced);
340
341 // Insert read after write smem syncs
342 const auto exprs_raw_sync = insertRawThreadSynchronization(exprs_alloced);
343
344 // Reuse memory locations
345 const auto exprs_reuse_mem = reuseMemoryAllocations(exprs_raw_sync);
346
347 // Insert SyncThreads at end of for-loop to avoid WAR race condition
348 const auto exprs_war_sync = insertWarThreadSynchronization(exprs_reuse_mem);
349
350 const auto exprs_double_buffered = DoubleBufferPass::run(exprs_war_sync);
351
352 // This pass inserts predicates as well as branches in the code. Up until now
353 // the code is explicitly single shot for loop based. Need to be careful in
354 // later passes when doing any kind of insertions in loop nest structure as
355 // insertions could be on if then or else instead of directly on a for loop.
356 const auto exprs_unrolled_loops =
357 UnrollPass::runPass(fusion_, exprs_double_buffered);
358
359 const auto exprs_unrolled_mv_loops =
360 processMisalignedVectorization(exprs_unrolled_loops);
361
362 const auto exprs_indexed_loops =
363 IndexLowering::getIndexedExprs(exprs_unrolled_mv_loops);
364
365 // TODO: It seems this type of optimization would be far easier to implement
366 // on fusion ir than kernel ir. We should likely refactor this to at least run
367 // before allocation insertion.
368 const auto exprs_with_fused_broadcast = fuseWarpReduce(exprs_indexed_loops);
369
370 const auto exprs_conditional_loops =
371 generateConditionalFromPredicate(exprs_with_fused_broadcast);
372
373 const auto exprs_common_index_allocated =
374 allocateCommonIndices(exprs_conditional_loops);
375
376 // Insert fake zero updates to make sure nvrtc doesn't blow out register use
377 // on index and predicate reuse
378 const auto exprs_register_adjusted =
379 insertMagicZero(exprs_common_index_allocated);
380
381 const auto exprs_cleaned_up_loops =
382 KIRCleaner::cleanUp(exprs_register_adjusted);
383
384 const auto exprs_instrumented = instrumentKernel(exprs_cleaned_up_loops);
385
386 // We now have the lowered expressions, finalize the kernel IR. This function
387 // will also copy over some relevant information for code generation from
388 // GpuLower.
389 kernel_->finalize(exprs_instrumented);
390}
391
392kir::Kernel* GpuLower::kernel() const {
393 TORCH_CHECK(kernel_);
394 return kernel_.get();
395}
396
397GpuLower* GpuLower::current() {
398 TORCH_INTERNAL_ASSERT(
399 active_gpu_lower != nullptr, "No active GpuLower available");
400 return active_gpu_lower;
401}
402
403bool GpuLower::hasCurrent() {
404 return active_gpu_lower != nullptr;
405}
406
407void GpuLower::propagateExprInfo(const Expr* old_expr, const Expr* new_expr) {
408 pred_elimination_.propagateRemovalInfo(old_expr, new_expr);
409}
410
411} // namespace cuda
412} // namespace fuser
413} // namespace jit
414} // namespace torch
415