1#include <gtest/gtest.h>
2
3#include <test/cpp/tensorexpr/test_base.h>
4#include <torch/csrc/jit/codegen/fuser/interface.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/ir/irparser.h>
7#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
8#include <torch/csrc/jit/runtime/interpreter.h>
9#include <torch/csrc/jit/testing/file_check.h>
10#include <sstream>
11
12namespace torch {
13namespace jit {
14
15using namespace torch::jit::tensorexpr;
16
17struct WithCPUFuser {
18 WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
19 overrideCanFuseOnCPU(val);
20 }
21
22 ~WithCPUFuser() {
23 overrideCanFuseOnCPU(cpuFuserEnabled);
24 }
25
26 bool cpuFuserEnabled;
27};
28
29TEST(TEFuserPass, FuserPass_1) {
30 WithCPUFuser cf;
31 const auto graph_string = R"IR(
32 graph(%0 : Float(128, strides=[1], device=cpu),
33 %1 : Float(128, strides=[1], device=cpu)):
34 %12 : int = prim::Constant[value=1]()
35 %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
36 %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1)
37 %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12)
38 %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1)
39 %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12)
40 return (%5))IR";
41 auto g = std::make_shared<Graph>();
42 torch::jit::parseIR(graph_string, g.get());
43
44 g->lint();
45 FuseTensorExprs(g);
46
47 // We should not be able to fuse across the in-place operation here.
48 testing::FileCheck()
49 .check("prim::TensorExprGroup_")
50 ->check("aten::add_")
51 ->check("prim::TensorExprGroup_")
52 ->run(*g);
53}
54
55TEST(TEFuserPass, FuserPass_2) {
56 WithCPUFuser cf;
57 const auto graph_string = R"IR(
58 graph(%0 : Float(128, strides=[1], device=cpu),
59 %1 : Float(128, strides=[1], device=cpu)):
60 %12 : int = prim::Constant[value=1]()
61 %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
62 %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12)
63 %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12)
64 %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a)
65 return (%d))IR";
66 auto g = std::make_shared<Graph>();
67 torch::jit::parseIR(graph_string, g.get());
68
69 g->lint();
70 FuseTensorExprs(g);
71
72 // We should not be able to fuse across the in-place operation here.
73 testing::FileCheck()
74 .check("aten::add_")
75 ->check("prim::TensorExprGroup_0")
76 ->run(*g);
77}
78
79TEST(TEFuserPass, FuserPass_3) {
80 WithCPUFuser cf;
81 const auto graph_string = R"IR(
82 graph(%x : Float(128, strides=[1], device=cpu),
83 %y : Float(128, strides=[1], device=cpu)):
84 %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y)
85 return (%r))IR";
86 {
87 auto g = std::make_shared<Graph>();
88 torch::jit::parseIR(graph_string, g.get());
89
90 g->lint();
91 FuseTensorExprs(g, /* min_group_size= */ 2);
92
93 // We should not create a fusion group since its size would be too small
94 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
95 }
96 {
97 auto g = std::make_shared<Graph>();
98 torch::jit::parseIR(graph_string, g.get());
99
100 g->lint();
101 FuseTensorExprs(g, /* min_group_size= */ 1);
102
103 // We should create a fusion group since its size is above the threshold
104 testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
105 }
106}
107
108TEST(TEFuserPass, FuserPass_0DimInput) {
109 WithCPUFuser cf;
110 const auto graph_string = R"IR(
111 graph(%x : Float(device=cpu),
112 %y : Float(device=cpu)):
113 %one : int = prim::Constant[value=1]()
114 %a : Float(device=cpu) = aten::mul(%x, %y)
115 %b : Float(device=cpu) = aten::add(%x, %a, %one)
116 return (%b))IR";
117 auto g = std::make_shared<Graph>();
118 torch::jit::parseIR(graph_string, g.get());
119
120 g->lint();
121 FuseTensorExprs(g);
122
123 // We should fuse 0-dim tensors too
124 testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
125}
126
127TEST(TEFuserPass, FuserPass_UnfusibleDevice) {
128 WithCPUFuser cf(false);
129 const auto graph_string = R"IR(
130 graph(%x : Float(10, strides=[1], device=cpu),
131 %y : Float(10, strides=[1], device=cpu)):
132 %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
133 return (%a))IR";
134 auto g = std::make_shared<Graph>();
135 torch::jit::parseIR(graph_string, g.get());
136
137 g->lint();
138 FuseTensorExprs(g, /* min_group_size= */ 1);
139
140 // Test that we're not starting fusion groups from nodes with unfusible device
141 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
142}
143
144TEST(TEFuserPass, FuserPass_UnknownShapes) {
145 WithCPUFuser cf;
146 const auto graph_string = R"IR(
147 graph(%x : Tensor,
148 %y : Tensor):
149 %a : Tensor = aten::mul(%x, %y)
150 %b : Tensor = aten::mul(%x, %a)
151 return (%b))IR";
152 auto g = std::make_shared<Graph>();
153 torch::jit::parseIR(graph_string, g.get());
154
155 g->lint();
156 FuseTensorExprs(g);
157
158 // Test that we're not generating fusion groups when shapes are not known
159 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
160}
161
162TEST(TEFuserPass, FuserPass_Multidevice) {
163 {
164 WithCPUFuser cf;
165 const auto graph_string = R"IR(
166 graph(%x : Float(10, strides=[1], device=cpu),
167 %y : Float(20, strides=[1], device=cpu),
168 %z : Float(30, strides=[1], device=cpu)):
169 %dim : int = prim::Constant[value=0]()
170 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
171 %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
172 return (%cat))IR";
173 auto g = std::make_shared<Graph>();
174 torch::jit::parseIR(graph_string, g.get());
175
176 g->lint();
177 FuseTensorExprs(g, /* min_group_size= */ 1);
178
179 // We should be able to fuse this
180 testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
181 }
182 {
183 WithCPUFuser cf;
184 const auto graph_string = R"IR(
185 graph(%x : Float(10, strides=[1], device=cpu),
186 %y : Float(20, strides=[1], device=cuda:0),
187 %z : Float(30, strides=[1], device=cpu)):
188 %dim : int = prim::Constant[value=0]()
189 %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
190 %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
191 return (%cat))IR";
192 auto g = std::make_shared<Graph>();
193 torch::jit::parseIR(graph_string, g.get());
194
195 g->lint();
196 FuseTensorExprs(g, /* min_group_size= */ 1);
197
198 // We should not fuse this aten::cat since its inputs are from different
199 // devices
200 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
201 }
202 {
203 WithCPUFuser cf;
204 const auto graph_string = R"IR(
205 graph(%x : Float(10, strides=[1], device=cpu),
206 %y : Float(20, strides=[1], device=cpu),
207 %z : Float(10, strides=[1], device=cuda:0)):
208 %dim : int = prim::Constant[value=0]()
209 %xy_list : Tensor[] = prim::ListConstruct(%x, %y)
210 %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
211 %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z)
212 return (%r))IR";
213 auto g = std::make_shared<Graph>();
214 torch::jit::parseIR(graph_string, g.get());
215
216 g->lint();
217 FuseTensorExprs(g, /* min_group_size= */ 2);
218
219 // Test that we check device before merging one node (cat) into another
220 // (mul)
221 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
222 }
223 {
224 WithCPUFuser cf;
225 const auto graph_string = R"IR(
226 graph(%x : Float(10, strides=[1], device=cpu),
227 %y : Float(20, strides=[1], device=cpu),
228 %z : Float(10, strides=[1], device=cuda:0)):
229 %z2 : Tensor = aten::mul(%z, %z)
230 %dim : int = prim::Constant[value=0]()
231 %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2)
232 %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
233 return (%cat))IR";
234 auto g = std::make_shared<Graph>();
235 torch::jit::parseIR(graph_string, g.get());
236
237 g->lint();
238 FuseTensorExprs(g, /* min_group_size= */ 2);
239
240 // Test that we check device before merging one node (mul) into another
241 // (cat)
242 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
243 }
244 {
245 WithCPUFuser cf;
246 const auto graph_string = R"IR(
247 graph(%x : Float(10, strides=[1], device=cpu),
248 %y : Float(20, strides=[1], device=cuda:0)):
249 %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
250 return (%r))IR";
251 auto g = std::make_shared<Graph>();
252 torch::jit::parseIR(graph_string, g.get());
253
254 g->lint();
255 FuseTensorExprs(g, /* min_group_size= */ 1);
256
257 // We should not fuse this graph since its inputs are from different devices
258 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
259 }
260 {
261 WithCPUFuser cf;
262 const auto graph_string = R"IR(
263 graph(%x : Float(10, strides=[1], device=cuda:0),
264 %y : Float(20, strides=[1], device=cuda:1),
265 %z : Float(20, strides=[1], device=cpu)):
266 %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x)
267 %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y)
268 %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z)
269 return (%x2, %y2, %z2))IR";
270 auto g = std::make_shared<Graph>();
271 torch::jit::parseIR(graph_string, g.get());
272
273 g->lint();
274 FuseTensorExprs(g, /* min_group_size= */ 2);
275
276 // We should not fuse these two computations since they use different
277 // devices
278 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
279 }
280}
281
282TEST(TEFuserPass, FuserPass_MergeGroups) {
283 WithCPUFuser cf;
284 const auto graph_string = R"IR(
285 graph(%a : Float(128, strides=[1], device=cpu),
286 %b : Float(128, strides=[1], device=cpu)):
287 %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a)
288 %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b)
289 return (%x, %y))IR";
290 auto g = std::make_shared<Graph>();
291 torch::jit::parseIR(graph_string, g.get());
292
293 g->lint();
294 FuseTensorExprs(g, /* min_group_size= */ 1);
295
296 // The %x and %y computations are completely independent and yet we should put
297 // them into a single fusion group rather than having two separate ones.
298 testing::FileCheck()
299 .check("= prim::TensorExprGroup_")
300 ->check_not("= prim::TensorExprGroup_")
301 ->run(*g);
302}
303
304TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {
305 WithCPUFuser cf;
306 const auto graph_string = R"IR(
307 graph(%x : Bool(8, strides=[1], device=cpu),
308 %y : Bool(8, strides=[1], device=cpu)):
309 %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y)
310 %b : Tensor = aten::__or__(%a, %y)
311 return (%b)
312 )IR";
313 auto g = std::make_shared<Graph>();
314 torch::jit::parseIR(graph_string, g.get());
315 g->lint();
316 FuseTensorExprs(g, /* min_group_size= */ 2);
317 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
318}
319
320TEST(TEFuserPass, FuserPass_Where) {
321 WithCPUFuser cf;
322 const auto graph_string = R"IR(
323 graph(%x : Float(8, strides=[1], device=cpu),
324 %y : Float(8, strides=[1], device=cpu),
325 %z : Float(8, strides=[1], device=cpu)):
326 %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
327 %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z)
328 return (%b)
329 )IR";
330 auto g = std::make_shared<Graph>();
331 torch::jit::parseIR(graph_string, g.get());
332 g->lint();
333 FuseTensorExprs(g, /* min_group_size= */ 2);
334 testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
335}
336
337TEST(TEFuserPass, FuserPass_WhereList) {
338 WithCPUFuser cf;
339 const auto graph_string = R"IR(
340 graph(%x : Float(8, strides=[1], device=cpu),
341 %y : Float(8, strides=[1], device=cpu),
342 %z : Float(8, strides=[1], device=cpu)):
343 %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
344 %b : Tensor[] = aten::where(%cond)
345 return (%b)
346 )IR";
347 auto g = std::make_shared<Graph>();
348 torch::jit::parseIR(graph_string, g.get());
349 g->lint();
350 FuseTensorExprs(g, /* min_group_size= */ 2);
351 testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
352}
353
354TEST(TEFuserPass, DynamicShapeFusion) {
355 WithCPUFuser cf;
356 const auto graph_string = R"IR(
357 graph(%0 : Float(10, 5, strides=[5, 1], device=cpu),
358 %1 : Float(10, 5, strides=[5, 1], device=cpu)):
359 %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1)
360 %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1)
361 return (%3))IR";
362 auto g = std::make_shared<Graph>();
363 torch::jit::parseIR(graph_string, g.get());
364
365 g->lint();
366 FuseTensorExprs(
367 g,
368 /* min_group_size = */ 2,
369 /* add_composed_op = */ true,
370 /* fuse_to_dynamic_shapes = */ true);
371 Code code(g, "");
372
373 testing::FileCheck()
374 .check("prim::TensorExprDynamicGroup_")
375 ->check("prim::TensorExprDynamicGuard")
376 ->check("prim::TensorExprGroup_")
377 ->run(*g);
378
379 auto run_and_compare = [&](const std::vector<at::Tensor>& inputs) {
380 TORCH_INTERNAL_ASSERT(inputs.size() == 2);
381
382 auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]);
383
384 InterpreterState interp(code);
385 Stack stack(inputs.begin(), inputs.end());
386 interp.run(stack);
387 at::Tensor out = pop(stack).toTensor();
388 ASSERT_TRUE(at::allclose(out, ref));
389 };
390
391 std::vector<at::Tensor> inputs = {at::rand({10, 5}), at::rand({10, 5})};
392 run_and_compare(inputs);
393
394 std::vector<at::Tensor> inputs2 = {at::rand({20, 5}), at::rand({20, 5})};
395 run_and_compare(inputs2);
396
397 std::vector<at::Tensor> inputs3 = {at::rand({25, 60}), at::rand({25, 60})};
398 run_and_compare(inputs3);
399}
400
401} // namespace jit
402} // namespace torch
403