1#include <torch/csrc/jit/passes/add_if_then_else.h>
2#include <torch/csrc/jit/runtime/graph_iterator.h>
3
4namespace torch {
5namespace jit {
6
7namespace {
8
9bool hasNoNodes(Block* block) {
10 auto nodes = block->nodes();
11 return nodes.begin() == nodes.end();
12}
13
14bool hasTrivialSubBlocks(Node* node) {
15 const auto blocks = node->blocks();
16 TORCH_DCHECK_EQ(blocks.size(), 2);
17
18 return hasNoNodes(blocks[0]) && hasNoNodes(blocks[1]);
19}
20
21} // namespace
22
23bool AddIfThenElseOp(std::shared_ptr<Graph>& graph) {
24 std::vector<Node*> to_replace;
25 DepthFirstGraphNodeIterator graph_it(graph);
26 for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
27 if (node->kind() != prim::If) {
28 continue;
29 }
30 if (node->outputs().size() != 1) {
31 continue;
32 }
33 if (hasTrivialSubBlocks(node)) {
34 to_replace.push_back(node);
35 }
36 }
37
38 for (auto* node : to_replace) {
39 auto* if_then_else_node = graph->create(prim::IfThenElse, 1);
40 if_then_else_node->addInput(node->input());
41 auto blocks = node->blocks();
42 if_then_else_node->addInput(blocks[0]->return_node()->input());
43 if_then_else_node->addInput(blocks[1]->return_node()->input());
44
45 if_then_else_node->insertBefore(node);
46 if_then_else_node->output()->copyMetadata(node->output());
47
48 node->output()->replaceAllUsesWith(if_then_else_node->output());
49 node->destroy();
50 }
51 return !to_replace.empty();
52}
53
54} // namespace jit
55} // namespace torch
56