1
2#include <instrumentation.h>
3#include <kernel_expr_evaluator.h>
4
5#include <iostream>
6
7namespace torch {
8namespace jit {
9namespace fuser {
10namespace cuda {
11namespace kir {
12
13namespace {
14
15template <typename T>
16c10::optional<IntOrDouble> toOptionalIntOrDouble(c10::optional<T> i) {
17 if (!i) {
18 return c10::nullopt;
19 }
20 return IntOrDouble(i.value());
21}
22
23} // namespace
24
25void ExpressionEvaluator::bind(const Val* value, IntOrDouble concrete_value) {
26 TORCH_CHECK(value->isScalar());
27 TORCH_CHECK(
28 value->dtype() == DataType::Int || value->dtype() == DataType::Double);
29 TORCH_CHECK(!value->isConstScalar(), "Tried to bind to a constant value");
30 TORCH_CHECK(
31 value->definition() == nullptr,
32 "Tried to bind to a value that is computed in the kernel IR: ",
33 value->toInlineString(),
34 " with ",
35 concrete_value);
36 known_values_[value] = concrete_value;
37}
38
39void ExpressionEvaluator::bind(
40 ParallelType pt,
41 Int::ScalarType concrete_value) {
42 TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt));
43 if (precomputed_values_) {
44 // Need to bind the thread value to integer machine
45 // in pre-computed mode.
46 precomputed_values_->bindConcreteParallelTypeValue(pt, concrete_value);
47 } else {
48 known_parallel_dimensions_[pt] = concrete_value;
49 }
50}
51
52c10::optional<IntOrDouble> ExpressionEvaluator::evaluate(const Val* value) {
53 if (precomputed_values_ && precomputed_values_->ready()) {
54 if (precomputed_values_->getMaybeValueFor(value).has_value()) {
55 return toOptionalIntOrDouble(
56 precomputed_values_->getMaybeValueFor(value));
57 }
58 }
59
60 if (value->isScalar() && value->isConst()) {
61 if (value->isADouble()) {
62 return toOptionalIntOrDouble(value->as<Double>()->value());
63 }
64 return toOptionalIntOrDouble(value->as<Int>()->value());
65 } else {
66 FUSER_PERF_SCOPE("kir::ExpressionEvaluator::evaluate");
67
68 TORCH_CHECK(value->isScalar(), value->toString());
69 TORCH_CHECK(
70 value->dtype() == DataType::Int || value->dtype() == DataType::Double,
71 value->toString());
72
73 // Is the value known (either explicit binding or memoized)?
74 const auto pre_eval_it = known_values_.find(value);
75 if (pre_eval_it != known_values_.end()) {
76 return pre_eval_it->second;
77 }
78
79 OptOutConstDispatch::handle(value);
80
81 const auto post_eval_it = known_values_.find(value);
82 return post_eval_it != known_values_.end()
83 ? c10::optional<IntOrDouble>(post_eval_it->second)
84 : c10::nullopt;
85 }
86 return c10::nullopt;
87}
88
89bool ExpressionEvaluator::isConst(const Val* value) {
90 return ExpressionEvaluator().evaluate(value).has_value();
91}
92
93void ExpressionEvaluator::print() const {
94 std::cout << "\nEvaluation context\n";
95 std::cout << "--------------------\n";
96 for (const auto& kv : known_values_) {
97 std::cout << kv.first->toString() << " = " << kv.second << "\n";
98 }
99 std::cout << "\nPre-computed Values\n";
100 if (precomputed_values_ != nullptr) {
101 precomputed_values_->print();
102 }
103 std::cout << "--------------------\n\n";
104}
105
106void ExpressionEvaluator::handle(const Int* value) {
107 TORCH_INTERNAL_ASSERT(!value->isConst());
108 if (auto def = value->definition()) {
109 OptOutConstDispatch::handle(def);
110 }
111}
112
113void ExpressionEvaluator::handle(const Double* value) {
114 TORCH_INTERNAL_ASSERT(!value->isConst());
115 if (auto def = value->definition()) {
116 OptOutConstDispatch::handle(def);
117 }
118}
119
120void ExpressionEvaluator::handle(const NamedScalar* named_scalar) {
121 const auto& name = named_scalar->name();
122 for (auto pt : kParallelTypeThreads) {
123 auto pt_val_it = known_parallel_dimensions_.find(pt);
124 if (pt_val_it == known_parallel_dimensions_.end()) {
125 continue;
126 }
127 if (name == stringifyThreadSize(pt)) {
128 known_values_[named_scalar] = pt_val_it->second;
129 return;
130 }
131 }
132}
133
134void ExpressionEvaluator::handle(const UnaryOp* unary_op) {
135 using namespace IntOrDouble_functions;
136 const auto in = evaluate(unary_op->in());
137 if (in.has_value()) {
138 switch (unary_op->getUnaryOpType()) {
139 case UnaryOpType::Neg:
140 known_values_[unary_op->out()] = -*in;
141 break;
142 case UnaryOpType::Set:
143 known_values_[unary_op->out()] = *in;
144 break;
145 case UnaryOpType::Cast:
146 if (unary_op->out()->getDataType() == DataType::Int) {
147 known_values_[unary_op->out()] = in->cast<int64_t>();
148 } else if (unary_op->out()->getDataType() == DataType::Double) {
149 known_values_[unary_op->out()] = in->cast<double>();
150 } else {
151 TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator");
152 }
153 break;
154 case UnaryOpType::Abs:
155 known_values_[unary_op->out()] = abs(*in);
156 break;
157 default:
158 TORCH_CHECK(
159 false,
160 "Unexpected operator type ",
161 unary_op->getUnaryOpType(),
162 " in ",
163 unary_op->toString());
164 }
165 }
166}
167
168void ExpressionEvaluator::handle(const BinaryOp* binary_op) {
169 using namespace IntOrDouble_functions;
170 const auto lhs = evaluate(binary_op->lhs());
171 const auto rhs = evaluate(binary_op->rhs());
172 if (lhs.has_value() && rhs.has_value()) {
173 switch (binary_op->getBinaryOpType()) {
174 case BinaryOpType::Add:
175 known_values_[binary_op->out()] = *lhs + *rhs;
176 break;
177 case BinaryOpType::Sub:
178 known_values_[binary_op->out()] = *lhs - *rhs;
179 break;
180 case BinaryOpType::Mul:
181 known_values_[binary_op->out()] = *lhs * *rhs;
182 break;
183 case BinaryOpType::Div:
184 TORCH_CHECK(*rhs != 0);
185 known_values_[binary_op->out()] = *lhs / *rhs;
186 break;
187 case BinaryOpType::Mod:
188 TORCH_CHECK(*rhs != 0);
189 known_values_[binary_op->out()] = *lhs % *rhs;
190 break;
191 case BinaryOpType::CeilDiv:
192 TORCH_CHECK(*rhs != 0);
193 known_values_[binary_op->out()] = ceildiv(*lhs, *rhs);
194 break;
195 case BinaryOpType::And:
196 known_values_[binary_op->out()] = Int::ScalarType(*lhs && *rhs);
197 break;
198 default:
199 TORCH_CHECK(!"Unexpected operator type");
200 }
201 }
202}
203
204} // namespace kir
205} // namespace cuda
206} // namespace fuser
207} // namespace jit
208} // namespace torch
209