1#include <ATen/cuda/CUDAContext.h>
2#include <expr_evaluator.h>
3#include <kernel_expr_evaluator.h>
4#include <kernel_ir_dispatch.h>
5#include <lower2device.h>
6#include <lower_utils.h>
7#include <lower_warp_reduce.h>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14namespace {
15
16//! A helper class for EliminateDeadBroadcastAndAllocate. Eliminate
17//! dead Allocate and Broadcast detected by EliminateDeadBroadcastAndAllocate.
18class DeadTvEliminator : private kir::ExprMutator {
19 public:
20 static std::vector<Expr*> run(
21 const std::vector<Expr*>& exprs,
22 const std::unordered_set<TensorView*>& dead_tvs) {
23 return DeadTvEliminator(exprs, dead_tvs).exprs_;
24 }
25
26 private:
27 DeadTvEliminator(
28 const std::vector<Expr*>& exprs,
29 const std::unordered_set<TensorView*>& dead_tvs)
30 : dead_tvs_(dead_tvs) {
31 traverseAndInsert(exprs);
32 }
33
34 using kir::ExprMutator::handle;
35
36 void handle(kir::Allocate* allocate) final {
37 if (auto buffer_tv = dynamic_cast<TensorView*>(allocate->buffer())) {
38 if (dead_tvs_.count(buffer_tv)) {
39 registerRemove(allocate);
40 }
41 }
42 }
43
44 void handle(BroadcastOp* broadcast) final {
45 if (auto out_ti = dynamic_cast<kir::TensorIndex*>(broadcast->out())) {
46 if (dead_tvs_.count(out_ti->view())) {
47 registerRemove(broadcast);
48 }
49 }
50 }
51
52 private:
53 const std::unordered_set<TensorView*>& dead_tvs_;
54};
55
56//! A simple DCE for eliminating the
57//! parallel broadcasts that has been fused
58//! and their corresponding allocations
59class EliminateDeadBroadcastAndAllocate {
60 public:
61 static std::vector<Expr*> run(const std::vector<Expr*>& exprs) {
62 EliminateDeadBroadcastAndAllocate dce(exprs);
63 return DeadTvEliminator::run(exprs, dce.dead_tvs_);
64 }
65
66 private:
67 EliminateDeadBroadcastAndAllocate(const std::vector<Expr*>& exprs) {
68 findLiveTvs(exprs);
69 findDeadTvs();
70 }
71
72 void findLiveTvs(const std::vector<Expr*>& exprs) {
73 for (auto expr : exprs) {
74 if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
75 findLiveTvs(for_loop->body().exprs());
76 continue;
77 } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
78 findLiveTvs(ite->thenBody().exprs());
79 findLiveTvs(ite->elseBody().exprs());
80 continue;
81 }
82
83 if (auto allocate = dynamic_cast<kir::Allocate*>(expr)) {
84 if (allocate->memoryType() == MemoryType::Local) {
85 if (auto tv = dynamic_cast<TensorView*>(allocate->buffer())) {
86 // We know only tvs that we'd want to consider are broadcast outputs
87 if (tv->definition()->isA<BroadcastOp>()) {
88 candidate_tv_set_.insert(tv);
89 }
90 }
91 }
92 }
93
94 for (auto inp : expr->inputs()) {
95 if (auto ti = dynamic_cast<kir::TensorIndex*>(inp)) {
96 if (candidate_tv_set_.count(ti->view())) {
97 live_tvs_.insert(ti->view());
98 }
99 }
100 }
101 }
102 }
103
104 void findDeadTvs() {
105 for (auto tv : candidate_tv_set_) {
106 if (!live_tvs_.count(tv)) {
107 dead_tvs_.insert(tv);
108 }
109 }
110 }
111
112 private:
113 std::unordered_set<TensorView*> live_tvs_;
114 std::unordered_set<TensorView*> dead_tvs_;
115 std::unordered_set<TensorView*> candidate_tv_set_;
116};
117
118//! A pass to eliminate redundant parallel broadcasts that are consumers
119//! of warp reduction.
120//! Detects the following pattern:
121//!
122//! For ... (serial)
123//! For ... (serial)
124//! T1[0] = warp_reduce (T0[0])
125//! T2[0] = block_broadcast (T1[0])
126//!
127//! The block_broadcast can then be eliminated given that both the warp
128//! reduce and the broadcast are known in compile-time to be parallelized
129//! on a single warp only.
130//!
131//! Currently only limited to buffers of size-1 to avoid having to
132//! re-run indexing
133//!
134//! This pass operates in 3 phases:
135//! 1. FuseBroadcastWithWarpReduce identifies the broadcasts that can
136//! be removed, and generates a replacement map from the broadcast
137//! output to reduction output.
138//!
139//! 2. ir_utils::replaceInputsInExpr replaces applicable uses of
140//! the broadcast output with the corresponding reduction output.
141//!
142//! 3. EliminateDeadBroadcastAndAllocate removes the broadcast ops
143//! and corresponding allocations if they're un-used after step 2.
144class FuseBroadcastWithWarpReduce : private kir::IrVisitor {
145 public:
146 static std::vector<Expr*> fuse(const std::vector<Expr*>& exprs) {
147 FuseBroadcastWithWarpReduce fuse_broadcast_map(exprs);
148 const auto replaced_inputs = ir_utils::replaceInputsInExpr(
149 exprs, fuse_broadcast_map.val_replacement_map_);
150 return EliminateDeadBroadcastAndAllocate::run(replaced_inputs);
151 }
152
153 private:
154 FuseBroadcastWithWarpReduce(const std::vector<Expr*>& exprs) {
155 // open stack space for global scope
156 // The scope stack for tv_to_allocate wouldn't be needed
157 // if the allocations are guaranteed to be once and unique,
158 // which can currently be assumed but this pass tries not
159 // to rely on this assumption.
160 running_tv_to_allocate_map_.emplace_back(
161 std::make_unique<std::unordered_map<TensorView*, kir::Allocate*>>());
162 running_visible_allocation_stack_.emplace_back(
163 std::make_unique<std::vector<kir::Allocate*>>());
164 kir::IrVisitor::handle(exprs);
165 }
166
167 void handle(Expr* expr) final {
168 if (ir_utils::isTvOp(expr)) {
169 // Process expr inputs if needs replacement
170 for (auto inp : expr->inputs()) {
171 if (auto input_ti = dynamic_cast<kir::TensorIndex*>(inp)) {
172 auto replace = findMaybeReplacedTensorIndex(input_ti);
173 if (replace.has_value()) {
174 val_replacement_map_[input_ti] = replace.value();
175 }
176 }
177 }
178 }
179 kir::IrVisitor::handle(expr);
180 }
181
182 bool openLoopNestLevel(IterDomain* id) {
183 if (id->isThread() || id->getParallelType() == ParallelType::Unswitch) {
184 return false;
185 }
186 if (id->getParallelType() == ParallelType::Serial ||
187 id->getParallelType() == ParallelType::Unroll) {
188 return !id->isBroadcast();
189 }
190 return true;
191 }
192
193 void handle(kir::ForLoop* for_loop) final {
194 // Keep track of visible reduction outputs
195 bool open_nest_level = openLoopNestLevel(for_loop->iter_domain());
196 if (open_nest_level) {
197 running_tv_to_allocate_map_.emplace_back(
198 std::make_unique<std::unordered_map<TensorView*, kir::Allocate*>>());
199 running_visible_allocation_stack_.emplace_back(
200 std::make_unique<std::vector<kir::Allocate*>>());
201 }
202 for (auto expr : for_loop->body().exprs()) {
203 handle(expr);
204 }
205 if (open_nest_level) {
206 running_tv_to_allocate_map_.pop_back();
207 running_visible_allocation_stack_.pop_back();
208 }
209 }
210
211 void handle(kir::IfThenElse* ite) final {
212 running_visible_allocation_stack_.emplace_back(
213 std::make_unique<std::vector<kir::Allocate*>>());
214 for (auto expr : ite->thenBody().exprs()) {
215 handle(expr);
216 }
217 running_visible_allocation_stack_.pop_back();
218 running_visible_allocation_stack_.emplace_back(
219 std::make_unique<std::vector<kir::Allocate*>>());
220 for (auto expr : ite->elseBody().exprs()) {
221 handle(expr);
222 }
223 running_visible_allocation_stack_.pop_back();
224 }
225
226 //! Place this allocate on the list of currently visible allocations,
227 //! organized by loop nest level.
228 void handle(kir::Allocate* allocate) final {
229 if (allocate->memoryType() != MemoryType::Local) {
230 return;
231 }
232 if (auto tv = dynamic_cast<TensorView*>(allocate->buffer())) {
233 if (tv->definition()) {
234 if (tv->definition()->isA<ReductionOp>() ||
235 tv->definition()->isA<BroadcastOp>()) {
236 running_visible_allocation_stack_.back()->push_back(allocate);
237 }
238 }
239 }
240 }
241
242 //! Checks if the given tv has been replaced by broadcast fusion.
243 //! returns the replaced TensorIndex if so.
244 c10::optional<kir::TensorIndex*> findMaybeReplacedTensorIndex(
245 kir::TensorIndex* tensor_index) {
246 auto tv = tensor_index->view();
247 auto tensor_index_it = running_tv_replacement_map_.find(tv);
248 if (tensor_index_it != running_tv_replacement_map_.end()) {
249 return tensor_index_it->second;
250 }
251 return c10::nullopt;
252 }
253
254 //! Iterate backwards on the currently visible loop scopes
255 //! and find the first allocation corresponding to the
256 //! given tv.
257 kir::Allocate* getActiveAllocateFor(TensorView* tv) {
258 for (auto frame_it = running_visible_allocation_stack_.rbegin();
259 frame_it != running_visible_allocation_stack_.rend();
260 frame_it++) {
261 for (auto allocate_it = (*frame_it)->rbegin();
262 allocate_it != (*frame_it)->rend();
263 allocate_it++) {
264 auto candidate_allocate = *allocate_it;
265 if (candidate_allocate->buffer() == tv) {
266 return candidate_allocate;
267 }
268 }
269 }
270 TORCH_INTERNAL_ASSERT(
271 false, "lower_warp_reduce: cannot find allocation for this op");
272 return nullptr;
273 }
274
275 bool isOpInputRegisterTV(Expr* expr) {
276 for (auto inp : expr->inputs()) {
277 if (auto inp_ti = dynamic_cast<kir::TensorIndex*>(inp)) {
278 if (inp_ti->view()->getMemoryType() != MemoryType::Local) {
279 return false;
280 }
281 }
282 }
283
284 return true;
285 }
286
287 bool isOpOutputRegisterTV(Expr* expr) {
288 for (auto out : expr->outputs()) {
289 if (auto out_ti = dynamic_cast<kir::TensorIndex*>(out)) {
290 if (out_ti->view()->getMemoryType() != MemoryType::Local) {
291 return false;
292 }
293 }
294 }
295
296 return true;
297 }
298
299 //! Updates map of serially visible reduction tvs, see comment on
300 //! running_tv_to_allocate_map_.
301 void handle(ReductionOp* reduction) final {
302 if (!isOpOutputRegisterTV(reduction)) {
303 return;
304 }
305 auto reduction_ti_out = dynamic_cast<kir::TensorIndex*>(reduction->out());
306 TORCH_INTERNAL_ASSERT(
307 reduction_ti_out,
308 "lower_warp_reduce: Pass needs to be run after indexing");
309
310 // keep track of which reduction buffer this expr writes into
311 auto reduction_allocate = getActiveAllocateFor(reduction_ti_out->view());
312 running_tv_to_allocate_map_.back()->operator[](reduction_ti_out->view()) =
313 reduction_allocate;
314 }
315
316 void handle(BroadcastOp* broadcast) final {
317 if (!isOpInputRegisterTV(broadcast) || !isOpOutputRegisterTV(broadcast)) {
318 return;
319 }
320 tryAddOutputToReplaceMap(broadcast);
321 }
322
323 //! Detects if this broadcast can be fused with the producer reduction.
324 //! adds the output of broadcast to replacement map if all above mentioned
325 //! conditions check.
326 void tryAddOutputToReplaceMap(BroadcastOp* broadcast) {
327 if (auto in_ti = dynamic_cast<kir::TensorIndex*>(broadcast->in())) {
328 if (!in_ti->view()->definition()->isA<ReductionOp>()) {
329 return;
330 }
331 auto out_ti = broadcast->out()->as<kir::TensorIndex>();
332 auto out_tv = out_ti->view();
333
334 // check reduction-broadcast mapping:
335 if (!canFuseBroadcastWithWarpReduction(
336 out_tv->definition()->as<BroadcastOp>())) {
337 return;
338 }
339
340 // check buffers are size-1
341 auto reduction_allocate_it =
342 running_tv_to_allocate_map_.back()->find(in_ti->view());
343 if (reduction_allocate_it == running_tv_to_allocate_map_.back()->end()) {
344 // The producer reduction is not in the serially visible scope,
345 // as defined in openLoopNestLevel. There still could be some
346 // cases that we could fuse but disabled for simplicity.
347 return;
348 }
349
350 kir::ExpressionEvaluator ee;
351
352 // Cannot replace if either the reduction buffer or broadcast buffer does
353 // not have
354 // a size of 1, since it would have required re-indexing.
355 auto reduction_allocation_size =
356 ee.evaluate(reduction_allocate_it->second->size());
357 if (!reduction_allocation_size.has_value() ||
358 reduction_allocation_size.value() != 1) {
359 return;
360 }
361
362 auto broadcast_allocate = getActiveAllocateFor(out_tv);
363 auto broadcast_allocation_size = ee.evaluate(broadcast_allocate->size());
364 if (!broadcast_allocation_size.has_value() ||
365 broadcast_allocation_size.value() != 1) {
366 return;
367 }
368
369 // Write the tv in to the replacement map
370 // so the future uses of this tv will put
371 // the tensorIndex's in the actual replacement map.
372 running_tv_replacement_map_[out_tv] = in_ti;
373 }
374 }
375
376 // Checks if the given IterDomain is mapped to a single warp,
377 // i.e. they are known at compile time to be of constant
378 // size of warp_size and they are paralleled on TIDx
379 int warp_size = at::cuda::warp_size();
380 bool isSingleWarp(IterDomain* id) {
381 if (id->getParallelType() != ParallelType::TIDx) {
382 return false;
383 }
384
385 if (!GpuLower::current()->getWarpPaddedParallelInfo().is_tidx_single_warp) {
386 return false;
387 }
388
389 // Prioritize checking for padded dimension
390 if (id->getMaybeSizeAfterPadding().has_value()) {
391 return id->getMaybeSizeAfterPadding().value() == warp_size;
392 }
393
394 if (id->extent()->isConstScalar()) {
395 ExpressionEvaluator evaluator(FusionGuard::getCurFusion());
396 return evaluator.evaluate(id->extent()).value() == warp_size;
397 }
398
399 return false;
400 }
401
402 // Check if this broadcast can be fused with the producer reduction
403 // Assumes:
404 // 1. Already checked the producer of input is a reduction
405 // 2. Already checked the producer reduction is in the same loop nest
406 // Checks:
407 // 1. Reduction is only non-trivially parallel on TIDx as a single warp
408 // 2. Broadcast is only non-trivially parallel on TIDx as a single warp
409 bool canFuseBroadcastWithWarpReduction(BroadcastOp* broadcast) {
410 auto reduction_out_tv = broadcast->in()->as<TensorView>();
411 auto broadcast_out_tv = broadcast->out()->as<TensorView>();
412
413 bool reduction_has_single_warp = false, broadcast_has_single_warp = false;
414
415 for (auto id : reduction_out_tv->domain()->domain()) {
416 if (id->isReduction() && id->isThread() && !id->isTrivialReduction() &&
417 !isSingleWarp(id)) {
418 return false;
419 }
420 if (id->isReduction() && isSingleWarp(id)) {
421 reduction_has_single_warp = true;
422 }
423 }
424 for (auto id : broadcast_out_tv->domain()->domain()) {
425 if (id->isBroadcast() && id->isThread() && !isSingleWarp(id)) {
426 return false;
427 }
428 if (id->isBroadcast() && isSingleWarp(id)) {
429 broadcast_has_single_warp = true;
430 }
431 }
432 return reduction_has_single_warp && broadcast_has_single_warp;
433 }
434
435 private:
436 //! A naive record of kir tv's that will need replacement at each expr,
437 //! could need some extension for more precise scope based analysis in the
438 //! future especially if we have more complex IfThenElse blocks than
439 //! predicates and unroll.
440 std::unordered_map<TensorView*, kir::TensorIndex*>
441 running_tv_replacement_map_;
442
443 //! Keeps track of the allocated buffers that the exprs will write/read
444 //! at each expr. Each outer vector element records the allocations at each
445 //! running scope level as this pass iterate through the loop nest.
446 std::vector<std::unique_ptr<std::vector<kir::Allocate*>>>
447 running_visible_allocation_stack_;
448
449 //! A different version of running_visible_allocation_stack_ constructed for
450 //! convenience,
451 //! the difference is that thread loops, serial broadcast loops, and
452 //! IfThenElse's are not modeled as another scope to model the textual
453 //! visibility on the generated kernel. The model of IfThenElse assumes the
454 //! only ITE's we have are predicates and unrolls, which might need to be
455 //! more precise.
456 std::vector<std::unique_ptr<std::unordered_map<TensorView*, kir::Allocate*>>>
457 running_tv_to_allocate_map_;
458
459 //! This map is the final output of this pass and a val replacement map will
460 //! be run using
461 //! it. All keys and values are TensorIndex's, and before this pass each
462 //! TensorIndex is uniquely generated by lower_index pass for each access of
463 //! a tv.
464 std::unordered_map<Val*, Val*> val_replacement_map_;
465};
466
467} // namespace
468
469std::vector<Expr*> fuseWarpReduce(const std::vector<Expr*> exprs) {
470 return FuseBroadcastWithWarpReduce::fuse(exprs);
471}
472
473} // namespace cuda
474} // namespace fuser
475} // namespace jit
476} // namespace torch
477