1 | #include <gtest/gtest.h> |
2 | |
3 | #include <ATen/code_template.h> |
4 | #include <c10/util/irange.h> |
5 | #include <test/cpp/tensorexpr/test_base.h> |
6 | #include <torch/csrc/jit/ir/ir.h> |
7 | #include <torch/csrc/jit/ir/irparser.h> |
8 | #include <torch/csrc/jit/passes/constant_propagation.h> |
9 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
10 | #include <torch/csrc/jit/tensorexpr/kernel.h> |
11 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
12 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
13 | #include <torch/csrc/jit/testing/file_check.h> |
14 | #include <torch/torch.h> |
15 | #include <cmath> |
16 | #include <sstream> |
17 | #include <stdexcept> |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | |
22 | using namespace torch::indexing; |
23 | using namespace torch::jit::tensorexpr; |
24 | |
25 | class Kernel : public ::testing::Test { |
26 | public: |
27 | void SetUp() override { |
28 | getTEMustUseLLVMOnCPU() = false; |
29 | } |
30 | }; |
31 | |
32 | TEST_F(Kernel, ParallelExternalCallBuf) { |
33 | const auto graph_string = R"IR( |
34 | graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu), |
35 | %1 : Float(1000, 5000, strides=[5000, 1], device=cpu), |
36 | %2 : Float(5000, 1000, strides=[5000, 1], device=cpu)): |
37 | %3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1) |
38 | %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2) |
39 | return (%4))IR" ; |
40 | auto graph = std::make_shared<Graph>(); |
41 | torch::jit::parseIR(graph_string, &*graph); |
42 | const std::string& verification_pattern = |
43 | R"IR( |
44 | # CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR" ; |
45 | |
46 | #ifdef TORCH_ENABLE_LLVM |
47 | TensorExprKernel k(graph); |
48 | StmtPtr s = k.getCodeGenStmt(); |
49 | std::ostringstream oss; |
50 | oss << *s; |
51 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
52 | #endif |
53 | } |
54 | |
55 | TEST_F(Kernel, InliningIntermediates) { |
56 | // here, each mul has only one use, so it should be completely inlined |
57 | { |
58 | const auto graph_string = R"IR( |
59 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
60 | %1 : Float(5, 3, strides=[3, 1], device=cpu)): |
61 | %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) |
62 | %one : int = prim::Constant[value=1]() |
63 | %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) |
64 | %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one) |
65 | return (%5))IR" ; |
66 | auto graph = std::make_shared<Graph>(); |
67 | parseIR(graph_string, &*graph); |
68 | TensorExprKernel k(graph); |
69 | auto stmt = k.getCodeGenStmt(); |
70 | std::ostringstream oss; |
71 | oss << *stmt; |
72 | torch::jit::testing::FileCheck().check_not("aten_mul" )->run(oss.str()); |
73 | } |
74 | { |
75 | const auto graph_template = R"IR( |
76 | graph(%0 : Float(5, 3, strides=[3, 1], device=${device}), |
77 | %1 : Float(5, 3, strides=[3, 1], device=${device})): |
78 | %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) |
79 | %one : int = prim::Constant[value=1]() |
80 | %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one) |
81 | %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one) |
82 | %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0) |
83 | return (%4, %5))IR" ; |
84 | for (bool use_cuda : {false, true}) { |
85 | if (!torch::cuda::is_available() && use_cuda) { |
86 | continue; |
87 | } |
88 | |
89 | at::jit::TemplateEnv env; |
90 | env.s("device" , use_cuda ? "cuda:0" : "cpu" ); |
91 | const auto graph_string = format(graph_template, env); |
92 | auto graph = std::make_shared<Graph>(); |
93 | parseIR(graph_string, &*graph); |
94 | TensorExprKernel k(graph); |
95 | auto stmt = k.getCodeGenStmt(); |
96 | std::ostringstream oss; |
97 | oss << *stmt; |
98 | // aten_mul only has one use, inlined completely |
99 | torch::jit::testing::FileCheck().check_not("aten_mul" )->run(oss.str()); |
100 | |
101 | // aten_sub should be removed by the CUDA backend by metavar rewriting |
102 | // and by the CPU backend by horizontal fusion. |
103 | torch::jit::testing::FileCheck().check_not("aten_sub" )->run(oss.str()); |
104 | } |
105 | } |
106 | } |
107 | |
108 | TEST_F(Kernel, PreAllocIntermediateBufs) { |
109 | const auto graph_string = R"IR( |
110 | graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu), |
111 | %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)): |
112 | %2 : int = prim::Constant[value=1]() |
113 | %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12 |
114 | %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15 |
115 | return (%3))IR" ; |
116 | auto graph = std::make_shared<Graph>(); |
117 | parseIR(graph_string, &*graph); |
118 | |
119 | auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); |
120 | auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); |
121 | auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); |
122 | auto ref = at::matmul(a, b) + a; |
123 | TensorExprKernel k(graph, {}, {}, true); |
124 | |
125 | std::vector<at::Tensor> inputs = {a, b}; |
126 | auto stmt = k.getCodeGenStmt(); |
127 | |
128 | std::ostringstream oss; |
129 | oss << *stmt; |
130 | |
131 | // Check whether the intermediate buffer has been added to constants |
132 | auto constants = k.getConstantDescriptors(); |
133 | ASSERT_EQ(constants.size(), 1); |
134 | |
135 | // Check the IR we produced |
136 | torch::jit::testing::FileCheck().check_not("Alloc" )->run(oss.str()); |
137 | torch::jit::testing::FileCheck().check_not("Free" )->run(oss.str()); |
138 | |
139 | // Check correctness |
140 | std::vector<IValue> stack = fmap<IValue>(inputs); |
141 | k.run(stack); |
142 | o = stack[0].toTensor(); |
143 | ASSERT_TRUE(at::allclose(o, ref)); |
144 | } |
145 | |
146 | TEST_F(Kernel, _1) { |
147 | const auto graph_string = R"IR( |
148 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
149 | %1 : Float(5, 3, strides=[3, 1], device=cpu)): |
150 | %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) |
151 | %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) |
152 | return (%3))IR" ; |
153 | auto graph = std::make_shared<Graph>(); |
154 | parseIR(graph_string, &*graph); |
155 | |
156 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
157 | auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
158 | auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
159 | auto ref = a * (a * b); |
160 | TensorExprKernel k(graph); |
161 | std::vector<at::Tensor> inputs = {a, b}; |
162 | StmtPtr s = k.getCodeGenStmt(); |
163 | |
164 | std::ostringstream oss; |
165 | oss << *s; |
166 | |
167 | // Check the IR we produced |
168 | const std::string& verification_pattern = |
169 | R"IR( |
170 | # CHECK: for |
171 | # CHECK-NEXT: for |
172 | # CHECK-NOT: for)IR" ; |
173 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
174 | |
175 | std::vector<IValue> stack = fmap<IValue>(inputs); |
176 | k.run(stack); |
177 | o = stack[0].toTensor(); |
178 | for (size_t i = 0; i < 5 * 3; i++) { |
179 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
180 | } |
181 | } |
182 | |
183 | TEST_F(Kernel, _2) { |
184 | const auto graph_string = R"IR( |
185 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
186 | %1 : Float(5, 3, strides=[1, 5], device=cpu)): |
187 | %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) |
188 | %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) |
189 | return (%3))IR" ; |
190 | auto graph = std::make_shared<Graph>(); |
191 | parseIR(graph_string, &*graph); |
192 | |
193 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
194 | auto b = |
195 | at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); |
196 | auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
197 | auto ref = a * (a * b); |
198 | TensorExprKernel k(graph); |
199 | std::vector<at::Tensor> inputs = {a, b}; |
200 | StmtPtr s = k.getCodeGenStmt(); |
201 | |
202 | std::ostringstream oss; |
203 | oss << *s; |
204 | |
205 | // Check the IR we produced |
206 | const std::string& verification_pattern = |
207 | R"IR( |
208 | # CHECK: for |
209 | # CHECK-NEXT: for |
210 | # CHECK-NOT: for)IR" ; |
211 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
212 | |
213 | std::vector<IValue> stack = fmap<IValue>(inputs); |
214 | k.run(stack); |
215 | o = stack[0].toTensor(); |
216 | for (size_t i = 0; i < 5 * 3; i++) { |
217 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
218 | } |
219 | } |
220 | |
221 | TEST_F(Kernel, _3) { |
222 | const auto graph_string = R"IR( |
223 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
224 | %1 : Float(5, 3, strides=[12, 2], device=cpu)): |
225 | %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) |
226 | %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) |
227 | return (%3))IR" ; |
228 | auto graph = std::make_shared<Graph>(); |
229 | parseIR(graph_string, &*graph); |
230 | |
231 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
232 | auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) |
233 | .index({Slice(None, None, 2), Slice(None, None, 2)}); |
234 | auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
235 | auto ref = a * (a * b); |
236 | TensorExprKernel k(graph); |
237 | std::vector<at::Tensor> inputs = {a, b}; |
238 | StmtPtr s = k.getCodeGenStmt(); |
239 | |
240 | std::ostringstream oss; |
241 | oss << *s; |
242 | |
243 | // Check the IR we produced |
244 | const std::string& verification_pattern = |
245 | R"IR( |
246 | # CHECK: for |
247 | # CHECK-NEXT: for |
248 | # CHECK-NOT: for)IR" ; |
249 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
250 | |
251 | std::vector<IValue> stack = fmap<IValue>(inputs); |
252 | k.run(stack); |
253 | o = stack[0].toTensor(); |
254 | for (size_t i = 0; i < 5 * 3; i++) { |
255 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
256 | } |
257 | } |
258 | |
259 | TEST_F(Kernel, Huge) { |
260 | const auto graph_string = R"IR( |
261 | graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)): |
262 | %1 : int = prim::Constant[value=0]() |
263 | %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1) |
264 | %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2) |
265 | return (%3))IR" ; |
266 | auto graph = std::make_shared<Graph>(); |
267 | parseIR(graph_string, &*graph); |
268 | TensorExprKernel k(graph); |
269 | std::ostringstream oss; |
270 | oss << *k.getCodeGenStmt(); |
271 | // The 4000000000 iterations loop will be split into 500000000 x 8 and the |
272 | // outer loop will be parallel. If LLVM is not present, it will not be split, |
273 | // and to cover both of these cases we're looking for 00000000ll; in the |
274 | // output. |
275 | const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR" ; |
276 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
277 | } |
278 | |
279 | TEST_F(Kernel, ParallelStrided) { |
280 | const auto graph_string = R"IR( |
281 | graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), |
282 | %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)): |
283 | %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1) |
284 | %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2) |
285 | return (%3))IR" ; |
286 | auto graph = std::make_shared<Graph>(); |
287 | parseIR(graph_string, &*graph); |
288 | |
289 | auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat)); |
290 | auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat)) |
291 | .index( |
292 | {Slice(None, None, 2), |
293 | Slice(None, None, 2), |
294 | Slice(None, None, 2)}); |
295 | auto ref = a * (a * b); |
296 | auto o = at::zeros_like(ref); |
297 | TensorExprKernel k(graph); |
298 | std::vector<at::Tensor> inputs = {a, b}; |
299 | std::vector<IValue> stack = fmap<IValue>(inputs); |
300 | k.run(stack); |
301 | o = stack[0].toTensor(); |
302 | for (size_t i = 0; i < 5 * 3; i++) { |
303 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
304 | } |
305 | } |
306 | |
307 | TEST_F(Kernel, DISABLED_Shape_Inference) { |
308 | // disabled: doesn't do stride propagation, and isn't being used currently |
309 | |
310 | // Test TensorExpr shape inference capabilities: it should only require shapes |
311 | // for the inputs |
312 | { |
313 | const auto graph_string = R"IR( |
314 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
315 | %1 : Float(5, 3, strides=[12, 2], device=cpu)): |
316 | %2 : Tensor = aten::mul(%0, %1) |
317 | %3 : Tensor = aten::mul(%0, %2) |
318 | return (%3))IR" ; |
319 | auto graph = std::make_shared<Graph>(); |
320 | parseIR(graph_string, &*graph); |
321 | |
322 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
323 | auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) |
324 | .index({Slice(None, None, 2), Slice(None, None, 2)}); |
325 | auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
326 | auto ref = a * (a * b); |
327 | TensorExprKernel k(graph); |
328 | std::vector<at::Tensor> inputs = {a, b}; |
329 | StmtPtr s = k.getCodeGenStmt(); |
330 | |
331 | std::ostringstream oss; |
332 | oss << *s; |
333 | |
334 | // Check the IR we produced |
335 | const std::string& verification_pattern = |
336 | R"IR( |
337 | # CHECK: for |
338 | # CHECK-NEXT: for |
339 | # CHECK-NOT: for)IR" ; |
340 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
341 | |
342 | std::vector<IValue> stack = fmap<IValue>(inputs); |
343 | k.run(stack); |
344 | o = stack[0].toTensor(); |
345 | for (size_t i = 0; i < 5 * 3; i++) { |
346 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
347 | } |
348 | } |
349 | { |
350 | const auto graph_string = R"IR( |
351 | graph(%0 : Float(8, 8, strides=[8, 1], device=cpu), |
352 | %1 : Float(8, 8, strides=[8, 1], device=cpu)): |
353 | %2 : Tensor = aten::mul(%0, %1) |
354 | %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2) |
355 | %r : Tensor = aten::mul(%3, %4) |
356 | return (%r))IR" ; |
357 | auto graph = std::make_shared<Graph>(); |
358 | parseIR(graph_string, &*graph); |
359 | |
360 | auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); |
361 | auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); |
362 | auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat)); |
363 | auto t = torch::chunk(a * b, 2, 1); |
364 | auto ref = t[0] * t[1]; |
365 | TensorExprKernel k(graph); |
366 | std::vector<at::Tensor> inputs = {a, b}; |
367 | StmtPtr s = k.getCodeGenStmt(); |
368 | |
369 | std::ostringstream oss; |
370 | oss << *s; |
371 | |
372 | // Check the IR we produced |
373 | const std::string& verification_pattern = |
374 | R"IR( |
375 | # CHECK: for)IR" ; |
376 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
377 | |
378 | std::vector<IValue> stack = fmap<IValue>(inputs); |
379 | k.run(stack); |
380 | o = stack[0].toTensor(); |
381 | TORCH_CHECK_EQ(o.sizes()[0], 8); |
382 | TORCH_CHECK_EQ(o.sizes()[1], 4); |
383 | for (size_t i = 0; i < 8 * 4; i++) { |
384 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
385 | } |
386 | } |
387 | { |
388 | // Test that shape inference handles aten::unsqueeze |
389 | |
390 | const auto graph_string = R"IR( |
391 | graph(%a : Float(4, 2, strides=[2, 1], device=cpu), |
392 | %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu), |
393 | %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)): |
394 | %one : int = prim::Constant[value=1]() |
395 | %minus_one : int = prim::Constant[value=-1]() |
396 | %three : int = prim::Constant[value=3]() |
397 | %minus_four : int = prim::Constant[value=-4]() |
398 | %a1 : Tensor = aten::unsqueeze(%a, %one) # new size: [4,1,2] |
399 | %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1] |
400 | %b1 : Tensor = aten::unsqueeze(%b, %three) # new size: [4,3,2,1] |
401 | %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2] |
402 | %ab : Tensor = aten::mul(%a2, %b1) # expected size: [4,3,2,1] |
403 | %abc : Tensor = aten::mul(%ab, %c1) # expected size: [4,3,2,2] |
404 | return (%abc))IR" ; |
405 | auto graph = std::make_shared<Graph>(); |
406 | parseIR(graph_string, &*graph); |
407 | |
408 | auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
409 | auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
410 | auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
411 | auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
412 | auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) * |
413 | at::unsqueeze(c, -4); |
414 | |
415 | TensorExprKernel k(graph); |
416 | std::vector<at::Tensor> inputs = {a, b, c}; |
417 | StmtPtr s = k.getCodeGenStmt(); |
418 | |
419 | std::ostringstream oss; |
420 | oss << *s; |
421 | |
422 | // Check the IR we produced |
423 | const std::string& verification_pattern = |
424 | R"IR( |
425 | # CHECK: for |
426 | # CHECK-NEXT: for |
427 | # CHECK-NEXT: for |
428 | # CHECK-NEXT: for |
429 | # CHECK-NEXT: aten_mul)IR" ; |
430 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
431 | |
432 | std::vector<IValue> stack = fmap<IValue>(inputs); |
433 | k.run(stack); |
434 | o = stack[0].toTensor(); |
435 | |
436 | // Check sizes |
437 | TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); |
438 | size_t num_el = 1; |
439 | for (const auto idx : c10::irange(ref.sizes().size())) { |
440 | TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); |
441 | num_el *= ref.sizes()[idx]; |
442 | } |
443 | |
444 | // Check the contents |
445 | for (const auto i : c10::irange(num_el)) { |
446 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
447 | } |
448 | } |
449 | { |
450 | // Test that shape inference handles aten::cat |
451 | |
452 | const auto graph_string = R"IR( |
453 | graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), |
454 | %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), |
455 | %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): |
456 | %dim : int = prim::Constant[value=1]() |
457 | %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) |
458 | %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2] |
459 | return (%r))IR" ; |
460 | auto graph = std::make_shared<Graph>(); |
461 | parseIR(graph_string, &*graph); |
462 | |
463 | auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
464 | auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
465 | auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
466 | auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
467 | auto ref = at::cat({a, b, c}, 1); |
468 | |
469 | TensorExprKernel k(graph); |
470 | std::vector<at::Tensor> inputs = {a, b, c}; |
471 | StmtPtr s = k.getCodeGenStmt(); |
472 | |
473 | std::ostringstream oss; |
474 | oss << *s; |
475 | |
476 | // Check the IR we produced |
477 | const std::string& verification_pattern = |
478 | R"IR( |
479 | # CHECK: for |
480 | # CHECK-NEXT: for |
481 | # CHECK-NEXT: for |
482 | # CHECK-NEXT: aten_cat)IR" ; |
483 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
484 | |
485 | std::vector<IValue> stack = fmap<IValue>(inputs); |
486 | k.run(stack); |
487 | o = stack[0].toTensor(); |
488 | |
489 | // Check sizes |
490 | TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); |
491 | size_t num_el = 1; |
492 | for (const auto idx : c10::irange(ref.sizes().size())) { |
493 | TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); |
494 | num_el *= ref.sizes()[idx]; |
495 | } |
496 | |
497 | // Check the contents |
498 | for (const auto i : c10::irange(num_el)) { |
499 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
500 | } |
501 | } |
502 | { |
503 | // Test that we throw an error when input list for aten::cat is empty |
504 | |
505 | const auto graph_string = R"IR( |
506 | graph(): |
507 | %dim : int = prim::Constant[value=1]() |
508 | %inputs : Tensor[] = prim::ListConstruct() |
509 | %r : Tensor = aten::cat(%inputs, %dim) |
510 | return (%r))IR" ; |
511 | auto graph = std::make_shared<Graph>(); |
512 | parseIR(graph_string, &*graph); |
513 | auto compile = [&]() { |
514 | TensorExprKernel k(graph); |
515 | k.getCodeGenStmt(); |
516 | }; |
517 | ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat" ); |
518 | } |
519 | { |
520 | // Test that we throw an error when 'dim' passed to aten::cat is invalid |
521 | |
522 | const auto ir_dim_99 = R"IR( |
523 | graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), |
524 | %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): |
525 | %dim : int = prim::Constant[value=99]() |
526 | %inputs : Tensor[] = prim::ListConstruct(%a, %b) |
527 | %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) |
528 | return (%r))IR" ; |
529 | const auto ir_dim_minus_6 = R"IR( |
530 | graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), |
531 | %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): |
532 | %dim : int = prim::Constant[value=-6]() |
533 | %inputs : Tensor[] = prim::ListConstruct(%a, %b) |
534 | %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) |
535 | return (%r))IR" ; |
536 | |
537 | auto compile = [](const std::string& graph_string) { |
538 | auto graph = std::make_shared<Graph>(); |
539 | parseIR(graph_string, &*graph); |
540 | TensorExprKernel k(graph); |
541 | k.getCodeGenStmt(); |
542 | }; |
543 | ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index" ); |
544 | ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index" ); |
545 | } |
546 | } |
547 | |
548 | TEST_F(Kernel, CatInputTypesPromotion) { |
549 | { |
550 | // Test that we properly promote input types for aten::cat |
551 | |
552 | const auto graph_string = R"IR( |
553 | graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), |
554 | %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), |
555 | %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)): |
556 | %dim : int = prim::Constant[value=1]() |
557 | %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) |
558 | %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) |
559 | return (%r))IR" ; |
560 | auto graph = std::make_shared<Graph>(); |
561 | parseIR(graph_string, &*graph); |
562 | |
563 | auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
564 | auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
565 | auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble)); |
566 | auto ref = at::cat({a, b, c}, 1); |
567 | |
568 | TensorExprKernel k(graph); |
569 | std::vector<at::Tensor> inputs = {a, b, c}; |
570 | StmtPtr s = k.getCodeGenStmt(); |
571 | |
572 | std::ostringstream oss; |
573 | oss << *s; |
574 | |
575 | // Check the IR we produced |
576 | const std::string& verification_pattern = |
577 | R"IR( |
578 | # CHECK: for |
579 | # CHECK-NEXT: for |
580 | # CHECK-NEXT: for |
581 | # CHECK-NEXT: aten_cat)IR" ; |
582 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
583 | |
584 | std::vector<IValue> stack = fmap<IValue>(inputs); |
585 | k.run(stack); |
586 | auto o = stack[0].toTensor(); |
587 | |
588 | // Check sizes |
589 | TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); |
590 | TORCH_CHECK_EQ(o.dtype(), ref.dtype()); |
591 | size_t num_el = 1; |
592 | for (const auto idx : c10::irange(ref.sizes().size())) { |
593 | TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); |
594 | num_el *= ref.sizes()[idx]; |
595 | } |
596 | |
597 | // Check the contents |
598 | for (const auto i : c10::irange(num_el)) { |
599 | TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]); |
600 | } |
601 | } |
602 | } |
603 | |
604 | TEST_F(Kernel, ToDType) { |
605 | #ifdef TORCH_ENABLE_LLVM |
606 | const auto graph_string = R"IR( |
607 | graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): |
608 | %1 : NoneType = prim::Constant() |
609 | %2 : bool = prim::Constant[value=0]() |
610 | %3 : int = prim::Constant[value=6]() |
611 | %4 : int = prim::Constant[value=15]() |
612 | %5 : int = prim::Constant[value=5]() |
613 | %6 : bool = prim::Constant[value=1]() |
614 | %y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1) |
615 | %z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4) |
616 | %h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6) |
617 | %i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1) |
618 | %j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1) |
619 | %k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1) |
620 | return (%k.3))IR" ; |
621 | |
622 | auto graph = std::make_shared<Graph>(); |
623 | parseIR(graph_string, &*graph); |
624 | TensorExprKernel k(graph); |
625 | StmtPtr s = k.getCodeGenStmt(); |
626 | std::ostringstream oss; |
627 | oss << *s; |
628 | |
629 | const std::string& verification_pattern = |
630 | R"IR( |
631 | # CHECK: for |
632 | # CHECK-NEXT: for |
633 | # CHECK-NEXT: aten_to |
634 | # CHECK-NEXT: } |
635 | # CHECK-NEXT: })IR" ; |
636 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
637 | |
638 | auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16)); |
639 | auto ref = |
640 | at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat)); |
641 | |
642 | std::vector<at::Tensor> inputs = {a}; |
643 | std::vector<IValue> stack = fmap<IValue>(inputs); |
644 | k.run(stack); |
645 | auto o = stack[0].toTensor(); |
646 | ASSERT_EQ(o.sizes(), ref.sizes()); |
647 | ASSERT_EQ(o.dtype(), ref.dtype()); |
648 | ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); |
649 | #endif |
650 | } |
651 | |
652 | TEST_F(Kernel, CatAndInlineWithAConstantDim) { |
653 | const auto graph_string = R"IR( |
654 | graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu), |
655 | %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)): |
656 | %2 : bool = prim::Constant[value=0]() |
657 | %3 : int = prim::Constant[value=1]() |
658 | %4 : Tensor[] = prim::ListConstruct(%0, %1) |
659 | %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3) |
660 | %6 : Tensor[] = prim::ListConstruct(%5) |
661 | %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3) |
662 | %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2) |
663 | return (%8, %7))IR" ; |
664 | |
665 | auto graph = std::make_shared<Graph>(); |
666 | parseIR(graph_string, &*graph); |
667 | TensorExprKernel k(graph); |
668 | |
669 | auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); |
670 | auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); |
671 | auto ref = at::_cast_Float(at::cat({a, b}, 1), 0); |
672 | |
673 | std::vector<at::Tensor> inputs = {a, b}; |
674 | std::vector<IValue> stack = fmap<IValue>(inputs); |
675 | k.run(stack); |
676 | auto o = stack[0].toTensor(); |
677 | ASSERT_EQ(o.sizes(), ref.sizes()); |
678 | ASSERT_EQ(o.dtype(), ref.dtype()); |
679 | ASSERT_TRUE(at::allclose(o, ref)); |
680 | } |
681 | |
682 | TEST_F(Kernel, CatWithEmptyInputs) { |
683 | bool curr_cat_wo_conditionals = getCatWoConditionals(); |
684 | for (auto cat_wo_conditionals : {true, false}) { |
685 | getCatWoConditionals() = cat_wo_conditionals; |
686 | const auto graph_string = R"IR( |
687 | graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu), |
688 | %1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)): |
689 | %3 : int = prim::Constant[value=0]() |
690 | %6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0) |
691 | %7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1) |
692 | %10 : Tensor[] = prim::ListConstruct(%6, %7) |
693 | %11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3) |
694 | return (%11))IR" ; |
695 | |
696 | auto graph = std::make_shared<Graph>(); |
697 | parseIR(graph_string, &*graph); |
698 | TensorExprKernel k(graph); |
699 | |
700 | auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat)); |
701 | auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat)); |
702 | auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0); |
703 | |
704 | std::vector<at::Tensor> inputs = {a, b}; |
705 | std::vector<IValue> stack = fmap<IValue>(inputs); |
706 | k.run(stack); |
707 | auto o = stack[0].toTensor(); |
708 | ASSERT_EQ(o.sizes(), ref.sizes()); |
709 | ASSERT_EQ(o.dtype(), ref.dtype()); |
710 | ASSERT_TRUE(at::allclose(o, ref)); |
711 | } |
712 | getCatWoConditionals() = curr_cat_wo_conditionals; |
713 | } |
714 | |
715 | TEST_F(Kernel, CatWoConditionals) { |
716 | bool old_cat_wo_conditionals = getCatWoConditionals(); |
717 | getCatWoConditionals() = true; |
718 | const auto graph_string = R"IR( |
719 | graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), |
720 | %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), |
721 | %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): |
722 | %dim : int = prim::Constant[value=1]() |
723 | %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) |
724 | %r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) |
725 | return (%r))IR" ; |
726 | |
727 | auto graph = std::make_shared<Graph>(); |
728 | parseIR(graph_string, &*graph); |
729 | |
730 | TensorExprKernel k(graph); |
731 | StmtPtr s = k.getCodeGenStmt(); |
732 | std::ostringstream oss; |
733 | oss << *s; |
734 | |
735 | const std::string& verification_pattern = |
736 | R"IR( |
737 | # CHECK: for |
738 | # CHECK: for |
739 | # CHECK: for |
740 | # CHECK: aten_cat |
741 | # CHECK: for |
742 | # CHECK: for |
743 | # CHECK: aten_cat |
744 | # CHECK: for |
745 | # CHECK: for |
746 | # CHECK: aten_cat)IR" ; |
747 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
748 | |
749 | auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
750 | auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
751 | auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); |
752 | auto ref = at::cat({a, b, c}, 1); |
753 | |
754 | std::vector<at::Tensor> inputs = {a, b, c}; |
755 | std::vector<IValue> stack = fmap<IValue>(inputs); |
756 | k.run(stack); |
757 | auto o = stack[0].toTensor(); |
758 | |
759 | // Check sizes |
760 | TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); |
761 | TORCH_CHECK_EQ(o.dtype(), ref.dtype()); |
762 | size_t num_el = 1; |
763 | for (const auto idx : c10::irange(ref.sizes().size())) { |
764 | TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); |
765 | num_el *= ref.sizes()[idx]; |
766 | } |
767 | |
768 | // Check the contents |
769 | for (const auto i : c10::irange(num_el)) { |
770 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
771 | } |
772 | getCatWoConditionals() = old_cat_wo_conditionals; |
773 | } |
774 | |
775 | TEST_F(Kernel, OptimizeConditionals) { |
776 | bool old_cat_wo_conditionals = getCatWoConditionals(); |
777 | bool old_opt_conditionals = getOptConditionals(); |
778 | getCatWoConditionals() = false; |
779 | getOptConditionals() = true; |
780 | const auto graph_string = R"IR( |
781 | graph(%a : Float(5, 3, strides=[3, 1], device=cpu), |
782 | %b : Float(5, 7, strides=[7, 1], device=cpu), |
783 | %c : Float(5, 9, strides=[9, 1], device=cpu)): |
784 | %dim : int = prim::Constant[value=1]() |
785 | %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) |
786 | %r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim) |
787 | %t : Float(5, 19, strides=[19, 1]) = aten::relu(%r) |
788 | return (%t))IR" ; |
789 | |
790 | auto graph = std::make_shared<Graph>(); |
791 | parseIR(graph_string, &*graph); |
792 | |
793 | TensorExprKernel k(graph); |
794 | StmtPtr s = k.getCodeGenStmt(); |
795 | std::ostringstream oss; |
796 | oss << *s; |
797 | |
798 | const std::string& verification_pattern = |
799 | R"IR( |
800 | # CHECK: for |
801 | # CHECK-NEXT: for |
802 | # CHECK-NEXT: aten_relu |
803 | # CHECK: for |
804 | # CHECK-NEXT: aten_relu |
805 | # CHECK: for |
806 | # CHECK-NEXT: aten_relu |
807 | # CHECK-NOT: Allocate |
808 | # CHECK-NOT: Free)IR" ; |
809 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
810 | |
811 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
812 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
813 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
814 | auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat)); |
815 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
816 | auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat)); |
817 | auto ref = at::relu(at::cat({a, b, c}, 1)); |
818 | |
819 | std::vector<at::Tensor> inputs = {a, b, c}; |
820 | std::vector<IValue> stack = fmap<IValue>(inputs); |
821 | k.run(stack); |
822 | auto o = stack[0].toTensor(); |
823 | |
824 | // Check sizes |
825 | TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); |
826 | TORCH_CHECK_EQ(o.dtype(), ref.dtype()); |
827 | size_t num_el = 1; |
828 | for (const auto idx : c10::irange(ref.sizes().size())) { |
829 | TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); |
830 | num_el *= ref.sizes()[idx]; |
831 | } |
832 | |
833 | // Check the contents |
834 | for (const auto i : c10::irange(num_el)) { |
835 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
836 | } |
837 | getOptConditionals() = old_opt_conditionals; |
838 | getCatWoConditionals() = old_cat_wo_conditionals; |
839 | } |
840 | |
841 | namespace { |
842 | |
843 | std::string dtypeConstant(ScalarType scalar_type) { |
844 | if (scalar_type == ScalarType::Undefined) { |
845 | return "None = prim::Constant()" ; |
846 | } else { |
847 | at::jit::TemplateEnv env_dtype; |
848 | env_dtype.d("scalar_type" , static_cast<int>(scalar_type)); |
849 | return format("int = prim::Constant[value=${scalar_type}]()" , env_dtype); |
850 | } |
851 | } |
852 | |
853 | at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) { |
854 | int64_t numel = std::accumulate( |
855 | sizes.begin(), |
856 | sizes.end(), |
857 | 1, |
858 | // NOLINTNEXTLINE(modernize-use-transparent-functors) |
859 | std::multiplies<int64_t>()); |
860 | std::vector<float> values(numel); |
861 | std::iota(values.begin(), values.end(), 0); |
862 | auto a = at::tensor(values, options); |
863 | return a.reshape(sizes); |
864 | } |
865 | |
866 | } // namespace |
867 | |
868 | TEST_F(Kernel, SumAllAxes) { |
869 | // Test lowering of sum on all axes. |
870 | const auto graph_template = R"IR( |
871 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): |
872 | %1 : ${dtype} |
873 | %2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1) |
874 | return (%2))IR" ; |
875 | auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
876 | |
877 | for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { |
878 | at::jit::TemplateEnv env; |
879 | env.s("dtype" , dtypeConstant(scalar_type)); |
880 | if (scalar_type == ScalarType::Undefined) { |
881 | env.s("out_dtype" , "Float" ); |
882 | } else { |
883 | env.s("out_dtype" , "Double" ); |
884 | } |
885 | const auto graph_string = format(graph_template, env); |
886 | |
887 | auto graph = std::make_shared<Graph>(); |
888 | parseIR(graph_string, &*graph); |
889 | |
890 | auto o = at::empty({}, TensorOptions(kCPU)); |
891 | c10::optional<c10::ScalarType> dtype; |
892 | if (scalar_type != ScalarType::Undefined) { |
893 | dtype = static_cast<c10::ScalarType>(scalar_type); |
894 | } |
895 | auto ref = a.sum(/*dtype=*/dtype); |
896 | TensorExprKernel k(graph); |
897 | std::vector<at::Tensor> inputs = {a}; |
898 | StmtPtr s = k.getCodeGenStmt(); |
899 | |
900 | std::ostringstream oss; |
901 | oss << *s; |
902 | |
903 | // Check the IR we produced |
904 | const std::string& verification_pattern = |
905 | R"IR( |
906 | # CHECK: for |
907 | # CHECK-NEXT: for)IR" ; |
908 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
909 | |
910 | std::vector<IValue> stack = fmap<IValue>(inputs); |
911 | k.run(stack); |
912 | o = stack[0].toTensor(); |
913 | ASSERT_EQ(o.sizes(), ref.sizes()); |
914 | ASSERT_EQ(o.dtype(), ref.dtype()); |
915 | ASSERT_TRUE(at::allclose(o, ref)); |
916 | } |
917 | } |
918 | |
919 | std::string li_to_str(at::ArrayRef<int64_t> li) { |
920 | std::stringstream out; |
921 | bool first = true; |
922 | for (auto elem : li) { |
923 | if (!first) { |
924 | out << ", " ; |
925 | } |
926 | out << elem; |
927 | first = false; |
928 | } |
929 | return out.str(); |
930 | } |
931 | |
932 | TEST_F(Kernel, SumOneAxis) { |
933 | // Test lowering of sum on one axis. |
934 | const auto graph_template = R"IR( |
935 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): |
936 | %1 : int[] = prim::Constant[value=[${dim}]]() |
937 | %2 : bool = prim::Constant[value=${keepdim}]() |
938 | %3 : ${dtype} |
939 | %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3) |
940 | return (%4))IR" ; |
941 | auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
942 | |
943 | for (int dim = -a.dim(); dim < a.dim(); ++dim) { |
944 | for (bool keepdim : {false, true}) { |
945 | for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { |
946 | at::jit::TemplateEnv env; |
947 | env.d("dim" , dim); |
948 | env.d("keepdim" , keepdim); |
949 | env.s("dtype" , dtypeConstant(scalar_type)); |
950 | c10::optional<c10::ScalarType> dtype; |
951 | if (scalar_type != ScalarType::Undefined) { |
952 | dtype = static_cast<c10::ScalarType>(scalar_type); |
953 | } |
954 | auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype); |
955 | if (scalar_type == ScalarType::Undefined) { |
956 | env.s("out_dtype" , "Float" ); |
957 | } else { |
958 | env.s("out_dtype" , "Double" ); |
959 | } |
960 | env.s("size" , li_to_str(ref.sizes())); |
961 | env.s("strides" , li_to_str(ref.strides())); |
962 | const auto graph_string = format(graph_template, env); |
963 | auto graph = std::make_shared<Graph>(); |
964 | parseIR(graph_string, &*graph); |
965 | |
966 | auto o = at::empty({}, TensorOptions(kCPU)); |
967 | TensorExprKernel k(graph); |
968 | std::vector<at::Tensor> inputs = {a}; |
969 | StmtPtr s = k.getCodeGenStmt(); |
970 | |
971 | std::ostringstream oss; |
972 | oss << *s; |
973 | |
974 | // Check the IR we produced |
975 | const std::string& verification_pattern = |
976 | R"IR( |
977 | # CHECK: for (int64_t |
978 | # CHECK-NEXT: sum |
979 | # CHECK-NEXT: for (int64_t |
980 | # CHECK-NEXT: sum)IR" ; |
981 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
982 | |
983 | std::vector<IValue> stack = fmap<IValue>(inputs); |
984 | k.run(stack); |
985 | o = stack[0].toTensor(); |
986 | ASSERT_EQ(o.sizes(), ref.sizes()); |
987 | ASSERT_EQ(o.dtype(), ref.dtype()); |
988 | ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); |
989 | } |
990 | } |
991 | } |
992 | } |
993 | |
994 | TEST_F(Kernel, SumMultipleAxes) { |
995 | // Test lowering of sum on multiple axes. |
996 | const auto graph_template = R"IR( |
997 | graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)): |
998 | %1 : int = prim::Constant[value=${dim1}]() |
999 | %2 : int = prim::Constant[value=${dim2}]() |
1000 | %3 : int[] = prim::ListConstruct(%1, %2) |
1001 | %4 : bool = prim::Constant[value=${keepdim}]() |
1002 | %5 : ${dtype} |
1003 | %6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5) |
1004 | return (%6))IR" ; |
1005 | auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1006 | |
1007 | // Only iterate over positive values of axes to keep the running time |
1008 | // reasonable, since the number of pairs is quadratic. |
1009 | for (const auto dim1 : c10::irange(a.dim())) { |
1010 | for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) { |
1011 | for (bool keepdim : {false, true}) { |
1012 | at::jit::TemplateEnv env; |
1013 | env.d("dim1" , dim1); |
1014 | env.d("dim2" , dim2); |
1015 | env.d("keepdim" , keepdim); |
1016 | env.s("dtype" , dtypeConstant(ScalarType::Undefined)); |
1017 | auto o = at::empty({}, TensorOptions(kCPU)); |
1018 | auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); |
1019 | |
1020 | env.s("size" , li_to_str(ref.sizes())); |
1021 | env.s("strides" , li_to_str(ref.strides())); |
1022 | |
1023 | const auto graph_string = format(graph_template, env); |
1024 | |
1025 | auto graph = std::make_shared<Graph>(); |
1026 | parseIR(graph_string, &*graph); |
1027 | |
1028 | TensorExprKernel k(graph); |
1029 | std::vector<at::Tensor> inputs = {a}; |
1030 | StmtPtr s = k.getCodeGenStmt(); |
1031 | |
1032 | std::ostringstream oss; |
1033 | oss << *s; |
1034 | |
1035 | // Check the IR we produced |
1036 | const std::string& verification_pattern = |
1037 | R"IR( |
1038 | # CHECK: for (int64_t |
1039 | # CHECK: for (int64_t |
1040 | # CHECK: for (int64_t |
1041 | # CHECK: for (int64_t |
1042 | # CHECK: sum)IR" ; |
1043 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1044 | |
1045 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1046 | k.run(stack); |
1047 | o = stack[0].toTensor(); |
1048 | ASSERT_EQ(o.sizes(), ref.sizes()); |
1049 | ASSERT_EQ(o.dtype(), ref.dtype()); |
1050 | ASSERT_TRUE(at::allclose(o, ref)); |
1051 | } |
1052 | } |
1053 | } |
1054 | } |
1055 | |
1056 | // This test and the following ones testing Softmax only tests with dim set |
1057 | // to one of the valid input dimensions. It does not test with dim=None |
1058 | // because that is supposed to be deprecated. |
1059 | TEST_F(Kernel, Softmax2D) { |
1060 | const auto graph_template = R"IR( |
1061 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): |
1062 | %1 : int = prim::Constant[value=${dim}]() |
1063 | %dt_float : int = prim::Constant[value=7]() |
1064 | %dt_none : NoneType = prim::Constant() |
1065 | %4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt}) |
1066 | return (%4))IR" ; |
1067 | |
1068 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1069 | |
1070 | const std::string& verification_template = |
1071 | R"IR( |
1072 | # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size} |
1073 | # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} |
1074 | # CHECK-NEXT: aten_softmax_max |
1075 | # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size} |
1076 | # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} |
1077 | # CHECK-NEXT: aten_softmax_sum |
1078 | # CHECK: for (int i0_2 = 0; i0_2 < 5 |
1079 | # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 |
1080 | # CHECK-NEXT: aten_softmax)IR" ; |
1081 | |
1082 | for (bool empty_dtype : {false, true}) { |
1083 | for (auto log_softmax : {false, true}) { |
1084 | for (const auto softmax_dim : c10::irange(a.dim())) { |
1085 | auto softmax_dim_size = a.sizes()[softmax_dim]; |
1086 | auto other_dim = (softmax_dim + 1) % a.dim(); |
1087 | auto ref = |
1088 | log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); |
1089 | at::jit::TemplateEnv env; |
1090 | env.d("dim" , softmax_dim); |
1091 | env.s("op" , log_softmax ? "log_softmax" : "softmax" ); |
1092 | env.s("size" , li_to_str(ref.sizes())); |
1093 | env.s("strides" , li_to_str(ref.strides())); |
1094 | env.s("dt" , empty_dtype ? "dt_none" : "dt_float" ); |
1095 | |
1096 | const auto graph_string = format(graph_template, env); |
1097 | |
1098 | auto graph = std::make_shared<Graph>(); |
1099 | parseIR(graph_string, &*graph); |
1100 | |
1101 | TensorExprKernel k(graph); |
1102 | std::vector<at::Tensor> inputs = {a}; |
1103 | StmtPtr s = k.getCodeGenStmt(); |
1104 | |
1105 | std::ostringstream oss; |
1106 | oss << *s; |
1107 | |
1108 | at::jit::TemplateEnv ver_env; |
1109 | ver_env.d("other_dim" , other_dim); |
1110 | ver_env.d("other_dim_size" , a.sizes()[other_dim]); |
1111 | ver_env.d("softmax_dim" , softmax_dim); |
1112 | ver_env.d("softmax_dim_size" , softmax_dim_size); |
1113 | const auto verification_pattern = |
1114 | format(verification_template, ver_env); |
1115 | |
1116 | // verication sting temporarily disabled until |
1117 | // inlining of exp() is benchmarked and determined |
1118 | // torch::jit::testing::FileCheck().run(verification_pattern, |
1119 | // oss.str()); |
1120 | |
1121 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1122 | k.run(stack); |
1123 | auto output = stack[0].toTensor(); |
1124 | ASSERT_EQ(output.sizes(), ref.sizes()); |
1125 | ASSERT_TRUE(at::allclose(output, ref)); |
1126 | } |
1127 | } |
1128 | } |
1129 | } |
1130 | |
1131 | TEST_F(Kernel, Softmax3D) { |
1132 | const auto graph_template = R"IR( |
1133 | graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)): |
1134 | %1 : int = prim::Constant[value=${dim}]() |
1135 | %2 : int = prim::Constant[value=7]() |
1136 | %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) |
1137 | return (%3))IR" ; |
1138 | |
1139 | auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat)); |
1140 | |
1141 | const std::string& verification_template = |
1142 | R"IR( |
1143 | # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} |
1144 | # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} |
1145 | # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} |
1146 | # CHECK-NEXT: aten_softmax_max |
1147 | # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} |
1148 | # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} |
1149 | # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} |
1150 | # CHECK-NEXT: aten_softmax_sum |
1151 | # CHECK: for (int i0_2 = 0; i0_2 < 3 |
1152 | # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4 |
1153 | # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5 |
1154 | # CHECK-NEXT: aten_softmax)IR" ; |
1155 | |
1156 | for (auto log_softmax : {false, true}) { |
1157 | for (const auto softmax_dim : c10::irange(a.dim())) { |
1158 | auto softmax_dim_size = a.sizes()[softmax_dim]; |
1159 | std::vector<int> other_dims; |
1160 | for (const auto i : c10::irange(a.dim())) { |
1161 | if (i != softmax_dim) { |
1162 | other_dims.push_back(i); |
1163 | } |
1164 | } |
1165 | auto ref = |
1166 | log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); |
1167 | |
1168 | at::jit::TemplateEnv env; |
1169 | env.d("dim" , softmax_dim); |
1170 | env.s("op" , log_softmax ? "log_softmax" : "softmax" ); |
1171 | env.s("size" , li_to_str(ref.sizes())); |
1172 | env.s("strides" , li_to_str(ref.strides())); |
1173 | |
1174 | const auto graph_string = format(graph_template, env); |
1175 | |
1176 | auto graph = std::make_shared<Graph>(); |
1177 | parseIR(graph_string, &*graph); |
1178 | |
1179 | TensorExprKernel k(graph); |
1180 | std::vector<at::Tensor> inputs = {a}; |
1181 | StmtPtr s = k.getCodeGenStmt(); |
1182 | |
1183 | std::ostringstream oss; |
1184 | oss << *s; |
1185 | |
1186 | at::jit::TemplateEnv ver_env; |
1187 | ver_env.d("dim1" , other_dims[0]); |
1188 | ver_env.d("dim1_size" , a.sizes()[other_dims[0]]); |
1189 | ver_env.d("dim2" , other_dims[1]); |
1190 | ver_env.d("dim2_size" , a.sizes()[other_dims[1]]); |
1191 | ver_env.d("softmax_dim" , softmax_dim); |
1192 | ver_env.d("softmax_dim_size" , softmax_dim_size); |
1193 | const auto verification_pattern = format(verification_template, ver_env); |
1194 | |
1195 | // verication sting temporarily disabled until |
1196 | // inlining of exp() is benchmarked and determined |
1197 | // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1198 | |
1199 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1200 | k.run(stack); |
1201 | auto output = stack[0].toTensor(); |
1202 | |
1203 | ASSERT_EQ(output.sizes(), ref.sizes()); |
1204 | ASSERT_TRUE(at::allclose(output, ref)); |
1205 | } |
1206 | } |
1207 | } |
1208 | |
1209 | TEST_F(Kernel, Softmax4D) { |
1210 | const auto graph_template = R"IR( |
1211 | graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): |
1212 | %1 : int = prim::Constant[value=${dim}]() |
1213 | %2 : int = prim::Constant[value=7]() |
1214 | %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) |
1215 | return (%3))IR" ; |
1216 | |
1217 | auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1218 | |
1219 | const std::string& verification_template = |
1220 | R"IR( |
1221 | # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} |
1222 | # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} |
1223 | # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size} |
1224 | # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} |
1225 | # CHECK-NEXT: aten_softmax_max |
1226 | # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} |
1227 | # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} |
1228 | # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size} |
1229 | # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} |
1230 | # CHECK-NEXT: aten_softmax_sum |
1231 | # CHECK: for (int i0_2 = 0; i0_2 < 2 |
1232 | # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 |
1233 | # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2 |
1234 | # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3 |
1235 | # CHECK-NEXT: aten_softmax)IR" ; |
1236 | |
1237 | for (auto log_softmax : {false, true}) { |
1238 | for (const auto softmax_dim : c10::irange(a.dim())) { |
1239 | auto softmax_dim_size = a.sizes()[softmax_dim]; |
1240 | std::vector<int> other_dims; |
1241 | for (const auto i : c10::irange(a.dim())) { |
1242 | if (i != softmax_dim) { |
1243 | other_dims.push_back(i); |
1244 | } |
1245 | } |
1246 | auto ref = |
1247 | log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); |
1248 | |
1249 | at::jit::TemplateEnv env; |
1250 | env.d("dim" , softmax_dim); |
1251 | env.s("op" , log_softmax ? "log_softmax" : "softmax" ); |
1252 | env.s("size" , li_to_str(ref.sizes())); |
1253 | env.s("strides" , li_to_str(ref.strides())); |
1254 | |
1255 | const auto graph_string = format(graph_template, env); |
1256 | |
1257 | auto graph = std::make_shared<Graph>(); |
1258 | parseIR(graph_string, &*graph); |
1259 | |
1260 | TensorExprKernel k(graph); |
1261 | std::vector<at::Tensor> inputs = {a}; |
1262 | StmtPtr s = k.getCodeGenStmt(); |
1263 | |
1264 | std::ostringstream oss; |
1265 | oss << *s; |
1266 | |
1267 | at::jit::TemplateEnv ver_env; |
1268 | ver_env.d("dim1" , other_dims[0]); |
1269 | ver_env.d("dim1_size" , a.sizes()[other_dims[0]]); |
1270 | ver_env.d("dim2" , other_dims[1]); |
1271 | ver_env.d("dim2_size" , a.sizes()[other_dims[1]]); |
1272 | ver_env.d("dim3" , other_dims[2]); |
1273 | ver_env.d("dim3_size" , a.sizes()[other_dims[2]]); |
1274 | ver_env.d("softmax_dim" , softmax_dim); |
1275 | ver_env.d("softmax_dim_size" , softmax_dim_size); |
1276 | const auto verification_pattern = format(verification_template, ver_env); |
1277 | |
1278 | // verication sting temporarily disabled until |
1279 | // inlining of exp() is benchmarked and determined |
1280 | // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1281 | |
1282 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1283 | k.run(stack); |
1284 | auto output = stack[0].toTensor(); |
1285 | ASSERT_EQ(output.sizes(), ref.sizes()); |
1286 | ASSERT_TRUE(at::allclose(output, ref)); |
1287 | } |
1288 | } |
1289 | } |
1290 | |
1291 | TEST_F(Kernel, SignTest) { |
1292 | const auto graph_template = R"IR( |
1293 | graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)): |
1294 | %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0) |
1295 | return (%2))IR" ; |
1296 | |
1297 | auto run_test = [](const std::string& graph_string, const at::Tensor& input) { |
1298 | auto graph = std::make_shared<Graph>(); |
1299 | parseIR(graph_string, &*graph); |
1300 | |
1301 | TensorExprKernel k(graph); |
1302 | StmtPtr s = k.getCodeGenStmt(); |
1303 | |
1304 | std::vector<at::Tensor> inputs = {input}; |
1305 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1306 | k.run(stack); |
1307 | auto o = stack[0].toTensor(); |
1308 | auto ref = at::sign(input); |
1309 | ASSERT_TRUE(at::allclose(o, ref)); |
1310 | }; |
1311 | auto common_options = at::TensorOptions() |
1312 | .layout(at::kStrided) |
1313 | .device(at::kCPU) |
1314 | .requires_grad(false); |
1315 | int default_input_size = 100; |
1316 | for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) { |
1317 | at::Tensor corner_case_inputs; |
1318 | at::jit::TemplateEnv env; |
1319 | auto options = common_options; |
1320 | switch (scalar_type) { |
1321 | case ScalarType::Float: { |
1322 | env.s("dtype" , "Float" ); |
1323 | options = options.dtype(at::kFloat); |
1324 | std::vector<float> input_float = { |
1325 | 0.0f, |
1326 | -0.0f, |
1327 | std::numeric_limits<float>::infinity(), |
1328 | -std::numeric_limits<float>::infinity(), |
1329 | std::nanf("1" ), |
1330 | -std::nanf("1" )}; |
1331 | corner_case_inputs = at::from_blob( |
1332 | input_float.data(), |
1333 | {static_cast<long>(input_float.size())}, |
1334 | options); |
1335 | auto rand_input = at::rand({default_input_size}, options); |
1336 | auto input = at::cat({rand_input, corner_case_inputs}); |
1337 | env.d("size" , at::numel(input)); |
1338 | const auto graph_string = format(graph_template, env); |
1339 | run_test(graph_string, input); |
1340 | break; |
1341 | } |
1342 | case ScalarType::Double: { |
1343 | env.s("dtype" , "Double" ); |
1344 | options = options.dtype(at::kDouble); |
1345 | std::vector<double> input_double = { |
1346 | 0.0, |
1347 | -0.0, |
1348 | std::numeric_limits<double>::infinity(), |
1349 | -std::numeric_limits<double>::infinity(), |
1350 | std::nan("1" ), |
1351 | -std::nan("1" )}; |
1352 | corner_case_inputs = at::from_blob( |
1353 | input_double.data(), |
1354 | {static_cast<long>(input_double.size())}, |
1355 | options); |
1356 | auto rand_input = at::rand({default_input_size}, options); |
1357 | auto input = at::cat({rand_input, corner_case_inputs}); |
1358 | env.d("size" , at::numel(input)); |
1359 | const auto graph_string = format(graph_template, env); |
1360 | run_test(graph_string, input); |
1361 | break; |
1362 | } |
1363 | default: |
1364 | throw unsupported_dtype(); |
1365 | } |
1366 | } |
1367 | } |
1368 | |
1369 | TEST_F(Kernel, InlineProducerIntoReduction) { |
1370 | // Inline producer (mul) into reduction (sum). |
1371 | const auto graph_string = R"IR( |
1372 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
1373 | %1 : Float(5, 3, strides=[3, 1], device=cpu)): |
1374 | %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1) |
1375 | %3 : int = prim::Constant[value=7]() |
1376 | %4 : Double(device=cpu) = aten::sum(%2, %3) |
1377 | return (%4))IR" ; |
1378 | auto graph = std::make_shared<Graph>(); |
1379 | parseIR(graph_string, &*graph); |
1380 | |
1381 | TensorExprKernel k(graph); |
1382 | StmtPtr s = k.getCodeGenStmt(); |
1383 | std::ostringstream oss; |
1384 | oss << *s; |
1385 | |
1386 | // Check the IR we produced. |
1387 | // We should have only one loop in the end. |
1388 | const std::string& verification_pattern = |
1389 | R"IR( |
1390 | # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 |
1391 | # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 |
1392 | # CHECK-NEXT: sum |
1393 | # CHECK-NOT: for)IR" ; |
1394 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1395 | |
1396 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1397 | auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1398 | std::vector<at::Tensor> inputs = {a, b}; |
1399 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1400 | k.run(stack); |
1401 | auto o = stack[0].toTensor(); |
1402 | auto ref = (a * b).sum(at::kDouble); |
1403 | ASSERT_TRUE(at::allclose(o, ref)); |
1404 | } |
1405 | |
1406 | TEST_F(Kernel, InlineReductionIntoConsumer) { |
1407 | // Inline producer (mul %2) into reduction (sum %4) but DO NOT |
1408 | // inline the reduction into consumer (mul %4). |
1409 | const auto graph_string = R"IR( |
1410 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
1411 | %1 : Float(5, 3, strides=[3, 1], device=cpu)): |
1412 | %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) |
1413 | %3 : int = prim::Constant[value=6]() |
1414 | %4 : Float(device=cpu) = aten::sum(%2, %3) |
1415 | %5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4) |
1416 | return (%5))IR" ; |
1417 | auto graph = std::make_shared<Graph>(); |
1418 | parseIR(graph_string, &*graph); |
1419 | |
1420 | TensorExprKernel k(graph); |
1421 | StmtPtr s = k.getCodeGenStmt(); |
1422 | std::ostringstream oss; |
1423 | oss << *s; |
1424 | |
1425 | // Check the IR we produced. |
1426 | // We should have two loops in the end. |
1427 | const std::string& verification_pattern = |
1428 | R"IR( |
1429 | # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 |
1430 | # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 |
1431 | # CHECK-NEXT: sum |
1432 | # CHECK: for (int64_t i_2 = 0ll; i_2 < 5 |
1433 | # CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3 |
1434 | # CHECK-NEXT: aten_mul |
1435 | # CHECK-NOT: for)IR" ; |
1436 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1437 | |
1438 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1439 | auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1440 | std::vector<at::Tensor> inputs = {a, b}; |
1441 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1442 | k.run(stack); |
1443 | auto o = stack[0].toTensor(); |
1444 | auto ref = (a * b).sum(at::kFloat) * (a * b); |
1445 | ASSERT_TRUE(at::allclose(o, ref)); |
1446 | } |
1447 | |
1448 | TEST_F(Kernel, SanitizeNames_CUDA) { |
1449 | const auto graph_string = R"IR( |
1450 | graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0), |
1451 | %1 : Float(5, 3, strides=[3, 1], device=cuda:0)): |
1452 | %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) |
1453 | %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) |
1454 | return (%4))IR" ; |
1455 | auto graph = std::make_shared<Graph>(); |
1456 | parseIR(graph_string, &*graph); |
1457 | graph->inputs().at(0)->setDebugName("aten::add:" ); |
1458 | graph->inputs().at(1)->setDebugName("aten::add_" ); |
1459 | TensorExprKernel k(graph); |
1460 | auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); |
1461 | auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); |
1462 | auto ref = a * (a * b); |
1463 | std::vector<at::Tensor> inputs = {a, b}; |
1464 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1465 | k.run(stack); |
1466 | auto o = stack[0].toTensor(); |
1467 | ASSERT_TRUE(at::allclose(o, ref)); |
1468 | } |
1469 | |
1470 | TEST_F(Kernel, SanitizeConstants_CUDA) { |
1471 | const auto graph_string = R"IR( |
1472 | graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)): |
1473 | %none : NoneType = prim::Constant() |
1474 | %size : int = prim::Constant[value=16]() |
1475 | %sizes : int[] = prim::ListConstruct(%size, %size) |
1476 | %30 : Device = prim::Constant[value="cuda"]() |
1477 | %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none) |
1478 | %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y) |
1479 | return (%z))IR" ; |
1480 | auto graph = std::make_shared<Graph>(); |
1481 | parseIR(graph_string, &*graph); |
1482 | // IRParser doesn't support tensor constants, so we insert a call to |
1483 | // aten::ones and then const-prop it |
1484 | ConstantPropagation(graph); |
1485 | |
1486 | // We set the name of the constant to include special characters that are |
1487 | // not allowed. This should be fixed by the sanitizer in TensorExprKernel. |
1488 | graph->nodes().front()->output()->setDebugName("illegal.name" ); |
1489 | |
1490 | // Check if we have a constant node with illegal name in the graph. |
1491 | auto const_node = graph->nodes().front(); |
1492 | ASSERT_EQ(const_node->kind(), prim::Constant); |
1493 | ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos); |
1494 | |
1495 | TensorExprKernel k(graph); |
1496 | |
1497 | auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); |
1498 | std::vector<at::Tensor> inputs = {x}; |
1499 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1500 | k.run(stack); |
1501 | auto o = stack[0].toTensor(); |
1502 | auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); |
1503 | auto ref = x * y; |
1504 | ASSERT_TRUE(at::allclose(o, ref)); |
1505 | } |
1506 | |
1507 | TEST_F(Kernel, ConstantTensors) { |
1508 | const auto graph_string = R"IR( |
1509 | graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): |
1510 | %none : NoneType = prim::Constant() |
1511 | %size : int = prim::Constant[value=16]() |
1512 | %sizes : int[] = prim::ListConstruct(%size, %size) |
1513 | %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none) |
1514 | %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) |
1515 | return (%z))IR" ; |
1516 | auto graph = std::make_shared<Graph>(); |
1517 | parseIR(graph_string, &*graph); |
1518 | // IRParser doesn't support tensor constants, so we insert a call to |
1519 | // aten::ones and then const-prop it |
1520 | ConstantPropagation(graph); |
1521 | |
1522 | TensorExprKernel k(graph); |
1523 | |
1524 | auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); |
1525 | std::vector<at::Tensor> inputs = {x}; |
1526 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1527 | k.run(stack); |
1528 | auto o = stack[0].toTensor(); |
1529 | auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); |
1530 | auto ref = x * y; |
1531 | ASSERT_TRUE(at::allclose(o, ref)); |
1532 | } |
1533 | |
1534 | TEST_F(Kernel, ConstantTensorsNonContiguous) { |
1535 | const auto graph_string = R"IR( |
1536 | graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): |
1537 | %none : NoneType = prim::Constant() |
1538 | %dtype : int = prim::Constant[value=6]() |
1539 | %c0 : int = prim::Constant[value=0]() |
1540 | %c256 : int = prim::Constant[value=256]() |
1541 | %c16 : int = prim::Constant[value=16]() |
1542 | %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) |
1543 | %sizes : int[] = prim::ListConstruct(%c16, %c16) |
1544 | %y_t : Tensor = aten::view(%y_flat, %sizes) |
1545 | %y : Tensor = aten::t(%y_t) |
1546 | %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) |
1547 | return (%z))IR" ; |
1548 | auto graph = std::make_shared<Graph>(); |
1549 | parseIR(graph_string, &*graph); |
1550 | // IRParser doesn't support tensor constants, so we generate several aten |
1551 | // calls to produce non-contiguos constant tensor and then const-prop it |
1552 | ConstantPropagation(graph); |
1553 | |
1554 | TensorExprKernel k(graph); |
1555 | |
1556 | auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); |
1557 | std::vector<at::Tensor> inputs = {x}; |
1558 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1559 | k.run(stack); |
1560 | auto o = stack[0].toTensor(); |
1561 | auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat)) |
1562 | .view({16, 16}) |
1563 | .t(); |
1564 | auto ref = x * y; |
1565 | ASSERT_TRUE(at::allclose(o, ref)); |
1566 | } |
1567 | |
1568 | TEST_F(Kernel, RunFast) { |
1569 | #ifdef TORCH_ENABLE_LLVM |
1570 | // TODO: Implement call_raw in IREval and remove the ifdef |
1571 | |
1572 | const auto graph_string = R"IR( |
1573 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
1574 | %1 : Float(5, 3, strides=[1, 5], device=cpu)): |
1575 | %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) |
1576 | %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) |
1577 | return (%3))IR" ; |
1578 | auto graph = std::make_shared<Graph>(); |
1579 | parseIR(graph_string, &*graph); |
1580 | |
1581 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1582 | auto b = |
1583 | at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); |
1584 | auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1585 | auto ref = a * (a * b); |
1586 | TensorExprKernel k(graph); |
1587 | |
1588 | k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()}); |
1589 | for (size_t i = 0; i < 5 * 3; i++) { |
1590 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
1591 | } |
1592 | #endif |
1593 | } |
1594 | |
1595 | TEST_F(Kernel, RunWithAllocatedOutputs) { |
1596 | #ifdef TORCH_ENABLE_LLVM |
1597 | const auto graph_string = R"IR( |
1598 | graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), |
1599 | %1 : Float(5, 3, strides=[1, 5], device=cpu)): |
1600 | %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) |
1601 | %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) |
1602 | return (%3))IR" ; |
1603 | auto graph = std::make_shared<Graph>(); |
1604 | parseIR(graph_string, &*graph); |
1605 | |
1606 | auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1607 | auto b = |
1608 | at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); |
1609 | auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1610 | auto ref = a * (a * b); |
1611 | TensorExprKernel k(graph); |
1612 | |
1613 | std::vector<at::Tensor> args = {o, a, b}; |
1614 | std::vector<IValue> stack = fmap<IValue>(args); |
1615 | k.runWithAllocatedOutputs(stack); |
1616 | for (size_t i = 0; i < 5 * 3; i++) { |
1617 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
1618 | } |
1619 | #endif |
1620 | } |
1621 | |
1622 | TEST_F(Kernel, CodegenInspection) { |
1623 | #ifdef TORCH_ENABLE_LLVM |
1624 | const auto graph_string = R"IR( |
1625 | graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): |
1626 | %none : NoneType = prim::Constant() |
1627 | %dtype : int = prim::Constant[value=6]() |
1628 | %c0 : int = prim::Constant[value=0]() |
1629 | %c256 : int = prim::Constant[value=256]() |
1630 | %c16 : int = prim::Constant[value=16]() |
1631 | %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) |
1632 | %sizes : int[] = prim::ListConstruct(%c16, %c16) |
1633 | %y_t : Tensor = aten::view(%y_flat, %sizes) |
1634 | %y : Tensor = aten::t(%y_t) |
1635 | %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) |
1636 | return (%z))IR" ; |
1637 | auto graph = std::make_shared<Graph>(); |
1638 | parseIR(graph_string, &*graph); |
1639 | // IRParser doesn't support tensor constants, so we generate several aten |
1640 | // calls to produce non-contiguos constant tensor and then const-prop it |
1641 | ConstantPropagation(graph); |
1642 | |
1643 | TensorExprKernel k(graph); |
1644 | |
1645 | // Check that we could retrieve generated assembly |
1646 | auto asm_str = k.getCodeText("asm" ); |
1647 | const std::string& asm_verification_pattern = |
1648 | R"ASM( |
1649 | # CHECK: .text |
1650 | # CHECK: retq)ASM" ; |
1651 | torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str); |
1652 | |
1653 | // Check that we could retrieve info about codegen parameters |
1654 | auto constants = k.getConstantDescriptors(); |
1655 | auto buf_args = k.getBufferArgs(); |
1656 | // Expected buf args: [input0, output0, constant0] |
1657 | ASSERT_EQ(buf_args.size(), 3); |
1658 | ASSERT_EQ(constants.size(), 1); |
1659 | ASSERT_TRUE( |
1660 | !buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar()); |
1661 | #endif |
1662 | } |
1663 | |
1664 | Tensor lowerNanToNum( |
1665 | const std::vector<ArgValue>& inputs, |
1666 | const std::vector<ExprHandle>& outputShape, |
1667 | const std::vector<ExprHandle>& outputStrides, |
1668 | const c10::optional<ScalarType>& outputType, |
1669 | at::Device device) { |
1670 | auto input_buf = c10::get<BufHandle>(inputs[0]); |
1671 | auto e = Compute( |
1672 | "custom_nan_to_num" , |
1673 | outputShape, |
1674 | outputStrides, |
1675 | [&](const std::vector<VarHandle>& axes) { |
1676 | std::vector<ExprHandle> indices(axes.begin(), axes.end()); |
1677 | auto load = input_buf.load(indices); |
1678 | return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load); |
1679 | }); |
1680 | return e; |
1681 | } |
1682 | |
1683 | TEST_F(Kernel, CustomLowering) { |
1684 | const auto graph_string = R"IR( |
1685 | graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): |
1686 | %none : NoneType = prim::Constant() |
1687 | %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) |
1688 | return (%y) |
1689 | )IR" ; |
1690 | auto graph = std::make_shared<Graph>(); |
1691 | parseIR(graph_string, &*graph); |
1692 | |
1693 | std::unordered_map<c10::Symbol, NNCLoweringFunction> lowerings = { |
1694 | {aten::nan_to_num, lowerNanToNum}}; |
1695 | TensorExprKernel k(graph, lowerings); |
1696 | |
1697 | auto stmt = k.getCodeGenStmt(); |
1698 | std::ostringstream oss; |
1699 | oss << *stmt; |
1700 | |
1701 | // Check that our custom lowering is actually used |
1702 | torch::jit::testing::FileCheck().check("custom_nan_to_num" )->run(oss.str()); |
1703 | torch::jit::testing::FileCheck().check("isnan" )->run(oss.str()); |
1704 | } |
1705 | |
1706 | TEST_F(Kernel, Vectorize) { |
1707 | #ifdef TORCH_ENABLE_LLVM |
1708 | const auto graph_string = R"IR( |
1709 | graph(%0 : Float(100, 16, strides=[16, 1], device=cpu), |
1710 | %1 : Float(100, 16, strides=[16, 1], device=cpu)): |
1711 | %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1) |
1712 | %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2) |
1713 | return (%3))IR" ; |
1714 | auto graph = std::make_shared<Graph>(); |
1715 | parseIR(graph_string, &*graph); |
1716 | |
1717 | auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); |
1718 | auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); |
1719 | auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); |
1720 | auto ref = a * (a * b); |
1721 | TensorExprKernel k(graph); |
1722 | std::vector<at::Tensor> inputs = {a, b}; |
1723 | StmtPtr s = k.getCodeGenStmt(); |
1724 | |
1725 | std::ostringstream oss; |
1726 | oss << *s; |
1727 | |
1728 | // Check the IR we produced |
1729 | const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR" ; |
1730 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1731 | |
1732 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1733 | k.run(stack); |
1734 | o = stack[0].toTensor(); |
1735 | for (size_t i = 0; i < 100 * 16; i++) { |
1736 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
1737 | } |
1738 | #endif |
1739 | } |
1740 | |
1741 | // TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first. |
1742 | TEST_F(Kernel, DISABLED_FlattenVectorize) { |
1743 | #ifdef TORCH_ENABLE_LLVM |
1744 | const auto graph_string = R"IR( |
1745 | graph(%0 : Float(100, 3, strides=[3, 1], device=cpu), |
1746 | %1 : Float(100, 3, strides=[3, 1], device=cpu)): |
1747 | %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1) |
1748 | %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2) |
1749 | return (%3))IR" ; |
1750 | auto graph = std::make_shared<Graph>(); |
1751 | parseIR(graph_string, &*graph); |
1752 | |
1753 | auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1754 | auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1755 | auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1756 | auto ref = a * (a * b); |
1757 | TensorExprKernel k(graph); |
1758 | std::vector<at::Tensor> inputs = {a, b}; |
1759 | StmtPtr s = k.getCodeGenStmt(); |
1760 | |
1761 | std::ostringstream oss; |
1762 | oss << *s; |
1763 | |
1764 | // Check the IR we produced |
1765 | const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR" ; |
1766 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1767 | |
1768 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1769 | k.run(stack); |
1770 | o = stack[0].toTensor(); |
1771 | for (size_t i = 0; i < 100 * 3; i++) { |
1772 | TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); |
1773 | } |
1774 | #endif |
1775 | } |
1776 | |
1777 | TEST_F(Kernel, Strided1dWithinBounds) { |
1778 | auto ir = R"IR( |
1779 | graph(%0 : Float(3, strides=[1], device=cpu), |
1780 | %1 : Float(3, strides=[2], device=cpu)): |
1781 | %2 : int = prim::Constant[value=1]() |
1782 | %3 : Float(3, strides=[1]) = aten::add(%0, %1, %2) |
1783 | return (%3))IR" ; |
1784 | auto graph = std::make_shared<Graph>(); |
1785 | std::unordered_map<std::string, Value*> vmap; |
1786 | parseIR(ir, graph.get(), vmap); |
1787 | TensorExprKernel k(graph); |
1788 | |
1789 | auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1790 | auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat)) |
1791 | .index({Slice(None, None, 2)}); |
1792 | auto expect = a + b; |
1793 | |
1794 | std::vector<at::Tensor> inputs = {a, b}; |
1795 | |
1796 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1797 | k.run(stack); |
1798 | |
1799 | auto output = stack[0].toTensor(); |
1800 | |
1801 | for (size_t i = 0; i < 3; ++i) { |
1802 | TORCH_CHECK_EQ( |
1803 | ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]); |
1804 | } |
1805 | } |
1806 | |
1807 | TEST_F(Kernel, InputAsOutput) { |
1808 | const auto graph_string = R"IR( |
1809 | graph(%x : Float(5, 3, strides=[3, 1], device=cpu), |
1810 | %y : Float(5, 3, strides=[1, 5], device=cpu)): |
1811 | return (%x, %y))IR" ; |
1812 | auto graph = std::make_shared<Graph>(); |
1813 | parseIR(graph_string, &*graph); |
1814 | |
1815 | auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); |
1816 | auto y = |
1817 | at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); |
1818 | TensorExprKernel k(graph); |
1819 | std::vector<at::Tensor> inputs = {x, y}; |
1820 | |
1821 | std::vector<IValue> stack = fmap<IValue>(inputs); |
1822 | k.run(stack); |
1823 | CHECK(at::allclose(x, stack[0].toTensor())); |
1824 | CHECK(at::allclose(y, stack[1].toTensor())); |
1825 | } |
1826 | |
1827 | TEST_F(Kernel, ScalarOut) { |
1828 | auto ir = R"IR( |
1829 | graph(%x : int, %y : int): |
1830 | %z : int = aten::mul(%x, %y) |
1831 | %r : int = aten::mul(%z, %x) |
1832 | return (%r, %z))IR" ; |
1833 | auto graph = std::make_shared<Graph>(); |
1834 | std::unordered_map<std::string, Value*> vmap; |
1835 | parseIR(ir, graph.get(), vmap); |
1836 | TensorExprKernel k(graph); |
1837 | |
1838 | auto stmt = k.getCodeGenStmt(); |
1839 | std::ostringstream oss; |
1840 | oss << *stmt; |
1841 | |
1842 | // Verify the generated IR. We expect to see a scalar variable (Let) followed |
1843 | // by a store to a 0-dim buffer. |
1844 | const std::string& verification_pattern = R"IR( |
1845 | # CHECK: int64_t |
1846 | # CHECK-NEXT: [0ll] = |
1847 | # CHECK-NEXT: int64_t |
1848 | # CHECK-NEXT: [0ll] = |
1849 | )IR" ; |
1850 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1851 | |
1852 | int64_t x = 2, y = 3, r = 0, z = 0; |
1853 | |
1854 | // Verify that TEK::runFast works correctly with scalar outputs |
1855 | std::vector<void*> inputs = {&x, &y}; |
1856 | std::vector<void*> outputs = {&r, &z}; |
1857 | k.runFast(inputs, outputs); |
1858 | TORCH_CHECK_EQ(z, x * y); |
1859 | TORCH_CHECK_EQ(r, z * x); |
1860 | |
1861 | // Verify that TEK::run works correctly with scalar outputs |
1862 | std::vector<IValue> stack = {x, y}; |
1863 | k.run(stack); |
1864 | TORCH_CHECK_EQ(stack[0], x * y * x); |
1865 | TORCH_CHECK_EQ(stack[1], x * y); |
1866 | } |
1867 | |
1868 | TEST_F(Kernel, ScalarTensorOut) { |
1869 | auto ir = R"IR( |
1870 | graph(%x : int, |
1871 | %xt : Long(3, strides=[1], device=cpu), |
1872 | %y : int, |
1873 | %yt : Long(3, strides=[1], device=cpu)): |
1874 | %z : int = aten::mul(%x, %y) |
1875 | %r : int = aten::mul(%z, %x) |
1876 | %zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y) |
1877 | %rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt) |
1878 | return (%r, %rt, %z, %zt))IR" ; |
1879 | auto graph = std::make_shared<Graph>(); |
1880 | std::unordered_map<std::string, Value*> vmap; |
1881 | parseIR(ir, graph.get(), vmap); |
1882 | TensorExprKernel k(graph); |
1883 | int64_t x = 2, y = 3, r = 0, z = 0; |
1884 | auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2; |
1885 | auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3; |
1886 | auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); |
1887 | auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); |
1888 | |
1889 | // Verify that TEK::runFast works correctly with mixed scalar and tensor |
1890 | // inputs/utputs |
1891 | std::vector<void*> inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()}; |
1892 | std::vector<void*> outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()}; |
1893 | k.runFast(inputs, outputs); |
1894 | TORCH_CHECK_EQ(z, x * y); |
1895 | TORCH_CHECK_EQ(r, z * x); |
1896 | ASSERT_TRUE(at::equal(zt, xt * yt)); |
1897 | ASSERT_TRUE(at::equal(rt, zt * xt)); |
1898 | |
1899 | // Verify that TEK::run works correctly with mixed scalar and tensor |
1900 | // inputs/utputs |
1901 | std::vector<IValue> stack = {x, xt, y, yt}; |
1902 | k.run(stack); |
1903 | TORCH_CHECK_EQ(stack[0], x * y * x); |
1904 | ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt)); |
1905 | TORCH_CHECK_EQ(stack[2], x * y); |
1906 | ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt)); |
1907 | } |
1908 | |
1909 | TEST_F(Kernel, FuseLoopsWithVariableBounds) { |
1910 | #ifdef TORCH_ENABLE_LLVM |
1911 | bool old_cat_wo_conditionals = getCatWoConditionals(); |
1912 | getCatWoConditionals() = true; |
1913 | const auto graph_string = R"IR( |
1914 | graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu), |
1915 | %b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu), |
1916 | %c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu), |
1917 | %SS_2 : int, |
1918 | %SS_3 : int): |
1919 | %dim : int = prim::Constant[value=1]() |
1920 | %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) |
1921 | %r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] |
1922 | return (%r))IR" ; |
1923 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
1924 | torch::jit::parseIR(graph_string, graph.get()); |
1925 | |
1926 | std::vector<int64_t> symbolic_shape_inputs = {-2, -3}; |
1927 | |
1928 | std::vector<torch::jit::StrideInput> input_desc = { |
1929 | torch::jit::StrideInput::TENSOR_CONT}; |
1930 | std::unordered_map< |
1931 | const torch::jit::Value*, |
1932 | std::vector<torch::jit::StrideInput>> |
1933 | symbolic_strides; |
1934 | symbolic_strides[graph->inputs().at(0)] = input_desc; |
1935 | symbolic_strides[graph->inputs().at(1)] = input_desc; |
1936 | symbolic_strides[graph->inputs().at(2)] = input_desc; |
1937 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
1938 | |
1939 | TensorExprKernel kernel( |
1940 | graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
1941 | |
1942 | std::ostringstream oss; |
1943 | oss << *kernel.getCodeGenStmt(); |
1944 | const std::string& verification_pattern = |
1945 | R"IR( |
1946 | # CHECK: for (int64_t i |
1947 | # CHECK-NEXT: for (int64_t j |
1948 | # CHECK-NEXT: for (int64_t k |
1949 | # CHECK: for (int64_t j |
1950 | # CHECK-NEXT: for (int64_t k |
1951 | # CHECK: for (int64_t j |
1952 | # CHECK-NEXT: for (int64_t k |
1953 | # CHECK-NOT: for (int64_t i |
1954 | )IR" ; |
1955 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1956 | |
1957 | auto run_kernel = [&](int dim1, int dim2) { |
1958 | auto a = |
1959 | at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); |
1960 | auto b = |
1961 | at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); |
1962 | auto c = |
1963 | at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); |
1964 | |
1965 | auto ref = at::cat({a, b, c}, 1); |
1966 | |
1967 | std::vector<IValue> stack = |
1968 | fmap<IValue>(std::vector<at::Tensor>({a, b, c})); |
1969 | stack.emplace_back(dim1); |
1970 | stack.emplace_back(dim2); |
1971 | kernel.run(stack); |
1972 | |
1973 | auto o = stack[0].toTensor(); |
1974 | ASSERT_TRUE(at::allclose(o, ref)); |
1975 | }; |
1976 | |
1977 | run_kernel(10, 20); |
1978 | getCatWoConditionals() = old_cat_wo_conditionals; |
1979 | #endif |
1980 | } |
1981 | |
1982 | TEST_F(Kernel, FuseLoopsWithVariableConcatDim) { |
1983 | #ifdef TORCH_ENABLE_LLVM |
1984 | bool old_cat_wo_conditionals = getCatWoConditionals(); |
1985 | getCatWoConditionals() = true; |
1986 | const auto graph_string = R"IR( |
1987 | graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), |
1988 | %b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), |
1989 | %c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), |
1990 | %SS_2 : int, |
1991 | %SS_3 : int, |
1992 | %SS_4 : int, |
1993 | %SS_5 : int): |
1994 | %dim : int = prim::Constant[value=1]() |
1995 | %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) |
1996 | %r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] |
1997 | return (%r))IR" ; |
1998 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
1999 | torch::jit::parseIR(graph_string, graph.get()); |
2000 | |
2001 | std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5}; |
2002 | |
2003 | std::vector<torch::jit::StrideInput> input_desc = { |
2004 | torch::jit::StrideInput::TENSOR_CONT}; |
2005 | std::unordered_map< |
2006 | const torch::jit::Value*, |
2007 | std::vector<torch::jit::StrideInput>> |
2008 | symbolic_strides; |
2009 | symbolic_strides[graph->inputs().at(0)] = input_desc; |
2010 | symbolic_strides[graph->inputs().at(1)] = input_desc; |
2011 | symbolic_strides[graph->inputs().at(2)] = input_desc; |
2012 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
2013 | |
2014 | TensorExprKernel kernel( |
2015 | graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
2016 | |
2017 | std::ostringstream oss; |
2018 | oss << *kernel.getCodeGenStmt(); |
2019 | const std::string& verification_pattern = |
2020 | R"IR( |
2021 | # CHECK: for (int64_t i |
2022 | # CHECK-NEXT: for (int64_t j |
2023 | # CHECK-NEXT: for (int64_t k |
2024 | # CHECK: for (int64_t j |
2025 | # CHECK-NEXT: for (int64_t k |
2026 | # CHECK: for (int64_t j |
2027 | # CHECK-NEXT: for (int64_t k |
2028 | # CHECK-NOT: for (int64_t i |
2029 | )IR" ; |
2030 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2031 | |
2032 | auto run_kernel = [&](int dim1, int dim2, int dim3) { |
2033 | auto a = |
2034 | at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); |
2035 | auto b = |
2036 | at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); |
2037 | auto c = |
2038 | at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); |
2039 | |
2040 | auto ref = at::cat({a, b, c}, 1); |
2041 | |
2042 | std::vector<IValue> stack = |
2043 | fmap<IValue>(std::vector<at::Tensor>({a, b, c})); |
2044 | stack.emplace_back(dim1); |
2045 | stack.emplace_back(dim2); |
2046 | stack.emplace_back(dim3); |
2047 | stack.emplace_back(3 * dim3); |
2048 | kernel.run(stack); |
2049 | |
2050 | auto o = stack[0].toTensor(); |
2051 | ASSERT_TRUE(at::allclose(o, ref)); |
2052 | }; |
2053 | |
2054 | run_kernel(10, 20, 15); |
2055 | getCatWoConditionals() = old_cat_wo_conditionals; |
2056 | #endif |
2057 | } |
2058 | |
2059 | TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) { |
2060 | #ifdef TORCH_ENABLE_LLVM |
2061 | bool old_cat_wo_conditionals = getCatWoConditionals(); |
2062 | getCatWoConditionals() = true; |
2063 | const auto graph_string = R"IR( |
2064 | graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), |
2065 | %b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu), |
2066 | %SS_2 : int, |
2067 | %SS_3 : int, |
2068 | %SS_4 : int, |
2069 | %SS_5 : int, |
2070 | %SS_6 : int): |
2071 | %dim : int = prim::Constant[value=1]() |
2072 | %inputs : Tensor[] = prim::ListConstruct(%a, %b) |
2073 | %r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] |
2074 | return (%r))IR" ; |
2075 | std::shared_ptr<Graph> graph = std::make_shared<Graph>(); |
2076 | torch::jit::parseIR(graph_string, graph.get()); |
2077 | |
2078 | std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5, -6}; |
2079 | |
2080 | std::vector<torch::jit::StrideInput> input_desc = { |
2081 | torch::jit::StrideInput::TENSOR_CONT}; |
2082 | std::unordered_map< |
2083 | const torch::jit::Value*, |
2084 | std::vector<torch::jit::StrideInput>> |
2085 | symbolic_strides; |
2086 | symbolic_strides[graph->inputs().at(0)] = input_desc; |
2087 | symbolic_strides[graph->inputs().at(1)] = input_desc; |
2088 | symbolic_strides[graph->outputs().at(0)] = input_desc; |
2089 | |
2090 | TensorExprKernel kernel( |
2091 | graph, {}, symbolic_shape_inputs, false, symbolic_strides); |
2092 | |
2093 | std::ostringstream oss; |
2094 | oss << *kernel.getCodeGenStmt(); |
2095 | const std::string& verification_pattern = |
2096 | R"IR( |
2097 | # CHECK: for (int64_t i |
2098 | # CHECK-NEXT: for (int64_t j |
2099 | # CHECK-NEXT: for (int64_t k |
2100 | # CHECK: for (int64_t j |
2101 | # CHECK-NEXT: for (int64_t k |
2102 | # CHECK-NOT: for (int64_t j |
2103 | # CHECK-NOT: for (int64_t i |
2104 | )IR" ; |
2105 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2106 | |
2107 | auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) { |
2108 | auto a = |
2109 | at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); |
2110 | auto b = |
2111 | at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); |
2112 | |
2113 | auto ref = at::cat({a, b}, 1); |
2114 | |
2115 | std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b})); |
2116 | stack.emplace_back(dim2); |
2117 | stack.emplace_back(dim3); |
2118 | stack.emplace_back(dim4); |
2119 | stack.emplace_back(dim5); |
2120 | stack.emplace_back(dim4 + dim5); |
2121 | kernel.run(stack); |
2122 | |
2123 | auto o = stack[0].toTensor(); |
2124 | ASSERT_TRUE(at::allclose(o, ref)); |
2125 | }; |
2126 | |
2127 | run_kernel(10, 20, 15, 8); |
2128 | getCatWoConditionals() = old_cat_wo_conditionals; |
2129 | #endif |
2130 | } |
2131 | |
2132 | } // namespace jit |
2133 | } // namespace torch |
2134 | |