1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/frontend/concrete_module_type.h> |
5 | #include <torch/csrc/jit/frontend/sugared_value.h> |
6 | #include <torch/csrc/jit/python/pybind_utils.h> |
7 | #include <memory> |
8 | #include <sstream> |
9 | #include <string> |
10 | #include <utility> |
11 | #include <vector> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | |
16 | std::string typeString(py::handle h); |
17 | |
18 | inline std::shared_ptr<SugaredValue> toSimple(Value* v) { |
19 | return std::make_shared<SimpleValue>(v); |
20 | } |
21 | |
22 | // NB: This should be the single entry-point for instantiating a SugaredValue |
23 | // from a Python object. If you are adding support for converting a new Python |
24 | // type, *add it in this function's implementation*. |
25 | std::shared_ptr<SugaredValue> toSugaredValue( |
26 | py::object obj, |
27 | GraphFunction& m, |
28 | const SourceRange& loc, |
29 | bool is_constant = false); |
30 | |
31 | c10::optional<StrongFunctionPtr> as_function(const py::object& obj); |
32 | |
33 | struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { |
34 | PythonValue( |
35 | py::object the_self, |
36 | c10::optional<py::object> rcb = c10::nullopt, |
37 | Value* module_self = nullptr) |
38 | : self(std::move(the_self)), |
39 | rcb(std::move(rcb)), |
40 | moduleSelf_(module_self) {} |
41 | |
42 | FunctionSchema getSchema( |
43 | const size_t n_args, |
44 | const size_t n_binders, |
45 | const SourceRange& loc); |
46 | |
47 | // call it like a function, e.g. `outputs = this(inputs)` |
48 | std::shared_ptr<SugaredValue> call( |
49 | const SourceRange& loc, |
50 | GraphFunction& m, |
51 | at::ArrayRef<NamedValue> args, |
52 | at::ArrayRef<NamedValue> kwargs, |
53 | size_t n_binders) override; |
54 | |
55 | std::string kind() const override; |
56 | |
57 | std::vector<std::shared_ptr<SugaredValue>> asTuple( |
58 | const SourceRange& loc, |
59 | GraphFunction& m, |
60 | const c10::optional<size_t>& size_hint = {}) override; |
61 | |
62 | std::shared_ptr<SugaredValue> attr( |
63 | const SourceRange& loc, |
64 | GraphFunction& m, |
65 | const std::string& field) override; |
66 | |
67 | Value* asValue(const SourceRange& loc, GraphFunction& m) override { |
68 | throw ErrorReport(loc) |
69 | << kind() << " cannot be used as a value. " |
70 | << "Perhaps it is a closed over global variable? If so, please " |
71 | << "consider passing it in as an argument or use a local varible " |
72 | << "instead." ; |
73 | } |
74 | |
75 | protected: |
76 | py::object getattr(const SourceRange& loc, const std::string& name); |
77 | |
78 | void checkForAddToConstantsError(std::stringstream& ss); |
79 | |
80 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
81 | py::object self; |
82 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
83 | c10::optional<py::object> rcb; |
84 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
85 | Value* moduleSelf_ = nullptr; |
86 | }; |
87 | |
88 | struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { |
89 | explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {} |
90 | |
91 | std::shared_ptr<SugaredValue> attr( |
92 | const SourceRange& loc, |
93 | GraphFunction& m, |
94 | const std::string& field) override; |
95 | }; |
96 | |
97 | // Used for desugaring uses of the torch.cuda module. All the CUDA APIs with |
98 | // torch.cuda.* are resolved using CUDAPythonModuleValue. |
99 | struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue { |
100 | explicit CUDAPythonModuleValue(py::object mod) |
101 | : PythonValue(std::move(mod)) {} |
102 | |
103 | std::shared_ptr<SugaredValue> attr( |
104 | const SourceRange& loc, |
105 | GraphFunction& m, |
106 | const std::string& field) override; |
107 | }; |
108 | |
109 | // Represents all the parameters of a module as a List[Tensor] |
110 | struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { |
111 | ConstantParameterList(Value* the_list) : the_list_(the_list) {} |
112 | std::string kind() const override { |
113 | return "constant parameter list" ; |
114 | } |
115 | std::shared_ptr<SugaredValue> call( |
116 | const SourceRange& loc, |
117 | GraphFunction& caller, |
118 | at::ArrayRef<NamedValue> args, |
119 | at::ArrayRef<NamedValue> kwargs, |
120 | size_t n_binders) override { |
121 | return toSimple(the_list_); |
122 | } |
123 | |
124 | private: |
125 | Value* the_list_; |
126 | }; |
127 | |
128 | struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue { |
129 | explicit ModuleDictMethod(SugaredValuePtr iterable, std::string name) |
130 | : iterable_(std::move(iterable)), name_(std::move(name)){}; |
131 | |
132 | std::string kind() const override { |
133 | return name_; |
134 | } |
135 | |
136 | std::shared_ptr<SugaredValue> call( |
137 | const SourceRange& loc, |
138 | GraphFunction& f, |
139 | at::ArrayRef<NamedValue> args, |
140 | at::ArrayRef<NamedValue> kwargs, |
141 | size_t n_binders) override { |
142 | if (!args.empty() || !kwargs.empty()) { |
143 | throw ErrorReport(loc) |
144 | << name_ << " method does not accept any arguments" ; |
145 | } |
146 | return iterable_; |
147 | } |
148 | |
149 | SugaredValuePtr iterable_; |
150 | const std::string name_; |
151 | }; |
152 | |
153 | struct SugaredDict; |
154 | |
155 | // defines how modules/methods behave inside the script subset. |
156 | // for now this does not have any interaction with python. |
157 | // in the future, we will add the ability to resolve `self.foo` to python |
158 | // {functions, modules, constants} so this SugaredValue is defined here |
159 | // anticipating we will eventually need to replace Module with a py::object |
160 | // holding the actual nn.Module class. |
161 | |
162 | struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { |
163 | ModuleValue(Value* self, std::shared_ptr<ConcreteModuleType> concreteType) |
164 | : self_(self), concreteType_(std::move(concreteType)) {} |
165 | |
166 | std::string kind() const override { |
167 | return "module" ; |
168 | } |
169 | |
170 | Value* asValue(const SourceRange& loc, GraphFunction& m) override; |
171 | |
172 | SugaredValuePtr asTupleValue(const SourceRange& loc, GraphFunction& m) |
173 | override; |
174 | |
175 | // select an attribute on it, e.g. `this.field` |
176 | std::shared_ptr<SugaredValue> tryGetAttr( |
177 | const SourceRange& loc, |
178 | GraphFunction& m, |
179 | const std::string& field); |
180 | |
181 | // select an attribute on it, e.g. `this.field` |
182 | std::shared_ptr<SugaredValue> attr( |
183 | const SourceRange& loc, |
184 | GraphFunction& m, |
185 | const std::string& field) override; |
186 | |
187 | // select an attribute on it, e.g. `this.field` |
188 | bool hasAttr( |
189 | const SourceRange& loc, |
190 | GraphFunction& m, |
191 | const std::string& field) override; |
192 | |
193 | // call module.forward with pre_hooks and hooks |
194 | std::shared_ptr<SugaredValue> call( |
195 | const SourceRange& loc, |
196 | GraphFunction& caller, |
197 | at::ArrayRef<NamedValue> args, |
198 | at::ArrayRef<NamedValue> kwargs, |
199 | size_t n_binders) override; |
200 | |
201 | std::shared_ptr<SugaredDict> getSugaredDict( |
202 | const SourceRange& loc, |
203 | GraphFunction& m); |
204 | |
205 | std::shared_ptr<SugaredDict> getSugaredNamedBufferDict( |
206 | const SourceRange& loc, |
207 | GraphFunction& m); |
208 | |
209 | std::shared_ptr<SugaredDict> getSugaredNamedParameterList( |
210 | const SourceRange& loc, |
211 | GraphFunction& m); |
212 | |
213 | std::shared_ptr<SugaredDict> getSugaredNamedParameterDict( |
214 | const SourceRange& loc, |
215 | GraphFunction& m); |
216 | |
217 | void setAttr( |
218 | const SourceRange& loc, |
219 | GraphFunction& m, |
220 | const std::string& field, |
221 | Value* newValue) override; |
222 | |
223 | SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override; |
224 | |
225 | std::shared_ptr<SugaredValue> getitem( |
226 | const SourceRange& loc, |
227 | GraphFunction& m, |
228 | Value* idx, |
229 | TypePtr type_hint) override; |
230 | |
231 | private: |
232 | // Check that the type of all submodules is a subtype of ty. If the function |
233 | // returns false, more information about why it returns false (e.g. which |
234 | // submodule's type is not a subtype of ty) is printed it why_not if it is not |
235 | // null. |
236 | bool areAllSubmodulesSubtypeOf( |
237 | const TypePtr& ty, |
238 | std::ostream* why_not = nullptr) const; |
239 | |
240 | Value* self_; |
241 | std::shared_ptr<ConcreteModuleType> concreteType_; |
242 | }; |
243 | |
244 | bool isNamedTupleClass(const py::object& obj); |
245 | TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc); |
246 | |
247 | void recurseThroughNestedModules( |
248 | const SourceRange& loc, |
249 | GraphFunction& m, |
250 | std::vector<SugaredValuePtr>& keys, |
251 | std::vector<SugaredValuePtr>& values, |
252 | std::shared_ptr<ModuleValue>& self, |
253 | const std::string& prefix, |
254 | const std::string& field); |
255 | |
256 | // Used to support named_modules() |
257 | struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue { |
258 | explicit SugaredDict( |
259 | std::shared_ptr<ModuleValue> self, |
260 | std::shared_ptr<SugaredTupleValue> keys, |
261 | std::shared_ptr<SugaredTupleValue> modules) |
262 | : self_(std::move(self)), |
263 | keys_(std::move(keys)), |
264 | modules_(std::move(modules)) {} |
265 | |
266 | std::string kind() const override { |
267 | return "ModuleDict" ; |
268 | } |
269 | |
270 | std::shared_ptr<SugaredTupleValue> getKeys() { |
271 | return keys_; |
272 | } |
273 | |
274 | std::shared_ptr<SugaredTupleValue> getModules() { |
275 | return modules_; |
276 | } |
277 | |
278 | std::shared_ptr<SugaredValue> attr( |
279 | const SourceRange& loc, |
280 | GraphFunction& m, |
281 | const std::string& field) override; |
282 | |
283 | SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override { |
284 | return keys_; |
285 | }; |
286 | |
287 | std::shared_ptr<ModuleValue> self_; |
288 | std::shared_ptr<SugaredTupleValue> keys_; |
289 | std::shared_ptr<SugaredTupleValue> modules_; |
290 | }; |
291 | |
292 | struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { |
293 | BooleanDispatchValue(py::dict dispatched_fn) |
294 | : dispatched_fn_(std::move(dispatched_fn)) {} |
295 | |
296 | std::string kind() const override { |
297 | return "boolean dispatch" ; |
298 | } |
299 | |
300 | std::shared_ptr<SugaredValue> call( |
301 | const SourceRange& loc, |
302 | GraphFunction& caller, |
303 | at::ArrayRef<NamedValue> args, |
304 | at::ArrayRef<NamedValue> kwargs, |
305 | size_t n_binders) override; |
306 | |
307 | private: |
308 | py::dict dispatched_fn_; |
309 | }; |
310 | |
311 | struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue { |
312 | PythonClassValue(ClassTypePtr type, py::object py_type) |
313 | : ClassValue(std::move(type)), py_type_(std::move(py_type)) {} |
314 | |
315 | std::string kind() const override { |
316 | return "Python type" ; |
317 | } |
318 | |
319 | std::shared_ptr<SugaredValue> attr( |
320 | const SourceRange& loc, |
321 | GraphFunction& m, |
322 | const std::string& field) override; |
323 | |
324 | bool hasAttr( |
325 | const SourceRange& loc, |
326 | GraphFunction& m, |
327 | const std::string& field) override; |
328 | |
329 | private: |
330 | py::object py_type_; |
331 | }; |
332 | |
333 | struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue { |
334 | explicit PythonExceptionValue(const py::object& exception_class) |
335 | : ExceptionValue( |
336 | py::str(py::getattr(exception_class, "__name__" , py::str("" )))), |
337 | exception_class_qualified_name_( |
338 | py::str(py::module::import("torch._jit_internal" ) |
339 | .attr("_qualified_name" )( |
340 | exception_class, |
341 | /*mangle_name=*/false))) {} |
342 | |
343 | std::string kind() const override { |
344 | return "Python exception" ; |
345 | } |
346 | |
347 | std::shared_ptr<SugaredValue> call( |
348 | const SourceRange& loc, |
349 | GraphFunction& caller, |
350 | at::ArrayRef<NamedValue> args, |
351 | at::ArrayRef<NamedValue> kwargs, |
352 | size_t n_binders) override; |
353 | |
354 | private: |
355 | std::string exception_class_qualified_name_; |
356 | }; |
357 | |
358 | // Python Slice class. |
359 | struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue { |
360 | explicit PythonSliceClass() = default; |
361 | |
362 | std::string kind() const override { |
363 | return "Python slice class" ; |
364 | } |
365 | |
366 | std::shared_ptr<SugaredValue> call( |
367 | const SourceRange& loc, |
368 | GraphFunction& caller, |
369 | at::ArrayRef<NamedValue> args, |
370 | at::ArrayRef<NamedValue> kwargs, |
371 | size_t n_binders) override; |
372 | }; |
373 | |
374 | } // namespace jit |
375 | } // namespace torch |
376 | |