1#pragma once
2
3#include <type_traits>
4#include <unordered_map>
5
6#include <ATen/core/dynamic_type.h>
7#include <ATen/core/jit_type_base.h>
8#include <c10/macros/Macros.h>
9
10namespace c10 {
11
12template <typename T>
13struct TORCH_API TypeFactoryBase {};
14
15template <>
16struct TORCH_API TypeFactoryBase<c10::DynamicType> {
17 template <typename T, typename... Args>
18 static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) {
19 return std::make_shared<c10::DynamicType>(
20 c10::DynamicTypeTrait<T>::tagValue(),
21 c10::DynamicType::Arguments(c10::ArrayRef<c10::TypePtr>(
22 {std::move(ty), std::forward<Args>(args)...})));
23 }
24 template <typename T>
25 static c10::DynamicTypePtr create(std::vector<c10::TypePtr> types) {
26 return std::make_shared<c10::DynamicType>(
27 c10::DynamicTypeTrait<T>::tagValue(),
28 c10::DynamicType::Arguments(types));
29 }
30 static c10::DynamicTypePtr createNamedTuple(
31 const std::string& name,
32 const std::vector<c10::string_view>& fields,
33 const std::vector<c10::TypePtr>& types) {
34 return std::make_shared<c10::DynamicType>(
35 c10::DynamicType::Tag::Tuple,
36 name,
37 c10::DynamicType::Arguments(fields, types));
38 }
39 template <typename T>
40 C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) {
41 return std::make_shared<c10::DynamicType>(
42 c10::DynamicTypeTrait<T>::tagValue(),
43 name,
44 c10::DynamicType::Arguments{});
45 }
46 template <typename T>
47 C10_ERASE static c10::DynamicTypePtr get() {
48 return DynamicTypeTrait<T>::getBaseType();
49 }
50 static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
51};
52
53using DynamicTypeFactory = TypeFactoryBase<c10::DynamicType>;
54
55// Helper functions for constructing DynamicTypes inline.
56template <
57 typename T,
58 std::enable_if_t<DynamicTypeTrait<T>::isBaseType, int> = 0>
59C10_ERASE DynamicTypePtr dynT() {
60 return DynamicTypeFactory::get<T>();
61}
62
63template <
64 typename T,
65 typename... Args,
66 std::enable_if_t<!DynamicTypeTrait<T>::isBaseType, int> = 0>
67C10_ERASE DynamicTypePtr dynT(Args&&... args) {
68 return DynamicTypeFactory::create<T>(std::forward<Args>(args)...);
69}
70
71template <>
72struct TORCH_API TypeFactoryBase<c10::Type> {
73 template <typename T, typename... Args>
74 static c10::TypePtr create(TypePtr ty, Args&&... args) {
75 return T::create(std::move(ty), std::forward<Args>(args)...);
76 }
77 template <typename T>
78 static c10::TypePtr create(std::vector<c10::TypePtr> types) {
79 return T::create(std::move(types));
80 }
81 static c10::TypePtr createNamedTuple(
82 const std::string& name,
83 const std::vector<c10::string_view>& fields,
84 const std::vector<c10::TypePtr>& types);
85 template <typename T>
86 C10_ERASE static c10::TypePtr createNamed(const std::string& name) {
87 return T::create(name);
88 }
89 static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
90 template <typename T>
91 C10_ERASE static c10::TypePtr get() {
92 return T::get();
93 }
94};
95
96using DefaultTypeFactory = TypeFactoryBase<c10::Type>;
97
98using PlatformType =
99#ifdef C10_MOBILE
100 c10::DynamicType
101#else
102 c10::Type
103#endif
104 ;
105
106using TypeFactory = TypeFactoryBase<PlatformType>;
107
108} // namespace c10
109