1#include <torch/csrc/jit/passes/constant_pooling.h>
2#include <torch/csrc/jit/passes/constant_propagation.h>
3#include <torch/csrc/jit/passes/remove_exceptions.h>
4
5#include <torch/csrc/jit/jit_log.h>
6
7namespace torch {
8namespace jit {
9
10bool certainlyThrows(Block* block) {
11 for (Node* n : block->nodes()) {
12 if (n->kind() == prim::RaiseException) {
13 return true;
14 }
15 }
16 return false;
17}
18
19void EliminateExceptions(Block* block) {
20 auto graph = block->owningGraph();
21 Value* false_const = graph->insertConstant(IValue(false));
22 Value* true_const = graph->insertConstant(IValue(true));
23 for (Node* n : block->nodes()) {
24 if (n->kind() == prim::If) {
25 Block* true_block = n->blocks()[0];
26 Block* false_block = n->blocks()[1];
27 if (certainlyThrows(true_block)) {
28 n->input(0)->replaceAllUsesWith(false_const);
29 } else if (certainlyThrows(false_block)) {
30 n->input(0)->replaceAllUsesWith(true_const);
31 }
32 }
33
34 for (Block* subblock : n->blocks()) {
35 EliminateExceptions(subblock);
36 }
37 }
38}
39
40void EliminateExceptions(std::shared_ptr<Graph>& graph) {
41 GRAPH_DUMP("Before EliminateExceptions: ", graph);
42 EliminateExceptions(graph->block());
43 ConstantPropagation(graph);
44 ConstantPooling(graph);
45 GRAPH_DUMP("After EliminateExceptions: ", graph);
46}
47
48} // namespace jit
49} // namespace torch
50