1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/tensorexpr/test_base.h> |
4 | #include <torch/csrc/jit/codegen/fuser/interface.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/jit/ir/irparser.h> |
7 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
8 | #include <torch/csrc/jit/runtime/interpreter.h> |
9 | #include <torch/csrc/jit/testing/file_check.h> |
10 | #include <sstream> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | using namespace torch::jit::tensorexpr; |
16 | |
17 | struct WithCPUFuser { |
18 | WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) { |
19 | overrideCanFuseOnCPU(val); |
20 | } |
21 | |
22 | ~WithCPUFuser() { |
23 | overrideCanFuseOnCPU(cpuFuserEnabled); |
24 | } |
25 | |
26 | bool cpuFuserEnabled; |
27 | }; |
28 | |
29 | TEST(TEFuserPass, FuserPass_1) { |
30 | WithCPUFuser cf; |
31 | const auto graph_string = R"IR( |
32 | graph(%0 : Float(128, strides=[1], device=cpu), |
33 | %1 : Float(128, strides=[1], device=cpu)): |
34 | %12 : int = prim::Constant[value=1]() |
35 | %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) |
36 | %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1) |
37 | %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12) |
38 | %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1) |
39 | %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12) |
40 | return (%5))IR" ; |
41 | auto g = std::make_shared<Graph>(); |
42 | torch::jit::parseIR(graph_string, g.get()); |
43 | |
44 | g->lint(); |
45 | FuseTensorExprs(g); |
46 | |
47 | // We should not be able to fuse across the in-place operation here. |
48 | testing::FileCheck() |
49 | .check("prim::TensorExprGroup_" ) |
50 | ->check("aten::add_" ) |
51 | ->check("prim::TensorExprGroup_" ) |
52 | ->run(*g); |
53 | } |
54 | |
55 | TEST(TEFuserPass, FuserPass_2) { |
56 | WithCPUFuser cf; |
57 | const auto graph_string = R"IR( |
58 | graph(%0 : Float(128, strides=[1], device=cpu), |
59 | %1 : Float(128, strides=[1], device=cpu)): |
60 | %12 : int = prim::Constant[value=1]() |
61 | %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) |
62 | %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12) |
63 | %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12) |
64 | %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a) |
65 | return (%d))IR" ; |
66 | auto g = std::make_shared<Graph>(); |
67 | torch::jit::parseIR(graph_string, g.get()); |
68 | |
69 | g->lint(); |
70 | FuseTensorExprs(g); |
71 | |
72 | // We should not be able to fuse across the in-place operation here. |
73 | testing::FileCheck() |
74 | .check("aten::add_" ) |
75 | ->check("prim::TensorExprGroup_0" ) |
76 | ->run(*g); |
77 | } |
78 | |
79 | TEST(TEFuserPass, FuserPass_3) { |
80 | WithCPUFuser cf; |
81 | const auto graph_string = R"IR( |
82 | graph(%x : Float(128, strides=[1], device=cpu), |
83 | %y : Float(128, strides=[1], device=cpu)): |
84 | %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y) |
85 | return (%r))IR" ; |
86 | { |
87 | auto g = std::make_shared<Graph>(); |
88 | torch::jit::parseIR(graph_string, g.get()); |
89 | |
90 | g->lint(); |
91 | FuseTensorExprs(g, /* min_group_size= */ 2); |
92 | |
93 | // We should not create a fusion group since its size would be too small |
94 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
95 | } |
96 | { |
97 | auto g = std::make_shared<Graph>(); |
98 | torch::jit::parseIR(graph_string, g.get()); |
99 | |
100 | g->lint(); |
101 | FuseTensorExprs(g, /* min_group_size= */ 1); |
102 | |
103 | // We should create a fusion group since its size is above the threshold |
104 | testing::FileCheck().check("prim::TensorExprGroup" )->run(*g); |
105 | } |
106 | } |
107 | |
108 | TEST(TEFuserPass, FuserPass_0DimInput) { |
109 | WithCPUFuser cf; |
110 | const auto graph_string = R"IR( |
111 | graph(%x : Float(device=cpu), |
112 | %y : Float(device=cpu)): |
113 | %one : int = prim::Constant[value=1]() |
114 | %a : Float(device=cpu) = aten::mul(%x, %y) |
115 | %b : Float(device=cpu) = aten::add(%x, %a, %one) |
116 | return (%b))IR" ; |
117 | auto g = std::make_shared<Graph>(); |
118 | torch::jit::parseIR(graph_string, g.get()); |
119 | |
120 | g->lint(); |
121 | FuseTensorExprs(g); |
122 | |
123 | // We should fuse 0-dim tensors too |
124 | testing::FileCheck().check("prim::TensorExprGroup" )->run(*g); |
125 | } |
126 | |
127 | TEST(TEFuserPass, FuserPass_UnfusibleDevice) { |
128 | WithCPUFuser cf(false); |
129 | const auto graph_string = R"IR( |
130 | graph(%x : Float(10, strides=[1], device=cpu), |
131 | %y : Float(10, strides=[1], device=cpu)): |
132 | %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) |
133 | return (%a))IR" ; |
134 | auto g = std::make_shared<Graph>(); |
135 | torch::jit::parseIR(graph_string, g.get()); |
136 | |
137 | g->lint(); |
138 | FuseTensorExprs(g, /* min_group_size= */ 1); |
139 | |
140 | // Test that we're not starting fusion groups from nodes with unfusible device |
141 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
142 | } |
143 | |
144 | TEST(TEFuserPass, FuserPass_UnknownShapes) { |
145 | WithCPUFuser cf; |
146 | const auto graph_string = R"IR( |
147 | graph(%x : Tensor, |
148 | %y : Tensor): |
149 | %a : Tensor = aten::mul(%x, %y) |
150 | %b : Tensor = aten::mul(%x, %a) |
151 | return (%b))IR" ; |
152 | auto g = std::make_shared<Graph>(); |
153 | torch::jit::parseIR(graph_string, g.get()); |
154 | |
155 | g->lint(); |
156 | FuseTensorExprs(g); |
157 | |
158 | // Test that we're not generating fusion groups when shapes are not known |
159 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
160 | } |
161 | |
162 | TEST(TEFuserPass, FuserPass_Multidevice) { |
163 | { |
164 | WithCPUFuser cf; |
165 | const auto graph_string = R"IR( |
166 | graph(%x : Float(10, strides=[1], device=cpu), |
167 | %y : Float(20, strides=[1], device=cpu), |
168 | %z : Float(30, strides=[1], device=cpu)): |
169 | %dim : int = prim::Constant[value=0]() |
170 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
171 | %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) |
172 | return (%cat))IR" ; |
173 | auto g = std::make_shared<Graph>(); |
174 | torch::jit::parseIR(graph_string, g.get()); |
175 | |
176 | g->lint(); |
177 | FuseTensorExprs(g, /* min_group_size= */ 1); |
178 | |
179 | // We should be able to fuse this |
180 | testing::FileCheck().check("prim::TensorExprGroup" )->run(*g); |
181 | } |
182 | { |
183 | WithCPUFuser cf; |
184 | const auto graph_string = R"IR( |
185 | graph(%x : Float(10, strides=[1], device=cpu), |
186 | %y : Float(20, strides=[1], device=cuda:0), |
187 | %z : Float(30, strides=[1], device=cpu)): |
188 | %dim : int = prim::Constant[value=0]() |
189 | %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) |
190 | %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) |
191 | return (%cat))IR" ; |
192 | auto g = std::make_shared<Graph>(); |
193 | torch::jit::parseIR(graph_string, g.get()); |
194 | |
195 | g->lint(); |
196 | FuseTensorExprs(g, /* min_group_size= */ 1); |
197 | |
198 | // We should not fuse this aten::cat since its inputs are from different |
199 | // devices |
200 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
201 | } |
202 | { |
203 | WithCPUFuser cf; |
204 | const auto graph_string = R"IR( |
205 | graph(%x : Float(10, strides=[1], device=cpu), |
206 | %y : Float(20, strides=[1], device=cpu), |
207 | %z : Float(10, strides=[1], device=cuda:0)): |
208 | %dim : int = prim::Constant[value=0]() |
209 | %xy_list : Tensor[] = prim::ListConstruct(%x, %y) |
210 | %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) |
211 | %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z) |
212 | return (%r))IR" ; |
213 | auto g = std::make_shared<Graph>(); |
214 | torch::jit::parseIR(graph_string, g.get()); |
215 | |
216 | g->lint(); |
217 | FuseTensorExprs(g, /* min_group_size= */ 2); |
218 | |
219 | // Test that we check device before merging one node (cat) into another |
220 | // (mul) |
221 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
222 | } |
223 | { |
224 | WithCPUFuser cf; |
225 | const auto graph_string = R"IR( |
226 | graph(%x : Float(10, strides=[1], device=cpu), |
227 | %y : Float(20, strides=[1], device=cpu), |
228 | %z : Float(10, strides=[1], device=cuda:0)): |
229 | %z2 : Tensor = aten::mul(%z, %z) |
230 | %dim : int = prim::Constant[value=0]() |
231 | %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2) |
232 | %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) |
233 | return (%cat))IR" ; |
234 | auto g = std::make_shared<Graph>(); |
235 | torch::jit::parseIR(graph_string, g.get()); |
236 | |
237 | g->lint(); |
238 | FuseTensorExprs(g, /* min_group_size= */ 2); |
239 | |
240 | // Test that we check device before merging one node (mul) into another |
241 | // (cat) |
242 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
243 | } |
244 | { |
245 | WithCPUFuser cf; |
246 | const auto graph_string = R"IR( |
247 | graph(%x : Float(10, strides=[1], device=cpu), |
248 | %y : Float(20, strides=[1], device=cuda:0)): |
249 | %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) |
250 | return (%r))IR" ; |
251 | auto g = std::make_shared<Graph>(); |
252 | torch::jit::parseIR(graph_string, g.get()); |
253 | |
254 | g->lint(); |
255 | FuseTensorExprs(g, /* min_group_size= */ 1); |
256 | |
257 | // We should not fuse this graph since its inputs are from different devices |
258 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
259 | } |
260 | { |
261 | WithCPUFuser cf; |
262 | const auto graph_string = R"IR( |
263 | graph(%x : Float(10, strides=[1], device=cuda:0), |
264 | %y : Float(20, strides=[1], device=cuda:1), |
265 | %z : Float(20, strides=[1], device=cpu)): |
266 | %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x) |
267 | %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y) |
268 | %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z) |
269 | return (%x2, %y2, %z2))IR" ; |
270 | auto g = std::make_shared<Graph>(); |
271 | torch::jit::parseIR(graph_string, g.get()); |
272 | |
273 | g->lint(); |
274 | FuseTensorExprs(g, /* min_group_size= */ 2); |
275 | |
276 | // We should not fuse these two computations since they use different |
277 | // devices |
278 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
279 | } |
280 | } |
281 | |
282 | TEST(TEFuserPass, FuserPass_MergeGroups) { |
283 | WithCPUFuser cf; |
284 | const auto graph_string = R"IR( |
285 | graph(%a : Float(128, strides=[1], device=cpu), |
286 | %b : Float(128, strides=[1], device=cpu)): |
287 | %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a) |
288 | %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b) |
289 | return (%x, %y))IR" ; |
290 | auto g = std::make_shared<Graph>(); |
291 | torch::jit::parseIR(graph_string, g.get()); |
292 | |
293 | g->lint(); |
294 | FuseTensorExprs(g, /* min_group_size= */ 1); |
295 | |
296 | // The %x and %y computations are completely independent and yet we should put |
297 | // them into a single fusion group rather than having two separate ones. |
298 | testing::FileCheck() |
299 | .check("= prim::TensorExprGroup_" ) |
300 | ->check_not("= prim::TensorExprGroup_" ) |
301 | ->run(*g); |
302 | } |
303 | |
304 | TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) { |
305 | WithCPUFuser cf; |
306 | const auto graph_string = R"IR( |
307 | graph(%x : Bool(8, strides=[1], device=cpu), |
308 | %y : Bool(8, strides=[1], device=cpu)): |
309 | %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y) |
310 | %b : Tensor = aten::__or__(%a, %y) |
311 | return (%b) |
312 | )IR" ; |
313 | auto g = std::make_shared<Graph>(); |
314 | torch::jit::parseIR(graph_string, g.get()); |
315 | g->lint(); |
316 | FuseTensorExprs(g, /* min_group_size= */ 2); |
317 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
318 | } |
319 | |
320 | TEST(TEFuserPass, FuserPass_Where) { |
321 | WithCPUFuser cf; |
322 | const auto graph_string = R"IR( |
323 | graph(%x : Float(8, strides=[1], device=cpu), |
324 | %y : Float(8, strides=[1], device=cpu), |
325 | %z : Float(8, strides=[1], device=cpu)): |
326 | %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) |
327 | %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z) |
328 | return (%b) |
329 | )IR" ; |
330 | auto g = std::make_shared<Graph>(); |
331 | torch::jit::parseIR(graph_string, g.get()); |
332 | g->lint(); |
333 | FuseTensorExprs(g, /* min_group_size= */ 2); |
334 | testing::FileCheck().check("prim::TensorExprGroup" )->run(*g); |
335 | } |
336 | |
337 | TEST(TEFuserPass, FuserPass_WhereList) { |
338 | WithCPUFuser cf; |
339 | const auto graph_string = R"IR( |
340 | graph(%x : Float(8, strides=[1], device=cpu), |
341 | %y : Float(8, strides=[1], device=cpu), |
342 | %z : Float(8, strides=[1], device=cpu)): |
343 | %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) |
344 | %b : Tensor[] = aten::where(%cond) |
345 | return (%b) |
346 | )IR" ; |
347 | auto g = std::make_shared<Graph>(); |
348 | torch::jit::parseIR(graph_string, g.get()); |
349 | g->lint(); |
350 | FuseTensorExprs(g, /* min_group_size= */ 2); |
351 | testing::FileCheck().check_not("prim::TensorExprGroup" )->run(*g); |
352 | } |
353 | |
354 | TEST(TEFuserPass, DynamicShapeFusion) { |
355 | WithCPUFuser cf; |
356 | const auto graph_string = R"IR( |
357 | graph(%0 : Float(10, 5, strides=[5, 1], device=cpu), |
358 | %1 : Float(10, 5, strides=[5, 1], device=cpu)): |
359 | %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1) |
360 | %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1) |
361 | return (%3))IR" ; |
362 | auto g = std::make_shared<Graph>(); |
363 | torch::jit::parseIR(graph_string, g.get()); |
364 | |
365 | g->lint(); |
366 | FuseTensorExprs( |
367 | g, |
368 | /* min_group_size = */ 2, |
369 | /* add_composed_op = */ true, |
370 | /* fuse_to_dynamic_shapes = */ true); |
371 | Code code(g, "" ); |
372 | |
373 | testing::FileCheck() |
374 | .check("prim::TensorExprDynamicGroup_" ) |
375 | ->check("prim::TensorExprDynamicGuard" ) |
376 | ->check("prim::TensorExprGroup_" ) |
377 | ->run(*g); |
378 | |
379 | auto run_and_compare = [&](const std::vector<at::Tensor>& inputs) { |
380 | TORCH_INTERNAL_ASSERT(inputs.size() == 2); |
381 | |
382 | auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]); |
383 | |
384 | InterpreterState interp(code); |
385 | Stack stack(inputs.begin(), inputs.end()); |
386 | interp.run(stack); |
387 | at::Tensor out = pop(stack).toTensor(); |
388 | ASSERT_TRUE(at::allclose(out, ref)); |
389 | }; |
390 | |
391 | std::vector<at::Tensor> inputs = {at::rand({10, 5}), at::rand({10, 5})}; |
392 | run_and_compare(inputs); |
393 | |
394 | std::vector<at::Tensor> inputs2 = {at::rand({20, 5}), at::rand({20, 5})}; |
395 | run_and_compare(inputs2); |
396 | |
397 | std::vector<at::Tensor> inputs3 = {at::rand({25, 60}), at::rand({25, 60})}; |
398 | run_and_compare(inputs3); |
399 | } |
400 | |
401 | } // namespace jit |
402 | } // namespace torch |
403 | |