1 | #include <torch/csrc/jit/jit_log.h> |
2 | #include <torch/csrc/jit/passes/inline_fork_wait.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | |
7 | void InlineForkWait( |
8 | Block* b, |
9 | std::unordered_map<Value*, Value*>& future_remap) { |
10 | auto nodes = b->nodes(); |
11 | |
12 | // Track the futures returned by prim::fork. |
13 | for (auto it = nodes.begin(); it != nodes.end(); it++) { |
14 | auto node = *it; |
15 | if (node->kind() != prim::fork) { |
16 | continue; |
17 | } |
18 | WithInsertPoint insert_guard(node); |
19 | auto graph = b->owningGraph(); |
20 | auto subgraph = node->g(attr::Subgraph); |
21 | |
22 | auto output = insertGraph(*graph, *subgraph, node->inputs()); |
23 | |
24 | future_remap[node->output()] = output.at(0); |
25 | } |
26 | |
27 | // Remove aten::wait if its input future is returned by prim::fork. |
28 | auto reversed = b->nodes().reverse(); |
29 | for (auto it = reversed.begin(); it != reversed.end(); it++) { |
30 | auto node = *it; |
31 | if (node->kind() == prim::fork) { |
32 | // Account for the case where the aten::wait call isn't present in |
33 | // the current graph. |
34 | node->output()->replaceAllUsesWith(future_remap.at(node->output())); |
35 | it.destroyCurrent(); |
36 | } else if (node->kind() == aten::wait) { |
37 | AT_ASSERT(node->inputs().size() == 1); |
38 | AT_ASSERT(node->outputs().size() == 1); |
39 | // If the future does not map to a prim::fork, it could be |
40 | // returned from prim::rpc_async, which has side effect, so it shouldn't |
41 | // be dead code eliminated. |
42 | if (future_remap.count(node->input())) { |
43 | node->output()->replaceAllUsesWith(future_remap.at(node->input())); |
44 | it.destroyCurrent(); |
45 | } |
46 | } |
47 | } |
48 | |
49 | // Recursively inline fork/wait. |
50 | for (auto it = nodes.begin(); it != nodes.end(); it++) { |
51 | auto node = *it; |
52 | for (auto sub_b : node->blocks()) { |
53 | InlineForkWait(sub_b, future_remap); |
54 | } |
55 | } |
56 | } |
57 | |
58 | void InlineForkWait(const std::shared_ptr<Graph>& graph) { |
59 | std::unordered_map<Value*, Value*> future_remap; |
60 | InlineForkWait(graph->block(), future_remap); |
61 | GRAPH_DUMP("After InlineForkWait: " , graph); |
62 | } |
63 | |
64 | } // namespace jit |
65 | } // namespace torch |
66 | |