1 | #include <dispatch.h> |
2 | #include <expr_evaluator.h> |
3 | #include <fusion.h> |
4 | #include <ir_all_nodes.h> |
5 | #include <ir_builder.h> |
6 | #include <ir_cloner.h> |
7 | #include <ir_printer.h> |
8 | #include <kernel.h> |
9 | #include <kernel_ir.h> |
10 | #include <kernel_ir_dispatch.h> |
11 | #include <mutator.h> |
12 | |
13 | #include <torch/csrc/jit/ir/ir.h> |
14 | |
15 | #include <c10/util/Exception.h> |
16 | #include <c10/util/irange.h> |
17 | |
18 | #include <iostream> |
19 | #include <stdexcept> |
20 | #include <string> |
21 | #include <unordered_map> |
22 | |
23 | namespace torch { |
24 | namespace jit { |
25 | namespace fuser { |
26 | namespace cuda { |
27 | |
28 | Statement::Statement(IrBuilderPasskey passkey) { |
29 | ir_container_ = passkey.ir_container_; |
30 | } |
31 | |
32 | Statement::Statement(const Statement* src, IrCloner* ir_cloner) { |
33 | ir_container_ = ir_cloner->container(); |
34 | } |
35 | |
36 | void Statement::setName(IrContainerPasskey, StmtNameType name) { |
37 | name_ = name; |
38 | } |
39 | |
40 | void Statement::setName(IrBuilderPasskey, StmtNameType name) { |
41 | name_ = name; |
42 | } |
43 | |
44 | Val* Statement::asVal() { |
45 | TORCH_INTERNAL_ASSERT(isVal(), "Cannot cast to Val as this is not a Val." ); |
46 | return this->as<Val>(); |
47 | } |
48 | |
49 | Expr* Statement::asExpr() { |
50 | TORCH_INTERNAL_ASSERT(isExpr(), "Cannot cast to Expr as this is not a Expr." ); |
51 | return this->as<Expr>(); |
52 | } |
53 | |
54 | std::string Statement::toString() const { |
55 | std::stringstream ss; |
56 | IrPrinter ir_printer(ss); |
57 | ir_printer.handle(this); |
58 | return ss.str(); |
59 | } |
60 | |
61 | std::string Statement::toInlineString() const { |
62 | std::stringstream ss; |
63 | IrPrinter ir_printer(ss); |
64 | ir_printer.print_inline(this); |
65 | return ss.str(); |
66 | } |
67 | |
68 | Fusion* Statement::fusion() const { |
69 | TORCH_INTERNAL_ASSERT( |
70 | ir_container_->isA<Fusion>(), "Statement does not belong to a fusion." ); |
71 | return ir_container_->as<Fusion>(); |
72 | } |
73 | |
74 | kir::Kernel* Statement::kernel() const { |
75 | TORCH_INTERNAL_ASSERT( |
76 | ir_container_->isA<kir::Kernel>(), |
77 | "Statement does not belong to a kernel." ); |
78 | return ir_container_->as<kir::Kernel>(); |
79 | } |
80 | |
81 | // When we create a Val we immediately register them with the active fusion. |
82 | Val::Val(IrBuilderPasskey passkey, ValType _vtype, DataType _dtype) |
83 | : Statement(passkey), vtype_(_vtype), dtype_(_dtype) {} |
84 | |
85 | // NOTE: we don't clone the definition_ and uses_ here |
86 | // since they may introduce cloning cycles. Instead, we copy |
87 | // the original pointers and we'll fix them up later part of the |
88 | // Fusion copy. Neither definition_ nor uses_ are copied through |
89 | // this constructor now leaving them to be resolved by later stages |
90 | // |
91 | Val::Val(const Val* src, IrCloner* ir_cloner) |
92 | : Statement(src, ir_cloner), vtype_(src->vtype_), dtype_(src->dtype_) {} |
93 | |
94 | const std::vector<Expr*>& Val::uses() const { |
95 | if (vtype_ == ValType::TensorView) { |
96 | if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) { |
97 | fusion()->resetTvUses(); |
98 | } |
99 | } |
100 | return uses_; |
101 | } |
102 | |
103 | // Converts the data type of TensorView or Scalar representing index |
104 | // values. The data type of the original input should be |
105 | // DataType::Index, but DataType::Int is also allowed as it is used |
106 | // for index expressions. |
107 | void Val::resolveIndexDtype() { |
108 | TORCH_INTERNAL_ASSERT( |
109 | vtype_ == ValType::TensorView || vtype_ == ValType::Scalar, |
110 | "Resolving index type is currently only supported on tensor view or scalar values. " |
111 | "Value type: " , |
112 | vtype_); |
113 | TORCH_INTERNAL_ASSERT( |
114 | dtype_ == DataType::Index || dtype_ == DataType::Int, |
115 | "Can only resolve index type if a Val has an Index or Int DataType. " , |
116 | "Data type: " , |
117 | dtype_); |
118 | TORCH_INTERNAL_ASSERT( |
119 | container()->isA<kir::Kernel>(), |
120 | "Index type can only be resolved at compile time." ); |
121 | dtype_ = container()->as<kir::Kernel>()->indexType(); |
122 | } |
123 | |
124 | namespace { |
125 | |
126 | // Traverse definition of all values involved in constructing the provided val. |
127 | // Check if all values involved are constant values, meaning the provided |
128 | // val is also a constant value. |
129 | class ConstCheck : private OptOutConstDispatch { |
130 | private: |
131 | bool is_const_ = true; |
132 | |
133 | // Returns true if all Val's in the hisotry of provided Val is an Int. Since |
134 | // our expression evaluator doesn't support any type besides int, it's |
135 | // important to check it is one. |
136 | bool is_int_ = true; |
137 | |
138 | void handle(const Bool* b) final { |
139 | is_const_ = is_const_ && b->isConst(); |
140 | } |
141 | |
142 | void handle(const Double* d) final { |
143 | is_const_ = is_const_ && d->isConst(); |
144 | } |
145 | |
146 | void handle(const Int* i) final { |
147 | is_const_ = is_const_ && i->isConst(); |
148 | } |
149 | |
150 | void handle(const NamedScalar* ns) final { |
151 | is_const_ = is_const_ && false; |
152 | } |
153 | |
154 | void handle(const Expr* expr) final { |
155 | for (auto inp : expr->inputs()) { |
156 | handle(inp); |
157 | } |
158 | } |
159 | |
160 | void handle(const Val* val) final { |
161 | if (!val->isAnInt()) { |
162 | is_int_ = false; |
163 | } |
164 | |
165 | if (val->definition() != nullptr) { |
166 | handle(val->definition()); |
167 | } else { |
168 | OptOutConstDispatch::handle(val); |
169 | } |
170 | } |
171 | |
172 | public: |
173 | static bool isConst(const Val* val) { |
174 | ConstCheck cc; |
175 | cc.handle(val); |
176 | return cc.is_const_; |
177 | } |
178 | |
179 | static bool isConstInt(const Val* val) { |
180 | ConstCheck cc; |
181 | cc.handle(val); |
182 | return cc.is_const_ && cc.is_int_; |
183 | } |
184 | }; |
185 | |
186 | } // namespace |
187 | |
188 | bool Val::isConstScalar() const { |
189 | if (!isScalar()) { |
190 | return false; |
191 | } |
192 | return ConstCheck::isConst(this); |
193 | } |
194 | |
195 | bool Val::isConstInt() const { |
196 | return ConstCheck::isConst(this) && isAnInt(); |
197 | } |
198 | |
199 | int64_t Val::evaluateInt() { |
200 | TORCH_INTERNAL_ASSERT( |
201 | ConstCheck::isConst(this), |
202 | "Cannot get Int of not const values through IR nodes, must use runtime ExpressionEvaluator." ); |
203 | |
204 | if (this->as<Int>()->value().has_value()) { |
205 | return this->as<Int>()->value().value(); |
206 | } |
207 | |
208 | ExpressionEvaluator ee(fusion()); |
209 | auto evaluated_val = ee.evaluate(this); |
210 | TORCH_INTERNAL_ASSERT( |
211 | evaluated_val.has_value(), |
212 | "Detected a const integer but failed to infer its value." ); |
213 | return evaluated_val->as<int64_t>(); |
214 | } |
215 | |
216 | double Val::evaluateDouble() { |
217 | TORCH_INTERNAL_ASSERT( |
218 | ConstCheck::isConst(this), |
219 | "Cannot get Double of not const doubles through IR nodes, must use runtime ExpressionEvaluator." ); |
220 | |
221 | if (this->as<Double>()->value().has_value()) { |
222 | return this->as<Double>()->value().value(); |
223 | } |
224 | |
225 | ExpressionEvaluator ee(fusion()); |
226 | auto evaluated_val = ee.evaluate(this); |
227 | TORCH_INTERNAL_ASSERT( |
228 | evaluated_val.has_value(), |
229 | "Detected a const integer but failed to infer its value." ); |
230 | return evaluated_val->as<double>(); |
231 | } |
232 | |
233 | c10::optional<int64_t> Val::getInt() const { |
234 | if (isConstScalar() && isAnInt()) { |
235 | if (this->getValType() == ValType::Scalar) { |
236 | if (this->isA<Int>()) { |
237 | return this->as<Int>()->value(); |
238 | } |
239 | } |
240 | } |
241 | return c10::nullopt; |
242 | } |
243 | |
244 | c10::optional<double> Val::getDouble() const { |
245 | if (isConstScalar() && isAnInt()) { |
246 | if (this->getValType() == ValType::Scalar) { |
247 | if (this->isA<Double>()) { |
248 | return this->as<Double>()->value(); |
249 | } |
250 | } |
251 | } |
252 | return c10::nullopt; |
253 | } |
254 | |
255 | bool Val::isZeroInt() const { |
256 | auto int_val = getInt(); |
257 | return int_val.has_value() && int_val.value() == 0; |
258 | } |
259 | |
260 | bool Val::isOneInt() const { |
261 | auto int_val = getInt(); |
262 | return int_val.has_value() && int_val.value() == 1; |
263 | } |
264 | |
265 | bool Val::isDefinitionType(ExprType expression_type) const { |
266 | if (definition() != nullptr) { |
267 | auto def_expr_type = definition()->getExprType(); |
268 | if (def_expr_type.has_value() && def_expr_type.value() == expression_type) { |
269 | return true; |
270 | } |
271 | } |
272 | return false; |
273 | } |
274 | |
275 | c10::optional<DataType> Val::getDataType() const { |
276 | TORCH_INTERNAL_ASSERT( |
277 | dtype_ != DataType::Null, "Value does not have a data type." ); |
278 | return dtype_; |
279 | } |
280 | |
281 | bool Val::isProducerOf(const Val* other) const { |
282 | TORCH_INTERNAL_ASSERT(other != nullptr); |
283 | TORCH_INTERNAL_ASSERT(container() == other->container()); |
284 | |
285 | if (definition() == nullptr) { |
286 | return false; |
287 | } |
288 | return std::any_of( |
289 | definition()->inputs().begin(), |
290 | definition()->inputs().end(), |
291 | [other](const Val* input) { return input == other; }); |
292 | } |
293 | |
294 | bool Val::isConsumerOf(const Val* other) const { |
295 | return other->isProducerOf(this); |
296 | } |
297 | |
298 | // We don't register with the active fusion in Expr as this needs to be done |
299 | // after inputs and outputs are registered with the Expr |
300 | Expr::Expr(IrBuilderPasskey passkey, ExprType etype) |
301 | : Statement(passkey), etype_{etype} {} |
302 | |
303 | Expr::Expr(const Expr* src, IrCloner* ir_cloner) |
304 | : Statement(src, ir_cloner), |
305 | etype_(src->etype_), |
306 | inputs_(ir_cloner->clone(src->inputs_)), |
307 | outputs_(ir_cloner->clone(src->outputs_)) {} |
308 | |
309 | bool Expr::sameAs(const Statement* other) const { |
310 | if (this == other) { |
311 | return true; |
312 | } |
313 | if (!other->isA<Expr>()) { |
314 | return false; |
315 | } |
316 | const Expr* other_expr = other->as<Expr>(); |
317 | if (getExprType() != other_expr->getExprType()) { |
318 | return false; |
319 | } |
320 | if (inputs().size() != other_expr->inputs().size() || |
321 | outputs().size() != other_expr->outputs().size()) { |
322 | return false; |
323 | } |
324 | for (const auto i : c10::irange(inputs().size())) { |
325 | if (!input(i)->sameAs(other_expr->input(i))) { |
326 | return false; |
327 | } |
328 | } |
329 | return true; |
330 | } |
331 | |
332 | kir::Predicate* Expr::predicate() const { |
333 | TORCH_INTERNAL_ASSERT( |
334 | container()->isA<kir::Kernel>(), "Function invalid for fusion." ); |
335 | return predicate_; |
336 | } |
337 | |
338 | void Expr::setPredicate(kir::Predicate* predicate) { |
339 | TORCH_INTERNAL_ASSERT( |
340 | container()->isA<kir::Kernel>(), "Function invalid for fusion." ); |
341 | predicate_ = predicate; |
342 | } |
343 | |
344 | Expr* Expr::withPredicate(kir::Predicate* predicate) { |
345 | auto result = shallowCopy(); |
346 | result->setPredicate(predicate); |
347 | return result; |
348 | } |
349 | |
350 | kir::Predicate* Expr::writePredicate() const { |
351 | TORCH_INTERNAL_ASSERT( |
352 | container()->isA<kir::Kernel>(), "Function invalid for fusion." ); |
353 | return write_predicate_; |
354 | } |
355 | |
356 | void Expr::setWritePredicate(kir::Predicate* write_predicate) { |
357 | TORCH_INTERNAL_ASSERT( |
358 | container()->isA<kir::Kernel>(), "Function invalid for fusion." ); |
359 | write_predicate_ = write_predicate; |
360 | } |
361 | |
362 | Expr* Expr::withWritePredicate(kir::Predicate* predicate) { |
363 | auto result = shallowCopy(); |
364 | result->setWritePredicate(predicate); |
365 | return result; |
366 | } |
367 | |
368 | void Expr::copyPredicatesFrom(const Expr* expr) { |
369 | if (container()->isA<kir::Kernel>()) { |
370 | predicate_ = expr->predicate_; |
371 | write_predicate_ = expr->write_predicate_; |
372 | } |
373 | } |
374 | |
375 | } // namespace cuda |
376 | } // namespace fuser |
377 | } // namespace jit |
378 | } // namespace torch |
379 | |