1#if defined(USE_CUDA)
2#include <gtest/gtest.h>
3
4#include <arith.h>
5#include <codegen.h>
6#include <disjoint_set.h>
7#include <executor.h>
8#include <executor_launch_params.h>
9#include <expr_evaluator.h>
10#include <fusion.h>
11#include <fusion_segmenter.h>
12#include <ir_all_nodes.h>
13#include <ir_graphviz.h>
14#include <ir_iostream.h>
15#include <ir_printer.h>
16#include <ir_utils.h>
17#include <iter_visitor.h>
18#include <kernel_cache.h>
19#include <kernel_expr_evaluator.h>
20#include <kernel_ir.h>
21#include <lower2device.h>
22#include <mma_type.h>
23#include <mutator.h>
24#include <ops/all_ops.h>
25#include <register_interface.h>
26#include <root_domain_map.h>
27#include <scheduler/all_schedulers.h>
28#include <scheduler/matmul.h>
29#include <scheduler/reduction_utils.h>
30#include <scheduler/utils.h>
31#include <test/test_gpu_validator.h>
32#include <test/test_utils.h>
33#include <transform_replay.h>
34#include <transform_rfactor.h>
35
36// fuser and IR parser
37#include <ATen/cuda/CUDAContext.h>
38#include <ATen/cuda/Exceptions.h>
39#include <c10/cuda/CUDAStream.h>
40
41#include <algorithm>
42#include <iostream>
43
44// Tests go in torch::jit
45namespace torch {
46namespace jit {
47
48using namespace torch::jit::fuser::cuda;
49using namespace at::indexing;
50
51namespace {
52
53bool cudaArchGuardShouldSkip(int required_major, int required_minor) {
54 int capability_major = at::cuda::getCurrentDeviceProperties()->major;
55 int capability_minor = at::cuda::getCurrentDeviceProperties()->minor;
56
57 if (capability_major < required_major ||
58 (capability_major == required_major &&
59 capability_minor < required_minor)) {
60 return true;
61 }
62 return false;
63}
64
65#define NVFUSER_TEST_CUDA_ARCH_GUARD(REQUIRED_MAJOR, REQUIRED_MINOR) \
66 if (cudaArchGuardShouldSkip(REQUIRED_MAJOR, REQUIRED_MINOR)) { \
67 GTEST_SKIP() << "Requires GPU capability above " << REQUIRED_MAJOR << "." \
68 << REQUIRED_MINOR << " to run.\n"; \
69 }
70
71#define NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( \
72 REQUIRED_MAJOR, REQUIRED_MINOR, COMPILE_FUSION) \
73 if (cudaArchGuardShouldSkip(REQUIRED_MAJOR, REQUIRED_MINOR)) { \
74 ASSERT_ANY_THROW(COMPILE_FUSION); \
75 GTEST_SKIP() << "(Lowered Only) Requires GPU capability above " \
76 << REQUIRED_MAJOR << "." << REQUIRED_MINOR << " to run.\n"; \
77 } else { \
78 COMPILE_FUSION; \
79 }
80
81// util to track support matmul operand layout.
82using MatmulLayout = MmaOptions::MmaInputLayout;
83
84static constexpr std::array<MatmulLayout, 3> kAllSupportedLayout = {
85 MatmulLayout::TT,
86 MatmulLayout::NT,
87 MatmulLayout::TN};
88
89// Generic interface to get matmul op with the given layout.
90TensorView* matmul(TensorView* a, TensorView* b, MatmulLayout layout) {
91 TORCH_CHECK(
92 a->nDims() == 2 && b->nDims() == 2, "only pure matmuls for these tests");
93 TensorView *tv2 = nullptr, *tv0b = nullptr, *tv1b = nullptr;
94 switch (layout) {
95 case MatmulLayout::TT:
96 tv0b = broadcast(a, {false, false, true});
97 tv1b = broadcast(b, {true, false, false});
98 tv2 = fusedMultiplySum(tv0b, tv1b, {1});
99 break;
100 case MatmulLayout::TN:
101 tv0b = broadcast(a, {false, true, false});
102 tv1b = broadcast(b, {true, false, false});
103 tv2 = fusedMultiplySum(tv0b, tv1b, {2});
104 break;
105 case MatmulLayout::NT:
106 tv0b = broadcast(a, {false, false, true});
107 tv1b = broadcast(b, {false, true, false});
108 tv2 = fusedMultiplySum(tv0b, tv1b, {0});
109 break;
110 default:
111 TORCH_CHECK(false, "unsupported data layout.");
112 }
113 return tv2;
114}
115
116// Utility to generate matmul input tensors based on given layout
117at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) {
118 switch (layout) {
119 case MatmulLayout::TT:
120 return a.matmul(b);
121 case MatmulLayout::TN:
122 return a.matmul(b.t());
123 case MatmulLayout::NT:
124 return a.t().matmul(b);
125 default:
126 TORCH_CHECK(false, "unsupported data layout.");
127 }
128 return at::Tensor();
129}
130
131// Utility to generate reference results based on given layout
132std::pair<at::Tensor, at::Tensor> fp16MatmulAtInput(
133 int M,
134 int N,
135 int K,
136 MatmulLayout layout) {
137 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
138
139 switch (layout) {
140 case MatmulLayout::TT:
141 return std::make_pair(
142 at::randn({M, K}, options), at::randn({K, N}, options));
143 case MatmulLayout::TN:
144 return std::make_pair(
145 at::randn({M, K}, options), at::randn({N, K}, options));
146 case MatmulLayout::NT:
147 return std::make_pair(
148 at::randn({K, M}, options), at::randn({K, N}, options));
149 default:
150 TORCH_CHECK(false, "unsupported data layout.");
151 }
152 return std::make_pair(at::Tensor(), at::Tensor());
153}
154
155#define REQUIRE_DEVICE_SMEM_SIZE(required_size, device_idx) \
156 if (at::cuda::getDeviceProperties(device_idx)->sharedMemPerBlockOptin < \
157 required_size) { \
158 GTEST_SKIP() << "not enough shared memory space on device to run test"; \
159 }
160
161} // namespace
162
163// MMA unit test for a single instruction tile. VoltaTT
164TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) {
165 Fusion fusion;
166 FusionGuard fg(&fusion);
167
168 // [M,K]
169 auto tv0 = makeConcreteTensor({16, 4}, DataType::Half);
170 // [K,N]
171 auto tv1 = makeConcreteTensor({4, 16}, DataType::Half);
172 fusion.addInput(tv0);
173 fusion.addInput(tv1);
174
175 // [M,K,N]
176 auto tv0b = broadcast(tv0, {false, false, true});
177 auto tv1b = broadcast(tv1, {true, false, false});
178
179 // Leaving both sets of mma inputs for volta outside
180 // currently since they need to be swizzled.
181 auto tv2 = fusedMultiplySum(tv0b, tv1b, {1});
182
183 fusion.addOutput(tv2);
184
185 // TODO: should be able to completely remove it
186 // in a follow up.
187 MatMulTileOptions gemm_tile;
188 gemm_tile.cta_tile = GemmTile(16, 16, 4);
189 gemm_tile.warp_tile = GemmTile(16, 16, 4);
190 gemm_tile.instruction_tile = GemmTile(16, 16, 4);
191
192 auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile)
193 .layout(MmaOptions::MmaInputLayout::TT);
194 mma_builder.configureMma(tv2);
195
196 // Write A to smem
197 auto tv0cw = tv0b->cacheAfter();
198 // Read A from smem
199 auto tv0cr = tv0cw->cacheAfter();
200
201 // Write B to smem
202 auto tv1cw = tv1b->cacheAfter();
203
204 // Read B from smem
205 auto tv1cr = tv1cw->cacheAfter();
206
207 // Register accumulator
208 auto tv2c = tv2->cacheBefore();
209
210 mma_builder.accumulatorTv(tv2c);
211
212 // [M,K,N]->[M,N,K]
213 tv0cr->reorder({{-2, -1}, {-1, -2}});
214
215 // Schedule the instruction tile loops, which is the only
216 // part we have in this unit test.
217 // Assumes last 3 dims are mnk
218 // The innermost loops are dictated by the type of mma used,
219 // the scheduler needs to use mma_util::WarpMmaSwizzler to
220 // get the right thread swizzle. Currently this is the only
221 // method allowed to schedule the 3/2 inner most loops of
222 // mma input/output.
223 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
224
225 // [M,K,N]->[M,N,K]
226 tv1cr->reorder({{-2, -1}, {-1, -2}});
227 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
228
229 // [M,K,N]->[M,N,K]
230 tv2c->reorder({{-2, -1}, {-1, -2}});
231
232 // Schedule the output instruction tile.
233 // Assumes last 3 dims are mnk
234 tv2c->applyMmaSwizzle(
235 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
236 tv2->applyMmaSwizzle(
237 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
238
239 // Set memory type.
240 tv0cw->setMemoryType(MemoryType::Shared);
241 tv1cw->setMemoryType(MemoryType::Shared);
242
243 at::manual_seed(0);
244 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
245 auto t0 = at::randn({16, 4}, options);
246 auto t1 = at::randn({4, 16}, options);
247
248 FusionExecutor fe;
249 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
250 7, 0, fe.compileFusion(&fusion, {t0, t1}));
251 auto cg_outputs = fe.runFusion({t0, t1});
252
253 auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat));
254
255 testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
256}
257
258// MMA unit test for a single instruction tile. VoltaTN
259TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) {
260 Fusion fusion;
261 FusionGuard fg(&fusion);
262
263 // [M,K]
264 auto tv0 = makeConcreteTensor({16, 4}, DataType::Half);
265 // [N,K]
266 auto tv1 = makeConcreteTensor({16, 4}, DataType::Half);
267 fusion.addInput(tv0);
268 fusion.addInput(tv1);
269
270 // [M,N,K]
271 auto tv0b = broadcast(tv0, {false, true, false});
272 auto tv1b = broadcast(tv1, {true, false, false});
273
274 // Leaving both sets of mma inputs for volta outside
275 // currently since they need to be swizzled.
276 auto tv2 = fusedMultiplySum(tv0b, tv1b, {2});
277
278 fusion.addOutput(tv2);
279
280 // TODO: should be able to completely remove it
281 // in a follow up.
282 MatMulTileOptions gemm_tile;
283 gemm_tile.cta_tile = GemmTile(16, 16, 4);
284 gemm_tile.warp_tile = GemmTile(16, 16, 4);
285 gemm_tile.instruction_tile = GemmTile(16, 16, 4);
286
287 auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile)
288 .layout(MmaOptions::MmaInputLayout::TN);
289
290 mma_builder.configureMma(tv2);
291
292 auto tv0cw = tv0b->cacheAfter();
293 auto tv0cr = tv0cw->cacheAfter();
294 auto tv1cw = tv1b->cacheAfter();
295 auto tv1cr = tv1cw->cacheAfter();
296 auto tv2c = tv2->cacheBefore();
297
298 mma_builder.accumulatorTv(tv2c);
299
300 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
301 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
302 tv2c->applyMmaSwizzle(
303 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
304 tv2->applyMmaSwizzle(
305 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
306
307 tv0cw->setMemoryType(MemoryType::Shared);
308 tv1cw->setMemoryType(MemoryType::Shared);
309
310 at::manual_seed(0);
311 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
312 auto t0 = at::randn({16, 4}, options);
313 auto t1 = at::randn({16, 4}, options);
314
315 FusionExecutor fe;
316 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
317 7, 0, fe.compileFusion(&fusion, {t0, t1}));
318 auto cg_outputs = fe.runFusion({t0, t1});
319 auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));
320 testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
321}
322
323// MMA unit test for a single instruction tile. VoltaNT
324TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) {
325 Fusion fusion;
326 FusionGuard fg(&fusion);
327
328 // [K,M]
329 auto tv0 = makeConcreteTensor({4, 16}, DataType::Half);
330 // [K,N]
331 auto tv1 = makeConcreteTensor({4, 16}, DataType::Half);
332 fusion.addInput(tv0);
333 fusion.addInput(tv1);
334
335 // [K,M,N]
336 auto tv0b = broadcast(tv0, {false, false, true});
337 auto tv1b = broadcast(tv1, {false, true, false});
338
339 // Leaving both sets of mma inputs for volta outside
340 // currently since they need to be swizzled.
341 auto tv2 = fusedMultiplySum(tv0b, tv1b, {0});
342
343 fusion.addOutput(tv2);
344
345 MatMulTileOptions gemm_tile;
346 gemm_tile.cta_tile = GemmTile(16, 16, 4);
347 gemm_tile.warp_tile = GemmTile(16, 16, 4);
348 gemm_tile.instruction_tile = GemmTile(16, 16, 4);
349
350 auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile)
351 .layout(MmaOptions::MmaInputLayout::NT);
352
353 mma_builder.configureMma(tv2);
354
355 auto tv0cw = tv0b->cacheAfter();
356 auto tv0cr = tv0cw->cacheAfter();
357 auto tv1cw = tv1b->cacheAfter();
358 auto tv1cr = tv1cw->cacheAfter();
359 auto tv2c = tv2->cacheBefore();
360
361 mma_builder.accumulatorTv(tv2c);
362
363 // To MNK
364 tv0cr->reorder({{0, 2}, {1, 0}, {2, 1}});
365 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
366
367 // To MNK
368 tv1cr->reorder({{0, 2}, {1, 0}, {2, 1}});
369 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
370
371 tv2c->reorder({{0, 2}, {1, 0}, {2, 1}});
372 tv2c->applyMmaSwizzle(
373 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
374 tv2->applyMmaSwizzle(
375 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
376 tv0cw->setMemoryType(MemoryType::Shared);
377 tv1cw->setMemoryType(MemoryType::Shared);
378
379 at::manual_seed(0);
380 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
381 auto t0 = at::randn({4, 16}, options);
382 auto t1 = at::randn({4, 16}, options);
383
384 FusionExecutor fe;
385 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
386 7, 0, fe.compileFusion(&fusion, {t0, t1}));
387 auto cg_outputs = fe.runFusion({t0, t1});
388 auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat));
389 testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
390}
391
392// Matmul test for Volta MMA: across supported layouts
393TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) {
394 // Keep multiples of 8 to keep vectorizable.
395 int M = 264, N = 136, K = 248;
396
397 for (auto layout : kAllSupportedLayout) {
398 Fusion fusion;
399 FusionGuard fg(&fusion);
400 auto tv0 = makeContigTensor(2, DataType::Half);
401 auto tv1 = makeContigTensor(2, DataType::Half);
402
403 fusion.addInput(tv0);
404 fusion.addInput(tv1);
405
406 auto tv2 = matmul(tv0, tv1, layout);
407
408 fusion.addOutput(tv2);
409
410 MatMulTileOptions gemm_tile;
411 gemm_tile.cta_tile = GemmTile(128, 128, 32);
412 gemm_tile.warp_tile = GemmTile(64, 64, 32);
413 gemm_tile.instruction_tile = GemmTile(16, 16, 4);
414
415 auto mma_builder =
416 MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile)
417 .layout(layout);
418
419 MatmulParam params(mma_builder);
420 params.tile_sizes = gemm_tile;
421 scheduleMatmul(tv2, tv0, tv1, params);
422
423 at::manual_seed(0);
424 auto inputs = fp16MatmulAtInput(M, N, K, layout);
425
426 FusionExecutor fe;
427 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
428 7, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
429 auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
430 auto tref = atMatmul(
431 inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
432 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
433 }
434}
435
436// Matmul test for Volta MMA: across supported layouts
437TEST_F(NVFuserTest, FusionVoltaMatmulRegDoubleBuffer_CUDA) {
438 // Keep multiples of 8 to keep vectorizable.
439 int M = 264, N = 136, K = 248;
440
441 for (auto layout : kAllSupportedLayout) {
442 Fusion fusion;
443 FusionGuard fg(&fusion);
444 auto tv0 = makeContigTensor(2, DataType::Half);
445 auto tv1 = makeContigTensor(2, DataType::Half);
446
447 fusion.addInput(tv0);
448 fusion.addInput(tv1);
449
450 auto tv2 = matmul(tv0, tv1, layout);
451
452 fusion.addOutput(tv2);
453
454 MatMulTileOptions gemm_tile;
455 gemm_tile.cta_tile = GemmTile(128, 128, 32);
456 gemm_tile.warp_tile = GemmTile(64, 64, 32);
457 gemm_tile.instruction_tile = GemmTile(16, 16, 4);
458
459 auto mma_builder =
460 MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile)
461 .layout(layout);
462
463 MatmulParam params(mma_builder);
464 params.tile_sizes = gemm_tile;
465 params.double_buffer_options.double_buffer_smem_read = true;
466 scheduleMatmul(tv2, tv0, tv1, params);
467
468 at::manual_seed(0);
469 auto inputs = fp16MatmulAtInput(M, N, K, layout);
470
471 FusionExecutor fe;
472 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
473 7, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
474 auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
475 auto tref = atMatmul(
476 inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
477 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
478 }
479}
480
481// MMA unit test on Ampere
482TEST_F(NVFuserTest, FusionAmpereMMATN_CUDA) {
483 NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
484
485 Fusion fusion;
486 FusionGuard fg(&fusion);
487
488 // [M,K]
489 auto tv0 = makeConcreteTensor({16, 16}, DataType::Half);
490 // [N,K]
491 auto tv1 = makeConcreteTensor({8, 16}, DataType::Half);
492 fusion.addInput(tv0);
493 fusion.addInput(tv1);
494
495 // [M,N,K]
496 auto tv0b = broadcast(tv0, {false, true, false});
497 auto tv1b = broadcast(tv1, {true, false, false});
498
499 // Leaving both sets of mma inputs for volta outside
500 // currently since they need to be swizzled.
501 auto tv2 = fusedMultiplySum(tv0b, tv1b, {2});
502
503 fusion.addOutput(tv2);
504
505 MatMulTileOptions gemm_tile;
506 gemm_tile.cta_tile = GemmTile(16, 8, 16);
507 gemm_tile.warp_tile = GemmTile(16, 8, 16);
508 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
509
510 auto mma_builder =
511 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
512 .layout(MmaOptions::MmaInputLayout::TN);
513
514 mma_builder.configureMma(tv2);
515
516 auto tv0cw = tv0b->cacheAfter();
517 auto tv0cr =
518 tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
519 auto tv1cw = tv1b->cacheAfter();
520 auto tv1cr =
521 tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
522
523 auto tv2c = tv2->cacheBefore();
524
525 mma_builder.accumulatorTv(tv2c);
526
527 // [M,N,K] -> [N,M,K]
528 tv0cr->reorder({{-2, -3}, {-3, -2}});
529 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
530 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
531 tv2c->applyMmaSwizzle(
532 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
533 tv2->applyMmaSwizzle(
534 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
535
536 tv0cw->setMemoryType(MemoryType::Shared);
537 tv1cw->setMemoryType(MemoryType::Shared);
538
539 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
540 auto t0 = at::randn({16, 16}, options);
541 auto t1 = at::randn({8, 16}, options);
542
543 FusionExecutor fe;
544 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
545 8, 0, fe.compileFusion(&fusion, {t0, t1}));
546 auto cg_outputs = fe.runFusion({t0, t1});
547
548 auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));
549
550 testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
551}
552
553// MMA unit test on Ampere
554TEST_F(NVFuserTest, FusionAmpereMMATT_CUDA) {
555 NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
556
557 Fusion fusion;
558 FusionGuard fg(&fusion);
559
560 // [M,K]
561 auto tv0 = makeConcreteTensor({16, 16}, DataType::Half);
562 // [K,N]
563 auto tv1 = makeConcreteTensor({16, 8}, DataType::Half);
564 fusion.addInput(tv0);
565 fusion.addInput(tv1);
566
567 // [M,K,N]
568 auto tv0b = broadcast(tv0, {false, false, true});
569 auto tv1b = broadcast(tv1, {true, false, false});
570
571 auto tv2 = fusedMultiplySum(tv0b, tv1b, {1});
572
573 fusion.addOutput(tv2);
574
575 MatMulTileOptions gemm_tile;
576 gemm_tile.cta_tile = GemmTile(16, 8, 16);
577 gemm_tile.warp_tile = GemmTile(16, 8, 16);
578 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
579
580 auto mma_builder =
581 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
582 .layout(MmaOptions::MmaInputLayout::TT);
583
584 mma_builder.configureMma(tv2);
585
586 auto tv0cw = tv0b->cacheAfter();
587 auto tv0cr =
588 tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
589 auto tv1cw = tv1b->cacheAfter();
590 auto tv1cr =
591 tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
592
593 auto tv2c = tv2->cacheBefore();
594
595 mma_builder.accumulatorTv(tv2c);
596 // [M,K,N] -> [N,M,K]
597 tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}});
598 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
599
600 // [M,K,N] -> [M,N,K]
601 tv1cr->reorder({{-2, -1}, {-1, -2}});
602 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
603
604 // [M,K,N] -> [M,N,K]
605 tv2c->reorder({{-2, -1}, {-1, -2}});
606 tv2c->applyMmaSwizzle(
607 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
608 tv2->applyMmaSwizzle(
609 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
610
611 tv0cw->setMemoryType(MemoryType::Shared);
612 tv1cw->setMemoryType(MemoryType::Shared);
613
614 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
615 auto t0 = at::randn({16, 16}, options);
616 auto t1 = at::randn({16, 8}, options);
617
618 FusionExecutor fe;
619
620 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
621 8, 0, fe.compileFusion(&fusion, {t0, t1}));
622
623 auto cg_outputs = fe.runFusion({t0, t1});
624
625 auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat));
626
627 testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
628}
629
630// MMA unit test on Ampere
631TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) {
632 NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
633
634 Fusion fusion;
635 FusionGuard fg(&fusion);
636
637 // [K,M]
638 auto tv0 = makeConcreteTensor({16, 16}, DataType::Half);
639 // [K,N]
640 auto tv1 = makeConcreteTensor({16, 8}, DataType::Half);
641 fusion.addInput(tv0);
642 fusion.addInput(tv1);
643
644 // [K,M,N]
645 auto tv0b = broadcast(tv0, {false, false, true});
646 auto tv1b = broadcast(tv1, {false, true, false});
647 auto tv2 = fusedMultiplySum(tv0b, tv1b, {0});
648
649 fusion.addOutput(tv2);
650
651 MatMulTileOptions gemm_tile;
652 gemm_tile.cta_tile = GemmTile(16, 8, 16);
653 gemm_tile.warp_tile = GemmTile(16, 8, 16);
654 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
655
656 auto mma_builder =
657 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
658 .layout(MmaOptions::MmaInputLayout::NT);
659
660 mma_builder.configureMma(tv2);
661
662 auto tv0cw = tv0b->cacheAfter();
663 auto tv0cr =
664 tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
665 auto tv1cw = tv1b->cacheAfter();
666 auto tv1cr =
667 tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
668
669 auto tv2c = tv2->cacheBefore();
670 mma_builder.accumulatorTv(tv2c);
671
672 // [K,M,N] -> [N,M,K]
673 tv0cr->reorder({{-3, -1}, {-1, -3}});
674 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
675
676 // [K,M,N] -> [M,N,K]
677 tv1cr->reorder({
678 {-3, -1},
679 {-2, -3},
680 {-1, -2},
681 });
682 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
683
684 // [K,M,N] -> [M,N,K]
685 tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}});
686 tv2c->applyMmaSwizzle(
687 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
688 tv2->applyMmaSwizzle(
689 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
690
691 tv0cw->setMemoryType(MemoryType::Shared);
692 tv1cw->setMemoryType(MemoryType::Shared);
693
694 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
695 auto t0 = at::randn({16, 16}, options);
696 auto t1 = at::randn({16, 8}, options);
697
698 FusionExecutor fe;
699 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
700 8, 0, fe.compileFusion(&fusion, {t0, t1}));
701 auto cg_outputs = fe.runFusion({t0, t1});
702
703 auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat));
704
705 testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
706}
707
708// Matmul test for Ampere MMA: across supported layouts
709TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) {
710 // Keep multiples of 8 to keep vectorizable.
711 int M = 504, N = 136, K = 248;
712
713 for (auto layout : kAllSupportedLayout) {
714 Fusion fusion;
715 FusionGuard fg(&fusion);
716 auto tv0 = makeContigTensor(2, DataType::Half);
717 auto tv1 = makeContigTensor(2, DataType::Half);
718
719 fusion.addInput(tv0);
720 fusion.addInput(tv1);
721
722 auto tv2 = matmul(tv0, tv1, layout);
723
724 fusion.addOutput(tv2);
725
726 MatMulTileOptions gemm_tile;
727 gemm_tile.cta_tile = GemmTile(128, 128, 32);
728 gemm_tile.warp_tile = GemmTile(64, 64, 32);
729 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
730
731 auto mma_builder =
732 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
733 .layout(layout);
734
735 MatmulParam params(mma_builder);
736 params.tile_sizes = gemm_tile;
737 params.async_gmem_load_operands = true;
738 params.double_buffer_options.double_buffer_smem_write = true;
739 params.double_buffer_options.smem_double_buffer_stage = 4;
740 scheduleMatmul(tv2, tv0, tv1, params);
741
742 at::manual_seed(0);
743 auto inputs = fp16MatmulAtInput(M, N, K, layout);
744
745 FusionExecutor fe;
746 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
747 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
748 auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
749 auto tref = atMatmul(
750 inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
751 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
752 }
753}
754
755// Matmul test for Ampere MMA: with pipelined gmem load
756TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) {
757 // Keep multiples of 8 to keep vectorizable.
758 int M = 504, N = 136, K = 248;
759 REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0);
760
761 // Gmem pipeline stage
762 for (auto stage : {3, 4}) {
763 for (auto layout : kAllSupportedLayout) {
764 Fusion fusion;
765 FusionGuard fg(&fusion);
766 auto tv0 = makeContigTensor(2, DataType::Half);
767 auto tv1 = makeContigTensor(2, DataType::Half);
768
769 fusion.addInput(tv0);
770 fusion.addInput(tv1);
771
772 auto tv2 = matmul(tv0, tv1, layout);
773
774 fusion.addOutput(tv2);
775
776 MatMulTileOptions gemm_tile;
777 gemm_tile.cta_tile = GemmTile(128, 128, 32);
778 gemm_tile.warp_tile = GemmTile(64, 64, 32);
779 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
780
781 auto mma_builder =
782 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
783 .layout(layout);
784
785 MatmulParam params(mma_builder);
786 params.tile_sizes = gemm_tile;
787 params.tile_sizes = gemm_tile;
788 params.async_gmem_load_operands = true;
789 params.double_buffer_options.double_buffer_smem_write = true;
790 params.double_buffer_options.smem_double_buffer_stage = stage;
791 scheduleMatmul(tv2, tv0, tv1, params);
792
793 at::manual_seed(0);
794 auto inputs = fp16MatmulAtInput(M, N, K, layout);
795
796 FusionExecutor fe;
797 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
798 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
799 auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
800 auto tref = atMatmul(
801 inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
802 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
803 }
804 }
805}
806
807TEST_F(NVFuserTest, FusionAmpereMatmulRegDbouleBuffer_CUDA) {
808 // Keep multiples of 8 to keep vectorizable.
809 int M = 504, N = 136, K = 248;
810 REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0);
811
812 // Gmem pipeline stage
813 for (auto stage : {3, 4}) {
814 for (auto layout : kAllSupportedLayout) {
815 Fusion fusion;
816 FusionGuard fg(&fusion);
817 auto tv0 = makeContigTensor(2, DataType::Half);
818 auto tv1 = makeContigTensor(2, DataType::Half);
819
820 fusion.addInput(tv0);
821 fusion.addInput(tv1);
822
823 auto tv2 = matmul(tv0, tv1, layout);
824
825 fusion.addOutput(tv2);
826
827 MatMulTileOptions gemm_tile;
828 gemm_tile.cta_tile = GemmTile(128, 128, 32);
829 gemm_tile.warp_tile = GemmTile(64, 64, 32);
830 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
831
832 auto mma_builder =
833 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
834 .layout(layout);
835
836 MatmulParam params(mma_builder);
837 params.tile_sizes = gemm_tile;
838 params.tile_sizes = gemm_tile;
839 params.async_gmem_load_operands = true;
840 params.double_buffer_options.double_buffer_smem_write = true;
841 params.double_buffer_options.smem_double_buffer_stage = stage;
842 params.double_buffer_options.double_buffer_smem_read = true;
843 scheduleMatmul(tv2, tv0, tv1, params);
844
845 at::manual_seed(0);
846 auto inputs = fp16MatmulAtInput(M, N, K, layout);
847
848 FusionExecutor fe;
849 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
850 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
851 auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
852 auto tref = atMatmul(
853 inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
854 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
855 }
856 }
857}
858
859// Matmul-Matmul fusion test on Ampere
860TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) {
861 NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
862
863 Fusion fusion;
864 FusionGuard fg(&fusion);
865 int M = 512, N = 256, K1 = 128, K2 = 128;
866
867 // Fusion definition (Both gemms are TN)
868 // [M,K1]
869 auto tv0 = makeConcreteTensor({M, K1}, DataType::Half);
870 // [K2,K1]
871 auto tv1 = makeConcreteTensor({K2, K1}, DataType::Half);
872 // [N,K2]
873 auto tv2 = makeConcreteTensor({N, K2}, DataType::Half);
874
875 fusion.addInput(tv0);
876 fusion.addInput(tv1);
877 fusion.addInput(tv2);
878
879 // [M,N,K]
880 auto tv0b = broadcast(tv0, {false, true, false});
881 auto tv1b = broadcast(tv1, {true, false, false});
882 auto tv2b = broadcast(tv2, {true, false, false});
883
884 // [M,K2,R]
885 auto tv3 = fusedMultiplySum(tv0b, tv1b, {2});
886
887 auto tv3h = castOp(DataType::Half, tv3);
888 auto tv3b = broadcast(tv3h, {false, true, false});
889
890 auto tv4 = fusedMultiplySum(tv3b, tv2b, {2});
891
892 fusion.addOutput(tv4);
893
894 // Fusion:
895 // Gemm(M,K2,K1) x Gemm(M,N,K2)
896
897 MatMulTileOptions gemm_tile1, gemm_tile2;
898
899 // cta tile:
900 // To save register, n of cta tile 1
901 // matches k of cta tile2
902 gemm_tile1.cta_tile = GemmTile(128, 64, 32);
903 gemm_tile2.cta_tile = GemmTile(128, 32, 64);
904
905 // Distribute to 2x2 warps
906 gemm_tile1.warp_tile = GemmTile(64, 32, 32);
907 gemm_tile2.warp_tile = GemmTile(64, 16, 64);
908
909 // Using Ampere mma macro
910 gemm_tile2.instruction_tile = GemmTile(16, 8, 16);
911 gemm_tile2.instruction_tile = GemmTile(16, 8, 16);
912
913 auto mma_builder1 =
914 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile1)
915 .layout(MmaOptions::MmaInputLayout::TN);
916
917 auto mma_builder2 =
918 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile2)
919 .layout(MmaOptions::MmaInputLayout::TN);
920
921 mma_builder1.configureMma(tv3);
922 mma_builder2.configureMma(tv4);
923
924 // Global read for gemm 1
925 auto tv0r = tv0->cacheAfter();
926 auto tv1r = tv1->cacheAfter();
927
928 // Global read for gemm 2
929 auto tv2r = tv2->cacheAfter();
930
931 // Gemm 1 main loop read
932 auto tv0cw = tv0r->cacheAfter();
933 auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix);
934 auto tv1cw = tv1r->cacheAfter();
935 auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix);
936
937 // Gemm 1 accumulator reg
938 auto tv3c = tv3->cacheBefore();
939 mma_builder1.accumulatorTv(tv3c);
940
941 // Gemm 2 main loop read
942 auto tv3cw = tv3h->cacheAfter();
943 auto tv3cr = tv3cw->cacheAfter(LoadStoreOpType::LdMatrix);
944
945 auto tv2cw = tv2r->cacheAfter();
946 auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix);
947
948 // Gemm 2 accumulator reg
949 auto tv4c = tv4->cacheBefore();
950 mma_builder2.accumulatorTv(tv4c);
951
952 // General idea is inlining gemm1's main loop inside gemm2's
953
954 // Schedule gemm 2:
955 // ------------------------------------------------------------------
956 tv4->split(-2, gemm_tile2.cta_tile.m);
957 tv4->split(-1, gemm_tile2.cta_tile.n);
958
959 // 0 1 2 3
960 // [Mo,M128, No, N128]
961 tv4->reorder({{1, 2}, {2, 1}});
962
963 // 0 1 2 3
964 // [Mo,No, M128, N128]
965 tv2->computeAt(tv4, 2);
966 tv3->computeAt(tv4, 2);
967
968 // Order K
969 // 0 1 2 3 4 5
970 // [Mo,No, M128, N128, Ko, K32]
971 tv4c->split(-1, gemm_tile2.cta_tile.k);
972 tv4c->reorder({{2, 3}, {3, 4}, {4, 2}});
973
974 // 0 1 2 3 4 5
975 // [Mo,No, Ko M128, N128, K32]
976 tv3->computeAt(tv4c, 3); // Implicitly defines cta tile of gemm1
977 tv2r->computeAt(tv4c, 3);
978
979 // Make warp tile
980 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(
981 tv4c, gemm_tile2);
982 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
983 tv4, gemm_tile2);
984 // -8 -7 -6 -5 -4 -3 -2 -1
985 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
986 tv3cr->computeAt(tv4c, -4);
987 tv2cr->computeAt(tv4c, -4);
988
989 // Schedule tv2 gmem read and smem write:
990 // ----------------------------------------------------------------
991 // [No,Ko,N,K]
992 tv2cw->merge(-2);
993 tv2r->merge(-2);
994
995 // [No,Ko,i,wy,wx,v]
996 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
997 tv2cw, gemm_tile2, 8);
998 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
999 tv2r, gemm_tile2, 8);
1000 tv2cw->setMemoryType(MemoryType::Shared);
1001
1002 // Schedule tv2 gmem read and smem write:
1003 // ----------------------------------------------------------------
1004
1005 // Schedule gemm 2 mma input
1006 // ---------------------------------------------------------------------------
1007 tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build());
1008
1009 // [... Mi, Ni, Ki] want [Ni, Mi, Ki]
1010 tv3b->reorder({{-2, -3}, {-3, -2}});
1011 tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build());
1012
1013 tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build());
1014 tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build());
1015
1016 // Schedule mma output
1017 // ---------------------------------------------------------------------------
1018 tv4c->applyMmaSwizzle(
1019 mma_builder2.operand(MmaOptions::Operand::Accumulator).build());
1020 tv4->applyMmaSwizzle(
1021 mma_builder2.operand(MmaOptions::Operand::Accumulator).build());
1022
1023 // Schedule gemm 1:
1024 // ------------------------------------------------------------------
1025
1026 // CTA tile:
1027 tv0->computeAt(tv3, 2);
1028 tv1->computeAt(tv3, 2);
1029
1030 // Schedule K dim for gemm 1:
1031
1032 // Order K
1033 // 0 1 2 3 4 5
1034 // [Mo,No, M128, N128, Ko, K32]
1035 tv3c->split(-1, gemm_tile1.cta_tile.k);
1036 tv3c->reorder({{2, 3}, {3, 4}, {4, 2}});
1037 // 0 1 2 3 4 5
1038 // [Mo,No, Ko M128, N128, K32]
1039 tv0r->computeAt(tv3c, 3);
1040 tv1r->computeAt(tv3c, 3);
1041
1042 // Make warp tile:
1043 // -------------------------------------------------------------------------
1044 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(
1045 tv3c, gemm_tile1);
1046 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
1047 tv3cw, gemm_tile1);
1048
1049 tv0cr->computeAt(tv3c, -4);
1050 tv1cr->computeAt(tv3c, -4);
1051
1052 tv3->computeAt(tv3cw, -3);
1053
1054 // Schedule gmem read and smem write:
1055 // ---------------------------------------------------------------------------
1056 // [Mo,Ko,M,K]
1057 tv0cw->merge(-2);
1058 tv0r->merge(-2);
1059 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1060 tv0cw, gemm_tile1, 8);
1061 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1062 tv0r, gemm_tile1, 8);
1063 tv0cw->setMemoryType(MemoryType::Shared);
1064 // [Mo,Ko,i,wy,wx,v]
1065
1066 // [No,Ko,N,K]
1067 tv1cw->merge(-2);
1068 tv1r->merge(-2);
1069 // [No,Ko,i,wy,wx,v]
1070 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1071 tv1cw, gemm_tile1, 8);
1072 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1073 tv1r, gemm_tile1, 8);
1074 tv1cw->setMemoryType(MemoryType::Shared);
1075
1076 // Schedule mma input
1077 // ---------------------------------------------------------------------------
1078 tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build());
1079 // [... Mi, Ni, Ki] want [Ni, Mi, Ki]
1080 tv0b->reorder({{-2, -3}, {-3, -2}});
1081 tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build());
1082
1083 tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build());
1084 tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build());
1085
1086 // Schedule mma output
1087 // ---------------------------------------------------------------------------
1088 tv3c->applyMmaSwizzle(
1089 mma_builder1.operand(MmaOptions::Operand::Accumulator).build());
1090 tv3cw->applyMmaSwizzle(
1091 mma_builder1.operand(MmaOptions::Operand::Accumulator).build());
1092 tv3h->applyMmaSwizzle(
1093 mma_builder1.operand(MmaOptions::Operand::Accumulator).build());
1094 tv3->applyMmaSwizzle(
1095 mma_builder1.operand(MmaOptions::Operand::Accumulator).build());
1096 tv3cw->setMemoryType(MemoryType::Shared);
1097
1098 // Parallelize
1099 // 0 1 2 3 4 5 6 7
1100 // [Mo No Mwo Nwo Mw Nw (Mi Ni)]
1101 // Gemm 1
1102 tv3c->axis(3)->parallelize(ParallelType::TIDz);
1103 tv3c->axis(4)->parallelize(ParallelType::TIDy);
1104
1105 tv3->computeAt(tv3cw, -2);
1106 tv3cw->axis(2)->parallelize(ParallelType::TIDz);
1107 tv3cw->axis(3)->parallelize(ParallelType::TIDy);
1108
1109 // Gemm 2
1110 tv4->axis(2)->parallelize(ParallelType::TIDz);
1111 tv4->axis(3)->parallelize(ParallelType::TIDy);
1112 tv4c->axis(3)->parallelize(ParallelType::TIDz);
1113 tv4c->axis(4)->parallelize(ParallelType::TIDy);
1114
1115 tv4->axis(0)->parallelize(ParallelType::BIDx);
1116 tv4->axis(1)->parallelize(ParallelType::BIDy);
1117
1118 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
1119 auto t0 = at::randn({M, K1}, options);
1120 auto t1 = at::randn({K2, K1}, options);
1121 auto t2 = at::randn({N, K2}, options);
1122
1123 auto tref = t0.to(at::kFloat)
1124 .matmul(t1.t().to(at::kFloat))
1125 .matmul(t2.t().to(at::kFloat));
1126
1127 FusionExecutor fe;
1128
1129 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
1130 8, 0, fe.compileFusion(&fusion, {t0, t1, t2}));
1131
1132 auto cg_outputs = fe.runFusion({t0, t1, t2});
1133
1134 // relaxed check for now, err accumulation is significant.
1135 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1));
1136}
1137
1138// Simplified Matmul-Softmax-Matmul test on Ampere
1139// (To be extended in follow ups)
1140TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) {
1141 NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
1142
1143 Fusion fusion;
1144 FusionGuard fg(&fusion);
1145
1146 // Omitting outer dimensions and pointwise ops
1147
1148 const int seql_q = 32;
1149 const int seql_k = 128;
1150 const int hidden_size = 1024;
1151 const int num_heads = 16;
1152 const int head_dim = hidden_size / num_heads;
1153
1154 // Gemm 1:
1155 // (80, 80, 64)
1156 const int M1 = seql_q, N1 = seql_k, K1 = head_dim;
1157 // (80, 64, 80)
1158 const int M2 = seql_q, N2 = head_dim, K2 = seql_k;
1159
1160 // Fusion definition (Both gemms are TN)
1161 // [M,K1]
1162 auto inp = makeConcreteTensor({M1, K1}, DataType::Half);
1163 // Query matrix
1164 auto qk = makeConcreteTensor({N1, K1}, DataType::Half);
1165 // Second linear matrix
1166 auto acc = makeConcreteTensor({N2, K2}, DataType::Half);
1167
1168 fusion.addInput(inp);
1169 fusion.addInput(qk);
1170 fusion.addInput(acc);
1171
1172 // [M,N,K]
1173 auto tv0b = broadcast(inp, {false, true, false});
1174 auto tv1b = broadcast(qk, {true, false, false});
1175 auto tv2b = broadcast(acc, {true, false, false});
1176
1177 // [M,K2,R]
1178 auto tv3 = fusedMultiplySum(tv0b, tv1b, {2});
1179
1180 // Inline define softmax for now for scheduling
1181 auto x = tv3;
1182 const int kReductionAxis = 1;
1183 const int kNumberOfDims = 2;
1184 std::vector<bool> broadcast_mask(kNumberOfDims, false);
1185 broadcast_mask[kReductionAxis] = true;
1186
1187 auto max_val = max(x, {kReductionAxis});
1188 auto bcast_max = broadcast(max_val, broadcast_mask);
1189 auto x_max_sub = sub(x, bcast_max);
1190 auto exp_val = exp(x_max_sub);
1191 auto sum_exp = sum(exp_val, {kReductionAxis});
1192 auto bcast_sum = broadcast(sum_exp, broadcast_mask);
1193 auto recip = reciprocal(bcast_sum);
1194 auto tv3sfm = mul(exp_val, recip);
1195
1196 auto tv3h = castOp(DataType::Half, tv3sfm);
1197 auto tv3b = broadcast(tv3h, {false, true, false});
1198 auto tv4 = fusedMultiplySum(tv3b, tv2b, {2});
1199
1200 fusion.addOutput(tv4);
1201
1202 // Fusion:
1203 // Gemm(M,K2,K1) x Gemm(M,N,K2)
1204 MatMulTileOptions gemm_tile;
1205
1206 // TODO: use very small tiles for now since
1207 // alias pass is not re-using smem. Fix later.
1208 gemm_tile.cta_tile = GemmTile(32, 128, 32);
1209
1210 // Distribute to 2x2 warps
1211 gemm_tile.warp_tile = GemmTile(16, 64, 32);
1212
1213 // Using Ampere mma macro
1214 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
1215
1216 auto mma_builder1 =
1217 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
1218 .layout(MmaOptions::MmaInputLayout::TN);
1219
1220 auto mma_builder2 =
1221 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
1222 .layout(MmaOptions::MmaInputLayout::TN);
1223
1224 mma_builder1.configureMma(tv3);
1225 mma_builder2.configureMma(tv4);
1226
1227 // Global read for gemm 1
1228 auto tv0r = inp->cacheAfter();
1229 auto tv1r = qk->cacheAfter();
1230
1231 // Global read for gemm 2
1232 auto tv2r = acc->cacheAfter();
1233
1234 // Gemm 1 main loop read
1235 auto tv0cw = tv0r->cacheAfter();
1236 auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix);
1237 auto tv1cw = tv1r->cacheAfter();
1238 auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix);
1239
1240 // Gemm 1 accumulator reg
1241 auto tv3c = tv3->cacheBefore();
1242 mma_builder1.accumulatorTv(tv3c);
1243
1244 // Softmax conversion:
1245 auto tv3ccr = tv3->cacheAfter();
1246
1247 // tv3ccr -> tv3h : softmax
1248
1249 // Gemm 2 main loop read
1250 // auto tv3cw = tv3h->cacheAfter();
1251 auto tv3cr = tv3h->cacheAfter(LoadStoreOpType::LdMatrix);
1252
1253 auto tv2cw = tv2r->cacheAfter();
1254 auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix);
1255
1256 // Gemm 2 accumulator reg
1257 auto tv4c = tv4->cacheBefore();
1258 mma_builder2.accumulatorTv(tv4c);
1259
1260 // Schedule gemm 2:
1261 // ------------------------------------------------------------------
1262 tv4->split(-2, gemm_tile.cta_tile.m);
1263 tv4->split(-1, gemm_tile.cta_tile.n);
1264
1265 // 0 1 2 3
1266 // [Mo,M128, No, N128]
1267 tv4->reorder({{1, 2}, {2, 1}});
1268
1269 // 0 1 2 3
1270 // [Mo,No, M128, N128]
1271 acc->computeAt(tv4, 2);
1272 tv3->computeAt(tv4, 2);
1273
1274 // Order K
1275 // 0 1 2 3 4 5
1276 // [Mo,No, M128, N128, Ko, K32]
1277 tv4c->split(-1, gemm_tile.cta_tile.k);
1278 tv4c->reorder({{2, 3}, {3, 4}, {4, 2}});
1279
1280 // 0 1 2 3 4 5
1281 // [Mo,No, Ko M128, N128, K32]
1282 tv3->computeAt(tv4c, 2);
1283 tv2r->computeAt(tv4c, 3);
1284
1285 // Make warp tile
1286 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile);
1287 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
1288 tv4, gemm_tile);
1289 // -8 -7 -6 -5 -4 -3 -2 -1
1290 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
1291 tv3cr->computeAt(tv4c, -4);
1292 tv2cr->computeAt(tv4c, -4);
1293
1294 // Schedule tv2 gmem read and smem write:
1295 // ----------------------------------------------------------------
1296 // [No,Ko,N,K]
1297 tv2cw->merge(-2);
1298 tv2r->merge(-2);
1299
1300 // [No,Ko,i,wy,wx,v]
1301 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1302 tv2cw, gemm_tile, 8);
1303 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1304 tv2r, gemm_tile, 8);
1305 tv2cw->setMemoryType(MemoryType::Shared);
1306
1307 // Schedule tv2 gmem read and smem write:
1308 // ----------------------------------------------------------------
1309
1310 // Schedule gemm 2 mma input
1311 // ---------------------------------------------------------------------------
1312 tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build());
1313 // [... Mi, Ni, Ki] want [Ni, Mi, Ki]
1314 tv3b->reorder({{-2, -3}, {-3, -2}});
1315 tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build());
1316
1317 tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build());
1318 tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build());
1319
1320 // Schedule mma output
1321 // ---------------------------------------------------------------------------
1322 tv4c->applyMmaSwizzle(
1323 mma_builder2.operand(MmaOptions::Operand::Accumulator).build());
1324 tv4->applyMmaSwizzle(
1325 mma_builder2.operand(MmaOptions::Operand::Accumulator).build());
1326
1327 // Schedule gemm 1:
1328 // ------------------------------------------------------------------
1329
1330 // CTA tile:
1331 // [Mo, Mi128, N80]
1332
1333 tv3->split(-1, gemm_tile.cta_tile.n);
1334 // [Mo, Mi128, No, Ni128]
1335
1336 tv3->reorder({{1, 2}, {2, 1}});
1337
1338 // [Mo, No, Mi128, Ni128]
1339 inp->computeAt(tv3, 2);
1340 qk->computeAt(tv3, 2);
1341
1342 // Schedule K dim for gemm 1:
1343
1344 // Order K
1345 // 0 1 2 3 4 5
1346 // [Mo,No, M128, N128, Ko, K32]
1347 tv3c->split(-1, gemm_tile.cta_tile.k);
1348 tv3c->reorder({{2, 3}, {3, 4}, {4, 2}});
1349 // 0 1 2 3 4 5
1350 // [Mo,No, Ko M128, N128, K32]
1351 tv0r->computeAt(tv3c, 3);
1352 tv1r->computeAt(tv3c, 3);
1353
1354 // Make warp tile:
1355 // -------------------------------------------------------------------------
1356 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv3c, gemm_tile);
1357 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
1358 tv3, gemm_tile);
1359
1360 tv0cr->computeAt(tv3c, -4);
1361 tv1cr->computeAt(tv3c, -4);
1362
1363 // tv3->computeAt(tv3cw,-3);
1364
1365 // Schedule gmem read and smem write:
1366 // ---------------------------------------------------------------------------
1367 // [Mo,Ko,M,K]
1368 tv0cw->merge(-2);
1369 tv0r->merge(-2);
1370 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1371 tv0cw, gemm_tile, 8);
1372 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1373 tv0r, gemm_tile, 8);
1374 tv0cw->setMemoryType(MemoryType::Shared);
1375 // [Mo,Ko,i,wy,wx,v]
1376
1377 // [No,Ko,N,K]
1378 tv1cw->merge(-2);
1379 tv1r->merge(-2);
1380 // [No,Ko,i,wy,wx,v]
1381 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1382 tv1cw, gemm_tile, 8);
1383 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1384 tv1r, gemm_tile, 8);
1385 tv1cw->setMemoryType(MemoryType::Shared);
1386
1387 // Schedule mma input
1388 // ---------------------------------------------------------------------------
1389 tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build());
1390 // [... Mi, Ni, Ki] want [Ni, Mi, Ki]
1391 tv0b->reorder({{-2, -3}, {-3, -2}});
1392 tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build());
1393
1394 tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build());
1395 tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build());
1396
1397 // // Schedule mma output
1398 // //
1399 // ---------------------------------------------------------------------------
1400 tv3c->applyMmaSwizzle(
1401 mma_builder1.operand(MmaOptions::Operand::Accumulator).build());
1402 tv3->applyMmaSwizzle(
1403 mma_builder1.operand(MmaOptions::Operand::Accumulator).build());
1404
1405 // mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(tv3ccw,
1406 // mma_builder1.build());
1407
1408 // Put tv3 result in smem
1409 tv3->setMemoryType(MemoryType::Shared);
1410
1411 // schedule a reg persistent softmax: from tv3
1412 // [Mo, M128, RN]
1413 max_val->split(-1, 128);
1414 // [Mo, M128, RN1, RN128]
1415 max_val->split(-1, 4);
1416 // Map to warp (2x2)
1417 max_val->split(-4, 4);
1418 max_val->split(-4, 2);
1419
1420 // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4]
1421 auto max_rf = max_val->rFactor({-1});
1422 // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4]
1423
1424 // [Mo, M128, RN]
1425 sum_exp->split(-1, 128);
1426 // [Mo, M128, RN1, RN128]
1427 sum_exp->split(-1, 4);
1428 // Map to warp (2x2)
1429 sum_exp->split(-4, 4);
1430 sum_exp->split(-4, 2);
1431
1432 // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4]
1433 auto sum_exp_rf = sum_exp->rFactor({-1});
1434 // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4]
1435
1436 exp_val->computeAt(sum_exp_rf, 4);
1437 exp_val->split(-1, 128);
1438 exp_val->split(-1, 4);
1439 bcast_max->computeAt(exp_val, -2);
1440
1441 // [Mo, Mo32, My2, Mx2, IN1, I32, INi4]
1442
1443 // Read from smem
1444 tv3ccr->computeAt(max_rf, 4);
1445 // [Mo, Mo32, My2, Mx2, N80]
1446 tv3ccr->split(-1, 128);
1447 tv3ccr->split(-1, 4);
1448 // [Mo, Mo32, My2, Mx2, IN1, I32, INi4]
1449
1450 // Write to second gemm
1451 tv3h->split(-1, 128);
1452 tv3h->split(-1, 4);
1453 // Map to warp (2x2)
1454 tv3h->split(-4, 4);
1455 tv3h->split(-4, 2);
1456
1457 bcast_sum->computeAt(tv3h, -2);
1458
1459 tv3h->setMemoryType(MemoryType::Shared);
1460
1461 // Parallelize
1462 tv4->axis(0)->parallelize(ParallelType::BIDx);
1463 // 0 1 2 3 4 5 6 7
1464 // [Mo No Mwo Nwo Mw Nw (Mi Ni)]
1465 // Gemm 1
1466 tv3c->axis(3)->parallelize(ParallelType::TIDz);
1467 tv3c->axis(4)->parallelize(ParallelType::TIDy);
1468 tv3->axis(2)->parallelize(ParallelType::TIDz);
1469 tv3->axis(3)->parallelize(ParallelType::TIDy);
1470
1471 auto parallelize_non_reduced_val = [](TensorView* tv) {
1472 tv->axis(-2)->parallelize(ParallelType::TIDx);
1473 tv->axis(2)->parallelize(ParallelType::TIDz);
1474 tv->axis(3)->parallelize(ParallelType::TIDy);
1475 };
1476
1477 auto parallelize_reduced_val = [](TensorView* tv) {
1478 tv->axis(-1)->parallelize(ParallelType::TIDx);
1479 tv->axis(2)->parallelize(ParallelType::TIDz);
1480 tv->axis(3)->parallelize(ParallelType::TIDy);
1481 };
1482
1483 parallelize_non_reduced_val(tv3h);
1484 parallelize_non_reduced_val(max_rf);
1485 parallelize_non_reduced_val(bcast_max);
1486 parallelize_non_reduced_val(exp_val);
1487 parallelize_non_reduced_val(sum_exp_rf);
1488 parallelize_non_reduced_val(bcast_sum);
1489 parallelize_non_reduced_val(recip);
1490
1491 parallelize_reduced_val(max_val);
1492 parallelize_reduced_val(sum_exp);
1493
1494 // 0 1 2 3 4 5 6 7
1495 // [Mo No Mwo Nwo Mw Nw (Mi Ni)]
1496 // Gemm 2
1497 tv4->axis(2)->parallelize(ParallelType::TIDz);
1498 tv4->axis(3)->parallelize(ParallelType::TIDy);
1499 tv4c->axis(3)->parallelize(ParallelType::TIDz);
1500 tv4c->axis(4)->parallelize(ParallelType::TIDy);
1501
1502 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
1503 auto t0 = at::randn({M1, K1}, options);
1504 auto t1 = at::randn({N1, K1}, options);
1505 auto t2 = at::randn({N2, K2}, options);
1506
1507 FusionExecutor fe;
1508
1509 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
1510 8, 0, fe.compileFusion(&fusion, {t0, t1, t2}));
1511
1512 auto cg_outputs = fe.runFusion({t0, t1, t2});
1513
1514 auto g1 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));
1515 auto sg1 = at::_softmax(g1, -1, false);
1516 auto gsg1 = sg1.matmul(t2.t().to(at::kFloat));
1517
1518 TORCH_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001));
1519}
1520
1521// MMA unit test on Turing
1522TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) {
1523 Fusion fusion;
1524 FusionGuard fg(&fusion);
1525
1526 // [M,K]
1527 auto tv0 = makeConcreteTensor({16, 16}, DataType::Half);
1528 // [N,K]
1529 auto tv1 = makeConcreteTensor({8, 16}, DataType::Half);
1530 fusion.addInput(tv0);
1531 fusion.addInput(tv1);
1532
1533 // [M,N,K]
1534 auto tv0b = broadcast(tv0, {false, true, false});
1535 auto tv1b = broadcast(tv1, {true, false, false});
1536
1537 // Leaving both sets of mma inputs for volta outside
1538 // currently since they need to be swizzled.
1539 auto tv2 = fusedMultiplySum(tv0b, tv1b, {2});
1540
1541 fusion.addOutput(tv2);
1542
1543 MatMulTileOptions gemm_tile;
1544 gemm_tile.cta_tile = GemmTile(16, 8, 16);
1545 gemm_tile.warp_tile = GemmTile(16, 8, 16);
1546 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
1547
1548 auto mma_builder =
1549 MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile)
1550 .layout(MmaOptions::MmaInputLayout::TN);
1551
1552 mma_builder.configureMma(tv2);
1553
1554 auto tv0cw = tv0b->cacheAfter();
1555 auto tv0cr =
1556 tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
1557 auto tv1cw = tv1b->cacheAfter();
1558 auto tv1cr =
1559 tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
1560
1561 auto tv2c = tv2->cacheBefore();
1562 mma_builder.accumulatorTv(tv2c);
1563
1564 // [M,N,K] -> [N,M,K]
1565 tv0cr->reorder({{-2, -3}, {-3, -2}});
1566 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
1567 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
1568 tv2c->applyMmaSwizzle(
1569 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
1570 tv2->applyMmaSwizzle(
1571 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
1572
1573 tv0cw->setMemoryType(MemoryType::Shared);
1574 tv1cw->setMemoryType(MemoryType::Shared);
1575
1576 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
1577 auto t0 = at::randn({16, 16}, options);
1578 auto t1 = at::randn({8, 16}, options);
1579
1580 FusionExecutor fe;
1581 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
1582 7, 5, fe.compileFusion(&fusion, {t0, t1}));
1583
1584 auto cg_outputs = fe.runFusion({t0, t1});
1585
1586 auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));
1587
1588 testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
1589}
1590
1591// MMA unit test on Turing
1592TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) {
1593 Fusion fusion;
1594 FusionGuard fg(&fusion);
1595
1596 // [M,K]
1597 auto tv0 = makeConcreteTensor({16, 16}, DataType::Half);
1598 // [K,N]
1599 auto tv1 = makeConcreteTensor({16, 8}, DataType::Half);
1600 fusion.addInput(tv0);
1601 fusion.addInput(tv1);
1602
1603 // [M,K,N]
1604 auto tv0b = broadcast(tv0, {false, false, true});
1605 auto tv1b = broadcast(tv1, {true, false, false});
1606
1607 auto tv2 = fusedMultiplySum(tv0b, tv1b, {1});
1608
1609 fusion.addOutput(tv2);
1610
1611 MatMulTileOptions gemm_tile;
1612 gemm_tile.cta_tile = GemmTile(16, 8, 16);
1613 gemm_tile.warp_tile = GemmTile(16, 8, 16);
1614 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
1615
1616 auto mma_builder =
1617 MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile)
1618 .layout(MmaOptions::MmaInputLayout::TT);
1619
1620 mma_builder.configureMma(tv2);
1621
1622 auto tv0cw = tv0b->cacheAfter();
1623 auto tv0cr =
1624 tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
1625 auto tv1cw = tv1b->cacheAfter();
1626 auto tv1cr =
1627 tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
1628
1629 auto tv2c = tv2->cacheBefore();
1630 mma_builder.accumulatorTv(tv2c);
1631
1632 // [M,K,N] -> [N,M,K]
1633 tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}});
1634 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
1635
1636 // [M,K,N] -> [M,N,K]
1637 tv1cr->reorder({{-2, -1}, {-1, -2}});
1638 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
1639
1640 // [M,K,N] -> [M,N,K]
1641 tv2c->reorder({{-2, -1}, {-1, -2}});
1642 tv2c->applyMmaSwizzle(
1643 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
1644 tv2->applyMmaSwizzle(
1645 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
1646
1647 tv0cw->setMemoryType(MemoryType::Shared);
1648 tv1cw->setMemoryType(MemoryType::Shared);
1649
1650 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
1651 auto t0 = at::randn({16, 16}, options);
1652 auto t1 = at::randn({16, 8}, options);
1653
1654 FusionExecutor fe;
1655 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
1656 7, 5, fe.compileFusion(&fusion, {t0, t1}));
1657
1658 auto cg_outputs = fe.runFusion({t0, t1});
1659
1660 auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat));
1661
1662 testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
1663}
1664
1665// MMA unit test on Turing
1666TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) {
1667 Fusion fusion;
1668 FusionGuard fg(&fusion);
1669
1670 // [K,M]
1671 auto tv0 = makeConcreteTensor({16, 16}, DataType::Half);
1672 // [K,N]
1673 auto tv1 = makeConcreteTensor({16, 8}, DataType::Half);
1674 fusion.addInput(tv0);
1675 fusion.addInput(tv1);
1676
1677 // [K,M,N]
1678 auto tv0b = broadcast(tv0, {false, false, true});
1679 auto tv1b = broadcast(tv1, {false, true, false});
1680 auto tv2 = fusedMultiplySum(tv0b, tv1b, {0});
1681
1682 fusion.addOutput(tv2);
1683
1684 MatMulTileOptions gemm_tile;
1685 gemm_tile.cta_tile = GemmTile(16, 8, 16);
1686 gemm_tile.warp_tile = GemmTile(16, 8, 16);
1687 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
1688
1689 auto mma_builder =
1690 MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile)
1691 .layout(MmaOptions::MmaInputLayout::NT);
1692
1693 mma_builder.configureMma(tv2);
1694
1695 auto tv0cw = tv0b->cacheAfter();
1696 auto tv0cr =
1697 tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
1698 auto tv1cw = tv1b->cacheAfter();
1699 auto tv1cr =
1700 tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
1701
1702 auto tv2c = tv2->cacheBefore();
1703 mma_builder.accumulatorTv(tv2c);
1704
1705 // [K,M,N] -> [N,M,K]
1706 tv0cr->reorder({{-3, -1}, {-1, -3}});
1707 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
1708
1709 // [K,M,N] -> [M,N,K]
1710 tv1cr->reorder({
1711 {-3, -1},
1712 {-2, -3},
1713 {-1, -2},
1714 });
1715 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
1716
1717 // [K,M,N] -> [M,N,K]
1718 tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}});
1719 tv2c->applyMmaSwizzle(
1720 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
1721 tv2->applyMmaSwizzle(
1722 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
1723
1724 tv0cw->setMemoryType(MemoryType::Shared);
1725 tv1cw->setMemoryType(MemoryType::Shared);
1726
1727 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
1728 auto t0 = at::randn({16, 16}, options);
1729 auto t1 = at::randn({16, 8}, options);
1730
1731 FusionExecutor fe;
1732 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
1733 7, 5, fe.compileFusion(&fusion, {t0, t1}));
1734
1735 auto cg_outputs = fe.runFusion({t0, t1});
1736
1737 auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat));
1738
1739 testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
1740}
1741
1742// Matmul test for Turing MMA: across supported layouts
1743TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) {
1744 // Keep multiples of 8 to keep vectorizable.
1745 int M = 504, N = 136, K = 248;
1746
1747 for (auto layout : kAllSupportedLayout) {
1748 Fusion fusion;
1749 FusionGuard fg(&fusion);
1750 auto tv0 = makeContigTensor(2, DataType::Half);
1751 auto tv1 = makeContigTensor(2, DataType::Half);
1752
1753 fusion.addInput(tv0);
1754 fusion.addInput(tv1);
1755
1756 auto tv2 = matmul(tv0, tv1, layout);
1757
1758 fusion.addOutput(tv2);
1759
1760 MatMulTileOptions gemm_tile;
1761 gemm_tile.cta_tile = GemmTile(128, 128, 32);
1762 gemm_tile.warp_tile = GemmTile(64, 64, 32);
1763 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
1764
1765 auto mma_builder =
1766 MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile)
1767 .layout(layout);
1768
1769 MatmulParam params(mma_builder);
1770 params.tile_sizes = gemm_tile;
1771 scheduleMatmul(tv2, tv0, tv1, params);
1772
1773 at::manual_seed(0);
1774 auto inputs = fp16MatmulAtInput(M, N, K, layout);
1775
1776 FusionExecutor fe;
1777 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
1778 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
1779 auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
1780 auto tref = atMatmul(
1781 inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
1782 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
1783 }
1784}
1785
1786// Matmul test on ampere, using ampere memory ops
1787TEST_F(NVFuserTest, FusionAmpereMatmulTNcpAsync_CUDA) {
1788 Fusion fusion;
1789 FusionGuard fg(&fusion);
1790
1791 int M = 255, N = 511, K = 88;
1792
1793 // [M,K]
1794 auto tv0 = makeContigTensor(2, DataType::Half);
1795 // [N,K]
1796 auto tv1 = makeContigTensor(2, DataType::Half);
1797 fusion.addInput(tv0);
1798 fusion.addInput(tv1);
1799
1800 // [M,N,K]
1801 auto tv0b = broadcast(tv0, {false, true, false});
1802 auto tv1b = broadcast(tv1, {true, false, false});
1803
1804 // Leaving both sets of mma inputs for volta outside
1805 // currently since they need to be swizzled.
1806 auto tv2 = fusedMultiplySum(tv0b, tv1b, {2});
1807
1808 fusion.addOutput(tv2);
1809
1810 MatMulTileOptions gemm_tile;
1811 gemm_tile.cta_tile = GemmTile(128, 128, 32);
1812 gemm_tile.warp_tile = GemmTile(64, 64, 32);
1813 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
1814
1815 auto mma_builder =
1816 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
1817 .layout(MmaOptions::MmaInputLayout::TN);
1818
1819 mma_builder.configureMma(tv2);
1820
1821 auto tv0cw = tv0->cacheAfter(LoadStoreOpType::CpAsync);
1822 auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix);
1823 auto tv1cw = tv1->cacheAfter(LoadStoreOpType::CpAsync);
1824 auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix);
1825 auto tv2c = tv2->cacheBefore();
1826 mma_builder.accumulatorTv(tv2c);
1827
1828 // Make a CTA tile
1829 // ------------------------------------------------------------------
1830 // [M,N]
1831 tv2->split(-2, gemm_tile.cta_tile.m);
1832 tv2->split(-1, gemm_tile.cta_tile.n);
1833
1834 // 0 1 2 3
1835 // [Mo,M128, No, N128]
1836 tv2->reorder({{1, 2}, {2, 1}});
1837
1838 // 0 1 2 3
1839 // [Mo,No, M128, N128]
1840 tv0->computeAt(tv2, 2);
1841 tv1->computeAt(tv2, 2);
1842
1843 // Order K
1844 // 0 1 2 3 4 5
1845 // [Mo,No, M128, N128, Ko, K32]
1846 tv2c->split(-1, gemm_tile.cta_tile.k);
1847 tv2c->reorder({{2, 3}, {3, 4}, {4, 2}});
1848
1849 // 0 1 2 3 4 5
1850 // [Mo,No, Ko M128, N128, K32]
1851 tv0cw->computeAt(tv2c, 3);
1852 tv1cw->computeAt(tv2c, 3);
1853
1854 // Make warp tile:
1855 // -------------------------------------------------------------------------
1856 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
1857 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
1858 tv2, gemm_tile);
1859 // -8 -7 -6 -5 -4 -3 -2 -1
1860 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
1861 tv0cr->computeAt(tv2c, -4);
1862 tv1cr->computeAt(tv2c, -4);
1863
1864 // Schedule gmem read and smem write:
1865 // ---------------------------------------------------------------------------
1866 // [Mo,Ko,M,K]
1867 tv0cw->merge(-2);
1868 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1869 tv0cw, gemm_tile, 8);
1870 tv0cw->setMemoryType(MemoryType::Shared);
1871 // [Mo,Ko,i,wy,wx,v]
1872
1873 // [No,Ko,N,K]
1874 tv1cw->merge(-2);
1875 // [No,Ko,i,wy,wx,v]
1876 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
1877 tv1cw, gemm_tile, 8);
1878 tv1cw->setMemoryType(MemoryType::Shared);
1879 // Schedule mma input
1880 // ---------------------------------------------------------------------------
1881 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
1882 // [... Mi, Ni, Ki]
1883 tv0b->reorder({{-2, -3}, {-3, -2}});
1884 tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
1885
1886 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
1887 tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
1888
1889 // Schedule mma output
1890 // ---------------------------------------------------------------------------
1891 tv2c->applyMmaSwizzle(
1892 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
1893 tv2->applyMmaSwizzle(
1894 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
1895
1896 // Parallelize
1897 // 0 1 2 3 4 5 6 7 8 9 10
1898 // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
1899 tv2c->axis(3)->parallelize(ParallelType::TIDz);
1900 tv2c->axis(4)->parallelize(ParallelType::TIDy);
1901
1902 // Parallelize
1903 // 0 1 2 3 4 5 6 7
1904 // [Mo No Mwo Nwo Mw Nw (Mi Ni)]
1905 tv2->axis(0)->parallelize(ParallelType::BIDx);
1906 tv2->axis(1)->parallelize(ParallelType::BIDy);
1907 tv2->axis(2)->parallelize(ParallelType::TIDz);
1908 tv2->axis(3)->parallelize(ParallelType::TIDy);
1909
1910 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
1911 auto t0 = at::randn({M, K}, options);
1912 auto t1 = at::randn({N, K}, options);
1913
1914 FusionExecutor fe;
1915 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
1916 8, 0, fe.compileFusion(&fusion, {t0, t1}));
1917
1918 auto cg_outputs = fe.runFusion({t0, t1});
1919
1920 auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));
1921
1922 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
1923}
1924
1925TEST_F(NVFuserTest, FusionAmpereStridedBatchedMatmulTN_CUDA) {
1926 NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
1927
1928 Fusion fusion;
1929 FusionGuard fg(&fusion);
1930 int M = 511, N = 123, K = 88, B0 = 3, B1 = 5;
1931
1932 // [B0 ,M, B1,K]
1933 auto tv0 = makeContigTensor(4, DataType::Half);
1934 // [B0, N, B1, K]
1935 auto tv1 = makeContigTensor(4, DataType::Half);
1936 fusion.addInput(tv0);
1937 fusion.addInput(tv1);
1938
1939 // [B0,M,N,B1,K]
1940 auto tv0b = broadcast(tv0, {false, false, true, false, false});
1941 auto tv1b = broadcast(tv1, {false, true, false, false, false});
1942
1943 // Leaving both sets of mma inputs for volta outside
1944 // currently since they need to be swizzled.
1945 auto tv2 = fusedMultiplySum(tv0b, tv1b, {4});
1946
1947 fusion.addOutput(tv2);
1948
1949 MatMulTileOptions gemm_tile;
1950 gemm_tile.cta_tile = GemmTile(128, 128, 32);
1951 gemm_tile.warp_tile = GemmTile(64, 64, 32);
1952 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
1953
1954 auto mma_builder =
1955 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
1956 .layout(MmaOptions::MmaInputLayout::TN);
1957
1958 mma_builder.configureMma(tv2);
1959
1960 auto tv0r = tv0->cacheAfter();
1961 auto tv1r = tv1->cacheAfter();
1962 auto tv0cw = tv0r->cacheAfter();
1963 auto tv0cr =
1964 tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
1965 auto tv1cw = tv1r->cacheAfter();
1966 auto tv1cr =
1967 tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
1968 auto tv2c = tv2->cacheBefore();
1969 mma_builder.accumulatorTv(tv2c);
1970
1971 // Group the BATCHED DIMS:
1972 // -4 -3 -2 -1
1973 // [B0, M, N, B1]
1974 tv2->reorder({{-3, -2}, {-2, -1}, {-1, -4}});
1975
1976 // -4 -3 -2 -1
1977 // [B0, B1, M,N]
1978
1979 // Make a CTA tile
1980 // ------------------------------------------------------------------
1981 // [B0, B1, M, N]
1982 tv2->split(-2, gemm_tile.cta_tile.m);
1983 tv2->split(-1, gemm_tile.cta_tile.n);
1984
1985 // 0 1 2 3 4 5
1986 // [B0, B1, Mo,M128, No, N128]
1987 tv2->reorder({{-3, -2}, {-2, -3}});
1988
1989 // 0 1 2 3 4 5
1990 // [B0, B1, Mo, No, M128, N128]
1991
1992 // Merge the outer dims:
1993 tv2->merge(0);
1994 tv2->merge(0);
1995
1996 // 0 1 2 3
1997 // [Mo,No, M128, N128]
1998 tv0->computeAt(tv2, 2);
1999 tv1->computeAt(tv2, 2);
2000
2001 // Order K
2002 // 0 1 2 3 4 5
2003 // [Mo,No, M128, N128, Ko, K32]
2004 tv2c->split(-1, gemm_tile.cta_tile.k);
2005 tv2c->reorder({{2, 3}, {3, 4}, {4, 2}});
2006
2007 // 0 1 2 3 4 5
2008 // [Mo,No, Ko M128, N128, K32]
2009 tv0r->computeAt(tv2c, 3);
2010 tv1r->computeAt(tv2c, 3);
2011
2012 // Make warp tile:
2013 // -------------------------------------------------------------------------
2014 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
2015 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
2016 tv2, gemm_tile);
2017 // -8 -7 -6 -5 -4 -3 -2 -1
2018 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
2019 tv0cr->computeAt(tv2c, -4);
2020 tv1cr->computeAt(tv2c, -4);
2021
2022 // Schedule gmem read and smem write:
2023 // ---------------------------------------------------------------------------
2024 // [Mo,Ko,M,K]
2025 tv0cw->merge(-2);
2026 tv0r->merge(-2);
2027 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2028 tv0cw, gemm_tile, 8);
2029 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2030 tv0r, gemm_tile, 8);
2031 tv0cw->setMemoryType(MemoryType::Shared);
2032 // [Mo,Ko,i,wy,wx,v]
2033
2034 // [No,Ko,N,K]
2035 tv1cw->merge(-2);
2036 tv1r->merge(-2);
2037 // [No,Ko,i,wy,wx,v]
2038 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2039 tv1cw, gemm_tile, 8);
2040 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2041 tv1r, gemm_tile, 8);
2042 tv1cw->setMemoryType(MemoryType::Shared);
2043 // Schedule mma input
2044 // ---------------------------------------------------------------------------
2045 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
2046
2047 // [... Mi, Ni, Ki] want [Ni, Mi, Ki]
2048 tv0b->reorder({{-2, -3}, {-3, -2}});
2049 tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
2050
2051 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
2052 tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
2053
2054 // Schedule mma output
2055 // ---------------------------------------------------------------------------
2056 tv2c->applyMmaSwizzle(
2057 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2058 tv2->applyMmaSwizzle(
2059 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2060
2061 // Parallelize
2062 // 0 1 2 3 4 5 6 7 8 9 10
2063 // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
2064 tv2c->axis(3)->parallelize(ParallelType::TIDz);
2065 tv2c->axis(4)->parallelize(ParallelType::TIDy);
2066
2067 // Parallelize
2068 // 0 1 2 3 4 5 6 7
2069 // [Mo No Mwo Nwo Mw Nw (Mi Ni)]
2070 tv2->axis(0)->parallelize(ParallelType::BIDx);
2071 tv2->axis(1)->parallelize(ParallelType::BIDy);
2072 tv2->axis(2)->parallelize(ParallelType::TIDz);
2073 tv2->axis(3)->parallelize(ParallelType::TIDy);
2074
2075 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
2076 auto t0 = at::randn({B0, M, B1, K}, options);
2077 auto t1 = at::randn({B0, N, B1, K}, options);
2078
2079 FusionExecutor fe;
2080
2081 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
2082 8, 0, fe.compileFusion(&fusion, {t0, t1}));
2083
2084 auto cg_outputs = fe.runFusion({t0, t1});
2085
2086 // ref implementation:
2087 auto ref_t0 = t0.permute({0, 2, 1, 3})
2088 .contiguous()
2089 .view({B0 * B1, M, K}); // B0, B1, M, K
2090 auto ref_t1 = t1.permute({0, 2, 3, 1})
2091 .contiguous()
2092 .view({B0 * B1, K, N}); // B0, B1, K, N
2093 auto ref_permuted =
2094 ref_t0.to(at::kFloat).bmm(ref_t1.to(at::kFloat)); // B0*B1, M,N
2095 auto ref = ref_permuted.view({B0, B1, M, N})
2096 .permute({0, 2, 3, 1})
2097 .contiguous(); // B0,M,N,B1
2098 TORCH_CHECK(cg_outputs[0].allclose(ref, 0.0001, 0.0001));
2099}
2100
2101// Matmul test on Ampere with a view on prolog
2102TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) {
2103 NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
2104
2105 Fusion fusion;
2106 FusionGuard fg(&fusion);
2107 int M = 511, N = 257, K = 88;
2108 int Ko = 11, Ki = 8;
2109
2110 // [M,Ko,Ki]
2111 auto tv0 = makeContigTensor(3, DataType::Half);
2112 // [N,K]
2113 auto tv1 = makeContigTensor(2, DataType::Half);
2114 fusion.addInput(tv0);
2115 fusion.addInput(tv1);
2116
2117 auto tv0_view = view(tv0, {M, Ko, Ki}, {M, K});
2118
2119 // [M,N,K]
2120 auto tv0b = broadcast(tv0_view, {false, true, false});
2121 auto tv1b = broadcast(tv1, {true, false, false});
2122
2123 // Leaving both sets of mma inputs for volta outside
2124 // currently since they need to be swizzled.
2125 auto tv2 = fusedMultiplySum(tv0b, tv1b, {2});
2126
2127 fusion.addOutput(tv2);
2128
2129 MatMulTileOptions gemm_tile;
2130 gemm_tile.cta_tile = GemmTile(128, 128, 32);
2131 gemm_tile.warp_tile = GemmTile(64, 64, 32);
2132 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
2133
2134 auto mma_builder =
2135 MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile)
2136 .layout(MmaOptions::MmaInputLayout::TN);
2137
2138 mma_builder.configureMma(tv2);
2139
2140 auto tv0r = tv0->cacheAfter();
2141 auto tv1r = tv1->cacheAfter();
2142 auto tv0cw = tv0_view->cacheAfter();
2143 auto tv0cr =
2144 tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
2145 auto tv1cw = tv1r->cacheAfter();
2146 auto tv1cr =
2147 tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
2148 auto tv2c = tv2->cacheBefore();
2149 mma_builder.accumulatorTv(tv2c);
2150
2151 // Make a CTA tile
2152 // ------------------------------------------------------------------
2153 // [M,N]
2154 tv2->split(-2, gemm_tile.cta_tile.m);
2155 tv2->split(-1, gemm_tile.cta_tile.n);
2156
2157 // 0 1 2 3
2158 // [Mo,M128, No, N128]
2159 tv2->reorder({{1, 2}, {2, 1}});
2160
2161 // 0 1 2 3
2162 // [Mo,No, M128, N128]
2163 tv0->computeAt(tv2, 2);
2164 tv1->computeAt(tv2, 2);
2165
2166 // Order K
2167 // 0 1 2 3 4 5
2168 // [Mo,No, M128, N128, Ko, K32]
2169 tv2c->split(-1, gemm_tile.cta_tile.k);
2170 tv2c->reorder({{2, 3}, {3, 4}, {4, 2}});
2171
2172 // 0 1 2 3 4 5
2173 // [Mo,No, Ko M128, N128, K32]
2174 tv0r->computeAt(tv2c, 3);
2175 tv1r->computeAt(tv2c, 3);
2176
2177 // Make warp tile:
2178 // -------------------------------------------------------------------------
2179 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
2180 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
2181 tv2, gemm_tile);
2182 // -8 -7 -6 -5 -4 -3 -2 -1
2183 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
2184 tv0cr->computeAt(tv2c, -4);
2185 tv1cr->computeAt(tv2c, -4);
2186
2187 // Schedule gmem read and smem write:
2188 // ---------------------------------------------------------------------------
2189 // [Mo,Ko,M,K]
2190 tv0cw->merge(-2);
2191 tv0r->merge(-2);
2192 tv0_view->merge(-2);
2193 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2194 tv0cw, gemm_tile, 8);
2195 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2196 tv0r, gemm_tile, 8);
2197 tv0cw->setMemoryType(MemoryType::Shared);
2198 // [Mo,Ko,i,wy,wx,v]
2199
2200 // [No,Ko,N,K]
2201 tv1cw->merge(-2);
2202 tv1r->merge(-2);
2203 // [No,Ko,i,wy,wx,v]
2204 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2205 tv1cw, gemm_tile, 8);
2206 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2207 tv1r, gemm_tile, 8);
2208 tv1cw->setMemoryType(MemoryType::Shared);
2209 // Schedule mma input
2210 // ---------------------------------------------------------------------------
2211 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
2212
2213 // [... Mi, Ni, Ki] want [Ni, Mi, Ki]
2214 tv0b->reorder({{-2, -3}, {-3, -2}});
2215 tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
2216
2217 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
2218 tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
2219
2220 // Schedule mma output
2221 // ---------------------------------------------------------------------------
2222 tv2c->applyMmaSwizzle(
2223 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2224 tv2->applyMmaSwizzle(
2225 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2226
2227 // Inline the view op with the shared mem write minus
2228 // the vectorization axes for now.
2229 tv0_view->computeAt(tv0cw, -2);
2230
2231 // Parallelize
2232 // 0 1 2 3 4 5 6 7 8 9 10
2233 // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
2234 tv2c->axis(3)->parallelize(ParallelType::TIDz);
2235 tv2c->axis(4)->parallelize(ParallelType::TIDy);
2236
2237 // Parallelize
2238 // 0 1 2 3 4 5 6 7
2239 // [Mo No Mwo Nwo Mw Nw (Mi Ni)]
2240 tv2->axis(0)->parallelize(ParallelType::BIDx);
2241 tv2->axis(1)->parallelize(ParallelType::BIDy);
2242 tv2->axis(2)->parallelize(ParallelType::TIDz);
2243 tv2->axis(3)->parallelize(ParallelType::TIDy);
2244
2245 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
2246 auto t0 = at::randn({M, Ko, Ki}, options);
2247 auto t1 = at::randn({N, K}, options);
2248
2249 FusionExecutor fe;
2250
2251 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
2252 8, 0, fe.compileFusion(&fusion, {t0, t1}));
2253
2254 auto cg_outputs = fe.runFusion({t0, t1});
2255
2256 auto tref =
2257 at::native::view(t0, {M, K}).to(at::kFloat).matmul(t1.t().to(at::kFloat));
2258
2259 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
2260}
2261
2262// Initial test case for in-CTA split K with VoltaMMA
2263TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossWarp_CUDA) {
2264 NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0);
2265
2266 Fusion fusion;
2267 FusionGuard fg(&fusion);
2268 int M = 120, N = 264, K = 120;
2269
2270 // [M,K]
2271 auto tv0 = makeContigTensor(2, DataType::Half);
2272 // [N,K]
2273 auto tv1 = makeContigTensor(2, DataType::Half);
2274
2275 fusion.addInput(tv0);
2276 fusion.addInput(tv1);
2277
2278 // [M,N,K]
2279 auto tv0b = broadcast(tv0, {false, true, false});
2280 auto tv1b = broadcast(tv1, {true, false, false});
2281
2282 // Leaving both sets of mma inputs for volta outside
2283 // currently since they need to be swizzled.
2284 auto tv2 = fusedMultiplySum(tv0b, tv1b, {2});
2285
2286 fusion.addOutput(tv2);
2287
2288 MatMulTileOptions gemm_tile;
2289 gemm_tile.cta_tile = GemmTile(128, 128, 32);
2290 gemm_tile.warp_tile = GemmTile(64, 64, 16);
2291 gemm_tile.instruction_tile = GemmTile(16, 16, 4);
2292
2293 auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile)
2294 .layout(MmaOptions::MmaInputLayout::TN);
2295
2296 mma_builder.configureMma(tv2);
2297
2298 auto tv0r = tv0->cacheAfter();
2299 auto tv1r = tv1->cacheAfter();
2300 auto tv0cw = tv0b->cacheAfter();
2301 auto tv0cr = tv0cw->cacheAfter();
2302 auto tv1cw = tv1b->cacheAfter();
2303 auto tv1cr = tv1cw->cacheAfter();
2304 auto tv2c = tv2->cacheBefore();
2305
2306 // Make a CTA tile
2307 // ------------------------------------------------------------------
2308 // [M,N]
2309 tv2->split(-2, gemm_tile.cta_tile.m);
2310 tv2->split(-1, gemm_tile.cta_tile.n);
2311
2312 // 0 1 2 3
2313 // [Mo,M128, No, N128]
2314 tv2->reorder({{1, 2}, {2, 1}});
2315
2316 // 0 1 2 3
2317 // [Mo,No, M128, N128]
2318 tv0->computeAt(tv2, 2);
2319 tv1->computeAt(tv2, 2);
2320
2321 // Order K
2322 // 0 1 2 3 4 5
2323 // [Mo,No, M128, N128, Ko, K32]
2324 tv2c->split(-1, gemm_tile.cta_tile.k);
2325 tv2c->reorder({{2, 3}, {3, 4}, {4, 2}});
2326
2327 // 0 1 2 3 4 5
2328 // [Mo,No, Ko M128, N128, K32]
2329 tv0r->computeAt(tv2c, 3);
2330 tv1r->computeAt(tv2c, 3);
2331
2332 // Make warp tile:
2333 // -------------------------------------------------------------------------
2334 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
2335 auto tv2c_rf = tv2c->rFactor({-9, -4, -1});
2336
2337 // tv2c_rf is the actual output of the mma op after
2338 // Rfactoring.
2339 mma_builder.accumulatorTv(tv2c_rf);
2340
2341 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
2342 tv2, gemm_tile);
2343
2344 // -8 -7 -6 -5 -4 -3 -2 -1
2345 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
2346 tv0cr->computeAt(tv2c_rf, -4);
2347 tv1cr->computeAt(tv2c_rf, -4);
2348
2349 // Schedule gmem read and smem write:
2350 // ---------------------------------------------------------------------------
2351 // [Mo,No,Ko,M,N,K]
2352 tv0cw->reorder({
2353 {-3, -2},
2354 {-2, -3},
2355 });
2356 // [Mo,No,Ko,N,M,K]
2357 tv0cw->merge(-2);
2358 tv0r->merge(-2);
2359 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2360 tv0cw, gemm_tile, 8);
2361 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2362 tv0r, gemm_tile, 8);
2363 tv0cw->setMemoryType(MemoryType::Shared);
2364 // [Mo,Ko,i,wy,wx,v]
2365
2366 // [Mo,No,Ko,M,N,K]
2367 tv1cw->merge(-2);
2368 tv1r->merge(-2);
2369 // [Mo,No,Ko,i,wy,wx,v]
2370 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2371 tv1cw, gemm_tile, 8);
2372 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2373 tv1r, gemm_tile, 8);
2374 tv1cw->setMemoryType(MemoryType::Shared);
2375 // Schedule mma input
2376 // ---------------------------------------------------------------------------
2377 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
2378 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
2379
2380 // Schedule mma output
2381 // ---------------------------------------------------------------------------
2382 tv2c_rf->applyMmaSwizzle(
2383 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2384 tv2c->applyMmaSwizzle(
2385 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2386 tv2->applyMmaSwizzle(
2387 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2388
2389 tv0b->computeAt(tv0cw, -2);
2390 tv1b->computeAt(tv1cw, -2);
2391
2392 tv0cr->axis(-1)->parallelize(ParallelType::Vectorize);
2393 tv1cr->axis(-1)->parallelize(ParallelType::Vectorize);
2394 // Parallelize
2395 // 0 1 2 3 4 5 6 7 8 9 10
2396 // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
2397 tv2c_rf->axis(0)->parallelize(ParallelType::BIDx);
2398 tv2c_rf->axis(1)->parallelize(ParallelType::BIDy);
2399 tv2c_rf->axis(3)->parallelize(ParallelType::TIDz);
2400 tv2c_rf->axis(4)->parallelize(ParallelType::TIDy);
2401
2402 tv2c->axis(2)->parallelize(ParallelType::TIDz);
2403 tv2c->axis(3)->parallelize(ParallelType::TIDy);
2404
2405 // Parallelize
2406 // 0 1 2 3 4 5 6 7
2407 // [Mo No Mwo Nwo Mw Nw (Mi Ni)]
2408 tv2->axis(0)->parallelize(ParallelType::BIDx);
2409 tv2->axis(1)->parallelize(ParallelType::BIDy);
2410 tv2->axis(2)->parallelize(ParallelType::TIDz);
2411
2412 at::manual_seed(0);
2413 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
2414 auto t0 = at::randn({M, K}, options);
2415 auto t1 = at::randn({N, K}, options);
2416
2417 FusionExecutor fe;
2418 fe.compileFusion(&fusion, {t0, t1});
2419 auto cg_outputs = fe.runFusion({t0, t1});
2420 auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t());
2421 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
2422}
2423
2424// Initial test case for cross-CTA split K with VoltaMMA
2425TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossCTA_CUDA) {
2426 NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0);
2427
2428 Fusion fusion;
2429 FusionGuard fg(&fusion);
2430 int M = 120, N = 264, K = 120;
2431
2432 // [M,K]
2433 auto tv0 = makeContigTensor(2, DataType::Half);
2434 // [N,K]
2435 auto tv1 = makeContigTensor(2, DataType::Half);
2436
2437 fusion.addInput(tv0);
2438 fusion.addInput(tv1);
2439
2440 // [M,N,K]
2441 auto tv0b = broadcast(tv0, {false, true, false});
2442 auto tv1b = broadcast(tv1, {true, false, false});
2443
2444 // Leaving both sets of mma inputs for volta outside
2445 // currently since they need to be swizzled.
2446 auto tv2 = fusedMultiplySum(tv0b, tv1b, {2});
2447
2448 fusion.addOutput(tv2);
2449
2450 MatMulTileOptions gemm_tile;
2451 gemm_tile.cta_tile = GemmTile(128, 128, 32);
2452 gemm_tile.warp_tile = GemmTile(64, 64, 32);
2453 gemm_tile.instruction_tile = GemmTile(16, 16, 4);
2454
2455 auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile)
2456 .layout(MmaOptions::MmaInputLayout::TN);
2457
2458 mma_builder.configureMma(tv2);
2459
2460 auto tv0r = tv0->cacheAfter();
2461 auto tv1r = tv1->cacheAfter();
2462 auto tv0cw = tv0b->cacheAfter();
2463 auto tv0cr = tv0cw->cacheAfter();
2464 auto tv1cw = tv1b->cacheAfter();
2465 auto tv1cr = tv1cw->cacheAfter();
2466 auto tv2c = tv2->cacheBefore();
2467
2468 // Make a CTA tile
2469 // ------------------------------------------------------------------
2470 // [M,N]
2471 tv2->split(-2, gemm_tile.cta_tile.m);
2472 tv2->split(-1, gemm_tile.cta_tile.n);
2473
2474 // 0 1 2 3
2475 // [Mo,M128, No, N128]
2476 tv2->reorder({{1, 2}, {2, 1}});
2477
2478 // 0 1 2 3
2479 // [Mo,No, M128, N128]
2480 tv0->computeAt(tv2, 2);
2481 tv1->computeAt(tv2, 2);
2482
2483 // Order K
2484 // 0 1 2 3 4 5
2485 // [Mo,No, M128, N128, Ko, K32]
2486 tv2c->split(-1, gemm_tile.cta_tile.k);
2487 tv2c->split(-2, 2, true);
2488 // Order K
2489 // 0 1 2 3 4 5 6
2490 // [Mo,No, M128, N128, Ko, K2CTA, K32]
2491 tv2c->reorder({{2, 4}, {3, 5}, {4, 3}, {5, 2}});
2492 // 0 1 2 3 4 5 6
2493 // [Mo,No, K2CTA, Ko M128, N128, K32]
2494 tv0r->computeAt(tv2c, 4);
2495 tv1r->computeAt(tv2c, 4);
2496
2497 // Make warp tile:
2498 // -------------------------------------------------------------------------
2499 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
2500 auto tv2c_rf = tv2c->rFactor({-9, -6, -1});
2501
2502 // tv2c_rf is the actual output of the mma op after
2503 // Rfactoring.
2504 mma_builder.accumulatorTv(tv2c_rf);
2505
2506 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
2507 tv2, gemm_tile);
2508
2509 // -8 -7 -6 -5 -4 -3 -2 -1
2510 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
2511 tv0cr->computeAt(tv2c_rf, -4);
2512 tv1cr->computeAt(tv2c_rf, -4);
2513
2514 // Schedule gmem read and smem write:
2515 // ---------------------------------------------------------------------------
2516 // [Mo,No,Ko,M,N,K]
2517 tv0cw->reorder({
2518 {-3, -2},
2519 {-2, -3},
2520 });
2521 // [Mo,No,Ko,N,M,K]
2522 tv0cw->merge(-2);
2523 tv0r->merge(-2);
2524 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2525 tv0cw, gemm_tile, 8);
2526 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2527 tv0r, gemm_tile, 8);
2528 tv0cw->setMemoryType(MemoryType::Shared);
2529 // [Mo,Ko,i,wy,wx,v]
2530
2531 // [Mo,No,Ko,M,N,K]
2532 tv1cw->merge(-2);
2533 tv1r->merge(-2);
2534 // [Mo,No,Ko,i,wy,wx,v]
2535 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2536 tv1cw, gemm_tile, 8);
2537 scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
2538 tv1r, gemm_tile, 8);
2539 tv1cw->setMemoryType(MemoryType::Shared);
2540 // Schedule mma input
2541 // ---------------------------------------------------------------------------
2542 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
2543 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
2544
2545 // Schedule mma output
2546 // ---------------------------------------------------------------------------
2547 tv2c_rf->applyMmaSwizzle(
2548 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2549 tv2c->applyMmaSwizzle(
2550 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2551 tv2->applyMmaSwizzle(
2552 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2553
2554 tv0b->computeAt(tv0cw, -2);
2555 tv1b->computeAt(tv1cw, -2);
2556
2557 tv0cr->axis(-1)->parallelize(ParallelType::Vectorize);
2558 tv1cr->axis(-1)->parallelize(ParallelType::Vectorize);
2559 // Parallelize
2560 // 0 1 2 3 4 5 6 7 8 9 10
2561 // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
2562 tv2c_rf->axis(0)->parallelize(ParallelType::BIDx);
2563 tv2c_rf->axis(1)->parallelize(ParallelType::BIDy);
2564 tv2c_rf->axis(2)->parallelize(ParallelType::BIDz);
2565 tv2c_rf->axis(4)->parallelize(ParallelType::TIDz);
2566 tv2c_rf->axis(5)->parallelize(ParallelType::TIDy);
2567
2568 tv2c->axis(0)->parallelize(ParallelType::BIDx);
2569 tv2c->axis(1)->parallelize(ParallelType::BIDy);
2570 tv2c->axis(2)->parallelize(ParallelType::BIDz);
2571 tv2c->axis(3)->parallelize(ParallelType::TIDz);
2572 tv2c->axis(4)->parallelize(ParallelType::TIDy);
2573
2574 // Parallelize
2575 // 0 1 2 3 4 5 6 7
2576 // [Mo No Mwo Nwo Mw Nw (Mi Ni)]
2577 tv2->axis(0)->parallelize(ParallelType::BIDx);
2578 tv2->axis(1)->parallelize(ParallelType::BIDy);
2579 tv2->axis(2)->parallelize(ParallelType::TIDz);
2580 tv2->axis(3)->parallelize(ParallelType::TIDy);
2581
2582 at::manual_seed(0);
2583 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
2584 auto t0 = at::randn({M, K}, options);
2585 auto t1 = at::randn({N, K}, options);
2586
2587 FusionExecutor fe;
2588 fe.compileFusion(&fusion, {t0, t1});
2589 auto cg_outputs = fe.runFusion({t0, t1});
2590 auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t());
2591 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
2592}
2593
2594// Test an end-to-end matmul case with swizzled smem
2595// data layout.
2596TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) {
2597 NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
2598
2599 Fusion fusion;
2600 FusionGuard fg(&fusion);
2601
2602 int M = 257, N = 511, K = 136;
2603
2604 MatMulTileOptions gemm_tile;
2605 gemm_tile.cta_tile = GemmTile(128, 128, 32);
2606 gemm_tile.warp_tile = GemmTile(64, 64, 32);
2607 gemm_tile.instruction_tile = GemmTile(16, 8, 16);
2608
2609 // [M,K]
2610 auto tv0 = makeContigTensor(2, DataType::Half);
2611 // [N,K]
2612 auto tv1 = makeContigTensor(2, DataType::Half);
2613 fusion.addInput(tv0);
2614 fusion.addInput(tv1);
2615
2616 // [M,N,K]
2617 auto tv0b = broadcast(tv0, {false, true, false});
2618 auto tv1b = broadcast(tv1, {true, false, false});
2619
2620 auto mma_builder =
2621 MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile)
2622 .layout(MmaOptions::MmaInputLayout::TN);
2623
2624 auto tv2 = fusedMultiplySum(tv0b, tv1b, {2});
2625
2626 fusion.addOutput(tv2);
2627
2628 mma_builder.configureMma(tv2);
2629
2630 auto tv0cw = tv0->cacheAfter(LoadStoreOpType::CpAsync);
2631 auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix);
2632 auto tv1cw = tv1->cacheAfter(LoadStoreOpType::CpAsync);
2633 auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix);
2634 auto tv2c = tv2->cacheBefore();
2635
2636 mma_builder.accumulatorTv(tv2c);
2637
2638 // Make a CTA tile
2639 // ------------------------------------------------------------------
2640 // [M,N]
2641 tv2->split(-2, gemm_tile.cta_tile.m);
2642 tv2->split(-1, gemm_tile.cta_tile.n);
2643
2644 // 0 1 2 3
2645 // [Mo,M128, No, N128]
2646 tv2->reorder({{1, 2}, {2, 1}});
2647
2648 // 0 1 2 3
2649 // [Mo,No, M128, N128]
2650 tv0->computeAt(tv2, 2);
2651 tv1->computeAt(tv2, 2);
2652
2653 // Order K
2654 // 0 1 2 3 4 5
2655 // [Mo,No, M128, N128, Ko, K32]
2656 tv2c->split(-1, gemm_tile.cta_tile.k);
2657 tv2c->reorder({{2, 3}, {3, 4}, {4, 2}});
2658
2659 // 0 1 2 3 4 5
2660 // [Mo,No, Ko M128, N128, K32]
2661 tv0cw->computeAt(tv2c, 3);
2662 tv1cw->computeAt(tv2c, 3);
2663
2664 // Make warp tile:
2665 //
2666 scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
2667 scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
2668 tv2, gemm_tile);
2669 // -8 -7 -6 -5 -4 -3 -2 -1
2670 // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
2671 tv0cr->computeAt(tv2c, -4);
2672 tv1cr->computeAt(tv2c, -4);
2673
2674 // Schedule gmem read and smem write:
2675 //
2676 // [Mo,Ko,M,K]
2677 // Swizzle tv0: 128 x 32 tile:
2678 tv0cw->split(-2, 8);
2679 tv0cw->split(-2, 2);
2680 tv0cw->split(-1, 8);
2681 // -5 -4 -3 -2 -1
2682 // [Mo,Ko,Mo16,M4,M2,Ko4,K8]
2683 tv0cw->swizzle(Swizzle2DType::XOR, -4, -2);
2684 tv0cw->merge(-4);
2685 tv0cw->merge(-3);
2686 // -3 -2 -1
2687 // [Mo,Ko,Mo16,warp,K8]
2688 tv0cw->split(-3, 4);
2689 tv0cw->split(-3, 2);
2690 // -4 -3 -2 -1
2691 // [Mo,Ko, S4, wz2, wy2, warp,K8]
2692 tv0cw->axis(-4)->parallelize(ParallelType::TIDz);
2693 tv0cw->axis(-3)->parallelize(ParallelType::TIDy);
2694 tv0cw->axis(-2)->parallelize(ParallelType::TIDx);
2695 tv0cw->axis(-1)->parallelize(ParallelType::Vectorize);
2696
2697 tv0cw->setMemoryType(MemoryType::Shared);
2698 // [Mo,Ko,i,wy,wx,v]
2699
2700 // [No,Ko,N,K]
2701 // Swizzle tv0: 128 x 32 tile:
2702 tv1cw->split(-2, 8);
2703 tv1cw->split(-2, 2);
2704 tv1cw->split(-1, 8);
2705 // -5 -4 -3 -2 -1
2706 // [No,Ko,No16,N4,N2,Ko4,K8]
2707 tv1cw->swizzle(Swizzle2DType::XOR, -4, -2);
2708 tv1cw->merge(-4);
2709 tv1cw->merge(-3);
2710 // -3 -2 -1
2711 // [No,Ko,No16,warp,K8]
2712 tv1cw->split(-3, 4);
2713 tv1cw->split(-3, 2);
2714 // -4 -3 -2 -1
2715 // [No,Ko, S4, wz2, wy2, warp,K8]
2716 tv1cw->axis(-4)->parallelize(ParallelType::TIDz);
2717 tv1cw->axis(-3)->parallelize(ParallelType::TIDy);
2718 tv1cw->axis(-2)->parallelize(ParallelType::TIDx);
2719 tv1cw->axis(-1)->parallelize(ParallelType::Vectorize);
2720
2721 tv1cw->setMemoryType(MemoryType::Shared);
2722 // Schedule mma input
2723 tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
2724
2725 // [... Mi, Ni, Ki]
2726 tv0b->reorder({{-2, -3}, {-3, -2}});
2727 tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
2728
2729 tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
2730 tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
2731
2732 // Schedule mma output
2733 tv2c->applyMmaSwizzle(
2734 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2735 tv2->applyMmaSwizzle(
2736 mma_builder.operand(MmaOptions::Operand::Accumulator).build());
2737
2738 // Parallelize
2739 // 0 1 2 3 4 5 6 7 8 9 10
2740 // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
2741 tv2c->axis(3)->parallelize(ParallelType::TIDz);
2742 tv2c->axis(4)->parallelize(ParallelType::TIDy);
2743
2744 // Parallelize
2745 // 0 1 2 3 4 5 6 7
2746 // [Mo No Mwo Nwo Mw Nw (Mi Ni)]
2747 tv2->axis(0)->parallelize(ParallelType::BIDx);
2748 tv2->axis(1)->parallelize(ParallelType::BIDy);
2749 tv2->axis(2)->parallelize(ParallelType::TIDz);
2750 tv2->axis(3)->parallelize(ParallelType::TIDy);
2751
2752 tv0cw->doubleBuffer();
2753 tv1cw->doubleBuffer();
2754 tv0cr->doubleBuffer();
2755 tv1cr->doubleBuffer();
2756
2757 auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
2758 auto t0 = at::randn({M, K}, options);
2759 auto t1 = at::randn({N, K}, options);
2760
2761 FusionExecutor fe;
2762 fe.compileFusion(&fusion);
2763 auto cg_outputs = fe.runFusion({t0, t1});
2764
2765 auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));
2766
2767 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
2768}
2769
2770// Matmul test on Ampere using ldmatrix.x4 to load operands
2771TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) {
2772 // Keep multiples of 8 to keep vectorizable.
2773 int M = 504, N = 136, K = 248;
2774 for (auto layout : kAllSupportedLayout) {
2775 Fusion fusion;
2776 FusionGuard fg(&fusion);
2777 auto tv0 = makeContigTensor(2, DataType::Half);
2778 auto tv1 = makeContigTensor(2, DataType::Half);
2779
2780 fusion.addInput(tv0);
2781 fusion.addInput(tv1);
2782
2783 auto tv2 = matmul(tv0, tv1, layout);
2784
2785 fusion.addOutput(tv2);
2786
2787 MatMulTileOptions gemm_tile;
2788 gemm_tile.cta_tile = GemmTile(128, 128, 32);
2789 gemm_tile.warp_tile = GemmTile(64, 64, 32);
2790 gemm_tile.instruction_tile = GemmTile(16, 16, 16);
2791
2792 auto mma_builder =
2793 MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile)
2794 .layout(layout);
2795
2796 MatmulParam params(mma_builder);
2797 params.tile_sizes = gemm_tile;
2798 params.async_gmem_load_operands = true;
2799 params.double_buffer_options.double_buffer_smem_write = true;
2800 params.double_buffer_options.smem_double_buffer_stage = 4;
2801 scheduleMatmul(tv2, tv0, tv1, params);
2802
2803 at::manual_seed(0);
2804 auto inputs = fp16MatmulAtInput(M, N, K, layout);
2805
2806 FusionExecutor fe;
2807 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
2808 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
2809 auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
2810 auto tref = atMatmul(
2811 inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
2812 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
2813 }
2814}
2815
2816// Matmul test for Turing MMA: across supported layouts
2817TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) {
2818 // Keep multiples of 8 to keep vectorizable.
2819 int M = 504, N = 136, K = 248;
2820
2821 for (auto layout : kAllSupportedLayout) {
2822 Fusion fusion;
2823 FusionGuard fg(&fusion);
2824 auto tv0 = makeContigTensor(2, DataType::Half);
2825 auto tv1 = makeContigTensor(2, DataType::Half);
2826
2827 fusion.addInput(tv0);
2828 fusion.addInput(tv1);
2829
2830 auto tv2 = matmul(tv0, tv1, layout);
2831
2832 fusion.addOutput(tv2);
2833
2834 MatMulTileOptions gemm_tile;
2835 gemm_tile.cta_tile = GemmTile(128, 128, 32);
2836 gemm_tile.warp_tile = GemmTile(64, 64, 32);
2837 gemm_tile.instruction_tile = GemmTile(16, 16, 16);
2838
2839 auto mma_builder =
2840 MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile)
2841 .layout(layout);
2842
2843 MatmulParam params(mma_builder);
2844 params.tile_sizes = gemm_tile;
2845 scheduleMatmul(tv2, tv0, tv1, params);
2846
2847 at::manual_seed(0);
2848 auto inputs = fp16MatmulAtInput(M, N, K, layout);
2849
2850 FusionExecutor fe;
2851 NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
2852 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second}));
2853 auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
2854 auto tref = atMatmul(
2855 inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
2856 TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
2857 }
2858}
2859
2860#undef NVFUSER_TEST_CUDA_ARCH_GUARD
2861
2862} // namespace jit
2863} // namespace torch
2864
2865#endif
2866