1#if defined(USE_CUDA)
2#include <gtest/gtest.h>
3
4#include <arith.h>
5#include <codegen.h>
6#include <disjoint_set.h>
7#include <executor.h>
8#include <executor_launch_params.h>
9#include <expr_evaluator.h>
10#include <fusion.h>
11#include <fusion_segmenter.h>
12#include <grouped_reduction.h>
13#include <inlining.h>
14#include <ir_all_nodes.h>
15#include <ir_builder.h>
16#include <ir_graphviz.h>
17#include <ir_iostream.h>
18#include <ir_utils.h>
19#include <iter_visitor.h>
20#include <kernel_cache.h>
21#include <kernel_expr_evaluator.h>
22#include <kernel_ir.h>
23#include <lower2device.h>
24#include <mutator.h>
25#include <register_interface.h>
26#include <root_domain_map.h>
27#include <scheduler/all_schedulers.h>
28#include <scheduler/reduction_utils.h>
29#include <scheduler/utils.h>
30#include <test/test_gpu_validator.h>
31#include <test/test_utils.h>
32#include <transform_replay.h>
33#include <transform_rfactor.h>
34
35// fuser and IR parser
36#include <ATen/cuda/CUDAContext.h>
37#include <ATen/cuda/Exceptions.h>
38#include <c10/cuda/CUDAStream.h>
39
40#include <algorithm>
41#include <iostream>
42
43// Tests go in torch::jit
44namespace torch {
45namespace jit {
46
47using namespace torch::jit::fuser::cuda;
48using namespace at::indexing;
49
50namespace {
51
52class KernelExprVisitor : private kir::IrVisitor {
53 public:
54 static std::vector<Expr*> getAllExprs(const kir::Kernel* kernel) {
55 KernelExprVisitor visitor(kernel);
56 return visitor.all_exprs_;
57 }
58
59 private:
60 KernelExprVisitor(const kir::Kernel* kernel) {
61 handle(kernel->topLevelExprs());
62 }
63
64 using kir::IrVisitor::handle;
65
66 void handle(Expr* expr) final {
67 all_exprs_.push_back(expr);
68 kir::IrVisitor::handle(expr);
69 }
70
71 private:
72 std::vector<Expr*> all_exprs_;
73};
74
75void validateNoParallelBroadcastExist(kir::Kernel* kernel) {
76 for (auto expr : KernelExprVisitor::getAllExprs(kernel)) {
77 BroadcastOp* bc = dynamic_cast<BroadcastOp*>(expr);
78 if (bc == nullptr) {
79 auto grid_bc = dynamic_cast<kir::GridBroadcast*>(expr);
80 if (grid_bc != nullptr) {
81 std::cerr << "Grid broadcast: " << grid_bc->toString();
82 bc = grid_bc->broadcast_op();
83 }
84 }
85 if (bc == nullptr) {
86 continue;
87 }
88 TORCH_CHECK(
89 kernel->summary().broadcast_parallel_types.at(bc).none(),
90 "Parallel broadcast should not exist but was found: ",
91 bc->toString());
92 }
93}
94
95} // namespace
96
97TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) {
98 const int nx = 999;
99 const int tidx = 128;
100 const int bidx = 4;
101
102 if (ceilDiv(nx, tidx) > deviceSMCount()) {
103 GTEST_SKIP() << "Not enough SMs to run this test";
104 }
105
106 Fusion fusion;
107 FusionGuard fg(&fusion);
108
109 auto tv0 = makeSymbolicTensor(1);
110 fusion.addInput(tv0);
111
112 auto tv1 = sum(tv0, {0});
113 auto tv2 = broadcast(tv1, {true});
114 auto tv3 = add(tv0, tv2);
115
116 fusion.addOutput(tv3);
117
118 tv3->split(0, tidx);
119 tv3->split(0, bidx);
120 tv3->split(0, 1); // unswitch
121 TransformPropagator propagator(tv3);
122 MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
123
124 tv3->axis(0)->parallelize(ParallelType::BIDy);
125 tv3->axis(2)->parallelize(ParallelType::BIDx);
126 tv3->axis(3)->parallelize(ParallelType::TIDx);
127 scheduler_utils::parallelizeAllLike(tv3);
128
129 // Just to make sure fused_reduction and work buffers are allocated
130 // uniquely
131 tv1->axis(1)->parallelize(ParallelType::Unswitch);
132
133 GpuLower gpulw(&fusion);
134 validateNoParallelBroadcastExist(gpulw.kernel());
135
136 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
137 at::manual_seed(0);
138 auto t0 = at::randn({nx}, options);
139
140 FusionExecutor fe;
141 fe.compileFusion(&fusion, {t0});
142 auto cg_outputs = fe.runFusion({t0});
143
144 auto ref = sum(t0).unsqueeze(0) + t0;
145
146 testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
147}
148
149TEST_F(NVFuserTest, FusionGridAllreduce2_CUDA) {
150 const int nx = 99;
151 const int tidx = 32;
152
153 if (ceilDiv(nx, tidx) > deviceSMCount()) {
154 GTEST_SKIP() << "Not enough SMs to run this test";
155 }
156
157 Fusion fusion;
158 FusionGuard fg(&fusion);
159
160 auto tv0 = makeSymbolicTensor(1);
161 fusion.addInput(tv0);
162
163 auto tv1 = sum(tv0, {0});
164 auto tv2 = broadcast(tv1, {true});
165 auto tv3 = add(tv0, tv2);
166
167 fusion.addOutput(tv3);
168
169 tv3->split(0, tidx);
170 TransformPropagator propagator(tv3);
171 MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
172
173 tv3->axis(0)->parallelize(ParallelType::BIDx);
174 tv3->axis(1)->parallelize(ParallelType::TIDx);
175 scheduler_utils::parallelizeAllLike(tv3, {tv2});
176
177 // Broadcast on TIDy instead of TIDx. This still uses the fused
178 // reduction as it's broadcast on BIDx as well. Since TIDy is not
179 // predicated, the broadcast becomes a set op.
180 tv1->axis(0)->parallelize(ParallelType::BIDx);
181 tv1->axis(1)->parallelize(ParallelType::TIDy);
182
183 GpuLower gpulw(&fusion);
184 validateNoParallelBroadcastExist(gpulw.kernel());
185
186 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
187 at::manual_seed(0);
188 auto t0 = at::randn({nx}, options);
189
190 FusionExecutor fe;
191 fe.compileFusion(&fusion, {t0});
192 auto cg_outputs = fe.runFusion({t0});
193
194 auto ref = sum(t0).unsqueeze(0) + t0;
195
196 testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
197}
198
199// Grid reduction with serial non-reduction axis. The global work
200// buffer is double buffered.
201TEST_F(NVFuserTest, FusionGridAllreduce3_CUDA) {
202 const int nx = 100;
203 const int ny = 5000;
204 const int tidx = 128;
205
206 if (ceilDiv(ny, tidx) > deviceSMCount()) {
207 GTEST_SKIP() << "Not enough SMs to run this test";
208 }
209
210 Fusion fusion;
211 FusionGuard fg(&fusion);
212
213 auto tv0 = makeSymbolicTensor(2);
214 fusion.addInput(tv0);
215
216 auto tv1 = sum(tv0, {1});
217 auto tv2 = broadcast(tv1, {false, true});
218 auto tv3 = add(tv0, tv2);
219
220 fusion.addOutput(tv3);
221
222 tv3->split(1, tidx);
223 TransformPropagator propagator(tv3);
224 MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
225
226 tv0->computeAt(tv3, 1);
227
228 tv3->axis(1)->parallelize(ParallelType::BIDx);
229 tv3->axis(2)->parallelize(ParallelType::TIDx);
230 scheduler_utils::parallelizeAllLike(tv3);
231
232 GpuLower gpulw(&fusion);
233 validateNoParallelBroadcastExist(gpulw.kernel());
234
235 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
236 at::manual_seed(0);
237 auto t0 = at::randn({nx, ny}, options);
238
239 FusionExecutor fe;
240 fe.compileFusion(&fusion, {t0});
241 auto cg_outputs = fe.runFusion({t0});
242
243 auto ref = sum(t0, {1}).unsqueeze(-1) + t0;
244
245 testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
246}
247
248// Indirect reduction and broadcast
249TEST_F(NVFuserTest, FusionGridAllreduce4_CUDA) {
250 const int nx = 999;
251 const int tidx = 128;
252
253 if (ceilDiv(nx, tidx) > deviceSMCount()) {
254 GTEST_SKIP() << "Not enough SMs to run this test";
255 }
256
257 Fusion fusion;
258 FusionGuard fg(&fusion);
259
260 auto tv0 = makeSymbolicTensor(1);
261 fusion.addInput(tv0);
262
263 auto tv1 = sum(tv0, {0});
264 auto tv2 = add(tv1, IrBuilder::create<Double>(1));
265 auto tv3 = broadcast(tv2, {true});
266 auto tv4 = add(tv0, tv3);
267
268 fusion.addOutput(tv4);
269
270 tv4->split(0, tidx);
271 TransformPropagator propagator(tv4);
272 MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator);
273
274 tv4->axis(0)->parallelize(ParallelType::BIDx);
275 tv4->axis(1)->parallelize(ParallelType::TIDx);
276 scheduler_utils::parallelizeAllLike(tv4);
277
278 GpuLower gpulw(&fusion);
279 validateNoParallelBroadcastExist(gpulw.kernel());
280
281 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
282 at::manual_seed(0);
283 auto t0 = at::randn({nx}, options);
284
285 FusionExecutor fe;
286 fe.compileFusion(&fusion, {t0});
287 auto cg_outputs = fe.runFusion({t0});
288
289 auto ref = (sum(t0) + 1).unsqueeze(0) + t0;
290
291 testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
292}
293
294// Unused block dimension in the kernel
295TEST_F(NVFuserTest, FusionGridAllreduce5_CUDA) {
296 const int nx = 999;
297 const int tidx = 128;
298 const int iter = 2;
299 const int bdimx = 9; // One more than required by the reduction
300 const int bdimy = 3; // Want an unused dimension
301
302 // Going to bump the bdimx count for this test, ignor
303 if (bdimx * bdimy > deviceSMCount()) {
304 GTEST_SKIP() << "Not enough SMs to run this test";
305 }
306
307 Fusion fusion;
308 FusionGuard fg(&fusion);
309
310 // Didn't setup this test with inlining for register usage, so just leave the
311 // iter dimension concrete
312 auto tv0 = makeConcreteTensor({iter, -1});
313 fusion.addInput(tv0);
314
315 auto tv1 = sum(tv0, {1});
316 auto tv2 = add(tv1, IrBuilder::create<Double>(1));
317 auto tv3 = broadcast(tv2, {false, true});
318 auto tv4 = add(tv0, tv3);
319
320 fusion.addOutput(tv4);
321
322 // Dummy op to mess with parallelization
323 auto tv5 = makeSymbolicTensor(2);
324 fusion.addInput(tv5);
325 auto tv6 = set(tv5);
326 fusion.addOutput(tv6);
327
328 // Setup the reduction
329 tv4->split(1, tidx);
330 TransformPropagator propagator(tv4);
331 MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator);
332
333 tv4->axis(1)->parallelize(ParallelType::BIDx);
334 tv4->axis(2)->parallelize(ParallelType::TIDx);
335 scheduler_utils::parallelizeAllLike(tv4);
336
337 tv6->axis(0)->parallelize(ParallelType::BIDy);
338 tv6->axis(1)->parallelize(ParallelType::BIDx);
339
340 GpuLower gpulw(&fusion);
341 validateNoParallelBroadcastExist(gpulw.kernel());
342
343 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
344 at::manual_seed(0);
345 auto t0 = at::randn({iter, nx}, options);
346 auto t5 = at::randn({bdimy, bdimx}, options);
347
348 FusionExecutor fe;
349 fe.compileFusion(&fusion, {t0, t5});
350 auto cg_outputs = fe.runFusion({t0, t5});
351
352 auto ref = (sum(t0, {1}) + 1).unsqueeze(-1) + t0;
353
354 testValidate(&fusion, cg_outputs, {t0, t5}, {ref, t5}, __LINE__, __FILE__);
355}
356
357TEST_F(NVFuserTest, FusionGridAllreduce6_CUDA) {
358 Fusion fusion;
359 FusionGuard fg(&fusion);
360
361 std::vector<int64_t> shape({99, 200});
362
363 const int vec = 4;
364 const int tidx = 32;
365 const int tidy = 8;
366 const int bdimx = ceilDiv(shape[1], vec * tidx);
367 const int bdimy = ceilDiv(shape[0], tidy);
368
369 if (bdimx * bdimy > deviceSMCount()) {
370 GTEST_SKIP() << "Not enough SMs to run this test";
371 }
372
373 auto tv0 = makeSymbolicTensor(2);
374 fusion.addInput(tv0);
375
376 auto tv1 = set(tv0);
377 auto tv2 = sum(tv1, {0});
378 auto tv3 = broadcast(tv2, {true, false});
379 auto tv4 = add(tv0, tv3);
380 fusion.addOutput(tv4);
381
382 tv1->split(1, vec);
383 tv1->split(1, tidx);
384 tv1->split(0, tidy);
385 TransformPropagator propagator(tv1);
386 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
387
388 tv1->axis(0)->parallelize(ParallelType::BIDy);
389 tv1->axis(1)->parallelize(ParallelType::TIDy);
390 tv1->axis(2)->parallelize(ParallelType::BIDx);
391 tv1->axis(3)->parallelize(ParallelType::TIDx);
392
393 scheduler_utils::parallelizeAllLike(tv1);
394
395 tv1->axis(4)->parallelize(ParallelType::Vectorize);
396
397 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
398 at::manual_seed(0);
399 auto t0 = at::randn(shape, options);
400
401 FusionExecutor fe;
402 fe.compileFusion(&fusion, {t0});
403 auto outputs = fe.runFusion({t0});
404
405 auto t0_double = t0.to(at::kDouble);
406 auto ref = t0_double + t0_double.sum({0}).unsqueeze(0);
407
408 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
409}
410
411TEST_F(NVFuserTest, FusionGridAllreduceWelford1_CUDA) {
412 const int nx = 999;
413 const int tidx = 128;
414
415 if (ceilDiv(nx, tidx) > deviceSMCount()) {
416 GTEST_SKIP() << "Not enough SMs to run this test";
417 }
418
419 Fusion fusion;
420 FusionGuard fg(&fusion);
421
422 auto tv0 = makeSymbolicTensor(1);
423 fusion.addInput(tv0);
424
425 auto tvs = Welford(tv0, {0});
426 auto tv2 = broadcast(tvs.avg, {true});
427 auto tv3 = broadcast(tvs.var_sum, {true});
428 auto tv4 = add(tv0, tv2);
429 auto tv5 = add(tv4, tv3);
430
431 fusion.addOutput(tv5);
432
433 tv5->split(0, tidx);
434 TransformPropagator propagator(tv5);
435 MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator);
436
437 tv5->axis(0)->parallelize(ParallelType::BIDx);
438 tv5->axis(1)->parallelize(ParallelType::TIDx);
439 scheduler_utils::parallelizeAllLike(tv5);
440
441 GpuLower gpulw(&fusion);
442 validateNoParallelBroadcastExist(gpulw.kernel());
443
444 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
445 at::manual_seed(0);
446 auto t0 = at::randn({nx}, options);
447
448 FusionExecutor fe;
449 fe.compileFusion(&fusion, {t0});
450 auto cg_outputs = fe.runFusion({t0});
451
452 auto ref =
453 (t0.mean({0}).unsqueeze(0) + t0) + t0.var({0}, false).unsqueeze(0) * nx;
454
455 testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
456}
457
458// Grid welford reduction with serial non-reduction axis. The global
459// work buffer is double buffered.
460TEST_F(NVFuserTest, FusionGridAllreduceWelford2_CUDA) {
461 const int nx = 100;
462 const int ny = 5000;
463 const int tidx = 128;
464
465 if (ceilDiv(ny, tidx) > deviceSMCount()) {
466 GTEST_SKIP() << "Not enough SMs to run this test";
467 }
468
469 Fusion fusion;
470 FusionGuard fg(&fusion);
471
472 auto tv0 = makeSymbolicTensor(2);
473 fusion.addInput(tv0);
474
475 auto tvs = Welford(tv0, {1});
476 auto tv2 = broadcast(tvs.avg, {false, true});
477 auto tv3 = add(tv0, tv2);
478
479 fusion.addOutput(tv3);
480
481 tv3->split(1, tidx);
482 TransformPropagator propagator(tv3);
483 MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
484
485 tv0->computeAt(tv3, 1);
486
487 tv3->axis(1)->parallelize(ParallelType::BIDx);
488 tv3->axis(2)->parallelize(ParallelType::TIDx);
489 scheduler_utils::parallelizeAllLike(tv3);
490
491 // There must be no parallel broadcast
492 GpuLower gpulw(&fusion);
493 validateNoParallelBroadcastExist(gpulw.kernel());
494
495 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
496 at::manual_seed(0);
497 auto t0 = at::randn({nx, ny}, options);
498
499 FusionExecutor fe;
500 fe.compileFusion(&fusion, {t0});
501 auto cg_outputs = fe.runFusion({t0});
502
503 auto ref = (sum(t0, {1}) / ny).unsqueeze(-1) + t0;
504
505 testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
506}
507
508// Persistent batchnorm. Uses the fused reduction for grid welford and
509// broadcast.
510TEST_F(NVFuserTest, FusionFusedReductionBatchnorm_CUDA) {
511 const std::vector<int64_t> input_shape{256, 2048, 14, 14};
512
513 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
514 Fusion& fusion = *fusion_ptr.get();
515 FusionGuard fg(&fusion);
516
517 auto tv0 = makeSymbolicTensor(4, DataType::Half);
518 fusion.addInput(tv0);
519 auto tv1 = makeSymbolicTensor(1, DataType::Half);
520 fusion.addInput(tv1);
521 auto tv2 = makeSymbolicTensor(1, DataType::Half);
522 fusion.addInput(tv2);
523 auto tv3 = makeSymbolicTensor(1, DataType::Float);
524 fusion.addInput(tv3);
525 auto tv4 = makeSymbolicTensor(1, DataType::Float);
526 fusion.addInput(tv4);
527
528 auto d34 = IrBuilder::create<Double>(1);
529 auto tv5 = castOp(DataType::Float, tv0);
530 auto tv6 = castOp(DataType::Float, tv1);
531 auto tv7 = castOp(DataType::Float, tv2);
532 auto tvs = Welford(tv5, {0, 2, 3});
533 auto tv8 = tvs.avg;
534 auto tv9 = tvs.var_sum;
535 auto tv10 = tvs.n;
536 auto tv11 = mul(tv8, IrBuilder::create<Double>(0.1));
537 auto tv12 = mul(tv3, d34);
538 auto tv13 = add(tv12, tv11);
539 auto d43 = IrBuilder::create<Double>(0.5);
540 auto tv14 = mul(tv9, d43);
541 auto tv15 = mul(tv14, IrBuilder::create<Double>(0.1));
542 auto tv16 = mul(tv4, d34);
543 auto tv17 = add(tv16, tv15);
544 auto tv18 = broadcast(tv8, {true, false, true, true});
545 auto tv19 = sub(tv5, tv18);
546 auto tv20 = mul(tv9, d43);
547 auto tv21 = add(tv20, IrBuilder::create<Double>(0.0001));
548 auto tv22 = rsqrt(tv21);
549 auto tv23 = broadcast(tv22, {true, false, true, true});
550 auto tv24 = mul(tv19, tv23);
551 auto tv25 = broadcast(tv6, {true, false, true, true});
552 auto tv26 = mul(tv24, tv25);
553 auto tv27 = broadcast(tv7, {true, false, true, true});
554 auto tv28 = add(tv26, tv27);
555 auto tv29 = castOp(DataType::Half, tv28);
556 fusion.addOutput(tv13);
557 fusion.addOutput(tv17);
558 fusion.addOutput(tv29);
559
560 auto tv0_cache = tv0->cacheAfter();
561 auto tv1_cache = tv1->cacheAfter();
562 auto tv2_cache = tv2->cacheAfter();
563 auto tv3_cache = tv3->cacheAfter();
564 auto tv4_cache = tv4->cacheAfter();
565
566 auto tv13_cache = tv13->cacheBefore();
567 auto tv17_cache = tv17->cacheBefore();
568 auto tv29_cache = tv29->cacheBefore();
569
570 tv0->split(1, NamedScalar::getParallelDim(ParallelType::BIDx), false);
571 tv0->split(0, NamedScalar::getParallelDim(ParallelType::BIDy), false);
572 tv0->split(1, 8, false);
573 tv0->split(2, 8, false);
574 tv0->merge(-2, -1);
575 tv0->split(-1, 2);
576 tv0->split(-2, 1, false);
577 tv0->split(-2, 1, false);
578 tv0->reorder(
579 {{4, 0},
580 {5, 1},
581 {0, 2},
582 {3, 3},
583 {8, 4},
584 {1, 5},
585 {7, 6},
586 {2, 7},
587 {9, 8},
588 {6, 9}});
589
590 TransformPropagator propagator(tv0);
591 MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator);
592
593 ir_utils::rfactorHelper(tvs.avg, {-5, -4, -3, -2, -1});
594
595 tv0->computeAt(tv29, 2);
596 tv1->computeAt(tv29, 2);
597 tv2->computeAt(tv29, 2);
598 tv3->computeAt(tv13, 2);
599 tv4->computeAt(tv17, 2);
600
601 tv29->axis(0)->parallelize(ParallelType::BIDx);
602 tv29->axis(2)->parallelize(ParallelType::BIDy);
603 tv29->axis(3)->parallelize(ParallelType::TIDz);
604 tv29->axis(4)->parallelize(ParallelType::TIDx);
605 scheduler_utils::parallelizeAllLike(tv29);
606
607 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
608 auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
609 at::manual_seed(0);
610 auto t0 = at::randn(input_shape, options_half);
611 auto t1 = at::randn(input_shape[1], options_half);
612 auto t2 = at::randn(input_shape[1], options_half);
613 auto t3 = at::randn(input_shape[1], options);
614 auto t4 = at::randn(input_shape[1], options);
615 std::vector<IValue> aten_inputs = {t0, t1, t2, t3, t4};
616
617 GpuLower gpulw(&fusion);
618 validateNoParallelBroadcastExist(gpulw.kernel());
619
620 FusionExecutor fe;
621 LaunchParams launch_params(2, 2, -1, -1, -1, -1);
622 fe.compileFusion(&fusion, aten_inputs, launch_params);
623 auto cg_outputs = fe.runFusion(aten_inputs, launch_params);
624
625 auto t5 = t0.to(at::kFloat);
626 auto t6 = t1.to(at::kFloat);
627 auto t7 = t2.to(at::kFloat);
628 auto t8 = t5.mean({0, 2, 3});
629 auto t9 = t5.var({0, 2, 3}, false) * input_shape[0] * input_shape[2] *
630 input_shape[3];
631 auto t11 = t8 * 0.1;
632 auto t12 = t3 * 1;
633 auto t13 = t12 + t11;
634 auto t14 = t9 * 0.5;
635 auto t15 = t14 * 0.1;
636 auto t16 = t4 * 1;
637 auto t17 = t16 + t15;
638 auto t18 = t8.unsqueeze(0).unsqueeze(-1).unsqueeze(-1);
639 auto t19 = t5 - t18;
640 auto t20 = t9 * 0.5;
641 auto t21 = t20 + 0.0001;
642 auto t22 = rsqrt(t21);
643 auto t23 = t22.unsqueeze(0).unsqueeze(-1).unsqueeze(-1);
644 auto t24 = t19 * t23;
645 auto t25 = t6.unsqueeze(0).unsqueeze(-1).unsqueeze(-1);
646 auto t26 = t24 * t25;
647 auto t27 = t7.unsqueeze(0).unsqueeze(-1).unsqueeze(-1);
648 auto t28 = t26 + t27;
649 auto t29 = t28.to(at::kHalf);
650
651 testValidate(
652 &fusion,
653 cg_outputs,
654 aten_inputs,
655 {t13, t17, t29},
656 __LINE__,
657 __FILE__,
658 "",
659 launch_params);
660}
661
662// Simple grouped reduction
663TEST_F(NVFuserTest, FusionGroupedReduction1_CUDA) {
664 Fusion fusion;
665 FusionGuard fg(&fusion);
666
667 auto tv0 = makeSymbolicTensor(2);
668 fusion.addInput(tv0);
669
670 auto tv1 = sum(tv0, {1});
671 auto tv2 = sum(tv0, {1});
672 auto tv3 = add(tv1, tv2);
673 fusion.addOutput(tv3);
674
675 groupReductions({tv1, tv2});
676
677 tv2->axis(0)->parallelize(ParallelType::BIDx);
678 tv2->axis(1)->parallelize(ParallelType::TIDx);
679 scheduler_utils::parallelizeAllLike(tv2);
680
681 std::vector<int64_t> shape({99, 999});
682
683 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
684
685 auto t0 = at::randn(shape, options);
686
687 FusionExecutor fe;
688 fe.compileFusion(&fusion, {t0});
689 auto outputs = fe.runFusion({t0});
690
691 auto ref = t0.sum({1}) * 2;
692
693 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
694}
695
696// Grouping reductions with different ops
697TEST_F(NVFuserTest, FusionGroupedReduction2_CUDA) {
698 Fusion fusion;
699 FusionGuard fg(&fusion);
700
701 auto tv0 = makeSymbolicTensor(2);
702 fusion.addInput(tv0);
703
704 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
705 auto tv2 = sum(tv1, {1});
706
707 auto tv3 = add(tv0, IrBuilder::create<Double>(2));
708 auto tv4 = max(tv3, {1});
709
710 auto tv5 = add(tv2, tv4);
711 fusion.addOutput(tv5);
712
713 groupReductions({tv2, tv4});
714
715 tv2->split(1, 128);
716 TransformPropagator propagator(tv2);
717 MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);
718
719 tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined);
720
721 // tv4 is automatically parallelized in the same way
722 tv2->axis(0)->parallelize(ParallelType::BIDy);
723 tv2->axis(1)->parallelize(ParallelType::BIDx);
724 tv2->axis(2)->parallelize(ParallelType::TIDx);
725
726 std::vector<int64_t> shape({99, 999});
727
728 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
729
730 auto t0 = at::randn(shape, options);
731
732 FusionExecutor fe;
733 fe.compileFusion(&fusion, {t0});
734 auto outputs = fe.runFusion({t0});
735
736 auto ref = (t0 + 1).sum({1}) + std::get<0>((t0 + 2).max(1));
737
738 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
739}
740
741// Grouped reduction with different types
742TEST_F(NVFuserTest, FusionGroupedReduction3_CUDA) {
743 Fusion fusion;
744 FusionGuard fg(&fusion);
745
746 auto tv0 = makeSymbolicTensor(2);
747 fusion.addInput(tv0);
748
749 auto tv1 = sum(tv0, {1});
750
751 auto tv2 = castOp(DataType::Double, tv0);
752 auto tv3 = sum(tv2, {1});
753 auto tv4 = castOp(DataType::Float, tv3);
754
755 auto tv5 = add(tv1, tv4);
756 fusion.addOutput(tv5);
757
758 groupReductions({tv1, tv3});
759 tv1->split(1, 128);
760 TransformPropagator propagator(tv1);
761 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
762
763 tv0->computeAt(tv5, -1, ComputeAtMode::MostInlined);
764
765 tv1->axis(0)->parallelize(ParallelType::BIDy);
766 tv1->axis(1)->parallelize(ParallelType::BIDx);
767 tv1->axis(2)->parallelize(ParallelType::TIDx);
768
769 std::vector<int64_t> shape({99, 999});
770
771 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
772
773 auto t0 = at::randn(shape, options);
774
775 FusionExecutor fe;
776 fe.compileFusion(&fusion, {t0});
777 auto outputs = fe.runFusion({t0});
778
779 auto ref = t0.sum({1}) + t0.to(c10::kDouble).sum({1}).to(c10::kFloat);
780
781 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
782}
783
784// Testing validation
785TEST_F(NVFuserTest, FusionGroupedReduction4_CUDA) {
786 Fusion fusion;
787 FusionGuard fg(&fusion);
788
789 auto tv0 = makeSymbolicTensor(2);
790 fusion.addInput(tv0);
791 auto tv1 = makeSymbolicTensor(2);
792 fusion.addInput(tv1);
793
794 auto tv2 = sum(tv0, {1});
795 auto tv3 = sum(tv1, {1});
796 auto tv4 = add(tv2, tv3);
797 fusion.addOutput(tv4);
798
799 // Invalid grouping as tv2 and tv3 are not guaranteed to have the
800 // same shape
801 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
802 ASSERT_ANY_THROW(groupReductions({tv2, tv3}));
803}
804
805// Testing validation
806TEST_F(NVFuserTest, FusionGroupedReduction5_CUDA) {
807 Fusion fusion;
808 FusionGuard fg(&fusion);
809
810 auto tv0 = makeSymbolicTensor(2);
811 fusion.addInput(tv0);
812
813 auto tv1 = sum(tv0, {1});
814 auto tv2 = sum(tv0, {1});
815 auto tv3 = add(tv1, tv2);
816 fusion.addOutput(tv3);
817
818 tv1->split(1, 128);
819 tv2->split(1, 64);
820
821 // Invalid grouping as tv1 and tv2 don't have the same
822 // transformations
823 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
824 ASSERT_ANY_THROW(groupReductions({tv1, tv2}));
825}
826
827// Grouping 3 reductions
828TEST_F(NVFuserTest, FusionGroupedReduction6_CUDA) {
829 Fusion fusion;
830 FusionGuard fg(&fusion);
831
832 auto tv0 = makeSymbolicTensor(2);
833 fusion.addInput(tv0);
834
835 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
836 auto tv2 = sum(tv1, {1});
837
838 auto tv3 = add(tv0, IrBuilder::create<Double>(2));
839 auto tv4 = sum(tv3, {1});
840
841 auto tv5 = add(tv0, IrBuilder::create<Double>(3));
842 auto tv6 = sum(tv5, {1});
843
844 auto tv7 = add(add(tv2, tv4), tv6);
845
846 fusion.addOutput(tv7);
847
848 groupReductions({tv2, tv4, tv6});
849
850 // There's no runtime grid reduction function that can take more
851 // than 2 inputs, yet.
852 tv2->axis(0)->parallelize(ParallelType::BIDx);
853 tv2->axis(1)->parallelize(ParallelType::TIDx);
854
855 scheduler_utils::parallelizeAllLike(tv2);
856
857 std::vector<int64_t> shape({99, 999});
858
859 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
860
861 auto t0 = at::randn(shape, options);
862
863 FusionExecutor fe;
864 fe.compileFusion(&fusion, {t0});
865 auto outputs = fe.runFusion({t0});
866
867 auto ref = (t0 + 1).sum({1}) + (t0 + 2).sum({1}) + (t0 + 3).sum({1});
868
869 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
870}
871
872TEST_F(NVFuserTest, FusionGroupedReduction7_CUDA) {
873 Fusion fusion;
874 FusionGuard fg(&fusion);
875
876 auto tv0 = makeSymbolicTensor(2);
877 fusion.addInput(tv0);
878
879 auto tv1 = sum(tv0, {1});
880 auto tv2 = broadcast(tv1, {false, true});
881 auto tv3 = add(tv0, tv2);
882 auto tv4 = sum(tv3, {1});
883 fusion.addOutput(tv4);
884
885 // Invalid grouping as tv3 depends on tv1
886 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
887 ASSERT_ANY_THROW(groupReductions({tv1, tv4}));
888}
889
890// Grouping rfactor'ed reductions
891TEST_F(NVFuserTest, FusionGroupedReductionRfactor1_CUDA) {
892 Fusion fusion;
893 FusionGuard fg(&fusion);
894
895 auto tv0 = makeSymbolicTensor(1);
896 fusion.addInput(tv0);
897
898 auto tv1 = sum(tv0, {0});
899 auto tv2 = sum(tv0, {0});
900 auto tv3 = add(tv1, tv2);
901 fusion.addOutput(tv3);
902
903 const size_t gdimx = 10;
904 const size_t bdimx = 128;
905
906 tv1->split(0, gdimx, false);
907 tv1->split(1, bdimx);
908 auto tv1_rf = tv1->rFactor({1});
909
910 tv2->split(0, gdimx, false);
911 tv2->split(1, bdimx);
912 auto tv2_rf = tv2->rFactor({1});
913
914 groupReductions({tv1_rf, tv2_rf});
915 groupReductions({tv1, tv2});
916
917 tv1_rf->axis(0)->parallelize(ParallelType::BIDx);
918 tv1_rf->axis(2)->parallelize(ParallelType::TIDx);
919
920 scheduler_utils::parallelizeAllLike(tv1_rf);
921
922 std::vector<int64_t> shape({12345});
923
924 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
925
926 auto t0 = at::randn(shape, options);
927
928 FusionExecutor fe;
929 fe.compileFusion(&fusion, {t0});
930 auto outputs = fe.runFusion({t0});
931
932 auto ref = t0.sum({0}) * 2;
933
934 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
935}
936
937// Rfactoring grouped reductions
938TEST_F(NVFuserTest, FusionGroupedReductionRfactor2_CUDA) {
939 Fusion fusion;
940 FusionGuard fg(&fusion);
941
942 auto tv0 = makeSymbolicTensor(1);
943 fusion.addInput(tv0);
944
945 auto tv1 = sum(tv0, {0});
946 auto tv2 = sum(tv0, {0});
947 auto tv3 = add(tv1, tv2);
948 fusion.addOutput(tv3);
949
950 groupReductions({tv1, tv2});
951
952 const size_t gdimx = 10;
953 const size_t bdimx = 128;
954
955 tv1->split(0, gdimx, false);
956 tv1->split(1, bdimx);
957
958 // This should rfactor tv2 as well
959 auto rf_tvs = tv1->rFactor({1}, {tv1, tv2});
960 auto tv1_rf = rf_tvs.at(0);
961
962 tv1_rf->axis(0)->parallelize(ParallelType::BIDx);
963 tv1_rf->axis(2)->parallelize(ParallelType::TIDx);
964
965 scheduler_utils::parallelizeAllLike(tv1_rf);
966
967 std::vector<int64_t> shape({12345});
968
969 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
970
971 auto t0 = at::randn(shape, options);
972
973 FusionExecutor fe;
974 fe.compileFusion(&fusion, {t0});
975 auto outputs = fe.runFusion({t0});
976
977 auto ref = t0.sum({0}) * 2;
978
979 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
980}
981
982// Group reductions of tensors that have computeAt positions set
983TEST_F(NVFuserTest, FusionGroupedReductionAfterComputeAt_CUDA) {
984 Fusion fusion;
985 FusionGuard fg(&fusion);
986 auto tv0 = makeSymbolicTensor(2);
987 fusion.addInput(tv0);
988
989 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
990 auto tv2 = sum(tv1, {1});
991 auto tv3 = sum(tv1, {1});
992 auto tv4 = add(tv2, tv3);
993 fusion.addOutput(tv4);
994
995 const size_t bdimx = 128;
996
997 tv2->split(1, bdimx);
998 auto tv2_rf = tv2->rFactor({1});
999 tv2_rf->reorder({{1, 2}});
1000
1001 tv3->split(1, bdimx);
1002 auto tv3_rf = tv3->rFactor({1});
1003 tv3_rf->reorder({{1, 2}});
1004
1005 tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined);
1006
1007 groupReductions({tv2_rf, tv3_rf});
1008 groupReductions({tv2, tv3});
1009
1010 tv2->axis(1)->parallelize(ParallelType::TIDx);
1011 scheduler_utils::parallelizeAllLike(tv2);
1012
1013 std::vector<int64_t> shape({3, 1234});
1014
1015 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1016
1017 auto t0 = at::randn(shape, options);
1018
1019 FusionExecutor fe;
1020 fe.compileFusion(&fusion, {t0});
1021 auto outputs = fe.runFusion({t0});
1022
1023 auto ref = (t0 + 1).sum({1}) * 2;
1024
1025 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
1026}
1027
1028TEST_F(NVFuserTest, FusionGroupAllreduce1_CUDA) {
1029 Fusion fusion;
1030 FusionGuard fg(&fusion);
1031
1032 auto tv0 = makeSymbolicTensor(1);
1033 fusion.addInput(tv0);
1034
1035 auto tv1 = sum(tv0, {0});
1036 auto tv2 = broadcast(tv1, {true});
1037 auto tv3 = sum(tv0, {0});
1038 auto tv4 = broadcast(tv3, {true});
1039 auto tv5 = add(tv0, tv2);
1040 auto tv6 = add(tv5, tv4);
1041 fusion.addOutput(tv6);
1042
1043 groupReductions({tv1, tv3});
1044
1045 tv2->split(0, 128);
1046 TransformPropagator propagator(tv2);
1047 MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);
1048
1049 tv2->axis(0)->parallelize(ParallelType::BIDx);
1050 tv2->axis(1)->parallelize(ParallelType::TIDx);
1051 scheduler_utils::parallelizeAllLike(tv2);
1052
1053 std::vector<int64_t> shape({999});
1054
1055 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1056
1057 auto t0 = at::randn(shape, options);
1058
1059 FusionExecutor fe;
1060 fe.compileFusion(&fusion, {t0});
1061 auto outputs = fe.runFusion({t0});
1062
1063 auto t3 = t0.sum({0}).unsqueeze(-1);
1064 auto ref = t0 + t3 + t3;
1065
1066 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
1067}
1068
1069// Grid reductionso of different types
1070TEST_F(NVFuserTest, FusionGroupAllreduce2_CUDA) {
1071 Fusion fusion;
1072 FusionGuard fg(&fusion);
1073
1074 auto tv0 = makeSymbolicTensor(2);
1075 fusion.addInput(tv0);
1076
1077 auto tv1 = sum(tv0, {1});
1078 auto tv2 = broadcast(tv1, {false, true});
1079
1080 auto tv3 = castOp(DataType::Double, tv0);
1081 auto tv4 = sum(tv3, {1});
1082 auto tv5 = broadcast(tv4, {false, true});
1083 auto tv6 = castOp(DataType::Float, tv5);
1084
1085 auto tv7 = add(tv0, tv2);
1086 auto tv8 = add(tv7, tv6);
1087 fusion.addOutput(tv8);
1088
1089 const int tidx = 512;
1090 groupReductions({tv1, tv4});
1091 tv1->split(1, tidx);
1092 TransformPropagator propagator(tv1);
1093 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
1094
1095 tv0->computeAt(tv8, -1, ComputeAtMode::MostInlined);
1096
1097 tv1->axis(0)->parallelize(ParallelType::BIDy);
1098 tv1->axis(1)->parallelize(ParallelType::BIDx);
1099 tv1->axis(2)->parallelize(ParallelType::TIDx);
1100 scheduler_utils::parallelizeAllLike(tv1);
1101
1102 std::vector<int64_t> shape({10, 999});
1103
1104 if (shape.at(0) * ceilDiv(shape.at(1), tidx) > deviceSMCount()) {
1105 GTEST_SKIP() << "Not enough SMs to run this test";
1106 }
1107
1108 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1109
1110 auto t0 = at::randn(shape, options);
1111
1112 FusionExecutor fe;
1113 fe.compileFusion(&fusion, {t0});
1114 auto outputs = fe.runFusion({t0});
1115
1116 auto t2 = t0.sum({1}).unsqueeze(-1);
1117 auto t6 = t0.to(c10::kDouble).sum({1}).unsqueeze(-1).to(c10::kFloat);
1118 auto ref = t0 + t2 + t6;
1119
1120 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
1121}
1122
1123// Grouping 3 grid allreduces
1124TEST_F(NVFuserTest, FusionGroupAllreduce3_CUDA) {
1125 Fusion fusion;
1126 FusionGuard fg(&fusion);
1127
1128 auto tv0 = makeSymbolicTensor(1);
1129 fusion.addInput(tv0);
1130
1131 auto tv1 = sum(tv0, {0});
1132 auto tv2 = broadcast(tv1, {true});
1133 auto tv3 = div(tv0, tv2);
1134 auto tv4 = max(tv0, {0});
1135 auto tv5 = broadcast(tv4, {true});
1136 auto tv6 = div(tv0, tv5);
1137 auto tv7 = min(tv0, {0});
1138 auto tv8 = broadcast(tv7, {true});
1139 auto tv9 = sub(tv0, tv8);
1140 fusion.addOutput(tv3);
1141 fusion.addOutput(tv6);
1142 fusion.addOutput(tv9);
1143
1144 groupReductions({tv1, tv4, tv7});
1145
1146 tv1->split(0, 128);
1147 TransformPropagator propagator(tv1);
1148 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
1149
1150 tv1->axis(0)->parallelize(ParallelType::BIDx);
1151 tv1->axis(1)->parallelize(ParallelType::TIDx);
1152 scheduler_utils::parallelizeAllLike(tv1);
1153
1154 std::vector<int64_t> shape({999});
1155
1156 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1157
1158 auto t0 = at::randn(shape, options);
1159
1160 FusionExecutor fe;
1161 fe.compileFusion(&fusion, {t0});
1162 auto outputs = fe.runFusion({t0});
1163
1164 auto t3 = t0 / t0.sum({0}).unsqueeze(0);
1165 auto t6 = t0 / std::get<0>(t0.max(0)).unsqueeze(0);
1166 auto t9 = t0 - std::get<0>(t0.min(0)).unsqueeze(0);
1167
1168 testValidate(fe.kernel(), outputs, {t0}, {t3, t6, t9}, __LINE__, __FILE__);
1169}
1170
1171// Grouping 8 grid allreduces
1172TEST_F(NVFuserTest, FusionGroupAllreduce4_CUDA) {
1173 Fusion fusion;
1174 FusionGuard fg(&fusion);
1175
1176 const int num_reductions = 8;
1177
1178 auto tv0 = makeSymbolicTensor(1);
1179 fusion.addInput(tv0);
1180
1181 auto tv_sum = tv0;
1182 std::vector<TensorView*> reduction_tvs;
1183
1184 for (int i = 0; i < num_reductions; ++i) {
1185 auto reduction = sum(add(tv0, IrBuilder::create<Double>(i)), {0});
1186 reduction_tvs.push_back(reduction);
1187 auto avg = div(reduction, tv0->axis(0)->extent());
1188 auto bc = broadcast(avg, {true});
1189 tv_sum = add(tv_sum, bc);
1190 }
1191
1192 fusion.addOutput(tv_sum);
1193
1194 groupReductions(reduction_tvs);
1195
1196 auto reduction_tv = reduction_tvs.at(0);
1197
1198 reduction_tv->split(0, 128);
1199 TransformPropagator propagator(reduction_tv);
1200 MaxRootDomainInfoSpanningTree(reduction_tv).traverse(&propagator);
1201
1202 reduction_tv->axis(0)->parallelize(ParallelType::BIDx);
1203 reduction_tv->axis(1)->parallelize(ParallelType::TIDx);
1204 scheduler_utils::parallelizeAllLike(reduction_tv);
1205
1206 std::vector<int64_t> shape({999});
1207
1208 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1209
1210 auto t0 = at::randn(shape, options);
1211
1212 FusionExecutor fe;
1213 fe.compileFusion(&fusion, {t0});
1214 auto outputs = fe.runFusion({t0});
1215
1216 at::Tensor ref = t0;
1217 for (int i = 0; i < num_reductions; ++i) {
1218 auto reduction = sum(add(t0, i), {0});
1219 auto avg = reduction / t0.sizes()[0];
1220 auto bc = avg.unsqueeze(0);
1221 ref = add(ref, bc);
1222 }
1223
1224 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
1225}
1226
1227// Variation of FusionGroupAllreduce5_CUDA but with different
1228// types. Exercise grouped allreduces with different types.
1229TEST_F(NVFuserTest, FusionGroupAllreduce5_CUDA) {
1230 Fusion fusion;
1231 FusionGuard fg(&fusion);
1232
1233 auto tv0 = makeSymbolicTensor(1, DataType::Float);
1234 fusion.addInput(tv0);
1235 auto tv1 = sum(tv0, {0});
1236 auto tv2 = broadcast(tv1, {true});
1237 auto tv3 = div(tv0, tv2);
1238
1239 auto tv4 = makeSymbolicTensor(1, DataType::Double);
1240 fusion.addInput(tv4);
1241 auto tv5 = sum(tv4, {0});
1242 auto tv6 = broadcast(tv5, {true});
1243 auto tv7 = div(tv4, tv6);
1244
1245 auto tv8 = makeSymbolicTensor(1, DataType::Int);
1246 fusion.addInput(tv8);
1247 auto tv9 = sum(tv8, {0});
1248 auto tv10 = broadcast(tv9, {true});
1249 auto tv11 = div(tv8, tv10);
1250
1251 auto out = add(
1252 add(castOp(DataType::Double, tv3), tv7), castOp(DataType::Double, tv11));
1253
1254 fusion.addOutput(out);
1255
1256 groupReductions({tv1, tv5, tv9});
1257
1258 tv1->split(0, 128);
1259 TransformPropagator propagator(tv1);
1260 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
1261
1262 tv1->axis(0)->parallelize(ParallelType::BIDx);
1263 tv1->axis(1)->parallelize(ParallelType::TIDx);
1264 scheduler_utils::parallelizeAllLike(tv1);
1265
1266 std::vector<int64_t> shape({999});
1267
1268 auto options_float =
1269 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1270 auto options_double =
1271 at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0);
1272 auto options_long = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
1273
1274 auto t0 = at::randn(shape, options_float);
1275 auto t4 = at::randn(shape, options_double);
1276 auto t8 = torch::randint(0, 1000, shape, options_long);
1277 std::vector<IValue> aten_inputs = {t0, t4, t8};
1278
1279 std::vector<at::indexing::TensorIndex> indices({at::indexing::Slice(0, 10)});
1280
1281 FusionExecutor fe;
1282 fe.compileFusion(&fusion, aten_inputs);
1283 auto outputs = fe.runFusion(aten_inputs);
1284
1285 auto t3 = t0 / t0.sum({0}).unsqueeze(0).to(at::kDouble);
1286 auto t7 = t4 / t4.sum({0}).unsqueeze(0);
1287 auto t11 = t8 / t8.sum({0}).unsqueeze(0).to(at::kDouble);
1288 auto ref = t3 + t7 + t11;
1289
1290 testValidate(fe.kernel(), outputs, aten_inputs, {ref}, __LINE__, __FILE__);
1291}
1292
1293// Persistent batchnorm backward with grouped allreduce
1294TEST_F(NVFuserTest, FusionPersistentBNBackwardAllreduce_CUDA) {
1295 const std::vector<int64_t> shape({64, 1024, 14, 14});
1296
1297 Fusion fusion;
1298 FusionGuard fg(&fusion);
1299
1300 auto input = makeContigTensor(4);
1301 fusion.addInput(input);
1302 auto grad_output = makeContigTensor(4);
1303 fusion.addInput(grad_output);
1304 auto weight = makeContigTensor(1);
1305 fusion.addInput(weight);
1306 auto save_mean = makeContigTensor(1);
1307 fusion.addInput(save_mean);
1308 auto save_invstd = makeContigTensor(1);
1309 fusion.addInput(save_invstd);
1310
1311 const bool kTraining = true;
1312 const bool channels_last = false;
1313
1314 const size_t kNumberOfDims =
1315 TensorDomain::noReductions(input->getMaybeRFactorDomain()).size();
1316 size_t c_axis = channels_last ? kNumberOfDims - 1 : 1;
1317
1318 std::vector<int> reduction_axes;
1319 std::vector<bool> broadcast_mask(kNumberOfDims, false);
1320 Val* num_features = nullptr;
1321 for (const auto axis : c10::irange(kNumberOfDims)) {
1322 if (axis != c_axis) {
1323 reduction_axes.push_back(axis);
1324 broadcast_mask[axis] = true;
1325 if (num_features == nullptr) {
1326 num_features =
1327 castOp(DataType::Double, input->domain()->domain()[axis]->extent());
1328 } else {
1329 num_features =
1330 mul(num_features, input->domain()->domain()[axis]->extent());
1331 }
1332 }
1333 }
1334
1335 auto mean = save_mean;
1336 auto invstd = save_invstd;
1337
1338 mean = broadcast(mean, broadcast_mask);
1339
1340 auto norm = reciprocal(num_features);
1341
1342 auto grad_output_sum = sum(grad_output, reduction_axes);
1343 auto dot_p = sum(mul(grad_output, sub(input, mean)), reduction_axes);
1344
1345 auto grad_mean = broadcast(mul(grad_output_sum, norm), broadcast_mask);
1346
1347 auto proj_scale =
1348 broadcast(mul(mul(dot_p, norm), mul(invstd, invstd)), broadcast_mask);
1349
1350 TensorView* grad_scale = nullptr;
1351
1352 if (weight == nullptr) {
1353 grad_scale =
1354 mul(broadcast(invstd, broadcast_mask),
1355 IrBuilder::create<Double>(input->container(), 1));
1356 } else {
1357 grad_scale = mul(
1358 broadcast(invstd, broadcast_mask), broadcast(weight, broadcast_mask));
1359 }
1360
1361 TensorView* grad_input = nullptr;
1362 if (kTraining) {
1363 auto proj = mul(sub(input, mean), proj_scale);
1364 grad_input = mul(sub(sub(grad_output, proj), grad_mean), grad_scale);
1365 } else {
1366 grad_input = mul(grad_output, grad_scale);
1367 }
1368
1369 fusion.addOutput(grad_input);
1370
1371 // Scheduling strategy
1372 // 1. Cache inputs
1373 // 2. Group the reductions (automatically fused with broadcasts)
1374 // 3. Merge HW and vectorize with the outer parallelized by TIDx
1375 // 4. Split N by TIDy with the outer parallelized by BIDx and
1376 // inner by TIDy
1377 // 5. Split C by BIDy and let the outer be the serial outermost loop
1378
1379 auto input_cache = input->cacheAfter();
1380 auto grad_output_cache = grad_output->cacheAfter();
1381 auto weight_cache = weight->cacheAfter();
1382 auto save_mean_cache = save_mean->cacheAfter();
1383 auto save_invstd_cache = save_invstd->cacheAfter();
1384
1385 // Group the two reductions
1386 groupReductions({grad_output_sum, dot_p});
1387
1388 // Transform grad_input to: [C/bidy, N/tidy, tidy, bidy, HW/vec_width,
1389 // vec_width]
1390 const int tidy = 8;
1391 const int bidy = 4;
1392 const int bidx = ceilDiv(shape[0], (int64_t)tidy);
1393 const int vec_width = 4;
1394 TORCH_CHECK(
1395 (shape[2] * shape[3]) % vec_width == 0,
1396 "Invalid vector width: ",
1397 vec_width);
1398
1399 grad_input->merge(-2, -1);
1400 grad_input->split(-1, vec_width);
1401
1402 grad_input->split(0, tidy);
1403 grad_input->split(2, bidy);
1404 TORCH_CHECK(
1405 grad_input->nDims() == 6,
1406 "Unexpected number of dimensions: ",
1407 grad_input->toString());
1408
1409 grad_input->reorder({{2, 0}, {0, 1}, {1, 2}});
1410
1411 grad_input->axis(1)->parallelize(ParallelType::BIDx);
1412 grad_input->axis(2)->parallelize(ParallelType::TIDy);
1413 grad_input->axis(3)->parallelize(ParallelType::BIDy);
1414 grad_input->axis(4)->parallelize(ParallelType::TIDx);
1415
1416 TransformPropagator propagator(grad_input);
1417 MaxRootDomainInfoSpanningTree(grad_input).traverse(&propagator);
1418
1419 auto rf_tensors = grad_output_sum->rFactor(
1420 {-1}, std::vector<TensorView*>({grad_output_sum, dot_p}));
1421
1422 for (auto fusion_input :
1423 ir_utils::filterByType<TensorView>(fusion.inputs())) {
1424 fusion_input->computeAt(grad_input, 1);
1425 }
1426
1427 // Parallelization
1428 scheduler_utils::parallelizeAllLike(grad_input);
1429 input_cache->axis(-1)->parallelize(ParallelType::Vectorize);
1430 grad_output_cache->axis(-1)->parallelize(ParallelType::Vectorize);
1431
1432 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1433 auto at_input = at::randn(shape, options);
1434 auto at_grad_output = at::randn(shape, options);
1435 auto at_weight = at::randn({shape[c_axis]}, options);
1436 auto at_save_mean = at::randn({shape[c_axis]}, options);
1437 auto at_save_invstd = at::randn({shape[c_axis]}, options);
1438 std::vector<IValue> aten_inputs(
1439 {at_input, at_grad_output, at_weight, at_save_mean, at_save_invstd});
1440
1441 GpuLower gpulw(&fusion);
1442 validateNoParallelBroadcastExist(gpulw.kernel());
1443
1444 FusionExecutor fe;
1445 fe.compileFusion(&fusion, aten_inputs);
1446
1447 if (bidx * bidy > deviceSMCount()) {
1448 GTEST_SKIP() << "Not enough SMs to run this test";
1449 }
1450
1451 auto outputs = fe.runFusion(aten_inputs);
1452
1453 std::vector<int64_t> at_reduction_axes;
1454 std::copy(
1455 reduction_axes.begin(),
1456 reduction_axes.end(),
1457 std::back_inserter(at_reduction_axes));
1458
1459 // MSVC bug on lambda non-capture of const integral type
1460 // https://developercommunity.visualstudio.com/t/lambda-fails-to-implicitly-capture-constexpr-value/610504
1461 auto at_bcast = [=](const auto& tensor) {
1462 if (channels_last) {
1463 tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0);
1464 } else {
1465 return tensor.unsqueeze(0).unsqueeze(-1).unsqueeze(-1);
1466 }
1467 };
1468
1469 auto at_mean = at_save_mean;
1470 const auto& at_invstd = at_save_invstd;
1471 at_mean = at_bcast(at_mean);
1472 auto at_norm = 1.0f / static_cast<float>(shape[0] * shape[2] * shape[3]);
1473
1474 auto at_grad_output_sum = sum(at_grad_output, at_reduction_axes);
1475 auto at_dot_p =
1476 sum(mul(at_grad_output, sub(at_input, at_mean)), at_reduction_axes);
1477
1478 auto at_grad_mean = at_bcast(at_grad_output_sum * at_norm);
1479
1480 auto at_proj_scale = at_bcast((at_dot_p * at_norm) * (at_invstd * at_invstd));
1481
1482 at::Tensor at_grad_scale;
1483
1484 if (weight == nullptr) {
1485 at_grad_scale = at_bcast(at_invstd);
1486 } else {
1487 at_grad_scale = at_bcast(at_invstd) * at_bcast(at_weight);
1488 }
1489
1490 at::Tensor at_grad_input;
1491 if (kTraining) {
1492 auto at_proj = (at_input - at_mean) * at_proj_scale;
1493 at_grad_input = (at_grad_output - at_proj - at_grad_mean) * at_grad_scale;
1494 } else {
1495 at_grad_input = at_grad_output * at_grad_scale;
1496 }
1497
1498 testValidate(
1499 fe.kernel(), outputs, aten_inputs, {at_grad_input}, __LINE__, __FILE__);
1500}
1501
1502TEST_F(NVFuserTest, FusionGroupedReductionReEntrant1_CUDA) {
1503 Fusion fusion;
1504 FusionGuard fg(&fusion);
1505
1506 auto tv0 = makeSymbolicTensor(2);
1507 fusion.addInput(tv0);
1508
1509 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
1510 auto tv2 = sum(tv1, {0});
1511
1512 auto tv3 = add(tv0, IrBuilder::create<Double>(2));
1513 auto tv4 = sum(tv3, {0});
1514
1515 auto tv5 = add(tv2, tv4);
1516 fusion.addOutput(tv5);
1517
1518 groupReductions({tv2, tv4});
1519
1520 auto tv0_cache = tv0->cacheAfter();
1521
1522 const int vec = 2;
1523 const int tidx = 64;
1524 const int tidy = 8;
1525
1526 tv2->split(1, vec);
1527 tv2->split(1, tidx);
1528
1529 tv2->split(0, tidy);
1530 TransformPropagator propagator(tv2);
1531 MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);
1532
1533 tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
1534
1535 tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined);
1536
1537 tv2->axis(0)->parallelize(ParallelType::BIDy);
1538 tv2->axis(1)->parallelize(ParallelType::TIDy);
1539 tv2->axis(2)->parallelize(ParallelType::BIDx);
1540 tv2->axis(3)->parallelize(ParallelType::TIDx);
1541
1542 scheduler_utils::parallelizeAllLike(tv2);
1543
1544 std::vector<int64_t> shape({99, 999});
1545
1546 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1547 at::manual_seed(0);
1548
1549 auto t0 = at::randn(shape, options);
1550
1551 FusionExecutor fe;
1552 fe.compileFusion(&fusion, {t0});
1553 auto outputs = fe.runFusion({t0});
1554
1555 auto t0_double = t0.to(at::kDouble);
1556 auto ref = (t0_double + 1).sum({0}) + (t0_double + 2).sum({0});
1557
1558 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
1559}
1560
1561// Channels-last batch norm with vectorization. Relies on re-entrant
1562// GroupedGridReduction
1563TEST_F(NVFuserTest, FusionGroupedReductionChannelsLastBatchNormLike_CUDA) {
1564 Fusion fusion;
1565 FusionGuard fg(&fusion);
1566
1567 const std::vector<int64_t> shape({64, 14, 14, 32});
1568
1569 auto tv0 = makeContigTensor(4, DataType::Half);
1570 fusion.addInput(tv0);
1571 auto tv1 = makeContigTensor(4, DataType::Half);
1572 fusion.addInput(tv1);
1573 auto tv2 = makeContigTensor(1);
1574 fusion.addInput(tv2);
1575
1576 std::vector<int> reduction_axes({0, 1, 2});
1577 std::vector<bool> broadcast_mask({true, true, true, false});
1578
1579 auto tv3 = castOp(DataType::Float, tv0);
1580 auto tv4 = castOp(DataType::Float, tv1);
1581
1582 auto tv5 = sum(tv3, reduction_axes);
1583
1584 auto tv6 = broadcast(tv2, broadcast_mask);
1585 auto tv7 = sub(tv4, tv6);
1586 auto tv8 = mul(tv3, tv7);
1587 auto tv9 = sum(tv8, reduction_axes);
1588
1589 auto tv10 = castOp(DataType::Half, tv5);
1590 auto tv11 = castOp(DataType::Half, tv9);
1591
1592 fusion.addOutput(tv10);
1593 fusion.addOutput(tv11);
1594
1595 groupReductions({tv5, tv9});
1596
1597 // Applies the outer-reduction schedule
1598 const int64_t num_channels = shape.back();
1599 const int64_t vector = 2;
1600 TORCH_CHECK(num_channels % vector == 0);
1601 // Use at most 32 TIDx threads
1602 const int64_t tidx = std::min<int64_t>(32l, num_channels / vector);
1603 const auto bidx = ceilDiv(num_channels, tidx * vector);
1604
1605 const int64_t tidy = 8;
1606 const auto bidy = ceilDiv(
1607 at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 4, bidx);
1608
1609 auto tv0_cache = tv0->cacheAfter();
1610 auto tv1_cache = tv1->cacheAfter();
1611
1612 auto ref = tv5;
1613
1614 // Move the reduction domains inner positions
1615 ref->reorder({{0, 1}, {1, 2}, {2, 3}, {3, 0}});
1616
1617 // Parallelizing the reduction domains
1618 ref->merge(2, 3);
1619 ref->merge(1, 2);
1620 ref->split(1, tidy);
1621 ref->split(1, bidy, false);
1622
1623 // Parallelizing the iteration domains
1624 ref->split(0, vector);
1625 ref->split(0, tidx);
1626
1627 // Move the vector axis to the innermost position
1628 ref->reorder({{2, 5}, {3, 2}, {4, 3}, {5, 4}});
1629 // Move the serial reduction to the right of the vector axis
1630 ref->reorder({{3, 4}, {4, 3}});
1631
1632 TransformPropagator propagator(ref);
1633 MaxRootDomainInfoSpanningTree(ref).traverse(&propagator);
1634
1635 auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9});
1636 auto tv5_rf = rf_tvs.at(0);
1637 auto tv9_rf = rf_tvs.at(1);
1638
1639 tv0->computeAt(tv5_rf, -2, ComputeAtMode::BestEffort);
1640 tv1->computeAt(tv9_rf, -2, ComputeAtMode::BestEffort);
1641 tv3->computeAt(tv5_rf, -1, ComputeAtMode::BestEffort);
1642 tv4->computeAt(tv9_rf, -1, ComputeAtMode::BestEffort);
1643
1644 ref = tv5_rf;
1645
1646 ref->axis(0)->parallelize(ParallelType::BIDx);
1647 ref->axis(1)->parallelize(ParallelType::TIDx);
1648 ref->axis(2)->parallelize(ParallelType::BIDy);
1649 ref->axis(3)->parallelize(ParallelType::TIDy);
1650 ref->axis(4)->parallelize(ParallelType::Serial);
1651 ref->axis(5)->parallelize(ParallelType::Serial);
1652
1653 scheduler_utils::parallelizeAllLike(ref);
1654
1655 tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
1656 tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize);
1657
1658 auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
1659 auto options_float =
1660 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1661 auto t0 = at::randn(shape, options_half);
1662 auto t1 = at::randn(shape, options_half);
1663 auto t2 = at::randn({shape.back()}, options_float);
1664 std::vector<IValue> aten_inputs({t0, t1, t2});
1665
1666 FusionExecutor fe;
1667 fe.compileFusion(&fusion, aten_inputs);
1668 auto outputs = fe.runFusion(aten_inputs);
1669
1670 auto t0_double = t0.to(at::kDouble);
1671 auto t1_double = t1.to(at::kDouble);
1672 auto t2_double = t2.to(at::kDouble);
1673
1674 std::vector<int64_t> at_reduction_axes(
1675 {reduction_axes.begin(), reduction_axes.end()});
1676 auto t5 = t0_double.sum(at_reduction_axes);
1677 auto t8 = t0_double *
1678 (t1_double - t2_double.unsqueeze(0).unsqueeze(0).unsqueeze(0));
1679 auto t9 = t8.sum(at_reduction_axes);
1680
1681 testValidate(fe.kernel(), outputs, aten_inputs, {t5, t9}, __LINE__, __FILE__);
1682}
1683
1684// Test the grouped grid allreduce with BN-like outer reductions
1685TEST_F(
1686 NVFuserTest,
1687 FusionGroupedReductionPersistentChannelsLastBatchNormLike_CUDA) {
1688 Fusion fusion;
1689 FusionGuard fg(&fusion);
1690
1691 const std::vector<int64_t> shape({64, 14, 14, 32});
1692
1693 auto tv0 = makeContigTensor(4, DataType::Half);
1694 fusion.addInput(tv0);
1695 auto tv1 = makeContigTensor(4, DataType::Half);
1696 fusion.addInput(tv1);
1697 auto tv2 = makeContigTensor(1);
1698 fusion.addInput(tv2);
1699
1700 std::vector<int> reduction_axes({0, 1, 2});
1701 std::vector<bool> broadcast_mask({true, true, true, false});
1702
1703 auto tv3 = castOp(DataType::Float, tv0);
1704 auto tv4 = castOp(DataType::Float, tv1);
1705
1706 auto tv5 = sum(tv3, reduction_axes);
1707
1708 auto tv6 = broadcast(tv2, broadcast_mask);
1709 auto tv7 = sub(tv4, tv6);
1710 auto tv8 = mul(tv3, tv7);
1711 auto tv9 = sum(tv8, reduction_axes);
1712
1713 auto tv10 = broadcast(tv5, broadcast_mask);
1714 auto tv11 = add(tv3, tv10);
1715
1716 auto tv12 = broadcast(tv9, broadcast_mask);
1717 auto tv13 = add(tv4, tv12);
1718
1719 auto tv14 = castOp(DataType::Half, tv11);
1720 auto tv15 = castOp(DataType::Half, tv13);
1721
1722 fusion.addOutput(tv14);
1723 fusion.addOutput(tv15);
1724
1725 groupReductions({tv5, tv9});
1726
1727 // Applies the outer-reduction schedule
1728 const int64_t num_channels = shape.back();
1729 const int64_t vector = 2;
1730 TORCH_CHECK(num_channels % vector == 0);
1731 // Use at most 32 TIDx threads
1732 const int64_t tidx = std::min<int64_t>(32l, num_channels / vector);
1733 const auto bidx = ceilDiv(num_channels, tidx * vector);
1734
1735 const int64_t tidy = 8;
1736 const int64_t reduction_work_per_thread = 8;
1737
1738 auto tv0_cache = tv0->cacheAfter();
1739 auto tv1_cache = tv1->cacheAfter();
1740
1741 auto ref = tv5;
1742
1743 // Move the reduction domains inner positions
1744 ref->reorder({{0, 1}, {1, 2}, {2, 3}, {3, 0}});
1745
1746 // Parallelizing the reduction domains
1747 ref->merge(2, 3);
1748 ref->merge(1, 2);
1749 ref->split(1, tidy);
1750 ref->split(1, reduction_work_per_thread);
1751
1752 // Parallelizing the iteration domains
1753 ref->split(0, vector);
1754 ref->split(0, tidx);
1755
1756 // Move the vector axis to the innermost position
1757 ref->reorder({{2, 5}, {3, 2}, {4, 3}, {5, 4}});
1758 // Move the serial reduction to the right of the vector axis
1759 ref->reorder({{3, 4}, {4, 3}});
1760
1761 TransformPropagator propagator(ref);
1762 MaxRootDomainInfoSpanningTree(ref).traverse(&propagator);
1763
1764 auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9});
1765 auto tv5_rf = rf_tvs.at(0);
1766 auto tv9_rf = rf_tvs.at(1);
1767
1768 tv0->computeAt(tv5_rf, -2, ComputeAtMode::BestEffort);
1769 tv1->computeAt(tv9_rf, -2, ComputeAtMode::BestEffort);
1770 tv3->computeAt(tv5_rf, -1, ComputeAtMode::BestEffort);
1771 tv4->computeAt(tv9_rf, -1, ComputeAtMode::BestEffort);
1772
1773 ref = tv5_rf;
1774
1775 ref->axis(0)->parallelize(ParallelType::BIDx);
1776 ref->axis(1)->parallelize(ParallelType::TIDx);
1777 ref->axis(2)->parallelize(ParallelType::BIDy);
1778 ref->axis(3)->parallelize(ParallelType::TIDy);
1779 ref->axis(4)->parallelize(ParallelType::Serial);
1780 ref->axis(5)->parallelize(ParallelType::Serial);
1781
1782 scheduler_utils::parallelizeAllLike(ref);
1783
1784 tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
1785 tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize);
1786
1787 tv5->axis(-1)->parallelize(ParallelType::Group);
1788
1789 auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
1790 auto options_float =
1791 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1792 auto t0 = at::randn(shape, options_half);
1793 auto t1 = at::randn(shape, options_half);
1794 auto t2 = at::randn({shape.back()}, options_float);
1795 std::vector<IValue> aten_inputs({t0, t1, t2});
1796
1797 FusionExecutor fe;
1798 fe.compileFusion(&fusion, aten_inputs);
1799 auto outputs = fe.runFusion(aten_inputs);
1800
1801 auto t0_double = t0.to(at::kDouble);
1802 auto t1_double = t1.to(at::kDouble);
1803 auto t2_double = t2.to(at::kDouble);
1804
1805 std::vector<int64_t> at_reduction_axes(
1806 {reduction_axes.begin(), reduction_axes.end()});
1807 auto t5 = t0_double.sum(at_reduction_axes);
1808 auto t8 = t0_double *
1809 (t1_double - t2_double.unsqueeze(0).unsqueeze(0).unsqueeze(0));
1810 auto t9 = t8.sum(at_reduction_axes);
1811
1812 auto t10 = t5.unsqueeze(0).unsqueeze(0).unsqueeze(0);
1813 auto t11 = t0_double + t10;
1814 auto t12 = t9.unsqueeze(0).unsqueeze(0).unsqueeze(0);
1815 auto t13 = t1_double + t12;
1816
1817 testValidate(
1818 fe.kernel(), outputs, aten_inputs, {t11, t13}, __LINE__, __FILE__);
1819}
1820
1821TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce1_CUDA) {
1822 Fusion fusion;
1823 FusionGuard fg(&fusion);
1824
1825 auto tv0 = makeSymbolicTensor(2);
1826 fusion.addInput(tv0);
1827
1828 auto tv1 = set(tv0);
1829 auto tv2 = sum(tv1, {0});
1830 auto tv3 = broadcast(tv2, {true, false});
1831 auto tv4 = add(tv0, tv3);
1832 fusion.addOutput(tv4);
1833
1834 const int vec = 2;
1835 const int tidx = 32;
1836 const int tidy = 8;
1837
1838 tv1->split(1, vec);
1839 tv1->split(1, tidx);
1840 tv1->split(0, tidy);
1841 TransformPropagator propagator(tv1);
1842 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
1843
1844 tv1->axis(0)->parallelize(ParallelType::BIDy);
1845 tv1->axis(1)->parallelize(ParallelType::TIDy);
1846 tv1->axis(2)->parallelize(ParallelType::BIDx);
1847 tv1->axis(3)->parallelize(ParallelType::TIDx);
1848
1849 scheduler_utils::parallelizeAllLike(tv1);
1850
1851 tv2->axis(4)->parallelize(ParallelType::Group);
1852
1853 // Make sure the reduction expr is converted to GroupedGridReduciton
1854 // and the non-reduction domains of the output TV are either
1855 // grouped or parallelized
1856 GpuLower gpulw(&fusion);
1857 bool validated = false;
1858 for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) {
1859 auto grouped_grid_reduction =
1860 dynamic_cast<kir::GroupedGridReduction*>(expr);
1861 if (grouped_grid_reduction == nullptr) {
1862 continue;
1863 }
1864 auto out = ir_utils::getTvOutput(grouped_grid_reduction);
1865 for (auto out_axis : out->domain()->domain()) {
1866 auto out_axis_pt = out_axis->getParallelType();
1867 TORCH_CHECK(
1868 isParallelTypeThread(out_axis_pt) ||
1869 out_axis_pt == ParallelType::Group,
1870 "Invalid parallel type of the reduction tensor: ",
1871 out_axis_pt,
1872 ". Reduction output tensor: ",
1873 out->toString());
1874 }
1875 validated = true;
1876 }
1877 TORCH_CHECK(
1878 validated, "Invalid lowered kernel. No GroupedGridReduction found.");
1879
1880 std::vector<int64_t> shape({99, 101});
1881
1882 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1883 at::manual_seed(0);
1884 auto t0 = at::randn(shape, options);
1885
1886 FusionExecutor fe;
1887 fe.compileFusion(&fusion, {t0});
1888 auto outputs = fe.runFusion({t0});
1889
1890 auto t0_double = t0.to(at::kDouble);
1891 auto ref = t0_double + t0_double.sum({0}).unsqueeze(0);
1892
1893 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
1894}
1895
1896// Test grouping of two domains
1897TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce2_CUDA) {
1898 Fusion fusion;
1899 FusionGuard fg(&fusion);
1900
1901 auto tv0 = makeSymbolicTensor(2);
1902 fusion.addInput(tv0);
1903
1904 auto tv1 = set(tv0);
1905 auto tv2 = sum(tv1, {0});
1906 auto tv3 = broadcast(tv2, {true, false});
1907 auto tv4 = add(tv0, tv3);
1908 fusion.addOutput(tv4);
1909
1910 const int vec1 = 2;
1911 const int vec2 = 3;
1912 const int tidx = 16;
1913 const int tidy = 8;
1914
1915 tv1->split(1, vec1);
1916 tv1->split(1, vec2);
1917 tv1->split(1, tidx);
1918 tv1->split(0, tidy);
1919 TransformPropagator propagator(tv1);
1920 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
1921
1922 tv1->axis(0)->parallelize(ParallelType::BIDy);
1923 tv1->axis(1)->parallelize(ParallelType::TIDy);
1924 tv1->axis(2)->parallelize(ParallelType::BIDx);
1925 tv1->axis(3)->parallelize(ParallelType::TIDx);
1926
1927 scheduler_utils::parallelizeAllLike(tv1);
1928
1929 tv2->axis(4)->parallelize(ParallelType::Group);
1930 tv2->axis(5)->parallelize(ParallelType::Group);
1931
1932 std::vector<int64_t> shape({99, 129});
1933
1934 // Make sure the reduction expr is converted to GroupedGridReduciton
1935 // and the non-reduction domains of the output TV are either
1936 // grouped or parallelized
1937 GpuLower gpulw(&fusion);
1938 bool validated = false;
1939 for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) {
1940 auto grouped_grid_reduction =
1941 dynamic_cast<kir::GroupedGridReduction*>(expr);
1942 if (grouped_grid_reduction == nullptr) {
1943 continue;
1944 }
1945 auto out = ir_utils::getTvOutput(grouped_grid_reduction);
1946 for (auto out_axis : out->domain()->domain()) {
1947 auto out_axis_pt = out_axis->getParallelType();
1948 TORCH_CHECK(
1949 isParallelTypeThread(out_axis_pt) ||
1950 out_axis_pt == ParallelType::Group,
1951 "Invalid parallel type of the reduction tensor: ",
1952 out_axis_pt,
1953 ". Reduction output tensor: ",
1954 out->toString());
1955 }
1956 validated = true;
1957 }
1958 TORCH_CHECK(
1959 validated, "Invalid lowered kernel. No GroupedGridReduction found.");
1960
1961 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1962 at::manual_seed(0);
1963 auto t0 = at::randn(shape, options);
1964
1965 FusionExecutor fe;
1966 fe.compileFusion(&fusion, {t0});
1967 auto outputs = fe.runFusion({t0});
1968
1969 auto t0_double = t0.to(at::kDouble);
1970 auto ref = t0_double + t0_double.sum({0}).unsqueeze(0);
1971
1972 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
1973}
1974
1975// Group both expressions and iterations
1976TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce3_CUDA) {
1977 Fusion fusion;
1978 FusionGuard fg(&fusion);
1979
1980 auto tv0 = makeSymbolicTensor(2);
1981 fusion.addInput(tv0);
1982
1983 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
1984 auto tv2 = sum(tv1, {0});
1985 auto tv3 = broadcast(tv2, {true, false});
1986 auto tv4 = add(tv1, tv3);
1987
1988 auto tv5 = add(tv0, IrBuilder::create<Double>(2));
1989 auto tv6 = sum(tv5, {0});
1990 auto tv7 = broadcast(tv6, {true, false});
1991 auto tv8 = add(tv5, tv7);
1992
1993 auto tv9 = add(tv4, tv8);
1994 fusion.addOutput(tv9);
1995
1996 groupReductions({tv2, tv6});
1997
1998 const int vec = 2;
1999 const int tidx = 32;
2000 const int tidy = 8;
2001
2002 tv1->split(1, vec);
2003 tv1->split(1, tidx);
2004 tv1->split(0, tidy);
2005 TransformPropagator propagator(tv1);
2006 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
2007
2008 tv1->axis(0)->parallelize(ParallelType::BIDy);
2009 tv1->axis(1)->parallelize(ParallelType::TIDy);
2010 tv1->axis(2)->parallelize(ParallelType::BIDx);
2011 tv1->axis(3)->parallelize(ParallelType::TIDx);
2012
2013 scheduler_utils::parallelizeAllLike(tv1);
2014
2015 tv2->axis(4)->parallelize(ParallelType::Group);
2016
2017 // Make sure the reduction expr is converted to GroupedGridReduciton
2018 // and the non-reduction domains of the output TV are either
2019 // grouped or parallelized
2020 GpuLower gpulw(&fusion);
2021 bool validated = false;
2022 for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) {
2023 auto grouped_grid_reduction =
2024 dynamic_cast<kir::GroupedGridReduction*>(expr);
2025 if (grouped_grid_reduction == nullptr) {
2026 continue;
2027 }
2028 auto out = ir_utils::getTvOutput(grouped_grid_reduction);
2029 for (auto out_axis : out->domain()->domain()) {
2030 auto out_axis_pt = out_axis->getParallelType();
2031 TORCH_CHECK(
2032 isParallelTypeThread(out_axis_pt) ||
2033 out_axis_pt == ParallelType::Group,
2034 "Invalid parallel type of the reduction tensor: ",
2035 out_axis_pt,
2036 ". Reduction output tensor: ",
2037 out->toString());
2038 }
2039 validated = true;
2040 }
2041 TORCH_CHECK(
2042 validated, "Invalid lowered kernel. No GroupedGridReduction found.");
2043
2044 std::vector<int64_t> shape({99, 101});
2045
2046 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2047 at::manual_seed(0);
2048 auto t0 = at::randn(shape, options);
2049
2050 FusionExecutor fe;
2051 fe.compileFusion(&fusion, {t0});
2052 auto outputs = fe.runFusion({t0});
2053
2054 auto t0_double = t0.to(at::kDouble);
2055 auto t4 = t0_double + 1 + (t0_double + 1).sum({0}).unsqueeze(0);
2056 auto t8 = t0_double + 2 + (t0_double + 2).sum({0}).unsqueeze(0);
2057 auto ref = t4 + t8;
2058
2059 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
2060}
2061
2062// ParallelType::Group with computeAt
2063TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce4_CUDA) {
2064 Fusion fusion;
2065 FusionGuard fg(&fusion);
2066
2067 auto tv0 = makeSymbolicTensor(2);
2068 fusion.addInput(tv0);
2069
2070 auto tv1 = set(tv0);
2071 auto tv2 = sum(tv1, {0});
2072 auto tv3 = broadcast(tv2, {true, false});
2073 auto tv4 = add(tv0, tv3);
2074 fusion.addOutput(tv4);
2075
2076 const int vec = 2;
2077 const int tidx = 32;
2078 const int tidy = 8;
2079
2080 tv2->reorder({{0, 1}});
2081 tv2->split(0, vec);
2082 tv2->split(0, tidx);
2083 tv2->split(-1, tidy);
2084
2085 TransformPropagator propagator(tv2);
2086 MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);
2087
2088 tv2->axis(2)->parallelize(ParallelType::Group);
2089
2090 // This should avoid inlining the grouped domain
2091 tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined);
2092
2093 TORCH_CHECK(
2094 tv1->getComputeAtPosition() == 2,
2095 "Invalid computeAt position: ",
2096 tv1->toString());
2097 TORCH_CHECK(
2098 tv2->getComputeAtPosition() == 2,
2099 "Invalid computeAt position: ",
2100 tv2->toString());
2101
2102 tv4->axis(0)->parallelize(ParallelType::BIDx);
2103 tv4->axis(1)->parallelize(ParallelType::TIDx);
2104
2105 for (auto tv : ir_utils::allTvs(&fusion)) {
2106 tv->axis(-2)->parallelize(ParallelType::BIDy);
2107 tv->axis(-1)->parallelize(ParallelType::TIDy);
2108 }
2109
2110 // Make sure the reduction expr is converted to GroupedGridReduciton
2111 // and the non-reduction domains of the output TV are either
2112 // grouped or parallelized
2113 GpuLower gpulw(&fusion);
2114 bool validated = false;
2115 for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) {
2116 auto grouped_grid_reduction =
2117 dynamic_cast<kir::GroupedGridReduction*>(expr);
2118 if (grouped_grid_reduction == nullptr) {
2119 continue;
2120 }
2121 auto out = ir_utils::getTvOutput(grouped_grid_reduction);
2122 for (auto out_axis : out->domain()->domain()) {
2123 auto out_axis_pt = out_axis->getParallelType();
2124 TORCH_CHECK(
2125 isParallelTypeThread(out_axis_pt) ||
2126 out_axis_pt == ParallelType::Group,
2127 "Invalid parallel type of the reduction tensor: ",
2128 out_axis_pt,
2129 ". Reduction output tensor: ",
2130 out->toString());
2131 }
2132 validated = true;
2133 }
2134 TORCH_CHECK(
2135 validated, "Invalid lowered kernel. No GroupedGridReduction found.");
2136
2137 std::vector<int64_t> shape({99, 101});
2138
2139 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2140 at::manual_seed(0);
2141 auto t0 = at::randn(shape, options);
2142
2143 FusionExecutor fe;
2144 fe.compileFusion(&fusion, {t0});
2145 auto outputs = fe.runFusion({t0});
2146
2147 auto t0_double = t0.to(at::kDouble);
2148 auto ref = t0_double + t0_double.sum({0}).unsqueeze(0);
2149
2150 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
2151}
2152
2153TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelford1_CUDA) {
2154 Fusion fusion;
2155 FusionGuard fg(&fusion);
2156
2157 auto tv0 = makeSymbolicTensor(2);
2158 fusion.addInput(tv0);
2159
2160 auto tv1 = set(tv0);
2161 auto tv2 = Welford(tv1, {0}).avg;
2162 auto tv3 = broadcast(tv2, {true, false});
2163 auto tv4 = add(tv0, tv3);
2164 fusion.addOutput(tv4);
2165
2166 const int vec = 2;
2167 const int tidx = 32;
2168 const int tidy = 8;
2169
2170 tv1->split(1, vec);
2171 tv1->split(1, tidx);
2172 tv1->split(0, tidy);
2173 TransformPropagator propagator(tv1);
2174 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
2175
2176 tv1->axis(0)->parallelize(ParallelType::BIDy);
2177 tv1->axis(1)->parallelize(ParallelType::TIDy);
2178 tv1->axis(2)->parallelize(ParallelType::BIDx);
2179 tv1->axis(3)->parallelize(ParallelType::TIDx);
2180
2181 scheduler_utils::parallelizeAllLike(tv1);
2182
2183 tv2->axis(4)->parallelize(ParallelType::Group);
2184
2185 // Make sure the reduction expr is converted to GroupedGridReduciton
2186 // and the non-reduction domains of the output TV are either
2187 // grouped or parallelized
2188 GpuLower gpulw(&fusion);
2189 bool validated = false;
2190 for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) {
2191 auto grouped_grid_reduction = dynamic_cast<kir::GroupedGridWelford*>(expr);
2192 if (grouped_grid_reduction == nullptr) {
2193 continue;
2194 }
2195 validated = true;
2196 }
2197 TORCH_CHECK(
2198 validated, "Invalid lowered kernel. No GroupedGridWelford found.");
2199
2200 std::vector<int64_t> shape({99, 101});
2201
2202 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2203 at::manual_seed(0);
2204 auto t0 = at::randn(shape, options);
2205
2206 FusionExecutor fe;
2207 fe.compileFusion(&fusion, {t0});
2208 auto outputs = fe.runFusion({t0});
2209
2210 auto t0_double = t0.to(at::kDouble);
2211 auto ref = t0_double + t0_double.mean({0}).unsqueeze(0);
2212
2213 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
2214}
2215
2216// Test grouping of two domains
2217TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelford2_CUDA) {
2218 Fusion fusion;
2219 FusionGuard fg(&fusion);
2220
2221 auto tv0 = makeSymbolicTensor(2);
2222 fusion.addInput(tv0);
2223
2224 auto tv1 = set(tv0);
2225 auto tv2 = Welford(tv1, {0}).avg;
2226 auto tv3 = broadcast(tv2, {true, false});
2227 auto tv4 = add(tv0, tv3);
2228 fusion.addOutput(tv4);
2229
2230 const int vec1 = 2;
2231 const int vec2 = 3;
2232 const int tidx = 16;
2233 const int tidy = 8;
2234
2235 tv1->split(1, vec1);
2236 tv1->split(1, vec2);
2237 tv1->split(1, tidx);
2238 tv1->split(0, tidy);
2239 TransformPropagator propagator(tv1);
2240 MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);
2241
2242 tv1->axis(0)->parallelize(ParallelType::BIDy);
2243 tv1->axis(1)->parallelize(ParallelType::TIDy);
2244 tv1->axis(2)->parallelize(ParallelType::BIDx);
2245 tv1->axis(3)->parallelize(ParallelType::TIDx);
2246
2247 scheduler_utils::parallelizeAllLike(tv1);
2248
2249 tv2->axis(4)->parallelize(ParallelType::Group);
2250 tv2->axis(5)->parallelize(ParallelType::Group);
2251
2252 std::vector<int64_t> shape({99, 129});
2253
2254 // Make sure the reduction expr is converted to GroupedGridReduciton
2255 // and the non-reduction domains of the output TV are either
2256 // grouped or parallelized
2257 GpuLower gpulw(&fusion);
2258 bool validated = false;
2259 for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) {
2260 auto grouped_grid_reduction = dynamic_cast<kir::GroupedGridWelford*>(expr);
2261 if (grouped_grid_reduction == nullptr) {
2262 continue;
2263 }
2264 validated = true;
2265 }
2266 TORCH_CHECK(
2267 validated, "Invalid lowered kernel. No GroupedGridWelford found.");
2268
2269 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2270 at::manual_seed(0);
2271 auto t0 = at::randn(shape, options);
2272
2273 FusionExecutor fe;
2274 fe.compileFusion(&fusion, {t0});
2275 auto outputs = fe.runFusion({t0});
2276
2277 auto t0_double = t0.to(at::kDouble);
2278 auto ref = t0_double + t0_double.mean({0}).unsqueeze(0);
2279
2280 testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__);
2281}
2282
2283// Follows the pattern of persistent outer grid welford in batchnorm
2284TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelfordShmoo_CUDA) {
2285 struct Params {
2286 int N;
2287 int H;
2288 int W;
2289 int C;
2290 int tidx;
2291 int tidy;
2292 int vect;
2293 int persistent_buffer;
2294 int bidx;
2295 };
2296
2297 auto test = [](const Params& params) {
2298 Fusion fusion;
2299 FusionGuard fg(&fusion);
2300
2301 std::vector<bool> bcast_pattern{true, true, true, false};
2302 std::vector<int> reduction_dims{2, 1, 0};
2303
2304 auto tv0 = makeSymbolicTensor(4);
2305 fusion.addInput(tv0);
2306
2307 auto tv1 = set(tv0);
2308 auto tvs = Welford(tv1, reduction_dims);
2309 auto tv2 = tvs.avg;
2310 auto tv3 = tvs.var_sum;
2311 auto tv4 = tvs.n;
2312 auto tv5 = broadcast(tv2, bcast_pattern);
2313 auto tv6 = broadcast(tv3, bcast_pattern);
2314 auto tv7 = broadcast(tv4, bcast_pattern);
2315 auto tv8 = sub(tv1, tv5);
2316 auto tv9 = add(tv8, tv6);
2317 // auto tv10 = div(tv9, tv7);
2318 // fusion.addOutput(tv10);
2319 fusion.addOutput(tv9);
2320
2321 // Schedule the fusion as it will be done by the persistent
2322 // scheduler
2323
2324 auto input_cache = tv1;
2325 auto output_cache = tv9->cacheBefore();
2326
2327 auto transform_ref = tv2;
2328
2329 transform_ref->merge(0)->merge(0);
2330
2331 int reduction_pos = 1;
2332
2333 transform_ref->split(0, params.tidy);
2334 ++reduction_pos;
2335 transform_ref->axis(1)->parallelize(ParallelType::TIDy);
2336
2337 // Persistent buffer
2338 transform_ref->split(0, params.persistent_buffer);
2339 ++reduction_pos;
2340
2341 // Unswitch
2342 transform_ref->split(0, 1);
2343 ++reduction_pos;
2344 transform_ref->axis(1)->parallelize(ParallelType::Unswitch);
2345
2346 transform_ref->axis(0)->parallelize(ParallelType::BIDy);
2347
2348 transform_ref->split(reduction_pos, params.vect);
2349 transform_ref->axis(reduction_pos + 1)
2350 ->parallelize(ParallelType::Vectorize);
2351
2352 transform_ref->split(reduction_pos, params.tidx);
2353 transform_ref->axis(reduction_pos + 1)->parallelize(ParallelType::TIDx);
2354 transform_ref->split(reduction_pos, params.bidx);
2355 transform_ref->axis(reduction_pos + 1)->parallelize(ParallelType::BIDx);
2356
2357 auto transform_ref_rf =
2358 reduction_scheduler_utils::sortAndRFactor(transform_ref);
2359
2360 TransformPropagator propagator(transform_ref_rf);
2361 MaxRootDomainInfoSpanningTree(transform_ref_rf).traverse(&propagator);
2362
2363 int vec_id = std::distance(
2364 transform_ref_rf->domain()->domain().begin(),
2365 std::find_if(
2366 transform_ref_rf->domain()->domain().begin(),
2367 transform_ref_rf->domain()->domain().end(),
2368 [](auto id) {
2369 return id->getParallelType() == ParallelType::Vectorize;
2370 }));
2371 transform_ref_rf->axis(vec_id)->parallelize(ParallelType::Serial);
2372
2373 int unswitch_id = std::distance(
2374 transform_ref_rf->domain()->domain().begin(),
2375 std::find_if(
2376 transform_ref_rf->domain()->domain().begin(),
2377 transform_ref_rf->domain()->domain().end(),
2378 [](auto id) {
2379 return id->getParallelType() == ParallelType::Unswitch;
2380 }));
2381 transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Serial);
2382
2383 scheduler_utils::parallelizeAllLike(
2384 transform_ref_rf, ir_utils::allTvs(&fusion));
2385
2386 ParallelType vec_pt = ParallelType::Vectorize;
2387 tv1->axis(vec_id)->parallelize(vec_pt);
2388 tv9->axis(vec_id)->parallelize(vec_pt);
2389
2390 transform_ref->axis(vec_id)->parallelize(ParallelType::Group);
2391
2392 transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Unswitch);
2393
2394 inlineMost();
2395
2396 // Make sure the reduction expr is converted to GroupedGridReduciton
2397 // and the non-reduction domains of the output TV are either
2398 // grouped or parallelized
2399 GpuLower gpulw(&fusion);
2400 bool validated = false;
2401 for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) {
2402 auto grouped_grid_reduction =
2403 dynamic_cast<kir::GroupedGridWelford*>(expr);
2404 validated = true;
2405 }
2406 TORCH_CHECK(
2407 validated, "Invalid lowered kernel. No GroupedGridWelford found.");
2408
2409 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2410 at::manual_seed(0);
2411
2412 const std::vector<int64_t> input_shape{
2413 params.N, params.H, params.W, params.C};
2414 auto t0 = at::randn(input_shape, options);
2415
2416 FusionExecutor fe;
2417 fe.compileFusion(&fusion, {t0});
2418
2419 // Skip the rest of this test size if the required number of SMs
2420 // exceeds the available SM count
2421 const auto num_required_sms = params.bidx *
2422 ceilDiv(ceilDiv(params.N * params.H * params.W, params.tidy),
2423 params.persistent_buffer);
2424 if (num_required_sms > deviceSMCount()) {
2425 return;
2426 }
2427
2428 auto cg_outputs = fe.runFusion({t0});
2429
2430 auto t1 = t0.to(at::kDouble);
2431 auto t2 = t1.mean({0, 1, 2}).unsqueeze(0).unsqueeze(0).unsqueeze(0);
2432 auto t3 =
2433 at::var(t1, {0, 1, 2}, false).unsqueeze(0).unsqueeze(0).unsqueeze(0);
2434 auto t4 = params.N * params.H * params.W;
2435 auto ref = (t1 - t2 + (t3 * t4));
2436
2437 testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__, "");
2438 };
2439
2440 std::vector<Params> base_params;
2441 base_params.push_back({256, 7, 7, 1, 8, 32, 2, 32, 4});
2442 base_params.push_back({256, 7, 7, 1, 16, 16, 4, 50, 4});
2443 base_params.push_back({128, 7, 7, 1, 16, 16, 4, 32, 4});
2444 base_params.push_back({128, 14, 14, 1, 16, 16, 4, 32, 1});
2445 base_params.push_back({128, 14, 14, 1, 16, 16, 2, 64, 2});
2446 base_params.push_back({128, 14, 14, 1, 8, 32, 4, 50, 4});
2447 base_params.push_back({128, 14, 14, 1, 8, 32, 2, 50, 4});
2448
2449 std::vector<Params> param_vec;
2450 for (const auto base_p : base_params) {
2451 for (const auto c_dim : {512, 1024, 2048}) {
2452 auto tmp = base_p;
2453 tmp.C = c_dim;
2454 param_vec.push_back(tmp);
2455 }
2456 }
2457
2458 for (const auto& params : param_vec) {
2459 test(params);
2460 }
2461}
2462
2463} // namespace jit
2464} // namespace torch
2465#endif // #if defined(USE_CUDA)
2466