1 | #pragma once |
2 | |
3 | #include <c10/core/impl/LocalDispatchKeySet.h> |
4 | #include <c10/macros/Export.h> |
5 | #include <c10/macros/Macros.h> |
6 | |
7 | // NOTE [Tracing Mode Switches] |
8 | // |
9 | // Historically, tracing function was controlled by two switches: |
10 | // |
11 | // - `AutoDispatchBelowADInplaceOrView` guard |
12 | // |
13 | // Tracing function used to be script-generated inside `VariableType_*.cpp` |
14 | // kernels, sharing the same `Autograd` dispatch key with autograd function. |
15 | // Therefore, before tracing function was moved out of VariableType, |
16 | // `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a |
17 | // side effect of disabling `Autograd` dispatching. |
18 | // |
19 | // - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h` |
20 | // |
21 | // It stores tracing data in a `TracingState` object in TLS. If the |
22 | // `TracingState` object in TLS is `null`, then tracing is paused. |
23 | // |
24 | // The `TracingState` object is created in `tracer::trace()` - the main |
25 | // entrance of tracing function. It's temporarily set to `null` inside |
26 | // generated VariableType (now TraceType) to bypass tracing for intermediate |
27 | // ops (ops being called by other ops). After the intermediate op call |
28 | // finishes it's set back to the original `TracingState` object. |
29 | // |
30 | // The `TracingState` obect in TLS can also be read/written via its Python |
31 | // binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs, |
32 | // which are also exposed as `TORCH_API`. |
33 | // |
34 | // Two new switches were introduced since tracing function was moved out of |
35 | // VariableType: |
36 | // |
37 | // - `tracer::impl::set_dispatch_enabled()` API |
38 | // |
39 | // Unlike the special `Autograd` dispatch key which is included in dispatch |
40 | // key set by default, `Tracer` dispatch key is off by default. The |
41 | // dispatching switch can be toggled via this new API. |
42 | // |
43 | // - `tracer::impl::NoTracerDispatchMode` guard |
44 | // |
45 | // It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView` |
46 | // after tracing was moved out of VariableType. |
47 | // |
48 | // Before tracing function was moved out of VariableType, tracing was enabled |
49 | // when the following conditions are satisfied: |
50 | // |
51 | // 1) `TracingState` object in TLS != null; |
52 | // - Either inside the execution scope of `tracer::trace()`, or |
53 | // - Eagerly called `setTracingState()` with non-null object. |
54 | // 2) Not inside `AutoDispatchBelowADInplaceOrView` scope; |
55 | // |
56 | // After: |
57 | // |
58 | // 1) `TracingState` object in TLS != null; |
59 | // 2) Has called `tracer::impl::set_dispatch_enabled(true)`; |
60 | // 3) Not inside `tracer::impl::NonDispatchGuard` scope; |
61 | // |
62 | // [TODOs] |
63 | // |
64 | // - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()` |
65 | // |
66 | // Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()` |
67 | // to keep the semantics exactly the same as before - it's confusing to keep |
68 | // both switches, though. We should consider simplifying/limiting the exposed |
69 | // `setTracingState()` Python/C++ APIs (and other APIs calling it) so that |
70 | // these two can be unified. |
71 | // |
72 | // - `AutoDispatchBelowADInplaceOrView` v.s. |
73 | // `tracer::impl::NoTracerDispatchMode` |
74 | // |
75 | // We don't need to always set both guards together to keep semantics |
76 | // unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView` |
77 | // we don't need set the new tracer guard: |
78 | // |
79 | // * Script-generated VariableType kernels. The guard is not necessary as |
80 | // tracing is already disabled explicitly by `setTracingState(null)` in |
81 | // generated TraceType kernels - we could keep it as is or use the new guard |
82 | // instead. |
83 | // |
84 | // * Custom ops. Will be handled by fallback kernel for `Tracer`. |
85 | // |
86 | // * Functions that are not likely to be called in tracing context (no python |
87 | // binding / not an operator), e.g.: all mobile forward() wrappers, test |
88 | // binaries, and etc. |
89 | // |
90 | // * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp. |
91 | // It's not necessary as tracing is off by default. |
92 | // |
93 | // For the rest of cases we might need have both: |
94 | // |
95 | // * Functions that might be reachable from eager mode python (especially |
96 | // factory methods), e.g.: |
97 | // `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`. |
98 | // Without the new guard it will add `aten::empty` to the traced graph. |
99 | // |
100 | // * Some manually maintained functions, e.g.: |
101 | // `torch/csrc/autograd/VariableTypeManual.cpp`. |
102 | // Set the new guard if it's not obvious whether `setTracingState(null)` |
103 | // has been called before it reaches the `AutoDispatchBelowADInplaceOrView` |
104 | // guard. |
105 | // |
106 | // We might need tweak the usage of the new guard to optimize/fix things. |
107 | // It should only affect the correctness of tracing function, because the |
108 | // guard is essentially no-op when the master `setTracingState()` switch is |
109 | // off. |
110 | |
111 | namespace at { |
112 | // TODO: move this from `at::` to `jit::torch::` after |
113 | // `aten/src/ATen/cpp_custom_type_hack.h` is removed. |
114 | |
115 | namespace tracer { |
116 | namespace impl { |
117 | |
118 | static inline bool is_dispatch_enabled() { |
119 | return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) && |
120 | !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer); |
121 | } |
122 | |
123 | static inline void set_dispatch_enabled(bool enabled) { |
124 | TORCH_INTERNAL_ASSERT( |
125 | !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer), |
126 | "Cannot enable tracing within the scope of NoTracerDispatchMode!" ); |
127 | c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled); |
128 | } |
129 | |
130 | struct NoTracerDispatchMode { |
131 | c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer}; |
132 | }; |
133 | |
134 | } // namespace impl |
135 | } // namespace tracer |
136 | } // namespace at |
137 | |