1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/Context.h> |
4 | #include <torch/csrc/jit/codegen/cuda/interface.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/jit/passes/pass_manager.h> |
7 | #include <string> |
8 | #include <utility> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | // Register CudaFuseGraph in custom passes |
14 | struct TORCH_API RegisterCudaFuseGraph |
15 | : public PassManager<RegisterCudaFuseGraph> { |
16 | static bool registerPass(bool enabled) { |
17 | TORCH_WARN( |
18 | "RegisterCudaFuseGraph::registerPass() is deprecated. " |
19 | "Please use torch::jit::fuser::cuda::setEnabled()."); |
20 | return fuser::cuda::setEnabled(enabled); |
21 | } |
22 | |
23 | static bool isRegistered() { |
24 | TORCH_WARN( |
25 | "RegisterCudaFuseGraph::isRegistered() is deprecated. " |
26 | "Please use torch::jit::fuser::cuda::isEnabled()."); |
27 | return fuser::cuda::isEnabled(); |
28 | } |
29 | }; |
30 | |
31 | struct CudaFuserComparisonCallback { |
32 | using callback_type = |
33 | std::function<void(const Stack&, const Stack&, const std::string&)>; |
34 | bool run_fallback; |
35 | callback_type callback; |
36 | }; |
37 | |
38 | TORCH_API CudaFuserComparisonCallback getCudaFuserComparisonCallback(); |
39 | TORCH_API void setCudaFuserComparisonCallback(CudaFuserComparisonCallback); |
40 | |
41 | } // namespace jit |
42 | } // namespace torch |
43 |