1#include <torch/csrc/jit/passes/inline_forked_closures.h>
2
3#include <torch/csrc/jit/frontend/ir_emitter.h>
4
5namespace torch {
6namespace jit {
7
8// Closure nodes are emitted as a tuple of (function %, context tuple %)
9// Inside the closure the closure is then unpacked so that all closed over
10// values are set. A function closing over a and b would look like:
11// def foo(context):
12// a, b = context
13//
14// To fork the closure, we need to set each value in the context tuple
15// as an explicit input to the fork node, and then within the closure
16// subgraph, replace the context unpacking value with the new graph input.
17// fork(foo) ->
18// def foo(a, b):
19void inlineForkedClosure(Node* fork_closure, NodeKind genKind) {
20 Node* function_context_node = fork_closure->input()->node();
21
22 if (function_context_node->inputs().size() != 2 ||
23 function_context_node->inputs().at(0)->node()->kind() != prim::Closure ||
24 function_context_node->inputs().at(1)->node()->kind() !=
25 prim::TupleConstruct) {
26 throw ErrorReport(fork_closure->sourceRange()) << "Cannot fork this value";
27 }
28
29 Node* function = function_context_node->inputs().at(0)->node();
30 Node* context = function_context_node->inputs().at(1)->node();
31 auto fork_graph = function->g(attr::Subgraph)->copy();
32 auto g = fork_closure->owningGraph();
33 Node* fork_node = g->create(genKind, 1)
34 ->insertAfter(fork_closure)
35 ->setSourceRange(fork_closure->sourceRange());
36
37 if (fork_graph->inputs().size() != 1 ||
38 !fork_graph->inputs().at(0)->type()->cast<TupleType>()) {
39 throw ErrorReport(fork_node->sourceRange())
40 << "Cannot fork lambda with parameters";
41 }
42 auto fork_graph_context = fork_graph->inputs().at(0);
43 AT_ASSERT(fork_graph_context->uses().size() == 1);
44 auto fork_graph_unpack = fork_graph_context->uses().at(0).user;
45
46 for (size_t i = 0; i < context->inputs().size(); ++i) {
47 auto cont_input = context->inputs().at(i);
48 fork_node->addInput(cont_input);
49 auto inp = fork_graph->insertInput(i)->copyMetadata(cont_input);
50 fork_graph_unpack->outputs().at(i)->replaceAllUsesWith(inp);
51 }
52 fork_graph_unpack->destroy();
53 fork_graph->eraseInput(fork_graph->inputs().size() - 1);
54 fork_node->output()->copyMetadata(fork_closure->output());
55 fork_closure->output()->replaceAllUsesWith(fork_node->output());
56 fork_closure->destroy();
57 fork_node->g_(attr::Subgraph, fork_graph);
58 runCleanupPasses(fork_graph);
59}
60
61void inlineForkedClosures(Block* block) {
62 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
63 Node* n = *it;
64 it++;
65 switch (n->kind()) {
66 case prim::forkClosure: {
67 inlineForkedClosure(n, prim::fork);
68 } break;
69 case prim::awaitableClosure: {
70 inlineForkedClosure(n, prim::awaitable);
71 } break;
72 default: {
73 for (Block* b : n->blocks()) {
74 inlineForkedClosures(b);
75 }
76 } break;
77 }
78 }
79}
80
81void inlineForkedClosures(std::shared_ptr<Graph>& to_clean) {
82 inlineForkedClosures(to_clean->block());
83}
84
85} // namespace jit
86} // namespace torch
87