1#pragma once
2
3#include <ATen/core/ivalue.h>
4#include <torch/csrc/jit/api/module.h>
5#include <torch/csrc/jit/python/pybind_utils.h>
6#include <memory>
7#include <string>
8#include <vector>
9
10namespace torch {
11namespace jit {
12
13enum class IterableModuleKind { NONE, LIST, DICT, PARAMLIST, PARAMDICT };
14class ConcreteModuleType;
15
16// You can think of an nn.Module as a template that corresponds to a family of
17// JIT types. The template "arguments" are things like the constant values.
18// e.g.
19// class M(nn.Module):
20// __constants__ = ["const"]
21// ...
22//
23// Is similar to writing the following in C++:
24//
25// template<TConst>
26// class M {
27// ...
28// }
29//
30// We need to consider each different member of the type family a different JIT
31// type because, e.g. different constant values lead to different versions of
32// the same method.
33//
34// ConcreteModuleType corresponds to a single member of the type family, with
35// all template arguments fully specified. Two Modules that share a
36// ConcreteModuleType can share a JIT type, and vice versa.
37//
38// Why not just use a JIT type to represent concrete types? Because constants,
39// function attributes, etc. are currently not representable in the type system,
40// so this acts a non-first-class way of tracking concrete types.
41//
42// ConcreteModuleType is also the source of truth for servicing all
43// ModuleValue::attr calls. This is so we can guarantee that if two Module's
44// share a JIT type (and thus a ConcreteModuleType), then they behave the same
45// way when you access attributes on them.
46
47// ConcreteModuleType has two phases.
48// 1. Creation: First we build it up, during the ScriptModule conversion
49// process. This is represented by ConcreteModuleTypeBuilder.
50// ...then the converter calls ConcreteModuleTypeBuilder::build(), producing
51// a
52// ConcreteModuleType ready for querying.
53// 2. Querying: We use ConcreteModuleType as a source of truth for
54// ModuleValue::attr calls during method compilation.
55
56// Represents a concrete type during in the process for construction. We use
57// this to decide whether we can share types between modules.
58class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
59 public:
60 explicit ConcreteModuleTypeBuilder(py::object pyClass) {
61 TORCH_INTERNAL_ASSERT(pyClass);
62 pyClass_ = std::move(pyClass);
63 }
64
65 void addConstant(std::string name, py::object value);
66 void addConstant(std::string name, IValue value);
67 void addAttribute(
68 std::string name,
69 const TypePtr& type,
70 bool isParameter,
71 bool isBuffer);
72 void addFunctionAttribute(
73 std::string name,
74 const TypePtr& type,
75 py::object pyFunction);
76
77 void addModule(std::string name, std::shared_ptr<ConcreteModuleType> meta);
78
79 void addForwardHook(py::object hook);
80 void addForwardPreHook(py::object pre_hook);
81
82 void addOverload(
83 std::string methodName,
84 std::vector<std::string> overloadedMethodNames);
85 void addBuiltinFunction(std::string name, const std::string& symbol_name);
86 void addFailedAttribute(std::string name, std::string failureReason);
87 void addIgnoredAttribute(std::string name);
88 void setIterableModuleKind(IterableModuleKind kind);
89
90 // If a ConcreteModuleType is poisoned, it will never compare equal to any
91 // other concrete type
92 void setPoisoned();
93
94 std::shared_ptr<ConcreteModuleType> build() const {
95 return std::make_shared<ConcreteModuleType>(*this);
96 }
97
98 // This determines whether two modules can share a type. The container structs
99 // used by ConcreteModuleType have been defined such that operator==
100 // implements a meaningful comparison in that context.
101 bool equals(const ConcreteModuleTypeBuilder& other) const;
102
103 struct FunctionAttribute {
104 FunctionTypePtr function_;
105 py::object pyFunction_;
106
107 friend bool operator==(
108 const FunctionAttribute& lhs,
109 const FunctionAttribute& rhs) {
110 // Functions are not first class, so we can't do type comparison like a
111 // regular attribute. So we do a pointer equality check on the actual
112 // Python function object.
113 return lhs.pyFunction_.is(rhs.pyFunction_);
114 }
115 };
116
117 struct Attribute {
118 Attribute(TypePtr type, bool isParam, bool isBuffer)
119 : type_(std::move(type)), isParam_(isParam), isBuffer_(isBuffer) {}
120
121 friend bool operator==(const Attribute& lhs, const Attribute& rhs) {
122 return *(lhs.type_) == *(rhs.type_) && lhs.isParam_ == rhs.isParam_;
123 }
124 TypePtr type_;
125 bool isParam_;
126 bool isBuffer_;
127 };
128
129 struct ModuleInfo {
130 ModuleInfo(std::string name, std::shared_ptr<ConcreteModuleType> meta)
131 : name_(std::move(name)), meta_(std::move(meta)) {}
132
133 friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs);
134
135 std::string name_;
136 std::shared_ptr<ConcreteModuleType> meta_;
137 };
138
139 private:
140 ConcreteModuleTypeBuilder() = default;
141 ClassTypePtr createTypeFromThis() const;
142
143 // If true, this type will never compare equally to anything else. This is
144 // used if we want to ensure that this type is not shared (for example, if it
145 // came from a traced module)
146 bool isPoisoned_ = false;
147
148 // The value of any constants defined by the module.
149 std::unordered_map<std::string, IValue> constants_;
150 // The types of any attributes
151 OrderedDict<std::string, Attribute> attributes_;
152 // Overloads, in the same format as `__overloads__` in Python
153 std::unordered_map<std::string, std::vector<std::string>> overloads_;
154 // Any attributes we failed to convert to TorchScript, along with a hint as to
155 // why
156 std::unordered_map<std::string, std::string> failedAttributes_;
157 // Any attributes that were marked as ignored. They cannot be used in
158 // TorchScript but can still be used in ignored function in Python.
159 std::unordered_set<std::string> ignoredAttributes_;
160 // Any function attributes. These are special right now because functions are
161 // not first-class in the type system.
162 std::unordered_map<std::string, FunctionAttribute> functionAttributes_;
163 // Function attributes that are calls to builtin functions. These get
164 // de-sugared directly into the corresponding aten:: call. The map is
165 // attribute name -> aten symbol name
166 std::unordered_map<std::string, c10::Symbol> builtinFunctions_;
167 // The concrete types of any submodules
168 std::vector<ModuleInfo> modules_;
169 // Hooks to be called before/after forward when the module
170 // is called directly. Used to ensure modules have different types
171 // when they have different python hooks
172 // Actual hooks are added to ClassType directly during compilation
173 std::vector<py::object> forwardHooks_;
174 std::vector<py::object> forwardPreHooks_;
175
176 // If something is a ModuleDict/ModuleList, it means:
177 // 1. The order of the submodules matters for comparing the type
178 // 2. The compiler is allowed to treat it like a dict/tuple
179 IterableModuleKind iterableModuleKind_ = IterableModuleKind::NONE;
180
181 // The original `nn.Module` class that we derived this ScriptModule from.
182 py::object pyClass_;
183
184 // NOTE: If you ever add any more state to this struct, you need to make sure
185 // operator== still makes sense!
186 friend ConcreteModuleType;
187};
188
189// Represents a finalized concrete type, used to service ModuleValue::attr calls
190// during method compilation.
191class VISIBILITY_HIDDEN ConcreteModuleType {
192 public:
193 explicit ConcreteModuleType(ConcreteModuleTypeBuilder data);
194
195 static std::shared_ptr<ConcreteModuleType> fromJitType(TypePtr type);
196
197 TypePtr getJitType() const;
198 c10::optional<py::object> getPyClass() const;
199 IterableModuleKind getIterableModuleKind() const;
200 c10::optional<std::vector<std::string>> findOverloads(
201 const std::string& name) const;
202 c10::optional<Function*> findFunctionAttribute(const std::string& name) const;
203 c10::optional<c10::Symbol> findBuiltinFunction(const std::string& name) const;
204 std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
205 const std::string& name) const;
206 c10::optional<std::string> findFailedAttribute(const std::string& name) const;
207 bool isIgnoredAttribute(const std::string& name) const;
208
209 // These getters are only here to return things as types that can be
210 // automatically converted by pybind.
211 std::unordered_map<std::string, py::object> getConstantsPy() const;
212 std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
213 const;
214 std::vector<std::pair<std::string, std::shared_ptr<ConcreteModuleType>>>
215 getModulesPy() const;
216
217 bool equals(const ConcreteModuleType& other) const {
218 if (jitType_ == other.jitType_) {
219 // If the computed types are the same, these modules can (obviously) share
220 // a type.
221 return true;
222 }
223
224 return data_.equals(other.data_);
225 }
226 bool equals(const ConcreteModuleTypeBuilder& other) const {
227 return data_.equals(other);
228 }
229
230 void dump() const;
231
232 private:
233 ConcreteModuleType() = default;
234
235 // The JIT type derived from this ConcreteModuleType.
236 ConcreteModuleTypeBuilder data_;
237 TypePtr jitType_;
238};
239
240} // namespace jit
241} // namespace torch
242