1 | #pragma once |
2 | |
3 | #include <c10/core/impl/LocalDispatchKeySet.h> |
4 | |
5 | namespace at { |
6 | namespace impl { |
7 | |
8 | // VmapMode contains a thread local count of how many nested vmaps |
9 | // we are currently inside. That number is known as the `vmap level`. |
10 | // VmapMode is used in the implementation of the Python `torch.vmap` API. |
11 | // |
12 | // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet. |
13 | |
14 | struct TORCH_API VmapMode { |
15 | // Returns the vmap level, aka the count of how many nested vmaps we're in. |
16 | static int64_t current_vmap_level(); |
17 | |
18 | // Increment the count of nested vmaps. If this causes the vmap level to be |
19 | // greater than 0, then it enables DispatchKey::VmapMode on all tensors. |
20 | static int64_t increment_nesting(); |
21 | |
22 | // Decrements the count of nested vmaps. If this causes the vmap level to be |
23 | // equal to 0, then it disables DispatchKey::VmapMode on all tensors. |
24 | static int64_t decrement_nesting(); |
25 | }; |
26 | |
27 | } // namespace impl |
28 | } // namespace at |
29 | |