1 | #pragma once |
2 | |
3 | #include <type_traits> |
4 | #include <unordered_map> |
5 | |
6 | #include <ATen/core/dynamic_type.h> |
7 | #include <ATen/core/jit_type_base.h> |
8 | #include <c10/macros/Macros.h> |
9 | |
10 | namespace c10 { |
11 | |
12 | template <typename T> |
13 | struct TORCH_API TypeFactoryBase {}; |
14 | |
15 | template <> |
16 | struct TORCH_API TypeFactoryBase<c10::DynamicType> { |
17 | template <typename T, typename... Args> |
18 | static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) { |
19 | return std::make_shared<c10::DynamicType>( |
20 | c10::DynamicTypeTrait<T>::tagValue(), |
21 | c10::DynamicType::Arguments(c10::ArrayRef<c10::TypePtr>( |
22 | {std::move(ty), std::forward<Args>(args)...}))); |
23 | } |
24 | template <typename T> |
25 | static c10::DynamicTypePtr create(std::vector<c10::TypePtr> types) { |
26 | return std::make_shared<c10::DynamicType>( |
27 | c10::DynamicTypeTrait<T>::tagValue(), |
28 | c10::DynamicType::Arguments(types)); |
29 | } |
30 | static c10::DynamicTypePtr createNamedTuple( |
31 | const std::string& name, |
32 | const std::vector<c10::string_view>& fields, |
33 | const std::vector<c10::TypePtr>& types) { |
34 | return std::make_shared<c10::DynamicType>( |
35 | c10::DynamicType::Tag::Tuple, |
36 | name, |
37 | c10::DynamicType::Arguments(fields, types)); |
38 | } |
39 | template <typename T> |
40 | C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) { |
41 | return std::make_shared<c10::DynamicType>( |
42 | c10::DynamicTypeTrait<T>::tagValue(), |
43 | name, |
44 | c10::DynamicType::Arguments{}); |
45 | } |
46 | template <typename T> |
47 | C10_ERASE static c10::DynamicTypePtr get() { |
48 | return DynamicTypeTrait<T>::getBaseType(); |
49 | } |
50 | static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes(); |
51 | }; |
52 | |
53 | using DynamicTypeFactory = TypeFactoryBase<c10::DynamicType>; |
54 | |
55 | // Helper functions for constructing DynamicTypes inline. |
56 | template < |
57 | typename T, |
58 | std::enable_if_t<DynamicTypeTrait<T>::isBaseType, int> = 0> |
59 | C10_ERASE DynamicTypePtr dynT() { |
60 | return DynamicTypeFactory::get<T>(); |
61 | } |
62 | |
63 | template < |
64 | typename T, |
65 | typename... Args, |
66 | std::enable_if_t<!DynamicTypeTrait<T>::isBaseType, int> = 0> |
67 | C10_ERASE DynamicTypePtr dynT(Args&&... args) { |
68 | return DynamicTypeFactory::create<T>(std::forward<Args>(args)...); |
69 | } |
70 | |
71 | template <> |
72 | struct TORCH_API TypeFactoryBase<c10::Type> { |
73 | template <typename T, typename... Args> |
74 | static c10::TypePtr create(TypePtr ty, Args&&... args) { |
75 | return T::create(std::move(ty), std::forward<Args>(args)...); |
76 | } |
77 | template <typename T> |
78 | static c10::TypePtr create(std::vector<c10::TypePtr> types) { |
79 | return T::create(std::move(types)); |
80 | } |
81 | static c10::TypePtr createNamedTuple( |
82 | const std::string& name, |
83 | const std::vector<c10::string_view>& fields, |
84 | const std::vector<c10::TypePtr>& types); |
85 | template <typename T> |
86 | C10_ERASE static c10::TypePtr createNamed(const std::string& name) { |
87 | return T::create(name); |
88 | } |
89 | static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes(); |
90 | template <typename T> |
91 | C10_ERASE static c10::TypePtr get() { |
92 | return T::get(); |
93 | } |
94 | }; |
95 | |
96 | using DefaultTypeFactory = TypeFactoryBase<c10::Type>; |
97 | |
98 | using PlatformType = |
99 | #ifdef C10_MOBILE |
100 | c10::DynamicType |
101 | #else |
102 | c10::Type |
103 | #endif |
104 | ; |
105 | |
106 | using TypeFactory = TypeFactoryBase<PlatformType>; |
107 | |
108 | } // namespace c10 |
109 | |