1 | #pragma once |
2 | |
3 | #include <functional> |
4 | #include <memory> |
5 | #include <string> |
6 | #include <utility> |
7 | |
8 | #include <ATen/core/qualified_name.h> |
9 | #include <ATen/core/type_ptr.h> |
10 | #include <c10/core/SymInt.h> |
11 | #include <c10/core/SymFloat.h> |
12 | #include <c10/core/SymIntArrayRef.h> |
13 | #include <c10/macros/Macros.h> |
14 | #include <c10/util/ArrayRef.h> |
15 | #include <c10/util/Exception.h> |
16 | #include <c10/util/Optional.h> |
17 | |
18 | namespace c10 { |
19 | |
20 | #define C10_FORALL_TYPES(_) \ |
21 | _(AnyType) \ |
22 | _(EnumType) \ |
23 | _(AnyEnumType) \ |
24 | _(TensorType) \ |
25 | _(StorageType) \ |
26 | _(TupleType) \ |
27 | _(ListType) \ |
28 | _(DictType) \ |
29 | _(NumberType) \ |
30 | _(FloatType) \ |
31 | _(ComplexType) \ |
32 | _(FutureType) \ |
33 | _(AwaitType) \ |
34 | _(RRefType) \ |
35 | _(IntType) \ |
36 | _(NoneType) \ |
37 | _(StringType) \ |
38 | _(GeneratorType) \ |
39 | _(QuantizerType) \ |
40 | _(BoolType) \ |
41 | _(OptionalType) \ |
42 | _(VarType) \ |
43 | _(DeviceObjType) \ |
44 | _(StreamObjType) \ |
45 | _(FunctionType) \ |
46 | _(ClassType) \ |
47 | _(PyObjectType) \ |
48 | _(CapsuleType) \ |
49 | _(InterfaceType) \ |
50 | _(QSchemeType) \ |
51 | _(ScalarTypeType) \ |
52 | _(LayoutType) \ |
53 | _(MemoryFormatType) \ |
54 | _(AnyListType) \ |
55 | _(AnyTupleType) \ |
56 | _(AnyClassType) \ |
57 | _(SymIntType) \ |
58 | _(SymFloatType) \ |
59 | _(UnionType) \ |
60 | _(DynamicType) |
61 | |
62 | enum class TypeKind { |
63 | #define DEFINE_TYPE(T) T, |
64 | C10_FORALL_TYPES(DEFINE_TYPE) |
65 | #undef DEFINE_TYPE |
66 | }; |
67 | |
68 | TORCH_API const char* typeKindToString(TypeKind kind); |
69 | |
70 | struct Type; |
71 | struct SharedType; |
72 | |
73 | // Use this to customize how a Type is printed using `annotation_str()`. If |
74 | // c10::nullopt is returned, `annotation_str()` falls through to its default |
75 | // implementation. |
76 | using TypePrinter = std::function<c10::optional<std::string>(const Type&)>; |
77 | |
78 | namespace detail { |
79 | template <typename T> |
80 | struct IsSingletonType : public std::integral_constant<bool, false> {}; |
81 | } // namespace detail |
82 | #define TORCH_DECLARE_SINGLETON(Type) \ |
83 | struct Type; \ |
84 | namespace detail { \ |
85 | template <> struct IsSingletonType<Type> : public std::integral_constant<bool, true> {}; \ |
86 | } |
87 | |
88 | TORCH_DECLARE_SINGLETON(AnyType); |
89 | TORCH_DECLARE_SINGLETON(AnyEnumType); |
90 | TORCH_DECLARE_SINGLETON(NumberType); |
91 | TORCH_DECLARE_SINGLETON(FloatType); |
92 | TORCH_DECLARE_SINGLETON(ComplexType); |
93 | TORCH_DECLARE_SINGLETON(IntType); |
94 | TORCH_DECLARE_SINGLETON(BoolType); |
95 | TORCH_DECLARE_SINGLETON(StringType); |
96 | TORCH_DECLARE_SINGLETON(StorageType); |
97 | TORCH_DECLARE_SINGLETON(NoneType); |
98 | TORCH_DECLARE_SINGLETON(GeneratorType); |
99 | TORCH_DECLARE_SINGLETON(QuantizerType); |
100 | TORCH_DECLARE_SINGLETON(QSchemeType); |
101 | TORCH_DECLARE_SINGLETON(DeviceObjType); |
102 | TORCH_DECLARE_SINGLETON(StreamObjType); |
103 | TORCH_DECLARE_SINGLETON(CapsuleType); |
104 | TORCH_DECLARE_SINGLETON(PyObjectType); |
105 | TORCH_DECLARE_SINGLETON(ScalarTypeType); |
106 | TORCH_DECLARE_SINGLETON(LayoutType); |
107 | TORCH_DECLARE_SINGLETON(MemoryFormatType); |
108 | TORCH_DECLARE_SINGLETON(AnyListType); |
109 | TORCH_DECLARE_SINGLETON(AnyTupleType); |
110 | TORCH_DECLARE_SINGLETON(AnyClassType); |
111 | |
112 | namespace detail { |
113 | template <typename T, typename Enable = void> |
114 | struct CastReturnType { |
115 | using type = std::shared_ptr<T>; |
116 | }; |
117 | |
118 | template <typename T> |
119 | struct CastReturnType<T, typename std::enable_if<IsSingletonType<T>::value>::type> { |
120 | using type = SingletonTypePtr<T>; |
121 | }; |
122 | |
123 | template <typename T, typename Enable = void> |
124 | struct CastConstReturnType { |
125 | using type = std::shared_ptr<const T>; |
126 | }; |
127 | |
128 | template <typename T> |
129 | struct CastConstReturnType<T, typename std::enable_if<IsSingletonType<T>::value>::type> { |
130 | using type = SingletonTypePtr<const T>; |
131 | }; |
132 | |
133 | template <typename T> |
134 | struct as_shared_type { |
135 | using type = SharedType*; |
136 | }; |
137 | |
138 | template <typename T> |
139 | struct as_shared_type<const T*> { |
140 | using type = const SharedType *; |
141 | }; |
142 | } // namespace detail |
143 | |
144 | struct TORCH_API Type { |
145 | friend TORCH_API bool operator==(const Type& lhs, const Type& rhs); |
146 | private: |
147 | TypeKind kind_; |
148 | |
149 | protected: |
150 | Type(TypeKind kind) : kind_(kind) {} |
151 | |
152 | virtual std::string annotation_str_impl(TypePrinter /*printer*/) const { |
153 | return str(); |
154 | } |
155 | // a == b |
156 | virtual bool equals(const Type& rhs) const = 0; |
157 | // a == b <=> b == a |
158 | virtual bool symmetric() const { |
159 | return true; |
160 | } |
161 | |
162 | public: |
163 | template <typename T> |
164 | class SingletonOrSharedTypePtr { |
165 | public: |
166 | using element_type = typename std::shared_ptr<T>::element_type; |
167 | |
168 | SingletonOrSharedTypePtr() = default; |
169 | |
170 | /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr<T> x) |
171 | : repr_(std::move(x)) {} |
172 | |
173 | template <typename U, std::enable_if_t<std::is_convertible<U*, T*>::value, bool> = true> |
174 | /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr<U> x) |
175 | : repr_(std::move(x)) {} |
176 | |
177 | /* implicit */ SingletonOrSharedTypePtr(std::nullptr_t) |
178 | : repr_(nullptr) {} |
179 | |
180 | /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p) |
181 | : repr_(p) {} |
182 | |
183 | template <typename U, std::enable_if_t<std::is_convertible<U*, T*>::value, bool> = true> |
184 | /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p) |
185 | : repr_(SingletonTypePtr<T>(p.get())) {} |
186 | |
187 | |
188 | // We need to support construction from T* for pybind. The problem |
189 | // is that it's not clear if we are supposed to be taking shared |
190 | // ownership or not. |
191 | // |
192 | // Case 1: if T is known statically to derive from SharedType, we should use |
193 | // shared_from_this() and take shared_ownership. |
194 | // |
195 | // Case 2: if T is exactly Type, we need to do a dynamic_cast to |
196 | // check if it's a SharedType and do the right thing. |
197 | // |
198 | // Case 3: Otherwise, T is not a SharedType. (debug-check this |
199 | // assumption!) Use a singleton pointer. |
200 | |
201 | template <typename U = T, std::enable_if_t<std::is_base_of<SharedType, U>::value, bool> = true> |
202 | /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {} |
203 | |
204 | template <typename U = T, std::enable_if_t<std::is_same<Type, U>::value, bool> = true> |
205 | /* implicit */ SingletonOrSharedTypePtr(T* p) { |
206 | if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) { |
207 | repr_ = Repr(shared_p->shared_from_this()); |
208 | } else { |
209 | repr_ = Repr(p); |
210 | } |
211 | } |
212 | |
213 | template <typename U = T, std::enable_if_t<!std::is_same<Type, U>::value && !std::is_base_of<SharedType, U>::value, bool> = true> |
214 | /* implicit */ SingletonOrSharedTypePtr(T* p) |
215 | : repr_(p) { |
216 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr); |
217 | } |
218 | |
219 | SingletonOrSharedTypePtr(const SingletonOrSharedTypePtr&) = default; |
220 | SingletonOrSharedTypePtr(SingletonOrSharedTypePtr&&) noexcept = default; |
221 | SingletonOrSharedTypePtr& operator=(const SingletonOrSharedTypePtr&) = default; |
222 | SingletonOrSharedTypePtr& operator=(SingletonOrSharedTypePtr&&) noexcept = default; |
223 | |
224 | T* get() const { |
225 | return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first); |
226 | } |
227 | |
228 | operator bool() const { |
229 | return repr_.isNonNull(); |
230 | } |
231 | |
232 | bool operator==(std::nullptr_t) const { |
233 | return !repr_.isNonNull(); |
234 | } |
235 | |
236 | bool operator!=(std::nullptr_t) const { |
237 | return repr_.isNonNull(); |
238 | } |
239 | |
240 | template <typename U = T, std::enable_if_t<!std::is_same<std::remove_const_t<U>, void>::value, bool> = true> |
241 | U& operator*() const { |
242 | return *get(); |
243 | } |
244 | |
245 | T* operator->() const { |
246 | return get(); |
247 | } |
248 | |
249 | private: |
250 | // NOTE: SharedPtrWrapper exists to work around a baffling bug in |
251 | // nvcc; see comment in destroy() below. |
252 | struct SharedPtrWrapper { |
253 | SharedPtrWrapper(std::shared_ptr<T> &&x) |
254 | : repr_(std::move(x)) {} |
255 | std::shared_ptr<T> repr_; |
256 | }; |
257 | union Repr { |
258 | Repr() : Repr(nullptr) {} |
259 | |
260 | explicit Repr(std::shared_ptr<T> x) |
261 | : shared_(std::move(x)) {} |
262 | |
263 | explicit Repr(std::nullptr_t) |
264 | : singletonRepr_(nullptr) {} |
265 | |
266 | explicit Repr(SingletonTypePtr<T> p) |
267 | : singletonRepr_(p.get()) {} |
268 | |
269 | ~Repr() { |
270 | destroy(); |
271 | } |
272 | |
273 | // NOTE: the only non-UB way to access our null state is through |
274 | // rawRepr(), because our copy operation doesn't preserve which |
275 | // union member is active for null pointers. |
276 | Repr(const Repr& rhs) { |
277 | if (rhs.isSharedAndNonNull()) { |
278 | new (&shared_) SharedPtrWrapper(rhs.shared_); |
279 | } else { |
280 | singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |
281 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); |
282 | singletonRepr_.unused_ = nullptr; |
283 | } |
284 | } |
285 | |
286 | Repr(Repr&& rhs) noexcept { |
287 | if (rhs.isSharedAndNonNull()) { |
288 | new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); |
289 | } else { |
290 | singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |
291 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); |
292 | singletonRepr_.unused_ = nullptr; |
293 | } |
294 | } |
295 | |
296 | Repr& operator=(const Repr& rhs) { |
297 | if (&rhs == this) { |
298 | return *this; |
299 | } |
300 | if (rhs.isSharedAndNonNull()) { |
301 | if (isSharedAndNonNull()) { |
302 | shared_ = rhs.shared_; |
303 | } else { |
304 | new (&shared_) SharedPtrWrapper(rhs.shared_); |
305 | } |
306 | } else { |
307 | if (isSharedAndNonNull()) { |
308 | destroy(); |
309 | } |
310 | singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |
311 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); |
312 | singletonRepr_.unused_ = nullptr; |
313 | } |
314 | return *this; |
315 | } |
316 | |
317 | Repr& operator=(Repr&& rhs) noexcept { |
318 | if (&rhs == this) { |
319 | return *this; |
320 | } |
321 | if (rhs.isSharedAndNonNull()) { |
322 | if (isSharedAndNonNull()) { |
323 | shared_ = std::move(rhs.shared_); |
324 | } else { |
325 | new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); |
326 | } |
327 | } else { |
328 | if (isSharedAndNonNull()) { |
329 | destroy(); |
330 | } |
331 | singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first); |
332 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); |
333 | singletonRepr_.unused_ = nullptr; |
334 | } |
335 | return *this; |
336 | } |
337 | |
338 | SharedPtrWrapper shared_; |
339 | |
340 | struct SingletonRepr { |
341 | explicit SingletonRepr(T* s) : singleton_(s) {} |
342 | T* singleton_; |
343 | void* unused_ = nullptr; |
344 | } singletonRepr_; |
345 | struct RawRepr { |
346 | void* first; |
347 | void* nullIfSingleton_; |
348 | }; |
349 | |
350 | // It is UB to read the singleton part of Repr if it was |
351 | // constructed as a shared_ptr and vice versa, but memcpying out |
352 | // the representation is always OK, so here's an accessor to obey |
353 | // the letter of the law. |
354 | RawRepr rawRepr() const { |
355 | RawRepr repr; |
356 | memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr)); |
357 | return repr; |
358 | } |
359 | |
360 | bool isNonNull() const { |
361 | auto repr = rawRepr(); |
362 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr); |
363 | return repr.first != nullptr; |
364 | } |
365 | |
366 | bool isSharedAndNonNull() const { |
367 | return rawRepr().nullIfSingleton_ != nullptr; |
368 | } |
369 | |
370 | private: |
371 | void destroy() { |
372 | if (isSharedAndNonNull()) { |
373 | // Without SharedPtrWrapper, this line would read |
374 | // `shared_.~shared_ptr()` and nvcc would complain with |
375 | // "error: expected primary-expression before '>' token" |
376 | // referring to the "t" in "shared_ptr". SharedPtrWrapper |
377 | // exists to work around this compiler bug. |
378 | shared_.~SharedPtrWrapper(); |
379 | } |
380 | } |
381 | } repr_; |
382 | }; |
383 | |
384 | using TypePtr = SingletonOrSharedTypePtr<Type>; |
385 | using Ptr = TypePtr; |
386 | using ElementType = Type; |
387 | |
388 | // subtyping relation. By default, we return true for the case |
389 | // when the type is exactly equal or if this <: T where rhs = Optional[T] |
390 | |
391 | // if this returns false and the why_not stream is non-null, it contains |
392 | // additional details that describe why this is not a subtype of 'rhs'. |
393 | // This additional information should only contain details that are not |
394 | // obvious from the annotation_str() that describes the type. For instance it |
395 | // is clear that `int <: str` is false but not clear why `Foo <: InterfaceBar` |
396 | // might be false. |
397 | virtual bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const; |
398 | virtual bool is_module() const; |
399 | bool isSubtypeOf(const Type& rhs) const { |
400 | return isSubtypeOfExt(rhs, nullptr); |
401 | } |
402 | // Compatibility shims to accommodate existing code that passes shared_ptrs |
403 | // around. Ideally, we would just delete this, but it should be harmless. |
404 | template <typename T> |
405 | typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type |
406 | isSubtypeOf(const std::shared_ptr<T>& rhs) const { |
407 | return isSubtypeOf(*rhs); |
408 | } |
409 | |
410 | template <typename T> |
411 | typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type |
412 | isSubtypeOf(const SingletonOrSharedTypePtr<T>& rhs) const { |
413 | return isSubtypeOf(*rhs); |
414 | } |
415 | |
416 | template <typename T> |
417 | typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type |
418 | isSubtypeOf(SingletonTypePtr<T> rhs) const { |
419 | return isSubtypeOf(*rhs); |
420 | } |
421 | |
422 | template <typename T> |
423 | typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type |
424 | isSubtypeOfExt(const SingletonOrSharedTypePtr<T>& rhs, std::ostream* why_not) const { |
425 | return isSubtypeOfExt(*rhs, why_not); |
426 | } |
427 | |
428 | template <typename T> |
429 | typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type |
430 | isSubtypeOfExt(const std::shared_ptr<T>& rhs, std::ostream* why_not) const { |
431 | return isSubtypeOfExt(*rhs, why_not); |
432 | } |
433 | |
434 | template <typename T> |
435 | typename std::enable_if<std::is_base_of<Type, T>::value, bool>::type |
436 | isSubtypeOfExt(SingletonTypePtr<T> rhs, std::ostream* why_not) const { |
437 | return isSubtypeOfExt(*rhs, why_not); |
438 | } |
439 | |
440 | // How this type will appear in FunctionSchema declarations |
441 | virtual std::string str() const = 0; |
442 | |
443 | // How this type will appear as if it were a type annotation in Python |
444 | // which is sometimes different than how it appears in declarations (e.g. |
445 | // int[] vs List[int]) |
446 | // |
447 | // Takes a custom printer that users can pass in to customize the output of |
448 | // this method. |
449 | std::string annotation_str(TypePrinter printer) const { |
450 | if (printer) { |
451 | // the printer can return nullopt to fall through to the default impl |
452 | if (auto renamed = printer(*this)) { |
453 | return *renamed; |
454 | } |
455 | } |
456 | return annotation_str_impl(std::move(printer)); |
457 | } |
458 | std::string annotation_str() const { |
459 | // Overload instead of define a default value for `printer` to help |
460 | // debuggers out. |
461 | return annotation_str(nullptr); |
462 | } |
463 | |
464 | // Returns a human readable string that includes additional information like |
465 | // "type is inferred rather than explictly defined" to help construct more |
466 | // user-friendly messages. |
467 | virtual std::string repr_str() const { |
468 | return annotation_str(); |
469 | } |
470 | |
471 | TypeKind kind() const { |
472 | return kind_; |
473 | } |
474 | |
475 | virtual bool isUnionType() const { |
476 | return false; |
477 | } |
478 | |
479 | virtual bool requires_grad() const { |
480 | for (const auto& ct : containedTypes()) { |
481 | if (ct->requires_grad()) { |
482 | return true; |
483 | } |
484 | } |
485 | return false; |
486 | } |
487 | |
488 | // Dynamically cast this object to the subclass indicated by the |
489 | // template variable, returning nullptr if the cast is invalid. |
490 | template <typename T, std::enable_if_t<!detail::IsSingletonType<T>::value, bool> = true> |
491 | typename detail::CastReturnType<T>::type cast() { |
492 | if (T::Kind == kind()) { |
493 | return std::static_pointer_cast<T>(static_cast<T*>(this)->shared_from_this()); |
494 | } |
495 | return nullptr; |
496 | } |
497 | template <typename T, std::enable_if_t<detail::IsSingletonType<T>::value, bool> = true> |
498 | typename detail::CastReturnType<T>::type cast() { |
499 | if (T::Kind == kind()) { |
500 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get()); |
501 | return typename detail::CastReturnType<T>::type(static_cast<T*>(this)); |
502 | } |
503 | return nullptr; |
504 | } |
505 | template <typename T, std::enable_if_t<!detail::IsSingletonType<T>::value, bool> = true> |
506 | typename detail::CastConstReturnType<T>::type cast() const { |
507 | if (T::Kind == kind()) { |
508 | return std::static_pointer_cast<const T>(static_cast<const T*>(this)->shared_from_this()); |
509 | } |
510 | return nullptr; |
511 | } |
512 | template <typename T, std::enable_if_t<detail::IsSingletonType<T>::value, bool> = true> |
513 | typename detail::CastConstReturnType<T>::type cast() const { |
514 | if (T::Kind == kind()) { |
515 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get()); |
516 | return typename detail::CastConstReturnType<T>::type(static_cast<const T*>(this)); |
517 | } |
518 | return nullptr; |
519 | } |
520 | template <typename T> |
521 | T* castRaw() { |
522 | if (T::Kind == kind()) { |
523 | return static_cast<T*>(this); |
524 | } |
525 | return nullptr; |
526 | } |
527 | template <typename T> |
528 | const T* castRaw() const { |
529 | if (T::Kind == kind()) { |
530 | return static_cast<const T*>(this); |
531 | } |
532 | return nullptr; |
533 | } |
534 | template <typename T> |
535 | auto expect() { |
536 | auto r = cast<T>(); |
537 | AT_ASSERT(r); |
538 | return r; |
539 | } |
540 | template <typename T> |
541 | auto expect() const { |
542 | auto r = cast<const T>(); |
543 | AT_ASSERT(r); |
544 | return r; |
545 | } |
546 | template <typename T> |
547 | T& expectRef() { |
548 | auto* r = castRaw<T>(); |
549 | AT_ASSERT(r); |
550 | return *r; |
551 | } |
552 | template <typename T> |
553 | const T& expectRef() const { |
554 | auto* r = castRaw<const T>(); |
555 | AT_ASSERT(r); |
556 | return *r; |
557 | } |
558 | virtual ~Type() = default; |
559 | virtual bool hasFreeVariables() const { |
560 | return false; |
561 | } |
562 | // list of types this type contains, e.g. for a List then element type of a |
563 | // list for a tuple, the types of the tuple elements |
564 | virtual at::ArrayRef<TypePtr> containedTypes() const { |
565 | return {}; |
566 | } |
567 | virtual TypePtr containedType(size_t i) const { |
568 | return containedTypes().at(i); |
569 | } |
570 | virtual size_t containedTypeSize() const { |
571 | return containedTypes().size(); |
572 | } |
573 | // create a new version of this type, replacing its contained types with |
574 | // contained_types |
575 | TypePtr withContained(std::vector<TypePtr> contained_types); |
576 | // per-type constructor, you only need to override this if the |
577 | // containedTypes() is not empty |
578 | virtual TypePtr createWithContained( |
579 | std::vector<TypePtr> /*contained_types*/) const { |
580 | AT_ERROR( |
581 | "type with contained types did not overload createWithContained: " , |
582 | str()); |
583 | } |
584 | |
585 | }; |
586 | |
587 | template <typename T> |
588 | using SingletonOrSharedTypePtr = Type::SingletonOrSharedTypePtr<T>; |
589 | |
590 | |
591 | template <typename T, typename U> |
592 | bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) { |
593 | return (void*)x.get() == (void*)y.get(); |
594 | } |
595 | |
596 | template <typename T, typename U> |
597 | bool operator==(const SingletonOrSharedTypePtr<T>& x, const std::shared_ptr<U>& y) { |
598 | return (void*)x.get() == (void*)y.get(); |
599 | } |
600 | |
601 | template <typename T, typename U> |
602 | bool operator==(const std::shared_ptr<T>& x, const SingletonOrSharedTypePtr<U>& y) { |
603 | return (void*)x.get() == (void*)y.get(); |
604 | } |
605 | |
606 | template <typename T, typename U> |
607 | bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonTypePtr<U>& y) { |
608 | return (void*)x.get() == (void*)y.get(); |
609 | } |
610 | |
611 | template <typename T, typename U> |
612 | bool operator==(const SingletonTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) { |
613 | return (void*)x.get() == (void*)y.get(); |
614 | } |
615 | |
616 | template <typename T, typename U> |
617 | bool operator!=(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) { |
618 | return !(x == y); |
619 | } |
620 | |
621 | template <typename T, typename U> |
622 | bool operator!=(const SingletonOrSharedTypePtr<T>& x, const std::shared_ptr<U>& y) { |
623 | return !(x == y); |
624 | } |
625 | |
626 | template <typename T, typename U> |
627 | bool operator!=(const std::shared_ptr<T>& x, const SingletonOrSharedTypePtr<U>& y) { |
628 | return !(x == y); |
629 | } |
630 | |
631 | template <typename T, typename U> |
632 | bool operator!=(const SingletonOrSharedTypePtr<T>& x, const SingletonTypePtr<U>& y) { |
633 | return !(x == y); |
634 | } |
635 | |
636 | template <typename T, typename U> |
637 | bool operator!=(const SingletonTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) { |
638 | return !(x == y); |
639 | } |
640 | |
641 | using TypePtr = SingletonOrSharedTypePtr<Type>; |
642 | using ConstTypePtr = SingletonOrSharedTypePtr<const Type>; |
643 | |
644 | // Explicitly enable MaybeOwned<shared_ptr<T>>, rather than allowing |
645 | // MaybeOwned to be used for any type right away. |
646 | template <typename T> |
647 | struct MaybeOwnedTraits<SingletonOrSharedTypePtr<T>> |
648 | : public MaybeOwnedTraitsGenericImpl<SingletonOrSharedTypePtr<T>> {}; |
649 | |
650 | // Base class for Types that are guaranteed to be owned by std::shared_ptr. |
651 | struct TORCH_API SharedType : public Type, public std::enable_shared_from_this<SharedType> { |
652 | using Type::Type; |
653 | }; |
654 | |
655 | inline TypePtr Type::withContained(std::vector<TypePtr> contained_types) { |
656 | auto current_contained = containedTypes(); |
657 | // Types with no contained_types don't need this call. Check before calling! |
658 | // |
659 | // (We can't support this efficiently because types without |
660 | // contained types may be singletons, in which case |
661 | // shared_from_this will crash; we would have to provide a virtual |
662 | // typeptr_from_this or isSingleton.) |
663 | TORCH_INTERNAL_ASSERT(!current_contained.empty() && current_contained.size() == contained_types.size()); |
664 | if (current_contained.equals(contained_types)) { |
665 | return std::static_pointer_cast<Type>(static_cast<SharedType *>(this)->shared_from_this()); |
666 | } |
667 | return createWithContained(std::move(contained_types)); |
668 | } |
669 | |
670 | |
671 | TORCH_API inline bool operator==(const Type& lhs, const Type& rhs) { |
672 | if (C10_UNLIKELY(!rhs.symmetric())) { |
673 | return rhs.equals(lhs); |
674 | } |
675 | return lhs.equals(rhs); |
676 | } |
677 | |
678 | struct NamedType; |
679 | using NamedTypePtr = std::shared_ptr<NamedType>; |
680 | using ConstNamedTypePtr = std::shared_ptr<const NamedType>; |
681 | |
682 | struct TORCH_API NamedType : public SharedType { |
683 | NamedType(TypeKind tk, c10::optional<QualifiedName> name) |
684 | : SharedType(tk), name_(std::move(name)) { |
685 | TORCH_INTERNAL_ASSERT( |
686 | tk == TypeKind::TupleType || tk == TypeKind::FunctionType || |
687 | tk == TypeKind::ClassType || tk == TypeKind::InterfaceType || |
688 | tk == TypeKind::EnumType, |
689 | "If you add a new kind of NamedType, " , |
690 | "please update the cast<NamedType> specialization and this assert" ); |
691 | } |
692 | |
693 | // Fully qualified name of type |
694 | // Looks like: "foo.bar.Baz". |
695 | const c10::optional<QualifiedName>& name() const { |
696 | return name_; |
697 | } |
698 | |
699 | private: |
700 | c10::optional<QualifiedName> name_; |
701 | }; |
702 | |
703 | } // namespace c10 |
704 | |
705 | namespace std { |
706 | template <typename T> |
707 | struct hash<c10::SingletonOrSharedTypePtr<T>> { |
708 | size_t operator()(const c10::SingletonOrSharedTypePtr<T>& x) const { |
709 | return std::hash<T*>()(x.get()); |
710 | } |
711 | }; |
712 | } // namespace std |
713 | |