1 | #include <torch/csrc/python_headers.h> |
2 | |
3 | #include <ATen/PythonTorchFunctionTLS.h> |
4 | #include <ATen/SavedTensorHooks.h> |
5 | #include <ATen/autocast_mode.h> |
6 | #include <ATen/core/PythonFallbackKernel.h> |
7 | #include <ATen/record_function.h> |
8 | #include <c10/core/DeviceType.h> |
9 | #include <c10/core/InferenceMode.h> |
10 | #include <c10/core/ScalarType.h> |
11 | #include <c10/core/impl/PythonDispatcherTLS.h> |
12 | #include <torch/csrc/Exceptions.h> |
13 | #include <torch/csrc/autograd/autograd.h> |
14 | #include <torch/csrc/autograd/function.h> |
15 | #include <torch/csrc/autograd/grad_mode.h> |
16 | #include <torch/csrc/autograd/profiler.h> |
17 | #include <torch/csrc/autograd/profiler_python.h> |
18 | #include <torch/csrc/autograd/python_function.h> |
19 | #include <torch/csrc/autograd/python_saved_variable_hooks.h> |
20 | #include <torch/csrc/autograd/python_variable.h> |
21 | #include <torch/csrc/autograd/record_function_ops.h> |
22 | #include <torch/csrc/autograd/saved_variable.h> |
23 | #include <torch/csrc/autograd/utils/python_arg_parsing.h> |
24 | #include <torch/csrc/autograd/utils/wrap_outputs.h> |
25 | #include <torch/csrc/jit/python/pybind_utils.h> |
26 | #include <torch/csrc/profiler/collection.h> |
27 | #include <torch/csrc/profiler/kineto_shim.h> |
28 | #include <torch/csrc/utils/disable_torch_function.h> |
29 | #include <torch/csrc/utils/pybind.h> |
30 | #include <torch/csrc/utils/pycfunction_helpers.h> |
31 | #include <torch/csrc/utils/python_torch_function_mode.h> |
32 | |
33 | #include <set> |
34 | #include <unordered_set> |
35 | #include <utility> |
36 | |
37 | namespace { |
38 | |
39 | struct DisableFuncTorch { |
40 | DisableFuncTorch() |
41 | : front_guard_(c10::DispatchKey::FuncTorchDynamicLayerFrontMode), |
42 | back_guard_(c10::DispatchKey::FuncTorchDynamicLayerBackMode) {} |
43 | c10::impl::ExcludeDispatchKeyGuard front_guard_; |
44 | c10::impl::ExcludeDispatchKeyGuard back_guard_; |
45 | }; |
46 | |
47 | struct MultithreadingEnabled { |
48 | MultithreadingEnabled(bool enabled) |
49 | : old_(c10::AutogradState::get_tls_state().get_multithreading_enabled()) { |
50 | c10::AutogradState::get_tls_state().set_multithreading_enabled(enabled); |
51 | } |
52 | ~MultithreadingEnabled() { |
53 | c10::AutogradState::get_tls_state().set_multithreading_enabled(old_); |
54 | } |
55 | bool old_; |
56 | }; |
57 | |
58 | struct ViewReplayEnabled { |
59 | ViewReplayEnabled(bool enabled) |
60 | : old_(c10::AutogradState::get_tls_state().get_view_replay_enabled()) { |
61 | c10::AutogradState::get_tls_state().set_view_replay_enabled(enabled); |
62 | } |
63 | ~ViewReplayEnabled() { |
64 | c10::AutogradState::get_tls_state().set_view_replay_enabled(old_); |
65 | } |
66 | bool old_; |
67 | }; |
68 | |
69 | struct DisableAutocast { |
70 | c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset}; |
71 | }; |
72 | |
73 | struct EnableTorchFunction { |
74 | EnableTorchFunction() |
75 | : old_(at::impl::PythonTorchFunctionTLS::get_disabled_state()) { |
76 | at::impl::PythonTorchFunctionTLS::set_disabled_state( |
77 | at::impl::TorchFunctionDisabledState::ENABLED); |
78 | } |
79 | ~EnableTorchFunction() { |
80 | at::impl::PythonTorchFunctionTLS::set_disabled_state(old_); |
81 | } |
82 | at::impl::TorchFunctionDisabledState old_; |
83 | }; |
84 | |
85 | struct EnablePythonDispatcher { |
86 | EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) { |
87 | c10::impl::PythonDispatcherTLS::set_state(getPyInterpreter()); |
88 | } |
89 | ~EnablePythonDispatcher() { |
90 | c10::impl::PythonDispatcherTLS::set_state(old_); |
91 | } |
92 | c10::impl::PyInterpreter* old_; |
93 | }; |
94 | |
95 | } // namespace |
96 | |
97 | PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { |
98 | using namespace torch::autograd::profiler; |
99 | using namespace torch::profiler::impl; |
100 | auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch._tensor" )); |
101 | if (!tensor_module) |
102 | return nullptr; |
103 | |
104 | // NOTE: "leaks" THPVariableClass |
105 | THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor" ); |
106 | if (!THPVariableClass) |
107 | return nullptr; |
108 | |
109 | auto autograd_module = THPObjectPtr(PyImport_ImportModule("torch.autograd" )); |
110 | if (!autograd_module) |
111 | return nullptr; |
112 | |
113 | // NOTE: "leaks" Function |
114 | THPFunctionClass = PyObject_GetAttrString(autograd_module, "Function" ); |
115 | if (!THPFunctionClass) |
116 | return nullptr; |
117 | |
118 | auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C" )); |
119 | if (!torch_C_module) |
120 | return nullptr; |
121 | auto _C_m = py::handle(torch_C_module).cast<py::module>(); |
122 | auto m = _C_m.def_submodule("_autograd" , "autograd bindings" ); |
123 | |
124 | auto parameter_module = |
125 | THPObjectPtr(PyImport_ImportModule("torch.nn.parameter" )); |
126 | if (!parameter_module) |
127 | return nullptr; |
128 | |
129 | // NOTE: "leaks" ParameterClass |
130 | ParameterClass = PyObject_GetAttrString(parameter_module, "Parameter" ); |
131 | if (!ParameterClass) |
132 | return nullptr; |
133 | |
134 | py::class_<LegacyEvent>(m, "ProfilerEvent" ) |
135 | .def("kind" , &LegacyEvent::kindStr) |
136 | .def("name" , [](const LegacyEvent& e) { return e.name(); }) |
137 | .def("thread_id" , &LegacyEvent::threadId) |
138 | .def("fwd_thread_id" , &LegacyEvent::fwdThreadId) |
139 | .def("device" , &LegacyEvent::device) |
140 | .def("cpu_elapsed_us" , &LegacyEvent::cpuElapsedUs) |
141 | .def("cuda_elapsed_us" , &LegacyEvent::cudaElapsedUs) |
142 | .def("has_cuda" , &LegacyEvent::hasCuda) |
143 | .def("shapes" , &LegacyEvent::shapes) |
144 | .def("cpu_memory_usage" , &LegacyEvent::cpuMemoryUsage) |
145 | .def("cuda_memory_usage" , &LegacyEvent::cudaMemoryUsage) |
146 | .def("handle" , &LegacyEvent::handle) |
147 | .def("node_id" , &LegacyEvent::nodeId) |
148 | .def("is_remote" , &LegacyEvent::isRemote) |
149 | .def("sequence_nr" , &LegacyEvent::sequenceNr) |
150 | .def("stack" , &LegacyEvent::stack) |
151 | .def("scope" , &LegacyEvent::scope) |
152 | .def("correlation_id" , &LegacyEvent::correlationId) |
153 | .def("start_us" , &LegacyEvent::cpuUs) |
154 | .def("flops" , &LegacyEvent::flops) |
155 | .def("is_async" , &LegacyEvent::isAsync); |
156 | |
157 | py::enum_<c10::DeviceType>(m, "DeviceType" ) |
158 | .value("CPU" , c10::DeviceType::CPU) |
159 | .value("CUDA" , c10::DeviceType::CUDA) |
160 | .value("MKLDNN" , c10::DeviceType::MKLDNN) |
161 | .value("OPENGL" , c10::DeviceType::OPENGL) |
162 | .value("OPENCL" , c10::DeviceType::OPENCL) |
163 | .value("IDEEP" , c10::DeviceType::IDEEP) |
164 | .value("HIP" , c10::DeviceType::HIP) |
165 | .value("FPGA" , c10::DeviceType::FPGA) |
166 | .value("ORT" , c10::DeviceType::ORT) |
167 | .value("XLA" , c10::DeviceType::XLA) |
168 | .value("Vulkan" , c10::DeviceType::Vulkan) |
169 | .value("Metal" , c10::DeviceType::Metal) |
170 | .value("XPU" , c10::DeviceType::XPU) |
171 | .value("MPS" , c10::DeviceType::MPS) |
172 | .value("Meta" , c10::DeviceType::Meta) |
173 | .value("HPU" , c10::DeviceType::HPU) |
174 | .value("VE" , c10::DeviceType::VE) |
175 | .value("Lazy" , c10::DeviceType::Lazy) |
176 | .value("IPU" , c10::DeviceType::IPU); |
177 | |
178 | py::class_<KinetoEvent>(m, "_KinetoEvent" ) |
179 | // name of the event |
180 | .def("name" , [](const KinetoEvent& e) { return e.name(); }) |
181 | // PyTorch thread id of the start callback |
182 | .def( |
183 | "start_thread_id" , |
184 | [](const KinetoEvent& e) { return e.startThreadId(); }) |
185 | // PyTorch thread id of the end callback |
186 | .def( |
187 | "end_thread_id" , [](const KinetoEvent& e) { return e.endThreadId(); }) |
188 | // for events of scope BACKWARD_FUNCTION - PyTorch thread id |
189 | // of the corresponding forward op |
190 | .def( |
191 | "fwd_thread_id" , [](const KinetoEvent& e) { return e.fwdThreadId(); }) |
192 | // together with fwd_thread_id, used to uniquely identify |
193 | // the forward op |
194 | .def("sequence_nr" , [](const KinetoEvent& e) { return e.sequenceNr(); }) |
195 | // absolute start time (since unix epoch) in us |
196 | .def("start_us" , [](const KinetoEvent& e) { return e.startUs(); }) |
197 | // duration in us |
198 | .def("duration_us" , [](const KinetoEvent& e) { return e.durationUs(); }) |
199 | // used for correlation between high-level PyTorch events |
200 | // and low-level device events |
201 | .def( |
202 | "correlation_id" , |
203 | [](const KinetoEvent& e) { return e.correlationId(); }) |
204 | // shapes of input tensors |
205 | .def("shapes" , [](const KinetoEvent& e) { return e.shapes().vec(); }) |
206 | .def("dtypes" , [](const KinetoEvent& e) { return e.dtypes().vec(); }) |
207 | // stack traces of the PyTorch CPU events |
208 | .def("stack" , [](const KinetoEvent& e) { return e.stack().vec(); }) |
209 | // type of the RecordFunction that generated a PyTorch CPU event |
210 | // (op, torchscript function, user label, etc) |
211 | .def("scope" , [](const KinetoEvent& e) { return e.scope(); }) |
212 | // device number, for CPU - process id |
213 | .def("device_index" , [](const KinetoEvent& e) { return e.deviceIndex(); }) |
214 | // for CUDA - stream id, for CPU - start thread id |
215 | .def( |
216 | "device_resource_id" , |
217 | [](const KinetoEvent& e) { return e.deviceResourceId(); }) |
218 | // device type |
219 | .def("device_type" , [](const KinetoEvent& e) { return e.deviceType(); }) |
220 | // correlation id of a linked event |
221 | .def( |
222 | "linked_correlation_id" , |
223 | [](const KinetoEvent& e) { return e.linkedCorrelationId(); }) |
224 | // compute flops |
225 | .def("flops" , [](const KinetoEvent& e) { return e.flops(); }) |
226 | // Whether this is async event or not |
227 | .def("is_async" , [](const KinetoEvent& e) { return e.isAsync(); }) |
228 | .def("cuda_elapsed_us" , &KinetoEvent::cudaElapsedUs) |
229 | .def("nbytes" , [](const KinetoEvent& e) { return e.nBytes(); }); |
230 | |
231 | m.def("_soft_assert_raises" , &setSoftAssertRaises); |
232 | |
233 | py::class_<ProfilerResult>(m, "_ProfilerResult" ) |
234 | .def("trace_start_us" , &ProfilerResult::trace_start_us) |
235 | .def("events" , &ProfilerResult::events) |
236 | .def("experimental_event_tree" , &ProfilerResult::event_tree) |
237 | #ifdef USE_KINETO |
238 | .def("save" , &ProfilerResult::save) |
239 | #endif // USE_KINETO |
240 | ; |
241 | |
242 | m.def( |
243 | "_enable_profiler" , |
244 | &enableProfiler, |
245 | py::arg("config" ), |
246 | py::arg("activities" ), |
247 | py::arg("scopes" ) = std::unordered_set<at::RecordScope>()); |
248 | m.def("_disable_profiler" , disableProfiler); |
249 | m.def("_prepare_profiler" , prepareProfiler); |
250 | m.def("_add_metadata_json" , addMetadataJson); // Only if `USE_KINETO` is set |
251 | m.def("_kineto_step" , profilerStep); // Only if `USE_KINETO` is set |
252 | m.def("kineto_available" , []() { return torch::profiler::kKinetoAvailable; }); |
253 | |
254 | // NOTICE: These record functions are not torch operators and may not show up |
255 | // in TorchScript tracing, FX transforms, or operator serialization. For these |
256 | // use cases, please use `torch.profiler.record_function`. |
257 | // Creates a new profiling scope using RecordFunction and invokes its starting |
258 | // callbacks. |
259 | m.def( |
260 | "_record_function_with_args_enter" , |
261 | [](const std::string& name, py::args args) { |
262 | using torch::autograd::profiler::PythonRecordFunction; |
263 | auto python_rec = c10::make_intrusive<PythonRecordFunction>( |
264 | at::RecordScope::USER_SCOPE); |
265 | auto* rec = &python_rec->record; |
266 | if (rec->isActive()) { |
267 | if (rec->needsInputs()) { |
268 | auto iv_inputs = std::vector<c10::IValue>(); |
269 | for (const auto& arg : args) { |
270 | iv_inputs.push_back(torch::jit::toTypeInferredIValue(arg)); |
271 | } |
272 | rec->before( |
273 | name, |
274 | c10::ArrayRef<const c10::IValue>( |
275 | iv_inputs.data(), iv_inputs.size())); |
276 | } else { |
277 | rec->before(name); |
278 | } |
279 | } |
280 | return torch::jit::toPyObject(std::move(python_rec)); |
281 | }); |
282 | |
283 | // Ends the profiling scope created with record_function_with_param_enter. |
284 | m.def("_record_function_with_args_exit" , [](const py::object& obj) { |
285 | using torch::autograd::profiler::PythonRecordFunction; |
286 | auto python_record = torch::jit::toCustomClass<PythonRecordFunction>(obj); |
287 | |
288 | // We don't actually need to do anything with handle just need to persist |
289 | // the lifetime until now. |
290 | python_record->record.end(); |
291 | }); |
292 | |
293 | m.def("_supported_activities" , []() { |
294 | std::set<ActivityType> activities{ActivityType::CPU}; |
295 | #if defined(USE_KINETO) && \ |
296 | (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) |
297 | if (at::getNumGPUs() > 0) { |
298 | activities.insert(ActivityType::CUDA); |
299 | } |
300 | #endif |
301 | return activities; |
302 | }); |
303 | |
304 | m.def("_unsafe_set_version_counter" , [](at::Tensor t, int64_t i) { |
305 | auto vc = torch::autograd::impl::version_counter(t); |
306 | vc.set_version(i); |
307 | }); |
308 | |
309 | m.def("_enable_profiler_legacy" , enableProfilerLegacy); |
310 | py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions" ) |
311 | .def(py::init<bool, bool>()); |
312 | m.def( |
313 | "_disable_profiler_legacy" , |
314 | disableProfilerLegacy, |
315 | py::arg("profiler_disable_options" ) = ProfilerDisableOptions()); |
316 | m.def("_profiler_enabled" , profilerEnabled); |
317 | m.def("_profiler_type" , torch::profiler::impl::profilerType); |
318 | m.def("_enable_record_function" , [](bool enable) { |
319 | at::enableRecordFunction(enable); |
320 | }); |
321 | m.def("_set_empty_test_observer" , [](bool is_global, double sampling_prob) { |
322 | auto cb = |
323 | at::RecordFunctionCallback(nullptr).needsInputs(true).samplingProb( |
324 | sampling_prob); |
325 | if (is_global) { |
326 | at::addGlobalCallback(cb); |
327 | } else { |
328 | at::addThreadLocalCallback(cb); |
329 | } |
330 | }); |
331 | m.def("_clear_callbacks" , []() { at::clearCallbacks(); }); |
332 | m.def( |
333 | "_saved_tensors_hooks_is_enabled" , |
334 | at::SavedTensorDefaultHooks::is_enabled); |
335 | m.def("_saved_tensors_hooks_enable" , at::SavedTensorDefaultHooks::enable); |
336 | m.def("_saved_tensors_hooks_disable" , at::SavedTensorDefaultHooks::disable); |
337 | m.def( |
338 | "_saved_tensors_hooks_get_disabled_error_message" , |
339 | at::SavedTensorDefaultHooks::get_disabled_error_message); |
340 | m.def( |
341 | "_push_saved_tensors_default_hooks" , |
342 | [](py::function& pack_hook, py::function& unpack_hook) { |
343 | torch::autograd::PyDefaultSavedVariableHooks::push_hooks( |
344 | pack_hook, unpack_hook); |
345 | }); |
346 | m.def("_pop_saved_tensors_default_hooks" , []() { |
347 | torch::autograd::PyDefaultSavedVariableHooks::pop_hooks(); |
348 | }); |
349 | |
350 | _C_m.def( |
351 | "_register_py_class_for_device" , |
352 | [](const std::string& device, py::object python_type_class) { |
353 | auto cls = python_type_class.ptr(); |
354 | registerPythonTensorClass(device, cls); |
355 | }); |
356 | |
357 | _C_m.def("_activate_cuda_trace" , []() { activateCUDATrace(); }); |
358 | |
359 | py::class_<c10::InferenceMode>(_C_m, "_InferenceMode" ).def(py::init<bool>()); |
360 | |
361 | py::class_<at::impl::RestorePythonTLSSnapshot>( |
362 | _C_m, "_RestorePythonTLSSnapshot" ) |
363 | .def(py::init<>()); |
364 | |
365 | py::class_<torch::DisableTorchDispatch>(_C_m, "_DisableTorchDispatch" ) |
366 | .def(py::init<>()); |
367 | py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction" ) |
368 | .def(py::init<>()); |
369 | py::class_<EnablePythonDispatcher>(_C_m, "_EnablePythonDispatcher" ) |
370 | .def(py::init<>()); |
371 | py::class_<c10::impl::DisablePythonDispatcher>( |
372 | _C_m, "_DisablePythonDispatcher" ) |
373 | .def(py::init<>()); |
374 | py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch" ).def(py::init<>()); |
375 | py::class_<MultithreadingEnabled>(_C_m, "_MultithreadingEnabled" ) |
376 | .def(py::init<bool>()); |
377 | py::class_<DisableAutocast>(std::move(_C_m), "_DisableAutocast" ) |
378 | .def(py::init<>()); |
379 | py::class_<ViewReplayEnabled>(_C_m, "_ViewReplayEnabled" ) |
380 | .def(py::init<bool>()); |
381 | py::class_<torch::autograd::SavedVariable>(std::move(m), "SavedTensor" ) |
382 | .def(py::init([]() -> torch::autograd::SavedVariable { |
383 | TORCH_CHECK( |
384 | false, |
385 | "Trying to create a SavedTensor object from Python is forbidden." ); |
386 | })) |
387 | .def( |
388 | "register_hooks" , |
389 | [](torch::autograd::SavedVariable& s, |
390 | py::function& pack_hook, |
391 | py::function& unpack_hook) { |
392 | // Because we use a py::object, pybind will increment the refcount |
393 | // of the hook functions for us |
394 | s.register_hooks( |
395 | std::make_unique<torch::autograd::PySavedVariableHooks>( |
396 | pack_hook, unpack_hook)); |
397 | }); |
398 | |
399 | torch::autograd::profiler::python_tracer::init(); |
400 | Py_RETURN_TRUE; |
401 | } |
402 | |
403 | namespace torch { |
404 | namespace autograd { |
405 | |
406 | static PyObject* set_autocast_enabled(PyObject* _unused, PyObject* arg) { |
407 | HANDLE_TH_ERRORS |
408 | if (!PyBool_Check(arg)) { |
409 | throw TypeError("enabled must be a bool (got %s)" , Py_TYPE(arg)->tp_name); |
410 | } |
411 | at::autocast::set_enabled(arg == Py_True); |
412 | Py_RETURN_NONE; |
413 | END_HANDLE_TH_ERRORS |
414 | } |
415 | |
416 | static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) { |
417 | HANDLE_TH_ERRORS |
418 | if (at::autocast::is_enabled()) { |
419 | Py_RETURN_TRUE; |
420 | } else { |
421 | Py_RETURN_FALSE; |
422 | } |
423 | END_HANDLE_TH_ERRORS |
424 | } |
425 | |
426 | static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) { |
427 | HANDLE_TH_ERRORS |
428 | if (at::autocast::is_enabled() || at::autocast::is_cpu_enabled() || |
429 | at::autocast::is_xpu_enabled()) { |
430 | Py_RETURN_TRUE; |
431 | } else { |
432 | Py_RETURN_FALSE; |
433 | } |
434 | END_HANDLE_TH_ERRORS |
435 | } |
436 | |
437 | static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) { |
438 | HANDLE_TH_ERRORS |
439 | if (!PyBool_Check(arg)) { |
440 | throw TypeError("enabled must be a bool (got %s)" , Py_TYPE(arg)->tp_name); |
441 | } |
442 | at::autocast::set_cpu_enabled(arg == Py_True); |
443 | Py_RETURN_NONE; |
444 | END_HANDLE_TH_ERRORS |
445 | } |
446 | |
447 | static PyObject* is_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) { |
448 | HANDLE_TH_ERRORS |
449 | if (at::autocast::is_cpu_enabled()) { |
450 | Py_RETURN_TRUE; |
451 | } else { |
452 | Py_RETURN_FALSE; |
453 | } |
454 | END_HANDLE_TH_ERRORS |
455 | } |
456 | |
457 | static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) { |
458 | HANDLE_TH_ERRORS |
459 | if (!THPDtype_Check(arg)) { |
460 | throw TypeError( |
461 | "dtype must be a torch.dtype (got %s)" , Py_TYPE(arg)->tp_name); |
462 | } |
463 | at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type; |
464 | at::autocast::set_autocast_gpu_dtype(targetType); |
465 | Py_RETURN_NONE; |
466 | END_HANDLE_TH_ERRORS |
467 | } |
468 | |
469 | static PyObject* set_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) { |
470 | HANDLE_TH_ERRORS |
471 | if (!THPDtype_Check(arg)) { |
472 | throw TypeError( |
473 | "dtype must be a torch.dtype (got %s)" , Py_TYPE(arg)->tp_name); |
474 | } |
475 | at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type; |
476 | at::autocast::set_autocast_cpu_dtype(targetType); |
477 | Py_RETURN_NONE; |
478 | END_HANDLE_TH_ERRORS |
479 | } |
480 | |
481 | static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) { |
482 | HANDLE_TH_ERRORS |
483 | at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype(); |
484 | auto dtype = (PyObject*)torch::getTHPDtype(current_dtype); |
485 | Py_INCREF(dtype); |
486 | return dtype; |
487 | END_HANDLE_TH_ERRORS |
488 | } |
489 | |
490 | static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) { |
491 | HANDLE_TH_ERRORS |
492 | at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype(); |
493 | auto dtype = (PyObject*)torch::getTHPDtype(current_dtype); |
494 | Py_INCREF(dtype); |
495 | return dtype; |
496 | END_HANDLE_TH_ERRORS |
497 | } |
498 | |
499 | static PyObject* clear_autocast_cache(PyObject* _unused, PyObject* arg) { |
500 | HANDLE_TH_ERRORS |
501 | at::autocast::clear_cache(); |
502 | Py_RETURN_NONE; |
503 | END_HANDLE_TH_ERRORS |
504 | } |
505 | |
506 | static PyObject* autocast_increment_nesting(PyObject* _unused, PyObject* arg) { |
507 | HANDLE_TH_ERRORS |
508 | return THPUtils_packInt64(at::autocast::increment_nesting()); |
509 | END_HANDLE_TH_ERRORS |
510 | } |
511 | |
512 | static PyObject* autocast_decrement_nesting(PyObject* _unused, PyObject* arg) { |
513 | HANDLE_TH_ERRORS |
514 | return THPUtils_packInt64(at::autocast::decrement_nesting()); |
515 | END_HANDLE_TH_ERRORS |
516 | } |
517 | |
518 | static PyObject* is_autocast_cache_enabled(PyObject* _unused, PyObject* arg) { |
519 | HANDLE_TH_ERRORS |
520 | if (at::autocast::is_autocast_cache_enabled()) { |
521 | Py_RETURN_TRUE; |
522 | } else { |
523 | Py_RETURN_FALSE; |
524 | } |
525 | END_HANDLE_TH_ERRORS |
526 | } |
527 | |
528 | static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) { |
529 | HANDLE_TH_ERRORS |
530 | if (!PyBool_Check(arg)) { |
531 | throw TypeError("enabled must be a bool (got %s)" , Py_TYPE(arg)->tp_name); |
532 | } |
533 | at::autocast::set_autocast_cache_enabled(arg == Py_True); |
534 | Py_RETURN_NONE; |
535 | END_HANDLE_TH_ERRORS |
536 | } |
537 | |
538 | static PyObject* set_grad_enabled(PyObject* _unused, PyObject* arg) { |
539 | HANDLE_TH_ERRORS |
540 | if (!PyBool_Check(arg)) { |
541 | throw TypeError("enabled must be a bool (got %s)" , Py_TYPE(arg)->tp_name); |
542 | } |
543 | GradMode::set_enabled(arg == Py_True); |
544 | Py_RETURN_NONE; |
545 | END_HANDLE_TH_ERRORS |
546 | } |
547 | |
548 | static PyObject* is_grad_enabled(PyObject* _unused, PyObject* arg) { |
549 | HANDLE_TH_ERRORS |
550 | if (GradMode::is_enabled()) { |
551 | Py_RETURN_TRUE; |
552 | } else { |
553 | Py_RETURN_FALSE; |
554 | } |
555 | END_HANDLE_TH_ERRORS |
556 | } |
557 | |
558 | static PyObject* set_fwd_grad_enabled(PyObject* _unused, PyObject* arg) { |
559 | HANDLE_TH_ERRORS |
560 | if (!PyBool_Check(arg)) { |
561 | throw TypeError("enabled must be a bool (got %s)" , Py_TYPE(arg)->tp_name); |
562 | } |
563 | c10::AutogradState::get_tls_state().set_fw_grad_mode(arg == Py_True); |
564 | Py_RETURN_NONE; |
565 | END_HANDLE_TH_ERRORS |
566 | } |
567 | |
568 | static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) { |
569 | HANDLE_TH_ERRORS |
570 | if (c10::AutogradState::get_tls_state().get_fw_grad_mode()) { |
571 | Py_RETURN_TRUE; |
572 | } else { |
573 | Py_RETURN_FALSE; |
574 | } |
575 | END_HANDLE_TH_ERRORS |
576 | } |
577 | |
578 | static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) { |
579 | HANDLE_TH_ERRORS |
580 | if (c10::InferenceMode::is_enabled()) { |
581 | Py_RETURN_TRUE; |
582 | } else { |
583 | Py_RETURN_FALSE; |
584 | } |
585 | END_HANDLE_TH_ERRORS |
586 | } |
587 | |
588 | static PyObject* set_anomaly_mode_enabled( |
589 | PyObject* _unused, |
590 | PyObject* args, |
591 | PyObject* kwargs) { |
592 | HANDLE_TH_ERRORS |
593 | static PythonArgParser parser({ |
594 | "set_anomaly_enabled(bool enabled, bool check_nan=True)" , |
595 | }); |
596 | ParsedArgs<2> parsed_args; |
597 | auto r = parser.parse(args, kwargs, parsed_args); |
598 | AnomalyMode::set_enabled(r.toBool(0), r.toBool(1)); |
599 | Py_RETURN_NONE; |
600 | END_HANDLE_TH_ERRORS |
601 | } |
602 | |
603 | static PyObject* is_anomaly_mode_enabled(PyObject* _unused, PyObject* arg) { |
604 | HANDLE_TH_ERRORS |
605 | if (AnomalyMode::is_enabled()) { |
606 | Py_RETURN_TRUE; |
607 | } else { |
608 | Py_RETURN_FALSE; |
609 | } |
610 | END_HANDLE_TH_ERRORS |
611 | } |
612 | |
613 | static PyObject* is_anomaly_check_nan_enabled( |
614 | PyObject* _unused, |
615 | PyObject* arg) { |
616 | HANDLE_TH_ERRORS |
617 | if (AnomalyMode::should_check_nan()) { |
618 | Py_RETURN_TRUE; |
619 | } else { |
620 | Py_RETURN_FALSE; |
621 | } |
622 | END_HANDLE_TH_ERRORS |
623 | } |
624 | |
625 | static PyObject* python_enter_dual_level(PyObject* _unused, PyObject* arg) { |
626 | HANDLE_TH_ERRORS |
627 | // It is unlikely that the depth of forward nesting will overflow int64_t so |
628 | // we just static cast here. |
629 | return utils::wrap(static_cast<int64_t>(forward_ad::enter_dual_level())); |
630 | END_HANDLE_TH_ERRORS |
631 | } |
632 | |
633 | static PyObject* python_exit_dual_level( |
634 | PyObject* _unused, |
635 | PyObject* args, |
636 | PyObject* kwargs) { |
637 | HANDLE_TH_ERRORS |
638 | static PythonArgParser parser({"exit_dual_level(int64_t level)" }); |
639 | |
640 | ParsedArgs<1> parsed_args; |
641 | auto _r = parser.parse(args, kwargs, parsed_args); |
642 | |
643 | auto idx = _r.toInt64(0); |
644 | // Make sure the given index is valid before casting it |
645 | TORCH_CHECK(idx >= 0, "Dual level must be a positive number." ); |
646 | forward_ad::exit_dual_level(static_cast<uint64_t>(idx)); |
647 | Py_RETURN_NONE; |
648 | END_HANDLE_TH_ERRORS |
649 | } |
650 | |
651 | static PyObject* is_torch_function_mode_enabled( |
652 | PyObject* _unused, |
653 | PyObject* _unused2) { |
654 | HANDLE_TH_ERRORS |
655 | if (at::impl::torch_function_mode_enabled()) { |
656 | Py_RETURN_TRUE; |
657 | } else { |
658 | Py_RETURN_FALSE; |
659 | } |
660 | END_HANDLE_TH_ERRORS |
661 | } |
662 | |
663 | static PyObject* push_on_torch_function_stack( |
664 | PyObject* _unused, |
665 | PyObject* arg) { |
666 | HANDLE_TH_ERRORS |
667 | if (arg != Py_None) { |
668 | Py_INCREF(arg); |
669 | at::impl::PythonTorchFunctionTLS::push_onto_stack( |
670 | std::make_shared<c10::SafePyObject>(arg, getPyInterpreter())); |
671 | } |
672 | Py_RETURN_NONE; |
673 | END_HANDLE_TH_ERRORS |
674 | } |
675 | |
676 | static PyObject* pop_torch_function_stack( |
677 | PyObject* _unused, |
678 | PyObject* _unused2) { |
679 | HANDLE_TH_ERRORS |
680 | const auto& mode = at::impl::PythonTorchFunctionTLS::pop_stack(); |
681 | auto* r = mode->ptr(getPyInterpreter()); |
682 | Py_INCREF(r); |
683 | return r; |
684 | END_HANDLE_TH_ERRORS |
685 | } |
686 | |
687 | static PyObject* get_function_stack_at( |
688 | PyObject* _unused, |
689 | PyObject* args, |
690 | PyObject* kwargs) { |
691 | HANDLE_TH_ERRORS |
692 | static PythonArgParser parser({"get_stack_at(int64_t level)" }); |
693 | |
694 | ParsedArgs<1> parsed_args; |
695 | auto _r = parser.parse(args, kwargs, parsed_args); |
696 | |
697 | auto idx = _r.toInt64(0); |
698 | const auto& mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); |
699 | auto* r = mode->ptr(getPyInterpreter()); |
700 | Py_INCREF(r); |
701 | return r; |
702 | END_HANDLE_TH_ERRORS |
703 | } |
704 | |
705 | static PyObject* len_torch_function_stack( |
706 | PyObject* _unused, |
707 | PyObject* _unused2) { |
708 | HANDLE_TH_ERRORS |
709 | const auto len = at::impl::PythonTorchFunctionTLS::stack_len(); |
710 | return utils::wrap(static_cast<int64_t>(len)); |
711 | END_HANDLE_TH_ERRORS |
712 | } |
713 | |
714 | static PyObject* push_on_torch_dispatch_stack( |
715 | PyObject* _unused, |
716 | PyObject* arg) { |
717 | HANDLE_TH_ERRORS |
718 | if (arg != Py_None) { |
719 | Py_INCREF(arg); |
720 | c10::impl::TorchDispatchModeTLS::push_onto_stack( |
721 | std::make_shared<c10::SafePyObject>(arg, getPyInterpreter())); |
722 | } |
723 | Py_RETURN_NONE; |
724 | END_HANDLE_TH_ERRORS |
725 | } |
726 | |
727 | static PyObject* pop_torch_dispatch_stack( |
728 | PyObject* _unused, |
729 | PyObject* _unused2) { |
730 | HANDLE_TH_ERRORS |
731 | const auto& mode = c10::impl::TorchDispatchModeTLS::pop_stack(); |
732 | auto* r = mode->ptr(getPyInterpreter()); |
733 | Py_INCREF(r); |
734 | return r; |
735 | END_HANDLE_TH_ERRORS |
736 | } |
737 | |
738 | static PyObject* get_dispatch_stack_at( |
739 | PyObject* _unused, |
740 | PyObject* args, |
741 | PyObject* kwargs) { |
742 | HANDLE_TH_ERRORS |
743 | static PythonArgParser parser({"get_stack_at(int64_t level)" }); |
744 | |
745 | ParsedArgs<1> parsed_args; |
746 | auto _r = parser.parse(args, kwargs, parsed_args); |
747 | |
748 | auto idx = _r.toInt64(0); |
749 | const auto& mode = c10::impl::TorchDispatchModeTLS::get_stack_at(idx); |
750 | auto* r = mode->ptr(getPyInterpreter()); |
751 | Py_INCREF(r); |
752 | return r; |
753 | END_HANDLE_TH_ERRORS |
754 | } |
755 | |
756 | static PyObject* len_torch_dispatch_stack( |
757 | PyObject* _unused, |
758 | PyObject* _unused2) { |
759 | HANDLE_TH_ERRORS |
760 | const auto len = c10::impl::TorchDispatchModeTLS::stack_len(); |
761 | return utils::wrap(static_cast<int64_t>(len)); |
762 | END_HANDLE_TH_ERRORS |
763 | } |
764 | |
765 | // autograd methods on torch._C |
766 | static PyMethodDef methods[] = { // NOLINT |
767 | {"_set_grad_enabled" , set_grad_enabled, METH_O, nullptr}, |
768 | {"is_grad_enabled" , is_grad_enabled, METH_NOARGS, nullptr}, |
769 | {"_set_fwd_grad_enabled" , set_fwd_grad_enabled, METH_O, nullptr}, |
770 | {"_is_fwd_grad_enabled" , is_fwd_grad_enabled, METH_NOARGS, nullptr}, |
771 | {"is_inference_mode_enabled" , |
772 | is_inference_mode_enabled, |
773 | METH_NOARGS, |
774 | nullptr}, |
775 | {"set_autocast_enabled" , set_autocast_enabled, METH_O, nullptr}, |
776 | {"is_autocast_enabled" , is_autocast_enabled, METH_NOARGS, nullptr}, |
777 | {"_is_any_autocast_enabled" , is_any_autocast_enabled, METH_NOARGS, nullptr}, |
778 | {"clear_autocast_cache" , clear_autocast_cache, METH_NOARGS, nullptr}, |
779 | {"set_autocast_cpu_enabled" , set_autocast_cpu_enabled, METH_O, nullptr}, |
780 | {"is_autocast_cpu_enabled" , is_autocast_cpu_enabled, METH_NOARGS, nullptr}, |
781 | {"set_autocast_cpu_dtype" , set_autocast_cpu_dtype, METH_O, nullptr}, |
782 | {"get_autocast_cpu_dtype" , get_autocast_cpu_dtype, METH_NOARGS, nullptr}, |
783 | {"set_autocast_gpu_dtype" , set_autocast_gpu_dtype, METH_O, nullptr}, |
784 | {"get_autocast_gpu_dtype" , get_autocast_gpu_dtype, METH_NOARGS, nullptr}, |
785 | {"autocast_increment_nesting" , |
786 | autocast_increment_nesting, |
787 | METH_NOARGS, |
788 | nullptr}, |
789 | {"autocast_decrement_nesting" , |
790 | autocast_decrement_nesting, |
791 | METH_NOARGS, |
792 | nullptr}, |
793 | {"is_autocast_cache_enabled" , |
794 | is_autocast_cache_enabled, |
795 | METH_NOARGS, |
796 | nullptr}, |
797 | {"set_autocast_cache_enabled" , set_autocast_cache_enabled, METH_O, nullptr}, |
798 | {"set_anomaly_enabled" , |
799 | castPyCFunctionWithKeywords(set_anomaly_mode_enabled), |
800 | METH_VARARGS | METH_KEYWORDS, |
801 | nullptr}, |
802 | {"is_anomaly_enabled" , is_anomaly_mode_enabled, METH_NOARGS, nullptr}, |
803 | {"is_anomaly_check_nan_enabled" , |
804 | is_anomaly_check_nan_enabled, |
805 | METH_NOARGS, |
806 | nullptr}, |
807 | {"_enter_dual_level" , python_enter_dual_level, METH_NOARGS, nullptr}, |
808 | {"_exit_dual_level" , |
809 | castPyCFunctionWithKeywords(python_exit_dual_level), |
810 | METH_VARARGS | METH_KEYWORDS, |
811 | nullptr}, |
812 | {"_is_torch_function_mode_enabled" , |
813 | is_torch_function_mode_enabled, |
814 | METH_NOARGS, |
815 | nullptr}, |
816 | {"_push_on_torch_function_stack" , |
817 | push_on_torch_function_stack, |
818 | METH_O, |
819 | nullptr}, |
820 | {"_pop_torch_function_stack" , |
821 | pop_torch_function_stack, |
822 | METH_NOARGS, |
823 | nullptr}, |
824 | {"_get_function_stack_at" , |
825 | castPyCFunctionWithKeywords(get_function_stack_at), |
826 | METH_VARARGS | METH_KEYWORDS, |
827 | nullptr}, |
828 | {"_len_torch_function_stack" , |
829 | len_torch_function_stack, |
830 | METH_NOARGS, |
831 | nullptr}, |
832 | {"_push_on_torch_dispatch_stack" , |
833 | push_on_torch_dispatch_stack, |
834 | METH_O, |
835 | nullptr}, |
836 | {"_pop_torch_dispatch_stack" , |
837 | pop_torch_dispatch_stack, |
838 | METH_NOARGS, |
839 | nullptr}, |
840 | {"_get_dispatch_stack_at" , |
841 | castPyCFunctionWithKeywords(get_dispatch_stack_at), |
842 | METH_VARARGS | METH_KEYWORDS, |
843 | nullptr}, |
844 | {"_len_torch_dispatch_stack" , |
845 | len_torch_dispatch_stack, |
846 | METH_NOARGS, |
847 | nullptr}, |
848 | {nullptr, nullptr, 0, nullptr}}; |
849 | |
850 | PyMethodDef* python_functions() { |
851 | return methods; |
852 | } |
853 | |
854 | } // namespace autograd |
855 | } // namespace torch |
856 | |