1#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
2#include <mutex>
3
4namespace torch {
5namespace jit {
6
7static CudaFuserComparisonCallback comparison_callback = {false, nullptr};
8static std::mutex comparison_callback_lock;
9
10CudaFuserComparisonCallback getCudaFuserComparisonCallback() {
11 std::lock_guard<std::mutex> guard(comparison_callback_lock);
12 return comparison_callback;
13}
14
15void setCudaFuserComparisonCallback(CudaFuserComparisonCallback callback) {
16 std::lock_guard<std::mutex> guard(comparison_callback_lock);
17 comparison_callback = callback;
18}
19
20} // namespace jit
21} // namespace torch
22