1 | #pragma once |
2 | |
3 | #include <torch/csrc/utils/variadic.h> |
4 | #include <torch/types.h> |
5 | |
6 | #include <cstdint> |
7 | #include <type_traits> |
8 | |
9 | namespace torch { |
10 | namespace nn { |
11 | class Module; |
12 | } // namespace nn |
13 | } // namespace torch |
14 | |
15 | namespace torch { |
16 | namespace detail { |
17 | /// Detects if a type T has a forward() method. |
18 | template <typename T> |
19 | struct has_forward { |
20 | // Declare two types with differing size. |
21 | using yes = int8_t; |
22 | using no = int16_t; |
23 | |
24 | // Here we declare two functions. The first is only enabled if `&U::forward` |
25 | // is well-formed and returns the `yes` type. In C++, the ellipsis parameter |
26 | // type (`...`) always puts the function at the bottom of overload resolution. |
27 | // This is specified in the standard as: 1) A standard conversion sequence is |
28 | // always better than a user-defined conversion sequence or an ellipsis |
29 | // conversion sequence. 2) A user-defined conversion sequence is always better |
30 | // than an ellipsis conversion sequence This means that if the first overload |
31 | // is viable, it will be preferred over the second as long as we pass any |
32 | // convertible type. The type of `&U::forward` is a pointer type, so we can |
33 | // pass e.g. 0. |
34 | template <typename U> |
35 | static yes test(decltype(&U::forward)); |
36 | template <typename U> |
37 | static no test(...); |
38 | |
39 | // Finally we test statically whether the size of the type returned by the |
40 | // selected overload is the size of the `yes` type. |
41 | static constexpr bool value = (sizeof(test<T>(nullptr)) == sizeof(yes)); |
42 | }; |
43 | |
44 | template <typename Head = void, typename... Tail> |
45 | constexpr bool check_not_lvalue_references() { |
46 | return (!std::is_lvalue_reference<Head>::value || |
47 | std::is_const<typename std::remove_reference<Head>::type>::value) && |
48 | check_not_lvalue_references<Tail...>(); |
49 | } |
50 | |
51 | template <> |
52 | inline constexpr bool check_not_lvalue_references<void>() { |
53 | return true; |
54 | } |
55 | |
56 | /// A type trait whose `value` member is true if `M` derives from `Module`. |
57 | template <typename M> |
58 | using is_module = |
59 | std::is_base_of<torch::nn::Module, typename std::decay<M>::type>; |
60 | |
61 | template <typename M, typename T = void> |
62 | using enable_if_module_t = |
63 | typename std::enable_if<is_module<M>::value, T>::type; |
64 | } // namespace detail |
65 | } // namespace torch |
66 | |