1#include <torch/csrc/cuda/THCP.h>
2#include <torch/csrc/python_headers.h>
3#include <cstdarg>
4#include <string>
5
6#ifdef USE_CUDA
7// NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
8// whatever the current stream of the device the input is associated with was.
9std::vector<c10::optional<at::cuda::CUDAStream>>
10THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) {
11 if (!PySequence_Check(obj)) {
12 throw std::runtime_error(
13 "Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
14 }
15 THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr));
16 if (seq.get() == nullptr) {
17 throw std::runtime_error(
18 "expected PySequence, but got " + std::string(THPUtils_typename(obj)));
19 }
20
21 std::vector<c10::optional<at::cuda::CUDAStream>> streams;
22 Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
23 for (Py_ssize_t i = 0; i < length; i++) {
24 PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i);
25
26 if (PyObject_IsInstance(stream, THCPStreamClass)) {
27 // Spicy hot reinterpret cast!!
28 streams.emplace_back(at::cuda::CUDAStream::unpack3(
29 (reinterpret_cast<THCPStream*>(stream))->stream_id,
30 (reinterpret_cast<THCPStream*>(stream))->device_index,
31 static_cast<c10::DeviceType>(
32 (reinterpret_cast<THCPStream*>(stream))->device_type)));
33 } else if (stream == Py_None) {
34 streams.emplace_back();
35 } else {
36 // NOLINTNEXTLINE(bugprone-throw-keyword-missing)
37 std::runtime_error(
38 "Unknown data type found in stream list. Need torch.cuda.Stream or None");
39 }
40 }
41 return streams;
42}
43
44#endif
45