1 | #include <torch/csrc/jit/passes/cuda_graph_fuser.h> |
---|---|
2 | #include <mutex> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | |
7 | static CudaFuserComparisonCallback comparison_callback = {false, nullptr}; |
8 | static std::mutex comparison_callback_lock; |
9 | |
10 | CudaFuserComparisonCallback getCudaFuserComparisonCallback() { |
11 | std::lock_guard<std::mutex> guard(comparison_callback_lock); |
12 | return comparison_callback; |
13 | } |
14 | |
15 | void setCudaFuserComparisonCallback(CudaFuserComparisonCallback callback) { |
16 | std::lock_guard<std::mutex> guard(comparison_callback_lock); |
17 | comparison_callback = callback; |
18 | } |
19 | |
20 | } // namespace jit |
21 | } // namespace torch |
22 |