1 | #pragma once |
2 | |
3 | #include <disjoint_set.h> |
4 | #include <ir_all_nodes.h> |
5 | #include <iter_visitor.h> |
6 | #include <utils.h> |
7 | |
8 | #include <c10/macros/Export.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | //! Generic interface for mapping root domains of a producer-consumer pair. |
16 | class TORCH_CUDA_CU_API RootDomainMap : public PolymorphicBase { |
17 | public: |
18 | //! Return a map from a producer TensorDomain to a consumer |
19 | //! TensorDomain |
20 | //! |
21 | //! \param producer A producer TensorDomain |
22 | //! \param consumer A consumer TensorDomain |
23 | //! \param root_dims_to_map Maps only producer root domains in this set |
24 | std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer( |
25 | const TensorDomain* producer, |
26 | const TensorDomain* consumer, |
27 | const std::unordered_set<IterDomain*>& root_dims_to_map) const; |
28 | |
29 | //! Return a map from a producer TensorDomain to a consumer |
30 | //! TensorDomain |
31 | //! |
32 | //! \param producer A producer TensorDomain |
33 | //! \param consumer A consumer TensorDomain |
34 | std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer( |
35 | const TensorDomain* producer, |
36 | const TensorDomain* consumer) const; |
37 | |
38 | //! Return a map from a consumer TensorDomain to a producer |
39 | //! TensorDomain |
40 | //! |
41 | //! \param consumer A consumer TensorDomain |
42 | //! \param producer A producer TensorDomain |
43 | //! \param root_dims_to_map Maps only consumer root domains in this set |
44 | std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer( |
45 | const TensorDomain* consumer, |
46 | const TensorDomain* producer, |
47 | const std::unordered_set<IterDomain*>& root_dims_to_map) const; |
48 | |
49 | //! Return a map from a consumer TensorDomain to a producer |
50 | //! TensorDomain |
51 | //! |
52 | //! \param consumer A consumer TensorDomain |
53 | //! \param producer A producer TensorDomain |
54 | std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer( |
55 | const TensorDomain* consumer, |
56 | const TensorDomain* producer) const; |
57 | |
58 | protected: |
59 | //! Return a map between root IterDomains of a producer-consumer |
60 | //! pair. |
61 | //! |
62 | //! \param producer A producer TensorDomain |
63 | //! \param consumer A consumer TensorDomain |
64 | //! \param root_dims_to_map Maps only from IterDomains in this set |
65 | //! \param producer_to_consumer Maps from producer to consumer if true |
66 | virtual std::unordered_map<IterDomain*, IterDomain*> map( |
67 | const TensorDomain* producer, |
68 | const TensorDomain* consumer, |
69 | const std::unordered_set<IterDomain*>& root_dims_to_map, |
70 | bool producer_to_consumer) const = 0; |
71 | }; |
72 | |
73 | //! Maps root domains of a producer-consumer pair. This class only |
74 | //! looks at the given pair of TensorViews and does not take into |
75 | //! consideration the constraints of the computeAt transformation, |
76 | //! i.e., unable to compute the same tensors multiple times. This |
77 | //! should not be used for transformations implementing computeAt, but |
78 | //! should be valid otherwise. |
79 | class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap { |
80 | public: |
81 | //! \param producer The producer tensor of a producer-consumer pair. |
82 | //! \param consumer The consumer tensor of a producer-consumer pair. |
83 | explicit PairwiseRootDomainMap( |
84 | const TensorView* producer, |
85 | const TensorView* consumer, |
86 | bool is_exact = false); |
87 | |
88 | const TensorView* producer() const { |
89 | return producer_tv_; |
90 | } |
91 | |
92 | const TensorView* consumer() const { |
93 | return consumer_tv_; |
94 | } |
95 | |
96 | std::string toString() const; |
97 | |
98 | protected: |
99 | std::unordered_map<IterDomain*, IterDomain*> map( |
100 | const TensorDomain* producer, |
101 | const TensorDomain* consumer, |
102 | const std::unordered_set<IterDomain*>& root_dims_to_map, |
103 | bool producer_to_consumer) const override; |
104 | |
105 | std::unordered_map<IterDomain*, IterDomain*> mapTranspose( |
106 | const TensorDomain* producer, |
107 | const TensorDomain* consumer, |
108 | const std::unordered_set<IterDomain*>& root_dims_to_map, |
109 | bool producer_to_consumer) const; |
110 | |
111 | private: |
112 | const TensorView* producer_tv_ = nullptr; |
113 | const TensorView* consumer_tv_ = nullptr; |
114 | //! If true, does not map broadcast IDs with non-broadcast IDs |
115 | const bool is_exact_ = false; |
116 | }; |
117 | |
118 | //! Represents an iteration domain of a TensorDomain. Only used for |
119 | //! root domain mapping. |
120 | //! |
121 | //! Note that an IterDomain object may be reused |
122 | //! across multiple TensorDomains, but an IterDomain in a |
123 | //! TensorDomain may not be necessarily mappable to the same |
124 | //! IterDomain used in a different TensorDomain. Thus, for the purpose |
125 | //! of root domain mapping, an iteration domain needs to be identified |
126 | //! with an IterDomain and its TensorDomain. |
127 | class DomainKey { |
128 | public: |
129 | DomainKey() = default; |
130 | DomainKey( |
131 | const TensorDomain* td, |
132 | const IterDomain* id, |
133 | const IterDomain* concrete_id = nullptr) |
134 | : td_(td), id_(id), concrete_id_(concrete_id) {} |
135 | const TensorDomain* td() const { |
136 | return td_; |
137 | } |
138 | const IterDomain* id() const { |
139 | return id_; |
140 | } |
141 | const IterDomain* concreteId() const { |
142 | return concrete_id_; |
143 | } |
144 | bool operator==(const DomainKey& other) const { |
145 | return td() == other.td() && id() == other.id() && |
146 | concreteId() == other.concreteId(); |
147 | } |
148 | bool operator!=(const DomainKey& other) const { |
149 | return !(*this == other); |
150 | } |
151 | |
152 | std::string toString() const; |
153 | |
154 | private: |
155 | const TensorDomain* td_ = nullptr; |
156 | const IterDomain* id_ = nullptr; |
157 | const IterDomain* concrete_id_ = nullptr; |
158 | }; |
159 | |
160 | struct DomainKeyHash { |
161 | std::size_t operator()(const DomainKey& key) const { |
162 | return std::hash<const TensorDomain*>{}(key.td()) ^ |
163 | std::hash<const IterDomain*>{}(key.id()); |
164 | } |
165 | }; |
166 | |
167 | using DomainKeySet = std::unordered_set<DomainKey, DomainKeyHash>; |
168 | |
169 | template <typename Mapped> |
170 | using DomainKeyMap = std::unordered_map<DomainKey, Mapped, DomainKeyHash>; |
171 | |
172 | class ComputeAtRootDomainMap; |
173 | |
174 | //! A helper class to find all DomainKeys that are consumers of |
175 | //! reduction outputs. Such consumer IterDomains may not be mapped to |
176 | //! the producer reduction domain since the corresponding reduction |
177 | //! loop must be closed before any of the consumers can appear. |
178 | class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor { |
179 | public: |
180 | UnmappableReductionDomains(); |
181 | ~UnmappableReductionDomains() override = default; |
182 | |
183 | //! Returns true when mapping consumer domains would cause a |
184 | //! reduction output domain to be mapped with a consumer domain of |
185 | //! the redution. It needs to be avoided as computing consumers of |
186 | //! reduction outputs within the corresponding reduction loop is not |
187 | //! possible. This routine is used to build root domain mappings. |
188 | bool isReductionOutputMapped( |
189 | const DomainKeySet& consumer_domains, |
190 | const ComputeAtRootDomainMap& root_map) const; |
191 | |
192 | std::string toString() const; |
193 | |
194 | private: |
195 | using IterVisitor::handle; |
196 | void handle(ReductionOp* op) override; |
197 | void handle(GroupedReductionOp* op) override; |
198 | void handle(WelfordOp* op) override; |
199 | void handle(MmaOp* op) override; |
200 | |
201 | void handleReductionOutput(TensorView* out_tv); |
202 | |
203 | private: |
204 | //! Map from Reduction output DomainKeys to consumer DomainKeys |
205 | DomainKeyMap<DomainKeySet> reduction_domains_; |
206 | //! Map from Reduction output DomainKeys to producer DomainKeys |
207 | DomainKeyMap<DomainKeySet> reduction_domain_inputs_; |
208 | }; |
209 | |
210 | //! Models root-domain mappings for computeAt |
211 | //! |
212 | //! Two iteration domains are mapped when computeAt of one iteration |
213 | //! domain is possible at another iteration domain. Consider a simple |
214 | //! example: |
215 | //! T2 [i0,i1] = T1[i2,i3] + T0[i4,i5] |
216 | //! This will create mappings between i0, i2 and i4. |
217 | //! |
218 | //! Note that with views, there can be multiple domains mapped with |
219 | //! the same domain. Thus, obtaining one-to-one maps can |
220 | //! fail. Currently, the only use of this class is getMappableDims, |
221 | //! which just grabs any domain that is mappable, which works no |
222 | //! matter view is used or not. |
223 | class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap { |
224 | friend class ComputeAtRootDomainMapBuilder; |
225 | |
226 | public: |
227 | //! Builds a mapping table by analyzing the current |
228 | //! fusion. Overwrite a previous table if any. |
229 | //! |
230 | //! \param map_through_reduction If set |
231 | //! true, will disable UnmappableReductionDomains check. |
232 | //! This is only for re-using logic in detecting |
233 | //! normalization fusions, which deviates slightly from |
234 | //! intended use of this class. Should always be true |
235 | //! in compute_at use cases. |
236 | void build(bool map_through_reduction = false); |
237 | |
238 | //! Returns if key(td_a, id_a) and key(td_b, id_b) are mapped to eachother |
239 | //! (equivalent), or are the same key. |
240 | //! |
241 | //! \param td_a A TensorDomain |
242 | //! \param id_a An IterDomain in td_a |
243 | //! \param td_b Another TensorDomain |
244 | //! \param id_b An IterDomain in td_b |
245 | //! \returns Boolean representing if they are mapped |
246 | bool canMap( |
247 | const TensorDomain* td_a, |
248 | const IterDomain* id_a, |
249 | const TensorDomain* td_b, |
250 | const IterDomain* id_b) const; |
251 | |
252 | //! Make a TensorDomain an alias of another TensorDomain |
253 | //! |
254 | //! This is for the computeAt transformation, where TensorViews are |
255 | //! updated with new TensorDomains. Since they keep using the same |
256 | //! root doamins, the root mapping remains valid but needs to |
257 | //! reflect the use of new TensorDomains as aliases of the existing |
258 | //! ones. |
259 | //! |
260 | //! \param td An existing TensorDomain |
261 | //! \param td_alias An alias of td |
262 | void setAlias(const TensorDomain* td, const TensorDomain* td_alias); |
263 | |
264 | //! Return a map between TensorDomains |
265 | //! |
266 | //! Unlike the other map functions, two TensorDomains do not need to |
267 | //! be a producer-consumer pair. Since they may not be a |
268 | //! producer-consumer pair, this function requires proper root |
269 | //! domains, which may be root or rfactor domains. Also, no error |
270 | //! check is done as we do not assume producer-consumer |
271 | //! relationship. |
272 | //! |
273 | //! Note that an exception is thrown when a domain is found to be |
274 | //! mapped to multiple domains, which can happen with views. |
275 | //! |
276 | //! \param from_td A TensorDomain from which a map is created |
277 | //! \param from_root A root domain of from_td |
278 | //! \param to_td A TensorDomain to which a map is created |
279 | //! \param to_root A root domain of to_td |
280 | std::unordered_map<IterDomain*, IterDomain*> mapBestEffort( |
281 | const TensorDomain* from_td, |
282 | const std::vector<IterDomain*>& from_root, |
283 | const TensorDomain* to_td, |
284 | const std::vector<IterDomain*>& to_root) const; |
285 | |
286 | // Returns an unordered set of all iter domains in producer and consumer that |
287 | // can map to eachother |
288 | std::unordered_set<IterDomain*> getMappableDims( |
289 | const TensorDomain* producer, |
290 | const TensorDomain* consumer) const; |
291 | |
292 | std::string toString() const; |
293 | |
294 | private: |
295 | //! Returns if key_a and key(td_b, id_b) are mapped to eachother (equivalent), |
296 | //! or are the same key. |
297 | //! |
298 | //! \param key_a A DomainKey |
299 | //! \param td_b Another TensorDomain |
300 | //! \param id_b An IterDomain in td_b |
301 | //! \returns Boolean representing if they are mapped |
302 | bool canMap( |
303 | const DomainKey& key_a, |
304 | const TensorDomain* td_b, |
305 | const IterDomain* id_b) const; |
306 | |
307 | //! Returns if key_a and key_b are mapped to each other (equivalent), or are |
308 | //! the same key. Returns false if two keys are not known to be mapped. |
309 | bool canMap(const DomainKey& key_a, const DomainKey& key_b) const; |
310 | |
311 | //! Returns the set of (non-broadcast) DomainKeys that id in td is |
312 | //! broadcasted to. Can result in more than one "concrete" DomainKey. |
313 | std::vector<DomainKey> getConcretizedKeys( |
314 | const TensorDomain* td, |
315 | const IterDomain* id) const; |
316 | |
317 | //! Returns the set of (non-broadcast) iter domains that id in td is |
318 | //! broadcasted to. Can result in more than one "concrete" iter domain. |
319 | std::unordered_set<const IterDomain*>& getConcretizedDomains( |
320 | const TensorDomain* td, |
321 | const IterDomain* id); |
322 | |
323 | //! Return a map between root IterDomains of a producer-consumer |
324 | //! pair. |
325 | //! |
326 | //! \param producer A producer TensorDomain |
327 | //! \param consumer A consumer TensorDomain |
328 | //! \param root_dims_to_map Maps only from IterDomains in this set |
329 | //! \param producer_to_consumer Maps from producer to consumer if true |
330 | std::unordered_map<IterDomain*, IterDomain*> map( |
331 | const TensorDomain* producer, |
332 | const TensorDomain* consumer, |
333 | const std::unordered_set<IterDomain*>& root_dims_to_map, |
334 | bool producer_to_consumer) const override; |
335 | |
336 | private: |
337 | //! Disjoint set of all mapped <TD, ID> keys to determine axes equivalency |
338 | DisjointSets<DomainKey, DomainKeyHash> eq_set_; |
339 | |
340 | //! All IterDomains in the mapping that are a broadcast ID |
341 | DomainKeyMap<std::unordered_set<const IterDomain*>> bcast_map_; |
342 | |
343 | //! Broadcast iter domain that does not match dimensions in its produer, |
344 | //! meaning it is a brand new domain in its TensorDomain. |
345 | DomainKeySet new_broadcast_domains_; |
346 | |
347 | //! Keep track of window axes so that the map function can ignore them. |
348 | std::unordered_set<IterDomain*> window_axes_; |
349 | }; |
350 | |
351 | //! Create a DisjointSets of root IterDomains by traversing the |
352 | //! current fusion entirely. IterDomains that can be mapped each |
353 | //! other with computeAt are grouped into the same subset in the |
354 | //! DisjointSets. |
355 | class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder |
356 | : private BackwardVisitor { |
357 | public: |
358 | explicit ComputeAtRootDomainMapBuilder( |
359 | ComputeAtRootDomainMap& root_map, |
360 | bool map_through_reduction = false); |
361 | |
362 | private: |
363 | //! Initialize the bcast map for fusion outputs |
364 | void initializeBcastMap(const TensorView* tv, const IterDomain* id); |
365 | |
366 | //! Set a pair of producer-consumer domain keys as mappable |
367 | void setMapped(const DomainKey& producer, const DomainKey& consumer); |
368 | |
369 | //! Records two domains are invalid to map |
370 | void setInvalid(const DomainKey& key1, const DomainKey& key2); |
371 | |
372 | //! Check if no pair of domains is invalid to map |
373 | bool isInvalid(const DomainKeySet& domains) const; |
374 | |
375 | //! Track a pair of producer-consumer domains as potentially mappable. Inserts |
376 | //! entries into pending_map_, but does not add anything into the root_map_ |
377 | //! (added when handle is called on a TensorView). Maybe mapped will, however, |
378 | //! immediately propagate broadcast iter domains. |
379 | void setMaybeMapped( |
380 | const TensorDomain* producer_td, |
381 | const IterDomain* producer_id, |
382 | const TensorDomain* consumer_td, |
383 | const IterDomain* consumer_id); |
384 | |
385 | void addToPendingList(const DomainKey& producer, const DomainKey& consumer); |
386 | |
387 | //! Map pointwise IterDomains from inputs of expressions to outputs. |
388 | //! Do not map reduction IterDomains in inputs. |
389 | void mapPointwiseOrReductionOp(Expr* e); |
390 | |
391 | using BackwardVisitor::handle; |
392 | |
393 | void handle(Expr* e) override; |
394 | |
395 | void handle(UnaryOp* uop) override { |
396 | mapPointwiseOrReductionOp(uop); |
397 | } |
398 | |
399 | void handle(BinaryOp* bop) override { |
400 | mapPointwiseOrReductionOp(bop); |
401 | } |
402 | |
403 | void handle(TernaryOp* top) override { |
404 | mapPointwiseOrReductionOp(top); |
405 | } |
406 | |
407 | void handle(RNGOp* top) override; |
408 | |
409 | void handle(ReductionOp* op) override { |
410 | mapPointwiseOrReductionOp(op); |
411 | } |
412 | |
413 | void handle(GroupedReductionOp* op) override { |
414 | mapPointwiseOrReductionOp(op); |
415 | } |
416 | |
417 | void handle(WelfordOp* wop) override { |
418 | mapPointwiseOrReductionOp(wop); |
419 | } |
420 | |
421 | void handle(LoadStoreOp* ldst) override { |
422 | mapPointwiseOrReductionOp(ldst); |
423 | } |
424 | |
425 | void handle(MmaOp* wop) override { |
426 | mapPointwiseOrReductionOp(wop); |
427 | } |
428 | |
429 | void handle(ShiftOp* op) override { |
430 | mapPointwiseOrReductionOp(op); |
431 | } |
432 | |
433 | void handle(ViewOp* op) override { |
434 | mapPointwiseOrReductionOp(op); |
435 | } |
436 | |
437 | void handle(ViewAsScalar* op) override; |
438 | |
439 | void handle(BroadcastOp* op) override; |
440 | |
441 | void handle(TransposeOp* op) override; |
442 | |
443 | void handle(ExpandOp* op) override { |
444 | mapPointwiseOrReductionOp(op); |
445 | } |
446 | |
447 | void handle(GatherOp* op) override; |
448 | |
449 | void handle(TensorView* tv) override; |
450 | |
451 | //! Maps all pending mappings. |
452 | //! This is called for each of TensorViews in a backward traversal, |
453 | //! recursively building mappings from the output tensors to the |
454 | //! input tensors. |
455 | void mapAllPendingMappings(const DomainKey& key); |
456 | |
457 | //! Maps all pending mappings for id of td. When id is a broadcast, |
458 | //! mapping is done separately for each concrete domain. |
459 | void mapAllPendingMappings(const TensorDomain* td, IterDomain* id); |
460 | |
461 | bool safeToMap(const DomainKeySet& domains); |
462 | |
463 | private: |
464 | ComputeAtRootDomainMap& root_map_; |
465 | //! Keep track of what we want to try and map |
466 | DomainKeyMap<DomainKeySet> pending_map_; |
467 | std::unordered_set<Expr*> visited_; |
468 | //! Helper class to find invalid mappings due to reductions |
469 | UnmappableReductionDomains incompatible_domains_; |
470 | //! Running vector of domain pairs that are invalid to map |
471 | std::vector<std::pair<DomainKey, DomainKey>> invalid_mappings_; |
472 | |
473 | //! Disable UnmappableReductions check, should |
474 | //! always be false for compute_at use cases |
475 | bool map_through_reduction_ = false; |
476 | }; |
477 | |
478 | //! Maps root domains of an entire fusion. Does not map broadcast |
479 | //! domains with non-broadcast domains. |
480 | class TORCH_CUDA_CU_API ExactRootDomainMap : public RootDomainMap { |
481 | public: |
482 | ExactRootDomainMap(Fusion* fusion); |
483 | |
484 | bool areMapped(const IterDomain* id_a, const IterDomain* id_b) const; |
485 | |
486 | std::string toString() const; |
487 | |
488 | protected: |
489 | std::unordered_map<IterDomain*, IterDomain*> map( |
490 | const TensorDomain* producer, |
491 | const TensorDomain* consumer, |
492 | const std::unordered_set<IterDomain*>& root_dims_to_map, |
493 | bool producer_to_consumer) const override; |
494 | |
495 | private: |
496 | DisjointSets<const IterDomain*> eq_sets_; |
497 | }; |
498 | |
499 | } // namespace cuda |
500 | } // namespace fuser |
501 | } // namespace jit |
502 | } // namespace torch |
503 | |