1 | #pragma once |
2 | #include <c10/util/Optional.h> |
3 | #include <functional> |
4 | #include <memory> |
5 | #include <string> |
6 | #include <utility> |
7 | |
8 | #include <ATen/core/symbol.h> |
9 | #include <caffe2/serialize/versions.h> |
10 | #include <torch/csrc/jit/api/module.h> |
11 | #include <torch/csrc/jit/frontend/error_report.h> |
12 | #include <torch/csrc/jit/frontend/schema_matching.h> |
13 | #include <torch/csrc/jit/frontend/versioned_symbols.h> |
14 | #include <torch/csrc/jit/ir/ir.h> |
15 | |
16 | namespace torch { |
17 | namespace jit { |
18 | |
19 | using SugaredValuePtr = std::shared_ptr<SugaredValue>; |
20 | |
21 | // The AST can contain nodes like `self`, `self.b` or `python_fn` that |
22 | // are not first-class values in the graph representation, but instead |
23 | // will be desugared based on how they are used in the AST. |
24 | |
25 | // SugaredValue is used to temporarily represent these values in a way |
26 | // that separates their behavior from the AST -> IR converter itself. |
27 | // This allows us to keep dependencies on python minimal. |
28 | |
29 | struct TORCH_API SugaredValue |
30 | : public std::enable_shared_from_this<SugaredValue> { |
31 | // what is this node? for error reporting (e.g. Module, python function) |
32 | virtual std::string kind() const = 0; |
33 | |
34 | // what can we do with this thing? |
35 | // use it as a value e.g. `this + 4` |
36 | virtual Value* asValue(const SourceRange& loc, GraphFunction& m) { |
37 | throw ErrorReport(loc) << kind() << " cannot be used as a value" ; |
38 | } |
39 | |
40 | // select an attribute on it, e.g. `this.field` |
41 | virtual std::shared_ptr<SugaredValue> attr( |
42 | const SourceRange& loc, |
43 | GraphFunction& m, |
44 | const std::string& field) { |
45 | throw ErrorReport(loc) << "attribute lookup is not defined on " << kind(); |
46 | } |
47 | |
48 | virtual bool hasAttr( |
49 | const SourceRange& loc, |
50 | GraphFunction& m, |
51 | const std::string& field) { |
52 | throw ErrorReport(loc) << "attribute lookup is not defined on " << kind(); |
53 | } |
54 | |
55 | // assign an attribute on it, e.g. `this.field = newValue` |
56 | virtual void setAttr( |
57 | const SourceRange& loc, |
58 | GraphFunction& m, |
59 | const std::string& field, |
60 | Value* newValue) { |
61 | throw ErrorReport(loc) << "attribute assignment is not defined on " |
62 | << kind(); |
63 | } |
64 | |
65 | // use it as a vector of values, e.g. a tuple of values as return value from |
66 | // a method invocation |
67 | virtual std::vector<std::shared_ptr<SugaredValue>> asTuple( |
68 | const SourceRange& loc, |
69 | GraphFunction& m, |
70 | const c10::optional<size_t>& size_hint = {}) { |
71 | throw ErrorReport(loc) << kind() << " cannot be used as a tuple" ; |
72 | } |
73 | |
74 | // TODO @wconstab refactor to use ModuleValue::asTuple instead of new API |
75 | virtual SugaredValuePtr asTupleValue( |
76 | const SourceRange& loc, |
77 | GraphFunction& m) { |
78 | throw ErrorReport(loc) << kind() << " cannot be used as a tuplevalue" ; |
79 | } |
80 | |
81 | virtual std::vector<std::shared_ptr<SugaredValue>> asType( |
82 | const SourceRange& loc, |
83 | Method& m) { |
84 | throw ErrorReport(loc) << kind() << " cannot be used as a type" ; |
85 | } |
86 | |
87 | // call it like a function, e.g. `outputs = this(inputs)` |
88 | virtual std::shared_ptr<SugaredValue> call( |
89 | const SourceRange& loc, |
90 | GraphFunction& m, |
91 | // note: names for args will be 'argument 0', 'argument 1', etc.. |
92 | at::ArrayRef<NamedValue> args, |
93 | at::ArrayRef<NamedValue> kwargs, |
94 | size_t n_binders) { |
95 | // n_binders is always set to the number of variables an expression is |
96 | // syntactically bound to: |
97 | // a = foo() # 1 binder (note in this case the single binder might be a |
98 | // tuple) a, * b = foo() # 1 binder a, b = foo() # 2 binders foo() # 0 |
99 | // binders |
100 | // |
101 | // In subexpressions, like bar() in foo(bar()), n_binders is always set to |
102 | // 1. n_binders is used as a hint to subexpressions to determine how many |
103 | // values they should return when that number is ambiguous statically. In |
104 | // particular it is currently used to decide how many tensors a call to a |
105 | // python function will return. It is only a hint, functions do not have to |
106 | // check that n_binders match the number of things they are returning, the |
107 | // assignment logic will do that anyway. |
108 | |
109 | throw ErrorReport(loc) << "cannot call a " << kind(); |
110 | } |
111 | |
112 | // This function is called when to convert a SugaredValue to its iterator. |
113 | // For example, when iterating through a Dict we iterate over its keys |
114 | virtual std::shared_ptr<SugaredValue> iter( |
115 | const SourceRange& loc, |
116 | GraphFunction& m) { |
117 | throw ErrorReport(loc) << kind() << " cannot be used as an iterable" ; |
118 | } |
119 | |
120 | // If we are iterating over a Sugared Value and it returns a value from this |
121 | // function, then we emit an unrolled loop over the variable. This allows us |
122 | // to support containers of Heterogenous types, like Module Containers & |
123 | // Tuples |
124 | virtual c10::optional<int64_t> staticLen() { |
125 | return c10::nullopt; |
126 | } |
127 | |
128 | // When iterating over this SugaredValue, should we emit the for loop as an |
129 | // unrolled loop. |
130 | bool shouldEmitUnrolled() { |
131 | return staticLen() != c10::nullopt; |
132 | } |
133 | |
134 | // return length of this thing, if not then it can't be iterated. |
135 | // If it does not have a statically-determinable length, then it cannot |
136 | // be iterated over with a modulelist. If it does it must return a constant |
137 | // Value * |
138 | virtual Value* len(const SourceRange& loc, GraphFunction& m) { |
139 | throw ErrorReport(loc) << "'" << kind() << "'" |
140 | << " object is not iterable" ; |
141 | } |
142 | |
143 | // expression for ith elemement for iterable value |
144 | virtual std::shared_ptr<SugaredValue> getitem( |
145 | const SourceRange& loc, |
146 | GraphFunction& m, |
147 | Value* idx, |
148 | TypePtr type_hint = nullptr) { |
149 | throw ErrorReport(loc) << "'" << kind() << "'" |
150 | << " object is not subscriptable" ; |
151 | } |
152 | |
153 | virtual ~SugaredValue() = default; |
154 | }; |
155 | |
156 | // most things in the environment are just simple value types |
157 | // and not special python syntax sugar types |
158 | struct TORCH_API SimpleValue : public SugaredValue { |
159 | SimpleValue(Value* value) : value_(value) {} |
160 | std::string kind() const override { |
161 | std::stringstream ss; |
162 | // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) |
163 | ss << "value of type '" << value_->type()->annotation_str() << "'" ; |
164 | return ss.str(); |
165 | } |
166 | Value* asValue(const SourceRange& range, GraphFunction& m) override { |
167 | return value_; |
168 | } |
169 | std::vector<std::shared_ptr<SugaredValue>> asTuple( |
170 | const SourceRange& loc, |
171 | GraphFunction& m, |
172 | const c10::optional<size_t>& size_hint = {}) override; |
173 | std::shared_ptr<SugaredValue> attr( |
174 | const SourceRange& loc, |
175 | GraphFunction& m, |
176 | const std::string& field) override; |
177 | |
178 | bool hasAttr( |
179 | const SourceRange& loc, |
180 | GraphFunction& m, |
181 | const std::string& field) override; |
182 | |
183 | void setAttr( |
184 | const SourceRange& loc, |
185 | GraphFunction& m, |
186 | const std::string& field, |
187 | Value* newValue) override; |
188 | |
189 | std::shared_ptr<SugaredValue> call( |
190 | const SourceRange& loc, |
191 | GraphFunction& m, |
192 | // note: names for args will be 'argument 0', 'argument 1', etc.. |
193 | at::ArrayRef<NamedValue> args, |
194 | at::ArrayRef<NamedValue> kwargs, |
195 | size_t n_binders) override; |
196 | |
197 | std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m) |
198 | override; |
199 | |
200 | Value* getValue() const { |
201 | return value_; |
202 | } |
203 | |
204 | Value* len(const SourceRange& loc, GraphFunction& m) override; |
205 | SugaredValuePtr getitem( |
206 | const SourceRange& loc, |
207 | GraphFunction& m, |
208 | Value* idx, |
209 | TypePtr type_hint = nullptr) override; |
210 | |
211 | private: |
212 | Value* value_; |
213 | }; |
214 | |
215 | struct TORCH_API BuiltinFunction : public SugaredValue { |
216 | BuiltinFunction(Symbol symbol, c10::optional<NamedValue> self) |
217 | : symbol(symbol), self(std::move(self)) {} |
218 | |
219 | // The symbol of the function (e.g. `aten::relu`). |
220 | Symbol symbol; |
221 | |
222 | // if this is method, then this is the self argument. |
223 | c10::optional<NamedValue> self; |
224 | std::string kind() const override { |
225 | return "builtin" ; |
226 | } |
227 | std::shared_ptr<SugaredValue> call( |
228 | const SourceRange& loc, |
229 | GraphFunction& m, |
230 | at::ArrayRef<NamedValue> args, |
231 | at::ArrayRef<NamedValue> kwargs, |
232 | size_t n_binders) override; |
233 | |
234 | // try to create this builtin but if it doesn't exist or the self argument |
235 | // cannot possibly match, then return nullptr. Use in situations where it is |
236 | // not clear if it is a valid builtin |
237 | static std::shared_ptr<BuiltinFunction> tryCreate( |
238 | Symbol symbol, |
239 | c10::optional<NamedValue> self); |
240 | }; |
241 | |
242 | struct TORCH_API SugaredTupleValue : public SugaredValue { |
243 | explicit SugaredTupleValue(std::vector<std::shared_ptr<SugaredValue>> tup) |
244 | : tup_(std::move(tup)){}; |
245 | |
246 | std::vector<std::shared_ptr<SugaredValue>> asTuple( |
247 | const SourceRange& loc, |
248 | GraphFunction& m, |
249 | const c10::optional<size_t>& size_hint = {}) override { |
250 | return tup_; |
251 | }; |
252 | |
253 | Value* asValue(const SourceRange& loc, GraphFunction& m) override { |
254 | std::vector<Value*> vec; |
255 | vec.reserve(tup_.size()); |
256 | for (const auto& sv : tup_) { |
257 | vec.push_back(sv->asValue(loc, m)); |
258 | } |
259 | Graph& g = *m.graph(); |
260 | return g.insertNode(g.createTuple(vec))->output(); |
261 | } |
262 | |
263 | std::string kind() const override { |
264 | return "Tuple" ; |
265 | } |
266 | |
267 | SugaredValuePtr getitem( |
268 | const SourceRange& loc, |
269 | GraphFunction& m, |
270 | Value* idx, |
271 | TypePtr type_hint = nullptr) override { |
272 | if (!(idx->type()->cast<IntType>() && toIValue(idx))) { |
273 | throw ErrorReport(loc) |
274 | << "Expected integer literal for index. " |
275 | << "ModuleList/Sequential indexing is only supported with integer literals. " |
276 | << "Enumeration is supported, e.g. 'for index, v in enumerate(self): ...'" ; |
277 | } |
278 | auto index = toIValue(idx)->toInt(); |
279 | int64_t adj_index = |
280 | (index < 0) ? index + static_cast<int64_t>(tup_.size()) : index; |
281 | if (!(adj_index >= 0 && adj_index < static_cast<int64_t>(tup_.size()))) { |
282 | throw ErrorReport(loc) |
283 | << "Index " << index << " out of range of length " << tup_.size(); |
284 | } |
285 | return tup_.at(adj_index); |
286 | } |
287 | |
288 | // This function is called when a SugaredValue is used to convert a |
289 | // SugaredValue to its iterator. For example, when iterating through a Dict we |
290 | // iterate over its keys |
291 | std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m) |
292 | override { |
293 | return shared_from_this(); |
294 | }; |
295 | |
296 | // Because this is used to contain SugaredValues of Heterogenous types, |
297 | // we define staticLen() so that when this is iterated over it is emitted |
298 | // as an unrolled loop. |
299 | c10::optional<int64_t> staticLen() override { |
300 | return static_cast<int64_t>(tup_.size()); |
301 | } |
302 | |
303 | std::vector<std::shared_ptr<SugaredValue>> tup_; |
304 | }; |
305 | |
306 | struct TORCH_API BuiltinModule : public SugaredValue { |
307 | BuiltinModule(std::string name, c10::optional<int64_t> version = at::nullopt) |
308 | : name(std::move(name)), version(version) {} |
309 | |
310 | std::string kind() const override { |
311 | return "builtin module" ; |
312 | } |
313 | std::shared_ptr<SugaredValue> attr( |
314 | const SourceRange& loc, |
315 | GraphFunction& m, |
316 | const std::string& field) override { |
317 | if (field == "autograd" ) { |
318 | // When refering torch.autograd, it is also considered to be a |
319 | // BuiltinModule and we will dispatch to the aten operators for the |
320 | // methods under its module. |
321 | return std::make_shared<BuiltinModule>("aten" , version); |
322 | } |
323 | |
324 | auto sym = Symbol::fromQualString(name + "::" + field); |
325 | return std::make_shared<BuiltinFunction>(sym, c10::nullopt); |
326 | } |
327 | |
328 | private: |
329 | std::string name; |
330 | // when we add operator versioning, emit this op as it exising at 'version' |
331 | // if not set, use the latest version |
332 | c10::optional<int64_t> version; |
333 | }; |
334 | |
335 | // Represents a class, analagous to `int` or `dict`. Instances of classes, |
336 | // like `1` or `{"foo": 5}`, are represented as SimpleValues |
337 | struct TORCH_API ClassValue : public SugaredValue { |
338 | explicit ClassValue(ClassTypePtr type) : type_(std::move(type)) {} |
339 | |
340 | // Call the type's constructor, as in: |
341 | // n = Foo(constructor_arg) |
342 | std::shared_ptr<SugaredValue> call( |
343 | const SourceRange& loc, |
344 | GraphFunction& m, |
345 | at::ArrayRef<NamedValue> args, |
346 | at::ArrayRef<NamedValue> kwargs, |
347 | size_t n_binders) override; |
348 | |
349 | std::shared_ptr<SugaredValue> attr( |
350 | const SourceRange& loc, |
351 | GraphFunction& m, |
352 | const std::string& field) override; |
353 | |
354 | std::string kind() const override { |
355 | return type_->str(); |
356 | } |
357 | |
358 | ClassTypePtr type_; |
359 | }; |
360 | |
361 | struct TORCH_API NamedTupleConstructor : public SugaredValue { |
362 | explicit NamedTupleConstructor(TupleTypePtr type) : type_(std::move(type)) {} |
363 | |
364 | std::shared_ptr<SugaredValue> call( |
365 | const SourceRange& loc, |
366 | GraphFunction& m, |
367 | at::ArrayRef<NamedValue> args, |
368 | at::ArrayRef<NamedValue> kwargs, |
369 | size_t n_binders) override; |
370 | |
371 | std::string kind() const override { |
372 | return type_->str(); |
373 | } |
374 | |
375 | TupleTypePtr type_; |
376 | }; |
377 | |
378 | struct FunctionValue : public SugaredValue { |
379 | FunctionValue(Function* callee) : callees_({callee}) {} |
380 | FunctionValue(const StrongFunctionPtr& p) |
381 | : callees_({p.function_}), cu_(p.cu_) {} |
382 | FunctionValue(const std::vector<StrongFunctionPtr>& callees) { |
383 | for (const StrongFunctionPtr& callee : callees) { |
384 | cu_ = cu_ ? cu_ : callee.cu_; |
385 | TORCH_INTERNAL_ASSERT(callee.cu_ == cu_); |
386 | callees_.push_back(callee.function_); |
387 | } |
388 | } |
389 | |
390 | std::string kind() const override { |
391 | return "function" ; |
392 | } |
393 | |
394 | std::shared_ptr<SugaredValue> call( |
395 | const SourceRange& loc, |
396 | GraphFunction& f, |
397 | at::ArrayRef<NamedValue> args, |
398 | at::ArrayRef<NamedValue> kwargs, |
399 | size_t n_binders) override { |
400 | std::vector<const FunctionSchema*> schemas; |
401 | for (Function* callee : callees_) { |
402 | try { |
403 | callee->ensure_defined(); |
404 | } catch (const RecursiveMethodCallError&) { |
405 | throw ErrorReport(loc) |
406 | << " function '" << callee->name() << "' is called recursively. " |
407 | << "Recursive calls are not supported" ; |
408 | } |
409 | schemas.push_back(&callee->getSchema()); |
410 | } |
411 | auto match = matchSchemas(schemas, loc, *f.graph(), args, kwargs); |
412 | Value* output = |
413 | f.graph()->insertFunctionCall(callees_[match.first], match.second); |
414 | output->node()->setSourceRange(loc); |
415 | return std::make_shared<SimpleValue>(output); |
416 | } |
417 | |
418 | const std::vector<Function*>& callees() { |
419 | return callees_; |
420 | } |
421 | |
422 | private: |
423 | std::vector<Function*> callees_; |
424 | // TODO holding this thing is creepy |
425 | std::shared_ptr<CompilationUnit> cu_; |
426 | }; |
427 | |
428 | struct TORCH_API ClosureValue : public SugaredValue { |
429 | ClosureValue(Value* value) : value_(value) { |
430 | TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Closure); |
431 | } |
432 | std::string kind() const override { |
433 | return "closure" ; |
434 | } |
435 | Value* asValue(const SourceRange& range, GraphFunction& m) override { |
436 | return value_; |
437 | } |
438 | Value* value_; |
439 | }; |
440 | |
441 | // defines how a method obtained from a module/class/interface behaves in script |
442 | struct MethodValue : public SugaredValue { |
443 | MethodValue(Value* self, std::vector<std::string> method_names) |
444 | : self_(self), method_names_(std::move(method_names)) {} |
445 | MethodValue(Value* self, std::string method_name) |
446 | : MethodValue(self, std::vector<std::string>({std::move(method_name)})) {} |
447 | |
448 | std::string kind() const override { |
449 | return "method" ; |
450 | } |
451 | |
452 | std::shared_ptr<SugaredValue> call( |
453 | const SourceRange& loc, |
454 | GraphFunction& f, |
455 | at::ArrayRef<NamedValue> args, |
456 | at::ArrayRef<NamedValue> kwargs, |
457 | size_t n_binders) override { |
458 | std::vector<NamedValue> argsWithSelf = {self_}; |
459 | argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end()); |
460 | std::vector<const FunctionSchema*> schemas; |
461 | for (const std::string& method_name : method_names_) { |
462 | if (auto class_type = self_->type()->cast<ClassType>()) { |
463 | Function& method = class_type->getMethod(method_name); |
464 | try { |
465 | method.ensure_defined(); |
466 | } catch (const RecursiveMethodCallError&) { |
467 | throw ErrorReport(loc) |
468 | << " method '" << method.name() << "' is called recursively. " |
469 | << "Recursive calls are not supported" ; |
470 | } |
471 | schemas.push_back(&method.getSchema()); |
472 | } else if (auto interface_type = self_->type()->cast<InterfaceType>()) { |
473 | schemas.push_back(interface_type->getMethod(method_name)); |
474 | } else { |
475 | TORCH_INTERNAL_ASSERT( |
476 | false, "method constructed that is not a class or interface" ); |
477 | } |
478 | } |
479 | auto match = matchSchemas(schemas, loc, *f.graph(), argsWithSelf, kwargs); |
480 | Value* output = |
481 | f.graph()->insertMethodCall(method_names_[match.first], match.second); |
482 | output->node()->setSourceRange(loc); |
483 | return std::make_shared<SimpleValue>(output); |
484 | } |
485 | |
486 | private: |
487 | Value* self_; |
488 | std::vector<std::string> method_names_; |
489 | }; |
490 | |
491 | struct TORCH_API PrintValue : public SugaredValue { |
492 | std::string kind() const override { |
493 | return "print" ; |
494 | } |
495 | std::shared_ptr<SugaredValue> call( |
496 | const SourceRange& loc, |
497 | GraphFunction& m, |
498 | at::ArrayRef<NamedValue> args, |
499 | at::ArrayRef<NamedValue> kwargs, |
500 | size_t n_binders) override; |
501 | }; |
502 | |
503 | // expressions like int(x) |
504 | // these are the same as call prim::Int or equivalent except it |
505 | // is a noop when the input is a subtype of 'type' |
506 | struct TORCH_API CastValue : public BuiltinFunction { |
507 | CastValue(TypePtr type, c10::Symbol method) |
508 | : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {} |
509 | std::shared_ptr<SugaredValue> call( |
510 | const SourceRange& loc, |
511 | GraphFunction& m, |
512 | at::ArrayRef<NamedValue> args, |
513 | at::ArrayRef<NamedValue> kwargs, |
514 | size_t n_binders) override { |
515 | if (args.size() == 1 && kwargs.empty()) { |
516 | auto len_op = std::make_shared<BuiltinFunction>(aten::len, at::nullopt); |
517 | auto gt_op = std::make_shared<BuiltinFunction>(aten::gt, at::nullopt); |
518 | auto zero = m.graph()->insertConstant(0); |
519 | |
520 | auto v = args[0].value(*m.graph()); |
521 | if (v->type()->isSubtypeOf(*type_)) { |
522 | return std::make_shared<SimpleValue>(v); |
523 | } else if ( |
524 | *type_ == *BoolType::get() && |
525 | (v->type()->isSubtypeOf(*AnyListType::get()) || |
526 | v->type()->isSubtypeOf(*StringType::get()) || |
527 | v->type()->cast<DictType>())) { |
528 | auto len = len_op->call(loc, m, {v}, {}, 1); |
529 | return gt_op->call(loc, m, {len->asValue(loc, m), zero}, {}, 1); |
530 | } |
531 | } |
532 | return BuiltinFunction::call(loc, m, args, kwargs, n_binders); |
533 | } |
534 | |
535 | private: |
536 | TypePtr type_; |
537 | }; |
538 | |
539 | struct TORCH_API TensorCastValue : public SugaredValue { |
540 | TensorCastValue(at::ScalarType type, NamedValue self) |
541 | : dtype_(type), self_(std::move(self)) {} |
542 | |
543 | std::string kind() const override { |
544 | return "Cast" ; |
545 | } |
546 | |
547 | std::shared_ptr<SugaredValue> call( |
548 | const SourceRange& loc, |
549 | GraphFunction& m, |
550 | at::ArrayRef<NamedValue> args, |
551 | at::ArrayRef<NamedValue> kwargs, |
552 | size_t n_binders) override { |
553 | TORCH_INTERNAL_ASSERT(args.empty() && kwargs.empty()); |
554 | Value* dtype_const = m.graph()->insertConstant(dtype_, loc); |
555 | std::vector<NamedValue> kwargs_{ |
556 | self_, NamedValue(loc, "dtype" , dtype_const)}; |
557 | Value* casted_val = m.graph()->insert( |
558 | /*opname=*/Symbol::fromQualString("aten::to" ), |
559 | /*args=*/args, |
560 | /*kwargs=*/kwargs_, |
561 | /*range=*/loc); |
562 | return std::make_shared<SimpleValue>(casted_val); |
563 | } |
564 | |
565 | at::ScalarType dtype_; |
566 | NamedValue self_; |
567 | }; |
568 | |
569 | // builtins operators and functions that call a method if it exists |
570 | // on a class type, like 'len(x)' and 'x + y' |
571 | struct TORCH_API MagicMethod : public SugaredValue { |
572 | MagicMethod(std::string desugared_name, SugaredValuePtr base) |
573 | : base_value_(std::move(base)), |
574 | desugared_name_(std::move(desugared_name)) {} |
575 | |
576 | std::string kind() const override { |
577 | return desugared_name_; |
578 | } |
579 | |
580 | std::shared_ptr<SugaredValue> call( |
581 | const SourceRange& loc, |
582 | GraphFunction& m, |
583 | at::ArrayRef<NamedValue> args, |
584 | at::ArrayRef<NamedValue> kwargs, |
585 | size_t n_binders) override; |
586 | |
587 | private: |
588 | SugaredValuePtr base_value_; |
589 | std::string desugared_name_; |
590 | }; |
591 | |
592 | // things that look like function applications, but |
593 | // perform non-standard evaluation are represented |
594 | // with SpecialFormValues, e.g. |
595 | // isinstance(x, int) |
596 | // fork(fn) |
597 | // annotate(int, 3) |
598 | // The implementation of each value is handled by a case inside emitApplyExpr |
599 | struct TORCH_API SpecialFormValue : public SugaredValue { |
600 | SpecialFormValue(Symbol form) : form_(form) {} |
601 | std::string kind() const override { |
602 | return form_.toUnqualString(); |
603 | } |
604 | Symbol form() const { |
605 | return form_; |
606 | } |
607 | static std::shared_ptr<SpecialFormValue> create(Symbol form) { |
608 | return std::make_shared<SpecialFormValue>(form); |
609 | } |
610 | |
611 | private: |
612 | Symbol form_; |
613 | }; |
614 | |
615 | struct TORCH_API LegacyTensorConstructor : public SpecialFormValue { |
616 | LegacyTensorConstructor(Symbol form, at::ScalarType dtype, at::Device device) |
617 | : SpecialFormValue(form), device_(device), dtype_(dtype) {} |
618 | |
619 | static std::shared_ptr<LegacyTensorConstructor> create( |
620 | Symbol form, |
621 | at::ScalarType dtype, |
622 | at::Device device) { |
623 | return std::make_shared<LegacyTensorConstructor>(form, dtype, device); |
624 | } |
625 | at::ScalarType dtype() const { |
626 | return dtype_; |
627 | } |
628 | |
629 | private: |
630 | at::Device device_; |
631 | at::ScalarType dtype_; |
632 | }; |
633 | |
634 | // matched against for special handling of range expressions |
635 | struct TORCH_API RangeValue : SugaredValue { |
636 | RangeValue( |
637 | const SourceRange& loc, |
638 | GraphFunction& m, |
639 | std::vector<Value*> input, |
640 | c10::optional<int64_t> static_len = c10::nullopt); |
641 | |
642 | std::string kind() const override { |
643 | return "range" ; |
644 | } |
645 | Value* len(const SourceRange& loc, GraphFunction& m) override; |
646 | SugaredValuePtr getitem( |
647 | const SourceRange& loc, |
648 | GraphFunction& m, |
649 | Value* idx, |
650 | TypePtr type_hint = nullptr) override; |
651 | std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m) |
652 | override; |
653 | |
654 | // When Range is instantiated via enumerate(iterable_with_static_len), |
655 | // then it takes the static length of the iterable |
656 | c10::optional<int64_t> staticLen() override { |
657 | return static_len_; |
658 | } |
659 | |
660 | private: |
661 | Value* start_{}; |
662 | Value* end_{}; |
663 | Value* step_{}; |
664 | // a flag to determine if it's a simple range() call with only end_ from |
665 | // arguments If true, we will not insert length calculation and index |
666 | // derivation nodes to simplify the graph and enable more possible |
667 | // optimizations |
668 | bool has_only_end_{}; |
669 | c10::optional<int64_t> static_len_; |
670 | }; |
671 | |
672 | // Specialized Tree structure to matched against for special handling |
673 | // of builtin functions iterables expressions like zip(), enumerate(), etc. |
674 | // zip and enumerate can be modeled as a tree of SimpleValue/RangeValue: |
675 | // zip(x, y) -> (x, y) with tuple assignment to each loop target |
676 | // enumerate(x) -> (range(0, math.inf, 1), x) |
677 | // So a complicated expression like zip(a, enumerate(b), range(0, 100)) will be: |
678 | // (a, (range(0, math.inf, 1), b), range(0, 100)) |
679 | // We use those base iterables to fill in the loop information like |
680 | // max_trip_count and set the value table for loop targets |
681 | // Iterables can contain lists of SugaredValues like ModuleLists. If it |
682 | // does, then we emit it unrolled and require that all values it contains |
683 | // have a statically-determinable length. |
684 | struct TORCH_API IterableTree : SugaredValue { |
685 | IterableTree() = default; |
686 | IterableTree( |
687 | const SourceRange& range, |
688 | GraphFunction& m, |
689 | at::ArrayRef<SugaredValuePtr> children) { |
690 | for (const auto& child : children) { |
691 | addChild(range, m, child); |
692 | } |
693 | } |
694 | std::string kind() const override { |
695 | return "iterabletree" ; |
696 | } |
697 | |
698 | std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m) |
699 | override { |
700 | return shared_from_this(); |
701 | } |
702 | |
703 | void addChild( |
704 | const SourceRange& range, |
705 | GraphFunction& m, |
706 | const SugaredValuePtr& iter_value); |
707 | |
708 | std::vector<SugaredValuePtr> get_children() { |
709 | return children_; |
710 | } |
711 | |
712 | // If this iterable contains a ModuleList or Tuple, then it will have a |
713 | // static length, and we will emit it as an unrolled for loop. |
714 | c10::optional<int64_t> staticLen() override { |
715 | return unroll_length_; |
716 | } |
717 | |
718 | // given a IterableTree node, get all the base iterables/leaves under the |
719 | // IterableTree node. This enables |
720 | // us to get all the basic SugaredValues that contains valid loop information |
721 | // with len() and getitem() |
722 | std::vector<SugaredValuePtr> get_base_iterables(); |
723 | |
724 | Value* len(const SourceRange& loc, GraphFunction& m) override; |
725 | SugaredValuePtr getitem( |
726 | const SourceRange& loc, |
727 | GraphFunction& m, |
728 | Value* idx, |
729 | TypePtr type_hint = nullptr) override; |
730 | |
731 | private: |
732 | c10::optional<int64_t> unroll_length_ = c10::nullopt; |
733 | std::vector<SugaredValuePtr> children_; |
734 | }; |
735 | |
736 | static inline std::vector<Value*> toValues( |
737 | Graph& g, |
738 | at::ArrayRef<NamedValue> nvs) { |
739 | return fmap(nvs, [&](const NamedValue& v) { return v.value(g); }); |
740 | } |
741 | |
742 | struct SimpleSelf : public Self { |
743 | explicit SimpleSelf(ClassTypePtr classType) |
744 | : Self(), classType_(std::move(classType)) {} |
745 | std::shared_ptr<SugaredValue> makeSugared(Value* v) const override { |
746 | v->setType(classType_); |
747 | return std::make_shared<SimpleValue>(v); |
748 | } |
749 | ClassTypePtr getClassType() const override { |
750 | return classType_; |
751 | } |
752 | |
753 | private: |
754 | ClassTypePtr classType_; |
755 | }; |
756 | |
757 | // This is not a SimpleValue so it can not pass through the code paths that |
758 | // expect a SimpleValue as a sugared value. |
759 | struct TORCH_API ExceptionMessageValue : public SugaredValue { |
760 | explicit ExceptionMessageValue( |
761 | Value* value, |
762 | Value* qualified_class_name = nullptr) |
763 | : value_(value), qualified_class_name_(qualified_class_name) {} |
764 | |
765 | std::string kind() const override { |
766 | return "exception message" ; |
767 | } |
768 | |
769 | Value* getValue() { |
770 | return value_; |
771 | } |
772 | |
773 | // qualified python class name |
774 | Value* getQualifiedClassName() { |
775 | return qualified_class_name_; |
776 | } |
777 | |
778 | private: |
779 | Value* value_; |
780 | Value* qualified_class_name_; |
781 | }; |
782 | |
783 | struct TORCH_API ExceptionValue : public SugaredValue { |
784 | explicit ExceptionValue(std::string message) : message_(std::move(message)) {} |
785 | |
786 | std::string kind() const override { |
787 | return "exception" ; |
788 | } |
789 | |
790 | std::shared_ptr<SugaredValue> call( |
791 | const SourceRange& loc, |
792 | GraphFunction& m, |
793 | at::ArrayRef<NamedValue> args, |
794 | at::ArrayRef<NamedValue> /*attributes*/, |
795 | size_t /*n_binders*/) override { |
796 | auto exception_message = insertConstant(*m.graph(), message_ + ": " , loc); |
797 | for (auto& input : args) { |
798 | auto input_str = input.value(*m.graph()); |
799 | if (!input_str->type()->isSubtypeOf(*StringType::get())) { |
800 | input_str = |
801 | emitBuiltinCall(loc, *m.graph(), aten::str, {input_str}, {}); |
802 | } |
803 | exception_message = emitBuiltinCall( |
804 | loc, *m.graph(), aten::add, {exception_message, input_str}, {}); |
805 | } |
806 | return std::make_shared<ExceptionMessageValue>(exception_message); |
807 | } |
808 | |
809 | std::string message_; |
810 | }; |
811 | |
812 | struct TORCH_API SugaredEnumClass : public SugaredValue { |
813 | explicit SugaredEnumClass(EnumTypePtr enum_type) |
814 | : enum_type_(std::move(enum_type)) {} |
815 | |
816 | std::string kind() const override { |
817 | return "EnumClass" ; |
818 | } |
819 | |
820 | SugaredValuePtr attr( |
821 | const SourceRange& loc, |
822 | GraphFunction& m, |
823 | const std::string& field) override; |
824 | |
825 | SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override; |
826 | |
827 | private: |
828 | EnumTypePtr enum_type_; |
829 | }; |
830 | |
831 | struct TORCH_API SliceValue : public SugaredValue { |
832 | explicit SliceValue(Value* start, Value* stop, Value* step) |
833 | : start_(start), stop_(stop), step_(step) {} |
834 | |
835 | std::string kind() const override { |
836 | return "Python slice value" ; |
837 | } |
838 | |
839 | Value* start() { |
840 | return start_; |
841 | }; |
842 | Value* stop() { |
843 | return stop_; |
844 | }; |
845 | Value* step() { |
846 | return step_; |
847 | }; |
848 | |
849 | private: |
850 | Value* start_; |
851 | Value* stop_; |
852 | Value* step_; |
853 | }; |
854 | |
855 | } // namespace jit |
856 | } // namespace torch |
857 | |