1#include <expr_evaluator.h>
2#include <instrumentation.h>
3#include <ir_utils.h>
4#include <kernel_expr_evaluator.h>
5#include <lower2device.h>
6
7#include <evaluator_common.h>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14namespace {
15
16template <typename VALTYPE>
17std::vector<VALTYPE*> getImmediateProducers(VALTYPE* val) {
18 if (val->definition()) {
19 auto expr = val->definition();
20 return expr->inputs();
21 } else {
22 return {};
23 }
24}
25
26//! IR-Generic utility, collects all the producers required for the
27//! given list of IR values and returns them along with the original
28//! list in topological order.
29template <typename VALTYPE>
30std::vector<VALTYPE*> makeSortedEvaluationList(std::vector<VALTYPE*> input) {
31 // Deduplicate
32 std::vector<VALTYPE*> to_sort;
33 std::unordered_set<VALTYPE*> visited;
34 for (auto val : input) {
35 if (!visited.count(val)) {
36 to_sort.push_back(val);
37 visited.insert(val);
38 }
39 }
40
41 std::vector<VALTYPE*> sorted;
42 visited.clear();
43
44 // Topological Sort
45 // Note: didn't explicitly exclude producers that are not in the original
46 // list. This should be acceptable for the intended use.
47 while (!to_sort.empty()) {
48 auto top_val = to_sort.back();
49 if (visited.count(top_val)) {
50 to_sort.pop_back();
51 } else {
52 bool ready_to_pop = true;
53 for (auto producer : getImmediateProducers(top_val)) {
54 if (!visited.count(producer)) {
55 ready_to_pop = false;
56 to_sort.push_back(producer);
57 }
58 }
59 if (ready_to_pop) {
60 visited.insert(top_val);
61 sorted.push_back(top_val);
62 to_sort.pop_back();
63 }
64 }
65 }
66
67 return sorted;
68}
69
70//! Kernel IR utility, collects all the symbolic values
71//! used in allocation nodes.
72void collectBufferSizes(
73 std::vector<Val*>& into,
74 const std::vector<Expr*>& exprs) {
75 for (auto expr : exprs) {
76 if (auto allocate = dynamic_cast<kir::Allocate*>(expr)) {
77 into.push_back(allocate->size());
78 } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
79 collectBufferSizes(into, for_loop->body().exprs());
80 } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
81 collectBufferSizes(into, ite->thenBody().exprs());
82 collectBufferSizes(into, ite->elseBody().exprs());
83 }
84 }
85}
86
87//! Kernel IR utility, collects all the kernel symbolic
88//! values we will need at runtime, i.e. after the
89//! generated cuda kernel has already been compiled.
90//! The values are to be used for runtime logic, like
91//! `computeLaunchparams`.
92std::vector<Val*> collectRuntimeUsedValues(kir::Kernel* kernel) {
93 std::vector<Val*> ret;
94 auto all_tvs = ir_utils::allTvs(kernel);
95 // Collect extent and inputs
96 for (auto tv : all_tvs) {
97 for (auto id : tv->domain()->domain()) {
98 ret.push_back(id->extent());
99 }
100 }
101 for (auto inp : kernel->inputs()) {
102 if (inp->isA<Int>() || inp->isA<Double>()) {
103 ret.push_back(inp);
104 }
105 }
106 // Collect allocation sizes:
107 collectBufferSizes(ret, kernel->topLevelExprs());
108 return makeSortedEvaluationList(ret);
109}
110
111std::vector<Val*> collectRuntimeUsedValues(Fusion* fusion) {
112 std::vector<Val*> ret;
113 auto all_tvs = ir_utils::allTvs(fusion);
114 // Collect extent and inputs
115 for (auto tv : all_tvs) {
116 for (auto id : tv->domain()->domain()) {
117 ret.push_back(id->extent());
118 }
119 }
120 for (auto inp : fusion->inputs()) {
121 if (inp->isA<Int>() || inp->isA<Double>()) {
122 ret.push_back(inp);
123 }
124 }
125 return makeSortedEvaluationList(ret);
126}
127
128} // namespace
129
130template <typename IRContext>
131void PrecomputedValuesBase<IRContext>::initializeValueList(
132 typename IRContext::EVALUATOR_TYPE& const_evaluator,
133 const std::vector<Val*>& sorted_value_list) {
134 // Initialize workspace
135 num_of_values_ = sorted_value_list.size();
136 defined_ = std::vector<bool>(num_of_values_, false);
137 is_constant_ = std::vector<bool>(num_of_values_, false);
138 values_ = std::vector<IntOrDouble>(num_of_values_, -1);
139
140 // Fill in constants and assign evaluator indices
141 for (const auto i : c10::irange(num_of_values_)) {
142 // Use an expression evaluator to test if value is const
143 auto const_val = const_evaluator.evaluate(sorted_value_list[i]);
144 if (const_val.has_value()) {
145 is_constant_[i] = true;
146 values_[i] = const_val.value();
147 }
148 sorted_value_list[i]->setEvaluatorIndex(i);
149 }
150}
151
152template <typename IRContext>
153c10::optional<IntOrDouble> PrecomputedValuesBase<IRContext>::getMaybeValueFor(
154 const Val* val) {
155 auto index = val->evaluatorIndex();
156 if (index < 0) {
157 return c10::nullopt;
158 }
159 if (!defined_[index] && !is_constant_[index]) {
160 return c10::nullopt;
161 }
162 return values_[index];
163}
164
165template <typename IRContext>
166void PrecomputedValuesBase<IRContext>::print() const {
167 std::cout << "Precomputed Values:\n";
168 for (auto i : c10::irange(symbols_.size())) {
169 if (defined_[i]) {
170 std::cout << symbols_[i]->toInlineString() << " = " << values_[i]
171 << std::endl;
172 }
173 }
174}
175
176template <typename IRContext>
177void PrecomputedValuesBase<IRContext>::evaluate() {
178 FUSER_PERF_SCOPE("PrecomputedValues::Evaluate");
179 value_machine_->run();
180 validate();
181}
182
183template <typename IRContext>
184void PrecomputedValuesBase<IRContext>::invalidate() {
185 // clear binding values
186 binding_log_.clear();
187
188 // invalidate value entries
189 std::fill(defined_.begin(), defined_.end(), false);
190
191 // invalidate flag
192 has_valid_values_ = false;
193}
194
195template <typename IRContext>
196void PrecomputedValuesBase<IRContext>::validate() {
197 FUSER_PERF_SCOPE("PrecomputedValuess::Validate");
198 for (auto it : binding_log_) {
199 TORCH_INTERNAL_ASSERT(
200 values_[it.first] == it.second,
201 "Precomputed values failed to validate.",
202 "\nSomething unexpected changed between the compilation and execution.\n",
203 values_[it.first],
204 " != ",
205 it.second);
206 }
207 has_valid_values_ = true;
208}
209
210template <typename IRContext>
211NaiveValueMachine<IRContext>::NaiveValueMachine(
212 PrecomputedValuesBase<IRContext>& precomputed_values)
213 : precomputed_values_(precomputed_values) {
214 num_of_instructions_ = 0;
215 for (auto val : precomputed_values_.symbols_) {
216 auto def = val->definition();
217 if (def) {
218 if (auto uop = dynamic_cast<UnaryOp*>(def)) {
219 makeUnaryOp(uop);
220 } else if (auto bop = dynamic_cast<BinaryOp*>(def)) {
221 makeBinaryOp(bop);
222 } else {
223 TORCH_INTERNAL_ASSERT(false, "Unsupported expr");
224 }
225 }
226 }
227}
228
229template <typename IRContext>
230void NaiveValueMachine<IRContext>::run() {
231 for (const auto i : c10::irange(num_of_instructions_)) {
232 // Skip this instruction if the dest location
233 // has already been computed or is constant.
234 if (precomputed_values_.defined_[dest_[i]] ||
235 precomputed_values_.is_constant_[dest_[i]]) {
236 continue;
237 }
238 runInstruction(i);
239 }
240}
241
242template <typename IRContext>
243void NaiveValueMachine<IRContext>::makeUnaryOp(UnaryOp* uop) {
244 int in = uop->inputs()[0]->evaluatorIndex();
245 int out = uop->outputs()[0]->evaluatorIndex();
246 TORCH_INTERNAL_ASSERT(in >= 0, "Integer Machine: unknown input: ", uop);
247 TORCH_INTERNAL_ASSERT(out >= 0, "Integer Machine: unknown out: ", uop);
248
249 int index = makeInstructionEntry();
250 inst_type_[index] = InstructionType::UNARY_OP;
251 uop_type_[index] = IRContext::getOpType(uop);
252 if (uop_type_[index] == UnaryOpType::Cast) {
253 data_type_[index] = uop->out()->getDataType().value();
254 }
255 src0_[index] = in;
256 dest_[index] = out;
257}
258
259template <typename IRContext>
260void NaiveValueMachine<IRContext>::makeBinaryOp(BinaryOp* bop) {
261 int in0 = bop->inputs()[0]->evaluatorIndex();
262 int in1 = bop->inputs()[1]->evaluatorIndex();
263 int out = bop->outputs()[0]->evaluatorIndex();
264
265 TORCH_INTERNAL_ASSERT(in0 >= 0, "Integer Machine: unknown lhs: ", bop);
266 TORCH_INTERNAL_ASSERT(in1 >= 0, "Integer Machine: unknown rhs: ", bop);
267 TORCH_INTERNAL_ASSERT(out >= 0, "Integer Machine: unknown out: ", bop);
268
269 int index = makeInstructionEntry();
270 inst_type_[index] = InstructionType::BINARY_OP;
271 bop_type_[index] = IRContext::getOpType(bop);
272 src0_[index] = in0;
273 src1_[index] = in1;
274 dest_[index] = out;
275}
276
277template <typename IRContext>
278int NaiveValueMachine<IRContext>::makeInstructionEntry() {
279 int index = num_of_instructions_++;
280 inst_type_.push_back(InstructionType::UNARY_OP);
281 uop_type_.push_back(UnaryOpType::Abs);
282 bop_type_.push_back(BinaryOpType::Add);
283 data_type_.push_back(DataType::Null);
284 src0_.push_back(-1);
285 src1_.push_back(-1);
286 dest_.push_back(-1);
287 return index;
288}
289
290template <typename IRContext>
291void NaiveValueMachine<IRContext>::runInstruction(int index) {
292 switch (inst_type_[index]) {
293 case InstructionType::UNARY_OP:
294 runUnaryOp(index);
295 break;
296 case InstructionType::BINARY_OP:
297 runBinaryOp(index);
298 break;
299 }
300}
301
302template <typename IRContext>
303void NaiveValueMachine<IRContext>::runUnaryOp(int index) {
304 using namespace IntOrDouble_functions;
305 int src_index = src0_[index];
306 bool src_defined = precomputed_values_.defined_[src_index];
307 bool src_is_const = precomputed_values_.is_constant_[src_index];
308 if (!src_defined && !src_is_const) {
309 return;
310 }
311
312 int dest_index = dest_[index];
313
314 auto& src = precomputed_values_.values_[src_index];
315 auto& dest = precomputed_values_.values_[dest_index];
316
317 switch (uop_type_[index]) {
318 case UnaryOpType::Neg:
319 dest = -src;
320 break;
321 case UnaryOpType::Set:
322 dest = src;
323 break;
324 case UnaryOpType::Cast:
325 if (data_type_[index] == DataType::Double) {
326 dest = src.template cast<double>();
327 } else if (data_type_[index] == DataType::Int) {
328 dest = src.template cast<int64_t>();
329 } else {
330 TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator");
331 }
332 break;
333 case UnaryOpType::Abs:
334 dest = abs(src);
335 break;
336 default:
337 TORCH_CHECK(!"Unexpected operator type ", uop_type_[index]);
338 }
339
340 precomputed_values_.defined_[dest_index] = true;
341}
342
343template <typename IRContext>
344void NaiveValueMachine<IRContext>::runBinaryOp(int index) {
345 using namespace IntOrDouble_functions;
346 int src0_index = src0_[index];
347 int src1_index = src1_[index];
348 bool src0_is_const = precomputed_values_.is_constant_[src0_index];
349 bool src1_is_const = precomputed_values_.is_constant_[src1_index];
350
351 bool src_defined =
352 (precomputed_values_.defined_[src0_index] || src0_is_const) &&
353 (precomputed_values_.defined_[src1_index] || src1_is_const);
354
355 if (!src_defined) {
356 return;
357 }
358 int dest_index = dest_[index];
359
360 auto& lhs = precomputed_values_.values_[src0_index];
361 auto& rhs = precomputed_values_.values_[src1_index];
362 auto& dest = precomputed_values_.values_[dest_index];
363
364 switch (bop_type_[index]) {
365 case BinaryOpType::Add:
366 dest = lhs + rhs;
367 break;
368 case BinaryOpType::Sub:
369 dest = lhs - rhs;
370 break;
371 case BinaryOpType::Mul:
372 dest = lhs * rhs;
373 break;
374 case BinaryOpType::Div:
375 TORCH_CHECK(rhs != 0);
376 dest = lhs / rhs;
377 break;
378 case BinaryOpType::Mod:
379 TORCH_CHECK(rhs != 0);
380 dest = lhs % rhs;
381 break;
382 case BinaryOpType::CeilDiv:
383 TORCH_CHECK(rhs != 0);
384 dest = ceildiv(lhs, rhs);
385 break;
386 case BinaryOpType::And:
387 dest = Int::ScalarType(lhs && rhs);
388 break;
389 case BinaryOpType::Max:
390 dest = lhs > rhs ? lhs : rhs;
391 break;
392 case BinaryOpType::Min:
393 dest = lhs < rhs ? lhs : rhs;
394 break;
395 default:
396 TORCH_CHECK(!"Unexpected operator type");
397 }
398
399 precomputed_values_.defined_[dest_index] = true;
400}
401
402KernelPrecomputedValues::KernelPrecomputedValues(kir::Kernel* kernel) {
403 loadSymbols(collectRuntimeUsedValues(kernel));
404 kir::ExpressionEvaluator evaluator;
405 initializeValueList(evaluator, symbols());
406 initializeNamedScalars();
407 initializeIntegerMachine();
408}
409
410// TODO: put this to base class
411void KernelPrecomputedValues::bindTensorMetaData(
412 TensorView* tv,
413 const TensorArgAbstract* tensor_arg_abstract) {
414 const auto root_domain =
415 TensorDomain::noReductions(tv->domain()->getMaybeRFactorDomain());
416 TORCH_INTERNAL_ASSERT(
417 tensor_arg_abstract->getRank() == static_cast<int>(root_domain.size()),
418 "Something went wrong configuring launch. Inputs do not match.");
419
420 for (const auto dim : c10::irange(root_domain.size())) {
421 auto extent = root_domain[dim]->extent();
422 auto value = tensor_arg_abstract->getSize(dim);
423 bindValue(extent->evaluatorIndex(), value);
424 }
425}
426
427namespace {
428
429//! Compares the name of given scalar with thread size strings
430//! and returns the corresponding parallel type if a match
431//! is found.
432c10::optional<ParallelType> getMaybeThreadSizeParallelType(
433 NamedScalar* named_scalar) {
434 auto& var_name = named_scalar->name();
435 for (auto ptype : kParallelTypeThreads) {
436 if (var_name == stringifyThreadSize(ptype)) {
437 return ptype;
438 }
439 }
440 return c10::nullopt;
441}
442
443} // namespace
444
445void KernelPrecomputedValues::initializeNamedScalars() {
446 for (auto val : symbols()) {
447 if (auto named_scalar = dynamic_cast<NamedScalar*>(val)) {
448 auto maybe_parallel_type = getMaybeThreadSizeParallelType(named_scalar);
449 if (maybe_parallel_type.has_value()) {
450 auto& index_list =
451 thread_dim_value_indices_[maybe_parallel_type.value()];
452 if (!index_list) {
453 index_list = std::make_unique<std::vector<int>>();
454 }
455 index_list->push_back(val->evaluatorIndex());
456 }
457 }
458 }
459}
460
461// TODO: merge this one with above.
462void KernelPrecomputedValues::bindKernelInputs(
463 kir::Kernel* kernel,
464 const KernelArgumentHolder& args) {
465 if (hasValidValues()) {
466 invalidate();
467 }
468
469 const auto& inputs = kernel->inputs();
470 TORCH_INTERNAL_ASSERT(
471 args.size() == inputs.size(), "kernel inputs size does not match args");
472
473 for (const auto i : c10::irange(inputs.size())) {
474 auto arg = args[i];
475 const auto input = inputs[i];
476 if (auto tensor_input = dynamic_cast<TensorView*>(input)) {
477 if (const auto& tensor_arg_abstract =
478 dynamic_cast<const TensorArgAbstract*>(arg)) {
479 bindTensorMetaData(tensor_input, tensor_arg_abstract);
480 } else {
481 // TODO: cpu scalar of int type should be bound as scalar int as well
482 TORCH_CHECK(
483 arg->isType(ArgType::CpuScalarTensor),
484 "binding input to TensorView expects input arg to be of tensor type");
485 }
486 } else if (input->isScalar()) {
487 if (input->dtype() == DataType::Int) {
488 TORCH_CHECK(
489 arg->isType(ArgType::Long),
490 "binding input to integer type expects input arg to be a scalar of Long type");
491 precomputedValuesBaseType::bindValue(
492 input->evaluatorIndex(), *static_cast<const int64_t*>(arg->arg()));
493 } else if (input->dtype() == DataType::Double) {
494 TORCH_CHECK(
495 arg->isType(ArgType::Double),
496 "binding input to double type expects input arg to be a scalar of Double type");
497 precomputedValuesBaseType::bindValue(
498 input->evaluatorIndex(), *static_cast<const double*>(arg->arg()));
499 }
500 }
501 }
502}
503
504void KernelPrecomputedValues::bindParallelExtents(
505 const ParallelExtentMap& parallel_extents,
506 const LaunchParams& launch_constraint) {
507 // Bind values of extents of parallelized
508 // iterdomains from launch_constraint when applicable.
509 // Consistency will be checked at validate().
510 for (const auto& it : parallel_extents) {
511 auto raw_val = launch_constraint.getRawVal(it.first);
512 if (raw_val > 0) {
513 for (auto extent : it.second) {
514 bindValue(extent->evaluatorIndex(), raw_val);
515 }
516 }
517 }
518}
519
520void KernelPrecomputedValues::bindConcreteParallelTypeValue(
521 ParallelType pt,
522 int64_t value) {
523 auto index_list_it = thread_dim_value_indices_.find(pt);
524 if (index_list_it != thread_dim_value_indices_.end()) {
525 for (auto index : *(index_list_it->second)) {
526 bindValue(index, value);
527 }
528 }
529}
530
531FusionPrecomputedValues::FusionPrecomputedValues(Fusion* fusion)
532 : fusion_(fusion) {
533 loadSymbols(collectRuntimeUsedValues(fusion));
534 ExpressionEvaluator evaluator(fusion);
535 initializeValueList(evaluator, symbols());
536 initializeIntegerMachine();
537}
538
539// TODO: put this to base class
540void FusionPrecomputedValues::bindTensorMetaData(
541 TensorView* tv,
542 const TensorArgAbstract* tensor_arg_abstract) {
543 const auto root_domain =
544 TensorDomain::noReductions(tv->getMaybeRFactorDomain());
545 TORCH_INTERNAL_ASSERT(
546 tensor_arg_abstract->getRank() == static_cast<int>(root_domain.size()),
547 "Something went wrong configuring launch. Inputs do not match.");
548
549 for (const auto dim : c10::irange(root_domain.size())) {
550 auto extent = root_domain[dim]->extent();
551 auto value = tensor_arg_abstract->getSize(dim);
552 precomputedValuesBaseType::bindValue(extent->evaluatorIndex(), value);
553 }
554}
555
556void FusionPrecomputedValues::bindFusionInputs(
557 const KernelArgumentHolder& args) {
558 if (hasValidValues()) {
559 precomputedValuesBaseType::invalidate();
560 }
561
562 const auto& inputs = fusion_->inputs();
563 TORCH_INTERNAL_ASSERT(
564 args.size() == inputs.size(), "kernel inputs size does not match args");
565
566 for (const auto i : c10::irange(inputs.size())) {
567 const auto input = inputs[i];
568 const ArgAbstract* arg = args[i];
569 if (auto tensor_input = dynamic_cast<TensorView*>(input)) {
570 if (const auto& tensor_arg_abstract =
571 dynamic_cast<const TensorArgAbstract*>(arg)) {
572 bindTensorMetaData(tensor_input, tensor_arg_abstract);
573 } else {
574 TORCH_CHECK(
575 arg->isType(ArgType::CpuScalarTensor),
576 "binding input to TensorView expects input arg to be of tensor type");
577 }
578 } else if (input->isScalar()) {
579 if (input->getDataType() == DataType::Int) {
580 TORCH_CHECK(
581 arg->isType(ArgType::Long),
582 "binding input to integer type expects input arg to be a scalar of Long type");
583 precomputedValuesBaseType::bindValue(
584 input->evaluatorIndex(), *static_cast<const int64_t*>(arg->arg()));
585 } else if (input->getDataType() == DataType::Double) {
586 TORCH_CHECK(
587 arg->isType(ArgType::Double),
588 "binding input to double type expects input arg to be a scalar of Double type");
589 precomputedValuesBaseType::bindValue(
590 input->evaluatorIndex(), *static_cast<const double*>(arg->arg()));
591 }
592 }
593 }
594}
595
596template class PrecomputedValuesBase<FusionIRContext>;
597template class PrecomputedValuesBase<KernelIRContext>;
598
599} // namespace cuda
600} // namespace fuser
601} // namespace jit
602} // namespace torch
603