1#include <gtest/gtest.h>
2
3#include <test/cpp/tensorexpr/test_base.h>
4
5#include <torch/csrc/jit/ir/irparser.h>
6#include <torch/csrc/jit/passes/subgraph_rewrite.h>
7#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
8#include <torch/csrc/jit/runtime/custom_operator.h>
9#include <torch/csrc/jit/tensorexpr/kernel.h>
10
11#include <test/cpp/tensorexpr/test_utils.h>
12#include <torch/csrc/jit/runtime/operator.h>
13#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
14#include <torch/csrc/jit/tensorexpr/eval.h>
15#include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
16#include <torch/csrc/jit/tensorexpr/ir.h>
17#include <torch/csrc/jit/tensorexpr/ir_printer.h>
18#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
19#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
20#include <torch/csrc/jit/tensorexpr/loopnest.h>
21#include <torch/csrc/jit/tensorexpr/tensor.h>
22
23#include <torch/csrc/jit/testing/file_check.h>
24#include <torch/jit.h>
25
26#include <ATen/NativeFunctions.h>
27#include <ATen/core/dispatch/Dispatcher.h>
28#include <ATen/native/xnnpack/OpContext.h>
29
30namespace torch {
31namespace jit {
32using namespace torch::jit::tensorexpr;
33
34TEST(ExternalCall, Conv1d_float) {
35 BufHandle Input("Input", {1, 100, 115}, kFloat);
36 BufHandle Weight("Weight", {100, 1, 7}, kFloat);
37 BufHandle Bias("Bias", {100}, kFloat);
38 BufHandle ResultBuf("Result", {1, 100, 115}, kFloat);
39 int64_t stride = 1;
40 int64_t pad = 3;
41 int64_t dilation = 1;
42 int64_t groups = 100;
43
44 Tensor Result = Tensor(
45 ResultBuf.node(),
46 ExternalCall::make(
47 ResultBuf,
48 "nnc_aten_conv1d",
49 {Input, Weight, Bias},
50 {stride, pad, dilation, groups}));
51 LoopNest l({Result});
52 l.prepareForCodegen();
53 l.simplify();
54
55 auto options = at::TensorOptions()
56 .dtype(at::kFloat)
57 .layout(at::kStrided)
58 .device(at::kCPU)
59 .requires_grad(false);
60 at::Tensor input = at::ones({1, 100, 115}, options) * 5.f;
61 at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f;
62 at::Tensor bias = at::ones({100}, options) * 11.f;
63 at::Tensor ref =
64 at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);
65
66 at::Tensor nnc_result;
67 std::vector<float> input_buf(1 * 100 * 115, 5.f);
68 std::vector<float> weight_buf(100 * 1 * 7, 6.f);
69 std::vector<float> bias_buf(100, 11.f);
70 std::vector<float> result_buf(1 * 100 * 115, -1.f);
71
72#ifdef TORCH_ENABLE_LLVM
73 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
74
75 llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
76 nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
77 ASSERT_TRUE(at::allclose(nnc_result, ref));
78#endif
79
80 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
81
82 ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
83 nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
84 ASSERT_TRUE(at::allclose(nnc_result, ref));
85}
86
87TEST(ExternalCall, Conv1d_int) {
88 // A similar test, but now using kInt tensors
89 BufHandle Input("Input", {1, 100, 115}, kInt);
90 BufHandle Weight("Weight", {100, 1, 7}, kInt);
91 BufHandle Bias("Bias", {100}, kInt);
92 BufHandle ResultBuf("Result", {1, 100, 115}, kInt);
93 int64_t stride = 1;
94 int64_t pad = 3;
95 int64_t dilation = 1;
96 int64_t groups = 100;
97
98 Tensor Result = Tensor(
99 ResultBuf.node(),
100 ExternalCall::make(
101 ResultBuf,
102 "nnc_aten_conv1d",
103 {Input, Weight, Bias},
104 {stride, pad, dilation, groups}));
105 LoopNest l({Result});
106 l.prepareForCodegen();
107 l.simplify();
108
109 auto options = at::TensorOptions()
110 .dtype(at::kInt)
111 .layout(at::kStrided)
112 .device(at::kCPU)
113 .requires_grad(false);
114 at::Tensor input = at::ones({1, 100, 115}, options) * 5;
115 at::Tensor weight = at::ones({100, 1, 7}, options) * 6;
116 at::Tensor bias = at::ones({100}, options) * 11;
117 at::Tensor ref =
118 at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);
119
120 at::Tensor nnc_result;
121 std::vector<int32_t> input_buf(1 * 100 * 115, 5);
122 std::vector<int32_t> weight_buf(100 * 1 * 7, 6);
123 std::vector<int32_t> bias_buf(100, 11);
124 std::vector<int32_t> result_buf(1 * 100 * 115, -1);
125
126#ifdef TORCH_ENABLE_LLVM
127 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
128
129 llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
130 nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
131 ASSERT_TRUE(at::allclose(nnc_result, ref));
132#endif
133
134 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
135
136 ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
137 nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
138 ASSERT_TRUE(at::allclose(nnc_result, ref));
139}
140
141TEST(ExternalCall, Conv1d_nobias_noargs) {
142 BufHandle Input("Input", {1, 1, 115}, kFloat);
143 BufHandle Weight("Weight", {10, 1, 7}, kFloat);
144 BufHandle ResultBuf("Result", {1, 10, 109}, kFloat);
145
146 Tensor Result = Tensor(
147 ResultBuf.node(),
148 ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {}));
149 LoopNest l({Result});
150 l.prepareForCodegen();
151 l.simplify();
152
153 auto options = at::TensorOptions()
154 .dtype(at::kFloat)
155 .layout(at::kStrided)
156 .device(at::kCPU)
157 .requires_grad(false);
158 at::Tensor input = at::ones({1, 1, 115}, options) * 5.f;
159 at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f;
160 at::Tensor ref = at::conv1d(input, weight);
161
162 at::Tensor nnc_result;
163 std::vector<float> input_buf(1 * 1 * 115, 5.f);
164 std::vector<float> weight_buf(10 * 1 * 7, 6.f);
165 std::vector<float> result_buf(1 * 10 * 109, -1.f);
166
167#ifdef TORCH_ENABLE_LLVM
168 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});
169
170 llvm_codegen.call({input_buf, weight_buf, result_buf});
171 nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);
172 ASSERT_TRUE(at::allclose(nnc_result, ref));
173#endif
174
175 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});
176
177 ir_eval.call({input_buf, weight_buf, result_buf});
178 nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);
179 ASSERT_TRUE(at::allclose(nnc_result, ref));
180}
181
182TEST(ExternalCall, Conv2d_float) {
183 BufHandle Input("Input", {1, 3, 224, 224}, kFloat);
184 BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat);
185 BufHandle Bias("Bias", {16}, kFloat);
186 BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
187 int64_t stride = 2;
188 int64_t pad = 1;
189 int64_t dilation = 1;
190 int64_t groups = 1;
191
192 Tensor Result = Tensor(
193 ResultBuf.node(),
194 ExternalCall::make(
195 ResultBuf,
196 "nnc_aten_conv2d",
197 {Input, Weight, Bias},
198 {stride, stride, pad, pad, dilation, dilation, groups}));
199 LoopNest l({Result});
200 l.prepareForCodegen();
201 l.simplify();
202
203 auto options = at::TensorOptions()
204 .dtype(at::kFloat)
205 .layout(at::kStrided)
206 .device(at::kCPU)
207 .requires_grad(false);
208 at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f;
209 at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f;
210 at::Tensor bias = at::ones({16}, options) * 11.f;
211 at::Tensor ref = at::conv2d(
212 input,
213 weight,
214 bias,
215 {stride, stride},
216 {pad, pad},
217 {dilation, dilation},
218 groups);
219
220 at::Tensor nnc_result;
221 std::vector<float> input_buf(1 * 3 * 224 * 224, 5.f);
222 std::vector<float> weight_buf(16 * 3 * 3 * 3, 6.f);
223 std::vector<float> bias_buf(16, 11.f);
224 std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
225
226#ifdef TORCH_ENABLE_LLVM
227 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
228
229 llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
230 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
231 ASSERT_TRUE(at::allclose(nnc_result, ref));
232#endif
233
234 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
235
236 ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
237 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
238 ASSERT_TRUE(at::allclose(nnc_result, ref));
239}
240
241TEST(ExternalCall, Conv2d_int) {
242 // A similar test, but now using kInt tensors
243
244 BufHandle Input("Input", {1, 3, 224, 224}, kInt);
245 BufHandle Weight("Weight", {16, 3, 3, 3}, kInt);
246 BufHandle Bias("Bias", {16}, kInt);
247 BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt);
248 int64_t stride = 2;
249 int64_t pad = 1;
250 int64_t dilation = 1;
251 int64_t groups = 1;
252
253 Tensor Result = Tensor(
254 ResultBuf.node(),
255 ExternalCall::make(
256 ResultBuf,
257 "nnc_aten_conv2d",
258 {Input, Weight, Bias},
259 {stride, stride, pad, pad, dilation, dilation, groups}));
260 LoopNest l({Result});
261 l.prepareForCodegen();
262 l.simplify();
263
264 auto options = at::TensorOptions()
265 .dtype(at::kInt)
266 .layout(at::kStrided)
267 .device(at::kCPU)
268 .requires_grad(false);
269 at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5;
270 at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6;
271 at::Tensor bias = at::ones({16}, options) * 11;
272 at::Tensor ref = at::conv2d(
273 input,
274 weight,
275 bias,
276 {stride, stride},
277 {pad, pad},
278 {dilation, dilation},
279 groups);
280
281 at::Tensor nnc_result;
282 std::vector<int32_t> input_buf(1 * 3 * 224 * 224, 5);
283 std::vector<int32_t> weight_buf(16 * 3 * 3 * 3, 6);
284 std::vector<int32_t> bias_buf(16, 11);
285 std::vector<int32_t> result_buf(1 * 16 * 112 * 112, -1);
286
287#ifdef TORCH_ENABLE_LLVM
288 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
289
290 llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
291 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
292 ASSERT_TRUE(at::allclose(nnc_result, ref));
293#endif
294
295 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
296
297 ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
298 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
299 ASSERT_TRUE(at::allclose(nnc_result, ref));
300}
301
302TEST(ExternalCall, Conv2d_nobias_noargs) {
303 BufHandle Input("Input", {1, 16, 112, 112}, kFloat);
304 BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat);
305 BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
306
307 Tensor Result = Tensor(
308 ResultBuf.node(),
309 ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {}));
310 LoopNest l({Result});
311 l.prepareForCodegen();
312 l.simplify();
313
314 auto options = at::TensorOptions()
315 .dtype(at::kFloat)
316 .layout(at::kStrided)
317 .device(at::kCPU)
318 .requires_grad(false);
319 at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f;
320 at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;
321 at::Tensor ref = at::conv2d(input, weight);
322
323 at::Tensor nnc_result;
324 std::vector<float> input_buf(1 * 16 * 112 * 112, 5.f);
325 std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);
326 std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
327
328#ifdef TORCH_ENABLE_LLVM
329 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});
330
331 llvm_codegen.call({input_buf, weight_buf, result_buf});
332 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
333 ASSERT_TRUE(at::allclose(nnc_result, ref));
334#endif
335
336 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});
337
338 ir_eval.call({input_buf, weight_buf, result_buf});
339 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
340 ASSERT_TRUE(at::allclose(nnc_result, ref));
341}
342
343TEST(ExternalCall, Addmm_float) {
344 BufHandle Input("Input", {100, 300}, kFloat);
345 BufHandle Mat1("Mat1", {100, 200}, kFloat);
346 BufHandle Mat2("Mat2", {200, 300}, kFloat);
347 BufHandle ResultBuf("Result", {100, 300}, kFloat);
348 int64_t beta = 2;
349 int64_t alpha = 2;
350
351 Tensor Result = Tensor(
352 ResultBuf.node(),
353 ExternalCall::make(
354 ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha}));
355 LoopNest l({Result});
356 l.prepareForCodegen();
357 l.simplify();
358
359 auto options = at::TensorOptions()
360 .dtype(at::kFloat)
361 .layout(at::kStrided)
362 .device(at::kCPU)
363 .requires_grad(false);
364 at::Tensor input = at::ones({100, 300}, options) * 5.f;
365 at::Tensor mat1 = at::ones({100, 200}, options) * 6.f;
366 at::Tensor mat2 = at::ones({200, 300}, options) * 11.f;
367 at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha);
368
369 at::Tensor nnc_result;
370 std::vector<float> input_buf(100 * 300, 5.f);
371 std::vector<float> mat1_buf(100 * 200, 6.f);
372 std::vector<float> mat2_buf(200 * 300, 11.f);
373 std::vector<float> result_buf(100 * 300, -1.f);
374
375#ifdef TORCH_ENABLE_LLVM
376 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result});
377
378 llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf});
379 nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
380 ASSERT_TRUE(at::allclose(nnc_result, ref));
381#endif
382
383 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result});
384
385 ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf});
386 nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
387 ASSERT_TRUE(at::allclose(nnc_result, ref));
388}
389
390TEST(ExternalCall, Embedding) {
391 BufHandle Weight("Weight", {256, 100}, kFloat);
392 BufHandle Indices("Indices", {1, 115}, kLong);
393 BufHandle ResultBuf("Result", {1, 115, 100}, kFloat);
394 int64_t padding_idx = -1;
395 bool scale_grad_by_freq = false;
396 bool sparse = false;
397
398 Tensor Result = Tensor(
399 ResultBuf.node(),
400 ExternalCall::make(
401 ResultBuf,
402 "nnc_aten_embedding",
403 {Weight, Indices},
404 {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse}));
405 LoopNest l({Result});
406 l.prepareForCodegen();
407 l.simplify();
408
409 auto options = at::TensorOptions()
410 .layout(at::kStrided)
411 .device(at::kCPU)
412 .requires_grad(false);
413
414 at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f;
415 at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6;
416 at::Tensor ref =
417 at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
418
419 at::Tensor nnc_result;
420 std::vector<float> weight_buf(256 * 100, 5.f);
421 std::vector<int64_t> indices_buf(1 * 115, 6);
422 std::vector<float> result_buf(1 * 115 * 100, -1.f);
423
424#ifdef TORCH_ENABLE_LLVM
425 LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result});
426
427 llvm_codegen.call({weight_buf, indices_buf, result_buf});
428 nnc_result = at::from_blob(
429 result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
430 ASSERT_TRUE(at::allclose(nnc_result, ref));
431#endif
432
433 SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result});
434
435 ir_eval.call({weight_buf, indices_buf, result_buf});
436 nnc_result = at::from_blob(
437 result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
438 ASSERT_TRUE(at::allclose(nnc_result, ref));
439}
440
441TEST(ExternalCall, MaxReduction) {
442 BufHandle Input("Input", {1, 115, 152}, kFloat);
443 BufHandle ResultBuf("Result", {1, 152}, kFloat);
444 int64_t dim = 1;
445 bool keep_dim = false;
446
447 Tensor Result = Tensor(
448 ResultBuf.node(),
449 ExternalCall::make(
450 ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim}));
451 LoopNest l({Result});
452 l.prepareForCodegen();
453 l.simplify();
454
455 auto options = at::TensorOptions()
456 .dtype(at::kFloat)
457 .layout(at::kStrided)
458 .device(at::kCPU)
459 .requires_grad(false);
460
461 at::Tensor input = at::ones({1, 115, 152}, options) * 5.f;
462 at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim));
463
464 at::Tensor nnc_result;
465 std::vector<float> input_buf(1 * 115 * 152, 5.f);
466 std::vector<float> result_buf(1 * 152, -1.f);
467
468#ifdef TORCH_ENABLE_LLVM
469 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result});
470
471 llvm_codegen.call({input_buf, result_buf});
472 nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);
473 ASSERT_TRUE(at::allclose(nnc_result, ref));
474#endif
475
476 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result});
477
478 ir_eval.call({input_buf, result_buf});
479 nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);
480 ASSERT_TRUE(at::allclose(nnc_result, ref));
481}
482
483#ifdef USE_XNNPACK
484
485TEST(ExternalCall, Prepacked_Linear_float) {
486 using namespace at::native::xnnpack;
487
488 BufHandle Input("Input", {100, 200}, kFloat);
489 BufHandle ResultBuf("Result", {100, 300}, kFloat);
490
491 // Calculate reference result using at::linear.
492 auto options = at::TensorOptions()
493 .dtype(at::kFloat)
494 .layout(at::kStrided)
495 .device(at::kCPU)
496 .requires_grad(false);
497 at::Tensor input =
498 at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200});
499 at::Tensor weight =
500 at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200});
501 at::Tensor bias = at::linspace(-10.0, 10.0, 300, options);
502 at::Tensor ref = at::linear(input, weight, bias);
503
504 // Create prepacked xnnpack context object.
505 auto linear_clamp_prepack_op =
506 c10::Dispatcher::singleton()
507 .findSchemaOrThrow("prepacked::linear_clamp_prepack", "")
508 .typed<c10::intrusive_ptr<LinearOpContext>(
509 at::Tensor,
510 c10::optional<at::Tensor>,
511 const c10::optional<at::Scalar>&,
512 const c10::optional<at::Scalar>&)>();
513 auto prepacked = linear_clamp_prepack_op.call(
514 weight, bias, c10::optional<at::Scalar>(), c10::optional<at::Scalar>());
515
516 BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);
517 Tensor Result = Tensor(
518 ResultBuf.node(),
519 ExternalCall::make(
520 ResultBuf,
521 "nnc_prepacked_linear_clamp_run",
522 {Input, DummyPrepacked},
523 {}));
524 LoopNest l({Result});
525 l.prepareForCodegen();
526 l.simplify();
527
528 at::Tensor nnc_result;
529 std::vector<float> input_buf(
530 input.data_ptr<float>(), input.data_ptr<float>() + 100 * 200);
531 std::vector<float> result_buf(100 * 300, -1.f);
532
533#ifdef TORCH_ENABLE_LLVM
534 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});
535
536 llvm_codegen.call({input_buf, prepacked.get(), result_buf});
537 nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
538 ASSERT_TRUE(at::allclose(nnc_result, ref));
539#endif
540
541 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});
542
543 ir_eval.call({input_buf, prepacked.get(), result_buf});
544 nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
545 ASSERT_TRUE(at::allclose(nnc_result, ref));
546}
547
548TEST(ExternalCall, Prepacked_Conv2d_float) {
549 using namespace at::native::xnnpack;
550
551 BufHandle Input("Input", {1, 3, 224, 224}, kFloat);
552 BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
553 int64_t stride = 2;
554 int64_t pad = 1;
555 int64_t dilation = 1;
556 int64_t groups = 1;
557
558 // Calculate reference result using at::conv2d.
559 auto options = at::TensorOptions()
560 .dtype(at::kFloat)
561 .layout(at::kStrided)
562 .device(at::kCPU)
563 .requires_grad(false);
564 at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options)
565 .resize_({1, 3, 224, 224});
566 at::Tensor weight =
567 at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3});
568 at::Tensor bias = at::linspace(-10.0, 10.0, 16, options);
569 at::Tensor ref = at::conv2d(
570 input,
571 weight,
572 bias,
573 {stride, stride},
574 {pad, pad},
575 {dilation, dilation},
576 groups);
577
578 // Create prepacked xnnpack context object.
579 auto conv2d_clamp_prepack_op =
580 c10::Dispatcher::singleton()
581 .findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "")
582 .typed<c10::intrusive_ptr<Conv2dOpContext>(
583 at::Tensor,
584 c10::optional<at::Tensor>,
585 std::vector<int64_t>,
586 std::vector<int64_t>,
587 std::vector<int64_t>,
588 int64_t,
589 const c10::optional<at::Scalar>&,
590 const c10::optional<at::Scalar>&)>();
591 auto prepacked = conv2d_clamp_prepack_op.call(
592 weight,
593 bias,
594 {stride, stride},
595 {pad, pad},
596 {dilation, dilation},
597 groups,
598 c10::optional<at::Scalar>(),
599 c10::optional<at::Scalar>());
600
601 BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);
602 Tensor Result = Tensor(
603 ResultBuf.node(),
604 ExternalCall::make(
605 ResultBuf,
606 "nnc_prepacked_conv2d_clamp_run",
607 {Input, DummyPrepacked},
608 {}));
609 LoopNest l({Result});
610 l.prepareForCodegen();
611 l.simplify();
612
613 at::Tensor nnc_result;
614 std::vector<float> input_buf(
615 input.data_ptr<float>(), input.data_ptr<float>() + 1 * 3 * 224 * 224);
616 std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
617
618#ifdef TORCH_ENABLE_LLVM
619 LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});
620
621 llvm_codegen.call({input_buf, prepacked.get(), result_buf});
622 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
623 ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));
624#endif
625
626 SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});
627
628 ir_eval.call({input_buf, prepacked.get(), result_buf});
629 nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
630 ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));
631}
632
633#endif // USE_XNNPACK
634
635TEST(ExternalCall, BinaryFloat) {
636 using TensorFunc = std::function<at::Tensor(at::Tensor, at::Tensor)>;
637 using Test = std::tuple<
638 std::vector<int64_t>,
639 std::vector<int64_t>,
640 std::vector<int64_t>,
641 TensorFunc,
642 std::string>;
643 std::vector<Test> tests = {};
644 tests.push_back(
645 Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"});
646 tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"});
647 tests.push_back(
648 Test{{100, 200}, {200, 300}, {100, 300}, at::mm, "nnc_aten_mm"});
649 for (auto curTest : tests) {
650 std::vector<int64_t> aShape, bShape, resShape;
651 TensorFunc torchFunc;
652 std::string externCallName;
653 std::tie(aShape, bShape, resShape, torchFunc, externCallName) = curTest;
654 auto toExprHandleVec = [](std::vector<int64_t> v) {
655 auto intV = std::vector<int>(v.begin(), v.end());
656 return std::vector<ExprHandle>(intV.begin(), intV.end());
657 };
658 BufHandle A("A", toExprHandleVec(aShape), kFloat);
659 BufHandle B("B", toExprHandleVec(bShape), kFloat);
660 BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
661
662 Tensor Result = Tensor(
663 ResultBuf.node(),
664 ExternalCall::make(ResultBuf, externCallName, {A, B}, {}));
665 LoopNest l({Result});
666 l.prepareForCodegen();
667 l.simplify();
668
669 auto options = at::TensorOptions()
670 .dtype(at::kFloat)
671 .layout(at::kStrided)
672 .device(at::kCPU)
673 .requires_grad(false);
674 at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
675 at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f;
676 at::Tensor ref = torchFunc(a, b);
677
678 auto prod = [](std::vector<int64_t> v) {
679 // NOLINTNEXTLINE(modernize-use-transparent-functors)
680 return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
681 };
682
683 at::Tensor nnc_result;
684 std::vector<float> a_buf(prod(aShape), 5.f);
685 std::vector<float> b_buf(prod(bShape), 6.f);
686 std::vector<float> result_buf(prod(resShape), -1.f);
687
688#ifdef TORCH_ENABLE_LLVM
689 LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result});
690
691 llvm_codegen.call({a_buf, b_buf, result_buf});
692 nnc_result =
693 at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
694 ASSERT_TRUE(at::allclose(nnc_result, ref));
695#endif
696
697 SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result});
698 ir_eval.call({a_buf, b_buf, result_buf});
699 nnc_result =
700 at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
701 ASSERT_TRUE(at::allclose(nnc_result, ref));
702 }
703}
704
705TEST(ExternalCall, UnaryFloat) {
706 using TensorFunc = std::function<at::Tensor(at::Tensor)>;
707 auto toExprHandleVec = [](std::vector<int64_t> v) {
708 auto intV = std::vector<int>(v.begin(), v.end());
709 return std::vector<ExprHandle>(intV.begin(), intV.end());
710 };
711 using Test = std::tuple<
712 std::vector<int64_t>,
713 std::vector<int64_t>,
714 TensorFunc,
715 std::string,
716 std::vector<ExprHandle>>;
717 std::vector<Test> tests = {};
718 tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
719 {1, 64, 8, 9},
720 {1, 64, 5, 7},
721 [](at::Tensor x) {
722 return at::adaptive_avg_pool2d(x, {5, 7});
723 },
724 "nnc_aten_adaptive_avg_pool2d",
725 toExprHandleVec({5, 7})});
726 tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
727 {100, 200},
728 {100},
729 [](at::Tensor x) { return at::mean(x, {1}); },
730 "nnc_aten_mean",
731 toExprHandleVec({1, /*keepdim=*/0})});
732 for (auto curTest : tests) {
733 std::vector<int64_t> aShape, resShape;
734 TensorFunc torchFunc;
735 std::string externCallName;
736 std::vector<ExprHandle> externCallArgs;
737 std::tie(aShape, resShape, torchFunc, externCallName, externCallArgs) =
738 curTest;
739 BufHandle A("A", toExprHandleVec(aShape), kFloat);
740 BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
741
742 Tensor Result = Tensor(
743 ResultBuf.node(),
744 ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs));
745 LoopNest l({Result});
746 l.prepareForCodegen();
747 l.simplify();
748
749 auto options = at::TensorOptions()
750 .dtype(at::kFloat)
751 .layout(at::kStrided)
752 .device(at::kCPU)
753 .requires_grad(false);
754 at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
755 at::Tensor ref = torchFunc(a);
756
757 auto prod = [](std::vector<int64_t> v) {
758 // NOLINTNEXTLINE(modernize-use-transparent-functors)
759 return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
760 };
761
762 at::Tensor nnc_result;
763 std::vector<float> a_buf(prod(aShape), 5.f);
764 std::vector<float> result_buf(prod(resShape), -1.f);
765
766#ifdef TORCH_ENABLE_LLVM
767 LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result});
768
769 llvm_codegen.call({a_buf, result_buf});
770 nnc_result =
771 at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
772 ASSERT_TRUE(at::allclose(nnc_result, ref));
773#endif
774
775 SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result});
776 ir_eval.call({a_buf, result_buf});
777 nnc_result =
778 at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
779 ASSERT_TRUE(at::allclose(nnc_result, ref));
780 }
781}
782
783TEST(ExternalCall, ComputeInterop) {
784 // This test verifies that Tensors using external calls can be used by and can
785 // use Tensors built with Compute API.
786
787 BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat);
788 BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat);
789
790 Tensor Input = Compute(
791 "Input",
792 {1, 16, 32, 32},
793 [&](const VarHandle& n,
794 const VarHandle& c,
795 const VarHandle& h,
796 const VarHandle& w) { return FloatImm::make(5.0f); });
797 Tensor Weight = Compute(
798 "Weight",
799 {16, 16, 1, 1},
800 [&](const VarHandle& n,
801 const VarHandle& c,
802 const VarHandle& h,
803 const VarHandle& w) { return FloatImm::make(6.0f); });
804
805 Tensor ConvResult = Tensor(
806 ConvResultBuf.node(),
807 ExternalCall::make(
808 ConvResultBuf,
809 "nnc_aten_conv2d",
810 {BufHandle(Input.buf()), BufHandle(Weight.buf())},
811 {}));
812 Tensor MatmulResult = Tensor(
813 MatmulResultBuf.node(),
814 ExternalCall::make(
815 MatmulResultBuf,
816 "nnc_aten_matmul",
817 {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())},
818 {}));
819 Tensor Result = Compute(
820 "Result",
821 {1, 16, 32, 32},
822 [&](const VarHandle& n,
823 const VarHandle& c,
824 const VarHandle& h,
825 const VarHandle& w) {
826 return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w);
827 });
828
829 LoopNest l({Input, Weight, ConvResult, MatmulResult, Result});
830
831 // Inlining should not inline anything here since all Bufs are either defined
832 // or used in ExternalCalls - we run it just for testing
833 l.inlineIntermediateBufs(true);
834
835 l.prepareForCodegen();
836 l.simplify();
837
838 auto options = at::TensorOptions()
839 .dtype(at::kFloat)
840 .layout(at::kStrided)
841 .device(at::kCPU)
842 .requires_grad(false);
843 at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f;
844 at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;
845 at::Tensor t = at::conv2d(input, weight);
846 at::Tensor t2 = at::matmul(t, t);
847 at::Tensor ref = t + t2;
848
849 at::Tensor nnc_result;
850 std::vector<float> input_buf(1 * 16 * 32 * 32, 5.f);
851 std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);
852 std::vector<float> conv_result_buf(1 * 16 * 32 * 32, -1.f);
853 std::vector<float> matmul_result_buf(1 * 16 * 32 * 32, -1.f);
854 std::vector<float> result_buf(1 * 16 * 32 * 32, -1.f);
855
856#ifdef TORCH_ENABLE_LLVM
857 LLVMCodeGen llvm_codegen(
858 l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});
859
860 llvm_codegen.call(
861 {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});
862 nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);
863 ASSERT_TRUE(at::allclose(nnc_result, ref));
864#endif
865
866 SimpleIREvaluator ir_eval(
867 l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});
868
869 ir_eval.call(
870 {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});
871 nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);
872 ASSERT_TRUE(at::allclose(nnc_result, ref));
873}
874
875TEST(ExternalCall, Inlining) {
876 // This test verifies that Tensors using external calls can be used by and
877 // can use Tensors built with Compute API.
878
879 BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat);
880
881 Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
882 return FloatImm::make(5.0f);
883 });
884 Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
885 return FloatImm::make(4.0f);
886 });
887 Tensor MatmulResult = Tensor(
888 MatmulResultBuf.node(),
889 ExternalCall::make(
890 MatmulResultBuf,
891 "nnc_aten_matmul",
892 {BufHandle(A.buf()), BufHandle(B.buf())},
893 {}));
894 Tensor Result =
895 Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
896 return MatmulResult.load(i, j) + FloatImm::make(3.0f);
897 });
898
899 StmtPtr root_stmt = alloc<torch::jit::tensorexpr::Block>(std::vector<StmtPtr>(
900 {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()}));
901 LoopNest l(root_stmt, {Result.buf()});
902
903 // Inlining should not inline anything here since all Bufs are either
904 // defined or used in ExternalCalls
905 l.inlineIntermediateBufs(false);
906
907 l.prepareForCodegen();
908 l.simplify();
909
910 auto options = at::TensorOptions()
911 .dtype(at::kFloat)
912 .layout(at::kStrided)
913 .device(at::kCPU)
914 .requires_grad(false);
915 at::Tensor a = at::ones({8, 8}, options) * 5.f;
916 at::Tensor b = at::ones({8, 8}, options) * 4.f;
917 at::Tensor t = at::matmul(a, b);
918 at::Tensor ref = t + 3.f;
919
920 at::Tensor nnc_result;
921 std::vector<float> result_buf(8 * 8);
922
923#ifdef TORCH_ENABLE_LLVM
924 LLVMCodeGen llvm_codegen(l.root_stmt(), {Result});
925
926 llvm_codegen.call({result_buf});
927 nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);
928 ASSERT_TRUE(at::allclose(nnc_result, ref));
929#endif
930
931 SimpleIREvaluator ir_eval(l.root_stmt(), {Result});
932
933 ir_eval.call({result_buf});
934 nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);
935 ASSERT_TRUE(at::allclose(nnc_result, ref));
936}
937
938TEST(ExternalCall, JitCustomFusionOp) {
939 const char* custom_op_schema_literal =
940 "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor";
941 const char* external_func_name = "nnc_add_mul";
942
943 auto add_mul_lowering_func =
944 [external_func_name](
945 const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,
946 const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,
947 const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,
948 const c10::optional<torch::jit::tensorexpr::ScalarType>& output_type,
949 at::Device device) {
950 auto output_dtype = Dtype(*output_type);
951 torch::jit::tensorexpr::BufHandle result_buf(
952 "nnc_add_mul_res_buf", output_shape, output_dtype);
953 const torch::jit::tensorexpr::BufHandle& a =
954 c10::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
955 const torch::jit::tensorexpr::BufHandle& b =
956 c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
957 const torch::jit::tensorexpr::BufHandle& c =
958 c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
959 torch::jit::tensorexpr::StmtPtr s =
960 torch::jit::tensorexpr::ExternalCall::make(
961 result_buf, external_func_name, {a, b, c}, {});
962 return Tensor(result_buf.node(), s);
963 };
964
965 auto add_mul_external_func = [](int64_t bufs_num,
966 void** buf_data,
967 int64_t* buf_ranks,
968 int64_t* buf_dims,
969 int64_t* buf_strides,
970 int8_t* buf_dtypes,
971 int64_t args_num,
972 int64_t* extra_args) {};
973
974 torch::jit::RegisterOperators reg({Operator(
975 custom_op_schema_literal,
976 [](const Node* node) -> Operation {
977 return [](Stack& _stack) {
978 auto a = std::move(peek(_stack, 0, 3)).toTensor();
979 auto b = std::move(peek(_stack, 1, 3)).toTensor();
980 auto c = std::move(peek(_stack, 2, 3)).toTensor();
981 drop(_stack, 3);
982 auto result = (a + b) * c;
983 pack(_stack, std::move(result));
984 return 0;
985 };
986 },
987 c10::AliasAnalysisKind::FROM_SCHEMA)});
988
989 auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet();
990 custom_operator_set.insert({custom_op_schema_literal});
991
992 auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry();
993 te_lowering_registry.insert(
994 parseSchema(custom_op_schema_literal), add_mul_lowering_func);
995
996 auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry();
997 te_nnc_func_registry[external_func_name] = add_mul_external_func;
998
999 std::string graph_string = R"IR(
1000 graph(%a : Float(10, 20, strides=[20, 1], device=cpu),
1001 %b : Float(10, 20, strides=[20, 1], device=cpu),
1002 %c : Float(10, 20, strides=[20, 1], device=cpu)):
1003 %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c)
1004 return (%res))IR";
1005
1006 auto graph = std::make_shared<Graph>();
1007 torch::jit::parseIR(graph_string, graph.get());
1008
1009 std::string shape_compute_python_string = R"PY(
1010 def computOutput(a: List[int], b: List[int], c: List[int]):
1011 expandedSizes: List[int] = []
1012 dimsA = len(a)
1013 dimsB = len(b)
1014 dimsC = len(c)
1015 ndim = max(dimsA, dimsB, dimsC)
1016 for i in range(ndim):
1017 offset = ndim - 1 - i
1018 dimA = dimsA - 1 - offset
1019 dimB = dimsB - 1 - offset
1020 dimC = dimsC - 1 - offset
1021 sizeA = a[dimA] if (dimA >= 0) else 1
1022 sizeB = b[dimB] if (dimB >= 0) else 1
1023 sizeC = a[dimC] if (dimC >= 0) else 1
1024
1025 if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1:
1026 # TODO: only assertion error is bound in C++ compilation right now
1027 raise AssertionError(
1028 "The size of tensor a {} must match the size of tensor b ("
1029 "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i)
1030 )
1031
1032 expandedSizes.append(max(sizeA, sizeB, sizeC))
1033
1034 return expandedSizes
1035 )PY";
1036 auto cu_ptr = torch::jit::compile(shape_compute_python_string);
1037 torch::jit::GraphFunction* gf =
1038 (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput");
1039 ASSERT_TRUE(gf);
1040
1041#ifdef TORCH_ENABLE_LLVM
1042 auto static_graph_case = graph->copy();
1043 FuseTensorExprs(static_graph_case, 1);
1044 torch::jit::testing::FileCheck()
1045 .check("prim::TensorExprGroup_")
1046 ->check("nnc_custom::add_mul")
1047 ->run(*static_graph_case);
1048
1049 auto dynamic_graph_case = graph->copy();
1050 auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal);
1051 ASSERT_TRUE(custom_op);
1052 torch::jit::RegisterShapeComputeGraphForSchema(
1053 custom_op->schema(), gf->graph());
1054 FuseTensorExprs(dynamic_graph_case, 1, false, true);
1055 torch::jit::testing::FileCheck()
1056 .check("prim::TensorExprGroup_")
1057 ->check("nnc_custom::add_mul")
1058 ->run(*dynamic_graph_case);
1059#else
1060 torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph);
1061#endif
1062}
1063
1064} // namespace jit
1065} // namespace torch
1066