1#include <torch/csrc/jit/passes/bailout_graph.h>
2
3#include <ATen/core/function.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/jit/ir/alias_analysis.h>
6#include <torch/csrc/jit/ir/ir_views.h>
7#include <torch/csrc/jit/jit_log.h>
8#include <torch/csrc/jit/passes/clear_profiling.h>
9#include <torch/csrc/jit/passes/constant_pooling.h>
10#include <torch/csrc/jit/passes/liveness.h>
11#include <memory>
12#include <unordered_set>
13#include <utility>
14
15namespace torch {
16namespace jit {
17
18static bool shouldBeCapturedInByBailOut(Node* n) {
19 return n->kind() != prim::Constant;
20}
21
22struct BailOutGraphBuilderForNode {
23 explicit BailOutGraphBuilderForNode(
24 std::shared_ptr<Graph> graph,
25 std::shared_ptr<Graph> target)
26 : graph_(std::move(graph)), copy_graph_(std::move(target)) {}
27
28 // capture `old_value` into the bailout graph
29 // by creating a new input and mapping
30 // `old_value` to it
31 Value* addNewInputForValue(Value* old_value) {
32 auto node = old_value->node();
33 // this reduces the number of inputs to a bailout graph significantly
34 // making it easier to debug
35 if (node->kind() == prim::Constant) {
36 TORCH_INTERNAL_ASSERT(!shouldBeCapturedInByBailOut(node));
37 auto new_const = copy_graph_->createClone(node, {nullptr});
38 copy_graph_->block()->prependNode(new_const);
39 return new_const->output();
40 }
41
42 live_inputs_.push_back(old_value);
43 auto new_value = copy_graph_->block()->addInput();
44 GRAPH_DEBUG(
45 "Adding a new value %",
46 new_value->debugName(),
47 " for %",
48 old_value->debugName());
49 return mapValueAndCopyMetadata(old_value, new_value);
50 }
51
52 Value* mapValueAndCopyMetadata(Value* old_value, Value* new_value) {
53 this->old_to_new_[old_value] = new_value;
54 new_value->copyMetadata(old_value);
55 return new_value;
56 }
57
58 Value* getOrAddInputForValue(Value* v) {
59 if (this->old_to_new_.count(v) == 0) {
60 return addNewInputForValue(v);
61 } else {
62 return this->old_to_new_[v];
63 }
64 }
65
66 Value* getInputForValue(Value* v) {
67 TORCH_INTERNAL_ASSERT(this->old_to_new_.count(v));
68 return this->old_to_new_[v];
69 }
70
71 Node* cloneNode(Node* node) {
72 auto* block = copy_graph_->block();
73 auto env = [this](Value* v) { return getOrAddInputForValue(v); };
74
75 auto new_node = block->appendNode(copy_graph_->createClone(node, env));
76 for (size_t i = 0; i < node->outputs().size(); ++i) {
77 auto oo = node->outputs()[i];
78 auto no = new_node->outputs()[i];
79 old_to_new_[oo] = no;
80 }
81
82 return new_node;
83 }
84
85 // buildBailOutBlockFrom builds a bailout graph from
86 // a given node `n` until the end of the owning block
87 // If `n` belongs to `prim::If` or `prim::Loop`
88 // buildBailOutLoop/If continue
89 // from block's owning node (e.g. `prim::If` or
90 // `prim::Loop`)
91 void buildBailOutBlockFrom(Node* n) {
92 auto b = n->owningBlock();
93 for (auto it = n->iterator(); it != b->nodes().end(); it++) {
94 cloneNode(*it);
95 }
96
97 // we are either in `prim::If` or `prim::Loop`
98 // bailout graph building will continue from `outer_node` next
99 auto outer_node = n->owningBlock()->owningNode();
100 if (outer_node) {
101 if (outer_node->kind() == prim::Loop) {
102 buildBailOutLoop(outer_node);
103 } else if (outer_node->kind() == prim::If) {
104 buildBailOutIf(b->outputs(), outer_node);
105 } else {
106 AT_ERROR("Unexpected outer node");
107 }
108 }
109 }
110
111 void mapValues(
112 const at::ArrayRef<Value*> block_outputs,
113 const at::ArrayRef<Value*> carried_deps) {
114 TORCH_INTERNAL_ASSERT(block_outputs.size() == carried_deps.size());
115 for (const auto i : c10::irange(block_outputs.size())) {
116 auto nv = getOrAddInputForValue(block_outputs[i]);
117 old_to_new_[carried_deps[i]] = nv;
118 }
119 }
120
121 void buildBailOutLoop(Node* outer_node) {
122 LoopView lv(outer_node);
123 auto old_max_count = getOrAddInputForValue(lv.maxTripCount());
124 auto cur_iter = getInputForValue(lv.currentTripCount());
125 auto block_outputs = lv.bodyBlock()->outputs();
126
127 auto* block = copy_graph_->block();
128 // subtract the number of iterations
129 WithInsertPoint guard(*block->nodes().end());
130 auto updated_max_trip_count =
131 copy_graph_->insert(aten::sub, {old_max_count, cur_iter});
132 auto one = copy_graph_->insertConstant({1});
133 updated_max_trip_count =
134 copy_graph_->insert(aten::sub, {updated_max_trip_count, one});
135 auto cur_plus_one = copy_graph_->insert(aten::add, {one, cur_iter});
136
137 // We need to be careful when mapping `block_outputs` to continuation
138 // loop's inputs since `cloneFrom` will replace `%4` with the same value
139 // in both, `prim::Loop` and `aten::cat` in the example below:
140 //
141 // ... : Tensor = prim::Loop(%MAX_TRIP_COUNT, %COND, ..., %4)
142 // block0(%i.2 : int, ...):
143 // ...
144 // %y.5 : Double(3) = aten::cat(%22, %4)
145 // ...
146 //
147 // However for the cloned loop node, the values should be different.
148 // Namely, the value in `prim::Loop` should come from
149 // `lv.bodyBlock()->outputs()` which are mapped to the outputs of the
150 // current iteration whereas `%4` in `aten::cat` needs to be mapped to the
151 // cloned value of `%4` in a bailout graph. To work around this, we manually
152 // clone loop nodes
153
154 // map the residual loop's inputs to the outputs of the current iteration
155 // (i.e. `block_outputs`)
156 auto new_loop =
157 copy_graph_->insertNode(copy_graph_->create(prim::Loop, {}, 0))
158 ->setSourceRange(outer_node->sourceRange());
159 new_loop->addInput(updated_max_trip_count);
160 for (auto bo : block_outputs) {
161 new_loop->addInput(getOrAddInputForValue(bo));
162 }
163
164 // clone the loop body and map old loop's outputs to new loop's outputs
165 auto new_loop_body = new_loop->addBlock();
166 auto env = [this](Value* v) { return getOrAddInputForValue(v); };
167 new_loop_body->cloneFrom(lv.bodyBlock(), env);
168 for (auto ov : lv.carriedOutputs()) {
169 auto no = new_loop->addOutput();
170 mapValueAndCopyMetadata(ov, no);
171 }
172 LoopView new_lv(new_loop);
173 {
174 WithInsertPoint guard_in_loop(*new_lv.bodyBlock()->nodes().begin());
175 // `one` will be replaced with new_lv.currentTripCount()
176 // but it needs to be done after
177 // new_lv.currentTripCount()->replaceAllUsesWith(adj_iter_ctr);
178 // to avoid cyclical references
179 auto adj_iter_ctr = copy_graph_->insert(aten::add, {cur_plus_one, one});
180 new_lv.currentTripCount()->replaceAllUsesWith(adj_iter_ctr);
181 adj_iter_ctr->node()->replaceInputWith(one, new_lv.currentTripCount());
182 }
183
184 if (outer_node->next()) {
185 buildBailOutBlockFrom(outer_node->next());
186 }
187 }
188
189 void buildBailOutIf(
190 const at::ArrayRef<Value*> block_outputs,
191 Node* outer_node) {
192 auto if_outputs = outer_node->outputs();
193 mapValues(block_outputs, if_outputs);
194 buildBailOutBlockFrom(outer_node->next());
195 }
196
197 std::shared_ptr<Graph> buildBailOutGraphFrom(Node* n) {
198 // add graph inputs for guard's input
199 // and loop counts for loops `n` is contained in
200 // to make sure we can line bailout grap's inputs up properly
201 // with arguments to this BailOut node.
202 for (auto bi : n->inputs()) {
203 getOrAddInputForValue(bi);
204 }
205
206 buildBailOutBlockFrom(n);
207 // add graph outputs
208 for (auto ov : graph_->outputs()) {
209 copy_graph_->registerOutput(getOrAddInputForValue(ov));
210 }
211 return copy_graph_;
212 }
213
214 std::shared_ptr<Graph> graph_;
215 std::shared_ptr<Graph> copy_graph_;
216 std::vector<Value*> live_inputs_;
217 std::unordered_map<Value*, Value*> old_to_new_;
218};
219
220// `BailOutInserter` replaces prim::Guard nodes with
221// prim::BailOut nodes that allow interpreter to
222// resume execution of the unoptimized(deoptimized)
223// version of an original graph from a particular point
224struct BailOutInserter {
225 explicit BailOutInserter(std::shared_ptr<Graph> graph)
226 : graph_(std::move(graph)), bailout_index_(0) {}
227
228 void run() {
229 liveness_sets_ = BuildLivenessSets(graph_);
230 insertBailOuts(graph_->block());
231 replaceGuardsWithBailouts();
232 // embed a full original graph
233 addUnoptimizedFuncToBailouts();
234 }
235
236 // Packs the original unoptimized graph into a Function constant
237 // and add it as the first input to every prim::BailOut point
238 // This graph will be used to compute a bailout graph for
239 // any given bailout point
240 void addUnoptimizedFuncToBailouts() {
241 auto unoptimized_graph = graph_->copy();
242 auto unopt_func = graph_->create(prim::BailoutTemplate)
243 ->insertAfter(graph_->param_node());
244
245 // Returns an int so that we have an easy way to do graph traversal
246 unopt_func->output()->setType(IntType::get());
247 unopt_func->g_(attr::Subgraph, std::move(unoptimized_graph));
248 for (auto bn : bailouts_) {
249 bn->insertInput(0, unopt_func->output());
250 }
251 }
252
253 // Removes guards by hooking up the guarded tensor
254 // directly to its users and also clears
255 // profiling information on it.
256 void removeGuards(Block* b) {
257 for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
258 if (it->kind() == prim::Guard) {
259 // this will need to be profiled again
260 it->input()->setType(TensorType::get());
261 // destroy the guard
262 it->output()->replaceAllUsesWith(it->input());
263 it.destroyCurrent();
264 }
265
266 for (auto ib : it->blocks()) {
267 removeGuards(ib);
268 }
269 }
270 }
271
272 // replace each prim::Guard
273 // with its corresponding prim::BailOut
274 void replaceGuardsWithBailouts() {
275 for (auto e : replacements_) {
276 e.first->replaceAllUsesWith(e.second);
277 e.second->node()->insertAfter(e.first->node());
278 e.first->node()->destroy();
279 }
280 }
281
282 // Inserts prim::BailOut nodes for every prim::Guard
283 // Each BailOut point takes the set of inputs live
284 // at that particular execution point.
285 // An input is live if it's used beyond the guard/BailOut
286 // point to compute graph's outputs
287 void insertBailOuts(Block* b) {
288 for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
289 if (it->kind() == prim::Guard) {
290 auto bailout_node = b->owningGraph()->create(prim::BailOut);
291 bailouts_.push_back(bailout_node);
292
293 const auto& live_inputs = liveness_sets_[*it];
294
295 // guarded inputs come first
296 // currently, there's always one guarded input
297 bailout_node->addInput(it->input());
298 for (auto li : live_inputs) {
299 // Guarded inputs have already been added
300 // Also, skip some inputs that BailOutGraphBuilder can
301 // materialize into bailout graphs directly
302 if (!shouldBeCapturedInByBailOut(li->node()) || li == it->input()) {
303 continue;
304 }
305 bailout_node->addInput(li);
306 }
307
308 bailout_node->output()->setType(it->output()->type());
309 bailout_node->i_(attr::index, bailout_index_++);
310 // we can't immediately replace nodes since this action will corrupt
311 // the liveness sets of following BailOut nodes if any of their
312 // arguments are BailOut nodes themselves
313 replacements_.insert({it->output(), bailout_node->output()});
314
315 } else {
316 for (auto ib : it->blocks()) {
317 insertBailOuts(ib);
318 }
319 }
320 }
321 }
322
323 std::shared_ptr<Graph> graph_;
324 std::map<Node*, Node*> subgraphs;
325 std::size_t bailout_index_;
326 std::unordered_map<Node*, std::vector<Value*>> liveness_sets_;
327 std::vector<Node*> bailouts_;
328 std::map<Value*, Value*> replacements_;
329};
330
331void InsertBailOuts(std::shared_ptr<Graph> graph) {
332 BailOutInserter ibo(std::move(graph));
333 ibo.run();
334}
335
336// linearly scans through graph's nodes to locate prim::BailOut whose
337// index matches the given `index`
338static Node* locateBailOutNodeInUnoptimizedGraph(Block* b, int64_t index) {
339 for (auto n : b->nodes()) {
340 if ((n->kind() == prim::BailOut || n->kind() == prim::Guard) &&
341 n->hasAttribute(attr::index) && n->i(attr::index) == index) {
342 return n;
343 }
344 for (auto ib : n->blocks()) {
345 if (auto bn = locateBailOutNodeInUnoptimizedGraph(ib, index)) {
346 return bn;
347 }
348 }
349 }
350 return nullptr;
351}
352
353// Removes prim::BailOuts and hooks the guarded input directly
354// to its users
355static void removeBailouts(Block* b) {
356 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
357 if (it->kind() == prim::BailOut || it->kind() == prim::Guard) {
358 // clear profiling information
359 it->inputs().at(0)->setType(TensorType::get());
360 it->output()->replaceAllUsesWith(it->inputs().at(0));
361 it.destroyCurrent();
362 } else {
363 for (auto ib : it->blocks()) {
364 removeBailouts(ib);
365 }
366 }
367 }
368}
369
370// see `bailout_graph.h`
371TORCH_API std::shared_ptr<Graph> BuildBailOutGraphFrom(
372 int64_t bailout_index,
373 const std::shared_ptr<Graph>& orig,
374 const std::shared_ptr<Graph>& target) {
375 auto orig_bailout_node =
376 locateBailOutNodeInUnoptimizedGraph(orig->block(), bailout_index);
377
378 GRAPH_DEBUG("bailout triggered for ", *orig_bailout_node);
379 GRAPH_DUMP("original bailout graph ", orig);
380 TORCH_INTERNAL_ASSERT(
381 orig_bailout_node->inputs().at(0)->type()->cast<FunctionType>() ==
382 nullptr);
383 TORCH_INTERNAL_ASSERT(
384 orig_bailout_node &&
385 (orig_bailout_node->kind() == prim::BailOut ||
386 orig_bailout_node->kind() == prim::Guard) &&
387 bailout_index == orig_bailout_node->i(attr::index));
388 BailOutGraphBuilderForNode bg(orig, target);
389 auto bailout_graph = bg.buildBailOutGraphFrom(orig_bailout_node);
390
391 removeBailouts(bailout_graph->block());
392 ClearProfilingInformation(bailout_graph);
393 GRAPH_DUMP("bailout_graph ", bailout_graph);
394 return bailout_graph;
395}
396
397} // namespace jit
398} // namespace torch
399