1#include <gmock/gmock-matchers.h>
2#include <gtest/gtest.h>
3
4#include <ATen/cuda/CUDAGeneratorImpl.h>
5#include <c10/util/Optional.h>
6#include <arith.h>
7#include <fusion.h>
8#include <ir_all_nodes.h>
9#include <kernel_cache.h>
10#include <scheduler/all_schedulers.h>
11#include <test/test_gpu_validator.h>
12#include <test/test_utils.h>
13#include <ATen/cuda/CUDAGraphsUtils.cuh>
14
15#include <cassert>
16#include <type_traits>
17
18#include <curand.h>
19#include <curand_kernel.h>
20#include <curand_philox4x32_x.h>
21
22// Tests go in torch::jit
23namespace torch {
24namespace jit {
25
26using namespace torch::jit::fuser::cuda;
27
28namespace {
29
30template <typename T>
31__global__ void generate_uniform_kernel(
32 T* output,
33 int64_t size,
34 PhiloxCudaState philox_args) {
35 int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
36
37 auto seeds = at::cuda::philox::unpack(philox_args);
38 curandStatePhilox4_32_10_t state;
39 curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
40
41 if (std::is_same<T, double>::value) {
42 double2 result = curand_uniform2_double(&state);
43 if (tid * 2 < size) {
44 output[tid * 2] = result.x;
45 }
46 if (tid * 2 + 1 < size) {
47 output[tid * 2 + 1] = result.y;
48 }
49 } else {
50 auto is_float = std::is_same<T, float>::value;
51 assert(is_float);
52 float4 result = curand_uniform4(&state);
53 if (tid * 4 < size) {
54 output[tid * 4] = result.x;
55 }
56 if (tid * 4 + 1 < size) {
57 output[tid * 4 + 1] = result.y;
58 }
59 if (tid * 4 + 2 < size) {
60 output[tid * 4 + 2] = result.z;
61 }
62 if (tid * 4 + 3 < size) {
63 output[tid * 4 + 3] = result.w;
64 }
65 }
66}
67
68at::Tensor generate_uniform(int64_t size, at::ScalarType dtype) {
69 auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
70 auto result = at::empty({size}, options);
71
72 auto gen = get_generator_or_default<CUDAGeneratorImpl>(
73 c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
74 PhiloxCudaState rng_engine_inputs;
75 {
76 // See Note [Acquire lock when using random generators]
77 std::lock_guard<std::mutex> lock(gen->mutex_);
78 rng_engine_inputs = gen->philox_cuda_state(4);
79 }
80
81 if (dtype == kFloat) {
82 int64_t block = 128;
83 int64_t block_elems = block * 4;
84 int64_t grid = (size + block_elems - 1) / block_elems;
85 generate_uniform_kernel<<<
86 grid,
87 block,
88 0,
89 at::cuda::getCurrentCUDAStream()>>>(
90 result.data_ptr<float>(), size, rng_engine_inputs);
91 } else {
92 TORCH_CHECK(dtype == kDouble);
93 int64_t block = 128;
94 int64_t block_elems = block * 2;
95 int64_t grid = (size + block_elems - 1) / block_elems;
96 generate_uniform_kernel<<<
97 grid,
98 block,
99 0,
100 at::cuda::getCurrentCUDAStream()>>>(
101 result.data_ptr<double>(), size, rng_engine_inputs);
102 }
103 return result;
104}
105
106} // namespace
107
108TEST_F(NVFuserTest, FusionRNGValidateWithCURand_CUDA) {
109 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
110 auto fusion = fusion_ptr.get();
111 FusionGuard fg(fusion);
112
113 Int* size_val = IrBuilder::create<Int>();
114 fusion->addInput(size_val);
115 TensorView* tv0 = rand({size_val}, DataType::Float);
116 TensorView* tv1 = rand({size_val}, DataType::Double);
117 fusion->addOutput(tv0);
118 fusion->addOutput(tv1);
119
120 FusionExecutorCache fec(std::move(fusion_ptr));
121
122 for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) {
123 at::manual_seed(0);
124 auto cg_outputs = fec.runFusionWithInputs({size});
125
126 at::manual_seed(0);
127 auto ref0 = generate_uniform(size, kFloat);
128 auto ref1 = generate_uniform(size, kDouble);
129
130 testValidate(
131 fec.fusion(), cg_outputs, {size}, {ref0, ref1}, __LINE__, __FILE__);
132 }
133}
134
135TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand_CUDA) {
136 int64_t size = 128;
137 auto dtype = kFloat;
138 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
139 auto fusion = fusion_ptr.get();
140 FusionGuard fg(fusion);
141
142 TensorView* tv0 = makeSymbolicTensor(1, aten_to_data_type(dtype));
143 fusion->addInput(tv0);
144 auto tv1 = rand_like(tv0);
145 auto tv2 = set(tv1);
146 fusion->addOutput(tv2);
147
148 tv2->split(0, 8);
149 tv2->axis(0)->parallelize(ParallelType::TIDx);
150
151 tv1->computeAt(tv2, 1);
152
153 auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
154 at::Tensor t0 = at::zeros({size}, options);
155
156 FusionExecutor fe;
157 fe.compileFusion(fusion, {t0});
158
159 at::manual_seed(0);
160 auto cg_outputs = fe.runFusion({t0});
161 auto out = cg_outputs[0];
162
163 at::manual_seed(0);
164 auto ref = generate_uniform(size, dtype);
165
166 testValidate(fusion, {out}, {t0}, {ref}, __LINE__, __FILE__);
167}
168
169TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand2_CUDA) {
170#ifdef FBCODE_CAFFE2
171 GTEST_SKIP() << "Fails accuracy on V100 32gb";
172#endif
173 auto dtype = kFloat;
174 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
175 auto fusion = fusion_ptr.get();
176 FusionGuard fg(fusion);
177
178 Int* size1 = IrBuilder::create<Int>();
179 Int* size2 = IrBuilder::create<Int>();
180 Int* size3 = IrBuilder::create<Int>();
181 Int* size4 = IrBuilder::create<Int>();
182 fusion->addInput(size1);
183 fusion->addInput(size2);
184 fusion->addInput(size3);
185 fusion->addInput(size4);
186 TensorView* tv0 = rand({size1, size2, size3, size4}, DataType::Float);
187 fusion->addOutput(tv0);
188
189 auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
190
191 FusionExecutor fe;
192 fe.compileFusion(fusion, {10, 10, 10, 10});
193
194 at::manual_seed(0);
195 auto cg_outputs = fe.runFusion({10, 10, 10, 10});
196 auto out = cg_outputs[0];
197
198 at::manual_seed(0);
199 auto ref = generate_uniform(10000, dtype).view({10, 10, 10, 10});
200
201 testValidate(fusion, {out}, {10, 10, 10, 10}, {ref}, __LINE__, __FILE__);
202}
203
204TEST_F(NVFuserTest, FusionBroadcastingRNG_CUDA) {
205 for (auto dtype : {kFloat, kDouble}) {
206 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
207 auto fusion = fusion_ptr.get();
208 FusionGuard fg(fusion);
209
210 TensorView* tv0 = makeConcreteTensor({5, 1}, aten_to_data_type(dtype));
211 TensorView* tv1 = makeConcreteTensor({5, 5}, aten_to_data_type(dtype));
212 fusion->addInput(tv0);
213 fusion->addInput(tv1);
214 auto tv2 = rand_like(tv0);
215 auto tv3 = add(tv1, tv2);
216 auto tv4 = add(tv0, tv3);
217 fusion->addOutput(tv4);
218
219 FusionExecutorCache fec(std::move(fusion_ptr));
220
221 auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
222 at::Tensor t0 = at::zeros({5, 1}, options);
223 at::Tensor t1 = at::zeros({5, 5}, options);
224
225 auto cg_outputs = fec.runFusionWithInputs({t0, t1});
226 auto out = cg_outputs[0];
227 TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>())
228 TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>())
229 TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>())
230 TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>())
231 }
232}
233
234TEST_F(NVFuserTest, FusionBroadcastingRNG2_CUDA) {
235 for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) {
236 for (auto dtype : {kFloat, kDouble}) {
237 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
238 auto fusion = fusion_ptr.get();
239 FusionGuard fg(fusion);
240
241 TensorView* tv0 = makeConcreteTensor({1}, aten_to_data_type(dtype));
242 TensorView* tv1 = makeSymbolicTensor(1, aten_to_data_type(dtype));
243 fusion->addInput(tv0);
244 fusion->addInput(tv1);
245 auto tv2 = rand_like(tv0);
246 auto tv3 = add(tv1, tv2);
247 fusion->addOutput(tv3);
248
249 FusionExecutorCache fec(std::move(fusion_ptr));
250
251 auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
252 at::Tensor t0 = at::zeros({1}, options);
253 at::Tensor t1 = at::zeros({size}, options);
254
255 at::manual_seed(0);
256 auto cg_outputs = fec.runFusionWithInputs({t0, t1});
257 auto out = cg_outputs[0];
258
259 at::manual_seed(0);
260 auto ref = generate_uniform(1, dtype).expand_as(t1);
261
262 testValidate(fec.fusion(), {out}, {t0, t1}, {ref}, __LINE__, __FILE__);
263 }
264 }
265}
266
267TEST_F(NVFuserTest, FusionBroadcastingRNGSmem_CUDA) {
268 for (auto dtype : {kFloat, kDouble}) {
269 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
270 auto fusion = fusion_ptr.get();
271 FusionGuard fg(fusion);
272
273 TensorView* tv0 = makeConcreteTensor({5, 1}, aten_to_data_type(dtype));
274 TensorView* tv1 = makeConcreteTensor({5, 5}, aten_to_data_type(dtype));
275 fusion->addInput(tv0);
276 fusion->addInput(tv1);
277 auto tv2 = rand_like(tv0);
278 auto tv3 = add(tv1, tv2);
279 auto tv4 = add(tv0, tv3);
280 fusion->addOutput(tv4);
281
282 auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
283 at::Tensor t0 = at::zeros({5, 1}, options);
284 at::Tensor t1 = at::zeros({5, 5}, options);
285
286 auto lparams = scheduleTranspose(fusion, {t0, t1});
287
288 FusionExecutor fe;
289 fe.compileFusion(fusion, {t0, t1}, lparams);
290 auto cg_outputs = fe.runFusion({t0, t1}, lparams);
291 auto out = cg_outputs[0];
292
293 TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>())
294 TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>())
295 TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>())
296 TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>())
297 }
298}
299
300TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) {
301 // https://github.com/csarofeen/pytorch/issues/1926
302 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
303 auto fusion = fusion_ptr.get();
304 FusionGuard fg(fusion);
305
306 TensorView* tv0 = makeConcreteTensor({5, 1});
307 TensorView* tv1 = makeConcreteTensor({5, 5});
308 fusion->addInput(tv0);
309 fusion->addInput(tv1);
310 auto tv2 = rand_like(tv0);
311 auto tv3 = add(tv1, tv2);
312 auto tv4 = add(tv0, tv3);
313 fusion->addOutput(tv4);
314
315 auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
316 at::Tensor t0 = at::zeros({5, 1}, options);
317 at::Tensor t1 = at::zeros({5, 5}, options);
318
319 TransposeParams heuristics;
320 heuristics.tile_size1 = 8;
321 heuristics.tile_size2 = 4;
322 scheduleTranspose(fusion, heuristics);
323
324 FusionExecutor fe;
325 fe.compileFusion(fusion, {t0, t1});
326 auto cg_outputs = fe.runFusion({t0, t1});
327 auto out = cg_outputs[0];
328
329 TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>());
330 TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>());
331 TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>());
332 TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>());
333}
334
335TEST_F(NVFuserTest, FusionUniform_CUDA) {
336 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
337 auto fusion = fusion_ptr.get();
338 FusionGuard fg(fusion);
339
340 Int* size_val = IrBuilder::create<Int>();
341 Double* low = IrBuilder::create<Double>();
342 Double* high = IrBuilder::create<Double>();
343 fusion->addInput(size_val);
344 fusion->addInput(low);
345 fusion->addInput(high);
346 TensorView* tv0 = uniform({size_val}, low, high, DataType::Float);
347 TensorView* tv1 = uniform({size_val}, low, high, DataType::Double);
348 fusion->addOutput(tv0);
349 fusion->addOutput(tv1);
350
351 FusionExecutorCache fec(std::move(fusion_ptr));
352
353 for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) {
354 at::manual_seed(0);
355 auto cg_outputs = fec.runFusionWithInputs({size, -1.0, 1.0});
356
357 at::manual_seed(0);
358 auto ref0 = generate_uniform(size, kFloat) * 2 - 1;
359 auto ref1 = generate_uniform(size, kDouble) * 2 - 1;
360
361 testValidate(
362 fec.fusion(),
363 cg_outputs,
364 {size, -1.0, 1.0},
365 {ref0, ref1},
366 __LINE__,
367 __FILE__);
368 }
369}
370
371TEST_F(NVFuserTest, FusionRandLikeReduction_CUDA) {
372 auto dtype = kFloat;
373 std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
374 auto fusion = fusion_ptr.get();
375 FusionGuard fg(fusion);
376
377 TensorView* tv0 = makeSymbolicTensor(2, aten_to_data_type(dtype));
378 fusion->addInput(tv0);
379 auto tv1 = sum(tv0, {0});
380 auto tv2 = rand_like(tv1);
381 auto tv3 = add(tv1, tv2);
382 fusion->addOutput(tv3);
383
384 FusionExecutorCache fec(std::move(fusion_ptr));
385
386 auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
387 at::Tensor t0 = at::zeros({2, 3}, options);
388
389 at::manual_seed(0);
390 auto cg_outputs = fec.runFusionWithInputs({t0});
391 auto out = cg_outputs[0];
392
393 at::manual_seed(0);
394 auto t1 = t0.sum(0);
395 auto t2 = generate_uniform(3, dtype).expand_as(t1);
396 auto t3 = t1.add(t2);
397
398 testValidate(fec.fusion(), {out}, {t0}, {t3}, __LINE__, __FILE__);
399}
400
401} // namespace jit
402} // namespace torch
403