1#pragma once
2
3#include <c10/macros/Export.h>
4#include <dynamic_type.h>
5#include <ir_interface_nodes.h>
6#include <iter_visitor.h>
7
8#include <c10/util/Optional.h>
9
10#include <string>
11#include <unordered_map>
12
13namespace torch {
14namespace jit {
15namespace fuser {
16namespace cuda {
17
18class FusionPrecomputedValues;
19
20//! Calculate Fusion IR expressions
21class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch {
22 public:
23 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
24 explicit ExpressionEvaluator(Fusion* fusion) : fusion_(fusion) {}
25
26 //! Returns the associated fusion object
27 Fusion* fusion() const {
28 return fusion_;
29 }
30
31 //! Bind a concrete value to an IR variable
32 void bind(Val* value, const IntOrDouble& concrete_value);
33
34 //! Bind a concrete value to a named scalar
35 void bind(const std::string& name, const IntOrDouble& concrete_value);
36
37 //! Try to evaluate a Fusion IR value
38 c10::optional<IntOrDouble> evaluate(Val* value);
39
40 //! Debugging helper, prints all the currently known values
41 void print() const;
42
43 void bindPrecomputedValues(FusionPrecomputedValues* precomputed_values) {
44 evaluator_precomputed_values_ = precomputed_values;
45 }
46
47 auto precomputedValues() {
48 return evaluator_precomputed_values_;
49 }
50
51 private:
52 c10::optional<IntOrDouble> getValue(Val* value);
53
54 void handle(UnaryOp*) final;
55 void handle(BinaryOp*) final;
56 // TODO: handle swizzle
57
58 private:
59 std::unordered_map<const Val*, IntOrDouble> known_values_;
60 std::unordered_map<std::string, IntOrDouble> known_named_scalars_;
61 Fusion* fusion_ = nullptr;
62 FusionPrecomputedValues* evaluator_precomputed_values_ = nullptr;
63};
64
65} // namespace cuda
66} // namespace fuser
67} // namespace jit
68} // namespace torch
69