1// Copyright 2004-present Facebook. All Rights Reserved.
2
3#pragma once
4
5#include <c10/util/ArrayRef.h>
6
7#include <iterator>
8#include <numeric>
9#include <type_traits>
10
11namespace c10 {
12
13/// Sum of a list of integers; accumulates into the int64_t datatype
14template <
15 typename C,
16 typename std::enable_if<
17 std::is_integral<typename C::value_type>::value,
18 int>::type = 0>
19inline int64_t sum_integers(const C& container) {
20 // std::accumulate infers return type from `init` type, so if the `init` type
21 // is not large enough to hold the result, computation can overflow. We use
22 // `int64_t` here to avoid this.
23 return std::accumulate(
24 container.begin(), container.end(), static_cast<int64_t>(0));
25}
26
27/// Sum of integer elements referred to by iterators; accumulates into the
28/// int64_t datatype
29template <
30 typename Iter,
31 typename std::enable_if<
32 std::is_integral<
33 typename std::iterator_traits<Iter>::value_type>::value,
34 int>::type = 0>
35inline int64_t sum_integers(Iter begin, Iter end) {
36 // std::accumulate infers return type from `init` type, so if the `init` type
37 // is not large enough to hold the result, computation can overflow. We use
38 // `int64_t` here to avoid this.
39 return std::accumulate(begin, end, static_cast<int64_t>(0));
40}
41
42/// Product of a list of integers; accumulates into the int64_t datatype
43template <
44 typename C,
45 typename std::enable_if<
46 std::is_integral<typename C::value_type>::value,
47 int>::type = 0>
48inline int64_t multiply_integers(const C& container) {
49 // std::accumulate infers return type from `init` type, so if the `init` type
50 // is not large enough to hold the result, computation can overflow. We use
51 // `int64_t` here to avoid this.
52 return std::accumulate(
53 container.begin(),
54 container.end(),
55 static_cast<int64_t>(1),
56 std::multiplies<>());
57}
58
59/// Product of integer elements referred to by iterators; accumulates into the
60/// int64_t datatype
61template <
62 typename Iter,
63 typename std::enable_if<
64 std::is_integral<
65 typename std::iterator_traits<Iter>::value_type>::value,
66 int>::type = 0>
67inline int64_t multiply_integers(Iter begin, Iter end) {
68 // std::accumulate infers return type from `init` type, so if the `init` type
69 // is not large enough to hold the result, computation can overflow. We use
70 // `int64_t` here to avoid this.
71 return std::accumulate(
72 begin, end, static_cast<int64_t>(1), std::multiplies<>());
73}
74
75/// Return product of all dimensions starting from k
76/// Returns 1 if k>=dims.size()
77template <
78 typename C,
79 typename std::enable_if<
80 std::is_integral<typename C::value_type>::value,
81 int>::type = 0>
82inline int64_t numelements_from_dim(const int k, const C& dims) {
83 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0);
84
85 if (k > static_cast<int>(dims.size())) {
86 return 1;
87 } else {
88 auto cbegin = dims.cbegin();
89 std::advance(cbegin, k);
90 return multiply_integers(cbegin, dims.cend());
91 }
92}
93
94/// Product of all dims up to k (not including dims[k])
95/// Throws an error if k>dims.size()
96template <
97 typename C,
98 typename std::enable_if<
99 std::is_integral<typename C::value_type>::value,
100 int>::type = 0>
101inline int64_t numelements_to_dim(const int k, const C& dims) {
102 TORCH_INTERNAL_ASSERT(0 <= k);
103 TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size());
104
105 auto cend = dims.cbegin();
106 std::advance(cend, k);
107 return multiply_integers(dims.cbegin(), cend);
108}
109
110/// Product of all dims between k and l (including dims[k] and excluding
111/// dims[l]) k and l may be supplied in either order
112template <
113 typename C,
114 typename std::enable_if<
115 std::is_integral<typename C::value_type>::value,
116 int>::type = 0>
117inline int64_t numelements_between_dim(int k, int l, const C& dims) {
118 TORCH_INTERNAL_ASSERT(0 <= k);
119 TORCH_INTERNAL_ASSERT(0 <= l);
120
121 if (k > l) {
122 std::swap(k, l);
123 }
124
125 TORCH_INTERNAL_ASSERT((unsigned)l < dims.size());
126
127 auto cbegin = dims.cbegin();
128 auto cend = dims.cbegin();
129 std::advance(cbegin, k);
130 std::advance(cend, l);
131 return multiply_integers(cbegin, cend);
132}
133
134} // namespace c10
135