1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4
5namespace torch {
6namespace jit {
7
8TORCH_API bool canFuseOnCPULegacy();
9TORCH_API void overrideCanFuseOnCPULegacy(bool value);
10
11// NB: Be sure to run DCE before fusion, because dead instructions
12// can prevent fusion opportunities from being exploited.
13// On Windows will noop, NYI
14TORCH_API void FuseGraph(
15 std::shared_ptr<Graph>& graph,
16 bool strict_fuser_check = false);
17
18// \brief Custom fusion pass using a node-level callback to
19// determine the inclusion of nodes in a subgraph.
20//
21// This helper omits aliased inputs and fusion across control flow
22// boundaries.
23//
24// \arg graph The graph to be modified in-place
25// \arg is_fusable A callback run on each fusable node in the graph.
26// \arg kind The label given to the resultant fused subgraph
27// \arg arg_limit The maximum number of args the resultant fused subgraph
28// should have. Note: This will likely develop into a general
29// post condition on the fused subgraph.
30TORCH_API void CustomFuseGraph(
31 std::shared_ptr<Graph>& graph,
32 const std::function<bool(Node*)>& is_fusable,
33 Symbol kind,
34 size_t arg_limit = std::numeric_limits<size_t>::max());
35
36} // namespace jit
37} // namespace torch
38