1#pragma once
2#include <ATen/core/ivalue.h>
3#include <torch/csrc/jit/frontend/source_range.h>
4#include <torch/csrc/jit/ir/constants.h>
5#include <torch/csrc/utils/variadic.h>
6
7namespace torch {
8namespace jit {
9
10struct Value;
11
12/**
13 * A value with optional extra name and location information. Used during
14 * schema matching to provide extra error information and resolve kwargs.
15 */
16struct NamedValue {
17 NamedValue(const SourceRange& loc, const std::string& name, Value* value)
18 : loc_(loc), name_(name), value_(value) {}
19 NamedValue(const SourceRange& loc, Value* value) : loc_(loc), value_(value) {}
20
21 /* implicit */ NamedValue(Value* value) : value_(value) {}
22 NamedValue(const std::string& name, Value* value)
23 : name_(name), value_(value) {}
24
25 /* implicit */ NamedValue(IValue value)
26 : value_(nullptr), ivalue_(std::move(value)) {}
27
28 NamedValue(const std::string& name, IValue value)
29 : name_(name), ivalue_(std::move(value)) {}
30
31 template <
32 typename T,
33 typename = enable_if_t<
34 (!std::is_same<decay_t<T>, NamedValue>::value &&
35 !std::is_same<decay_t<T>, Value*>::value &&
36 !std::is_same<decay_t<T>, IValue>::value)>>
37 // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
38 NamedValue(T&& t) : NamedValue(IValue(std::forward<T>(t))) {}
39
40 template <
41 typename T,
42 typename = enable_if_t<
43 (!std::is_same<decay_t<T>, Value*>::value &&
44 !std::is_same<decay_t<T>, IValue>::value)>>
45 NamedValue(const std::string& name, T&& t)
46 : NamedValue(name, IValue(std::forward<T>(t))) {}
47
48 SourceRange locOr(const SourceRange& backup_location) const {
49 if (!loc_)
50 return backup_location;
51 return loc();
52 }
53
54 // note: this will insert a constant node into the graph at the current
55 // insert point if this NamedValue is actually a constant
56 Value* value(Graph& g) const {
57 if (!value_)
58 return insertConstant(
59 g, ivalue_); // use insertConstant to remove need to include ir.h here
60 return value_;
61 }
62
63 const std::string& name() const {
64 AT_ASSERT(name_);
65 return *name_;
66 }
67
68 const SourceRange& loc() const {
69 AT_ASSERT(loc_);
70 return *loc_;
71 }
72
73 at::TypePtr type() const;
74
75 private:
76 c10::optional<SourceRange> loc_;
77 c10::optional<std::string> name_;
78 Value* value_{nullptr};
79 // only valid if value_ == nullptr;
80 IValue ivalue_;
81};
82
83} // namespace jit
84} // namespace torch
85