1#include <gtest/gtest.h>
2
3#include "test/cpp/jit/test_utils.h"
4#include "torch/csrc/jit/ir/subgraph_matcher.h"
5
6namespace torch {
7namespace jit {
8
9TEST(SubgraphMatcherTest, Trivial1) {
10 Graph graph, pattern;
11 parseIR(
12 R"IR(
13graph(%0):
14 %a = a::aaa(%0)
15 return (%a))IR",
16 &graph);
17 parseIR(
18 R"IR(
19graph(%0):
20 %x = a::aaa(%0)
21 return (%x))IR",
22 &pattern);
23 AT_ASSERT(!findPatternMatches(pattern, graph).empty());
24}
25
26TEST(SubgraphMatcherTest, Trivial2) {
27 Graph graph;
28 auto* g_in = graph.addInput();
29 auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
30 g_tanh->addInput(g_in);
31 graph.registerOutput(g_tanh->output());
32
33 Graph pattern;
34 auto* p_in = pattern.addInput();
35 auto* p_tanh =
36 pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1));
37 p_tanh->addInput(p_in);
38 pattern.registerOutput(p_tanh->output());
39
40 auto matches = findPatternMatches(pattern, graph);
41 AT_ASSERT(matches.size() == 1);
42 for (const Match& m : matches) {
43 AT_ASSERT(m.values_map.at(p_in) == g_in);
44 AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output());
45 AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh);
46 }
47}
48
49TEST(SubgraphMatcherTest, Trivial3) {
50 Graph graph, pattern;
51 parseIR(
52 R"IR(
53graph(%0):
54 %a = a::a(%0)
55 %b = a::b(%0)
56 %c = a::c(%a, %b)
57 return (%c))IR",
58 &graph);
59 parseIR(
60 R"IR(
61graph(%a, %b):
62 %c = a::c(%a, %b)
63 return (%c))IR",
64 &pattern);
65 AT_ASSERT(!findPatternMatches(pattern, graph).empty());
66}
67
68TEST(SubgraphMatcherTest, Trivial4) {
69 Graph graph;
70 auto* g_in0 = graph.addInput();
71 auto* g_in1 = graph.addInput();
72 auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1));
73 g_mul->addInput(g_in0);
74 g_mul->addInput(g_in1);
75 graph.registerOutput(g_mul->output());
76
77 Graph pattern;
78 auto* p_in0 = pattern.addInput();
79 auto* p_in1 = pattern.addInput();
80 auto* p_mul =
81 pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1));
82 p_mul->addInput(p_in0);
83 p_mul->addInput(p_in1);
84 pattern.registerOutput(p_mul->output());
85
86 auto matches = findPatternMatches(pattern, graph);
87 AT_ASSERT(matches.size() == 1);
88 for (const Match& m : matches) {
89 AT_ASSERT(m.values_map.at(p_in0) == g_in0);
90 AT_ASSERT(m.values_map.at(p_in1) == g_in1);
91 AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output());
92 AT_ASSERT(m.nodes_map.at(p_mul) == g_mul);
93 }
94}
95
96TEST(SubgraphMatcherTest, Linear1) {
97 Graph graph, pattern;
98 parseIR(
99 R"IR(
100graph(%0):
101 %a = a::aaa(%0)
102 %b = b::bbb(%a)
103 %c = c::ccc(%b)
104 %d = d::ddd(%c)
105 %a = a::aaa(%0)
106 return (%d))IR",
107 &graph);
108 parseIR(
109 R"IR(
110graph(%0):
111 %x = b::bbb(%0)
112 %y = c::ccc(%x)
113 return (%y))IR",
114 &pattern);
115 AT_ASSERT(!findPatternMatches(pattern, graph).empty());
116}
117
118TEST(SubgraphMatcherTest, Linear2) {
119 Graph graph;
120 auto* g_in = graph.addInput();
121
122 auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
123 g_tanh->addInput(g_in);
124
125 auto* g_tanh2 =
126 graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
127 g_tanh2->addInput(g_tanh->output());
128
129 graph.registerOutput(g_tanh2->output());
130
131 Graph pattern;
132 auto* p_in = pattern.addInput();
133
134 auto* p_tanh =
135 pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1));
136 p_tanh->addInput(p_in);
137
138 auto* p_tanh2 =
139 pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1));
140 p_tanh2->addInput(p_tanh->output());
141
142 pattern.registerOutput(p_tanh2->output());
143
144 auto matches = findPatternMatches(pattern, graph);
145 AT_ASSERT(matches.size() == 1);
146 for (const Match& m : matches) {
147 AT_ASSERT(m.values_map.at(p_in) == g_in);
148 AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output());
149 AT_ASSERT(m.values_map.at(p_tanh2->output()) == g_tanh2->output());
150 AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh);
151 AT_ASSERT(m.nodes_map.at(p_tanh2) == g_tanh2);
152 }
153}
154
155/**
156 * Test diamond pattern:
157 *
158 * ooo
159 * |
160 * aaa
161 * / \
162 * bbb ccc
163 * \ /
164 * ddd
165 * |
166 * eee
167 */
168TEST(SubgraphMatcherTest, Diamond1) {
169 Graph graph, pattern1, pattern2;
170 parseIR(
171 R"IR(
172graph(%0):
173 %o = o::ooo(%0)
174 %a = a::aaa(%o)
175 %b = b::bbb(%a)
176 %c = c::ccc(%a)
177 %d = d::ddd(%b, %c)
178 %e = e::eee(%d)
179 return (%e))IR",
180 &graph);
181
182 parseIR(
183 R"IR(
184graph(%0):
185 %a = a::aaa(%0)
186 %b = b::bbb(%a)
187 %c = c::ccc(%a)
188 %d = d::ddd(%b, %c)
189 return (%d))IR",
190 &pattern1);
191 AT_ASSERT(!findPatternMatches(pattern1, graph).empty());
192
193 // Check that order of nodes inside the diamond does not affect the result
194 parseIR(
195 R"IR(
196graph(%0):
197 %a = a::aaa(%0)
198 %c = c::ccc(%a)
199 %b = b::bbb(%a)
200 %d = d::ddd(%b, %c)
201 return (%d))IR",
202 &pattern2);
203 AT_ASSERT(!findPatternMatches(pattern2, graph).empty());
204}
205
206/**
207 * Test diamond pattern:
208 *
209 * i0
210 * |
211 * chunk
212 * / \
213 * os[0] os[1]
214 * \ /
215 * *
216 * |
217 * o1
218 */
219TEST(SubgraphMatcherTest, Diamond2) {
220 Graph graph;
221 auto* g_in = graph.addInput();
222
223 auto* g_chunk =
224 graph.insertNode(graph.create(prim::ConstantChunk, /*num_outputs =*/2));
225 g_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0);
226 g_chunk->addInput(g_in);
227
228 auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1));
229 g_mul->addInput(g_chunk->outputs()[0]);
230 g_mul->addInput(g_chunk->outputs()[1]);
231 graph.registerOutput(g_mul->output());
232
233 Graph pattern;
234 auto* p_in = pattern.addInput();
235 auto* p_chunk = pattern.insertNode(
236 pattern.create(prim::ConstantChunk, /*num_outputs =*/2));
237 p_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0);
238 p_chunk->addInput(p_in);
239
240 auto* p_mul =
241 pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1));
242 p_mul->addInput(p_chunk->outputs()[0]);
243 p_mul->addInput(p_chunk->outputs()[1]);
244 pattern.registerOutput(p_mul->output());
245
246 auto matches = findPatternMatches(pattern, graph);
247 AT_ASSERT(matches.size() == 1);
248 for (const Match& m : matches) {
249 AT_ASSERT(m.values_map.at(p_in) == g_in);
250 AT_ASSERT(m.values_map.at(p_chunk->outputs()[0]) == g_chunk->outputs()[0]);
251 AT_ASSERT(m.values_map.at(p_chunk->outputs()[1]) == g_chunk->outputs()[1]);
252 AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output());
253 AT_ASSERT(m.nodes_map.at(p_mul) == g_mul);
254 }
255}
256
257TEST(SubgraphMatcherTest, XPattern) {
258 Graph graph, pattern;
259 parseIR(
260 R"IR(
261graph(%0, %1):
262 %b = b::bbb(%0)
263 %c = c::ccc(%1)
264 %x = x::xxx(%b, %c)
265 %e = e::eee(%x)
266 %f = f::fff(%x)
267 %g = g::ggg(%e, %f)
268 return (%g))IR",
269 &graph);
270 parseIR(
271 R"IR(
272graph(%0, %1):
273 %b = b::bbb(%0)
274 %c = c::ccc(%1)
275 %x = x::xxx(%b, %c)
276 %e = e::eee(%x)
277 %f = f::fff(%x)
278 %g = g::ggg(%e, %f)
279 return (%g))IR",
280 &pattern);
281 AT_ASSERT(!findPatternMatches(pattern, graph).empty());
282}
283
284TEST(SubgraphMatcherTest, MultipleMatches) {
285 Graph graph, pattern;
286 parseIR(
287 R"IR(
288graph(%t0):
289 %t1 = a::aaa(%t0)
290 %t2 = a::aaa(%t1)
291 %t3 = a::aaa(%t2)
292 %t4 = a::aaa(%t3)
293 return (%t4))IR",
294 &graph);
295 parseIR(
296 R"IR(
297graph(%t0):
298 %t1 = a::aaa(%t0)
299 return (%t1))IR",
300 &pattern);
301 auto matches = findPatternMatches(pattern, graph);
302 AT_ASSERT(matches.size() == 4);
303}
304
305TEST(SubgraphMatcherTest, OverlappingMatches) {
306 Graph graph, pattern;
307 parseIR(
308 R"IR(
309graph(%t0):
310 %t1 = a::aaa(%t0)
311 %t2 = a::aaa(%t1)
312 %t3 = a::aaa(%t2)
313 %t4 = a::aaa(%t3)
314 return (%t4))IR",
315 &graph);
316 parseIR(
317 R"IR(
318graph(%t0):
319 %t1 = a::aaa(%t0)
320 %t2 = a::aaa(%t1)
321 return (%t2))IR",
322 &pattern);
323 auto matches = findPatternMatches(pattern, graph);
324 AT_ASSERT(matches.size() == 3);
325}
326
327TEST(SubgraphMatcherTest, MatchInBasicBlocks1) {
328 Graph graph;
329 parseIR(
330 R"IR(
331graph(%a, %b, %c):
332 %d = aten::mul(%a, %b)
333 %x = prim::If(%c)
334 block0():
335 %x1 = aten::mul(%a, %d)
336 -> (%x1)
337 block1():
338 %x2 = aten::mul(%b, %d)
339 -> (%x2)
340 return (%x))IR",
341 &graph);
342
343 // Ensure the matches don't cross basic block boundaries
344 Graph pattern0;
345 parseIR(
346 R"IR(
347graph(%x, %y):
348 %z = aten::mul(%x, %y)
349 return (%z))IR",
350 &pattern0);
351 AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3);
352
353 Graph pattern1;
354 parseIR(
355 R"IR(
356graph(%x, %y):
357 %z1 = aten::mul(%x, %y)
358 %z2 = aten::mul(%y, %z1)
359 return (%z2))IR",
360 &pattern1);
361 AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
362}
363
364TEST(SubgraphMatcherTest, MatchInBasicBlocks2) {
365 Graph graph;
366 parseIR(
367 R"IR(
368graph(%a, %b):
369 %x = my::mul(%a, %b)
370 %y = my::node_with_subblock()
371 block0():
372 %z = my::mul(%b, %x)
373 -> (%z)
374 return (%y))IR",
375 &graph);
376
377 // Check that we can match both mul ops
378 Graph pattern0;
379 parseIR(
380 R"IR(
381graph(%x, %y):
382 %z = my::mul(%x, %y)
383 return (%z))IR",
384 &pattern0);
385 AT_ASSERT(findPatternMatches(pattern0, graph).size() == 2);
386
387 // Ensure the matches don't cross basic block boundaries
388 Graph pattern1;
389 parseIR(
390 R"IR(
391graph(%x, %y):
392 %u = my::mul(%x, %y)
393 %v = my::mul(%y, %u)
394 return (%v))IR",
395 &pattern1);
396 AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
397}
398
399TEST(SubgraphMatcherTest, MatchesAttributes) {
400 Graph graph;
401 parseIR(
402 R"IR(
403graph(%0):
404 %a = a::a[isattr=[1,2]](%0)
405 %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0)
406 %c = a::c[myattr="qqq"](%a, %b)
407 return (%c))IR",
408 &graph);
409
410 {
411 Graph pattern;
412 parseIR(
413 R"IR(
414graph(%a, %b):
415 %c = a::c[myattr="qqq"](%a, %b)
416 return (%c))IR",
417 &pattern);
418 AT_ASSERT(!findPatternMatches(pattern, graph).empty());
419 }
420 {
421 Graph pattern;
422 parseIR(
423 R"IR(
424graph(%a, %b):
425 %c = a::c[myattr="zzz"](%a, %b)
426 return (%c))IR",
427 &pattern);
428 AT_ASSERT(findPatternMatches(pattern, graph).empty());
429 }
430 {
431 Graph pattern;
432 parseIR(
433 R"IR(
434graph(%0):
435 %b = a::b[extraattr=10](%0)
436 return (%b))IR",
437 &pattern);
438 AT_ASSERT(findPatternMatches(pattern, graph).empty());
439 }
440 {
441 Graph pattern;
442 parseIR(
443 R"IR(
444graph(%0):
445 %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0)
446 return (%b))IR",
447 &pattern);
448 AT_ASSERT(!findPatternMatches(pattern, graph).empty());
449 }
450 {
451 Graph pattern;
452 parseIR(
453 R"IR(
454graph(%0):
455 %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j, strattr="rrr"](%0)
456 return (%b))IR",
457 &pattern);
458 AT_ASSERT(findPatternMatches(pattern, graph).empty());
459 }
460 {
461 Graph pattern;
462 parseIR(
463 R"IR(
464graph(%0):
465 %a = a::a[isattr=[1,2]](%0)
466 return (%a))IR",
467 &pattern);
468 // Lists are not supported yet, thus we shouldn't match for now.
469 AT_ASSERT(findPatternMatches(pattern, graph).empty());
470 }
471 {
472 Graph pattern;
473 parseIR(
474 R"IR(
475graph(%a, %b):
476 %c = a::c[myattr="q.*"](%a, %b)
477 return (%c))IR",
478 &pattern);
479 AT_ASSERT(!findPatternMatches(pattern, graph).empty());
480 }
481}
482
483TEST(SubgraphMatcherTest, BadPattern) {
484 Graph graph, pattern1, pattern2;
485 parseIR(
486 R"IR(
487graph(%x):
488 %y = my::op1(%x)
489 %z = my::op2(%x)
490 return (%y, %z))IR",
491 &graph);
492
493 parseIR(
494 R"IR(
495graph(%x):
496 %y = my::node_with_subblock()
497 block0():
498 %z = my::op(%x)
499 -> (%z)
500 return (%y))IR",
501 &pattern1);
502 // No support for patterns with subblocks
503 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
504 ASSERT_ANY_THROW(findPatternMatches(pattern1, graph));
505
506 parseIR(
507 R"IR(
508graph(%x):
509 %y = my::op1(%x)
510 %z = my::op2(%x)
511 return (%y, %z))IR",
512 &pattern2);
513 // Not supported multi-output pattern, because not the whole pattern is
514 // covered by a traversal up from the first output (`%z = ...` is not
515 // visited). See the note "Multi-output Patterns" in subgraph_matcher.h.
516 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
517 ASSERT_ANY_THROW(findPatternMatches(pattern2, graph));
518}
519
520TEST(SubgraphMatcherTest, MultiOutput) {
521 {
522 Graph graph, pattern;
523 parseIR(
524 R"IR(
525graph(%0):
526 %a = a::aaa(%0)
527 %b = b::bbb(%a)
528 %c = c::ccc(%a, %b)
529 %x = a::aaa(%c)
530 %y = b::bbb(%x)
531 %z = d::ddd(%x, %y)
532 return (%y))IR",
533 &graph);
534 parseIR(
535 R"IR(
536graph(%0):
537 %a = a::aaa(%0)
538 %b = b::bbb(%a)
539 return (%b, %a))IR",
540 &pattern);
541 AT_ASSERT(findPatternMatches(pattern, graph).size() == 2);
542 }
543 {
544 Graph graph, pattern;
545 parseIR(
546 R"IR(
547graph(%0, %1):
548 %a1, %a2 = a::aaa(%0, %1)
549 %b = b::bbb(%a1)
550 %c = c::ccc(%b)
551
552 %x1, %x2 = a::aaa(%c, %a2)
553 %y = b::bbb(%x1)
554 %z = d::ddd(%y)
555 return (%z))IR",
556 &graph);
557 parseIR(
558 R"IR(
559graph(%0, %1):
560 %a1, %a2 = a::aaa(%0, %1)
561 %b = b::bbb(%a1)
562 return (%b, %a2))IR",
563 &pattern);
564 AT_ASSERT(findPatternMatches(pattern, graph).size() == 2);
565 }
566}
567
568} // namespace jit
569} // namespace torch
570