1#include <torch/csrc/jit/passes/lift_closures.h>
2
3#include <torch/csrc/jit/frontend/ir_emitter.h>
4#include <torch/csrc/jit/ir/ir.h>
5
6#include <utility>
7
8namespace torch {
9namespace jit {
10
11// Closures are initially emitted as prim::Closure nodes with a single block.
12// Here, we convert the block to a subgraph, adding all closed over variables
13// as a context tuple input to the closure node.
14// At this point the closure has already undergone conversion to SSA,
15// so closed over variables will just be value * that are not set in the
16// closure block.
17// Within the closure subgraph, the context tuple is unpacked and the unpacked
18// values are used for closed over values.
19void liftClosure(Node* closure) {
20 auto block = closure->blocks().at(0);
21 auto subgraph = std::make_shared<Graph>();
22 // closures/forks can be nested, so use closure owning graph
23 auto g = closure->owningGraph();
24 Node* pack_context =
25 g->create(prim::TupleConstruct, {}, 1)->insertAfter(closure);
26 Value* context = subgraph->addInput("context");
27 // cannot use createTupleUnpack because the type is not known yet
28 Node* unpack_context =
29 subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
30
31 std::unordered_map<Value*, Value*> captures;
32 auto env = [&](Value* v) -> Value* {
33 auto it = captures.find(v);
34 if (it != captures.end()) {
35 return it->second;
36 }
37 pack_context->addInput(v);
38 Value* r = unpack_context->addOutput()->copyMetadata(v);
39 captures[v] = r;
40 return r;
41 };
42 subgraph->block()->cloneFrom(block, env);
43 auto context_type = TupleType::create(
44 fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
45 context->setType(context_type);
46 pack_context->output()->setType(context_type);
47 auto closure_tuple =
48 g->create(prim::TupleConstruct, {}, 1)->insertAfter(pack_context);
49 closure->output()->replaceAllUsesWith(closure_tuple->output());
50 closure_tuple->addInput(closure->output());
51 closure_tuple->addInput(pack_context->output());
52 closure_tuple->output()->setType(
53 TupleType::create({closure->output()->type(), std::move(context_type)}));
54 closure->eraseBlock(0);
55 closure->g_(attr::Subgraph, std::move(subgraph));
56 runCleanupPasses(closure->g(attr::Subgraph));
57}
58
59void liftClosures(Block* block) {
60 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
61 Node* n = *it;
62 it++;
63 switch (n->kind()) {
64 case prim::Closure: {
65 liftClosure(n);
66 } break;
67 default: {
68 for (Block* b : n->blocks()) {
69 liftClosures(b);
70 }
71 }
72 }
73 }
74}
75
76void liftClosures(const std::shared_ptr<Graph>& to_clean) {
77 liftClosures(to_clean->block());
78}
79
80} // namespace jit
81} // namespace torch
82