1 | #pragma once |
2 | #include <c10/util/Exception.h> |
3 | |
4 | #include <mutex> |
5 | #include <vector> |
6 | |
7 | namespace torch { |
8 | namespace autograd { |
9 | namespace utils { |
10 | |
11 | // Warning handler for multi-threaded contexts. Gather warnings from |
12 | // all threads into a single queue, then process together at the end |
13 | // in the main thread. |
14 | class DelayWarningHandler : public at::WarningHandler { |
15 | public: |
16 | ~DelayWarningHandler() override = default; |
17 | void replay_warnings(); |
18 | |
19 | private: |
20 | void process(const c10::Warning& warning) override; |
21 | |
22 | std::vector<c10::Warning> warnings_; |
23 | std::mutex mutex_; |
24 | }; |
25 | |
26 | } // namespace utils |
27 | } // namespace autograd |
28 | } // namespace torch |
29 | |