1#include <lower_thread_predicate.h>
2
3#include <arith.h>
4#include <instrumentation.h>
5#include <ir_iostream.h>
6#include <ir_utils.h>
7#include <lower2device.h>
8#include <lower_utils.h>
9
10#include <c10/util/irange.h>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17namespace {
18
19Bool* getPredicatePerParallelType(
20 ParallelType pt,
21 const ThreadPredicateMap::PredicateInfo& pred_info) {
22 auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt);
23
24 // If pt is not used or is proven to be one, no need to predicate.
25 if (pt_dim == nullptr || pt_dim->isOneInt()) {
26 return GpuLower::current()->kernel()->trueVal();
27 }
28 // When BID needs to be predicated, that means it's an output of a grid
29 // reduction and only the last block index in that dimension has the right
30 // value from the grid reduce.
31 if (isParallelTypeBlockDim(pt) && pred_info.limited_types.get(pt)) {
32 return SimplifyingIrBuilder::eqExpr(
33 NamedScalar::getParallelIndex(pt),
34 SimplifyingIrBuilder::subExpr(
35 NamedScalar::getParallelDim(pt),
36 GpuLower::current()->kernel()->oneVal()))
37 ->as<Bool>();
38 }
39
40 // Otherwise, only thread of index 0 executes the computation
41 return SimplifyingIrBuilder::eqExpr(
42 NamedScalar::getParallelIndex(pt),
43 GpuLower::current()->kernel()->zeroVal())
44 ->as<Bool>();
45}
46
47} // namespace
48
49Bool* ThreadPredicateMap::getPredicateFromPredicateInfo(
50 const ThreadPredicateMap::PredicateInfo& pred_info) {
51 const auto pred_types = pred_info.limited_types | pred_info.redundant_types;
52
53 if (pred_types.none()) {
54 return GpuLower::current()->kernel()->trueVal();
55 }
56
57 Bool* pred = nullptr;
58 for (const auto pt : pred_types) {
59 const auto tp = getPredicatePerParallelType(pt, pred_info);
60 pred = SimplifyingIrBuilder::andExpr(pred, tp)->as<Bool>();
61 }
62 TORCH_INTERNAL_ASSERT(pred != nullptr);
63
64 return pred;
65}
66
67namespace {
68
69// Build redundant predicate flags. Will be stored as
70// PredicateInfo.redundant_types for the given tensor.
71ParallelTypeBitmap avoidRedundantWrites(const TensorView* out_tv) {
72 // If the memory type is Local, it's fine to write into it always as
73 // it's thread local. If it's Global, it's also fine to let each
74 // thread do its own write, unless out_tv is an output of a
75 // reduction. Standard reductions (forget gridReduce for the sake of this
76 // argument) directly into global memory buffers accumulate into the global
77 // memory buffer. If this is done redundantly then it could lead to incorrect
78 // results. Correctness issues here can come from smem aliasing, smem
79 // reductions or gmem reductions because the reduction itself performs an
80 // update to a value, not just a set. For performance it's safe to ommit the
81 // redundant writes to gmem or smem, this comment is just specifying it's not
82 // always just a performance optimization, but can also be a correctness
83 // requirement.
84 //
85 // For now this is enabled for shared memory buffers, global memory buffers
86 // undergoing a reduction, and global memory buffers with terminating outputs.
87 // This could be extended to all global memory buffer transactions, but in the
88 // test AdvancedIndexing11 there's a case where an intermediate global buffer
89 // is set and used to perform a broadcast. At the moment a grid sync is not
90 // being inserted here, and it's generally safe since it's just a set. We
91 // could enable this more generally for global memory buffers, but would have
92 // to insert a sync or a grid broadcast in that example. For now the
93 // approach is to only do this on a grid buffer (not undergoing a reduction)
94 // if there are no other uses in the kernel.
95 //
96 // TODO: Revisit if something like AdvancedIndexing11 could be happening at
97 // the same time of a global reduction in a way that could produce an
98 // incorrect result.
99 const bool is_reduction = ir_utils::isReductionOp(out_tv->definition());
100 if (!(out_tv->getMemoryType() == MemoryType::Shared ||
101 (out_tv->getMemoryType() == MemoryType::Global && is_reduction) ||
102 (out_tv->getMemoryType() == MemoryType::Global &&
103 out_tv->uses().empty()))) {
104 return ParallelTypeBitmap();
105 }
106
107 ParallelTypeBitmap pred;
108 // Track which TID types are not used to find redundant parallel
109 // types. Only TID types are checked if the tensor is on shared
110 // memory otherwise on global memory all TID and BID types are checked.
111 ParallelTypeBitmap unused_types;
112 // Initially all types are conservatively assumed to not be used.
113 unused_types = ~unused_types;
114 for (auto out_tv_id : out_tv->domain()->domain()) {
115 auto pt = out_tv_id->getParallelType();
116 if (!isParallelTypeThread(pt)) {
117 continue;
118 }
119 // If the axis is a broadcast domain and is parallelized by TID,
120 // it is sufficient to use just one thread since the tensor is on
121 // shared memory.
122 if ((out_tv->getMemoryType() == MemoryType::Shared &&
123 out_tv_id->isBroadcast() && isParallelTypeThreadDim(pt)) ||
124 // Protect against global memory and is_reduction as we don't want to
125 // predicate grid dimensions as codegen will complain predication on
126 // block dimensions is not allowed in grid reductions. The old
127 // grid reduction runtime kernel does not differentiate
128 // non-reduction and predicated parallel types, so the sync
129 // integer buffer would need to be expanded even for
130 // predicated parallel types, which is not what
131 // getGridSyncBufferSize does. The right thing here is either:
132 // retire the old grid reduction kernel, or update the kernel
133 // to propertly ignore predicated types. The new kernel is
134 // significantly complex and has not been tested, so the
135 // latter option seems more reasonable for now. See #1671.
136 (!is_reduction && out_tv->getMemoryType() == MemoryType::Global &&
137 out_tv_id->isBroadcast() && isParallelTypeThread(pt))) {
138 pred.set(pt);
139 }
140 unused_types.clear(pt);
141 }
142
143 const auto& par_dim_map = GpuLower::current()->parallelDimensionMap();
144
145 for (const auto pt : unused_types) {
146 // For shared memory tensors, unused BID isn't redundant
147 if (isParallelTypeBlockDim(pt) &&
148 out_tv->getMemoryType() == MemoryType::Shared) {
149 continue;
150 }
151 // If the pt is not used or is proven to be one, it is not
152 // really redundant.
153 auto pt_dim = par_dim_map.get(pt);
154 if (pt_dim == nullptr || pt_dim->isOneInt()) {
155 continue;
156 }
157 pred.set(pt);
158 }
159
160 return pred;
161}
162
163// If tv is an output of a reduction with unused parallel types, those
164// unused parallel types need to be predicated if the tensor is on
165// global memory.
166ParallelTypeBitmap getReductionPredicateForUnusedParallelTypes(
167 const TensorView* tv,
168 const ThreadPredicateMap::PredicateInfo& pred_info) {
169 auto tv_def = tv->definition();
170 if (!(tv_def && ir_utils::isReductionOp(tv_def) &&
171 tv->getMemoryType() == MemoryType::Global)) {
172 return {};
173 }
174
175 // Unused types are set as redundant types of tv
176 return pred_info.redundant_types;
177}
178
179} // namespace
180
181// Update the reduction_deps bitset based on provided Expr
182void ThreadPredicateMap::updateBitSet(const Expr* expr) {
183 FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap::updateBitSet");
184
185 // If all of the inputs are not updated and all of the outputs have
186 // already mappings, don't do anything
187 if (std::all_of(
188 ir_utils::filterByType<TensorView>(expr->inputs()).begin(),
189 ir_utils::filterByType<TensorView>(expr->inputs()).end(),
190 [this](TensorView* tv) {
191 return updated_tvs_.find(tv) == updated_tvs_.end();
192 }) &&
193 std::all_of(
194 ir_utils::filterByType<TensorView>(expr->outputs()).begin(),
195 ir_utils::filterByType<TensorView>(expr->outputs()).end(),
196 [this](TensorView* tv) { return find(tv) != end(); })) {
197 return;
198 }
199
200 // Which predicates were set for the inputs
201 ParallelTypeBitmap input_preds;
202
203 // Which dims are reductions in inputs
204 ParallelTypeBitmap input_reductions;
205
206 // Run through inputs and update bitsets
207 for (const auto* inp : expr->inputs()) {
208 if (!ir_utils::isTV(inp))
209 continue;
210
211 auto tv_inp = inp->as<TensorView>();
212
213 // If tv_inp was an output of a multi-output expression, just change it to a
214 // consistent sibling to use a single predicate name.
215 if (auto tv_def = tv_inp->definition()) {
216 if (tv_def->outputs().size() > 1) {
217 tv_inp = ir_utils::getTvOutput(tv_def);
218 }
219 }
220
221 TORCH_INTERNAL_ASSERT(
222 thread_predicates_.find(tv_inp) != thread_predicates_.end(),
223 "Thread predicate map was not initialized, couldn't find ",
224 inp);
225
226 const auto& pred_info = at(tv_inp);
227
228 ParallelTypeBitmap id_reductions;
229 ParallelTypeBitmap id_bcasts;
230 ParallelTypeBitmap id_ptypes;
231
232 for (auto id : tv_inp->domain()->domain()) {
233 if (id->isThread()) {
234 id_ptypes.set(id->getParallelType());
235 if (id->isReduction() &&
236 !GpuLower::current()->fusedReductionInfo().isAllreduce(id)) {
237 id_reductions.set(id->getParallelType());
238 }
239 if (id->isBroadcast() &&
240 GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
241 id)) {
242 id_bcasts.set(id->getParallelType());
243 }
244 }
245 }
246
247 // Validate the combination of ptypes, reductions, bcasts
248 for (const auto i : c10::irange(ParallelTypeBitmap::kNumParallelTypes)) {
249 if (input_reductions[i]) {
250 if (id_ptypes[i]) {
251 TORCH_INTERNAL_ASSERT(
252 id_reductions[i],
253 "Mismatched parallelized reductions found on inputs of epxr: ",
254 expr);
255 TORCH_CHECK(
256 !id_bcasts[i],
257 "Invalid broadcast and reduction combination, tried to parallelize both with the same thread dim: ",
258 inp);
259 }
260 }
261 }
262
263 // Figure out which dims bcast wants to reset
264 auto this_input_preds = pred_info.limited_types;
265 const auto bcast_reset_mask = ~(this_input_preds & id_bcasts);
266 this_input_preds &= bcast_reset_mask;
267
268 input_preds |= this_input_preds;
269
270 id_reductions |=
271 getReductionPredicateForUnusedParallelTypes(tv_inp, at(tv_inp));
272
273 // Accumulate
274 input_reductions |= id_reductions;
275 }
276
277 // Update map for this tv, before accumulating to other inputs
278 // Add any reductions this id has to any input predicates
279 auto output_preds = input_preds | input_reductions;
280
281 // Run through outputs and set bitset predicates
282 for (auto* out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
283 auto redundant_types = avoidRedundantWrites(out_tv);
284 update(out_tv, output_preds, redundant_types);
285 }
286}
287
288namespace {
289
290//! A simple backward data flow pass:
291//! This pass propagates information backward to annotate "redundant use
292//! chain"'s.
293//! The reason this is needed is that, say for example, if we have a chain
294//! of register-to-register ops that begins with a redundant shared mem write
295//! and ends with an op that non-redundantly uses the result, we'd need to
296//! insert a sync at the begining of the register-to-register chain.
297//!
298//! The same mechanism also applies in the case of a register/sharedmem chain
299//! that starts and ends with global memory read/write.
300//!
301//! The propagation rule is summarized as follows:
302//!
303//! Shared TV val:
304//! Reset all block redundant info to its own redundant write info
305//! Backpropagate grid redundant info
306//! Global TV val:
307//! Reset all redundant info to its own redundant write info
308//! Local Tv val:
309//! Backpropagate all redundant info
310//! Exprs:
311//! Propagate redundant info backwards from outputs to inputs:
312//! For each parallel type,
313//! The parallel type is redundantly used in the expr input
314//! only if all of the outputs redundantly use the same type.
315class RedundantUseAnalysis : BackwardVisitor {
316 public:
317 RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map)
318 : fusion_(fusion), pred_map_(pred_map) {
319 traverseTo(fusion, fusion->terminatingMathVals());
320 }
321
322 //! Returns a bit map signifying the parallel dimensions
323 //! on which the given tv is redundantly used. On these
324 //! dimensions not all threads/blocks are required to
325 //! hold valid value for their dependent computations.
326 ParallelTypeBitmap getRedundantUseBitMap(const TensorView* tv) {
327 // Since all tv's consumers are visited at this point, we
328 // can aggregate the final redundant use info for this tv.
329 if (fusion_->unordered_uses(tv).empty()) {
330 // Base case, un-used is also not redundantly used
331 return ParallelTypeBitmap();
332 } else {
333 // Aggregate redundant use as a conjunction of all
334 // consumer's redundant consumer info propagated
335 // backward from their consumer chains.
336 ParallelTypeBitmap redundant_use;
337 redundant_use.setAllBID();
338 redundant_use.setAllTID();
339 for (auto expr : fusion_->unordered_uses(tv)) {
340 redundant_use &= redundant_expr_use_map_.at(expr);
341 }
342
343 return redundant_use;
344 }
345 }
346
347 private:
348 using BackwardVisitor::handle;
349
350 void handle(TensorView* tv) final {
351 auto redundant_tv_map = pred_map_.getPredicateInfo(tv).redundant_types;
352
353 // Setup the info to propagate backward for the producer tv's and
354 // expressions.
355 ParallelTypeBitmap& redundant_consumer_map =
356 redundant_consumer_parallel_type_map_[tv];
357
358 // Initialize the use map to the redundant pred result
359 redundant_consumer_map = redundant_tv_map;
360
361 if (tv->getMemoryType() == MemoryType::Shared) {
362 backPropagateRedundantUse(
363 redundant_consumer_map,
364 tv,
365 false, // no propagate TID redundant use for shared tv
366 true // propagate BID redundant use
367 );
368
369 } else if (tv->getMemoryType() == MemoryType::Local) {
370 backPropagateRedundantUse(
371 redundant_consumer_map,
372 tv,
373 true, // propagate TID redundant use
374 true // propagate BID redundant use
375 );
376 }
377 }
378
379 void backPropagateRedundantUse(
380 ParallelTypeBitmap& use_map,
381 TensorView* tv,
382 bool propagate_tid,
383 bool propagate_bid) {
384 // Clear the propagated part of the original result
385 if (propagate_bid) {
386 use_map.setAllBID();
387 }
388 if (propagate_tid) {
389 use_map.setAllTID();
390 }
391
392 for (auto expr : fusion_->unordered_uses(tv)) {
393 // Assuming all consumer expressions have been
394 // visited at this point since we are traversing
395 // backward.
396 auto expr_use_map = redundant_expr_use_map_.at(expr);
397 // Clear the part of expression use map that does not
398 // need to be propagated.
399 if (!propagate_bid) {
400 expr_use_map.setAllBID();
401 }
402 if (!propagate_tid) {
403 expr_use_map.setAllTID();
404 }
405
406 // Accumulate expression redundant usage
407 // This implements the `only if all` part in
408 // the discussion above.
409 use_map &= expr_use_map;
410 }
411 }
412
413 void handle(Expr* expr) final {
414 if (ir_utils::isTvOp(expr)) {
415 // Initialize redundant info for current expr
416 c10::optional<ParallelTypeBitmap> maybe_expr_pred_map;
417
418 for (auto consumer_tv :
419 ir_utils::filterByType<TensorView>(expr->outputs())) {
420 auto tv_redundant_bitmap =
421 redundant_consumer_parallel_type_map_.at(consumer_tv);
422
423 if (maybe_expr_pred_map.has_value()) {
424 // Accumulate redundant info of this tv output.
425 maybe_expr_pred_map.value() &= tv_redundant_bitmap;
426 } else {
427 // Copy the tv's redundant info as the first valid case.
428 maybe_expr_pred_map = tv_redundant_bitmap;
429 }
430 }
431
432 TORCH_INTERNAL_ASSERT(
433 maybe_expr_pred_map.has_value(), "TV op not having a tv output");
434 redundant_expr_use_map_[expr] = maybe_expr_pred_map.value();
435 }
436 }
437
438 private:
439 // Populated redundant use information on the used tv's
440 // This map provides information on if the given tv does not require
441 // valid data from its producer on any parallel dimensions.
442 // For example:
443 // T1_local = T0_shared[...]
444 // if(tid.x == 0)
445 // T2_shared[...] = T1_local[...]
446 // Then tidx would be redundant consumer parallel type
447 // for T1, as T1 is local tensor, and only threads satisfying
448 // tidx == 0 would need to provide a valid data.
449 // In this case, not all threads would need to read correct data
450 // from T0_shared, which would help remove some sync's.
451 std::unordered_map<const TensorView*, ParallelTypeBitmap>
452 redundant_consumer_parallel_type_map_;
453
454 // Populated redundant use information on the used tv expressions.
455 std::unordered_map<const Expr*, ParallelTypeBitmap> redundant_expr_use_map_;
456
457 // Short cut to the owning fusion of this analysis.
458 Fusion* fusion_ = nullptr;
459
460 // Short cut to the active pred map analysis this pass is running as part of.
461 const ThreadPredicateMap& pred_map_;
462};
463
464} // namespace
465
466void ThreadPredicateMap::build(Fusion* fusion) {
467 FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap");
468
469 // Initialize mapping for input tensors
470 for (auto inp : fusion->inputs()) {
471 if (auto tv = dynamic_cast<const TensorView*>(inp)) {
472 update(tv, ParallelTypeBitmap(), ParallelTypeBitmap());
473 }
474 }
475 for (auto expr : fusion->exprs()) {
476 updateBitSet(expr);
477 }
478 updated_tvs_.clear();
479 populateRedundantUseMap(fusion);
480}
481
482void ThreadPredicateMap::populateRedundantUseMap(Fusion* fusion) {
483 RedundantUseAnalysis redundant_use(fusion, *this);
484 for (auto& it : thread_predicates_) {
485 it.second.redundant_use_types =
486 redundant_use.getRedundantUseBitMap(it.first);
487 }
488}
489
490ThreadPredicateMap::const_iterator ThreadPredicateMap::find(
491 const TensorView* tv) const {
492 return thread_predicates_.find(tv);
493}
494
495ThreadPredicateMap::const_iterator ThreadPredicateMap::end() const {
496 return thread_predicates_.end();
497}
498
499const ThreadPredicateMap::PredicateInfo& ThreadPredicateMap::at(
500 const TensorView* tv) const {
501 return thread_predicates_.at(tv);
502}
503
504ThreadPredicateMap::PredicateInfo& ThreadPredicateMap::at(
505 const TensorView* tv) {
506 return thread_predicates_.at(tv);
507}
508
509ThreadPredicateMap::PredicateInfo ThreadPredicateMap::getPredicateInfo(
510 const TensorView* tv) const {
511 auto pred_info = thread_predicates_.at(tv);
512 // Do not predicate a paralell type if it is a parallel bcast domain
513 if (dynamic_cast<BroadcastOp*>(tv->definition())) {
514 auto parallel_bcast = getParallelBroadcastDomains(tv);
515 pred_info.limited_types ^= parallel_bcast;
516 }
517 return pred_info;
518}
519
520ParallelTypeBitmap ThreadPredicateMap::getPredicatedParallelTypes(
521 const TensorView* tv) const {
522 auto pred_info = getPredicateInfo(tv);
523 return pred_info.limited_types | pred_info.redundant_types;
524}
525
526bool ThreadPredicateMap::update(
527 const TensorView* tv,
528 const ParallelTypeBitmap& limited_types,
529 const ParallelTypeBitmap& redundant_types) {
530 return update(tv, {limited_types, redundant_types});
531}
532
533bool ThreadPredicateMap::update(
534 const TensorView* tv,
535 const PredicateInfo& pred_info) {
536 auto existing_mapping_it = thread_predicates_.find(tv);
537 if (existing_mapping_it != end()) {
538 PredicateInfo& existing_info = existing_mapping_it->second;
539 if (existing_info == pred_info) {
540 return false;
541 } else {
542 existing_info = pred_info;
543 markAsUpdated(tv);
544 return true;
545 }
546 } else {
547 thread_predicates_.insert({tv, pred_info});
548 markAsUpdated(tv);
549 return true;
550 }
551}
552
553Bool* ThreadPredicateMap::getPredicate(const TensorView* tv) const {
554 TORCH_INTERNAL_ASSERT(find(tv) != end(), "Couldn't find ", tv);
555 auto pred_info = getPredicateInfo(tv);
556 return getPredicateFromPredicateInfo(pred_info);
557}
558
559ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains(
560 const TensorView* tv) const {
561 // If no pred is found for tv, no predicate is necessary
562 if (find(tv) == end()) {
563 return ParallelTypeBitmap();
564 }
565
566 ParallelTypeBitmap parallel_broadcast;
567
568 const auto& iter_domains = tv->domain()->domain();
569
570 // If the output is on shared memory, assume that all subsequent
571 // reads from all threads in its CTA can be done with no parallel
572 // broadcast. Only one thread will write to shared memory followed
573 // by a proper _syncthreads.
574 const bool output_smem = tv->getMemoryType() == MemoryType::Shared;
575
576 for (auto id : iter_domains) {
577 if (!id->isBroadcast() ||
578 !GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
579 id)) {
580 continue;
581 }
582 if (id->isBlockDim() || (!output_smem && id->isThreadDim())) {
583 parallel_broadcast.set(id->getParallelType());
584 }
585 }
586
587 return parallel_broadcast & at(tv).limited_types;
588}
589
590ParallelTypeBitmap ThreadPredicateMap::getRedundantConsumerType(
591 Expr* expr) const {
592 c10::optional<ParallelTypeBitmap> result;
593 for (auto out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
594 auto out_tv_redundant_map = getPredicateInfo(out_tv).redundant_use_types;
595 if (!result.has_value()) {
596 result = out_tv_redundant_map;
597 } else {
598 result.value() &= out_tv_redundant_map;
599 }
600 }
601
602 TORCH_INTERNAL_ASSERT(
603 result.has_value(), "ThreadPredicateMap : TV op assumed");
604 return result.value();
605}
606
607void ThreadPredicateMap::markAsUpdated(const TensorView* tv) {
608 updated_tvs_.insert(tv);
609}
610
611void ThreadPredicateMap::print() const {
612 std::cout << "\nThreadPredicateMap\n";
613 std::cout << "--------------------------------\n";
614 for (const auto& kv : thread_predicates_) {
615 std::cout << "T" << kv.first->name();
616 std::cout << " {" << kv.second.limited_types.toString() << "}\n";
617 std::cout << "{" << kv.second.redundant_types.toString() << "}\n";
618 std::cout << "{" << kv.second.redundant_use_types.toString() << "}\n";
619 }
620 std::cout << "--------------------------------\n\n";
621}
622
623} // namespace cuda
624} // namespace fuser
625} // namespace jit
626} // namespace torch
627