1#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2
3#include <torch/csrc/jit/ir/irparser.h>
4#include <torch/csrc/jit/ir/subgraph_matcher.h>
5
6#include <c10/util/irange.h>
7
8#include <utility>
9
10namespace torch {
11namespace jit {
12
13namespace {
14void update_source_range_and_cs_ptr(
15 const std::set<const Node*>& input_nodes,
16 const Match& m,
17 std::unordered_map<Node*, Node*>& pattern_node_map) {
18 // pattern_node_map, maps nodes of the replacement graph
19 // to the nodes of the pattern graph.
20 // Now we iterate over each node of the replacement graph
21 // and find the corresponding pattern node in the match.
22 // The matched's node's source range and callstack is then
23 // used to update replacement node's source range and callstack
24 for (auto& it : pattern_node_map) {
25 Node* replacement_node = it.first;
26 Node* pattern_node = it.second;
27 if (!input_nodes.count(pattern_node)) {
28 Node* orig_node = m.nodes_map.at(pattern_node);
29 replacement_node->setSourceRange(orig_node->sourceRange());
30 if (orig_node->callstack()) {
31 replacement_node->setCallStack(orig_node->callstack().value());
32 }
33 }
34 }
35}
36} // namespace
37
38void SubgraphRewriter::RegisterDefaultPatterns() {
39 // TODO: Add actual patterns (like Conv-Relu).
40 RegisterRewritePattern(
41 R"IR(
42graph(%x, %w, %b):
43 %c = aten::conv(%x, %w, %b)
44 %r = aten::relu(%c)
45 return (%r))IR",
46 R"IR(
47graph(%x, %w, %b):
48 %r = aten::convrelu(%x, %w, %b)
49 return (%r))IR",
50 {{"r", "c"}});
51}
52
53void SubgraphRewriter::RegisterRewritePattern(
54 const std::string& pattern,
55 const std::string& replacement,
56 const std::vector<std::pair<std::string, std::string>>& value_name_pairs) {
57 std::unordered_map<std::string, std::string> value_name_map(
58 value_name_pairs.begin(), value_name_pairs.end());
59 RewritePatternDescr d = {pattern, replacement, std::move(value_name_map)};
60 patterns_.push_back(std::move(d));
61}
62
63Module SubgraphRewriter::runOnModule(const Module& module) {
64 nodes_to_delete_.clear();
65 for (const auto& m : module.get_methods()) {
66 auto g = toGraphFunction(m.function()).graph();
67 runOnGraph(g);
68 }
69 return module;
70}
71
72void SubgraphRewriter::runOnGraph(
73 std::shared_ptr<Graph>& graph,
74 const std::vector<MatchFilter>& filters) {
75 for (const RewritePatternDescr& pattern : patterns_) {
76 rewriteSinglePatternOnGraph(graph, pattern, filters);
77 }
78}
79
80void SubgraphRewriter::rewriteSinglePatternOnGraph(
81 std::shared_ptr<Graph>& graph,
82 const RewritePatternDescr& pattern,
83 const std::vector<MatchFilter>& filters) {
84 std::unordered_map<Value*, Value*> rewrite_map;
85 std::vector<Value*> values_to_rewrite;
86
87 Graph pattern_graph;
88 std::unordered_map<std::string, Value*> vmap;
89 parseIR(pattern.pattern, &pattern_graph, vmap);
90
91 Graph replacement_graph;
92 std::unordered_map<std::string, Value*> vmap_replacement;
93 parseIR(pattern.replacement, &replacement_graph, vmap_replacement);
94
95 // First construct map of Node*-to-Node*
96 // This maps Nodes in replacement graph to nodes in pattern graph
97 // given the value_name_map, which maps value names from repalcement
98 // pattern to value name in pattern
99 std::unordered_map<Node*, Node*> pattern_node_map;
100 std::set<const Node*> pattern_input_nodes;
101 for (auto& it : vmap_replacement) {
102 const auto& replacement_value_name = it.first;
103 Node* replacement_value_node = it.second->node();
104 if (pattern.value_name_map.count(replacement_value_name)) {
105 const auto& pattern_value_name =
106 pattern.value_name_map.at(replacement_value_name);
107 TORCH_CHECK(
108 vmap.count(pattern_value_name),
109 "Value must be found in the replacement graph.");
110 Node* pattern_value_node = vmap.at(pattern_value_name)->node();
111 pattern_node_map.emplace(replacement_value_node, pattern_value_node);
112 }
113 }
114
115 const auto& matches = findPatternMatches(pattern_graph, *graph);
116 for (const Match& match : matches) {
117 if (!std::all_of(filters.begin(), filters.end(), [&](const MatchFilter& f) {
118 return f(match, vmap);
119 })) {
120 continue;
121 }
122 // Matches might overlap with each other, in that case some of the nodes in
123 // the current match might have already been used in another folded pattern.
124 // We need to skip such matches.
125 if (overlapsWithPreviousMatches(&match)) {
126 continue;
127 }
128
129 // Figure out what values we need to use as inputs and outputs for the
130 // replacement subgraph and where the replacement subgraph needs to be
131 // inserted.
132 Node* ins_point = nullptr;
133 std::vector<Value*> inputs, outputs;
134 for (Value* v : pattern_graph.inputs()) {
135 Value* input = match.values_map.at(v);
136 if (!ins_point || ins_point->isBefore(input->node())) {
137 ins_point = input->node();
138 }
139 inputs.push_back(input);
140 }
141 AT_ASSERT(ins_point);
142
143 // Check that the insertion point we've chosen precedes all the uses of the
144 // outputs - otherwise the replacement is incorrect and we have to skip it.
145 bool ins_point_before_uses = true;
146 for (Value* v : pattern_graph.outputs()) {
147 Value* output = match.values_map.at(v);
148 outputs.push_back(match.values_map.at(v));
149
150 for (const Use& u : output->uses()) {
151 if (u.user->isBefore(ins_point)) {
152 ins_point_before_uses = false;
153 break;
154 }
155 }
156 }
157
158 if (!ins_point_before_uses) {
159 continue;
160 }
161
162 // Before rewriting the graph, update source range and callstack
163 // info of the replacement pattern graph so that the rewritten graph
164 // has the updated info
165 update_source_range_and_cs_ptr(
166 pattern_input_nodes, match, pattern_node_map);
167 // Insert a clone of replacement subgraph.
168 // `inputs` vector holds values that we would use as incoming values to the
169 // new subgraph, and we will get `new_outputs` vector containing values
170 // produced by this new subgraph - we will then rewrite old outputs with the
171 // new ones.
172 WithInsertPoint insert_point(ins_point->next());
173 std::vector<Value*> new_outputs =
174 insertGraph(*graph, replacement_graph, inputs);
175
176 // Record all planned rewritings
177 AT_ASSERT(outputs.size() == new_outputs.size());
178 for (const auto idx : c10::irange(outputs.size())) {
179 values_to_rewrite.push_back(outputs[idx]);
180 rewrite_map[outputs[idx]] =
181 new_outputs[idx]->setType(outputs[idx]->type());
182 }
183 // Record all planned deletions
184 for (Node* pattern_n : pattern_graph.nodes()) {
185 if (match.nodes_map.count(pattern_n)) {
186 Node* n = match.nodes_map.at(pattern_n);
187 nodes_to_delete_.insert(n);
188 }
189 }
190 }
191
192 // Perform planned rewritings
193 for (auto v : values_to_rewrite) {
194 v->replaceAllUsesWith(rewrite_map.at(v));
195 }
196
197 // Perform planned deletions
198 for (auto n : nodes_to_delete_) {
199 n->removeAllInputs();
200 }
201 for (auto n : nodes_to_delete_) {
202 n->destroy();
203 }
204 nodes_to_delete_.clear();
205}
206
207bool SubgraphRewriter::overlapsWithPreviousMatches(const Match* match) {
208 for (auto n : match->nodes_map) {
209 if (nodes_to_delete_.count(n.second)) {
210 return true;
211 }
212 }
213 return false;
214}
215
216Module PatternBasedRewrite(const Module& module) {
217 // TODO: Deep-copy the module
218 SubgraphRewriter subgraph_rewriter;
219 subgraph_rewriter.RegisterDefaultPatterns();
220 return subgraph_rewriter.runOnModule(module);
221}
222
223} // namespace jit
224} // namespace torch
225