1#include <torch/csrc/python_headers.h>
2
3#include <ATen/PythonTorchFunctionTLS.h>
4#include <ATen/SavedTensorHooks.h>
5#include <ATen/autocast_mode.h>
6#include <ATen/core/PythonFallbackKernel.h>
7#include <ATen/record_function.h>
8#include <c10/core/DeviceType.h>
9#include <c10/core/InferenceMode.h>
10#include <c10/core/ScalarType.h>
11#include <c10/core/impl/PythonDispatcherTLS.h>
12#include <torch/csrc/Exceptions.h>
13#include <torch/csrc/autograd/autograd.h>
14#include <torch/csrc/autograd/function.h>
15#include <torch/csrc/autograd/grad_mode.h>
16#include <torch/csrc/autograd/profiler.h>
17#include <torch/csrc/autograd/profiler_python.h>
18#include <torch/csrc/autograd/python_function.h>
19#include <torch/csrc/autograd/python_saved_variable_hooks.h>
20#include <torch/csrc/autograd/python_variable.h>
21#include <torch/csrc/autograd/record_function_ops.h>
22#include <torch/csrc/autograd/saved_variable.h>
23#include <torch/csrc/autograd/utils/python_arg_parsing.h>
24#include <torch/csrc/autograd/utils/wrap_outputs.h>
25#include <torch/csrc/jit/python/pybind_utils.h>
26#include <torch/csrc/profiler/collection.h>
27#include <torch/csrc/profiler/kineto_shim.h>
28#include <torch/csrc/utils/disable_torch_function.h>
29#include <torch/csrc/utils/pybind.h>
30#include <torch/csrc/utils/pycfunction_helpers.h>
31#include <torch/csrc/utils/python_torch_function_mode.h>
32
33#include <set>
34#include <unordered_set>
35#include <utility>
36
37namespace {
38
39struct DisableFuncTorch {
40 DisableFuncTorch()
41 : front_guard_(c10::DispatchKey::FuncTorchDynamicLayerFrontMode),
42 back_guard_(c10::DispatchKey::FuncTorchDynamicLayerBackMode) {}
43 c10::impl::ExcludeDispatchKeyGuard front_guard_;
44 c10::impl::ExcludeDispatchKeyGuard back_guard_;
45};
46
47struct MultithreadingEnabled {
48 MultithreadingEnabled(bool enabled)
49 : old_(c10::AutogradState::get_tls_state().get_multithreading_enabled()) {
50 c10::AutogradState::get_tls_state().set_multithreading_enabled(enabled);
51 }
52 ~MultithreadingEnabled() {
53 c10::AutogradState::get_tls_state().set_multithreading_enabled(old_);
54 }
55 bool old_;
56};
57
58struct ViewReplayEnabled {
59 ViewReplayEnabled(bool enabled)
60 : old_(c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
61 c10::AutogradState::get_tls_state().set_view_replay_enabled(enabled);
62 }
63 ~ViewReplayEnabled() {
64 c10::AutogradState::get_tls_state().set_view_replay_enabled(old_);
65 }
66 bool old_;
67};
68
69struct DisableAutocast {
70 c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset};
71};
72
73struct EnableTorchFunction {
74 EnableTorchFunction()
75 : old_(at::impl::PythonTorchFunctionTLS::get_disabled_state()) {
76 at::impl::PythonTorchFunctionTLS::set_disabled_state(
77 at::impl::TorchFunctionDisabledState::ENABLED);
78 }
79 ~EnableTorchFunction() {
80 at::impl::PythonTorchFunctionTLS::set_disabled_state(old_);
81 }
82 at::impl::TorchFunctionDisabledState old_;
83};
84
85struct EnablePythonDispatcher {
86 EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) {
87 c10::impl::PythonDispatcherTLS::set_state(getPyInterpreter());
88 }
89 ~EnablePythonDispatcher() {
90 c10::impl::PythonDispatcherTLS::set_state(old_);
91 }
92 c10::impl::PyInterpreter* old_;
93};
94
95} // namespace
96
97PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
98 using namespace torch::autograd::profiler;
99 using namespace torch::profiler::impl;
100 auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch._tensor"));
101 if (!tensor_module)
102 return nullptr;
103
104 // NOTE: "leaks" THPVariableClass
105 THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor");
106 if (!THPVariableClass)
107 return nullptr;
108
109 auto autograd_module = THPObjectPtr(PyImport_ImportModule("torch.autograd"));
110 if (!autograd_module)
111 return nullptr;
112
113 // NOTE: "leaks" Function
114 THPFunctionClass = PyObject_GetAttrString(autograd_module, "Function");
115 if (!THPFunctionClass)
116 return nullptr;
117
118 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
119 if (!torch_C_module)
120 return nullptr;
121 auto _C_m = py::handle(torch_C_module).cast<py::module>();
122 auto m = _C_m.def_submodule("_autograd", "autograd bindings");
123
124 auto parameter_module =
125 THPObjectPtr(PyImport_ImportModule("torch.nn.parameter"));
126 if (!parameter_module)
127 return nullptr;
128
129 // NOTE: "leaks" ParameterClass
130 ParameterClass = PyObject_GetAttrString(parameter_module, "Parameter");
131 if (!ParameterClass)
132 return nullptr;
133
134 py::class_<LegacyEvent>(m, "ProfilerEvent")
135 .def("kind", &LegacyEvent::kindStr)
136 .def("name", [](const LegacyEvent& e) { return e.name(); })
137 .def("thread_id", &LegacyEvent::threadId)
138 .def("fwd_thread_id", &LegacyEvent::fwdThreadId)
139 .def("device", &LegacyEvent::device)
140 .def("cpu_elapsed_us", &LegacyEvent::cpuElapsedUs)
141 .def("cuda_elapsed_us", &LegacyEvent::cudaElapsedUs)
142 .def("has_cuda", &LegacyEvent::hasCuda)
143 .def("shapes", &LegacyEvent::shapes)
144 .def("cpu_memory_usage", &LegacyEvent::cpuMemoryUsage)
145 .def("cuda_memory_usage", &LegacyEvent::cudaMemoryUsage)
146 .def("handle", &LegacyEvent::handle)
147 .def("node_id", &LegacyEvent::nodeId)
148 .def("is_remote", &LegacyEvent::isRemote)
149 .def("sequence_nr", &LegacyEvent::sequenceNr)
150 .def("stack", &LegacyEvent::stack)
151 .def("scope", &LegacyEvent::scope)
152 .def("correlation_id", &LegacyEvent::correlationId)
153 .def("start_us", &LegacyEvent::cpuUs)
154 .def("flops", &LegacyEvent::flops)
155 .def("is_async", &LegacyEvent::isAsync);
156
157 py::enum_<c10::DeviceType>(m, "DeviceType")
158 .value("CPU", c10::DeviceType::CPU)
159 .value("CUDA", c10::DeviceType::CUDA)
160 .value("MKLDNN", c10::DeviceType::MKLDNN)
161 .value("OPENGL", c10::DeviceType::OPENGL)
162 .value("OPENCL", c10::DeviceType::OPENCL)
163 .value("IDEEP", c10::DeviceType::IDEEP)
164 .value("HIP", c10::DeviceType::HIP)
165 .value("FPGA", c10::DeviceType::FPGA)
166 .value("ORT", c10::DeviceType::ORT)
167 .value("XLA", c10::DeviceType::XLA)
168 .value("Vulkan", c10::DeviceType::Vulkan)
169 .value("Metal", c10::DeviceType::Metal)
170 .value("XPU", c10::DeviceType::XPU)
171 .value("MPS", c10::DeviceType::MPS)
172 .value("Meta", c10::DeviceType::Meta)
173 .value("HPU", c10::DeviceType::HPU)
174 .value("VE", c10::DeviceType::VE)
175 .value("Lazy", c10::DeviceType::Lazy)
176 .value("IPU", c10::DeviceType::IPU);
177
178 py::class_<KinetoEvent>(m, "_KinetoEvent")
179 // name of the event
180 .def("name", [](const KinetoEvent& e) { return e.name(); })
181 // PyTorch thread id of the start callback
182 .def(
183 "start_thread_id",
184 [](const KinetoEvent& e) { return e.startThreadId(); })
185 // PyTorch thread id of the end callback
186 .def(
187 "end_thread_id", [](const KinetoEvent& e) { return e.endThreadId(); })
188 // for events of scope BACKWARD_FUNCTION - PyTorch thread id
189 // of the corresponding forward op
190 .def(
191 "fwd_thread_id", [](const KinetoEvent& e) { return e.fwdThreadId(); })
192 // together with fwd_thread_id, used to uniquely identify
193 // the forward op
194 .def("sequence_nr", [](const KinetoEvent& e) { return e.sequenceNr(); })
195 // absolute start time (since unix epoch) in us
196 .def("start_us", [](const KinetoEvent& e) { return e.startUs(); })
197 // duration in us
198 .def("duration_us", [](const KinetoEvent& e) { return e.durationUs(); })
199 // used for correlation between high-level PyTorch events
200 // and low-level device events
201 .def(
202 "correlation_id",
203 [](const KinetoEvent& e) { return e.correlationId(); })
204 // shapes of input tensors
205 .def("shapes", [](const KinetoEvent& e) { return e.shapes().vec(); })
206 .def("dtypes", [](const KinetoEvent& e) { return e.dtypes().vec(); })
207 // stack traces of the PyTorch CPU events
208 .def("stack", [](const KinetoEvent& e) { return e.stack().vec(); })
209 // type of the RecordFunction that generated a PyTorch CPU event
210 // (op, torchscript function, user label, etc)
211 .def("scope", [](const KinetoEvent& e) { return e.scope(); })
212 // device number, for CPU - process id
213 .def("device_index", [](const KinetoEvent& e) { return e.deviceIndex(); })
214 // for CUDA - stream id, for CPU - start thread id
215 .def(
216 "device_resource_id",
217 [](const KinetoEvent& e) { return e.deviceResourceId(); })
218 // device type
219 .def("device_type", [](const KinetoEvent& e) { return e.deviceType(); })
220 // correlation id of a linked event
221 .def(
222 "linked_correlation_id",
223 [](const KinetoEvent& e) { return e.linkedCorrelationId(); })
224 // compute flops
225 .def("flops", [](const KinetoEvent& e) { return e.flops(); })
226 // Whether this is async event or not
227 .def("is_async", [](const KinetoEvent& e) { return e.isAsync(); })
228 .def("cuda_elapsed_us", &KinetoEvent::cudaElapsedUs)
229 .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); });
230
231 m.def("_soft_assert_raises", &setSoftAssertRaises);
232
233 py::class_<ProfilerResult>(m, "_ProfilerResult")
234 .def("trace_start_us", &ProfilerResult::trace_start_us)
235 .def("events", &ProfilerResult::events)
236 .def("experimental_event_tree", &ProfilerResult::event_tree)
237#ifdef USE_KINETO
238 .def("save", &ProfilerResult::save)
239#endif // USE_KINETO
240 ;
241
242 m.def(
243 "_enable_profiler",
244 &enableProfiler,
245 py::arg("config"),
246 py::arg("activities"),
247 py::arg("scopes") = std::unordered_set<at::RecordScope>());
248 m.def("_disable_profiler", disableProfiler);
249 m.def("_prepare_profiler", prepareProfiler);
250 m.def("_add_metadata_json", addMetadataJson); // Only if `USE_KINETO` is set
251 m.def("_kineto_step", profilerStep); // Only if `USE_KINETO` is set
252 m.def("kineto_available", []() { return torch::profiler::kKinetoAvailable; });
253
254 // NOTICE: These record functions are not torch operators and may not show up
255 // in TorchScript tracing, FX transforms, or operator serialization. For these
256 // use cases, please use `torch.profiler.record_function`.
257 // Creates a new profiling scope using RecordFunction and invokes its starting
258 // callbacks.
259 m.def(
260 "_record_function_with_args_enter",
261 [](const std::string& name, py::args args) {
262 using torch::autograd::profiler::PythonRecordFunction;
263 auto python_rec = c10::make_intrusive<PythonRecordFunction>(
264 at::RecordScope::USER_SCOPE);
265 auto* rec = &python_rec->record;
266 if (rec->isActive()) {
267 if (rec->needsInputs()) {
268 auto iv_inputs = std::vector<c10::IValue>();
269 for (const auto& arg : args) {
270 iv_inputs.push_back(torch::jit::toTypeInferredIValue(arg));
271 }
272 rec->before(
273 name,
274 c10::ArrayRef<const c10::IValue>(
275 iv_inputs.data(), iv_inputs.size()));
276 } else {
277 rec->before(name);
278 }
279 }
280 return torch::jit::toPyObject(std::move(python_rec));
281 });
282
283 // Ends the profiling scope created with record_function_with_param_enter.
284 m.def("_record_function_with_args_exit", [](const py::object& obj) {
285 using torch::autograd::profiler::PythonRecordFunction;
286 auto python_record = torch::jit::toCustomClass<PythonRecordFunction>(obj);
287
288 // We don't actually need to do anything with handle just need to persist
289 // the lifetime until now.
290 python_record->record.end();
291 });
292
293 m.def("_supported_activities", []() {
294 std::set<ActivityType> activities{ActivityType::CPU};
295#if defined(USE_KINETO) && \
296 (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER))
297 if (at::getNumGPUs() > 0) {
298 activities.insert(ActivityType::CUDA);
299 }
300#endif
301 return activities;
302 });
303
304 m.def("_unsafe_set_version_counter", [](at::Tensor t, int64_t i) {
305 auto vc = torch::autograd::impl::version_counter(t);
306 vc.set_version(i);
307 });
308
309 m.def("_enable_profiler_legacy", enableProfilerLegacy);
310 py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
311 .def(py::init<bool, bool>());
312 m.def(
313 "_disable_profiler_legacy",
314 disableProfilerLegacy,
315 py::arg("profiler_disable_options") = ProfilerDisableOptions());
316 m.def("_profiler_enabled", profilerEnabled);
317 m.def("_profiler_type", torch::profiler::impl::profilerType);
318 m.def("_enable_record_function", [](bool enable) {
319 at::enableRecordFunction(enable);
320 });
321 m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) {
322 auto cb =
323 at::RecordFunctionCallback(nullptr).needsInputs(true).samplingProb(
324 sampling_prob);
325 if (is_global) {
326 at::addGlobalCallback(cb);
327 } else {
328 at::addThreadLocalCallback(cb);
329 }
330 });
331 m.def("_clear_callbacks", []() { at::clearCallbacks(); });
332 m.def(
333 "_saved_tensors_hooks_is_enabled",
334 at::SavedTensorDefaultHooks::is_enabled);
335 m.def("_saved_tensors_hooks_enable", at::SavedTensorDefaultHooks::enable);
336 m.def("_saved_tensors_hooks_disable", at::SavedTensorDefaultHooks::disable);
337 m.def(
338 "_saved_tensors_hooks_get_disabled_error_message",
339 at::SavedTensorDefaultHooks::get_disabled_error_message);
340 m.def(
341 "_push_saved_tensors_default_hooks",
342 [](py::function& pack_hook, py::function& unpack_hook) {
343 torch::autograd::PyDefaultSavedVariableHooks::push_hooks(
344 pack_hook, unpack_hook);
345 });
346 m.def("_pop_saved_tensors_default_hooks", []() {
347 torch::autograd::PyDefaultSavedVariableHooks::pop_hooks();
348 });
349
350 _C_m.def(
351 "_register_py_class_for_device",
352 [](const std::string& device, py::object python_type_class) {
353 auto cls = python_type_class.ptr();
354 registerPythonTensorClass(device, cls);
355 });
356
357 _C_m.def("_activate_cuda_trace", []() { activateCUDATrace(); });
358
359 py::class_<c10::InferenceMode>(_C_m, "_InferenceMode").def(py::init<bool>());
360
361 py::class_<at::impl::RestorePythonTLSSnapshot>(
362 _C_m, "_RestorePythonTLSSnapshot")
363 .def(py::init<>());
364
365 py::class_<torch::DisableTorchDispatch>(_C_m, "_DisableTorchDispatch")
366 .def(py::init<>());
367 py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction")
368 .def(py::init<>());
369 py::class_<EnablePythonDispatcher>(_C_m, "_EnablePythonDispatcher")
370 .def(py::init<>());
371 py::class_<c10::impl::DisablePythonDispatcher>(
372 _C_m, "_DisablePythonDispatcher")
373 .def(py::init<>());
374 py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
375 py::class_<MultithreadingEnabled>(_C_m, "_MultithreadingEnabled")
376 .def(py::init<bool>());
377 py::class_<DisableAutocast>(std::move(_C_m), "_DisableAutocast")
378 .def(py::init<>());
379 py::class_<ViewReplayEnabled>(_C_m, "_ViewReplayEnabled")
380 .def(py::init<bool>());
381 py::class_<torch::autograd::SavedVariable>(std::move(m), "SavedTensor")
382 .def(py::init([]() -> torch::autograd::SavedVariable {
383 TORCH_CHECK(
384 false,
385 "Trying to create a SavedTensor object from Python is forbidden.");
386 }))
387 .def(
388 "register_hooks",
389 [](torch::autograd::SavedVariable& s,
390 py::function& pack_hook,
391 py::function& unpack_hook) {
392 // Because we use a py::object, pybind will increment the refcount
393 // of the hook functions for us
394 s.register_hooks(
395 std::make_unique<torch::autograd::PySavedVariableHooks>(
396 pack_hook, unpack_hook));
397 });
398
399 torch::autograd::profiler::python_tracer::init();
400 Py_RETURN_TRUE;
401}
402
403namespace torch {
404namespace autograd {
405
406static PyObject* set_autocast_enabled(PyObject* _unused, PyObject* arg) {
407 HANDLE_TH_ERRORS
408 if (!PyBool_Check(arg)) {
409 throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
410 }
411 at::autocast::set_enabled(arg == Py_True);
412 Py_RETURN_NONE;
413 END_HANDLE_TH_ERRORS
414}
415
416static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) {
417 HANDLE_TH_ERRORS
418 if (at::autocast::is_enabled()) {
419 Py_RETURN_TRUE;
420 } else {
421 Py_RETURN_FALSE;
422 }
423 END_HANDLE_TH_ERRORS
424}
425
426static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
427 HANDLE_TH_ERRORS
428 if (at::autocast::is_enabled() || at::autocast::is_cpu_enabled() ||
429 at::autocast::is_xpu_enabled()) {
430 Py_RETURN_TRUE;
431 } else {
432 Py_RETURN_FALSE;
433 }
434 END_HANDLE_TH_ERRORS
435}
436
437static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
438 HANDLE_TH_ERRORS
439 if (!PyBool_Check(arg)) {
440 throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
441 }
442 at::autocast::set_cpu_enabled(arg == Py_True);
443 Py_RETURN_NONE;
444 END_HANDLE_TH_ERRORS
445}
446
447static PyObject* is_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
448 HANDLE_TH_ERRORS
449 if (at::autocast::is_cpu_enabled()) {
450 Py_RETURN_TRUE;
451 } else {
452 Py_RETURN_FALSE;
453 }
454 END_HANDLE_TH_ERRORS
455}
456
457static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
458 HANDLE_TH_ERRORS
459 if (!THPDtype_Check(arg)) {
460 throw TypeError(
461 "dtype must be a torch.dtype (got %s)", Py_TYPE(arg)->tp_name);
462 }
463 at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
464 at::autocast::set_autocast_gpu_dtype(targetType);
465 Py_RETURN_NONE;
466 END_HANDLE_TH_ERRORS
467}
468
469static PyObject* set_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
470 HANDLE_TH_ERRORS
471 if (!THPDtype_Check(arg)) {
472 throw TypeError(
473 "dtype must be a torch.dtype (got %s)", Py_TYPE(arg)->tp_name);
474 }
475 at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
476 at::autocast::set_autocast_cpu_dtype(targetType);
477 Py_RETURN_NONE;
478 END_HANDLE_TH_ERRORS
479}
480
481static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
482 HANDLE_TH_ERRORS
483 at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype();
484 auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
485 Py_INCREF(dtype);
486 return dtype;
487 END_HANDLE_TH_ERRORS
488}
489
490static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
491 HANDLE_TH_ERRORS
492 at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype();
493 auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
494 Py_INCREF(dtype);
495 return dtype;
496 END_HANDLE_TH_ERRORS
497}
498
499static PyObject* clear_autocast_cache(PyObject* _unused, PyObject* arg) {
500 HANDLE_TH_ERRORS
501 at::autocast::clear_cache();
502 Py_RETURN_NONE;
503 END_HANDLE_TH_ERRORS
504}
505
506static PyObject* autocast_increment_nesting(PyObject* _unused, PyObject* arg) {
507 HANDLE_TH_ERRORS
508 return THPUtils_packInt64(at::autocast::increment_nesting());
509 END_HANDLE_TH_ERRORS
510}
511
512static PyObject* autocast_decrement_nesting(PyObject* _unused, PyObject* arg) {
513 HANDLE_TH_ERRORS
514 return THPUtils_packInt64(at::autocast::decrement_nesting());
515 END_HANDLE_TH_ERRORS
516}
517
518static PyObject* is_autocast_cache_enabled(PyObject* _unused, PyObject* arg) {
519 HANDLE_TH_ERRORS
520 if (at::autocast::is_autocast_cache_enabled()) {
521 Py_RETURN_TRUE;
522 } else {
523 Py_RETURN_FALSE;
524 }
525 END_HANDLE_TH_ERRORS
526}
527
528static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) {
529 HANDLE_TH_ERRORS
530 if (!PyBool_Check(arg)) {
531 throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
532 }
533 at::autocast::set_autocast_cache_enabled(arg == Py_True);
534 Py_RETURN_NONE;
535 END_HANDLE_TH_ERRORS
536}
537
538static PyObject* set_grad_enabled(PyObject* _unused, PyObject* arg) {
539 HANDLE_TH_ERRORS
540 if (!PyBool_Check(arg)) {
541 throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
542 }
543 GradMode::set_enabled(arg == Py_True);
544 Py_RETURN_NONE;
545 END_HANDLE_TH_ERRORS
546}
547
548static PyObject* is_grad_enabled(PyObject* _unused, PyObject* arg) {
549 HANDLE_TH_ERRORS
550 if (GradMode::is_enabled()) {
551 Py_RETURN_TRUE;
552 } else {
553 Py_RETURN_FALSE;
554 }
555 END_HANDLE_TH_ERRORS
556}
557
558static PyObject* set_fwd_grad_enabled(PyObject* _unused, PyObject* arg) {
559 HANDLE_TH_ERRORS
560 if (!PyBool_Check(arg)) {
561 throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
562 }
563 c10::AutogradState::get_tls_state().set_fw_grad_mode(arg == Py_True);
564 Py_RETURN_NONE;
565 END_HANDLE_TH_ERRORS
566}
567
568static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) {
569 HANDLE_TH_ERRORS
570 if (c10::AutogradState::get_tls_state().get_fw_grad_mode()) {
571 Py_RETURN_TRUE;
572 } else {
573 Py_RETURN_FALSE;
574 }
575 END_HANDLE_TH_ERRORS
576}
577
578static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) {
579 HANDLE_TH_ERRORS
580 if (c10::InferenceMode::is_enabled()) {
581 Py_RETURN_TRUE;
582 } else {
583 Py_RETURN_FALSE;
584 }
585 END_HANDLE_TH_ERRORS
586}
587
588static PyObject* set_anomaly_mode_enabled(
589 PyObject* _unused,
590 PyObject* args,
591 PyObject* kwargs) {
592 HANDLE_TH_ERRORS
593 static PythonArgParser parser({
594 "set_anomaly_enabled(bool enabled, bool check_nan=True)",
595 });
596 ParsedArgs<2> parsed_args;
597 auto r = parser.parse(args, kwargs, parsed_args);
598 AnomalyMode::set_enabled(r.toBool(0), r.toBool(1));
599 Py_RETURN_NONE;
600 END_HANDLE_TH_ERRORS
601}
602
603static PyObject* is_anomaly_mode_enabled(PyObject* _unused, PyObject* arg) {
604 HANDLE_TH_ERRORS
605 if (AnomalyMode::is_enabled()) {
606 Py_RETURN_TRUE;
607 } else {
608 Py_RETURN_FALSE;
609 }
610 END_HANDLE_TH_ERRORS
611}
612
613static PyObject* is_anomaly_check_nan_enabled(
614 PyObject* _unused,
615 PyObject* arg) {
616 HANDLE_TH_ERRORS
617 if (AnomalyMode::should_check_nan()) {
618 Py_RETURN_TRUE;
619 } else {
620 Py_RETURN_FALSE;
621 }
622 END_HANDLE_TH_ERRORS
623}
624
625static PyObject* python_enter_dual_level(PyObject* _unused, PyObject* arg) {
626 HANDLE_TH_ERRORS
627 // It is unlikely that the depth of forward nesting will overflow int64_t so
628 // we just static cast here.
629 return utils::wrap(static_cast<int64_t>(forward_ad::enter_dual_level()));
630 END_HANDLE_TH_ERRORS
631}
632
633static PyObject* python_exit_dual_level(
634 PyObject* _unused,
635 PyObject* args,
636 PyObject* kwargs) {
637 HANDLE_TH_ERRORS
638 static PythonArgParser parser({"exit_dual_level(int64_t level)"});
639
640 ParsedArgs<1> parsed_args;
641 auto _r = parser.parse(args, kwargs, parsed_args);
642
643 auto idx = _r.toInt64(0);
644 // Make sure the given index is valid before casting it
645 TORCH_CHECK(idx >= 0, "Dual level must be a positive number.");
646 forward_ad::exit_dual_level(static_cast<uint64_t>(idx));
647 Py_RETURN_NONE;
648 END_HANDLE_TH_ERRORS
649}
650
651static PyObject* is_torch_function_mode_enabled(
652 PyObject* _unused,
653 PyObject* _unused2) {
654 HANDLE_TH_ERRORS
655 if (at::impl::torch_function_mode_enabled()) {
656 Py_RETURN_TRUE;
657 } else {
658 Py_RETURN_FALSE;
659 }
660 END_HANDLE_TH_ERRORS
661}
662
663static PyObject* push_on_torch_function_stack(
664 PyObject* _unused,
665 PyObject* arg) {
666 HANDLE_TH_ERRORS
667 if (arg != Py_None) {
668 Py_INCREF(arg);
669 at::impl::PythonTorchFunctionTLS::push_onto_stack(
670 std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
671 }
672 Py_RETURN_NONE;
673 END_HANDLE_TH_ERRORS
674}
675
676static PyObject* pop_torch_function_stack(
677 PyObject* _unused,
678 PyObject* _unused2) {
679 HANDLE_TH_ERRORS
680 const auto& mode = at::impl::PythonTorchFunctionTLS::pop_stack();
681 auto* r = mode->ptr(getPyInterpreter());
682 Py_INCREF(r);
683 return r;
684 END_HANDLE_TH_ERRORS
685}
686
687static PyObject* get_function_stack_at(
688 PyObject* _unused,
689 PyObject* args,
690 PyObject* kwargs) {
691 HANDLE_TH_ERRORS
692 static PythonArgParser parser({"get_stack_at(int64_t level)"});
693
694 ParsedArgs<1> parsed_args;
695 auto _r = parser.parse(args, kwargs, parsed_args);
696
697 auto idx = _r.toInt64(0);
698 const auto& mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
699 auto* r = mode->ptr(getPyInterpreter());
700 Py_INCREF(r);
701 return r;
702 END_HANDLE_TH_ERRORS
703}
704
705static PyObject* len_torch_function_stack(
706 PyObject* _unused,
707 PyObject* _unused2) {
708 HANDLE_TH_ERRORS
709 const auto len = at::impl::PythonTorchFunctionTLS::stack_len();
710 return utils::wrap(static_cast<int64_t>(len));
711 END_HANDLE_TH_ERRORS
712}
713
714static PyObject* push_on_torch_dispatch_stack(
715 PyObject* _unused,
716 PyObject* arg) {
717 HANDLE_TH_ERRORS
718 if (arg != Py_None) {
719 Py_INCREF(arg);
720 c10::impl::TorchDispatchModeTLS::push_onto_stack(
721 std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
722 }
723 Py_RETURN_NONE;
724 END_HANDLE_TH_ERRORS
725}
726
727static PyObject* pop_torch_dispatch_stack(
728 PyObject* _unused,
729 PyObject* _unused2) {
730 HANDLE_TH_ERRORS
731 const auto& mode = c10::impl::TorchDispatchModeTLS::pop_stack();
732 auto* r = mode->ptr(getPyInterpreter());
733 Py_INCREF(r);
734 return r;
735 END_HANDLE_TH_ERRORS
736}
737
738static PyObject* get_dispatch_stack_at(
739 PyObject* _unused,
740 PyObject* args,
741 PyObject* kwargs) {
742 HANDLE_TH_ERRORS
743 static PythonArgParser parser({"get_stack_at(int64_t level)"});
744
745 ParsedArgs<1> parsed_args;
746 auto _r = parser.parse(args, kwargs, parsed_args);
747
748 auto idx = _r.toInt64(0);
749 const auto& mode = c10::impl::TorchDispatchModeTLS::get_stack_at(idx);
750 auto* r = mode->ptr(getPyInterpreter());
751 Py_INCREF(r);
752 return r;
753 END_HANDLE_TH_ERRORS
754}
755
756static PyObject* len_torch_dispatch_stack(
757 PyObject* _unused,
758 PyObject* _unused2) {
759 HANDLE_TH_ERRORS
760 const auto len = c10::impl::TorchDispatchModeTLS::stack_len();
761 return utils::wrap(static_cast<int64_t>(len));
762 END_HANDLE_TH_ERRORS
763}
764
765// autograd methods on torch._C
766static PyMethodDef methods[] = { // NOLINT
767 {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
768 {"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
769 {"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr},
770 {"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr},
771 {"is_inference_mode_enabled",
772 is_inference_mode_enabled,
773 METH_NOARGS,
774 nullptr},
775 {"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr},
776 {"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr},
777 {"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
778 {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
779 {"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
780 {"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},
781 {"set_autocast_cpu_dtype", set_autocast_cpu_dtype, METH_O, nullptr},
782 {"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr},
783 {"set_autocast_gpu_dtype", set_autocast_gpu_dtype, METH_O, nullptr},
784 {"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr},
785 {"autocast_increment_nesting",
786 autocast_increment_nesting,
787 METH_NOARGS,
788 nullptr},
789 {"autocast_decrement_nesting",
790 autocast_decrement_nesting,
791 METH_NOARGS,
792 nullptr},
793 {"is_autocast_cache_enabled",
794 is_autocast_cache_enabled,
795 METH_NOARGS,
796 nullptr},
797 {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr},
798 {"set_anomaly_enabled",
799 castPyCFunctionWithKeywords(set_anomaly_mode_enabled),
800 METH_VARARGS | METH_KEYWORDS,
801 nullptr},
802 {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
803 {"is_anomaly_check_nan_enabled",
804 is_anomaly_check_nan_enabled,
805 METH_NOARGS,
806 nullptr},
807 {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
808 {"_exit_dual_level",
809 castPyCFunctionWithKeywords(python_exit_dual_level),
810 METH_VARARGS | METH_KEYWORDS,
811 nullptr},
812 {"_is_torch_function_mode_enabled",
813 is_torch_function_mode_enabled,
814 METH_NOARGS,
815 nullptr},
816 {"_push_on_torch_function_stack",
817 push_on_torch_function_stack,
818 METH_O,
819 nullptr},
820 {"_pop_torch_function_stack",
821 pop_torch_function_stack,
822 METH_NOARGS,
823 nullptr},
824 {"_get_function_stack_at",
825 castPyCFunctionWithKeywords(get_function_stack_at),
826 METH_VARARGS | METH_KEYWORDS,
827 nullptr},
828 {"_len_torch_function_stack",
829 len_torch_function_stack,
830 METH_NOARGS,
831 nullptr},
832 {"_push_on_torch_dispatch_stack",
833 push_on_torch_dispatch_stack,
834 METH_O,
835 nullptr},
836 {"_pop_torch_dispatch_stack",
837 pop_torch_dispatch_stack,
838 METH_NOARGS,
839 nullptr},
840 {"_get_dispatch_stack_at",
841 castPyCFunctionWithKeywords(get_dispatch_stack_at),
842 METH_VARARGS | METH_KEYWORDS,
843 nullptr},
844 {"_len_torch_dispatch_stack",
845 len_torch_dispatch_stack,
846 METH_NOARGS,
847 nullptr},
848 {nullptr, nullptr, 0, nullptr}};
849
850PyMethodDef* python_functions() {
851 return methods;
852}
853
854} // namespace autograd
855} // namespace torch
856