1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
5 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
6 | #include <torch/csrc/jit/testing/file_check.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | using namespace testing; |
11 | |
12 | TEST(SubgraphRewriterTest, FilterMatch) { |
13 | auto graph = std::make_shared<Graph>(); |
14 | |
15 | parseIR( |
16 | R"IR( |
17 | graph(%0): |
18 | %a = a::aaa(%0) |
19 | %b : int = prim::Constant[value=1]() |
20 | %c = c::ccc(%a, %b) |
21 | return (%c))IR" , |
22 | graph.get()); |
23 | |
24 | std::string pattern = R"IR( |
25 | graph(%a, %b): |
26 | %c = c::ccc(%a, %b) |
27 | return (%c))IR" ; |
28 | Graph pattern_graph; |
29 | std::unordered_map<std::string, Value*> vmap; |
30 | |
31 | parseIR(pattern, &pattern_graph, vmap); |
32 | |
33 | auto b_is_constant = [](const Match& match, |
34 | const std::unordered_map<std::string, Value*>& vmap) { |
35 | const auto& match_vmap = match.values_map; |
36 | auto b_node = match_vmap.at(vmap.at("b" ))->node(); |
37 | return b_node->kind() == prim::Constant; |
38 | }; |
39 | |
40 | auto b_is_one = [](const Match& match, |
41 | const std::unordered_map<std::string, Value*>& vmap) { |
42 | const auto& match_vmap = match.values_map; |
43 | auto b_val = toIValue(match_vmap.at(vmap.at("b" ))); |
44 | return b_val && b_val->isInt() && b_val->toInt() == 1; |
45 | }; |
46 | |
47 | auto b_is_two = [](const Match& match, |
48 | const std::unordered_map<std::string, Value*>& vmap) { |
49 | const auto& match_vmap = match.values_map; |
50 | auto b_val = toIValue(match_vmap.at(vmap.at("b" ))); |
51 | return b_val && b_val->isInt() && b_val->toInt() == 2; |
52 | }; |
53 | |
54 | std::string replacement = R"IR( |
55 | graph(%a, %b): |
56 | %d = d::ddd(%a, %b) |
57 | return (%d))IR" ; |
58 | |
59 | SubgraphRewriter rewriter; |
60 | rewriter.RegisterRewritePattern(pattern, replacement); |
61 | |
62 | // b is constant, so the match will succeed |
63 | { |
64 | auto g = graph->copy(); |
65 | rewriter.runOnGraph(g, b_is_constant); |
66 | FileCheck().check("d::ddd" )->check_not("c::ccc" )->run(*g); |
67 | } |
68 | |
69 | // b is constant and the value is one, the match will succeed |
70 | { |
71 | auto g = graph->copy(); |
72 | rewriter.runOnGraph(g, {b_is_constant, b_is_one}); |
73 | FileCheck().check("d::ddd" )->check_not("c::ccc" )->run(*g); |
74 | } |
75 | |
76 | // b is constant but the value is not two, the match will fail |
77 | { |
78 | auto g = graph->copy(); |
79 | rewriter.runOnGraph(g, {b_is_constant, b_is_two}); |
80 | FileCheck().check("c::ccc" )->check_not("d::ddd" )->run(*g); |
81 | } |
82 | } |
83 | |
84 | TEST(SubgraphRewriterTest, FilterNoMatch) { |
85 | auto graph = std::make_shared<Graph>(); |
86 | parseIR( |
87 | R"IR( |
88 | graph(%0): |
89 | %a = a::aaa(%0) |
90 | %b = prim::Constant[value=1]() |
91 | %c = c::ccc(%a, %b) |
92 | return (%c))IR" , |
93 | graph.get()); |
94 | |
95 | std::string pattern = R"IR( |
96 | graph(%a, %b): |
97 | %c = c::ccc(%a, %b) |
98 | return (%c))IR" ; |
99 | Graph pattern_graph; |
100 | std::unordered_map<std::string, Value*> vmap; |
101 | |
102 | parseIR(pattern, &pattern_graph, vmap); |
103 | |
104 | auto filter = [](const Match& match, |
105 | const std::unordered_map<std::string, Value*>& vmap) { |
106 | const auto& match_vmap = match.values_map; |
107 | auto b_node = match_vmap.at(vmap.at("b" ))->node(); |
108 | // b_node is not prim::Assign, so this won't match and we'll skip the |
109 | // rewrite |
110 | return b_node->kind() == prim::Assign; |
111 | }; |
112 | |
113 | std::string replacement = R"IR( |
114 | graph(%a, %b): |
115 | %d = d::ddd(%a, %b) |
116 | return (%d))IR" ; |
117 | |
118 | SubgraphRewriter rewriter; |
119 | rewriter.RegisterRewritePattern(pattern, replacement); |
120 | rewriter.runOnGraph(graph, filter); |
121 | |
122 | FileCheck().check("c::ccc" )->check_not("d::ddd" )->run(*graph); |
123 | } |
124 | |
125 | TEST(SubgraphRewriterTest, MultiOutput) { |
126 | { |
127 | auto graph = std::make_shared<Graph>(); |
128 | |
129 | // Basic multi-output pattern rewriting |
130 | parseIR( |
131 | R"IR( |
132 | graph(%0, %1): |
133 | %a1, %a2 = a::aaa(%0, %1) |
134 | %b = b::bbb(%a1) |
135 | %c = c::ccc(%b) |
136 | |
137 | %x1, %x2 = a::aaa(%c, %a2) |
138 | %y = b::bbb(%x1) |
139 | %z = d::ddd(%y) |
140 | return (%z))IR" , |
141 | graph.get()); |
142 | |
143 | std::string pattern = R"IR( |
144 | graph(%0, %1): |
145 | %a1, %a2 = a::aaa(%0, %1) |
146 | %b = b::bbb(%a1) |
147 | return (%b, %a2))IR" ; |
148 | |
149 | std::string replacement = R"IR( |
150 | graph(%a, %b): |
151 | %x, %y = ab::ababab(%a, %b) |
152 | return (%x, %y))IR" ; |
153 | |
154 | SubgraphRewriter rewriter; |
155 | rewriter.RegisterRewritePattern(pattern, replacement); |
156 | |
157 | auto g = graph->copy(); |
158 | rewriter.runOnGraph(g); |
159 | FileCheck().check("ab::ababab" )->check("ab::ababab" )->run(*g); |
160 | } |
161 | { |
162 | auto graph = std::make_shared<Graph>(); |
163 | |
164 | // Mimic a real model case |
165 | parseIR( |
166 | R"IR( |
167 | graph(%k, %m, %x1, %x2, %x3, %x4, %y1, %y2, %y3, %y4): |
168 | %a1 = aa::aaa(%x1, %k) |
169 | %b1_1, %b1_2 = bb::bbb(%y1, %a1) |
170 | %a2 = aa::aaa(%x2, %k) |
171 | %b2_1, %b2_2 = bb::bbb(%y2, %a2) |
172 | %a3 = aa::aaa(%x3, %k) |
173 | %b3_1, %b3_2 = bb::bbb(%y3, %a3) |
174 | %a4 = aa::aaa(%x4, %k) |
175 | %b4_1, %b4_2 = bb::bbb(%y4, %a4) |
176 | %c = cc::ccc(%b4_1) |
177 | %d1 = dd::ddd(%b1_2, %m) |
178 | %e1 = ee::eee(%b1_1, %d1) |
179 | %d2 = dd::ddd(%b2_2, %m) |
180 | %e2 = ee::eee(%b2_1, %d2) |
181 | %d3 = dd::ddd(%b3_2, %m) |
182 | %e3 = ee::eee(%b3_1, %d3) |
183 | %d4 = dd::ddd(%b4_2, %m) |
184 | %e4 = ee::eee(%b4_1, %d4) |
185 | return (%d1, %d2, %d3, %d4, %e1, %e2, %e3, %e4) |
186 | )IR" , |
187 | graph.get()); |
188 | |
189 | std::string pattern = R"IR( |
190 | graph(%a, %b, %c, %d): |
191 | %y0 = aa::aaa(%b, %c) |
192 | %y1, %y2 = bb::bbb(%a, %y0) |
193 | %y3 = dd::ddd(%y2, %d) |
194 | return (%y3, %y1))IR" ; |
195 | |
196 | std::string replacement = R"IR( |
197 | graph(%a, %b, %c, %d): |
198 | %x, %y = ab::ababab(%a, %b, %c, %d) |
199 | return (%x, %y))IR" ; |
200 | |
201 | SubgraphRewriter rewriter; |
202 | rewriter.RegisterRewritePattern(pattern, replacement); |
203 | |
204 | auto g = graph->copy(); |
205 | rewriter.runOnGraph(g); |
206 | FileCheck().check("ab::ababab" )->check("ab::ababab" )->run(*g); |
207 | } |
208 | { |
209 | auto graph = std::make_shared<Graph>(); |
210 | |
211 | // A case where no rewriting should occur due to data dependencies |
212 | parseIR( |
213 | R"IR( |
214 | graph(%x, %y): |
215 | %a = aa::aaa(%x) |
216 | %b = bb::bbb(%a) |
217 | %e = ee::eee(%b) |
218 | %c = cc::ccc(%y) |
219 | %d = dd::ddd(%b, %c) |
220 | %f = ff::fff(%b, %d) |
221 | return (%f) |
222 | )IR" , |
223 | graph.get()); |
224 | |
225 | std::string pattern = R"IR( |
226 | graph(%a, %c): |
227 | %b = bb::bbb(%a) |
228 | %d = dd::ddd(%b, %c) |
229 | return (%d, %b))IR" ; |
230 | |
231 | std::string replacement = R"IR( |
232 | graph(%a, %c): |
233 | %d, %b = db::fused(%a, %c) |
234 | return (%d, %b))IR" ; |
235 | |
236 | SubgraphRewriter rewriter; |
237 | rewriter.RegisterRewritePattern(pattern, replacement); |
238 | |
239 | auto g = graph->copy(); |
240 | rewriter.runOnGraph(g); |
241 | // We should not perform the replacement on the given graph due to data |
242 | // dependency constraints: the output %b is used in %e, which precedes one |
243 | // def of the input %c. |
244 | FileCheck().check_not("db::fused" )->run(*g); |
245 | } |
246 | } |
247 | |
248 | TEST(SubgraphRewriterTest, OutputType) { |
249 | std::string pattern = R"IR( |
250 | graph(%a, %b): |
251 | %c = c::ccc(%a, %b) |
252 | return (%c))IR" ; |
253 | Graph pattern_graph; |
254 | std::unordered_map<std::string, Value*> vmap; |
255 | |
256 | parseIR(pattern, &pattern_graph, vmap); |
257 | |
258 | auto b_is_constant = [](const Match& match, |
259 | const std::unordered_map<std::string, Value*>& vmap) { |
260 | const auto& match_vmap = match.values_map; |
261 | auto b_node = match_vmap.at(vmap.at("b" ))->node(); |
262 | return b_node->kind() == prim::Constant; |
263 | }; |
264 | |
265 | std::string replacement = R"IR( |
266 | graph(%a, %b): |
267 | %d = d::ddd(%a, %b) |
268 | return (%d))IR" ; |
269 | |
270 | SubgraphRewriter rewriter; |
271 | rewriter.RegisterRewritePattern(pattern, replacement); |
272 | { |
273 | auto graph = std::make_shared<Graph>(); |
274 | |
275 | parseIR( |
276 | R"IR( |
277 | graph(%0): |
278 | %a : Float(10, 20) = a::aaa(%0) |
279 | %b : int = prim::Constant[value=1]() |
280 | %c : Float(10, 20) = c::ccc(%a, %b) |
281 | return (%c))IR" , |
282 | graph.get()); |
283 | |
284 | // output has shape info. |
285 | rewriter.runOnGraph(graph, b_is_constant); |
286 | FileCheck() |
287 | .check("Float(10, 20) = d::ddd" ) |
288 | ->check_not("c::ccc" ) |
289 | ->run(*graph); |
290 | } |
291 | { |
292 | auto graph = std::make_shared<Graph>(); |
293 | |
294 | parseIR( |
295 | R"IR( |
296 | graph(%0): |
297 | %a = a::aaa(%0) |
298 | %b : int = prim::Constant[value=1]() |
299 | %c = c::ccc(%a, %b) |
300 | return (%c))IR" , |
301 | graph.get()); |
302 | |
303 | // output has not shape info. |
304 | rewriter.runOnGraph(graph, b_is_constant); |
305 | FileCheck().check("Tensor = d::ddd" )->check_not("c::ccc" )->run(*graph); |
306 | } |
307 | } |
308 | |
309 | } // namespace jit |
310 | } // namespace torch |
311 | |