1#pragma once
2
3#include <c10/core/impl/LocalDispatchKeySet.h>
4
5namespace at {
6namespace impl {
7
8// VmapMode contains a thread local count of how many nested vmaps
9// we are currently inside. That number is known as the `vmap level`.
10// VmapMode is used in the implementation of the Python `torch.vmap` API.
11//
12// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
13
14struct TORCH_API VmapMode {
15 // Returns the vmap level, aka the count of how many nested vmaps we're in.
16 static int64_t current_vmap_level();
17
18 // Increment the count of nested vmaps. If this causes the vmap level to be
19 // greater than 0, then it enables DispatchKey::VmapMode on all tensors.
20 static int64_t increment_nesting();
21
22 // Decrements the count of nested vmaps. If this causes the vmap level to be
23 // equal to 0, then it disables DispatchKey::VmapMode on all tensors.
24 static int64_t decrement_nesting();
25};
26
27} // namespace impl
28} // namespace at
29