1#include <gtest/gtest.h>
2
3#include <ATen/code_template.h>
4#include <c10/util/irange.h>
5#include <test/cpp/tensorexpr/test_base.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/ir/irparser.h>
8#include <torch/csrc/jit/passes/constant_propagation.h>
9#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
10#include <torch/csrc/jit/tensorexpr/kernel.h>
11#include <torch/csrc/jit/tensorexpr/loopnest.h>
12#include <torch/csrc/jit/tensorexpr/tensor.h>
13#include <torch/csrc/jit/testing/file_check.h>
14#include <torch/torch.h>
15#include <cmath>
16#include <sstream>
17#include <stdexcept>
18
19namespace torch {
20namespace jit {
21
22using namespace torch::indexing;
23using namespace torch::jit::tensorexpr;
24
25class Kernel : public ::testing::Test {
26 public:
27 void SetUp() override {
28 getTEMustUseLLVMOnCPU() = false;
29 }
30};
31
32TEST_F(Kernel, ParallelExternalCallBuf) {
33 const auto graph_string = R"IR(
34 graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu),
35 %1 : Float(1000, 5000, strides=[5000, 1], device=cpu),
36 %2 : Float(5000, 1000, strides=[5000, 1], device=cpu)):
37 %3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1)
38 %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2)
39 return (%4))IR";
40 auto graph = std::make_shared<Graph>();
41 torch::jit::parseIR(graph_string, &*graph);
42 const std::string& verification_pattern =
43 R"IR(
44# CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR";
45
46#ifdef TORCH_ENABLE_LLVM
47 TensorExprKernel k(graph);
48 StmtPtr s = k.getCodeGenStmt();
49 std::ostringstream oss;
50 oss << *s;
51 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
52#endif
53}
54
55TEST_F(Kernel, InliningIntermediates) {
56 // here, each mul has only one use, so it should be completely inlined
57 {
58 const auto graph_string = R"IR(
59 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
60 %1 : Float(5, 3, strides=[3, 1], device=cpu)):
61 %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
62 %one : int = prim::Constant[value=1]()
63 %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
64 %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one)
65 return (%5))IR";
66 auto graph = std::make_shared<Graph>();
67 parseIR(graph_string, &*graph);
68 TensorExprKernel k(graph);
69 auto stmt = k.getCodeGenStmt();
70 std::ostringstream oss;
71 oss << *stmt;
72 torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str());
73 }
74 {
75 const auto graph_template = R"IR(
76 graph(%0 : Float(5, 3, strides=[3, 1], device=${device}),
77 %1 : Float(5, 3, strides=[3, 1], device=${device})):
78 %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
79 %one : int = prim::Constant[value=1]()
80 %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one)
81 %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one)
82 %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0)
83 return (%4, %5))IR";
84 for (bool use_cuda : {false, true}) {
85 if (!torch::cuda::is_available() && use_cuda) {
86 continue;
87 }
88
89 at::jit::TemplateEnv env;
90 env.s("device", use_cuda ? "cuda:0" : "cpu");
91 const auto graph_string = format(graph_template, env);
92 auto graph = std::make_shared<Graph>();
93 parseIR(graph_string, &*graph);
94 TensorExprKernel k(graph);
95 auto stmt = k.getCodeGenStmt();
96 std::ostringstream oss;
97 oss << *stmt;
98 // aten_mul only has one use, inlined completely
99 torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str());
100
101 // aten_sub should be removed by the CUDA backend by metavar rewriting
102 // and by the CPU backend by horizontal fusion.
103 torch::jit::testing::FileCheck().check_not("aten_sub")->run(oss.str());
104 }
105 }
106}
107
108TEST_F(Kernel, PreAllocIntermediateBufs) {
109 const auto graph_string = R"IR(
110graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu),
111 %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)):
112 %2 : int = prim::Constant[value=1]()
113 %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12
114 %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15
115 return (%3))IR";
116 auto graph = std::make_shared<Graph>();
117 parseIR(graph_string, &*graph);
118
119 auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
120 auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
121 auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
122 auto ref = at::matmul(a, b) + a;
123 TensorExprKernel k(graph, {}, {}, true);
124
125 std::vector<at::Tensor> inputs = {a, b};
126 auto stmt = k.getCodeGenStmt();
127
128 std::ostringstream oss;
129 oss << *stmt;
130
131 // Check whether the intermediate buffer has been added to constants
132 auto constants = k.getConstantDescriptors();
133 ASSERT_EQ(constants.size(), 1);
134
135 // Check the IR we produced
136 torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str());
137 torch::jit::testing::FileCheck().check_not("Free")->run(oss.str());
138
139 // Check correctness
140 std::vector<IValue> stack = fmap<IValue>(inputs);
141 k.run(stack);
142 o = stack[0].toTensor();
143 ASSERT_TRUE(at::allclose(o, ref));
144}
145
146TEST_F(Kernel, _1) {
147 const auto graph_string = R"IR(
148 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
149 %1 : Float(5, 3, strides=[3, 1], device=cpu)):
150 %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
151 %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
152 return (%3))IR";
153 auto graph = std::make_shared<Graph>();
154 parseIR(graph_string, &*graph);
155
156 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
157 auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
158 auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
159 auto ref = a * (a * b);
160 TensorExprKernel k(graph);
161 std::vector<at::Tensor> inputs = {a, b};
162 StmtPtr s = k.getCodeGenStmt();
163
164 std::ostringstream oss;
165 oss << *s;
166
167 // Check the IR we produced
168 const std::string& verification_pattern =
169 R"IR(
170# CHECK: for
171# CHECK-NEXT: for
172# CHECK-NOT: for)IR";
173 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
174
175 std::vector<IValue> stack = fmap<IValue>(inputs);
176 k.run(stack);
177 o = stack[0].toTensor();
178 for (size_t i = 0; i < 5 * 3; i++) {
179 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
180 }
181}
182
183TEST_F(Kernel, _2) {
184 const auto graph_string = R"IR(
185 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
186 %1 : Float(5, 3, strides=[1, 5], device=cpu)):
187 %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
188 %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
189 return (%3))IR";
190 auto graph = std::make_shared<Graph>();
191 parseIR(graph_string, &*graph);
192
193 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
194 auto b =
195 at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
196 auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
197 auto ref = a * (a * b);
198 TensorExprKernel k(graph);
199 std::vector<at::Tensor> inputs = {a, b};
200 StmtPtr s = k.getCodeGenStmt();
201
202 std::ostringstream oss;
203 oss << *s;
204
205 // Check the IR we produced
206 const std::string& verification_pattern =
207 R"IR(
208# CHECK: for
209# CHECK-NEXT: for
210# CHECK-NOT: for)IR";
211 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
212
213 std::vector<IValue> stack = fmap<IValue>(inputs);
214 k.run(stack);
215 o = stack[0].toTensor();
216 for (size_t i = 0; i < 5 * 3; i++) {
217 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
218 }
219}
220
221TEST_F(Kernel, _3) {
222 const auto graph_string = R"IR(
223 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
224 %1 : Float(5, 3, strides=[12, 2], device=cpu)):
225 %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
226 %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
227 return (%3))IR";
228 auto graph = std::make_shared<Graph>();
229 parseIR(graph_string, &*graph);
230
231 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
232 auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
233 .index({Slice(None, None, 2), Slice(None, None, 2)});
234 auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
235 auto ref = a * (a * b);
236 TensorExprKernel k(graph);
237 std::vector<at::Tensor> inputs = {a, b};
238 StmtPtr s = k.getCodeGenStmt();
239
240 std::ostringstream oss;
241 oss << *s;
242
243 // Check the IR we produced
244 const std::string& verification_pattern =
245 R"IR(
246# CHECK: for
247# CHECK-NEXT: for
248# CHECK-NOT: for)IR";
249 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
250
251 std::vector<IValue> stack = fmap<IValue>(inputs);
252 k.run(stack);
253 o = stack[0].toTensor();
254 for (size_t i = 0; i < 5 * 3; i++) {
255 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
256 }
257}
258
259TEST_F(Kernel, Huge) {
260 const auto graph_string = R"IR(
261 graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)):
262 %1 : int = prim::Constant[value=0]()
263 %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1)
264 %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2)
265 return (%3))IR";
266 auto graph = std::make_shared<Graph>();
267 parseIR(graph_string, &*graph);
268 TensorExprKernel k(graph);
269 std::ostringstream oss;
270 oss << *k.getCodeGenStmt();
271 // The 4000000000 iterations loop will be split into 500000000 x 8 and the
272 // outer loop will be parallel. If LLVM is not present, it will not be split,
273 // and to cover both of these cases we're looking for 00000000ll; in the
274 // output.
275 const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR";
276 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
277}
278
279TEST_F(Kernel, ParallelStrided) {
280 const auto graph_string = R"IR(
281 graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu),
282 %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)):
283 %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1)
284 %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2)
285 return (%3))IR";
286 auto graph = std::make_shared<Graph>();
287 parseIR(graph_string, &*graph);
288
289 auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat));
290 auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat))
291 .index(
292 {Slice(None, None, 2),
293 Slice(None, None, 2),
294 Slice(None, None, 2)});
295 auto ref = a * (a * b);
296 auto o = at::zeros_like(ref);
297 TensorExprKernel k(graph);
298 std::vector<at::Tensor> inputs = {a, b};
299 std::vector<IValue> stack = fmap<IValue>(inputs);
300 k.run(stack);
301 o = stack[0].toTensor();
302 for (size_t i = 0; i < 5 * 3; i++) {
303 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
304 }
305}
306
307TEST_F(Kernel, DISABLED_Shape_Inference) {
308 // disabled: doesn't do stride propagation, and isn't being used currently
309
310 // Test TensorExpr shape inference capabilities: it should only require shapes
311 // for the inputs
312 {
313 const auto graph_string = R"IR(
314 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
315 %1 : Float(5, 3, strides=[12, 2], device=cpu)):
316 %2 : Tensor = aten::mul(%0, %1)
317 %3 : Tensor = aten::mul(%0, %2)
318 return (%3))IR";
319 auto graph = std::make_shared<Graph>();
320 parseIR(graph_string, &*graph);
321
322 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
323 auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
324 .index({Slice(None, None, 2), Slice(None, None, 2)});
325 auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
326 auto ref = a * (a * b);
327 TensorExprKernel k(graph);
328 std::vector<at::Tensor> inputs = {a, b};
329 StmtPtr s = k.getCodeGenStmt();
330
331 std::ostringstream oss;
332 oss << *s;
333
334 // Check the IR we produced
335 const std::string& verification_pattern =
336 R"IR(
337# CHECK: for
338# CHECK-NEXT: for
339# CHECK-NOT: for)IR";
340 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
341
342 std::vector<IValue> stack = fmap<IValue>(inputs);
343 k.run(stack);
344 o = stack[0].toTensor();
345 for (size_t i = 0; i < 5 * 3; i++) {
346 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
347 }
348 }
349 {
350 const auto graph_string = R"IR(
351 graph(%0 : Float(8, 8, strides=[8, 1], device=cpu),
352 %1 : Float(8, 8, strides=[8, 1], device=cpu)):
353 %2 : Tensor = aten::mul(%0, %1)
354 %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2)
355 %r : Tensor = aten::mul(%3, %4)
356 return (%r))IR";
357 auto graph = std::make_shared<Graph>();
358 parseIR(graph_string, &*graph);
359
360 auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
361 auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
362 auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat));
363 auto t = torch::chunk(a * b, 2, 1);
364 auto ref = t[0] * t[1];
365 TensorExprKernel k(graph);
366 std::vector<at::Tensor> inputs = {a, b};
367 StmtPtr s = k.getCodeGenStmt();
368
369 std::ostringstream oss;
370 oss << *s;
371
372 // Check the IR we produced
373 const std::string& verification_pattern =
374 R"IR(
375# CHECK: for)IR";
376 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
377
378 std::vector<IValue> stack = fmap<IValue>(inputs);
379 k.run(stack);
380 o = stack[0].toTensor();
381 TORCH_CHECK_EQ(o.sizes()[0], 8);
382 TORCH_CHECK_EQ(o.sizes()[1], 4);
383 for (size_t i = 0; i < 8 * 4; i++) {
384 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
385 }
386 }
387 {
388 // Test that shape inference handles aten::unsqueeze
389
390 const auto graph_string = R"IR(
391 graph(%a : Float(4, 2, strides=[2, 1], device=cpu),
392 %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu),
393 %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)):
394 %one : int = prim::Constant[value=1]()
395 %minus_one : int = prim::Constant[value=-1]()
396 %three : int = prim::Constant[value=3]()
397 %minus_four : int = prim::Constant[value=-4]()
398 %a1 : Tensor = aten::unsqueeze(%a, %one) # new size: [4,1,2]
399 %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1]
400 %b1 : Tensor = aten::unsqueeze(%b, %three) # new size: [4,3,2,1]
401 %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2]
402 %ab : Tensor = aten::mul(%a2, %b1) # expected size: [4,3,2,1]
403 %abc : Tensor = aten::mul(%ab, %c1) # expected size: [4,3,2,2]
404 return (%abc))IR";
405 auto graph = std::make_shared<Graph>();
406 parseIR(graph_string, &*graph);
407
408 auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat));
409 auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
410 auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
411 auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
412 auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) *
413 at::unsqueeze(c, -4);
414
415 TensorExprKernel k(graph);
416 std::vector<at::Tensor> inputs = {a, b, c};
417 StmtPtr s = k.getCodeGenStmt();
418
419 std::ostringstream oss;
420 oss << *s;
421
422 // Check the IR we produced
423 const std::string& verification_pattern =
424 R"IR(
425# CHECK: for
426# CHECK-NEXT: for
427# CHECK-NEXT: for
428# CHECK-NEXT: for
429# CHECK-NEXT: aten_mul)IR";
430 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
431
432 std::vector<IValue> stack = fmap<IValue>(inputs);
433 k.run(stack);
434 o = stack[0].toTensor();
435
436 // Check sizes
437 TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
438 size_t num_el = 1;
439 for (const auto idx : c10::irange(ref.sizes().size())) {
440 TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
441 num_el *= ref.sizes()[idx];
442 }
443
444 // Check the contents
445 for (const auto i : c10::irange(num_el)) {
446 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
447 }
448 }
449 {
450 // Test that shape inference handles aten::cat
451
452 const auto graph_string = R"IR(
453 graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
454 %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
455 %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)):
456 %dim : int = prim::Constant[value=1]()
457 %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
458 %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2]
459 return (%r))IR";
460 auto graph = std::make_shared<Graph>();
461 parseIR(graph_string, &*graph);
462
463 auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
464 auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
465 auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat));
466 auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat));
467 auto ref = at::cat({a, b, c}, 1);
468
469 TensorExprKernel k(graph);
470 std::vector<at::Tensor> inputs = {a, b, c};
471 StmtPtr s = k.getCodeGenStmt();
472
473 std::ostringstream oss;
474 oss << *s;
475
476 // Check the IR we produced
477 const std::string& verification_pattern =
478 R"IR(
479# CHECK: for
480# CHECK-NEXT: for
481# CHECK-NEXT: for
482# CHECK-NEXT: aten_cat)IR";
483 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
484
485 std::vector<IValue> stack = fmap<IValue>(inputs);
486 k.run(stack);
487 o = stack[0].toTensor();
488
489 // Check sizes
490 TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
491 size_t num_el = 1;
492 for (const auto idx : c10::irange(ref.sizes().size())) {
493 TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
494 num_el *= ref.sizes()[idx];
495 }
496
497 // Check the contents
498 for (const auto i : c10::irange(num_el)) {
499 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
500 }
501 }
502 {
503 // Test that we throw an error when input list for aten::cat is empty
504
505 const auto graph_string = R"IR(
506 graph():
507 %dim : int = prim::Constant[value=1]()
508 %inputs : Tensor[] = prim::ListConstruct()
509 %r : Tensor = aten::cat(%inputs, %dim)
510 return (%r))IR";
511 auto graph = std::make_shared<Graph>();
512 parseIR(graph_string, &*graph);
513 auto compile = [&]() {
514 TensorExprKernel k(graph);
515 k.getCodeGenStmt();
516 };
517 ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat");
518 }
519 {
520 // Test that we throw an error when 'dim' passed to aten::cat is invalid
521
522 const auto ir_dim_99 = R"IR(
523 graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
524 %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)):
525 %dim : int = prim::Constant[value=99]()
526 %inputs : Tensor[] = prim::ListConstruct(%a, %b)
527 %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim)
528 return (%r))IR";
529 const auto ir_dim_minus_6 = R"IR(
530 graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
531 %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)):
532 %dim : int = prim::Constant[value=-6]()
533 %inputs : Tensor[] = prim::ListConstruct(%a, %b)
534 %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim)
535 return (%r))IR";
536
537 auto compile = [](const std::string& graph_string) {
538 auto graph = std::make_shared<Graph>();
539 parseIR(graph_string, &*graph);
540 TensorExprKernel k(graph);
541 k.getCodeGenStmt();
542 };
543 ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index");
544 ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index");
545 }
546}
547
548TEST_F(Kernel, CatInputTypesPromotion) {
549 {
550 // Test that we properly promote input types for aten::cat
551
552 const auto graph_string = R"IR(
553 graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
554 %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
555 %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)):
556 %dim : int = prim::Constant[value=1]()
557 %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
558 %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim)
559 return (%r))IR";
560 auto graph = std::make_shared<Graph>();
561 parseIR(graph_string, &*graph);
562
563 auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
564 auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
565 auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble));
566 auto ref = at::cat({a, b, c}, 1);
567
568 TensorExprKernel k(graph);
569 std::vector<at::Tensor> inputs = {a, b, c};
570 StmtPtr s = k.getCodeGenStmt();
571
572 std::ostringstream oss;
573 oss << *s;
574
575 // Check the IR we produced
576 const std::string& verification_pattern =
577 R"IR(
578# CHECK: for
579# CHECK-NEXT: for
580# CHECK-NEXT: for
581# CHECK-NEXT: aten_cat)IR";
582 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
583
584 std::vector<IValue> stack = fmap<IValue>(inputs);
585 k.run(stack);
586 auto o = stack[0].toTensor();
587
588 // Check sizes
589 TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
590 TORCH_CHECK_EQ(o.dtype(), ref.dtype());
591 size_t num_el = 1;
592 for (const auto idx : c10::irange(ref.sizes().size())) {
593 TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
594 num_el *= ref.sizes()[idx];
595 }
596
597 // Check the contents
598 for (const auto i : c10::irange(num_el)) {
599 TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]);
600 }
601 }
602}
603
604TEST_F(Kernel, ToDType) {
605#ifdef TORCH_ENABLE_LLVM
606 const auto graph_string = R"IR(
607 graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
608 %1 : NoneType = prim::Constant()
609 %2 : bool = prim::Constant[value=0]()
610 %3 : int = prim::Constant[value=6]()
611 %4 : int = prim::Constant[value=15]()
612 %5 : int = prim::Constant[value=5]()
613 %6 : bool = prim::Constant[value=1]()
614 %y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1)
615 %z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4)
616 %h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6)
617 %i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1)
618 %j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1)
619 %k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1)
620 return (%k.3))IR";
621
622 auto graph = std::make_shared<Graph>();
623 parseIR(graph_string, &*graph);
624 TensorExprKernel k(graph);
625 StmtPtr s = k.getCodeGenStmt();
626 std::ostringstream oss;
627 oss << *s;
628
629 const std::string& verification_pattern =
630 R"IR(
631# CHECK: for
632# CHECK-NEXT: for
633# CHECK-NEXT: aten_to
634# CHECK-NEXT: }
635# CHECK-NEXT: })IR";
636 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
637
638 auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16));
639 auto ref =
640 at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat));
641
642 std::vector<at::Tensor> inputs = {a};
643 std::vector<IValue> stack = fmap<IValue>(inputs);
644 k.run(stack);
645 auto o = stack[0].toTensor();
646 ASSERT_EQ(o.sizes(), ref.sizes());
647 ASSERT_EQ(o.dtype(), ref.dtype());
648 ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
649#endif
650}
651
652TEST_F(Kernel, CatAndInlineWithAConstantDim) {
653 const auto graph_string = R"IR(
654 graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu),
655 %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)):
656 %2 : bool = prim::Constant[value=0]()
657 %3 : int = prim::Constant[value=1]()
658 %4 : Tensor[] = prim::ListConstruct(%0, %1)
659 %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3)
660 %6 : Tensor[] = prim::ListConstruct(%5)
661 %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3)
662 %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2)
663 return (%8, %7))IR";
664
665 auto graph = std::make_shared<Graph>();
666 parseIR(graph_string, &*graph);
667 TensorExprKernel k(graph);
668
669 auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
670 auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
671 auto ref = at::_cast_Float(at::cat({a, b}, 1), 0);
672
673 std::vector<at::Tensor> inputs = {a, b};
674 std::vector<IValue> stack = fmap<IValue>(inputs);
675 k.run(stack);
676 auto o = stack[0].toTensor();
677 ASSERT_EQ(o.sizes(), ref.sizes());
678 ASSERT_EQ(o.dtype(), ref.dtype());
679 ASSERT_TRUE(at::allclose(o, ref));
680}
681
682TEST_F(Kernel, CatWithEmptyInputs) {
683 bool curr_cat_wo_conditionals = getCatWoConditionals();
684 for (auto cat_wo_conditionals : {true, false}) {
685 getCatWoConditionals() = cat_wo_conditionals;
686 const auto graph_string = R"IR(
687 graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu),
688 %1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)):
689 %3 : int = prim::Constant[value=0]()
690 %6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0)
691 %7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1)
692 %10 : Tensor[] = prim::ListConstruct(%6, %7)
693 %11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3)
694 return (%11))IR";
695
696 auto graph = std::make_shared<Graph>();
697 parseIR(graph_string, &*graph);
698 TensorExprKernel k(graph);
699
700 auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat));
701 auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat));
702 auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0);
703
704 std::vector<at::Tensor> inputs = {a, b};
705 std::vector<IValue> stack = fmap<IValue>(inputs);
706 k.run(stack);
707 auto o = stack[0].toTensor();
708 ASSERT_EQ(o.sizes(), ref.sizes());
709 ASSERT_EQ(o.dtype(), ref.dtype());
710 ASSERT_TRUE(at::allclose(o, ref));
711 }
712 getCatWoConditionals() = curr_cat_wo_conditionals;
713}
714
715TEST_F(Kernel, CatWoConditionals) {
716 bool old_cat_wo_conditionals = getCatWoConditionals();
717 getCatWoConditionals() = true;
718 const auto graph_string = R"IR(
719 graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
720 %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
721 %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)):
722 %dim : int = prim::Constant[value=1]()
723 %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
724 %r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim)
725 return (%r))IR";
726
727 auto graph = std::make_shared<Graph>();
728 parseIR(graph_string, &*graph);
729
730 TensorExprKernel k(graph);
731 StmtPtr s = k.getCodeGenStmt();
732 std::ostringstream oss;
733 oss << *s;
734
735 const std::string& verification_pattern =
736 R"IR(
737# CHECK: for
738# CHECK: for
739# CHECK: for
740# CHECK: aten_cat
741# CHECK: for
742# CHECK: for
743# CHECK: aten_cat
744# CHECK: for
745# CHECK: for
746# CHECK: aten_cat)IR";
747 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
748
749 auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
750 auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
751 auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat));
752 auto ref = at::cat({a, b, c}, 1);
753
754 std::vector<at::Tensor> inputs = {a, b, c};
755 std::vector<IValue> stack = fmap<IValue>(inputs);
756 k.run(stack);
757 auto o = stack[0].toTensor();
758
759 // Check sizes
760 TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
761 TORCH_CHECK_EQ(o.dtype(), ref.dtype());
762 size_t num_el = 1;
763 for (const auto idx : c10::irange(ref.sizes().size())) {
764 TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
765 num_el *= ref.sizes()[idx];
766 }
767
768 // Check the contents
769 for (const auto i : c10::irange(num_el)) {
770 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
771 }
772 getCatWoConditionals() = old_cat_wo_conditionals;
773}
774
775TEST_F(Kernel, OptimizeConditionals) {
776 bool old_cat_wo_conditionals = getCatWoConditionals();
777 bool old_opt_conditionals = getOptConditionals();
778 getCatWoConditionals() = false;
779 getOptConditionals() = true;
780 const auto graph_string = R"IR(
781 graph(%a : Float(5, 3, strides=[3, 1], device=cpu),
782 %b : Float(5, 7, strides=[7, 1], device=cpu),
783 %c : Float(5, 9, strides=[9, 1], device=cpu)):
784 %dim : int = prim::Constant[value=1]()
785 %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
786 %r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim)
787 %t : Float(5, 19, strides=[19, 1]) = aten::relu(%r)
788 return (%t))IR";
789
790 auto graph = std::make_shared<Graph>();
791 parseIR(graph_string, &*graph);
792
793 TensorExprKernel k(graph);
794 StmtPtr s = k.getCodeGenStmt();
795 std::ostringstream oss;
796 oss << *s;
797
798 const std::string& verification_pattern =
799 R"IR(
800# CHECK: for
801# CHECK-NEXT: for
802# CHECK-NEXT: aten_relu
803# CHECK: for
804# CHECK-NEXT: aten_relu
805# CHECK: for
806# CHECK-NEXT: aten_relu
807# CHECK-NOT: Allocate
808# CHECK-NOT: Free)IR";
809 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
810
811 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
812 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
813 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
814 auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat));
815 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
816 auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat));
817 auto ref = at::relu(at::cat({a, b, c}, 1));
818
819 std::vector<at::Tensor> inputs = {a, b, c};
820 std::vector<IValue> stack = fmap<IValue>(inputs);
821 k.run(stack);
822 auto o = stack[0].toTensor();
823
824 // Check sizes
825 TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
826 TORCH_CHECK_EQ(o.dtype(), ref.dtype());
827 size_t num_el = 1;
828 for (const auto idx : c10::irange(ref.sizes().size())) {
829 TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
830 num_el *= ref.sizes()[idx];
831 }
832
833 // Check the contents
834 for (const auto i : c10::irange(num_el)) {
835 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
836 }
837 getOptConditionals() = old_opt_conditionals;
838 getCatWoConditionals() = old_cat_wo_conditionals;
839}
840
841namespace {
842
843std::string dtypeConstant(ScalarType scalar_type) {
844 if (scalar_type == ScalarType::Undefined) {
845 return "None = prim::Constant()";
846 } else {
847 at::jit::TemplateEnv env_dtype;
848 env_dtype.d("scalar_type", static_cast<int>(scalar_type));
849 return format("int = prim::Constant[value=${scalar_type}]()", env_dtype);
850 }
851}
852
853at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) {
854 int64_t numel = std::accumulate(
855 sizes.begin(),
856 sizes.end(),
857 1,
858 // NOLINTNEXTLINE(modernize-use-transparent-functors)
859 std::multiplies<int64_t>());
860 std::vector<float> values(numel);
861 std::iota(values.begin(), values.end(), 0);
862 auto a = at::tensor(values, options);
863 return a.reshape(sizes);
864}
865
866} // namespace
867
868TEST_F(Kernel, SumAllAxes) {
869 // Test lowering of sum on all axes.
870 const auto graph_template = R"IR(
871 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
872 %1 : ${dtype}
873 %2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1)
874 return (%2))IR";
875 auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
876
877 for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) {
878 at::jit::TemplateEnv env;
879 env.s("dtype", dtypeConstant(scalar_type));
880 if (scalar_type == ScalarType::Undefined) {
881 env.s("out_dtype", "Float");
882 } else {
883 env.s("out_dtype", "Double");
884 }
885 const auto graph_string = format(graph_template, env);
886
887 auto graph = std::make_shared<Graph>();
888 parseIR(graph_string, &*graph);
889
890 auto o = at::empty({}, TensorOptions(kCPU));
891 c10::optional<c10::ScalarType> dtype;
892 if (scalar_type != ScalarType::Undefined) {
893 dtype = static_cast<c10::ScalarType>(scalar_type);
894 }
895 auto ref = a.sum(/*dtype=*/dtype);
896 TensorExprKernel k(graph);
897 std::vector<at::Tensor> inputs = {a};
898 StmtPtr s = k.getCodeGenStmt();
899
900 std::ostringstream oss;
901 oss << *s;
902
903 // Check the IR we produced
904 const std::string& verification_pattern =
905 R"IR(
906# CHECK: for
907# CHECK-NEXT: for)IR";
908 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
909
910 std::vector<IValue> stack = fmap<IValue>(inputs);
911 k.run(stack);
912 o = stack[0].toTensor();
913 ASSERT_EQ(o.sizes(), ref.sizes());
914 ASSERT_EQ(o.dtype(), ref.dtype());
915 ASSERT_TRUE(at::allclose(o, ref));
916 }
917}
918
919std::string li_to_str(at::ArrayRef<int64_t> li) {
920 std::stringstream out;
921 bool first = true;
922 for (auto elem : li) {
923 if (!first) {
924 out << ", ";
925 }
926 out << elem;
927 first = false;
928 }
929 return out.str();
930}
931
932TEST_F(Kernel, SumOneAxis) {
933 // Test lowering of sum on one axis.
934 const auto graph_template = R"IR(
935 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
936 %1 : int[] = prim::Constant[value=[${dim}]]()
937 %2 : bool = prim::Constant[value=${keepdim}]()
938 %3 : ${dtype}
939 %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3)
940 return (%4))IR";
941 auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
942
943 for (int dim = -a.dim(); dim < a.dim(); ++dim) {
944 for (bool keepdim : {false, true}) {
945 for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) {
946 at::jit::TemplateEnv env;
947 env.d("dim", dim);
948 env.d("keepdim", keepdim);
949 env.s("dtype", dtypeConstant(scalar_type));
950 c10::optional<c10::ScalarType> dtype;
951 if (scalar_type != ScalarType::Undefined) {
952 dtype = static_cast<c10::ScalarType>(scalar_type);
953 }
954 auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype);
955 if (scalar_type == ScalarType::Undefined) {
956 env.s("out_dtype", "Float");
957 } else {
958 env.s("out_dtype", "Double");
959 }
960 env.s("size", li_to_str(ref.sizes()));
961 env.s("strides", li_to_str(ref.strides()));
962 const auto graph_string = format(graph_template, env);
963 auto graph = std::make_shared<Graph>();
964 parseIR(graph_string, &*graph);
965
966 auto o = at::empty({}, TensorOptions(kCPU));
967 TensorExprKernel k(graph);
968 std::vector<at::Tensor> inputs = {a};
969 StmtPtr s = k.getCodeGenStmt();
970
971 std::ostringstream oss;
972 oss << *s;
973
974 // Check the IR we produced
975 const std::string& verification_pattern =
976 R"IR(
977# CHECK: for (int64_t
978# CHECK-NEXT: sum
979# CHECK-NEXT: for (int64_t
980# CHECK-NEXT: sum)IR";
981 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
982
983 std::vector<IValue> stack = fmap<IValue>(inputs);
984 k.run(stack);
985 o = stack[0].toTensor();
986 ASSERT_EQ(o.sizes(), ref.sizes());
987 ASSERT_EQ(o.dtype(), ref.dtype());
988 ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
989 }
990 }
991 }
992}
993
994TEST_F(Kernel, SumMultipleAxes) {
995 // Test lowering of sum on multiple axes.
996 const auto graph_template = R"IR(
997 graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)):
998 %1 : int = prim::Constant[value=${dim1}]()
999 %2 : int = prim::Constant[value=${dim2}]()
1000 %3 : int[] = prim::ListConstruct(%1, %2)
1001 %4 : bool = prim::Constant[value=${keepdim}]()
1002 %5 : ${dtype}
1003 %6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5)
1004 return (%6))IR";
1005 auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1006
1007 // Only iterate over positive values of axes to keep the running time
1008 // reasonable, since the number of pairs is quadratic.
1009 for (const auto dim1 : c10::irange(a.dim())) {
1010 for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) {
1011 for (bool keepdim : {false, true}) {
1012 at::jit::TemplateEnv env;
1013 env.d("dim1", dim1);
1014 env.d("dim2", dim2);
1015 env.d("keepdim", keepdim);
1016 env.s("dtype", dtypeConstant(ScalarType::Undefined));
1017 auto o = at::empty({}, TensorOptions(kCPU));
1018 auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim);
1019
1020 env.s("size", li_to_str(ref.sizes()));
1021 env.s("strides", li_to_str(ref.strides()));
1022
1023 const auto graph_string = format(graph_template, env);
1024
1025 auto graph = std::make_shared<Graph>();
1026 parseIR(graph_string, &*graph);
1027
1028 TensorExprKernel k(graph);
1029 std::vector<at::Tensor> inputs = {a};
1030 StmtPtr s = k.getCodeGenStmt();
1031
1032 std::ostringstream oss;
1033 oss << *s;
1034
1035 // Check the IR we produced
1036 const std::string& verification_pattern =
1037 R"IR(
1038# CHECK: for (int64_t
1039# CHECK: for (int64_t
1040# CHECK: for (int64_t
1041# CHECK: for (int64_t
1042# CHECK: sum)IR";
1043 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1044
1045 std::vector<IValue> stack = fmap<IValue>(inputs);
1046 k.run(stack);
1047 o = stack[0].toTensor();
1048 ASSERT_EQ(o.sizes(), ref.sizes());
1049 ASSERT_EQ(o.dtype(), ref.dtype());
1050 ASSERT_TRUE(at::allclose(o, ref));
1051 }
1052 }
1053 }
1054}
1055
1056// This test and the following ones testing Softmax only tests with dim set
1057// to one of the valid input dimensions. It does not test with dim=None
1058// because that is supposed to be deprecated.
1059TEST_F(Kernel, Softmax2D) {
1060 const auto graph_template = R"IR(
1061 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
1062 %1 : int = prim::Constant[value=${dim}]()
1063 %dt_float : int = prim::Constant[value=7]()
1064 %dt_none : NoneType = prim::Constant()
1065 %4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt})
1066 return (%4))IR";
1067
1068 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1069
1070 const std::string& verification_template =
1071 R"IR(
1072 # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size}
1073 # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1074 # CHECK-NEXT: aten_softmax_max
1075 # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size}
1076 # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1077 # CHECK-NEXT: aten_softmax_sum
1078 # CHECK: for (int i0_2 = 0; i0_2 < 5
1079 # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
1080 # CHECK-NEXT: aten_softmax)IR";
1081
1082 for (bool empty_dtype : {false, true}) {
1083 for (auto log_softmax : {false, true}) {
1084 for (const auto softmax_dim : c10::irange(a.dim())) {
1085 auto softmax_dim_size = a.sizes()[softmax_dim];
1086 auto other_dim = (softmax_dim + 1) % a.dim();
1087 auto ref =
1088 log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1089 at::jit::TemplateEnv env;
1090 env.d("dim", softmax_dim);
1091 env.s("op", log_softmax ? "log_softmax" : "softmax");
1092 env.s("size", li_to_str(ref.sizes()));
1093 env.s("strides", li_to_str(ref.strides()));
1094 env.s("dt", empty_dtype ? "dt_none" : "dt_float");
1095
1096 const auto graph_string = format(graph_template, env);
1097
1098 auto graph = std::make_shared<Graph>();
1099 parseIR(graph_string, &*graph);
1100
1101 TensorExprKernel k(graph);
1102 std::vector<at::Tensor> inputs = {a};
1103 StmtPtr s = k.getCodeGenStmt();
1104
1105 std::ostringstream oss;
1106 oss << *s;
1107
1108 at::jit::TemplateEnv ver_env;
1109 ver_env.d("other_dim", other_dim);
1110 ver_env.d("other_dim_size", a.sizes()[other_dim]);
1111 ver_env.d("softmax_dim", softmax_dim);
1112 ver_env.d("softmax_dim_size", softmax_dim_size);
1113 const auto verification_pattern =
1114 format(verification_template, ver_env);
1115
1116 // verication sting temporarily disabled until
1117 // inlining of exp() is benchmarked and determined
1118 // torch::jit::testing::FileCheck().run(verification_pattern,
1119 // oss.str());
1120
1121 std::vector<IValue> stack = fmap<IValue>(inputs);
1122 k.run(stack);
1123 auto output = stack[0].toTensor();
1124 ASSERT_EQ(output.sizes(), ref.sizes());
1125 ASSERT_TRUE(at::allclose(output, ref));
1126 }
1127 }
1128 }
1129}
1130
1131TEST_F(Kernel, Softmax3D) {
1132 const auto graph_template = R"IR(
1133 graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)):
1134 %1 : int = prim::Constant[value=${dim}]()
1135 %2 : int = prim::Constant[value=7]()
1136 %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2)
1137 return (%3))IR";
1138
1139 auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat));
1140
1141 const std::string& verification_template =
1142 R"IR(
1143 # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size}
1144 # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size}
1145 # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1146 # CHECK-NEXT: aten_softmax_max
1147 # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size}
1148 # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size}
1149 # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1150 # CHECK-NEXT: aten_softmax_sum
1151 # CHECK: for (int i0_2 = 0; i0_2 < 3
1152 # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4
1153 # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5
1154 # CHECK-NEXT: aten_softmax)IR";
1155
1156 for (auto log_softmax : {false, true}) {
1157 for (const auto softmax_dim : c10::irange(a.dim())) {
1158 auto softmax_dim_size = a.sizes()[softmax_dim];
1159 std::vector<int> other_dims;
1160 for (const auto i : c10::irange(a.dim())) {
1161 if (i != softmax_dim) {
1162 other_dims.push_back(i);
1163 }
1164 }
1165 auto ref =
1166 log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1167
1168 at::jit::TemplateEnv env;
1169 env.d("dim", softmax_dim);
1170 env.s("op", log_softmax ? "log_softmax" : "softmax");
1171 env.s("size", li_to_str(ref.sizes()));
1172 env.s("strides", li_to_str(ref.strides()));
1173
1174 const auto graph_string = format(graph_template, env);
1175
1176 auto graph = std::make_shared<Graph>();
1177 parseIR(graph_string, &*graph);
1178
1179 TensorExprKernel k(graph);
1180 std::vector<at::Tensor> inputs = {a};
1181 StmtPtr s = k.getCodeGenStmt();
1182
1183 std::ostringstream oss;
1184 oss << *s;
1185
1186 at::jit::TemplateEnv ver_env;
1187 ver_env.d("dim1", other_dims[0]);
1188 ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
1189 ver_env.d("dim2", other_dims[1]);
1190 ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
1191 ver_env.d("softmax_dim", softmax_dim);
1192 ver_env.d("softmax_dim_size", softmax_dim_size);
1193 const auto verification_pattern = format(verification_template, ver_env);
1194
1195 // verication sting temporarily disabled until
1196 // inlining of exp() is benchmarked and determined
1197 // torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1198
1199 std::vector<IValue> stack = fmap<IValue>(inputs);
1200 k.run(stack);
1201 auto output = stack[0].toTensor();
1202
1203 ASSERT_EQ(output.sizes(), ref.sizes());
1204 ASSERT_TRUE(at::allclose(output, ref));
1205 }
1206 }
1207}
1208
1209TEST_F(Kernel, Softmax4D) {
1210 const auto graph_template = R"IR(
1211 graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)):
1212 %1 : int = prim::Constant[value=${dim}]()
1213 %2 : int = prim::Constant[value=7]()
1214 %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2)
1215 return (%3))IR";
1216
1217 auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1218
1219 const std::string& verification_template =
1220 R"IR(
1221 # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size}
1222 # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size}
1223 # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size}
1224 # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1225 # CHECK-NEXT: aten_softmax_max
1226 # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size}
1227 # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size}
1228 # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size}
1229 # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1230 # CHECK-NEXT: aten_softmax_sum
1231 # CHECK: for (int i0_2 = 0; i0_2 < 2
1232 # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
1233 # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2
1234 # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3
1235 # CHECK-NEXT: aten_softmax)IR";
1236
1237 for (auto log_softmax : {false, true}) {
1238 for (const auto softmax_dim : c10::irange(a.dim())) {
1239 auto softmax_dim_size = a.sizes()[softmax_dim];
1240 std::vector<int> other_dims;
1241 for (const auto i : c10::irange(a.dim())) {
1242 if (i != softmax_dim) {
1243 other_dims.push_back(i);
1244 }
1245 }
1246 auto ref =
1247 log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1248
1249 at::jit::TemplateEnv env;
1250 env.d("dim", softmax_dim);
1251 env.s("op", log_softmax ? "log_softmax" : "softmax");
1252 env.s("size", li_to_str(ref.sizes()));
1253 env.s("strides", li_to_str(ref.strides()));
1254
1255 const auto graph_string = format(graph_template, env);
1256
1257 auto graph = std::make_shared<Graph>();
1258 parseIR(graph_string, &*graph);
1259
1260 TensorExprKernel k(graph);
1261 std::vector<at::Tensor> inputs = {a};
1262 StmtPtr s = k.getCodeGenStmt();
1263
1264 std::ostringstream oss;
1265 oss << *s;
1266
1267 at::jit::TemplateEnv ver_env;
1268 ver_env.d("dim1", other_dims[0]);
1269 ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
1270 ver_env.d("dim2", other_dims[1]);
1271 ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
1272 ver_env.d("dim3", other_dims[2]);
1273 ver_env.d("dim3_size", a.sizes()[other_dims[2]]);
1274 ver_env.d("softmax_dim", softmax_dim);
1275 ver_env.d("softmax_dim_size", softmax_dim_size);
1276 const auto verification_pattern = format(verification_template, ver_env);
1277
1278 // verication sting temporarily disabled until
1279 // inlining of exp() is benchmarked and determined
1280 // torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1281
1282 std::vector<IValue> stack = fmap<IValue>(inputs);
1283 k.run(stack);
1284 auto output = stack[0].toTensor();
1285 ASSERT_EQ(output.sizes(), ref.sizes());
1286 ASSERT_TRUE(at::allclose(output, ref));
1287 }
1288 }
1289}
1290
1291TEST_F(Kernel, SignTest) {
1292 const auto graph_template = R"IR(
1293 graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)):
1294 %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0)
1295 return (%2))IR";
1296
1297 auto run_test = [](const std::string& graph_string, const at::Tensor& input) {
1298 auto graph = std::make_shared<Graph>();
1299 parseIR(graph_string, &*graph);
1300
1301 TensorExprKernel k(graph);
1302 StmtPtr s = k.getCodeGenStmt();
1303
1304 std::vector<at::Tensor> inputs = {input};
1305 std::vector<IValue> stack = fmap<IValue>(inputs);
1306 k.run(stack);
1307 auto o = stack[0].toTensor();
1308 auto ref = at::sign(input);
1309 ASSERT_TRUE(at::allclose(o, ref));
1310 };
1311 auto common_options = at::TensorOptions()
1312 .layout(at::kStrided)
1313 .device(at::kCPU)
1314 .requires_grad(false);
1315 int default_input_size = 100;
1316 for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) {
1317 at::Tensor corner_case_inputs;
1318 at::jit::TemplateEnv env;
1319 auto options = common_options;
1320 switch (scalar_type) {
1321 case ScalarType::Float: {
1322 env.s("dtype", "Float");
1323 options = options.dtype(at::kFloat);
1324 std::vector<float> input_float = {
1325 0.0f,
1326 -0.0f,
1327 std::numeric_limits<float>::infinity(),
1328 -std::numeric_limits<float>::infinity(),
1329 std::nanf("1"),
1330 -std::nanf("1")};
1331 corner_case_inputs = at::from_blob(
1332 input_float.data(),
1333 {static_cast<long>(input_float.size())},
1334 options);
1335 auto rand_input = at::rand({default_input_size}, options);
1336 auto input = at::cat({rand_input, corner_case_inputs});
1337 env.d("size", at::numel(input));
1338 const auto graph_string = format(graph_template, env);
1339 run_test(graph_string, input);
1340 break;
1341 }
1342 case ScalarType::Double: {
1343 env.s("dtype", "Double");
1344 options = options.dtype(at::kDouble);
1345 std::vector<double> input_double = {
1346 0.0,
1347 -0.0,
1348 std::numeric_limits<double>::infinity(),
1349 -std::numeric_limits<double>::infinity(),
1350 std::nan("1"),
1351 -std::nan("1")};
1352 corner_case_inputs = at::from_blob(
1353 input_double.data(),
1354 {static_cast<long>(input_double.size())},
1355 options);
1356 auto rand_input = at::rand({default_input_size}, options);
1357 auto input = at::cat({rand_input, corner_case_inputs});
1358 env.d("size", at::numel(input));
1359 const auto graph_string = format(graph_template, env);
1360 run_test(graph_string, input);
1361 break;
1362 }
1363 default:
1364 throw unsupported_dtype();
1365 }
1366 }
1367}
1368
1369TEST_F(Kernel, InlineProducerIntoReduction) {
1370 // Inline producer (mul) into reduction (sum).
1371 const auto graph_string = R"IR(
1372 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1373 %1 : Float(5, 3, strides=[3, 1], device=cpu)):
1374 %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
1375 %3 : int = prim::Constant[value=7]()
1376 %4 : Double(device=cpu) = aten::sum(%2, %3)
1377 return (%4))IR";
1378 auto graph = std::make_shared<Graph>();
1379 parseIR(graph_string, &*graph);
1380
1381 TensorExprKernel k(graph);
1382 StmtPtr s = k.getCodeGenStmt();
1383 std::ostringstream oss;
1384 oss << *s;
1385
1386 // Check the IR we produced.
1387 // We should have only one loop in the end.
1388 const std::string& verification_pattern =
1389 R"IR(
1390 # CHECK: for (int64_t i_1 = 0ll; i_1 < 5
1391 # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3
1392 # CHECK-NEXT: sum
1393 # CHECK-NOT: for)IR";
1394 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1395
1396 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1397 auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1398 std::vector<at::Tensor> inputs = {a, b};
1399 std::vector<IValue> stack = fmap<IValue>(inputs);
1400 k.run(stack);
1401 auto o = stack[0].toTensor();
1402 auto ref = (a * b).sum(at::kDouble);
1403 ASSERT_TRUE(at::allclose(o, ref));
1404}
1405
1406TEST_F(Kernel, InlineReductionIntoConsumer) {
1407 // Inline producer (mul %2) into reduction (sum %4) but DO NOT
1408 // inline the reduction into consumer (mul %4).
1409 const auto graph_string = R"IR(
1410 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1411 %1 : Float(5, 3, strides=[3, 1], device=cpu)):
1412 %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1413 %3 : int = prim::Constant[value=6]()
1414 %4 : Float(device=cpu) = aten::sum(%2, %3)
1415 %5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4)
1416 return (%5))IR";
1417 auto graph = std::make_shared<Graph>();
1418 parseIR(graph_string, &*graph);
1419
1420 TensorExprKernel k(graph);
1421 StmtPtr s = k.getCodeGenStmt();
1422 std::ostringstream oss;
1423 oss << *s;
1424
1425 // Check the IR we produced.
1426 // We should have two loops in the end.
1427 const std::string& verification_pattern =
1428 R"IR(
1429 # CHECK: for (int64_t i_1 = 0ll; i_1 < 5
1430 # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3
1431 # CHECK-NEXT: sum
1432 # CHECK: for (int64_t i_2 = 0ll; i_2 < 5
1433 # CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3
1434 # CHECK-NEXT: aten_mul
1435 # CHECK-NOT: for)IR";
1436 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1437
1438 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1439 auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1440 std::vector<at::Tensor> inputs = {a, b};
1441 std::vector<IValue> stack = fmap<IValue>(inputs);
1442 k.run(stack);
1443 auto o = stack[0].toTensor();
1444 auto ref = (a * b).sum(at::kFloat) * (a * b);
1445 ASSERT_TRUE(at::allclose(o, ref));
1446}
1447
1448TEST_F(Kernel, SanitizeNames_CUDA) {
1449 const auto graph_string = R"IR(
1450 graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0),
1451 %1 : Float(5, 3, strides=[3, 1], device=cuda:0)):
1452 %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1453 %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1454 return (%4))IR";
1455 auto graph = std::make_shared<Graph>();
1456 parseIR(graph_string, &*graph);
1457 graph->inputs().at(0)->setDebugName("aten::add:");
1458 graph->inputs().at(1)->setDebugName("aten::add_");
1459 TensorExprKernel k(graph);
1460 auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat));
1461 auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat));
1462 auto ref = a * (a * b);
1463 std::vector<at::Tensor> inputs = {a, b};
1464 std::vector<IValue> stack = fmap<IValue>(inputs);
1465 k.run(stack);
1466 auto o = stack[0].toTensor();
1467 ASSERT_TRUE(at::allclose(o, ref));
1468}
1469
1470TEST_F(Kernel, SanitizeConstants_CUDA) {
1471 const auto graph_string = R"IR(
1472 graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)):
1473 %none : NoneType = prim::Constant()
1474 %size : int = prim::Constant[value=16]()
1475 %sizes : int[] = prim::ListConstruct(%size, %size)
1476 %30 : Device = prim::Constant[value="cuda"]()
1477 %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none)
1478 %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y)
1479 return (%z))IR";
1480 auto graph = std::make_shared<Graph>();
1481 parseIR(graph_string, &*graph);
1482 // IRParser doesn't support tensor constants, so we insert a call to
1483 // aten::ones and then const-prop it
1484 ConstantPropagation(graph);
1485
1486 // We set the name of the constant to include special characters that are
1487 // not allowed. This should be fixed by the sanitizer in TensorExprKernel.
1488 graph->nodes().front()->output()->setDebugName("illegal.name");
1489
1490 // Check if we have a constant node with illegal name in the graph.
1491 auto const_node = graph->nodes().front();
1492 ASSERT_EQ(const_node->kind(), prim::Constant);
1493 ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos);
1494
1495 TensorExprKernel k(graph);
1496
1497 auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat));
1498 std::vector<at::Tensor> inputs = {x};
1499 std::vector<IValue> stack = fmap<IValue>(inputs);
1500 k.run(stack);
1501 auto o = stack[0].toTensor();
1502 auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat));
1503 auto ref = x * y;
1504 ASSERT_TRUE(at::allclose(o, ref));
1505}
1506
1507TEST_F(Kernel, ConstantTensors) {
1508 const auto graph_string = R"IR(
1509 graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1510 %none : NoneType = prim::Constant()
1511 %size : int = prim::Constant[value=16]()
1512 %sizes : int[] = prim::ListConstruct(%size, %size)
1513 %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none)
1514 %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1515 return (%z))IR";
1516 auto graph = std::make_shared<Graph>();
1517 parseIR(graph_string, &*graph);
1518 // IRParser doesn't support tensor constants, so we insert a call to
1519 // aten::ones and then const-prop it
1520 ConstantPropagation(graph);
1521
1522 TensorExprKernel k(graph);
1523
1524 auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1525 std::vector<at::Tensor> inputs = {x};
1526 std::vector<IValue> stack = fmap<IValue>(inputs);
1527 k.run(stack);
1528 auto o = stack[0].toTensor();
1529 auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1530 auto ref = x * y;
1531 ASSERT_TRUE(at::allclose(o, ref));
1532}
1533
1534TEST_F(Kernel, ConstantTensorsNonContiguous) {
1535 const auto graph_string = R"IR(
1536 graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1537 %none : NoneType = prim::Constant()
1538 %dtype : int = prim::Constant[value=6]()
1539 %c0 : int = prim::Constant[value=0]()
1540 %c256 : int = prim::Constant[value=256]()
1541 %c16 : int = prim::Constant[value=16]()
1542 %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none)
1543 %sizes : int[] = prim::ListConstruct(%c16, %c16)
1544 %y_t : Tensor = aten::view(%y_flat, %sizes)
1545 %y : Tensor = aten::t(%y_t)
1546 %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1547 return (%z))IR";
1548 auto graph = std::make_shared<Graph>();
1549 parseIR(graph_string, &*graph);
1550 // IRParser doesn't support tensor constants, so we generate several aten
1551 // calls to produce non-contiguos constant tensor and then const-prop it
1552 ConstantPropagation(graph);
1553
1554 TensorExprKernel k(graph);
1555
1556 auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1557 std::vector<at::Tensor> inputs = {x};
1558 std::vector<IValue> stack = fmap<IValue>(inputs);
1559 k.run(stack);
1560 auto o = stack[0].toTensor();
1561 auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat))
1562 .view({16, 16})
1563 .t();
1564 auto ref = x * y;
1565 ASSERT_TRUE(at::allclose(o, ref));
1566}
1567
1568TEST_F(Kernel, RunFast) {
1569#ifdef TORCH_ENABLE_LLVM
1570 // TODO: Implement call_raw in IREval and remove the ifdef
1571
1572 const auto graph_string = R"IR(
1573 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1574 %1 : Float(5, 3, strides=[1, 5], device=cpu)):
1575 %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1576 %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1577 return (%3))IR";
1578 auto graph = std::make_shared<Graph>();
1579 parseIR(graph_string, &*graph);
1580
1581 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1582 auto b =
1583 at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1584 auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1585 auto ref = a * (a * b);
1586 TensorExprKernel k(graph);
1587
1588 k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()});
1589 for (size_t i = 0; i < 5 * 3; i++) {
1590 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1591 }
1592#endif
1593}
1594
1595TEST_F(Kernel, RunWithAllocatedOutputs) {
1596#ifdef TORCH_ENABLE_LLVM
1597 const auto graph_string = R"IR(
1598 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1599 %1 : Float(5, 3, strides=[1, 5], device=cpu)):
1600 %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1601 %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1602 return (%3))IR";
1603 auto graph = std::make_shared<Graph>();
1604 parseIR(graph_string, &*graph);
1605
1606 auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1607 auto b =
1608 at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1609 auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1610 auto ref = a * (a * b);
1611 TensorExprKernel k(graph);
1612
1613 std::vector<at::Tensor> args = {o, a, b};
1614 std::vector<IValue> stack = fmap<IValue>(args);
1615 k.runWithAllocatedOutputs(stack);
1616 for (size_t i = 0; i < 5 * 3; i++) {
1617 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1618 }
1619#endif
1620}
1621
1622TEST_F(Kernel, CodegenInspection) {
1623#ifdef TORCH_ENABLE_LLVM
1624 const auto graph_string = R"IR(
1625 graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1626 %none : NoneType = prim::Constant()
1627 %dtype : int = prim::Constant[value=6]()
1628 %c0 : int = prim::Constant[value=0]()
1629 %c256 : int = prim::Constant[value=256]()
1630 %c16 : int = prim::Constant[value=16]()
1631 %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none)
1632 %sizes : int[] = prim::ListConstruct(%c16, %c16)
1633 %y_t : Tensor = aten::view(%y_flat, %sizes)
1634 %y : Tensor = aten::t(%y_t)
1635 %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1636 return (%z))IR";
1637 auto graph = std::make_shared<Graph>();
1638 parseIR(graph_string, &*graph);
1639 // IRParser doesn't support tensor constants, so we generate several aten
1640 // calls to produce non-contiguos constant tensor and then const-prop it
1641 ConstantPropagation(graph);
1642
1643 TensorExprKernel k(graph);
1644
1645 // Check that we could retrieve generated assembly
1646 auto asm_str = k.getCodeText("asm");
1647 const std::string& asm_verification_pattern =
1648 R"ASM(
1649 # CHECK: .text
1650 # CHECK: retq)ASM";
1651 torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str);
1652
1653 // Check that we could retrieve info about codegen parameters
1654 auto constants = k.getConstantDescriptors();
1655 auto buf_args = k.getBufferArgs();
1656 // Expected buf args: [input0, output0, constant0]
1657 ASSERT_EQ(buf_args.size(), 3);
1658 ASSERT_EQ(constants.size(), 1);
1659 ASSERT_TRUE(
1660 !buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar());
1661#endif
1662}
1663
1664Tensor lowerNanToNum(
1665 const std::vector<ArgValue>& inputs,
1666 const std::vector<ExprHandle>& outputShape,
1667 const std::vector<ExprHandle>& outputStrides,
1668 const c10::optional<ScalarType>& outputType,
1669 at::Device device) {
1670 auto input_buf = c10::get<BufHandle>(inputs[0]);
1671 auto e = Compute(
1672 "custom_nan_to_num",
1673 outputShape,
1674 outputStrides,
1675 [&](const std::vector<VarHandle>& axes) {
1676 std::vector<ExprHandle> indices(axes.begin(), axes.end());
1677 auto load = input_buf.load(indices);
1678 return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load);
1679 });
1680 return e;
1681}
1682
1683TEST_F(Kernel, CustomLowering) {
1684 const auto graph_string = R"IR(
1685 graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
1686 %none : NoneType = prim::Constant()
1687 %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none)
1688 return (%y)
1689)IR";
1690 auto graph = std::make_shared<Graph>();
1691 parseIR(graph_string, &*graph);
1692
1693 std::unordered_map<c10::Symbol, NNCLoweringFunction> lowerings = {
1694 {aten::nan_to_num, lowerNanToNum}};
1695 TensorExprKernel k(graph, lowerings);
1696
1697 auto stmt = k.getCodeGenStmt();
1698 std::ostringstream oss;
1699 oss << *stmt;
1700
1701 // Check that our custom lowering is actually used
1702 torch::jit::testing::FileCheck().check("custom_nan_to_num")->run(oss.str());
1703 torch::jit::testing::FileCheck().check("isnan")->run(oss.str());
1704}
1705
1706TEST_F(Kernel, Vectorize) {
1707#ifdef TORCH_ENABLE_LLVM
1708 const auto graph_string = R"IR(
1709 graph(%0 : Float(100, 16, strides=[16, 1], device=cpu),
1710 %1 : Float(100, 16, strides=[16, 1], device=cpu)):
1711 %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1)
1712 %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2)
1713 return (%3))IR";
1714 auto graph = std::make_shared<Graph>();
1715 parseIR(graph_string, &*graph);
1716
1717 auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1718 auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1719 auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1720 auto ref = a * (a * b);
1721 TensorExprKernel k(graph);
1722 std::vector<at::Tensor> inputs = {a, b};
1723 StmtPtr s = k.getCodeGenStmt();
1724
1725 std::ostringstream oss;
1726 oss << *s;
1727
1728 // Check the IR we produced
1729 const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
1730 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1731
1732 std::vector<IValue> stack = fmap<IValue>(inputs);
1733 k.run(stack);
1734 o = stack[0].toTensor();
1735 for (size_t i = 0; i < 100 * 16; i++) {
1736 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1737 }
1738#endif
1739}
1740
1741// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first.
1742TEST_F(Kernel, DISABLED_FlattenVectorize) {
1743#ifdef TORCH_ENABLE_LLVM
1744 const auto graph_string = R"IR(
1745 graph(%0 : Float(100, 3, strides=[3, 1], device=cpu),
1746 %1 : Float(100, 3, strides=[3, 1], device=cpu)):
1747 %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1)
1748 %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2)
1749 return (%3))IR";
1750 auto graph = std::make_shared<Graph>();
1751 parseIR(graph_string, &*graph);
1752
1753 auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1754 auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1755 auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1756 auto ref = a * (a * b);
1757 TensorExprKernel k(graph);
1758 std::vector<at::Tensor> inputs = {a, b};
1759 StmtPtr s = k.getCodeGenStmt();
1760
1761 std::ostringstream oss;
1762 oss << *s;
1763
1764 // Check the IR we produced
1765 const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
1766 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1767
1768 std::vector<IValue> stack = fmap<IValue>(inputs);
1769 k.run(stack);
1770 o = stack[0].toTensor();
1771 for (size_t i = 0; i < 100 * 3; i++) {
1772 TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1773 }
1774#endif
1775}
1776
1777TEST_F(Kernel, Strided1dWithinBounds) {
1778 auto ir = R"IR(
1779 graph(%0 : Float(3, strides=[1], device=cpu),
1780 %1 : Float(3, strides=[2], device=cpu)):
1781 %2 : int = prim::Constant[value=1]()
1782 %3 : Float(3, strides=[1]) = aten::add(%0, %1, %2)
1783 return (%3))IR";
1784 auto graph = std::make_shared<Graph>();
1785 std::unordered_map<std::string, Value*> vmap;
1786 parseIR(ir, graph.get(), vmap);
1787 TensorExprKernel k(graph);
1788
1789 auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat));
1790 auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat))
1791 .index({Slice(None, None, 2)});
1792 auto expect = a + b;
1793
1794 std::vector<at::Tensor> inputs = {a, b};
1795
1796 std::vector<IValue> stack = fmap<IValue>(inputs);
1797 k.run(stack);
1798
1799 auto output = stack[0].toTensor();
1800
1801 for (size_t i = 0; i < 3; ++i) {
1802 TORCH_CHECK_EQ(
1803 ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]);
1804 }
1805}
1806
1807TEST_F(Kernel, InputAsOutput) {
1808 const auto graph_string = R"IR(
1809 graph(%x : Float(5, 3, strides=[3, 1], device=cpu),
1810 %y : Float(5, 3, strides=[1, 5], device=cpu)):
1811 return (%x, %y))IR";
1812 auto graph = std::make_shared<Graph>();
1813 parseIR(graph_string, &*graph);
1814
1815 auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1816 auto y =
1817 at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1818 TensorExprKernel k(graph);
1819 std::vector<at::Tensor> inputs = {x, y};
1820
1821 std::vector<IValue> stack = fmap<IValue>(inputs);
1822 k.run(stack);
1823 CHECK(at::allclose(x, stack[0].toTensor()));
1824 CHECK(at::allclose(y, stack[1].toTensor()));
1825}
1826
1827TEST_F(Kernel, ScalarOut) {
1828 auto ir = R"IR(
1829graph(%x : int, %y : int):
1830 %z : int = aten::mul(%x, %y)
1831 %r : int = aten::mul(%z, %x)
1832 return (%r, %z))IR";
1833 auto graph = std::make_shared<Graph>();
1834 std::unordered_map<std::string, Value*> vmap;
1835 parseIR(ir, graph.get(), vmap);
1836 TensorExprKernel k(graph);
1837
1838 auto stmt = k.getCodeGenStmt();
1839 std::ostringstream oss;
1840 oss << *stmt;
1841
1842 // Verify the generated IR. We expect to see a scalar variable (Let) followed
1843 // by a store to a 0-dim buffer.
1844 const std::string& verification_pattern = R"IR(
1845# CHECK: int64_t
1846# CHECK-NEXT: [0ll] =
1847# CHECK-NEXT: int64_t
1848# CHECK-NEXT: [0ll] =
1849)IR";
1850 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1851
1852 int64_t x = 2, y = 3, r = 0, z = 0;
1853
1854 // Verify that TEK::runFast works correctly with scalar outputs
1855 std::vector<void*> inputs = {&x, &y};
1856 std::vector<void*> outputs = {&r, &z};
1857 k.runFast(inputs, outputs);
1858 TORCH_CHECK_EQ(z, x * y);
1859 TORCH_CHECK_EQ(r, z * x);
1860
1861 // Verify that TEK::run works correctly with scalar outputs
1862 std::vector<IValue> stack = {x, y};
1863 k.run(stack);
1864 TORCH_CHECK_EQ(stack[0], x * y * x);
1865 TORCH_CHECK_EQ(stack[1], x * y);
1866}
1867
1868TEST_F(Kernel, ScalarTensorOut) {
1869 auto ir = R"IR(
1870graph(%x : int,
1871 %xt : Long(3, strides=[1], device=cpu),
1872 %y : int,
1873 %yt : Long(3, strides=[1], device=cpu)):
1874 %z : int = aten::mul(%x, %y)
1875 %r : int = aten::mul(%z, %x)
1876 %zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y)
1877 %rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt)
1878 return (%r, %rt, %z, %zt))IR";
1879 auto graph = std::make_shared<Graph>();
1880 std::unordered_map<std::string, Value*> vmap;
1881 parseIR(ir, graph.get(), vmap);
1882 TensorExprKernel k(graph);
1883 int64_t x = 2, y = 3, r = 0, z = 0;
1884 auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2;
1885 auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3;
1886 auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong));
1887 auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong));
1888
1889 // Verify that TEK::runFast works correctly with mixed scalar and tensor
1890 // inputs/utputs
1891 std::vector<void*> inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()};
1892 std::vector<void*> outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()};
1893 k.runFast(inputs, outputs);
1894 TORCH_CHECK_EQ(z, x * y);
1895 TORCH_CHECK_EQ(r, z * x);
1896 ASSERT_TRUE(at::equal(zt, xt * yt));
1897 ASSERT_TRUE(at::equal(rt, zt * xt));
1898
1899 // Verify that TEK::run works correctly with mixed scalar and tensor
1900 // inputs/utputs
1901 std::vector<IValue> stack = {x, xt, y, yt};
1902 k.run(stack);
1903 TORCH_CHECK_EQ(stack[0], x * y * x);
1904 ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt));
1905 TORCH_CHECK_EQ(stack[2], x * y);
1906 ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt));
1907}
1908
1909TEST_F(Kernel, FuseLoopsWithVariableBounds) {
1910#ifdef TORCH_ENABLE_LLVM
1911 bool old_cat_wo_conditionals = getCatWoConditionals();
1912 getCatWoConditionals() = true;
1913 const auto graph_string = R"IR(
1914 graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu),
1915 %b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu),
1916 %c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu),
1917 %SS_2 : int,
1918 %SS_3 : int):
1919 %dim : int = prim::Constant[value=1]()
1920 %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
1921 %r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2]
1922 return (%r))IR";
1923 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1924 torch::jit::parseIR(graph_string, graph.get());
1925
1926 std::vector<int64_t> symbolic_shape_inputs = {-2, -3};
1927
1928 std::vector<torch::jit::StrideInput> input_desc = {
1929 torch::jit::StrideInput::TENSOR_CONT};
1930 std::unordered_map<
1931 const torch::jit::Value*,
1932 std::vector<torch::jit::StrideInput>>
1933 symbolic_strides;
1934 symbolic_strides[graph->inputs().at(0)] = input_desc;
1935 symbolic_strides[graph->inputs().at(1)] = input_desc;
1936 symbolic_strides[graph->inputs().at(2)] = input_desc;
1937 symbolic_strides[graph->outputs().at(0)] = input_desc;
1938
1939 TensorExprKernel kernel(
1940 graph, {}, symbolic_shape_inputs, false, symbolic_strides);
1941
1942 std::ostringstream oss;
1943 oss << *kernel.getCodeGenStmt();
1944 const std::string& verification_pattern =
1945 R"IR(
1946# CHECK: for (int64_t i
1947# CHECK-NEXT: for (int64_t j
1948# CHECK-NEXT: for (int64_t k
1949# CHECK: for (int64_t j
1950# CHECK-NEXT: for (int64_t k
1951# CHECK: for (int64_t j
1952# CHECK-NEXT: for (int64_t k
1953# CHECK-NOT: for (int64_t i
1954 )IR";
1955 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1956
1957 auto run_kernel = [&](int dim1, int dim2) {
1958 auto a =
1959 at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1960 auto b =
1961 at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1962 auto c =
1963 at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1964
1965 auto ref = at::cat({a, b, c}, 1);
1966
1967 std::vector<IValue> stack =
1968 fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
1969 stack.emplace_back(dim1);
1970 stack.emplace_back(dim2);
1971 kernel.run(stack);
1972
1973 auto o = stack[0].toTensor();
1974 ASSERT_TRUE(at::allclose(o, ref));
1975 };
1976
1977 run_kernel(10, 20);
1978 getCatWoConditionals() = old_cat_wo_conditionals;
1979#endif
1980}
1981
1982TEST_F(Kernel, FuseLoopsWithVariableConcatDim) {
1983#ifdef TORCH_ENABLE_LLVM
1984 bool old_cat_wo_conditionals = getCatWoConditionals();
1985 getCatWoConditionals() = true;
1986 const auto graph_string = R"IR(
1987 graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1988 %b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1989 %c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1990 %SS_2 : int,
1991 %SS_3 : int,
1992 %SS_4 : int,
1993 %SS_5 : int):
1994 %dim : int = prim::Constant[value=1]()
1995 %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
1996 %r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2]
1997 return (%r))IR";
1998 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1999 torch::jit::parseIR(graph_string, graph.get());
2000
2001 std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5};
2002
2003 std::vector<torch::jit::StrideInput> input_desc = {
2004 torch::jit::StrideInput::TENSOR_CONT};
2005 std::unordered_map<
2006 const torch::jit::Value*,
2007 std::vector<torch::jit::StrideInput>>
2008 symbolic_strides;
2009 symbolic_strides[graph->inputs().at(0)] = input_desc;
2010 symbolic_strides[graph->inputs().at(1)] = input_desc;
2011 symbolic_strides[graph->inputs().at(2)] = input_desc;
2012 symbolic_strides[graph->outputs().at(0)] = input_desc;
2013
2014 TensorExprKernel kernel(
2015 graph, {}, symbolic_shape_inputs, false, symbolic_strides);
2016
2017 std::ostringstream oss;
2018 oss << *kernel.getCodeGenStmt();
2019 const std::string& verification_pattern =
2020 R"IR(
2021# CHECK: for (int64_t i
2022# CHECK-NEXT: for (int64_t j
2023# CHECK-NEXT: for (int64_t k
2024# CHECK: for (int64_t j
2025# CHECK-NEXT: for (int64_t k
2026# CHECK: for (int64_t j
2027# CHECK-NEXT: for (int64_t k
2028# CHECK-NOT: for (int64_t i
2029 )IR";
2030 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2031
2032 auto run_kernel = [&](int dim1, int dim2, int dim3) {
2033 auto a =
2034 at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2035 auto b =
2036 at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2037 auto c =
2038 at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2039
2040 auto ref = at::cat({a, b, c}, 1);
2041
2042 std::vector<IValue> stack =
2043 fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
2044 stack.emplace_back(dim1);
2045 stack.emplace_back(dim2);
2046 stack.emplace_back(dim3);
2047 stack.emplace_back(3 * dim3);
2048 kernel.run(stack);
2049
2050 auto o = stack[0].toTensor();
2051 ASSERT_TRUE(at::allclose(o, ref));
2052 };
2053
2054 run_kernel(10, 20, 15);
2055 getCatWoConditionals() = old_cat_wo_conditionals;
2056#endif
2057}
2058
2059TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) {
2060#ifdef TORCH_ENABLE_LLVM
2061 bool old_cat_wo_conditionals = getCatWoConditionals();
2062 getCatWoConditionals() = true;
2063 const auto graph_string = R"IR(
2064 graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
2065 %b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu),
2066 %SS_2 : int,
2067 %SS_3 : int,
2068 %SS_4 : int,
2069 %SS_5 : int,
2070 %SS_6 : int):
2071 %dim : int = prim::Constant[value=1]()
2072 %inputs : Tensor[] = prim::ListConstruct(%a, %b)
2073 %r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2]
2074 return (%r))IR";
2075 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
2076 torch::jit::parseIR(graph_string, graph.get());
2077
2078 std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5, -6};
2079
2080 std::vector<torch::jit::StrideInput> input_desc = {
2081 torch::jit::StrideInput::TENSOR_CONT};
2082 std::unordered_map<
2083 const torch::jit::Value*,
2084 std::vector<torch::jit::StrideInput>>
2085 symbolic_strides;
2086 symbolic_strides[graph->inputs().at(0)] = input_desc;
2087 symbolic_strides[graph->inputs().at(1)] = input_desc;
2088 symbolic_strides[graph->outputs().at(0)] = input_desc;
2089
2090 TensorExprKernel kernel(
2091 graph, {}, symbolic_shape_inputs, false, symbolic_strides);
2092
2093 std::ostringstream oss;
2094 oss << *kernel.getCodeGenStmt();
2095 const std::string& verification_pattern =
2096 R"IR(
2097# CHECK: for (int64_t i
2098# CHECK-NEXT: for (int64_t j
2099# CHECK-NEXT: for (int64_t k
2100# CHECK: for (int64_t j
2101# CHECK-NEXT: for (int64_t k
2102# CHECK-NOT: for (int64_t j
2103# CHECK-NOT: for (int64_t i
2104 )IR";
2105 torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2106
2107 auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) {
2108 auto a =
2109 at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat));
2110 auto b =
2111 at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat));
2112
2113 auto ref = at::cat({a, b}, 1);
2114
2115 std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
2116 stack.emplace_back(dim2);
2117 stack.emplace_back(dim3);
2118 stack.emplace_back(dim4);
2119 stack.emplace_back(dim5);
2120 stack.emplace_back(dim4 + dim5);
2121 kernel.run(stack);
2122
2123 auto o = stack[0].toTensor();
2124 ASSERT_TRUE(at::allclose(o, ref));
2125 };
2126
2127 run_kernel(10, 20, 15, 8);
2128 getCatWoConditionals() = old_cat_wo_conditionals;
2129#endif
2130}
2131
2132} // namespace jit
2133} // namespace torch
2134