1 | #include <torch/csrc/jit/passes/lift_closures.h> |
2 | |
3 | #include <torch/csrc/jit/frontend/ir_emitter.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | |
6 | #include <utility> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | // Closures are initially emitted as prim::Closure nodes with a single block. |
12 | // Here, we convert the block to a subgraph, adding all closed over variables |
13 | // as a context tuple input to the closure node. |
14 | // At this point the closure has already undergone conversion to SSA, |
15 | // so closed over variables will just be value * that are not set in the |
16 | // closure block. |
17 | // Within the closure subgraph, the context tuple is unpacked and the unpacked |
18 | // values are used for closed over values. |
19 | void liftClosure(Node* closure) { |
20 | auto block = closure->blocks().at(0); |
21 | auto subgraph = std::make_shared<Graph>(); |
22 | // closures/forks can be nested, so use closure owning graph |
23 | auto g = closure->owningGraph(); |
24 | Node* pack_context = |
25 | g->create(prim::TupleConstruct, {}, 1)->insertAfter(closure); |
26 | Value* context = subgraph->addInput("context" ); |
27 | // cannot use createTupleUnpack because the type is not known yet |
28 | Node* unpack_context = |
29 | subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0)); |
30 | |
31 | std::unordered_map<Value*, Value*> captures; |
32 | auto env = [&](Value* v) -> Value* { |
33 | auto it = captures.find(v); |
34 | if (it != captures.end()) { |
35 | return it->second; |
36 | } |
37 | pack_context->addInput(v); |
38 | Value* r = unpack_context->addOutput()->copyMetadata(v); |
39 | captures[v] = r; |
40 | return r; |
41 | }; |
42 | subgraph->block()->cloneFrom(block, env); |
43 | auto context_type = TupleType::create( |
44 | fmap(pack_context->inputs(), [](Value* v) { return v->type(); })); |
45 | context->setType(context_type); |
46 | pack_context->output()->setType(context_type); |
47 | auto closure_tuple = |
48 | g->create(prim::TupleConstruct, {}, 1)->insertAfter(pack_context); |
49 | closure->output()->replaceAllUsesWith(closure_tuple->output()); |
50 | closure_tuple->addInput(closure->output()); |
51 | closure_tuple->addInput(pack_context->output()); |
52 | closure_tuple->output()->setType( |
53 | TupleType::create({closure->output()->type(), std::move(context_type)})); |
54 | closure->eraseBlock(0); |
55 | closure->g_(attr::Subgraph, std::move(subgraph)); |
56 | runCleanupPasses(closure->g(attr::Subgraph)); |
57 | } |
58 | |
59 | void liftClosures(Block* block) { |
60 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
61 | Node* n = *it; |
62 | it++; |
63 | switch (n->kind()) { |
64 | case prim::Closure: { |
65 | liftClosure(n); |
66 | } break; |
67 | default: { |
68 | for (Block* b : n->blocks()) { |
69 | liftClosures(b); |
70 | } |
71 | } |
72 | } |
73 | } |
74 | } |
75 | |
76 | void liftClosures(const std::shared_ptr<Graph>& to_clean) { |
77 | liftClosures(to_clean->block()); |
78 | } |
79 | |
80 | } // namespace jit |
81 | } // namespace torch |
82 | |