1#pragma once
2#include <ATen/core/function.h>
3#include <c10/util/Exception.h>
4#include <torch/csrc/jit/api/function_impl.h>
5#include <torch/csrc/jit/frontend/name_mangler.h>
6#include <torch/csrc/jit/frontend/source_range.h>
7#include <torch/csrc/jit/ir/ir.h>
8#include <torch/csrc/jit/runtime/graph_executor.h>
9
10#include <torch/csrc/Export.h>
11#include <torch/csrc/utils/memory.h>
12
13#include <ATen/core/function_schema.h>
14#include <ATen/core/qualified_name.h>
15#include <c10/util/ArrayRef.h>
16#include <c10/util/Optional.h>
17
18#include <functional>
19#include <memory>
20#include <mutex>
21#include <ostream>
22#include <string>
23#include <unordered_map>
24#include <vector>
25
26namespace torch {
27namespace jit {
28
29struct Def;
30struct Property;
31struct ClassDef;
32struct SugaredValue;
33struct Resolver;
34
35using ResolverPtr = std::shared_ptr<Resolver>;
36struct Self {
37 virtual ~Self() = default;
38 virtual std::shared_ptr<SugaredValue> makeSugared(Value* v) const = 0;
39 virtual ClassTypePtr getClassType() const = 0;
40};
41
42// A CompilationUnit is a list of named Functions
43// with helper methods to iterate the list or invoke the function.
44// Classes have a CompilationUnit holding the class methods,
45// and Modules have a CompilationUnit holding the Functions that
46// are used to implement their Methods
47
48struct TORCH_API CompilationUnit {
49 enum class FunctionType { Method, Hook, PreHook };
50 // constructor that takes a set of functions to compile using the native
51 // resolver
52 explicit CompilationUnit(const std::string& source);
53 CompilationUnit() = default;
54
55 CompilationUnit& operator=(CompilationUnit&&) = default;
56 CompilationUnit(CompilationUnit&&) = default;
57 CompilationUnit& operator=(const CompilationUnit&) = delete;
58 CompilationUnit(const CompilationUnit&) = delete;
59
60 Function* find_function(const c10::QualifiedName& name) const {
61 auto it = dict_.find(name);
62 if (it == dict_.end()) {
63 return nullptr;
64 }
65 return functions_[it->second].get();
66 }
67
68 Function& get_function(const c10::QualifiedName& name) const {
69 if (auto r = find_function(name)) {
70 return *r;
71 }
72 TORCH_CHECK(false, "attempted to get undefined function ", name.name());
73 }
74
75 void set_optimized(bool o) {
76 TORCH_WARN(
77 "CompilationUnit::set_optimized() is deprecated and has no effect. "
78 "Please use setGraphExecutorOptimize()");
79 }
80
81 bool is_optimized() const {
82 TORCH_WARN(
83 "CompilationUnit::is_optimized() is deprecated and always returns true. "
84 "Please use getGraphExecutorOptimize()");
85 return true;
86 }
87
88 // for historic reasons, these are defined in ir_emitter.cpp
89 // Returns the list of Functions just defined.
90 std::vector<Function*> define(
91 const c10::optional<c10::QualifiedName>& prefix,
92 const std::vector<Property>& properties,
93 const std::vector<ResolverPtr>& propResolvers,
94 const std::vector<Def>& definitions,
95 const std::vector<ResolverPtr>&
96 defResolvers, /* determines how we handle free
97 variables in each definition*/
98 // if non-null, the first argument to each def, is bound to this value
99 const Self* self,
100 // see [name mangling]
101 bool shouldMangle = false,
102 c10::optional<size_t> operator_set_version = c10::nullopt);
103
104 void define_hooks(
105 const c10::optional<c10::QualifiedName>& prefix,
106 const std::vector<Def>& hookDefs,
107 const std::vector<ResolverPtr>& hookResolvers,
108 const std::vector<Def>& preHookDefs,
109 const std::vector<ResolverPtr>& preHookResolvers,
110 const Self* self,
111 bool shouldMangle = false);
112
113 // same as above but parse the definitions from source
114 // Returns the list of Functions just defined.
115 std::vector<Function*> define(
116 // prefix namespace to put all the defined functions into
117 const c10::optional<c10::QualifiedName>& prefix,
118 const std::string& source,
119 const ResolverPtr& resolver,
120 const Self* self);
121
122 void define_interface(
123 const c10::QualifiedName& qualifiedName,
124 const ClassDef& classDef,
125 ResolverPtr rcb,
126 bool is_module = false);
127
128 Function* create_function(
129 c10::QualifiedName name,
130 std::shared_ptr<Graph> graph,
131 bool shouldMangle = false) {
132 if (shouldMangle) {
133 name = mangle(name);
134 }
135 auto fn = torch::make_unique<GraphFunction>(
136 std::move(name), std::move(graph), nullptr);
137 auto ret = fn.get();
138 register_function(std::move(fn));
139 return ret;
140 }
141
142 std::vector<Function*> get_functions() const {
143 return fmap(functions_, [](const std::unique_ptr<Function>& fn) {
144 return fn.get();
145 });
146 }
147
148 /// Run a method from this compilation.
149 ///
150 /// For example:
151 /// @code
152 /// IValue output = module->run("relu_script", a, b);
153 /// @endcode
154 ///
155 /// To get a compile a module from a source string, see torch::jit::compile
156 ///
157 /// @param method_name The name of the method to run
158 /// @param args Arguments to be passed to the method
159 /// @return An IValue containing the return value (or values if it is a tuple)
160 /// from the method
161 template <typename... Types>
162 IValue run_method(const c10::QualifiedName& method_name, Types&&... args) {
163 return get_function(method_name)({IValue(std::forward<Types>(args))...});
164 }
165
166 void drop_all_functions() {
167 dict_.clear();
168 functions_.clear();
169 }
170
171 /**
172 * Register a class as being owned by this compilation unit.
173 */
174 void register_type(c10::NamedTypePtr namedType) {
175 // TODO: class types cannot be redefined because we have no way right now
176 // of invalidating their methods. NamedTuples are fine though, since they
177 // don't have methods.
178 TORCH_CHECK(
179 0 == classDict_.count(*namedType->name()),
180 "class '",
181 namedType->name()->qualifiedName(),
182 "' already defined.");
183 classes_.push_back(std::move(namedType));
184 classDict_[*classes_.back()->name()] = classes_.size() - 1;
185 };
186
187 c10::ClassTypePtr get_class(const c10::QualifiedName& name) const {
188 auto type = get_type(name);
189 if (!type) {
190 return nullptr;
191 }
192 return type->cast<c10::ClassType>();
193 }
194
195 c10::InterfaceTypePtr get_interface(const c10::QualifiedName& name) const {
196 auto type = get_type(name);
197 if (!type) {
198 return nullptr;
199 }
200 return type->cast<c10::InterfaceType>();
201 }
202
203 c10::TupleTypePtr get_named_tuple(const c10::QualifiedName& name) const {
204 for (const auto& cls : classes_) {
205 if (cls->name()->qualifiedName() == name.qualifiedName()) {
206 return cls->expect<TupleType>();
207 }
208 }
209 return nullptr;
210 }
211
212 c10::NamedTypePtr get_type(const c10::QualifiedName& name) const {
213 auto it = classDict_.find(name);
214 if (it == classDict_.end()) {
215 return nullptr;
216 }
217 return classes_[it->second];
218 }
219
220 // For testing: clear all Python-defined classes to ensure that unit tests
221 // have isolation.
222 void _clear_python_cu() {
223 // Delete all the associated class methods
224 for (const auto& type : classes_) {
225 if (auto cls = type->cast<ClassType>()) {
226 for (auto method : cls->methods()) {
227 // Tombstone the method in the compilation unit.
228 // Don't erase because the dict_
229 auto it = dict_.find(method->qualname());
230 if (it != dict_.end()) {
231 functions_[it->second] = nullptr;
232 // Erase in our big lookup table
233 dict_.erase(it);
234 }
235 }
236 // Classes can have multiple pointers to the same hook,
237 // need to make sure to not delete it twice
238 std::unordered_set<Function*> hooks_to_delete;
239 for (const auto& hook : cls->getForwardHooks()) {
240 hooks_to_delete.insert(hook);
241 }
242 for (const auto& pre_hook : cls->getForwardPreHooks()) {
243 hooks_to_delete.insert(pre_hook);
244 }
245 for (const auto& hook : hooks_to_delete) {
246 // Tombstone the hook in the compilation unit.
247 auto it = dict_.find(hook->qualname());
248 if (it != dict_.end()) {
249 functions_[it->second] = nullptr;
250 // Erase in our big lookup table
251 dict_.erase(it);
252 }
253 }
254 }
255 }
256 classes_.clear();
257 classDict_.clear();
258 }
259
260 // [Internal Only] Remove method.
261 // Note Used for freezing.
262 void unsafeRemoveMethod(const c10::QualifiedName& method_name) {
263 auto it = dict_.find(method_name);
264 TORCH_CHECK(
265 it != dict_.end(),
266 "method '",
267 method_name.qualifiedName(),
268 "' does not exist.");
269 functions_[it->second] = nullptr;
270 dict_.erase(it);
271 }
272
273 // [name mangling] All code objects must have a unique qualified name in a
274 // CompilationUnit. In Python, sometimes functions won't have unique qualified
275 // name (for example, nested functions). So we mangle Python functions to
276 // ensure that they are uniquely named.
277 //
278 // We also use mangling to distinguish different Module instances. Since each
279 // Module is a singleton class instance, different instances of the same
280 // Python Module will have different types but the same qualified name.
281 c10::QualifiedName mangle(const c10::QualifiedName& name) const {
282 auto mangled = name;
283 while (get_type(mangled) || find_function(mangled)) {
284 mangled = mangler_.mangle(mangled);
285 }
286 return mangled;
287 }
288
289 private:
290 std::unique_ptr<Function> define(
291 const c10::optional<c10::QualifiedName>& prefix,
292 const Def& def,
293 const ResolverPtr& resolver,
294 const Self* self,
295 const std::unordered_map<std::string, Function*>& function_table,
296 bool shouldMangle = false,
297 FunctionType type = FunctionType::Method,
298 c10::optional<size_t> version = c10::nullopt) const;
299
300 // Define a property on \p self.
301 struct PropertyPair;
302 PropertyPair define_property(
303 const c10::optional<c10::QualifiedName>& prefix,
304 const Property& prop,
305 const ResolverPtr& resolver,
306 const Self* self,
307 const std::unordered_map<std::string, Function*>& function_table,
308 bool shouldMangle = false) const;
309
310 Function& register_function(std::unique_ptr<Function> fn) {
311 TORCH_CHECK(
312 0 == dict_.count(fn->qualname().qualifiedName()),
313 "method '",
314 fn->qualname().qualifiedName(),
315 "' already defined.");
316 functions_.emplace_back(std::move(fn));
317 dict_[functions_.back()->qualname()] = functions_.size() - 1;
318 return *functions_.back();
319 }
320 std::vector<std::unique_ptr<Function>> functions_;
321 // for fast lookup
322 std::unordered_map<c10::QualifiedName, size_t> dict_;
323 std::unordered_map<c10::QualifiedName, size_t> classDict_;
324
325 // [class ownership] Right now there aree two relationships between classes
326 // and compilation units:
327 // 1. Classes have compilation units internally that hold their methods.
328 // 2. On load, the TypePtrs of any imported classes are owned by the main
329 // module's compilation unit.
330 std::vector<c10::NamedTypePtr> classes_;
331
332 mutable NameMangler mangler_;
333};
334
335// An owning pointer to a Function. Just a pair of a raw Function ptr and it's
336// owning CU. We need this because pybind requires a ref-counted way to refer to
337// Functions.
338struct StrongFunctionPtr {
339 StrongFunctionPtr(std::shared_ptr<CompilationUnit> cu, Function* function)
340 : cu_(std::move(cu)), function_(function) {
341 TORCH_INTERNAL_ASSERT(cu_);
342 TORCH_INTERNAL_ASSERT(function_);
343 }
344 std::shared_ptr<CompilationUnit> cu_;
345 Function* function_;
346};
347
348namespace script {
349// We once had a `script::` namespace that was deleted. This is for backcompat
350// of the public API; new code should not use this type alias.
351using CompilationUnit = ::torch::jit::CompilationUnit;
352} // namespace script
353} // namespace jit
354} // namespace torch
355