1 | #pragma once |
2 | |
3 | #include <ATen/core/functional.h> |
4 | #include <ATen/core/ivalue.h> |
5 | #include <c10/util/Optional.h> |
6 | #include <torch/csrc/jit/api/method.h> |
7 | |
8 | #include <utility> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | struct Resolver; |
14 | using ResolverPtr = std::shared_ptr<Resolver>; |
15 | |
16 | using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>; |
17 | |
18 | // Throw this in C++ land if `attr` fails. This will be converted to a Python |
19 | // AttributeError by the Python binding code |
20 | class ObjectAttributeError : public std::runtime_error { |
21 | public: |
22 | ObjectAttributeError(const std::string& what) : std::runtime_error(what) {} |
23 | }; |
24 | |
25 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
26 | struct TORCH_API Object { |
27 | Object() = default; |
28 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
29 | Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {} |
30 | Object(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type); |
31 | Object( |
32 | c10::QualifiedName, |
33 | std::shared_ptr<CompilationUnit> cu, |
34 | bool shouldMangle = false); |
35 | |
36 | ObjectPtr _ivalue() const { |
37 | TORCH_INTERNAL_ASSERT(_ivalue_); |
38 | return _ivalue_; |
39 | } |
40 | |
41 | c10::ClassTypePtr type() const { |
42 | return _ivalue()->type(); |
43 | } |
44 | |
45 | struct Property { |
46 | std::string name; |
47 | Method getter_func; |
48 | c10::optional<Method> setter_func; |
49 | }; |
50 | |
51 | void setattr(const std::string& name, c10::IValue v) { |
52 | if (_ivalue()->type()->hasConstant(name)) { |
53 | TORCH_CHECK( |
54 | false, |
55 | "Can't set constant '" , |
56 | name, |
57 | "' which has value:" , |
58 | _ivalue()->type()->getConstant(name)); |
59 | } else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) { |
60 | const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot); |
61 | TORCH_CHECK( |
62 | v.type()->isSubtypeOf(*expected), |
63 | "Expected a value of type '" , |
64 | expected->repr_str(), |
65 | "' for field '" , |
66 | name, |
67 | "', but found '" , |
68 | v.type()->repr_str(), |
69 | "'" ); |
70 | _ivalue()->setSlot(*slot, std::move(v)); |
71 | } else { |
72 | TORCH_CHECK(false, "Module has no attribute '" , name, "'" ); |
73 | } |
74 | } |
75 | |
76 | c10::IValue attr(const std::string& name) const { |
77 | if (auto r = _ivalue()->type()->findAttributeSlot(name)) { |
78 | return _ivalue()->getSlot(*r); |
79 | } |
80 | if (auto r = _ivalue()->type()->findConstantSlot(name)) { |
81 | return _ivalue()->type()->getConstant(*r); |
82 | } |
83 | std::stringstream err; |
84 | err << _ivalue()->type()->repr_str() << " does not have a field with name '" |
85 | << name.c_str() << "'" ; |
86 | throw ObjectAttributeError(err.str()); |
87 | } |
88 | |
89 | c10::IValue attr(const std::string& name, c10::IValue or_else) const { |
90 | if (auto r = _ivalue()->type()->findAttributeSlot(name)) { |
91 | return _ivalue()->getSlot(*r); |
92 | } |
93 | if (auto r = _ivalue()->type()->findConstantSlot(name)) { |
94 | return _ivalue()->type()->getConstant(*r); |
95 | } |
96 | return or_else; |
97 | } |
98 | |
99 | bool hasattr(const std::string& name) const { |
100 | return _ivalue()->type()->hasAttribute(name) || |
101 | _ivalue()->type()->hasConstant(name); |
102 | } |
103 | |
104 | // each object owns its methods. The reference returned here |
105 | // is guaranteed to stay valid until this module has been destroyed |
106 | Method get_method(const std::string& name) const { |
107 | if (auto method = find_method(name)) { |
108 | return *method; |
109 | } |
110 | AT_ERROR("Method '" , name, "' is not defined." ); |
111 | } |
112 | |
113 | const std::vector<Method> get_methods() const { |
114 | return c10::fmap(type()->methods(), [&](Function* func) { |
115 | return Method(_ivalue(), func); |
116 | }); |
117 | } |
118 | |
119 | bool has_property(const std::string& name) const { |
120 | for (const auto& prop : type()->properties()) { |
121 | if (prop.name == name) { |
122 | return true; |
123 | } |
124 | } |
125 | return false; |
126 | } |
127 | |
128 | const Property get_property(const std::string& name) const { |
129 | for (const auto& prop : type()->properties()) { |
130 | if (prop.name == name) { |
131 | c10::optional<Method> setter = c10::nullopt; |
132 | if (prop.setter) { |
133 | setter = Method(_ivalue(), prop.setter); |
134 | } |
135 | return Property{ |
136 | prop.name, Method(_ivalue(), prop.getter), std::move(setter)}; |
137 | } |
138 | } |
139 | AT_ERROR("Property '" , name, "' is not defined." ); |
140 | } |
141 | |
142 | const std::vector<Property> get_properties() const { |
143 | return c10::fmap(type()->properties(), [&](ClassType::Property prop) { |
144 | c10::optional<Method> setter = c10::nullopt; |
145 | if (prop.setter) { |
146 | setter = Method(_ivalue(), prop.setter); |
147 | } |
148 | return Property{ |
149 | prop.name, Method(_ivalue(), prop.getter), std::move(setter)}; |
150 | }); |
151 | } |
152 | |
153 | c10::optional<Method> find_method(const std::string& basename) const; |
154 | |
155 | /// Run a method from this module. |
156 | /// |
157 | /// For example: |
158 | /// @code |
159 | /// IValue output = module->run("relu_script", a, b); |
160 | /// @endcode |
161 | /// |
162 | /// To get a compile a module from a source string, see torch::jit::compile |
163 | /// |
164 | /// @param method_name The name of the method to run |
165 | /// @param args Arguments to be passed to the method |
166 | /// @return An IValue containing the return value (or values if it is a tuple) |
167 | /// from the method |
168 | template <typename... Types> |
169 | IValue run_method(const std::string& method_name, Types&&... args) { |
170 | return get_method(method_name)({IValue(std::forward<Types>(args))...}); |
171 | } |
172 | |
173 | // so that C++ users can easily add methods |
174 | void define(const std::string& src, const ResolverPtr& resolver = nullptr); |
175 | |
176 | size_t num_slots() const { |
177 | return _ivalue()->slots().size(); |
178 | } |
179 | |
180 | // shallow copy the object |
181 | Object copy() const; |
182 | |
183 | // Copies all the attributes of the object recursively without creating new |
184 | // `ClassType`, including deepcopy of Tensors |
185 | Object deepcopy() const; |
186 | |
187 | private: |
188 | // mutable be we lazily initialize in module_object. |
189 | mutable ObjectPtr _ivalue_; |
190 | }; |
191 | |
192 | namespace script { |
193 | // We once had a `script::` namespace that was deleted. This is for backcompat |
194 | // of the public API; new code should not use this type alias. |
195 | using Object = ::torch::jit::Object; |
196 | } // namespace script |
197 | } // namespace jit |
198 | } // namespace torch |
199 | |