1 | #pragma once |
---|---|
2 | #include <manager.h> |
3 | #include <transform_view.h> |
4 | |
5 | #include <c10/macros/Export.h> |
6 | #include <torch/csrc/jit/codegen/cuda/interface.h> |
7 | #include <torch/csrc/jit/ir/ir.h> |
8 | #include <torch/csrc/jit/passes/pass_manager.h> |
9 | #include <torch/csrc/jit/runtime/profiling_record.h> |
10 | |
11 | /* |
12 | * This file contains APIs for cuda fuser; |
13 | * |
14 | * We use an empty static struct to hold the function pointers, which are |
15 | * registered separately. This is to support cpu-only compilation. |
16 | * Registration is done in torch/csrc/jit/codegen/cuda/register_interface.cpp |
17 | */ |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | namespace fuser { |
22 | namespace cuda { |
23 | |
24 | TORCH_CUDA_CU_API bool complyWith( |
25 | const at::Tensor& tensor, |
26 | const c10::TensorTypePtr& guard_tensor_type); |
27 | |
28 | struct TORCH_CUDA_CU_API NVFuserPassManager |
29 | : public PassManager<NVFuserPassManager> { |
30 | static bool registerPass(bool enabled) { |
31 | bool old_value = PassManager::isRegistered(); |
32 | if (enabled) { |
33 | PassManager::registerPass(fuseGraph); |
34 | } else { |
35 | PassManager::clearPass(); |
36 | } |
37 | return old_value; |
38 | } |
39 | |
40 | static bool isRegistered() { |
41 | return PassManager::isRegistered(); |
42 | } |
43 | }; |
44 | |
45 | } // namespace cuda |
46 | } // namespace fuser |
47 | } // namespace jit |
48 | } // namespace torch |
49 |