1 | #include <lower2device.h> |
2 | |
3 | #include <ATen/cuda/CUDAContext.h> |
4 | #include <expr_evaluator.h> |
5 | #include <fusion.h> |
6 | #include <instrumentation.h> |
7 | #include <ir_iostream.h> |
8 | #include <ir_utils.h> |
9 | #include <lower_alias_memory.h> |
10 | #include <lower_allocation.h> |
11 | #include <lower_divisible_split.h> |
12 | #include <lower_double_buffer.h> |
13 | #include <lower_expr_sort.h> |
14 | #include <lower_fusion_simplifier.h> |
15 | #include <lower_index.h> |
16 | #include <lower_insert_syncs.h> |
17 | #include <lower_instrument.h> |
18 | #include <lower_loops.h> |
19 | #include <lower_magic_zero.h> |
20 | #include <lower_misaligned_vectorization.h> |
21 | #include <lower_predicate.h> |
22 | #include <lower_replace_size.h> |
23 | #include <lower_shift.h> |
24 | #include <lower_trivial_reductions.h> |
25 | #include <lower_unroll.h> |
26 | #include <lower_utils.h> |
27 | #include <lower_validation.h> |
28 | #include <lower_warp_reduce.h> |
29 | |
30 | #include <list> |
31 | #include <unordered_map> |
32 | #include <unordered_set> |
33 | |
34 | namespace torch { |
35 | namespace jit { |
36 | namespace fuser { |
37 | namespace cuda { |
38 | |
39 | thread_local GpuLower* active_gpu_lower = nullptr; // NOLINT |
40 | namespace { |
41 | |
42 | class KIRCleaner : public OptOutDispatch { |
43 | public: |
44 | //! Remove nop IR nodes |
45 | static std::vector<Expr*> cleanUp(const std::vector<Expr*>& loop_nests) { |
46 | KIRCleaner cleaner; |
47 | std::vector<Expr*> out_loop_nests; |
48 | for (auto loop_nest : loop_nests) { |
49 | cleaner.handle(loop_nest); |
50 | // No need to keep the loop nest if it's determined to be nop |
51 | if (!cleaner.is_nop_) { |
52 | out_loop_nests.push_back(loop_nest); |
53 | } |
54 | } |
55 | return out_loop_nests; |
56 | } |
57 | |
58 | private: |
59 | using OptOutDispatch::handle; |
60 | void handle(Expr* expr) final { |
61 | if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) { |
62 | OptOutDispatch::handle(expr); |
63 | } else { |
64 | // Any non-scoping expr is not considered nop |
65 | is_nop_ = false; |
66 | } |
67 | } |
68 | |
69 | void handle(kir::ForLoop* fl) final { |
70 | auto exprs = fl->body().exprs(); |
71 | fl->body().clear(); |
72 | for (auto expr : exprs) { |
73 | handle(expr); |
74 | // Add the expr to the loop body only when the expr is not nop |
75 | if (!is_nop_) { |
76 | fl->body().push_back(expr); |
77 | } |
78 | } |
79 | // The loop is nop when no expr exists in the body |
80 | is_nop_ = fl->body().empty(); |
81 | } |
82 | |
83 | void handle(kir::IfThenElse* ite) final { |
84 | const auto conditional = ite->predicate()->value(); |
85 | |
86 | // Visit the then block |
87 | auto then_exprs = ite->thenBody().exprs(); |
88 | ite->thenBody().clear(); |
89 | if (!conditional->isConst() || conditional->value().value()) { |
90 | for (auto expr : then_exprs) { |
91 | handle(expr); |
92 | if (!is_nop_) { |
93 | ite->thenBody().push_back(expr); |
94 | } |
95 | } |
96 | } |
97 | |
98 | const bool then_nop = ite->thenBody().empty(); |
99 | |
100 | // Visit the else block |
101 | auto else_exprs = ite->elseBody().exprs(); |
102 | ite->elseBody().clear(); |
103 | if (!conditional->isConst() || !conditional->value().value()) { |
104 | for (auto expr : else_exprs) { |
105 | handle(expr); |
106 | if (!is_nop_) { |
107 | ite->elseBody().push_back(expr); |
108 | } |
109 | } |
110 | } |
111 | |
112 | const bool else_nop = ite->elseBody().empty(); |
113 | |
114 | // If the then block is nop but the else is not, invert the |
115 | // conditional and move the exprs in the else block to the then |
116 | // block. |
117 | if (then_nop && !else_nop) { |
118 | Bool* pred = ite->predicate()->value(); |
119 | Bool* not_pred = SimplifyingIrBuilder::notExpr(pred)->as<Bool>(); |
120 | ite->predicate()->setValue(not_pred); |
121 | for (auto expr : ite->elseBody().exprs()) { |
122 | ite->thenBody().push_back(expr); |
123 | } |
124 | ite->elseBody().clear(); |
125 | } |
126 | |
127 | // This IfThenElse is nop if both the then and else blocks are nop |
128 | is_nop_ = then_nop && else_nop; |
129 | } |
130 | |
131 | private: |
132 | //! True if the last visited expr is nop |
133 | bool is_nop_ = false; |
134 | }; |
135 | |
136 | } // namespace |
137 | |
138 | void GpuLower::collectPaddedParallelDims() { |
139 | ExpressionEvaluator ee(fusion_); |
140 | bool can_be_single_warp = true; |
141 | |
142 | auto warp_size = at::cuda::warp_size(); |
143 | |
144 | auto used_vals = fusion_->usedMathVals(); |
145 | for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) { |
146 | for (auto id : tv->domain()->domain()) { |
147 | if (tv->definition()) { |
148 | // TODO: Support GroupedReductionOp |
149 | if (auto reduction = dynamic_cast<ReductionOp*>(tv->definition())) { |
150 | if (ir_utils::getMaybeWarpReductionDim( |
151 | reduction->out(), reduction->in()) |
152 | .has_value()) { |
153 | warp_pad_info_.has_warp_reduction = true; |
154 | } |
155 | } |
156 | } |
157 | |
158 | // Check ifi TIDx is padded in this kernel |
159 | if (id->hasPaddingToMultipleOfWarp()) { |
160 | TORCH_INTERNAL_ASSERT( |
161 | id->getParallelType() == ParallelType::TIDx, |
162 | "Padded types supported only on TIDx" ); |
163 | warp_pad_info_.is_tidx_padded = true; |
164 | } |
165 | |
166 | // Check all possible bindings of TIDx to see |
167 | // if TIDx will eventually be bound to a single warp. |
168 | if (id->getParallelType() == ParallelType::TIDx) { |
169 | auto eval_dim = ee.evaluate(id->extent()); |
170 | auto size_after_padding = id->getMaybeSizeAfterPadding(); |
171 | bool padding_to_single_warp = size_after_padding.has_value() && |
172 | size_after_padding.value() == warp_size; |
173 | |
174 | if ((!eval_dim.has_value() || eval_dim.value() > warp_size) && |
175 | !padding_to_single_warp) { |
176 | // If we see any other TIDx binding that's larger than |
177 | // a warp or unknown, we shouldn't lower warp reduce |
178 | // to a single warp type. |
179 | can_be_single_warp = false; |
180 | warp_pad_info_.is_tidx_single_warp = false; |
181 | } else if (can_be_single_warp) { |
182 | if (padding_to_single_warp || |
183 | (eval_dim.has_value() && eval_dim.value() == warp_size)) { |
184 | warp_pad_info_.is_tidx_single_warp = true; |
185 | } |
186 | } |
187 | } |
188 | } |
189 | } |
190 | } |
191 | |
192 | void assignRNGOffset(Fusion* fusion) { |
193 | int counter = 0; |
194 | for (auto expr : fusion->exprs()) { |
195 | if (expr->isA<RNGOp>()) { |
196 | auto rop = expr->as<RNGOp>(); |
197 | rop->setRNGOffset(counter++); |
198 | } |
199 | } |
200 | } |
201 | |
202 | void GpuLower::lower(Fusion* fusion, DataType index_type) { |
203 | FUSER_PERF_SCOPE("GpuLower::lower" ); |
204 | TORCH_INTERNAL_ASSERT(fusion != nullptr); |
205 | TORCH_INTERNAL_ASSERT( |
206 | active_gpu_lower == nullptr, "Nested lowering passes are not supported" ); |
207 | |
208 | struct LowerGuard { |
209 | LowerGuard(GpuLower* gpu_lower) { |
210 | active_gpu_lower = gpu_lower; |
211 | } |
212 | ~LowerGuard() { |
213 | active_gpu_lower = nullptr; |
214 | } |
215 | } lower_guard(this); |
216 | // Copy fusion into a new kernel for processing |
217 | kernel_ = std::make_unique<kir::Kernel>(fusion, index_type); |
218 | // Alias the fusion kernel caries around as a view of itself. |
219 | fusion_ = kernel_.get(); |
220 | |
221 | // Convert tensor views of DataType::Index type to either Int or Int32 |
222 | for (auto tv : ir_utils::allTvs(fusion_)) { |
223 | if (tv->dtype() == DataType::Index) { |
224 | tv->resolveIndexDtype(); |
225 | } |
226 | } |
227 | assignRNGOffset(fusion_); |
228 | |
229 | FusionGuard fg(fusion_); |
230 | // prepare for lowering |
231 | validateIr(fusion_); |
232 | |
233 | // Checks if any TIDx dim is marked as padded to a warp. Also checks if we can |
234 | // determine the padding is explicitly a single warp. |
235 | collectPaddedParallelDims(); |
236 | |
237 | // Replaces integers that are tensor sizes by named scalars as "T0.size[0]" |
238 | replaceSymbolicSizes(fusion_); |
239 | |
240 | // Traverse through reductions and termine if any iteration domains are |
241 | // trivial reductions. Add these iteration domains to trivial_reduction_info_ |
242 | // which simply holds a map of which axes are trivial and which are not. |
243 | trivial_reduction_info_.build(fusion_); |
244 | // Replaces trivial reduction expressions (all id's being reduced are trivial) |
245 | // with set unary op |
246 | trivialReductionReplacement(fusion_, trivial_reduction_info_); |
247 | |
248 | // Build what's refered to as the compute at map. This map contains the |
249 | // mappings of all iteration domains across the fusion. There are three types |
250 | // of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more |
251 | // information. |
252 | compute_at_map_ = std::make_shared<ComputeAtMap>(fusion_); |
253 | |
254 | if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) { |
255 | std::cout << compute_at_map_->toString() << std::endl; |
256 | } |
257 | |
258 | compute_at_map_->validateAndPropagatePType(); |
259 | |
260 | // Uses compute_at_map, find all splits that are enforced to be divisible |
261 | divisible_splits_ = getAllDivisibleSplits(fusion_, compute_at_map_.get()); |
262 | |
263 | // Used in parallel dimension map |
264 | concretized_broadcast_domains_ = |
265 | std::make_shared<const ConcretizedBroadcastDomains>(fusion_); |
266 | |
267 | parallelDimensionMap().build(fusion_); |
268 | if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) { |
269 | std::cout << "Parallel dimension map:" << std::endl; |
270 | std::cout << parallel_dimension_map_.toString() << std::endl; |
271 | } |
272 | |
273 | // Validate mma data format and compatibility if any on the fusion. |
274 | validateMma(fusion_); |
275 | |
276 | // Validate swizzle usage on the fusion schedule. |
277 | validateSwizzle(fusion_); |
278 | |
279 | // Compute thread predicates. Depends on parallel_dimension_map_ |
280 | thread_pred_map_.build(fusion_); |
281 | |
282 | // Fuse cetain patterns of reductions, such as a grid reduction |
283 | // followed by a grid broadcast. Only depends on parallelization and |
284 | // thread predicate map. |
285 | fuseReductionsAndBroadcasts(fusion_); |
286 | |
287 | // Scan the whole fusion and build mappings about halo extensions of |
288 | // all IterDomains |
289 | halo_info_ = std::make_shared<HaloInfo>(fusion_, compute_at_map_); |
290 | |
291 | // Want to run this after parallel map and halo info map are |
292 | // created. vectorized_accesses_ and vectorized_set_info_ are filled. |
293 | validateAndCollectVectorizeInfo(fusion_); |
294 | |
295 | // Depends on ComputeAtMap and HaloInfo. |
296 | validateAndConvertIterDomainGrouping(fusion_); |
297 | |
298 | // Assumes all grouped reductions are convered to |
299 | // GroupedReductionOp, which is done by |
300 | // validateAndConvertIterDomainGrouping |
301 | validateGroupedReductions(fusion_); |
302 | |
303 | // Depends on thread_pred_map_, validates parallelization collects which |
304 | // tensor views need WAR or RAW syncs |
305 | sync_map_.build(fusion_); |
306 | if (isDebugDumpEnabled(DebugDumpOption::SyncMap)) { |
307 | std::cout << sync_map_.toString() << std::endl; |
308 | } |
309 | |
310 | partialSplitMap().build(fusion_); |
311 | |
312 | validatePartialSplit(fusion_); |
313 | |
314 | nonDivisibleSplitInfo().build(fusion_); |
315 | |
316 | // Detects all exprssions that don't need predicates. Depends on |
317 | // nonDivisibleSplitInfo. |
318 | predicateElimination().build(fusion_); |
319 | |
320 | doubleBufferInfo().build(fusion_); |
321 | |
322 | compute_at_map_->allocateIndexVariables(); |
323 | // Run our passes keeping the lowered expressions and forwarding |
324 | // them |
325 | |
326 | // Reorder expressions for loop-nest generation respecting computeAt |
327 | // relationships |
328 | const auto exprs_sorted = reorderExprsForComputeAt(); |
329 | |
330 | // Generate loop-nests and place each expression at its |
331 | // corresponding loop |
332 | const auto exprs_lowered = LoopNestGenerator::loweredExprs(exprs_sorted); |
333 | |
334 | // Replace trivial reductions, Transpose, Shift, Gather, and View ops with |
335 | // unary ops since they're not separately processed in lowering. |
336 | const auto exprs_unary_replaced = unarySetOpInserter(exprs_lowered); |
337 | |
338 | // Insert allocations |
339 | const auto exprs_alloced = insertAllocations(exprs_unary_replaced); |
340 | |
341 | // Insert read after write smem syncs |
342 | const auto exprs_raw_sync = insertRawThreadSynchronization(exprs_alloced); |
343 | |
344 | // Reuse memory locations |
345 | const auto exprs_reuse_mem = reuseMemoryAllocations(exprs_raw_sync); |
346 | |
347 | // Insert SyncThreads at end of for-loop to avoid WAR race condition |
348 | const auto exprs_war_sync = insertWarThreadSynchronization(exprs_reuse_mem); |
349 | |
350 | const auto exprs_double_buffered = DoubleBufferPass::run(exprs_war_sync); |
351 | |
352 | // This pass inserts predicates as well as branches in the code. Up until now |
353 | // the code is explicitly single shot for loop based. Need to be careful in |
354 | // later passes when doing any kind of insertions in loop nest structure as |
355 | // insertions could be on if then or else instead of directly on a for loop. |
356 | const auto exprs_unrolled_loops = |
357 | UnrollPass::runPass(fusion_, exprs_double_buffered); |
358 | |
359 | const auto exprs_unrolled_mv_loops = |
360 | processMisalignedVectorization(exprs_unrolled_loops); |
361 | |
362 | const auto exprs_indexed_loops = |
363 | IndexLowering::getIndexedExprs(exprs_unrolled_mv_loops); |
364 | |
365 | // TODO: It seems this type of optimization would be far easier to implement |
366 | // on fusion ir than kernel ir. We should likely refactor this to at least run |
367 | // before allocation insertion. |
368 | const auto exprs_with_fused_broadcast = fuseWarpReduce(exprs_indexed_loops); |
369 | |
370 | const auto exprs_conditional_loops = |
371 | generateConditionalFromPredicate(exprs_with_fused_broadcast); |
372 | |
373 | const auto exprs_common_index_allocated = |
374 | allocateCommonIndices(exprs_conditional_loops); |
375 | |
376 | // Insert fake zero updates to make sure nvrtc doesn't blow out register use |
377 | // on index and predicate reuse |
378 | const auto exprs_register_adjusted = |
379 | insertMagicZero(exprs_common_index_allocated); |
380 | |
381 | const auto exprs_cleaned_up_loops = |
382 | KIRCleaner::cleanUp(exprs_register_adjusted); |
383 | |
384 | const auto exprs_instrumented = instrumentKernel(exprs_cleaned_up_loops); |
385 | |
386 | // We now have the lowered expressions, finalize the kernel IR. This function |
387 | // will also copy over some relevant information for code generation from |
388 | // GpuLower. |
389 | kernel_->finalize(exprs_instrumented); |
390 | } |
391 | |
392 | kir::Kernel* GpuLower::kernel() const { |
393 | TORCH_CHECK(kernel_); |
394 | return kernel_.get(); |
395 | } |
396 | |
397 | GpuLower* GpuLower::current() { |
398 | TORCH_INTERNAL_ASSERT( |
399 | active_gpu_lower != nullptr, "No active GpuLower available" ); |
400 | return active_gpu_lower; |
401 | } |
402 | |
403 | bool GpuLower::hasCurrent() { |
404 | return active_gpu_lower != nullptr; |
405 | } |
406 | |
407 | void GpuLower::propagateExprInfo(const Expr* old_expr, const Expr* new_expr) { |
408 | pred_elimination_.propagateRemovalInfo(old_expr, new_expr); |
409 | } |
410 | |
411 | } // namespace cuda |
412 | } // namespace fuser |
413 | } // namespace jit |
414 | } // namespace torch |
415 | |