1 | #include <gtest/gtest.h> |
2 | |
3 | #include "test/cpp/jit/test_utils.h" |
4 | #include "torch/csrc/jit/ir/subgraph_matcher.h" |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | TEST(SubgraphMatcherTest, Trivial1) { |
10 | Graph graph, pattern; |
11 | parseIR( |
12 | R"IR( |
13 | graph(%0): |
14 | %a = a::aaa(%0) |
15 | return (%a))IR" , |
16 | &graph); |
17 | parseIR( |
18 | R"IR( |
19 | graph(%0): |
20 | %x = a::aaa(%0) |
21 | return (%x))IR" , |
22 | &pattern); |
23 | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); |
24 | } |
25 | |
26 | TEST(SubgraphMatcherTest, Trivial2) { |
27 | Graph graph; |
28 | auto* g_in = graph.addInput(); |
29 | auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1)); |
30 | g_tanh->addInput(g_in); |
31 | graph.registerOutput(g_tanh->output()); |
32 | |
33 | Graph pattern; |
34 | auto* p_in = pattern.addInput(); |
35 | auto* p_tanh = |
36 | pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1)); |
37 | p_tanh->addInput(p_in); |
38 | pattern.registerOutput(p_tanh->output()); |
39 | |
40 | auto matches = findPatternMatches(pattern, graph); |
41 | AT_ASSERT(matches.size() == 1); |
42 | for (const Match& m : matches) { |
43 | AT_ASSERT(m.values_map.at(p_in) == g_in); |
44 | AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output()); |
45 | AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh); |
46 | } |
47 | } |
48 | |
49 | TEST(SubgraphMatcherTest, Trivial3) { |
50 | Graph graph, pattern; |
51 | parseIR( |
52 | R"IR( |
53 | graph(%0): |
54 | %a = a::a(%0) |
55 | %b = a::b(%0) |
56 | %c = a::c(%a, %b) |
57 | return (%c))IR" , |
58 | &graph); |
59 | parseIR( |
60 | R"IR( |
61 | graph(%a, %b): |
62 | %c = a::c(%a, %b) |
63 | return (%c))IR" , |
64 | &pattern); |
65 | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); |
66 | } |
67 | |
68 | TEST(SubgraphMatcherTest, Trivial4) { |
69 | Graph graph; |
70 | auto* g_in0 = graph.addInput(); |
71 | auto* g_in1 = graph.addInput(); |
72 | auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1)); |
73 | g_mul->addInput(g_in0); |
74 | g_mul->addInput(g_in1); |
75 | graph.registerOutput(g_mul->output()); |
76 | |
77 | Graph pattern; |
78 | auto* p_in0 = pattern.addInput(); |
79 | auto* p_in1 = pattern.addInput(); |
80 | auto* p_mul = |
81 | pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1)); |
82 | p_mul->addInput(p_in0); |
83 | p_mul->addInput(p_in1); |
84 | pattern.registerOutput(p_mul->output()); |
85 | |
86 | auto matches = findPatternMatches(pattern, graph); |
87 | AT_ASSERT(matches.size() == 1); |
88 | for (const Match& m : matches) { |
89 | AT_ASSERT(m.values_map.at(p_in0) == g_in0); |
90 | AT_ASSERT(m.values_map.at(p_in1) == g_in1); |
91 | AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output()); |
92 | AT_ASSERT(m.nodes_map.at(p_mul) == g_mul); |
93 | } |
94 | } |
95 | |
96 | TEST(SubgraphMatcherTest, Linear1) { |
97 | Graph graph, pattern; |
98 | parseIR( |
99 | R"IR( |
100 | graph(%0): |
101 | %a = a::aaa(%0) |
102 | %b = b::bbb(%a) |
103 | %c = c::ccc(%b) |
104 | %d = d::ddd(%c) |
105 | %a = a::aaa(%0) |
106 | return (%d))IR" , |
107 | &graph); |
108 | parseIR( |
109 | R"IR( |
110 | graph(%0): |
111 | %x = b::bbb(%0) |
112 | %y = c::ccc(%x) |
113 | return (%y))IR" , |
114 | &pattern); |
115 | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); |
116 | } |
117 | |
118 | TEST(SubgraphMatcherTest, Linear2) { |
119 | Graph graph; |
120 | auto* g_in = graph.addInput(); |
121 | |
122 | auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1)); |
123 | g_tanh->addInput(g_in); |
124 | |
125 | auto* g_tanh2 = |
126 | graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1)); |
127 | g_tanh2->addInput(g_tanh->output()); |
128 | |
129 | graph.registerOutput(g_tanh2->output()); |
130 | |
131 | Graph pattern; |
132 | auto* p_in = pattern.addInput(); |
133 | |
134 | auto* p_tanh = |
135 | pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1)); |
136 | p_tanh->addInput(p_in); |
137 | |
138 | auto* p_tanh2 = |
139 | pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1)); |
140 | p_tanh2->addInput(p_tanh->output()); |
141 | |
142 | pattern.registerOutput(p_tanh2->output()); |
143 | |
144 | auto matches = findPatternMatches(pattern, graph); |
145 | AT_ASSERT(matches.size() == 1); |
146 | for (const Match& m : matches) { |
147 | AT_ASSERT(m.values_map.at(p_in) == g_in); |
148 | AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output()); |
149 | AT_ASSERT(m.values_map.at(p_tanh2->output()) == g_tanh2->output()); |
150 | AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh); |
151 | AT_ASSERT(m.nodes_map.at(p_tanh2) == g_tanh2); |
152 | } |
153 | } |
154 | |
155 | /** |
156 | * Test diamond pattern: |
157 | * |
158 | * ooo |
159 | * | |
160 | * aaa |
161 | * / \ |
162 | * bbb ccc |
163 | * \ / |
164 | * ddd |
165 | * | |
166 | * eee |
167 | */ |
168 | TEST(SubgraphMatcherTest, Diamond1) { |
169 | Graph graph, pattern1, pattern2; |
170 | parseIR( |
171 | R"IR( |
172 | graph(%0): |
173 | %o = o::ooo(%0) |
174 | %a = a::aaa(%o) |
175 | %b = b::bbb(%a) |
176 | %c = c::ccc(%a) |
177 | %d = d::ddd(%b, %c) |
178 | %e = e::eee(%d) |
179 | return (%e))IR" , |
180 | &graph); |
181 | |
182 | parseIR( |
183 | R"IR( |
184 | graph(%0): |
185 | %a = a::aaa(%0) |
186 | %b = b::bbb(%a) |
187 | %c = c::ccc(%a) |
188 | %d = d::ddd(%b, %c) |
189 | return (%d))IR" , |
190 | &pattern1); |
191 | AT_ASSERT(!findPatternMatches(pattern1, graph).empty()); |
192 | |
193 | // Check that order of nodes inside the diamond does not affect the result |
194 | parseIR( |
195 | R"IR( |
196 | graph(%0): |
197 | %a = a::aaa(%0) |
198 | %c = c::ccc(%a) |
199 | %b = b::bbb(%a) |
200 | %d = d::ddd(%b, %c) |
201 | return (%d))IR" , |
202 | &pattern2); |
203 | AT_ASSERT(!findPatternMatches(pattern2, graph).empty()); |
204 | } |
205 | |
206 | /** |
207 | * Test diamond pattern: |
208 | * |
209 | * i0 |
210 | * | |
211 | * chunk |
212 | * / \ |
213 | * os[0] os[1] |
214 | * \ / |
215 | * * |
216 | * | |
217 | * o1 |
218 | */ |
219 | TEST(SubgraphMatcherTest, Diamond2) { |
220 | Graph graph; |
221 | auto* g_in = graph.addInput(); |
222 | |
223 | auto* g_chunk = |
224 | graph.insertNode(graph.create(prim::ConstantChunk, /*num_outputs =*/2)); |
225 | g_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0); |
226 | g_chunk->addInput(g_in); |
227 | |
228 | auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1)); |
229 | g_mul->addInput(g_chunk->outputs()[0]); |
230 | g_mul->addInput(g_chunk->outputs()[1]); |
231 | graph.registerOutput(g_mul->output()); |
232 | |
233 | Graph pattern; |
234 | auto* p_in = pattern.addInput(); |
235 | auto* p_chunk = pattern.insertNode( |
236 | pattern.create(prim::ConstantChunk, /*num_outputs =*/2)); |
237 | p_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0); |
238 | p_chunk->addInput(p_in); |
239 | |
240 | auto* p_mul = |
241 | pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1)); |
242 | p_mul->addInput(p_chunk->outputs()[0]); |
243 | p_mul->addInput(p_chunk->outputs()[1]); |
244 | pattern.registerOutput(p_mul->output()); |
245 | |
246 | auto matches = findPatternMatches(pattern, graph); |
247 | AT_ASSERT(matches.size() == 1); |
248 | for (const Match& m : matches) { |
249 | AT_ASSERT(m.values_map.at(p_in) == g_in); |
250 | AT_ASSERT(m.values_map.at(p_chunk->outputs()[0]) == g_chunk->outputs()[0]); |
251 | AT_ASSERT(m.values_map.at(p_chunk->outputs()[1]) == g_chunk->outputs()[1]); |
252 | AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output()); |
253 | AT_ASSERT(m.nodes_map.at(p_mul) == g_mul); |
254 | } |
255 | } |
256 | |
257 | TEST(SubgraphMatcherTest, XPattern) { |
258 | Graph graph, pattern; |
259 | parseIR( |
260 | R"IR( |
261 | graph(%0, %1): |
262 | %b = b::bbb(%0) |
263 | %c = c::ccc(%1) |
264 | %x = x::xxx(%b, %c) |
265 | %e = e::eee(%x) |
266 | %f = f::fff(%x) |
267 | %g = g::ggg(%e, %f) |
268 | return (%g))IR" , |
269 | &graph); |
270 | parseIR( |
271 | R"IR( |
272 | graph(%0, %1): |
273 | %b = b::bbb(%0) |
274 | %c = c::ccc(%1) |
275 | %x = x::xxx(%b, %c) |
276 | %e = e::eee(%x) |
277 | %f = f::fff(%x) |
278 | %g = g::ggg(%e, %f) |
279 | return (%g))IR" , |
280 | &pattern); |
281 | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); |
282 | } |
283 | |
284 | TEST(SubgraphMatcherTest, MultipleMatches) { |
285 | Graph graph, pattern; |
286 | parseIR( |
287 | R"IR( |
288 | graph(%t0): |
289 | %t1 = a::aaa(%t0) |
290 | %t2 = a::aaa(%t1) |
291 | %t3 = a::aaa(%t2) |
292 | %t4 = a::aaa(%t3) |
293 | return (%t4))IR" , |
294 | &graph); |
295 | parseIR( |
296 | R"IR( |
297 | graph(%t0): |
298 | %t1 = a::aaa(%t0) |
299 | return (%t1))IR" , |
300 | &pattern); |
301 | auto matches = findPatternMatches(pattern, graph); |
302 | AT_ASSERT(matches.size() == 4); |
303 | } |
304 | |
305 | TEST(SubgraphMatcherTest, OverlappingMatches) { |
306 | Graph graph, pattern; |
307 | parseIR( |
308 | R"IR( |
309 | graph(%t0): |
310 | %t1 = a::aaa(%t0) |
311 | %t2 = a::aaa(%t1) |
312 | %t3 = a::aaa(%t2) |
313 | %t4 = a::aaa(%t3) |
314 | return (%t4))IR" , |
315 | &graph); |
316 | parseIR( |
317 | R"IR( |
318 | graph(%t0): |
319 | %t1 = a::aaa(%t0) |
320 | %t2 = a::aaa(%t1) |
321 | return (%t2))IR" , |
322 | &pattern); |
323 | auto matches = findPatternMatches(pattern, graph); |
324 | AT_ASSERT(matches.size() == 3); |
325 | } |
326 | |
327 | TEST(SubgraphMatcherTest, MatchInBasicBlocks1) { |
328 | Graph graph; |
329 | parseIR( |
330 | R"IR( |
331 | graph(%a, %b, %c): |
332 | %d = aten::mul(%a, %b) |
333 | %x = prim::If(%c) |
334 | block0(): |
335 | %x1 = aten::mul(%a, %d) |
336 | -> (%x1) |
337 | block1(): |
338 | %x2 = aten::mul(%b, %d) |
339 | -> (%x2) |
340 | return (%x))IR" , |
341 | &graph); |
342 | |
343 | // Ensure the matches don't cross basic block boundaries |
344 | Graph pattern0; |
345 | parseIR( |
346 | R"IR( |
347 | graph(%x, %y): |
348 | %z = aten::mul(%x, %y) |
349 | return (%z))IR" , |
350 | &pattern0); |
351 | AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3); |
352 | |
353 | Graph pattern1; |
354 | parseIR( |
355 | R"IR( |
356 | graph(%x, %y): |
357 | %z1 = aten::mul(%x, %y) |
358 | %z2 = aten::mul(%y, %z1) |
359 | return (%z2))IR" , |
360 | &pattern1); |
361 | AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0); |
362 | } |
363 | |
364 | TEST(SubgraphMatcherTest, MatchInBasicBlocks2) { |
365 | Graph graph; |
366 | parseIR( |
367 | R"IR( |
368 | graph(%a, %b): |
369 | %x = my::mul(%a, %b) |
370 | %y = my::node_with_subblock() |
371 | block0(): |
372 | %z = my::mul(%b, %x) |
373 | -> (%z) |
374 | return (%y))IR" , |
375 | &graph); |
376 | |
377 | // Check that we can match both mul ops |
378 | Graph pattern0; |
379 | parseIR( |
380 | R"IR( |
381 | graph(%x, %y): |
382 | %z = my::mul(%x, %y) |
383 | return (%z))IR" , |
384 | &pattern0); |
385 | AT_ASSERT(findPatternMatches(pattern0, graph).size() == 2); |
386 | |
387 | // Ensure the matches don't cross basic block boundaries |
388 | Graph pattern1; |
389 | parseIR( |
390 | R"IR( |
391 | graph(%x, %y): |
392 | %u = my::mul(%x, %y) |
393 | %v = my::mul(%y, %u) |
394 | return (%v))IR" , |
395 | &pattern1); |
396 | AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0); |
397 | } |
398 | |
399 | TEST(SubgraphMatcherTest, MatchesAttributes) { |
400 | Graph graph; |
401 | parseIR( |
402 | R"IR( |
403 | graph(%0): |
404 | %a = a::a[isattr=[1,2]](%0) |
405 | %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0) |
406 | %c = a::c[myattr="qqq"](%a, %b) |
407 | return (%c))IR" , |
408 | &graph); |
409 | |
410 | { |
411 | Graph pattern; |
412 | parseIR( |
413 | R"IR( |
414 | graph(%a, %b): |
415 | %c = a::c[myattr="qqq"](%a, %b) |
416 | return (%c))IR" , |
417 | &pattern); |
418 | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); |
419 | } |
420 | { |
421 | Graph pattern; |
422 | parseIR( |
423 | R"IR( |
424 | graph(%a, %b): |
425 | %c = a::c[myattr="zzz"](%a, %b) |
426 | return (%c))IR" , |
427 | &pattern); |
428 | AT_ASSERT(findPatternMatches(pattern, graph).empty()); |
429 | } |
430 | { |
431 | Graph pattern; |
432 | parseIR( |
433 | R"IR( |
434 | graph(%0): |
435 | %b = a::b[extraattr=10](%0) |
436 | return (%b))IR" , |
437 | &pattern); |
438 | AT_ASSERT(findPatternMatches(pattern, graph).empty()); |
439 | } |
440 | { |
441 | Graph pattern; |
442 | parseIR( |
443 | R"IR( |
444 | graph(%0): |
445 | %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0) |
446 | return (%b))IR" , |
447 | &pattern); |
448 | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); |
449 | } |
450 | { |
451 | Graph pattern; |
452 | parseIR( |
453 | R"IR( |
454 | graph(%0): |
455 | %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j, strattr="rrr"](%0) |
456 | return (%b))IR" , |
457 | &pattern); |
458 | AT_ASSERT(findPatternMatches(pattern, graph).empty()); |
459 | } |
460 | { |
461 | Graph pattern; |
462 | parseIR( |
463 | R"IR( |
464 | graph(%0): |
465 | %a = a::a[isattr=[1,2]](%0) |
466 | return (%a))IR" , |
467 | &pattern); |
468 | // Lists are not supported yet, thus we shouldn't match for now. |
469 | AT_ASSERT(findPatternMatches(pattern, graph).empty()); |
470 | } |
471 | { |
472 | Graph pattern; |
473 | parseIR( |
474 | R"IR( |
475 | graph(%a, %b): |
476 | %c = a::c[myattr="q.*"](%a, %b) |
477 | return (%c))IR" , |
478 | &pattern); |
479 | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); |
480 | } |
481 | } |
482 | |
483 | TEST(SubgraphMatcherTest, BadPattern) { |
484 | Graph graph, pattern1, pattern2; |
485 | parseIR( |
486 | R"IR( |
487 | graph(%x): |
488 | %y = my::op1(%x) |
489 | %z = my::op2(%x) |
490 | return (%y, %z))IR" , |
491 | &graph); |
492 | |
493 | parseIR( |
494 | R"IR( |
495 | graph(%x): |
496 | %y = my::node_with_subblock() |
497 | block0(): |
498 | %z = my::op(%x) |
499 | -> (%z) |
500 | return (%y))IR" , |
501 | &pattern1); |
502 | // No support for patterns with subblocks |
503 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
504 | ASSERT_ANY_THROW(findPatternMatches(pattern1, graph)); |
505 | |
506 | parseIR( |
507 | R"IR( |
508 | graph(%x): |
509 | %y = my::op1(%x) |
510 | %z = my::op2(%x) |
511 | return (%y, %z))IR" , |
512 | &pattern2); |
513 | // Not supported multi-output pattern, because not the whole pattern is |
514 | // covered by a traversal up from the first output (`%z = ...` is not |
515 | // visited). See the note "Multi-output Patterns" in subgraph_matcher.h. |
516 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
517 | ASSERT_ANY_THROW(findPatternMatches(pattern2, graph)); |
518 | } |
519 | |
520 | TEST(SubgraphMatcherTest, MultiOutput) { |
521 | { |
522 | Graph graph, pattern; |
523 | parseIR( |
524 | R"IR( |
525 | graph(%0): |
526 | %a = a::aaa(%0) |
527 | %b = b::bbb(%a) |
528 | %c = c::ccc(%a, %b) |
529 | %x = a::aaa(%c) |
530 | %y = b::bbb(%x) |
531 | %z = d::ddd(%x, %y) |
532 | return (%y))IR" , |
533 | &graph); |
534 | parseIR( |
535 | R"IR( |
536 | graph(%0): |
537 | %a = a::aaa(%0) |
538 | %b = b::bbb(%a) |
539 | return (%b, %a))IR" , |
540 | &pattern); |
541 | AT_ASSERT(findPatternMatches(pattern, graph).size() == 2); |
542 | } |
543 | { |
544 | Graph graph, pattern; |
545 | parseIR( |
546 | R"IR( |
547 | graph(%0, %1): |
548 | %a1, %a2 = a::aaa(%0, %1) |
549 | %b = b::bbb(%a1) |
550 | %c = c::ccc(%b) |
551 | |
552 | %x1, %x2 = a::aaa(%c, %a2) |
553 | %y = b::bbb(%x1) |
554 | %z = d::ddd(%y) |
555 | return (%z))IR" , |
556 | &graph); |
557 | parseIR( |
558 | R"IR( |
559 | graph(%0, %1): |
560 | %a1, %a2 = a::aaa(%0, %1) |
561 | %b = b::bbb(%a1) |
562 | return (%b, %a2))IR" , |
563 | &pattern); |
564 | AT_ASSERT(findPatternMatches(pattern, graph).size() == 2); |
565 | } |
566 | } |
567 | |
568 | } // namespace jit |
569 | } // namespace torch |
570 | |