1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | #include <torch/csrc/jit/api/module.h> |
5 | #include <torch/csrc/jit/python/pybind_utils.h> |
6 | #include <memory> |
7 | #include <string> |
8 | #include <vector> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | enum class IterableModuleKind { NONE, LIST, DICT, PARAMLIST, PARAMDICT }; |
14 | class ConcreteModuleType; |
15 | |
16 | // You can think of an nn.Module as a template that corresponds to a family of |
17 | // JIT types. The template "arguments" are things like the constant values. |
18 | // e.g. |
19 | // class M(nn.Module): |
20 | // __constants__ = ["const"] |
21 | // ... |
22 | // |
23 | // Is similar to writing the following in C++: |
24 | // |
25 | // template<TConst> |
26 | // class M { |
27 | // ... |
28 | // } |
29 | // |
30 | // We need to consider each different member of the type family a different JIT |
31 | // type because, e.g. different constant values lead to different versions of |
32 | // the same method. |
33 | // |
34 | // ConcreteModuleType corresponds to a single member of the type family, with |
35 | // all template arguments fully specified. Two Modules that share a |
36 | // ConcreteModuleType can share a JIT type, and vice versa. |
37 | // |
38 | // Why not just use a JIT type to represent concrete types? Because constants, |
39 | // function attributes, etc. are currently not representable in the type system, |
40 | // so this acts a non-first-class way of tracking concrete types. |
41 | // |
42 | // ConcreteModuleType is also the source of truth for servicing all |
43 | // ModuleValue::attr calls. This is so we can guarantee that if two Module's |
44 | // share a JIT type (and thus a ConcreteModuleType), then they behave the same |
45 | // way when you access attributes on them. |
46 | |
47 | // ConcreteModuleType has two phases. |
48 | // 1. Creation: First we build it up, during the ScriptModule conversion |
49 | // process. This is represented by ConcreteModuleTypeBuilder. |
50 | // ...then the converter calls ConcreteModuleTypeBuilder::build(), producing |
51 | // a |
52 | // ConcreteModuleType ready for querying. |
53 | // 2. Querying: We use ConcreteModuleType as a source of truth for |
54 | // ModuleValue::attr calls during method compilation. |
55 | |
56 | // Represents a concrete type during in the process for construction. We use |
57 | // this to decide whether we can share types between modules. |
58 | class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { |
59 | public: |
60 | explicit ConcreteModuleTypeBuilder(py::object pyClass) { |
61 | TORCH_INTERNAL_ASSERT(pyClass); |
62 | pyClass_ = std::move(pyClass); |
63 | } |
64 | |
65 | void addConstant(std::string name, py::object value); |
66 | void addConstant(std::string name, IValue value); |
67 | void addAttribute( |
68 | std::string name, |
69 | const TypePtr& type, |
70 | bool isParameter, |
71 | bool isBuffer); |
72 | void addFunctionAttribute( |
73 | std::string name, |
74 | const TypePtr& type, |
75 | py::object pyFunction); |
76 | |
77 | void addModule(std::string name, std::shared_ptr<ConcreteModuleType> meta); |
78 | |
79 | void addForwardHook(py::object hook); |
80 | void addForwardPreHook(py::object pre_hook); |
81 | |
82 | void addOverload( |
83 | std::string methodName, |
84 | std::vector<std::string> overloadedMethodNames); |
85 | void addBuiltinFunction(std::string name, const std::string& symbol_name); |
86 | void addFailedAttribute(std::string name, std::string failureReason); |
87 | void addIgnoredAttribute(std::string name); |
88 | void setIterableModuleKind(IterableModuleKind kind); |
89 | |
90 | // If a ConcreteModuleType is poisoned, it will never compare equal to any |
91 | // other concrete type |
92 | void setPoisoned(); |
93 | |
94 | std::shared_ptr<ConcreteModuleType> build() const { |
95 | return std::make_shared<ConcreteModuleType>(*this); |
96 | } |
97 | |
98 | // This determines whether two modules can share a type. The container structs |
99 | // used by ConcreteModuleType have been defined such that operator== |
100 | // implements a meaningful comparison in that context. |
101 | bool equals(const ConcreteModuleTypeBuilder& other) const; |
102 | |
103 | struct FunctionAttribute { |
104 | FunctionTypePtr function_; |
105 | py::object pyFunction_; |
106 | |
107 | friend bool operator==( |
108 | const FunctionAttribute& lhs, |
109 | const FunctionAttribute& rhs) { |
110 | // Functions are not first class, so we can't do type comparison like a |
111 | // regular attribute. So we do a pointer equality check on the actual |
112 | // Python function object. |
113 | return lhs.pyFunction_.is(rhs.pyFunction_); |
114 | } |
115 | }; |
116 | |
117 | struct Attribute { |
118 | Attribute(TypePtr type, bool isParam, bool isBuffer) |
119 | : type_(std::move(type)), isParam_(isParam), isBuffer_(isBuffer) {} |
120 | |
121 | friend bool operator==(const Attribute& lhs, const Attribute& rhs) { |
122 | return *(lhs.type_) == *(rhs.type_) && lhs.isParam_ == rhs.isParam_; |
123 | } |
124 | TypePtr type_; |
125 | bool isParam_; |
126 | bool isBuffer_; |
127 | }; |
128 | |
129 | struct ModuleInfo { |
130 | ModuleInfo(std::string name, std::shared_ptr<ConcreteModuleType> meta) |
131 | : name_(std::move(name)), meta_(std::move(meta)) {} |
132 | |
133 | friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs); |
134 | |
135 | std::string name_; |
136 | std::shared_ptr<ConcreteModuleType> meta_; |
137 | }; |
138 | |
139 | private: |
140 | ConcreteModuleTypeBuilder() = default; |
141 | ClassTypePtr createTypeFromThis() const; |
142 | |
143 | // If true, this type will never compare equally to anything else. This is |
144 | // used if we want to ensure that this type is not shared (for example, if it |
145 | // came from a traced module) |
146 | bool isPoisoned_ = false; |
147 | |
148 | // The value of any constants defined by the module. |
149 | std::unordered_map<std::string, IValue> constants_; |
150 | // The types of any attributes |
151 | OrderedDict<std::string, Attribute> attributes_; |
152 | // Overloads, in the same format as `__overloads__` in Python |
153 | std::unordered_map<std::string, std::vector<std::string>> overloads_; |
154 | // Any attributes we failed to convert to TorchScript, along with a hint as to |
155 | // why |
156 | std::unordered_map<std::string, std::string> failedAttributes_; |
157 | // Any attributes that were marked as ignored. They cannot be used in |
158 | // TorchScript but can still be used in ignored function in Python. |
159 | std::unordered_set<std::string> ignoredAttributes_; |
160 | // Any function attributes. These are special right now because functions are |
161 | // not first-class in the type system. |
162 | std::unordered_map<std::string, FunctionAttribute> functionAttributes_; |
163 | // Function attributes that are calls to builtin functions. These get |
164 | // de-sugared directly into the corresponding aten:: call. The map is |
165 | // attribute name -> aten symbol name |
166 | std::unordered_map<std::string, c10::Symbol> builtinFunctions_; |
167 | // The concrete types of any submodules |
168 | std::vector<ModuleInfo> modules_; |
169 | // Hooks to be called before/after forward when the module |
170 | // is called directly. Used to ensure modules have different types |
171 | // when they have different python hooks |
172 | // Actual hooks are added to ClassType directly during compilation |
173 | std::vector<py::object> forwardHooks_; |
174 | std::vector<py::object> forwardPreHooks_; |
175 | |
176 | // If something is a ModuleDict/ModuleList, it means: |
177 | // 1. The order of the submodules matters for comparing the type |
178 | // 2. The compiler is allowed to treat it like a dict/tuple |
179 | IterableModuleKind iterableModuleKind_ = IterableModuleKind::NONE; |
180 | |
181 | // The original `nn.Module` class that we derived this ScriptModule from. |
182 | py::object pyClass_; |
183 | |
184 | // NOTE: If you ever add any more state to this struct, you need to make sure |
185 | // operator== still makes sense! |
186 | friend ConcreteModuleType; |
187 | }; |
188 | |
189 | // Represents a finalized concrete type, used to service ModuleValue::attr calls |
190 | // during method compilation. |
191 | class VISIBILITY_HIDDEN ConcreteModuleType { |
192 | public: |
193 | explicit ConcreteModuleType(ConcreteModuleTypeBuilder data); |
194 | |
195 | static std::shared_ptr<ConcreteModuleType> fromJitType(TypePtr type); |
196 | |
197 | TypePtr getJitType() const; |
198 | c10::optional<py::object> getPyClass() const; |
199 | IterableModuleKind getIterableModuleKind() const; |
200 | c10::optional<std::vector<std::string>> findOverloads( |
201 | const std::string& name) const; |
202 | c10::optional<Function*> findFunctionAttribute(const std::string& name) const; |
203 | c10::optional<c10::Symbol> findBuiltinFunction(const std::string& name) const; |
204 | std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType( |
205 | const std::string& name) const; |
206 | c10::optional<std::string> findFailedAttribute(const std::string& name) const; |
207 | bool isIgnoredAttribute(const std::string& name) const; |
208 | |
209 | // These getters are only here to return things as types that can be |
210 | // automatically converted by pybind. |
211 | std::unordered_map<std::string, py::object> getConstantsPy() const; |
212 | std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy() |
213 | const; |
214 | std::vector<std::pair<std::string, std::shared_ptr<ConcreteModuleType>>> |
215 | getModulesPy() const; |
216 | |
217 | bool equals(const ConcreteModuleType& other) const { |
218 | if (jitType_ == other.jitType_) { |
219 | // If the computed types are the same, these modules can (obviously) share |
220 | // a type. |
221 | return true; |
222 | } |
223 | |
224 | return data_.equals(other.data_); |
225 | } |
226 | bool equals(const ConcreteModuleTypeBuilder& other) const { |
227 | return data_.equals(other); |
228 | } |
229 | |
230 | void dump() const; |
231 | |
232 | private: |
233 | ConcreteModuleType() = default; |
234 | |
235 | // The JIT type derived from this ConcreteModuleType. |
236 | ConcreteModuleTypeBuilder data_; |
237 | TypePtr jitType_; |
238 | }; |
239 | |
240 | } // namespace jit |
241 | } // namespace torch |
242 | |