1
2#pragma once
3
4#include <c10/macros/Export.h>
5
6#include <dispatch.h>
7#include <dynamic_type.h>
8#include <evaluator_common.h>
9#include <kernel_ir.h>
10
11#include <c10/util/Optional.h>
12
13#include <unordered_map>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20class GpuLower;
21
22namespace kir {
23
24//! Calculate Kernel IR expressions
25//!
26//! How to evaluate Kernel IR expressions:
27//!
28//! ```cpp
29//! kir::ExpressionEvaluator eval;
30//! eval.bind(symbolic_value, concrete_value);
31//! ... bind more values ...
32//! const auto result = eval.evaluate(interesting_value);
33//! if (result.has_value()) {
34//! ... we have successfully calculated the result ...
35//! } else {
36//! ... expression can't be evaluated ...
37//! }
38//! ```
39//!
40class TORCH_CUDA_CU_API ExpressionEvaluator : private OptInConstDispatch {
41 public:
42 //! Set a concrete value for a symbolic value
43 void bind(const Val* value, IntOrDouble concrete_value);
44
45 //! Set a concrete value for a parallel dimension
46 void bind(ParallelType pt, Int::ScalarType concrete_value);
47
48 //! Try to evaluate a Kernel IR value
49 c10::optional<IntOrDouble> evaluate(const Val* value);
50
51 //! Returns true if `value` is known before binding kernel inputs
52 static bool isConst(const Val* value);
53
54 //! Debugging helper, prints all the currently known values
55 void print() const;
56
57 auto& precomputedValues() {
58 return precomputed_values_;
59 }
60
61 private:
62 void handle(const Int* value) final;
63 void handle(const Double* value) final;
64 void handle(const NamedScalar* named_scalar) final;
65 void handle(const UnaryOp* unary_op) final;
66 void handle(const BinaryOp* binary_op) final;
67
68 private:
69 std::unordered_map<const Val*, IntOrDouble> known_values_;
70 KernelPrecomputedValues* precomputed_values_ = nullptr;
71 std::unordered_map<ParallelType, Int::ScalarType, TypeHash>
72 known_parallel_dimensions_;
73};
74
75} // namespace kir
76} // namespace cuda
77} // namespace fuser
78} // namespace jit
79} // namespace torch
80