1#pragma once
2
3#include <c10/util/Array.h>
4#include <c10/util/TypeList.h>
5#include <array>
6#include <functional>
7#include <type_traits>
8
9namespace c10 {
10namespace guts {
11
12/**
13 * Access information about result type or arguments from a function type.
14 * Example:
15 * using A = function_traits<int (float, double)>::return_type // A == int
16 * using A = function_traits<int (float, double)>::parameter_types::tuple_type
17 * // A == tuple<float, double>
18 */
19template <class Func>
20struct function_traits {
21 static_assert(
22 !std::is_same<Func, Func>::value,
23 "In function_traits<Func>, Func must be a plain function type.");
24};
25template <class Result, class... Args>
26struct function_traits<Result(Args...)> {
27 using func_type = Result(Args...);
28 using return_type = Result;
29 using parameter_types = typelist::typelist<Args...>;
30 static constexpr auto number_of_parameters = sizeof...(Args);
31};
32
33/**
34 * infer_function_traits: creates a `function_traits` type for a simple
35 * function (pointer) or functor (lambda/struct). Currently does not support
36 * class methods.
37 */
38
39template <typename Functor>
40struct infer_function_traits {
41 using type = function_traits<
42 c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
43};
44
45template <typename Result, typename... Args>
46struct infer_function_traits<Result (*)(Args...)> {
47 using type = function_traits<Result(Args...)>;
48};
49
50template <typename Result, typename... Args>
51struct infer_function_traits<Result(Args...)> {
52 using type = function_traits<Result(Args...)>;
53};
54
55template <typename T>
56using infer_function_traits_t = typename infer_function_traits<T>::type;
57
58/**
59 * make_function_traits: creates a `function_traits` type given a Return type
60 * and a typelist of Argument types
61 *
62 * Example:
63 * bool f(int, int);
64 *
65 * infer_function_traits_t<f> == make_function_traits_t<bool,
66 * typelist::typelist<int, int>>
67 */
68template <typename Result, typename ArgList>
69struct make_function_traits {
70 static_assert(
71 false_t<ArgList>::value,
72 "In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>.");
73};
74
75template <typename Result, typename... Args>
76struct make_function_traits<Result, typelist::typelist<Args...>> {
77 using type = function_traits<Result(Args...)>;
78};
79
80template <typename Result, typename ArgList>
81using make_function_traits_t =
82 typename make_function_traits<Result, ArgList>::type;
83
84/**
85 * Use extract_arg_by_filtered_index to return the i-th argument whose
86 * type fulfills a given type trait. The argument itself is perfectly forwarded.
87 *
88 * Example:
89 * std::string arg1 = "Hello";
90 * std::string arg2 = "World";
91 * std::string&& result = extract_arg_by_filtered_index<is_string, 1>(0,
92 * arg1, 2.0, std::move(arg2));
93 *
94 * Warning: Taking the result by rvalue reference can cause segfaults because
95 * ownership will not be passed on from the original reference. The original
96 * reference dies after the expression and the resulting
97 */
98namespace detail {
99template <
100 template <class>
101 class Condition,
102 size_t index,
103 class Enable,
104 class... Args>
105struct extract_arg_by_filtered_index_;
106template <
107 template <class>
108 class Condition,
109 size_t index,
110 class Head,
111 class... Tail>
112struct extract_arg_by_filtered_index_<
113 Condition,
114 index,
115 std::enable_if_t<!Condition<Head>::value>,
116 Head,
117 Tail...> {
118 static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
119 return extract_arg_by_filtered_index_<Condition, index, void, Tail...>::
120 call(std::forward<Tail>(tail)...);
121 }
122};
123template <
124 template <class>
125 class Condition,
126 size_t index,
127 class Head,
128 class... Tail>
129struct extract_arg_by_filtered_index_<
130 Condition,
131 index,
132 std::enable_if_t<Condition<Head>::value && index != 0>,
133 Head,
134 Tail...> {
135 static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
136 return extract_arg_by_filtered_index_<Condition, index - 1, void, Tail...>::
137 call(std::forward<Tail>(tail)...);
138 }
139};
140template <template <class> class Condition, size_t index>
141struct extract_arg_by_filtered_index_<Condition, index, void> {
142 static void call() {
143 static_assert(
144 index != index, "extract_arg_by_filtered_index out of range.");
145 }
146};
147template <
148 template <class>
149 class Condition,
150 size_t index,
151 class Head,
152 class... Tail>
153struct extract_arg_by_filtered_index_<
154 Condition,
155 index,
156 std::enable_if_t<Condition<Head>::value && index == 0>,
157 Head,
158 Tail...> {
159 static decltype(auto) call(Head&& head, Tail&&... /*tail*/) {
160 return std::forward<Head>(head);
161 }
162};
163} // namespace detail
164template <template <class> class Condition, size_t index, class... Args>
165decltype(auto) extract_arg_by_filtered_index(Args&&... args) {
166 static_assert(
167 is_type_condition<Condition>::value,
168 "In extract_arg_by_filtered_index, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
169 return detail::
170 extract_arg_by_filtered_index_<Condition, index, void, Args...>::call(
171 std::forward<Args>(args)...);
172}
173
174/**
175 * Use filter_map to map a subset of the arguments to values.
176 * The subset is defined by type traits, and will be evaluated at compile time.
177 * At runtime, it will just loop over the pre-filtered arguments to create an
178 * std::array.
179 *
180 * Example:
181 * std::array<double, 2> result = filter_map<double, std::is_integral>([] (auto
182 * a) {return (double)a;}, 3, "bla", 4);
183 * // result == {3.0, 4.0}
184 */
185namespace detail {
186
187template <class ResultType, size_t num_results>
188struct filter_map_ {
189 template <
190 template <class>
191 class Condition,
192 class Mapper,
193 class... Args,
194 size_t... INDEX>
195 static guts::array<ResultType, num_results> call(
196 const Mapper& mapper,
197 std::index_sequence<INDEX...>,
198 Args&&... args) {
199 return guts::array<ResultType, num_results>{
200 mapper(extract_arg_by_filtered_index<Condition, INDEX>(
201 std::forward<Args>(args)...))...};
202 }
203};
204template <class ResultType>
205struct filter_map_<ResultType, 0> {
206 template <
207 template <class>
208 class Condition,
209 class Mapper,
210 class... Args,
211 size_t... INDEX>
212 static guts::array<ResultType, 0> call(
213 const Mapper& /*mapper*/,
214 std::index_sequence<INDEX...>,
215 Args&&... /*args*/) {
216 return guts::array<ResultType, 0>{};
217 }
218};
219} // namespace detail
220
221template <
222 class ResultType,
223 template <class>
224 class Condition,
225 class Mapper,
226 class... Args>
227decltype(auto) filter_map(const Mapper& mapper, Args&&... args) {
228 static_assert(
229 is_type_condition<Condition>::value,
230 "In filter_map<Result, Condition>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
231
232 static constexpr size_t num_results =
233 typelist::count_if<Condition, typelist::typelist<Args...>>::value;
234 return detail::filter_map_<ResultType, num_results>::
235 template call<Condition, Mapper, Args...>(
236 mapper,
237 std::make_index_sequence<num_results>(),
238 std::forward<Args>(args)...);
239}
240
241/**
242 * make_offset_index_sequence<Start, N>
243 * Like make_index_sequence<N>, but starting from Start instead of 0.
244 *
245 * Example:
246 * make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12>
247 */
248template <size_t Start, size_t N, size_t... Is>
249struct make_offset_index_sequence_impl
250 : make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> {
251 static_assert(
252 static_cast<int>(Start) >= 0,
253 "make_offset_index_sequence: Start < 0");
254 static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0");
255};
256
257template <size_t Start, size_t... Is>
258struct make_offset_index_sequence_impl<Start, 0, Is...> {
259 typedef std::index_sequence<Is...> type;
260};
261
262template <size_t Start, size_t N>
263using make_offset_index_sequence =
264 typename make_offset_index_sequence_impl<Start, N>::type;
265
266/**
267 * Use tuple_elements to extract a position-indexed subset of elements
268 * from the argument tuple into a result tuple.
269 *
270 * Example:
271 * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
272 * std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0,
273 * 2>());
274 */
275template <class Tuple, size_t... Is>
276constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) {
277 return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...);
278}
279
280/**
281 * Use tuple_take to extract the first or last n elements from the argument
282 * tuple into a result tuple.
283 *
284 * Example:
285 * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
286 * std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t);
287 * std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t);
288 */
289template <class Tuple, int N, class Enable = void>
290struct TupleTake {};
291
292template <class Tuple, int N>
293struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> {
294 static auto call(Tuple t) {
295 constexpr size_t size = std::tuple_size<Tuple>();
296 static_assert(N <= size, "tuple_take: N > size");
297 return tuple_elements(t, std::make_index_sequence<N>{});
298 }
299};
300
301template <class Tuple, int N>
302 struct TupleTake < Tuple,
303 N, std::enable_if_t<N<0, void>> {
304 static auto call(Tuple t) {
305 constexpr size_t size = std::tuple_size<Tuple>();
306 static_assert(-N <= size, "tuple_take: -N > size");
307 return tuple_elements(t, make_offset_index_sequence<size + N, -N>{});
308 }
309};
310
311template <class Tuple, int N>
312auto tuple_take(Tuple t) {
313 return TupleTake<Tuple, N>::call(t);
314}
315
316/**
317 * Use tuple_slice to extract a contiguous subtuple from the argument.
318 *
319 * Example:
320 * std::tuple<int, const char*, double, bool> t = std::make_tuple(0,
321 * "HEY", 2.0, false); std::tuple<int, const char*> middle_two =
322 * tuple_slice<decltype(t), 1, 2>(t);
323 */
324template <class Tuple, size_t Start, size_t N>
325constexpr auto tuple_slice(Tuple t) {
326 constexpr size_t size = std::tuple_size<Tuple>();
327 static_assert(Start + N <= size, "tuple_slice: Start + N > size");
328 return tuple_elements(t, make_offset_index_sequence<Start, N>{});
329}
330
331/**
332 * Use tuple_map to run a mapping function over a tuple to get a new tuple.
333 *
334 * Example 1:
335 * auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), []
336 * (int32_t a) -> int16_t {return a+1;});
337 * // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
338 *
339 * Example 2:
340 * struct Mapper {
341 * std::string operator()(int32_t a) const {
342 * return std::to_string(a);
343 * }
344 * int64_t operator()(const std::string& a) const {
345 * return atoi(a.c_str());
346 * }
347 * };
348 * auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"),
349 * Mapper());
350 * // result == std::tuple<std::string, int64_t>("3", 4)
351 *
352 * Example 3:
353 * struct A final {
354 * int32_t func() {
355 * return 5;
356 * }
357 * };
358 * struct B final {
359 * std::string func() {
360 * return "5";
361 * }
362 * };
363 * auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return
364 * a.func(); });
365 * // result == std::tuple<int32_t, std::string>(5, "5");
366 */
367namespace detail {
368template <class Mapper, class... Args, size_t... Indices>
369auto tuple_map(
370 std::tuple<Args...>&& tuple,
371 const Mapper& mapper,
372 std::index_sequence<Indices...>) {
373 return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(
374 tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...);
375}
376} // namespace detail
377
378template <class Mapper, class... Args>
379auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
380 return detail::tuple_map(
381 std::move(tuple), mapper, std::index_sequence_for<Args...>());
382}
383
384/**
385 * tuple_concat concatenates several tuples into one.
386 */
387
388namespace detail {
389// extract_tuple_element_by_index is a helper that takes a list of tuples and
390// extracts the i-th element in a flattened view of the tuples. Example:
391// extract_tuple_element_by_index<3>(tuple(2,3), tuple(4,5), tuple(6,7)) == 5.
392
393template <
394 size_t index,
395 class HeadTuple,
396 class... TailTuples,
397 std::enable_if_t<
398 index<std::tuple_size<HeadTuple>::value, int> = 0> decltype(auto)
399 extract_tuple_element_by_index(
400 HeadTuple&& head_tuple,
401 TailTuples&&... /*tail_tuples*/) {
402 // TODO if constexpr instead of enable_if
403 return std::get<index>(std::forward<HeadTuple>(head_tuple));
404}
405
406template <
407 size_t index,
408 class HeadTuple,
409 class... TailTuples,
410 std::enable_if_t<index >= std::tuple_size<HeadTuple>::value, int> = 0>
411decltype(auto) extract_tuple_element_by_index(
412 HeadTuple&& /*head_tuple*/,
413 TailTuples&&... tail_tuples) {
414 // TODO if constexpr instead of enable_if
415 return extract_tuple_element_by_index<
416 index - std::tuple_size<HeadTuple>::value,
417 TailTuples...>(std::forward<TailTuples>(tail_tuples)...);
418}
419
420static_assert(
421 std::is_same<
422 int&&,
423 decltype(extract_tuple_element_by_index<2>(
424 std::tuple<int32_t>(2),
425 std::tuple<int32_t&&, int32_t>(std::declval<int32_t>(), 3)))>::
426 value,
427 "extract_tuple_element_by_index should return rvalue references if the tuple contains them. It should not move them into a value");
428
429template <class ConcatenatedTuple, class... Tuples, size_t... ElementIndices>
430auto tuple_concat(Tuples&&... tuples, std::index_sequence<ElementIndices...>) {
431 return ConcatenatedTuple(extract_tuple_element_by_index<ElementIndices>(
432 std::forward<Tuples>(tuples)...)...);
433}
434} // namespace detail
435
436template <class... Tuples>
437auto tuple_concat(Tuples&&... tuples) {
438 using flattened_types =
439 guts::typelist::concat_t<guts::typelist::from_tuple_t<Tuples>...>;
440 using concatenated_tuple = guts::typelist::to_tuple_t<flattened_types>;
441 constexpr size_t num_elements = guts::typelist::size<flattened_types>::value;
442 return detail::tuple_concat<concatenated_tuple, Tuples...>(
443 std::forward<Tuples>(tuples)...,
444 std::make_index_sequence<num_elements>());
445}
446
447/**
448 * Concatenate multiple integer sequences
449 * Example:
450 * concat_iseq_t<std::index_sequence<2, 5, 3>, std::index_sequence<4, 2>,
451 * std::index_sequence<5>>
452 * == std::index_sequence<2, 5, 3, 4, 2, 5>
453 */
454template <class... ISeqs>
455struct concat_iseq {
456 static_assert(
457 false_t<ISeqs...>::value,
458 "In concat_iseq<T1, ...>, the T arguments each must be std::integer_sequence<...> with the same IntType.");
459};
460template <>
461struct concat_iseq<> {
462 using type = std::index_sequence<>;
463};
464template <class IntType, IntType... Indices>
465struct concat_iseq<std::integer_sequence<IntType, Indices...>> {
466 using type = std::integer_sequence<IntType, Indices...>;
467};
468template <
469 class IntType,
470 IntType... Head1Indices,
471 IntType... Head2Indices,
472 class... TailISeqs>
473struct concat_iseq<
474 std::integer_sequence<IntType, Head1Indices...>,
475 std::integer_sequence<IntType, Head2Indices...>,
476 TailISeqs...> {
477 using type = typename concat_iseq<
478 std::integer_sequence<IntType, Head1Indices..., Head2Indices...>,
479 TailISeqs...>::type;
480};
481template <class... ISeqs>
482using concat_iseq_t = typename concat_iseq<ISeqs...>::type;
483
484} // namespace guts
485} // namespace c10
486