1// Copyright 2004-present Facebook. All Rights Reserved.
2
3#include <c10/util/accumulate.h>
4
5#include <gtest/gtest.h>
6
7#include <list>
8#include <vector>
9
10using namespace ::testing;
11
12TEST(accumulate_test, vector_test) {
13 std::vector<int> ints = {1, 2, 3, 4, 5};
14
15 EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
16 EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
17
18 EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
19 EXPECT_EQ(
20 c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
21
22 EXPECT_EQ(c10::sum_integers(ints.begin() + 1, ints.end() - 1), 2 + 3 + 4);
23 EXPECT_EQ(
24 c10::multiply_integers(ints.begin() + 1, ints.end() - 1), 2 * 3 * 4);
25
26 EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
27 EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
28 EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
29 EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
30}
31
32TEST(accumulate_test, list_test) {
33 std::list<int> ints = {1, 2, 3, 4, 5};
34
35 EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
36 EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
37
38 EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
39 EXPECT_EQ(
40 c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
41
42 EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
43 EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
44 EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
45 EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
46}
47
48TEST(accumulate_test, base_cases) {
49 std::vector<int> ints = {};
50
51 EXPECT_EQ(c10::sum_integers(ints), 0);
52 EXPECT_EQ(c10::multiply_integers(ints), 1);
53}
54
55TEST(accumulate_test, errors) {
56 std::vector<int> ints = {1, 2, 3, 4, 5};
57
58#ifndef NDEBUG
59 EXPECT_THROW(c10::numelements_from_dim(-1, ints), c10::Error);
60#endif
61
62 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
63 EXPECT_THROW(c10::numelements_to_dim(-1, ints), c10::Error);
64 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
65 EXPECT_THROW(c10::numelements_between_dim(-1, 10, ints), c10::Error);
66 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
67 EXPECT_THROW(c10::numelements_between_dim(10, -1, ints), c10::Error);
68
69 EXPECT_EQ(c10::numelements_from_dim(10, ints), 1);
70 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
71 EXPECT_THROW(c10::numelements_to_dim(10, ints), c10::Error);
72 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
73 EXPECT_THROW(c10::numelements_between_dim(10, 4, ints), c10::Error);
74 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
75 EXPECT_THROW(c10::numelements_between_dim(4, 10, ints), c10::Error);
76}
77