1#include <c10/util/Backtrace.h>
2#include <c10/util/Exception.h>
3#include <c10/util/Logging.h>
4#include <c10/util/Type.h>
5
6#include <iostream>
7#include <numeric>
8#include <sstream>
9#include <string>
10#include <utility>
11
12namespace c10 {
13
14Error::Error(std::string msg, std::string backtrace, const void* caller)
15 : msg_(std::move(msg)), backtrace_(std::move(backtrace)), caller_(caller) {
16 refresh_what();
17}
18
19// PyTorch-style error message
20// Error::Error(SourceLocation source_location, const std::string& msg)
21// NB: This is defined in Logging.cpp for access to GetFetchStackTrace
22
23// Caffe2-style error message
24Error::Error(
25 const char* file,
26 const uint32_t line,
27 const char* condition,
28 const std::string& msg,
29 const std::string& backtrace,
30 const void* caller)
31 : Error(
32 str("[enforce fail at ",
33 detail::StripBasename(file),
34 ":",
35 line,
36 "] ",
37 condition,
38 ". ",
39 msg),
40 backtrace,
41 caller) {}
42
43std::string Error::compute_what(bool include_backtrace) const {
44 std::ostringstream oss;
45
46 oss << msg_;
47
48 if (context_.size() == 1) {
49 // Fold error and context in one line
50 oss << " (" << context_[0] << ")";
51 } else {
52 for (const auto& c : context_) {
53 oss << "\n " << c;
54 }
55 }
56
57 if (include_backtrace) {
58 oss << "\n" << backtrace_;
59 }
60
61 return oss.str();
62}
63
64void Error::refresh_what() {
65 what_ = compute_what(/*include_backtrace*/ true);
66 what_without_backtrace_ = compute_what(/*include_backtrace*/ false);
67}
68
69void Error::add_context(std::string new_msg) {
70 context_.push_back(std::move(new_msg));
71 // TODO: Calling add_context O(n) times has O(n^2) cost. We can fix
72 // this perf problem by populating the fields lazily... if this ever
73 // actually is a problem.
74 // NB: If you do fix this, make sure you do it in a thread safe way!
75 // what() is almost certainly expected to be thread safe even when
76 // accessed across multiple threads
77 refresh_what();
78}
79
80namespace detail {
81
82void torchCheckFail(
83 const char* func,
84 const char* file,
85 uint32_t line,
86 const std::string& msg) {
87 throw ::c10::Error({func, file, line}, msg);
88}
89
90void torchCheckFail(
91 const char* func,
92 const char* file,
93 uint32_t line,
94 const char* msg) {
95 throw ::c10::Error({func, file, line}, msg);
96}
97
98void torchInternalAssertFail(
99 const char* func,
100 const char* file,
101 uint32_t line,
102 const char* condMsg,
103 const char* userMsg) {
104 torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
105}
106
107// This should never be called. It is provided in case of compilers
108// that don't do any dead code stripping in debug builds.
109void torchInternalAssertFail(
110 const char* func,
111 const char* file,
112 uint32_t line,
113 const char* condMsg,
114 const std::string& userMsg) {
115 torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
116}
117
118} // namespace detail
119
120namespace WarningUtils {
121
122namespace {
123WarningHandler* getBaseHandler() {
124 static WarningHandler base_warning_handler_ = WarningHandler();
125 return &base_warning_handler_;
126};
127
128class ThreadWarningHandler {
129 public:
130 ThreadWarningHandler() = delete;
131
132 static WarningHandler* get_handler() {
133 if (!warning_handler_) {
134 warning_handler_ = getBaseHandler();
135 }
136 return warning_handler_;
137 }
138
139 static void set_handler(WarningHandler* handler) {
140 warning_handler_ = handler;
141 }
142
143 private:
144 static thread_local WarningHandler* warning_handler_;
145};
146
147thread_local WarningHandler* ThreadWarningHandler::warning_handler_ = nullptr;
148
149} // namespace
150
151void set_warning_handler(WarningHandler* handler) noexcept(true) {
152 ThreadWarningHandler::set_handler(handler);
153}
154
155WarningHandler* get_warning_handler() noexcept(true) {
156 return ThreadWarningHandler::get_handler();
157}
158
159bool warn_always = false;
160
161void set_warnAlways(bool setting) noexcept(true) {
162 warn_always = setting;
163}
164
165bool get_warnAlways() noexcept(true) {
166 return warn_always;
167}
168
169WarnAlways::WarnAlways(bool setting /*=true*/)
170 : prev_setting(get_warnAlways()) {
171 set_warnAlways(setting);
172}
173
174WarnAlways::~WarnAlways() {
175 set_warnAlways(prev_setting);
176}
177
178} // namespace WarningUtils
179
180void warn(const Warning& warning) {
181 WarningUtils::ThreadWarningHandler::get_handler()->process(warning);
182}
183
184Warning::Warning(
185 warning_variant_t type,
186 const SourceLocation& source_location,
187 std::string msg,
188 const bool verbatim)
189 : type_(type),
190 source_location_(source_location),
191 msg_(std::move(msg)),
192 verbatim_(verbatim) {}
193
194Warning::Warning(
195 warning_variant_t type,
196 SourceLocation source_location,
197 detail::CompileTimeEmptyString msg,
198 const bool verbatim)
199 : Warning(type, source_location, "", verbatim) {}
200
201Warning::Warning(
202 warning_variant_t type,
203 SourceLocation source_location,
204 const char* msg,
205 const bool verbatim)
206 : type_(type),
207 source_location_(source_location),
208 msg_(std::string(msg)),
209 verbatim_(verbatim) {}
210
211Warning::warning_variant_t Warning::type() const {
212 return type_;
213}
214
215const SourceLocation& Warning::source_location() const {
216 return source_location_;
217}
218
219const std::string& Warning::msg() const {
220 return msg_;
221}
222
223bool Warning::verbatim() const {
224 return verbatim_;
225}
226
227void WarningHandler::process(const Warning& warning) {
228 LOG_AT_FILE_LINE(
229 WARNING, warning.source_location().file, warning.source_location().line)
230 << "Warning: " << warning.msg() << " (function "
231 << warning.source_location().function << ")";
232}
233
234std::string GetExceptionString(const std::exception& e) {
235#ifdef __GXX_RTTI
236 return demangle(typeid(e).name()) + ": " + e.what();
237#else
238 return std::string("Exception (no RTTI available): ") + e.what();
239#endif // __GXX_RTTI
240}
241
242} // namespace c10
243