1 | #pragma once |
---|---|
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <dynamic_type.h> |
5 | #include <ir_interface_nodes.h> |
6 | #include <iter_visitor.h> |
7 | |
8 | #include <c10/util/Optional.h> |
9 | |
10 | #include <string> |
11 | #include <unordered_map> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | namespace fuser { |
16 | namespace cuda { |
17 | |
18 | class FusionPrecomputedValues; |
19 | |
20 | //! Calculate Fusion IR expressions |
21 | class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch { |
22 | public: |
23 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
24 | explicit ExpressionEvaluator(Fusion* fusion) : fusion_(fusion) {} |
25 | |
26 | //! Returns the associated fusion object |
27 | Fusion* fusion() const { |
28 | return fusion_; |
29 | } |
30 | |
31 | //! Bind a concrete value to an IR variable |
32 | void bind(Val* value, const IntOrDouble& concrete_value); |
33 | |
34 | //! Bind a concrete value to a named scalar |
35 | void bind(const std::string& name, const IntOrDouble& concrete_value); |
36 | |
37 | //! Try to evaluate a Fusion IR value |
38 | c10::optional<IntOrDouble> evaluate(Val* value); |
39 | |
40 | //! Debugging helper, prints all the currently known values |
41 | void print() const; |
42 | |
43 | void bindPrecomputedValues(FusionPrecomputedValues* precomputed_values) { |
44 | evaluator_precomputed_values_ = precomputed_values; |
45 | } |
46 | |
47 | auto precomputedValues() { |
48 | return evaluator_precomputed_values_; |
49 | } |
50 | |
51 | private: |
52 | c10::optional<IntOrDouble> getValue(Val* value); |
53 | |
54 | void handle(UnaryOp*) final; |
55 | void handle(BinaryOp*) final; |
56 | // TODO: handle swizzle |
57 | |
58 | private: |
59 | std::unordered_map<const Val*, IntOrDouble> known_values_; |
60 | std::unordered_map<std::string, IntOrDouble> known_named_scalars_; |
61 | Fusion* fusion_ = nullptr; |
62 | FusionPrecomputedValues* evaluator_precomputed_values_ = nullptr; |
63 | }; |
64 | |
65 | } // namespace cuda |
66 | } // namespace fuser |
67 | } // namespace jit |
68 | } // namespace torch |
69 |