1#include <torch/csrc/jit/passes/dead_code_elimination.h>
2
3#include <c10/util/irange.h>
4#include <torch/csrc/jit/ir/alias_analysis.h>
5#include <torch/csrc/jit/ir/ir_views.h>
6#include <torch/csrc/jit/jit_log.h>
7#include <torch/csrc/utils/memory.h>
8
9#include <unordered_map>
10
11namespace torch {
12namespace jit {
13
14namespace prim {
15using namespace ::c10::prim;
16}
17
18class DeadCodeEliminator {
19 public:
20 explicit DeadCodeEliminator(
21 std::shared_ptr<Graph> graph,
22 DCESideEffectPolicy sideEffectPolicy)
23 : sideEffectPolicy_(sideEffectPolicy),
24 graph_(std::move(graph)),
25 useAliasDb_(true) {}
26 DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy)
27 : sideEffectPolicy_(sideEffectPolicy) {}
28
29 // The algorithm is an inverse mark-and-sweep. Starting from the return node,
30 // we mark "live" nodes that are necessary for the output. Nodes that have
31 // side effects are also marked.
32 void run(Block* block, bool recurse) {
33 // clean up unused fork inputs before starting the main algorithm
34 eliminateDeadForkInputs(block, recurse);
35
36 // Initialize by marking the return node and all its consumed values as live
37 mark(block->return_node());
38
39 mark(block);
40
41 deleteCallback_(liveValues_);
42
43 sweep(block, recurse);
44 }
45
46 void setDeleteCallback(
47 std::function<void(const std::unordered_set<const Value*>&)>
48 deleteCallback) {
49 deleteCallback_ = std::move(deleteCallback);
50 }
51
52 private:
53 void eliminateDeadForkInputs(Block* block, bool recurse) {
54 for (Node* node : block->nodes()) {
55 if (recurse) {
56 for (Block* sb : node->blocks()) {
57 eliminateDeadForkInputs(sb, recurse);
58 }
59 }
60 if (node->kind() != prim::fork) {
61 continue;
62 }
63 Graph& g = *node->g(attr::Subgraph);
64 // WARNING: Do not use a ranged loop. The loop bounds are changed by the
65 // loop body.
66 for (size_t i = 0; i < g.inputs().size(); ++i) {
67 if (!g.inputs().at(i)->hasUses()) {
68 GRAPH_UPDATE(
69 "Dead ",
70 i,
71 "-th input ",
72 node->inputs().at(i)->debugName(),
73 "(",
74 g.inputs().at(i)->debugName(),
75 " in a subgraph) will be removed");
76 g.eraseInput(i);
77 node->removeInput(i);
78 }
79 }
80 }
81 }
82
83 // Special handling for block return nodes. Unlike other nodes, the block
84 // return node doesn't really "use" its inputs. Consider:
85 //
86 // %a0 = aten::foo()
87 // %b = aten::foo()
88 // %a2, %b2 = prim::If(%cond) {
89 // block0() {
90 // %a1 = aten::foo(%.0)
91 // %b1 = aten::foo(%b)
92 // } -> (%a1, %b1)
93 // }
94 // return (%a2)
95 //
96 // We want to be able to DCE all the %b stuff. So when processing block
97 // returns, we only mark producers for values that "live" (i.e. used outside
98 // the block).
99 //
100 // Returns true iff this marked something we haven't marked before.
101 bool markReturnNode(Node* node) {
102 if (marked_.count(node)) {
103 return false;
104 }
105
106 AT_ASSERT(node->owningBlock()->return_node() == node);
107 auto outerNode = node->owningBlock()->owningNode();
108 if (outerNode == nullptr || outerNode->kind() == prim::Reverse) {
109 // If there's no outer node, we're looking at the graph's top-level
110 // return block. We consider all graph outputs to be "used", so just mark
111 // this node normally.
112 return mark(node);
113 }
114
115 // Collect all inputs that are actually live
116 if (outerNode->kind() == prim::Loop ||
117 outerNode->kind() == c10::onnx::Loop) {
118 // Special handling to deal with loop carried dependencies.
119 auto loop = LoopView(outerNode);
120 for (const auto i : c10::irange(loop.carriedOutputs().size())) {
121 if (outerNode->kind() == c10::onnx::Loop) {
122 // Special handling for onnx loop.
123 // The number of body carried inputs and outputs are different.
124 // They cannot be mapped to each other easily by the same index.
125 liveValues_.insert(loop.bodyCarriedOutputs().at(i));
126 continue;
127 }
128 auto innerInput = loop.bodyCarriedInputs().at(i);
129 auto innerOutput = loop.bodyCarriedOutputs().at(i);
130 auto outerOutput = loop.carriedOutputs().at(i);
131 if (liveValues_.count(outerOutput) || innerInput->hasUses()) {
132 liveValues_.insert(innerOutput);
133 }
134 }
135
136 // Also mark the loop next condition as live, since it will be used inside
137 // the loop body.
138 liveValues_.insert(loop.nextCond());
139 } else {
140 AT_ASSERT(outerNode->outputs().size() == node->inputs().size());
141 for (const auto i : c10::irange(outerNode->outputs().size())) {
142 auto innerOutput = node->inputs()[i];
143 auto outerOutput = outerNode->outputs()[i];
144 if (liveValues_.count(outerOutput)) {
145 liveValues_.insert(innerOutput);
146 }
147 }
148 }
149
150 marked_.insert(node);
151 return true;
152 }
153
154 // Loops are special, because we need to run them to convergence.
155 // Consider the following loop:
156 // for i in range(3):
157 // tot += a[0][0]
158 // b = a[0]
159 // b[0] += 1
160 // print(tot)
161 //
162 // If we only process the loop block once, we will conclude that `b[0]` and
163 // `b` are dead, even though `b[0] += 1` mutates a live memory location (since
164 // `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next
165 // iteration
166 //
167 // We need to mark the loop again with the information that `a` is live, and
168 // repeat until we're not marking new stuff anymore.
169 //
170 // Returns true iff this marked something we haven't marked before.
171 bool markLoop(Node* node) {
172 TORCH_INTERNAL_ASSERT(node->kind() == prim::Loop);
173 // Did a single iteration over the loop block mark anything new?
174 // If this is false, we've converged.
175 bool marked = false;
176 // Did we ever mark anything new?
177 bool anyMarked = false;
178 do {
179 marked = mark(node->blocks().at(0));
180 anyMarked |= marked;
181 } while (marked);
182 return anyMarked;
183 }
184
185 // Returns true iff this marked something we haven't marked before.
186 bool mark(Block* block) {
187 bool anyMarked = false;
188 // Mark all nodes with side effects.
189 for (auto node : block->nodes()) {
190 if (sideEffectPolicy_ ==
191 DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS &&
192 hasSideEffects(node)) {
193 anyMarked |= mark(node);
194 }
195 }
196
197 // Initialize by marking the return node
198 anyMarked |= markReturnNode(block->return_node());
199
200 for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) {
201 auto node = *it;
202 if (node->kind() == prim::Loop) {
203 // Special casing for loops, see comment in markLoop.
204 anyMarked |= markLoop(node);
205 } else {
206 // Other nodes with sub-blocks get marked normally.
207 for (auto subBlock : node->blocks()) {
208 anyMarked |= mark(subBlock);
209 }
210 }
211 anyMarked |= markIfLive(node);
212 }
213 return anyMarked;
214 }
215
216 // If we output or write to a live memory location, mark this node
217 // Returns true iff this marked something we haven't marked before.
218 bool markIfLive(Node* node) {
219 for (const auto output : node->outputs()) {
220 if (liveValues_.count(output)) {
221 return mark(node);
222 }
223 }
224
225 if (useAliasDb_) {
226 if (getOrCreateAliasDb()->writesToAlias(node, liveValues_)) {
227 return mark(node);
228 }
229 }
230
231 return false;
232 }
233
234 // Mark this node as live and add this node's inputs and aliases to the live
235 // value sets.
236 // Returns true iff this marked something we haven't marked before.
237 bool mark(Node* node) {
238 if (marked_.count(node)) {
239 return false;
240 }
241
242 marked_.insert(node);
243
244 // Mark all nodes in this node's blockchain (since owning nodes are
245 // considered live if they contain a live node)
246 auto curNode = node;
247 while (curNode) {
248 if (!curNode->owningBlock()) {
249 break;
250 }
251
252 mark(curNode);
253 curNode = curNode->owningBlock()->owningNode();
254 }
255
256 // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
257 for (const auto input : node->inputs()) {
258 if (liveValues_.count(input)) {
259 continue;
260 }
261 liveValues_.insert(input);
262 }
263 return true;
264 }
265
266 // Delete all unmarked nodes.
267 void sweep(Block* block, bool recurse) {
268 auto nodes = block->nodes().reverse();
269 for (auto it = nodes.begin(); it != nodes.end(); it++) {
270 auto node = *it;
271 // note these occur before the recursion because we want to uncover
272 // dead code in the blocks used to calculate the output
273 removeDeadBlockOutputs(node);
274 removeDeadLoopOutputs(node);
275 if (recurse) {
276 for (Block* block : node->blocks()) {
277 sweep(block, true);
278 }
279 }
280 // NB: Checking hasUses() is required. AD graphs are not perfectly
281 // valid, as a node in grad_desc.f might be used in reverse_block.
282 // Reverse_block is inlined in grad_desc.f before it's separated
283 // to grad_desc.df.
284 if (!(marked_.count(node) || node->hasUses())) {
285 GRAPH_UPDATE(
286 "Node ",
287 it->kind().toQualString(),
288 " which outputs ",
289 (!node->outputs().empty() ? node->outputs().at(0)->debugName()
290 : "n/a"),
291 " will be removed");
292 it.destroyCurrent();
293 }
294 }
295 }
296
297 bool hasUntrackedMutation(Node* node) {
298 if (!useAliasDb_) {
299 // If we don't have alias information, all mutable ops have unknown
300 // effects and can't be considered for elimination.
301
302 if (node->kind() == prim::SetAttr) {
303 // SetAttr is a special case: it doesn't have a schema, but does
304 // have untracked mutations
305 return true;
306 }
307
308 // onnx export calls EliminateDeadCode but sometimes passes invalid
309 // aten operators. So we call maybeSchema so we handle the cases when
310 // there is no valid schema for a node
311 auto schema = node->maybeSchema();
312 return schema && schema->is_mutable();
313 } else {
314 return getOrCreateAliasDb()->writesToWildcard(node);
315 }
316 }
317
318 bool hasSideEffects(Node* node) {
319 auto it = memo_.find(node);
320 if (it != memo_.end())
321 return it->second;
322 bool has_side_effects = node->hasSideEffects() ||
323 std::any_of(node->blocks().begin(),
324 node->blocks().end(),
325 [&](Block* b) {
326 return std::any_of(
327 b->nodes().begin(), b->nodes().end(), [&](Node* n) {
328 return hasSideEffects(n);
329 });
330 }) ||
331 hasUntrackedMutation(node);
332
333 memo_.emplace(node, has_side_effects);
334 return has_side_effects;
335 }
336
337 void removeDeadBlockOutputs(Node* node) {
338 if (node->kind() != prim::If && node->kind() != prim::GradOf) {
339 return;
340 }
341
342 for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
343 size_t i = i_1 - 1;
344 if (!node->outputs().at(i)->hasUses()) {
345 GRAPH_UPDATE(
346 "Dead ",
347 i,
348 "-th output ",
349 node->outputs().at(i)->debugName(),
350 " of node ",
351 node->kind().toQualString(),
352 " will be removed");
353 node->eraseOutput(i);
354 for (Block* b : node->blocks()) {
355 GRAPH_UPDATE(
356 "\tCorresponding block output ",
357 b->outputs().at(i)->debugName(),
358 " will be removed");
359 b->eraseOutput(i);
360 }
361 }
362 }
363 }
364
365 void removeDeadLoopOutputs(Node* node) {
366 if (node->kind() != prim::Loop)
367 return;
368 auto loop_body = node->blocks().at(0);
369 auto loop_input_offset = 2; // offset of loop carried deps in input list
370 auto loop_body_offset =
371 1; // offset to the loop carried dependencies in block inputs/outputs
372
373 for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
374 size_t i = i_1 - 1;
375 if (!node->outputs().at(i)->hasUses() &&
376 !loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
377 logDeadLoopOutputs(node, i, loop_input_offset, loop_body_offset);
378 node->eraseOutput(i);
379 node->removeInput(loop_input_offset + i);
380 loop_body->eraseInput(loop_body_offset + i);
381 loop_body->eraseOutput(loop_body_offset + i);
382 }
383 }
384 }
385
386 void logDeadLoopOutputs(
387 Node* node,
388 size_t i,
389 size_t loop_input_offset,
390 size_t loop_body_offset) {
391 auto loop_body = node->blocks().at(0);
392 GRAPH_UPDATE(
393 "Dead ",
394 loop_input_offset + i,
395 "-th input ",
396 node->inputs().at(i)->debugName(),
397 " will be removed");
398 GRAPH_UPDATE(
399 "Dead ",
400 i,
401 "-th output ",
402 node->outputs().at(i)->debugName(),
403 " will be removed");
404 GRAPH_UPDATE(
405 "\tDead block input ",
406 loop_body->inputs().at(loop_body_offset + i)->debugName(),
407 "at offset ",
408 loop_body_offset + i,
409 " will be removed");
410 GRAPH_UPDATE(
411 "\tDead block output ",
412 loop_body->outputs().at(loop_body_offset + i)->debugName(),
413 "at offset ",
414 loop_body_offset + i,
415 " will be removed");
416 }
417
418 AliasDb* getOrCreateAliasDb() {
419 if (!aliasDb_) {
420 aliasDb_ = std::make_unique<AliasDb>(graph_);
421 }
422 return aliasDb_.get();
423 }
424
425 DCESideEffectPolicy sideEffectPolicy_;
426
427 std::shared_ptr<Graph> graph_;
428 bool useAliasDb_ = false;
429 // lazily initialized
430 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
431 std::unordered_map<Node*, bool> memo_;
432 std::unordered_set<Node*> marked_;
433 std::unordered_set<const Value*> liveValues_;
434 std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
435 [](const std::unordered_set<const Value*>&) {};
436};
437
438void EliminateDeadCode(
439 const std::shared_ptr<Graph>& graph,
440 DCESideEffectPolicy sideEffectPolicy) {
441 DeadCodeEliminator(graph, sideEffectPolicy)
442 .run(graph->block(), /*recurse=*/true);
443 GRAPH_DUMP("After EliminateDeadCode: ", graph);
444}
445
446void EliminateDeadCode(
447 Block* block,
448 bool recurse,
449 DCESideEffectPolicy sideEffectPolicy) {
450 DeadCodeEliminator(sideEffectPolicy).run(block, recurse);
451}
452
453void EliminateDeadCode(
454 Block* block,
455 std::function<void(const std::unordered_set<const Value*>&)> cb,
456 DCESideEffectPolicy sideEffectPolicy) {
457 DeadCodeEliminator eliminator(sideEffectPolicy);
458 eliminator.setDeleteCallback(std::move(cb));
459 eliminator.run(block, /*recurse=*/true);
460}
461
462} // namespace jit
463} // namespace torch
464