1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4#include <torch/csrc/jit/ir/irparser.h>
5#include <torch/csrc/jit/passes/variadic_ops.h>
6#include <torch/csrc/jit/runtime/interpreter.h>
7#include <torch/csrc/jit/testing/file_check.h>
8
9namespace torch {
10namespace jit {
11
12TEST(StackOptTest, UseVariadicStack) {
13 auto graph = std::make_shared<Graph>();
14
15 const std::string input =
16 R"IR(
17 graph(%0: Float(56, 56, 56),
18 %1: Float(56, 56, 56),
19 %2: Float(56, 56, 56),
20 %3: Float(56, 56, 56),
21 %4: Float(56, 56, 56),
22 %5: Float(56, 56, 56)):
23 %10 : int = prim::Constant[value=0]()
24 %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5)
25 %stack : Float(5, 56, 56, 56) = aten::stack(%input, %10)
26 return (%stack)
27 )IR";
28 parseIR(input, graph.get());
29 std::vector<at::Tensor> inputs = {
30 at::rand({56, 56, 56}, at::kCPU),
31 at::rand({56, 56, 56}, at::kCPU),
32 at::rand({56, 56, 56}, at::kCPU),
33 at::rand({56, 56, 56}, at::kCPU),
34 at::rand({56, 56, 56}, at::kCPU),
35 at::rand({56, 56, 56}, at::kCPU)};
36 auto orig_outputs = runGraph(graph, inputs);
37
38 ASSERT_TRUE(UseVariadicStack(graph));
39 graph->lint();
40 auto opt_outputs = runGraph(graph, inputs);
41
42 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
43
44 // After replacing `aten::stack` with `prim::VarStack` we should have the
45 // following graph:
46 //
47 // graph(%0 : ...,
48 // %1 : ...):
49 // %zero : int = prim:Constant[value=0]()
50 // %varstack : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %5, %zero)
51 // return (%varstack)
52 testing::FileCheck()
53 .check_count("= prim::VarStack(", 1, /*exactly*/ true)
54 ->check_count("= aten::stack(", 0, /*exactly*/ true)
55 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
56 ->run(*graph);
57}
58
59TEST(StackOptTest, UseVariadicStackReplaceMultiple) {
60 auto graph = std::make_shared<Graph>();
61
62 const std::string input =
63 R"IR(
64 graph(%0: Float(56, 56, 56),
65 %1: Float(56, 56, 56),
66 %2: Float(56, 56, 56),
67 %3: Float(56, 56, 56)):
68 %10 : int = prim::Constant[value=0]()
69 %input1 : Tensor[] = prim::ListConstruct(%0, %1)
70 %stack1 : Float(4, 56, 56, 56) = aten::stack(%input1, %10)
71 %input2 : Tensor[] = prim::ListConstruct(%2, %3)
72 %stack2 : Float(4, 56, 56, 56) = aten::stack(%input2, %10)
73 return (%stack1, %stack2)
74 )IR";
75 parseIR(input, graph.get());
76 std::vector<at::Tensor> inputs = {
77 at::rand({56, 56, 56}, at::kCPU),
78 at::rand({56, 56, 56}, at::kCPU),
79 at::rand({56, 56, 56}, at::kCPU),
80 at::rand({56, 56, 56}, at::kCPU)};
81 auto orig_outputs = runGraph(graph, inputs);
82
83 ASSERT_TRUE(UseVariadicStack(graph));
84 graph->lint();
85 auto opt_outputs = runGraph(graph, inputs);
86
87 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
88
89 // After full stack optimization we should have the following graph:
90 //
91 // graph(%0 : ...,
92 // %1 : ...,
93 // %2 : ...,
94 // %3 : ....):
95 // %zero : int = prim:Constant[value=0]()
96 // %varcat1 : Tensor = prim::VarStack(%0, %1, %zero)
97 // %varcat2 : Tensor = prim::VarStack(%2, %3, %zero)
98 // return (%varcat1, %varcat2)
99 testing::FileCheck()
100 .check_count("= prim::VarStack(", 2, /*exactly*/ true)
101 ->check_count("= aten::stack(", 0, /*exactly*/ true)
102 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
103 ->run(*graph);
104}
105
106TEST(StackOptTest, UseVariadicStackWithMultipleListUses) {
107 auto graph = std::make_shared<Graph>();
108
109 const std::string input =
110 R"IR(
111 graph(%0: Float(56, 56, 56),
112 %1: Float(56, 56, 56)):
113 %2 : int = prim::Constant[value=0]()
114 %input : Tensor[] = prim::ListConstruct(%0, %1)
115 %stack : Float(2, 56, 56, 56) = aten::stack(%input, %2)
116 return (%stack, %input)
117 )IR";
118 parseIR(input, graph.get());
119 std::vector<at::Tensor> inputs = {
120 at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)};
121 auto orig_outputs = runGraph(graph, inputs);
122
123 ASSERT_TRUE(UseVariadicStack(graph));
124 graph->lint();
125 auto opt_outputs = runGraph(graph, inputs);
126
127 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
128
129 // After replacing `aten::stack` with `prim::VarStack` we should have the
130 // following graph:
131 //
132 // graph(%0 : ...,
133 // %1 : ...):
134 // %zero : int = prim:Constant[value=0]()
135 // %input : Tensor[] = prim::ListConstruct(%0, %1)
136 // %varcat : Tensor = prim::VarStack(%0, %1, %zero)
137 // return (%varcat, %input)
138 testing::FileCheck()
139 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
140 ->check_count("= prim::VarStack(", 1, /*exactly*/ true)
141 ->check_count("= aten::stack(", 0, /*exactly*/ true)
142 ->run(*graph);
143}
144
145TEST(StackOptTest, UseVariadicStackWithListMutationAfterCat) {
146 auto graph = std::make_shared<Graph>();
147
148 const std::string input =
149 R"IR(
150 graph(%0: Float(56, 56, 56),
151 %1: Float(56, 56, 56),
152 %2: Float(56, 56, 56)):
153 %10 : int = prim::Constant[value=0]()
154 %input : Tensor[] = prim::ListConstruct(%0, %1)
155 %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
156 %11 : Tensor = aten::append(%input, %2)
157 return (%stack, %input)
158 )IR";
159 parseIR(input, graph.get());
160 std::vector<at::Tensor> inputs = {
161 at::rand({56, 56, 56}, at::kCPU),
162 at::rand({56, 56, 56}, at::kCPU),
163 at::rand({56, 56, 56}, at::kCPU)};
164 auto orig_outputs = runGraph(graph, inputs);
165
166 ASSERT_TRUE(UseVariadicStack(graph));
167 graph->lint();
168 auto opt_outputs = runGraph(graph, inputs);
169 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
170
171 // The input list to `aten::stack` is mutated only after `aten::stack` op. So,
172 // it should have been replaced with `prim::VarStack`. The transformed graph
173 // should look like the following:
174 //
175 // graph(%0 : ...,
176 // %1 : ...,
177 // %2 : ...):
178 // %3 : int = prim:Constant[value=0]()
179 // %4 : Tensor[] = prim::ListConstruct(%0, %1)
180 // %7 : Tensor = prim::VarStack(%0, %1, %3)
181 // %6 : Tensor = aten::append(%4, %2)
182 // return (%7, %4)
183 testing::FileCheck()
184 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
185 ->check_count("= prim::VarStack(", 1, /*exactly*/ true)
186 ->check_count("= aten::stack(", 0, /*exactly*/ true)
187 ->run(*graph);
188}
189
190TEST(StackOptTest, UseVariadicStackWithListMutationBeforeCat) {
191 auto graph = std::make_shared<Graph>();
192
193 const std::string input =
194 R"IR(
195 graph(%0: Float(56, 56, 56),
196 %1: Float(56, 56, 56),
197 %2: Float(56, 56, 56)):
198 %10 : int = prim::Constant[value=0]()
199 %input : Tensor[] = prim::ListConstruct(%0, %1)
200 %11 : Tensor = aten::append(%input, %2)
201 %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
202 return (%stack)
203 )IR";
204 parseIR(input, graph.get());
205 std::vector<at::Tensor> inputs = {
206 at::rand({56, 56, 56}, at::kCPU),
207 at::rand({56, 56, 56}, at::kCPU),
208 at::rand({56, 56, 56}, at::kCPU)};
209 auto orig_outputs = runGraph(graph, inputs);
210
211 {
212 ASSERT_FALSE(UseVariadicStack(graph));
213 graph->lint();
214 auto opt_outputs = runGraph(graph, inputs);
215 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
216
217 // No transformation should have happened since the `prim::ListConstruct` is
218 // mutated before `aten::stack`.
219 testing::FileCheck()
220 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
221 ->check_count("= aten::stack(", 1, /*exactly*/ true)
222 ->check_count("= prim::VarStack(", 0, /*exactly*/ true)
223 ->run(*graph);
224 }
225
226 {
227 ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
228 graph->lint();
229 auto opt_outputs = runGraph(graph, inputs);
230 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
231
232 // The mutation of the list must be removed and the `aten::stack` op must
233 // be replaced with the `prim::VarStack` op in the graph. The transformed
234 // graph should look like the following:
235 //
236 // graph(%0 : ...,
237 // %1 : ...,
238 // %2 : ...):
239 // %3 : int = prim:Constant[value=0]()
240 // %7 : Tensor = prim::VarStack(%0, %1, %2, %3)
241 // return (%7)
242 testing::FileCheck()
243 .check_count("= prim::VarStack(", 1, /*exactly*/ true)
244 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
245 ->check_count("= aten::stack(", 0, /*exactly*/ true)
246 ->run(*graph);
247 }
248}
249
250TEST(StackOptTest, UseVariadicStackWithMultipleListMutations) {
251 auto graph = std::make_shared<Graph>();
252
253 const std::string input =
254 R"IR(
255 graph(%0: Float(56, 56, 56),
256 %1: Float(56, 56, 56),
257 %2: Float(56, 56, 56),
258 %3: Float(56, 56, 56),
259 %4: Float(56, 56, 56)):
260 %10 : int = prim::Constant[value=0]()
261 %input : Tensor[] = prim::ListConstruct(%0, %1)
262 %stack.1 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
263 %11 : Tensor = aten::append(%input, %2)
264 %stack.2 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
265 %12 : Tensor = aten::append(%input, %3)
266 %stack.3 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
267 %13 : Tensor = aten::append(%input, %4)
268 %stack.4 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
269 return (%stack.1, %stack.2, %stack.3, %stack.4)
270 )IR";
271 parseIR(input, graph.get());
272 std::vector<at::Tensor> inputs = {
273 at::rand({56, 56, 56}, at::kCPU),
274 at::rand({56, 56, 56}, at::kCPU),
275 at::rand({56, 56, 56}, at::kCPU),
276 at::rand({56, 56, 56}, at::kCPU),
277 at::rand({56, 56, 56}, at::kCPU)};
278 auto orig_outputs = runGraph(graph, inputs);
279
280 ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
281 graph->lint();
282 auto opt_outputs = runGraph(graph, inputs);
283 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
284
285 // All the mutations of the list must be removed and the `aten::stack` ops
286 // must be replaced with `prim::VarStack` ops in the graph. The transformed
287 // graph should look like the following:
288 //
289 // graph(%0 : ...,
290 // %1 : ...,
291 // %2 : ...,
292 // %3 : ...,
293 // %4 : ...):
294 // %10 : int = prim:Constant[value=0]()
295 // %5 : Tensor = prim::VarStack(%0, %1, %10)
296 // %6 : Tensor = prim::VarStack(%0, %1, %2, %10)
297 // %7 : Tensor = prim::VarStack(%0, %1, %2, %3, %10)
298 // %8 : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %10)
299 // return (%5, %6, %7, %8)
300 testing::FileCheck()
301 .check_count("= prim::VarStack(", 4, /*exactly*/ true)
302 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
303 ->check_count("= aten::stack(", 0, /*exactly*/ true)
304 ->run(*graph);
305}
306
307} // namespace jit
308} // namespace torch
309