1 | #if defined(USE_CUDA) |
2 | #include <gtest/gtest.h> |
3 | |
4 | #include <arith.h> |
5 | #include <codegen.h> |
6 | #include <disjoint_set.h> |
7 | #include <executor.h> |
8 | #include <executor_launch_params.h> |
9 | #include <expr_evaluator.h> |
10 | #include <fusion.h> |
11 | #include <fusion_segmenter.h> |
12 | #include <grouped_reduction.h> |
13 | #include <inlining.h> |
14 | #include <ir_all_nodes.h> |
15 | #include <ir_builder.h> |
16 | #include <ir_graphviz.h> |
17 | #include <ir_iostream.h> |
18 | #include <ir_utils.h> |
19 | #include <iter_visitor.h> |
20 | #include <kernel_cache.h> |
21 | #include <kernel_expr_evaluator.h> |
22 | #include <kernel_ir.h> |
23 | #include <lower2device.h> |
24 | #include <mutator.h> |
25 | #include <register_interface.h> |
26 | #include <root_domain_map.h> |
27 | #include <scheduler/all_schedulers.h> |
28 | #include <scheduler/reduction_utils.h> |
29 | #include <scheduler/utils.h> |
30 | #include <test/test_gpu_validator.h> |
31 | #include <test/test_utils.h> |
32 | #include <transform_replay.h> |
33 | #include <transform_rfactor.h> |
34 | |
35 | // fuser and IR parser |
36 | #include <ATen/cuda/CUDAContext.h> |
37 | #include <ATen/cuda/Exceptions.h> |
38 | #include <c10/cuda/CUDAStream.h> |
39 | |
40 | #include <algorithm> |
41 | #include <iostream> |
42 | |
43 | // Tests go in torch::jit |
44 | namespace torch { |
45 | namespace jit { |
46 | |
47 | using namespace torch::jit::fuser::cuda; |
48 | using namespace at::indexing; |
49 | |
50 | namespace { |
51 | |
52 | class KernelExprVisitor : private kir::IrVisitor { |
53 | public: |
54 | static std::vector<Expr*> getAllExprs(const kir::Kernel* kernel) { |
55 | KernelExprVisitor visitor(kernel); |
56 | return visitor.all_exprs_; |
57 | } |
58 | |
59 | private: |
60 | KernelExprVisitor(const kir::Kernel* kernel) { |
61 | handle(kernel->topLevelExprs()); |
62 | } |
63 | |
64 | using kir::IrVisitor::handle; |
65 | |
66 | void handle(Expr* expr) final { |
67 | all_exprs_.push_back(expr); |
68 | kir::IrVisitor::handle(expr); |
69 | } |
70 | |
71 | private: |
72 | std::vector<Expr*> all_exprs_; |
73 | }; |
74 | |
75 | void validateNoParallelBroadcastExist(kir::Kernel* kernel) { |
76 | for (auto expr : KernelExprVisitor::getAllExprs(kernel)) { |
77 | BroadcastOp* bc = dynamic_cast<BroadcastOp*>(expr); |
78 | if (bc == nullptr) { |
79 | auto grid_bc = dynamic_cast<kir::GridBroadcast*>(expr); |
80 | if (grid_bc != nullptr) { |
81 | std::cerr << "Grid broadcast: " << grid_bc->toString(); |
82 | bc = grid_bc->broadcast_op(); |
83 | } |
84 | } |
85 | if (bc == nullptr) { |
86 | continue; |
87 | } |
88 | TORCH_CHECK( |
89 | kernel->summary().broadcast_parallel_types.at(bc).none(), |
90 | "Parallel broadcast should not exist but was found: " , |
91 | bc->toString()); |
92 | } |
93 | } |
94 | |
95 | } // namespace |
96 | |
97 | TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) { |
98 | const int nx = 999; |
99 | const int tidx = 128; |
100 | const int bidx = 4; |
101 | |
102 | if (ceilDiv(nx, tidx) > deviceSMCount()) { |
103 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
104 | } |
105 | |
106 | Fusion fusion; |
107 | FusionGuard fg(&fusion); |
108 | |
109 | auto tv0 = makeSymbolicTensor(1); |
110 | fusion.addInput(tv0); |
111 | |
112 | auto tv1 = sum(tv0, {0}); |
113 | auto tv2 = broadcast(tv1, {true}); |
114 | auto tv3 = add(tv0, tv2); |
115 | |
116 | fusion.addOutput(tv3); |
117 | |
118 | tv3->split(0, tidx); |
119 | tv3->split(0, bidx); |
120 | tv3->split(0, 1); // unswitch |
121 | TransformPropagator propagator(tv3); |
122 | MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); |
123 | |
124 | tv3->axis(0)->parallelize(ParallelType::BIDy); |
125 | tv3->axis(2)->parallelize(ParallelType::BIDx); |
126 | tv3->axis(3)->parallelize(ParallelType::TIDx); |
127 | scheduler_utils::parallelizeAllLike(tv3); |
128 | |
129 | // Just to make sure fused_reduction and work buffers are allocated |
130 | // uniquely |
131 | tv1->axis(1)->parallelize(ParallelType::Unswitch); |
132 | |
133 | GpuLower gpulw(&fusion); |
134 | validateNoParallelBroadcastExist(gpulw.kernel()); |
135 | |
136 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
137 | at::manual_seed(0); |
138 | auto t0 = at::randn({nx}, options); |
139 | |
140 | FusionExecutor fe; |
141 | fe.compileFusion(&fusion, {t0}); |
142 | auto cg_outputs = fe.runFusion({t0}); |
143 | |
144 | auto ref = sum(t0).unsqueeze(0) + t0; |
145 | |
146 | testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); |
147 | } |
148 | |
149 | TEST_F(NVFuserTest, FusionGridAllreduce2_CUDA) { |
150 | const int nx = 99; |
151 | const int tidx = 32; |
152 | |
153 | if (ceilDiv(nx, tidx) > deviceSMCount()) { |
154 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
155 | } |
156 | |
157 | Fusion fusion; |
158 | FusionGuard fg(&fusion); |
159 | |
160 | auto tv0 = makeSymbolicTensor(1); |
161 | fusion.addInput(tv0); |
162 | |
163 | auto tv1 = sum(tv0, {0}); |
164 | auto tv2 = broadcast(tv1, {true}); |
165 | auto tv3 = add(tv0, tv2); |
166 | |
167 | fusion.addOutput(tv3); |
168 | |
169 | tv3->split(0, tidx); |
170 | TransformPropagator propagator(tv3); |
171 | MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); |
172 | |
173 | tv3->axis(0)->parallelize(ParallelType::BIDx); |
174 | tv3->axis(1)->parallelize(ParallelType::TIDx); |
175 | scheduler_utils::parallelizeAllLike(tv3, {tv2}); |
176 | |
177 | // Broadcast on TIDy instead of TIDx. This still uses the fused |
178 | // reduction as it's broadcast on BIDx as well. Since TIDy is not |
179 | // predicated, the broadcast becomes a set op. |
180 | tv1->axis(0)->parallelize(ParallelType::BIDx); |
181 | tv1->axis(1)->parallelize(ParallelType::TIDy); |
182 | |
183 | GpuLower gpulw(&fusion); |
184 | validateNoParallelBroadcastExist(gpulw.kernel()); |
185 | |
186 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
187 | at::manual_seed(0); |
188 | auto t0 = at::randn({nx}, options); |
189 | |
190 | FusionExecutor fe; |
191 | fe.compileFusion(&fusion, {t0}); |
192 | auto cg_outputs = fe.runFusion({t0}); |
193 | |
194 | auto ref = sum(t0).unsqueeze(0) + t0; |
195 | |
196 | testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); |
197 | } |
198 | |
199 | // Grid reduction with serial non-reduction axis. The global work |
200 | // buffer is double buffered. |
201 | TEST_F(NVFuserTest, FusionGridAllreduce3_CUDA) { |
202 | const int nx = 100; |
203 | const int ny = 5000; |
204 | const int tidx = 128; |
205 | |
206 | if (ceilDiv(ny, tidx) > deviceSMCount()) { |
207 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
208 | } |
209 | |
210 | Fusion fusion; |
211 | FusionGuard fg(&fusion); |
212 | |
213 | auto tv0 = makeSymbolicTensor(2); |
214 | fusion.addInput(tv0); |
215 | |
216 | auto tv1 = sum(tv0, {1}); |
217 | auto tv2 = broadcast(tv1, {false, true}); |
218 | auto tv3 = add(tv0, tv2); |
219 | |
220 | fusion.addOutput(tv3); |
221 | |
222 | tv3->split(1, tidx); |
223 | TransformPropagator propagator(tv3); |
224 | MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); |
225 | |
226 | tv0->computeAt(tv3, 1); |
227 | |
228 | tv3->axis(1)->parallelize(ParallelType::BIDx); |
229 | tv3->axis(2)->parallelize(ParallelType::TIDx); |
230 | scheduler_utils::parallelizeAllLike(tv3); |
231 | |
232 | GpuLower gpulw(&fusion); |
233 | validateNoParallelBroadcastExist(gpulw.kernel()); |
234 | |
235 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
236 | at::manual_seed(0); |
237 | auto t0 = at::randn({nx, ny}, options); |
238 | |
239 | FusionExecutor fe; |
240 | fe.compileFusion(&fusion, {t0}); |
241 | auto cg_outputs = fe.runFusion({t0}); |
242 | |
243 | auto ref = sum(t0, {1}).unsqueeze(-1) + t0; |
244 | |
245 | testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); |
246 | } |
247 | |
248 | // Indirect reduction and broadcast |
249 | TEST_F(NVFuserTest, FusionGridAllreduce4_CUDA) { |
250 | const int nx = 999; |
251 | const int tidx = 128; |
252 | |
253 | if (ceilDiv(nx, tidx) > deviceSMCount()) { |
254 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
255 | } |
256 | |
257 | Fusion fusion; |
258 | FusionGuard fg(&fusion); |
259 | |
260 | auto tv0 = makeSymbolicTensor(1); |
261 | fusion.addInput(tv0); |
262 | |
263 | auto tv1 = sum(tv0, {0}); |
264 | auto tv2 = add(tv1, IrBuilder::create<Double>(1)); |
265 | auto tv3 = broadcast(tv2, {true}); |
266 | auto tv4 = add(tv0, tv3); |
267 | |
268 | fusion.addOutput(tv4); |
269 | |
270 | tv4->split(0, tidx); |
271 | TransformPropagator propagator(tv4); |
272 | MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); |
273 | |
274 | tv4->axis(0)->parallelize(ParallelType::BIDx); |
275 | tv4->axis(1)->parallelize(ParallelType::TIDx); |
276 | scheduler_utils::parallelizeAllLike(tv4); |
277 | |
278 | GpuLower gpulw(&fusion); |
279 | validateNoParallelBroadcastExist(gpulw.kernel()); |
280 | |
281 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
282 | at::manual_seed(0); |
283 | auto t0 = at::randn({nx}, options); |
284 | |
285 | FusionExecutor fe; |
286 | fe.compileFusion(&fusion, {t0}); |
287 | auto cg_outputs = fe.runFusion({t0}); |
288 | |
289 | auto ref = (sum(t0) + 1).unsqueeze(0) + t0; |
290 | |
291 | testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); |
292 | } |
293 | |
294 | // Unused block dimension in the kernel |
295 | TEST_F(NVFuserTest, FusionGridAllreduce5_CUDA) { |
296 | const int nx = 999; |
297 | const int tidx = 128; |
298 | const int iter = 2; |
299 | const int bdimx = 9; // One more than required by the reduction |
300 | const int bdimy = 3; // Want an unused dimension |
301 | |
302 | // Going to bump the bdimx count for this test, ignor |
303 | if (bdimx * bdimy > deviceSMCount()) { |
304 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
305 | } |
306 | |
307 | Fusion fusion; |
308 | FusionGuard fg(&fusion); |
309 | |
310 | // Didn't setup this test with inlining for register usage, so just leave the |
311 | // iter dimension concrete |
312 | auto tv0 = makeConcreteTensor({iter, -1}); |
313 | fusion.addInput(tv0); |
314 | |
315 | auto tv1 = sum(tv0, {1}); |
316 | auto tv2 = add(tv1, IrBuilder::create<Double>(1)); |
317 | auto tv3 = broadcast(tv2, {false, true}); |
318 | auto tv4 = add(tv0, tv3); |
319 | |
320 | fusion.addOutput(tv4); |
321 | |
322 | // Dummy op to mess with parallelization |
323 | auto tv5 = makeSymbolicTensor(2); |
324 | fusion.addInput(tv5); |
325 | auto tv6 = set(tv5); |
326 | fusion.addOutput(tv6); |
327 | |
328 | // Setup the reduction |
329 | tv4->split(1, tidx); |
330 | TransformPropagator propagator(tv4); |
331 | MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); |
332 | |
333 | tv4->axis(1)->parallelize(ParallelType::BIDx); |
334 | tv4->axis(2)->parallelize(ParallelType::TIDx); |
335 | scheduler_utils::parallelizeAllLike(tv4); |
336 | |
337 | tv6->axis(0)->parallelize(ParallelType::BIDy); |
338 | tv6->axis(1)->parallelize(ParallelType::BIDx); |
339 | |
340 | GpuLower gpulw(&fusion); |
341 | validateNoParallelBroadcastExist(gpulw.kernel()); |
342 | |
343 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
344 | at::manual_seed(0); |
345 | auto t0 = at::randn({iter, nx}, options); |
346 | auto t5 = at::randn({bdimy, bdimx}, options); |
347 | |
348 | FusionExecutor fe; |
349 | fe.compileFusion(&fusion, {t0, t5}); |
350 | auto cg_outputs = fe.runFusion({t0, t5}); |
351 | |
352 | auto ref = (sum(t0, {1}) + 1).unsqueeze(-1) + t0; |
353 | |
354 | testValidate(&fusion, cg_outputs, {t0, t5}, {ref, t5}, __LINE__, __FILE__); |
355 | } |
356 | |
357 | TEST_F(NVFuserTest, FusionGridAllreduce6_CUDA) { |
358 | Fusion fusion; |
359 | FusionGuard fg(&fusion); |
360 | |
361 | std::vector<int64_t> shape({99, 200}); |
362 | |
363 | const int vec = 4; |
364 | const int tidx = 32; |
365 | const int tidy = 8; |
366 | const int bdimx = ceilDiv(shape[1], vec * tidx); |
367 | const int bdimy = ceilDiv(shape[0], tidy); |
368 | |
369 | if (bdimx * bdimy > deviceSMCount()) { |
370 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
371 | } |
372 | |
373 | auto tv0 = makeSymbolicTensor(2); |
374 | fusion.addInput(tv0); |
375 | |
376 | auto tv1 = set(tv0); |
377 | auto tv2 = sum(tv1, {0}); |
378 | auto tv3 = broadcast(tv2, {true, false}); |
379 | auto tv4 = add(tv0, tv3); |
380 | fusion.addOutput(tv4); |
381 | |
382 | tv1->split(1, vec); |
383 | tv1->split(1, tidx); |
384 | tv1->split(0, tidy); |
385 | TransformPropagator propagator(tv1); |
386 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
387 | |
388 | tv1->axis(0)->parallelize(ParallelType::BIDy); |
389 | tv1->axis(1)->parallelize(ParallelType::TIDy); |
390 | tv1->axis(2)->parallelize(ParallelType::BIDx); |
391 | tv1->axis(3)->parallelize(ParallelType::TIDx); |
392 | |
393 | scheduler_utils::parallelizeAllLike(tv1); |
394 | |
395 | tv1->axis(4)->parallelize(ParallelType::Vectorize); |
396 | |
397 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
398 | at::manual_seed(0); |
399 | auto t0 = at::randn(shape, options); |
400 | |
401 | FusionExecutor fe; |
402 | fe.compileFusion(&fusion, {t0}); |
403 | auto outputs = fe.runFusion({t0}); |
404 | |
405 | auto t0_double = t0.to(at::kDouble); |
406 | auto ref = t0_double + t0_double.sum({0}).unsqueeze(0); |
407 | |
408 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
409 | } |
410 | |
411 | TEST_F(NVFuserTest, FusionGridAllreduceWelford1_CUDA) { |
412 | const int nx = 999; |
413 | const int tidx = 128; |
414 | |
415 | if (ceilDiv(nx, tidx) > deviceSMCount()) { |
416 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
417 | } |
418 | |
419 | Fusion fusion; |
420 | FusionGuard fg(&fusion); |
421 | |
422 | auto tv0 = makeSymbolicTensor(1); |
423 | fusion.addInput(tv0); |
424 | |
425 | auto tvs = Welford(tv0, {0}); |
426 | auto tv2 = broadcast(tvs.avg, {true}); |
427 | auto tv3 = broadcast(tvs.var_sum, {true}); |
428 | auto tv4 = add(tv0, tv2); |
429 | auto tv5 = add(tv4, tv3); |
430 | |
431 | fusion.addOutput(tv5); |
432 | |
433 | tv5->split(0, tidx); |
434 | TransformPropagator propagator(tv5); |
435 | MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); |
436 | |
437 | tv5->axis(0)->parallelize(ParallelType::BIDx); |
438 | tv5->axis(1)->parallelize(ParallelType::TIDx); |
439 | scheduler_utils::parallelizeAllLike(tv5); |
440 | |
441 | GpuLower gpulw(&fusion); |
442 | validateNoParallelBroadcastExist(gpulw.kernel()); |
443 | |
444 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
445 | at::manual_seed(0); |
446 | auto t0 = at::randn({nx}, options); |
447 | |
448 | FusionExecutor fe; |
449 | fe.compileFusion(&fusion, {t0}); |
450 | auto cg_outputs = fe.runFusion({t0}); |
451 | |
452 | auto ref = |
453 | (t0.mean({0}).unsqueeze(0) + t0) + t0.var({0}, false).unsqueeze(0) * nx; |
454 | |
455 | testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); |
456 | } |
457 | |
458 | // Grid welford reduction with serial non-reduction axis. The global |
459 | // work buffer is double buffered. |
460 | TEST_F(NVFuserTest, FusionGridAllreduceWelford2_CUDA) { |
461 | const int nx = 100; |
462 | const int ny = 5000; |
463 | const int tidx = 128; |
464 | |
465 | if (ceilDiv(ny, tidx) > deviceSMCount()) { |
466 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
467 | } |
468 | |
469 | Fusion fusion; |
470 | FusionGuard fg(&fusion); |
471 | |
472 | auto tv0 = makeSymbolicTensor(2); |
473 | fusion.addInput(tv0); |
474 | |
475 | auto tvs = Welford(tv0, {1}); |
476 | auto tv2 = broadcast(tvs.avg, {false, true}); |
477 | auto tv3 = add(tv0, tv2); |
478 | |
479 | fusion.addOutput(tv3); |
480 | |
481 | tv3->split(1, tidx); |
482 | TransformPropagator propagator(tv3); |
483 | MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); |
484 | |
485 | tv0->computeAt(tv3, 1); |
486 | |
487 | tv3->axis(1)->parallelize(ParallelType::BIDx); |
488 | tv3->axis(2)->parallelize(ParallelType::TIDx); |
489 | scheduler_utils::parallelizeAllLike(tv3); |
490 | |
491 | // There must be no parallel broadcast |
492 | GpuLower gpulw(&fusion); |
493 | validateNoParallelBroadcastExist(gpulw.kernel()); |
494 | |
495 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
496 | at::manual_seed(0); |
497 | auto t0 = at::randn({nx, ny}, options); |
498 | |
499 | FusionExecutor fe; |
500 | fe.compileFusion(&fusion, {t0}); |
501 | auto cg_outputs = fe.runFusion({t0}); |
502 | |
503 | auto ref = (sum(t0, {1}) / ny).unsqueeze(-1) + t0; |
504 | |
505 | testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); |
506 | } |
507 | |
508 | // Persistent batchnorm. Uses the fused reduction for grid welford and |
509 | // broadcast. |
510 | TEST_F(NVFuserTest, FusionFusedReductionBatchnorm_CUDA) { |
511 | const std::vector<int64_t> input_shape{256, 2048, 14, 14}; |
512 | |
513 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
514 | Fusion& fusion = *fusion_ptr.get(); |
515 | FusionGuard fg(&fusion); |
516 | |
517 | auto tv0 = makeSymbolicTensor(4, DataType::Half); |
518 | fusion.addInput(tv0); |
519 | auto tv1 = makeSymbolicTensor(1, DataType::Half); |
520 | fusion.addInput(tv1); |
521 | auto tv2 = makeSymbolicTensor(1, DataType::Half); |
522 | fusion.addInput(tv2); |
523 | auto tv3 = makeSymbolicTensor(1, DataType::Float); |
524 | fusion.addInput(tv3); |
525 | auto tv4 = makeSymbolicTensor(1, DataType::Float); |
526 | fusion.addInput(tv4); |
527 | |
528 | auto d34 = IrBuilder::create<Double>(1); |
529 | auto tv5 = castOp(DataType::Float, tv0); |
530 | auto tv6 = castOp(DataType::Float, tv1); |
531 | auto tv7 = castOp(DataType::Float, tv2); |
532 | auto tvs = Welford(tv5, {0, 2, 3}); |
533 | auto tv8 = tvs.avg; |
534 | auto tv9 = tvs.var_sum; |
535 | auto tv10 = tvs.n; |
536 | auto tv11 = mul(tv8, IrBuilder::create<Double>(0.1)); |
537 | auto tv12 = mul(tv3, d34); |
538 | auto tv13 = add(tv12, tv11); |
539 | auto d43 = IrBuilder::create<Double>(0.5); |
540 | auto tv14 = mul(tv9, d43); |
541 | auto tv15 = mul(tv14, IrBuilder::create<Double>(0.1)); |
542 | auto tv16 = mul(tv4, d34); |
543 | auto tv17 = add(tv16, tv15); |
544 | auto tv18 = broadcast(tv8, {true, false, true, true}); |
545 | auto tv19 = sub(tv5, tv18); |
546 | auto tv20 = mul(tv9, d43); |
547 | auto tv21 = add(tv20, IrBuilder::create<Double>(0.0001)); |
548 | auto tv22 = rsqrt(tv21); |
549 | auto tv23 = broadcast(tv22, {true, false, true, true}); |
550 | auto tv24 = mul(tv19, tv23); |
551 | auto tv25 = broadcast(tv6, {true, false, true, true}); |
552 | auto tv26 = mul(tv24, tv25); |
553 | auto tv27 = broadcast(tv7, {true, false, true, true}); |
554 | auto tv28 = add(tv26, tv27); |
555 | auto tv29 = castOp(DataType::Half, tv28); |
556 | fusion.addOutput(tv13); |
557 | fusion.addOutput(tv17); |
558 | fusion.addOutput(tv29); |
559 | |
560 | auto tv0_cache = tv0->cacheAfter(); |
561 | auto tv1_cache = tv1->cacheAfter(); |
562 | auto tv2_cache = tv2->cacheAfter(); |
563 | auto tv3_cache = tv3->cacheAfter(); |
564 | auto tv4_cache = tv4->cacheAfter(); |
565 | |
566 | auto tv13_cache = tv13->cacheBefore(); |
567 | auto tv17_cache = tv17->cacheBefore(); |
568 | auto tv29_cache = tv29->cacheBefore(); |
569 | |
570 | tv0->split(1, NamedScalar::getParallelDim(ParallelType::BIDx), false); |
571 | tv0->split(0, NamedScalar::getParallelDim(ParallelType::BIDy), false); |
572 | tv0->split(1, 8, false); |
573 | tv0->split(2, 8, false); |
574 | tv0->merge(-2, -1); |
575 | tv0->split(-1, 2); |
576 | tv0->split(-2, 1, false); |
577 | tv0->split(-2, 1, false); |
578 | tv0->reorder( |
579 | {{4, 0}, |
580 | {5, 1}, |
581 | {0, 2}, |
582 | {3, 3}, |
583 | {8, 4}, |
584 | {1, 5}, |
585 | {7, 6}, |
586 | {2, 7}, |
587 | {9, 8}, |
588 | {6, 9}}); |
589 | |
590 | TransformPropagator propagator(tv0); |
591 | MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); |
592 | |
593 | ir_utils::rfactorHelper(tvs.avg, {-5, -4, -3, -2, -1}); |
594 | |
595 | tv0->computeAt(tv29, 2); |
596 | tv1->computeAt(tv29, 2); |
597 | tv2->computeAt(tv29, 2); |
598 | tv3->computeAt(tv13, 2); |
599 | tv4->computeAt(tv17, 2); |
600 | |
601 | tv29->axis(0)->parallelize(ParallelType::BIDx); |
602 | tv29->axis(2)->parallelize(ParallelType::BIDy); |
603 | tv29->axis(3)->parallelize(ParallelType::TIDz); |
604 | tv29->axis(4)->parallelize(ParallelType::TIDx); |
605 | scheduler_utils::parallelizeAllLike(tv29); |
606 | |
607 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
608 | auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
609 | at::manual_seed(0); |
610 | auto t0 = at::randn(input_shape, options_half); |
611 | auto t1 = at::randn(input_shape[1], options_half); |
612 | auto t2 = at::randn(input_shape[1], options_half); |
613 | auto t3 = at::randn(input_shape[1], options); |
614 | auto t4 = at::randn(input_shape[1], options); |
615 | std::vector<IValue> aten_inputs = {t0, t1, t2, t3, t4}; |
616 | |
617 | GpuLower gpulw(&fusion); |
618 | validateNoParallelBroadcastExist(gpulw.kernel()); |
619 | |
620 | FusionExecutor fe; |
621 | LaunchParams launch_params(2, 2, -1, -1, -1, -1); |
622 | fe.compileFusion(&fusion, aten_inputs, launch_params); |
623 | auto cg_outputs = fe.runFusion(aten_inputs, launch_params); |
624 | |
625 | auto t5 = t0.to(at::kFloat); |
626 | auto t6 = t1.to(at::kFloat); |
627 | auto t7 = t2.to(at::kFloat); |
628 | auto t8 = t5.mean({0, 2, 3}); |
629 | auto t9 = t5.var({0, 2, 3}, false) * input_shape[0] * input_shape[2] * |
630 | input_shape[3]; |
631 | auto t11 = t8 * 0.1; |
632 | auto t12 = t3 * 1; |
633 | auto t13 = t12 + t11; |
634 | auto t14 = t9 * 0.5; |
635 | auto t15 = t14 * 0.1; |
636 | auto t16 = t4 * 1; |
637 | auto t17 = t16 + t15; |
638 | auto t18 = t8.unsqueeze(0).unsqueeze(-1).unsqueeze(-1); |
639 | auto t19 = t5 - t18; |
640 | auto t20 = t9 * 0.5; |
641 | auto t21 = t20 + 0.0001; |
642 | auto t22 = rsqrt(t21); |
643 | auto t23 = t22.unsqueeze(0).unsqueeze(-1).unsqueeze(-1); |
644 | auto t24 = t19 * t23; |
645 | auto t25 = t6.unsqueeze(0).unsqueeze(-1).unsqueeze(-1); |
646 | auto t26 = t24 * t25; |
647 | auto t27 = t7.unsqueeze(0).unsqueeze(-1).unsqueeze(-1); |
648 | auto t28 = t26 + t27; |
649 | auto t29 = t28.to(at::kHalf); |
650 | |
651 | testValidate( |
652 | &fusion, |
653 | cg_outputs, |
654 | aten_inputs, |
655 | {t13, t17, t29}, |
656 | __LINE__, |
657 | __FILE__, |
658 | "" , |
659 | launch_params); |
660 | } |
661 | |
662 | // Simple grouped reduction |
663 | TEST_F(NVFuserTest, FusionGroupedReduction1_CUDA) { |
664 | Fusion fusion; |
665 | FusionGuard fg(&fusion); |
666 | |
667 | auto tv0 = makeSymbolicTensor(2); |
668 | fusion.addInput(tv0); |
669 | |
670 | auto tv1 = sum(tv0, {1}); |
671 | auto tv2 = sum(tv0, {1}); |
672 | auto tv3 = add(tv1, tv2); |
673 | fusion.addOutput(tv3); |
674 | |
675 | groupReductions({tv1, tv2}); |
676 | |
677 | tv2->axis(0)->parallelize(ParallelType::BIDx); |
678 | tv2->axis(1)->parallelize(ParallelType::TIDx); |
679 | scheduler_utils::parallelizeAllLike(tv2); |
680 | |
681 | std::vector<int64_t> shape({99, 999}); |
682 | |
683 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
684 | |
685 | auto t0 = at::randn(shape, options); |
686 | |
687 | FusionExecutor fe; |
688 | fe.compileFusion(&fusion, {t0}); |
689 | auto outputs = fe.runFusion({t0}); |
690 | |
691 | auto ref = t0.sum({1}) * 2; |
692 | |
693 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
694 | } |
695 | |
696 | // Grouping reductions with different ops |
697 | TEST_F(NVFuserTest, FusionGroupedReduction2_CUDA) { |
698 | Fusion fusion; |
699 | FusionGuard fg(&fusion); |
700 | |
701 | auto tv0 = makeSymbolicTensor(2); |
702 | fusion.addInput(tv0); |
703 | |
704 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
705 | auto tv2 = sum(tv1, {1}); |
706 | |
707 | auto tv3 = add(tv0, IrBuilder::create<Double>(2)); |
708 | auto tv4 = max(tv3, {1}); |
709 | |
710 | auto tv5 = add(tv2, tv4); |
711 | fusion.addOutput(tv5); |
712 | |
713 | groupReductions({tv2, tv4}); |
714 | |
715 | tv2->split(1, 128); |
716 | TransformPropagator propagator(tv2); |
717 | MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); |
718 | |
719 | tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined); |
720 | |
721 | // tv4 is automatically parallelized in the same way |
722 | tv2->axis(0)->parallelize(ParallelType::BIDy); |
723 | tv2->axis(1)->parallelize(ParallelType::BIDx); |
724 | tv2->axis(2)->parallelize(ParallelType::TIDx); |
725 | |
726 | std::vector<int64_t> shape({99, 999}); |
727 | |
728 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
729 | |
730 | auto t0 = at::randn(shape, options); |
731 | |
732 | FusionExecutor fe; |
733 | fe.compileFusion(&fusion, {t0}); |
734 | auto outputs = fe.runFusion({t0}); |
735 | |
736 | auto ref = (t0 + 1).sum({1}) + std::get<0>((t0 + 2).max(1)); |
737 | |
738 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
739 | } |
740 | |
741 | // Grouped reduction with different types |
742 | TEST_F(NVFuserTest, FusionGroupedReduction3_CUDA) { |
743 | Fusion fusion; |
744 | FusionGuard fg(&fusion); |
745 | |
746 | auto tv0 = makeSymbolicTensor(2); |
747 | fusion.addInput(tv0); |
748 | |
749 | auto tv1 = sum(tv0, {1}); |
750 | |
751 | auto tv2 = castOp(DataType::Double, tv0); |
752 | auto tv3 = sum(tv2, {1}); |
753 | auto tv4 = castOp(DataType::Float, tv3); |
754 | |
755 | auto tv5 = add(tv1, tv4); |
756 | fusion.addOutput(tv5); |
757 | |
758 | groupReductions({tv1, tv3}); |
759 | tv1->split(1, 128); |
760 | TransformPropagator propagator(tv1); |
761 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
762 | |
763 | tv0->computeAt(tv5, -1, ComputeAtMode::MostInlined); |
764 | |
765 | tv1->axis(0)->parallelize(ParallelType::BIDy); |
766 | tv1->axis(1)->parallelize(ParallelType::BIDx); |
767 | tv1->axis(2)->parallelize(ParallelType::TIDx); |
768 | |
769 | std::vector<int64_t> shape({99, 999}); |
770 | |
771 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
772 | |
773 | auto t0 = at::randn(shape, options); |
774 | |
775 | FusionExecutor fe; |
776 | fe.compileFusion(&fusion, {t0}); |
777 | auto outputs = fe.runFusion({t0}); |
778 | |
779 | auto ref = t0.sum({1}) + t0.to(c10::kDouble).sum({1}).to(c10::kFloat); |
780 | |
781 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
782 | } |
783 | |
784 | // Testing validation |
785 | TEST_F(NVFuserTest, FusionGroupedReduction4_CUDA) { |
786 | Fusion fusion; |
787 | FusionGuard fg(&fusion); |
788 | |
789 | auto tv0 = makeSymbolicTensor(2); |
790 | fusion.addInput(tv0); |
791 | auto tv1 = makeSymbolicTensor(2); |
792 | fusion.addInput(tv1); |
793 | |
794 | auto tv2 = sum(tv0, {1}); |
795 | auto tv3 = sum(tv1, {1}); |
796 | auto tv4 = add(tv2, tv3); |
797 | fusion.addOutput(tv4); |
798 | |
799 | // Invalid grouping as tv2 and tv3 are not guaranteed to have the |
800 | // same shape |
801 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
802 | ASSERT_ANY_THROW(groupReductions({tv2, tv3})); |
803 | } |
804 | |
805 | // Testing validation |
806 | TEST_F(NVFuserTest, FusionGroupedReduction5_CUDA) { |
807 | Fusion fusion; |
808 | FusionGuard fg(&fusion); |
809 | |
810 | auto tv0 = makeSymbolicTensor(2); |
811 | fusion.addInput(tv0); |
812 | |
813 | auto tv1 = sum(tv0, {1}); |
814 | auto tv2 = sum(tv0, {1}); |
815 | auto tv3 = add(tv1, tv2); |
816 | fusion.addOutput(tv3); |
817 | |
818 | tv1->split(1, 128); |
819 | tv2->split(1, 64); |
820 | |
821 | // Invalid grouping as tv1 and tv2 don't have the same |
822 | // transformations |
823 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
824 | ASSERT_ANY_THROW(groupReductions({tv1, tv2})); |
825 | } |
826 | |
827 | // Grouping 3 reductions |
828 | TEST_F(NVFuserTest, FusionGroupedReduction6_CUDA) { |
829 | Fusion fusion; |
830 | FusionGuard fg(&fusion); |
831 | |
832 | auto tv0 = makeSymbolicTensor(2); |
833 | fusion.addInput(tv0); |
834 | |
835 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
836 | auto tv2 = sum(tv1, {1}); |
837 | |
838 | auto tv3 = add(tv0, IrBuilder::create<Double>(2)); |
839 | auto tv4 = sum(tv3, {1}); |
840 | |
841 | auto tv5 = add(tv0, IrBuilder::create<Double>(3)); |
842 | auto tv6 = sum(tv5, {1}); |
843 | |
844 | auto tv7 = add(add(tv2, tv4), tv6); |
845 | |
846 | fusion.addOutput(tv7); |
847 | |
848 | groupReductions({tv2, tv4, tv6}); |
849 | |
850 | // There's no runtime grid reduction function that can take more |
851 | // than 2 inputs, yet. |
852 | tv2->axis(0)->parallelize(ParallelType::BIDx); |
853 | tv2->axis(1)->parallelize(ParallelType::TIDx); |
854 | |
855 | scheduler_utils::parallelizeAllLike(tv2); |
856 | |
857 | std::vector<int64_t> shape({99, 999}); |
858 | |
859 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
860 | |
861 | auto t0 = at::randn(shape, options); |
862 | |
863 | FusionExecutor fe; |
864 | fe.compileFusion(&fusion, {t0}); |
865 | auto outputs = fe.runFusion({t0}); |
866 | |
867 | auto ref = (t0 + 1).sum({1}) + (t0 + 2).sum({1}) + (t0 + 3).sum({1}); |
868 | |
869 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
870 | } |
871 | |
872 | TEST_F(NVFuserTest, FusionGroupedReduction7_CUDA) { |
873 | Fusion fusion; |
874 | FusionGuard fg(&fusion); |
875 | |
876 | auto tv0 = makeSymbolicTensor(2); |
877 | fusion.addInput(tv0); |
878 | |
879 | auto tv1 = sum(tv0, {1}); |
880 | auto tv2 = broadcast(tv1, {false, true}); |
881 | auto tv3 = add(tv0, tv2); |
882 | auto tv4 = sum(tv3, {1}); |
883 | fusion.addOutput(tv4); |
884 | |
885 | // Invalid grouping as tv3 depends on tv1 |
886 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
887 | ASSERT_ANY_THROW(groupReductions({tv1, tv4})); |
888 | } |
889 | |
890 | // Grouping rfactor'ed reductions |
891 | TEST_F(NVFuserTest, FusionGroupedReductionRfactor1_CUDA) { |
892 | Fusion fusion; |
893 | FusionGuard fg(&fusion); |
894 | |
895 | auto tv0 = makeSymbolicTensor(1); |
896 | fusion.addInput(tv0); |
897 | |
898 | auto tv1 = sum(tv0, {0}); |
899 | auto tv2 = sum(tv0, {0}); |
900 | auto tv3 = add(tv1, tv2); |
901 | fusion.addOutput(tv3); |
902 | |
903 | const size_t gdimx = 10; |
904 | const size_t bdimx = 128; |
905 | |
906 | tv1->split(0, gdimx, false); |
907 | tv1->split(1, bdimx); |
908 | auto tv1_rf = tv1->rFactor({1}); |
909 | |
910 | tv2->split(0, gdimx, false); |
911 | tv2->split(1, bdimx); |
912 | auto tv2_rf = tv2->rFactor({1}); |
913 | |
914 | groupReductions({tv1_rf, tv2_rf}); |
915 | groupReductions({tv1, tv2}); |
916 | |
917 | tv1_rf->axis(0)->parallelize(ParallelType::BIDx); |
918 | tv1_rf->axis(2)->parallelize(ParallelType::TIDx); |
919 | |
920 | scheduler_utils::parallelizeAllLike(tv1_rf); |
921 | |
922 | std::vector<int64_t> shape({12345}); |
923 | |
924 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
925 | |
926 | auto t0 = at::randn(shape, options); |
927 | |
928 | FusionExecutor fe; |
929 | fe.compileFusion(&fusion, {t0}); |
930 | auto outputs = fe.runFusion({t0}); |
931 | |
932 | auto ref = t0.sum({0}) * 2; |
933 | |
934 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
935 | } |
936 | |
937 | // Rfactoring grouped reductions |
938 | TEST_F(NVFuserTest, FusionGroupedReductionRfactor2_CUDA) { |
939 | Fusion fusion; |
940 | FusionGuard fg(&fusion); |
941 | |
942 | auto tv0 = makeSymbolicTensor(1); |
943 | fusion.addInput(tv0); |
944 | |
945 | auto tv1 = sum(tv0, {0}); |
946 | auto tv2 = sum(tv0, {0}); |
947 | auto tv3 = add(tv1, tv2); |
948 | fusion.addOutput(tv3); |
949 | |
950 | groupReductions({tv1, tv2}); |
951 | |
952 | const size_t gdimx = 10; |
953 | const size_t bdimx = 128; |
954 | |
955 | tv1->split(0, gdimx, false); |
956 | tv1->split(1, bdimx); |
957 | |
958 | // This should rfactor tv2 as well |
959 | auto rf_tvs = tv1->rFactor({1}, {tv1, tv2}); |
960 | auto tv1_rf = rf_tvs.at(0); |
961 | |
962 | tv1_rf->axis(0)->parallelize(ParallelType::BIDx); |
963 | tv1_rf->axis(2)->parallelize(ParallelType::TIDx); |
964 | |
965 | scheduler_utils::parallelizeAllLike(tv1_rf); |
966 | |
967 | std::vector<int64_t> shape({12345}); |
968 | |
969 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
970 | |
971 | auto t0 = at::randn(shape, options); |
972 | |
973 | FusionExecutor fe; |
974 | fe.compileFusion(&fusion, {t0}); |
975 | auto outputs = fe.runFusion({t0}); |
976 | |
977 | auto ref = t0.sum({0}) * 2; |
978 | |
979 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
980 | } |
981 | |
982 | // Group reductions of tensors that have computeAt positions set |
983 | TEST_F(NVFuserTest, FusionGroupedReductionAfterComputeAt_CUDA) { |
984 | Fusion fusion; |
985 | FusionGuard fg(&fusion); |
986 | auto tv0 = makeSymbolicTensor(2); |
987 | fusion.addInput(tv0); |
988 | |
989 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
990 | auto tv2 = sum(tv1, {1}); |
991 | auto tv3 = sum(tv1, {1}); |
992 | auto tv4 = add(tv2, tv3); |
993 | fusion.addOutput(tv4); |
994 | |
995 | const size_t bdimx = 128; |
996 | |
997 | tv2->split(1, bdimx); |
998 | auto tv2_rf = tv2->rFactor({1}); |
999 | tv2_rf->reorder({{1, 2}}); |
1000 | |
1001 | tv3->split(1, bdimx); |
1002 | auto tv3_rf = tv3->rFactor({1}); |
1003 | tv3_rf->reorder({{1, 2}}); |
1004 | |
1005 | tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined); |
1006 | |
1007 | groupReductions({tv2_rf, tv3_rf}); |
1008 | groupReductions({tv2, tv3}); |
1009 | |
1010 | tv2->axis(1)->parallelize(ParallelType::TIDx); |
1011 | scheduler_utils::parallelizeAllLike(tv2); |
1012 | |
1013 | std::vector<int64_t> shape({3, 1234}); |
1014 | |
1015 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1016 | |
1017 | auto t0 = at::randn(shape, options); |
1018 | |
1019 | FusionExecutor fe; |
1020 | fe.compileFusion(&fusion, {t0}); |
1021 | auto outputs = fe.runFusion({t0}); |
1022 | |
1023 | auto ref = (t0 + 1).sum({1}) * 2; |
1024 | |
1025 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
1026 | } |
1027 | |
1028 | TEST_F(NVFuserTest, FusionGroupAllreduce1_CUDA) { |
1029 | Fusion fusion; |
1030 | FusionGuard fg(&fusion); |
1031 | |
1032 | auto tv0 = makeSymbolicTensor(1); |
1033 | fusion.addInput(tv0); |
1034 | |
1035 | auto tv1 = sum(tv0, {0}); |
1036 | auto tv2 = broadcast(tv1, {true}); |
1037 | auto tv3 = sum(tv0, {0}); |
1038 | auto tv4 = broadcast(tv3, {true}); |
1039 | auto tv5 = add(tv0, tv2); |
1040 | auto tv6 = add(tv5, tv4); |
1041 | fusion.addOutput(tv6); |
1042 | |
1043 | groupReductions({tv1, tv3}); |
1044 | |
1045 | tv2->split(0, 128); |
1046 | TransformPropagator propagator(tv2); |
1047 | MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); |
1048 | |
1049 | tv2->axis(0)->parallelize(ParallelType::BIDx); |
1050 | tv2->axis(1)->parallelize(ParallelType::TIDx); |
1051 | scheduler_utils::parallelizeAllLike(tv2); |
1052 | |
1053 | std::vector<int64_t> shape({999}); |
1054 | |
1055 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1056 | |
1057 | auto t0 = at::randn(shape, options); |
1058 | |
1059 | FusionExecutor fe; |
1060 | fe.compileFusion(&fusion, {t0}); |
1061 | auto outputs = fe.runFusion({t0}); |
1062 | |
1063 | auto t3 = t0.sum({0}).unsqueeze(-1); |
1064 | auto ref = t0 + t3 + t3; |
1065 | |
1066 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
1067 | } |
1068 | |
1069 | // Grid reductionso of different types |
1070 | TEST_F(NVFuserTest, FusionGroupAllreduce2_CUDA) { |
1071 | Fusion fusion; |
1072 | FusionGuard fg(&fusion); |
1073 | |
1074 | auto tv0 = makeSymbolicTensor(2); |
1075 | fusion.addInput(tv0); |
1076 | |
1077 | auto tv1 = sum(tv0, {1}); |
1078 | auto tv2 = broadcast(tv1, {false, true}); |
1079 | |
1080 | auto tv3 = castOp(DataType::Double, tv0); |
1081 | auto tv4 = sum(tv3, {1}); |
1082 | auto tv5 = broadcast(tv4, {false, true}); |
1083 | auto tv6 = castOp(DataType::Float, tv5); |
1084 | |
1085 | auto tv7 = add(tv0, tv2); |
1086 | auto tv8 = add(tv7, tv6); |
1087 | fusion.addOutput(tv8); |
1088 | |
1089 | const int tidx = 512; |
1090 | groupReductions({tv1, tv4}); |
1091 | tv1->split(1, tidx); |
1092 | TransformPropagator propagator(tv1); |
1093 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
1094 | |
1095 | tv0->computeAt(tv8, -1, ComputeAtMode::MostInlined); |
1096 | |
1097 | tv1->axis(0)->parallelize(ParallelType::BIDy); |
1098 | tv1->axis(1)->parallelize(ParallelType::BIDx); |
1099 | tv1->axis(2)->parallelize(ParallelType::TIDx); |
1100 | scheduler_utils::parallelizeAllLike(tv1); |
1101 | |
1102 | std::vector<int64_t> shape({10, 999}); |
1103 | |
1104 | if (shape.at(0) * ceilDiv(shape.at(1), tidx) > deviceSMCount()) { |
1105 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
1106 | } |
1107 | |
1108 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1109 | |
1110 | auto t0 = at::randn(shape, options); |
1111 | |
1112 | FusionExecutor fe; |
1113 | fe.compileFusion(&fusion, {t0}); |
1114 | auto outputs = fe.runFusion({t0}); |
1115 | |
1116 | auto t2 = t0.sum({1}).unsqueeze(-1); |
1117 | auto t6 = t0.to(c10::kDouble).sum({1}).unsqueeze(-1).to(c10::kFloat); |
1118 | auto ref = t0 + t2 + t6; |
1119 | |
1120 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
1121 | } |
1122 | |
1123 | // Grouping 3 grid allreduces |
1124 | TEST_F(NVFuserTest, FusionGroupAllreduce3_CUDA) { |
1125 | Fusion fusion; |
1126 | FusionGuard fg(&fusion); |
1127 | |
1128 | auto tv0 = makeSymbolicTensor(1); |
1129 | fusion.addInput(tv0); |
1130 | |
1131 | auto tv1 = sum(tv0, {0}); |
1132 | auto tv2 = broadcast(tv1, {true}); |
1133 | auto tv3 = div(tv0, tv2); |
1134 | auto tv4 = max(tv0, {0}); |
1135 | auto tv5 = broadcast(tv4, {true}); |
1136 | auto tv6 = div(tv0, tv5); |
1137 | auto tv7 = min(tv0, {0}); |
1138 | auto tv8 = broadcast(tv7, {true}); |
1139 | auto tv9 = sub(tv0, tv8); |
1140 | fusion.addOutput(tv3); |
1141 | fusion.addOutput(tv6); |
1142 | fusion.addOutput(tv9); |
1143 | |
1144 | groupReductions({tv1, tv4, tv7}); |
1145 | |
1146 | tv1->split(0, 128); |
1147 | TransformPropagator propagator(tv1); |
1148 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
1149 | |
1150 | tv1->axis(0)->parallelize(ParallelType::BIDx); |
1151 | tv1->axis(1)->parallelize(ParallelType::TIDx); |
1152 | scheduler_utils::parallelizeAllLike(tv1); |
1153 | |
1154 | std::vector<int64_t> shape({999}); |
1155 | |
1156 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1157 | |
1158 | auto t0 = at::randn(shape, options); |
1159 | |
1160 | FusionExecutor fe; |
1161 | fe.compileFusion(&fusion, {t0}); |
1162 | auto outputs = fe.runFusion({t0}); |
1163 | |
1164 | auto t3 = t0 / t0.sum({0}).unsqueeze(0); |
1165 | auto t6 = t0 / std::get<0>(t0.max(0)).unsqueeze(0); |
1166 | auto t9 = t0 - std::get<0>(t0.min(0)).unsqueeze(0); |
1167 | |
1168 | testValidate(fe.kernel(), outputs, {t0}, {t3, t6, t9}, __LINE__, __FILE__); |
1169 | } |
1170 | |
1171 | // Grouping 8 grid allreduces |
1172 | TEST_F(NVFuserTest, FusionGroupAllreduce4_CUDA) { |
1173 | Fusion fusion; |
1174 | FusionGuard fg(&fusion); |
1175 | |
1176 | const int num_reductions = 8; |
1177 | |
1178 | auto tv0 = makeSymbolicTensor(1); |
1179 | fusion.addInput(tv0); |
1180 | |
1181 | auto tv_sum = tv0; |
1182 | std::vector<TensorView*> reduction_tvs; |
1183 | |
1184 | for (int i = 0; i < num_reductions; ++i) { |
1185 | auto reduction = sum(add(tv0, IrBuilder::create<Double>(i)), {0}); |
1186 | reduction_tvs.push_back(reduction); |
1187 | auto avg = div(reduction, tv0->axis(0)->extent()); |
1188 | auto bc = broadcast(avg, {true}); |
1189 | tv_sum = add(tv_sum, bc); |
1190 | } |
1191 | |
1192 | fusion.addOutput(tv_sum); |
1193 | |
1194 | groupReductions(reduction_tvs); |
1195 | |
1196 | auto reduction_tv = reduction_tvs.at(0); |
1197 | |
1198 | reduction_tv->split(0, 128); |
1199 | TransformPropagator propagator(reduction_tv); |
1200 | MaxRootDomainInfoSpanningTree(reduction_tv).traverse(&propagator); |
1201 | |
1202 | reduction_tv->axis(0)->parallelize(ParallelType::BIDx); |
1203 | reduction_tv->axis(1)->parallelize(ParallelType::TIDx); |
1204 | scheduler_utils::parallelizeAllLike(reduction_tv); |
1205 | |
1206 | std::vector<int64_t> shape({999}); |
1207 | |
1208 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1209 | |
1210 | auto t0 = at::randn(shape, options); |
1211 | |
1212 | FusionExecutor fe; |
1213 | fe.compileFusion(&fusion, {t0}); |
1214 | auto outputs = fe.runFusion({t0}); |
1215 | |
1216 | at::Tensor ref = t0; |
1217 | for (int i = 0; i < num_reductions; ++i) { |
1218 | auto reduction = sum(add(t0, i), {0}); |
1219 | auto avg = reduction / t0.sizes()[0]; |
1220 | auto bc = avg.unsqueeze(0); |
1221 | ref = add(ref, bc); |
1222 | } |
1223 | |
1224 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
1225 | } |
1226 | |
1227 | // Variation of FusionGroupAllreduce5_CUDA but with different |
1228 | // types. Exercise grouped allreduces with different types. |
1229 | TEST_F(NVFuserTest, FusionGroupAllreduce5_CUDA) { |
1230 | Fusion fusion; |
1231 | FusionGuard fg(&fusion); |
1232 | |
1233 | auto tv0 = makeSymbolicTensor(1, DataType::Float); |
1234 | fusion.addInput(tv0); |
1235 | auto tv1 = sum(tv0, {0}); |
1236 | auto tv2 = broadcast(tv1, {true}); |
1237 | auto tv3 = div(tv0, tv2); |
1238 | |
1239 | auto tv4 = makeSymbolicTensor(1, DataType::Double); |
1240 | fusion.addInput(tv4); |
1241 | auto tv5 = sum(tv4, {0}); |
1242 | auto tv6 = broadcast(tv5, {true}); |
1243 | auto tv7 = div(tv4, tv6); |
1244 | |
1245 | auto tv8 = makeSymbolicTensor(1, DataType::Int); |
1246 | fusion.addInput(tv8); |
1247 | auto tv9 = sum(tv8, {0}); |
1248 | auto tv10 = broadcast(tv9, {true}); |
1249 | auto tv11 = div(tv8, tv10); |
1250 | |
1251 | auto out = add( |
1252 | add(castOp(DataType::Double, tv3), tv7), castOp(DataType::Double, tv11)); |
1253 | |
1254 | fusion.addOutput(out); |
1255 | |
1256 | groupReductions({tv1, tv5, tv9}); |
1257 | |
1258 | tv1->split(0, 128); |
1259 | TransformPropagator propagator(tv1); |
1260 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
1261 | |
1262 | tv1->axis(0)->parallelize(ParallelType::BIDx); |
1263 | tv1->axis(1)->parallelize(ParallelType::TIDx); |
1264 | scheduler_utils::parallelizeAllLike(tv1); |
1265 | |
1266 | std::vector<int64_t> shape({999}); |
1267 | |
1268 | auto options_float = |
1269 | at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1270 | auto options_double = |
1271 | at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); |
1272 | auto options_long = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); |
1273 | |
1274 | auto t0 = at::randn(shape, options_float); |
1275 | auto t4 = at::randn(shape, options_double); |
1276 | auto t8 = torch::randint(0, 1000, shape, options_long); |
1277 | std::vector<IValue> aten_inputs = {t0, t4, t8}; |
1278 | |
1279 | std::vector<at::indexing::TensorIndex> indices({at::indexing::Slice(0, 10)}); |
1280 | |
1281 | FusionExecutor fe; |
1282 | fe.compileFusion(&fusion, aten_inputs); |
1283 | auto outputs = fe.runFusion(aten_inputs); |
1284 | |
1285 | auto t3 = t0 / t0.sum({0}).unsqueeze(0).to(at::kDouble); |
1286 | auto t7 = t4 / t4.sum({0}).unsqueeze(0); |
1287 | auto t11 = t8 / t8.sum({0}).unsqueeze(0).to(at::kDouble); |
1288 | auto ref = t3 + t7 + t11; |
1289 | |
1290 | testValidate(fe.kernel(), outputs, aten_inputs, {ref}, __LINE__, __FILE__); |
1291 | } |
1292 | |
1293 | // Persistent batchnorm backward with grouped allreduce |
1294 | TEST_F(NVFuserTest, FusionPersistentBNBackwardAllreduce_CUDA) { |
1295 | const std::vector<int64_t> shape({64, 1024, 14, 14}); |
1296 | |
1297 | Fusion fusion; |
1298 | FusionGuard fg(&fusion); |
1299 | |
1300 | auto input = makeContigTensor(4); |
1301 | fusion.addInput(input); |
1302 | auto grad_output = makeContigTensor(4); |
1303 | fusion.addInput(grad_output); |
1304 | auto weight = makeContigTensor(1); |
1305 | fusion.addInput(weight); |
1306 | auto save_mean = makeContigTensor(1); |
1307 | fusion.addInput(save_mean); |
1308 | auto save_invstd = makeContigTensor(1); |
1309 | fusion.addInput(save_invstd); |
1310 | |
1311 | const bool kTraining = true; |
1312 | const bool channels_last = false; |
1313 | |
1314 | const size_t kNumberOfDims = |
1315 | TensorDomain::noReductions(input->getMaybeRFactorDomain()).size(); |
1316 | size_t c_axis = channels_last ? kNumberOfDims - 1 : 1; |
1317 | |
1318 | std::vector<int> reduction_axes; |
1319 | std::vector<bool> broadcast_mask(kNumberOfDims, false); |
1320 | Val* num_features = nullptr; |
1321 | for (const auto axis : c10::irange(kNumberOfDims)) { |
1322 | if (axis != c_axis) { |
1323 | reduction_axes.push_back(axis); |
1324 | broadcast_mask[axis] = true; |
1325 | if (num_features == nullptr) { |
1326 | num_features = |
1327 | castOp(DataType::Double, input->domain()->domain()[axis]->extent()); |
1328 | } else { |
1329 | num_features = |
1330 | mul(num_features, input->domain()->domain()[axis]->extent()); |
1331 | } |
1332 | } |
1333 | } |
1334 | |
1335 | auto mean = save_mean; |
1336 | auto invstd = save_invstd; |
1337 | |
1338 | mean = broadcast(mean, broadcast_mask); |
1339 | |
1340 | auto norm = reciprocal(num_features); |
1341 | |
1342 | auto grad_output_sum = sum(grad_output, reduction_axes); |
1343 | auto dot_p = sum(mul(grad_output, sub(input, mean)), reduction_axes); |
1344 | |
1345 | auto grad_mean = broadcast(mul(grad_output_sum, norm), broadcast_mask); |
1346 | |
1347 | auto proj_scale = |
1348 | broadcast(mul(mul(dot_p, norm), mul(invstd, invstd)), broadcast_mask); |
1349 | |
1350 | TensorView* grad_scale = nullptr; |
1351 | |
1352 | if (weight == nullptr) { |
1353 | grad_scale = |
1354 | mul(broadcast(invstd, broadcast_mask), |
1355 | IrBuilder::create<Double>(input->container(), 1)); |
1356 | } else { |
1357 | grad_scale = mul( |
1358 | broadcast(invstd, broadcast_mask), broadcast(weight, broadcast_mask)); |
1359 | } |
1360 | |
1361 | TensorView* grad_input = nullptr; |
1362 | if (kTraining) { |
1363 | auto proj = mul(sub(input, mean), proj_scale); |
1364 | grad_input = mul(sub(sub(grad_output, proj), grad_mean), grad_scale); |
1365 | } else { |
1366 | grad_input = mul(grad_output, grad_scale); |
1367 | } |
1368 | |
1369 | fusion.addOutput(grad_input); |
1370 | |
1371 | // Scheduling strategy |
1372 | // 1. Cache inputs |
1373 | // 2. Group the reductions (automatically fused with broadcasts) |
1374 | // 3. Merge HW and vectorize with the outer parallelized by TIDx |
1375 | // 4. Split N by TIDy with the outer parallelized by BIDx and |
1376 | // inner by TIDy |
1377 | // 5. Split C by BIDy and let the outer be the serial outermost loop |
1378 | |
1379 | auto input_cache = input->cacheAfter(); |
1380 | auto grad_output_cache = grad_output->cacheAfter(); |
1381 | auto weight_cache = weight->cacheAfter(); |
1382 | auto save_mean_cache = save_mean->cacheAfter(); |
1383 | auto save_invstd_cache = save_invstd->cacheAfter(); |
1384 | |
1385 | // Group the two reductions |
1386 | groupReductions({grad_output_sum, dot_p}); |
1387 | |
1388 | // Transform grad_input to: [C/bidy, N/tidy, tidy, bidy, HW/vec_width, |
1389 | // vec_width] |
1390 | const int tidy = 8; |
1391 | const int bidy = 4; |
1392 | const int bidx = ceilDiv(shape[0], (int64_t)tidy); |
1393 | const int vec_width = 4; |
1394 | TORCH_CHECK( |
1395 | (shape[2] * shape[3]) % vec_width == 0, |
1396 | "Invalid vector width: " , |
1397 | vec_width); |
1398 | |
1399 | grad_input->merge(-2, -1); |
1400 | grad_input->split(-1, vec_width); |
1401 | |
1402 | grad_input->split(0, tidy); |
1403 | grad_input->split(2, bidy); |
1404 | TORCH_CHECK( |
1405 | grad_input->nDims() == 6, |
1406 | "Unexpected number of dimensions: " , |
1407 | grad_input->toString()); |
1408 | |
1409 | grad_input->reorder({{2, 0}, {0, 1}, {1, 2}}); |
1410 | |
1411 | grad_input->axis(1)->parallelize(ParallelType::BIDx); |
1412 | grad_input->axis(2)->parallelize(ParallelType::TIDy); |
1413 | grad_input->axis(3)->parallelize(ParallelType::BIDy); |
1414 | grad_input->axis(4)->parallelize(ParallelType::TIDx); |
1415 | |
1416 | TransformPropagator propagator(grad_input); |
1417 | MaxRootDomainInfoSpanningTree(grad_input).traverse(&propagator); |
1418 | |
1419 | auto rf_tensors = grad_output_sum->rFactor( |
1420 | {-1}, std::vector<TensorView*>({grad_output_sum, dot_p})); |
1421 | |
1422 | for (auto fusion_input : |
1423 | ir_utils::filterByType<TensorView>(fusion.inputs())) { |
1424 | fusion_input->computeAt(grad_input, 1); |
1425 | } |
1426 | |
1427 | // Parallelization |
1428 | scheduler_utils::parallelizeAllLike(grad_input); |
1429 | input_cache->axis(-1)->parallelize(ParallelType::Vectorize); |
1430 | grad_output_cache->axis(-1)->parallelize(ParallelType::Vectorize); |
1431 | |
1432 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1433 | auto at_input = at::randn(shape, options); |
1434 | auto at_grad_output = at::randn(shape, options); |
1435 | auto at_weight = at::randn({shape[c_axis]}, options); |
1436 | auto at_save_mean = at::randn({shape[c_axis]}, options); |
1437 | auto at_save_invstd = at::randn({shape[c_axis]}, options); |
1438 | std::vector<IValue> aten_inputs( |
1439 | {at_input, at_grad_output, at_weight, at_save_mean, at_save_invstd}); |
1440 | |
1441 | GpuLower gpulw(&fusion); |
1442 | validateNoParallelBroadcastExist(gpulw.kernel()); |
1443 | |
1444 | FusionExecutor fe; |
1445 | fe.compileFusion(&fusion, aten_inputs); |
1446 | |
1447 | if (bidx * bidy > deviceSMCount()) { |
1448 | GTEST_SKIP() << "Not enough SMs to run this test" ; |
1449 | } |
1450 | |
1451 | auto outputs = fe.runFusion(aten_inputs); |
1452 | |
1453 | std::vector<int64_t> at_reduction_axes; |
1454 | std::copy( |
1455 | reduction_axes.begin(), |
1456 | reduction_axes.end(), |
1457 | std::back_inserter(at_reduction_axes)); |
1458 | |
1459 | // MSVC bug on lambda non-capture of const integral type |
1460 | // https://developercommunity.visualstudio.com/t/lambda-fails-to-implicitly-capture-constexpr-value/610504 |
1461 | auto at_bcast = [=](const auto& tensor) { |
1462 | if (channels_last) { |
1463 | tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0); |
1464 | } else { |
1465 | return tensor.unsqueeze(0).unsqueeze(-1).unsqueeze(-1); |
1466 | } |
1467 | }; |
1468 | |
1469 | auto at_mean = at_save_mean; |
1470 | const auto& at_invstd = at_save_invstd; |
1471 | at_mean = at_bcast(at_mean); |
1472 | auto at_norm = 1.0f / static_cast<float>(shape[0] * shape[2] * shape[3]); |
1473 | |
1474 | auto at_grad_output_sum = sum(at_grad_output, at_reduction_axes); |
1475 | auto at_dot_p = |
1476 | sum(mul(at_grad_output, sub(at_input, at_mean)), at_reduction_axes); |
1477 | |
1478 | auto at_grad_mean = at_bcast(at_grad_output_sum * at_norm); |
1479 | |
1480 | auto at_proj_scale = at_bcast((at_dot_p * at_norm) * (at_invstd * at_invstd)); |
1481 | |
1482 | at::Tensor at_grad_scale; |
1483 | |
1484 | if (weight == nullptr) { |
1485 | at_grad_scale = at_bcast(at_invstd); |
1486 | } else { |
1487 | at_grad_scale = at_bcast(at_invstd) * at_bcast(at_weight); |
1488 | } |
1489 | |
1490 | at::Tensor at_grad_input; |
1491 | if (kTraining) { |
1492 | auto at_proj = (at_input - at_mean) * at_proj_scale; |
1493 | at_grad_input = (at_grad_output - at_proj - at_grad_mean) * at_grad_scale; |
1494 | } else { |
1495 | at_grad_input = at_grad_output * at_grad_scale; |
1496 | } |
1497 | |
1498 | testValidate( |
1499 | fe.kernel(), outputs, aten_inputs, {at_grad_input}, __LINE__, __FILE__); |
1500 | } |
1501 | |
1502 | TEST_F(NVFuserTest, FusionGroupedReductionReEntrant1_CUDA) { |
1503 | Fusion fusion; |
1504 | FusionGuard fg(&fusion); |
1505 | |
1506 | auto tv0 = makeSymbolicTensor(2); |
1507 | fusion.addInput(tv0); |
1508 | |
1509 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1510 | auto tv2 = sum(tv1, {0}); |
1511 | |
1512 | auto tv3 = add(tv0, IrBuilder::create<Double>(2)); |
1513 | auto tv4 = sum(tv3, {0}); |
1514 | |
1515 | auto tv5 = add(tv2, tv4); |
1516 | fusion.addOutput(tv5); |
1517 | |
1518 | groupReductions({tv2, tv4}); |
1519 | |
1520 | auto tv0_cache = tv0->cacheAfter(); |
1521 | |
1522 | const int vec = 2; |
1523 | const int tidx = 64; |
1524 | const int tidy = 8; |
1525 | |
1526 | tv2->split(1, vec); |
1527 | tv2->split(1, tidx); |
1528 | |
1529 | tv2->split(0, tidy); |
1530 | TransformPropagator propagator(tv2); |
1531 | MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); |
1532 | |
1533 | tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); |
1534 | |
1535 | tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined); |
1536 | |
1537 | tv2->axis(0)->parallelize(ParallelType::BIDy); |
1538 | tv2->axis(1)->parallelize(ParallelType::TIDy); |
1539 | tv2->axis(2)->parallelize(ParallelType::BIDx); |
1540 | tv2->axis(3)->parallelize(ParallelType::TIDx); |
1541 | |
1542 | scheduler_utils::parallelizeAllLike(tv2); |
1543 | |
1544 | std::vector<int64_t> shape({99, 999}); |
1545 | |
1546 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1547 | at::manual_seed(0); |
1548 | |
1549 | auto t0 = at::randn(shape, options); |
1550 | |
1551 | FusionExecutor fe; |
1552 | fe.compileFusion(&fusion, {t0}); |
1553 | auto outputs = fe.runFusion({t0}); |
1554 | |
1555 | auto t0_double = t0.to(at::kDouble); |
1556 | auto ref = (t0_double + 1).sum({0}) + (t0_double + 2).sum({0}); |
1557 | |
1558 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
1559 | } |
1560 | |
1561 | // Channels-last batch norm with vectorization. Relies on re-entrant |
1562 | // GroupedGridReduction |
1563 | TEST_F(NVFuserTest, FusionGroupedReductionChannelsLastBatchNormLike_CUDA) { |
1564 | Fusion fusion; |
1565 | FusionGuard fg(&fusion); |
1566 | |
1567 | const std::vector<int64_t> shape({64, 14, 14, 32}); |
1568 | |
1569 | auto tv0 = makeContigTensor(4, DataType::Half); |
1570 | fusion.addInput(tv0); |
1571 | auto tv1 = makeContigTensor(4, DataType::Half); |
1572 | fusion.addInput(tv1); |
1573 | auto tv2 = makeContigTensor(1); |
1574 | fusion.addInput(tv2); |
1575 | |
1576 | std::vector<int> reduction_axes({0, 1, 2}); |
1577 | std::vector<bool> broadcast_mask({true, true, true, false}); |
1578 | |
1579 | auto tv3 = castOp(DataType::Float, tv0); |
1580 | auto tv4 = castOp(DataType::Float, tv1); |
1581 | |
1582 | auto tv5 = sum(tv3, reduction_axes); |
1583 | |
1584 | auto tv6 = broadcast(tv2, broadcast_mask); |
1585 | auto tv7 = sub(tv4, tv6); |
1586 | auto tv8 = mul(tv3, tv7); |
1587 | auto tv9 = sum(tv8, reduction_axes); |
1588 | |
1589 | auto tv10 = castOp(DataType::Half, tv5); |
1590 | auto tv11 = castOp(DataType::Half, tv9); |
1591 | |
1592 | fusion.addOutput(tv10); |
1593 | fusion.addOutput(tv11); |
1594 | |
1595 | groupReductions({tv5, tv9}); |
1596 | |
1597 | // Applies the outer-reduction schedule |
1598 | const int64_t num_channels = shape.back(); |
1599 | const int64_t vector = 2; |
1600 | TORCH_CHECK(num_channels % vector == 0); |
1601 | // Use at most 32 TIDx threads |
1602 | const int64_t tidx = std::min<int64_t>(32l, num_channels / vector); |
1603 | const auto bidx = ceilDiv(num_channels, tidx * vector); |
1604 | |
1605 | const int64_t tidy = 8; |
1606 | const auto bidy = ceilDiv( |
1607 | at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 4, bidx); |
1608 | |
1609 | auto tv0_cache = tv0->cacheAfter(); |
1610 | auto tv1_cache = tv1->cacheAfter(); |
1611 | |
1612 | auto ref = tv5; |
1613 | |
1614 | // Move the reduction domains inner positions |
1615 | ref->reorder({{0, 1}, {1, 2}, {2, 3}, {3, 0}}); |
1616 | |
1617 | // Parallelizing the reduction domains |
1618 | ref->merge(2, 3); |
1619 | ref->merge(1, 2); |
1620 | ref->split(1, tidy); |
1621 | ref->split(1, bidy, false); |
1622 | |
1623 | // Parallelizing the iteration domains |
1624 | ref->split(0, vector); |
1625 | ref->split(0, tidx); |
1626 | |
1627 | // Move the vector axis to the innermost position |
1628 | ref->reorder({{2, 5}, {3, 2}, {4, 3}, {5, 4}}); |
1629 | // Move the serial reduction to the right of the vector axis |
1630 | ref->reorder({{3, 4}, {4, 3}}); |
1631 | |
1632 | TransformPropagator propagator(ref); |
1633 | MaxRootDomainInfoSpanningTree(ref).traverse(&propagator); |
1634 | |
1635 | auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9}); |
1636 | auto tv5_rf = rf_tvs.at(0); |
1637 | auto tv9_rf = rf_tvs.at(1); |
1638 | |
1639 | tv0->computeAt(tv5_rf, -2, ComputeAtMode::BestEffort); |
1640 | tv1->computeAt(tv9_rf, -2, ComputeAtMode::BestEffort); |
1641 | tv3->computeAt(tv5_rf, -1, ComputeAtMode::BestEffort); |
1642 | tv4->computeAt(tv9_rf, -1, ComputeAtMode::BestEffort); |
1643 | |
1644 | ref = tv5_rf; |
1645 | |
1646 | ref->axis(0)->parallelize(ParallelType::BIDx); |
1647 | ref->axis(1)->parallelize(ParallelType::TIDx); |
1648 | ref->axis(2)->parallelize(ParallelType::BIDy); |
1649 | ref->axis(3)->parallelize(ParallelType::TIDy); |
1650 | ref->axis(4)->parallelize(ParallelType::Serial); |
1651 | ref->axis(5)->parallelize(ParallelType::Serial); |
1652 | |
1653 | scheduler_utils::parallelizeAllLike(ref); |
1654 | |
1655 | tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); |
1656 | tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize); |
1657 | |
1658 | auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
1659 | auto options_float = |
1660 | at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1661 | auto t0 = at::randn(shape, options_half); |
1662 | auto t1 = at::randn(shape, options_half); |
1663 | auto t2 = at::randn({shape.back()}, options_float); |
1664 | std::vector<IValue> aten_inputs({t0, t1, t2}); |
1665 | |
1666 | FusionExecutor fe; |
1667 | fe.compileFusion(&fusion, aten_inputs); |
1668 | auto outputs = fe.runFusion(aten_inputs); |
1669 | |
1670 | auto t0_double = t0.to(at::kDouble); |
1671 | auto t1_double = t1.to(at::kDouble); |
1672 | auto t2_double = t2.to(at::kDouble); |
1673 | |
1674 | std::vector<int64_t> at_reduction_axes( |
1675 | {reduction_axes.begin(), reduction_axes.end()}); |
1676 | auto t5 = t0_double.sum(at_reduction_axes); |
1677 | auto t8 = t0_double * |
1678 | (t1_double - t2_double.unsqueeze(0).unsqueeze(0).unsqueeze(0)); |
1679 | auto t9 = t8.sum(at_reduction_axes); |
1680 | |
1681 | testValidate(fe.kernel(), outputs, aten_inputs, {t5, t9}, __LINE__, __FILE__); |
1682 | } |
1683 | |
1684 | // Test the grouped grid allreduce with BN-like outer reductions |
1685 | TEST_F( |
1686 | NVFuserTest, |
1687 | FusionGroupedReductionPersistentChannelsLastBatchNormLike_CUDA) { |
1688 | Fusion fusion; |
1689 | FusionGuard fg(&fusion); |
1690 | |
1691 | const std::vector<int64_t> shape({64, 14, 14, 32}); |
1692 | |
1693 | auto tv0 = makeContigTensor(4, DataType::Half); |
1694 | fusion.addInput(tv0); |
1695 | auto tv1 = makeContigTensor(4, DataType::Half); |
1696 | fusion.addInput(tv1); |
1697 | auto tv2 = makeContigTensor(1); |
1698 | fusion.addInput(tv2); |
1699 | |
1700 | std::vector<int> reduction_axes({0, 1, 2}); |
1701 | std::vector<bool> broadcast_mask({true, true, true, false}); |
1702 | |
1703 | auto tv3 = castOp(DataType::Float, tv0); |
1704 | auto tv4 = castOp(DataType::Float, tv1); |
1705 | |
1706 | auto tv5 = sum(tv3, reduction_axes); |
1707 | |
1708 | auto tv6 = broadcast(tv2, broadcast_mask); |
1709 | auto tv7 = sub(tv4, tv6); |
1710 | auto tv8 = mul(tv3, tv7); |
1711 | auto tv9 = sum(tv8, reduction_axes); |
1712 | |
1713 | auto tv10 = broadcast(tv5, broadcast_mask); |
1714 | auto tv11 = add(tv3, tv10); |
1715 | |
1716 | auto tv12 = broadcast(tv9, broadcast_mask); |
1717 | auto tv13 = add(tv4, tv12); |
1718 | |
1719 | auto tv14 = castOp(DataType::Half, tv11); |
1720 | auto tv15 = castOp(DataType::Half, tv13); |
1721 | |
1722 | fusion.addOutput(tv14); |
1723 | fusion.addOutput(tv15); |
1724 | |
1725 | groupReductions({tv5, tv9}); |
1726 | |
1727 | // Applies the outer-reduction schedule |
1728 | const int64_t num_channels = shape.back(); |
1729 | const int64_t vector = 2; |
1730 | TORCH_CHECK(num_channels % vector == 0); |
1731 | // Use at most 32 TIDx threads |
1732 | const int64_t tidx = std::min<int64_t>(32l, num_channels / vector); |
1733 | const auto bidx = ceilDiv(num_channels, tidx * vector); |
1734 | |
1735 | const int64_t tidy = 8; |
1736 | const int64_t reduction_work_per_thread = 8; |
1737 | |
1738 | auto tv0_cache = tv0->cacheAfter(); |
1739 | auto tv1_cache = tv1->cacheAfter(); |
1740 | |
1741 | auto ref = tv5; |
1742 | |
1743 | // Move the reduction domains inner positions |
1744 | ref->reorder({{0, 1}, {1, 2}, {2, 3}, {3, 0}}); |
1745 | |
1746 | // Parallelizing the reduction domains |
1747 | ref->merge(2, 3); |
1748 | ref->merge(1, 2); |
1749 | ref->split(1, tidy); |
1750 | ref->split(1, reduction_work_per_thread); |
1751 | |
1752 | // Parallelizing the iteration domains |
1753 | ref->split(0, vector); |
1754 | ref->split(0, tidx); |
1755 | |
1756 | // Move the vector axis to the innermost position |
1757 | ref->reorder({{2, 5}, {3, 2}, {4, 3}, {5, 4}}); |
1758 | // Move the serial reduction to the right of the vector axis |
1759 | ref->reorder({{3, 4}, {4, 3}}); |
1760 | |
1761 | TransformPropagator propagator(ref); |
1762 | MaxRootDomainInfoSpanningTree(ref).traverse(&propagator); |
1763 | |
1764 | auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9}); |
1765 | auto tv5_rf = rf_tvs.at(0); |
1766 | auto tv9_rf = rf_tvs.at(1); |
1767 | |
1768 | tv0->computeAt(tv5_rf, -2, ComputeAtMode::BestEffort); |
1769 | tv1->computeAt(tv9_rf, -2, ComputeAtMode::BestEffort); |
1770 | tv3->computeAt(tv5_rf, -1, ComputeAtMode::BestEffort); |
1771 | tv4->computeAt(tv9_rf, -1, ComputeAtMode::BestEffort); |
1772 | |
1773 | ref = tv5_rf; |
1774 | |
1775 | ref->axis(0)->parallelize(ParallelType::BIDx); |
1776 | ref->axis(1)->parallelize(ParallelType::TIDx); |
1777 | ref->axis(2)->parallelize(ParallelType::BIDy); |
1778 | ref->axis(3)->parallelize(ParallelType::TIDy); |
1779 | ref->axis(4)->parallelize(ParallelType::Serial); |
1780 | ref->axis(5)->parallelize(ParallelType::Serial); |
1781 | |
1782 | scheduler_utils::parallelizeAllLike(ref); |
1783 | |
1784 | tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); |
1785 | tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize); |
1786 | |
1787 | tv5->axis(-1)->parallelize(ParallelType::Group); |
1788 | |
1789 | auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); |
1790 | auto options_float = |
1791 | at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1792 | auto t0 = at::randn(shape, options_half); |
1793 | auto t1 = at::randn(shape, options_half); |
1794 | auto t2 = at::randn({shape.back()}, options_float); |
1795 | std::vector<IValue> aten_inputs({t0, t1, t2}); |
1796 | |
1797 | FusionExecutor fe; |
1798 | fe.compileFusion(&fusion, aten_inputs); |
1799 | auto outputs = fe.runFusion(aten_inputs); |
1800 | |
1801 | auto t0_double = t0.to(at::kDouble); |
1802 | auto t1_double = t1.to(at::kDouble); |
1803 | auto t2_double = t2.to(at::kDouble); |
1804 | |
1805 | std::vector<int64_t> at_reduction_axes( |
1806 | {reduction_axes.begin(), reduction_axes.end()}); |
1807 | auto t5 = t0_double.sum(at_reduction_axes); |
1808 | auto t8 = t0_double * |
1809 | (t1_double - t2_double.unsqueeze(0).unsqueeze(0).unsqueeze(0)); |
1810 | auto t9 = t8.sum(at_reduction_axes); |
1811 | |
1812 | auto t10 = t5.unsqueeze(0).unsqueeze(0).unsqueeze(0); |
1813 | auto t11 = t0_double + t10; |
1814 | auto t12 = t9.unsqueeze(0).unsqueeze(0).unsqueeze(0); |
1815 | auto t13 = t1_double + t12; |
1816 | |
1817 | testValidate( |
1818 | fe.kernel(), outputs, aten_inputs, {t11, t13}, __LINE__, __FILE__); |
1819 | } |
1820 | |
1821 | TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce1_CUDA) { |
1822 | Fusion fusion; |
1823 | FusionGuard fg(&fusion); |
1824 | |
1825 | auto tv0 = makeSymbolicTensor(2); |
1826 | fusion.addInput(tv0); |
1827 | |
1828 | auto tv1 = set(tv0); |
1829 | auto tv2 = sum(tv1, {0}); |
1830 | auto tv3 = broadcast(tv2, {true, false}); |
1831 | auto tv4 = add(tv0, tv3); |
1832 | fusion.addOutput(tv4); |
1833 | |
1834 | const int vec = 2; |
1835 | const int tidx = 32; |
1836 | const int tidy = 8; |
1837 | |
1838 | tv1->split(1, vec); |
1839 | tv1->split(1, tidx); |
1840 | tv1->split(0, tidy); |
1841 | TransformPropagator propagator(tv1); |
1842 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
1843 | |
1844 | tv1->axis(0)->parallelize(ParallelType::BIDy); |
1845 | tv1->axis(1)->parallelize(ParallelType::TIDy); |
1846 | tv1->axis(2)->parallelize(ParallelType::BIDx); |
1847 | tv1->axis(3)->parallelize(ParallelType::TIDx); |
1848 | |
1849 | scheduler_utils::parallelizeAllLike(tv1); |
1850 | |
1851 | tv2->axis(4)->parallelize(ParallelType::Group); |
1852 | |
1853 | // Make sure the reduction expr is converted to GroupedGridReduciton |
1854 | // and the non-reduction domains of the output TV are either |
1855 | // grouped or parallelized |
1856 | GpuLower gpulw(&fusion); |
1857 | bool validated = false; |
1858 | for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { |
1859 | auto grouped_grid_reduction = |
1860 | dynamic_cast<kir::GroupedGridReduction*>(expr); |
1861 | if (grouped_grid_reduction == nullptr) { |
1862 | continue; |
1863 | } |
1864 | auto out = ir_utils::getTvOutput(grouped_grid_reduction); |
1865 | for (auto out_axis : out->domain()->domain()) { |
1866 | auto out_axis_pt = out_axis->getParallelType(); |
1867 | TORCH_CHECK( |
1868 | isParallelTypeThread(out_axis_pt) || |
1869 | out_axis_pt == ParallelType::Group, |
1870 | "Invalid parallel type of the reduction tensor: " , |
1871 | out_axis_pt, |
1872 | ". Reduction output tensor: " , |
1873 | out->toString()); |
1874 | } |
1875 | validated = true; |
1876 | } |
1877 | TORCH_CHECK( |
1878 | validated, "Invalid lowered kernel. No GroupedGridReduction found." ); |
1879 | |
1880 | std::vector<int64_t> shape({99, 101}); |
1881 | |
1882 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1883 | at::manual_seed(0); |
1884 | auto t0 = at::randn(shape, options); |
1885 | |
1886 | FusionExecutor fe; |
1887 | fe.compileFusion(&fusion, {t0}); |
1888 | auto outputs = fe.runFusion({t0}); |
1889 | |
1890 | auto t0_double = t0.to(at::kDouble); |
1891 | auto ref = t0_double + t0_double.sum({0}).unsqueeze(0); |
1892 | |
1893 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
1894 | } |
1895 | |
1896 | // Test grouping of two domains |
1897 | TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce2_CUDA) { |
1898 | Fusion fusion; |
1899 | FusionGuard fg(&fusion); |
1900 | |
1901 | auto tv0 = makeSymbolicTensor(2); |
1902 | fusion.addInput(tv0); |
1903 | |
1904 | auto tv1 = set(tv0); |
1905 | auto tv2 = sum(tv1, {0}); |
1906 | auto tv3 = broadcast(tv2, {true, false}); |
1907 | auto tv4 = add(tv0, tv3); |
1908 | fusion.addOutput(tv4); |
1909 | |
1910 | const int vec1 = 2; |
1911 | const int vec2 = 3; |
1912 | const int tidx = 16; |
1913 | const int tidy = 8; |
1914 | |
1915 | tv1->split(1, vec1); |
1916 | tv1->split(1, vec2); |
1917 | tv1->split(1, tidx); |
1918 | tv1->split(0, tidy); |
1919 | TransformPropagator propagator(tv1); |
1920 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
1921 | |
1922 | tv1->axis(0)->parallelize(ParallelType::BIDy); |
1923 | tv1->axis(1)->parallelize(ParallelType::TIDy); |
1924 | tv1->axis(2)->parallelize(ParallelType::BIDx); |
1925 | tv1->axis(3)->parallelize(ParallelType::TIDx); |
1926 | |
1927 | scheduler_utils::parallelizeAllLike(tv1); |
1928 | |
1929 | tv2->axis(4)->parallelize(ParallelType::Group); |
1930 | tv2->axis(5)->parallelize(ParallelType::Group); |
1931 | |
1932 | std::vector<int64_t> shape({99, 129}); |
1933 | |
1934 | // Make sure the reduction expr is converted to GroupedGridReduciton |
1935 | // and the non-reduction domains of the output TV are either |
1936 | // grouped or parallelized |
1937 | GpuLower gpulw(&fusion); |
1938 | bool validated = false; |
1939 | for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { |
1940 | auto grouped_grid_reduction = |
1941 | dynamic_cast<kir::GroupedGridReduction*>(expr); |
1942 | if (grouped_grid_reduction == nullptr) { |
1943 | continue; |
1944 | } |
1945 | auto out = ir_utils::getTvOutput(grouped_grid_reduction); |
1946 | for (auto out_axis : out->domain()->domain()) { |
1947 | auto out_axis_pt = out_axis->getParallelType(); |
1948 | TORCH_CHECK( |
1949 | isParallelTypeThread(out_axis_pt) || |
1950 | out_axis_pt == ParallelType::Group, |
1951 | "Invalid parallel type of the reduction tensor: " , |
1952 | out_axis_pt, |
1953 | ". Reduction output tensor: " , |
1954 | out->toString()); |
1955 | } |
1956 | validated = true; |
1957 | } |
1958 | TORCH_CHECK( |
1959 | validated, "Invalid lowered kernel. No GroupedGridReduction found." ); |
1960 | |
1961 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
1962 | at::manual_seed(0); |
1963 | auto t0 = at::randn(shape, options); |
1964 | |
1965 | FusionExecutor fe; |
1966 | fe.compileFusion(&fusion, {t0}); |
1967 | auto outputs = fe.runFusion({t0}); |
1968 | |
1969 | auto t0_double = t0.to(at::kDouble); |
1970 | auto ref = t0_double + t0_double.sum({0}).unsqueeze(0); |
1971 | |
1972 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
1973 | } |
1974 | |
1975 | // Group both expressions and iterations |
1976 | TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce3_CUDA) { |
1977 | Fusion fusion; |
1978 | FusionGuard fg(&fusion); |
1979 | |
1980 | auto tv0 = makeSymbolicTensor(2); |
1981 | fusion.addInput(tv0); |
1982 | |
1983 | auto tv1 = add(tv0, IrBuilder::create<Double>(1)); |
1984 | auto tv2 = sum(tv1, {0}); |
1985 | auto tv3 = broadcast(tv2, {true, false}); |
1986 | auto tv4 = add(tv1, tv3); |
1987 | |
1988 | auto tv5 = add(tv0, IrBuilder::create<Double>(2)); |
1989 | auto tv6 = sum(tv5, {0}); |
1990 | auto tv7 = broadcast(tv6, {true, false}); |
1991 | auto tv8 = add(tv5, tv7); |
1992 | |
1993 | auto tv9 = add(tv4, tv8); |
1994 | fusion.addOutput(tv9); |
1995 | |
1996 | groupReductions({tv2, tv6}); |
1997 | |
1998 | const int vec = 2; |
1999 | const int tidx = 32; |
2000 | const int tidy = 8; |
2001 | |
2002 | tv1->split(1, vec); |
2003 | tv1->split(1, tidx); |
2004 | tv1->split(0, tidy); |
2005 | TransformPropagator propagator(tv1); |
2006 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
2007 | |
2008 | tv1->axis(0)->parallelize(ParallelType::BIDy); |
2009 | tv1->axis(1)->parallelize(ParallelType::TIDy); |
2010 | tv1->axis(2)->parallelize(ParallelType::BIDx); |
2011 | tv1->axis(3)->parallelize(ParallelType::TIDx); |
2012 | |
2013 | scheduler_utils::parallelizeAllLike(tv1); |
2014 | |
2015 | tv2->axis(4)->parallelize(ParallelType::Group); |
2016 | |
2017 | // Make sure the reduction expr is converted to GroupedGridReduciton |
2018 | // and the non-reduction domains of the output TV are either |
2019 | // grouped or parallelized |
2020 | GpuLower gpulw(&fusion); |
2021 | bool validated = false; |
2022 | for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { |
2023 | auto grouped_grid_reduction = |
2024 | dynamic_cast<kir::GroupedGridReduction*>(expr); |
2025 | if (grouped_grid_reduction == nullptr) { |
2026 | continue; |
2027 | } |
2028 | auto out = ir_utils::getTvOutput(grouped_grid_reduction); |
2029 | for (auto out_axis : out->domain()->domain()) { |
2030 | auto out_axis_pt = out_axis->getParallelType(); |
2031 | TORCH_CHECK( |
2032 | isParallelTypeThread(out_axis_pt) || |
2033 | out_axis_pt == ParallelType::Group, |
2034 | "Invalid parallel type of the reduction tensor: " , |
2035 | out_axis_pt, |
2036 | ". Reduction output tensor: " , |
2037 | out->toString()); |
2038 | } |
2039 | validated = true; |
2040 | } |
2041 | TORCH_CHECK( |
2042 | validated, "Invalid lowered kernel. No GroupedGridReduction found." ); |
2043 | |
2044 | std::vector<int64_t> shape({99, 101}); |
2045 | |
2046 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2047 | at::manual_seed(0); |
2048 | auto t0 = at::randn(shape, options); |
2049 | |
2050 | FusionExecutor fe; |
2051 | fe.compileFusion(&fusion, {t0}); |
2052 | auto outputs = fe.runFusion({t0}); |
2053 | |
2054 | auto t0_double = t0.to(at::kDouble); |
2055 | auto t4 = t0_double + 1 + (t0_double + 1).sum({0}).unsqueeze(0); |
2056 | auto t8 = t0_double + 2 + (t0_double + 2).sum({0}).unsqueeze(0); |
2057 | auto ref = t4 + t8; |
2058 | |
2059 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
2060 | } |
2061 | |
2062 | // ParallelType::Group with computeAt |
2063 | TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce4_CUDA) { |
2064 | Fusion fusion; |
2065 | FusionGuard fg(&fusion); |
2066 | |
2067 | auto tv0 = makeSymbolicTensor(2); |
2068 | fusion.addInput(tv0); |
2069 | |
2070 | auto tv1 = set(tv0); |
2071 | auto tv2 = sum(tv1, {0}); |
2072 | auto tv3 = broadcast(tv2, {true, false}); |
2073 | auto tv4 = add(tv0, tv3); |
2074 | fusion.addOutput(tv4); |
2075 | |
2076 | const int vec = 2; |
2077 | const int tidx = 32; |
2078 | const int tidy = 8; |
2079 | |
2080 | tv2->reorder({{0, 1}}); |
2081 | tv2->split(0, vec); |
2082 | tv2->split(0, tidx); |
2083 | tv2->split(-1, tidy); |
2084 | |
2085 | TransformPropagator propagator(tv2); |
2086 | MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); |
2087 | |
2088 | tv2->axis(2)->parallelize(ParallelType::Group); |
2089 | |
2090 | // This should avoid inlining the grouped domain |
2091 | tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined); |
2092 | |
2093 | TORCH_CHECK( |
2094 | tv1->getComputeAtPosition() == 2, |
2095 | "Invalid computeAt position: " , |
2096 | tv1->toString()); |
2097 | TORCH_CHECK( |
2098 | tv2->getComputeAtPosition() == 2, |
2099 | "Invalid computeAt position: " , |
2100 | tv2->toString()); |
2101 | |
2102 | tv4->axis(0)->parallelize(ParallelType::BIDx); |
2103 | tv4->axis(1)->parallelize(ParallelType::TIDx); |
2104 | |
2105 | for (auto tv : ir_utils::allTvs(&fusion)) { |
2106 | tv->axis(-2)->parallelize(ParallelType::BIDy); |
2107 | tv->axis(-1)->parallelize(ParallelType::TIDy); |
2108 | } |
2109 | |
2110 | // Make sure the reduction expr is converted to GroupedGridReduciton |
2111 | // and the non-reduction domains of the output TV are either |
2112 | // grouped or parallelized |
2113 | GpuLower gpulw(&fusion); |
2114 | bool validated = false; |
2115 | for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { |
2116 | auto grouped_grid_reduction = |
2117 | dynamic_cast<kir::GroupedGridReduction*>(expr); |
2118 | if (grouped_grid_reduction == nullptr) { |
2119 | continue; |
2120 | } |
2121 | auto out = ir_utils::getTvOutput(grouped_grid_reduction); |
2122 | for (auto out_axis : out->domain()->domain()) { |
2123 | auto out_axis_pt = out_axis->getParallelType(); |
2124 | TORCH_CHECK( |
2125 | isParallelTypeThread(out_axis_pt) || |
2126 | out_axis_pt == ParallelType::Group, |
2127 | "Invalid parallel type of the reduction tensor: " , |
2128 | out_axis_pt, |
2129 | ". Reduction output tensor: " , |
2130 | out->toString()); |
2131 | } |
2132 | validated = true; |
2133 | } |
2134 | TORCH_CHECK( |
2135 | validated, "Invalid lowered kernel. No GroupedGridReduction found." ); |
2136 | |
2137 | std::vector<int64_t> shape({99, 101}); |
2138 | |
2139 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2140 | at::manual_seed(0); |
2141 | auto t0 = at::randn(shape, options); |
2142 | |
2143 | FusionExecutor fe; |
2144 | fe.compileFusion(&fusion, {t0}); |
2145 | auto outputs = fe.runFusion({t0}); |
2146 | |
2147 | auto t0_double = t0.to(at::kDouble); |
2148 | auto ref = t0_double + t0_double.sum({0}).unsqueeze(0); |
2149 | |
2150 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
2151 | } |
2152 | |
2153 | TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelford1_CUDA) { |
2154 | Fusion fusion; |
2155 | FusionGuard fg(&fusion); |
2156 | |
2157 | auto tv0 = makeSymbolicTensor(2); |
2158 | fusion.addInput(tv0); |
2159 | |
2160 | auto tv1 = set(tv0); |
2161 | auto tv2 = Welford(tv1, {0}).avg; |
2162 | auto tv3 = broadcast(tv2, {true, false}); |
2163 | auto tv4 = add(tv0, tv3); |
2164 | fusion.addOutput(tv4); |
2165 | |
2166 | const int vec = 2; |
2167 | const int tidx = 32; |
2168 | const int tidy = 8; |
2169 | |
2170 | tv1->split(1, vec); |
2171 | tv1->split(1, tidx); |
2172 | tv1->split(0, tidy); |
2173 | TransformPropagator propagator(tv1); |
2174 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
2175 | |
2176 | tv1->axis(0)->parallelize(ParallelType::BIDy); |
2177 | tv1->axis(1)->parallelize(ParallelType::TIDy); |
2178 | tv1->axis(2)->parallelize(ParallelType::BIDx); |
2179 | tv1->axis(3)->parallelize(ParallelType::TIDx); |
2180 | |
2181 | scheduler_utils::parallelizeAllLike(tv1); |
2182 | |
2183 | tv2->axis(4)->parallelize(ParallelType::Group); |
2184 | |
2185 | // Make sure the reduction expr is converted to GroupedGridReduciton |
2186 | // and the non-reduction domains of the output TV are either |
2187 | // grouped or parallelized |
2188 | GpuLower gpulw(&fusion); |
2189 | bool validated = false; |
2190 | for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { |
2191 | auto grouped_grid_reduction = dynamic_cast<kir::GroupedGridWelford*>(expr); |
2192 | if (grouped_grid_reduction == nullptr) { |
2193 | continue; |
2194 | } |
2195 | validated = true; |
2196 | } |
2197 | TORCH_CHECK( |
2198 | validated, "Invalid lowered kernel. No GroupedGridWelford found." ); |
2199 | |
2200 | std::vector<int64_t> shape({99, 101}); |
2201 | |
2202 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2203 | at::manual_seed(0); |
2204 | auto t0 = at::randn(shape, options); |
2205 | |
2206 | FusionExecutor fe; |
2207 | fe.compileFusion(&fusion, {t0}); |
2208 | auto outputs = fe.runFusion({t0}); |
2209 | |
2210 | auto t0_double = t0.to(at::kDouble); |
2211 | auto ref = t0_double + t0_double.mean({0}).unsqueeze(0); |
2212 | |
2213 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
2214 | } |
2215 | |
2216 | // Test grouping of two domains |
2217 | TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelford2_CUDA) { |
2218 | Fusion fusion; |
2219 | FusionGuard fg(&fusion); |
2220 | |
2221 | auto tv0 = makeSymbolicTensor(2); |
2222 | fusion.addInput(tv0); |
2223 | |
2224 | auto tv1 = set(tv0); |
2225 | auto tv2 = Welford(tv1, {0}).avg; |
2226 | auto tv3 = broadcast(tv2, {true, false}); |
2227 | auto tv4 = add(tv0, tv3); |
2228 | fusion.addOutput(tv4); |
2229 | |
2230 | const int vec1 = 2; |
2231 | const int vec2 = 3; |
2232 | const int tidx = 16; |
2233 | const int tidy = 8; |
2234 | |
2235 | tv1->split(1, vec1); |
2236 | tv1->split(1, vec2); |
2237 | tv1->split(1, tidx); |
2238 | tv1->split(0, tidy); |
2239 | TransformPropagator propagator(tv1); |
2240 | MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); |
2241 | |
2242 | tv1->axis(0)->parallelize(ParallelType::BIDy); |
2243 | tv1->axis(1)->parallelize(ParallelType::TIDy); |
2244 | tv1->axis(2)->parallelize(ParallelType::BIDx); |
2245 | tv1->axis(3)->parallelize(ParallelType::TIDx); |
2246 | |
2247 | scheduler_utils::parallelizeAllLike(tv1); |
2248 | |
2249 | tv2->axis(4)->parallelize(ParallelType::Group); |
2250 | tv2->axis(5)->parallelize(ParallelType::Group); |
2251 | |
2252 | std::vector<int64_t> shape({99, 129}); |
2253 | |
2254 | // Make sure the reduction expr is converted to GroupedGridReduciton |
2255 | // and the non-reduction domains of the output TV are either |
2256 | // grouped or parallelized |
2257 | GpuLower gpulw(&fusion); |
2258 | bool validated = false; |
2259 | for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { |
2260 | auto grouped_grid_reduction = dynamic_cast<kir::GroupedGridWelford*>(expr); |
2261 | if (grouped_grid_reduction == nullptr) { |
2262 | continue; |
2263 | } |
2264 | validated = true; |
2265 | } |
2266 | TORCH_CHECK( |
2267 | validated, "Invalid lowered kernel. No GroupedGridWelford found." ); |
2268 | |
2269 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2270 | at::manual_seed(0); |
2271 | auto t0 = at::randn(shape, options); |
2272 | |
2273 | FusionExecutor fe; |
2274 | fe.compileFusion(&fusion, {t0}); |
2275 | auto outputs = fe.runFusion({t0}); |
2276 | |
2277 | auto t0_double = t0.to(at::kDouble); |
2278 | auto ref = t0_double + t0_double.mean({0}).unsqueeze(0); |
2279 | |
2280 | testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); |
2281 | } |
2282 | |
2283 | // Follows the pattern of persistent outer grid welford in batchnorm |
2284 | TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelfordShmoo_CUDA) { |
2285 | struct Params { |
2286 | int N; |
2287 | int H; |
2288 | int W; |
2289 | int C; |
2290 | int tidx; |
2291 | int tidy; |
2292 | int vect; |
2293 | int persistent_buffer; |
2294 | int bidx; |
2295 | }; |
2296 | |
2297 | auto test = [](const Params& params) { |
2298 | Fusion fusion; |
2299 | FusionGuard fg(&fusion); |
2300 | |
2301 | std::vector<bool> bcast_pattern{true, true, true, false}; |
2302 | std::vector<int> reduction_dims{2, 1, 0}; |
2303 | |
2304 | auto tv0 = makeSymbolicTensor(4); |
2305 | fusion.addInput(tv0); |
2306 | |
2307 | auto tv1 = set(tv0); |
2308 | auto tvs = Welford(tv1, reduction_dims); |
2309 | auto tv2 = tvs.avg; |
2310 | auto tv3 = tvs.var_sum; |
2311 | auto tv4 = tvs.n; |
2312 | auto tv5 = broadcast(tv2, bcast_pattern); |
2313 | auto tv6 = broadcast(tv3, bcast_pattern); |
2314 | auto tv7 = broadcast(tv4, bcast_pattern); |
2315 | auto tv8 = sub(tv1, tv5); |
2316 | auto tv9 = add(tv8, tv6); |
2317 | // auto tv10 = div(tv9, tv7); |
2318 | // fusion.addOutput(tv10); |
2319 | fusion.addOutput(tv9); |
2320 | |
2321 | // Schedule the fusion as it will be done by the persistent |
2322 | // scheduler |
2323 | |
2324 | auto input_cache = tv1; |
2325 | auto output_cache = tv9->cacheBefore(); |
2326 | |
2327 | auto transform_ref = tv2; |
2328 | |
2329 | transform_ref->merge(0)->merge(0); |
2330 | |
2331 | int reduction_pos = 1; |
2332 | |
2333 | transform_ref->split(0, params.tidy); |
2334 | ++reduction_pos; |
2335 | transform_ref->axis(1)->parallelize(ParallelType::TIDy); |
2336 | |
2337 | // Persistent buffer |
2338 | transform_ref->split(0, params.persistent_buffer); |
2339 | ++reduction_pos; |
2340 | |
2341 | // Unswitch |
2342 | transform_ref->split(0, 1); |
2343 | ++reduction_pos; |
2344 | transform_ref->axis(1)->parallelize(ParallelType::Unswitch); |
2345 | |
2346 | transform_ref->axis(0)->parallelize(ParallelType::BIDy); |
2347 | |
2348 | transform_ref->split(reduction_pos, params.vect); |
2349 | transform_ref->axis(reduction_pos + 1) |
2350 | ->parallelize(ParallelType::Vectorize); |
2351 | |
2352 | transform_ref->split(reduction_pos, params.tidx); |
2353 | transform_ref->axis(reduction_pos + 1)->parallelize(ParallelType::TIDx); |
2354 | transform_ref->split(reduction_pos, params.bidx); |
2355 | transform_ref->axis(reduction_pos + 1)->parallelize(ParallelType::BIDx); |
2356 | |
2357 | auto transform_ref_rf = |
2358 | reduction_scheduler_utils::sortAndRFactor(transform_ref); |
2359 | |
2360 | TransformPropagator propagator(transform_ref_rf); |
2361 | MaxRootDomainInfoSpanningTree(transform_ref_rf).traverse(&propagator); |
2362 | |
2363 | int vec_id = std::distance( |
2364 | transform_ref_rf->domain()->domain().begin(), |
2365 | std::find_if( |
2366 | transform_ref_rf->domain()->domain().begin(), |
2367 | transform_ref_rf->domain()->domain().end(), |
2368 | [](auto id) { |
2369 | return id->getParallelType() == ParallelType::Vectorize; |
2370 | })); |
2371 | transform_ref_rf->axis(vec_id)->parallelize(ParallelType::Serial); |
2372 | |
2373 | int unswitch_id = std::distance( |
2374 | transform_ref_rf->domain()->domain().begin(), |
2375 | std::find_if( |
2376 | transform_ref_rf->domain()->domain().begin(), |
2377 | transform_ref_rf->domain()->domain().end(), |
2378 | [](auto id) { |
2379 | return id->getParallelType() == ParallelType::Unswitch; |
2380 | })); |
2381 | transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Serial); |
2382 | |
2383 | scheduler_utils::parallelizeAllLike( |
2384 | transform_ref_rf, ir_utils::allTvs(&fusion)); |
2385 | |
2386 | ParallelType vec_pt = ParallelType::Vectorize; |
2387 | tv1->axis(vec_id)->parallelize(vec_pt); |
2388 | tv9->axis(vec_id)->parallelize(vec_pt); |
2389 | |
2390 | transform_ref->axis(vec_id)->parallelize(ParallelType::Group); |
2391 | |
2392 | transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Unswitch); |
2393 | |
2394 | inlineMost(); |
2395 | |
2396 | // Make sure the reduction expr is converted to GroupedGridReduciton |
2397 | // and the non-reduction domains of the output TV are either |
2398 | // grouped or parallelized |
2399 | GpuLower gpulw(&fusion); |
2400 | bool validated = false; |
2401 | for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { |
2402 | auto grouped_grid_reduction = |
2403 | dynamic_cast<kir::GroupedGridWelford*>(expr); |
2404 | validated = true; |
2405 | } |
2406 | TORCH_CHECK( |
2407 | validated, "Invalid lowered kernel. No GroupedGridWelford found." ); |
2408 | |
2409 | auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
2410 | at::manual_seed(0); |
2411 | |
2412 | const std::vector<int64_t> input_shape{ |
2413 | params.N, params.H, params.W, params.C}; |
2414 | auto t0 = at::randn(input_shape, options); |
2415 | |
2416 | FusionExecutor fe; |
2417 | fe.compileFusion(&fusion, {t0}); |
2418 | |
2419 | // Skip the rest of this test size if the required number of SMs |
2420 | // exceeds the available SM count |
2421 | const auto num_required_sms = params.bidx * |
2422 | ceilDiv(ceilDiv(params.N * params.H * params.W, params.tidy), |
2423 | params.persistent_buffer); |
2424 | if (num_required_sms > deviceSMCount()) { |
2425 | return; |
2426 | } |
2427 | |
2428 | auto cg_outputs = fe.runFusion({t0}); |
2429 | |
2430 | auto t1 = t0.to(at::kDouble); |
2431 | auto t2 = t1.mean({0, 1, 2}).unsqueeze(0).unsqueeze(0).unsqueeze(0); |
2432 | auto t3 = |
2433 | at::var(t1, {0, 1, 2}, false).unsqueeze(0).unsqueeze(0).unsqueeze(0); |
2434 | auto t4 = params.N * params.H * params.W; |
2435 | auto ref = (t1 - t2 + (t3 * t4)); |
2436 | |
2437 | testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__, "" ); |
2438 | }; |
2439 | |
2440 | std::vector<Params> base_params; |
2441 | base_params.push_back({256, 7, 7, 1, 8, 32, 2, 32, 4}); |
2442 | base_params.push_back({256, 7, 7, 1, 16, 16, 4, 50, 4}); |
2443 | base_params.push_back({128, 7, 7, 1, 16, 16, 4, 32, 4}); |
2444 | base_params.push_back({128, 14, 14, 1, 16, 16, 4, 32, 1}); |
2445 | base_params.push_back({128, 14, 14, 1, 16, 16, 2, 64, 2}); |
2446 | base_params.push_back({128, 14, 14, 1, 8, 32, 4, 50, 4}); |
2447 | base_params.push_back({128, 14, 14, 1, 8, 32, 2, 50, 4}); |
2448 | |
2449 | std::vector<Params> param_vec; |
2450 | for (const auto base_p : base_params) { |
2451 | for (const auto c_dim : {512, 1024, 2048}) { |
2452 | auto tmp = base_p; |
2453 | tmp.C = c_dim; |
2454 | param_vec.push_back(tmp); |
2455 | } |
2456 | } |
2457 | |
2458 | for (const auto& params : param_vec) { |
2459 | test(params); |
2460 | } |
2461 | } |
2462 | |
2463 | } // namespace jit |
2464 | } // namespace torch |
2465 | #endif // #if defined(USE_CUDA) |
2466 | |