1#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
2
3#include <torch/csrc/jit/ir/alias_analysis.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/ir/node_hashing.h>
6#include <torch/csrc/jit/jit_log.h>
7
8#include <unordered_map>
9
10namespace torch {
11namespace jit {
12namespace {
13
14struct CommonSubexpressionEliminator {
15 CommonSubexpressionEliminator(std::shared_ptr<Graph> graph)
16 : graph_(std::move(graph)) {}
17
18 bool run(std::function<Node*(Node*)> parent_lookup_fn) {
19 return run(graph_->block(), std::move(parent_lookup_fn));
20 }
21
22 // The function implements common subexpression elimination.
23 // Since the nodes are visited in topological order, one pass is enough.
24 // returns true if CSE made changes to a graph
25 bool run(Block* block, std::function<Node*(Node*)> parent_lookup_fn) {
26 std::unordered_set<Node*, HashNode, EqualNode> subexprs;
27 bool changed = false;
28 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
29 auto node = *it;
30
31 if (node->kind() == prim::profile) {
32 GRAPH_DEBUG(
33 "Profiled nodes shouldn't be CSE'ed there's a separate pass that does dedup and merging:\n",
34 *node);
35 continue;
36 }
37
38 if (node->hasSideEffects()) {
39 GRAPH_DEBUG("Node was skipped due to side effects:\n", *node);
40 continue;
41 }
42 if (node->isNondeterministic()) {
43 GRAPH_DEBUG("Node was skipped due to its non determinism:\n", *node);
44 continue;
45 }
46
47 if (!node->blocks().empty()) {
48 // Traverse sub-blocks.
49 for (auto block : node->blocks()) {
50 changed |= run(block, [&](Node* n) {
51 auto existing = subexprs.find(n);
52 if (existing != subexprs.end()) {
53 return *existing;
54 }
55
56 return parent_lookup_fn(n);
57 });
58 }
59
60 continue;
61 }
62
63 if (getOrCreateAliasDb().hasWriters(node)) {
64 GRAPH_DEBUG("Node was skipped due to alias analysis result:\n", *node);
65 // Do NOT have enough information to do CSE on these nodes.
66 continue;
67 }
68
69 // Check for CSE opportunities in the parent block.
70 auto parent_lookup = parent_lookup_fn(node);
71 auto g_out = node->owningGraph()->outputs();
72 if (parent_lookup != nullptr) {
73 if (!getOrCreateAliasDb().safeToChangeAliasingRelationship(
74 node->outputs(), parent_lookup->outputs())) {
75 continue;
76 }
77
78 GRAPH_UPDATE("Replacing\n", *node, "with\n", *parent_lookup);
79 changed = true;
80 node->replaceAllUsesWith(parent_lookup);
81 it.destroyCurrent();
82 continue;
83 }
84
85 // Check whether the same subexpression already exists.
86 auto subit = subexprs.insert(node);
87 if (!subit.second) {
88 // Subexpression exists, replace the uses of node, and destroy it.
89 auto existing = *subit.first;
90
91 // don't introduce new aliasing among graph outputs
92 if (getOrCreateAliasDb().mayContainAlias(
93 node->outputs(), node->owningGraph()->outputs()) &&
94 getOrCreateAliasDb().mayContainAlias(existing->outputs(), g_out)) {
95 continue;
96 }
97
98 GRAPH_UPDATE("Replacing\n", *node, "with\n", *existing);
99 changed = true;
100 node->replaceAllUsesWith(existing);
101 // Destroy the node.
102 it.destroyCurrent();
103 }
104 }
105
106 return changed;
107 }
108
109 AliasDb& getOrCreateAliasDb() {
110 if (!alias_db_) {
111 alias_db_ = std::make_unique<AliasDb>(graph_);
112 }
113
114 return *alias_db_;
115 }
116
117 private:
118 std::unique_ptr<AliasDb> alias_db_;
119 std::shared_ptr<Graph> graph_;
120};
121
122} // namespace
123
124bool EliminateCommonSubexpression(const std::shared_ptr<Graph>& graph) {
125 GRAPH_DUMP("Before CSE", graph);
126 CommonSubexpressionEliminator cse(graph);
127 return cse.run([](Node*) { return nullptr; });
128}
129} // namespace jit
130} // namespace torch
131