1#pragma once
2
3#include <functional>
4#include <memory>
5#include <string>
6#include <utility>
7
8#include <ATen/core/qualified_name.h>
9#include <ATen/core/type_ptr.h>
10#include <c10/core/SymInt.h>
11#include <c10/core/SymFloat.h>
12#include <c10/core/SymIntArrayRef.h>
13#include <c10/macros/Macros.h>
14#include <c10/util/ArrayRef.h>
15#include <c10/util/Exception.h>
16#include <c10/util/Optional.h>
17
18namespace c10 {
19
20#define C10_FORALL_TYPES(_) \
21 _(AnyType) \
22 _(EnumType) \
23 _(AnyEnumType) \
24 _(TensorType) \
25 _(StorageType) \
26 _(TupleType) \
27 _(ListType) \
28 _(DictType) \
29 _(NumberType) \
30 _(FloatType) \
31 _(ComplexType) \
32 _(FutureType) \
33 _(AwaitType) \
34 _(RRefType) \
35 _(IntType) \
36 _(NoneType) \
37 _(StringType) \
38 _(GeneratorType) \
39 _(QuantizerType) \
40 _(BoolType) \
41 _(OptionalType) \
42 _(VarType) \
43 _(DeviceObjType) \
44 _(StreamObjType) \
45 _(FunctionType) \
46 _(ClassType) \
47 _(PyObjectType) \
48 _(CapsuleType) \
49 _(InterfaceType) \
50 _(QSchemeType) \
51 _(ScalarTypeType) \
52 _(LayoutType) \
53 _(MemoryFormatType) \
54 _(AnyListType) \
55 _(AnyTupleType) \
56 _(AnyClassType) \
57 _(SymIntType) \
58 _(SymFloatType) \
59 _(UnionType) \
60 _(DynamicType)
61
62enum class TypeKind {
63#define DEFINE_TYPE(T) T,
64 C10_FORALL_TYPES(DEFINE_TYPE)
65#undef DEFINE_TYPE
66};
67
68TORCH_API const char* typeKindToString(TypeKind kind);
69
70struct Type;
71struct SharedType;
72
73// Use this to customize how a Type is printed using `annotation_str()`. If
74// c10::nullopt is returned, `annotation_str()` falls through to its default
75// implementation.
76using TypePrinter = std::function<c10::optional<std::string>(const Type&)>;
77
78namespace detail {
79template <typename T>
80struct IsSingletonType : public std::integral_constant<bool, false> {};
81} // namespace detail
82#define TORCH_DECLARE_SINGLETON(Type) \
83 struct Type; \
84 namespace detail { \
85 template <> struct IsSingletonType<Type> : public std::integral_constant<bool, true> {}; \
86 }
87
88TORCH_DECLARE_SINGLETON(AnyType);
89TORCH_DECLARE_SINGLETON(AnyEnumType);
90TORCH_DECLARE_SINGLETON(NumberType);
91TORCH_DECLARE_SINGLETON(FloatType);
92TORCH_DECLARE_SINGLETON(ComplexType);
93TORCH_DECLARE_SINGLETON(IntType);
94TORCH_DECLARE_SINGLETON(BoolType);
95TORCH_DECLARE_SINGLETON(StringType);
96TORCH_DECLARE_SINGLETON(StorageType);
97TORCH_DECLARE_SINGLETON(NoneType);
98TORCH_DECLARE_SINGLETON(GeneratorType);
99TORCH_DECLARE_SINGLETON(QuantizerType);
100TORCH_DECLARE_SINGLETON(QSchemeType);
101TORCH_DECLARE_SINGLETON(DeviceObjType);
102TORCH_DECLARE_SINGLETON(StreamObjType);
103TORCH_DECLARE_SINGLETON(CapsuleType);
104TORCH_DECLARE_SINGLETON(PyObjectType);
105TORCH_DECLARE_SINGLETON(ScalarTypeType);
106TORCH_DECLARE_SINGLETON(LayoutType);
107TORCH_DECLARE_SINGLETON(MemoryFormatType);
108TORCH_DECLARE_SINGLETON(AnyListType);
109TORCH_DECLARE_SINGLETON(AnyTupleType);
110TORCH_DECLARE_SINGLETON(AnyClassType);
111
112namespace detail {
113template <typename T, typename Enable = void>
114struct CastReturnType {
115 using type = std::shared_ptr<T>;
116};
117
118template <typename T>
119struct CastReturnType<T, typename std::enable_if<IsSingletonType<T>::value>::type> {
120 using type = SingletonTypePtr<T>;
121};
122
123template <typename T, typename Enable = void>
124struct CastConstReturnType {
125 using type = std::shared_ptr<const T>;
126};
127
128template <typename T>
129struct CastConstReturnType<T, typename std::enable_if<IsSingletonType<T>::value>::type> {
130 using type = SingletonTypePtr<const T>;
131};
132
133template <typename T>
134struct as_shared_type {
135 using type = SharedType*;
136};
137
138template <typename T>
139struct as_shared_type<const T*> {
140 using type = const SharedType *;
141};
142} // namespace detail
143
144struct TORCH_API Type {
145 friend TORCH_API bool operator==(const Type& lhs, const Type& rhs);
146 private:
147 TypeKind kind_;
148
149 protected:
150 Type(TypeKind kind) : kind_(kind) {}
151
152 virtual std::string annotation_str_impl(TypePrinter /*printer*/) const {
153 return str();
154 }
155 // a == b
156 virtual bool equals(const Type& rhs) const = 0;
157 // a == b <=> b == a
158 virtual bool symmetric() const {
159 return true;
160 }
161
162 public:
163 template <typename T>
164 class SingletonOrSharedTypePtr {
165 public:
166 using element_type = typename std::shared_ptr<T>::element_type;
167
168 SingletonOrSharedTypePtr() = default;
169
170 /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr<T> x)
171 : repr_(std::move(x)) {}
172
173 template <typename U, std::enable_if_t<std::is_convertible<U*, T*>::value, bool> = true>
174 /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr<U> x)
175 : repr_(std::move(x)) {}
176
177 /* implicit */ SingletonOrSharedTypePtr(std::nullptr_t)
178 : repr_(nullptr) {}
179
180 /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
181 : repr_(p) {}
182
183 template <typename U, std::enable_if_t<std::is_convertible<U*, T*>::value, bool> = true>
184 /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
185 : repr_(SingletonTypePtr<T>(p.get())) {}
186
187
188 // We need to support construction from T* for pybind. The problem
189 // is that it's not clear if we are supposed to be taking shared
190 // ownership or not.
191 //
192 // Case 1: if T is known statically to derive from SharedType, we should use
193 // shared_from_this() and take shared_ownership.
194 //
195 // Case 2: if T is exactly Type, we need to do a dynamic_cast to
196 // check if it's a SharedType and do the right thing.
197 //
198 // Case 3: Otherwise, T is not a SharedType. (debug-check this
199 // assumption!) Use a singleton pointer.
200
201 template <typename U = T, std::enable_if_t<std::is_base_of<SharedType, U>::value, bool> = true>
202 /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
203
204 template <typename U = T, std::enable_if_t<std::is_same<Type, U>::value, bool> = true>
205 /* implicit */ SingletonOrSharedTypePtr(T* p) {
206 if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
207 repr_ = Repr(shared_p->shared_from_this());
208 } else {
209 repr_ = Repr(p);
210 }
211 }
212
213 template <typename U = T, std::enable_if_t<!std::is_same<Type, U>::value && !std::is_base_of<SharedType, U>::value, bool> = true>
214 /* implicit */ SingletonOrSharedTypePtr(T* p)
215 : repr_(p) {
216 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
217 }
218
219 SingletonOrSharedTypePtr(const SingletonOrSharedTypePtr&) = default;
220 SingletonOrSharedTypePtr(SingletonOrSharedTypePtr&&) noexcept = default;
221 SingletonOrSharedTypePtr& operator=(const SingletonOrSharedTypePtr&) = default;
222 SingletonOrSharedTypePtr& operator=(SingletonOrSharedTypePtr&&) noexcept = default;
223
224 T* get() const {
225 return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first);
226 }
227
228 operator bool() const {
229 return repr_.isNonNull();
230 }
231
232 bool operator==(std::nullptr_t) const {
233 return !repr_.isNonNull();
234 }
235
236 bool operator!=(std::nullptr_t) const {
237 return repr_.isNonNull();
238 }
239
240 template <typename U = T, std::enable_if_t<!std::is_same<std::remove_const_t<U>, void>::value, bool> = true>
241 U& operator*() const {
242 return *get();
243 }
244
245 T* operator->() const {
246 return get();
247 }
248
249 private:
250 // NOTE: SharedPtrWrapper exists to work around a baffling bug in
251 // nvcc; see comment in destroy() below.
252 struct SharedPtrWrapper {
253 SharedPtrWrapper(std::shared_ptr<T> &&x)
254 : repr_(std::move(x)) {}
255 std::shared_ptr<T> repr_;
256 };
257 union Repr {
258 Repr() : Repr(nullptr) {}
259
260 explicit Repr(std::shared_ptr<T> x)
261 : shared_(std::move(x)) {}
262
263 explicit Repr(std::nullptr_t)
264 : singletonRepr_(nullptr) {}
265
266 explicit Repr(SingletonTypePtr<T> p)
267 : singletonRepr_(p.get()) {}
268
269 ~Repr() {
270 destroy();
271 }
272
273 // NOTE: the only non-UB way to access our null state is through
274 // rawRepr(), because our copy operation doesn't preserve which
275 // union member is active for null pointers.
276 Repr(const Repr& rhs) {
277 if (rhs.isSharedAndNonNull()) {
278 new (&shared_) SharedPtrWrapper(rhs.shared_);
279 } else {
280 singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
281 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
282 singletonRepr_.unused_ = nullptr;
283 }
284 }
285
286 Repr(Repr&& rhs) noexcept {
287 if (rhs.isSharedAndNonNull()) {
288 new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
289 } else {
290 singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
291 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
292 singletonRepr_.unused_ = nullptr;
293 }
294 }
295
296 Repr& operator=(const Repr& rhs) {
297 if (&rhs == this) {
298 return *this;
299 }
300 if (rhs.isSharedAndNonNull()) {
301 if (isSharedAndNonNull()) {
302 shared_ = rhs.shared_;
303 } else {
304 new (&shared_) SharedPtrWrapper(rhs.shared_);
305 }
306 } else {
307 if (isSharedAndNonNull()) {
308 destroy();
309 }
310 singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
311 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
312 singletonRepr_.unused_ = nullptr;
313 }
314 return *this;
315 }
316
317 Repr& operator=(Repr&& rhs) noexcept {
318 if (&rhs == this) {
319 return *this;
320 }
321 if (rhs.isSharedAndNonNull()) {
322 if (isSharedAndNonNull()) {
323 shared_ = std::move(rhs.shared_);
324 } else {
325 new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
326 }
327 } else {
328 if (isSharedAndNonNull()) {
329 destroy();
330 }
331 singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
332 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
333 singletonRepr_.unused_ = nullptr;
334 }
335 return *this;
336 }
337
338 SharedPtrWrapper shared_;
339
340 struct SingletonRepr {
341 explicit SingletonRepr(T* s) : singleton_(s) {}
342 T* singleton_;
343 void* unused_ = nullptr;
344 } singletonRepr_;
345 struct RawRepr {
346 void* first;
347 void* nullIfSingleton_;
348 };
349
350 // It is UB to read the singleton part of Repr if it was
351 // constructed as a shared_ptr and vice versa, but memcpying out
352 // the representation is always OK, so here's an accessor to obey
353 // the letter of the law.
354 RawRepr rawRepr() const {
355 RawRepr repr;
356 memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
357 return repr;
358 }
359
360 bool isNonNull() const {
361 auto repr = rawRepr();
362 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
363 return repr.first != nullptr;
364 }
365
366 bool isSharedAndNonNull() const {
367 return rawRepr().nullIfSingleton_ != nullptr;
368 }
369
370 private:
371 void destroy() {
372 if (isSharedAndNonNull()) {
373 // Without SharedPtrWrapper, this line would read
374 // `shared_.~shared_ptr()` and nvcc would complain with
375 // "error: expected primary-expression before '>' token"
376 // referring to the "t" in "shared_ptr". SharedPtrWrapper
377 // exists to work around this compiler bug.
378 shared_.~SharedPtrWrapper();
379 }
380 }
381 } repr_;
382 };
383
384 using TypePtr = SingletonOrSharedTypePtr<Type>;
385 using Ptr = TypePtr;
386 using ElementType = Type;
387
388 // subtyping relation. By default, we return true for the case
389 // when the type is exactly equal or if this <: T where rhs = Optional[T]
390
391 // if this returns false and the why_not stream is non-null, it contains
392 // additional details that describe why this is not a subtype of 'rhs'.
393 // This additional information should only contain details that are not
394 // obvious from the annotation_str() that describes the type. For instance it
395 // is clear that `int <: str` is false but not clear why `Foo <: InterfaceBar`
396 // might be false.
397 virtual bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const;
398 virtual bool is_module() const;
399 bool isSubtypeOf(const Type& rhs) const {
400 return isSubtypeOfExt(rhs, nullptr);
401 }
402 // Compatibility shims to accommodate existing code that passes shared_ptrs
403 // around. Ideally, we would just delete this, but it should be harmless.
404 template <typename T>
405 typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type
406 isSubtypeOf(const std::shared_ptr<T>& rhs) const {
407 return isSubtypeOf(*rhs);
408 }
409
410 template <typename T>
411 typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type
412 isSubtypeOf(const SingletonOrSharedTypePtr<T>& rhs) const {
413 return isSubtypeOf(*rhs);
414 }
415
416 template <typename T>
417 typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type
418 isSubtypeOf(SingletonTypePtr<T> rhs) const {
419 return isSubtypeOf(*rhs);
420 }
421
422 template <typename T>
423 typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type
424 isSubtypeOfExt(const SingletonOrSharedTypePtr<T>& rhs, std::ostream* why_not) const {
425 return isSubtypeOfExt(*rhs, why_not);
426 }
427
428 template <typename T>
429 typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type
430 isSubtypeOfExt(const std::shared_ptr<T>& rhs, std::ostream* why_not) const {
431 return isSubtypeOfExt(*rhs, why_not);
432 }
433
434 template <typename T>
435 typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type
436 isSubtypeOfExt(SingletonTypePtr<T> rhs, std::ostream* why_not) const {
437 return isSubtypeOfExt(*rhs, why_not);
438 }
439
440 // How this type will appear in FunctionSchema declarations
441 virtual std::string str() const = 0;
442
443 // How this type will appear as if it were a type annotation in Python
444 // which is sometimes different than how it appears in declarations (e.g.
445 // int[] vs List[int])
446 //
447 // Takes a custom printer that users can pass in to customize the output of
448 // this method.
449 std::string annotation_str(TypePrinter printer) const {
450 if (printer) {
451 // the printer can return nullopt to fall through to the default impl
452 if (auto renamed = printer(*this)) {
453 return *renamed;
454 }
455 }
456 return annotation_str_impl(std::move(printer));
457 }
458 std::string annotation_str() const {
459 // Overload instead of define a default value for `printer` to help
460 // debuggers out.
461 return annotation_str(nullptr);
462 }
463
464 // Returns a human readable string that includes additional information like
465 // "type is inferred rather than explictly defined" to help construct more
466 // user-friendly messages.
467 virtual std::string repr_str() const {
468 return annotation_str();
469 }
470
471 TypeKind kind() const {
472 return kind_;
473 }
474
475 virtual bool isUnionType() const {
476 return false;
477 }
478
479 virtual bool requires_grad() const {
480 for (const auto& ct : containedTypes()) {
481 if (ct->requires_grad()) {
482 return true;
483 }
484 }
485 return false;
486 }
487
488 // Dynamically cast this object to the subclass indicated by the
489 // template variable, returning nullptr if the cast is invalid.
490 template <typename T, std::enable_if_t<!detail::IsSingletonType<T>::value, bool> = true>
491 typename detail::CastReturnType<T>::type cast() {
492 if (T::Kind == kind()) {
493 return std::static_pointer_cast<T>(static_cast<T*>(this)->shared_from_this());
494 }
495 return nullptr;
496 }
497 template <typename T, std::enable_if_t<detail::IsSingletonType<T>::value, bool> = true>
498 typename detail::CastReturnType<T>::type cast() {
499 if (T::Kind == kind()) {
500 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get());
501 return typename detail::CastReturnType<T>::type(static_cast<T*>(this));
502 }
503 return nullptr;
504 }
505 template <typename T, std::enable_if_t<!detail::IsSingletonType<T>::value, bool> = true>
506 typename detail::CastConstReturnType<T>::type cast() const {
507 if (T::Kind == kind()) {
508 return std::static_pointer_cast<const T>(static_cast<const T*>(this)->shared_from_this());
509 }
510 return nullptr;
511 }
512 template <typename T, std::enable_if_t<detail::IsSingletonType<T>::value, bool> = true>
513 typename detail::CastConstReturnType<T>::type cast() const {
514 if (T::Kind == kind()) {
515 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get());
516 return typename detail::CastConstReturnType<T>::type(static_cast<const T*>(this));
517 }
518 return nullptr;
519 }
520 template <typename T>
521 T* castRaw() {
522 if (T::Kind == kind()) {
523 return static_cast<T*>(this);
524 }
525 return nullptr;
526 }
527 template <typename T>
528 const T* castRaw() const {
529 if (T::Kind == kind()) {
530 return static_cast<const T*>(this);
531 }
532 return nullptr;
533 }
534 template <typename T>
535 auto expect() {
536 auto r = cast<T>();
537 AT_ASSERT(r);
538 return r;
539 }
540 template <typename T>
541 auto expect() const {
542 auto r = cast<const T>();
543 AT_ASSERT(r);
544 return r;
545 }
546 template <typename T>
547 T& expectRef() {
548 auto* r = castRaw<T>();
549 AT_ASSERT(r);
550 return *r;
551 }
552 template <typename T>
553 const T& expectRef() const {
554 auto* r = castRaw<const T>();
555 AT_ASSERT(r);
556 return *r;
557 }
558 virtual ~Type() = default;
559 virtual bool hasFreeVariables() const {
560 return false;
561 }
562 // list of types this type contains, e.g. for a List then element type of a
563 // list for a tuple, the types of the tuple elements
564 virtual at::ArrayRef<TypePtr> containedTypes() const {
565 return {};
566 }
567 virtual TypePtr containedType(size_t i) const {
568 return containedTypes().at(i);
569 }
570 virtual size_t containedTypeSize() const {
571 return containedTypes().size();
572 }
573 // create a new version of this type, replacing its contained types with
574 // contained_types
575 TypePtr withContained(std::vector<TypePtr> contained_types);
576 // per-type constructor, you only need to override this if the
577 // containedTypes() is not empty
578 virtual TypePtr createWithContained(
579 std::vector<TypePtr> /*contained_types*/) const {
580 AT_ERROR(
581 "type with contained types did not overload createWithContained: ",
582 str());
583 }
584
585};
586
587template <typename T>
588using SingletonOrSharedTypePtr = Type::SingletonOrSharedTypePtr<T>;
589
590
591template <typename T, typename U>
592bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
593 return (void*)x.get() == (void*)y.get();
594}
595
596template <typename T, typename U>
597bool operator==(const SingletonOrSharedTypePtr<T>& x, const std::shared_ptr<U>& y) {
598 return (void*)x.get() == (void*)y.get();
599}
600
601template <typename T, typename U>
602bool operator==(const std::shared_ptr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
603 return (void*)x.get() == (void*)y.get();
604}
605
606template <typename T, typename U>
607bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonTypePtr<U>& y) {
608 return (void*)x.get() == (void*)y.get();
609}
610
611template <typename T, typename U>
612bool operator==(const SingletonTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
613 return (void*)x.get() == (void*)y.get();
614}
615
616template <typename T, typename U>
617bool operator!=(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
618 return !(x == y);
619}
620
621template <typename T, typename U>
622bool operator!=(const SingletonOrSharedTypePtr<T>& x, const std::shared_ptr<U>& y) {
623 return !(x == y);
624}
625
626template <typename T, typename U>
627bool operator!=(const std::shared_ptr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
628 return !(x == y);
629}
630
631template <typename T, typename U>
632bool operator!=(const SingletonOrSharedTypePtr<T>& x, const SingletonTypePtr<U>& y) {
633 return !(x == y);
634}
635
636template <typename T, typename U>
637bool operator!=(const SingletonTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
638 return !(x == y);
639}
640
641using TypePtr = SingletonOrSharedTypePtr<Type>;
642using ConstTypePtr = SingletonOrSharedTypePtr<const Type>;
643
644// Explicitly enable MaybeOwned<shared_ptr<T>>, rather than allowing
645// MaybeOwned to be used for any type right away.
646template <typename T>
647struct MaybeOwnedTraits<SingletonOrSharedTypePtr<T>>
648 : public MaybeOwnedTraitsGenericImpl<SingletonOrSharedTypePtr<T>> {};
649
650// Base class for Types that are guaranteed to be owned by std::shared_ptr.
651struct TORCH_API SharedType : public Type, public std::enable_shared_from_this<SharedType> {
652 using Type::Type;
653};
654
655inline TypePtr Type::withContained(std::vector<TypePtr> contained_types) {
656 auto current_contained = containedTypes();
657 // Types with no contained_types don't need this call. Check before calling!
658 //
659 // (We can't support this efficiently because types without
660 // contained types may be singletons, in which case
661 // shared_from_this will crash; we would have to provide a virtual
662 // typeptr_from_this or isSingleton.)
663 TORCH_INTERNAL_ASSERT(!current_contained.empty() && current_contained.size() == contained_types.size());
664 if (current_contained.equals(contained_types)) {
665 return std::static_pointer_cast<Type>(static_cast<SharedType *>(this)->shared_from_this());
666 }
667 return createWithContained(std::move(contained_types));
668}
669
670
671TORCH_API inline bool operator==(const Type& lhs, const Type& rhs) {
672 if (C10_UNLIKELY(!rhs.symmetric())) {
673 return rhs.equals(lhs);
674 }
675 return lhs.equals(rhs);
676}
677
678struct NamedType;
679using NamedTypePtr = std::shared_ptr<NamedType>;
680using ConstNamedTypePtr = std::shared_ptr<const NamedType>;
681
682struct TORCH_API NamedType : public SharedType {
683 NamedType(TypeKind tk, c10::optional<QualifiedName> name)
684 : SharedType(tk), name_(std::move(name)) {
685 TORCH_INTERNAL_ASSERT(
686 tk == TypeKind::TupleType || tk == TypeKind::FunctionType ||
687 tk == TypeKind::ClassType || tk == TypeKind::InterfaceType ||
688 tk == TypeKind::EnumType,
689 "If you add a new kind of NamedType, ",
690 "please update the cast<NamedType> specialization and this assert");
691 }
692
693 // Fully qualified name of type
694 // Looks like: "foo.bar.Baz".
695 const c10::optional<QualifiedName>& name() const {
696 return name_;
697 }
698
699 private:
700 c10::optional<QualifiedName> name_;
701};
702
703} // namespace c10
704
705namespace std {
706template <typename T>
707struct hash<c10::SingletonOrSharedTypePtr<T>> {
708 size_t operator()(const c10::SingletonOrSharedTypePtr<T>& x) const {
709 return std::hash<T*>()(x.get());
710 }
711};
712} // namespace std
713