1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | TORCH_API bool canFuseOnCPULegacy(); |
9 | TORCH_API void overrideCanFuseOnCPULegacy(bool value); |
10 | |
11 | // NB: Be sure to run DCE before fusion, because dead instructions |
12 | // can prevent fusion opportunities from being exploited. |
13 | // On Windows will noop, NYI |
14 | TORCH_API void FuseGraph( |
15 | std::shared_ptr<Graph>& graph, |
16 | bool strict_fuser_check = false); |
17 | |
18 | // \brief Custom fusion pass using a node-level callback to |
19 | // determine the inclusion of nodes in a subgraph. |
20 | // |
21 | // This helper omits aliased inputs and fusion across control flow |
22 | // boundaries. |
23 | // |
24 | // \arg graph The graph to be modified in-place |
25 | // \arg is_fusable A callback run on each fusable node in the graph. |
26 | // \arg kind The label given to the resultant fused subgraph |
27 | // \arg arg_limit The maximum number of args the resultant fused subgraph |
28 | // should have. Note: This will likely develop into a general |
29 | // post condition on the fused subgraph. |
30 | TORCH_API void CustomFuseGraph( |
31 | std::shared_ptr<Graph>& graph, |
32 | const std::function<bool(Node*)>& is_fusable, |
33 | Symbol kind, |
34 | size_t arg_limit = std::numeric_limits<size_t>::max()); |
35 | |
36 | } // namespace jit |
37 | } // namespace torch |
38 | |