1 | #pragma once |
2 | #include <ATen/core/ivalue.h> |
3 | #include <torch/csrc/jit/frontend/source_range.h> |
4 | #include <torch/csrc/jit/ir/constants.h> |
5 | #include <torch/csrc/utils/variadic.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | struct Value; |
11 | |
12 | /** |
13 | * A value with optional extra name and location information. Used during |
14 | * schema matching to provide extra error information and resolve kwargs. |
15 | */ |
16 | struct NamedValue { |
17 | NamedValue(const SourceRange& loc, const std::string& name, Value* value) |
18 | : loc_(loc), name_(name), value_(value) {} |
19 | NamedValue(const SourceRange& loc, Value* value) : loc_(loc), value_(value) {} |
20 | |
21 | /* implicit */ NamedValue(Value* value) : value_(value) {} |
22 | NamedValue(const std::string& name, Value* value) |
23 | : name_(name), value_(value) {} |
24 | |
25 | /* implicit */ NamedValue(IValue value) |
26 | : value_(nullptr), ivalue_(std::move(value)) {} |
27 | |
28 | NamedValue(const std::string& name, IValue value) |
29 | : name_(name), ivalue_(std::move(value)) {} |
30 | |
31 | template < |
32 | typename T, |
33 | typename = enable_if_t< |
34 | (!std::is_same<decay_t<T>, NamedValue>::value && |
35 | !std::is_same<decay_t<T>, Value*>::value && |
36 | !std::is_same<decay_t<T>, IValue>::value)>> |
37 | // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) |
38 | NamedValue(T&& t) : NamedValue(IValue(std::forward<T>(t))) {} |
39 | |
40 | template < |
41 | typename T, |
42 | typename = enable_if_t< |
43 | (!std::is_same<decay_t<T>, Value*>::value && |
44 | !std::is_same<decay_t<T>, IValue>::value)>> |
45 | NamedValue(const std::string& name, T&& t) |
46 | : NamedValue(name, IValue(std::forward<T>(t))) {} |
47 | |
48 | SourceRange locOr(const SourceRange& backup_location) const { |
49 | if (!loc_) |
50 | return backup_location; |
51 | return loc(); |
52 | } |
53 | |
54 | // note: this will insert a constant node into the graph at the current |
55 | // insert point if this NamedValue is actually a constant |
56 | Value* value(Graph& g) const { |
57 | if (!value_) |
58 | return insertConstant( |
59 | g, ivalue_); // use insertConstant to remove need to include ir.h here |
60 | return value_; |
61 | } |
62 | |
63 | const std::string& name() const { |
64 | AT_ASSERT(name_); |
65 | return *name_; |
66 | } |
67 | |
68 | const SourceRange& loc() const { |
69 | AT_ASSERT(loc_); |
70 | return *loc_; |
71 | } |
72 | |
73 | at::TypePtr type() const; |
74 | |
75 | private: |
76 | c10::optional<SourceRange> loc_; |
77 | c10::optional<std::string> name_; |
78 | Value* value_{nullptr}; |
79 | // only valid if value_ == nullptr; |
80 | IValue ivalue_; |
81 | }; |
82 | |
83 | } // namespace jit |
84 | } // namespace torch |
85 | |