1#include <ATen/LegacyVmapMode.h>
2
3namespace at {
4namespace impl {
5
6thread_local int64_t VmapMode_current_vmap_level = 0;
7
8int64_t VmapMode::current_vmap_level() {
9 return VmapMode_current_vmap_level;
10}
11
12int64_t VmapMode::increment_nesting() {
13 VmapMode_current_vmap_level++;
14 if (VmapMode_current_vmap_level == 1) {
15 c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, true);
16 }
17 return VmapMode_current_vmap_level;
18}
19
20int64_t VmapMode::decrement_nesting() {
21 VmapMode_current_vmap_level--;
22 if (VmapMode_current_vmap_level == 0) {
23 c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, false);
24 }
25 return VmapMode_current_vmap_level;
26}
27} // namespace impl
28} // namespace at
29