1 | #include <c10/util/Exception.h> |
2 | #include <gtest/gtest.h> |
3 | #include <stdexcept> |
4 | |
5 | using c10::Error; |
6 | |
7 | namespace { |
8 | |
9 | template <class Functor> |
10 | inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) { |
11 | try { |
12 | std::forward<Functor>(functor)(); |
13 | } catch (const Error& e) { |
14 | EXPECT_STREQ(e.what_without_backtrace(), expectedMessage); |
15 | return; |
16 | } |
17 | ADD_FAILURE() << "Expected to throw exception with message \"" |
18 | << expectedMessage << "\" but didn't throw" ; |
19 | } |
20 | } // namespace |
21 | |
22 | TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) { |
23 | #ifdef NDEBUG |
24 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
25 | ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false)); |
26 | // Does nothing - `throw ...` should not be evaluated |
27 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
28 | ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
29 | (throw std::runtime_error("I'm throwing..." ), true))); |
30 | #else |
31 | ASSERT_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false), c10::Error); |
32 | ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(true)); |
33 | #endif |
34 | } |
35 | |
36 | // On these platforms there's no assert |
37 | #if !defined(__ANDROID__) && !defined(__APPLE__) && \ |
38 | !(defined(USE_ROCM) && ROCM_VERSION < 40100) |
39 | TEST(ExceptionTest, CUDA_KERNEL_ASSERT) { |
40 | // This function always throws even in NDEBUG mode |
41 | ASSERT_DEATH_IF_SUPPORTED({ CUDA_KERNEL_ASSERT(false); }, "Assert" ); |
42 | } |
43 | #endif |
44 | |
45 | TEST(WarningTest, JustPrintWarning) { |
46 | TORCH_WARN("I'm a warning" ); |
47 | } |
48 | |
49 | TEST(ExceptionTest, ErrorFormatting) { |
50 | expectThrowsEq( |
51 | []() { TORCH_CHECK(false, "This is invalid" ); }, "This is invalid" ); |
52 | |
53 | expectThrowsEq( |
54 | []() { |
55 | try { |
56 | TORCH_CHECK(false, "This is invalid" ); |
57 | } catch (Error& e) { |
58 | TORCH_RETHROW(e, "While checking X" ); |
59 | } |
60 | }, |
61 | "This is invalid (While checking X)" ); |
62 | |
63 | expectThrowsEq( |
64 | []() { |
65 | try { |
66 | try { |
67 | TORCH_CHECK(false, "This is invalid" ); |
68 | } catch (Error& e) { |
69 | TORCH_RETHROW(e, "While checking X" ); |
70 | } |
71 | } catch (Error& e) { |
72 | TORCH_RETHROW(e, "While checking Y" ); |
73 | } |
74 | }, |
75 | R"msg(This is invalid |
76 | While checking X |
77 | While checking Y)msg" ); |
78 | } |
79 | |
80 | static int assertionArgumentCounter = 0; |
81 | static int getAssertionArgument() { |
82 | return ++assertionArgumentCounter; |
83 | } |
84 | |
85 | static void failCheck() { |
86 | TORCH_CHECK(false, "message " , getAssertionArgument()); |
87 | } |
88 | |
89 | static void failInternalAssert() { |
90 | TORCH_INTERNAL_ASSERT(false, "message " , getAssertionArgument()); |
91 | } |
92 | |
93 | TEST(ExceptionTest, DontCallArgumentFunctionsTwiceOnFailure) { |
94 | assertionArgumentCounter = 0; |
95 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
96 | EXPECT_ANY_THROW(failCheck()); |
97 | EXPECT_EQ(assertionArgumentCounter, 1) << "TORCH_CHECK called argument twice" ; |
98 | |
99 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
100 | EXPECT_ANY_THROW(failInternalAssert()); |
101 | EXPECT_EQ(assertionArgumentCounter, 2) |
102 | << "TORCH_INTERNAL_ASSERT called argument twice" ; |
103 | } |
104 | |