1 | #if defined(USE_CUDA) |
2 | #include <gtest/gtest.h> |
3 | |
4 | #include <arith.h> |
5 | #include <codegen.h> |
6 | #include <disjoint_set.h> |
7 | #include <executor.h> |
8 | #include <executor_launch_params.h> |
9 | #include <expr_evaluator.h> |
10 | #include <fusion.h> |
11 | #include <fusion_segmenter.h> |
12 | #include <ir_all_nodes.h> |
13 | #include <ir_builder.h> |
14 | #include <ir_graphviz.h> |
15 | #include <ir_iostream.h> |
16 | #include <ir_utils.h> |
17 | #include <iter_visitor.h> |
18 | #include <kernel_cache.h> |
19 | #include <kernel_expr_evaluator.h> |
20 | #include <kernel_ir.h> |
21 | #include <lower2device.h> |
22 | #include <mutator.h> |
23 | #include <ops/all_ops.h> |
24 | #include <register_interface.h> |
25 | #include <root_domain_map.h> |
26 | #include <scheduler/all_schedulers.h> |
27 | #include <scheduler/utils.h> |
28 | #include <test/test_gpu_validator.h> |
29 | #include <test/test_utils.h> |
30 | #include <transform_replay.h> |
31 | #include <transform_rfactor.h> |
32 | |
33 | // fuser and IR parser |
34 | #include <ATen/cuda/CUDAContext.h> |
35 | #include <ATen/cuda/Exceptions.h> |
36 | #include <c10/cuda/CUDAStream.h> |
37 | |
38 | #include <algorithm> |
39 | #include <iostream> |
40 | |
41 | // Tests go in torch::jit |
42 | namespace torch { |
43 | namespace jit { |
44 | |
45 | using namespace torch::jit::fuser::cuda; |
46 | using namespace at::indexing; |
47 | |
48 | namespace { |
49 | |
50 | // Used to signify invalid ranges, i.e., values at offset 0 to |
51 | // start_offset, and values at offset stop_offset to the end of the |
52 | // domain. |
53 | static constexpr int invalid_marker = 1; |
54 | |
55 | // ATen version of tensor shifting |
56 | auto shift( |
57 | at::Tensor tensor, |
58 | const std::vector<int>& offsets, |
59 | std::vector<int> padding = {}) { |
60 | TORCH_INTERNAL_ASSERT( |
61 | tensor.ndimension() == static_cast<int64_t>(offsets.size())); |
62 | if (padding.empty()) { |
63 | padding = offsets; |
64 | for (auto& p : padding) { |
65 | p = std::abs(p); |
66 | } |
67 | } |
68 | at::Tensor t = tensor; |
69 | for (size_t i = 0; i < offsets.size(); ++i) { |
70 | auto offset = offsets[i]; |
71 | t = t.roll(offsets[i], i); |
72 | if (offset == 0) { |
73 | continue; |
74 | } |
75 | // Zero padding |
76 | std::vector<at::indexing::TensorIndex> indices( |
77 | tensor.ndimension(), at::indexing::Slice(0, at::indexing::None)); |
78 | if (offset > 0) { |
79 | indices[i] = at::indexing::Slice(0, offset); |
80 | } else { |
81 | indices[i] = at::indexing::Slice(offset, at::indexing::None); |
82 | } |
83 | t.index(indices) = 0; |
84 | // Fill the outside range by the special marker value. |
85 | const auto pad = padding[i]; |
86 | if (offset > 0) { |
87 | indices[i] = at::indexing::Slice(0, offset - pad); |
88 | } else { |
89 | offset += pad; |
90 | TORCH_INTERNAL_ASSERT(offset <= 0); |
91 | if (offset == 0) { |
92 | continue; |
93 | } |
94 | indices[i] = at::indexing::Slice(offset, at::indexing::None); |
95 | } |
96 | t.index(indices) = invalid_marker; |
97 | } |
98 | return t; |
99 | } |
100 | |
101 | // ATen version of tensor gather |
102 | auto gather( |
103 | at::Tensor tensor, |
104 | const std::vector<int>& window_shape, |
105 | const std::vector<std::vector<int>>& pad_width, |
106 | std::vector<int> strides = {}) { |
107 | TORCH_CHECK( |
108 | tensor.ndimension() == static_cast<int64_t>(window_shape.size()), |
109 | "Invalid window shape: " , |
110 | window_shape, |
111 | ". Size of the window shape is different from the tensor dimension." ); |
112 | TORCH_CHECK( |
113 | tensor.ndimension() == static_cast<int64_t>(pad_width.size()), |
114 | "Invalid pad width: " , |
115 | pad_width, |
116 | ". Size of the pad width is different from the tensor dimension." ); |
117 | if (strides.empty()) { |
118 | strides = std::vector<int>(tensor.ndimension(), 1); |
119 | } else { |
120 | TORCH_CHECK( |
121 | tensor.ndimension() == static_cast<int64_t>(strides.size()), |
122 | "Invalid strides: " , |
123 | strides, |
124 | ". Size of strides is different from the tensor dimension." ); |
125 | } |
126 | at::Tensor t = tensor; |
127 | for (size_t i = 0; i < window_shape.size(); ++i) { |
128 | const auto w_size = window_shape[i]; |
129 | TORCH_CHECK(w_size != 0); |
130 | const auto& pad = pad_width[i]; |
131 | TORCH_CHECK(pad.size() == 2); |
132 | const auto out_extent_adj = -w_size + 1 + pad[0] + pad[1]; |
133 | TORCH_INTERNAL_ASSERT(out_extent_adj <= 0); |
134 | const auto stride = strides[i]; |
135 | TORCH_CHECK(stride >= 1); |
136 | |
137 | at::Tensor concat_tensor; |
138 | |
139 | for (int w = 0; w < w_size; ++w) { |
140 | std::vector<int> shift_offsets(t.ndimension(), 0); |
141 | shift_offsets[i] = pad[0] - w; |
142 | auto shifted = shift(t, shift_offsets); |
143 | // Apply stride |
144 | if (stride != 1) { |
145 | std::vector<at::indexing::TensorIndex> indices( |
146 | shifted.ndimension(), at::indexing::Slice(0, at::indexing::None)); |
147 | if (out_extent_adj == 0) { |
148 | indices[i] = at::indexing::Slice(0, at::indexing::None, strides[i]); |
149 | } else { |
150 | indices[i] = at::indexing::Slice(0, out_extent_adj, strides[i]); |
151 | } |
152 | shifted = shifted.index(indices); |
153 | } |
154 | shifted = shifted.unsqueeze(-1); |
155 | if (w == 0) { |
156 | concat_tensor = shifted; |
157 | } else { |
158 | concat_tensor = at::cat({concat_tensor, shifted}, -1); |
159 | } |
160 | } |
161 | t = concat_tensor; |
162 | } |
163 | |
164 | // Fill invalid regions with the marker. Note that when non-unit |
165 | // stride is used, it trims invalid regions, so no marking is |
166 | // necessary. |
167 | for (size_t i = 0; i < window_shape.size(); ++i) { |
168 | if (strides[i] != 1) { |
169 | continue; |
170 | } |
171 | |
172 | const auto out_extent_adj = |
173 | -window_shape[i] + 1 + pad_width[i][0] + pad_width[i][1]; |
174 | if (out_extent_adj < 0) { |
175 | std::vector<at::indexing::TensorIndex> indices( |
176 | t.ndimension(), at::indexing::Slice(0, at::indexing::None)); |
177 | indices[i] = at::indexing::Slice(out_extent_adj, at::indexing::None); |
178 | t.index(indices) = invalid_marker; |
179 | } |
180 | } |
181 | |
182 | return t; |
183 | } |
184 | |
185 | } // namespace |
186 | |
187 | // Shift an input tensor |
188 | TEST_F(NVFuserTest, FusionShift1_CUDA) { |
189 | Fusion fusion; |
190 | FusionGuard fg(&fusion); |
191 | |
192 | auto tv0 = makeSymbolicTensor(2); |
193 | fusion.addInput(tv0); |
194 | |
195 | auto tv1 = shift(tv0, {-1, 0}); |
196 | fusion.addOutput(tv1); |
197 | |
198 | auto tv2 = shift(tv0, {0, 1}); |
199 | fusion.addOutput(tv2); |
200 | |
201 | auto tv3 = shift(tv0, {2, 2}); |
202 | fusion.addOutput(tv3); |
203 | |
204 | auto tv4 = shift(tv0, {-2, -2}); |
205 | fusion.addOutput(tv4); |
206 | |
207 | int numel_x = 9; |
208 | int numel_y = 11; |
209 | |
210 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
211 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
212 | std::vector<IValue> inputs = {t0}; |
213 | |
214 | FusionExecutor fe; |
215 | fe.compileFusion(&fusion, inputs); |
216 | auto outputs = fe.runFusion(inputs); |
217 | |
218 | auto t1 = shift(t0, {-1, 0}); |
219 | TORCH_CHECK(t1.equal(outputs[0])); |
220 | |
221 | auto t2 = shift(t0, {0, 1}); |
222 | TORCH_CHECK(t2.equal(outputs[1])); |
223 | |
224 | auto t3 = shift(t0, {2, 2}); |
225 | TORCH_CHECK(t3.equal(outputs[2])); |
226 | |
227 | auto t4 = shift(t0, {-2, -2}); |
228 | TORCH_CHECK(t4.equal(outputs[3])); |
229 | } |
230 | |
231 | // Shifts an intermediate tensor |
232 | TEST_F(NVFuserTest, FusionShift2_CUDA) { |
233 | Fusion fusion; |
234 | FusionGuard fg(&fusion); |
235 | |
236 | auto tv0 = makeSymbolicTensor(2); |
237 | fusion.addInput(tv0); |
238 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
239 | auto tv2 = shift(tv1, {-1, 0}); |
240 | fusion.addOutput(tv2); |
241 | |
242 | // make it a little more complex |
243 | auto tv3 = add(tv0, IrBuilder::create<Double>(3)); |
244 | auto tv4 = add(tv3, IrBuilder::create<Double>(4)); |
245 | auto tv5 = shift(tv4, {-1, 0}); |
246 | auto tv6 = shift(tv4, {0, -1}); |
247 | auto tv7 = shift(tv4, {1, 0}); |
248 | auto tv8 = shift(tv4, {0, 0}); |
249 | auto tv9 = add(tv5, tv6); |
250 | auto tv10 = add(tv9, tv7); |
251 | auto tv11 = add(tv10, tv8); |
252 | fusion.addOutput(tv11); |
253 | |
254 | for (auto tv : {tv1, tv2, tv3, tv4, tv5, tv6, tv7, tv8, tv9, tv10, tv11}) { |
255 | tv->setMemoryType(MemoryType::Global); |
256 | } |
257 | |
258 | // t1 allocation: (t1.size[0] + 1) * (t1.size[1]) |
259 | // t3 allocation: (t3.size[0] + 2) * (t3.size[1] + 1) |
260 | // t4 allocation: (t3.size[0] + 2) * (t3.size[1] + 1) |
261 | GpuLower gpulw(&fusion); |
262 | |
263 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
264 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
265 | auto tensor_name = alloc->buffer()->name(); |
266 | if (tensor_name == 1 || tensor_name == 3 || tensor_name == 4) { |
267 | TORCH_CHECK(alloc->shape().size() == 2); |
268 | for (int i = 0; i < 2; ++i) { |
269 | if (tensor_name == 1 && i == 1) { |
270 | TORCH_CHECK(alloc->shape().at(i)->isA<NamedScalar>()); |
271 | continue; |
272 | } |
273 | auto def = |
274 | dynamic_cast<BinaryOp*>(alloc->shape().at(i)->definition()); |
275 | TORCH_CHECK( |
276 | def != nullptr && def->getBinaryOpType() == BinaryOpType::Add); |
277 | TORCH_CHECK(def->as<BinaryOp>()->lhs()->isA<NamedScalar>()); |
278 | auto rhs = dynamic_cast<Int*>(def->as<BinaryOp>()->rhs()); |
279 | TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
280 | int rhs_value = *rhs->value(); |
281 | if (tensor_name == 1) { |
282 | TORCH_CHECK(i == 0); |
283 | TORCH_CHECK(rhs_value == 1); |
284 | } else { |
285 | if (i == 0) { |
286 | TORCH_CHECK(rhs_value == 2); |
287 | } else { |
288 | TORCH_CHECK(rhs_value == 1); |
289 | } |
290 | } |
291 | } |
292 | } |
293 | } |
294 | } |
295 | |
296 | int numel_x = 9; |
297 | int numel_y = 11; |
298 | |
299 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
300 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
301 | std::vector<IValue> inputs = {t0}; |
302 | |
303 | FusionExecutor fe; |
304 | fe.compileFusion(&fusion, inputs); |
305 | auto outputs = fe.runFusion(inputs); |
306 | |
307 | auto t1 = t0 + 1; |
308 | auto t2 = shift(t1, {-1, 0}); |
309 | |
310 | auto t3 = t0 + 3; |
311 | auto t4 = t3 + 4; |
312 | auto t5 = shift(t4, {-1, 0}); |
313 | auto t6 = shift(t4, {0, -1}); |
314 | auto t7 = shift(t4, {1, 0}); |
315 | auto t8 = shift(t4, {0, 0}); |
316 | auto t9 = t5 + t6; |
317 | auto t10 = t9 + t7; |
318 | auto t11 = t10 + t8; |
319 | |
320 | testValidate(&fusion, outputs, inputs, {t2, t11}, __LINE__, __FILE__); |
321 | } |
322 | |
323 | TEST_F(NVFuserTest, FusionShiftRightOfCA_CUDA) { |
324 | Fusion fusion; |
325 | FusionGuard fg(&fusion); |
326 | |
327 | auto tv0 = makeSymbolicTensor(2); |
328 | fusion.addInput(tv0); |
329 | |
330 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
331 | auto tv2 = shift(tv1, {0, 1}); |
332 | fusion.addOutput(tv2); |
333 | |
334 | tv0->computeAt(tv2, -2); |
335 | |
336 | tv1->setMemoryType(MemoryType::Global); |
337 | |
338 | int numel_x = 100; |
339 | int numel_y = 101; |
340 | |
341 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
342 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
343 | std::vector<IValue> inputs = {t0}; |
344 | |
345 | FusionExecutor fe; |
346 | fe.compileFusion(&fusion, inputs); |
347 | auto outputs = fe.runFusion(inputs); |
348 | |
349 | auto t1 = t0 + 1; |
350 | auto t2 = shift(t1, {0, 1}); |
351 | |
352 | TORCH_CHECK(t2.allclose(outputs[0])); |
353 | } |
354 | |
355 | TEST_F(NVFuserTest, FusionShiftLeftOfCA_CUDA) { |
356 | Fusion fusion; |
357 | FusionGuard fg(&fusion); |
358 | |
359 | auto tv0 = makeSymbolicTensor(2); |
360 | fusion.addInput(tv0); |
361 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
362 | auto tv2 = add(tv1, IrBuilder::create<Double>(1)); |
363 | auto tv3 = shift(tv2, {-1, 0}); |
364 | auto tv4 = add(tv3, IrBuilder::create<Double>(1)); |
365 | fusion.addOutput(tv4); |
366 | |
367 | tv0->computeAt(tv4, -1); |
368 | |
369 | // Lowering should trigger an assertion failure as a shifted axis is |
370 | // found inside an allocation position. |
371 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
372 | ASSERT_ANY_THROW(fusion.printKernel()); |
373 | } |
374 | |
375 | TEST_F(NVFuserTest, FusionShiftSplit1_CUDA) { |
376 | Fusion fusion; |
377 | FusionGuard fg(&fusion); |
378 | |
379 | auto tv0 = makeSymbolicTensor(2); |
380 | fusion.addInput(tv0); |
381 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
382 | auto tv2 = shift(tv1, {0, 1}); |
383 | auto tv3 = shift(tv1, {0, -2}); |
384 | fusion.addOutput(tv2); |
385 | fusion.addOutput(tv3); |
386 | |
387 | int split_factor = 4; |
388 | tv2->split(-1, split_factor); |
389 | tv3->split(-1, split_factor); |
390 | |
391 | tv0->computeAt(tv2, -2); |
392 | tv0->computeAt(tv3, -2); |
393 | |
394 | // t1 allocation: 7 |
395 | GpuLower gpulw(&fusion); |
396 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
397 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
398 | auto tensor_name = alloc->buffer()->name(); |
399 | if (tensor_name == 1) { |
400 | TORCH_CHECK(alloc->shape().size() == 1); |
401 | auto size = dynamic_cast<Int*>(alloc->shape().at(0)); |
402 | TORCH_CHECK( |
403 | size != nullptr && size->isConst() && size->value().value() == 7); |
404 | } |
405 | } |
406 | } |
407 | |
408 | int numel_x = 9; |
409 | int numel_y = 11; |
410 | |
411 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
412 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
413 | std::vector<IValue> inputs = {t0}; |
414 | |
415 | FusionExecutor fe; |
416 | fe.compileFusion(&fusion, inputs); |
417 | auto outputs = fe.runFusion(inputs); |
418 | |
419 | auto t1 = t0 + 1; |
420 | auto t2 = shift(t1, {0, 1}); |
421 | auto t3 = shift(t1, {0, -2}); |
422 | |
423 | testValidate(&fusion, outputs, inputs, {t2, t3}, __LINE__, __FILE__); |
424 | } |
425 | |
426 | TEST_F(NVFuserTest, FusionShiftSplit2_CUDA) { |
427 | Fusion fusion; |
428 | FusionGuard fg(&fusion); |
429 | |
430 | auto tv0 = makeSymbolicTensor(2); |
431 | fusion.addInput(tv0); |
432 | |
433 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
434 | auto tv2 = add(tv1, IrBuilder::create<Double>(1)); |
435 | auto tv3 = shift(tv2, {0, -1}); |
436 | auto tv4 = shift(tv2, {0, 1}); |
437 | auto tv5 = add(tv3, tv4); |
438 | fusion.addOutput(tv5); |
439 | |
440 | auto tv6 = add(tv0, IrBuilder::create<Double>(1)); |
441 | auto tv7 = shift(tv6, {0, 0}); |
442 | auto tv8 = add(tv7, IrBuilder::create<Double>(1)); |
443 | fusion.addOutput(tv8); |
444 | |
445 | int split_factor = 4; |
446 | |
447 | tv5->split(-1, split_factor); |
448 | tv8->split(-1, split_factor); |
449 | |
450 | tv0->computeAt(tv5, -2); |
451 | tv0->computeAt(tv8, -2); |
452 | |
453 | // t1 and t2 allocation: 6 |
454 | // t4 allocation: 4 |
455 | GpuLower gpulw(&fusion); |
456 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
457 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
458 | auto tensor_name = alloc->buffer()->name(); |
459 | if (tensor_name == 1 || tensor_name == 2) { |
460 | TORCH_CHECK(alloc->shape().size() == 1); |
461 | auto size = dynamic_cast<Int*>(alloc->shape().at(0)); |
462 | TORCH_CHECK( |
463 | size != nullptr && size->isConst() && size->value().value() == 6); |
464 | } else if (tensor_name == 4) { |
465 | TORCH_CHECK(alloc->shape().size() == 1); |
466 | auto size = dynamic_cast<Int*>(alloc->shape().at(0)); |
467 | TORCH_CHECK(size != nullptr && size->isConst()); |
468 | int size_value = *size->value(); |
469 | TORCH_CHECK(size_value == split_factor); |
470 | } |
471 | } |
472 | } |
473 | |
474 | int numel_x = 9; |
475 | int numel_y = 11; |
476 | |
477 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
478 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
479 | std::vector<IValue> inputs = {t0}; |
480 | |
481 | FusionExecutor fe; |
482 | fe.compileFusion(&fusion, inputs); |
483 | auto outputs = fe.runFusion(inputs); |
484 | |
485 | auto t1 = t0 + 2; |
486 | auto t3 = shift(t1, {0, -1}); |
487 | auto t4 = shift(t1, {0, 1}); |
488 | auto t5 = t3 + t4; |
489 | |
490 | auto t6 = t0 + 1; |
491 | auto t7 = t6; |
492 | auto t8 = t7 + 1; |
493 | |
494 | testValidate(&fusion, outputs, inputs, {t5, t8}, __LINE__, __FILE__); |
495 | } |
496 | |
497 | TEST_F(NVFuserTest, FusionShiftDoubleSplit_CUDA) { |
498 | Fusion fusion; |
499 | FusionGuard fg(&fusion); |
500 | |
501 | auto tv0 = makeSymbolicTensor(2); |
502 | fusion.addInput(tv0); |
503 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
504 | auto tv2 = add(tv1, IrBuilder::create<Double>(2)); |
505 | auto tv3 = shift(tv2, {0, 1}); |
506 | fusion.addOutput(tv3); |
507 | |
508 | int split_factor1 = 8; |
509 | int split_factor2 = 4; |
510 | |
511 | tv3->split(-1, split_factor1); |
512 | |
513 | tv0->computeAt(tv3, -2); |
514 | |
515 | tv1->split(-1, split_factor2); |
516 | |
517 | // t1: [i1, i2/8, 8/4, 4] |
518 | // t2: [i1, i2/8, 8] |
519 | // t3: [i1, i2/8, 8] |
520 | |
521 | // t1 and t2 allocation: (split_factor1 + 1) = 9 |
522 | GpuLower gpulw(&fusion); |
523 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
524 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
525 | auto tensor_name = alloc->buffer()->name(); |
526 | if (tensor_name == 1 || tensor_name == 2) { |
527 | TORCH_CHECK(alloc->shape().size() == 1); |
528 | auto size = dynamic_cast<Int*>(alloc->shape().at(0)); |
529 | TORCH_CHECK( |
530 | size != nullptr && size->isConst() && size->value().value() == 9); |
531 | } |
532 | } |
533 | } |
534 | |
535 | int numel_x = 99; |
536 | int numel_y = 101; |
537 | |
538 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
539 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
540 | std::vector<IValue> inputs = {t0}; |
541 | |
542 | FusionExecutor fe; |
543 | fe.compileFusion(&fusion, inputs); |
544 | auto outputs = fe.runFusion(inputs); |
545 | |
546 | auto t1 = t0 + 3; |
547 | auto ref = shift(t1, {0, 1}); |
548 | |
549 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
550 | } |
551 | |
552 | TEST_F(NVFuserTest, FusionShift3ptStencil_CUDA) { |
553 | Fusion fusion; |
554 | FusionGuard fg(&fusion); |
555 | |
556 | // 3-pt stencil |
557 | auto tv0 = makeSymbolicTensor(1); |
558 | fusion.addInput(tv0); |
559 | |
560 | std::vector<std::vector<int>> offsets = {{-1}, {1}}; |
561 | |
562 | std::vector<TensorView*> tvs; |
563 | for (const auto& offset : offsets) { |
564 | tvs.push_back(shift(tv0, offset)); |
565 | } |
566 | |
567 | auto tv_out = tv0; |
568 | |
569 | for (auto tv : tvs) { |
570 | tv_out = add(tv_out, tv); |
571 | } |
572 | |
573 | tv_out = div(tv_out, IrBuilder::create<Double>(tvs.size() + 1)); |
574 | |
575 | fusion.addOutput(tv_out); |
576 | |
577 | int split_factor = 4; |
578 | |
579 | tv_out->split(0, split_factor); |
580 | |
581 | // This seems fine but not verified yet |
582 | // tv_out->axis(-1)->parallelize(ParallelType::Unswitch); |
583 | |
584 | auto cache = tv0->cacheAfter(); |
585 | |
586 | tv0->computeAt(tv_out, 1); |
587 | |
588 | // Inline completely except for the cache |
589 | for (auto tv : tvs) { |
590 | tv->computeAt(tv_out, -1); |
591 | } |
592 | |
593 | // cache allocation: (split_factor + 2) |
594 | GpuLower gpulw(&fusion); |
595 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
596 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
597 | auto tensor_name = alloc->buffer()->name(); |
598 | if (tensor_name == cache->name()) { |
599 | TORCH_CHECK(alloc->shape().size() == 1); |
600 | auto size = dynamic_cast<Int*>(alloc->shape().at(0)); |
601 | TORCH_CHECK( |
602 | size != nullptr && size->isConst() && |
603 | size->value().value() == split_factor + 2); |
604 | } |
605 | } |
606 | } |
607 | |
608 | cache->doubleBuffer(); |
609 | |
610 | int numel_x = 99; |
611 | |
612 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
613 | at::Tensor t0 = at::randn({numel_x}, options); |
614 | std::vector<IValue> inputs = {t0}; |
615 | |
616 | FusionExecutor fe; |
617 | fe.compileFusion(&fusion, inputs); |
618 | auto outputs = fe.runFusion(inputs); |
619 | |
620 | auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3; |
621 | |
622 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
623 | } |
624 | |
625 | TEST_F(NVFuserTest, FusionShift5ptStencil_CUDA) { |
626 | Fusion fusion; |
627 | FusionGuard fg(&fusion); |
628 | |
629 | // 5-pt stencil |
630 | auto tv0 = makeSymbolicTensor(2); |
631 | fusion.addInput(tv0); |
632 | std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; |
633 | |
634 | std::vector<TensorView*> tvs; |
635 | for (const auto& offset : offsets) { |
636 | tvs.push_back(shift(tv0, offset)); |
637 | } |
638 | |
639 | auto tv_out = tv0; |
640 | |
641 | for (auto tv : tvs) { |
642 | tv_out = add(tv_out, tv); |
643 | } |
644 | |
645 | tv_out = div(tv_out, IrBuilder::create<Double>(tvs.size() + 1)); |
646 | |
647 | fusion.addOutput(tv_out); |
648 | |
649 | std::vector<int> split_factor({4, 8}); |
650 | |
651 | tv_out->split(-1, split_factor[1]); |
652 | tv_out->split(0, split_factor[0]); |
653 | tv_out->reorder({{1, 2}, {2, 1}}); |
654 | |
655 | auto cache = tv0->cacheAfter(); |
656 | |
657 | tv0->computeAt(tv_out, 2); |
658 | |
659 | // Inline completely except for the cache |
660 | for (auto tv : tvs) { |
661 | tv->computeAt(tv_out, -1); |
662 | } |
663 | |
664 | // cache allocation: (split_factor + 2) * (split_factor + 2) |
665 | GpuLower gpulw(&fusion); |
666 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
667 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
668 | auto tensor_name = alloc->buffer()->name(); |
669 | if (tensor_name == cache->name()) { |
670 | TORCH_CHECK(alloc->shape().size() == 2); |
671 | for (int i = 0; i < 2; ++i) { |
672 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
673 | TORCH_CHECK( |
674 | size != nullptr && size->isConst() && |
675 | size->value().value() == split_factor[i] + 2); |
676 | } |
677 | } |
678 | } |
679 | } |
680 | |
681 | cache->doubleBuffer(); |
682 | |
683 | int numel_x = 99; |
684 | int numel_y = 101; |
685 | |
686 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
687 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
688 | std::vector<IValue> inputs = {t0}; |
689 | |
690 | FusionExecutor fe; |
691 | fe.compileFusion(&fusion, inputs); |
692 | auto outputs = fe.runFusion(inputs); |
693 | |
694 | auto ref = t0; |
695 | for (const auto& offset : offsets) { |
696 | ref = ref + shift(t0, offset); |
697 | } |
698 | ref = ref / int(offsets.size() + 1); |
699 | |
700 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
701 | } |
702 | |
703 | TEST_F(NVFuserTest, FusionShift9ptStencil_CUDA) { |
704 | Fusion fusion; |
705 | FusionGuard fg(&fusion); |
706 | |
707 | // 9-pt stencil |
708 | std::vector<std::vector<int>> offsets; |
709 | for (int i = -1; i < 2; ++i) { |
710 | for (int j = -1; j < 2; ++j) { |
711 | if (i == 0 && j == 0) { |
712 | continue; |
713 | } |
714 | offsets.push_back({i, j}); |
715 | } |
716 | } |
717 | |
718 | auto tv0 = makeSymbolicTensor(2); |
719 | fusion.addInput(tv0); |
720 | std::vector<TensorView*> tvs; |
721 | for (const auto& offset : offsets) { |
722 | tvs.push_back(shift(tv0, offset)); |
723 | } |
724 | |
725 | auto tv_out = tv0; |
726 | |
727 | for (auto tv : tvs) { |
728 | tv_out = add(tv_out, tv); |
729 | } |
730 | |
731 | tv_out = div(tv_out, IrBuilder::create<Double>(tvs.size() + 1)); |
732 | |
733 | fusion.addOutput(tv_out); |
734 | |
735 | std::vector<int> split_factor({4, 8}); |
736 | tv_out->split(-1, split_factor[1]); |
737 | tv_out->split(0, split_factor[0]); |
738 | tv_out->reorder({{1, 2}, {2, 1}}); |
739 | |
740 | auto cache = tv0->cacheAfter(); |
741 | |
742 | tv0->computeAt(tv_out, 2); |
743 | |
744 | // Inline completely except for the cache |
745 | for (auto tv : tvs) { |
746 | tv->computeAt(tv_out, -1); |
747 | } |
748 | |
749 | // This seems fine but not yet verified |
750 | // tv_out->axis(-1)->parallelize(ParallelType::Unswitch); |
751 | |
752 | // cache allocation: (split_factor + 2) * (split_factor + 2) |
753 | GpuLower gpulw(&fusion); |
754 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
755 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
756 | auto tensor_name = alloc->buffer()->name(); |
757 | if (tensor_name == cache->name()) { |
758 | TORCH_CHECK(alloc->shape().size() == 2); |
759 | for (int i = 0; i < 2; ++i) { |
760 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
761 | TORCH_CHECK( |
762 | size != nullptr && size->isConst() && |
763 | size->value().value() == split_factor[i] + 2); |
764 | } |
765 | } |
766 | } |
767 | } |
768 | |
769 | cache->doubleBuffer(); |
770 | |
771 | int numel_x = 99; |
772 | int numel_y = 101; |
773 | |
774 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
775 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
776 | std::vector<IValue> inputs = {t0}; |
777 | |
778 | FusionExecutor fe; |
779 | fe.compileFusion(&fusion, inputs); |
780 | auto outputs = fe.runFusion(inputs); |
781 | |
782 | auto ref = t0; |
783 | for (const auto& offset : offsets) { |
784 | ref = ref + shift(t0, offset); |
785 | } |
786 | ref = ref / int(offsets.size() + 1); |
787 | |
788 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
789 | } |
790 | |
791 | TEST_F(NVFuserTest, FusionShiftSmemBlocking_CUDA) { |
792 | Fusion fusion; |
793 | FusionGuard fg(&fusion); |
794 | |
795 | auto tv0 = makeSymbolicTensor(2); |
796 | fusion.addInput(tv0); |
797 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
798 | auto tv2 = shift(tv1, {0, 1}); |
799 | fusion.addOutput(tv2); |
800 | |
801 | int smem_block_factor = 32; |
802 | |
803 | tv2->split(-1, smem_block_factor); |
804 | |
805 | tv0->computeAt(tv2, -2); |
806 | |
807 | tv1->axis(-1)->parallelize(ParallelType::TIDx); |
808 | tv2->axis(-1)->parallelize(ParallelType::TIDx); |
809 | |
810 | tv1->setMemoryType(MemoryType::Shared); |
811 | |
812 | // tv1 allocation: (split_factor + 1) |
813 | GpuLower gpulw(&fusion); |
814 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
815 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
816 | auto tensor_name = alloc->buffer()->name(); |
817 | if (tensor_name == tv1->name()) { |
818 | TORCH_CHECK(alloc->shape().size() == 1); |
819 | for (int i = 0; i < 1; ++i) { |
820 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
821 | TORCH_CHECK( |
822 | size != nullptr && size->isConst() && |
823 | size->value().value() == smem_block_factor + 1); |
824 | } |
825 | } |
826 | } |
827 | } |
828 | |
829 | int numel_x = 100; |
830 | int numel_y = 101; |
831 | |
832 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
833 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
834 | std::vector<IValue> inputs = {t0}; |
835 | |
836 | FusionExecutor fe; |
837 | fe.compileFusion(&fusion, inputs); |
838 | auto outputs = fe.runFusion(inputs); |
839 | |
840 | auto t1 = t0 + 1; |
841 | auto t2 = shift(t1, {0, 1}); |
842 | auto ref = t2; |
843 | |
844 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
845 | } |
846 | |
847 | TEST_F(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { |
848 | Fusion fusion; |
849 | FusionGuard fg(&fusion); |
850 | |
851 | // 3-pt stencil |
852 | auto tv0 = makeSymbolicTensor(1); |
853 | fusion.addInput(tv0); |
854 | std::vector<TensorView*> tvs; |
855 | tvs.push_back(shift(tv0, {-1})); |
856 | tvs.push_back(shift(tv0, {1})); |
857 | |
858 | auto tv_out = tv0; |
859 | |
860 | for (auto tv : tvs) { |
861 | tv_out = add(tv_out, tv); |
862 | } |
863 | |
864 | tv_out = div(tv_out, IrBuilder::create<Double>(tvs.size() + 1)); |
865 | |
866 | fusion.addOutput(tv_out); |
867 | |
868 | int smem_block_factor = 32; |
869 | |
870 | tv_out->split(0, smem_block_factor); |
871 | // tv_out->axis(-1)->parallelize(ParallelType::Unswitch); |
872 | |
873 | auto tv0_cache = tv0->cacheAfter(); |
874 | |
875 | tv0->computeAt(tv_out, 1); |
876 | |
877 | for (auto tv : tvs) { |
878 | tv->computeAt(tv_out, -1); |
879 | } |
880 | |
881 | tv0_cache->setMemoryType(MemoryType::Shared); |
882 | tv_out->axis(-1)->parallelize(ParallelType::TIDx); |
883 | tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); |
884 | |
885 | tv0_cache->doubleBuffer(); |
886 | |
887 | int numel_x = 99; |
888 | |
889 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
890 | at::Tensor t0 = at::randn({numel_x}, options); |
891 | std::vector<IValue> inputs = {t0}; |
892 | |
893 | FusionExecutor fe; |
894 | fe.compileFusion(&fusion, inputs); |
895 | auto outputs = fe.runFusion(inputs); |
896 | |
897 | auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3; |
898 | |
899 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
900 | } |
901 | |
902 | TEST_F(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { |
903 | Fusion fusion; |
904 | FusionGuard fg(&fusion); |
905 | |
906 | // 5-pt stencil |
907 | auto tv0 = makeSymbolicTensor(2); |
908 | fusion.addInput(tv0); |
909 | std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; |
910 | |
911 | std::vector<TensorView*> tvs; |
912 | for (const auto& offset : offsets) { |
913 | tvs.push_back(shift(tv0, offset)); |
914 | } |
915 | |
916 | auto tv_out = tv0; |
917 | |
918 | for (auto tv : tvs) { |
919 | tv_out = add(tv_out, tv); |
920 | } |
921 | |
922 | tv_out = div(tv_out, IrBuilder::create<Double>(tvs.size() + 1)); |
923 | |
924 | fusion.addOutput(tv_out); |
925 | |
926 | int smem_block_factor = 32; |
927 | |
928 | tv_out->split(-1, smem_block_factor); |
929 | tv_out->split(0, smem_block_factor); |
930 | |
931 | tv_out->reorder({{1, 2}, {2, 1}}); |
932 | |
933 | auto tv0_cache = tv0->cacheAfter(); |
934 | |
935 | tv0->computeAt(tv_out, 2); |
936 | |
937 | for (auto tv : tvs) { |
938 | tv->computeAt(tv_out, -1); |
939 | } |
940 | |
941 | tv_out->axis(-1)->parallelize(ParallelType::TIDx); |
942 | tv_out->axis(-2)->parallelize(ParallelType::TIDy); |
943 | tv_out->axis(-3)->parallelize(ParallelType::BIDx); |
944 | tv_out->axis(-4)->parallelize(ParallelType::BIDy); |
945 | |
946 | tv0_cache->setMemoryType(MemoryType::Shared); |
947 | tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); |
948 | tv0_cache->axis(-2)->parallelize(ParallelType::TIDy); |
949 | |
950 | int numel_x = 99; |
951 | int numel_y = 101; |
952 | |
953 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
954 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
955 | std::vector<IValue> inputs = {t0}; |
956 | |
957 | FusionExecutor fe; |
958 | fe.compileFusion(&fusion, inputs); |
959 | auto outputs = fe.runFusion(inputs); |
960 | |
961 | auto ref = t0; |
962 | for (const auto& offset : offsets) { |
963 | ref = ref + shift(t0, offset); |
964 | } |
965 | ref = ref / int(offsets.size() + 1); |
966 | |
967 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
968 | } |
969 | |
970 | TEST_F(NVFuserTest, FusionShiftMerge1_CUDA) { |
971 | Fusion fusion; |
972 | FusionGuard fg(&fusion); |
973 | |
974 | auto tv0 = makeSymbolicTensor(2); |
975 | fusion.addInput(tv0); |
976 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
977 | auto tv2 = shift(tv1, {-1, 1}); |
978 | fusion.addOutput(tv2); |
979 | |
980 | int split_factor = 4; |
981 | |
982 | tv2->split(-1, split_factor); |
983 | tv2->split(0, split_factor); |
984 | tv2->reorder({{1, 2}, {2, 1}}); |
985 | tv2->merge(2, 3); |
986 | |
987 | tv0->computeAt(tv2, 2); |
988 | |
989 | // t1 allocation: (split_factor + 1) * (split_factor + 1) |
990 | GpuLower gpulw(&fusion); |
991 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
992 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
993 | auto tensor_name = alloc->buffer()->name(); |
994 | if (tensor_name == 1) { |
995 | TORCH_CHECK(alloc->shape().size() == 2); |
996 | for (int i = 0; i < 2; ++i) { |
997 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
998 | TORCH_CHECK( |
999 | size != nullptr && size->isConst() && |
1000 | size->value().value() == split_factor + 1); |
1001 | } |
1002 | } |
1003 | } |
1004 | } |
1005 | |
1006 | int numel_x = 99; |
1007 | int numel_y = 101; |
1008 | |
1009 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1010 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1011 | std::vector<IValue> inputs = {t0}; |
1012 | |
1013 | FusionExecutor fe; |
1014 | fe.compileFusion(&fusion, inputs); |
1015 | auto outputs = fe.runFusion(inputs); |
1016 | |
1017 | auto t1 = t0 + 1; |
1018 | auto t2 = shift(t1, {-1, 1}); |
1019 | auto ref = t2; |
1020 | |
1021 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1022 | } |
1023 | |
1024 | TEST_F(NVFuserTest, FusionShiftMerge2_CUDA) { |
1025 | Fusion fusion; |
1026 | FusionGuard fg(&fusion); |
1027 | |
1028 | auto tv0 = makeSymbolicTensor(2); |
1029 | fusion.addInput(tv0); |
1030 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1031 | auto tv2 = shift(tv1, {1, -1}); |
1032 | auto tv3 = shift(tv1, {-1, 1}); |
1033 | auto tv4 = add(tv2, tv3); |
1034 | fusion.addOutput(tv4); |
1035 | |
1036 | int split_factor = 4; |
1037 | |
1038 | tv4->split(-1, split_factor); |
1039 | tv4->split(0, split_factor); |
1040 | tv4->reorder({{1, 2}, {2, 1}}); |
1041 | tv4->merge(2, 3); |
1042 | |
1043 | tv0->computeAt(tv4, -2); |
1044 | |
1045 | // t1 allocation: (split_factor + 2) * (split_factor + 2) |
1046 | GpuLower gpulw(&fusion); |
1047 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
1048 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
1049 | auto tensor_name = alloc->buffer()->name(); |
1050 | if (tensor_name == 1) { |
1051 | TORCH_CHECK(alloc->shape().size() == 2); |
1052 | for (int i = 0; i < 2; ++i) { |
1053 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
1054 | TORCH_CHECK( |
1055 | size != nullptr && size->isConst() && |
1056 | size->value().value() == split_factor + 2); |
1057 | } |
1058 | } |
1059 | } |
1060 | } |
1061 | |
1062 | int numel_x = 99; |
1063 | int numel_y = 101; |
1064 | |
1065 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1066 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1067 | std::vector<IValue> inputs = {t0}; |
1068 | |
1069 | FusionExecutor fe; |
1070 | fe.compileFusion(&fusion, inputs); |
1071 | auto outputs = fe.runFusion(inputs); |
1072 | |
1073 | auto t1 = t0 + 1; |
1074 | auto t2 = shift(t1, {1, -1}); |
1075 | auto t3 = shift(t1, {-1, 1}); |
1076 | auto t4 = t2 + t3; |
1077 | |
1078 | TORCH_CHECK(t4.allclose(outputs[0])); |
1079 | } |
1080 | |
1081 | TEST_F(NVFuserTest, FusionShiftGlobal_CUDA) { |
1082 | Fusion fusion; |
1083 | FusionGuard fg(&fusion); |
1084 | |
1085 | auto tv0 = makeSymbolicTensor(2); |
1086 | fusion.addInput(tv0); |
1087 | |
1088 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1089 | auto tv2 = shift(tv1, {0, 1}); |
1090 | auto tv3 = shift(tv1, {-1, 0}); |
1091 | auto tv4 = add(tv2, tv3); |
1092 | fusion.addOutput(tv4); |
1093 | |
1094 | tv1->split(-1, 4); |
1095 | tv2->split(-1, 8); |
1096 | tv3->split(-1, 2); |
1097 | tv4->split(-1, 3); |
1098 | |
1099 | tv1->merge(-2, -1); |
1100 | |
1101 | tv1->setMemoryType(MemoryType::Global); |
1102 | tv2->setMemoryType(MemoryType::Global); |
1103 | tv3->setMemoryType(MemoryType::Global); |
1104 | |
1105 | // t1 allocation: (t1.size[0] + 1) * (t1.size[1] + 1) |
1106 | GpuLower gpulw(&fusion); |
1107 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
1108 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
1109 | auto tensor_name = alloc->buffer()->name(); |
1110 | if (tensor_name == 1) { |
1111 | TORCH_CHECK(alloc->shape().size() == 2); |
1112 | for (int i = 0; i < 2; ++i) { |
1113 | auto def = |
1114 | dynamic_cast<BinaryOp*>(alloc->shape().at(i)->definition()); |
1115 | TORCH_CHECK( |
1116 | def != nullptr && def->getBinaryOpType() == BinaryOpType::Add); |
1117 | TORCH_CHECK(def->as<BinaryOp>()->lhs()->isA<NamedScalar>()); |
1118 | auto rhs = dynamic_cast<Int*>(def->as<BinaryOp>()->rhs()); |
1119 | TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
1120 | int rhs_value = *rhs->value(); |
1121 | TORCH_CHECK(rhs_value == 1); |
1122 | } |
1123 | } |
1124 | } |
1125 | } |
1126 | |
1127 | int numel_x = 99; |
1128 | int numel_y = 101; |
1129 | |
1130 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1131 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1132 | std::vector<IValue> inputs = {t0}; |
1133 | |
1134 | FusionExecutor fe; |
1135 | fe.compileFusion(&fusion, inputs); |
1136 | auto outputs = fe.runFusion(inputs); |
1137 | |
1138 | auto t1 = t0 + 1; |
1139 | auto t2 = shift(t1, {0, 1}); |
1140 | auto t3 = shift(t1, {-1, 0}); |
1141 | auto t4 = t2 + t3; |
1142 | auto ref = t4; |
1143 | |
1144 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1145 | } |
1146 | |
1147 | TEST_F(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { |
1148 | Fusion fusion; |
1149 | FusionGuard fg(&fusion); |
1150 | |
1151 | auto tv0 = makeSymbolicTensor(2); |
1152 | fusion.addInput(tv0); |
1153 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1154 | auto tv2 = add(tv1, IrBuilder::create<Double>(2)); |
1155 | auto tv3 = shift(tv2, {0, 1}); |
1156 | fusion.addOutput(tv3); |
1157 | |
1158 | int split_factor1 = 8; |
1159 | int split_factor2 = 4; |
1160 | |
1161 | tv3->split(-1, split_factor1); |
1162 | |
1163 | tv0->computeAt(tv3, -2); |
1164 | |
1165 | tv1->split(-1, split_factor2); |
1166 | tv1->merge(-2, -1); |
1167 | |
1168 | // t1 and t2 allocation: (split_factor1 + 1) |
1169 | GpuLower gpulw(&fusion); |
1170 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
1171 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
1172 | auto tensor_name = alloc->buffer()->name(); |
1173 | if (tensor_name == 1 || tensor_name == 2) { |
1174 | auto size = dynamic_cast<Int*>(alloc->shape().at(0)); |
1175 | TORCH_CHECK( |
1176 | size != nullptr && size->isConst() && |
1177 | size->value().value() == split_factor1 + 1); |
1178 | } |
1179 | } |
1180 | } |
1181 | |
1182 | int numel_x = 99; |
1183 | int numel_y = 101; |
1184 | |
1185 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1186 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1187 | std::vector<IValue> inputs = {t0}; |
1188 | |
1189 | FusionExecutor fe; |
1190 | fe.compileFusion(&fusion, inputs); |
1191 | auto outputs = fe.runFusion(inputs); |
1192 | |
1193 | auto t1 = t0 + 3; |
1194 | auto ref = shift(t1, {0, 1}); |
1195 | |
1196 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1197 | } |
1198 | |
1199 | TEST_F(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { |
1200 | Fusion fusion; |
1201 | FusionGuard fg(&fusion); |
1202 | |
1203 | auto tv0 = makeSymbolicTensor(2); |
1204 | fusion.addInput(tv0); |
1205 | |
1206 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1207 | auto tv2 = add(tv1, IrBuilder::create<Double>(2)); |
1208 | auto tv3 = shift(tv2, {1, 1}); |
1209 | fusion.addOutput(tv3); |
1210 | |
1211 | auto out = tv3; |
1212 | |
1213 | int split_factor1 = 32; |
1214 | int split_factor2 = 4; |
1215 | |
1216 | out->split(-1, split_factor1); |
1217 | out->split(-1, split_factor2); |
1218 | out->split(0, split_factor1); |
1219 | out->split(1, split_factor2); |
1220 | out->reorder({{3, 1}, {1, 2}, {4, 3}, {2, 4}}); |
1221 | out->merge(2, 3); |
1222 | out->merge(2, 3); |
1223 | out->merge(2, 3); |
1224 | out->merge(0, 1); |
1225 | |
1226 | TransformPropagator propagator(out); |
1227 | MaxRootDomainInfoSpanningTree(out).traverse(&propagator); |
1228 | |
1229 | tv0->computeAt(out, 1); |
1230 | |
1231 | out->axis(0)->parallelize(ParallelType::BIDx); |
1232 | out->axis(1)->parallelize(ParallelType::TIDx); |
1233 | |
1234 | scheduler_utils::parallelizeAllLike(out, {tv1, tv2}); |
1235 | |
1236 | for (auto tv : {tv1, tv2}) { |
1237 | tv->setMemoryType(MemoryType::Shared); |
1238 | } |
1239 | |
1240 | // t1 and t2 allocation: (split_factor1 + 1) * (split_factor1 + 1) |
1241 | GpuLower gpulw(&fusion); |
1242 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
1243 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
1244 | auto tensor_name = alloc->buffer()->name(); |
1245 | if (tensor_name == 1 || tensor_name == 2) { |
1246 | TORCH_CHECK(alloc->shape().size() == 2); |
1247 | for (int i = 0; i < 2; ++i) { |
1248 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
1249 | TORCH_CHECK( |
1250 | size != nullptr && size->isConst() && |
1251 | size->value().value() == split_factor1 + 1); |
1252 | } |
1253 | } |
1254 | } |
1255 | } |
1256 | |
1257 | int numel_x = 99; |
1258 | int numel_y = 101; |
1259 | |
1260 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1261 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1262 | std::vector<IValue> inputs = {t0}; |
1263 | |
1264 | FusionExecutor fe; |
1265 | fe.compileFusion(&fusion, inputs); |
1266 | auto outputs = fe.runFusion(inputs); |
1267 | |
1268 | auto ref = shift(t0 + 1 + 2, {1, 1}); |
1269 | |
1270 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1271 | } |
1272 | |
1273 | TEST_F(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { |
1274 | Fusion fusion; |
1275 | FusionGuard fg(&fusion); |
1276 | |
1277 | // 5-pt stencil |
1278 | auto tv0 = makeSymbolicTensor(2); |
1279 | fusion.addInput(tv0); |
1280 | std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; |
1281 | |
1282 | std::vector<TensorView*> tvs; |
1283 | for (const auto& offset : offsets) { |
1284 | tvs.push_back(shift(tv0, offset)); |
1285 | } |
1286 | |
1287 | auto tv_out = tv0; |
1288 | |
1289 | for (auto tv : tvs) { |
1290 | tv_out = add(tv_out, tv); |
1291 | } |
1292 | |
1293 | tv_out = div(tv_out, IrBuilder::create<Double>(tvs.size() + 1)); |
1294 | |
1295 | fusion.addOutput(tv_out); |
1296 | |
1297 | std::vector<int> split_factor({4, 32}); |
1298 | |
1299 | tv_out->split(-1, split_factor[1]); |
1300 | tv_out->split(0, split_factor[0]); |
1301 | tv_out->reorder({{1, 2}, {2, 1}}); |
1302 | |
1303 | auto tv0_cache = tv0->cacheAfter(); |
1304 | |
1305 | // Merge the inner-most two axes and create |
1306 | // a 1D thread block of split_factor1*split_factor2 threads |
1307 | tv_out->merge(-2, -1); |
1308 | |
1309 | tv0->computeAt(tv_out, 2); |
1310 | |
1311 | // Inline completely except for the cache |
1312 | for (auto tv : tvs) { |
1313 | tv->computeAt(tv_out, -1); |
1314 | } |
1315 | |
1316 | tv0_cache->merge(-2, -1); |
1317 | |
1318 | tv_out->axis(-1)->parallelize(ParallelType::TIDx); |
1319 | tv_out->axis(1)->parallelize(ParallelType::BIDx); |
1320 | tv_out->axis(0)->parallelize(ParallelType::BIDy); |
1321 | |
1322 | tv0_cache->setMemoryType(MemoryType::Shared); |
1323 | tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); |
1324 | |
1325 | // cache allocation: (split_factor1 + 2) * (split_factor2 + 2) |
1326 | GpuLower gpulw(&fusion); |
1327 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
1328 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
1329 | auto tensor_name = alloc->buffer()->name(); |
1330 | if (tensor_name == tv0_cache->name()) { |
1331 | TORCH_CHECK(alloc->shape().size() == 2); |
1332 | for (int i = 0; i < 2; ++i) { |
1333 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
1334 | TORCH_CHECK( |
1335 | size != nullptr && size->isConst() && |
1336 | size->value().value() == split_factor[i] + 2); |
1337 | } |
1338 | } |
1339 | } |
1340 | } |
1341 | |
1342 | int numel_x = 99; |
1343 | int numel_y = 101; |
1344 | |
1345 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1346 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1347 | std::vector<IValue> inputs = {t0}; |
1348 | |
1349 | FusionExecutor fe; |
1350 | fe.compileFusion(&fusion, inputs); |
1351 | auto outputs = fe.runFusion(inputs); |
1352 | |
1353 | auto ref = t0; |
1354 | for (const auto& offset : offsets) { |
1355 | ref = ref + shift(t0, offset); |
1356 | } |
1357 | ref = ref / int(offsets.size() + 1); |
1358 | |
1359 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1360 | } |
1361 | |
1362 | TEST_F(NVFuserTest, FusionShiftChain1_CUDA) { |
1363 | Fusion fusion; |
1364 | FusionGuard fg(&fusion); |
1365 | |
1366 | auto tv0 = makeSymbolicTensor(2); |
1367 | fusion.addInput(tv0); |
1368 | auto tv1 = shift(tv0, {0, 1}); |
1369 | auto tv2 = shift(tv1, {0, 1}); |
1370 | fusion.addOutput(tv2); |
1371 | |
1372 | int split_factor = 4; |
1373 | tv2->split(-1, split_factor); |
1374 | |
1375 | tv0->computeAt(tv2, -2); |
1376 | |
1377 | int numel_x = 99; |
1378 | int numel_y = 101; |
1379 | |
1380 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1381 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1382 | std::vector<IValue> inputs = {t0}; |
1383 | |
1384 | FusionExecutor fe; |
1385 | fe.compileFusion(&fusion, inputs); |
1386 | auto outputs = fe.runFusion(inputs); |
1387 | |
1388 | auto ref = shift(shift(t0, {0, 1}), {0, 1}); |
1389 | |
1390 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1391 | } |
1392 | |
1393 | TEST_F(NVFuserTest, FusionShiftChain2_CUDA) { |
1394 | Fusion fusion; |
1395 | FusionGuard fg(&fusion); |
1396 | |
1397 | auto tv0 = makeSymbolicTensor(2); |
1398 | fusion.addInput(tv0); |
1399 | auto tv1 = shift(tv0, {0, 1}); |
1400 | auto tv2 = shift(tv1, {0, -1}); |
1401 | fusion.addOutput(tv2); |
1402 | |
1403 | tv2->split(-1, 4); |
1404 | |
1405 | tv0->computeAt(tv2, -2); |
1406 | |
1407 | int numel_x = 99; |
1408 | int numel_y = 101; |
1409 | |
1410 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1411 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1412 | std::vector<IValue> inputs = {t0}; |
1413 | |
1414 | FusionExecutor fe; |
1415 | fe.compileFusion(&fusion, inputs); |
1416 | auto outputs = fe.runFusion(inputs); |
1417 | |
1418 | auto ref = shift(shift(t0, {0, 1}), {0, -1}); |
1419 | |
1420 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1421 | } |
1422 | |
1423 | TEST_F(NVFuserTest, FusionShiftChain3_CUDA) { |
1424 | Fusion fusion; |
1425 | FusionGuard fg(&fusion); |
1426 | |
1427 | auto tv0 = makeSymbolicTensor(2); |
1428 | fusion.addInput(tv0); |
1429 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1430 | auto tv2 = shift(tv1, {0, 1}); |
1431 | auto tv3 = shift(tv2, {0, 1}); |
1432 | fusion.addOutput(tv3); |
1433 | |
1434 | int split_factor = 4; |
1435 | tv3->split(-1, split_factor); |
1436 | |
1437 | tv0->computeAt(tv3, -2); |
1438 | |
1439 | // Halo size of tv1 is 2 as it needs to account for both of the two |
1440 | // shift operations , while that of tv2 is still just 1 |
1441 | |
1442 | // tv1: (split_factor + 2) |
1443 | // tv2: (split_factor + 1) |
1444 | GpuLower gpulw(&fusion); |
1445 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
1446 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
1447 | auto tensor_name = alloc->buffer()->name(); |
1448 | if (tensor_name == 1 || tensor_name == 2) { |
1449 | TORCH_CHECK(alloc->shape().size() == 1); |
1450 | for (int i = 0; i < 1; ++i) { |
1451 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
1452 | TORCH_CHECK(size != nullptr && size->isConst()); |
1453 | if (tensor_name == 1) { |
1454 | TORCH_CHECK(size->value().value() == split_factor + 2); |
1455 | } else if (tensor_name == 2) { |
1456 | TORCH_CHECK(size->value().value() == split_factor + 1); |
1457 | } |
1458 | } |
1459 | } |
1460 | } |
1461 | } |
1462 | |
1463 | int numel_x = 99; |
1464 | int numel_y = 101; |
1465 | |
1466 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1467 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1468 | std::vector<IValue> inputs = {t0}; |
1469 | |
1470 | FusionExecutor fe; |
1471 | fe.compileFusion(&fusion, inputs); |
1472 | auto outputs = fe.runFusion(inputs); |
1473 | |
1474 | auto t1 = t0 + 1; |
1475 | auto t2 = shift(t1, {0, 1}); |
1476 | auto t3 = shift(t2, {0, 1}); |
1477 | auto ref = t3; |
1478 | |
1479 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1480 | } |
1481 | |
1482 | TEST_F(NVFuserTest, FusionShiftChain4_CUDA) { |
1483 | Fusion fusion; |
1484 | FusionGuard fg(&fusion); |
1485 | |
1486 | auto tv0 = makeSymbolicTensor(2); |
1487 | fusion.addInput(tv0); |
1488 | auto tv1 = shift(tv0, {1, -1}); |
1489 | auto tv2 = shift(tv1, {2, -2}); |
1490 | auto tv3 = shift(tv2, {3, -3}); |
1491 | auto tv4 = shift(tv3, {4, -4}); |
1492 | auto tv_out = tv4; |
1493 | |
1494 | fusion.addOutput(tv_out); |
1495 | |
1496 | int split_factor = 4; |
1497 | |
1498 | tv_out->split(-1, split_factor); |
1499 | tv_out->split(0, split_factor); |
1500 | tv_out->reorder({{1, 2}, {2, 1}}); |
1501 | |
1502 | tv0->computeAt(tv_out, 2); |
1503 | |
1504 | tv1->merge(-2, -1); |
1505 | tv2->merge(-2, -1); |
1506 | tv3->merge(-2, -1); |
1507 | |
1508 | // tv1: (split_factor + 9) * (split_factor + 9) |
1509 | // tv2: (split_factor + 7) * (split_factor + 7) |
1510 | // tv3: (split_factor + 4) * (split_factor + 4) |
1511 | GpuLower gpulw(&fusion); |
1512 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
1513 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
1514 | auto tensor_name = alloc->buffer()->name(); |
1515 | if (tensor_name == 1 || tensor_name == 2) { |
1516 | TORCH_CHECK(alloc->shape().size() == 2); |
1517 | for (int i = 0; i < 2; ++i) { |
1518 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
1519 | TORCH_CHECK(size != nullptr && size->isConst()); |
1520 | auto size_val = size->value().value(); |
1521 | if (tensor_name == 1) { |
1522 | TORCH_CHECK(size_val == split_factor + 9); |
1523 | } else if (tensor_name == 2) { |
1524 | TORCH_CHECK(size_val == split_factor + 7); |
1525 | } else if (tensor_name == 3) { |
1526 | TORCH_CHECK(size_val == split_factor + 4); |
1527 | } |
1528 | } |
1529 | } |
1530 | } |
1531 | } |
1532 | |
1533 | int numel_x = 99; |
1534 | int numel_y = 101; |
1535 | |
1536 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1537 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1538 | std::vector<IValue> inputs = {t0}; |
1539 | |
1540 | FusionExecutor fe; |
1541 | fe.compileFusion(&fusion, inputs); |
1542 | auto outputs = fe.runFusion(inputs); |
1543 | |
1544 | auto t1 = shift(t0, {1, -1}); |
1545 | auto t2 = shift(t1, {2, -2}); |
1546 | auto t3 = shift(t2, {3, -3}); |
1547 | auto t4 = shift(t3, {4, -4}); |
1548 | auto ref = t4; |
1549 | |
1550 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1551 | } |
1552 | |
1553 | TEST_F(NVFuserTest, FusionShift5ptStencilChain_CUDA) { |
1554 | Fusion fusion; |
1555 | FusionGuard fg(&fusion); |
1556 | |
1557 | auto tv0 = makeSymbolicTensor(2); |
1558 | fusion.addInput(tv0); |
1559 | std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; |
1560 | |
1561 | // First stencil: 5pt stencil |
1562 | // stencil1 = (tv0 + tv0[+1][0] + tv0[-1][0] + tv0[0][+1] + tv0[0][-1]) / 5 |
1563 | std::vector<TensorView*> tv_stencil1_shifts; |
1564 | for (const auto& offset : offsets) { |
1565 | tv_stencil1_shifts.push_back(shift(tv0, offset)); |
1566 | } |
1567 | |
1568 | auto tv_stencil1 = tv0; |
1569 | for (auto tv : tv_stencil1_shifts) { |
1570 | tv_stencil1 = add(tv_stencil1, tv); |
1571 | } |
1572 | |
1573 | tv_stencil1 = div( |
1574 | tv_stencil1, IrBuilder::create<Double>(tv_stencil1_shifts.size() + 1)); |
1575 | |
1576 | // Second stencil: Same 5pt stencil |
1577 | std::vector<TensorView*> tv_stencil2_shifts; |
1578 | for (const auto& offset : offsets) { |
1579 | tv_stencil2_shifts.push_back(shift(tv_stencil1, offset)); |
1580 | } |
1581 | |
1582 | auto tv_stencil2 = tv_stencil1; |
1583 | for (auto tv : tv_stencil2_shifts) { |
1584 | tv_stencil2 = add(tv_stencil2, tv); |
1585 | } |
1586 | |
1587 | tv_stencil2 = div( |
1588 | tv_stencil2, IrBuilder::create<Double>(tv_stencil2_shifts.size() + 1)); |
1589 | |
1590 | auto tv_out = tv_stencil2; |
1591 | |
1592 | fusion.addOutput(tv_out); |
1593 | |
1594 | auto tv0_cache = tv0->cacheAfter(); |
1595 | |
1596 | std::vector<int> split_factor({16, 16}); |
1597 | |
1598 | tv_out->split(-1, split_factor[1]); |
1599 | tv_out->split(0, split_factor[0]); |
1600 | tv_out->reorder({{1, 2}, {2, 1}}); |
1601 | |
1602 | tv0->computeAt(tv_out, 2); |
1603 | |
1604 | // Inline completely all inputs to the first stencil output, except for the |
1605 | // tv0 cache |
1606 | for (auto tv : tv_stencil1_shifts) { |
1607 | tv->computeAt(tv_stencil1, -1); |
1608 | } |
1609 | |
1610 | // Inline completely all inputs to the second stencil output, except |
1611 | // for the first stencil output |
1612 | for (auto tv : tv_stencil2_shifts) { |
1613 | tv->computeAt(tv_stencil2, -1); |
1614 | } |
1615 | |
1616 | tv_out->axis(1)->parallelize(ParallelType::BIDx); |
1617 | tv_out->axis(0)->parallelize(ParallelType::BIDy); |
1618 | |
1619 | auto all_values = DependencyCheck::getAllValsBetween( |
1620 | {fusion.inputs().begin(), fusion.inputs().end()}, fusion.outputs()); |
1621 | for (auto tv : ir_utils::filterByType<TensorView>(all_values)) { |
1622 | tv->axis(-1)->parallelize(ParallelType::TIDx); |
1623 | tv->axis(-2)->parallelize(ParallelType::TIDy); |
1624 | } |
1625 | |
1626 | tv0_cache->setMemoryType(MemoryType::Shared); |
1627 | tv_stencil1->setMemoryType(MemoryType::Shared); |
1628 | |
1629 | // tv0_cache: (split_factor + 4) * (split_factor + 4) |
1630 | // tv_stencil1: (split_factor + 2) * (split_factor + 2) |
1631 | GpuLower gpulw(&fusion); |
1632 | for (const auto expr : gpulw.kernel()->unordered_exprs()) { |
1633 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
1634 | auto tensor_name = alloc->buffer()->name(); |
1635 | if (tensor_name == tv0_cache->name() || |
1636 | tensor_name == tv_stencil1->name()) { |
1637 | TORCH_CHECK(alloc->shape().size() == 2); |
1638 | for (int i = 0; i < 2; ++i) { |
1639 | auto size = dynamic_cast<Int*>(alloc->shape().at(i)); |
1640 | TORCH_CHECK(size != nullptr && size->isConst()); |
1641 | if (tensor_name == tv0_cache->name()) { |
1642 | TORCH_CHECK(size->value().value() == split_factor[i] + 4); |
1643 | } else if (tensor_name == tv_stencil1->name()) { |
1644 | TORCH_CHECK(size->value().value() == split_factor[i] + 2); |
1645 | } |
1646 | } |
1647 | } |
1648 | } |
1649 | } |
1650 | |
1651 | int numel_x = 99; |
1652 | int numel_y = 101; |
1653 | |
1654 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1655 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1656 | std::vector<IValue> inputs = {t0}; |
1657 | |
1658 | FusionExecutor fe; |
1659 | fe.compileFusion(&fusion, inputs); |
1660 | auto outputs = fe.runFusion(inputs); |
1661 | |
1662 | auto stencil1 = t0; |
1663 | for (const auto& offset : offsets) { |
1664 | stencil1 = stencil1 + shift(t0, offset); |
1665 | } |
1666 | stencil1 = stencil1 / int(offsets.size() + 1); |
1667 | auto stencil2 = stencil1; |
1668 | for (const auto& offset : offsets) { |
1669 | stencil2 = stencil2 + shift(stencil1, offset); |
1670 | } |
1671 | stencil2 = stencil2 / int(offsets.size() + 1); |
1672 | auto ref = stencil2; |
1673 | |
1674 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1675 | } |
1676 | |
1677 | // Shift a reduced tensor |
1678 | TEST_F(NVFuserTest, FusionShiftReduction1_CUDA) { |
1679 | Fusion fusion; |
1680 | FusionGuard fg(&fusion); |
1681 | |
1682 | auto tv0 = makeSymbolicTensor(2); |
1683 | fusion.addInput(tv0); |
1684 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1685 | auto tv2 = sum(tv1, {1}); |
1686 | auto tv3 = shift(tv2, {1}); |
1687 | fusion.addOutput(tv3); |
1688 | |
1689 | tv3->split(0, 4); |
1690 | tv0->computeAt(tv3, 1); |
1691 | tv0->computeAt(tv2, -1); |
1692 | |
1693 | const int numel_x = 9; |
1694 | const int numel_y = 11; |
1695 | |
1696 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1697 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1698 | std::vector<IValue> inputs = {t0}; |
1699 | |
1700 | FusionExecutor fe; |
1701 | fe.compileFusion(&fusion, inputs); |
1702 | auto outputs = fe.runFusion(inputs); |
1703 | |
1704 | auto t1 = t0 + 1; |
1705 | auto t2 = sum(t1, {1}); |
1706 | auto t3 = shift(t2, {1}); |
1707 | auto ref = t3; |
1708 | |
1709 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1710 | } |
1711 | |
1712 | // Parallelized version of FusionShiftReduction1 |
1713 | TEST_F(NVFuserTest, FusionShiftReduction2_CUDA) { |
1714 | Fusion fusion; |
1715 | FusionGuard fg(&fusion); |
1716 | |
1717 | auto tv0 = makeSymbolicTensor(2); |
1718 | fusion.addInput(tv0); |
1719 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1720 | auto tv2 = sum(tv1, {1}); |
1721 | auto tv3 = shift(tv2, {1}); |
1722 | fusion.addOutput(tv3); |
1723 | |
1724 | tv3->split(0, 4); |
1725 | tv0->computeAt(tv3, 1); |
1726 | |
1727 | tv2->split(-1, 32); |
1728 | tv0->computeAt(tv2, -1); |
1729 | |
1730 | tv2->axis(-1)->parallelize(ParallelType::TIDx); |
1731 | |
1732 | tv2->setMemoryType(MemoryType::Shared); |
1733 | |
1734 | const int numel_x = 201; |
1735 | const int numel_y = 301; |
1736 | |
1737 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1738 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1739 | std::vector<IValue> inputs = {t0}; |
1740 | |
1741 | FusionExecutor fe; |
1742 | fe.compileFusion(&fusion, inputs); |
1743 | auto outputs = fe.runFusion(inputs); |
1744 | |
1745 | auto t1 = t0 + 1; |
1746 | auto t2 = sum(t1, {1}); |
1747 | auto t3 = shift(t2, {1}); |
1748 | auto ref = t3; |
1749 | |
1750 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1751 | } |
1752 | |
1753 | TEST_F(NVFuserTest, FusionShiftRfactor1_CUDA) { |
1754 | Fusion fusion; |
1755 | FusionGuard fg(&fusion); |
1756 | |
1757 | auto tv0 = makeSymbolicTensor(2); |
1758 | fusion.addInput(tv0); |
1759 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1760 | auto tv2 = sum(tv1, {1}); |
1761 | auto tv3 = shift(tv2, {1}); |
1762 | fusion.addOutput(tv3); |
1763 | |
1764 | tv3->split(0, 4); |
1765 | tv0->computeAt(tv3, 1); |
1766 | |
1767 | tv2->split(-1, 32); |
1768 | auto rf = tv2->rFactor({-2}); |
1769 | tv0->computeAt(tv2, -1); |
1770 | tv0->computeAt(rf, -1); |
1771 | |
1772 | tv2->axis(-1)->parallelize(ParallelType::TIDx); |
1773 | |
1774 | tv2->setMemoryType(MemoryType::Shared); |
1775 | |
1776 | const int numel_x = 201; |
1777 | const int numel_y = 301; |
1778 | |
1779 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1780 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1781 | std::vector<IValue> inputs = {t0}; |
1782 | |
1783 | FusionExecutor fe; |
1784 | fe.compileFusion(&fusion, inputs); |
1785 | auto outputs = fe.runFusion(inputs); |
1786 | |
1787 | auto t1 = t0 + 1; |
1788 | auto t2 = sum(t1, {1}); |
1789 | auto t3 = shift(t2, {1}); |
1790 | auto ref = t3; |
1791 | |
1792 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1793 | } |
1794 | |
1795 | TEST_F(NVFuserTest, FusionShiftBcast1_CUDA) { |
1796 | Fusion fusion; |
1797 | FusionGuard fg(&fusion); |
1798 | |
1799 | auto tv0 = makeSymbolicTensor(1); |
1800 | fusion.addInput(tv0); |
1801 | auto tv1 = makeSymbolicTensor(2); |
1802 | fusion.addInput(tv1); |
1803 | auto tv2 = broadcast(tv0, {false, true}); |
1804 | auto tv3 = shift(tv2, {0, 1}); |
1805 | auto tv4 = add(tv3, tv1); |
1806 | fusion.addOutput(tv4); |
1807 | |
1808 | tv0->computeAt(tv4, -1); |
1809 | tv1->computeAt(tv4, -1); |
1810 | |
1811 | const int numel_x = 9; |
1812 | const int numel_y = 11; |
1813 | |
1814 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1815 | at::Tensor t0 = at::randn({numel_x}, options); |
1816 | at::Tensor t1 = at::randn({numel_x, numel_y}, options); |
1817 | std::vector<IValue> inputs = {t0, t1}; |
1818 | |
1819 | FusionExecutor fe; |
1820 | fe.compileFusion(&fusion, inputs); |
1821 | auto outputs = fe.runFusion(inputs); |
1822 | |
1823 | auto t4 = t0.unsqueeze(-1).expand({numel_x, numel_y}) + t1; |
1824 | auto ref = t4; |
1825 | |
1826 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1827 | } |
1828 | |
1829 | TEST_F(NVFuserTest, FusionShiftBcast2_CUDA) { |
1830 | Fusion fusion; |
1831 | FusionGuard fg(&fusion); |
1832 | |
1833 | auto tv0 = makeSymbolicTensor(1); |
1834 | fusion.addInput(tv0); |
1835 | auto tv1 = makeSymbolicTensor(2); |
1836 | fusion.addInput(tv1); |
1837 | auto tv2 = broadcast(tv0, {false, true}); |
1838 | auto tv3 = shift(tv2, {1, 0}); |
1839 | auto tv4 = add(tv3, tv1); |
1840 | fusion.addOutput(tv4); |
1841 | |
1842 | tv4->split(0, 4); |
1843 | tv0->computeAt(tv4, 1); |
1844 | |
1845 | const int numel_x = 9; |
1846 | const int numel_y = 11; |
1847 | |
1848 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1849 | at::Tensor t0 = at::randn({numel_x}, options); |
1850 | at::Tensor t1 = at::randn({numel_x, numel_y}, options); |
1851 | std::vector<IValue> inputs = {t0, t1}; |
1852 | |
1853 | FusionExecutor fe; |
1854 | fe.compileFusion(&fusion, inputs); |
1855 | auto outputs = fe.runFusion(inputs); |
1856 | |
1857 | auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y}); |
1858 | auto t3 = shift(t2, {1, 0}); |
1859 | auto ref = t3 + t1; |
1860 | |
1861 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1862 | } |
1863 | |
1864 | // Combine ShiftBcast1 and ShiftBcast2 with parallelization |
1865 | TEST_F(NVFuserTest, FusionShiftBcast3_CUDA) { |
1866 | Fusion fusion; |
1867 | FusionGuard fg(&fusion); |
1868 | |
1869 | auto tv0 = makeSymbolicTensor(1); |
1870 | fusion.addInput(tv0); |
1871 | auto tv1 = makeSymbolicTensor(2); |
1872 | fusion.addInput(tv1); |
1873 | auto tv2 = broadcast(tv0, {false, true}); |
1874 | auto tv3 = shift(tv2, {1, 0}); |
1875 | auto tv4 = shift(tv2, {0, 1}); |
1876 | auto tv5 = shift(tv2, {-1, -1}); |
1877 | auto tv6 = add(tv3, tv4); |
1878 | auto tv7 = add(tv6, tv5); |
1879 | auto tv8 = add(tv7, tv1); |
1880 | fusion.addOutput(tv8); |
1881 | |
1882 | tv8->split(0, 4); |
1883 | tv8->split(-1, 4); |
1884 | tv0->computeAt(tv8, 1); |
1885 | |
1886 | tv8->axis(-1)->parallelize(ParallelType::TIDx); |
1887 | for (auto tv : {tv8, tv7, tv6, tv5, tv4, tv3, tv2}) { |
1888 | tv->axis(1)->parallelize(ParallelType::TIDy); |
1889 | } |
1890 | |
1891 | tv2->setMemoryType(MemoryType::Shared); |
1892 | |
1893 | const int numel_x = 101; |
1894 | const int numel_y = 201; |
1895 | |
1896 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1897 | at::Tensor t0 = at::randn({numel_x}, options); |
1898 | at::Tensor t1 = at::randn({numel_x, numel_y}, options); |
1899 | std::vector<IValue> inputs = {t0, t1}; |
1900 | |
1901 | FusionExecutor fe; |
1902 | fe.compileFusion(&fusion, inputs); |
1903 | auto outputs = fe.runFusion(inputs); |
1904 | |
1905 | auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y}); |
1906 | auto t3 = shift(t2, {1, 0}); |
1907 | auto t4 = t2; |
1908 | auto t5 = shift(t2, {-1, 0}); |
1909 | auto ref = t3 + t4 + t5 + t1; |
1910 | |
1911 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
1912 | } |
1913 | |
1914 | // See issue #893 |
1915 | TEST_F(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { |
1916 | Fusion fusion; |
1917 | FusionGuard fg(&fusion); |
1918 | |
1919 | auto tv0 = makeSymbolicTensor(2); |
1920 | fusion.addInput(tv0); |
1921 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1922 | auto tv2 = add(tv0, IrBuilder::create<Double>(2)); |
1923 | auto tv3 = add(tv1, tv2); |
1924 | auto tv4 = shift(tv3, {0, 1}); |
1925 | fusion.addOutput(tv4); |
1926 | |
1927 | tv4->split(1, 8); |
1928 | tv0->computeAt(tv4, 2); |
1929 | |
1930 | tv2->computeAt(tv3, -1); |
1931 | |
1932 | tv1->setMemoryType(MemoryType::Shared); |
1933 | tv3->setMemoryType(MemoryType::Shared); |
1934 | |
1935 | tv1->axis(-1)->parallelize(ParallelType::TIDx); |
1936 | tv3->axis(-1)->parallelize(ParallelType::TIDx); |
1937 | tv4->axis(-1)->parallelize(ParallelType::TIDx); |
1938 | |
1939 | int numel_x = 99; |
1940 | int numel_y = 101; |
1941 | |
1942 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1943 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
1944 | std::vector<IValue> inputs = {t0}; |
1945 | |
1946 | FusionExecutor fe; |
1947 | fe.compileFusion(&fusion, inputs); |
1948 | auto outputs = fe.runFusion(inputs); |
1949 | |
1950 | auto t1 = t0 + 1; |
1951 | auto t2 = t0 + 2; |
1952 | auto t3 = add(t1, t2); |
1953 | auto t4 = shift(t3, {0, 1}); |
1954 | |
1955 | testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); |
1956 | } |
1957 | |
1958 | // See issue #893. Top-level placement. |
1959 | TEST_F(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { |
1960 | Fusion fusion; |
1961 | FusionGuard fg(&fusion); |
1962 | |
1963 | auto tv0 = makeSymbolicTensor(1); |
1964 | fusion.addInput(tv0); |
1965 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1966 | auto tv2 = add(tv0, IrBuilder::create<Double>(2)); |
1967 | auto tv3 = add(tv1, tv2); |
1968 | auto tv4 = shift(tv3, {1}); |
1969 | fusion.addOutput(tv4); |
1970 | |
1971 | tv2->computeAt(tv3, -1); |
1972 | |
1973 | tv1->setMemoryType(MemoryType::Shared); |
1974 | tv3->setMemoryType(MemoryType::Shared); |
1975 | |
1976 | tv1->axis(-1)->parallelize(ParallelType::TIDx); |
1977 | tv3->axis(-1)->parallelize(ParallelType::TIDx); |
1978 | tv4->axis(-1)->parallelize(ParallelType::TIDx); |
1979 | |
1980 | int numel_x = 99; |
1981 | |
1982 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1983 | at::Tensor t0 = at::randn({numel_x}, options); |
1984 | std::vector<IValue> inputs = {t0}; |
1985 | |
1986 | FusionExecutor fe; |
1987 | fe.compileFusion(&fusion, inputs); |
1988 | auto outputs = fe.runFusion(inputs); |
1989 | |
1990 | auto t1 = t0 + 1; |
1991 | auto t2 = t0 + 2; |
1992 | auto t3 = add(t1, t2); |
1993 | auto t4 = shift(t3, {1}); |
1994 | |
1995 | testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); |
1996 | } |
1997 | |
1998 | // Based on original CUDA provided by Vishal Mehta. |
1999 | // Major differences with the original version: |
2000 | // - The original version uses additional 2 warps to load the halos |
2001 | // along the Y dimension. The other 10 warps are used to load a 32x10 |
2002 | // tile, and all warps will do coalesced loads. No such optimization |
2003 | // is done in the fuser version. |
2004 | TEST_F(NVFuserTest, FusionHdiff_CUDA) { |
2005 | Fusion fusion; |
2006 | FusionGuard fg(&fusion); |
2007 | |
2008 | auto inp = makeSymbolicTensor(3); |
2009 | fusion.addInput(inp); |
2010 | auto coeff = makeSymbolicTensor(3); |
2011 | fusion.addInput(coeff); |
2012 | |
2013 | std::vector<std::vector<int>> offsets{ |
2014 | {0, 1, 0}, {0, -1, 0}, {0, 0, 1}, {0, 0, -1}}; |
2015 | |
2016 | // T2, T3, T4, T5 |
2017 | std::vector<TensorView*> inp_neighbors; |
2018 | for (const auto& offset : offsets) { |
2019 | inp_neighbors.push_back(shift(inp, offset, false)); |
2020 | } |
2021 | |
2022 | // T8 |
2023 | TensorView* sum_of_neighbors = nullptr; |
2024 | for (auto inp_neighbor : inp_neighbors) { |
2025 | if (sum_of_neighbors == nullptr) { |
2026 | sum_of_neighbors = inp_neighbor; |
2027 | } else { |
2028 | sum_of_neighbors = add(sum_of_neighbors, inp_neighbor); |
2029 | } |
2030 | } |
2031 | |
2032 | // T9 = T0 * 4 |
2033 | // T10 = T9 - T8 |
2034 | auto lap = sub(mul(inp, IrBuilder::create<Double>(4)), sum_of_neighbors); |
2035 | |
2036 | // T11 = shift(T10) |
2037 | // T12 = T11 - T10 |
2038 | auto flx = sub(shift(lap, {0, 0, -1}, false), lap); |
2039 | // T14 = T13 - T0 |
2040 | // T15 = T12 * T14 |
2041 | // T16 = T15 > 0 |
2042 | // T17 = T16 ? 0 : T12 |
2043 | auto flx_cond = |
2044 | gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), |
2045 | IrBuilder::create<Double>(0)); |
2046 | auto flx0 = where(flx_cond, IrBuilder::create<Double>(0), flx); |
2047 | |
2048 | // T18 = shift(T10) |
2049 | // T19 = T18 - T10 |
2050 | auto fly = sub(shift(lap, {0, -1, 0}, false), lap); |
2051 | // T20 = shift(T0) |
2052 | // T21 = T20 - T0 |
2053 | // T22 = T19 * T21 |
2054 | // T23 = T22 > 0 |
2055 | auto fly_cond = |
2056 | gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), |
2057 | IrBuilder::create<Double>(0)); |
2058 | // T24 = T23 ? 0 : T19 |
2059 | auto fly0 = where(fly_cond, IrBuilder::create<Double>(0), fly); |
2060 | |
2061 | // T25 = shift(flx0) |
2062 | // T26 = T17 - T25 |
2063 | // T27 = shift(fly0) |
2064 | // T28 = T24 - T27 |
2065 | // T29 = T26 + T28 |
2066 | // T30 = T1 * T29 |
2067 | // T31 = T0 - T30 |
2068 | auto out = |
2069 | sub(inp, |
2070 | mul(coeff, |
2071 | add(sub(flx0, shift(flx0, {0, 0, 1}, false)), |
2072 | sub(fly0, shift(fly0, {0, 1, 0}, false))))); |
2073 | |
2074 | fusion.addOutput(out); |
2075 | |
2076 | ///////////////////////////////// |
2077 | // Scheduling |
2078 | ///////////////////////////////// |
2079 | |
2080 | out->setContiguity(false); |
2081 | |
2082 | // Step 1: 2D Tiling |
2083 | |
2084 | const int tile_x = 32; |
2085 | const int tile_y = 8; |
2086 | |
2087 | out->split(-1, tile_x); |
2088 | out->split(-3, tile_y); |
2089 | out->reorder({{-2, -3}}); |
2090 | inp->computeAt(out, -3); |
2091 | coeff->computeAt(out, -3); |
2092 | |
2093 | // Step 2: Inlining |
2094 | |
2095 | // Inline inputs to lap |
2096 | auto lap_vals = DependencyCheck::getAllValsBetween({inp}, {lap}); |
2097 | for (auto val : ir_utils::filterByType<TensorView>(lap_vals)) { |
2098 | if (val != lap && val != inp) { |
2099 | val->computeAt(lap, -1); |
2100 | } |
2101 | } |
2102 | |
2103 | // Inline inputs to flx0 |
2104 | auto flx0_vals = DependencyCheck::getAllValsBetween({lap, inp}, {flx0}); |
2105 | for (auto val : ir_utils::filterByType<TensorView>(flx0_vals)) { |
2106 | if (val != lap && val != flx0 && val != inp) { |
2107 | val->computeAt(flx0, -1); |
2108 | } |
2109 | } |
2110 | |
2111 | // Inline inputs to fly0 |
2112 | auto flxy_vals = DependencyCheck::getAllValsBetween({lap, inp}, {fly0}); |
2113 | for (auto val : ir_utils::filterByType<TensorView>(flxy_vals)) { |
2114 | if (val != lap && val != fly0 && val != inp) { |
2115 | val->computeAt(fly0, -1); |
2116 | } |
2117 | } |
2118 | |
2119 | // Inline inputs to out |
2120 | auto out_vals = DependencyCheck::getAllValsBetween({flx0, fly0}, {out}); |
2121 | for (auto val : ir_utils::filterByType<TensorView>(out_vals)) { |
2122 | if (val != flx0 && val != fly0 && val != out) { |
2123 | val->computeAt(out, -1); |
2124 | } |
2125 | } |
2126 | |
2127 | // Step 3: Parallelization |
2128 | |
2129 | // Block parallelization |
2130 | out->axis(0)->parallelize(ParallelType::BIDz); |
2131 | out->axis(1)->parallelize(ParallelType::BIDy); |
2132 | out->axis(2)->parallelize(ParallelType::BIDx); |
2133 | // Thread parallelization |
2134 | out->axis(3)->parallelize(ParallelType::TIDy); |
2135 | out->axis(4)->parallelize(ParallelType::TIDx); |
2136 | // Apply the same parallelization to all other tensors |
2137 | scheduler_utils::parallelizeAllLike(out); |
2138 | |
2139 | // Store intermediate stencil results on smem so that they can be |
2140 | // accessed by threads |
2141 | for (auto tv : {flx0, fly0, lap}) { |
2142 | tv->setMemoryType(MemoryType::Shared); |
2143 | } |
2144 | |
2145 | ///////////////////////////////// |
2146 | int numel_x = 101; |
2147 | int numel_y = 99; |
2148 | int numel_z = 10; |
2149 | |
2150 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2151 | at::Tensor inp_at = at::randn({numel_z, numel_y, numel_x}, options); |
2152 | at::Tensor coeff_at = at::randn({numel_z, numel_y, numel_x}, options); |
2153 | std::vector<IValue> inputs = {inp_at, coeff_at}; |
2154 | |
2155 | FusionExecutor fe; |
2156 | fe.compileFusion(&fusion, inputs); |
2157 | auto fuser_output = fe.runFusion(inputs)[0]; |
2158 | |
2159 | // Trim the outer rim |
2160 | std::vector<at::indexing::TensorIndex> indices{ |
2161 | at::indexing::Slice(0, at::indexing::None), |
2162 | at::indexing::Slice(2, -2), |
2163 | at::indexing::Slice(2, -2)}; |
2164 | fuser_output = fuser_output.index(indices); |
2165 | |
2166 | { |
2167 | at::Tensor zeros = at::zeros({numel_z, numel_y, numel_x}, options); |
2168 | auto lap = inp_at * 4 - |
2169 | (shift(inp_at, {0, 1, 0}) + shift(inp_at, {0, -1, 0}) + |
2170 | shift(inp_at, {0, 0, 1}) + shift(inp_at, {0, 0, -1})); |
2171 | auto flx = shift(lap, {0, 0, -1}) - lap; |
2172 | auto flx_cond = (flx * (shift(inp_at, {0, 0, -1}) - inp_at)) > 0; |
2173 | auto flx0 = at::where(flx_cond, zeros, flx); |
2174 | auto fly = shift(lap, {0, -1, 0}) - lap; |
2175 | auto fly_cond = (fly * (shift(inp_at, {0, -1, 0}) - inp_at)) > 0; |
2176 | auto fly0 = at::where(fly_cond, zeros, fly); |
2177 | |
2178 | auto ref = inp_at - |
2179 | coeff_at * |
2180 | ((flx0 - shift(flx0, {0, 0, 1})) + (fly0 - shift(fly0, {0, 1, 0}))); |
2181 | ref = ref.index(indices); |
2182 | |
2183 | testValidate(&fusion, {fuser_output}, inputs, {ref}, __LINE__, __FILE__); |
2184 | } |
2185 | } |
2186 | |
2187 | TEST_F(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { |
2188 | Fusion fusion; |
2189 | FusionGuard fg(&fusion); |
2190 | |
2191 | auto inp = makeSymbolicTensor(3); |
2192 | fusion.addInput(inp); |
2193 | auto coeff = makeSymbolicTensor(3); |
2194 | fusion.addInput(coeff); |
2195 | |
2196 | std::vector<std::vector<int>> offsets{ |
2197 | {0, 1, 0}, {0, -1, 0}, {0, 0, 1}, {0, 0, -1}}; |
2198 | |
2199 | // T2, T3, T4, T5 |
2200 | std::vector<TensorView*> inp_neighbors; |
2201 | for (const auto& offset : offsets) { |
2202 | inp_neighbors.push_back(shift(inp, offset, false)); |
2203 | } |
2204 | |
2205 | // T8 |
2206 | TensorView* sum_of_neighbors = nullptr; |
2207 | for (auto inp_neighbor : inp_neighbors) { |
2208 | if (sum_of_neighbors == nullptr) { |
2209 | sum_of_neighbors = inp_neighbor; |
2210 | } else { |
2211 | sum_of_neighbors = add(sum_of_neighbors, inp_neighbor); |
2212 | } |
2213 | } |
2214 | |
2215 | // T9 = T0 * 4 |
2216 | // T10 = T9 - T8 |
2217 | auto lap = sub(mul(inp, IrBuilder::create<Double>(4)), sum_of_neighbors); |
2218 | |
2219 | // T11 = shift(T10) |
2220 | // T12 = T11 - T10 |
2221 | auto flx = sub(shift(lap, {0, 0, -1}, false), lap); |
2222 | // T14 = T13 - T0 |
2223 | // T15 = T12 * T14 |
2224 | // T16 = T15 > 0 |
2225 | // T17 = T16 ? 0 : T12 |
2226 | auto flx_cond = |
2227 | gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), |
2228 | IrBuilder::create<Double>(0)); |
2229 | auto flx0 = where(flx_cond, IrBuilder::create<Double>(0), flx); |
2230 | |
2231 | // T18 = shift(T10) |
2232 | // T19 = T18 - T10 |
2233 | auto fly = sub(shift(lap, {0, -1, 0}, false), lap); |
2234 | // T20 = shift(T0) |
2235 | // T21 = T20 - T0 |
2236 | // T22 = T19 * T21 |
2237 | // T23 = T22 > 0 |
2238 | auto fly_cond = |
2239 | gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), |
2240 | IrBuilder::create<Double>(0)); |
2241 | // T24 = T23 ? 0 : T19 |
2242 | auto fly0 = where(fly_cond, IrBuilder::create<Double>(0), fly); |
2243 | |
2244 | // T25 = shift(flx0) |
2245 | // T26 = T17 - T25 |
2246 | // T27 = shift(fly0) |
2247 | // T28 = T24 - T27 |
2248 | // T29 = T26 + T28 |
2249 | // T30 = T1 * T29 |
2250 | // T31 = T0 - T30 |
2251 | auto out = |
2252 | sub(inp, |
2253 | mul(coeff, |
2254 | add(sub(flx0, shift(flx0, {0, 0, 1}, false)), |
2255 | sub(fly0, shift(fly0, {0, 1, 0}, false))))); |
2256 | |
2257 | fusion.addOutput(out); |
2258 | |
2259 | out->setContiguity(false); |
2260 | |
2261 | ///////////////////////////////// |
2262 | // Scheduling |
2263 | ///////////////////////////////// |
2264 | |
2265 | const auto all_vals = fusion.usedMathVals(); |
2266 | const std::vector<TensorView*> all_tensors( |
2267 | {ir_utils::filterByType<TensorView>(all_vals).begin(), |
2268 | ir_utils::filterByType<TensorView>(all_vals).end()}); |
2269 | |
2270 | // Step 1: Blocking |
2271 | // - Thread block size: (tile_x, tile_y) |
2272 | // - Each thread computes a vertical column of length tile_z along the Z |
2273 | // axis. |
2274 | // - Grid dize: (NX / block_x, NY / block_y, NZ / tile_z) |
2275 | |
2276 | const int tile_x = 32; |
2277 | const int tile_y = 8; |
2278 | const int tile_z = 16; |
2279 | |
2280 | out->split(0, tile_z); |
2281 | out->split(-1, tile_x, true, true); |
2282 | out->split(-3, tile_y, true, true); |
2283 | // out: [NZ/tz, tz, NY/by, by, NX/bx, bx] |
2284 | out->reorder({{1, 3}, {2, 1}, {3, 4}, {4, 2}}); |
2285 | // out: [NZ/tz, NY/by, NX/bx, tz, by, bx] |
2286 | |
2287 | TransformPropagator propagator(out); |
2288 | MaxRootDomainInfoSpanningTree(out).traverse(&propagator); |
2289 | |
2290 | inp->computeAt(out, 4); |
2291 | |
2292 | // Step 2: Inlining |
2293 | |
2294 | // Inline inputs to lap |
2295 | auto lap_vals = DependencyCheck::getAllValsBetween({inp}, {lap}); |
2296 | for (auto val : ir_utils::filterByType<TensorView>(lap_vals)) { |
2297 | if (val != lap && val != inp) { |
2298 | val->computeAt(lap, -1); |
2299 | } |
2300 | } |
2301 | |
2302 | // Inline inputs to flx0 |
2303 | auto flx0_vals = DependencyCheck::getAllValsBetween({lap, inp}, {flx0}); |
2304 | for (auto val : ir_utils::filterByType<TensorView>(flx0_vals)) { |
2305 | if (val != lap && val != flx0 && val != inp) { |
2306 | val->computeAt(flx0, -1); |
2307 | } |
2308 | } |
2309 | |
2310 | // Inline inputs to fly0 |
2311 | auto flxy_vals = DependencyCheck::getAllValsBetween({lap, inp}, {fly0}); |
2312 | for (auto val : ir_utils::filterByType<TensorView>(flxy_vals)) { |
2313 | if (val != lap && val != fly0 && val != inp) { |
2314 | val->computeAt(fly0, -1); |
2315 | } |
2316 | } |
2317 | |
2318 | // Inline inputs to out |
2319 | auto out_vals = DependencyCheck::getAllValsBetween({flx0, fly0}, {out}); |
2320 | for (auto val : ir_utils::filterByType<TensorView>(out_vals)) { |
2321 | if (val != flx0 && val != fly0 && val != out) { |
2322 | val->computeAt(out, -1); |
2323 | } |
2324 | } |
2325 | |
2326 | // Step 3: Parallelization |
2327 | |
2328 | // Block parallelization |
2329 | out->axis(0)->parallelize(ParallelType::BIDz); |
2330 | out->axis(1)->parallelize(ParallelType::BIDy); |
2331 | out->axis(2)->parallelize(ParallelType::BIDx); |
2332 | out->axis(4)->parallelize(ParallelType::TIDy); |
2333 | out->axis(5)->parallelize(ParallelType::TIDx); |
2334 | // Unswitch at the tz axis |
2335 | out->axis(3)->parallelize(ParallelType::Unswitch); |
2336 | |
2337 | scheduler_utils::parallelizeAllLike(out, all_tensors); |
2338 | |
2339 | // These need to be on smem |
2340 | for (auto tv : {flx0, fly0, lap}) { |
2341 | tv->setMemoryType(MemoryType::Shared); |
2342 | } |
2343 | |
2344 | ///////////////////////////////// |
2345 | const int halo_extent = 2; |
2346 | const int numel_x = 64 + halo_extent * 2; |
2347 | const int numel_y = 64 + halo_extent * 2; |
2348 | const int numel_z = 32; |
2349 | |
2350 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2351 | at::Tensor inp_at = at::randn({numel_z, numel_y, numel_x}, options); |
2352 | at::Tensor coeff_at = at::randn({numel_z, numel_y, numel_x}, options); |
2353 | std::vector<IValue> inputs = {inp_at, coeff_at}; |
2354 | |
2355 | FusionExecutor fe; |
2356 | fe.compileFusion(&fusion, inputs); |
2357 | auto fuser_output = fe.runFusion(inputs)[0]; |
2358 | |
2359 | // Trim the outer rim |
2360 | std::vector<at::indexing::TensorIndex> indices{ |
2361 | at::indexing::Slice(0, at::indexing::None), |
2362 | at::indexing::Slice(2, -2), |
2363 | at::indexing::Slice(2, -2)}; |
2364 | fuser_output = fuser_output.index(indices); |
2365 | |
2366 | { |
2367 | at::Tensor zeros = at::zeros({numel_z, numel_y, numel_x}, options); |
2368 | auto lap = inp_at * 4 - |
2369 | (shift(inp_at, {0, 1, 0}) + shift(inp_at, {0, -1, 0}) + |
2370 | shift(inp_at, {0, 0, 1}) + shift(inp_at, {0, 0, -1})); |
2371 | auto flx = shift(lap, {0, 0, -1}) - lap; |
2372 | auto flx_cond = (flx * (shift(inp_at, {0, 0, -1}) - inp_at)) > 0; |
2373 | auto flx0 = at::where(flx_cond, zeros, flx); |
2374 | auto fly = shift(lap, {0, -1, 0}) - lap; |
2375 | auto fly_cond = (fly * (shift(inp_at, {0, -1, 0}) - inp_at)) > 0; |
2376 | auto fly0 = at::where(fly_cond, zeros, fly); |
2377 | |
2378 | auto ref = inp_at - |
2379 | coeff_at * |
2380 | ((flx0 - shift(flx0, {0, 0, 1})) + (fly0 - shift(fly0, {0, 1, 0}))); |
2381 | ref = ref.index(indices); |
2382 | |
2383 | testValidate(&fusion, {fuser_output}, inputs, {ref}, __LINE__, __FILE__); |
2384 | } |
2385 | } |
2386 | |
2387 | // 3x3 max pooling |
2388 | TEST_F(NVFuserTest, FusionMaxPooling_CUDA) { |
2389 | Fusion fusion; |
2390 | FusionGuard fg(&fusion); |
2391 | |
2392 | // Format: CHW |
2393 | auto inp = makeSymbolicTensor(3); |
2394 | fusion.addInput(inp); |
2395 | |
2396 | // 3x3 pooling of the HW spatial domain |
2397 | std::vector<std::vector<int>> offsets; |
2398 | for (int i = -1; i <= 1; ++i) { |
2399 | for (int j = -1; j <= 1; ++j) { |
2400 | if (i == 0 && j == 0) { |
2401 | continue; |
2402 | } |
2403 | offsets.push_back({i, j}); |
2404 | } |
2405 | } |
2406 | |
2407 | std::vector<TensorView*> inp_tile({inp}); |
2408 | for (auto offset : offsets) { |
2409 | offset.insert(offset.begin(), 0); |
2410 | inp_tile.push_back(shift(inp, offset)); |
2411 | } |
2412 | |
2413 | TensorView* max_tensor = nullptr; |
2414 | for (auto tv : inp_tile) { |
2415 | if (max_tensor == nullptr) { |
2416 | max_tensor = tv; |
2417 | } else { |
2418 | max_tensor = binaryOp(BinaryOpType::Max, max_tensor, tv); |
2419 | } |
2420 | } |
2421 | |
2422 | fusion.addOutput(max_tensor); |
2423 | |
2424 | //////////////////////////////////// |
2425 | |
2426 | // Cache the input and weight tensors |
2427 | auto inp_cache = inp->cacheAfter(); |
2428 | |
2429 | // Tiling the spatial domain |
2430 | const int tile_x = 32; |
2431 | const int tile_y = 8; |
2432 | |
2433 | max_tensor->split(-2, tile_y); |
2434 | max_tensor->axis(-2)->parallelize(ParallelType::TIDy); |
2435 | max_tensor->split(-1, tile_x); |
2436 | max_tensor->axis(-1)->parallelize(ParallelType::TIDx); |
2437 | max_tensor->reorder({{-3, -2}}); |
2438 | |
2439 | inp_cache->computeAt(max_tensor, 3); |
2440 | inp_cache->axis(-2)->parallelize(ParallelType::TIDy); |
2441 | inp_cache->axis(-1)->parallelize(ParallelType::TIDx); |
2442 | inp_cache->setMemoryType(MemoryType::Shared); |
2443 | |
2444 | auto max_tensor_dep = |
2445 | DependencyCheck::getAllValsBetween({inp_cache}, {max_tensor}); |
2446 | for (auto tv : ir_utils::filterByType<TensorView>(max_tensor_dep)) { |
2447 | if (tv == inp_cache || tv == max_tensor) { |
2448 | continue; |
2449 | } |
2450 | tv->computeAt(max_tensor, -1); |
2451 | } |
2452 | |
2453 | max_tensor->axis(0)->parallelize(ParallelType::BIDx); |
2454 | |
2455 | const int hw = 50; |
2456 | const int num_channels = 20; |
2457 | const int pooling_window = 3; |
2458 | |
2459 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2460 | at::Tensor aten_inp = at::randn({num_channels, hw, hw}, options); |
2461 | // shift always pads by zero, so if all surrounding values are |
2462 | // negative, max pooling would pick a padded value, which isn't the |
2463 | // correct behavior. We need to be able to choose the value of |
2464 | // padding. In this case, padding by the minimum value would not |
2465 | // have this problem. For now, avoid the problem by making sure all |
2466 | // values are not negative. |
2467 | aten_inp = at::abs(aten_inp); |
2468 | std::vector<IValue> inputs = {aten_inp}; |
2469 | |
2470 | FusionExecutor fe; |
2471 | fe.compileFusion(&fusion, inputs); |
2472 | auto outputs = fe.runFusion(inputs); |
2473 | |
2474 | auto ref = at::max_pool2d( |
2475 | aten_inp, {pooling_window, pooling_window}, {1, 1}, {1, 1}); |
2476 | |
2477 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
2478 | } |
2479 | |
2480 | TEST_F(NVFuserTest, FusionGather1_CUDA) { |
2481 | Fusion fusion; |
2482 | FusionGuard fg(&fusion); |
2483 | |
2484 | auto tv0 = makeSymbolicTensor(2); |
2485 | fusion.addInput(tv0); |
2486 | |
2487 | const std::vector<int> window_shape = {1, 3}; |
2488 | const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}}; |
2489 | |
2490 | auto tv1 = gather(tv0, window_shape, padding_width); |
2491 | |
2492 | fusion.addOutput(tv1); |
2493 | |
2494 | const int s1 = 11; |
2495 | const int s2 = 13; |
2496 | |
2497 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2498 | at::Tensor t0 = at::randn({s1, s2}, options); |
2499 | |
2500 | auto ref = gather(t0, window_shape, padding_width); |
2501 | |
2502 | FusionExecutor fe; |
2503 | fe.compileFusion(&fusion, {t0}); |
2504 | auto outputs = fe.runFusion({t0}); |
2505 | |
2506 | TORCH_CHECK(ref.equal(outputs[0])); |
2507 | } |
2508 | |
2509 | TEST_F(NVFuserTest, FusionGather2_CUDA) { |
2510 | Fusion fusion; |
2511 | FusionGuard fg(&fusion); |
2512 | |
2513 | const std::vector<int> window_shape = {1, 3}; |
2514 | const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}}; |
2515 | |
2516 | auto tv0 = makeSymbolicTensor(2); |
2517 | fusion.addInput(tv0); |
2518 | |
2519 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
2520 | |
2521 | auto tv2 = gather(tv1, window_shape, padding_width); |
2522 | |
2523 | auto tv3 = sum(tv2, {-1}); |
2524 | |
2525 | fusion.addOutput(tv3); |
2526 | |
2527 | tv3->split(1, 32); |
2528 | tv0->computeAt(tv3, 2); |
2529 | tv2->computeAt(tv3, -1); |
2530 | |
2531 | tv3->axis(0)->parallelize(ParallelType::BIDy); |
2532 | tv3->axis(1)->parallelize(ParallelType::BIDx); |
2533 | tv3->axis(2)->parallelize(ParallelType::TIDx); |
2534 | tv1->axis(2)->parallelize(ParallelType::TIDx); |
2535 | |
2536 | tv1->setMemoryType(MemoryType::Shared); |
2537 | |
2538 | const int s1 = 99; |
2539 | const int s2 = 101; |
2540 | |
2541 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2542 | at::Tensor t0 = at::randn({s1, s2}, options); |
2543 | std::vector<IValue> inputs = {t0}; |
2544 | |
2545 | FusionExecutor fe; |
2546 | fe.compileFusion(&fusion, inputs); |
2547 | auto outputs = fe.runFusion(inputs); |
2548 | |
2549 | auto t1 = t0 + 1; |
2550 | auto t2 = gather(t1, window_shape, padding_width); |
2551 | auto ref = sum(t2, {-1}); |
2552 | |
2553 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
2554 | } |
2555 | |
2556 | TEST_F(NVFuserTest, FusionGather3_CUDA) { |
2557 | Fusion fusion; |
2558 | FusionGuard fg(&fusion); |
2559 | |
2560 | auto tv0 = makeSymbolicTensor(2); |
2561 | fusion.addInput(tv0); |
2562 | |
2563 | const std::vector<int> window_shape = {1, 3}; |
2564 | const std::vector<std::vector<int>> padding_width = {{0, 0}, {0, 0}}; |
2565 | |
2566 | auto tv1 = gather(tv0, window_shape, padding_width); |
2567 | |
2568 | fusion.addOutput(tv1); |
2569 | |
2570 | const int s1 = 11; |
2571 | const int s2 = 13; |
2572 | |
2573 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2574 | std::vector<int64_t> size({s1, s2}); |
2575 | at::Tensor t0 = at::randn(size, options); |
2576 | size.insert(size.end(), window_shape.begin(), window_shape.end()); |
2577 | // Use a pre-allocated output tensor filled with 1 so that invalid |
2578 | // writes to outside valid ranges can be detected |
2579 | at::Tensor output = at::ones(size, options); |
2580 | |
2581 | FusionExecutor fe; |
2582 | fe.compileFusion(&fusion, {t0}); |
2583 | auto outputs = fe.runFusion({t0}, {output}); |
2584 | |
2585 | auto ref = gather(t0, window_shape, padding_width); |
2586 | TORCH_CHECK(ref.equal(outputs[0])); |
2587 | } |
2588 | |
2589 | TEST_F(NVFuserTest, FusionGather4_CUDA) { |
2590 | Fusion fusion; |
2591 | FusionGuard fg(&fusion); |
2592 | |
2593 | auto tv0 = makeSymbolicTensor(2); |
2594 | fusion.addInput(tv0); |
2595 | |
2596 | const std::vector<int> window_shape = {3, 3}; |
2597 | const std::vector<std::vector<int>> padding_width = {{0, 0}, {0, 0}}; |
2598 | |
2599 | auto tv1 = gather(tv0, window_shape, padding_width); |
2600 | |
2601 | fusion.addOutput(tv1); |
2602 | |
2603 | const int s1 = 11; |
2604 | const int s2 = 13; |
2605 | |
2606 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2607 | std::vector<int64_t> size({s1, s2}); |
2608 | at::Tensor t0 = at::randn(size, options); |
2609 | size.insert(size.end(), window_shape.begin(), window_shape.end()); |
2610 | // Use a pre-allocated output tensor filled with 1 so that invalid |
2611 | // writes to outside valid ranges can be detected |
2612 | at::Tensor output = at::ones(size, options); |
2613 | |
2614 | FusionExecutor fe; |
2615 | fe.compileFusion(&fusion, {t0}); |
2616 | auto outputs = fe.runFusion({t0}, {output}); |
2617 | |
2618 | auto ref = gather(t0, window_shape, padding_width); |
2619 | |
2620 | TORCH_CHECK(ref.equal(outputs[0])); |
2621 | } |
2622 | |
2623 | TEST_F(NVFuserTest, FusionGather5_CUDA) { |
2624 | Fusion fusion; |
2625 | FusionGuard fg(&fusion); |
2626 | |
2627 | auto tv0 = makeSymbolicTensor(2); |
2628 | fusion.addInput(tv0); |
2629 | |
2630 | const std::vector<int> window_shape = {3, 3}; |
2631 | const std::vector<std::vector<int>> padding_width = {{1, 0}, {0, 1}}; |
2632 | |
2633 | auto tv1 = gather(tv0, window_shape, padding_width); |
2634 | |
2635 | fusion.addOutput(tv1); |
2636 | |
2637 | const int s1 = 11; |
2638 | const int s2 = 13; |
2639 | |
2640 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2641 | std::vector<int64_t> size({s1, s2}); |
2642 | at::Tensor t0 = at::randn(size, options); |
2643 | size.insert(size.end(), window_shape.begin(), window_shape.end()); |
2644 | // Use a pre-allocated output tensor filled with 1 so that invalid |
2645 | // writes to outside valid ranges can be detected |
2646 | at::Tensor output = at::ones(size, options); |
2647 | |
2648 | FusionExecutor fe; |
2649 | fe.compileFusion(&fusion, {t0}); |
2650 | auto outputs = fe.runFusion({t0}, {output}); |
2651 | |
2652 | auto ref = gather(t0, window_shape, padding_width); |
2653 | |
2654 | TORCH_CHECK(ref.equal(outputs[0])); |
2655 | } |
2656 | |
2657 | // Conv-like pattern with no padding |
2658 | TEST_F(NVFuserTest, FusionGather6_CUDA) { |
2659 | Fusion fusion; |
2660 | FusionGuard fg(&fusion); |
2661 | |
2662 | auto tv0 = makeSymbolicTensor(2); |
2663 | fusion.addInput(tv0); |
2664 | |
2665 | const std::vector<int> window_shape = {3, 4}; |
2666 | const std::vector<std::vector<int>> padding_width = {{0, 0}, {0, 0}}; |
2667 | |
2668 | auto tv1 = gather(tv0, window_shape, padding_width); |
2669 | |
2670 | fusion.addOutput(tv1); |
2671 | |
2672 | // Blocking the spatial dimensions |
2673 | const int block_x = 16; |
2674 | const int block_y = 8; |
2675 | |
2676 | auto tv0_cache = tv0->cacheAfter(); |
2677 | auto out = tv1; |
2678 | auto out_cache = out->cacheBefore(); |
2679 | |
2680 | out->split(1, block_x); |
2681 | out->split(0, block_y); |
2682 | out->reorder({{1, 2}, {2, 1}}); |
2683 | |
2684 | TransformPropagator propagator(out); |
2685 | MaxRootDomainInfoSpanningTree(out).traverse(&propagator); |
2686 | |
2687 | tv0->computeAt(out, 2); |
2688 | |
2689 | tv0_cache->setMemoryType(MemoryType::Shared); |
2690 | |
2691 | out->axis(0)->parallelize(ParallelType::BIDy); |
2692 | out->axis(1)->parallelize(ParallelType::BIDx); |
2693 | out->axis(2)->parallelize(ParallelType::TIDy); |
2694 | out->axis(3)->parallelize(ParallelType::TIDx); |
2695 | scheduler_utils::parallelizeAllLike(out); |
2696 | |
2697 | const int s1 = 101; |
2698 | const int s2 = 99; |
2699 | |
2700 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2701 | std::vector<int64_t> size({s1, s2}); |
2702 | at::Tensor t0 = at::randn(size, options); |
2703 | size.insert(size.end(), window_shape.begin(), window_shape.end()); |
2704 | // Use a pre-allocated output tensor filled with 1 so that invalid |
2705 | // writes to outside valid ranges can be detected |
2706 | at::Tensor output = at::ones(size, options); |
2707 | |
2708 | FusionExecutor fe; |
2709 | fe.compileFusion(&fusion, {t0}); |
2710 | auto outputs = fe.runFusion({t0}, {output}); |
2711 | |
2712 | auto ref = gather(t0, window_shape, padding_width); |
2713 | |
2714 | TORCH_CHECK(ref.equal(outputs[0])); |
2715 | } |
2716 | |
2717 | // Conv-like pattern with irregular padding |
2718 | TEST_F(NVFuserTest, FusionGather7_CUDA) { |
2719 | Fusion fusion; |
2720 | FusionGuard fg(&fusion); |
2721 | |
2722 | auto tv0 = makeSymbolicTensor(2); |
2723 | fusion.addInput(tv0); |
2724 | |
2725 | const std::vector<int> window_shape = {3, 4}; |
2726 | const std::vector<std::vector<int>> padding_width = {{0, 2}, {2, 1}}; |
2727 | |
2728 | auto tv1 = gather(tv0, window_shape, padding_width); |
2729 | |
2730 | fusion.addOutput(tv1); |
2731 | |
2732 | // Blocking the spatial dimensions |
2733 | const int block_x = 16; |
2734 | const int block_y = 8; |
2735 | |
2736 | auto tv0_cache = tv0->cacheAfter(); |
2737 | auto out = tv1; |
2738 | auto out_cache = out->cacheBefore(); |
2739 | |
2740 | out->split(1, block_x); |
2741 | out->split(0, block_y); |
2742 | out->reorder({{1, 2}, {2, 1}}); |
2743 | |
2744 | TransformPropagator propagator(out); |
2745 | MaxRootDomainInfoSpanningTree(out).traverse(&propagator); |
2746 | |
2747 | tv0->computeAt(out, 2); |
2748 | |
2749 | tv0_cache->setMemoryType(MemoryType::Shared); |
2750 | |
2751 | out->axis(0)->parallelize(ParallelType::BIDy); |
2752 | out->axis(1)->parallelize(ParallelType::BIDx); |
2753 | out->axis(2)->parallelize(ParallelType::TIDy); |
2754 | out->axis(3)->parallelize(ParallelType::TIDx); |
2755 | scheduler_utils::parallelizeAllLike(out); |
2756 | |
2757 | const int s1 = 101; |
2758 | const int s2 = 99; |
2759 | |
2760 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2761 | std::vector<int64_t> size({s1, s2}); |
2762 | at::Tensor t0 = at::randn(size, options); |
2763 | size.insert(size.end(), window_shape.begin(), window_shape.end()); |
2764 | at::Tensor output = at::ones(size, options); |
2765 | |
2766 | FusionExecutor fe; |
2767 | fe.compileFusion(&fusion, {t0}); |
2768 | auto outputs = fe.runFusion({t0}, {output}); |
2769 | |
2770 | auto ref = gather(t0, window_shape, padding_width); |
2771 | |
2772 | TORCH_CHECK(ref.equal(outputs[0])); |
2773 | } |
2774 | |
2775 | // With no padding but with striding |
2776 | TEST_F(NVFuserTest, FusionGather8_CUDA) { |
2777 | Fusion fusion; |
2778 | FusionGuard fg(&fusion); |
2779 | |
2780 | auto tv0 = makeSymbolicTensor(2); |
2781 | fusion.addInput(tv0); |
2782 | |
2783 | const std::vector<int> window_shape = {2, 3}; |
2784 | const std::vector<std::vector<int>> padding_width = {{0, 0}, {0, 0}}; |
2785 | const std::vector<int> strides = {3, 3}; |
2786 | |
2787 | auto tv1 = gather(tv0, window_shape, padding_width, strides); |
2788 | |
2789 | fusion.addOutput(tv1); |
2790 | |
2791 | const int s1 = 11; |
2792 | const int s2 = 13; |
2793 | |
2794 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2795 | std::vector<int64_t> size({s1, s2}); |
2796 | at::Tensor t0 = at::randn(size, options); |
2797 | for (const auto i : c10::irange(size.size())) { |
2798 | size[i] = ceilDiv( |
2799 | size[i] - window_shape[i] + 1 + padding_width[i][0] + |
2800 | padding_width[i][1], |
2801 | strides[i]); |
2802 | } |
2803 | size.insert(size.end(), window_shape.begin(), window_shape.end()); |
2804 | // Use a pre-allocated output tensor filled with 1 so that invalid |
2805 | // writes to outside valid ranges can be detected |
2806 | at::Tensor output = at::ones(size, options); |
2807 | |
2808 | FusionExecutor fe; |
2809 | fe.compileFusion(&fusion, {t0}); |
2810 | auto outputs = fe.runFusion({t0}, {output}); |
2811 | |
2812 | auto ref = gather(t0, window_shape, padding_width, strides); |
2813 | |
2814 | TORCH_CHECK(ref.equal(outputs[0])); |
2815 | } |
2816 | |
2817 | // Similar to Gather8 but with splitting and parallelization |
2818 | TEST_F(NVFuserTest, FusionGather9_CUDA) { |
2819 | Fusion fusion; |
2820 | FusionGuard fg(&fusion); |
2821 | |
2822 | auto tv0 = makeSymbolicTensor(2); |
2823 | fusion.addInput(tv0); |
2824 | |
2825 | const std::vector<int> window_shape = {3, 4}; |
2826 | const std::vector<std::vector<int>> padding_width = {{0, 0}, {0, 0}}; |
2827 | const std::vector<int> strides = {2, 2}; |
2828 | |
2829 | auto tv1 = gather(tv0, window_shape, padding_width, strides); |
2830 | |
2831 | fusion.addOutput(tv1); |
2832 | |
2833 | // Blocking the spatial dimensions |
2834 | const int block_x = 16; |
2835 | const int block_y = 8; |
2836 | |
2837 | auto tv0_cache = tv0->cacheAfter(); |
2838 | auto out = tv1; |
2839 | auto out_cache = out->cacheBefore(); |
2840 | |
2841 | out->split(1, block_x); |
2842 | out->split(0, block_y); |
2843 | out->reorder({{1, 2}, {2, 1}}); |
2844 | |
2845 | TransformPropagator propagator(out); |
2846 | MaxRootDomainInfoSpanningTree(out).traverse(&propagator); |
2847 | |
2848 | tv0->computeAt(out, 2); |
2849 | |
2850 | tv0_cache->setMemoryType(MemoryType::Shared); |
2851 | |
2852 | out->axis(0)->parallelize(ParallelType::BIDy); |
2853 | out->axis(1)->parallelize(ParallelType::BIDx); |
2854 | out->axis(2)->parallelize(ParallelType::TIDy); |
2855 | out->axis(3)->parallelize(ParallelType::TIDx); |
2856 | scheduler_utils::parallelizeAllLike(out); |
2857 | |
2858 | const int s1 = 101; |
2859 | const int s2 = 99; |
2860 | |
2861 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2862 | std::vector<int64_t> size({s1, s2}); |
2863 | at::Tensor t0 = at::randn(size, options); |
2864 | for (const auto i : c10::irange(size.size())) { |
2865 | size[i] = ceilDiv( |
2866 | size[i] - window_shape[i] + 1 + padding_width[i][0] + |
2867 | padding_width[i][1], |
2868 | strides[i]); |
2869 | } |
2870 | size.insert(size.end(), window_shape.begin(), window_shape.end()); |
2871 | // Use a pre-allocated output tensor filled with 1 so that invalid |
2872 | // writes to outside valid ranges can be detected |
2873 | at::Tensor output = at::ones(size, options); |
2874 | |
2875 | FusionExecutor fe; |
2876 | fe.compileFusion(&fusion, {t0}); |
2877 | auto outputs = fe.runFusion({t0}, {output}); |
2878 | |
2879 | auto ref = gather(t0, window_shape, padding_width, strides); |
2880 | |
2881 | TORCH_CHECK(ref.equal(outputs[0])); |
2882 | } |
2883 | |
2884 | TEST_F(NVFuserTest, FusionConv2D_CUDA) { |
2885 | Fusion fusion; |
2886 | FusionGuard fg(&fusion); |
2887 | |
2888 | // Input: [C, H, W] |
2889 | auto inp = makeSymbolicTensor(3); |
2890 | fusion.addInput(inp); |
2891 | |
2892 | // Weights: [K, C, 3, 3] |
2893 | auto w = makeSymbolicTensor(4); |
2894 | fusion.addInput(w); |
2895 | |
2896 | // Gather a neighbor tile of [3, 3] with padding size of 1 for each |
2897 | // side of the spatial dimensions |
2898 | auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}}); |
2899 | // inp_tile: [C, H, W, 1, 3, 3] |
2900 | |
2901 | auto inp_bc = |
2902 | broadcast(inp_tile, {true, false, false, false, false, false, false}); |
2903 | auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
2904 | |
2905 | auto inp_times_w = mul(inp_bc, w_bc); |
2906 | |
2907 | // Reduce the channel and neighbor tile dimensions |
2908 | auto out = sum(inp_times_w, {1, 4, 5, 6}); |
2909 | |
2910 | fusion.addOutput(out); |
2911 | |
2912 | //////////////////////////////////// |
2913 | |
2914 | // Cache the input and weight tensors |
2915 | auto inp_cache = inp->cacheAfter(); |
2916 | |
2917 | // Blocking the spatial dimensions |
2918 | const int block_w = 16; |
2919 | const int block_h = 4; |
2920 | // Blocking the channel dimension |
2921 | const int block_c = 8; |
2922 | |
2923 | out->split(2, block_h); |
2924 | out->split(4, block_w); |
2925 | out->reorder({{3, 4}}); |
2926 | // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] |
2927 | |
2928 | out->split(1, block_c); |
2929 | // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
2930 | |
2931 | auto out_rf = out->rFactor({1, -3, -2, -1}); |
2932 | // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
2933 | // out_rf: [K, Ci, Ho, Wo, Hi, Wi] |
2934 | |
2935 | // Create a [block_x, block_y] tile on smem |
2936 | inp_cache->computeAt(out, 4); |
2937 | // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
2938 | inp_cache->setMemoryType(MemoryType::Shared); |
2939 | |
2940 | // Move Ci forward |
2941 | out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); |
2942 | inp_cache->computeAt(out_rf, 5); |
2943 | |
2944 | inp_tile->computeAt(out_rf, -1); |
2945 | w->computeAt(out_rf, -1); |
2946 | |
2947 | out->axis(0)->parallelize(ParallelType::BIDx); |
2948 | out->axis(1)->parallelize(ParallelType::TIDz); |
2949 | out->axis(4)->parallelize(ParallelType::TIDy); |
2950 | out->axis(5)->parallelize(ParallelType::TIDx); |
2951 | |
2952 | scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
2953 | |
2954 | const int dim_h = 99; |
2955 | const int dim_w = 101; |
2956 | const int dim_c = 10; |
2957 | const int dim_f = 20; |
2958 | |
2959 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2960 | at::manual_seed(0); |
2961 | at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
2962 | at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); |
2963 | std::vector<IValue> inputs = {at_inp, at_w}; |
2964 | |
2965 | FusionExecutor fe; |
2966 | fe.compileFusion(&fusion, inputs); |
2967 | auto cg_outputs = fe.runFusion(inputs); |
2968 | |
2969 | at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
2970 | auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); |
2971 | at_out = at_out.squeeze(0); // drop the N axis |
2972 | |
2973 | testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
2974 | } |
2975 | |
2976 | TEST_F(NVFuserTest, FusionConv2DNoPadding_CUDA) { |
2977 | Fusion fusion; |
2978 | FusionGuard fg(&fusion); |
2979 | ContextCudnnTF32Disabled disabling_tf32_cudnn; |
2980 | |
2981 | // Input: [C, H, W] |
2982 | auto inp = makeSymbolicTensor(3); |
2983 | fusion.addInput(inp); |
2984 | |
2985 | // Weights: [K, C, 3, 3] |
2986 | auto w = makeSymbolicTensor(4); |
2987 | fusion.addInput(w); |
2988 | |
2989 | // Gather a neighbor tile of [3, 3] with no padding |
2990 | auto inp_tile = |
2991 | gather(inp, {1, 3, 3}, {{0, 0}, {0, 0}, {0, 0}}, {1, 1, 1}, true); |
2992 | // inp_tile: [C, H-2, W-2, 1, 3, 3] |
2993 | |
2994 | auto inp_bc = |
2995 | broadcast(inp_tile, {true, false, false, false, false, false, false}); |
2996 | auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
2997 | |
2998 | auto inp_times_w = mul(inp_bc, w_bc); |
2999 | |
3000 | // Reduce the channel and neighbor tile dimensions |
3001 | auto out = sum(inp_times_w, {1, 4, 5, 6}); |
3002 | |
3003 | fusion.addOutput(out); |
3004 | |
3005 | //////////////////////////////////// |
3006 | |
3007 | // Cache the input and weight tensors |
3008 | auto inp_cache = inp->cacheAfter(); |
3009 | |
3010 | // Blocking the spatial dimensions |
3011 | const int block_w = 16; |
3012 | const int block_h = 4; |
3013 | // Blocking the channel dimension |
3014 | const int block_c = 8; |
3015 | |
3016 | out->split(2, block_h); |
3017 | out->split(4, block_w); |
3018 | out->reorder({{3, 4}}); |
3019 | // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] |
3020 | |
3021 | out->split(1, block_c); |
3022 | // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
3023 | |
3024 | auto out_rf = out->rFactor({1, -3, -2, -1}); |
3025 | // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
3026 | // out_rf: [K, Ci, Ho, Wo, Hi, Wi] |
3027 | |
3028 | // Create a [block_x, block_y] tile on smem |
3029 | inp_cache->computeAt(out, 4); |
3030 | // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
3031 | inp_cache->setMemoryType(MemoryType::Shared); |
3032 | |
3033 | // Move Ci forward |
3034 | out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); |
3035 | inp_cache->computeAt(out_rf, 5); |
3036 | |
3037 | inp_tile->computeAt(out_rf, -1); |
3038 | w->computeAt(out_rf, -1); |
3039 | |
3040 | out->axis(0)->parallelize(ParallelType::BIDx); |
3041 | out->axis(1)->parallelize(ParallelType::TIDz); |
3042 | out->axis(4)->parallelize(ParallelType::TIDy); |
3043 | out->axis(5)->parallelize(ParallelType::TIDx); |
3044 | |
3045 | scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
3046 | |
3047 | const int dim_h = 99; |
3048 | const int dim_w = 101; |
3049 | const int dim_c = 10; |
3050 | const int dim_f = 20; |
3051 | |
3052 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3053 | at::manual_seed(0); |
3054 | at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
3055 | at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); |
3056 | std::vector<IValue> inputs = {at_inp, at_w}; |
3057 | |
3058 | FusionExecutor fe; |
3059 | fe.compileFusion(&fusion, inputs); |
3060 | auto cg_outputs = fe.runFusion(inputs); |
3061 | |
3062 | at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
3063 | std::vector<int64_t> stride = {1, 1}; |
3064 | std::vector<int64_t> padding = {0, 0}; |
3065 | auto at_out = at::conv2d(at_inp, at_w, {}, stride, padding); |
3066 | at_out = at_out.squeeze(0); // drop the N axis |
3067 | |
3068 | testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
3069 | } |
3070 | |
3071 | TEST_F(NVFuserTest, FusionConv2DNoPaddingStrided_CUDA) { |
3072 | Fusion fusion; |
3073 | FusionGuard fg(&fusion); |
3074 | |
3075 | // Input: [C, H, W] |
3076 | auto inp = makeSymbolicTensor(3); |
3077 | fusion.addInput(inp); |
3078 | |
3079 | // Weights: [K, C, 3, 3] |
3080 | auto w = makeSymbolicTensor(4); |
3081 | fusion.addInput(w); |
3082 | |
3083 | // Gather a neighbor tile of [2, 2] with no padding and strides of |
3084 | // [2, 2] |
3085 | auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 0}, {0, 0}}, {1, 2, 2}); |
3086 | // inp_tile: [C, H/2, W/2, 1, 2, 2] |
3087 | |
3088 | auto inp_bc = |
3089 | broadcast(inp_tile, {true, false, false, false, false, false, false}); |
3090 | auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
3091 | |
3092 | auto inp_times_w = mul(inp_bc, w_bc); |
3093 | |
3094 | // Reduce the channel and neighbor tile dimensions |
3095 | auto out = sum(inp_times_w, {1, 4, 5, 6}); |
3096 | |
3097 | fusion.addOutput(out); |
3098 | |
3099 | //////////////////////////////////// |
3100 | |
3101 | // Cache the input and weight tensors |
3102 | auto inp_cache = inp->cacheAfter(); |
3103 | |
3104 | // Blocking the spatial dimensions |
3105 | const int block_w = 16; |
3106 | const int block_h = 4; |
3107 | // Blocking the channel dimension |
3108 | const int block_c = 8; |
3109 | |
3110 | out->split(2, block_h); |
3111 | out->split(4, block_w); |
3112 | out->reorder({{3, 4}}); |
3113 | // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] |
3114 | |
3115 | out->split(1, block_c); |
3116 | // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
3117 | |
3118 | auto out_rf = out->rFactor({1, -3, -2, -1}); |
3119 | // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
3120 | // out_rf: [K, Ci, Ho, Wo, Hi, Wi] |
3121 | |
3122 | // Create a [block_x, block_y] tile on smem |
3123 | inp_cache->computeAt(out, 4); |
3124 | // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
3125 | inp_cache->setMemoryType(MemoryType::Shared); |
3126 | |
3127 | // Move Ci forward |
3128 | out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); |
3129 | inp_cache->computeAt(out_rf, 5); |
3130 | |
3131 | inp_tile->computeAt(out_rf, -1); |
3132 | w->computeAt(out_rf, -1); |
3133 | |
3134 | out->axis(0)->parallelize(ParallelType::BIDx); |
3135 | out->axis(1)->parallelize(ParallelType::TIDz); |
3136 | out->axis(4)->parallelize(ParallelType::TIDy); |
3137 | out->axis(5)->parallelize(ParallelType::TIDx); |
3138 | |
3139 | scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
3140 | |
3141 | const int dim_h = 99; |
3142 | const int dim_w = 101; |
3143 | const int dim_c = 10; |
3144 | const int dim_f = 20; |
3145 | |
3146 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3147 | at::manual_seed(0); |
3148 | at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
3149 | at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options); |
3150 | std::vector<IValue> inputs = {at_inp, at_w}; |
3151 | |
3152 | FusionExecutor fe; |
3153 | fe.compileFusion(&fusion, inputs); |
3154 | auto cg_outputs = fe.runFusion(inputs); |
3155 | |
3156 | at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
3157 | std::vector<int64_t> stride = {2, 2}; |
3158 | std::vector<int64_t> padding = {0, 0}; |
3159 | auto at_out = at::conv2d(at_inp, at_w, {}, stride, padding); |
3160 | at_out = at_out.squeeze(0); // drop the N axis |
3161 | |
3162 | testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
3163 | } |
3164 | |
3165 | // 5x5 followed by 3x3 |
3166 | TEST_F(NVFuserTest, FusionConv2DChain_CUDA) { |
3167 | const int dim_w1_h = 5; |
3168 | const int dim_w1_w = 5; |
3169 | const int dim_pad1_h = (dim_w1_h - 1) / 2; |
3170 | const int dim_pad1_w = (dim_w1_w - 1) / 2; |
3171 | const int dim_w2_h = 3; |
3172 | const int dim_w2_w = 3; |
3173 | const int dim_pad2_h = (dim_w2_h - 1) / 2; |
3174 | const int dim_pad2_w = (dim_w2_w - 1) / 2; |
3175 | |
3176 | Fusion fusion; |
3177 | FusionGuard fg(&fusion); |
3178 | |
3179 | // Input: [K1, H, W] |
3180 | auto inp = makeSymbolicTensor(3); |
3181 | fusion.addInput(inp); |
3182 | |
3183 | // Weights: [K2, K1, S1, T1] |
3184 | auto w1 = makeSymbolicTensor(4); |
3185 | fusion.addInput(w1); |
3186 | |
3187 | // Weights: [K3, K2, S2, T2] |
3188 | auto w2 = makeSymbolicTensor(4); |
3189 | fusion.addInput(w2); |
3190 | |
3191 | // Gather a neighbor tile of [w1_h, w1_w] with padding |
3192 | auto inp_tile = gather( |
3193 | inp, |
3194 | {1, dim_w1_h, dim_w1_w}, |
3195 | {{0, 0}, {dim_pad1_h, dim_pad1_h}, {dim_pad1_w, dim_pad1_w}}); |
3196 | // inp_tile: [C, 1, H - w1_h + 1, W - w1_w + 1, w1_h, w1_w] |
3197 | |
3198 | auto inp_bc = |
3199 | broadcast(inp_tile, {true, false, false, false, false, false, false}); |
3200 | auto w1_bc = broadcast(w1, {false, false, true, true, true, false, false}); |
3201 | |
3202 | auto inp_times_w1 = mul(inp_bc, w1_bc); |
3203 | |
3204 | // Reduce the channel and neighbor tile dimensions |
3205 | auto out1 = sum(inp_times_w1, {1, 4, 5, 6}); |
3206 | |
3207 | // Second conv |
3208 | auto out1_tile = gather( |
3209 | out1, |
3210 | {1, dim_w2_h, dim_w2_w}, |
3211 | {{0, 0}, {dim_pad2_h, dim_pad2_h}, {dim_pad2_w, dim_pad2_w}}); |
3212 | |
3213 | auto out1_bc = |
3214 | broadcast(out1_tile, {true, false, false, false, false, false, false}); |
3215 | auto w2_bc = broadcast(w2, {false, false, true, true, true, false, false}); |
3216 | |
3217 | auto out1_times_w2 = mul(out1_bc, w2_bc); |
3218 | |
3219 | auto out2 = sum(out1_times_w2, {1, 4, 5, 6}); |
3220 | |
3221 | fusion.addOutput(out2); |
3222 | |
3223 | //////////////////////////////////// |
3224 | // Cache the input and weight tensors |
3225 | auto inp_cache = inp->cacheAfter(); |
3226 | |
3227 | // Blocking the spatial dimensions |
3228 | const int block_w = 16; |
3229 | const int block_h = 4; |
3230 | |
3231 | out2->split(2, block_h); |
3232 | out2->split(4, block_w); |
3233 | out2->reorder({{3, 4}}); |
3234 | // out2: [K3, K2, Ho, Wo, Hi, Wi, 1, 3, 3] |
3235 | |
3236 | // Create a [block_x, block_y] tile on smem |
3237 | inp_cache->computeAt(out2, 4); |
3238 | // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
3239 | inp_cache->setMemoryType(MemoryType::Shared); |
3240 | |
3241 | // Move Ci forward |
3242 | out1->reorder({{5, 3}, {3, 4}, {4, 5}}); |
3243 | out1->setMemoryType(MemoryType::Shared); |
3244 | |
3245 | inp_cache->computeAt(out1, 4); |
3246 | |
3247 | inp_tile->computeAt(out1, -1); |
3248 | w1->computeAt(out1, -1); |
3249 | |
3250 | out1_tile->computeAt(out2, -1); |
3251 | w2->computeAt(out2, -1); |
3252 | |
3253 | out2->axis(0)->parallelize(ParallelType::BIDx); |
3254 | out2->axis(4)->parallelize(ParallelType::TIDy); |
3255 | out2->axis(5)->parallelize(ParallelType::TIDx); |
3256 | |
3257 | scheduler_utils::parallelizeAllLike(out2, {inp_cache, out1}); |
3258 | |
3259 | const int dim_h = 99; |
3260 | const int dim_w = 101; |
3261 | const int dim_k1 = 3; |
3262 | const int dim_k2 = 5; |
3263 | const int dim_k3 = 7; |
3264 | |
3265 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3266 | at::manual_seed(0); |
3267 | at::Tensor at_inp = at::randn({dim_k1, dim_h, dim_w}, options); |
3268 | at::Tensor at_w1 = at::randn({dim_k2, dim_k1, dim_w1_h, dim_w1_w}, options); |
3269 | at::Tensor at_w2 = at::randn({dim_k3, dim_k2, dim_w2_h, dim_w2_w}, options); |
3270 | std::vector<IValue> inputs = {at_inp, at_w1, at_w2}; |
3271 | |
3272 | FusionExecutor fe; |
3273 | fe.compileFusion(&fusion, inputs); |
3274 | auto cg_outputs = fe.runFusion(inputs); |
3275 | |
3276 | at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
3277 | auto at_out1 = at::conv2d(at_inp, at_w1, {}, 1, 2); |
3278 | auto at_out2 = at::conv2d(at_out1, at_w2, {}, 1, 1); |
3279 | at_out2 = at_out2.squeeze(0); // drop the N axis |
3280 | |
3281 | testValidate(&fusion, cg_outputs, inputs, {at_out2}, __LINE__, __FILE__); |
3282 | } |
3283 | |
3284 | TEST_F(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { |
3285 | Fusion fusion; |
3286 | FusionGuard fg(&fusion); |
3287 | |
3288 | // Input: [C, H, W] |
3289 | auto inp = makeSymbolicTensor(3); |
3290 | fusion.addInput(inp); |
3291 | |
3292 | // Weights: [K, C, 2, 2] |
3293 | auto w = makeSymbolicTensor(4); |
3294 | fusion.addInput(w); |
3295 | |
3296 | // Gather a neighbor tile of [2, 2] with padding size of 1 only for |
3297 | // the right side of the spatial dimensions. The left padding is |
3298 | // zero so that the output axis stays the same. |
3299 | auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 1}, {0, 1}}); |
3300 | // inp_tile: [C, H, W, 1, 2, 2] |
3301 | |
3302 | auto inp_bc = |
3303 | broadcast(inp_tile, {true, false, false, false, false, false, false}); |
3304 | auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
3305 | |
3306 | auto inp_times_w = mul(inp_bc, w_bc); |
3307 | |
3308 | // Reduce the channel and neighbor tile dimensions |
3309 | auto out = sum(inp_times_w, {1, 4, 5, 6}); |
3310 | |
3311 | fusion.addOutput(out); |
3312 | |
3313 | //////////////////////////////////// |
3314 | |
3315 | // Cache the input and weight tensors |
3316 | auto inp_cache = inp->cacheAfter(); |
3317 | |
3318 | // Blocking the spatial dimensions |
3319 | const int block_w = 16; |
3320 | const int block_h = 4; |
3321 | // Blocking the channel dimension |
3322 | const int block_c = 8; |
3323 | |
3324 | out->split(2, block_h); |
3325 | out->split(4, block_w); |
3326 | out->reorder({{3, 4}}); |
3327 | // out: [K, C, Ho, Wo, Hi, Wi, 1, 2, 2] |
3328 | |
3329 | out->split(1, block_c); |
3330 | // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] |
3331 | |
3332 | auto out_rf = out->rFactor({1, -3, -2, -1}); |
3333 | // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] |
3334 | // out_rf: [K, Ci, Ho, Wo, Hi, Wi] |
3335 | |
3336 | // Create a [block_x, block_y] tile on smem |
3337 | inp_cache->computeAt(out, 4); |
3338 | // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
3339 | inp_cache->setMemoryType(MemoryType::Shared); |
3340 | |
3341 | // Move Ci forward |
3342 | out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); |
3343 | inp_cache->computeAt(out_rf, 5); |
3344 | |
3345 | inp_tile->computeAt(out_rf, -1); |
3346 | w->computeAt(out_rf, -1); |
3347 | |
3348 | out->axis(0)->parallelize(ParallelType::BIDx); |
3349 | out->axis(1)->parallelize(ParallelType::TIDz); |
3350 | out->axis(4)->parallelize(ParallelType::TIDy); |
3351 | out->axis(5)->parallelize(ParallelType::TIDx); |
3352 | |
3353 | scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
3354 | |
3355 | const int dim_h = 99; |
3356 | const int dim_w = 101; |
3357 | const int dim_c = 10; |
3358 | const int dim_f = 20; |
3359 | |
3360 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3361 | at::manual_seed(0); |
3362 | at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
3363 | at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options); |
3364 | std::vector<IValue> inputs = {at_inp, at_w}; |
3365 | |
3366 | FusionExecutor fe; |
3367 | fe.compileFusion(&fusion, inputs); |
3368 | auto cg_outputs = fe.runFusion(inputs); |
3369 | |
3370 | at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
3371 | auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); |
3372 | at_out = at_out.squeeze(0); // drop the N axis |
3373 | // The shape of the spatial domain is (dim_h+1)x(dim_w+1), whereas |
3374 | // the fuser output has dim_h*dim_w. Drop the first elements to make |
3375 | // it match with the fuser output. |
3376 | std::vector<at::indexing::TensorIndex> indices{ |
3377 | at::indexing::Slice(0, at::indexing::None), |
3378 | at::indexing::Slice(1, at::indexing::None), |
3379 | at::indexing::Slice(1, at::indexing::None)}; |
3380 | at_out = at_out.index(indices); |
3381 | |
3382 | testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
3383 | } |
3384 | |
3385 | TEST_F(NVFuserTest, FusionConv4x4Pad1x1_CUDA) { |
3386 | Fusion fusion; |
3387 | FusionGuard fg(&fusion); |
3388 | |
3389 | // Input: [C, H, W] |
3390 | auto inp = makeSymbolicTensor(3); |
3391 | fusion.addInput(inp); |
3392 | |
3393 | // Weights: [K, C, 4, 4] |
3394 | auto w = makeSymbolicTensor(4); |
3395 | fusion.addInput(w); |
3396 | |
3397 | // Gather a neighbor tile of [4, 4] with padding size of 1 for both |
3398 | // sides of the spatial dimensions. The resulting extent is |
3399 | // decreased by one. |
3400 | auto inp_tile = |
3401 | gather(inp, {1, 4, 4}, {{0, 0}, {1, 1}, {1, 1}}, {1, 1, 1}, true); |
3402 | // inp_tile: [C, H-1, W-1, 1, 4, 4] |
3403 | |
3404 | auto inp_bc = |
3405 | broadcast(inp_tile, {true, false, false, false, false, false, false}); |
3406 | auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
3407 | |
3408 | auto inp_times_w = mul(inp_bc, w_bc); |
3409 | |
3410 | // Reduce the channel and neighbor tile dimensions |
3411 | auto out = sum(inp_times_w, {1, 4, 5, 6}); |
3412 | |
3413 | fusion.addOutput(out); |
3414 | |
3415 | //////////////////////////////////// |
3416 | |
3417 | // Cache the input and weight tensors |
3418 | auto inp_cache = inp->cacheAfter(); |
3419 | |
3420 | // Blocking the spatial dimensions |
3421 | const int block_w = 16; |
3422 | const int block_h = 4; |
3423 | // Blocking the channel dimension |
3424 | const int block_c = 8; |
3425 | |
3426 | out->split(2, block_h); |
3427 | out->split(4, block_w); |
3428 | out->reorder({{3, 4}}); |
3429 | // out: [K, C, Ho, Wo, Hi, Wi, 1, 4, 4] |
3430 | |
3431 | out->split(1, block_c); |
3432 | // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 4, 4] |
3433 | |
3434 | auto out_rf = out->rFactor({1, -3, -2, -1}); |
3435 | // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 4, 4] |
3436 | // out_rf: [K, Ci, Ho, Wo, Hi, Wi] |
3437 | |
3438 | // Create a [block_x, block_y] tile on smem |
3439 | inp_cache->computeAt(out, 4); |
3440 | // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
3441 | inp_cache->setMemoryType(MemoryType::Shared); |
3442 | |
3443 | // Move Ci forward |
3444 | out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); |
3445 | inp_cache->computeAt(out_rf, 5); |
3446 | |
3447 | inp_tile->computeAt(out_rf, -1); |
3448 | w->computeAt(out_rf, -1); |
3449 | |
3450 | out->axis(0)->parallelize(ParallelType::BIDx); |
3451 | out->axis(1)->parallelize(ParallelType::TIDz); |
3452 | out->axis(4)->parallelize(ParallelType::TIDy); |
3453 | out->axis(5)->parallelize(ParallelType::TIDx); |
3454 | |
3455 | scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
3456 | |
3457 | const int dim_h = 99; |
3458 | const int dim_w = 101; |
3459 | const int dim_c = 10; |
3460 | const int dim_f = 20; |
3461 | |
3462 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3463 | at::manual_seed(0); |
3464 | at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
3465 | at::Tensor at_w = at::randn({dim_f, dim_c, 4, 4}, options); |
3466 | std::vector<IValue> inputs = {at_inp, at_w}; |
3467 | |
3468 | FusionExecutor fe; |
3469 | fe.compileFusion(&fusion, inputs); |
3470 | auto cg_outputs = fe.runFusion(inputs); |
3471 | |
3472 | at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
3473 | auto at_out = |
3474 | at::conv2d(at_inp.to(at::kDouble), at_w.to(at::kDouble), {}, 1, 1); |
3475 | at_out = at_out.squeeze(0); // drop the N axis |
3476 | |
3477 | testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
3478 | } |
3479 | |
3480 | TEST_F(NVFuserTest, FusionConv4x5Pad1x2_CUDA) { |
3481 | Fusion fusion; |
3482 | FusionGuard fg(&fusion); |
3483 | |
3484 | // Input: [C, H, W] |
3485 | auto inp = makeSymbolicTensor(3); |
3486 | fusion.addInput(inp); |
3487 | |
3488 | // Weights: [K, C, 4, 4] |
3489 | auto w = makeSymbolicTensor(4); |
3490 | fusion.addInput(w); |
3491 | |
3492 | // Gather a neighbor tile of [4, 5] with padding size of 1 and 2 for |
3493 | // each side of the spatial dimensions. |
3494 | auto inp_tile = |
3495 | gather(inp, {1, 4, 5}, {{0, 0}, {1, 1}, {2, 2}}, {1, 1, 1}, true); |
3496 | // inp_tile: [C, H-1, W, 1, 4, 5] |
3497 | |
3498 | auto inp_bc = |
3499 | broadcast(inp_tile, {true, false, false, false, false, false, false}); |
3500 | auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
3501 | |
3502 | auto inp_times_w = mul(inp_bc, w_bc); |
3503 | |
3504 | // Reduce the channel and neighbor tile dimensions |
3505 | auto out = sum(inp_times_w, {1, 4, 5, 6}); |
3506 | |
3507 | fusion.addOutput(out); |
3508 | |
3509 | //////////////////////////////////// |
3510 | |
3511 | // Cache the input and weight tensors |
3512 | auto inp_cache = inp->cacheAfter(); |
3513 | |
3514 | // Blocking the spatial dimensions |
3515 | const int block_w = 16; |
3516 | const int block_h = 4; |
3517 | // Blocking the channel dimension |
3518 | const int block_c = 8; |
3519 | |
3520 | out->split(2, block_h); |
3521 | out->split(4, block_w); |
3522 | out->reorder({{3, 4}}); |
3523 | // out: [K, C, Ho, Wo, Hi, Wi, 1, 4, 5] |
3524 | |
3525 | out->split(1, block_c); |
3526 | // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 4, 5] |
3527 | |
3528 | auto out_rf = out->rFactor({1, -3, -2, -1}); |
3529 | // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 4, 5] |
3530 | // out_rf: [K, Ci, Ho, Wo, Hi, Wi] |
3531 | |
3532 | // Create a [block_x, block_y] tile on smem |
3533 | inp_cache->computeAt(out, 4); |
3534 | // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
3535 | inp_cache->setMemoryType(MemoryType::Shared); |
3536 | |
3537 | // Move Ci forward |
3538 | out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); |
3539 | inp_cache->computeAt(out_rf, 5); |
3540 | |
3541 | inp_tile->computeAt(out_rf, -1); |
3542 | w->computeAt(out_rf, -1); |
3543 | |
3544 | out->axis(0)->parallelize(ParallelType::BIDx); |
3545 | out->axis(1)->parallelize(ParallelType::TIDz); |
3546 | out->axis(4)->parallelize(ParallelType::TIDy); |
3547 | out->axis(5)->parallelize(ParallelType::TIDx); |
3548 | |
3549 | scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
3550 | |
3551 | const int dim_h = 99; |
3552 | const int dim_w = 101; |
3553 | const int dim_c = 10; |
3554 | const int dim_f = 20; |
3555 | |
3556 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3557 | at::manual_seed(0); |
3558 | at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
3559 | at::Tensor at_w = at::randn({dim_f, dim_c, 4, 5}, options); |
3560 | std::vector<IValue> inputs = {at_inp, at_w}; |
3561 | |
3562 | FusionExecutor fe; |
3563 | fe.compileFusion(&fusion, inputs); |
3564 | auto cg_outputs = fe.runFusion(inputs); |
3565 | |
3566 | at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
3567 | auto at_out = |
3568 | at::conv2d(at_inp.to(at::kDouble), at_w.to(at::kDouble), {}, 1, {1, 2}); |
3569 | at_out = at_out.squeeze(0); // drop the N axis |
3570 | |
3571 | testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
3572 | } |
3573 | |
3574 | TEST_F(NVFuserTest, FusionConv4x4Pad1x1Stride4_CUDA) { |
3575 | Fusion fusion; |
3576 | FusionGuard fg(&fusion); |
3577 | |
3578 | // Input: [C, H, W] |
3579 | auto inp = makeSymbolicTensor(3); |
3580 | fusion.addInput(inp); |
3581 | |
3582 | // Weights: [K, C, 3, 3] |
3583 | auto w = makeSymbolicTensor(4); |
3584 | fusion.addInput(w); |
3585 | |
3586 | // Gather a neighbor tile of [4, 4] with padding size of 1 for both |
3587 | // sides of the spatial dimensions. Set the stride width as 4. |
3588 | auto inp_tile = gather(inp, {1, 4, 4}, {{0, 0}, {1, 1}, {1, 1}}, {1, 4, 4}); |
3589 | // inp_tile: [C, H/4, s4, W/4, s4, 1, 4, 4] |
3590 | |
3591 | auto inp_bc = |
3592 | broadcast(inp_tile, {true, false, false, false, false, false, false}); |
3593 | auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
3594 | |
3595 | auto inp_times_w = mul(inp_bc, w_bc); |
3596 | |
3597 | // Reduce the channel and neighbor tile dimensions |
3598 | auto out = sum(inp_times_w, {1, 4, 5, 6}); |
3599 | |
3600 | fusion.addOutput(out); |
3601 | |
3602 | //////////////////////////////////// |
3603 | |
3604 | // Cache the input and weight tensors |
3605 | auto inp_cache = inp->cacheAfter(); |
3606 | |
3607 | // Blocking the spatial dimensions |
3608 | const int block_w = 16; |
3609 | const int block_h = 4; |
3610 | const int block_c = 2; |
3611 | |
3612 | // [K, C, H/s, W/s, 1, 4, 4] |
3613 | out->split(2, block_h); |
3614 | // [K, C, H/s/block_h, block_h, W/s, 1, 4, 4] |
3615 | out->split(4, block_w); |
3616 | // [K, C, H/s/block_h, block_h, W/s/block_w, block_w, 1, 4, 4] |
3617 | out->reorder({{3, 4}}); |
3618 | // [K, C, H/s/block_h, W/s/block_w, block_h, block_w, 1, 4, 4] |
3619 | out->split(1, block_c); |
3620 | // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, block_h, block_w, 1, 4, |
3621 | // 4] |
3622 | out->split(4, 1); |
3623 | // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, 1, |
3624 | // 4, 4] |
3625 | |
3626 | auto out_rf = out->rFactor({1, -3, -2, -1}); |
3627 | // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, 1, |
3628 | // 4, 4] |
3629 | |
3630 | // out: [K, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w] |
3631 | |
3632 | inp_cache->computeAt(out, 5); |
3633 | inp_cache->setMemoryType(MemoryType::Shared); |
3634 | // [K, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, C/block_c, 1, |
3635 | // 4, 4] |
3636 | |
3637 | // Move C/block_c before block_h/2 and share the domain from |
3638 | // inp_cache to out_rf |
3639 | out_rf->reorder({{7, 5}, {5, 6}, {6, 7}}); |
3640 | inp_cache->computeAt(out_rf, 6); |
3641 | |
3642 | inp_tile->computeAt(out_rf, -1); |
3643 | w->computeAt(out_rf, -1); |
3644 | |
3645 | out->axis(0)->parallelize(ParallelType::BIDx); |
3646 | out->axis(1)->parallelize(ParallelType::TIDz); |
3647 | out->axis(4)->parallelize(ParallelType::Unswitch); |
3648 | out->axis(5)->parallelize(ParallelType::TIDy); |
3649 | out->axis(6)->parallelize(ParallelType::TIDx); |
3650 | |
3651 | scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
3652 | |
3653 | const int dim_h = 99; |
3654 | const int dim_w = 101; |
3655 | const int dim_c = 10; |
3656 | const int dim_f = 20; |
3657 | |
3658 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3659 | at::manual_seed(0); |
3660 | at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
3661 | at::Tensor at_w = at::randn({dim_f, dim_c, 4, 4}, options); |
3662 | std::vector<IValue> inputs = {at_inp, at_w}; |
3663 | |
3664 | FusionExecutor fe; |
3665 | fe.compileFusion(&fusion, inputs); |
3666 | auto cg_outputs = fe.runFusion(inputs); |
3667 | |
3668 | at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
3669 | auto at_out = |
3670 | at::conv2d(at_inp.to(at::kDouble), at_w.to(at::kDouble), {}, 4, {1, 1}); |
3671 | at_out = at_out.squeeze(0); // drop the N axis |
3672 | |
3673 | testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
3674 | } |
3675 | |
3676 | // POC implementation of im2col for 3-by-3 kernels |
3677 | TEST_F(NVFuserTest, FusionIm2Col_CUDA) { |
3678 | Fusion fusion; |
3679 | FusionGuard fg(&fusion); |
3680 | |
3681 | // Input: [N, C, H, W] |
3682 | auto inp = makeSymbolicTensor(4); |
3683 | fusion.addInput(inp); |
3684 | |
3685 | // Gather a neighbor tile of [3, 3] with padding size of 1 for each |
3686 | // side of the spatial dimensions |
3687 | auto inp_tile = gather(inp, {1, 1, 3, 3}, {{0, 0}, {0, 0}, {1, 1}, {1, 1}}); |
3688 | // inp_tile: [N, C, H, W, 1, 1, 3, 3] |
3689 | |
3690 | auto inp_col = permute(inp_tile, {0, 2, 3, 1, 4, 5, 6, 7}); |
3691 | // inp_col: [N, H, W, C, 1, 1, 3, 3] |
3692 | |
3693 | fusion.addOutput(inp_col); |
3694 | |
3695 | //////////////////////////////////// |
3696 | |
3697 | // Cache the input tensor |
3698 | auto inp_cache = inp->cacheAfter(); |
3699 | |
3700 | // Blocking the spatial dimensions |
3701 | const int block_w = 16; |
3702 | const int block_h = 4; |
3703 | |
3704 | auto out = inp_col; |
3705 | |
3706 | out->split(1, block_h); |
3707 | out->split(3, block_w); |
3708 | out->reorder({{2, 3}}); |
3709 | // out: [N, Ho, Wo, Hi, Wi, C, 1, 1, 3, 3] |
3710 | // Move the C axis out of Hi*Wi |
3711 | out->reorder({{5, 3}, {3, 4}, {4, 5}}); |
3712 | // out: [N, Ho, Wo, C, Hi, Wi, 1, 1, 3, 3] |
3713 | |
3714 | // Create a [block_x, block_y] tile on smem |
3715 | inp_cache->computeAt(out, 4); |
3716 | inp_cache->setMemoryType(MemoryType::Shared); |
3717 | // Fully inline inp_tile |
3718 | inp_tile->computeAt(out, -1); |
3719 | |
3720 | out->axis(0)->parallelize(ParallelType::BIDz); |
3721 | out->axis(1)->parallelize(ParallelType::BIDy); |
3722 | out->axis(2)->parallelize(ParallelType::BIDx); |
3723 | out->axis(4)->parallelize(ParallelType::TIDy); |
3724 | out->axis(5)->parallelize(ParallelType::TIDx); |
3725 | |
3726 | scheduler_utils::parallelizeAllLike(out, {inp_cache, inp_tile}); |
3727 | |
3728 | const int dim_h = 31; |
3729 | const int dim_w = 33; |
3730 | const int dim_c = 5; |
3731 | const int dim_n = 3; |
3732 | |
3733 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3734 | at::manual_seed(0); |
3735 | at::Tensor at_inp = at::randn({dim_n, dim_c, dim_h, dim_w}, options); |
3736 | std::vector<IValue> inputs = {at_inp}; |
3737 | |
3738 | FusionExecutor fe; |
3739 | fe.compileFusion(&fusion, inputs); |
3740 | auto cg_outputs = fe.runFusion(inputs); |
3741 | |
3742 | auto at_out = at::im2col(at_inp, {3, 3}, {1, 1}, {1, 1}, {1, 1}); |
3743 | |
3744 | // at::im2col outputs [N, C*3*3, N*H] |
3745 | at_out = at::transpose(at_out, 1, 2); |
3746 | at_out = at::reshape(at_out, {dim_n, dim_h, dim_w, dim_c, 1, 1, 3, 3}); |
3747 | |
3748 | testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
3749 | } |
3750 | |
3751 | TEST_F(NVFuserTest, FusionShiftNoPadding1_CUDA) { |
3752 | Fusion fusion; |
3753 | FusionGuard fg(&fusion); |
3754 | |
3755 | auto tv0 = makeSymbolicTensor(2); |
3756 | fusion.addInput(tv0); |
3757 | |
3758 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
3759 | auto tv2 = shift(tv1, {1, -1}, false); |
3760 | auto tv3 = shift(tv1, {-1, 1}, false); |
3761 | auto tv4 = add(tv2, tv3); |
3762 | auto tv5 = sum(tv4, {0, 1}); |
3763 | |
3764 | fusion.addOutput(tv5); |
3765 | |
3766 | tv1->setMemoryType(MemoryType::Shared); |
3767 | |
3768 | tv5->split(0, 4); |
3769 | tv5->split(-1, 8); |
3770 | tv5->reorder({{1, 2}}); |
3771 | |
3772 | TransformPropagator propagator(tv5); |
3773 | MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); |
3774 | |
3775 | tv2->computeAt(tv5, -1); |
3776 | tv3->computeAt(tv5, -1); |
3777 | |
3778 | tv5->axis(-1)->parallelize(ParallelType::TIDx); |
3779 | tv5->axis(-2)->parallelize(ParallelType::TIDy); |
3780 | scheduler_utils::parallelizeAllLike(tv5); |
3781 | |
3782 | int numel_x = 99; |
3783 | int numel_y = 101; |
3784 | |
3785 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3786 | at::manual_seed(0); |
3787 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
3788 | std::vector<IValue> inputs = {t0}; |
3789 | |
3790 | FusionExecutor fe; |
3791 | fe.compileFusion(&fusion, inputs); |
3792 | auto outputs = fe.runFusion(inputs); |
3793 | |
3794 | auto t1 = t0 + 1; |
3795 | auto t2 = shift(t1, {1, -1}); |
3796 | auto t3 = shift(t1, {-1, 1}); |
3797 | auto t4 = t2 + t3; |
3798 | std::vector<at::indexing::TensorIndex> indices{ |
3799 | at::indexing::Slice(1, -1), at::indexing::Slice(1, -1)}; |
3800 | t4 = t4.index(indices); |
3801 | auto ref = t4.sum(at::ArrayRef<int64_t>{0, 1}); |
3802 | |
3803 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
3804 | } |
3805 | |
3806 | // Split and merge |
3807 | TEST_F(NVFuserTest, FusionShiftNoPadding2_CUDA) { |
3808 | Fusion fusion; |
3809 | FusionGuard fg(&fusion); |
3810 | |
3811 | auto tv0 = makeSymbolicTensor(2); |
3812 | fusion.addInput(tv0); |
3813 | |
3814 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
3815 | auto tv2 = shift(tv1, {1, -1}, false); |
3816 | auto tv3 = shift(tv1, {-1, 1}, false); |
3817 | auto tv4 = add(tv2, tv3); |
3818 | auto tv5 = sum(tv4, {0, 1}); |
3819 | |
3820 | fusion.addOutput(tv5); |
3821 | |
3822 | tv1->setMemoryType(MemoryType::Shared); |
3823 | |
3824 | tv5->split(0, 4); |
3825 | tv5->split(-1, 8); |
3826 | tv5->reorder({{1, 2}}); |
3827 | tv5->merge(-2, -1); |
3828 | |
3829 | TransformPropagator propagator(tv5); |
3830 | MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); |
3831 | |
3832 | tv2->computeAt(tv5, -1); |
3833 | tv3->computeAt(tv5, -1); |
3834 | |
3835 | tv5->axis(-1)->parallelize(ParallelType::TIDx); |
3836 | scheduler_utils::parallelizeAllLike(tv5); |
3837 | |
3838 | int numel_x = 99; |
3839 | int numel_y = 101; |
3840 | |
3841 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3842 | at::manual_seed(0); |
3843 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
3844 | std::vector<IValue> inputs = {t0}; |
3845 | |
3846 | FusionExecutor fe; |
3847 | fe.compileFusion(&fusion, inputs); |
3848 | auto outputs = fe.runFusion(inputs); |
3849 | |
3850 | auto t1 = t0 + 1; |
3851 | auto t2 = shift(t1, {1, -1}); |
3852 | auto t3 = shift(t1, {-1, 1}); |
3853 | auto t4 = t2 + t3; |
3854 | std::vector<at::indexing::TensorIndex> indices{ |
3855 | at::indexing::Slice(1, -1), at::indexing::Slice(1, -1)}; |
3856 | t4 = t4.index(indices); |
3857 | auto ref = t4.sum(at::ArrayRef<int64_t>{0, 1}); |
3858 | |
3859 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
3860 | } |
3861 | |
3862 | // Split and merge, then welford |
3863 | TEST_F(NVFuserTest, FusionShiftNoPadding3_CUDA) { |
3864 | Fusion fusion; |
3865 | FusionGuard fg(&fusion); |
3866 | |
3867 | auto tv0 = makeSymbolicTensor(2); |
3868 | fusion.addInput(tv0); |
3869 | |
3870 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
3871 | auto tv2 = shift(tv1, {1, -1}, false); |
3872 | auto tv3 = shift(tv1, {-1, 1}, false); |
3873 | auto tv4 = add(tv2, tv3); |
3874 | auto tvs = Welford(tv4, {0, 1}); |
3875 | auto tv_avg = tvs.avg; |
3876 | auto tv_M2 = tvs.var_sum; |
3877 | auto tv_N = tvs.n; |
3878 | |
3879 | fusion.addOutput(tv_avg); |
3880 | fusion.addOutput(tv_M2); |
3881 | fusion.addOutput(tv_N); |
3882 | |
3883 | tv1->setMemoryType(MemoryType::Shared); |
3884 | |
3885 | tv_avg->split(0, 4); |
3886 | tv_avg->split(-1, 8); |
3887 | tv_avg->reorder({{1, 2}}); |
3888 | tv_avg->merge(-2, -1); |
3889 | |
3890 | TransformPropagator propagator(tv_avg); |
3891 | MaxRootDomainInfoSpanningTree(tv_avg).traverse(&propagator); |
3892 | |
3893 | tv2->computeAt(tv_avg, -1); |
3894 | tv3->computeAt(tv_avg, -1); |
3895 | |
3896 | tv_avg->axis(-1)->parallelize(ParallelType::TIDx); |
3897 | scheduler_utils::parallelizeAllLike(tv_avg); |
3898 | |
3899 | int numel_x = 99; |
3900 | int numel_y = 101; |
3901 | |
3902 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3903 | auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); |
3904 | at::manual_seed(0); |
3905 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
3906 | std::vector<IValue> inputs = {t0}; |
3907 | |
3908 | FusionExecutor fe; |
3909 | fe.compileFusion(&fusion, inputs); |
3910 | auto outputs = fe.runFusion(inputs); |
3911 | |
3912 | outputs[1] /= (numel_x - 2) * (numel_y - 2); |
3913 | |
3914 | auto t1 = t0 + 1; |
3915 | auto t2 = shift(t1, {1, -1}); |
3916 | auto t3 = shift(t1, {-1, 1}); |
3917 | auto t4 = t2 + t3; |
3918 | std::vector<at::indexing::TensorIndex> indices{ |
3919 | at::indexing::Slice(1, -1), at::indexing::Slice(1, -1)}; |
3920 | t4 = t4.index(indices); |
3921 | auto ref_avg = t4.mean(at::ArrayRef<int64_t>{0, 1}); |
3922 | auto ref_M2 = t4.var(at::ArrayRef<int64_t>{0, 1}, false); |
3923 | auto ref_N = at::ones({}, options_int) * (numel_x - 2) * (numel_y - 2); |
3924 | |
3925 | testValidate( |
3926 | fe.kernel(), |
3927 | outputs, |
3928 | inputs, |
3929 | {ref_avg, ref_M2, ref_N}, |
3930 | __LINE__, |
3931 | __FILE__); |
3932 | } |
3933 | |
3934 | // Shift indexing and predication with contiguous merge |
3935 | TEST_F(NVFuserTest, FusionShiftNoPaddingContigMerge_CUDA) { |
3936 | Fusion fusion; |
3937 | FusionGuard fg(&fusion); |
3938 | |
3939 | auto tv0 = makeSymbolicTensor(2); |
3940 | fusion.addInput(tv0); |
3941 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
3942 | auto tv2 = shift(tv1, {1, -1}, true); |
3943 | auto tv3 = shift(tv1, {-1, 1}, false); |
3944 | auto tv4 = add(tv2, tv3); |
3945 | fusion.addOutput(tv4); |
3946 | |
3947 | tv2->merge(0); |
3948 | tv3->merge(0); |
3949 | tv4->merge(0); |
3950 | |
3951 | tv1->setMemoryType(MemoryType::Global); |
3952 | tv2->setMemoryType(MemoryType::Global); |
3953 | tv3->setMemoryType(MemoryType::Global); |
3954 | |
3955 | int numel_x = 9; |
3956 | int numel_y = 11; |
3957 | |
3958 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
3959 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
3960 | std::vector<IValue> inputs = {t0}; |
3961 | |
3962 | FusionExecutor fe; |
3963 | fe.compileFusion(&fusion, inputs); |
3964 | auto outputs = fe.runFusion(inputs); |
3965 | |
3966 | std::vector<at::indexing::TensorIndex> indices{ |
3967 | at::indexing::Slice(1, -1), at::indexing::Slice(1, -1)}; |
3968 | |
3969 | auto fuser_out = outputs[0].index(indices); |
3970 | |
3971 | auto t1 = t0 + 1; |
3972 | auto t2 = shift(t1, {1, -1}); |
3973 | auto t3 = shift(t1, {-1, 1}); |
3974 | auto ref = t2 + t3; |
3975 | |
3976 | ref = ref.index(indices); |
3977 | |
3978 | testValidate(&fusion, {fuser_out}, inputs, {ref}, __LINE__, __FILE__); |
3979 | } |
3980 | |
3981 | TEST_F(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { |
3982 | Fusion fusion; |
3983 | FusionGuard fg(&fusion); |
3984 | |
3985 | auto tv0 = makeSymbolicTensor(2); |
3986 | fusion.addInput(tv0); |
3987 | |
3988 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
3989 | auto tv2 = shift(tv1, {1, -1}, false); |
3990 | auto tv3 = shift(tv2, {1, -1}, false); |
3991 | auto tv4 = sum(tv3, {0, 1}); |
3992 | fusion.addOutput(tv4); |
3993 | |
3994 | tv1->setMemoryType(MemoryType::Shared); |
3995 | tv2->setMemoryType(MemoryType::Shared); |
3996 | |
3997 | tv4->split(0, 4); |
3998 | tv4->split(-1, 8); |
3999 | tv4->reorder({{1, 2}}); |
4000 | |
4001 | tv1->computeAt(tv4, 2); |
4002 | |
4003 | tv4->axis(-1)->parallelize(ParallelType::TIDx); |
4004 | tv4->axis(-2)->parallelize(ParallelType::TIDy); |
4005 | |
4006 | tv4->axis(0)->parallelize(ParallelType::BIDy); |
4007 | tv4->axis(1)->parallelize(ParallelType::BIDx); |
4008 | |
4009 | scheduler_utils::parallelizeAllLike(tv4, {tv1, tv2, tv3}); |
4010 | |
4011 | int numel_x = 99; |
4012 | int numel_y = 101; |
4013 | |
4014 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4015 | auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); |
4016 | at::manual_seed(0); |
4017 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
4018 | std::vector<IValue> inputs = {t0}; |
4019 | |
4020 | FusionExecutor fe; |
4021 | fe.compileFusion(&fusion, inputs); |
4022 | auto outputs = fe.runFusion(inputs); |
4023 | |
4024 | auto t1 = t0 + 1; |
4025 | auto t2 = shift(t1, {1, -1}); |
4026 | auto t3 = shift(t2, {1, -1}); |
4027 | std::vector<at::indexing::TensorIndex> indices{ |
4028 | at::indexing::Slice(2, at::indexing::None), at::indexing::Slice(0, -2)}; |
4029 | t3 = t3.index(indices); |
4030 | auto ref = t3.sum(at::ArrayRef<int64_t>{0, 1}); |
4031 | |
4032 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
4033 | } |
4034 | |
4035 | // Rfactor is not allowed with partial domains |
4036 | TEST_F(NVFuserTest, FusionShiftNoPaddingRfactor_CUDA) { |
4037 | Fusion fusion; |
4038 | FusionGuard fg(&fusion); |
4039 | |
4040 | auto tv0 = makeSymbolicTensor(2); |
4041 | fusion.addInput(tv0); |
4042 | |
4043 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4044 | auto tv2 = shift(tv1, {1, -1}, false); |
4045 | auto tv3 = sum(tv2, {0, 1}); |
4046 | fusion.addOutput(tv3); |
4047 | |
4048 | tv3->split(0, 4); |
4049 | tv3->split(-1, 8); |
4050 | tv3->reorder({{1, 2}}); |
4051 | |
4052 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
4053 | ASSERT_ANY_THROW(tv3->rFactor({-2})); |
4054 | } |
4055 | |
4056 | TEST_F(NVFuserTest, FusionShiftPadding1_CUDA) { |
4057 | Fusion fusion; |
4058 | FusionGuard fg(&fusion); |
4059 | |
4060 | auto tv0 = makeSymbolicTensor(2); |
4061 | fusion.addInput(tv0); |
4062 | |
4063 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4064 | auto tv2 = shift(tv1, {2, -2}, {1, 1}); |
4065 | auto tv3 = shift(tv1, {-3, 2}, {2, 2}); |
4066 | auto tv4 = add(tv2, tv3); |
4067 | auto tv5 = sum(tv4, {0, 1}); |
4068 | |
4069 | fusion.addOutput(tv5); |
4070 | |
4071 | tv1->setMemoryType(MemoryType::Shared); |
4072 | |
4073 | tv5->split(0, 4); |
4074 | tv5->split(-1, 8); |
4075 | tv5->reorder({{1, 2}}); |
4076 | |
4077 | TransformPropagator propagator(tv5); |
4078 | MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); |
4079 | |
4080 | tv2->computeAt(tv5, -1); |
4081 | tv3->computeAt(tv5, -1); |
4082 | |
4083 | tv5->axis(-1)->parallelize(ParallelType::TIDx); |
4084 | tv5->axis(-2)->parallelize(ParallelType::TIDy); |
4085 | scheduler_utils::parallelizeAllLike(tv5); |
4086 | |
4087 | int numel_x = 99; |
4088 | int numel_y = 101; |
4089 | |
4090 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4091 | at::manual_seed(0); |
4092 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
4093 | std::vector<IValue> inputs = {t0}; |
4094 | |
4095 | FusionExecutor fe; |
4096 | fe.compileFusion(&fusion, inputs); |
4097 | auto outputs = fe.runFusion(inputs); |
4098 | |
4099 | auto t1 = t0 + 1; |
4100 | auto t2 = shift(t1, {2, -2}); |
4101 | auto t3 = shift(t1, {-3, 2}); |
4102 | auto t4 = t2 + t3; |
4103 | std::vector<at::indexing::TensorIndex> indices{ |
4104 | at::indexing::Slice(1, -1), at::indexing::Slice(0, -1)}; |
4105 | t4 = t4.index(indices); |
4106 | auto ref = t4.sum(at::ArrayRef<int64_t>{0, 1}); |
4107 | |
4108 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
4109 | } |
4110 | |
4111 | TEST_F(NVFuserTest, FusionPartialSplit1_CUDA) { |
4112 | Fusion fusion; |
4113 | FusionGuard fg(&fusion); |
4114 | |
4115 | auto tv0 = makeSymbolicTensor(1); |
4116 | // [I] |
4117 | fusion.addInput(tv0); |
4118 | |
4119 | auto tv1 = add(tv0, IrBuilder::create<Double>(0)); |
4120 | // [I] |
4121 | auto tv2 = shift(tv1, {1}, false); |
4122 | // [1:I] |
4123 | auto tv3 = shift(tv1, {-1}, false); |
4124 | // [0:I-1] |
4125 | auto tv4 = add(tv2, tv3); |
4126 | // [1:I-1] |
4127 | fusion.addOutput(tv4); |
4128 | |
4129 | // Partial split of tv4. Split only the valid range, which is |
4130 | // [1:-1]. |
4131 | tv4->split(0, 8, true, true); |
4132 | // [(I-2)/8, 8] |
4133 | |
4134 | // Propagates the partial split back to tv1. This means that all of |
4135 | // the other tensors are also shaped as [(I-2)/8, 8], which appears |
4136 | // to mean only the sub region of ((I-2)/8 * 8) is |
4137 | // computed for tv1, tv2 and tv3. It's fine for the tv2 and tv3 |
4138 | // tensors as only that sub region is used by tv4. It's also fine |
4139 | // for tv1 since it has halo of size one at each side, so the whole |
4140 | // region is actually calculated for tv1. |
4141 | tv1->computeAt(tv4, 1); |
4142 | |
4143 | tv4->axis(-1)->parallelize(ParallelType::TIDx); |
4144 | tv4->axis(-2)->parallelize(ParallelType::BIDx); |
4145 | scheduler_utils::parallelizeAllLike(tv4, {tv1, tv2, tv3}); |
4146 | |
4147 | tv1->setMemoryType(MemoryType::Shared); |
4148 | |
4149 | // gridDim.x is ceilDiv(numel_x - 2, 8), not ceilDiv(numel_x, 8), |
4150 | // so it's going to be just 2 rather than 3. |
4151 | const int numel_x = 18; |
4152 | |
4153 | ExpressionEvaluator evaluator(&fusion); |
4154 | auto root_extent = tv4->getRootDomain()[0]->extent(); |
4155 | evaluator.bind(root_extent, numel_x); |
4156 | auto extent_eval = evaluator.evaluate(tv4->axis(0)->extent()); |
4157 | TORCH_CHECK( |
4158 | extent_eval.has_value(), |
4159 | "Invalid evaluation of outer domain extent of partial split" ); |
4160 | TORCH_CHECK( |
4161 | extent_eval.value() == (numel_x - 2) / 8, |
4162 | "Invalid extent of outer domain of partial split" ); |
4163 | |
4164 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4165 | auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); |
4166 | at::manual_seed(0); |
4167 | at::Tensor t0 = at::randn({numel_x}, options); |
4168 | std::vector<IValue> inputs = {t0}; |
4169 | |
4170 | FusionExecutor fe; |
4171 | fe.compileFusion(&fusion, inputs); |
4172 | auto outputs = fe.runFusion(inputs); |
4173 | |
4174 | std::vector<at::indexing::TensorIndex> indices{at::indexing::Slice(1, -1)}; |
4175 | |
4176 | outputs[0] = outputs[0].index(indices); |
4177 | |
4178 | auto ref = (shift(t0, {1}) + shift(t0, {-1})).index(indices); |
4179 | |
4180 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
4181 | } |
4182 | |
4183 | TEST_F(NVFuserTest, FusionPartialSplit2_CUDA) { |
4184 | Fusion fusion; |
4185 | FusionGuard fg(&fusion); |
4186 | |
4187 | auto tv0 = makeSymbolicTensor(1); |
4188 | fusion.addInput(tv0); |
4189 | |
4190 | auto tv1 = add(tv0, IrBuilder::create<Double>(0)); |
4191 | auto tv2 = shift(tv1, {1}, false); |
4192 | auto tv3 = shift(tv1, {-1}, false); |
4193 | auto tv4 = add(tv2, tv3); |
4194 | fusion.addOutput(tv4); |
4195 | |
4196 | auto tv5 = add(tv1, IrBuilder::create<Double>(1)); |
4197 | auto tv6 = add(tv5, IrBuilder::create<Double>(1)); |
4198 | fusion.addOutput(tv6); |
4199 | |
4200 | tv4->split(0, 4, true, true); |
4201 | |
4202 | // This causes tv5 and tv6 also to be split with the same partial |
4203 | // offsets, however, since they need to be calculated entirely, the |
4204 | // resulting code would be invalid. It should be detected as part of |
4205 | // initial fusion validation during lowering. |
4206 | tv1->computeAt(tv4, 1); |
4207 | |
4208 | // Validation should throw an error due to tv5 and tv6. |
4209 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
4210 | ASSERT_ANY_THROW(fusion.printKernel()); |
4211 | } |
4212 | |
4213 | // 2D version of PartialSplit1 |
4214 | TEST_F(NVFuserTest, FusionPartialSplit3_CUDA) { |
4215 | Fusion fusion; |
4216 | FusionGuard fg(&fusion); |
4217 | |
4218 | auto tv0 = makeSymbolicTensor(2); |
4219 | fusion.addInput(tv0); |
4220 | |
4221 | auto tv1 = add(tv0, IrBuilder::create<Double>(0)); |
4222 | auto tv2 = shift(tv1, {1, 2}, false); |
4223 | auto tv3 = shift(tv1, {-2, -1}, false); |
4224 | auto tv4 = add(tv2, tv3); |
4225 | fusion.addOutput(tv4); |
4226 | |
4227 | tv4->split(1, 8, true, true); |
4228 | tv4->split(0, 4, true, true); |
4229 | tv4->reorder({{1, 2}, {2, 1}}); |
4230 | |
4231 | tv1->computeAt(tv4, 2); |
4232 | |
4233 | tv4->axis(0)->parallelize(ParallelType::BIDy); |
4234 | tv4->axis(1)->parallelize(ParallelType::BIDx); |
4235 | tv4->axis(2)->parallelize(ParallelType::TIDy); |
4236 | tv4->axis(3)->parallelize(ParallelType::TIDx); |
4237 | scheduler_utils::parallelizeAllLike(tv4, {tv1, tv2, tv3}); |
4238 | |
4239 | tv1->setMemoryType(MemoryType::Shared); |
4240 | |
4241 | const int numel_x = 32 + 3; |
4242 | const int numel_y = 32 + 3; |
4243 | |
4244 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4245 | auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); |
4246 | at::manual_seed(0); |
4247 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
4248 | std::vector<IValue> inputs = {t0}; |
4249 | |
4250 | FusionExecutor fe; |
4251 | fe.compileFusion(&fusion, inputs); |
4252 | auto outputs = fe.runFusion(inputs); |
4253 | |
4254 | std::vector<at::indexing::TensorIndex> indices{ |
4255 | at::indexing::Slice(1, -2), at::indexing::Slice(2, -1)}; |
4256 | |
4257 | outputs[0] = outputs[0].index(indices); |
4258 | |
4259 | auto ref = (shift(t0, {1, 2}) + shift(t0, {-2, -1})).index(indices); |
4260 | |
4261 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
4262 | } |
4263 | |
4264 | // Almost same fusion with Shift5ptStencilChain but non-padded shift |
4265 | // and partial split. |
4266 | TEST_F(NVFuserTest, FusionPartialSplit4_CUDA) { |
4267 | Fusion fusion; |
4268 | FusionGuard fg(&fusion); |
4269 | |
4270 | auto tv0 = makeSymbolicTensor(2); |
4271 | fusion.addInput(tv0); |
4272 | |
4273 | std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; |
4274 | |
4275 | // First stencil: 5pt stencil |
4276 | // stencil1 = (tv0 + tv0[+1][0] + tv0[-1][0] + tv0[0][+1] + tv0[0][-1]) / 5 |
4277 | std::vector<TensorView*> tv_stencil1_shifts; |
4278 | for (const auto& offset : offsets) { |
4279 | tv_stencil1_shifts.push_back(shift(tv0, offset, false)); |
4280 | } |
4281 | |
4282 | auto tv_stencil1 = tv0; |
4283 | for (auto tv : tv_stencil1_shifts) { |
4284 | tv_stencil1 = add(tv_stencil1, tv); |
4285 | } |
4286 | |
4287 | tv_stencil1 = div( |
4288 | tv_stencil1, IrBuilder::create<Double>(tv_stencil1_shifts.size() + 1)); |
4289 | |
4290 | // Second stencil: Same 5pt stencil |
4291 | std::vector<TensorView*> tv_stencil2_shifts; |
4292 | for (const auto& offset : offsets) { |
4293 | tv_stencil2_shifts.push_back(shift(tv_stencil1, offset, false)); |
4294 | } |
4295 | |
4296 | auto tv_stencil2 = tv_stencil1; |
4297 | for (auto tv : tv_stencil2_shifts) { |
4298 | tv_stencil2 = add(tv_stencil2, tv); |
4299 | } |
4300 | |
4301 | tv_stencil2 = div( |
4302 | tv_stencil2, IrBuilder::create<Double>(tv_stencil2_shifts.size() + 1)); |
4303 | |
4304 | auto tv_out = tv_stencil2; |
4305 | |
4306 | fusion.addOutput(tv_out); |
4307 | |
4308 | auto tv0_cache = tv0->cacheAfter(); |
4309 | |
4310 | std::vector<int> split_factor({16, 16}); |
4311 | |
4312 | tv_out->split(-1, split_factor[1], true, true); |
4313 | tv_out->split(0, split_factor[0], true, true); |
4314 | tv_out->reorder({{1, 2}, {2, 1}}); |
4315 | |
4316 | tv0->computeAt(tv_out, 2); |
4317 | |
4318 | // Inline completely all inputs to the first stencil output, except for the |
4319 | // tv0 cache |
4320 | for (auto tv : tv_stencil1_shifts) { |
4321 | tv->computeAt(tv_stencil1, -1); |
4322 | } |
4323 | |
4324 | // Inline completely all inputs to the second stencil output, except |
4325 | // for the first stencil output |
4326 | for (auto tv : tv_stencil2_shifts) { |
4327 | tv->computeAt(tv_stencil2, -1); |
4328 | } |
4329 | |
4330 | tv_out->axis(0)->parallelize(ParallelType::BIDy); |
4331 | tv_out->axis(1)->parallelize(ParallelType::BIDx); |
4332 | tv_out->axis(2)->parallelize(ParallelType::TIDy); |
4333 | tv_out->axis(3)->parallelize(ParallelType::TIDx); |
4334 | |
4335 | auto all_values = DependencyCheck::getAllValsBetween( |
4336 | {fusion.inputs().begin(), fusion.inputs().end()}, fusion.outputs()); |
4337 | for (auto tv : ir_utils::filterByType<TensorView>(all_values)) { |
4338 | scheduler_utils::parallelizeAllLike(tv_out, {tv}); |
4339 | } |
4340 | |
4341 | tv0_cache->setMemoryType(MemoryType::Shared); |
4342 | tv_stencil1->setMemoryType(MemoryType::Shared); |
4343 | |
4344 | // Input matrix size is 68x68, and the output is 64x64. Both |
4345 | // gridDim.x and gridim.y should be ceilDiv(numel - 4, |
4346 | // split_factor), which is 4. If full split is used, the grid |
4347 | // dimension would be 5. |
4348 | const int numel_x = 64 + 4; |
4349 | const int numel_y = 64 + 4; |
4350 | |
4351 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4352 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
4353 | std::vector<IValue> inputs = {t0}; |
4354 | |
4355 | FusionExecutor fe; |
4356 | fe.compileFusion(&fusion, inputs); |
4357 | auto outputs = fe.runFusion(inputs); |
4358 | |
4359 | std::vector<at::indexing::TensorIndex> indices{ |
4360 | at::indexing::Slice(2, -2), at::indexing::Slice(2, -2)}; |
4361 | |
4362 | outputs[0] = outputs[0].index(indices); |
4363 | |
4364 | auto stencil1 = t0; |
4365 | for (const auto& offset : offsets) { |
4366 | stencil1 = stencil1 + shift(t0, offset); |
4367 | } |
4368 | stencil1 = stencil1 / int(offsets.size() + 1); |
4369 | auto stencil2 = stencil1; |
4370 | for (const auto& offset : offsets) { |
4371 | stencil2 = stencil2 + shift(stencil1, offset); |
4372 | } |
4373 | stencil2 = stencil2 / int(offsets.size() + 1); |
4374 | auto ref = stencil2.index(indices); |
4375 | |
4376 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
4377 | } |
4378 | |
4379 | TEST_F(NVFuserTest, FusionPartialSplit5_CUDA) { |
4380 | Fusion fusion; |
4381 | FusionGuard fg(&fusion); |
4382 | |
4383 | const int numel_x = 10; |
4384 | const int numel_y = 11; |
4385 | |
4386 | // auto tv0 = makeSymbolicTensor(2); |
4387 | auto tv0 = makeConcreteTensor({numel_x, numel_y}); |
4388 | fusion.addInput(tv0); |
4389 | |
4390 | auto tv1 = shift(tv0, {0, 1}, false); |
4391 | auto tv2 = add(tv1, IrBuilder::create<Double>(1)); |
4392 | |
4393 | fusion.addOutput(tv2); |
4394 | |
4395 | // Partially split tv2 but not tv1. Producer indexing with tv2 as a consumer |
4396 | // requires adjustment of the index to account for the difference of split |
4397 | // offsets. |
4398 | tv2->split(1, 4, true, true); |
4399 | tv1->split(1, 4); |
4400 | |
4401 | tv1->computeAt(tv2, 1); |
4402 | |
4403 | tv2->axis(1)->parallelize(ParallelType::TIDx); |
4404 | tv1->axis(1)->parallelize(ParallelType::TIDx); |
4405 | |
4406 | tv1->setMemoryType(MemoryType::Shared); |
4407 | |
4408 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4409 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
4410 | std::vector<IValue> inputs = {t0}; |
4411 | |
4412 | FusionExecutor fe; |
4413 | fe.compileFusion(&fusion, inputs); |
4414 | auto outputs = fe.runFusion(inputs); |
4415 | |
4416 | std::vector<at::indexing::TensorIndex> indices{ |
4417 | at::indexing::Slice(0, at::indexing::None), |
4418 | at::indexing::Slice(1, at::indexing::None)}; |
4419 | |
4420 | outputs[0] = outputs[0].index(indices); |
4421 | |
4422 | auto ref = (shift(t0, {0, 1}) + 1).index(indices); |
4423 | |
4424 | testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); |
4425 | } |
4426 | |
4427 | TEST_F(NVFuserTest, FusionPartialSplit6_CUDA) { |
4428 | Fusion fusion; |
4429 | FusionGuard fg(&fusion); |
4430 | |
4431 | const int numel_x = 9; |
4432 | |
4433 | auto tv0 = makeConcreteTensor({numel_x}); |
4434 | fusion.addInput(tv0); |
4435 | |
4436 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4437 | auto tv2 = shift(tv1, {1}, false); |
4438 | auto tv3 = add(tv2, IrBuilder::create<Double>(1)); |
4439 | |
4440 | fusion.addOutput(tv3); |
4441 | |
4442 | // Another mix of partial and non-partial split |
4443 | tv1->split(0, 4); |
4444 | tv2->split(0, 4, true, true); |
4445 | tv3->split(0, 4); |
4446 | |
4447 | // Just make it easier for compute-sanitizer to flag invalid memory accesses |
4448 | tv1->setMemoryType(MemoryType::Shared); |
4449 | tv2->setMemoryType(MemoryType::Shared); |
4450 | |
4451 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4452 | at::Tensor t0 = at::randn({numel_x}, options); |
4453 | std::vector<IValue> inputs = {t0}; |
4454 | |
4455 | FusionExecutor fe; |
4456 | fe.compileFusion(&fusion, inputs); |
4457 | auto outputs = fe.runFusion(inputs); |
4458 | |
4459 | std::vector<at::indexing::TensorIndex> indices{ |
4460 | at::indexing::Slice(1, at::indexing::None)}; |
4461 | |
4462 | outputs[0] = outputs[0].index(indices); |
4463 | |
4464 | auto ref = (shift(t0 + 1, {1}) + 1).index(indices); |
4465 | |
4466 | testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); |
4467 | } |
4468 | |
4469 | TEST_F(NVFuserTest, FusionShiftUnswitch1_CUDA) { |
4470 | Fusion fusion; |
4471 | FusionGuard fg(&fusion); |
4472 | |
4473 | auto tv0 = makeSymbolicTensor(2); |
4474 | fusion.addInput(tv0); |
4475 | |
4476 | auto tv1 = shift(tv0, {-1, 0}); |
4477 | fusion.addOutput(tv1); |
4478 | |
4479 | auto tv2 = shift(tv0, {0, 1}); |
4480 | fusion.addOutput(tv2); |
4481 | |
4482 | auto tv3 = shift(tv0, {2, 2}); |
4483 | fusion.addOutput(tv3); |
4484 | |
4485 | auto tv4 = shift(tv0, {-2, -2}); |
4486 | fusion.addOutput(tv4); |
4487 | |
4488 | auto tv5 = add(tv0, IrBuilder::create<Double>(1)); |
4489 | auto tv6 = shift(tv5, {0, -1}); |
4490 | fusion.addOutput(tv6); |
4491 | |
4492 | tv1->axis(1)->parallelize(ParallelType::Unswitch); |
4493 | tv2->axis(1)->parallelize(ParallelType::Unswitch); |
4494 | tv3->axis(0)->parallelize(ParallelType::Unswitch); |
4495 | tv4->axis(0)->parallelize(ParallelType::Unswitch); |
4496 | |
4497 | tv5->axis(1)->parallelize(ParallelType::TIDx); |
4498 | tv6->axis(1)->parallelize(ParallelType::TIDx); |
4499 | tv5->axis(0)->parallelize(ParallelType::Unswitch); |
4500 | tv5->setMemoryType(MemoryType::Shared); |
4501 | |
4502 | int numel_x = 9; |
4503 | int numel_y = 11; |
4504 | |
4505 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4506 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
4507 | std::vector<IValue> inputs = {t0}; |
4508 | |
4509 | FusionExecutor fe; |
4510 | fe.compileFusion(&fusion, inputs); |
4511 | auto outputs = fe.runFusion(inputs); |
4512 | |
4513 | auto t1 = shift(t0, {-1, 0}); |
4514 | TORCH_CHECK(t1.equal(outputs[0])); |
4515 | |
4516 | auto t2 = shift(t0, {0, 1}); |
4517 | TORCH_CHECK(t2.equal(outputs[1])); |
4518 | |
4519 | auto t3 = shift(t0, {2, 2}); |
4520 | TORCH_CHECK(t3.equal(outputs[2])); |
4521 | |
4522 | auto t4 = shift(t0, {-2, -2}); |
4523 | TORCH_CHECK(t4.equal(outputs[3])); |
4524 | |
4525 | auto t6 = shift(t0 + 1, {0, -1}); |
4526 | TORCH_CHECK(t6.equal(outputs[4])); |
4527 | } |
4528 | |
4529 | TEST_F(NVFuserTest, FusionGatherUnswitch1_CUDA) { |
4530 | const int tv1_gather = 3; |
4531 | const int tv1_gather_pad = 1; |
4532 | const int tv2_gather = 5; |
4533 | const int tv2_gather_pad = 2; |
4534 | |
4535 | Fusion fusion; |
4536 | FusionGuard fg(&fusion); |
4537 | |
4538 | auto tv0 = makeSymbolicTensor(1); |
4539 | fusion.addInput(tv0); |
4540 | |
4541 | auto tv1 = gather(tv0, {tv1_gather}, {{tv1_gather_pad, tv1_gather_pad}}); |
4542 | fusion.addOutput(tv1); |
4543 | |
4544 | auto tv2 = gather(tv0, {tv2_gather}, {{tv2_gather_pad, tv2_gather_pad}}); |
4545 | fusion.addOutput(tv2); |
4546 | |
4547 | // Static gather |
4548 | auto tv3 = gather(tv0, {3}, {{1, 1}}); |
4549 | fusion.addOutput(tv3); |
4550 | |
4551 | // Static gather |
4552 | auto tv4 = gather(tv0, {5}, {{2, 2}}); |
4553 | fusion.addOutput(tv4); |
4554 | |
4555 | auto tv0_cache = tv0->cacheAfter(); |
4556 | tv0_cache->setMemoryType(MemoryType::Shared); |
4557 | |
4558 | tv4->split(0, 32); |
4559 | |
4560 | tv0->computeAt(tv4, 1); |
4561 | |
4562 | tv4->axis(0)->parallelize(ParallelType::Unswitch); |
4563 | tv4->axis(1)->parallelize(ParallelType::TIDx); |
4564 | |
4565 | const int numel_x = 100; |
4566 | |
4567 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4568 | at::Tensor t0 = at::randn({numel_x}, options); |
4569 | std::vector<IValue> inputs = {t0}; |
4570 | |
4571 | FusionExecutor fe; |
4572 | fe.compileFusion(&fusion, inputs); |
4573 | auto outputs = fe.runFusion(inputs); |
4574 | |
4575 | auto t1 = gather(t0, {tv1_gather}, {{tv1_gather_pad, tv1_gather_pad}}); |
4576 | TORCH_CHECK(t1.equal(outputs[0])); |
4577 | |
4578 | auto t2 = gather(t0, {tv2_gather}, {{tv2_gather_pad, tv2_gather_pad}}); |
4579 | TORCH_CHECK(t2.equal(outputs[1])); |
4580 | |
4581 | auto t3 = gather(t0, {3}, {{1, 1}}); |
4582 | TORCH_CHECK(t3.equal(outputs[2])); |
4583 | |
4584 | auto t4 = gather(t0, {5}, {{2, 2}}); |
4585 | TORCH_CHECK(t4.equal(outputs[3])); |
4586 | } |
4587 | |
4588 | TEST_F(NVFuserTest, FusionGatherStrided1_CUDA) { |
4589 | Fusion fusion; |
4590 | FusionGuard fg(&fusion); |
4591 | |
4592 | auto tv0 = makeSymbolicTensor(2); |
4593 | fusion.addInput(tv0); |
4594 | |
4595 | const std::vector<int> window_shape = {1, 3}; |
4596 | const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}}; |
4597 | |
4598 | const std::vector<int> strides = {1, 3}; |
4599 | |
4600 | auto tv1 = gather(tv0, window_shape, padding_width, strides); |
4601 | |
4602 | fusion.addOutput(tv1); |
4603 | |
4604 | const int s1 = 11; |
4605 | const int s2 = 13; |
4606 | |
4607 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4608 | at::Tensor t0 = at::randn({s1, s2}, options); |
4609 | |
4610 | FusionExecutor fe; |
4611 | fe.compileFusion(&fusion, {t0}); |
4612 | auto outputs = fe.runFusion({t0}); |
4613 | |
4614 | // tv1 has a stride dimension, so its number of dimensions should be |
4615 | // input_ndims + window_ndims + stride. |
4616 | TORCH_CHECK(tv1->nDims() == tv0->nDims() * 2 + 1); |
4617 | |
4618 | // However, the number of dimensions of the Aten tensor should still |
4619 | // be just the twice of the number of dimensions of the input |
4620 | // tensor. |
4621 | auto fuser_out = outputs[0]; |
4622 | TORCH_CHECK( |
4623 | fuser_out.ndimension() == static_cast<int64_t>(tv0->nDims()) * 2, |
4624 | "Invalid dimensionality of output tensor: " , |
4625 | fuser_out.ndimension()); |
4626 | |
4627 | // Each output dimension should be: ceilDiv(input_size + padding_width - |
4628 | // window, stride). |
4629 | for (const auto i : c10::irange(window_shape.size())) { |
4630 | auto valid_dim = ceilDiv( |
4631 | t0.size(i) + padding_width[i][0] + padding_width[i][1] - |
4632 | window_shape[i] + 1, |
4633 | strides[i]); |
4634 | auto actual_dim = outputs[0].size(i); |
4635 | TORCH_CHECK( |
4636 | valid_dim == actual_dim, |
4637 | "Invalid output size at dimension " , |
4638 | i, |
4639 | ". Expected: " , |
4640 | valid_dim, |
4641 | ", actual: " , |
4642 | actual_dim); |
4643 | } |
4644 | |
4645 | auto ref = gather(t0, window_shape, padding_width, strides); |
4646 | |
4647 | TORCH_CHECK(ref.equal(outputs[0])); |
4648 | } |
4649 | |
4650 | // Split strided domain |
4651 | TEST_F(NVFuserTest, FusionGatherStrided2_CUDA) { |
4652 | Fusion fusion; |
4653 | FusionGuard fg(&fusion); |
4654 | |
4655 | const std::vector<int> window_shape = {3}; |
4656 | const std::vector<std::vector<int>> padding_width = {{1, 1}}; |
4657 | const std::vector<int> strides = {3}; |
4658 | |
4659 | auto tv0 = makeSymbolicTensor(1); |
4660 | fusion.addInput(tv0); |
4661 | |
4662 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4663 | |
4664 | auto tv2 = gather(tv1, window_shape, padding_width, strides); |
4665 | |
4666 | auto tv3 = sum(tv2, {-1}); |
4667 | |
4668 | fusion.addOutput(tv3); |
4669 | |
4670 | // Split the strided domain |
4671 | tv3->split(0, 4); |
4672 | |
4673 | // Propagate the split by 4 of the tv3 domain to pre-stride domains, |
4674 | // making them split by 4 * 3 |
4675 | tv0->computeAt(tv3, 1); |
4676 | |
4677 | tv2->computeAt(tv3, -1); |
4678 | |
4679 | tv3->axis(0)->parallelize(ParallelType::BIDx); |
4680 | tv3->axis(1)->parallelize(ParallelType::TIDx); |
4681 | scheduler_utils::parallelizeAllLike(tv3, {tv1, tv2}); |
4682 | |
4683 | tv1->setMemoryType(MemoryType::Shared); |
4684 | |
4685 | const int s1 = 100; |
4686 | |
4687 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4688 | at::Tensor t0 = at::randn({s1}, options); |
4689 | std::vector<IValue> inputs = {t0}; |
4690 | |
4691 | FusionExecutor fe; |
4692 | fe.compileFusion(&fusion, inputs); |
4693 | auto outputs = fe.runFusion(inputs); |
4694 | |
4695 | auto t1 = t0 + 1; |
4696 | auto t2 = gather(t1, window_shape, padding_width, strides); |
4697 | auto ref = sum(t2, {-1}); |
4698 | |
4699 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
4700 | } |
4701 | |
4702 | // Outer split |
4703 | TEST_F(NVFuserTest, FusionGatherStrided3_CUDA) { |
4704 | Fusion fusion; |
4705 | FusionGuard fg(&fusion); |
4706 | |
4707 | const std::vector<int> window_shape = {3}; |
4708 | const std::vector<std::vector<int>> padding_width = {{1, 1}}; |
4709 | const std::vector<int> strides = {3}; |
4710 | |
4711 | auto tv0 = makeSymbolicTensor(1); |
4712 | fusion.addInput(tv0); |
4713 | |
4714 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4715 | |
4716 | auto tv2 = gather(tv1, window_shape, padding_width, strides); |
4717 | |
4718 | auto tv3 = sum(tv2, {-1}); |
4719 | fusion.addOutput(tv3); |
4720 | |
4721 | // Outer split |
4722 | tv3->split(0, 2, false); |
4723 | |
4724 | tv0->computeAt(tv3, 1); |
4725 | |
4726 | tv3->axis(0)->parallelize(ParallelType::BIDx); |
4727 | tv3->axis(1)->parallelize(ParallelType::TIDx); |
4728 | scheduler_utils::parallelizeAllLike(tv3, {tv1, tv2}); |
4729 | |
4730 | tv1->setMemoryType(MemoryType::Shared); |
4731 | tv2->setMemoryType(MemoryType::Shared); |
4732 | |
4733 | const int s1 = 100; |
4734 | |
4735 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4736 | at::Tensor t0 = at::randn({s1}, options); |
4737 | std::vector<IValue> inputs = {t0}; |
4738 | |
4739 | FusionExecutor fe; |
4740 | fe.compileFusion(&fusion, inputs); |
4741 | auto outputs = fe.runFusion(inputs); |
4742 | |
4743 | auto t1 = t0 + 1; |
4744 | auto t2 = gather(t1, window_shape, padding_width, strides); |
4745 | auto ref = sum(t2, {-1}); |
4746 | |
4747 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
4748 | } |
4749 | |
4750 | TEST_F(NVFuserTest, FusionGatherStrided4_CUDA) { |
4751 | Fusion fusion; |
4752 | FusionGuard fg(&fusion); |
4753 | |
4754 | const std::vector<int> window_shape = {3}; |
4755 | const std::vector<std::vector<int>> padding_width = {{1, 1}}; |
4756 | const std::vector<int> strides = {3}; |
4757 | |
4758 | auto tv0 = makeSymbolicTensor(1); |
4759 | fusion.addInput(tv0); |
4760 | |
4761 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4762 | |
4763 | // Test propagation of split from one gather output to another |
4764 | auto tv2 = gather(tv1, window_shape, padding_width, strides); |
4765 | auto tv3 = gather(tv1, window_shape, padding_width, strides); |
4766 | |
4767 | auto tv4 = sum(tv2, {-1}); |
4768 | fusion.addOutput(tv4); |
4769 | |
4770 | auto tv5 = sum(tv3, {-1}); |
4771 | fusion.addOutput(tv5); |
4772 | |
4773 | tv4->split(0, 2); |
4774 | |
4775 | // Test forward computeAt propagation from tv1 to tv3 |
4776 | tv0->computeAt(tv4, 1); |
4777 | |
4778 | const int s1 = 101; |
4779 | |
4780 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4781 | at::Tensor t0 = at::randn({s1}, options); |
4782 | std::vector<IValue> inputs = {t0}; |
4783 | |
4784 | FusionExecutor fe; |
4785 | fe.compileFusion(&fusion, inputs); |
4786 | auto outputs = fe.runFusion(inputs); |
4787 | |
4788 | auto t1 = t0 + 1; |
4789 | auto t2 = gather(t1, window_shape, padding_width, strides); |
4790 | auto ref = sum(t2, {-1}); |
4791 | |
4792 | testValidate(&fusion, outputs, inputs, {ref, ref}, __LINE__, __FILE__); |
4793 | } |
4794 | |
4795 | // Same as GatherStrided1 but with stride != window |
4796 | TEST_F(NVFuserTest, FusionGatherStrided5_CUDA) { |
4797 | Fusion fusion; |
4798 | FusionGuard fg(&fusion); |
4799 | |
4800 | auto tv0 = makeSymbolicTensor(2); |
4801 | fusion.addInput(tv0); |
4802 | |
4803 | const std::vector<int> window_shape = {1, 3}; |
4804 | const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}}; |
4805 | |
4806 | const std::vector<int> strides = {1, 2}; |
4807 | |
4808 | auto tv1 = gather(tv0, window_shape, padding_width, strides); |
4809 | |
4810 | fusion.addOutput(tv1); |
4811 | |
4812 | const int s1 = 11; |
4813 | const int s2 = 13; |
4814 | |
4815 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4816 | at::Tensor t0 = at::randn({s1, s2}, options); |
4817 | |
4818 | FusionExecutor fe; |
4819 | fe.compileFusion(&fusion, {t0}); |
4820 | auto outputs = fe.runFusion({t0}); |
4821 | |
4822 | auto ref = gather(t0, window_shape, padding_width, strides); |
4823 | |
4824 | TORCH_CHECK(ref.equal(outputs[0])); |
4825 | } |
4826 | |
4827 | // Same as GatherStrided2 but with stride != window |
4828 | TEST_F(NVFuserTest, FusionGatherStrided6_CUDA) { |
4829 | Fusion fusion; |
4830 | FusionGuard fg(&fusion); |
4831 | |
4832 | const std::vector<int> window_shape = {3}; |
4833 | const std::vector<std::vector<int>> padding_width = {{1, 1}}; |
4834 | const std::vector<int> strides = {2}; |
4835 | |
4836 | auto tv0 = makeSymbolicTensor(1); |
4837 | fusion.addInput(tv0); |
4838 | |
4839 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4840 | |
4841 | auto tv2 = gather(tv1, window_shape, padding_width, strides); |
4842 | |
4843 | auto tv3 = sum(tv2, {-1}); |
4844 | |
4845 | fusion.addOutput(tv3); |
4846 | |
4847 | // Split the strided domain |
4848 | tv3->split(0, 4); |
4849 | |
4850 | // Propagate the split by 4 of the tv3 domain to pre-stride domains, |
4851 | // making them split by 4 * 2 |
4852 | tv0->computeAt(tv3, 1); |
4853 | |
4854 | tv2->computeAt(tv3, -1); |
4855 | |
4856 | tv3->axis(0)->parallelize(ParallelType::BIDx); |
4857 | tv3->axis(1)->parallelize(ParallelType::TIDx); |
4858 | scheduler_utils::parallelizeAllLike(tv3, {tv1, tv2}); |
4859 | |
4860 | tv1->setMemoryType(MemoryType::Shared); |
4861 | |
4862 | const int s1 = 100; |
4863 | |
4864 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4865 | at::Tensor t0 = at::randn({s1}, options); |
4866 | std::vector<IValue> inputs = {t0}; |
4867 | |
4868 | FusionExecutor fe; |
4869 | fe.compileFusion(&fusion, inputs); |
4870 | auto outputs = fe.runFusion(inputs); |
4871 | |
4872 | auto t1 = t0 + 1; |
4873 | auto t2 = gather(t1, window_shape, padding_width, strides); |
4874 | auto ref = sum(t2, {-1}); |
4875 | |
4876 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
4877 | } |
4878 | |
4879 | // Same as GatherStrided4 but different strides |
4880 | TEST_F(NVFuserTest, FusionGatherStrided7_CUDA) { |
4881 | Fusion fusion; |
4882 | FusionGuard fg(&fusion); |
4883 | |
4884 | const std::vector<int> window_shape = {3}; |
4885 | const std::vector<std::vector<int>> padding_width = {{1, 1}}; |
4886 | |
4887 | auto tv0 = makeSymbolicTensor(1); |
4888 | fusion.addInput(tv0); |
4889 | |
4890 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4891 | |
4892 | // Use different strides |
4893 | auto tv2 = gather(tv1, window_shape, padding_width, {3}); |
4894 | auto tv3 = gather(tv1, window_shape, padding_width, {2}); |
4895 | |
4896 | auto tv4 = sum(tv2, {-1}); |
4897 | fusion.addOutput(tv4); |
4898 | |
4899 | auto tv5 = sum(tv3, {-1}); |
4900 | fusion.addOutput(tv5); |
4901 | |
4902 | tv4->split(0, 2); |
4903 | |
4904 | // Since tv3 has a different stride factor, this should fail. |
4905 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
4906 | ASSERT_ANY_THROW(tv0->computeAt(tv4, 1)); |
4907 | } |
4908 | |
4909 | // Same as GatherStrided2 but with unswitch |
4910 | TEST_F(NVFuserTest, FusionGatherStrided8_CUDA) { |
4911 | Fusion fusion; |
4912 | FusionGuard fg(&fusion); |
4913 | |
4914 | const std::vector<int> window_shape = {3}; |
4915 | const std::vector<std::vector<int>> padding_width = {{1, 1}}; |
4916 | const std::vector<int> strides = {3}; |
4917 | |
4918 | auto tv0 = makeSymbolicTensor(1); |
4919 | fusion.addInput(tv0); |
4920 | |
4921 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4922 | |
4923 | auto tv2 = gather(tv1, window_shape, padding_width, strides); |
4924 | |
4925 | auto tv3 = sum(tv2, {-1}); |
4926 | |
4927 | fusion.addOutput(tv3); |
4928 | |
4929 | const int tidx = 32; |
4930 | |
4931 | // Split the strided domain |
4932 | tv3->split(0, tidx); |
4933 | |
4934 | // Split for unswitch |
4935 | tv3->split(0, 1); |
4936 | |
4937 | tv0->computeAt(tv3, 2); |
4938 | |
4939 | tv2->computeAt(tv3, -1); |
4940 | |
4941 | tv3->axis(0)->parallelize(ParallelType::BIDx); |
4942 | tv3->axis(1)->parallelize(ParallelType::Unswitch); |
4943 | tv3->axis(2)->parallelize(ParallelType::TIDx); |
4944 | scheduler_utils::parallelizeAllLike(tv3, {tv1, tv2}); |
4945 | |
4946 | tv1->setMemoryType(MemoryType::Shared); |
4947 | |
4948 | const int s1 = 1023; |
4949 | |
4950 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
4951 | at::Tensor t0 = at::randn({s1}, options); |
4952 | std::vector<IValue> inputs = {t0}; |
4953 | |
4954 | FusionExecutor fe; |
4955 | fe.compileFusion(&fusion, inputs); |
4956 | auto outputs = fe.runFusion(inputs); |
4957 | |
4958 | auto t1 = t0 + 1; |
4959 | auto t2 = gather(t1, window_shape, padding_width, strides); |
4960 | auto ref = sum(t2, {-1}); |
4961 | |
4962 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
4963 | } |
4964 | |
4965 | // Chained strided gather. Not supported yet. |
4966 | TEST_F(NVFuserTest, FusionGatherStridedChain_CUDA) { |
4967 | Fusion fusion; |
4968 | FusionGuard fg(&fusion); |
4969 | |
4970 | const std::vector<int> window_shape = {3}; |
4971 | const std::vector<std::vector<int>> padding_width = {{1, 1}}; |
4972 | const std::vector<int> strides = {3}; |
4973 | // const std::vector<int> strides = {1}; |
4974 | |
4975 | auto tv0 = makeSymbolicTensor(1); |
4976 | fusion.addInput(tv0); |
4977 | |
4978 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
4979 | |
4980 | auto tv2 = gather(tv1, window_shape, padding_width, strides); |
4981 | // Reduce gathered window |
4982 | auto tv3 = sum(tv2, {-1}); |
4983 | |
4984 | // Repeat |
4985 | auto tv4 = gather(tv3, window_shape, padding_width, strides); |
4986 | auto tv5 = sum(tv4, {-1}); |
4987 | auto out = tv5; |
4988 | |
4989 | fusion.addOutput(out); |
4990 | |
4991 | // This should throw an error at HaloInfo::build. |
4992 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
4993 | ASSERT_ANY_THROW(GpuLower gpulw(&fusion)); |
4994 | } |
4995 | |
4996 | TEST_F(NVFuserTest, FusionMaxPoolingStrided_CUDA) { |
4997 | Fusion fusion; |
4998 | FusionGuard fg(&fusion); |
4999 | |
5000 | // Input: CHW |
5001 | // Pooling window: 3x3 |
5002 | // Strides: 3 |
5003 | // Padding: 1 at each end of the inner 2 dimensions |
5004 | |
5005 | // [C, H, W] |
5006 | auto inp = makeSymbolicTensor(3); |
5007 | fusion.addInput(inp); |
5008 | |
5009 | // [C, H/3, W/3, 1, 3, 3] |
5010 | auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}}, {1, 3, 3}); |
5011 | |
5012 | // [C, H/3, W/3] |
5013 | auto max_tensor = reductionOp( |
5014 | BinaryOpType::Max, |
5015 | {-3, -2, -1}, |
5016 | IrBuilder::create<Double>(std::numeric_limits<float>::lowest()), |
5017 | inp_tile); |
5018 | fusion.addOutput(max_tensor); |
5019 | |
5020 | //////////////////////////////////// |
5021 | |
5022 | // Cache the input and weight tensors |
5023 | auto inp_cache = inp->cacheAfter(); |
5024 | |
5025 | // Tiling the spatial domain |
5026 | const int tile_x = 32; |
5027 | const int tile_y = 8; |
5028 | |
5029 | max_tensor->split(1, tile_y); |
5030 | max_tensor->split(3, tile_x); |
5031 | max_tensor->reorder({{2, 3}}); |
5032 | // [C, H/tile_y, W/tile_x, tile_y, tile_x] |
5033 | max_tensor->split(2, 1); |
5034 | // [C, H/tile_y, W/tile_x, 1, tile_y, tile_x] |
5035 | |
5036 | inp->computeAt(max_tensor, 4); |
5037 | |
5038 | max_tensor->axis(0)->parallelize(ParallelType::BIDx); |
5039 | max_tensor->axis(3)->parallelize(ParallelType::Unswitch); |
5040 | max_tensor->axis(4)->parallelize(ParallelType::TIDy); |
5041 | max_tensor->axis(5)->parallelize(ParallelType::TIDx); |
5042 | |
5043 | scheduler_utils::parallelizeAllLike(max_tensor); |
5044 | |
5045 | inp_cache->setMemoryType(MemoryType::Shared); |
5046 | |
5047 | const int hw = 50; |
5048 | const int num_channels = 20; |
5049 | const int pooling_window = 3; |
5050 | |
5051 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
5052 | at::Tensor aten_inp = at::randn({num_channels, hw, hw}, options); |
5053 | // We always pad inputs by zero, so if all surrounding values are |
5054 | // negative, max pooling would pick a padded value, which isn't the |
5055 | // correct behavior. We need to be able to choose the value of |
5056 | // padding. In this case, padding by the minimum value would not |
5057 | // have this problem. For now, avoid the problem by making sure all |
5058 | // values are not negative. |
5059 | aten_inp = at::abs(aten_inp); |
5060 | std::vector<IValue> inputs = {aten_inp}; |
5061 | |
5062 | FusionExecutor fe; |
5063 | fe.compileFusion(&fusion, inputs); |
5064 | auto outputs = fe.runFusion(inputs); |
5065 | |
5066 | auto ref = at::max_pool2d( |
5067 | aten_inp, {pooling_window, pooling_window}, {3, 3}, {1, 1}); |
5068 | |
5069 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
5070 | } |
5071 | |
5072 | TEST_F(NVFuserTest, FusionConv2DStaticStrided_CUDA) { |
5073 | Fusion fusion; |
5074 | FusionGuard fg(&fusion); |
5075 | |
5076 | // Input: [C, H, W] |
5077 | auto inp = makeSymbolicTensor(3); |
5078 | fusion.addInput(inp); |
5079 | |
5080 | // Weights: [K, C, 3, 3] |
5081 | auto w = makeSymbolicTensor(4); |
5082 | fusion.addInput(w); |
5083 | |
5084 | // Gather a neighbor tile of [3, 3] with padding size of 1 for each |
5085 | // side of the spatial dimensions |
5086 | auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}}, {1, 3, 3}); |
5087 | // inp_tile: [C, H/3, s3, W/3, s3, 1, 3, 3] |
5088 | |
5089 | auto inp_bc = |
5090 | broadcast(inp_tile, {true, false, false, false, false, false, false}); |
5091 | auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
5092 | |
5093 | auto inp_times_w = mul(inp_bc, w_bc); |
5094 | |
5095 | // Reduce the channel and neighbor tile dimensions |
5096 | auto out = sum(inp_times_w, {1, 4, 5, 6}); |
5097 | |
5098 | fusion.addOutput(out); |
5099 | |
5100 | //////////////////////////////////// |
5101 | |
5102 | // Cache the input and weight tensors |
5103 | auto inp_cache = inp->cacheAfter(); |
5104 | |
5105 | // Blocking the spatial dimensions |
5106 | const int block_w = 16; |
5107 | const int block_h = 4; |
5108 | const int block_c = 2; |
5109 | |
5110 | // [K, C, H/s, W/s, 1, 3, 3] |
5111 | out->split(2, block_h); |
5112 | // [K, C, H/s/block_h, block_h, W/s, 1, 3, 3] |
5113 | out->split(4, block_w); |
5114 | // [K, C, H/s/block_h, block_h, W/s/block_w, block_w, 1, 3, 3] |
5115 | out->reorder({{3, 4}}); |
5116 | // [K, C, H/s/block_h, W/s/block_w, block_h, block_w, 1, 3, 3] |
5117 | out->split(1, block_c); |
5118 | // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, block_h, block_w, 1, 3, |
5119 | // 3] |
5120 | out->split(4, 1); |
5121 | // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, 1, |
5122 | // 3, 3] |
5123 | |
5124 | auto out_rf = out->rFactor({1, -3, -2, -1}); |
5125 | // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, 1, |
5126 | // 3, 3] |
5127 | |
5128 | // out: [K, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w] |
5129 | |
5130 | inp_cache->computeAt(out, 5); |
5131 | inp_cache->setMemoryType(MemoryType::Shared); |
5132 | // [K, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, C/block_c, 1, |
5133 | // 3, 3] |
5134 | |
5135 | // Move C/block_c before block_h/2 and share the domain from |
5136 | // inp_cache to out_rf |
5137 | out_rf->reorder({{7, 5}, {5, 6}, {6, 7}}); |
5138 | inp_cache->computeAt(out_rf, 6); |
5139 | |
5140 | inp_tile->computeAt(out_rf, -1); |
5141 | w->computeAt(out_rf, -1); |
5142 | |
5143 | out->axis(0)->parallelize(ParallelType::BIDx); |
5144 | out->axis(1)->parallelize(ParallelType::TIDz); |
5145 | out->axis(4)->parallelize(ParallelType::Unswitch); |
5146 | out->axis(5)->parallelize(ParallelType::TIDy); |
5147 | out->axis(6)->parallelize(ParallelType::TIDx); |
5148 | |
5149 | scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
5150 | |
5151 | const int dim_h = 99; |
5152 | const int dim_w = 101; |
5153 | const int dim_c = 10; |
5154 | const int dim_f = 20; |
5155 | |
5156 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
5157 | at::manual_seed(0); |
5158 | at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
5159 | at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); |
5160 | std::vector<IValue> inputs = {at_inp, at_w}; |
5161 | |
5162 | FusionExecutor fe; |
5163 | fe.compileFusion(&fusion, inputs); |
5164 | auto cg_outputs = fe.runFusion(inputs); |
5165 | |
5166 | at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
5167 | auto at_out = at::conv2d(at_inp, at_w, {}, 3, 1); |
5168 | at_out = at_out.squeeze(0); // drop the N axis |
5169 | |
5170 | testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
5171 | } |
5172 | |
5173 | TEST_F(NVFuserTest, FusionNonDivisibleHalo1_CUDA) { |
5174 | Fusion fusion; |
5175 | FusionGuard fg(&fusion); |
5176 | |
5177 | auto tv0 = makeSymbolicTensor(1); |
5178 | fusion.addInput(tv0); |
5179 | |
5180 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
5181 | auto tv2 = shift(tv1, {-1}); |
5182 | fusion.addOutput(tv2); |
5183 | |
5184 | // [I] |
5185 | tv2->split(0, 8); |
5186 | // [I/8, 8] |
5187 | tv2->split(1, 3); |
5188 | // [I/8, 3, 3] |
5189 | |
5190 | tv0->computeAt(tv2, -2); |
5191 | |
5192 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
5193 | at::Tensor t0 = at::randn({24}, options); |
5194 | |
5195 | FusionExecutor fe; |
5196 | fe.compileFusion(&fusion, {t0}); |
5197 | auto cg_outputs = fe.runFusion({t0}); |
5198 | |
5199 | auto ref = shift((t0 + 1), {-1}); |
5200 | |
5201 | testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); |
5202 | } |
5203 | |
5204 | TEST_F(NVFuserTest, FusionNonDivisibleHalo2_CUDA) { |
5205 | Fusion fusion; |
5206 | FusionGuard fg(&fusion); |
5207 | |
5208 | auto tv0 = makeSymbolicTensor(2); |
5209 | fusion.addInput(tv0); |
5210 | |
5211 | auto tv1 = gather(tv0, {3, 3}, {{1, 1}, {1, 1}}); |
5212 | auto tv2 = sum(tv1, {-2, -1}); |
5213 | auto tv3 = add(tv0, tv2); |
5214 | auto tv4 = sum(tv3, {0, 1}); |
5215 | fusion.addOutput(tv4); |
5216 | |
5217 | const int gy = 50; |
5218 | const int gx = 50; |
5219 | const int by = 8; |
5220 | const int bx = 16; |
5221 | |
5222 | auto tv5 = tv0->cacheAfter(); |
5223 | |
5224 | // [I, J] |
5225 | tv4->split(0, gy); |
5226 | // [I/gy, gy, J] |
5227 | tv4->split(1, by); |
5228 | // [I/gy, gy/by, by, J] |
5229 | tv4->split(-1, gx); |
5230 | // [I/gy, gy/by, by, J/gx, gx] |
5231 | tv4->split(-1, bx); |
5232 | // [I/gy, gy/by, by, J/gx, gx/bx, bx] |
5233 | tv4->reorder({{3, 1}, {1, 2}, {4, 3}, {2, 4}}); |
5234 | // [I/gy, J/gx, gy/by, gx/bx, by, bx] |
5235 | |
5236 | auto tv6 = tv4->rFactor({2, 3}); |
5237 | |
5238 | tv0->computeAt(tv6, 4); |
5239 | |
5240 | tv4->axis(0)->parallelize(ParallelType::BIDy); |
5241 | tv4->axis(1)->parallelize(ParallelType::BIDx); |
5242 | tv4->axis(2)->parallelize(ParallelType::TIDy); |
5243 | tv4->axis(3)->parallelize(ParallelType::TIDx); |
5244 | |
5245 | scheduler_utils::parallelizeAllLike(tv4, {tv1, tv2, tv3, tv5, tv6}); |
5246 | |
5247 | tv5->setMemoryType(MemoryType::Shared); |
5248 | |
5249 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
5250 | at::Tensor t0 = at::randn({111, 222}, options); |
5251 | |
5252 | FusionExecutor fe; |
5253 | fe.compileFusion(&fusion, {t0}); |
5254 | auto cg_outputs = fe.runFusion({t0}); |
5255 | |
5256 | auto t1 = gather(t0, {3, 3}, {{1, 1}, {1, 1}}); |
5257 | auto t2 = t1.sum({-2, -1}); |
5258 | auto t3 = t0 + t2; |
5259 | auto t4 = t3.sum({-2, -1}); |
5260 | |
5261 | testValidate(&fusion, cg_outputs, {t0}, {t4}, __LINE__, __FILE__); |
5262 | } |
5263 | |
5264 | TEST_F(NVFuserTest, FusionGather9ptStencilDoubleBuffering_CUDA) { |
5265 | Fusion fusion; |
5266 | FusionGuard fg(&fusion); |
5267 | |
5268 | auto tv0 = makeSymbolicTensor(2); |
5269 | fusion.addInput(tv0); |
5270 | |
5271 | auto tv1 = gather(tv0, {3, 3}, {{1, 1}, {1, 1}}); |
5272 | auto tv2 = sum(tv1, {-2, -1}); |
5273 | auto tv3 = div(tv2, IrBuilder::create<Double>(9)); |
5274 | |
5275 | auto out = tv3; |
5276 | |
5277 | fusion.addOutput(out); |
5278 | |
5279 | auto tv0_cache = tv0->cacheAfter(); |
5280 | |
5281 | tv0_cache->setMemoryType(MemoryType::Shared); |
5282 | |
5283 | out->split(-2, 4); |
5284 | out->split(-1, 32); |
5285 | out->reorder({{1, 2}, {2, 1}}); |
5286 | TransformPropagator propagator(out); |
5287 | MaxRootDomainInfoSpanningTree(out).traverse(&propagator); |
5288 | |
5289 | tv0->computeAt(out, 2); |
5290 | |
5291 | out->axis(3)->parallelize(ParallelType::TIDx); |
5292 | out->axis(2)->parallelize(ParallelType::TIDy); |
5293 | out->axis(0)->parallelize(ParallelType::BIDx); |
5294 | |
5295 | scheduler_utils::parallelizeAllLike(out); |
5296 | |
5297 | tv0_cache->doubleBuffer(); |
5298 | |
5299 | int numel_x = 99; |
5300 | int numel_y = 101; |
5301 | |
5302 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
5303 | at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
5304 | std::vector<IValue> inputs = {t0}; |
5305 | |
5306 | FusionExecutor fe; |
5307 | fe.compileFusion(&fusion, inputs); |
5308 | auto outputs = fe.runFusion(inputs); |
5309 | |
5310 | auto t1 = gather(t0, {3, 3}, {{1, 1}, {1, 1}}); |
5311 | auto t2 = sum(t1, {-2, -1}); |
5312 | auto t3 = t2 / 9; |
5313 | auto ref = t3; |
5314 | |
5315 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
5316 | } |
5317 | |
5318 | TEST_F(NVFuserTest, FusionValidateParallelizeShift_CUDA) { |
5319 | Fusion fusion; |
5320 | FusionGuard fg(&fusion); |
5321 | |
5322 | auto tv0 = makeSymbolicTensor(1); |
5323 | fusion.addInput(tv0); |
5324 | |
5325 | auto tv1 = set(tv0); |
5326 | auto tv2 = shift(tv1, {1}); |
5327 | auto tv3 = shift(tv1, {-1}); |
5328 | auto tv4 = add(tv1, tv2); |
5329 | auto tv5 = add(tv4, tv3); |
5330 | fusion.addOutput(tv5); |
5331 | |
5332 | tv1->setMemoryType(MemoryType::Shared); |
5333 | |
5334 | tv5->split(-1, 1024); |
5335 | tv5->split(-1, 2); |
5336 | TransformPropagator propagator(tv5); |
5337 | MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); |
5338 | |
5339 | tv0->computeAt(tv5, 1); |
5340 | |
5341 | tv5->axis(1)->parallelize(ParallelType::TIDx); |
5342 | |
5343 | scheduler_utils::parallelizeAllLike(tv5); |
5344 | |
5345 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
5346 | at::Tensor t0 = at::randn({1024 * 32}, options); |
5347 | std::vector<IValue> inputs = {t0}; |
5348 | |
5349 | FusionExecutor fe; |
5350 | fe.compileFusion(&fusion, inputs); |
5351 | auto outputs = fe.runFusion(inputs); |
5352 | |
5353 | auto ref = t0 + shift(t0, {1}) + shift(t0, {-1}); |
5354 | |
5355 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
5356 | } |
5357 | |
5358 | // Test IterType promotion with gather |
5359 | TEST_F(NVFuserTest, FusionGatherIterTypePromotion_CUDA) { |
5360 | Fusion fusion; |
5361 | FusionGuard fg(&fusion); |
5362 | |
5363 | const int s1 = 11; |
5364 | const int s2 = 3; |
5365 | |
5366 | auto tv0 = makeConcreteTensor({s1}); |
5367 | fusion.addInput(tv0); |
5368 | auto tv1 = makeConcreteTensor({s1, s2}); |
5369 | fusion.addInput(tv1); |
5370 | |
5371 | const std::vector<int> window_shape = {3}; |
5372 | const std::vector<std::vector<int>> padding_width = {{1, 1}}; |
5373 | |
5374 | auto tv2 = gather(tv0, window_shape, padding_width); |
5375 | auto tv3 = add(tv2, tv1); |
5376 | |
5377 | fusion.addOutput(tv3); |
5378 | |
5379 | TORCH_CHECK( |
5380 | tv3->axis(1)->getIterType() == IterType::Iteration, |
5381 | "Invalid IterType promotion: " , |
5382 | tv3->axis(1)->toString()); |
5383 | |
5384 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
5385 | at::Tensor t0 = at::randn({s1}, options); |
5386 | at::Tensor t1 = at::randn({s1, s2}, options); |
5387 | std::vector<IValue> inputs = {t0, t1}; |
5388 | |
5389 | auto ref = gather(t0, window_shape, padding_width) + t1; |
5390 | |
5391 | FusionExecutor fe; |
5392 | fe.compileFusion(&fusion, inputs); |
5393 | auto outputs = fe.runFusion(inputs); |
5394 | |
5395 | testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
5396 | } |
5397 | |
5398 | TEST_F(NVFuserTest, FusionContigPredicateShift_CUDA) { |
5399 | Fusion fusion; |
5400 | FusionGuard fg(&fusion); |
5401 | |
5402 | std::vector<int64_t> shape({2, 2}); |
5403 | |
5404 | auto tv0 = makeConcreteTensor(shape); |
5405 | // [0:I] |
5406 | fusion.addInput(tv0); |
5407 | |
5408 | // Below, tv2 and tv3 are mostly the same, except for tv2 is padded |
5409 | // with 0, whereas tv3 is not, so the valid range of tv3 is [0:I-1] |
5410 | |
5411 | // [0:I] |
5412 | auto tv1 = shift(tv0, {-1, 0}); |
5413 | |
5414 | // [0:I-1] |
5415 | auto tv2 = shift(tv0, {-1, 0}, false); |
5416 | |
5417 | // tv3 is not an output of shift, but it gets a partial root |
5418 | // domain from tv2, so it must be predicated at the root domain |
5419 | auto tv3 = add(tv2, IrBuilder::create<Double>(1)); |
5420 | |
5421 | fusion.addOutput(tv1); |
5422 | fusion.addOutput(tv3); |
5423 | |
5424 | // contig merge |
5425 | tv1->merge(0); |
5426 | tv1->split(0, 4); |
5427 | TransformPropagator propagator(tv1); |
5428 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
5429 | |
5430 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
5431 | |
5432 | // Create 3x2 and trim to 2x2. This would cause the output tensor |
5433 | // non-zero values if not properly predicated. |
5434 | at::Tensor t0 = at::randn({3, 2}, options); |
5435 | t0 = t0.index( |
5436 | {at::indexing::Slice(0, 2), at::indexing::Slice(0, at::indexing::None)}); |
5437 | |
5438 | // Use random output to detect invalid writes |
5439 | at::Tensor t1 = at::rand_like(t0, options); |
5440 | // Use zero-cleared output to detect invalid writes |
5441 | at::Tensor t3 = at::zeros_like(t0, options); |
5442 | |
5443 | std::vector<IValue> inputs = {t0}; |
5444 | std::vector<at::Tensor> outputs = {t1, t3}; |
5445 | |
5446 | std::vector<at::indexing::TensorIndex> indices{ |
5447 | at::indexing::Slice(0, -1), at::indexing::Slice(0, at::indexing::None)}; |
5448 | |
5449 | FusionExecutor fe; |
5450 | fe.compileFusion(&fusion, inputs); |
5451 | fe.runFusion(inputs, outputs); |
5452 | |
5453 | // Make sure the padded region is zero filled |
5454 | TORCH_CHECK(t1[1].equal(at::zeros(2, options))); |
5455 | // Make sure not touched as the shift is not padded |
5456 | TORCH_CHECK(t3[1].equal(at::zeros(2, options))); |
5457 | |
5458 | auto ref = shift(t0, {-1, 0}); |
5459 | |
5460 | TORCH_CHECK(t1.equal(ref)); |
5461 | TORCH_CHECK(t3.index(indices).equal((ref + 1).index(indices))); |
5462 | } |
5463 | |
5464 | } // namespace jit |
5465 | } // namespace torch |
5466 | #endif // #if defined(USE_CUDA) |
5467 | |