1#pragma once
2
3#include <disjoint_set.h>
4#include <ir_all_nodes.h>
5#include <iter_visitor.h>
6#include <utils.h>
7
8#include <c10/macros/Export.h>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15//! Generic interface for mapping root domains of a producer-consumer pair.
16class TORCH_CUDA_CU_API RootDomainMap : public PolymorphicBase {
17 public:
18 //! Return a map from a producer TensorDomain to a consumer
19 //! TensorDomain
20 //!
21 //! \param producer A producer TensorDomain
22 //! \param consumer A consumer TensorDomain
23 //! \param root_dims_to_map Maps only producer root domains in this set
24 std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer(
25 const TensorDomain* producer,
26 const TensorDomain* consumer,
27 const std::unordered_set<IterDomain*>& root_dims_to_map) const;
28
29 //! Return a map from a producer TensorDomain to a consumer
30 //! TensorDomain
31 //!
32 //! \param producer A producer TensorDomain
33 //! \param consumer A consumer TensorDomain
34 std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer(
35 const TensorDomain* producer,
36 const TensorDomain* consumer) const;
37
38 //! Return a map from a consumer TensorDomain to a producer
39 //! TensorDomain
40 //!
41 //! \param consumer A consumer TensorDomain
42 //! \param producer A producer TensorDomain
43 //! \param root_dims_to_map Maps only consumer root domains in this set
44 std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer(
45 const TensorDomain* consumer,
46 const TensorDomain* producer,
47 const std::unordered_set<IterDomain*>& root_dims_to_map) const;
48
49 //! Return a map from a consumer TensorDomain to a producer
50 //! TensorDomain
51 //!
52 //! \param consumer A consumer TensorDomain
53 //! \param producer A producer TensorDomain
54 std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer(
55 const TensorDomain* consumer,
56 const TensorDomain* producer) const;
57
58 protected:
59 //! Return a map between root IterDomains of a producer-consumer
60 //! pair.
61 //!
62 //! \param producer A producer TensorDomain
63 //! \param consumer A consumer TensorDomain
64 //! \param root_dims_to_map Maps only from IterDomains in this set
65 //! \param producer_to_consumer Maps from producer to consumer if true
66 virtual std::unordered_map<IterDomain*, IterDomain*> map(
67 const TensorDomain* producer,
68 const TensorDomain* consumer,
69 const std::unordered_set<IterDomain*>& root_dims_to_map,
70 bool producer_to_consumer) const = 0;
71};
72
73//! Maps root domains of a producer-consumer pair. This class only
74//! looks at the given pair of TensorViews and does not take into
75//! consideration the constraints of the computeAt transformation,
76//! i.e., unable to compute the same tensors multiple times. This
77//! should not be used for transformations implementing computeAt, but
78//! should be valid otherwise.
79class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap {
80 public:
81 //! \param producer The producer tensor of a producer-consumer pair.
82 //! \param consumer The consumer tensor of a producer-consumer pair.
83 explicit PairwiseRootDomainMap(
84 const TensorView* producer,
85 const TensorView* consumer,
86 bool is_exact = false);
87
88 const TensorView* producer() const {
89 return producer_tv_;
90 }
91
92 const TensorView* consumer() const {
93 return consumer_tv_;
94 }
95
96 std::string toString() const;
97
98 protected:
99 std::unordered_map<IterDomain*, IterDomain*> map(
100 const TensorDomain* producer,
101 const TensorDomain* consumer,
102 const std::unordered_set<IterDomain*>& root_dims_to_map,
103 bool producer_to_consumer) const override;
104
105 std::unordered_map<IterDomain*, IterDomain*> mapTranspose(
106 const TensorDomain* producer,
107 const TensorDomain* consumer,
108 const std::unordered_set<IterDomain*>& root_dims_to_map,
109 bool producer_to_consumer) const;
110
111 private:
112 const TensorView* producer_tv_ = nullptr;
113 const TensorView* consumer_tv_ = nullptr;
114 //! If true, does not map broadcast IDs with non-broadcast IDs
115 const bool is_exact_ = false;
116};
117
118//! Represents an iteration domain of a TensorDomain. Only used for
119//! root domain mapping.
120//!
121//! Note that an IterDomain object may be reused
122//! across multiple TensorDomains, but an IterDomain in a
123//! TensorDomain may not be necessarily mappable to the same
124//! IterDomain used in a different TensorDomain. Thus, for the purpose
125//! of root domain mapping, an iteration domain needs to be identified
126//! with an IterDomain and its TensorDomain.
127class DomainKey {
128 public:
129 DomainKey() = default;
130 DomainKey(
131 const TensorDomain* td,
132 const IterDomain* id,
133 const IterDomain* concrete_id = nullptr)
134 : td_(td), id_(id), concrete_id_(concrete_id) {}
135 const TensorDomain* td() const {
136 return td_;
137 }
138 const IterDomain* id() const {
139 return id_;
140 }
141 const IterDomain* concreteId() const {
142 return concrete_id_;
143 }
144 bool operator==(const DomainKey& other) const {
145 return td() == other.td() && id() == other.id() &&
146 concreteId() == other.concreteId();
147 }
148 bool operator!=(const DomainKey& other) const {
149 return !(*this == other);
150 }
151
152 std::string toString() const;
153
154 private:
155 const TensorDomain* td_ = nullptr;
156 const IterDomain* id_ = nullptr;
157 const IterDomain* concrete_id_ = nullptr;
158};
159
160struct DomainKeyHash {
161 std::size_t operator()(const DomainKey& key) const {
162 return std::hash<const TensorDomain*>{}(key.td()) ^
163 std::hash<const IterDomain*>{}(key.id());
164 }
165};
166
167using DomainKeySet = std::unordered_set<DomainKey, DomainKeyHash>;
168
169template <typename Mapped>
170using DomainKeyMap = std::unordered_map<DomainKey, Mapped, DomainKeyHash>;
171
172class ComputeAtRootDomainMap;
173
174//! A helper class to find all DomainKeys that are consumers of
175//! reduction outputs. Such consumer IterDomains may not be mapped to
176//! the producer reduction domain since the corresponding reduction
177//! loop must be closed before any of the consumers can appear.
178class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor {
179 public:
180 UnmappableReductionDomains();
181 ~UnmappableReductionDomains() override = default;
182
183 //! Returns true when mapping consumer domains would cause a
184 //! reduction output domain to be mapped with a consumer domain of
185 //! the redution. It needs to be avoided as computing consumers of
186 //! reduction outputs within the corresponding reduction loop is not
187 //! possible. This routine is used to build root domain mappings.
188 bool isReductionOutputMapped(
189 const DomainKeySet& consumer_domains,
190 const ComputeAtRootDomainMap& root_map) const;
191
192 std::string toString() const;
193
194 private:
195 using IterVisitor::handle;
196 void handle(ReductionOp* op) override;
197 void handle(GroupedReductionOp* op) override;
198 void handle(WelfordOp* op) override;
199 void handle(MmaOp* op) override;
200
201 void handleReductionOutput(TensorView* out_tv);
202
203 private:
204 //! Map from Reduction output DomainKeys to consumer DomainKeys
205 DomainKeyMap<DomainKeySet> reduction_domains_;
206 //! Map from Reduction output DomainKeys to producer DomainKeys
207 DomainKeyMap<DomainKeySet> reduction_domain_inputs_;
208};
209
210//! Models root-domain mappings for computeAt
211//!
212//! Two iteration domains are mapped when computeAt of one iteration
213//! domain is possible at another iteration domain. Consider a simple
214//! example:
215//! T2 [i0,i1] = T1[i2,i3] + T0[i4,i5]
216//! This will create mappings between i0, i2 and i4.
217//!
218//! Note that with views, there can be multiple domains mapped with
219//! the same domain. Thus, obtaining one-to-one maps can
220//! fail. Currently, the only use of this class is getMappableDims,
221//! which just grabs any domain that is mappable, which works no
222//! matter view is used or not.
223class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap {
224 friend class ComputeAtRootDomainMapBuilder;
225
226 public:
227 //! Builds a mapping table by analyzing the current
228 //! fusion. Overwrite a previous table if any.
229 //!
230 //! \param map_through_reduction If set
231 //! true, will disable UnmappableReductionDomains check.
232 //! This is only for re-using logic in detecting
233 //! normalization fusions, which deviates slightly from
234 //! intended use of this class. Should always be true
235 //! in compute_at use cases.
236 void build(bool map_through_reduction = false);
237
238 //! Returns if key(td_a, id_a) and key(td_b, id_b) are mapped to eachother
239 //! (equivalent), or are the same key.
240 //!
241 //! \param td_a A TensorDomain
242 //! \param id_a An IterDomain in td_a
243 //! \param td_b Another TensorDomain
244 //! \param id_b An IterDomain in td_b
245 //! \returns Boolean representing if they are mapped
246 bool canMap(
247 const TensorDomain* td_a,
248 const IterDomain* id_a,
249 const TensorDomain* td_b,
250 const IterDomain* id_b) const;
251
252 //! Make a TensorDomain an alias of another TensorDomain
253 //!
254 //! This is for the computeAt transformation, where TensorViews are
255 //! updated with new TensorDomains. Since they keep using the same
256 //! root doamins, the root mapping remains valid but needs to
257 //! reflect the use of new TensorDomains as aliases of the existing
258 //! ones.
259 //!
260 //! \param td An existing TensorDomain
261 //! \param td_alias An alias of td
262 void setAlias(const TensorDomain* td, const TensorDomain* td_alias);
263
264 //! Return a map between TensorDomains
265 //!
266 //! Unlike the other map functions, two TensorDomains do not need to
267 //! be a producer-consumer pair. Since they may not be a
268 //! producer-consumer pair, this function requires proper root
269 //! domains, which may be root or rfactor domains. Also, no error
270 //! check is done as we do not assume producer-consumer
271 //! relationship.
272 //!
273 //! Note that an exception is thrown when a domain is found to be
274 //! mapped to multiple domains, which can happen with views.
275 //!
276 //! \param from_td A TensorDomain from which a map is created
277 //! \param from_root A root domain of from_td
278 //! \param to_td A TensorDomain to which a map is created
279 //! \param to_root A root domain of to_td
280 std::unordered_map<IterDomain*, IterDomain*> mapBestEffort(
281 const TensorDomain* from_td,
282 const std::vector<IterDomain*>& from_root,
283 const TensorDomain* to_td,
284 const std::vector<IterDomain*>& to_root) const;
285
286 // Returns an unordered set of all iter domains in producer and consumer that
287 // can map to eachother
288 std::unordered_set<IterDomain*> getMappableDims(
289 const TensorDomain* producer,
290 const TensorDomain* consumer) const;
291
292 std::string toString() const;
293
294 private:
295 //! Returns if key_a and key(td_b, id_b) are mapped to eachother (equivalent),
296 //! or are the same key.
297 //!
298 //! \param key_a A DomainKey
299 //! \param td_b Another TensorDomain
300 //! \param id_b An IterDomain in td_b
301 //! \returns Boolean representing if they are mapped
302 bool canMap(
303 const DomainKey& key_a,
304 const TensorDomain* td_b,
305 const IterDomain* id_b) const;
306
307 //! Returns if key_a and key_b are mapped to each other (equivalent), or are
308 //! the same key. Returns false if two keys are not known to be mapped.
309 bool canMap(const DomainKey& key_a, const DomainKey& key_b) const;
310
311 //! Returns the set of (non-broadcast) DomainKeys that id in td is
312 //! broadcasted to. Can result in more than one "concrete" DomainKey.
313 std::vector<DomainKey> getConcretizedKeys(
314 const TensorDomain* td,
315 const IterDomain* id) const;
316
317 //! Returns the set of (non-broadcast) iter domains that id in td is
318 //! broadcasted to. Can result in more than one "concrete" iter domain.
319 std::unordered_set<const IterDomain*>& getConcretizedDomains(
320 const TensorDomain* td,
321 const IterDomain* id);
322
323 //! Return a map between root IterDomains of a producer-consumer
324 //! pair.
325 //!
326 //! \param producer A producer TensorDomain
327 //! \param consumer A consumer TensorDomain
328 //! \param root_dims_to_map Maps only from IterDomains in this set
329 //! \param producer_to_consumer Maps from producer to consumer if true
330 std::unordered_map<IterDomain*, IterDomain*> map(
331 const TensorDomain* producer,
332 const TensorDomain* consumer,
333 const std::unordered_set<IterDomain*>& root_dims_to_map,
334 bool producer_to_consumer) const override;
335
336 private:
337 //! Disjoint set of all mapped <TD, ID> keys to determine axes equivalency
338 DisjointSets<DomainKey, DomainKeyHash> eq_set_;
339
340 //! All IterDomains in the mapping that are a broadcast ID
341 DomainKeyMap<std::unordered_set<const IterDomain*>> bcast_map_;
342
343 //! Broadcast iter domain that does not match dimensions in its produer,
344 //! meaning it is a brand new domain in its TensorDomain.
345 DomainKeySet new_broadcast_domains_;
346
347 //! Keep track of window axes so that the map function can ignore them.
348 std::unordered_set<IterDomain*> window_axes_;
349};
350
351//! Create a DisjointSets of root IterDomains by traversing the
352//! current fusion entirely. IterDomains that can be mapped each
353//! other with computeAt are grouped into the same subset in the
354//! DisjointSets.
355class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder
356 : private BackwardVisitor {
357 public:
358 explicit ComputeAtRootDomainMapBuilder(
359 ComputeAtRootDomainMap& root_map,
360 bool map_through_reduction = false);
361
362 private:
363 //! Initialize the bcast map for fusion outputs
364 void initializeBcastMap(const TensorView* tv, const IterDomain* id);
365
366 //! Set a pair of producer-consumer domain keys as mappable
367 void setMapped(const DomainKey& producer, const DomainKey& consumer);
368
369 //! Records two domains are invalid to map
370 void setInvalid(const DomainKey& key1, const DomainKey& key2);
371
372 //! Check if no pair of domains is invalid to map
373 bool isInvalid(const DomainKeySet& domains) const;
374
375 //! Track a pair of producer-consumer domains as potentially mappable. Inserts
376 //! entries into pending_map_, but does not add anything into the root_map_
377 //! (added when handle is called on a TensorView). Maybe mapped will, however,
378 //! immediately propagate broadcast iter domains.
379 void setMaybeMapped(
380 const TensorDomain* producer_td,
381 const IterDomain* producer_id,
382 const TensorDomain* consumer_td,
383 const IterDomain* consumer_id);
384
385 void addToPendingList(const DomainKey& producer, const DomainKey& consumer);
386
387 //! Map pointwise IterDomains from inputs of expressions to outputs.
388 //! Do not map reduction IterDomains in inputs.
389 void mapPointwiseOrReductionOp(Expr* e);
390
391 using BackwardVisitor::handle;
392
393 void handle(Expr* e) override;
394
395 void handle(UnaryOp* uop) override {
396 mapPointwiseOrReductionOp(uop);
397 }
398
399 void handle(BinaryOp* bop) override {
400 mapPointwiseOrReductionOp(bop);
401 }
402
403 void handle(TernaryOp* top) override {
404 mapPointwiseOrReductionOp(top);
405 }
406
407 void handle(RNGOp* top) override;
408
409 void handle(ReductionOp* op) override {
410 mapPointwiseOrReductionOp(op);
411 }
412
413 void handle(GroupedReductionOp* op) override {
414 mapPointwiseOrReductionOp(op);
415 }
416
417 void handle(WelfordOp* wop) override {
418 mapPointwiseOrReductionOp(wop);
419 }
420
421 void handle(LoadStoreOp* ldst) override {
422 mapPointwiseOrReductionOp(ldst);
423 }
424
425 void handle(MmaOp* wop) override {
426 mapPointwiseOrReductionOp(wop);
427 }
428
429 void handle(ShiftOp* op) override {
430 mapPointwiseOrReductionOp(op);
431 }
432
433 void handle(ViewOp* op) override {
434 mapPointwiseOrReductionOp(op);
435 }
436
437 void handle(ViewAsScalar* op) override;
438
439 void handle(BroadcastOp* op) override;
440
441 void handle(TransposeOp* op) override;
442
443 void handle(ExpandOp* op) override {
444 mapPointwiseOrReductionOp(op);
445 }
446
447 void handle(GatherOp* op) override;
448
449 void handle(TensorView* tv) override;
450
451 //! Maps all pending mappings.
452 //! This is called for each of TensorViews in a backward traversal,
453 //! recursively building mappings from the output tensors to the
454 //! input tensors.
455 void mapAllPendingMappings(const DomainKey& key);
456
457 //! Maps all pending mappings for id of td. When id is a broadcast,
458 //! mapping is done separately for each concrete domain.
459 void mapAllPendingMappings(const TensorDomain* td, IterDomain* id);
460
461 bool safeToMap(const DomainKeySet& domains);
462
463 private:
464 ComputeAtRootDomainMap& root_map_;
465 //! Keep track of what we want to try and map
466 DomainKeyMap<DomainKeySet> pending_map_;
467 std::unordered_set<Expr*> visited_;
468 //! Helper class to find invalid mappings due to reductions
469 UnmappableReductionDomains incompatible_domains_;
470 //! Running vector of domain pairs that are invalid to map
471 std::vector<std::pair<DomainKey, DomainKey>> invalid_mappings_;
472
473 //! Disable UnmappableReductions check, should
474 //! always be false for compute_at use cases
475 bool map_through_reduction_ = false;
476};
477
478//! Maps root domains of an entire fusion. Does not map broadcast
479//! domains with non-broadcast domains.
480class TORCH_CUDA_CU_API ExactRootDomainMap : public RootDomainMap {
481 public:
482 ExactRootDomainMap(Fusion* fusion);
483
484 bool areMapped(const IterDomain* id_a, const IterDomain* id_b) const;
485
486 std::string toString() const;
487
488 protected:
489 std::unordered_map<IterDomain*, IterDomain*> map(
490 const TensorDomain* producer,
491 const TensorDomain* consumer,
492 const std::unordered_set<IterDomain*>& root_dims_to_map,
493 bool producer_to_consumer) const override;
494
495 private:
496 DisjointSets<const IterDomain*> eq_sets_;
497};
498
499} // namespace cuda
500} // namespace fuser
501} // namespace jit
502} // namespace torch
503