1#pragma once
2
3#if defined(USE_GTEST)
4#include <gtest/gtest.h>
5#include <test/cpp/common/support.h>
6#else
7#include <cmath>
8#include "c10/util/Exception.h"
9#include "test/cpp/tensorexpr/gtest_assert_float_eq.h"
10#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__)
11#define ASSERT_FLOAT_EQ(x, y, ...) \
12 TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__)
13#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__)
14#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__)
15#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__)
16#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__)
17#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__)
18
19#define ASSERT_NEAR(x, y, a, ...) \
20 TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__)
21
22#define ASSERT_TRUE TORCH_INTERNAL_ASSERT
23#define ASSERT_FALSE(x) ASSERT_TRUE(!(x))
24#define ASSERT_THROWS_WITH(statement, substring) \
25 try { \
26 (void)statement; \
27 ASSERT_TRUE(false); \
28 } catch (const std::exception& e) { \
29 ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
30 }
31#define ASSERT_ANY_THROW(statement) \
32 { \
33 bool threw = false; \
34 try { \
35 (void)statement; \
36 } catch (const std::exception& e) { \
37 threw = true; \
38 } \
39 ASSERT_TRUE(threw); \
40 }
41
42#endif // defined(USE_GTEST)
43
44namespace torch {
45namespace jit {
46namespace tensorexpr {
47
48template <typename U, typename V>
49void ExpectAllNear(
50 const std::vector<U>& v1,
51 const std::vector<U>& v2,
52 V threshold,
53 const std::string& name = "") {
54 ASSERT_EQ(v1.size(), v2.size());
55 for (size_t i = 0; i < v1.size(); i++) {
56 ASSERT_NEAR(v1[i], v2[i], threshold);
57 }
58}
59
60template <typename U, typename V>
61void ExpectAllNear(
62 const std::vector<U>& vec,
63 const U& val,
64 V threshold,
65 const std::string& name = "") {
66 for (size_t i = 0; i < vec.size(); i++) {
67 ASSERT_NEAR(vec[i], val, threshold);
68 }
69}
70
71template <typename T>
72static void assertAllEqual(const std::vector<T>& vec, const T& val) {
73 for (auto const& elt : vec) {
74 ASSERT_EQ(elt, val);
75 }
76}
77
78template <typename T>
79static void assertAllEqual(const std::vector<T>& v1, const std::vector<T>& v2) {
80 ASSERT_EQ(v1.size(), v2.size());
81 for (size_t i = 0; i < v1.size(); ++i) {
82 ASSERT_EQ(v1[i], v2[i]);
83 }
84}
85} // namespace tensorexpr
86} // namespace jit
87} // namespace torch
88