1 | |
2 | #pragma once |
3 | |
4 | #include <c10/macros/Export.h> |
5 | |
6 | #include <ir_all_nodes.h> |
7 | #include <lower_utils.h> |
8 | #include <parallel_type_bitmap.h> |
9 | |
10 | #include <unordered_map> |
11 | #include <unordered_set> |
12 | #include <utility> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | namespace fuser { |
17 | namespace cuda { |
18 | |
19 | //! Maps TensorViews to a { ParallelTypeBitmap, SourceMap } pair |
20 | //! |
21 | //! Map from TensorView to bit set represnting <BIDx, BIDy, BIDz, TIDx, TIDy, |
22 | //! TIDz> If any dependency of TV had a parallelized reduction, we will track |
23 | //! it here. This will be used for predicate generation to prevent |
24 | //! parallelization on that axis. This is important if we have a reduction on |
25 | //! for example TIDx, as the reduced value is only valid on threadIdx.x == 0 |
26 | //! therefore if we use that value later in the kernel we have that predicate. |
27 | //! If we follow a reduction parallelized on TIDx with a broadcast on TIDx we |
28 | //! no longer need the predicate and can reset the bit accordingly |
29 | //! |
30 | //! In addition, if a parallel thread type is not used, it is |
31 | //! redundant to use all threads/blocks. That isn't a problem |
32 | //! generally although it can be inefficient, but when an aliased smem |
33 | //! buffer is used as an output, redundant writes can be invalid (see issue |
34 | //! #1110). PredicateInfo::redundant_types track which parallel types |
35 | //! are redundant for each tensor and is used to let only one |
36 | //! thread/block of a redundant type execute the expression for a |
37 | //! tensor. |
38 | class TORCH_CUDA_CU_API ThreadPredicateMap { |
39 | public: |
40 | using SourceMap = std::unordered_map< |
41 | ParallelType, |
42 | std::unordered_set<const TensorView*>, |
43 | TypeHash>; |
44 | |
45 | //! Thread predicate information for each tensor |
46 | struct PredicateInfo { |
47 | // Parallel types where only one thread/block is valid. |
48 | ParallelTypeBitmap limited_types; |
49 | // Parallel types where only one thread/block is enough. |
50 | ParallelTypeBitmap redundant_types; |
51 | // Tracking use chain of redundant writes: |
52 | // [Redundant use chain] |
53 | // a parallel type is a `redundant_consumer_type` only |
54 | // if all of its propagation use chains terminate with |
55 | // a redundant write of this type. |
56 | // A propagation use chain is currently either a reg-to-reg |
57 | // chain for a shared mem tv, or a reg/smem-to-reg/smem chain |
58 | // for a global tv. |
59 | // This is complementary information to `redundant_types`. |
60 | // If a tensor view is redundantly written and not redundantly |
61 | // used by all consumers, see FusionRedundantPredSync3, |
62 | // a RAW sync will need to be inserted before reading |
63 | // this redundantly written tensor. |
64 | ParallelTypeBitmap redundant_use_types; |
65 | bool operator==(const PredicateInfo& other) const { |
66 | return limited_types == other.limited_types && |
67 | redundant_types == other.redundant_types && |
68 | redundant_use_types == other.redundant_use_types; |
69 | } |
70 | }; |
71 | |
72 | using MapType = std::unordered_map<const TensorView*, PredicateInfo>; |
73 | |
74 | using const_iterator = MapType::const_iterator; |
75 | |
76 | //! Build a map from each tensor to PredicateInfo. |
77 | void build(Fusion* fusion); |
78 | |
79 | //! Get a PredicateInfo for a given tensor. If it's an output of |
80 | //! a parallel broadcast, unmask the limited_types_ bit of the |
81 | //! corresponding parallel type since it must join the broadcast |
82 | //! operation although the valid input is only available at one of |
83 | //! the threads/blocks. |
84 | PredicateInfo getPredicateInfo(const TensorView* tv) const; |
85 | |
86 | //! Returns a flag set that indicates which parallel types should be |
87 | //! predicated. |
88 | ParallelTypeBitmap getPredicatedParallelTypes(const TensorView* tv) const; |
89 | |
90 | //! Returns a Bool predicate for a given TensorView. |
91 | Bool* getPredicate(const TensorView* tv) const; |
92 | |
93 | //! Returns a ParallelTypeBitmap representing which domain needs |
94 | //! blockBroadcast. |
95 | //! |
96 | //! Even when a domain is broadcast and parallelized, it does not need |
97 | //! blockBroadcast unless it is predicated by limited_types_ |
98 | ParallelTypeBitmap getParallelBroadcastDomains(const TensorView* tv) const; |
99 | |
100 | //! Mark tv as updated so that rebuilding the map should recompute |
101 | //! its predicates and those of its dependents. |
102 | void markAsUpdated(const TensorView* tv); |
103 | |
104 | void print() const; |
105 | |
106 | //! Generate a Bool value from PredicateInfo. |
107 | static Bool* getPredicateFromPredicateInfo( |
108 | const ThreadPredicateMap::PredicateInfo& pred_info); |
109 | |
110 | //! Get the redundant use types of the given expr, see [Redundant use chain] |
111 | ParallelTypeBitmap getRedundantConsumerType(Expr* expr) const; |
112 | |
113 | private: |
114 | // Update the thread_predicates bitset based on provided Expr |
115 | void updateBitSet(const Expr*); |
116 | |
117 | const_iterator find(const TensorView* tv) const; |
118 | const_iterator end() const; |
119 | |
120 | const PredicateInfo& at(const TensorView* tv) const; |
121 | PredicateInfo& at(const TensorView* tv); |
122 | |
123 | //! Update a mapping |
124 | bool update( |
125 | const TensorView* tv, |
126 | const ParallelTypeBitmap& limited_types, |
127 | const ParallelTypeBitmap& redundant_types); |
128 | |
129 | //! Update a mapping |
130 | bool update(const TensorView* tv, const PredicateInfo& pred_and_src); |
131 | |
132 | //! Backward populate redundant use chain info once the redundant |
133 | //! parallel writes have been identified. |
134 | void populateRedundantUseMap(Fusion* fusion); |
135 | |
136 | private: |
137 | MapType thread_predicates_; |
138 | //! Keep track of updated tensors that need predicates to be computed |
139 | std::unordered_set<const TensorView*> updated_tvs_; |
140 | }; |
141 | |
142 | } // namespace cuda |
143 | } // namespace fuser |
144 | } // namespace jit |
145 | } // namespace torch |
146 | |