1#pragma once
2
3#include <torch/csrc/utils/variadic.h>
4#include <torch/types.h>
5
6#include <cstdint>
7#include <type_traits>
8
9namespace torch {
10namespace nn {
11class Module;
12} // namespace nn
13} // namespace torch
14
15namespace torch {
16namespace detail {
17/// Detects if a type T has a forward() method.
18template <typename T>
19struct has_forward {
20 // Declare two types with differing size.
21 using yes = int8_t;
22 using no = int16_t;
23
24 // Here we declare two functions. The first is only enabled if `&U::forward`
25 // is well-formed and returns the `yes` type. In C++, the ellipsis parameter
26 // type (`...`) always puts the function at the bottom of overload resolution.
27 // This is specified in the standard as: 1) A standard conversion sequence is
28 // always better than a user-defined conversion sequence or an ellipsis
29 // conversion sequence. 2) A user-defined conversion sequence is always better
30 // than an ellipsis conversion sequence This means that if the first overload
31 // is viable, it will be preferred over the second as long as we pass any
32 // convertible type. The type of `&U::forward` is a pointer type, so we can
33 // pass e.g. 0.
34 template <typename U>
35 static yes test(decltype(&U::forward));
36 template <typename U>
37 static no test(...);
38
39 // Finally we test statically whether the size of the type returned by the
40 // selected overload is the size of the `yes` type.
41 static constexpr bool value = (sizeof(test<T>(nullptr)) == sizeof(yes));
42};
43
44template <typename Head = void, typename... Tail>
45constexpr bool check_not_lvalue_references() {
46 return (!std::is_lvalue_reference<Head>::value ||
47 std::is_const<typename std::remove_reference<Head>::type>::value) &&
48 check_not_lvalue_references<Tail...>();
49}
50
51template <>
52inline constexpr bool check_not_lvalue_references<void>() {
53 return true;
54}
55
56/// A type trait whose `value` member is true if `M` derives from `Module`.
57template <typename M>
58using is_module =
59 std::is_base_of<torch::nn::Module, typename std::decay<M>::type>;
60
61template <typename M, typename T = void>
62using enable_if_module_t =
63 typename std::enable_if<is_module<M>::value, T>::type;
64} // namespace detail
65} // namespace torch
66