1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | |
6 | /* |
7 | * API for query node-compatibility in CudaCodeGen |
8 | * |
9 | * It is used in the optimization passes, where the graph is traversed and parts |
10 | * that could be handled by CudaCodegen is partitioned and stuffed in |
11 | * `attr::Subgraph` of `prim::CudaFusionGroup`. |
12 | * |
13 | * Logic right now is very simple. On top of device placement, we consider a |
14 | * `Node` compatible when we have a parsing rule for it in our parser. |
15 | */ |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | namespace fuser { |
20 | namespace cuda { |
21 | |
22 | TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup(const Node* node); |
23 | |
24 | // consider if `node` could be fused into `fusion` |
25 | TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup( |
26 | const Node* fusion, |
27 | const Node* node); |
28 | |
29 | } // namespace cuda |
30 | } // namespace fuser |
31 | } // namespace jit |
32 | } // namespace torch |
33 | |