1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4#include <torch/csrc/jit/ir/subgraph_matcher.h>
5#include <torch/csrc/jit/passes/subgraph_rewrite.h>
6#include <torch/csrc/jit/testing/file_check.h>
7
8namespace torch {
9namespace jit {
10using namespace testing;
11
12TEST(SubgraphRewriterTest, FilterMatch) {
13 auto graph = std::make_shared<Graph>();
14
15 parseIR(
16 R"IR(
17graph(%0):
18 %a = a::aaa(%0)
19 %b : int = prim::Constant[value=1]()
20 %c = c::ccc(%a, %b)
21 return (%c))IR",
22 graph.get());
23
24 std::string pattern = R"IR(
25graph(%a, %b):
26 %c = c::ccc(%a, %b)
27 return (%c))IR";
28 Graph pattern_graph;
29 std::unordered_map<std::string, Value*> vmap;
30
31 parseIR(pattern, &pattern_graph, vmap);
32
33 auto b_is_constant = [](const Match& match,
34 const std::unordered_map<std::string, Value*>& vmap) {
35 const auto& match_vmap = match.values_map;
36 auto b_node = match_vmap.at(vmap.at("b"))->node();
37 return b_node->kind() == prim::Constant;
38 };
39
40 auto b_is_one = [](const Match& match,
41 const std::unordered_map<std::string, Value*>& vmap) {
42 const auto& match_vmap = match.values_map;
43 auto b_val = toIValue(match_vmap.at(vmap.at("b")));
44 return b_val && b_val->isInt() && b_val->toInt() == 1;
45 };
46
47 auto b_is_two = [](const Match& match,
48 const std::unordered_map<std::string, Value*>& vmap) {
49 const auto& match_vmap = match.values_map;
50 auto b_val = toIValue(match_vmap.at(vmap.at("b")));
51 return b_val && b_val->isInt() && b_val->toInt() == 2;
52 };
53
54 std::string replacement = R"IR(
55graph(%a, %b):
56 %d = d::ddd(%a, %b)
57 return (%d))IR";
58
59 SubgraphRewriter rewriter;
60 rewriter.RegisterRewritePattern(pattern, replacement);
61
62 // b is constant, so the match will succeed
63 {
64 auto g = graph->copy();
65 rewriter.runOnGraph(g, b_is_constant);
66 FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
67 }
68
69 // b is constant and the value is one, the match will succeed
70 {
71 auto g = graph->copy();
72 rewriter.runOnGraph(g, {b_is_constant, b_is_one});
73 FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
74 }
75
76 // b is constant but the value is not two, the match will fail
77 {
78 auto g = graph->copy();
79 rewriter.runOnGraph(g, {b_is_constant, b_is_two});
80 FileCheck().check("c::ccc")->check_not("d::ddd")->run(*g);
81 }
82}
83
84TEST(SubgraphRewriterTest, FilterNoMatch) {
85 auto graph = std::make_shared<Graph>();
86 parseIR(
87 R"IR(
88graph(%0):
89 %a = a::aaa(%0)
90 %b = prim::Constant[value=1]()
91 %c = c::ccc(%a, %b)
92 return (%c))IR",
93 graph.get());
94
95 std::string pattern = R"IR(
96graph(%a, %b):
97 %c = c::ccc(%a, %b)
98 return (%c))IR";
99 Graph pattern_graph;
100 std::unordered_map<std::string, Value*> vmap;
101
102 parseIR(pattern, &pattern_graph, vmap);
103
104 auto filter = [](const Match& match,
105 const std::unordered_map<std::string, Value*>& vmap) {
106 const auto& match_vmap = match.values_map;
107 auto b_node = match_vmap.at(vmap.at("b"))->node();
108 // b_node is not prim::Assign, so this won't match and we'll skip the
109 // rewrite
110 return b_node->kind() == prim::Assign;
111 };
112
113 std::string replacement = R"IR(
114graph(%a, %b):
115 %d = d::ddd(%a, %b)
116 return (%d))IR";
117
118 SubgraphRewriter rewriter;
119 rewriter.RegisterRewritePattern(pattern, replacement);
120 rewriter.runOnGraph(graph, filter);
121
122 FileCheck().check("c::ccc")->check_not("d::ddd")->run(*graph);
123}
124
125TEST(SubgraphRewriterTest, MultiOutput) {
126 {
127 auto graph = std::make_shared<Graph>();
128
129 // Basic multi-output pattern rewriting
130 parseIR(
131 R"IR(
132graph(%0, %1):
133 %a1, %a2 = a::aaa(%0, %1)
134 %b = b::bbb(%a1)
135 %c = c::ccc(%b)
136
137 %x1, %x2 = a::aaa(%c, %a2)
138 %y = b::bbb(%x1)
139 %z = d::ddd(%y)
140 return (%z))IR",
141 graph.get());
142
143 std::string pattern = R"IR(
144graph(%0, %1):
145 %a1, %a2 = a::aaa(%0, %1)
146 %b = b::bbb(%a1)
147 return (%b, %a2))IR";
148
149 std::string replacement = R"IR(
150graph(%a, %b):
151 %x, %y = ab::ababab(%a, %b)
152 return (%x, %y))IR";
153
154 SubgraphRewriter rewriter;
155 rewriter.RegisterRewritePattern(pattern, replacement);
156
157 auto g = graph->copy();
158 rewriter.runOnGraph(g);
159 FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g);
160 }
161 {
162 auto graph = std::make_shared<Graph>();
163
164 // Mimic a real model case
165 parseIR(
166 R"IR(
167 graph(%k, %m, %x1, %x2, %x3, %x4, %y1, %y2, %y3, %y4):
168 %a1 = aa::aaa(%x1, %k)
169 %b1_1, %b1_2 = bb::bbb(%y1, %a1)
170 %a2 = aa::aaa(%x2, %k)
171 %b2_1, %b2_2 = bb::bbb(%y2, %a2)
172 %a3 = aa::aaa(%x3, %k)
173 %b3_1, %b3_2 = bb::bbb(%y3, %a3)
174 %a4 = aa::aaa(%x4, %k)
175 %b4_1, %b4_2 = bb::bbb(%y4, %a4)
176 %c = cc::ccc(%b4_1)
177 %d1 = dd::ddd(%b1_2, %m)
178 %e1 = ee::eee(%b1_1, %d1)
179 %d2 = dd::ddd(%b2_2, %m)
180 %e2 = ee::eee(%b2_1, %d2)
181 %d3 = dd::ddd(%b3_2, %m)
182 %e3 = ee::eee(%b3_1, %d3)
183 %d4 = dd::ddd(%b4_2, %m)
184 %e4 = ee::eee(%b4_1, %d4)
185 return (%d1, %d2, %d3, %d4, %e1, %e2, %e3, %e4)
186 )IR",
187 graph.get());
188
189 std::string pattern = R"IR(
190 graph(%a, %b, %c, %d):
191 %y0 = aa::aaa(%b, %c)
192 %y1, %y2 = bb::bbb(%a, %y0)
193 %y3 = dd::ddd(%y2, %d)
194 return (%y3, %y1))IR";
195
196 std::string replacement = R"IR(
197 graph(%a, %b, %c, %d):
198 %x, %y = ab::ababab(%a, %b, %c, %d)
199 return (%x, %y))IR";
200
201 SubgraphRewriter rewriter;
202 rewriter.RegisterRewritePattern(pattern, replacement);
203
204 auto g = graph->copy();
205 rewriter.runOnGraph(g);
206 FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g);
207 }
208 {
209 auto graph = std::make_shared<Graph>();
210
211 // A case where no rewriting should occur due to data dependencies
212 parseIR(
213 R"IR(
214 graph(%x, %y):
215 %a = aa::aaa(%x)
216 %b = bb::bbb(%a)
217 %e = ee::eee(%b)
218 %c = cc::ccc(%y)
219 %d = dd::ddd(%b, %c)
220 %f = ff::fff(%b, %d)
221 return (%f)
222 )IR",
223 graph.get());
224
225 std::string pattern = R"IR(
226 graph(%a, %c):
227 %b = bb::bbb(%a)
228 %d = dd::ddd(%b, %c)
229 return (%d, %b))IR";
230
231 std::string replacement = R"IR(
232 graph(%a, %c):
233 %d, %b = db::fused(%a, %c)
234 return (%d, %b))IR";
235
236 SubgraphRewriter rewriter;
237 rewriter.RegisterRewritePattern(pattern, replacement);
238
239 auto g = graph->copy();
240 rewriter.runOnGraph(g);
241 // We should not perform the replacement on the given graph due to data
242 // dependency constraints: the output %b is used in %e, which precedes one
243 // def of the input %c.
244 FileCheck().check_not("db::fused")->run(*g);
245 }
246}
247
248TEST(SubgraphRewriterTest, OutputType) {
249 std::string pattern = R"IR(
250graph(%a, %b):
251 %c = c::ccc(%a, %b)
252 return (%c))IR";
253 Graph pattern_graph;
254 std::unordered_map<std::string, Value*> vmap;
255
256 parseIR(pattern, &pattern_graph, vmap);
257
258 auto b_is_constant = [](const Match& match,
259 const std::unordered_map<std::string, Value*>& vmap) {
260 const auto& match_vmap = match.values_map;
261 auto b_node = match_vmap.at(vmap.at("b"))->node();
262 return b_node->kind() == prim::Constant;
263 };
264
265 std::string replacement = R"IR(
266graph(%a, %b):
267 %d = d::ddd(%a, %b)
268 return (%d))IR";
269
270 SubgraphRewriter rewriter;
271 rewriter.RegisterRewritePattern(pattern, replacement);
272 {
273 auto graph = std::make_shared<Graph>();
274
275 parseIR(
276 R"IR(
277 graph(%0):
278 %a : Float(10, 20) = a::aaa(%0)
279 %b : int = prim::Constant[value=1]()
280 %c : Float(10, 20) = c::ccc(%a, %b)
281 return (%c))IR",
282 graph.get());
283
284 // output has shape info.
285 rewriter.runOnGraph(graph, b_is_constant);
286 FileCheck()
287 .check("Float(10, 20) = d::ddd")
288 ->check_not("c::ccc")
289 ->run(*graph);
290 }
291 {
292 auto graph = std::make_shared<Graph>();
293
294 parseIR(
295 R"IR(
296 graph(%0):
297 %a = a::aaa(%0)
298 %b : int = prim::Constant[value=1]()
299 %c = c::ccc(%a, %b)
300 return (%c))IR",
301 graph.get());
302
303 // output has not shape info.
304 rewriter.runOnGraph(graph, b_is_constant);
305 FileCheck().check("Tensor = d::ddd")->check_not("c::ccc")->run(*graph);
306 }
307}
308
309} // namespace jit
310} // namespace torch
311