1#pragma once
2
3#include <ATen/core/Tensor.h>
4#include <ATen/core/Variadic.h>
5#include <torch/csrc/autograd/variable.h>
6
7#include <cstdint>
8#include <tuple>
9#include <type_traits>
10#include <utility>
11
12namespace torch {
13
14using at::IterArgs;
15
16struct CountTensors : IterArgs<CountTensors> {
17 size_t out = 0;
18 void operator()(const at::Tensor& x) {
19 out += 1;
20 }
21 void operator()(const c10::optional<at::Tensor>& x) {
22 out += x.has_value();
23 }
24 void operator()(at::ArrayRef<at::Tensor> xs) {
25 out += xs.size();
26 }
27};
28
29template <typename... Args>
30size_t count_tensors(Args&&... args) {
31 return CountTensors().apply(std::forward<Args>(args)...).out;
32}
33
34struct CountVariables : IterArgs<CountVariables> {
35 size_t out = 0;
36 void operator()(const autograd::Variable& x) {
37 out += 1;
38 }
39 void operator()(at::ArrayRef<autograd::Variable> xs) {
40 out += xs.size();
41 }
42};
43
44template <typename... Args>
45inline size_t count_variables(Args&&... args) {
46 return CountVariables().apply(std::forward<Args>(args)...).out;
47}
48
49//===----------------------------------------------------------------------===//
50// std::index_sequence shim for C++11
51//===----------------------------------------------------------------------===//
52
53// A container of type-template parameter indices.
54template <size_t... Is>
55struct Indices {};
56
57// Decrements the index N, adds N-1 to the list of indices and forwards
58// whatever we already have.
59template <size_t N, size_t... Is>
60struct MakeIndices : MakeIndices<N - 1, N - 1, Is...> {};
61
62// Partial specialization that forms our base case. When N is zero, we stop
63// and define a typedef that will be visible to earlier classes due to
64// inheritance. The typedef we define is an index list containing the numbers
65// 0 through N-1.
66template <size_t... Is>
67struct MakeIndices<0, Is...> {
68 using indices = Indices<Is...>;
69};
70
71//===----------------------------------------------------------------------===//
72// Utilities
73//===----------------------------------------------------------------------===//
74
75template <bool value, typename T = void>
76using enable_if_t = typename std::enable_if<value, T>::type;
77
78template <bool value, typename T = void>
79using disable_if_t = enable_if_t<!value, T>;
80
81template <typename T>
82using decay_t = typename std::decay<T>::type;
83
84namespace detail {
85template <bool...>
86struct pack;
87} // namespace detail
88
89template <bool... values>
90struct all_of : std::is_same<
91 detail::pack<values..., true>,
92 detail::pack<true, values...>> {};
93
94template <bool...>
95struct any_of;
96
97template <>
98struct any_of<> : std::false_type {};
99
100template <bool head, bool... tail>
101struct any_of<head, tail...> {
102 static constexpr bool value = head || any_of<tail...>::value;
103};
104
105template <bool... values>
106struct none_of {
107 static constexpr bool value = !any_of<values...>::value;
108};
109
110template <bool... values>
111using enable_if_all_of_t = enable_if_t<all_of<values...>::value>;
112
113template <typename T, typename... Ts>
114using disable_if_contains_t =
115 enable_if_all_of_t<(!std::is_same<T, decay_t<Ts>>::value)...>;
116
117template <typename Function, typename... Ts>
118void apply(Function function, Ts&&... ts) {
119 // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector
120 // Creates a dummy array, so that each function call is evaluated in order.
121 // `(function(), 0)` is because `function` should (!) return `void`, so
122 // according to the comma operator, it is evaluated and its result (`void`)
123 // is discarded. Then the zero is evaluated and used as an element in the
124 // array. The first zero ensures the array is not empty.
125 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
126 int _[]{0, (function(std::forward<Ts>(ts)), 0)...};
127 (void)_;
128}
129
130template <
131 typename ReturnType,
132 typename... Ts,
133 typename Function,
134 typename Accessor>
135ReturnType unpack(Function function, Accessor accessor) {
136 return ReturnType(unpack<ReturnType, Ts...>(
137 std::move(function),
138 std::move(accessor),
139 typename MakeIndices<sizeof...(Ts)>::indices()));
140}
141
142template <
143 typename ReturnType,
144 typename... Ts,
145 typename Function,
146 typename Accessor,
147 size_t... Is>
148ReturnType unpack(Function function, Accessor accessor, Indices<Is...>) {
149 return ReturnType(function(accessor.template operator()<Ts>(Is)...));
150}
151
152} // namespace torch
153