1#pragma once
2
3#include <c10/macros/Export.h>
4#include <torch/csrc/jit/ir/ir.h>
5
6/*
7 * API for query node-compatibility in CudaCodeGen
8 *
9 * It is used in the optimization passes, where the graph is traversed and parts
10 * that could be handled by CudaCodegen is partitioned and stuffed in
11 * `attr::Subgraph` of `prim::CudaFusionGroup`.
12 *
13 * Logic right now is very simple. On top of device placement, we consider a
14 * `Node` compatible when we have a parsing rule for it in our parser.
15 */
16
17namespace torch {
18namespace jit {
19namespace fuser {
20namespace cuda {
21
22TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup(const Node* node);
23
24// consider if `node` could be fused into `fusion`
25TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup(
26 const Node* fusion,
27 const Node* node);
28
29} // namespace cuda
30} // namespace fuser
31} // namespace jit
32} // namespace torch
33