1
2#pragma once
3
4#include <c10/macros/Export.h>
5
6#include <ir_all_nodes.h>
7#include <lower_utils.h>
8#include <parallel_type_bitmap.h>
9
10#include <unordered_map>
11#include <unordered_set>
12#include <utility>
13
14namespace torch {
15namespace jit {
16namespace fuser {
17namespace cuda {
18
19//! Maps TensorViews to a { ParallelTypeBitmap, SourceMap } pair
20//!
21//! Map from TensorView to bit set represnting <BIDx, BIDy, BIDz, TIDx, TIDy,
22//! TIDz> If any dependency of TV had a parallelized reduction, we will track
23//! it here. This will be used for predicate generation to prevent
24//! parallelization on that axis. This is important if we have a reduction on
25//! for example TIDx, as the reduced value is only valid on threadIdx.x == 0
26//! therefore if we use that value later in the kernel we have that predicate.
27//! If we follow a reduction parallelized on TIDx with a broadcast on TIDx we
28//! no longer need the predicate and can reset the bit accordingly
29//!
30//! In addition, if a parallel thread type is not used, it is
31//! redundant to use all threads/blocks. That isn't a problem
32//! generally although it can be inefficient, but when an aliased smem
33//! buffer is used as an output, redundant writes can be invalid (see issue
34//! #1110). PredicateInfo::redundant_types track which parallel types
35//! are redundant for each tensor and is used to let only one
36//! thread/block of a redundant type execute the expression for a
37//! tensor.
38class TORCH_CUDA_CU_API ThreadPredicateMap {
39 public:
40 using SourceMap = std::unordered_map<
41 ParallelType,
42 std::unordered_set<const TensorView*>,
43 TypeHash>;
44
45 //! Thread predicate information for each tensor
46 struct PredicateInfo {
47 // Parallel types where only one thread/block is valid.
48 ParallelTypeBitmap limited_types;
49 // Parallel types where only one thread/block is enough.
50 ParallelTypeBitmap redundant_types;
51 // Tracking use chain of redundant writes:
52 // [Redundant use chain]
53 // a parallel type is a `redundant_consumer_type` only
54 // if all of its propagation use chains terminate with
55 // a redundant write of this type.
56 // A propagation use chain is currently either a reg-to-reg
57 // chain for a shared mem tv, or a reg/smem-to-reg/smem chain
58 // for a global tv.
59 // This is complementary information to `redundant_types`.
60 // If a tensor view is redundantly written and not redundantly
61 // used by all consumers, see FusionRedundantPredSync3,
62 // a RAW sync will need to be inserted before reading
63 // this redundantly written tensor.
64 ParallelTypeBitmap redundant_use_types;
65 bool operator==(const PredicateInfo& other) const {
66 return limited_types == other.limited_types &&
67 redundant_types == other.redundant_types &&
68 redundant_use_types == other.redundant_use_types;
69 }
70 };
71
72 using MapType = std::unordered_map<const TensorView*, PredicateInfo>;
73
74 using const_iterator = MapType::const_iterator;
75
76 //! Build a map from each tensor to PredicateInfo.
77 void build(Fusion* fusion);
78
79 //! Get a PredicateInfo for a given tensor. If it's an output of
80 //! a parallel broadcast, unmask the limited_types_ bit of the
81 //! corresponding parallel type since it must join the broadcast
82 //! operation although the valid input is only available at one of
83 //! the threads/blocks.
84 PredicateInfo getPredicateInfo(const TensorView* tv) const;
85
86 //! Returns a flag set that indicates which parallel types should be
87 //! predicated.
88 ParallelTypeBitmap getPredicatedParallelTypes(const TensorView* tv) const;
89
90 //! Returns a Bool predicate for a given TensorView.
91 Bool* getPredicate(const TensorView* tv) const;
92
93 //! Returns a ParallelTypeBitmap representing which domain needs
94 //! blockBroadcast.
95 //!
96 //! Even when a domain is broadcast and parallelized, it does not need
97 //! blockBroadcast unless it is predicated by limited_types_
98 ParallelTypeBitmap getParallelBroadcastDomains(const TensorView* tv) const;
99
100 //! Mark tv as updated so that rebuilding the map should recompute
101 //! its predicates and those of its dependents.
102 void markAsUpdated(const TensorView* tv);
103
104 void print() const;
105
106 //! Generate a Bool value from PredicateInfo.
107 static Bool* getPredicateFromPredicateInfo(
108 const ThreadPredicateMap::PredicateInfo& pred_info);
109
110 //! Get the redundant use types of the given expr, see [Redundant use chain]
111 ParallelTypeBitmap getRedundantConsumerType(Expr* expr) const;
112
113 private:
114 // Update the thread_predicates bitset based on provided Expr
115 void updateBitSet(const Expr*);
116
117 const_iterator find(const TensorView* tv) const;
118 const_iterator end() const;
119
120 const PredicateInfo& at(const TensorView* tv) const;
121 PredicateInfo& at(const TensorView* tv);
122
123 //! Update a mapping
124 bool update(
125 const TensorView* tv,
126 const ParallelTypeBitmap& limited_types,
127 const ParallelTypeBitmap& redundant_types);
128
129 //! Update a mapping
130 bool update(const TensorView* tv, const PredicateInfo& pred_and_src);
131
132 //! Backward populate redundant use chain info once the redundant
133 //! parallel writes have been identified.
134 void populateRedundantUseMap(Fusion* fusion);
135
136 private:
137 MapType thread_predicates_;
138 //! Keep track of updated tensors that need predicates to be computed
139 std::unordered_set<const TensorView*> updated_tvs_;
140};
141
142} // namespace cuda
143} // namespace fuser
144} // namespace jit
145} // namespace torch
146