1 | #pragma once |
2 | |
3 | #include <index_compute.h> |
4 | #include <kernel_ir.h> |
5 | #include <lower_thread_predicate.h> |
6 | #include <lower_utils.h> |
7 | #include <root_domain_map.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | class PredicateCompute { |
15 | public: |
16 | // ignore_internal_syncthread_ops will prevent creation of predicates on |
17 | // block/grid broadcast/reduce as these have syncthread calls within them |
18 | // so all threads need to execute the function. |
19 | static Bool* getInlinePredicate( |
20 | const Expr* expr, |
21 | const std::vector<kir::ForLoop*>& loops, |
22 | Bool* thread_pred, |
23 | PredicateType pred_type); |
24 | }; |
25 | |
26 | //! Parallelized domains may need to be predicated with threading |
27 | //! indices and IterDomain extents. For example, if a domain is |
28 | //! parallelized by TIDx, when TIDx is not exact, i.e., it can be |
29 | //! larger than the extents of domains parallelized by TIDx, |
30 | //! threadIdx.x may be larger than the IterDomain extent. This can be |
31 | //! harmless for Local tensors, however, for it would |
32 | //! result in out-of-bounds access for Shared tensors as they are |
33 | //! allocated based on tensor shapes rather than threading |
34 | //! dimensions. |
35 | class ParallelizedDomainPredicate { |
36 | public: |
37 | //! Predicate information for parallelized domains |
38 | class PredicateInfo { |
39 | public: |
40 | explicit PredicateInfo(ParallelType pt) : pt_(pt) {} |
41 | |
42 | //! Adds a domain that is parallized by the same paralell type |
43 | bool addDomain(IterDomain* id); |
44 | |
45 | const std::vector<IterDomain*>& ids() const { |
46 | return ids_; |
47 | } |
48 | |
49 | //! Generates a predicate Val from predicate information |
50 | Bool* getPredicate() const; |
51 | |
52 | private: |
53 | ParallelType pt_; |
54 | //! Domains parallelized by the same parallel type |
55 | std::vector<IterDomain*> ids_; |
56 | }; |
57 | |
58 | //! Returns a predicate Val for parallelied domains of an expression. |
59 | static Bool* getPredicate( |
60 | const Expr* expr, |
61 | const std::vector<kir::ForLoop*>& loops); |
62 | |
63 | //! Returns predicate information for parallelied domains of an |
64 | //! expression. |
65 | static std::unordered_map<ParallelType, PredicateInfo, TypeHash> |
66 | getPredicateMap( |
67 | const Expr* expr, |
68 | const std::vector<kir::ForLoop*>& loops, |
69 | kir::ForLoop* unswitched_loop = nullptr); |
70 | }; |
71 | |
72 | //! Keys to identify unique unswitch predicates. Just consists of a |
73 | //! predicated concrete domain if not parallelized. If parallelized, |
74 | //! pick one for each different parallelization. When the same |
75 | //! parallel type is used for different concrete domains, they are |
76 | //! considered different predicates and are included in the unswitch |
77 | //! condition lists. |
78 | class UnswitchPredicateKey { |
79 | public: |
80 | UnswitchPredicateKey(); |
81 | |
82 | UnswitchPredicateKey( |
83 | IterDomain* predicated_consumer_id, |
84 | TensorView* consumer_tv, |
85 | IterDomain* predicated_concrete_id); |
86 | |
87 | bool operator==(const UnswitchPredicateKey& other) const { |
88 | return predicated_concrete_id_ == other.predicated_concrete_id_ && |
89 | parallel_concrete_ids_ == other.parallel_concrete_ids_; |
90 | } |
91 | |
92 | const auto& predicatedId() const { |
93 | return predicated_concrete_id_; |
94 | } |
95 | |
96 | const auto& parallelConcreteIds() const { |
97 | return parallel_concrete_ids_; |
98 | } |
99 | |
100 | IterDomain* parallelId(ParallelType pt) const { |
101 | auto it = parallelConcreteIds().find(pt); |
102 | if (it == parallelConcreteIds().end()) { |
103 | return nullptr; |
104 | } else { |
105 | return it->second; |
106 | } |
107 | } |
108 | |
109 | std::string toString() const; |
110 | |
111 | private: |
112 | //! Predicated concrete domain |
113 | IterDomain* predicated_concrete_id_ = nullptr; |
114 | //! Store parallelized concrete domains |
115 | std::unordered_map<ParallelType, IterDomain*, TypeHash> |
116 | parallel_concrete_ids_; |
117 | }; |
118 | |
119 | struct UnswitchPredicateKeyHash { |
120 | std::size_t operator()(const UnswitchPredicateKey& key) const; |
121 | }; |
122 | |
123 | class TORCH_CUDA_CU_API UnswitchPredicate { |
124 | public: |
125 | static Bool* get( |
126 | const std::vector<kir::ForLoop*>& outer_loops, |
127 | kir::ForLoop* unrolled_loop); |
128 | |
129 | private: |
130 | //! Predicate information for each UnswitchPredicateKey. |
131 | struct MergedPredicates { |
132 | //! Predicate information for the start and stop predicates. |
133 | struct Info { |
134 | //! Most restrictive static predicate. Nullptr if no static |
135 | //! predicate found. |
136 | Bool* static_pred = nullptr; |
137 | //! The offset value of static_pred |
138 | int64_t static_offset = 0; |
139 | //! List of dynamic predicates. |
140 | std::vector<Bool*> dynamic_preds; |
141 | }; |
142 | UnswitchPredicateKey predicate_key; |
143 | Info start; |
144 | Info stop; |
145 | }; |
146 | |
147 | UnswitchPredicate( |
148 | std::vector<kir::ForLoop*> outer_loops, |
149 | kir::ForLoop* unrolled_loop); |
150 | |
151 | void predicateOn(Expr*); |
152 | |
153 | void openLoop(kir::ForLoop*); |
154 | |
155 | void openIte(kir::IfThenElse*); |
156 | |
157 | //! Generates the final predicates from the predicated_keys map |
158 | void finalize(); |
159 | |
160 | //! Merge predicates as much as possible. If a predicate offset is |
161 | //! static, only pick the most restrictive one, e.g., the one with the |
162 | //! minimum offset for the start predication. |
163 | void mergeUnswitchPredicateOffsets( |
164 | Bool* predicate, |
165 | Val* offset, |
166 | MergedPredicates::Info& merged_predicate_info, |
167 | bool is_start); |
168 | |
169 | //! Adds new predicates for parallelized domains |
170 | void addParallelizedDomainPredicates(Expr*); |
171 | |
172 | private: |
173 | //! Track which iter domains have been predicated |
174 | std::unordered_set<UnswitchPredicateKey, UnswitchPredicateKeyHash> |
175 | predicated_keys_; |
176 | |
177 | //! The predicates that have been recorded but not yet finalized |
178 | std::vector<MergedPredicates> pending_predicates_; |
179 | |
180 | //! Track which parallelized domains have been predicated |
181 | std::unordered_map< |
182 | ParallelType, |
183 | ParallelizedDomainPredicate::PredicateInfo, |
184 | TypeHash> |
185 | parallelized_dom_predicates_; |
186 | |
187 | //! The predicates that have been generated. |
188 | std::vector<Bool*> predicates_; |
189 | |
190 | std::vector<kir::ForLoop*> for_loops_; |
191 | |
192 | kir::ForLoop* unrolled_loop_; |
193 | }; |
194 | |
195 | } // namespace cuda |
196 | } // namespace fuser |
197 | } // namespace jit |
198 | } // namespace torch |
199 | |