1 | #include <torch/csrc/jit/passes/inline_forked_closures.h> |
2 | |
3 | #include <torch/csrc/jit/frontend/ir_emitter.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | // Closure nodes are emitted as a tuple of (function %, context tuple %) |
9 | // Inside the closure the closure is then unpacked so that all closed over |
10 | // values are set. A function closing over a and b would look like: |
11 | // def foo(context): |
12 | // a, b = context |
13 | // |
14 | // To fork the closure, we need to set each value in the context tuple |
15 | // as an explicit input to the fork node, and then within the closure |
16 | // subgraph, replace the context unpacking value with the new graph input. |
17 | // fork(foo) -> |
18 | // def foo(a, b): |
19 | void inlineForkedClosure(Node* fork_closure, NodeKind genKind) { |
20 | Node* function_context_node = fork_closure->input()->node(); |
21 | |
22 | if (function_context_node->inputs().size() != 2 || |
23 | function_context_node->inputs().at(0)->node()->kind() != prim::Closure || |
24 | function_context_node->inputs().at(1)->node()->kind() != |
25 | prim::TupleConstruct) { |
26 | throw ErrorReport(fork_closure->sourceRange()) << "Cannot fork this value" ; |
27 | } |
28 | |
29 | Node* function = function_context_node->inputs().at(0)->node(); |
30 | Node* context = function_context_node->inputs().at(1)->node(); |
31 | auto fork_graph = function->g(attr::Subgraph)->copy(); |
32 | auto g = fork_closure->owningGraph(); |
33 | Node* fork_node = g->create(genKind, 1) |
34 | ->insertAfter(fork_closure) |
35 | ->setSourceRange(fork_closure->sourceRange()); |
36 | |
37 | if (fork_graph->inputs().size() != 1 || |
38 | !fork_graph->inputs().at(0)->type()->cast<TupleType>()) { |
39 | throw ErrorReport(fork_node->sourceRange()) |
40 | << "Cannot fork lambda with parameters" ; |
41 | } |
42 | auto fork_graph_context = fork_graph->inputs().at(0); |
43 | AT_ASSERT(fork_graph_context->uses().size() == 1); |
44 | auto fork_graph_unpack = fork_graph_context->uses().at(0).user; |
45 | |
46 | for (size_t i = 0; i < context->inputs().size(); ++i) { |
47 | auto cont_input = context->inputs().at(i); |
48 | fork_node->addInput(cont_input); |
49 | auto inp = fork_graph->insertInput(i)->copyMetadata(cont_input); |
50 | fork_graph_unpack->outputs().at(i)->replaceAllUsesWith(inp); |
51 | } |
52 | fork_graph_unpack->destroy(); |
53 | fork_graph->eraseInput(fork_graph->inputs().size() - 1); |
54 | fork_node->output()->copyMetadata(fork_closure->output()); |
55 | fork_closure->output()->replaceAllUsesWith(fork_node->output()); |
56 | fork_closure->destroy(); |
57 | fork_node->g_(attr::Subgraph, fork_graph); |
58 | runCleanupPasses(fork_graph); |
59 | } |
60 | |
61 | void inlineForkedClosures(Block* block) { |
62 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
63 | Node* n = *it; |
64 | it++; |
65 | switch (n->kind()) { |
66 | case prim::forkClosure: { |
67 | inlineForkedClosure(n, prim::fork); |
68 | } break; |
69 | case prim::awaitableClosure: { |
70 | inlineForkedClosure(n, prim::awaitable); |
71 | } break; |
72 | default: { |
73 | for (Block* b : n->blocks()) { |
74 | inlineForkedClosures(b); |
75 | } |
76 | } break; |
77 | } |
78 | } |
79 | } |
80 | |
81 | void inlineForkedClosures(std::shared_ptr<Graph>& to_clean) { |
82 | inlineForkedClosures(to_clean->block()); |
83 | } |
84 | |
85 | } // namespace jit |
86 | } // namespace torch |
87 | |