1 | |
2 | #pragma once |
3 | |
4 | #include <c10/macros/Export.h> |
5 | |
6 | #include <compute_at_map.h> |
7 | #include <ir_all_nodes.h> |
8 | #include <kernel_ir.h> |
9 | #include <parallel_type_bitmap.h> |
10 | |
11 | #include <bitset> |
12 | #include <map> |
13 | |
14 | // Provides utilities for dealing with nested ForLoop and IfThenElse scopes |
15 | |
16 | namespace torch { |
17 | namespace jit { |
18 | namespace fuser { |
19 | namespace cuda { |
20 | |
21 | class ThreadPredicateMap; |
22 | |
23 | using IterDomainMap = std::unordered_map<IterDomain*, IterDomain*>; |
24 | |
25 | namespace scope_utils { |
26 | |
27 | //! Create an **empty** Forloop and copy the metadata. |
28 | kir::ForLoop* cloneForLoop(kir::ForLoop* for_loop); |
29 | |
30 | //! Create an **empty** IfThenElse and copy the metadata. |
31 | kir::IfThenElse* cloneIfThenElse(kir::IfThenElse* ite); |
32 | |
33 | } // namespace scope_utils |
34 | |
35 | namespace ir_utils { |
36 | |
37 | // Somtimes we want to temporarily view a tensorview with another tensordomain. |
38 | // This isn't a permanent transformation, but in indexing we want to index |
39 | // producers with a consumer set of indices, so we need to view the producer |
40 | // transformed like consumer while we index. This will set the tv with td for |
41 | // the life of this context guard. |
42 | class TORCH_CUDA_CU_API TVDomainGuard { |
43 | private: |
44 | TensorView* tv_; |
45 | TensorDomain* prev_domain_; |
46 | |
47 | public: |
48 | explicit TVDomainGuard(TensorView* tv, TensorDomain* td); |
49 | TVDomainGuard(const TVDomainGuard&) = delete; |
50 | TVDomainGuard(TVDomainGuard&&); |
51 | |
52 | //! An utility to access the tensordomain before the temporary |
53 | //! view. This is used to retrieve information, like swizzle |
54 | //! information that can only be reliably kept at the original domain. |
55 | const TensorDomain* prevDomain() const { |
56 | return prev_domain_; |
57 | } |
58 | |
59 | ~TVDomainGuard(); |
60 | }; |
61 | |
62 | // Create a TVDomainGuard that temporarily view a tensorview with specified |
63 | // all-true or all-false contiguity. |
64 | TORCH_CUDA_CU_API ir_utils::TVDomainGuard overrideContiguityGuard( |
65 | TensorView* tv, |
66 | bool contiguity); |
67 | |
68 | //! Return inputs of provided IterDomains that are IterDomains. A list |
69 | //! of input IterDomain can be optionally given. Otherwise, |
70 | //! IterDomains with no defining expression are returned. |
71 | std::vector<IterDomain*> iterDomainInputsOf( |
72 | const std::vector<IterDomain*>& input_ids, |
73 | const std::vector<IterDomain*>& all_inputs = {}); |
74 | |
75 | // Return inputs of provided IterDomains that are IterDomains, order as the |
76 | // second provided vector. |
77 | std::vector<IterDomain*> iterDomainInputsOfOrderedAs( |
78 | const std::vector<IterDomain*>& of, |
79 | const std::vector<IterDomain*>& order); |
80 | |
81 | // Returns if Val is a TensorView or TensorIndex |
82 | TORCH_CUDA_CU_API bool isTV(const Val* const); |
83 | |
84 | // Returns if Expr is a TensorView or TensorIndex Expr. |
85 | TORCH_CUDA_CU_API bool isTvOp(const Expr*); |
86 | |
87 | // Returns the first output of Expr that is a TensorView |
88 | TORCH_CUDA_CU_API TensorView* getTvOutput(const Expr*); |
89 | |
90 | // Returns the first input of Expr that is a TensorView |
91 | TORCH_CUDA_CU_API TensorView* getTvInput(const Expr*); |
92 | |
93 | //! Returns the iterdomain that maps to the thread dimension grouped |
94 | //! to warps. Returns nullopt if the reduction is not to be lowered to |
95 | //! a warp reduction. |
96 | c10::optional<IterDomain*> getMaybeWarpReductionDim( |
97 | const Val* output, |
98 | const Val* input); |
99 | |
100 | bool isScalarOp(const Expr*); |
101 | |
102 | //! Get TensorView potentially via kir::TensorIndex. Returns nullptr if |
103 | //! cast fails. |
104 | TensorView* getTv(Val*); |
105 | const TensorView* getTv(const Val*); |
106 | |
107 | //! Get only TensorView potentially via kir::TensorIndex. |
108 | std::vector<TensorView*> getTvs(const std::vector<Val*>& vals); |
109 | |
110 | //! Return true if axis is derived from a root axis that is an input |
111 | //! to a CA leaf axis. |
112 | bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis); |
113 | |
114 | std::unordered_map<ParallelType, IterDomain*, TypeHash> getParallelDomains( |
115 | const Val* val); |
116 | |
117 | //! Returns true if the expression will be lowered to |
118 | //! a ldmatrix intrinsic. |
119 | bool isLdMatrixOp(const Expr* expr); |
120 | |
121 | //! Returns true if the expression will be lowered to |
122 | //! a cp.async intrinsic. |
123 | bool isCpAsyncOp(const Expr* expr); |
124 | |
125 | //! Short-cut for detecting initialization for cpAsync op. |
126 | bool isCpAsyncInit(const Expr* expr); |
127 | |
128 | //! Short-cut for matching a singleton expr in a if statement, |
129 | //! which likely becomes a predicated instruction in ptx, eg.: |
130 | //! if(...) {expr;} |
131 | //! Returns the expr if it is this pattern. |
132 | //! Returns nullptr if the pattern doesn't match. |
133 | c10::optional<Expr*> getMaybePredicatedSingleton(Expr* expr); |
134 | |
135 | //! Short-cut for checking if the expression loads from global memory. |
136 | bool isGlobalLoad(const Expr* expr); |
137 | |
138 | //! Short-cut for checking if the given expression initializes buffers |
139 | //! for global memory load. |
140 | bool isGlobalLoadInit(const Expr* expr); |
141 | |
142 | //! Returns true if the given expression fills the output |
143 | //! tensor with a single scalar. |
144 | bool isTensorScalarFillOp(const Expr* expr); |
145 | |
146 | //! Flattens all the scoped exprs, i.e. ForLoop and IfThenElse, |
147 | //! and returns all the exprs in all scopes in the original |
148 | //! linear textural order. |
149 | TORCH_CUDA_CU_API std::vector<Expr*> flattenScopedExprs( |
150 | const std::vector<Expr*>& loop_nests); |
151 | |
152 | //! Returns all swizzle ops between the set of iterdomains |
153 | //! in `from` and `to`. |
154 | std::vector<Expr*> getAllSwizzlesBetween( |
155 | std::vector<IterDomain*> from, |
156 | std::vector<IterDomain*> to); |
157 | |
158 | // Replace value pass on Kernel IR. |
159 | // Replace each use of any Val* that apears in the given `replacement_map` |
160 | // Keeps the predicate carried by each expr |
161 | // |
162 | // Warning: Blindly replaces all use based on pointer |
163 | // Warning: May invalidate indexing if replacing uses of allocated values |
164 | std::vector<Expr*> replaceInputsInExpr( |
165 | const std::vector<Expr*>& exprs, |
166 | const std::unordered_map<Val*, Val*>& replacement_map); |
167 | |
168 | // Go through all expressions and compute a local ordering of loops. operator< |
169 | // is implemented based on the concrete_id_dependencies analysis done. If |
170 | // there's no dependency between two IDs then order doesn't mater, otherwise we |
171 | // can tell which is inner most by checking if there's any dependency |
172 | // relationships. |
173 | // |
174 | // Dependency relationships in concrete_id_dependencies has a "global" view in |
175 | // the fusion, so it can resolve ordering by only looking at id's and the |
176 | // dependency map. |
177 | // |
178 | // For example two expressions may have domains: [I0], [I1] Yet we |
179 | // won't know the ordering unless we see a domain with: [I0, I1]. This happened |
180 | // in advancedIndexing9 (also see AdvancedLowering6) test when merging T5 with |
181 | // the group containing T10 (cache of T5, which is post broadcasted output) and |
182 | // T6(pre broadcasted output). |
183 | // T5 had the domain [0, 1, 2, 3, 4] produce at 3 |
184 | // T6 had the domain [0, 3, 4] compute at 3 |
185 | // Merging [0, 1, 2] and [0, 3, 4] resulted in the domain [0, 3, 4, 1, 2] |
186 | // |
187 | // If ID's are not in filter, we don't care about their ordering and ignore |
188 | // them. This is because we're only focused on loops we will have to merge |
189 | // across groups. If the domain is not in a produce at position in the producer |
190 | // edges, or a compute at position in the consumer edges, the expressions we |
191 | // look at may not have a unique ordering. |
192 | |
193 | struct TORCH_CUDA_CU_API IterDomainDependencySorter { |
194 | IterDomainDependencySorter( |
195 | const std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>& |
196 | concrete_id_dependencies, |
197 | std::shared_ptr<const ComputeAtMap> compute_at_map) |
198 | : concrete_id_dependencies_(concrete_id_dependencies), |
199 | compute_at_map_(compute_at_map) {} |
200 | |
201 | // Return true if id0 should be before id1 |
202 | // Orders such that if x maps to {y}, x comes before y in final ordering. |
203 | inline bool operator()(IterDomain* id0, IterDomain* id1) { |
204 | auto concrete_id_0 = |
205 | compute_at_map_->getConcreteMappedID(id0, IdMappingMode::LOOP); |
206 | auto concrete_id_1 = |
207 | compute_at_map_->getConcreteMappedID(id1, IdMappingMode::LOOP); |
208 | |
209 | if (concrete_id_dependencies_.find(concrete_id_0) != |
210 | concrete_id_dependencies_.end()) { |
211 | const auto& dependencies_0 = concrete_id_dependencies_.at(concrete_id_0); |
212 | // if id0 depends on id1 it means id1 is inside id0, so id0 < id1 |
213 | if (dependencies_0.count(concrete_id_1)) { |
214 | return true; |
215 | } |
216 | } |
217 | |
218 | return false; |
219 | } |
220 | |
221 | const std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>& |
222 | concrete_id_dependencies_; |
223 | const std::shared_ptr<const ComputeAtMap> compute_at_map_; |
224 | }; |
225 | |
226 | } // namespace ir_utils |
227 | |
228 | namespace lower_utils { |
229 | |
230 | bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map); |
231 | |
232 | // Allocate global buffer for a grid communication calls, i.e. grid reduce, grid |
233 | // welford reduce, grid broadcast. |
234 | kir::Allocate* allocGlobalBufferForGridComm( |
235 | Val* buffer_size, |
236 | DataType dtype, |
237 | bool zero_init); |
238 | |
239 | struct BasicAllocInfo { |
240 | // The for loop that the initialization of this allocation must be |
241 | // placed in, nullptr if not within a loop |
242 | kir::ForLoop* init_for_loop = nullptr; |
243 | |
244 | // Keep track of the actual allocation loop. This can be different |
245 | // from init_for_loop only with unswitched shared memory allocations, |
246 | // which are moved outer loops to avoid duplicated allocations. This means |
247 | // that the alloc position may be outside what's expected. Most applications |
248 | // outside lower_allocation is likely looking for init_for_loop which is |
249 | // more directly related to how large an allocation is and how it's used. |
250 | // (see issue #1133). |
251 | kir::ForLoop* alloc_for_loop = nullptr; |
252 | |
253 | // The allocation position relative to buffer IDs, it could be outside the |
254 | // compute at position if it's shared memory with a compute at inside an |
255 | // unswitch |
256 | size_t alloc_pos = 0; |
257 | }; |
258 | |
259 | // Fill the above allocation struct based on provided information. id_map is |
260 | // used if we're looking at a producer tensor but loops on a consumer tensor. |
261 | BasicAllocInfo getAllocInformation( |
262 | const TensorView* tv, |
263 | const std::vector<kir::ForLoop*>& loops, |
264 | const std::unordered_map<IterDomain*, IterDomain*>& id_map = {}, |
265 | bool use_id_map = false); |
266 | |
267 | //! Returns true if the expression has a variant that takes a predicate |
268 | //! as an inline argument. |
269 | bool supportInlinePredicate(Expr* expr); |
270 | |
271 | } // namespace lower_utils |
272 | |
273 | } // namespace cuda |
274 | } // namespace fuser |
275 | } // namespace jit |
276 | } // namespace torch |
277 | |