1#pragma once
2#include <c10/util/Optional.h>
3#include <functional>
4#include <memory>
5#include <string>
6#include <utility>
7
8#include <ATen/core/symbol.h>
9#include <caffe2/serialize/versions.h>
10#include <torch/csrc/jit/api/module.h>
11#include <torch/csrc/jit/frontend/error_report.h>
12#include <torch/csrc/jit/frontend/schema_matching.h>
13#include <torch/csrc/jit/frontend/versioned_symbols.h>
14#include <torch/csrc/jit/ir/ir.h>
15
16namespace torch {
17namespace jit {
18
19using SugaredValuePtr = std::shared_ptr<SugaredValue>;
20
21// The AST can contain nodes like `self`, `self.b` or `python_fn` that
22// are not first-class values in the graph representation, but instead
23// will be desugared based on how they are used in the AST.
24
25// SugaredValue is used to temporarily represent these values in a way
26// that separates their behavior from the AST -> IR converter itself.
27// This allows us to keep dependencies on python minimal.
28
29struct TORCH_API SugaredValue
30 : public std::enable_shared_from_this<SugaredValue> {
31 // what is this node? for error reporting (e.g. Module, python function)
32 virtual std::string kind() const = 0;
33
34 // what can we do with this thing?
35 // use it as a value e.g. `this + 4`
36 virtual Value* asValue(const SourceRange& loc, GraphFunction& m) {
37 throw ErrorReport(loc) << kind() << " cannot be used as a value";
38 }
39
40 // select an attribute on it, e.g. `this.field`
41 virtual std::shared_ptr<SugaredValue> attr(
42 const SourceRange& loc,
43 GraphFunction& m,
44 const std::string& field) {
45 throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
46 }
47
48 virtual bool hasAttr(
49 const SourceRange& loc,
50 GraphFunction& m,
51 const std::string& field) {
52 throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
53 }
54
55 // assign an attribute on it, e.g. `this.field = newValue`
56 virtual void setAttr(
57 const SourceRange& loc,
58 GraphFunction& m,
59 const std::string& field,
60 Value* newValue) {
61 throw ErrorReport(loc) << "attribute assignment is not defined on "
62 << kind();
63 }
64
65 // use it as a vector of values, e.g. a tuple of values as return value from
66 // a method invocation
67 virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
68 const SourceRange& loc,
69 GraphFunction& m,
70 const c10::optional<size_t>& size_hint = {}) {
71 throw ErrorReport(loc) << kind() << " cannot be used as a tuple";
72 }
73
74 // TODO @wconstab refactor to use ModuleValue::asTuple instead of new API
75 virtual SugaredValuePtr asTupleValue(
76 const SourceRange& loc,
77 GraphFunction& m) {
78 throw ErrorReport(loc) << kind() << " cannot be used as a tuplevalue";
79 }
80
81 virtual std::vector<std::shared_ptr<SugaredValue>> asType(
82 const SourceRange& loc,
83 Method& m) {
84 throw ErrorReport(loc) << kind() << " cannot be used as a type";
85 }
86
87 // call it like a function, e.g. `outputs = this(inputs)`
88 virtual std::shared_ptr<SugaredValue> call(
89 const SourceRange& loc,
90 GraphFunction& m,
91 // note: names for args will be 'argument 0', 'argument 1', etc..
92 at::ArrayRef<NamedValue> args,
93 at::ArrayRef<NamedValue> kwargs,
94 size_t n_binders) {
95 // n_binders is always set to the number of variables an expression is
96 // syntactically bound to:
97 // a = foo() # 1 binder (note in this case the single binder might be a
98 // tuple) a, * b = foo() # 1 binder a, b = foo() # 2 binders foo() # 0
99 // binders
100 //
101 // In subexpressions, like bar() in foo(bar()), n_binders is always set to
102 // 1. n_binders is used as a hint to subexpressions to determine how many
103 // values they should return when that number is ambiguous statically. In
104 // particular it is currently used to decide how many tensors a call to a
105 // python function will return. It is only a hint, functions do not have to
106 // check that n_binders match the number of things they are returning, the
107 // assignment logic will do that anyway.
108
109 throw ErrorReport(loc) << "cannot call a " << kind();
110 }
111
112 // This function is called when to convert a SugaredValue to its iterator.
113 // For example, when iterating through a Dict we iterate over its keys
114 virtual std::shared_ptr<SugaredValue> iter(
115 const SourceRange& loc,
116 GraphFunction& m) {
117 throw ErrorReport(loc) << kind() << " cannot be used as an iterable";
118 }
119
120 // If we are iterating over a Sugared Value and it returns a value from this
121 // function, then we emit an unrolled loop over the variable. This allows us
122 // to support containers of Heterogenous types, like Module Containers &
123 // Tuples
124 virtual c10::optional<int64_t> staticLen() {
125 return c10::nullopt;
126 }
127
128 // When iterating over this SugaredValue, should we emit the for loop as an
129 // unrolled loop.
130 bool shouldEmitUnrolled() {
131 return staticLen() != c10::nullopt;
132 }
133
134 // return length of this thing, if not then it can't be iterated.
135 // If it does not have a statically-determinable length, then it cannot
136 // be iterated over with a modulelist. If it does it must return a constant
137 // Value *
138 virtual Value* len(const SourceRange& loc, GraphFunction& m) {
139 throw ErrorReport(loc) << "'" << kind() << "'"
140 << " object is not iterable";
141 }
142
143 // expression for ith elemement for iterable value
144 virtual std::shared_ptr<SugaredValue> getitem(
145 const SourceRange& loc,
146 GraphFunction& m,
147 Value* idx,
148 TypePtr type_hint = nullptr) {
149 throw ErrorReport(loc) << "'" << kind() << "'"
150 << " object is not subscriptable";
151 }
152
153 virtual ~SugaredValue() = default;
154};
155
156// most things in the environment are just simple value types
157// and not special python syntax sugar types
158struct TORCH_API SimpleValue : public SugaredValue {
159 SimpleValue(Value* value) : value_(value) {}
160 std::string kind() const override {
161 std::stringstream ss;
162 // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
163 ss << "value of type '" << value_->type()->annotation_str() << "'";
164 return ss.str();
165 }
166 Value* asValue(const SourceRange& range, GraphFunction& m) override {
167 return value_;
168 }
169 std::vector<std::shared_ptr<SugaredValue>> asTuple(
170 const SourceRange& loc,
171 GraphFunction& m,
172 const c10::optional<size_t>& size_hint = {}) override;
173 std::shared_ptr<SugaredValue> attr(
174 const SourceRange& loc,
175 GraphFunction& m,
176 const std::string& field) override;
177
178 bool hasAttr(
179 const SourceRange& loc,
180 GraphFunction& m,
181 const std::string& field) override;
182
183 void setAttr(
184 const SourceRange& loc,
185 GraphFunction& m,
186 const std::string& field,
187 Value* newValue) override;
188
189 std::shared_ptr<SugaredValue> call(
190 const SourceRange& loc,
191 GraphFunction& m,
192 // note: names for args will be 'argument 0', 'argument 1', etc..
193 at::ArrayRef<NamedValue> args,
194 at::ArrayRef<NamedValue> kwargs,
195 size_t n_binders) override;
196
197 std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
198 override;
199
200 Value* getValue() const {
201 return value_;
202 }
203
204 Value* len(const SourceRange& loc, GraphFunction& m) override;
205 SugaredValuePtr getitem(
206 const SourceRange& loc,
207 GraphFunction& m,
208 Value* idx,
209 TypePtr type_hint = nullptr) override;
210
211 private:
212 Value* value_;
213};
214
215struct TORCH_API BuiltinFunction : public SugaredValue {
216 BuiltinFunction(Symbol symbol, c10::optional<NamedValue> self)
217 : symbol(symbol), self(std::move(self)) {}
218
219 // The symbol of the function (e.g. `aten::relu`).
220 Symbol symbol;
221
222 // if this is method, then this is the self argument.
223 c10::optional<NamedValue> self;
224 std::string kind() const override {
225 return "builtin";
226 }
227 std::shared_ptr<SugaredValue> call(
228 const SourceRange& loc,
229 GraphFunction& m,
230 at::ArrayRef<NamedValue> args,
231 at::ArrayRef<NamedValue> kwargs,
232 size_t n_binders) override;
233
234 // try to create this builtin but if it doesn't exist or the self argument
235 // cannot possibly match, then return nullptr. Use in situations where it is
236 // not clear if it is a valid builtin
237 static std::shared_ptr<BuiltinFunction> tryCreate(
238 Symbol symbol,
239 c10::optional<NamedValue> self);
240};
241
242struct TORCH_API SugaredTupleValue : public SugaredValue {
243 explicit SugaredTupleValue(std::vector<std::shared_ptr<SugaredValue>> tup)
244 : tup_(std::move(tup)){};
245
246 std::vector<std::shared_ptr<SugaredValue>> asTuple(
247 const SourceRange& loc,
248 GraphFunction& m,
249 const c10::optional<size_t>& size_hint = {}) override {
250 return tup_;
251 };
252
253 Value* asValue(const SourceRange& loc, GraphFunction& m) override {
254 std::vector<Value*> vec;
255 vec.reserve(tup_.size());
256 for (const auto& sv : tup_) {
257 vec.push_back(sv->asValue(loc, m));
258 }
259 Graph& g = *m.graph();
260 return g.insertNode(g.createTuple(vec))->output();
261 }
262
263 std::string kind() const override {
264 return "Tuple";
265 }
266
267 SugaredValuePtr getitem(
268 const SourceRange& loc,
269 GraphFunction& m,
270 Value* idx,
271 TypePtr type_hint = nullptr) override {
272 if (!(idx->type()->cast<IntType>() && toIValue(idx))) {
273 throw ErrorReport(loc)
274 << "Expected integer literal for index. "
275 << "ModuleList/Sequential indexing is only supported with integer literals. "
276 << "Enumeration is supported, e.g. 'for index, v in enumerate(self): ...'";
277 }
278 auto index = toIValue(idx)->toInt();
279 int64_t adj_index =
280 (index < 0) ? index + static_cast<int64_t>(tup_.size()) : index;
281 if (!(adj_index >= 0 && adj_index < static_cast<int64_t>(tup_.size()))) {
282 throw ErrorReport(loc)
283 << "Index " << index << " out of range of length " << tup_.size();
284 }
285 return tup_.at(adj_index);
286 }
287
288 // This function is called when a SugaredValue is used to convert a
289 // SugaredValue to its iterator. For example, when iterating through a Dict we
290 // iterate over its keys
291 std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
292 override {
293 return shared_from_this();
294 };
295
296 // Because this is used to contain SugaredValues of Heterogenous types,
297 // we define staticLen() so that when this is iterated over it is emitted
298 // as an unrolled loop.
299 c10::optional<int64_t> staticLen() override {
300 return static_cast<int64_t>(tup_.size());
301 }
302
303 std::vector<std::shared_ptr<SugaredValue>> tup_;
304};
305
306struct TORCH_API BuiltinModule : public SugaredValue {
307 BuiltinModule(std::string name, c10::optional<int64_t> version = at::nullopt)
308 : name(std::move(name)), version(version) {}
309
310 std::string kind() const override {
311 return "builtin module";
312 }
313 std::shared_ptr<SugaredValue> attr(
314 const SourceRange& loc,
315 GraphFunction& m,
316 const std::string& field) override {
317 if (field == "autograd") {
318 // When refering torch.autograd, it is also considered to be a
319 // BuiltinModule and we will dispatch to the aten operators for the
320 // methods under its module.
321 return std::make_shared<BuiltinModule>("aten", version);
322 }
323
324 auto sym = Symbol::fromQualString(name + "::" + field);
325 return std::make_shared<BuiltinFunction>(sym, c10::nullopt);
326 }
327
328 private:
329 std::string name;
330 // when we add operator versioning, emit this op as it exising at 'version'
331 // if not set, use the latest version
332 c10::optional<int64_t> version;
333};
334
335// Represents a class, analagous to `int` or `dict`. Instances of classes,
336// like `1` or `{"foo": 5}`, are represented as SimpleValues
337struct TORCH_API ClassValue : public SugaredValue {
338 explicit ClassValue(ClassTypePtr type) : type_(std::move(type)) {}
339
340 // Call the type's constructor, as in:
341 // n = Foo(constructor_arg)
342 std::shared_ptr<SugaredValue> call(
343 const SourceRange& loc,
344 GraphFunction& m,
345 at::ArrayRef<NamedValue> args,
346 at::ArrayRef<NamedValue> kwargs,
347 size_t n_binders) override;
348
349 std::shared_ptr<SugaredValue> attr(
350 const SourceRange& loc,
351 GraphFunction& m,
352 const std::string& field) override;
353
354 std::string kind() const override {
355 return type_->str();
356 }
357
358 ClassTypePtr type_;
359};
360
361struct TORCH_API NamedTupleConstructor : public SugaredValue {
362 explicit NamedTupleConstructor(TupleTypePtr type) : type_(std::move(type)) {}
363
364 std::shared_ptr<SugaredValue> call(
365 const SourceRange& loc,
366 GraphFunction& m,
367 at::ArrayRef<NamedValue> args,
368 at::ArrayRef<NamedValue> kwargs,
369 size_t n_binders) override;
370
371 std::string kind() const override {
372 return type_->str();
373 }
374
375 TupleTypePtr type_;
376};
377
378struct FunctionValue : public SugaredValue {
379 FunctionValue(Function* callee) : callees_({callee}) {}
380 FunctionValue(const StrongFunctionPtr& p)
381 : callees_({p.function_}), cu_(p.cu_) {}
382 FunctionValue(const std::vector<StrongFunctionPtr>& callees) {
383 for (const StrongFunctionPtr& callee : callees) {
384 cu_ = cu_ ? cu_ : callee.cu_;
385 TORCH_INTERNAL_ASSERT(callee.cu_ == cu_);
386 callees_.push_back(callee.function_);
387 }
388 }
389
390 std::string kind() const override {
391 return "function";
392 }
393
394 std::shared_ptr<SugaredValue> call(
395 const SourceRange& loc,
396 GraphFunction& f,
397 at::ArrayRef<NamedValue> args,
398 at::ArrayRef<NamedValue> kwargs,
399 size_t n_binders) override {
400 std::vector<const FunctionSchema*> schemas;
401 for (Function* callee : callees_) {
402 try {
403 callee->ensure_defined();
404 } catch (const RecursiveMethodCallError&) {
405 throw ErrorReport(loc)
406 << " function '" << callee->name() << "' is called recursively. "
407 << "Recursive calls are not supported";
408 }
409 schemas.push_back(&callee->getSchema());
410 }
411 auto match = matchSchemas(schemas, loc, *f.graph(), args, kwargs);
412 Value* output =
413 f.graph()->insertFunctionCall(callees_[match.first], match.second);
414 output->node()->setSourceRange(loc);
415 return std::make_shared<SimpleValue>(output);
416 }
417
418 const std::vector<Function*>& callees() {
419 return callees_;
420 }
421
422 private:
423 std::vector<Function*> callees_;
424 // TODO holding this thing is creepy
425 std::shared_ptr<CompilationUnit> cu_;
426};
427
428struct TORCH_API ClosureValue : public SugaredValue {
429 ClosureValue(Value* value) : value_(value) {
430 TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Closure);
431 }
432 std::string kind() const override {
433 return "closure";
434 }
435 Value* asValue(const SourceRange& range, GraphFunction& m) override {
436 return value_;
437 }
438 Value* value_;
439};
440
441// defines how a method obtained from a module/class/interface behaves in script
442struct MethodValue : public SugaredValue {
443 MethodValue(Value* self, std::vector<std::string> method_names)
444 : self_(self), method_names_(std::move(method_names)) {}
445 MethodValue(Value* self, std::string method_name)
446 : MethodValue(self, std::vector<std::string>({std::move(method_name)})) {}
447
448 std::string kind() const override {
449 return "method";
450 }
451
452 std::shared_ptr<SugaredValue> call(
453 const SourceRange& loc,
454 GraphFunction& f,
455 at::ArrayRef<NamedValue> args,
456 at::ArrayRef<NamedValue> kwargs,
457 size_t n_binders) override {
458 std::vector<NamedValue> argsWithSelf = {self_};
459 argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end());
460 std::vector<const FunctionSchema*> schemas;
461 for (const std::string& method_name : method_names_) {
462 if (auto class_type = self_->type()->cast<ClassType>()) {
463 Function& method = class_type->getMethod(method_name);
464 try {
465 method.ensure_defined();
466 } catch (const RecursiveMethodCallError&) {
467 throw ErrorReport(loc)
468 << " method '" << method.name() << "' is called recursively. "
469 << "Recursive calls are not supported";
470 }
471 schemas.push_back(&method.getSchema());
472 } else if (auto interface_type = self_->type()->cast<InterfaceType>()) {
473 schemas.push_back(interface_type->getMethod(method_name));
474 } else {
475 TORCH_INTERNAL_ASSERT(
476 false, "method constructed that is not a class or interface");
477 }
478 }
479 auto match = matchSchemas(schemas, loc, *f.graph(), argsWithSelf, kwargs);
480 Value* output =
481 f.graph()->insertMethodCall(method_names_[match.first], match.second);
482 output->node()->setSourceRange(loc);
483 return std::make_shared<SimpleValue>(output);
484 }
485
486 private:
487 Value* self_;
488 std::vector<std::string> method_names_;
489};
490
491struct TORCH_API PrintValue : public SugaredValue {
492 std::string kind() const override {
493 return "print";
494 }
495 std::shared_ptr<SugaredValue> call(
496 const SourceRange& loc,
497 GraphFunction& m,
498 at::ArrayRef<NamedValue> args,
499 at::ArrayRef<NamedValue> kwargs,
500 size_t n_binders) override;
501};
502
503// expressions like int(x)
504// these are the same as call prim::Int or equivalent except it
505// is a noop when the input is a subtype of 'type'
506struct TORCH_API CastValue : public BuiltinFunction {
507 CastValue(TypePtr type, c10::Symbol method)
508 : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {}
509 std::shared_ptr<SugaredValue> call(
510 const SourceRange& loc,
511 GraphFunction& m,
512 at::ArrayRef<NamedValue> args,
513 at::ArrayRef<NamedValue> kwargs,
514 size_t n_binders) override {
515 if (args.size() == 1 && kwargs.empty()) {
516 auto len_op = std::make_shared<BuiltinFunction>(aten::len, at::nullopt);
517 auto gt_op = std::make_shared<BuiltinFunction>(aten::gt, at::nullopt);
518 auto zero = m.graph()->insertConstant(0);
519
520 auto v = args[0].value(*m.graph());
521 if (v->type()->isSubtypeOf(*type_)) {
522 return std::make_shared<SimpleValue>(v);
523 } else if (
524 *type_ == *BoolType::get() &&
525 (v->type()->isSubtypeOf(*AnyListType::get()) ||
526 v->type()->isSubtypeOf(*StringType::get()) ||
527 v->type()->cast<DictType>())) {
528 auto len = len_op->call(loc, m, {v}, {}, 1);
529 return gt_op->call(loc, m, {len->asValue(loc, m), zero}, {}, 1);
530 }
531 }
532 return BuiltinFunction::call(loc, m, args, kwargs, n_binders);
533 }
534
535 private:
536 TypePtr type_;
537};
538
539struct TORCH_API TensorCastValue : public SugaredValue {
540 TensorCastValue(at::ScalarType type, NamedValue self)
541 : dtype_(type), self_(std::move(self)) {}
542
543 std::string kind() const override {
544 return "Cast";
545 }
546
547 std::shared_ptr<SugaredValue> call(
548 const SourceRange& loc,
549 GraphFunction& m,
550 at::ArrayRef<NamedValue> args,
551 at::ArrayRef<NamedValue> kwargs,
552 size_t n_binders) override {
553 TORCH_INTERNAL_ASSERT(args.empty() && kwargs.empty());
554 Value* dtype_const = m.graph()->insertConstant(dtype_, loc);
555 std::vector<NamedValue> kwargs_{
556 self_, NamedValue(loc, "dtype", dtype_const)};
557 Value* casted_val = m.graph()->insert(
558 /*opname=*/Symbol::fromQualString("aten::to"),
559 /*args=*/args,
560 /*kwargs=*/kwargs_,
561 /*range=*/loc);
562 return std::make_shared<SimpleValue>(casted_val);
563 }
564
565 at::ScalarType dtype_;
566 NamedValue self_;
567};
568
569// builtins operators and functions that call a method if it exists
570// on a class type, like 'len(x)' and 'x + y'
571struct TORCH_API MagicMethod : public SugaredValue {
572 MagicMethod(std::string desugared_name, SugaredValuePtr base)
573 : base_value_(std::move(base)),
574 desugared_name_(std::move(desugared_name)) {}
575
576 std::string kind() const override {
577 return desugared_name_;
578 }
579
580 std::shared_ptr<SugaredValue> call(
581 const SourceRange& loc,
582 GraphFunction& m,
583 at::ArrayRef<NamedValue> args,
584 at::ArrayRef<NamedValue> kwargs,
585 size_t n_binders) override;
586
587 private:
588 SugaredValuePtr base_value_;
589 std::string desugared_name_;
590};
591
592// things that look like function applications, but
593// perform non-standard evaluation are represented
594// with SpecialFormValues, e.g.
595// isinstance(x, int)
596// fork(fn)
597// annotate(int, 3)
598// The implementation of each value is handled by a case inside emitApplyExpr
599struct TORCH_API SpecialFormValue : public SugaredValue {
600 SpecialFormValue(Symbol form) : form_(form) {}
601 std::string kind() const override {
602 return form_.toUnqualString();
603 }
604 Symbol form() const {
605 return form_;
606 }
607 static std::shared_ptr<SpecialFormValue> create(Symbol form) {
608 return std::make_shared<SpecialFormValue>(form);
609 }
610
611 private:
612 Symbol form_;
613};
614
615struct TORCH_API LegacyTensorConstructor : public SpecialFormValue {
616 LegacyTensorConstructor(Symbol form, at::ScalarType dtype, at::Device device)
617 : SpecialFormValue(form), device_(device), dtype_(dtype) {}
618
619 static std::shared_ptr<LegacyTensorConstructor> create(
620 Symbol form,
621 at::ScalarType dtype,
622 at::Device device) {
623 return std::make_shared<LegacyTensorConstructor>(form, dtype, device);
624 }
625 at::ScalarType dtype() const {
626 return dtype_;
627 }
628
629 private:
630 at::Device device_;
631 at::ScalarType dtype_;
632};
633
634// matched against for special handling of range expressions
635struct TORCH_API RangeValue : SugaredValue {
636 RangeValue(
637 const SourceRange& loc,
638 GraphFunction& m,
639 std::vector<Value*> input,
640 c10::optional<int64_t> static_len = c10::nullopt);
641
642 std::string kind() const override {
643 return "range";
644 }
645 Value* len(const SourceRange& loc, GraphFunction& m) override;
646 SugaredValuePtr getitem(
647 const SourceRange& loc,
648 GraphFunction& m,
649 Value* idx,
650 TypePtr type_hint = nullptr) override;
651 std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
652 override;
653
654 // When Range is instantiated via enumerate(iterable_with_static_len),
655 // then it takes the static length of the iterable
656 c10::optional<int64_t> staticLen() override {
657 return static_len_;
658 }
659
660 private:
661 Value* start_{};
662 Value* end_{};
663 Value* step_{};
664 // a flag to determine if it's a simple range() call with only end_ from
665 // arguments If true, we will not insert length calculation and index
666 // derivation nodes to simplify the graph and enable more possible
667 // optimizations
668 bool has_only_end_{};
669 c10::optional<int64_t> static_len_;
670};
671
672// Specialized Tree structure to matched against for special handling
673// of builtin functions iterables expressions like zip(), enumerate(), etc.
674// zip and enumerate can be modeled as a tree of SimpleValue/RangeValue:
675// zip(x, y) -> (x, y) with tuple assignment to each loop target
676// enumerate(x) -> (range(0, math.inf, 1), x)
677// So a complicated expression like zip(a, enumerate(b), range(0, 100)) will be:
678// (a, (range(0, math.inf, 1), b), range(0, 100))
679// We use those base iterables to fill in the loop information like
680// max_trip_count and set the value table for loop targets
681// Iterables can contain lists of SugaredValues like ModuleLists. If it
682// does, then we emit it unrolled and require that all values it contains
683// have a statically-determinable length.
684struct TORCH_API IterableTree : SugaredValue {
685 IterableTree() = default;
686 IterableTree(
687 const SourceRange& range,
688 GraphFunction& m,
689 at::ArrayRef<SugaredValuePtr> children) {
690 for (const auto& child : children) {
691 addChild(range, m, child);
692 }
693 }
694 std::string kind() const override {
695 return "iterabletree";
696 }
697
698 std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
699 override {
700 return shared_from_this();
701 }
702
703 void addChild(
704 const SourceRange& range,
705 GraphFunction& m,
706 const SugaredValuePtr& iter_value);
707
708 std::vector<SugaredValuePtr> get_children() {
709 return children_;
710 }
711
712 // If this iterable contains a ModuleList or Tuple, then it will have a
713 // static length, and we will emit it as an unrolled for loop.
714 c10::optional<int64_t> staticLen() override {
715 return unroll_length_;
716 }
717
718 // given a IterableTree node, get all the base iterables/leaves under the
719 // IterableTree node. This enables
720 // us to get all the basic SugaredValues that contains valid loop information
721 // with len() and getitem()
722 std::vector<SugaredValuePtr> get_base_iterables();
723
724 Value* len(const SourceRange& loc, GraphFunction& m) override;
725 SugaredValuePtr getitem(
726 const SourceRange& loc,
727 GraphFunction& m,
728 Value* idx,
729 TypePtr type_hint = nullptr) override;
730
731 private:
732 c10::optional<int64_t> unroll_length_ = c10::nullopt;
733 std::vector<SugaredValuePtr> children_;
734};
735
736static inline std::vector<Value*> toValues(
737 Graph& g,
738 at::ArrayRef<NamedValue> nvs) {
739 return fmap(nvs, [&](const NamedValue& v) { return v.value(g); });
740}
741
742struct SimpleSelf : public Self {
743 explicit SimpleSelf(ClassTypePtr classType)
744 : Self(), classType_(std::move(classType)) {}
745 std::shared_ptr<SugaredValue> makeSugared(Value* v) const override {
746 v->setType(classType_);
747 return std::make_shared<SimpleValue>(v);
748 }
749 ClassTypePtr getClassType() const override {
750 return classType_;
751 }
752
753 private:
754 ClassTypePtr classType_;
755};
756
757// This is not a SimpleValue so it can not pass through the code paths that
758// expect a SimpleValue as a sugared value.
759struct TORCH_API ExceptionMessageValue : public SugaredValue {
760 explicit ExceptionMessageValue(
761 Value* value,
762 Value* qualified_class_name = nullptr)
763 : value_(value), qualified_class_name_(qualified_class_name) {}
764
765 std::string kind() const override {
766 return "exception message";
767 }
768
769 Value* getValue() {
770 return value_;
771 }
772
773 // qualified python class name
774 Value* getQualifiedClassName() {
775 return qualified_class_name_;
776 }
777
778 private:
779 Value* value_;
780 Value* qualified_class_name_;
781};
782
783struct TORCH_API ExceptionValue : public SugaredValue {
784 explicit ExceptionValue(std::string message) : message_(std::move(message)) {}
785
786 std::string kind() const override {
787 return "exception";
788 }
789
790 std::shared_ptr<SugaredValue> call(
791 const SourceRange& loc,
792 GraphFunction& m,
793 at::ArrayRef<NamedValue> args,
794 at::ArrayRef<NamedValue> /*attributes*/,
795 size_t /*n_binders*/) override {
796 auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc);
797 for (auto& input : args) {
798 auto input_str = input.value(*m.graph());
799 if (!input_str->type()->isSubtypeOf(*StringType::get())) {
800 input_str =
801 emitBuiltinCall(loc, *m.graph(), aten::str, {input_str}, {});
802 }
803 exception_message = emitBuiltinCall(
804 loc, *m.graph(), aten::add, {exception_message, input_str}, {});
805 }
806 return std::make_shared<ExceptionMessageValue>(exception_message);
807 }
808
809 std::string message_;
810};
811
812struct TORCH_API SugaredEnumClass : public SugaredValue {
813 explicit SugaredEnumClass(EnumTypePtr enum_type)
814 : enum_type_(std::move(enum_type)) {}
815
816 std::string kind() const override {
817 return "EnumClass";
818 }
819
820 SugaredValuePtr attr(
821 const SourceRange& loc,
822 GraphFunction& m,
823 const std::string& field) override;
824
825 SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
826
827 private:
828 EnumTypePtr enum_type_;
829};
830
831struct TORCH_API SliceValue : public SugaredValue {
832 explicit SliceValue(Value* start, Value* stop, Value* step)
833 : start_(start), stop_(stop), step_(step) {}
834
835 std::string kind() const override {
836 return "Python slice value";
837 }
838
839 Value* start() {
840 return start_;
841 };
842 Value* stop() {
843 return stop_;
844 };
845 Value* step() {
846 return step_;
847 };
848
849 private:
850 Value* start_;
851 Value* stop_;
852 Value* step_;
853};
854
855} // namespace jit
856} // namespace torch
857