1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/tensorexpr/test_base.h> |
4 | |
5 | #include <torch/csrc/jit/ir/irparser.h> |
6 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
7 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
8 | #include <torch/csrc/jit/runtime/custom_operator.h> |
9 | #include <torch/csrc/jit/tensorexpr/kernel.h> |
10 | |
11 | #include <test/cpp/tensorexpr/test_utils.h> |
12 | #include <torch/csrc/jit/runtime/operator.h> |
13 | #include <torch/csrc/jit/runtime/symbolic_shape_registry.h> |
14 | #include <torch/csrc/jit/tensorexpr/eval.h> |
15 | #include <torch/csrc/jit/tensorexpr/external_functions_registry.h> |
16 | #include <torch/csrc/jit/tensorexpr/ir.h> |
17 | #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
18 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
19 | #include <torch/csrc/jit/tensorexpr/llvm_codegen.h> |
20 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
21 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
22 | |
23 | #include <torch/csrc/jit/testing/file_check.h> |
24 | #include <torch/jit.h> |
25 | |
26 | #include <ATen/NativeFunctions.h> |
27 | #include <ATen/core/dispatch/Dispatcher.h> |
28 | #include <ATen/native/xnnpack/OpContext.h> |
29 | |
30 | namespace torch { |
31 | namespace jit { |
32 | using namespace torch::jit::tensorexpr; |
33 | |
34 | TEST(ExternalCall, Conv1d_float) { |
35 | BufHandle Input("Input" , {1, 100, 115}, kFloat); |
36 | BufHandle Weight("Weight" , {100, 1, 7}, kFloat); |
37 | BufHandle Bias("Bias" , {100}, kFloat); |
38 | BufHandle ResultBuf("Result" , {1, 100, 115}, kFloat); |
39 | int64_t stride = 1; |
40 | int64_t pad = 3; |
41 | int64_t dilation = 1; |
42 | int64_t groups = 100; |
43 | |
44 | Tensor Result = Tensor( |
45 | ResultBuf.node(), |
46 | ExternalCall::make( |
47 | ResultBuf, |
48 | "nnc_aten_conv1d" , |
49 | {Input, Weight, Bias}, |
50 | {stride, pad, dilation, groups})); |
51 | LoopNest l({Result}); |
52 | l.prepareForCodegen(); |
53 | l.simplify(); |
54 | |
55 | auto options = at::TensorOptions() |
56 | .dtype(at::kFloat) |
57 | .layout(at::kStrided) |
58 | .device(at::kCPU) |
59 | .requires_grad(false); |
60 | at::Tensor input = at::ones({1, 100, 115}, options) * 5.f; |
61 | at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f; |
62 | at::Tensor bias = at::ones({100}, options) * 11.f; |
63 | at::Tensor ref = |
64 | at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); |
65 | |
66 | at::Tensor nnc_result; |
67 | std::vector<float> input_buf(1 * 100 * 115, 5.f); |
68 | std::vector<float> weight_buf(100 * 1 * 7, 6.f); |
69 | std::vector<float> bias_buf(100, 11.f); |
70 | std::vector<float> result_buf(1 * 100 * 115, -1.f); |
71 | |
72 | #ifdef TORCH_ENABLE_LLVM |
73 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); |
74 | |
75 | llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); |
76 | nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); |
77 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
78 | #endif |
79 | |
80 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); |
81 | |
82 | ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); |
83 | nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); |
84 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
85 | } |
86 | |
87 | TEST(ExternalCall, Conv1d_int) { |
88 | // A similar test, but now using kInt tensors |
89 | BufHandle Input("Input" , {1, 100, 115}, kInt); |
90 | BufHandle Weight("Weight" , {100, 1, 7}, kInt); |
91 | BufHandle Bias("Bias" , {100}, kInt); |
92 | BufHandle ResultBuf("Result" , {1, 100, 115}, kInt); |
93 | int64_t stride = 1; |
94 | int64_t pad = 3; |
95 | int64_t dilation = 1; |
96 | int64_t groups = 100; |
97 | |
98 | Tensor Result = Tensor( |
99 | ResultBuf.node(), |
100 | ExternalCall::make( |
101 | ResultBuf, |
102 | "nnc_aten_conv1d" , |
103 | {Input, Weight, Bias}, |
104 | {stride, pad, dilation, groups})); |
105 | LoopNest l({Result}); |
106 | l.prepareForCodegen(); |
107 | l.simplify(); |
108 | |
109 | auto options = at::TensorOptions() |
110 | .dtype(at::kInt) |
111 | .layout(at::kStrided) |
112 | .device(at::kCPU) |
113 | .requires_grad(false); |
114 | at::Tensor input = at::ones({1, 100, 115}, options) * 5; |
115 | at::Tensor weight = at::ones({100, 1, 7}, options) * 6; |
116 | at::Tensor bias = at::ones({100}, options) * 11; |
117 | at::Tensor ref = |
118 | at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); |
119 | |
120 | at::Tensor nnc_result; |
121 | std::vector<int32_t> input_buf(1 * 100 * 115, 5); |
122 | std::vector<int32_t> weight_buf(100 * 1 * 7, 6); |
123 | std::vector<int32_t> bias_buf(100, 11); |
124 | std::vector<int32_t> result_buf(1 * 100 * 115, -1); |
125 | |
126 | #ifdef TORCH_ENABLE_LLVM |
127 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); |
128 | |
129 | llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); |
130 | nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); |
131 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
132 | #endif |
133 | |
134 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); |
135 | |
136 | ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); |
137 | nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); |
138 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
139 | } |
140 | |
141 | TEST(ExternalCall, Conv1d_nobias_noargs) { |
142 | BufHandle Input("Input" , {1, 1, 115}, kFloat); |
143 | BufHandle Weight("Weight" , {10, 1, 7}, kFloat); |
144 | BufHandle ResultBuf("Result" , {1, 10, 109}, kFloat); |
145 | |
146 | Tensor Result = Tensor( |
147 | ResultBuf.node(), |
148 | ExternalCall::make(ResultBuf, "nnc_aten_conv1d" , {Input, Weight}, {})); |
149 | LoopNest l({Result}); |
150 | l.prepareForCodegen(); |
151 | l.simplify(); |
152 | |
153 | auto options = at::TensorOptions() |
154 | .dtype(at::kFloat) |
155 | .layout(at::kStrided) |
156 | .device(at::kCPU) |
157 | .requires_grad(false); |
158 | at::Tensor input = at::ones({1, 1, 115}, options) * 5.f; |
159 | at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f; |
160 | at::Tensor ref = at::conv1d(input, weight); |
161 | |
162 | at::Tensor nnc_result; |
163 | std::vector<float> input_buf(1 * 1 * 115, 5.f); |
164 | std::vector<float> weight_buf(10 * 1 * 7, 6.f); |
165 | std::vector<float> result_buf(1 * 10 * 109, -1.f); |
166 | |
167 | #ifdef TORCH_ENABLE_LLVM |
168 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); |
169 | |
170 | llvm_codegen.call({input_buf, weight_buf, result_buf}); |
171 | nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); |
172 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
173 | #endif |
174 | |
175 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); |
176 | |
177 | ir_eval.call({input_buf, weight_buf, result_buf}); |
178 | nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); |
179 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
180 | } |
181 | |
182 | TEST(ExternalCall, Conv2d_float) { |
183 | BufHandle Input("Input" , {1, 3, 224, 224}, kFloat); |
184 | BufHandle Weight("Weight" , {16, 3, 3, 3}, kFloat); |
185 | BufHandle Bias("Bias" , {16}, kFloat); |
186 | BufHandle ResultBuf("Result" , {1, 16, 112, 112}, kFloat); |
187 | int64_t stride = 2; |
188 | int64_t pad = 1; |
189 | int64_t dilation = 1; |
190 | int64_t groups = 1; |
191 | |
192 | Tensor Result = Tensor( |
193 | ResultBuf.node(), |
194 | ExternalCall::make( |
195 | ResultBuf, |
196 | "nnc_aten_conv2d" , |
197 | {Input, Weight, Bias}, |
198 | {stride, stride, pad, pad, dilation, dilation, groups})); |
199 | LoopNest l({Result}); |
200 | l.prepareForCodegen(); |
201 | l.simplify(); |
202 | |
203 | auto options = at::TensorOptions() |
204 | .dtype(at::kFloat) |
205 | .layout(at::kStrided) |
206 | .device(at::kCPU) |
207 | .requires_grad(false); |
208 | at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f; |
209 | at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f; |
210 | at::Tensor bias = at::ones({16}, options) * 11.f; |
211 | at::Tensor ref = at::conv2d( |
212 | input, |
213 | weight, |
214 | bias, |
215 | {stride, stride}, |
216 | {pad, pad}, |
217 | {dilation, dilation}, |
218 | groups); |
219 | |
220 | at::Tensor nnc_result; |
221 | std::vector<float> input_buf(1 * 3 * 224 * 224, 5.f); |
222 | std::vector<float> weight_buf(16 * 3 * 3 * 3, 6.f); |
223 | std::vector<float> bias_buf(16, 11.f); |
224 | std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f); |
225 | |
226 | #ifdef TORCH_ENABLE_LLVM |
227 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); |
228 | |
229 | llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); |
230 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); |
231 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
232 | #endif |
233 | |
234 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); |
235 | |
236 | ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); |
237 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); |
238 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
239 | } |
240 | |
241 | TEST(ExternalCall, Conv2d_int) { |
242 | // A similar test, but now using kInt tensors |
243 | |
244 | BufHandle Input("Input" , {1, 3, 224, 224}, kInt); |
245 | BufHandle Weight("Weight" , {16, 3, 3, 3}, kInt); |
246 | BufHandle Bias("Bias" , {16}, kInt); |
247 | BufHandle ResultBuf("Result" , {1, 16, 112, 112}, kInt); |
248 | int64_t stride = 2; |
249 | int64_t pad = 1; |
250 | int64_t dilation = 1; |
251 | int64_t groups = 1; |
252 | |
253 | Tensor Result = Tensor( |
254 | ResultBuf.node(), |
255 | ExternalCall::make( |
256 | ResultBuf, |
257 | "nnc_aten_conv2d" , |
258 | {Input, Weight, Bias}, |
259 | {stride, stride, pad, pad, dilation, dilation, groups})); |
260 | LoopNest l({Result}); |
261 | l.prepareForCodegen(); |
262 | l.simplify(); |
263 | |
264 | auto options = at::TensorOptions() |
265 | .dtype(at::kInt) |
266 | .layout(at::kStrided) |
267 | .device(at::kCPU) |
268 | .requires_grad(false); |
269 | at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5; |
270 | at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6; |
271 | at::Tensor bias = at::ones({16}, options) * 11; |
272 | at::Tensor ref = at::conv2d( |
273 | input, |
274 | weight, |
275 | bias, |
276 | {stride, stride}, |
277 | {pad, pad}, |
278 | {dilation, dilation}, |
279 | groups); |
280 | |
281 | at::Tensor nnc_result; |
282 | std::vector<int32_t> input_buf(1 * 3 * 224 * 224, 5); |
283 | std::vector<int32_t> weight_buf(16 * 3 * 3 * 3, 6); |
284 | std::vector<int32_t> bias_buf(16, 11); |
285 | std::vector<int32_t> result_buf(1 * 16 * 112 * 112, -1); |
286 | |
287 | #ifdef TORCH_ENABLE_LLVM |
288 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); |
289 | |
290 | llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); |
291 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); |
292 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
293 | #endif |
294 | |
295 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); |
296 | |
297 | ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); |
298 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); |
299 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
300 | } |
301 | |
302 | TEST(ExternalCall, Conv2d_nobias_noargs) { |
303 | BufHandle Input("Input" , {1, 16, 112, 112}, kFloat); |
304 | BufHandle Weight("Weight" , {16, 16, 1, 1}, kFloat); |
305 | BufHandle ResultBuf("Result" , {1, 16, 112, 112}, kFloat); |
306 | |
307 | Tensor Result = Tensor( |
308 | ResultBuf.node(), |
309 | ExternalCall::make(ResultBuf, "nnc_aten_conv2d" , {Input, Weight}, {})); |
310 | LoopNest l({Result}); |
311 | l.prepareForCodegen(); |
312 | l.simplify(); |
313 | |
314 | auto options = at::TensorOptions() |
315 | .dtype(at::kFloat) |
316 | .layout(at::kStrided) |
317 | .device(at::kCPU) |
318 | .requires_grad(false); |
319 | at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f; |
320 | at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; |
321 | at::Tensor ref = at::conv2d(input, weight); |
322 | |
323 | at::Tensor nnc_result; |
324 | std::vector<float> input_buf(1 * 16 * 112 * 112, 5.f); |
325 | std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f); |
326 | std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f); |
327 | |
328 | #ifdef TORCH_ENABLE_LLVM |
329 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); |
330 | |
331 | llvm_codegen.call({input_buf, weight_buf, result_buf}); |
332 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); |
333 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
334 | #endif |
335 | |
336 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); |
337 | |
338 | ir_eval.call({input_buf, weight_buf, result_buf}); |
339 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); |
340 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
341 | } |
342 | |
343 | TEST(ExternalCall, Addmm_float) { |
344 | BufHandle Input("Input" , {100, 300}, kFloat); |
345 | BufHandle Mat1("Mat1" , {100, 200}, kFloat); |
346 | BufHandle Mat2("Mat2" , {200, 300}, kFloat); |
347 | BufHandle ResultBuf("Result" , {100, 300}, kFloat); |
348 | int64_t beta = 2; |
349 | int64_t alpha = 2; |
350 | |
351 | Tensor Result = Tensor( |
352 | ResultBuf.node(), |
353 | ExternalCall::make( |
354 | ResultBuf, "nnc_aten_addmm" , {Input, Mat1, Mat2}, {beta, alpha})); |
355 | LoopNest l({Result}); |
356 | l.prepareForCodegen(); |
357 | l.simplify(); |
358 | |
359 | auto options = at::TensorOptions() |
360 | .dtype(at::kFloat) |
361 | .layout(at::kStrided) |
362 | .device(at::kCPU) |
363 | .requires_grad(false); |
364 | at::Tensor input = at::ones({100, 300}, options) * 5.f; |
365 | at::Tensor mat1 = at::ones({100, 200}, options) * 6.f; |
366 | at::Tensor mat2 = at::ones({200, 300}, options) * 11.f; |
367 | at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha); |
368 | |
369 | at::Tensor nnc_result; |
370 | std::vector<float> input_buf(100 * 300, 5.f); |
371 | std::vector<float> mat1_buf(100 * 200, 6.f); |
372 | std::vector<float> mat2_buf(200 * 300, 11.f); |
373 | std::vector<float> result_buf(100 * 300, -1.f); |
374 | |
375 | #ifdef TORCH_ENABLE_LLVM |
376 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result}); |
377 | |
378 | llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf}); |
379 | nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); |
380 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
381 | #endif |
382 | |
383 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result}); |
384 | |
385 | ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf}); |
386 | nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); |
387 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
388 | } |
389 | |
390 | TEST(ExternalCall, Embedding) { |
391 | BufHandle Weight("Weight" , {256, 100}, kFloat); |
392 | BufHandle Indices("Indices" , {1, 115}, kLong); |
393 | BufHandle ResultBuf("Result" , {1, 115, 100}, kFloat); |
394 | int64_t padding_idx = -1; |
395 | bool scale_grad_by_freq = false; |
396 | bool sparse = false; |
397 | |
398 | Tensor Result = Tensor( |
399 | ResultBuf.node(), |
400 | ExternalCall::make( |
401 | ResultBuf, |
402 | "nnc_aten_embedding" , |
403 | {Weight, Indices}, |
404 | {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse})); |
405 | LoopNest l({Result}); |
406 | l.prepareForCodegen(); |
407 | l.simplify(); |
408 | |
409 | auto options = at::TensorOptions() |
410 | .layout(at::kStrided) |
411 | .device(at::kCPU) |
412 | .requires_grad(false); |
413 | |
414 | at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f; |
415 | at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6; |
416 | at::Tensor ref = |
417 | at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); |
418 | |
419 | at::Tensor nnc_result; |
420 | std::vector<float> weight_buf(256 * 100, 5.f); |
421 | std::vector<int64_t> indices_buf(1 * 115, 6); |
422 | std::vector<float> result_buf(1 * 115 * 100, -1.f); |
423 | |
424 | #ifdef TORCH_ENABLE_LLVM |
425 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result}); |
426 | |
427 | llvm_codegen.call({weight_buf, indices_buf, result_buf}); |
428 | nnc_result = at::from_blob( |
429 | result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); |
430 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
431 | #endif |
432 | |
433 | SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result}); |
434 | |
435 | ir_eval.call({weight_buf, indices_buf, result_buf}); |
436 | nnc_result = at::from_blob( |
437 | result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); |
438 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
439 | } |
440 | |
441 | TEST(ExternalCall, MaxReduction) { |
442 | BufHandle Input("Input" , {1, 115, 152}, kFloat); |
443 | BufHandle ResultBuf("Result" , {1, 152}, kFloat); |
444 | int64_t dim = 1; |
445 | bool keep_dim = false; |
446 | |
447 | Tensor Result = Tensor( |
448 | ResultBuf.node(), |
449 | ExternalCall::make( |
450 | ResultBuf, "nnc_aten_max_red" , {Input}, {dim, (int64_t)keep_dim})); |
451 | LoopNest l({Result}); |
452 | l.prepareForCodegen(); |
453 | l.simplify(); |
454 | |
455 | auto options = at::TensorOptions() |
456 | .dtype(at::kFloat) |
457 | .layout(at::kStrided) |
458 | .device(at::kCPU) |
459 | .requires_grad(false); |
460 | |
461 | at::Tensor input = at::ones({1, 115, 152}, options) * 5.f; |
462 | at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim)); |
463 | |
464 | at::Tensor nnc_result; |
465 | std::vector<float> input_buf(1 * 115 * 152, 5.f); |
466 | std::vector<float> result_buf(1 * 152, -1.f); |
467 | |
468 | #ifdef TORCH_ENABLE_LLVM |
469 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result}); |
470 | |
471 | llvm_codegen.call({input_buf, result_buf}); |
472 | nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); |
473 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
474 | #endif |
475 | |
476 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result}); |
477 | |
478 | ir_eval.call({input_buf, result_buf}); |
479 | nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); |
480 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
481 | } |
482 | |
483 | #ifdef USE_XNNPACK |
484 | |
485 | TEST(ExternalCall, Prepacked_Linear_float) { |
486 | using namespace at::native::xnnpack; |
487 | |
488 | BufHandle Input("Input" , {100, 200}, kFloat); |
489 | BufHandle ResultBuf("Result" , {100, 300}, kFloat); |
490 | |
491 | // Calculate reference result using at::linear. |
492 | auto options = at::TensorOptions() |
493 | .dtype(at::kFloat) |
494 | .layout(at::kStrided) |
495 | .device(at::kCPU) |
496 | .requires_grad(false); |
497 | at::Tensor input = |
498 | at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200}); |
499 | at::Tensor weight = |
500 | at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200}); |
501 | at::Tensor bias = at::linspace(-10.0, 10.0, 300, options); |
502 | at::Tensor ref = at::linear(input, weight, bias); |
503 | |
504 | // Create prepacked xnnpack context object. |
505 | auto linear_clamp_prepack_op = |
506 | c10::Dispatcher::singleton() |
507 | .findSchemaOrThrow("prepacked::linear_clamp_prepack" , "" ) |
508 | .typed<c10::intrusive_ptr<LinearOpContext>( |
509 | at::Tensor, |
510 | c10::optional<at::Tensor>, |
511 | const c10::optional<at::Scalar>&, |
512 | const c10::optional<at::Scalar>&)>(); |
513 | auto prepacked = linear_clamp_prepack_op.call( |
514 | weight, bias, c10::optional<at::Scalar>(), c10::optional<at::Scalar>()); |
515 | |
516 | BufHandle DummyPrepacked("DummyPrepacked" , {1}, kFloat); |
517 | Tensor Result = Tensor( |
518 | ResultBuf.node(), |
519 | ExternalCall::make( |
520 | ResultBuf, |
521 | "nnc_prepacked_linear_clamp_run" , |
522 | {Input, DummyPrepacked}, |
523 | {})); |
524 | LoopNest l({Result}); |
525 | l.prepareForCodegen(); |
526 | l.simplify(); |
527 | |
528 | at::Tensor nnc_result; |
529 | std::vector<float> input_buf( |
530 | input.data_ptr<float>(), input.data_ptr<float>() + 100 * 200); |
531 | std::vector<float> result_buf(100 * 300, -1.f); |
532 | |
533 | #ifdef TORCH_ENABLE_LLVM |
534 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); |
535 | |
536 | llvm_codegen.call({input_buf, prepacked.get(), result_buf}); |
537 | nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); |
538 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
539 | #endif |
540 | |
541 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); |
542 | |
543 | ir_eval.call({input_buf, prepacked.get(), result_buf}); |
544 | nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); |
545 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
546 | } |
547 | |
548 | TEST(ExternalCall, Prepacked_Conv2d_float) { |
549 | using namespace at::native::xnnpack; |
550 | |
551 | BufHandle Input("Input" , {1, 3, 224, 224}, kFloat); |
552 | BufHandle ResultBuf("Result" , {1, 16, 112, 112}, kFloat); |
553 | int64_t stride = 2; |
554 | int64_t pad = 1; |
555 | int64_t dilation = 1; |
556 | int64_t groups = 1; |
557 | |
558 | // Calculate reference result using at::conv2d. |
559 | auto options = at::TensorOptions() |
560 | .dtype(at::kFloat) |
561 | .layout(at::kStrided) |
562 | .device(at::kCPU) |
563 | .requires_grad(false); |
564 | at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options) |
565 | .resize_({1, 3, 224, 224}); |
566 | at::Tensor weight = |
567 | at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3}); |
568 | at::Tensor bias = at::linspace(-10.0, 10.0, 16, options); |
569 | at::Tensor ref = at::conv2d( |
570 | input, |
571 | weight, |
572 | bias, |
573 | {stride, stride}, |
574 | {pad, pad}, |
575 | {dilation, dilation}, |
576 | groups); |
577 | |
578 | // Create prepacked xnnpack context object. |
579 | auto conv2d_clamp_prepack_op = |
580 | c10::Dispatcher::singleton() |
581 | .findSchemaOrThrow("prepacked::conv2d_clamp_prepack" , "" ) |
582 | .typed<c10::intrusive_ptr<Conv2dOpContext>( |
583 | at::Tensor, |
584 | c10::optional<at::Tensor>, |
585 | std::vector<int64_t>, |
586 | std::vector<int64_t>, |
587 | std::vector<int64_t>, |
588 | int64_t, |
589 | const c10::optional<at::Scalar>&, |
590 | const c10::optional<at::Scalar>&)>(); |
591 | auto prepacked = conv2d_clamp_prepack_op.call( |
592 | weight, |
593 | bias, |
594 | {stride, stride}, |
595 | {pad, pad}, |
596 | {dilation, dilation}, |
597 | groups, |
598 | c10::optional<at::Scalar>(), |
599 | c10::optional<at::Scalar>()); |
600 | |
601 | BufHandle DummyPrepacked("DummyPrepacked" , {1}, kFloat); |
602 | Tensor Result = Tensor( |
603 | ResultBuf.node(), |
604 | ExternalCall::make( |
605 | ResultBuf, |
606 | "nnc_prepacked_conv2d_clamp_run" , |
607 | {Input, DummyPrepacked}, |
608 | {})); |
609 | LoopNest l({Result}); |
610 | l.prepareForCodegen(); |
611 | l.simplify(); |
612 | |
613 | at::Tensor nnc_result; |
614 | std::vector<float> input_buf( |
615 | input.data_ptr<float>(), input.data_ptr<float>() + 1 * 3 * 224 * 224); |
616 | std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f); |
617 | |
618 | #ifdef TORCH_ENABLE_LLVM |
619 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); |
620 | |
621 | llvm_codegen.call({input_buf, prepacked.get(), result_buf}); |
622 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); |
623 | ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); |
624 | #endif |
625 | |
626 | SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); |
627 | |
628 | ir_eval.call({input_buf, prepacked.get(), result_buf}); |
629 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); |
630 | ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); |
631 | } |
632 | |
633 | #endif // USE_XNNPACK |
634 | |
635 | TEST(ExternalCall, BinaryFloat) { |
636 | using TensorFunc = std::function<at::Tensor(at::Tensor, at::Tensor)>; |
637 | using Test = std::tuple< |
638 | std::vector<int64_t>, |
639 | std::vector<int64_t>, |
640 | std::vector<int64_t>, |
641 | TensorFunc, |
642 | std::string>; |
643 | std::vector<Test> tests = {}; |
644 | tests.push_back( |
645 | Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul" }); |
646 | tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv" }); |
647 | tests.push_back( |
648 | Test{{100, 200}, {200, 300}, {100, 300}, at::mm, "nnc_aten_mm" }); |
649 | for (auto curTest : tests) { |
650 | std::vector<int64_t> aShape, bShape, resShape; |
651 | TensorFunc torchFunc; |
652 | std::string externCallName; |
653 | std::tie(aShape, bShape, resShape, torchFunc, externCallName) = curTest; |
654 | auto toExprHandleVec = [](std::vector<int64_t> v) { |
655 | auto intV = std::vector<int>(v.begin(), v.end()); |
656 | return std::vector<ExprHandle>(intV.begin(), intV.end()); |
657 | }; |
658 | BufHandle A("A" , toExprHandleVec(aShape), kFloat); |
659 | BufHandle B("B" , toExprHandleVec(bShape), kFloat); |
660 | BufHandle ResultBuf("Result" , toExprHandleVec(resShape), kFloat); |
661 | |
662 | Tensor Result = Tensor( |
663 | ResultBuf.node(), |
664 | ExternalCall::make(ResultBuf, externCallName, {A, B}, {})); |
665 | LoopNest l({Result}); |
666 | l.prepareForCodegen(); |
667 | l.simplify(); |
668 | |
669 | auto options = at::TensorOptions() |
670 | .dtype(at::kFloat) |
671 | .layout(at::kStrided) |
672 | .device(at::kCPU) |
673 | .requires_grad(false); |
674 | at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; |
675 | at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f; |
676 | at::Tensor ref = torchFunc(a, b); |
677 | |
678 | auto prod = [](std::vector<int64_t> v) { |
679 | // NOLINTNEXTLINE(modernize-use-transparent-functors) |
680 | return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>()); |
681 | }; |
682 | |
683 | at::Tensor nnc_result; |
684 | std::vector<float> a_buf(prod(aShape), 5.f); |
685 | std::vector<float> b_buf(prod(bShape), 6.f); |
686 | std::vector<float> result_buf(prod(resShape), -1.f); |
687 | |
688 | #ifdef TORCH_ENABLE_LLVM |
689 | LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result}); |
690 | |
691 | llvm_codegen.call({a_buf, b_buf, result_buf}); |
692 | nnc_result = |
693 | at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); |
694 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
695 | #endif |
696 | |
697 | SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result}); |
698 | ir_eval.call({a_buf, b_buf, result_buf}); |
699 | nnc_result = |
700 | at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); |
701 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
702 | } |
703 | } |
704 | |
705 | TEST(ExternalCall, UnaryFloat) { |
706 | using TensorFunc = std::function<at::Tensor(at::Tensor)>; |
707 | auto toExprHandleVec = [](std::vector<int64_t> v) { |
708 | auto intV = std::vector<int>(v.begin(), v.end()); |
709 | return std::vector<ExprHandle>(intV.begin(), intV.end()); |
710 | }; |
711 | using Test = std::tuple< |
712 | std::vector<int64_t>, |
713 | std::vector<int64_t>, |
714 | TensorFunc, |
715 | std::string, |
716 | std::vector<ExprHandle>>; |
717 | std::vector<Test> tests = {}; |
718 | tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
719 | {1, 64, 8, 9}, |
720 | {1, 64, 5, 7}, |
721 | [](at::Tensor x) { |
722 | return at::adaptive_avg_pool2d(x, {5, 7}); |
723 | }, |
724 | "nnc_aten_adaptive_avg_pool2d" , |
725 | toExprHandleVec({5, 7})}); |
726 | tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
727 | {100, 200}, |
728 | {100}, |
729 | [](at::Tensor x) { return at::mean(x, {1}); }, |
730 | "nnc_aten_mean" , |
731 | toExprHandleVec({1, /*keepdim=*/0})}); |
732 | for (auto curTest : tests) { |
733 | std::vector<int64_t> aShape, resShape; |
734 | TensorFunc torchFunc; |
735 | std::string externCallName; |
736 | std::vector<ExprHandle> externCallArgs; |
737 | std::tie(aShape, resShape, torchFunc, externCallName, externCallArgs) = |
738 | curTest; |
739 | BufHandle A("A" , toExprHandleVec(aShape), kFloat); |
740 | BufHandle ResultBuf("Result" , toExprHandleVec(resShape), kFloat); |
741 | |
742 | Tensor Result = Tensor( |
743 | ResultBuf.node(), |
744 | ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs)); |
745 | LoopNest l({Result}); |
746 | l.prepareForCodegen(); |
747 | l.simplify(); |
748 | |
749 | auto options = at::TensorOptions() |
750 | .dtype(at::kFloat) |
751 | .layout(at::kStrided) |
752 | .device(at::kCPU) |
753 | .requires_grad(false); |
754 | at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; |
755 | at::Tensor ref = torchFunc(a); |
756 | |
757 | auto prod = [](std::vector<int64_t> v) { |
758 | // NOLINTNEXTLINE(modernize-use-transparent-functors) |
759 | return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>()); |
760 | }; |
761 | |
762 | at::Tensor nnc_result; |
763 | std::vector<float> a_buf(prod(aShape), 5.f); |
764 | std::vector<float> result_buf(prod(resShape), -1.f); |
765 | |
766 | #ifdef TORCH_ENABLE_LLVM |
767 | LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result}); |
768 | |
769 | llvm_codegen.call({a_buf, result_buf}); |
770 | nnc_result = |
771 | at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); |
772 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
773 | #endif |
774 | |
775 | SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result}); |
776 | ir_eval.call({a_buf, result_buf}); |
777 | nnc_result = |
778 | at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); |
779 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
780 | } |
781 | } |
782 | |
783 | TEST(ExternalCall, ComputeInterop) { |
784 | // This test verifies that Tensors using external calls can be used by and can |
785 | // use Tensors built with Compute API. |
786 | |
787 | BufHandle ConvResultBuf("ConvResult" , {1, 16, 32, 32}, kFloat); |
788 | BufHandle MatmulResultBuf("MatmulResult" , {1, 16, 32, 32}, kFloat); |
789 | |
790 | Tensor Input = Compute( |
791 | "Input" , |
792 | {1, 16, 32, 32}, |
793 | [&](const VarHandle& n, |
794 | const VarHandle& c, |
795 | const VarHandle& h, |
796 | const VarHandle& w) { return FloatImm::make(5.0f); }); |
797 | Tensor Weight = Compute( |
798 | "Weight" , |
799 | {16, 16, 1, 1}, |
800 | [&](const VarHandle& n, |
801 | const VarHandle& c, |
802 | const VarHandle& h, |
803 | const VarHandle& w) { return FloatImm::make(6.0f); }); |
804 | |
805 | Tensor ConvResult = Tensor( |
806 | ConvResultBuf.node(), |
807 | ExternalCall::make( |
808 | ConvResultBuf, |
809 | "nnc_aten_conv2d" , |
810 | {BufHandle(Input.buf()), BufHandle(Weight.buf())}, |
811 | {})); |
812 | Tensor MatmulResult = Tensor( |
813 | MatmulResultBuf.node(), |
814 | ExternalCall::make( |
815 | MatmulResultBuf, |
816 | "nnc_aten_matmul" , |
817 | {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())}, |
818 | {})); |
819 | Tensor Result = Compute( |
820 | "Result" , |
821 | {1, 16, 32, 32}, |
822 | [&](const VarHandle& n, |
823 | const VarHandle& c, |
824 | const VarHandle& h, |
825 | const VarHandle& w) { |
826 | return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w); |
827 | }); |
828 | |
829 | LoopNest l({Input, Weight, ConvResult, MatmulResult, Result}); |
830 | |
831 | // Inlining should not inline anything here since all Bufs are either defined |
832 | // or used in ExternalCalls - we run it just for testing |
833 | l.inlineIntermediateBufs(true); |
834 | |
835 | l.prepareForCodegen(); |
836 | l.simplify(); |
837 | |
838 | auto options = at::TensorOptions() |
839 | .dtype(at::kFloat) |
840 | .layout(at::kStrided) |
841 | .device(at::kCPU) |
842 | .requires_grad(false); |
843 | at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f; |
844 | at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; |
845 | at::Tensor t = at::conv2d(input, weight); |
846 | at::Tensor t2 = at::matmul(t, t); |
847 | at::Tensor ref = t + t2; |
848 | |
849 | at::Tensor nnc_result; |
850 | std::vector<float> input_buf(1 * 16 * 32 * 32, 5.f); |
851 | std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f); |
852 | std::vector<float> conv_result_buf(1 * 16 * 32 * 32, -1.f); |
853 | std::vector<float> matmul_result_buf(1 * 16 * 32 * 32, -1.f); |
854 | std::vector<float> result_buf(1 * 16 * 32 * 32, -1.f); |
855 | |
856 | #ifdef TORCH_ENABLE_LLVM |
857 | LLVMCodeGen llvm_codegen( |
858 | l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); |
859 | |
860 | llvm_codegen.call( |
861 | {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); |
862 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); |
863 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
864 | #endif |
865 | |
866 | SimpleIREvaluator ir_eval( |
867 | l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); |
868 | |
869 | ir_eval.call( |
870 | {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); |
871 | nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); |
872 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
873 | } |
874 | |
875 | TEST(ExternalCall, Inlining) { |
876 | // This test verifies that Tensors using external calls can be used by and |
877 | // can use Tensors built with Compute API. |
878 | |
879 | BufHandle MatmulResultBuf("MatmulResult" , {8, 8}, kFloat); |
880 | |
881 | Tensor A = Compute("A" , {8, 8}, [&](const VarHandle& i, const VarHandle& j) { |
882 | return FloatImm::make(5.0f); |
883 | }); |
884 | Tensor B = Compute("B" , {8, 8}, [&](const VarHandle& i, const VarHandle& j) { |
885 | return FloatImm::make(4.0f); |
886 | }); |
887 | Tensor MatmulResult = Tensor( |
888 | MatmulResultBuf.node(), |
889 | ExternalCall::make( |
890 | MatmulResultBuf, |
891 | "nnc_aten_matmul" , |
892 | {BufHandle(A.buf()), BufHandle(B.buf())}, |
893 | {})); |
894 | Tensor Result = |
895 | Compute("Result" , {8, 8}, [&](const VarHandle& i, const VarHandle& j) { |
896 | return MatmulResult.load(i, j) + FloatImm::make(3.0f); |
897 | }); |
898 | |
899 | StmtPtr root_stmt = alloc<torch::jit::tensorexpr::Block>(std::vector<StmtPtr>( |
900 | {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()})); |
901 | LoopNest l(root_stmt, {Result.buf()}); |
902 | |
903 | // Inlining should not inline anything here since all Bufs are either |
904 | // defined or used in ExternalCalls |
905 | l.inlineIntermediateBufs(false); |
906 | |
907 | l.prepareForCodegen(); |
908 | l.simplify(); |
909 | |
910 | auto options = at::TensorOptions() |
911 | .dtype(at::kFloat) |
912 | .layout(at::kStrided) |
913 | .device(at::kCPU) |
914 | .requires_grad(false); |
915 | at::Tensor a = at::ones({8, 8}, options) * 5.f; |
916 | at::Tensor b = at::ones({8, 8}, options) * 4.f; |
917 | at::Tensor t = at::matmul(a, b); |
918 | at::Tensor ref = t + 3.f; |
919 | |
920 | at::Tensor nnc_result; |
921 | std::vector<float> result_buf(8 * 8); |
922 | |
923 | #ifdef TORCH_ENABLE_LLVM |
924 | LLVMCodeGen llvm_codegen(l.root_stmt(), {Result}); |
925 | |
926 | llvm_codegen.call({result_buf}); |
927 | nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); |
928 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
929 | #endif |
930 | |
931 | SimpleIREvaluator ir_eval(l.root_stmt(), {Result}); |
932 | |
933 | ir_eval.call({result_buf}); |
934 | nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); |
935 | ASSERT_TRUE(at::allclose(nnc_result, ref)); |
936 | } |
937 | |
938 | TEST(ExternalCall, JitCustomFusionOp) { |
939 | const char* custom_op_schema_literal = |
940 | "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor" ; |
941 | const char* external_func_name = "nnc_add_mul" ; |
942 | |
943 | auto add_mul_lowering_func = |
944 | [external_func_name]( |
945 | const std::vector<torch::jit::tensorexpr::ArgValue>& inputs, |
946 | const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape, |
947 | const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides, |
948 | const c10::optional<torch::jit::tensorexpr::ScalarType>& output_type, |
949 | at::Device device) { |
950 | auto output_dtype = Dtype(*output_type); |
951 | torch::jit::tensorexpr::BufHandle result_buf( |
952 | "nnc_add_mul_res_buf" , output_shape, output_dtype); |
953 | const torch::jit::tensorexpr::BufHandle& a = |
954 | c10::get<torch::jit::tensorexpr::BufHandle>(inputs[0]); |
955 | const torch::jit::tensorexpr::BufHandle& b = |
956 | c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]); |
957 | const torch::jit::tensorexpr::BufHandle& c = |
958 | c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]); |
959 | torch::jit::tensorexpr::StmtPtr s = |
960 | torch::jit::tensorexpr::ExternalCall::make( |
961 | result_buf, external_func_name, {a, b, c}, {}); |
962 | return Tensor(result_buf.node(), s); |
963 | }; |
964 | |
965 | auto add_mul_external_func = [](int64_t bufs_num, |
966 | void** buf_data, |
967 | int64_t* buf_ranks, |
968 | int64_t* buf_dims, |
969 | int64_t* buf_strides, |
970 | int8_t* buf_dtypes, |
971 | int64_t args_num, |
972 | int64_t* ) {}; |
973 | |
974 | torch::jit::RegisterOperators reg({Operator( |
975 | custom_op_schema_literal, |
976 | [](const Node* node) -> Operation { |
977 | return [](Stack& _stack) { |
978 | auto a = std::move(peek(_stack, 0, 3)).toTensor(); |
979 | auto b = std::move(peek(_stack, 1, 3)).toTensor(); |
980 | auto c = std::move(peek(_stack, 2, 3)).toTensor(); |
981 | drop(_stack, 3); |
982 | auto result = (a + b) * c; |
983 | pack(_stack, std::move(result)); |
984 | return 0; |
985 | }; |
986 | }, |
987 | c10::AliasAnalysisKind::FROM_SCHEMA)}); |
988 | |
989 | auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet(); |
990 | custom_operator_set.insert({custom_op_schema_literal}); |
991 | |
992 | auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry(); |
993 | te_lowering_registry.insert( |
994 | parseSchema(custom_op_schema_literal), add_mul_lowering_func); |
995 | |
996 | auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry(); |
997 | te_nnc_func_registry[external_func_name] = add_mul_external_func; |
998 | |
999 | std::string graph_string = R"IR( |
1000 | graph(%a : Float(10, 20, strides=[20, 1], device=cpu), |
1001 | %b : Float(10, 20, strides=[20, 1], device=cpu), |
1002 | %c : Float(10, 20, strides=[20, 1], device=cpu)): |
1003 | %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c) |
1004 | return (%res))IR" ; |
1005 | |
1006 | auto graph = std::make_shared<Graph>(); |
1007 | torch::jit::parseIR(graph_string, graph.get()); |
1008 | |
1009 | std::string shape_compute_python_string = R"PY( |
1010 | def computOutput(a: List[int], b: List[int], c: List[int]): |
1011 | expandedSizes: List[int] = [] |
1012 | dimsA = len(a) |
1013 | dimsB = len(b) |
1014 | dimsC = len(c) |
1015 | ndim = max(dimsA, dimsB, dimsC) |
1016 | for i in range(ndim): |
1017 | offset = ndim - 1 - i |
1018 | dimA = dimsA - 1 - offset |
1019 | dimB = dimsB - 1 - offset |
1020 | dimC = dimsC - 1 - offset |
1021 | sizeA = a[dimA] if (dimA >= 0) else 1 |
1022 | sizeB = b[dimB] if (dimB >= 0) else 1 |
1023 | sizeC = a[dimC] if (dimC >= 0) else 1 |
1024 | |
1025 | if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1: |
1026 | # TODO: only assertion error is bound in C++ compilation right now |
1027 | raise AssertionError( |
1028 | "The size of tensor a {} must match the size of tensor b (" |
1029 | "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i) |
1030 | ) |
1031 | |
1032 | expandedSizes.append(max(sizeA, sizeB, sizeC)) |
1033 | |
1034 | return expandedSizes |
1035 | )PY" ; |
1036 | auto cu_ptr = torch::jit::compile(shape_compute_python_string); |
1037 | torch::jit::GraphFunction* gf = |
1038 | (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput" ); |
1039 | ASSERT_TRUE(gf); |
1040 | |
1041 | #ifdef TORCH_ENABLE_LLVM |
1042 | auto static_graph_case = graph->copy(); |
1043 | FuseTensorExprs(static_graph_case, 1); |
1044 | torch::jit::testing::FileCheck() |
1045 | .check("prim::TensorExprGroup_" ) |
1046 | ->check("nnc_custom::add_mul" ) |
1047 | ->run(*static_graph_case); |
1048 | |
1049 | auto dynamic_graph_case = graph->copy(); |
1050 | auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal); |
1051 | ASSERT_TRUE(custom_op); |
1052 | torch::jit::RegisterShapeComputeGraphForSchema( |
1053 | custom_op->schema(), gf->graph()); |
1054 | FuseTensorExprs(dynamic_graph_case, 1, false, true); |
1055 | torch::jit::testing::FileCheck() |
1056 | .check("prim::TensorExprGroup_" ) |
1057 | ->check("nnc_custom::add_mul" ) |
1058 | ->run(*dynamic_graph_case); |
1059 | #else |
1060 | torch::jit::testing::FileCheck().check("nnc_custom::add_mul" )->run(*graph); |
1061 | #endif |
1062 | } |
1063 | |
1064 | } // namespace jit |
1065 | } // namespace torch |
1066 | |