1#include <torch/csrc/jit/jit_log.h>
2#include <torch/csrc/jit/passes/inline_fork_wait.h>
3
4namespace torch {
5namespace jit {
6
7void InlineForkWait(
8 Block* b,
9 std::unordered_map<Value*, Value*>& future_remap) {
10 auto nodes = b->nodes();
11
12 // Track the futures returned by prim::fork.
13 for (auto it = nodes.begin(); it != nodes.end(); it++) {
14 auto node = *it;
15 if (node->kind() != prim::fork) {
16 continue;
17 }
18 WithInsertPoint insert_guard(node);
19 auto graph = b->owningGraph();
20 auto subgraph = node->g(attr::Subgraph);
21
22 auto output = insertGraph(*graph, *subgraph, node->inputs());
23
24 future_remap[node->output()] = output.at(0);
25 }
26
27 // Remove aten::wait if its input future is returned by prim::fork.
28 auto reversed = b->nodes().reverse();
29 for (auto it = reversed.begin(); it != reversed.end(); it++) {
30 auto node = *it;
31 if (node->kind() == prim::fork) {
32 // Account for the case where the aten::wait call isn't present in
33 // the current graph.
34 node->output()->replaceAllUsesWith(future_remap.at(node->output()));
35 it.destroyCurrent();
36 } else if (node->kind() == aten::wait) {
37 AT_ASSERT(node->inputs().size() == 1);
38 AT_ASSERT(node->outputs().size() == 1);
39 // If the future does not map to a prim::fork, it could be
40 // returned from prim::rpc_async, which has side effect, so it shouldn't
41 // be dead code eliminated.
42 if (future_remap.count(node->input())) {
43 node->output()->replaceAllUsesWith(future_remap.at(node->input()));
44 it.destroyCurrent();
45 }
46 }
47 }
48
49 // Recursively inline fork/wait.
50 for (auto it = nodes.begin(); it != nodes.end(); it++) {
51 auto node = *it;
52 for (auto sub_b : node->blocks()) {
53 InlineForkWait(sub_b, future_remap);
54 }
55 }
56}
57
58void InlineForkWait(const std::shared_ptr<Graph>& graph) {
59 std::unordered_map<Value*, Value*> future_remap;
60 InlineForkWait(graph->block(), future_remap);
61 GRAPH_DUMP("After InlineForkWait: ", graph);
62}
63
64} // namespace jit
65} // namespace torch
66