1#include <instrumentation.h>
2#include <ir_builder.h>
3#include <ir_cloner.h>
4#include <ir_container.h>
5
6namespace torch {
7namespace jit {
8namespace fuser {
9namespace cuda {
10
11void swap(IrContainer& a, IrContainer& b) noexcept {
12 FUSER_PERF_SCOPE("Fusion swap");
13
14 using std::swap;
15
16 // Swap the content
17 swap(a.vals_up_, b.vals_up_);
18 swap(a.vals_, b.vals_);
19
20 swap(a.exprs_up_, b.exprs_up_);
21 swap(a.exprs_, b.exprs_);
22
23 swap(a.raw_ptrs_, b.raw_ptrs_);
24
25 swap(a.val_type_name_map_, b.val_type_name_map_);
26 swap(a.expr_name_counter_, b.expr_name_counter_);
27
28 // Fixup the Statement::fusion_ links for a
29 for (auto val : a.vals_) {
30 val->ir_container_ = &a;
31 }
32 for (auto expr : a.exprs_) {
33 expr->ir_container_ = &a;
34 }
35
36 // Fixup the Statement::fusion_ links for b
37 for (auto val : b.vals_) {
38 val->ir_container_ = &a;
39 }
40 for (auto expr : b.exprs_) {
41 expr->ir_container_ = &a;
42 }
43}
44
45IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) {
46 to->clear();
47 IrCloner ir_cloner(to);
48
49 for (auto val : from->vals_) {
50 to->vals_.insert(ir_cloner.clone(val));
51 }
52
53 for (auto expr : from->exprs_) {
54 to->exprs_.insert(ir_cloner.clone(expr));
55 }
56
57 to->val_type_name_map_ = from->val_type_name_map_;
58 to->expr_name_counter_ = from->expr_name_counter_;
59
60 return ir_cloner;
61}
62
63IrContainer::IrContainer() = default;
64
65IrContainer::IrContainer(const IrContainer& other) {
66 FUSER_PERF_SCOPE("IrContainer copy");
67 IrContainer::copy(&other, this);
68}
69
70IrContainer::IrContainer(IrContainer&& other) noexcept {
71 FUSER_PERF_SCOPE("IrContainer move");
72 swap(*this, other);
73}
74
75IrContainer& IrContainer::operator=(const IrContainer& other) {
76 FUSER_PERF_SCOPE("IrContainer copy assign");
77 IrContainer copy(other);
78 clear();
79 swap(*this, copy);
80 return *this;
81}
82
83IrContainer& IrContainer::operator=(IrContainer&& other) noexcept {
84 FUSER_PERF_SCOPE("IrContainer move assign");
85 clear();
86 swap(*this, other);
87 return *this;
88}
89
90IrContainer::~IrContainer() {
91 clear();
92}
93
94//! Register the Statement with this container
95void IrContainer::registerStmt(IrBuilderPasskey, Statement* stmt) {
96 if (stmt->isVal()) {
97 registerVal(stmt->asVal());
98 } else {
99 registerExpr(stmt->asExpr());
100 }
101}
102
103//! Register the Val with this container
104void IrContainer::registerVal(IrBuilderPasskey, Val* val) {
105 registerVal(val);
106}
107
108//! Register expr with this container.
109void IrContainer::registerExpr(IrBuilderPasskey, Expr* expr) {
110 registerExpr(expr);
111}
112
113void IrContainer::registerExpr(ExprPasskey, Expr* expr) {
114 registerExpr(expr);
115}
116
117void IrContainer::removeExpr(Expr* expr) {
118 TORCH_INTERNAL_ASSERT(
119 exprs_.find(expr) != exprs_.end(),
120 "Wanted to remove an expression but it doesn't exist in this container.");
121 auto expr_in_deque = std::find_if(
122 exprs_up_.begin(),
123 exprs_up_.end(),
124 [expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get() == expr; });
125
126 TORCH_INTERNAL_ASSERT(
127 expr_in_deque != exprs_up_.end(),
128 "Wanted to remove an expression but its unique ptr is missing.");
129
130 exprs_.erase(expr);
131 exprs_up_.erase(expr_in_deque);
132 raw_ptrs_.erase((void*)expr);
133}
134
135//! Completely remove val from the fusion, break all dependencies associated
136//! with it
137void IrContainer::removeVal(Val* val) {
138 // Don't remove shortcuts
139 if (val == true_val_.get() || val == false_val_.get() ||
140 val == one_val_.get() || val == zero_val_.get() ||
141 val == magic_zero_val_.get()) {
142 return;
143 }
144
145 TORCH_INTERNAL_ASSERT(
146 vals_.find(val) != vals_.end(),
147 "Wanted to remove a value but it doesn't exist in this container.");
148 auto val_in_deque = std::find_if(
149 vals_up_.begin(), vals_up_.end(), [val](std::unique_ptr<Val>& val_up) {
150 return val_up.get() == val;
151 });
152
153 TORCH_INTERNAL_ASSERT(
154 val_in_deque != vals_up_.end(),
155 "Wanted to remove a value but its unique ptr is missing.");
156
157 vals_.erase(val);
158 vals_up_.erase(val_in_deque);
159 raw_ptrs_.erase((void*)val);
160}
161
162//! Register the Val with this container
163void IrContainer::registerVal(Val* val) {
164 if (inContainer(val)) {
165 return;
166 }
167
168 vals_up_.emplace_back(std::unique_ptr<Val>(val));
169 vals_.emplace(vals_up_.back().get());
170 val->setName(IrContainerPasskey(), getValName(vals_up_.back()->vtype()));
171 raw_ptrs_.emplace((void*)vals_up_.back().get());
172}
173
174//! Register expr with this container.
175void IrContainer::registerExpr(Expr* expr) {
176 if (inContainer(expr)) {
177 return;
178 }
179 exprs_up_.emplace_back(std::unique_ptr<Expr>(expr));
180 exprs_.emplace(exprs_up_.back().get());
181 expr->setName(IrContainerPasskey(), getExprName());
182 raw_ptrs_.emplace((void*)exprs_up_.back().get());
183}
184
185void IrContainer::clear() noexcept {
186 FUSER_PERF_SCOPE("IrContainer clear");
187 vals_.clear();
188 vals_up_.clear();
189 exprs_.clear();
190 exprs_up_.clear();
191 raw_ptrs_.clear();
192
193 val_type_name_map_.clear();
194 expr_name_counter_ = 0;
195}
196
197bool IrContainer::inContainer(const Statement* stmt) const {
198 const void* const_void = (const void*)(stmt);
199 void* nonconst_void = const_cast<void*>(const_void); // NOLINT
200 if (raw_ptrs_.find(nonconst_void) == raw_ptrs_.end()) {
201 return false;
202 }
203
204 TORCH_INTERNAL_ASSERT(
205 stmt->container() == this,
206 "Container claims to own stmt, but stmt disagrees.");
207
208 Statement* nonconst_stmt = const_cast<Statement*>(stmt); // NOLINT
209 if (stmt->isExpr()) {
210 TORCH_INTERNAL_ASSERT(
211 exprs_.find(nonconst_stmt->as<Expr>()) != exprs_.end(),
212 "Somehow container claims to and not to own an Expr.");
213 }
214 if (stmt->isVal()) {
215 TORCH_INTERNAL_ASSERT(
216 vals_.find(nonconst_stmt->as<Val>()) != vals_.end(),
217 "Somehow container claims to and not to own an Val.");
218 }
219
220 return true;
221}
222
223// Shortcuts for frequently used vals
224Int* IrContainer::zeroVal() {
225 if (!zero_val_) {
226 auto zero_val = IrBuilder::create<Int>(this, 0);
227 TORCH_INTERNAL_ASSERT(vals_up_.back().get() == zero_val);
228 zero_val_ = std::unique_ptr<Int>(vals_up_.back().release()->as<Int>());
229 vals_up_.pop_back();
230 }
231 return zero_val_.get();
232}
233
234Int* IrContainer::oneVal() {
235 if (!one_val_) {
236 auto one_val = IrBuilder::create<Int>(this, 1);
237 TORCH_INTERNAL_ASSERT(vals_up_.back().get() == one_val);
238 one_val_ = std::unique_ptr<Int>(vals_up_.back().release()->as<Int>());
239 vals_up_.pop_back();
240 }
241 return one_val_.get();
242}
243
244Bool* IrContainer::falseVal() {
245 if (!false_val_) {
246 auto false_val = IrBuilder::create<Bool>(this, false);
247 TORCH_INTERNAL_ASSERT(vals_up_.back().get() == false_val);
248 false_val_ = std::unique_ptr<Bool>(vals_up_.back().release()->as<Bool>());
249 vals_up_.pop_back();
250 }
251 return false_val_.get();
252}
253
254Bool* IrContainer::trueVal() {
255 if (!true_val_) {
256 auto true_val = IrBuilder::create<Bool>(this, true);
257 TORCH_INTERNAL_ASSERT(vals_up_.back().get() == true_val);
258 true_val_ = std::unique_ptr<Bool>(vals_up_.back().release()->as<Bool>());
259 vals_up_.pop_back();
260 }
261 return true_val_.get();
262}
263
264NamedScalar* IrContainer::magicZeroVal() {
265 if (!magic_zero_val_) {
266 auto magic_zero =
267 IrBuilder::create<NamedScalar>(kMagicZeroName, DataType::Int);
268 TORCH_INTERNAL_ASSERT(vals_up_.back().get() == magic_zero);
269 magic_zero_val_ = std::unique_ptr<NamedScalar>(
270 vals_up_.back().release()->as<NamedScalar>());
271 vals_up_.pop_back();
272 }
273 return magic_zero_val_.get();
274}
275
276} // namespace cuda
277} // namespace fuser
278} // namespace jit
279} // namespace torch
280