1#include <c10/util/Optional.h>
2#include <sys/types.h>
3#include <torch/csrc/python_headers.h>
4
5#ifndef _MSC_VER
6#include <sys/socket.h>
7#endif
8
9#include <ATen/ATen.h>
10#include <ATen/DLConvertor.h>
11#include <ATen/ExpandUtils.h>
12#include <ATen/LegacyVmapMode.h>
13#include <ATen/LinalgBackend.h>
14#include <ATen/Parallel.h>
15#include <ATen/Utils.h>
16#include <ATen/core/Vitals.h>
17#include <ATen/dlpack.h>
18#include <ATen/native/ConvUtils.h>
19#include <c10/core/DispatchKeySet.h>
20#include <c10/util/Logging.h>
21#include <c10/util/irange.h>
22#include <libshm.h>
23#include <pybind11/pybind11.h>
24#include <pybind11/stl.h>
25#include <torch/csrc/THConcat.h>
26#include <torch/csrc/utils/pybind.h>
27#include <cstdlib>
28#include <unordered_map>
29
30#include <ATen/ThreadLocalPythonObjects.h>
31#include <torch/csrc/DataLoader.h>
32#include <torch/csrc/Device.h>
33#include <torch/csrc/Dtype.h>
34#include <torch/csrc/DynamicTypes.h>
35#include <torch/csrc/Generator.h>
36#include <torch/csrc/Layout.h>
37#include <torch/csrc/MemoryFormat.h>
38#include <torch/csrc/QScheme.h>
39#include <torch/csrc/Stream.h>
40#include <torch/csrc/THP.h>
41#include <torch/csrc/TypeInfo.h>
42#include <torch/csrc/api/include/torch/python/init.h>
43#include <torch/csrc/autograd/python_cpp_function.h>
44#include <torch/csrc/autograd/python_enum_tag.h>
45#include <torch/csrc/autograd/python_fft_functions.h>
46#include <torch/csrc/autograd/python_function.h>
47#include <torch/csrc/autograd/python_legacy_variable.h>
48#include <torch/csrc/autograd/python_linalg_functions.h>
49#include <torch/csrc/autograd/python_nested_functions.h>
50#include <torch/csrc/autograd/python_nn_functions.h>
51#include <torch/csrc/autograd/python_return_types.h>
52#include <torch/csrc/autograd/python_sparse_functions.h>
53#include <torch/csrc/autograd/python_special_functions.h>
54#include <torch/csrc/autograd/python_variable.h>
55#include <torch/csrc/dynamo/init.h>
56#include <torch/csrc/functorch/init.h>
57#include <torch/csrc/jit/python/init.h>
58#include <torch/csrc/jit/python/python_ir.h>
59#include <torch/csrc/jit/python/python_tracer.h>
60#include <torch/csrc/jit/serialization/pickler.h>
61#include <torch/csrc/lazy/python/init.h>
62#include <torch/csrc/monitor/python_init.h>
63#include <torch/csrc/multiprocessing/init.h>
64#include <torch/csrc/onnx/init.h>
65#include <torch/csrc/profiler/python/init.h>
66#include <torch/csrc/tensor/python_tensor.h>
67#include <torch/csrc/utils/disable_torch_function.h>
68#include <torch/csrc/utils/init.h>
69#include <torch/csrc/utils/pycfunction_helpers.h>
70#include <torch/csrc/utils/python_arg_parser.h>
71#include <torch/csrc/utils/python_compat.h>
72#include <torch/csrc/utils/python_dispatch.h>
73#include <torch/csrc/utils/python_strings.h>
74#include <torch/csrc/utils/tensor_dtypes.h>
75#include <torch/csrc/utils/tensor_layouts.h>
76#include <torch/csrc/utils/tensor_memoryformats.h>
77#include <torch/csrc/utils/tensor_new.h>
78#include <torch/csrc/utils/tensor_numpy.h>
79#include <torch/csrc/utils/tensor_qschemes.h>
80
81#ifdef USE_DISTRIBUTED
82#ifdef USE_C10D
83#include <torch/csrc/distributed/autograd/python_autograd.h>
84#include <torch/csrc/distributed/c10d/c10d.h>
85#include <torch/csrc/distributed/rpc/rpc.h>
86#include <torch/csrc/distributed/rpc/testing/testing.h>
87#endif
88#endif
89
90#if defined(USE_MPS)
91#include <ATen/mps/MPSDevice.h>
92#endif
93
94#if defined(USE_VALGRIND)
95#include <callgrind.h>
96#endif
97
98namespace py = pybind11;
99
100PyObject* module;
101
102THPGenerator* THPDefaultCPUGenerator = nullptr;
103
104////////////////////////////////////////////////////////////////////////////////
105////////////////////////////////////////////////////////////////////////////////
106
107static PyObject* THPModule_initNames(PyObject* self, PyObject* arg) {
108 static std::vector<std::string> names;
109
110 THPObjectPtr types(PySequence_Fast(arg, "expected a sequence"));
111 if (!types)
112 return nullptr;
113
114 // NOLINTNEXTLINE(bugprone-branch-clone)
115 auto num_classes = PySequence_Fast_GET_SIZE(types.get());
116 names.reserve(names.size() + num_classes);
117 for (Py_ssize_t i = 0; i < num_classes; i++) {
118 PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i);
119 THPUtils_assert(PyType_Check(obj), "expected a PyTypeObject");
120 PyTypeObject* type = (PyTypeObject*)obj;
121
122 THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__"));
123 if (!module_name)
124 return nullptr;
125 THPUtils_assert(
126 THPUtils_checkString(module_name.get()),
127 "expected __module__ to be a string");
128 std::string name = THPUtils_unpackString(module_name.get());
129 names.emplace_back(name + "." + type->tp_name);
130 type->tp_name = names.back().c_str();
131 }
132 Py_RETURN_NONE;
133}
134//
135// Callback for python part. Used for additional initialization of python
136// classes
137static PyObject* THPModule_initExtension(
138 PyObject* _unused,
139 PyObject* shm_manager_path) {
140 HANDLE_TH_ERRORS
141 if (!THPUtils_checkString(shm_manager_path)) {
142 THPUtils_setError(
143 "initialization error - expected bytes/string object as shm_manager_path!");
144 return nullptr;
145 }
146 torch::utils::initializeLayouts();
147 torch::utils::initializeMemoryFormats();
148 torch::utils::initializeQSchemes();
149 torch::utils::initializeDtypes();
150 torch::tensors::initialize_python_bindings();
151 std::string path = THPUtils_unpackString(shm_manager_path);
152 libshm_init(path.c_str());
153
154 auto module = THPObjectPtr(PyImport_ImportModule("torch"));
155 if (!module)
156 throw python_error();
157
158 THPStorage_postInit(module);
159 THPAutograd_initFunctions();
160 Py_RETURN_NONE;
161 END_HANDLE_TH_ERRORS
162}
163
164// The idea behind these two functions is to make it easy to test if we are
165// built with ASAN: they're designed not to crash if ASAN is not enabled, but
166// to trigger ASAN if it is enabled. This lets us run a "canary" tests which
167// checks if our build environment is misconfigured.
168
169static PyObject* THPModule_crashIfCsrcASAN(PyObject* module, PyObject* arg) {
170 THPUtils_assert(
171 THPUtils_checkLong(arg),
172 "crash_if_csrc_asan expects an int, "
173 "but got %s",
174 THPUtils_typename(arg));
175 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays)
176 volatile char x[3];
177 x[THPUtils_unpackInt(arg)] = 0;
178 // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
179 return THPUtils_packInt32(x[0]);
180}
181
182static PyObject* THPModule_crashIfCsrcUBSAN(PyObject* module, PyObject* arg) {
183 THPUtils_assert(
184 THPUtils_checkLong(arg),
185 "crash_if_csrc_ubsan expects an int, "
186 "but got %s",
187 THPUtils_typename(arg));
188 int32_t x = THPUtils_unpackInt(arg);
189 double y = 1.0 / x;
190 return THPUtils_packInt32((int)y);
191}
192
193static PyObject* THPModule_crashIfvptrUBSAN(PyObject* module, PyObject* noarg) {
194 // This code shoud work perfectly fine, as vtables are idential for Foo and
195 // Baz unless rtti and ubsan are enabled
196 struct Foo {
197 virtual int bar() = 0;
198 virtual ~Foo() = default;
199 };
200 struct Baz {
201 virtual int bar() {
202 return 17;
203 }
204 virtual ~Baz() = default;
205 };
206 Baz x{};
207 auto y = static_cast<Foo*>(static_cast<void*>(&x));
208 auto rc = y->bar();
209 return THPUtils_packInt32(rc);
210}
211
212static PyObject* THPModule_crashIfATenASAN(PyObject* module, PyObject* arg) {
213 THPUtils_assert(
214 THPUtils_checkLong(arg),
215 "crash_if_aten_asan expects an int, "
216 "but got %s",
217 THPUtils_typename(arg));
218 return THPUtils_packInt32(at::_crash_if_asan(THPUtils_unpackInt(arg)));
219}
220
221static PyObject* THPModule_getNumThreads(PyObject* module, PyObject* noargs) {
222 return THPUtils_packInt32(at::get_num_threads());
223}
224
225static PyObject* THPModule_setNumThreads(PyObject* module, PyObject* arg) {
226 THPUtils_assert(
227 THPUtils_checkLong(arg),
228 "set_num_threads expects an int, "
229 "but got %s",
230 THPUtils_typename(arg));
231 int nthreads = (int)THPUtils_unpackLong(arg);
232 THPUtils_assert(nthreads > 0, "set_num_threads expects a positive integer");
233 at::set_num_threads(nthreads);
234 Py_RETURN_NONE;
235}
236
237static PyObject* THPModule_getNumInteropThreads(
238 PyObject* module,
239 PyObject* noargs) {
240 return THPUtils_packInt32(at::get_num_interop_threads());
241}
242
243static PyObject* THPModule_setNumInteropThreads(
244 PyObject* module,
245 PyObject* arg) {
246 THPUtils_assert(
247 THPUtils_checkLong(arg),
248 "set_num_interop_threads expects an int, "
249 "but got %s",
250 THPUtils_typename(arg));
251 int nthreads = (int)THPUtils_unpackLong(arg);
252 THPUtils_assert(
253 nthreads > 0, "set_num_interop_threads expects a positive integer");
254 at::set_num_interop_threads(nthreads);
255 Py_RETURN_NONE;
256}
257
258PyObject* THPModule_setDefaultTensorType(PyObject* _unused, PyObject* type) {
259 HANDLE_TH_ERRORS
260 torch::tensors::py_set_default_tensor_type(type);
261 Py_RETURN_NONE;
262 END_HANDLE_TH_ERRORS
263}
264
265PyObject* THPModule_setDefaultDtype(PyObject* _unused, PyObject* dtype) {
266 HANDLE_TH_ERRORS
267 torch::tensors::py_set_default_dtype(dtype);
268 Py_RETURN_NONE;
269 END_HANDLE_TH_ERRORS
270}
271
272PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
273 // adds a __doc__ string to a function, similar to numpy's arr_add_docstring
274 static std::vector<std::string> all_docs;
275 PyObject* obj = nullptr;
276 PyObject* doc_obj = nullptr;
277 if (!PyArg_ParseTuple(args, "OO", &obj, &doc_obj)) {
278 return nullptr;
279 }
280
281 const char* doc_str = "<invalid string>";
282 if (THPUtils_checkString(doc_obj)) {
283 all_docs.push_back(THPUtils_unpackString(doc_obj));
284 doc_str = all_docs.back().c_str();
285 }
286
287 if (Py_TYPE(obj) == &PyCFunction_Type) {
288 PyCFunctionObject* f = (PyCFunctionObject*)obj;
289 if (f->m_ml->ml_doc) {
290 return PyErr_Format(
291 PyExc_RuntimeError,
292 "function '%s' already has a docstring",
293 f->m_ml->ml_name);
294 }
295 f->m_ml->ml_doc = doc_str;
296 } else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
297 PyMethodDescrObject* m = (PyMethodDescrObject*)obj;
298 if (m->d_method->ml_doc) {
299 return PyErr_Format(
300 PyExc_RuntimeError,
301 "method '%s' already has a docstring",
302 m->d_method->ml_name);
303 }
304 m->d_method->ml_doc = doc_str;
305 } else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor") == 0) {
306 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
307 PyGetSetDescrObject* m = (PyGetSetDescrObject*)obj;
308 if (m->d_getset->doc) {
309 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg)
310 return PyErr_Format(
311 PyExc_RuntimeError,
312 "attribute '%s' already has a docstring",
313 m->d_getset->name);
314 }
315 // This field is not const for python < 3.7 yet the content is
316 // never modified.
317 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
318 m->d_getset->doc = const_cast<char*>(doc_str);
319 } else if (Py_TYPE(obj) == &PyType_Type) {
320 PyTypeObject* t = (PyTypeObject*)obj;
321 if (t->tp_doc) {
322 return PyErr_Format(
323 PyExc_RuntimeError, "Type '%s' already has a docstring", t->tp_name);
324 }
325 t->tp_doc = doc_str;
326 } else {
327 return PyErr_Format(
328 PyExc_TypeError,
329 "don't know how to add docstring to type '%s'",
330 Py_TYPE(obj)->tp_name);
331 }
332
333 Py_INCREF(obj);
334 return obj;
335}
336
337PyObject* THPModule_inferSize(PyObject* _unused, PyObject* args) {
338 HANDLE_TH_ERRORS
339 Py_ssize_t num_args = args ? (Py_ssize_t)PyTuple_Size(args) : 0;
340 THPUtils_assert(num_args == 2, "expected exactly 2 arguments");
341 PyObject* arg1 = PyTuple_GET_ITEM(args, 0);
342 THPUtils_assert(THPSize_Check(arg1), "expected a torch.Size as argument 1");
343 PyObject* arg2 = PyTuple_GET_ITEM(args, 1);
344 THPUtils_assert(THPSize_Check(arg2), "expected a torch.Size as argument 2");
345
346 auto size1 = THPUtils_unpackLongs(arg1);
347 auto size2 = THPUtils_unpackLongs(arg2);
348 auto sizes = at::infer_size(size1, size2);
349 return THPSize_NewFromSizes(sizes.size(), sizes.data());
350 END_HANDLE_TH_ERRORS
351}
352
353static PyObject* THPModule_setBackcompatBroadcastWarn(
354 PyObject* module,
355 PyObject* arg) {
356 THPUtils_assert(
357 PyBool_Check(arg),
358 "set_backcompat_broadcast_warn expects a bool, "
359 "but got %s",
360 THPUtils_typename(arg));
361 setBackCompatBroadcastWarn(arg == Py_True);
362 Py_RETURN_NONE;
363}
364
365static PyObject* THPModule_getBackcompatBroadcastWarn(
366 PyObject* module,
367 PyObject* noargs) {
368 if (getBackCompatBroadcastWarn())
369 Py_RETURN_TRUE;
370 else
371 Py_RETURN_FALSE;
372}
373
374static PyObject* THPModule_setBackcompatKeepdimWarn(
375 PyObject* module,
376 PyObject* arg) {
377 THPUtils_assert(
378 PyBool_Check(arg),
379 "set_backcompat_keepdim_warn expects a bool, "
380 "but got %s",
381 THPUtils_typename(arg));
382 setBackCompatKeepdimWarn(arg == Py_True);
383 Py_RETURN_NONE;
384}
385
386static PyObject* THPModule_getBackcompatKeepdimWarn(
387 PyObject* module,
388 PyObject* noargs) {
389 if (getBackCompatKeepdimWarn())
390 Py_RETURN_TRUE;
391 else
392 Py_RETURN_FALSE;
393}
394
395PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) {
396#ifdef USE_DISTRIBUTED
397 Py_RETURN_TRUE;
398#else
399 Py_RETURN_FALSE;
400#endif
401}
402
403static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) {
404 HANDLE_TH_ERRORS
405 return THPUtils_packString(at::show_config());
406 END_HANDLE_TH_ERRORS
407}
408
409static PyObject* THPModule_cxxFlags(PyObject* module, PyObject* noargs) {
410 HANDLE_TH_ERRORS
411 return THPUtils_packString(at::get_cxx_flags());
412 END_HANDLE_TH_ERRORS
413}
414
415static PyObject* THPModule_parallelInfo(PyObject* module, PyObject* noargs) {
416 HANDLE_TH_ERRORS
417 return THPUtils_packString(at::get_parallel_info());
418 END_HANDLE_TH_ERRORS
419}
420
421void DLPack_Capsule_Destructor(PyObject* data) {
422 if (C10_LIKELY(!PyCapsule_IsValid(data, "dltensor"))) {
423 // early out, see DLPack spec: if a consuming library sets the capsule
424 // name to something else, they own it and we don't need to do anything
425 return;
426 }
427 HANDLE_TH_ERRORS
428 // Causes overheads for validity checks again, but this case is rare
429 // since consuming libraries should rename the capsule according to spec.
430 // Note that this cannot set a python error (we checked validity above),
431 // so we don't need to handle python error state here.
432 DLManagedTensor* dlMTensor =
433 (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
434 // the dlMTensor has not been consumed, call deleter ourselves.
435 // DLPack spec mentions that deleter may be NULL, but deleter from
436 // `at::toDLPack` is never NULL, so no need for an additional check here.
437 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
438 dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor));
439 END_HANDLE_TH_ERRORS_RET()
440}
441
442PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) {
443 HANDLE_TH_ERRORS
444 THPUtils_assert(THPVariable_Check(data), "data must be a Tensor");
445 DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data));
446 return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor);
447 END_HANDLE_TH_ERRORS
448}
449
450PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) {
451 using namespace torch::autograd;
452 HANDLE_TH_ERRORS
453 auto tensor = torch::utils::tensor_fromDLPack(data);
454 return THPVariable_Wrap(tensor);
455 END_HANDLE_TH_ERRORS
456}
457
458PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) {
459 HANDLE_TH_ERRORS
460 size_t frames_to_skip;
461 size_t maximum_number_of_frames;
462 if (!PyArg_ParseTuple(
463 args, "LL", &frames_to_skip, &maximum_number_of_frames)) {
464 return nullptr;
465 }
466 return THPUtils_packString(
467 c10::get_backtrace(frames_to_skip, maximum_number_of_frames, true));
468 END_HANDLE_TH_ERRORS
469}
470static PyObject* THModule_rename_privateuse1_backend(
471 PyObject* _unused,
472 PyObject* arg) {
473 HANDLE_TH_ERRORS
474 THPUtils_assert(
475 THPUtils_checkString(arg),
476 "_rename_privateuse1_backend expects a str, "
477 "but got %s",
478 THPUtils_typename(arg));
479 const std::string backend_name = THPUtils_unpackString(arg);
480 c10::register_privateuse1_backend(backend_name);
481 Py_RETURN_NONE;
482 END_HANDLE_TH_ERRORS
483}
484
485PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
486 THPUtils_assert(
487 PyBool_Check(arg),
488 "set_allow_tf32_cublas expects a bool, "
489 "but got %s",
490 THPUtils_typename(arg));
491 at::globalContext().setAllowTF32CuDNN(arg == Py_True);
492 Py_RETURN_NONE;
493}
494
495PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) {
496 if (at::globalContext().allowTF32CuDNN())
497 Py_RETURN_TRUE;
498 else
499 Py_RETURN_FALSE;
500}
501
502PyObject* THPModule_setFloat32MatmulPrecision(
503 PyObject* _unused,
504 PyObject* arg) {
505 THPUtils_assert(
506 THPUtils_checkString(arg),
507 "set_float32_matmul_precision expects a str, "
508 "but got %s",
509 THPUtils_typename(arg));
510 std::string s = THPUtils_unpackString(arg);
511 at::globalContext().setFloat32MatmulPrecision(s);
512 Py_RETURN_NONE;
513}
514
515PyObject* THPModule_float32MatmulPrecision(
516 PyObject* _unused,
517 PyObject* noargs) {
518 std::string s = "highest";
519 auto p = at::globalContext().float32MatmulPrecision();
520 if (p == at::Float32MatmulPrecision::HIGH) {
521 s = "high";
522 } else if (p == at::Float32MatmulPrecision::MEDIUM) {
523 s = "medium";
524 }
525 return THPUtils_packString(s);
526}
527PyObject* THPModule_setSDPUseFlash(PyObject* _unused, PyObject* arg) {
528 THPUtils_assert(
529 PyBool_Check(arg),
530 "set_sdp_use_math expects a bool, "
531 "but got %s",
532 THPUtils_typename(arg));
533 at::globalContext().setSDPUseFlash(arg == Py_True);
534 Py_RETURN_NONE;
535}
536PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) {
537 if (at::globalContext().userEnabledFlashSDP())
538 Py_RETURN_TRUE;
539 else
540 Py_RETURN_FALSE;
541}
542PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) {
543 THPUtils_assert(
544 PyBool_Check(arg),
545 "set_sdp_use_math expects a bool, "
546 "but got %s",
547 THPUtils_typename(arg));
548 at::globalContext().setSDPUseMemEfficient(arg == Py_True);
549 Py_RETURN_NONE;
550}
551PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) {
552 if (at::globalContext().userEnabledMemEfficientSDP())
553 Py_RETURN_TRUE;
554 else
555 Py_RETURN_FALSE;
556}
557PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) {
558 THPUtils_assert(
559 PyBool_Check(arg),
560 "set_sdp_use_math expects a bool, "
561 "but got %s",
562 THPUtils_typename(arg));
563 at::globalContext().setSDPUseMath(arg == Py_True);
564 Py_RETURN_NONE;
565}
566PyObject* THPModule_userEnabledMathSDP(PyObject* _unused, PyObject* noargs) {
567 if (at::globalContext().userEnabledMathSDP())
568 Py_RETURN_TRUE;
569 else
570 Py_RETURN_FALSE;
571}
572PyObject* THPModule_setUserEnabledCuDNN(PyObject* _unused, PyObject* arg) {
573 THPUtils_assert(
574 PyBool_Check(arg),
575 "set_enabled_cudnn expects a bool, "
576 "but got %s",
577 THPUtils_typename(arg));
578 at::globalContext().setUserEnabledCuDNN(arg == Py_True);
579 Py_RETURN_NONE;
580}
581
582PyObject* THPModule_userEnabledCuDNN(PyObject* _unused, PyObject* noargs) {
583 if (at::globalContext().userEnabledCuDNN())
584 Py_RETURN_TRUE;
585 else
586 Py_RETURN_FALSE;
587}
588
589PyObject* THPModule_setUserEnabledMkldnn(PyObject* _unused, PyObject* arg) {
590 THPUtils_assert(
591 PyBool_Check(arg),
592 "set_enabled_mkldnn expects a bool, "
593 "but got %s",
594 THPUtils_typename(arg));
595 at::globalContext().setUserEnabledMkldnn(arg == Py_True);
596 Py_RETURN_NONE;
597}
598
599PyObject* THPModule_userEnabledMkldnn(PyObject* _unused, PyObject* noargs) {
600 if (at::globalContext().userEnabledMkldnn())
601 Py_RETURN_TRUE;
602 else
603 Py_RETURN_FALSE;
604}
605
606PyObject* THPModule_setDeterministicCuDNN(PyObject* _unused, PyObject* arg) {
607 HANDLE_TH_ERRORS
608 THPUtils_assert(
609 PyBool_Check(arg),
610 "set_deterministic_cudnn expects a bool, "
611 "but got %s",
612 THPUtils_typename(arg));
613 at::globalContext().setDeterministicCuDNN(arg == Py_True);
614 Py_RETURN_NONE;
615 END_HANDLE_TH_ERRORS
616}
617
618PyObject* THPModule_deterministicCuDNN(PyObject* _unused, PyObject* noargs) {
619 if (at::globalContext().deterministicCuDNN())
620 Py_RETURN_TRUE;
621 else
622 Py_RETURN_FALSE;
623}
624
625PyObject* THPModule_setDeterministicAlgorithms(
626 PyObject* _unused,
627 PyObject* args,
628 PyObject* kwargs) {
629 HANDLE_TH_ERRORS
630 static torch::PythonArgParser parser(
631 {"_set_deterministic_algorithms(bool mode, *, bool warn_only=False)"});
632 torch::ParsedArgs<2> parsed_args{};
633 auto r = parser.parse(args, kwargs, parsed_args);
634 bool mode = r.toBool(0);
635 bool warn_only = r.toBool(1);
636 at::globalContext().setDeterministicAlgorithms(mode, warn_only);
637 Py_RETURN_NONE;
638 END_HANDLE_TH_ERRORS
639}
640
641PyObject* THPModule_deterministicAlgorithms(
642 PyObject* _unused,
643 PyObject* noargs) {
644 if (at::globalContext().deterministicAlgorithms()) {
645 Py_RETURN_TRUE;
646 }
647 Py_RETURN_FALSE;
648}
649
650PyObject* THPModule_deterministicAlgorithmsWarnOnly(
651 PyObject* _unused,
652 PyObject* noargs) {
653 if (at::globalContext().deterministicAlgorithmsWarnOnly()) {
654 Py_RETURN_TRUE;
655 }
656 Py_RETURN_FALSE;
657}
658
659PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) {
660 THPUtils_assert(
661 PyBool_Check(arg),
662 "setWarnOnlyOnce expects a bool, "
663 "but got %s",
664 THPUtils_typename(arg));
665 c10::WarningUtils::set_warnAlways(arg == Py_True);
666 Py_RETURN_NONE;
667}
668
669PyObject* THPModule_warnAlways(PyObject* _unused, PyObject* noargs) {
670 if (c10::WarningUtils::get_warnAlways()) {
671 Py_RETURN_TRUE;
672 }
673 Py_RETURN_FALSE;
674}
675
676// Used only for testing C++ to Python warning translations.
677PyObject* THPModule_warn(PyObject* _unused, PyObject* noargs) {
678 HANDLE_TH_ERRORS
679 TORCH_WARN("Test message for TORCH_WARN");
680 Py_RETURN_NONE;
681 END_HANDLE_TH_ERRORS
682}
683
684// Used only for testing C++ to Python warning translations.
685PyObject* THPModule_warnDeprecation(PyObject* _unused, PyObject* noargs) {
686 HANDLE_TH_ERRORS
687 TORCH_WARN_DEPRECATION("Test message for TORCH_WARN_DEPRECATION");
688 Py_RETURN_NONE;
689 END_HANDLE_TH_ERRORS
690}
691
692PyObject* THPModule_setBenchmarkCuDNN(PyObject* _unused, PyObject* arg) {
693 THPUtils_assert(
694 PyBool_Check(arg),
695 "set_benchmark_cudnn expects a bool, "
696 "but got %s",
697 THPUtils_typename(arg));
698 at::globalContext().setBenchmarkCuDNN(arg == Py_True);
699 Py_RETURN_NONE;
700}
701
702PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) {
703 if (at::globalContext().benchmarkCuDNN()) {
704 Py_RETURN_TRUE;
705 }
706 Py_RETURN_FALSE;
707}
708
709PyObject* THPModule_setAllowTF32CuBLAS(PyObject* _unused, PyObject* arg) {
710 THPUtils_assert(
711 PyBool_Check(arg),
712 "set_allow_tf32_cublas expects a bool, "
713 "but got %s",
714 THPUtils_typename(arg));
715 at::globalContext().setAllowTF32CuBLAS(arg == Py_True);
716 Py_RETURN_NONE;
717}
718
719PyObject* THPModule_allowTF32CuBLAS(PyObject* _unused, PyObject* noargs) {
720 if (at::globalContext().allowTF32CuBLAS()) {
721 Py_RETURN_TRUE;
722 }
723 Py_RETURN_FALSE;
724}
725
726PyObject* THPModule_setAllowFP16ReductionCuBLAS(
727 PyObject* _unused,
728 PyObject* arg) {
729 THPUtils_assert(
730 PyBool_Check(arg),
731 "set_allow_fp16_reduction_cublas expects a bool, "
732 "but got %s",
733 THPUtils_typename(arg));
734 at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True);
735 Py_RETURN_NONE;
736}
737
738PyObject* THPModule_allowFP16ReductionCuBLAS(
739 PyObject* _unused,
740 PyObject* noargs) {
741 if (at::globalContext().allowFP16ReductionCuBLAS()) {
742 Py_RETURN_TRUE;
743 }
744 Py_RETURN_FALSE;
745}
746
747PyObject* THPModule_setAllowBF16ReductionCuBLAS(
748 PyObject* _unused,
749 PyObject* arg) {
750 THPUtils_assert(
751 PyBool_Check(arg),
752 "set_allow_bf16_reduction_cublas expects a bool, "
753 "but got %s",
754 THPUtils_typename(arg));
755 at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True);
756 Py_RETURN_NONE;
757}
758
759PyObject* THPModule_allowBF16ReductionCuBLAS(
760 PyObject* _unused,
761 PyObject* noargs) {
762 if (at::globalContext().allowBF16ReductionCuBLAS()) {
763 Py_RETURN_TRUE;
764 }
765 Py_RETURN_FALSE;
766}
767
768PyObject* THPModule_setFlushDenormal(PyObject* _unused, PyObject* arg) {
769 THPUtils_assert(
770 PyBool_Check(arg),
771 "flush_denormal expects a bool, "
772 "but got %s",
773 THPUtils_typename(arg));
774 if (!at::globalContext().setFlushDenormal(arg == Py_True)) {
775 Py_RETURN_FALSE;
776 };
777 Py_RETURN_TRUE;
778}
779
780PyObject* THPModule_getDefaultDtype(PyObject* _unused, PyObject* arg) {
781 HANDLE_TH_ERRORS
782 auto scalar_type = torch::tensors::get_default_scalar_type();
783 auto dtype = (PyObject*)torch::getTHPDtype(scalar_type);
784 Py_INCREF(dtype);
785 return dtype;
786 END_HANDLE_TH_ERRORS
787}
788
789PyObject* THPModule_getDefaultDevice(PyObject* _unused, PyObject* arg) {
790 HANDLE_TH_ERRORS
791 return THPUtils_packString(c10::DeviceTypeName(
792 dispatchKeyToDeviceType(torch::tensors::get_default_dispatch_key()),
793 /*lower_case=*/true));
794 END_HANDLE_TH_ERRORS
795}
796
797PyObject* THPModule_setQEngine(PyObject* /* unused */, PyObject* arg) {
798 THPUtils_assert(
799 THPUtils_checkLong(arg),
800 "set_qengine expects an int, "
801 "but got %s",
802 THPUtils_typename(arg));
803 HANDLE_TH_ERRORS
804 auto qengine = static_cast<int>(THPUtils_unpackLong(arg));
805 at::globalContext().setQEngine(static_cast<at::QEngine>(qengine));
806 Py_RETURN_NONE;
807 END_HANDLE_TH_ERRORS
808}
809
810PyObject* THPModule_qEngine(PyObject* _unused, PyObject* noargs) {
811 return THPUtils_packInt64(static_cast<int>(at::globalContext().qEngine()));
812}
813
814PyObject* THPModule_supportedQEngines(PyObject* _unused, PyObject* noargs) {
815 auto qengines = at::globalContext().supportedQEngines();
816 auto list = THPObjectPtr(PyList_New(qengines.size()));
817 if (!list)
818 return nullptr;
819 for (const auto i : c10::irange(qengines.size())) {
820 PyObject* i64 = THPUtils_packInt64(static_cast<int>(qengines[i]));
821 if (!i64)
822 return nullptr;
823 PyList_SET_ITEM(list.get(), i, i64);
824 }
825 return list.release();
826}
827
828PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) {
829 if (at::globalContext().isXNNPACKAvailable())
830 Py_RETURN_TRUE;
831 else
832 Py_RETURN_FALSE;
833}
834
835PyObject* THPModule_setCheckSparseTensorInvariants(
836 PyObject* _unused,
837 PyObject* arg) {
838 THPUtils_assert(
839 PyBool_Check(arg),
840 "set_check_sparse_tensor_invariants expects a bool, "
841 "but got %s",
842 THPUtils_typename(arg));
843 at::globalContext().setCheckSparseTensorInvariants(arg == Py_True);
844 Py_RETURN_NONE;
845}
846
847PyObject* THPModule_checkSparseTensorInvariants(
848 PyObject* _unused,
849 PyObject* noargs) {
850 if (at::globalContext().checkSparseTensorInvariants())
851 Py_RETURN_TRUE;
852 else
853 Py_RETURN_FALSE;
854}
855
856PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
857 HANDLE_TH_ERRORS
858 bool isTHPFunction = THPFunction_Check(arg);
859 bool isTHPCppFunction = torch::autograd::THPCppFunction_Check(arg);
860 THPUtils_assert(
861 isTHPFunction || isTHPCppFunction,
862 "_will_engine_execute_node expects an grad_fn, "
863 "but got %s",
864 THPUtils_typename(arg));
865 const auto exec_info = torch::autograd::get_current_graph_task_exec_info();
866 THPUtils_assert(
867 exec_info,
868 "_get_should_execute_nodes should only be called during the backward pass");
869 torch::autograd::Node* node;
870 std::shared_ptr<torch::autograd::Node> node_sp;
871 if (isTHPFunction) {
872 node_sp = ((THPFunction*)arg)->cdata.lock();
873 node = node_sp.get();
874 } else {
875 node = ((torch::autograd::THPCppFunction*)arg)->cdata.get();
876 }
877 const auto nodes_in_graph =
878 torch::autograd::get_current_graph_task_nodes_in_graph();
879 bool ret = nodes_in_graph->find(node) != nodes_in_graph->end();
880 if (ret && !exec_info->empty()) {
881 auto it = exec_info->find(node);
882 if (it == exec_info->end() || !it->second.should_execute()) {
883 ret = false;
884 } else {
885 TORCH_CHECK(
886 !(node->topological_nr() == 0 && it->second.captures_),
887 "A leaf node was passed to _will_engine_execute_node but we are "
888 "currently running autograd.grad(). This is currently not supported.");
889 }
890 }
891 if (ret) {
892 Py_RETURN_TRUE;
893 } else {
894 Py_RETURN_FALSE;
895 }
896 END_HANDLE_TH_ERRORS
897}
898
899PyObject* THPModule_getCurrentGraphTaskExecutionOrder(
900 PyObject* _unused,
901 PyObject* noargs) {
902 HANDLE_TH_ERRORS
903 std::vector<torch::autograd::Node*> nodes =
904 torch::autograd::get_current_graph_task_execution_order();
905 TORCH_CHECK(
906 nodes.size(),
907 "_current_graph_task_execution_order should only be called during the backward pass");
908 auto list = THPObjectPtr(PyList_New(nodes.size()));
909 if (!list)
910 return nullptr;
911 for (const auto i : c10::irange(nodes.size())) {
912 // This node is guaranteed to be alive since the backward is still running
913 PyObject* pyobj_node =
914 torch::autograd::functionToPyObject(nodes[i]->getptr());
915 PyList_SET_ITEM(list.get(), i, pyobj_node);
916 }
917 return list.release();
918 END_HANDLE_TH_ERRORS
919}
920
921PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) {
922 HANDLE_TH_ERRORS
923 return THPUtils_packInt64(torch::autograd::get_current_graph_task_id());
924 END_HANDLE_TH_ERRORS
925}
926
927PyObject* THPModule_getCurrentNode(PyObject* _unused, PyObject* noargs) {
928 HANDLE_TH_ERRORS
929 return torch::autograd::functionToPyObject(
930 torch::autograd::get_current_node());
931 END_HANDLE_TH_ERRORS
932}
933
934PyObject* THPModule_setDefaultMobileCPUAllocator(
935 PyObject* _unused,
936 PyObject* noargs) {
937 HANDLE_TH_ERRORS
938 at::globalContext().setDefaultMobileCPUAllocator();
939 Py_RETURN_NONE;
940 END_HANDLE_TH_ERRORS
941}
942
943PyObject* THPModule_unsetDefaultMobileCPUAllocator(
944 PyObject* _unused,
945 PyObject* noargs) {
946 HANDLE_TH_ERRORS
947 at::globalContext().unsetDefaultMobileCPUAllocator();
948 Py_RETURN_NONE;
949 END_HANDLE_TH_ERRORS
950}
951
952static PyObject* THPModule_vmapmode_increment_nesting(
953 PyObject* _unused,
954 PyObject* arg) {
955 HANDLE_TH_ERRORS
956 return THPUtils_packInt64(at::impl::VmapMode::increment_nesting());
957 END_HANDLE_TH_ERRORS
958}
959
960static PyObject* THPModule_vmapmode_decrement_nesting(
961 PyObject* _unused,
962 PyObject* arg) {
963 HANDLE_TH_ERRORS
964 return THPUtils_packInt64(at::impl::VmapMode::decrement_nesting());
965 END_HANDLE_TH_ERRORS
966}
967
968static PyObject* THPModule_set_display_vmap_fallback_warnings_mode(
969 PyObject* _unused,
970 PyObject* arg) {
971 HANDLE_TH_ERRORS
972 THPUtils_assert(
973 PyBool_Check(arg),
974 "enabled must be a bool, "
975 "but got %s",
976 THPUtils_typename(arg));
977 at::globalContext().setDisplayVmapFallbackWarnings(arg == Py_True);
978 Py_RETURN_NONE;
979 END_HANDLE_TH_ERRORS
980}
981
982static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
983 PyObject* _unused,
984 PyObject* arg) {
985 HANDLE_TH_ERRORS
986 if (at::globalContext().areVmapFallbackWarningsEnabled()) {
987 Py_RETURN_TRUE;
988 } else {
989 Py_RETURN_FALSE;
990 }
991 END_HANDLE_TH_ERRORS
992}
993
994// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,
995// cppcoreguidelines-avoid-non-const-global-variables, modernize-avoid-c-arrays)
996static PyMethodDef TorchMethods[] = {
997 {"_initExtension", THPModule_initExtension, METH_O, nullptr},
998 {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
999 {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr},
1000 {"_init_names", THPModule_initNames, METH_O, nullptr},
1001 {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr},
1002 {"_set_default_tensor_type",
1003 THPModule_setDefaultTensorType,
1004 METH_O,
1005 nullptr},
1006 {"_set_default_dtype", THPModule_setDefaultDtype, METH_O, nullptr},
1007 {"_infer_size", THPModule_inferSize, METH_VARARGS, nullptr},
1008 {"_crash_if_csrc_asan", THPModule_crashIfCsrcASAN, METH_O, nullptr},
1009 {"_crash_if_csrc_ubsan", THPModule_crashIfCsrcUBSAN, METH_O, nullptr},
1010 {"_crash_if_vptr_ubsan", THPModule_crashIfvptrUBSAN, METH_NOARGS, nullptr},
1011 {"_crash_if_aten_asan", THPModule_crashIfATenASAN, METH_O, nullptr},
1012 {"_show_config", THPModule_showConfig, METH_NOARGS, nullptr},
1013 {"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr},
1014 {"_parallel_info", THPModule_parallelInfo, METH_NOARGS, nullptr},
1015 {"_set_backcompat_broadcast_warn",
1016 THPModule_setBackcompatBroadcastWarn,
1017 METH_O,
1018 nullptr},
1019 {"_get_backcompat_broadcast_warn",
1020 THPModule_getBackcompatBroadcastWarn,
1021 METH_NOARGS,
1022 nullptr},
1023 {"_set_backcompat_keepdim_warn",
1024 THPModule_setBackcompatKeepdimWarn,
1025 METH_O,
1026 nullptr},
1027 {"_get_backcompat_keepdim_warn",
1028 THPModule_getBackcompatKeepdimWarn,
1029 METH_NOARGS,
1030 nullptr},
1031 {"get_num_threads", THPModule_getNumThreads, METH_NOARGS, nullptr},
1032 {"set_num_threads", THPModule_setNumThreads, METH_O, nullptr},
1033 {"get_num_interop_threads",
1034 THPModule_getNumInteropThreads,
1035 METH_NOARGS,
1036 nullptr},
1037 {"set_num_interop_threads",
1038 THPModule_setNumInteropThreads,
1039 METH_O,
1040 nullptr},
1041 {"_get_flash_sdp_enabled",
1042 THPModule_userEnabledFlashSDP,
1043 METH_NOARGS,
1044 nullptr},
1045 {"_set_sdp_use_flash", THPModule_setSDPUseFlash, METH_O, nullptr},
1046 {"_get_mem_efficient_sdp_enabled",
1047 userEnabledMemEfficientSDP,
1048 METH_NOARGS,
1049 nullptr},
1050 {"_set_sdp_use_mem_efficient",
1051 THPModule_setSDPUseMemEfficient,
1052 METH_O,
1053 nullptr},
1054 {"_get_math_sdp_enabled",
1055 THPModule_userEnabledMathSDP,
1056 METH_NOARGS,
1057 nullptr},
1058 {"_set_sdp_use_math", THPModule_setSDPUseMath, METH_O, nullptr},
1059 {"_get_cudnn_enabled", THPModule_userEnabledCuDNN, METH_NOARGS, nullptr},
1060 {"_set_cudnn_enabled", THPModule_setUserEnabledCuDNN, METH_O, nullptr},
1061 {"_get_mkldnn_enabled", THPModule_userEnabledMkldnn, METH_NOARGS, nullptr},
1062 {"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr},
1063 {"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr},
1064 {"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr},
1065 {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
1066 {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr},
1067 {"_get_cudnn_deterministic",
1068 THPModule_deterministicCuDNN,
1069 METH_NOARGS,
1070 nullptr},
1071 {"_set_cudnn_deterministic",
1072 THPModule_setDeterministicCuDNN,
1073 METH_O,
1074 nullptr},
1075 {"_get_deterministic_algorithms",
1076 THPModule_deterministicAlgorithms,
1077 METH_NOARGS,
1078 nullptr},
1079 {"_get_deterministic_algorithms_warn_only",
1080 THPModule_deterministicAlgorithmsWarnOnly,
1081 METH_NOARGS,
1082 nullptr},
1083 {"_set_deterministic_algorithms",
1084 castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms),
1085 METH_VARARGS | METH_KEYWORDS,
1086 nullptr},
1087 {"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr},
1088 {"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
1089 {"_warn", THPModule_warn, METH_NOARGS, nullptr},
1090 {"_warn_deprecation", THPModule_warnDeprecation, METH_NOARGS, nullptr},
1091 {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr},
1092 {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr},
1093 {"_get_float32_matmul_precision",
1094 THPModule_float32MatmulPrecision,
1095 METH_NOARGS,
1096 nullptr},
1097 {"_set_float32_matmul_precision",
1098 THPModule_setFloat32MatmulPrecision,
1099 METH_O,
1100 nullptr},
1101 {"_get_cublas_allow_fp16_reduced_precision_reduction",
1102 THPModule_allowFP16ReductionCuBLAS,
1103 METH_NOARGS,
1104 nullptr},
1105 {"_set_cublas_allow_fp16_reduced_precision_reduction",
1106 THPModule_setAllowFP16ReductionCuBLAS,
1107 METH_O,
1108 nullptr},
1109 {"_get_cublas_allow_bf16_reduced_precision_reduction",
1110 THPModule_allowBF16ReductionCuBLAS,
1111 METH_NOARGS,
1112 nullptr},
1113 {"_set_cublas_allow_bf16_reduced_precision_reduction",
1114 THPModule_setAllowBF16ReductionCuBLAS,
1115 METH_O,
1116 nullptr},
1117 {"_vmapmode_increment_nesting",
1118 THPModule_vmapmode_increment_nesting,
1119 METH_NOARGS,
1120 nullptr},
1121 {"_vmapmode_decrement_nesting",
1122 THPModule_vmapmode_decrement_nesting,
1123 METH_NOARGS,
1124 nullptr},
1125 {"_debug_only_display_vmap_fallback_warnings",
1126 THPModule_set_display_vmap_fallback_warnings_mode,
1127 METH_O,
1128 nullptr},
1129 {"_debug_only_are_vmap_fallback_warnings_enabled",
1130 THPModule_are_vmap_fallback_warnings_enabled,
1131 METH_NOARGS,
1132 nullptr},
1133 {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr},
1134 {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr},
1135 {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr},
1136 {"_rename_privateuse1_backend",
1137 THModule_rename_privateuse1_backend,
1138 METH_O,
1139 nullptr},
1140 {"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr},
1141 {"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr},
1142 {"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr},
1143 {"_get_qengine", THPModule_qEngine, METH_NOARGS, nullptr},
1144 {"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
1145 {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
1146 {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
1147 {"_set_check_sparse_tensor_invariants",
1148 THPModule_setCheckSparseTensorInvariants,
1149 METH_O,
1150 nullptr},
1151 {"_check_sparse_tensor_invariants",
1152 THPModule_checkSparseTensorInvariants,
1153 METH_NOARGS,
1154 nullptr},
1155 {"_will_engine_execute_node",
1156 THPModule_willEngineExecuteNode,
1157 METH_O,
1158 nullptr},
1159 {"_current_graph_task_execution_order",
1160 THPModule_getCurrentGraphTaskExecutionOrder,
1161 METH_NOARGS,
1162 nullptr},
1163 {"_current_graph_task_id",
1164 THPModule_getCurrentGraphTaskId,
1165 METH_NOARGS,
1166 nullptr},
1167 {"_current_autograd_node", THPModule_getCurrentNode, METH_NOARGS, nullptr},
1168 {"_set_default_mobile_cpu_allocator",
1169 THPModule_setDefaultMobileCPUAllocator,
1170 METH_NOARGS,
1171 nullptr},
1172 {"_unset_default_mobile_cpu_allocator",
1173 THPModule_unsetDefaultMobileCPUAllocator,
1174 METH_NOARGS,
1175 nullptr},
1176 {"_is_torch_function_enabled",
1177 THPModule_isEnabledTorchFunction,
1178 METH_NOARGS,
1179 nullptr},
1180 {"_disabled_torch_function_impl",
1181 THPModule_disable_torch_function,
1182 METH_VARARGS,
1183 nullptr},
1184 {"_disabled_torch_dispatch_impl",
1185 THPModule_disable_torch_dispatch,
1186 METH_VARARGS,
1187 nullptr},
1188 {"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr},
1189 {"_has_torch_function_unary",
1190 THPModule_has_torch_function_unary,
1191 METH_O,
1192 nullptr},
1193 {"_has_torch_function_variadic",
1194 (PyCFunction)(void (*)(void))THPModule_has_torch_function_variadic,
1195 METH_FASTCALL,
1196 nullptr},
1197 {nullptr, nullptr, 0, nullptr}};
1198
1199void THCPStream_init(PyObject* module);
1200void THCPEvent_init(PyObject* module);
1201void THCPGraph_init(PyObject* module);
1202
1203#ifdef USE_CUDA
1204PyMethodDef* THCPModule_methods();
1205namespace torch {
1206namespace cuda {
1207
1208void initModule(PyObject* module);
1209
1210}
1211} // namespace torch
1212#endif
1213
1214#ifdef USE_ITT
1215namespace torch {
1216namespace profiler {
1217void initIttBindings(PyObject* module);
1218} // namespace profiler
1219} // namespace torch
1220#endif
1221
1222namespace torch {
1223void initVerboseBindings(PyObject* module);
1224} // namespace torch
1225
1226static std::vector<PyMethodDef> methods;
1227
1228// In Python we can't use the trick of C10_LOG_API_USAGE_ONCE
1229// Guaranteed to be invoked from Python under GIL, no locking on map needed
1230static void LogAPIUsageOnceFromPython(const std::string& event) {
1231 static std::unordered_set<std::string> seen;
1232 if (!seen.count(event)) {
1233 seen.insert(event);
1234 c10::LogAPIUsage(event);
1235 }
1236}
1237
1238// Weak reference to tensor, used to test a tensor isn't leaked
1239class WeakTensorRef {
1240 c10::weak_intrusive_ptr<c10::TensorImpl> weakref_;
1241
1242 public:
1243 WeakTensorRef(const at::Tensor& t) : weakref_(t.getIntrusivePtr()) {}
1244
1245 bool expired() {
1246 return weakref_.expired();
1247 }
1248};
1249
1250extern "C"
1251#ifdef _WIN32
1252 __declspec(dllexport)
1253#endif
1254 TORCH_API PyObject* initModule();
1255// separate decl and defn for msvc error C2491
1256PyObject* initModule() {
1257 HANDLE_TH_ERRORS
1258
1259 c10::initLogging();
1260
1261 at::internal::lazy_init_num_threads();
1262
1263 C10_LOG_API_USAGE_ONCE("torch.python.import");
1264
1265// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
1266#define ASSERT_TRUE(cmd) \
1267 if (!(cmd)) \
1268 return nullptr
1269
1270 THPUtils_addPyMethodDefs(methods, TorchMethods);
1271 THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
1272 THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
1273 THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
1274#ifdef USE_CUDA
1275 THPUtils_addPyMethodDefs(methods, THCPModule_methods());
1276#endif
1277#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
1278 THPUtils_addPyMethodDefs(
1279 methods, torch::distributed::c10d::python_functions());
1280#ifndef _WIN32
1281 THPUtils_addPyMethodDefs(
1282 methods, torch::distributed::rpc::python_functions());
1283 THPUtils_addPyMethodDefs(
1284 methods, torch::distributed::autograd::python_functions());
1285 THPUtils_addPyMethodDefs(
1286 methods, torch::distributed::rpc::testing::python_functions());
1287#endif
1288#endif
1289
1290 static struct PyModuleDef torchmodule = {
1291 PyModuleDef_HEAD_INIT, "torch._C", nullptr, -1, methods.data()};
1292 ASSERT_TRUE(module = PyModule_Create(&torchmodule));
1293 ASSERT_TRUE(THPGenerator_init(module));
1294 ASSERT_TRUE(THPException_init(module));
1295 THPSize_init(module);
1296 THPDtype_init(module);
1297 THPDTypeInfo_init(module);
1298 THPLayout_init(module);
1299 THPMemoryFormat_init(module);
1300 THPQScheme_init(module);
1301 THPDevice_init(module);
1302 THPStream_init(module);
1303 ASSERT_TRUE(THPVariable_initModule(module));
1304 ASSERT_TRUE(THPFunction_initModule(module));
1305 ASSERT_TRUE(THPEngine_initModule(module));
1306 // NOTE: We need to be able to access OperatorExportTypes from ONNX for use in
1307 // the export side of JIT, so this ONNX init needs to appear before the JIT
1308 // init.
1309 torch::onnx::initONNXBindings(module);
1310 torch::autograd::initEnumTag(module);
1311 torch::jit::initJITBindings(module);
1312 torch::monitor::initMonitorBindings(module);
1313 torch::impl::dispatch::initDispatchBindings(module);
1314 torch::dynamo::initDynamoBindings(module);
1315 torch::functorch::impl::initFuncTorchBindings(module);
1316 torch::throughput_benchmark::initThroughputBenchmarkBindings(module);
1317 torch::autograd::initReturnTypes(module);
1318 torch::autograd::initNNFunctions(module);
1319 torch::autograd::initFFTFunctions(module);
1320 torch::autograd::initLinalgFunctions(module);
1321 torch::autograd::initNestedFunctions(module);
1322 torch::autograd::initSparseFunctions(module);
1323 torch::autograd::initSpecialFunctions(module);
1324 torch::autograd::init_legacy_variable(module);
1325 torch::profiler::initPythonBindings(module);
1326 torch::python::init_bindings(module);
1327 torch::lazy::initLazyBindings(module);
1328#ifdef USE_ITT
1329 torch::profiler::initIttBindings(module);
1330#endif
1331#ifdef USE_CUDA
1332 torch::cuda::initModule(module);
1333#endif
1334 torch::initVerboseBindings(module);
1335 ASSERT_TRUE(THPStorage_init(module));
1336
1337#ifdef USE_CUDA
1338 // This will only initialise base classes and attach them to library namespace
1339 // They won't be ready for real usage until importing cuda module, that will
1340 // complete the process (but it defines Python classes before calling back
1341 // into C, so these lines have to execute first)..
1342 THCPStream_init(module);
1343 THCPEvent_init(module);
1344 THCPGraph_init(module);
1345#endif
1346
1347 auto set_module_attr =
1348 [&](const char* name, PyObject* v, bool incref = true) {
1349 // PyModule_AddObject steals reference
1350 if (incref) {
1351 Py_INCREF(v);
1352 }
1353 return PyModule_AddObject(module, name, v) == 0;
1354 };
1355
1356#if defined(USE_CUDNN) || defined(USE_ROCM)
1357 PyObject* has_cudnn = Py_True;
1358#else
1359 PyObject* has_cudnn = Py_False;
1360#endif
1361 ASSERT_TRUE(set_module_attr("has_cudnn", has_cudnn));
1362
1363#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
1364 PyObject* has_spectral = Py_True;
1365#else
1366 PyObject* has_spectral = Py_False;
1367#endif
1368 ASSERT_TRUE(set_module_attr("has_spectral", has_spectral));
1369
1370 // force ATen to initialize because it handles
1371 // setting up TH Errors so that they throw C++ exceptions
1372 at::init();
1373
1374 // Automatically translate errors thrown from pybind11 functions
1375 py::register_exception_translator([](std::exception_ptr e) { // NOLINT
1376 try {
1377 if (e) {
1378 std::rethrow_exception(e);
1379 }
1380 }
1381 CATCH_TH_ERRORS()
1382 });
1383
1384 auto py_module = py::reinterpret_borrow<py::module>(module);
1385 py_module.def("_demangle", &c10::demangle);
1386 py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython);
1387
1388 py_module.def("vitals_enabled", &at::vitals::torchVitalEnabled);
1389 py_module.def(
1390 "set_vital",
1391 [](const std::string& vital,
1392 const std::string& attr,
1393 const std::string value) {
1394 return at::vitals::VitalsAPI.setVital(vital, attr, value);
1395 });
1396 py_module.def(
1397 "read_vitals", []() { return at::vitals::VitalsAPI.readVitals(); });
1398
1399 py_module.def(
1400 "init_num_threads",
1401 torch::wrap_pybind_function(at::init_num_threads),
1402 R"(
1403init_num_threads()
1404
1405Initializes the number of parallel threads used on the current thread.
1406
1407Call this whenever a new thread is created in order to propagate values from
1408:func:`torch.set_num_threads` onto the new thread.
1409)");
1410
1411 ASSERT_TRUE(
1412 set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
1413 ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
1414 ASSERT_TRUE(
1415 set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
1416
1417 py_module.def("_valgrind_supported_platform", []() {
1418#if defined(USE_VALGRIND)
1419 return true;
1420#else
1421 return false;
1422#endif
1423 });
1424
1425 py_module.def("_valgrind_toggle", []() {
1426#if defined(USE_VALGRIND)
1427 CALLGRIND_TOGGLE_COLLECT;
1428#else
1429 TORCH_CHECK(false, "Valgrind is not supported.");
1430#endif
1431 });
1432
1433 py_module.def("_valgrind_toggle_and_dump_stats", []() {
1434#if defined(USE_VALGRIND)
1435 // NB: If we don't toggle collect around dump stats, callgrind_annotate
1436 // won't process the results correctly. Specifically,
1437 // `callgrind_annotate --inclusive=no` will be almost completely empty.
1438 CALLGRIND_TOGGLE_COLLECT;
1439 CALLGRIND_DUMP_STATS;
1440#else
1441 TORCH_CHECK(false, "Valgrind is not supported.");
1442#endif
1443 });
1444
1445 py::class_<WeakTensorRef>(py_module, "_WeakTensorRef")
1446 .def(py::init([](py::object tensor) {
1447 return WeakTensorRef(THPVariable_Unpack(tensor.ptr()));
1448 }))
1449 .def("expired", &WeakTensorRef::expired);
1450
1451 py::enum_<at::native::ConvBackend>(py_module, "_ConvBackend")
1452 .value("CudaDepthwise2d", at::native::ConvBackend::CudaDepthwise2d)
1453 .value("CudaDepthwise3d", at::native::ConvBackend::CudaDepthwise3d)
1454 .value("Cudnn", at::native::ConvBackend::Cudnn)
1455 .value("CudnnTranspose", at::native::ConvBackend::CudnnTranspose)
1456 .value("Empty", at::native::ConvBackend::Empty)
1457 .value("Miopen", at::native::ConvBackend::Miopen)
1458 .value("MiopenDepthwise", at::native::ConvBackend::MiopenDepthwise)
1459 .value("MiopenTranspose", at::native::ConvBackend::MiopenTranspose)
1460 .value("Mkldnn", at::native::ConvBackend::Mkldnn)
1461 .value("MkldnnEmpty", at::native::ConvBackend::MkldnnEmpty)
1462 .value("NnpackSpatial", at::native::ConvBackend::NnpackSpatial)
1463 .value("Overrideable", at::native::ConvBackend::Overrideable)
1464 .value("Slow2d", at::native::ConvBackend::Slow2d)
1465 .value("Slow3d", at::native::ConvBackend::Slow3d)
1466 .value("SlowDilated2d", at::native::ConvBackend::SlowDilated2d)
1467 .value("SlowDilated3d", at::native::ConvBackend::SlowDilated3d)
1468 .value("SlowTranspose2d", at::native::ConvBackend::SlowTranspose2d)
1469 .value("SlowTranspose3d", at::native::ConvBackend::SlowTranspose3d)
1470 .value(
1471 "Winograd3x3Depthwise", at::native::ConvBackend::Winograd3x3Depthwise)
1472 .value("Xnnpack2d", at::native::ConvBackend::Xnnpack2d)
1473 .value("Mps", at::native::ConvBackend::Mps)
1474 .value("MpsTranspose,", at::native::ConvBackend::MpsTranspose);
1475
1476 py_module.def(
1477 "_select_conv_backend",
1478 [](const at::Tensor& input,
1479 const at::Tensor& weight,
1480 const c10::optional<at::Tensor>& bias_opt,
1481 at::IntArrayRef stride_,
1482 at::SymIntArrayRef padding_,
1483 at::IntArrayRef dilation_,
1484 bool transposed_,
1485 at::SymIntArrayRef output_padding_,
1486 int64_t groups_) {
1487 return at::native::select_conv_backend(
1488 input,
1489 weight,
1490 bias_opt,
1491 stride_,
1492 padding_,
1493 dilation_,
1494 transposed_,
1495 output_padding_,
1496 groups_,
1497 c10::nullopt);
1498 },
1499 py::arg("input"),
1500 py::arg("weight"),
1501 py::arg("bias"),
1502 py::arg("stride"),
1503 py::arg("padding"),
1504 py::arg("dilation"),
1505 py::arg("transposed"),
1506 py::arg("output_padding"),
1507 py::arg("groups"));
1508
1509 // overload for bias_sizes_opt/backward TODO: figure out default value
1510 py_module.def(
1511 "_select_conv_backend",
1512 [](const at::Tensor& input,
1513 const at::Tensor& weight,
1514 const c10::optional<at::Tensor>& bias,
1515 at::IntArrayRef stride_,
1516 at::SymIntArrayRef padding_,
1517 at::IntArrayRef dilation_,
1518 bool transposed_,
1519 at::SymIntArrayRef output_padding_,
1520 int64_t groups_,
1521 c10::optional<std::vector<c10::SymInt>> bias_sizes_opt) {
1522 c10::OptionalArrayRef<c10::SymInt> ref = c10::nullopt;
1523 if (bias_sizes_opt) {
1524 ref = (*bias_sizes_opt);
1525 }
1526 return at::native::select_conv_backend(
1527 input,
1528 weight,
1529 bias,
1530 stride_,
1531 padding_,
1532 dilation_,
1533 transposed_,
1534 output_padding_,
1535 groups_,
1536 ref);
1537 },
1538 py::arg("input"),
1539 py::arg("weight"),
1540 py::arg("bias"),
1541 py::arg("stride"),
1542 py::arg("padding"),
1543 py::arg("dilation"),
1544 py::arg("transposed"),
1545 py::arg("output_padding"),
1546 py::arg("groups"),
1547 py::arg("bias_sizes"));
1548
1549 py_module.def(
1550 "_conv_determine_backend_memory_format",
1551 at::native::_determine_backend_memory_format);
1552
1553 py::enum_<at::LinalgBackend>(py_module, "_LinalgBackend")
1554 .value("Default", at::LinalgBackend::Default)
1555 .value("Cusolver", at::LinalgBackend::Cusolver)
1556 .value("Magma", at::LinalgBackend::Magma);
1557
1558 py_module.def("_set_linalg_preferred_backend", [](at::LinalgBackend b) {
1559 at::globalContext().setLinalgPreferredBackend(b);
1560 });
1561 py_module.def("_get_linalg_preferred_backend", []() {
1562 return at::globalContext().linalgPreferredBackend();
1563 });
1564
1565 py_module.def("_stash_obj_in_tls", [](std::string key, py::handle arg) {
1566 at::impl::ThreadLocalPythonObjects::get_state().set(
1567 key,
1568 std::make_shared<c10::SafePyObject>(arg.ptr(), getPyInterpreter()));
1569 });
1570
1571 py_module.def("_get_obj_in_tls", [](std::string key) -> py::handle {
1572 auto safe_pyobject =
1573 at::impl::ThreadLocalPythonObjects::get_state().get(key);
1574 auto obj = safe_pyobject->ptr(getPyInterpreter());
1575 return py::handle(obj);
1576 });
1577
1578 py_module.def("_is_key_in_tls", [](std::string key) -> bool {
1579 return at::impl::ThreadLocalPythonObjects::get_state().contains(key);
1580 });
1581
1582#ifdef USE_CUDA
1583 PyObject* has_cuda = Py_True;
1584#else
1585 PyObject* has_cuda = Py_False;
1586#endif
1587
1588#ifdef USE_MPS
1589 PyObject* has_mps = Py_True;
1590#else
1591 PyObject* has_mps = Py_False;
1592#endif
1593
1594 ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
1595 ASSERT_TRUE(set_module_attr("has_mps", has_mps));
1596 py_module.def("_is_mps_available", []() { return at::hasMPS(); });
1597 py_module.def("_is_mps_on_macos_13_or_newer", []() {
1598#ifdef USE_MPS
1599 return at::mps::is_macos_13_or_newer();
1600#else
1601 return false;
1602#endif
1603 });
1604
1605 ASSERT_TRUE(
1606 set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
1607
1608#ifdef _GLIBCXX_USE_CXX11_ABI
1609 ASSERT_TRUE(set_module_attr(
1610 "_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False));
1611#else
1612 ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False));
1613#endif
1614
1615// See note [Pybind11 ABI constants]
1616#define SET_STR_DEFINE(name) \
1617 ASSERT_TRUE(set_module_attr("_" #name, THPUtils_packString(name)))
1618
1619#ifdef PYBIND11_COMPILER_TYPE
1620 SET_STR_DEFINE(PYBIND11_COMPILER_TYPE);
1621#else
1622 ASSERT_TRUE(
1623 set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None));
1624#endif
1625
1626#ifdef PYBIND11_STDLIB
1627 SET_STR_DEFINE(PYBIND11_STDLIB);
1628#else
1629 ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_STDLIB), Py_None));
1630#endif
1631
1632#ifdef PYBIND11_BUILD_ABI
1633 SET_STR_DEFINE(PYBIND11_BUILD_ABI);
1634#else
1635 ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_BUILD_ABI), Py_None));
1636#endif
1637#undef SET_STR_DEFINE
1638
1639 py_module.def(
1640 "_set_conj", [](const at::Tensor& x, bool conj) { x._set_conj(conj); });
1641 py_module.def(
1642 "_set_neg", [](const at::Tensor& x, bool neg) { x._set_neg(neg); });
1643 py_module.def("_get_tensor_metadata", &torch::jit::getTensorMetadata);
1644 py_module.def(
1645 "_set_tensor_metadata",
1646 static_cast<void (*)(
1647 const at::Tensor&, std::unordered_map<std::string, bool>)>(
1648 torch::jit::setTensorMetadata));
1649 py_module.def("_dispatch_key_set", [](const at::Tensor& x) {
1650 return toString(x.key_set());
1651 });
1652 py_module.def(
1653 "_has_storage", [](const at::Tensor& x) { return x.has_storage(); });
1654
1655 py_module.def("_set_meta_in_tls_dispatch_include", [](bool meta_in_tls) {
1656 auto local_keyset = c10::impl::tls_local_dispatch_key_set();
1657 c10::DispatchKeySet key_set({at::DispatchKey::Meta});
1658 if (meta_in_tls) {
1659 local_keyset.included_ = local_keyset.included_ | key_set;
1660 } else {
1661 local_keyset.included_ =
1662 local_keyset.included_.remove_backend(c10::BackendComponent::MetaBit);
1663 }
1664 c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
1665 });
1666
1667 py_module.def("_meta_in_tls_dispatch_include", []() {
1668 auto local_keyset = c10::impl::tls_local_dispatch_key_set();
1669 return local_keyset.included_.has_backend(c10::BackendComponent::MetaBit);
1670 });
1671
1672 py_module.def("_dump_local_tls_set", []() {
1673 auto local_keyset = c10::impl::tls_local_dispatch_key_set();
1674 std::cout << "Included: " << toString(local_keyset.included_) << "\n";
1675 std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
1676 });
1677
1678 py_module.def(
1679 "_should_allow_numbers_as_tensors", [](const std::string& name) {
1680 return torch::should_allow_numbers_as_tensors(name);
1681 });
1682
1683 const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
1684 THPDefaultCPUGenerator =
1685 (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
1686 // This reference is meant to be given away, so no need to incref here.
1687 ASSERT_TRUE(set_module_attr(
1688 "default_generator",
1689 (PyObject*)THPDefaultCPUGenerator,
1690 /* incref= */ false));
1691 ASSERT_TRUE(set_module_attr(
1692 "DisableTorchFunctionSubclass",
1693 (PyObject*)THPModule_DisableTorchFunctionSubclassType(),
1694 /* incref= */ false));
1695 ASSERT_TRUE(set_module_attr(
1696 "DisableTorchFunction",
1697 (PyObject*)THPModule_DisableTorchFunctionType(),
1698 /* incref= */ false));
1699 torch::set_disabled_torch_function_impl(
1700 PyObject_GetAttrString(module, "_disabled_torch_function_impl"));
1701 ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr);
1702 torch::set_disabled_torch_dispatch_impl(
1703 PyObject_GetAttrString(module, "_disabled_torch_dispatch_impl"));
1704 ASSERT_TRUE(torch::disabled_torch_dispatch_impl() != nullptr);
1705 return module;
1706 END_HANDLE_TH_ERRORS
1707}
1708
1709// Checks that the _C shared library isn't initialized multiple times. This
1710// can happen if the same csrc files are compiled into multiple shared
1711// libraries.
1712inline void pytorch_duplicate_guard() {
1713 static int initialized = 0;
1714 if (initialized) {
1715 fprintf(stderr, "pytorch: _C shared library re-initialized\n");
1716 abort();
1717 }
1718 initialized = 1;
1719 ;
1720}
1721
1722struct call_duplicate_guard {
1723 call_duplicate_guard() {
1724 pytorch_duplicate_guard();
1725 }
1726};
1727
1728static call_duplicate_guard _call_duplicate_guard;
1729