1#include <ATen/ATen.h>
2#include <ATen/cuda/CUDAConfig.h>
3#if AT_CUDNN_ENABLED()
4
5#include <ATen/native/cudnn/Macros.h>
6
7#endif
8#include <ATen/cuda/CUDAContext.h>
9#include <ATen/cuda/CUDAGeneratorImpl.h>
10#include <ATen/cuda/CachingHostAllocator.h>
11#include <ATen/cuda/Sleep.h>
12#include <ATen/cuda/detail/CUDAHooks.h>
13#include <ATen/cuda/jiterator.h>
14#include <c10/cuda/CUDACachingAllocator.h>
15#include <c10/cuda/CUDAFunctions.h>
16#include <ATen/cuda/CUDAGraphsUtils.cuh>
17#ifdef USE_NCCL
18#include <torch/csrc/cuda/python_nccl.h>
19#endif
20#include <c10/util/CallOnce.h>
21#include <c10/util/irange.h>
22
23#include <torch/csrc/CudaIPCTypes.h>
24#include <torch/csrc/Generator.h>
25#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
26#include <torch/csrc/cuda/THCP.h>
27#include <torch/csrc/cuda/python_comm.h>
28#include <torch/csrc/python_headers.h>
29#include <torch/csrc/utils/cuda_lazy_init.h>
30#include <torch/csrc/utils/pybind.h>
31#include <torch/csrc/utils/pycfunction_helpers.h>
32#include <torch/csrc/utils/python_numbers.h>
33#include <torch/csrc/utils/python_strings.h>
34
35#include <array>
36#include <chrono>
37#include <iostream>
38#include <sstream>
39#include <thread>
40#include <unordered_map>
41
42#ifndef WIN32
43#include <pthread.h>
44#endif
45
46using namespace torch;
47
48static bool in_bad_fork = false; // True for children forked after cuda init
49
50#ifndef WIN32
51// Called in the forked child if cuda has already been initialized
52static void forked_child() {
53 in_bad_fork = true;
54 torch::utils::set_requires_cuda_init(true);
55}
56#endif
57
58// Should be called before the first cuda call.
59// Note: This is distinct from initExtension because a stub cuda implementation
60// has some working functions (e.g. device_count) but cannot fully initialize.
61static void poison_fork() {
62#ifndef WIN32
63 static c10::once_flag flag;
64 c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
65#endif
66}
67
68////////////////////////////////////////////////////////////////////////////////
69// CUDA management methods
70////////////////////////////////////////////////////////////////////////////////
71
72void THCPModule_setDevice(int device) {
73 c10::cuda::set_device(static_cast<c10::DeviceIndex>(device));
74}
75
76PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
77 HANDLE_TH_ERRORS
78 THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to setDevice");
79 int64_t device = THPUtils_unpackLong(arg);
80
81 torch::utils::cuda_lazy_init();
82 THCPModule_setDevice(device);
83
84 Py_RETURN_NONE;
85 END_HANDLE_TH_ERRORS
86}
87
88PyObject* THCPModule_exchangeDevice(PyObject* self, PyObject* arg) {
89 HANDLE_TH_ERRORS
90 TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");
91 int64_t device = THPUtils_unpackLong(arg);
92 if (device < 0) {
93 return THPUtils_packInt32(-1);
94 }
95
96 torch::utils::cuda_lazy_init();
97 auto current_device = c10::cuda::current_device();
98 if (current_device != device) {
99 THCPModule_setDevice(device);
100 }
101
102 return THPUtils_packInt32(static_cast<int>(current_device));
103 END_HANDLE_TH_ERRORS
104}
105
106PyObject* THCPModule_getDevice_wrap(PyObject* self, PyObject* noargs) {
107 HANDLE_TH_ERRORS
108 torch::utils::cuda_lazy_init();
109 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
110 auto device = static_cast<int>(c10::cuda::current_device());
111 return THPUtils_packInt32(device);
112 END_HANDLE_TH_ERRORS
113}
114
115PyObject* THCPModule_canDeviceAccessPeer_wrap(PyObject* self, PyObject* args) {
116 HANDLE_TH_ERRORS
117 PyObject* arg1 = nullptr;
118 PyObject* arg2 = nullptr;
119 if (!PyArg_ParseTuple(args, "OO", &arg1, &arg2)) {
120 THPUtils_invalidArguments(
121 args,
122 nullptr,
123 "can_device_peer_access",
124 1,
125 "(int device, int peer_device);");
126 return nullptr;
127 }
128 THPUtils_assert(
129 THPUtils_checkLong(arg1), "invalid argument to canDeviceAccessPeer");
130 THPUtils_assert(
131 THPUtils_checkLong(arg2), "invalid argument to canDeviceAccessPeer");
132 int64_t device = THPUtils_unpackLong(arg1);
133 int64_t peer_device = THPUtils_unpackLong(arg2);
134
135 torch::utils::cuda_lazy_init();
136 auto can_access = at::cuda::canDeviceAccessPeer(device, peer_device);
137 return PyBool_FromLong(can_access);
138 END_HANDLE_TH_ERRORS
139}
140
141PyObject* THCPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) {
142 HANDLE_TH_ERRORS
143 poison_fork();
144 return THPUtils_packUInt64(at::cuda::device_count());
145 END_HANDLE_TH_ERRORS
146}
147
148PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) {
149 HANDLE_TH_ERRORS
150 poison_fork();
151#ifdef CUDA_ARCH_FLAGS
152 static const char* flags = C10_STRINGIZE(CUDA_ARCH_FLAGS);
153 return THPUtils_packString(flags);
154#else
155 Py_RETURN_NONE;
156#endif
157 END_HANDLE_TH_ERRORS
158}
159
160static PyObject* THCPModule_isInBadFork(PyObject* self, PyObject* noargs) {
161 HANDLE_TH_ERRORS
162 return PyBool_FromLong(in_bad_fork);
163 END_HANDLE_TH_ERRORS
164}
165
166PyObject* THCPModule_getCurrentStream_wrap(
167 PyObject* /* unused */,
168 PyObject* device_index) {
169 HANDLE_TH_ERRORS
170 THPUtils_assert(
171 THPUtils_checkLong(device_index), "invalid argument to getCurrentStream");
172 int64_t device = THPUtils_unpackLong(device_index);
173 auto stream = at::cuda::getCurrentCUDAStream(device);
174 PyObject* output_tuple = PyTuple_New(3);
175 PyTuple_SetItem(
176 output_tuple, 0, THPUtils_packInt64(static_cast<int64_t>(stream.id())));
177 PyTuple_SetItem(
178 output_tuple,
179 1,
180 THPUtils_packInt64(static_cast<int64_t>(stream.device_index())));
181 PyTuple_SetItem(
182 output_tuple,
183 2,
184 THPUtils_packInt64(static_cast<int64_t>(stream.device_type())));
185 return output_tuple;
186 END_HANDLE_TH_ERRORS
187}
188
189PyObject* THCPModule_getCurrentStream_raw(
190 PyObject* /* unused */,
191 PyObject* device_index) {
192 HANDLE_TH_ERRORS
193 THPUtils_assert(
194 THPUtils_checkLong(device_index), "invalid argument to getCurrentStream");
195 int64_t device = THPUtils_unpackLong(device_index);
196 return PyLong_FromVoidPtr(at::cuda::getCurrentCUDAStream(device).stream());
197 END_HANDLE_TH_ERRORS
198}
199
200PyObject* THCPModule_getDefaultStream_wrap(
201 PyObject* /* unused */,
202 PyObject* device_index) {
203 HANDLE_TH_ERRORS
204 THPUtils_assert(
205 THPUtils_checkLong(device_index), "invalid argument to getDefaultStream");
206 int64_t device = THPUtils_unpackLong(device_index);
207 auto stream = at::cuda::getDefaultCUDAStream(device);
208 PyObject* output_tuple = PyTuple_New(3);
209 PyTuple_SetItem(
210 output_tuple, 0, THPUtils_packInt64(static_cast<int64_t>(stream.id())));
211 PyTuple_SetItem(
212 output_tuple,
213 1,
214 THPUtils_packInt64(static_cast<int64_t>(stream.device_index())));
215 PyTuple_SetItem(
216 output_tuple,
217 2,
218 THPUtils_packInt64(static_cast<int64_t>(stream.device_type())));
219 return output_tuple;
220 END_HANDLE_TH_ERRORS
221}
222
223PyObject* THCPModule_setStream_wrap(
224 PyObject* self,
225 PyObject* args,
226 PyObject* kwargs) {
227 HANDLE_TH_ERRORS
228 int64_t stream_id = 0;
229 int64_t device_index = 0;
230 int64_t device_type = 0;
231
232 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
233 constexpr char* kwlist[] = {
234 "stream_id", "device_index", "device_type", nullptr};
235 if (!PyArg_ParseTupleAndKeywords(
236 args,
237 kwargs,
238 "|LLL",
239 const_cast<char**>(kwlist),
240 &stream_id,
241 &device_index,
242 &device_type)) {
243 }
244
245 auto stream = at::cuda::CUDAStream::unpack3(
246 stream_id, device_index, static_cast<c10::DeviceType>(device_type));
247
248 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
249 auto device = static_cast<int>(c10::cuda::current_device());
250 if (device != stream.device_index()) {
251 THCPModule_setDevice(stream.device_index());
252 }
253 at::cuda::setCurrentCUDAStream(stream);
254 Py_RETURN_NONE;
255 END_HANDLE_TH_ERRORS
256}
257
258PyObject* THCPModule_getCompiledVersion(PyObject* self, PyObject* noargs) {
259#if defined(USE_ROCM)
260 return THPUtils_packInt64((int64_t)ROCM_VERSION);
261#else
262 return THPUtils_packInt64((int64_t)CUDA_VERSION);
263#endif
264}
265
266PyObject* THCPModule_cudaHostAllocator(PyObject* _unused, PyObject* noargs) {
267 HANDLE_TH_ERRORS
268 c10::Allocator* allocator = at::cuda::getCachingHostAllocator();
269 return PyLong_FromVoidPtr(allocator);
270 END_HANDLE_TH_ERRORS
271}
272
273PyObject* THCPModule_cudaCachingAllocator_raw_alloc(
274 PyObject* _unused,
275 PyObject* args) {
276 HANDLE_TH_ERRORS
277 PyObject* size_o = nullptr;
278 PyObject* stream_o = nullptr;
279 if (!PyArg_ParseTuple(args, "OO", &size_o, &stream_o)) {
280 THPUtils_invalidArguments(
281 args,
282 nullptr,
283 "caching_allocator_alloc",
284 1,
285 "(ssize_t size, intptr_t stream);");
286 return nullptr;
287 }
288 auto size = PyLong_AsSsize_t(size_o);
289 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
290 cudaStream_t stream = static_cast<cudaStream_t>(PyLong_AsVoidPtr(stream_o));
291 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
292 void* mem =
293 c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(size, stream);
294 return PyLong_FromVoidPtr(mem);
295 END_HANDLE_TH_ERRORS
296}
297
298// Unpack a PyObject to at::Scalar, throw an exception if it fails
299at::Scalar as_scalar(PyObject* arg) {
300 // Zero-dim tensors are converted to Scalars as-is. Note this doesn't
301 // currently handle most NumPy scalar types except np.float64.
302 if (THPVariable_Check(arg)) {
303 return THPVariable_Unpack(arg).item();
304 }
305
306 if (THPUtils_checkLong(arg)) {
307 return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(arg)));
308 }
309
310 if (PyBool_Check(arg)) {
311 return at::Scalar(THPUtils_unpackBool(arg));
312 }
313
314 if (PyComplex_Check(arg)) {
315 return at::Scalar(THPUtils_unpackComplexDouble(arg));
316 }
317 return at::Scalar(THPUtils_unpackDouble(arg));
318}
319
320// Entrypoint for the callable created by torch.cuda.jiterator
321// See jiterator.py for more details
322PyObject* THCPModule_cudaJiteratorCompileAndLaunchKernel(
323 PyObject* _unused,
324 PyObject* args) {
325 HANDLE_TH_ERRORS
326
327 PyObject* code_string_o = nullptr;
328 PyObject* kernel_name_o = nullptr;
329 PyObject* return_by_ref_o = nullptr;
330 PyObject* num_outputs_o = nullptr;
331 PyObject* tensors_o = nullptr;
332 PyObject* kwargs_o = nullptr;
333 if (!PyArg_ParseTuple(
334 args,
335 "OOOOO|O",
336 &code_string_o,
337 &kernel_name_o,
338 &return_by_ref_o,
339 &num_outputs_o,
340 &tensors_o,
341 &kwargs_o)) {
342 return nullptr;
343 }
344
345 const std::string code_string = THPUtils_unpackString(code_string_o);
346 const std::string kernel_name = THPUtils_unpackString(kernel_name_o);
347 const bool return_by_ref = THPUtils_unpackBool(return_by_ref_o);
348 const int num_outputs = static_cast<int>(THPUtils_unpackLong(num_outputs_o));
349
350 THPUtils_assert(
351 PyTuple_Check(tensors_o),
352 "tensors argument is expected to "
353 "be a tuple, but got %s",
354 THPUtils_typename(tensors_o));
355 Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors_o);
356
357 c10::SmallVector<at::Tensor> tensors;
358 for (const auto i : c10::irange(num_tensors)) {
359 PyObject* _tensor = PyTuple_GET_ITEM(tensors_o, i);
360 THPUtils_assert(
361 THPVariable_Check(_tensor),
362 "%d of input tensors tuple is not a Tensor",
363 i);
364
365 tensors.emplace_back(THPVariable_Unpack(_tensor));
366 }
367
368 c10::SmallVector<at::Scalar> extra_args;
369 PyObject* key = nullptr;
370 PyObject* value = nullptr;
371 Py_ssize_t pos = 0;
372 while (PyDict_Next(kwargs_o, &pos, &key, &value)) {
373 extra_args.emplace_back(as_scalar(value));
374 }
375
376 c10::SmallVector<at::Tensor> outputs = at::cuda::CompileAndLaunchKernel(
377 code_string,
378 kernel_name,
379 num_outputs,
380 tensors,
381 extra_args,
382 return_by_ref);
383
384 if (num_outputs == 1) {
385 return THPVariable_Wrap(outputs[0]);
386 } else {
387 PyObject* output_tuple = PyTuple_New(num_outputs);
388 for (int i = 0; i < num_outputs; ++i) {
389 PyTuple_SetItem(output_tuple, i, THPVariable_Wrap(outputs[i]));
390 }
391 return output_tuple;
392 }
393
394 END_HANDLE_TH_ERRORS
395}
396
397PyObject* THCPModule_cudaCachingAllocator_raw_delete(
398 PyObject* _unused,
399 PyObject* obj) {
400 HANDLE_TH_ERRORS
401 void* mem_ptr = PyLong_AsVoidPtr(obj);
402 c10::cuda::CUDACachingAllocator::raw_delete(mem_ptr);
403 Py_RETURN_NONE;
404 END_HANDLE_TH_ERRORS
405}
406
407PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings(
408 PyObject* _unused,
409 PyObject* env) {
410 HANDLE_TH_ERRORS
411 c10::cuda::CUDACachingAllocator::setAllocatorSettings(
412 THPUtils_unpackString(env));
413 Py_RETURN_NONE;
414 END_HANDLE_TH_ERRORS
415}
416
417PyObject* THCPModule_getAllocatorBackend(PyObject* _unused, PyObject* noargs) {
418 HANDLE_TH_ERRORS
419 return THPUtils_packString(c10::cuda::CUDACachingAllocator::name());
420 END_HANDLE_TH_ERRORS
421}
422
423PyObject* THCPModule_cudaSynchronize(PyObject* _unused, PyObject* noargs) {
424 HANDLE_TH_ERRORS
425 c10::cuda::device_synchronize();
426 Py_RETURN_NONE;
427 END_HANDLE_TH_ERRORS
428}
429
430PyObject* THCPModule_cudaIPCCollect(PyObject* _unused, PyObject* noargs) {
431 HANDLE_TH_ERRORS
432 torch::CudaIPCCollect();
433 Py_RETURN_NONE;
434 END_HANDLE_TH_ERRORS
435}
436
437PyObject* THCPModule_cudaSleep(PyObject* _unused, PyObject* cycles) {
438 HANDLE_TH_ERRORS
439 THPUtils_assert(
440 THPUtils_checkLong(cycles), "torch.cuda._sleep(): expected 'int'");
441 at::cuda::sleep(THPUtils_unpackLong(cycles));
442 Py_RETURN_NONE;
443 END_HANDLE_TH_ERRORS
444}
445
446// We need to ensure that as long as a thread will NEVER loose the GIL as long
447// as it holds the CUDA mutex. Otherwise another thread might be scheduled and
448// try to e.g. allocate a new tensor which will cause a deadlock. It's enough to
449// have a single global, because it can be only set once (cudaMutex is not
450// recursive) by the thread that owns the mutex (obviously there can be only one
451// such thread).
452static PyGILState_STATE cudaMutexGILState;
453
454PyObject* THCPModule_cudaLockMutex(PyObject* module, PyObject* noargs) {
455 auto mutex = c10::cuda::getFreeMutex();
456 // This has to be a busy loop because we **absolutely need to** hold the GIL
457 // or it's a recipe for a deadlock otherwise (if we let other Python threads
458 // run while we have the cudaMutex, but not the GIL, they might try to e.g.
459 // free a CUDA tensor and acquire the cudaMutex without giving up the GIL,
460 // because it happens deep within THC).
461 while (true) {
462 if (mutex->try_lock())
463 break;
464 {
465 pybind11::gil_scoped_release no_gil;
466 std::this_thread::sleep_for(std::chrono::microseconds(10));
467 }
468 }
469
470 cudaMutexGILState = PyGILState_Ensure();
471 Py_RETURN_NONE;
472}
473
474PyObject* THCPModule_cudaUnlockMutex(PyObject* module, PyObject* noargs) {
475 auto mutex = c10::cuda::getFreeMutex();
476 PyGILState_Release(cudaMutexGILState);
477 mutex->unlock();
478 Py_RETURN_NONE;
479}
480
481PyObject* THCPModule_hasPrimaryContext(PyObject* _unused, PyObject* arg) {
482 HANDLE_TH_ERRORS
483 THPUtils_assert(
484 THPUtils_checkLong(arg), "invalid argument to has_primary_context");
485 int64_t device_index = static_cast<int64_t>(THPUtils_unpackLong(arg));
486 if (at::cuda::detail::hasPrimaryContext(device_index)) {
487 Py_RETURN_TRUE;
488 } else {
489 Py_RETURN_FALSE;
490 }
491 END_HANDLE_TH_ERRORS
492}
493
494PyObject* THCPModule_setMemoryFraction(PyObject* _unused, PyObject* args) {
495 HANDLE_TH_ERRORS
496 PyObject* fraction_o = nullptr;
497 PyObject* device_o = nullptr;
498 if (!PyArg_ParseTuple(args, "OO", &fraction_o, &device_o)) {
499 THPUtils_invalidArguments(
500 args,
501 nullptr,
502 "set_memory_fraction",
503 1,
504 "(double fraction, int device);");
505 return nullptr;
506 }
507 double fraction = PyFloat_AsDouble(fraction_o);
508 int64_t device = PyLong_AsLongLong(device_o);
509
510 c10::cuda::CUDACachingAllocator::setMemoryFraction(fraction, device);
511 END_HANDLE_TH_ERRORS
512 Py_RETURN_NONE;
513}
514
515PyObject* THCPModule_emptyCache(PyObject* _unused, PyObject* noargs) {
516 HANDLE_TH_ERRORS
517 c10::cuda::CUDACachingAllocator::emptyCache();
518 END_HANDLE_TH_ERRORS
519 Py_RETURN_NONE;
520}
521
522PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) {
523 HANDLE_TH_ERRORS
524 THPUtils_assert(
525 THPUtils_checkLong(arg), "invalid argument to memory_allocated");
526 const int device = (int)THPUtils_unpackLong(arg);
527
528 using c10::cuda::CUDACachingAllocator::DeviceStats;
529 using c10::cuda::CUDACachingAllocator::Stat;
530 using c10::cuda::CUDACachingAllocator::StatArray;
531 using c10::cuda::CUDACachingAllocator::StatType;
532
533 const auto statToDict = [](const Stat& stat) {
534 py::dict dict;
535
536 dict["current"] = stat.current;
537 dict["peak"] = stat.peak;
538 dict["allocated"] = stat.allocated;
539 dict["freed"] = stat.freed;
540 return dict;
541 };
542
543 const auto statArrayToDict = [=](const StatArray& statArray) {
544 const std::array<const char*, static_cast<size_t>(StatType::NUM_TYPES)>
545 statTypeNames = {"all", "small_pool", "large_pool"};
546 py::dict dict;
547 for (const auto i : c10::irange(statTypeNames.size())) {
548 dict[statTypeNames[i]] = statToDict(statArray[i]);
549 }
550 return dict;
551 };
552
553 const DeviceStats stats =
554 c10::cuda::CUDACachingAllocator::getDeviceStats(device);
555
556 py::dict result;
557 result["num_alloc_retries"] = stats.num_alloc_retries;
558 result["num_ooms"] = stats.num_ooms;
559 result["max_split_size"] = stats.max_split_size;
560 result["allocation"] = statArrayToDict(stats.allocation);
561 result["segment"] = statArrayToDict(stats.segment);
562 result["active"] = statArrayToDict(stats.active);
563 result["inactive_split"] = statArrayToDict(stats.inactive_split);
564 result["allocated_bytes"] = statArrayToDict(stats.allocated_bytes);
565 result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes);
566 result["active_bytes"] = statArrayToDict(stats.active_bytes);
567 result["inactive_split_bytes"] = statArrayToDict(stats.inactive_split_bytes);
568 result["requested_bytes"] = statArrayToDict(stats.requested_bytes);
569 result["oversize_allocations"] = statToDict(stats.oversize_allocations);
570 result["oversize_segments"] = statToDict(stats.oversize_segments);
571
572 return result.release().ptr();
573 END_HANDLE_TH_ERRORS
574}
575
576PyObject* THCPModule_resetAccumulatedMemoryStats(
577 PyObject* _unused,
578 PyObject* arg) {
579 HANDLE_TH_ERRORS
580 THPUtils_assert(
581 THPUtils_checkLong(arg),
582 "invalid argument to reset_accumulated_memory_stats");
583 const int device = (int)THPUtils_unpackLong(arg);
584 c10::cuda::CUDACachingAllocator::resetAccumulatedStats(device);
585 END_HANDLE_TH_ERRORS
586 Py_RETURN_NONE;
587}
588
589PyObject* THCPModule_resetPeakMemoryStats(PyObject* _unused, PyObject* arg) {
590 HANDLE_TH_ERRORS
591 THPUtils_assert(
592 THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats");
593 const int device = (int)THPUtils_unpackLong(arg);
594 c10::cuda::CUDACachingAllocator::resetPeakStats(device);
595 END_HANDLE_TH_ERRORS
596 Py_RETURN_NONE;
597}
598
599struct Frame {
600 PyCodeObject* code;
601 int lasti;
602};
603
604struct StackContext : public c10::cuda::CUDACachingAllocator::Context {
605 std::vector<Frame> frames;
606 // Empty if cpp traces weren't enabled
607 std::string cpp_frames;
608 ~StackContext() {
609 py::gil_scoped_acquire acquire;
610 for (auto& f : frames) {
611 Py_XDECREF((PyObject*)f.code);
612 }
613 }
614 static std::shared_ptr<StackContext> _gather() {
615 py::gil_scoped_acquire acquire;
616 auto r = std::make_shared<StackContext>();
617 PyFrameObject* f = PyEval_GetFrame();
618 Py_XINCREF(f);
619 while (f) {
620 r->frames.emplace_back(Frame{PyFrame_GetCode(f), PyFrame_GetLasti(f)});
621 auto f_back = PyFrame_GetBack(f);
622 Py_XDECREF(f);
623 f = f_back;
624 }
625 return r;
626 }
627 static std::shared_ptr<c10::cuda::CUDACachingAllocator::Context> gather() {
628 return _gather();
629 }
630 static std::shared_ptr<c10::cuda::CUDACachingAllocator::Context>
631 gather_with_cpp() {
632 auto r = _gather();
633 r->cpp_frames = c10::get_backtrace();
634 return std::move(r);
635 }
636};
637
638PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
639 HANDLE_TH_ERRORS
640
641 using c10::cuda::CUDACachingAllocator::BlockInfo;
642 using c10::cuda::CUDACachingAllocator::History;
643 using c10::cuda::CUDACachingAllocator::SegmentInfo;
644
645 py::str device_s = "device";
646 py::str address_s = "address";
647 py::str total_size_s = "total_size";
648 py::str allocated_size_s = "allocated_size";
649 py::str active_size_s = "active_size";
650 py::str requested_size_s = "requested_size";
651 py::str stream_s = "stream";
652 py::str segment_type_s = "segment_type";
653 py::str large_s = "large";
654 py::str small_s = "small";
655 py::str size_s = "size";
656 py::str state_s = "state";
657 py::str active_allocated_s = "active_allocated";
658 py::str active_pending_free_s = "active_pending_free";
659 py::str inactive_s = "inactive";
660 py::str addr_s = "addr";
661 py::str real_size_s = "real_size";
662 py::str filename_s = "filename";
663 py::str name_s = "name";
664 py::str line_s = "line";
665 py::str frames_s = "frames";
666 py::str cpp_frames_s = "cpp_frames";
667 py::str history_s = "history";
668 py::str blocks_s = "blocks";
669
670 std::unordered_map<StackContext*, py::list> cached_frames;
671 const auto get_frames = [&](StackContext* sc) -> py::list {
672 auto it = cached_frames.find(sc);
673 if (it != cached_frames.end()) {
674 return it->second;
675 }
676 py::list frames;
677 for (auto& f : sc->frames) {
678 py::dict frame;
679 frame[filename_s] =
680 py::reinterpret_borrow<py::object>(f.code->co_filename);
681 frame[name_s] = py::reinterpret_borrow<py::object>(f.code->co_name);
682 frame[line_s] = PyCode_Addr2Line(f.code, f.lasti);
683 frames.append(std::move(frame));
684 }
685 cached_frames.insert({sc, frames});
686 return frames;
687 };
688
689 const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) {
690 py::dict segmentDict;
691 segmentDict[device_s] = segmentInfo.device;
692 segmentDict[address_s] = segmentInfo.address;
693 segmentDict[total_size_s] = segmentInfo.total_size;
694 segmentDict[allocated_size_s] = segmentInfo.allocated_size;
695 segmentDict[active_size_s] = segmentInfo.active_size;
696 segmentDict[requested_size_s] = segmentInfo.requested_size;
697 // we want the python objects to pickle easily so use an int to
698 // represent the stream rather than a torch.cuda.stream object
699 segmentDict[stream_s] = int64_t(segmentInfo.stream);
700 segmentDict[segment_type_s] = (segmentInfo.is_large ? large_s : small_s);
701
702 py::list blocks;
703 for (const auto& blockInfo : segmentInfo.blocks) {
704 py::dict blockDict;
705 blockDict[size_s] = blockInfo.size;
706 blockDict[requested_size_s] = blockInfo.requested_size;
707 blockDict[state_s] =
708 (blockInfo.allocated
709 ? active_allocated_s
710 : (blockInfo.active ? active_pending_free_s : inactive_s));
711 if (blockInfo.history.size()) {
712 py::list history;
713 for (const History& h : blockInfo.history) {
714 py::dict history_entry;
715 history_entry[addr_s] = (int64_t)h.addr;
716 history_entry[real_size_s] = h.real_size;
717 if (h.context) {
718 auto sc = (StackContext*)h.context.get();
719 history_entry[frames_s] = get_frames(sc);
720 if (!sc->cpp_frames.empty()) {
721 history_entry[cpp_frames_s] = py::cast(sc->cpp_frames);
722 }
723 }
724 history.append(std::move(history_entry));
725 }
726 blockDict[history_s] = std::move(history);
727 }
728 blocks.append(blockDict);
729 }
730 segmentDict[blocks_s] = blocks;
731
732 return segmentDict;
733 };
734
735 auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
736 py::list segments;
737
738 for (const auto& segmentInfo : snapshot.segments) {
739 segments.append(segmentInfoToDict(segmentInfo));
740 }
741
742 py::list traces;
743 py::str action_s = "action";
744 py::str alloc_s = "alloc";
745 py::str free_requested_s = "free_requested";
746 py::str free_completed_s = "free_completed";
747 py::str segment_alloc_s = "segment_alloc";
748 py::str segment_free_s = "segment_free";
749 py::str snapshot_s = "snapshot";
750 py::str oom_s = "oom";
751 py::str device_free_s = "device_free";
752
753 using namespace c10::cuda::CUDACachingAllocator;
754
755 auto action_to_str = [&](TraceEntry::Action action) {
756 switch (action) {
757 case TraceEntry::ALLOC:
758 return alloc_s;
759 case TraceEntry::FREE_REQUESTED:
760 return free_requested_s;
761 case TraceEntry::FREE_COMPLETED:
762 return free_completed_s;
763 case TraceEntry::SEGMENT_ALLOC:
764 return segment_alloc_s;
765 case TraceEntry::SEGMENT_FREE:
766 return segment_free_s;
767 case TraceEntry::OOM:
768 return oom_s;
769 case TraceEntry::SNAPSHOT:
770 return snapshot_s;
771 }
772 throw std::runtime_error("unreachable");
773 };
774
775 for (const auto& traceInfo : snapshot.device_traces) {
776 py::list trace;
777 for (const auto& te : traceInfo) {
778 py::dict trace_entry;
779 if (te.context_) {
780 // without further compression frames can get really large on dump
781 auto sc = (StackContext*)te.context_.get();
782 trace_entry[frames_s] = get_frames(sc);
783 if (!sc->cpp_frames.empty()) {
784 trace_entry[cpp_frames_s] = py::cast(sc->cpp_frames);
785 }
786 }
787 trace_entry[action_s] = action_to_str(te.action_);
788 trace_entry[TraceEntry::OOM == te.action_ ? device_free_s : addr_s] =
789 te.addr_;
790 trace_entry[size_s] = te.size_;
791 trace_entry[stream_s] = int64_t(te.stream_);
792 trace.append(trace_entry);
793 }
794 traces.append(trace);
795 }
796
797 py::dict result;
798 result["segments"] = segments;
799 result["device_traces"] = traces;
800
801 return result.release().ptr();
802 END_HANDLE_TH_ERRORS
803}
804
805PyObject* THCPModule_attachOutOfMemoryObserver(
806 PyObject* _unused,
807 PyObject* observer) {
808 HANDLE_TH_ERRORS
809 Py_XINCREF(observer);
810 auto obs = [observer](
811 int64_t device,
812 int64_t alloc,
813 int64_t device_allocated,
814 int64_t device_free) {
815 py::gil_scoped_acquire g;
816 PyObject* result = PyObject_CallFunction(
817 observer, "LLLL", device, alloc, device_allocated, device_free);
818 if (!result) {
819 throw py::error_already_set();
820 }
821 Py_XDECREF(result);
822 };
823 c10::cuda::CUDACachingAllocator::attachOutOfMemoryObserver(std::move(obs));
824 Py_RETURN_NONE;
825 END_HANDLE_TH_ERRORS
826}
827
828PyObject* THCPModule_cudaSetSyncDebugMode(PyObject* _unused, PyObject* arg) {
829 HANDLE_TH_ERRORS
830 TORCH_WARN_ONCE(
831 "Synchronization debug mode is a prototype feature and does not yet detect all "
832 "synchronizing operations");
833 THPUtils_assert(
834 THPUtils_checkLong(arg), "invalid argument to set_sync_debug_mode");
835 int64_t debug_mode = THPUtils_unpackLong(arg);
836 TORCH_CHECK(
837 debug_mode >= 0 && debug_mode <= 2,
838 "invalid value of debug_mode, expected one of 0,1,2");
839 c10::cuda::SyncDebugMode l;
840 switch (debug_mode) {
841 case 0:
842 l = c10::cuda::SyncDebugMode::L_DISABLED;
843 break;
844 case 1:
845 l = c10::cuda::SyncDebugMode::L_WARN;
846 break;
847 case 2:
848 l = c10::cuda::SyncDebugMode::L_ERROR;
849 break;
850 default:
851 l = c10::cuda::SyncDebugMode::L_DISABLED;
852 break; // can't happen
853 }
854 c10::cuda::warning_state().set_sync_debug_mode(l);
855 Py_RETURN_NONE;
856 END_HANDLE_TH_ERRORS
857}
858
859PyObject* THCPModule_cudaGetSyncDebugMode(PyObject* self, PyObject* noargs) {
860 HANDLE_TH_ERRORS
861 auto debug_mode = c10::cuda::warning_state().get_sync_debug_mode();
862 switch (debug_mode) {
863 case c10::cuda::SyncDebugMode::L_DISABLED:
864 return THPUtils_packInt32(0);
865 case c10::cuda::SyncDebugMode::L_WARN:
866 return THPUtils_packInt32(1);
867 case c10::cuda::SyncDebugMode::L_ERROR:
868 return THPUtils_packInt32(2);
869 default:
870 return THPUtils_packInt32(-1); // can't happen
871 }
872 END_HANDLE_TH_ERRORS
873}
874
875////////////////////////////////////////////////////////////////////////////////
876// Cuda module initialization
877////////////////////////////////////////////////////////////////////////////////
878
879static void registerCudaDeviceProperties(PyObject* module) {
880 // Add _cudaDevicePropertires class to torch._C
881 auto m = py::handle(module).cast<py::module>();
882 py::class_<cudaDeviceProp>(m, "_CudaDeviceProperties")
883 .def_readonly("name", &cudaDeviceProp::name)
884 .def_readonly("major", &cudaDeviceProp::major)
885 .def_readonly("minor", &cudaDeviceProp::minor)
886 .def_readonly("is_multi_gpu_board", &cudaDeviceProp::isMultiGpuBoard)
887 .def_readonly("is_integrated", &cudaDeviceProp::integrated)
888 .def_readonly(
889 "multi_processor_count", &cudaDeviceProp::multiProcessorCount)
890 .def_readonly("total_memory", &cudaDeviceProp::totalGlobalMem)
891 .def("__repr__", [](const cudaDeviceProp& prop) {
892 std::ostringstream stream;
893 stream << "_CudaDeviceProperties(name='" << prop.name
894 << "', major=" << prop.major << ", minor=" << prop.minor
895 << ", total_memory=" << prop.totalGlobalMem / (1024 * 1024)
896 << "MB, multi_processor_count=" << prop.multiProcessorCount
897 << ")";
898 return stream.str();
899 });
900
901 m.def(
902 "_cuda_recordMemoryHistory",
903 [](bool enabled,
904 bool record_context,
905 bool record_context_cpp,
906 Py_ssize_t alloc_trace_max_entries,
907 bool alloc_trace_record_context) {
908 c10::cuda::CUDACachingAllocator::recordHistory(
909 enabled,
910 record_context ? (record_context_cpp ? StackContext::gather_with_cpp
911 : StackContext::gather)
912 : nullptr,
913 alloc_trace_max_entries,
914 alloc_trace_record_context);
915 });
916}
917
918static void registerCudaPluggableAllocator(PyObject* module) {
919 auto m = py::handle(module).cast<py::module>();
920
921 py::class_<
922 c10::cuda::CUDACachingAllocator::CUDAAllocator,
923 std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>>(
924 m, "_cuda_CUDAAllocator");
925 m.def("_cuda_getAllocator", []() {
926 return py::cast(torch::cuda::CUDAPluggableAllocator::getCurrentAllocator());
927 });
928
929 m.def(
930 "_cuda_changeCurrentAllocator",
931 [](std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
932 allocator) {
933 torch::cuda::CUDAPluggableAllocator::changeCurrentAllocator(allocator);
934 });
935 py::class_<
936 torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator,
937 c10::cuda::CUDACachingAllocator::CUDAAllocator,
938 std::shared_ptr<
939 torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator>>(
940 m, "_CUDAPluggableAllocator")
941 .def(
942 "set_init_fn",
943 [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self,
944 uint64_t func_ptr) {
945 using FuncType = void(int);
946 std::function<FuncType> func =
947 reinterpret_cast<FuncType*>(func_ptr);
948 self.set_init_fn(func);
949 })
950 .def(
951 "set_reset_fn",
952 [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self,
953 uint64_t func_ptr) {
954 using FuncType = void();
955 std::function<FuncType> func =
956 reinterpret_cast<FuncType*>(func_ptr);
957 self.set_reset_fn(func);
958 })
959 .def(
960 "set_memory_fraction_fn",
961 [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self,
962 uint64_t func_ptr) {
963 using FuncType = void(double, int);
964 std::function<FuncType> func =
965 reinterpret_cast<FuncType*>(func_ptr);
966 self.set_memory_fraction_fn(func);
967 })
968 .def(
969 "set_base_alloc_fn",
970 [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self,
971 uint64_t func_ptr) {
972 using FuncType = void*(void*, size_t*);
973 std::function<FuncType> func =
974 reinterpret_cast<FuncType*>(func_ptr);
975 self.set_base_alloc_fn(func);
976 })
977 .def(
978 "set_record_stream_fn",
979 [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self,
980 uint64_t func_ptr) {
981 using FuncType = void(void*, cudaStream_t);
982 std::function<FuncType> func =
983 reinterpret_cast<FuncType*>(func_ptr);
984 self.set_record_stream_fn(func);
985 })
986 .def(
987 "set_capture_begin_fn",
988 [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self,
989 uint64_t func_ptr) {
990 using FuncType =
991 void(int, c10::cuda::CaptureId_t, c10::cuda::MempoolId_t);
992 std::function<FuncType> func =
993 reinterpret_cast<FuncType*>(func_ptr);
994 self.set_capture_begin_fn(func);
995 })
996 .def(
997 "set_capture_about_to_end_fn",
998 [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self,
999 uint64_t func_ptr) {
1000 using FuncType = void(int, c10::cuda::CaptureId_t);
1001 std::function<FuncType> func =
1002 reinterpret_cast<FuncType*>(func_ptr);
1003 self.set_capture_about_to_end_fn(func);
1004 })
1005 .def(
1006 "set_capture_ended_fn",
1007 [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self,
1008 uint64_t func_ptr) {
1009 using FuncType = void(int, c10::cuda::CaptureId_t);
1010 std::function<FuncType> func =
1011 reinterpret_cast<FuncType*>(func_ptr);
1012 self.set_capture_ended_fn(func);
1013 })
1014 .def(
1015 "set_capture_destroy_fn",
1016 [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self,
1017 uint64_t func_ptr) {
1018 using FuncType = void(int, c10::cuda::MempoolId_t);
1019 std::function<FuncType> func =
1020 reinterpret_cast<FuncType*>(func_ptr);
1021 self.set_capture_destroy_fn(func);
1022 });
1023 m.def("_cuda_customAllocator", [](uint64_t malloc_ptr, uint64_t free_ptr) {
1024 using MallocFuncType = void*(size_t, int, cudaStream_t);
1025 using FreeFuncType = void(void*, size_t, int, cudaStream_t);
1026 std::function<MallocFuncType> malloc_fn =
1027 reinterpret_cast<MallocFuncType*>(malloc_ptr);
1028 std::function<FreeFuncType> free_fn =
1029 reinterpret_cast<FreeFuncType*>(free_ptr);
1030 return torch::cuda::CUDAPluggableAllocator::createCustomAllocator(
1031 malloc_fn, free_fn);
1032 });
1033}
1034
1035static void bindGetDeviceProperties(PyObject* module) {
1036 // Add method to torch.cuda
1037 auto m = py::handle(module).cast<py::module>();
1038 m.def(
1039 "_get_device_properties",
1040 [](int device) -> cudaDeviceProp* {
1041 return at::cuda::getDeviceProperties(device);
1042 },
1043 py::return_value_policy::reference);
1044}
1045
1046// Callback for python part. Used for additional initialization of python
1047// classes
1048static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) {
1049#if C10_ASAN_ENABLED
1050 TORCH_WARN(
1051 "torch.cuda: your pytorch binary has address sanitizer (asan) built in, "
1052 "asan is currently not compatible with torch.cuda module, "
1053 "you might get unexpected behavior (eg. out of memory, crash, etc.), "
1054 "please rebuild pytorch without asan if you need to use this module");
1055#endif
1056 HANDLE_TH_ERRORS
1057 TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
1058 poison_fork();
1059 at::globalContext().lazyInitCUDA();
1060
1061 auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
1062 if (!m)
1063 throw python_error();
1064
1065 bool has_half = true;
1066
1067 auto set_module_attr = [&](const char* name, PyObject* v) {
1068 // PyObject_SetAttrString doesn't steal reference. So no need to incref.
1069 if (PyObject_SetAttrString(m, name, v) < 0) {
1070 throw python_error();
1071 }
1072 };
1073
1074 set_module_attr("has_magma", at::hasMAGMA() ? Py_True : Py_False);
1075 set_module_attr("has_half", has_half ? Py_True : Py_False);
1076
1077 auto num_gpus = c10::cuda::device_count();
1078 auto default_cuda_generators = PyTuple_New(static_cast<Py_ssize_t>(num_gpus));
1079 for (const auto i : c10::irange(num_gpus)) {
1080 auto cast_gen = (THPGenerator*)THPGenerator_initDefaultGenerator(
1081 at::cuda::detail::getDefaultCUDAGenerator(i));
1082 // This reference is meant to be given away, so no need to incref here.
1083 PyTuple_SetItem(default_cuda_generators, i, (PyObject*)cast_gen);
1084 }
1085 set_module_attr("default_generators", default_cuda_generators);
1086 bindGetDeviceProperties(m);
1087
1088 Py_RETURN_NONE;
1089 END_HANDLE_TH_ERRORS
1090}
1091
1092PyObject* THCPModule_getCurrentBlasHandle_wrap(
1093 PyObject* self,
1094 PyObject* noargs) {
1095 HANDLE_TH_ERRORS
1096 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1097 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1098 return PyLong_FromVoidPtr(handle);
1099 END_HANDLE_TH_ERRORS
1100}
1101
1102static PyObject* THCPModule_clearBlasWorkspaces_wrap(
1103 PyObject* self,
1104 PyObject* noargs) {
1105 HANDLE_TH_ERRORS
1106 at::cuda::clearCublasWorkspaces();
1107 Py_RETURN_NONE;
1108 END_HANDLE_TH_ERRORS
1109}
1110
1111PyObject* THCPModule_rocm_is_backward_pass(
1112 PyObject* _unused,
1113 PyObject* noargs) {
1114 HANDLE_TH_ERRORS
1115#if USE_ROCM
1116 if (at::ROCmBackwardPassGuard::is_backward_pass()) {
1117 Py_RETURN_TRUE;
1118 } else {
1119 Py_RETURN_FALSE;
1120 }
1121#else
1122 Py_RETURN_FALSE;
1123#endif
1124 END_HANDLE_TH_ERRORS
1125}
1126
1127static PyObject* THCPModule_isCurrentStreamCapturing_wrap(
1128 PyObject* self,
1129 PyObject* noargs) {
1130 HANDLE_TH_ERRORS
1131 // If there's no cuda context, at::cuda::currentStreamCaptureStatus returns
1132 // CaptureStatus::None without initializing a context.
1133 if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
1134 Py_RETURN_FALSE;
1135 } else {
1136 Py_RETURN_TRUE;
1137 }
1138 END_HANDLE_TH_ERRORS
1139}
1140
1141PyObject* THCPModule_setBenchmarkLimitCuDNN(PyObject* _unused, PyObject* arg) {
1142 THPUtils_assert(
1143 THPUtils_checkLong(arg),
1144 "set_benchmark_limit_cudnn expects an int, "
1145 "but got %s",
1146 THPUtils_typename(arg));
1147 auto benchmark_limit = static_cast<int>(THPUtils_unpackLong(arg));
1148#if defined(USE_ROCM)
1149 TORCH_WARN_ONCE(
1150 "cuDNN Benchmark limit is not supported in MIOpen and will have no effect.");
1151#endif
1152#if AT_CUDNN_ENABLED()
1153#if HAS_CUDNN_V8()
1154 at::globalContext().setBenchmarkLimitCuDNN(benchmark_limit);
1155#else
1156 TORCH_WARN_ONCE(
1157 "cuDNN Benchmark limit is not supported with cuDNN v7 API and will have no effect.");
1158#endif
1159#endif
1160 Py_RETURN_NONE;
1161}
1162
1163PyObject* THCPModule_benchmarkLimitCuDNN(PyObject* _unused, PyObject* noargs) {
1164 return THPUtils_packInt32(at::globalContext().benchmarkLimitCuDNN());
1165}
1166
1167// NOLINTNEXTLINE(modernize-avoid-c-arrays,
1168// cppcoreguidelines-avoid-non-const-global-variables,
1169// cppcoreguidelines-avoid-c-arrays)
1170static struct PyMethodDef _THCPModule_methods[] = {
1171 {"_cuda_init", THCPModule_initExtension, METH_NOARGS, nullptr},
1172 {"_cuda_setDevice", THCPModule_setDevice_wrap, METH_O, nullptr},
1173 {"_cuda_exchangeDevice", THCPModule_exchangeDevice, METH_O, nullptr},
1174 {"_cuda_getDevice", THCPModule_getDevice_wrap, METH_NOARGS, nullptr},
1175 {"_cuda_getDeviceCount",
1176 THCPModule_getDeviceCount_wrap,
1177 METH_NOARGS,
1178 nullptr},
1179 {"_cuda_canDeviceAccessPeer",
1180 THCPModule_canDeviceAccessPeer_wrap,
1181 METH_VARARGS,
1182 nullptr},
1183 {"_cuda_getArchFlags", THCPModule_getArchFlags, METH_NOARGS, nullptr},
1184 {"_cuda_isInBadFork", THCPModule_isInBadFork, METH_NOARGS, nullptr},
1185 {"_cuda_getCurrentStream",
1186 THCPModule_getCurrentStream_wrap,
1187 METH_O,
1188 nullptr},
1189 {"_cuda_getCurrentRawStream",
1190 THCPModule_getCurrentStream_raw,
1191 METH_O,
1192 nullptr},
1193 {"_cuda_getDefaultStream",
1194 THCPModule_getDefaultStream_wrap,
1195 METH_O,
1196 nullptr},
1197 {"_cuda_getCurrentBlasHandle",
1198 THCPModule_getCurrentBlasHandle_wrap,
1199 METH_NOARGS,
1200 nullptr},
1201 {"_cuda_clearCublasWorkspaces",
1202 THCPModule_clearBlasWorkspaces_wrap,
1203 METH_NOARGS,
1204 nullptr},
1205 {"_cuda_isCurrentStreamCapturing",
1206 THCPModule_isCurrentStreamCapturing_wrap,
1207 METH_NOARGS,
1208 nullptr},
1209 {"_cuda_setStream",
1210 castPyCFunctionWithKeywords(THCPModule_setStream_wrap),
1211 METH_VARARGS | METH_KEYWORDS,
1212 nullptr},
1213 {"_cuda_getCompiledVersion",
1214 THCPModule_getCompiledVersion,
1215 METH_NOARGS,
1216 nullptr},
1217 {"_cuda_hasPrimaryContext", THCPModule_hasPrimaryContext, METH_O, nullptr},
1218 {"_cuda_setMemoryFraction",
1219 THCPModule_setMemoryFraction,
1220 METH_VARARGS,
1221 nullptr},
1222 {"_cuda_emptyCache", THCPModule_emptyCache, METH_NOARGS, nullptr},
1223 {"_cuda_memoryStats", THCPModule_memoryStats, METH_O, nullptr},
1224 {"_cuda_resetAccumulatedMemoryStats",
1225 THCPModule_resetAccumulatedMemoryStats,
1226 METH_O,
1227 nullptr},
1228 {"_cuda_resetPeakMemoryStats",
1229 THCPModule_resetPeakMemoryStats,
1230 METH_O,
1231 nullptr},
1232 {"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr},
1233 {"_cuda_attach_out_of_memory_observer",
1234 THCPModule_attachOutOfMemoryObserver,
1235 METH_O,
1236 nullptr},
1237 {"_cuda_cudaHostAllocator",
1238 THCPModule_cudaHostAllocator,
1239 METH_NOARGS,
1240 nullptr},
1241 {"_cuda_cudaCachingAllocator_raw_alloc",
1242 THCPModule_cudaCachingAllocator_raw_alloc,
1243 METH_VARARGS,
1244 nullptr},
1245 {"_cuda_cudaCachingAllocator_raw_delete",
1246 THCPModule_cudaCachingAllocator_raw_delete,
1247 METH_O,
1248 nullptr},
1249 {"_cuda_cudaCachingAllocator_set_allocator_settings",
1250 THCPModule_cudaCachingAllocator_set_allocator_settings,
1251 METH_O,
1252 nullptr},
1253 {"_cuda_getAllocatorBackend",
1254 THCPModule_getAllocatorBackend,
1255 METH_NOARGS,
1256 nullptr},
1257 {"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr},
1258 {"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr},
1259 {"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr},
1260 {"_cuda_lock_mutex", THCPModule_cudaLockMutex, METH_NOARGS, nullptr},
1261 {"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr},
1262 {"_cuda_set_sync_debug_mode",
1263 THCPModule_cudaSetSyncDebugMode,
1264 METH_O,
1265 nullptr},
1266 {"_cuda_get_sync_debug_mode",
1267 THCPModule_cudaGetSyncDebugMode,
1268 METH_NOARGS,
1269 nullptr},
1270 {"_cuda_jiterator_compile_and_launch_kernel",
1271 THCPModule_cudaJiteratorCompileAndLaunchKernel,
1272 METH_VARARGS,
1273 nullptr},
1274 {"_cuda_get_cudnn_benchmark_limit",
1275 THCPModule_benchmarkLimitCuDNN,
1276 METH_NOARGS,
1277 nullptr},
1278 {"_cuda_set_cudnn_benchmark_limit",
1279 THCPModule_setBenchmarkLimitCuDNN,
1280 METH_O,
1281 nullptr},
1282#ifdef USE_NCCL
1283 {"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr},
1284 {"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr},
1285 {"_nccl_init_rank", THCPModule_nccl_init_rank, METH_VARARGS, nullptr},
1286 {"_nccl_reduce", THCPModule_nccl_reduce, METH_VARARGS, nullptr},
1287 {"_nccl_all_reduce", THCPModule_nccl_all_reduce, METH_VARARGS, nullptr},
1288 {"_nccl_broadcast", THCPModule_nccl_broadcast, METH_VARARGS, nullptr},
1289 {"_nccl_all_gather", THCPModule_nccl_all_gather, METH_VARARGS, nullptr},
1290 {"_nccl_reduce_scatter",
1291 THCPModule_nccl_reduce_scatter,
1292 METH_VARARGS,
1293 nullptr},
1294#endif
1295 {"_rocm_is_backward_pass",
1296 THCPModule_rocm_is_backward_pass,
1297 METH_NOARGS,
1298 nullptr},
1299 {nullptr}};
1300
1301PyMethodDef* THCPModule_methods() {
1302 return _THCPModule_methods;
1303}
1304
1305namespace torch {
1306namespace cuda {
1307
1308namespace shared {
1309
1310void initCudartBindings(PyObject* module);
1311void initNvtxBindings(PyObject* module);
1312#if defined(USE_CUDNN) || defined(USE_ROCM)
1313void initCudnnBindings(PyObject* module);
1314#endif
1315
1316} // namespace shared
1317
1318void initModule(PyObject* module) {
1319 python::initCommMethods(module);
1320 // As weird as it seems, this file is also compiled for ROCm,
1321 // so this condition might not always be true...
1322 shared::initCudartBindings(module);
1323 shared::initNvtxBindings(module);
1324#if defined(USE_CUDNN) || defined(USE_ROCM)
1325 shared::initCudnnBindings(module);
1326#endif
1327 registerCudaDeviceProperties(module);
1328 registerCudaPluggableAllocator(module);
1329}
1330
1331} // namespace cuda
1332} // namespace torch
1333