1 | #pragma once |
2 | |
3 | #if defined(USE_GTEST) |
4 | #include <gtest/gtest.h> |
5 | #include <test/cpp/common/support.h> |
6 | #else |
7 | #include <cmath> |
8 | #include "c10/util/Exception.h" |
9 | #include "test/cpp/tensorexpr/gtest_assert_float_eq.h" |
10 | #define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__) |
11 | #define ASSERT_FLOAT_EQ(x, y, ...) \ |
12 | TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__) |
13 | #define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__) |
14 | #define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__) |
15 | #define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__) |
16 | #define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__) |
17 | #define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__) |
18 | |
19 | #define ASSERT_NEAR(x, y, a, ...) \ |
20 | TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__) |
21 | |
22 | #define ASSERT_TRUE TORCH_INTERNAL_ASSERT |
23 | #define ASSERT_FALSE(x) ASSERT_TRUE(!(x)) |
24 | #define ASSERT_THROWS_WITH(statement, substring) \ |
25 | try { \ |
26 | (void)statement; \ |
27 | ASSERT_TRUE(false); \ |
28 | } catch (const std::exception& e) { \ |
29 | ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ |
30 | } |
31 | #define ASSERT_ANY_THROW(statement) \ |
32 | { \ |
33 | bool threw = false; \ |
34 | try { \ |
35 | (void)statement; \ |
36 | } catch (const std::exception& e) { \ |
37 | threw = true; \ |
38 | } \ |
39 | ASSERT_TRUE(threw); \ |
40 | } |
41 | |
42 | #endif // defined(USE_GTEST) |
43 | |
44 | namespace torch { |
45 | namespace jit { |
46 | namespace tensorexpr { |
47 | |
48 | template <typename U, typename V> |
49 | void ExpectAllNear( |
50 | const std::vector<U>& v1, |
51 | const std::vector<U>& v2, |
52 | V threshold, |
53 | const std::string& name = "" ) { |
54 | ASSERT_EQ(v1.size(), v2.size()); |
55 | for (size_t i = 0; i < v1.size(); i++) { |
56 | ASSERT_NEAR(v1[i], v2[i], threshold); |
57 | } |
58 | } |
59 | |
60 | template <typename U, typename V> |
61 | void ExpectAllNear( |
62 | const std::vector<U>& vec, |
63 | const U& val, |
64 | V threshold, |
65 | const std::string& name = "" ) { |
66 | for (size_t i = 0; i < vec.size(); i++) { |
67 | ASSERT_NEAR(vec[i], val, threshold); |
68 | } |
69 | } |
70 | |
71 | template <typename T> |
72 | static void assertAllEqual(const std::vector<T>& vec, const T& val) { |
73 | for (auto const& elt : vec) { |
74 | ASSERT_EQ(elt, val); |
75 | } |
76 | } |
77 | |
78 | template <typename T> |
79 | static void assertAllEqual(const std::vector<T>& v1, const std::vector<T>& v2) { |
80 | ASSERT_EQ(v1.size(), v2.size()); |
81 | for (size_t i = 0; i < v1.size(); ++i) { |
82 | ASSERT_EQ(v1[i], v2[i]); |
83 | } |
84 | } |
85 | } // namespace tensorexpr |
86 | } // namespace jit |
87 | } // namespace torch |
88 | |