1 | /* |
2 | Provides the implementations of CUDA BLAS function templates. |
3 | */ |
4 | |
5 | #include <ATen/ATen.h> |
6 | #include <ATen/cuda/CUDABlas.h> |
7 | #include <ATen/cuda/Exceptions.h> |
8 | #include <c10/cuda/CUDAFunctions.h> |
9 | #include <c10/macros/Export.h> |
10 | #include <c10/util/irange.h> |
11 | |
12 | // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also |
13 | // added bf16 support |
14 | #if !defined(USE_ROCM) && !defined(_MSC_VER) |
15 | #include <cublasLt.h> |
16 | #endif |
17 | |
18 | #ifdef USE_ROCM |
19 | #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) |
20 | #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) |
21 | #endif |
22 | |
23 | #define CUDABLAS_POSINT_CHECK(FD, X) \ |
24 | TORCH_CHECK( \ |
25 | (X > 0 && X <= INT_MAX), \ |
26 | "at::cuda::blas::" #FD " argument " #X \ |
27 | " must be positive and less than ", \ |
28 | INT_MAX, \ |
29 | " but got ", \ |
30 | X) |
31 | |
32 | #define CUDABLAS_NONNEGINT_CHECK(FD, X) \ |
33 | TORCH_CHECK( \ |
34 | (X >= 0 && X <= INT_MAX), \ |
35 | "at::cuda::blas::" #FD " argument " #X \ |
36 | " must be non-negative and less than ", \ |
37 | INT_MAX, \ |
38 | " but got ", \ |
39 | X) |
40 | |
41 | namespace { |
42 | |
43 | static cublasOperation_t _cublasOpFromChar(char op) { |
44 | switch (op) { |
45 | case 'n': |
46 | case 'N': |
47 | return CUBLAS_OP_N; |
48 | case 't': |
49 | case 'T': |
50 | return CUBLAS_OP_T; |
51 | case 'c': |
52 | case 'C': |
53 | return CUBLAS_OP_C; |
54 | } |
55 | AT_ERROR( |
56 | "_cublasOpFromChar input should be 't', 'n' or 'c' but got `" , op, "`" ); |
57 | } |
58 | |
59 | static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) { |
60 | // Note: leading dimensions generally are checked that they are > 0 |
61 | // and at least as big the result requires (even if the value won't |
62 | // be used). |
63 | |
64 | // Q: Why does Level3 check trans but this doesn't? |
65 | // A: In level 2, the sizes (m, n) specify the size of A |
66 | // (independent of trans value). In level 3. the sizes (m, n, k) |
67 | // specify the sizes of op(A), op(B) where op depend on trans |
68 | // values. |
69 | if (n <= 1) |
70 | *lda = std::max<int64_t>(m, 1); |
71 | } |
72 | |
73 | static void _cublasAdjustLdLevel3( |
74 | char transa, |
75 | char transb, |
76 | int64_t m, |
77 | int64_t n, |
78 | int64_t k, |
79 | int64_t* lda, |
80 | int64_t* ldb, |
81 | int64_t* ldc) { |
82 | bool transa_ = ((transa != 'n') && (transa != 'N')); |
83 | bool transb_ = ((transb != 'n') && (transb != 'N')); |
84 | |
85 | // Note: leading dimensions generally are checked that they are > 0 |
86 | // and at least as big the result requires (even if the value won't |
87 | // be used). |
88 | if (n <= 1) |
89 | *ldc = std::max<int64_t>(m, 1); |
90 | |
91 | if (transa_) { |
92 | if (m <= 1) |
93 | *lda = std::max<int64_t>(k, 1); |
94 | } else { |
95 | if (k <= 1) |
96 | *lda = std::max<int64_t>(m, 1); |
97 | } |
98 | |
99 | if (transb_) { |
100 | if (k <= 1) |
101 | *ldb = std::max<int64_t>(n, 1); |
102 | } else { |
103 | if (n <= 1) |
104 | *ldb = std::max<int64_t>(k, 1); |
105 | } |
106 | } |
107 | } // anonymous namespace |
108 | |
109 | namespace at { |
110 | namespace cuda { |
111 | namespace blas { |
112 | |
113 | /* LEVEL 3 BLAS FUNCTIONS */ |
114 | |
115 | #ifndef USE_ROCM |
116 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11020 |
117 | #define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx |
118 | #else |
119 | // Workaround for https://github.com/pytorch/pytorch/issues/45724 |
120 | cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle, |
121 | cublasOperation_t transa, |
122 | cublasOperation_t transb, |
123 | int m, |
124 | int n, |
125 | int k, |
126 | const void *alpha, |
127 | const void *A, |
128 | cudaDataType Atype, |
129 | int lda, |
130 | long long int strideA, |
131 | const void *B, |
132 | cudaDataType Btype, |
133 | int ldb, |
134 | long long int strideB, |
135 | const void *beta, |
136 | void *C, |
137 | cudaDataType Ctype, |
138 | int ldc, |
139 | long long int strideC, |
140 | int64_t batchCount, |
141 | cudaDataType computeType, |
142 | cublasGemmAlgo_t algo) |
143 | { |
144 | cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); |
145 | if (prop->major != 7) { |
146 | return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo); |
147 | } |
148 | cublasStatus_t result; |
149 | constexpr int64_t split = 63 * 1024; |
150 | for(int64_t i = 0; i < batchCount; i += split) { |
151 | int64_t count = std::min<int64_t>(split, batchCount - i); |
152 | result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, |
153 | (char *)A + i * strideA * 2, Atype, lda, strideA, |
154 | (char *)B + i * strideB * 2, Btype, ldb, strideB, |
155 | beta, |
156 | (char *)C + i * strideC * 2, Ctype, ldc, strideC, |
157 | (int)count, computeType, algo); |
158 | TORCH_CUDABLAS_CHECK(result); |
159 | } |
160 | return result; |
161 | } |
162 | #endif |
163 | #endif |
164 | |
165 | #define GEMM_CHECK_ARGVALUES(Dtype) \ |
166 | do { \ |
167 | CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, m); \ |
168 | CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, n); \ |
169 | CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, k); \ |
170 | CUDABLAS_POSINT_CHECK(gemm<Dtype>, lda); \ |
171 | CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldb); \ |
172 | CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \ |
173 | } while (0) |
174 | |
175 | #define BGEMM_CHECK_ARGVALUES(Dtype) \ |
176 | do { \ |
177 | CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, m); \ |
178 | CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, n); \ |
179 | CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, k); \ |
180 | CUDABLAS_POSINT_CHECK(bgemm<Dtype>, lda); \ |
181 | CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldb); \ |
182 | CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldc); \ |
183 | CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \ |
184 | } while (0) |
185 | |
186 | template <> |
187 | void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) { |
188 | // See Note [Writing Nondeterministic Operations] |
189 | globalContext().alertCuBLASConfigNotDeterministic(); |
190 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
191 | cublasOperation_t opa = _cublasOpFromChar(transa); |
192 | cublasOperation_t opb = _cublasOpFromChar(transb); |
193 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
194 | BGEMM_CHECK_ARGVALUES(double); |
195 | TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched( |
196 | handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); |
197 | } |
198 | |
199 | template <> |
200 | void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) { |
201 | // See Note [Writing Nondeterministic Operations] |
202 | globalContext().alertCuBLASConfigNotDeterministic(); |
203 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
204 | cublasOperation_t opa = _cublasOpFromChar(transa); |
205 | cublasOperation_t opb = _cublasOpFromChar(transb); |
206 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
207 | BGEMM_CHECK_ARGVALUES(float); |
208 | TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched( |
209 | handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); |
210 | } |
211 | |
212 | template <> |
213 | void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) { |
214 | // See Note [Writing Nondeterministic Operations] |
215 | globalContext().alertCuBLASConfigNotDeterministic(); |
216 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
217 | cublasOperation_t opa = _cublasOpFromChar(transa); |
218 | cublasOperation_t opb = _cublasOpFromChar(transb); |
219 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
220 | BGEMM_CHECK_ARGVALUES(c10::complex<double>); |
221 | TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched( |
222 | handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a), |
223 | lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta), |
224 | reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches)); |
225 | } |
226 | |
227 | template <> |
228 | void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) { |
229 | // See Note [Writing Nondeterministic Operations] |
230 | globalContext().alertCuBLASConfigNotDeterministic(); |
231 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
232 | cublasOperation_t opa = _cublasOpFromChar(transa); |
233 | cublasOperation_t opb = _cublasOpFromChar(transb); |
234 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
235 | BGEMM_CHECK_ARGVALUES(c10::complex<float>); |
236 | TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched( |
237 | handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a), |
238 | lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta), |
239 | reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches)); |
240 | } |
241 | |
242 | template <> |
243 | void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { |
244 | // See Note [Writing Nondeterministic Operations] |
245 | globalContext().alertCuBLASConfigNotDeterministic(); |
246 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
247 | cublasOperation_t opa = _cublasOpFromChar(transa); |
248 | cublasOperation_t opb = _cublasOpFromChar(transb); |
249 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
250 | BGEMM_CHECK_ARGVALUES(at::Half); |
251 | float falpha = alpha; |
252 | float fbeta = beta; |
253 | #ifdef USE_ROCM |
254 | int flag = 0; |
255 | #if USE_GEMM_FLAGS_FP16_ALT_IMPL |
256 | flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; |
257 | #endif |
258 | TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, |
259 | (void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea, |
260 | b, rocblas_datatype_f16_r, (int)ldb, strideb, |
261 | (void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec, |
262 | c, rocblas_datatype_f16_r, (int)ldc, stridec, |
263 | (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, |
264 | 0, flag)); |
265 | #else |
266 | cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); |
267 | if (prop->major >= 5){ |
268 | TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix( |
269 | handle, opa, opb, m, n, k, |
270 | (void*)(&falpha), a, CUDA_R_16F, lda, stridea, |
271 | b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta), |
272 | c, CUDA_R_16F, ldc, stridec, |
273 | num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); |
274 | } else { |
275 | for (const auto i : c10::irange(num_batches)) { |
276 | at::cuda::blas::gemm<at::Half>( |
277 | transa, transb, |
278 | m, n, k, |
279 | alpha, (a + i * stridea), lda, |
280 | (b + i * strideb), ldb, beta, |
281 | (c + i * stridec), ldc); |
282 | } |
283 | } |
284 | #endif // USE_ROCM |
285 | } |
286 | |
287 | template <> |
288 | void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { |
289 | // See Note [Writing Nondeterministic Operations] |
290 | globalContext().alertCuBLASConfigNotDeterministic(); |
291 | BGEMM_CHECK_ARGVALUES(at::BFloat16); |
292 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
293 | cublasOperation_t opa = _cublasOpFromChar(transa); |
294 | cublasOperation_t opb = _cublasOpFromChar(transb); |
295 | const float falpha = alpha; |
296 | const float fbeta = beta; |
297 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
298 | |
299 | #if !defined(USE_ROCM) |
300 | TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle, |
301 | opa, opb, (int)m, (int)n, (int)k, |
302 | (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, |
303 | b, CUDA_R_16BF, (int)ldb, strideb, |
304 | (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, |
305 | (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); |
306 | #else |
307 | TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, |
308 | (void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea, |
309 | b, rocblas_datatype_bf16_r, (int)ldb, strideb, |
310 | (void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec, |
311 | c, rocblas_datatype_bf16_r, (int)ldc, stridec, |
312 | (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, |
313 | 0, 0, NULL, NULL)); |
314 | #endif // !defined(USE_ROCM) |
315 | } |
316 | |
317 | template <> |
318 | void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) { |
319 | // See Note [Writing Nondeterministic Operations] |
320 | globalContext().alertCuBLASConfigNotDeterministic(); |
321 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
322 | cublasOperation_t opa = _cublasOpFromChar(transa); |
323 | cublasOperation_t opb = _cublasOpFromChar(transb); |
324 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
325 | GEMM_CHECK_ARGVALUES(double); |
326 | TORCH_CUDABLAS_CHECK(cublasDgemm( |
327 | handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); |
328 | } |
329 | |
330 | template <> |
331 | void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) { |
332 | // See Note [Writing Nondeterministic Operations] |
333 | globalContext().alertCuBLASConfigNotDeterministic(); |
334 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
335 | cublasOperation_t opa = _cublasOpFromChar(transa); |
336 | cublasOperation_t opb = _cublasOpFromChar(transb); |
337 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
338 | GEMM_CHECK_ARGVALUES(float); |
339 | TORCH_CUDABLAS_CHECK(cublasSgemm( |
340 | handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); |
341 | } |
342 | |
343 | #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000) |
344 | template <> |
345 | void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) { |
346 | // See Note [Writing Nondeterministic Operations] |
347 | globalContext().alertCuBLASConfigNotDeterministic(); |
348 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
349 | cublasOperation_t opa = _cublasOpFromChar(transa); |
350 | cublasOperation_t opb = _cublasOpFromChar(transb); |
351 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
352 | GEMM_CHECK_ARGVALUES(c10::complex<double>); |
353 | TORCH_CUDABLAS_CHECK(cublasZgemm( |
354 | handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a), |
355 | lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta), |
356 | reinterpret_cast<cuDoubleComplex*>(c), ldc)); |
357 | } |
358 | #endif |
359 | |
360 | #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000) |
361 | template <> |
362 | void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) { |
363 | // See Note [Writing Nondeterministic Operations] |
364 | globalContext().alertCuBLASConfigNotDeterministic(); |
365 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
366 | cublasOperation_t opa = _cublasOpFromChar(transa); |
367 | cublasOperation_t opb = _cublasOpFromChar(transb); |
368 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
369 | GEMM_CHECK_ARGVALUES(c10::complex<float>); |
370 | TORCH_CUDABLAS_CHECK(cublasCgemm( |
371 | handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a), |
372 | lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta), |
373 | reinterpret_cast<cuComplex*>(c), ldc)); |
374 | } |
375 | #endif |
376 | |
377 | template <> |
378 | void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) { |
379 | // See Note [Writing Nondeterministic Operations] |
380 | globalContext().alertCuBLASConfigNotDeterministic(); |
381 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
382 | cublasOperation_t opa = _cublasOpFromChar(transa); |
383 | cublasOperation_t opb = _cublasOpFromChar(transb); |
384 | float falpha = alpha; |
385 | float fbeta = beta; |
386 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
387 | GEMM_CHECK_ARGVALUES(at::Half); |
388 | #ifdef USE_ROCM |
389 | int flag = 0; |
390 | #if USE_GEMM_FLAGS_FP16_ALT_IMPL |
391 | flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; |
392 | #endif |
393 | TORCH_CUDABLAS_CHECK(rocblas_gemm_ex( |
394 | handle, |
395 | opa, |
396 | opb, |
397 | m, |
398 | n, |
399 | k, |
400 | &falpha, |
401 | a, |
402 | rocblas_datatype_f16_r, |
403 | lda, |
404 | b, |
405 | rocblas_datatype_f16_r, |
406 | ldb, |
407 | &fbeta, |
408 | c, |
409 | rocblas_datatype_f16_r, |
410 | ldc, |
411 | c, |
412 | rocblas_datatype_f16_r, |
413 | ldc, |
414 | rocblas_datatype_f32_r, |
415 | rocblas_gemm_algo_standard, |
416 | 0, |
417 | flag)); |
418 | #else |
419 | cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); |
420 | if (prop->major >= 5) { |
421 | cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH; |
422 | if (!at::globalContext().allowFP16ReductionCuBLAS()) { |
423 | cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); |
424 | } |
425 | // Disallow fp16 reductions that could lead to unexpected overflow issues. |
426 | TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags)); |
427 | TORCH_CUDABLAS_CHECK(cublasGemmEx( |
428 | handle, |
429 | opa, |
430 | opb, |
431 | m, |
432 | n, |
433 | k, |
434 | &falpha, |
435 | a, |
436 | CUDA_R_16F, |
437 | lda, |
438 | b, |
439 | CUDA_R_16F, |
440 | ldb, |
441 | &fbeta, |
442 | c, |
443 | CUDA_R_16F, |
444 | ldc, |
445 | CUDA_R_32F, |
446 | CUBLAS_GEMM_DFALT_TENSOR_OP)); |
447 | TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); |
448 | } else { |
449 | TORCH_CUDABLAS_CHECK(cublasSgemmEx( |
450 | handle, |
451 | opa, |
452 | opb, |
453 | m, |
454 | n, |
455 | k, |
456 | &falpha, |
457 | a, |
458 | CUDA_R_16F, |
459 | lda, |
460 | b, |
461 | CUDA_R_16F, |
462 | ldb, |
463 | &fbeta, |
464 | c, |
465 | CUDA_R_16F, |
466 | ldc)); |
467 | } |
468 | #endif |
469 | } |
470 | |
471 | #ifdef USE_ROCM |
472 | template <> |
473 | void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { |
474 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
475 | cublasOperation_t opa = _cublasOpFromChar(transa); |
476 | cublasOperation_t opb = _cublasOpFromChar(transb); |
477 | float falpha = alpha; |
478 | float fbeta = beta; |
479 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
480 | GEMM_CHECK_ARGVALUES(at::BFloat16); |
481 | TORCH_CUDABLAS_CHECK(rocblas_gemm_ex( |
482 | handle, |
483 | opa, |
484 | opb, |
485 | m, |
486 | n, |
487 | k, |
488 | &falpha, |
489 | a, |
490 | rocblas_datatype_bf16_r, |
491 | lda, |
492 | b, |
493 | rocblas_datatype_bf16_r, |
494 | ldb, |
495 | &fbeta, |
496 | c, |
497 | rocblas_datatype_bf16_r, |
498 | ldc, |
499 | c, |
500 | rocblas_datatype_bf16_r, |
501 | ldc, |
502 | rocblas_datatype_f32_r, |
503 | rocblas_gemm_algo_standard, |
504 | 0, |
505 | 0)); |
506 | } |
507 | #endif |
508 | |
509 | #if !defined(USE_ROCM) |
510 | template <> |
511 | void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { |
512 | globalContext().alertCuBLASConfigNotDeterministic(); |
513 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
514 | cublasOperation_t opa = _cublasOpFromChar(transa); |
515 | cublasOperation_t opb = _cublasOpFromChar(transb); |
516 | float falpha = alpha; |
517 | float fbeta = beta; |
518 | _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); |
519 | GEMM_CHECK_ARGVALUES(at::BFloat16); |
520 | cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH; |
521 | if (!at::globalContext().allowBF16ReductionCuBLAS()) { |
522 | cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); |
523 | } |
524 | TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags)); |
525 | TORCH_CUDABLAS_CHECK(cublasGemmEx( |
526 | handle, |
527 | opa, |
528 | opb, |
529 | m, |
530 | n, |
531 | k, |
532 | &falpha, |
533 | a, |
534 | CUDA_R_16BF, |
535 | lda, |
536 | b, |
537 | CUDA_R_16BF, |
538 | ldb, |
539 | &fbeta, |
540 | c, |
541 | CUDA_R_16BF, |
542 | ldc, |
543 | CUDA_R_32F, |
544 | CUBLAS_GEMM_DFALT_TENSOR_OP)); |
545 | TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); |
546 | } |
547 | #endif // !defined(USE_ROCM) |
548 | |
549 | #if !defined(USE_ROCM) && !defined(_MSC_VER) |
550 | |
551 | namespace { |
552 | // Following the pattern of CuSparseDescriptor |
553 | // Defined here for now because this is the only place cublas_lt interface is |
554 | // used but can be moved to a header once cublas_lt interface is used in |
555 | // multiple places. |
556 | template <typename T, cublasStatus_t (*destructor)(T*)> |
557 | struct CuBlasLtDeleter { |
558 | void operator()(T* x) { |
559 | if (x != nullptr) { |
560 | TORCH_CUDABLAS_CHECK(destructor(x)); |
561 | } |
562 | } |
563 | }; |
564 | |
565 | template <typename T, cublasStatus_t (*destructor)(T*)> |
566 | class CuBlasLtDescriptor { |
567 | public: |
568 | T* descriptor() const { |
569 | return descriptor_.get(); |
570 | } |
571 | T* descriptor() { |
572 | return descriptor_.get(); |
573 | } |
574 | |
575 | protected: |
576 | std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_; |
577 | }; |
578 | |
579 | class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor< |
580 | cublasLtMatmulDescOpaque_t, |
581 | &cublasLtMatmulDescDestroy> { |
582 | public: |
583 | CuBlasLtMatmulDescriptor( |
584 | cublasComputeType_t compute_type, |
585 | cudaDataType_t scale_type) { |
586 | cublasLtMatmulDesc_t raw_descriptor = nullptr; |
587 | TORCH_CUDABLAS_CHECK( |
588 | cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); |
589 | descriptor_.reset(raw_descriptor); |
590 | } |
591 | }; |
592 | |
593 | class CuBlasLtMatrixLayout : public CuBlasLtDescriptor< |
594 | cublasLtMatrixLayoutOpaque_t, |
595 | &cublasLtMatrixLayoutDestroy> { |
596 | public: |
597 | CuBlasLtMatrixLayout( |
598 | cudaDataType_t type, |
599 | uint64_t rows, |
600 | uint64_t cols, |
601 | int64_t ld) { |
602 | cublasLtMatrixLayout_t raw_descriptor = nullptr; |
603 | TORCH_CUDABLAS_CHECK( |
604 | cublasLtMatrixLayoutCreate(&raw_descriptor, type, rows, cols, ld)); |
605 | descriptor_.reset(raw_descriptor); |
606 | } |
607 | }; |
608 | |
609 | class CuBlasLtMatmulPreference : public CuBlasLtDescriptor< |
610 | cublasLtMatmulPreferenceOpaque_t, |
611 | &cublasLtMatmulPreferenceDestroy> { |
612 | public: |
613 | CuBlasLtMatmulPreference() { |
614 | cublasLtMatmulPreference_t raw_descriptor = nullptr; |
615 | TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor)); |
616 | descriptor_.reset(raw_descriptor); |
617 | } |
618 | }; |
619 | } // namespace |
620 | |
621 | template <typename Dtype> |
622 | void gemm_and_bias( |
623 | bool transpose_mat1, |
624 | bool transpose_mat2, |
625 | int64_t m, |
626 | int64_t n, |
627 | int64_t k, |
628 | at::opmath_type<Dtype> alpha_val, |
629 | const Dtype* mat1_ptr, |
630 | int64_t mat1_ld, |
631 | const Dtype* mat2_ptr, |
632 | int64_t mat2_ld, |
633 | const Dtype* bias, |
634 | Dtype* result_ptr, |
635 | int64_t result_ld, |
636 | GEMMAndBiasActivationEpilogue activation) { |
637 | using opmath_t = at::opmath_type<Dtype>; |
638 | opmath_t beta_val = 0; // bias is added in epilogue |
639 | |
640 | cudaDataType_t abcType = CUDA_R_32F; |
641 | cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; |
642 | cudaDataType_t scaleType = CUDA_R_32F; |
643 | if (std::is_same<Dtype, double>::value) { |
644 | abcType = CUDA_R_64F; |
645 | computeType = CUBLAS_COMPUTE_64F; |
646 | scaleType = CUDA_R_64F; |
647 | } else if (std::is_same<Dtype, float>::value) { |
648 | if (at::globalContext().allowTF32CuBLAS()) { |
649 | computeType = CUBLAS_COMPUTE_32F_FAST_TF32; |
650 | } |
651 | abcType = CUDA_R_32F; |
652 | } else if (std::is_same<Dtype, at::Half>::value) { |
653 | abcType = CUDA_R_16F; |
654 | } else if (std::is_same<Dtype, at::BFloat16>::value) { |
655 | abcType = CUDA_R_16BF; |
656 | } |
657 | |
658 | CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); |
659 | cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N; |
660 | TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute( |
661 | computeDesc.descriptor(), |
662 | CUBLASLT_MATMUL_DESC_TRANSA, |
663 | &transa, |
664 | sizeof(transa))); |
665 | cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; |
666 | TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute( |
667 | computeDesc.descriptor(), |
668 | CUBLASLT_MATMUL_DESC_TRANSB, |
669 | &transb, |
670 | sizeof(transb))); |
671 | cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; |
672 | if (activation == GEMMAndBiasActivationEpilogue::RELU) { |
673 | epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; |
674 | } else if (activation == GEMMAndBiasActivationEpilogue::GELU) { |
675 | #if CUDA_VERSION >= 11040 |
676 | epilogue = CUBLASLT_EPILOGUE_GELU_BIAS; |
677 | #endif |
678 | } |
679 | TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute( |
680 | computeDesc.descriptor(), |
681 | CUBLASLT_MATMUL_DESC_EPILOGUE, |
682 | &epilogue, |
683 | sizeof(epilogue))); |
684 | TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute( |
685 | computeDesc.descriptor(), |
686 | CUBLASLT_MATMUL_DESC_BIAS_POINTER, |
687 | &bias, |
688 | sizeof(Dtype*))); |
689 | |
690 | CuBlasLtMatrixLayout Adesc( |
691 | abcType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld); |
692 | CuBlasLtMatrixLayout Bdesc( |
693 | abcType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld); |
694 | CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld); |
695 | |
696 | CuBlasLtMatmulPreference preference; |
697 | // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind |
698 | // setting this to 1M. |
699 | size_t workspaceSize = 1024 * 1024; |
700 | TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute( |
701 | preference.descriptor(), |
702 | CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, |
703 | &workspaceSize, |
704 | sizeof(workspaceSize))); |
705 | |
706 | auto workspace = at::empty( |
707 | {static_cast<int64_t>(workspaceSize)}, |
708 | at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)); |
709 | |
710 | cublasLtMatmulHeuristicResult_t heuristicResult = {}; |
711 | int returnedResult = 0; |
712 | cublasLtHandle_t ltHandle = |
713 | reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle()); |
714 | TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( |
715 | ltHandle, |
716 | computeDesc.descriptor(), |
717 | Adesc.descriptor(), |
718 | Bdesc.descriptor(), |
719 | Cdesc.descriptor(), |
720 | Cdesc.descriptor(), |
721 | preference.descriptor(), |
722 | 1, |
723 | &heuristicResult, |
724 | &returnedResult)); |
725 | if (returnedResult == 0) { |
726 | TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED); |
727 | } |
728 | |
729 | cublasStatus_t cublasStatus = cublasLtMatmul( |
730 | ltHandle, |
731 | computeDesc.descriptor(), |
732 | &alpha_val, |
733 | mat1_ptr, |
734 | Adesc.descriptor(), |
735 | mat2_ptr, |
736 | Bdesc.descriptor(), |
737 | &beta_val, |
738 | result_ptr, |
739 | Cdesc.descriptor(), |
740 | result_ptr, |
741 | Cdesc.descriptor(), |
742 | &heuristicResult.algo, |
743 | workspace.data_ptr(), |
744 | workspaceSize, |
745 | at::cuda::getCurrentCUDAStream()); |
746 | TORCH_CHECK( |
747 | cublasStatus == CUBLAS_STATUS_SUCCESS, |
748 | "CUDA error: " , |
749 | at::cuda::blas::_cublasGetErrorEnum(cublasStatus), |
750 | " when calling cublasLtMatmul with transpose_mat1 " , |
751 | transpose_mat1, |
752 | " transpose_mat2 " , |
753 | transpose_mat2, |
754 | " m " , |
755 | m, |
756 | " n " , |
757 | n, |
758 | " k " , |
759 | k, |
760 | " mat1_ld " , |
761 | mat1_ld, |
762 | " mat2_ld " , |
763 | mat2_ld, |
764 | " result_ld " , |
765 | result_ld, |
766 | " abcType " , |
767 | abcType, |
768 | " computeType " , |
769 | computeType, |
770 | " scaleType " , |
771 | scaleType); |
772 | } |
773 | |
774 | template void gemm_and_bias( |
775 | bool transpose_mat1, |
776 | bool transpose_mat2, |
777 | int64_t m, |
778 | int64_t n, |
779 | int64_t k, |
780 | at::opmath_type<double> alpha_val, |
781 | const double* mat1_ptr, |
782 | int64_t mat1_ld, |
783 | const double* mat2_ptr, |
784 | int64_t mat2_ld, |
785 | const double* bias, |
786 | double* result_ptr, |
787 | int64_t result_ld, |
788 | GEMMAndBiasActivationEpilogue activation); |
789 | |
790 | template void gemm_and_bias( |
791 | bool transpose_mat1, |
792 | bool transpose_mat2, |
793 | int64_t m, |
794 | int64_t n, |
795 | int64_t k, |
796 | at::opmath_type<float> alpha_val, |
797 | const float* mat1_ptr, |
798 | int64_t mat1_ld, |
799 | const float* mat2_ptr, |
800 | int64_t mat2_ld, |
801 | const float* bias, |
802 | float* result_ptr, |
803 | int64_t result_ld, |
804 | GEMMAndBiasActivationEpilogue activation); |
805 | |
806 | template void gemm_and_bias( |
807 | bool transpose_mat1, |
808 | bool transpose_mat2, |
809 | int64_t m, |
810 | int64_t n, |
811 | int64_t k, |
812 | at::opmath_type<at::Half> alpha_val, |
813 | const at::Half* mat1_ptr, |
814 | int64_t mat1_ld, |
815 | const at::Half* mat2_ptr, |
816 | int64_t mat2_ld, |
817 | const at::Half* bias, |
818 | at::Half* result_ptr, |
819 | int64_t result_ld, |
820 | GEMMAndBiasActivationEpilogue activation); |
821 | |
822 | template void gemm_and_bias( |
823 | bool transpose_mat1, |
824 | bool transpose_mat2, |
825 | int64_t m, |
826 | int64_t n, |
827 | int64_t k, |
828 | at::opmath_type<at::BFloat16> alpha_val, |
829 | const at::BFloat16* mat1_ptr, |
830 | int64_t mat1_ld, |
831 | const at::BFloat16* mat2_ptr, |
832 | int64_t mat2_ld, |
833 | const at::BFloat16* bias, |
834 | at::BFloat16* result_ptr, |
835 | int64_t result_ld, |
836 | GEMMAndBiasActivationEpilogue activation); |
837 | #endif // !defined(USE_ROCM) && !defined(_MSC_VER) |
838 | |
839 | template <> |
840 | void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) { |
841 | TORCH_CUDABLAS_CHECK(cublasStrsm( |
842 | handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb)); |
843 | } |
844 | |
845 | template <> |
846 | void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double)) { |
847 | TORCH_CUDABLAS_CHECK(cublasDtrsm( |
848 | handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb)); |
849 | } |
850 | |
851 | template <> |
852 | void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>)) { |
853 | TORCH_CUDABLAS_CHECK(cublasCtrsm( |
854 | handle, |
855 | side, |
856 | uplo, |
857 | trans, |
858 | diag, |
859 | m, |
860 | n, |
861 | reinterpret_cast<const cuComplex*>(alpha), |
862 | reinterpret_cast<const cuComplex*>(A), |
863 | lda, |
864 | reinterpret_cast<cuComplex*>(B), |
865 | ldb)); |
866 | } |
867 | |
868 | template <> |
869 | void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>)) { |
870 | TORCH_CUDABLAS_CHECK(cublasZtrsm( |
871 | handle, |
872 | side, |
873 | uplo, |
874 | trans, |
875 | diag, |
876 | m, |
877 | n, |
878 | reinterpret_cast<const cuDoubleComplex*>(alpha), |
879 | reinterpret_cast<const cuDoubleComplex*>(A), |
880 | lda, |
881 | reinterpret_cast<cuDoubleComplex*>(B), |
882 | ldb)); |
883 | } |
884 | |
885 | template <> |
886 | void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)) { |
887 | TORCH_CUDABLAS_CHECK(cublasStrsmBatched( |
888 | handle, |
889 | side, |
890 | uplo, |
891 | trans, |
892 | diag, |
893 | m, |
894 | n, |
895 | alpha, |
896 | A, |
897 | lda, |
898 | B, |
899 | ldb, |
900 | batchCount)); |
901 | } |
902 | |
903 | template <> |
904 | void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)) { |
905 | TORCH_CUDABLAS_CHECK(cublasDtrsmBatched( |
906 | handle, |
907 | side, |
908 | uplo, |
909 | trans, |
910 | diag, |
911 | m, |
912 | n, |
913 | alpha, |
914 | A, |
915 | lda, |
916 | B, |
917 | ldb, |
918 | batchCount)); |
919 | } |
920 | |
921 | template <> |
922 | void trsmBatched<c10::complex<float>>( |
923 | CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>)) { |
924 | TORCH_CUDABLAS_CHECK(cublasCtrsmBatched( |
925 | handle, |
926 | side, |
927 | uplo, |
928 | trans, |
929 | diag, |
930 | m, |
931 | n, |
932 | reinterpret_cast<const cuComplex*>(alpha), |
933 | reinterpret_cast<cuComplex**>(A), |
934 | lda, |
935 | reinterpret_cast<cuComplex**>(B), |
936 | ldb, |
937 | batchCount)); |
938 | } |
939 | |
940 | template <> |
941 | void trsmBatched<c10::complex<double>>( |
942 | CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>)) { |
943 | TORCH_CUDABLAS_CHECK(cublasZtrsmBatched( |
944 | handle, |
945 | side, |
946 | uplo, |
947 | trans, |
948 | diag, |
949 | m, |
950 | n, |
951 | reinterpret_cast<const cuDoubleComplex*>(alpha), |
952 | reinterpret_cast<cuDoubleComplex**>(A), |
953 | lda, |
954 | reinterpret_cast<cuDoubleComplex**>(B), |
955 | ldb, |
956 | batchCount)); |
957 | } |
958 | |
959 | /* LEVEL 2 BLAS FUNCTIONS */ |
960 | |
961 | #define GEMV_CHECK_ARGVALUES(Dtype) \ |
962 | do { \ |
963 | CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, m); \ |
964 | CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, n); \ |
965 | CUDABLAS_POSINT_CHECK(gemv<Dtype>, lda); \ |
966 | CUDABLAS_POSINT_CHECK(gemv<Dtype>, incx); \ |
967 | CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \ |
968 | } while (0) |
969 | |
970 | #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000) |
971 | template <> |
972 | void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) { |
973 | // See Note [Writing Nondeterministic Operations] |
974 | globalContext().alertCuBLASConfigNotDeterministic(); |
975 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
976 | cublasOperation_t op = _cublasOpFromChar(trans); |
977 | _cublasAdjustLdLevel2(m, n, &lda); |
978 | GEMV_CHECK_ARGVALUES(c10::complex<double>); |
979 | TORCH_CUDABLAS_CHECK( |
980 | cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a), |
981 | lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta), |
982 | reinterpret_cast<cuDoubleComplex*>(y), incy)); |
983 | } |
984 | #endif |
985 | |
986 | #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000) |
987 | template <> |
988 | void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) { |
989 | // gemv is bw bound, and does not benefit from TF32. But the precision |
990 | // loss still happens on TF32. So we disable it here. |
991 | NoTF32Guard disable_tf32; |
992 | // See Note [Writing Nondeterministic Operations] |
993 | globalContext().alertCuBLASConfigNotDeterministic(); |
994 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
995 | cublasOperation_t op = _cublasOpFromChar(trans); |
996 | _cublasAdjustLdLevel2(m, n, &lda); |
997 | GEMV_CHECK_ARGVALUES(c10::complex<float>); |
998 | TORCH_CUDABLAS_CHECK( |
999 | cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a), |
1000 | lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta), |
1001 | reinterpret_cast<cuComplex*>(y), incy)); |
1002 | } |
1003 | #endif |
1004 | |
1005 | template <> |
1006 | void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) { |
1007 | // See Note [Writing Nondeterministic Operations] |
1008 | globalContext().alertCuBLASConfigNotDeterministic(); |
1009 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
1010 | cublasOperation_t op = _cublasOpFromChar(trans); |
1011 | _cublasAdjustLdLevel2(m, n, &lda); |
1012 | GEMV_CHECK_ARGVALUES(double); |
1013 | TORCH_CUDABLAS_CHECK( |
1014 | cublasDgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy)); |
1015 | } |
1016 | |
1017 | template <> |
1018 | void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) { |
1019 | // gemv is bw bound, and does not benefit from TF32. But the precision |
1020 | // loss still happens on TF32. So we disable it here. |
1021 | NoTF32Guard disable_tf32; |
1022 | // See Note [Writing Nondeterministic Operations] |
1023 | globalContext().alertCuBLASConfigNotDeterministic(); |
1024 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
1025 | cublasOperation_t op = _cublasOpFromChar(trans); |
1026 | _cublasAdjustLdLevel2(m, n, &lda); |
1027 | GEMV_CHECK_ARGVALUES(float); |
1028 | TORCH_CUDABLAS_CHECK( |
1029 | cublasSgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy)); |
1030 | } |
1031 | |
1032 | template <> |
1033 | void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)) { |
1034 | // In general, cublas regards matrices as column-major. |
1035 | // The cublasS/Dgemv usages in cuda::blas::gemv<float>/<double> above |
1036 | // require that external blas::gemv callers obey the following convention: |
1037 | // |
1038 | // If "a" is row-major with shape (output, summed) in blas::gemv's caller, |
1039 | // caller interprets it as column-major with shape (summed, output), passes |
1040 | // summed and output respectively to our local vars m, n, and requests that cublas |
1041 | // internally transpose ("trans") the column-major interpretation of a. |
1042 | // |
1043 | // There's no such thing as "cublasHalfgemv", so here we hack gemv with a gemm. |
1044 | // However, we must allow the same calling convention, because the caller shouldn't |
1045 | // have to swap args based on whether it's calling blas::gemv<at::Half> or <float>. |
1046 | |
1047 | bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N); |
1048 | if (trans_bool) { |
1049 | std::swap(m, n); |
1050 | } |
1051 | // After swap, local vars m, n contain the output and summed sizes respectively, |
1052 | // regardless of whether "a" was row-major or column-major in gemv<>'s caller. |
1053 | |
1054 | // To handle the possibility incy > 1, interprets vector y as column-major matrix with one row |
1055 | // (shape (1, output)) and leading dim incy. |
1056 | // trans(a)*x would compute a matrix with one column (shape (output, 1)) which wouldn't match y. |
1057 | // So instead, we interpret x similarly to y, as a column-major matrix with one row |
1058 | // (shape (1, summed)) and leading dim incx. The gemm then carries out x*transpose(trans(a)) to |
1059 | // produce a matrix with one row (shape (1, output)), matching y. |
1060 | char trans_flipped = (trans_bool ? 'n' : 't'); |
1061 | gemm<at::Half>( |
1062 | 'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy); |
1063 | } |
1064 | |
1065 | template <> |
1066 | void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)) { |
1067 | bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N); |
1068 | if (trans_bool) { |
1069 | std::swap(m, n); |
1070 | } |
1071 | char trans_flipped = (trans_bool ? 'n' : 't'); |
1072 | gemm<at::BFloat16>( |
1073 | 'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy); |
1074 | } |
1075 | |
1076 | /* LEVEL 1 BLAS FUNCTIONS */ |
1077 | |
1078 | template <> |
1079 | void dot<double>(CUDABLAS_DOT_ARGTYPES(double)) { |
1080 | TORCH_CUDABLAS_CHECK(cublasDdot(handle, n, x, incx, y, incy, result)); |
1081 | } |
1082 | |
1083 | template <> |
1084 | void dot<float>(CUDABLAS_DOT_ARGTYPES(float)) { |
1085 | TORCH_CUDABLAS_CHECK(cublasSdot(handle, n, x, incx, y, incy, result)); |
1086 | } |
1087 | |
1088 | template <> |
1089 | void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) { |
1090 | TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), |
1091 | incx, reinterpret_cast<const cuDoubleComplex*>(y), incy, |
1092 | reinterpret_cast<cuDoubleComplex*>(result))); |
1093 | } |
1094 | |
1095 | template <> |
1096 | void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) { |
1097 | TORCH_CUDABLAS_CHECK(cublasCdotu(handle, n, reinterpret_cast<const cuComplex*>(x), |
1098 | incx, reinterpret_cast<const cuComplex*>(y), incy, |
1099 | reinterpret_cast<cuComplex*>(result))); |
1100 | } |
1101 | |
1102 | template <> |
1103 | void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) { |
1104 | #if !defined(USE_ROCM) |
1105 | TORCH_CUDABLAS_CHECK(cublasDotEx( |
1106 | handle, |
1107 | n, |
1108 | x, |
1109 | CUDA_R_16F, |
1110 | incx, |
1111 | y, |
1112 | CUDA_R_16F, |
1113 | incy, |
1114 | result, |
1115 | CUDA_R_16F, |
1116 | CUDA_R_32F)); |
1117 | #elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000 |
1118 | TORCH_CUDABLAS_CHECK(rocblas_hdot( |
1119 | handle, |
1120 | n, |
1121 | reinterpret_cast<const rocblas_half*>(x), |
1122 | incx, |
1123 | reinterpret_cast<const rocblas_half*>(y), |
1124 | incy, |
1125 | reinterpret_cast<rocblas_half*>(result))); |
1126 | #else |
1127 | AT_ERROR("Cublas_Hdot requires CUDA 8.0+" ); |
1128 | #endif |
1129 | } |
1130 | |
1131 | template <> |
1132 | void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) { |
1133 | #if !defined(USE_ROCM) |
1134 | TORCH_CUDABLAS_CHECK(cublasDotEx( |
1135 | handle, |
1136 | n, |
1137 | x, |
1138 | CUDA_R_16BF, |
1139 | incx, |
1140 | y, |
1141 | CUDA_R_16BF, |
1142 | incy, |
1143 | result, |
1144 | CUDA_R_16BF, |
1145 | CUDA_R_32F)); |
1146 | #elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000 |
1147 | TORCH_CUDABLAS_CHECK(rocblas_bfdot( |
1148 | handle, |
1149 | n, |
1150 | reinterpret_cast<const rocblas_bfloat16*>(x), |
1151 | incx, |
1152 | reinterpret_cast<const rocblas_bfloat16*>(y), |
1153 | incy, |
1154 | reinterpret_cast<rocblas_bfloat16*>(result))); |
1155 | #else |
1156 | AT_ERROR("Cublas_bfdot requires CUDA 11.0+" ); |
1157 | #endif |
1158 | } |
1159 | |
1160 | template <> |
1161 | void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) { |
1162 | TORCH_CUDABLAS_CHECK(cublasCdotc(handle, n, reinterpret_cast<const cuComplex*>(x), |
1163 | incx, reinterpret_cast<const cuComplex*>(y), incy, |
1164 | reinterpret_cast<cuComplex*>(result))); |
1165 | } |
1166 | |
1167 | template <> |
1168 | void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) { |
1169 | TORCH_CUDABLAS_CHECK(cublasZdotc(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), |
1170 | incx, reinterpret_cast<const cuDoubleComplex*>(y), incy, |
1171 | reinterpret_cast<cuDoubleComplex*>(result))); |
1172 | } |
1173 | |
1174 | // This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda |
1175 | #ifdef CUDART_VERSION |
1176 | |
1177 | template <> |
1178 | void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float)) { |
1179 | TORCH_CUDABLAS_CHECK(cublasSgetrsBatched( |
1180 | handle, |
1181 | trans, |
1182 | n, |
1183 | nrhs, |
1184 | dA_array, |
1185 | lda, |
1186 | ipiv_array, |
1187 | dB_array, |
1188 | ldb, |
1189 | info_array, |
1190 | batchsize)); |
1191 | } |
1192 | |
1193 | template <> |
1194 | void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double)) { |
1195 | TORCH_CUDABLAS_CHECK(cublasDgetrsBatched( |
1196 | handle, |
1197 | trans, |
1198 | n, |
1199 | nrhs, |
1200 | dA_array, |
1201 | lda, |
1202 | ipiv_array, |
1203 | dB_array, |
1204 | ldb, |
1205 | info_array, |
1206 | batchsize)); |
1207 | } |
1208 | |
1209 | template <> |
1210 | void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>)) { |
1211 | TORCH_CUDABLAS_CHECK(cublasCgetrsBatched( |
1212 | handle, |
1213 | trans, |
1214 | n, |
1215 | nrhs, |
1216 | reinterpret_cast<cuComplex**>(dA_array), |
1217 | lda, |
1218 | ipiv_array, |
1219 | reinterpret_cast<cuComplex**>(dB_array), |
1220 | ldb, |
1221 | info_array, |
1222 | batchsize)); |
1223 | } |
1224 | |
1225 | template <> |
1226 | void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>)) { |
1227 | TORCH_CUDABLAS_CHECK(cublasZgetrsBatched( |
1228 | handle, |
1229 | trans, |
1230 | n, |
1231 | nrhs, |
1232 | reinterpret_cast<cuDoubleComplex**>(dA_array), |
1233 | lda, |
1234 | ipiv_array, |
1235 | reinterpret_cast<cuDoubleComplex**>(dB_array), |
1236 | ldb, |
1237 | info_array, |
1238 | batchsize)); |
1239 | } |
1240 | |
1241 | template <> |
1242 | void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)) { |
1243 | TORCH_CUDABLAS_CHECK(cublasSgeqrfBatched( |
1244 | handle, m, n, A_array, lda, tau_array, info, batchsize)); |
1245 | } |
1246 | |
1247 | template <> |
1248 | void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)) { |
1249 | TORCH_CUDABLAS_CHECK(cublasDgeqrfBatched( |
1250 | handle, m, n, A_array, lda, tau_array, info, batchsize)); |
1251 | } |
1252 | |
1253 | template <> |
1254 | void geqrfBatched<c10::complex<float>>( |
1255 | CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>)) { |
1256 | TORCH_CUDABLAS_CHECK(cublasCgeqrfBatched( |
1257 | handle, |
1258 | m, |
1259 | n, |
1260 | reinterpret_cast<cuComplex**>(A_array), |
1261 | lda, |
1262 | reinterpret_cast<cuComplex**>(tau_array), |
1263 | info, |
1264 | batchsize)); |
1265 | } |
1266 | |
1267 | template <> |
1268 | void geqrfBatched<c10::complex<double>>( |
1269 | CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>)) { |
1270 | TORCH_CUDABLAS_CHECK(cublasZgeqrfBatched( |
1271 | handle, |
1272 | m, |
1273 | n, |
1274 | reinterpret_cast<cuDoubleComplex**>(A_array), |
1275 | lda, |
1276 | reinterpret_cast<cuDoubleComplex**>(tau_array), |
1277 | info, |
1278 | batchsize)); |
1279 | } |
1280 | |
1281 | template <> |
1282 | void getrfBatched<double>( |
1283 | int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) { |
1284 | auto handle = at::cuda::getCurrentCUDABlasHandle(); |
1285 | TORCH_CUDABLAS_CHECK(cublasDgetrfBatched( |
1286 | handle, n, dA_array, ldda, ipiv_array, info_array, batchsize)); |
1287 | } |
1288 | |
1289 | template <> |
1290 | void getrfBatched<float>( |
1291 | int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) { |
1292 | auto handle = at::cuda::getCurrentCUDABlasHandle(); |
1293 | TORCH_CUDABLAS_CHECK(cublasSgetrfBatched( |
1294 | handle, n, dA_array, ldda, ipiv_array, info_array, batchsize)); |
1295 | } |
1296 | |
1297 | template <> |
1298 | void getrfBatched<c10::complex<double>>( |
1299 | int n, |
1300 | c10::complex<double>** dA_array, |
1301 | int ldda, |
1302 | int* ipiv_array, |
1303 | int* info_array, |
1304 | int batchsize) { |
1305 | auto handle = at::cuda::getCurrentCUDABlasHandle(); |
1306 | TORCH_CUDABLAS_CHECK(cublasZgetrfBatched( |
1307 | handle, |
1308 | n, |
1309 | reinterpret_cast<cuDoubleComplex**>(dA_array), |
1310 | ldda, |
1311 | ipiv_array, |
1312 | info_array, |
1313 | batchsize)); |
1314 | } |
1315 | |
1316 | template <> |
1317 | void getrfBatched<c10::complex<float>>( |
1318 | int n, |
1319 | c10::complex<float>** dA_array, |
1320 | int ldda, |
1321 | int* ipiv_array, |
1322 | int* info_array, |
1323 | int batchsize) { |
1324 | auto handle = at::cuda::getCurrentCUDABlasHandle(); |
1325 | TORCH_CUDABLAS_CHECK(cublasCgetrfBatched( |
1326 | handle, |
1327 | n, |
1328 | reinterpret_cast<cuComplex**>(dA_array), |
1329 | ldda, |
1330 | ipiv_array, |
1331 | info_array, |
1332 | batchsize)); |
1333 | } |
1334 | |
1335 | |
1336 | template <> |
1337 | void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double)) { |
1338 | TORCH_CUDABLAS_CHECK(cublasDgelsBatched( |
1339 | handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize)); |
1340 | } |
1341 | |
1342 | template <> |
1343 | void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float)) { |
1344 | TORCH_CUDABLAS_CHECK(cublasSgelsBatched( |
1345 | handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize)); |
1346 | } |
1347 | |
1348 | template <> |
1349 | void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>)) { |
1350 | TORCH_CUDABLAS_CHECK(cublasZgelsBatched( |
1351 | handle, trans, |
1352 | m, n, nrhs, |
1353 | reinterpret_cast<cuDoubleComplex**>(dA_array), |
1354 | ldda, |
1355 | reinterpret_cast<cuDoubleComplex**>(dC_array), |
1356 | lddc, |
1357 | info, |
1358 | devInfoArray, |
1359 | batchSize)); |
1360 | } |
1361 | |
1362 | template <> |
1363 | void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>)) { |
1364 | TORCH_CUDABLAS_CHECK(cublasCgelsBatched( |
1365 | handle, trans, |
1366 | m, n, nrhs, |
1367 | reinterpret_cast<cuComplex**>(dA_array), |
1368 | ldda, |
1369 | reinterpret_cast<cuComplex**>(dC_array), |
1370 | lddc, |
1371 | info, |
1372 | devInfoArray, |
1373 | batchSize)); |
1374 | } |
1375 | |
1376 | #endif // CUDART_VERSION |
1377 | |
1378 | } // namespace blas |
1379 | } // namespace cuda |
1380 | } // namespace at |
1381 | |