1 | // The clang-tidy job seems to complain that it can't find cudnn.h without this. |
2 | // This file should only be compiled if this condition holds, so it should be |
3 | // safe. |
4 | #if defined(USE_CUDNN) || defined(USE_ROCM) |
5 | #include <torch/csrc/utils/pybind.h> |
6 | |
7 | #include <array> |
8 | #include <tuple> |
9 | |
10 | namespace { |
11 | using version_tuple = std::tuple<size_t, size_t, size_t>; |
12 | } |
13 | |
14 | #ifdef USE_CUDNN |
15 | #include <cudnn.h> |
16 | |
17 | namespace { |
18 | |
19 | version_tuple getCompileVersion() { |
20 | return version_tuple(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); |
21 | } |
22 | |
23 | version_tuple getRuntimeVersion() { |
24 | #ifndef USE_STATIC_CUDNN |
25 | auto version = cudnnGetVersion(); |
26 | auto major = version / 1000; |
27 | auto minor = (version % 1000) / 100; |
28 | auto patch = version % 10; |
29 | return version_tuple(major, minor, patch); |
30 | #else |
31 | return getCompileVersion(); |
32 | #endif |
33 | } |
34 | |
35 | size_t getVersionInt() { |
36 | #ifndef USE_STATIC_CUDNN |
37 | return cudnnGetVersion(); |
38 | #else |
39 | return CUDNN_VERSION; |
40 | #endif |
41 | } |
42 | |
43 | } // namespace |
44 | #elif defined(USE_ROCM) |
45 | #include <miopen/miopen.h> |
46 | #include <miopen/version.h> |
47 | |
48 | namespace { |
49 | |
50 | version_tuple getCompileVersion() { |
51 | return version_tuple( |
52 | MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH); |
53 | } |
54 | |
55 | version_tuple getRuntimeVersion() { |
56 | // MIOpen doesn't include runtime version info before 2.3.0 |
57 | #if (MIOPEN_VERSION_MAJOR > 2) || \ |
58 | (MIOPEN_VERSION_MAJOR == 2 && MIOPEN_VERSION_MINOR > 2) |
59 | size_t major, minor, patch; |
60 | miopenGetVersion(&major, &minor, &patch); |
61 | return version_tuple(major, minor, patch); |
62 | #else |
63 | return getCompileVersion(); |
64 | #endif |
65 | } |
66 | |
67 | size_t getVersionInt() { |
68 | // miopen version is MAJOR*1000000 + MINOR*1000 + PATCH |
69 | size_t major, minor, patch; |
70 | std::tie(major, minor, patch) = getRuntimeVersion(); |
71 | return major * 1000000 + minor * 1000 + patch; |
72 | } |
73 | |
74 | } // namespace |
75 | #endif |
76 | |
77 | namespace torch { |
78 | namespace cuda { |
79 | namespace shared { |
80 | |
81 | void initCudnnBindings(PyObject* module) { |
82 | auto m = py::handle(module).cast<py::module>(); |
83 | |
84 | auto cudnn = m.def_submodule("_cudnn" , "libcudnn.so bindings" ); |
85 | |
86 | py::enum_<cudnnRNNMode_t>(cudnn, "RNNMode" ) |
87 | .value("rnn_relu" , CUDNN_RNN_RELU) |
88 | .value("rnn_tanh" , CUDNN_RNN_TANH) |
89 | .value("lstm" , CUDNN_LSTM) |
90 | .value("gru" , CUDNN_GRU); |
91 | |
92 | // The runtime version check in python needs to distinguish cudnn from miopen |
93 | #ifdef USE_CUDNN |
94 | cudnn.attr("is_cuda" ) = true; |
95 | #else |
96 | cudnn.attr("is_cuda" ) = false; |
97 | #endif |
98 | |
99 | cudnn.def("getRuntimeVersion" , getRuntimeVersion); |
100 | cudnn.def("getCompileVersion" , getCompileVersion); |
101 | cudnn.def("getVersionInt" , getVersionInt); |
102 | } |
103 | |
104 | } // namespace shared |
105 | } // namespace cuda |
106 | } // namespace torch |
107 | #endif |
108 | |