1 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/ir/alias_analysis.h> |
4 | #include <torch/csrc/jit/ir/ir_views.h> |
5 | #include <torch/csrc/jit/jit_log.h> |
6 | #include <torch/csrc/utils/memory.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | void RemoveRedundantProfiles(Block* block, AliasDb& db) { |
12 | for (auto it = block->nodes().end()->reverseIterator(); |
13 | it != block->nodes().begin();) { |
14 | Node* n = *it; |
15 | it++; |
16 | |
17 | for (Block* b : n->blocks()) { |
18 | RemoveRedundantProfiles(b, db); |
19 | } |
20 | |
21 | // we only check prim::profile and not prim::profile_ivalue bc profile |
22 | // is inserted on each use, while profile_ivalue is inserted on the def |
23 | if (n->kind() != prim::profile || |
24 | n->input()->node()->kind() != prim::profile) { |
25 | continue; |
26 | } |
27 | |
28 | Node* input_node = n->input()->node(); |
29 | if (input_node->ty(attr::profiled_type) != n->ty(attr::profiled_type)) { |
30 | continue; |
31 | } |
32 | |
33 | if (!db.moveBeforeTopologicallyValid(input_node, n)) { |
34 | continue; |
35 | } |
36 | |
37 | n->output()->replaceAllUsesWith(n->input()); |
38 | n->destroy(); |
39 | } |
40 | } |
41 | |
42 | void RemoveRedundantProfiles(std::shared_ptr<Graph>& graph) { |
43 | AliasDb db(graph); |
44 | RemoveRedundantProfiles(graph->block(), db); |
45 | } |
46 | |
47 | } // namespace jit |
48 | } // namespace torch |
49 |