1#pragma once
2
3#include <ir_all_nodes.h>
4#include <root_domain_map.h>
5
6#include <c10/macros/Export.h>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13//! Traverse and collect all concretized broadcast domains.
14//!
15//! The traversal first initializes the origin map with broadcast
16//! domains in input tensors. Then, a new entry is added to the origin
17//! map when a broadcast op is encountered during a forward traversal
18//! of the given fusion. For non-broadcast ops, mappings are just
19//! propagated forward using PairwiseRootDomainMap.
20//!
21//! When the mapped consumer domain is not broadcast, it means the
22//! producer broadcast domain is concretized, and its origin broadcast
23//! domains are marked as concretized.
24class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor {
25 public:
26 ConcretizedBroadcastDomains() = delete;
27 ConcretizedBroadcastDomains(Fusion* fusion);
28
29 //! Is a domain concretized?
30 bool isConcretized(IterDomain* id) const;
31
32 //! Is a domain concretized to a unique concrete domain?
33 bool isUniquelyConcretized(IterDomain* id) const;
34
35 //! Is a domain concretized to multiple concrete domains?
36 bool maybeNonUniquelyConcretized(IterDomain* id) const;
37
38 private:
39 using IterVisitor::handle;
40
41 void handle(BroadcastOp* bop) final;
42
43 void handle(Expr* expr) final;
44
45 void markAsConcretized(
46 IterDomain* broadcast_root_domain,
47 IterDomain* concrete_root_domain);
48
49 bool insertRootDomainToConcreteDomainSet(
50 IterDomain* new_root_id,
51 std::unordered_set<IterDomain*>& id_set);
52
53 private:
54 //! Maps each root broadcast domain to its original root broadcast
55 //! domains. Their can be multiple original domains due to, e.g.,
56 //! binary ops with broadcast domains in both inputs.
57 std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
58 broadcast_origin_map_;
59 //! Map all broadcast domains to concrete root domains
60 std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
61 broadcast_to_concrete_map_;
62
63 std::unique_ptr<ExactRootDomainMap> exact_map_;
64};
65
66} // namespace cuda
67} // namespace fuser
68} // namespace jit
69} // namespace torch
70