1 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
2 | |
3 | #include <torch/csrc/jit/ir/alias_analysis.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/ir/node_hashing.h> |
6 | #include <torch/csrc/jit/jit_log.h> |
7 | |
8 | #include <unordered_map> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace { |
13 | |
14 | struct CommonSubexpressionEliminator { |
15 | CommonSubexpressionEliminator(std::shared_ptr<Graph> graph) |
16 | : graph_(std::move(graph)) {} |
17 | |
18 | bool run(std::function<Node*(Node*)> parent_lookup_fn) { |
19 | return run(graph_->block(), std::move(parent_lookup_fn)); |
20 | } |
21 | |
22 | // The function implements common subexpression elimination. |
23 | // Since the nodes are visited in topological order, one pass is enough. |
24 | // returns true if CSE made changes to a graph |
25 | bool run(Block* block, std::function<Node*(Node*)> parent_lookup_fn) { |
26 | std::unordered_set<Node*, HashNode, EqualNode> subexprs; |
27 | bool changed = false; |
28 | for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { |
29 | auto node = *it; |
30 | |
31 | if (node->kind() == prim::profile) { |
32 | GRAPH_DEBUG( |
33 | "Profiled nodes shouldn't be CSE'ed there's a separate pass that does dedup and merging:\n" , |
34 | *node); |
35 | continue; |
36 | } |
37 | |
38 | if (node->hasSideEffects()) { |
39 | GRAPH_DEBUG("Node was skipped due to side effects:\n" , *node); |
40 | continue; |
41 | } |
42 | if (node->isNondeterministic()) { |
43 | GRAPH_DEBUG("Node was skipped due to its non determinism:\n" , *node); |
44 | continue; |
45 | } |
46 | |
47 | if (!node->blocks().empty()) { |
48 | // Traverse sub-blocks. |
49 | for (auto block : node->blocks()) { |
50 | changed |= run(block, [&](Node* n) { |
51 | auto existing = subexprs.find(n); |
52 | if (existing != subexprs.end()) { |
53 | return *existing; |
54 | } |
55 | |
56 | return parent_lookup_fn(n); |
57 | }); |
58 | } |
59 | |
60 | continue; |
61 | } |
62 | |
63 | if (getOrCreateAliasDb().hasWriters(node)) { |
64 | GRAPH_DEBUG("Node was skipped due to alias analysis result:\n" , *node); |
65 | // Do NOT have enough information to do CSE on these nodes. |
66 | continue; |
67 | } |
68 | |
69 | // Check for CSE opportunities in the parent block. |
70 | auto parent_lookup = parent_lookup_fn(node); |
71 | auto g_out = node->owningGraph()->outputs(); |
72 | if (parent_lookup != nullptr) { |
73 | if (!getOrCreateAliasDb().safeToChangeAliasingRelationship( |
74 | node->outputs(), parent_lookup->outputs())) { |
75 | continue; |
76 | } |
77 | |
78 | GRAPH_UPDATE("Replacing\n" , *node, "with\n" , *parent_lookup); |
79 | changed = true; |
80 | node->replaceAllUsesWith(parent_lookup); |
81 | it.destroyCurrent(); |
82 | continue; |
83 | } |
84 | |
85 | // Check whether the same subexpression already exists. |
86 | auto subit = subexprs.insert(node); |
87 | if (!subit.second) { |
88 | // Subexpression exists, replace the uses of node, and destroy it. |
89 | auto existing = *subit.first; |
90 | |
91 | // don't introduce new aliasing among graph outputs |
92 | if (getOrCreateAliasDb().mayContainAlias( |
93 | node->outputs(), node->owningGraph()->outputs()) && |
94 | getOrCreateAliasDb().mayContainAlias(existing->outputs(), g_out)) { |
95 | continue; |
96 | } |
97 | |
98 | GRAPH_UPDATE("Replacing\n" , *node, "with\n" , *existing); |
99 | changed = true; |
100 | node->replaceAllUsesWith(existing); |
101 | // Destroy the node. |
102 | it.destroyCurrent(); |
103 | } |
104 | } |
105 | |
106 | return changed; |
107 | } |
108 | |
109 | AliasDb& getOrCreateAliasDb() { |
110 | if (!alias_db_) { |
111 | alias_db_ = std::make_unique<AliasDb>(graph_); |
112 | } |
113 | |
114 | return *alias_db_; |
115 | } |
116 | |
117 | private: |
118 | std::unique_ptr<AliasDb> alias_db_; |
119 | std::shared_ptr<Graph> graph_; |
120 | }; |
121 | |
122 | } // namespace |
123 | |
124 | bool EliminateCommonSubexpression(const std::shared_ptr<Graph>& graph) { |
125 | GRAPH_DUMP("Before CSE" , graph); |
126 | CommonSubexpressionEliminator cse(graph); |
127 | return cse.run([](Node*) { return nullptr; }); |
128 | } |
129 | } // namespace jit |
130 | } // namespace torch |
131 | |