1 | // Copyright 2004-present Facebook. All Rights Reserved. |
2 | |
3 | #pragma once |
4 | |
5 | #include <c10/util/ArrayRef.h> |
6 | |
7 | #include <iterator> |
8 | #include <numeric> |
9 | #include <type_traits> |
10 | |
11 | namespace c10 { |
12 | |
13 | /// Sum of a list of integers; accumulates into the int64_t datatype |
14 | template < |
15 | typename C, |
16 | typename std::enable_if< |
17 | std::is_integral<typename C::value_type>::value, |
18 | int>::type = 0> |
19 | inline int64_t sum_integers(const C& container) { |
20 | // std::accumulate infers return type from `init` type, so if the `init` type |
21 | // is not large enough to hold the result, computation can overflow. We use |
22 | // `int64_t` here to avoid this. |
23 | return std::accumulate( |
24 | container.begin(), container.end(), static_cast<int64_t>(0)); |
25 | } |
26 | |
27 | /// Sum of integer elements referred to by iterators; accumulates into the |
28 | /// int64_t datatype |
29 | template < |
30 | typename Iter, |
31 | typename std::enable_if< |
32 | std::is_integral< |
33 | typename std::iterator_traits<Iter>::value_type>::value, |
34 | int>::type = 0> |
35 | inline int64_t sum_integers(Iter begin, Iter end) { |
36 | // std::accumulate infers return type from `init` type, so if the `init` type |
37 | // is not large enough to hold the result, computation can overflow. We use |
38 | // `int64_t` here to avoid this. |
39 | return std::accumulate(begin, end, static_cast<int64_t>(0)); |
40 | } |
41 | |
42 | /// Product of a list of integers; accumulates into the int64_t datatype |
43 | template < |
44 | typename C, |
45 | typename std::enable_if< |
46 | std::is_integral<typename C::value_type>::value, |
47 | int>::type = 0> |
48 | inline int64_t multiply_integers(const C& container) { |
49 | // std::accumulate infers return type from `init` type, so if the `init` type |
50 | // is not large enough to hold the result, computation can overflow. We use |
51 | // `int64_t` here to avoid this. |
52 | return std::accumulate( |
53 | container.begin(), |
54 | container.end(), |
55 | static_cast<int64_t>(1), |
56 | std::multiplies<>()); |
57 | } |
58 | |
59 | /// Product of integer elements referred to by iterators; accumulates into the |
60 | /// int64_t datatype |
61 | template < |
62 | typename Iter, |
63 | typename std::enable_if< |
64 | std::is_integral< |
65 | typename std::iterator_traits<Iter>::value_type>::value, |
66 | int>::type = 0> |
67 | inline int64_t multiply_integers(Iter begin, Iter end) { |
68 | // std::accumulate infers return type from `init` type, so if the `init` type |
69 | // is not large enough to hold the result, computation can overflow. We use |
70 | // `int64_t` here to avoid this. |
71 | return std::accumulate( |
72 | begin, end, static_cast<int64_t>(1), std::multiplies<>()); |
73 | } |
74 | |
75 | /// Return product of all dimensions starting from k |
76 | /// Returns 1 if k>=dims.size() |
77 | template < |
78 | typename C, |
79 | typename std::enable_if< |
80 | std::is_integral<typename C::value_type>::value, |
81 | int>::type = 0> |
82 | inline int64_t numelements_from_dim(const int k, const C& dims) { |
83 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0); |
84 | |
85 | if (k > static_cast<int>(dims.size())) { |
86 | return 1; |
87 | } else { |
88 | auto cbegin = dims.cbegin(); |
89 | std::advance(cbegin, k); |
90 | return multiply_integers(cbegin, dims.cend()); |
91 | } |
92 | } |
93 | |
94 | /// Product of all dims up to k (not including dims[k]) |
95 | /// Throws an error if k>dims.size() |
96 | template < |
97 | typename C, |
98 | typename std::enable_if< |
99 | std::is_integral<typename C::value_type>::value, |
100 | int>::type = 0> |
101 | inline int64_t numelements_to_dim(const int k, const C& dims) { |
102 | TORCH_INTERNAL_ASSERT(0 <= k); |
103 | TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size()); |
104 | |
105 | auto cend = dims.cbegin(); |
106 | std::advance(cend, k); |
107 | return multiply_integers(dims.cbegin(), cend); |
108 | } |
109 | |
110 | /// Product of all dims between k and l (including dims[k] and excluding |
111 | /// dims[l]) k and l may be supplied in either order |
112 | template < |
113 | typename C, |
114 | typename std::enable_if< |
115 | std::is_integral<typename C::value_type>::value, |
116 | int>::type = 0> |
117 | inline int64_t numelements_between_dim(int k, int l, const C& dims) { |
118 | TORCH_INTERNAL_ASSERT(0 <= k); |
119 | TORCH_INTERNAL_ASSERT(0 <= l); |
120 | |
121 | if (k > l) { |
122 | std::swap(k, l); |
123 | } |
124 | |
125 | TORCH_INTERNAL_ASSERT((unsigned)l < dims.size()); |
126 | |
127 | auto cbegin = dims.cbegin(); |
128 | auto cend = dims.cbegin(); |
129 | std::advance(cbegin, k); |
130 | std::advance(cend, l); |
131 | return multiply_integers(cbegin, cend); |
132 | } |
133 | |
134 | } // namespace c10 |
135 | |