1#pragma once
2
3/**
4 * This file contains functionality to take a C++ function and infer its
5 * c10::FunctionSchema.
6 */
7
8#include <ATen/core/function_schema.h>
9#include <c10/util/C++17.h>
10#include <c10/util/Metaprogramming.h>
11
12namespace c10 {
13namespace detail {
14
15namespace infer_schema {
16
17/// The templated inference code creates `ArgumentDef` instead of `Argument`,
18/// because that can be constructed at compile time and has a much smaller
19/// binary size than having calls to `Argument` constructors in the template.
20/// Creating `Argument` objects from `ArgumentDef` can then be done at
21/// runtime in a non-templated way.
22struct ArgumentDef final {
23 using GetTypeFn = TypePtr();
24 GetTypeFn* getTypeFn;
25 GetTypeFn* getFakeTypeFn;
26 constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {}
27 explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {}
28};
29
30template<bool V>
31struct bool_t {};
32template<> struct bool_t<true> : std::true_type {};
33template<> struct bool_t<false> : std::false_type {};
34
35/// Checks the static C++ types `Types` for correctness to catch common error cases.
36template <class... Types>
37constexpr int checkStaticTypes() {
38 // Give nice error messages for some of the common error cases.
39 // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
40 static_assert(guts::conjunction<
41 bool_t<!std::is_integral<Types>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
42 >::value, "INVALID TYPE: Only int64_t and bool are supported as an integral argument type");
43 static_assert(guts::conjunction<
44 bool_t<!std::is_same<Types, float>::value>...
45 >::value, "INVALID TYPE: float is not supported as an argument type, use double instead");
46 return 0;
47}
48
49template <typename... Ts, size_t... Is>
50constexpr std::array<ArgumentDef, sizeof...(Ts)> createArgumentVectorFromTypes(std::index_sequence<Is...>) {
51 return (
52 // Check types for common errors
53 checkStaticTypes<Ts...>(),
54
55 // Create the return value
56 std::array<ArgumentDef, sizeof...(Ts)>{
57 ArgumentDef(&getTypePtrCopy<std::decay_t<Ts>>, &getFakeTypePtrCopy<std::decay_t<Ts>>)...}
58 );
59}
60
61/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
62/// as template arguments.
63template<class ParameterTypes> struct createArguments final {};
64template<class... ParameterTypes>
65struct createArguments<guts::typelist::typelist<ParameterTypes...>> final {
66 static constexpr std::array<ArgumentDef, sizeof...(ParameterTypes)> call() {
67 return createArgumentVectorFromTypes<ParameterTypes...>(
68 std::make_index_sequence<sizeof...(ParameterTypes)>()
69 );
70 }
71};
72
73/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
74/// as a tuple (i.e. in the way c10 kernels return values).
75/// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C.
76/// It can be an empty tuple<>, or void for kernels that don't return anything.
77/// It can be a single type A (i.e. no tuple) for the case where a kernel just
78/// returns one value.
79template<class ReturnTypeTuple, class Enable = void> struct createReturns final {};
80
81template<class... ReturnTypes>
82struct createReturns<std::tuple<ReturnTypes...>, void> final {
83 static constexpr std::array<ArgumentDef, sizeof...(ReturnTypes)> call() {
84 return createArgumentVectorFromTypes<ReturnTypes...>(
85 std::make_index_sequence<sizeof...(ReturnTypes)>()
86 );
87 }
88};
89
90template<class ReturnType>
91struct createReturns<ReturnType, std::enable_if_t<!std::is_same<void, ReturnType>::value && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final {
92 static constexpr std::array<ArgumentDef, 1> call() {
93 return createReturns<std::tuple<ReturnType>>::call();
94 }
95};
96
97template<>
98struct createReturns<void, void> final {
99 static constexpr std::array<ArgumentDef, 0> call() {
100 return createReturns<std::tuple<>>::call();
101 }
102};
103
104template <typename ReturnType>
105struct createSingleReturn {
106 static constexpr std::array<ArgumentDef, 1> call() {
107 return createArgumentVectorFromTypes<ReturnType>(std::make_index_sequence<1>());
108 }
109};
110
111C10_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns);
112C10_API FunctionSchema make_function_schema(c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns);
113
114/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
115/// function. Flattens std::tuple returns into multiple return types
116template <typename FunctionTraits>
117FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() {
118 using ReturnType = typename FunctionTraits::return_type;
119 using ParameterTypes = typename FunctionTraits::parameter_types;
120
121 // arguments and returns are computed into a std::array at compile time and embedded into the binary.
122 // The only code executed at runtime here is the one that creates a std::vector
123 // of the arguments/returns from the std::array.
124 constexpr auto arguments = createArguments<ParameterTypes>::call();
125 constexpr auto returns = createReturns<ReturnType>::call();
126
127 return make_function_schema(arguments, returns);
128}
129
130/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
131/// function. Preserves std::tuple returns as a Tuple return type
132template <typename FunctionTraits>
133FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) {
134 using ReturnType = typename FunctionTraits::return_type;
135 using ParameterTypes = typename FunctionTraits::parameter_types;
136
137 // arguments and returns are computed into a std::array at compile time and embedded into the binary.
138 // The only code executed at runtime here is the one that creates a std::vector
139 // of the arguments/returns from the std::array.
140 constexpr auto arguments = createArguments<ParameterTypes>::call();
141 constexpr auto returns = createSingleReturn<ReturnType>::call();
142
143 return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
144}
145
146}
147}
148
149template<class FuncType>
150FunctionSchema inferFunctionSchemaFlattenedReturns() {
151 return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns<guts::infer_function_traits_t<FuncType>>();
152}
153
154template<class FuncType>
155FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) {
156 return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
157}
158
159TORCH_API c10::optional<std::string> findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified);
160
161}
162