1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <ir_base_nodes.h> |
6 | #include <utils.h> |
7 | |
8 | #include <deque> |
9 | #include <unordered_map> |
10 | #include <unordered_set> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | class IrBuilderPasskey; |
18 | class ExprPasskey; |
19 | class OptOutMutator; |
20 | |
21 | class Int; |
22 | class Bool; |
23 | class NamedScalar; |
24 | |
25 | // Passkey for container to register names with statements |
26 | class IrContainerPasskey { |
27 | friend class IrContainer; |
28 | |
29 | private: |
30 | explicit IrContainerPasskey() {} |
31 | }; |
32 | |
33 | class TORCH_CUDA_CU_API IrContainer : public PolymorphicBase { |
34 | public: |
35 | IrContainer(); |
36 | |
37 | IrContainer(const IrContainer& other); |
38 | IrContainer(IrContainer&& other) noexcept; |
39 | |
40 | IrContainer& operator=(const IrContainer& other); |
41 | IrContainer& operator=(IrContainer&& other) noexcept; |
42 | |
43 | virtual ~IrContainer(); |
44 | |
45 | bool inContainer(const Statement* stmt) const; |
46 | |
47 | void assertInContainer(const Statement* stmt, const std::string& msg) const { |
48 | TORCH_CHECK( |
49 | inContainer(stmt), msg, " it was not found in the active container." ); |
50 | } |
51 | |
52 | //! Return in insertion order |
53 | const std::deque<Val*> deterministic_vals() const noexcept { |
54 | std::deque<Val*> vals_deque; |
55 | std::transform( |
56 | vals_up_.begin(), |
57 | vals_up_.end(), |
58 | std::back_inserter(vals_deque), |
59 | [](const std::unique_ptr<Val>& val_up) { return val_up.get(); }); |
60 | return vals_deque; |
61 | } |
62 | |
63 | //! Register the Statement with this container |
64 | virtual void registerStmt(IrBuilderPasskey, Statement* stmt); |
65 | |
66 | //! Register the Val with this container |
67 | virtual void registerVal(IrBuilderPasskey, Val* val); |
68 | |
69 | //! Register expr with this container. |
70 | virtual void registerExpr(IrBuilderPasskey, Expr* expr); |
71 | |
72 | //! Allow expr's to register themselves with a container, this is only used |
73 | //! for broadcastOp so it can register itself in its constructor so root maps |
74 | //! can be built. |
75 | virtual void registerExpr(ExprPasskey, Expr* expr); |
76 | |
77 | //! Return the set of Exprs registered with this fusion. Warning: This will |
78 | //! return exprs outside inputs/outputs, so can be unsafe for use with |
79 | //! segmented fusions. |
80 | const std::unordered_set<Expr*>& unordered_exprs() const noexcept { |
81 | return exprs_; |
82 | } |
83 | |
84 | //! Return the set of Vals registered with this fusion |
85 | const std::unordered_set<Val*>& vals() const noexcept { |
86 | return vals_; |
87 | } |
88 | |
89 | // Shortcuts for frequently used vals |
90 | Int* zeroVal(); |
91 | Int* oneVal(); |
92 | Bool* falseVal(); |
93 | Bool* trueVal(); |
94 | NamedScalar* magicZeroVal(); |
95 | |
96 | protected: |
97 | static IrCloner copy(const IrContainer* from, IrContainer* to); |
98 | |
99 | friend void swap(IrContainer& a, IrContainer& b) noexcept; |
100 | |
101 | // Let mutator remove Exprs. |
102 | friend OptOutMutator; |
103 | |
104 | virtual void removeExpr(Expr* expr); |
105 | |
106 | //! Completely remove val from the fusion, break all dependencies associated |
107 | //! with it |
108 | virtual void removeVal(Val* val); |
109 | |
110 | //! Register the Val with this container |
111 | virtual void registerVal(Val* val); |
112 | |
113 | //! Register expr with this container. |
114 | virtual void registerExpr(Expr* expr); |
115 | |
116 | StmtNameType getValName(ValType vtype) { |
117 | if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { |
118 | val_type_name_map_[vtype] = 0; |
119 | } |
120 | return val_type_name_map_[vtype]++; |
121 | } |
122 | |
123 | StmtNameType getExprName() { |
124 | return expr_name_counter_++; |
125 | } |
126 | |
127 | void clear() noexcept; |
128 | |
129 | // Deque of unique pointer is the memory owning data structure |
130 | std::deque<std::unique_ptr<Val>> vals_up_; |
131 | |
132 | // A convenient set to return when we just need an unordered set to do |
133 | // something like check if a Val is in this container |
134 | std::unordered_set<Val*> vals_; |
135 | |
136 | // Deque of unique pointer is the memory owning data structure |
137 | std::deque<std::unique_ptr<Expr>> exprs_up_; |
138 | |
139 | // A convenient set to return when we just need an unordered set to do |
140 | // something like check if an Expr is in this container |
141 | std::unordered_set<Expr*> exprs_; |
142 | |
143 | // Used to implement a generic "inContainer" that can be passed an invalid |
144 | // pointer. Specifically a pointer to a Statement owned by another container |
145 | // that has been freed. We can't check normally with the unordered_sets we |
146 | // already have because it would require a const_cast from a constant |
147 | // expr/val, or a dynamic cast from a Statement. |
148 | std::unordered_set<void*> raw_ptrs_; |
149 | |
150 | // Values names counters |
151 | std::unordered_map<ValType, StmtNameType, TypeHash> val_type_name_map_; |
152 | |
153 | // Expression names counter |
154 | StmtNameType expr_name_counter_ = 0; |
155 | |
156 | // Manually store some persistent, frequently used nodes. It's very |
157 | // challenging to do this anything but manually as detecting when a container |
158 | // may or may not have one of these vals is tricky. Specifically because if |
159 | // the container doesn't own it, it's hard to understand from the outside if |
160 | // the node may have been removed then re-registered. It could also be tricky |
161 | // to know when we're using a different container as in FusionCopy_test |
162 | // demonstrates deleting then creating containers can result in the same |
163 | // pointer for the container. |
164 | std::unique_ptr<Bool> true_val_; |
165 | std::unique_ptr<Bool> false_val_; |
166 | std::unique_ptr<Int> one_val_; |
167 | std::unique_ptr<Int> zero_val_; |
168 | std::unique_ptr<NamedScalar> magic_zero_val_; |
169 | }; |
170 | |
171 | } // namespace cuda |
172 | } // namespace fuser |
173 | } // namespace jit |
174 | } // namespace torch |
175 | |