1 | #pragma once |
2 | #include <ATen/core/function.h> |
3 | #include <c10/util/Exception.h> |
4 | #include <torch/csrc/jit/api/function_impl.h> |
5 | #include <torch/csrc/jit/frontend/name_mangler.h> |
6 | #include <torch/csrc/jit/frontend/source_range.h> |
7 | #include <torch/csrc/jit/ir/ir.h> |
8 | #include <torch/csrc/jit/runtime/graph_executor.h> |
9 | |
10 | #include <torch/csrc/Export.h> |
11 | #include <torch/csrc/utils/memory.h> |
12 | |
13 | #include <ATen/core/function_schema.h> |
14 | #include <ATen/core/qualified_name.h> |
15 | #include <c10/util/ArrayRef.h> |
16 | #include <c10/util/Optional.h> |
17 | |
18 | #include <functional> |
19 | #include <memory> |
20 | #include <mutex> |
21 | #include <ostream> |
22 | #include <string> |
23 | #include <unordered_map> |
24 | #include <vector> |
25 | |
26 | namespace torch { |
27 | namespace jit { |
28 | |
29 | struct Def; |
30 | struct Property; |
31 | struct ClassDef; |
32 | struct SugaredValue; |
33 | struct Resolver; |
34 | |
35 | using ResolverPtr = std::shared_ptr<Resolver>; |
36 | struct Self { |
37 | virtual ~Self() = default; |
38 | virtual std::shared_ptr<SugaredValue> makeSugared(Value* v) const = 0; |
39 | virtual ClassTypePtr getClassType() const = 0; |
40 | }; |
41 | |
42 | // A CompilationUnit is a list of named Functions |
43 | // with helper methods to iterate the list or invoke the function. |
44 | // Classes have a CompilationUnit holding the class methods, |
45 | // and Modules have a CompilationUnit holding the Functions that |
46 | // are used to implement their Methods |
47 | |
48 | struct TORCH_API CompilationUnit { |
49 | enum class FunctionType { Method, Hook, PreHook }; |
50 | // constructor that takes a set of functions to compile using the native |
51 | // resolver |
52 | explicit CompilationUnit(const std::string& source); |
53 | CompilationUnit() = default; |
54 | |
55 | CompilationUnit& operator=(CompilationUnit&&) = default; |
56 | CompilationUnit(CompilationUnit&&) = default; |
57 | CompilationUnit& operator=(const CompilationUnit&) = delete; |
58 | CompilationUnit(const CompilationUnit&) = delete; |
59 | |
60 | Function* find_function(const c10::QualifiedName& name) const { |
61 | auto it = dict_.find(name); |
62 | if (it == dict_.end()) { |
63 | return nullptr; |
64 | } |
65 | return functions_[it->second].get(); |
66 | } |
67 | |
68 | Function& get_function(const c10::QualifiedName& name) const { |
69 | if (auto r = find_function(name)) { |
70 | return *r; |
71 | } |
72 | TORCH_CHECK(false, "attempted to get undefined function " , name.name()); |
73 | } |
74 | |
75 | void set_optimized(bool o) { |
76 | TORCH_WARN( |
77 | "CompilationUnit::set_optimized() is deprecated and has no effect. " |
78 | "Please use setGraphExecutorOptimize()" ); |
79 | } |
80 | |
81 | bool is_optimized() const { |
82 | TORCH_WARN( |
83 | "CompilationUnit::is_optimized() is deprecated and always returns true. " |
84 | "Please use getGraphExecutorOptimize()" ); |
85 | return true; |
86 | } |
87 | |
88 | // for historic reasons, these are defined in ir_emitter.cpp |
89 | // Returns the list of Functions just defined. |
90 | std::vector<Function*> define( |
91 | const c10::optional<c10::QualifiedName>& prefix, |
92 | const std::vector<Property>& properties, |
93 | const std::vector<ResolverPtr>& propResolvers, |
94 | const std::vector<Def>& definitions, |
95 | const std::vector<ResolverPtr>& |
96 | defResolvers, /* determines how we handle free |
97 | variables in each definition*/ |
98 | // if non-null, the first argument to each def, is bound to this value |
99 | const Self* self, |
100 | // see [name mangling] |
101 | bool shouldMangle = false, |
102 | c10::optional<size_t> operator_set_version = c10::nullopt); |
103 | |
104 | void define_hooks( |
105 | const c10::optional<c10::QualifiedName>& prefix, |
106 | const std::vector<Def>& hookDefs, |
107 | const std::vector<ResolverPtr>& hookResolvers, |
108 | const std::vector<Def>& preHookDefs, |
109 | const std::vector<ResolverPtr>& preHookResolvers, |
110 | const Self* self, |
111 | bool shouldMangle = false); |
112 | |
113 | // same as above but parse the definitions from source |
114 | // Returns the list of Functions just defined. |
115 | std::vector<Function*> define( |
116 | // prefix namespace to put all the defined functions into |
117 | const c10::optional<c10::QualifiedName>& prefix, |
118 | const std::string& source, |
119 | const ResolverPtr& resolver, |
120 | const Self* self); |
121 | |
122 | void define_interface( |
123 | const c10::QualifiedName& qualifiedName, |
124 | const ClassDef& classDef, |
125 | ResolverPtr rcb, |
126 | bool is_module = false); |
127 | |
128 | Function* create_function( |
129 | c10::QualifiedName name, |
130 | std::shared_ptr<Graph> graph, |
131 | bool shouldMangle = false) { |
132 | if (shouldMangle) { |
133 | name = mangle(name); |
134 | } |
135 | auto fn = torch::make_unique<GraphFunction>( |
136 | std::move(name), std::move(graph), nullptr); |
137 | auto ret = fn.get(); |
138 | register_function(std::move(fn)); |
139 | return ret; |
140 | } |
141 | |
142 | std::vector<Function*> get_functions() const { |
143 | return fmap(functions_, [](const std::unique_ptr<Function>& fn) { |
144 | return fn.get(); |
145 | }); |
146 | } |
147 | |
148 | /// Run a method from this compilation. |
149 | /// |
150 | /// For example: |
151 | /// @code |
152 | /// IValue output = module->run("relu_script", a, b); |
153 | /// @endcode |
154 | /// |
155 | /// To get a compile a module from a source string, see torch::jit::compile |
156 | /// |
157 | /// @param method_name The name of the method to run |
158 | /// @param args Arguments to be passed to the method |
159 | /// @return An IValue containing the return value (or values if it is a tuple) |
160 | /// from the method |
161 | template <typename... Types> |
162 | IValue run_method(const c10::QualifiedName& method_name, Types&&... args) { |
163 | return get_function(method_name)({IValue(std::forward<Types>(args))...}); |
164 | } |
165 | |
166 | void drop_all_functions() { |
167 | dict_.clear(); |
168 | functions_.clear(); |
169 | } |
170 | |
171 | /** |
172 | * Register a class as being owned by this compilation unit. |
173 | */ |
174 | void register_type(c10::NamedTypePtr namedType) { |
175 | // TODO: class types cannot be redefined because we have no way right now |
176 | // of invalidating their methods. NamedTuples are fine though, since they |
177 | // don't have methods. |
178 | TORCH_CHECK( |
179 | 0 == classDict_.count(*namedType->name()), |
180 | "class '" , |
181 | namedType->name()->qualifiedName(), |
182 | "' already defined." ); |
183 | classes_.push_back(std::move(namedType)); |
184 | classDict_[*classes_.back()->name()] = classes_.size() - 1; |
185 | }; |
186 | |
187 | c10::ClassTypePtr get_class(const c10::QualifiedName& name) const { |
188 | auto type = get_type(name); |
189 | if (!type) { |
190 | return nullptr; |
191 | } |
192 | return type->cast<c10::ClassType>(); |
193 | } |
194 | |
195 | c10::InterfaceTypePtr get_interface(const c10::QualifiedName& name) const { |
196 | auto type = get_type(name); |
197 | if (!type) { |
198 | return nullptr; |
199 | } |
200 | return type->cast<c10::InterfaceType>(); |
201 | } |
202 | |
203 | c10::TupleTypePtr get_named_tuple(const c10::QualifiedName& name) const { |
204 | for (const auto& cls : classes_) { |
205 | if (cls->name()->qualifiedName() == name.qualifiedName()) { |
206 | return cls->expect<TupleType>(); |
207 | } |
208 | } |
209 | return nullptr; |
210 | } |
211 | |
212 | c10::NamedTypePtr get_type(const c10::QualifiedName& name) const { |
213 | auto it = classDict_.find(name); |
214 | if (it == classDict_.end()) { |
215 | return nullptr; |
216 | } |
217 | return classes_[it->second]; |
218 | } |
219 | |
220 | // For testing: clear all Python-defined classes to ensure that unit tests |
221 | // have isolation. |
222 | void _clear_python_cu() { |
223 | // Delete all the associated class methods |
224 | for (const auto& type : classes_) { |
225 | if (auto cls = type->cast<ClassType>()) { |
226 | for (auto method : cls->methods()) { |
227 | // Tombstone the method in the compilation unit. |
228 | // Don't erase because the dict_ |
229 | auto it = dict_.find(method->qualname()); |
230 | if (it != dict_.end()) { |
231 | functions_[it->second] = nullptr; |
232 | // Erase in our big lookup table |
233 | dict_.erase(it); |
234 | } |
235 | } |
236 | // Classes can have multiple pointers to the same hook, |
237 | // need to make sure to not delete it twice |
238 | std::unordered_set<Function*> hooks_to_delete; |
239 | for (const auto& hook : cls->getForwardHooks()) { |
240 | hooks_to_delete.insert(hook); |
241 | } |
242 | for (const auto& pre_hook : cls->getForwardPreHooks()) { |
243 | hooks_to_delete.insert(pre_hook); |
244 | } |
245 | for (const auto& hook : hooks_to_delete) { |
246 | // Tombstone the hook in the compilation unit. |
247 | auto it = dict_.find(hook->qualname()); |
248 | if (it != dict_.end()) { |
249 | functions_[it->second] = nullptr; |
250 | // Erase in our big lookup table |
251 | dict_.erase(it); |
252 | } |
253 | } |
254 | } |
255 | } |
256 | classes_.clear(); |
257 | classDict_.clear(); |
258 | } |
259 | |
260 | // [Internal Only] Remove method. |
261 | // Note Used for freezing. |
262 | void unsafeRemoveMethod(const c10::QualifiedName& method_name) { |
263 | auto it = dict_.find(method_name); |
264 | TORCH_CHECK( |
265 | it != dict_.end(), |
266 | "method '" , |
267 | method_name.qualifiedName(), |
268 | "' does not exist." ); |
269 | functions_[it->second] = nullptr; |
270 | dict_.erase(it); |
271 | } |
272 | |
273 | // [name mangling] All code objects must have a unique qualified name in a |
274 | // CompilationUnit. In Python, sometimes functions won't have unique qualified |
275 | // name (for example, nested functions). So we mangle Python functions to |
276 | // ensure that they are uniquely named. |
277 | // |
278 | // We also use mangling to distinguish different Module instances. Since each |
279 | // Module is a singleton class instance, different instances of the same |
280 | // Python Module will have different types but the same qualified name. |
281 | c10::QualifiedName mangle(const c10::QualifiedName& name) const { |
282 | auto mangled = name; |
283 | while (get_type(mangled) || find_function(mangled)) { |
284 | mangled = mangler_.mangle(mangled); |
285 | } |
286 | return mangled; |
287 | } |
288 | |
289 | private: |
290 | std::unique_ptr<Function> define( |
291 | const c10::optional<c10::QualifiedName>& prefix, |
292 | const Def& def, |
293 | const ResolverPtr& resolver, |
294 | const Self* self, |
295 | const std::unordered_map<std::string, Function*>& function_table, |
296 | bool shouldMangle = false, |
297 | FunctionType type = FunctionType::Method, |
298 | c10::optional<size_t> version = c10::nullopt) const; |
299 | |
300 | // Define a property on \p self. |
301 | struct PropertyPair; |
302 | PropertyPair define_property( |
303 | const c10::optional<c10::QualifiedName>& prefix, |
304 | const Property& prop, |
305 | const ResolverPtr& resolver, |
306 | const Self* self, |
307 | const std::unordered_map<std::string, Function*>& function_table, |
308 | bool shouldMangle = false) const; |
309 | |
310 | Function& register_function(std::unique_ptr<Function> fn) { |
311 | TORCH_CHECK( |
312 | 0 == dict_.count(fn->qualname().qualifiedName()), |
313 | "method '" , |
314 | fn->qualname().qualifiedName(), |
315 | "' already defined." ); |
316 | functions_.emplace_back(std::move(fn)); |
317 | dict_[functions_.back()->qualname()] = functions_.size() - 1; |
318 | return *functions_.back(); |
319 | } |
320 | std::vector<std::unique_ptr<Function>> functions_; |
321 | // for fast lookup |
322 | std::unordered_map<c10::QualifiedName, size_t> dict_; |
323 | std::unordered_map<c10::QualifiedName, size_t> classDict_; |
324 | |
325 | // [class ownership] Right now there aree two relationships between classes |
326 | // and compilation units: |
327 | // 1. Classes have compilation units internally that hold their methods. |
328 | // 2. On load, the TypePtrs of any imported classes are owned by the main |
329 | // module's compilation unit. |
330 | std::vector<c10::NamedTypePtr> classes_; |
331 | |
332 | mutable NameMangler mangler_; |
333 | }; |
334 | |
335 | // An owning pointer to a Function. Just a pair of a raw Function ptr and it's |
336 | // owning CU. We need this because pybind requires a ref-counted way to refer to |
337 | // Functions. |
338 | struct StrongFunctionPtr { |
339 | StrongFunctionPtr(std::shared_ptr<CompilationUnit> cu, Function* function) |
340 | : cu_(std::move(cu)), function_(function) { |
341 | TORCH_INTERNAL_ASSERT(cu_); |
342 | TORCH_INTERNAL_ASSERT(function_); |
343 | } |
344 | std::shared_ptr<CompilationUnit> cu_; |
345 | Function* function_; |
346 | }; |
347 | |
348 | namespace script { |
349 | // We once had a `script::` namespace that was deleted. This is for backcompat |
350 | // of the public API; new code should not use this type alias. |
351 | using CompilationUnit = ::torch::jit::CompilationUnit; |
352 | } // namespace script |
353 | } // namespace jit |
354 | } // namespace torch |
355 | |