1 | #pragma once |
2 | |
3 | #include <ir_all_nodes.h> |
4 | #include <root_domain_map.h> |
5 | |
6 | #include <c10/macros/Export.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | //! Traverse and collect all concretized broadcast domains. |
14 | //! |
15 | //! The traversal first initializes the origin map with broadcast |
16 | //! domains in input tensors. Then, a new entry is added to the origin |
17 | //! map when a broadcast op is encountered during a forward traversal |
18 | //! of the given fusion. For non-broadcast ops, mappings are just |
19 | //! propagated forward using PairwiseRootDomainMap. |
20 | //! |
21 | //! When the mapped consumer domain is not broadcast, it means the |
22 | //! producer broadcast domain is concretized, and its origin broadcast |
23 | //! domains are marked as concretized. |
24 | class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor { |
25 | public: |
26 | ConcretizedBroadcastDomains() = delete; |
27 | ConcretizedBroadcastDomains(Fusion* fusion); |
28 | |
29 | //! Is a domain concretized? |
30 | bool isConcretized(IterDomain* id) const; |
31 | |
32 | //! Is a domain concretized to a unique concrete domain? |
33 | bool isUniquelyConcretized(IterDomain* id) const; |
34 | |
35 | //! Is a domain concretized to multiple concrete domains? |
36 | bool maybeNonUniquelyConcretized(IterDomain* id) const; |
37 | |
38 | private: |
39 | using IterVisitor::handle; |
40 | |
41 | void handle(BroadcastOp* bop) final; |
42 | |
43 | void handle(Expr* expr) final; |
44 | |
45 | void markAsConcretized( |
46 | IterDomain* broadcast_root_domain, |
47 | IterDomain* concrete_root_domain); |
48 | |
49 | bool insertRootDomainToConcreteDomainSet( |
50 | IterDomain* new_root_id, |
51 | std::unordered_set<IterDomain*>& id_set); |
52 | |
53 | private: |
54 | //! Maps each root broadcast domain to its original root broadcast |
55 | //! domains. Their can be multiple original domains due to, e.g., |
56 | //! binary ops with broadcast domains in both inputs. |
57 | std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> |
58 | broadcast_origin_map_; |
59 | //! Map all broadcast domains to concrete root domains |
60 | std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> |
61 | broadcast_to_concrete_map_; |
62 | |
63 | std::unique_ptr<ExactRootDomainMap> exact_map_; |
64 | }; |
65 | |
66 | } // namespace cuda |
67 | } // namespace fuser |
68 | } // namespace jit |
69 | } // namespace torch |
70 | |