1 | |
2 | #include <evaluator_common.h> |
3 | #include <expr_evaluator.h> |
4 | #include <fusion.h> |
5 | #include <instrumentation.h> |
6 | #include <ir_all_nodes.h> |
7 | #include <ir_iostream.h> |
8 | |
9 | #include <iostream> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | namespace { |
17 | |
18 | bool equals(Val* value, const IntOrDouble& concrete_value) { |
19 | switch (value->getDataType().value()) { |
20 | case DataType::Int: { |
21 | if (!concrete_value.is_int()) { |
22 | return false; |
23 | } |
24 | auto val = value->getInt(); |
25 | return val.has_value() && val.value() == concrete_value.as<int64_t>(); |
26 | } |
27 | case DataType::Double: { |
28 | if (concrete_value.is_int()) { |
29 | return false; |
30 | } |
31 | auto val = value->getDouble(); |
32 | return val.has_value() && val.value() == concrete_value.as<double>(); |
33 | } |
34 | default: |
35 | TORCH_INTERNAL_ASSERT(false); |
36 | } |
37 | } |
38 | |
39 | template <typename T> |
40 | c10::optional<IntOrDouble> toOptionalIntOrDouble(c10::optional<T> i) { |
41 | if (!i) { |
42 | return c10::nullopt; |
43 | } |
44 | return IntOrDouble(i.value()); |
45 | } |
46 | |
47 | } // namespace |
48 | |
49 | void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) { |
50 | if (equals(value, concrete_value)) { |
51 | return; |
52 | } |
53 | TORCH_CHECK(!value->isConstScalar(), "Tried to bind to a constant value" ); |
54 | TORCH_CHECK( |
55 | value->definition() == nullptr, |
56 | "Tried to bind to a value that is computed in the fusion IR" ); |
57 | if (value->isA<NamedScalar>()) { |
58 | known_named_scalars_[value->as<NamedScalar>()->name()] = concrete_value; |
59 | } else { |
60 | known_values_[value] = concrete_value; |
61 | } |
62 | } |
63 | |
64 | void ExpressionEvaluator::bind( |
65 | const std::string& name, |
66 | const IntOrDouble& concrete_value) { |
67 | known_named_scalars_[name] = concrete_value; |
68 | } |
69 | |
70 | c10::optional<IntOrDouble> ExpressionEvaluator::evaluate(Val* value) { |
71 | if (evaluator_precomputed_values_ != nullptr) { |
72 | return toOptionalIntOrDouble( |
73 | evaluator_precomputed_values_->getMaybeValueFor(value)); |
74 | } else { |
75 | auto maybe_concrete_value = getValue(value); |
76 | if (!maybe_concrete_value.has_value()) { |
77 | if (value->definition() != nullptr) { |
78 | OptOutDispatch::handle(value->definition()); |
79 | maybe_concrete_value = getValue(value); |
80 | } |
81 | } |
82 | return maybe_concrete_value; |
83 | } |
84 | return c10::nullopt; |
85 | } |
86 | |
87 | void ExpressionEvaluator::print() const { |
88 | std::cout << "\nEvaluation context\n" ; |
89 | std::cout << "--------------------\n" ; |
90 | for (const auto& kv : known_values_) { |
91 | TORCH_INTERNAL_ASSERT(!kv.first->isConstScalar()); |
92 | std::cout << kv.first << " = " << kv.second << " ; " |
93 | << *kv.first->getValType() << "\n" ; |
94 | } |
95 | std::cout << "--------------------\n\n" ; |
96 | } |
97 | |
98 | c10::optional<IntOrDouble> ExpressionEvaluator::getValue(Val* value) { |
99 | TORCH_INTERNAL_ASSERT( |
100 | value->isAnInt() || value->isADouble(), |
101 | "Expression Evaluation does not support values other than integers/doubles at this time." ); |
102 | |
103 | if (value->getValType().value() == ValType::Scalar) { |
104 | if (value->isAnInt() && value->as<Int>()->value().has_value()) { |
105 | return toOptionalIntOrDouble(value->as<Int>()->value()); |
106 | } |
107 | if (value->isADouble() && value->as<Double>()->value().has_value()) { |
108 | return toOptionalIntOrDouble(value->as<Double>()->value()); |
109 | } |
110 | } |
111 | |
112 | if (value->isA<NamedScalar>()) { |
113 | const auto it = known_named_scalars_.find(value->as<NamedScalar>()->name()); |
114 | return it != known_named_scalars_.end() |
115 | ? c10::optional<IntOrDouble>(it->second) |
116 | : c10::nullopt; |
117 | } else { |
118 | const auto it = known_values_.find(value); |
119 | return it != known_values_.end() ? c10::optional<IntOrDouble>(it->second) |
120 | : c10::nullopt; |
121 | } |
122 | } |
123 | |
124 | void ExpressionEvaluator::handle(UnaryOp* uop) { |
125 | using namespace IntOrDouble_functions; |
126 | const auto in = evaluate(uop->in()); |
127 | if (in.has_value()) { |
128 | switch (uop->getUnaryOpType()) { |
129 | case UnaryOpType::Neg: |
130 | known_values_[uop->out()] = -*in; |
131 | break; |
132 | case UnaryOpType::Set: |
133 | known_values_[uop->out()] = *in; |
134 | break; |
135 | case UnaryOpType::Cast: |
136 | if (uop->out()->getDataType() == DataType::Int) { |
137 | known_values_[uop->out()] = in->cast<int64_t>(); |
138 | } else if (uop->out()->getDataType() == DataType::Double) { |
139 | known_values_[uop->out()] = in->cast<double>(); |
140 | } else { |
141 | TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator" ); |
142 | } |
143 | break; |
144 | case UnaryOpType::Abs: |
145 | known_values_[uop->out()] = abs(*in); |
146 | break; |
147 | default: |
148 | TORCH_CHECK( |
149 | !"Unexpected operator type " , |
150 | uop->getUnaryOpType(), |
151 | " in " , |
152 | uop->toString()); |
153 | } |
154 | } |
155 | } |
156 | |
157 | void ExpressionEvaluator::handle(BinaryOp* bop) { |
158 | using namespace IntOrDouble_functions; |
159 | const auto lhs = evaluate(bop->lhs()); |
160 | const auto rhs = evaluate(bop->rhs()); |
161 | if (lhs.has_value() && rhs.has_value()) { |
162 | switch (bop->getBinaryOpType()) { |
163 | case BinaryOpType::Add: |
164 | known_values_[bop->out()] = *lhs + *rhs; |
165 | break; |
166 | case BinaryOpType::Sub: |
167 | known_values_[bop->out()] = *lhs - *rhs; |
168 | break; |
169 | case BinaryOpType::Mul: |
170 | known_values_[bop->out()] = *lhs * *rhs; |
171 | break; |
172 | case BinaryOpType::Div: |
173 | TORCH_CHECK(*rhs != 0); |
174 | known_values_[bop->out()] = *lhs / *rhs; |
175 | break; |
176 | case BinaryOpType::Mod: |
177 | TORCH_CHECK(*rhs != 0); |
178 | known_values_[bop->out()] = *lhs % *rhs; |
179 | break; |
180 | case BinaryOpType::CeilDiv: |
181 | TORCH_CHECK(*rhs != 0); |
182 | known_values_[bop->out()] = ceildiv(*lhs, *rhs); |
183 | break; |
184 | case BinaryOpType::And: |
185 | known_values_[bop->out()] = *lhs && *rhs; |
186 | break; |
187 | case BinaryOpType::Max: |
188 | known_values_[bop->out()] = max(*lhs, *rhs); |
189 | break; |
190 | case BinaryOpType::Min: |
191 | known_values_[bop->out()] = min(*lhs, *rhs); |
192 | break; |
193 | default: |
194 | TORCH_CHECK(!"Unexpected operator type" ); |
195 | } |
196 | } |
197 | } |
198 | |
199 | } // namespace cuda |
200 | } // namespace fuser |
201 | } // namespace jit |
202 | } // namespace torch |
203 | |