1 | #include <c10/util/Optional.h> |
2 | #include <sys/types.h> |
3 | #include <torch/csrc/python_headers.h> |
4 | |
5 | #ifndef _MSC_VER |
6 | #include <sys/socket.h> |
7 | #endif |
8 | |
9 | #include <ATen/ATen.h> |
10 | #include <ATen/DLConvertor.h> |
11 | #include <ATen/ExpandUtils.h> |
12 | #include <ATen/LegacyVmapMode.h> |
13 | #include <ATen/LinalgBackend.h> |
14 | #include <ATen/Parallel.h> |
15 | #include <ATen/Utils.h> |
16 | #include <ATen/core/Vitals.h> |
17 | #include <ATen/dlpack.h> |
18 | #include <ATen/native/ConvUtils.h> |
19 | #include <c10/core/DispatchKeySet.h> |
20 | #include <c10/util/Logging.h> |
21 | #include <c10/util/irange.h> |
22 | #include <libshm.h> |
23 | #include <pybind11/pybind11.h> |
24 | #include <pybind11/stl.h> |
25 | #include <torch/csrc/THConcat.h> |
26 | #include <torch/csrc/utils/pybind.h> |
27 | #include <cstdlib> |
28 | #include <unordered_map> |
29 | |
30 | #include <ATen/ThreadLocalPythonObjects.h> |
31 | #include <torch/csrc/DataLoader.h> |
32 | #include <torch/csrc/Device.h> |
33 | #include <torch/csrc/Dtype.h> |
34 | #include <torch/csrc/DynamicTypes.h> |
35 | #include <torch/csrc/Generator.h> |
36 | #include <torch/csrc/Layout.h> |
37 | #include <torch/csrc/MemoryFormat.h> |
38 | #include <torch/csrc/QScheme.h> |
39 | #include <torch/csrc/Stream.h> |
40 | #include <torch/csrc/THP.h> |
41 | #include <torch/csrc/TypeInfo.h> |
42 | #include <torch/csrc/api/include/torch/python/init.h> |
43 | #include <torch/csrc/autograd/python_cpp_function.h> |
44 | #include <torch/csrc/autograd/python_enum_tag.h> |
45 | #include <torch/csrc/autograd/python_fft_functions.h> |
46 | #include <torch/csrc/autograd/python_function.h> |
47 | #include <torch/csrc/autograd/python_legacy_variable.h> |
48 | #include <torch/csrc/autograd/python_linalg_functions.h> |
49 | #include <torch/csrc/autograd/python_nested_functions.h> |
50 | #include <torch/csrc/autograd/python_nn_functions.h> |
51 | #include <torch/csrc/autograd/python_return_types.h> |
52 | #include <torch/csrc/autograd/python_sparse_functions.h> |
53 | #include <torch/csrc/autograd/python_special_functions.h> |
54 | #include <torch/csrc/autograd/python_variable.h> |
55 | #include <torch/csrc/dynamo/init.h> |
56 | #include <torch/csrc/functorch/init.h> |
57 | #include <torch/csrc/jit/python/init.h> |
58 | #include <torch/csrc/jit/python/python_ir.h> |
59 | #include <torch/csrc/jit/python/python_tracer.h> |
60 | #include <torch/csrc/jit/serialization/pickler.h> |
61 | #include <torch/csrc/lazy/python/init.h> |
62 | #include <torch/csrc/monitor/python_init.h> |
63 | #include <torch/csrc/multiprocessing/init.h> |
64 | #include <torch/csrc/onnx/init.h> |
65 | #include <torch/csrc/profiler/python/init.h> |
66 | #include <torch/csrc/tensor/python_tensor.h> |
67 | #include <torch/csrc/utils/disable_torch_function.h> |
68 | #include <torch/csrc/utils/init.h> |
69 | #include <torch/csrc/utils/pycfunction_helpers.h> |
70 | #include <torch/csrc/utils/python_arg_parser.h> |
71 | #include <torch/csrc/utils/python_compat.h> |
72 | #include <torch/csrc/utils/python_dispatch.h> |
73 | #include <torch/csrc/utils/python_strings.h> |
74 | #include <torch/csrc/utils/tensor_dtypes.h> |
75 | #include <torch/csrc/utils/tensor_layouts.h> |
76 | #include <torch/csrc/utils/tensor_memoryformats.h> |
77 | #include <torch/csrc/utils/tensor_new.h> |
78 | #include <torch/csrc/utils/tensor_numpy.h> |
79 | #include <torch/csrc/utils/tensor_qschemes.h> |
80 | |
81 | #ifdef USE_DISTRIBUTED |
82 | #ifdef USE_C10D |
83 | #include <torch/csrc/distributed/autograd/python_autograd.h> |
84 | #include <torch/csrc/distributed/c10d/c10d.h> |
85 | #include <torch/csrc/distributed/rpc/rpc.h> |
86 | #include <torch/csrc/distributed/rpc/testing/testing.h> |
87 | #endif |
88 | #endif |
89 | |
90 | #if defined(USE_MPS) |
91 | #include <ATen/mps/MPSDevice.h> |
92 | #endif |
93 | |
94 | #if defined(USE_VALGRIND) |
95 | #include <callgrind.h> |
96 | #endif |
97 | |
98 | namespace py = pybind11; |
99 | |
100 | PyObject* module; |
101 | |
102 | THPGenerator* THPDefaultCPUGenerator = nullptr; |
103 | |
104 | //////////////////////////////////////////////////////////////////////////////// |
105 | //////////////////////////////////////////////////////////////////////////////// |
106 | |
107 | static PyObject* THPModule_initNames(PyObject* self, PyObject* arg) { |
108 | static std::vector<std::string> names; |
109 | |
110 | THPObjectPtr types(PySequence_Fast(arg, "expected a sequence" )); |
111 | if (!types) |
112 | return nullptr; |
113 | |
114 | // NOLINTNEXTLINE(bugprone-branch-clone) |
115 | auto num_classes = PySequence_Fast_GET_SIZE(types.get()); |
116 | names.reserve(names.size() + num_classes); |
117 | for (Py_ssize_t i = 0; i < num_classes; i++) { |
118 | PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i); |
119 | THPUtils_assert(PyType_Check(obj), "expected a PyTypeObject" ); |
120 | PyTypeObject* type = (PyTypeObject*)obj; |
121 | |
122 | THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__" )); |
123 | if (!module_name) |
124 | return nullptr; |
125 | THPUtils_assert( |
126 | THPUtils_checkString(module_name.get()), |
127 | "expected __module__ to be a string" ); |
128 | std::string name = THPUtils_unpackString(module_name.get()); |
129 | names.emplace_back(name + "." + type->tp_name); |
130 | type->tp_name = names.back().c_str(); |
131 | } |
132 | Py_RETURN_NONE; |
133 | } |
134 | // |
135 | // Callback for python part. Used for additional initialization of python |
136 | // classes |
137 | static PyObject* THPModule_initExtension( |
138 | PyObject* _unused, |
139 | PyObject* shm_manager_path) { |
140 | HANDLE_TH_ERRORS |
141 | if (!THPUtils_checkString(shm_manager_path)) { |
142 | THPUtils_setError( |
143 | "initialization error - expected bytes/string object as shm_manager_path!" ); |
144 | return nullptr; |
145 | } |
146 | torch::utils::initializeLayouts(); |
147 | torch::utils::initializeMemoryFormats(); |
148 | torch::utils::initializeQSchemes(); |
149 | torch::utils::initializeDtypes(); |
150 | torch::tensors::initialize_python_bindings(); |
151 | std::string path = THPUtils_unpackString(shm_manager_path); |
152 | libshm_init(path.c_str()); |
153 | |
154 | auto module = THPObjectPtr(PyImport_ImportModule("torch" )); |
155 | if (!module) |
156 | throw python_error(); |
157 | |
158 | THPStorage_postInit(module); |
159 | THPAutograd_initFunctions(); |
160 | Py_RETURN_NONE; |
161 | END_HANDLE_TH_ERRORS |
162 | } |
163 | |
164 | // The idea behind these two functions is to make it easy to test if we are |
165 | // built with ASAN: they're designed not to crash if ASAN is not enabled, but |
166 | // to trigger ASAN if it is enabled. This lets us run a "canary" tests which |
167 | // checks if our build environment is misconfigured. |
168 | |
169 | static PyObject* THPModule_crashIfCsrcASAN(PyObject* module, PyObject* arg) { |
170 | THPUtils_assert( |
171 | THPUtils_checkLong(arg), |
172 | "crash_if_csrc_asan expects an int, " |
173 | "but got %s" , |
174 | THPUtils_typename(arg)); |
175 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays) |
176 | volatile char x[3]; |
177 | x[THPUtils_unpackInt(arg)] = 0; |
178 | // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) |
179 | return THPUtils_packInt32(x[0]); |
180 | } |
181 | |
182 | static PyObject* THPModule_crashIfCsrcUBSAN(PyObject* module, PyObject* arg) { |
183 | THPUtils_assert( |
184 | THPUtils_checkLong(arg), |
185 | "crash_if_csrc_ubsan expects an int, " |
186 | "but got %s" , |
187 | THPUtils_typename(arg)); |
188 | int32_t x = THPUtils_unpackInt(arg); |
189 | double y = 1.0 / x; |
190 | return THPUtils_packInt32((int)y); |
191 | } |
192 | |
193 | static PyObject* THPModule_crashIfvptrUBSAN(PyObject* module, PyObject* noarg) { |
194 | // This code shoud work perfectly fine, as vtables are idential for Foo and |
195 | // Baz unless rtti and ubsan are enabled |
196 | struct Foo { |
197 | virtual int bar() = 0; |
198 | virtual ~Foo() = default; |
199 | }; |
200 | struct Baz { |
201 | virtual int bar() { |
202 | return 17; |
203 | } |
204 | virtual ~Baz() = default; |
205 | }; |
206 | Baz x{}; |
207 | auto y = static_cast<Foo*>(static_cast<void*>(&x)); |
208 | auto rc = y->bar(); |
209 | return THPUtils_packInt32(rc); |
210 | } |
211 | |
212 | static PyObject* THPModule_crashIfATenASAN(PyObject* module, PyObject* arg) { |
213 | THPUtils_assert( |
214 | THPUtils_checkLong(arg), |
215 | "crash_if_aten_asan expects an int, " |
216 | "but got %s" , |
217 | THPUtils_typename(arg)); |
218 | return THPUtils_packInt32(at::_crash_if_asan(THPUtils_unpackInt(arg))); |
219 | } |
220 | |
221 | static PyObject* THPModule_getNumThreads(PyObject* module, PyObject* noargs) { |
222 | return THPUtils_packInt32(at::get_num_threads()); |
223 | } |
224 | |
225 | static PyObject* THPModule_setNumThreads(PyObject* module, PyObject* arg) { |
226 | THPUtils_assert( |
227 | THPUtils_checkLong(arg), |
228 | "set_num_threads expects an int, " |
229 | "but got %s" , |
230 | THPUtils_typename(arg)); |
231 | int nthreads = (int)THPUtils_unpackLong(arg); |
232 | THPUtils_assert(nthreads > 0, "set_num_threads expects a positive integer" ); |
233 | at::set_num_threads(nthreads); |
234 | Py_RETURN_NONE; |
235 | } |
236 | |
237 | static PyObject* THPModule_getNumInteropThreads( |
238 | PyObject* module, |
239 | PyObject* noargs) { |
240 | return THPUtils_packInt32(at::get_num_interop_threads()); |
241 | } |
242 | |
243 | static PyObject* THPModule_setNumInteropThreads( |
244 | PyObject* module, |
245 | PyObject* arg) { |
246 | THPUtils_assert( |
247 | THPUtils_checkLong(arg), |
248 | "set_num_interop_threads expects an int, " |
249 | "but got %s" , |
250 | THPUtils_typename(arg)); |
251 | int nthreads = (int)THPUtils_unpackLong(arg); |
252 | THPUtils_assert( |
253 | nthreads > 0, "set_num_interop_threads expects a positive integer" ); |
254 | at::set_num_interop_threads(nthreads); |
255 | Py_RETURN_NONE; |
256 | } |
257 | |
258 | PyObject* THPModule_setDefaultTensorType(PyObject* _unused, PyObject* type) { |
259 | HANDLE_TH_ERRORS |
260 | torch::tensors::py_set_default_tensor_type(type); |
261 | Py_RETURN_NONE; |
262 | END_HANDLE_TH_ERRORS |
263 | } |
264 | |
265 | PyObject* THPModule_setDefaultDtype(PyObject* _unused, PyObject* dtype) { |
266 | HANDLE_TH_ERRORS |
267 | torch::tensors::py_set_default_dtype(dtype); |
268 | Py_RETURN_NONE; |
269 | END_HANDLE_TH_ERRORS |
270 | } |
271 | |
272 | PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) { |
273 | // adds a __doc__ string to a function, similar to numpy's arr_add_docstring |
274 | static std::vector<std::string> all_docs; |
275 | PyObject* obj = nullptr; |
276 | PyObject* doc_obj = nullptr; |
277 | if (!PyArg_ParseTuple(args, "OO" , &obj, &doc_obj)) { |
278 | return nullptr; |
279 | } |
280 | |
281 | const char* doc_str = "<invalid string>" ; |
282 | if (THPUtils_checkString(doc_obj)) { |
283 | all_docs.push_back(THPUtils_unpackString(doc_obj)); |
284 | doc_str = all_docs.back().c_str(); |
285 | } |
286 | |
287 | if (Py_TYPE(obj) == &PyCFunction_Type) { |
288 | PyCFunctionObject* f = (PyCFunctionObject*)obj; |
289 | if (f->m_ml->ml_doc) { |
290 | return PyErr_Format( |
291 | PyExc_RuntimeError, |
292 | "function '%s' already has a docstring" , |
293 | f->m_ml->ml_name); |
294 | } |
295 | f->m_ml->ml_doc = doc_str; |
296 | } else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor" ) == 0) { |
297 | PyMethodDescrObject* m = (PyMethodDescrObject*)obj; |
298 | if (m->d_method->ml_doc) { |
299 | return PyErr_Format( |
300 | PyExc_RuntimeError, |
301 | "method '%s' already has a docstring" , |
302 | m->d_method->ml_name); |
303 | } |
304 | m->d_method->ml_doc = doc_str; |
305 | } else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor" ) == 0) { |
306 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) |
307 | PyGetSetDescrObject* m = (PyGetSetDescrObject*)obj; |
308 | if (m->d_getset->doc) { |
309 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) |
310 | return PyErr_Format( |
311 | PyExc_RuntimeError, |
312 | "attribute '%s' already has a docstring" , |
313 | m->d_getset->name); |
314 | } |
315 | // This field is not const for python < 3.7 yet the content is |
316 | // never modified. |
317 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
318 | m->d_getset->doc = const_cast<char*>(doc_str); |
319 | } else if (Py_TYPE(obj) == &PyType_Type) { |
320 | PyTypeObject* t = (PyTypeObject*)obj; |
321 | if (t->tp_doc) { |
322 | return PyErr_Format( |
323 | PyExc_RuntimeError, "Type '%s' already has a docstring" , t->tp_name); |
324 | } |
325 | t->tp_doc = doc_str; |
326 | } else { |
327 | return PyErr_Format( |
328 | PyExc_TypeError, |
329 | "don't know how to add docstring to type '%s'" , |
330 | Py_TYPE(obj)->tp_name); |
331 | } |
332 | |
333 | Py_INCREF(obj); |
334 | return obj; |
335 | } |
336 | |
337 | PyObject* THPModule_inferSize(PyObject* _unused, PyObject* args) { |
338 | HANDLE_TH_ERRORS |
339 | Py_ssize_t num_args = args ? (Py_ssize_t)PyTuple_Size(args) : 0; |
340 | THPUtils_assert(num_args == 2, "expected exactly 2 arguments" ); |
341 | PyObject* arg1 = PyTuple_GET_ITEM(args, 0); |
342 | THPUtils_assert(THPSize_Check(arg1), "expected a torch.Size as argument 1" ); |
343 | PyObject* arg2 = PyTuple_GET_ITEM(args, 1); |
344 | THPUtils_assert(THPSize_Check(arg2), "expected a torch.Size as argument 2" ); |
345 | |
346 | auto size1 = THPUtils_unpackLongs(arg1); |
347 | auto size2 = THPUtils_unpackLongs(arg2); |
348 | auto sizes = at::infer_size(size1, size2); |
349 | return THPSize_NewFromSizes(sizes.size(), sizes.data()); |
350 | END_HANDLE_TH_ERRORS |
351 | } |
352 | |
353 | static PyObject* THPModule_setBackcompatBroadcastWarn( |
354 | PyObject* module, |
355 | PyObject* arg) { |
356 | THPUtils_assert( |
357 | PyBool_Check(arg), |
358 | "set_backcompat_broadcast_warn expects a bool, " |
359 | "but got %s" , |
360 | THPUtils_typename(arg)); |
361 | setBackCompatBroadcastWarn(arg == Py_True); |
362 | Py_RETURN_NONE; |
363 | } |
364 | |
365 | static PyObject* THPModule_getBackcompatBroadcastWarn( |
366 | PyObject* module, |
367 | PyObject* noargs) { |
368 | if (getBackCompatBroadcastWarn()) |
369 | Py_RETURN_TRUE; |
370 | else |
371 | Py_RETURN_FALSE; |
372 | } |
373 | |
374 | static PyObject* THPModule_setBackcompatKeepdimWarn( |
375 | PyObject* module, |
376 | PyObject* arg) { |
377 | THPUtils_assert( |
378 | PyBool_Check(arg), |
379 | "set_backcompat_keepdim_warn expects a bool, " |
380 | "but got %s" , |
381 | THPUtils_typename(arg)); |
382 | setBackCompatKeepdimWarn(arg == Py_True); |
383 | Py_RETURN_NONE; |
384 | } |
385 | |
386 | static PyObject* THPModule_getBackcompatKeepdimWarn( |
387 | PyObject* module, |
388 | PyObject* noargs) { |
389 | if (getBackCompatKeepdimWarn()) |
390 | Py_RETURN_TRUE; |
391 | else |
392 | Py_RETURN_FALSE; |
393 | } |
394 | |
395 | PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) { |
396 | #ifdef USE_DISTRIBUTED |
397 | Py_RETURN_TRUE; |
398 | #else |
399 | Py_RETURN_FALSE; |
400 | #endif |
401 | } |
402 | |
403 | static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) { |
404 | HANDLE_TH_ERRORS |
405 | return THPUtils_packString(at::show_config()); |
406 | END_HANDLE_TH_ERRORS |
407 | } |
408 | |
409 | static PyObject* THPModule_cxxFlags(PyObject* module, PyObject* noargs) { |
410 | HANDLE_TH_ERRORS |
411 | return THPUtils_packString(at::get_cxx_flags()); |
412 | END_HANDLE_TH_ERRORS |
413 | } |
414 | |
415 | static PyObject* THPModule_parallelInfo(PyObject* module, PyObject* noargs) { |
416 | HANDLE_TH_ERRORS |
417 | return THPUtils_packString(at::get_parallel_info()); |
418 | END_HANDLE_TH_ERRORS |
419 | } |
420 | |
421 | void DLPack_Capsule_Destructor(PyObject* data) { |
422 | if (C10_LIKELY(!PyCapsule_IsValid(data, "dltensor" ))) { |
423 | // early out, see DLPack spec: if a consuming library sets the capsule |
424 | // name to something else, they own it and we don't need to do anything |
425 | return; |
426 | } |
427 | HANDLE_TH_ERRORS |
428 | // Causes overheads for validity checks again, but this case is rare |
429 | // since consuming libraries should rename the capsule according to spec. |
430 | // Note that this cannot set a python error (we checked validity above), |
431 | // so we don't need to handle python error state here. |
432 | DLManagedTensor* dlMTensor = |
433 | (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor" ); |
434 | // the dlMTensor has not been consumed, call deleter ourselves. |
435 | // DLPack spec mentions that deleter may be NULL, but deleter from |
436 | // `at::toDLPack` is never NULL, so no need for an additional check here. |
437 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
438 | dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor)); |
439 | END_HANDLE_TH_ERRORS_RET() |
440 | } |
441 | |
442 | PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { |
443 | HANDLE_TH_ERRORS |
444 | THPUtils_assert(THPVariable_Check(data), "data must be a Tensor" ); |
445 | DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data)); |
446 | return PyCapsule_New(dlMTensor, "dltensor" , DLPack_Capsule_Destructor); |
447 | END_HANDLE_TH_ERRORS |
448 | } |
449 | |
450 | PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { |
451 | using namespace torch::autograd; |
452 | HANDLE_TH_ERRORS |
453 | auto tensor = torch::utils::tensor_fromDLPack(data); |
454 | return THPVariable_Wrap(tensor); |
455 | END_HANDLE_TH_ERRORS |
456 | } |
457 | |
458 | PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) { |
459 | HANDLE_TH_ERRORS |
460 | size_t frames_to_skip; |
461 | size_t maximum_number_of_frames; |
462 | if (!PyArg_ParseTuple( |
463 | args, "LL" , &frames_to_skip, &maximum_number_of_frames)) { |
464 | return nullptr; |
465 | } |
466 | return THPUtils_packString( |
467 | c10::get_backtrace(frames_to_skip, maximum_number_of_frames, true)); |
468 | END_HANDLE_TH_ERRORS |
469 | } |
470 | static PyObject* THModule_rename_privateuse1_backend( |
471 | PyObject* _unused, |
472 | PyObject* arg) { |
473 | HANDLE_TH_ERRORS |
474 | THPUtils_assert( |
475 | THPUtils_checkString(arg), |
476 | "_rename_privateuse1_backend expects a str, " |
477 | "but got %s" , |
478 | THPUtils_typename(arg)); |
479 | const std::string backend_name = THPUtils_unpackString(arg); |
480 | c10::register_privateuse1_backend(backend_name); |
481 | Py_RETURN_NONE; |
482 | END_HANDLE_TH_ERRORS |
483 | } |
484 | |
485 | PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) { |
486 | THPUtils_assert( |
487 | PyBool_Check(arg), |
488 | "set_allow_tf32_cublas expects a bool, " |
489 | "but got %s" , |
490 | THPUtils_typename(arg)); |
491 | at::globalContext().setAllowTF32CuDNN(arg == Py_True); |
492 | Py_RETURN_NONE; |
493 | } |
494 | |
495 | PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) { |
496 | if (at::globalContext().allowTF32CuDNN()) |
497 | Py_RETURN_TRUE; |
498 | else |
499 | Py_RETURN_FALSE; |
500 | } |
501 | |
502 | PyObject* THPModule_setFloat32MatmulPrecision( |
503 | PyObject* _unused, |
504 | PyObject* arg) { |
505 | THPUtils_assert( |
506 | THPUtils_checkString(arg), |
507 | "set_float32_matmul_precision expects a str, " |
508 | "but got %s" , |
509 | THPUtils_typename(arg)); |
510 | std::string s = THPUtils_unpackString(arg); |
511 | at::globalContext().setFloat32MatmulPrecision(s); |
512 | Py_RETURN_NONE; |
513 | } |
514 | |
515 | PyObject* THPModule_float32MatmulPrecision( |
516 | PyObject* _unused, |
517 | PyObject* noargs) { |
518 | std::string s = "highest" ; |
519 | auto p = at::globalContext().float32MatmulPrecision(); |
520 | if (p == at::Float32MatmulPrecision::HIGH) { |
521 | s = "high" ; |
522 | } else if (p == at::Float32MatmulPrecision::MEDIUM) { |
523 | s = "medium" ; |
524 | } |
525 | return THPUtils_packString(s); |
526 | } |
527 | PyObject* THPModule_setSDPUseFlash(PyObject* _unused, PyObject* arg) { |
528 | THPUtils_assert( |
529 | PyBool_Check(arg), |
530 | "set_sdp_use_math expects a bool, " |
531 | "but got %s" , |
532 | THPUtils_typename(arg)); |
533 | at::globalContext().setSDPUseFlash(arg == Py_True); |
534 | Py_RETURN_NONE; |
535 | } |
536 | PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) { |
537 | if (at::globalContext().userEnabledFlashSDP()) |
538 | Py_RETURN_TRUE; |
539 | else |
540 | Py_RETURN_FALSE; |
541 | } |
542 | PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) { |
543 | THPUtils_assert( |
544 | PyBool_Check(arg), |
545 | "set_sdp_use_math expects a bool, " |
546 | "but got %s" , |
547 | THPUtils_typename(arg)); |
548 | at::globalContext().setSDPUseMemEfficient(arg == Py_True); |
549 | Py_RETURN_NONE; |
550 | } |
551 | PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) { |
552 | if (at::globalContext().userEnabledMemEfficientSDP()) |
553 | Py_RETURN_TRUE; |
554 | else |
555 | Py_RETURN_FALSE; |
556 | } |
557 | PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) { |
558 | THPUtils_assert( |
559 | PyBool_Check(arg), |
560 | "set_sdp_use_math expects a bool, " |
561 | "but got %s" , |
562 | THPUtils_typename(arg)); |
563 | at::globalContext().setSDPUseMath(arg == Py_True); |
564 | Py_RETURN_NONE; |
565 | } |
566 | PyObject* THPModule_userEnabledMathSDP(PyObject* _unused, PyObject* noargs) { |
567 | if (at::globalContext().userEnabledMathSDP()) |
568 | Py_RETURN_TRUE; |
569 | else |
570 | Py_RETURN_FALSE; |
571 | } |
572 | PyObject* THPModule_setUserEnabledCuDNN(PyObject* _unused, PyObject* arg) { |
573 | THPUtils_assert( |
574 | PyBool_Check(arg), |
575 | "set_enabled_cudnn expects a bool, " |
576 | "but got %s" , |
577 | THPUtils_typename(arg)); |
578 | at::globalContext().setUserEnabledCuDNN(arg == Py_True); |
579 | Py_RETURN_NONE; |
580 | } |
581 | |
582 | PyObject* THPModule_userEnabledCuDNN(PyObject* _unused, PyObject* noargs) { |
583 | if (at::globalContext().userEnabledCuDNN()) |
584 | Py_RETURN_TRUE; |
585 | else |
586 | Py_RETURN_FALSE; |
587 | } |
588 | |
589 | PyObject* THPModule_setUserEnabledMkldnn(PyObject* _unused, PyObject* arg) { |
590 | THPUtils_assert( |
591 | PyBool_Check(arg), |
592 | "set_enabled_mkldnn expects a bool, " |
593 | "but got %s" , |
594 | THPUtils_typename(arg)); |
595 | at::globalContext().setUserEnabledMkldnn(arg == Py_True); |
596 | Py_RETURN_NONE; |
597 | } |
598 | |
599 | PyObject* THPModule_userEnabledMkldnn(PyObject* _unused, PyObject* noargs) { |
600 | if (at::globalContext().userEnabledMkldnn()) |
601 | Py_RETURN_TRUE; |
602 | else |
603 | Py_RETURN_FALSE; |
604 | } |
605 | |
606 | PyObject* THPModule_setDeterministicCuDNN(PyObject* _unused, PyObject* arg) { |
607 | HANDLE_TH_ERRORS |
608 | THPUtils_assert( |
609 | PyBool_Check(arg), |
610 | "set_deterministic_cudnn expects a bool, " |
611 | "but got %s" , |
612 | THPUtils_typename(arg)); |
613 | at::globalContext().setDeterministicCuDNN(arg == Py_True); |
614 | Py_RETURN_NONE; |
615 | END_HANDLE_TH_ERRORS |
616 | } |
617 | |
618 | PyObject* THPModule_deterministicCuDNN(PyObject* _unused, PyObject* noargs) { |
619 | if (at::globalContext().deterministicCuDNN()) |
620 | Py_RETURN_TRUE; |
621 | else |
622 | Py_RETURN_FALSE; |
623 | } |
624 | |
625 | PyObject* THPModule_setDeterministicAlgorithms( |
626 | PyObject* _unused, |
627 | PyObject* args, |
628 | PyObject* kwargs) { |
629 | HANDLE_TH_ERRORS |
630 | static torch::PythonArgParser parser( |
631 | {"_set_deterministic_algorithms(bool mode, *, bool warn_only=False)" }); |
632 | torch::ParsedArgs<2> parsed_args{}; |
633 | auto r = parser.parse(args, kwargs, parsed_args); |
634 | bool mode = r.toBool(0); |
635 | bool warn_only = r.toBool(1); |
636 | at::globalContext().setDeterministicAlgorithms(mode, warn_only); |
637 | Py_RETURN_NONE; |
638 | END_HANDLE_TH_ERRORS |
639 | } |
640 | |
641 | PyObject* THPModule_deterministicAlgorithms( |
642 | PyObject* _unused, |
643 | PyObject* noargs) { |
644 | if (at::globalContext().deterministicAlgorithms()) { |
645 | Py_RETURN_TRUE; |
646 | } |
647 | Py_RETURN_FALSE; |
648 | } |
649 | |
650 | PyObject* THPModule_deterministicAlgorithmsWarnOnly( |
651 | PyObject* _unused, |
652 | PyObject* noargs) { |
653 | if (at::globalContext().deterministicAlgorithmsWarnOnly()) { |
654 | Py_RETURN_TRUE; |
655 | } |
656 | Py_RETURN_FALSE; |
657 | } |
658 | |
659 | PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) { |
660 | THPUtils_assert( |
661 | PyBool_Check(arg), |
662 | "setWarnOnlyOnce expects a bool, " |
663 | "but got %s" , |
664 | THPUtils_typename(arg)); |
665 | c10::WarningUtils::set_warnAlways(arg == Py_True); |
666 | Py_RETURN_NONE; |
667 | } |
668 | |
669 | PyObject* THPModule_warnAlways(PyObject* _unused, PyObject* noargs) { |
670 | if (c10::WarningUtils::get_warnAlways()) { |
671 | Py_RETURN_TRUE; |
672 | } |
673 | Py_RETURN_FALSE; |
674 | } |
675 | |
676 | // Used only for testing C++ to Python warning translations. |
677 | PyObject* THPModule_warn(PyObject* _unused, PyObject* noargs) { |
678 | HANDLE_TH_ERRORS |
679 | TORCH_WARN("Test message for TORCH_WARN" ); |
680 | Py_RETURN_NONE; |
681 | END_HANDLE_TH_ERRORS |
682 | } |
683 | |
684 | // Used only for testing C++ to Python warning translations. |
685 | PyObject* THPModule_warnDeprecation(PyObject* _unused, PyObject* noargs) { |
686 | HANDLE_TH_ERRORS |
687 | TORCH_WARN_DEPRECATION("Test message for TORCH_WARN_DEPRECATION" ); |
688 | Py_RETURN_NONE; |
689 | END_HANDLE_TH_ERRORS |
690 | } |
691 | |
692 | PyObject* THPModule_setBenchmarkCuDNN(PyObject* _unused, PyObject* arg) { |
693 | THPUtils_assert( |
694 | PyBool_Check(arg), |
695 | "set_benchmark_cudnn expects a bool, " |
696 | "but got %s" , |
697 | THPUtils_typename(arg)); |
698 | at::globalContext().setBenchmarkCuDNN(arg == Py_True); |
699 | Py_RETURN_NONE; |
700 | } |
701 | |
702 | PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) { |
703 | if (at::globalContext().benchmarkCuDNN()) { |
704 | Py_RETURN_TRUE; |
705 | } |
706 | Py_RETURN_FALSE; |
707 | } |
708 | |
709 | PyObject* THPModule_setAllowTF32CuBLAS(PyObject* _unused, PyObject* arg) { |
710 | THPUtils_assert( |
711 | PyBool_Check(arg), |
712 | "set_allow_tf32_cublas expects a bool, " |
713 | "but got %s" , |
714 | THPUtils_typename(arg)); |
715 | at::globalContext().setAllowTF32CuBLAS(arg == Py_True); |
716 | Py_RETURN_NONE; |
717 | } |
718 | |
719 | PyObject* THPModule_allowTF32CuBLAS(PyObject* _unused, PyObject* noargs) { |
720 | if (at::globalContext().allowTF32CuBLAS()) { |
721 | Py_RETURN_TRUE; |
722 | } |
723 | Py_RETURN_FALSE; |
724 | } |
725 | |
726 | PyObject* THPModule_setAllowFP16ReductionCuBLAS( |
727 | PyObject* _unused, |
728 | PyObject* arg) { |
729 | THPUtils_assert( |
730 | PyBool_Check(arg), |
731 | "set_allow_fp16_reduction_cublas expects a bool, " |
732 | "but got %s" , |
733 | THPUtils_typename(arg)); |
734 | at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True); |
735 | Py_RETURN_NONE; |
736 | } |
737 | |
738 | PyObject* THPModule_allowFP16ReductionCuBLAS( |
739 | PyObject* _unused, |
740 | PyObject* noargs) { |
741 | if (at::globalContext().allowFP16ReductionCuBLAS()) { |
742 | Py_RETURN_TRUE; |
743 | } |
744 | Py_RETURN_FALSE; |
745 | } |
746 | |
747 | PyObject* THPModule_setAllowBF16ReductionCuBLAS( |
748 | PyObject* _unused, |
749 | PyObject* arg) { |
750 | THPUtils_assert( |
751 | PyBool_Check(arg), |
752 | "set_allow_bf16_reduction_cublas expects a bool, " |
753 | "but got %s" , |
754 | THPUtils_typename(arg)); |
755 | at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True); |
756 | Py_RETURN_NONE; |
757 | } |
758 | |
759 | PyObject* THPModule_allowBF16ReductionCuBLAS( |
760 | PyObject* _unused, |
761 | PyObject* noargs) { |
762 | if (at::globalContext().allowBF16ReductionCuBLAS()) { |
763 | Py_RETURN_TRUE; |
764 | } |
765 | Py_RETURN_FALSE; |
766 | } |
767 | |
768 | PyObject* THPModule_setFlushDenormal(PyObject* _unused, PyObject* arg) { |
769 | THPUtils_assert( |
770 | PyBool_Check(arg), |
771 | "flush_denormal expects a bool, " |
772 | "but got %s" , |
773 | THPUtils_typename(arg)); |
774 | if (!at::globalContext().setFlushDenormal(arg == Py_True)) { |
775 | Py_RETURN_FALSE; |
776 | }; |
777 | Py_RETURN_TRUE; |
778 | } |
779 | |
780 | PyObject* THPModule_getDefaultDtype(PyObject* _unused, PyObject* arg) { |
781 | HANDLE_TH_ERRORS |
782 | auto scalar_type = torch::tensors::get_default_scalar_type(); |
783 | auto dtype = (PyObject*)torch::getTHPDtype(scalar_type); |
784 | Py_INCREF(dtype); |
785 | return dtype; |
786 | END_HANDLE_TH_ERRORS |
787 | } |
788 | |
789 | PyObject* THPModule_getDefaultDevice(PyObject* _unused, PyObject* arg) { |
790 | HANDLE_TH_ERRORS |
791 | return THPUtils_packString(c10::DeviceTypeName( |
792 | dispatchKeyToDeviceType(torch::tensors::get_default_dispatch_key()), |
793 | /*lower_case=*/true)); |
794 | END_HANDLE_TH_ERRORS |
795 | } |
796 | |
797 | PyObject* THPModule_setQEngine(PyObject* /* unused */, PyObject* arg) { |
798 | THPUtils_assert( |
799 | THPUtils_checkLong(arg), |
800 | "set_qengine expects an int, " |
801 | "but got %s" , |
802 | THPUtils_typename(arg)); |
803 | HANDLE_TH_ERRORS |
804 | auto qengine = static_cast<int>(THPUtils_unpackLong(arg)); |
805 | at::globalContext().setQEngine(static_cast<at::QEngine>(qengine)); |
806 | Py_RETURN_NONE; |
807 | END_HANDLE_TH_ERRORS |
808 | } |
809 | |
810 | PyObject* THPModule_qEngine(PyObject* _unused, PyObject* noargs) { |
811 | return THPUtils_packInt64(static_cast<int>(at::globalContext().qEngine())); |
812 | } |
813 | |
814 | PyObject* THPModule_supportedQEngines(PyObject* _unused, PyObject* noargs) { |
815 | auto qengines = at::globalContext().supportedQEngines(); |
816 | auto list = THPObjectPtr(PyList_New(qengines.size())); |
817 | if (!list) |
818 | return nullptr; |
819 | for (const auto i : c10::irange(qengines.size())) { |
820 | PyObject* i64 = THPUtils_packInt64(static_cast<int>(qengines[i])); |
821 | if (!i64) |
822 | return nullptr; |
823 | PyList_SET_ITEM(list.get(), i, i64); |
824 | } |
825 | return list.release(); |
826 | } |
827 | |
828 | PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) { |
829 | if (at::globalContext().isXNNPACKAvailable()) |
830 | Py_RETURN_TRUE; |
831 | else |
832 | Py_RETURN_FALSE; |
833 | } |
834 | |
835 | PyObject* THPModule_setCheckSparseTensorInvariants( |
836 | PyObject* _unused, |
837 | PyObject* arg) { |
838 | THPUtils_assert( |
839 | PyBool_Check(arg), |
840 | "set_check_sparse_tensor_invariants expects a bool, " |
841 | "but got %s" , |
842 | THPUtils_typename(arg)); |
843 | at::globalContext().setCheckSparseTensorInvariants(arg == Py_True); |
844 | Py_RETURN_NONE; |
845 | } |
846 | |
847 | PyObject* THPModule_checkSparseTensorInvariants( |
848 | PyObject* _unused, |
849 | PyObject* noargs) { |
850 | if (at::globalContext().checkSparseTensorInvariants()) |
851 | Py_RETURN_TRUE; |
852 | else |
853 | Py_RETURN_FALSE; |
854 | } |
855 | |
856 | PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) { |
857 | HANDLE_TH_ERRORS |
858 | bool isTHPFunction = THPFunction_Check(arg); |
859 | bool isTHPCppFunction = torch::autograd::THPCppFunction_Check(arg); |
860 | THPUtils_assert( |
861 | isTHPFunction || isTHPCppFunction, |
862 | "_will_engine_execute_node expects an grad_fn, " |
863 | "but got %s" , |
864 | THPUtils_typename(arg)); |
865 | const auto exec_info = torch::autograd::get_current_graph_task_exec_info(); |
866 | THPUtils_assert( |
867 | exec_info, |
868 | "_get_should_execute_nodes should only be called during the backward pass" ); |
869 | torch::autograd::Node* node; |
870 | std::shared_ptr<torch::autograd::Node> node_sp; |
871 | if (isTHPFunction) { |
872 | node_sp = ((THPFunction*)arg)->cdata.lock(); |
873 | node = node_sp.get(); |
874 | } else { |
875 | node = ((torch::autograd::THPCppFunction*)arg)->cdata.get(); |
876 | } |
877 | const auto nodes_in_graph = |
878 | torch::autograd::get_current_graph_task_nodes_in_graph(); |
879 | bool ret = nodes_in_graph->find(node) != nodes_in_graph->end(); |
880 | if (ret && !exec_info->empty()) { |
881 | auto it = exec_info->find(node); |
882 | if (it == exec_info->end() || !it->second.should_execute()) { |
883 | ret = false; |
884 | } else { |
885 | TORCH_CHECK( |
886 | !(node->topological_nr() == 0 && it->second.captures_), |
887 | "A leaf node was passed to _will_engine_execute_node but we are " |
888 | "currently running autograd.grad(). This is currently not supported." ); |
889 | } |
890 | } |
891 | if (ret) { |
892 | Py_RETURN_TRUE; |
893 | } else { |
894 | Py_RETURN_FALSE; |
895 | } |
896 | END_HANDLE_TH_ERRORS |
897 | } |
898 | |
899 | PyObject* THPModule_getCurrentGraphTaskExecutionOrder( |
900 | PyObject* _unused, |
901 | PyObject* noargs) { |
902 | HANDLE_TH_ERRORS |
903 | std::vector<torch::autograd::Node*> nodes = |
904 | torch::autograd::get_current_graph_task_execution_order(); |
905 | TORCH_CHECK( |
906 | nodes.size(), |
907 | "_current_graph_task_execution_order should only be called during the backward pass" ); |
908 | auto list = THPObjectPtr(PyList_New(nodes.size())); |
909 | if (!list) |
910 | return nullptr; |
911 | for (const auto i : c10::irange(nodes.size())) { |
912 | // This node is guaranteed to be alive since the backward is still running |
913 | PyObject* pyobj_node = |
914 | torch::autograd::functionToPyObject(nodes[i]->getptr()); |
915 | PyList_SET_ITEM(list.get(), i, pyobj_node); |
916 | } |
917 | return list.release(); |
918 | END_HANDLE_TH_ERRORS |
919 | } |
920 | |
921 | PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) { |
922 | HANDLE_TH_ERRORS |
923 | return THPUtils_packInt64(torch::autograd::get_current_graph_task_id()); |
924 | END_HANDLE_TH_ERRORS |
925 | } |
926 | |
927 | PyObject* THPModule_getCurrentNode(PyObject* _unused, PyObject* noargs) { |
928 | HANDLE_TH_ERRORS |
929 | return torch::autograd::functionToPyObject( |
930 | torch::autograd::get_current_node()); |
931 | END_HANDLE_TH_ERRORS |
932 | } |
933 | |
934 | PyObject* THPModule_setDefaultMobileCPUAllocator( |
935 | PyObject* _unused, |
936 | PyObject* noargs) { |
937 | HANDLE_TH_ERRORS |
938 | at::globalContext().setDefaultMobileCPUAllocator(); |
939 | Py_RETURN_NONE; |
940 | END_HANDLE_TH_ERRORS |
941 | } |
942 | |
943 | PyObject* THPModule_unsetDefaultMobileCPUAllocator( |
944 | PyObject* _unused, |
945 | PyObject* noargs) { |
946 | HANDLE_TH_ERRORS |
947 | at::globalContext().unsetDefaultMobileCPUAllocator(); |
948 | Py_RETURN_NONE; |
949 | END_HANDLE_TH_ERRORS |
950 | } |
951 | |
952 | static PyObject* THPModule_vmapmode_increment_nesting( |
953 | PyObject* _unused, |
954 | PyObject* arg) { |
955 | HANDLE_TH_ERRORS |
956 | return THPUtils_packInt64(at::impl::VmapMode::increment_nesting()); |
957 | END_HANDLE_TH_ERRORS |
958 | } |
959 | |
960 | static PyObject* THPModule_vmapmode_decrement_nesting( |
961 | PyObject* _unused, |
962 | PyObject* arg) { |
963 | HANDLE_TH_ERRORS |
964 | return THPUtils_packInt64(at::impl::VmapMode::decrement_nesting()); |
965 | END_HANDLE_TH_ERRORS |
966 | } |
967 | |
968 | static PyObject* THPModule_set_display_vmap_fallback_warnings_mode( |
969 | PyObject* _unused, |
970 | PyObject* arg) { |
971 | HANDLE_TH_ERRORS |
972 | THPUtils_assert( |
973 | PyBool_Check(arg), |
974 | "enabled must be a bool, " |
975 | "but got %s" , |
976 | THPUtils_typename(arg)); |
977 | at::globalContext().setDisplayVmapFallbackWarnings(arg == Py_True); |
978 | Py_RETURN_NONE; |
979 | END_HANDLE_TH_ERRORS |
980 | } |
981 | |
982 | static PyObject* THPModule_are_vmap_fallback_warnings_enabled( |
983 | PyObject* _unused, |
984 | PyObject* arg) { |
985 | HANDLE_TH_ERRORS |
986 | if (at::globalContext().areVmapFallbackWarningsEnabled()) { |
987 | Py_RETURN_TRUE; |
988 | } else { |
989 | Py_RETURN_FALSE; |
990 | } |
991 | END_HANDLE_TH_ERRORS |
992 | } |
993 | |
994 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, |
995 | // cppcoreguidelines-avoid-non-const-global-variables, modernize-avoid-c-arrays) |
996 | static PyMethodDef TorchMethods[] = { |
997 | {"_initExtension" , THPModule_initExtension, METH_O, nullptr}, |
998 | {"_autograd_init" , THPAutograd_initExtension, METH_NOARGS, nullptr}, |
999 | {"_add_docstr" , THPModule_addDocStr, METH_VARARGS, nullptr}, |
1000 | {"_init_names" , THPModule_initNames, METH_O, nullptr}, |
1001 | {"_has_distributed" , THPModule_hasDistributed, METH_NOARGS, nullptr}, |
1002 | {"_set_default_tensor_type" , |
1003 | THPModule_setDefaultTensorType, |
1004 | METH_O, |
1005 | nullptr}, |
1006 | {"_set_default_dtype" , THPModule_setDefaultDtype, METH_O, nullptr}, |
1007 | {"_infer_size" , THPModule_inferSize, METH_VARARGS, nullptr}, |
1008 | {"_crash_if_csrc_asan" , THPModule_crashIfCsrcASAN, METH_O, nullptr}, |
1009 | {"_crash_if_csrc_ubsan" , THPModule_crashIfCsrcUBSAN, METH_O, nullptr}, |
1010 | {"_crash_if_vptr_ubsan" , THPModule_crashIfvptrUBSAN, METH_NOARGS, nullptr}, |
1011 | {"_crash_if_aten_asan" , THPModule_crashIfATenASAN, METH_O, nullptr}, |
1012 | {"_show_config" , THPModule_showConfig, METH_NOARGS, nullptr}, |
1013 | {"_cxx_flags" , THPModule_cxxFlags, METH_NOARGS, nullptr}, |
1014 | {"_parallel_info" , THPModule_parallelInfo, METH_NOARGS, nullptr}, |
1015 | {"_set_backcompat_broadcast_warn" , |
1016 | THPModule_setBackcompatBroadcastWarn, |
1017 | METH_O, |
1018 | nullptr}, |
1019 | {"_get_backcompat_broadcast_warn" , |
1020 | THPModule_getBackcompatBroadcastWarn, |
1021 | METH_NOARGS, |
1022 | nullptr}, |
1023 | {"_set_backcompat_keepdim_warn" , |
1024 | THPModule_setBackcompatKeepdimWarn, |
1025 | METH_O, |
1026 | nullptr}, |
1027 | {"_get_backcompat_keepdim_warn" , |
1028 | THPModule_getBackcompatKeepdimWarn, |
1029 | METH_NOARGS, |
1030 | nullptr}, |
1031 | {"get_num_threads" , THPModule_getNumThreads, METH_NOARGS, nullptr}, |
1032 | {"set_num_threads" , THPModule_setNumThreads, METH_O, nullptr}, |
1033 | {"get_num_interop_threads" , |
1034 | THPModule_getNumInteropThreads, |
1035 | METH_NOARGS, |
1036 | nullptr}, |
1037 | {"set_num_interop_threads" , |
1038 | THPModule_setNumInteropThreads, |
1039 | METH_O, |
1040 | nullptr}, |
1041 | {"_get_flash_sdp_enabled" , |
1042 | THPModule_userEnabledFlashSDP, |
1043 | METH_NOARGS, |
1044 | nullptr}, |
1045 | {"_set_sdp_use_flash" , THPModule_setSDPUseFlash, METH_O, nullptr}, |
1046 | {"_get_mem_efficient_sdp_enabled" , |
1047 | userEnabledMemEfficientSDP, |
1048 | METH_NOARGS, |
1049 | nullptr}, |
1050 | {"_set_sdp_use_mem_efficient" , |
1051 | THPModule_setSDPUseMemEfficient, |
1052 | METH_O, |
1053 | nullptr}, |
1054 | {"_get_math_sdp_enabled" , |
1055 | THPModule_userEnabledMathSDP, |
1056 | METH_NOARGS, |
1057 | nullptr}, |
1058 | {"_set_sdp_use_math" , THPModule_setSDPUseMath, METH_O, nullptr}, |
1059 | {"_get_cudnn_enabled" , THPModule_userEnabledCuDNN, METH_NOARGS, nullptr}, |
1060 | {"_set_cudnn_enabled" , THPModule_setUserEnabledCuDNN, METH_O, nullptr}, |
1061 | {"_get_mkldnn_enabled" , THPModule_userEnabledMkldnn, METH_NOARGS, nullptr}, |
1062 | {"_set_mkldnn_enabled" , THPModule_setUserEnabledMkldnn, METH_O, nullptr}, |
1063 | {"_get_cudnn_allow_tf32" , THPModule_allowTF32CuDNN, METH_NOARGS, nullptr}, |
1064 | {"_set_cudnn_allow_tf32" , THPModule_setAllowTF32CuDNN, METH_O, nullptr}, |
1065 | {"_get_cudnn_benchmark" , THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, |
1066 | {"_set_cudnn_benchmark" , THPModule_setBenchmarkCuDNN, METH_O, nullptr}, |
1067 | {"_get_cudnn_deterministic" , |
1068 | THPModule_deterministicCuDNN, |
1069 | METH_NOARGS, |
1070 | nullptr}, |
1071 | {"_set_cudnn_deterministic" , |
1072 | THPModule_setDeterministicCuDNN, |
1073 | METH_O, |
1074 | nullptr}, |
1075 | {"_get_deterministic_algorithms" , |
1076 | THPModule_deterministicAlgorithms, |
1077 | METH_NOARGS, |
1078 | nullptr}, |
1079 | {"_get_deterministic_algorithms_warn_only" , |
1080 | THPModule_deterministicAlgorithmsWarnOnly, |
1081 | METH_NOARGS, |
1082 | nullptr}, |
1083 | {"_set_deterministic_algorithms" , |
1084 | castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms), |
1085 | METH_VARARGS | METH_KEYWORDS, |
1086 | nullptr}, |
1087 | {"_get_warnAlways" , THPModule_warnAlways, METH_NOARGS, nullptr}, |
1088 | {"_set_warnAlways" , THPModule_setWarnAlways, METH_O, nullptr}, |
1089 | {"_warn" , THPModule_warn, METH_NOARGS, nullptr}, |
1090 | {"_warn_deprecation" , THPModule_warnDeprecation, METH_NOARGS, nullptr}, |
1091 | {"_get_cublas_allow_tf32" , THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr}, |
1092 | {"_set_cublas_allow_tf32" , THPModule_setAllowTF32CuBLAS, METH_O, nullptr}, |
1093 | {"_get_float32_matmul_precision" , |
1094 | THPModule_float32MatmulPrecision, |
1095 | METH_NOARGS, |
1096 | nullptr}, |
1097 | {"_set_float32_matmul_precision" , |
1098 | THPModule_setFloat32MatmulPrecision, |
1099 | METH_O, |
1100 | nullptr}, |
1101 | {"_get_cublas_allow_fp16_reduced_precision_reduction" , |
1102 | THPModule_allowFP16ReductionCuBLAS, |
1103 | METH_NOARGS, |
1104 | nullptr}, |
1105 | {"_set_cublas_allow_fp16_reduced_precision_reduction" , |
1106 | THPModule_setAllowFP16ReductionCuBLAS, |
1107 | METH_O, |
1108 | nullptr}, |
1109 | {"_get_cublas_allow_bf16_reduced_precision_reduction" , |
1110 | THPModule_allowBF16ReductionCuBLAS, |
1111 | METH_NOARGS, |
1112 | nullptr}, |
1113 | {"_set_cublas_allow_bf16_reduced_precision_reduction" , |
1114 | THPModule_setAllowBF16ReductionCuBLAS, |
1115 | METH_O, |
1116 | nullptr}, |
1117 | {"_vmapmode_increment_nesting" , |
1118 | THPModule_vmapmode_increment_nesting, |
1119 | METH_NOARGS, |
1120 | nullptr}, |
1121 | {"_vmapmode_decrement_nesting" , |
1122 | THPModule_vmapmode_decrement_nesting, |
1123 | METH_NOARGS, |
1124 | nullptr}, |
1125 | {"_debug_only_display_vmap_fallback_warnings" , |
1126 | THPModule_set_display_vmap_fallback_warnings_mode, |
1127 | METH_O, |
1128 | nullptr}, |
1129 | {"_debug_only_are_vmap_fallback_warnings_enabled" , |
1130 | THPModule_are_vmap_fallback_warnings_enabled, |
1131 | METH_NOARGS, |
1132 | nullptr}, |
1133 | {"_to_dlpack" , THPModule_toDLPack, METH_O, nullptr}, |
1134 | {"_from_dlpack" , THPModule_fromDLPack, METH_O, nullptr}, |
1135 | {"_get_cpp_backtrace" , THModule_getCppBacktrace, METH_VARARGS, nullptr}, |
1136 | {"_rename_privateuse1_backend" , |
1137 | THModule_rename_privateuse1_backend, |
1138 | METH_O, |
1139 | nullptr}, |
1140 | {"set_flush_denormal" , THPModule_setFlushDenormal, METH_O, nullptr}, |
1141 | {"get_default_dtype" , THPModule_getDefaultDtype, METH_NOARGS, nullptr}, |
1142 | {"_get_default_device" , THPModule_getDefaultDevice, METH_NOARGS, nullptr}, |
1143 | {"_get_qengine" , THPModule_qEngine, METH_NOARGS, nullptr}, |
1144 | {"_set_qengine" , THPModule_setQEngine, METH_O, nullptr}, |
1145 | {"_supported_qengines" , THPModule_supportedQEngines, METH_NOARGS, nullptr}, |
1146 | {"_is_xnnpack_enabled" , THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr}, |
1147 | {"_set_check_sparse_tensor_invariants" , |
1148 | THPModule_setCheckSparseTensorInvariants, |
1149 | METH_O, |
1150 | nullptr}, |
1151 | {"_check_sparse_tensor_invariants" , |
1152 | THPModule_checkSparseTensorInvariants, |
1153 | METH_NOARGS, |
1154 | nullptr}, |
1155 | {"_will_engine_execute_node" , |
1156 | THPModule_willEngineExecuteNode, |
1157 | METH_O, |
1158 | nullptr}, |
1159 | {"_current_graph_task_execution_order" , |
1160 | THPModule_getCurrentGraphTaskExecutionOrder, |
1161 | METH_NOARGS, |
1162 | nullptr}, |
1163 | {"_current_graph_task_id" , |
1164 | THPModule_getCurrentGraphTaskId, |
1165 | METH_NOARGS, |
1166 | nullptr}, |
1167 | {"_current_autograd_node" , THPModule_getCurrentNode, METH_NOARGS, nullptr}, |
1168 | {"_set_default_mobile_cpu_allocator" , |
1169 | THPModule_setDefaultMobileCPUAllocator, |
1170 | METH_NOARGS, |
1171 | nullptr}, |
1172 | {"_unset_default_mobile_cpu_allocator" , |
1173 | THPModule_unsetDefaultMobileCPUAllocator, |
1174 | METH_NOARGS, |
1175 | nullptr}, |
1176 | {"_is_torch_function_enabled" , |
1177 | THPModule_isEnabledTorchFunction, |
1178 | METH_NOARGS, |
1179 | nullptr}, |
1180 | {"_disabled_torch_function_impl" , |
1181 | THPModule_disable_torch_function, |
1182 | METH_VARARGS, |
1183 | nullptr}, |
1184 | {"_disabled_torch_dispatch_impl" , |
1185 | THPModule_disable_torch_dispatch, |
1186 | METH_VARARGS, |
1187 | nullptr}, |
1188 | {"_has_torch_function" , THPModule_has_torch_function, METH_O, nullptr}, |
1189 | {"_has_torch_function_unary" , |
1190 | THPModule_has_torch_function_unary, |
1191 | METH_O, |
1192 | nullptr}, |
1193 | {"_has_torch_function_variadic" , |
1194 | (PyCFunction)(void (*)(void))THPModule_has_torch_function_variadic, |
1195 | METH_FASTCALL, |
1196 | nullptr}, |
1197 | {nullptr, nullptr, 0, nullptr}}; |
1198 | |
1199 | void THCPStream_init(PyObject* module); |
1200 | void THCPEvent_init(PyObject* module); |
1201 | void THCPGraph_init(PyObject* module); |
1202 | |
1203 | #ifdef USE_CUDA |
1204 | PyMethodDef* THCPModule_methods(); |
1205 | namespace torch { |
1206 | namespace cuda { |
1207 | |
1208 | void initModule(PyObject* module); |
1209 | |
1210 | } |
1211 | } // namespace torch |
1212 | #endif |
1213 | |
1214 | #ifdef USE_ITT |
1215 | namespace torch { |
1216 | namespace profiler { |
1217 | void initIttBindings(PyObject* module); |
1218 | } // namespace profiler |
1219 | } // namespace torch |
1220 | #endif |
1221 | |
1222 | namespace torch { |
1223 | void initVerboseBindings(PyObject* module); |
1224 | } // namespace torch |
1225 | |
1226 | static std::vector<PyMethodDef> methods; |
1227 | |
1228 | // In Python we can't use the trick of C10_LOG_API_USAGE_ONCE |
1229 | // Guaranteed to be invoked from Python under GIL, no locking on map needed |
1230 | static void LogAPIUsageOnceFromPython(const std::string& event) { |
1231 | static std::unordered_set<std::string> seen; |
1232 | if (!seen.count(event)) { |
1233 | seen.insert(event); |
1234 | c10::LogAPIUsage(event); |
1235 | } |
1236 | } |
1237 | |
1238 | // Weak reference to tensor, used to test a tensor isn't leaked |
1239 | class WeakTensorRef { |
1240 | c10::weak_intrusive_ptr<c10::TensorImpl> weakref_; |
1241 | |
1242 | public: |
1243 | WeakTensorRef(const at::Tensor& t) : weakref_(t.getIntrusivePtr()) {} |
1244 | |
1245 | bool expired() { |
1246 | return weakref_.expired(); |
1247 | } |
1248 | }; |
1249 | |
1250 | extern "C" |
1251 | #ifdef _WIN32 |
1252 | __declspec(dllexport) |
1253 | #endif |
1254 | TORCH_API PyObject* initModule(); |
1255 | // separate decl and defn for msvc error C2491 |
1256 | PyObject* initModule() { |
1257 | HANDLE_TH_ERRORS |
1258 | |
1259 | c10::initLogging(); |
1260 | |
1261 | at::internal::lazy_init_num_threads(); |
1262 | |
1263 | C10_LOG_API_USAGE_ONCE("torch.python.import" ); |
1264 | |
1265 | // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) |
1266 | #define ASSERT_TRUE(cmd) \ |
1267 | if (!(cmd)) \ |
1268 | return nullptr |
1269 | |
1270 | THPUtils_addPyMethodDefs(methods, TorchMethods); |
1271 | THPUtils_addPyMethodDefs(methods, DataLoaderMethods); |
1272 | THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions()); |
1273 | THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions()); |
1274 | #ifdef USE_CUDA |
1275 | THPUtils_addPyMethodDefs(methods, THCPModule_methods()); |
1276 | #endif |
1277 | #if defined(USE_DISTRIBUTED) && defined(USE_C10D) |
1278 | THPUtils_addPyMethodDefs( |
1279 | methods, torch::distributed::c10d::python_functions()); |
1280 | #ifndef _WIN32 |
1281 | THPUtils_addPyMethodDefs( |
1282 | methods, torch::distributed::rpc::python_functions()); |
1283 | THPUtils_addPyMethodDefs( |
1284 | methods, torch::distributed::autograd::python_functions()); |
1285 | THPUtils_addPyMethodDefs( |
1286 | methods, torch::distributed::rpc::testing::python_functions()); |
1287 | #endif |
1288 | #endif |
1289 | |
1290 | static struct PyModuleDef torchmodule = { |
1291 | PyModuleDef_HEAD_INIT, "torch._C" , nullptr, -1, methods.data()}; |
1292 | ASSERT_TRUE(module = PyModule_Create(&torchmodule)); |
1293 | ASSERT_TRUE(THPGenerator_init(module)); |
1294 | ASSERT_TRUE(THPException_init(module)); |
1295 | THPSize_init(module); |
1296 | THPDtype_init(module); |
1297 | THPDTypeInfo_init(module); |
1298 | THPLayout_init(module); |
1299 | THPMemoryFormat_init(module); |
1300 | THPQScheme_init(module); |
1301 | THPDevice_init(module); |
1302 | THPStream_init(module); |
1303 | ASSERT_TRUE(THPVariable_initModule(module)); |
1304 | ASSERT_TRUE(THPFunction_initModule(module)); |
1305 | ASSERT_TRUE(THPEngine_initModule(module)); |
1306 | // NOTE: We need to be able to access OperatorExportTypes from ONNX for use in |
1307 | // the export side of JIT, so this ONNX init needs to appear before the JIT |
1308 | // init. |
1309 | torch::onnx::initONNXBindings(module); |
1310 | torch::autograd::initEnumTag(module); |
1311 | torch::jit::initJITBindings(module); |
1312 | torch::monitor::initMonitorBindings(module); |
1313 | torch::impl::dispatch::initDispatchBindings(module); |
1314 | torch::dynamo::initDynamoBindings(module); |
1315 | torch::functorch::impl::initFuncTorchBindings(module); |
1316 | torch::throughput_benchmark::initThroughputBenchmarkBindings(module); |
1317 | torch::autograd::initReturnTypes(module); |
1318 | torch::autograd::initNNFunctions(module); |
1319 | torch::autograd::initFFTFunctions(module); |
1320 | torch::autograd::initLinalgFunctions(module); |
1321 | torch::autograd::initNestedFunctions(module); |
1322 | torch::autograd::initSparseFunctions(module); |
1323 | torch::autograd::initSpecialFunctions(module); |
1324 | torch::autograd::init_legacy_variable(module); |
1325 | torch::profiler::initPythonBindings(module); |
1326 | torch::python::init_bindings(module); |
1327 | torch::lazy::initLazyBindings(module); |
1328 | #ifdef USE_ITT |
1329 | torch::profiler::initIttBindings(module); |
1330 | #endif |
1331 | #ifdef USE_CUDA |
1332 | torch::cuda::initModule(module); |
1333 | #endif |
1334 | torch::initVerboseBindings(module); |
1335 | ASSERT_TRUE(THPStorage_init(module)); |
1336 | |
1337 | #ifdef USE_CUDA |
1338 | // This will only initialise base classes and attach them to library namespace |
1339 | // They won't be ready for real usage until importing cuda module, that will |
1340 | // complete the process (but it defines Python classes before calling back |
1341 | // into C, so these lines have to execute first).. |
1342 | THCPStream_init(module); |
1343 | THCPEvent_init(module); |
1344 | THCPGraph_init(module); |
1345 | #endif |
1346 | |
1347 | auto set_module_attr = |
1348 | [&](const char* name, PyObject* v, bool incref = true) { |
1349 | // PyModule_AddObject steals reference |
1350 | if (incref) { |
1351 | Py_INCREF(v); |
1352 | } |
1353 | return PyModule_AddObject(module, name, v) == 0; |
1354 | }; |
1355 | |
1356 | #if defined(USE_CUDNN) || defined(USE_ROCM) |
1357 | PyObject* has_cudnn = Py_True; |
1358 | #else |
1359 | PyObject* has_cudnn = Py_False; |
1360 | #endif |
1361 | ASSERT_TRUE(set_module_attr("has_cudnn" , has_cudnn)); |
1362 | |
1363 | #if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED() |
1364 | PyObject* has_spectral = Py_True; |
1365 | #else |
1366 | PyObject* has_spectral = Py_False; |
1367 | #endif |
1368 | ASSERT_TRUE(set_module_attr("has_spectral" , has_spectral)); |
1369 | |
1370 | // force ATen to initialize because it handles |
1371 | // setting up TH Errors so that they throw C++ exceptions |
1372 | at::init(); |
1373 | |
1374 | // Automatically translate errors thrown from pybind11 functions |
1375 | py::register_exception_translator([](std::exception_ptr e) { // NOLINT |
1376 | try { |
1377 | if (e) { |
1378 | std::rethrow_exception(e); |
1379 | } |
1380 | } |
1381 | CATCH_TH_ERRORS() |
1382 | }); |
1383 | |
1384 | auto py_module = py::reinterpret_borrow<py::module>(module); |
1385 | py_module.def("_demangle" , &c10::demangle); |
1386 | py_module.def("_log_api_usage_once" , &LogAPIUsageOnceFromPython); |
1387 | |
1388 | py_module.def("vitals_enabled" , &at::vitals::torchVitalEnabled); |
1389 | py_module.def( |
1390 | "set_vital" , |
1391 | [](const std::string& vital, |
1392 | const std::string& attr, |
1393 | const std::string value) { |
1394 | return at::vitals::VitalsAPI.setVital(vital, attr, value); |
1395 | }); |
1396 | py_module.def( |
1397 | "read_vitals" , []() { return at::vitals::VitalsAPI.readVitals(); }); |
1398 | |
1399 | py_module.def( |
1400 | "init_num_threads" , |
1401 | torch::wrap_pybind_function(at::init_num_threads), |
1402 | R"( |
1403 | init_num_threads() |
1404 | |
1405 | Initializes the number of parallel threads used on the current thread. |
1406 | |
1407 | Call this whenever a new thread is created in order to propagate values from |
1408 | :func:`torch.set_num_threads` onto the new thread. |
1409 | )" ); |
1410 | |
1411 | ASSERT_TRUE( |
1412 | set_module_attr("has_openmp" , at::hasOpenMP() ? Py_True : Py_False)); |
1413 | ASSERT_TRUE(set_module_attr("has_mkl" , at::hasMKL() ? Py_True : Py_False)); |
1414 | ASSERT_TRUE( |
1415 | set_module_attr("has_lapack" , at::hasLAPACK() ? Py_True : Py_False)); |
1416 | |
1417 | py_module.def("_valgrind_supported_platform" , []() { |
1418 | #if defined(USE_VALGRIND) |
1419 | return true; |
1420 | #else |
1421 | return false; |
1422 | #endif |
1423 | }); |
1424 | |
1425 | py_module.def("_valgrind_toggle" , []() { |
1426 | #if defined(USE_VALGRIND) |
1427 | CALLGRIND_TOGGLE_COLLECT; |
1428 | #else |
1429 | TORCH_CHECK(false, "Valgrind is not supported." ); |
1430 | #endif |
1431 | }); |
1432 | |
1433 | py_module.def("_valgrind_toggle_and_dump_stats" , []() { |
1434 | #if defined(USE_VALGRIND) |
1435 | // NB: If we don't toggle collect around dump stats, callgrind_annotate |
1436 | // won't process the results correctly. Specifically, |
1437 | // `callgrind_annotate --inclusive=no` will be almost completely empty. |
1438 | CALLGRIND_TOGGLE_COLLECT; |
1439 | CALLGRIND_DUMP_STATS; |
1440 | #else |
1441 | TORCH_CHECK(false, "Valgrind is not supported." ); |
1442 | #endif |
1443 | }); |
1444 | |
1445 | py::class_<WeakTensorRef>(py_module, "_WeakTensorRef" ) |
1446 | .def(py::init([](py::object tensor) { |
1447 | return WeakTensorRef(THPVariable_Unpack(tensor.ptr())); |
1448 | })) |
1449 | .def("expired" , &WeakTensorRef::expired); |
1450 | |
1451 | py::enum_<at::native::ConvBackend>(py_module, "_ConvBackend" ) |
1452 | .value("CudaDepthwise2d" , at::native::ConvBackend::CudaDepthwise2d) |
1453 | .value("CudaDepthwise3d" , at::native::ConvBackend::CudaDepthwise3d) |
1454 | .value("Cudnn" , at::native::ConvBackend::Cudnn) |
1455 | .value("CudnnTranspose" , at::native::ConvBackend::CudnnTranspose) |
1456 | .value("Empty" , at::native::ConvBackend::Empty) |
1457 | .value("Miopen" , at::native::ConvBackend::Miopen) |
1458 | .value("MiopenDepthwise" , at::native::ConvBackend::MiopenDepthwise) |
1459 | .value("MiopenTranspose" , at::native::ConvBackend::MiopenTranspose) |
1460 | .value("Mkldnn" , at::native::ConvBackend::Mkldnn) |
1461 | .value("MkldnnEmpty" , at::native::ConvBackend::MkldnnEmpty) |
1462 | .value("NnpackSpatial" , at::native::ConvBackend::NnpackSpatial) |
1463 | .value("Overrideable" , at::native::ConvBackend::Overrideable) |
1464 | .value("Slow2d" , at::native::ConvBackend::Slow2d) |
1465 | .value("Slow3d" , at::native::ConvBackend::Slow3d) |
1466 | .value("SlowDilated2d" , at::native::ConvBackend::SlowDilated2d) |
1467 | .value("SlowDilated3d" , at::native::ConvBackend::SlowDilated3d) |
1468 | .value("SlowTranspose2d" , at::native::ConvBackend::SlowTranspose2d) |
1469 | .value("SlowTranspose3d" , at::native::ConvBackend::SlowTranspose3d) |
1470 | .value( |
1471 | "Winograd3x3Depthwise" , at::native::ConvBackend::Winograd3x3Depthwise) |
1472 | .value("Xnnpack2d" , at::native::ConvBackend::Xnnpack2d) |
1473 | .value("Mps" , at::native::ConvBackend::Mps) |
1474 | .value("MpsTranspose," , at::native::ConvBackend::MpsTranspose); |
1475 | |
1476 | py_module.def( |
1477 | "_select_conv_backend" , |
1478 | [](const at::Tensor& input, |
1479 | const at::Tensor& weight, |
1480 | const c10::optional<at::Tensor>& bias_opt, |
1481 | at::IntArrayRef stride_, |
1482 | at::SymIntArrayRef padding_, |
1483 | at::IntArrayRef dilation_, |
1484 | bool transposed_, |
1485 | at::SymIntArrayRef output_padding_, |
1486 | int64_t groups_) { |
1487 | return at::native::select_conv_backend( |
1488 | input, |
1489 | weight, |
1490 | bias_opt, |
1491 | stride_, |
1492 | padding_, |
1493 | dilation_, |
1494 | transposed_, |
1495 | output_padding_, |
1496 | groups_, |
1497 | c10::nullopt); |
1498 | }, |
1499 | py::arg("input" ), |
1500 | py::arg("weight" ), |
1501 | py::arg("bias" ), |
1502 | py::arg("stride" ), |
1503 | py::arg("padding" ), |
1504 | py::arg("dilation" ), |
1505 | py::arg("transposed" ), |
1506 | py::arg("output_padding" ), |
1507 | py::arg("groups" )); |
1508 | |
1509 | // overload for bias_sizes_opt/backward TODO: figure out default value |
1510 | py_module.def( |
1511 | "_select_conv_backend" , |
1512 | [](const at::Tensor& input, |
1513 | const at::Tensor& weight, |
1514 | const c10::optional<at::Tensor>& bias, |
1515 | at::IntArrayRef stride_, |
1516 | at::SymIntArrayRef padding_, |
1517 | at::IntArrayRef dilation_, |
1518 | bool transposed_, |
1519 | at::SymIntArrayRef output_padding_, |
1520 | int64_t groups_, |
1521 | c10::optional<std::vector<c10::SymInt>> bias_sizes_opt) { |
1522 | c10::OptionalArrayRef<c10::SymInt> ref = c10::nullopt; |
1523 | if (bias_sizes_opt) { |
1524 | ref = (*bias_sizes_opt); |
1525 | } |
1526 | return at::native::select_conv_backend( |
1527 | input, |
1528 | weight, |
1529 | bias, |
1530 | stride_, |
1531 | padding_, |
1532 | dilation_, |
1533 | transposed_, |
1534 | output_padding_, |
1535 | groups_, |
1536 | ref); |
1537 | }, |
1538 | py::arg("input" ), |
1539 | py::arg("weight" ), |
1540 | py::arg("bias" ), |
1541 | py::arg("stride" ), |
1542 | py::arg("padding" ), |
1543 | py::arg("dilation" ), |
1544 | py::arg("transposed" ), |
1545 | py::arg("output_padding" ), |
1546 | py::arg("groups" ), |
1547 | py::arg("bias_sizes" )); |
1548 | |
1549 | py_module.def( |
1550 | "_conv_determine_backend_memory_format" , |
1551 | at::native::_determine_backend_memory_format); |
1552 | |
1553 | py::enum_<at::LinalgBackend>(py_module, "_LinalgBackend" ) |
1554 | .value("Default" , at::LinalgBackend::Default) |
1555 | .value("Cusolver" , at::LinalgBackend::Cusolver) |
1556 | .value("Magma" , at::LinalgBackend::Magma); |
1557 | |
1558 | py_module.def("_set_linalg_preferred_backend" , [](at::LinalgBackend b) { |
1559 | at::globalContext().setLinalgPreferredBackend(b); |
1560 | }); |
1561 | py_module.def("_get_linalg_preferred_backend" , []() { |
1562 | return at::globalContext().linalgPreferredBackend(); |
1563 | }); |
1564 | |
1565 | py_module.def("_stash_obj_in_tls" , [](std::string key, py::handle arg) { |
1566 | at::impl::ThreadLocalPythonObjects::get_state().set( |
1567 | key, |
1568 | std::make_shared<c10::SafePyObject>(arg.ptr(), getPyInterpreter())); |
1569 | }); |
1570 | |
1571 | py_module.def("_get_obj_in_tls" , [](std::string key) -> py::handle { |
1572 | auto safe_pyobject = |
1573 | at::impl::ThreadLocalPythonObjects::get_state().get(key); |
1574 | auto obj = safe_pyobject->ptr(getPyInterpreter()); |
1575 | return py::handle(obj); |
1576 | }); |
1577 | |
1578 | py_module.def("_is_key_in_tls" , [](std::string key) -> bool { |
1579 | return at::impl::ThreadLocalPythonObjects::get_state().contains(key); |
1580 | }); |
1581 | |
1582 | #ifdef USE_CUDA |
1583 | PyObject* has_cuda = Py_True; |
1584 | #else |
1585 | PyObject* has_cuda = Py_False; |
1586 | #endif |
1587 | |
1588 | #ifdef USE_MPS |
1589 | PyObject* has_mps = Py_True; |
1590 | #else |
1591 | PyObject* has_mps = Py_False; |
1592 | #endif |
1593 | |
1594 | ASSERT_TRUE(set_module_attr("has_cuda" , has_cuda)); |
1595 | ASSERT_TRUE(set_module_attr("has_mps" , has_mps)); |
1596 | py_module.def("_is_mps_available" , []() { return at::hasMPS(); }); |
1597 | py_module.def("_is_mps_on_macos_13_or_newer" , []() { |
1598 | #ifdef USE_MPS |
1599 | return at::mps::is_macos_13_or_newer(); |
1600 | #else |
1601 | return false; |
1602 | #endif |
1603 | }); |
1604 | |
1605 | ASSERT_TRUE( |
1606 | set_module_attr("has_mkldnn" , at::hasMKLDNN() ? Py_True : Py_False)); |
1607 | |
1608 | #ifdef _GLIBCXX_USE_CXX11_ABI |
1609 | ASSERT_TRUE(set_module_attr( |
1610 | "_GLIBCXX_USE_CXX11_ABI" , _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False)); |
1611 | #else |
1612 | ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI" , Py_False)); |
1613 | #endif |
1614 | |
1615 | // See note [Pybind11 ABI constants] |
1616 | #define SET_STR_DEFINE(name) \ |
1617 | ASSERT_TRUE(set_module_attr("_" #name, THPUtils_packString(name))) |
1618 | |
1619 | #ifdef PYBIND11_COMPILER_TYPE |
1620 | SET_STR_DEFINE(PYBIND11_COMPILER_TYPE); |
1621 | #else |
1622 | ASSERT_TRUE( |
1623 | set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None)); |
1624 | #endif |
1625 | |
1626 | #ifdef PYBIND11_STDLIB |
1627 | SET_STR_DEFINE(PYBIND11_STDLIB); |
1628 | #else |
1629 | ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_STDLIB), Py_None)); |
1630 | #endif |
1631 | |
1632 | #ifdef PYBIND11_BUILD_ABI |
1633 | SET_STR_DEFINE(PYBIND11_BUILD_ABI); |
1634 | #else |
1635 | ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_BUILD_ABI), Py_None)); |
1636 | #endif |
1637 | #undef SET_STR_DEFINE |
1638 | |
1639 | py_module.def( |
1640 | "_set_conj" , [](const at::Tensor& x, bool conj) { x._set_conj(conj); }); |
1641 | py_module.def( |
1642 | "_set_neg" , [](const at::Tensor& x, bool neg) { x._set_neg(neg); }); |
1643 | py_module.def("_get_tensor_metadata" , &torch::jit::getTensorMetadata); |
1644 | py_module.def( |
1645 | "_set_tensor_metadata" , |
1646 | static_cast<void (*)( |
1647 | const at::Tensor&, std::unordered_map<std::string, bool>)>( |
1648 | torch::jit::setTensorMetadata)); |
1649 | py_module.def("_dispatch_key_set" , [](const at::Tensor& x) { |
1650 | return toString(x.key_set()); |
1651 | }); |
1652 | py_module.def( |
1653 | "_has_storage" , [](const at::Tensor& x) { return x.has_storage(); }); |
1654 | |
1655 | py_module.def("_set_meta_in_tls_dispatch_include" , [](bool meta_in_tls) { |
1656 | auto local_keyset = c10::impl::tls_local_dispatch_key_set(); |
1657 | c10::DispatchKeySet key_set({at::DispatchKey::Meta}); |
1658 | if (meta_in_tls) { |
1659 | local_keyset.included_ = local_keyset.included_ | key_set; |
1660 | } else { |
1661 | local_keyset.included_ = |
1662 | local_keyset.included_.remove_backend(c10::BackendComponent::MetaBit); |
1663 | } |
1664 | c10::impl::_force_tls_local_dispatch_key_set(local_keyset); |
1665 | }); |
1666 | |
1667 | py_module.def("_meta_in_tls_dispatch_include" , []() { |
1668 | auto local_keyset = c10::impl::tls_local_dispatch_key_set(); |
1669 | return local_keyset.included_.has_backend(c10::BackendComponent::MetaBit); |
1670 | }); |
1671 | |
1672 | py_module.def("_dump_local_tls_set" , []() { |
1673 | auto local_keyset = c10::impl::tls_local_dispatch_key_set(); |
1674 | std::cout << "Included: " << toString(local_keyset.included_) << "\n" ; |
1675 | std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n" ; |
1676 | }); |
1677 | |
1678 | py_module.def( |
1679 | "_should_allow_numbers_as_tensors" , [](const std::string& name) { |
1680 | return torch::should_allow_numbers_as_tensors(name); |
1681 | }); |
1682 | |
1683 | const auto& defaultGenerator = at::detail::getDefaultCPUGenerator(); |
1684 | THPDefaultCPUGenerator = |
1685 | (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); |
1686 | // This reference is meant to be given away, so no need to incref here. |
1687 | ASSERT_TRUE(set_module_attr( |
1688 | "default_generator" , |
1689 | (PyObject*)THPDefaultCPUGenerator, |
1690 | /* incref= */ false)); |
1691 | ASSERT_TRUE(set_module_attr( |
1692 | "DisableTorchFunctionSubclass" , |
1693 | (PyObject*)THPModule_DisableTorchFunctionSubclassType(), |
1694 | /* incref= */ false)); |
1695 | ASSERT_TRUE(set_module_attr( |
1696 | "DisableTorchFunction" , |
1697 | (PyObject*)THPModule_DisableTorchFunctionType(), |
1698 | /* incref= */ false)); |
1699 | torch::set_disabled_torch_function_impl( |
1700 | PyObject_GetAttrString(module, "_disabled_torch_function_impl" )); |
1701 | ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr); |
1702 | torch::set_disabled_torch_dispatch_impl( |
1703 | PyObject_GetAttrString(module, "_disabled_torch_dispatch_impl" )); |
1704 | ASSERT_TRUE(torch::disabled_torch_dispatch_impl() != nullptr); |
1705 | return module; |
1706 | END_HANDLE_TH_ERRORS |
1707 | } |
1708 | |
1709 | // Checks that the _C shared library isn't initialized multiple times. This |
1710 | // can happen if the same csrc files are compiled into multiple shared |
1711 | // libraries. |
1712 | inline void pytorch_duplicate_guard() { |
1713 | static int initialized = 0; |
1714 | if (initialized) { |
1715 | fprintf(stderr, "pytorch: _C shared library re-initialized\n" ); |
1716 | abort(); |
1717 | } |
1718 | initialized = 1; |
1719 | ; |
1720 | } |
1721 | |
1722 | struct call_duplicate_guard { |
1723 | call_duplicate_guard() { |
1724 | pytorch_duplicate_guard(); |
1725 | } |
1726 | }; |
1727 | |
1728 | static call_duplicate_guard _call_duplicate_guard; |
1729 | |