1 | #include <torch/csrc/jit/passes/constant_pooling.h> |
---|---|
2 | #include <torch/csrc/jit/passes/constant_propagation.h> |
3 | #include <torch/csrc/jit/passes/remove_exceptions.h> |
4 | |
5 | #include <torch/csrc/jit/jit_log.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | bool certainlyThrows(Block* block) { |
11 | for (Node* n : block->nodes()) { |
12 | if (n->kind() == prim::RaiseException) { |
13 | return true; |
14 | } |
15 | } |
16 | return false; |
17 | } |
18 | |
19 | void EliminateExceptions(Block* block) { |
20 | auto graph = block->owningGraph(); |
21 | Value* false_const = graph->insertConstant(IValue(false)); |
22 | Value* true_const = graph->insertConstant(IValue(true)); |
23 | for (Node* n : block->nodes()) { |
24 | if (n->kind() == prim::If) { |
25 | Block* true_block = n->blocks()[0]; |
26 | Block* false_block = n->blocks()[1]; |
27 | if (certainlyThrows(true_block)) { |
28 | n->input(0)->replaceAllUsesWith(false_const); |
29 | } else if (certainlyThrows(false_block)) { |
30 | n->input(0)->replaceAllUsesWith(true_const); |
31 | } |
32 | } |
33 | |
34 | for (Block* subblock : n->blocks()) { |
35 | EliminateExceptions(subblock); |
36 | } |
37 | } |
38 | } |
39 | |
40 | void EliminateExceptions(std::shared_ptr<Graph>& graph) { |
41 | GRAPH_DUMP("Before EliminateExceptions: ", graph); |
42 | EliminateExceptions(graph->block()); |
43 | ConstantPropagation(graph); |
44 | ConstantPooling(graph); |
45 | GRAPH_DUMP("After EliminateExceptions: ", graph); |
46 | } |
47 | |
48 | } // namespace jit |
49 | } // namespace torch |
50 |