1 | #ifdef USE_CUDA |
2 | |
3 | #include <cmath> |
4 | #include <sstream> |
5 | #include <stdexcept> |
6 | |
7 | #include <gtest/gtest.h> |
8 | |
9 | #include <test/cpp/tensorexpr/test_base.h> |
10 | |
11 | #include <test/cpp/tensorexpr/padded_buffer.h> |
12 | #include <torch/csrc/jit/tensorexpr/cuda_codegen.h> |
13 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
14 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
15 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
16 | #include <torch/csrc/jit/testing/file_check.h> |
17 | |
18 | #include <torch/csrc/jit/testing/file_check.h> |
19 | |
20 | #include <c10/cuda/CUDACachingAllocator.h> |
21 | #include <c10/util/Half.h> |
22 | #include <c10/util/irange.h> |
23 | |
24 | namespace torch { |
25 | namespace jit { |
26 | using namespace torch::jit::tensorexpr; |
27 | using namespace torch::jit::tensorexpr; |
28 | |
29 | template <typename ctype> |
30 | static void testCudaTestVectorAdd01_impl() { |
31 | const int num_iter = 3; |
32 | const int block_count = 16; |
33 | const int block_size = 128; |
34 | Dtype dtype = ToDtype<ctype>(); |
35 | BufHandle a_buf("a" , {num_iter, block_count, block_size}, dtype); |
36 | BufHandle b_buf("b" , {num_iter, block_count, block_size}, dtype); |
37 | Tensor c = Compute( |
38 | "c" , |
39 | { |
40 | num_iter, |
41 | block_count, |
42 | block_size, |
43 | }, |
44 | [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { |
45 | return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id); |
46 | }); |
47 | LoopNest l({c}); |
48 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
49 | loops[1]->set_gpu_block_index(0); |
50 | loops[2]->set_gpu_thread_index(0); |
51 | l.prepareForCodegen(); |
52 | StmtPtr stmt = l.root_stmt(); |
53 | CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); |
54 | const int N = block_count * block_size * num_iter; |
55 | PaddedBuffer<ctype> a_v(N); |
56 | PaddedBuffer<ctype> b_v(N); |
57 | PaddedBuffer<ctype> c_v(N); |
58 | PaddedBuffer<ctype> c_ref(N); |
59 | |
60 | for (const auto i : c10::irange(N)) { |
61 | a_v(i) = ctype(i); |
62 | b_v(i) = ctype(i * 3 + 7); |
63 | c_ref(i) = a_v(i) + b_v(i); |
64 | } |
65 | |
66 | // TODO: move gpu support into PaddedBuffer |
67 | ctype* a_dev = nullptr; |
68 | C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(ctype))); |
69 | ctype* b_dev = nullptr; |
70 | C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(ctype))); |
71 | ctype* c_dev = nullptr; |
72 | C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(ctype))); |
73 | C10_CUDA_CHECK( |
74 | cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); |
75 | C10_CUDA_CHECK( |
76 | cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); |
77 | C10_CUDA_CHECK( |
78 | cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); |
79 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
80 | |
81 | cuda_cg(c_dev, a_dev, b_dev); |
82 | |
83 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
84 | C10_CUDA_CHECK( |
85 | cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost)); |
86 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
87 | |
88 | ExpectAllNear(c_v, c_ref, 1e-5); |
89 | |
90 | C10_CUDA_CHECK(cudaFree(a_dev)); |
91 | C10_CUDA_CHECK(cudaFree(b_dev)); |
92 | C10_CUDA_CHECK(cudaFree(c_dev)); |
93 | } |
94 | |
95 | float sigmoid(float x) { |
96 | return 1.0f / (1.0f + expf(-0.0f - x)); |
97 | } |
98 | |
99 | TEST(Cuda, Sigmoid_CUDA) { |
100 | const int num_iter = 3; |
101 | const int block_count = 16; |
102 | const int block_size = 128; |
103 | Dtype dtype = ToDtype<float>(); |
104 | BufHandle a_buf("a" , {num_iter, block_count, block_size}, dtype); |
105 | Tensor c = Compute( |
106 | "c" , |
107 | { |
108 | num_iter, |
109 | block_count, |
110 | block_size, |
111 | }, |
112 | [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { |
113 | return sigmoid(sigmoid(a_buf.load(n, b_id, t_id))); |
114 | }); |
115 | LoopNest l({c}); |
116 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
117 | loops[1]->set_gpu_block_index(0); |
118 | loops[2]->set_gpu_thread_index(0); |
119 | l.prepareForCodegen(); |
120 | StmtPtr stmt = l.root_stmt(); |
121 | CudaCodeGen cuda_cg(stmt, c, a_buf); |
122 | const int N = block_count * block_size * num_iter; |
123 | PaddedBuffer<float> a_v(N); |
124 | PaddedBuffer<float> c_v(N); |
125 | PaddedBuffer<float> c_ref(N); |
126 | |
127 | for (const auto i : c10::irange(N)) { |
128 | a_v(i) = float(i); |
129 | c_ref(i) = sigmoid(sigmoid(a_v(i))); |
130 | } |
131 | |
132 | // TODO: move gpu support into PaddedBuffer |
133 | float* a_dev = nullptr; |
134 | C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); |
135 | float* c_dev = nullptr; |
136 | C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); |
137 | C10_CUDA_CHECK( |
138 | cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); |
139 | C10_CUDA_CHECK( |
140 | cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); |
141 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
142 | |
143 | cuda_cg(c_dev, a_dev); |
144 | |
145 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
146 | C10_CUDA_CHECK( |
147 | cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); |
148 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
149 | |
150 | ExpectAllNear(c_v, c_ref, 1e-5); |
151 | |
152 | C10_CUDA_CHECK(cudaFree(a_dev)); |
153 | C10_CUDA_CHECK(cudaFree(c_dev)); |
154 | } |
155 | |
156 | TEST(Cuda, TestVectorAdd01_CUDA) { |
157 | // floating types. |
158 | testCudaTestVectorAdd01_impl<float>(); |
159 | testCudaTestVectorAdd01_impl<at::Half>(); |
160 | testCudaTestVectorAdd01_impl<double>(); |
161 | |
162 | // integer types. |
163 | testCudaTestVectorAdd01_impl<int8_t>(); |
164 | testCudaTestVectorAdd01_impl<uint8_t>(); |
165 | testCudaTestVectorAdd01_impl<int16_t>(); |
166 | testCudaTestVectorAdd01_impl<int32_t>(); |
167 | testCudaTestVectorAdd01_impl<int64_t>(); |
168 | } |
169 | |
170 | static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) { |
171 | BufHandle a_buf("a" , {N}, kFloat); |
172 | BufHandle b_buf("b" , {N}, kFloat); |
173 | Tensor c = Compute("c" , {N}, [&](const VarHandle& n) { |
174 | return a_buf.load(n) + b_buf.load(n); |
175 | }); |
176 | LoopNest l({c}); |
177 | ForPtr n_inner; |
178 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
179 | l.splitWithMask(loops[0], block_size, &n_inner); |
180 | loops[0]->set_gpu_block_index(0); |
181 | n_inner->set_gpu_thread_index(0); |
182 | l.prepareForCodegen(); |
183 | StmtPtr stmt = l.root_stmt(); |
184 | CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); |
185 | PaddedBuffer<float> a_v(N); |
186 | PaddedBuffer<float> b_v(N); |
187 | PaddedBuffer<float> c_v(N); |
188 | PaddedBuffer<float> c_ref(N); |
189 | |
190 | for (const auto i : c10::irange(N)) { |
191 | a_v(i) = i; |
192 | b_v(i) = i * 3 + 7; |
193 | c_ref(i) = a_v(i) + b_v(i); |
194 | } |
195 | |
196 | // TODO: move gpu support into PaddedBuffer |
197 | float* a_dev = nullptr; |
198 | C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); |
199 | float* b_dev = nullptr; |
200 | C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); |
201 | float* c_dev = nullptr; |
202 | C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); |
203 | C10_CUDA_CHECK( |
204 | cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); |
205 | C10_CUDA_CHECK( |
206 | cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); |
207 | C10_CUDA_CHECK( |
208 | cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); |
209 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
210 | |
211 | cuda_cg(c_dev, a_dev, b_dev); |
212 | |
213 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
214 | C10_CUDA_CHECK( |
215 | cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); |
216 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
217 | |
218 | ExpectAllNear(c_v, c_ref, 1e-5); |
219 | |
220 | C10_CUDA_CHECK(cudaFree(a_dev)); |
221 | C10_CUDA_CHECK(cudaFree(b_dev)); |
222 | C10_CUDA_CHECK(cudaFree(c_dev)); |
223 | } |
224 | |
225 | TEST(Cuda, TestVectorAdd02_CUDA) { |
226 | testCudaTestVectorAdd02_impl(1024, 128); |
227 | testCudaTestVectorAdd02_impl(1030, 128); |
228 | } |
229 | |
230 | TEST(Cuda, HalfCast_CUDA) { |
231 | auto half = ToDtype<at::Half>(); |
232 | BufHandle a("a" , {4}, half); |
233 | Tensor b = Compute("b" , {4}, [&](const VarHandle& i) { |
234 | return Cast::make(kFloat, a.load(i)); |
235 | }); |
236 | |
237 | LoopNest l({b}); |
238 | l.prepareForCodegen(); |
239 | StmtPtr s = l.root_stmt(); |
240 | CudaCodeGen cg(s, {a, b}); |
241 | |
242 | std::vector<at::Half> aData(4, 2.0f); |
243 | std::vector<float> bData(4, 0.0f); |
244 | at::Half* aDev = nullptr; |
245 | float* bDev = nullptr; |
246 | auto aSize = aData.size() * sizeof(aData[0]); |
247 | auto bSize = bData.size() * sizeof(bData[0]); |
248 | |
249 | C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); |
250 | C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); |
251 | C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); |
252 | C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); |
253 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
254 | |
255 | cg.call({aDev, bDev}); |
256 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
257 | |
258 | C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); |
259 | C10_CUDA_CHECK(cudaMemcpy(bData.data(), bDev, bSize, cudaMemcpyDeviceToHost)); |
260 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
261 | |
262 | assertAllEqual(bData, 2.0f); |
263 | |
264 | C10_CUDA_CHECK(cudaFree(aDev)); |
265 | C10_CUDA_CHECK(cudaFree(bDev)); |
266 | } |
267 | |
268 | TEST(Cuda, DynamicShape2D_CUDA) { |
269 | auto testWithSize = [](int32_t M, int32_t N) { |
270 | VarHandle m("m" , kInt); |
271 | VarHandle n("n" , kInt); |
272 | BufHandle a("a" , {m, n}, kFloat); |
273 | BufHandle b("b" , {m, n}, kFloat); |
274 | Tensor c = |
275 | Compute("c" , {m, n}, [&](const VarHandle& i, const VarHandle& j) { |
276 | return a.load(i, j) + b.load(i, j); |
277 | }); |
278 | LoopNest l({c}); |
279 | l.prepareForCodegen(); |
280 | StmtPtr s = l.root_stmt(); |
281 | CudaCodeGen cg(s, {a, b, c, m, n}); |
282 | |
283 | std::vector<float> aData(M * N, 1.0f); |
284 | std::vector<float> bData(M * N, 2.0f); |
285 | std::vector<float> cData(M * N, 0.0f); |
286 | float* aDev = nullptr; |
287 | float* bDev = nullptr; |
288 | float* cDev = nullptr; |
289 | C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); |
290 | C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); |
291 | C10_CUDA_CHECK(cudaMalloc(&cDev, cData.size() * sizeof(cData[0]))); |
292 | C10_CUDA_CHECK(cudaMemcpy( |
293 | aDev, |
294 | aData.data(), |
295 | aData.size() * sizeof(aData[0]), |
296 | cudaMemcpyHostToDevice)); |
297 | C10_CUDA_CHECK(cudaMemcpy( |
298 | bDev, |
299 | bData.data(), |
300 | bData.size() * sizeof(bData[0]), |
301 | cudaMemcpyHostToDevice)); |
302 | C10_CUDA_CHECK(cudaMemcpy( |
303 | cDev, |
304 | cData.data(), |
305 | cData.size() * sizeof(cData[0]), |
306 | cudaMemcpyHostToDevice)); |
307 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
308 | |
309 | cg.call({aDev, bDev, cDev, M, N}); |
310 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
311 | |
312 | C10_CUDA_CHECK(cudaMemcpy( |
313 | cData.data(), |
314 | cDev, |
315 | cData.size() * sizeof(cData[0]), |
316 | cudaMemcpyDeviceToHost)); |
317 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
318 | |
319 | ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7); |
320 | |
321 | C10_CUDA_CHECK(cudaFree(aDev)); |
322 | C10_CUDA_CHECK(cudaFree(bDev)); |
323 | C10_CUDA_CHECK(cudaFree(cDev)); |
324 | }; |
325 | testWithSize(32, 32); |
326 | testWithSize(1, 16); |
327 | testWithSize(27, 13); |
328 | } |
329 | |
330 | TEST(Cuda, TestRand01_CUDA) { |
331 | const int num_iter = 3; |
332 | const int block_count = 16; |
333 | const int block_size = 128; |
334 | Tensor c = Compute( |
335 | "c" , |
336 | { |
337 | num_iter, |
338 | block_count, |
339 | block_size, |
340 | }, |
341 | [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { |
342 | return Intrinsics::make(IntrinsicsOp::kRand, kFloat); |
343 | }); |
344 | LoopNest l({c}); |
345 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
346 | loops[1]->set_gpu_block_index(0); |
347 | loops[2]->set_gpu_thread_index(0); |
348 | l.prepareForCodegen(); |
349 | StmtPtr stmt = l.root_stmt(); |
350 | CudaCodeGen cuda_cg(stmt, c); |
351 | const int N = block_count * block_size * num_iter; |
352 | PaddedBuffer<float> c_v(N); |
353 | |
354 | // TODO: move gpu support into PaddedBuffer |
355 | float* c_dev = nullptr; |
356 | C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); |
357 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
358 | |
359 | cuda_cg(c_dev); |
360 | |
361 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
362 | C10_CUDA_CHECK( |
363 | cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); |
364 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
365 | |
366 | float sum1 = 0; |
367 | float sum2 = 0; |
368 | float sum3 = 0; |
369 | for (const auto i : c10::irange(N)) { |
370 | float v = c_v.data()[i]; |
371 | sum1 += v; |
372 | sum2 += v * v; |
373 | sum3 += v * v * v; |
374 | ASSERT_TRUE(v >= 0 && v < 1); |
375 | } |
376 | sum1 /= N; |
377 | sum2 /= N; |
378 | sum3 /= N; |
379 | float sum1_mean = 1.f / 2; |
380 | float sum2_mean = 1.f / 3; |
381 | float sum3_mean = 1.f / 4; |
382 | |
383 | ASSERT_NEAR(sum1, sum1_mean, 2e-2); |
384 | ASSERT_NEAR(sum2, sum2_mean, 2e-2); |
385 | ASSERT_NEAR(sum3, sum3_mean, 2e-2); |
386 | C10_CUDA_CHECK(cudaFree(c_dev)); |
387 | } |
388 | |
389 | TEST(Cuda, DynamicShapeSplit_CUDA) { |
390 | constexpr int64_t N = 4096; |
391 | VarHandle n("n" , kLong); |
392 | BufHandle a("a" , {n}, kFloat); |
393 | Tensor b = |
394 | Compute("b" , {n}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); |
395 | LoopNest l({b}); |
396 | ForPtr inner; |
397 | std::vector<ForPtr> loops = l.getLoopStmtsFor(b); |
398 | l.splitWithMask(loops[0], 1024, &inner); |
399 | loops[0]->set_gpu_block_index(0); |
400 | inner->set_gpu_thread_index(0); |
401 | StmtPtr s = l.root_stmt(); |
402 | CudaCodeGen cg(s, {a, b, n}); |
403 | |
404 | std::vector<float> aData(N, 1.0f); |
405 | std::vector<float> bData(N, 1.0f); |
406 | float* aDev = nullptr; |
407 | float* bDev = nullptr; |
408 | C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); |
409 | C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); |
410 | C10_CUDA_CHECK(cudaMemcpy( |
411 | aDev, |
412 | aData.data(), |
413 | aData.size() * sizeof(aData[0]), |
414 | cudaMemcpyHostToDevice)); |
415 | C10_CUDA_CHECK(cudaMemcpy( |
416 | bDev, |
417 | bData.data(), |
418 | bData.size() * sizeof(aData[0]), |
419 | cudaMemcpyHostToDevice)); |
420 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
421 | |
422 | cg.call({aDev, bDev, N}); |
423 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
424 | |
425 | C10_CUDA_CHECK(cudaMemcpy( |
426 | bData.data(), |
427 | bDev, |
428 | bData.size() * sizeof(aData[0]), |
429 | cudaMemcpyDeviceToHost)); |
430 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
431 | |
432 | ExpectAllNear(bData, std::vector<float>(N, 2.0f), 1e-7); |
433 | |
434 | C10_CUDA_CHECK(cudaFree(aDev)); |
435 | C10_CUDA_CHECK(cudaFree(bDev)); |
436 | } |
437 | |
438 | TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { |
439 | const static int N = 1024; |
440 | BufHandle data_buf("data" , {N}, kFloat); |
441 | BufHandle output_buf("output" , {1}, kFloat); |
442 | |
443 | // The test adds the following code for trivial reduction: |
444 | // for (const auto bidx : c10::irange(1)) { // blockIdx.x |
445 | // for (const auto tidx : c10::irange(1)) { // threadIdx.x |
446 | // output[0] = 0.f; |
447 | // for (const auto i1 : c10::irange(1024)) { |
448 | // output[0] = output[0] + data[i1]; |
449 | // } |
450 | // } |
451 | // } |
452 | |
453 | StorePtr init_store = output_buf.store({0}, 0.f); |
454 | VarHandle i1("i1" , kInt); |
455 | ExprHandle load_data = Load::make(data_buf, {i1}); |
456 | ExprHandle load_output = Load::make(output_buf, {0}); |
457 | ExprHandle add_value = load_output + load_data; |
458 | StorePtr store_output = output_buf.store({0}, add_value); |
459 | ForPtr for_output = For::make(i1, 0, N, store_output); |
460 | StmtPtr reduce_block = Block::make({init_store, for_output}); |
461 | VarHandle thread_idx("tidx" , kInt); |
462 | LoopOptions thread_idx_options; |
463 | thread_idx_options.set_gpu_thread_index(0); |
464 | ForPtr thread_idx_loop = |
465 | For::make(thread_idx, 0, 1, reduce_block, thread_idx_options); |
466 | VarHandle block_idx("bidx" , kInt); |
467 | LoopOptions block_idx_options; |
468 | block_idx_options.set_gpu_block_index(0); |
469 | ForPtr block_idx_loop = |
470 | For::make(block_idx, 0, 1, thread_idx_loop, block_idx_options); |
471 | |
472 | CudaCodeGen cuda_cg(block_idx_loop, data_buf, output_buf); |
473 | PaddedBuffer<float> data_v(N); |
474 | PaddedBuffer<float> output_v(1, "output_v" ); |
475 | PaddedBuffer<float> output_ref(1, "output_ref" ); |
476 | |
477 | output_ref(0) = 0; |
478 | for (const auto i : c10::irange(N)) { |
479 | data_v(i) = i; |
480 | output_ref(0) += data_v(i); |
481 | } |
482 | |
483 | float* data_dev = nullptr; |
484 | C10_CUDA_CHECK(cudaMalloc(&data_dev, N * sizeof(float))); |
485 | C10_CUDA_CHECK(cudaMemcpy( |
486 | data_dev, data_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); |
487 | float* output_dev = nullptr; |
488 | C10_CUDA_CHECK(cudaMalloc(&output_dev, 1 * sizeof(float))); |
489 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
490 | |
491 | cuda_cg(data_dev, output_dev); |
492 | |
493 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
494 | C10_CUDA_CHECK(cudaMemcpy( |
495 | output_v.data(), output_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); |
496 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
497 | |
498 | ExpectAllNear(output_v, output_ref, 1e-5); |
499 | |
500 | C10_CUDA_CHECK(cudaFree(data_dev)); |
501 | C10_CUDA_CHECK(cudaFree(output_dev)); |
502 | } |
503 | |
504 | TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { |
505 | const static int N = 1024; |
506 | |
507 | // This test does the following reduction: |
508 | // clang-format off |
509 | // for b in 0..1 // block-idx |
510 | // for t in 0..1024: // thread-idx |
511 | // if t < 1: |
512 | // b[0] = 0 |
513 | // // implied sync_threads |
514 | // for t in 0..1024: // thread-idx |
515 | // b[0] = b[0] + a[t] // implied atomic |
516 | // clang-format on |
517 | |
518 | BufHandle a_buf("a" , {N}, kFloat); |
519 | BufHandle b_buf("b" , {1}, kFloat); |
520 | |
521 | StorePtr init_store = b_buf.store({0}, 0.f); |
522 | VarHandle t("t" , kInt); |
523 | VarHandle b("b" , kInt); |
524 | |
525 | // for t in 0..1024: // thread-idx |
526 | // if t < 1: |
527 | // b[0] = 0 |
528 | ExprHandle cond_t_lt_1 = |
529 | CompareSelect::make(t, 1, CompareSelectOperation::kLT); |
530 | CondPtr masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr); |
531 | LoopOptions thread_idx_options; |
532 | thread_idx_options.set_gpu_thread_index(0); |
533 | ForPtr for_init = For::make(t, 0, N, masked_init_b, thread_idx_options); |
534 | |
535 | // for t in 0..1024: // thread-idx |
536 | // b[0] = b[0] + a[t] // implied atomic |
537 | ExprHandle load_a = Load::make(a_buf, {t}); |
538 | ExprHandle load_b = Load::make(b_buf, {0}); |
539 | ExprHandle add_value = load_b + load_a; |
540 | StorePtr store_b = b_buf.store({0}, add_value); |
541 | ForPtr for_b = For::make(t, 0, N, store_b, thread_idx_options); |
542 | |
543 | StmtPtr reduce_block = Block::make({for_init, for_b}); |
544 | |
545 | VarHandle block_idx("bidx" , kInt); |
546 | LoopOptions block_idx_options; |
547 | block_idx_options.set_gpu_block_index(0); |
548 | ForPtr block_idx_loop = |
549 | For::make(block_idx, 0, 1, reduce_block, block_idx_options); |
550 | |
551 | CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); |
552 | PaddedBuffer<float> a_v(N); |
553 | PaddedBuffer<float> b_v(1, "b_v" ); |
554 | PaddedBuffer<float> b_ref(1, "b_ref" ); |
555 | |
556 | b_ref(0) = 0; |
557 | for (const auto i : c10::irange(N)) { |
558 | a_v(i) = i; |
559 | b_ref(0) += a_v(i); |
560 | } |
561 | |
562 | float* a_dev = nullptr; |
563 | C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); |
564 | C10_CUDA_CHECK( |
565 | cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); |
566 | float* b_dev = nullptr; |
567 | C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); |
568 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
569 | |
570 | cuda_cg(a_dev, b_dev); |
571 | |
572 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
573 | C10_CUDA_CHECK( |
574 | cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); |
575 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
576 | |
577 | ExpectAllNear(b_v, b_ref, 1e-5); |
578 | |
579 | C10_CUDA_CHECK(cudaFree(a_dev)); |
580 | C10_CUDA_CHECK(cudaFree(b_dev)); |
581 | } |
582 | |
583 | TEST(Cuda, NoThreadIdxWrite_1_CUDA) { |
584 | // This test does the following reduction: |
585 | // |
586 | // for k in 0..1: // block-idx |
587 | // a[0] = 0 |
588 | // for n in 0..2: |
589 | // a[0] = a[0] + n |
590 | // for m in 0..1024: // thread-idx |
591 | // b[m] = m |
592 | // a[1] = 1 |
593 | // for l in 0..2: |
594 | // a[1] = a[1] + n |
595 | // |
596 | // note that the statements not covered by thread-idx are supposed to be |
597 | // covered by its own thread-idx |
598 | |
599 | const static int N = 1024; |
600 | BufHandle a_buf("a" , {2}, kFloat); |
601 | BufHandle b_buf("b" , {N}, kFloat); |
602 | |
603 | VarHandle k("k" , kInt); |
604 | VarHandle l("l" , kInt); |
605 | VarHandle m("m" , kInt); |
606 | VarHandle n("n" , kInt); |
607 | |
608 | // a[0] = 0 |
609 | // for n in 0..2: |
610 | // a[0] = a[0] + n |
611 | StorePtr store_a0_0 = a_buf.store({0}, 0.f); |
612 | ExprHandle load_a0 = Load::make(a_buf, {0}); |
613 | ExprHandle v1 = load_a0 + n; |
614 | StorePtr store_a0_v1 = a_buf.store({0}, v1); |
615 | ForPtr loop_a_0 = For::make(n, 0, 2, store_a0_v1); |
616 | |
617 | // for m in 0..1024: // thread-idx |
618 | // b[m] = m |
619 | StorePtr store_bm_m = b_buf.store({m}, m + 0.f); |
620 | LoopOptions thread_idx_options; |
621 | thread_idx_options.set_gpu_thread_index(0); |
622 | ForPtr loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options); |
623 | |
624 | // a[1] = 1 |
625 | // for l in 0..2: |
626 | // a[1] = a[1] + l |
627 | StorePtr store_a1_1 = a_buf.store({1}, 1.f); |
628 | ExprHandle load_a1 = a_buf.load(1); |
629 | ExprHandle v2 = load_a1 + l; |
630 | StorePtr store_a1_v2 = a_buf.store({1}, v2); |
631 | ForPtr loop_a_1 = For::make(l, 0, 2, store_a1_v2); |
632 | |
633 | StmtPtr reduce_block = |
634 | Block::make({store_a0_0, loop_a_0, loop_b_1, store_a1_1, loop_a_1}); |
635 | |
636 | VarHandle block_idx("bidx" , kInt); |
637 | LoopOptions block_idx_options; |
638 | block_idx_options.set_gpu_block_index(0); |
639 | ForPtr block_idx_loop = |
640 | For::make(block_idx, 0, 1, reduce_block, block_idx_options); |
641 | |
642 | CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); |
643 | PaddedBuffer<float> a_v(2); |
644 | PaddedBuffer<float> b_v(N, "b_v" ); |
645 | PaddedBuffer<float> a_ref(2, "a_ref" ); |
646 | PaddedBuffer<float> b_ref(N, "b_ref" ); |
647 | |
648 | a_ref(0) = 0; |
649 | for (const auto i : c10::irange(2)) { |
650 | a_ref(0) += i; |
651 | } |
652 | a_ref(1) = a_ref(0) + 1; |
653 | for (const auto i : c10::irange(N)) { |
654 | b_ref(i) = i; |
655 | } |
656 | |
657 | // TODO: add check of the generated code. |
658 | float* a_dev = nullptr; |
659 | C10_CUDA_CHECK(cudaMalloc(&a_dev, 2 * sizeof(float))); |
660 | float* b_dev = nullptr; |
661 | C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); |
662 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
663 | |
664 | cuda_cg(a_dev, b_dev); |
665 | |
666 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
667 | C10_CUDA_CHECK( |
668 | cudaMemcpy(a_v.data(), a_dev, 2 * sizeof(float), cudaMemcpyDeviceToHost)); |
669 | C10_CUDA_CHECK( |
670 | cudaMemcpy(b_v.data(), b_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); |
671 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
672 | |
673 | ExpectAllNear(a_v, a_ref, 1e-5); |
674 | ExpectAllNear(b_v, b_ref, 1e-5); |
675 | |
676 | C10_CUDA_CHECK(cudaFree(a_dev)); |
677 | C10_CUDA_CHECK(cudaFree(b_dev)); |
678 | } |
679 | |
680 | TEST(Cuda, SharedMemReduce_1_CUDA) { |
681 | // FIXME: this test is flaky in CI. |
682 | // This test does the following: |
683 | // for k in 0..1: // block-idx |
684 | // alloc(c, 64) |
685 | // for n in 0..64: // thread-idx |
686 | // c(n) = 0 |
687 | // for m in 0..128: |
688 | // for n in 0..64: // thread_idx |
689 | // c(n) = c(n) + a(k, m, n) |
690 | // b(k) = 0 |
691 | // for n in 0..64: // thread_idx |
692 | // b(k) = b(k) + c(n) |
693 | // free(c) |
694 | |
695 | const int M = 128; |
696 | const int N = 64; |
697 | const int kTotalSize = M * N; |
698 | LoopOptions thread_idx_opt; |
699 | thread_idx_opt.set_gpu_thread_index(0); |
700 | LoopOptions block_idx_opt; |
701 | block_idx_opt.set_gpu_block_index(0); |
702 | |
703 | BufHandle a("a" , {1, M, N}, kFloat); |
704 | BufHandle b("b" , {1}, kFloat); |
705 | VarHandle k("k" , kInt); |
706 | VarHandle m("m" , kInt); |
707 | VarHandle n("n" , kInt); |
708 | |
709 | std::vector<StmtPtr> block; |
710 | std::vector<ExprPtr> dims; |
711 | dims.push_back(ExprHandle(N).node()); |
712 | BufHandle c{alloc<Buf>("c" , dims, kFloat)}; |
713 | { |
714 | // alloc(c, 64); |
715 | AllocatePtr alloc = Allocate::make(c); |
716 | block.push_back(alloc); |
717 | } |
718 | |
719 | { |
720 | // for n in 0..64: // thread-idx |
721 | // c(n) = 0 |
722 | StorePtr store_cn_0 = Store::make(c, {n}, 0.f); |
723 | ForPtr loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt); |
724 | block.push_back(loop_n1); |
725 | } |
726 | |
727 | { |
728 | // for m in 0..128: |
729 | // for n in 0..64: // thread_idx |
730 | // c(n) = c(n) + a(k, m, n) |
731 | ExprHandle load_cn = Load::make(kFloat, c, {n}); |
732 | ExprHandle a_kmn = Load::make(a, {k * (M * N) + m * N + n}); |
733 | ExprHandle v_add = load_cn + a_kmn; |
734 | StorePtr store_cn_v = Store::make(c, {n}, v_add); |
735 | ForPtr loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt); |
736 | ForPtr loop_m1 = For::make(m, 0, M, loop_n2); |
737 | block.push_back(loop_m1); |
738 | } |
739 | |
740 | { |
741 | // b(k) = 0 |
742 | // for n in 0..64: // thread_idx |
743 | // b(k) = b(k) + c(n) |
744 | StorePtr store_bk_0 = b.store({k}, 0.f); |
745 | block.push_back(store_bk_0); |
746 | ExprHandle load_bk = b.load(k); |
747 | ExprHandle load_cn = Load::make(kFloat, c, {n}); |
748 | ExprHandle v_add = load_bk + load_cn; |
749 | StorePtr store_bk = b.store({k}, v_add); |
750 | ForPtr loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt); |
751 | block.push_back(loop_n3); |
752 | } |
753 | |
754 | { |
755 | // free(c) |
756 | FreePtr free_stmt = Free::make(c); |
757 | block.push_back(free_stmt); |
758 | } |
759 | |
760 | BlockPtr reduce_body = Block::make(block); |
761 | ForPtr loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt); |
762 | |
763 | // TODO: check the generated code for correctness. |
764 | CudaCodeGen cuda_cg(loop_k1, a, b); |
765 | |
766 | std::ostringstream oss; |
767 | oss << *cuda_cg.stmt(); |
768 | |
769 | // Check the c write is not masked, but the d write is. |
770 | const std::string& verification_pattern = |
771 | R"IR( |
772 | # CHECK: c_1 = 0 |
773 | # CHECK: for (int m = 0; m < 128 |
774 | # CHECK: c_1 = c_1 + |
775 | # CHECK: __syncthreads(); |
776 | # CHECK: if (threadIdx.x<1 |
777 | # CHECK: b[blockIdx.x] = |
778 | # CHECK: __syncthreads(); |
779 | # CHECK: atomicAdd(&b[blockIdx.x], c_1) |
780 | )IR" ; |
781 | |
782 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
783 | |
784 | PaddedBuffer<float> a_v(1, M, N, "a_v" ); |
785 | PaddedBuffer<float> b_v(1, "b_v" ); |
786 | PaddedBuffer<float> b_ref(1, "b_ref" ); |
787 | |
788 | b_ref(0) = 0; |
789 | for (const auto i : c10::irange(M)) { |
790 | for (const auto j : c10::irange(N)) { |
791 | int v = i + j; |
792 | a_v(0, i, j) = v; |
793 | b_ref(0) += v; |
794 | } |
795 | } |
796 | |
797 | float* a_dev = nullptr; |
798 | C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); |
799 | C10_CUDA_CHECK(cudaMemcpy( |
800 | a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); |
801 | float* b_dev = nullptr; |
802 | C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); |
803 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
804 | |
805 | cuda_cg(a_dev, b_dev); |
806 | |
807 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
808 | C10_CUDA_CHECK( |
809 | cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); |
810 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
811 | |
812 | ExpectAllNear(b_v, b_ref, 1e-5); |
813 | |
814 | C10_CUDA_CHECK(cudaFree(a_dev)); |
815 | C10_CUDA_CHECK(cudaFree(b_dev)); |
816 | } |
817 | |
818 | TEST(Cuda, LocalMemReduce_1_CUDA) { |
819 | // This test does the following: |
820 | // for k in 0..1: // block-idx |
821 | // b(k) = 0 |
822 | // for n in 0..64: // thread-idx |
823 | // alloc(c, 1) |
824 | // c(0) = 0 |
825 | // for m in 0..128: |
826 | // c(0) = c(0) + a(k, m, n) |
827 | // b(k) = b(k) + c(0) |
828 | // free(c) |
829 | |
830 | const int M = 128; |
831 | const int N = 64; |
832 | const int kTotalSize = M * N; |
833 | LoopOptions thread_idx_opt; |
834 | thread_idx_opt.set_gpu_thread_index(0); |
835 | LoopOptions block_idx_opt; |
836 | block_idx_opt.set_gpu_block_index(0); |
837 | |
838 | BufHandle a("a" , {1, M, N}, kFloat); |
839 | BufHandle b("b" , {1}, kFloat); |
840 | VarHandle k("k" , kInt); |
841 | VarHandle m("m" , kInt); |
842 | VarHandle n("n" , kInt); |
843 | |
844 | BufHandle c{ |
845 | alloc<Buf>("c" , std::vector<ExprPtr>({alloc<IntImm>(1)}), kFloat)}; |
846 | std::vector<StmtPtr> block_k; |
847 | { |
848 | // b(k) = 0 |
849 | StorePtr store_bk_0 = b.store({k}, 0.f); |
850 | block_k.push_back(store_bk_0); |
851 | } |
852 | std::vector<StmtPtr> block_n; |
853 | { |
854 | // alloc(c, 1); |
855 | AllocatePtr alloc = Allocate::make(c); |
856 | block_n.push_back(alloc); |
857 | } |
858 | { |
859 | // c(0) = 0 |
860 | StorePtr store_c0_0 = Store::make(c, {0}, 0.f); |
861 | block_n.push_back(store_c0_0); |
862 | } |
863 | { |
864 | // for m in 0..128: |
865 | // c(0) = c(0) + a(k, m, n) |
866 | ExprHandle load_c0 = Load::make(kFloat, c, {0}); |
867 | ExprHandle a_kmn = a.load(k * (M * N) + m * N + n); |
868 | ExprHandle v_add = load_c0 + a_kmn; |
869 | StorePtr store_c0_v = Store::make(c, {0}, v_add); |
870 | ForPtr loop_m = For::make(m, 0, M, store_c0_v); |
871 | block_n.push_back(loop_m); |
872 | } |
873 | { |
874 | // b(k) = b(k) + c(0) |
875 | ExprHandle load_bk = b.load(k); |
876 | ExprHandle load_c0 = Load::make(kFloat, c, {0}); |
877 | ExprHandle v_add = load_bk + load_c0; |
878 | StorePtr store_bk = b.store({k}, v_add); |
879 | block_n.push_back(store_bk); |
880 | } |
881 | { |
882 | // free(c) |
883 | FreePtr free_stmt = Free::make(c); |
884 | block_n.push_back(free_stmt); |
885 | } |
886 | { |
887 | BlockPtr block_n_stmt = Block::make(block_n); |
888 | ForPtr for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt); |
889 | block_k.push_back(for_n); |
890 | } |
891 | BlockPtr block_k_stmt = Block::make(block_k); |
892 | ForPtr loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt); |
893 | |
894 | CudaCodeGen cuda_cg(loop_k, a, b); |
895 | PaddedBuffer<float> a_v(1, M, N, "a_v" ); |
896 | PaddedBuffer<float> b_v(1, "b_v" ); |
897 | PaddedBuffer<float> b_ref(1, "b_ref" ); |
898 | |
899 | b_ref(0) = 0; |
900 | for (const auto i : c10::irange(M)) { |
901 | for (const auto j : c10::irange(N)) { |
902 | int v = i + j; |
903 | a_v(0, i, j) = v; |
904 | b_ref(0) += v; |
905 | } |
906 | } |
907 | |
908 | float* a_dev = nullptr; |
909 | C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); |
910 | C10_CUDA_CHECK(cudaMemcpy( |
911 | a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); |
912 | float* b_dev = nullptr; |
913 | C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); |
914 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
915 | |
916 | cuda_cg(a_dev, b_dev); |
917 | |
918 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
919 | C10_CUDA_CHECK( |
920 | cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); |
921 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
922 | |
923 | ExpectAllNear(b_v, b_ref, 1e-5); |
924 | |
925 | C10_CUDA_CHECK(cudaFree(a_dev)); |
926 | C10_CUDA_CHECK(cudaFree(b_dev)); |
927 | } |
928 | |
929 | TEST(Cuda, HalfSupport_CUDA) { |
930 | auto half = ToDtype<at::Half>(); |
931 | BufHandle a("a" , {4}, half); |
932 | Tensor b = Compute("b" , {4}, [&](const VarHandle& i) { |
933 | return Cast::make(half, ExprHandle(2.0f) * a.load(i)); |
934 | }); |
935 | |
936 | Tensor c = Compute("c" , {4}, [&](const VarHandle& i) { |
937 | return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b.load(i)); |
938 | }); |
939 | |
940 | Tensor d = Compute("d" , {4}, [&](const VarHandle& i) { |
941 | return Cast::make(half, c.load(i)); |
942 | }); |
943 | |
944 | LoopNest l({b, c, d}); |
945 | l.prepareForCodegen(); |
946 | StmtPtr s = l.root_stmt(); |
947 | CudaCodeGen cg(s, {a, b, c, d}); |
948 | |
949 | std::vector<at::Half> aData(4, 2.0f); |
950 | std::vector<float> cData(4, 0.0f); |
951 | std::vector<at::Half> dData(4, 0.0f); |
952 | at::Half* aDev = nullptr; |
953 | at::Half* bDev = nullptr; |
954 | at::Half* cDev = nullptr; |
955 | at::Half* dDev = nullptr; |
956 | auto aSize = aData.size() * sizeof(aData[0]); |
957 | auto bSize = aData.size() * sizeof(aData[0]); |
958 | auto cSize = cData.size() * sizeof(float); |
959 | auto dSize = dData.size() * sizeof(dData[0]); |
960 | |
961 | C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); |
962 | C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); |
963 | C10_CUDA_CHECK(cudaMalloc(&cDev, cSize)); |
964 | C10_CUDA_CHECK(cudaMalloc(&dDev, dSize)); |
965 | C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); |
966 | C10_CUDA_CHECK(cudaMemcpy(cDev, cData.data(), cSize, cudaMemcpyHostToDevice)); |
967 | C10_CUDA_CHECK(cudaMemcpy(dDev, dData.data(), dSize, cudaMemcpyHostToDevice)); |
968 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
969 | |
970 | cg.call({aDev, bDev, cDev, dDev}); |
971 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
972 | |
973 | C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); |
974 | C10_CUDA_CHECK(cudaMemcpy(cData.data(), cDev, cSize, cudaMemcpyDeviceToHost)); |
975 | C10_CUDA_CHECK(cudaMemcpy(dData.data(), dDev, dSize, cudaMemcpyDeviceToHost)); |
976 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
977 | |
978 | assertAllEqual(cData, 46.0f); |
979 | |
980 | C10_CUDA_CHECK(cudaFree(aDev)); |
981 | C10_CUDA_CHECK(cudaFree(bDev)); |
982 | C10_CUDA_CHECK(cudaFree(cDev)); |
983 | C10_CUDA_CHECK(cudaFree(dDev)); |
984 | } |
985 | |
986 | TEST(Cuda, HalfPropagation_CUDA) { |
987 | auto half = ToDtype<at::Half>(); |
988 | BufHandle a("a" , {4}, half); |
989 | Tensor relu = Compute("relu" , {4}, [&](const VarHandle& i) { |
990 | return Max::make(a.load(i), ExprHandle(alloc<HalfImm>(0)), true); |
991 | }); |
992 | |
993 | LoopNest l({relu}); |
994 | l.prepareForCodegen(); |
995 | StmtPtr s = l.root_stmt(); |
996 | CudaCodeGen cg(s, {a, relu}); |
997 | |
998 | std::ostringstream oss; |
999 | oss << *cg.stmt(); |
1000 | |
1001 | // Check the types used by the Max are Float. |
1002 | const std::string& verification_pattern = |
1003 | R"IR( |
1004 | # CHECK: for ( |
1005 | # CHECK: float v = float(a[i]); |
1006 | # CHECK: relu[i] = half(Max(v, 0.f |
1007 | # CHECK: })IR" ; |
1008 | |
1009 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1010 | |
1011 | std::vector<at::Half> aData(4, 2.0f); |
1012 | std::vector<at::Half> reluData(4, 0.0f); |
1013 | at::Half* aDev = nullptr; |
1014 | at::Half* reluDev = nullptr; |
1015 | auto aSize = aData.size() * sizeof(aData[0]); |
1016 | auto reluSize = reluData.size() * sizeof(reluData[0]); |
1017 | |
1018 | C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); |
1019 | C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); |
1020 | C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); |
1021 | C10_CUDA_CHECK( |
1022 | cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); |
1023 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1024 | |
1025 | cg.call({aDev, reluDev}); |
1026 | C10_CUDA_CHECK( |
1027 | cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); |
1028 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1029 | |
1030 | assertAllEqual(aData, reluData); |
1031 | |
1032 | C10_CUDA_CHECK(cudaFree(aDev)); |
1033 | C10_CUDA_CHECK(cudaFree(reluDev)); |
1034 | } |
1035 | |
1036 | TEST(Cuda, UnusedHalfArgument_CUDA) { |
1037 | BufHandle a("a" , {4}, kFloat); |
1038 | auto half = ToDtype<at::Half>(); |
1039 | BufHandle b("b" , {4}, half); |
1040 | Tensor relu = Compute("relu" , {4}, [&](const VarHandle& i) { |
1041 | return Max::make(a.load(i), ExprHandle(alloc<FloatImm>(0)), true); |
1042 | }); |
1043 | |
1044 | LoopNest l({relu}); |
1045 | l.prepareForCodegen(); |
1046 | StmtPtr s = l.root_stmt(); |
1047 | CudaCodeGen cg(s, {a, b, relu}); |
1048 | |
1049 | std::ostringstream oss; |
1050 | oss << *cg.stmt(); |
1051 | |
1052 | // Check the types used by the Max are Float. |
1053 | const std::string& verification_pattern = |
1054 | R"IR( |
1055 | # CHECK: for ( |
1056 | # CHECK: float v = a[i]; |
1057 | # CHECK: relu[i] = Max(v, 0.f |
1058 | # CHECK: })IR" ; |
1059 | |
1060 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1061 | |
1062 | // Sanity Cbeck; |
1063 | std::vector<float> aData(4, 2.0f); |
1064 | std::vector<at::Half> bData(4, 2.0f); |
1065 | std::vector<float> reluData(4, 0.0f); |
1066 | at::Half* aDev = nullptr; |
1067 | at::Half* bDev = nullptr; |
1068 | at::Half* reluDev = nullptr; |
1069 | auto aSize = aData.size() * sizeof(aData[0]); |
1070 | auto bSize = bData.size() * sizeof(bData[0]); |
1071 | auto reluSize = reluData.size() * sizeof(reluData[0]); |
1072 | |
1073 | C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); |
1074 | C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); |
1075 | C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); |
1076 | C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); |
1077 | C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); |
1078 | C10_CUDA_CHECK( |
1079 | cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); |
1080 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1081 | |
1082 | cg.call({aDev, bDev, reluDev}); |
1083 | C10_CUDA_CHECK( |
1084 | cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); |
1085 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1086 | |
1087 | assertAllEqual(aData, reluData); |
1088 | |
1089 | C10_CUDA_CHECK(cudaFree(aDev)); |
1090 | C10_CUDA_CHECK(cudaFree(bDev)); |
1091 | C10_CUDA_CHECK(cudaFree(reluDev)); |
1092 | } |
1093 | |
1094 | TEST(Cuda, PrioritizeDependents_CUDA) { |
1095 | BufHandle a("a" , {10}, kFloat); |
1096 | BufHandle b("b" , {12}, kFloat); |
1097 | BufHandle c("c" , {12}, kFloat); |
1098 | |
1099 | LoopOptions block_idx_opt; |
1100 | block_idx_opt.set_gpu_block_index(0); |
1101 | |
1102 | VarHandle i("i" , kInt); |
1103 | VarHandle j("j" , kInt); |
1104 | |
1105 | /* |
1106 | * for (const auto i : c10::irange(12)) { |
1107 | * c[i] = (i < 10 ? a[i] + b[i] : b[i]); |
1108 | * } |
1109 | */ |
1110 | ExprHandle load_a = a.load({i}); |
1111 | ExprHandle load_b = b.load({i}); |
1112 | ExprHandle cmp = CompareSelect::make(i, 10, CompareSelectOperation::kLT); |
1113 | ExprHandle ite = IfThenElse::make(cmp, Add::make(load_a, load_b), load_b); |
1114 | |
1115 | ForPtr loop = |
1116 | For::make(i, 0, 12, Block::make({c.store({i}, ite)}), block_idx_opt); |
1117 | |
1118 | CudaCodeGen cuda_cg(loop, a, b, c); |
1119 | |
1120 | PaddedBuffer<float> a_v(10, "a_v" ); |
1121 | PaddedBuffer<float> b_v(12, "b_v" ); |
1122 | PaddedBuffer<float> c_v(12, "c_v" ); |
1123 | PaddedBuffer<float> c_ref(12, "c_ref" ); |
1124 | |
1125 | for (const auto i : c10::irange(10)) { |
1126 | a_v(i) = i * 100; |
1127 | b_v(i) = i; |
1128 | c_v(i) = 0; |
1129 | } |
1130 | |
1131 | for (const auto i : c10::irange(10, 12)) { |
1132 | b_v(i) = i; |
1133 | c_v(i) = 0; |
1134 | } |
1135 | |
1136 | float* a_dev = nullptr; |
1137 | float* b_dev = nullptr; |
1138 | float* c_dev = nullptr; |
1139 | C10_CUDA_CHECK(cudaMalloc(&a_dev, 10 * sizeof(float))); |
1140 | C10_CUDA_CHECK(cudaMalloc(&b_dev, 12 * sizeof(float))); |
1141 | C10_CUDA_CHECK(cudaMalloc(&c_dev, 12 * sizeof(float))); |
1142 | |
1143 | C10_CUDA_CHECK(cudaMemcpy( |
1144 | a_dev, a_v.data(), 10 * sizeof(float), cudaMemcpyHostToDevice)); |
1145 | C10_CUDA_CHECK(cudaMemcpy( |
1146 | b_dev, b_v.data(), 12 * sizeof(float), cudaMemcpyHostToDevice)); |
1147 | |
1148 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1149 | |
1150 | cuda_cg(a_dev, b_dev, c_dev); |
1151 | |
1152 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1153 | C10_CUDA_CHECK(cudaMemcpy( |
1154 | c_v.data(), c_dev, 12 * sizeof(float), cudaMemcpyDeviceToHost)); |
1155 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1156 | |
1157 | for (const auto i : c10::irange(12)) { |
1158 | if (i < 10) { |
1159 | c_ref(i) = i + i * 100; |
1160 | } else { |
1161 | c_ref(i) = i; |
1162 | } |
1163 | } |
1164 | |
1165 | ExpectAllNear(c_v, c_ref, 1e-5); |
1166 | } |
1167 | |
1168 | /// Tests the case where there are two loops which have different extents bound |
1169 | /// to the same block dimension. We must mask the smaller extent loop body. |
1170 | TEST(Cuda, MaskBlockDim_CUDA) { |
1171 | int A_SIZE = 100; |
1172 | int B_SIZE = 50; |
1173 | BufHandle a_buf("a" , {A_SIZE}, kFloat); |
1174 | BufHandle b_buf("b" , {B_SIZE}, kFloat); |
1175 | Tensor c = Compute( |
1176 | "c" , {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); |
1177 | Tensor d = Compute("d" , {B_SIZE}, [&](const VarHandle& i) { |
1178 | return a_buf.load(i) + b_buf.load(i); |
1179 | }); |
1180 | |
1181 | LoopNest l({c, d}); |
1182 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
1183 | loops[0]->set_gpu_block_index(0); |
1184 | loops = l.getLoopStmtsFor(d); |
1185 | loops[0]->set_gpu_block_index(0); |
1186 | |
1187 | l.prepareForCodegen(); |
1188 | StmtPtr stmt = l.root_stmt(); |
1189 | CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); |
1190 | |
1191 | std::ostringstream oss; |
1192 | oss << *cuda_cg.stmt(); |
1193 | |
1194 | // Check the c write is not masked, but the d write is. |
1195 | const std::string& verification_pattern = |
1196 | R"IR( |
1197 | # CHECK-NOT: if (blockIdx |
1198 | # CHECK: c[blockIdx.x] = |
1199 | # CHECK: if (blockIdx.x<50 |
1200 | # CHECK: d[blockIdx.x] =)IR" ; |
1201 | |
1202 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1203 | |
1204 | auto blockExtents = cuda_cg.gpu_block_extents(); |
1205 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
1206 | ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(A_SIZE))); |
1207 | ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(1))); |
1208 | |
1209 | // Sanity check that the kernel works. |
1210 | PaddedBuffer<float> a_v(A_SIZE); |
1211 | PaddedBuffer<float> b_v(B_SIZE); |
1212 | PaddedBuffer<float> c_v(A_SIZE); |
1213 | PaddedBuffer<float> d_v(B_SIZE); |
1214 | |
1215 | PaddedBuffer<float> c_ref(A_SIZE); |
1216 | PaddedBuffer<float> d_ref(B_SIZE); |
1217 | |
1218 | for (const auto i : c10::irange(A_SIZE)) { |
1219 | a_v(i) = (float)i; |
1220 | c_ref(i) = (float)(i + 10); |
1221 | } |
1222 | |
1223 | for (const auto i : c10::irange(B_SIZE)) { |
1224 | b_v(i) = (float)(B_SIZE - i); |
1225 | d_ref(i) = a_v(i) + b_v(i); |
1226 | } |
1227 | |
1228 | float* a_dev = nullptr; |
1229 | C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); |
1230 | float* b_dev = nullptr; |
1231 | C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); |
1232 | float* c_dev = nullptr; |
1233 | C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); |
1234 | float* d_dev = nullptr; |
1235 | C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); |
1236 | C10_CUDA_CHECK(cudaMemcpy( |
1237 | a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1238 | C10_CUDA_CHECK(cudaMemcpy( |
1239 | b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1240 | C10_CUDA_CHECK(cudaMemcpy( |
1241 | c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1242 | C10_CUDA_CHECK(cudaMemcpy( |
1243 | d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1244 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1245 | |
1246 | cuda_cg(c_dev, d_dev, a_dev, b_dev); |
1247 | |
1248 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1249 | C10_CUDA_CHECK(cudaMemcpy( |
1250 | c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); |
1251 | C10_CUDA_CHECK(cudaMemcpy( |
1252 | d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); |
1253 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1254 | |
1255 | ExpectAllNear(c_v, c_ref, 1e-5); |
1256 | ExpectAllNear(d_v, d_ref, 1e-5); |
1257 | |
1258 | C10_CUDA_CHECK(cudaFree(a_dev)); |
1259 | C10_CUDA_CHECK(cudaFree(b_dev)); |
1260 | C10_CUDA_CHECK(cudaFree(c_dev)); |
1261 | C10_CUDA_CHECK(cudaFree(d_dev)); |
1262 | } |
1263 | |
1264 | /// Tests the case with two loops, which have different extents that are bound |
1265 | /// to the same thread dimension. This is the same as the above - the smaller |
1266 | /// rank write should be masked. But this time we also need to syncthreads. |
1267 | TEST(Cuda, MaskThreadDim_CUDA) { |
1268 | int A_SIZE = 50; |
1269 | int B_SIZE = 100; |
1270 | BufHandle a_buf("a" , {A_SIZE}, kFloat); |
1271 | BufHandle b_buf("b" , {B_SIZE}, kFloat); |
1272 | Tensor c = Compute( |
1273 | "c" , {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); |
1274 | Tensor d = Compute("d" , {B_SIZE}, [&](const VarHandle& i) { |
1275 | return a_buf.load(i / 2) + b_buf.load(i); |
1276 | }); |
1277 | |
1278 | LoopNest l({c, d}); |
1279 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
1280 | loops[0]->set_gpu_thread_index(0); |
1281 | loops = l.getLoopStmtsFor(d); |
1282 | loops[0]->set_gpu_thread_index(0); |
1283 | |
1284 | l.prepareForCodegen(); |
1285 | StmtPtr stmt = l.root_stmt(); |
1286 | CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); |
1287 | |
1288 | std::ostringstream oss; |
1289 | oss << *cuda_cg.stmt(); |
1290 | |
1291 | // Check the c write is masked, but the d write is not. |
1292 | const std::string& verification_pattern = |
1293 | R"IR( |
1294 | # CHECK: if (threadIdx.x<50 |
1295 | # CHECK: c[threadIdx.x] = |
1296 | # CHECK: __syncthreads(); |
1297 | # CHECK-NOT: if (threadIdx.x |
1298 | # CHECK: d[threadIdx.x] =)IR" ; |
1299 | |
1300 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1301 | |
1302 | auto blockExtents = cuda_cg.gpu_block_extents(); |
1303 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
1304 | ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(1))); |
1305 | ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(B_SIZE))); |
1306 | |
1307 | PaddedBuffer<float> a_v(A_SIZE); |
1308 | PaddedBuffer<float> b_v(B_SIZE); |
1309 | PaddedBuffer<float> c_v(A_SIZE); |
1310 | PaddedBuffer<float> d_v(B_SIZE); |
1311 | |
1312 | PaddedBuffer<float> c_ref(A_SIZE); |
1313 | PaddedBuffer<float> d_ref(B_SIZE); |
1314 | |
1315 | for (const auto i : c10::irange(A_SIZE)) { |
1316 | a_v(i) = (float)i; |
1317 | c_ref(i) = (float)(i + 10); |
1318 | } |
1319 | |
1320 | for (const auto i : c10::irange(B_SIZE)) { |
1321 | b_v(i) = (float)(B_SIZE - i); |
1322 | d_ref(i) = a_v(i / 2) + b_v(i); |
1323 | } |
1324 | |
1325 | float* a_dev = nullptr; |
1326 | C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); |
1327 | float* b_dev = nullptr; |
1328 | C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); |
1329 | float* c_dev = nullptr; |
1330 | C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); |
1331 | float* d_dev = nullptr; |
1332 | C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); |
1333 | C10_CUDA_CHECK(cudaMemcpy( |
1334 | a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1335 | C10_CUDA_CHECK(cudaMemcpy( |
1336 | b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1337 | C10_CUDA_CHECK(cudaMemcpy( |
1338 | c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1339 | C10_CUDA_CHECK(cudaMemcpy( |
1340 | d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1341 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1342 | |
1343 | cuda_cg(c_dev, d_dev, a_dev, b_dev); |
1344 | |
1345 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1346 | C10_CUDA_CHECK(cudaMemcpy( |
1347 | c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); |
1348 | C10_CUDA_CHECK(cudaMemcpy( |
1349 | d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); |
1350 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1351 | |
1352 | ExpectAllNear(c_v, c_ref, 1e-5); |
1353 | ExpectAllNear(d_v, d_ref, 1e-5); |
1354 | |
1355 | C10_CUDA_CHECK(cudaFree(a_dev)); |
1356 | C10_CUDA_CHECK(cudaFree(b_dev)); |
1357 | C10_CUDA_CHECK(cudaFree(c_dev)); |
1358 | C10_CUDA_CHECK(cudaFree(d_dev)); |
1359 | } |
1360 | |
1361 | /// Tests the case where there are two loops, and each is bound to a different |
1362 | /// block dimension. In this case all writes should be masked since they occur |
1363 | /// in distinct dimensions. |
1364 | // Note: this is an extremely dumb pattern which we should never see, but is a |
1365 | // useful edge case to make sure we've got things covered. |
1366 | TEST(Cuda, MaskMultiBlockDim_CUDA) { |
1367 | int A_SIZE = 100; |
1368 | int B_SIZE = 50; |
1369 | BufHandle a_buf("a" , {A_SIZE}, kFloat); |
1370 | BufHandle b_buf("b" , {B_SIZE}, kFloat); |
1371 | Tensor c = Compute( |
1372 | "c" , {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); |
1373 | Tensor d = Compute("d" , {B_SIZE}, [&](const VarHandle& i) { |
1374 | return a_buf.load(i) + b_buf.load(i); |
1375 | }); |
1376 | |
1377 | LoopNest l({c, d}); |
1378 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
1379 | loops[0]->set_gpu_block_index(0); |
1380 | loops = l.getLoopStmtsFor(d); |
1381 | loops[0]->set_gpu_block_index(1); |
1382 | |
1383 | l.prepareForCodegen(); |
1384 | StmtPtr stmt = l.root_stmt(); |
1385 | CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); |
1386 | |
1387 | std::ostringstream oss; |
1388 | oss << *cuda_cg.stmt(); |
1389 | |
1390 | // Write to c should be masked against y, write to d against x. |
1391 | const std::string& verification_pattern = |
1392 | R"IR( |
1393 | # CHECK: if (blockIdx.y<1 |
1394 | # CHECK: c[blockIdx.x] = |
1395 | # CHECK: if (blockIdx.x<1 |
1396 | # CHECK: d[blockIdx.y] =)IR" ; |
1397 | |
1398 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1399 | |
1400 | auto blockExtents = cuda_cg.gpu_block_extents(); |
1401 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
1402 | ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(A_SIZE))); |
1403 | ASSERT_TRUE(exprEquals(blockExtents[1], alloc<IntImm>(B_SIZE))); |
1404 | |
1405 | PaddedBuffer<float> a_v(A_SIZE); |
1406 | PaddedBuffer<float> b_v(B_SIZE); |
1407 | PaddedBuffer<float> c_v(A_SIZE); |
1408 | PaddedBuffer<float> d_v(B_SIZE); |
1409 | |
1410 | PaddedBuffer<float> c_ref(A_SIZE); |
1411 | PaddedBuffer<float> d_ref(B_SIZE); |
1412 | |
1413 | for (const auto i : c10::irange(A_SIZE)) { |
1414 | a_v(i) = (float)i; |
1415 | c_ref(i) = (float)(i + 10); |
1416 | } |
1417 | |
1418 | for (const auto i : c10::irange(B_SIZE)) { |
1419 | b_v(i) = (float)(B_SIZE - i); |
1420 | d_ref(i) = a_v(i) + b_v(i); |
1421 | } |
1422 | |
1423 | float* a_dev = nullptr; |
1424 | C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); |
1425 | float* b_dev = nullptr; |
1426 | C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); |
1427 | float* c_dev = nullptr; |
1428 | C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); |
1429 | float* d_dev = nullptr; |
1430 | C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); |
1431 | C10_CUDA_CHECK(cudaMemcpy( |
1432 | a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1433 | C10_CUDA_CHECK(cudaMemcpy( |
1434 | b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1435 | C10_CUDA_CHECK(cudaMemcpy( |
1436 | c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1437 | C10_CUDA_CHECK(cudaMemcpy( |
1438 | d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1439 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1440 | |
1441 | cuda_cg(c_dev, d_dev, a_dev, b_dev); |
1442 | |
1443 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1444 | C10_CUDA_CHECK(cudaMemcpy( |
1445 | c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); |
1446 | C10_CUDA_CHECK(cudaMemcpy( |
1447 | d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); |
1448 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1449 | |
1450 | ExpectAllNear(c_v, c_ref, 1e-5); |
1451 | ExpectAllNear(d_v, d_ref, 1e-5); |
1452 | |
1453 | C10_CUDA_CHECK(cudaFree(a_dev)); |
1454 | C10_CUDA_CHECK(cudaFree(b_dev)); |
1455 | C10_CUDA_CHECK(cudaFree(c_dev)); |
1456 | C10_CUDA_CHECK(cudaFree(d_dev)); |
1457 | } |
1458 | |
1459 | /// Tests the case where both the blockDim and threadDim are bound to different |
1460 | /// loops. In this instance both stores should be masked since they are |
1461 | /// distinct. |
1462 | // Note: this is an extremely dumb pattern which we should never see, but is a |
1463 | // useful edge case to make sure we've got things covered. |
1464 | TEST(Cuda, MaskBlockAndThreadDim_CUDA) { |
1465 | int A_SIZE = 100; |
1466 | int B_SIZE = 50; |
1467 | BufHandle a_buf("a" , {A_SIZE}, kFloat); |
1468 | BufHandle b_buf("b" , {B_SIZE}, kFloat); |
1469 | Tensor c = Compute( |
1470 | "c" , {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); |
1471 | Tensor d = Compute("d" , {B_SIZE}, [&](const VarHandle& i) { |
1472 | return a_buf.load(i) + b_buf.load(i); |
1473 | }); |
1474 | |
1475 | LoopNest l({c, d}); |
1476 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
1477 | loops[0]->set_gpu_block_index(0); |
1478 | loops = l.getLoopStmtsFor(d); |
1479 | loops[0]->set_gpu_thread_index(0); |
1480 | |
1481 | l.prepareForCodegen(); |
1482 | StmtPtr stmt = l.root_stmt(); |
1483 | CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); |
1484 | |
1485 | std::ostringstream oss; |
1486 | oss << *cuda_cg.stmt(); |
1487 | |
1488 | const std::string& verification_pattern = |
1489 | R"IR( |
1490 | # CHECK: if (threadIdx.x<1 |
1491 | # CHECK: c[blockIdx.x] = |
1492 | # CHECK: } |
1493 | # CHECK: if (blockIdx.x<1 |
1494 | # CHECK: d[threadIdx.x] =)IR" ; |
1495 | |
1496 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1497 | |
1498 | auto blockExtents = cuda_cg.gpu_block_extents(); |
1499 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
1500 | ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(A_SIZE))); |
1501 | ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(B_SIZE))); |
1502 | |
1503 | PaddedBuffer<float> a_v(A_SIZE); |
1504 | PaddedBuffer<float> b_v(B_SIZE); |
1505 | PaddedBuffer<float> c_v(A_SIZE); |
1506 | PaddedBuffer<float> d_v(B_SIZE); |
1507 | |
1508 | PaddedBuffer<float> c_ref(A_SIZE); |
1509 | PaddedBuffer<float> d_ref(B_SIZE); |
1510 | |
1511 | for (const auto i : c10::irange(A_SIZE)) { |
1512 | a_v(i) = (float)i; |
1513 | c_ref(i) = (float)(i + 10); |
1514 | } |
1515 | |
1516 | for (const auto i : c10::irange(B_SIZE)) { |
1517 | b_v(i) = (float)(B_SIZE - i); |
1518 | d_ref(i) = a_v(i) + b_v(i); |
1519 | } |
1520 | |
1521 | float* a_dev = nullptr; |
1522 | C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); |
1523 | float* b_dev = nullptr; |
1524 | C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); |
1525 | float* c_dev = nullptr; |
1526 | C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); |
1527 | float* d_dev = nullptr; |
1528 | C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); |
1529 | C10_CUDA_CHECK(cudaMemcpy( |
1530 | a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1531 | C10_CUDA_CHECK(cudaMemcpy( |
1532 | b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1533 | C10_CUDA_CHECK(cudaMemcpy( |
1534 | c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1535 | C10_CUDA_CHECK(cudaMemcpy( |
1536 | d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); |
1537 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1538 | |
1539 | cuda_cg(c_dev, d_dev, a_dev, b_dev); |
1540 | |
1541 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1542 | C10_CUDA_CHECK(cudaMemcpy( |
1543 | c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); |
1544 | C10_CUDA_CHECK(cudaMemcpy( |
1545 | d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); |
1546 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1547 | |
1548 | ExpectAllNear(c_v, c_ref, 1e-5); |
1549 | ExpectAllNear(d_v, d_ref, 1e-5); |
1550 | |
1551 | C10_CUDA_CHECK(cudaFree(a_dev)); |
1552 | C10_CUDA_CHECK(cudaFree(b_dev)); |
1553 | C10_CUDA_CHECK(cudaFree(c_dev)); |
1554 | C10_CUDA_CHECK(cudaFree(d_dev)); |
1555 | } |
1556 | |
1557 | /// Tests the case where the loopnest has two loops of depth two: each with the |
1558 | /// outer loop bound to blockDim.x and the inner loop bound to threadDim.x. In |
1559 | /// this case all writes with a rank smaller than the max should be masked. |
1560 | TEST(Cuda, MaskMultiDim_CUDA) { |
1561 | int OUTER_SIZE = 10; |
1562 | int A_SIZE = 100; |
1563 | int B_SIZE = 50; |
1564 | BufHandle a_buf("a" , {OUTER_SIZE, A_SIZE}, kFloat); |
1565 | BufHandle b_buf("b" , {OUTER_SIZE, B_SIZE}, kFloat); |
1566 | Tensor c = Compute( |
1567 | "C" , {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { |
1568 | return ExprHandle(2) * a_buf.load(i, j); |
1569 | }); |
1570 | Tensor d = Compute( |
1571 | "D" , {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { |
1572 | return c.load(i, j * 2) + b_buf.load(i, j); |
1573 | }); |
1574 | |
1575 | LoopNest l({c, d}); |
1576 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
1577 | loops[0]->set_gpu_block_index(0); |
1578 | loops[1]->set_gpu_thread_index(0); |
1579 | loops = l.getLoopStmtsFor(d); |
1580 | loops[0]->set_gpu_block_index(0); |
1581 | loops[1]->set_gpu_thread_index(0); |
1582 | |
1583 | l.prepareForCodegen(); |
1584 | StmtPtr stmt = l.root_stmt(); |
1585 | CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); |
1586 | |
1587 | std::ostringstream oss; |
1588 | oss << *cuda_cg.stmt(); |
1589 | |
1590 | // The write to D should be masked, but not the write to C. |
1591 | const std::string& verification_pattern = |
1592 | R"IR( |
1593 | # CHECK-NOT: if ( |
1594 | # CHECK: C[threadIdx.x + 100 * blockIdx.x] = |
1595 | # CHECK: __syncthreads(); |
1596 | # CHECK: if (threadIdx.x<50 |
1597 | # CHECK: D[threadIdx.x + 50 * blockIdx.x] =)IR" ; |
1598 | |
1599 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1600 | |
1601 | auto blockExtents = cuda_cg.gpu_block_extents(); |
1602 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
1603 | ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(OUTER_SIZE))); |
1604 | ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE))); |
1605 | |
1606 | PaddedBuffer<float> a_v(OUTER_SIZE, A_SIZE); |
1607 | PaddedBuffer<float> b_v(OUTER_SIZE, B_SIZE); |
1608 | PaddedBuffer<float> c_v(OUTER_SIZE, A_SIZE); |
1609 | PaddedBuffer<float> d_v(OUTER_SIZE, B_SIZE); |
1610 | |
1611 | PaddedBuffer<float> c_ref(OUTER_SIZE, A_SIZE); |
1612 | PaddedBuffer<float> d_ref(OUTER_SIZE, B_SIZE); |
1613 | |
1614 | for (const auto o : c10::irange(OUTER_SIZE)) { |
1615 | for (const auto i : c10::irange(A_SIZE)) { |
1616 | a_v(o, i) = (float)i; |
1617 | c_ref(o, i) = (float)(i * 2); |
1618 | } |
1619 | } |
1620 | |
1621 | for (const auto o : c10::irange(OUTER_SIZE)) { |
1622 | for (const auto i : c10::irange(B_SIZE)) { |
1623 | b_v(o, i) = (float)(B_SIZE - i); |
1624 | d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); |
1625 | } |
1626 | } |
1627 | |
1628 | float* a_dev = nullptr; |
1629 | C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); |
1630 | float* b_dev = nullptr; |
1631 | C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); |
1632 | float* c_dev = nullptr; |
1633 | C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); |
1634 | float* d_dev = nullptr; |
1635 | C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); |
1636 | C10_CUDA_CHECK(cudaMemcpy( |
1637 | a_dev, |
1638 | a_v.data(), |
1639 | OUTER_SIZE * A_SIZE * sizeof(float), |
1640 | cudaMemcpyHostToDevice)); |
1641 | C10_CUDA_CHECK(cudaMemcpy( |
1642 | b_dev, |
1643 | b_v.data(), |
1644 | OUTER_SIZE * B_SIZE * sizeof(float), |
1645 | cudaMemcpyHostToDevice)); |
1646 | C10_CUDA_CHECK(cudaMemcpy( |
1647 | c_dev, |
1648 | c_v.data(), |
1649 | OUTER_SIZE * A_SIZE * sizeof(float), |
1650 | cudaMemcpyHostToDevice)); |
1651 | C10_CUDA_CHECK(cudaMemcpy( |
1652 | d_dev, |
1653 | d_v.data(), |
1654 | OUTER_SIZE * B_SIZE * sizeof(float), |
1655 | cudaMemcpyHostToDevice)); |
1656 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1657 | |
1658 | cuda_cg(c_dev, d_dev, a_dev, b_dev); |
1659 | |
1660 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1661 | C10_CUDA_CHECK(cudaMemcpy( |
1662 | c_v.data(), |
1663 | c_dev, |
1664 | OUTER_SIZE * A_SIZE * sizeof(float), |
1665 | cudaMemcpyDeviceToHost)); |
1666 | C10_CUDA_CHECK(cudaMemcpy( |
1667 | d_v.data(), |
1668 | d_dev, |
1669 | OUTER_SIZE * B_SIZE * sizeof(float), |
1670 | cudaMemcpyDeviceToHost)); |
1671 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1672 | |
1673 | ExpectAllNear(c_v, c_ref, 1e-5); |
1674 | ExpectAllNear(d_v, d_ref, 1e-5); |
1675 | |
1676 | C10_CUDA_CHECK(cudaFree(a_dev)); |
1677 | C10_CUDA_CHECK(cudaFree(b_dev)); |
1678 | C10_CUDA_CHECK(cudaFree(c_dev)); |
1679 | C10_CUDA_CHECK(cudaFree(d_dev)); |
1680 | } |
1681 | |
1682 | // Tests the case where loop extents are symbolic and not known at compile time. |
1683 | // In this case both stores must be masked against the extent of the other loop, |
1684 | // incase it is larger. |
1685 | TEST(Cuda, MaskMultiDimSymbolic_CUDA) { |
1686 | VarHandle OUTER_SIZE("OUTER_SIZE" , kLong); |
1687 | VarHandle A_SIZE("A_SIZE" , kLong); |
1688 | VarHandle B_SIZE("B_SIZE" , kLong); |
1689 | BufHandle a_buf("a" , {OUTER_SIZE, A_SIZE}, kFloat); |
1690 | BufHandle b_buf("b" , {OUTER_SIZE, B_SIZE}, kFloat); |
1691 | Tensor c = Compute( |
1692 | "C" , {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { |
1693 | return ExprHandle(2) * a_buf.load(i, j); |
1694 | }); |
1695 | Tensor d = Compute( |
1696 | "D" , {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { |
1697 | return c.load(i, j * 2) + b_buf.load(i, j); |
1698 | }); |
1699 | |
1700 | LoopNest l({c, d}); |
1701 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
1702 | loops[0]->set_gpu_block_index(0); |
1703 | loops[1]->set_gpu_thread_index(0); |
1704 | loops = l.getLoopStmtsFor(d); |
1705 | loops[0]->set_gpu_block_index(0); |
1706 | loops[1]->set_gpu_thread_index(0); |
1707 | |
1708 | l.prepareForCodegen(); |
1709 | StmtPtr stmt = l.root_stmt(); |
1710 | CudaCodeGen cuda_cg(stmt, c, d, OUTER_SIZE, A_SIZE, B_SIZE, a_buf, b_buf); |
1711 | |
1712 | std::ostringstream oss; |
1713 | oss << *cuda_cg.stmt(); |
1714 | |
1715 | // Since we don't know which is bigger (A_SIZE or B_SIZE) we must mask both. |
1716 | const std::string& verification_pattern = |
1717 | R"IR( |
1718 | # CHECK: if (threadIdx.x<A_SIZE |
1719 | # CHECK: C[A_SIZE * int64_t(blockIdx.x) + int64_t(threadIdx.x)] = |
1720 | # CHECK: __syncthreads(); |
1721 | # CHECK: if (threadIdx.x<B_SIZE |
1722 | # CHECK: D[B_SIZE * int64_t(blockIdx.x) + int64_t(threadIdx.x)] =)IR" ; |
1723 | |
1724 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1725 | |
1726 | auto blockExtents = cuda_cg.gpu_block_extents(); |
1727 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
1728 | ASSERT_TRUE(exprEquals(blockExtents[0], OUTER_SIZE.node())); |
1729 | ASSERT_TRUE(exprEquals( |
1730 | threadExtents[0], alloc<Max>(A_SIZE.node(), B_SIZE.node(), true))); |
1731 | |
1732 | int64_t OUTER_EXTENT = 10; |
1733 | int64_t A_EXTENT = 100; |
1734 | int64_t B_EXTENT = 50; |
1735 | |
1736 | PaddedBuffer<float> a_v(OUTER_EXTENT, A_EXTENT); |
1737 | PaddedBuffer<float> b_v(OUTER_EXTENT, B_EXTENT); |
1738 | PaddedBuffer<float> c_v(OUTER_EXTENT, A_EXTENT); |
1739 | PaddedBuffer<float> d_v(OUTER_EXTENT, B_EXTENT); |
1740 | |
1741 | PaddedBuffer<float> c_ref(OUTER_EXTENT, A_EXTENT); |
1742 | PaddedBuffer<float> d_ref(OUTER_EXTENT, B_EXTENT); |
1743 | |
1744 | for (const auto o : c10::irange(OUTER_EXTENT)) { |
1745 | for (const auto i : c10::irange(A_EXTENT)) { |
1746 | a_v(o, i) = (float)i; |
1747 | c_ref(o, i) = (float)(i * 2); |
1748 | } |
1749 | } |
1750 | |
1751 | for (const auto o : c10::irange(OUTER_EXTENT)) { |
1752 | for (const auto i : c10::irange(B_EXTENT)) { |
1753 | b_v(o, i) = (float)(B_EXTENT - i); |
1754 | d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); |
1755 | } |
1756 | } |
1757 | |
1758 | float* a_dev = nullptr; |
1759 | C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); |
1760 | float* b_dev = nullptr; |
1761 | C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); |
1762 | float* c_dev = nullptr; |
1763 | C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); |
1764 | float* d_dev = nullptr; |
1765 | C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); |
1766 | C10_CUDA_CHECK(cudaMemcpy( |
1767 | a_dev, |
1768 | a_v.data(), |
1769 | OUTER_EXTENT * A_EXTENT * sizeof(float), |
1770 | cudaMemcpyHostToDevice)); |
1771 | C10_CUDA_CHECK(cudaMemcpy( |
1772 | b_dev, |
1773 | b_v.data(), |
1774 | OUTER_EXTENT * B_EXTENT * sizeof(float), |
1775 | cudaMemcpyHostToDevice)); |
1776 | C10_CUDA_CHECK(cudaMemcpy( |
1777 | c_dev, |
1778 | c_v.data(), |
1779 | OUTER_EXTENT * A_EXTENT * sizeof(float), |
1780 | cudaMemcpyHostToDevice)); |
1781 | C10_CUDA_CHECK(cudaMemcpy( |
1782 | d_dev, |
1783 | d_v.data(), |
1784 | OUTER_EXTENT * B_EXTENT * sizeof(float), |
1785 | cudaMemcpyHostToDevice)); |
1786 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1787 | |
1788 | cuda_cg(c_dev, d_dev, OUTER_EXTENT, A_EXTENT, B_EXTENT, a_dev, b_dev); |
1789 | |
1790 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1791 | C10_CUDA_CHECK(cudaMemcpy( |
1792 | c_v.data(), |
1793 | c_dev, |
1794 | OUTER_EXTENT * A_EXTENT * sizeof(float), |
1795 | cudaMemcpyDeviceToHost)); |
1796 | C10_CUDA_CHECK(cudaMemcpy( |
1797 | d_v.data(), |
1798 | d_dev, |
1799 | OUTER_EXTENT * B_EXTENT * sizeof(float), |
1800 | cudaMemcpyDeviceToHost)); |
1801 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1802 | |
1803 | ExpectAllNear(c_v, c_ref, 1e-5); |
1804 | ExpectAllNear(d_v, d_ref, 1e-5); |
1805 | |
1806 | C10_CUDA_CHECK(cudaFree(a_dev)); |
1807 | C10_CUDA_CHECK(cudaFree(b_dev)); |
1808 | C10_CUDA_CHECK(cudaFree(c_dev)); |
1809 | C10_CUDA_CHECK(cudaFree(d_dev)); |
1810 | } |
1811 | |
1812 | // Tests the case where two loops are fused at a common parent loop, which is |
1813 | // bound to the block dimension. Internally the inner loops have different |
1814 | // extents but are bound to the same thread dimension. The smaller loop should |
1815 | // be masked. |
1816 | TEST(Cuda, MaskCompoundInnerLoop_CUDA) { |
1817 | int OUTER_SIZE = 10; |
1818 | int A_SIZE = 100; |
1819 | int B_SIZE = 50; |
1820 | BufHandle a_buf("a" , {OUTER_SIZE, A_SIZE}, kFloat); |
1821 | BufHandle b_buf("b" , {OUTER_SIZE, B_SIZE}, kFloat); |
1822 | BufHandle c_buf("c" , {OUTER_SIZE, A_SIZE}, kFloat); |
1823 | BufHandle d_buf("d" , {OUTER_SIZE, B_SIZE}, kFloat); |
1824 | |
1825 | // Can't build this using Compute and transforms yet. |
1826 | LoopOptions blockBound; |
1827 | blockBound.set_gpu_block_index(0); |
1828 | LoopOptions threadBound; |
1829 | threadBound.set_gpu_thread_index(0); |
1830 | VarHandle i("i" , kInt); |
1831 | VarHandle j("j" , kInt); |
1832 | VarHandle k("k" , kInt); |
1833 | |
1834 | StmtPtr stmt = For::make( |
1835 | i, |
1836 | 0, |
1837 | OUTER_SIZE, |
1838 | Block::make( |
1839 | {For::make( |
1840 | j, |
1841 | 0, |
1842 | A_SIZE, |
1843 | c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), |
1844 | threadBound), |
1845 | For::make( |
1846 | k, |
1847 | 0, |
1848 | B_SIZE, |
1849 | d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), |
1850 | threadBound)}), |
1851 | blockBound); |
1852 | |
1853 | stmt = FlattenIndexes(stmt); |
1854 | stmt = IRSimplifier::simplify(stmt); |
1855 | |
1856 | CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); |
1857 | |
1858 | std::ostringstream oss; |
1859 | oss << *cuda_cg.stmt(); |
1860 | |
1861 | // The write to D should be masked, but not the write to C. |
1862 | const std::string& verification_pattern = |
1863 | R"IR( |
1864 | # CHECK-NOT: if ( |
1865 | # CHECK: c[threadIdx.x + 100 * blockIdx.x] = |
1866 | # CHECK: __syncthreads(); |
1867 | # CHECK: if (threadIdx.x<50 |
1868 | # CHECK: d[threadIdx.x + 50 * blockIdx.x] =)IR" ; |
1869 | |
1870 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
1871 | |
1872 | auto blockExtents = cuda_cg.gpu_block_extents(); |
1873 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
1874 | ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(OUTER_SIZE))); |
1875 | ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE))); |
1876 | |
1877 | PaddedBuffer<float> a_v(OUTER_SIZE, A_SIZE); |
1878 | PaddedBuffer<float> b_v(OUTER_SIZE, B_SIZE); |
1879 | PaddedBuffer<float> c_v(OUTER_SIZE, A_SIZE); |
1880 | PaddedBuffer<float> d_v(OUTER_SIZE, B_SIZE); |
1881 | |
1882 | PaddedBuffer<float> c_ref(OUTER_SIZE, A_SIZE); |
1883 | PaddedBuffer<float> d_ref(OUTER_SIZE, B_SIZE); |
1884 | |
1885 | for (const auto o : c10::irange(OUTER_SIZE)) { |
1886 | for (const auto i : c10::irange(A_SIZE)) { |
1887 | a_v(o, i) = (float)i; |
1888 | c_ref(o, i) = (float)(i * 2); |
1889 | } |
1890 | for (const auto i : c10::irange(B_SIZE)) { |
1891 | b_v(o, i) = (float)(B_SIZE - i); |
1892 | d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); |
1893 | } |
1894 | } |
1895 | |
1896 | float* a_dev = nullptr; |
1897 | C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); |
1898 | float* b_dev = nullptr; |
1899 | C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); |
1900 | float* c_dev = nullptr; |
1901 | C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); |
1902 | float* d_dev = nullptr; |
1903 | C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); |
1904 | C10_CUDA_CHECK(cudaMemcpy( |
1905 | a_dev, |
1906 | a_v.data(), |
1907 | OUTER_SIZE * A_SIZE * sizeof(float), |
1908 | cudaMemcpyHostToDevice)); |
1909 | C10_CUDA_CHECK(cudaMemcpy( |
1910 | b_dev, |
1911 | b_v.data(), |
1912 | OUTER_SIZE * B_SIZE * sizeof(float), |
1913 | cudaMemcpyHostToDevice)); |
1914 | C10_CUDA_CHECK(cudaMemcpy( |
1915 | c_dev, |
1916 | c_v.data(), |
1917 | OUTER_SIZE * A_SIZE * sizeof(float), |
1918 | cudaMemcpyHostToDevice)); |
1919 | C10_CUDA_CHECK(cudaMemcpy( |
1920 | d_dev, |
1921 | d_v.data(), |
1922 | OUTER_SIZE * B_SIZE * sizeof(float), |
1923 | cudaMemcpyHostToDevice)); |
1924 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1925 | |
1926 | cuda_cg(a_dev, b_dev, c_dev, d_dev); |
1927 | |
1928 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1929 | C10_CUDA_CHECK(cudaMemcpy( |
1930 | c_v.data(), |
1931 | c_dev, |
1932 | OUTER_SIZE * A_SIZE * sizeof(float), |
1933 | cudaMemcpyDeviceToHost)); |
1934 | C10_CUDA_CHECK(cudaMemcpy( |
1935 | d_v.data(), |
1936 | d_dev, |
1937 | OUTER_SIZE * B_SIZE * sizeof(float), |
1938 | cudaMemcpyDeviceToHost)); |
1939 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
1940 | |
1941 | ExpectAllNear(c_v, c_ref, 1e-5); |
1942 | ExpectAllNear(d_v, d_ref, 1e-5); |
1943 | |
1944 | C10_CUDA_CHECK(cudaFree(a_dev)); |
1945 | C10_CUDA_CHECK(cudaFree(b_dev)); |
1946 | C10_CUDA_CHECK(cudaFree(c_dev)); |
1947 | C10_CUDA_CHECK(cudaFree(d_dev)); |
1948 | } |
1949 | |
1950 | // Tests the case with two loops fused into a common parent, which is not bound |
1951 | // to any block or thread dimension - however it's two inner loops are bound to |
1952 | // the first thread dimensions. This should work just like the MaskThreadDim |
1953 | // test where the bigger loop is unmasked but the smaller is masked. |
1954 | TEST(Cuda, MaskInnerLoopOneBlock_CUDA) { |
1955 | int OUTER_SIZE = 10; |
1956 | int A_SIZE = 100; |
1957 | int B_SIZE = 50; |
1958 | BufHandle a_buf("a" , {OUTER_SIZE, A_SIZE}, kFloat); |
1959 | BufHandle b_buf("b" , {OUTER_SIZE, B_SIZE}, kFloat); |
1960 | BufHandle c_buf("c" , {OUTER_SIZE, A_SIZE}, kFloat); |
1961 | BufHandle d_buf("d" , {OUTER_SIZE, B_SIZE}, kFloat); |
1962 | |
1963 | // Can't build this using Compute and transforms yet. |
1964 | LoopOptions blockBound; |
1965 | blockBound.set_gpu_block_index(0); |
1966 | LoopOptions threadBound; |
1967 | threadBound.set_gpu_thread_index(0); |
1968 | VarHandle i("i" , kInt); |
1969 | VarHandle j("j" , kInt); |
1970 | VarHandle k("k" , kInt); |
1971 | |
1972 | StmtPtr stmt = For::make( |
1973 | i, |
1974 | 0, |
1975 | OUTER_SIZE, |
1976 | Block::make( |
1977 | {For::make( |
1978 | j, |
1979 | 0, |
1980 | A_SIZE, |
1981 | c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), |
1982 | threadBound), |
1983 | For::make( |
1984 | k, |
1985 | 0, |
1986 | B_SIZE, |
1987 | d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), |
1988 | threadBound)})); |
1989 | |
1990 | stmt = FlattenIndexes(stmt); |
1991 | stmt = IRSimplifier::simplify(stmt); |
1992 | |
1993 | CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); |
1994 | |
1995 | std::ostringstream oss; |
1996 | oss << *cuda_cg.stmt(); |
1997 | |
1998 | // The other loop remains the D write is masked. |
1999 | const std::string& verification_pattern = |
2000 | R"IR( |
2001 | # CHECK: for (int i = 0; i < 10 |
2002 | # CHECK-NOT: if ( |
2003 | # CHECK: c[threadIdx.x + 100 * i] = |
2004 | # CHECK: __syncthreads(); |
2005 | # CHECK: if (threadIdx.x<50 |
2006 | # CHECK: d[threadIdx.x + 50 * i] =)IR" ; |
2007 | |
2008 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2009 | |
2010 | auto blockExtents = cuda_cg.gpu_block_extents(); |
2011 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
2012 | ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(1))); |
2013 | ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE))); |
2014 | |
2015 | PaddedBuffer<float> a_v(OUTER_SIZE, A_SIZE); |
2016 | PaddedBuffer<float> b_v(OUTER_SIZE, B_SIZE); |
2017 | PaddedBuffer<float> c_v(OUTER_SIZE, A_SIZE); |
2018 | PaddedBuffer<float> d_v(OUTER_SIZE, B_SIZE); |
2019 | |
2020 | PaddedBuffer<float> c_ref(OUTER_SIZE, A_SIZE); |
2021 | PaddedBuffer<float> d_ref(OUTER_SIZE, B_SIZE); |
2022 | |
2023 | for (const auto o : c10::irange(OUTER_SIZE)) { |
2024 | for (const auto i : c10::irange(A_SIZE)) { |
2025 | a_v(o, i) = (float)i; |
2026 | c_ref(o, i) = (float)(i * 2); |
2027 | } |
2028 | for (const auto i : c10::irange(B_SIZE)) { |
2029 | b_v(o, i) = (float)(B_SIZE - i); |
2030 | d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); |
2031 | } |
2032 | } |
2033 | |
2034 | float* a_dev = nullptr; |
2035 | C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); |
2036 | float* b_dev = nullptr; |
2037 | C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); |
2038 | float* c_dev = nullptr; |
2039 | C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); |
2040 | float* d_dev = nullptr; |
2041 | C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); |
2042 | C10_CUDA_CHECK(cudaMemcpy( |
2043 | a_dev, |
2044 | a_v.data(), |
2045 | OUTER_SIZE * A_SIZE * sizeof(float), |
2046 | cudaMemcpyHostToDevice)); |
2047 | C10_CUDA_CHECK(cudaMemcpy( |
2048 | b_dev, |
2049 | b_v.data(), |
2050 | OUTER_SIZE * B_SIZE * sizeof(float), |
2051 | cudaMemcpyHostToDevice)); |
2052 | C10_CUDA_CHECK(cudaMemcpy( |
2053 | c_dev, |
2054 | c_v.data(), |
2055 | OUTER_SIZE * A_SIZE * sizeof(float), |
2056 | cudaMemcpyHostToDevice)); |
2057 | C10_CUDA_CHECK(cudaMemcpy( |
2058 | d_dev, |
2059 | d_v.data(), |
2060 | OUTER_SIZE * B_SIZE * sizeof(float), |
2061 | cudaMemcpyHostToDevice)); |
2062 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
2063 | |
2064 | cuda_cg(a_dev, b_dev, c_dev, d_dev); |
2065 | |
2066 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
2067 | C10_CUDA_CHECK(cudaMemcpy( |
2068 | c_v.data(), |
2069 | c_dev, |
2070 | OUTER_SIZE * A_SIZE * sizeof(float), |
2071 | cudaMemcpyDeviceToHost)); |
2072 | C10_CUDA_CHECK(cudaMemcpy( |
2073 | d_v.data(), |
2074 | d_dev, |
2075 | OUTER_SIZE * B_SIZE * sizeof(float), |
2076 | cudaMemcpyDeviceToHost)); |
2077 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
2078 | |
2079 | ExpectAllNear(c_v, c_ref, 1e-5); |
2080 | ExpectAllNear(d_v, d_ref, 1e-5); |
2081 | |
2082 | C10_CUDA_CHECK(cudaFree(a_dev)); |
2083 | C10_CUDA_CHECK(cudaFree(b_dev)); |
2084 | C10_CUDA_CHECK(cudaFree(c_dev)); |
2085 | C10_CUDA_CHECK(cudaFree(d_dev)); |
2086 | } |
2087 | |
2088 | // Tests the case with two loop nests, each of which bound to the same block |
2089 | // size, but with internal loops bound to different thread rank (ie x and y). In |
2090 | // this case both bodies must be masked against the other dimension being > 0. |
2091 | // Note: this is a bit degenerate no one would actually write this for perf. |
2092 | TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { |
2093 | int OUTER_SIZE = 10; |
2094 | int A_SIZE = 30; |
2095 | int B_SIZE = 15; |
2096 | BufHandle a_buf("a" , {OUTER_SIZE, A_SIZE}, kFloat); |
2097 | BufHandle b_buf("b" , {OUTER_SIZE, B_SIZE}, kFloat); |
2098 | Tensor c = Compute( |
2099 | "C" , {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { |
2100 | return ExprHandle(2) * a_buf.load(i, j); |
2101 | }); |
2102 | Tensor d = Compute( |
2103 | "D" , {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { |
2104 | return c.load(i, j * 2) + b_buf.load(i, j); |
2105 | }); |
2106 | |
2107 | LoopNest l({c, d}); |
2108 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
2109 | loops[0]->set_gpu_block_index(0); |
2110 | loops[1]->set_gpu_thread_index(0); |
2111 | loops = l.getLoopStmtsFor(d); |
2112 | loops[0]->set_gpu_block_index(0); |
2113 | loops[1]->set_gpu_thread_index(1); |
2114 | |
2115 | l.prepareForCodegen(); |
2116 | StmtPtr stmt = l.root_stmt(); |
2117 | CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); |
2118 | |
2119 | std::ostringstream oss; |
2120 | oss << *cuda_cg.stmt(); |
2121 | |
2122 | // Both stores masked agaist the other thread dim < 1. |
2123 | const std::string& verification_pattern = |
2124 | R"IR( |
2125 | # CHECK: if (threadIdx.y<1 |
2126 | # CHECK: C[threadIdx.x + 30 * blockIdx.x] = |
2127 | # CHECK: __syncthreads(); |
2128 | # CHECK: if (threadIdx.x<1 |
2129 | # CHECK: D[threadIdx.y + 15 * blockIdx.x] =)IR" ; |
2130 | |
2131 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2132 | |
2133 | auto blockExtents = cuda_cg.gpu_block_extents(); |
2134 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
2135 | ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(OUTER_SIZE))); |
2136 | ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE))); |
2137 | |
2138 | PaddedBuffer<float> a_v(OUTER_SIZE, A_SIZE); |
2139 | PaddedBuffer<float> b_v(OUTER_SIZE, B_SIZE); |
2140 | PaddedBuffer<float> c_v(OUTER_SIZE, A_SIZE); |
2141 | PaddedBuffer<float> d_v(OUTER_SIZE, B_SIZE); |
2142 | |
2143 | PaddedBuffer<float> c_ref(OUTER_SIZE, A_SIZE); |
2144 | PaddedBuffer<float> d_ref(OUTER_SIZE, B_SIZE); |
2145 | |
2146 | for (const auto o : c10::irange(OUTER_SIZE)) { |
2147 | for (const auto i : c10::irange(A_SIZE)) { |
2148 | a_v(o, i) = (float)i; |
2149 | c_ref(o, i) = (float)(i * 2); |
2150 | } |
2151 | } |
2152 | |
2153 | for (const auto o : c10::irange(OUTER_SIZE)) { |
2154 | for (const auto i : c10::irange(B_SIZE)) { |
2155 | b_v(o, i) = (float)(B_SIZE - i); |
2156 | d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); |
2157 | } |
2158 | } |
2159 | |
2160 | float* a_dev = nullptr; |
2161 | C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); |
2162 | float* b_dev = nullptr; |
2163 | C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); |
2164 | float* c_dev = nullptr; |
2165 | C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); |
2166 | float* d_dev = nullptr; |
2167 | C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); |
2168 | C10_CUDA_CHECK(cudaMemcpy( |
2169 | a_dev, |
2170 | a_v.data(), |
2171 | OUTER_SIZE * A_SIZE * sizeof(float), |
2172 | cudaMemcpyHostToDevice)); |
2173 | C10_CUDA_CHECK(cudaMemcpy( |
2174 | b_dev, |
2175 | b_v.data(), |
2176 | OUTER_SIZE * B_SIZE * sizeof(float), |
2177 | cudaMemcpyHostToDevice)); |
2178 | C10_CUDA_CHECK(cudaMemcpy( |
2179 | c_dev, |
2180 | c_v.data(), |
2181 | OUTER_SIZE * A_SIZE * sizeof(float), |
2182 | cudaMemcpyHostToDevice)); |
2183 | C10_CUDA_CHECK(cudaMemcpy( |
2184 | d_dev, |
2185 | d_v.data(), |
2186 | OUTER_SIZE * B_SIZE * sizeof(float), |
2187 | cudaMemcpyHostToDevice)); |
2188 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
2189 | |
2190 | cuda_cg(c_dev, d_dev, a_dev, b_dev); |
2191 | |
2192 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
2193 | C10_CUDA_CHECK(cudaMemcpy( |
2194 | c_v.data(), |
2195 | c_dev, |
2196 | OUTER_SIZE * A_SIZE * sizeof(float), |
2197 | cudaMemcpyDeviceToHost)); |
2198 | C10_CUDA_CHECK(cudaMemcpy( |
2199 | d_v.data(), |
2200 | d_dev, |
2201 | OUTER_SIZE * B_SIZE * sizeof(float), |
2202 | cudaMemcpyDeviceToHost)); |
2203 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
2204 | |
2205 | ExpectAllNear(c_v, c_ref, 1e-5); |
2206 | ExpectAllNear(d_v, d_ref, 1e-5); |
2207 | |
2208 | C10_CUDA_CHECK(cudaFree(a_dev)); |
2209 | C10_CUDA_CHECK(cudaFree(b_dev)); |
2210 | C10_CUDA_CHECK(cudaFree(c_dev)); |
2211 | C10_CUDA_CHECK(cudaFree(d_dev)); |
2212 | } |
2213 | |
2214 | // Tests the case with two loop nests, each bound to both Block and Thread but |
2215 | // the second loop is smaller in both cases - the second store must be masked |
2216 | // for both the block and thread dimension. |
2217 | TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { |
2218 | int OUTER_A_SIZE = 10; |
2219 | int OUTER_B_SIZE = 5; |
2220 | int A_SIZE = 30; |
2221 | int B_SIZE = 15; |
2222 | BufHandle a_buf("a" , {OUTER_A_SIZE, A_SIZE}, kFloat); |
2223 | BufHandle b_buf("b" , {OUTER_B_SIZE, B_SIZE}, kFloat); |
2224 | Tensor c = Compute( |
2225 | "C" , {OUTER_A_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { |
2226 | return ExprHandle(2) * a_buf.load(i, j); |
2227 | }); |
2228 | Tensor d = Compute( |
2229 | "D" , {OUTER_B_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { |
2230 | return c.load(i, j * 2) + b_buf.load(i, j); |
2231 | }); |
2232 | |
2233 | LoopNest l({c, d}); |
2234 | std::vector<ForPtr> loops = l.getLoopStmtsFor(c); |
2235 | loops[0]->set_gpu_block_index(0); |
2236 | loops[1]->set_gpu_thread_index(0); |
2237 | loops = l.getLoopStmtsFor(d); |
2238 | loops[0]->set_gpu_block_index(0); |
2239 | loops[1]->set_gpu_thread_index(0); |
2240 | |
2241 | l.prepareForCodegen(); |
2242 | StmtPtr stmt = l.root_stmt(); |
2243 | CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); |
2244 | |
2245 | std::ostringstream oss; |
2246 | oss << *cuda_cg.stmt(); |
2247 | |
2248 | // The write to D should be masked twice, but not the write to C. |
2249 | const std::string& verification_pattern = |
2250 | R"IR( |
2251 | # CHECK-NOT: if ( |
2252 | # CHECK: C[threadIdx.x + 30 * blockIdx.x] = |
2253 | # CHECK: __syncthreads(); |
2254 | # CHECK: if (blockIdx.x<5 |
2255 | # CHECK: if (threadIdx.x<15 |
2256 | # CHECK: D[threadIdx.x + 15 * blockIdx.x] =)IR" ; |
2257 | |
2258 | torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
2259 | |
2260 | auto blockExtents = cuda_cg.gpu_block_extents(); |
2261 | auto threadExtents = cuda_cg.gpu_thread_extents(); |
2262 | ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(OUTER_A_SIZE))); |
2263 | ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE))); |
2264 | |
2265 | PaddedBuffer<float> a_v(OUTER_A_SIZE, A_SIZE); |
2266 | PaddedBuffer<float> b_v(OUTER_B_SIZE, B_SIZE); |
2267 | PaddedBuffer<float> c_v(OUTER_A_SIZE, A_SIZE); |
2268 | PaddedBuffer<float> d_v(OUTER_B_SIZE, B_SIZE); |
2269 | |
2270 | PaddedBuffer<float> c_ref(OUTER_A_SIZE, A_SIZE); |
2271 | PaddedBuffer<float> d_ref(OUTER_B_SIZE, B_SIZE); |
2272 | |
2273 | for (const auto o : c10::irange(OUTER_A_SIZE)) { |
2274 | for (const auto i : c10::irange(A_SIZE)) { |
2275 | a_v(o, i) = (float)i; |
2276 | c_ref(o, i) = (float)(i * 2); |
2277 | } |
2278 | } |
2279 | |
2280 | for (const auto o : c10::irange(OUTER_B_SIZE)) { |
2281 | for (const auto i : c10::irange(B_SIZE)) { |
2282 | b_v(o, i) = (float)(B_SIZE - i); |
2283 | d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); |
2284 | } |
2285 | } |
2286 | |
2287 | float* a_dev = nullptr; |
2288 | C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); |
2289 | float* b_dev = nullptr; |
2290 | C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); |
2291 | float* c_dev = nullptr; |
2292 | C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); |
2293 | float* d_dev = nullptr; |
2294 | C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); |
2295 | C10_CUDA_CHECK(cudaMemcpy( |
2296 | a_dev, |
2297 | a_v.data(), |
2298 | OUTER_A_SIZE * A_SIZE * sizeof(float), |
2299 | cudaMemcpyHostToDevice)); |
2300 | C10_CUDA_CHECK(cudaMemcpy( |
2301 | b_dev, |
2302 | b_v.data(), |
2303 | OUTER_B_SIZE * B_SIZE * sizeof(float), |
2304 | cudaMemcpyHostToDevice)); |
2305 | C10_CUDA_CHECK(cudaMemcpy( |
2306 | c_dev, |
2307 | c_v.data(), |
2308 | OUTER_A_SIZE * A_SIZE * sizeof(float), |
2309 | cudaMemcpyHostToDevice)); |
2310 | C10_CUDA_CHECK(cudaMemcpy( |
2311 | d_dev, |
2312 | d_v.data(), |
2313 | OUTER_B_SIZE * B_SIZE * sizeof(float), |
2314 | cudaMemcpyHostToDevice)); |
2315 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
2316 | |
2317 | cuda_cg(c_dev, d_dev, a_dev, b_dev); |
2318 | |
2319 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
2320 | C10_CUDA_CHECK(cudaMemcpy( |
2321 | c_v.data(), |
2322 | c_dev, |
2323 | OUTER_A_SIZE * A_SIZE * sizeof(float), |
2324 | cudaMemcpyDeviceToHost)); |
2325 | C10_CUDA_CHECK(cudaMemcpy( |
2326 | d_v.data(), |
2327 | d_dev, |
2328 | OUTER_B_SIZE * B_SIZE * sizeof(float), |
2329 | cudaMemcpyDeviceToHost)); |
2330 | C10_CUDA_CHECK(cudaDeviceSynchronize()); |
2331 | |
2332 | ExpectAllNear(c_v, c_ref, 1e-5); |
2333 | ExpectAllNear(d_v, d_ref, 1e-5); |
2334 | |
2335 | C10_CUDA_CHECK(cudaFree(a_dev)); |
2336 | C10_CUDA_CHECK(cudaFree(b_dev)); |
2337 | C10_CUDA_CHECK(cudaFree(c_dev)); |
2338 | C10_CUDA_CHECK(cudaFree(d_dev)); |
2339 | } |
2340 | |
2341 | } // namespace jit |
2342 | } // namespace torch |
2343 | |
2344 | #endif |
2345 | |