1 | #include <torch/csrc/jit/passes/liveness.h> |
2 | |
3 | #include <torch/csrc/jit/ir/alias_analysis.h> |
4 | #include <torch/csrc/jit/ir/ir_views.h> |
5 | #include <torch/csrc/jit/passes/constant_pooling.h> |
6 | #include <memory> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | // LivenessAnalyzer computes "bailout" liveness which is equivalent to |
12 | // "{LIVE_IN} or {GEN}" or "{LIVE_OUT} - {KILL}" |
13 | struct LivenessAnalyzer { |
14 | explicit LivenessAnalyzer(std::shared_ptr<Graph> graph) |
15 | : graph_(std::move(graph)), changed_(false) {} |
16 | |
17 | std::unordered_map<Node*, std::vector<Value*>> run() { |
18 | std::vector<Node*> counters; |
19 | insertExplicitUsesOfLoopCounters(graph_->block(), counters); |
20 | |
21 | // we implement the canonical fixed-point liveness |
22 | // the analysis is run until there are no more changes |
23 | // to liveness sets for each node |
24 | do { |
25 | changed_ = false; |
26 | processBlock(graph_->block(), SparseBitVector{}); |
27 | } while (changed_); |
28 | |
29 | removeCounterNodes(counters); |
30 | std::unordered_map<Node*, std::vector<Value*>> result; |
31 | |
32 | for (const auto& e : liveness_sets_) { |
33 | result.insert({e.first, toValueVector(e.second)}); |
34 | } |
35 | return result; |
36 | } |
37 | |
38 | // temporary make loop counts live for the duration of the loop |
39 | // as they are needed by BailOuts in the loop |
40 | void insertExplicitUsesOfLoopCounters( |
41 | Block* b, |
42 | std::vector<Node*>& counters) { |
43 | for (auto it : b->nodes()) { |
44 | if (it->kind() == prim::Loop) { |
45 | LoopView lv(it); |
46 | WithInsertPoint guard(lv.bodyBlock()); |
47 | auto ctc = graph_->create(prim::Store, {lv.currentTripCount()}, 0); |
48 | graph_->insertNode(ctc); |
49 | counters.push_back(ctc); |
50 | auto mtc = graph_->create(prim::Store, {lv.maxTripCount()}, 0); |
51 | graph_->insertNode(mtc); |
52 | counters.push_back(mtc); |
53 | } |
54 | |
55 | for (auto ib : it->blocks()) { |
56 | insertExplicitUsesOfLoopCounters(ib, counters); |
57 | } |
58 | } |
59 | } |
60 | |
61 | void removeCounterNodes(std::vector<Node*>& counters) { |
62 | for (auto n : counters) { |
63 | n->destroy(); |
64 | } |
65 | } |
66 | |
67 | void dump( |
68 | const std::unordered_map<Node*, std::vector<Value*>>& liveness_sets) { |
69 | std::cout << "Liveness info:\n" ; |
70 | for (auto e : liveness_sets) { |
71 | if (!e.first->outputs().empty()) { |
72 | std::cout << e.first->outputs()[0]->debugName(); |
73 | } |
74 | |
75 | std::cout << " " << e.first->kind().toQualString(); |
76 | std::cout << " = " ; |
77 | dump(e.second); |
78 | std::cout << std::endl; |
79 | } |
80 | std::cout << "graph :\n" ; |
81 | graph_->dump(); |
82 | } |
83 | |
84 | void dump(const std::vector<Value*>& set) { |
85 | bool first = true; |
86 | std::cout << "[" ; |
87 | for (auto el : set) { |
88 | if (first) { |
89 | first = false; |
90 | } else { |
91 | std::cout << ", " ; |
92 | } |
93 | std::cout << el->debugName() << "(" << el->unique() << ")" ; |
94 | } |
95 | std::cout << "]" ; |
96 | } |
97 | |
98 | private: |
99 | SparseBitVector toSparseBitVector(at::ArrayRef<Value*> values) { |
100 | SparseBitVector sbv; |
101 | for (auto v : values) { |
102 | ids_to_values_[v->unique()] = v; |
103 | sbv.set(v->unique()); |
104 | } |
105 | return sbv; |
106 | } |
107 | |
108 | std::vector<Value*> toValueVector(const SparseBitVector& sbv) { |
109 | std::vector<Value*> vec; |
110 | for (auto id : sbv) { |
111 | vec.push_back(ids_to_values_[id]); |
112 | } |
113 | return vec; |
114 | } |
115 | |
116 | SparseBitVector processBlock(Block* b, SparseBitVector liveness) { |
117 | // block outputs are the uses |
118 | auto block_outputs = toSparseBitVector(b->outputs()); |
119 | liveness |= block_outputs; |
120 | |
121 | SparseBitVector defs; |
122 | for (Node* it : b->nodes().reverse()) { |
123 | // kill outputs |
124 | liveness -= toSparseBitVector(it->outputs()); |
125 | if (it->kind() == prim::Loop) { |
126 | LoopView lv(it); |
127 | // N.B. merge in changes from the loop header |
128 | auto = *lv.bodyBlock()->nodes().begin(); |
129 | auto loop_block = liveness | liveness_sets_[loop_header]; |
130 | loop_block = processBlock(lv.bodyBlock(), loop_block); |
131 | // loop block's inputs die outside loop's block |
132 | loop_block -= toSparseBitVector(lv.bodyBlock()->inputs()); |
133 | liveness |= loop_block; |
134 | } else if (it->kind() == prim::If) { |
135 | IfView iv(it); |
136 | auto true_liveness = processBlock(iv.thenBlock(), liveness); |
137 | auto false_liveness = processBlock(iv.elseBlock(), liveness); |
138 | liveness |= true_liveness; |
139 | liveness |= false_liveness; |
140 | } |
141 | liveness |= toSparseBitVector(it->inputs()); |
142 | // `|=` returns true if new bits were set in LHS |
143 | // after or/union with `liveness` |
144 | auto changed = liveness_sets_[it] |= liveness; |
145 | changed_ = changed_ | changed; |
146 | } |
147 | return liveness; |
148 | } |
149 | |
150 | std::shared_ptr<Graph> graph_; |
151 | bool changed_; |
152 | std::map<Node*, SparseBitVector> liveness_sets_; |
153 | std::map<size_t, Value*> ids_to_values_; |
154 | }; |
155 | |
156 | std::unordered_map<Node*, std::vector<Value*>> BuildLivenessSets( |
157 | std::shared_ptr<Graph> graph) { |
158 | LivenessAnalyzer la(std::move(graph)); |
159 | return la.run(); |
160 | } |
161 | |
162 | } // namespace jit |
163 | } // namespace torch |
164 | |