1 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/jit/ir/alias_analysis.h> |
5 | #include <torch/csrc/jit/ir/ir_views.h> |
6 | #include <torch/csrc/jit/jit_log.h> |
7 | #include <torch/csrc/utils/memory.h> |
8 | |
9 | #include <unordered_map> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | |
14 | namespace prim { |
15 | using namespace ::c10::prim; |
16 | } |
17 | |
18 | class DeadCodeEliminator { |
19 | public: |
20 | explicit DeadCodeEliminator( |
21 | std::shared_ptr<Graph> graph, |
22 | DCESideEffectPolicy sideEffectPolicy) |
23 | : sideEffectPolicy_(sideEffectPolicy), |
24 | graph_(std::move(graph)), |
25 | useAliasDb_(true) {} |
26 | DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy) |
27 | : sideEffectPolicy_(sideEffectPolicy) {} |
28 | |
29 | // The algorithm is an inverse mark-and-sweep. Starting from the return node, |
30 | // we mark "live" nodes that are necessary for the output. Nodes that have |
31 | // side effects are also marked. |
32 | void run(Block* block, bool recurse) { |
33 | // clean up unused fork inputs before starting the main algorithm |
34 | eliminateDeadForkInputs(block, recurse); |
35 | |
36 | // Initialize by marking the return node and all its consumed values as live |
37 | mark(block->return_node()); |
38 | |
39 | mark(block); |
40 | |
41 | deleteCallback_(liveValues_); |
42 | |
43 | sweep(block, recurse); |
44 | } |
45 | |
46 | void setDeleteCallback( |
47 | std::function<void(const std::unordered_set<const Value*>&)> |
48 | deleteCallback) { |
49 | deleteCallback_ = std::move(deleteCallback); |
50 | } |
51 | |
52 | private: |
53 | void eliminateDeadForkInputs(Block* block, bool recurse) { |
54 | for (Node* node : block->nodes()) { |
55 | if (recurse) { |
56 | for (Block* sb : node->blocks()) { |
57 | eliminateDeadForkInputs(sb, recurse); |
58 | } |
59 | } |
60 | if (node->kind() != prim::fork) { |
61 | continue; |
62 | } |
63 | Graph& g = *node->g(attr::Subgraph); |
64 | // WARNING: Do not use a ranged loop. The loop bounds are changed by the |
65 | // loop body. |
66 | for (size_t i = 0; i < g.inputs().size(); ++i) { |
67 | if (!g.inputs().at(i)->hasUses()) { |
68 | GRAPH_UPDATE( |
69 | "Dead " , |
70 | i, |
71 | "-th input " , |
72 | node->inputs().at(i)->debugName(), |
73 | "(" , |
74 | g.inputs().at(i)->debugName(), |
75 | " in a subgraph) will be removed" ); |
76 | g.eraseInput(i); |
77 | node->removeInput(i); |
78 | } |
79 | } |
80 | } |
81 | } |
82 | |
83 | // Special handling for block return nodes. Unlike other nodes, the block |
84 | // return node doesn't really "use" its inputs. Consider: |
85 | // |
86 | // %a0 = aten::foo() |
87 | // %b = aten::foo() |
88 | // %a2, %b2 = prim::If(%cond) { |
89 | // block0() { |
90 | // %a1 = aten::foo(%.0) |
91 | // %b1 = aten::foo(%b) |
92 | // } -> (%a1, %b1) |
93 | // } |
94 | // return (%a2) |
95 | // |
96 | // We want to be able to DCE all the %b stuff. So when processing block |
97 | // returns, we only mark producers for values that "live" (i.e. used outside |
98 | // the block). |
99 | // |
100 | // Returns true iff this marked something we haven't marked before. |
101 | bool markReturnNode(Node* node) { |
102 | if (marked_.count(node)) { |
103 | return false; |
104 | } |
105 | |
106 | AT_ASSERT(node->owningBlock()->return_node() == node); |
107 | auto outerNode = node->owningBlock()->owningNode(); |
108 | if (outerNode == nullptr || outerNode->kind() == prim::Reverse) { |
109 | // If there's no outer node, we're looking at the graph's top-level |
110 | // return block. We consider all graph outputs to be "used", so just mark |
111 | // this node normally. |
112 | return mark(node); |
113 | } |
114 | |
115 | // Collect all inputs that are actually live |
116 | if (outerNode->kind() == prim::Loop || |
117 | outerNode->kind() == c10::onnx::Loop) { |
118 | // Special handling to deal with loop carried dependencies. |
119 | auto loop = LoopView(outerNode); |
120 | for (const auto i : c10::irange(loop.carriedOutputs().size())) { |
121 | if (outerNode->kind() == c10::onnx::Loop) { |
122 | // Special handling for onnx loop. |
123 | // The number of body carried inputs and outputs are different. |
124 | // They cannot be mapped to each other easily by the same index. |
125 | liveValues_.insert(loop.bodyCarriedOutputs().at(i)); |
126 | continue; |
127 | } |
128 | auto innerInput = loop.bodyCarriedInputs().at(i); |
129 | auto innerOutput = loop.bodyCarriedOutputs().at(i); |
130 | auto outerOutput = loop.carriedOutputs().at(i); |
131 | if (liveValues_.count(outerOutput) || innerInput->hasUses()) { |
132 | liveValues_.insert(innerOutput); |
133 | } |
134 | } |
135 | |
136 | // Also mark the loop next condition as live, since it will be used inside |
137 | // the loop body. |
138 | liveValues_.insert(loop.nextCond()); |
139 | } else { |
140 | AT_ASSERT(outerNode->outputs().size() == node->inputs().size()); |
141 | for (const auto i : c10::irange(outerNode->outputs().size())) { |
142 | auto innerOutput = node->inputs()[i]; |
143 | auto outerOutput = outerNode->outputs()[i]; |
144 | if (liveValues_.count(outerOutput)) { |
145 | liveValues_.insert(innerOutput); |
146 | } |
147 | } |
148 | } |
149 | |
150 | marked_.insert(node); |
151 | return true; |
152 | } |
153 | |
154 | // Loops are special, because we need to run them to convergence. |
155 | // Consider the following loop: |
156 | // for i in range(3): |
157 | // tot += a[0][0] |
158 | // b = a[0] |
159 | // b[0] += 1 |
160 | // print(tot) |
161 | // |
162 | // If we only process the loop block once, we will conclude that `b[0]` and |
163 | // `b` are dead, even though `b[0] += 1` mutates a live memory location (since |
164 | // `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next |
165 | // iteration |
166 | // |
167 | // We need to mark the loop again with the information that `a` is live, and |
168 | // repeat until we're not marking new stuff anymore. |
169 | // |
170 | // Returns true iff this marked something we haven't marked before. |
171 | bool markLoop(Node* node) { |
172 | TORCH_INTERNAL_ASSERT(node->kind() == prim::Loop); |
173 | // Did a single iteration over the loop block mark anything new? |
174 | // If this is false, we've converged. |
175 | bool marked = false; |
176 | // Did we ever mark anything new? |
177 | bool anyMarked = false; |
178 | do { |
179 | marked = mark(node->blocks().at(0)); |
180 | anyMarked |= marked; |
181 | } while (marked); |
182 | return anyMarked; |
183 | } |
184 | |
185 | // Returns true iff this marked something we haven't marked before. |
186 | bool mark(Block* block) { |
187 | bool anyMarked = false; |
188 | // Mark all nodes with side effects. |
189 | for (auto node : block->nodes()) { |
190 | if (sideEffectPolicy_ == |
191 | DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS && |
192 | hasSideEffects(node)) { |
193 | anyMarked |= mark(node); |
194 | } |
195 | } |
196 | |
197 | // Initialize by marking the return node |
198 | anyMarked |= markReturnNode(block->return_node()); |
199 | |
200 | for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) { |
201 | auto node = *it; |
202 | if (node->kind() == prim::Loop) { |
203 | // Special casing for loops, see comment in markLoop. |
204 | anyMarked |= markLoop(node); |
205 | } else { |
206 | // Other nodes with sub-blocks get marked normally. |
207 | for (auto subBlock : node->blocks()) { |
208 | anyMarked |= mark(subBlock); |
209 | } |
210 | } |
211 | anyMarked |= markIfLive(node); |
212 | } |
213 | return anyMarked; |
214 | } |
215 | |
216 | // If we output or write to a live memory location, mark this node |
217 | // Returns true iff this marked something we haven't marked before. |
218 | bool markIfLive(Node* node) { |
219 | for (const auto output : node->outputs()) { |
220 | if (liveValues_.count(output)) { |
221 | return mark(node); |
222 | } |
223 | } |
224 | |
225 | if (useAliasDb_) { |
226 | if (getOrCreateAliasDb()->writesToAlias(node, liveValues_)) { |
227 | return mark(node); |
228 | } |
229 | } |
230 | |
231 | return false; |
232 | } |
233 | |
234 | // Mark this node as live and add this node's inputs and aliases to the live |
235 | // value sets. |
236 | // Returns true iff this marked something we haven't marked before. |
237 | bool mark(Node* node) { |
238 | if (marked_.count(node)) { |
239 | return false; |
240 | } |
241 | |
242 | marked_.insert(node); |
243 | |
244 | // Mark all nodes in this node's blockchain (since owning nodes are |
245 | // considered live if they contain a live node) |
246 | auto curNode = node; |
247 | while (curNode) { |
248 | if (!curNode->owningBlock()) { |
249 | break; |
250 | } |
251 | |
252 | mark(curNode); |
253 | curNode = curNode->owningBlock()->owningNode(); |
254 | } |
255 | |
256 | // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) |
257 | for (const auto input : node->inputs()) { |
258 | if (liveValues_.count(input)) { |
259 | continue; |
260 | } |
261 | liveValues_.insert(input); |
262 | } |
263 | return true; |
264 | } |
265 | |
266 | // Delete all unmarked nodes. |
267 | void sweep(Block* block, bool recurse) { |
268 | auto nodes = block->nodes().reverse(); |
269 | for (auto it = nodes.begin(); it != nodes.end(); it++) { |
270 | auto node = *it; |
271 | // note these occur before the recursion because we want to uncover |
272 | // dead code in the blocks used to calculate the output |
273 | removeDeadBlockOutputs(node); |
274 | removeDeadLoopOutputs(node); |
275 | if (recurse) { |
276 | for (Block* block : node->blocks()) { |
277 | sweep(block, true); |
278 | } |
279 | } |
280 | // NB: Checking hasUses() is required. AD graphs are not perfectly |
281 | // valid, as a node in grad_desc.f might be used in reverse_block. |
282 | // Reverse_block is inlined in grad_desc.f before it's separated |
283 | // to grad_desc.df. |
284 | if (!(marked_.count(node) || node->hasUses())) { |
285 | GRAPH_UPDATE( |
286 | "Node " , |
287 | it->kind().toQualString(), |
288 | " which outputs " , |
289 | (!node->outputs().empty() ? node->outputs().at(0)->debugName() |
290 | : "n/a" ), |
291 | " will be removed" ); |
292 | it.destroyCurrent(); |
293 | } |
294 | } |
295 | } |
296 | |
297 | bool hasUntrackedMutation(Node* node) { |
298 | if (!useAliasDb_) { |
299 | // If we don't have alias information, all mutable ops have unknown |
300 | // effects and can't be considered for elimination. |
301 | |
302 | if (node->kind() == prim::SetAttr) { |
303 | // SetAttr is a special case: it doesn't have a schema, but does |
304 | // have untracked mutations |
305 | return true; |
306 | } |
307 | |
308 | // onnx export calls EliminateDeadCode but sometimes passes invalid |
309 | // aten operators. So we call maybeSchema so we handle the cases when |
310 | // there is no valid schema for a node |
311 | auto schema = node->maybeSchema(); |
312 | return schema && schema->is_mutable(); |
313 | } else { |
314 | return getOrCreateAliasDb()->writesToWildcard(node); |
315 | } |
316 | } |
317 | |
318 | bool hasSideEffects(Node* node) { |
319 | auto it = memo_.find(node); |
320 | if (it != memo_.end()) |
321 | return it->second; |
322 | bool has_side_effects = node->hasSideEffects() || |
323 | std::any_of(node->blocks().begin(), |
324 | node->blocks().end(), |
325 | [&](Block* b) { |
326 | return std::any_of( |
327 | b->nodes().begin(), b->nodes().end(), [&](Node* n) { |
328 | return hasSideEffects(n); |
329 | }); |
330 | }) || |
331 | hasUntrackedMutation(node); |
332 | |
333 | memo_.emplace(node, has_side_effects); |
334 | return has_side_effects; |
335 | } |
336 | |
337 | void removeDeadBlockOutputs(Node* node) { |
338 | if (node->kind() != prim::If && node->kind() != prim::GradOf) { |
339 | return; |
340 | } |
341 | |
342 | for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) { |
343 | size_t i = i_1 - 1; |
344 | if (!node->outputs().at(i)->hasUses()) { |
345 | GRAPH_UPDATE( |
346 | "Dead " , |
347 | i, |
348 | "-th output " , |
349 | node->outputs().at(i)->debugName(), |
350 | " of node " , |
351 | node->kind().toQualString(), |
352 | " will be removed" ); |
353 | node->eraseOutput(i); |
354 | for (Block* b : node->blocks()) { |
355 | GRAPH_UPDATE( |
356 | "\tCorresponding block output " , |
357 | b->outputs().at(i)->debugName(), |
358 | " will be removed" ); |
359 | b->eraseOutput(i); |
360 | } |
361 | } |
362 | } |
363 | } |
364 | |
365 | void removeDeadLoopOutputs(Node* node) { |
366 | if (node->kind() != prim::Loop) |
367 | return; |
368 | auto loop_body = node->blocks().at(0); |
369 | auto loop_input_offset = 2; // offset of loop carried deps in input list |
370 | auto loop_body_offset = |
371 | 1; // offset to the loop carried dependencies in block inputs/outputs |
372 | |
373 | for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) { |
374 | size_t i = i_1 - 1; |
375 | if (!node->outputs().at(i)->hasUses() && |
376 | !loop_body->inputs().at(loop_body_offset + i)->hasUses()) { |
377 | logDeadLoopOutputs(node, i, loop_input_offset, loop_body_offset); |
378 | node->eraseOutput(i); |
379 | node->removeInput(loop_input_offset + i); |
380 | loop_body->eraseInput(loop_body_offset + i); |
381 | loop_body->eraseOutput(loop_body_offset + i); |
382 | } |
383 | } |
384 | } |
385 | |
386 | void logDeadLoopOutputs( |
387 | Node* node, |
388 | size_t i, |
389 | size_t loop_input_offset, |
390 | size_t loop_body_offset) { |
391 | auto loop_body = node->blocks().at(0); |
392 | GRAPH_UPDATE( |
393 | "Dead " , |
394 | loop_input_offset + i, |
395 | "-th input " , |
396 | node->inputs().at(i)->debugName(), |
397 | " will be removed" ); |
398 | GRAPH_UPDATE( |
399 | "Dead " , |
400 | i, |
401 | "-th output " , |
402 | node->outputs().at(i)->debugName(), |
403 | " will be removed" ); |
404 | GRAPH_UPDATE( |
405 | "\tDead block input " , |
406 | loop_body->inputs().at(loop_body_offset + i)->debugName(), |
407 | "at offset " , |
408 | loop_body_offset + i, |
409 | " will be removed" ); |
410 | GRAPH_UPDATE( |
411 | "\tDead block output " , |
412 | loop_body->outputs().at(loop_body_offset + i)->debugName(), |
413 | "at offset " , |
414 | loop_body_offset + i, |
415 | " will be removed" ); |
416 | } |
417 | |
418 | AliasDb* getOrCreateAliasDb() { |
419 | if (!aliasDb_) { |
420 | aliasDb_ = std::make_unique<AliasDb>(graph_); |
421 | } |
422 | return aliasDb_.get(); |
423 | } |
424 | |
425 | DCESideEffectPolicy sideEffectPolicy_; |
426 | |
427 | std::shared_ptr<Graph> graph_; |
428 | bool useAliasDb_ = false; |
429 | // lazily initialized |
430 | std::unique_ptr<AliasDb> aliasDb_ = nullptr; |
431 | std::unordered_map<Node*, bool> memo_; |
432 | std::unordered_set<Node*> marked_; |
433 | std::unordered_set<const Value*> liveValues_; |
434 | std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ = |
435 | [](const std::unordered_set<const Value*>&) {}; |
436 | }; |
437 | |
438 | void EliminateDeadCode( |
439 | const std::shared_ptr<Graph>& graph, |
440 | DCESideEffectPolicy sideEffectPolicy) { |
441 | DeadCodeEliminator(graph, sideEffectPolicy) |
442 | .run(graph->block(), /*recurse=*/true); |
443 | GRAPH_DUMP("After EliminateDeadCode: " , graph); |
444 | } |
445 | |
446 | void EliminateDeadCode( |
447 | Block* block, |
448 | bool recurse, |
449 | DCESideEffectPolicy sideEffectPolicy) { |
450 | DeadCodeEliminator(sideEffectPolicy).run(block, recurse); |
451 | } |
452 | |
453 | void EliminateDeadCode( |
454 | Block* block, |
455 | std::function<void(const std::unordered_set<const Value*>&)> cb, |
456 | DCESideEffectPolicy sideEffectPolicy) { |
457 | DeadCodeEliminator eliminator(sideEffectPolicy); |
458 | eliminator.setDeleteCallback(std::move(cb)); |
459 | eliminator.run(block, /*recurse=*/true); |
460 | } |
461 | |
462 | } // namespace jit |
463 | } // namespace torch |
464 | |