1 | #pragma once |
2 | |
3 | #include <ATen/core/Tensor.h> |
4 | #include <ATen/core/Variadic.h> |
5 | #include <torch/csrc/autograd/variable.h> |
6 | |
7 | #include <cstdint> |
8 | #include <tuple> |
9 | #include <type_traits> |
10 | #include <utility> |
11 | |
12 | namespace torch { |
13 | |
14 | using at::IterArgs; |
15 | |
16 | struct CountTensors : IterArgs<CountTensors> { |
17 | size_t out = 0; |
18 | void operator()(const at::Tensor& x) { |
19 | out += 1; |
20 | } |
21 | void operator()(const c10::optional<at::Tensor>& x) { |
22 | out += x.has_value(); |
23 | } |
24 | void operator()(at::ArrayRef<at::Tensor> xs) { |
25 | out += xs.size(); |
26 | } |
27 | }; |
28 | |
29 | template <typename... Args> |
30 | size_t count_tensors(Args&&... args) { |
31 | return CountTensors().apply(std::forward<Args>(args)...).out; |
32 | } |
33 | |
34 | struct CountVariables : IterArgs<CountVariables> { |
35 | size_t out = 0; |
36 | void operator()(const autograd::Variable& x) { |
37 | out += 1; |
38 | } |
39 | void operator()(at::ArrayRef<autograd::Variable> xs) { |
40 | out += xs.size(); |
41 | } |
42 | }; |
43 | |
44 | template <typename... Args> |
45 | inline size_t count_variables(Args&&... args) { |
46 | return CountVariables().apply(std::forward<Args>(args)...).out; |
47 | } |
48 | |
49 | //===----------------------------------------------------------------------===// |
50 | // std::index_sequence shim for C++11 |
51 | //===----------------------------------------------------------------------===// |
52 | |
53 | // A container of type-template parameter indices. |
54 | template <size_t... Is> |
55 | struct Indices {}; |
56 | |
57 | // Decrements the index N, adds N-1 to the list of indices and forwards |
58 | // whatever we already have. |
59 | template <size_t N, size_t... Is> |
60 | struct MakeIndices : MakeIndices<N - 1, N - 1, Is...> {}; |
61 | |
62 | // Partial specialization that forms our base case. When N is zero, we stop |
63 | // and define a typedef that will be visible to earlier classes due to |
64 | // inheritance. The typedef we define is an index list containing the numbers |
65 | // 0 through N-1. |
66 | template <size_t... Is> |
67 | struct MakeIndices<0, Is...> { |
68 | using indices = Indices<Is...>; |
69 | }; |
70 | |
71 | //===----------------------------------------------------------------------===// |
72 | // Utilities |
73 | //===----------------------------------------------------------------------===// |
74 | |
75 | template <bool value, typename T = void> |
76 | using enable_if_t = typename std::enable_if<value, T>::type; |
77 | |
78 | template <bool value, typename T = void> |
79 | using disable_if_t = enable_if_t<!value, T>; |
80 | |
81 | template <typename T> |
82 | using decay_t = typename std::decay<T>::type; |
83 | |
84 | namespace detail { |
85 | template <bool...> |
86 | struct pack; |
87 | } // namespace detail |
88 | |
89 | template <bool... values> |
90 | struct all_of : std::is_same< |
91 | detail::pack<values..., true>, |
92 | detail::pack<true, values...>> {}; |
93 | |
94 | template <bool...> |
95 | struct any_of; |
96 | |
97 | template <> |
98 | struct any_of<> : std::false_type {}; |
99 | |
100 | template <bool head, bool... tail> |
101 | struct any_of<head, tail...> { |
102 | static constexpr bool value = head || any_of<tail...>::value; |
103 | }; |
104 | |
105 | template <bool... values> |
106 | struct none_of { |
107 | static constexpr bool value = !any_of<values...>::value; |
108 | }; |
109 | |
110 | template <bool... values> |
111 | using enable_if_all_of_t = enable_if_t<all_of<values...>::value>; |
112 | |
113 | template <typename T, typename... Ts> |
114 | using disable_if_contains_t = |
115 | enable_if_all_of_t<(!std::is_same<T, decay_t<Ts>>::value)...>; |
116 | |
117 | template <typename Function, typename... Ts> |
118 | void apply(Function function, Ts&&... ts) { |
119 | // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector |
120 | // Creates a dummy array, so that each function call is evaluated in order. |
121 | // `(function(), 0)` is because `function` should (!) return `void`, so |
122 | // according to the comma operator, it is evaluated and its result (`void`) |
123 | // is discarded. Then the zero is evaluated and used as an element in the |
124 | // array. The first zero ensures the array is not empty. |
125 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
126 | int _[]{0, (function(std::forward<Ts>(ts)), 0)...}; |
127 | (void)_; |
128 | } |
129 | |
130 | template < |
131 | typename ReturnType, |
132 | typename... Ts, |
133 | typename Function, |
134 | typename Accessor> |
135 | ReturnType unpack(Function function, Accessor accessor) { |
136 | return ReturnType(unpack<ReturnType, Ts...>( |
137 | std::move(function), |
138 | std::move(accessor), |
139 | typename MakeIndices<sizeof...(Ts)>::indices())); |
140 | } |
141 | |
142 | template < |
143 | typename ReturnType, |
144 | typename... Ts, |
145 | typename Function, |
146 | typename Accessor, |
147 | size_t... Is> |
148 | ReturnType unpack(Function function, Accessor accessor, Indices<Is...>) { |
149 | return ReturnType(function(accessor.template operator()<Ts>(Is)...)); |
150 | } |
151 | |
152 | } // namespace torch |
153 | |