1#pragma once
2
3#include <ATen/LegacyBatchedTensorImpl.h>
4#include <ATen/core/IListRef.h>
5
6namespace at {
7
8// This file contains abstractions used for transforming *logical* vmap
9// arguments into *physical* arguments. (Keep reading for definitions of these
10// terms).
11
12// NOTE: [Logical vs physical args]
13// Consider the following vmap.
14// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
15// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
16// with batch dims 0 and 2:
17// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
18//
19// We say the *logical* view of the tensor has size [3] -- tensors inside
20// `func` appear to have size [3].
21// However, the *physical* underlying tensor (the one passed to vmap) has size
22// [2, 3, 4].
23//
24// This notion of logical vs physical also extends to non-tensor arguments.
25// Consider the previous tensor; let's assume the user called
26// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
27// dimension they are reducing over is dim 0 but the physical dim is dim 1
28// (the first non-batch dimension)
29
30// Forward declared; see NOTE: [What is a VmapPhysicalView?]
31struct VmapPhysicalView;
32
33// Most PyTorch operators take 4 or fewer inputs.
34constexpr int64_t kVmapTransformStaticInputSize = 4;
35using VmapPhysicalViewVec =
36 SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
37
38// Pytorch generally advertises good performance for <= 5 dims.
39// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
40// dimensions to get 8. Adjust this number as necessary
41constexpr int64_t kVmapStaticDimVecSize = 8;
42using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
43
44// NOTE: [What is an VmapTransform?]
45// An *VmapTransform* converts logical views of tensors to physical views.
46//
47// Batching rules use VmapTransforms to convert logical arguments to
48// physical arguments, then call one or more at:: operator that handles the
49// physical arguments, and then converts the physical result back to a logical
50// argument.
51
52// VmapTransform for operators that take tensors with multiple batch dims.
53// Given one or more logical views on Tensors, `logicalToPhysical`
54// permutes all of the batch dims to the front of the tensor, aligns
55// and expands the batch dims to match each other (according to their `level`),
56// and returns a VmapPhysicalView on the tensor(s).
57struct TORCH_API MultiBatchVmapTransform {
58 static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
59 static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
60};
61
62// VmapTransform for operators that broadcast all inputs.
63// Given some logical views on Tensors, `logicalToPhysical`:
64// - permutes all of the batch dims to the front of the tensors
65// - aligns all the batch dims to the collective levels of all of the tensors.
66// If a tensor does not have a batch dim for a vmap level, then it receives
67// a size-one dimension for said level.
68// - aligns the non-batch dims to have the same dimensionality, adding extra
69// size-1 dimensions in between the batch dimensions and the non-batch
70// dimensions so that the batch dimensions are lined up from the right.
71//
72// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
73// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
74// tensors of size (B, 1, 2) and (B, 3, 2).
75//
76// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
77// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
78// actually *need* to return a tensor of size (1, 2) for the second tensor
79// because the broadcasting operation takes care of that for us, but we do
80// it anyways to keep things simple.
81struct TORCH_API BroadcastingVmapTransform {
82 static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
83};
84
85// Forward declared, if you're reading this file head to toe, don't worry about
86// it yet.
87struct VmapPhysicalToLogicalMap;
88
89// NOTE: [What is a VmapPhysicalView?]
90// VmapPhysicalView represents a physical view on a Tensor.
91//
92// One can use it to further convert logical dimension indices, logical shapes,
93// and more to their physical variants, or convert a new (physical) tensor into
94// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
95//
96// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
97// the front and some levels that correspond to said batch dimensions.
98//
99// The levels bitset specifies which vmap levels correspond to the batch
100// dimensions at the front of the tensor. In particular, the number of set bits
101// corresponds to the number of batch dimensions on `tensor` and the rightmost
102// bit of `levels` specifies the maximum number of nested vmaps we are in at
103// this point in time.
104// For example, given:
105// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
106//
107// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
108// than or equal to 3.
109// bitset: 010100
110// ^
111// |
112// levels: 012345
113struct TORCH_API VmapPhysicalView {
114 VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
115 : levels_(levels), tensor_(tensor) {
116 TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
117 }
118
119 Tensor& tensor() {
120 return tensor_;
121 }
122 const Tensor& tensor() const {
123 return tensor_;
124 }
125
126 // Maps logical dim indices to physical dim indices. Also does dim wrapping.
127 //
128 // For example, given:
129 // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
130 //
131 // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
132 // This is because the size of levels tell us that the first two dimensions
133 // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
134 // a physical dim of `n + 2`.
135 VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
136 int64_t getPhysicalDim(int64_t logical_dim) const;
137
138 // Returns a VmapPhysicalToLogicalMap object. This can be used for
139 // mapping a physical tensor to a new logical tensor (BatchedTensor)
140 VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
141
142 // Maps a logical shape to a physical shape by pre-pending the batch
143 // sizes to the logical shape.
144 VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
145
146 int64_t numBatchDims() const;
147
148 private:
149 int64_t numLogicalDims() const;
150
151 std::bitset<kVmapNumLevels> levels_;
152 Tensor tensor_;
153};
154
155// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
156// to a logical one (BatchedTensor). It holds some levels that are used to do
157// the mapping and assumes that the batch dimensions in the physical tensor all
158// occur at the front of the tensor.
159struct TORCH_API VmapPhysicalToLogicalMap {
160 VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
161 : levels_(levels) {}
162
163 // Maps a physical tensor to a new logical tensor (BatchedTensor).
164 // Assumes that all of the "batch dimensions" are at the front
165 // of the physical tensor. For example, given:
166 // - x = rank-4 Tensor with size 2, 3, 5, 7
167 // - levels = (2, 4)
168 // Returns:
169 // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
170 Tensor apply(const Tensor& physical_tensor) const;
171
172 // Given a vector of physical tensors,
173 // 1. maps each tensor to a new logical tensor. Assumes that all of the
174 // "batch dimensions" are at the front of the physical tensors.
175 // 2. stores the new logical tensors back into the passed-in vector. This is
176 // to avoid additional dynamic allocations.
177 void applyInplace(std::vector<Tensor>& physical_tensors) const;
178
179 std::bitset<kVmapNumLevels> levels_;
180};
181
182} // namespace at
183