1 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
2 | |
3 | #include <torch/csrc/jit/ir/irparser.h> |
4 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
5 | |
6 | #include <c10/util/irange.h> |
7 | |
8 | #include <utility> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | namespace { |
14 | void update_source_range_and_cs_ptr( |
15 | const std::set<const Node*>& input_nodes, |
16 | const Match& m, |
17 | std::unordered_map<Node*, Node*>& pattern_node_map) { |
18 | // pattern_node_map, maps nodes of the replacement graph |
19 | // to the nodes of the pattern graph. |
20 | // Now we iterate over each node of the replacement graph |
21 | // and find the corresponding pattern node in the match. |
22 | // The matched's node's source range and callstack is then |
23 | // used to update replacement node's source range and callstack |
24 | for (auto& it : pattern_node_map) { |
25 | Node* replacement_node = it.first; |
26 | Node* pattern_node = it.second; |
27 | if (!input_nodes.count(pattern_node)) { |
28 | Node* orig_node = m.nodes_map.at(pattern_node); |
29 | replacement_node->setSourceRange(orig_node->sourceRange()); |
30 | if (orig_node->callstack()) { |
31 | replacement_node->setCallStack(orig_node->callstack().value()); |
32 | } |
33 | } |
34 | } |
35 | } |
36 | } // namespace |
37 | |
38 | void SubgraphRewriter::RegisterDefaultPatterns() { |
39 | // TODO: Add actual patterns (like Conv-Relu). |
40 | RegisterRewritePattern( |
41 | R"IR( |
42 | graph(%x, %w, %b): |
43 | %c = aten::conv(%x, %w, %b) |
44 | %r = aten::relu(%c) |
45 | return (%r))IR" , |
46 | R"IR( |
47 | graph(%x, %w, %b): |
48 | %r = aten::convrelu(%x, %w, %b) |
49 | return (%r))IR" , |
50 | {{"r" , "c" }}); |
51 | } |
52 | |
53 | void SubgraphRewriter::RegisterRewritePattern( |
54 | const std::string& pattern, |
55 | const std::string& replacement, |
56 | const std::vector<std::pair<std::string, std::string>>& value_name_pairs) { |
57 | std::unordered_map<std::string, std::string> value_name_map( |
58 | value_name_pairs.begin(), value_name_pairs.end()); |
59 | RewritePatternDescr d = {pattern, replacement, std::move(value_name_map)}; |
60 | patterns_.push_back(std::move(d)); |
61 | } |
62 | |
63 | Module SubgraphRewriter::runOnModule(const Module& module) { |
64 | nodes_to_delete_.clear(); |
65 | for (const auto& m : module.get_methods()) { |
66 | auto g = toGraphFunction(m.function()).graph(); |
67 | runOnGraph(g); |
68 | } |
69 | return module; |
70 | } |
71 | |
72 | void SubgraphRewriter::runOnGraph( |
73 | std::shared_ptr<Graph>& graph, |
74 | const std::vector<MatchFilter>& filters) { |
75 | for (const RewritePatternDescr& pattern : patterns_) { |
76 | rewriteSinglePatternOnGraph(graph, pattern, filters); |
77 | } |
78 | } |
79 | |
80 | void SubgraphRewriter::rewriteSinglePatternOnGraph( |
81 | std::shared_ptr<Graph>& graph, |
82 | const RewritePatternDescr& pattern, |
83 | const std::vector<MatchFilter>& filters) { |
84 | std::unordered_map<Value*, Value*> rewrite_map; |
85 | std::vector<Value*> values_to_rewrite; |
86 | |
87 | Graph pattern_graph; |
88 | std::unordered_map<std::string, Value*> vmap; |
89 | parseIR(pattern.pattern, &pattern_graph, vmap); |
90 | |
91 | Graph replacement_graph; |
92 | std::unordered_map<std::string, Value*> vmap_replacement; |
93 | parseIR(pattern.replacement, &replacement_graph, vmap_replacement); |
94 | |
95 | // First construct map of Node*-to-Node* |
96 | // This maps Nodes in replacement graph to nodes in pattern graph |
97 | // given the value_name_map, which maps value names from repalcement |
98 | // pattern to value name in pattern |
99 | std::unordered_map<Node*, Node*> pattern_node_map; |
100 | std::set<const Node*> pattern_input_nodes; |
101 | for (auto& it : vmap_replacement) { |
102 | const auto& replacement_value_name = it.first; |
103 | Node* replacement_value_node = it.second->node(); |
104 | if (pattern.value_name_map.count(replacement_value_name)) { |
105 | const auto& pattern_value_name = |
106 | pattern.value_name_map.at(replacement_value_name); |
107 | TORCH_CHECK( |
108 | vmap.count(pattern_value_name), |
109 | "Value must be found in the replacement graph." ); |
110 | Node* pattern_value_node = vmap.at(pattern_value_name)->node(); |
111 | pattern_node_map.emplace(replacement_value_node, pattern_value_node); |
112 | } |
113 | } |
114 | |
115 | const auto& matches = findPatternMatches(pattern_graph, *graph); |
116 | for (const Match& match : matches) { |
117 | if (!std::all_of(filters.begin(), filters.end(), [&](const MatchFilter& f) { |
118 | return f(match, vmap); |
119 | })) { |
120 | continue; |
121 | } |
122 | // Matches might overlap with each other, in that case some of the nodes in |
123 | // the current match might have already been used in another folded pattern. |
124 | // We need to skip such matches. |
125 | if (overlapsWithPreviousMatches(&match)) { |
126 | continue; |
127 | } |
128 | |
129 | // Figure out what values we need to use as inputs and outputs for the |
130 | // replacement subgraph and where the replacement subgraph needs to be |
131 | // inserted. |
132 | Node* ins_point = nullptr; |
133 | std::vector<Value*> inputs, outputs; |
134 | for (Value* v : pattern_graph.inputs()) { |
135 | Value* input = match.values_map.at(v); |
136 | if (!ins_point || ins_point->isBefore(input->node())) { |
137 | ins_point = input->node(); |
138 | } |
139 | inputs.push_back(input); |
140 | } |
141 | AT_ASSERT(ins_point); |
142 | |
143 | // Check that the insertion point we've chosen precedes all the uses of the |
144 | // outputs - otherwise the replacement is incorrect and we have to skip it. |
145 | bool ins_point_before_uses = true; |
146 | for (Value* v : pattern_graph.outputs()) { |
147 | Value* output = match.values_map.at(v); |
148 | outputs.push_back(match.values_map.at(v)); |
149 | |
150 | for (const Use& u : output->uses()) { |
151 | if (u.user->isBefore(ins_point)) { |
152 | ins_point_before_uses = false; |
153 | break; |
154 | } |
155 | } |
156 | } |
157 | |
158 | if (!ins_point_before_uses) { |
159 | continue; |
160 | } |
161 | |
162 | // Before rewriting the graph, update source range and callstack |
163 | // info of the replacement pattern graph so that the rewritten graph |
164 | // has the updated info |
165 | update_source_range_and_cs_ptr( |
166 | pattern_input_nodes, match, pattern_node_map); |
167 | // Insert a clone of replacement subgraph. |
168 | // `inputs` vector holds values that we would use as incoming values to the |
169 | // new subgraph, and we will get `new_outputs` vector containing values |
170 | // produced by this new subgraph - we will then rewrite old outputs with the |
171 | // new ones. |
172 | WithInsertPoint insert_point(ins_point->next()); |
173 | std::vector<Value*> new_outputs = |
174 | insertGraph(*graph, replacement_graph, inputs); |
175 | |
176 | // Record all planned rewritings |
177 | AT_ASSERT(outputs.size() == new_outputs.size()); |
178 | for (const auto idx : c10::irange(outputs.size())) { |
179 | values_to_rewrite.push_back(outputs[idx]); |
180 | rewrite_map[outputs[idx]] = |
181 | new_outputs[idx]->setType(outputs[idx]->type()); |
182 | } |
183 | // Record all planned deletions |
184 | for (Node* pattern_n : pattern_graph.nodes()) { |
185 | if (match.nodes_map.count(pattern_n)) { |
186 | Node* n = match.nodes_map.at(pattern_n); |
187 | nodes_to_delete_.insert(n); |
188 | } |
189 | } |
190 | } |
191 | |
192 | // Perform planned rewritings |
193 | for (auto v : values_to_rewrite) { |
194 | v->replaceAllUsesWith(rewrite_map.at(v)); |
195 | } |
196 | |
197 | // Perform planned deletions |
198 | for (auto n : nodes_to_delete_) { |
199 | n->removeAllInputs(); |
200 | } |
201 | for (auto n : nodes_to_delete_) { |
202 | n->destroy(); |
203 | } |
204 | nodes_to_delete_.clear(); |
205 | } |
206 | |
207 | bool SubgraphRewriter::overlapsWithPreviousMatches(const Match* match) { |
208 | for (auto n : match->nodes_map) { |
209 | if (nodes_to_delete_.count(n.second)) { |
210 | return true; |
211 | } |
212 | } |
213 | return false; |
214 | } |
215 | |
216 | Module PatternBasedRewrite(const Module& module) { |
217 | // TODO: Deep-copy the module |
218 | SubgraphRewriter subgraph_rewriter; |
219 | subgraph_rewriter.RegisterDefaultPatterns(); |
220 | return subgraph_rewriter.runOnModule(module); |
221 | } |
222 | |
223 | } // namespace jit |
224 | } // namespace torch |
225 | |