1
2#include <evaluator_common.h>
3#include <expr_evaluator.h>
4#include <fusion.h>
5#include <instrumentation.h>
6#include <ir_all_nodes.h>
7#include <ir_iostream.h>
8
9#include <iostream>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16namespace {
17
18bool equals(Val* value, const IntOrDouble& concrete_value) {
19 switch (value->getDataType().value()) {
20 case DataType::Int: {
21 if (!concrete_value.is_int()) {
22 return false;
23 }
24 auto val = value->getInt();
25 return val.has_value() && val.value() == concrete_value.as<int64_t>();
26 }
27 case DataType::Double: {
28 if (concrete_value.is_int()) {
29 return false;
30 }
31 auto val = value->getDouble();
32 return val.has_value() && val.value() == concrete_value.as<double>();
33 }
34 default:
35 TORCH_INTERNAL_ASSERT(false);
36 }
37}
38
39template <typename T>
40c10::optional<IntOrDouble> toOptionalIntOrDouble(c10::optional<T> i) {
41 if (!i) {
42 return c10::nullopt;
43 }
44 return IntOrDouble(i.value());
45}
46
47} // namespace
48
49void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) {
50 if (equals(value, concrete_value)) {
51 return;
52 }
53 TORCH_CHECK(!value->isConstScalar(), "Tried to bind to a constant value");
54 TORCH_CHECK(
55 value->definition() == nullptr,
56 "Tried to bind to a value that is computed in the fusion IR");
57 if (value->isA<NamedScalar>()) {
58 known_named_scalars_[value->as<NamedScalar>()->name()] = concrete_value;
59 } else {
60 known_values_[value] = concrete_value;
61 }
62}
63
64void ExpressionEvaluator::bind(
65 const std::string& name,
66 const IntOrDouble& concrete_value) {
67 known_named_scalars_[name] = concrete_value;
68}
69
70c10::optional<IntOrDouble> ExpressionEvaluator::evaluate(Val* value) {
71 if (evaluator_precomputed_values_ != nullptr) {
72 return toOptionalIntOrDouble(
73 evaluator_precomputed_values_->getMaybeValueFor(value));
74 } else {
75 auto maybe_concrete_value = getValue(value);
76 if (!maybe_concrete_value.has_value()) {
77 if (value->definition() != nullptr) {
78 OptOutDispatch::handle(value->definition());
79 maybe_concrete_value = getValue(value);
80 }
81 }
82 return maybe_concrete_value;
83 }
84 return c10::nullopt;
85}
86
87void ExpressionEvaluator::print() const {
88 std::cout << "\nEvaluation context\n";
89 std::cout << "--------------------\n";
90 for (const auto& kv : known_values_) {
91 TORCH_INTERNAL_ASSERT(!kv.first->isConstScalar());
92 std::cout << kv.first << " = " << kv.second << " ; "
93 << *kv.first->getValType() << "\n";
94 }
95 std::cout << "--------------------\n\n";
96}
97
98c10::optional<IntOrDouble> ExpressionEvaluator::getValue(Val* value) {
99 TORCH_INTERNAL_ASSERT(
100 value->isAnInt() || value->isADouble(),
101 "Expression Evaluation does not support values other than integers/doubles at this time.");
102
103 if (value->getValType().value() == ValType::Scalar) {
104 if (value->isAnInt() && value->as<Int>()->value().has_value()) {
105 return toOptionalIntOrDouble(value->as<Int>()->value());
106 }
107 if (value->isADouble() && value->as<Double>()->value().has_value()) {
108 return toOptionalIntOrDouble(value->as<Double>()->value());
109 }
110 }
111
112 if (value->isA<NamedScalar>()) {
113 const auto it = known_named_scalars_.find(value->as<NamedScalar>()->name());
114 return it != known_named_scalars_.end()
115 ? c10::optional<IntOrDouble>(it->second)
116 : c10::nullopt;
117 } else {
118 const auto it = known_values_.find(value);
119 return it != known_values_.end() ? c10::optional<IntOrDouble>(it->second)
120 : c10::nullopt;
121 }
122}
123
124void ExpressionEvaluator::handle(UnaryOp* uop) {
125 using namespace IntOrDouble_functions;
126 const auto in = evaluate(uop->in());
127 if (in.has_value()) {
128 switch (uop->getUnaryOpType()) {
129 case UnaryOpType::Neg:
130 known_values_[uop->out()] = -*in;
131 break;
132 case UnaryOpType::Set:
133 known_values_[uop->out()] = *in;
134 break;
135 case UnaryOpType::Cast:
136 if (uop->out()->getDataType() == DataType::Int) {
137 known_values_[uop->out()] = in->cast<int64_t>();
138 } else if (uop->out()->getDataType() == DataType::Double) {
139 known_values_[uop->out()] = in->cast<double>();
140 } else {
141 TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator");
142 }
143 break;
144 case UnaryOpType::Abs:
145 known_values_[uop->out()] = abs(*in);
146 break;
147 default:
148 TORCH_CHECK(
149 !"Unexpected operator type ",
150 uop->getUnaryOpType(),
151 " in ",
152 uop->toString());
153 }
154 }
155}
156
157void ExpressionEvaluator::handle(BinaryOp* bop) {
158 using namespace IntOrDouble_functions;
159 const auto lhs = evaluate(bop->lhs());
160 const auto rhs = evaluate(bop->rhs());
161 if (lhs.has_value() && rhs.has_value()) {
162 switch (bop->getBinaryOpType()) {
163 case BinaryOpType::Add:
164 known_values_[bop->out()] = *lhs + *rhs;
165 break;
166 case BinaryOpType::Sub:
167 known_values_[bop->out()] = *lhs - *rhs;
168 break;
169 case BinaryOpType::Mul:
170 known_values_[bop->out()] = *lhs * *rhs;
171 break;
172 case BinaryOpType::Div:
173 TORCH_CHECK(*rhs != 0);
174 known_values_[bop->out()] = *lhs / *rhs;
175 break;
176 case BinaryOpType::Mod:
177 TORCH_CHECK(*rhs != 0);
178 known_values_[bop->out()] = *lhs % *rhs;
179 break;
180 case BinaryOpType::CeilDiv:
181 TORCH_CHECK(*rhs != 0);
182 known_values_[bop->out()] = ceildiv(*lhs, *rhs);
183 break;
184 case BinaryOpType::And:
185 known_values_[bop->out()] = *lhs && *rhs;
186 break;
187 case BinaryOpType::Max:
188 known_values_[bop->out()] = max(*lhs, *rhs);
189 break;
190 case BinaryOpType::Min:
191 known_values_[bop->out()] = min(*lhs, *rhs);
192 break;
193 default:
194 TORCH_CHECK(!"Unexpected operator type");
195 }
196 }
197}
198
199} // namespace cuda
200} // namespace fuser
201} // namespace jit
202} // namespace torch
203