1 | #pragma once |
---|---|
2 | |
3 | #include <c10/core/AutogradState.h> |
4 | #include <c10/macros/Macros.h> |
5 | |
6 | namespace c10 { |
7 | |
8 | struct C10_API GradMode { |
9 | static bool is_enabled(); |
10 | static void set_enabled(bool enabled); |
11 | }; |
12 | |
13 | // A RAII, thread local (!) guard that enables or disables grad mode upon |
14 | // construction, and sets it back to the original value upon destruction. |
15 | struct C10_API AutoGradMode { |
16 | AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { |
17 | GradMode::set_enabled(enabled); |
18 | } |
19 | ~AutoGradMode() { |
20 | GradMode::set_enabled(prev_mode); |
21 | } |
22 | bool prev_mode; |
23 | }; |
24 | |
25 | // A RAII, thread local (!) guard that stops future operations from building |
26 | // gradients. |
27 | struct C10_API NoGradGuard : public AutoGradMode { |
28 | NoGradGuard() : AutoGradMode(/*enabled=*/false) {} |
29 | }; |
30 | |
31 | // A RAII, thread local (!) guard that enables or disables forward grad mode |
32 | // upon construction, and sets it back to the original value upon destruction. |
33 | struct C10_API AutoFwGradMode { |
34 | AutoFwGradMode(bool enabled) |
35 | : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) { |
36 | AutogradState::get_tls_state().set_fw_grad_mode(enabled); |
37 | } |
38 | ~AutoFwGradMode() { |
39 | AutogradState::get_tls_state().set_fw_grad_mode(prev_mode); |
40 | } |
41 | bool prev_mode; |
42 | }; |
43 | |
44 | } // namespace c10 |
45 |