1#if defined(USE_CUDA)
2#include <gmock/gmock-matchers.h>
3#include <gtest/gtest.h>
4
5#include <arith.h>
6#include <codegen.h>
7#include <disjoint_set.h>
8#include <executor.h>
9#include <executor_launch_params.h>
10#include <expr_evaluator.h>
11#include <fusion.h>
12#include <fusion_segmenter.h>
13#include <grouped_reduction.h>
14#include <inlining.h>
15#include <ir_all_nodes.h>
16#include <ir_builder.h>
17#include <ir_graphviz.h>
18#include <ir_iostream.h>
19#include <ir_utils.h>
20#include <iter_visitor.h>
21#include <kernel_cache.h>
22#include <kernel_expr_evaluator.h>
23#include <kernel_ir.h>
24#include <kernel_ir_dispatch.h>
25#include <lower2device.h>
26#include <lower_magic_zero.h>
27#include <mutator.h>
28#include <ops/all_ops.h>
29#include <register_interface.h>
30#include <root_domain_map.h>
31#include <scheduler/all_schedulers.h>
32#include <scheduler/reduction_utils.h>
33#include <scheduler/utils.h>
34#include <test/test_gpu_validator.h>
35#include <test/test_utils.h>
36#include <transform_replay.h>
37#include <transform_rfactor.h>
38
39#include <test/cpp/jit/test_utils.h>
40#include <torch/csrc/jit/api/function_impl.h>
41#include <parser.h>
42#include <torch/csrc/jit/ir/irparser.h>
43#include <torch/torch.h>
44
45#include <ATen/cuda/CUDAContext.h>
46#include <ATen/cuda/Exceptions.h>
47#include <c10/cuda/CUDAStream.h>
48
49#include <algorithm>
50#include <iostream>
51#include <sstream>
52#include <thread>
53
54// Tests go in torch::jit
55namespace torch {
56namespace jit {
57
58using namespace torch::jit::fuser::cuda;
59using namespace at::indexing;
60
61// A few smoke tests for IrGraphGenerator
62// (These tests exercise IrGraphGenerator through a non-trivial IR,
63// to make sure that it runs w/o crashing. The actual output is not
64// validated)
65TEST_F(NVFuserTest, FusionIrGraphGenerator_CUDA) {
66 Fusion fusion;
67 FusionGuard fg(&fusion);
68
69 // Make sure we can handle empty IRs
70 TORCH_CHECK(!IrGraphGenerator::toGraphviz(
71 &fusion, IrGraphGenerator::DetailLevel::Basic)
72 .empty());
73
74 // Construct an interesting IR
75 TensorView* tv0 = makeSymbolicTensor(2);
76 fusion.addInput(tv0);
77
78 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(3.141));
79 TensorView* tv3 = broadcast(tv0, {false, true, false, true});
80 TensorView* tv4 =
81 reductionOp(BinaryOpType::Add, {2}, IrBuilder::create<Double>(0), tv3);
82 TensorView* tv5 = clamp(
83 tv4, IrBuilder::create<Double>(0.f), IrBuilder::create<Double>(1.f));
84 TensorView* tv6 = add(tv2, tv2);
85
86 // Another checkpoint before adding outputs
87 TORCH_CHECK(!IrGraphGenerator::toGraphviz(
88 &fusion, IrGraphGenerator::DetailLevel::Explicit)
89 .empty());
90
91 fusion.addOutput(tv6);
92
93 tv4->axis(2)->parallelize(ParallelType::BIDy);
94 tv6->merge(0);
95 tv6->split(0, 4);
96 tv6->axis(0)->parallelize(ParallelType::BIDx);
97 tv5->reorder({{-1, 0}});
98 tv2->computeAt(tv6, 1);
99
100 // Another checkpoint with more node types
101 TORCH_CHECK(!IrGraphGenerator::toGraphviz(
102 &fusion, IrGraphGenerator::DetailLevel::ComputeOnly)
103 .empty());
104
105 for (Val* val : fusion.vals()) {
106 if (!val->isFusionInput() &&
107 val->getValType().value() == ValType::TensorView) {
108 TensorView* tv = static_cast<TensorView*>(val);
109 tv->axis(-1)->parallelize(ParallelType::TIDx);
110 }
111 }
112
113 // Final IR graph
114 TORCH_CHECK(!IrGraphGenerator::toGraphviz(
115 &fusion, IrGraphGenerator::DetailLevel::Verbose)
116 .empty());
117}
118
119TEST_F(NVFuserTest, FusionDispatch_CUDA) {
120 Fusion fusion;
121 FusionGuard fg(&fusion);
122
123 Double* f = IrBuilder::create<Double>(2.f);
124 std::stringstream ss1, ss2, ss3;
125 ss1 << f;
126 ss2 << static_cast<Val*>(f);
127 ss3 << static_cast<Statement*>(f);
128 TORCH_CHECK(
129 ss1.str().compare(ss2.str()) == 0 && ss1.str().compare(ss3.str()) == 0,
130 "Error with dispatch system where results differ by passing Double* vs Val* vs Statement*.");
131}
132
133// Evaluate basic scalar operations with constant values
134TEST_F(NVFuserTest, FusionExprEvalConstants_CUDA) {
135 Fusion fusion;
136 FusionGuard fg(&fusion);
137
138 ExpressionEvaluator evaluator(&fusion);
139
140 auto* a = IrBuilder::create<Int>(7);
141 auto* b = IrBuilder::create<Int>(3);
142
143 // Avoid div operation because it casts int operands to float
144 checkIntValue(evaluator, neg(a), -7);
145 checkIntValue(evaluator, add(a, b), 10);
146 checkIntValue(evaluator, neg(mul(sub(a, b), add(a, b))), -40);
147 checkIntValue(evaluator, mod(a, b), 1);
148 checkIntValue(evaluator, ceilDiv(a, b), 3);
149}
150
151TEST_F(NVFuserTest, FusionExprEvalDouble_CUDA) {
152 auto fusion = std::make_unique<Fusion>();
153 FusionGuard fg(fusion.get());
154 auto ten = IrBuilder::create<Double>(10);
155 auto two = IrBuilder::create<Double>(2);
156 auto three = IrBuilder::create<Double>(3);
157 auto val = castOp(DataType::Int, ceilDiv(sub(ten, two), three));
158 auto reference = static_cast<int64_t>(std::ceil((10.0 - 2.0) / 3.0));
159 TORCH_CHECK(reference == val->evaluateInt());
160}
161
162// Evaluate basic scalar operations with bound values
163TEST_F(NVFuserTest, FusionExprEvalBindings_CUDA) {
164 Fusion fusion;
165 FusionGuard fg(&fusion);
166
167 ExpressionEvaluator evaluator(&fusion);
168
169 auto* a = IrBuilder::create<Int>();
170 auto* b = IrBuilder::create<Int>();
171 auto* c = add(a, b);
172 auto* d = neg(ceilDiv(c, b));
173 auto* e = IrBuilder::create<Int>(0);
174
175 // trying to evaluate before binding should give empty results
176 TORCH_CHECK(!evaluator.evaluate(a).has_value());
177 TORCH_CHECK(!evaluator.evaluate(d).has_value());
178
179 evaluator.bind(a, 7);
180 evaluator.bind(b, 3);
181
182 // can't bind to the results of expressions
183 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
184 ASSERT_ANY_THROW(evaluator.bind(c, 100));
185
186 // can't bind to concrete values
187 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
188 ASSERT_ANY_THROW(evaluator.bind(e, 100));
189
190 checkIntValue(evaluator, c, 10);
191 checkIntValue(evaluator, sub(a, b), 4);
192 checkIntValue(evaluator, mod(a, b), 1);
193 checkIntValue(evaluator, ceilDiv(a, b), 3);
194 checkIntValue(evaluator, d, -4);
195
196 // Reset evaluation context
197 evaluator = ExpressionEvaluator(&fusion);
198
199 evaluator.bind(a, 2);
200 evaluator.bind(b, 5);
201
202 checkIntValue(evaluator, c, 7);
203 checkIntValue(evaluator, sub(a, b), -3);
204 checkIntValue(evaluator, mod(a, b), 2);
205 checkIntValue(evaluator, ceilDiv(a, b), 1);
206 checkIntValue(evaluator, d, -2);
207}
208
209// Evaluate expressions in a simple IR
210TEST_F(NVFuserTest, FusionExprEvalBasic_CUDA) {
211 Fusion fusion;
212 FusionGuard fg(&fusion);
213
214 // Create a non-trivial IR
215 TensorView* tv0 = makeSymbolicTensor(2);
216 TensorView* tv1 = makeSymbolicTensor(2);
217
218 fusion.addInput(tv0);
219 fusion.addInput(tv1);
220
221 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2.0));
222 TensorView* tv3 = add(tv0, tv2);
223
224 fusion.addOutput(tv3);
225
226 tv3->split(0, 4);
227
228 tv0->computeAt(tv3, 1);
229 tv1->computeAt(tv3, 1);
230
231 tv3->axis(0)->parallelize(ParallelType::BIDx);
232 tv2->axis(1)->parallelize(ParallelType::Unroll);
233 tv3->axis(1)->parallelize(ParallelType::Unroll);
234 tv2->axis(-1)->parallelize(ParallelType::TIDx);
235 tv3->axis(-1)->parallelize(ParallelType::TIDx);
236
237 // 1. Create an evaluator
238 ExpressionEvaluator evaluator(&fusion);
239
240 // 2. Bind values
241 //
242 // IMPORTANT:
243 // a. The bindings are only as stable as the Vals are in the fusion graph
244 // b. You must use the original (rootDomain) extents
245 // (ex. `tv0->getRootDomain()[0]->extent()`
246 // instead of `tv0->axis(0)->extent()`)
247 //
248 evaluator.bind(tv0->getRootDomain()[0]->extent(), 6);
249 evaluator.bind(tv0->getRootDomain()[1]->extent(), 128);
250 evaluator.bind(tv1->getRootDomain()[0]->extent(), 6);
251 evaluator.bind(tv1->getRootDomain()[1]->extent(), 128);
252
253 // 3. Evaluate and check result values
254 TORCH_CHECK(tv2->domain()->nDims() == 3);
255 checkIntValue(evaluator, tv2->axis(0)->extent(), 2);
256 checkIntValue(evaluator, tv2->axis(1)->extent(), 4);
257 checkIntValue(evaluator, tv2->axis(2)->extent(), 128);
258
259 TORCH_CHECK(tv3->domain()->nDims() == 3);
260 checkIntValue(evaluator, tv3->axis(0)->extent(), 2);
261 checkIntValue(evaluator, tv3->axis(1)->extent(), 4);
262 checkIntValue(evaluator, tv3->axis(2)->extent(), 128);
263}
264
265// Evaluate expressions in a more complex IR
266TEST_F(NVFuserTest, FusionExprEvalComplex_CUDA) {
267 Fusion fusion;
268 FusionGuard fg(&fusion);
269
270 TensorView* tv0 = makeSymbolicTensor(2);
271 fusion.addInput(tv0);
272
273 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(-1.0));
274 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(3.0));
275 TensorView* tv3 = mul(tv0, IrBuilder::create<Double>(2.0));
276 TensorView* tv4 = add(tv2, tv1);
277 TensorView* tv5 = add(tv4, tv3);
278 TensorView* tv6 = add(tv0, tv3);
279
280 fusion.addOutput(tv5);
281 fusion.addOutput(tv6);
282
283 tv5->reorder({{-1, 0}});
284
285 tv6->split(0, 5);
286 tv5->merge(0);
287
288 // 1. Create an evaluator
289 ExpressionEvaluator evaluator(&fusion);
290
291 // 2. Bind values
292 evaluator.bind(tv0->getRootDomain()[0]->extent(), 129);
293 evaluator.bind(tv0->getRootDomain()[1]->extent(), 127);
294
295 // Evaluate and check extent values
296 TORCH_CHECK(tv0->domain()->nDims() == 2);
297 checkIntValue(evaluator, tv0->axis(0)->extent(), 129);
298 checkIntValue(evaluator, tv0->axis(1)->extent(), 127);
299
300 TORCH_CHECK(tv3->domain()->nDims() == 2);
301 checkIntValue(evaluator, tv3->axis(0)->extent(), 129);
302 checkIntValue(evaluator, tv3->axis(1)->extent(), 127);
303
304 TORCH_CHECK(tv4->domain()->nDims() == 2);
305 checkIntValue(evaluator, tv4->axis(0)->extent(), 129);
306 checkIntValue(evaluator, tv4->axis(1)->extent(), 127);
307
308 TORCH_CHECK(tv5->domain()->nDims() == 1);
309 checkIntValue(evaluator, tv5->axis(0)->extent(), 16383);
310
311 TORCH_CHECK(tv6->domain()->nDims() == 3);
312 checkIntValue(evaluator, tv6->axis(0)->extent(), 26);
313 checkIntValue(evaluator, tv6->axis(1)->extent(), 5);
314 checkIntValue(evaluator, tv6->axis(2)->extent(), 127);
315}
316
317// Evaluate expressions post lowering
318TEST_F(NVFuserTest, FusionExprEvalPostLower_CUDA) {
319 Fusion fusion;
320 FusionGuard fg(&fusion);
321
322 // Create a non-trivial IR
323 TensorView* tv0 = makeSymbolicTensor(2);
324 TensorView* tv1 = makeSymbolicTensor(2);
325
326 fusion.addInput(tv0);
327 fusion.addInput(tv1);
328
329 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2.0));
330 TensorView* tv3 = add(tv0, tv2);
331
332 fusion.addOutput(tv3);
333
334 tv3->split(0, 4);
335
336 tv0->computeAt(tv3, 1);
337 tv1->computeAt(tv3, 1);
338
339 tv3->axis(0)->parallelize(ParallelType::BIDx);
340 tv2->axis(1)->parallelize(ParallelType::Unroll);
341 tv3->axis(1)->parallelize(ParallelType::Unroll);
342 tv2->axis(-1)->parallelize(ParallelType::TIDx);
343 tv3->axis(-1)->parallelize(ParallelType::TIDx);
344
345 auto* bid_x = add(tv3->axis(0)->extent(), IrBuilder::create<Int>(0));
346 auto* tid_x = add(tv3->axis(-1)->extent(), IrBuilder::create<Int>(0));
347
348 // Lower
349 GpuLower gpulw(&fusion);
350
351 // 1. Create an evaluation context
352 ExpressionEvaluator evaluator(&fusion);
353
354 // 2. Bind values
355 evaluator.bind(tv0->getRootDomain()[0]->extent(), 6);
356 evaluator.bind(tv0->getRootDomain()[1]->extent(), 128);
357 evaluator.bind(tv1->getRootDomain()[0]->extent(), 6);
358 evaluator.bind(tv1->getRootDomain()[1]->extent(), 128);
359
360 // 3. Evaluate and check result values
361 TORCH_CHECK(tv2->domain()->nDims() == 3);
362 checkIntValue(evaluator, tv2->axis(0)->extent(), 2);
363 checkIntValue(evaluator, tv2->axis(1)->extent(), 4);
364 checkIntValue(evaluator, tv2->axis(2)->extent(), 128);
365
366 TORCH_CHECK(tv3->domain()->nDims() == 3);
367 checkIntValue(evaluator, tv3->axis(0)->extent(), 2);
368 checkIntValue(evaluator, tv3->axis(1)->extent(), 4);
369 checkIntValue(evaluator, tv3->axis(2)->extent(), 128);
370
371 checkIntValue(evaluator, bid_x, 2);
372 checkIntValue(evaluator, tid_x, 128);
373}
374
375// Kernel IR: Evaluate basic scalar operations with constant values
376TEST_F(NVFuserTest, FusionKernelExprEvalConstants_CUDA) {
377 Fusion fusion;
378 kir::Kernel kernel(&fusion);
379 FusionGuard fg((&kernel)->as<Fusion>());
380
381 auto a = IrBuilder::create<Int>(7);
382 auto b = IrBuilder::create<Int>(3);
383 auto c = IrBuilder::subExpr(a, b);
384 auto d = IrBuilder::divExpr(a, b);
385 auto e = IrBuilder::mulExpr(c, d);
386
387 kir::ExpressionEvaluator evaluator;
388
389 checkIntValue(evaluator, IrBuilder::negExpr(a), -7);
390 checkIntValue(evaluator, IrBuilder::addExpr(a, b), 10);
391 checkIntValue(evaluator, IrBuilder::negExpr(e), -8);
392 checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1);
393 checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3);
394}
395
396// Kernel IR: Evaluate basic scalar operations with bound values
397TEST_F(NVFuserTest, FusionKernelExprEvalBindings_CUDA) {
398 Fusion fusion;
399 kir::Kernel kernel(&fusion);
400 FusionGuard fg((&kernel)->as<Fusion>());
401
402 kir::ExpressionEvaluator evaluator;
403
404 auto a = IrBuilder::create<Int>(c10::nullopt);
405 auto b = IrBuilder::create<Int>(c10::nullopt);
406 auto c = IrBuilder::addExpr(a, b);
407 auto d = IrBuilder::negExpr(IrBuilder::ceilDivExpr(c, b));
408 auto e = IrBuilder::create<Int>(0);
409
410 // trying to evaluate before binding should give empty results
411 TORCH_CHECK(!evaluator.evaluate(a).has_value());
412 TORCH_CHECK(!evaluator.evaluate(d).has_value());
413
414 evaluator.bind(a, 7);
415 evaluator.bind(b, 3);
416
417 // can't bind to the results of expressions
418 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
419 ASSERT_ANY_THROW(evaluator.bind(c, 100));
420
421 // can't bind to concrete values
422 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
423 ASSERT_ANY_THROW(evaluator.bind(e, 100));
424
425 checkIntValue(evaluator, c, 10);
426 checkIntValue(evaluator, IrBuilder::subExpr(a, b), 4);
427 checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1);
428 checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3);
429 checkIntValue(evaluator, d, -4);
430
431 // Reset the evaluation context
432 evaluator = kir::ExpressionEvaluator();
433
434 evaluator.bind(a, 2);
435 evaluator.bind(b, 5);
436
437 checkIntValue(evaluator, c, 7);
438 checkIntValue(evaluator, IrBuilder::subExpr(a, b), -3);
439 checkIntValue(evaluator, IrBuilder::modExpr(a, b), 2);
440 checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 1);
441 checkIntValue(evaluator, d, -2);
442}
443
444TEST_F(NVFuserTest, FusionClear_CUDA) {
445 Fusion fusion;
446 FusionGuard fg(&fusion);
447
448 // 1. Create a dummy IR
449
450 {
451 TensorView* tv0 = makeSymbolicTensor(2);
452 TensorView* tv1 = makeSymbolicTensor(2);
453
454 fusion.addInput(tv0);
455 fusion.addInput(tv1);
456
457 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2.0));
458 TensorView* tv3 = add(tv0, tv2);
459
460 fusion.addOutput(tv3);
461
462 tv3->split(0, 4);
463 tv0->computeAt(tv3, 1);
464 tv1->computeAt(tv3, 1);
465
466 tv3->axis(0)->parallelize(ParallelType::BIDx);
467 tv2->axis(1)->parallelize(ParallelType::Unroll);
468 tv3->axis(-1)->parallelize(ParallelType::TIDx);
469 }
470
471 // 2. Clear the IR
472
473 fusion.clear();
474
475 TORCH_CHECK(fusion.unordered_exprs().empty());
476 TORCH_CHECK(fusion.vals().empty());
477
478 TORCH_CHECK(fusion.inputs().empty());
479 TORCH_CHECK(fusion.outputs().empty());
480
481 TORCH_CHECK(ir_utils::getReductionOps(&fusion).empty());
482
483 // 3. Rebuild the IR
484
485 {
486 TensorView* tv0 = makeSymbolicTensor(3);
487 TensorView* tv1 = makeSymbolicTensor(3);
488 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2.0));
489 TensorView* tv3 = add(tv0, tv2);
490
491 fusion.addInput(tv0);
492 fusion.addInput(tv1);
493 fusion.addOutput(tv3);
494
495 // tv3 [i0, i1, i2]
496 tv3->reorder({{0, 2}, {2, 0}});
497 // tv3 [i2, i1, i0]
498 tv3->split(-1, 4);
499 // tv3 [i2, i1, i0outer, i0inner{4}]
500 tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
501 // tv3 [i0outer, i0inner{4}, i1, i2]
502 tv0->computeAt(tv3, -1);
503 tv1->computeAt(tv3, -1);
504 tv3->axis(1)->parallelize(ParallelType::BIDx);
505 }
506
507 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
508
509 at::Tensor input1 = at::randn({16, 8, 8}, options);
510 at::Tensor input2 = at::randn_like(input1);
511
512 FusionExecutor fe;
513 fe.compileFusion(&fusion, {input1, input2});
514 auto outputs = fe.runFusion({input1, input2});
515
516 at::Tensor tv2_ref = input2 + 2.0;
517 at::Tensor output_ref = input1 + tv2_ref;
518
519 TORCH_CHECK(output_ref.equal(outputs[0]));
520}
521
522TEST_F(NVFuserTest, FusionCopy_CUDA) {
523 Fusion original_fusion;
524
525 // Create the test IR
526 {
527 FusionGuard fg(&original_fusion);
528
529 auto tv0 = makeSymbolicTensor(3);
530 auto tv1 = makeSymbolicTensor(3);
531 auto tv2 = add(tv1, IrBuilder::create<Double>(2.0));
532 auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
533
534 original_fusion.addInput(tv0);
535 original_fusion.addInput(tv1);
536 original_fusion.addOutput(tv3);
537
538 tv3->reorder({{0, 2}, {2, 0}});
539 tv3->split(-1, 4);
540 tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
541
542 tv0->computeAt(tv3, -1);
543 tv1->computeAt(tv3, -1);
544
545 tv3->axis(0)->parallelize(ParallelType::BIDx);
546 tv3->axis(-1)->parallelize(ParallelType::TIDx);
547 }
548
549 // Test copy before lowering
550 Fusion clone = original_fusion;
551
552 // Compare IR dumps
553 std::stringstream original_ir;
554 std::stringstream clone_ir;
555 original_ir << original_fusion;
556 clone_ir << clone;
557 ASSERT_EQ(original_ir.str(), clone_ir.str());
558
559 // Lower original fusion
560 std::string original_kernel;
561 {
562 // TODO(kir): remove this guard once we implement the cuda codegen visitor
563 FusionGuard fg(&original_fusion);
564 original_kernel =
565 codegen::generateCudaKernel(GpuLower(&original_fusion).kernel());
566 }
567
568 // Make sure the "before lowering" clone was not mutated
569 // while lowering the original fusion IR
570 std::stringstream before_lowering_ir;
571 before_lowering_ir << clone;
572 ASSERT_EQ(original_ir.str(), before_lowering_ir.str());
573
574 // Test copy after lowering (including assignment operator)
575 Fusion before_lowering = clone;
576 clone = original_fusion;
577
578 // Compare IR dumps
579 std::stringstream original_lowered_ir;
580 std::stringstream clone_lowered_ir;
581 original_lowered_ir << original_fusion;
582 clone_lowered_ir << clone;
583 ASSERT_EQ(original_lowered_ir.str(), clone_lowered_ir.str());
584
585 // Lower the "before lowering" and compare kernels
586 std::string clone_kernel;
587 {
588 // TODO(kir): remove this guard once we implement the cuda codegen visitor
589 FusionGuard fg(&before_lowering);
590 clone_kernel =
591 codegen::generateCudaKernel(GpuLower(&before_lowering).kernel());
592 }
593 ASSERT_EQ(original_kernel, clone_kernel);
594}
595
596TEST_F(NVFuserTest, FusionMove_CUDA) {
597 Fusion fusion;
598
599 // Create the test IR
600 {
601 FusionGuard fg(&fusion);
602
603 auto tv0 = makeSymbolicTensor(3);
604 auto tv1 = makeSymbolicTensor(3);
605 auto tv2 = add(tv1, IrBuilder::create<Double>(2.0));
606 auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
607
608 fusion.addInput(tv0);
609 fusion.addInput(tv1);
610 fusion.addOutput(tv3);
611
612 tv3->reorder({{0, 2}, {2, 0}});
613 tv3->split(-1, 4);
614 tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
615
616 tv0->computeAt(tv3, -1);
617 tv1->computeAt(tv3, -1);
618
619 tv3->axis(0)->parallelize(ParallelType::BIDx);
620 tv3->axis(-1)->parallelize(ParallelType::TIDx);
621 }
622
623 std::stringstream original_ir;
624 original_ir << fusion;
625
626 // Test move before lowering
627 Fusion another_fusion = std::move(fusion);
628
629 // Check that the original fusion is "empty"
630 //
631 // IMPORTANT: these checks assume knowledge of the internal
632 // implementation of the move operations. General uses
633 // should only assume that the moved-from object is in
634 // a valid, but unspecified state. This is similar to the
635 // standard library containers:
636 // https://en.cppreference.com/w/cpp/utility/move
637 //
638 TORCH_CHECK(fusion.unordered_exprs().empty());
639 TORCH_CHECK(fusion.vals().empty());
640 TORCH_CHECK(fusion.inputs().empty());
641 TORCH_CHECK(fusion.outputs().empty());
642
643 // clear() has no pre-conditions so it's valid to call on a moved-from object
644 fusion.clear();
645
646 // Compare IR dumps
647 std::stringstream another_ir;
648 another_ir << another_fusion;
649 ASSERT_EQ(original_ir.str(), another_ir.str());
650
651 // Lower the fusion IR
652 GpuLower lower(&another_fusion);
653
654 std::stringstream lowered_ir;
655 lowered_ir << another_fusion;
656
657 // Test move assignment after lowering
658 fusion = std::move(another_fusion);
659
660 // Compare IR dumps
661 std::stringstream moved_lowered_ir;
662 moved_lowered_ir << fusion;
663 ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str());
664}
665
666TEST_F(NVFuserTest, FusionSimpleArith_CUDA) {
667 std::stringstream ss1, ss2;
668
669 Fusion fusion;
670 FusionGuard fg(&fusion);
671
672 Double* d1 = IrBuilder::create<Double>(1.f);
673 Double* d2 = IrBuilder::create<Double>(2.f);
674 Double* d3 = IrBuilder::create<Double>();
675
676 // Disrupt the fusion to make sure guard works well
677 {
678 Fusion fusion2;
679 FusionGuard fg(&fusion2);
680
681 Double* d1 = IrBuilder::create<Double>(1.f);
682 Double* d2 = IrBuilder::create<Double>(2.f);
683 add(d1, d2);
684 ss2 << fusion2;
685 }
686
687 IrBuilder::create<BinaryOp>(BinaryOpType::Add, d3, d1, d2);
688 ss1 << fusion;
689
690 TORCH_CHECK(
691 ss1.str().compare(ss2.str()) == 0,
692 "Error where explicit add nodes don't match implicit add nodes.");
693}
694
695TEST_F(NVFuserTest, FusionScalarTypePromote_CUDA) {
696 Fusion fusion;
697 FusionGuard fg(&fusion);
698
699 Bool* b = IrBuilder::create<Bool>(true);
700 Double* d = IrBuilder::create<Double>(4.f);
701 Int* i = IrBuilder::create<Int>(3);
702 ComplexDouble* c =
703 IrBuilder::create<ComplexDouble>(c10::complex<double>(1, 2));
704
705 TORCH_CHECK(add(b, b)->getDataType() == DataType::Bool);
706 TORCH_CHECK(add(b, d)->getDataType() == DataType::Double);
707 TORCH_CHECK(add(b, i)->getDataType() == DataType::Int);
708 TORCH_CHECK(add(b, c)->getDataType() == DataType::ComplexDouble);
709
710 TORCH_CHECK(add(d, b)->getDataType() == DataType::Double);
711 TORCH_CHECK(add(d, d)->getDataType() == DataType::Double);
712 TORCH_CHECK(add(d, i)->getDataType() == DataType::Double);
713 TORCH_CHECK(add(d, c)->getDataType() == DataType::ComplexDouble);
714
715 TORCH_CHECK(add(i, b)->getDataType() == DataType::Int);
716 TORCH_CHECK(add(i, d)->getDataType() == DataType::Double);
717 TORCH_CHECK(add(i, i)->getDataType() == DataType::Int);
718 TORCH_CHECK(add(i, c)->getDataType() == DataType::ComplexDouble);
719
720 TORCH_CHECK(add(c, b)->getDataType() == DataType::ComplexDouble);
721 TORCH_CHECK(add(c, d)->getDataType() == DataType::ComplexDouble);
722 TORCH_CHECK(add(c, i)->getDataType() == DataType::ComplexDouble);
723 TORCH_CHECK(add(c, c)->getDataType() == DataType::ComplexDouble);
724}
725
726TEST_F(NVFuserTest, FusionComplexAbsTypes_CUDA) {
727 Fusion fusion;
728 FusionGuard fg(&fusion);
729
730 auto options = at::TensorOptions().device(at::kCUDA, 0);
731 auto tensor_cf = at::randn({4, 4, 4}, options.dtype(at::kComplexFloat));
732 auto tensor_cd = at::randn({4, 4, 4}, options.dtype(at::kComplexDouble));
733
734 auto type_cf = TensorType::create(tensor_cf);
735 auto tv_cf = IrBuilder::create<TensorView>(type_cf);
736 auto type_cd = TensorType::create(tensor_cd);
737 auto tv_cd = IrBuilder::create<TensorView>(type_cd);
738
739 TORCH_CHECK(
740 tensor_cf.abs().scalar_type() ==
741 data_type_to_aten(abs(tv_cf)->getDataType().value()));
742 TORCH_CHECK(
743 tensor_cd.abs().scalar_type() ==
744 data_type_to_aten(abs(tv_cd)->getDataType().value()));
745}
746
747TEST_F(NVFuserTest, FusionRegister_CUDA) {
748 Fusion fusion;
749 FusionGuard fg(&fusion);
750 Double* v1 = IrBuilder::create<Double>(1.f);
751 Double* v2 = IrBuilder::create<Double>(2.f);
752 Val* v3 = binaryOp(BinaryOpType::Add, v1, v2);
753 Val* v4 = binaryOp(BinaryOpType::Add, v1, v2);
754 TORCH_CHECK(v1->name() + 1 == v2->name());
755 TORCH_CHECK(v2->name() + 1 == v3->name());
756 TORCH_CHECK(v3->name() + 1 == v4->name());
757 TORCH_CHECK(v3->definition()->name() + 1 == v4->definition()->name());
758}
759
760// dummy expr with 2 outputs only for toposort test.
761struct DummyExpr : public Expr {
762 ~DummyExpr() = default;
763 DummyExpr(
764 IrBuilderPasskey passkey,
765 Val* _outlhs,
766 Val* _outrhs,
767 Val* _lhs,
768 Val* _rhs)
769 : Expr(passkey, ExprType::UnaryOp) // Not terribly safe...
770 {
771 addOutput(_outlhs);
772 addOutput(_outrhs);
773 addInput(_lhs);
774 addInput(_rhs);
775 }
776 DummyExpr(const DummyExpr& other) = delete;
777 DummyExpr& operator=(const DummyExpr& other) = delete;
778 DummyExpr(DummyExpr&& other) = delete;
779 DummyExpr& operator=(DummyExpr&& other) = delete;
780 Expr* shallowCopy() const override {
781 return nullptr;
782 }
783};
784
785TEST_F(NVFuserTest, FusionTopoSort_CUDA) {
786 Fusion fusion;
787 FusionGuard fg(&fusion);
788
789 // e0: v3, v2 = dummy(v1, v0)
790 // e1: v4 = add(v3, v2)
791 // e2: v5 = add(v2, v4)
792 // e3: v6 = add(v5, v5)
793 Double* v0 = IrBuilder::create<Double>();
794 Double* v1 = IrBuilder::create<Double>();
795 Double* v2 = IrBuilder::create<Double>();
796 Double* v3 = IrBuilder::create<Double>();
797 Double* v4 = IrBuilder::create<Double>();
798 Double* v5 = IrBuilder::create<Double>();
799 Double* v6 = IrBuilder::create<Double>();
800
801 std::vector<Val*> inputs = {v0, v1};
802 for (auto val : inputs) {
803 fusion.addInput(val);
804 }
805
806 Expr* e0 = IrBuilder::create<DummyExpr>(v3, v2, v1, v0);
807 Expr* e1 = IrBuilder::create<BinaryOp>(BinaryOpType::Add, v4, v3, v2);
808 Expr* e2 = IrBuilder::create<BinaryOp>(BinaryOpType::Add, v5, v2, v4);
809 Expr* e3 = IrBuilder::create<BinaryOp>(BinaryOpType::Add, v6, v5, v5);
810
811 fusion.addOutput(v2);
812 fusion.addOutput(v3);
813 auto exprs = fusion.exprs();
814 TORCH_CHECK(exprs.size() == 1, "Found ", exprs.size(), " but expecting 1");
815 TORCH_CHECK(exprs[0] == e0);
816
817 fusion.addOutput(v5);
818 exprs = fusion.exprs();
819 TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3");
820 TORCH_CHECK(exprs[0] == e0);
821 TORCH_CHECK(exprs[1] == e1);
822 TORCH_CHECK(exprs[2] == e2);
823
824 fusion.addOutput(v4);
825 exprs = fusion.exprs();
826 TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3");
827 TORCH_CHECK(exprs[0] == e0);
828 TORCH_CHECK(exprs[1] == e1);
829 TORCH_CHECK(exprs[2] == e2);
830
831 fusion.addOutput(v6);
832 exprs = fusion.exprs();
833 TORCH_CHECK(exprs.size() == 4, "Found ", exprs.size(), " but expecting 4");
834 TORCH_CHECK(exprs[0] == e0);
835 TORCH_CHECK(exprs[1] == e1);
836 TORCH_CHECK(exprs[2] == e2);
837 TORCH_CHECK(exprs[3] == e3);
838
839 TORCH_CHECK(v2->definition()->name() == 0);
840 TORCH_CHECK(v3->definition()->name() == 0);
841 TORCH_CHECK(v4->definition()->name() == 1);
842 TORCH_CHECK(v5->definition()->name() == 2);
843 TORCH_CHECK(v6->definition()->name() == 3);
844}
845
846TEST_F(NVFuserTest, FusionTensor_CUDA) {
847 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
848
849 Fusion fusion;
850 FusionGuard fg(&fusion);
851
852 {
853 auto tensor = at::randn({2, 3, 4, 5}, options);
854 auto tensor_type = TensorType::create(tensor);
855 auto fuser_tensor = IrBuilder::create<TensorView>(tensor_type);
856 TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim());
857 TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float);
858 TORCH_CHECK(fuser_tensor->domain() != nullptr);
859 for (const auto i : c10::irange(fuser_tensor->nDims())) {
860 // size 1 dimension are makred as broadcast
861 TORCH_CHECK(
862 fuser_tensor->axis(i)->isBroadcast() == (tensor.sizes()[i] == 1));
863 // check contiguity information;
864 TORCH_CHECK(fuser_tensor->domain()->contiguity()[i]);
865 }
866 }
867
868 // TensorType::create fills stride_properties, which helps us to mark
869 // IterDomain properly
870 // Note: implementation could change, depending on how much we want to invest
871 // in our home-brew contiguity coalescing. For now let's make sure that we
872 // properly test what we are using.
873 {
874 auto tensor = at::randn({4, 4, 4}, options);
875 auto sliced_tensor = tensor.slice(1, 0, -1, 2);
876
877 auto tensor_type = TensorType::create(sliced_tensor);
878 auto fuser_tensor = IrBuilder::create<TensorView>(tensor_type);
879 TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim());
880 TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float);
881 TORCH_CHECK(fuser_tensor->domain() != nullptr);
882 for (const auto i : c10::irange(fuser_tensor->nDims())) {
883 // size 1 dimension are makred as broadcast
884 TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false);
885 }
886 TORCH_CHECK(fuser_tensor->domain()->contiguity()[0]);
887 TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]);
888 TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]);
889 }
890
891 {
892 auto tensor = at::randn({2, 3, 4, 5}, options);
893 auto permuted_tensor = tensor.permute({0, 3, 1, 2});
894 auto tensor_type = TensorType::create(permuted_tensor);
895 auto fuser_tensor = IrBuilder::create<TensorView>(tensor_type);
896 TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim());
897 TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float);
898 TORCH_CHECK(fuser_tensor->domain() != nullptr);
899 for (const auto i : c10::irange(fuser_tensor->nDims())) {
900 // size 1 dimension are makred as broadcast
901 TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false);
902 }
903 TORCH_CHECK(!fuser_tensor->domain()->contiguity()[0]);
904 TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]);
905 TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]);
906 TORCH_CHECK(!fuser_tensor->domain()->contiguity()[3]);
907 }
908}
909
910TEST_F(NVFuserTest, FusionFilterVals_CUDA) {
911 Fusion fusion;
912 FusionGuard fg(&fusion);
913
914 auto tv0 = makeSymbolicTensor(1);
915 auto tv1 = makeSymbolicTensor(1);
916 auto scalar0 = IrBuilder::create<Double>(0);
917 auto scalar1 = IrBuilder::create<Int>(0);
918 auto scalar2 = IrBuilder::create<Int>(1);
919
920 const std::vector<Val*> vals = {tv0, scalar0, tv1, scalar1, scalar2};
921
922 std::vector<TensorView*> tvs(
923 ir_utils::filterByType<TensorView>(vals).begin(),
924 ir_utils::filterByType<TensorView>(vals).end());
925 TORCH_CHECK(tvs.size() == 2);
926 TORCH_CHECK(tvs[0] == tv0);
927 TORCH_CHECK(tvs[1] == tv1);
928
929 std::vector<Double*> floats(
930 ir_utils::filterByType<Double>(vals).begin(),
931 ir_utils::filterByType<Double>(vals).end());
932 TORCH_CHECK(floats.size() == 1);
933 TORCH_CHECK(floats[0] == scalar0);
934
935 std::vector<Int*> ints(
936 ir_utils::filterByType<Int>(vals).begin(),
937 ir_utils::filterByType<Int>(vals).end());
938 TORCH_CHECK(ints.size() == 2);
939 TORCH_CHECK(ints[0] == scalar1);
940 TORCH_CHECK(ints[1] == scalar2);
941
942 TORCH_CHECK(
943 ir_utils::filterByType<Expr>(vals).begin() ==
944 ir_utils::filterByType<Expr>(vals).end(),
945 "Not expecting any results");
946}
947
948TEST_F(NVFuserTest, FusionTVSplit_CUDA) {
949 Fusion fusion;
950 FusionGuard fg(&fusion);
951
952 TensorView* tv = makeSymbolicTensor(3);
953
954 tv = tv->split(2, 2);
955 TORCH_CHECK(tv->nDims() == 4);
956 Expr* outer = tv->axis(2)->extent()->definition();
957
958 TORCH_CHECK(
959 outer->getExprType().value() == ExprType::BinaryOp &&
960 static_cast<BinaryOp*>(outer)->getBinaryOpType() ==
961 BinaryOpType::CeilDiv &&
962 static_cast<BinaryOp*>(outer)->lhs()->sameAs(
963 tv->getRootDomain()[2]->extent()) &&
964 static_cast<Int*>(static_cast<BinaryOp*>(outer)->rhs())
965 ->sameAs(IrBuilder::create<Int>(2)));
966
967 IterDomain* inner = static_cast<IterDomain*>(tv->axis(3));
968 TORCH_CHECK(
969 inner->extent()->isScalar() &&
970 static_cast<Int*>(inner->extent())->isConst() &&
971 static_cast<Int*>(inner->extent())->value().value() == 2);
972}
973
974TEST_F(NVFuserTest, FusionTVMerge_CUDA) {
975 Fusion fusion;
976 FusionGuard fg(&fusion);
977
978 TensorView* tv = makeSymbolicTensor(3);
979
980 tv = tv->merge(1);
981 Expr* axisOp = tv->axis(1)->extent()->definition();
982
983 TORCH_CHECK(
984 tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp &&
985 static_cast<BinaryOp*>(axisOp)->getBinaryOpType() == BinaryOpType::Mul &&
986 static_cast<BinaryOp*>(axisOp)->lhs() ==
987 tv->getRootDomain()[1]->extent() &&
988 static_cast<BinaryOp*>(axisOp)->rhs() ==
989 tv->getRootDomain()[2]->extent());
990}
991
992TEST_F(NVFuserTest, FusionTVReorder_CUDA) {
993 Fusion fusion;
994 FusionGuard fg(&fusion);
995
996 std::unordered_map<int, int> shift_right{{-1, 0}};
997
998 std::unordered_map<int, int> shift_left{{0, -1}};
999
1000 std::unordered_map<int, int> shift_left_2{{0, -1}, {1, 0}, {2, 1}};
1001
1002 std::unordered_map<int, int> swap{{0, 2}, {2, 0}};
1003
1004 auto tv = makeSymbolicTensor(3);
1005 std::vector<IterDomain*> ref;
1006 ref = std::vector<IterDomain*>(
1007 tv->domain()->domain().begin(), tv->domain()->domain().end());
1008
1009 tv->reorder(shift_left);
1010 for (const auto i : c10::irange(tv->nDims())) {
1011 TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1)));
1012 }
1013
1014 tv = makeSymbolicTensor(3);
1015 ref = std::vector<IterDomain*>(
1016 tv->domain()->domain().begin(), tv->domain()->domain().end());
1017
1018 tv->reorder(shift_left);
1019 for (const auto i : c10::irange(tv->nDims())) {
1020 TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1)));
1021 }
1022
1023 tv = makeSymbolicTensor(3);
1024 ref = std::vector<IterDomain*>(
1025 tv->domain()->domain().begin(), tv->domain()->domain().end());
1026
1027 tv->reorder(shift_right);
1028 TORCH_CHECK(ref[ref.size() - 1]->sameAs(tv->axis(0)));
1029 for (const auto i : c10::irange(1, tv->nDims())) {
1030 TORCH_CHECK(ref[i - 1]->sameAs(tv->axis(i)));
1031 }
1032
1033 tv = makeSymbolicTensor(3);
1034 ref = std::vector<IterDomain*>(
1035 tv->domain()->domain().begin(), tv->domain()->domain().end());
1036 tv->reorder(swap);
1037 TORCH_CHECK(ref[0]->sameAs(tv->axis(2)));
1038 TORCH_CHECK(ref[2]->sameAs(tv->axis(0)));
1039 TORCH_CHECK(ref[1]->sameAs(tv->axis(1)));
1040}
1041
1042TEST_F(NVFuserTest, FusionEquality_CUDA) {
1043 Fusion fusion;
1044 FusionGuard fg(&fusion);
1045
1046 Double* fval1 = IrBuilder::create<Double>();
1047 Double* fval1_copy = fval1;
1048 Double* fval2 = IrBuilder::create<Double>();
1049 Double* fone = IrBuilder::create<Double>(1.0);
1050
1051 TORCH_CHECK(fval1->sameAs(fval1_copy));
1052 TORCH_CHECK(!fval1->sameAs(fval2));
1053 TORCH_CHECK(!fone->sameAs(fval1));
1054 TORCH_CHECK(fone->sameAs(IrBuilder::create<Double>(1.0)));
1055
1056 Int* ival1 = IrBuilder::create<Int>();
1057 Int* ival1_copy = ival1;
1058 Int* ival2 = IrBuilder::create<Int>();
1059 Int* ione = IrBuilder::create<Int>(1);
1060
1061 TORCH_CHECK(ival1->sameAs(ival1_copy));
1062 TORCH_CHECK(!ival1->sameAs(ival2));
1063 TORCH_CHECK(!ione->sameAs(ival1));
1064 TORCH_CHECK(ione->sameAs(IrBuilder::create<Int>(1)));
1065
1066 BinaryOp* add1 = IrBuilder::create<BinaryOp>(
1067 BinaryOpType::Add, IrBuilder::create<Double>(), fval1, ival1);
1068 BinaryOp* add1_copy = IrBuilder::create<BinaryOp>(
1069 BinaryOpType::Add, IrBuilder::create<Double>(), fval1, ival1);
1070 BinaryOp* sub1 = IrBuilder::create<BinaryOp>(
1071 BinaryOpType::Sub, IrBuilder::create<Double>(), fval1, ival1);
1072
1073 UnaryOp* neg1 = IrBuilder::create<UnaryOp>(
1074 UnaryOpType::Neg, IrBuilder::create<Double>(), fval1);
1075 UnaryOp* neg2 = IrBuilder::create<UnaryOp>(
1076 UnaryOpType::Neg, IrBuilder::create<Double>(), fval2);
1077 UnaryOp* neg1_copy = IrBuilder::create<UnaryOp>(
1078 UnaryOpType::Neg, IrBuilder::create<Double>(), fval1);
1079
1080 TORCH_CHECK(add1->sameAs(add1_copy));
1081 TORCH_CHECK(!add1->sameAs(sub1));
1082
1083 TORCH_CHECK(neg1->sameAs(neg1_copy));
1084 TORCH_CHECK(!static_cast<Expr*>(neg1)->sameAs(add1));
1085 TORCH_CHECK(!neg1->sameAs(neg2));
1086}
1087
1088TEST_F(NVFuserTest, FusionDependency_CUDA) {
1089 Fusion fusion;
1090 FusionGuard fg(&fusion);
1091
1092 Double* d0 = IrBuilder::create<Double>(0.f);
1093 Double* d1 = IrBuilder::create<Double>(1.f);
1094 auto d2 = add(d0, d1);
1095
1096 auto d3 = add(d2, d2);
1097
1098 Double* d4 = IrBuilder::create<Double>(4.f);
1099 Double* d5 = IrBuilder::create<Double>(5.f);
1100 auto d6 = add(d4, d5);
1101
1102 Double* d7 = IrBuilder::create<Double>(7.f);
1103 Double* d8 = IrBuilder::create<Double>(8.f);
1104 auto d9 = add(d7, d8);
1105
1106 auto d10 = add(d6, d9);
1107
1108 auto d11 = add(d3, d10);
1109
1110 TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d11));
1111 TORCH_CHECK(DependencyCheck::isDependencyOf(d1, d11));
1112 TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d11));
1113 TORCH_CHECK(DependencyCheck::isDependencyOf(d3, d11));
1114 TORCH_CHECK(DependencyCheck::isDependencyOf(d6, d11));
1115 TORCH_CHECK(DependencyCheck::isDependencyOf(d9, d11));
1116 TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d2));
1117 TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d3));
1118 TORCH_CHECK(DependencyCheck::isDependencyOf(d4, d6));
1119 TORCH_CHECK(DependencyCheck::isDependencyOf(d8, d10));
1120
1121 TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d0));
1122 TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d1));
1123 TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d2));
1124 TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d3));
1125 TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d4));
1126 TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d5));
1127 TORCH_CHECK(!DependencyCheck::isDependencyOf(d2, d0));
1128 TORCH_CHECK(!DependencyCheck::isDependencyOf(d3, d2));
1129 TORCH_CHECK(!DependencyCheck::isDependencyOf(d6, d4));
1130 TORCH_CHECK(!DependencyCheck::isDependencyOf(d10, d8));
1131
1132 auto dep_chain = DependencyCheck::getSingleDependencyChain(d0, d11);
1133 TORCH_CHECK(dep_chain.back() == d11);
1134 dep_chain.pop_back();
1135 TORCH_CHECK(dep_chain.back() == d3);
1136 dep_chain.pop_back();
1137 TORCH_CHECK(dep_chain.back() == d2);
1138 dep_chain.pop_back();
1139
1140 dep_chain = DependencyCheck::getSingleDependencyChain(d6, d11);
1141 TORCH_CHECK(dep_chain.back() == d11);
1142 dep_chain.pop_back();
1143 TORCH_CHECK(dep_chain.back() == d10);
1144 dep_chain.pop_back();
1145
1146 dep_chain = DependencyCheck::getSingleDependencyChain(d4, d11);
1147 TORCH_CHECK(dep_chain.back() == d11);
1148 dep_chain.pop_back();
1149 TORCH_CHECK(dep_chain.back() == d10);
1150 dep_chain.pop_back();
1151 TORCH_CHECK(dep_chain.back() == d6);
1152 dep_chain.pop_back();
1153
1154 dep_chain = DependencyCheck::getSingleDependencyChain(d11, d2);
1155 TORCH_CHECK(dep_chain.empty());
1156}
1157
1158TEST_F(NVFuserTest, FusionParser_CUDA) {
1159 // This test may not pass if using a custom block sync as there may
1160 // be additional calls. Skip the test as it's not specifically
1161 // relevant with block synchronizatin.
1162 if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
1163 return;
1164 }
1165 auto g = std::make_shared<Graph>();
1166 const auto graph0_string = R"IR(
1167 graph(%0 : Float(2, strides=[1]),
1168 %1 : Float(2, strides=[1])):
1169 %c0 : Float(2, strides=[1]) = aten::mul(%0, %1)
1170 %d0 : Float(2, strides=[1]) = aten::mul(%c0, %0)
1171 return (%d0))IR";
1172 parseIR(graph0_string, g.get());
1173
1174 // strides are not yet supported in the irparser.
1175 for (auto val : g->block()->inputs()) {
1176 if (val->isCompleteTensor())
1177 val->setType(val->type()->castRaw<TensorType>()->contiguous());
1178 }
1179 for (auto node : g->block()->nodes()) {
1180 for (auto val : node->outputs()) {
1181 if (val->isCompleteTensor())
1182 val->setType(val->type()->castRaw<TensorType>()->contiguous());
1183 }
1184 }
1185
1186 auto fusion = parseJitIR(g);
1187 FusionGuard fg(fusion.get());
1188 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1189 // Avoid vectorization here as those kernels can't be lowered twice at the
1190 // moment
1191 at::Tensor input1 = at::randn({16}, options);
1192 at::Tensor input2 = at::randn({16}, options);
1193 auto lparams = schedulePointwise(fusion.get(), {input1, input2});
1194
1195 // CONSIDER:
1196 // 1. this can be moved to a dedicated "golden" file
1197 // 2. use a fuzzy compare (ignore non-significant whitespaces for example)
1198 const std::string expected_kernel = R"(
1199__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3) {
1200 int64_t i50;
1201 i50 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
1202 if ((i50 < T0.size[0])) {
1203 float T5[1];
1204 T5[0] = 0;
1205 T5[0]
1206 = T1[i50];
1207 float T4[1];
1208 T4[0] = 0;
1209 T4[0]
1210 = T0[i50];
1211 float T2[1];
1212 T2[0]
1213 = T4[0]
1214 * T5[0];
1215 float T6[1];
1216 T6[0]
1217 = T2[0]
1218 * T4[0];
1219 T3[i50]
1220 = T6[0];
1221 }
1222}
1223)";
1224
1225 const std::string actual_kernel =
1226 "\n" + codegen::generateCudaKernel(GpuLower(fusion.get()).kernel());
1227 if (expected_kernel.size() != actual_kernel.size() ||
1228 expected_kernel.compare(actual_kernel) != 0) {
1229 std::cerr
1230 << " Codegen mismatch, codegen possibly changed, or is incorrect. "
1231 << " \n ========= EXPECTED ========= \n"
1232 << expected_kernel << "\n========= ACTUAL ========== \n"
1233 << actual_kernel << "\n=================" << std::endl;
1234 auto it = std::mismatch(
1235 expected_kernel.begin(),
1236 expected_kernel.end(),
1237 actual_kernel.begin(),
1238 actual_kernel.end());
1239 std::string actual_mismatched_snippet(it.second, actual_kernel.end());
1240 actual_mismatched_snippet = actual_mismatched_snippet.substr(0, 10);
1241 std::string expected_mismatched_snippet(it.first, expected_kernel.end());
1242 expected_mismatched_snippet = expected_mismatched_snippet.substr(0, 10);
1243 std::cerr << "First mismatch found at: " << actual_mismatched_snippet
1244 << ", expected: " << expected_mismatched_snippet << std::endl;
1245 TORCH_CHECK(false);
1246 }
1247
1248 FusionExecutor fe;
1249 fe.compileFusion(fusion.get(), {input1, input2}, lparams);
1250 auto outputs = fe.runFusion({input1, input2}, lparams);
1251 at::Tensor output_ref = input1 * input2 * input1;
1252 TORCH_CHECK(output_ref.equal(outputs[0]));
1253}
1254
1255TEST_F(NVFuserTest, FusionOuterSplit_CUDA) {
1256 Fusion fusion;
1257 FusionGuard fg(&fusion);
1258
1259 TensorView* tv0 = makeSymbolicTensor(3);
1260
1261 IrBuilder::create<BinaryOp>(
1262 BinaryOpType::Add,
1263 tv0,
1264 IrBuilder::create<Double>(0.0),
1265 IrBuilder::create<Double>(1.0));
1266 TensorView* tv1 = add(tv0, IrBuilder::create<Double>(2.0));
1267 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(3.0));
1268 fusion.addOutput(tv2);
1269
1270 //[I0, I1, I2]
1271 tv2->split(-1, 4, false);
1272 //[I0, I1, I2o{4}, I2i]
1273 tv2->merge(0);
1274 tv2->merge(0);
1275 //[I0*I1*I2o{4}, I2i]
1276 tv2->split(0, 2);
1277 //[I0*I1*I2o{4}o, I0*I1*I2o{4}i{2}, I2i]
1278 tv2->reorder({{0, 1}, {1, 0}});
1279 // I0*I1*I2o{4}i{2}, [I0*I1*I2o{4}o, I2i]
1280
1281 tv0->computeAt(tv2, -1);
1282
1283 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1284
1285 at::Tensor output = at::empty({2, 6, 32}, options);
1286
1287 FusionExecutor fe;
1288 fe.compileFusion(&fusion);
1289 fe.runFusion({}, {output});
1290
1291 at::Tensor output_ref = at::zeros_like(output, options);
1292 output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0;
1293
1294 TORCH_CHECK(output_ref.equal(output));
1295}
1296
1297TEST_F(NVFuserTest, FusionCodeGen_CUDA) {
1298 Fusion fusion;
1299 FusionGuard fg(&fusion);
1300
1301 TensorView* tv0 = makeSymbolicTensor(3);
1302
1303 IrBuilder::create<BinaryOp>(
1304 BinaryOpType::Add,
1305 tv0,
1306 IrBuilder::create<Double>(0.0),
1307 IrBuilder::create<Double>(1.0));
1308 TensorView* tv1 = add(tv0, IrBuilder::create<Double>(2.0));
1309 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(3.0));
1310 fusion.addOutput(tv2);
1311
1312 //[I0, I1, I2]
1313 tv2 = tv2->split(0, 4);
1314 //[I0o, I0i{4}, I1, I2]
1315 tv2 = tv2->merge(1);
1316 //[I0o, I0i{4}*I1, I2]
1317 tv2 = tv2->split(-1, 2);
1318 //[I0o, I0i{4}*I1, I2o, I2i{2}]
1319 tv2 = tv2->reorder({{0, 1}, {1, 0}, {3, 2}});
1320 //[I0i{4}*I1, I0o, I2i{2}, I2o]
1321
1322 tv0->computeAt(tv2, -1);
1323
1324 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1325
1326 at::Tensor output = at::empty({16, 8, 8}, options);
1327
1328 FusionExecutor fe;
1329 fe.compileFusion(&fusion);
1330 fe.runFusion({}, {output});
1331
1332 at::Tensor output_ref = at::zeros_like(output, options);
1333 output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0;
1334
1335 TORCH_CHECK(output_ref.equal(output));
1336}
1337
1338TEST_F(NVFuserTest, FusionCodeGen2_CUDA) {
1339 Fusion fusion;
1340 FusionGuard fg(&fusion);
1341
1342 TensorView* tv0 = makeSymbolicTensor(3);
1343 TensorView* tv1 = makeSymbolicTensor(3);
1344 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2.0));
1345 TensorView* tv3 = add(tv0, tv2);
1346
1347 fusion.addInput(tv0);
1348 fusion.addInput(tv1);
1349 fusion.addOutput(tv3);
1350
1351 //[I0, I1, I2]
1352 tv3->reorder({{0, 2}, {2, 0}});
1353 //[I2, I1, I0]
1354 tv3->split(-1, 4);
1355 //[I2, I1, I0o, I0i{4}]
1356 tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
1357 // I0o, I0i{4}, I1, I2]
1358
1359 tv0->computeAt(tv3, -1);
1360 tv1->computeAt(tv3, -1);
1361
1362 tv3->axis(0)->parallelize(ParallelType::BIDx);
1363 tv3->axis(-1)->parallelize(ParallelType::TIDx);
1364
1365 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1366
1367 at::Tensor input1 = at::randn({16, 8, 8}, options);
1368 at::Tensor input2 = at::randn_like(input1);
1369
1370 FusionExecutor fe;
1371 fe.compileFusion(&fusion, {input1, input2});
1372 auto outputs = fe.runFusion({input1, input2});
1373
1374 at::Tensor tv2_ref = input2 + 2.0;
1375 at::Tensor output_ref = input1 + tv2_ref;
1376
1377 TORCH_CHECK(output_ref.equal(outputs[0]));
1378}
1379
1380TEST_F(NVFuserTest, FusionSimplePWise_CUDA) {
1381 Fusion fusion;
1382 FusionGuard fg(&fusion);
1383 // dimensionality of the problem
1384 int nDims = 3;
1385
1386 // Set up your input tensor views
1387 TensorView* tv0 = makeContigTensor(nDims);
1388 TensorView* tv1 = makeContigTensor(nDims);
1389
1390 // Register your inputs
1391 fusion.addInput(tv0);
1392 fusion.addInput(tv1);
1393
1394 // Do math with it, it returns a `Val*` but can be static_casted back to
1395 // TensorView
1396 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2.0));
1397 TensorView* tv3 = add(tv0, tv2);
1398
1399 // Register your outputs
1400 fusion.addOutput(tv3);
1401
1402 // Do transformations, remember, transformations are outputs to inputs
1403 // This doesn't have to be in this order
1404 tv3->merge(1);
1405 tv3->merge(0);
1406
1407 // Split by n_threads
1408 tv3->split(0, 128);
1409 tv3->split(0, 4);
1410
1411 // For all inputs, computeAt the output inline, temporaries should be squeezed
1412 // between them
1413 tv0->computeAt(tv3, -1);
1414 tv1->computeAt(tv3, -1);
1415
1416 // Parallelize TV3
1417 tv3->axis(0)->parallelize(ParallelType::BIDx);
1418 tv3->axis(-2)->parallelize(ParallelType::Unroll);
1419 tv3->axis(-1)->parallelize(ParallelType::TIDx);
1420
1421 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1422
1423 at::Tensor input1 = at::randn({64, 2, 128}, options);
1424 at::Tensor input2 = at::rand_like(input1);
1425 at::Tensor output = at::empty_like(input1);
1426
1427 FusionExecutor fe;
1428 fe.compileFusion(&fusion, {input1, input2});
1429 fe.runFusion({input1, input2}, {output});
1430
1431 at::Tensor tv2_ref = input2 + 2.0;
1432 at::Tensor output_ref = input1 + tv2_ref;
1433
1434 TORCH_CHECK(output_ref.equal(output));
1435}
1436
1437TEST_F(NVFuserTest, FusionSimplePWiseDtypeComplex_CUDA) {
1438 Fusion fusion;
1439 FusionGuard fg(&fusion);
1440 // dimensionality of the problem
1441 int nDims = 3;
1442
1443 // Set up your input tensor views
1444 TensorView* tv0 = makeContigTensor(nDims, DataType::ComplexFloat);
1445 TensorView* tv1 = makeContigTensor(nDims, DataType::ComplexFloat);
1446
1447 // Register your inputs
1448 fusion.addInput(tv0);
1449 fusion.addInput(tv1);
1450
1451 // Do math with it, it returns a `Val*` but can be static_casted back to
1452 // TensorView
1453 c10::complex<double> scalar1(2.0, 3.0);
1454 TensorView* tv2 = add(tv1, IrBuilder::create<ComplexDouble>(scalar1));
1455 TensorView* tv3 = add(tv0, tv2);
1456
1457 // Register your outputs
1458 fusion.addOutput(tv3);
1459
1460 // Do transformations, remember, transformations are outputs to inputs
1461 // This doesn't have to be in this order
1462 tv3->merge(1);
1463 tv3->merge(0);
1464
1465 // Split by n_threads
1466 tv3->split(0, 128);
1467 tv3->split(0, 4);
1468
1469 // For all inputs, computeAt the output inline, temporaries should be squeezed
1470 // between them
1471 tv0->computeAt(tv3, -1);
1472 tv1->computeAt(tv3, -1);
1473
1474 // Parallelize TV3
1475 tv3->axis(0)->parallelize(ParallelType::BIDx);
1476 tv3->axis(-2)->parallelize(ParallelType::Unroll);
1477 tv3->axis(-1)->parallelize(ParallelType::TIDx);
1478
1479 auto options =
1480 at::TensorOptions().dtype(at::kComplexFloat).device(at::kCUDA, 0);
1481
1482 at::Tensor input1 = at::randn({64, 2, 128}, options);
1483 at::Tensor input2 = at::rand_like(input1);
1484 at::Tensor output = at::empty_like(input1);
1485
1486 FusionExecutor fe;
1487 fe.compileFusion(&fusion, {input1, input2});
1488 fe.runFusion({input1, input2}, {output});
1489
1490 at::Tensor tv2_ref = input2 + scalar1;
1491 at::Tensor output_ref = input1 + tv2_ref;
1492
1493 TORCH_CHECK(output_ref.equal(output));
1494}
1495
1496TEST_F(NVFuserTest, FusionExecKernel_CUDA) {
1497 Fusion fusion;
1498 FusionGuard fg(&fusion);
1499
1500 // Set up your input tensor views
1501 TensorView* tv0 = makeSymbolicTensor(2);
1502 TensorView* tv1 = makeSymbolicTensor(2);
1503
1504 // Register your inputs
1505 fusion.addInput(tv0);
1506 fusion.addInput(tv1);
1507
1508 // Do math with it, it returns a `Val*` but can be static_casted back to
1509 // TensorView
1510 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2.0));
1511 TensorView* tv3 = add(tv0, tv2);
1512
1513 // Register your outputs
1514 fusion.addOutput(tv3);
1515
1516 tv3->merge(0);
1517 tv3->split(0, 128);
1518 tv3->split(0, 4);
1519
1520 // For all inputs, computeAt the output inline, temporaries should be squeezed
1521 // between them
1522 tv0->computeAt(tv3, 1);
1523 tv1->computeAt(tv3, 1);
1524
1525 // Parallelize TV3
1526 tv3->axis(0)->parallelize(ParallelType::BIDx);
1527 tv2->axis(1)->parallelize(ParallelType::Unroll);
1528 tv3->axis(1)->parallelize(ParallelType::Unroll);
1529 tv2->axis(-1)->parallelize(ParallelType::TIDx);
1530 tv3->axis(-1)->parallelize(ParallelType::TIDx);
1531
1532 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1533
1534 at::Tensor input1 = at::ones({1, 128}, options);
1535 at::Tensor input2 = at::ones_like(input1);
1536
1537 FusionExecutor fe;
1538 fe.compileFusion(&fusion, {input1, input2});
1539 auto outputs = fe.runFusion({input1, input2});
1540
1541 at::Tensor check = at::full({1, 128}, 4, options);
1542 ;
1543 TORCH_CHECK(outputs[0].equal(check));
1544}
1545
1546int ceilDiv_(int a, int b) {
1547 return (a + b - 1) / b;
1548}
1549
1550TEST_F(NVFuserTest, FusionAdvancedComputeAt1_CUDA) {
1551 // Case 1
1552 // tv1 = tv0 * 0.5
1553 // tv2 = tv1 * -1
1554 // tv3 = tv1 + 3
1555 // tv4 = tv1 * 2
1556 // tv5 = tv3 + tv2
1557 // tv6 = tv5 + tv4
1558 // tv7 = tv1 + tv4
1559 Fusion fusion;
1560 FusionGuard fg(&fusion);
1561
1562 TensorView* tv0 = makeSymbolicTensor(2);
1563 fusion.addInput(tv0);
1564
1565 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(0.5));
1566 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(-1.0));
1567 TensorView* tv3 = add(tv1, IrBuilder::create<Double>(3.0));
1568 TensorView* tv4 = mul(tv1, IrBuilder::create<Double>(2.0));
1569 TensorView* tv5 = add(tv3, tv2);
1570
1571 TensorView* tv6 = add(tv5, tv4);
1572 TensorView* tv7 = add(tv1, tv4);
1573
1574 fusion.addOutput(tv6);
1575 fusion.addOutput(tv7);
1576
1577 // Lets setup to actually run
1578 tv7->merge(0);
1579 tv7->split(0, 128);
1580 tv7->split(0, 4);
1581
1582 tv7->axis(0)->parallelize(ParallelType::BIDx);
1583
1584 tv0->computeAt(tv7, 1);
1585
1586 ComputeAtMap ca_map(&fusion);
1587
1588 // The this-position of the last tensor should be zero.
1589 TORCH_CHECK(
1590 tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 &&
1591 tv7->getMaxProducerPosition() == 1);
1592 TORCH_CHECK(
1593 tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 &&
1594 tv6->getMaxProducerPosition() == 1);
1595 // The position of every other tensor should be 1.
1596 for (auto tv : {tv1, tv2, tv3, tv4, tv5}) {
1597 TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1);
1598
1599 TORCH_CHECK(
1600 ca_map.areMapped(tv7->axis(0), tv->axis(0), IdMappingMode::PERMISSIVE));
1601 }
1602
1603 for (Val* val : fusion.vals()) {
1604 if (!val->isFusionInput() &&
1605 val->getValType().value() == ValType::TensorView) {
1606 TensorView* tv = static_cast<TensorView*>(val);
1607 tv->axis(1)->parallelize(ParallelType::Unroll);
1608 tv->axis(-1)->parallelize(ParallelType::TIDx);
1609 }
1610 }
1611
1612 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1613
1614 at::Tensor aten_input = at::randn({129, 127}, options);
1615
1616 auto t1 = aten_input.mul({0.5});
1617 auto t2 = t1.mul({-1.0});
1618 auto t3 = t1.add({3.0});
1619 auto t4 = t1.mul({2.0});
1620 auto t5 = t3.add(t2);
1621 auto t6 = t5.add(t4);
1622 auto t7 = t1.add(t4);
1623
1624 std::vector<at::Tensor> aten_outputs = {t6, t7};
1625 std::vector<at::Tensor> cg_outputs = {
1626 at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
1627
1628 FusionExecutor fe;
1629 fe.compileFusion(&fusion, {aten_input});
1630 fe.runFusion({aten_input}, cg_outputs);
1631
1632 testValidate(
1633 &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
1634}
1635
1636TEST_F(NVFuserTest, FusionAdvancedComputeAt2_CUDA) {
1637 // Case 2
1638 // tv1 = tv0 * -1
1639 // tv2 = tv0 + 3
1640 // tv3 = tv0 * 2
1641 // tv4 = tv2 + tv1
1642 // tv5 = tv4 + tv3
1643 // tv6 = tv5 + tv3
1644 Fusion fusion;
1645 FusionGuard fg(&fusion);
1646
1647 TensorView* tv0 = makeSymbolicTensor(2);
1648 fusion.addInput(tv0);
1649
1650 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(-1.0));
1651 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(3.0));
1652 TensorView* tv3 = mul(tv0, IrBuilder::create<Double>(2.0));
1653 TensorView* tv4 = add(tv2, tv1);
1654
1655 TensorView* tv5 = add(tv4, tv3);
1656 TensorView* tv6 = add(tv5, tv3);
1657
1658 fusion.addOutput(tv5);
1659 fusion.addOutput(tv6);
1660
1661 // Lets setup to actually run
1662 tv6->merge(0);
1663 tv6->split(0, 128);
1664 tv6->split(0, 4);
1665
1666 tv6->axis(0)->parallelize(ParallelType::BIDx);
1667
1668 tv0->computeAt(tv6, 1);
1669
1670 for (Val* val : fusion.vals()) {
1671 if (!val->isFusionInput() &&
1672 val->getValType().value() == ValType::TensorView) {
1673 TensorView* tv = static_cast<TensorView*>(val);
1674
1675 tv->axis(1)->parallelize(ParallelType::Unroll);
1676 tv->axis(-1)->parallelize(ParallelType::TIDx);
1677 }
1678 }
1679
1680 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1681 at::Tensor input = at::randn({129, 127}, options);
1682
1683 auto t1 = input.mul({-1.0});
1684 auto t2 = input.add({3.0});
1685 auto t3 = input.mul({2.0});
1686 auto t4 = t2.add(t1);
1687 auto t5 = t4.add(t3);
1688 auto t6 = t5.add(t3);
1689
1690 std::vector<at::Tensor> aten_outputs = {t5, t6};
1691
1692 FusionExecutor fe;
1693 fe.compileFusion(&fusion, {input});
1694 auto cg_outputs = fe.runFusion({input});
1695
1696 testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
1697}
1698
1699TEST_F(NVFuserTest, FusionAdvancedComputeAt3_CUDA) {
1700 // Case 3
1701 // T2 = T1 * 0.979361
1702 // T3 = T2 * T0
1703 Fusion fusion;
1704 FusionGuard fg(&fusion);
1705
1706 TensorView* tv0 = makeSymbolicTensor(4);
1707 fusion.addInput(tv0);
1708
1709 TensorView* tv1 = makeSymbolicTensor(4);
1710 fusion.addInput(tv1);
1711
1712 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(.979361));
1713 TensorView* tv3 = mul(tv2, tv0);
1714
1715 fusion.addOutput(tv3);
1716
1717 // Lets setup to actually run
1718 while (tv3->nDims() > 1)
1719 tv3->merge(0);
1720 tv3->split(0, 128);
1721 tv3->split(0, 4);
1722
1723 tv0->computeAt(tv3, 1);
1724 tv1->computeAt(tv3, 1);
1725
1726 tv3->axis(0)->parallelize(ParallelType::BIDx);
1727
1728 for (Val* val : fusion.vals()) {
1729 if (!val->isFusionInput() &&
1730 val->getValType().value() == ValType::TensorView) {
1731 TensorView* tv = static_cast<TensorView*>(val);
1732
1733 tv->axis(1)->parallelize(ParallelType::Unroll);
1734 tv->axis(-1)->parallelize(ParallelType::TIDx);
1735 }
1736 }
1737
1738 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1739 at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
1740 at::Tensor t1 = at::rand_like(t0, options);
1741
1742 auto t2 = t1.mul({0.979361});
1743 auto aten_output = t2.mul(t0);
1744
1745 std::vector<IValue> aten_inputs = {t0, t1};
1746
1747 at::Tensor cg_output = at::empty_like(t0, options);
1748
1749 FusionExecutor fe;
1750 fe.compileFusion(&fusion, aten_inputs);
1751 fe.runFusion(aten_inputs, {cg_output});
1752
1753 testValidate(
1754 &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
1755}
1756
1757TEST_F(NVFuserTest, FusionAdvancedComputeAt4_CUDA) {
1758 // Case 4
1759 // T4 = T2 - T3
1760 // T5 = T1 + T4
1761 // T6 = T5 - T0
1762 Fusion fusion;
1763 FusionGuard fg(&fusion);
1764
1765 TensorView* tv0 = makeSymbolicTensor(4);
1766 fusion.addInput(tv0);
1767
1768 TensorView* tv1 = makeSymbolicTensor(4);
1769 fusion.addInput(tv1);
1770
1771 TensorView* tv2 = makeSymbolicTensor(4);
1772 fusion.addInput(tv2);
1773
1774 TensorView* tv3 = makeSymbolicTensor(4);
1775 fusion.addInput(tv3);
1776
1777 TensorView* tv4 = sub(tv2, tv3);
1778 TensorView* tv5 = add(tv1, tv4);
1779 TensorView* tv6 = sub(tv5, tv0);
1780
1781 fusion.addOutput(tv6);
1782
1783 // Lets setup to actually run
1784 while (tv6->nDims() > 1)
1785 tv6->merge(0);
1786 tv6->split(0, 128);
1787 tv6->split(0, 4);
1788
1789 tv0->computeAt(tv6, 1);
1790 tv1->computeAt(tv6, 1);
1791 tv2->computeAt(tv6, 1);
1792 tv3->computeAt(tv6, 1);
1793
1794 tv6->axis(0)->parallelize(ParallelType::BIDx);
1795
1796 for (Val* val : fusion.vals()) {
1797 if (!val->isFusionInput() &&
1798 val->getValType().value() == ValType::TensorView) {
1799 TensorView* tv = static_cast<TensorView*>(val);
1800
1801 tv->axis(1)->parallelize(ParallelType::Unroll);
1802 tv->axis(-1)->parallelize(ParallelType::TIDx);
1803 }
1804 }
1805
1806 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1807 at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
1808 at::Tensor t1 = at::rand_like(t0, options);
1809 at::Tensor t2 = at::rand_like(t0, options);
1810 at::Tensor t3 = at::rand_like(t0, options);
1811
1812 auto t4 = t2.sub(t3);
1813 auto t5 = t1.add(t4);
1814 auto aten_output = t5.sub(t0);
1815
1816 std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
1817
1818 FusionExecutor fe;
1819 fe.compileFusion(&fusion, aten_inputs);
1820 auto cg_outputs = fe.runFusion(aten_inputs);
1821
1822 testValidate(
1823 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
1824}
1825
1826TEST_F(NVFuserTest, FusionAdvancedComputeAt5_CUDA) {
1827 // Case 5
1828 // tv2 = tv0 + 2.0
1829 // tv3 = tv1 * tv2
1830 Fusion fusion;
1831 FusionGuard fg(&fusion);
1832
1833 // Set up your input tensor views
1834 TensorView* tv0 = makeSymbolicTensor(2);
1835 fusion.addInput(tv0);
1836 TensorView* tv1 = makeSymbolicTensor(2);
1837 fusion.addInput(tv1);
1838 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(2.0));
1839 TensorView* tv3 = mul(tv1, tv2);
1840 fusion.addOutput(tv3);
1841
1842 tv3->merge(0);
1843 tv3->split(-1, 8);
1844 tv3->split(-1, 4);
1845
1846 tv2->computeAt(tv3, 1);
1847 tv3->axis(0)->parallelize(ParallelType::BIDx);
1848
1849 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1850 at::Tensor t0 = at::randn({63, 65}, options);
1851 at::Tensor t1 = at::rand_like(t0, options);
1852
1853 auto t2 = t0.add(2.0);
1854 auto aten_output = t1.mul(t2);
1855
1856 std::vector<IValue> aten_inputs = {t0, t1};
1857
1858 FusionExecutor fe;
1859 fe.compileFusion(&fusion, aten_inputs);
1860 auto cg_outputs = fe.runFusion(aten_inputs);
1861
1862 testValidate(
1863 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
1864}
1865
1866TEST_F(NVFuserTest, FusionAdvancedComputeAt6_CUDA) {
1867 Fusion fusion;
1868 FusionGuard fg(&fusion);
1869
1870 TensorView* tv0 = makeSymbolicTensor(2);
1871 fusion.addInput(tv0);
1872 TensorView* tv1 = makeSymbolicTensor(2);
1873 fusion.addInput(tv1);
1874 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(2.0));
1875 TensorView* tv3 = mul(tv1, tv2);
1876 fusion.addOutput(tv3);
1877
1878 tv2->merge(0);
1879 tv2->split(-1, 8);
1880 tv2->split(-1, 4);
1881 tv3->merge(0);
1882 tv3->split(-1, 8);
1883
1884 tv2->computeAt(tv3, 1);
1885
1886 tv3->axis(0)->parallelize(ParallelType::BIDx);
1887
1888 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1889 at::Tensor t0 = at::randn({63, 65}, options);
1890 at::Tensor t1 = at::rand_like(t0, options);
1891
1892 auto t2 = t0.add(2.0);
1893 auto aten_output = t1.mul(t2);
1894
1895 std::vector<IValue> aten_inputs = {t0, t1};
1896
1897 FusionExecutor fe;
1898 fe.compileFusion(&fusion, aten_inputs);
1899 auto cg_outputs = fe.runFusion(aten_inputs);
1900
1901 testValidate(
1902 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
1903}
1904
1905TEST_F(NVFuserTest, FusionAdvancedComputeAt7_CUDA) {
1906 Fusion fusion;
1907 FusionGuard fg(&fusion);
1908
1909 auto tv0 = makeSymbolicTensor(1);
1910 fusion.addInput(tv0);
1911
1912 auto tv1 = add(tv0, IrBuilder::create<Double>(1.0));
1913
1914 auto tv2 = makeSymbolicTensor(1);
1915 fusion.addInput(tv2);
1916
1917 auto tv3 = add(tv2, IrBuilder::create<Double>(3.0));
1918
1919 auto tv4 = add(tv1, tv3);
1920 fusion.addOutput(tv4);
1921
1922 auto tv5 = broadcast(tv1, {false, true});
1923
1924 auto tv6 = makeSymbolicTensor(2);
1925 fusion.addInput(tv6);
1926
1927 auto tv7 = mul(tv5, tv6);
1928
1929 fusion.addOutput(tv7);
1930
1931 tv7->split(1, 2);
1932 tv7->merge(0);
1933 tv7->split(0, 4);
1934 tv7->split(0, 128);
1935
1936 tv7->axis(0)->parallelize(ParallelType::BIDx);
1937 tv7->axis(1)->parallelize(ParallelType::TIDx);
1938
1939 tv0->computeAt(tv7, 1);
1940 auto tv5_domain = tv5->domain()->domain();
1941
1942 // These computeAt transformations should not affect the TV5 domain
1943 tv0->computeAt(tv4, -1);
1944 tv2->computeAt(tv4, -1);
1945
1946 auto tv5_domain_current = tv5->domain()->domain();
1947 TORCH_CHECK(tv5_domain == tv5_domain_current, "Invalid TV5 domain");
1948
1949 const int numel_x = 100;
1950 const int numel_y = 200;
1951
1952 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1953 auto t0 = at::randn({numel_x}, options);
1954 auto t2 = at::randn({numel_x}, options);
1955 auto t6 = at::randn({numel_x, numel_y}, options);
1956
1957 auto t1 = t0.add(1.0);
1958 auto t3 = t2.add(3.0);
1959 auto t4 = t1.add(t3);
1960 auto t5 = t1.unsqueeze(1);
1961 auto t7 = t5.mul(t6);
1962
1963 std::vector<IValue> aten_inputs = {t0, t2, t6};
1964 std::vector<at::Tensor> aten_outputs = {t4, t7};
1965
1966 FusionExecutor fe;
1967 fe.compileFusion(&fusion, aten_inputs);
1968 auto cg_outputs = fe.runFusion(aten_inputs);
1969
1970 testValidate(
1971 &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
1972}
1973
1974TEST_F(NVFuserTest, FusionAdvancedComputeAt8_CUDA) {
1975 Fusion fusion;
1976 FusionGuard fg(&fusion);
1977
1978 auto tv0 = makeSymbolicTensor(1);
1979 fusion.addInput(tv0);
1980
1981 auto tv1 = add(tv0, IrBuilder::create<Double>(1.0));
1982
1983 auto tv2 = makeSymbolicTensor(1);
1984 fusion.addInput(tv2);
1985
1986 auto tv3 = add(tv2, IrBuilder::create<Double>(3.0));
1987
1988 auto tv4 = add(tv1, tv3);
1989 fusion.addOutput(tv4);
1990
1991 auto tv5 = broadcast(tv1, {false, true});
1992
1993 auto tv6 = makeSymbolicTensor(2);
1994 fusion.addInput(tv6);
1995
1996 auto tv7 = mul(tv5, tv6);
1997
1998 fusion.addOutput(tv7);
1999
2000 tv7->split(1, 2);
2001 tv7->merge(0);
2002 tv7->split(0, 128, false);
2003 tv7->split(0, 4, false);
2004
2005 tv7->axis(0)->parallelize(ParallelType::BIDx);
2006 tv7->axis(1)->parallelize(ParallelType::TIDx);
2007
2008 // Reverse computeAt structure from previous test
2009 tv0->computeAt(tv4, -1);
2010 tv2->computeAt(tv4, -1);
2011 tv0->computeAt(tv7, -1);
2012
2013 const int numel_x = 100;
2014 const int numel_y = 200;
2015
2016 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2017 auto t0 = at::randn({numel_x}, options);
2018 auto t2 = at::randn({numel_x}, options);
2019 auto t6 = at::randn({numel_x, numel_y}, options);
2020
2021 auto t1 = t0.add(1.0);
2022 auto t3 = t2.add(3.0);
2023 auto t4 = t1.add(t3);
2024 auto t5 = t1.unsqueeze(1);
2025 auto t7 = t5.mul(t6);
2026
2027 std::vector<IValue> aten_inputs = {t0, t2, t6};
2028 std::vector<at::Tensor> aten_outputs = {t4, t7};
2029
2030 FusionExecutor fe;
2031 fe.compileFusion(&fusion, aten_inputs);
2032 auto cg_outputs = fe.runFusion(aten_inputs);
2033
2034 testValidate(
2035 &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
2036}
2037
2038TEST_F(NVFuserTest, FusionAdvancedComputeWith1_CUDA) {
2039 // Case 1
2040 // tv1 = tv0 * 0.5
2041 // tv2 = tv1 * -1
2042 // tv3 = tv1 + 3
2043 // tv4 = tv1 * 2
2044 // tv5 = tv3 + tv2
2045 // tv6 = tv5 + tv4
2046 // tv7 = tv1 + tv4
2047 Fusion fusion;
2048 FusionGuard fg(&fusion);
2049
2050 TensorView* tv0 = makeSymbolicTensor(2);
2051 fusion.addInput(tv0);
2052
2053 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(0.5));
2054 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(-1.0));
2055 TensorView* tv3 = add(tv1, IrBuilder::create<Double>(3.0));
2056 TensorView* tv4 = mul(tv1, IrBuilder::create<Double>(2.0));
2057 TensorView* tv5 = add(tv3, tv2);
2058
2059 TensorView* tv6 = add(tv5, tv4);
2060 TensorView* tv7 = add(tv1, tv4);
2061
2062 fusion.addOutput(tv6);
2063 fusion.addOutput(tv7);
2064
2065 // Lets setup to actually run
2066 tv0->merge(0);
2067 tv0->split(0, 128);
2068 tv0->split(0, 4);
2069
2070 tv0->axis(0)->parallelize(ParallelType::BIDx);
2071
2072 tv0->computeWith(tv7, 1);
2073
2074 GpuLower gpulw(&fusion);
2075
2076 // The this-position of the last tensor should be zero.
2077 TORCH_CHECK(
2078 tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 &&
2079 tv7->getMaxProducerPosition() == 1);
2080 TORCH_CHECK(
2081 tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 &&
2082 tv6->getMaxProducerPosition() == 1);
2083
2084 ComputeAtMap ca_map(&fusion);
2085
2086 // The position of every other tensor should be 1.
2087 for (auto tv : {tv1, tv2, tv3, tv4, tv5}) {
2088 TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1);
2089 TORCH_CHECK(
2090 ca_map.areMapped(tv7->axis(0), tv->axis(0), IdMappingMode::PERMISSIVE));
2091 }
2092
2093 for (Val* val : fusion.vals()) {
2094 if (!val->isFusionInput() &&
2095 val->getValType().value() == ValType::TensorView) {
2096 TensorView* tv = static_cast<TensorView*>(val);
2097 tv->axis(1)->parallelize(ParallelType::Unroll);
2098 tv->axis(-1)->parallelize(ParallelType::TIDx);
2099 }
2100 }
2101
2102 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2103
2104 at::Tensor aten_input = at::randn({129, 127}, options);
2105
2106 auto t1 = aten_input.mul({0.5});
2107 auto t2 = t1.mul({-1.0});
2108 auto t3 = t1.add({3.0});
2109 auto t4 = t1.mul({2.0});
2110 auto t5 = t3.add(t2);
2111 auto t6 = t5.add(t4);
2112 auto t7 = t1.add(t4);
2113
2114 std::vector<at::Tensor> aten_outputs = {t6, t7};
2115 std::vector<at::Tensor> cg_outputs = {
2116 at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
2117
2118 FusionExecutor fe;
2119 fe.compileFusion(&fusion, {aten_input});
2120 fe.runFusion({aten_input}, cg_outputs);
2121
2122 testValidate(
2123 &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
2124}
2125
2126TEST_F(NVFuserTest, FusionAdvancedComputeWith2_CUDA) {
2127 // Case 2
2128 // tv1 = tv0 * -1
2129 // tv2 = tv0 + 3
2130 // tv3 = tv0 * 2
2131 // tv4 = tv2 + tv1
2132 // tv5 = tv4 + tv3
2133 // tv6 = tv5 + tv3
2134 Fusion fusion;
2135 FusionGuard fg(&fusion);
2136
2137 TensorView* tv0 = makeSymbolicTensor(2);
2138 fusion.addInput(tv0);
2139
2140 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(-1.0));
2141 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(3.0));
2142 TensorView* tv3 = mul(tv0, IrBuilder::create<Double>(2.0));
2143 TensorView* tv4 = add(tv2, tv1);
2144
2145 TensorView* tv5 = add(tv4, tv3);
2146 TensorView* tv6 = add(tv5, tv3);
2147
2148 fusion.addOutput(tv5);
2149 fusion.addOutput(tv6);
2150
2151 // Lets setup to actually run
2152 tv0->merge(0);
2153 tv0->split(0, 128);
2154 tv0->split(0, 4);
2155
2156 tv0->axis(0)->parallelize(ParallelType::BIDx);
2157
2158 tv0->computeWith(tv6, 1);
2159
2160 for (Val* val : fusion.vals()) {
2161 if (!val->isFusionInput() &&
2162 val->getValType().value() == ValType::TensorView) {
2163 TensorView* tv = static_cast<TensorView*>(val);
2164
2165 tv->axis(1)->parallelize(ParallelType::Unroll);
2166 tv->axis(-1)->parallelize(ParallelType::TIDx);
2167 }
2168 }
2169
2170 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2171 at::Tensor input = at::randn({129, 127}, options);
2172
2173 auto t1 = input.mul({-1.0});
2174 auto t2 = input.add({3.0});
2175 auto t3 = input.mul({2.0});
2176 auto t4 = t2.add(t1);
2177 auto t5 = t4.add(t3);
2178 auto t6 = t5.add(t3);
2179
2180 std::vector<at::Tensor> aten_outputs = {t5, t6};
2181
2182 FusionExecutor fe;
2183 fe.compileFusion(&fusion, {input});
2184 auto cg_outputs = fe.runFusion({input});
2185
2186 testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
2187}
2188
2189TEST_F(NVFuserTest, FusionAdvancedComputeWith3_CUDA) {
2190 // Case 3
2191 // T2 = T1 * 0.979361
2192 // T3 = T2 * T0
2193 Fusion fusion;
2194 FusionGuard fg(&fusion);
2195
2196 TensorView* tv0 = makeSymbolicTensor(4);
2197 fusion.addInput(tv0);
2198
2199 TensorView* tv1 = makeSymbolicTensor(4);
2200 fusion.addInput(tv1);
2201
2202 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(.979361));
2203 TensorView* tv3 = mul(tv2, tv0);
2204
2205 fusion.addOutput(tv3);
2206
2207 // Lets setup to actually run
2208 while (tv0->nDims() > 1)
2209 tv0->merge(0);
2210 tv0->split(0, 128);
2211 tv0->split(0, 4);
2212
2213 while (tv1->nDims() > 1)
2214 tv1->merge(0);
2215 tv1->split(0, 128);
2216 tv1->split(0, 4);
2217
2218 tv0->computeWith(tv3, 1);
2219 tv1->computeWith(tv3, 1);
2220
2221 tv3->axis(0)->parallelize(ParallelType::BIDx);
2222
2223 for (Val* val : fusion.vals()) {
2224 if (!val->isFusionInput() &&
2225 val->getValType().value() == ValType::TensorView) {
2226 TensorView* tv = static_cast<TensorView*>(val);
2227
2228 tv->axis(1)->parallelize(ParallelType::Unroll);
2229 tv->axis(-1)->parallelize(ParallelType::TIDx);
2230 }
2231 }
2232
2233 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2234 at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
2235 at::Tensor t1 = at::rand_like(t0, options);
2236
2237 auto t2 = t1.mul({0.979361});
2238 auto aten_output = t2.mul(t0);
2239
2240 std::vector<IValue> aten_inputs = {t0, t1};
2241
2242 at::Tensor cg_output = at::empty_like(t0, options);
2243
2244 FusionExecutor fe;
2245 fe.compileFusion(&fusion, aten_inputs);
2246 fe.runFusion(aten_inputs, {cg_output});
2247
2248 testValidate(
2249 &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
2250}
2251
2252TEST_F(NVFuserTest, FusionAdvancedComputeWith4_CUDA) {
2253 // Case 4
2254 // T4 = T2 - T3
2255 // T5 = T1 + T4
2256 // T6 = T5 - T0
2257 Fusion fusion;
2258 FusionGuard fg(&fusion);
2259
2260 TensorView* tv0 = makeSymbolicTensor(4);
2261 fusion.addInput(tv0);
2262
2263 TensorView* tv1 = makeSymbolicTensor(4);
2264 fusion.addInput(tv1);
2265
2266 TensorView* tv2 = makeSymbolicTensor(4);
2267 fusion.addInput(tv2);
2268
2269 TensorView* tv3 = makeSymbolicTensor(4);
2270 fusion.addInput(tv3);
2271
2272 TensorView* tv4 = sub(tv2, tv3);
2273 TensorView* tv5 = add(tv1, tv4);
2274 TensorView* tv6 = sub(tv5, tv0);
2275
2276 fusion.addOutput(tv6);
2277 std::vector<TensorView*> tvs = {tv0, tv1, tv2};
2278 for (auto tv : tvs) {
2279 // Lets setup to actually run
2280 while (tv->nDims() > 1) {
2281 tv->merge(0);
2282 }
2283 tv->split(0, 128);
2284 tv->split(0, 4);
2285 tv->computeWith(tv6, 1);
2286 }
2287
2288 tv6->axis(0)->parallelize(ParallelType::BIDx);
2289
2290 for (Val* val : fusion.vals()) {
2291 if (!val->isFusionInput() &&
2292 val->getValType().value() == ValType::TensorView) {
2293 TensorView* tv = static_cast<TensorView*>(val);
2294
2295 tv->axis(1)->parallelize(ParallelType::Unroll);
2296 tv->axis(-1)->parallelize(ParallelType::TIDx);
2297 }
2298 }
2299
2300 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2301 at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
2302 at::Tensor t1 = at::rand_like(t0, options);
2303 at::Tensor t2 = at::rand_like(t0, options);
2304 at::Tensor t3 = at::rand_like(t0, options);
2305
2306 auto t4 = t2.sub(t3);
2307 auto t5 = t1.add(t4);
2308 auto aten_output = t5.sub(t0);
2309
2310 std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
2311
2312 FusionExecutor fe;
2313 fe.compileFusion(&fusion, aten_inputs);
2314 auto cg_outputs = fe.runFusion(aten_inputs);
2315
2316 testValidate(
2317 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
2318}
2319
2320TEST_F(NVFuserTest, FusionAdvancedComputeWith5_CUDA) {
2321 // Case 5
2322 // tv2 = tv0 + 2.0
2323 // tv3 = tv1 * tv2
2324 Fusion fusion;
2325 FusionGuard fg(&fusion);
2326
2327 // Set up your input tensor views
2328 TensorView* tv0 = makeSymbolicTensor(2);
2329 fusion.addInput(tv0);
2330 TensorView* tv1 = makeSymbolicTensor(2);
2331 fusion.addInput(tv1);
2332 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(2.0));
2333 TensorView* tv3 = mul(tv1, tv2);
2334 fusion.addOutput(tv3);
2335
2336 tv2->merge(0);
2337 tv2->split(-1, 8);
2338 tv2->split(-1, 4);
2339
2340 tv2->computeWith(tv3, 1);
2341 tv3->axis(0)->parallelize(ParallelType::BIDx);
2342
2343 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2344 at::Tensor t0 = at::randn({63, 65}, options);
2345 at::Tensor t1 = at::rand_like(t0, options);
2346
2347 auto t2 = t0.add(2.0);
2348 auto aten_output = t1.mul(t2);
2349
2350 std::vector<IValue> aten_inputs = {t0, t1};
2351
2352 FusionExecutor fe;
2353 fe.compileFusion(&fusion, aten_inputs);
2354 auto cg_outputs = fe.runFusion(aten_inputs);
2355
2356 testValidate(
2357 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
2358}
2359
2360TEST_F(NVFuserTest, FusionAdvancedComputeWith6_CUDA) {
2361 Fusion fusion;
2362 FusionGuard fg(&fusion);
2363
2364 TensorView* tv0 = makeSymbolicTensor(2);
2365 fusion.addInput(tv0);
2366 TensorView* tv1 = makeSymbolicTensor(2);
2367 fusion.addInput(tv1);
2368 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(2.0));
2369 TensorView* tv3 = mul(tv1, tv2);
2370 fusion.addOutput(tv3);
2371
2372 tv2->merge(0);
2373 tv2->split(-1, 8);
2374 tv2->split(-1, 4);
2375 tv3->merge(0);
2376 tv3->split(-1, 8);
2377
2378 tv2->computeWith(tv3, 1);
2379
2380 tv3->axis(0)->parallelize(ParallelType::BIDx);
2381
2382 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2383 at::Tensor t0 = at::randn({63, 65}, options);
2384 at::Tensor t1 = at::rand_like(t0, options);
2385
2386 auto t2 = t0.add(2.0);
2387 auto aten_output = t1.mul(t2);
2388
2389 std::vector<IValue> aten_inputs = {t0, t1};
2390
2391 FusionExecutor fe;
2392 fe.compileFusion(&fusion, aten_inputs);
2393 auto cg_outputs = fe.runFusion(aten_inputs);
2394
2395 testValidate(
2396 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
2397}
2398
2399TEST_F(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) {
2400 // tv1 = tv0 * 0.5
2401 // tv2 = tv1 * -1
2402 // tv3 = tv2 * -2
2403 Fusion fusion;
2404 FusionGuard fg(&fusion);
2405
2406 TensorView* tv0 = makeSymbolicTensor(1);
2407 fusion.addInput(tv0);
2408
2409 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(0.5));
2410 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(-1.0));
2411 TensorView* tv3 = mul(tv1, IrBuilder::create<Double>(-2.0));
2412 fusion.addOutput(tv2);
2413 fusion.addOutput(tv3);
2414
2415 // This computeAt will affect tv2 as well, even though tv2 is not in
2416 // the data-flow path between tv1 and tv3. The reason is that tv1 is
2417 // now computed at tv3, so tv2 must also be computed at the same
2418 // location. Overall, what will happen is basically we merge
2419 // expressions of all tensors and compute them in a single loop
2420 // nest.
2421 TensorView* computeAtTarget = tv3;
2422 computeAtTarget->split(0, 128);
2423 tv1->computeAt(computeAtTarget, 1);
2424
2425 TensorView* affected_tensors[] = {tv1, tv2, tv3};
2426 for (auto tv : affected_tensors) {
2427 TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
2428 }
2429
2430 GpuLower gpulw(&fusion);
2431
2432 TORCH_CHECK(tv1->getComputeAtPosition() == 1);
2433 TORCH_CHECK(
2434 tv2->getComputeAtPosition() == 0 && tv2->getMaxProducerPosition() == 1);
2435 TORCH_CHECK(
2436 tv3->getComputeAtPosition() == 0 && tv3->getMaxProducerPosition() == 1);
2437
2438 ComputeAtMap ca_map(&fusion);
2439
2440 // Note that tv2 is also computed at tv3.
2441 for (auto tv : {tv1, tv2}) {
2442 TORCH_CHECK(ca_map.areMapped(
2443 tv->axis(0), computeAtTarget->axis(0), IdMappingMode::PERMISSIVE));
2444 }
2445
2446 TORCH_CHECK(tv3->getComputeAtPosition() == 0);
2447
2448 computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
2449 for (auto tv : affected_tensors) {
2450 tv->axis(-1)->parallelize(ParallelType::TIDx);
2451 }
2452
2453 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2454
2455 at::Tensor aten_input = at::randn({1000}, options);
2456
2457 auto t1 = aten_input * 0.5;
2458 auto t2 = t1 * -1.0;
2459 auto t3 = t1 * -2.0;
2460
2461 std::vector<at::Tensor> aten_outputs = {t2, t3};
2462
2463 std::vector<at::Tensor> cg_outputs = {
2464 at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
2465
2466 FusionExecutor fe;
2467 fe.compileFusion(&fusion, {aten_input});
2468 fe.runFusion({aten_input}, cg_outputs);
2469
2470 testValidate(
2471 &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
2472}
2473
2474// Similar to ComputeAtMultiConsumers, but with a common consumer.
2475TEST_F(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) {
2476 // tv1 = tv0 * 0.5
2477 // tv2 = tv1 * -1
2478 // tv3 = tv2 * -2
2479 // tv4 = tv2 + tv3
2480 // tv5 = tv4 * 5
2481 Fusion fusion;
2482 FusionGuard fg(&fusion);
2483
2484 TensorView* tv0 = makeSymbolicTensor(1);
2485 fusion.addInput(tv0);
2486
2487 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(0.5));
2488 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(-1.0));
2489 TensorView* tv3 = mul(tv1, IrBuilder::create<Double>(-2.0));
2490 TensorView* tv4 = add(tv2, tv3);
2491 TensorView* tv5 = mul(tv4, IrBuilder::create<Double>(5.0));
2492 fusion.addOutput(tv3);
2493 fusion.addOutput(tv4);
2494 fusion.addOutput(tv5);
2495
2496 // Computing tv1 at tv3. This will affect tv2 as discussed in
2497 // ComplexComputeAt1. Additionally, in this case, notice that tv4 is
2498 // the common consumer of tv2 and tv3, so they are computed at
2499 // tv4. The indirect propagation of the computeAt should stop at the
2500 // common consumer, and no further change should occur. More
2501 // specifically, the computeAT position of tv4 and tv5 should be zero.
2502 TensorView* computeAtTarget = tv3;
2503 computeAtTarget->split(0, 128);
2504 tv1->computeAt(computeAtTarget, 1);
2505
2506 TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4};
2507 for (auto tv : affected_tensors) {
2508 TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
2509 }
2510
2511 TORCH_CHECK(tv1->getComputeAtPosition() == 1);
2512 TORCH_CHECK(tv2->getComputeAtPosition() == 1);
2513 TORCH_CHECK(tv3->getComputeAtPosition() == 1);
2514 TORCH_CHECK(tv4->getComputeAtPosition() == 0);
2515 TORCH_CHECK(tv5->getComputeAtPosition() == 0);
2516
2517 computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
2518
2519 for (auto tv : affected_tensors) {
2520 tv->axis(-1)->parallelize(ParallelType::TIDx);
2521 }
2522
2523 // Transform tv5 to make it look like the rest
2524 tv5->split(0, 128);
2525 tv5->axis(1)->parallelize(ParallelType::TIDx);
2526 tv5->axis(0)->parallelize(ParallelType::BIDx);
2527
2528 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2529
2530 at::Tensor aten_input = at::randn({1000}, options);
2531
2532 auto t1 = aten_input * 0.5;
2533 auto t2 = t1 * -1.0;
2534 auto t3 = t1 * -2.0;
2535 auto t4 = t2 + t3;
2536 auto t5 = t4 * 5.0;
2537
2538 std::vector<at::Tensor> aten_outputs = {t3, t4, t5};
2539 std::vector<at::Tensor> cg_outputs = {
2540 at::empty_like(aten_input, options),
2541 at::empty_like(aten_input, options),
2542 at::empty_like(aten_input, options)};
2543
2544 FusionExecutor fe;
2545 fe.compileFusion(&fusion, {aten_input});
2546 fe.runFusion({aten_input}, cg_outputs);
2547
2548 testValidate(
2549 &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
2550}
2551
2552TEST_F(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) {
2553 // tv1 = tv0 * 0.5
2554 // tv2 = tv1 * -1
2555 // tv3 = tv2 * -1
2556 // tv4 = tv1 + 4
2557 // tv5 = tv3 + tv4
2558 Fusion fusion;
2559 FusionGuard fg(&fusion);
2560
2561 TensorView* tv0 = makeSymbolicTensor(2);
2562 fusion.addInput(tv0);
2563
2564 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(0.5));
2565 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(-1.0));
2566 TensorView* tv3 = mul(tv2, IrBuilder::create<Double>(-1.0));
2567 TensorView* tv4 = add(tv1, IrBuilder::create<Double>(4.0));
2568 TensorView* tv5 = add(tv3, tv4);
2569
2570 fusion.addOutput(tv5);
2571
2572 TensorView* computeAtTarget = tv3;
2573
2574 computeAtTarget->merge(0);
2575 computeAtTarget->split(0, 128);
2576 computeAtTarget->split(0, 4);
2577
2578 computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
2579
2580 // This computeAt will affect all tensors including tv3, tv4 and
2581 // tv5, even though it appears to impact only tv1 and tv2. The
2582 // reason is that tv1 is now computed at tv3, so tv4 must also be
2583 // computed at the same location. Similarly, the consumer of tv4,
2584 // tv5, must also be computed at the same location. Overall, what
2585 // will happen is basically we merge expressions of all tensors and
2586 // compute them in a single loop nest. Internally, this will be
2587 // realized by making all tensors, except for those in the path
2588 // between tv1 and tv3, computed at tv5, which we call the common
2589 // consumer.
2590 tv1->computeAt(computeAtTarget, 1);
2591
2592 // All tensors should have the same dimenionality as the target
2593 for (Val* val : fusion.vals()) {
2594 if (val->isFusionInput() ||
2595 val->getValType().value() != ValType::TensorView) {
2596 continue;
2597 }
2598 TensorView* tv = val->as<TensorView>();
2599 TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
2600 if (tv == tv5) {
2601 TORCH_CHECK(tv->getComputeAtPosition() == 0);
2602 } else {
2603 TORCH_CHECK(tv->getComputeAtPosition() == 1);
2604 }
2605 }
2606
2607 for (auto tv : ir_utils::filterByType<TensorView>(fusion.vals())) {
2608 if (!tv->isFusionInput()) {
2609 tv->axis(1)->parallelize(ParallelType::Unroll);
2610 tv->axis(-1)->parallelize(ParallelType::TIDx);
2611 }
2612 }
2613
2614 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2615
2616 at::Tensor aten_input = at::randn({129, 127}, options);
2617
2618 auto t1 = aten_input.mul({0.5});
2619 auto t2 = t1.mul({-1.0});
2620 auto t3 = t2.mul({-1.0});
2621 auto t4 = t1.add({4.0});
2622 auto aten_output = t3 + t4;
2623
2624 at::Tensor cg_output = at::empty_like(aten_input, options);
2625
2626 FusionExecutor fe;
2627 fe.compileFusion(&fusion, {aten_input});
2628 fe.runFusion({aten_input}, {cg_output});
2629
2630 testValidate(
2631 &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
2632}
2633
2634// Similar to the above common consumer test but adds an additional
2635// tensor that has no common consumer with the other tensors.
2636TEST_F(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) {
2637 // tv1 = tv0 * 0.5
2638 // tv2 = tv1 * -1
2639 // tv3 = tv2 * -1
2640 // tv4 = tv1 + 4
2641 // tv5 = tv2 + tv3
2642 // tv6 = tv1 + 6
2643 Fusion fusion;
2644 FusionGuard fg(&fusion);
2645
2646 TensorView* tv0 = makeSymbolicTensor(2);
2647 fusion.addInput(tv0);
2648
2649 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(0.5));
2650 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(-1.0));
2651 TensorView* tv3 = mul(tv2, IrBuilder::create<Double>(-1.0));
2652 TensorView* tv4 = add(tv1, IrBuilder::create<Double>(4.0));
2653 TensorView* tv5 = add(tv3, tv4);
2654 TensorView* tv6 = add(tv1, IrBuilder::create<Double>(6.0));
2655
2656 fusion.addOutput(tv5);
2657 fusion.addOutput(tv6);
2658
2659 TensorView* computeAtTarget = tv3;
2660
2661 computeAtTarget->merge(0);
2662 computeAtTarget->split(0, 128);
2663 computeAtTarget->split(0, 4);
2664
2665 computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
2666
2667 // This will have the same impact on the tensors except for tv5 and
2668 // tv6. tv6 does not have any common consumer with the computeAt
2669 // target, but since it uses tv1, it must be also computed at the
2670 // same location as the other impacted tensors. We can either make
2671 // tv5 computed at tv6 or tv6 computed at tv5. In this case, tv5
2672 // should be computed at tv6 just because the current implementation
2673 // orders the computeAt relationship based on the order in which
2674 // tensors are specified as outputs.
2675
2676 tv1->computeAt(computeAtTarget, 1);
2677
2678 // All tensors should have the same dimenionality as the target
2679 for (auto tv : ir_utils::filterByType<TensorView>(fusion.vals())) {
2680 if (tv->isFusionInput()) {
2681 continue;
2682 }
2683 TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
2684 if (tv == tv5 || tv == tv6) {
2685 TORCH_CHECK(tv->getComputeAtPosition() == 0);
2686 TORCH_CHECK(tv->getMaxProducerPosition() == 1);
2687 } else {
2688 TORCH_CHECK(tv->getComputeAtPosition() == 1);
2689 }
2690 }
2691
2692 for (Val* val : fusion.vals()) {
2693 if (!val->isFusionInput() &&
2694 val->getValType().value() == ValType::TensorView) {
2695 TensorView* tv = val->as<TensorView>();
2696 tv->axis(1)->parallelize(ParallelType::Unroll);
2697 tv->axis(-1)->parallelize(ParallelType::TIDx);
2698 }
2699 }
2700
2701 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2702
2703 at::Tensor aten_input = at::randn({129, 127}, options);
2704
2705 auto t1 = aten_input.mul({0.5});
2706 auto t2 = t1.mul({-1.0});
2707 auto t3 = t2.mul({-1.0});
2708 auto t4 = t1.add({4.0});
2709 auto t5 = t3 + t4;
2710 auto t6 = t1.add({6.0});
2711
2712 std::vector<at::Tensor> aten_outputs = {t5, t6};
2713 std::vector<at::Tensor> cg_outputs = {
2714 at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
2715
2716 FusionExecutor fe;
2717 fe.compileFusion(&fusion, {aten_input});
2718 fe.runFusion({aten_input}, cg_outputs);
2719
2720 testValidate(
2721 &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
2722}
2723
2724// Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor
2725// that does not have data dependency with the consumer.
2726TEST_F(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) {
2727 // tv1 = tv0 * 0.5
2728 // tv2 = tv1 * -1
2729 // tv3 = tv1 * -2
2730 // tv4 = tv2 + tv3
2731 // tv5 = tv4 * 5
2732 // tv6 = tv1 * 6
2733 Fusion fusion;
2734 FusionGuard fg(&fusion);
2735
2736 TensorView* tv0 = makeSymbolicTensor(1);
2737 fusion.addInput(tv0);
2738
2739 TensorView* tv1 = mul(tv0, IrBuilder::create<Double>(0.5));
2740 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(-1.0));
2741 TensorView* tv3 = mul(tv1, IrBuilder::create<Double>(-2.0));
2742 TensorView* tv4 = add(tv2, tv3);
2743 TensorView* tv5 = mul(tv4, IrBuilder::create<Double>(5.0));
2744 // Notice that tv6 is not a consumer of tv4.
2745 TensorView* tv6 = mul(tv1, IrBuilder::create<Double>(6.0));
2746 fusion.addOutput(tv3);
2747 fusion.addOutput(tv4);
2748 fusion.addOutput(tv5);
2749 fusion.addOutput(tv6);
2750
2751 TensorView* computeAtTarget = tv3;
2752 computeAtTarget->split(0, 128);
2753 tv1->computeAt(computeAtTarget, 1);
2754
2755 TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4, tv5, tv6};
2756 for (auto tv : affected_tensors) {
2757 TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
2758 if (tv == tv6 || tv == tv5) {
2759 TORCH_CHECK(tv->getComputeAtPosition() == 0);
2760 } else {
2761 TORCH_CHECK(tv->getComputeAtPosition() == 1);
2762 }
2763 }
2764
2765 computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
2766
2767 for (auto tv : affected_tensors) {
2768 tv->axis(-1)->parallelize(ParallelType::TIDx);
2769 }
2770
2771 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2772
2773 at::Tensor aten_input = at::randn({1000}, options);
2774
2775 auto t1 = aten_input * 0.5;
2776 auto t2 = t1 * -1.0;
2777 auto t3 = t1 * -2.0;
2778 auto t4 = t2 + t3;
2779 auto t5 = t4 * 5.0;
2780 auto t6 = t1 * 6.0;
2781
2782 std::vector<at::Tensor> aten_outputs = {t3, t4, t5, t6};
2783 std::vector<at::Tensor> cg_outputs = {
2784 at::empty_like(aten_input, options),
2785 at::empty_like(aten_input, options),
2786 at::empty_like(aten_input, options),
2787 at::empty_like(aten_input, options)};
2788
2789 FusionExecutor fe;
2790 fe.compileFusion(&fusion, {aten_input});
2791 fe.runFusion({aten_input}, cg_outputs);
2792
2793 testValidate(
2794 &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
2795}
2796
2797namespace {
2798
2799void checkIdMapped(
2800 ComputeAtRootDomainMap& root_map,
2801 TensorView* v0,
2802 IterDomain* id0,
2803 TensorView* v1,
2804 IterDomain* id1,
2805 bool should_map) {
2806 if (should_map) {
2807 TORCH_CHECK(
2808 root_map.canMap(v0->domain(), id0, v1->domain(), id1),
2809 "Should be mappable: ",
2810 id0,
2811 " of ",
2812 v0,
2813 " and ",
2814 id1,
2815 " of ",
2816 v1);
2817 } else {
2818 TORCH_CHECK(
2819 !root_map.canMap(v0->domain(), id0, v1->domain(), id1),
2820 "Should not be mappable: ",
2821 id0,
2822 " of ",
2823 v0,
2824 " and ",
2825 id1,
2826 " of ",
2827 v1);
2828 }
2829}
2830
2831void checkIdMapped(
2832 TensorView* v0,
2833 const std::vector<IterDomain*>& root0,
2834 const std::vector<bool> should_map0,
2835 TensorView* v1,
2836 const std::vector<IterDomain*>& root1,
2837 const std::vector<bool> should_map1) {
2838 ComputeAtRootDomainMap map;
2839 map.build();
2840 TORCH_INTERNAL_ASSERT(root0.size() == should_map0.size());
2841 TORCH_INTERNAL_ASSERT(root1.size() == should_map1.size());
2842 size_t idx0 = 0;
2843 for (const auto i : c10::irange(root0.size())) {
2844 size_t idx1 = 0;
2845 for (const auto j : c10::irange(root1.size())) {
2846 if (should_map0[i] && should_map1[j] && idx0 == idx1) {
2847 checkIdMapped(map, v0, root0[i], v1, root1[j], true);
2848 } else {
2849 checkIdMapped(map, v0, root0[i], v1, root1[j], false);
2850 }
2851 if (should_map1[j])
2852 ++idx1;
2853 }
2854 if (should_map0[i])
2855 ++idx0;
2856 }
2857}
2858
2859void checkIdMapped(
2860 TensorView* v0,
2861 const std::vector<IterDomain*>& root0,
2862 TensorView* v1,
2863 const std::vector<IterDomain*>& root1) {
2864 checkIdMapped(
2865 v0,
2866 root0,
2867 std::vector<bool>(root0.size(), true),
2868 v1,
2869 root1,
2870 std::vector<bool>(root1.size(), true));
2871}
2872
2873} // namespace
2874
2875TEST_F(NVFuserTest, FusionRootMappingBasic_CUDA) {
2876 Fusion fusion;
2877 FusionGuard fg(&fusion);
2878
2879 TensorView* tv0 = makeSymbolicTensor(2);
2880 TensorView* tv1 = makeSymbolicTensor(2);
2881
2882 fusion.addInput(tv0);
2883 fusion.addInput(tv1);
2884 auto tv3 = broadcast(tv0, {true, false, false});
2885 auto tv4 = broadcast(tv1, {false, true, false});
2886 auto tv5 = add(tv3, tv4);
2887 fusion.addOutput(tv5);
2888
2889 checkIdMapped(
2890 tv0,
2891 tv0->getRootDomain(),
2892 {true, true},
2893 tv4,
2894 tv4->getRootDomain(),
2895 {false, true, true});
2896 checkIdMapped(
2897 tv1,
2898 tv1->getRootDomain(),
2899 {true, true},
2900 tv4,
2901 tv4->getRootDomain(),
2902 {true, false, true});
2903 checkIdMapped(
2904 tv0,
2905 tv0->getRootDomain(),
2906 {false, true},
2907 tv1,
2908 tv1->getRootDomain(),
2909 {false, true});
2910 checkIdMapped(
2911 tv0,
2912 tv0->getRootDomain(),
2913 {true, true},
2914 tv5,
2915 tv5->getRootDomain(),
2916 {false, true, true});
2917 checkIdMapped(
2918 tv1,
2919 tv1->getRootDomain(),
2920 {true, true},
2921 tv5,
2922 tv5->getRootDomain(),
2923 {true, false, true});
2924 checkIdMapped(tv3, tv3->getRootDomain(), tv4, tv4->getRootDomain());
2925 checkIdMapped(tv3, tv3->getRootDomain(), tv5, tv5->getRootDomain());
2926 checkIdMapped(tv4, tv4->getRootDomain(), tv5, tv5->getRootDomain());
2927}
2928
2929TEST_F(NVFuserTest, FusionRootMappingRfactor_CUDA) {
2930 Fusion fusion;
2931 FusionGuard fg(&fusion);
2932
2933 // [I,I]
2934 TensorView* tv0 = makeSymbolicTensor(2);
2935 // [I,I,I]
2936 TensorView* tv1 = makeSymbolicTensor(3);
2937
2938 //[I,I,R]
2939 auto tv2 = sum(tv1, {2});
2940 auto tv3 = add(tv2, tv0);
2941
2942 fusion.addInput(tv0);
2943 fusion.addInput(tv1);
2944 fusion.addOutput(tv3);
2945
2946 // scheduling:
2947 //[B,I,R0,R1=128], root = [B,I,R]
2948 tv2->split(2, 128);
2949
2950 // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf]
2951 auto tv4 = tv2->rFactor({3});
2952
2953 checkIdMapped(tv1, tv1->getRootDomain(), tv4, tv4->getRootDomain());
2954 checkIdMapped(
2955 tv4,
2956 tv4->getRFactorDomain(),
2957 {true, true, true, false},
2958 tv2,
2959 tv2->getRootDomain(),
2960 {true, true, true});
2961 checkIdMapped(
2962 tv1,
2963 tv1->getRootDomain(),
2964 {true, true, false},
2965 tv2,
2966 tv2->getRootDomain(),
2967 {true, true, false});
2968 checkIdMapped(
2969 tv1,
2970 tv1->getRootDomain(),
2971 {true, true, false},
2972 tv3,
2973 tv3->getRootDomain(),
2974 {true, true});
2975 checkIdMapped(
2976 tv2,
2977 tv2->getRootDomain(),
2978 {true, true, false},
2979 tv3,
2980 tv3->getRootDomain(),
2981 {true, true});
2982 checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain());
2983 checkIdMapped(
2984 tv0,
2985 tv0->getRootDomain(),
2986 {true, true},
2987 tv1,
2988 tv1->getRootDomain(),
2989 {true, true, false});
2990 checkIdMapped(
2991 tv0,
2992 tv0->getRootDomain(),
2993 {true, true},
2994 tv2,
2995 tv2->getRootDomain(),
2996 {true, true, false});
2997 checkIdMapped(
2998 tv0,
2999 tv0->getRootDomain(),
3000 {true, true},
3001 tv4,
3002 tv4->getRFactorDomain(),
3003 {true, true, false, false});
3004 checkIdMapped(
3005 tv0,
3006 tv0->getRootDomain(),
3007 {true, true},
3008 tv4,
3009 tv4->getRootDomain(),
3010 {true, true, false});
3011}
3012
3013TEST_F(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) {
3014 Fusion fusion;
3015 FusionGuard fg(&fusion);
3016
3017 TensorView* tv0 = makeSymbolicTensor(2);
3018 auto tv1 = sum(tv0, {1});
3019 auto tv2 = broadcast(tv1, {false, true});
3020 fusion.addOutput(tv2);
3021
3022 // The second dimension cannot be mapped as it would require recomputation.
3023 checkIdMapped(tv0, tv0->getRootDomain(), tv1, tv1->getRootDomain());
3024 checkIdMapped(
3025 tv1,
3026 tv1->getRootDomain(),
3027 {true, false},
3028 tv2,
3029 tv2->getRootDomain(),
3030 {true, false});
3031 checkIdMapped(
3032 tv0,
3033 tv0->getRootDomain(),
3034 {true, false},
3035 tv2,
3036 tv2->getRootDomain(),
3037 {true, false});
3038}
3039
3040TEST_F(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) {
3041 Fusion fusion;
3042 FusionGuard fg(&fusion);
3043
3044 TensorView* tv0 = makeSymbolicTensor(2);
3045 auto tv1 = sum(tv0, {1});
3046 auto tv2 = broadcast(tv1, {false, true});
3047 auto tv3 = add(tv0, tv2);
3048 fusion.addOutput(tv3);
3049
3050 checkIdMapped(
3051 tv0,
3052 tv0->getRootDomain(),
3053 {true, false},
3054 tv1,
3055 tv1->getRootDomain(),
3056 {true, false});
3057 checkIdMapped(
3058 tv1,
3059 tv1->getRootDomain(),
3060 {true, false},
3061 tv2,
3062 tv2->getRootDomain(),
3063 {true, false});
3064 checkIdMapped(
3065 tv0,
3066 tv0->getRootDomain(),
3067 {true, false},
3068 tv3,
3069 tv3->getRootDomain(),
3070 {true, false});
3071 checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain());
3072}
3073
3074TEST_F(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) {
3075 Fusion fusion;
3076 FusionGuard fg(&fusion);
3077
3078 TensorView* tv0 = makeSymbolicTensor(2);
3079 auto tv1 = sum(tv0, {1});
3080 auto tv2 = broadcast(tv1, {false, true});
3081 fusion.addOutput(tv2);
3082
3083 tv1->split(-1, 4);
3084 auto tv3 = tv1->rFactor({-2});
3085
3086 checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain());
3087 checkIdMapped(
3088 tv3,
3089 tv3->getMaybeRFactorDomain(),
3090 {true, false, true},
3091 tv1,
3092 tv1->getRootDomain(),
3093 {true, true});
3094 checkIdMapped(
3095 tv1,
3096 tv1->getRootDomain(),
3097 {true, false},
3098 tv2,
3099 tv2->getRootDomain(),
3100 {true, false});
3101}
3102
3103TEST_F(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) {
3104 Fusion fusion;
3105 FusionGuard fg(&fusion);
3106
3107 TensorView* tv0 = makeSymbolicTensor(2);
3108 auto tv1 = sum(tv0, {1});
3109 auto tv2 = broadcast(tv1, {false, true});
3110 auto tv3 = add(tv0, tv2);
3111 fusion.addOutput(tv3);
3112
3113 tv1->split(-1, 4);
3114 auto tv4 = tv1->rFactor({-2});
3115
3116 checkIdMapped(
3117 tv0,
3118 tv0->getRootDomain(),
3119 {true, false},
3120 tv4,
3121 tv4->getRootDomain(),
3122 {true, false});
3123 checkIdMapped(
3124 tv4,
3125 tv4->getMaybeRFactorDomain(),
3126 {true, false, true},
3127 tv1,
3128 tv1->getRootDomain(),
3129 {true, true});
3130 checkIdMapped(
3131 tv1,
3132 tv1->getRootDomain(),
3133 {true, false},
3134 tv2,
3135 tv2->getRootDomain(),
3136 {true, false});
3137 checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain());
3138 checkIdMapped(
3139 tv0,
3140 tv0->getRootDomain(),
3141 {true, false},
3142 tv2,
3143 tv2->getRootDomain(),
3144 {true, false});
3145}
3146
3147// Reproducer of issue #749
3148TEST_F(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) {
3149 Fusion fusion;
3150 FusionGuard fg(&fusion);
3151
3152 auto tv0 = makeSymbolicTensor(2);
3153 fusion.addInput(tv0);
3154 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
3155 auto tv2 = sum(tv1, {1});
3156 auto tv3 = broadcast(tv2, {false, true});
3157 auto tv4 = add(tv0, tv3);
3158 auto tv5 = add(tv4, tv1);
3159 fusion.addOutput(tv5);
3160
3161 checkIdMapped(
3162 tv0,
3163 tv0->getRootDomain(),
3164 {true, false},
3165 tv1,
3166 tv1->getRootDomain(),
3167 {true, false});
3168 checkIdMapped(
3169 tv1,
3170 tv1->getRootDomain(),
3171 {true, false},
3172 tv2,
3173 tv2->getRootDomain(),
3174 {true, false});
3175 checkIdMapped(
3176 tv2,
3177 tv2->getRootDomain(),
3178 {true, false},
3179 tv3,
3180 tv3->getRootDomain(),
3181 {true, false});
3182 checkIdMapped(
3183 tv3,
3184 tv3->getRootDomain(),
3185 {true, true},
3186 tv4,
3187 tv4->getRootDomain(),
3188 {true, true});
3189 checkIdMapped(
3190 tv0,
3191 tv0->getRootDomain(),
3192 {true, false},
3193 tv4,
3194 tv4->getRootDomain(),
3195 {true, false});
3196 checkIdMapped(
3197 tv4,
3198 tv4->getRootDomain(),
3199 {true, true},
3200 tv5,
3201 tv5->getRootDomain(),
3202 {true, true});
3203}
3204
3205// Similar to RootMappingReductionDependency5 but with rFactor
3206TEST_F(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) {
3207 Fusion fusion;
3208 FusionGuard fg(&fusion);
3209
3210 auto tv0 = makeSymbolicTensor(2);
3211 fusion.addInput(tv0);
3212 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
3213 auto tv2 = sum(tv1, {1});
3214 auto tv3 = broadcast(tv2, {false, true});
3215 auto tv4 = add(tv0, tv3);
3216 auto tv5 = add(tv4, tv1);
3217 fusion.addOutput(tv5);
3218
3219 tv2->split(1, 4);
3220 auto tv6 = tv2->rFactor({-1});
3221
3222 checkIdMapped(
3223 tv0,
3224 tv0->getRootDomain(),
3225 {true, false},
3226 tv1,
3227 tv1->getRootDomain(),
3228 {true, false});
3229 checkIdMapped(
3230 tv1,
3231 tv1->getRootDomain(),
3232 {true, false},
3233 tv6,
3234 tv6->getRootDomain(),
3235 {true, false});
3236 checkIdMapped(
3237 tv6,
3238 tv6->getMaybeRFactorDomain(),
3239 {true, true, false},
3240 tv2,
3241 tv2->getRootDomain(),
3242 {true, true});
3243 checkIdMapped(
3244 tv1,
3245 tv1->getRootDomain(),
3246 {true, false},
3247 tv2,
3248 tv2->getRootDomain(),
3249 {true, false});
3250 checkIdMapped(
3251 tv2,
3252 tv2->getRootDomain(),
3253 {true, false},
3254 tv3,
3255 tv3->getRootDomain(),
3256 {true, false});
3257 checkIdMapped(
3258 tv3,
3259 tv3->getRootDomain(),
3260 {true, true},
3261 tv4,
3262 tv4->getRootDomain(),
3263 {true, true});
3264 checkIdMapped(
3265 tv0,
3266 tv0->getRootDomain(),
3267 {true, false},
3268 tv4,
3269 tv4->getRootDomain(),
3270 {true, false});
3271 checkIdMapped(
3272 tv4,
3273 tv4->getRootDomain(),
3274 {true, true},
3275 tv5,
3276 tv5->getRootDomain(),
3277 {true, true});
3278}
3279
3280TEST_F(
3281 NVFuserTest,
3282 FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) {
3283 Fusion fusion;
3284 FusionGuard fg(&fusion);
3285
3286 TensorView* tv0 = makeSymbolicTensor(1);
3287 auto tv1 = broadcast(tv0, {false, true});
3288 auto tv2 = broadcast(tv0, {true, false});
3289 fusion.addOutput(tv1);
3290 fusion.addOutput(tv2);
3291
3292 // If there is no common consumer, there is no recomputation constraint.
3293 checkIdMapped(
3294 tv0,
3295 tv0->getRootDomain(),
3296 {true},
3297 tv1,
3298 tv1->getRootDomain(),
3299 {true, false});
3300 checkIdMapped(
3301 tv0,
3302 tv0->getRootDomain(),
3303 {true},
3304 tv2,
3305 tv2->getRootDomain(),
3306 {false, true});
3307 checkIdMapped(
3308 tv1,
3309 tv1->getRootDomain(),
3310 {true, false},
3311 tv2,
3312 tv2->getRootDomain(),
3313 {false, true});
3314}
3315
3316TEST_F(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) {
3317 Fusion fusion;
3318 FusionGuard fg(&fusion);
3319
3320 auto tv0 = makeSymbolicTensor(1);
3321 fusion.addInput(tv0);
3322 auto tv1 = makeSymbolicTensor(2);
3323 fusion.addInput(tv1);
3324 auto tv2 = makeSymbolicTensor(2);
3325 fusion.addInput(tv2);
3326 auto tv3 = broadcast(tv0, {false, true});
3327 auto tv4 = add(tv1, tv3);
3328 fusion.addOutput(tv4);
3329 auto tv5 = add(tv2, tv3);
3330 fusion.addOutput(tv5);
3331
3332 // Broadcast domains can be used with multiple domains with
3333 // different sizes. In this test, the broadcast domain of tv3 has
3334 // two consumers, tv4 and tv5, which may have different sizes. Each
3335 // of the consumers is used with the broadcast domain of tv3, but
3336 // the two consumers may not have the same size, it is not possible
3337 // to map those domains.
3338 checkIdMapped(
3339 tv0,
3340 tv0->getRootDomain(),
3341 {true},
3342 tv3,
3343 tv3->getRootDomain(),
3344 {true, false});
3345 checkIdMapped(
3346 tv0,
3347 tv0->getRootDomain(),
3348 {true},
3349 tv1,
3350 tv1->getRootDomain(),
3351 {true, false});
3352 checkIdMapped(
3353 tv0,
3354 tv0->getRootDomain(),
3355 {true},
3356 tv2,
3357 tv2->getRootDomain(),
3358 {true, false});
3359 checkIdMapped(
3360 tv1,
3361 tv1->getRootDomain(),
3362 {true, false},
3363 tv2,
3364 tv2->getRootDomain(),
3365 {true, false});
3366 checkIdMapped(
3367 tv1,
3368 tv1->getRootDomain(),
3369 {true, false},
3370 tv3,
3371 tv3->getRootDomain(),
3372 {true, false});
3373 checkIdMapped(
3374 tv2,
3375 tv2->getRootDomain(),
3376 {true, false},
3377 tv3,
3378 tv3->getRootDomain(),
3379 {true, false});
3380 checkIdMapped(
3381 tv3,
3382 tv3->getRootDomain(),
3383 {true, false},
3384 tv4,
3385 tv4->getRootDomain(),
3386 {true, false});
3387 checkIdMapped(
3388 tv3,
3389 tv3->getRootDomain(),
3390 {true, false},
3391 tv5,
3392 tv5->getRootDomain(),
3393 {true, false});
3394 checkIdMapped(
3395 tv4,
3396 tv4->getRootDomain(),
3397 {true, false},
3398 tv5,
3399 tv5->getRootDomain(),
3400 {true, false});
3401}
3402
3403TEST_F(NVFuserTest, FusionRootMappingBroadcast_CUDA) {
3404 Fusion fusion;
3405 FusionGuard fg(&fusion);
3406
3407 auto tv0 = makeSymbolicTensor(1);
3408 // tv0[I0]
3409 fusion.addInput(tv0);
3410 auto tv1 = broadcast(tv0, {true, false});
3411 // tv1[B1, I0]
3412 auto tv2 = broadcast(tv1, {true, false, false});
3413 // tv2[B2, B1, I0]
3414 fusion.addOutput(tv2);
3415
3416 // In this case, tv1 and tv2 has one and two broadcast domains,
3417 // respectively. It is the second broadcast domain that is mapped to
3418 // the broadcast of tv1.
3419 checkIdMapped(
3420 tv0,
3421 tv0->getRootDomain(),
3422 {true},
3423 tv1,
3424 tv1->getRootDomain(),
3425 {false, true});
3426 checkIdMapped(
3427 tv1,
3428 tv1->getRootDomain(),
3429 {true, true},
3430 tv2,
3431 tv2->getRootDomain(),
3432 {false, true, true}); // Not {true, false, true}
3433 checkIdMapped(
3434 tv0,
3435 tv0->getRootDomain(),
3436 {true},
3437 tv2,
3438 tv2->getRootDomain(),
3439 {false, false, true});
3440}
3441
3442// Reproducer of issue #723
3443TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) {
3444 Fusion fusion;
3445 FusionGuard fg(&fusion);
3446
3447 auto tv0 = makeSymbolicTensor(1);
3448 auto tv1 = makeSymbolicTensor(2);
3449
3450 fusion.addInput(tv0);
3451 fusion.addInput(tv1);
3452
3453 auto tv2 = broadcast(tv0, {true, false});
3454 auto tv3 = sum(tv2, {0});
3455 auto tv4 = add(tv2, tv1);
3456
3457 fusion.addOutput(tv3);
3458 fusion.addOutput(tv4);
3459
3460 ComputeAtRootDomainMap map;
3461 map.build();
3462
3463 checkIdMapped(
3464 map, tv2, tv2->getRootDomain()[0], tv4, tv4->getRootDomain()[0], true);
3465 checkIdMapped(
3466 map, tv2, tv2->getRootDomain()[0], tv3, tv3->getRootDomain()[0], true);
3467
3468 tv2->computeAt(tv4, -1);
3469
3470 const int x = 11;
3471 const int y = 12;
3472 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
3473 at::Tensor t0 = at::randn({x}, options);
3474 at::Tensor t1 = at::randn({y, x}, options);
3475 std::vector<IValue> aten_inputs = {t0, t1};
3476
3477 FusionExecutor fe;
3478 fe.compileFusion(&fusion, aten_inputs);
3479 auto outputs = fe.runFusion(aten_inputs);
3480
3481 auto t3 = t0;
3482 auto t4 = t0.unsqueeze(0).expand({y, x}) + t1;
3483
3484 testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__);
3485}
3486
3487// Repro of issue #1950
3488TEST_F(NVFuserTest, FusionRootMappingRepro1950_CUDA) {
3489 Fusion fusion;
3490 FusionGuard fg(&fusion);
3491 auto tv0 = makeSymbolicTensor(3);
3492 auto tv1 = makeSymbolicTensor(3);
3493 auto tv2 = makeSymbolicTensor(3);
3494
3495 fusion.addInput(tv0);
3496 fusion.addInput(tv1);
3497 fusion.addInput(tv2);
3498
3499 auto tv3 = set(tv0);
3500 auto tv4 = mul(tv1, tv3);
3501 auto tv5 = mul(tv1, tv2);
3502 auto tv6 = mul(tv5, tv3);
3503 auto tv7 = sum(tv6, {2});
3504 auto tv8 = broadcast(tv7, {false, false, true});
3505 auto tv9 = mul(tv3, tv8);
3506
3507 // Issue #1950 was caused by a particular traversal ordering based
3508 // on the output tensor ordering as below
3509 fusion.addOutput(tv9);
3510 fusion.addOutput(tv5);
3511 fusion.addOutput(tv4);
3512
3513 ComputeAtRootDomainMap root_map;
3514 root_map.build();
3515
3516 checkIdMapped(root_map, tv4, tv4->axis(-1), tv9, tv9->axis(-1), false);
3517}
3518
3519TEST_F(NVFuserTest, FusionDetectSelfMappedDomains_CUDA) {
3520 Fusion fusion;
3521 FusionGuard fg(&fusion);
3522
3523 auto tv0 = makeSymbolicTensor(1);
3524 fusion.addInput(tv0);
3525 // [I1]
3526 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
3527 // [B2, I2]
3528 auto tv2 = broadcast(tv1, {true, false});
3529 // [I3, B3]
3530 auto tv3 = broadcast(tv1, {false, true});
3531 // [I4, I5]
3532 auto tv4 = add(tv2, tv3);
3533 fusion.addOutput(tv4);
3534
3535 // IterDomainGraph maps B2, I3 and I4 together, and similarly I2,
3536 // B3 and I5. The problem is I1 is mapped with both of the ID
3537 // groups, so eventually all of the IDs are mapped
3538 // together. IterDomainGraph should throw an exception as this
3539 // pattern of domain mappings is not supported.
3540
3541 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
3542 ASSERT_ANY_THROW({ IterDomainGraph id_graph(&fusion); });
3543}
3544
3545TEST_F(NVFuserTest, FusionScalarInputs_CUDA) {
3546 Fusion fusion;
3547 FusionGuard fg(&fusion);
3548
3549 TensorView* tv0 = makeSymbolicTensor(2);
3550 fusion.addInput(tv0);
3551 TensorView* tv1 = makeSymbolicTensor(2);
3552 fusion.addInput(tv1);
3553
3554 Double* d0 = IrBuilder::create<Double>();
3555 fusion.addInput(d0);
3556 Double* d1 = IrBuilder::create<Double>();
3557 fusion.addInput(d1);
3558 Double* d2 = IrBuilder::create<Double>();
3559 fusion.addInput(d2);
3560 Double* d3 = IrBuilder::create<Double>();
3561 fusion.addInput(d3);
3562 Val* d4 = mul(d0, d1);
3563 Val* d5 = sub(d2, d3);
3564
3565 TensorView* tv2 = sub(tv1, d4);
3566 TensorView* tv3 = add(tv0, d5);
3567 TensorView* tv4 = mul(tv3, tv2);
3568
3569 fusion.addOutput(tv4);
3570
3571 // Lets setup to actually run
3572 while (tv4->nDims() > 1)
3573 tv4->merge(0);
3574 tv4->split(0, 128);
3575 tv4->split(0, 4);
3576
3577 tv0->computeAt(tv4, 1);
3578 tv1->computeAt(tv4, 1);
3579
3580 tv4->axis(0)->parallelize(ParallelType::BIDx);
3581
3582 for (Val* val : fusion.vals()) {
3583 if (!val->isFusionInput() &&
3584 val->getValType().value() == ValType::TensorView) {
3585 TensorView* tv = static_cast<TensorView*>(val);
3586
3587 tv->axis(1)->parallelize(ParallelType::Unroll);
3588 tv->axis(-1)->parallelize(ParallelType::TIDx);
3589 }
3590 }
3591
3592 // d4 = d0 * d1
3593 // d5 = d2 - d3
3594 // t2 = t1 - d4
3595 // t3 = t0 + d5
3596 // t4 = t3 * t2
3597
3598 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
3599
3600 float fl0 = 0.1;
3601 float fl1 = -0.2;
3602 float fl2 = 0.3;
3603 float fl3 = -0.4;
3604 float fl4 = fl0 * fl1;
3605 float fl5 = fl2 - fl3;
3606
3607 at::Tensor t0 = at::randn({129, 127}, options);
3608 at::Tensor t1 = at::rand_like(t0, options);
3609
3610 auto t2 = t1.sub(fl4);
3611 auto t3 = t0.add(fl5);
3612 auto aten_output = t3.mul(t2);
3613
3614 at::Tensor cg_output = at::empty_like(t0, options);
3615
3616 at::Scalar test(fl0);
3617
3618 std::vector<IValue> aten_inputs = {
3619 t0,
3620 t1,
3621 at::Scalar(fl0),
3622 at::Scalar(fl1),
3623 at::Scalar(fl2),
3624 at::Scalar(fl3)};
3625
3626 FusionExecutor fe;
3627 fe.compileFusion(&fusion, aten_inputs);
3628 fe.runFusion(aten_inputs, {cg_output});
3629
3630 testValidate(
3631 &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
3632}
3633
3634TEST_F(NVFuserTest, FusionLoopUnroll_CUDA) {
3635 Fusion fusion;
3636 FusionGuard fg(&fusion);
3637
3638 // Set up your input tensor views
3639 TensorView* tv0 = makeSymbolicTensor(3);
3640 TensorView* tv1 = makeSymbolicTensor(3);
3641
3642 // Register your inputs
3643 fusion.addInput(tv0);
3644 fusion.addInput(tv1);
3645
3646 // Do math with it, it returns a `Val*` but can be static_casted back to
3647 // TensorView
3648 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2.0));
3649 TensorView* tv3 = add(tv0, tv2);
3650
3651 // Register your outputs
3652 fusion.addOutput(tv3);
3653
3654 int block_size = 16;
3655
3656 tv3->merge(0, 1);
3657 tv3->merge(0, 1);
3658
3659 tv3->split(0, block_size);
3660 tv3->split(0, 4);
3661
3662 // For all inputs, computeAt the output inline, temporaries should be squeezed
3663 // between them
3664 tv0->computeAt(tv3, 1);
3665 tv1->computeAt(tv3, 1);
3666
3667 // Parallelize
3668 tv2->axis(1)->parallelize(ParallelType::Unroll);
3669 tv3->axis(1)->parallelize(ParallelType::Unroll);
3670 tv2->axis(-1)->parallelize(ParallelType::TIDx);
3671 tv3->axis(-1)->parallelize(ParallelType::TIDx);
3672 tv3->axis(0)->parallelize(ParallelType::BIDx);
3673
3674 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
3675
3676 at::Tensor input0 = at::randn({129, 13, 3}, options);
3677 at::Tensor input1 = at::randn({129, 13, 3}, options);
3678
3679 FusionExecutor fe;
3680 fe.compileFusion(&fusion, {input0, input1});
3681 auto outputs = fe.runFusion({input0, input1});
3682
3683 TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0))));
3684}
3685
3686/*
3687 * Helper function for single op testing that generates a codegen operand
3688 */
3689
3690Val* gen_jit_operand(std::pair<ValType, DataType> desc) {
3691 if (desc.first == ValType::TensorView) {
3692 return makeSymbolicTensor(2, desc.second);
3693 } else if (desc.first == ValType::Scalar) {
3694 if (desc.second == DataType::Float) {
3695 return IrBuilder::create<Double>();
3696 } else if (desc.second == DataType::Double) {
3697 return IrBuilder::create<Double>();
3698 } else if (desc.second == DataType::ComplexFloat) {
3699 return IrBuilder::create<ComplexDouble>();
3700 } else if (desc.second == DataType::ComplexDouble) {
3701 return IrBuilder::create<ComplexDouble>();
3702 } else if (desc.second == DataType::Int) {
3703 return IrBuilder::create<Int>();
3704 } else {
3705 TORCH_CHECK(false, "Not currently supported type: ", desc.first);
3706 }
3707 } else {
3708 TORCH_CHECK(false, "Not currently supported type: ", desc.first);
3709 }
3710 return nullptr;
3711}
3712
3713/*
3714 * Helper function for single op testing that generates an ATen operand
3715 */
3716
3717IValue gen_aten_operand(
3718 std::pair<ValType, DataType> desc,
3719 int blocks,
3720 int threads,
3721 bool rand) {
3722 if (desc.first == ValType::TensorView) {
3723 if (desc.second == DataType::Double || desc.second == DataType::Float ||
3724 desc.second == DataType::ComplexDouble ||
3725 desc.second == DataType::ComplexFloat ||
3726 desc.second == DataType::Half || desc.second == DataType::BFloat16) {
3727 auto options = at::TensorOptions()
3728 .dtype(data_type_to_aten(desc.second))
3729 .device(at::kCUDA, 0);
3730 if (rand) {
3731 return IValue(at::rand({blocks, threads}, options));
3732 } else {
3733 return IValue(at::empty({blocks, threads}, options));
3734 }
3735 } else if (desc.second == DataType::Int || desc.second == DataType::Int32) {
3736 auto dtype = desc.second == DataType::Int32 ? at::kInt : at::kLong;
3737 if (rand) {
3738 auto options =
3739 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
3740 return IValue(at::randn({blocks, threads}, options).mul(5).to(dtype));
3741 } else {
3742 auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
3743 return IValue(at::empty({blocks, threads}, options));
3744 }
3745 } else if (desc.second == DataType::Bool) {
3746 if (rand) {
3747 auto options =
3748 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
3749 return IValue(
3750 at::rand({blocks, threads}, options).round().to(at::kBool));
3751 } else {
3752 auto options =
3753 at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0);
3754 return IValue(at::empty({blocks, threads}, options));
3755 }
3756 } else {
3757 TORCH_CHECK(false, "Not currently supported type: ", desc.second)
3758 }
3759 } else if (desc.first == ValType::Scalar) {
3760 // IValue scalars can only be double int64 or bool
3761 if (desc.second == DataType::ComplexDouble ||
3762 desc.second == DataType::ComplexFloat) {
3763 return IValue(at::Scalar(c10::complex<double>(1.0, 0.0)));
3764 } else if (
3765 desc.second == DataType::Double || desc.second == DataType::Float ||
3766 desc.second == DataType::Half || desc.second == DataType::BFloat16) {
3767 return IValue(at::Scalar(1.0));
3768 } else if (desc.second == DataType::Int) {
3769 return IValue(at::Scalar(1));
3770 } else {
3771 TORCH_CHECK(false, "Not currently supported type: ", desc.first);
3772 }
3773 } else {
3774 TORCH_CHECK(false, "Not currently supported type: ", desc.first);
3775 }
3776 return nullptr;
3777}
3778
3779/*
3780 * Templatized Helper Function To generate single Op comparison between the
3781 * JIT codegen for Cuda and the ATen Library.
3782 */
3783
3784using OutputPair = std::pair<ValType, DataType>;
3785template <
3786 typename AtenFunc,
3787 typename JitFunc,
3788 typename InputTuple,
3789 size_t... NumInputs>
3790void test_op(
3791 int blocks,
3792 int threads,
3793 std::string op_str,
3794 AtenFunc af,
3795 JitFunc jf,
3796 OutputPair op,
3797 InputTuple it,
3798 std::index_sequence<NumInputs...>) {
3799 Fusion fusion;
3800 FusionGuard fg(&fusion);
3801
3802 // Generate Input JIT function Inputs and add them as Inputs to the Fusion
3803 // Graph
3804 std::array<Val*, sizeof...(NumInputs)> jit_inputs = {
3805 gen_jit_operand(std::get<NumInputs>(it))...};
3806 std::for_each(jit_inputs.begin(), jit_inputs.end(), [&fusion](Val* v) {
3807 fusion.addInput(v);
3808 });
3809 TensorView* out =
3810 static_cast<TensorView*>(jf(std::get<NumInputs>(jit_inputs)...));
3811 fusion.addOutput(out);
3812
3813 std::for_each(jit_inputs.begin(), jit_inputs.end(), [out](Val* v) {
3814 if (v->getValType() == ValType::TensorView)
3815 static_cast<TensorView*>(v)->computeAt(out, -1);
3816 });
3817 out->axis(0)->parallelize(ParallelType::BIDx);
3818 out->axis(-1)->parallelize(ParallelType::TIDx);
3819
3820 std::array<IValue, sizeof...(NumInputs)> aten_inputs = {gen_aten_operand(
3821 std::get<NumInputs>(it), blocks, threads, /*rand*/ true)...};
3822 const at::ArrayRef<IValue> aten_inputs_ivalues(aten_inputs);
3823
3824 at::Tensor cg_output =
3825 gen_aten_operand(op, blocks, threads, /*rand*/ false).toTensor();
3826 std::vector<at::Tensor> output_vect = {cg_output};
3827 cudaDeviceSynchronize();
3828 if (fusion.isStochastic())
3829 at::manual_seed(0);
3830
3831 FusionExecutor fe;
3832 fe.compileFusion(&fusion, aten_inputs_ivalues);
3833 fe.runFusion(aten_inputs_ivalues, output_vect);
3834 cudaDeviceSynchronize();
3835
3836 if (fusion.isStochastic())
3837 at::manual_seed(0);
3838 at::Tensor aten_output = af(aten_inputs);
3839 cudaDeviceSynchronize(); // This sync shouldn't be necessary;
3840
3841 std::string op_msg = "Operation " + op_str;
3842
3843 testValidate(
3844 &fusion,
3845 {cg_output},
3846 aten_inputs,
3847 {aten_output},
3848 __LINE__,
3849 __FILE__,
3850 op_msg);
3851}
3852
3853/*
3854 * Templatized Helper Function that uses variadic templates to
3855 * process a variable length Input Tuple of different Operand Type.
3856 */
3857template <typename AtenFunc, typename JitFunc, typename InputTuple>
3858void test_op(
3859 int blocks,
3860 int threads,
3861 std::string op_str,
3862 AtenFunc af,
3863 JitFunc jf,
3864 OutputPair op,
3865 InputTuple it) {
3866 static constexpr auto size = std::tuple_size<InputTuple>::value;
3867 test_op(
3868 blocks,
3869 threads,
3870 op_str,
3871 af,
3872 jf,
3873 op,
3874 it,
3875 std::make_index_sequence<size>{});
3876}
3877
3878TEST_F(NVFuserTest, FusionUnaryOps_CUDA) {
3879 using OpTuple =
3880 std::tuple<at::Tensor (*)(const at::Tensor&), UnaryOpType, std::string>;
3881
3882 // [Note: explicit tuple type for uniform initialization list]
3883 // Tuple type must be explicitly specified for each uniform initialization
3884 // list within the vector to make this code compatible with some old env
3885 // which we still need to support. eg. gcc 5.4 + cuda 9.2.
3886 std::vector<OpTuple> ops{
3887 OpTuple{at::acos, UnaryOpType::Acos, "acos"},
3888 OpTuple{at::asin, UnaryOpType::Asin, "asin"},
3889 OpTuple{at::atan, UnaryOpType::Atan, "atan"},
3890 // There does not appear to be an appropriate ATen function for atanh
3891 // OpTuple{at::atanh, UnaryOpType::Atanh, "atanh" },
3892 OpTuple{at::cos, UnaryOpType::Cos, "cos"},
3893 OpTuple{at::cosh, UnaryOpType::Cosh, "cosh"},
3894 OpTuple{at::exp, UnaryOpType::Exp, "exp"},
3895 // OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"},
3896 OpTuple{at::log, UnaryOpType::Log, "log"},
3897 OpTuple{at::log10, UnaryOpType::Log10, "log10"},
3898 OpTuple{at::neg, UnaryOpType::Neg, "neg"},
3899 OpTuple{at::reciprocal, UnaryOpType::Reciprocal, "reciprocal"},
3900 OpTuple{at::sigmoid, UnaryOpType::Sigmoid, "sigmoid"},
3901 OpTuple{at::sin, UnaryOpType::Sin, "sin"},
3902 OpTuple{at::sinh, UnaryOpType::Sinh, "sinh"},
3903 OpTuple{at::sqrt, UnaryOpType::Sqrt, "sqrt"},
3904 OpTuple{at::tan, UnaryOpType::Tan, "tan"},
3905 OpTuple{at::tanh, UnaryOpType::Tanh, "tanh"},
3906 OpTuple{at::isfinite, UnaryOpType::IsFinite, "isfinite"},
3907 OpTuple{at::isinf, UnaryOpType::IsInf, "isinf"},
3908 OpTuple{at::isnan, UnaryOpType::IsNan, "isnan"},
3909 OpTuple{at::isreal, UnaryOpType::IsReal, "isreal"},
3910 };
3911
3912 // The following ops has no complex support in eager mode
3913 std::vector<OpTuple> ops_without_complex{
3914 OpTuple{at::ceil, UnaryOpType::Ceil, "ceil"},
3915 OpTuple{at::floor, UnaryOpType::Floor, "floor"},
3916 OpTuple{at::frac, UnaryOpType::Frac, "frac"},
3917 OpTuple{at::trunc, UnaryOpType::Trunc, "trunc"},
3918 OpTuple{at::round, UnaryOpType::Round, "round"},
3919 OpTuple{at::relu, UnaryOpType::Relu, "relu"},
3920 OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"},
3921 OpTuple{at::log1p, UnaryOpType::Log1p, "log1p"},
3922 OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"},
3923 OpTuple{at::erf, UnaryOpType::Erf, "erf"},
3924 OpTuple{at::erfc, UnaryOpType::Erfc, "erfc"},
3925 OpTuple{at::isneginf, UnaryOpType::IsNegInf, "isneginf"},
3926 OpTuple{at::isposinf, UnaryOpType::IsPosInf, "isposinf"},
3927 };
3928
3929 // The following ops only supports complex
3930 std::vector<OpTuple> ops_complex_only{
3931 // real is supported via UnaryOpType::Set for non-complex types, and
3932 // UnaryOpType::Real requires input to be complex
3933 OpTuple{at::real, UnaryOpType::Real, "real"},
3934 OpTuple{at::imag, UnaryOpType::Imag, "imag"},
3935 };
3936
3937 // Complex support for the following op is not working in nvFuser yet
3938 std::vector<OpTuple> ops_skip_complex{
3939 // TODO: abs is actually supported in nvFuser, but it has bug!!!
3940 // In eager mode, abs(complex_tensor) returns floating point tensor
3941 // but in nvFuser, it wrongly returns complex tensor!
3942 // We need to:
3943 // 1. change our type promotion logic to make a special case for abs
3944 // 2. why this bug is not detected here? we should bump up test coverage
3945 OpTuple{at::abs, UnaryOpType::Abs, "abs"},
3946 // TODO: the following two ops fails with compilation error like
3947 // "undefined function rsqrt(complex)", we could implement them in
3948 // helpers.cu, but I think it is better to check with Jiterator first,
3949 // because Jiterator uses the same string for complex support.
3950 OpTuple{at::rsqrt, UnaryOpType::Rsqrt, "rsqrt"},
3951 OpTuple{at::log2, UnaryOpType::Log2, "log2"}};
3952
3953 std::vector<DataType> dtypes = {
3954 DataType::Float,
3955 DataType::Double,
3956 DataType::ComplexFloat,
3957 DataType::ComplexDouble};
3958
3959 for (auto dtype : dtypes) {
3960 auto ops_to_test = ops;
3961 if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) {
3962 ops_to_test.insert(
3963 ops_to_test.end(),
3964 ops_without_complex.begin(),
3965 ops_without_complex.end());
3966 ops_to_test.insert(
3967 ops_to_test.end(), ops_skip_complex.begin(), ops_skip_complex.end());
3968 } else {
3969 ops_to_test.insert(
3970 ops_to_test.end(), ops_complex_only.begin(), ops_complex_only.end());
3971 }
3972 std::for_each(ops.begin(), ops.end(), [&](OpTuple& op) {
3973 test_op(
3974 /*blocks*/ 640,
3975 /*threads*/ 64,
3976 /*name*/ std::get<2>(op),
3977 /*Aten Func */
3978 [&op](std::array<IValue, 1>& vals) {
3979 return std::get<0>(op)(vals[0].toTensor());
3980 },
3981 /*JIT Func */
3982 [&op](Val* in1) -> Val* { return unaryOp(std::get<1>(op), in1); },
3983 /*Output */ std::make_pair(ValType::TensorView, dtype),
3984 /*Inputs Tuple*/
3985 std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
3986 });
3987 }
3988
3989 dtypes = {DataType::Int, DataType::Int32, DataType::Bool};
3990 for (auto dtype : dtypes) {
3991 test_op(
3992 /*blocks*/ 128,
3993 /*threads*/ 64,
3994 /*name*/ "bitwise_not",
3995 /*Aten Func */
3996 [](std::array<IValue, 1>& vals) {
3997 return at::bitwise_not(vals[0].toTensor());
3998 },
3999 /*JIT Func */
4000 [](Val* in1) -> Val* { return unaryOp(UnaryOpType::Not, in1); },
4001 /*Output */ std::make_pair(ValType::TensorView, dtype),
4002 /*Inputs Tuple*/
4003 std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
4004 }
4005}
4006
4007TEST_F(NVFuserTest, FusionBinaryOps_CUDA) {
4008 using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&);
4009 using OpTuple = std::tuple<AtenFuncSig, BinaryOpType, std::string>;
4010
4011 std::vector<DataType> dtypes = {
4012 DataType::Double,
4013 DataType::Float,
4014 DataType::ComplexFloat,
4015 DataType::ComplexDouble};
4016
4017 // see [Note: explicit tuple type for uniform initialization list]
4018 std::vector<OpTuple> equal_ops{
4019 OpTuple{at::eq, BinaryOpType::Eq, "eq"},
4020 OpTuple{at::ne, BinaryOpType::NE, "ne"}};
4021
4022 // Complex numbers are not ordered
4023 std::vector<OpTuple> order_ops{
4024 OpTuple{at::ge, BinaryOpType::GE, "ge"},
4025 OpTuple{at::gt, BinaryOpType::GT, "gt"},
4026 OpTuple{at::le, BinaryOpType::LE, "le"},
4027 OpTuple{at::lt, BinaryOpType::LT, "lt"}};
4028
4029 // see [Note: explicit tuple type for uniform initialization list]
4030 std::vector<OpTuple> math_ops{
4031 OpTuple{at::div, BinaryOpType::Div, "div"},
4032 OpTuple{at::mul, BinaryOpType::Mul, "mul"},
4033 OpTuple{at::pow, BinaryOpType::Pow, "pow"}};
4034
4035 // The following ops has no complex support in eager mode
4036 std::vector<OpTuple> math_ops_without_complex{
4037 OpTuple{at::atan2, BinaryOpType::Atan2, "atan2"},
4038 OpTuple{at::max, BinaryOpType::Max, "max"},
4039 OpTuple{at::min, BinaryOpType::Min, "min"},
4040 OpTuple{at::fmod, BinaryOpType::Fmod, "fmod"},
4041 // NOTE: Remainder does not match the Aten impl exactly
4042 // despite using an identical function.
4043 OpTuple{at::remainder, BinaryOpType::Remainder, "remainder"}};
4044
4045 for (auto dtype : dtypes) {
4046 auto logic_ops = equal_ops;
4047 if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) {
4048 logic_ops.insert(logic_ops.end(), order_ops.begin(), order_ops.end());
4049 }
4050 std::for_each(logic_ops.begin(), logic_ops.end(), [&](OpTuple& op) {
4051 test_op(
4052 /*blocks*/ 640,
4053 /*threads*/ 64,
4054 /*name*/ std::get<2>(op),
4055 /*Aten Func */
4056 [&op](std::array<IValue, 2>& vals) {
4057 return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
4058 },
4059 /*JIT Func */
4060 [&op](Val* in1, Val* in2) -> Val* {
4061 return binaryOp(std::get<1>(op), in1, in2);
4062 },
4063 /*Output */ std::make_pair(ValType::TensorView, DataType::Bool),
4064 /*Inputs Tuple*/
4065 std::make_tuple(
4066 std::make_pair(ValType::TensorView, dtype),
4067 std::make_pair(ValType::TensorView, dtype)));
4068 });
4069
4070 auto enabled_math_ops = math_ops;
4071 if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) {
4072 enabled_math_ops.insert(
4073 enabled_math_ops.end(),
4074 math_ops_without_complex.begin(),
4075 math_ops_without_complex.end());
4076 }
4077 std::for_each(
4078 enabled_math_ops.begin(), enabled_math_ops.end(), [&](OpTuple& op) {
4079 test_op(
4080 /*blocks*/ 640,
4081 /*threads*/ 64,
4082 /*name*/ std::get<2>(op),
4083 /*Aten Func */
4084 [&op](std::array<IValue, 2>& vals) {
4085 return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
4086 },
4087 /*JIT Func */
4088 [&op](Val* in1, Val* in2) -> Val* {
4089 return binaryOp(std::get<1>(op), in1, in2);
4090 },
4091 /*Output */ std::make_pair(ValType::TensorView, dtype),
4092 /*Inputs Tuple*/
4093 std::make_tuple(
4094 std::make_pair(ValType::TensorView, dtype),
4095 std::make_pair(ValType::TensorView, dtype)));
4096 });
4097
4098 test_op(
4099 /*blocks*/ 640,
4100 /*threads*/ 64,
4101 /*name*/ "add_alpha",
4102 /*Aten Func */
4103 [](std::array<IValue, 3>& vals) {
4104 return at::add(
4105 vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
4106 },
4107 /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&add_alpha),
4108 /*Output */ std::make_pair(ValType::TensorView, dtype),
4109 /*Inputs Tuple*/
4110 std::make_tuple(
4111 std::make_pair(ValType::TensorView, dtype),
4112 std::make_pair(ValType::TensorView, dtype),
4113 std::make_pair(ValType::Scalar, dtype)));
4114
4115 test_op(
4116 /*blocks*/ 640,
4117 /*threads*/ 64,
4118 /*name*/ "sub_alpha",
4119 /*Aten Func */
4120 [](std::array<IValue, 3>& vals) {
4121 return at::sub(
4122 vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
4123 },
4124 /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&sub_alpha),
4125 /*Output */ std::make_pair(ValType::TensorView, dtype),
4126 /*Inputs Tuple*/
4127 std::make_tuple(
4128 std::make_pair(ValType::TensorView, dtype),
4129 std::make_pair(ValType::TensorView, dtype),
4130 std::make_pair(ValType::Scalar, dtype)));
4131 }
4132}
4133
4134TEST_F(NVFuserTest, FusionTernaryOps_CUDA) {
4135 std::vector<DataType> dtypes = {
4136 DataType::Double,
4137 DataType::Float,
4138 DataType::ComplexFloat,
4139 DataType::ComplexDouble};
4140
4141 for (auto dtype : dtypes) {
4142 // clamp and threshold are not supported for complex on eager mode
4143 if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) {
4144 test_op(
4145 /*blocks*/ 640,
4146 /*threads*/ 64,
4147 /*name*/ "clamp",
4148 /*Aten Func */
4149 [](std::array<IValue, 1>& vals) {
4150 return at::clamp(vals[0].toTensor(), 0.f, 1.f);
4151 },
4152 /*JIT Func */
4153 [&](Val* in1) -> Val* {
4154 if (dtype == DataType::Float) {
4155 return clamp(
4156 in1,
4157 IrBuilder::create<Double>(0.f),
4158 IrBuilder::create<Double>(1.f));
4159 } else {
4160 return clamp(
4161 in1,
4162 IrBuilder::create<Double>(0.f),
4163 IrBuilder::create<Double>(1.f));
4164 }
4165 },
4166 /*Output */ std::make_pair(ValType::TensorView, dtype),
4167 /*Inputs Tuple*/
4168 std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
4169 test_op(
4170 /*blocks*/ 640,
4171 /*threads*/ 64,
4172 /*name*/ "threshold",
4173 /*Aten Func */
4174 [](std::array<IValue, 1>& vals) {
4175 return at::threshold(vals[0].toTensor(), 0.f, 1.f);
4176 },
4177 /*JIT Func */
4178 [&](Val* in1) -> Val* {
4179 if (dtype == DataType::Float) {
4180 return threshold(
4181 in1,
4182 IrBuilder::create<Double>(0.f),
4183 IrBuilder::create<Double>(1.f));
4184 } else {
4185 return threshold(
4186 in1,
4187 IrBuilder::create<Double>(0.f),
4188 IrBuilder::create<Double>(1.f));
4189 }
4190 },
4191 /*Output */ std::make_pair(ValType::TensorView, dtype),
4192 /*Inputs Tuple*/
4193 std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
4194 }
4195 test_op(
4196 /*blocks*/ 640,
4197 /*threads*/ 64,
4198 /*name*/ "where",
4199 /*Aten Func */
4200 [](std::array<IValue, 3>& vals) {
4201 return at::where(
4202 vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
4203 },
4204 /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&where),
4205 /*Output */ std::make_pair(ValType::TensorView, dtype),
4206 /*Inputs Tuple*/
4207 std::make_tuple(
4208 std::make_pair(ValType::TensorView, DataType::Bool),
4209 std::make_pair(ValType::TensorView, dtype),
4210 std::make_pair(ValType::TensorView, dtype)));
4211 }
4212}
4213
4214TEST_F(NVFuserTest, FusionCompoundOps_CUDA) {
4215 std::vector<DataType> dtypes = {
4216 DataType::Double,
4217 DataType::Float,
4218 DataType::ComplexFloat,
4219 DataType::ComplexDouble};
4220
4221 for (auto dtype : dtypes) {
4222 test_op(
4223 /*blocks*/ 640,
4224 /*threads*/ 64,
4225 /*name*/ "lerp",
4226 /*Aten Func */
4227 [](std::array<IValue, 3>& vals) {
4228 return at::lerp(
4229 vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
4230 },
4231 /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&lerp),
4232 /*Output */ std::make_pair(ValType::TensorView, dtype),
4233 /*Inputs Tuple*/
4234 std::make_tuple(
4235 std::make_pair(ValType::TensorView, dtype),
4236 std::make_pair(ValType::TensorView, dtype),
4237 std::make_pair(ValType::TensorView, dtype)));
4238 test_op(
4239 /*blocks*/ 640,
4240 /*threads*/ 64,
4241 /*name*/ "addcmul",
4242 /*Aten Func */
4243 [](std::array<IValue, 4>& vals) {
4244 return at::addcmul(
4245 vals[0].toTensor(),
4246 vals[1].toTensor(),
4247 vals[2].toTensor(),
4248 vals[3].toScalar());
4249 },
4250 /*JIT Func */
4251 static_cast<Val* (*)(Val*, Val*, Val*, Val*)>(&addcmul),
4252 /*Output */ std::make_pair(ValType::TensorView, dtype),
4253 /*Inputs Tuple*/
4254 std::make_tuple(
4255 std::make_pair(ValType::TensorView, dtype),
4256 std::make_pair(ValType::TensorView, dtype),
4257 std::make_pair(ValType::TensorView, dtype),
4258 std::make_pair(ValType::Scalar, dtype)));
4259 }
4260}
4261
4262TEST_F(NVFuserTest, FusionCastOps_CUDA) {
4263 Fusion fusion;
4264 FusionGuard fg(&fusion);
4265
4266 TensorView* tv0 = makeSymbolicTensor(2, DataType::Half);
4267
4268 TensorView* intrm1 = castOp(DataType::Float, tv0);
4269 TensorView* out = castOp(DataType::Half, intrm1);
4270
4271 fusion.addInput(tv0);
4272 fusion.addOutput(out);
4273 tv0->computeAt(out, -1);
4274
4275 out->axis(0)->parallelize(ParallelType::BIDx);
4276 out->axis(-1)->parallelize(ParallelType::TIDx);
4277
4278 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
4279
4280 at::Tensor input1 = at::randn({1, 4}, options);
4281 at::Tensor ref_output = at::empty_like(input1);
4282
4283 std::array<IValue, 1> inputs = {input1};
4284 const at::ArrayRef<IValue> input_ivalues(inputs);
4285
4286 FusionExecutor fe;
4287 fe.compileFusion(&fusion, input_ivalues);
4288 auto outputs = fe.runFusion(input_ivalues);
4289
4290 ref_output = at::_cast_Half(at::_cast_Double(input1));
4291
4292 TORCH_CHECK(
4293 outputs[0].equal(ref_output),
4294 "\nOp Type: -- ",
4295 "cast FP16->FP32->FP16",
4296 " -- had a mismatch.\n",
4297 "\nABS MAX DIFF: ",
4298 outputs[0].sub(ref_output).abs().max(),
4299 "\n");
4300}
4301
4302// Start off simple, block on the outer dim
4303// block stride + thread all reduce + unrolling on inner dim
4304TEST_F(NVFuserTest, FusionReduction1_CUDA) {
4305 Fusion fusion;
4306 FusionGuard fg(&fusion);
4307
4308 // Set up your input tensor views
4309 TensorView* tv0 = makeSymbolicTensor(2);
4310 fusion.addInput(tv0);
4311
4312 // tv1[I0, R1] = tv0[I0, I1]
4313 TensorView* tv1 =
4314 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
4315 fusion.addOutput(tv1);
4316
4317 TORCH_CHECK(
4318 ir_utils::getReductionOps(&fusion).size(),
4319 "Could not detect reduction in fusion.");
4320
4321 tv1->split(1, 128);
4322 // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
4323 tv1->split(1, 4);
4324 // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
4325
4326 TensorView* tv2 = tv1->rFactor({1});
4327 // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
4328 // tv1[I0, R1oi{4}, R1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
4329
4330 TensorView* tv3 = tv1->rFactor({1});
4331 // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
4332 // tv3[I0, R1oi{4}, Ir1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
4333 // tv1[I0, R1i{128}] = tv3[I0, R1oi{4}, Ir1i{128}]
4334
4335 // Incrementally, can print in between for debugging
4336 tv0->computeAt(tv2, 1);
4337 tv2->computeAt(tv3, 1);
4338 tv3->computeAt(tv1, 1);
4339
4340 // Re do it all at once, because why not.
4341 tv0->computeAt(tv1, 1);
4342
4343 tv2->axis(2)->parallelize(ParallelType::Unroll);
4344 tv1->axis(0)->parallelize(ParallelType::BIDx);
4345
4346 tv1->axis(-1)->parallelize(ParallelType::TIDx);
4347 tv2->axis(-1)->parallelize(ParallelType::TIDx);
4348 tv3->axis(-1)->parallelize(ParallelType::TIDx);
4349
4350 int numel_x = 65000;
4351 int numel_y = 1025;
4352
4353 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4354 at::Tensor input = at::randn({numel_x, numel_y}, options);
4355 at::Tensor cg_output = at::empty({numel_x}, options);
4356
4357 FusionExecutor fe;
4358 fe.compileFusion(&fusion, {input});
4359 fe.runFusion({input}, {cg_output});
4360
4361 auto aten_output = input.to(at::kDouble).sum({1});
4362
4363 testValidate(
4364 &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
4365}
4366
4367TEST_F(NVFuserTest, FusionReduction2_CUDA) {
4368 Fusion fusion;
4369 FusionGuard fg(&fusion);
4370
4371 // Set up your input tensor views
4372 TensorView* tv0 = makeSymbolicTensor(2);
4373 fusion.addInput(tv0);
4374
4375 // tv1[I0, R1] = tv0[I0, I1]
4376 TensorView* tv1 =
4377 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
4378
4379 fusion.addOutput(tv1);
4380
4381 // switches to try some different scenarios. maybe we should iterate on all
4382 // permutations.
4383 bool bind_bidx = true;
4384 bool bind_tidx = true;
4385 bool bind_tidy = true;
4386 bool bind_unroll = true;
4387
4388 int numel_x = 1025; // Cannot exceed block dim max size / tidy
4389 int numel_y = 129;
4390 int tidx = 16;
4391 int tidy = 8;
4392 int unroll_factor = 4;
4393
4394 tv1->split(1, tidx);
4395 // tv1[I0, R1o, R1i{tidx}] = tv0[I0, I1]
4396
4397 tv1->split(1, unroll_factor);
4398 // tv1[I0, R1oo, R1oi{unroll}, R1i{tidx}] = tv0[I0, I1]
4399
4400 tv1->split(0, tidy);
4401
4402 TensorView* tv2 = tv1->rFactor({-3});
4403 // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
4404 // tv1[I0o, I0i{tidy}, R1oi{unroll}, R1i{tidx}]
4405
4406 TensorView* tv3 = tv1->rFactor({-2});
4407 // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
4408 // tv3[I0, R1oi{unroll}, Ir1i{tidx}]
4409 // tv1[I0o, I0i{tidy}, R1i{tidx}]
4410
4411 tv0->computeAt(tv1, -2);
4412
4413 if (bind_unroll)
4414 tv2->axis(-2)->parallelize(ParallelType::Unroll);
4415 if (bind_bidx)
4416 tv1->axis(0)->parallelize(ParallelType::BIDx);
4417 if (bind_tidy)
4418 tv1->axis(1)->parallelize(ParallelType::TIDy);
4419
4420 if (bind_tidx) {
4421 tv2->axis(-1)->parallelize(ParallelType::TIDx);
4422 tv3->axis(-1)->parallelize(ParallelType::TIDx);
4423 tv1->axis(-1)->parallelize(ParallelType::TIDx);
4424 }
4425
4426 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4427 at::Tensor input = at::randn({numel_x, numel_y}, options);
4428
4429 FusionExecutor fe;
4430 fe.compileFusion(&fusion, {input});
4431 auto cg_outputs = fe.runFusion({input});
4432
4433 auto aten_output = input.to(at::kDouble).sum({1});
4434 testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
4435}
4436
4437TEST_F(NVFuserTest, FusionReduction3_CUDA) {
4438 // What if Z participates in the reduction with X?
4439 Fusion fusion;
4440 FusionGuard fg(&fusion);
4441
4442 // Set up your input tensor views
4443 TensorView* tv0 = makeSymbolicTensor(2);
4444 fusion.addInput(tv0);
4445
4446 // tv1[I0, R1] = tv0[I0, I1]
4447 TensorView* tv1 =
4448 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
4449
4450 fusion.addOutput(tv1);
4451
4452 int numel_x = 1025; // Cannot exceed block dim max size / tidy
4453 int numel_y = 129;
4454 int tidx = 16;
4455 int tidz = 8;
4456
4457 tv1->split(1, tidz);
4458 // tv1[I0, R1o, R1i{tidz}] = tv0[I0, I1]
4459
4460 tv1->split(1, tidx);
4461 // tv1[I0, R1oo, R1oi{tidx}, R1i{tidz}] = tv0[I0, I1]
4462
4463 TensorView* tv2 = tv1->rFactor({-3});
4464 // tv2[I0, >R1oo<, Ir1oi{tidx}, Ir1i{tidz}]
4465 // tv1[I0o, R1oi{tidx}, R1i{tidz}]
4466
4467 tv0->computeAt(tv1, -3);
4468
4469 tv1->axis(0)->parallelize(ParallelType::BIDx);
4470 tv1->axis(-2)->parallelize(ParallelType::TIDx);
4471 tv1->axis(-1)->parallelize(ParallelType::TIDz);
4472
4473 tv2->axis(-2)->parallelize(ParallelType::TIDx);
4474 tv2->axis(-1)->parallelize(ParallelType::TIDz);
4475
4476 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4477 at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
4478 at::Tensor cg_output = at::empty({numel_x}, options);
4479
4480 FusionExecutor fe;
4481 fe.compileFusion(&fusion, {aten_input});
4482 fe.runFusion({aten_input}, {cg_output});
4483
4484 auto aten_output = aten_input.to(at::kDouble).sum({1});
4485
4486 testValidate(
4487 &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
4488}
4489
4490TEST_F(NVFuserTest, FusionReduction4_CUDA) {
4491 Fusion fusion;
4492 FusionGuard fg(&fusion);
4493
4494 // Set up your input tensor views
4495 TensorView* tv0 = makeSymbolicTensor(2);
4496 TensorView* tv1 = makeSymbolicTensor(2);
4497
4498 TensorView* tv2 = add(tv0, tv1);
4499 // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1]
4500
4501 fusion.addInput(tv0);
4502 fusion.addInput(tv1);
4503
4504 TensorView* tv3 =
4505 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv2);
4506 // tv3[I0, R1] = tv2[I0, I1]
4507
4508 TensorView* tv4 = makeSymbolicTensor(1);
4509 fusion.addInput(tv4);
4510
4511 // tv5[I0] = tv3[I0, R1] * tv4[I0]
4512 TensorView* tv5 = mul(tv3, tv4);
4513 fusion.addOutput(tv5);
4514
4515 int tidx = 16;
4516
4517 // RFactor the reduction
4518 tv3->split(1, tidx);
4519 // tv3[I0, R1o, R1i{tidx}] = tv2[I0, I1]
4520
4521 TensorView* tv6 = tv3->rFactor({-2});
4522 // tv6[I0, R1o, iR1i{tidx}] = tv2[I0, I1]
4523 // tv3[I0, R1i{tidx}] = tv3[I0, I1]
4524 tv2->computeAt(tv6, 2);
4525
4526 // Compute at inline with tv5 (only 1D)
4527 tv6->computeAt(tv3, 1);
4528 tv3->computeAt(tv5, 1);
4529
4530 tv5->axis(0)->parallelize(ParallelType::BIDx);
4531
4532 // Intermediate tensors only need this, but doesn't hurt to do on inputs
4533 // tv0, 1, 4
4534 tv2->axis(-1)->parallelize(ParallelType::TIDx);
4535 tv3->axis(-1)->parallelize(ParallelType::TIDx);
4536 tv6->axis(-1)->parallelize(ParallelType::TIDx);
4537
4538 int numel_x = 1025;
4539 int numel_y = 129;
4540
4541 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4542 at::Tensor t0 = at::randn({numel_x, numel_y}, options);
4543 at::Tensor t1 = at::randn({numel_x, numel_y}, options);
4544 at::Tensor t4 = at::randn({numel_x}, options);
4545
4546 FusionExecutor fe;
4547 fe.compileFusion(&fusion, {t0, t1, t4});
4548 auto cg_outputs = fe.runFusion({t0, t1, t4});
4549
4550 auto t2 = t0.add(t1);
4551 auto t3 = t2.to(at::kDouble).sum({1});
4552 auto aten_output = t3.mul(t4);
4553
4554 testValidate(
4555 &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__);
4556}
4557
4558TEST_F(NVFuserTest, FusionReduction5_CUDA) {
4559 Fusion fusion;
4560 FusionGuard fg(&fusion);
4561
4562 // Set up your input tensor views
4563 TensorView* tv0 = makeSymbolicTensor(3);
4564
4565 fusion.addInput(tv0);
4566
4567 TensorView* tv1 =
4568 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
4569
4570 fusion.addOutput(tv1);
4571
4572 int bidy = 2;
4573 int tidy = 4;
4574 int tidx = 5;
4575
4576 int dim1 = 11;
4577
4578 tv1->split(-2, tidy);
4579
4580 TensorView* tv2 = tv1->rFactor({-3});
4581
4582 tv0->computeAt(tv1, 1);
4583 tv1->axis(0)->parallelize(ParallelType::BIDy);
4584
4585 for (auto* val : fusion.vals()) {
4586 if (!val->isFusionInput() &&
4587 val->getValType().value() == ValType::TensorView) {
4588 val->as<TensorView>()->axis(-1)->parallelize(ParallelType::TIDx);
4589 }
4590 }
4591
4592 tv2->axis(-2)->parallelize(ParallelType::TIDy);
4593 tv1->axis(-2)->parallelize(ParallelType::TIDy);
4594
4595 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4596 at::Tensor input = at::randn({bidy, dim1, tidx}, options);
4597
4598 at::Tensor cg_output = at::empty({bidy, tidx}, options);
4599
4600 FusionExecutor fe;
4601 fe.compileFusion(&fusion, {input});
4602 fe.runFusion({input}, {cg_output});
4603
4604 auto aten_output = input.to(at::kDouble).sum({1});
4605 testValidate(
4606 &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
4607}
4608
4609TEST_F(NVFuserTest, FusionReduction6_CUDA) {
4610 Fusion fusion;
4611 FusionGuard fg(&fusion);
4612
4613 const int bdimx = 64;
4614 const int bdimy = 8;
4615
4616 // Set up your input tensor views
4617 TensorView* tv0 = makeSymbolicTensor(3);
4618 fusion.addInput(tv0);
4619
4620 // tv1[I0, R1, R2] = tv0[I0, I1, I2]
4621 TensorView* tv1 =
4622 reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create<Double>(0), tv0);
4623 fusion.addOutput(tv1);
4624
4625 TORCH_CHECK(
4626 ir_utils::getReductionOps(&fusion).size(),
4627 "Could not detect reduction in fusion.");
4628
4629 tv1->split(2, bdimx);
4630 // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
4631 tv1->split(1, bdimy);
4632 // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2]
4633
4634 TensorView* tv2 = tv1->rFactor({3});
4635 // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
4636 // tv1[I0, R1o, R1i{8}, R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
4637
4638 TensorView* tv3 = tv1->rFactor({1});
4639 // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
4640 // tv3[I0, R1o, I1i{8}, I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
4641 // tv1[I0, R1i{8}, R2i{128}] = tv3[I0, R1o, I1i{8}, I2i{128}]
4642
4643 tv3->computeAt(tv1, 1);
4644 tv2->computeAt(tv3, 2);
4645
4646 tv1->axis(0)->parallelize(ParallelType::BIDx);
4647 tv2->axis(0)->parallelize(ParallelType::BIDx);
4648 tv3->axis(0)->parallelize(ParallelType::BIDx);
4649
4650 tv1->axis(-1)->parallelize(ParallelType::TIDx);
4651 tv2->axis(-1)->parallelize(ParallelType::TIDx);
4652 tv3->axis(-1)->parallelize(ParallelType::TIDx);
4653
4654 tv1->axis(-2)->parallelize(ParallelType::TIDy);
4655 tv3->axis(-2)->parallelize(ParallelType::TIDy);
4656 tv2->axis(-3)->parallelize(ParallelType::TIDy);
4657
4658 int numel_x = 650;
4659 int numel_y = 1000;
4660 int numel_z = 4;
4661
4662 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4663 at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options);
4664
4665 FusionExecutor fe;
4666 fe.compileFusion(&fusion, {input});
4667 auto cg_outputs = fe.runFusion({input});
4668
4669 auto aten_output = input.to(at::kDouble).sum({1, 2});
4670 testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
4671}
4672
4673TEST_F(NVFuserTest, FusionMultiGridReduction_CUDA) {
4674 Fusion fusion;
4675 FusionGuard fg(&fusion);
4676
4677 TensorView* tv0 = makeSymbolicTensor(2);
4678 fusion.addInput(tv0);
4679 TensorView* tv1 = max(tv0, {0});
4680 TensorView* tv2 = sum(tv0, {0});
4681
4682 fusion.addOutput(tv1);
4683 fusion.addOutput(tv2);
4684
4685 int numel_x = 4;
4686 int numel_y = 2;
4687
4688 tv1->axis(0)->parallelize(ParallelType::BIDx);
4689 tv1->axis(1)->parallelize(ParallelType::TIDx);
4690
4691 tv2->axis(0)->parallelize(ParallelType::BIDx);
4692 tv2->axis(1)->parallelize(ParallelType::TIDx);
4693
4694 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4695 at::Tensor input = at::randn({numel_x, numel_y}, options);
4696
4697 FusionExecutor fe;
4698 fe.compileFusion(&fusion, {input});
4699 auto cg_outputs = fe.runFusion({input});
4700
4701 std::vector<at::Tensor> aten_outputs = {
4702 std::get<0>(input.to(at::kDouble).max(0)), input.to(at::kDouble).sum(0)};
4703 testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
4704}
4705
4706TEST_F(NVFuserTest, FusionMultiGridReduction2_CUDA) {
4707 Fusion fusion;
4708 FusionGuard fg(&fusion);
4709
4710 auto tv0 = makeSymbolicTensor(2);
4711 fusion.addInput(tv0);
4712 auto tv1 = sum(tv0, {0});
4713 auto tv2 = sum(tv1, {0});
4714 fusion.addOutput(tv2);
4715
4716 tv1->axis(0)->parallelize(ParallelType::BIDx);
4717 tv1->axis(1)->parallelize(ParallelType::BIDy);
4718 tv2->axis(0)->parallelize(ParallelType::BIDy);
4719
4720 FusionExecutor fe;
4721 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
4722 ASSERT_ANY_THROW(fe.compileFusion(&fusion));
4723}
4724
4725TEST_F(NVFuserTest, FusionReductionTFT_CUDA) {
4726 Fusion fusion;
4727 FusionGuard fg(&fusion);
4728
4729 // Set up your input tensor views
4730 TensorView* tv0 = makeSymbolicTensor(2);
4731 fusion.addInput(tv0);
4732
4733 // tv1[I0, R1] = tv0[I0, I1]
4734 TensorView* tv1 =
4735 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
4736
4737 fusion.addOutput(tv1);
4738
4739 int numel_x = 1025;
4740 int numel_y = 129;
4741 int tidx = 16;
4742 int tidy = 8;
4743 int tidz = 8;
4744
4745 tv1->split(1, tidx);
4746 // tv1[I0, R1o, R1i{tidx}]
4747
4748 tv1->split(1, tidz);
4749 // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}]
4750
4751 tv1->split(0, tidy);
4752 // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}]
4753
4754 TensorView* tv2 = tv1->rFactor({2});
4755 // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}]
4756 // tv1[I0o, I0i, R1Oi{tidz}, R1R1i{tidx}]
4757
4758 tv2->computeAt(tv1, 2);
4759
4760 tv1->axis(1)->parallelize(ParallelType::TIDy);
4761
4762 tv2->axis(-1)->parallelize(ParallelType::TIDx);
4763 tv1->axis(-1)->parallelize(ParallelType::TIDx);
4764
4765 tv1->axis(-2)->parallelize(ParallelType::TIDz);
4766 tv2->axis(-2)->parallelize(ParallelType::TIDz);
4767
4768 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4769 at::Tensor input = at::randn({numel_x, numel_y}, options);
4770 at::Tensor cg_output = at::empty({numel_x}, options);
4771
4772 FusionExecutor fe;
4773 fe.compileFusion(&fusion, {input});
4774 fe.runFusion({input}, {cg_output});
4775
4776 auto aten_output = input.to(at::kDouble).sum({1});
4777 testValidate(
4778 &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
4779}
4780
4781TEST_F(NVFuserTest, FusionReductionOuterSplit_CUDA) {
4782 // based off FusionReduction4
4783 Fusion fusion;
4784 FusionGuard fg(&fusion);
4785
4786 // Set up your input tensor views
4787 TensorView* tv0 = makeSymbolicTensor(2);
4788 TensorView* tv1 = makeSymbolicTensor(2);
4789
4790 TensorView* tv2 = add(tv0, tv1);
4791 // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1]
4792
4793 fusion.addInput(tv0);
4794 fusion.addInput(tv1);
4795
4796 TensorView* tv3 =
4797 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv2);
4798 // tv3[I0, R1] = tv2[I0, I1]
4799
4800 TensorView* tv4 = makeSymbolicTensor(1);
4801 fusion.addInput(tv4);
4802
4803 // tv5[I0] = tv3[I0, R1] * tv4[I0]
4804 TensorView* tv5 = mul(tv3, tv4);
4805 fusion.addOutput(tv5);
4806
4807 // RFactor the reduction
4808 tv3->split(1, 16, false);
4809 // tv3[I0, R1o{16}, R1i{tidx}] = tv2[I0, I1]
4810
4811 TensorView* tv6 = tv3->rFactor({-2});
4812 // tv6[I0, R1o{16}, iR1i{tidx}] = tv2[I0, I1]
4813 // tv3[I0, R1i{tidx}] = tv3[I0, I1]
4814 tv2->computeAt(tv6, 2);
4815
4816 // Compute at inline with tv5 (only 1D)
4817 tv6->computeAt(tv3, 1);
4818 tv3->computeAt(tv5, 1);
4819
4820 tv5->axis(0)->parallelize(ParallelType::BIDx);
4821
4822 // Intermediate tensors only need this, but doesn't hurt to do on inputs
4823 // tv0, 1, 4
4824 tv2->axis(-1)->parallelize(ParallelType::TIDx);
4825 tv3->axis(-1)->parallelize(ParallelType::TIDx);
4826 tv6->axis(-1)->parallelize(ParallelType::TIDx);
4827
4828 int numel_x = 1025;
4829 int numel_y = 129;
4830
4831 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4832 at::Tensor t0 = at::randn({numel_x, numel_y}, options);
4833 at::Tensor t1 = at::randn({numel_x, numel_y}, options);
4834 at::Tensor t4 = at::randn({numel_x}, options);
4835
4836 FusionExecutor fe;
4837 fe.compileFusion(&fusion, {t0, t1, t4});
4838 auto cg_outputs = fe.runFusion({t0, t1, t4});
4839
4840 auto t2 = t0.add(t1);
4841 auto t3 = t2.to(at::kDouble).sum({1});
4842 auto aten_output = t3.mul(t4);
4843
4844 testValidate(
4845 &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__);
4846}
4847
4848TEST_F(NVFuserTest, FusionBranches_CUDA) {
4849 Fusion fusion;
4850 FusionGuard fg(&fusion);
4851
4852 // Set up your input tensor views
4853 TensorView* tv0 = makeSymbolicTensor(2);
4854 TensorView* tv1 = makeSymbolicTensor(2);
4855 TensorView* tv2 = makeSymbolicTensor(2);
4856 fusion.addInput(tv0);
4857 fusion.addInput(tv1);
4858 fusion.addInput(tv2);
4859
4860 auto tv3 = add(tv0, IrBuilder::create<Double>(1.0));
4861 auto tv4 = add(tv3, tv1);
4862 auto tv5 = add(tv3, tv2);
4863 auto tv6 = add(tv4, tv5);
4864
4865 fusion.addOutput(tv6);
4866
4867 constexpr int x = 63, y = 33;
4868
4869 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4870
4871 at::Tensor t0 = at::randn({x, y}, options);
4872 at::Tensor t1 = at::randn({x, y}, options);
4873 at::Tensor t2 = at::randn({x, y}, options);
4874
4875 FusionExecutor fe;
4876 tv6->merge(0);
4877 tv6->split(0, 128);
4878 tv6->split(0, 4);
4879
4880 tv6->axis(0)->parallelize(ParallelType::BIDx);
4881
4882 tv0->computeAt(tv6, 1);
4883 tv1->computeAt(tv6, 1);
4884 tv2->computeAt(tv6, 1);
4885
4886 tv3->axis(-2)->parallelize(ParallelType::Unroll);
4887 tv3->axis(-1)->parallelize(ParallelType::TIDx);
4888 tv4->axis(-2)->parallelize(ParallelType::Unroll);
4889 tv4->axis(-1)->parallelize(ParallelType::TIDx);
4890 tv5->axis(-2)->parallelize(ParallelType::Unroll);
4891 tv5->axis(-1)->parallelize(ParallelType::TIDx);
4892 tv6->axis(-1)->parallelize(ParallelType::TIDx);
4893
4894 std::vector<IValue> aten_inputs = {t0, t1, t2};
4895
4896 fe.compileFusion(&fusion, aten_inputs);
4897 auto cg_outputs = fe.runFusion(aten_inputs);
4898
4899 auto t3 = t0.add(1.0);
4900 auto t4 = t3.add(t1);
4901 auto t5 = t3.add(t2);
4902 auto aten_output = t4.add(t5);
4903
4904 testValidate(
4905 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
4906}
4907
4908TEST_F(NVFuserTest, FusionSimpleBCast1_CUDA) {
4909 Fusion fusion;
4910 FusionGuard fg(&fusion);
4911
4912 // Set up your input tensor views
4913 TensorView* tv0 = makeSymbolicTensor(2);
4914 fusion.addInput(tv0);
4915 TensorView* tv1 = add(tv0, IrBuilder::create<Double>(1.5));
4916
4917 TensorView* tv2 = makeSymbolicTensor(2);
4918 fusion.addInput(tv2);
4919 TensorView* tv3 = makeSymbolicTensor(2);
4920 fusion.addInput(tv3);
4921 TensorView* tv4 = sub(tv2, tv3);
4922
4923 TensorView* tv5 = broadcast(tv1, {false, false, true});
4924 TensorView* tv6 = broadcast(tv4, {true, false, false});
4925
4926 TensorView* tv7 = add(tv5, tv6);
4927 fusion.addOutput(tv7);
4928
4929 tv7->split(-1, 4);
4930 tv7->split(0, 8);
4931
4932 tv0->computeAt(tv7, -1);
4933 tv2->computeAt(tv7, -1);
4934
4935 tv7->axis(0)->parallelize(ParallelType::BIDx);
4936 tv7->axis(-1)->parallelize(ParallelType::TIDx);
4937
4938 constexpr int x = 63, y = 33, z = 15;
4939
4940 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
4941
4942 at::Tensor t0 = at::randn({x, y}, options);
4943 at::Tensor t1 = t0.add(1.5);
4944
4945 at::Tensor t2 = at::randn({y, z}, options);
4946 at::Tensor t3 = at::randn({y, z}, options);
4947
4948 at::Tensor t4 = t2.sub(t3);
4949 at::Tensor t5 = t1.unsqueeze(-1).expand({x, y, z});
4950
4951 at::Tensor t6 = t4.expand({x, y, z});
4952
4953 at::Tensor aten_output = t5.add(t6);
4954
4955 std::vector<IValue> aten_inputs = {t0, t2, t3};
4956
4957 FusionExecutor fe;
4958 fe.compileFusion(&fusion, aten_inputs);
4959 auto cg_outputs = fe.runFusion(aten_inputs);
4960
4961 testValidate(
4962 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
4963}
4964
4965TEST_F(NVFuserTest, FusionSimpleBCast2_CUDA) {
4966 Fusion fusion;
4967 FusionGuard fg(&fusion);
4968
4969 // Set up your input tensor views
4970 TensorView* tv0 = makeSymbolicTensor(2);
4971 fusion.addInput(tv0);
4972 TensorView* tv1 = makeSymbolicTensor(2);
4973 fusion.addInput(tv1);
4974
4975 TensorView* tv2 = add(tv0, tv1);
4976
4977 TensorView* tv3 = broadcast(tv2, {false, false, true});
4978
4979 TensorView* tv4 = makeSymbolicTensor(2);
4980 fusion.addInput(tv4);
4981
4982 TensorView* tv5 = sub(tv4, IrBuilder::create<Double>(0.1));
4983
4984 TensorView* tv6 = broadcast(tv5, {true, false, false});
4985
4986 TensorView* tv7 = add(tv3, tv6);
4987
4988 fusion.addOutput(tv7);
4989
4990 tv7->merge(0, 1);
4991
4992 tv0->computeAt(tv7, -1);
4993 tv4->computeAt(tv7, -1);
4994
4995 tv7->axis(0)->parallelize(ParallelType::BIDx);
4996 tv7->axis(-1)->parallelize(ParallelType::TIDx);
4997
4998 constexpr int x = 63, y = 33, z = 15;
4999
5000 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5001
5002 at::Tensor t0 = at::randn({x, y}, options);
5003 at::Tensor t1 = at::randn({x, y}, options);
5004 at::Tensor t2 = t0.add(t1);
5005 at::Tensor t3 = t2.unsqueeze(-1).expand({x, y, z});
5006
5007 at::Tensor t4 = at::randn({y, z}, options);
5008 at::Tensor t5 = t4.sub(0.1);
5009 at::Tensor t6 = t5.expand({x, y, z});
5010 at::Tensor aten_output = t3.add(t6);
5011
5012 at::Tensor cg_output = at::empty({x, y, z}, options);
5013
5014 std::vector<IValue> aten_inputs = {t0, t1, t4};
5015
5016 FusionExecutor fe;
5017 fe.compileFusion(&fusion, aten_inputs);
5018 fe.runFusion(aten_inputs, {cg_output});
5019
5020 testValidate(
5021 &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
5022}
5023
5024TEST_F(NVFuserTest, FusionSimpleBCast3_CUDA) {
5025 Fusion fusion;
5026 FusionGuard fg(&fusion);
5027
5028 // Set up input tensor views
5029 // tv0[I1, B{1}]
5030 TensorView* tv0 = makeConcreteTensor({-1, 1});
5031 fusion.addInput(tv0);
5032
5033 // tv1[I0, I1, I2]
5034 TensorView* tv2 = makeSymbolicTensor(3);
5035 fusion.addInput(tv2);
5036
5037 TensorView* tv3 = add(tv0, tv2);
5038
5039 fusion.addOutput(tv3);
5040
5041 tv3->merge(0);
5042 tv3->merge(0);
5043
5044 tv0->computeAt(tv3, -1);
5045 tv2->computeAt(tv3, -1);
5046
5047 tv3->axis(0)->parallelize(ParallelType::BIDx);
5048
5049 constexpr int x = 2, y = 3, z = 4;
5050
5051 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5052
5053 at::Tensor t0 = at::randn({y, 1}, options);
5054 at::Tensor t2 = at::randn({x, y, z}, options);
5055 auto aten_output = t0.add(t2);
5056
5057 std::vector<IValue> aten_inputs = {t0, t2};
5058 at::Tensor cg_output = at::empty({x, y, z}, options);
5059
5060 FusionExecutor fe;
5061 fe.compileFusion(&fusion, aten_inputs);
5062 fe.runFusion(aten_inputs, {cg_output});
5063
5064 testValidate(
5065 &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
5066}
5067
5068TEST_F(NVFuserTest, FusionSimpleBCast4_CUDA) {
5069 Fusion fusion;
5070 FusionGuard fg(&fusion);
5071
5072 // Set up your input tensor views
5073 TensorView* tv0 = makeConcreteTensor({1, -1});
5074
5075 TensorView* tv1 = makeSymbolicTensor(3);
5076 fusion.addInput(tv0);
5077 fusion.addInput(tv1);
5078
5079 TensorView* tv3 = add(tv0, tv1);
5080
5081 tv3->merge(0);
5082 tv3->merge(0);
5083 tv3->split(0, 128);
5084 tv3->split(0, 4);
5085
5086 fusion.addOutput(tv3);
5087
5088 tv0->computeAt(tv3, -1);
5089 tv1->computeAt(tv3, -1);
5090
5091 tv3->axis(0)->parallelize(ParallelType::BIDx);
5092 tv3->axis(-1)->parallelize(ParallelType::TIDx);
5093 tv3->axis(-2)->parallelize(ParallelType::Unroll);
5094
5095 constexpr int x = 63, y = 33, z = 15;
5096
5097 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5098
5099 at::Tensor t0 = at::randn({1, z}, options);
5100 at::Tensor t1 = at::randn({x, y, z}, options);
5101
5102 auto aten_output = t0.add(t1);
5103
5104 at::Tensor cg_output = at::empty({x, y, z}, options);
5105
5106 std::vector<IValue> aten_inputs = {t0, t1};
5107
5108 FusionExecutor fe;
5109 fe.compileFusion(&fusion, aten_inputs);
5110 fe.runFusion(aten_inputs, {cg_output});
5111
5112 testValidate(
5113 &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
5114}
5115
5116TEST_F(NVFuserTest, FusionSimpleBCast5_CUDA) {
5117 Fusion fusion;
5118 FusionGuard fg(&fusion);
5119
5120 constexpr int m = 2, k = 3, n = 4;
5121 auto tv0 = makeConcreteTensor({m, k});
5122 auto tv1 = makeConcreteTensor({k, n});
5123
5124 fusion.addInput(tv0);
5125 fusion.addInput(tv1);
5126
5127 TensorView* tv2 = broadcast(tv0, {false, false, true});
5128 TensorView* tv3 = broadcast(tv1, {true, false, false});
5129
5130 TensorView* tv4 = add(tv2, tv3);
5131
5132 fusion.addOutput(tv4);
5133
5134 tv4->merge(0);
5135 tv4->merge(0);
5136
5137 tv0->computeAt(tv4, -1);
5138 tv1->computeAt(tv4, -1);
5139
5140 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5141
5142 at::Tensor t0 = at::randn({m, k}, options);
5143 at::Tensor t1 = at::randn({k, n}, options);
5144
5145 auto t2 = t0.unsqueeze(-1).expand({m, k, n});
5146 auto t3 = t1.expand({m, k, n});
5147 auto aten_output = t2.add(t3);
5148
5149 at::Tensor cg_output = at::empty({m, k, n}, options);
5150
5151 std::vector<IValue> aten_inputs = {t0, t1};
5152
5153 FusionExecutor fe;
5154 fe.compileFusion(&fusion, aten_inputs);
5155 fe.runFusion(aten_inputs, {cg_output});
5156
5157 testValidate(
5158 &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
5159}
5160
5161TEST_F(NVFuserTest, FusionComplexBCast1_CUDA) {
5162 Fusion fusion;
5163 FusionGuard fg(&fusion);
5164
5165 int x = 2, y = 3, z = 4;
5166
5167 auto tv0 = makeConcreteTensor({y});
5168 auto tv1 = div(tv0, IrBuilder::create<Double>(2.0));
5169 auto tv2 = broadcast(tv1, {false, true});
5170 auto tv3 = makeConcreteTensor({y, z});
5171 auto tv4 = mul(tv2, tv3);
5172 auto tv5 = broadcast(tv4, {true, false, false});
5173 auto tv6 = makeConcreteTensor({x, y, z});
5174 auto tv7 = add(tv5, tv6);
5175
5176 // tv0[ i1 ] = input
5177 // tv1[ i1 ] = tv0/2.0
5178 // tv2[ i1, b2] = bcast(tv1)
5179 // tv3[ i1, i2] = input
5180 // tv4[ i1, i2] = tv2 * tv3
5181 // tv5[b0, i1, i2] = bcast(tv4)
5182 // tv6[i0, i1, i2] = input
5183 // tv7[i0, i1, i2] = tv5 + tv6
5184
5185 // tv4 = bcast(tv1) * tv3
5186 // tv7 = bcast(tv4) + tv6
5187
5188 fusion.addInput(tv0);
5189 fusion.addInput(tv3);
5190 fusion.addInput(tv6);
5191
5192 fusion.addOutput(tv7);
5193
5194 tv7->merge(0);
5195 tv7->merge(0);
5196 tv0->computeAt(tv7, -1);
5197
5198 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5199
5200 at::Tensor t0 = at::randn({y}, options);
5201 at::Tensor t3 = at::randn({y, z}, options);
5202 at::Tensor t6 = at::randn({x, y, z}, options);
5203
5204 auto t4 = t0.div(2.0).unsqueeze(-1).expand({y, z}) * t3;
5205 auto aten_output = t4.unsqueeze(0).expand({x, y, z}) + t6;
5206
5207 std::vector<IValue> aten_inputs = {t0, t3, t6};
5208
5209 FusionExecutor fe;
5210 fe.compileFusion(&fusion, aten_inputs);
5211 auto cg_outputs = fe.runFusion(aten_inputs);
5212
5213 testValidate(
5214 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
5215}
5216
5217TEST_F(NVFuserTest, FusionComplexBCast2_CUDA) {
5218 Fusion fusion;
5219 FusionGuard fg(&fusion);
5220
5221 int x = 2, y = 3, z = 4;
5222
5223 auto tv0 = makeConcreteTensor({y, z});
5224 auto tv1 = div(tv0, IrBuilder::create<Double>(2.0));
5225 auto tv2 = sum(tv1, {1});
5226 auto tv3 = broadcast(tv2, {true, false});
5227 auto tv4 = makeConcreteTensor({x, y});
5228 auto tv5 = add(tv3, tv4);
5229
5230 // tv0[ i1, i2] = input
5231 // tv1[ i1, i2] = tv0/2.0
5232 // tv2[ i1 ] = sum(tv1, 1)
5233 // tv3[b0, i1 ] = bcast(tv2)
5234 // tv4[i0, i1 ] = input
5235 // tv5[i0, i1 ] = tv3 + tv4
5236
5237 // tv2 = sum(tv0/2.0, 1)
5238 // tv5 = bcast(tv2) + tv4
5239
5240 fusion.addInput(tv0);
5241 fusion.addInput(tv4);
5242
5243 fusion.addOutput(tv5);
5244
5245 tv5->merge(0);
5246 tv0->computeAt(tv5, -1);
5247 tv1->computeAt(tv2, -1);
5248
5249 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5250
5251 at::Tensor t0 = at::randn({y, z}, options);
5252 at::Tensor t4 = at::randn({x, y}, options);
5253
5254 FusionExecutor fe;
5255 fe.compileFusion(&fusion, {t0, t4});
5256 auto cg_outputs = fe.runFusion({t0, t4});
5257
5258 auto t1 = t0.div(2.0);
5259 auto t2 = t1.to(at::kDouble).sum(1);
5260 auto t3 = t2.unsqueeze(0).expand({x, y});
5261 auto aten_output = t3.add(t4);
5262
5263 testValidate(
5264 &fusion, {cg_outputs}, {t0, t4}, {aten_output}, __LINE__, __FILE__);
5265}
5266
5267TEST_F(NVFuserTest, FusionAdvancedIndexing1_CUDA) {
5268 Fusion fusion;
5269 FusionGuard fg(&fusion);
5270
5271 int w = 3, x = 4, y = 7, z = 8;
5272 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5273
5274 auto tv0 = makeSymbolicTensor(3);
5275 auto tv1 = makeSymbolicTensor(4);
5276 fusion.addInput(tv0);
5277 fusion.addInput(tv1);
5278
5279 auto tv2 = add(tv0, IrBuilder::create<Double>(1.0));
5280 auto tv3 = broadcast(tv2, {true, false, false, false});
5281 auto tv4 = add(tv3, tv1);
5282
5283 fusion.addOutput(tv4);
5284
5285 tv4->merge(0);
5286 tv4->merge(0);
5287 tv4->merge(0);
5288
5289 tv4->split(0, 128);
5290 tv4->split(0, 4);
5291
5292 tv2->computeAt(tv4, 1);
5293
5294 tv4->axis(0)->parallelize(ParallelType::BIDx);
5295 tv4->axis(1)->parallelize(ParallelType::Unroll);
5296 tv4->axis(2)->parallelize(ParallelType::TIDx);
5297
5298 tv3->axis(1)->parallelize(ParallelType::Unroll);
5299 tv3->axis(2)->parallelize(ParallelType::TIDx);
5300
5301 tv2->axis(1)->parallelize(ParallelType::Unroll);
5302 tv2->axis(2)->parallelize(ParallelType::TIDx);
5303
5304 FusionExecutor fe;
5305
5306 at::Tensor t0 = at::randn({x, y, z}, options);
5307 at::Tensor t1 = at::randn({w, x, y, z}, options);
5308
5309 auto t3 = t0.add(1.0);
5310 auto aten_output = t3.add(t1);
5311
5312 std::vector<IValue> aten_inputs = {t0, t1};
5313
5314 fe.compileFusion(&fusion, aten_inputs);
5315 auto cg_outputs = fe.runFusion(aten_inputs);
5316
5317 testValidate(
5318 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
5319}
5320
5321TEST_F(NVFuserTest, FusionAdvancedIndexing2_CUDA) {
5322 Fusion fusion;
5323 FusionGuard fg(&fusion);
5324
5325 int w = 3, x = 4, y = 7, z = 8;
5326 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5327
5328 auto tv0 = makeSymbolicTensor(3);
5329 auto tv1 = makeSymbolicTensor(4);
5330 fusion.addInput(tv0);
5331 fusion.addInput(tv1);
5332
5333 auto tv2 = add(tv0, IrBuilder::create<Double>(1.0));
5334 auto tv3 = broadcast(tv2, {true, false, false, false});
5335 auto tv4 = add(tv3, tv1);
5336
5337 fusion.addOutput(tv4);
5338
5339 tv4->merge(-2);
5340 tv4->merge(-2);
5341 tv4->merge(-2);
5342
5343 tv4->split(0, 128);
5344 tv4->split(0, 4);
5345
5346 tv2->computeAt(tv4, 1);
5347
5348 tv4->axis(0)->parallelize(ParallelType::BIDx);
5349 tv4->axis(1)->parallelize(ParallelType::Unroll);
5350 tv4->axis(2)->parallelize(ParallelType::TIDx);
5351
5352 tv3->axis(1)->parallelize(ParallelType::Unroll);
5353 tv3->axis(2)->parallelize(ParallelType::TIDx);
5354
5355 tv2->axis(1)->parallelize(ParallelType::Unroll);
5356 tv2->axis(2)->parallelize(ParallelType::TIDx);
5357
5358 FusionExecutor fe;
5359
5360 at::Tensor t0 = at::randn({x, y, z}, options);
5361 at::Tensor t1 = at::randn({w, x, y, z}, options);
5362
5363 auto t3 = t0.add(1.0);
5364 auto aten_output = t3.add(t1);
5365
5366 std::vector<IValue> aten_inputs = {t0, t1};
5367
5368 fe.compileFusion(&fusion, aten_inputs);
5369 auto cg_outputs = fe.runFusion(aten_inputs);
5370
5371 testValidate(
5372 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
5373}
5374
5375TEST_F(NVFuserTest, FusionAdvancedIndexing3_CUDA) {
5376 Fusion fusion;
5377 FusionGuard fg(&fusion);
5378
5379 int w = 3, x = 4, y = 7, z = 8;
5380
5381 auto tv0 = makeSymbolicTensor(3);
5382 auto tv1 = makeSymbolicTensor(4);
5383 fusion.addInput(tv0);
5384 fusion.addInput(tv1);
5385
5386 auto tv2 = add(tv0, IrBuilder::create<Double>(1.0));
5387 auto tv3 = add(tv2, tv1);
5388 fusion.addOutput(tv3);
5389
5390 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5391 at::Tensor t0 = at::randn({x, y, z}, options);
5392 at::Tensor t1 = at::randn({w, x, y, z}, options);
5393
5394 auto t2 = t0.add(1.0);
5395 auto aten_output = t2.add(t1);
5396
5397 std::vector<IValue> aten_inputs = {t0, t1};
5398
5399 auto lparams = schedulePointwise(&fusion, aten_inputs);
5400
5401 FusionExecutor fe;
5402 fe.compileFusion(&fusion, aten_inputs, lparams);
5403 auto cg_outputs = fe.runFusion(aten_inputs, lparams);
5404
5405 testValidate(
5406 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
5407}
5408
5409TEST_F(NVFuserTest, FusionAdvancedIndexing4_CUDA) {
5410 Fusion fusion;
5411 FusionGuard fg(&fusion);
5412
5413 // Set up your input tensor views
5414 TensorView* tv0 = makeConcreteTensor({4, 8});
5415 fusion.addInput(tv0);
5416 TensorView* tv1 = makeConcreteTensor({4, 4, 8});
5417 fusion.addInput(tv1);
5418
5419 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(1));
5420 TensorView* tv3 = broadcast(tv2, {true, false, false});
5421 TensorView* tv4 = add(tv3, tv1);
5422 fusion.addOutput(tv4);
5423
5424 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5425 at::Tensor t0 = at::randn({4, 8}, options);
5426 at::Tensor t1 = at::randn({4, 4, 8}, options);
5427
5428 auto t2 = t0.add(1.0);
5429 auto aten_output = t2.add(t1);
5430
5431 std::vector<IValue> aten_inputs = {t0, t1};
5432
5433 FusionExecutor fe;
5434 fe.compileFusion(&fusion, aten_inputs);
5435 auto cg_outputs = fe.runFusion(aten_inputs);
5436
5437 testValidate(
5438 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
5439}
5440
5441TEST_F(NVFuserTest, FusionAdvancedIndexing5_CUDA) {
5442 Fusion fusion;
5443 FusionGuard fg(&fusion);
5444
5445 // Set up your input tensor views
5446 TensorView* tv0 = makeSymbolicTensor(1);
5447 fusion.addInput(tv0);
5448 TensorView* tv1 = makeSymbolicTensor(3);
5449 fusion.addInput(tv1);
5450
5451 TensorView* tv2 = add(tv0, IrBuilder::create<Double>(1));
5452 TensorView* tv3 = broadcast(tv2, {true, false, true});
5453 TensorView* tv4 = add(tv3, tv1);
5454 fusion.addOutput(tv4);
5455
5456 tv3->merge(0)->merge(0)->split(0, 2)->split(0, 3);
5457 tv4->merge(0)->merge(0)->split(0, 2)->split(0, 3);
5458
5459 tv0->computeAt(tv4, 1);
5460 tv1->computeAt(tv4, 1);
5461
5462 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5463 at::Tensor t0 = at::randn({7}, options);
5464 at::Tensor t1 = at::randn({5, 7, 11}, options);
5465
5466 auto t2 = t0.add(1.0);
5467 auto aten_output = t2.unsqueeze(-1).add(t1);
5468
5469 std::vector<IValue> aten_inputs = {t0, t1};
5470
5471 FusionExecutor fe;
5472 fe.compileFusion(&fusion, aten_inputs);
5473 auto cg_outputs = fe.runFusion(aten_inputs);
5474
5475 testValidate(
5476 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
5477}
5478
5479TEST_F(NVFuserTest, FusionAdvancedIndexing6_CUDA) {
5480 Fusion fusion;
5481 FusionGuard fg(&fusion);
5482
5483 std::vector<int64_t> tensor0_shape{7, 4, 7};
5484 std::vector<int64_t> tensor1_shape{4, 7};
5485
5486 TensorView* tv0 = makeSymbolicTensor(tensor0_shape.size());
5487 fusion.addInput(tv0);
5488 TensorView* tv1 = makeSymbolicTensor(tensor1_shape.size());
5489 fusion.addInput(tv1);
5490
5491 TensorView* tv2 = add(tv0, tv1);
5492 TensorView* tv3 = sum(tv2, {0, 1});
5493 fusion.addOutput(tv3);
5494
5495 const auto options =
5496 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5497
5498 at::Tensor input0 = at::randn(tensor0_shape, options);
5499 at::Tensor input1 = at::randn(tensor1_shape, options);
5500
5501 std::vector<int64_t> reduction_axes{0, 1};
5502 auto reduction_params = getReductionHeuristics(&fusion, {input0, input1});
5503 TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
5504 scheduleReduction(&fusion, *reduction_params);
5505
5506 FusionExecutor fe;
5507 fe.compileFusion(&fusion, {input0, input1}, reduction_params->lparams);
5508 auto cg_outputs = fe.runFusion({input0, input1}, reduction_params->lparams);
5509
5510 auto aten_output = input0.add(input1).to(at::kDouble).sum(reduction_axes);
5511
5512 testValidate(
5513 &fusion,
5514 cg_outputs,
5515 {input0, input1},
5516 {aten_output},
5517 __LINE__,
5518 __FILE__,
5519 "",
5520 reduction_params->lparams);
5521}
5522
5523TEST_F(NVFuserTest, FusionAdvancedIndexing7_CUDA) {
5524 // Might be able to use this one without 6 as the heuristics in 6 may change
5525 // and this test is to cover the same issue.
5526 Fusion fusion;
5527 FusionGuard fg(&fusion);
5528
5529 auto tv0 = makeSymbolicTensor(1);
5530 fusion.addInput(tv0);
5531
5532 auto tv1 = broadcast(tv0, {false, true});
5533
5534 auto tv2 = makeSymbolicTensor(2);
5535 fusion.addInput(tv2);
5536
5537 auto tv3 = add(tv1, tv2);
5538 auto tv4 = sum(tv3, {0, 1});
5539 fusion.addOutput(tv4);
5540
5541 tv4->merge(0, 1);
5542 tv4->split(0, 128);
5543 tv4->split(0, 4);
5544
5545 auto tv5 = tv4->rFactor({0, 1});
5546
5547 tv5->computeAt(tv4, -1);
5548 tv0->computeAt(tv5, -1);
5549
5550 tv4->axis(0)->parallelize(ParallelType::TIDx);
5551
5552 const int numel_x = 100;
5553 const int numel_y = 200;
5554 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5555 auto at_t0 = at::randn({numel_x}, options);
5556 auto at_t1 = at::randn({numel_x, numel_y}, options);
5557
5558 FusionExecutor fe;
5559 fe.compileFusion(&fusion, {at_t0, at_t1});
5560 auto cg_outputs = fe.runFusion({at_t0, at_t1});
5561
5562 auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1)
5563 .to(at::kDouble)
5564 .sum();
5565
5566 testValidate(
5567 &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__);
5568}
5569
5570TEST_F(NVFuserTest, FusionAdvancedIndexing8_CUDA) {
5571 // Same as 7 but with outer splits instead of inner
5572 Fusion fusion;
5573 FusionGuard fg(&fusion);
5574
5575 auto tv0 = makeSymbolicTensor(1);
5576 fusion.addInput(tv0);
5577
5578 auto tv1 = broadcast(tv0, {false, true});
5579
5580 auto tv2 = makeSymbolicTensor(2);
5581 fusion.addInput(tv2);
5582
5583 auto tv3 = add(tv1, tv2);
5584 auto tv4 = sum(tv3, {0, 1});
5585 fusion.addOutput(tv4);
5586
5587 tv4->merge(0, 1);
5588 tv4->split(0, 128, false);
5589 tv4->split(0, 4, false);
5590
5591 auto tv5 = tv4->rFactor({0, 1});
5592
5593 tv5->computeAt(tv4, -1);
5594 tv0->computeAt(tv5, -1);
5595
5596 tv4->axis(0)->parallelize(ParallelType::TIDx);
5597
5598 const int numel_x = 100;
5599 const int numel_y = 200;
5600 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5601 auto at_t0 = at::randn({numel_x}, options);
5602 auto at_t1 = at::randn({numel_x, numel_y}, options);
5603
5604 FusionExecutor fe;
5605 fe.compileFusion(&fusion, {at_t0, at_t1});
5606 auto cg_outputs = fe.runFusion({at_t0, at_t1});
5607
5608 auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1)
5609 .to(at::kDouble)
5610 .sum();
5611
5612 testValidate(
5613 &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__);
5614}
5615
5616TEST_F(NVFuserTest, FusionAdvancedIndexing9_CUDA) {
5617 // Same as 7 but with outer splits instead of inner
5618 Fusion fusion;
5619 FusionGuard fg(&fusion);
5620
5621 auto tv0 = makeSymbolicTensor(1);
5622 fusion.addInput(tv0);
5623
5624 auto tv1 = broadcast(tv0, {false, true});
5625
5626 auto tv2 = mul(tv1, IrBuilder::create<Double>(2));
5627 fusion.addOutput(tv2);
5628
5629 auto tv3 = makeSymbolicTensor(3);
5630 fusion.addInput(tv3);
5631
5632 auto tv4 = add(tv3, tv2);
5633 fusion.addOutput(tv4);
5634
5635 const int numel_x = 200;
5636 const int numel_y = 300;
5637 const int numel_z = 400;
5638 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5639 auto at_t0 = at::randn({numel_y}, options);
5640 auto at_t3 = at::randn({numel_x, numel_y, numel_z}, options);
5641 std::vector<IValue> aten_inputs = {at_t0, at_t3};
5642
5643 auto lparams = schedulePointwise(&fusion, aten_inputs);
5644
5645 FusionExecutor fe;
5646 fe.compileFusion(&fusion, aten_inputs, lparams);
5647 auto cg_outputs = fe.runFusion(aten_inputs, lparams);
5648
5649 auto at_t1 = at_t0.unsqueeze(-1);
5650 auto at_t2 = at_t1.mul(2.0);
5651
5652 auto at_t4 = at_t3.add(at_t2);
5653
5654 testValidate(
5655 &fusion, cg_outputs, aten_inputs, {at_t2, at_t4}, __LINE__, __FILE__);
5656}
5657
5658TEST_F(NVFuserTest, FusionAdvancedIndexing10_CUDA) {
5659 Fusion fusion;
5660 FusionGuard fg(&fusion);
5661
5662 // Set up your input tensor views
5663 TensorView* tv0 = makeContigTensor(2);
5664 TensorView* tv1 = makeContigTensor(2);
5665
5666 // Register your inputs
5667 fusion.addInput(tv0);
5668 fusion.addInput(tv1);
5669
5670 // Do math with it, it returns a `Val*` but can be static_casted back to
5671 // TensorView
5672 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2.0));
5673 TensorView* tv3 = add(tv0, tv2);
5674
5675 // Register your outputs
5676 fusion.addOutput(tv3);
5677
5678 auto tv0_cache = tv0->cacheAfter();
5679 auto tv1_cache = tv1->cacheAfter();
5680
5681 std::vector<TensorView*> tvs = {tv0_cache, tv1_cache, tv2, tv3};
5682
5683 for (auto tv : tvs) {
5684 tv->split(1, 2, false);
5685 tv->split(1, 1);
5686 tv->split(-1, 4);
5687 // [I0, 2, 1, I1/2/4, 4]
5688 tv->reorder({{1, 2}, {2, 3}, {3, 1}});
5689 tv->axis(0)->parallelize(ParallelType::BIDx);
5690 tv->axis(1)->parallelize(ParallelType::TIDx);
5691 }
5692
5693 // For all inputs, computeAt the output inline, temporaries should be squeezed
5694 // between them
5695 tv0->computeAt(tv3, 1);
5696 tv1->computeAt(tv3, 1);
5697
5698 tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
5699 tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize);
5700
5701 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5702
5703 at::Tensor input1 = at::randn({64, 128}, options);
5704 at::Tensor input2 = at::rand_like(input1);
5705 at::Tensor output = at::empty_like(input1);
5706
5707 FusionExecutor fe;
5708 fe.compileFusion(&fusion, {input1, input2});
5709 fe.runFusion({input1, input2}, {output});
5710
5711 at::Tensor tv2_ref = input2 + 2.0;
5712 at::Tensor output_ref = input1 + tv2_ref;
5713
5714 TORCH_CHECK(output_ref.equal(output));
5715}
5716
5717TEST_F(NVFuserTest, FusionAdvancedIndexing11_CUDA) {
5718 Fusion fusion;
5719 FusionGuard fg(&fusion);
5720
5721 int w = 3, x = 4, y = 7, z = 8;
5722 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5723
5724 auto tv0 = makeSymbolicTensor(4);
5725 auto tv1 = makeSymbolicTensor(1);
5726 fusion.addInput(tv0);
5727 fusion.addInput(tv1);
5728
5729 auto tv2 = add(tv1, IrBuilder::create<Double>(1.0));
5730 auto tv3 = broadcast(tv2, {true, false, true, true});
5731 auto tv4 = add(tv3, tv0);
5732
5733 fusion.addOutput(tv4);
5734
5735 tv4->merge(0);
5736 tv4->merge(1);
5737
5738 tv4->split(1, 32);
5739 tv4->split(0, 1);
5740
5741 tv4->reorder({{2, 1}});
5742
5743 tv2->computeAt(tv4, 3);
5744
5745 tv2->setMemoryType(MemoryType::Global);
5746
5747 tv4->axis(0)->parallelize(ParallelType::BIDx);
5748 tv4->axis(1)->parallelize(ParallelType::BIDy);
5749 tv4->axis(2)->parallelize(ParallelType::Unswitch);
5750 tv4->axis(-1)->parallelize(ParallelType::TIDx);
5751
5752 tv3->axis(-1)->parallelize(ParallelType::TIDx);
5753
5754 FusionExecutor fe;
5755
5756 at::Tensor t0 = at::randn({w, x, y, z}, options);
5757 at::Tensor t1 = at::randn({x}, options);
5758
5759 auto t3 = t1.add(1.0).unsqueeze(-1).unsqueeze(-1);
5760 auto aten_output = t3.add(t0);
5761
5762 std::vector<IValue> aten_inputs = {t0, t1};
5763
5764 fe.compileFusion(&fusion, aten_inputs);
5765 auto cg_outputs = fe.runFusion(aten_inputs);
5766
5767 testValidate(
5768 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
5769}
5770
5771// Intended to stress the lowering of our code generator
5772TEST_F(NVFuserTest, FusionAdvancedLowering1_CUDA) {
5773 Fusion fusion;
5774 FusionGuard fg(&fusion);
5775
5776 TensorView* tv0 = makeConcreteTensor({9, 5});
5777 fusion.addInput(tv0);
5778
5779 TensorView* tv1 = add(tv0, IrBuilder::create<Double>(1));
5780 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2));
5781 TensorView* tv3 = add(tv1, IrBuilder::create<Double>(3));
5782 TensorView* tv4 = sum(tv3, {1});
5783
5784 fusion.addOutput(tv2);
5785 fusion.addOutput(tv4);
5786
5787 tv4->split(1, 4);
5788 auto tv5 = tv4->rFactor({2});
5789
5790 tv1->computeAt(tv5, 2);
5791
5792 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5793 at::manual_seed(1);
5794 at::Tensor aten_input = at::randn({9, 5}, options);
5795
5796 auto t1 = aten_input.add(1.0);
5797 auto t2 = t1.add(2.0);
5798 auto t3 = t1.add(3.0);
5799 auto t4 = t3.sum(1);
5800
5801 std::vector<at::Tensor> aten_outputs = {t2, t4};
5802
5803 FusionExecutor fe;
5804 fe.compileFusion(&fusion, {aten_input});
5805 auto cg_outputs = fe.runFusion({aten_input});
5806
5807 testValidate(
5808 &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
5809}
5810
5811TEST_F(NVFuserTest, FusionAdvancedLowering2_CUDA) {
5812 Fusion fusion;
5813 FusionGuard fg(&fusion);
5814
5815 // Progressively broadcast tensors
5816 TensorView* tv0 = makeSymbolicTensor(1);
5817 fusion.addInput(tv0);
5818 TensorView* tv1 = makeSymbolicTensor(2);
5819 fusion.addInput(tv1);
5820 TensorView* tv2 = makeSymbolicTensor(3);
5821 fusion.addInput(tv2);
5822
5823 TensorView* tv3 = add(tv0, IrBuilder::create<Double>(1));
5824 TensorView* tv4 = broadcast(tv3, {false, true});
5825 TensorView* tv5 = add(tv4, tv1);
5826 TensorView* tv6 = add(tv5, tv2);
5827
5828 fusion.addOutput(tv6);
5829
5830 // Split inner dimension
5831 tv6->split(1, 4);
5832 // Merge middle dims with outer dimensions
5833 tv6->merge(2);
5834 tv6->merge(0);
5835
5836 // tv6[I0*I1o, I1i*I2]
5837
5838 // Compute everything inline
5839 tv0->computeAt(tv6, -1);
5840
5841 tv6->axis(0)->parallelize(ParallelType::BIDx);
5842 tv6->axis(1)->parallelize(ParallelType::TIDx);
5843
5844 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5845 int x = 13, y = 9, z = 5;
5846 at::Tensor t0 = at::randn({y}, options);
5847 at::Tensor t1 = at::randn({y, z}, options);
5848 at::Tensor t2 = at::randn({x, y, z}, options);
5849
5850 auto t3 = t0.add(1.0);
5851 auto t4 = t3.unsqueeze(-1);
5852 auto t5 = t4.add(t1);
5853 auto t6 = t5.add(t2);
5854
5855 std::vector<IValue> aten_inputs = {t0, t1, t2};
5856 std::vector<at::Tensor> aten_outputs = {t6};
5857
5858 FusionExecutor fe;
5859 fe.compileFusion(&fusion, aten_inputs);
5860 auto cg_outputs = fe.runFusion(aten_inputs);
5861
5862 testValidate(
5863 &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
5864}
5865
5866// TODO: Complete test
5867TEST_F(NVFuserTest, FusionAdvancedLowering3_CUDA) {
5868 Fusion fusion;
5869 FusionGuard fg(&fusion);
5870
5871 auto tv0 = makeConcreteTensor({1, -1});
5872 auto tv1 = makeSymbolicTensor(2);
5873 fusion.addInput(tv0);
5874 fusion.addInput(tv1);
5875
5876 // [b0, i1]
5877 auto tv2 = add(tv0, IrBuilder::create<Double>(2.0));
5878
5879 // [i0, i1]
5880 auto tv3 = add(tv1, IrBuilder::create<Double>(3.0));
5881
5882 // [b0, i1]
5883 auto tv4 = add(tv2, IrBuilder::create<Double>(4.0));
5884
5885 // [io, i1]
5886 auto tv5 = add(tv2, tv3);
5887
5888 fusion.addOutput(tv4);
5889 fusion.addOutput(tv5);
5890
5891 tv0->computeAt(tv4, -1);
5892
5893 tv3->setMemoryType(MemoryType::Global);
5894
5895 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5896 int x = 13, y = 9;
5897 at::Tensor t0 = at::randn({1, y}, options);
5898 at::Tensor t1 = at::randn({x, y}, options);
5899
5900 auto t4 = t0 + 2 + 4;
5901 auto t5 = t0 + 2 + t1 + 3;
5902
5903 std::vector<IValue> aten_inputs = {t0, t1};
5904 std::vector<at::Tensor> aten_outputs = {t4, t5};
5905
5906 FusionExecutor fe;
5907 fe.compileFusion(&fusion, aten_inputs);
5908 auto cg_outputs = fe.runFusion(aten_inputs);
5909
5910 testValidate(
5911 &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
5912}
5913
5914// This excercises indexing with broadcast root axes. Non-broadcast
5915// axes need to be preferred when propagating index exprs to root
5916// axes. See, e.g., Index::getConsumerIndex_impl.
5917TEST_F(NVFuserTest, FusionAdvancedLowering4_CUDA) {
5918 Fusion fusion;
5919 FusionGuard fg(&fusion);
5920
5921 auto tv0 = makeSymbolicTensor(1);
5922 fusion.addInput(tv0);
5923 auto tv1 = broadcast(tv0, {false, true});
5924 auto tv2 = broadcast(tv1, {false, false, true});
5925 auto tv3 = makeSymbolicTensor(3);
5926 fusion.addInput(tv3);
5927 auto tv4 = add(tv2, tv3);
5928 fusion.addOutput(tv4);
5929
5930 tv4->merge(1)->merge(0);
5931 tv4->split(0, 8);
5932 tv0->computeAt(tv4, 1);
5933
5934 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5935 const int bx = 10;
5936 const int by = 20;
5937 const int bz = 30;
5938 at::Tensor t0 = at::randn({bx}, options);
5939 at::Tensor t3 = at::randn({bx, by, bz}, options);
5940 std::vector<IValue> aten_inputs = {t0, t3};
5941
5942 FusionExecutor fe;
5943 fe.compileFusion(&fusion, aten_inputs);
5944 auto cg_outputs = fe.runFusion(aten_inputs);
5945
5946 auto aten_output =
5947 t0.unsqueeze(-1).expand({bx, by}).unsqueeze(-1).expand({bx, by, bz}) + t3;
5948
5949 testValidate(
5950 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
5951}
5952
5953TEST_F(NVFuserTest, FusionAdvancedLowering5_CUDA) {
5954 Fusion fusion;
5955 FusionGuard fg(&fusion);
5956
5957 TensorView* tv0 = makeConcreteTensor({5, 4, 3});
5958 fusion.addInput(tv0);
5959
5960 TensorView* tv1 = makeConcreteTensor({5, 3});
5961 fusion.addInput(tv1);
5962
5963 auto tv2 = broadcast(tv1, {false, true, false});
5964
5965 auto tv3 = add(tv0, tv2);
5966
5967 fusion.addOutput(tv3);
5968
5969 tv2->merge(0);
5970 tv1->computeAt(tv2, 1);
5971
5972 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5973 at::manual_seed(1);
5974 at::Tensor t0 = at::randn({5, 4, 3}, options);
5975 at::Tensor t1 = at::randn({5, 3}, options);
5976 auto t2 = t1.unsqueeze(1);
5977 auto t3 = t0 + t2;
5978
5979 std::vector<IValue> aten_inputs = {t0, t1};
5980 std::vector<at::Tensor> aten_outputs = {t3};
5981
5982 FusionExecutor fe;
5983 fe.compileFusion(&fusion, aten_inputs);
5984 auto cg_outputs = fe.runFusion(aten_inputs);
5985
5986 testValidate(
5987 &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
5988}
5989
5990TEST_F(NVFuserTest, FusionAdvancedLowering6_CUDA) {
5991 Fusion fusion;
5992 FusionGuard fg(&fusion);
5993
5994 TensorView* tv0 = makeConcreteTensor({5, 4, 3});
5995 fusion.addInput(tv0);
5996 auto tv1 = makeConcreteTensor({4});
5997 fusion.addInput(tv1);
5998 auto tv2 = unaryOp(UnaryOpType::Set, tv0);
5999 auto tv3 = unaryOp(UnaryOpType::Set, tv1);
6000
6001 auto tv4 = sum(tv2, {0, 2});
6002 auto tv5 = add(tv4, tv3);
6003 fusion.addOutput(tv5);
6004
6005 auto tv6 = broadcast(tv3, {true, false, true});
6006 auto tv7 = add(tv2, tv6);
6007 fusion.addOutput(tv7);
6008
6009 tv2->computeAt(tv4, -1, ComputeAtMode::BestEffort);
6010 tv3->computeAt(tv7, -1, ComputeAtMode::BestEffort);
6011
6012 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6013 at::manual_seed(1);
6014 at::Tensor t0 = at::randn({5, 4, 3}, options);
6015 at::Tensor t1 = at::randn({4}, options);
6016
6017 auto t2 = t0;
6018 auto t3 = t1;
6019
6020 std::vector<int64_t> reduction_axes{0, 2};
6021 auto t4 = t2.sum(reduction_axes);
6022 auto t5 = add(t4, t3);
6023 auto t6 = t3.unsqueeze(0).unsqueeze(-1);
6024 auto t7 = t2.add(t6);
6025
6026 std::vector<IValue> aten_inputs = {t0, t1};
6027 std::vector<at::Tensor> aten_outputs = {t5, t7};
6028
6029 FusionExecutor fe;
6030 fe.compileFusion(&fusion, aten_inputs);
6031 auto cg_outputs = fe.runFusion(aten_inputs);
6032
6033 testValidate(
6034 &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
6035}
6036
6037// Test a simple Gemm but also play around with fusion executor features
6038TEST_F(NVFuserTest, FusionSimpleGemm_CUDA) {
6039 Fusion fusion;
6040 FusionGuard fg(&fusion);
6041
6042 // Set up your input tensor views
6043 TensorView* tv0 = makeSymbolicTensor(2); // M, K
6044 TensorView* tv1 = makeSymbolicTensor(2); // K, N
6045 fusion.addInput(tv0);
6046 fusion.addInput(tv1);
6047
6048 TensorView* tv2 = broadcast(tv0, {false, false, true});
6049 // tv2[I0, I1, B] = tv0[I0, I1]
6050
6051 TensorView* tv3 = broadcast(tv1, {true, false, false});
6052 // tv3[B, I1, I2] = tv1[I1, I2]
6053
6054 // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
6055 TensorView* tv4 = mul(tv2, tv3);
6056 // tv5[I0, R1, I2] = tv4[I0, I1, I2]
6057 TensorView* tv5 = sum(tv4, {1});
6058 fusion.addOutput(tv5);
6059
6060 tv5->split(1, 32);
6061 // tv5[I0, R1o, R1i{32}, I2]
6062
6063 auto tv6 = tv5->rFactor({1});
6064 // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
6065 // tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]
6066
6067 tv5->split(0, 4);
6068 tv5->split(-1, 4);
6069 // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
6070 // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
6071
6072 tv0->computeAt(tv5, -1);
6073 tv1->computeAt(tv5, -1);
6074
6075 // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
6076 // tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
6077 //--> (line symbolizes compute at location)
6078 // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
6079 // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
6080 // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
6081
6082 tv0->computeAt(tv6, -1);
6083 tv1->computeAt(tv6, -1);
6084 // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
6085 // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
6086 // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
6087
6088 tv5->axis(0)->parallelize(ParallelType::BIDz);
6089 tv5->axis(1)->parallelize(ParallelType::TIDz);
6090
6091 tv5->axis(-2)->parallelize(ParallelType::BIDy);
6092 tv5->axis(-1)->parallelize(ParallelType::TIDy);
6093
6094 tv5->axis(2)->parallelize(ParallelType::TIDx);
6095 tv6->axis(2)->parallelize(ParallelType::TIDx);
6096
6097 constexpr int M = 65, K = 33, N = 17;
6098
6099 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6100
6101 at::Tensor t0 = at::randn({M, K}, options);
6102 at::Tensor t1 = at::randn({K, N}, options);
6103
6104 FusionExecutor fe;
6105 fe.compileFusion(&fusion, {t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4));
6106 // Lets specify a few bounds in launch params to make sure it works
6107 fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4));
6108
6109 // Make sure bad launch params throws
6110 // TODO: Re-enable once we have parallelization validation in.
6111 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
6112 // ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6)));
6113
6114 // Don't specify any launch params
6115 auto cg_outputs = fe.runFusion({t0, t1});
6116
6117 auto aten_output = t0.to(at::kDouble).matmul(t1.to(at::kDouble));
6118
6119 testValidate(
6120 &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__);
6121}
6122
6123// Softmax with a 1D tensor. Parallelized only with a single thread block.
6124TEST_F(NVFuserTest, FusionSoftmax1D_CUDA) {
6125 Fusion fusion;
6126 FusionGuard fg(&fusion);
6127
6128 const int tidx = 128;
6129 const int dimx = 1000;
6130
6131 // Set up your input tensor views
6132 TensorView* input_tv0 = makeSymbolicTensor(1);
6133 fusion.addInput(input_tv0);
6134
6135 TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
6136 TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
6137 TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {true});
6138
6139 // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
6140 // computed at sum_exp_rf_tv8.
6141 TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
6142
6143 TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
6144
6145 fusion.addOutput(output_tv4);
6146
6147 bcast_sum_tv3->split(0, tidx);
6148
6149 sum_exp_tv2->split(-1, tidx);
6150 TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
6151
6152 output_tv4->split(-1, tidx);
6153
6154 exp_tv1->computeAt(sum_exp_rf_tv5, -1);
6155 exp_tv1_copy->computeAt(output_tv4, -1);
6156
6157 TensorView* tensors_to_parallelize[] = {
6158 sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
6159
6160 for (auto tv : tensors_to_parallelize) {
6161 tv->axis(-1)->parallelize(ParallelType::TIDx);
6162 }
6163
6164 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6165 at::Tensor t0 = at::randn({dimx}, options);
6166 at::Tensor cg_output = at::empty({dimx}, options);
6167 at::Tensor t3_output = at::empty_like(cg_output, options);
6168
6169 FusionExecutor fe;
6170 fe.compileFusion(&fusion, {t0});
6171 fe.runFusion({t0}, {cg_output});
6172
6173 auto aten_output = at::_softmax(t0.to(at::kDouble), -1, false);
6174
6175 testValidate(&fusion, {cg_output}, {t0}, {aten_output}, __LINE__, __FILE__);
6176}
6177
6178// Softmax with a 1D tensor with input normalization.
6179TEST_F(NVFuserTest, FusionSoftmax1DNormalized_CUDA) {
6180 Fusion fusion;
6181 FusionGuard fg(&fusion);
6182
6183 const int tidx = 128;
6184 const int dimx = 1000;
6185
6186 // Set up your input tensor views
6187 TensorView* input_tv0 = makeSymbolicTensor(1);
6188 fusion.addInput(input_tv0);
6189
6190 // Normalize with the max value before computing exp.
6191 TensorView* max_val_tv1 = reductionOp(
6192 BinaryOpType::Max, {-1}, IrBuilder::create<Double>(0), input_tv0);
6193 TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true});
6194 TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
6195 TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
6196 TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
6197 TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {true});
6198
6199 // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
6200 // computed at sum_exp_rf_tv8.
6201 TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
6202 TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
6203
6204 TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
6205
6206 fusion.addOutput(output_tv7);
6207 bcast_max_tv2->split(0, tidx);
6208 bcast_sum_tv6->split(0, tidx);
6209
6210 max_val_tv1->split(-1, tidx);
6211 TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
6212
6213 sum_exp_tv5->split(-1, tidx);
6214 TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
6215
6216 output_tv7->split(-1, tidx);
6217
6218 sub_tv3->computeAt(sum_exp_rf_tv9, -1);
6219 sub_tv3_copy->computeAt(output_tv7, -1);
6220
6221 TensorView* tensors_to_parallelize[] = {
6222 max_val_tv1,
6223 bcast_max_tv2,
6224 sum_exp_tv5,
6225 bcast_sum_tv6,
6226 output_tv7,
6227 max_val_rf_tv8,
6228 sum_exp_rf_tv9};
6229
6230 for (auto tv : tensors_to_parallelize) {
6231 tv->axis(-1)->parallelize(ParallelType::TIDx);
6232 }
6233
6234 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6235 at::Tensor input = at::randn({dimx}, options);
6236 at::Tensor t3_output = at::empty({dimx}, options);
6237
6238 FusionExecutor fe;
6239 fe.compileFusion(&fusion, {input});
6240 auto cg_outputs = fe.runFusion({input});
6241
6242 auto aten_output = at::_softmax(input.to(at::kDouble), -1, false);
6243
6244 testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
6245}
6246
6247// Softmax with a 3D tensor, where the inner-most 3rd dimension is
6248// normalized. Pallelized with multiple thread blocks.
6249TEST_F(NVFuserTest, FusionSoftmax3D_CUDA) {
6250 Fusion fusion;
6251 FusionGuard fg(&fusion);
6252
6253 const int tidx = 32;
6254 const int dimx = 32;
6255 const int dimy = 16;
6256 const int dimz = 130;
6257
6258 // Set up your input tensor views
6259 TensorView* input_tv0 = makeSymbolicTensor(3);
6260 fusion.addInput(input_tv0);
6261
6262 TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
6263 TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
6264 TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true});
6265
6266 // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
6267 // computed at sum_exp_rf_tv8.
6268 TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
6269
6270 TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
6271
6272 fusion.addOutput(output_tv4);
6273
6274 bcast_sum_tv3->split(-1, tidx);
6275
6276 sum_exp_tv2->split(-1, tidx);
6277 TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
6278
6279 output_tv4->split(-1, tidx);
6280
6281 exp_tv1->computeAt(sum_exp_rf_tv5, -1);
6282 exp_tv1_copy->computeAt(output_tv4, -1);
6283
6284 TensorView* tensors_to_parallelize[] = {
6285 sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
6286
6287 for (auto tv : tensors_to_parallelize) {
6288 tv->axis(0)->parallelize(ParallelType::BIDx);
6289 tv->axis(1)->parallelize(ParallelType::BIDy);
6290 tv->axis(-1)->parallelize(ParallelType::TIDx);
6291 }
6292
6293 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6294 at::Tensor input = at::randn({dimx, dimy, dimz}, options);
6295
6296 at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options);
6297
6298 FusionExecutor fe;
6299 fe.compileFusion(&fusion, {input});
6300 fe.runFusion({input}, {cg_output});
6301
6302 auto aten_output = at::_softmax(input.to(at::kDouble), -1, false);
6303
6304 testValidate(
6305 &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
6306}
6307
6308// Softmax with a 3D tensor with input normalization.
6309TEST_F(NVFuserTest, FusionSoftmax3DNormalized_CUDA) {
6310 Fusion fusion;
6311 FusionGuard fg(&fusion);
6312
6313 const int tidx = 32;
6314 const int dimx = 32;
6315 const int dimy = 16;
6316 const int dimz = 130;
6317
6318 // Set up your input tensor views
6319 TensorView* input_tv0 = makeSymbolicTensor(3);
6320 fusion.addInput(input_tv0);
6321
6322 // Normalize with the max value before computing exp.
6323 TensorView* max_val_tv1 = reductionOp(
6324 BinaryOpType::Max, {-1}, IrBuilder::create<Double>(0), input_tv0);
6325 TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true});
6326 TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
6327 TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
6328 TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
6329 TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {false, false, true});
6330
6331 // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
6332 // computed at sum_exp_rf_tv8.
6333 TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
6334 TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
6335
6336 TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
6337
6338 fusion.addOutput(output_tv7);
6339
6340 bcast_max_tv2->split(-1, tidx);
6341 bcast_sum_tv6->split(-1, tidx);
6342
6343 max_val_tv1->split(-1, tidx);
6344 TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
6345
6346 sum_exp_tv5->split(-1, tidx);
6347 TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
6348
6349 output_tv7->split(-1, tidx);
6350
6351 sub_tv3->computeAt(sum_exp_rf_tv9, -1);
6352 sub_tv3_copy->computeAt(output_tv7, -1);
6353
6354 TensorView* tensors_to_parallelize[] = {
6355 max_val_tv1,
6356 bcast_max_tv2,
6357 sum_exp_tv5,
6358 bcast_sum_tv6,
6359 output_tv7,
6360 max_val_rf_tv8,
6361 sum_exp_rf_tv9};
6362
6363 for (auto tv : tensors_to_parallelize) {
6364 tv->axis(0)->parallelize(ParallelType::BIDx);
6365 tv->axis(1)->parallelize(ParallelType::BIDy);
6366 tv->axis(-1)->parallelize(ParallelType::TIDx);
6367 }
6368
6369 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6370 at::Tensor input = at::randn({dimx, dimy, dimz}, options);
6371 at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options);
6372
6373 FusionExecutor fe;
6374 fe.compileFusion(&fusion, {input});
6375 auto cg_outputs = fe.runFusion({input});
6376
6377 auto aten_output = at::_softmax(input.to(at::kDouble), -1, false);
6378
6379 testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
6380}
6381
6382TEST_F(NVFuserTest, FusionSoftmaxComputeAt_CUDA) {
6383 Fusion fusion;
6384 FusionGuard fg(&fusion);
6385
6386 // Set up your input tensor views
6387 TensorView* tv0 = makeSymbolicTensor(2);
6388 fusion.addInput(tv0);
6389
6390 auto tv1 = sum(tv0, {1});
6391 auto tv2 = broadcast(tv1, {false, true});
6392
6393 auto tv3 = add(tv0, IrBuilder::create<Double>(1.0));
6394
6395 auto tv4 = mul(tv2, tv3);
6396
6397 auto tv5 = sum(tv4, {1});
6398 auto tv6 = broadcast(tv5, {false, true});
6399
6400 auto tv7 = sub(tv6, tv4);
6401 fusion.addOutput(tv7);
6402
6403 tv1->computeAt(tv7, 1);
6404 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
6405 ASSERT_ANY_THROW(tv1->computeAt(tv7, -1));
6406}
6407
6408// Similar to FusionReduction but uses grid reduction
6409TEST_F(NVFuserTest, FusionGridReduction1_CUDA) {
6410 const int gdimx = 32;
6411 const int bdimx = 128;
6412
6413 Fusion fusion;
6414 FusionGuard fg(&fusion);
6415
6416 // Set up your input tensor views
6417 TensorView* tv0 = makeSymbolicTensor(2);
6418 fusion.addInput(tv0);
6419
6420 // tv1[I0, R1] = tv0[I0, I1]
6421 TensorView* tv1 =
6422 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
6423 fusion.addOutput(tv1);
6424
6425 TORCH_CHECK(
6426 ir_utils::getReductionOps(&fusion).size(),
6427 "Could not detect reduction in fusion.");
6428
6429 tv1->split(1, bdimx);
6430 // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
6431 tv1->split(1, gdimx);
6432 // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
6433
6434 TensorView* tv2 = tv1->rFactor({1});
6435 // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
6436 // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
6437
6438 // Incrementally, can print in between for debugging
6439 tv0->computeAt(tv2, 1);
6440 tv2->computeAt(tv1, 1);
6441
6442 // Re do it all at once, because why not.
6443 tv0->computeAt(tv1, 1);
6444
6445 tv1->axis(0)->parallelize(ParallelType::BIDy);
6446 tv1->axis(1)->parallelize(ParallelType::BIDx);
6447 tv2->axis(2)->parallelize(ParallelType::BIDx);
6448
6449 tv1->axis(-1)->parallelize(ParallelType::TIDx);
6450 tv2->axis(-1)->parallelize(ParallelType::TIDx);
6451
6452 int numel_x = 10000;
6453 int numel_y = 65000;
6454
6455 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6456 at::Tensor input = at::randn({numel_x, numel_y}, options);
6457 at::Tensor cg_output = at::empty({numel_x}, options);
6458
6459 FusionExecutor fe;
6460 fe.compileFusion(&fusion, {input});
6461 fe.runFusion({input}, {cg_output});
6462
6463 auto aten_output = input.to(at::kDouble).sum({1});
6464
6465 testValidate(
6466 &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
6467}
6468
6469// Same test as the above but uses BIDy and TIDx for reduction
6470TEST_F(NVFuserTest, FusionGridReduction2_CUDA) {
6471 const int gdimy = 32;
6472 const int bdimx = 128;
6473
6474 Fusion fusion;
6475 FusionGuard fg(&fusion);
6476
6477 // Set up your input tensor views
6478 TensorView* tv0 = makeSymbolicTensor(2);
6479 fusion.addInput(tv0);
6480
6481 // tv1[I0, R1] = tv0[I0, I1]
6482 TensorView* tv1 =
6483 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
6484 fusion.addOutput(tv1);
6485
6486 TORCH_CHECK(
6487 ir_utils::getReductionOps(&fusion).size(),
6488 "Could not detect reduction in fusion.");
6489
6490 tv1->split(1, bdimx);
6491 // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
6492 tv1->split(1, gdimy);
6493 // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
6494
6495 TensorView* tv2 = tv1->rFactor({1});
6496 // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
6497 // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
6498
6499 // Incrementally, can print in between for debugging
6500 tv0->computeAt(tv2, 1);
6501 tv2->computeAt(tv1, 1);
6502
6503 // Re do it all at once, because why not.
6504 tv0->computeAt(tv1, 1);
6505
6506 tv1->axis(0)->parallelize(ParallelType::BIDx);
6507 tv1->axis(1)->parallelize(ParallelType::BIDy);
6508 tv2->axis(2)->parallelize(ParallelType::BIDy);
6509
6510 tv1->axis(-1)->parallelize(ParallelType::TIDx);
6511 tv2->axis(-1)->parallelize(ParallelType::TIDx);
6512
6513 int numel_x = 10000;
6514 int numel_y = 65000;
6515
6516 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6517 at::Tensor input = at::randn({numel_x, numel_y}, options);
6518
6519 FusionExecutor fe;
6520 fe.compileFusion(&fusion, {input});
6521 auto cg_outputs = fe.runFusion({input});
6522
6523 auto aten_output = input.to(at::kDouble).sum({1});
6524
6525 testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
6526}
6527
6528// Same test but uses BIDy and BIDz for reduction. No TID used.
6529TEST_F(NVFuserTest, FusionGridReduction3dim1_CUDA) {
6530 // Grid reductions when there aren't any threads are serial reductions
6531 // keep these numbers low so our error isn't too high compared to normal cuda
6532 // reductions
6533 const int gdimz = 15;
6534 const int gdimy = 9;
6535
6536 Fusion fusion;
6537 FusionGuard fg(&fusion);
6538
6539 // Set up your input tensor views
6540 TensorView* tv0 = makeSymbolicTensor(2);
6541 fusion.addInput(tv0);
6542
6543 // tv1[I0, R1] = tv0[I0, I1]
6544 TensorView* tv1 =
6545 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
6546 fusion.addOutput(tv1);
6547
6548 TORCH_CHECK(
6549 ir_utils::getReductionOps(&fusion).size(),
6550 "Could not detect reduction in fusion.");
6551
6552 tv1->split(1, gdimy);
6553 // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
6554 tv1->split(1, gdimz);
6555 // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
6556
6557 TensorView* tv2 = tv1->rFactor({1});
6558 // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
6559 // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
6560
6561 // Incrementally, can print in between for debugging
6562 tv0->computeAt(tv2, 1);
6563 tv2->computeAt(tv1, 1);
6564
6565 // Re do it all at once, because why not.
6566 tv0->computeAt(tv1, 1);
6567
6568 tv1->axis(0)->parallelize(ParallelType::BIDx);
6569 tv1->axis(1)->parallelize(ParallelType::BIDz);
6570 tv2->axis(2)->parallelize(ParallelType::BIDz);
6571 tv1->axis(-1)->parallelize(ParallelType::BIDy);
6572 tv2->axis(-1)->parallelize(ParallelType::BIDy);
6573
6574 int numel_x = 100;
6575 int numel_y = 6500;
6576
6577 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6578 at::Tensor input = at::randn({numel_x, numel_y}, options);
6579 at::Tensor cg_output = at::empty({numel_x}, options);
6580
6581 FusionExecutor fe;
6582 fe.compileFusion(&fusion, {input});
6583 fe.runFusion({input}, {cg_output});
6584
6585 auto aten_output = input.to(at::kDouble).sum({1});
6586 testValidate(
6587 &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
6588}
6589
6590// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0
6591TEST_F(NVFuserTest, FusionGridReduction3dim0_CUDA) {
6592 // Grid reductions when there aren't any threads are serial reductions
6593 // keep these numbers low so our error isn't too high compared to normal cuda
6594 // reductions
6595 const int gdimz = 15;
6596 const int gdimy = 9;
6597
6598 Fusion fusion;
6599 FusionGuard fg(&fusion);
6600
6601 // Set up your input tensor views
6602 TensorView* tv0 = makeSymbolicTensor(2);
6603 fusion.addInput(tv0);
6604
6605 // tv1[R0, I1] = tv0[I0, I1]
6606 TensorView* tv1 =
6607 reductionOp(BinaryOpType::Add, {0}, IrBuilder::create<Double>(0), tv0);
6608 fusion.addOutput(tv1);
6609
6610 TORCH_CHECK(
6611 ir_utils::getReductionOps(&fusion).size(),
6612 "Could not detect reduction in fusion.");
6613
6614 tv1->split(0, gdimy);
6615 // tv1[R0o, R0i{128}, I1] = tv0[I0, I1]
6616 tv1->split(0, gdimz);
6617 // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1]
6618
6619 TensorView* tv2 = tv1->rFactor({0});
6620 // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1]
6621 // tv1[ R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1]
6622
6623 // Note that computeAt isn't going to make anything better as there
6624 // is no dynamically sized dimension.
6625
6626 // Map parallelism as [Serial, BIDz, BIDy, BIDx]
6627 tv1->axis(-1)->parallelize(ParallelType::BIDx);
6628 tv2->axis(-1)->parallelize(ParallelType::BIDx);
6629 tv1->axis(-2)->parallelize(ParallelType::BIDy);
6630 tv2->axis(-2)->parallelize(ParallelType::BIDy);
6631 tv1->axis(-3)->parallelize(ParallelType::BIDz);
6632 tv2->axis(-3)->parallelize(ParallelType::BIDz);
6633
6634 int numel_x = 6500;
6635 int numel_y = 100;
6636
6637 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6638 at::manual_seed(0);
6639 at::Tensor input = at::randn({numel_x, numel_y}, options);
6640
6641 FusionExecutor fe;
6642 fe.compileFusion(&fusion, {input});
6643 auto cg_outputs = fe.runFusion({input});
6644
6645 auto aten_output = input.to(at::kDouble).sum({0});
6646
6647 testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
6648}
6649
6650// This is similar to the FusionReduction, but swaps BIDx and TIDx
6651TEST_F(NVFuserTest, FusionGridReduction4_CUDA) {
6652 Fusion fusion;
6653 FusionGuard fg(&fusion);
6654
6655 const int bdimx = 128;
6656 const int gdimx = 1024;
6657
6658 // Set up your input tensor views
6659 TensorView* tv0 = makeSymbolicTensor(2);
6660 fusion.addInput(tv0);
6661
6662 // tv1[I0, R1] = tv0[I0, I1]
6663 TensorView* tv1 =
6664 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
6665 fusion.addOutput(tv1);
6666
6667 TORCH_CHECK(
6668 ir_utils::getReductionOps(&fusion).size(),
6669 "Could not detect reduction in fusion.");
6670
6671 tv1->split(1, gdimx);
6672 // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1]
6673 tv1->split(1, 4);
6674 // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
6675
6676 TensorView* tv2 = tv1->rFactor({1});
6677 // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
6678 // tv1[I0, R1oi{4}, R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
6679
6680 TensorView* tv3 = tv1->rFactor({1});
6681 // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
6682 // tv3[I0, R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
6683 // tv1[I0, R1i{1024}] = tv3[I0, R1oi{4}, Ir1i{1024}]
6684
6685 // Incrementally, can print in between for debugging
6686 tv0->computeAt(tv2, 1);
6687 tv2->computeAt(tv3, 1);
6688 tv3->computeAt(tv1, 1);
6689
6690 // Re do it all at once, because why not.
6691 tv0->computeAt(tv1, 1);
6692
6693 tv2->axis(2)->parallelize(ParallelType::Unroll);
6694 tv1->axis(0)->parallelize(ParallelType::TIDx);
6695
6696 tv1->axis(-1)->parallelize(ParallelType::BIDx);
6697 tv2->axis(-1)->parallelize(ParallelType::BIDx);
6698 tv3->axis(-1)->parallelize(ParallelType::BIDx);
6699
6700 int numel_x = bdimx;
6701 int numel_y = 65000;
6702
6703 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6704 at::Tensor input = at::randn({numel_x, numel_y}, options);
6705 at::Tensor cg_output = at::empty({numel_x}, options);
6706
6707 FusionExecutor fe;
6708 fe.compileFusion(&fusion, {input});
6709 fe.runFusion({input}, {cg_output});
6710
6711 auto aten_output = input.to(at::kDouble).sum({1});
6712 testValidate(
6713 &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
6714}
6715
6716// Grid reduction with 2D thread blocks but only TIDx and BIDx are
6717// mapped to a reduction dim
6718TEST_F(NVFuserTest, FusionGridReduction5_CUDA) {
6719 Fusion fusion;
6720 FusionGuard fg(&fusion);
6721
6722 const int bdimx = 64;
6723 const int bdimy = 16;
6724 const int gdimx = 4;
6725
6726 // Set up your input tensor views
6727 TensorView* tv0 = makeSymbolicTensor(2);
6728 fusion.addInput(tv0);
6729
6730 // tv1[I0, R1] = tv0[I0, I1]
6731 TensorView* tv1 =
6732 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
6733 fusion.addOutput(tv1);
6734
6735 TORCH_CHECK(
6736 ir_utils::getReductionOps(&fusion).size(),
6737 "Could not detect reduction in fusion.");
6738
6739 tv1->split(1, bdimx);
6740 // tv1[I0, R1o, R1i{64}] = tv0[I0, I1]
6741 tv1->split(1, gdimx);
6742 // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1]
6743
6744 TensorView* tv2 = tv1->rFactor({1});
6745 // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1]
6746 // tv1[I0, R1oi{4}, R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}]
6747
6748 tv0->computeAt(tv1, 1);
6749
6750 tv1->axis(-1)->parallelize(ParallelType::TIDx);
6751 tv2->axis(-1)->parallelize(ParallelType::TIDx);
6752
6753 tv1->axis(-2)->parallelize(ParallelType::BIDx);
6754 tv2->axis(-2)->parallelize(ParallelType::BIDx);
6755
6756 tv1->axis(0)->parallelize(ParallelType::TIDy);
6757
6758 int numel_x = bdimy;
6759 int numel_y = 6500;
6760
6761 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6762 at::Tensor input = at::randn({numel_x, numel_y}, options);
6763
6764 FusionExecutor fe;
6765 fe.compileFusion(&fusion, {input});
6766 auto cg_outputs = fe.runFusion({input});
6767
6768 auto aten_output = input.to(at::kDouble).sum({1});
6769 testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
6770}
6771
6772// Similar to FusionGridReduction1 but with 3D tensors
6773TEST_F(NVFuserTest, FusionGridReduction6_CUDA) {
6774 Fusion fusion;
6775 FusionGuard fg(&fusion);
6776
6777 // Set up your input tensor views
6778 TensorView* tv0 = makeSymbolicTensor(3);
6779 fusion.addInput(tv0);
6780
6781 // tv1[I0, R1, R2] = tv0[I0, I1, I2]
6782 TensorView* tv1 =
6783 reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create<Double>(0), tv0);
6784 fusion.addOutput(tv1);
6785
6786 TORCH_CHECK(
6787 ir_utils::getReductionOps(&fusion).size(),
6788 "Could not detect reduction in fusion.");
6789
6790 // Splitting for TID
6791 tv1->split(2, 128);
6792 // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
6793
6794 // Splitting for BID
6795 tv1->split(1, 128);
6796
6797 // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2]
6798
6799 TensorView* tv2 = tv1->rFactor({3});
6800 // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
6801 // tv1[I0, R1o, R1i{128}, R2i{128}]
6802
6803 TensorView* tv3 = tv1->rFactor({1});
6804 // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
6805 // tv3[I0, R1o, I1i{128}, I2i{128}]
6806 // tv1[I0, R1i{128}, R2i{128}]
6807
6808 tv3->computeAt(tv1, 1);
6809 tv2->computeAt(tv3, 3);
6810
6811 tv1->axis(0)->parallelize(ParallelType::BIDy);
6812
6813 tv1->axis(-1)->parallelize(ParallelType::TIDx);
6814 tv2->axis(-1)->parallelize(ParallelType::TIDx);
6815 tv3->axis(-1)->parallelize(ParallelType::TIDx);
6816
6817 tv1->axis(-2)->parallelize(ParallelType::BIDx);
6818 tv2->axis(-3)->parallelize(ParallelType::BIDx);
6819 tv3->axis(-2)->parallelize(ParallelType::BIDx);
6820
6821 int numel_x = 6500;
6822 int numel_y = 200;
6823 int numel_z = numel_y;
6824
6825 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6826 at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options);
6827 at::Tensor cg_output = at::empty({numel_x}, options);
6828
6829 FusionExecutor fe;
6830 fe.compileFusion(&fusion, {input});
6831 fe.runFusion({input}, {cg_output});
6832
6833 auto aten_output = input.to(at::kDouble).sum({1, 2});
6834
6835 testValidate(
6836 &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
6837}
6838
6839// See issue #1049
6840TEST_F(NVFuserTest, FusionGridReduction7_CUDA) {
6841 Fusion fusion;
6842 FusionGuard fg(&fusion);
6843
6844 auto tv0 = makeSymbolicTensor(1);
6845 fusion.addInput(tv0);
6846
6847 auto tv1 = sum(tv0, {0});
6848 fusion.addOutput(tv1);
6849
6850 tv1->split(0, 1000);
6851
6852 tv1->axis(0)->parallelize(ParallelType::BIDx);
6853 tv1->axis(1)->parallelize(ParallelType::BIDy);
6854
6855 const int numel_x = 1;
6856
6857 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6858 at::Tensor input = at::randn({numel_x}, options);
6859 at::Tensor cg_output = at::empty({numel_x}, options);
6860
6861 FusionExecutor fe;
6862 fe.compileFusion(&fusion, {input});
6863 auto out = fe.runFusion({input});
6864
6865 auto aten_output = input.sum({0});
6866
6867 testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__);
6868}
6869
6870TEST_F(NVFuserTest, FusionGridReduction8_CUDA) {
6871 Fusion fusion;
6872 FusionGuard fg(&fusion);
6873
6874 auto tv0 = makeSymbolicTensor(2);
6875 fusion.addInput(tv0);
6876
6877 auto tv1 = sum(tv0, {0});
6878 fusion.addOutput(tv1);
6879
6880 tv1->axis(0)->parallelize(ParallelType::BIDx);
6881 tv1->axis(1)->parallelize(ParallelType::TIDx);
6882
6883 const int numel_x = 2;
6884 const int numel_y = 4;
6885
6886 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6887 at::Tensor input = at::randn({numel_x, numel_y}, options);
6888
6889 FusionExecutor fe;
6890 fe.compileFusion(&fusion, {input});
6891 auto out = fe.runFusion({input});
6892
6893 auto aten_output = input.sum({0});
6894
6895 testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__);
6896}
6897
6898TEST_F(NVFuserTest, FusionGridReduction9_CUDA) {
6899 Fusion fusion;
6900 FusionGuard fg(&fusion);
6901
6902 auto tv0 = makeSymbolicTensor(2);
6903 fusion.addInput(tv0);
6904 auto tv1 = sum(tv0, {1});
6905
6906 auto tv2 = makeSymbolicTensor(1);
6907 fusion.addInput(tv2);
6908
6909 auto tv3 = add(tv2, tv1);
6910 fusion.addOutput(tv3);
6911
6912 tv1->split(1, 2);
6913
6914 tv1->axis(1)->parallelize(ParallelType::BIDx);
6915 tv1->axis(2)->parallelize(ParallelType::BIDy);
6916
6917 tv1->computeAt(tv3, 1);
6918
6919 const int numel_x = 4;
6920 const int numel_y = 10;
6921
6922 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6923 at::Tensor t0 = at::randn({numel_x, numel_y}, options);
6924 at::Tensor t2 = at::randn({numel_x}, options);
6925
6926 std::vector<IValue> aten_inputs = {t0, t2};
6927
6928 FusionExecutor fe;
6929 fe.compileFusion(&fusion, aten_inputs);
6930 auto cg_output = fe.runFusion(aten_inputs);
6931
6932 auto aten_output = t0.sum({1}).add(t2);
6933
6934 testValidate(&fusion, cg_output, {t0, t2}, {aten_output}, __LINE__, __FILE__);
6935}
6936
6937TEST_F(NVFuserTest, FusionGridReduction10_CUDA) {
6938 Fusion fusion;
6939 FusionGuard fg(&fusion);
6940
6941 auto tv0 = makeSymbolicTensor(4);
6942 fusion.addInput(tv0);
6943
6944 auto tv1 = sum(tv0, {-1});
6945 auto tv2 = sum(tv1, {-1});
6946 auto tv3 = sum(tv2, {-1});
6947
6948 fusion.addOutput(tv3);
6949 tv1->axis(0)->parallelize(ParallelType::TIDx);
6950 tv1->axis(1)->parallelize(ParallelType::BIDx);
6951 tv1->axis(2)->parallelize(ParallelType::TIDy);
6952 tv1->axis(3)->parallelize(ParallelType::TIDz);
6953
6954 tv2->axis(0)->parallelize(ParallelType::TIDx);
6955 tv2->axis(1)->parallelize(ParallelType::BIDx);
6956 tv2->axis(2)->parallelize(ParallelType::TIDy);
6957
6958 tv3->axis(0)->parallelize(ParallelType::TIDx);
6959 tv3->axis(1)->parallelize(ParallelType::BIDx);
6960
6961 tv0->computeAt(tv3, 1);
6962
6963 const int numel_w = 2;
6964 const int numel_x = 3;
6965 const int numel_y = 4;
6966 const int numel_z = 5;
6967
6968 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6969 at::Tensor t0 = at::randn({numel_w, numel_x, numel_y, numel_z}, options);
6970
6971 FusionExecutor fe;
6972 fe.compileFusion(&fusion, {t0});
6973 auto cg_output = fe.runFusion({t0});
6974
6975 auto aten_output = t0.sum({1, 2, 3});
6976
6977 testValidate(&fusion, cg_output, {t0}, {aten_output}, __LINE__, __FILE__);
6978}
6979
6980TEST_F(NVFuserTest, FusionNonRedAxisBind_CUDA) {
6981 int bid_x = 3;
6982 int tid_x = 2;
6983 int red_dim = 0;
6984
6985 Fusion fusion;
6986 FusionGuard fg(&fusion);
6987
6988 // Set up your input tensor views
6989 TensorView* tv0 = makeSymbolicTensor(2);
6990 fusion.addInput(tv0);
6991
6992 TensorView* tv1 = reductionOp(
6993 BinaryOpType::Add, {red_dim}, IrBuilder::create<Double>(0), tv0);
6994 fusion.addOutput(tv1);
6995
6996 tv1->split(-1, tid_x);
6997 tv1->axis(-2)->parallelize(ParallelType::BIDx);
6998 tv1->axis(-1)->parallelize(ParallelType::TIDx);
6999
7000 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7001 at::Tensor input = at::randn({16, bid_x * tid_x}, options);
7002
7003 FusionExecutor fe;
7004 fe.compileFusion(&fusion, {input});
7005 auto cg_outputs = fe.runFusion({input});
7006
7007 auto aten_output = input.to(at::kDouble).sum({red_dim});
7008
7009 testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
7010}
7011
7012TEST_F(NVFuserTest, FusionSplitBCast_CUDA) {
7013 Fusion fusion;
7014 FusionGuard fg(&fusion);
7015
7016 // Set up your input tensor views
7017 TensorView* input_tv0 = makeSymbolicTensor(3);
7018 TensorView* input_tv1 = makeSymbolicTensor(3);
7019 fusion.addInput(input_tv0);
7020 fusion.addInput(input_tv1);
7021
7022 TensorView* sum_tv2 = reductionOp(
7023 BinaryOpType::Add, {2}, IrBuilder::create<Double>(0), input_tv0);
7024 TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
7025 TensorView* output_tv4 = div(input_tv1, bcast_tv3);
7026
7027 sum_tv2->split(-1, 32);
7028 TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});
7029
7030 bcast_tv3->split(-1, 32);
7031 output_tv4->split(-1, 32);
7032
7033 sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
7034 sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
7035 bcast_tv3->axis(0)->parallelize(ParallelType::BIDx);
7036 output_tv4->axis(0)->parallelize(ParallelType::BIDx);
7037
7038 sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
7039 sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
7040 bcast_tv3->axis(1)->parallelize(ParallelType::BIDy);
7041 output_tv4->axis(1)->parallelize(ParallelType::BIDy);
7042
7043 sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
7044 sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
7045 bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx);
7046 output_tv4->axis(-1)->parallelize(ParallelType::TIDx);
7047
7048 fusion.addOutput(output_tv4);
7049
7050 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7051 at::Tensor t0 = at::randn({32, 32, 128}, options);
7052 at::Tensor t1 = at::randn({32, 32, 128}, options);
7053 at::Tensor cg_output = at::empty({32, 32, 128}, options);
7054
7055 FusionExecutor fe;
7056 fe.compileFusion(&fusion, {t0, t1});
7057 fe.runFusion({t0, t1}, {cg_output});
7058}
7059
7060TEST_F(NVFuserTest, FusionBCastInnerDim_CUDA) {
7061 Fusion fusion;
7062 FusionGuard fg(&fusion);
7063
7064 TensorView* tv0 = makeSymbolicTensor(2);
7065 fusion.addInput(tv0);
7066
7067 // reduce then broadcast
7068 auto tv1 = sum(tv0, {0});
7069 auto tv2 = broadcast(tv1, {false, true});
7070
7071 TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast());
7072}
7073
7074TEST_F(NVFuserTest, FusionBCastReduce_CUDA) {
7075 Fusion fusion;
7076 FusionGuard fg(&fusion);
7077
7078 // Set up your input tensor views
7079 TensorView* tv0 = makeSymbolicTensor(2);
7080
7081 auto tv1 = broadcast(tv0, {true, false, false});
7082 auto tv2 = sum(tv1, {1});
7083 TORCH_CHECK(
7084 tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() &&
7085 !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction());
7086}
7087
7088// Multiple consumer reduction with computeAt
7089// https://github.com/csarofeen/pytorch/issues/110
7090TEST_F(NVFuserTest, FusionReductionMultiConsumer_CUDA) {
7091 Fusion fusion;
7092 FusionGuard fg(&fusion);
7093 TensorView* tv0 = makeSymbolicTensor(2);
7094 fusion.addInput(tv0);
7095 auto tv1 = unaryOp(UnaryOpType::Exp, tv0);
7096 auto tv2 =
7097 reductionOp(BinaryOpType::Max, {-1}, IrBuilder::create<Double>(0), tv1);
7098 auto tv3 =
7099 reductionOp(BinaryOpType::Min, {-1}, IrBuilder::create<Double>(0), tv1);
7100 auto tv4 = add(tv2, tv3);
7101 fusion.addOutput(tv4);
7102 tv1->computeAt(tv2, -1, ComputeAtMode::BestEffort);
7103
7104 TORCH_CHECK(tv1->getComputeAtPosition() == 2);
7105}
7106
7107TEST_F(NVFuserTest, FusionComputeAtExprOrder1_CUDA) {
7108 for (const auto i : c10::irange(2)) {
7109 Fusion fusion;
7110 FusionGuard fg(&fusion);
7111
7112 // Set up your input tensor views
7113 TensorView* tv0 = makeSymbolicTensor(1);
7114 fusion.addInput(tv0);
7115
7116 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
7117 auto tv2 = add(tv0, IrBuilder::create<Double>(1));
7118 TensorView* tv3 = add(tv1, tv2);
7119 // Set outputs tv2 or tv1 and then tv3
7120 if (i == 0) {
7121 fusion.addOutput(tv2);
7122 } else {
7123 fusion.addOutput(tv1);
7124 }
7125 fusion.addOutput(tv3);
7126
7127 if (i == 0) {
7128 tv1->computeAt(tv3, -1);
7129 } else {
7130 tv2->computeAt(tv3, -1);
7131 }
7132
7133 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7134 at::Tensor aten_input = at::randn({100}, options);
7135 std::vector<at::Tensor> aten_outputs = {
7136 aten_input + 1, (aten_input + 1) * 2};
7137
7138 FusionExecutor fe;
7139 fe.compileFusion(&fusion, {aten_input});
7140 auto cg_outputs = fe.runFusion({aten_input});
7141
7142 testValidate(
7143 &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
7144 }
7145}
7146
7147TEST_F(NVFuserTest, FusionComputeAtExprOrder2_CUDA) {
7148 Fusion fusion;
7149 FusionGuard fg(&fusion);
7150
7151 // Set up your input tensor views
7152 TensorView* tv0 = makeSymbolicTensor(2);
7153 fusion.addInput(tv0);
7154
7155 auto tv1 = add(tv0, IrBuilder::create<Double>(1));
7156 auto tv2 = add(tv0, IrBuilder::create<Double>(1));
7157 TensorView* tv3 = add(tv1, tv2);
7158 fusion.addOutput(tv3);
7159
7160 tv3->split(-1, 32);
7161
7162 tv1->computeAt(tv3, -1);
7163 tv2->computeAt(tv3, -2);
7164
7165 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7166 at::Tensor aten_input = at::randn({100, 100}, options);
7167 auto aten_output = (aten_input + 1) * 2;
7168
7169 at::Tensor cg_output = at::empty_like(aten_input, options);
7170
7171 FusionExecutor fe;
7172 fe.compileFusion(&fusion, {aten_input});
7173 fe.runFusion({aten_input}, {cg_output});
7174
7175 testValidate(
7176 &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
7177}
7178
7179TEST_F(NVFuserTest, FusionComputeAtExprOrder3_CUDA) {
7180 Fusion fusion;
7181 FusionGuard fg(&fusion);
7182
7183 const int64_t dimx = 13;
7184 const int64_t dimy = 15;
7185
7186 TensorView* tv0 = makeConcreteTensor({dimx, dimy});
7187 fusion.addInput(tv0);
7188 TensorView* tv1 = add(tv0, IrBuilder::create<Double>(1));
7189 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2));
7190 TensorView* tv3 = add(tv2, IrBuilder::create<Double>(3));
7191 TensorView* tv4 = add(tv3, IrBuilder::create<Double>(4));
7192 TensorView* tv5 = mul(tv2, tv4);
7193 fusion.addOutput(tv5);
7194
7195 tv1->computeAt(tv2, 2);
7196 tv3->computeAt(tv4, 1);
7197 tv4->computeAt(tv5, 2);
7198
7199 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7200 at::Tensor aten_input = at::randn({dimx, dimy}, options);
7201 auto t1 = aten_input.add(1.);
7202 auto t2 = t1.add(2.);
7203 auto t3 = t2.add(3.);
7204 auto t4 = t3.add(4.);
7205 auto aten_output = t2.mul(t4);
7206
7207 torch::jit::fuser::cuda::FusionExecutor fe;
7208 fe.compileFusion(&fusion, {aten_input});
7209 auto cg_outputs = fe.runFusion({aten_input});
7210
7211 testValidate(
7212 &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
7213}
7214
7215TEST_F(NVFuserTest, FusionZeroDimComputeAt_CUDA) {
7216 Fusion fusion;
7217 FusionGuard fg(&fusion);
7218
7219 TensorView* tv0 = makeSymbolicTensor(1);
7220 fusion.addInput(tv0);
7221
7222 auto tv1 = sum(tv0, {0});
7223 auto tv2 = add(tv1, IrBuilder::create<Double>(1));
7224 fusion.addOutput(tv2);
7225 TORCH_CHECK(tv2->nDims() == 0);
7226 tv1->computeAt(tv2, 0);
7227
7228 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7229 at::Tensor aten_input = at::randn({100}, options);
7230 auto aten_output = aten_input.to(at::kDouble).sum() + 1;
7231
7232 FusionExecutor fe;
7233 fe.compileFusion(&fusion, {aten_input});
7234 auto cg_outputs = fe.runFusion({aten_input});
7235
7236 testValidate(
7237 &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
7238}
7239
7240TEST_F(NVFuserTest, FusionZeroDimBroadcast_CUDA) {
7241 Fusion fusion;
7242 FusionGuard fg(&fusion);
7243
7244 TensorView* tv0 = makeSymbolicTensor(0);
7245 fusion.addInput(tv0);
7246
7247 auto tv1 = broadcast(tv0, {true, true});
7248 TORCH_CHECK(tv1->nDims() == 2);
7249
7250 TensorView* tv2 = makeSymbolicTensor(2);
7251 fusion.addInput(tv2);
7252
7253 auto tv3 = add(tv1, tv2);
7254 auto tv4 = sum(tv3, {0, 1});
7255 fusion.addOutput(tv4);
7256
7257 tv3->computeAt(tv4, -1);
7258 tv3->axis(-2)->parallelize(ParallelType::TIDx);
7259 tv3->axis(-1)->parallelize(ParallelType::TIDy);
7260
7261 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7262 at::Tensor t0 = at::randn({}, options);
7263 at::Tensor t1 = at::randn({10, 10}, options);
7264
7265 auto aten_output = (t0.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + t1)
7266 .to(at::kDouble)
7267 .sum();
7268
7269 std::vector<IValue> aten_inputs = {t0, t1};
7270 at::Tensor cg_output = at::empty({}, options);
7271
7272 FusionExecutor fe;
7273 fe.compileFusion(&fusion, aten_inputs);
7274 fe.runFusion(aten_inputs, {cg_output});
7275
7276 testValidate(
7277 &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
7278}
7279
7280TEST_F(NVFuserTest, FusionZeroDimReduction_CUDA) {
7281 Fusion fusion;
7282 FusionGuard fg(&fusion);
7283
7284 const int bdimx = 32;
7285 const int gdimx = 32;
7286
7287 TensorView* tv0 = makeSymbolicTensor(1);
7288 fusion.addInput(tv0);
7289
7290 auto tv1 = sum(tv0, {0});
7291 fusion.addOutput(tv1);
7292
7293 tv1->split(0, bdimx);
7294 tv1->split(0, gdimx);
7295 auto tv2 = tv1->rFactor({0});
7296
7297 tv1->axis(-1)->parallelize(ParallelType::TIDx);
7298 tv2->axis(-1)->parallelize(ParallelType::TIDx);
7299 tv1->axis(-2)->parallelize(ParallelType::BIDx);
7300 tv2->axis(-2)->parallelize(ParallelType::BIDx);
7301
7302 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7303 at::Tensor aten_input = at::randn({1000}, options);
7304 auto aten_output = aten_input.to(at::kDouble).sum();
7305
7306 at::Tensor cg_output = at::empty({}, options);
7307
7308 FusionExecutor fe;
7309 fe.compileFusion(&fusion, {aten_input});
7310 fe.runFusion({aten_input}, {cg_output});
7311
7312 testValidate(
7313 &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
7314}
7315
7316TEST_F(NVFuserTest, FusionBCastAfterReduce_CUDA) {
7317 Fusion fusion;
7318 FusionGuard fg(&fusion);
7319 const int tidx = 128;
7320
7321 // Set up your input tensor views
7322 TensorView* tv0 = makeSymbolicTensor(2);
7323 fusion.addInput(tv0);
7324
7325 auto tv1 = sum(tv0, {1});
7326 auto tv2 = broadcast(tv1, {false, true});
7327
7328 tv1->split(1, tidx);
7329 auto tv3 = tv1->rFactor({-2});
7330
7331 TensorView* tv4 = makeSymbolicTensor(2);
7332 fusion.addInput(tv4);
7333
7334 auto tv5 = add(tv2, tv4);
7335 fusion.addOutput(tv5);
7336 tv5->split(1, tidx);
7337
7338 tv3->computeAt(tv5, 1);
7339
7340 tv2->split(1, tidx);
7341
7342 tv1->axis(-1)->parallelize(ParallelType::TIDx);
7343 tv2->axis(-1)->parallelize(ParallelType::TIDx);
7344 tv3->axis(-1)->parallelize(ParallelType::TIDx);
7345 tv5->axis(-1)->parallelize(ParallelType::TIDx);
7346
7347 tv5->axis(0)->parallelize(ParallelType::BIDx);
7348
7349 int x = 63, y = 200;
7350
7351 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7352
7353 at::Tensor t0 = at::randn({x, y}, options);
7354 at::Tensor t4 = at::randn({x, y}, options);
7355
7356 auto t3 = t0.to(at::kDouble).sum({1}).unsqueeze(-1).expand({x, y});
7357 auto aten_output = t3.add(t4);
7358
7359 std::vector<IValue> aten_inputs = {t0, t4};
7360 FusionExecutor fe;
7361 fe.compileFusion(&fusion, {t0, t4});
7362 auto cg_outputs = fe.runFusion({t0, t4});
7363
7364 testValidate(
7365 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
7366}
7367
7368TEST_F(NVFuserTest, FusionOutputBroadcast_CUDA) {
7369 Fusion fusion;
7370 FusionGuard fg(&fusion);
7371
7372 TensorView* tv0 = makeConcreteTensor({2, 3});
7373 fusion.addInput(tv0);
7374
7375 TensorView* tv1 = broadcast(tv0, {true, false, true, false, true});
7376
7377 fusion.addOutput(tv1);
7378
7379 const auto options =
7380 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7381
7382 at::Tensor aten_input = at::randn({2, 3}, options);
7383 auto aten_output = aten_input.unsqueeze(2).unsqueeze(1).unsqueeze(0);
7384
7385 FusionExecutor fe;
7386 fe.compileFusion(&fusion, {aten_input});
7387 auto cg_outputs = fe.runFusion({aten_input});
7388
7389 testValidate(
7390 &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
7391}
7392
7393TEST_F(NVFuserTest, FusionReductionKeepDimBasic_CUDA) {
7394 Fusion fusion;
7395 FusionGuard fg(&fusion);
7396
7397 TensorView* tv0 = makeConcreteTensor({2, 3, 4, 5, 6});
7398 fusion.addInput(tv0);
7399
7400 TensorView* tv1 = sum(tv0, {0, 2, -1}, /*keep_dim=*/true);
7401
7402 fusion.addOutput(tv1);
7403
7404 const auto options =
7405 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7406
7407 at::Tensor aten_input = at::randn({2, 3, 4, 5, 6}, options);
7408 auto aten_output =
7409 aten_input.to(at::kDouble).sum({0, 2, -1}, /*keepdim=*/true);
7410
7411 FusionExecutor fe;
7412 fe.compileFusion(&fusion, {aten_input});
7413 auto cg_outputs = fe.runFusion({aten_input});
7414
7415 testValidate(
7416 &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
7417}
7418
7419TEST_F(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) {
7420 constexpr int bid_x = 80;
7421 constexpr int tid_x = 4096;
7422 constexpr int red_dim = 1;
7423
7424 Fusion fusion;
7425 FusionGuard fg(&fusion);
7426
7427 // Set up your input tensor views
7428 TensorView* tv0 = makeConcreteTensor({bid_x, tid_x});
7429 fusion.addInput(tv0);
7430
7431 TensorView* tv1 = reductionOp(
7432 BinaryOpType::Add,
7433 {red_dim},
7434 IrBuilder::create<Double>(0),
7435 tv0,
7436 /*keep_dim=*/true);
7437
7438 fusion.addOutput(tv1);
7439
7440 const auto options =
7441 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7442
7443 at::Tensor aten_input = at::randn({bid_x, tid_x}, options);
7444 auto aten_output =
7445 aten_input.to(at::kDouble).sum({red_dim}, /*keepdim=*/true);
7446
7447 // Apply reduction heuristic
7448 auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
7449 TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
7450 scheduleReduction(&fusion, *reduction_params);
7451
7452 auto lparams = reduction_params->lparams;
7453
7454 FusionExecutor fe;
7455 fe.compileFusion(&fusion, {aten_input}, lparams);
7456 auto cg_outputs = fe.runFusion({aten_input}, lparams);
7457
7458 testValidate(
7459 &fusion,
7460 cg_outputs,
7461 {aten_input},
7462 {aten_output},
7463 __LINE__,
7464 __FILE__,
7465 "",
7466 lparams);
7467}
7468
7469TEST_F(NVFuserTest, FusionSumTo_CUDA) {
7470 Fusion fusion;
7471 FusionGuard fg(&fusion);
7472
7473 std::vector<int64_t> tensor_shape{2, 3, 4, 5, 6};
7474 std::vector<int64_t> sum_to_shape{1, 5, 6};
7475
7476 std::vector<int64_t> tensor_shape_ref{2, 3, 4, 5, 6};
7477 std::vector<int64_t> sum_to_shape_ref{1, 5, 6};
7478
7479 std::vector<Int*> sum_to_symb;
7480 std::transform(
7481 sum_to_shape.begin(),
7482 sum_to_shape.end(),
7483 std::back_inserter(sum_to_symb),
7484 [](int s) -> Int* { return IrBuilder::create<Int>(s); });
7485
7486 TensorView* tv0 = makeConcreteTensor(tensor_shape);
7487 fusion.addInput(tv0);
7488
7489 TensorView* tv1 = sum_to(tv0, sum_to_symb);
7490 fusion.addOutput(tv1);
7491
7492 const auto options =
7493 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7494
7495 at::Tensor aten_input = at::randn(tensor_shape_ref, options);
7496 auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref);
7497
7498 FusionExecutor fe;
7499 fe.compileFusion(&fusion, {aten_input});
7500 auto cg_outputs = fe.runFusion({aten_input});
7501
7502 TORCH_CHECK(
7503 cg_outputs[0].dim() == static_cast<int64_t>(sum_to_shape.size()),
7504 "sum_to not keeping the final dimension");
7505
7506 testValidate(
7507 &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
7508}
7509
7510TEST_F(NVFuserTest, FusionSumToNoop_CUDA) {
7511 Fusion fusion;
7512 FusionGuard fg(&fusion);
7513
7514 std::vector<int64_t> tensor_shape{4, 5, 6};
7515 std::vector<int64_t> sum_to_shape{4, 5, 6};
7516
7517 std::vector<int64_t> tensor_shape_ref{4, 5, 6};
7518 std::vector<int64_t> sum_to_shape_ref{4, 5, 6};
7519
7520 std::vector<Int*> sum_to_symb;
7521 std::transform(
7522 sum_to_shape.begin(),
7523 sum_to_shape.end(),
7524 std::back_inserter(sum_to_symb),
7525 [](int s) -> Int* { return IrBuilder::create<Int>(s); });
7526
7527 TensorView* tv0 = makeConcreteTensor(tensor_shape);
7528 fusion.addInput(tv0);
7529
7530 TensorView* tv1 = sum_to(tv0, sum_to_symb);
7531
7532 // Dummy operator to avoid tv0 both input and output
7533 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(0));
7534 fusion.addOutput(tv2);
7535
7536 const auto options =
7537 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7538
7539 at::Tensor aten_input = at::randn(tensor_shape_ref, options);
7540
7541 FusionExecutor fe;
7542 fe.compileFusion(&fusion, {aten_input});
7543 auto cg_outputs = fe.runFusion({aten_input});
7544 auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref);
7545
7546 TORCH_CHECK(
7547 cg_outputs[0].dim() == static_cast<int64_t>(sum_to_shape.size()),
7548 "sum_to not keeping the final dimension");
7549
7550 testValidate(
7551 &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
7552}
7553
7554TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) {
7555 constexpr int bid_x = 80;
7556 constexpr int tid_x = 4096;
7557 constexpr int red_dim = 1;
7558
7559 Fusion fusion;
7560 FusionGuard fg(&fusion);
7561
7562 // Set up your input tensor views
7563 TensorView* tv0 = makeSymbolicTensor(2);
7564 fusion.addInput(tv0);
7565
7566 TensorView* tv1 = reductionOp(
7567 BinaryOpType::Add, {red_dim}, IrBuilder::create<Double>(0), tv0);
7568 fusion.addOutput(tv1);
7569
7570 const auto options =
7571 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7572
7573 at::Tensor aten_input = at::randn({bid_x, tid_x}, options);
7574 auto aten_output = aten_input.to(at::kDouble).sum({red_dim});
7575
7576 // Apply reduction heuristic
7577 auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
7578 TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
7579 scheduleReduction(&fusion, *reduction_params);
7580
7581 auto lparams = reduction_params->lparams;
7582
7583 FusionExecutor fe;
7584 fe.compileFusion(&fusion, {aten_input}, lparams);
7585 // no broadcasting needed, omitting the last optional argument;
7586 auto cg_outputs = fe.runFusion({aten_input}, lparams);
7587
7588 testValidate(
7589 &fusion,
7590 cg_outputs,
7591 {aten_input},
7592 {aten_output},
7593 __LINE__,
7594 __FILE__,
7595 "",
7596 lparams);
7597}
7598
7599// This test checks if our system could correctly handles the case where both
7600// reduction and trivial reduction exist in the fusion. Trivial reduction
7601// deserve testing because trivial reduction is handled more like a broadcasting
7602// rather than a reduction.
7603TEST_F(NVFuserTest, FusionReductionWithTrivialReduction_CUDA) {
7604 constexpr int bid_x = 80;
7605 constexpr int tid_x = 4096;
7606
7607 std::vector<std::vector<int64_t>> shapes = {
7608 {-1, -1, 1}, {-1, 1, -1}, {1, -1, -1}};
7609
7610 for (auto shape : shapes) {
7611 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
7612 Fusion& fusion = *fusion_ptr;
7613 FusionGuard fg(&fusion);
7614
7615 std::vector<std::vector<int64_t>> reduction_dims = {
7616 {0},
7617 {1},
7618 {2},
7619 {0, 1},
7620 {0, 2},
7621 {1, 2},
7622 {0, 1, 2},
7623 };
7624
7625 // Set up your input tensor views
7626 TensorView* tv0 = makeConcreteTensor(shape);
7627 fusion.addInput(tv0);
7628
7629 for (auto rdims : reduction_dims) {
7630 std::vector<int> rdims_(rdims.begin(), rdims.end());
7631 auto tv = sum(tv0, rdims_);
7632 fusion.addOutput(tv);
7633 }
7634
7635 const auto options =
7636 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7637
7638 auto concrete_shape = shape;
7639 std::deque<int64_t> concrete_values = {bid_x, tid_x};
7640 for (auto& s : concrete_shape) {
7641 if (s == -1) {
7642 s = concrete_values.front();
7643 concrete_values.pop_front();
7644 }
7645 }
7646
7647 at::Tensor aten_input = at::randn(concrete_shape, options);
7648 std::vector<at::Tensor> aten_outputs;
7649 for (auto rdims : reduction_dims) {
7650 aten_outputs.push_back(aten_input.sum(rdims));
7651 }
7652
7653 FusionExecutorCache executor_cache(std::move(fusion_ptr));
7654 auto cg_outputs = executor_cache.runFusionWithInputs({aten_input});
7655
7656 testValidate(
7657 &fusion,
7658 cg_outputs,
7659 {aten_input},
7660 aten_outputs,
7661 __LINE__,
7662 __FILE__,
7663 "");
7664 }
7665}
7666
7667// Simple reduction parallelized on a symbolic size.
7668TEST_F(NVFuserTest, FusionSymbolicReduction_CUDA) {
7669 Fusion fusion;
7670 FusionGuard fg(&fusion);
7671
7672 // Set up your input tensor views
7673 TensorView* tv0 = makeSymbolicTensor(2);
7674 fusion.addInput(tv0);
7675
7676 // tv1[I0, R1] = tv0[I0, I1]
7677 TensorView* tv1 =
7678 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
7679 fusion.addOutput(tv1);
7680
7681 // Interface should just be a direct split with a Parallel type. We can
7682 // include the parallelize call if we do this.
7683 tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
7684 // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
7685
7686 TensorView* tv2 = tv1->rFactor({1});
7687 // tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] = tv0[I0, I1]
7688 // tv1[I0, R1oi{4}, R1i{BIDx}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}]
7689
7690 // Incrementally, can print in between for debugging
7691 tv0->computeAt(tv2, 1);
7692 tv2->computeAt(tv1, 1);
7693
7694 tv2->axis(-1)->parallelize(ParallelType::TIDx);
7695
7696 tv1->axis(0)->parallelize(ParallelType::BIDx);
7697 tv1->axis(-1)->parallelize(ParallelType::TIDx);
7698
7699 int numel_x = 65000;
7700 int numel_y = 1025;
7701
7702 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7703 at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
7704 auto aten_output = aten_input.to(at::kDouble).sum({1});
7705
7706 // How many threads to use for the block reduction
7707 int runtime_threadIdx_dim = 128;
7708
7709 LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
7710
7711 FusionExecutor fe;
7712 fe.compileFusion(&fusion, {aten_input}, lparams);
7713 auto cg_outputs = fe.runFusion({aten_input}, lparams);
7714
7715 testValidate(
7716 &fusion,
7717 cg_outputs,
7718 {aten_input},
7719 {aten_output},
7720 __LINE__,
7721 __FILE__,
7722 "",
7723 lparams);
7724}
7725
7726TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) {
7727 const std::vector<int> red_dims = {0, 2};
7728 // Copy is because CodeGen requires int and Pytorch requires int64_t
7729 // for a vector of reduction dimensions
7730 const std::vector<int64_t> red_dims64 = {0, 2};
7731 const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
7732 const std::vector<int64_t> tensor_dims_out = {10, 20};
7733
7734 Fusion fusion;
7735 FusionGuard fg(&fusion);
7736
7737 // Set up your input tensor views
7738 TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size());
7739 fusion.addInput(tv0);
7740
7741 TensorView* tv1 = reductionOp(
7742 BinaryOpType::Add, red_dims, IrBuilder::create<Double>(0), tv0);
7743 fusion.addOutput(tv1);
7744
7745 const auto options =
7746 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7747 at::Tensor aten_input = at::randn(tensor_dims_in, options);
7748 auto aten_output = aten_input.to(at::kDouble).sum(red_dims64);
7749 at::Tensor cg_output = at::empty(tensor_dims_out, options);
7750
7751 // Apply reduction heuristic
7752 auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
7753 TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
7754 scheduleReduction(&fusion, *reduction_params);
7755 auto lparams = reduction_params->lparams;
7756
7757 FusionExecutor fe;
7758 fe.compileFusion(&fusion, {aten_input}, lparams);
7759 fe.runFusion({aten_input}, {cg_output}, lparams);
7760
7761 testValidate(
7762 &fusion,
7763 {cg_output},
7764 {aten_input},
7765 {aten_output},
7766 __LINE__,
7767 __FILE__,
7768 "",
7769 lparams);
7770}
7771
7772TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) {
7773 const std::vector<int> red_dims = {1, 3};
7774 // Copy is because CodeGen requires int and Pytorch requires int64_t
7775 // for a vector of reduction dimensions
7776 const std::vector<int64_t> red_dims64 = {1, 3};
7777 const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
7778
7779 Fusion fusion;
7780 FusionGuard fg(&fusion);
7781
7782 // Set up your input tensor views
7783 TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size());
7784 fusion.addInput(tv0);
7785
7786 TensorView* tv1 = reductionOp(
7787 BinaryOpType::Add, red_dims, IrBuilder::create<Double>(0), tv0);
7788 fusion.addOutput(tv1);
7789
7790 const auto options =
7791 at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
7792 at::Tensor aten_input = at::randn(tensor_dims_in, options);
7793 auto aten_output = aten_input.to(at::kDouble).sum(red_dims64);
7794
7795 auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
7796 TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
7797 scheduleReduction(&fusion, *reduction_params);
7798 auto lparams = reduction_params->lparams;
7799
7800 FusionExecutor fe;
7801 fe.compileFusion(&fusion, {aten_input}, lparams);
7802 auto cg_outputs = fe.runFusion({aten_input}, lparams);
7803
7804 testValidate(
7805 &fusion,
7806 cg_outputs,
7807 {aten_input},
7808 {aten_output},
7809 __LINE__,
7810 __FILE__,
7811 "",
7812 lparams);
7813}
7814
7815TEST_F(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) {
7816 std::vector<DataType> dtypes = {
7817 DataType::Double, DataType::Float, DataType::Half};
7818 // TODO: add test for complex. Currently complex fails with the following
7819 // NVRTC compilation error message:
7820 // error: no suitable user-defined conversion from
7821 // "CudaCodeGen::std::complex<double>" to "CudaCodeGen::std::complex<float>"
7822 // exists
7823#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
7824 if (at::cuda::getDeviceProperties(0)->major >= 8) {
7825 dtypes.insert(dtypes.end(), DataType::BFloat16);
7826 }
7827#endif
7828
7829 std::vector<int> red_dims;
7830
7831 // Tried to cut down the number iterations with just
7832 // doing every other power of 2.
7833 for (int i = 1; i <= 1024 * 1024; i <<= 2) {
7834 red_dims.push_back(i);
7835 }
7836
7837 for (auto dtype : dtypes) {
7838 at::ScalarType aten_dtype = data_type_to_aten(dtype);
7839 for (auto& rdim : red_dims) {
7840 Fusion fusion;
7841 FusionGuard fg(&fusion);
7842
7843 bool is_fp16 = dtype == DataType::Half;
7844 bool is_bf16 = dtype == DataType::BFloat16;
7845
7846 TensorView* tv0 = makeSymbolicTensor(1, dtype);
7847 fusion.addInput(tv0);
7848
7849 TensorView* tv0_cast = tv0;
7850 if (is_fp16 || is_bf16) {
7851 tv0_cast = castOp(DataType::Float, tv0);
7852 }
7853
7854 TensorView* tv1 = sum(tv0_cast, {0});
7855
7856 TensorView* tv1_cast = tv1;
7857 if (is_fp16) {
7858 tv1_cast = castOp(DataType::Half, tv1);
7859 }
7860 if (is_bf16) {
7861 tv1_cast = castOp(DataType::BFloat16, tv1);
7862 }
7863
7864 fusion.addOutput(tv1_cast);
7865
7866 auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
7867
7868 at::Tensor aten_input = at::randn({rdim}, options);
7869 auto aten_output = aten_input.to(at::kDouble).sum({0});
7870
7871 auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
7872 TORCH_CHECK(reduction_params != nullptr, "Reduction is not found!");
7873 scheduleReduction(&fusion, *reduction_params);
7874 auto lparams = reduction_params->lparams;
7875
7876 FusionExecutor fe;
7877 fe.compileFusion(&fusion, {aten_input}, lparams);
7878 auto cg_outputs = fe.runFusion({aten_input}, lparams);
7879
7880 testValidate(
7881 &fusion,
7882 cg_outputs,
7883 {aten_input},
7884 {aten_output},
7885 __LINE__,
7886 __FILE__,
7887 "",
7888 lparams);
7889 }
7890 }
7891}
7892
7893TEST_F(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) {
7894 std::vector<DataType> dtypes = {
7895 DataType::Double, DataType::Float, DataType::Half};
7896 // TODO: add complex support. Currently, complex fails with the following
7897 // NVRTC compilation error:
7898 // error: no instance of overloaded function "__shfl_xor_sync" matches the
7899 // argument list
7900#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
7901 if (at::cuda::getDeviceProperties(0)->major >= 8) {
7902 dtypes.insert(dtypes.end(), DataType::BFloat16);
7903 }
7904#endif
7905
7906 std::vector<int> red_axis = {1, 0};
7907 std::vector<int> output_dims = {160, 320};
7908 std::vector<int> red_dims;
7909
7910 // Tried to cut down the number iterations with just
7911 // doing every other power of 2.
7912 for (int i = 1; i <= 1024 * 1024; i <<= 2) {
7913 red_dims.push_back(i);
7914 }
7915
7916 for (auto dtype : dtypes) {
7917 at::ScalarType aten_dtype = data_type_to_aten(dtype);
7918 for (auto& axis : red_axis) {
7919 for (auto& odim : output_dims) {
7920 for (auto& rdim : red_dims) {
7921 Fusion fusion;
7922 FusionGuard fg(&fusion);
7923
7924 bool is_fp16 = dtype == DataType::Half;
7925 bool is_bf16 = dtype == DataType::BFloat16;
7926
7927 TensorView* tv0 = makeSymbolicTensor(2, dtype);
7928 fusion.addInput(tv0);
7929
7930 TensorView* tv0_cast = tv0;
7931 if (is_fp16 || is_bf16) {
7932 tv0_cast = castOp(DataType::Float, tv0);
7933 }
7934
7935 TensorView* tv1 = sum(tv0_cast, {axis});
7936
7937 TensorView* tv1_cast = tv1;
7938 if (is_fp16) {
7939 tv1_cast = castOp(DataType::Half, tv1);
7940 }
7941 if (is_bf16) {
7942 tv1_cast = castOp(DataType::BFloat16, tv1);
7943 }
7944 fusion.addOutput(tv1_cast);
7945
7946 auto options =
7947 at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
7948
7949 at::Tensor aten_input =
7950 (axis ? at::randn({odim, rdim}, options)
7951 : at::randn({rdim, odim}, options));
7952
7953 auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
7954 TORCH_CHECK(reduction_params != nullptr, "Reduction is not found!");
7955 scheduleReduction(&fusion, *reduction_params);
7956 auto lparams = reduction_params->lparams;
7957
7958 FusionExecutor fe;
7959 fe.compileFusion(&fusion, {aten_input}, lparams);
7960 auto cg_outputs = fe.runFusion({aten_input}, lparams);
7961 auto aten_output = aten_input.to(at::kDouble).sum({axis});
7962 testValidate(
7963 &fusion,
7964 cg_outputs,
7965 {aten_input},
7966 {aten_output},
7967 __LINE__,
7968 __FILE__,
7969 "",
7970 lparams);
7971 }
7972 }
7973 }
7974 }
7975}
7976
7977TEST_F(NVFuserTest, FusionCacheBefore_CUDA) {
7978 // TVM Cache Write
7979 Fusion fusion;
7980 FusionGuard fg(&fusion);
7981
7982 TensorView* tv0 = makeSymbolicTensor(2);
7983 TensorView* tv1 = add(tv0, IrBuilder::create<Double>(1.0));
7984 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(3.0));
7985 fusion.addInput(tv0);
7986 fusion.addOutput(tv2);
7987
7988 // Before: TV2 = TV1 * 3
7989 // After: TV3 = TV1 * 3;
7990 // TV2 = TV3;
7991 TensorView* tv3 = tv2->cacheBefore();
7992
7993 constexpr int BSX = 32;
7994 tv2->split(-1, BSX);
7995 tv0->computeAt(tv2, -1);
7996
7997 // Thread and Block binding
7998 tv2->axis(0)->parallelize(ParallelType::BIDx);
7999 tv2->axis(-1)->parallelize(ParallelType::TIDx);
8000
8001 constexpr int M = 32, N = 750;
8002
8003 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8004 at::Tensor aten_input = at::randn({M, N}, options);
8005 at::Tensor aten_output = (aten_input + 1.0) * 3.0;
8006
8007 FusionExecutor fe;
8008 fe.compileFusion(&fusion, {aten_input});
8009 auto cg_outputs = fe.runFusion({aten_input});
8010
8011 testValidate(
8012 &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
8013}
8014
8015TEST_F(NVFuserTest, FusionCacheAfter_CUDA) {
8016 // TVM Cache Read
8017 Fusion fusion;
8018 FusionGuard fg(&fusion);
8019
8020 TensorView* tv0 = makeSymbolicTensor(2);
8021 TensorView* tv1 = add(tv0, IrBuilder::create<Double>(1.0));
8022 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(3.0));
8023 fusion.addInput(tv0);
8024 fusion.addOutput(tv2);
8025
8026 // Before: TV1 = TV0 + 1
8027 // After: TV3 = TV0;
8028 // TV1 = TV3 + 1
8029 TensorView* tv3 = tv0->cacheAfter();
8030
8031 constexpr int BSX = 32;
8032 tv2->split(-1, BSX);
8033 tv0->computeAt(tv2, -1);
8034
8035 // Thread and Block binding
8036 tv2->axis(0)->parallelize(ParallelType::BIDx);
8037 tv2->axis(-1)->parallelize(ParallelType::TIDx);
8038
8039 constexpr int M = 32, N = 457;
8040
8041 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8042 at::Tensor aten_input = at::randn({M, N}, options);
8043 at::Tensor aten_output = (aten_input + 1.0) * 3.0;
8044
8045 FusionExecutor fe;
8046 fe.compileFusion(&fusion, {aten_input});
8047 auto cg_outputs = fe.runFusion({aten_input});
8048
8049 testValidate(
8050 &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
8051}
8052
8053TEST_F(NVFuserTest, FusionCacheFork_CUDA) {
8054 Fusion fusion;
8055 FusionGuard fg(&fusion);
8056
8057 TensorView* tv0 = makeSymbolicTensor(2);
8058 TensorView* tv1 = add(tv0, IrBuilder::create<Double>(1.0));
8059 TensorView* tv2 = mul(tv1, IrBuilder::create<Double>(3.0));
8060 fusion.addInput(tv0);
8061 fusion.addOutput(tv1);
8062 fusion.addOutput(tv2);
8063 // Before: TV1 = TV0 + 1
8064 // TV2 = TV1 * 1
8065 // Output: TV1, TV2
8066
8067 // After: TV1 = TV0 + 1
8068 // TV3 = TV1
8069 // TV2 = TV1 * 1
8070 // Output: TV3, TV2
8071
8072 // cacheFork !!does not!! automatically apply ComputeAt to the cache
8073 auto tv3 = tv1->cacheFork();
8074
8075 constexpr int BSX = 32;
8076 tv2->split(-1, BSX);
8077 tv0->computeAt(tv2, -1);
8078
8079 // Thread and Block binding
8080 tv2->axis(0)->parallelize(ParallelType::BIDx);
8081 tv2->axis(-1)->parallelize(ParallelType::TIDx);
8082
8083 constexpr int M = 32, N = 457;
8084
8085 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8086 at::Tensor aten_input = at::randn({M, N}, options);
8087 at::Tensor aten_output1 = aten_input + 1.0;
8088 at::Tensor aten_output2 = aten_output1 * 3.0;
8089
8090 FusionExecutor fe;
8091 fe.compileFusion(&fusion, {aten_input});
8092 auto cg_outputs = fe.runFusion({aten_input});
8093
8094 testValidate(
8095 &fusion,
8096 cg_outputs,
8097 {aten_input},
8098 {aten_output1, aten_output2},
8099 __LINE__,
8100 __FILE__);
8101}
8102
8103TEST_F(NVFuserTest, FusionCacheIndirect_CUDA) {
8104 Fusion fusion;
8105 FusionGuard fg(&fusion);
8106
8107 TensorView* tv0 = makeSymbolicTensor(2);
8108 TensorView* tv1 = makeSymbolicTensor(2);
8109 TensorView* tv2 = makeSymbolicTensor(2);
8110 TensorView* tv3 = makeSymbolicTensor(2);
8111 TensorView* tv4 = sub(tv2, tv3);
8112 TensorView* tv5 = add(tv1, tv4);
8113 TensorView* tv6 = sub(tv5, tv0);
8114 fusion.addInput(tv0);
8115 fusion.addInput(tv1);
8116 fusion.addInput(tv2);
8117 fusion.addInput(tv3);
8118 fusion.addOutput(tv6);
8119 // t6 = ((t1 + (t2 - t3)) - t0)
8120
8121 tv5->cacheAfter();
8122 tv5->cacheBefore();
8123
8124 // cacheAfter on inputs placed before schedule
8125 constexpr int BSX = 32;
8126 tv6->split(-1, BSX);
8127 tv2->computeAt(tv6, -1);
8128
8129 // Thread and Block binding
8130 tv6->axis(0)->parallelize(ParallelType::BIDx);
8131 tv6->axis(-1)->parallelize(ParallelType::TIDx);
8132
8133 constexpr int M = 32, N = 810;
8134
8135 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8136 at::Tensor t0 = at::randn({M, N}, options);
8137 at::Tensor t1 = at::randn({M, N}, options);
8138 at::Tensor t2 = at::randn({M, N}, options);
8139 at::Tensor t3 = at::randn({M, N}, options);
8140
8141 std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
8142 at::Tensor aten_output = (t1 + (t2 - t3)) - t0;
8143
8144 FusionExecutor fe;
8145 fe.compileFusion(&fusion, aten_inputs);
8146 auto cg_outputs = fe.runFusion(aten_inputs);
8147
8148 testValidate(
8149 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
8150}
8151
8152TEST_F(NVFuserTest, FusionCacheBcast_CUDA) {
8153 Fusion fusion;
8154 FusionGuard fg(&fusion);
8155
8156 // Algorithm
8157 TensorView* tv0 = makeSymbolicTensor(1); // (M, 1)
8158 TensorView* tv1 = broadcast(tv0, {false, true});
8159 TensorView* tv2 = makeSymbolicTensor(1); // (1, N)
8160 TensorView* tv3 = broadcast(tv2, {true, false});
8161 TensorView* tv4 = mul(tv1, tv3);
8162 fusion.addInput(tv0);
8163 fusion.addInput(tv2);
8164 fusion.addOutput(tv4);
8165
8166 // Case 1
8167 tv0->cacheAfter();
8168
8169 // Case 2
8170 tv1->cacheBefore();
8171
8172 // Case 3
8173 tv1->cacheAfter();
8174
8175 // Case 4
8176 TensorView* tv8 = tv4->cacheBefore();
8177
8178 constexpr int BSX = 128;
8179 tv4->split(0, BSX);
8180 tv4->split(-1, BSX);
8181 tv4->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
8182 // M/BSX, N/BSY, BSX, BSY
8183 tv0->computeAt(tv4, 2);
8184 tv2->computeAt(tv4, 2);
8185 // 0, 1 | 2, 3, 4
8186
8187 tv4->axis(0)->parallelize(ParallelType::BIDx);
8188 tv4->axis(1)->parallelize(ParallelType::BIDy);
8189 tv4->axis(-1)->parallelize(ParallelType::TIDx);
8190 // Manual Replay on TV3
8191 tv3->axis(-1)->parallelize(ParallelType::TIDx);
8192 tv8->axis(-1)->parallelize(ParallelType::TIDx);
8193
8194 constexpr int M = 92, N = 500;
8195
8196 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8197 at::Tensor t0 = at::randn({M}, options);
8198 at::Tensor t1 = at::randn({N}, options);
8199 std::vector<IValue> aten_inputs = {t0, t1};
8200 at::Tensor aten_output =
8201 t0.to(at::kDouble).unsqueeze(1).matmul(t1.to(at::kDouble).unsqueeze(0));
8202
8203 FusionExecutor fe;
8204 fe.compileFusion(&fusion, aten_inputs);
8205 auto cg_outputs = fe.runFusion(aten_inputs);
8206
8207 testValidate(
8208 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
8209}
8210
8211TEST_F(NVFuserTest, FusionCacheMultiConsumer_CUDA) {
8212 Fusion fusion;
8213 FusionGuard fg(&fusion);
8214
8215 TensorView* tv0 = makeSymbolicTensor(1);
8216 TensorView* tv1 = add(tv0, IrBuilder::create<Double>(1));
8217 TensorView* tv2 = add(tv1, IrBuilder::create<Double>(2));
8218 TensorView* tv3 = add(tv0, IrBuilder::create<Double>(1));
8219 TensorView* tv4 = add(tv3, IrBuilder::create<Double>(2));
8220
8221 fusion.addInput(tv0);
8222 fusion.addOutput(tv2);
8223 fusion.addOutput(tv4);
8224
8225 auto tv5 = tv1->cacheBefore();
8226 auto tv6 = tv3->cacheBefore();
8227 tv5->setMemoryType(MemoryType::Shared);
8228 tv6->setMemoryType(MemoryType::Shared);
8229
8230 tv1->computeAt(tv2, -1);
8231 tv3->computeAt(tv4, -1);
8232
8233 // Fails because tensor must be recomputed twice
8234 // auto tv7 = tv0->cacheAfter();
8235
8236 constexpr int N = 800;
8237
8238 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8239 at::Tensor aten_input = at::randn({N}, options);
8240 auto aten_output = (aten_input + 1) + 2;
8241
8242 FusionExecutor fe;
8243 fe.compileFusion(&fusion, {aten_input});
8244 auto cg_outputs = fe.runFusion({aten_input});
8245
8246 testValidate(
8247 &fusion,
8248 cg_outputs,
8249 {aten_input},
8250 {aten_output, aten_output},
8251 __LINE__,
8252 __FILE__);
8253}
8254
8255TEST_F(NVFuserTest, FusionSmem_CUDA) {
8256 Fusion fusion;
8257 FusionGuard fg(&fusion);
8258
8259 // Algorithm
8260 TensorView* tv0 = makeSymbolicTensor(2); // (M, N)
8261 TensorView* tv1 = makeSymbolicTensor(2); // (M, N)
8262 TensorView* tv2 = mul(tv0, tv1);
8263 fusion.addInput(tv0);
8264 fusion.addInput(tv1);
8265 fusion.addOutput(tv2);
8266
8267 // Schedule
8268 TensorView* tv3 = tv0->cacheAfter();
8269 TensorView* tv4 = tv1->cacheAfter();
8270 tv3->setMemoryType(MemoryType::Shared);
8271 tv4->setMemoryType(MemoryType::Shared);
8272
8273 constexpr int BSY = 32;
8274 constexpr int BSX = 128;
8275 tv2->split(0, BSY);
8276 tv2->split(2, BSX);
8277 // M/BSX, BSX, N/BSX, BSX
8278 tv2->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
8279 // M/BSX, N/BSX, BSX, BSX
8280
8281 tv0->computeAt(tv2, 2);
8282 tv1->computeAt(tv2, 2);
8283
8284 // Thread and Block binding
8285 tv2->axis(0)->parallelize(ParallelType::BIDx);
8286 tv2->axis(1)->parallelize(ParallelType::BIDy);
8287 tv2->axis(-1)->parallelize(ParallelType::TIDx);
8288 // Manual Binding
8289 tv3->axis(-1)->parallelize(ParallelType::TIDx);
8290 tv4->axis(-1)->parallelize(ParallelType::TIDx);
8291
8292 constexpr int M = 128, N = 10240;
8293
8294 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8295 at::Tensor t0 = at::randn({M, N}, options);
8296 at::Tensor t1 = at::randn({M, N}, options);
8297 at::Tensor aten_output = mul(t0, t1);
8298
8299 std::vector<IValue> aten_inputs = {t0, t1};
8300
8301 FusionExecutor fe;
8302 fe.compileFusion(&fusion, {t0, t1});
8303 auto cg_outputs = fe.runFusion({t0, t1});
8304
8305 testValidate(
8306 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
8307
8308 TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
8309}
8310
8311TEST_F(NVFuserTest, FusionSmemReduce_CUDA) {
8312 Fusion fusion;
8313 FusionGuard fg(&fusion);
8314
8315 // Algorithm
8316 TensorView* tv0 = makeSymbolicTensor(3); // M, K, N
8317 TensorView* tv1 = sum(tv0, {1}); // M, R, N
8318 fusion.addInput(tv0);
8319 fusion.addOutput(tv1);
8320
8321 TensorView* tv2 = tv0->cacheAfter();
8322 tv2->setMemoryType(MemoryType::Shared);
8323
8324 // Schedule
8325 constexpr int BSX = 32;
8326 tv1->split(2, BSX);
8327 tv1->split(1, 128);
8328 tv1->split(0, BSX);
8329 // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
8330 tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
8331 TensorView* tv3 = tv1->rFactor({-2});
8332
8333 tv0->computeAt(tv1, -2);
8334 tv0->computeAt(tv3, -2);
8335
8336 // Thread and Block binding
8337 tv1->axis(0)->parallelize(ParallelType::BIDx);
8338 tv1->axis(1)->parallelize(ParallelType::BIDy);
8339 tv1->axis(-1)->parallelize(ParallelType::TIDx);
8340 // Manual Binding
8341 tv2->axis(-1)->parallelize(ParallelType::TIDx);
8342 tv3->axis(-1)->parallelize(ParallelType::TIDx);
8343
8344 constexpr int M = 154, K = 45, N = 1524;
8345
8346 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8347 at::Tensor aten_input = at::randn({M, K, N}, options);
8348 at::Tensor aten_output = sum(aten_input.to(at::kDouble), {1});
8349
8350 FusionExecutor fe;
8351 fe.compileFusion(&fusion, {aten_input});
8352 auto cg_outputs = fe.runFusion({aten_input});
8353
8354 testValidate(
8355 &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
8356 TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
8357}
8358
8359TEST_F(NVFuserTest, FusionSmemBlockGemm_CUDA) {
8360 Fusion fusion;
8361 FusionGuard fg(&fusion);
8362
8363 // Algorithm
8364 TensorView* tv0 = makeSymbolicTensor(2); // (M, K)
8365 TensorView* tv1 = makeSymbolicTensor(2); // (K, N)
8366 TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
8367 TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
8368 TensorView* tv4 = mul(tv2, tv3); // M, K, N
8369 TensorView* tv5 = sum(tv4, {1}); // M, R, N
8370 fusion.addInput(tv0);
8371 fusion.addInput(tv1);
8372 fusion.addOutput(tv5);
8373
8374 // Schedule
8375 constexpr int BSX = 16;
8376 tv5->split(2, BSX - 1);
8377 tv5->split(1, BSX);
8378 tv5->split(0, BSX + 1);
8379 // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
8380 tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}});
8381 // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
8382 TensorView* tv6 = tv5->rFactor({-1});
8383
8384 tv2->setMemoryType(MemoryType::Shared);
8385 tv3->setMemoryType(MemoryType::Shared);
8386 tv4->setMemoryType(MemoryType::Shared);
8387 tv6->setMemoryType(MemoryType::Shared);
8388
8389 tv0->computeAt(tv5, 3);
8390 tv1->computeAt(tv5, 3);
8391
8392 // Thread and Block binding
8393 tv5->axis(0)->parallelize(ParallelType::BIDx);
8394 tv5->axis(1)->parallelize(ParallelType::BIDy);
8395 tv5->axis(-2)->parallelize(ParallelType::TIDy);
8396 tv5->axis(-1)->parallelize(ParallelType::TIDx);
8397 // Manual Binding
8398 tv2->axis(-3)->parallelize(ParallelType::TIDy);
8399 tv2->axis(-1)->parallelize(ParallelType::TIDx);
8400 tv3->axis(-1)->parallelize(ParallelType::TIDx);
8401 tv4->axis(-3)->parallelize(ParallelType::TIDy);
8402 tv4->axis(-1)->parallelize(ParallelType::TIDx);
8403 tv6->axis(-3)->parallelize(ParallelType::TIDy);
8404 tv6->axis(-2)->parallelize(ParallelType::TIDx);
8405
8406 // Make sure BIDx is makred as exact (see issue #1119)
8407 GpuLower gpulw(&fusion);
8408 TORCH_CHECK(gpulw.parallelDimensionMap().isExact(ParallelType::BIDx));
8409
8410 constexpr int M = 154, K = 45, N = 1524;
8411
8412 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8413 at::Tensor t0 = at::randn({M, K}, options);
8414 at::Tensor t1 = at::randn({K, N}, options);
8415
8416 std::vector<IValue> aten_inputs = {t0, t1};
8417 at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble));
8418
8419 FusionExecutor fe;
8420 fe.compileFusion(&fusion, {t0, t1});
8421 auto cg_outputs = fe.runFusion({t0, t1});
8422
8423 testValidate(
8424 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
8425
8426 TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
8427}
8428
8429TEST_F(NVFuserTest, FusionSmemBlockGemmCache_CUDA) {
8430 Fusion fusion;
8431 FusionGuard fg(&fusion);
8432
8433 // Algorithm
8434 TensorView* tv0 = makeSymbolicTensor(2); // (M, K)
8435 TensorView* tv1 = makeSymbolicTensor(2); // (K, N)
8436 TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
8437 TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
8438 TensorView* tv4 = mul(tv2, tv3); // M, K, N
8439 TensorView* tv5 = sum(tv4, {1}); // M, R, N
8440 fusion.addInput(tv0);
8441 fusion.addInput(tv1);
8442 fusion.addOutput(tv5);
8443
8444 // Schedule
8445 // Remove reduction axis from tv5
8446 // tv6 = (M, R, N)
8447 // tv5 = (M, N)
8448 TensorView* tv6 = tv5->cacheBefore();
8449
8450 constexpr int BSX = 16;
8451 tv5->split(1, BSX);
8452 tv5->split(0, BSX);
8453 // M/BSX, BSX, N/BSX, BSX
8454 tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
8455 // tv5 = M/BSX, N/BSX, MSX, NSX
8456
8457 tv6->computeAt(tv5, 2);
8458 tv6->computeAt(tv5, 2);
8459
8460 tv6->split(-1, BSX);
8461 // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
8462 tv6->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 4}, {4, 2}, {5, 5}});
8463 // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
8464 TensorView* tv7 = tv6->rFactor({-1});
8465 // tv7 = M/BSX, N/BSX, K/BSXrf, MSX, NSX, KSXr
8466 // tv6 = M/BSX, N/BSX, K/BSXr, MSX, NSX
8467
8468 tv0->computeAt(tv6, 3);
8469 tv1->computeAt(tv6, 3);
8470
8471 tv0->computeAt(tv7, 3);
8472 tv1->computeAt(tv7, 3);
8473
8474 tv2->setMemoryType(MemoryType::Shared);
8475 tv3->setMemoryType(MemoryType::Shared);
8476 tv4->setMemoryType(MemoryType::Shared);
8477 tv6->setMemoryType(MemoryType::Shared);
8478 tv7->setMemoryType(MemoryType::Shared);
8479 // Memory Type
8480
8481 // Thread and Block binding
8482 tv5->axis(0)->parallelize(ParallelType::BIDx);
8483 tv5->axis(1)->parallelize(ParallelType::BIDy);
8484 tv5->axis(-2)->parallelize(ParallelType::TIDy);
8485 tv5->axis(-1)->parallelize(ParallelType::TIDx);
8486 // Manual Binding
8487 tv2->axis(-3)->parallelize(ParallelType::TIDy);
8488 tv2->axis(-1)->parallelize(ParallelType::TIDx);
8489 tv3->axis(-1)->parallelize(ParallelType::TIDx);
8490 tv4->axis(-3)->parallelize(ParallelType::TIDy);
8491 tv4->axis(-1)->parallelize(ParallelType::TIDx);
8492
8493 tv7->axis(-3)->parallelize(ParallelType::TIDy);
8494 tv7->axis(-2)->parallelize(ParallelType::TIDx);
8495
8496 tv6->axis(-2)->parallelize(ParallelType::TIDy);
8497 tv6->axis(-1)->parallelize(ParallelType::TIDx);
8498
8499 constexpr int M = 154, K = 45, N = 1524;
8500
8501 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8502 at::Tensor t0 = at::randn({M, K}, options);
8503 at::Tensor t1 = at::randn({K, N}, options);
8504 at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble));
8505
8506 std::vector<IValue> aten_inputs = {t0, t1};
8507
8508 FusionExecutor fe;
8509 fe.compileFusion(&fusion, aten_inputs);
8510 auto cg_outputs = fe.runFusion(aten_inputs);
8511
8512 testValidate(
8513 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
8514
8515 TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
8516}
8517
8518TEST_F(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) {
8519 Fusion fusion;
8520 FusionGuard fg(&fusion);
8521
8522 TensorView* x = makeSymbolicTensor(2);
8523 fusion.addInput(x);
8524 TensorView* max_val = reductionOp(
8525 BinaryOpType::Max,
8526 {-1},
8527 IrBuilder::create<Double>(std::numeric_limits<float>::lowest()),
8528 x); // (M)
8529 TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
8530 TensorView* x_max_sub = sub(x, bcast_max); // (M, N)
8531 TensorView* exp = unaryOp(UnaryOpType::Exp, x_max_sub); // (M, N)
8532 TensorView* sum_exp = sum(exp, {-1}); // (M, R)
8533 TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
8534 TensorView* softmax = div(exp, bcast_sum); // (M, N)
8535 fusion.addOutput(softmax);
8536
8537 // Read Input into Shared Memory
8538 // Load Input + Pwise into shared memory
8539 auto cache_x = x->cacheAfter();
8540 cache_x->setMemoryType(MemoryType::Shared);
8541 exp->setMemoryType(MemoryType::Shared);
8542
8543 std::vector<TensorView*> all_tensors(
8544 {x,
8545 cache_x,
8546 max_val,
8547 bcast_max,
8548 x_max_sub,
8549 exp,
8550 sum_exp,
8551 bcast_sum,
8552 softmax});
8553
8554 auto tidx = IrBuilder::create<Int>();
8555 fusion.addInput(tidx);
8556
8557 for (auto tensor : all_tensors) {
8558 tensor->split(-1, tidx);
8559 }
8560
8561 auto sum_exp_rf = sum_exp->rFactor({1});
8562 all_tensors.push_back(sum_exp_rf);
8563
8564 // computeAt
8565 x->computeAt(x_max_sub, 1);
8566 exp->computeAt(softmax, 1);
8567 x_max_sub->computeAt(exp, 2);
8568
8569 softmax->axis(0)->parallelize(ParallelType::BIDx);
8570 for (auto tensor : all_tensors) {
8571 tensor->axis(-1)->parallelize(ParallelType::TIDx);
8572 }
8573
8574 const int64_t dimx = 1024;
8575 const int64_t dimy = 4096;
8576 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8577 at::Tensor aten_input = at::randn({dimx, dimy}, options);
8578 auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false);
8579
8580 torch::jit::fuser::cuda::FusionExecutor fe;
8581 fe.compileFusion(&fusion, {aten_input, 128});
8582 auto cg_outputs = fe.runFusion({aten_input, 128});
8583
8584 testValidate(
8585 &fusion,
8586 cg_outputs,
8587 {aten_input, 128},
8588 {aten_output},
8589 __LINE__,
8590 __FILE__);
8591}
8592
8593TEST_F(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) {
8594 Fusion fusion;
8595 FusionGuard fg(&fusion);
8596
8597 const int kReductionAxis = 3;
8598 std::vector<int64_t> input_shape{10, 10, 10, 67};
8599 TensorView* input = makeSymbolicTensor(input_shape.size());
8600 fusion.addInput(input);
8601
8602 auto output = softmax(input, kReductionAxis);
8603
8604 fusion.addOutput(output);
8605
8606 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8607 at::Tensor aten_input = at::randn(input_shape, options);
8608 auto aten_output =
8609 at::_softmax(aten_input.to(at::kDouble), kReductionAxis, false);
8610
8611 auto reduction_params = getPersistentHeuristics(&fusion, {aten_input});
8612 TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
8613
8614 schedulePersistentKernel(&fusion, *reduction_params);
8615
8616 auto lparams = reduction_params->lparams;
8617
8618 torch::jit::fuser::cuda::FusionExecutor fe;
8619 fe.compileFusion(&fusion, {aten_input}, lparams);
8620 auto cg_outputs = fe.runFusion({aten_input}, lparams);
8621
8622 testValidate(
8623 &fusion,
8624 cg_outputs,
8625 {aten_input},
8626 {aten_output},
8627 __LINE__,
8628 __FILE__,
8629 "",
8630 lparams);
8631}
8632
8633TEST_F(NVFuserTest, FusionTestMaskSoftmax_CUDA) {
8634 // This test is testing the usage of all padding tokens
8635 // with softmax like Bert might might use in a full padding
8636 // sequence.
8637 Fusion fusion;
8638 FusionGuard fg(&fusion);
8639
8640 const int kReductionAxis = 3;
8641 std::vector<int64_t> input_shape{256, 16, 128, 128};
8642 TensorView* input = makeSymbolicTensor(input_shape.size());
8643 TensorView* mask = makeSymbolicTensor(input_shape.size());
8644 fusion.addInput(input);
8645 fusion.addInput(mask);
8646
8647 auto out1 = add(input, mask);
8648 auto output = softmax(out1, kReductionAxis);
8649
8650 fusion.addOutput(output);
8651
8652 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8653 at::Tensor aten_input = at::randn(input_shape, options);
8654 at::Tensor aten_mask = at::ones(input_shape, options);
8655 // -10,000 is used here as a magic number because the padding
8656 // tokens need to be a value that gives a value close to zero
8657 // as to not influence softmax. Bert, in particular, does
8658 // not use -Infinity because sometimes it will have a
8659 // softmax of all padding tokkens that can result a divide by
8660 // zero that creates NaN result.
8661 aten_mask = aten_mask * -10000.0;
8662 auto aten_out1 = aten_input + aten_mask;
8663 auto aten_output = at::_softmax(aten_out1, kReductionAxis, false);
8664
8665 auto reduction_params =
8666 getPersistentHeuristics(&fusion, {aten_input, aten_mask});
8667 TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
8668
8669 schedulePersistentKernel(&fusion, *reduction_params);
8670
8671 auto lparams = reduction_params->lparams;
8672
8673 torch::jit::fuser::cuda::FusionExecutor fe;
8674 fe.compileFusion(&fusion, {aten_input, aten_mask}, lparams);
8675 auto cg_outputs = fe.runFusion({aten_input, aten_mask}, lparams);
8676
8677 testValidate(
8678 &fusion,
8679 cg_outputs,
8680 {aten_input, aten_mask},
8681 {aten_output},
8682 __LINE__,
8683 __FILE__,
8684 "",
8685 lparams);
8686}
8687
8688TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) {
8689 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
8690 Fusion& fusion = *fusion_ptr.get();
8691 FusionGuard fg(&fusion);
8692
8693 std::vector<int64_t> shape{20, 100, 35, 67};
8694 std::vector<int64_t> norm_shape{67};
8695
8696 const size_t kM = shape.size();
8697 const size_t kN = norm_shape.size();
8698 const size_t kOuterNumDims = kM - kN;
8699
8700 std::vector<int64_t> outer_shape;
8701 for (const auto idx : c10::irange(kOuterNumDims)) {
8702 outer_shape.push_back(shape[idx]);
8703 }
8704 for (const auto idx : c10::irange(kOuterNumDims, kM)) {
8705 outer_shape.push_back(1);
8706 }
8707
8708 auto grad_out = makeSymbolicTensor(shape.size());
8709 auto input = makeSymbolicTensor(shape.size());
8710 auto mean = makeConcreteTensor(outer_shape);
8711 auto rstd = makeConcreteTensor(outer_shape);
8712 auto weight = makeSymbolicTensor(norm_shape.size());
8713 auto bias = makeSymbolicTensor(norm_shape.size());
8714 fusion.addInput(grad_out);
8715 fusion.addInput(input);
8716 fusion.addInput(mean);
8717 fusion.addInput(rstd);
8718 fusion.addInput(weight);
8719 fusion.addInput(bias);
8720
8721 auto grads = layer_norm_backward(
8722 grad_out,
8723 input,
8724 norm_shape,
8725 mean,
8726 rstd,
8727 weight,
8728 bias,
8729 {true, true, true});
8730
8731 fusion.addOutput(grads.grad_input);
8732 fusion.addOutput(grads.grad_weight);
8733 fusion.addOutput(grads.grad_bias);
8734
8735 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8736 at::Tensor aten_grad_out = at::randn(shape, options);
8737 at::Tensor aten_input = at::randn(shape, options);
8738 at::Tensor aten_weight = at::randn(norm_shape, options);
8739 at::Tensor aten_bias = at::randn(norm_shape, options);
8740 auto at_weight = c10::optional<at::Tensor>(aten_weight);
8741 auto at_bias = c10::optional<at::Tensor>(aten_bias);
8742
8743 const float kEps = 1e-5;
8744 auto aten_results =
8745 at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps);
8746 auto aten_output = std::get<0>(aten_results);
8747 auto aten_mean = std::get<1>(aten_results);
8748 auto aten_rstd = std::get<2>(aten_results);
8749
8750 FusionExecutorCache fec(std::move(fusion_ptr));
8751 std::vector<IValue> aten_inputs = {
8752 aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias};
8753 auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
8754
8755 auto aten_gradients = at::native_layer_norm_backward(
8756 aten_grad_out.to(at::kDouble),
8757 aten_input.to(at::kDouble),
8758 norm_shape,
8759 aten_mean.to(at::kDouble),
8760 aten_rstd.to(at::kDouble),
8761 c10::optional<at::Tensor>(aten_weight.to(at::kDouble)),
8762 c10::optional<at::Tensor>(aten_bias.to(at::kDouble)),
8763 {true, true, true});
8764
8765 testValidate(
8766 &fusion,
8767 cg_outputs,
8768 aten_inputs,
8769 {std::get<0>(aten_gradients),
8770 std::get<1>(aten_gradients),
8771 std::get<2>(aten_gradients)},
8772 __LINE__,
8773 __FILE__);
8774}
8775
8776TEST_F(NVFuserTest, FusionMagicSchedulerRMSNormBackward_CUDA) {
8777 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
8778 Fusion& fusion = *fusion_ptr.get();
8779 FusionGuard fg(&fusion);
8780 const int64_t NORM_SIZE = 1024;
8781 std::vector<int64_t> shape{8, 56, NORM_SIZE};
8782 std::vector<int64_t> norm_shape{NORM_SIZE};
8783
8784 const size_t kM = shape.size();
8785 const size_t kN = norm_shape.size();
8786 const size_t kOuterNumDims = kM - kN;
8787
8788 std::vector<int64_t> outer_shape;
8789 for (const auto idx : c10::irange(kOuterNumDims)) {
8790 outer_shape.push_back(shape[idx]);
8791 }
8792 for (const auto idx : c10::irange(kOuterNumDims, kM)) {
8793 outer_shape.push_back(1);
8794 }
8795
8796 auto grad_out = makeContigTensor(shape.size());
8797 auto input = makeContigTensor(shape.size());
8798 auto rstd = makeConcreteTensor(outer_shape);
8799 auto weight = makeContigTensor(norm_shape.size());
8800 fusion.addInput(grad_out);
8801 fusion.addInput(input);
8802 fusion.addInput(rstd);
8803 fusion.addInput(weight);
8804
8805 auto grads = rms_norm_backward(
8806 grad_out, input, norm_shape, rstd, weight, {true, true});
8807
8808 fusion.addOutput(grads.grad_input);
8809 fusion.addOutput(grads.grad_weight);
8810
8811 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8812 at::Tensor aten_grad_out = at::randn(shape, options);
8813 at::Tensor aten_input = at::randn(shape, options);
8814 at::Tensor aten_weight = at::randn(norm_shape, options);
8815 auto at_weight = c10::optional<at::Tensor>(aten_weight);
8816
8817 const float kEps = 1e-6;
8818 auto pow2 = at::pow(aten_input, 2);
8819 auto sum = at::sum(pow2, -1, true);
8820 auto var = at::mul(sum, 1.0 / NORM_SIZE);
8821 auto aten_rstd = at::pow(at::add(var, kEps), -0.5);
8822
8823 FusionExecutorCache fec(std::move(fusion_ptr));
8824 std::vector<IValue> aten_inputs = {
8825 aten_grad_out, aten_input, aten_rstd, aten_weight};
8826 auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
8827
8828 auto in_mul_rstd = at::mul(aten_input, aten_rstd);
8829 auto grad_out_mul = at::mul(aten_grad_out, in_mul_rstd);
8830 auto aten_grad_weight = at::sum(grad_out_mul, c10::IntArrayRef{0, 1});
8831 auto sum_loss1 = at::sum(at::mul(aten_grad_out, aten_weight), -1, true);
8832 auto sum_loss2 = at::sum(
8833 at::mul(
8834 at::mul(at::mul(aten_grad_out, aten_weight), aten_input), aten_rstd),
8835 -1,
8836 true);
8837
8838 const float fH = NORM_SIZE;
8839 auto term1 = at::mul(aten_rstd, 1.0 / fH);
8840 auto aten_grad_input = at::mul(at::mul(aten_grad_out, fH), aten_weight);
8841 aten_grad_input = at::sub(aten_grad_input, sum_loss1);
8842 aten_grad_input = at::sub(
8843 aten_grad_input, at::mul(at::mul(aten_input, aten_rstd), sum_loss2));
8844 aten_grad_input = at::mul(aten_grad_input, term1);
8845 testValidate(
8846 &fusion,
8847 cg_outputs,
8848 aten_inputs,
8849 {aten_grad_input, aten_grad_weight},
8850 __LINE__,
8851 __FILE__);
8852}
8853
8854TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) {
8855 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
8856 Fusion& fusion = *fusion_ptr.get();
8857 FusionGuard fg(&fusion);
8858
8859 const float kEps = 1e-5;
8860 Double* eps_ptr = IrBuilder::create<Double>(kEps);
8861
8862 std::vector<int64_t> input_shape{20, 100, 35, 67};
8863 std::vector<int64_t> norm_shape{67};
8864
8865 auto input = makeSymbolicTensor(input_shape.size());
8866 fusion.addInput(input);
8867
8868 auto result = layer_norm(input, norm_shape, nullptr, nullptr, eps_ptr);
8869
8870 fusion.addOutput(result.output);
8871 fusion.addOutput(result.mean);
8872 fusion.addOutput(result.invstd);
8873
8874 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8875 at::Tensor aten_input = at::randn(input_shape, options);
8876 c10::optional<at::Tensor> aten_weight = c10::nullopt;
8877 c10::optional<at::Tensor> aten_bias = c10::nullopt;
8878 auto aten_outputs = at::native_layer_norm(
8879 aten_input, norm_shape, aten_weight, aten_bias, kEps);
8880
8881 // Check reduction axis is same for all reductions
8882 // Generate Launch Parameters
8883 auto reduction_params = getPersistentHeuristics(&fusion, {aten_input});
8884 TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
8885
8886 FusionExecutorCache fec(std::move(fusion_ptr));
8887 auto cg_outputs = fec.runFusionWithInputs({aten_input});
8888
8889 testValidate(
8890 &fusion,
8891 cg_outputs,
8892 {aten_input},
8893 {std::get<0>(aten_outputs),
8894 std::get<1>(aten_outputs),
8895 std::get<2>(aten_outputs)},
8896 __LINE__,
8897 __FILE__,
8898 "");
8899}
8900
8901TEST_F(NVFuserTest, FusionMagicSchedulerRMSNormalization_CUDA) {
8902 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
8903 Fusion& fusion = *fusion_ptr.get();
8904 FusionGuard fg(&fusion);
8905
8906 int64_t NORM_SIZE = 1024;
8907 const float kEps = 1e-6;
8908 Double* eps_ptr = IrBuilder::create<Double>(kEps);
8909
8910 std::vector<int64_t> input_shape{8, 56, NORM_SIZE};
8911 std::vector<int64_t> norm_shape{NORM_SIZE};
8912
8913 auto input = makeContigTensor(input_shape.size());
8914 fusion.addInput(input);
8915 auto result = rms_norm(input, norm_shape, nullptr, eps_ptr);
8916
8917 fusion.addOutput(result.output);
8918 fusion.addOutput(result.invstd);
8919
8920 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8921 at::Tensor aten_input = at::randn(input_shape, options);
8922 c10::optional<at::Tensor> aten_weight = c10::nullopt;
8923
8924 auto pow2 = at::pow(aten_input, 2);
8925
8926 auto sum = at::sum(pow2, -1, true);
8927 auto var = at::mul(sum, 1.0 / NORM_SIZE);
8928 auto invstd = at::pow(at::add(var, kEps), -0.5);
8929 auto output = at::mul(aten_input, invstd);
8930 //// Check reduction axis is same for all reductions
8931 //// Generate Launch Parameters
8932 auto reduction_params = getPersistentHeuristics(&fusion, {aten_input});
8933 TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
8934
8935 FusionExecutorCache fec(std::move(fusion_ptr));
8936 auto cg_outputs = fec.runFusionWithInputs({aten_input});
8937
8938 testValidate(
8939 &fusion,
8940 cg_outputs,
8941 {aten_input},
8942 {output, invstd},
8943 __LINE__,
8944 __FILE__,
8945 "");
8946}
8947
8948TEST_F(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) {
8949 if (!deviceMajorMinorCheck(7)) {
8950 GTEST_SKIP() << "skipping tests on pre-Volta GPUs";
8951 return;
8952 }
8953 auto fusion = std::make_unique<Fusion>();
8954 FusionGuard fg(fusion.get());
8955
8956 const float kMomentum = 0.1;
8957 const float kEps = 1e-5;
8958 const bool kTraining = true;
8959 std::vector<int64_t> input_shape{20, 100, 35, 45};
8960
8961 auto input = makeSymbolicTensor(input_shape.size());
8962 auto weight = makeSymbolicTensor(1);
8963 auto bias = makeSymbolicTensor(1);
8964 auto running_mean = makeSymbolicTensor(1);
8965 auto running_var = makeSymbolicTensor(1);
8966 fusion->addInput(input);
8967 fusion->addInput(weight);
8968 fusion->addInput(bias);
8969 fusion->addInput(running_mean);
8970 fusion->addInput(running_var);
8971
8972 Double* momentum = IrBuilder::create<Double>(kMomentum);
8973 Double* eps = IrBuilder::create<Double>(kEps);
8974
8975 auto result = batch_norm(
8976 input, weight, bias, running_mean, running_var, kTraining, momentum, eps);
8977
8978 fusion->addOutput(result.output);
8979 fusion->addOutput(result.mean);
8980 fusion->addOutput(result.invstd);
8981
8982 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
8983 auto at_input = at::randn(input_shape, options);
8984 auto at_weight = at::ones({input_shape[1]}, options);
8985 auto at_bias = at::zeros({input_shape[1]}, options);
8986 auto at_run_mean = at::zeros({input_shape[1]}, options);
8987 auto at_run_var = at::ones({input_shape[1]}, options);
8988
8989 std::vector<IValue> aten_inputs = {
8990 at_input, at_weight, at_bias, at_run_mean, at_run_var};
8991
8992 FusionExecutorCache executor_cache(std::move(fusion));
8993
8994 auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);
8995
8996 auto aten_outputs = at::native_batch_norm(
8997 at_input,
8998 c10::optional<at::Tensor>(at_weight),
8999 c10::optional<at::Tensor>(at_bias),
9000 c10::optional<at::Tensor>(at_run_mean),
9001 c10::optional<at::Tensor>(at_run_var),
9002 kTraining,
9003 kMomentum,
9004 kEps);
9005
9006 testValidate(
9007 executor_cache.fusion(),
9008 cg_outputs,
9009 aten_inputs,
9010 {std::get<0>(aten_outputs),
9011 std::get<1>(aten_outputs),
9012 std::get<2>(aten_outputs)},
9013 __LINE__,
9014 __FILE__,
9015 "");
9016}
9017
9018TEST_F(NVFuserTest, FusionMagicSchedulerInstanceNormalization_CUDA) {
9019 if (!deviceMajorMinorCheck(7)) {
9020 GTEST_SKIP() << "skipping tests on pre-Volta GPUs";
9021 return;
9022 }
9023 auto fusion = std::make_unique<Fusion>();
9024 FusionGuard fg(fusion.get());
9025
9026 const float kMomentum = 0.1;
9027 const float kEps = 1e-5;
9028 const bool kUseInputStats = true;
9029 std::vector<int64_t> input_shape{20, 100, 35, 45};
9030
9031 auto input = makeSymbolicTensor(input_shape.size());
9032 auto weight = makeSymbolicTensor(1);
9033 auto bias = makeSymbolicTensor(1);
9034 auto running_mean = makeSymbolicTensor(1);
9035 auto running_var = makeSymbolicTensor(1);
9036 fusion->addInput(input);
9037 fusion->addInput(weight);
9038 fusion->addInput(bias);
9039 fusion->addInput(running_mean);
9040 fusion->addInput(running_var);
9041
9042 Double* momentum = IrBuilder::create<Double>(kMomentum);
9043 Double* eps = IrBuilder::create<Double>(kEps);
9044
9045 auto result = instance_norm(
9046 input,
9047 weight,
9048 bias,
9049 running_mean,
9050 running_var,
9051 kUseInputStats,
9052 momentum,
9053 eps);
9054
9055 fusion->addOutput(result.output);
9056 // fusion->addOutput(result.mean);
9057 // fusion->addOutput(result.invstd);
9058
9059 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
9060 auto at_input = at::randn(input_shape, options);
9061 auto at_weight = at::ones({input_shape[1]}, options);
9062 auto at_bias = at::zeros({input_shape[1]}, options);
9063 auto at_run_mean = at::zeros({input_shape[1]}, options);
9064 auto at_run_var = at::ones({input_shape[1]}, options);
9065
9066 std::vector<IValue> aten_inputs = {
9067 at_input, at_weight, at_bias, at_run_mean, at_run_var};
9068
9069 FusionExecutorCache executor_cache(std::move(fusion));
9070
9071 auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);
9072 auto cg_outputs_full = {at_run_mean, at_run_var, cg_outputs[0]};
9073
9074 auto aten_outputs = at::instance_norm(
9075 at_input,
9076 c10::optional<at::Tensor>(at_weight),
9077 c10::optional<at::Tensor>(at_bias),
9078 c10::optional<at::Tensor>(at_run_mean),
9079 c10::optional<at::Tensor>(at_run_var),
9080 kUseInputStats,
9081 kMomentum,
9082 kEps,
9083 false);
9084
9085 testValidate(
9086 executor_cache.fusion(),
9087 cg_outputs,
9088 aten_inputs,
9089 // TODO: can run_mean/run_var be checked here?
9090 // fusion_outputs.size() == aten_outputs.size() && aten_outputs.size() ==
9091 // fusion->outputs().size() - output_alias_indices.size()
9092 {aten_outputs},
9093 __LINE__,
9094 __FILE__,
9095 "");
9096}
9097
9098TEST_F(NVFuserTest, FusionMagicSchedulerInstanceNormalizationBackward_CUDA) {
9099 if (!deviceMajorMinorCheck(7)) {
9100 GTEST_SKIP() << "skipping tests on pre-Volta GPUs";
9101 return;
9102 }
9103 auto fusion_forward = std::make_unique<Fusion>();
9104 FusionGuard fg_forward(fusion_forward.get());
9105
9106 const float kMomentum = 0.1;
9107 const float kEps = 1e-5;
9108 const bool kUseInputStats = true;
9109 const bool channels_last = true;
9110 const int B = 2;
9111 const int C = 5;
9112 const int S = 3;
9113 std::vector<int64_t> input_shape{B, C, S, S, S};
9114 // explicit channels-last for NVFuser
9115 std::vector<int64_t> nvfuser_input_shape{B, S, S, S, C};
9116
9117 auto input = makeContigTensor(input_shape.size());
9118 auto weight = makeContigTensor(1);
9119 auto bias = makeContigTensor(1);
9120 fusion_forward->addInput(input);
9121 fusion_forward->addInput(weight);
9122 fusion_forward->addInput(bias);
9123
9124 Double* momentum = IrBuilder::create<Double>(kMomentum);
9125 Double* eps = IrBuilder::create<Double>(kEps);
9126 auto result_forward = instance_norm(
9127 input,
9128 weight,
9129 bias,
9130 nullptr,
9131 nullptr,
9132 kUseInputStats,
9133 momentum,
9134 eps,
9135 channels_last);
9136 fusion_forward->addOutput(result_forward.output);
9137 fusion_forward->addOutput(result_forward.mean);
9138 fusion_forward->addOutput(result_forward.invstd);
9139
9140 FusionExecutorCache executor_cache_forward(std::move(fusion_forward));
9141
9142 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
9143 auto at_input = at::randn(input_shape, options)
9144 .to(at::MemoryFormat::ChannelsLast3d)
9145 .set_requires_grad(true);
9146 auto at_input_nvfuser = at_input.clone().detach().permute({0, 2, 3, 4, 1});
9147 auto at_weight = at::ones({input_shape[1]}, options).set_requires_grad(true);
9148 auto at_weight_nvfuser = at_weight.clone().detach();
9149 auto at_bias = at::zeros({input_shape[1]}, options).set_requires_grad(true);
9150 auto at_bias_nvfuser = at_bias.clone().detach();
9151 std::vector<torch::jit::IValue> aten_inputs_forward = {
9152 at_input_nvfuser, at_weight_nvfuser, at_bias_nvfuser};
9153 // out, mean, invstd
9154 auto outputs_forward =
9155 executor_cache_forward.runFusionWithInputs(aten_inputs_forward);
9156 auto at_out = at::instance_norm(
9157 at_input,
9158 c10::optional<at::Tensor>(at_weight),
9159 c10::optional<at::Tensor>(at_bias),
9160 c10::optional<at::Tensor>(c10::nullopt),
9161 c10::optional<at::Tensor>(c10::nullopt),
9162 kUseInputStats,
9163 kMomentum,
9164 kEps,
9165 false);
9166 auto at_grad =
9167 at::randn(input_shape, options).to(at::MemoryFormat::ChannelsLast3d);
9168 auto at_grad_nvfuser = at_grad.clone().detach().permute({0, 2, 3, 4, 1});
9169 at_out.backward(at_grad);
9170 auto fusion_backward = std::make_unique<Fusion>();
9171 FusionGuard fg_backward(fusion_backward.get());
9172
9173 input = makeContigTensor(input_shape.size());
9174 auto grad_output = makeContigTensor(input_shape.size());
9175 weight = makeContigTensor(1);
9176 auto save_mean = makeContigTensor(2);
9177 auto save_invstd = makeContigTensor(2);
9178 auto dummy = makeContigTensor(0);
9179
9180 fusion_backward->addInput(input);
9181 fusion_backward->addInput(grad_output);
9182 fusion_backward->addInput(weight);
9183 fusion_backward->addInput(dummy); // dummy for run_mean
9184 fusion_backward->addInput(dummy); // dummy for run_var
9185 fusion_backward->addInput(save_mean);
9186 fusion_backward->addInput(save_invstd);
9187
9188 auto result_backward = instance_norm_backward(
9189 input,
9190 grad_output,
9191 weight,
9192 nullptr,
9193 nullptr,
9194 save_mean,
9195 save_invstd,
9196 kUseInputStats,
9197 eps,
9198 {true, true, true},
9199 channels_last);
9200
9201 fusion_backward->addOutput(result_backward.grad_input);
9202 fusion_backward->addOutput(result_backward.grad_weight);
9203 fusion_backward->addOutput(result_backward.grad_bias);
9204
9205 FusionExecutorCache executor_cache_backward(std::move(fusion_backward));
9206 std::vector<torch::jit::IValue> aten_inputs_backward = {
9207 at_input_nvfuser,
9208 at_grad_nvfuser,
9209 at_weight_nvfuser,
9210 at::empty({}),
9211 at::empty({}),
9212 outputs_forward[1],
9213 outputs_forward[2]};
9214 auto outputs_backward =
9215 executor_cache_backward.runFusionWithInputs(aten_inputs_backward);
9216 outputs_backward[0] = outputs_backward[0].permute({0, 4, 1, 2, 3});
9217 testValidate(
9218 executor_cache_backward.fusion(),
9219 outputs_backward,
9220 aten_inputs_backward,
9221 {at_input.grad(), at_weight.grad(), at_bias.grad()},
9222 __LINE__,
9223 __FILE__,
9224 "");
9225}
9226
9227TEST_F(NVFuserTest, FusionPersistentSoftmaxLocalShared_CUDA) {
9228 Fusion fusion;
9229 FusionGuard fg(&fusion);
9230
9231 const int pixels_per_thread = 64;
9232 const int TIDX = 128;
9233 const int static_size = pixels_per_thread * TIDX;
9234
9235 TensorView* sx = makeConcreteTensor({-1, static_size});
9236 TensorView* dx = makeSymbolicTensor(2);
9237 fusion.addInput(sx);
9238 fusion.addInput(dx);
9239
9240 TensorView* max_sx = reductionOp(
9241 BinaryOpType::Max,
9242 {-1},
9243 IrBuilder::create<Double>(std::numeric_limits<float>::lowest()),
9244 sx); // (M)
9245 TensorView* max_dx = reductionOp(
9246 BinaryOpType::Max,
9247 {-1},
9248 IrBuilder::create<Double>(std::numeric_limits<float>::lowest()),
9249 dx); // (M)
9250
9251 // Reduction => merge local and shared memory TensorViews
9252 TensorView* max_val = binaryOp(BinaryOpType::Max, max_sx, max_dx);
9253 TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
9254
9255 TensorView* sx_max_sub = sub(sx, bcast_max); // (M, N)
9256 TensorView* dx_max_sub = sub(dx, bcast_max); // (M, N)
9257
9258 TensorView* sx_exp = unaryOp(UnaryOpType::Exp, sx_max_sub); // (M, N)
9259 TensorView* dx_exp = unaryOp(UnaryOpType::Exp, dx_max_sub); // (M, N)
9260
9261 TensorView* sx_sum_exp = sum(sx_exp, {-1}); // (M, R)
9262 TensorView* dx_sum_exp = sum(dx_exp, {-1}); // (M, R)
9263
9264 // Reduction => merge local and shared memory TensorViews
9265 TensorView* sum_exp = binaryOp(BinaryOpType::Add, sx_sum_exp, dx_sum_exp);
9266 TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
9267
9268 TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N)
9269 TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N)
9270 fusion.addOutput(sx_softmax);
9271 fusion.addOutput(dx_softmax);
9272
9273 auto sx_cache = sx->cacheAfter();
9274 auto dx_cache = dx->cacheAfter();
9275 dx_cache->setMemoryType(MemoryType::Shared);
9276 dx_exp->setMemoryType(MemoryType::Shared);
9277
9278 // Reduction and Broadcast Tensors common to both memory TVs
9279 std::vector<TensorView*> common_tensors(
9280 {max_val, sum_exp, bcast_max, bcast_sum});
9281
9282 // Static Local Memory TVs
9283 std::vector<TensorView*> static_tensors(
9284 {sx, sx_cache, max_sx, sx_max_sub, sx_exp, sx_sum_exp, sx_softmax});
9285
9286 // Dynamic Local Memory TVs
9287 std::vector<TensorView*> dynamic_tensors(
9288 {dx, dx_cache, max_dx, dx_max_sub, dx_exp, dx_sum_exp, dx_softmax});
9289
9290 std::vector<TensorView*> all_tensors;
9291 all_tensors.insert(
9292 all_tensors.end(), common_tensors.begin(), common_tensors.end());
9293 all_tensors.insert(
9294 all_tensors.end(), static_tensors.begin(), static_tensors.end());
9295 all_tensors.insert(
9296 all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
9297
9298 // M => M
9299 // M, N => M, N/128, 128
9300 for (auto tensor : all_tensors) {
9301 if (tensor->nDims() > 1) {
9302 tensor->split(-1, TIDX);
9303 }
9304 }
9305
9306 auto sx_sum_exp_rf = sx_sum_exp->rFactor({1});
9307 auto dx_sum_exp_rf = dx_sum_exp->rFactor({1});
9308 all_tensors.push_back(sx_sum_exp_rf);
9309 all_tensors.push_back(dx_sum_exp_rf);
9310
9311 // computeAt
9312 sx->computeAt(sx_max_sub, 1);
9313 dx->computeAt(dx_max_sub, 1);
9314
9315 sx_exp->computeAt(sx_softmax, 1);
9316 dx_exp->computeAt(dx_softmax, 1);
9317
9318 sx_max_sub->computeAt(sx_exp, 2);
9319 dx_max_sub->computeAt(dx_exp, 2);
9320
9321 sx_softmax->axis(0)->parallelize(ParallelType::BIDx);
9322 dx_softmax->axis(0)->parallelize(ParallelType::BIDx);
9323 for (auto tensor : all_tensors) {
9324 if (tensor->nDims() > 1) {
9325 tensor->axis(-1)->parallelize(ParallelType::TIDx);
9326 }
9327 }
9328
9329 const int64_t dimx = 1024;
9330 const int64_t dimy = 16384;
9331
9332 auto properties = at::cuda::getDeviceProperties(0);
9333 const size_t required_smem_size =
9334 (dimy - static_size) * sizeof(float) + TIDX * sizeof(float);
9335 if (properties->sharedMemPerBlockOptin < required_smem_size) {
9336 GTEST_SKIP() << "not enough shared memory space on device to run test: "
9337 << properties->sharedMemPerBlock;
9338 }
9339
9340 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
9341 at::Tensor aten_input = at::randn({dimx, dimy}, options);
9342 at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size);
9343 at::Tensor aten_dynamic_in =
9344 aten_input.narrow(1, static_size, dimy - static_size);
9345
9346 at::Tensor out = at::zeros({dimx, dimy}, options);
9347 at::Tensor cg_static_out = out.narrow(1, 0, static_size);
9348 at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size);
9349
9350 std::vector<at::Tensor> aten_outputs;
9351
9352 auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false);
9353 at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size);
9354 at::Tensor aten_dynamic_out =
9355 aten_output.narrow(1, static_size, dimy - static_size);
9356
9357 torch::jit::fuser::cuda::FusionExecutor fe;
9358 fe.compileFusion(&fusion, {aten_static_in, aten_dynamic_in});
9359 fe.runFusion(
9360 {aten_static_in, aten_dynamic_in}, {cg_static_out, cg_dynamic_out});
9361
9362 testValidate(
9363 &fusion,
9364 {cg_static_out, cg_dynamic_out},
9365 {aten_static_in, aten_dynamic_in},
9366 {cg_static_out, cg_dynamic_out},
9367 __LINE__,
9368 __FILE__);
9369}
9370
9371TEST_F(NVFuserTest, FusionPersistentNormLocalShared_CUDA) {
9372 Fusion fusion;
9373 FusionGuard fg(&fusion);
9374
9375 const int pixels_per_thread = 64;
9376 const int TIDX = 128;
9377 const int static_size = pixels_per_thread * TIDX;
9378
9379 TensorView* sx = makeConcreteTensor({-1, static_size});
9380 TensorView* dx = makeSymbolicTensor(2);
9381 fusion.addInput(sx);
9382 fusion.addInput(dx);
9383
9384 Double* gamma = IrBuilder::create<Double>();
9385 Double* beta = IrBuilder::create<Double>();
9386 Double* eps = IrBuilder::create<Double>();
9387 Int* N = IrBuilder::create<Int>();
9388 fusion.addInput(gamma);
9389 fusion.addInput(beta);
9390 fusion.addInput(eps);
9391 fusion.addInput(N);
9392
9393 // Reduction
9394 auto sx_sum = sum(sx, {-1}); // (M, R)
9395 auto dx_sum = sum(dx, {-1}); // (M, R)
9396 // Reduction => merge local and shared memory TensorViews
9397 auto x_sum = binaryOp(BinaryOpType::Add, sx_sum, dx_sum);
9398
9399 // Broadcast
9400 auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
9401 // Pwise
9402 auto x_mean = div(x_sum_bcast, N); // (M, B)
9403
9404 auto sx_mean_sub = sub(sx, x_mean); // (M, N)
9405 auto dx_mean_sub = sub(dx, x_mean); // (M, N)
9406
9407 auto sx_mean_sub_pow = mul(sx_mean_sub, sx_mean_sub); // (M, N)
9408 auto dx_mean_sub_pow = mul(dx_mean_sub, dx_mean_sub); // (M, N)
9409
9410 // Reduction
9411 auto sx_var_sum = sum(sx_mean_sub_pow, {-1}); // (M, R)
9412 auto dx_var_sum = sum(dx_mean_sub_pow, {-1}); // (M, R)
9413 // Reduction => merge local and shared memory TensorViews
9414 auto var_sum = binaryOp(BinaryOpType::Add, sx_var_sum, dx_var_sum);
9415
9416 // Broadcast
9417 auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
9418 // Pwise
9419 auto var = div(var_sum_bcast, N); // (M, B)
9420 auto var_eps = add(var, eps); // (M, B)
9421 auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
9422
9423 auto sx_norm = mul(sx_mean_sub, rvar);
9424 auto dx_norm = mul(dx_mean_sub, rvar);
9425
9426 auto sx_norm_gamma = mul(sx_norm, gamma);
9427 auto dx_norm_gamma = mul(dx_norm, gamma);
9428
9429 auto sx_norm_gamma_beta = add(sx_norm_gamma, beta);
9430 auto dx_norm_gamma_beta = add(dx_norm_gamma, beta);
9431
9432 fusion.addOutput(sx_norm_gamma_beta);
9433 fusion.addOutput(dx_norm_gamma_beta);
9434
9435 sx_norm_gamma_beta->setContiguity(false);
9436 dx_norm_gamma_beta->setContiguity(false);
9437
9438 // Read Input into Shared Memory
9439 // Read Input minus Input_Mean into Shared Memory
9440 auto sx_cache = sx->cacheAfter();
9441 auto dx_cache = dx->cacheAfter();
9442 dx_cache->setMemoryType(MemoryType::Shared);
9443 dx_mean_sub->setMemoryType(MemoryType::Shared);
9444
9445 std::vector<TensorView*> common_tensors(
9446 {x_sum, x_sum_bcast, x_mean, var_sum, var_sum_bcast, var, var_eps, rvar});
9447
9448 std::vector<TensorView*> static_tensors(
9449 {sx,
9450 sx_cache,
9451 sx_sum,
9452 sx_mean_sub,
9453 sx_mean_sub_pow,
9454 sx_var_sum,
9455 sx_norm,
9456 sx_norm_gamma,
9457 sx_norm_gamma_beta});
9458
9459 std::vector<TensorView*> dynamic_tensors(
9460 {dx,
9461 dx_cache,
9462 dx_sum,
9463 dx_mean_sub,
9464 dx_mean_sub_pow,
9465 dx_var_sum,
9466 dx_norm,
9467 dx_norm_gamma,
9468 dx_norm_gamma_beta});
9469
9470 std::vector<TensorView*> all_tensors;
9471 all_tensors.insert(
9472 all_tensors.end(), common_tensors.begin(), common_tensors.end());
9473 all_tensors.insert(
9474 all_tensors.end(), static_tensors.begin(), static_tensors.end());
9475 all_tensors.insert(
9476 all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
9477
9478 // M => M
9479 // M, N => M, N/128, 128
9480 for (auto tensor : all_tensors) {
9481 if (tensor->nDims() > 1) {
9482 tensor->split(-1, TIDX);
9483 }
9484 }
9485
9486 // Local Sum => Block Broadcast
9487 TensorView* sx_sum_rf = sx_sum->rFactor({1});
9488 TensorView* sx_var_sum_rf = sx_var_sum->rFactor({1});
9489 TensorView* dx_sum_rf = dx_sum->rFactor({1});
9490 TensorView* dx_var_sum_rf = dx_var_sum->rFactor({1});
9491 all_tensors.push_back(sx_sum_rf);
9492 all_tensors.push_back(sx_var_sum_rf);
9493 all_tensors.push_back(dx_sum_rf);
9494 all_tensors.push_back(dx_var_sum_rf);
9495
9496 // ComputeAt
9497 sx->computeAt(sx_mean_sub_pow, 1);
9498 dx->computeAt(dx_mean_sub_pow, 1);
9499
9500 var_sum->computeAt(rvar, 1);
9501
9502 sx_mean_sub_pow->computeAt(sx_var_sum_rf, 2);
9503 dx_mean_sub_pow->computeAt(dx_var_sum_rf, 2);
9504
9505 sx_norm->computeAt(sx_norm_gamma_beta, 2);
9506 dx_norm->computeAt(dx_norm_gamma_beta, 2);
9507
9508 sx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
9509 dx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
9510 for (auto tensor : all_tensors) {
9511 if (tensor->nDims() > 1) {
9512 tensor->axis(-1)->parallelize(ParallelType::TIDx);
9513 }
9514 }
9515
9516 const int dimx = 1024;
9517 const int dimy = 16384;
9518 const float kGamma = 1.0f;
9519 const float kBeta = 0.0f;
9520 const float kEps = 1e-5;
9521 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
9522
9523 auto properties = at::cuda::getDeviceProperties(0);
9524 const size_t required_smem_size =
9525 (dimy - static_size) * sizeof(float) + TIDX * sizeof(float);
9526 if (properties->sharedMemPerBlockOptin < required_smem_size) {
9527 GTEST_SKIP() << "not enough shared memory space on device to run test: "
9528 << properties->sharedMemPerBlock;
9529 }
9530
9531 at::Tensor aten_input = at::randn({dimx, dimy}, options);
9532 at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size);
9533 at::Tensor aten_dynamic_in =
9534 aten_input.narrow(1, static_size, dimy - static_size);
9535
9536 at::Tensor out = at::zeros({dimx, dimy}, options);
9537 at::Tensor cg_static_out = out.narrow(1, 0, static_size);
9538 at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size);
9539
9540 std::vector<IValue> aten_inputs = {
9541 aten_static_in, aten_dynamic_in, kGamma, kBeta, kEps, dimy};
9542
9543 torch::jit::fuser::cuda::FusionExecutor fe;
9544 fe.compileFusion(&fusion, aten_inputs);
9545
9546 fe.runFusion(aten_inputs, {cg_static_out, cg_dynamic_out});
9547
9548 auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1);
9549 auto at_var = at::var(aten_input.to(at::kDouble), -1, false).unsqueeze(1);
9550 auto at_rvar = at::rsqrt(at::add(at_var, kEps));
9551 auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar);
9552 auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta);
9553 at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size);
9554 at::Tensor aten_dynamic_out =
9555 aten_output.narrow(1, static_size, dimy - static_size);
9556
9557 testValidate(
9558 &fusion,
9559 {cg_static_out, cg_dynamic_out},
9560 aten_inputs,
9561 {aten_static_out, aten_dynamic_out},
9562 __LINE__,
9563 __FILE__);
9564}
9565
9566TEST_F(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) {
9567 Fusion fusion;
9568 FusionGuard fg(&fusion);
9569
9570 // Set up your input tensor views
9571 auto x = makeSymbolicTensor(2);
9572 Double* gamma = IrBuilder::create<Double>();
9573 Double* beta = IrBuilder::create<Double>();
9574 Double* eps = IrBuilder::create<Double>();
9575 Int* N = IrBuilder::create<Int>();
9576 fusion.addInput(x);
9577 fusion.addInput(gamma);
9578 fusion.addInput(beta);
9579 fusion.addInput(eps);
9580 fusion.addInput(N);
9581
9582 // Reduction
9583 auto x_sum = sum(x, {-1}); // (M, R)
9584 // Broadcast
9585 auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
9586 // Pwise
9587 auto x_mean = div(x_sum_bcast, N); // (M, B)
9588 auto x_mean_sub = sub(x, x_mean); // (M, N)
9589 auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); // (M, N)
9590 // Reduction
9591 auto var_sum = sum(x_mean_sub_pow, {-1}); // (M, R)
9592 // Broadcast
9593 auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
9594 // Pwise
9595 auto var = div(var_sum_bcast, N); // (M, B)
9596 auto var_eps = add(var, eps); // (M, B)
9597 auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
9598 auto norm = mul(x_mean_sub, rvar);
9599 auto norm_gamma = mul(norm, gamma);
9600 auto norm_gamma_beta = add(norm_gamma, beta);
9601 fusion.addOutput(norm_gamma_beta);
9602
9603 // Read Input into Shared Memory
9604 // Read Input minus Input_Mean into Shared Memory
9605 auto cache_x = x->cacheAfter();
9606 cache_x->setMemoryType(MemoryType::Shared);
9607 x_mean_sub->setMemoryType(MemoryType::Shared);
9608
9609 std::vector<TensorView*> all_tensors(
9610 {x_sum,
9611 x_mean,
9612 cache_x,
9613 x_sum_bcast,
9614 x_mean_sub,
9615 x_mean_sub_pow,
9616 var_sum,
9617 var_sum_bcast,
9618 var,
9619 var_eps,
9620 rvar,
9621 norm,
9622 norm_gamma,
9623 norm_gamma_beta});
9624
9625 auto tidx = IrBuilder::create<Int>();
9626 fusion.addInput(tidx);
9627
9628 for (auto tensor : all_tensors) {
9629 tensor->split(-1, tidx);
9630 }
9631
9632 // Local Sum => Block Broadcast
9633 TensorView* x_sum_rf = x_sum->rFactor({1});
9634 TensorView* var_sum_rf = var_sum->rFactor({1});
9635 all_tensors.push_back(x_sum_rf);
9636 all_tensors.push_back(var_sum_rf);
9637
9638 // ComputeAt
9639 x->computeAt(x_mean_sub_pow, 1);
9640 var_sum->computeAt(rvar, 1);
9641 x_mean_sub_pow->computeAt(var_sum_rf, 2);
9642 norm->computeAt(norm_gamma_beta, 2);
9643
9644 for (auto tv : all_tensors) {
9645 tv->axis(0)->parallelize(ParallelType::BIDx);
9646 tv->axis(-1)->parallelize(ParallelType::TIDx);
9647 }
9648
9649 const int dimx = 128;
9650 const int dimy = 2048;
9651 const float kGamma = 1.0f;
9652 const float kBeta = 0.0f;
9653 const float kEps = 1e-5;
9654 const int TIDX = 128;
9655
9656 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
9657 at::Tensor aten_input = at::randn({dimx, dimy}, options);
9658 auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1);
9659 auto at_var = at::var(aten_input.to(at::kDouble), -1).unsqueeze(1);
9660 auto at_rvar = at::rsqrt(at::add(at_var, kEps));
9661 auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar);
9662 auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta);
9663
9664 std::vector<IValue> aten_inputs = {
9665 aten_input, kGamma, kBeta, kEps, dimy, TIDX};
9666
9667 torch::jit::fuser::cuda::FusionExecutor fe;
9668 fe.compileFusion(&fusion, aten_inputs);
9669 auto cg_outputs = fe.runFusion(aten_inputs);
9670
9671 testValidate(
9672 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
9673}
9674
9675TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) {
9676 Fusion fusion;
9677 FusionGuard fg(&fusion);
9678
9679 // Set up your input tensor views
9680 TensorView* tv0 = makeSymbolicTensor(2);
9681 TensorView* tv1 =
9682 reductionOp(BinaryOpType::Add, {1}, IrBuilder::create<Double>(0), tv0);
9683 fusion.addInput(tv0);
9684 fusion.addOutput(tv1);
9685 // tv1[I0, R1] = tv0[I0, I1]
9686
9687 // Interface should just be a direct split with a Parallel type. We can
9688 // include the parallelize call if we do this.
9689 tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
9690 // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
9691
9692 TensorView* tv2 = tv1->rFactor({2});
9693 tv2->setMemoryType(MemoryType::Shared);
9694 // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1]
9695 // tv1[I0, R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}]
9696
9697 tv0->computeAt(tv1, 1);
9698
9699 tv2->axis(-1)->parallelize(ParallelType::TIDx);
9700 tv1->axis(0)->parallelize(ParallelType::BIDx);
9701
9702 constexpr int numel_x = 65000, numel_y = 1024;
9703
9704 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
9705 at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
9706 auto aten_output = aten_input.to(at::kDouble).sum({1});
9707
9708 // How many threads to use for the block reduction
9709 constexpr int runtime_threadIdx_dim = 128;
9710
9711 LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
9712
9713 FusionExecutor fe;
9714 fe.compileFusion(&fusion, {aten_input}, lparams);
9715 auto cg_outputs = fe.runFusion({aten_input}, lparams);
9716
9717 testValidate(
9718 &fusion,
9719 cg_outputs,
9720 {aten_input},
9721 {aten_output},
9722 __LINE__,
9723 __FILE__,
9724 "",
9725 lparams);
9726 TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
9727}
9728
9729TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) {
9730 Fusion fusion;
9731 FusionGuard fg(&fusion);
9732
9733 // Algorithm
9734 Int* sym_bsx = IrBuilder::create<Int>();
9735 TensorView* tv0 = makeSymbolicTensor(3); // M, K, N
9736 fusion.addInput(tv0);
9737 fusion.addInput(sym_bsx);
9738
9739 TensorView* tv1 = sum(tv0, {1}); // M, R, N
9740 fusion.addOutput(tv1);
9741
9742 TensorView* tv2 = tv0->cacheAfter();
9743 tv2->setMemoryType(MemoryType::Shared);
9744
9745 // Schedule
9746 constexpr int BSX = 32;
9747 tv1->split(2, BSX);
9748 tv1->split(1, sym_bsx);
9749 tv1->split(0, BSX);
9750 // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
9751 tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
9752 TensorView* tv3 = tv1->rFactor({-2});
9753
9754 tv0->computeAt(tv1, -2);
9755 tv0->computeAt(tv3, -2);
9756
9757 // Thread and Block binding
9758 tv1->axis(0)->parallelize(ParallelType::BIDx);
9759 tv1->axis(1)->parallelize(ParallelType::BIDy);
9760 tv1->axis(-1)->parallelize(ParallelType::TIDx);
9761 // Manual Binding
9762 tv2->axis(-1)->parallelize(ParallelType::TIDx);
9763 tv3->axis(-1)->parallelize(ParallelType::TIDx);
9764
9765 constexpr int M = 154, K = 45, N = 1524;
9766
9767 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
9768 at::Tensor aten_input = at::randn({M, K, N}, options);
9769 at::Tensor aten_output = aten_input.to(at::kDouble).sum({1});
9770
9771 // How many threads to use for the block reduction
9772 constexpr int runtime_threadIdx_dim = 128;
9773
9774 auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
9775
9776 FusionExecutor fe;
9777 fe.compileFusion(&fusion, {aten_input, runtime_threadIdx_dim}, lparams);
9778 auto cg_outputs = fe.runFusion({aten_input, runtime_threadIdx_dim}, lparams);
9779
9780 testValidate(
9781 &fusion,
9782 cg_outputs,
9783 {aten_input, runtime_threadIdx_dim},
9784 {aten_output},
9785 __LINE__,
9786 __FILE__,
9787 "",
9788 lparams);
9789
9790 TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
9791}
9792
9793TEST_F(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) {
9794 Fusion fusion;
9795 FusionGuard fg(&fusion);
9796
9797 Int* sym_bsx = IrBuilder::create<Int>();
9798 TensorView* tv0 = makeSymbolicTensor(2); // (M, K)
9799 TensorView* tv1 = makeSymbolicTensor(2); // (K, N)
9800 TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
9801 TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
9802 TensorView* tv4 = mul(tv2, tv3); // M, K, N
9803 fusion.addInput(tv0);
9804 fusion.addInput(tv1);
9805 fusion.addInput(sym_bsx);
9806 fusion.addOutput(tv4);
9807 // Algorithm
9808
9809 tv2->setMemoryType(MemoryType::Shared);
9810 tv3->setMemoryType(MemoryType::Shared);
9811
9812 constexpr int BSX = 32;
9813 tv4->split(2, BSX);
9814 tv4->split(1, sym_bsx);
9815 tv4->split(0, BSX);
9816 // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
9817 tv4->reorder({{0, 0}, {1, 3}, {2, 1}, {3, 4}, {4, 2}, {5, 5}});
9818 // M/BSX, K/BSX, N/BSX, MSX, KSX, NSX
9819
9820 tv0->computeAt(tv4, 3);
9821 tv1->computeAt(tv4, 3);
9822 // Schedule
9823
9824 tv4->axis(0)->parallelize(ParallelType::BIDx);
9825 tv4->axis(2)->parallelize(ParallelType::BIDy);
9826 // Manual Binding
9827 tv2->axis(-2)->parallelize(ParallelType::TIDx);
9828 tv3->axis(-1)->parallelize(ParallelType::TIDx);
9829 // Thread and Block binding
9830
9831 constexpr int M = 128, K = 457, N = 1024;
9832
9833 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
9834 at::Tensor t0 = at::randn({M, K}, options);
9835 at::Tensor t1 = at::randn({K, N}, options);
9836 at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0));
9837 std::vector<IValue> aten_inputs = {t0, t1, BSX};
9838
9839 LaunchParams lparams(-1, -1, -1, BSX, -1, -1);
9840
9841 FusionExecutor fe;
9842 fe.compileFusion(&fusion, aten_inputs, lparams);
9843 auto cg_outputs = fe.runFusion(aten_inputs, lparams);
9844
9845 testValidate(
9846 &fusion,
9847 cg_outputs,
9848 aten_inputs,
9849 {aten_output},
9850 __LINE__,
9851 __FILE__,
9852 "",
9853 lparams);
9854
9855 TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
9856}
9857
9858TEST_F(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) {
9859 Fusion fusion;
9860 FusionGuard fg(&fusion);
9861
9862 // Symbolic integers we will use for runtime tiling
9863 Int* symbolic_m_tile_dim = IrBuilder::create<Int>(); // bound to threadIdx.z
9864 Int* symbolic_split_k_tile_dim =
9865 IrBuilder::create<Int>(); // bound to blockIdx.x
9866 Int* symbolic_block_k_tile_dim =
9867 IrBuilder::create<Int>(); // bound to threadIdx.x
9868 // Compile-time integer for tiling
9869 int n_smem_tile = 8; // bound to threadIdx.y
9870
9871 // Symbolic 2D tensors TV0[M, K], TV1[K, N]
9872 TensorView* tv0 = makeSymbolicTensor(2);
9873 TensorView* tv1 = makeSymbolicTensor(2);
9874
9875 // Broadcast tv0 to [M, K, *]
9876 TensorView* tv2 = broadcast(tv0, {false, false, true});
9877 // Broadcast tv1 to [*, K, N]
9878 TensorView* tv3 = broadcast(tv1, {true, false, false});
9879
9880 // Pointwise multiplication resulting in tv3[M, K, N]
9881 TensorView* tv4 = mul(tv2, tv3);
9882
9883 // Turn the K-dimension of tv4 into a reduction dimension
9884 TensorView* tv5 = sum(tv4, {1});
9885
9886 // Register inputs and outputs
9887 fusion.addInput(tv0);
9888 fusion.addInput(tv1);
9889 fusion.addOutput(tv5);
9890
9891 // Register runtime tile dims as inputs
9892 fusion.addInput(symbolic_m_tile_dim);
9893 fusion.addInput(symbolic_split_k_tile_dim);
9894 fusion.addInput(symbolic_block_k_tile_dim);
9895
9896 // Make a 3D tile, mix of symbolic and constant, do in reverse order because
9897 // dims are inserted
9898 // [M, K, N]
9899 tv5->split(2, n_smem_tile);
9900 tv5->split(1, symbolic_block_k_tile_dim);
9901 tv5->split(1, symbolic_split_k_tile_dim);
9902 tv5->split(0, symbolic_m_tile_dim);
9903 // [Mo, Mi, Koo, Koi, Ki, No, Ni]
9904
9905 // Reorder so all outer tiles are in the leftmost 3 positions
9906 tv5->reorder({{1, 5}, {5, 1}});
9907 // [Mo, No, Koo, Koi, Ki, Mi, Ni]
9908
9909 // Factor out the outer reduction IterDomain, then run the inter-cta
9910 // reduction, and intra-cta reduction
9911 auto tv6 = tv5->rFactor({2});
9912 // [Mo, No, rKoo, rKoi, rKi, Mi, Ni]
9913 // [Mo, No, rKoi, rKi, Mi, Ni]
9914
9915 // Scope computations
9916 tv6->computeAt(tv5, 2);
9917 // [Mo, No, rKoo, Koi, Ki, Mi, Ni]
9918 // [Mo, No, rKoi, rKi, Mi, Ni]
9919
9920 // Setup compute at schedule
9921 tv0->computeAt(tv6, 3);
9922 tv1->computeAt(tv6, 3);
9923 tv4->computeAt(tv6, -1);
9924 //
9925 // T2[Mo, bNo, Koo, Koi, Kii, Mi, bNi] CA(4, 3)
9926 // T3[bMo, No, Koo, Koi, Kii, bMi, Ni] CA(4, 3)
9927 // T4[ Mo, No, Koo, Koi, Kii, Mi, Ni]
9928 // T6[ Mo, No, rKoo, Koi, Kii, Mi, Ni]
9929 // T5[ Mo, No, rKoi, rKii, Mi, Ni]
9930
9931 // Cache smem tiles
9932 tv2->setMemoryType(MemoryType::Shared);
9933 tv3->setMemoryType(MemoryType::Shared);
9934 tv4->setMemoryType(MemoryType::Local);
9935 tv6->setMemoryType(MemoryType::Local);
9936
9937 tv5->axis(0)->parallelize(ParallelType::BIDz);
9938 tv5->axis(1)->parallelize(ParallelType::BIDy);
9939
9940 std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6};
9941 for (auto tv : tv_list) {
9942 tv->axis(-2)->parallelize(ParallelType::TIDz);
9943 tv->axis(-1)->parallelize(ParallelType::TIDy);
9944 }
9945 tv2->axis(3)->parallelize(ParallelType::TIDx);
9946 tv3->axis(3)->parallelize(ParallelType::TIDx);
9947 tv4->axis(3)->parallelize(ParallelType::TIDx);
9948 tv6->axis(3)->parallelize(ParallelType::TIDx);
9949 tv5->axis(2)->parallelize(ParallelType::TIDx);
9950
9951 tv2->axis(4)->parallelize(ParallelType::BIDx);
9952 tv3->axis(4)->parallelize(ParallelType::BIDx);
9953 tv4->axis(4)->parallelize(ParallelType::BIDx);
9954 tv6->axis(4)->parallelize(ParallelType::BIDx);
9955 tv5->axis(3)->parallelize(ParallelType::BIDx);
9956
9957 constexpr int M = 31, K = 65, N = 33;
9958
9959 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
9960 at::Tensor t0 = at::randn({M, K}, options);
9961 at::Tensor t1 = at::randn({K, N}, options);
9962
9963 // Runtime tiling
9964 int m_tile = 4; // bound to threadIdx.z
9965 int split_k = 7; // bound to blockIdx.x
9966 int intra_cta = 8; // bound to threadIdx.x
9967
9968 std::vector<IValue> aten_inputs = {t0, t1, m_tile, split_k, intra_cta};
9969 at::Tensor aten_output =
9970 mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1);
9971
9972 FusionExecutor fe;
9973 // Generate CUDA and compile with nvRTC
9974 fe.compileFusion(&fusion, aten_inputs);
9975 auto cg_outputs = fe.runFusion(aten_inputs);
9976
9977 testValidate(
9978 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
9979
9980 TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
9981}
9982
9983} // namespace jit
9984} // namespace torch
9985#endif // #if defined(USE_CUDA)
9986