1 | #include <gmock/gmock-matchers.h> |
2 | #include <gtest/gtest.h> |
3 | |
4 | #include <ATen/cuda/CUDAGeneratorImpl.h> |
5 | #include <c10/util/Optional.h> |
6 | #include <arith.h> |
7 | #include <fusion.h> |
8 | #include <ir_all_nodes.h> |
9 | #include <kernel_cache.h> |
10 | #include <scheduler/all_schedulers.h> |
11 | #include <test/test_gpu_validator.h> |
12 | #include <test/test_utils.h> |
13 | #include <ATen/cuda/CUDAGraphsUtils.cuh> |
14 | |
15 | #include <cassert> |
16 | #include <type_traits> |
17 | |
18 | #include <curand.h> |
19 | #include <curand_kernel.h> |
20 | #include <curand_philox4x32_x.h> |
21 | |
22 | // Tests go in torch::jit |
23 | namespace torch { |
24 | namespace jit { |
25 | |
26 | using namespace torch::jit::fuser::cuda; |
27 | |
28 | namespace { |
29 | |
30 | template <typename T> |
31 | __global__ void generate_uniform_kernel( |
32 | T* output, |
33 | int64_t size, |
34 | PhiloxCudaState philox_args) { |
35 | int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; |
36 | |
37 | auto seeds = at::cuda::philox::unpack(philox_args); |
38 | curandStatePhilox4_32_10_t state; |
39 | curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state); |
40 | |
41 | if (std::is_same<T, double>::value) { |
42 | double2 result = curand_uniform2_double(&state); |
43 | if (tid * 2 < size) { |
44 | output[tid * 2] = result.x; |
45 | } |
46 | if (tid * 2 + 1 < size) { |
47 | output[tid * 2 + 1] = result.y; |
48 | } |
49 | } else { |
50 | auto is_float = std::is_same<T, float>::value; |
51 | assert(is_float); |
52 | float4 result = curand_uniform4(&state); |
53 | if (tid * 4 < size) { |
54 | output[tid * 4] = result.x; |
55 | } |
56 | if (tid * 4 + 1 < size) { |
57 | output[tid * 4 + 1] = result.y; |
58 | } |
59 | if (tid * 4 + 2 < size) { |
60 | output[tid * 4 + 2] = result.z; |
61 | } |
62 | if (tid * 4 + 3 < size) { |
63 | output[tid * 4 + 3] = result.w; |
64 | } |
65 | } |
66 | } |
67 | |
68 | at::Tensor generate_uniform(int64_t size, at::ScalarType dtype) { |
69 | auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); |
70 | auto result = at::empty({size}, options); |
71 | |
72 | auto gen = get_generator_or_default<CUDAGeneratorImpl>( |
73 | c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); |
74 | PhiloxCudaState rng_engine_inputs; |
75 | { |
76 | // See Note [Acquire lock when using random generators] |
77 | std::lock_guard<std::mutex> lock(gen->mutex_); |
78 | rng_engine_inputs = gen->philox_cuda_state(4); |
79 | } |
80 | |
81 | if (dtype == kFloat) { |
82 | int64_t block = 128; |
83 | int64_t block_elems = block * 4; |
84 | int64_t grid = (size + block_elems - 1) / block_elems; |
85 | generate_uniform_kernel<<< |
86 | grid, |
87 | block, |
88 | 0, |
89 | at::cuda::getCurrentCUDAStream()>>>( |
90 | result.data_ptr<float>(), size, rng_engine_inputs); |
91 | } else { |
92 | TORCH_CHECK(dtype == kDouble); |
93 | int64_t block = 128; |
94 | int64_t block_elems = block * 2; |
95 | int64_t grid = (size + block_elems - 1) / block_elems; |
96 | generate_uniform_kernel<<< |
97 | grid, |
98 | block, |
99 | 0, |
100 | at::cuda::getCurrentCUDAStream()>>>( |
101 | result.data_ptr<double>(), size, rng_engine_inputs); |
102 | } |
103 | return result; |
104 | } |
105 | |
106 | } // namespace |
107 | |
108 | TEST_F(NVFuserTest, FusionRNGValidateWithCURand_CUDA) { |
109 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
110 | auto fusion = fusion_ptr.get(); |
111 | FusionGuard fg(fusion); |
112 | |
113 | Int* size_val = IrBuilder::create<Int>(); |
114 | fusion->addInput(size_val); |
115 | TensorView* tv0 = rand({size_val}, DataType::Float); |
116 | TensorView* tv1 = rand({size_val}, DataType::Double); |
117 | fusion->addOutput(tv0); |
118 | fusion->addOutput(tv1); |
119 | |
120 | FusionExecutorCache fec(std::move(fusion_ptr)); |
121 | |
122 | for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) { |
123 | at::manual_seed(0); |
124 | auto cg_outputs = fec.runFusionWithInputs({size}); |
125 | |
126 | at::manual_seed(0); |
127 | auto ref0 = generate_uniform(size, kFloat); |
128 | auto ref1 = generate_uniform(size, kDouble); |
129 | |
130 | testValidate( |
131 | fec.fusion(), cg_outputs, {size}, {ref0, ref1}, __LINE__, __FILE__); |
132 | } |
133 | } |
134 | |
135 | TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand_CUDA) { |
136 | int64_t size = 128; |
137 | auto dtype = kFloat; |
138 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
139 | auto fusion = fusion_ptr.get(); |
140 | FusionGuard fg(fusion); |
141 | |
142 | TensorView* tv0 = makeSymbolicTensor(1, aten_to_data_type(dtype)); |
143 | fusion->addInput(tv0); |
144 | auto tv1 = rand_like(tv0); |
145 | auto tv2 = set(tv1); |
146 | fusion->addOutput(tv2); |
147 | |
148 | tv2->split(0, 8); |
149 | tv2->axis(0)->parallelize(ParallelType::TIDx); |
150 | |
151 | tv1->computeAt(tv2, 1); |
152 | |
153 | auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); |
154 | at::Tensor t0 = at::zeros({size}, options); |
155 | |
156 | FusionExecutor fe; |
157 | fe.compileFusion(fusion, {t0}); |
158 | |
159 | at::manual_seed(0); |
160 | auto cg_outputs = fe.runFusion({t0}); |
161 | auto out = cg_outputs[0]; |
162 | |
163 | at::manual_seed(0); |
164 | auto ref = generate_uniform(size, dtype); |
165 | |
166 | testValidate(fusion, {out}, {t0}, {ref}, __LINE__, __FILE__); |
167 | } |
168 | |
169 | TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand2_CUDA) { |
170 | #ifdef FBCODE_CAFFE2 |
171 | GTEST_SKIP() << "Fails accuracy on V100 32gb" ; |
172 | #endif |
173 | auto dtype = kFloat; |
174 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
175 | auto fusion = fusion_ptr.get(); |
176 | FusionGuard fg(fusion); |
177 | |
178 | Int* size1 = IrBuilder::create<Int>(); |
179 | Int* size2 = IrBuilder::create<Int>(); |
180 | Int* size3 = IrBuilder::create<Int>(); |
181 | Int* size4 = IrBuilder::create<Int>(); |
182 | fusion->addInput(size1); |
183 | fusion->addInput(size2); |
184 | fusion->addInput(size3); |
185 | fusion->addInput(size4); |
186 | TensorView* tv0 = rand({size1, size2, size3, size4}, DataType::Float); |
187 | fusion->addOutput(tv0); |
188 | |
189 | auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); |
190 | |
191 | FusionExecutor fe; |
192 | fe.compileFusion(fusion, {10, 10, 10, 10}); |
193 | |
194 | at::manual_seed(0); |
195 | auto cg_outputs = fe.runFusion({10, 10, 10, 10}); |
196 | auto out = cg_outputs[0]; |
197 | |
198 | at::manual_seed(0); |
199 | auto ref = generate_uniform(10000, dtype).view({10, 10, 10, 10}); |
200 | |
201 | testValidate(fusion, {out}, {10, 10, 10, 10}, {ref}, __LINE__, __FILE__); |
202 | } |
203 | |
204 | TEST_F(NVFuserTest, FusionBroadcastingRNG_CUDA) { |
205 | for (auto dtype : {kFloat, kDouble}) { |
206 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
207 | auto fusion = fusion_ptr.get(); |
208 | FusionGuard fg(fusion); |
209 | |
210 | TensorView* tv0 = makeConcreteTensor({5, 1}, aten_to_data_type(dtype)); |
211 | TensorView* tv1 = makeConcreteTensor({5, 5}, aten_to_data_type(dtype)); |
212 | fusion->addInput(tv0); |
213 | fusion->addInput(tv1); |
214 | auto tv2 = rand_like(tv0); |
215 | auto tv3 = add(tv1, tv2); |
216 | auto tv4 = add(tv0, tv3); |
217 | fusion->addOutput(tv4); |
218 | |
219 | FusionExecutorCache fec(std::move(fusion_ptr)); |
220 | |
221 | auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); |
222 | at::Tensor t0 = at::zeros({5, 1}, options); |
223 | at::Tensor t1 = at::zeros({5, 5}, options); |
224 | |
225 | auto cg_outputs = fec.runFusionWithInputs({t0, t1}); |
226 | auto out = cg_outputs[0]; |
227 | TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>()) |
228 | TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>()) |
229 | TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>()) |
230 | TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>()) |
231 | } |
232 | } |
233 | |
234 | TEST_F(NVFuserTest, FusionBroadcastingRNG2_CUDA) { |
235 | for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) { |
236 | for (auto dtype : {kFloat, kDouble}) { |
237 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
238 | auto fusion = fusion_ptr.get(); |
239 | FusionGuard fg(fusion); |
240 | |
241 | TensorView* tv0 = makeConcreteTensor({1}, aten_to_data_type(dtype)); |
242 | TensorView* tv1 = makeSymbolicTensor(1, aten_to_data_type(dtype)); |
243 | fusion->addInput(tv0); |
244 | fusion->addInput(tv1); |
245 | auto tv2 = rand_like(tv0); |
246 | auto tv3 = add(tv1, tv2); |
247 | fusion->addOutput(tv3); |
248 | |
249 | FusionExecutorCache fec(std::move(fusion_ptr)); |
250 | |
251 | auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); |
252 | at::Tensor t0 = at::zeros({1}, options); |
253 | at::Tensor t1 = at::zeros({size}, options); |
254 | |
255 | at::manual_seed(0); |
256 | auto cg_outputs = fec.runFusionWithInputs({t0, t1}); |
257 | auto out = cg_outputs[0]; |
258 | |
259 | at::manual_seed(0); |
260 | auto ref = generate_uniform(1, dtype).expand_as(t1); |
261 | |
262 | testValidate(fec.fusion(), {out}, {t0, t1}, {ref}, __LINE__, __FILE__); |
263 | } |
264 | } |
265 | } |
266 | |
267 | TEST_F(NVFuserTest, FusionBroadcastingRNGSmem_CUDA) { |
268 | for (auto dtype : {kFloat, kDouble}) { |
269 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
270 | auto fusion = fusion_ptr.get(); |
271 | FusionGuard fg(fusion); |
272 | |
273 | TensorView* tv0 = makeConcreteTensor({5, 1}, aten_to_data_type(dtype)); |
274 | TensorView* tv1 = makeConcreteTensor({5, 5}, aten_to_data_type(dtype)); |
275 | fusion->addInput(tv0); |
276 | fusion->addInput(tv1); |
277 | auto tv2 = rand_like(tv0); |
278 | auto tv3 = add(tv1, tv2); |
279 | auto tv4 = add(tv0, tv3); |
280 | fusion->addOutput(tv4); |
281 | |
282 | auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); |
283 | at::Tensor t0 = at::zeros({5, 1}, options); |
284 | at::Tensor t1 = at::zeros({5, 5}, options); |
285 | |
286 | auto lparams = scheduleTranspose(fusion, {t0, t1}); |
287 | |
288 | FusionExecutor fe; |
289 | fe.compileFusion(fusion, {t0, t1}, lparams); |
290 | auto cg_outputs = fe.runFusion({t0, t1}, lparams); |
291 | auto out = cg_outputs[0]; |
292 | |
293 | TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>()) |
294 | TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>()) |
295 | TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>()) |
296 | TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>()) |
297 | } |
298 | } |
299 | |
300 | TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) { |
301 | // https://github.com/csarofeen/pytorch/issues/1926 |
302 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
303 | auto fusion = fusion_ptr.get(); |
304 | FusionGuard fg(fusion); |
305 | |
306 | TensorView* tv0 = makeConcreteTensor({5, 1}); |
307 | TensorView* tv1 = makeConcreteTensor({5, 5}); |
308 | fusion->addInput(tv0); |
309 | fusion->addInput(tv1); |
310 | auto tv2 = rand_like(tv0); |
311 | auto tv3 = add(tv1, tv2); |
312 | auto tv4 = add(tv0, tv3); |
313 | fusion->addOutput(tv4); |
314 | |
315 | auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); |
316 | at::Tensor t0 = at::zeros({5, 1}, options); |
317 | at::Tensor t1 = at::zeros({5, 5}, options); |
318 | |
319 | TransposeParams heuristics; |
320 | heuristics.tile_size1 = 8; |
321 | heuristics.tile_size2 = 4; |
322 | scheduleTranspose(fusion, heuristics); |
323 | |
324 | FusionExecutor fe; |
325 | fe.compileFusion(fusion, {t0, t1}); |
326 | auto cg_outputs = fe.runFusion({t0, t1}); |
327 | auto out = cg_outputs[0]; |
328 | |
329 | TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>()); |
330 | TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>()); |
331 | TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>()); |
332 | TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>()); |
333 | } |
334 | |
335 | TEST_F(NVFuserTest, FusionUniform_CUDA) { |
336 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
337 | auto fusion = fusion_ptr.get(); |
338 | FusionGuard fg(fusion); |
339 | |
340 | Int* size_val = IrBuilder::create<Int>(); |
341 | Double* low = IrBuilder::create<Double>(); |
342 | Double* high = IrBuilder::create<Double>(); |
343 | fusion->addInput(size_val); |
344 | fusion->addInput(low); |
345 | fusion->addInput(high); |
346 | TensorView* tv0 = uniform({size_val}, low, high, DataType::Float); |
347 | TensorView* tv1 = uniform({size_val}, low, high, DataType::Double); |
348 | fusion->addOutput(tv0); |
349 | fusion->addOutput(tv1); |
350 | |
351 | FusionExecutorCache fec(std::move(fusion_ptr)); |
352 | |
353 | for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) { |
354 | at::manual_seed(0); |
355 | auto cg_outputs = fec.runFusionWithInputs({size, -1.0, 1.0}); |
356 | |
357 | at::manual_seed(0); |
358 | auto ref0 = generate_uniform(size, kFloat) * 2 - 1; |
359 | auto ref1 = generate_uniform(size, kDouble) * 2 - 1; |
360 | |
361 | testValidate( |
362 | fec.fusion(), |
363 | cg_outputs, |
364 | {size, -1.0, 1.0}, |
365 | {ref0, ref1}, |
366 | __LINE__, |
367 | __FILE__); |
368 | } |
369 | } |
370 | |
371 | TEST_F(NVFuserTest, FusionRandLikeReduction_CUDA) { |
372 | auto dtype = kFloat; |
373 | std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>(); |
374 | auto fusion = fusion_ptr.get(); |
375 | FusionGuard fg(fusion); |
376 | |
377 | TensorView* tv0 = makeSymbolicTensor(2, aten_to_data_type(dtype)); |
378 | fusion->addInput(tv0); |
379 | auto tv1 = sum(tv0, {0}); |
380 | auto tv2 = rand_like(tv1); |
381 | auto tv3 = add(tv1, tv2); |
382 | fusion->addOutput(tv3); |
383 | |
384 | FusionExecutorCache fec(std::move(fusion_ptr)); |
385 | |
386 | auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); |
387 | at::Tensor t0 = at::zeros({2, 3}, options); |
388 | |
389 | at::manual_seed(0); |
390 | auto cg_outputs = fec.runFusionWithInputs({t0}); |
391 | auto out = cg_outputs[0]; |
392 | |
393 | at::manual_seed(0); |
394 | auto t1 = t0.sum(0); |
395 | auto t2 = generate_uniform(3, dtype).expand_as(t1); |
396 | auto t3 = t1.add(t2); |
397 | |
398 | testValidate(fec.fusion(), {out}, {t0}, {t3}, __LINE__, __FILE__); |
399 | } |
400 | |
401 | } // namespace jit |
402 | } // namespace torch |
403 | |