1#if defined(USE_CUDA)
2#include <gmock/gmock-matchers.h>
3#include <gtest/gtest.h>
4
5#include <executor.h>
6#include <inlining.h>
7#include <kernel_cache.h>
8#include <ops/all_ops.h>
9#include <scheduler/all_schedulers.h>
10#include <scheduler/transpose.h>
11#include <scheduler/utils.h>
12#include <test/test_gpu_validator.h>
13#include <test/test_utils.h>
14
15// Tests go in torch::jit
16namespace torch {
17namespace jit {
18
19using namespace torch::jit::fuser::cuda;
20
21TEST_F(NVFuserTest, FusionTranspose1_CUDA) {
22 Fusion fusion;
23 FusionGuard fg(&fusion);
24
25 constexpr int M = 10;
26 constexpr int N = 20;
27
28 auto tv0 = makeSymbolicTensor(2);
29 auto tv1 = transpose(tv0);
30 fusion.addInput(tv0);
31 fusion.addOutput(tv1);
32
33 tv1->axis(0)->parallelize(ParallelType::BIDx);
34 tv1->axis(1)->parallelize(ParallelType::TIDx);
35
36 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
37 at::manual_seed(0);
38 at::Tensor t0 = at::randn({M, N}, options);
39 std::vector<IValue> aten_inputs = {t0};
40
41 FusionExecutor fe;
42 fe.compileFusion(&fusion, aten_inputs);
43 auto outputs = fe.runFusion(aten_inputs);
44
45 at::Tensor aten_output = t0.t();
46
47 testValidate(
48 &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
49}
50
51TEST_F(NVFuserTest, FusionTranspose2_CUDA) {
52 Fusion fusion;
53 FusionGuard fg(&fusion);
54
55 constexpr int M = 10;
56 constexpr int N = 20;
57
58 auto tv0 = makeSymbolicTensor(2);
59 auto tv1 = transpose(tv0);
60 fusion.addInput(tv0);
61 fusion.addOutput(tv1);
62
63 tv1->merge(0);
64 tv1->split(0, 32);
65
66 tv1->axis(0)->parallelize(ParallelType::BIDx);
67 tv1->axis(1)->parallelize(ParallelType::TIDx);
68
69 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
70 at::manual_seed(0);
71 at::Tensor t0 = at::randn({M, N}, options);
72 std::vector<IValue> aten_inputs = {t0};
73
74 FusionExecutor fe;
75 fe.compileFusion(&fusion, aten_inputs);
76 auto outputs = fe.runFusion(aten_inputs);
77
78 at::Tensor aten_output = t0.t();
79
80 testValidate(
81 &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
82}
83
84TEST_F(NVFuserTest, FusionTransposeWithSwizzle_CUDA) {
85 Fusion fusion;
86 FusionGuard fg(&fusion);
87
88 auto tv0 = makeSymbolicTensor(2);
89 fusion.addInput(tv0);
90 auto tv1 = transpose(tv0);
91 fusion.addOutput(tv1);
92
93 // tv0: [I0, I1]
94 // tv1: [I1, I0]
95
96 const int BS = 32;
97
98 // CTA tiling by BS*BS
99 tv1->split(1, BS);
100 tv1->split(0, BS);
101 tv1->reorder({{1, 2}});
102 // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)]
103
104 // Create a smem buffer to cache each tile
105 auto tv0_cache = tv0->cacheAfter();
106 tv0_cache->setMemoryType(MemoryType::Shared);
107
108 tv0->computeAt(tv1, 2);
109 // tv0: [I0, I1]
110 // tv0_cache: [I1/BS, I0/BS, BS(I1), BS(I0)]
111 // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)]
112
113 // Assign each thread block to a tile
114 tv1->axis(0)->parallelize(ParallelType::BIDy);
115 tv1->axis(1)->parallelize(ParallelType::BIDx);
116
117 // Thread mapping for each tile. For both of the input and output
118 // tiles, map TIDx to the fastest-changing dimension to facilitate
119 // coalesced gmem accesses.
120 tv1->axis(2)->parallelize(ParallelType::TIDy);
121 tv1->axis(3)->parallelize(ParallelType::TIDx);
122 // Note that the fastest-changing axis is next to the inner-most
123 // axis since computeAt reorders the axes as the output tensor.
124 tv0_cache->axis(2)->parallelize(ParallelType::TIDx);
125 tv0_cache->axis(3)->parallelize(ParallelType::TIDy);
126
127 // Swizzles the smem cache to avoid bank conflicts
128 tv0_cache->swizzle(SwizzleType::Transpose, {3, 2});
129
130 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
131 const int bx = 100;
132 const int by = 200;
133 at::Tensor t0 = at::randn({bx, by}, options);
134 std::vector<IValue> aten_inputs = {t0};
135
136 FusionExecutor fe;
137 fe.compileFusion(&fusion, aten_inputs);
138 auto cg_outputs = fe.runFusion(aten_inputs);
139
140 auto aten_output = t0.t();
141
142 testValidate(
143 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
144}
145
146TEST_F(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) {
147 Fusion fusion;
148 FusionGuard fg(&fusion);
149
150 auto tv0 = makeSymbolicTensor(2);
151 fusion.addInput(tv0);
152 auto tv1 = transpose(tv0);
153 fusion.addOutput(tv1);
154
155 // tv0: [I0, I1]
156 // tv1: [I1, I0]
157
158 const int BS = 32;
159 const int BDIM = 256;
160
161 // CTA tiling by BS*BS
162 tv1->split(1, BS);
163 tv1->split(0, BS);
164 tv1->reorder({{1, 2}});
165 // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)]
166
167 // Create a smem buffer to cache each tile
168 auto tv0_cache = tv0->cacheAfter();
169 tv0_cache->setMemoryType(MemoryType::Shared);
170
171 tv0->computeAt(tv1, 2);
172 // tv0: [I0, I1]
173 // tv0_cache: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
174 // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
175
176 // Tranform the tile axes for 1D thread mapping
177 tv1->merge(-2, -1);
178 tv1->split(-1, BDIM);
179 // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
180
181 // Transform the cache similarly but apply swizzle to the 2D tile axes.
182 tv0_cache->reorder({{-2, -1}});
183 tv0_cache->swizzle(SwizzleType::Transpose, {2, 3});
184 tv0_cache->merge(-2, -1);
185 tv0_cache->split(-1, BDIM);
186 // tv0: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
187
188 // Assign each thread block to a tile
189 tv1->axis(0)->parallelize(ParallelType::BIDy);
190 tv1->axis(1)->parallelize(ParallelType::BIDx);
191
192 // Thread mapping for each tile.
193 tv1->axis(-1)->parallelize(ParallelType::TIDx);
194 tv0_cache->axis(-1)->parallelize(ParallelType::TIDx);
195
196 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
197 const int bx = 100;
198 const int by = 200;
199 at::Tensor t0 = at::randn({bx, by}, options);
200 std::vector<IValue> aten_inputs = {t0};
201
202 FusionExecutor fe;
203 fe.compileFusion(&fusion, aten_inputs);
204 auto cg_outputs = fe.runFusion(aten_inputs);
205
206 auto aten_output = t0.t();
207
208 testValidate(
209 &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
210}
211
212// x->sin->transpose->cos->y
213TEST_F(NVFuserTest, FusionScheduleTransposeSimple_CUDA) {
214 Fusion fusion;
215 FusionGuard fg(&fusion);
216
217 auto tv0 = makeContigTensor(3);
218 fusion.addInput(tv0);
219 auto tv1 = sin(tv0);
220 auto tv2 = transpose(tv1, 1, 2);
221 auto tv3 = cos(tv2);
222 fusion.addOutput(tv3);
223
224 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
225 at::Tensor input = at::randn({256, 1024, 1024}, options);
226
227 auto lparams = scheduleTranspose(&fusion, {input});
228
229 FusionExecutor fe;
230 fe.compileFusion(&fusion, {input}, lparams);
231 auto outputs = fe.runFusion({input}, lparams);
232
233 auto tv_ref = input.sin().transpose(1, 2).cos();
234
235 testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
236}
237
238// x->tanspose->sin->transpose->cos->y
239TEST_F(NVFuserTest, FusionScheduleTransposeSinTransposeCos_CUDA) {
240 Fusion fusion;
241 FusionGuard fg(&fusion);
242
243 auto tv0 = makeContigTensor(3);
244 fusion.addInput(tv0);
245 auto tv1 = transpose(tv0, 0, 2);
246 auto tv2 = sin(tv1);
247 auto tv3 = transpose(tv2, 1, 2);
248 auto tv4 = cos(tv3);
249 fusion.addOutput(tv4);
250
251 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
252 at::Tensor input = at::randn({256, 1024, 1024}, options);
253
254 auto lparams = scheduleTranspose(&fusion, {input});
255
256 FusionExecutor fe;
257 fe.compileFusion(&fusion, {input}, lparams);
258 auto outputs = fe.runFusion({input}, lparams);
259
260 auto tv_ref = input.transpose(0, 2).sin().transpose(1, 2).cos();
261
262 testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
263}
264
265/*
266 * t0->transpose--.
267 * \
268 * t1->transpose---add-->sin->t5
269 */
270TEST_F(NVFuserTest, FusionScheduleTransposeMultipleInput_CUDA) {
271 Fusion fusion;
272 FusionGuard fg(&fusion);
273
274 auto tv0 = makeContigTensor(3);
275 auto tv1 = makeContigTensor(3);
276 fusion.addInput(tv0);
277 fusion.addInput(tv1);
278 auto tv2 = transpose(tv0, 0, 2);
279 auto tv3 = transpose(tv1, 0, 2);
280 auto tv4 = add(tv2, tv3);
281 auto tv5 = sin(tv4);
282 fusion.addOutput(tv5);
283
284 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
285 at::Tensor input0 = at::randn({256, 1024, 1024}, options);
286 at::Tensor input1 = at::randn({256, 1024, 1024}, options);
287
288 auto lparams = scheduleTranspose(&fusion, {input0, input1});
289
290 FusionExecutor fe;
291 fe.compileFusion(&fusion, {input0, input1}, lparams);
292 auto outputs = fe.runFusion({input0, input1}, lparams);
293
294 auto tv_ref = (input0.transpose(0, 2) + input1.transpose(0, 2)).sin();
295
296 testValidate(
297 &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
298}
299
300// t0->sin->transpose->t5
301// `->cos->transpose->t6
302TEST_F(NVFuserTest, FusionScheduleTransposeMultipleOutput_CUDA) {
303 Fusion fusion;
304 FusionGuard fg(&fusion);
305
306 auto tv0 = makeContigTensor(3);
307 fusion.addInput(tv0);
308 auto tv2 = sin(tv0);
309 auto tv3 = cos(tv0);
310 auto tv5 = transpose(tv2, 0, 2);
311 auto tv6 = transpose(tv3, 0, 2);
312 fusion.addOutput(tv5);
313 fusion.addOutput(tv6);
314
315 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
316 at::Tensor input = at::randn({256, 1024, 1024}, options);
317
318 auto lparams = scheduleTranspose(&fusion, {input});
319
320 FusionExecutor fe;
321 fe.compileFusion(&fusion, {input}, lparams);
322 auto outputs = fe.runFusion({input}, lparams);
323
324 auto tv_ref1 = input.sin().transpose(0, 2);
325 auto tv_ref2 = input.cos().transpose(0, 2);
326
327 testValidate(
328 &fusion, outputs, {input}, {tv_ref1, tv_ref2}, __LINE__, __FILE__);
329}
330
331/*
332 * t0->transpose->sin->t3
333 * \_.-->cos->t5
334 * /
335 * t1
336 */
337TEST_F(NVFuserTest, FusionScheduleTransposeMultipleInputOutput_CUDA) {
338 Fusion fusion;
339 FusionGuard fg(&fusion);
340
341 auto tv0 = makeContigTensor(3);
342 auto tv1 = makeContigTensor(3);
343 fusion.addInput(tv0);
344 fusion.addInput(tv1);
345 auto tv2 = transpose(tv0, 0, 2);
346 auto tv3 = sin(tv2);
347 fusion.addOutput(tv3);
348 auto tv4 = add(tv0, tv1);
349 auto tv5 = cos(tv4);
350 fusion.addOutput(tv5);
351
352 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
353 at::Tensor input0 = at::randn({256, 1024, 1024}, options);
354 at::Tensor input1 = at::randn({256, 1024, 1024}, options);
355
356 auto lparams = scheduleTranspose(&fusion, {input0, input1});
357
358 FusionExecutor fe;
359 fe.compileFusion(&fusion, {input0, input1}, lparams);
360 auto outputs = fe.runFusion({input0, input1}, lparams);
361
362 auto tv_ref1 = input0.transpose(0, 2).sin();
363 auto tv_ref2 = (input0 + input1).cos();
364
365 testValidate(
366 &fusion,
367 outputs,
368 {input0, input1},
369 {tv_ref1, tv_ref2},
370 __LINE__,
371 __FILE__);
372}
373
374/*
375 * .------>sin------>z
376 * x->transpose->transpose->add->y
377 * \_______________________/
378 */
379TEST_F(NVFuserTest, FusionScheduleTransposeMatchingSkipConnection_CUDA) {
380 Fusion fusion;
381 FusionGuard fg(&fusion);
382
383 auto tv0 = makeContigTensor(3);
384 fusion.addInput(tv0);
385 auto tv1 = transpose(tv0, 0, 2);
386 auto tv2 = transpose(tv1, 0, 2);
387 auto tv3 = add(tv0, tv2);
388 fusion.addOutput(tv3);
389 auto tv4 = sin(tv1);
390 fusion.addOutput(tv4);
391
392 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
393 at::Tensor input = at::randn({256, 1024, 1024}, options);
394
395 auto lparams = scheduleTranspose(&fusion, {input});
396
397 FusionExecutor fe;
398 fe.compileFusion(&fusion, {input}, lparams);
399 auto outputs = fe.runFusion({input}, lparams);
400
401 auto tv_ref1 = input.transpose(0, 2).transpose(0, 2) + input;
402 auto tv_ref2 = input.transpose(0, 2).sin();
403
404 testValidate(
405 &fusion, outputs, {input}, {tv_ref1, tv_ref2}, __LINE__, __FILE__);
406}
407
408// x->transpose--add->z
409// y->broadcast-/
410TEST_F(NVFuserTest, FusionScheduleTransposeBroadcast_CUDA) {
411 Fusion fusion;
412 FusionGuard fg(&fusion);
413
414 auto tv0 = makeContigTensor(3);
415 auto tv1 = makeContigTensor(2);
416 fusion.addInput(tv0);
417 fusion.addInput(tv1);
418 auto tv2 = transpose(tv0, 1, 2);
419 auto tv3 = broadcast(tv1, {false, false, true});
420 auto tv4 = add(tv2, tv3);
421 fusion.addOutput(tv4);
422
423 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
424 at::Tensor input0 = at::randn({1024, 256, 1024}, options);
425 at::Tensor input1 = at::randn({1024, 1024}, options);
426
427 auto lparams = scheduleTranspose(&fusion, {input0, input1});
428 // auto lparams = schedulePointwise(&fusion, {input0, input1});
429
430 FusionExecutor fe;
431 fe.compileFusion(&fusion, {input0, input1}, lparams);
432 auto outputs = fe.runFusion({input0, input1}, lparams);
433
434 auto tv_ref = input0.transpose(1, 2) + input1.unsqueeze(2);
435
436 testValidate(
437 &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
438}
439
440// x->broadcast--add->z
441// y->broadcast-/
442TEST_F(NVFuserTest, FusionScheduleTransposeNoReference_CUDA) {
443 Fusion fusion;
444 FusionGuard fg(&fusion);
445
446 auto tv0 = makeContigTensor(2);
447 auto tv1 = makeContigTensor(2);
448 fusion.addInput(tv0);
449 fusion.addInput(tv1);
450 auto tv2 = broadcast(tv0, {false, true, false});
451 auto tv3 = broadcast(tv1, {false, false, true});
452 auto tv4 = add(tv2, tv3);
453 fusion.addOutput(tv4);
454
455 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
456 at::Tensor input0 = at::randn({1024, 256}, options);
457 at::Tensor input1 = at::randn({1024, 1024}, options);
458
459 EXPECT_THAT(
460 [&]() {
461 scheduleTranspose(&fusion, {input0, input1});
462 },
463 testing::ThrowsMessage<c10::Error>(
464 testing::HasSubstr("reference tensor")));
465}
466
467// x->broadcast--add->z
468// y->broadcast-/
469TEST_F(NVFuserTest, FusionScheduleBroadcastOnly_CUDA) {
470 for (bool contig0 : {true, false}) {
471 for (bool contig1 : {true, false}) {
472 Fusion fusion;
473 FusionGuard fg(&fusion);
474 auto tv0 = contig0 ? makeContigConcreteTensor({-1, 1, -1})
475 : makeConcreteTensor({-1, 1, -1});
476 auto tv1 = contig1 ? makeContigConcreteTensor({-1, -1, 1})
477 : makeConcreteTensor({-1, -1, 1});
478 fusion.addInput(tv0);
479 fusion.addInput(tv1);
480 auto tv2 = add(tv0, tv1);
481 fusion.addOutput(tv2);
482
483 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
484 at::Tensor input0 = at::randn({1024, 1, 256}, options);
485 at::Tensor input1 = at::randn({1024, 1024, 1}, options);
486
487 auto lparams = scheduleTranspose(&fusion, {input0, input1});
488
489 FusionExecutor fe;
490 fe.compileFusion(&fusion, {input0, input1}, lparams);
491 auto outputs = fe.runFusion({input0, input1}, lparams);
492
493 auto tv_ref = input0 + input1;
494
495 testValidate(
496 &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
497 }
498 }
499}
500
501// mermaid graph:
502// ```mermaid
503// %%{
504// init: {
505// 'theme': 'base',
506// 'themeVariables': { 'fontSize': '30px', 'fontFamily': 'times'}}
507// }%%
508// graph TD
509// T0("T0(M, N, K)")
510// T1("T1(N, M, K)")
511// T2("T2(M, K, N)")
512// T0 --> A("transpose(1, 2)") --> T3("T3(M, K, N)")
513// T1 ---> sigmoid --> T5("T5(N, M, K)")
514// T5 --> B("transpose(0, 2)") --> T7("T7(K, M, N)")
515// T2 ----> C("add")
516// T3 --> C --> T6("T6(M, K, N)")
517// T6 --> D("transpose(0, 1)") --> T11("T11(K, M, N)")
518// T11 --> E("add") -->T12("T12(K, M, N)")
519// T7 --> E
520// T1 ---> F("transpose(0, 1)") --> T4("T4(M, N, K)")
521// T0 --> G("add") --> T8("T8(M, N, K)") --> relu ---> T9("T9(M, N, K)")
522// T4 --> G
523// T6 ---> sin ---> T10("T10(M, K, N)")
524// style T0 fill:lightgreen
525// style T1 fill:lightgreen
526// style T2 fill:lightgreen
527// style T12 fill:lightblue
528// style T9 fill:lightblue
529// style T10 fill:lightblue
530// ```
531TEST_F(NVFuserTest, FusionScheduleTransposeComplexDAG1_CUDA) {
532 Fusion fusion;
533 FusionGuard fg(&fusion);
534
535 auto tv0 = makeContigTensor(3);
536 auto tv1 = makeContigTensor(3);
537 auto tv2 = makeContigTensor(3);
538 fusion.addInput(tv0);
539 fusion.addInput(tv1);
540 fusion.addInput(tv2);
541 auto tv3 = transpose(tv0, 1, 2);
542 auto tv4 = transpose(tv1, 0, 1);
543 auto tv5 = sigmoid(tv1);
544 auto tv6 = add(tv2, tv3);
545 auto tv7 = transpose(tv5, 0, 2);
546 auto tv8 = add(tv4, tv0);
547 auto tv9 = relu(tv8);
548 fusion.addOutput(tv9);
549 auto tv10 = sin(tv6);
550 fusion.addOutput(tv10);
551 auto tv11 = transpose(tv6, 0, 1);
552 auto tv12 = add(tv7, tv11);
553 fusion.addOutput(tv12);
554
555 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
556 at::Tensor input0 = at::randn({512, 1024, 256}, options);
557 at::Tensor input1 = at::randn({1024, 512, 256}, options);
558 at::Tensor input2 = at::randn({512, 256, 1024}, options);
559
560 auto lparams = scheduleTranspose(&fusion, {input0, input1, input2});
561
562 FusionExecutor fe;
563 fe.compileFusion(&fusion, {input0, input1, input2}, lparams);
564 auto outputs = fe.runFusion({input0, input1, input2}, lparams);
565
566 auto t3 = input0.transpose(1, 2);
567 auto t4 = input1.transpose(0, 1);
568 auto t5 = input1.sigmoid();
569 auto t6 = input2 + t3;
570 auto t7 = t5.transpose(0, 2);
571 auto t8 = t4 + input0;
572 auto t9 = t8.relu();
573 auto t10 = t6.sin();
574 auto t11 = t6.transpose(0, 1);
575 auto t12 = t7 + t11;
576
577 testValidate(
578 &fusion,
579 outputs,
580 {input0, input1, input2},
581 {t9, t10, t12},
582 __LINE__,
583 __FILE__);
584}
585
586// mermaid graph:
587// ```mermaid
588// %%{
589// init: {
590// 'theme': 'base',
591// 'themeVariables': { 'fontSize': '30px', 'fontFamily': 'times'}}
592// }%%
593// graph TD
594// T0("T0(M, N, K)")
595// T1("T1(N, M, K)")
596// T2("T2(M, K, N)")
597// T0 --> A("transpose(1, 2)") --> T3("T3(M, K, N)")
598// T1 ---> sigmoid --> T5("T5(N, M, K)")
599// T5 --> B("transpose(0, 2)") --> T7("T7(K, M, N)")
600// T2 ----> C("add")
601// T3 --> C --> T6("T6(M, K, N)")
602// T6 --> D("transpose(0, 1)") --> T11("T11(K, M, N)")
603// T11 --> E("add") -->T12("T12(K, M, N)")
604// T7 --> E
605// T1 ---> F("transpose(0, 1)") --> T4("T4(M, N, K)")
606// T0 --> G("add") --> T8("T8(M, N, K)") --> relu ---> T9("T9(M, N, K)")
607// T4 --> G
608// T6 ---> sin ---> T10("T10(M, K, N)")
609// style T0 fill:lightgreen
610// style T1 fill:lightgreen
611// style T2 fill:lightgreen
612// style T12 fill:lightblue
613// style T9 fill:lightblue
614// style T10 fill:lightblue
615// ```
616TEST_F(NVFuserTest, FusionManualScheduleTransposeComplexDAG1_CUDA) {
617 // achieved: 833.526 GB/s on RTX 3090 (theoretical bandwidth: 936 GB/s)
618 Fusion fusion;
619 FusionGuard fg(&fusion);
620
621 auto tv0 = makeContigTensor(3);
622 auto tv1 = makeContigTensor(3);
623 auto tv2 = makeContigTensor(3);
624 fusion.addInput(tv0);
625 fusion.addInput(tv1);
626 fusion.addInput(tv2);
627 auto tv3 = transpose(tv0, 1, 2);
628 auto tv4 = transpose(tv1, 0, 1);
629 auto tv5 = sigmoid(tv1);
630 auto tv6 = add(tv2, tv3);
631 auto tv7 = transpose(tv5, 0, 2);
632 auto tv8 = add(tv4, tv0);
633 auto tv9 = relu(tv8);
634 fusion.addOutput(tv9);
635 auto tv10 = sin(tv6);
636 fusion.addOutput(tv10);
637 auto tv11 = transpose(tv6, 0, 1);
638 auto tv12 = add(tv7, tv11);
639 fusion.addOutput(tv12);
640
641 // group 1: tv0, tv1, *tv9, innermost dim K
642 // group 2: tv2, *tv10, tv12, innermost dim N
643
644 // cache inputs and outputs
645 auto tv0_cache = tv0->cacheAfter();
646 auto tv1_cache = tv1->cacheAfter();
647 auto tv2_cache = tv2->cacheAfter();
648 auto tv9_cache = tv9->cacheBefore();
649 auto tv10_cache = tv10->cacheBefore();
650 auto tv12_cache = tv12->cacheBefore();
651
652 // Step 1: Make 32x32 tiles, schedule outer dimensions
653 {
654 // Pick an arbitrary tensor as a reference tensor for this step. There is no
655 // requirement on which group this reference tensor should belong to. Here
656 // we pick tv9, which belongs to group 1.
657
658 // Make 32x32 tile:
659 // [M, N, K]
660 tv9->split(1, 32);
661 tv9->reorder({{2, -1}});
662 tv9->split(2, 32);
663 tv9->reorder({{3, -1}});
664 // [M, N/32, K/32, 32(N), 32(K)]
665
666 // merge outer dims, parallelize on BIDx, and unswitch
667 tv9->merge(0);
668 tv9->merge(0);
669 tv9->split(0, 1);
670 // [M * N/32 * K/32, 1, 32(N), 32(K)]
671 tv9->axis(0)->parallelize(ParallelType::BIDx);
672 tv9->axis(1)->parallelize(ParallelType::Unswitch);
673 // [BIDx, Unswitch, 32(N), 32(K)]
674
675 // propagate to the entire DAG
676 MaxRootDomainInfoSpanningTree entire_dag(tv9);
677 TransformPropagator tp(tv9);
678 entire_dag.traverse(&tp);
679 scheduler_utils::parallelizeAllLike(tv9);
680 }
681
682 constexpr int threads_per_block = 128;
683
684 // Step 2, schedule group 2
685 {
686 // group 2: tv2, *tv10, tv12, innermost dim N
687
688 tv2_cache->setMemoryType(MemoryType::Shared);
689 tv10_cache->setMemoryType(MemoryType::Shared);
690 tv12_cache->setMemoryType(MemoryType::Shared);
691
692 // pick tv10 as reference tensor for group 2
693 // [BIDx, Unswitch, 32(N), 32(K)]
694 tv10->reorder({{-1, -2}});
695 // [BIDx, Unswitch, 32(K), 32(N)]
696 tv10->merge(2);
697 tv10->split(2, 4);
698 tv10->split(2, threads_per_block);
699 tv10->axis(-1)->parallelize(ParallelType::Vectorize);
700 tv10->axis(-2)->parallelize(ParallelType::TIDx);
701 tv10->axis(-3)->parallelize(ParallelType::Unroll);
702 // [BIDx, Unswitch, Unroll, TIDx, Vectorize]
703
704 // Propagate to group 2 and its cache. Note that group 2 and its cache are
705 // not connected, so we need to borrow other tensors of the DAG to be able
706 // to propagate. The transformations on borrowed tensors will be overwritten
707 // in the next step. We can not borrow the reference tensor of group 1.
708 auto all_tvs_except_ref1 = ir_utils::allTvsExcept(&fusion, {tv9});
709 auto all_tvs_except_ref1_set = std::unordered_set<TensorView*>(
710 all_tvs_except_ref1.begin(), all_tvs_except_ref1.end());
711 SetSelector selector(all_tvs_except_ref1_set);
712 MaxRootDomainInfoSpanningTree tree(tv10, &selector);
713 TransformPropagator tp(tv10);
714 tree.traverse(&tp);
715 scheduler_utils::parallelizeAllLike(
716 tv10, {tv2_cache, tv10, tv12}, {ParallelType::TIDx});
717 scheduler_utils::parallelizeAllLike(
718 tv10,
719 {tv2_cache, tv10, tv12},
720 {ParallelType::Vectorize, ParallelType::Unroll});
721 }
722
723 // Step 3, schedule group 1
724 {
725 // group 1: tv0, tv1, *tv9, innermost dim K
726 // [BIDx, Unswitch, 32(N), 32(K)]
727 tv9->merge(2);
728 tv9->split(2, 4);
729 tv9->split(2, threads_per_block);
730 tv9->axis(-1)->parallelize(ParallelType::Vectorize);
731 tv9->axis(-2)->parallelize(ParallelType::TIDx);
732 tv9->axis(-3)->parallelize(ParallelType::Unroll);
733 // [BIDx, Unswitch, Unroll, TIDx, Vectorize]
734
735 // Propagate to the entire DAG except for group 2 and its cached inputs
736 auto all_tvs_except2 =
737 ir_utils::allTvsExcept(&fusion, {tv2, tv2_cache, tv10, tv12});
738 auto all_tvs_except2_set = std::unordered_set<TensorView*>(
739 all_tvs_except2.begin(), all_tvs_except2.end());
740 SetSelector selector(all_tvs_except2_set);
741 MaxRootDomainInfoSpanningTree tree(tv9, &selector);
742 TransformPropagator tp(tv9);
743 tree.traverse(&tp);
744 scheduler_utils::parallelizeAllLike(
745 tv9, all_tvs_except2, {ParallelType::TIDx});
746 scheduler_utils::parallelizeAllLike(
747 tv9,
748 {tv0_cache, tv1_cache, tv9},
749 {ParallelType::Vectorize, ParallelType::Unroll});
750 }
751
752 // inline
753 inlineMost();
754
755 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
756 at::Tensor input0 = at::randn({512, 1024, 256}, options);
757 at::Tensor input1 = at::randn({1024, 512, 256}, options);
758 at::Tensor input2 = at::randn({512, 256, 1024}, options);
759
760 FusionExecutor fe;
761 fe.compileFusion(&fusion, {input0, input1, input2});
762 auto outputs = fe.runFusion({input0, input1, input2});
763
764 auto t3 = input0.transpose(1, 2);
765 auto t4 = input1.transpose(0, 1);
766 auto t5 = input1.sigmoid();
767 auto t6 = input2 + t3;
768 auto t7 = t5.transpose(0, 2);
769 auto t8 = t4 + input0;
770 auto t9 = t8.relu();
771 auto t10 = t6.sin();
772 auto t11 = t6.transpose(0, 1);
773 auto t12 = t7 + t11;
774
775 testValidate(
776 &fusion,
777 outputs,
778 {input0, input1, input2},
779 {t9, t10, t12},
780 __LINE__,
781 __FILE__);
782}
783
784// x->view->y
785TEST_F(NVFuserTest, FusionViewNoTranspose_CUDA) {
786 Fusion fusion;
787 FusionGuard fg(&fusion);
788
789 auto tv0 = makeContigTensor(3);
790 fusion.addInput(tv0);
791 auto tv1 = flatten(tv0, 1, 2);
792 fusion.addOutput(tv1);
793
794 TORCH_CHECK(!hasAtLeastTwoValidGroups(&fusion));
795}
796
797TEST_F(NVFuserTest, FusionTransposeSelfMapping_CUDA) {
798 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
799 Fusion& fusion = *fusion_ptr.get();
800 FusionGuard fg(&fusion);
801
802 auto tv0 = makeContigTensor(2);
803 fusion.addInput(tv0);
804 auto tv1 = transpose(tv0, 0, 1);
805 auto tv2 = add(tv0, tv1);
806 fusion.addOutput(tv2);
807
808 EXPECT_THAT(
809 [&]() { IterDomainGraph(fusion_ptr.get()); },
810 testing::ThrowsMessage<c10::Error>(
811 testing::HasSubstr("Unsupported domain mapping detected")));
812
813 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
814 auto t0 = at::randn({5, 5}, options);
815
816 FusionExecutorCache executor_cache(std::move(fusion_ptr));
817 auto cg_outputs = executor_cache.runFusionWithInputs({t0});
818
819 auto ref = t0.transpose(0, 1) + t0;
820
821 testValidate(
822 executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
823}
824
825#if 0
826// silent wrong result
827TEST_F(NVFuserTest, FusionTransposeViewSelfMapping_CUDA) {
828 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
829 Fusion& fusion = *fusion_ptr.get();
830 FusionGuard fg(&fusion);
831
832 auto tv0 = makeContigTensor(2);
833 fusion.addInput(tv0);
834 auto tv1 = transpose(tv0, 0, 1);
835 auto tv2 = view(tv0, {2, 3}, {3, 2});
836 auto tv3 = add(tv1, tv2);
837 fusion.addOutput(tv3);
838
839 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
840 auto t0 = at::randn({2, 3}, options);
841
842 FusionExecutorCache executor_cache(std::move(fusion_ptr));
843 auto cg_outputs = executor_cache.runFusionWithInputs({t0});
844
845 auto ref = t0.transpose(0, 1) + t0.view({3, 2});
846
847 testValidate(
848 executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
849}
850#endif
851
852// t0------------.
853// t2->broadcast->sub->mul->relu->t6
854// t1------------------'
855TEST_F(NVFuserTest, FusionScheduleTransposeMissingDim_CUDA) {
856 Fusion fusion;
857 FusionGuard fg(&fusion);
858
859 auto tv0 = makeContigTensor(3);
860 auto tv1 = makeContigConcreteTensor({1, -1, 1});
861 auto tv2 = makeContigTensor(1);
862 fusion.addInput(tv0);
863 fusion.addInput(tv1);
864 fusion.addInput(tv2);
865 auto tv3 = broadcast(tv2, {true, false, true});
866 auto tv4 = sub(tv0, tv3);
867 auto tv5 = mul(tv4, tv1);
868 auto tv6 = relu(tv5);
869 fusion.addOutput(tv6);
870
871 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
872 at::Tensor input0 = at::randn({512, 1024, 512}, options);
873 at::Tensor input1 = at::randn({1, 1024, 1}, options);
874 at::Tensor input2 = at::randn({1024}, options);
875
876 auto lparams = scheduleTranspose(&fusion, {input0, input1, input2});
877
878 FusionExecutor fe;
879 fe.compileFusion(&fusion, {input0, input1, input2}, lparams);
880 auto outputs = fe.runFusion({input0, input1, input2}, lparams);
881
882 auto t3 = input2.unsqueeze(0).unsqueeze(-1);
883 auto t4 = input0 - t3;
884 auto t5 = t4 * input1;
885 auto t6 = at::relu(t5);
886
887 testValidate(
888 &fusion, outputs, {input0, input1, input2}, {t6}, __LINE__, __FILE__);
889}
890
891// x->sin->transpose->cos->y
892TEST_F(NVFuserTest, FusionScheduleTransposeSmall_CUDA) {
893 Fusion fusion;
894 FusionGuard fg(&fusion);
895
896 auto tv0 = makeContigTensor(3);
897 fusion.addInput(tv0);
898 auto tv1 = sin(tv0);
899 auto tv2 = transpose(tv1, 1, 2);
900 auto tv3 = cos(tv2);
901 fusion.addOutput(tv3);
902
903 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
904 at::Tensor input = at::randn({1024, 2, 2}, options);
905
906 auto lparams = scheduleTranspose(&fusion, {input});
907
908 FusionExecutor fe;
909 fe.compileFusion(&fusion, {input}, lparams);
910 auto outputs = fe.runFusion({input}, lparams);
911
912 auto tv_ref = input.sin().transpose(1, 2).cos();
913
914 testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
915}
916
917// x->sin->transpose->cos->y
918TEST_F(NVFuserTest, FusionScheduleTransposeSmallInnerSize1_CUDA) {
919 Fusion fusion;
920 FusionGuard fg(&fusion);
921
922 auto tv0 = makeContigTensor(3);
923 fusion.addInput(tv0);
924 auto tv1 = sin(tv0);
925 auto tv2 = transpose(tv1, 1, 2);
926 auto tv3 = cos(tv2);
927 fusion.addOutput(tv3);
928
929 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
930 at::Tensor input = at::randn({64 * 1024 * 1024, 2, 2}, options);
931
932 auto lparams = scheduleTranspose(&fusion, {input});
933
934 FusionExecutor fe;
935 fe.compileFusion(&fusion, {input}, lparams);
936 auto outputs = fe.runFusion({input}, lparams);
937
938 auto tv_ref = input.sin().transpose(1, 2).cos();
939
940 testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
941}
942
943// x->sin->transpose->cos->y
944TEST_F(NVFuserTest, FusionScheduleTransposeSmallInnerSize2_CUDA) {
945 Fusion fusion;
946 FusionGuard fg(&fusion);
947
948 auto tv0 = makeContigTensor(3);
949 fusion.addInput(tv0);
950 auto tv1 = sin(tv0);
951 auto tv2 = transpose(tv1, 0, 2);
952 auto tv3 = cos(tv2);
953 fusion.addOutput(tv3);
954
955 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
956 at::Tensor input = at::randn({2, 64 * 1024 * 1024, 2}, options);
957
958 auto lparams = scheduleTranspose(&fusion, {input});
959
960 FusionExecutor fe;
961 fe.compileFusion(&fusion, {input}, lparams);
962 auto outputs = fe.runFusion({input}, lparams);
963
964 auto tv_ref = input.sin().transpose(0, 2).cos();
965
966 testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
967}
968
969// x->sin->transpose->cos->y
970TEST_F(NVFuserTest, FusionScheduleTransposeSmallInnerSize3_CUDA) {
971 Fusion fusion;
972 FusionGuard fg(&fusion);
973
974 auto tv0 = makeContigTensor(8);
975 fusion.addInput(tv0);
976 auto tv1 = sin(tv0);
977 auto tv2 = transpose(tv1, 4, 7);
978 auto tv3 = cos(tv2);
979 fusion.addOutput(tv3);
980
981 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
982 at::Tensor input = at::randn({1024 * 1024, 2, 2, 2, 2, 2, 2, 2}, options);
983
984 auto lparams = scheduleTranspose(&fusion, {input});
985
986 FusionExecutor fe;
987 fe.compileFusion(&fusion, {input}, lparams);
988 auto outputs = fe.runFusion({input}, lparams);
989
990 auto tv_ref = input.sin().transpose(4, 7).cos();
991
992 testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
993}
994
995// x->sin->transpose->cos->y
996TEST_F(NVFuserTest, FusionScheduleTranspose2DSmallInnerSize_CUDA) {
997 std::array<std::vector<int64_t>, 2> shapes{
998 std::vector<int64_t>{1024 * 1024 * 128, 2},
999 std::vector<int64_t>{2, 1024 * 1024 * 128}};
1000 for (const auto& shape : shapes) {
1001 Fusion fusion;
1002 FusionGuard fg(&fusion);
1003
1004 auto tv0 = makeContigTensor(2);
1005 fusion.addInput(tv0);
1006 auto tv1 = sin(tv0);
1007 auto tv2 = transpose(tv1, 0, 1);
1008 auto tv3 = cos(tv2);
1009 fusion.addOutput(tv3);
1010
1011 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1012 at::Tensor input = at::randn(shape, options);
1013
1014 auto lparams = scheduleTranspose(&fusion, {input});
1015
1016 FusionExecutor fe;
1017 fe.compileFusion(&fusion, {input}, lparams);
1018 auto outputs = fe.runFusion({input}, lparams);
1019
1020 auto tv_ref = input.sin().transpose(0, 1).cos();
1021
1022 testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
1023 }
1024}
1025
1026TEST_F(NVFuserTest, FusionTransposeBankConflict1_CUDA) {
1027 Fusion fusion;
1028 FusionGuard fg(&fusion);
1029
1030 auto tv0 = makeConcreteTensor({32, 32});
1031 fusion.addInput(tv0);
1032 auto tv1 = set(tv0);
1033 auto tv2 = transpose(tv1, 0, 1);
1034 auto tv3 = set(tv2);
1035 fusion.addOutput(tv3);
1036
1037 tv1->setMemoryType(MemoryType::Shared);
1038 tv1->axis(1)->parallelize(ParallelType::TIDx);
1039 tv2->axis(1)->parallelize(ParallelType::TIDx);
1040 tv3->axis(1)->parallelize(ParallelType::TIDx);
1041
1042 auto bank_conflict_info = fusion.bankConflictInfo();
1043
1044 TORCH_CHECK(!bank_conflict_info.empty());
1045 for (auto info : bank_conflict_info) {
1046 std::pair<int, int> expect{32, 0};
1047 TORCH_CHECK(info.second == expect);
1048 }
1049}
1050
1051TEST_F(NVFuserTest, FusionTransposeBankConflict2_CUDA) {
1052 Fusion fusion;
1053 FusionGuard fg(&fusion);
1054
1055 auto tv0 = makeConcreteTensor({32, 32});
1056 fusion.addInput(tv0);
1057 auto tv1 = set(tv0);
1058 auto tv2 = transpose(tv1, 0, 1);
1059 auto tv3 = set(tv2);
1060 fusion.addOutput(tv3);
1061
1062 tv1->setMemoryType(MemoryType::Shared);
1063 tv1->axis(0)->parallelize(ParallelType::TIDx);
1064 tv2->axis(0)->parallelize(ParallelType::TIDx);
1065 tv3->axis(0)->parallelize(ParallelType::TIDx);
1066
1067 auto bank_conflict_info = fusion.bankConflictInfo();
1068
1069 TORCH_CHECK(!bank_conflict_info.empty());
1070 for (auto info : bank_conflict_info) {
1071 std::pair<int, int> expect{0, 32};
1072 TORCH_CHECK(info.second == expect);
1073 }
1074}
1075
1076TEST_F(NVFuserTest, FusionTransposeBankConflict3_CUDA) {
1077 Fusion fusion;
1078 FusionGuard fg(&fusion);
1079
1080 auto tv0 = makeConcreteTensor({32, 32}, DataType::Bool);
1081 fusion.addInput(tv0);
1082 auto tv1 = set(tv0);
1083 auto tv2 = transpose(tv1, 0, 1);
1084 auto tv3 = set(tv2);
1085 fusion.addOutput(tv3);
1086
1087 tv1->setMemoryType(MemoryType::Shared);
1088 tv1->axis(1)->parallelize(ParallelType::TIDx);
1089 tv2->axis(1)->parallelize(ParallelType::TIDx);
1090 tv3->axis(1)->parallelize(ParallelType::TIDx);
1091
1092 auto bank_conflict_info = fusion.bankConflictInfo();
1093
1094 TORCH_CHECK(!bank_conflict_info.empty());
1095 for (auto info : bank_conflict_info) {
1096 std::pair<int, int> expect{8, 0};
1097 TORCH_CHECK(info.second == expect);
1098 }
1099}
1100
1101TEST_F(NVFuserTest, FusionTransposeBankConflict4_CUDA) {
1102 Fusion fusion;
1103 FusionGuard fg(&fusion);
1104
1105 auto tv0 = makeConcreteTensor({32, 32});
1106 fusion.addInput(tv0);
1107 auto tv1 = set(tv0);
1108 auto tv2 = transpose(tv1, 0, 1);
1109 auto tv3 = set(tv2);
1110 fusion.addOutput(tv3);
1111
1112 tv1->setMemoryType(MemoryType::Shared);
1113 tv1->merge(0);
1114 tv1->split(0, 4);
1115 tv1->split(0, 8);
1116 tv1->axis(-1)->parallelize(ParallelType::Vectorize);
1117 tv1->axis(0)->parallelize(ParallelType::TIDx);
1118 // T1 [TIDx(32), 8, V(4)]
1119
1120 tv2->setMemoryType(MemoryType::Shared);
1121 tv2->merge(0);
1122 tv2->split(0, 4);
1123 tv2->split(0, 32);
1124 tv2->axis(1)->parallelize(ParallelType::TIDx);
1125 // T2 [8, TIDx(32), 4]
1126
1127 tv3->merge(0);
1128 tv3->split(0, 2);
1129 tv3->split(0, 32);
1130 tv3->axis(1)->parallelize(ParallelType::TIDx);
1131 // T3 [16, TIDx(32), 2]
1132
1133 auto bank_conflict_info = fusion.bankConflictInfo();
1134
1135 TORCH_CHECK(!bank_conflict_info.empty());
1136 for (auto info : bank_conflict_info) {
1137 std::pair<int, int> expect1{0, 8};
1138 std::pair<int, int> expect2{8, 4};
1139 std::pair<int, int> expect3{2, 0};
1140 TORCH_CHECK(
1141 info.second == expect1 || info.second == expect2 ||
1142 info.second == expect3);
1143 }
1144}
1145
1146TEST_F(NVFuserTest, FusionTransposeBankConflict5_CUDA) {
1147 Fusion fusion;
1148 FusionGuard fg(&fusion);
1149
1150 auto tv0 = makeConcreteTensor({1024, 32, 32});
1151 fusion.addInput(tv0);
1152 auto tv1 = set(tv0);
1153 auto tv2 = transpose(tv1, 1, 2);
1154 auto tv3 = set(tv2);
1155 fusion.addOutput(tv3);
1156
1157 tv1->setMemoryType(MemoryType::Shared);
1158 tv1->axis(2)->parallelize(ParallelType::TIDx);
1159 tv2->axis(2)->parallelize(ParallelType::TIDx);
1160 tv3->axis(2)->parallelize(ParallelType::TIDx);
1161 tv1->axis(0)->parallelize(ParallelType::BIDx);
1162 tv2->axis(0)->parallelize(ParallelType::BIDx);
1163 tv3->axis(0)->parallelize(ParallelType::BIDx);
1164
1165 auto bank_conflict_info = fusion.bankConflictInfo();
1166
1167 TORCH_CHECK(!bank_conflict_info.empty());
1168 for (auto info : bank_conflict_info) {
1169 std::pair<int, int> expect{32, 0};
1170 TORCH_CHECK(info.second == expect);
1171 }
1172}
1173
1174TEST_F(NVFuserTest, FusionTransposeBankConflict6_CUDA) {
1175 Fusion fusion;
1176 FusionGuard fg(&fusion);
1177
1178 auto tv0 = makeConcreteTensor({1024, 32, 32});
1179 fusion.addInput(tv0);
1180 auto tv1 = set(tv0);
1181 auto tv2 = transpose(tv1, 1, 2);
1182 auto tv3 = set(tv2);
1183 fusion.addOutput(tv3);
1184
1185 tv1->setMemoryType(MemoryType::Shared);
1186 tv1->axis(2)->parallelize(ParallelType::TIDy);
1187 tv2->axis(2)->parallelize(ParallelType::TIDy);
1188 tv3->axis(2)->parallelize(ParallelType::TIDy);
1189 tv1->axis(0)->parallelize(ParallelType::BIDx);
1190 tv2->axis(0)->parallelize(ParallelType::BIDx);
1191 tv3->axis(0)->parallelize(ParallelType::BIDx);
1192
1193 auto bank_conflict_info = fusion.bankConflictInfo();
1194
1195 TORCH_CHECK(!bank_conflict_info.empty());
1196 for (auto info : bank_conflict_info) {
1197 std::pair<int, int> expect{32, 0};
1198 TORCH_CHECK(info.second == expect);
1199 }
1200}
1201
1202TEST_F(NVFuserTest, FusionTransposeBankConflict7_CUDA) {
1203 Fusion fusion;
1204 FusionGuard fg(&fusion);
1205
1206 auto tv0 = makeConcreteTensor({1024, 8, 8});
1207 fusion.addInput(tv0);
1208 auto tv1 = set(tv0);
1209 auto tv2 = transpose(tv1, 1, 2);
1210 auto tv3 = set(tv2);
1211 fusion.addOutput(tv3);
1212
1213 tv1->setMemoryType(MemoryType::Shared);
1214 tv1->axis(1)->parallelize(ParallelType::TIDx);
1215 tv2->axis(1)->parallelize(ParallelType::TIDx);
1216 tv3->axis(1)->parallelize(ParallelType::TIDx);
1217 tv1->axis(2)->parallelize(ParallelType::TIDy);
1218 tv2->axis(2)->parallelize(ParallelType::TIDy);
1219 tv3->axis(2)->parallelize(ParallelType::TIDy);
1220 tv1->axis(0)->parallelize(ParallelType::BIDx);
1221 tv2->axis(0)->parallelize(ParallelType::BIDx);
1222 tv3->axis(0)->parallelize(ParallelType::BIDx);
1223
1224 auto bank_conflict_info = fusion.bankConflictInfo();
1225
1226 TORCH_CHECK(!bank_conflict_info.empty());
1227 for (auto info : bank_conflict_info) {
1228 std::pair<int, int> expect{0, 2};
1229 TORCH_CHECK(info.second == expect);
1230 }
1231}
1232
1233TEST_F(NVFuserTest, FusionTransposeBankConflict8_CUDA) {
1234 Fusion fusion;
1235 FusionGuard fg(&fusion);
1236
1237 auto tv0 = makeConcreteTensor({1024, 8, 8});
1238 fusion.addInput(tv0);
1239 auto tv1 = set(tv0);
1240 auto tv2 = transpose(tv1, 1, 2);
1241 auto tv3 = set(tv2);
1242 fusion.addOutput(tv3);
1243
1244 tv1->setMemoryType(MemoryType::Shared);
1245 tv1->axis(2)->parallelize(ParallelType::TIDx);
1246 tv2->axis(2)->parallelize(ParallelType::TIDy);
1247 tv3->axis(2)->parallelize(ParallelType::TIDy);
1248 tv1->axis(0)->parallelize(ParallelType::BIDx);
1249 tv2->axis(0)->parallelize(ParallelType::BIDx);
1250 tv3->axis(0)->parallelize(ParallelType::BIDx);
1251
1252 auto bank_conflict_info = fusion.bankConflictInfo();
1253
1254 // no bank confliction
1255 TORCH_CHECK(bank_conflict_info.empty());
1256}
1257
1258} // namespace jit
1259} // namespace torch
1260#endif // #if defined(USE_CUDA)
1261