1#pragma once
2
3#include <torch/csrc/jit/api/module.h>
4#include <torch/csrc/jit/frontend/concrete_module_type.h>
5#include <torch/csrc/jit/frontend/sugared_value.h>
6#include <torch/csrc/jit/python/pybind_utils.h>
7#include <memory>
8#include <sstream>
9#include <string>
10#include <utility>
11#include <vector>
12
13namespace torch {
14namespace jit {
15
16std::string typeString(py::handle h);
17
18inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
19 return std::make_shared<SimpleValue>(v);
20}
21
22// NB: This should be the single entry-point for instantiating a SugaredValue
23// from a Python object. If you are adding support for converting a new Python
24// type, *add it in this function's implementation*.
25std::shared_ptr<SugaredValue> toSugaredValue(
26 py::object obj,
27 GraphFunction& m,
28 const SourceRange& loc,
29 bool is_constant = false);
30
31c10::optional<StrongFunctionPtr> as_function(const py::object& obj);
32
33struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
34 PythonValue(
35 py::object the_self,
36 c10::optional<py::object> rcb = c10::nullopt,
37 Value* module_self = nullptr)
38 : self(std::move(the_self)),
39 rcb(std::move(rcb)),
40 moduleSelf_(module_self) {}
41
42 FunctionSchema getSchema(
43 const size_t n_args,
44 const size_t n_binders,
45 const SourceRange& loc);
46
47 // call it like a function, e.g. `outputs = this(inputs)`
48 std::shared_ptr<SugaredValue> call(
49 const SourceRange& loc,
50 GraphFunction& m,
51 at::ArrayRef<NamedValue> args,
52 at::ArrayRef<NamedValue> kwargs,
53 size_t n_binders) override;
54
55 std::string kind() const override;
56
57 std::vector<std::shared_ptr<SugaredValue>> asTuple(
58 const SourceRange& loc,
59 GraphFunction& m,
60 const c10::optional<size_t>& size_hint = {}) override;
61
62 std::shared_ptr<SugaredValue> attr(
63 const SourceRange& loc,
64 GraphFunction& m,
65 const std::string& field) override;
66
67 Value* asValue(const SourceRange& loc, GraphFunction& m) override {
68 throw ErrorReport(loc)
69 << kind() << " cannot be used as a value. "
70 << "Perhaps it is a closed over global variable? If so, please "
71 << "consider passing it in as an argument or use a local varible "
72 << "instead.";
73 }
74
75 protected:
76 py::object getattr(const SourceRange& loc, const std::string& name);
77
78 void checkForAddToConstantsError(std::stringstream& ss);
79
80 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
81 py::object self;
82 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
83 c10::optional<py::object> rcb;
84 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
85 Value* moduleSelf_ = nullptr;
86};
87
88struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
89 explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {}
90
91 std::shared_ptr<SugaredValue> attr(
92 const SourceRange& loc,
93 GraphFunction& m,
94 const std::string& field) override;
95};
96
97// Used for desugaring uses of the torch.cuda module. All the CUDA APIs with
98// torch.cuda.* are resolved using CUDAPythonModuleValue.
99struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
100 explicit CUDAPythonModuleValue(py::object mod)
101 : PythonValue(std::move(mod)) {}
102
103 std::shared_ptr<SugaredValue> attr(
104 const SourceRange& loc,
105 GraphFunction& m,
106 const std::string& field) override;
107};
108
109// Represents all the parameters of a module as a List[Tensor]
110struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
111 ConstantParameterList(Value* the_list) : the_list_(the_list) {}
112 std::string kind() const override {
113 return "constant parameter list";
114 }
115 std::shared_ptr<SugaredValue> call(
116 const SourceRange& loc,
117 GraphFunction& caller,
118 at::ArrayRef<NamedValue> args,
119 at::ArrayRef<NamedValue> kwargs,
120 size_t n_binders) override {
121 return toSimple(the_list_);
122 }
123
124 private:
125 Value* the_list_;
126};
127
128struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
129 explicit ModuleDictMethod(SugaredValuePtr iterable, std::string name)
130 : iterable_(std::move(iterable)), name_(std::move(name)){};
131
132 std::string kind() const override {
133 return name_;
134 }
135
136 std::shared_ptr<SugaredValue> call(
137 const SourceRange& loc,
138 GraphFunction& f,
139 at::ArrayRef<NamedValue> args,
140 at::ArrayRef<NamedValue> kwargs,
141 size_t n_binders) override {
142 if (!args.empty() || !kwargs.empty()) {
143 throw ErrorReport(loc)
144 << name_ << " method does not accept any arguments";
145 }
146 return iterable_;
147 }
148
149 SugaredValuePtr iterable_;
150 const std::string name_;
151};
152
153struct SugaredDict;
154
155// defines how modules/methods behave inside the script subset.
156// for now this does not have any interaction with python.
157// in the future, we will add the ability to resolve `self.foo` to python
158// {functions, modules, constants} so this SugaredValue is defined here
159// anticipating we will eventually need to replace Module with a py::object
160// holding the actual nn.Module class.
161
162struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
163 ModuleValue(Value* self, std::shared_ptr<ConcreteModuleType> concreteType)
164 : self_(self), concreteType_(std::move(concreteType)) {}
165
166 std::string kind() const override {
167 return "module";
168 }
169
170 Value* asValue(const SourceRange& loc, GraphFunction& m) override;
171
172 SugaredValuePtr asTupleValue(const SourceRange& loc, GraphFunction& m)
173 override;
174
175 // select an attribute on it, e.g. `this.field`
176 std::shared_ptr<SugaredValue> tryGetAttr(
177 const SourceRange& loc,
178 GraphFunction& m,
179 const std::string& field);
180
181 // select an attribute on it, e.g. `this.field`
182 std::shared_ptr<SugaredValue> attr(
183 const SourceRange& loc,
184 GraphFunction& m,
185 const std::string& field) override;
186
187 // select an attribute on it, e.g. `this.field`
188 bool hasAttr(
189 const SourceRange& loc,
190 GraphFunction& m,
191 const std::string& field) override;
192
193 // call module.forward with pre_hooks and hooks
194 std::shared_ptr<SugaredValue> call(
195 const SourceRange& loc,
196 GraphFunction& caller,
197 at::ArrayRef<NamedValue> args,
198 at::ArrayRef<NamedValue> kwargs,
199 size_t n_binders) override;
200
201 std::shared_ptr<SugaredDict> getSugaredDict(
202 const SourceRange& loc,
203 GraphFunction& m);
204
205 std::shared_ptr<SugaredDict> getSugaredNamedBufferDict(
206 const SourceRange& loc,
207 GraphFunction& m);
208
209 std::shared_ptr<SugaredDict> getSugaredNamedParameterList(
210 const SourceRange& loc,
211 GraphFunction& m);
212
213 std::shared_ptr<SugaredDict> getSugaredNamedParameterDict(
214 const SourceRange& loc,
215 GraphFunction& m);
216
217 void setAttr(
218 const SourceRange& loc,
219 GraphFunction& m,
220 const std::string& field,
221 Value* newValue) override;
222
223 SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
224
225 std::shared_ptr<SugaredValue> getitem(
226 const SourceRange& loc,
227 GraphFunction& m,
228 Value* idx,
229 TypePtr type_hint) override;
230
231 private:
232 // Check that the type of all submodules is a subtype of ty. If the function
233 // returns false, more information about why it returns false (e.g. which
234 // submodule's type is not a subtype of ty) is printed it why_not if it is not
235 // null.
236 bool areAllSubmodulesSubtypeOf(
237 const TypePtr& ty,
238 std::ostream* why_not = nullptr) const;
239
240 Value* self_;
241 std::shared_ptr<ConcreteModuleType> concreteType_;
242};
243
244bool isNamedTupleClass(const py::object& obj);
245TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc);
246
247void recurseThroughNestedModules(
248 const SourceRange& loc,
249 GraphFunction& m,
250 std::vector<SugaredValuePtr>& keys,
251 std::vector<SugaredValuePtr>& values,
252 std::shared_ptr<ModuleValue>& self,
253 const std::string& prefix,
254 const std::string& field);
255
256// Used to support named_modules()
257struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue {
258 explicit SugaredDict(
259 std::shared_ptr<ModuleValue> self,
260 std::shared_ptr<SugaredTupleValue> keys,
261 std::shared_ptr<SugaredTupleValue> modules)
262 : self_(std::move(self)),
263 keys_(std::move(keys)),
264 modules_(std::move(modules)) {}
265
266 std::string kind() const override {
267 return "ModuleDict";
268 }
269
270 std::shared_ptr<SugaredTupleValue> getKeys() {
271 return keys_;
272 }
273
274 std::shared_ptr<SugaredTupleValue> getModules() {
275 return modules_;
276 }
277
278 std::shared_ptr<SugaredValue> attr(
279 const SourceRange& loc,
280 GraphFunction& m,
281 const std::string& field) override;
282
283 SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override {
284 return keys_;
285 };
286
287 std::shared_ptr<ModuleValue> self_;
288 std::shared_ptr<SugaredTupleValue> keys_;
289 std::shared_ptr<SugaredTupleValue> modules_;
290};
291
292struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
293 BooleanDispatchValue(py::dict dispatched_fn)
294 : dispatched_fn_(std::move(dispatched_fn)) {}
295
296 std::string kind() const override {
297 return "boolean dispatch";
298 }
299
300 std::shared_ptr<SugaredValue> call(
301 const SourceRange& loc,
302 GraphFunction& caller,
303 at::ArrayRef<NamedValue> args,
304 at::ArrayRef<NamedValue> kwargs,
305 size_t n_binders) override;
306
307 private:
308 py::dict dispatched_fn_;
309};
310
311struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
312 PythonClassValue(ClassTypePtr type, py::object py_type)
313 : ClassValue(std::move(type)), py_type_(std::move(py_type)) {}
314
315 std::string kind() const override {
316 return "Python type";
317 }
318
319 std::shared_ptr<SugaredValue> attr(
320 const SourceRange& loc,
321 GraphFunction& m,
322 const std::string& field) override;
323
324 bool hasAttr(
325 const SourceRange& loc,
326 GraphFunction& m,
327 const std::string& field) override;
328
329 private:
330 py::object py_type_;
331};
332
333struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
334 explicit PythonExceptionValue(const py::object& exception_class)
335 : ExceptionValue(
336 py::str(py::getattr(exception_class, "__name__", py::str("")))),
337 exception_class_qualified_name_(
338 py::str(py::module::import("torch._jit_internal")
339 .attr("_qualified_name")(
340 exception_class,
341 /*mangle_name=*/false))) {}
342
343 std::string kind() const override {
344 return "Python exception";
345 }
346
347 std::shared_ptr<SugaredValue> call(
348 const SourceRange& loc,
349 GraphFunction& caller,
350 at::ArrayRef<NamedValue> args,
351 at::ArrayRef<NamedValue> kwargs,
352 size_t n_binders) override;
353
354 private:
355 std::string exception_class_qualified_name_;
356};
357
358// Python Slice class.
359struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue {
360 explicit PythonSliceClass() = default;
361
362 std::string kind() const override {
363 return "Python slice class";
364 }
365
366 std::shared_ptr<SugaredValue> call(
367 const SourceRange& loc,
368 GraphFunction& caller,
369 at::ArrayRef<NamedValue> args,
370 at::ArrayRef<NamedValue> kwargs,
371 size_t n_binders) override;
372};
373
374} // namespace jit
375} // namespace torch
376