1#pragma once
2
3#include <c10/macros/Export.h>
4#include <torch/csrc/jit/ir/ir.h>
5
6/*
7 * This file handles compilation and execution of a CudaFusionGroup;
8 *
9 * A CudaFusionGroup node comes with `attr::Subgraph` containing the computation
10 * graph. We compile the graph to generate CUDA function and cache them in a
11 * registry. We cache & reuse kernels across nodes sharing identical graph.
12 *
13 * After compilation, we assign the key to cached kernel as an integer attribute
14 * on the node `attr::cache_id`.
15 */
16
17namespace torch {
18namespace jit {
19namespace fuser {
20namespace cuda {
21
22// Get fusion_node ready for execution.
23// find or compile `CudaKernel` for graph stored in `attr::Subgraph`
24// this function assigns `attr::cache_id` to `fusion_node`
25TORCH_CUDA_CU_API void compileCudaFusionGroup(Node* fusion_node);
26
27// Execute fusion_node.
28// Current protocol is that the function allocates output tensor append them to
29// `stack` after execution.
30// TODO: support shape inferencing. Right now we only handles static shape
31TORCH_CUDA_CU_API void runCudaFusionGroup(
32 const Node* fusion_node,
33 Stack& stack);
34
35TORCH_CUDA_CU_API void CudaFuseGraph(std::shared_ptr<Graph>& graph);
36
37} // namespace cuda
38} // namespace fuser
39} // namespace jit
40} // namespace torch
41