1#pragma once
2
3#include <ATen/core/custom_class.h>
4#include <ATen/core/jit_type_base.h>
5#include <ATen/core/TensorBody.h>
6#include <ATen/core/functional.h>
7#include <ATen/core/symbol.h>
8#include <ATen/core/type_factory.h>
9#include <ATen/core/qualified_name.h>
10#include <c10/util/TypeList.h>
11#include <c10/util/Optional.h>
12#include <c10/core/SymFloat.h>
13
14#include <array>
15#include <memory>
16#include <ostream>
17#include <sstream>
18#include <type_traits>
19#include <utility>
20
21namespace torch {
22namespace jit {
23struct Function;
24} // namespace jit
25} // namespace torch
26
27namespace c10 {
28
29template<class Key, class Value>
30class Dict;
31struct IValue;
32struct FunctionSchema;
33struct NamedType;
34using OptNameList = c10::optional<std::vector<std::string>>;
35
36void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
37void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
38
39inline bool is_contiguous_strides(
40 const IntArrayRef sizes,
41 const IntArrayRef strides) {
42 int n_dim = static_cast<int>(sizes.size());
43 if (n_dim == 0) {
44 return true;
45 }
46
47 if (strides[n_dim - 1] != 1) {
48 return false;
49 }
50
51 for (int i = n_dim - 2; i >= 0; i--) {
52 if (strides[i] != strides[i + 1] * sizes[i + 1]) {
53 return false;
54 }
55 }
56 return true;
57}
58
59struct AnyType;
60using AnyTypePtr = SingletonTypePtr<AnyType>;
61// Any is the top of the type hierarchy, all other types are subtypes
62// T <: Any, forall T
63struct TORCH_API AnyType : public Type {
64 bool equals(const Type& rhs) const override {
65 return rhs.kind() == kind();
66 }
67 std::string str() const override {
68 return "Any";
69 }
70 static const TypeKind Kind = TypeKind::AnyType;
71 // global singleton
72 static AnyTypePtr get();
73
74 private:
75 AnyType() : Type(TypeKind::AnyType) {}
76};
77
78inline std::string toString(const Type& type) {
79 return type.str();
80}
81
82// Shim for compatibility with code that uses TypePtr.
83inline std::string toString(const TypePtr& typePtr) {
84 return toString(*typePtr);
85}
86
87inline bool operator!=(const Type& lhs, const Type& rhs) {
88 return !(lhs == rhs);
89}
90
91// common base for all types that have a single sub element
92// e.g. Future[T], Optional[T], List[T]
93template <TypeKind K, typename T>
94struct SingleElementType : public SharedType {
95 static const TypeKind Kind = K;
96
97 const TypePtr& getElementType() const {
98 return elem;
99 }
100
101 bool hasFreeVariables() const override {
102 return getElementType()->hasFreeVariables();
103 }
104
105 at::ArrayRef<TypePtr> containedTypes() const override {
106 return elem;
107 }
108
109 bool equals(const Type& rhs) const override {
110 if (auto rhs_ = rhs.cast<T>()) {
111 return *getElementType() == *rhs_->getElementType();
112 }
113 return false;
114 }
115
116 protected:
117 SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
118 if (!this->elem) {
119 throw std::runtime_error(c10::str(
120 "Can not create ", typeKindToString(Kind), " with None type"));
121 }
122 }
123
124 private:
125 TypePtr elem;
126};
127
128struct UnionType;
129using UnionTypePtr = std::shared_ptr<UnionType>;
130struct TORCH_API UnionType : public SharedType {
131 friend struct Type;
132
133 static const TypeKind Kind = TypeKind::UnionType;
134
135 bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
136
137 std::string str() const override;
138
139 static UnionTypePtr create(std::vector<TypePtr> reference);
140
141 bool equals(const Type& rhs) const override;
142
143 bool isUnionType() const override {
144 return true;
145 }
146
147 at::ArrayRef<TypePtr> containedTypes() const override {
148 return types_;
149 }
150
151 // For testing purposes only
152 at::ArrayRef<TypePtr> getTypes() const {
153 return types_;
154 }
155
156 TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
157 return create(std::move(contained_types));
158 }
159
160 bool canHoldType(const Type& type) const;
161
162 bool hasFreeVariables() const override {
163 return has_free_variables_;
164 }
165
166 c10::optional<TypePtr> toOptional() const;
167
168 c10::optional<TypePtr> subtractTypeSet(std::vector<TypePtr>& to_subtract) const;
169
170 protected:
171 explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
172 std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
173 std::string unionStr(
174 TypePrinter printer = nullptr,
175 bool is_annotation_str = false) const;
176 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
177 bool has_free_variables_;
178 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
179 std::vector<TypePtr> types_;
180 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
181 bool can_hold_none_;
182
183};
184
185struct OptionalType;
186using OptionalTypePtr = std::shared_ptr<OptionalType>;
187// This type represents an optional type. There is one `Optional` for
188// each element type. `Optional[T]` can accept both `T` and
189// `None`(`c10::nullopt` in C++)
190// Subtype hierarchy for Optional:
191// - Optional[T] <: Optional[R] iff T <: R
192// - T <: Optional[R] if T <: R
193// - None <: Optional[T] for all T
194// - Optional[T] == Union[T, None] for all T
195struct TORCH_API OptionalType : public UnionType {
196 static OptionalTypePtr create(TypePtr contained);
197
198 static const TypeKind Kind = TypeKind::OptionalType;
199
200 friend struct Type;
201
202 bool equals(const Type& rhs) const override;
203
204 const TypePtr& getElementType() const {
205 return contained_;
206 }
207
208 at::ArrayRef<TypePtr> containedTypes() const override {
209 return contained_;
210 }
211
212 std::string str() const override {
213 std::stringstream ss;
214 ss << getElementType()->str() << "?";
215 return ss.str();
216 }
217
218 TypePtr createWithContained(
219 std::vector<TypePtr> contained_types) const override {
220 AT_ASSERT(contained_types.size() == 1);
221 return create(std::move(contained_types[0]));
222 }
223
224 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
225
226 bool isUnionType() const override {
227 return true;
228 }
229
230 // common cast Optional[Tensor] for undefined tensor type
231 static TypePtr ofTensor();
232 //
233 // global singleton
234 static TypePtr get(TypePtr inner);
235
236 private:
237 explicit OptionalType(TypePtr contained);
238
239 TypePtr contained_;
240
241 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
242 std::stringstream ss;
243 ss << "Optional[" << getElementType()->annotation_str(std::move(printer)) << "]";
244 return ss.str();
245 }
246};
247
248template <typename T>
249inline c10::optional<T> merge_primitive(
250 const c10::optional<T>& a,
251 const c10::optional<T>& b) {
252 if (a.has_value() && b.has_value() && a.value() == b.value()) {
253 return a;
254 }
255 return c10::optional<T>{};
256}
257
258// If we see `a + b + c` and know that a, b, and c are the same size and have
259// two dimensions (WxH), then we can generate a fused kernel for them. That
260// fused kernel would likely have indexing math to handling both the W and H
261// dimensions. However, if we knew the WxH dimensions were contiguous, we can
262// pretend like we only have a single dimension, simplifying the indexing logic.
263// This can be performed even if the dimensions are transposed,
264// as long as a, b, and c are transposed in the same way.
265// We'd like to have the compiler be able to do this dimensionality reduction,
266// but simply knowing sizes is not enough.
267// We can extend profiling to also record stride information.
268// Rather than recording specific strides,
269// we can simply order the strides from smallest to largest with
270// `stride_indices` A contiguity marker on the smallest stride (c0) indicates
271// the stride is precisely 1, otherwise a contiguity marker means that $stride_n
272// = size_{n-1}*stride_{n-1}$
273struct TORCH_API Stride {
274 Stride() = default;
275 Stride(
276 const c10::optional<size_t>& stride_index,
277 c10::optional<bool> contiguous,
278 const c10::optional<size_t>& stride)
279 : stride_index_(stride_index), contiguous_(contiguous), stride_(stride) {}
280
281 bool operator==(const Stride& b) const {
282 return stride_index_ == b.stride_index_ && contiguous_ == b.contiguous_ &&
283 stride_ == b.stride_;
284 }
285
286 bool isComplete() const {
287 return stride_index_ && contiguous_ && stride_;
288 }
289
290 c10::optional<size_t> stride_index_;
291 c10::optional<bool> contiguous_;
292 c10::optional<size_t> stride_;
293};
294
295template <>
296inline c10::optional<Stride> merge_primitive(
297 const c10::optional<Stride>& a,
298 const c10::optional<Stride>& b) {
299 c10::optional<Stride> left = a;
300 c10::optional<Stride> right = b;
301 if (!left.has_value()) {
302 left = {Stride()};
303 }
304 if (!right.has_value()) {
305 right = {Stride()};
306 }
307
308 auto merged_index =
309 merge_primitive(left->stride_index_, right->stride_index_);
310 auto merged_cont = merge_primitive(left->contiguous_, right->contiguous_);
311 auto merged_stride = merge_primitive(left->stride_, right->stride_);
312 auto r = Stride(merged_index, merged_cont, merged_stride);
313 // normalize
314 if (!r.stride_index_.has_value() && !r.contiguous_.has_value() &&
315 !r.stride_.has_value()) {
316 return c10::optional<Stride>{};
317 }
318
319 return r;
320}
321
322struct TORCH_API ShapeSymbol {
323 // needed for use in `std::map`
324 ShapeSymbol() : value_(-1) {}
325 // is this symbol a fixed/static dimension
326 bool is_static() const {
327 return value_ >= 0;
328 };
329 bool operator==(const ShapeSymbol& b) const {
330 return value_ == b.value_;
331 }
332 bool operator<(const ShapeSymbol& b) const {
333 return value_ < b.value_;
334 }
335
336 static ShapeSymbol fromStaticSize(int64_t val) {
337 return ShapeSymbol(val);
338 }
339 int64_t static_size() const {
340 TORCH_CHECK(is_static());
341 return value_;
342 };
343
344 int64_t value() const {
345 return value_;
346 };
347
348 static ShapeSymbol newSymbol() {
349 return fromStaticSize(-static_cast<int64_t>(++num_symbols));
350 };
351 friend TORCH_API std::ostream& operator<<(
352 std::ostream& os,
353 const ShapeSymbol& s);
354
355 private:
356 ShapeSymbol(int64_t val) : value_(val) {}
357 int64_t value_;
358 static std::atomic<size_t> num_symbols;
359};
360
361inline ShapeSymbol merge_primitive(
362 const ShapeSymbol& a,
363 const ShapeSymbol& b) {
364 if (a.is_static() && b.is_static() && a == b) {
365 return a;
366 }
367 return ShapeSymbol::newSymbol();
368}
369
370// Shape of a Tensor represented with ShapeSymbol's. Unranked, ranked unknown
371// dims, partially known and fully known shapes are all supported.
372struct TORCH_API SymbolicShape {
373 // Unranked shape constructor.
374 SymbolicShape() : dims_(c10::nullopt) {}
375
376 // Known rank but unknown dimentions.
377 SymbolicShape(c10::optional<size_t> rank) : dims_(c10::nullopt) {
378 if(!rank) {
379 return;
380 }
381
382 std::vector<ShapeSymbol> shape_symbols;
383 shape_symbols.reserve(*rank);
384 for(size_t i = 0; i < *rank; ++i) {
385 shape_symbols.push_back(ShapeSymbol::newSymbol());
386 }
387 dims_ = shape_symbols;
388 }
389
390 // Mix of known and unknown ranks
391 SymbolicShape(const std::vector<c10::optional<int64_t>>& dims) {
392 std::vector<ShapeSymbol> shape_symbols;
393 shape_symbols.reserve(dims.size());
394 for(c10::optional<int64_t> dim: dims) {
395 if(!dim) {
396 shape_symbols.push_back(ShapeSymbol::newSymbol());
397 } else {
398 shape_symbols.push_back(ShapeSymbol::fromStaticSize(*dim));
399 }
400 }
401 dims_ = shape_symbols;
402 }
403
404 void dump() const;
405
406 SymbolicShape(std::vector<ShapeSymbol> dims) : dims_(std::move(dims)) {}
407
408 SymbolicShape(c10::IntArrayRef dims) {
409 std::vector<ShapeSymbol> shape_symbols;
410 shape_symbols.reserve(dims.size());
411 for(int64_t dim : dims) {
412 shape_symbols.push_back(ShapeSymbol::fromStaticSize(dim));
413 }
414 dims_ = shape_symbols;
415 }
416
417 ShapeSymbol operator[](size_t i) const {
418 if (!dims_) {
419 throw std::runtime_error("Rank isn't fixed");
420 }
421 return (*dims_).at(i);
422 }
423
424 ShapeSymbol at(size_t i) const {
425 if (!dims_) {
426 throw std::runtime_error("Rank isn't fixed");
427 }
428 return (*dims_).at(i);
429 }
430
431 // Returns rank or nullopt in case of unranked shape.
432 c10::optional<size_t> rank() const {
433 if(!dims_) {
434 return c10::nullopt;
435 }
436 return dims_->size();
437 }
438
439 c10::optional<std::vector<ShapeSymbol>> sizes() const {
440 return dims_;
441 }
442
443 c10::optional<std::vector<bool>> symbolicDims() const {
444 if (!dims_) {
445 return c10::nullopt;
446 }
447 auto symbolic_dims = std::vector<bool>();
448 for (const ShapeSymbol& s : *dims_) {
449 symbolic_dims.push_back(!s.is_static());
450 }
451 return symbolic_dims;
452 }
453
454 // Checks whether the shape is fully defined/complete, ie. rank and sizes
455 // of every dimension are known.
456 bool isComplete() const {
457 if(!dims_) {
458 return false;
459 }
460 for(auto d : *dims_) {
461 if(!d.is_static()) {
462 return false;
463 }
464 }
465 return true;
466 }
467
468 // Create new SymbolicShape that is result of merging self and another
469 // SymbolicShape. Only dimensions that are static and equal will be
470 // preserved.
471 // If either of two shapes are of unknown rank or they have unmatching rank,
472 // result will be unranked.
473 SymbolicShape merge(const SymbolicShape& other) const;
474
475 friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) {
476 return lhs.dims_ == rhs.dims_;
477 }
478
479 friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) {
480 return !(lhs == rhs);
481 }
482
483 private:
484 c10::optional<std::vector<ShapeSymbol>> dims_;
485};
486
487namespace detail {
488inline bool isComplete(const Stride& s) {
489 return s.isComplete();
490}
491
492template<typename T>
493inline bool isComplete(const T& /*t*/) {
494 return true;
495}
496}
497
498template <typename T>
499struct VaryingShape {
500 using ListOfOptionalElements = std::vector<c10::optional<T>>;
501 VaryingShape(const std::vector<T>& vec)
502 : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {}
503
504 VaryingShape(c10::ArrayRef<T> vec)
505 : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {}
506
507 VaryingShape(c10::optional<size_t> size = c10::nullopt) : dims_(c10::nullopt) {
508 if (size) {
509 dims_ = ListOfOptionalElements(*size);
510 }
511 }
512
513 VaryingShape(ListOfOptionalElements dims) : dims_(std::move(dims)) {}
514
515 VaryingShape(size_t size) : VaryingShape(c10::optional<size_t>(size)) {}
516
517 bool operator==(const VaryingShape& other) const {
518 return dims_ == other.dims_;
519 }
520
521 const c10::optional<T> &operator[](size_t i) const {
522 if (!dims_) {
523 throw std::runtime_error("Rank isn't fixed");
524 }
525 return (*dims_).at(i);
526 }
527
528 c10::optional<size_t> size() const {
529 if (!dims_) {
530 return c10::nullopt;
531 }
532 const auto& dims = dims_.value();
533 return dims.size();
534 }
535
536 const c10::optional<ListOfOptionalElements>& sizes() const {
537 return dims_;
538 }
539
540 TORCH_API VaryingShape merge(const VaryingShape& other) const;
541
542 c10::optional<std::vector<T>> concrete_sizes() const {
543 if (!dims_) {
544 return c10::nullopt;
545 }
546 std::vector<T> sizes;
547 for (auto d : *dims_) {
548 if (!d) {
549 return c10::nullopt;
550 }
551 sizes.push_back(d.value());
552 }
553 return sizes;
554 }
555
556 bool isComplete() const {
557 if (!dims_) {
558 return false;
559 }
560 for (auto d : *dims_) {
561 if (!d || !detail::isComplete(*d)) {
562 return false;
563 }
564 }
565 return true;
566 }
567
568 private:
569 c10::optional<ListOfOptionalElements> dims_;
570};
571
572struct TensorType;
573// TODO: investigate making this SingletonOrSharedTypePtr<TensorType>
574using TensorTypePtr = std::shared_ptr<TensorType>;
575// This type represents a single Tensor with a specific size
576struct TORCH_API TensorType : public SharedType {
577 static TensorTypePtr create(const at::Tensor& t);
578
579 // used by TensorType::create(size_t dim) which in turn used by
580 // shape_analysis.cpp
581 static TensorTypePtr create(
582 c10::optional<at::ScalarType> scalar_type,
583 c10::optional<Device> device,
584 const VaryingShape<int64_t>& sizes,
585 const VaryingShape<int64_t>& strides,
586 c10::optional<bool> requires_grad,
587 c10::optional<bool> undefined = false,
588 bool tensor_contiguity = false);
589
590 static TensorTypePtr create(
591 c10::optional<at::ScalarType> scalar_type,
592 c10::optional<Device> device,
593 const SymbolicShape& sizes,
594 const VaryingShape<Stride>& stride_,
595 c10::optional<bool> requires_grad,
596 c10::optional<bool> undefined = false);
597
598 static TensorTypePtr create(
599 c10::optional<at::ScalarType> scalar_type,
600 c10::optional<Device> device,
601 c10::optional<size_t> dim,
602 c10::optional<bool> requires_grad);
603
604 // overloaded create variadic template argument as it could not distinguish
605 // initializer list
606 static TensorTypePtr createContiguous(
607 at::ScalarType scalar_type,
608 at::Device device,
609 at::IntArrayRef sizes);
610
611 static TypePtr fromNumberType(const Type& typ);
612 static TypePtr fromBoolType();
613
614 c10::optional<size_t> dim() const {
615 return sizes().size();
616 }
617
618 VaryingShape<int64_t> sizes() const;
619
620 VaryingShape<int64_t> strides() const;
621
622 const VaryingShape<Stride>& stride_properties() const {
623 return strides_;
624 }
625
626 c10::optional<at::Device> device() const {
627 return device_;
628 }
629 c10::optional<at::ScalarType> scalarType() const {
630 return scalar_type_;
631 }
632 c10::optional<bool> requiresGrad() const {
633 return requires_grad_;
634 }
635 bool requires_grad() const override {
636 return requires_grad_ ? *requires_grad_ : true;
637 }
638
639 bool equals(const Type& rhs) const override;
640 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
641
642 std::string str() const override;
643
644 std::string repr_str() const override {
645 if (isInferredType()) {
646 return str() + " (inferred)";
647 } else {
648 return str();
649 }
650 }
651
652 c10::optional<size_t> numel() const {
653 size_t prod = 1;
654 const auto& shape = sizes();
655
656 for (size_t i = 0; i < shape.size(); i++) {
657 if (!shape[i]) {
658 return c10::optional<size_t>{};
659 }
660 prod *= shape[i].value();
661 }
662 return prod;
663 }
664
665 TensorTypePtr withRequiresGrad(c10::optional<bool> s) {
666 auto copy = clone();
667 copy->requires_grad_ = s;
668 return copy;
669 }
670
671 TensorTypePtr withScalarType(c10::optional<ScalarType> st) {
672 auto copy = clone();
673 copy->scalar_type_ = st;
674 return copy;
675 }
676
677 TensorTypePtr withDim(c10::optional<size_t> d) {
678 auto copy = clone();
679 // withDim is only used by the legacy executor
680 // that only cares about the rank, so create dummy symbols)) :
681 copy->sizes_ = SymbolicShape(d);
682 copy->strides_ = VaryingShape<Stride>(d);
683 return copy;
684 }
685
686 TensorTypePtr withStrides(VaryingShape<Stride> sstrides) const {
687 auto cloned = clone();
688 cloned->strides_ = sstrides;
689 return cloned;
690 }
691
692 TensorTypePtr withSizesStrides(
693 at::IntArrayRef sizes,
694 at::IntArrayRef strides) const {
695 auto cloned = clone();
696 auto ssizes = SymbolicShape(sizes);
697 cloned->sizes_ = ssizes;
698 cloned->strides_ = computeStrideProps(sizes, strides);
699 return cloned;
700 }
701
702 TensorTypePtr withSymbolicShapes(SymbolicShape ssizes) const {
703 auto cloned = clone();
704 cloned->sizes_ = std::move(ssizes);
705 return cloned;
706 }
707
708 TensorTypePtr withSizes(at::IntArrayRef sizes) const {
709 return withSizesStrides(
710 sizes, contiguousStridesOf(sizes));
711 }
712
713 TensorTypePtr withDevice(const c10::optional<at::Device> device) const {
714 auto copy = clone();
715 copy->device_ = device;
716 return copy;
717 }
718
719 TensorTypePtr dimensionedOnly() const {
720 auto copy = clone();
721 copy->sizes_ = SymbolicShape(sizes().size());
722 copy->strides_ = VaryingShape<Stride>(sizes().size());
723 return copy;
724 }
725
726 TensorTypePtr contiguous() const {
727 auto cloned = clone();
728 TORCH_INTERNAL_ASSERT(sizes().concrete_sizes().has_value());
729 auto strides = computeStrideProps(
730 *sizes().concrete_sizes(),
731 contiguousStridesOf(*sizes().concrete_sizes()));
732 cloned->strides_ = strides;
733 return cloned;
734 }
735
736 const SymbolicShape& symbolic_sizes() const;
737
738 TensorTypePtr merge(const TensorType& other, bool merge_sizes = true) const;
739
740 bool matchTensor(const at::Tensor& t);
741
742 // is all information about the type specified except for autograd?
743 // This replaces the notion of a 'CompleteTensorType' that used to exist
744 // in the type-hierarchy. Excluding require_grad and undefined allows
745 // this to match the old behavior.
746 bool isComplete() const {
747 return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete();
748 }
749
750 bool isInferredType() const {
751 return is_inferred_;
752 }
753
754 static TensorTypePtr getInferred() {
755 static auto valueInferred = TensorType::create(
756 /*scalar_type=*/{},
757 /*device=*/{},
758 /*sizes=*/SymbolicShape(),
759 /*stride=*/VaryingShape<Stride>{},
760 /*requires_grad=*/{},
761 /*undefined=*/false);
762 valueInferred->is_inferred_ = true;
763 return valueInferred;
764 }
765
766 // this property is used by GuardElimination
767 // please see `checkInputs` for more details
768 bool isSummarized() const {
769 return !(isComplete() && requiresGrad().has_value() &&
770 undefined().has_value());
771 }
772
773 TensorTypePtr withUndefined() {
774 auto r = clone();
775 r->undefined_ = true;
776 return r;
777 }
778
779 TensorTypePtr withPossiblyUndefined() {
780 auto r = clone();
781 r->undefined_ = c10::nullopt;
782 return r;
783 }
784
785 c10::optional<bool> undefined() const { return undefined_; }
786
787 static const TensorTypePtr& get();
788
789 static const TypeKind Kind = TypeKind::TensorType;
790
791 static std::vector<int64_t> contiguousStridesOf(
792 at::IntArrayRef in_sizes,
793 at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
794 auto contiguous_fn = [](const at::IntArrayRef& sizes,
795 const std::vector<int64_t>& dim_order) {
796 std::vector<int64_t> strides(sizes.size());
797 if (sizes.empty()) // zero-dim case
798 return strides;
799
800 strides[dim_order[0]] = 1;
801 for (size_t i = 1; i < dim_order.size(); i++) {
802 auto cur_dim = dim_order[i];
803 auto pre_dim = dim_order[i - 1];
804 strides[cur_dim] = strides[pre_dim] * sizes[pre_dim];
805 }
806 return strides;
807 };
808
809 std::vector<int64_t> dim_order(in_sizes.size());
810 if (memory_format == MemoryFormat::ChannelsLast) {
811 dim_order = {1, 3, 2, 0};
812 } else if (memory_format == MemoryFormat::ChannelsLast3d) {
813 dim_order = {1, 4, 3, 2, 0};
814 } else {
815 auto ndims = in_sizes.size();
816 for (size_t i = 0; i < ndims; i++) {
817 dim_order[i] = ndims - i - 1; // Reverse
818 }
819 }
820 return contiguous_fn(in_sizes, dim_order);
821 }
822
823 private:
824 TensorType(
825 c10::optional<at::ScalarType> scalar_type,
826 c10::optional<Device> device,
827 const SymbolicShape& sizes,
828 const VaryingShape<Stride>& strides,
829 c10::optional<bool> requires_grad,
830 c10::optional<bool> undefined = false);
831
832 TensorTypePtr clone() const {
833 return TensorTypePtr(new TensorType(
834 scalar_type_, device_, sizes_, strides_, requires_grad_, undefined_));
835 }
836
837 static VaryingShape<Stride> computeStrideProps(
838 at::IntArrayRef sizes,
839 at::IntArrayRef strides,
840 bool tensor_contiguity = false);
841
842 c10::optional<at::ScalarType> scalar_type_;
843 c10::optional<at::Device> device_;
844 SymbolicShape sizes_;
845 VaryingShape<Stride> strides_;
846 c10::optional<bool> requires_grad_;
847 // we exploit the fact certain tensors must be zero in the autograd to
848 // optimize gradient computation. Such zero tensors are currently implemented
849 // with `UndefinedTensorImpl.` They can be handled only by special operators
850 // (e.g. `AutogradAdd`) and their `Tensor::defined()` property returns false.
851 // Normally, `undefined_` is set to false, unless a type was created
852 // with `withUndefined`
853 // This will also mean that `undefined` tensors will fail
854 // `subtypeOf(TensorType::get())` check
855 // undefined_ may become `c10::nullopt` if the tensor was observed to be both
856 // defined and undefined. However, no tensor type starts out with
857 // `undefined_` set to `c10::nullopt`
858 c10::optional<bool> undefined_;
859 // Represents whether or not this type was inferred.
860 bool is_inferred_ = false;
861};
862
863struct ListType;
864using ListTypePtr = std::shared_ptr<ListType>;
865struct TORCH_API ListType
866 : public SingleElementType<TypeKind::ListType, ListType> {
867 // It's not exactly a singleton, but there should be exactly one instance of
868 // List[T] for every T
869 friend struct Type;
870 template <typename... T>
871 static ListTypePtr create(T&&... all) {
872 return ListTypePtr(
873 new ListType(std::forward<T>(all)...)); // NOLINT(modernize-make-shared)
874 }
875
876 std::string str() const override {
877 std::stringstream ss;
878 ss << getElementType()->str() << "[]";
879 return ss.str();
880 }
881 TypePtr createWithContained(
882 std::vector<TypePtr> contained_types) const override {
883 return create(std::move(contained_types.at(0)));
884 }
885
886 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
887
888 // global singleton
889 // Given an inner type T and an identifier,
890 // this function wil return the global singleton type pointer
891 // the type List<T>.
892 // The extra "identifier" argument is needed beccause we have multiple container types
893 // that all re-use this function (List<T>, array<T, N>, etc.)
894 static TypePtr get(std::string identifier, TypePtr inner);
895
896 // common cast List[Tensor]
897 static ListTypePtr ofTensors();
898 static ListTypePtr ofOptionalTensors();
899 static ListTypePtr ofInts();
900 static ListTypePtr ofFloats();
901 static ListTypePtr ofComplexDoubles();
902 static ListTypePtr ofBools();
903 static ListTypePtr ofStrings();
904
905 private:
906 ListType(TypePtr elem) : SingleElementType(std::move(elem)) {}
907
908 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
909 std::stringstream ss;
910 ss << "List[" << getElementType()->annotation_str(std::move(printer)) << "]";
911 return ss.str();
912 }
913};
914
915struct DictType;
916using DictTypePtr = std::shared_ptr<DictType>;
917struct TORCH_API DictType : public SharedType {
918 friend struct Type;
919 static const TypeKind Kind = TypeKind::DictType;
920
921 static DictTypePtr create(TypePtr key, TypePtr value) {
922 auto kind = key->kind();
923 if (auto dyn = key->castRaw<DynamicType>()) {
924 kind = dyn->dynamicKind();
925 }
926 switch (kind) {
927 case TypeKind::AnyType:
928 case TypeKind::IntType:
929 case TypeKind::BoolType:
930 case TypeKind::FloatType:
931 case TypeKind::ComplexType:
932 case TypeKind::StringType:
933 case TypeKind::TensorType:
934 case TypeKind::DeviceObjType:
935 return DictTypePtr(new DictType(std::move(key), std::move(value)));
936 default:
937 AT_ERROR(
938 "Cannot create dict for key type '",
939 key->str(),
940 "', only int, float, complex, Tensor, device and string keys are supported");
941 }
942 }
943
944 // aligned with the format in FunctionSchema
945 std::string str() const override {
946 std::stringstream ss;
947 ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str()
948 << ")";
949 return ss.str();
950 }
951
952 TypePtr createWithContained(
953 std::vector<TypePtr> contained_types) const override {
954 if (contained_types.size() != 2) {
955 throw std::runtime_error("Expected 2 contained types");
956 }
957 return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
958 }
959
960 const TypePtr& getKeyType() const {
961 return types.at(0);
962 }
963
964 const TypePtr& getValueType() const {
965 return types.at(1);
966 }
967
968 bool hasFreeVariables() const override {
969 return has_free_variables;
970 }
971
972 at::ArrayRef<TypePtr> containedTypes() const override {
973 return types;
974 }
975
976 bool equals(const Type& rhs) const override {
977 if (auto* dict_rhs = rhs.castRaw<DictType>()) {
978 return *getKeyType() == *(dict_rhs->getKeyType()) &&
979 *getValueType() == *(dict_rhs->getValueType());
980 }
981 return false;
982 }
983
984 // global singleton
985 // Given an inner type T and an identifier,
986 // this function wil return the global singleton type pointer
987 // the type List<T>.
988 // The extra "identifier" argument is needed beccause we have multiple container types
989 // that all re-use this function (Dict<K, V> and unordered_map<K, V>)
990 static TypePtr get(std::string identifier, TypePtr key, TypePtr val);
991
992 private:
993 DictType(TypePtr key, TypePtr value)
994 : SharedType(TypeKind::DictType),
995 has_free_variables(
996 key->hasFreeVariables() || value->hasFreeVariables()) {
997 types.reserve(2);
998 types.push_back(std::move(key));
999 types.push_back(std::move(value));
1000 }
1001
1002 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1003 std::stringstream ss;
1004 ss << "Dict[" << getKeyType()->annotation_str(printer) << ", ";
1005 ss << getValueType()->annotation_str(std::move(printer)) << "]";
1006 return ss.str();
1007 }
1008
1009 std::vector<TypePtr> types;
1010 bool has_free_variables;
1011};
1012
1013struct FutureType;
1014using FutureTypePtr = std::shared_ptr<FutureType>;
1015
1016struct TORCH_API FutureType
1017 : public SingleElementType<TypeKind::FutureType, FutureType> {
1018 friend struct Type;
1019 template <typename... T>
1020 static FutureTypePtr create(TypePtr elem) {
1021 return FutureTypePtr(
1022 new FutureType(std::move(elem))); // NOLINT(modernize-make-shared)
1023 }
1024
1025 std::string str() const override {
1026 std::stringstream ss;
1027 ss << "Future(" << getElementType()->str() << ")";
1028 return ss.str();
1029 }
1030 TypePtr createWithContained(
1031 std::vector<TypePtr> contained_types) const override {
1032 return create(std::move(contained_types.at(0)));
1033 }
1034
1035 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1036 if (Type::isSubtypeOfExt(rhs, why_not)) {
1037 return true;
1038 }
1039 if (auto rhs_ = rhs.castRaw<FutureType>()) {
1040 return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not);
1041 }
1042 return false;
1043 }
1044
1045 private:
1046 FutureType(TypePtr elem) : SingleElementType(std::move(elem)) {}
1047
1048 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1049 std::stringstream ss;
1050 ss << "Future[" << getElementType()->annotation_str(std::move(printer)) << "]";
1051 return ss.str();
1052 }
1053};
1054
1055struct AwaitType;
1056using AwaitTypePtr = std::shared_ptr<AwaitType>;
1057
1058struct TORCH_API AwaitType
1059 : public SingleElementType<TypeKind::AwaitType, AwaitType> {
1060 friend struct Type;
1061 template <typename... T>
1062 static AwaitTypePtr create(TypePtr elem) {
1063 return AwaitTypePtr(
1064 new AwaitType(std::move(elem))); // NOLINT(modernize-make-shared)
1065 }
1066
1067 std::string str() const override {
1068 std::stringstream ss;
1069 ss << "Await(" << getElementType()->str() << ")";
1070 return ss.str();
1071 }
1072 TypePtr createWithContained(
1073 std::vector<TypePtr> contained_types) const override {
1074 return create(std::move(contained_types.at(0)));
1075 }
1076
1077 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1078 if (Type::isSubtypeOfExt(rhs, why_not)) {
1079 return true;
1080 }
1081 if (auto rhs_ = rhs.castRaw<AwaitType>()) {
1082 return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not);
1083 }
1084 return false;
1085 }
1086
1087 private:
1088 AwaitType(TypePtr elem) : SingleElementType(std::move(elem)) {}
1089
1090 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1091 std::stringstream ss;
1092 ss << "Await[" << getElementType()->annotation_str(printer) << "]";
1093 return ss.str();
1094 }
1095};
1096
1097struct RRefType;
1098using RRefTypePtr = std::shared_ptr<RRefType>;
1099
1100struct TORCH_API RRefType
1101 : public SingleElementType<TypeKind::RRefType, RRefType> {
1102 friend struct Type;
1103 template <typename... T>
1104 static RRefTypePtr create(TypePtr elem) {
1105 return RRefTypePtr(
1106 new RRefType(std::move(elem))); // NOLINT(modernize-make-shared)
1107 }
1108
1109 std::string str() const override {
1110 std::stringstream ss;
1111 ss << "RRef(" << getElementType()->str() << ")";
1112 return ss.str();
1113 }
1114 TypePtr createWithContained(
1115 std::vector<TypePtr> contained_types) const override {
1116 return create(std::move(contained_types.at(0)));
1117 }
1118
1119 private:
1120 RRefType(TypePtr elem) : SingleElementType(std::move(elem)) {}
1121
1122 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1123 std::stringstream ss;
1124 ss << "RRef[" << getElementType()->annotation_str(std::move(printer)) << "]";
1125 return ss.str();
1126 }
1127};
1128
1129// Any should never appear in a named type like a class, namedtuple or
1130// interface. If it does, then dynamic type information will be lost in the
1131// Pickler, leading to hard-to-track-down bugs that will only occur
1132// after saving or loading a model. This is because we rely on the
1133// static types in named types to reconstruct type tags of loaded
1134// values. Lifting this restriction requires solving the serialization
1135// problem first.
1136TORCH_API void checkNoAny(
1137 const Type& base,
1138 const char* what,
1139 const std::string& attrname,
1140 const TypePtr& attrtype);
1141
1142struct TupleType;
1143using TupleTypePtr = std::shared_ptr<TupleType>;
1144using NameList = std::vector<std::string>;
1145// This type represents a Tuple
1146struct TORCH_API TupleType : public NamedType {
1147
1148 static TupleTypePtr createNamed(const c10::optional<c10::QualifiedName>& name,
1149 const std::vector<std::string>& field_names,
1150 const std::vector<TypePtr>& field_types,
1151 std::vector<IValue>& field_defaults);
1152
1153 static TupleTypePtr createNamed(const c10::optional<c10::QualifiedName>& name,
1154 const std::vector<std::string>& field_names,
1155 const std::vector<TypePtr>& field_types);
1156
1157 static TupleTypePtr createNamed(const c10::optional<c10::QualifiedName>& name,
1158 const std::vector<c10::string_view>& field_names,
1159 const std::vector<TypePtr>& field_types);
1160
1161 static TupleTypePtr create(
1162 std::vector<TypePtr> types) {
1163 return TupleTypePtr(new TupleType(
1164 std::move(types),
1165 c10::nullopt,
1166 nullptr)); // NOLINT(modernize-make-shared)
1167 }
1168 static TupleTypePtr create() {
1169 return create({});
1170 }
1171
1172 at::ArrayRef<TypePtr> elements() const {
1173 return elements_;
1174 }
1175
1176 bool equals(const Type& rhs) const override;
1177 bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
1178
1179 std::string str() const override;
1180 bool hasFreeVariables() const override {
1181 return has_free_variables_;
1182 }
1183 at::ArrayRef<TypePtr> containedTypes() const override {
1184 return elements_;
1185 }
1186 TypePtr createWithContained(
1187 std::vector<TypePtr> contained_types) const override {
1188 return std::shared_ptr<TupleType>(
1189 new TupleType(std::move(contained_types), name(), schema()));
1190 }
1191 const std::shared_ptr<FunctionSchema>& schema() const {
1192 return schema_;
1193 }
1194 c10::optional<std::vector<c10::string_view>> names() const;
1195
1196 static const TypeKind Kind = TypeKind::TupleType;
1197
1198 private:
1199 template <typename S>
1200 static TupleTypePtr createWithSpec(
1201 const c10::optional<c10::QualifiedName>& name,
1202 const std::vector<S>& field_names,
1203 const std::vector<TypePtr>& field_types,
1204 std::vector<IValue>& field_defaults);
1205
1206 TupleType(
1207 std::vector<TypePtr> elements_,
1208 c10::optional<c10::QualifiedName> name,
1209 std::shared_ptr<FunctionSchema> schema);
1210
1211 bool compare(
1212 const Type& rhs,
1213 std::function<bool(const Type&, const Type&)> fn) const {
1214 if (rhs.kind() != kind()) {
1215 return false;
1216 }
1217
1218 const auto& l_elements = elements();
1219 const auto& r_elements = rhs.castRaw<TupleType>()->elements();
1220 if (l_elements.size() != r_elements.size())
1221 return false;
1222 for (size_t i = 0; i < l_elements.size(); ++i) {
1223 if (!fn(*l_elements[i], *r_elements[i]))
1224 return false;
1225 }
1226 return true;
1227 }
1228
1229 std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
1230
1231 std::vector<TypePtr> elements_;
1232 bool has_free_variables_;
1233 std::shared_ptr<FunctionSchema> schema_;
1234};
1235
1236// the common supertype of all Enums, only used in operator registraion.
1237// EnumType <: AnyEnumType for all Enums
1238struct AnyEnumType;
1239using AnyEnumTypePtr = SingletonTypePtr<AnyEnumType>;
1240struct TORCH_API AnyEnumType final : public Type {
1241 bool equals(const Type& rhs) const override {
1242 return rhs.kind() == kind();
1243 }
1244 std::string str() const override {
1245 return "AnyEnumType";
1246 }
1247 static const TypeKind Kind = TypeKind::AnyEnumType;
1248 // global singleton
1249 static AnyEnumTypePtr get();
1250private:
1251 AnyEnumType()
1252 : Type(TypeKind::AnyEnumType) {}
1253};
1254
1255struct NumberType;
1256using NumberTypePtr = SingletonTypePtr<NumberType>;
1257// This type represents a Python number
1258// Subtype hierarchy for Number Types (NumberType as the base type):
1259// IntType <: NumberType
1260// FloatType <: NumberType
1261// ComplexType <:NumberType
1262//
1263// WARNING: if you add a new subtype of NumberType that is not
1264// represented by a global singleton, you need to change NumberTypePtr
1265// to a SingletonOrSharedTypePtr and deal with NumberType needing to
1266// both inherit and not inherit from SharedType!
1267struct TORCH_API NumberType : public Type {
1268 bool equals(const Type& rhs) const override;
1269
1270 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
1271
1272 std::string str() const override {
1273 return "Scalar"; // match what PythonArgParser says for clarity
1274 }
1275 static const TypeKind Kind = TypeKind::NumberType;
1276 // global singleton
1277 static NumberTypePtr get();
1278
1279 protected:
1280 NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {}
1281
1282 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1283 (void)printer; // Suppress unused variable warning
1284 return "number"; // technically not a valid python type, but
1285 // we need to use it when parsing back in annotations
1286 // for implicit conversions
1287 }
1288};
1289
1290struct FloatType;
1291using FloatTypePtr = SingletonTypePtr<FloatType>;
1292// This type represents a Python float number
1293struct TORCH_API FloatType : public NumberType {
1294 bool equals(const Type& rhs) const override {
1295 return rhs.kind() == kind();
1296 }
1297 std::string str() const override {
1298 return "float";
1299 }
1300 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1301 // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1302 return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
1303 }
1304 static const TypeKind Kind = TypeKind::FloatType;
1305 // global singleton
1306 static FloatTypePtr get();
1307
1308 private:
1309 FloatType() : NumberType(TypeKind::FloatType) {}
1310 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1311 (void)printer; // Suppress unused variable warning
1312 return "float";
1313 }
1314};
1315
1316struct ComplexType;
1317using ComplexTypePtr = SingletonTypePtr<ComplexType>;
1318// This type represents a Python float number
1319struct TORCH_API ComplexType : public NumberType {
1320 bool equals(const Type& rhs) const override {
1321 return rhs.kind() == kind();
1322 }
1323 std::string str() const override {
1324 return "complex";
1325 }
1326 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1327 // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1328 return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
1329 }
1330 static const TypeKind Kind = TypeKind::ComplexType;
1331 // global singleton
1332 static ComplexTypePtr get();
1333
1334 private:
1335 ComplexType() : NumberType(TypeKind::ComplexType) {}
1336 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1337 (void)printer; // Suppress unused variable warning
1338 return "complex";
1339 }
1340};
1341
1342// We need to introduce `SymIntType` to represent the `SymInt` type
1343// used in function schemas e.g. `aten::narrow_copy(... SymInt length)
1344// `SymInt` will be used to enable tracing arithmetic operations on
1345// dimension values. Please see [SymInt.h] for more information
1346struct SymIntType;
1347using SymIntTypePtr = SingletonTypePtr<SymIntType>;
1348struct TORCH_API SymIntType : public Type {
1349 bool equals(const Type& rhs) const override {
1350 return rhs.kind() == kind();
1351 }
1352 std::string str() const override {
1353 return "SymInt";
1354 }
1355 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1356 return "int";
1357 }
1358 static const TypeKind Kind = TypeKind::SymIntType;
1359 // global singleton
1360 static SymIntTypePtr get();
1361
1362 private:
1363 SymIntType() : Type(TypeKind::SymIntType) {}
1364};
1365
1366struct SymFloatType;
1367using SymFloatTypePtr = SingletonTypePtr<SymFloatType>;
1368struct TORCH_API SymFloatType : public Type {
1369 bool equals(const Type& rhs) const override {
1370 return rhs.kind() == kind();
1371 }
1372 std::string str() const override {
1373 return "SymFloat";
1374 }
1375 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1376 return "float";
1377 }
1378 static const TypeKind Kind = TypeKind::SymFloatType;
1379 // global singleton
1380 static SymFloatTypePtr get();
1381
1382 private:
1383 SymFloatType() : Type(TypeKind::SymFloatType) {}
1384};
1385
1386struct IntType;
1387using IntTypePtr = SingletonTypePtr<IntType>;
1388// This type represents a Python int number
1389struct TORCH_API IntType : public NumberType {
1390 bool equals(const Type& rhs) const override {
1391 return rhs.kind() == kind();
1392 }
1393 std::string str() const override {
1394 return "int";
1395 }
1396 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1397 // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1398 return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
1399 }
1400 static const TypeKind Kind = TypeKind::IntType;
1401 // global singleton
1402 static IntTypePtr get();
1403
1404 private:
1405 IntType() : NumberType(TypeKind::IntType) {}
1406 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1407 (void)printer; // Suppress unused variable warning
1408 return "int";
1409 }
1410};
1411
1412struct BoolType;
1413using BoolTypePtr = SingletonTypePtr<BoolType>;
1414// This node represents a Python bool value
1415struct TORCH_API BoolType : public Type {
1416 bool equals(const Type& rhs) const override {
1417 return rhs.kind() == kind();
1418 }
1419 std::string str() const override {
1420 return "bool";
1421 }
1422 static const TypeKind Kind = TypeKind::BoolType;
1423 // global singleton
1424 static BoolTypePtr get();
1425
1426 private:
1427 BoolType() : Type(TypeKind::BoolType) {}
1428};
1429
1430struct StringType;
1431using StringTypePtr = SingletonTypePtr<StringType>;
1432// This type represents a Python string
1433struct TORCH_API StringType : public Type {
1434 bool equals(const Type& rhs) const override {
1435 return rhs.kind() == kind();
1436 }
1437 std::string str() const override {
1438 // we only use "str" (not "string") in both FunctionSchema and script
1439 return annotation_str();
1440 }
1441 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1442 (void)printer; // Suppress unused variable warning
1443 return "str";
1444 }
1445 static const TypeKind Kind = TypeKind::StringType;
1446 // global singleton
1447 static StringTypePtr get();
1448
1449 private:
1450 StringType() : Type(TypeKind::StringType) {}
1451};
1452
1453struct StorageType;
1454using StorageTypePtr = SingletonTypePtr<StorageType>;
1455struct TORCH_API StorageType : public Type {
1456 bool equals(const Type& rhs) const override {
1457 return rhs.kind() == kind();
1458 }
1459 std::string str() const override {
1460 return annotation_str();
1461 }
1462 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1463 (void)printer; // Suppress unused variable warning
1464 return "Storage";
1465 }
1466 static const TypeKind Kind = TypeKind::StorageType;
1467 // global singleton
1468 static StorageTypePtr get();
1469
1470 private:
1471 StorageType() : Type(TypeKind::StorageType) {}
1472};
1473
1474struct FunctionType;
1475using FunctionTypePtr = std::shared_ptr<FunctionType>;
1476struct TORCH_API FunctionType : public NamedType {
1477 static FunctionTypePtr create(torch::jit::Function* function) {
1478 return FunctionTypePtr(
1479 new FunctionType(function)); // NOLINT(modernize-make-shared)
1480 }
1481 bool equals(const Type& rhs) const override {
1482 if (auto func_type = rhs.cast<FunctionType>()) {
1483 return func_type->function_ == function_;
1484 }
1485
1486 return false;
1487 }
1488 std::string str() const override {
1489 return "Function";
1490 }
1491 torch::jit::Function* function() const {
1492 return function_;
1493 }
1494 static const TypeKind Kind = TypeKind::FunctionType;
1495
1496 private:
1497 FunctionType(torch::jit::Function* function);
1498 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1499 (void)printer; // Suppress unused variable warning
1500 const auto& n = name().value();
1501 return n.qualifiedName();
1502 }
1503 torch::jit::Function* function_;
1504};
1505
1506struct NoneType;
1507using NoneTypePtr = SingletonTypePtr<NoneType>;
1508// This type represents a Python None
1509struct TORCH_API NoneType : public Type {
1510 bool equals(const Type& rhs) const override {
1511 return rhs.kind() == kind();
1512 }
1513 std::string str() const override {
1514 return "NoneType";
1515 }
1516 bool isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const override;
1517
1518 static const TypeKind Kind = TypeKind::NoneType;
1519 // global singleton
1520 static NoneTypePtr get();
1521
1522 private:
1523 NoneType() : Type(TypeKind::NoneType) {}
1524};
1525
1526struct GeneratorType;
1527using GeneratorTypePtr = SingletonTypePtr<GeneratorType>;
1528// This type represents a Generator
1529struct TORCH_API GeneratorType : public Type {
1530 bool equals(const Type& rhs) const override {
1531 return rhs.kind() == kind();
1532 }
1533 std::string str() const override {
1534 return "Generator";
1535 }
1536 static const TypeKind Kind = TypeKind::GeneratorType;
1537 // global singleton
1538 static GeneratorTypePtr get();
1539
1540 private:
1541 GeneratorType() : Type(TypeKind::GeneratorType) {}
1542};
1543
1544struct QuantizerType;
1545using QuantizerTypePtr = SingletonTypePtr<QuantizerType>;
1546// This type represents a Quantizer
1547struct TORCH_API QuantizerType : public Type {
1548 bool equals(const Type& rhs) const override {
1549 return rhs.kind() == kind();
1550 }
1551 std::string str() const override {
1552 return "Quantizer";
1553 }
1554 static const TypeKind Kind = TypeKind::QuantizerType;
1555 // global singleton
1556 static QuantizerTypePtr get();
1557
1558 private:
1559 QuantizerType() : Type(TypeKind::QuantizerType) {}
1560};
1561
1562struct QSchemeType;
1563using QSchemeTypePtr = SingletonTypePtr<QSchemeType>;
1564// This type represents a QScheme
1565struct TORCH_API QSchemeType : public Type {
1566 bool equals(const Type& rhs) const override {
1567 return rhs.kind() == kind();
1568 }
1569 std::string str() const override {
1570 return "QScheme";
1571 }
1572 static const TypeKind Kind = TypeKind::QSchemeType;
1573 // global singleton
1574 static QSchemeTypePtr get();
1575
1576 private:
1577 QSchemeType() : Type(TypeKind::QSchemeType) {}
1578};
1579
1580struct DeviceObjType;
1581using DeviceObjTypePtr = SingletonTypePtr<DeviceObjType>;
1582// This type represents a Device
1583struct TORCH_API DeviceObjType : public Type {
1584 bool equals(const Type& rhs) const override {
1585 return rhs.kind() == kind();
1586 }
1587 std::string str() const override {
1588 return "Device";
1589 }
1590 static const TypeKind Kind = TypeKind::DeviceObjType;
1591 // global singleton
1592 static DeviceObjTypePtr get();
1593
1594 private:
1595 DeviceObjType() : Type(TypeKind::DeviceObjType) {}
1596};
1597
1598struct StreamObjType;
1599using StreamObjTypePtr = SingletonTypePtr<StreamObjType>;
1600// This type represents a Generator
1601struct TORCH_API StreamObjType : public Type {
1602 bool equals(const Type& rhs) const override {
1603 return rhs.kind() == kind();
1604 }
1605 std::string str() const override {
1606 return "Stream";
1607 }
1608 static const TypeKind Kind = TypeKind::StreamObjType;
1609 // global singleton
1610 static StreamObjTypePtr get();
1611
1612private:
1613 StreamObjType() : Type(TypeKind::StreamObjType) {}
1614};
1615
1616struct VarType;
1617using VarTypePtr = std::shared_ptr<VarType>;
1618// This type represents a type variable, used in FunctionSchema
1619struct VarType : public SharedType {
1620 static VarTypePtr create(std::string name_) {
1621 return VarTypePtr(new VarType(std::move(name_)));
1622 }
1623 bool equals(const Type& rhs) const override {
1624 return rhs.kind() == kind();
1625 }
1626 std::string str() const override {
1627 return name();
1628 }
1629 const std::string& name() const {
1630 return name_;
1631 }
1632 bool hasFreeVariables() const override {
1633 return true;
1634 }
1635 static const TypeKind Kind = TypeKind::VarType;
1636
1637 private:
1638 VarType(std::string name_)
1639 : SharedType(TypeKind::VarType), name_(std::move(name_)) {}
1640 std::string name_;
1641};
1642
1643struct CapsuleType;
1644using CapsuleTypePtr = SingletonTypePtr<CapsuleType>;
1645// This type represents a Python Capsule.
1646// It does not appear in the IR and is only used during runtime
1647struct TORCH_API CapsuleType : public Type {
1648 bool equals(const Type& rhs) const override {
1649 return rhs.kind() == kind();
1650 }
1651 std::string str() const override {
1652 return "Capsule";
1653 }
1654 static const TypeKind Kind = TypeKind::CapsuleType;
1655 // global singleton
1656 static CapsuleTypePtr get();
1657private:
1658 CapsuleType()
1659 : Type(TypeKind::CapsuleType) {}
1660};
1661
1662struct PyObjectType;
1663using PyObjectTypePtr = SingletonTypePtr<PyObjectType>;
1664// This type represents a PyObject Type
1665struct TORCH_API PyObjectType : public Type {
1666 bool equals(const Type& rhs) const override {
1667 return rhs.kind() == kind();
1668 }
1669 std::string str() const override {
1670 return "PyObject";
1671 }
1672 static const TypeKind Kind = TypeKind::PyObjectType;
1673 // global singleton
1674 static PyObjectTypePtr get();
1675private:
1676 PyObjectType()
1677 : Type(TypeKind::PyObjectType) {}
1678};
1679
1680enum class TypeVerbosity {
1681 None,
1682 Type,
1683 TypeAndStride,
1684 Full,
1685 Symbolic,
1686 Default = Full,
1687};
1688
1689TORCH_API TypeVerbosity type_verbosity();
1690
1691TORCH_API std::ostream& operator<<(std::ostream& out, const Type& t);
1692template <typename T>
1693TORCH_API std::ostream& operator<<(
1694 std::ostream& out,
1695 const VaryingShape<T>& t);
1696TORCH_API std::ostream& operator<<(std::ostream& os, const SymbolicShape& s);
1697TORCH_API std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s);
1698TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
1699// what is the type, ignoring extra size/shape information?
1700// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
1701
1702// `unshapedType` is used to remove Tensor subtypes. We treat all Tensor
1703// subtypes as simply "Tensor"; we also create a new version of any
1704// container types in which internal Tensors have undergone the same
1705// operation. This is used for type comparisons between two Tensor types
1706// (`unshapedType` means that we don't falsely return `false` for e.g.
1707// Tensors of different dimensions). It's also used in the alias
1708// analysis pass.
1709// Be careful with calls because this can be very slow. If calling this
1710// on a graph, use `EraseShapeInformation` in shape_analysis.h
1711inline TypePtr unshapedType(const TypePtr& type) {
1712 if (type->isSubtypeOf(*TensorType::get())) {
1713 return TensorType::get();
1714 }
1715 at::ArrayRef<TypePtr> contained = type->containedTypes();
1716 if (contained.empty()) {
1717 return type;
1718 }
1719 return type->withContained(fmap(type->containedTypes(), unshapedType));
1720}
1721
1722inline TypePtr TensorType::fromNumberType(const Type& typ) {
1723 if (typ.isSubtypeOf(*IntType::get())) {
1724 return TensorType::createContiguous(at::kLong, at::kCPU, {});
1725 } else if (typ.isSubtypeOf(*FloatType::get())) {
1726 return TensorType::createContiguous(at::kDouble, at::kCPU, {});
1727 } else if (typ.isSubtypeOf(*BoolType::get())) {
1728 return TensorType::createContiguous(at::kBool, at::kCPU, {});
1729 } else if (typ.kind() == NumberType::Kind) {
1730 return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt);
1731 }
1732 TORCH_CHECK(false, "Unknown number type: ", typ.str());
1733}
1734inline TypePtr TensorType::fromBoolType() {
1735 return TensorType::createContiguous(at::kBool, at::kCPU, {});
1736}
1737
1738inline c10::optional<c10::ScalarType> tryScalarTypeFromJitType(const Type& type) {
1739 if (type == *FloatType::get()) {
1740 return at::typeMetaToScalarType(c10::get_default_dtype());
1741 } else if (type == *IntType::get()) {
1742 return at::ScalarType::Long;
1743 } else if (type == *BoolType::get()) {
1744 return at::ScalarType::Bool;
1745 }
1746 return c10::nullopt;
1747}
1748
1749inline at::ScalarType scalarTypeFromJitType(const Type& type) {
1750 auto result = tryScalarTypeFromJitType(type);
1751 TORCH_CHECK(
1752 result,
1753 "Add new condition, expected Float, Complex, Int, or Bool but got",
1754 type.str());
1755 return *result;
1756}
1757
1758// Attempt to find the correct supertype of the two types `t1` and `t2`.
1759// If no supertype is found, then nullopt will be returned if
1760// `default_to_union` is false, and `Union[t1, t2]` will be returned
1761// if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`,
1762// then `t2` will be returned (and vice versa).
1763//
1764// Two different tensortypes will return dynamic.
1765//
1766// Currently we chose not to support returning a NumberType for
1767// two types from the set of {FloatType, IntType, ComplexType}, because
1768// there is a lack of operator support for NumberType.
1769//
1770// If `type_hint` is an `InterfaceType`, then we can use that as a
1771// potential supertype for `ClassType`s in the list. Otherwise, we have
1772// no way to find and use some common interface type
1773TORCH_API c10::optional<TypePtr> unifyTypes(
1774 const TypePtr& t1,
1775 const TypePtr& t2,
1776 bool default_to_union = false,
1777 TypePtr type_hint = nullptr);
1778
1779TORCH_API c10::optional<TypePtr> unifyTypeList(
1780 at::ArrayRef<TypePtr> elements,
1781 std::ostream& why_not,
1782 bool default_to_union = false,
1783 TypePtr type_hint = nullptr);
1784
1785namespace detail {
1786template <typename T>
1787struct getTypePtr_ final {
1788 static decltype(auto) call() {
1789 return ([]() {
1790 try {
1791 return getCustomClassType<T>();
1792 } catch(const c10::Error&) {
1793 TORCH_CHECK(
1794 false,
1795 "Type ",
1796 c10::util::get_fully_qualified_type_name<T>(),
1797 " could not be converted to any of the known types."
1798 );
1799 }
1800 }());
1801 }
1802};
1803
1804template <typename T, bool fake>
1805struct getMaybeFakeTypePtr_ final {
1806 static decltype(auto) call() {
1807 return getTypePtr_<T>::call();
1808 }
1809};
1810
1811template <>
1812struct getTypePtr_<at::IValue> final {
1813 static decltype(auto) call() {
1814 return AnyType::get();
1815 }
1816};
1817
1818template <>
1819struct getTypePtr_<at::Tensor> final {
1820 static decltype(auto) call() {
1821 return TensorType::get();
1822 }
1823};
1824template <>
1825struct getTypePtr_<c10::Storage> final {
1826 static decltype(auto) call() {
1827 return StorageType::get();
1828 }
1829};
1830template <>
1831struct getTypePtr_<c10::Stream> final {
1832 static decltype(auto) call() {
1833 return StreamObjType::get();
1834 }
1835};
1836template <>
1837struct getTypePtr_<double> final {
1838 static decltype(auto) call() {
1839 return FloatType::get();
1840 }
1841};
1842template <>
1843struct getTypePtr_<c10::complex<double>> final {
1844 static decltype(auto) call() {
1845 return ComplexType::get();
1846 }
1847};
1848template <>
1849struct getTypePtr_<int64_t> final {
1850 static decltype(auto) call() {
1851 return IntType::get();
1852 }
1853};
1854
1855template <>
1856struct getMaybeFakeTypePtr_<SymInt, false> final {
1857 static decltype(auto) call() {
1858 return SymIntType::get();
1859 }
1860};
1861template <>
1862struct getMaybeFakeTypePtr_<SymInt, true> final {
1863 static decltype(auto) call() {
1864 return IntType::get();
1865 }
1866};
1867
1868template <>
1869struct getMaybeFakeTypePtr_<SymFloat, false> final {
1870 static decltype(auto) call() {
1871 return SymFloatType::get();
1872 }
1873};
1874template <>
1875struct getMaybeFakeTypePtr_<SymFloat, true> final {
1876 static decltype(auto) call() {
1877 return FloatType::get();
1878 }
1879};
1880
1881template <>
1882struct getTypePtr_<c10::Device> final {
1883 static decltype(auto) call() {
1884 return DeviceObjType::get();
1885 }
1886};
1887template <>
1888struct getTypePtr_<bool> final {
1889 static decltype(auto) call() {
1890 return BoolType::get();
1891 }
1892};
1893template <>
1894struct getTypePtr_<at::Scalar> final {
1895 static decltype(auto) call() {
1896 return NumberType::get();
1897 }
1898};
1899template <>
1900struct getTypePtr_<c10::QScheme> final {
1901 static decltype(auto) call() {
1902 return QSchemeType::get();
1903 }
1904};
1905template <>
1906struct getTypePtr_<at::Generator> final {
1907 static decltype(auto) call() {
1908 return TypeFactory::create<OptionalType>(
1909 TypeFactory::get<GeneratorType>());
1910 }
1911};
1912template <>
1913struct getTypePtr_<std::string> final {
1914 static decltype(auto) call() {
1915 return StringType::get();
1916 }
1917};
1918template <>
1919struct getTypePtr_<c10::string_view> final {
1920 static decltype(auto) call() {
1921 return StringType::get();
1922 }
1923};
1924template <>
1925struct getTypePtr_<at::Dimname> final {
1926 static decltype(auto) call() {
1927 return StringType::get();
1928 }
1929};
1930template <class T, bool fake>
1931struct getMaybeFakeTypePtr_<std::vector<T>, fake> final {
1932 static const auto& call() {
1933 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1934 // The "per vector<T>" static singleton needs to live in a .cpp file,
1935 // otherwise we'll end up with one singleton instance per shared library.
1936 static auto type = ListType::get("vector", inner_type);
1937 return type;
1938 }
1939};
1940template <class T, bool fake>
1941struct getMaybeFakeTypePtr_<c10::ArrayRef<T>, fake> final {
1942 static const auto& call() {
1943 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1944 // The "per ArrayRef<T>" static singleton needs to live in a .cpp file,
1945 // otherwise we'll end up with one singleton instance per shared library.
1946 static auto type = ListType::get("ArrayRef", inner_type);
1947 return type;
1948 }
1949};
1950template <bool fake>
1951struct getMaybeFakeTypePtr_<c10::SymIntArrayRef, fake> final {
1952 static const auto& call() {
1953 static auto type = ListType::create(getMaybeFakeTypePtr_<c10::SymInt, fake>::call());
1954 return type;
1955 }
1956};
1957template <class T, bool fake>
1958struct getMaybeFakeTypePtr_<c10::List<T>, fake> final {
1959 static const auto& call() {
1960 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1961 // The "per List<T>" static singleton needs to live in a .cpp file,
1962 // otherwise we'll end up with one singleton instance per shared library.
1963 static auto type = ListType::get("List", inner_type);
1964 return type;
1965 }
1966};
1967template <class T, bool fake>
1968struct getMaybeFakeTypePtr_<c10::IListRef<T>, fake> final {
1969 static const auto& call() {
1970 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1971 static auto type = ListType::get("List", inner_type);
1972 return type;
1973 }
1974};
1975template <class T, size_t N, bool fake>
1976struct getMaybeFakeTypePtr_<std::array<T, N>, fake> final {
1977 static const auto& call() {
1978 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1979 // The "per array<T, N>" static singleton needs to live in a .cpp file,
1980 // otherwise we'll end up with one singleton instance per shared library.
1981 // (Concatenating the length onto the end of the string because we want a unique
1982 // type_ptr created for every std::array<T, N> type).
1983 static auto type = ListType::get(std::string("array") + std::to_string(N), inner_type);
1984 return type;
1985 }
1986};
1987template <class K, class V, bool fake>
1988struct getMaybeFakeTypePtr_<std::unordered_map<K, V>, fake> final {
1989 static const auto& call() {
1990 static auto inner_key_type = getMaybeFakeTypePtr_<K, fake>::call();
1991 static auto inner_val_type = getMaybeFakeTypePtr_<V, fake>::call();
1992 // The "per unordered_map<K, V>" static singleton needs to live in a .cpp file,
1993 // otherwise we'll end up with one singleton instance per shared library.
1994 static auto type = DictType::get("unordered_map", inner_key_type, inner_val_type);
1995 return type;
1996 }
1997};
1998template <class K, class V, bool fake>
1999struct getMaybeFakeTypePtr_<c10::Dict<K, V>, fake> final {
2000 static const auto& call() {
2001 static auto inner_key_type = getMaybeFakeTypePtr_<K, fake>::call();
2002 static auto inner_val_type = getMaybeFakeTypePtr_<V, fake>::call();
2003 // The "per Dict<K, V>" static singleton needs to live in a .cpp file,
2004 // otherwise we'll end up with one singleton instance per shared library.
2005 static auto type = DictType::get("Dict", inner_key_type, inner_val_type);
2006 return type;
2007 }
2008};
2009
2010template <class T, bool fake>
2011struct getMaybeFakeTypePtr_<at::optional<T>, fake> final {
2012 static const auto& call() {
2013 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
2014 // The "per optional<T>" static singleton needs to live in a .cpp file,
2015 // otherwise we'll end up with one singleton instance per shared library.
2016 static auto type = OptionalType::get(inner_type);
2017 return type;
2018 }
2019};
2020
2021
2022template<>
2023struct getTypePtr_<at::OptionalIntArrayRef> final {
2024 static const auto& call() {
2025 static auto inner_type = getMaybeFakeTypePtr_<IntArrayRef, false>::call();
2026 // The "per optional<T>" static singleton needs to live in a .cpp file,
2027 // otherwise we'll end up with one singleton instance per shared library.
2028 static auto type = OptionalType::get(inner_type);
2029 return type;
2030 }
2031};
2032
2033template <bool fake>
2034struct getMaybeFakeTypePtr_<at::OptionalSymIntArrayRef, fake> final {
2035 static const auto& call() {
2036 // The "per optional<T>" static singleton needs to live in a .cpp file,
2037 // otherwise we'll end up with one singleton instance per shared library.
2038 static auto inner_type = getMaybeFakeTypePtr_<SymIntArrayRef, fake>::call();
2039 static auto type = OptionalType::get(inner_type);
2040 return type;
2041 }
2042};
2043
2044template <class... Contained, bool fake>
2045struct getMaybeFakeTypePtr_<std::tuple<Contained...>, fake> final {
2046 static const auto& call() {
2047 static auto type = ([]() {
2048 std::vector<TypePtr> contained_types = {
2049 (getMaybeFakeTypePtr_<Contained, fake>::call())...
2050 };
2051 return TupleType::create(std::move(contained_types));
2052 })();
2053 return type;
2054 }
2055};
2056template <>
2057struct getTypePtr_<void> final {
2058 static decltype(auto) call() {
2059 return NoneType::get();
2060 }
2061};
2062} // namespace detail
2063template <class T>
2064inline decltype(auto) getTypePtr() {
2065 // TODO: static_assert that a templated function exists, and throw a friendly
2066 // error message if not
2067 return detail::getMaybeFakeTypePtr_<T, false>::call();
2068}
2069
2070template <class T>
2071inline TypePtr getTypePtrCopy() {
2072 // TODO: static_assert that a templated function exists, and throw a friendly
2073 // error message if not
2074 return getTypePtr<T>();
2075}
2076
2077template <class T>
2078inline decltype(auto) getFakeTypePtr() {
2079 return detail::getMaybeFakeTypePtr_<T, true>::call();
2080}
2081
2082template <class T>
2083inline TypePtr getFakeTypePtrCopy() {
2084 return getFakeTypePtr<T>();
2085}
2086
2087using TypeEnv = std::unordered_map<std::string, TypePtr>;
2088struct MatchTypeReturn {
2089 MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {}
2090 static MatchTypeReturn Success() {
2091 return MatchTypeReturn();
2092 }
2093 bool success() const {
2094 return !reason_.has_value();
2095 }
2096 const std::string& reason() const {
2097 return reason_.value();
2098 }
2099
2100 private:
2101 MatchTypeReturn()
2102 : reason_(c10::nullopt) {}
2103 c10::optional<std::string> reason_; // is there is no match, this contains the reason
2104};
2105
2106// attempt to match the type variables in formal to actual, adding them to type_env.
2107// If no match is possible this returns a MatchTypeReturn with r.success() == false
2108// and a r.reason() that describes why it could not match.
2109// note: It is possible to successfully match a formal, but for type variables
2110// in the formal to still not be defined. In particular, None matches Optional[T]
2111// but does not define the value of T.
2112TORCH_API MatchTypeReturn
2113matchTypeVariables(const TypePtr& formal, const TypePtr& actual, TypeEnv& type_env);
2114
2115// replace type variables appearing in `type` with the values in
2116// `type_env`. Returns nullptr if a variable used in `type`
2117// does not appear in `type_env`
2118TORCH_API TypePtr tryEvalTypeVariables(const TypePtr& type, TypeEnv& type_env);
2119
2120TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type);
2121
2122struct InterfaceType;
2123using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
2124
2125// Interfaces are a list of abstract methods that a class might meet.
2126// If a class provides those methods, it implicitly meets the interface.
2127
2128// Subtype relations for Interface with ClassType:
2129// lhs (ClassType or InterfaceType) is a subtype of rhs if:
2130// 1. lhs methods are a superset of rhs methods
2131// 2. if rhs is module interface, the lhs must be module interface or module itself
2132struct TORCH_API InterfaceType : public NamedType {
2133 static InterfaceTypePtr create(
2134 QualifiedName qualifiedName, bool is_module=false);
2135
2136 bool equals(const Type& rhs) const override {
2137 if (auto user_rhs = rhs.castRaw<InterfaceType>()) {
2138 return isSubTypeImpl(*this, *user_rhs, nullptr) &&
2139 isSubTypeImpl(*user_rhs, *this, nullptr);
2140 }
2141 return false;
2142 }
2143
2144 std::string str() const override {
2145 return std::string("InterfaceType<") + name()->name() + ">";
2146 }
2147
2148 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
2149
2150 // try to find a method of this interface,
2151 // returns nullptr if not found.
2152 const FunctionSchema* getMethod(const std::string& name) const;
2153 void addMethod(FunctionSchema schema);
2154 const std::vector<FunctionSchema>& methods() const {
2155 return *methods_;
2156 }
2157
2158 bool is_module() const override{
2159 return is_module_;
2160 }
2161 static const TypeKind Kind = TypeKind::InterfaceType;
2162 ~InterfaceType() override;
2163 private:
2164 InterfaceType(QualifiedName name, bool is_module);
2165 static bool isSubTypeImpl(
2166 const InterfaceType& lhs,
2167 const InterfaceType& rhs,
2168 std::ostream* why_not);
2169
2170 std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
2171 (void)printer; // Suppress unused variable warning
2172 return name()->qualifiedName();
2173 }
2174
2175 // shared_ptr so that this header does not have to depend on
2176 // FunctionSchema.h
2177 std::shared_ptr<std::vector<FunctionSchema>> methods_;
2178 // flag to distinguish if it's an interface type from a module or not
2179 bool is_module_;
2180};
2181
2182template <TypeKind K>
2183struct EnumerationType : public Type {
2184static const TypeKind Kind = K;
2185
2186bool equals(const Type& rhs) const override {
2187 return rhs.kind() == kind();
2188}
2189
2190protected:
2191EnumerationType() : Type(Kind) {}
2192};
2193
2194// WARNING: These enumeration types below DO NOT actually get parsed out
2195// from the logical schema strings, instead they are mapped as ints. To
2196// observe these types, use real_type() instead of type() on Argument
2197
2198struct ScalarTypeType;
2199using ScalarTypeTypePtr = SingletonTypePtr<ScalarTypeType>;
2200struct TORCH_API ScalarTypeType : public EnumerationType<TypeKind::ScalarTypeType> {
2201std::string str() const override {
2202return "ScalarType";
2203}
2204static const TypeKind Kind = TypeKind::ScalarTypeType;
2205// global singleton
2206static ScalarTypeTypePtr get();
2207
2208private:
2209ScalarTypeType() : EnumerationType() {}
2210};
2211
2212struct MemoryFormatType;
2213using MemoryFormatTypePtr = SingletonTypePtr<MemoryFormatType>;
2214struct TORCH_API MemoryFormatType : public EnumerationType<TypeKind::MemoryFormatType> {
2215std::string str() const override {
2216return "MemoryFormat";
2217}
2218static const TypeKind Kind = TypeKind::MemoryFormatType;
2219// global singleton
2220static MemoryFormatTypePtr get();
2221
2222private:
2223MemoryFormatType() : EnumerationType() {}
2224};
2225
2226struct LayoutType;
2227using LayoutTypePtr = SingletonTypePtr<LayoutType>;
2228struct TORCH_API LayoutType : public EnumerationType<TypeKind::LayoutType> {
2229std::string str() const override {
2230return "Layout";
2231}
2232static const TypeKind Kind = TypeKind::LayoutType;
2233// global singleton
2234static LayoutTypePtr get();
2235
2236private:
2237LayoutType() : EnumerationType() {}
2238};
2239
2240namespace detail {
2241template <>
2242struct getMaybeFakeTypePtr_<c10::ScalarType, false> final {
2243 static decltype(auto) call() {
2244 return ScalarTypeType::get();
2245 }
2246};
2247template <>
2248struct getMaybeFakeTypePtr_<c10::Layout, false> final {
2249 static decltype(auto) call() {
2250 return LayoutType::get();
2251 }
2252};
2253template <>
2254struct getMaybeFakeTypePtr_<c10::MemoryFormat, false> final {
2255 static decltype(auto) call() {
2256 return MemoryFormatType::get();
2257 }
2258};
2259template <>
2260struct getMaybeFakeTypePtr_<c10::ScalarType, true> final {
2261 static decltype(auto) call() {
2262 return IntType::get();
2263 }
2264};
2265template <>
2266struct getMaybeFakeTypePtr_<c10::Layout, true> final {
2267 static decltype(auto) call() {
2268 return IntType::get();
2269 }
2270};
2271template <>
2272struct getMaybeFakeTypePtr_<c10::MemoryFormat, true> final {
2273 static decltype(auto) call() {
2274 return IntType::get();
2275 }
2276};
2277} // namespace detail
2278
2279// the common supertype of all lists,
2280// List[T] <: AnyList for all T
2281struct AnyListType;
2282using AnyListTypePtr = SingletonTypePtr<AnyListType>;
2283struct TORCH_API AnyListType : public Type {
2284 bool equals(const Type& rhs) const override {
2285 return rhs.kind() == kind();
2286 }
2287 std::string str() const override {
2288 return "list";
2289 }
2290 static const TypeKind Kind = TypeKind::AnyListType;
2291 // global singleton
2292 static AnyListTypePtr get();
2293private:
2294 AnyListType()
2295 : Type(TypeKind::AnyListType) {}
2296};
2297
2298// the common supertype of all tuples,
2299// Tuple[T...] <: AnyTuple for all T
2300struct AnyTupleType;
2301using AnyTupleTypePtr = SingletonTypePtr<AnyTupleType>;
2302struct TORCH_API AnyTupleType : public Type {
2303 bool equals(const Type& rhs) const override {
2304 return rhs.kind() == kind();
2305 }
2306
2307 std::string str() const override {
2308 return "tuple";
2309 }
2310 static const TypeKind Kind = TypeKind::AnyTupleType;
2311
2312 // global singleton
2313 static AnyTupleTypePtr get();
2314private:
2315 AnyTupleType()
2316 : Type(TypeKind::AnyTupleType) {}
2317};
2318
2319// the common supertype of all classes,
2320// ClassType <: AnyClassType for all classes
2321struct AnyClassType;
2322using AnyClassTypePtr = SingletonTypePtr<AnyClassType>;
2323struct TORCH_API AnyClassType : public Type {
2324 bool equals(const Type& rhs) const override {
2325 return rhs.kind() == kind();
2326 }
2327 std::string str() const override {
2328 return "AnyClassType";
2329 }
2330 static const TypeKind Kind = TypeKind::AnyClassType;
2331 // global singleton
2332 static AnyClassTypePtr get();
2333private:
2334 AnyClassType()
2335 : Type(TypeKind::AnyClassType) {}
2336};
2337
2338template<>
2339inline typename detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
2340 if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
2341 kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
2342 return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this());
2343 }
2344 return nullptr;
2345}
2346
2347template<>
2348inline typename detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
2349 if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
2350 kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
2351 return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this());
2352 }
2353 return nullptr;
2354}
2355
2356template<>
2357inline const NamedType* Type::castRaw<NamedType>() const {
2358 if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
2359 kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
2360 return static_cast<const NamedType*>(this);
2361 }
2362 return nullptr;
2363}
2364
2365// Used as a return type when inferring the IValue type of a Python object.
2366struct InferredType {
2367 /* implicit */ InferredType(TypePtr type) : type_(std::move(type)) {}
2368 /* implicit */ InferredType(std::string reason)
2369 : type_(nullptr), reason_(std::move(reason)) {}
2370 TypePtr type() const {
2371 TORCH_INTERNAL_ASSERT(
2372 type_,
2373 "Tried to get the type from an InferredType but the type is null. ",
2374 "Reason: ",
2375 reason_);
2376 return type_;
2377 }
2378 bool success() const {
2379 return type_ != nullptr;
2380 }
2381 const std::string& reason() const {
2382 TORCH_INTERNAL_ASSERT(!type_);
2383 return reason_;
2384 }
2385
2386private:
2387 TypePtr type_;
2388 std::string reason_;
2389};
2390
2391TORCH_API bool containsAnyType(const TypePtr& type);
2392
2393} // namespace c10
2394