1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <torch/csrc/jit/ir/irparser.h> |
5 | #include <torch/csrc/jit/passes/concat_opt.h> |
6 | #include <torch/csrc/jit/passes/variadic_ops.h> |
7 | #include <torch/csrc/jit/runtime/interpreter.h> |
8 | #include <torch/csrc/jit/testing/file_check.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) { |
14 | auto graph = std::make_shared<Graph>(); |
15 | |
16 | const std::string input = |
17 | R"IR( |
18 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
19 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
20 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
21 | %5 : int = prim::Constant[value=0]() |
22 | %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5) |
23 | %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5) |
24 | %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3) |
25 | return (%res) |
26 | )IR" ; |
27 | parseIR(input, graph.get()); |
28 | std::vector<at::Tensor> inputs = { |
29 | at::rand({64, 56, 56}, at::kCPU), |
30 | at::rand({32, 56, 56}, at::kCPU), |
31 | at::rand({32, 56, 56}, at::kCPU)}; |
32 | auto orig_outputs = runGraph(graph, inputs); |
33 | |
34 | ASSERT_TRUE(EliminateConcatCommonInputs(graph)); |
35 | graph->lint(); |
36 | auto opt_outputs = runGraph(graph, inputs); |
37 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
38 | |
39 | // Graph after EliminateConcatCommonInputs: |
40 | // graph(%0 : ..., |
41 | // %1 : ..., |
42 | // %2 : ...): |
43 | // %3 : int = prim::Constant[value=0]() |
44 | // %4 : Tensor = prim::VarConcat(%0, %1, %3) |
45 | // %7 : Tensor = prim::VarConcat(%4, %2, %3) // UPDATED |
46 | // %8 : Tensor[] = prim::ListConstruct(%4, %7) |
47 | // return (%8) |
48 | |
49 | testing::FileCheck() |
50 | .check_count("= prim::VarConcat(%0, %1, %3)" , 1, /*exactly*/ true) |
51 | ->check_count("= prim::VarConcat(%4, %2, %3)" , 1, /*exactly*/ true) |
52 | ->check_count("= prim::ListConstruct(%4, %7)" , 1, /*exactly*/ true) |
53 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
54 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
55 | ->run(*graph); |
56 | } |
57 | |
58 | TEST(ConcatOptTest, SimpleCommonInputsEliminationSuffix) { |
59 | auto graph = std::make_shared<Graph>(); |
60 | |
61 | const std::string input = |
62 | R"IR( |
63 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
64 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
65 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
66 | %5 : int = prim::Constant[value=0]() |
67 | %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %2, %5) |
68 | %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5) |
69 | %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3) |
70 | return (%res) |
71 | )IR" ; |
72 | parseIR(input, graph.get()); |
73 | std::vector<at::Tensor> inputs = { |
74 | at::rand({64, 56, 56}, at::kCPU), |
75 | at::rand({32, 56, 56}, at::kCPU), |
76 | at::rand({32, 56, 56}, at::kCPU)}; |
77 | auto orig_outputs = runGraph(graph, inputs); |
78 | |
79 | ASSERT_TRUE(EliminateConcatCommonInputs(graph)); |
80 | graph->lint(); |
81 | auto opt_outputs = runGraph(graph, inputs); |
82 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
83 | |
84 | // Graph after EliminateConcatCommonInputs: |
85 | // graph(%0 : ..., |
86 | // %1 : ..., |
87 | // %2 : ...): |
88 | // %3 : int = prim::Constant[value=0]() |
89 | // %4 : Tensor = prim::VarConcat(%1, %2, %3) |
90 | // %7 : Tensor = prim::VarConcat(%0, %4, %3) // UPDATED |
91 | // %8 : Tensor[] = prim::ListConstruct(%4, %7) |
92 | // return (%8) |
93 | |
94 | testing::FileCheck() |
95 | .check_count("= prim::VarConcat(%1, %2, %3)" , 1, /*exactly*/ true) |
96 | ->check_count("= prim::VarConcat(%0, %4, %3)" , 1, /*exactly*/ true) |
97 | ->check_count("= prim::ListConstruct(%4, %7)" , 1, /*exactly*/ true) |
98 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
99 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
100 | ->run(*graph); |
101 | } |
102 | |
103 | TEST(ConcatOptTest, CommonInputsEliminationWithDifferentOrderInputs) { |
104 | auto graph = std::make_shared<Graph>(); |
105 | |
106 | const std::string input = |
107 | R"IR( |
108 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
109 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
110 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
111 | %5 : int = prim::Constant[value=0]() |
112 | |
113 | #CHECK: prim::VarConcat |
114 | %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5) |
115 | |
116 | #CHECK: prim::VarConcat |
117 | %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %0, %2, %5) |
118 | |
119 | #CHECK: prim::ListConstruct |
120 | %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2) |
121 | return (%res) |
122 | )IR" ; |
123 | parseIR(input, graph.get()); |
124 | std::vector<at::Tensor> inputs = { |
125 | at::rand({64, 56, 56}, at::kCPU), |
126 | at::rand({32, 56, 56}, at::kCPU), |
127 | at::rand({32, 56, 56}, at::kCPU)}; |
128 | auto orig_outputs = runGraph(graph, inputs); |
129 | |
130 | ASSERT_FALSE(EliminateConcatCommonInputs(graph)); |
131 | graph->lint(); |
132 | auto opt_outputs = runGraph(graph, inputs); |
133 | |
134 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
135 | |
136 | // No optimizations should have happened in this case since the inputs |
137 | // to the `cat` are in different order. |
138 | testing::FileCheck().run(input, *graph); |
139 | } |
140 | |
141 | TEST(ConcatOptTest, MoreCommonInputsElimination) { |
142 | auto graph = std::make_shared<Graph>(); |
143 | |
144 | const std::string input = |
145 | R"IR( |
146 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
147 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
148 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
149 | %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
150 | %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
151 | %5 : int = prim::Constant[value=0]() |
152 | %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5) |
153 | %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5) |
154 | %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %5) |
155 | %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %4, %5) |
156 | %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2, %concat.3, %concat.4) |
157 | return (%res) |
158 | )IR" ; |
159 | parseIR(input, graph.get()); |
160 | std::vector<at::Tensor> inputs = { |
161 | at::rand({64, 56, 56}, at::kCPU), |
162 | at::rand({32, 56, 56}, at::kCPU), |
163 | at::rand({32, 56, 56}, at::kCPU), |
164 | at::rand({32, 56, 56}, at::kCPU), |
165 | at::rand({32, 56, 56}, at::kCPU)}; |
166 | auto orig_outputs = runGraph(graph, inputs); |
167 | |
168 | ASSERT_TRUE(EliminateConcatCommonInputs(graph)); |
169 | graph->lint(); |
170 | auto opt_outputs = runGraph(graph, inputs); |
171 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
172 | |
173 | testing::FileCheck() |
174 | .check_count("= prim::VarConcat(%0, %1, %5)" , 1, /*exactly*/ true) |
175 | ->check_count("= prim::VarConcat(%6, %2, %5)" , 1, /*exactly*/ true) |
176 | ->check_count("= prim::VarConcat(%11, %3, %5)" , 1, /*exactly*/ true) |
177 | ->check_count("= prim::VarConcat(%12, %4, %5)" , 1, /*exactly*/ true) |
178 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
179 | ->run(*graph); |
180 | } |
181 | |
182 | TEST(ConcatOptTest, ExpandConcat) { |
183 | auto graph = std::make_shared<Graph>(); |
184 | |
185 | const std::string input = |
186 | R"IR( |
187 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
188 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
189 | %2 : int = prim::Constant[value=0]() |
190 | %3 : float = prim::Constant[value=0.5]() |
191 | %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3) |
192 | %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3) |
193 | %input : Tensor[] = prim::ListConstruct(%4, %5) |
194 | %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2) |
195 | return (%concat) |
196 | )IR" ; |
197 | parseIR(input, graph.get()); |
198 | std::vector<at::Tensor> inputs = { |
199 | at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; |
200 | auto orig_outputs = runGraph(graph, inputs); |
201 | |
202 | ExpandConcatAndEliminateRedundancy(graph); |
203 | graph->lint(); |
204 | auto opt_outputs = runGraph(graph, inputs); |
205 | |
206 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
207 | |
208 | // After full concat optimization we should have the following graph: |
209 | // |
210 | // graph(%0 : ..., |
211 | // %1 : ...): |
212 | // ... |
213 | // %4 : Tensor = aten::clamp_max(...) |
214 | // %5 : Tensor = aten::clamp_max(...) |
215 | // %13 : int[] = prim::ListConstruct(...) |
216 | // %14 : Tensor = aten::empty(%13, ...) // concat buffer |
217 | // %17 : Tensor = aten::slice(%14, ...) // slice for %4 |
218 | // %18 : Tensor = aten::copy_(%17, %4) |
219 | // %20 : Tensor = aten::slice(%14, ...) // slice for %5 |
220 | // %21 : Tensor = aten::copy_(%20, %5) |
221 | // return (%14) |
222 | testing::FileCheck() |
223 | .check_count("= aten::cat(" , 0, /*exactly*/ true) |
224 | ->check_count("= aten::clamp_max(" , 2, /*exactly*/ true) |
225 | ->check_count("= aten::empty(" , 1, /*exactly*/ true) |
226 | ->check_count("= aten::slice(" , 1, /*exactly*/ true) |
227 | ->check_count("= aten::copy_(" , 1, /*exactly*/ true) |
228 | ->check_count("= aten::slice(" , 1, /*exactly*/ true) |
229 | ->check_count("= aten::copy_(" , 1, /*exactly*/ true) |
230 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
231 | ->run(*graph); |
232 | } |
233 | |
234 | TEST(ConcatOptTest, ConcatWithoutResultShape) { |
235 | auto graph = std::make_shared<Graph>(); |
236 | |
237 | const std::string input = |
238 | R"IR( |
239 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
240 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
241 | %2 : int = prim::Constant[value=0]() |
242 | %3 : float = prim::Constant[value=0.5]() |
243 | # CHECK: clamp_max |
244 | %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3) |
245 | # CHECK: clamp_max |
246 | %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3) |
247 | # CHECK: prim::ListConstruct |
248 | %6 : Tensor[] = prim::ListConstruct(%4, %5) |
249 | # CHECK: aten::cat |
250 | %7 : Tensor = aten::cat(%6, %2) |
251 | return (%7) |
252 | )IR" ; |
253 | parseIR(input, graph.get()); |
254 | std::vector<at::Tensor> inputs = { |
255 | at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; |
256 | auto orig_outputs = runGraph(graph, inputs); |
257 | |
258 | ExpandConcatAndEliminateRedundancy(graph); |
259 | graph->lint(); |
260 | auto opt_outputs = runGraph(graph, inputs); |
261 | |
262 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
263 | |
264 | // No optimizations should have happened in this case since the output |
265 | // shape of `aten::cat` is not known. |
266 | testing::FileCheck().run(input, *graph); |
267 | } |
268 | |
269 | TEST(ConcatOptTest, ConcatWithoutInputShape) { |
270 | auto graph = std::make_shared<Graph>(); |
271 | |
272 | const std::string input = |
273 | R"IR( |
274 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
275 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
276 | %2 : int = prim::Constant[value=0]() |
277 | %3 : float = prim::Constant[value=0.5]() |
278 | # CHECK: clamp_max |
279 | %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3) |
280 | # CHECK: clamp_max |
281 | %5 : Tensor = aten::clamp_max(%1, %3) |
282 | # CHECK: prim::ListConstruct |
283 | %6 : Tensor[] = prim::ListConstruct(%4, %5) |
284 | # CHECK: aten::cat |
285 | %7 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%6, %2) |
286 | return (%7) |
287 | )IR" ; |
288 | parseIR(input, graph.get()); |
289 | std::vector<at::Tensor> inputs = { |
290 | at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; |
291 | auto orig_outputs = runGraph(graph, inputs); |
292 | |
293 | ExpandConcatAndEliminateRedundancy(graph); |
294 | graph->lint(); |
295 | auto opt_outputs = runGraph(graph, inputs); |
296 | |
297 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
298 | |
299 | // No optimizations should have happened in this case since the shape of %5, |
300 | // which is an input to `aten::cat`, is not known. |
301 | testing::FileCheck().run(input, *graph); |
302 | } |
303 | |
304 | TEST(ConcatOptTest, UseVariadicCat) { |
305 | auto graph = std::make_shared<Graph>(); |
306 | |
307 | const std::string input = |
308 | R"IR( |
309 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
310 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
311 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
312 | %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
313 | %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
314 | %5: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
315 | %10 : int = prim::Constant[value=0]() |
316 | %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5) |
317 | %concat : Float(224, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) |
318 | return (%concat) |
319 | )IR" ; |
320 | parseIR(input, graph.get()); |
321 | std::vector<at::Tensor> inputs = { |
322 | at::rand({64, 56, 56}, at::kCPU), |
323 | at::rand({32, 56, 56}, at::kCPU), |
324 | at::rand({32, 56, 56}, at::kCPU), |
325 | at::rand({32, 56, 56}, at::kCPU), |
326 | at::rand({32, 56, 56}, at::kCPU), |
327 | at::rand({32, 56, 56}, at::kCPU)}; |
328 | auto orig_outputs = runGraph(graph, inputs); |
329 | |
330 | ASSERT_TRUE(UseVariadicCat(graph)); |
331 | graph->lint(); |
332 | auto opt_outputs = runGraph(graph, inputs); |
333 | |
334 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
335 | |
336 | // After replacing `aten::cat` with `prim::VarConcat` we should have the |
337 | // following graph: |
338 | // |
339 | // graph(%0 : ..., |
340 | // %1 : ...): |
341 | // %zero : int = prim:Constant[value=0]() |
342 | // %varcat : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %5, %zero) |
343 | // return (%varcat) |
344 | testing::FileCheck() |
345 | .check_count("= prim::VarConcat(" , 1, /*exactly*/ true) |
346 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
347 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
348 | ->run(*graph); |
349 | } |
350 | |
351 | TEST(OptimizeConcatTest, UseVariadicCatReplaceMultiple) { |
352 | auto graph = std::make_shared<Graph>(); |
353 | |
354 | const std::string input = |
355 | R"IR( |
356 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
357 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
358 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
359 | %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
360 | %10 : int = prim::Constant[value=0]() |
361 | %input1 : Tensor[] = prim::ListConstruct(%0, %1) |
362 | %concat1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input1, %10) |
363 | %input2 : Tensor[] = prim::ListConstruct(%2, %3) |
364 | %concat2 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input2, %10) |
365 | return (%concat1, %concat2) |
366 | )IR" ; |
367 | parseIR(input, graph.get()); |
368 | std::vector<at::Tensor> inputs = { |
369 | at::rand({64, 56, 56}, at::kCPU), |
370 | at::rand({32, 56, 56}, at::kCPU), |
371 | at::rand({32, 56, 56}, at::kCPU), |
372 | at::rand({32, 56, 56}, at::kCPU)}; |
373 | auto orig_outputs = runGraph(graph, inputs); |
374 | |
375 | ASSERT_TRUE(UseVariadicCat(graph)); |
376 | graph->lint(); |
377 | auto opt_outputs = runGraph(graph, inputs); |
378 | |
379 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
380 | |
381 | // After full concat optimization we should have the following graph: |
382 | // |
383 | // graph(%0 : ..., |
384 | // %1 : ..., |
385 | // %2 : ..., |
386 | // %3 : ....): |
387 | // %zero : int = prim:Constant[value=0]() |
388 | // %varcat1 : Tensor = prim::VarConcat(%0, %1, %zero) |
389 | // %varcat2 : Tensor = prim::VarConcat(%2, %3, %zero) |
390 | // return (%varcat1, %varcat2) |
391 | testing::FileCheck() |
392 | .check_count("= prim::VarConcat(" , 2, /*exactly*/ true) |
393 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
394 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
395 | ->run(*graph); |
396 | } |
397 | |
398 | TEST(ConcatOptTest, UseVariadicCatWithMultipleListUses) { |
399 | auto graph = std::make_shared<Graph>(); |
400 | |
401 | const std::string input = |
402 | R"IR( |
403 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
404 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
405 | %2 : int = prim::Constant[value=0]() |
406 | %input : Tensor[] = prim::ListConstruct(%0, %1) |
407 | %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2) |
408 | return (%concat, %input) |
409 | )IR" ; |
410 | parseIR(input, graph.get()); |
411 | std::vector<at::Tensor> inputs = { |
412 | at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)}; |
413 | auto orig_outputs = runGraph(graph, inputs); |
414 | |
415 | ASSERT_TRUE(UseVariadicCat(graph)); |
416 | graph->lint(); |
417 | auto opt_outputs = runGraph(graph, inputs); |
418 | |
419 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
420 | |
421 | // After replacing `aten::cat` with `prim::VarConcat` we should have the |
422 | // following graph: |
423 | // |
424 | // graph(%0 : ..., |
425 | // %1 : ...): |
426 | // %zero : int = prim:Constant[value=0]() |
427 | // %input : Tensor[] = prim::ListConstruct(%0, %1) |
428 | // %varcat : Tensor = prim::VarConcat(%0, %1, %zero) |
429 | // return (%varcat, %input) |
430 | testing::FileCheck() |
431 | .check_count("= prim::ListConstruct(" , 1, /*exactly*/ true) |
432 | ->check_count("= prim::VarConcat(" , 1, /*exactly*/ true) |
433 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
434 | ->run(*graph); |
435 | } |
436 | |
437 | TEST(ConcatOptTest, UseVariadicCatWithListMutationAfterCat) { |
438 | auto graph = std::make_shared<Graph>(); |
439 | |
440 | const std::string input = |
441 | R"IR( |
442 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
443 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
444 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
445 | %10 : int = prim::Constant[value=0]() |
446 | %input : Tensor[] = prim::ListConstruct(%0, %1) |
447 | %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) |
448 | %11 : Tensor = aten::append(%input, %2) |
449 | return (%concat, %input) |
450 | )IR" ; |
451 | parseIR(input, graph.get()); |
452 | std::vector<at::Tensor> inputs = { |
453 | at::rand({64, 56, 56}, at::kCPU), |
454 | at::rand({32, 56, 56}, at::kCPU), |
455 | at::rand({32, 56, 56}, at::kCPU)}; |
456 | auto orig_outputs = runGraph(graph, inputs); |
457 | |
458 | ASSERT_TRUE(UseVariadicCat(graph)); |
459 | graph->lint(); |
460 | auto opt_outputs = runGraph(graph, inputs); |
461 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
462 | |
463 | // The input list to `aten::cat` is mutated only after `aten::cat` op. So, |
464 | // it should have been replaced with `prim::VarConcat`. The transformed graph |
465 | // should look like the following: |
466 | // |
467 | // graph(%0 : ..., |
468 | // %1 : ..., |
469 | // %2 : ...): |
470 | // %3 : int = prim:Constant[value=0]() |
471 | // %4 : Tensor[] = prim::ListConstruct(%0, %1) |
472 | // %7 : Tensor = prim::VarConcat(%0, %1, %3) |
473 | // %6 : Tensor = aten::append(%4, %2) |
474 | // return (%7, %4) |
475 | testing::FileCheck() |
476 | .check_count("= prim::ListConstruct(" , 1, /*exactly*/ true) |
477 | ->check_count("= prim::VarConcat(" , 1, /*exactly*/ true) |
478 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
479 | ->run(*graph); |
480 | } |
481 | |
482 | TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) { |
483 | auto graph = std::make_shared<Graph>(); |
484 | |
485 | const std::string input = |
486 | R"IR( |
487 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
488 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
489 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
490 | %10 : int = prim::Constant[value=0]() |
491 | %input : Tensor[] = prim::ListConstruct(%0, %1) |
492 | %11 : Tensor = aten::append(%input, %2) |
493 | %concat : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) |
494 | return (%concat) |
495 | )IR" ; |
496 | parseIR(input, graph.get()); |
497 | std::vector<at::Tensor> inputs = { |
498 | at::rand({64, 56, 56}, at::kCPU), |
499 | at::rand({32, 56, 56}, at::kCPU), |
500 | at::rand({32, 56, 56}, at::kCPU)}; |
501 | auto orig_outputs = runGraph(graph, inputs); |
502 | |
503 | { |
504 | ASSERT_FALSE(UseVariadicCat(graph)); |
505 | graph->lint(); |
506 | auto opt_outputs = runGraph(graph, inputs); |
507 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
508 | |
509 | // No transformation should have happened since the `prim::ListConstruct` is |
510 | // mutated before `aten::cat`. |
511 | testing::FileCheck() |
512 | .check_count("= prim::ListConstruct(" , 1, /*exactly*/ true) |
513 | ->check_count("= aten::cat(" , 1, /*exactly*/ true) |
514 | ->check_count("= prim::VarConcat(" , 0, /*exactly*/ true) |
515 | ->run(*graph); |
516 | } |
517 | |
518 | { |
519 | ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph)); |
520 | graph->lint(); |
521 | auto opt_outputs = runGraph(graph, inputs); |
522 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
523 | |
524 | // The mutation of the list must be removed and the `aten::cat` op must |
525 | // be replaced with the `prim::VarConcat` op in the graph. The transformed |
526 | // graph should look like the following: |
527 | // |
528 | // graph(%0 : ..., |
529 | // %1 : ..., |
530 | // %2 : ...): |
531 | // %3 : int = prim:Constant[value=0]() |
532 | // %7 : Tensor = prim::VarConcat(%0, %1, %2, %3) |
533 | // return (%7) |
534 | testing::FileCheck() |
535 | .check_count("= prim::VarConcat(" , 1, /*exactly*/ true) |
536 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
537 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
538 | ->run(*graph); |
539 | } |
540 | } |
541 | |
542 | TEST(ConcatOptTest, UseVariadicCatWithMultipleListMutations) { |
543 | auto graph = std::make_shared<Graph>(); |
544 | |
545 | const std::string input = |
546 | R"IR( |
547 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
548 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
549 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
550 | %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
551 | %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
552 | %10 : int = prim::Constant[value=0]() |
553 | %input : Tensor[] = prim::ListConstruct(%0, %1) |
554 | %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) |
555 | %11 : Tensor = aten::append(%input, %2) |
556 | %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) |
557 | %12 : Tensor = aten::append(%input, %3) |
558 | %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) |
559 | %13 : Tensor = aten::append(%input, %4) |
560 | %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10) |
561 | return (%concat.1, %concat.2, %concat.3, %concat.4) |
562 | )IR" ; |
563 | parseIR(input, graph.get()); |
564 | std::vector<at::Tensor> inputs = { |
565 | at::rand({64, 56, 56}, at::kCPU), |
566 | at::rand({32, 56, 56}, at::kCPU), |
567 | at::rand({32, 56, 56}, at::kCPU), |
568 | at::rand({32, 56, 56}, at::kCPU), |
569 | at::rand({32, 56, 56}, at::kCPU)}; |
570 | auto orig_outputs = runGraph(graph, inputs); |
571 | |
572 | ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph)); |
573 | graph->lint(); |
574 | auto opt_outputs = runGraph(graph, inputs); |
575 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
576 | |
577 | // All the mutations of the list must be removed and the `aten::cat` ops must |
578 | // be replaced with `prim::VarConcat` ops in the graph. The transformed graph |
579 | // should look like the following: |
580 | // |
581 | // graph(%0 : ..., |
582 | // %1 : ..., |
583 | // %2 : ..., |
584 | // %3 : ..., |
585 | // %4 : ...): |
586 | // %10 : int = prim:Constant[value=0]() |
587 | // %5 : Tensor = prim::VarConcat(%0, %1, %10) |
588 | // %6 : Tensor = prim::VarConcat(%0, %1, %2, %10) |
589 | // %7 : Tensor = prim::VarConcat(%0, %1, %2, %3, %10) |
590 | // %8 : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %10) |
591 | // return (%5, %6, %7, %8) |
592 | testing::FileCheck() |
593 | .check_count("= prim::VarConcat(" , 4, /*exactly*/ true) |
594 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
595 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
596 | ->run(*graph); |
597 | } |
598 | |
599 | TEST( |
600 | ConcatOptTest, |
601 | RemoveListMutationUseVariadicCatAndCommonInputsElimination) { |
602 | auto graph = std::make_shared<Graph>(); |
603 | |
604 | const std::string input = |
605 | R"IR( |
606 | graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
607 | %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu), |
608 | %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)): |
609 | %5 : int = prim::Constant[value=0]() |
610 | |
611 | %features.2 : Tensor[] = prim::ListConstruct(%0, %1) |
612 | %6 : Tensor [] = aten::append(%features.2, %2) |
613 | %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5) |
614 | |
615 | %7 : Tensor [] = aten::append(%features.2, %0) |
616 | %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5) |
617 | |
618 | %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3) |
619 | return (%res) |
620 | )IR" ; |
621 | parseIR(input, graph.get()); |
622 | std::vector<at::Tensor> inputs = { |
623 | at::rand({64, 56, 56}, at::kCPU), |
624 | at::rand({32, 56, 56}, at::kCPU), |
625 | at::rand({32, 56, 56}, at::kCPU)}; |
626 | auto orig_outputs = runGraph(graph, inputs); |
627 | |
628 | ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph)); |
629 | ASSERT_TRUE(EliminateConcatCommonInputs(graph)); |
630 | graph->lint(); |
631 | auto opt_outputs = runGraph(graph, inputs); |
632 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
633 | |
634 | // After performing: |
635 | // * Remove list mutation |
636 | // * Use variadic cat |
637 | // * Eliminate common inputs |
638 | // we should have the following graph: |
639 | // |
640 | // graph(%0 : ..., |
641 | // %1 : ..., |
642 | // %2 : ...): |
643 | // %3 : int = prim::Constant[value=0]() |
644 | // %10 : Tensor = prim::VarConcat(%0, %1, %2, %3) |
645 | // %12 : Tensor = prim::VarConcat(%10, %0, %3) // UPDATED |
646 | // %8 : Tensor[] = prim::ListConstruct(%10, %12) |
647 | // return (%8) |
648 | testing::FileCheck() |
649 | .check_count("= prim::VarConcat(%0, %1, %2, %3)" , 1, /*exactly*/ true) |
650 | ->check_count("= prim::VarConcat(%10, %0, %3)" , 1, /*exactly*/ true) |
651 | ->check_count("= prim::ListConstruct(%10, %12)" , 1, /*exactly*/ true) |
652 | ->check_count("= aten::cat(" , 0, /*exactly*/ true) |
653 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
654 | ->run(*graph); |
655 | } |
656 | |
657 | TEST(ConcatOpt, CombineConcatsSimpleCase) { |
658 | auto graph = std::make_shared<Graph>(); |
659 | const std::string input = |
660 | R"IR( |
661 | graph(%0: Tensor): |
662 | %dim : int = prim::Constant[value=0]() |
663 | %input.1 : Tensor[] = prim::ListConstruct(%0, %0) |
664 | %concat.1 : Tensor = aten::cat(%input.1, %dim) |
665 | %input.2 : Tensor[] = prim::ListConstruct(%concat.1, %0) |
666 | %concat.2 : Tensor = aten::cat(%input.2, %dim) |
667 | return (%concat.2) |
668 | )IR" ; |
669 | parseIR(input, graph.get()); |
670 | std::vector<at::Tensor> inputs = {at::rand({1})}; |
671 | auto orig_outputs = runGraph(graph, inputs); |
672 | |
673 | ASSERT_TRUE(CombineConcats(graph)); |
674 | graph->lint(); |
675 | auto opt_outputs = runGraph(graph, inputs); |
676 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
677 | |
678 | // After performing CombineConcats: |
679 | // graph(%0 : Tensor): |
680 | // %dim : int = prim::Constant[value=0]() |
681 | // %input : Tensor[] = prim::ListConstruct(%0, %0, %0) |
682 | // %concat : Tensor = aten::cat(%input, %dim) |
683 | // return (%concat) |
684 | testing::FileCheck() |
685 | .check_count("prim::ListConstruct" , 1, /*exactly*/ true) |
686 | ->check_count("aten::cat" , 1, /*exactly*/ true) |
687 | ->run(*graph); |
688 | } |
689 | |
690 | TEST(ConcatOpt, CombineConcatsLongChain) { |
691 | auto graph = std::make_shared<Graph>(); |
692 | const std::string input = |
693 | R"IR( |
694 | graph(%0: Tensor, %1 : Tensor): |
695 | %dim : int = prim::Constant[value=0]() |
696 | %input.1 : Tensor[] = prim::ListConstruct(%0, %0) |
697 | %concat.1 : Tensor = aten::cat(%input.1, %dim) |
698 | %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1) |
699 | %concat.2 : Tensor = aten::cat(%input.2, %dim) |
700 | %input.3 : Tensor[] = prim::ListConstruct(%0, %concat.2, %0) |
701 | %concat.3 : Tensor = aten::cat(%input.3, %dim) |
702 | return (%concat.3) |
703 | )IR" ; |
704 | parseIR(input, graph.get()); |
705 | std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})}; |
706 | auto orig_outputs = runGraph(graph, inputs); |
707 | |
708 | ASSERT_TRUE(CombineConcats(graph)); |
709 | graph->lint(); |
710 | auto opt_outputs = runGraph(graph, inputs); |
711 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
712 | |
713 | // After performing CombineConcats: |
714 | // graph(%0 : Tensor): |
715 | // %dim : int = prim::Constant[value=0]() |
716 | // %input : Tensor[] = prim::ListConstruct(%0, %1, %0, %0, %1, %0) |
717 | // %concat : Tensor = aten::cat(%input, %dim) |
718 | // return (%concat) |
719 | testing::FileCheck() |
720 | .check_count("prim::ListConstruct" , 1, /*exactly*/ true) |
721 | ->check_count("aten::cat" , 1, /*exactly*/ true) |
722 | ->run(*graph); |
723 | } |
724 | |
725 | TEST(ConcatOpt, CombineConcatsMutation) { |
726 | auto graph = std::make_shared<Graph>(); |
727 | const std::string input = |
728 | R"IR( |
729 | graph(%0: Tensor, %1 : Tensor): |
730 | %dim : int = prim::Constant[value=0]() |
731 | %input.1 : Tensor[] = prim::ListConstruct(%0, %0) |
732 | %concat.1 : Tensor = aten::cat(%input.1, %dim) |
733 | %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1) |
734 | %input.3 : Tensor[] = aten::append(%input.2, %0) |
735 | %concat.2 : Tensor = aten::cat(%input.2, %dim) |
736 | return (%concat.2) |
737 | )IR" ; |
738 | parseIR(input, graph.get()); |
739 | std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})}; |
740 | // No modifications due to aten::append |
741 | ASSERT_FALSE(CombineConcats(graph)); |
742 | } |
743 | |
744 | } // namespace jit |
745 | } // namespace torch |
746 | |