1#pragma once
2
3#include <memory>
4
5#include <ATen/core/ivalue.h>
6#include <ATen/core/jit_type_base.h>
7#include <c10/util/Optional.h>
8
9namespace torch {
10namespace jit {
11struct CompilationUnit;
12struct Function;
13} // namespace jit
14} // namespace torch
15
16namespace c10 {
17
18struct FunctionSchema;
19
20// This enumerator represents the 'kind' of an attribute - a buffer, a parameter, or neither.
21// This state is mutually exclusive. Buffers and Parameters can only appear on modules.
22enum class AttributeKind {
23 BUFFER,
24 PARAMETER,
25 REGULAR_ATTRIBUTE
26};
27
28// This structure represents all notional booking entities in a class attribute: name, kind (see: AttributeKind), and type (see: TypePtr).
29// Note: This structure does not represent the value of the attribute.
30struct TORCH_API ClassAttribute {
31 public:
32 ClassAttribute(AttributeKind kind,
33 TypePtr attributeType,
34 std::string attributeName) :
35 kind_(kind),
36 attributeType_(std::move(attributeType)),
37 attributeName_(std::move(attributeName)) {}
38
39 AttributeKind getKind() const {
40 return kind_;
41 }
42
43 const TypePtr& getType() const {
44 return attributeType_;
45 }
46
47 const std::string& getName() const {
48 return attributeName_;
49 }
50
51 private:
52 AttributeKind kind_;
53 TypePtr attributeType_;
54 std::string attributeName_;
55};
56
57/**
58 * User Defined Types
59 */
60
61struct ClassType;
62using ClassTypePtr = std::shared_ptr<ClassType>;
63using ::torch::jit::CompilationUnit;
64
65// This represents a class in TorchScript.
66struct TORCH_API ClassType : public NamedType {
67 // This represents an attribute of a class; a name associated with an attribute, and a
68 // getter and (optional) setter for that attribute.
69 struct Property {
70 std::string name;
71 torch::jit::Function* getter;
72 torch::jit::Function* setter;
73 };
74
75 // Create a class type with name `name` and its methods stored in `cu`.
76 static ClassTypePtr create(
77 c10::optional<QualifiedName> qualifiedName,
78 std::weak_ptr<CompilationUnit> cu,
79 bool is_module = false,
80 std::string doc_string = "",
81 std::vector<std::string> unresolved_class_attributes = {});
82
83 bool equals(const Type& rhs) const override {
84 if (this == &rhs) {
85 return true;
86 }
87 if (auto user_rhs = rhs.castRaw<ClassType>()) {
88 const auto& lhs_name = name().value();
89 const auto& rhs_name = user_rhs->name().value();
90
91 return lhs_name == rhs_name &&
92 this->compilation_unit() == user_rhs->compilation_unit();
93 }
94 return false;
95 }
96
97 std::string str() const override {
98 return annotation_str();
99 }
100
101 std::string repr_str() const override {
102 std::stringstream ss;
103 ss << str()
104 << " (of Python compilation unit at: " << compilation_unit().get() << ")";
105 return ss.str();
106 }
107
108 const std::vector<torch::jit::Function*>& methods() const;
109
110 TypePtr findAttribute(const std::string& name) const {
111 size_t pos = 0;
112 for (const auto& attr : attributes_) {
113 if (name == attr.getName()) {
114 break;
115 }
116 ++pos;
117 }
118
119 if (pos >= attributes_.size()) {
120 return nullptr;
121 }
122 return attributes_[pos].getType();
123 }
124
125 const TypePtr& getAttribute(const std::string& name) const {
126 auto slot = findAttributeSlot(name);
127 TORCH_CHECK(
128 slot,
129 repr_str(),
130 " does not have an attribute with name '",
131 name,
132 "'");
133 return attributes_[*slot].getType();
134 }
135
136 size_t numAttributes() const {
137 return attributes_.size();
138 }
139
140 const TypePtr& getAttribute(size_t slot) const {
141 AT_ASSERT(slot < attributes_.size());
142 return attributes_.at(slot).getType();
143 }
144
145 const std::string getAttributeName(size_t slot) const {
146 AT_ASSERT(slot < attributes_.size());
147 return attributes_[slot].getName();
148 }
149
150 void checkNotExist(const std::string& name, const std::string& what) const;
151
152 // Attributes are stored in a specific slot at runtime for effiency.
153 // When emitting instructions we specify the slot so that attribute access is
154 // a constant lookup
155 c10::optional<size_t> findAttributeSlot(const std::string& name) const {
156 size_t slot = 0;
157 for (const auto& attr : attributes_) {
158 if (name == attr.getName()) {
159 return slot;
160 }
161 slot++;
162 }
163 return c10::nullopt;
164 }
165 size_t getAttributeSlot(const std::string& name) const {
166 if (auto r = findAttributeSlot(name)) {
167 return *r;
168 }
169 TORCH_CHECK(
170 false,
171 repr_str(),
172 " does not have an attribute with name '",
173 name,
174 "'");
175 }
176
177 bool hasAttribute(const std::string& name) const {
178 return std::find_if(
179 attributes_.cbegin(),
180 attributes_.cend(),
181 [&](const ClassAttribute& attr) { return attr.getName() == name; }) !=
182 attributes_.cend();
183 }
184
185 bool isUnresolvedClassAttribute(const std::string& name) const;
186
187 at::ArrayRef<TypePtr> containedTypes() const override {
188 return attributeTypes_;
189 }
190
191 size_t addAttribute(
192 const std::string& name,
193 TypePtr type,
194 bool is_parameter = false,
195 bool is_buffer = false);
196
197 // [Internal Only] Remove attribute from the ClassType,
198 // caller is responsible to make sure the modification is safe:
199 // it is unsafe to having existing allocations
200 // of this object around anymore, and any code that works on
201 // the attribute is now invalid. Only newly created code is
202 // valid again.
203 void unsafeRemoveAttribute(const std::string& name);
204
205 // [Internal Only] Change the type of an attribute of the ClassType,
206 // The caller is responsible to make sure the modification is safe:
207 // it is unsafe to maintain uses of the old type of the attribute,
208 // and any code that works on the attribute is now invalid.
209 // Only newly created code is valid again.
210 void unsafeChangeAttributeType(const std::string& name, TypePtr new_ty);
211
212 // Add attribute \p NAME if it doesn't exist or verify that it has a
213 // compatible type otherwise.
214 size_t addOrCheckAttribute(
215 const std::string& name,
216 TypePtr ty,
217 bool is_parameter = false,
218 bool is_buffer = false) {
219 auto slot_idx = findAttributeSlot(name);
220 if (!slot_idx) {
221 return addAttribute(name, std::move(ty), is_parameter, is_buffer);
222 }
223
224 TORCH_CHECK(
225 is_parameter == this->is_parameter(*slot_idx),
226 "Parameter field mismatch for the field '",
227 name,
228 "'");
229 const TypePtr& atype = getAttribute(*slot_idx);
230 TORCH_CHECK(
231 ty->isSubtypeOf(*atype),
232 ty->repr_str(),
233 " is not compatible with the type ",
234 atype->repr_str(),
235 " for the field '",
236 name,
237 "'");
238 return *slot_idx;
239 }
240
241 // Get the property with the given \p name, if it exists on the class.
242 c10::optional<ClassType::Property> getProperty(const std::string& name);
243 // Add a property named \p name with \p getter and \p setter as its getter and setter.
244 void addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter);
245 // Get a list of all properties.
246 const std::vector<Property>& properties() const {
247 return properties_;
248 }
249
250 bool hasConstant(const std::string& name) const {
251 return std::find_if(
252 constantNames_.cbegin(),
253 constantNames_.cend(),
254 [&](const std::string& constant) { return constant == name; }) !=
255 constantNames_.cend();
256 }
257
258 size_t addConstant(const std::string& name, const IValue& value);
259
260 c10::optional<size_t> findConstantSlot(const std::string& name) const;
261
262 size_t getConstantSlot(const std::string& name) const {
263 if (auto r = findConstantSlot(name)) {
264 return *r;
265 }
266 TORCH_CHECK(
267 false,
268 repr_str(),
269 " does not have constant field with the name '",
270 name,
271 "'");
272 }
273
274 const std::string& getConstantName(size_t slot) const;
275
276 const std::string& doc_string() const {
277 return doc_string_;
278 }
279
280 IValue getConstant(const std::string& name) const;
281
282 IValue getConstant(size_t slot) const;
283
284 c10::optional<IValue> findConstant(const std::string& name) const;
285
286 size_t numConstants() const;
287
288 at::ArrayRef<std::string> constantNames() const {
289 return constantNames_;
290 }
291
292 at::ArrayRef<IValue> constantValues() const;
293
294 // [Internal Only] Remove constant from the ClassType
295 // caller is responsible to make sure the modification is safe:
296 // it is unsafe to having existing allocations
297 // of this object around anymore, and any code that works on
298 // the attribute is now invalid. Only newly created code is
299 // valid again.
300 void unsafeRemoveConstant(const std::string& name);
301
302 TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
303 auto ptr = ClassType::create(name(), compilation_unit_, is_module());
304 AT_ASSERT(numAttributes() == contained_types.size());
305 for(size_t i = 0; i < attributes_.size(); ++i) {
306 AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i]));
307 ptr->addAttribute(attributes_[i].getName(), std::move(contained_types[i]));
308 }
309 // Copy methods over
310 for (const auto& method : methods()) {
311 ptr->addMethod(method);
312 }
313 return ptr;
314 }
315
316 bool is_module() const override {
317 return isModule_;
318 }
319
320 const std::vector<ClassAttribute>& getAttributes() const {
321 return attributes_;
322 }
323
324 bool is_parameter(size_t slot) const {
325 TORCH_INTERNAL_ASSERT(
326 is_module(), "asking for parameterSlots of non-Module");
327 return attributes_.at(slot).getKind() == AttributeKind::PARAMETER;
328 }
329
330 bool is_buffer(size_t slot) const {
331 TORCH_INTERNAL_ASSERT(
332 is_module(), "asking for bufferWrittenSlots of non-Module");
333 return attributes_.at(slot).getKind() == AttributeKind::BUFFER;
334 }
335
336 void addForwardPreHook(torch::jit::Function* pre_hook_ptr);
337 void addForwardHook(torch::jit::Function* hook_ptr);
338 torch::jit::Function* findForwardPreHook(const std::string& name) const;
339 torch::jit::Function* findForwardHook(const std::string& name) const;
340 const std::vector<torch::jit::Function*>& getForwardHooks() const;
341 const std::vector<torch::jit::Function*>& getForwardPreHooks() const;
342
343 void checkForwardPreHookSchema(
344 int pre_hook_idx,
345 const FunctionSchema& pre_hook_schema) const;
346 void checkForwardHookSchema(
347 int hook_idx,
348 const FunctionSchema& hook_schema) const;
349
350 void addMethod(torch::jit::Function* method);
351 torch::jit::Function* findMethod(const std::string& name) const;
352 torch::jit::Function& getMethod(const std::string& name) const;
353 torch::jit::Function* findHook(const std::string& name) const;
354 torch::jit::Function& getHook(const std::string& name) const;
355 bool hasMethod(const std::string& name) const;
356
357 torch::jit::Function* findStaticMethod(const std::string& name) const;
358 void addStaticMethod(torch::jit::Function* method);
359
360 // [Internal Only] Remove method from the ClassType
361 // caller is responsible to make sure the modification is safe:
362 // it is unsafe to having existing allocations
363 // of this object around anymore, and any code that works on
364 // the attribute is now invalid. Only newly created code is
365 // valid again.
366 // Note this method is intended for freezing only.
367 void unsafeRemoveMethod(const std::string& name);
368
369 std::shared_ptr<CompilationUnit> compilation_unit();
370
371 std::shared_ptr<const CompilationUnit> compilation_unit() const;
372
373 // generate a refined version of this class.
374 // It has the same name but the slot Types are subtypes of
375 // the original slots. It is only valid to refine a class type in a context
376 // where it is know that there are not assignments to the objects slots
377 // that would invalidate the refinement.
378 // These variants are not registered in the global class table.
379 ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const;
380
381 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
382
383 static const TypeKind Kind = TypeKind::ClassType;
384
385 private:
386 ClassType(
387 c10::optional<QualifiedName> name,
388 std::weak_ptr<CompilationUnit> cu,
389 bool is_module = false,
390 std::string doc_string = "",
391 std::vector<std::string> unresolved_class_attributes = {});
392
393 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
394 (void)printer; // Suppress unused variable warning
395 const auto& n = name().value();
396 return n.qualifiedName();
397 }
398
399 void addAttribute(ClassAttribute classAttribute);
400 std::string getForwardPreHookErrorMessage(int pre_hook_idx) const;
401 std::string getForwardHookErrorMessage(int hook_idx) const;
402
403 // Mapping of attribute names -> their type.
404 // NOTE: this does not contain methods, which are stored in the module
405 // TODO: once modules support arbitrary ivalue attributes, we don't need this
406 // anymore.
407 // TODO: This is better represented as an OrderedDict, but alas it is not yet
408 // available from c10
409
410 // Mapping of constant names -> their value.
411 std::vector<std::string> constantNames_;
412 std::vector<IValue> constantValues_;
413 // Holds method attributes
414 std::weak_ptr<CompilationUnit> compilation_unit_;
415
416 // Holds all atrributes, attribute details are found on ClassAttribute
417 std::vector<ClassAttribute> attributes_;
418 // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef.
419 // Never fill this without using the appropriate provideNewClassAttribute method
420 std::vector<TypePtr> attributeTypes_;
421
422 // List of methods associated with this class.
423 std::vector<torch::jit::Function*> methods_;
424 std::vector<torch::jit::Function*> staticmethods_;
425
426 // List of hooks to be run before/after forward.
427 std::vector<torch::jit::Function*> forward_hooks_;
428 std::vector<torch::jit::Function*> forward_pre_hooks_;
429
430 // List of properties exposed by this class.
431 std::vector<Property> properties_;
432
433 bool isModule_ = false;
434
435 // Doc string of class.
436 std::string doc_string_ = "";
437
438 // For error reporting accesses to class level attributes.
439 std::vector<std::string> unresolved_class_attributes_;
440};
441
442}
443