1 | #include <torch/csrc/jit/passes/bailout_graph.h> |
2 | |
3 | #include <ATen/core/function.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/jit/ir/alias_analysis.h> |
6 | #include <torch/csrc/jit/ir/ir_views.h> |
7 | #include <torch/csrc/jit/jit_log.h> |
8 | #include <torch/csrc/jit/passes/clear_profiling.h> |
9 | #include <torch/csrc/jit/passes/constant_pooling.h> |
10 | #include <torch/csrc/jit/passes/liveness.h> |
11 | #include <memory> |
12 | #include <unordered_set> |
13 | #include <utility> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | |
18 | static bool shouldBeCapturedInByBailOut(Node* n) { |
19 | return n->kind() != prim::Constant; |
20 | } |
21 | |
22 | struct BailOutGraphBuilderForNode { |
23 | explicit BailOutGraphBuilderForNode( |
24 | std::shared_ptr<Graph> graph, |
25 | std::shared_ptr<Graph> target) |
26 | : graph_(std::move(graph)), copy_graph_(std::move(target)) {} |
27 | |
28 | // capture `old_value` into the bailout graph |
29 | // by creating a new input and mapping |
30 | // `old_value` to it |
31 | Value* addNewInputForValue(Value* old_value) { |
32 | auto node = old_value->node(); |
33 | // this reduces the number of inputs to a bailout graph significantly |
34 | // making it easier to debug |
35 | if (node->kind() == prim::Constant) { |
36 | TORCH_INTERNAL_ASSERT(!shouldBeCapturedInByBailOut(node)); |
37 | auto new_const = copy_graph_->createClone(node, {nullptr}); |
38 | copy_graph_->block()->prependNode(new_const); |
39 | return new_const->output(); |
40 | } |
41 | |
42 | live_inputs_.push_back(old_value); |
43 | auto new_value = copy_graph_->block()->addInput(); |
44 | GRAPH_DEBUG( |
45 | "Adding a new value %" , |
46 | new_value->debugName(), |
47 | " for %" , |
48 | old_value->debugName()); |
49 | return mapValueAndCopyMetadata(old_value, new_value); |
50 | } |
51 | |
52 | Value* mapValueAndCopyMetadata(Value* old_value, Value* new_value) { |
53 | this->old_to_new_[old_value] = new_value; |
54 | new_value->copyMetadata(old_value); |
55 | return new_value; |
56 | } |
57 | |
58 | Value* getOrAddInputForValue(Value* v) { |
59 | if (this->old_to_new_.count(v) == 0) { |
60 | return addNewInputForValue(v); |
61 | } else { |
62 | return this->old_to_new_[v]; |
63 | } |
64 | } |
65 | |
66 | Value* getInputForValue(Value* v) { |
67 | TORCH_INTERNAL_ASSERT(this->old_to_new_.count(v)); |
68 | return this->old_to_new_[v]; |
69 | } |
70 | |
71 | Node* cloneNode(Node* node) { |
72 | auto* block = copy_graph_->block(); |
73 | auto env = [this](Value* v) { return getOrAddInputForValue(v); }; |
74 | |
75 | auto new_node = block->appendNode(copy_graph_->createClone(node, env)); |
76 | for (size_t i = 0; i < node->outputs().size(); ++i) { |
77 | auto oo = node->outputs()[i]; |
78 | auto no = new_node->outputs()[i]; |
79 | old_to_new_[oo] = no; |
80 | } |
81 | |
82 | return new_node; |
83 | } |
84 | |
85 | // buildBailOutBlockFrom builds a bailout graph from |
86 | // a given node `n` until the end of the owning block |
87 | // If `n` belongs to `prim::If` or `prim::Loop` |
88 | // buildBailOutLoop/If continue |
89 | // from block's owning node (e.g. `prim::If` or |
90 | // `prim::Loop`) |
91 | void buildBailOutBlockFrom(Node* n) { |
92 | auto b = n->owningBlock(); |
93 | for (auto it = n->iterator(); it != b->nodes().end(); it++) { |
94 | cloneNode(*it); |
95 | } |
96 | |
97 | // we are either in `prim::If` or `prim::Loop` |
98 | // bailout graph building will continue from `outer_node` next |
99 | auto outer_node = n->owningBlock()->owningNode(); |
100 | if (outer_node) { |
101 | if (outer_node->kind() == prim::Loop) { |
102 | buildBailOutLoop(outer_node); |
103 | } else if (outer_node->kind() == prim::If) { |
104 | buildBailOutIf(b->outputs(), outer_node); |
105 | } else { |
106 | AT_ERROR("Unexpected outer node" ); |
107 | } |
108 | } |
109 | } |
110 | |
111 | void mapValues( |
112 | const at::ArrayRef<Value*> block_outputs, |
113 | const at::ArrayRef<Value*> carried_deps) { |
114 | TORCH_INTERNAL_ASSERT(block_outputs.size() == carried_deps.size()); |
115 | for (const auto i : c10::irange(block_outputs.size())) { |
116 | auto nv = getOrAddInputForValue(block_outputs[i]); |
117 | old_to_new_[carried_deps[i]] = nv; |
118 | } |
119 | } |
120 | |
121 | void buildBailOutLoop(Node* outer_node) { |
122 | LoopView lv(outer_node); |
123 | auto old_max_count = getOrAddInputForValue(lv.maxTripCount()); |
124 | auto cur_iter = getInputForValue(lv.currentTripCount()); |
125 | auto block_outputs = lv.bodyBlock()->outputs(); |
126 | |
127 | auto* block = copy_graph_->block(); |
128 | // subtract the number of iterations |
129 | WithInsertPoint guard(*block->nodes().end()); |
130 | auto updated_max_trip_count = |
131 | copy_graph_->insert(aten::sub, {old_max_count, cur_iter}); |
132 | auto one = copy_graph_->insertConstant({1}); |
133 | updated_max_trip_count = |
134 | copy_graph_->insert(aten::sub, {updated_max_trip_count, one}); |
135 | auto cur_plus_one = copy_graph_->insert(aten::add, {one, cur_iter}); |
136 | |
137 | // We need to be careful when mapping `block_outputs` to continuation |
138 | // loop's inputs since `cloneFrom` will replace `%4` with the same value |
139 | // in both, `prim::Loop` and `aten::cat` in the example below: |
140 | // |
141 | // ... : Tensor = prim::Loop(%MAX_TRIP_COUNT, %COND, ..., %4) |
142 | // block0(%i.2 : int, ...): |
143 | // ... |
144 | // %y.5 : Double(3) = aten::cat(%22, %4) |
145 | // ... |
146 | // |
147 | // However for the cloned loop node, the values should be different. |
148 | // Namely, the value in `prim::Loop` should come from |
149 | // `lv.bodyBlock()->outputs()` which are mapped to the outputs of the |
150 | // current iteration whereas `%4` in `aten::cat` needs to be mapped to the |
151 | // cloned value of `%4` in a bailout graph. To work around this, we manually |
152 | // clone loop nodes |
153 | |
154 | // map the residual loop's inputs to the outputs of the current iteration |
155 | // (i.e. `block_outputs`) |
156 | auto new_loop = |
157 | copy_graph_->insertNode(copy_graph_->create(prim::Loop, {}, 0)) |
158 | ->setSourceRange(outer_node->sourceRange()); |
159 | new_loop->addInput(updated_max_trip_count); |
160 | for (auto bo : block_outputs) { |
161 | new_loop->addInput(getOrAddInputForValue(bo)); |
162 | } |
163 | |
164 | // clone the loop body and map old loop's outputs to new loop's outputs |
165 | auto new_loop_body = new_loop->addBlock(); |
166 | auto env = [this](Value* v) { return getOrAddInputForValue(v); }; |
167 | new_loop_body->cloneFrom(lv.bodyBlock(), env); |
168 | for (auto ov : lv.carriedOutputs()) { |
169 | auto no = new_loop->addOutput(); |
170 | mapValueAndCopyMetadata(ov, no); |
171 | } |
172 | LoopView new_lv(new_loop); |
173 | { |
174 | WithInsertPoint guard_in_loop(*new_lv.bodyBlock()->nodes().begin()); |
175 | // `one` will be replaced with new_lv.currentTripCount() |
176 | // but it needs to be done after |
177 | // new_lv.currentTripCount()->replaceAllUsesWith(adj_iter_ctr); |
178 | // to avoid cyclical references |
179 | auto adj_iter_ctr = copy_graph_->insert(aten::add, {cur_plus_one, one}); |
180 | new_lv.currentTripCount()->replaceAllUsesWith(adj_iter_ctr); |
181 | adj_iter_ctr->node()->replaceInputWith(one, new_lv.currentTripCount()); |
182 | } |
183 | |
184 | if (outer_node->next()) { |
185 | buildBailOutBlockFrom(outer_node->next()); |
186 | } |
187 | } |
188 | |
189 | void buildBailOutIf( |
190 | const at::ArrayRef<Value*> block_outputs, |
191 | Node* outer_node) { |
192 | auto if_outputs = outer_node->outputs(); |
193 | mapValues(block_outputs, if_outputs); |
194 | buildBailOutBlockFrom(outer_node->next()); |
195 | } |
196 | |
197 | std::shared_ptr<Graph> buildBailOutGraphFrom(Node* n) { |
198 | // add graph inputs for guard's input |
199 | // and loop counts for loops `n` is contained in |
200 | // to make sure we can line bailout grap's inputs up properly |
201 | // with arguments to this BailOut node. |
202 | for (auto bi : n->inputs()) { |
203 | getOrAddInputForValue(bi); |
204 | } |
205 | |
206 | buildBailOutBlockFrom(n); |
207 | // add graph outputs |
208 | for (auto ov : graph_->outputs()) { |
209 | copy_graph_->registerOutput(getOrAddInputForValue(ov)); |
210 | } |
211 | return copy_graph_; |
212 | } |
213 | |
214 | std::shared_ptr<Graph> graph_; |
215 | std::shared_ptr<Graph> copy_graph_; |
216 | std::vector<Value*> live_inputs_; |
217 | std::unordered_map<Value*, Value*> old_to_new_; |
218 | }; |
219 | |
220 | // `BailOutInserter` replaces prim::Guard nodes with |
221 | // prim::BailOut nodes that allow interpreter to |
222 | // resume execution of the unoptimized(deoptimized) |
223 | // version of an original graph from a particular point |
224 | struct BailOutInserter { |
225 | explicit BailOutInserter(std::shared_ptr<Graph> graph) |
226 | : graph_(std::move(graph)), bailout_index_(0) {} |
227 | |
228 | void run() { |
229 | liveness_sets_ = BuildLivenessSets(graph_); |
230 | insertBailOuts(graph_->block()); |
231 | replaceGuardsWithBailouts(); |
232 | // embed a full original graph |
233 | addUnoptimizedFuncToBailouts(); |
234 | } |
235 | |
236 | // Packs the original unoptimized graph into a Function constant |
237 | // and add it as the first input to every prim::BailOut point |
238 | // This graph will be used to compute a bailout graph for |
239 | // any given bailout point |
240 | void addUnoptimizedFuncToBailouts() { |
241 | auto unoptimized_graph = graph_->copy(); |
242 | auto unopt_func = graph_->create(prim::BailoutTemplate) |
243 | ->insertAfter(graph_->param_node()); |
244 | |
245 | // Returns an int so that we have an easy way to do graph traversal |
246 | unopt_func->output()->setType(IntType::get()); |
247 | unopt_func->g_(attr::Subgraph, std::move(unoptimized_graph)); |
248 | for (auto bn : bailouts_) { |
249 | bn->insertInput(0, unopt_func->output()); |
250 | } |
251 | } |
252 | |
253 | // Removes guards by hooking up the guarded tensor |
254 | // directly to its users and also clears |
255 | // profiling information on it. |
256 | void removeGuards(Block* b) { |
257 | for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) { |
258 | if (it->kind() == prim::Guard) { |
259 | // this will need to be profiled again |
260 | it->input()->setType(TensorType::get()); |
261 | // destroy the guard |
262 | it->output()->replaceAllUsesWith(it->input()); |
263 | it.destroyCurrent(); |
264 | } |
265 | |
266 | for (auto ib : it->blocks()) { |
267 | removeGuards(ib); |
268 | } |
269 | } |
270 | } |
271 | |
272 | // replace each prim::Guard |
273 | // with its corresponding prim::BailOut |
274 | void replaceGuardsWithBailouts() { |
275 | for (auto e : replacements_) { |
276 | e.first->replaceAllUsesWith(e.second); |
277 | e.second->node()->insertAfter(e.first->node()); |
278 | e.first->node()->destroy(); |
279 | } |
280 | } |
281 | |
282 | // Inserts prim::BailOut nodes for every prim::Guard |
283 | // Each BailOut point takes the set of inputs live |
284 | // at that particular execution point. |
285 | // An input is live if it's used beyond the guard/BailOut |
286 | // point to compute graph's outputs |
287 | void insertBailOuts(Block* b) { |
288 | for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) { |
289 | if (it->kind() == prim::Guard) { |
290 | auto bailout_node = b->owningGraph()->create(prim::BailOut); |
291 | bailouts_.push_back(bailout_node); |
292 | |
293 | const auto& live_inputs = liveness_sets_[*it]; |
294 | |
295 | // guarded inputs come first |
296 | // currently, there's always one guarded input |
297 | bailout_node->addInput(it->input()); |
298 | for (auto li : live_inputs) { |
299 | // Guarded inputs have already been added |
300 | // Also, skip some inputs that BailOutGraphBuilder can |
301 | // materialize into bailout graphs directly |
302 | if (!shouldBeCapturedInByBailOut(li->node()) || li == it->input()) { |
303 | continue; |
304 | } |
305 | bailout_node->addInput(li); |
306 | } |
307 | |
308 | bailout_node->output()->setType(it->output()->type()); |
309 | bailout_node->i_(attr::index, bailout_index_++); |
310 | // we can't immediately replace nodes since this action will corrupt |
311 | // the liveness sets of following BailOut nodes if any of their |
312 | // arguments are BailOut nodes themselves |
313 | replacements_.insert({it->output(), bailout_node->output()}); |
314 | |
315 | } else { |
316 | for (auto ib : it->blocks()) { |
317 | insertBailOuts(ib); |
318 | } |
319 | } |
320 | } |
321 | } |
322 | |
323 | std::shared_ptr<Graph> graph_; |
324 | std::map<Node*, Node*> subgraphs; |
325 | std::size_t bailout_index_; |
326 | std::unordered_map<Node*, std::vector<Value*>> liveness_sets_; |
327 | std::vector<Node*> bailouts_; |
328 | std::map<Value*, Value*> replacements_; |
329 | }; |
330 | |
331 | void InsertBailOuts(std::shared_ptr<Graph> graph) { |
332 | BailOutInserter ibo(std::move(graph)); |
333 | ibo.run(); |
334 | } |
335 | |
336 | // linearly scans through graph's nodes to locate prim::BailOut whose |
337 | // index matches the given `index` |
338 | static Node* locateBailOutNodeInUnoptimizedGraph(Block* b, int64_t index) { |
339 | for (auto n : b->nodes()) { |
340 | if ((n->kind() == prim::BailOut || n->kind() == prim::Guard) && |
341 | n->hasAttribute(attr::index) && n->i(attr::index) == index) { |
342 | return n; |
343 | } |
344 | for (auto ib : n->blocks()) { |
345 | if (auto bn = locateBailOutNodeInUnoptimizedGraph(ib, index)) { |
346 | return bn; |
347 | } |
348 | } |
349 | } |
350 | return nullptr; |
351 | } |
352 | |
353 | // Removes prim::BailOuts and hooks the guarded input directly |
354 | // to its users |
355 | static void removeBailouts(Block* b) { |
356 | for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { |
357 | if (it->kind() == prim::BailOut || it->kind() == prim::Guard) { |
358 | // clear profiling information |
359 | it->inputs().at(0)->setType(TensorType::get()); |
360 | it->output()->replaceAllUsesWith(it->inputs().at(0)); |
361 | it.destroyCurrent(); |
362 | } else { |
363 | for (auto ib : it->blocks()) { |
364 | removeBailouts(ib); |
365 | } |
366 | } |
367 | } |
368 | } |
369 | |
370 | // see `bailout_graph.h` |
371 | TORCH_API std::shared_ptr<Graph> BuildBailOutGraphFrom( |
372 | int64_t bailout_index, |
373 | const std::shared_ptr<Graph>& orig, |
374 | const std::shared_ptr<Graph>& target) { |
375 | auto orig_bailout_node = |
376 | locateBailOutNodeInUnoptimizedGraph(orig->block(), bailout_index); |
377 | |
378 | GRAPH_DEBUG("bailout triggered for " , *orig_bailout_node); |
379 | GRAPH_DUMP("original bailout graph " , orig); |
380 | TORCH_INTERNAL_ASSERT( |
381 | orig_bailout_node->inputs().at(0)->type()->cast<FunctionType>() == |
382 | nullptr); |
383 | TORCH_INTERNAL_ASSERT( |
384 | orig_bailout_node && |
385 | (orig_bailout_node->kind() == prim::BailOut || |
386 | orig_bailout_node->kind() == prim::Guard) && |
387 | bailout_index == orig_bailout_node->i(attr::index)); |
388 | BailOutGraphBuilderForNode bg(orig, target); |
389 | auto bailout_graph = bg.buildBailOutGraphFrom(orig_bailout_node); |
390 | |
391 | removeBailouts(bailout_graph->block()); |
392 | ClearProfilingInformation(bailout_graph); |
393 | GRAPH_DUMP("bailout_graph " , bailout_graph); |
394 | return bailout_graph; |
395 | } |
396 | |
397 | } // namespace jit |
398 | } // namespace torch |
399 | |