1 | #include <ATen/cuda/CUDAContext.h> |
2 | #include <expr_evaluator.h> |
3 | #include <kernel_expr_evaluator.h> |
4 | #include <kernel_ir_dispatch.h> |
5 | #include <lower2device.h> |
6 | #include <lower_utils.h> |
7 | #include <lower_warp_reduce.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | namespace { |
15 | |
16 | //! A helper class for EliminateDeadBroadcastAndAllocate. Eliminate |
17 | //! dead Allocate and Broadcast detected by EliminateDeadBroadcastAndAllocate. |
18 | class DeadTvEliminator : private kir::ExprMutator { |
19 | public: |
20 | static std::vector<Expr*> run( |
21 | const std::vector<Expr*>& exprs, |
22 | const std::unordered_set<TensorView*>& dead_tvs) { |
23 | return DeadTvEliminator(exprs, dead_tvs).exprs_; |
24 | } |
25 | |
26 | private: |
27 | DeadTvEliminator( |
28 | const std::vector<Expr*>& exprs, |
29 | const std::unordered_set<TensorView*>& dead_tvs) |
30 | : dead_tvs_(dead_tvs) { |
31 | traverseAndInsert(exprs); |
32 | } |
33 | |
34 | using kir::ExprMutator::handle; |
35 | |
36 | void handle(kir::Allocate* allocate) final { |
37 | if (auto buffer_tv = dynamic_cast<TensorView*>(allocate->buffer())) { |
38 | if (dead_tvs_.count(buffer_tv)) { |
39 | registerRemove(allocate); |
40 | } |
41 | } |
42 | } |
43 | |
44 | void handle(BroadcastOp* broadcast) final { |
45 | if (auto out_ti = dynamic_cast<kir::TensorIndex*>(broadcast->out())) { |
46 | if (dead_tvs_.count(out_ti->view())) { |
47 | registerRemove(broadcast); |
48 | } |
49 | } |
50 | } |
51 | |
52 | private: |
53 | const std::unordered_set<TensorView*>& dead_tvs_; |
54 | }; |
55 | |
56 | //! A simple DCE for eliminating the |
57 | //! parallel broadcasts that has been fused |
58 | //! and their corresponding allocations |
59 | class EliminateDeadBroadcastAndAllocate { |
60 | public: |
61 | static std::vector<Expr*> run(const std::vector<Expr*>& exprs) { |
62 | EliminateDeadBroadcastAndAllocate dce(exprs); |
63 | return DeadTvEliminator::run(exprs, dce.dead_tvs_); |
64 | } |
65 | |
66 | private: |
67 | EliminateDeadBroadcastAndAllocate(const std::vector<Expr*>& exprs) { |
68 | findLiveTvs(exprs); |
69 | findDeadTvs(); |
70 | } |
71 | |
72 | void findLiveTvs(const std::vector<Expr*>& exprs) { |
73 | for (auto expr : exprs) { |
74 | if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
75 | findLiveTvs(for_loop->body().exprs()); |
76 | continue; |
77 | } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) { |
78 | findLiveTvs(ite->thenBody().exprs()); |
79 | findLiveTvs(ite->elseBody().exprs()); |
80 | continue; |
81 | } |
82 | |
83 | if (auto allocate = dynamic_cast<kir::Allocate*>(expr)) { |
84 | if (allocate->memoryType() == MemoryType::Local) { |
85 | if (auto tv = dynamic_cast<TensorView*>(allocate->buffer())) { |
86 | // We know only tvs that we'd want to consider are broadcast outputs |
87 | if (tv->definition()->isA<BroadcastOp>()) { |
88 | candidate_tv_set_.insert(tv); |
89 | } |
90 | } |
91 | } |
92 | } |
93 | |
94 | for (auto inp : expr->inputs()) { |
95 | if (auto ti = dynamic_cast<kir::TensorIndex*>(inp)) { |
96 | if (candidate_tv_set_.count(ti->view())) { |
97 | live_tvs_.insert(ti->view()); |
98 | } |
99 | } |
100 | } |
101 | } |
102 | } |
103 | |
104 | void findDeadTvs() { |
105 | for (auto tv : candidate_tv_set_) { |
106 | if (!live_tvs_.count(tv)) { |
107 | dead_tvs_.insert(tv); |
108 | } |
109 | } |
110 | } |
111 | |
112 | private: |
113 | std::unordered_set<TensorView*> live_tvs_; |
114 | std::unordered_set<TensorView*> dead_tvs_; |
115 | std::unordered_set<TensorView*> candidate_tv_set_; |
116 | }; |
117 | |
118 | //! A pass to eliminate redundant parallel broadcasts that are consumers |
119 | //! of warp reduction. |
120 | //! Detects the following pattern: |
121 | //! |
122 | //! For ... (serial) |
123 | //! For ... (serial) |
124 | //! T1[0] = warp_reduce (T0[0]) |
125 | //! T2[0] = block_broadcast (T1[0]) |
126 | //! |
127 | //! The block_broadcast can then be eliminated given that both the warp |
128 | //! reduce and the broadcast are known in compile-time to be parallelized |
129 | //! on a single warp only. |
130 | //! |
131 | //! Currently only limited to buffers of size-1 to avoid having to |
132 | //! re-run indexing |
133 | //! |
134 | //! This pass operates in 3 phases: |
135 | //! 1. FuseBroadcastWithWarpReduce identifies the broadcasts that can |
136 | //! be removed, and generates a replacement map from the broadcast |
137 | //! output to reduction output. |
138 | //! |
139 | //! 2. ir_utils::replaceInputsInExpr replaces applicable uses of |
140 | //! the broadcast output with the corresponding reduction output. |
141 | //! |
142 | //! 3. EliminateDeadBroadcastAndAllocate removes the broadcast ops |
143 | //! and corresponding allocations if they're un-used after step 2. |
144 | class FuseBroadcastWithWarpReduce : private kir::IrVisitor { |
145 | public: |
146 | static std::vector<Expr*> fuse(const std::vector<Expr*>& exprs) { |
147 | FuseBroadcastWithWarpReduce fuse_broadcast_map(exprs); |
148 | const auto replaced_inputs = ir_utils::replaceInputsInExpr( |
149 | exprs, fuse_broadcast_map.val_replacement_map_); |
150 | return EliminateDeadBroadcastAndAllocate::run(replaced_inputs); |
151 | } |
152 | |
153 | private: |
154 | FuseBroadcastWithWarpReduce(const std::vector<Expr*>& exprs) { |
155 | // open stack space for global scope |
156 | // The scope stack for tv_to_allocate wouldn't be needed |
157 | // if the allocations are guaranteed to be once and unique, |
158 | // which can currently be assumed but this pass tries not |
159 | // to rely on this assumption. |
160 | running_tv_to_allocate_map_.emplace_back( |
161 | std::make_unique<std::unordered_map<TensorView*, kir::Allocate*>>()); |
162 | running_visible_allocation_stack_.emplace_back( |
163 | std::make_unique<std::vector<kir::Allocate*>>()); |
164 | kir::IrVisitor::handle(exprs); |
165 | } |
166 | |
167 | void handle(Expr* expr) final { |
168 | if (ir_utils::isTvOp(expr)) { |
169 | // Process expr inputs if needs replacement |
170 | for (auto inp : expr->inputs()) { |
171 | if (auto input_ti = dynamic_cast<kir::TensorIndex*>(inp)) { |
172 | auto replace = findMaybeReplacedTensorIndex(input_ti); |
173 | if (replace.has_value()) { |
174 | val_replacement_map_[input_ti] = replace.value(); |
175 | } |
176 | } |
177 | } |
178 | } |
179 | kir::IrVisitor::handle(expr); |
180 | } |
181 | |
182 | bool openLoopNestLevel(IterDomain* id) { |
183 | if (id->isThread() || id->getParallelType() == ParallelType::Unswitch) { |
184 | return false; |
185 | } |
186 | if (id->getParallelType() == ParallelType::Serial || |
187 | id->getParallelType() == ParallelType::Unroll) { |
188 | return !id->isBroadcast(); |
189 | } |
190 | return true; |
191 | } |
192 | |
193 | void handle(kir::ForLoop* for_loop) final { |
194 | // Keep track of visible reduction outputs |
195 | bool open_nest_level = openLoopNestLevel(for_loop->iter_domain()); |
196 | if (open_nest_level) { |
197 | running_tv_to_allocate_map_.emplace_back( |
198 | std::make_unique<std::unordered_map<TensorView*, kir::Allocate*>>()); |
199 | running_visible_allocation_stack_.emplace_back( |
200 | std::make_unique<std::vector<kir::Allocate*>>()); |
201 | } |
202 | for (auto expr : for_loop->body().exprs()) { |
203 | handle(expr); |
204 | } |
205 | if (open_nest_level) { |
206 | running_tv_to_allocate_map_.pop_back(); |
207 | running_visible_allocation_stack_.pop_back(); |
208 | } |
209 | } |
210 | |
211 | void handle(kir::IfThenElse* ite) final { |
212 | running_visible_allocation_stack_.emplace_back( |
213 | std::make_unique<std::vector<kir::Allocate*>>()); |
214 | for (auto expr : ite->thenBody().exprs()) { |
215 | handle(expr); |
216 | } |
217 | running_visible_allocation_stack_.pop_back(); |
218 | running_visible_allocation_stack_.emplace_back( |
219 | std::make_unique<std::vector<kir::Allocate*>>()); |
220 | for (auto expr : ite->elseBody().exprs()) { |
221 | handle(expr); |
222 | } |
223 | running_visible_allocation_stack_.pop_back(); |
224 | } |
225 | |
226 | //! Place this allocate on the list of currently visible allocations, |
227 | //! organized by loop nest level. |
228 | void handle(kir::Allocate* allocate) final { |
229 | if (allocate->memoryType() != MemoryType::Local) { |
230 | return; |
231 | } |
232 | if (auto tv = dynamic_cast<TensorView*>(allocate->buffer())) { |
233 | if (tv->definition()) { |
234 | if (tv->definition()->isA<ReductionOp>() || |
235 | tv->definition()->isA<BroadcastOp>()) { |
236 | running_visible_allocation_stack_.back()->push_back(allocate); |
237 | } |
238 | } |
239 | } |
240 | } |
241 | |
242 | //! Checks if the given tv has been replaced by broadcast fusion. |
243 | //! returns the replaced TensorIndex if so. |
244 | c10::optional<kir::TensorIndex*> findMaybeReplacedTensorIndex( |
245 | kir::TensorIndex* tensor_index) { |
246 | auto tv = tensor_index->view(); |
247 | auto tensor_index_it = running_tv_replacement_map_.find(tv); |
248 | if (tensor_index_it != running_tv_replacement_map_.end()) { |
249 | return tensor_index_it->second; |
250 | } |
251 | return c10::nullopt; |
252 | } |
253 | |
254 | //! Iterate backwards on the currently visible loop scopes |
255 | //! and find the first allocation corresponding to the |
256 | //! given tv. |
257 | kir::Allocate* getActiveAllocateFor(TensorView* tv) { |
258 | for (auto frame_it = running_visible_allocation_stack_.rbegin(); |
259 | frame_it != running_visible_allocation_stack_.rend(); |
260 | frame_it++) { |
261 | for (auto allocate_it = (*frame_it)->rbegin(); |
262 | allocate_it != (*frame_it)->rend(); |
263 | allocate_it++) { |
264 | auto candidate_allocate = *allocate_it; |
265 | if (candidate_allocate->buffer() == tv) { |
266 | return candidate_allocate; |
267 | } |
268 | } |
269 | } |
270 | TORCH_INTERNAL_ASSERT( |
271 | false, "lower_warp_reduce: cannot find allocation for this op" ); |
272 | return nullptr; |
273 | } |
274 | |
275 | bool isOpInputRegisterTV(Expr* expr) { |
276 | for (auto inp : expr->inputs()) { |
277 | if (auto inp_ti = dynamic_cast<kir::TensorIndex*>(inp)) { |
278 | if (inp_ti->view()->getMemoryType() != MemoryType::Local) { |
279 | return false; |
280 | } |
281 | } |
282 | } |
283 | |
284 | return true; |
285 | } |
286 | |
287 | bool isOpOutputRegisterTV(Expr* expr) { |
288 | for (auto out : expr->outputs()) { |
289 | if (auto out_ti = dynamic_cast<kir::TensorIndex*>(out)) { |
290 | if (out_ti->view()->getMemoryType() != MemoryType::Local) { |
291 | return false; |
292 | } |
293 | } |
294 | } |
295 | |
296 | return true; |
297 | } |
298 | |
299 | //! Updates map of serially visible reduction tvs, see comment on |
300 | //! running_tv_to_allocate_map_. |
301 | void handle(ReductionOp* reduction) final { |
302 | if (!isOpOutputRegisterTV(reduction)) { |
303 | return; |
304 | } |
305 | auto reduction_ti_out = dynamic_cast<kir::TensorIndex*>(reduction->out()); |
306 | TORCH_INTERNAL_ASSERT( |
307 | reduction_ti_out, |
308 | "lower_warp_reduce: Pass needs to be run after indexing" ); |
309 | |
310 | // keep track of which reduction buffer this expr writes into |
311 | auto reduction_allocate = getActiveAllocateFor(reduction_ti_out->view()); |
312 | running_tv_to_allocate_map_.back()->operator[](reduction_ti_out->view()) = |
313 | reduction_allocate; |
314 | } |
315 | |
316 | void handle(BroadcastOp* broadcast) final { |
317 | if (!isOpInputRegisterTV(broadcast) || !isOpOutputRegisterTV(broadcast)) { |
318 | return; |
319 | } |
320 | tryAddOutputToReplaceMap(broadcast); |
321 | } |
322 | |
323 | //! Detects if this broadcast can be fused with the producer reduction. |
324 | //! adds the output of broadcast to replacement map if all above mentioned |
325 | //! conditions check. |
326 | void tryAddOutputToReplaceMap(BroadcastOp* broadcast) { |
327 | if (auto in_ti = dynamic_cast<kir::TensorIndex*>(broadcast->in())) { |
328 | if (!in_ti->view()->definition()->isA<ReductionOp>()) { |
329 | return; |
330 | } |
331 | auto out_ti = broadcast->out()->as<kir::TensorIndex>(); |
332 | auto out_tv = out_ti->view(); |
333 | |
334 | // check reduction-broadcast mapping: |
335 | if (!canFuseBroadcastWithWarpReduction( |
336 | out_tv->definition()->as<BroadcastOp>())) { |
337 | return; |
338 | } |
339 | |
340 | // check buffers are size-1 |
341 | auto reduction_allocate_it = |
342 | running_tv_to_allocate_map_.back()->find(in_ti->view()); |
343 | if (reduction_allocate_it == running_tv_to_allocate_map_.back()->end()) { |
344 | // The producer reduction is not in the serially visible scope, |
345 | // as defined in openLoopNestLevel. There still could be some |
346 | // cases that we could fuse but disabled for simplicity. |
347 | return; |
348 | } |
349 | |
350 | kir::ExpressionEvaluator ee; |
351 | |
352 | // Cannot replace if either the reduction buffer or broadcast buffer does |
353 | // not have |
354 | // a size of 1, since it would have required re-indexing. |
355 | auto reduction_allocation_size = |
356 | ee.evaluate(reduction_allocate_it->second->size()); |
357 | if (!reduction_allocation_size.has_value() || |
358 | reduction_allocation_size.value() != 1) { |
359 | return; |
360 | } |
361 | |
362 | auto broadcast_allocate = getActiveAllocateFor(out_tv); |
363 | auto broadcast_allocation_size = ee.evaluate(broadcast_allocate->size()); |
364 | if (!broadcast_allocation_size.has_value() || |
365 | broadcast_allocation_size.value() != 1) { |
366 | return; |
367 | } |
368 | |
369 | // Write the tv in to the replacement map |
370 | // so the future uses of this tv will put |
371 | // the tensorIndex's in the actual replacement map. |
372 | running_tv_replacement_map_[out_tv] = in_ti; |
373 | } |
374 | } |
375 | |
376 | // Checks if the given IterDomain is mapped to a single warp, |
377 | // i.e. they are known at compile time to be of constant |
378 | // size of warp_size and they are paralleled on TIDx |
379 | int warp_size = at::cuda::warp_size(); |
380 | bool isSingleWarp(IterDomain* id) { |
381 | if (id->getParallelType() != ParallelType::TIDx) { |
382 | return false; |
383 | } |
384 | |
385 | if (!GpuLower::current()->getWarpPaddedParallelInfo().is_tidx_single_warp) { |
386 | return false; |
387 | } |
388 | |
389 | // Prioritize checking for padded dimension |
390 | if (id->getMaybeSizeAfterPadding().has_value()) { |
391 | return id->getMaybeSizeAfterPadding().value() == warp_size; |
392 | } |
393 | |
394 | if (id->extent()->isConstScalar()) { |
395 | ExpressionEvaluator evaluator(FusionGuard::getCurFusion()); |
396 | return evaluator.evaluate(id->extent()).value() == warp_size; |
397 | } |
398 | |
399 | return false; |
400 | } |
401 | |
402 | // Check if this broadcast can be fused with the producer reduction |
403 | // Assumes: |
404 | // 1. Already checked the producer of input is a reduction |
405 | // 2. Already checked the producer reduction is in the same loop nest |
406 | // Checks: |
407 | // 1. Reduction is only non-trivially parallel on TIDx as a single warp |
408 | // 2. Broadcast is only non-trivially parallel on TIDx as a single warp |
409 | bool canFuseBroadcastWithWarpReduction(BroadcastOp* broadcast) { |
410 | auto reduction_out_tv = broadcast->in()->as<TensorView>(); |
411 | auto broadcast_out_tv = broadcast->out()->as<TensorView>(); |
412 | |
413 | bool reduction_has_single_warp = false, broadcast_has_single_warp = false; |
414 | |
415 | for (auto id : reduction_out_tv->domain()->domain()) { |
416 | if (id->isReduction() && id->isThread() && !id->isTrivialReduction() && |
417 | !isSingleWarp(id)) { |
418 | return false; |
419 | } |
420 | if (id->isReduction() && isSingleWarp(id)) { |
421 | reduction_has_single_warp = true; |
422 | } |
423 | } |
424 | for (auto id : broadcast_out_tv->domain()->domain()) { |
425 | if (id->isBroadcast() && id->isThread() && !isSingleWarp(id)) { |
426 | return false; |
427 | } |
428 | if (id->isBroadcast() && isSingleWarp(id)) { |
429 | broadcast_has_single_warp = true; |
430 | } |
431 | } |
432 | return reduction_has_single_warp && broadcast_has_single_warp; |
433 | } |
434 | |
435 | private: |
436 | //! A naive record of kir tv's that will need replacement at each expr, |
437 | //! could need some extension for more precise scope based analysis in the |
438 | //! future especially if we have more complex IfThenElse blocks than |
439 | //! predicates and unroll. |
440 | std::unordered_map<TensorView*, kir::TensorIndex*> |
441 | running_tv_replacement_map_; |
442 | |
443 | //! Keeps track of the allocated buffers that the exprs will write/read |
444 | //! at each expr. Each outer vector element records the allocations at each |
445 | //! running scope level as this pass iterate through the loop nest. |
446 | std::vector<std::unique_ptr<std::vector<kir::Allocate*>>> |
447 | running_visible_allocation_stack_; |
448 | |
449 | //! A different version of running_visible_allocation_stack_ constructed for |
450 | //! convenience, |
451 | //! the difference is that thread loops, serial broadcast loops, and |
452 | //! IfThenElse's are not modeled as another scope to model the textual |
453 | //! visibility on the generated kernel. The model of IfThenElse assumes the |
454 | //! only ITE's we have are predicates and unrolls, which might need to be |
455 | //! more precise. |
456 | std::vector<std::unique_ptr<std::unordered_map<TensorView*, kir::Allocate*>>> |
457 | running_tv_to_allocate_map_; |
458 | |
459 | //! This map is the final output of this pass and a val replacement map will |
460 | //! be run using |
461 | //! it. All keys and values are TensorIndex's, and before this pass each |
462 | //! TensorIndex is uniquely generated by lower_index pass for each access of |
463 | //! a tv. |
464 | std::unordered_map<Val*, Val*> val_replacement_map_; |
465 | }; |
466 | |
467 | } // namespace |
468 | |
469 | std::vector<Expr*> fuseWarpReduce(const std::vector<Expr*> exprs) { |
470 | return FuseBroadcastWithWarpReduce::fuse(exprs); |
471 | } |
472 | |
473 | } // namespace cuda |
474 | } // namespace fuser |
475 | } // namespace jit |
476 | } // namespace torch |
477 | |