1#pragma once
2
3#include <ATen/Context.h>
4#include <torch/csrc/jit/codegen/cuda/interface.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/passes/pass_manager.h>
7#include <string>
8#include <utility>
9
10namespace torch {
11namespace jit {
12
13// Register CudaFuseGraph in custom passes
14struct TORCH_API RegisterCudaFuseGraph
15 : public PassManager<RegisterCudaFuseGraph> {
16 static bool registerPass(bool enabled) {
17 TORCH_WARN(
18 "RegisterCudaFuseGraph::registerPass() is deprecated. "
19 "Please use torch::jit::fuser::cuda::setEnabled().");
20 return fuser::cuda::setEnabled(enabled);
21 }
22
23 static bool isRegistered() {
24 TORCH_WARN(
25 "RegisterCudaFuseGraph::isRegistered() is deprecated. "
26 "Please use torch::jit::fuser::cuda::isEnabled().");
27 return fuser::cuda::isEnabled();
28 }
29};
30
31struct CudaFuserComparisonCallback {
32 using callback_type =
33 std::function<void(const Stack&, const Stack&, const std::string&)>;
34 bool run_fallback;
35 callback_type callback;
36};
37
38TORCH_API CudaFuserComparisonCallback getCudaFuserComparisonCallback();
39TORCH_API void setCudaFuserComparisonCallback(CudaFuserComparisonCallback);
40
41} // namespace jit
42} // namespace torch
43