1 | #include <torch/csrc/jit/passes/add_if_then_else.h> |
---|---|
2 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | |
7 | namespace { |
8 | |
9 | bool hasNoNodes(Block* block) { |
10 | auto nodes = block->nodes(); |
11 | return nodes.begin() == nodes.end(); |
12 | } |
13 | |
14 | bool hasTrivialSubBlocks(Node* node) { |
15 | const auto blocks = node->blocks(); |
16 | TORCH_DCHECK_EQ(blocks.size(), 2); |
17 | |
18 | return hasNoNodes(blocks[0]) && hasNoNodes(blocks[1]); |
19 | } |
20 | |
21 | } // namespace |
22 | |
23 | bool AddIfThenElseOp(std::shared_ptr<Graph>& graph) { |
24 | std::vector<Node*> to_replace; |
25 | DepthFirstGraphNodeIterator graph_it(graph); |
26 | for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) { |
27 | if (node->kind() != prim::If) { |
28 | continue; |
29 | } |
30 | if (node->outputs().size() != 1) { |
31 | continue; |
32 | } |
33 | if (hasTrivialSubBlocks(node)) { |
34 | to_replace.push_back(node); |
35 | } |
36 | } |
37 | |
38 | for (auto* node : to_replace) { |
39 | auto* if_then_else_node = graph->create(prim::IfThenElse, 1); |
40 | if_then_else_node->addInput(node->input()); |
41 | auto blocks = node->blocks(); |
42 | if_then_else_node->addInput(blocks[0]->return_node()->input()); |
43 | if_then_else_node->addInput(blocks[1]->return_node()->input()); |
44 | |
45 | if_then_else_node->insertBefore(node); |
46 | if_then_else_node->output()->copyMetadata(node->output()); |
47 | |
48 | node->output()->replaceAllUsesWith(if_then_else_node->output()); |
49 | node->destroy(); |
50 | } |
51 | return !to_replace.empty(); |
52 | } |
53 | |
54 | } // namespace jit |
55 | } // namespace torch |
56 |