1#pragma once
2#include <c10/util/Exception.h>
3
4#include <mutex>
5#include <vector>
6
7namespace torch {
8namespace autograd {
9namespace utils {
10
11// Warning handler for multi-threaded contexts. Gather warnings from
12// all threads into a single queue, then process together at the end
13// in the main thread.
14class DelayWarningHandler : public at::WarningHandler {
15 public:
16 ~DelayWarningHandler() override = default;
17 void replay_warnings();
18
19 private:
20 void process(const c10::Warning& warning) override;
21
22 std::vector<c10::Warning> warnings_;
23 std::mutex mutex_;
24};
25
26} // namespace utils
27} // namespace autograd
28} // namespace torch
29