1 | #include <c10/util/Backtrace.h> |
2 | #include <c10/util/Exception.h> |
3 | #include <c10/util/Logging.h> |
4 | #include <c10/util/Type.h> |
5 | |
6 | #include <iostream> |
7 | #include <numeric> |
8 | #include <sstream> |
9 | #include <string> |
10 | #include <utility> |
11 | |
12 | namespace c10 { |
13 | |
14 | Error::Error(std::string msg, std::string backtrace, const void* caller) |
15 | : msg_(std::move(msg)), backtrace_(std::move(backtrace)), caller_(caller) { |
16 | refresh_what(); |
17 | } |
18 | |
19 | // PyTorch-style error message |
20 | // Error::Error(SourceLocation source_location, const std::string& msg) |
21 | // NB: This is defined in Logging.cpp for access to GetFetchStackTrace |
22 | |
23 | // Caffe2-style error message |
24 | Error::Error( |
25 | const char* file, |
26 | const uint32_t line, |
27 | const char* condition, |
28 | const std::string& msg, |
29 | const std::string& backtrace, |
30 | const void* caller) |
31 | : Error( |
32 | str("[enforce fail at " , |
33 | detail::StripBasename(file), |
34 | ":" , |
35 | line, |
36 | "] " , |
37 | condition, |
38 | ". " , |
39 | msg), |
40 | backtrace, |
41 | caller) {} |
42 | |
43 | std::string Error::compute_what(bool include_backtrace) const { |
44 | std::ostringstream oss; |
45 | |
46 | oss << msg_; |
47 | |
48 | if (context_.size() == 1) { |
49 | // Fold error and context in one line |
50 | oss << " (" << context_[0] << ")" ; |
51 | } else { |
52 | for (const auto& c : context_) { |
53 | oss << "\n " << c; |
54 | } |
55 | } |
56 | |
57 | if (include_backtrace) { |
58 | oss << "\n" << backtrace_; |
59 | } |
60 | |
61 | return oss.str(); |
62 | } |
63 | |
64 | void Error::refresh_what() { |
65 | what_ = compute_what(/*include_backtrace*/ true); |
66 | what_without_backtrace_ = compute_what(/*include_backtrace*/ false); |
67 | } |
68 | |
69 | void Error::add_context(std::string new_msg) { |
70 | context_.push_back(std::move(new_msg)); |
71 | // TODO: Calling add_context O(n) times has O(n^2) cost. We can fix |
72 | // this perf problem by populating the fields lazily... if this ever |
73 | // actually is a problem. |
74 | // NB: If you do fix this, make sure you do it in a thread safe way! |
75 | // what() is almost certainly expected to be thread safe even when |
76 | // accessed across multiple threads |
77 | refresh_what(); |
78 | } |
79 | |
80 | namespace detail { |
81 | |
82 | void torchCheckFail( |
83 | const char* func, |
84 | const char* file, |
85 | uint32_t line, |
86 | const std::string& msg) { |
87 | throw ::c10::Error({func, file, line}, msg); |
88 | } |
89 | |
90 | void torchCheckFail( |
91 | const char* func, |
92 | const char* file, |
93 | uint32_t line, |
94 | const char* msg) { |
95 | throw ::c10::Error({func, file, line}, msg); |
96 | } |
97 | |
98 | void torchInternalAssertFail( |
99 | const char* func, |
100 | const char* file, |
101 | uint32_t line, |
102 | const char* condMsg, |
103 | const char* userMsg) { |
104 | torchCheckFail(func, file, line, c10::str(condMsg, userMsg)); |
105 | } |
106 | |
107 | // This should never be called. It is provided in case of compilers |
108 | // that don't do any dead code stripping in debug builds. |
109 | void torchInternalAssertFail( |
110 | const char* func, |
111 | const char* file, |
112 | uint32_t line, |
113 | const char* condMsg, |
114 | const std::string& userMsg) { |
115 | torchCheckFail(func, file, line, c10::str(condMsg, userMsg)); |
116 | } |
117 | |
118 | } // namespace detail |
119 | |
120 | namespace WarningUtils { |
121 | |
122 | namespace { |
123 | WarningHandler* getBaseHandler() { |
124 | static WarningHandler base_warning_handler_ = WarningHandler(); |
125 | return &base_warning_handler_; |
126 | }; |
127 | |
128 | class ThreadWarningHandler { |
129 | public: |
130 | ThreadWarningHandler() = delete; |
131 | |
132 | static WarningHandler* get_handler() { |
133 | if (!warning_handler_) { |
134 | warning_handler_ = getBaseHandler(); |
135 | } |
136 | return warning_handler_; |
137 | } |
138 | |
139 | static void set_handler(WarningHandler* handler) { |
140 | warning_handler_ = handler; |
141 | } |
142 | |
143 | private: |
144 | static thread_local WarningHandler* warning_handler_; |
145 | }; |
146 | |
147 | thread_local WarningHandler* ThreadWarningHandler::warning_handler_ = nullptr; |
148 | |
149 | } // namespace |
150 | |
151 | void set_warning_handler(WarningHandler* handler) noexcept(true) { |
152 | ThreadWarningHandler::set_handler(handler); |
153 | } |
154 | |
155 | WarningHandler* get_warning_handler() noexcept(true) { |
156 | return ThreadWarningHandler::get_handler(); |
157 | } |
158 | |
159 | bool warn_always = false; |
160 | |
161 | void set_warnAlways(bool setting) noexcept(true) { |
162 | warn_always = setting; |
163 | } |
164 | |
165 | bool get_warnAlways() noexcept(true) { |
166 | return warn_always; |
167 | } |
168 | |
169 | WarnAlways::WarnAlways(bool setting /*=true*/) |
170 | : prev_setting(get_warnAlways()) { |
171 | set_warnAlways(setting); |
172 | } |
173 | |
174 | WarnAlways::~WarnAlways() { |
175 | set_warnAlways(prev_setting); |
176 | } |
177 | |
178 | } // namespace WarningUtils |
179 | |
180 | void warn(const Warning& warning) { |
181 | WarningUtils::ThreadWarningHandler::get_handler()->process(warning); |
182 | } |
183 | |
184 | Warning::Warning( |
185 | warning_variant_t type, |
186 | const SourceLocation& source_location, |
187 | std::string msg, |
188 | const bool verbatim) |
189 | : type_(type), |
190 | source_location_(source_location), |
191 | msg_(std::move(msg)), |
192 | verbatim_(verbatim) {} |
193 | |
194 | Warning::Warning( |
195 | warning_variant_t type, |
196 | SourceLocation source_location, |
197 | detail::CompileTimeEmptyString msg, |
198 | const bool verbatim) |
199 | : Warning(type, source_location, "" , verbatim) {} |
200 | |
201 | Warning::Warning( |
202 | warning_variant_t type, |
203 | SourceLocation source_location, |
204 | const char* msg, |
205 | const bool verbatim) |
206 | : type_(type), |
207 | source_location_(source_location), |
208 | msg_(std::string(msg)), |
209 | verbatim_(verbatim) {} |
210 | |
211 | Warning::warning_variant_t Warning::type() const { |
212 | return type_; |
213 | } |
214 | |
215 | const SourceLocation& Warning::source_location() const { |
216 | return source_location_; |
217 | } |
218 | |
219 | const std::string& Warning::msg() const { |
220 | return msg_; |
221 | } |
222 | |
223 | bool Warning::verbatim() const { |
224 | return verbatim_; |
225 | } |
226 | |
227 | void WarningHandler::process(const Warning& warning) { |
228 | LOG_AT_FILE_LINE( |
229 | WARNING, warning.source_location().file, warning.source_location().line) |
230 | << "Warning: " << warning.msg() << " (function " |
231 | << warning.source_location().function << ")" ; |
232 | } |
233 | |
234 | std::string GetExceptionString(const std::exception& e) { |
235 | #ifdef __GXX_RTTI |
236 | return demangle(typeid(e).name()) + ": " + e.what(); |
237 | #else |
238 | return std::string("Exception (no RTTI available): " ) + e.what(); |
239 | #endif // __GXX_RTTI |
240 | } |
241 | |
242 | } // namespace c10 |
243 | |