1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | // Runs constant propagation on all objects unless ignore_custom_classes is |
9 | // specified as true, in which case user defined classes are skipped. This is |
10 | // useful to prevent early fusion of packing operations, which end up lowering |
11 | // away information about their constructors (e.g. packed::linear_clamp_prepack |
12 | // and prepacked::conv2d_clamp_prepack) |
13 | // Returns True if the pass made a change to the graph |
14 | TORCH_API bool ConstantPropagation( |
15 | std::shared_ptr<Graph>& graph, |
16 | bool ignore_custom_classes = false); |
17 | |
18 | // runs constant propagation only on ops that have non-aliasing inputs & outputs |
19 | // Returns True if the pass made a change to the graph |
20 | TORCH_API bool ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph); |
21 | |
22 | // Runs the node if its inputs are constants. Callers of this function must |
23 | // make their own determination if constant prop is appropriate - for example |
24 | // non-deterministic ops or ops with side effects. If ignore_custom_classes is |
25 | // specified, nodes that output user defined classes are not run. |
26 | TORCH_API c10::optional<Stack> runNodeIfInputsAreConstant( |
27 | const Node* node, |
28 | bool ignore_custom_classes = false, |
29 | AliasDb* db = nullptr); |
30 | |
31 | } // namespace jit |
32 | } // namespace torch |
33 | |