1 | #include <torch/csrc/cuda/THCP.h> |
---|---|
2 | #include <torch/csrc/python_headers.h> |
3 | #include <cstdarg> |
4 | #include <string> |
5 | |
6 | #ifdef USE_CUDA |
7 | // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use |
8 | // whatever the current stream of the device the input is associated with was. |
9 | std::vector<c10::optional<at::cuda::CUDAStream>> |
10 | THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) { |
11 | if (!PySequence_Check(obj)) { |
12 | throw std::runtime_error( |
13 | "Expected a sequence in THPUtils_PySequence_to_CUDAStreamList"); |
14 | } |
15 | THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr)); |
16 | if (seq.get() == nullptr) { |
17 | throw std::runtime_error( |
18 | "expected PySequence, but got "+ std::string(THPUtils_typename(obj))); |
19 | } |
20 | |
21 | std::vector<c10::optional<at::cuda::CUDAStream>> streams; |
22 | Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); |
23 | for (Py_ssize_t i = 0; i < length; i++) { |
24 | PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i); |
25 | |
26 | if (PyObject_IsInstance(stream, THCPStreamClass)) { |
27 | // Spicy hot reinterpret cast!! |
28 | streams.emplace_back(at::cuda::CUDAStream::unpack3( |
29 | (reinterpret_cast<THCPStream*>(stream))->stream_id, |
30 | (reinterpret_cast<THCPStream*>(stream))->device_index, |
31 | static_cast<c10::DeviceType>( |
32 | (reinterpret_cast<THCPStream*>(stream))->device_type))); |
33 | } else if (stream == Py_None) { |
34 | streams.emplace_back(); |
35 | } else { |
36 | // NOLINTNEXTLINE(bugprone-throw-keyword-missing) |
37 | std::runtime_error( |
38 | "Unknown data type found in stream list. Need torch.cuda.Stream or None"); |
39 | } |
40 | } |
41 | return streams; |
42 | } |
43 | |
44 | #endif |
45 |