1 | #pragma once |
2 | |
3 | /** |
4 | * This file contains functionality to take a C++ function and infer its |
5 | * c10::FunctionSchema. |
6 | */ |
7 | |
8 | #include <ATen/core/function_schema.h> |
9 | #include <c10/util/C++17.h> |
10 | #include <c10/util/Metaprogramming.h> |
11 | |
12 | namespace c10 { |
13 | namespace detail { |
14 | |
15 | namespace infer_schema { |
16 | |
17 | /// The templated inference code creates `ArgumentDef` instead of `Argument`, |
18 | /// because that can be constructed at compile time and has a much smaller |
19 | /// binary size than having calls to `Argument` constructors in the template. |
20 | /// Creating `Argument` objects from `ArgumentDef` can then be done at |
21 | /// runtime in a non-templated way. |
22 | struct ArgumentDef final { |
23 | using GetTypeFn = TypePtr(); |
24 | GetTypeFn* getTypeFn; |
25 | GetTypeFn* getFakeTypeFn; |
26 | constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {} |
27 | explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {} |
28 | }; |
29 | |
30 | template<bool V> |
31 | struct bool_t {}; |
32 | template<> struct bool_t<true> : std::true_type {}; |
33 | template<> struct bool_t<false> : std::false_type {}; |
34 | |
35 | /// Checks the static C++ types `Types` for correctness to catch common error cases. |
36 | template <class... Types> |
37 | constexpr int checkStaticTypes() { |
38 | // Give nice error messages for some of the common error cases. |
39 | // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT |
40 | static_assert(guts::conjunction< |
41 | bool_t<!std::is_integral<Types>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>... |
42 | >::value, "INVALID TYPE: Only int64_t and bool are supported as an integral argument type" ); |
43 | static_assert(guts::conjunction< |
44 | bool_t<!std::is_same<Types, float>::value>... |
45 | >::value, "INVALID TYPE: float is not supported as an argument type, use double instead" ); |
46 | return 0; |
47 | } |
48 | |
49 | template <typename... Ts, size_t... Is> |
50 | constexpr std::array<ArgumentDef, sizeof...(Ts)> createArgumentVectorFromTypes(std::index_sequence<Is...>) { |
51 | return ( |
52 | // Check types for common errors |
53 | checkStaticTypes<Ts...>(), |
54 | |
55 | // Create the return value |
56 | std::array<ArgumentDef, sizeof...(Ts)>{ |
57 | ArgumentDef(&getTypePtrCopy<std::decay_t<Ts>>, &getFakeTypePtrCopy<std::decay_t<Ts>>)...} |
58 | ); |
59 | } |
60 | |
61 | /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified |
62 | /// as template arguments. |
63 | template<class ParameterTypes> struct createArguments final {}; |
64 | template<class... ParameterTypes> |
65 | struct createArguments<guts::typelist::typelist<ParameterTypes...>> final { |
66 | static constexpr std::array<ArgumentDef, sizeof...(ParameterTypes)> call() { |
67 | return createArgumentVectorFromTypes<ParameterTypes...>( |
68 | std::make_index_sequence<sizeof...(ParameterTypes)>() |
69 | ); |
70 | } |
71 | }; |
72 | |
73 | /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified |
74 | /// as a tuple (i.e. in the way c10 kernels return values). |
75 | /// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C. |
76 | /// It can be an empty tuple<>, or void for kernels that don't return anything. |
77 | /// It can be a single type A (i.e. no tuple) for the case where a kernel just |
78 | /// returns one value. |
79 | template<class ReturnTypeTuple, class Enable = void> struct createReturns final {}; |
80 | |
81 | template<class... ReturnTypes> |
82 | struct createReturns<std::tuple<ReturnTypes...>, void> final { |
83 | static constexpr std::array<ArgumentDef, sizeof...(ReturnTypes)> call() { |
84 | return createArgumentVectorFromTypes<ReturnTypes...>( |
85 | std::make_index_sequence<sizeof...(ReturnTypes)>() |
86 | ); |
87 | } |
88 | }; |
89 | |
90 | template<class ReturnType> |
91 | struct createReturns<ReturnType, std::enable_if_t<!std::is_same<void, ReturnType>::value && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final { |
92 | static constexpr std::array<ArgumentDef, 1> call() { |
93 | return createReturns<std::tuple<ReturnType>>::call(); |
94 | } |
95 | }; |
96 | |
97 | template<> |
98 | struct createReturns<void, void> final { |
99 | static constexpr std::array<ArgumentDef, 0> call() { |
100 | return createReturns<std::tuple<>>::call(); |
101 | } |
102 | }; |
103 | |
104 | template <typename ReturnType> |
105 | struct createSingleReturn { |
106 | static constexpr std::array<ArgumentDef, 1> call() { |
107 | return createArgumentVectorFromTypes<ReturnType>(std::make_index_sequence<1>()); |
108 | } |
109 | }; |
110 | |
111 | C10_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns); |
112 | C10_API FunctionSchema make_function_schema(c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns); |
113 | |
114 | /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a |
115 | /// function. Flattens std::tuple returns into multiple return types |
116 | template <typename FunctionTraits> |
117 | FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() { |
118 | using ReturnType = typename FunctionTraits::return_type; |
119 | using ParameterTypes = typename FunctionTraits::parameter_types; |
120 | |
121 | // arguments and returns are computed into a std::array at compile time and embedded into the binary. |
122 | // The only code executed at runtime here is the one that creates a std::vector |
123 | // of the arguments/returns from the std::array. |
124 | constexpr auto arguments = createArguments<ParameterTypes>::call(); |
125 | constexpr auto returns = createReturns<ReturnType>::call(); |
126 | |
127 | return make_function_schema(arguments, returns); |
128 | } |
129 | |
130 | /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a |
131 | /// function. Preserves std::tuple returns as a Tuple return type |
132 | template <typename FunctionTraits> |
133 | FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) { |
134 | using ReturnType = typename FunctionTraits::return_type; |
135 | using ParameterTypes = typename FunctionTraits::parameter_types; |
136 | |
137 | // arguments and returns are computed into a std::array at compile time and embedded into the binary. |
138 | // The only code executed at runtime here is the one that creates a std::vector |
139 | // of the arguments/returns from the std::array. |
140 | constexpr auto arguments = createArguments<ParameterTypes>::call(); |
141 | constexpr auto returns = createSingleReturn<ReturnType>::call(); |
142 | |
143 | return make_function_schema(std::move(name), std::move(overload_name), arguments, returns); |
144 | } |
145 | |
146 | } |
147 | } |
148 | |
149 | template<class FuncType> |
150 | FunctionSchema inferFunctionSchemaFlattenedReturns() { |
151 | return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns<guts::infer_function_traits_t<FuncType>>(); |
152 | } |
153 | |
154 | template<class FuncType> |
155 | FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) { |
156 | return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name)); |
157 | } |
158 | |
159 | TORCH_API c10::optional<std::string> findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified); |
160 | |
161 | } |
162 | |