1// The clang-tidy job seems to complain that it can't find cudnn.h without this.
2// This file should only be compiled if this condition holds, so it should be
3// safe.
4#if defined(USE_CUDNN) || defined(USE_ROCM)
5#include <torch/csrc/utils/pybind.h>
6
7#include <array>
8#include <tuple>
9
10namespace {
11using version_tuple = std::tuple<size_t, size_t, size_t>;
12}
13
14#ifdef USE_CUDNN
15#include <cudnn.h>
16
17namespace {
18
19version_tuple getCompileVersion() {
20 return version_tuple(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
21}
22
23version_tuple getRuntimeVersion() {
24#ifndef USE_STATIC_CUDNN
25 auto version = cudnnGetVersion();
26 auto major = version / 1000;
27 auto minor = (version % 1000) / 100;
28 auto patch = version % 10;
29 return version_tuple(major, minor, patch);
30#else
31 return getCompileVersion();
32#endif
33}
34
35size_t getVersionInt() {
36#ifndef USE_STATIC_CUDNN
37 return cudnnGetVersion();
38#else
39 return CUDNN_VERSION;
40#endif
41}
42
43} // namespace
44#elif defined(USE_ROCM)
45#include <miopen/miopen.h>
46#include <miopen/version.h>
47
48namespace {
49
50version_tuple getCompileVersion() {
51 return version_tuple(
52 MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH);
53}
54
55version_tuple getRuntimeVersion() {
56 // MIOpen doesn't include runtime version info before 2.3.0
57#if (MIOPEN_VERSION_MAJOR > 2) || \
58 (MIOPEN_VERSION_MAJOR == 2 && MIOPEN_VERSION_MINOR > 2)
59 size_t major, minor, patch;
60 miopenGetVersion(&major, &minor, &patch);
61 return version_tuple(major, minor, patch);
62#else
63 return getCompileVersion();
64#endif
65}
66
67size_t getVersionInt() {
68 // miopen version is MAJOR*1000000 + MINOR*1000 + PATCH
69 size_t major, minor, patch;
70 std::tie(major, minor, patch) = getRuntimeVersion();
71 return major * 1000000 + minor * 1000 + patch;
72}
73
74} // namespace
75#endif
76
77namespace torch {
78namespace cuda {
79namespace shared {
80
81void initCudnnBindings(PyObject* module) {
82 auto m = py::handle(module).cast<py::module>();
83
84 auto cudnn = m.def_submodule("_cudnn", "libcudnn.so bindings");
85
86 py::enum_<cudnnRNNMode_t>(cudnn, "RNNMode")
87 .value("rnn_relu", CUDNN_RNN_RELU)
88 .value("rnn_tanh", CUDNN_RNN_TANH)
89 .value("lstm", CUDNN_LSTM)
90 .value("gru", CUDNN_GRU);
91
92 // The runtime version check in python needs to distinguish cudnn from miopen
93#ifdef USE_CUDNN
94 cudnn.attr("is_cuda") = true;
95#else
96 cudnn.attr("is_cuda") = false;
97#endif
98
99 cudnn.def("getRuntimeVersion", getRuntimeVersion);
100 cudnn.def("getCompileVersion", getCompileVersion);
101 cudnn.def("getVersionInt", getVersionInt);
102}
103
104} // namespace shared
105} // namespace cuda
106} // namespace torch
107#endif
108