1 | #include <instrumentation.h> |
2 | #include <ir_builder.h> |
3 | #include <ir_cloner.h> |
4 | #include <ir_container.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | namespace fuser { |
9 | namespace cuda { |
10 | |
11 | void swap(IrContainer& a, IrContainer& b) noexcept { |
12 | FUSER_PERF_SCOPE("Fusion swap" ); |
13 | |
14 | using std::swap; |
15 | |
16 | // Swap the content |
17 | swap(a.vals_up_, b.vals_up_); |
18 | swap(a.vals_, b.vals_); |
19 | |
20 | swap(a.exprs_up_, b.exprs_up_); |
21 | swap(a.exprs_, b.exprs_); |
22 | |
23 | swap(a.raw_ptrs_, b.raw_ptrs_); |
24 | |
25 | swap(a.val_type_name_map_, b.val_type_name_map_); |
26 | swap(a.expr_name_counter_, b.expr_name_counter_); |
27 | |
28 | // Fixup the Statement::fusion_ links for a |
29 | for (auto val : a.vals_) { |
30 | val->ir_container_ = &a; |
31 | } |
32 | for (auto expr : a.exprs_) { |
33 | expr->ir_container_ = &a; |
34 | } |
35 | |
36 | // Fixup the Statement::fusion_ links for b |
37 | for (auto val : b.vals_) { |
38 | val->ir_container_ = &a; |
39 | } |
40 | for (auto expr : b.exprs_) { |
41 | expr->ir_container_ = &a; |
42 | } |
43 | } |
44 | |
45 | IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { |
46 | to->clear(); |
47 | IrCloner ir_cloner(to); |
48 | |
49 | for (auto val : from->vals_) { |
50 | to->vals_.insert(ir_cloner.clone(val)); |
51 | } |
52 | |
53 | for (auto expr : from->exprs_) { |
54 | to->exprs_.insert(ir_cloner.clone(expr)); |
55 | } |
56 | |
57 | to->val_type_name_map_ = from->val_type_name_map_; |
58 | to->expr_name_counter_ = from->expr_name_counter_; |
59 | |
60 | return ir_cloner; |
61 | } |
62 | |
63 | IrContainer::IrContainer() = default; |
64 | |
65 | IrContainer::IrContainer(const IrContainer& other) { |
66 | FUSER_PERF_SCOPE("IrContainer copy" ); |
67 | IrContainer::copy(&other, this); |
68 | } |
69 | |
70 | IrContainer::IrContainer(IrContainer&& other) noexcept { |
71 | FUSER_PERF_SCOPE("IrContainer move" ); |
72 | swap(*this, other); |
73 | } |
74 | |
75 | IrContainer& IrContainer::operator=(const IrContainer& other) { |
76 | FUSER_PERF_SCOPE("IrContainer copy assign" ); |
77 | IrContainer copy(other); |
78 | clear(); |
79 | swap(*this, copy); |
80 | return *this; |
81 | } |
82 | |
83 | IrContainer& IrContainer::operator=(IrContainer&& other) noexcept { |
84 | FUSER_PERF_SCOPE("IrContainer move assign" ); |
85 | clear(); |
86 | swap(*this, other); |
87 | return *this; |
88 | } |
89 | |
90 | IrContainer::~IrContainer() { |
91 | clear(); |
92 | } |
93 | |
94 | //! Register the Statement with this container |
95 | void IrContainer::registerStmt(IrBuilderPasskey, Statement* stmt) { |
96 | if (stmt->isVal()) { |
97 | registerVal(stmt->asVal()); |
98 | } else { |
99 | registerExpr(stmt->asExpr()); |
100 | } |
101 | } |
102 | |
103 | //! Register the Val with this container |
104 | void IrContainer::registerVal(IrBuilderPasskey, Val* val) { |
105 | registerVal(val); |
106 | } |
107 | |
108 | //! Register expr with this container. |
109 | void IrContainer::registerExpr(IrBuilderPasskey, Expr* expr) { |
110 | registerExpr(expr); |
111 | } |
112 | |
113 | void IrContainer::registerExpr(ExprPasskey, Expr* expr) { |
114 | registerExpr(expr); |
115 | } |
116 | |
117 | void IrContainer::removeExpr(Expr* expr) { |
118 | TORCH_INTERNAL_ASSERT( |
119 | exprs_.find(expr) != exprs_.end(), |
120 | "Wanted to remove an expression but it doesn't exist in this container." ); |
121 | auto expr_in_deque = std::find_if( |
122 | exprs_up_.begin(), |
123 | exprs_up_.end(), |
124 | [expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get() == expr; }); |
125 | |
126 | TORCH_INTERNAL_ASSERT( |
127 | expr_in_deque != exprs_up_.end(), |
128 | "Wanted to remove an expression but its unique ptr is missing." ); |
129 | |
130 | exprs_.erase(expr); |
131 | exprs_up_.erase(expr_in_deque); |
132 | raw_ptrs_.erase((void*)expr); |
133 | } |
134 | |
135 | //! Completely remove val from the fusion, break all dependencies associated |
136 | //! with it |
137 | void IrContainer::removeVal(Val* val) { |
138 | // Don't remove shortcuts |
139 | if (val == true_val_.get() || val == false_val_.get() || |
140 | val == one_val_.get() || val == zero_val_.get() || |
141 | val == magic_zero_val_.get()) { |
142 | return; |
143 | } |
144 | |
145 | TORCH_INTERNAL_ASSERT( |
146 | vals_.find(val) != vals_.end(), |
147 | "Wanted to remove a value but it doesn't exist in this container." ); |
148 | auto val_in_deque = std::find_if( |
149 | vals_up_.begin(), vals_up_.end(), [val](std::unique_ptr<Val>& val_up) { |
150 | return val_up.get() == val; |
151 | }); |
152 | |
153 | TORCH_INTERNAL_ASSERT( |
154 | val_in_deque != vals_up_.end(), |
155 | "Wanted to remove a value but its unique ptr is missing." ); |
156 | |
157 | vals_.erase(val); |
158 | vals_up_.erase(val_in_deque); |
159 | raw_ptrs_.erase((void*)val); |
160 | } |
161 | |
162 | //! Register the Val with this container |
163 | void IrContainer::registerVal(Val* val) { |
164 | if (inContainer(val)) { |
165 | return; |
166 | } |
167 | |
168 | vals_up_.emplace_back(std::unique_ptr<Val>(val)); |
169 | vals_.emplace(vals_up_.back().get()); |
170 | val->setName(IrContainerPasskey(), getValName(vals_up_.back()->vtype())); |
171 | raw_ptrs_.emplace((void*)vals_up_.back().get()); |
172 | } |
173 | |
174 | //! Register expr with this container. |
175 | void IrContainer::registerExpr(Expr* expr) { |
176 | if (inContainer(expr)) { |
177 | return; |
178 | } |
179 | exprs_up_.emplace_back(std::unique_ptr<Expr>(expr)); |
180 | exprs_.emplace(exprs_up_.back().get()); |
181 | expr->setName(IrContainerPasskey(), getExprName()); |
182 | raw_ptrs_.emplace((void*)exprs_up_.back().get()); |
183 | } |
184 | |
185 | void IrContainer::clear() noexcept { |
186 | FUSER_PERF_SCOPE("IrContainer clear" ); |
187 | vals_.clear(); |
188 | vals_up_.clear(); |
189 | exprs_.clear(); |
190 | exprs_up_.clear(); |
191 | raw_ptrs_.clear(); |
192 | |
193 | val_type_name_map_.clear(); |
194 | expr_name_counter_ = 0; |
195 | } |
196 | |
197 | bool IrContainer::inContainer(const Statement* stmt) const { |
198 | const void* const_void = (const void*)(stmt); |
199 | void* nonconst_void = const_cast<void*>(const_void); // NOLINT |
200 | if (raw_ptrs_.find(nonconst_void) == raw_ptrs_.end()) { |
201 | return false; |
202 | } |
203 | |
204 | TORCH_INTERNAL_ASSERT( |
205 | stmt->container() == this, |
206 | "Container claims to own stmt, but stmt disagrees." ); |
207 | |
208 | Statement* nonconst_stmt = const_cast<Statement*>(stmt); // NOLINT |
209 | if (stmt->isExpr()) { |
210 | TORCH_INTERNAL_ASSERT( |
211 | exprs_.find(nonconst_stmt->as<Expr>()) != exprs_.end(), |
212 | "Somehow container claims to and not to own an Expr." ); |
213 | } |
214 | if (stmt->isVal()) { |
215 | TORCH_INTERNAL_ASSERT( |
216 | vals_.find(nonconst_stmt->as<Val>()) != vals_.end(), |
217 | "Somehow container claims to and not to own an Val." ); |
218 | } |
219 | |
220 | return true; |
221 | } |
222 | |
223 | // Shortcuts for frequently used vals |
224 | Int* IrContainer::zeroVal() { |
225 | if (!zero_val_) { |
226 | auto zero_val = IrBuilder::create<Int>(this, 0); |
227 | TORCH_INTERNAL_ASSERT(vals_up_.back().get() == zero_val); |
228 | zero_val_ = std::unique_ptr<Int>(vals_up_.back().release()->as<Int>()); |
229 | vals_up_.pop_back(); |
230 | } |
231 | return zero_val_.get(); |
232 | } |
233 | |
234 | Int* IrContainer::oneVal() { |
235 | if (!one_val_) { |
236 | auto one_val = IrBuilder::create<Int>(this, 1); |
237 | TORCH_INTERNAL_ASSERT(vals_up_.back().get() == one_val); |
238 | one_val_ = std::unique_ptr<Int>(vals_up_.back().release()->as<Int>()); |
239 | vals_up_.pop_back(); |
240 | } |
241 | return one_val_.get(); |
242 | } |
243 | |
244 | Bool* IrContainer::falseVal() { |
245 | if (!false_val_) { |
246 | auto false_val = IrBuilder::create<Bool>(this, false); |
247 | TORCH_INTERNAL_ASSERT(vals_up_.back().get() == false_val); |
248 | false_val_ = std::unique_ptr<Bool>(vals_up_.back().release()->as<Bool>()); |
249 | vals_up_.pop_back(); |
250 | } |
251 | return false_val_.get(); |
252 | } |
253 | |
254 | Bool* IrContainer::trueVal() { |
255 | if (!true_val_) { |
256 | auto true_val = IrBuilder::create<Bool>(this, true); |
257 | TORCH_INTERNAL_ASSERT(vals_up_.back().get() == true_val); |
258 | true_val_ = std::unique_ptr<Bool>(vals_up_.back().release()->as<Bool>()); |
259 | vals_up_.pop_back(); |
260 | } |
261 | return true_val_.get(); |
262 | } |
263 | |
264 | NamedScalar* IrContainer::magicZeroVal() { |
265 | if (!magic_zero_val_) { |
266 | auto magic_zero = |
267 | IrBuilder::create<NamedScalar>(kMagicZeroName, DataType::Int); |
268 | TORCH_INTERNAL_ASSERT(vals_up_.back().get() == magic_zero); |
269 | magic_zero_val_ = std::unique_ptr<NamedScalar>( |
270 | vals_up_.back().release()->as<NamedScalar>()); |
271 | vals_up_.pop_back(); |
272 | } |
273 | return magic_zero_val_.get(); |
274 | } |
275 | |
276 | } // namespace cuda |
277 | } // namespace fuser |
278 | } // namespace jit |
279 | } // namespace torch |
280 | |