1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | |
6 | /* |
7 | * This file handles compilation and execution of a CudaFusionGroup; |
8 | * |
9 | * A CudaFusionGroup node comes with `attr::Subgraph` containing the computation |
10 | * graph. We compile the graph to generate CUDA function and cache them in a |
11 | * registry. We cache & reuse kernels across nodes sharing identical graph. |
12 | * |
13 | * After compilation, we assign the key to cached kernel as an integer attribute |
14 | * on the node `attr::cache_id`. |
15 | */ |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | namespace fuser { |
20 | namespace cuda { |
21 | |
22 | // Get fusion_node ready for execution. |
23 | // find or compile `CudaKernel` for graph stored in `attr::Subgraph` |
24 | // this function assigns `attr::cache_id` to `fusion_node` |
25 | TORCH_CUDA_CU_API void compileCudaFusionGroup(Node* fusion_node); |
26 | |
27 | // Execute fusion_node. |
28 | // Current protocol is that the function allocates output tensor append them to |
29 | // `stack` after execution. |
30 | // TODO: support shape inferencing. Right now we only handles static shape |
31 | TORCH_CUDA_CU_API void runCudaFusionGroup( |
32 | const Node* fusion_node, |
33 | Stack& stack); |
34 | |
35 | TORCH_CUDA_CU_API void CudaFuseGraph(std::shared_ptr<Graph>& graph); |
36 | |
37 | } // namespace cuda |
38 | } // namespace fuser |
39 | } // namespace jit |
40 | } // namespace torch |
41 | |