1 | |
2 | #include <instrumentation.h> |
3 | #include <kernel_expr_evaluator.h> |
4 | |
5 | #include <iostream> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace fuser { |
10 | namespace cuda { |
11 | namespace kir { |
12 | |
13 | namespace { |
14 | |
15 | template <typename T> |
16 | c10::optional<IntOrDouble> toOptionalIntOrDouble(c10::optional<T> i) { |
17 | if (!i) { |
18 | return c10::nullopt; |
19 | } |
20 | return IntOrDouble(i.value()); |
21 | } |
22 | |
23 | } // namespace |
24 | |
25 | void ExpressionEvaluator::bind(const Val* value, IntOrDouble concrete_value) { |
26 | TORCH_CHECK(value->isScalar()); |
27 | TORCH_CHECK( |
28 | value->dtype() == DataType::Int || value->dtype() == DataType::Double); |
29 | TORCH_CHECK(!value->isConstScalar(), "Tried to bind to a constant value" ); |
30 | TORCH_CHECK( |
31 | value->definition() == nullptr, |
32 | "Tried to bind to a value that is computed in the kernel IR: " , |
33 | value->toInlineString(), |
34 | " with " , |
35 | concrete_value); |
36 | known_values_[value] = concrete_value; |
37 | } |
38 | |
39 | void ExpressionEvaluator::bind( |
40 | ParallelType pt, |
41 | Int::ScalarType concrete_value) { |
42 | TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt)); |
43 | if (precomputed_values_) { |
44 | // Need to bind the thread value to integer machine |
45 | // in pre-computed mode. |
46 | precomputed_values_->bindConcreteParallelTypeValue(pt, concrete_value); |
47 | } else { |
48 | known_parallel_dimensions_[pt] = concrete_value; |
49 | } |
50 | } |
51 | |
52 | c10::optional<IntOrDouble> ExpressionEvaluator::evaluate(const Val* value) { |
53 | if (precomputed_values_ && precomputed_values_->ready()) { |
54 | if (precomputed_values_->getMaybeValueFor(value).has_value()) { |
55 | return toOptionalIntOrDouble( |
56 | precomputed_values_->getMaybeValueFor(value)); |
57 | } |
58 | } |
59 | |
60 | if (value->isScalar() && value->isConst()) { |
61 | if (value->isADouble()) { |
62 | return toOptionalIntOrDouble(value->as<Double>()->value()); |
63 | } |
64 | return toOptionalIntOrDouble(value->as<Int>()->value()); |
65 | } else { |
66 | FUSER_PERF_SCOPE("kir::ExpressionEvaluator::evaluate" ); |
67 | |
68 | TORCH_CHECK(value->isScalar(), value->toString()); |
69 | TORCH_CHECK( |
70 | value->dtype() == DataType::Int || value->dtype() == DataType::Double, |
71 | value->toString()); |
72 | |
73 | // Is the value known (either explicit binding or memoized)? |
74 | const auto pre_eval_it = known_values_.find(value); |
75 | if (pre_eval_it != known_values_.end()) { |
76 | return pre_eval_it->second; |
77 | } |
78 | |
79 | OptOutConstDispatch::handle(value); |
80 | |
81 | const auto post_eval_it = known_values_.find(value); |
82 | return post_eval_it != known_values_.end() |
83 | ? c10::optional<IntOrDouble>(post_eval_it->second) |
84 | : c10::nullopt; |
85 | } |
86 | return c10::nullopt; |
87 | } |
88 | |
89 | bool ExpressionEvaluator::isConst(const Val* value) { |
90 | return ExpressionEvaluator().evaluate(value).has_value(); |
91 | } |
92 | |
93 | void ExpressionEvaluator::print() const { |
94 | std::cout << "\nEvaluation context\n" ; |
95 | std::cout << "--------------------\n" ; |
96 | for (const auto& kv : known_values_) { |
97 | std::cout << kv.first->toString() << " = " << kv.second << "\n" ; |
98 | } |
99 | std::cout << "\nPre-computed Values\n" ; |
100 | if (precomputed_values_ != nullptr) { |
101 | precomputed_values_->print(); |
102 | } |
103 | std::cout << "--------------------\n\n" ; |
104 | } |
105 | |
106 | void ExpressionEvaluator::handle(const Int* value) { |
107 | TORCH_INTERNAL_ASSERT(!value->isConst()); |
108 | if (auto def = value->definition()) { |
109 | OptOutConstDispatch::handle(def); |
110 | } |
111 | } |
112 | |
113 | void ExpressionEvaluator::handle(const Double* value) { |
114 | TORCH_INTERNAL_ASSERT(!value->isConst()); |
115 | if (auto def = value->definition()) { |
116 | OptOutConstDispatch::handle(def); |
117 | } |
118 | } |
119 | |
120 | void ExpressionEvaluator::handle(const NamedScalar* named_scalar) { |
121 | const auto& name = named_scalar->name(); |
122 | for (auto pt : kParallelTypeThreads) { |
123 | auto pt_val_it = known_parallel_dimensions_.find(pt); |
124 | if (pt_val_it == known_parallel_dimensions_.end()) { |
125 | continue; |
126 | } |
127 | if (name == stringifyThreadSize(pt)) { |
128 | known_values_[named_scalar] = pt_val_it->second; |
129 | return; |
130 | } |
131 | } |
132 | } |
133 | |
134 | void ExpressionEvaluator::handle(const UnaryOp* unary_op) { |
135 | using namespace IntOrDouble_functions; |
136 | const auto in = evaluate(unary_op->in()); |
137 | if (in.has_value()) { |
138 | switch (unary_op->getUnaryOpType()) { |
139 | case UnaryOpType::Neg: |
140 | known_values_[unary_op->out()] = -*in; |
141 | break; |
142 | case UnaryOpType::Set: |
143 | known_values_[unary_op->out()] = *in; |
144 | break; |
145 | case UnaryOpType::Cast: |
146 | if (unary_op->out()->getDataType() == DataType::Int) { |
147 | known_values_[unary_op->out()] = in->cast<int64_t>(); |
148 | } else if (unary_op->out()->getDataType() == DataType::Double) { |
149 | known_values_[unary_op->out()] = in->cast<double>(); |
150 | } else { |
151 | TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator" ); |
152 | } |
153 | break; |
154 | case UnaryOpType::Abs: |
155 | known_values_[unary_op->out()] = abs(*in); |
156 | break; |
157 | default: |
158 | TORCH_CHECK( |
159 | false, |
160 | "Unexpected operator type " , |
161 | unary_op->getUnaryOpType(), |
162 | " in " , |
163 | unary_op->toString()); |
164 | } |
165 | } |
166 | } |
167 | |
168 | void ExpressionEvaluator::handle(const BinaryOp* binary_op) { |
169 | using namespace IntOrDouble_functions; |
170 | const auto lhs = evaluate(binary_op->lhs()); |
171 | const auto rhs = evaluate(binary_op->rhs()); |
172 | if (lhs.has_value() && rhs.has_value()) { |
173 | switch (binary_op->getBinaryOpType()) { |
174 | case BinaryOpType::Add: |
175 | known_values_[binary_op->out()] = *lhs + *rhs; |
176 | break; |
177 | case BinaryOpType::Sub: |
178 | known_values_[binary_op->out()] = *lhs - *rhs; |
179 | break; |
180 | case BinaryOpType::Mul: |
181 | known_values_[binary_op->out()] = *lhs * *rhs; |
182 | break; |
183 | case BinaryOpType::Div: |
184 | TORCH_CHECK(*rhs != 0); |
185 | known_values_[binary_op->out()] = *lhs / *rhs; |
186 | break; |
187 | case BinaryOpType::Mod: |
188 | TORCH_CHECK(*rhs != 0); |
189 | known_values_[binary_op->out()] = *lhs % *rhs; |
190 | break; |
191 | case BinaryOpType::CeilDiv: |
192 | TORCH_CHECK(*rhs != 0); |
193 | known_values_[binary_op->out()] = ceildiv(*lhs, *rhs); |
194 | break; |
195 | case BinaryOpType::And: |
196 | known_values_[binary_op->out()] = Int::ScalarType(*lhs && *rhs); |
197 | break; |
198 | default: |
199 | TORCH_CHECK(!"Unexpected operator type" ); |
200 | } |
201 | } |
202 | } |
203 | |
204 | } // namespace kir |
205 | } // namespace cuda |
206 | } // namespace fuser |
207 | } // namespace jit |
208 | } // namespace torch |
209 | |