1#pragma once
2
3#include <c10/core/impl/LocalDispatchKeySet.h>
4#include <c10/macros/Export.h>
5#include <c10/macros/Macros.h>
6
7// NOTE [Tracing Mode Switches]
8//
9// Historically, tracing function was controlled by two switches:
10//
11// - `AutoDispatchBelowADInplaceOrView` guard
12//
13// Tracing function used to be script-generated inside `VariableType_*.cpp`
14// kernels, sharing the same `Autograd` dispatch key with autograd function.
15// Therefore, before tracing function was moved out of VariableType,
16// `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a
17// side effect of disabling `Autograd` dispatching.
18//
19// - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
20//
21// It stores tracing data in a `TracingState` object in TLS. If the
22// `TracingState` object in TLS is `null`, then tracing is paused.
23//
24// The `TracingState` object is created in `tracer::trace()` - the main
25// entrance of tracing function. It's temporarily set to `null` inside
26// generated VariableType (now TraceType) to bypass tracing for intermediate
27// ops (ops being called by other ops). After the intermediate op call
28// finishes it's set back to the original `TracingState` object.
29//
30// The `TracingState` obect in TLS can also be read/written via its Python
31// binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs,
32// which are also exposed as `TORCH_API`.
33//
34// Two new switches were introduced since tracing function was moved out of
35// VariableType:
36//
37// - `tracer::impl::set_dispatch_enabled()` API
38//
39// Unlike the special `Autograd` dispatch key which is included in dispatch
40// key set by default, `Tracer` dispatch key is off by default. The
41// dispatching switch can be toggled via this new API.
42//
43// - `tracer::impl::NoTracerDispatchMode` guard
44//
45// It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView`
46// after tracing was moved out of VariableType.
47//
48// Before tracing function was moved out of VariableType, tracing was enabled
49// when the following conditions are satisfied:
50//
51// 1) `TracingState` object in TLS != null;
52// - Either inside the execution scope of `tracer::trace()`, or
53// - Eagerly called `setTracingState()` with non-null object.
54// 2) Not inside `AutoDispatchBelowADInplaceOrView` scope;
55//
56// After:
57//
58// 1) `TracingState` object in TLS != null;
59// 2) Has called `tracer::impl::set_dispatch_enabled(true)`;
60// 3) Not inside `tracer::impl::NonDispatchGuard` scope;
61//
62// [TODOs]
63//
64// - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()`
65//
66// Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()`
67// to keep the semantics exactly the same as before - it's confusing to keep
68// both switches, though. We should consider simplifying/limiting the exposed
69// `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
70// these two can be unified.
71//
72// - `AutoDispatchBelowADInplaceOrView` v.s.
73// `tracer::impl::NoTracerDispatchMode`
74//
75// We don't need to always set both guards together to keep semantics
76// unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView`
77// we don't need set the new tracer guard:
78//
79// * Script-generated VariableType kernels. The guard is not necessary as
80// tracing is already disabled explicitly by `setTracingState(null)` in
81// generated TraceType kernels - we could keep it as is or use the new guard
82// instead.
83//
84// * Custom ops. Will be handled by fallback kernel for `Tracer`.
85//
86// * Functions that are not likely to be called in tracing context (no python
87// binding / not an operator), e.g.: all mobile forward() wrappers, test
88// binaries, and etc.
89//
90// * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp.
91// It's not necessary as tracing is off by default.
92//
93// For the rest of cases we might need have both:
94//
95// * Functions that might be reachable from eager mode python (especially
96// factory methods), e.g.:
97// `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`.
98// Without the new guard it will add `aten::empty` to the traced graph.
99//
100// * Some manually maintained functions, e.g.:
101// `torch/csrc/autograd/VariableTypeManual.cpp`.
102// Set the new guard if it's not obvious whether `setTracingState(null)`
103// has been called before it reaches the `AutoDispatchBelowADInplaceOrView`
104// guard.
105//
106// We might need tweak the usage of the new guard to optimize/fix things.
107// It should only affect the correctness of tracing function, because the
108// guard is essentially no-op when the master `setTracingState()` switch is
109// off.
110
111namespace at {
112// TODO: move this from `at::` to `jit::torch::` after
113// `aten/src/ATen/cpp_custom_type_hack.h` is removed.
114
115namespace tracer {
116namespace impl {
117
118static inline bool is_dispatch_enabled() {
119 return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
120 !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
121}
122
123static inline void set_dispatch_enabled(bool enabled) {
124 TORCH_INTERNAL_ASSERT(
125 !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
126 "Cannot enable tracing within the scope of NoTracerDispatchMode!");
127 c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled);
128}
129
130struct NoTracerDispatchMode {
131 c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
132};
133
134} // namespace impl
135} // namespace tracer
136} // namespace at
137