1 | #if defined(USE_CUDA) |
2 | #include <gtest/gtest.h> |
3 | |
4 | #include <arith.h> |
5 | #include <codegen.h> |
6 | #include <disjoint_set.h> |
7 | #include <executor.h> |
8 | #include <executor_launch_params.h> |
9 | #include <expr_evaluator.h> |
10 | #include <fusion.h> |
11 | #include <fusion_segmenter.h> |
12 | #include <ir_all_nodes.h> |
13 | #include <ir_graphviz.h> |
14 | #include <ir_iostream.h> |
15 | #include <ir_printer.h> |
16 | #include <ir_utils.h> |
17 | #include <iter_visitor.h> |
18 | #include <kernel_cache.h> |
19 | #include <kernel_expr_evaluator.h> |
20 | #include <kernel_ir.h> |
21 | #include <lower2device.h> |
22 | #include <mma_type.h> |
23 | #include <mutator.h> |
24 | #include <ops/all_ops.h> |
25 | #include <register_interface.h> |
26 | #include <root_domain_map.h> |
27 | #include <scheduler/all_schedulers.h> |
28 | #include <scheduler/matmul.h> |
29 | #include <scheduler/reduction_utils.h> |
30 | #include <scheduler/utils.h> |
31 | #include <test/test_gpu_validator.h> |
32 | #include <test/test_utils.h> |
33 | #include <transform_replay.h> |
34 | #include <transform_rfactor.h> |
35 | |
36 | // fuser and IR parser |
37 | #include <ATen/cuda/CUDAContext.h> |
38 | #include <ATen/cuda/Exceptions.h> |
39 | #include <c10/cuda/CUDAStream.h> |
40 | |
41 | #include <algorithm> |
42 | #include <iostream> |
43 | |
44 | // Tests go in torch::jit |
45 | namespace torch { |
46 | namespace jit { |
47 | |
48 | using namespace torch::jit::fuser::cuda; |
49 | using namespace at::indexing; |
50 | |
51 | namespace { |
52 | |
53 | bool cudaArchGuardShouldSkip(int required_major, int required_minor) { |
54 | int capability_major = at::cuda::getCurrentDeviceProperties()->major; |
55 | int capability_minor = at::cuda::getCurrentDeviceProperties()->minor; |
56 | |
57 | if (capability_major < required_major || |
58 | (capability_major == required_major && |
59 | capability_minor < required_minor)) { |
60 | return true; |
61 | } |
62 | return false; |
63 | } |
64 | |
65 | #define NVFUSER_TEST_CUDA_ARCH_GUARD(REQUIRED_MAJOR, REQUIRED_MINOR) \ |
66 | if (cudaArchGuardShouldSkip(REQUIRED_MAJOR, REQUIRED_MINOR)) { \ |
67 | GTEST_SKIP() << "Requires GPU capability above " << REQUIRED_MAJOR << "." \ |
68 | << REQUIRED_MINOR << " to run.\n"; \ |
69 | } |
70 | |
71 | #define NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( \ |
72 | REQUIRED_MAJOR, REQUIRED_MINOR, COMPILE_FUSION) \ |
73 | if (cudaArchGuardShouldSkip(REQUIRED_MAJOR, REQUIRED_MINOR)) { \ |
74 | ASSERT_ANY_THROW(COMPILE_FUSION); \ |
75 | GTEST_SKIP() << "(Lowered Only) Requires GPU capability above " \ |
76 | << REQUIRED_MAJOR << "." << REQUIRED_MINOR << " to run.\n"; \ |
77 | } else { \ |
78 | COMPILE_FUSION; \ |
79 | } |
80 | |
81 | // util to track support matmul operand layout. |
82 | using MatmulLayout = MmaOptions::MmaInputLayout; |
83 | |
84 | static constexpr std::array<MatmulLayout, 3> kAllSupportedLayout = { |
85 | MatmulLayout::TT, |
86 | MatmulLayout::NT, |
87 | MatmulLayout::TN}; |
88 | |
89 | // Generic interface to get matmul op with the given layout. |
90 | TensorView* matmul(TensorView* a, TensorView* b, MatmulLayout layout) { |
91 | TORCH_CHECK( |
92 | a->nDims() == 2 && b->nDims() == 2, "only pure matmuls for these tests" ); |
93 | TensorView *tv2 = nullptr, *tv0b = nullptr, *tv1b = nullptr; |
94 | switch (layout) { |
95 | case MatmulLayout::TT: |
96 | tv0b = broadcast(a, {false, false, true}); |
97 | tv1b = broadcast(b, {true, false, false}); |
98 | tv2 = fusedMultiplySum(tv0b, tv1b, {1}); |
99 | break; |
100 | case MatmulLayout::TN: |
101 | tv0b = broadcast(a, {false, true, false}); |
102 | tv1b = broadcast(b, {true, false, false}); |
103 | tv2 = fusedMultiplySum(tv0b, tv1b, {2}); |
104 | break; |
105 | case MatmulLayout::NT: |
106 | tv0b = broadcast(a, {false, false, true}); |
107 | tv1b = broadcast(b, {false, true, false}); |
108 | tv2 = fusedMultiplySum(tv0b, tv1b, {0}); |
109 | break; |
110 | default: |
111 | TORCH_CHECK(false, "unsupported data layout." ); |
112 | } |
113 | return tv2; |
114 | } |
115 | |
116 | // Utility to generate matmul input tensors based on given layout |
117 | at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) { |
118 | switch (layout) { |
119 | case MatmulLayout::TT: |
120 | return a.matmul(b); |
121 | case MatmulLayout::TN: |
122 | return a.matmul(b.t()); |
123 | case MatmulLayout::NT: |
124 | return a.t().matmul(b); |
125 | default: |
126 | TORCH_CHECK(false, "unsupported data layout." ); |
127 | } |
128 | return at::Tensor(); |
129 | } |
130 | |
131 | // Utility to generate reference results based on given layout |
132 | std::pair<at::Tensor, at::Tensor> fp16MatmulAtInput( |
133 | int M, |
134 | int N, |
135 | int K, |
136 | MatmulLayout layout) { |
137 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
138 | |
139 | switch (layout) { |
140 | case MatmulLayout::TT: |
141 | return std::make_pair( |
142 | at::randn({M, K}, options), at::randn({K, N}, options)); |
143 | case MatmulLayout::TN: |
144 | return std::make_pair( |
145 | at::randn({M, K}, options), at::randn({N, K}, options)); |
146 | case MatmulLayout::NT: |
147 | return std::make_pair( |
148 | at::randn({K, M}, options), at::randn({K, N}, options)); |
149 | default: |
150 | TORCH_CHECK(false, "unsupported data layout." ); |
151 | } |
152 | return std::make_pair(at::Tensor(), at::Tensor()); |
153 | } |
154 | |
155 | #define REQUIRE_DEVICE_SMEM_SIZE(required_size, device_idx) \ |
156 | if (at::cuda::getDeviceProperties(device_idx)->sharedMemPerBlockOptin < \ |
157 | required_size) { \ |
158 | GTEST_SKIP() << "not enough shared memory space on device to run test"; \ |
159 | } |
160 | |
161 | } // namespace |
162 | |
163 | // MMA unit test for a single instruction tile. VoltaTT |
164 | TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { |
165 | Fusion fusion; |
166 | FusionGuard fg(&fusion); |
167 | |
168 | // [M,K] |
169 | auto tv0 = makeConcreteTensor({16, 4}, DataType::Half); |
170 | // [K,N] |
171 | auto tv1 = makeConcreteTensor({4, 16}, DataType::Half); |
172 | fusion.addInput(tv0); |
173 | fusion.addInput(tv1); |
174 | |
175 | // [M,K,N] |
176 | auto tv0b = broadcast(tv0, {false, false, true}); |
177 | auto tv1b = broadcast(tv1, {true, false, false}); |
178 | |
179 | // Leaving both sets of mma inputs for volta outside |
180 | // currently since they need to be swizzled. |
181 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); |
182 | |
183 | fusion.addOutput(tv2); |
184 | |
185 | // TODO: should be able to completely remove it |
186 | // in a follow up. |
187 | MatMulTileOptions gemm_tile; |
188 | gemm_tile.cta_tile = GemmTile(16, 16, 4); |
189 | gemm_tile.warp_tile = GemmTile(16, 16, 4); |
190 | gemm_tile.instruction_tile = GemmTile(16, 16, 4); |
191 | |
192 | auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) |
193 | .layout(MmaOptions::MmaInputLayout::TT); |
194 | mma_builder.configureMma(tv2); |
195 | |
196 | // Write A to smem |
197 | auto tv0cw = tv0b->cacheAfter(); |
198 | // Read A from smem |
199 | auto tv0cr = tv0cw->cacheAfter(); |
200 | |
201 | // Write B to smem |
202 | auto tv1cw = tv1b->cacheAfter(); |
203 | |
204 | // Read B from smem |
205 | auto tv1cr = tv1cw->cacheAfter(); |
206 | |
207 | // Register accumulator |
208 | auto tv2c = tv2->cacheBefore(); |
209 | |
210 | mma_builder.accumulatorTv(tv2c); |
211 | |
212 | // [M,K,N]->[M,N,K] |
213 | tv0cr->reorder({{-2, -1}, {-1, -2}}); |
214 | |
215 | // Schedule the instruction tile loops, which is the only |
216 | // part we have in this unit test. |
217 | // Assumes last 3 dims are mnk |
218 | // The innermost loops are dictated by the type of mma used, |
219 | // the scheduler needs to use mma_util::WarpMmaSwizzler to |
220 | // get the right thread swizzle. Currently this is the only |
221 | // method allowed to schedule the 3/2 inner most loops of |
222 | // mma input/output. |
223 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
224 | |
225 | // [M,K,N]->[M,N,K] |
226 | tv1cr->reorder({{-2, -1}, {-1, -2}}); |
227 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
228 | |
229 | // [M,K,N]->[M,N,K] |
230 | tv2c->reorder({{-2, -1}, {-1, -2}}); |
231 | |
232 | // Schedule the output instruction tile. |
233 | // Assumes last 3 dims are mnk |
234 | tv2c->applyMmaSwizzle( |
235 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
236 | tv2->applyMmaSwizzle( |
237 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
238 | |
239 | // Set memory type. |
240 | tv0cw->setMemoryType(MemoryType::Shared); |
241 | tv1cw->setMemoryType(MemoryType::Shared); |
242 | |
243 | at::manual_seed(0); |
244 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
245 | auto t0 = at::randn({16, 4}, options); |
246 | auto t1 = at::randn({4, 16}, options); |
247 | |
248 | FusionExecutor fe; |
249 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
250 | 7, 0, fe.compileFusion(&fusion, {t0, t1})); |
251 | auto cg_outputs = fe.runFusion({t0, t1}); |
252 | |
253 | auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); |
254 | |
255 | testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); |
256 | } |
257 | |
258 | // MMA unit test for a single instruction tile. VoltaTN |
259 | TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { |
260 | Fusion fusion; |
261 | FusionGuard fg(&fusion); |
262 | |
263 | // [M,K] |
264 | auto tv0 = makeConcreteTensor({16, 4}, DataType::Half); |
265 | // [N,K] |
266 | auto tv1 = makeConcreteTensor({16, 4}, DataType::Half); |
267 | fusion.addInput(tv0); |
268 | fusion.addInput(tv1); |
269 | |
270 | // [M,N,K] |
271 | auto tv0b = broadcast(tv0, {false, true, false}); |
272 | auto tv1b = broadcast(tv1, {true, false, false}); |
273 | |
274 | // Leaving both sets of mma inputs for volta outside |
275 | // currently since they need to be swizzled. |
276 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); |
277 | |
278 | fusion.addOutput(tv2); |
279 | |
280 | // TODO: should be able to completely remove it |
281 | // in a follow up. |
282 | MatMulTileOptions gemm_tile; |
283 | gemm_tile.cta_tile = GemmTile(16, 16, 4); |
284 | gemm_tile.warp_tile = GemmTile(16, 16, 4); |
285 | gemm_tile.instruction_tile = GemmTile(16, 16, 4); |
286 | |
287 | auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) |
288 | .layout(MmaOptions::MmaInputLayout::TN); |
289 | |
290 | mma_builder.configureMma(tv2); |
291 | |
292 | auto tv0cw = tv0b->cacheAfter(); |
293 | auto tv0cr = tv0cw->cacheAfter(); |
294 | auto tv1cw = tv1b->cacheAfter(); |
295 | auto tv1cr = tv1cw->cacheAfter(); |
296 | auto tv2c = tv2->cacheBefore(); |
297 | |
298 | mma_builder.accumulatorTv(tv2c); |
299 | |
300 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
301 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
302 | tv2c->applyMmaSwizzle( |
303 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
304 | tv2->applyMmaSwizzle( |
305 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
306 | |
307 | tv0cw->setMemoryType(MemoryType::Shared); |
308 | tv1cw->setMemoryType(MemoryType::Shared); |
309 | |
310 | at::manual_seed(0); |
311 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
312 | auto t0 = at::randn({16, 4}, options); |
313 | auto t1 = at::randn({16, 4}, options); |
314 | |
315 | FusionExecutor fe; |
316 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
317 | 7, 0, fe.compileFusion(&fusion, {t0, t1})); |
318 | auto cg_outputs = fe.runFusion({t0, t1}); |
319 | auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); |
320 | testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); |
321 | } |
322 | |
323 | // MMA unit test for a single instruction tile. VoltaNT |
324 | TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { |
325 | Fusion fusion; |
326 | FusionGuard fg(&fusion); |
327 | |
328 | // [K,M] |
329 | auto tv0 = makeConcreteTensor({4, 16}, DataType::Half); |
330 | // [K,N] |
331 | auto tv1 = makeConcreteTensor({4, 16}, DataType::Half); |
332 | fusion.addInput(tv0); |
333 | fusion.addInput(tv1); |
334 | |
335 | // [K,M,N] |
336 | auto tv0b = broadcast(tv0, {false, false, true}); |
337 | auto tv1b = broadcast(tv1, {false, true, false}); |
338 | |
339 | // Leaving both sets of mma inputs for volta outside |
340 | // currently since they need to be swizzled. |
341 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); |
342 | |
343 | fusion.addOutput(tv2); |
344 | |
345 | MatMulTileOptions gemm_tile; |
346 | gemm_tile.cta_tile = GemmTile(16, 16, 4); |
347 | gemm_tile.warp_tile = GemmTile(16, 16, 4); |
348 | gemm_tile.instruction_tile = GemmTile(16, 16, 4); |
349 | |
350 | auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) |
351 | .layout(MmaOptions::MmaInputLayout::NT); |
352 | |
353 | mma_builder.configureMma(tv2); |
354 | |
355 | auto tv0cw = tv0b->cacheAfter(); |
356 | auto tv0cr = tv0cw->cacheAfter(); |
357 | auto tv1cw = tv1b->cacheAfter(); |
358 | auto tv1cr = tv1cw->cacheAfter(); |
359 | auto tv2c = tv2->cacheBefore(); |
360 | |
361 | mma_builder.accumulatorTv(tv2c); |
362 | |
363 | // To MNK |
364 | tv0cr->reorder({{0, 2}, {1, 0}, {2, 1}}); |
365 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
366 | |
367 | // To MNK |
368 | tv1cr->reorder({{0, 2}, {1, 0}, {2, 1}}); |
369 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
370 | |
371 | tv2c->reorder({{0, 2}, {1, 0}, {2, 1}}); |
372 | tv2c->applyMmaSwizzle( |
373 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
374 | tv2->applyMmaSwizzle( |
375 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
376 | tv0cw->setMemoryType(MemoryType::Shared); |
377 | tv1cw->setMemoryType(MemoryType::Shared); |
378 | |
379 | at::manual_seed(0); |
380 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
381 | auto t0 = at::randn({4, 16}, options); |
382 | auto t1 = at::randn({4, 16}, options); |
383 | |
384 | FusionExecutor fe; |
385 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
386 | 7, 0, fe.compileFusion(&fusion, {t0, t1})); |
387 | auto cg_outputs = fe.runFusion({t0, t1}); |
388 | auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); |
389 | testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); |
390 | } |
391 | |
392 | // Matmul test for Volta MMA: across supported layouts |
393 | TEST_F(NVFuserTest, FusionVoltaMatmul_CUDA) { |
394 | // Keep multiples of 8 to keep vectorizable. |
395 | int M = 264, N = 136, K = 248; |
396 | |
397 | for (auto layout : kAllSupportedLayout) { |
398 | Fusion fusion; |
399 | FusionGuard fg(&fusion); |
400 | auto tv0 = makeContigTensor(2, DataType::Half); |
401 | auto tv1 = makeContigTensor(2, DataType::Half); |
402 | |
403 | fusion.addInput(tv0); |
404 | fusion.addInput(tv1); |
405 | |
406 | auto tv2 = matmul(tv0, tv1, layout); |
407 | |
408 | fusion.addOutput(tv2); |
409 | |
410 | MatMulTileOptions gemm_tile; |
411 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
412 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
413 | gemm_tile.instruction_tile = GemmTile(16, 16, 4); |
414 | |
415 | auto mma_builder = |
416 | MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) |
417 | .layout(layout); |
418 | |
419 | MatmulParam params(mma_builder); |
420 | params.tile_sizes = gemm_tile; |
421 | scheduleMatmul(tv2, tv0, tv1, params); |
422 | |
423 | at::manual_seed(0); |
424 | auto inputs = fp16MatmulAtInput(M, N, K, layout); |
425 | |
426 | FusionExecutor fe; |
427 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
428 | 7, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); |
429 | auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); |
430 | auto tref = atMatmul( |
431 | inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); |
432 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
433 | } |
434 | } |
435 | |
436 | // Matmul test for Volta MMA: across supported layouts |
437 | TEST_F(NVFuserTest, FusionVoltaMatmulRegDoubleBuffer_CUDA) { |
438 | // Keep multiples of 8 to keep vectorizable. |
439 | int M = 264, N = 136, K = 248; |
440 | |
441 | for (auto layout : kAllSupportedLayout) { |
442 | Fusion fusion; |
443 | FusionGuard fg(&fusion); |
444 | auto tv0 = makeContigTensor(2, DataType::Half); |
445 | auto tv1 = makeContigTensor(2, DataType::Half); |
446 | |
447 | fusion.addInput(tv0); |
448 | fusion.addInput(tv1); |
449 | |
450 | auto tv2 = matmul(tv0, tv1, layout); |
451 | |
452 | fusion.addOutput(tv2); |
453 | |
454 | MatMulTileOptions gemm_tile; |
455 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
456 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
457 | gemm_tile.instruction_tile = GemmTile(16, 16, 4); |
458 | |
459 | auto mma_builder = |
460 | MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) |
461 | .layout(layout); |
462 | |
463 | MatmulParam params(mma_builder); |
464 | params.tile_sizes = gemm_tile; |
465 | params.double_buffer_options.double_buffer_smem_read = true; |
466 | scheduleMatmul(tv2, tv0, tv1, params); |
467 | |
468 | at::manual_seed(0); |
469 | auto inputs = fp16MatmulAtInput(M, N, K, layout); |
470 | |
471 | FusionExecutor fe; |
472 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
473 | 7, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); |
474 | auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); |
475 | auto tref = atMatmul( |
476 | inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); |
477 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
478 | } |
479 | } |
480 | |
481 | // MMA unit test on Ampere |
482 | TEST_F(NVFuserTest, FusionAmpereMMATN_CUDA) { |
483 | NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); |
484 | |
485 | Fusion fusion; |
486 | FusionGuard fg(&fusion); |
487 | |
488 | // [M,K] |
489 | auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); |
490 | // [N,K] |
491 | auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); |
492 | fusion.addInput(tv0); |
493 | fusion.addInput(tv1); |
494 | |
495 | // [M,N,K] |
496 | auto tv0b = broadcast(tv0, {false, true, false}); |
497 | auto tv1b = broadcast(tv1, {true, false, false}); |
498 | |
499 | // Leaving both sets of mma inputs for volta outside |
500 | // currently since they need to be swizzled. |
501 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); |
502 | |
503 | fusion.addOutput(tv2); |
504 | |
505 | MatMulTileOptions gemm_tile; |
506 | gemm_tile.cta_tile = GemmTile(16, 8, 16); |
507 | gemm_tile.warp_tile = GemmTile(16, 8, 16); |
508 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
509 | |
510 | auto mma_builder = |
511 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
512 | .layout(MmaOptions::MmaInputLayout::TN); |
513 | |
514 | mma_builder.configureMma(tv2); |
515 | |
516 | auto tv0cw = tv0b->cacheAfter(); |
517 | auto tv0cr = |
518 | tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); |
519 | auto tv1cw = tv1b->cacheAfter(); |
520 | auto tv1cr = |
521 | tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); |
522 | |
523 | auto tv2c = tv2->cacheBefore(); |
524 | |
525 | mma_builder.accumulatorTv(tv2c); |
526 | |
527 | // [M,N,K] -> [N,M,K] |
528 | tv0cr->reorder({{-2, -3}, {-3, -2}}); |
529 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
530 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
531 | tv2c->applyMmaSwizzle( |
532 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
533 | tv2->applyMmaSwizzle( |
534 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
535 | |
536 | tv0cw->setMemoryType(MemoryType::Shared); |
537 | tv1cw->setMemoryType(MemoryType::Shared); |
538 | |
539 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
540 | auto t0 = at::randn({16, 16}, options); |
541 | auto t1 = at::randn({8, 16}, options); |
542 | |
543 | FusionExecutor fe; |
544 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
545 | 8, 0, fe.compileFusion(&fusion, {t0, t1})); |
546 | auto cg_outputs = fe.runFusion({t0, t1}); |
547 | |
548 | auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); |
549 | |
550 | testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); |
551 | } |
552 | |
553 | // MMA unit test on Ampere |
554 | TEST_F(NVFuserTest, FusionAmpereMMATT_CUDA) { |
555 | NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); |
556 | |
557 | Fusion fusion; |
558 | FusionGuard fg(&fusion); |
559 | |
560 | // [M,K] |
561 | auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); |
562 | // [K,N] |
563 | auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); |
564 | fusion.addInput(tv0); |
565 | fusion.addInput(tv1); |
566 | |
567 | // [M,K,N] |
568 | auto tv0b = broadcast(tv0, {false, false, true}); |
569 | auto tv1b = broadcast(tv1, {true, false, false}); |
570 | |
571 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); |
572 | |
573 | fusion.addOutput(tv2); |
574 | |
575 | MatMulTileOptions gemm_tile; |
576 | gemm_tile.cta_tile = GemmTile(16, 8, 16); |
577 | gemm_tile.warp_tile = GemmTile(16, 8, 16); |
578 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
579 | |
580 | auto mma_builder = |
581 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
582 | .layout(MmaOptions::MmaInputLayout::TT); |
583 | |
584 | mma_builder.configureMma(tv2); |
585 | |
586 | auto tv0cw = tv0b->cacheAfter(); |
587 | auto tv0cr = |
588 | tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); |
589 | auto tv1cw = tv1b->cacheAfter(); |
590 | auto tv1cr = |
591 | tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); |
592 | |
593 | auto tv2c = tv2->cacheBefore(); |
594 | |
595 | mma_builder.accumulatorTv(tv2c); |
596 | // [M,K,N] -> [N,M,K] |
597 | tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}}); |
598 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
599 | |
600 | // [M,K,N] -> [M,N,K] |
601 | tv1cr->reorder({{-2, -1}, {-1, -2}}); |
602 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
603 | |
604 | // [M,K,N] -> [M,N,K] |
605 | tv2c->reorder({{-2, -1}, {-1, -2}}); |
606 | tv2c->applyMmaSwizzle( |
607 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
608 | tv2->applyMmaSwizzle( |
609 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
610 | |
611 | tv0cw->setMemoryType(MemoryType::Shared); |
612 | tv1cw->setMemoryType(MemoryType::Shared); |
613 | |
614 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
615 | auto t0 = at::randn({16, 16}, options); |
616 | auto t1 = at::randn({16, 8}, options); |
617 | |
618 | FusionExecutor fe; |
619 | |
620 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
621 | 8, 0, fe.compileFusion(&fusion, {t0, t1})); |
622 | |
623 | auto cg_outputs = fe.runFusion({t0, t1}); |
624 | |
625 | auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); |
626 | |
627 | testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); |
628 | } |
629 | |
630 | // MMA unit test on Ampere |
631 | TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) { |
632 | NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); |
633 | |
634 | Fusion fusion; |
635 | FusionGuard fg(&fusion); |
636 | |
637 | // [K,M] |
638 | auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); |
639 | // [K,N] |
640 | auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); |
641 | fusion.addInput(tv0); |
642 | fusion.addInput(tv1); |
643 | |
644 | // [K,M,N] |
645 | auto tv0b = broadcast(tv0, {false, false, true}); |
646 | auto tv1b = broadcast(tv1, {false, true, false}); |
647 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); |
648 | |
649 | fusion.addOutput(tv2); |
650 | |
651 | MatMulTileOptions gemm_tile; |
652 | gemm_tile.cta_tile = GemmTile(16, 8, 16); |
653 | gemm_tile.warp_tile = GemmTile(16, 8, 16); |
654 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
655 | |
656 | auto mma_builder = |
657 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
658 | .layout(MmaOptions::MmaInputLayout::NT); |
659 | |
660 | mma_builder.configureMma(tv2); |
661 | |
662 | auto tv0cw = tv0b->cacheAfter(); |
663 | auto tv0cr = |
664 | tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); |
665 | auto tv1cw = tv1b->cacheAfter(); |
666 | auto tv1cr = |
667 | tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); |
668 | |
669 | auto tv2c = tv2->cacheBefore(); |
670 | mma_builder.accumulatorTv(tv2c); |
671 | |
672 | // [K,M,N] -> [N,M,K] |
673 | tv0cr->reorder({{-3, -1}, {-1, -3}}); |
674 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
675 | |
676 | // [K,M,N] -> [M,N,K] |
677 | tv1cr->reorder({ |
678 | {-3, -1}, |
679 | {-2, -3}, |
680 | {-1, -2}, |
681 | }); |
682 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
683 | |
684 | // [K,M,N] -> [M,N,K] |
685 | tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); |
686 | tv2c->applyMmaSwizzle( |
687 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
688 | tv2->applyMmaSwizzle( |
689 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
690 | |
691 | tv0cw->setMemoryType(MemoryType::Shared); |
692 | tv1cw->setMemoryType(MemoryType::Shared); |
693 | |
694 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
695 | auto t0 = at::randn({16, 16}, options); |
696 | auto t1 = at::randn({16, 8}, options); |
697 | |
698 | FusionExecutor fe; |
699 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
700 | 8, 0, fe.compileFusion(&fusion, {t0, t1})); |
701 | auto cg_outputs = fe.runFusion({t0, t1}); |
702 | |
703 | auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); |
704 | |
705 | testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); |
706 | } |
707 | |
708 | // Matmul test for Ampere MMA: across supported layouts |
709 | TEST_F(NVFuserTest, FusionAmpereMatmul_CUDA) { |
710 | // Keep multiples of 8 to keep vectorizable. |
711 | int M = 504, N = 136, K = 248; |
712 | |
713 | for (auto layout : kAllSupportedLayout) { |
714 | Fusion fusion; |
715 | FusionGuard fg(&fusion); |
716 | auto tv0 = makeContigTensor(2, DataType::Half); |
717 | auto tv1 = makeContigTensor(2, DataType::Half); |
718 | |
719 | fusion.addInput(tv0); |
720 | fusion.addInput(tv1); |
721 | |
722 | auto tv2 = matmul(tv0, tv1, layout); |
723 | |
724 | fusion.addOutput(tv2); |
725 | |
726 | MatMulTileOptions gemm_tile; |
727 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
728 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
729 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
730 | |
731 | auto mma_builder = |
732 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
733 | .layout(layout); |
734 | |
735 | MatmulParam params(mma_builder); |
736 | params.tile_sizes = gemm_tile; |
737 | params.async_gmem_load_operands = true; |
738 | params.double_buffer_options.double_buffer_smem_write = true; |
739 | params.double_buffer_options.smem_double_buffer_stage = 4; |
740 | scheduleMatmul(tv2, tv0, tv1, params); |
741 | |
742 | at::manual_seed(0); |
743 | auto inputs = fp16MatmulAtInput(M, N, K, layout); |
744 | |
745 | FusionExecutor fe; |
746 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
747 | 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); |
748 | auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); |
749 | auto tref = atMatmul( |
750 | inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); |
751 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
752 | } |
753 | } |
754 | |
755 | // Matmul test for Ampere MMA: with pipelined gmem load |
756 | TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { |
757 | // Keep multiples of 8 to keep vectorizable. |
758 | int M = 504, N = 136, K = 248; |
759 | REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); |
760 | |
761 | // Gmem pipeline stage |
762 | for (auto stage : {3, 4}) { |
763 | for (auto layout : kAllSupportedLayout) { |
764 | Fusion fusion; |
765 | FusionGuard fg(&fusion); |
766 | auto tv0 = makeContigTensor(2, DataType::Half); |
767 | auto tv1 = makeContigTensor(2, DataType::Half); |
768 | |
769 | fusion.addInput(tv0); |
770 | fusion.addInput(tv1); |
771 | |
772 | auto tv2 = matmul(tv0, tv1, layout); |
773 | |
774 | fusion.addOutput(tv2); |
775 | |
776 | MatMulTileOptions gemm_tile; |
777 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
778 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
779 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
780 | |
781 | auto mma_builder = |
782 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
783 | .layout(layout); |
784 | |
785 | MatmulParam params(mma_builder); |
786 | params.tile_sizes = gemm_tile; |
787 | params.tile_sizes = gemm_tile; |
788 | params.async_gmem_load_operands = true; |
789 | params.double_buffer_options.double_buffer_smem_write = true; |
790 | params.double_buffer_options.smem_double_buffer_stage = stage; |
791 | scheduleMatmul(tv2, tv0, tv1, params); |
792 | |
793 | at::manual_seed(0); |
794 | auto inputs = fp16MatmulAtInput(M, N, K, layout); |
795 | |
796 | FusionExecutor fe; |
797 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
798 | 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); |
799 | auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); |
800 | auto tref = atMatmul( |
801 | inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); |
802 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
803 | } |
804 | } |
805 | } |
806 | |
807 | TEST_F(NVFuserTest, FusionAmpereMatmulRegDbouleBuffer_CUDA) { |
808 | // Keep multiples of 8 to keep vectorizable. |
809 | int M = 504, N = 136, K = 248; |
810 | REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); |
811 | |
812 | // Gmem pipeline stage |
813 | for (auto stage : {3, 4}) { |
814 | for (auto layout : kAllSupportedLayout) { |
815 | Fusion fusion; |
816 | FusionGuard fg(&fusion); |
817 | auto tv0 = makeContigTensor(2, DataType::Half); |
818 | auto tv1 = makeContigTensor(2, DataType::Half); |
819 | |
820 | fusion.addInput(tv0); |
821 | fusion.addInput(tv1); |
822 | |
823 | auto tv2 = matmul(tv0, tv1, layout); |
824 | |
825 | fusion.addOutput(tv2); |
826 | |
827 | MatMulTileOptions gemm_tile; |
828 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
829 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
830 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
831 | |
832 | auto mma_builder = |
833 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
834 | .layout(layout); |
835 | |
836 | MatmulParam params(mma_builder); |
837 | params.tile_sizes = gemm_tile; |
838 | params.tile_sizes = gemm_tile; |
839 | params.async_gmem_load_operands = true; |
840 | params.double_buffer_options.double_buffer_smem_write = true; |
841 | params.double_buffer_options.smem_double_buffer_stage = stage; |
842 | params.double_buffer_options.double_buffer_smem_read = true; |
843 | scheduleMatmul(tv2, tv0, tv1, params); |
844 | |
845 | at::manual_seed(0); |
846 | auto inputs = fp16MatmulAtInput(M, N, K, layout); |
847 | |
848 | FusionExecutor fe; |
849 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
850 | 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); |
851 | auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); |
852 | auto tref = atMatmul( |
853 | inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); |
854 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
855 | } |
856 | } |
857 | } |
858 | |
859 | // Matmul-Matmul fusion test on Ampere |
860 | TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { |
861 | NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); |
862 | |
863 | Fusion fusion; |
864 | FusionGuard fg(&fusion); |
865 | int M = 512, N = 256, K1 = 128, K2 = 128; |
866 | |
867 | // Fusion definition (Both gemms are TN) |
868 | // [M,K1] |
869 | auto tv0 = makeConcreteTensor({M, K1}, DataType::Half); |
870 | // [K2,K1] |
871 | auto tv1 = makeConcreteTensor({K2, K1}, DataType::Half); |
872 | // [N,K2] |
873 | auto tv2 = makeConcreteTensor({N, K2}, DataType::Half); |
874 | |
875 | fusion.addInput(tv0); |
876 | fusion.addInput(tv1); |
877 | fusion.addInput(tv2); |
878 | |
879 | // [M,N,K] |
880 | auto tv0b = broadcast(tv0, {false, true, false}); |
881 | auto tv1b = broadcast(tv1, {true, false, false}); |
882 | auto tv2b = broadcast(tv2, {true, false, false}); |
883 | |
884 | // [M,K2,R] |
885 | auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); |
886 | |
887 | auto tv3h = castOp(DataType::Half, tv3); |
888 | auto tv3b = broadcast(tv3h, {false, true, false}); |
889 | |
890 | auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); |
891 | |
892 | fusion.addOutput(tv4); |
893 | |
894 | // Fusion: |
895 | // Gemm(M,K2,K1) x Gemm(M,N,K2) |
896 | |
897 | MatMulTileOptions gemm_tile1, gemm_tile2; |
898 | |
899 | // cta tile: |
900 | // To save register, n of cta tile 1 |
901 | // matches k of cta tile2 |
902 | gemm_tile1.cta_tile = GemmTile(128, 64, 32); |
903 | gemm_tile2.cta_tile = GemmTile(128, 32, 64); |
904 | |
905 | // Distribute to 2x2 warps |
906 | gemm_tile1.warp_tile = GemmTile(64, 32, 32); |
907 | gemm_tile2.warp_tile = GemmTile(64, 16, 64); |
908 | |
909 | // Using Ampere mma macro |
910 | gemm_tile2.instruction_tile = GemmTile(16, 8, 16); |
911 | gemm_tile2.instruction_tile = GemmTile(16, 8, 16); |
912 | |
913 | auto mma_builder1 = |
914 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile1) |
915 | .layout(MmaOptions::MmaInputLayout::TN); |
916 | |
917 | auto mma_builder2 = |
918 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile2) |
919 | .layout(MmaOptions::MmaInputLayout::TN); |
920 | |
921 | mma_builder1.configureMma(tv3); |
922 | mma_builder2.configureMma(tv4); |
923 | |
924 | // Global read for gemm 1 |
925 | auto tv0r = tv0->cacheAfter(); |
926 | auto tv1r = tv1->cacheAfter(); |
927 | |
928 | // Global read for gemm 2 |
929 | auto tv2r = tv2->cacheAfter(); |
930 | |
931 | // Gemm 1 main loop read |
932 | auto tv0cw = tv0r->cacheAfter(); |
933 | auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); |
934 | auto tv1cw = tv1r->cacheAfter(); |
935 | auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); |
936 | |
937 | // Gemm 1 accumulator reg |
938 | auto tv3c = tv3->cacheBefore(); |
939 | mma_builder1.accumulatorTv(tv3c); |
940 | |
941 | // Gemm 2 main loop read |
942 | auto tv3cw = tv3h->cacheAfter(); |
943 | auto tv3cr = tv3cw->cacheAfter(LoadStoreOpType::LdMatrix); |
944 | |
945 | auto tv2cw = tv2r->cacheAfter(); |
946 | auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); |
947 | |
948 | // Gemm 2 accumulator reg |
949 | auto tv4c = tv4->cacheBefore(); |
950 | mma_builder2.accumulatorTv(tv4c); |
951 | |
952 | // General idea is inlining gemm1's main loop inside gemm2's |
953 | |
954 | // Schedule gemm 2: |
955 | // ------------------------------------------------------------------ |
956 | tv4->split(-2, gemm_tile2.cta_tile.m); |
957 | tv4->split(-1, gemm_tile2.cta_tile.n); |
958 | |
959 | // 0 1 2 3 |
960 | // [Mo,M128, No, N128] |
961 | tv4->reorder({{1, 2}, {2, 1}}); |
962 | |
963 | // 0 1 2 3 |
964 | // [Mo,No, M128, N128] |
965 | tv2->computeAt(tv4, 2); |
966 | tv3->computeAt(tv4, 2); |
967 | |
968 | // Order K |
969 | // 0 1 2 3 4 5 |
970 | // [Mo,No, M128, N128, Ko, K32] |
971 | tv4c->split(-1, gemm_tile2.cta_tile.k); |
972 | tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); |
973 | |
974 | // 0 1 2 3 4 5 |
975 | // [Mo,No, Ko M128, N128, K32] |
976 | tv3->computeAt(tv4c, 3); // Implicitly defines cta tile of gemm1 |
977 | tv2r->computeAt(tv4c, 3); |
978 | |
979 | // Make warp tile |
980 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( |
981 | tv4c, gemm_tile2); |
982 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
983 | tv4, gemm_tile2); |
984 | // -8 -7 -6 -5 -4 -3 -2 -1 |
985 | // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] |
986 | tv3cr->computeAt(tv4c, -4); |
987 | tv2cr->computeAt(tv4c, -4); |
988 | |
989 | // Schedule tv2 gmem read and smem write: |
990 | // ---------------------------------------------------------------- |
991 | // [No,Ko,N,K] |
992 | tv2cw->merge(-2); |
993 | tv2r->merge(-2); |
994 | |
995 | // [No,Ko,i,wy,wx,v] |
996 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
997 | tv2cw, gemm_tile2, 8); |
998 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
999 | tv2r, gemm_tile2, 8); |
1000 | tv2cw->setMemoryType(MemoryType::Shared); |
1001 | |
1002 | // Schedule tv2 gmem read and smem write: |
1003 | // ---------------------------------------------------------------- |
1004 | |
1005 | // Schedule gemm 2 mma input |
1006 | // --------------------------------------------------------------------------- |
1007 | tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); |
1008 | |
1009 | // [... Mi, Ni, Ki] want [Ni, Mi, Ki] |
1010 | tv3b->reorder({{-2, -3}, {-3, -2}}); |
1011 | tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); |
1012 | |
1013 | tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); |
1014 | tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); |
1015 | |
1016 | // Schedule mma output |
1017 | // --------------------------------------------------------------------------- |
1018 | tv4c->applyMmaSwizzle( |
1019 | mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); |
1020 | tv4->applyMmaSwizzle( |
1021 | mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); |
1022 | |
1023 | // Schedule gemm 1: |
1024 | // ------------------------------------------------------------------ |
1025 | |
1026 | // CTA tile: |
1027 | tv0->computeAt(tv3, 2); |
1028 | tv1->computeAt(tv3, 2); |
1029 | |
1030 | // Schedule K dim for gemm 1: |
1031 | |
1032 | // Order K |
1033 | // 0 1 2 3 4 5 |
1034 | // [Mo,No, M128, N128, Ko, K32] |
1035 | tv3c->split(-1, gemm_tile1.cta_tile.k); |
1036 | tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); |
1037 | // 0 1 2 3 4 5 |
1038 | // [Mo,No, Ko M128, N128, K32] |
1039 | tv0r->computeAt(tv3c, 3); |
1040 | tv1r->computeAt(tv3c, 3); |
1041 | |
1042 | // Make warp tile: |
1043 | // ------------------------------------------------------------------------- |
1044 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( |
1045 | tv3c, gemm_tile1); |
1046 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
1047 | tv3cw, gemm_tile1); |
1048 | |
1049 | tv0cr->computeAt(tv3c, -4); |
1050 | tv1cr->computeAt(tv3c, -4); |
1051 | |
1052 | tv3->computeAt(tv3cw, -3); |
1053 | |
1054 | // Schedule gmem read and smem write: |
1055 | // --------------------------------------------------------------------------- |
1056 | // [Mo,Ko,M,K] |
1057 | tv0cw->merge(-2); |
1058 | tv0r->merge(-2); |
1059 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1060 | tv0cw, gemm_tile1, 8); |
1061 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1062 | tv0r, gemm_tile1, 8); |
1063 | tv0cw->setMemoryType(MemoryType::Shared); |
1064 | // [Mo,Ko,i,wy,wx,v] |
1065 | |
1066 | // [No,Ko,N,K] |
1067 | tv1cw->merge(-2); |
1068 | tv1r->merge(-2); |
1069 | // [No,Ko,i,wy,wx,v] |
1070 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1071 | tv1cw, gemm_tile1, 8); |
1072 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1073 | tv1r, gemm_tile1, 8); |
1074 | tv1cw->setMemoryType(MemoryType::Shared); |
1075 | |
1076 | // Schedule mma input |
1077 | // --------------------------------------------------------------------------- |
1078 | tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); |
1079 | // [... Mi, Ni, Ki] want [Ni, Mi, Ki] |
1080 | tv0b->reorder({{-2, -3}, {-3, -2}}); |
1081 | tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); |
1082 | |
1083 | tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); |
1084 | tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); |
1085 | |
1086 | // Schedule mma output |
1087 | // --------------------------------------------------------------------------- |
1088 | tv3c->applyMmaSwizzle( |
1089 | mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); |
1090 | tv3cw->applyMmaSwizzle( |
1091 | mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); |
1092 | tv3h->applyMmaSwizzle( |
1093 | mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); |
1094 | tv3->applyMmaSwizzle( |
1095 | mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); |
1096 | tv3cw->setMemoryType(MemoryType::Shared); |
1097 | |
1098 | // Parallelize |
1099 | // 0 1 2 3 4 5 6 7 |
1100 | // [Mo No Mwo Nwo Mw Nw (Mi Ni)] |
1101 | // Gemm 1 |
1102 | tv3c->axis(3)->parallelize(ParallelType::TIDz); |
1103 | tv3c->axis(4)->parallelize(ParallelType::TIDy); |
1104 | |
1105 | tv3->computeAt(tv3cw, -2); |
1106 | tv3cw->axis(2)->parallelize(ParallelType::TIDz); |
1107 | tv3cw->axis(3)->parallelize(ParallelType::TIDy); |
1108 | |
1109 | // Gemm 2 |
1110 | tv4->axis(2)->parallelize(ParallelType::TIDz); |
1111 | tv4->axis(3)->parallelize(ParallelType::TIDy); |
1112 | tv4c->axis(3)->parallelize(ParallelType::TIDz); |
1113 | tv4c->axis(4)->parallelize(ParallelType::TIDy); |
1114 | |
1115 | tv4->axis(0)->parallelize(ParallelType::BIDx); |
1116 | tv4->axis(1)->parallelize(ParallelType::BIDy); |
1117 | |
1118 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
1119 | auto t0 = at::randn({M, K1}, options); |
1120 | auto t1 = at::randn({K2, K1}, options); |
1121 | auto t2 = at::randn({N, K2}, options); |
1122 | |
1123 | auto tref = t0.to(at::kFloat) |
1124 | .matmul(t1.t().to(at::kFloat)) |
1125 | .matmul(t2.t().to(at::kFloat)); |
1126 | |
1127 | FusionExecutor fe; |
1128 | |
1129 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
1130 | 8, 0, fe.compileFusion(&fusion, {t0, t1, t2})); |
1131 | |
1132 | auto cg_outputs = fe.runFusion({t0, t1, t2}); |
1133 | |
1134 | // relaxed check for now, err accumulation is significant. |
1135 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.1, 0.1)); |
1136 | } |
1137 | |
1138 | // Simplified Matmul-Softmax-Matmul test on Ampere |
1139 | // (To be extended in follow ups) |
1140 | TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { |
1141 | NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); |
1142 | |
1143 | Fusion fusion; |
1144 | FusionGuard fg(&fusion); |
1145 | |
1146 | // Omitting outer dimensions and pointwise ops |
1147 | |
1148 | const int seql_q = 32; |
1149 | const int seql_k = 128; |
1150 | const int hidden_size = 1024; |
1151 | const int num_heads = 16; |
1152 | const int head_dim = hidden_size / num_heads; |
1153 | |
1154 | // Gemm 1: |
1155 | // (80, 80, 64) |
1156 | const int M1 = seql_q, N1 = seql_k, K1 = head_dim; |
1157 | // (80, 64, 80) |
1158 | const int M2 = seql_q, N2 = head_dim, K2 = seql_k; |
1159 | |
1160 | // Fusion definition (Both gemms are TN) |
1161 | // [M,K1] |
1162 | auto inp = makeConcreteTensor({M1, K1}, DataType::Half); |
1163 | // Query matrix |
1164 | auto qk = makeConcreteTensor({N1, K1}, DataType::Half); |
1165 | // Second linear matrix |
1166 | auto acc = makeConcreteTensor({N2, K2}, DataType::Half); |
1167 | |
1168 | fusion.addInput(inp); |
1169 | fusion.addInput(qk); |
1170 | fusion.addInput(acc); |
1171 | |
1172 | // [M,N,K] |
1173 | auto tv0b = broadcast(inp, {false, true, false}); |
1174 | auto tv1b = broadcast(qk, {true, false, false}); |
1175 | auto tv2b = broadcast(acc, {true, false, false}); |
1176 | |
1177 | // [M,K2,R] |
1178 | auto tv3 = fusedMultiplySum(tv0b, tv1b, {2}); |
1179 | |
1180 | // Inline define softmax for now for scheduling |
1181 | auto x = tv3; |
1182 | const int kReductionAxis = 1; |
1183 | const int kNumberOfDims = 2; |
1184 | std::vector<bool> broadcast_mask(kNumberOfDims, false); |
1185 | broadcast_mask[kReductionAxis] = true; |
1186 | |
1187 | auto max_val = max(x, {kReductionAxis}); |
1188 | auto bcast_max = broadcast(max_val, broadcast_mask); |
1189 | auto x_max_sub = sub(x, bcast_max); |
1190 | auto exp_val = exp(x_max_sub); |
1191 | auto sum_exp = sum(exp_val, {kReductionAxis}); |
1192 | auto bcast_sum = broadcast(sum_exp, broadcast_mask); |
1193 | auto recip = reciprocal(bcast_sum); |
1194 | auto tv3sfm = mul(exp_val, recip); |
1195 | |
1196 | auto tv3h = castOp(DataType::Half, tv3sfm); |
1197 | auto tv3b = broadcast(tv3h, {false, true, false}); |
1198 | auto tv4 = fusedMultiplySum(tv3b, tv2b, {2}); |
1199 | |
1200 | fusion.addOutput(tv4); |
1201 | |
1202 | // Fusion: |
1203 | // Gemm(M,K2,K1) x Gemm(M,N,K2) |
1204 | MatMulTileOptions gemm_tile; |
1205 | |
1206 | // TODO: use very small tiles for now since |
1207 | // alias pass is not re-using smem. Fix later. |
1208 | gemm_tile.cta_tile = GemmTile(32, 128, 32); |
1209 | |
1210 | // Distribute to 2x2 warps |
1211 | gemm_tile.warp_tile = GemmTile(16, 64, 32); |
1212 | |
1213 | // Using Ampere mma macro |
1214 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
1215 | |
1216 | auto mma_builder1 = |
1217 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
1218 | .layout(MmaOptions::MmaInputLayout::TN); |
1219 | |
1220 | auto mma_builder2 = |
1221 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
1222 | .layout(MmaOptions::MmaInputLayout::TN); |
1223 | |
1224 | mma_builder1.configureMma(tv3); |
1225 | mma_builder2.configureMma(tv4); |
1226 | |
1227 | // Global read for gemm 1 |
1228 | auto tv0r = inp->cacheAfter(); |
1229 | auto tv1r = qk->cacheAfter(); |
1230 | |
1231 | // Global read for gemm 2 |
1232 | auto tv2r = acc->cacheAfter(); |
1233 | |
1234 | // Gemm 1 main loop read |
1235 | auto tv0cw = tv0r->cacheAfter(); |
1236 | auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); |
1237 | auto tv1cw = tv1r->cacheAfter(); |
1238 | auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); |
1239 | |
1240 | // Gemm 1 accumulator reg |
1241 | auto tv3c = tv3->cacheBefore(); |
1242 | mma_builder1.accumulatorTv(tv3c); |
1243 | |
1244 | // Softmax conversion: |
1245 | auto tv3ccr = tv3->cacheAfter(); |
1246 | |
1247 | // tv3ccr -> tv3h : softmax |
1248 | |
1249 | // Gemm 2 main loop read |
1250 | // auto tv3cw = tv3h->cacheAfter(); |
1251 | auto tv3cr = tv3h->cacheAfter(LoadStoreOpType::LdMatrix); |
1252 | |
1253 | auto tv2cw = tv2r->cacheAfter(); |
1254 | auto tv2cr = tv2cw->cacheAfter(LoadStoreOpType::LdMatrix); |
1255 | |
1256 | // Gemm 2 accumulator reg |
1257 | auto tv4c = tv4->cacheBefore(); |
1258 | mma_builder2.accumulatorTv(tv4c); |
1259 | |
1260 | // Schedule gemm 2: |
1261 | // ------------------------------------------------------------------ |
1262 | tv4->split(-2, gemm_tile.cta_tile.m); |
1263 | tv4->split(-1, gemm_tile.cta_tile.n); |
1264 | |
1265 | // 0 1 2 3 |
1266 | // [Mo,M128, No, N128] |
1267 | tv4->reorder({{1, 2}, {2, 1}}); |
1268 | |
1269 | // 0 1 2 3 |
1270 | // [Mo,No, M128, N128] |
1271 | acc->computeAt(tv4, 2); |
1272 | tv3->computeAt(tv4, 2); |
1273 | |
1274 | // Order K |
1275 | // 0 1 2 3 4 5 |
1276 | // [Mo,No, M128, N128, Ko, K32] |
1277 | tv4c->split(-1, gemm_tile.cta_tile.k); |
1278 | tv4c->reorder({{2, 3}, {3, 4}, {4, 2}}); |
1279 | |
1280 | // 0 1 2 3 4 5 |
1281 | // [Mo,No, Ko M128, N128, K32] |
1282 | tv3->computeAt(tv4c, 2); |
1283 | tv2r->computeAt(tv4c, 3); |
1284 | |
1285 | // Make warp tile |
1286 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile); |
1287 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
1288 | tv4, gemm_tile); |
1289 | // -8 -7 -6 -5 -4 -3 -2 -1 |
1290 | // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] |
1291 | tv3cr->computeAt(tv4c, -4); |
1292 | tv2cr->computeAt(tv4c, -4); |
1293 | |
1294 | // Schedule tv2 gmem read and smem write: |
1295 | // ---------------------------------------------------------------- |
1296 | // [No,Ko,N,K] |
1297 | tv2cw->merge(-2); |
1298 | tv2r->merge(-2); |
1299 | |
1300 | // [No,Ko,i,wy,wx,v] |
1301 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1302 | tv2cw, gemm_tile, 8); |
1303 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1304 | tv2r, gemm_tile, 8); |
1305 | tv2cw->setMemoryType(MemoryType::Shared); |
1306 | |
1307 | // Schedule tv2 gmem read and smem write: |
1308 | // ---------------------------------------------------------------- |
1309 | |
1310 | // Schedule gemm 2 mma input |
1311 | // --------------------------------------------------------------------------- |
1312 | tv3cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); |
1313 | // [... Mi, Ni, Ki] want [Ni, Mi, Ki] |
1314 | tv3b->reorder({{-2, -3}, {-3, -2}}); |
1315 | tv3b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::A).build()); |
1316 | |
1317 | tv2cr->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); |
1318 | tv2b->applyMmaSwizzle(mma_builder2.operand(MmaOptions::Operand::B).build()); |
1319 | |
1320 | // Schedule mma output |
1321 | // --------------------------------------------------------------------------- |
1322 | tv4c->applyMmaSwizzle( |
1323 | mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); |
1324 | tv4->applyMmaSwizzle( |
1325 | mma_builder2.operand(MmaOptions::Operand::Accumulator).build()); |
1326 | |
1327 | // Schedule gemm 1: |
1328 | // ------------------------------------------------------------------ |
1329 | |
1330 | // CTA tile: |
1331 | // [Mo, Mi128, N80] |
1332 | |
1333 | tv3->split(-1, gemm_tile.cta_tile.n); |
1334 | // [Mo, Mi128, No, Ni128] |
1335 | |
1336 | tv3->reorder({{1, 2}, {2, 1}}); |
1337 | |
1338 | // [Mo, No, Mi128, Ni128] |
1339 | inp->computeAt(tv3, 2); |
1340 | qk->computeAt(tv3, 2); |
1341 | |
1342 | // Schedule K dim for gemm 1: |
1343 | |
1344 | // Order K |
1345 | // 0 1 2 3 4 5 |
1346 | // [Mo,No, M128, N128, Ko, K32] |
1347 | tv3c->split(-1, gemm_tile.cta_tile.k); |
1348 | tv3c->reorder({{2, 3}, {3, 4}, {4, 2}}); |
1349 | // 0 1 2 3 4 5 |
1350 | // [Mo,No, Ko M128, N128, K32] |
1351 | tv0r->computeAt(tv3c, 3); |
1352 | tv1r->computeAt(tv3c, 3); |
1353 | |
1354 | // Make warp tile: |
1355 | // ------------------------------------------------------------------------- |
1356 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv3c, gemm_tile); |
1357 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
1358 | tv3, gemm_tile); |
1359 | |
1360 | tv0cr->computeAt(tv3c, -4); |
1361 | tv1cr->computeAt(tv3c, -4); |
1362 | |
1363 | // tv3->computeAt(tv3cw,-3); |
1364 | |
1365 | // Schedule gmem read and smem write: |
1366 | // --------------------------------------------------------------------------- |
1367 | // [Mo,Ko,M,K] |
1368 | tv0cw->merge(-2); |
1369 | tv0r->merge(-2); |
1370 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1371 | tv0cw, gemm_tile, 8); |
1372 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1373 | tv0r, gemm_tile, 8); |
1374 | tv0cw->setMemoryType(MemoryType::Shared); |
1375 | // [Mo,Ko,i,wy,wx,v] |
1376 | |
1377 | // [No,Ko,N,K] |
1378 | tv1cw->merge(-2); |
1379 | tv1r->merge(-2); |
1380 | // [No,Ko,i,wy,wx,v] |
1381 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1382 | tv1cw, gemm_tile, 8); |
1383 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1384 | tv1r, gemm_tile, 8); |
1385 | tv1cw->setMemoryType(MemoryType::Shared); |
1386 | |
1387 | // Schedule mma input |
1388 | // --------------------------------------------------------------------------- |
1389 | tv0cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); |
1390 | // [... Mi, Ni, Ki] want [Ni, Mi, Ki] |
1391 | tv0b->reorder({{-2, -3}, {-3, -2}}); |
1392 | tv0b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::A).build()); |
1393 | |
1394 | tv1cr->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); |
1395 | tv1b->applyMmaSwizzle(mma_builder1.operand(MmaOptions::Operand::B).build()); |
1396 | |
1397 | // // Schedule mma output |
1398 | // // |
1399 | // --------------------------------------------------------------------------- |
1400 | tv3c->applyMmaSwizzle( |
1401 | mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); |
1402 | tv3->applyMmaSwizzle( |
1403 | mma_builder1.operand(MmaOptions::Operand::Accumulator).build()); |
1404 | |
1405 | // mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(tv3ccw, |
1406 | // mma_builder1.build()); |
1407 | |
1408 | // Put tv3 result in smem |
1409 | tv3->setMemoryType(MemoryType::Shared); |
1410 | |
1411 | // schedule a reg persistent softmax: from tv3 |
1412 | // [Mo, M128, RN] |
1413 | max_val->split(-1, 128); |
1414 | // [Mo, M128, RN1, RN128] |
1415 | max_val->split(-1, 4); |
1416 | // Map to warp (2x2) |
1417 | max_val->split(-4, 4); |
1418 | max_val->split(-4, 2); |
1419 | |
1420 | // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] |
1421 | auto max_rf = max_val->rFactor({-1}); |
1422 | // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] |
1423 | |
1424 | // [Mo, M128, RN] |
1425 | sum_exp->split(-1, 128); |
1426 | // [Mo, M128, RN1, RN128] |
1427 | sum_exp->split(-1, 4); |
1428 | // Map to warp (2x2) |
1429 | sum_exp->split(-4, 4); |
1430 | sum_exp->split(-4, 2); |
1431 | |
1432 | // [Mo, Mo32, My2, Mx2, RN1, RNo32, RNi4] |
1433 | auto sum_exp_rf = sum_exp->rFactor({-1}); |
1434 | // [Mo, Mo32, My2, Mx2, RN1, I32, RNi4] |
1435 | |
1436 | exp_val->computeAt(sum_exp_rf, 4); |
1437 | exp_val->split(-1, 128); |
1438 | exp_val->split(-1, 4); |
1439 | bcast_max->computeAt(exp_val, -2); |
1440 | |
1441 | // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] |
1442 | |
1443 | // Read from smem |
1444 | tv3ccr->computeAt(max_rf, 4); |
1445 | // [Mo, Mo32, My2, Mx2, N80] |
1446 | tv3ccr->split(-1, 128); |
1447 | tv3ccr->split(-1, 4); |
1448 | // [Mo, Mo32, My2, Mx2, IN1, I32, INi4] |
1449 | |
1450 | // Write to second gemm |
1451 | tv3h->split(-1, 128); |
1452 | tv3h->split(-1, 4); |
1453 | // Map to warp (2x2) |
1454 | tv3h->split(-4, 4); |
1455 | tv3h->split(-4, 2); |
1456 | |
1457 | bcast_sum->computeAt(tv3h, -2); |
1458 | |
1459 | tv3h->setMemoryType(MemoryType::Shared); |
1460 | |
1461 | // Parallelize |
1462 | tv4->axis(0)->parallelize(ParallelType::BIDx); |
1463 | // 0 1 2 3 4 5 6 7 |
1464 | // [Mo No Mwo Nwo Mw Nw (Mi Ni)] |
1465 | // Gemm 1 |
1466 | tv3c->axis(3)->parallelize(ParallelType::TIDz); |
1467 | tv3c->axis(4)->parallelize(ParallelType::TIDy); |
1468 | tv3->axis(2)->parallelize(ParallelType::TIDz); |
1469 | tv3->axis(3)->parallelize(ParallelType::TIDy); |
1470 | |
1471 | auto parallelize_non_reduced_val = [](TensorView* tv) { |
1472 | tv->axis(-2)->parallelize(ParallelType::TIDx); |
1473 | tv->axis(2)->parallelize(ParallelType::TIDz); |
1474 | tv->axis(3)->parallelize(ParallelType::TIDy); |
1475 | }; |
1476 | |
1477 | auto parallelize_reduced_val = [](TensorView* tv) { |
1478 | tv->axis(-1)->parallelize(ParallelType::TIDx); |
1479 | tv->axis(2)->parallelize(ParallelType::TIDz); |
1480 | tv->axis(3)->parallelize(ParallelType::TIDy); |
1481 | }; |
1482 | |
1483 | parallelize_non_reduced_val(tv3h); |
1484 | parallelize_non_reduced_val(max_rf); |
1485 | parallelize_non_reduced_val(bcast_max); |
1486 | parallelize_non_reduced_val(exp_val); |
1487 | parallelize_non_reduced_val(sum_exp_rf); |
1488 | parallelize_non_reduced_val(bcast_sum); |
1489 | parallelize_non_reduced_val(recip); |
1490 | |
1491 | parallelize_reduced_val(max_val); |
1492 | parallelize_reduced_val(sum_exp); |
1493 | |
1494 | // 0 1 2 3 4 5 6 7 |
1495 | // [Mo No Mwo Nwo Mw Nw (Mi Ni)] |
1496 | // Gemm 2 |
1497 | tv4->axis(2)->parallelize(ParallelType::TIDz); |
1498 | tv4->axis(3)->parallelize(ParallelType::TIDy); |
1499 | tv4c->axis(3)->parallelize(ParallelType::TIDz); |
1500 | tv4c->axis(4)->parallelize(ParallelType::TIDy); |
1501 | |
1502 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
1503 | auto t0 = at::randn({M1, K1}, options); |
1504 | auto t1 = at::randn({N1, K1}, options); |
1505 | auto t2 = at::randn({N2, K2}, options); |
1506 | |
1507 | FusionExecutor fe; |
1508 | |
1509 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
1510 | 8, 0, fe.compileFusion(&fusion, {t0, t1, t2})); |
1511 | |
1512 | auto cg_outputs = fe.runFusion({t0, t1, t2}); |
1513 | |
1514 | auto g1 = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); |
1515 | auto sg1 = at::_softmax(g1, -1, false); |
1516 | auto gsg1 = sg1.matmul(t2.t().to(at::kFloat)); |
1517 | |
1518 | TORCH_CHECK(cg_outputs[0].allclose(gsg1, 0.001, 0.001)); |
1519 | } |
1520 | |
1521 | // MMA unit test on Turing |
1522 | TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { |
1523 | Fusion fusion; |
1524 | FusionGuard fg(&fusion); |
1525 | |
1526 | // [M,K] |
1527 | auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); |
1528 | // [N,K] |
1529 | auto tv1 = makeConcreteTensor({8, 16}, DataType::Half); |
1530 | fusion.addInput(tv0); |
1531 | fusion.addInput(tv1); |
1532 | |
1533 | // [M,N,K] |
1534 | auto tv0b = broadcast(tv0, {false, true, false}); |
1535 | auto tv1b = broadcast(tv1, {true, false, false}); |
1536 | |
1537 | // Leaving both sets of mma inputs for volta outside |
1538 | // currently since they need to be swizzled. |
1539 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); |
1540 | |
1541 | fusion.addOutput(tv2); |
1542 | |
1543 | MatMulTileOptions gemm_tile; |
1544 | gemm_tile.cta_tile = GemmTile(16, 8, 16); |
1545 | gemm_tile.warp_tile = GemmTile(16, 8, 16); |
1546 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
1547 | |
1548 | auto mma_builder = |
1549 | MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) |
1550 | .layout(MmaOptions::MmaInputLayout::TN); |
1551 | |
1552 | mma_builder.configureMma(tv2); |
1553 | |
1554 | auto tv0cw = tv0b->cacheAfter(); |
1555 | auto tv0cr = |
1556 | tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); |
1557 | auto tv1cw = tv1b->cacheAfter(); |
1558 | auto tv1cr = |
1559 | tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); |
1560 | |
1561 | auto tv2c = tv2->cacheBefore(); |
1562 | mma_builder.accumulatorTv(tv2c); |
1563 | |
1564 | // [M,N,K] -> [N,M,K] |
1565 | tv0cr->reorder({{-2, -3}, {-3, -2}}); |
1566 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
1567 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
1568 | tv2c->applyMmaSwizzle( |
1569 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
1570 | tv2->applyMmaSwizzle( |
1571 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
1572 | |
1573 | tv0cw->setMemoryType(MemoryType::Shared); |
1574 | tv1cw->setMemoryType(MemoryType::Shared); |
1575 | |
1576 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
1577 | auto t0 = at::randn({16, 16}, options); |
1578 | auto t1 = at::randn({8, 16}, options); |
1579 | |
1580 | FusionExecutor fe; |
1581 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
1582 | 7, 5, fe.compileFusion(&fusion, {t0, t1})); |
1583 | |
1584 | auto cg_outputs = fe.runFusion({t0, t1}); |
1585 | |
1586 | auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); |
1587 | |
1588 | testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); |
1589 | } |
1590 | |
1591 | // MMA unit test on Turing |
1592 | TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { |
1593 | Fusion fusion; |
1594 | FusionGuard fg(&fusion); |
1595 | |
1596 | // [M,K] |
1597 | auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); |
1598 | // [K,N] |
1599 | auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); |
1600 | fusion.addInput(tv0); |
1601 | fusion.addInput(tv1); |
1602 | |
1603 | // [M,K,N] |
1604 | auto tv0b = broadcast(tv0, {false, false, true}); |
1605 | auto tv1b = broadcast(tv1, {true, false, false}); |
1606 | |
1607 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); |
1608 | |
1609 | fusion.addOutput(tv2); |
1610 | |
1611 | MatMulTileOptions gemm_tile; |
1612 | gemm_tile.cta_tile = GemmTile(16, 8, 16); |
1613 | gemm_tile.warp_tile = GemmTile(16, 8, 16); |
1614 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
1615 | |
1616 | auto mma_builder = |
1617 | MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) |
1618 | .layout(MmaOptions::MmaInputLayout::TT); |
1619 | |
1620 | mma_builder.configureMma(tv2); |
1621 | |
1622 | auto tv0cw = tv0b->cacheAfter(); |
1623 | auto tv0cr = |
1624 | tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); |
1625 | auto tv1cw = tv1b->cacheAfter(); |
1626 | auto tv1cr = |
1627 | tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); |
1628 | |
1629 | auto tv2c = tv2->cacheBefore(); |
1630 | mma_builder.accumulatorTv(tv2c); |
1631 | |
1632 | // [M,K,N] -> [N,M,K] |
1633 | tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}}); |
1634 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
1635 | |
1636 | // [M,K,N] -> [M,N,K] |
1637 | tv1cr->reorder({{-2, -1}, {-1, -2}}); |
1638 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
1639 | |
1640 | // [M,K,N] -> [M,N,K] |
1641 | tv2c->reorder({{-2, -1}, {-1, -2}}); |
1642 | tv2c->applyMmaSwizzle( |
1643 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
1644 | tv2->applyMmaSwizzle( |
1645 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
1646 | |
1647 | tv0cw->setMemoryType(MemoryType::Shared); |
1648 | tv1cw->setMemoryType(MemoryType::Shared); |
1649 | |
1650 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
1651 | auto t0 = at::randn({16, 16}, options); |
1652 | auto t1 = at::randn({16, 8}, options); |
1653 | |
1654 | FusionExecutor fe; |
1655 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
1656 | 7, 5, fe.compileFusion(&fusion, {t0, t1})); |
1657 | |
1658 | auto cg_outputs = fe.runFusion({t0, t1}); |
1659 | |
1660 | auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat)); |
1661 | |
1662 | testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); |
1663 | } |
1664 | |
1665 | // MMA unit test on Turing |
1666 | TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { |
1667 | Fusion fusion; |
1668 | FusionGuard fg(&fusion); |
1669 | |
1670 | // [K,M] |
1671 | auto tv0 = makeConcreteTensor({16, 16}, DataType::Half); |
1672 | // [K,N] |
1673 | auto tv1 = makeConcreteTensor({16, 8}, DataType::Half); |
1674 | fusion.addInput(tv0); |
1675 | fusion.addInput(tv1); |
1676 | |
1677 | // [K,M,N] |
1678 | auto tv0b = broadcast(tv0, {false, false, true}); |
1679 | auto tv1b = broadcast(tv1, {false, true, false}); |
1680 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {0}); |
1681 | |
1682 | fusion.addOutput(tv2); |
1683 | |
1684 | MatMulTileOptions gemm_tile; |
1685 | gemm_tile.cta_tile = GemmTile(16, 8, 16); |
1686 | gemm_tile.warp_tile = GemmTile(16, 8, 16); |
1687 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
1688 | |
1689 | auto mma_builder = |
1690 | MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) |
1691 | .layout(MmaOptions::MmaInputLayout::NT); |
1692 | |
1693 | mma_builder.configureMma(tv2); |
1694 | |
1695 | auto tv0cw = tv0b->cacheAfter(); |
1696 | auto tv0cr = |
1697 | tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); |
1698 | auto tv1cw = tv1b->cacheAfter(); |
1699 | auto tv1cr = |
1700 | tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); |
1701 | |
1702 | auto tv2c = tv2->cacheBefore(); |
1703 | mma_builder.accumulatorTv(tv2c); |
1704 | |
1705 | // [K,M,N] -> [N,M,K] |
1706 | tv0cr->reorder({{-3, -1}, {-1, -3}}); |
1707 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
1708 | |
1709 | // [K,M,N] -> [M,N,K] |
1710 | tv1cr->reorder({ |
1711 | {-3, -1}, |
1712 | {-2, -3}, |
1713 | {-1, -2}, |
1714 | }); |
1715 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
1716 | |
1717 | // [K,M,N] -> [M,N,K] |
1718 | tv2c->reorder({{-3, -1}, {-2, -3}, {-1, -2}}); |
1719 | tv2c->applyMmaSwizzle( |
1720 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
1721 | tv2->applyMmaSwizzle( |
1722 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
1723 | |
1724 | tv0cw->setMemoryType(MemoryType::Shared); |
1725 | tv1cw->setMemoryType(MemoryType::Shared); |
1726 | |
1727 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
1728 | auto t0 = at::randn({16, 16}, options); |
1729 | auto t1 = at::randn({16, 8}, options); |
1730 | |
1731 | FusionExecutor fe; |
1732 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
1733 | 7, 5, fe.compileFusion(&fusion, {t0, t1})); |
1734 | |
1735 | auto cg_outputs = fe.runFusion({t0, t1}); |
1736 | |
1737 | auto tref = t0.t().to(at::kFloat).matmul(t1.to(at::kFloat)); |
1738 | |
1739 | testValidate(&fusion, cg_outputs, {t0, t1}, {tref}, __LINE__, __FILE__); |
1740 | } |
1741 | |
1742 | // Matmul test for Turing MMA: across supported layouts |
1743 | TEST_F(NVFuserTest, FusionTuringMatmul_CUDA) { |
1744 | // Keep multiples of 8 to keep vectorizable. |
1745 | int M = 504, N = 136, K = 248; |
1746 | |
1747 | for (auto layout : kAllSupportedLayout) { |
1748 | Fusion fusion; |
1749 | FusionGuard fg(&fusion); |
1750 | auto tv0 = makeContigTensor(2, DataType::Half); |
1751 | auto tv1 = makeContigTensor(2, DataType::Half); |
1752 | |
1753 | fusion.addInput(tv0); |
1754 | fusion.addInput(tv1); |
1755 | |
1756 | auto tv2 = matmul(tv0, tv1, layout); |
1757 | |
1758 | fusion.addOutput(tv2); |
1759 | |
1760 | MatMulTileOptions gemm_tile; |
1761 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
1762 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
1763 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
1764 | |
1765 | auto mma_builder = |
1766 | MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) |
1767 | .layout(layout); |
1768 | |
1769 | MatmulParam params(mma_builder); |
1770 | params.tile_sizes = gemm_tile; |
1771 | scheduleMatmul(tv2, tv0, tv1, params); |
1772 | |
1773 | at::manual_seed(0); |
1774 | auto inputs = fp16MatmulAtInput(M, N, K, layout); |
1775 | |
1776 | FusionExecutor fe; |
1777 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
1778 | 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); |
1779 | auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); |
1780 | auto tref = atMatmul( |
1781 | inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); |
1782 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
1783 | } |
1784 | } |
1785 | |
1786 | // Matmul test on ampere, using ampere memory ops |
1787 | TEST_F(NVFuserTest, FusionAmpereMatmulTNcpAsync_CUDA) { |
1788 | Fusion fusion; |
1789 | FusionGuard fg(&fusion); |
1790 | |
1791 | int M = 255, N = 511, K = 88; |
1792 | |
1793 | // [M,K] |
1794 | auto tv0 = makeContigTensor(2, DataType::Half); |
1795 | // [N,K] |
1796 | auto tv1 = makeContigTensor(2, DataType::Half); |
1797 | fusion.addInput(tv0); |
1798 | fusion.addInput(tv1); |
1799 | |
1800 | // [M,N,K] |
1801 | auto tv0b = broadcast(tv0, {false, true, false}); |
1802 | auto tv1b = broadcast(tv1, {true, false, false}); |
1803 | |
1804 | // Leaving both sets of mma inputs for volta outside |
1805 | // currently since they need to be swizzled. |
1806 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); |
1807 | |
1808 | fusion.addOutput(tv2); |
1809 | |
1810 | MatMulTileOptions gemm_tile; |
1811 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
1812 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
1813 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
1814 | |
1815 | auto mma_builder = |
1816 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
1817 | .layout(MmaOptions::MmaInputLayout::TN); |
1818 | |
1819 | mma_builder.configureMma(tv2); |
1820 | |
1821 | auto tv0cw = tv0->cacheAfter(LoadStoreOpType::CpAsync); |
1822 | auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); |
1823 | auto tv1cw = tv1->cacheAfter(LoadStoreOpType::CpAsync); |
1824 | auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); |
1825 | auto tv2c = tv2->cacheBefore(); |
1826 | mma_builder.accumulatorTv(tv2c); |
1827 | |
1828 | // Make a CTA tile |
1829 | // ------------------------------------------------------------------ |
1830 | // [M,N] |
1831 | tv2->split(-2, gemm_tile.cta_tile.m); |
1832 | tv2->split(-1, gemm_tile.cta_tile.n); |
1833 | |
1834 | // 0 1 2 3 |
1835 | // [Mo,M128, No, N128] |
1836 | tv2->reorder({{1, 2}, {2, 1}}); |
1837 | |
1838 | // 0 1 2 3 |
1839 | // [Mo,No, M128, N128] |
1840 | tv0->computeAt(tv2, 2); |
1841 | tv1->computeAt(tv2, 2); |
1842 | |
1843 | // Order K |
1844 | // 0 1 2 3 4 5 |
1845 | // [Mo,No, M128, N128, Ko, K32] |
1846 | tv2c->split(-1, gemm_tile.cta_tile.k); |
1847 | tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); |
1848 | |
1849 | // 0 1 2 3 4 5 |
1850 | // [Mo,No, Ko M128, N128, K32] |
1851 | tv0cw->computeAt(tv2c, 3); |
1852 | tv1cw->computeAt(tv2c, 3); |
1853 | |
1854 | // Make warp tile: |
1855 | // ------------------------------------------------------------------------- |
1856 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); |
1857 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
1858 | tv2, gemm_tile); |
1859 | // -8 -7 -6 -5 -4 -3 -2 -1 |
1860 | // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] |
1861 | tv0cr->computeAt(tv2c, -4); |
1862 | tv1cr->computeAt(tv2c, -4); |
1863 | |
1864 | // Schedule gmem read and smem write: |
1865 | // --------------------------------------------------------------------------- |
1866 | // [Mo,Ko,M,K] |
1867 | tv0cw->merge(-2); |
1868 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1869 | tv0cw, gemm_tile, 8); |
1870 | tv0cw->setMemoryType(MemoryType::Shared); |
1871 | // [Mo,Ko,i,wy,wx,v] |
1872 | |
1873 | // [No,Ko,N,K] |
1874 | tv1cw->merge(-2); |
1875 | // [No,Ko,i,wy,wx,v] |
1876 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
1877 | tv1cw, gemm_tile, 8); |
1878 | tv1cw->setMemoryType(MemoryType::Shared); |
1879 | // Schedule mma input |
1880 | // --------------------------------------------------------------------------- |
1881 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
1882 | // [... Mi, Ni, Ki] |
1883 | tv0b->reorder({{-2, -3}, {-3, -2}}); |
1884 | tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
1885 | |
1886 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
1887 | tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
1888 | |
1889 | // Schedule mma output |
1890 | // --------------------------------------------------------------------------- |
1891 | tv2c->applyMmaSwizzle( |
1892 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
1893 | tv2->applyMmaSwizzle( |
1894 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
1895 | |
1896 | // Parallelize |
1897 | // 0 1 2 3 4 5 6 7 8 9 10 |
1898 | // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] |
1899 | tv2c->axis(3)->parallelize(ParallelType::TIDz); |
1900 | tv2c->axis(4)->parallelize(ParallelType::TIDy); |
1901 | |
1902 | // Parallelize |
1903 | // 0 1 2 3 4 5 6 7 |
1904 | // [Mo No Mwo Nwo Mw Nw (Mi Ni)] |
1905 | tv2->axis(0)->parallelize(ParallelType::BIDx); |
1906 | tv2->axis(1)->parallelize(ParallelType::BIDy); |
1907 | tv2->axis(2)->parallelize(ParallelType::TIDz); |
1908 | tv2->axis(3)->parallelize(ParallelType::TIDy); |
1909 | |
1910 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
1911 | auto t0 = at::randn({M, K}, options); |
1912 | auto t1 = at::randn({N, K}, options); |
1913 | |
1914 | FusionExecutor fe; |
1915 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
1916 | 8, 0, fe.compileFusion(&fusion, {t0, t1})); |
1917 | |
1918 | auto cg_outputs = fe.runFusion({t0, t1}); |
1919 | |
1920 | auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); |
1921 | |
1922 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
1923 | } |
1924 | |
1925 | TEST_F(NVFuserTest, FusionAmpereStridedBatchedMatmulTN_CUDA) { |
1926 | NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); |
1927 | |
1928 | Fusion fusion; |
1929 | FusionGuard fg(&fusion); |
1930 | int M = 511, N = 123, K = 88, B0 = 3, B1 = 5; |
1931 | |
1932 | // [B0 ,M, B1,K] |
1933 | auto tv0 = makeContigTensor(4, DataType::Half); |
1934 | // [B0, N, B1, K] |
1935 | auto tv1 = makeContigTensor(4, DataType::Half); |
1936 | fusion.addInput(tv0); |
1937 | fusion.addInput(tv1); |
1938 | |
1939 | // [B0,M,N,B1,K] |
1940 | auto tv0b = broadcast(tv0, {false, false, true, false, false}); |
1941 | auto tv1b = broadcast(tv1, {false, true, false, false, false}); |
1942 | |
1943 | // Leaving both sets of mma inputs for volta outside |
1944 | // currently since they need to be swizzled. |
1945 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {4}); |
1946 | |
1947 | fusion.addOutput(tv2); |
1948 | |
1949 | MatMulTileOptions gemm_tile; |
1950 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
1951 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
1952 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
1953 | |
1954 | auto mma_builder = |
1955 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
1956 | .layout(MmaOptions::MmaInputLayout::TN); |
1957 | |
1958 | mma_builder.configureMma(tv2); |
1959 | |
1960 | auto tv0r = tv0->cacheAfter(); |
1961 | auto tv1r = tv1->cacheAfter(); |
1962 | auto tv0cw = tv0r->cacheAfter(); |
1963 | auto tv0cr = |
1964 | tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); |
1965 | auto tv1cw = tv1r->cacheAfter(); |
1966 | auto tv1cr = |
1967 | tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); |
1968 | auto tv2c = tv2->cacheBefore(); |
1969 | mma_builder.accumulatorTv(tv2c); |
1970 | |
1971 | // Group the BATCHED DIMS: |
1972 | // -4 -3 -2 -1 |
1973 | // [B0, M, N, B1] |
1974 | tv2->reorder({{-3, -2}, {-2, -1}, {-1, -4}}); |
1975 | |
1976 | // -4 -3 -2 -1 |
1977 | // [B0, B1, M,N] |
1978 | |
1979 | // Make a CTA tile |
1980 | // ------------------------------------------------------------------ |
1981 | // [B0, B1, M, N] |
1982 | tv2->split(-2, gemm_tile.cta_tile.m); |
1983 | tv2->split(-1, gemm_tile.cta_tile.n); |
1984 | |
1985 | // 0 1 2 3 4 5 |
1986 | // [B0, B1, Mo,M128, No, N128] |
1987 | tv2->reorder({{-3, -2}, {-2, -3}}); |
1988 | |
1989 | // 0 1 2 3 4 5 |
1990 | // [B0, B1, Mo, No, M128, N128] |
1991 | |
1992 | // Merge the outer dims: |
1993 | tv2->merge(0); |
1994 | tv2->merge(0); |
1995 | |
1996 | // 0 1 2 3 |
1997 | // [Mo,No, M128, N128] |
1998 | tv0->computeAt(tv2, 2); |
1999 | tv1->computeAt(tv2, 2); |
2000 | |
2001 | // Order K |
2002 | // 0 1 2 3 4 5 |
2003 | // [Mo,No, M128, N128, Ko, K32] |
2004 | tv2c->split(-1, gemm_tile.cta_tile.k); |
2005 | tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); |
2006 | |
2007 | // 0 1 2 3 4 5 |
2008 | // [Mo,No, Ko M128, N128, K32] |
2009 | tv0r->computeAt(tv2c, 3); |
2010 | tv1r->computeAt(tv2c, 3); |
2011 | |
2012 | // Make warp tile: |
2013 | // ------------------------------------------------------------------------- |
2014 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); |
2015 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
2016 | tv2, gemm_tile); |
2017 | // -8 -7 -6 -5 -4 -3 -2 -1 |
2018 | // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] |
2019 | tv0cr->computeAt(tv2c, -4); |
2020 | tv1cr->computeAt(tv2c, -4); |
2021 | |
2022 | // Schedule gmem read and smem write: |
2023 | // --------------------------------------------------------------------------- |
2024 | // [Mo,Ko,M,K] |
2025 | tv0cw->merge(-2); |
2026 | tv0r->merge(-2); |
2027 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2028 | tv0cw, gemm_tile, 8); |
2029 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2030 | tv0r, gemm_tile, 8); |
2031 | tv0cw->setMemoryType(MemoryType::Shared); |
2032 | // [Mo,Ko,i,wy,wx,v] |
2033 | |
2034 | // [No,Ko,N,K] |
2035 | tv1cw->merge(-2); |
2036 | tv1r->merge(-2); |
2037 | // [No,Ko,i,wy,wx,v] |
2038 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2039 | tv1cw, gemm_tile, 8); |
2040 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2041 | tv1r, gemm_tile, 8); |
2042 | tv1cw->setMemoryType(MemoryType::Shared); |
2043 | // Schedule mma input |
2044 | // --------------------------------------------------------------------------- |
2045 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
2046 | |
2047 | // [... Mi, Ni, Ki] want [Ni, Mi, Ki] |
2048 | tv0b->reorder({{-2, -3}, {-3, -2}}); |
2049 | tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
2050 | |
2051 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
2052 | tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
2053 | |
2054 | // Schedule mma output |
2055 | // --------------------------------------------------------------------------- |
2056 | tv2c->applyMmaSwizzle( |
2057 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2058 | tv2->applyMmaSwizzle( |
2059 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2060 | |
2061 | // Parallelize |
2062 | // 0 1 2 3 4 5 6 7 8 9 10 |
2063 | // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] |
2064 | tv2c->axis(3)->parallelize(ParallelType::TIDz); |
2065 | tv2c->axis(4)->parallelize(ParallelType::TIDy); |
2066 | |
2067 | // Parallelize |
2068 | // 0 1 2 3 4 5 6 7 |
2069 | // [Mo No Mwo Nwo Mw Nw (Mi Ni)] |
2070 | tv2->axis(0)->parallelize(ParallelType::BIDx); |
2071 | tv2->axis(1)->parallelize(ParallelType::BIDy); |
2072 | tv2->axis(2)->parallelize(ParallelType::TIDz); |
2073 | tv2->axis(3)->parallelize(ParallelType::TIDy); |
2074 | |
2075 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
2076 | auto t0 = at::randn({B0, M, B1, K}, options); |
2077 | auto t1 = at::randn({B0, N, B1, K}, options); |
2078 | |
2079 | FusionExecutor fe; |
2080 | |
2081 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
2082 | 8, 0, fe.compileFusion(&fusion, {t0, t1})); |
2083 | |
2084 | auto cg_outputs = fe.runFusion({t0, t1}); |
2085 | |
2086 | // ref implementation: |
2087 | auto ref_t0 = t0.permute({0, 2, 1, 3}) |
2088 | .contiguous() |
2089 | .view({B0 * B1, M, K}); // B0, B1, M, K |
2090 | auto ref_t1 = t1.permute({0, 2, 3, 1}) |
2091 | .contiguous() |
2092 | .view({B0 * B1, K, N}); // B0, B1, K, N |
2093 | auto ref_permuted = |
2094 | ref_t0.to(at::kFloat).bmm(ref_t1.to(at::kFloat)); // B0*B1, M,N |
2095 | auto ref = ref_permuted.view({B0, B1, M, N}) |
2096 | .permute({0, 2, 3, 1}) |
2097 | .contiguous(); // B0,M,N,B1 |
2098 | TORCH_CHECK(cg_outputs[0].allclose(ref, 0.0001, 0.0001)); |
2099 | } |
2100 | |
2101 | // Matmul test on Ampere with a view on prolog |
2102 | TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { |
2103 | NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); |
2104 | |
2105 | Fusion fusion; |
2106 | FusionGuard fg(&fusion); |
2107 | int M = 511, N = 257, K = 88; |
2108 | int Ko = 11, Ki = 8; |
2109 | |
2110 | // [M,Ko,Ki] |
2111 | auto tv0 = makeContigTensor(3, DataType::Half); |
2112 | // [N,K] |
2113 | auto tv1 = makeContigTensor(2, DataType::Half); |
2114 | fusion.addInput(tv0); |
2115 | fusion.addInput(tv1); |
2116 | |
2117 | auto tv0_view = view(tv0, {M, Ko, Ki}, {M, K}); |
2118 | |
2119 | // [M,N,K] |
2120 | auto tv0b = broadcast(tv0_view, {false, true, false}); |
2121 | auto tv1b = broadcast(tv1, {true, false, false}); |
2122 | |
2123 | // Leaving both sets of mma inputs for volta outside |
2124 | // currently since they need to be swizzled. |
2125 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); |
2126 | |
2127 | fusion.addOutput(tv2); |
2128 | |
2129 | MatMulTileOptions gemm_tile; |
2130 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
2131 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
2132 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
2133 | |
2134 | auto mma_builder = |
2135 | MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) |
2136 | .layout(MmaOptions::MmaInputLayout::TN); |
2137 | |
2138 | mma_builder.configureMma(tv2); |
2139 | |
2140 | auto tv0r = tv0->cacheAfter(); |
2141 | auto tv1r = tv1->cacheAfter(); |
2142 | auto tv0cw = tv0_view->cacheAfter(); |
2143 | auto tv0cr = |
2144 | tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); |
2145 | auto tv1cw = tv1r->cacheAfter(); |
2146 | auto tv1cr = |
2147 | tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); |
2148 | auto tv2c = tv2->cacheBefore(); |
2149 | mma_builder.accumulatorTv(tv2c); |
2150 | |
2151 | // Make a CTA tile |
2152 | // ------------------------------------------------------------------ |
2153 | // [M,N] |
2154 | tv2->split(-2, gemm_tile.cta_tile.m); |
2155 | tv2->split(-1, gemm_tile.cta_tile.n); |
2156 | |
2157 | // 0 1 2 3 |
2158 | // [Mo,M128, No, N128] |
2159 | tv2->reorder({{1, 2}, {2, 1}}); |
2160 | |
2161 | // 0 1 2 3 |
2162 | // [Mo,No, M128, N128] |
2163 | tv0->computeAt(tv2, 2); |
2164 | tv1->computeAt(tv2, 2); |
2165 | |
2166 | // Order K |
2167 | // 0 1 2 3 4 5 |
2168 | // [Mo,No, M128, N128, Ko, K32] |
2169 | tv2c->split(-1, gemm_tile.cta_tile.k); |
2170 | tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); |
2171 | |
2172 | // 0 1 2 3 4 5 |
2173 | // [Mo,No, Ko M128, N128, K32] |
2174 | tv0r->computeAt(tv2c, 3); |
2175 | tv1r->computeAt(tv2c, 3); |
2176 | |
2177 | // Make warp tile: |
2178 | // ------------------------------------------------------------------------- |
2179 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); |
2180 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
2181 | tv2, gemm_tile); |
2182 | // -8 -7 -6 -5 -4 -3 -2 -1 |
2183 | // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] |
2184 | tv0cr->computeAt(tv2c, -4); |
2185 | tv1cr->computeAt(tv2c, -4); |
2186 | |
2187 | // Schedule gmem read and smem write: |
2188 | // --------------------------------------------------------------------------- |
2189 | // [Mo,Ko,M,K] |
2190 | tv0cw->merge(-2); |
2191 | tv0r->merge(-2); |
2192 | tv0_view->merge(-2); |
2193 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2194 | tv0cw, gemm_tile, 8); |
2195 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2196 | tv0r, gemm_tile, 8); |
2197 | tv0cw->setMemoryType(MemoryType::Shared); |
2198 | // [Mo,Ko,i,wy,wx,v] |
2199 | |
2200 | // [No,Ko,N,K] |
2201 | tv1cw->merge(-2); |
2202 | tv1r->merge(-2); |
2203 | // [No,Ko,i,wy,wx,v] |
2204 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2205 | tv1cw, gemm_tile, 8); |
2206 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2207 | tv1r, gemm_tile, 8); |
2208 | tv1cw->setMemoryType(MemoryType::Shared); |
2209 | // Schedule mma input |
2210 | // --------------------------------------------------------------------------- |
2211 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
2212 | |
2213 | // [... Mi, Ni, Ki] want [Ni, Mi, Ki] |
2214 | tv0b->reorder({{-2, -3}, {-3, -2}}); |
2215 | tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
2216 | |
2217 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
2218 | tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
2219 | |
2220 | // Schedule mma output |
2221 | // --------------------------------------------------------------------------- |
2222 | tv2c->applyMmaSwizzle( |
2223 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2224 | tv2->applyMmaSwizzle( |
2225 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2226 | |
2227 | // Inline the view op with the shared mem write minus |
2228 | // the vectorization axes for now. |
2229 | tv0_view->computeAt(tv0cw, -2); |
2230 | |
2231 | // Parallelize |
2232 | // 0 1 2 3 4 5 6 7 8 9 10 |
2233 | // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] |
2234 | tv2c->axis(3)->parallelize(ParallelType::TIDz); |
2235 | tv2c->axis(4)->parallelize(ParallelType::TIDy); |
2236 | |
2237 | // Parallelize |
2238 | // 0 1 2 3 4 5 6 7 |
2239 | // [Mo No Mwo Nwo Mw Nw (Mi Ni)] |
2240 | tv2->axis(0)->parallelize(ParallelType::BIDx); |
2241 | tv2->axis(1)->parallelize(ParallelType::BIDy); |
2242 | tv2->axis(2)->parallelize(ParallelType::TIDz); |
2243 | tv2->axis(3)->parallelize(ParallelType::TIDy); |
2244 | |
2245 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
2246 | auto t0 = at::randn({M, Ko, Ki}, options); |
2247 | auto t1 = at::randn({N, K}, options); |
2248 | |
2249 | FusionExecutor fe; |
2250 | |
2251 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
2252 | 8, 0, fe.compileFusion(&fusion, {t0, t1})); |
2253 | |
2254 | auto cg_outputs = fe.runFusion({t0, t1}); |
2255 | |
2256 | auto tref = |
2257 | at::native::view(t0, {M, K}).to(at::kFloat).matmul(t1.t().to(at::kFloat)); |
2258 | |
2259 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
2260 | } |
2261 | |
2262 | // Initial test case for in-CTA split K with VoltaMMA |
2263 | TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossWarp_CUDA) { |
2264 | NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); |
2265 | |
2266 | Fusion fusion; |
2267 | FusionGuard fg(&fusion); |
2268 | int M = 120, N = 264, K = 120; |
2269 | |
2270 | // [M,K] |
2271 | auto tv0 = makeContigTensor(2, DataType::Half); |
2272 | // [N,K] |
2273 | auto tv1 = makeContigTensor(2, DataType::Half); |
2274 | |
2275 | fusion.addInput(tv0); |
2276 | fusion.addInput(tv1); |
2277 | |
2278 | // [M,N,K] |
2279 | auto tv0b = broadcast(tv0, {false, true, false}); |
2280 | auto tv1b = broadcast(tv1, {true, false, false}); |
2281 | |
2282 | // Leaving both sets of mma inputs for volta outside |
2283 | // currently since they need to be swizzled. |
2284 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); |
2285 | |
2286 | fusion.addOutput(tv2); |
2287 | |
2288 | MatMulTileOptions gemm_tile; |
2289 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
2290 | gemm_tile.warp_tile = GemmTile(64, 64, 16); |
2291 | gemm_tile.instruction_tile = GemmTile(16, 16, 4); |
2292 | |
2293 | auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) |
2294 | .layout(MmaOptions::MmaInputLayout::TN); |
2295 | |
2296 | mma_builder.configureMma(tv2); |
2297 | |
2298 | auto tv0r = tv0->cacheAfter(); |
2299 | auto tv1r = tv1->cacheAfter(); |
2300 | auto tv0cw = tv0b->cacheAfter(); |
2301 | auto tv0cr = tv0cw->cacheAfter(); |
2302 | auto tv1cw = tv1b->cacheAfter(); |
2303 | auto tv1cr = tv1cw->cacheAfter(); |
2304 | auto tv2c = tv2->cacheBefore(); |
2305 | |
2306 | // Make a CTA tile |
2307 | // ------------------------------------------------------------------ |
2308 | // [M,N] |
2309 | tv2->split(-2, gemm_tile.cta_tile.m); |
2310 | tv2->split(-1, gemm_tile.cta_tile.n); |
2311 | |
2312 | // 0 1 2 3 |
2313 | // [Mo,M128, No, N128] |
2314 | tv2->reorder({{1, 2}, {2, 1}}); |
2315 | |
2316 | // 0 1 2 3 |
2317 | // [Mo,No, M128, N128] |
2318 | tv0->computeAt(tv2, 2); |
2319 | tv1->computeAt(tv2, 2); |
2320 | |
2321 | // Order K |
2322 | // 0 1 2 3 4 5 |
2323 | // [Mo,No, M128, N128, Ko, K32] |
2324 | tv2c->split(-1, gemm_tile.cta_tile.k); |
2325 | tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); |
2326 | |
2327 | // 0 1 2 3 4 5 |
2328 | // [Mo,No, Ko M128, N128, K32] |
2329 | tv0r->computeAt(tv2c, 3); |
2330 | tv1r->computeAt(tv2c, 3); |
2331 | |
2332 | // Make warp tile: |
2333 | // ------------------------------------------------------------------------- |
2334 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); |
2335 | auto tv2c_rf = tv2c->rFactor({-9, -4, -1}); |
2336 | |
2337 | // tv2c_rf is the actual output of the mma op after |
2338 | // Rfactoring. |
2339 | mma_builder.accumulatorTv(tv2c_rf); |
2340 | |
2341 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
2342 | tv2, gemm_tile); |
2343 | |
2344 | // -8 -7 -6 -5 -4 -3 -2 -1 |
2345 | // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] |
2346 | tv0cr->computeAt(tv2c_rf, -4); |
2347 | tv1cr->computeAt(tv2c_rf, -4); |
2348 | |
2349 | // Schedule gmem read and smem write: |
2350 | // --------------------------------------------------------------------------- |
2351 | // [Mo,No,Ko,M,N,K] |
2352 | tv0cw->reorder({ |
2353 | {-3, -2}, |
2354 | {-2, -3}, |
2355 | }); |
2356 | // [Mo,No,Ko,N,M,K] |
2357 | tv0cw->merge(-2); |
2358 | tv0r->merge(-2); |
2359 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2360 | tv0cw, gemm_tile, 8); |
2361 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2362 | tv0r, gemm_tile, 8); |
2363 | tv0cw->setMemoryType(MemoryType::Shared); |
2364 | // [Mo,Ko,i,wy,wx,v] |
2365 | |
2366 | // [Mo,No,Ko,M,N,K] |
2367 | tv1cw->merge(-2); |
2368 | tv1r->merge(-2); |
2369 | // [Mo,No,Ko,i,wy,wx,v] |
2370 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2371 | tv1cw, gemm_tile, 8); |
2372 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2373 | tv1r, gemm_tile, 8); |
2374 | tv1cw->setMemoryType(MemoryType::Shared); |
2375 | // Schedule mma input |
2376 | // --------------------------------------------------------------------------- |
2377 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
2378 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
2379 | |
2380 | // Schedule mma output |
2381 | // --------------------------------------------------------------------------- |
2382 | tv2c_rf->applyMmaSwizzle( |
2383 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2384 | tv2c->applyMmaSwizzle( |
2385 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2386 | tv2->applyMmaSwizzle( |
2387 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2388 | |
2389 | tv0b->computeAt(tv0cw, -2); |
2390 | tv1b->computeAt(tv1cw, -2); |
2391 | |
2392 | tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); |
2393 | tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); |
2394 | // Parallelize |
2395 | // 0 1 2 3 4 5 6 7 8 9 10 |
2396 | // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] |
2397 | tv2c_rf->axis(0)->parallelize(ParallelType::BIDx); |
2398 | tv2c_rf->axis(1)->parallelize(ParallelType::BIDy); |
2399 | tv2c_rf->axis(3)->parallelize(ParallelType::TIDz); |
2400 | tv2c_rf->axis(4)->parallelize(ParallelType::TIDy); |
2401 | |
2402 | tv2c->axis(2)->parallelize(ParallelType::TIDz); |
2403 | tv2c->axis(3)->parallelize(ParallelType::TIDy); |
2404 | |
2405 | // Parallelize |
2406 | // 0 1 2 3 4 5 6 7 |
2407 | // [Mo No Mwo Nwo Mw Nw (Mi Ni)] |
2408 | tv2->axis(0)->parallelize(ParallelType::BIDx); |
2409 | tv2->axis(1)->parallelize(ParallelType::BIDy); |
2410 | tv2->axis(2)->parallelize(ParallelType::TIDz); |
2411 | |
2412 | at::manual_seed(0); |
2413 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
2414 | auto t0 = at::randn({M, K}, options); |
2415 | auto t1 = at::randn({N, K}, options); |
2416 | |
2417 | FusionExecutor fe; |
2418 | fe.compileFusion(&fusion, {t0, t1}); |
2419 | auto cg_outputs = fe.runFusion({t0, t1}); |
2420 | auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); |
2421 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
2422 | } |
2423 | |
2424 | // Initial test case for cross-CTA split K with VoltaMMA |
2425 | TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossCTA_CUDA) { |
2426 | NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); |
2427 | |
2428 | Fusion fusion; |
2429 | FusionGuard fg(&fusion); |
2430 | int M = 120, N = 264, K = 120; |
2431 | |
2432 | // [M,K] |
2433 | auto tv0 = makeContigTensor(2, DataType::Half); |
2434 | // [N,K] |
2435 | auto tv1 = makeContigTensor(2, DataType::Half); |
2436 | |
2437 | fusion.addInput(tv0); |
2438 | fusion.addInput(tv1); |
2439 | |
2440 | // [M,N,K] |
2441 | auto tv0b = broadcast(tv0, {false, true, false}); |
2442 | auto tv1b = broadcast(tv1, {true, false, false}); |
2443 | |
2444 | // Leaving both sets of mma inputs for volta outside |
2445 | // currently since they need to be swizzled. |
2446 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); |
2447 | |
2448 | fusion.addOutput(tv2); |
2449 | |
2450 | MatMulTileOptions gemm_tile; |
2451 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
2452 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
2453 | gemm_tile.instruction_tile = GemmTile(16, 16, 4); |
2454 | |
2455 | auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) |
2456 | .layout(MmaOptions::MmaInputLayout::TN); |
2457 | |
2458 | mma_builder.configureMma(tv2); |
2459 | |
2460 | auto tv0r = tv0->cacheAfter(); |
2461 | auto tv1r = tv1->cacheAfter(); |
2462 | auto tv0cw = tv0b->cacheAfter(); |
2463 | auto tv0cr = tv0cw->cacheAfter(); |
2464 | auto tv1cw = tv1b->cacheAfter(); |
2465 | auto tv1cr = tv1cw->cacheAfter(); |
2466 | auto tv2c = tv2->cacheBefore(); |
2467 | |
2468 | // Make a CTA tile |
2469 | // ------------------------------------------------------------------ |
2470 | // [M,N] |
2471 | tv2->split(-2, gemm_tile.cta_tile.m); |
2472 | tv2->split(-1, gemm_tile.cta_tile.n); |
2473 | |
2474 | // 0 1 2 3 |
2475 | // [Mo,M128, No, N128] |
2476 | tv2->reorder({{1, 2}, {2, 1}}); |
2477 | |
2478 | // 0 1 2 3 |
2479 | // [Mo,No, M128, N128] |
2480 | tv0->computeAt(tv2, 2); |
2481 | tv1->computeAt(tv2, 2); |
2482 | |
2483 | // Order K |
2484 | // 0 1 2 3 4 5 |
2485 | // [Mo,No, M128, N128, Ko, K32] |
2486 | tv2c->split(-1, gemm_tile.cta_tile.k); |
2487 | tv2c->split(-2, 2, true); |
2488 | // Order K |
2489 | // 0 1 2 3 4 5 6 |
2490 | // [Mo,No, M128, N128, Ko, K2CTA, K32] |
2491 | tv2c->reorder({{2, 4}, {3, 5}, {4, 3}, {5, 2}}); |
2492 | // 0 1 2 3 4 5 6 |
2493 | // [Mo,No, K2CTA, Ko M128, N128, K32] |
2494 | tv0r->computeAt(tv2c, 4); |
2495 | tv1r->computeAt(tv2c, 4); |
2496 | |
2497 | // Make warp tile: |
2498 | // ------------------------------------------------------------------------- |
2499 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); |
2500 | auto tv2c_rf = tv2c->rFactor({-9, -6, -1}); |
2501 | |
2502 | // tv2c_rf is the actual output of the mma op after |
2503 | // Rfactoring. |
2504 | mma_builder.accumulatorTv(tv2c_rf); |
2505 | |
2506 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
2507 | tv2, gemm_tile); |
2508 | |
2509 | // -8 -7 -6 -5 -4 -3 -2 -1 |
2510 | // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] |
2511 | tv0cr->computeAt(tv2c_rf, -4); |
2512 | tv1cr->computeAt(tv2c_rf, -4); |
2513 | |
2514 | // Schedule gmem read and smem write: |
2515 | // --------------------------------------------------------------------------- |
2516 | // [Mo,No,Ko,M,N,K] |
2517 | tv0cw->reorder({ |
2518 | {-3, -2}, |
2519 | {-2, -3}, |
2520 | }); |
2521 | // [Mo,No,Ko,N,M,K] |
2522 | tv0cw->merge(-2); |
2523 | tv0r->merge(-2); |
2524 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2525 | tv0cw, gemm_tile, 8); |
2526 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2527 | tv0r, gemm_tile, 8); |
2528 | tv0cw->setMemoryType(MemoryType::Shared); |
2529 | // [Mo,Ko,i,wy,wx,v] |
2530 | |
2531 | // [Mo,No,Ko,M,N,K] |
2532 | tv1cw->merge(-2); |
2533 | tv1r->merge(-2); |
2534 | // [Mo,No,Ko,i,wy,wx,v] |
2535 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2536 | tv1cw, gemm_tile, 8); |
2537 | scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( |
2538 | tv1r, gemm_tile, 8); |
2539 | tv1cw->setMemoryType(MemoryType::Shared); |
2540 | // Schedule mma input |
2541 | // --------------------------------------------------------------------------- |
2542 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
2543 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
2544 | |
2545 | // Schedule mma output |
2546 | // --------------------------------------------------------------------------- |
2547 | tv2c_rf->applyMmaSwizzle( |
2548 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2549 | tv2c->applyMmaSwizzle( |
2550 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2551 | tv2->applyMmaSwizzle( |
2552 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2553 | |
2554 | tv0b->computeAt(tv0cw, -2); |
2555 | tv1b->computeAt(tv1cw, -2); |
2556 | |
2557 | tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); |
2558 | tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); |
2559 | // Parallelize |
2560 | // 0 1 2 3 4 5 6 7 8 9 10 |
2561 | // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] |
2562 | tv2c_rf->axis(0)->parallelize(ParallelType::BIDx); |
2563 | tv2c_rf->axis(1)->parallelize(ParallelType::BIDy); |
2564 | tv2c_rf->axis(2)->parallelize(ParallelType::BIDz); |
2565 | tv2c_rf->axis(4)->parallelize(ParallelType::TIDz); |
2566 | tv2c_rf->axis(5)->parallelize(ParallelType::TIDy); |
2567 | |
2568 | tv2c->axis(0)->parallelize(ParallelType::BIDx); |
2569 | tv2c->axis(1)->parallelize(ParallelType::BIDy); |
2570 | tv2c->axis(2)->parallelize(ParallelType::BIDz); |
2571 | tv2c->axis(3)->parallelize(ParallelType::TIDz); |
2572 | tv2c->axis(4)->parallelize(ParallelType::TIDy); |
2573 | |
2574 | // Parallelize |
2575 | // 0 1 2 3 4 5 6 7 |
2576 | // [Mo No Mwo Nwo Mw Nw (Mi Ni)] |
2577 | tv2->axis(0)->parallelize(ParallelType::BIDx); |
2578 | tv2->axis(1)->parallelize(ParallelType::BIDy); |
2579 | tv2->axis(2)->parallelize(ParallelType::TIDz); |
2580 | tv2->axis(3)->parallelize(ParallelType::TIDy); |
2581 | |
2582 | at::manual_seed(0); |
2583 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
2584 | auto t0 = at::randn({M, K}, options); |
2585 | auto t1 = at::randn({N, K}, options); |
2586 | |
2587 | FusionExecutor fe; |
2588 | fe.compileFusion(&fusion, {t0, t1}); |
2589 | auto cg_outputs = fe.runFusion({t0, t1}); |
2590 | auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); |
2591 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
2592 | } |
2593 | |
2594 | // Test an end-to-end matmul case with swizzled smem |
2595 | // data layout. |
2596 | TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { |
2597 | NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); |
2598 | |
2599 | Fusion fusion; |
2600 | FusionGuard fg(&fusion); |
2601 | |
2602 | int M = 257, N = 511, K = 136; |
2603 | |
2604 | MatMulTileOptions gemm_tile; |
2605 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
2606 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
2607 | gemm_tile.instruction_tile = GemmTile(16, 8, 16); |
2608 | |
2609 | // [M,K] |
2610 | auto tv0 = makeContigTensor(2, DataType::Half); |
2611 | // [N,K] |
2612 | auto tv1 = makeContigTensor(2, DataType::Half); |
2613 | fusion.addInput(tv0); |
2614 | fusion.addInput(tv1); |
2615 | |
2616 | // [M,N,K] |
2617 | auto tv0b = broadcast(tv0, {false, true, false}); |
2618 | auto tv1b = broadcast(tv1, {true, false, false}); |
2619 | |
2620 | auto mma_builder = |
2621 | MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) |
2622 | .layout(MmaOptions::MmaInputLayout::TN); |
2623 | |
2624 | auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); |
2625 | |
2626 | fusion.addOutput(tv2); |
2627 | |
2628 | mma_builder.configureMma(tv2); |
2629 | |
2630 | auto tv0cw = tv0->cacheAfter(LoadStoreOpType::CpAsync); |
2631 | auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); |
2632 | auto tv1cw = tv1->cacheAfter(LoadStoreOpType::CpAsync); |
2633 | auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); |
2634 | auto tv2c = tv2->cacheBefore(); |
2635 | |
2636 | mma_builder.accumulatorTv(tv2c); |
2637 | |
2638 | // Make a CTA tile |
2639 | // ------------------------------------------------------------------ |
2640 | // [M,N] |
2641 | tv2->split(-2, gemm_tile.cta_tile.m); |
2642 | tv2->split(-1, gemm_tile.cta_tile.n); |
2643 | |
2644 | // 0 1 2 3 |
2645 | // [Mo,M128, No, N128] |
2646 | tv2->reorder({{1, 2}, {2, 1}}); |
2647 | |
2648 | // 0 1 2 3 |
2649 | // [Mo,No, M128, N128] |
2650 | tv0->computeAt(tv2, 2); |
2651 | tv1->computeAt(tv2, 2); |
2652 | |
2653 | // Order K |
2654 | // 0 1 2 3 4 5 |
2655 | // [Mo,No, M128, N128, Ko, K32] |
2656 | tv2c->split(-1, gemm_tile.cta_tile.k); |
2657 | tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); |
2658 | |
2659 | // 0 1 2 3 4 5 |
2660 | // [Mo,No, Ko M128, N128, K32] |
2661 | tv0cw->computeAt(tv2c, 3); |
2662 | tv1cw->computeAt(tv2c, 3); |
2663 | |
2664 | // Make warp tile: |
2665 | // |
2666 | scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); |
2667 | scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( |
2668 | tv2, gemm_tile); |
2669 | // -8 -7 -6 -5 -4 -3 -2 -1 |
2670 | // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] |
2671 | tv0cr->computeAt(tv2c, -4); |
2672 | tv1cr->computeAt(tv2c, -4); |
2673 | |
2674 | // Schedule gmem read and smem write: |
2675 | // |
2676 | // [Mo,Ko,M,K] |
2677 | // Swizzle tv0: 128 x 32 tile: |
2678 | tv0cw->split(-2, 8); |
2679 | tv0cw->split(-2, 2); |
2680 | tv0cw->split(-1, 8); |
2681 | // -5 -4 -3 -2 -1 |
2682 | // [Mo,Ko,Mo16,M4,M2,Ko4,K8] |
2683 | tv0cw->swizzle(Swizzle2DType::XOR, -4, -2); |
2684 | tv0cw->merge(-4); |
2685 | tv0cw->merge(-3); |
2686 | // -3 -2 -1 |
2687 | // [Mo,Ko,Mo16,warp,K8] |
2688 | tv0cw->split(-3, 4); |
2689 | tv0cw->split(-3, 2); |
2690 | // -4 -3 -2 -1 |
2691 | // [Mo,Ko, S4, wz2, wy2, warp,K8] |
2692 | tv0cw->axis(-4)->parallelize(ParallelType::TIDz); |
2693 | tv0cw->axis(-3)->parallelize(ParallelType::TIDy); |
2694 | tv0cw->axis(-2)->parallelize(ParallelType::TIDx); |
2695 | tv0cw->axis(-1)->parallelize(ParallelType::Vectorize); |
2696 | |
2697 | tv0cw->setMemoryType(MemoryType::Shared); |
2698 | // [Mo,Ko,i,wy,wx,v] |
2699 | |
2700 | // [No,Ko,N,K] |
2701 | // Swizzle tv0: 128 x 32 tile: |
2702 | tv1cw->split(-2, 8); |
2703 | tv1cw->split(-2, 2); |
2704 | tv1cw->split(-1, 8); |
2705 | // -5 -4 -3 -2 -1 |
2706 | // [No,Ko,No16,N4,N2,Ko4,K8] |
2707 | tv1cw->swizzle(Swizzle2DType::XOR, -4, -2); |
2708 | tv1cw->merge(-4); |
2709 | tv1cw->merge(-3); |
2710 | // -3 -2 -1 |
2711 | // [No,Ko,No16,warp,K8] |
2712 | tv1cw->split(-3, 4); |
2713 | tv1cw->split(-3, 2); |
2714 | // -4 -3 -2 -1 |
2715 | // [No,Ko, S4, wz2, wy2, warp,K8] |
2716 | tv1cw->axis(-4)->parallelize(ParallelType::TIDz); |
2717 | tv1cw->axis(-3)->parallelize(ParallelType::TIDy); |
2718 | tv1cw->axis(-2)->parallelize(ParallelType::TIDx); |
2719 | tv1cw->axis(-1)->parallelize(ParallelType::Vectorize); |
2720 | |
2721 | tv1cw->setMemoryType(MemoryType::Shared); |
2722 | // Schedule mma input |
2723 | tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
2724 | |
2725 | // [... Mi, Ni, Ki] |
2726 | tv0b->reorder({{-2, -3}, {-3, -2}}); |
2727 | tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); |
2728 | |
2729 | tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
2730 | tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); |
2731 | |
2732 | // Schedule mma output |
2733 | tv2c->applyMmaSwizzle( |
2734 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2735 | tv2->applyMmaSwizzle( |
2736 | mma_builder.operand(MmaOptions::Operand::Accumulator).build()); |
2737 | |
2738 | // Parallelize |
2739 | // 0 1 2 3 4 5 6 7 8 9 10 |
2740 | // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] |
2741 | tv2c->axis(3)->parallelize(ParallelType::TIDz); |
2742 | tv2c->axis(4)->parallelize(ParallelType::TIDy); |
2743 | |
2744 | // Parallelize |
2745 | // 0 1 2 3 4 5 6 7 |
2746 | // [Mo No Mwo Nwo Mw Nw (Mi Ni)] |
2747 | tv2->axis(0)->parallelize(ParallelType::BIDx); |
2748 | tv2->axis(1)->parallelize(ParallelType::BIDy); |
2749 | tv2->axis(2)->parallelize(ParallelType::TIDz); |
2750 | tv2->axis(3)->parallelize(ParallelType::TIDy); |
2751 | |
2752 | tv0cw->doubleBuffer(); |
2753 | tv1cw->doubleBuffer(); |
2754 | tv0cr->doubleBuffer(); |
2755 | tv1cr->doubleBuffer(); |
2756 | |
2757 | auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
2758 | auto t0 = at::randn({M, K}, options); |
2759 | auto t1 = at::randn({N, K}, options); |
2760 | |
2761 | FusionExecutor fe; |
2762 | fe.compileFusion(&fusion); |
2763 | auto cg_outputs = fe.runFusion({t0, t1}); |
2764 | |
2765 | auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); |
2766 | |
2767 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
2768 | } |
2769 | |
2770 | // Matmul test on Ampere using ldmatrix.x4 to load operands |
2771 | TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { |
2772 | // Keep multiples of 8 to keep vectorizable. |
2773 | int M = 504, N = 136, K = 248; |
2774 | for (auto layout : kAllSupportedLayout) { |
2775 | Fusion fusion; |
2776 | FusionGuard fg(&fusion); |
2777 | auto tv0 = makeContigTensor(2, DataType::Half); |
2778 | auto tv1 = makeContigTensor(2, DataType::Half); |
2779 | |
2780 | fusion.addInput(tv0); |
2781 | fusion.addInput(tv1); |
2782 | |
2783 | auto tv2 = matmul(tv0, tv1, layout); |
2784 | |
2785 | fusion.addOutput(tv2); |
2786 | |
2787 | MatMulTileOptions gemm_tile; |
2788 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
2789 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
2790 | gemm_tile.instruction_tile = GemmTile(16, 16, 16); |
2791 | |
2792 | auto mma_builder = |
2793 | MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) |
2794 | .layout(layout); |
2795 | |
2796 | MatmulParam params(mma_builder); |
2797 | params.tile_sizes = gemm_tile; |
2798 | params.async_gmem_load_operands = true; |
2799 | params.double_buffer_options.double_buffer_smem_write = true; |
2800 | params.double_buffer_options.smem_double_buffer_stage = 4; |
2801 | scheduleMatmul(tv2, tv0, tv1, params); |
2802 | |
2803 | at::manual_seed(0); |
2804 | auto inputs = fp16MatmulAtInput(M, N, K, layout); |
2805 | |
2806 | FusionExecutor fe; |
2807 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
2808 | 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); |
2809 | auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); |
2810 | auto tref = atMatmul( |
2811 | inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); |
2812 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
2813 | } |
2814 | } |
2815 | |
2816 | // Matmul test for Turing MMA: across supported layouts |
2817 | TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) { |
2818 | // Keep multiples of 8 to keep vectorizable. |
2819 | int M = 504, N = 136, K = 248; |
2820 | |
2821 | for (auto layout : kAllSupportedLayout) { |
2822 | Fusion fusion; |
2823 | FusionGuard fg(&fusion); |
2824 | auto tv0 = makeContigTensor(2, DataType::Half); |
2825 | auto tv1 = makeContigTensor(2, DataType::Half); |
2826 | |
2827 | fusion.addInput(tv0); |
2828 | fusion.addInput(tv1); |
2829 | |
2830 | auto tv2 = matmul(tv0, tv1, layout); |
2831 | |
2832 | fusion.addOutput(tv2); |
2833 | |
2834 | MatMulTileOptions gemm_tile; |
2835 | gemm_tile.cta_tile = GemmTile(128, 128, 32); |
2836 | gemm_tile.warp_tile = GemmTile(64, 64, 32); |
2837 | gemm_tile.instruction_tile = GemmTile(16, 16, 16); |
2838 | |
2839 | auto mma_builder = |
2840 | MmaBuilder(MmaOptions::MacroType::Turing_16_16_16, gemm_tile) |
2841 | .layout(layout); |
2842 | |
2843 | MatmulParam params(mma_builder); |
2844 | params.tile_sizes = gemm_tile; |
2845 | scheduleMatmul(tv2, tv0, tv1, params); |
2846 | |
2847 | at::manual_seed(0); |
2848 | auto inputs = fp16MatmulAtInput(M, N, K, layout); |
2849 | |
2850 | FusionExecutor fe; |
2851 | NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( |
2852 | 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); |
2853 | auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); |
2854 | auto tref = atMatmul( |
2855 | inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); |
2856 | TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); |
2857 | } |
2858 | } |
2859 | |
2860 | #undef NVFUSER_TEST_CUDA_ARCH_GUARD |
2861 | |
2862 | } // namespace jit |
2863 | } // namespace torch |
2864 | |
2865 | #endif |
2866 | |