1#include <c10/util/Exception.h>
2#include <gtest/gtest.h>
3#include <stdexcept>
4
5using c10::Error;
6
7namespace {
8
9template <class Functor>
10inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
11 try {
12 std::forward<Functor>(functor)();
13 } catch (const Error& e) {
14 EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
15 return;
16 }
17 ADD_FAILURE() << "Expected to throw exception with message \""
18 << expectedMessage << "\" but didn't throw";
19}
20} // namespace
21
22TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) {
23#ifdef NDEBUG
24 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
25 ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false));
26 // Does nothing - `throw ...` should not be evaluated
27 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
28 ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
29 (throw std::runtime_error("I'm throwing..."), true)));
30#else
31 ASSERT_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false), c10::Error);
32 ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(true));
33#endif
34}
35
36// On these platforms there's no assert
37#if !defined(__ANDROID__) && !defined(__APPLE__) && \
38 !(defined(USE_ROCM) && ROCM_VERSION < 40100)
39TEST(ExceptionTest, CUDA_KERNEL_ASSERT) {
40 // This function always throws even in NDEBUG mode
41 ASSERT_DEATH_IF_SUPPORTED({ CUDA_KERNEL_ASSERT(false); }, "Assert");
42}
43#endif
44
45TEST(WarningTest, JustPrintWarning) {
46 TORCH_WARN("I'm a warning");
47}
48
49TEST(ExceptionTest, ErrorFormatting) {
50 expectThrowsEq(
51 []() { TORCH_CHECK(false, "This is invalid"); }, "This is invalid");
52
53 expectThrowsEq(
54 []() {
55 try {
56 TORCH_CHECK(false, "This is invalid");
57 } catch (Error& e) {
58 TORCH_RETHROW(e, "While checking X");
59 }
60 },
61 "This is invalid (While checking X)");
62
63 expectThrowsEq(
64 []() {
65 try {
66 try {
67 TORCH_CHECK(false, "This is invalid");
68 } catch (Error& e) {
69 TORCH_RETHROW(e, "While checking X");
70 }
71 } catch (Error& e) {
72 TORCH_RETHROW(e, "While checking Y");
73 }
74 },
75 R"msg(This is invalid
76 While checking X
77 While checking Y)msg");
78}
79
80static int assertionArgumentCounter = 0;
81static int getAssertionArgument() {
82 return ++assertionArgumentCounter;
83}
84
85static void failCheck() {
86 TORCH_CHECK(false, "message ", getAssertionArgument());
87}
88
89static void failInternalAssert() {
90 TORCH_INTERNAL_ASSERT(false, "message ", getAssertionArgument());
91}
92
93TEST(ExceptionTest, DontCallArgumentFunctionsTwiceOnFailure) {
94 assertionArgumentCounter = 0;
95 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
96 EXPECT_ANY_THROW(failCheck());
97 EXPECT_EQ(assertionArgumentCounter, 1) << "TORCH_CHECK called argument twice";
98
99 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
100 EXPECT_ANY_THROW(failInternalAssert());
101 EXPECT_EQ(assertionArgumentCounter, 2)
102 << "TORCH_INTERNAL_ASSERT called argument twice";
103}
104