1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <torch/csrc/jit/ir/irparser.h> |
5 | #include <torch/csrc/jit/passes/variadic_ops.h> |
6 | #include <torch/csrc/jit/runtime/interpreter.h> |
7 | #include <torch/csrc/jit/testing/file_check.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | TEST(StackOptTest, UseVariadicStack) { |
13 | auto graph = std::make_shared<Graph>(); |
14 | |
15 | const std::string input = |
16 | R"IR( |
17 | graph(%0: Float(56, 56, 56), |
18 | %1: Float(56, 56, 56), |
19 | %2: Float(56, 56, 56), |
20 | %3: Float(56, 56, 56), |
21 | %4: Float(56, 56, 56), |
22 | %5: Float(56, 56, 56)): |
23 | %10 : int = prim::Constant[value=0]() |
24 | %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5) |
25 | %stack : Float(5, 56, 56, 56) = aten::stack(%input, %10) |
26 | return (%stack) |
27 | )IR" ; |
28 | parseIR(input, graph.get()); |
29 | std::vector<at::Tensor> inputs = { |
30 | at::rand({56, 56, 56}, at::kCPU), |
31 | at::rand({56, 56, 56}, at::kCPU), |
32 | at::rand({56, 56, 56}, at::kCPU), |
33 | at::rand({56, 56, 56}, at::kCPU), |
34 | at::rand({56, 56, 56}, at::kCPU), |
35 | at::rand({56, 56, 56}, at::kCPU)}; |
36 | auto orig_outputs = runGraph(graph, inputs); |
37 | |
38 | ASSERT_TRUE(UseVariadicStack(graph)); |
39 | graph->lint(); |
40 | auto opt_outputs = runGraph(graph, inputs); |
41 | |
42 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
43 | |
44 | // After replacing `aten::stack` with `prim::VarStack` we should have the |
45 | // following graph: |
46 | // |
47 | // graph(%0 : ..., |
48 | // %1 : ...): |
49 | // %zero : int = prim:Constant[value=0]() |
50 | // %varstack : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %5, %zero) |
51 | // return (%varstack) |
52 | testing::FileCheck() |
53 | .check_count("= prim::VarStack(" , 1, /*exactly*/ true) |
54 | ->check_count("= aten::stack(" , 0, /*exactly*/ true) |
55 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
56 | ->run(*graph); |
57 | } |
58 | |
59 | TEST(StackOptTest, UseVariadicStackReplaceMultiple) { |
60 | auto graph = std::make_shared<Graph>(); |
61 | |
62 | const std::string input = |
63 | R"IR( |
64 | graph(%0: Float(56, 56, 56), |
65 | %1: Float(56, 56, 56), |
66 | %2: Float(56, 56, 56), |
67 | %3: Float(56, 56, 56)): |
68 | %10 : int = prim::Constant[value=0]() |
69 | %input1 : Tensor[] = prim::ListConstruct(%0, %1) |
70 | %stack1 : Float(4, 56, 56, 56) = aten::stack(%input1, %10) |
71 | %input2 : Tensor[] = prim::ListConstruct(%2, %3) |
72 | %stack2 : Float(4, 56, 56, 56) = aten::stack(%input2, %10) |
73 | return (%stack1, %stack2) |
74 | )IR" ; |
75 | parseIR(input, graph.get()); |
76 | std::vector<at::Tensor> inputs = { |
77 | at::rand({56, 56, 56}, at::kCPU), |
78 | at::rand({56, 56, 56}, at::kCPU), |
79 | at::rand({56, 56, 56}, at::kCPU), |
80 | at::rand({56, 56, 56}, at::kCPU)}; |
81 | auto orig_outputs = runGraph(graph, inputs); |
82 | |
83 | ASSERT_TRUE(UseVariadicStack(graph)); |
84 | graph->lint(); |
85 | auto opt_outputs = runGraph(graph, inputs); |
86 | |
87 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
88 | |
89 | // After full stack optimization we should have the following graph: |
90 | // |
91 | // graph(%0 : ..., |
92 | // %1 : ..., |
93 | // %2 : ..., |
94 | // %3 : ....): |
95 | // %zero : int = prim:Constant[value=0]() |
96 | // %varcat1 : Tensor = prim::VarStack(%0, %1, %zero) |
97 | // %varcat2 : Tensor = prim::VarStack(%2, %3, %zero) |
98 | // return (%varcat1, %varcat2) |
99 | testing::FileCheck() |
100 | .check_count("= prim::VarStack(" , 2, /*exactly*/ true) |
101 | ->check_count("= aten::stack(" , 0, /*exactly*/ true) |
102 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
103 | ->run(*graph); |
104 | } |
105 | |
106 | TEST(StackOptTest, UseVariadicStackWithMultipleListUses) { |
107 | auto graph = std::make_shared<Graph>(); |
108 | |
109 | const std::string input = |
110 | R"IR( |
111 | graph(%0: Float(56, 56, 56), |
112 | %1: Float(56, 56, 56)): |
113 | %2 : int = prim::Constant[value=0]() |
114 | %input : Tensor[] = prim::ListConstruct(%0, %1) |
115 | %stack : Float(2, 56, 56, 56) = aten::stack(%input, %2) |
116 | return (%stack, %input) |
117 | )IR" ; |
118 | parseIR(input, graph.get()); |
119 | std::vector<at::Tensor> inputs = { |
120 | at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)}; |
121 | auto orig_outputs = runGraph(graph, inputs); |
122 | |
123 | ASSERT_TRUE(UseVariadicStack(graph)); |
124 | graph->lint(); |
125 | auto opt_outputs = runGraph(graph, inputs); |
126 | |
127 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
128 | |
129 | // After replacing `aten::stack` with `prim::VarStack` we should have the |
130 | // following graph: |
131 | // |
132 | // graph(%0 : ..., |
133 | // %1 : ...): |
134 | // %zero : int = prim:Constant[value=0]() |
135 | // %input : Tensor[] = prim::ListConstruct(%0, %1) |
136 | // %varcat : Tensor = prim::VarStack(%0, %1, %zero) |
137 | // return (%varcat, %input) |
138 | testing::FileCheck() |
139 | .check_count("= prim::ListConstruct(" , 1, /*exactly*/ true) |
140 | ->check_count("= prim::VarStack(" , 1, /*exactly*/ true) |
141 | ->check_count("= aten::stack(" , 0, /*exactly*/ true) |
142 | ->run(*graph); |
143 | } |
144 | |
145 | TEST(StackOptTest, UseVariadicStackWithListMutationAfterCat) { |
146 | auto graph = std::make_shared<Graph>(); |
147 | |
148 | const std::string input = |
149 | R"IR( |
150 | graph(%0: Float(56, 56, 56), |
151 | %1: Float(56, 56, 56), |
152 | %2: Float(56, 56, 56)): |
153 | %10 : int = prim::Constant[value=0]() |
154 | %input : Tensor[] = prim::ListConstruct(%0, %1) |
155 | %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10) |
156 | %11 : Tensor = aten::append(%input, %2) |
157 | return (%stack, %input) |
158 | )IR" ; |
159 | parseIR(input, graph.get()); |
160 | std::vector<at::Tensor> inputs = { |
161 | at::rand({56, 56, 56}, at::kCPU), |
162 | at::rand({56, 56, 56}, at::kCPU), |
163 | at::rand({56, 56, 56}, at::kCPU)}; |
164 | auto orig_outputs = runGraph(graph, inputs); |
165 | |
166 | ASSERT_TRUE(UseVariadicStack(graph)); |
167 | graph->lint(); |
168 | auto opt_outputs = runGraph(graph, inputs); |
169 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
170 | |
171 | // The input list to `aten::stack` is mutated only after `aten::stack` op. So, |
172 | // it should have been replaced with `prim::VarStack`. The transformed graph |
173 | // should look like the following: |
174 | // |
175 | // graph(%0 : ..., |
176 | // %1 : ..., |
177 | // %2 : ...): |
178 | // %3 : int = prim:Constant[value=0]() |
179 | // %4 : Tensor[] = prim::ListConstruct(%0, %1) |
180 | // %7 : Tensor = prim::VarStack(%0, %1, %3) |
181 | // %6 : Tensor = aten::append(%4, %2) |
182 | // return (%7, %4) |
183 | testing::FileCheck() |
184 | .check_count("= prim::ListConstruct(" , 1, /*exactly*/ true) |
185 | ->check_count("= prim::VarStack(" , 1, /*exactly*/ true) |
186 | ->check_count("= aten::stack(" , 0, /*exactly*/ true) |
187 | ->run(*graph); |
188 | } |
189 | |
190 | TEST(StackOptTest, UseVariadicStackWithListMutationBeforeCat) { |
191 | auto graph = std::make_shared<Graph>(); |
192 | |
193 | const std::string input = |
194 | R"IR( |
195 | graph(%0: Float(56, 56, 56), |
196 | %1: Float(56, 56, 56), |
197 | %2: Float(56, 56, 56)): |
198 | %10 : int = prim::Constant[value=0]() |
199 | %input : Tensor[] = prim::ListConstruct(%0, %1) |
200 | %11 : Tensor = aten::append(%input, %2) |
201 | %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10) |
202 | return (%stack) |
203 | )IR" ; |
204 | parseIR(input, graph.get()); |
205 | std::vector<at::Tensor> inputs = { |
206 | at::rand({56, 56, 56}, at::kCPU), |
207 | at::rand({56, 56, 56}, at::kCPU), |
208 | at::rand({56, 56, 56}, at::kCPU)}; |
209 | auto orig_outputs = runGraph(graph, inputs); |
210 | |
211 | { |
212 | ASSERT_FALSE(UseVariadicStack(graph)); |
213 | graph->lint(); |
214 | auto opt_outputs = runGraph(graph, inputs); |
215 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
216 | |
217 | // No transformation should have happened since the `prim::ListConstruct` is |
218 | // mutated before `aten::stack`. |
219 | testing::FileCheck() |
220 | .check_count("= prim::ListConstruct(" , 1, /*exactly*/ true) |
221 | ->check_count("= aten::stack(" , 1, /*exactly*/ true) |
222 | ->check_count("= prim::VarStack(" , 0, /*exactly*/ true) |
223 | ->run(*graph); |
224 | } |
225 | |
226 | { |
227 | ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph)); |
228 | graph->lint(); |
229 | auto opt_outputs = runGraph(graph, inputs); |
230 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
231 | |
232 | // The mutation of the list must be removed and the `aten::stack` op must |
233 | // be replaced with the `prim::VarStack` op in the graph. The transformed |
234 | // graph should look like the following: |
235 | // |
236 | // graph(%0 : ..., |
237 | // %1 : ..., |
238 | // %2 : ...): |
239 | // %3 : int = prim:Constant[value=0]() |
240 | // %7 : Tensor = prim::VarStack(%0, %1, %2, %3) |
241 | // return (%7) |
242 | testing::FileCheck() |
243 | .check_count("= prim::VarStack(" , 1, /*exactly*/ true) |
244 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
245 | ->check_count("= aten::stack(" , 0, /*exactly*/ true) |
246 | ->run(*graph); |
247 | } |
248 | } |
249 | |
250 | TEST(StackOptTest, UseVariadicStackWithMultipleListMutations) { |
251 | auto graph = std::make_shared<Graph>(); |
252 | |
253 | const std::string input = |
254 | R"IR( |
255 | graph(%0: Float(56, 56, 56), |
256 | %1: Float(56, 56, 56), |
257 | %2: Float(56, 56, 56), |
258 | %3: Float(56, 56, 56), |
259 | %4: Float(56, 56, 56)): |
260 | %10 : int = prim::Constant[value=0]() |
261 | %input : Tensor[] = prim::ListConstruct(%0, %1) |
262 | %stack.1 : Float(5, 56, 56, 56) = aten::stack(%input, %10) |
263 | %11 : Tensor = aten::append(%input, %2) |
264 | %stack.2 : Float(5, 56, 56, 56) = aten::stack(%input, %10) |
265 | %12 : Tensor = aten::append(%input, %3) |
266 | %stack.3 : Float(5, 56, 56, 56) = aten::stack(%input, %10) |
267 | %13 : Tensor = aten::append(%input, %4) |
268 | %stack.4 : Float(5, 56, 56, 56) = aten::stack(%input, %10) |
269 | return (%stack.1, %stack.2, %stack.3, %stack.4) |
270 | )IR" ; |
271 | parseIR(input, graph.get()); |
272 | std::vector<at::Tensor> inputs = { |
273 | at::rand({56, 56, 56}, at::kCPU), |
274 | at::rand({56, 56, 56}, at::kCPU), |
275 | at::rand({56, 56, 56}, at::kCPU), |
276 | at::rand({56, 56, 56}, at::kCPU), |
277 | at::rand({56, 56, 56}, at::kCPU)}; |
278 | auto orig_outputs = runGraph(graph, inputs); |
279 | |
280 | ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph)); |
281 | graph->lint(); |
282 | auto opt_outputs = runGraph(graph, inputs); |
283 | ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); |
284 | |
285 | // All the mutations of the list must be removed and the `aten::stack` ops |
286 | // must be replaced with `prim::VarStack` ops in the graph. The transformed |
287 | // graph should look like the following: |
288 | // |
289 | // graph(%0 : ..., |
290 | // %1 : ..., |
291 | // %2 : ..., |
292 | // %3 : ..., |
293 | // %4 : ...): |
294 | // %10 : int = prim:Constant[value=0]() |
295 | // %5 : Tensor = prim::VarStack(%0, %1, %10) |
296 | // %6 : Tensor = prim::VarStack(%0, %1, %2, %10) |
297 | // %7 : Tensor = prim::VarStack(%0, %1, %2, %3, %10) |
298 | // %8 : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %10) |
299 | // return (%5, %6, %7, %8) |
300 | testing::FileCheck() |
301 | .check_count("= prim::VarStack(" , 4, /*exactly*/ true) |
302 | ->check_count("= prim::ListConstruct(" , 0, /*exactly*/ true) |
303 | ->check_count("= aten::stack(" , 0, /*exactly*/ true) |
304 | ->run(*graph); |
305 | } |
306 | |
307 | } // namespace jit |
308 | } // namespace torch |
309 | |