1 | #pragma once |
2 | |
3 | #include <memory> |
4 | |
5 | #include <ATen/core/ivalue.h> |
6 | #include <ATen/core/jit_type_base.h> |
7 | #include <c10/util/Optional.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | struct CompilationUnit; |
12 | struct Function; |
13 | } // namespace jit |
14 | } // namespace torch |
15 | |
16 | namespace c10 { |
17 | |
18 | struct FunctionSchema; |
19 | |
20 | // This enumerator represents the 'kind' of an attribute - a buffer, a parameter, or neither. |
21 | // This state is mutually exclusive. Buffers and Parameters can only appear on modules. |
22 | enum class AttributeKind { |
23 | BUFFER, |
24 | PARAMETER, |
25 | REGULAR_ATTRIBUTE |
26 | }; |
27 | |
28 | // This structure represents all notional booking entities in a class attribute: name, kind (see: AttributeKind), and type (see: TypePtr). |
29 | // Note: This structure does not represent the value of the attribute. |
30 | struct TORCH_API ClassAttribute { |
31 | public: |
32 | ClassAttribute(AttributeKind kind, |
33 | TypePtr attributeType, |
34 | std::string attributeName) : |
35 | kind_(kind), |
36 | attributeType_(std::move(attributeType)), |
37 | attributeName_(std::move(attributeName)) {} |
38 | |
39 | AttributeKind getKind() const { |
40 | return kind_; |
41 | } |
42 | |
43 | const TypePtr& getType() const { |
44 | return attributeType_; |
45 | } |
46 | |
47 | const std::string& getName() const { |
48 | return attributeName_; |
49 | } |
50 | |
51 | private: |
52 | AttributeKind kind_; |
53 | TypePtr attributeType_; |
54 | std::string attributeName_; |
55 | }; |
56 | |
57 | /** |
58 | * User Defined Types |
59 | */ |
60 | |
61 | struct ClassType; |
62 | using ClassTypePtr = std::shared_ptr<ClassType>; |
63 | using ::torch::jit::CompilationUnit; |
64 | |
65 | // This represents a class in TorchScript. |
66 | struct TORCH_API ClassType : public NamedType { |
67 | // This represents an attribute of a class; a name associated with an attribute, and a |
68 | // getter and (optional) setter for that attribute. |
69 | struct Property { |
70 | std::string name; |
71 | torch::jit::Function* getter; |
72 | torch::jit::Function* setter; |
73 | }; |
74 | |
75 | // Create a class type with name `name` and its methods stored in `cu`. |
76 | static ClassTypePtr create( |
77 | c10::optional<QualifiedName> qualifiedName, |
78 | std::weak_ptr<CompilationUnit> cu, |
79 | bool is_module = false, |
80 | std::string doc_string = "" , |
81 | std::vector<std::string> unresolved_class_attributes = {}); |
82 | |
83 | bool equals(const Type& rhs) const override { |
84 | if (this == &rhs) { |
85 | return true; |
86 | } |
87 | if (auto user_rhs = rhs.castRaw<ClassType>()) { |
88 | const auto& lhs_name = name().value(); |
89 | const auto& rhs_name = user_rhs->name().value(); |
90 | |
91 | return lhs_name == rhs_name && |
92 | this->compilation_unit() == user_rhs->compilation_unit(); |
93 | } |
94 | return false; |
95 | } |
96 | |
97 | std::string str() const override { |
98 | return annotation_str(); |
99 | } |
100 | |
101 | std::string repr_str() const override { |
102 | std::stringstream ss; |
103 | ss << str() |
104 | << " (of Python compilation unit at: " << compilation_unit().get() << ")" ; |
105 | return ss.str(); |
106 | } |
107 | |
108 | const std::vector<torch::jit::Function*>& methods() const; |
109 | |
110 | TypePtr findAttribute(const std::string& name) const { |
111 | size_t pos = 0; |
112 | for (const auto& attr : attributes_) { |
113 | if (name == attr.getName()) { |
114 | break; |
115 | } |
116 | ++pos; |
117 | } |
118 | |
119 | if (pos >= attributes_.size()) { |
120 | return nullptr; |
121 | } |
122 | return attributes_[pos].getType(); |
123 | } |
124 | |
125 | const TypePtr& getAttribute(const std::string& name) const { |
126 | auto slot = findAttributeSlot(name); |
127 | TORCH_CHECK( |
128 | slot, |
129 | repr_str(), |
130 | " does not have an attribute with name '" , |
131 | name, |
132 | "'" ); |
133 | return attributes_[*slot].getType(); |
134 | } |
135 | |
136 | size_t numAttributes() const { |
137 | return attributes_.size(); |
138 | } |
139 | |
140 | const TypePtr& getAttribute(size_t slot) const { |
141 | AT_ASSERT(slot < attributes_.size()); |
142 | return attributes_.at(slot).getType(); |
143 | } |
144 | |
145 | const std::string getAttributeName(size_t slot) const { |
146 | AT_ASSERT(slot < attributes_.size()); |
147 | return attributes_[slot].getName(); |
148 | } |
149 | |
150 | void checkNotExist(const std::string& name, const std::string& what) const; |
151 | |
152 | // Attributes are stored in a specific slot at runtime for effiency. |
153 | // When emitting instructions we specify the slot so that attribute access is |
154 | // a constant lookup |
155 | c10::optional<size_t> findAttributeSlot(const std::string& name) const { |
156 | size_t slot = 0; |
157 | for (const auto& attr : attributes_) { |
158 | if (name == attr.getName()) { |
159 | return slot; |
160 | } |
161 | slot++; |
162 | } |
163 | return c10::nullopt; |
164 | } |
165 | size_t getAttributeSlot(const std::string& name) const { |
166 | if (auto r = findAttributeSlot(name)) { |
167 | return *r; |
168 | } |
169 | TORCH_CHECK( |
170 | false, |
171 | repr_str(), |
172 | " does not have an attribute with name '" , |
173 | name, |
174 | "'" ); |
175 | } |
176 | |
177 | bool hasAttribute(const std::string& name) const { |
178 | return std::find_if( |
179 | attributes_.cbegin(), |
180 | attributes_.cend(), |
181 | [&](const ClassAttribute& attr) { return attr.getName() == name; }) != |
182 | attributes_.cend(); |
183 | } |
184 | |
185 | bool isUnresolvedClassAttribute(const std::string& name) const; |
186 | |
187 | at::ArrayRef<TypePtr> containedTypes() const override { |
188 | return attributeTypes_; |
189 | } |
190 | |
191 | size_t addAttribute( |
192 | const std::string& name, |
193 | TypePtr type, |
194 | bool is_parameter = false, |
195 | bool is_buffer = false); |
196 | |
197 | // [Internal Only] Remove attribute from the ClassType, |
198 | // caller is responsible to make sure the modification is safe: |
199 | // it is unsafe to having existing allocations |
200 | // of this object around anymore, and any code that works on |
201 | // the attribute is now invalid. Only newly created code is |
202 | // valid again. |
203 | void unsafeRemoveAttribute(const std::string& name); |
204 | |
205 | // [Internal Only] Change the type of an attribute of the ClassType, |
206 | // The caller is responsible to make sure the modification is safe: |
207 | // it is unsafe to maintain uses of the old type of the attribute, |
208 | // and any code that works on the attribute is now invalid. |
209 | // Only newly created code is valid again. |
210 | void unsafeChangeAttributeType(const std::string& name, TypePtr new_ty); |
211 | |
212 | // Add attribute \p NAME if it doesn't exist or verify that it has a |
213 | // compatible type otherwise. |
214 | size_t addOrCheckAttribute( |
215 | const std::string& name, |
216 | TypePtr ty, |
217 | bool is_parameter = false, |
218 | bool is_buffer = false) { |
219 | auto slot_idx = findAttributeSlot(name); |
220 | if (!slot_idx) { |
221 | return addAttribute(name, std::move(ty), is_parameter, is_buffer); |
222 | } |
223 | |
224 | TORCH_CHECK( |
225 | is_parameter == this->is_parameter(*slot_idx), |
226 | "Parameter field mismatch for the field '" , |
227 | name, |
228 | "'" ); |
229 | const TypePtr& atype = getAttribute(*slot_idx); |
230 | TORCH_CHECK( |
231 | ty->isSubtypeOf(*atype), |
232 | ty->repr_str(), |
233 | " is not compatible with the type " , |
234 | atype->repr_str(), |
235 | " for the field '" , |
236 | name, |
237 | "'" ); |
238 | return *slot_idx; |
239 | } |
240 | |
241 | // Get the property with the given \p name, if it exists on the class. |
242 | c10::optional<ClassType::Property> getProperty(const std::string& name); |
243 | // Add a property named \p name with \p getter and \p setter as its getter and setter. |
244 | void addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter); |
245 | // Get a list of all properties. |
246 | const std::vector<Property>& properties() const { |
247 | return properties_; |
248 | } |
249 | |
250 | bool hasConstant(const std::string& name) const { |
251 | return std::find_if( |
252 | constantNames_.cbegin(), |
253 | constantNames_.cend(), |
254 | [&](const std::string& constant) { return constant == name; }) != |
255 | constantNames_.cend(); |
256 | } |
257 | |
258 | size_t addConstant(const std::string& name, const IValue& value); |
259 | |
260 | c10::optional<size_t> findConstantSlot(const std::string& name) const; |
261 | |
262 | size_t getConstantSlot(const std::string& name) const { |
263 | if (auto r = findConstantSlot(name)) { |
264 | return *r; |
265 | } |
266 | TORCH_CHECK( |
267 | false, |
268 | repr_str(), |
269 | " does not have constant field with the name '" , |
270 | name, |
271 | "'" ); |
272 | } |
273 | |
274 | const std::string& getConstantName(size_t slot) const; |
275 | |
276 | const std::string& doc_string() const { |
277 | return doc_string_; |
278 | } |
279 | |
280 | IValue getConstant(const std::string& name) const; |
281 | |
282 | IValue getConstant(size_t slot) const; |
283 | |
284 | c10::optional<IValue> findConstant(const std::string& name) const; |
285 | |
286 | size_t numConstants() const; |
287 | |
288 | at::ArrayRef<std::string> constantNames() const { |
289 | return constantNames_; |
290 | } |
291 | |
292 | at::ArrayRef<IValue> constantValues() const; |
293 | |
294 | // [Internal Only] Remove constant from the ClassType |
295 | // caller is responsible to make sure the modification is safe: |
296 | // it is unsafe to having existing allocations |
297 | // of this object around anymore, and any code that works on |
298 | // the attribute is now invalid. Only newly created code is |
299 | // valid again. |
300 | void unsafeRemoveConstant(const std::string& name); |
301 | |
302 | TypePtr createWithContained(std::vector<TypePtr> contained_types) const override { |
303 | auto ptr = ClassType::create(name(), compilation_unit_, is_module()); |
304 | AT_ASSERT(numAttributes() == contained_types.size()); |
305 | for(size_t i = 0; i < attributes_.size(); ++i) { |
306 | AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i])); |
307 | ptr->addAttribute(attributes_[i].getName(), std::move(contained_types[i])); |
308 | } |
309 | // Copy methods over |
310 | for (const auto& method : methods()) { |
311 | ptr->addMethod(method); |
312 | } |
313 | return ptr; |
314 | } |
315 | |
316 | bool is_module() const override { |
317 | return isModule_; |
318 | } |
319 | |
320 | const std::vector<ClassAttribute>& getAttributes() const { |
321 | return attributes_; |
322 | } |
323 | |
324 | bool is_parameter(size_t slot) const { |
325 | TORCH_INTERNAL_ASSERT( |
326 | is_module(), "asking for parameterSlots of non-Module" ); |
327 | return attributes_.at(slot).getKind() == AttributeKind::PARAMETER; |
328 | } |
329 | |
330 | bool is_buffer(size_t slot) const { |
331 | TORCH_INTERNAL_ASSERT( |
332 | is_module(), "asking for bufferWrittenSlots of non-Module" ); |
333 | return attributes_.at(slot).getKind() == AttributeKind::BUFFER; |
334 | } |
335 | |
336 | void addForwardPreHook(torch::jit::Function* pre_hook_ptr); |
337 | void addForwardHook(torch::jit::Function* hook_ptr); |
338 | torch::jit::Function* findForwardPreHook(const std::string& name) const; |
339 | torch::jit::Function* findForwardHook(const std::string& name) const; |
340 | const std::vector<torch::jit::Function*>& getForwardHooks() const; |
341 | const std::vector<torch::jit::Function*>& getForwardPreHooks() const; |
342 | |
343 | void checkForwardPreHookSchema( |
344 | int pre_hook_idx, |
345 | const FunctionSchema& pre_hook_schema) const; |
346 | void checkForwardHookSchema( |
347 | int hook_idx, |
348 | const FunctionSchema& hook_schema) const; |
349 | |
350 | void addMethod(torch::jit::Function* method); |
351 | torch::jit::Function* findMethod(const std::string& name) const; |
352 | torch::jit::Function& getMethod(const std::string& name) const; |
353 | torch::jit::Function* findHook(const std::string& name) const; |
354 | torch::jit::Function& getHook(const std::string& name) const; |
355 | bool hasMethod(const std::string& name) const; |
356 | |
357 | torch::jit::Function* findStaticMethod(const std::string& name) const; |
358 | void addStaticMethod(torch::jit::Function* method); |
359 | |
360 | // [Internal Only] Remove method from the ClassType |
361 | // caller is responsible to make sure the modification is safe: |
362 | // it is unsafe to having existing allocations |
363 | // of this object around anymore, and any code that works on |
364 | // the attribute is now invalid. Only newly created code is |
365 | // valid again. |
366 | // Note this method is intended for freezing only. |
367 | void unsafeRemoveMethod(const std::string& name); |
368 | |
369 | std::shared_ptr<CompilationUnit> compilation_unit(); |
370 | |
371 | std::shared_ptr<const CompilationUnit> compilation_unit() const; |
372 | |
373 | // generate a refined version of this class. |
374 | // It has the same name but the slot Types are subtypes of |
375 | // the original slots. It is only valid to refine a class type in a context |
376 | // where it is know that there are not assignments to the objects slots |
377 | // that would invalidate the refinement. |
378 | // These variants are not registered in the global class table. |
379 | ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const; |
380 | |
381 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; |
382 | |
383 | static const TypeKind Kind = TypeKind::ClassType; |
384 | |
385 | private: |
386 | ClassType( |
387 | c10::optional<QualifiedName> name, |
388 | std::weak_ptr<CompilationUnit> cu, |
389 | bool is_module = false, |
390 | std::string doc_string = "" , |
391 | std::vector<std::string> unresolved_class_attributes = {}); |
392 | |
393 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
394 | (void)printer; // Suppress unused variable warning |
395 | const auto& n = name().value(); |
396 | return n.qualifiedName(); |
397 | } |
398 | |
399 | void addAttribute(ClassAttribute classAttribute); |
400 | std::string getForwardPreHookErrorMessage(int pre_hook_idx) const; |
401 | std::string getForwardHookErrorMessage(int hook_idx) const; |
402 | |
403 | // Mapping of attribute names -> their type. |
404 | // NOTE: this does not contain methods, which are stored in the module |
405 | // TODO: once modules support arbitrary ivalue attributes, we don't need this |
406 | // anymore. |
407 | // TODO: This is better represented as an OrderedDict, but alas it is not yet |
408 | // available from c10 |
409 | |
410 | // Mapping of constant names -> their value. |
411 | std::vector<std::string> constantNames_; |
412 | std::vector<IValue> constantValues_; |
413 | // Holds method attributes |
414 | std::weak_ptr<CompilationUnit> compilation_unit_; |
415 | |
416 | // Holds all atrributes, attribute details are found on ClassAttribute |
417 | std::vector<ClassAttribute> attributes_; |
418 | // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef. |
419 | // Never fill this without using the appropriate provideNewClassAttribute method |
420 | std::vector<TypePtr> attributeTypes_; |
421 | |
422 | // List of methods associated with this class. |
423 | std::vector<torch::jit::Function*> methods_; |
424 | std::vector<torch::jit::Function*> staticmethods_; |
425 | |
426 | // List of hooks to be run before/after forward. |
427 | std::vector<torch::jit::Function*> forward_hooks_; |
428 | std::vector<torch::jit::Function*> forward_pre_hooks_; |
429 | |
430 | // List of properties exposed by this class. |
431 | std::vector<Property> properties_; |
432 | |
433 | bool isModule_ = false; |
434 | |
435 | // Doc string of class. |
436 | std::string doc_string_ = "" ; |
437 | |
438 | // For error reporting accesses to class level attributes. |
439 | std::vector<std::string> unresolved_class_attributes_; |
440 | }; |
441 | |
442 | } |
443 | |