1#include <dispatch.h>
2#include <expr_evaluator.h>
3#include <fusion.h>
4#include <ir_all_nodes.h>
5#include <ir_builder.h>
6#include <ir_cloner.h>
7#include <ir_printer.h>
8#include <kernel.h>
9#include <kernel_ir.h>
10#include <kernel_ir_dispatch.h>
11#include <mutator.h>
12
13#include <torch/csrc/jit/ir/ir.h>
14
15#include <c10/util/Exception.h>
16#include <c10/util/irange.h>
17
18#include <iostream>
19#include <stdexcept>
20#include <string>
21#include <unordered_map>
22
23namespace torch {
24namespace jit {
25namespace fuser {
26namespace cuda {
27
28Statement::Statement(IrBuilderPasskey passkey) {
29 ir_container_ = passkey.ir_container_;
30}
31
32Statement::Statement(const Statement* src, IrCloner* ir_cloner) {
33 ir_container_ = ir_cloner->container();
34}
35
36void Statement::setName(IrContainerPasskey, StmtNameType name) {
37 name_ = name;
38}
39
40void Statement::setName(IrBuilderPasskey, StmtNameType name) {
41 name_ = name;
42}
43
44Val* Statement::asVal() {
45 TORCH_INTERNAL_ASSERT(isVal(), "Cannot cast to Val as this is not a Val.");
46 return this->as<Val>();
47}
48
49Expr* Statement::asExpr() {
50 TORCH_INTERNAL_ASSERT(isExpr(), "Cannot cast to Expr as this is not a Expr.");
51 return this->as<Expr>();
52}
53
54std::string Statement::toString() const {
55 std::stringstream ss;
56 IrPrinter ir_printer(ss);
57 ir_printer.handle(this);
58 return ss.str();
59}
60
61std::string Statement::toInlineString() const {
62 std::stringstream ss;
63 IrPrinter ir_printer(ss);
64 ir_printer.print_inline(this);
65 return ss.str();
66}
67
68Fusion* Statement::fusion() const {
69 TORCH_INTERNAL_ASSERT(
70 ir_container_->isA<Fusion>(), "Statement does not belong to a fusion.");
71 return ir_container_->as<Fusion>();
72}
73
74kir::Kernel* Statement::kernel() const {
75 TORCH_INTERNAL_ASSERT(
76 ir_container_->isA<kir::Kernel>(),
77 "Statement does not belong to a kernel.");
78 return ir_container_->as<kir::Kernel>();
79}
80
81// When we create a Val we immediately register them with the active fusion.
82Val::Val(IrBuilderPasskey passkey, ValType _vtype, DataType _dtype)
83 : Statement(passkey), vtype_(_vtype), dtype_(_dtype) {}
84
85// NOTE: we don't clone the definition_ and uses_ here
86// since they may introduce cloning cycles. Instead, we copy
87// the original pointers and we'll fix them up later part of the
88// Fusion copy. Neither definition_ nor uses_ are copied through
89// this constructor now leaving them to be resolved by later stages
90//
91Val::Val(const Val* src, IrCloner* ir_cloner)
92 : Statement(src, ir_cloner), vtype_(src->vtype_), dtype_(src->dtype_) {}
93
94const std::vector<Expr*>& Val::uses() const {
95 if (vtype_ == ValType::TensorView) {
96 if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) {
97 fusion()->resetTvUses();
98 }
99 }
100 return uses_;
101}
102
103// Converts the data type of TensorView or Scalar representing index
104// values. The data type of the original input should be
105// DataType::Index, but DataType::Int is also allowed as it is used
106// for index expressions.
107void Val::resolveIndexDtype() {
108 TORCH_INTERNAL_ASSERT(
109 vtype_ == ValType::TensorView || vtype_ == ValType::Scalar,
110 "Resolving index type is currently only supported on tensor view or scalar values. "
111 "Value type: ",
112 vtype_);
113 TORCH_INTERNAL_ASSERT(
114 dtype_ == DataType::Index || dtype_ == DataType::Int,
115 "Can only resolve index type if a Val has an Index or Int DataType. ",
116 "Data type: ",
117 dtype_);
118 TORCH_INTERNAL_ASSERT(
119 container()->isA<kir::Kernel>(),
120 "Index type can only be resolved at compile time.");
121 dtype_ = container()->as<kir::Kernel>()->indexType();
122}
123
124namespace {
125
126// Traverse definition of all values involved in constructing the provided val.
127// Check if all values involved are constant values, meaning the provided
128// val is also a constant value.
129class ConstCheck : private OptOutConstDispatch {
130 private:
131 bool is_const_ = true;
132
133 // Returns true if all Val's in the hisotry of provided Val is an Int. Since
134 // our expression evaluator doesn't support any type besides int, it's
135 // important to check it is one.
136 bool is_int_ = true;
137
138 void handle(const Bool* b) final {
139 is_const_ = is_const_ && b->isConst();
140 }
141
142 void handle(const Double* d) final {
143 is_const_ = is_const_ && d->isConst();
144 }
145
146 void handle(const Int* i) final {
147 is_const_ = is_const_ && i->isConst();
148 }
149
150 void handle(const NamedScalar* ns) final {
151 is_const_ = is_const_ && false;
152 }
153
154 void handle(const Expr* expr) final {
155 for (auto inp : expr->inputs()) {
156 handle(inp);
157 }
158 }
159
160 void handle(const Val* val) final {
161 if (!val->isAnInt()) {
162 is_int_ = false;
163 }
164
165 if (val->definition() != nullptr) {
166 handle(val->definition());
167 } else {
168 OptOutConstDispatch::handle(val);
169 }
170 }
171
172 public:
173 static bool isConst(const Val* val) {
174 ConstCheck cc;
175 cc.handle(val);
176 return cc.is_const_;
177 }
178
179 static bool isConstInt(const Val* val) {
180 ConstCheck cc;
181 cc.handle(val);
182 return cc.is_const_ && cc.is_int_;
183 }
184};
185
186} // namespace
187
188bool Val::isConstScalar() const {
189 if (!isScalar()) {
190 return false;
191 }
192 return ConstCheck::isConst(this);
193}
194
195bool Val::isConstInt() const {
196 return ConstCheck::isConst(this) && isAnInt();
197}
198
199int64_t Val::evaluateInt() {
200 TORCH_INTERNAL_ASSERT(
201 ConstCheck::isConst(this),
202 "Cannot get Int of not const values through IR nodes, must use runtime ExpressionEvaluator.");
203
204 if (this->as<Int>()->value().has_value()) {
205 return this->as<Int>()->value().value();
206 }
207
208 ExpressionEvaluator ee(fusion());
209 auto evaluated_val = ee.evaluate(this);
210 TORCH_INTERNAL_ASSERT(
211 evaluated_val.has_value(),
212 "Detected a const integer but failed to infer its value.");
213 return evaluated_val->as<int64_t>();
214}
215
216double Val::evaluateDouble() {
217 TORCH_INTERNAL_ASSERT(
218 ConstCheck::isConst(this),
219 "Cannot get Double of not const doubles through IR nodes, must use runtime ExpressionEvaluator.");
220
221 if (this->as<Double>()->value().has_value()) {
222 return this->as<Double>()->value().value();
223 }
224
225 ExpressionEvaluator ee(fusion());
226 auto evaluated_val = ee.evaluate(this);
227 TORCH_INTERNAL_ASSERT(
228 evaluated_val.has_value(),
229 "Detected a const integer but failed to infer its value.");
230 return evaluated_val->as<double>();
231}
232
233c10::optional<int64_t> Val::getInt() const {
234 if (isConstScalar() && isAnInt()) {
235 if (this->getValType() == ValType::Scalar) {
236 if (this->isA<Int>()) {
237 return this->as<Int>()->value();
238 }
239 }
240 }
241 return c10::nullopt;
242}
243
244c10::optional<double> Val::getDouble() const {
245 if (isConstScalar() && isAnInt()) {
246 if (this->getValType() == ValType::Scalar) {
247 if (this->isA<Double>()) {
248 return this->as<Double>()->value();
249 }
250 }
251 }
252 return c10::nullopt;
253}
254
255bool Val::isZeroInt() const {
256 auto int_val = getInt();
257 return int_val.has_value() && int_val.value() == 0;
258}
259
260bool Val::isOneInt() const {
261 auto int_val = getInt();
262 return int_val.has_value() && int_val.value() == 1;
263}
264
265bool Val::isDefinitionType(ExprType expression_type) const {
266 if (definition() != nullptr) {
267 auto def_expr_type = definition()->getExprType();
268 if (def_expr_type.has_value() && def_expr_type.value() == expression_type) {
269 return true;
270 }
271 }
272 return false;
273}
274
275c10::optional<DataType> Val::getDataType() const {
276 TORCH_INTERNAL_ASSERT(
277 dtype_ != DataType::Null, "Value does not have a data type.");
278 return dtype_;
279}
280
281bool Val::isProducerOf(const Val* other) const {
282 TORCH_INTERNAL_ASSERT(other != nullptr);
283 TORCH_INTERNAL_ASSERT(container() == other->container());
284
285 if (definition() == nullptr) {
286 return false;
287 }
288 return std::any_of(
289 definition()->inputs().begin(),
290 definition()->inputs().end(),
291 [other](const Val* input) { return input == other; });
292}
293
294bool Val::isConsumerOf(const Val* other) const {
295 return other->isProducerOf(this);
296}
297
298// We don't register with the active fusion in Expr as this needs to be done
299// after inputs and outputs are registered with the Expr
300Expr::Expr(IrBuilderPasskey passkey, ExprType etype)
301 : Statement(passkey), etype_{etype} {}
302
303Expr::Expr(const Expr* src, IrCloner* ir_cloner)
304 : Statement(src, ir_cloner),
305 etype_(src->etype_),
306 inputs_(ir_cloner->clone(src->inputs_)),
307 outputs_(ir_cloner->clone(src->outputs_)) {}
308
309bool Expr::sameAs(const Statement* other) const {
310 if (this == other) {
311 return true;
312 }
313 if (!other->isA<Expr>()) {
314 return false;
315 }
316 const Expr* other_expr = other->as<Expr>();
317 if (getExprType() != other_expr->getExprType()) {
318 return false;
319 }
320 if (inputs().size() != other_expr->inputs().size() ||
321 outputs().size() != other_expr->outputs().size()) {
322 return false;
323 }
324 for (const auto i : c10::irange(inputs().size())) {
325 if (!input(i)->sameAs(other_expr->input(i))) {
326 return false;
327 }
328 }
329 return true;
330}
331
332kir::Predicate* Expr::predicate() const {
333 TORCH_INTERNAL_ASSERT(
334 container()->isA<kir::Kernel>(), "Function invalid for fusion.");
335 return predicate_;
336}
337
338void Expr::setPredicate(kir::Predicate* predicate) {
339 TORCH_INTERNAL_ASSERT(
340 container()->isA<kir::Kernel>(), "Function invalid for fusion.");
341 predicate_ = predicate;
342}
343
344Expr* Expr::withPredicate(kir::Predicate* predicate) {
345 auto result = shallowCopy();
346 result->setPredicate(predicate);
347 return result;
348}
349
350kir::Predicate* Expr::writePredicate() const {
351 TORCH_INTERNAL_ASSERT(
352 container()->isA<kir::Kernel>(), "Function invalid for fusion.");
353 return write_predicate_;
354}
355
356void Expr::setWritePredicate(kir::Predicate* write_predicate) {
357 TORCH_INTERNAL_ASSERT(
358 container()->isA<kir::Kernel>(), "Function invalid for fusion.");
359 write_predicate_ = write_predicate;
360}
361
362Expr* Expr::withWritePredicate(kir::Predicate* predicate) {
363 auto result = shallowCopy();
364 result->setWritePredicate(predicate);
365 return result;
366}
367
368void Expr::copyPredicatesFrom(const Expr* expr) {
369 if (container()->isA<kir::Kernel>()) {
370 predicate_ = expr->predicate_;
371 write_predicate_ = expr->write_predicate_;
372 }
373}
374
375} // namespace cuda
376} // namespace fuser
377} // namespace jit
378} // namespace torch
379