1 | |
2 | #pragma once |
3 | |
4 | #include <c10/macros/Export.h> |
5 | |
6 | #include <dispatch.h> |
7 | #include <dynamic_type.h> |
8 | #include <evaluator_common.h> |
9 | #include <kernel_ir.h> |
10 | |
11 | #include <c10/util/Optional.h> |
12 | |
13 | #include <unordered_map> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | namespace fuser { |
18 | namespace cuda { |
19 | |
20 | class GpuLower; |
21 | |
22 | namespace kir { |
23 | |
24 | //! Calculate Kernel IR expressions |
25 | //! |
26 | //! How to evaluate Kernel IR expressions: |
27 | //! |
28 | //! ```cpp |
29 | //! kir::ExpressionEvaluator eval; |
30 | //! eval.bind(symbolic_value, concrete_value); |
31 | //! ... bind more values ... |
32 | //! const auto result = eval.evaluate(interesting_value); |
33 | //! if (result.has_value()) { |
34 | //! ... we have successfully calculated the result ... |
35 | //! } else { |
36 | //! ... expression can't be evaluated ... |
37 | //! } |
38 | //! ``` |
39 | //! |
40 | class TORCH_CUDA_CU_API ExpressionEvaluator : private OptInConstDispatch { |
41 | public: |
42 | //! Set a concrete value for a symbolic value |
43 | void bind(const Val* value, IntOrDouble concrete_value); |
44 | |
45 | //! Set a concrete value for a parallel dimension |
46 | void bind(ParallelType pt, Int::ScalarType concrete_value); |
47 | |
48 | //! Try to evaluate a Kernel IR value |
49 | c10::optional<IntOrDouble> evaluate(const Val* value); |
50 | |
51 | //! Returns true if `value` is known before binding kernel inputs |
52 | static bool isConst(const Val* value); |
53 | |
54 | //! Debugging helper, prints all the currently known values |
55 | void print() const; |
56 | |
57 | auto& precomputedValues() { |
58 | return precomputed_values_; |
59 | } |
60 | |
61 | private: |
62 | void handle(const Int* value) final; |
63 | void handle(const Double* value) final; |
64 | void handle(const NamedScalar* named_scalar) final; |
65 | void handle(const UnaryOp* unary_op) final; |
66 | void handle(const BinaryOp* binary_op) final; |
67 | |
68 | private: |
69 | std::unordered_map<const Val*, IntOrDouble> known_values_; |
70 | KernelPrecomputedValues* precomputed_values_ = nullptr; |
71 | std::unordered_map<ParallelType, Int::ScalarType, TypeHash> |
72 | known_parallel_dimensions_; |
73 | }; |
74 | |
75 | } // namespace kir |
76 | } // namespace cuda |
77 | } // namespace fuser |
78 | } // namespace jit |
79 | } // namespace torch |
80 | |