1 | #pragma once |
2 | |
3 | #include <c10/util/Array.h> |
4 | #include <c10/util/TypeList.h> |
5 | #include <array> |
6 | #include <functional> |
7 | #include <type_traits> |
8 | |
9 | namespace c10 { |
10 | namespace guts { |
11 | |
12 | /** |
13 | * Access information about result type or arguments from a function type. |
14 | * Example: |
15 | * using A = function_traits<int (float, double)>::return_type // A == int |
16 | * using A = function_traits<int (float, double)>::parameter_types::tuple_type |
17 | * // A == tuple<float, double> |
18 | */ |
19 | template <class Func> |
20 | struct function_traits { |
21 | static_assert( |
22 | !std::is_same<Func, Func>::value, |
23 | "In function_traits<Func>, Func must be a plain function type." ); |
24 | }; |
25 | template <class Result, class... Args> |
26 | struct function_traits<Result(Args...)> { |
27 | using func_type = Result(Args...); |
28 | using return_type = Result; |
29 | using parameter_types = typelist::typelist<Args...>; |
30 | static constexpr auto number_of_parameters = sizeof...(Args); |
31 | }; |
32 | |
33 | /** |
34 | * infer_function_traits: creates a `function_traits` type for a simple |
35 | * function (pointer) or functor (lambda/struct). Currently does not support |
36 | * class methods. |
37 | */ |
38 | |
39 | template <typename Functor> |
40 | struct infer_function_traits { |
41 | using type = function_traits< |
42 | c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>; |
43 | }; |
44 | |
45 | template <typename Result, typename... Args> |
46 | struct infer_function_traits<Result (*)(Args...)> { |
47 | using type = function_traits<Result(Args...)>; |
48 | }; |
49 | |
50 | template <typename Result, typename... Args> |
51 | struct infer_function_traits<Result(Args...)> { |
52 | using type = function_traits<Result(Args...)>; |
53 | }; |
54 | |
55 | template <typename T> |
56 | using infer_function_traits_t = typename infer_function_traits<T>::type; |
57 | |
58 | /** |
59 | * make_function_traits: creates a `function_traits` type given a Return type |
60 | * and a typelist of Argument types |
61 | * |
62 | * Example: |
63 | * bool f(int, int); |
64 | * |
65 | * infer_function_traits_t<f> == make_function_traits_t<bool, |
66 | * typelist::typelist<int, int>> |
67 | */ |
68 | template <typename Result, typename ArgList> |
69 | struct make_function_traits { |
70 | static_assert( |
71 | false_t<ArgList>::value, |
72 | "In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>." ); |
73 | }; |
74 | |
75 | template <typename Result, typename... Args> |
76 | struct make_function_traits<Result, typelist::typelist<Args...>> { |
77 | using type = function_traits<Result(Args...)>; |
78 | }; |
79 | |
80 | template <typename Result, typename ArgList> |
81 | using make_function_traits_t = |
82 | typename make_function_traits<Result, ArgList>::type; |
83 | |
84 | /** |
85 | * Use extract_arg_by_filtered_index to return the i-th argument whose |
86 | * type fulfills a given type trait. The argument itself is perfectly forwarded. |
87 | * |
88 | * Example: |
89 | * std::string arg1 = "Hello"; |
90 | * std::string arg2 = "World"; |
91 | * std::string&& result = extract_arg_by_filtered_index<is_string, 1>(0, |
92 | * arg1, 2.0, std::move(arg2)); |
93 | * |
94 | * Warning: Taking the result by rvalue reference can cause segfaults because |
95 | * ownership will not be passed on from the original reference. The original |
96 | * reference dies after the expression and the resulting |
97 | */ |
98 | namespace detail { |
99 | template < |
100 | template <class> |
101 | class Condition, |
102 | size_t index, |
103 | class Enable, |
104 | class... Args> |
105 | struct ; |
106 | template < |
107 | template <class> |
108 | class Condition, |
109 | size_t index, |
110 | class Head, |
111 | class... Tail> |
112 | struct < |
113 | Condition, |
114 | index, |
115 | std::enable_if_t<!Condition<Head>::value>, |
116 | Head, |
117 | Tail...> { |
118 | static decltype(auto) (Head&& /*head*/, Tail&&... tail) { |
119 | return extract_arg_by_filtered_index_<Condition, index, void, Tail...>:: |
120 | call(std::forward<Tail>(tail)...); |
121 | } |
122 | }; |
123 | template < |
124 | template <class> |
125 | class Condition, |
126 | size_t index, |
127 | class Head, |
128 | class... Tail> |
129 | struct < |
130 | Condition, |
131 | index, |
132 | std::enable_if_t<Condition<Head>::value && index != 0>, |
133 | Head, |
134 | Tail...> { |
135 | static decltype(auto) (Head&& /*head*/, Tail&&... tail) { |
136 | return extract_arg_by_filtered_index_<Condition, index - 1, void, Tail...>:: |
137 | call(std::forward<Tail>(tail)...); |
138 | } |
139 | }; |
140 | template <template <class> class Condition, size_t index> |
141 | struct <Condition, index, void> { |
142 | static void () { |
143 | static_assert( |
144 | index != index, "extract_arg_by_filtered_index out of range." ); |
145 | } |
146 | }; |
147 | template < |
148 | template <class> |
149 | class Condition, |
150 | size_t index, |
151 | class Head, |
152 | class... Tail> |
153 | struct < |
154 | Condition, |
155 | index, |
156 | std::enable_if_t<Condition<Head>::value && index == 0>, |
157 | Head, |
158 | Tail...> { |
159 | static decltype(auto) (Head&& head, Tail&&... /*tail*/) { |
160 | return std::forward<Head>(head); |
161 | } |
162 | }; |
163 | } // namespace detail |
164 | template <template <class> class Condition, size_t index, class... Args> |
165 | decltype(auto) (Args&&... args) { |
166 | static_assert( |
167 | is_type_condition<Condition>::value, |
168 | "In extract_arg_by_filtered_index, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member." ); |
169 | return detail:: |
170 | extract_arg_by_filtered_index_<Condition, index, void, Args...>::call( |
171 | std::forward<Args>(args)...); |
172 | } |
173 | |
174 | /** |
175 | * Use filter_map to map a subset of the arguments to values. |
176 | * The subset is defined by type traits, and will be evaluated at compile time. |
177 | * At runtime, it will just loop over the pre-filtered arguments to create an |
178 | * std::array. |
179 | * |
180 | * Example: |
181 | * std::array<double, 2> result = filter_map<double, std::is_integral>([] (auto |
182 | * a) {return (double)a;}, 3, "bla", 4); |
183 | * // result == {3.0, 4.0} |
184 | */ |
185 | namespace detail { |
186 | |
187 | template <class ResultType, size_t num_results> |
188 | struct filter_map_ { |
189 | template < |
190 | template <class> |
191 | class Condition, |
192 | class Mapper, |
193 | class... Args, |
194 | size_t... INDEX> |
195 | static guts::array<ResultType, num_results> call( |
196 | const Mapper& mapper, |
197 | std::index_sequence<INDEX...>, |
198 | Args&&... args) { |
199 | return guts::array<ResultType, num_results>{ |
200 | mapper(extract_arg_by_filtered_index<Condition, INDEX>( |
201 | std::forward<Args>(args)...))...}; |
202 | } |
203 | }; |
204 | template <class ResultType> |
205 | struct filter_map_<ResultType, 0> { |
206 | template < |
207 | template <class> |
208 | class Condition, |
209 | class Mapper, |
210 | class... Args, |
211 | size_t... INDEX> |
212 | static guts::array<ResultType, 0> call( |
213 | const Mapper& /*mapper*/, |
214 | std::index_sequence<INDEX...>, |
215 | Args&&... /*args*/) { |
216 | return guts::array<ResultType, 0>{}; |
217 | } |
218 | }; |
219 | } // namespace detail |
220 | |
221 | template < |
222 | class ResultType, |
223 | template <class> |
224 | class Condition, |
225 | class Mapper, |
226 | class... Args> |
227 | decltype(auto) filter_map(const Mapper& mapper, Args&&... args) { |
228 | static_assert( |
229 | is_type_condition<Condition>::value, |
230 | "In filter_map<Result, Condition>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member." ); |
231 | |
232 | static constexpr size_t num_results = |
233 | typelist::count_if<Condition, typelist::typelist<Args...>>::value; |
234 | return detail::filter_map_<ResultType, num_results>:: |
235 | template call<Condition, Mapper, Args...>( |
236 | mapper, |
237 | std::make_index_sequence<num_results>(), |
238 | std::forward<Args>(args)...); |
239 | } |
240 | |
241 | /** |
242 | * make_offset_index_sequence<Start, N> |
243 | * Like make_index_sequence<N>, but starting from Start instead of 0. |
244 | * |
245 | * Example: |
246 | * make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12> |
247 | */ |
248 | template <size_t Start, size_t N, size_t... Is> |
249 | struct make_offset_index_sequence_impl |
250 | : make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> { |
251 | static_assert( |
252 | static_cast<int>(Start) >= 0, |
253 | "make_offset_index_sequence: Start < 0" ); |
254 | static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0" ); |
255 | }; |
256 | |
257 | template <size_t Start, size_t... Is> |
258 | struct make_offset_index_sequence_impl<Start, 0, Is...> { |
259 | typedef std::index_sequence<Is...> type; |
260 | }; |
261 | |
262 | template <size_t Start, size_t N> |
263 | using make_offset_index_sequence = |
264 | typename make_offset_index_sequence_impl<Start, N>::type; |
265 | |
266 | /** |
267 | * Use tuple_elements to extract a position-indexed subset of elements |
268 | * from the argument tuple into a result tuple. |
269 | * |
270 | * Example: |
271 | * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0); |
272 | * std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0, |
273 | * 2>()); |
274 | */ |
275 | template <class Tuple, size_t... Is> |
276 | constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) { |
277 | return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...); |
278 | } |
279 | |
280 | /** |
281 | * Use tuple_take to extract the first or last n elements from the argument |
282 | * tuple into a result tuple. |
283 | * |
284 | * Example: |
285 | * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0); |
286 | * std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t); |
287 | * std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t); |
288 | */ |
289 | template <class Tuple, int N, class Enable = void> |
290 | struct TupleTake {}; |
291 | |
292 | template <class Tuple, int N> |
293 | struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> { |
294 | static auto call(Tuple t) { |
295 | constexpr size_t size = std::tuple_size<Tuple>(); |
296 | static_assert(N <= size, "tuple_take: N > size" ); |
297 | return tuple_elements(t, std::make_index_sequence<N>{}); |
298 | } |
299 | }; |
300 | |
301 | template <class Tuple, int N> |
302 | struct TupleTake < Tuple, |
303 | N, std::enable_if_t<N<0, void>> { |
304 | static auto call(Tuple t) { |
305 | constexpr size_t size = std::tuple_size<Tuple>(); |
306 | static_assert(-N <= size, "tuple_take: -N > size" ); |
307 | return tuple_elements(t, make_offset_index_sequence<size + N, -N>{}); |
308 | } |
309 | }; |
310 | |
311 | template <class Tuple, int N> |
312 | auto tuple_take(Tuple t) { |
313 | return TupleTake<Tuple, N>::call(t); |
314 | } |
315 | |
316 | /** |
317 | * Use tuple_slice to extract a contiguous subtuple from the argument. |
318 | * |
319 | * Example: |
320 | * std::tuple<int, const char*, double, bool> t = std::make_tuple(0, |
321 | * "HEY", 2.0, false); std::tuple<int, const char*> middle_two = |
322 | * tuple_slice<decltype(t), 1, 2>(t); |
323 | */ |
324 | template <class Tuple, size_t Start, size_t N> |
325 | constexpr auto tuple_slice(Tuple t) { |
326 | constexpr size_t size = std::tuple_size<Tuple>(); |
327 | static_assert(Start + N <= size, "tuple_slice: Start + N > size" ); |
328 | return tuple_elements(t, make_offset_index_sequence<Start, N>{}); |
329 | } |
330 | |
331 | /** |
332 | * Use tuple_map to run a mapping function over a tuple to get a new tuple. |
333 | * |
334 | * Example 1: |
335 | * auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), [] |
336 | * (int32_t a) -> int16_t {return a+1;}); |
337 | * // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6) |
338 | * |
339 | * Example 2: |
340 | * struct Mapper { |
341 | * std::string operator()(int32_t a) const { |
342 | * return std::to_string(a); |
343 | * } |
344 | * int64_t operator()(const std::string& a) const { |
345 | * return atoi(a.c_str()); |
346 | * } |
347 | * }; |
348 | * auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"), |
349 | * Mapper()); |
350 | * // result == std::tuple<std::string, int64_t>("3", 4) |
351 | * |
352 | * Example 3: |
353 | * struct A final { |
354 | * int32_t func() { |
355 | * return 5; |
356 | * } |
357 | * }; |
358 | * struct B final { |
359 | * std::string func() { |
360 | * return "5"; |
361 | * } |
362 | * }; |
363 | * auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return |
364 | * a.func(); }); |
365 | * // result == std::tuple<int32_t, std::string>(5, "5"); |
366 | */ |
367 | namespace detail { |
368 | template <class Mapper, class... Args, size_t... Indices> |
369 | auto tuple_map( |
370 | std::tuple<Args...>&& tuple, |
371 | const Mapper& mapper, |
372 | std::index_sequence<Indices...>) { |
373 | return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>( |
374 | tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...); |
375 | } |
376 | } // namespace detail |
377 | |
378 | template <class Mapper, class... Args> |
379 | auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) { |
380 | return detail::tuple_map( |
381 | std::move(tuple), mapper, std::index_sequence_for<Args...>()); |
382 | } |
383 | |
384 | /** |
385 | * tuple_concat concatenates several tuples into one. |
386 | */ |
387 | |
388 | namespace detail { |
389 | // extract_tuple_element_by_index is a helper that takes a list of tuples and |
390 | // extracts the i-th element in a flattened view of the tuples. Example: |
391 | // extract_tuple_element_by_index<3>(tuple(2,3), tuple(4,5), tuple(6,7)) == 5. |
392 | |
393 | template < |
394 | size_t index, |
395 | class HeadTuple, |
396 | class... TailTuples, |
397 | std::enable_if_t< |
398 | index<std::tuple_size<HeadTuple>::value, int> = 0> decltype(auto) |
399 | ( |
400 | HeadTuple&& head_tuple, |
401 | TailTuples&&... /*tail_tuples*/) { |
402 | // TODO if constexpr instead of enable_if |
403 | return std::get<index>(std::forward<HeadTuple>(head_tuple)); |
404 | } |
405 | |
406 | template < |
407 | size_t index, |
408 | class HeadTuple, |
409 | class... TailTuples, |
410 | std::enable_if_t<index >= std::tuple_size<HeadTuple>::value, int> = 0> |
411 | decltype(auto) ( |
412 | HeadTuple&& /*head_tuple*/, |
413 | TailTuples&&... tail_tuples) { |
414 | // TODO if constexpr instead of enable_if |
415 | return extract_tuple_element_by_index< |
416 | index - std::tuple_size<HeadTuple>::value, |
417 | TailTuples...>(std::forward<TailTuples>(tail_tuples)...); |
418 | } |
419 | |
420 | static_assert( |
421 | std::is_same< |
422 | int&&, |
423 | decltype(extract_tuple_element_by_index<2>( |
424 | std::tuple<int32_t>(2), |
425 | std::tuple<int32_t&&, int32_t>(std::declval<int32_t>(), 3)))>:: |
426 | value, |
427 | "extract_tuple_element_by_index should return rvalue references if the tuple contains them. It should not move them into a value" ); |
428 | |
429 | template <class ConcatenatedTuple, class... Tuples, size_t... ElementIndices> |
430 | auto tuple_concat(Tuples&&... tuples, std::index_sequence<ElementIndices...>) { |
431 | return ConcatenatedTuple(extract_tuple_element_by_index<ElementIndices>( |
432 | std::forward<Tuples>(tuples)...)...); |
433 | } |
434 | } // namespace detail |
435 | |
436 | template <class... Tuples> |
437 | auto tuple_concat(Tuples&&... tuples) { |
438 | using flattened_types = |
439 | guts::typelist::concat_t<guts::typelist::from_tuple_t<Tuples>...>; |
440 | using concatenated_tuple = guts::typelist::to_tuple_t<flattened_types>; |
441 | constexpr size_t num_elements = guts::typelist::size<flattened_types>::value; |
442 | return detail::tuple_concat<concatenated_tuple, Tuples...>( |
443 | std::forward<Tuples>(tuples)..., |
444 | std::make_index_sequence<num_elements>()); |
445 | } |
446 | |
447 | /** |
448 | * Concatenate multiple integer sequences |
449 | * Example: |
450 | * concat_iseq_t<std::index_sequence<2, 5, 3>, std::index_sequence<4, 2>, |
451 | * std::index_sequence<5>> |
452 | * == std::index_sequence<2, 5, 3, 4, 2, 5> |
453 | */ |
454 | template <class... ISeqs> |
455 | struct concat_iseq { |
456 | static_assert( |
457 | false_t<ISeqs...>::value, |
458 | "In concat_iseq<T1, ...>, the T arguments each must be std::integer_sequence<...> with the same IntType." ); |
459 | }; |
460 | template <> |
461 | struct concat_iseq<> { |
462 | using type = std::index_sequence<>; |
463 | }; |
464 | template <class IntType, IntType... Indices> |
465 | struct concat_iseq<std::integer_sequence<IntType, Indices...>> { |
466 | using type = std::integer_sequence<IntType, Indices...>; |
467 | }; |
468 | template < |
469 | class IntType, |
470 | IntType... Head1Indices, |
471 | IntType... Head2Indices, |
472 | class... TailISeqs> |
473 | struct concat_iseq< |
474 | std::integer_sequence<IntType, Head1Indices...>, |
475 | std::integer_sequence<IntType, Head2Indices...>, |
476 | TailISeqs...> { |
477 | using type = typename concat_iseq< |
478 | std::integer_sequence<IntType, Head1Indices..., Head2Indices...>, |
479 | TailISeqs...>::type; |
480 | }; |
481 | template <class... ISeqs> |
482 | using concat_iseq_t = typename concat_iseq<ISeqs...>::type; |
483 | |
484 | } // namespace guts |
485 | } // namespace c10 |
486 | |