1#pragma once
2
3#include <ATen/core/functional.h>
4#include <ATen/core/ivalue.h>
5#include <c10/util/Optional.h>
6#include <torch/csrc/jit/api/method.h>
7
8#include <utility>
9
10namespace torch {
11namespace jit {
12
13struct Resolver;
14using ResolverPtr = std::shared_ptr<Resolver>;
15
16using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
17
18// Throw this in C++ land if `attr` fails. This will be converted to a Python
19// AttributeError by the Python binding code
20class ObjectAttributeError : public std::runtime_error {
21 public:
22 ObjectAttributeError(const std::string& what) : std::runtime_error(what) {}
23};
24
25// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
26struct TORCH_API Object {
27 Object() = default;
28 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
29 Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {}
30 Object(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
31 Object(
32 c10::QualifiedName,
33 std::shared_ptr<CompilationUnit> cu,
34 bool shouldMangle = false);
35
36 ObjectPtr _ivalue() const {
37 TORCH_INTERNAL_ASSERT(_ivalue_);
38 return _ivalue_;
39 }
40
41 c10::ClassTypePtr type() const {
42 return _ivalue()->type();
43 }
44
45 struct Property {
46 std::string name;
47 Method getter_func;
48 c10::optional<Method> setter_func;
49 };
50
51 void setattr(const std::string& name, c10::IValue v) {
52 if (_ivalue()->type()->hasConstant(name)) {
53 TORCH_CHECK(
54 false,
55 "Can't set constant '",
56 name,
57 "' which has value:",
58 _ivalue()->type()->getConstant(name));
59 } else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) {
60 const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot);
61 TORCH_CHECK(
62 v.type()->isSubtypeOf(*expected),
63 "Expected a value of type '",
64 expected->repr_str(),
65 "' for field '",
66 name,
67 "', but found '",
68 v.type()->repr_str(),
69 "'");
70 _ivalue()->setSlot(*slot, std::move(v));
71 } else {
72 TORCH_CHECK(false, "Module has no attribute '", name, "'");
73 }
74 }
75
76 c10::IValue attr(const std::string& name) const {
77 if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
78 return _ivalue()->getSlot(*r);
79 }
80 if (auto r = _ivalue()->type()->findConstantSlot(name)) {
81 return _ivalue()->type()->getConstant(*r);
82 }
83 std::stringstream err;
84 err << _ivalue()->type()->repr_str() << " does not have a field with name '"
85 << name.c_str() << "'";
86 throw ObjectAttributeError(err.str());
87 }
88
89 c10::IValue attr(const std::string& name, c10::IValue or_else) const {
90 if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
91 return _ivalue()->getSlot(*r);
92 }
93 if (auto r = _ivalue()->type()->findConstantSlot(name)) {
94 return _ivalue()->type()->getConstant(*r);
95 }
96 return or_else;
97 }
98
99 bool hasattr(const std::string& name) const {
100 return _ivalue()->type()->hasAttribute(name) ||
101 _ivalue()->type()->hasConstant(name);
102 }
103
104 // each object owns its methods. The reference returned here
105 // is guaranteed to stay valid until this module has been destroyed
106 Method get_method(const std::string& name) const {
107 if (auto method = find_method(name)) {
108 return *method;
109 }
110 AT_ERROR("Method '", name, "' is not defined.");
111 }
112
113 const std::vector<Method> get_methods() const {
114 return c10::fmap(type()->methods(), [&](Function* func) {
115 return Method(_ivalue(), func);
116 });
117 }
118
119 bool has_property(const std::string& name) const {
120 for (const auto& prop : type()->properties()) {
121 if (prop.name == name) {
122 return true;
123 }
124 }
125 return false;
126 }
127
128 const Property get_property(const std::string& name) const {
129 for (const auto& prop : type()->properties()) {
130 if (prop.name == name) {
131 c10::optional<Method> setter = c10::nullopt;
132 if (prop.setter) {
133 setter = Method(_ivalue(), prop.setter);
134 }
135 return Property{
136 prop.name, Method(_ivalue(), prop.getter), std::move(setter)};
137 }
138 }
139 AT_ERROR("Property '", name, "' is not defined.");
140 }
141
142 const std::vector<Property> get_properties() const {
143 return c10::fmap(type()->properties(), [&](ClassType::Property prop) {
144 c10::optional<Method> setter = c10::nullopt;
145 if (prop.setter) {
146 setter = Method(_ivalue(), prop.setter);
147 }
148 return Property{
149 prop.name, Method(_ivalue(), prop.getter), std::move(setter)};
150 });
151 }
152
153 c10::optional<Method> find_method(const std::string& basename) const;
154
155 /// Run a method from this module.
156 ///
157 /// For example:
158 /// @code
159 /// IValue output = module->run("relu_script", a, b);
160 /// @endcode
161 ///
162 /// To get a compile a module from a source string, see torch::jit::compile
163 ///
164 /// @param method_name The name of the method to run
165 /// @param args Arguments to be passed to the method
166 /// @return An IValue containing the return value (or values if it is a tuple)
167 /// from the method
168 template <typename... Types>
169 IValue run_method(const std::string& method_name, Types&&... args) {
170 return get_method(method_name)({IValue(std::forward<Types>(args))...});
171 }
172
173 // so that C++ users can easily add methods
174 void define(const std::string& src, const ResolverPtr& resolver = nullptr);
175
176 size_t num_slots() const {
177 return _ivalue()->slots().size();
178 }
179
180 // shallow copy the object
181 Object copy() const;
182
183 // Copies all the attributes of the object recursively without creating new
184 // `ClassType`, including deepcopy of Tensors
185 Object deepcopy() const;
186
187 private:
188 // mutable be we lazily initialize in module_object.
189 mutable ObjectPtr _ivalue_;
190};
191
192namespace script {
193// We once had a `script::` namespace that was deleted. This is for backcompat
194// of the public API; new code should not use this type alias.
195using Object = ::torch::jit::Object;
196} // namespace script
197} // namespace jit
198} // namespace torch
199