1/*
2 Provides the implementations of CUDA BLAS function templates.
3 */
4
5#include <ATen/ATen.h>
6#include <ATen/cuda/CUDABlas.h>
7#include <ATen/cuda/Exceptions.h>
8#include <c10/cuda/CUDAFunctions.h>
9#include <c10/macros/Export.h>
10#include <c10/util/irange.h>
11
12// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
13// added bf16 support
14#if !defined(USE_ROCM) && !defined(_MSC_VER)
15#include <cublasLt.h>
16#endif
17
18#ifdef USE_ROCM
19#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
20#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
21#endif
22
23#define CUDABLAS_POSINT_CHECK(FD, X) \
24 TORCH_CHECK( \
25 (X > 0 && X <= INT_MAX), \
26 "at::cuda::blas::" #FD " argument " #X \
27 " must be positive and less than ", \
28 INT_MAX, \
29 " but got ", \
30 X)
31
32#define CUDABLAS_NONNEGINT_CHECK(FD, X) \
33 TORCH_CHECK( \
34 (X >= 0 && X <= INT_MAX), \
35 "at::cuda::blas::" #FD " argument " #X \
36 " must be non-negative and less than ", \
37 INT_MAX, \
38 " but got ", \
39 X)
40
41namespace {
42
43static cublasOperation_t _cublasOpFromChar(char op) {
44 switch (op) {
45 case 'n':
46 case 'N':
47 return CUBLAS_OP_N;
48 case 't':
49 case 'T':
50 return CUBLAS_OP_T;
51 case 'c':
52 case 'C':
53 return CUBLAS_OP_C;
54 }
55 AT_ERROR(
56 "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
57}
58
59static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
60 // Note: leading dimensions generally are checked that they are > 0
61 // and at least as big the result requires (even if the value won't
62 // be used).
63
64 // Q: Why does Level3 check trans but this doesn't?
65 // A: In level 2, the sizes (m, n) specify the size of A
66 // (independent of trans value). In level 3. the sizes (m, n, k)
67 // specify the sizes of op(A), op(B) where op depend on trans
68 // values.
69 if (n <= 1)
70 *lda = std::max<int64_t>(m, 1);
71}
72
73static void _cublasAdjustLdLevel3(
74 char transa,
75 char transb,
76 int64_t m,
77 int64_t n,
78 int64_t k,
79 int64_t* lda,
80 int64_t* ldb,
81 int64_t* ldc) {
82 bool transa_ = ((transa != 'n') && (transa != 'N'));
83 bool transb_ = ((transb != 'n') && (transb != 'N'));
84
85 // Note: leading dimensions generally are checked that they are > 0
86 // and at least as big the result requires (even if the value won't
87 // be used).
88 if (n <= 1)
89 *ldc = std::max<int64_t>(m, 1);
90
91 if (transa_) {
92 if (m <= 1)
93 *lda = std::max<int64_t>(k, 1);
94 } else {
95 if (k <= 1)
96 *lda = std::max<int64_t>(m, 1);
97 }
98
99 if (transb_) {
100 if (k <= 1)
101 *ldb = std::max<int64_t>(n, 1);
102 } else {
103 if (n <= 1)
104 *ldb = std::max<int64_t>(k, 1);
105 }
106}
107} // anonymous namespace
108
109namespace at {
110namespace cuda {
111namespace blas {
112
113/* LEVEL 3 BLAS FUNCTIONS */
114
115#ifndef USE_ROCM
116#if defined(CUDA_VERSION) && CUDA_VERSION >= 11020
117#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx
118#else
119// Workaround for https://github.com/pytorch/pytorch/issues/45724
120cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
121 cublasOperation_t transa,
122 cublasOperation_t transb,
123 int m,
124 int n,
125 int k,
126 const void *alpha,
127 const void *A,
128 cudaDataType Atype,
129 int lda,
130 long long int strideA,
131 const void *B,
132 cudaDataType Btype,
133 int ldb,
134 long long int strideB,
135 const void *beta,
136 void *C,
137 cudaDataType Ctype,
138 int ldc,
139 long long int strideC,
140 int64_t batchCount,
141 cudaDataType computeType,
142 cublasGemmAlgo_t algo)
143{
144 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
145 if (prop->major != 7) {
146 return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo);
147 }
148 cublasStatus_t result;
149 constexpr int64_t split = 63 * 1024;
150 for(int64_t i = 0; i < batchCount; i += split) {
151 int64_t count = std::min<int64_t>(split, batchCount - i);
152 result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha,
153 (char *)A + i * strideA * 2, Atype, lda, strideA,
154 (char *)B + i * strideB * 2, Btype, ldb, strideB,
155 beta,
156 (char *)C + i * strideC * 2, Ctype, ldc, strideC,
157 (int)count, computeType, algo);
158 TORCH_CUDABLAS_CHECK(result);
159 }
160 return result;
161}
162#endif
163#endif
164
165#define GEMM_CHECK_ARGVALUES(Dtype) \
166 do { \
167 CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, m); \
168 CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, n); \
169 CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, k); \
170 CUDABLAS_POSINT_CHECK(gemm<Dtype>, lda); \
171 CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldb); \
172 CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \
173 } while (0)
174
175#define BGEMM_CHECK_ARGVALUES(Dtype) \
176 do { \
177 CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, m); \
178 CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, n); \
179 CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, k); \
180 CUDABLAS_POSINT_CHECK(bgemm<Dtype>, lda); \
181 CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldb); \
182 CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldc); \
183 CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
184 } while (0)
185
186template <>
187void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
188 // See Note [Writing Nondeterministic Operations]
189 globalContext().alertCuBLASConfigNotDeterministic();
190 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
191 cublasOperation_t opa = _cublasOpFromChar(transa);
192 cublasOperation_t opb = _cublasOpFromChar(transb);
193 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
194 BGEMM_CHECK_ARGVALUES(double);
195 TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched(
196 handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
197}
198
199template <>
200void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
201 // See Note [Writing Nondeterministic Operations]
202 globalContext().alertCuBLASConfigNotDeterministic();
203 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
204 cublasOperation_t opa = _cublasOpFromChar(transa);
205 cublasOperation_t opb = _cublasOpFromChar(transb);
206 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
207 BGEMM_CHECK_ARGVALUES(float);
208 TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched(
209 handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
210}
211
212template <>
213void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
214 // See Note [Writing Nondeterministic Operations]
215 globalContext().alertCuBLASConfigNotDeterministic();
216 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
217 cublasOperation_t opa = _cublasOpFromChar(transa);
218 cublasOperation_t opb = _cublasOpFromChar(transb);
219 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
220 BGEMM_CHECK_ARGVALUES(c10::complex<double>);
221 TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched(
222 handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
223 lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
224 reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));
225}
226
227template <>
228void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
229 // See Note [Writing Nondeterministic Operations]
230 globalContext().alertCuBLASConfigNotDeterministic();
231 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
232 cublasOperation_t opa = _cublasOpFromChar(transa);
233 cublasOperation_t opb = _cublasOpFromChar(transb);
234 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
235 BGEMM_CHECK_ARGVALUES(c10::complex<float>);
236 TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched(
237 handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
238 lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta),
239 reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches));
240}
241
242template <>
243void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
244 // See Note [Writing Nondeterministic Operations]
245 globalContext().alertCuBLASConfigNotDeterministic();
246 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
247 cublasOperation_t opa = _cublasOpFromChar(transa);
248 cublasOperation_t opb = _cublasOpFromChar(transb);
249 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
250 BGEMM_CHECK_ARGVALUES(at::Half);
251 float falpha = alpha;
252 float fbeta = beta;
253#ifdef USE_ROCM
254 int flag = 0;
255#if USE_GEMM_FLAGS_FP16_ALT_IMPL
256 flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
257#endif
258 TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
259 (void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
260 b, rocblas_datatype_f16_r, (int)ldb, strideb,
261 (void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
262 c, rocblas_datatype_f16_r, (int)ldc, stridec,
263 (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
264 0, flag));
265#else
266 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
267 if (prop->major >= 5){
268 TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(
269 handle, opa, opb, m, n, k,
270 (void*)(&falpha), a, CUDA_R_16F, lda, stridea,
271 b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
272 c, CUDA_R_16F, ldc, stridec,
273 num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
274 } else {
275 for (const auto i : c10::irange(num_batches)) {
276 at::cuda::blas::gemm<at::Half>(
277 transa, transb,
278 m, n, k,
279 alpha, (a + i * stridea), lda,
280 (b + i * strideb), ldb, beta,
281 (c + i * stridec), ldc);
282 }
283 }
284#endif // USE_ROCM
285}
286
287template <>
288void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
289 // See Note [Writing Nondeterministic Operations]
290 globalContext().alertCuBLASConfigNotDeterministic();
291 BGEMM_CHECK_ARGVALUES(at::BFloat16);
292 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
293 cublasOperation_t opa = _cublasOpFromChar(transa);
294 cublasOperation_t opb = _cublasOpFromChar(transb);
295 const float falpha = alpha;
296 const float fbeta = beta;
297 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
298
299 #if !defined(USE_ROCM)
300 TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
301 opa, opb, (int)m, (int)n, (int)k,
302 (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
303 b, CUDA_R_16BF, (int)ldb, strideb,
304 (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
305 (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
306 #else
307 TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
308 (void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea,
309 b, rocblas_datatype_bf16_r, (int)ldb, strideb,
310 (void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec,
311 c, rocblas_datatype_bf16_r, (int)ldc, stridec,
312 (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
313 0, 0, NULL, NULL));
314 #endif // !defined(USE_ROCM)
315}
316
317template <>
318void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
319 // See Note [Writing Nondeterministic Operations]
320 globalContext().alertCuBLASConfigNotDeterministic();
321 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
322 cublasOperation_t opa = _cublasOpFromChar(transa);
323 cublasOperation_t opb = _cublasOpFromChar(transb);
324 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
325 GEMM_CHECK_ARGVALUES(double);
326 TORCH_CUDABLAS_CHECK(cublasDgemm(
327 handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
328}
329
330template <>
331void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
332 // See Note [Writing Nondeterministic Operations]
333 globalContext().alertCuBLASConfigNotDeterministic();
334 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
335 cublasOperation_t opa = _cublasOpFromChar(transa);
336 cublasOperation_t opb = _cublasOpFromChar(transb);
337 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
338 GEMM_CHECK_ARGVALUES(float);
339 TORCH_CUDABLAS_CHECK(cublasSgemm(
340 handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
341}
342
343#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
344 template <>
345 void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
346 // See Note [Writing Nondeterministic Operations]
347 globalContext().alertCuBLASConfigNotDeterministic();
348 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
349 cublasOperation_t opa = _cublasOpFromChar(transa);
350 cublasOperation_t opb = _cublasOpFromChar(transb);
351 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
352 GEMM_CHECK_ARGVALUES(c10::complex<double>);
353 TORCH_CUDABLAS_CHECK(cublasZgemm(
354 handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
355 lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta),
356 reinterpret_cast<cuDoubleComplex*>(c), ldc));
357 }
358#endif
359
360#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
361 template <>
362 void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
363 // See Note [Writing Nondeterministic Operations]
364 globalContext().alertCuBLASConfigNotDeterministic();
365 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
366 cublasOperation_t opa = _cublasOpFromChar(transa);
367 cublasOperation_t opb = _cublasOpFromChar(transb);
368 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
369 GEMM_CHECK_ARGVALUES(c10::complex<float>);
370 TORCH_CUDABLAS_CHECK(cublasCgemm(
371 handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
372 lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta),
373 reinterpret_cast<cuComplex*>(c), ldc));
374 }
375#endif
376
377template <>
378void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
379 // See Note [Writing Nondeterministic Operations]
380 globalContext().alertCuBLASConfigNotDeterministic();
381 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
382 cublasOperation_t opa = _cublasOpFromChar(transa);
383 cublasOperation_t opb = _cublasOpFromChar(transb);
384 float falpha = alpha;
385 float fbeta = beta;
386 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
387 GEMM_CHECK_ARGVALUES(at::Half);
388#ifdef USE_ROCM
389 int flag = 0;
390#if USE_GEMM_FLAGS_FP16_ALT_IMPL
391 flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
392#endif
393 TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
394 handle,
395 opa,
396 opb,
397 m,
398 n,
399 k,
400 &falpha,
401 a,
402 rocblas_datatype_f16_r,
403 lda,
404 b,
405 rocblas_datatype_f16_r,
406 ldb,
407 &fbeta,
408 c,
409 rocblas_datatype_f16_r,
410 ldc,
411 c,
412 rocblas_datatype_f16_r,
413 ldc,
414 rocblas_datatype_f32_r,
415 rocblas_gemm_algo_standard,
416 0,
417 flag));
418#else
419 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
420 if (prop->major >= 5) {
421 cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
422 if (!at::globalContext().allowFP16ReductionCuBLAS()) {
423 cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
424 }
425 // Disallow fp16 reductions that could lead to unexpected overflow issues.
426 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
427 TORCH_CUDABLAS_CHECK(cublasGemmEx(
428 handle,
429 opa,
430 opb,
431 m,
432 n,
433 k,
434 &falpha,
435 a,
436 CUDA_R_16F,
437 lda,
438 b,
439 CUDA_R_16F,
440 ldb,
441 &fbeta,
442 c,
443 CUDA_R_16F,
444 ldc,
445 CUDA_R_32F,
446 CUBLAS_GEMM_DFALT_TENSOR_OP));
447 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
448 } else {
449 TORCH_CUDABLAS_CHECK(cublasSgemmEx(
450 handle,
451 opa,
452 opb,
453 m,
454 n,
455 k,
456 &falpha,
457 a,
458 CUDA_R_16F,
459 lda,
460 b,
461 CUDA_R_16F,
462 ldb,
463 &fbeta,
464 c,
465 CUDA_R_16F,
466 ldc));
467 }
468#endif
469}
470
471#ifdef USE_ROCM
472template <>
473void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
474 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
475 cublasOperation_t opa = _cublasOpFromChar(transa);
476 cublasOperation_t opb = _cublasOpFromChar(transb);
477 float falpha = alpha;
478 float fbeta = beta;
479 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
480 GEMM_CHECK_ARGVALUES(at::BFloat16);
481 TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
482 handle,
483 opa,
484 opb,
485 m,
486 n,
487 k,
488 &falpha,
489 a,
490 rocblas_datatype_bf16_r,
491 lda,
492 b,
493 rocblas_datatype_bf16_r,
494 ldb,
495 &fbeta,
496 c,
497 rocblas_datatype_bf16_r,
498 ldc,
499 c,
500 rocblas_datatype_bf16_r,
501 ldc,
502 rocblas_datatype_f32_r,
503 rocblas_gemm_algo_standard,
504 0,
505 0));
506}
507#endif
508
509#if !defined(USE_ROCM)
510template <>
511void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
512 globalContext().alertCuBLASConfigNotDeterministic();
513 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
514 cublasOperation_t opa = _cublasOpFromChar(transa);
515 cublasOperation_t opb = _cublasOpFromChar(transb);
516 float falpha = alpha;
517 float fbeta = beta;
518 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
519 GEMM_CHECK_ARGVALUES(at::BFloat16);
520 cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
521 if (!at::globalContext().allowBF16ReductionCuBLAS()) {
522 cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
523 }
524 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
525 TORCH_CUDABLAS_CHECK(cublasGemmEx(
526 handle,
527 opa,
528 opb,
529 m,
530 n,
531 k,
532 &falpha,
533 a,
534 CUDA_R_16BF,
535 lda,
536 b,
537 CUDA_R_16BF,
538 ldb,
539 &fbeta,
540 c,
541 CUDA_R_16BF,
542 ldc,
543 CUDA_R_32F,
544 CUBLAS_GEMM_DFALT_TENSOR_OP));
545 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
546}
547#endif // !defined(USE_ROCM)
548
549#if !defined(USE_ROCM) && !defined(_MSC_VER)
550
551namespace {
552// Following the pattern of CuSparseDescriptor
553// Defined here for now because this is the only place cublas_lt interface is
554// used but can be moved to a header once cublas_lt interface is used in
555// multiple places.
556template <typename T, cublasStatus_t (*destructor)(T*)>
557struct CuBlasLtDeleter {
558 void operator()(T* x) {
559 if (x != nullptr) {
560 TORCH_CUDABLAS_CHECK(destructor(x));
561 }
562 }
563};
564
565template <typename T, cublasStatus_t (*destructor)(T*)>
566class CuBlasLtDescriptor {
567 public:
568 T* descriptor() const {
569 return descriptor_.get();
570 }
571 T* descriptor() {
572 return descriptor_.get();
573 }
574
575 protected:
576 std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
577};
578
579class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
580 cublasLtMatmulDescOpaque_t,
581 &cublasLtMatmulDescDestroy> {
582 public:
583 CuBlasLtMatmulDescriptor(
584 cublasComputeType_t compute_type,
585 cudaDataType_t scale_type) {
586 cublasLtMatmulDesc_t raw_descriptor = nullptr;
587 TORCH_CUDABLAS_CHECK(
588 cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
589 descriptor_.reset(raw_descriptor);
590 }
591};
592
593class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
594 cublasLtMatrixLayoutOpaque_t,
595 &cublasLtMatrixLayoutDestroy> {
596 public:
597 CuBlasLtMatrixLayout(
598 cudaDataType_t type,
599 uint64_t rows,
600 uint64_t cols,
601 int64_t ld) {
602 cublasLtMatrixLayout_t raw_descriptor = nullptr;
603 TORCH_CUDABLAS_CHECK(
604 cublasLtMatrixLayoutCreate(&raw_descriptor, type, rows, cols, ld));
605 descriptor_.reset(raw_descriptor);
606 }
607};
608
609class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
610 cublasLtMatmulPreferenceOpaque_t,
611 &cublasLtMatmulPreferenceDestroy> {
612 public:
613 CuBlasLtMatmulPreference() {
614 cublasLtMatmulPreference_t raw_descriptor = nullptr;
615 TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
616 descriptor_.reset(raw_descriptor);
617 }
618};
619} // namespace
620
621template <typename Dtype>
622void gemm_and_bias(
623 bool transpose_mat1,
624 bool transpose_mat2,
625 int64_t m,
626 int64_t n,
627 int64_t k,
628 at::opmath_type<Dtype> alpha_val,
629 const Dtype* mat1_ptr,
630 int64_t mat1_ld,
631 const Dtype* mat2_ptr,
632 int64_t mat2_ld,
633 const Dtype* bias,
634 Dtype* result_ptr,
635 int64_t result_ld,
636 GEMMAndBiasActivationEpilogue activation) {
637 using opmath_t = at::opmath_type<Dtype>;
638 opmath_t beta_val = 0; // bias is added in epilogue
639
640 cudaDataType_t abcType = CUDA_R_32F;
641 cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
642 cudaDataType_t scaleType = CUDA_R_32F;
643 if (std::is_same<Dtype, double>::value) {
644 abcType = CUDA_R_64F;
645 computeType = CUBLAS_COMPUTE_64F;
646 scaleType = CUDA_R_64F;
647 } else if (std::is_same<Dtype, float>::value) {
648 if (at::globalContext().allowTF32CuBLAS()) {
649 computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
650 }
651 abcType = CUDA_R_32F;
652 } else if (std::is_same<Dtype, at::Half>::value) {
653 abcType = CUDA_R_16F;
654 } else if (std::is_same<Dtype, at::BFloat16>::value) {
655 abcType = CUDA_R_16BF;
656 }
657
658 CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
659 cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
660 TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
661 computeDesc.descriptor(),
662 CUBLASLT_MATMUL_DESC_TRANSA,
663 &transa,
664 sizeof(transa)));
665 cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
666 TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
667 computeDesc.descriptor(),
668 CUBLASLT_MATMUL_DESC_TRANSB,
669 &transb,
670 sizeof(transb)));
671 cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
672 if (activation == GEMMAndBiasActivationEpilogue::RELU) {
673 epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
674 } else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
675#if CUDA_VERSION >= 11040
676 epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
677#endif
678 }
679 TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
680 computeDesc.descriptor(),
681 CUBLASLT_MATMUL_DESC_EPILOGUE,
682 &epilogue,
683 sizeof(epilogue)));
684 TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
685 computeDesc.descriptor(),
686 CUBLASLT_MATMUL_DESC_BIAS_POINTER,
687 &bias,
688 sizeof(Dtype*)));
689
690 CuBlasLtMatrixLayout Adesc(
691 abcType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld);
692 CuBlasLtMatrixLayout Bdesc(
693 abcType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld);
694 CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
695
696 CuBlasLtMatmulPreference preference;
697 // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
698 // setting this to 1M.
699 size_t workspaceSize = 1024 * 1024;
700 TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(
701 preference.descriptor(),
702 CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
703 &workspaceSize,
704 sizeof(workspaceSize)));
705
706 auto workspace = at::empty(
707 {static_cast<int64_t>(workspaceSize)},
708 at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));
709
710 cublasLtMatmulHeuristicResult_t heuristicResult = {};
711 int returnedResult = 0;
712 cublasLtHandle_t ltHandle =
713 reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
714 TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
715 ltHandle,
716 computeDesc.descriptor(),
717 Adesc.descriptor(),
718 Bdesc.descriptor(),
719 Cdesc.descriptor(),
720 Cdesc.descriptor(),
721 preference.descriptor(),
722 1,
723 &heuristicResult,
724 &returnedResult));
725 if (returnedResult == 0) {
726 TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
727 }
728
729 cublasStatus_t cublasStatus = cublasLtMatmul(
730 ltHandle,
731 computeDesc.descriptor(),
732 &alpha_val,
733 mat1_ptr,
734 Adesc.descriptor(),
735 mat2_ptr,
736 Bdesc.descriptor(),
737 &beta_val,
738 result_ptr,
739 Cdesc.descriptor(),
740 result_ptr,
741 Cdesc.descriptor(),
742 &heuristicResult.algo,
743 workspace.data_ptr(),
744 workspaceSize,
745 at::cuda::getCurrentCUDAStream());
746 TORCH_CHECK(
747 cublasStatus == CUBLAS_STATUS_SUCCESS,
748 "CUDA error: ",
749 at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
750 " when calling cublasLtMatmul with transpose_mat1 ",
751 transpose_mat1,
752 " transpose_mat2 ",
753 transpose_mat2,
754 " m ",
755 m,
756 " n ",
757 n,
758 " k ",
759 k,
760 " mat1_ld ",
761 mat1_ld,
762 " mat2_ld ",
763 mat2_ld,
764 " result_ld ",
765 result_ld,
766 " abcType ",
767 abcType,
768 " computeType ",
769 computeType,
770 " scaleType ",
771 scaleType);
772}
773
774template void gemm_and_bias(
775 bool transpose_mat1,
776 bool transpose_mat2,
777 int64_t m,
778 int64_t n,
779 int64_t k,
780 at::opmath_type<double> alpha_val,
781 const double* mat1_ptr,
782 int64_t mat1_ld,
783 const double* mat2_ptr,
784 int64_t mat2_ld,
785 const double* bias,
786 double* result_ptr,
787 int64_t result_ld,
788 GEMMAndBiasActivationEpilogue activation);
789
790template void gemm_and_bias(
791 bool transpose_mat1,
792 bool transpose_mat2,
793 int64_t m,
794 int64_t n,
795 int64_t k,
796 at::opmath_type<float> alpha_val,
797 const float* mat1_ptr,
798 int64_t mat1_ld,
799 const float* mat2_ptr,
800 int64_t mat2_ld,
801 const float* bias,
802 float* result_ptr,
803 int64_t result_ld,
804 GEMMAndBiasActivationEpilogue activation);
805
806template void gemm_and_bias(
807 bool transpose_mat1,
808 bool transpose_mat2,
809 int64_t m,
810 int64_t n,
811 int64_t k,
812 at::opmath_type<at::Half> alpha_val,
813 const at::Half* mat1_ptr,
814 int64_t mat1_ld,
815 const at::Half* mat2_ptr,
816 int64_t mat2_ld,
817 const at::Half* bias,
818 at::Half* result_ptr,
819 int64_t result_ld,
820 GEMMAndBiasActivationEpilogue activation);
821
822template void gemm_and_bias(
823 bool transpose_mat1,
824 bool transpose_mat2,
825 int64_t m,
826 int64_t n,
827 int64_t k,
828 at::opmath_type<at::BFloat16> alpha_val,
829 const at::BFloat16* mat1_ptr,
830 int64_t mat1_ld,
831 const at::BFloat16* mat2_ptr,
832 int64_t mat2_ld,
833 const at::BFloat16* bias,
834 at::BFloat16* result_ptr,
835 int64_t result_ld,
836 GEMMAndBiasActivationEpilogue activation);
837#endif // !defined(USE_ROCM) && !defined(_MSC_VER)
838
839template <>
840void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
841 TORCH_CUDABLAS_CHECK(cublasStrsm(
842 handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb));
843}
844
845template <>
846void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double)) {
847 TORCH_CUDABLAS_CHECK(cublasDtrsm(
848 handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb));
849}
850
851template <>
852void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>)) {
853 TORCH_CUDABLAS_CHECK(cublasCtrsm(
854 handle,
855 side,
856 uplo,
857 trans,
858 diag,
859 m,
860 n,
861 reinterpret_cast<const cuComplex*>(alpha),
862 reinterpret_cast<const cuComplex*>(A),
863 lda,
864 reinterpret_cast<cuComplex*>(B),
865 ldb));
866}
867
868template <>
869void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>)) {
870 TORCH_CUDABLAS_CHECK(cublasZtrsm(
871 handle,
872 side,
873 uplo,
874 trans,
875 diag,
876 m,
877 n,
878 reinterpret_cast<const cuDoubleComplex*>(alpha),
879 reinterpret_cast<const cuDoubleComplex*>(A),
880 lda,
881 reinterpret_cast<cuDoubleComplex*>(B),
882 ldb));
883}
884
885template <>
886void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)) {
887 TORCH_CUDABLAS_CHECK(cublasStrsmBatched(
888 handle,
889 side,
890 uplo,
891 trans,
892 diag,
893 m,
894 n,
895 alpha,
896 A,
897 lda,
898 B,
899 ldb,
900 batchCount));
901}
902
903template <>
904void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)) {
905 TORCH_CUDABLAS_CHECK(cublasDtrsmBatched(
906 handle,
907 side,
908 uplo,
909 trans,
910 diag,
911 m,
912 n,
913 alpha,
914 A,
915 lda,
916 B,
917 ldb,
918 batchCount));
919}
920
921template <>
922void trsmBatched<c10::complex<float>>(
923 CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>)) {
924 TORCH_CUDABLAS_CHECK(cublasCtrsmBatched(
925 handle,
926 side,
927 uplo,
928 trans,
929 diag,
930 m,
931 n,
932 reinterpret_cast<const cuComplex*>(alpha),
933 reinterpret_cast<cuComplex**>(A),
934 lda,
935 reinterpret_cast<cuComplex**>(B),
936 ldb,
937 batchCount));
938}
939
940template <>
941void trsmBatched<c10::complex<double>>(
942 CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>)) {
943 TORCH_CUDABLAS_CHECK(cublasZtrsmBatched(
944 handle,
945 side,
946 uplo,
947 trans,
948 diag,
949 m,
950 n,
951 reinterpret_cast<const cuDoubleComplex*>(alpha),
952 reinterpret_cast<cuDoubleComplex**>(A),
953 lda,
954 reinterpret_cast<cuDoubleComplex**>(B),
955 ldb,
956 batchCount));
957}
958
959/* LEVEL 2 BLAS FUNCTIONS */
960
961#define GEMV_CHECK_ARGVALUES(Dtype) \
962 do { \
963 CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, m); \
964 CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, n); \
965 CUDABLAS_POSINT_CHECK(gemv<Dtype>, lda); \
966 CUDABLAS_POSINT_CHECK(gemv<Dtype>, incx); \
967 CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
968 } while (0)
969
970#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
971 template <>
972 void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
973 // See Note [Writing Nondeterministic Operations]
974 globalContext().alertCuBLASConfigNotDeterministic();
975 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
976 cublasOperation_t op = _cublasOpFromChar(trans);
977 _cublasAdjustLdLevel2(m, n, &lda);
978 GEMV_CHECK_ARGVALUES(c10::complex<double>);
979 TORCH_CUDABLAS_CHECK(
980 cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
981 lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta),
982 reinterpret_cast<cuDoubleComplex*>(y), incy));
983 }
984#endif
985
986#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
987template <>
988void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
989 // gemv is bw bound, and does not benefit from TF32. But the precision
990 // loss still happens on TF32. So we disable it here.
991 NoTF32Guard disable_tf32;
992 // See Note [Writing Nondeterministic Operations]
993 globalContext().alertCuBLASConfigNotDeterministic();
994 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
995 cublasOperation_t op = _cublasOpFromChar(trans);
996 _cublasAdjustLdLevel2(m, n, &lda);
997 GEMV_CHECK_ARGVALUES(c10::complex<float>);
998 TORCH_CUDABLAS_CHECK(
999 cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
1000 lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
1001 reinterpret_cast<cuComplex*>(y), incy));
1002}
1003#endif
1004
1005template <>
1006void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
1007 // See Note [Writing Nondeterministic Operations]
1008 globalContext().alertCuBLASConfigNotDeterministic();
1009 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1010 cublasOperation_t op = _cublasOpFromChar(trans);
1011 _cublasAdjustLdLevel2(m, n, &lda);
1012 GEMV_CHECK_ARGVALUES(double);
1013 TORCH_CUDABLAS_CHECK(
1014 cublasDgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
1015}
1016
1017template <>
1018void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
1019 // gemv is bw bound, and does not benefit from TF32. But the precision
1020 // loss still happens on TF32. So we disable it here.
1021 NoTF32Guard disable_tf32;
1022 // See Note [Writing Nondeterministic Operations]
1023 globalContext().alertCuBLASConfigNotDeterministic();
1024 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1025 cublasOperation_t op = _cublasOpFromChar(trans);
1026 _cublasAdjustLdLevel2(m, n, &lda);
1027 GEMV_CHECK_ARGVALUES(float);
1028 TORCH_CUDABLAS_CHECK(
1029 cublasSgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
1030}
1031
1032template <>
1033void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)) {
1034 // In general, cublas regards matrices as column-major.
1035 // The cublasS/Dgemv usages in cuda::blas::gemv<float>/<double> above
1036 // require that external blas::gemv callers obey the following convention:
1037 //
1038 // If "a" is row-major with shape (output, summed) in blas::gemv's caller,
1039 // caller interprets it as column-major with shape (summed, output), passes
1040 // summed and output respectively to our local vars m, n, and requests that cublas
1041 // internally transpose ("trans") the column-major interpretation of a.
1042 //
1043 // There's no such thing as "cublasHalfgemv", so here we hack gemv with a gemm.
1044 // However, we must allow the same calling convention, because the caller shouldn't
1045 // have to swap args based on whether it's calling blas::gemv<at::Half> or <float>.
1046
1047 bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
1048 if (trans_bool) {
1049 std::swap(m, n);
1050 }
1051 // After swap, local vars m, n contain the output and summed sizes respectively,
1052 // regardless of whether "a" was row-major or column-major in gemv<>'s caller.
1053
1054 // To handle the possibility incy > 1, interprets vector y as column-major matrix with one row
1055 // (shape (1, output)) and leading dim incy.
1056 // trans(a)*x would compute a matrix with one column (shape (output, 1)) which wouldn't match y.
1057 // So instead, we interpret x similarly to y, as a column-major matrix with one row
1058 // (shape (1, summed)) and leading dim incx. The gemm then carries out x*transpose(trans(a)) to
1059 // produce a matrix with one row (shape (1, output)), matching y.
1060 char trans_flipped = (trans_bool ? 'n' : 't');
1061 gemm<at::Half>(
1062 'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
1063}
1064
1065template <>
1066void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)) {
1067 bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
1068 if (trans_bool) {
1069 std::swap(m, n);
1070 }
1071 char trans_flipped = (trans_bool ? 'n' : 't');
1072 gemm<at::BFloat16>(
1073 'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
1074}
1075
1076/* LEVEL 1 BLAS FUNCTIONS */
1077
1078template <>
1079void dot<double>(CUDABLAS_DOT_ARGTYPES(double)) {
1080 TORCH_CUDABLAS_CHECK(cublasDdot(handle, n, x, incx, y, incy, result));
1081}
1082
1083template <>
1084void dot<float>(CUDABLAS_DOT_ARGTYPES(float)) {
1085 TORCH_CUDABLAS_CHECK(cublasSdot(handle, n, x, incx, y, incy, result));
1086}
1087
1088template <>
1089void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
1090 TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
1091 incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
1092 reinterpret_cast<cuDoubleComplex*>(result)));
1093}
1094
1095template <>
1096void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
1097 TORCH_CUDABLAS_CHECK(cublasCdotu(handle, n, reinterpret_cast<const cuComplex*>(x),
1098 incx, reinterpret_cast<const cuComplex*>(y), incy,
1099 reinterpret_cast<cuComplex*>(result)));
1100}
1101
1102template <>
1103void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
1104#if !defined(USE_ROCM)
1105 TORCH_CUDABLAS_CHECK(cublasDotEx(
1106 handle,
1107 n,
1108 x,
1109 CUDA_R_16F,
1110 incx,
1111 y,
1112 CUDA_R_16F,
1113 incy,
1114 result,
1115 CUDA_R_16F,
1116 CUDA_R_32F));
1117#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
1118 TORCH_CUDABLAS_CHECK(rocblas_hdot(
1119 handle,
1120 n,
1121 reinterpret_cast<const rocblas_half*>(x),
1122 incx,
1123 reinterpret_cast<const rocblas_half*>(y),
1124 incy,
1125 reinterpret_cast<rocblas_half*>(result)));
1126#else
1127 AT_ERROR("Cublas_Hdot requires CUDA 8.0+");
1128#endif
1129}
1130
1131template <>
1132void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
1133#if !defined(USE_ROCM)
1134 TORCH_CUDABLAS_CHECK(cublasDotEx(
1135 handle,
1136 n,
1137 x,
1138 CUDA_R_16BF,
1139 incx,
1140 y,
1141 CUDA_R_16BF,
1142 incy,
1143 result,
1144 CUDA_R_16BF,
1145 CUDA_R_32F));
1146#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
1147 TORCH_CUDABLAS_CHECK(rocblas_bfdot(
1148 handle,
1149 n,
1150 reinterpret_cast<const rocblas_bfloat16*>(x),
1151 incx,
1152 reinterpret_cast<const rocblas_bfloat16*>(y),
1153 incy,
1154 reinterpret_cast<rocblas_bfloat16*>(result)));
1155#else
1156 AT_ERROR("Cublas_bfdot requires CUDA 11.0+");
1157#endif
1158}
1159
1160template <>
1161void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
1162 TORCH_CUDABLAS_CHECK(cublasCdotc(handle, n, reinterpret_cast<const cuComplex*>(x),
1163 incx, reinterpret_cast<const cuComplex*>(y), incy,
1164 reinterpret_cast<cuComplex*>(result)));
1165}
1166
1167template <>
1168void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
1169 TORCH_CUDABLAS_CHECK(cublasZdotc(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
1170 incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
1171 reinterpret_cast<cuDoubleComplex*>(result)));
1172}
1173
1174// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
1175#ifdef CUDART_VERSION
1176
1177template <>
1178void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float)) {
1179 TORCH_CUDABLAS_CHECK(cublasSgetrsBatched(
1180 handle,
1181 trans,
1182 n,
1183 nrhs,
1184 dA_array,
1185 lda,
1186 ipiv_array,
1187 dB_array,
1188 ldb,
1189 info_array,
1190 batchsize));
1191}
1192
1193template <>
1194void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double)) {
1195 TORCH_CUDABLAS_CHECK(cublasDgetrsBatched(
1196 handle,
1197 trans,
1198 n,
1199 nrhs,
1200 dA_array,
1201 lda,
1202 ipiv_array,
1203 dB_array,
1204 ldb,
1205 info_array,
1206 batchsize));
1207}
1208
1209template <>
1210void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>)) {
1211 TORCH_CUDABLAS_CHECK(cublasCgetrsBatched(
1212 handle,
1213 trans,
1214 n,
1215 nrhs,
1216 reinterpret_cast<cuComplex**>(dA_array),
1217 lda,
1218 ipiv_array,
1219 reinterpret_cast<cuComplex**>(dB_array),
1220 ldb,
1221 info_array,
1222 batchsize));
1223}
1224
1225template <>
1226void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>)) {
1227 TORCH_CUDABLAS_CHECK(cublasZgetrsBatched(
1228 handle,
1229 trans,
1230 n,
1231 nrhs,
1232 reinterpret_cast<cuDoubleComplex**>(dA_array),
1233 lda,
1234 ipiv_array,
1235 reinterpret_cast<cuDoubleComplex**>(dB_array),
1236 ldb,
1237 info_array,
1238 batchsize));
1239}
1240
1241template <>
1242void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)) {
1243 TORCH_CUDABLAS_CHECK(cublasSgeqrfBatched(
1244 handle, m, n, A_array, lda, tau_array, info, batchsize));
1245}
1246
1247template <>
1248void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)) {
1249 TORCH_CUDABLAS_CHECK(cublasDgeqrfBatched(
1250 handle, m, n, A_array, lda, tau_array, info, batchsize));
1251}
1252
1253template <>
1254void geqrfBatched<c10::complex<float>>(
1255 CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>)) {
1256 TORCH_CUDABLAS_CHECK(cublasCgeqrfBatched(
1257 handle,
1258 m,
1259 n,
1260 reinterpret_cast<cuComplex**>(A_array),
1261 lda,
1262 reinterpret_cast<cuComplex**>(tau_array),
1263 info,
1264 batchsize));
1265}
1266
1267template <>
1268void geqrfBatched<c10::complex<double>>(
1269 CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>)) {
1270 TORCH_CUDABLAS_CHECK(cublasZgeqrfBatched(
1271 handle,
1272 m,
1273 n,
1274 reinterpret_cast<cuDoubleComplex**>(A_array),
1275 lda,
1276 reinterpret_cast<cuDoubleComplex**>(tau_array),
1277 info,
1278 batchsize));
1279}
1280
1281template <>
1282void getrfBatched<double>(
1283 int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
1284 auto handle = at::cuda::getCurrentCUDABlasHandle();
1285 TORCH_CUDABLAS_CHECK(cublasDgetrfBatched(
1286 handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
1287}
1288
1289template <>
1290void getrfBatched<float>(
1291 int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
1292 auto handle = at::cuda::getCurrentCUDABlasHandle();
1293 TORCH_CUDABLAS_CHECK(cublasSgetrfBatched(
1294 handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
1295}
1296
1297template <>
1298void getrfBatched<c10::complex<double>>(
1299 int n,
1300 c10::complex<double>** dA_array,
1301 int ldda,
1302 int* ipiv_array,
1303 int* info_array,
1304 int batchsize) {
1305 auto handle = at::cuda::getCurrentCUDABlasHandle();
1306 TORCH_CUDABLAS_CHECK(cublasZgetrfBatched(
1307 handle,
1308 n,
1309 reinterpret_cast<cuDoubleComplex**>(dA_array),
1310 ldda,
1311 ipiv_array,
1312 info_array,
1313 batchsize));
1314}
1315
1316template <>
1317void getrfBatched<c10::complex<float>>(
1318 int n,
1319 c10::complex<float>** dA_array,
1320 int ldda,
1321 int* ipiv_array,
1322 int* info_array,
1323 int batchsize) {
1324 auto handle = at::cuda::getCurrentCUDABlasHandle();
1325 TORCH_CUDABLAS_CHECK(cublasCgetrfBatched(
1326 handle,
1327 n,
1328 reinterpret_cast<cuComplex**>(dA_array),
1329 ldda,
1330 ipiv_array,
1331 info_array,
1332 batchsize));
1333}
1334
1335
1336template <>
1337void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double)) {
1338 TORCH_CUDABLAS_CHECK(cublasDgelsBatched(
1339 handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize));
1340}
1341
1342template <>
1343void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float)) {
1344 TORCH_CUDABLAS_CHECK(cublasSgelsBatched(
1345 handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize));
1346}
1347
1348template <>
1349void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>)) {
1350 TORCH_CUDABLAS_CHECK(cublasZgelsBatched(
1351 handle, trans,
1352 m, n, nrhs,
1353 reinterpret_cast<cuDoubleComplex**>(dA_array),
1354 ldda,
1355 reinterpret_cast<cuDoubleComplex**>(dC_array),
1356 lddc,
1357 info,
1358 devInfoArray,
1359 batchSize));
1360}
1361
1362template <>
1363void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>)) {
1364 TORCH_CUDABLAS_CHECK(cublasCgelsBatched(
1365 handle, trans,
1366 m, n, nrhs,
1367 reinterpret_cast<cuComplex**>(dA_array),
1368 ldda,
1369 reinterpret_cast<cuComplex**>(dC_array),
1370 lddc,
1371 info,
1372 devInfoArray,
1373 batchSize));
1374}
1375
1376#endif // CUDART_VERSION
1377
1378} // namespace blas
1379} // namespace cuda
1380} // namespace at
1381