1#include <torch/csrc/jit/passes/liveness.h>
2
3#include <torch/csrc/jit/ir/alias_analysis.h>
4#include <torch/csrc/jit/ir/ir_views.h>
5#include <torch/csrc/jit/passes/constant_pooling.h>
6#include <memory>
7
8namespace torch {
9namespace jit {
10
11// LivenessAnalyzer computes "bailout" liveness which is equivalent to
12// "{LIVE_IN} or {GEN}" or "{LIVE_OUT} - {KILL}"
13struct LivenessAnalyzer {
14 explicit LivenessAnalyzer(std::shared_ptr<Graph> graph)
15 : graph_(std::move(graph)), changed_(false) {}
16
17 std::unordered_map<Node*, std::vector<Value*>> run() {
18 std::vector<Node*> counters;
19 insertExplicitUsesOfLoopCounters(graph_->block(), counters);
20
21 // we implement the canonical fixed-point liveness
22 // the analysis is run until there are no more changes
23 // to liveness sets for each node
24 do {
25 changed_ = false;
26 processBlock(graph_->block(), SparseBitVector{});
27 } while (changed_);
28
29 removeCounterNodes(counters);
30 std::unordered_map<Node*, std::vector<Value*>> result;
31
32 for (const auto& e : liveness_sets_) {
33 result.insert({e.first, toValueVector(e.second)});
34 }
35 return result;
36 }
37
38 // temporary make loop counts live for the duration of the loop
39 // as they are needed by BailOuts in the loop
40 void insertExplicitUsesOfLoopCounters(
41 Block* b,
42 std::vector<Node*>& counters) {
43 for (auto it : b->nodes()) {
44 if (it->kind() == prim::Loop) {
45 LoopView lv(it);
46 WithInsertPoint guard(lv.bodyBlock());
47 auto ctc = graph_->create(prim::Store, {lv.currentTripCount()}, 0);
48 graph_->insertNode(ctc);
49 counters.push_back(ctc);
50 auto mtc = graph_->create(prim::Store, {lv.maxTripCount()}, 0);
51 graph_->insertNode(mtc);
52 counters.push_back(mtc);
53 }
54
55 for (auto ib : it->blocks()) {
56 insertExplicitUsesOfLoopCounters(ib, counters);
57 }
58 }
59 }
60
61 void removeCounterNodes(std::vector<Node*>& counters) {
62 for (auto n : counters) {
63 n->destroy();
64 }
65 }
66
67 void dump(
68 const std::unordered_map<Node*, std::vector<Value*>>& liveness_sets) {
69 std::cout << "Liveness info:\n";
70 for (auto e : liveness_sets) {
71 if (!e.first->outputs().empty()) {
72 std::cout << e.first->outputs()[0]->debugName();
73 }
74
75 std::cout << " " << e.first->kind().toQualString();
76 std::cout << " = ";
77 dump(e.second);
78 std::cout << std::endl;
79 }
80 std::cout << "graph :\n";
81 graph_->dump();
82 }
83
84 void dump(const std::vector<Value*>& set) {
85 bool first = true;
86 std::cout << "[";
87 for (auto el : set) {
88 if (first) {
89 first = false;
90 } else {
91 std::cout << ", ";
92 }
93 std::cout << el->debugName() << "(" << el->unique() << ")";
94 }
95 std::cout << "]";
96 }
97
98 private:
99 SparseBitVector toSparseBitVector(at::ArrayRef<Value*> values) {
100 SparseBitVector sbv;
101 for (auto v : values) {
102 ids_to_values_[v->unique()] = v;
103 sbv.set(v->unique());
104 }
105 return sbv;
106 }
107
108 std::vector<Value*> toValueVector(const SparseBitVector& sbv) {
109 std::vector<Value*> vec;
110 for (auto id : sbv) {
111 vec.push_back(ids_to_values_[id]);
112 }
113 return vec;
114 }
115
116 SparseBitVector processBlock(Block* b, SparseBitVector liveness) {
117 // block outputs are the uses
118 auto block_outputs = toSparseBitVector(b->outputs());
119 liveness |= block_outputs;
120
121 SparseBitVector defs;
122 for (Node* it : b->nodes().reverse()) {
123 // kill outputs
124 liveness -= toSparseBitVector(it->outputs());
125 if (it->kind() == prim::Loop) {
126 LoopView lv(it);
127 // N.B. merge in changes from the loop header
128 auto loop_header = *lv.bodyBlock()->nodes().begin();
129 auto loop_block = liveness | liveness_sets_[loop_header];
130 loop_block = processBlock(lv.bodyBlock(), loop_block);
131 // loop block's inputs die outside loop's block
132 loop_block -= toSparseBitVector(lv.bodyBlock()->inputs());
133 liveness |= loop_block;
134 } else if (it->kind() == prim::If) {
135 IfView iv(it);
136 auto true_liveness = processBlock(iv.thenBlock(), liveness);
137 auto false_liveness = processBlock(iv.elseBlock(), liveness);
138 liveness |= true_liveness;
139 liveness |= false_liveness;
140 }
141 liveness |= toSparseBitVector(it->inputs());
142 // `|=` returns true if new bits were set in LHS
143 // after or/union with `liveness`
144 auto changed = liveness_sets_[it] |= liveness;
145 changed_ = changed_ | changed;
146 }
147 return liveness;
148 }
149
150 std::shared_ptr<Graph> graph_;
151 bool changed_;
152 std::map<Node*, SparseBitVector> liveness_sets_;
153 std::map<size_t, Value*> ids_to_values_;
154};
155
156std::unordered_map<Node*, std::vector<Value*>> BuildLivenessSets(
157 std::shared_ptr<Graph> graph) {
158 LivenessAnalyzer la(std::move(graph));
159 return la.run();
160}
161
162} // namespace jit
163} // namespace torch
164