1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4#include <torch/csrc/jit/ir/irparser.h>
5#include <torch/csrc/jit/passes/concat_opt.h>
6#include <torch/csrc/jit/passes/variadic_ops.h>
7#include <torch/csrc/jit/runtime/interpreter.h>
8#include <torch/csrc/jit/testing/file_check.h>
9
10namespace torch {
11namespace jit {
12
13TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) {
14 auto graph = std::make_shared<Graph>();
15
16 const std::string input =
17 R"IR(
18 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
19 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
20 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
21 %5 : int = prim::Constant[value=0]()
22 %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
23 %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
24 %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
25 return (%res)
26 )IR";
27 parseIR(input, graph.get());
28 std::vector<at::Tensor> inputs = {
29 at::rand({64, 56, 56}, at::kCPU),
30 at::rand({32, 56, 56}, at::kCPU),
31 at::rand({32, 56, 56}, at::kCPU)};
32 auto orig_outputs = runGraph(graph, inputs);
33
34 ASSERT_TRUE(EliminateConcatCommonInputs(graph));
35 graph->lint();
36 auto opt_outputs = runGraph(graph, inputs);
37 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
38
39 // Graph after EliminateConcatCommonInputs:
40 // graph(%0 : ...,
41 // %1 : ...,
42 // %2 : ...):
43 // %3 : int = prim::Constant[value=0]()
44 // %4 : Tensor = prim::VarConcat(%0, %1, %3)
45 // %7 : Tensor = prim::VarConcat(%4, %2, %3) // UPDATED
46 // %8 : Tensor[] = prim::ListConstruct(%4, %7)
47 // return (%8)
48
49 testing::FileCheck()
50 .check_count("= prim::VarConcat(%0, %1, %3)", 1, /*exactly*/ true)
51 ->check_count("= prim::VarConcat(%4, %2, %3)", 1, /*exactly*/ true)
52 ->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true)
53 ->check_count("= aten::cat(", 0, /*exactly*/ true)
54 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
55 ->run(*graph);
56}
57
58TEST(ConcatOptTest, SimpleCommonInputsEliminationSuffix) {
59 auto graph = std::make_shared<Graph>();
60
61 const std::string input =
62 R"IR(
63 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
64 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
65 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
66 %5 : int = prim::Constant[value=0]()
67 %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %2, %5)
68 %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
69 %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
70 return (%res)
71 )IR";
72 parseIR(input, graph.get());
73 std::vector<at::Tensor> inputs = {
74 at::rand({64, 56, 56}, at::kCPU),
75 at::rand({32, 56, 56}, at::kCPU),
76 at::rand({32, 56, 56}, at::kCPU)};
77 auto orig_outputs = runGraph(graph, inputs);
78
79 ASSERT_TRUE(EliminateConcatCommonInputs(graph));
80 graph->lint();
81 auto opt_outputs = runGraph(graph, inputs);
82 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
83
84 // Graph after EliminateConcatCommonInputs:
85 // graph(%0 : ...,
86 // %1 : ...,
87 // %2 : ...):
88 // %3 : int = prim::Constant[value=0]()
89 // %4 : Tensor = prim::VarConcat(%1, %2, %3)
90 // %7 : Tensor = prim::VarConcat(%0, %4, %3) // UPDATED
91 // %8 : Tensor[] = prim::ListConstruct(%4, %7)
92 // return (%8)
93
94 testing::FileCheck()
95 .check_count("= prim::VarConcat(%1, %2, %3)", 1, /*exactly*/ true)
96 ->check_count("= prim::VarConcat(%0, %4, %3)", 1, /*exactly*/ true)
97 ->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true)
98 ->check_count("= aten::cat(", 0, /*exactly*/ true)
99 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
100 ->run(*graph);
101}
102
103TEST(ConcatOptTest, CommonInputsEliminationWithDifferentOrderInputs) {
104 auto graph = std::make_shared<Graph>();
105
106 const std::string input =
107 R"IR(
108 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
109 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
110 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
111 %5 : int = prim::Constant[value=0]()
112
113 #CHECK: prim::VarConcat
114 %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
115
116 #CHECK: prim::VarConcat
117 %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %0, %2, %5)
118
119 #CHECK: prim::ListConstruct
120 %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2)
121 return (%res)
122 )IR";
123 parseIR(input, graph.get());
124 std::vector<at::Tensor> inputs = {
125 at::rand({64, 56, 56}, at::kCPU),
126 at::rand({32, 56, 56}, at::kCPU),
127 at::rand({32, 56, 56}, at::kCPU)};
128 auto orig_outputs = runGraph(graph, inputs);
129
130 ASSERT_FALSE(EliminateConcatCommonInputs(graph));
131 graph->lint();
132 auto opt_outputs = runGraph(graph, inputs);
133
134 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
135
136 // No optimizations should have happened in this case since the inputs
137 // to the `cat` are in different order.
138 testing::FileCheck().run(input, *graph);
139}
140
141TEST(ConcatOptTest, MoreCommonInputsElimination) {
142 auto graph = std::make_shared<Graph>();
143
144 const std::string input =
145 R"IR(
146 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
147 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
148 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
149 %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
150 %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
151 %5 : int = prim::Constant[value=0]()
152 %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
153 %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
154 %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %5)
155 %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %4, %5)
156 %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2, %concat.3, %concat.4)
157 return (%res)
158 )IR";
159 parseIR(input, graph.get());
160 std::vector<at::Tensor> inputs = {
161 at::rand({64, 56, 56}, at::kCPU),
162 at::rand({32, 56, 56}, at::kCPU),
163 at::rand({32, 56, 56}, at::kCPU),
164 at::rand({32, 56, 56}, at::kCPU),
165 at::rand({32, 56, 56}, at::kCPU)};
166 auto orig_outputs = runGraph(graph, inputs);
167
168 ASSERT_TRUE(EliminateConcatCommonInputs(graph));
169 graph->lint();
170 auto opt_outputs = runGraph(graph, inputs);
171 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
172
173 testing::FileCheck()
174 .check_count("= prim::VarConcat(%0, %1, %5)", 1, /*exactly*/ true)
175 ->check_count("= prim::VarConcat(%6, %2, %5)", 1, /*exactly*/ true)
176 ->check_count("= prim::VarConcat(%11, %3, %5)", 1, /*exactly*/ true)
177 ->check_count("= prim::VarConcat(%12, %4, %5)", 1, /*exactly*/ true)
178 ->check_count("= aten::cat(", 0, /*exactly*/ true)
179 ->run(*graph);
180}
181
182TEST(ConcatOptTest, ExpandConcat) {
183 auto graph = std::make_shared<Graph>();
184
185 const std::string input =
186 R"IR(
187 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
188 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
189 %2 : int = prim::Constant[value=0]()
190 %3 : float = prim::Constant[value=0.5]()
191 %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3)
192 %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3)
193 %input : Tensor[] = prim::ListConstruct(%4, %5)
194 %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2)
195 return (%concat)
196 )IR";
197 parseIR(input, graph.get());
198 std::vector<at::Tensor> inputs = {
199 at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
200 auto orig_outputs = runGraph(graph, inputs);
201
202 ExpandConcatAndEliminateRedundancy(graph);
203 graph->lint();
204 auto opt_outputs = runGraph(graph, inputs);
205
206 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
207
208 // After full concat optimization we should have the following graph:
209 //
210 // graph(%0 : ...,
211 // %1 : ...):
212 // ...
213 // %4 : Tensor = aten::clamp_max(...)
214 // %5 : Tensor = aten::clamp_max(...)
215 // %13 : int[] = prim::ListConstruct(...)
216 // %14 : Tensor = aten::empty(%13, ...) // concat buffer
217 // %17 : Tensor = aten::slice(%14, ...) // slice for %4
218 // %18 : Tensor = aten::copy_(%17, %4)
219 // %20 : Tensor = aten::slice(%14, ...) // slice for %5
220 // %21 : Tensor = aten::copy_(%20, %5)
221 // return (%14)
222 testing::FileCheck()
223 .check_count("= aten::cat(", 0, /*exactly*/ true)
224 ->check_count("= aten::clamp_max(", 2, /*exactly*/ true)
225 ->check_count("= aten::empty(", 1, /*exactly*/ true)
226 ->check_count("= aten::slice(", 1, /*exactly*/ true)
227 ->check_count("= aten::copy_(", 1, /*exactly*/ true)
228 ->check_count("= aten::slice(", 1, /*exactly*/ true)
229 ->check_count("= aten::copy_(", 1, /*exactly*/ true)
230 ->check_count("= aten::cat(", 0, /*exactly*/ true)
231 ->run(*graph);
232}
233
234TEST(ConcatOptTest, ConcatWithoutResultShape) {
235 auto graph = std::make_shared<Graph>();
236
237 const std::string input =
238 R"IR(
239 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
240 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
241 %2 : int = prim::Constant[value=0]()
242 %3 : float = prim::Constant[value=0.5]()
243 # CHECK: clamp_max
244 %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3)
245 # CHECK: clamp_max
246 %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3)
247 # CHECK: prim::ListConstruct
248 %6 : Tensor[] = prim::ListConstruct(%4, %5)
249 # CHECK: aten::cat
250 %7 : Tensor = aten::cat(%6, %2)
251 return (%7)
252 )IR";
253 parseIR(input, graph.get());
254 std::vector<at::Tensor> inputs = {
255 at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
256 auto orig_outputs = runGraph(graph, inputs);
257
258 ExpandConcatAndEliminateRedundancy(graph);
259 graph->lint();
260 auto opt_outputs = runGraph(graph, inputs);
261
262 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
263
264 // No optimizations should have happened in this case since the output
265 // shape of `aten::cat` is not known.
266 testing::FileCheck().run(input, *graph);
267}
268
269TEST(ConcatOptTest, ConcatWithoutInputShape) {
270 auto graph = std::make_shared<Graph>();
271
272 const std::string input =
273 R"IR(
274 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
275 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
276 %2 : int = prim::Constant[value=0]()
277 %3 : float = prim::Constant[value=0.5]()
278 # CHECK: clamp_max
279 %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3)
280 # CHECK: clamp_max
281 %5 : Tensor = aten::clamp_max(%1, %3)
282 # CHECK: prim::ListConstruct
283 %6 : Tensor[] = prim::ListConstruct(%4, %5)
284 # CHECK: aten::cat
285 %7 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%6, %2)
286 return (%7)
287 )IR";
288 parseIR(input, graph.get());
289 std::vector<at::Tensor> inputs = {
290 at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
291 auto orig_outputs = runGraph(graph, inputs);
292
293 ExpandConcatAndEliminateRedundancy(graph);
294 graph->lint();
295 auto opt_outputs = runGraph(graph, inputs);
296
297 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
298
299 // No optimizations should have happened in this case since the shape of %5,
300 // which is an input to `aten::cat`, is not known.
301 testing::FileCheck().run(input, *graph);
302}
303
304TEST(ConcatOptTest, UseVariadicCat) {
305 auto graph = std::make_shared<Graph>();
306
307 const std::string input =
308 R"IR(
309 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
310 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
311 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
312 %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
313 %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
314 %5: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
315 %10 : int = prim::Constant[value=0]()
316 %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5)
317 %concat : Float(224, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
318 return (%concat)
319 )IR";
320 parseIR(input, graph.get());
321 std::vector<at::Tensor> inputs = {
322 at::rand({64, 56, 56}, at::kCPU),
323 at::rand({32, 56, 56}, at::kCPU),
324 at::rand({32, 56, 56}, at::kCPU),
325 at::rand({32, 56, 56}, at::kCPU),
326 at::rand({32, 56, 56}, at::kCPU),
327 at::rand({32, 56, 56}, at::kCPU)};
328 auto orig_outputs = runGraph(graph, inputs);
329
330 ASSERT_TRUE(UseVariadicCat(graph));
331 graph->lint();
332 auto opt_outputs = runGraph(graph, inputs);
333
334 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
335
336 // After replacing `aten::cat` with `prim::VarConcat` we should have the
337 // following graph:
338 //
339 // graph(%0 : ...,
340 // %1 : ...):
341 // %zero : int = prim:Constant[value=0]()
342 // %varcat : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %5, %zero)
343 // return (%varcat)
344 testing::FileCheck()
345 .check_count("= prim::VarConcat(", 1, /*exactly*/ true)
346 ->check_count("= aten::cat(", 0, /*exactly*/ true)
347 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
348 ->run(*graph);
349}
350
351TEST(OptimizeConcatTest, UseVariadicCatReplaceMultiple) {
352 auto graph = std::make_shared<Graph>();
353
354 const std::string input =
355 R"IR(
356 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
357 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
358 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
359 %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
360 %10 : int = prim::Constant[value=0]()
361 %input1 : Tensor[] = prim::ListConstruct(%0, %1)
362 %concat1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input1, %10)
363 %input2 : Tensor[] = prim::ListConstruct(%2, %3)
364 %concat2 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input2, %10)
365 return (%concat1, %concat2)
366 )IR";
367 parseIR(input, graph.get());
368 std::vector<at::Tensor> inputs = {
369 at::rand({64, 56, 56}, at::kCPU),
370 at::rand({32, 56, 56}, at::kCPU),
371 at::rand({32, 56, 56}, at::kCPU),
372 at::rand({32, 56, 56}, at::kCPU)};
373 auto orig_outputs = runGraph(graph, inputs);
374
375 ASSERT_TRUE(UseVariadicCat(graph));
376 graph->lint();
377 auto opt_outputs = runGraph(graph, inputs);
378
379 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
380
381 // After full concat optimization we should have the following graph:
382 //
383 // graph(%0 : ...,
384 // %1 : ...,
385 // %2 : ...,
386 // %3 : ....):
387 // %zero : int = prim:Constant[value=0]()
388 // %varcat1 : Tensor = prim::VarConcat(%0, %1, %zero)
389 // %varcat2 : Tensor = prim::VarConcat(%2, %3, %zero)
390 // return (%varcat1, %varcat2)
391 testing::FileCheck()
392 .check_count("= prim::VarConcat(", 2, /*exactly*/ true)
393 ->check_count("= aten::cat(", 0, /*exactly*/ true)
394 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
395 ->run(*graph);
396}
397
398TEST(ConcatOptTest, UseVariadicCatWithMultipleListUses) {
399 auto graph = std::make_shared<Graph>();
400
401 const std::string input =
402 R"IR(
403 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
404 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
405 %2 : int = prim::Constant[value=0]()
406 %input : Tensor[] = prim::ListConstruct(%0, %1)
407 %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2)
408 return (%concat, %input)
409 )IR";
410 parseIR(input, graph.get());
411 std::vector<at::Tensor> inputs = {
412 at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
413 auto orig_outputs = runGraph(graph, inputs);
414
415 ASSERT_TRUE(UseVariadicCat(graph));
416 graph->lint();
417 auto opt_outputs = runGraph(graph, inputs);
418
419 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
420
421 // After replacing `aten::cat` with `prim::VarConcat` we should have the
422 // following graph:
423 //
424 // graph(%0 : ...,
425 // %1 : ...):
426 // %zero : int = prim:Constant[value=0]()
427 // %input : Tensor[] = prim::ListConstruct(%0, %1)
428 // %varcat : Tensor = prim::VarConcat(%0, %1, %zero)
429 // return (%varcat, %input)
430 testing::FileCheck()
431 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
432 ->check_count("= prim::VarConcat(", 1, /*exactly*/ true)
433 ->check_count("= aten::cat(", 0, /*exactly*/ true)
434 ->run(*graph);
435}
436
437TEST(ConcatOptTest, UseVariadicCatWithListMutationAfterCat) {
438 auto graph = std::make_shared<Graph>();
439
440 const std::string input =
441 R"IR(
442 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
443 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
444 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
445 %10 : int = prim::Constant[value=0]()
446 %input : Tensor[] = prim::ListConstruct(%0, %1)
447 %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
448 %11 : Tensor = aten::append(%input, %2)
449 return (%concat, %input)
450 )IR";
451 parseIR(input, graph.get());
452 std::vector<at::Tensor> inputs = {
453 at::rand({64, 56, 56}, at::kCPU),
454 at::rand({32, 56, 56}, at::kCPU),
455 at::rand({32, 56, 56}, at::kCPU)};
456 auto orig_outputs = runGraph(graph, inputs);
457
458 ASSERT_TRUE(UseVariadicCat(graph));
459 graph->lint();
460 auto opt_outputs = runGraph(graph, inputs);
461 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
462
463 // The input list to `aten::cat` is mutated only after `aten::cat` op. So,
464 // it should have been replaced with `prim::VarConcat`. The transformed graph
465 // should look like the following:
466 //
467 // graph(%0 : ...,
468 // %1 : ...,
469 // %2 : ...):
470 // %3 : int = prim:Constant[value=0]()
471 // %4 : Tensor[] = prim::ListConstruct(%0, %1)
472 // %7 : Tensor = prim::VarConcat(%0, %1, %3)
473 // %6 : Tensor = aten::append(%4, %2)
474 // return (%7, %4)
475 testing::FileCheck()
476 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
477 ->check_count("= prim::VarConcat(", 1, /*exactly*/ true)
478 ->check_count("= aten::cat(", 0, /*exactly*/ true)
479 ->run(*graph);
480}
481
482TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) {
483 auto graph = std::make_shared<Graph>();
484
485 const std::string input =
486 R"IR(
487 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
488 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
489 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
490 %10 : int = prim::Constant[value=0]()
491 %input : Tensor[] = prim::ListConstruct(%0, %1)
492 %11 : Tensor = aten::append(%input, %2)
493 %concat : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
494 return (%concat)
495 )IR";
496 parseIR(input, graph.get());
497 std::vector<at::Tensor> inputs = {
498 at::rand({64, 56, 56}, at::kCPU),
499 at::rand({32, 56, 56}, at::kCPU),
500 at::rand({32, 56, 56}, at::kCPU)};
501 auto orig_outputs = runGraph(graph, inputs);
502
503 {
504 ASSERT_FALSE(UseVariadicCat(graph));
505 graph->lint();
506 auto opt_outputs = runGraph(graph, inputs);
507 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
508
509 // No transformation should have happened since the `prim::ListConstruct` is
510 // mutated before `aten::cat`.
511 testing::FileCheck()
512 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
513 ->check_count("= aten::cat(", 1, /*exactly*/ true)
514 ->check_count("= prim::VarConcat(", 0, /*exactly*/ true)
515 ->run(*graph);
516 }
517
518 {
519 ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
520 graph->lint();
521 auto opt_outputs = runGraph(graph, inputs);
522 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
523
524 // The mutation of the list must be removed and the `aten::cat` op must
525 // be replaced with the `prim::VarConcat` op in the graph. The transformed
526 // graph should look like the following:
527 //
528 // graph(%0 : ...,
529 // %1 : ...,
530 // %2 : ...):
531 // %3 : int = prim:Constant[value=0]()
532 // %7 : Tensor = prim::VarConcat(%0, %1, %2, %3)
533 // return (%7)
534 testing::FileCheck()
535 .check_count("= prim::VarConcat(", 1, /*exactly*/ true)
536 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
537 ->check_count("= aten::cat(", 0, /*exactly*/ true)
538 ->run(*graph);
539 }
540}
541
542TEST(ConcatOptTest, UseVariadicCatWithMultipleListMutations) {
543 auto graph = std::make_shared<Graph>();
544
545 const std::string input =
546 R"IR(
547 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
548 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
549 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
550 %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
551 %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
552 %10 : int = prim::Constant[value=0]()
553 %input : Tensor[] = prim::ListConstruct(%0, %1)
554 %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
555 %11 : Tensor = aten::append(%input, %2)
556 %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
557 %12 : Tensor = aten::append(%input, %3)
558 %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
559 %13 : Tensor = aten::append(%input, %4)
560 %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
561 return (%concat.1, %concat.2, %concat.3, %concat.4)
562 )IR";
563 parseIR(input, graph.get());
564 std::vector<at::Tensor> inputs = {
565 at::rand({64, 56, 56}, at::kCPU),
566 at::rand({32, 56, 56}, at::kCPU),
567 at::rand({32, 56, 56}, at::kCPU),
568 at::rand({32, 56, 56}, at::kCPU),
569 at::rand({32, 56, 56}, at::kCPU)};
570 auto orig_outputs = runGraph(graph, inputs);
571
572 ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
573 graph->lint();
574 auto opt_outputs = runGraph(graph, inputs);
575 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
576
577 // All the mutations of the list must be removed and the `aten::cat` ops must
578 // be replaced with `prim::VarConcat` ops in the graph. The transformed graph
579 // should look like the following:
580 //
581 // graph(%0 : ...,
582 // %1 : ...,
583 // %2 : ...,
584 // %3 : ...,
585 // %4 : ...):
586 // %10 : int = prim:Constant[value=0]()
587 // %5 : Tensor = prim::VarConcat(%0, %1, %10)
588 // %6 : Tensor = prim::VarConcat(%0, %1, %2, %10)
589 // %7 : Tensor = prim::VarConcat(%0, %1, %2, %3, %10)
590 // %8 : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %10)
591 // return (%5, %6, %7, %8)
592 testing::FileCheck()
593 .check_count("= prim::VarConcat(", 4, /*exactly*/ true)
594 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
595 ->check_count("= aten::cat(", 0, /*exactly*/ true)
596 ->run(*graph);
597}
598
599TEST(
600 ConcatOptTest,
601 RemoveListMutationUseVariadicCatAndCommonInputsElimination) {
602 auto graph = std::make_shared<Graph>();
603
604 const std::string input =
605 R"IR(
606 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
607 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
608 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
609 %5 : int = prim::Constant[value=0]()
610
611 %features.2 : Tensor[] = prim::ListConstruct(%0, %1)
612 %6 : Tensor [] = aten::append(%features.2, %2)
613 %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5)
614
615 %7 : Tensor [] = aten::append(%features.2, %0)
616 %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5)
617
618 %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
619 return (%res)
620 )IR";
621 parseIR(input, graph.get());
622 std::vector<at::Tensor> inputs = {
623 at::rand({64, 56, 56}, at::kCPU),
624 at::rand({32, 56, 56}, at::kCPU),
625 at::rand({32, 56, 56}, at::kCPU)};
626 auto orig_outputs = runGraph(graph, inputs);
627
628 ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
629 ASSERT_TRUE(EliminateConcatCommonInputs(graph));
630 graph->lint();
631 auto opt_outputs = runGraph(graph, inputs);
632 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
633
634 // After performing:
635 // * Remove list mutation
636 // * Use variadic cat
637 // * Eliminate common inputs
638 // we should have the following graph:
639 //
640 // graph(%0 : ...,
641 // %1 : ...,
642 // %2 : ...):
643 // %3 : int = prim::Constant[value=0]()
644 // %10 : Tensor = prim::VarConcat(%0, %1, %2, %3)
645 // %12 : Tensor = prim::VarConcat(%10, %0, %3) // UPDATED
646 // %8 : Tensor[] = prim::ListConstruct(%10, %12)
647 // return (%8)
648 testing::FileCheck()
649 .check_count("= prim::VarConcat(%0, %1, %2, %3)", 1, /*exactly*/ true)
650 ->check_count("= prim::VarConcat(%10, %0, %3)", 1, /*exactly*/ true)
651 ->check_count("= prim::ListConstruct(%10, %12)", 1, /*exactly*/ true)
652 ->check_count("= aten::cat(", 0, /*exactly*/ true)
653 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
654 ->run(*graph);
655}
656
657TEST(ConcatOpt, CombineConcatsSimpleCase) {
658 auto graph = std::make_shared<Graph>();
659 const std::string input =
660 R"IR(
661 graph(%0: Tensor):
662 %dim : int = prim::Constant[value=0]()
663 %input.1 : Tensor[] = prim::ListConstruct(%0, %0)
664 %concat.1 : Tensor = aten::cat(%input.1, %dim)
665 %input.2 : Tensor[] = prim::ListConstruct(%concat.1, %0)
666 %concat.2 : Tensor = aten::cat(%input.2, %dim)
667 return (%concat.2)
668 )IR";
669 parseIR(input, graph.get());
670 std::vector<at::Tensor> inputs = {at::rand({1})};
671 auto orig_outputs = runGraph(graph, inputs);
672
673 ASSERT_TRUE(CombineConcats(graph));
674 graph->lint();
675 auto opt_outputs = runGraph(graph, inputs);
676 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
677
678 // After performing CombineConcats:
679 // graph(%0 : Tensor):
680 // %dim : int = prim::Constant[value=0]()
681 // %input : Tensor[] = prim::ListConstruct(%0, %0, %0)
682 // %concat : Tensor = aten::cat(%input, %dim)
683 // return (%concat)
684 testing::FileCheck()
685 .check_count("prim::ListConstruct", 1, /*exactly*/ true)
686 ->check_count("aten::cat", 1, /*exactly*/ true)
687 ->run(*graph);
688}
689
690TEST(ConcatOpt, CombineConcatsLongChain) {
691 auto graph = std::make_shared<Graph>();
692 const std::string input =
693 R"IR(
694 graph(%0: Tensor, %1 : Tensor):
695 %dim : int = prim::Constant[value=0]()
696 %input.1 : Tensor[] = prim::ListConstruct(%0, %0)
697 %concat.1 : Tensor = aten::cat(%input.1, %dim)
698 %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
699 %concat.2 : Tensor = aten::cat(%input.2, %dim)
700 %input.3 : Tensor[] = prim::ListConstruct(%0, %concat.2, %0)
701 %concat.3 : Tensor = aten::cat(%input.3, %dim)
702 return (%concat.3)
703 )IR";
704 parseIR(input, graph.get());
705 std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
706 auto orig_outputs = runGraph(graph, inputs);
707
708 ASSERT_TRUE(CombineConcats(graph));
709 graph->lint();
710 auto opt_outputs = runGraph(graph, inputs);
711 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
712
713 // After performing CombineConcats:
714 // graph(%0 : Tensor):
715 // %dim : int = prim::Constant[value=0]()
716 // %input : Tensor[] = prim::ListConstruct(%0, %1, %0, %0, %1, %0)
717 // %concat : Tensor = aten::cat(%input, %dim)
718 // return (%concat)
719 testing::FileCheck()
720 .check_count("prim::ListConstruct", 1, /*exactly*/ true)
721 ->check_count("aten::cat", 1, /*exactly*/ true)
722 ->run(*graph);
723}
724
725TEST(ConcatOpt, CombineConcatsMutation) {
726 auto graph = std::make_shared<Graph>();
727 const std::string input =
728 R"IR(
729 graph(%0: Tensor, %1 : Tensor):
730 %dim : int = prim::Constant[value=0]()
731 %input.1 : Tensor[] = prim::ListConstruct(%0, %0)
732 %concat.1 : Tensor = aten::cat(%input.1, %dim)
733 %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
734 %input.3 : Tensor[] = aten::append(%input.2, %0)
735 %concat.2 : Tensor = aten::cat(%input.2, %dim)
736 return (%concat.2)
737 )IR";
738 parseIR(input, graph.get());
739 std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
740 // No modifications due to aten::append
741 ASSERT_FALSE(CombineConcats(graph));
742}
743
744} // namespace jit
745} // namespace torch
746