1 | #pragma once |
2 | |
3 | #include <ATen/LegacyBatchedTensorImpl.h> |
4 | #include <ATen/core/IListRef.h> |
5 | |
6 | namespace at { |
7 | |
8 | // This file contains abstractions used for transforming *logical* vmap |
9 | // arguments into *physical* arguments. (Keep reading for definitions of these |
10 | // terms). |
11 | |
12 | // NOTE: [Logical vs physical args] |
13 | // Consider the following vmap. |
14 | // vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4)) |
15 | // This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4], |
16 | // with batch dims 0 and 2: |
17 | // BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)]) |
18 | // |
19 | // We say the *logical* view of the tensor has size [3] -- tensors inside |
20 | // `func` appear to have size [3]. |
21 | // However, the *physical* underlying tensor (the one passed to vmap) has size |
22 | // [2, 3, 4]. |
23 | // |
24 | // This notion of logical vs physical also extends to non-tensor arguments. |
25 | // Consider the previous tensor; let's assume the user called |
26 | // `torch.sum(tensor, dim=0)` inside of `func`. Then the logical |
27 | // dimension they are reducing over is dim 0 but the physical dim is dim 1 |
28 | // (the first non-batch dimension) |
29 | |
30 | // Forward declared; see NOTE: [What is a VmapPhysicalView?] |
31 | struct VmapPhysicalView; |
32 | |
33 | // Most PyTorch operators take 4 or fewer inputs. |
34 | constexpr int64_t kVmapTransformStaticInputSize = 4; |
35 | using VmapPhysicalViewVec = |
36 | SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>; |
37 | |
38 | // Pytorch generally advertises good performance for <= 5 dims. |
39 | // (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap |
40 | // dimensions to get 8. Adjust this number as necessary |
41 | constexpr int64_t kVmapStaticDimVecSize = 8; |
42 | using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>; |
43 | |
44 | // NOTE: [What is an VmapTransform?] |
45 | // An *VmapTransform* converts logical views of tensors to physical views. |
46 | // |
47 | // Batching rules use VmapTransforms to convert logical arguments to |
48 | // physical arguments, then call one or more at:: operator that handles the |
49 | // physical arguments, and then converts the physical result back to a logical |
50 | // argument. |
51 | |
52 | // VmapTransform for operators that take tensors with multiple batch dims. |
53 | // Given one or more logical views on Tensors, `logicalToPhysical` |
54 | // permutes all of the batch dims to the front of the tensor, aligns |
55 | // and expands the batch dims to match each other (according to their `level`), |
56 | // and returns a VmapPhysicalView on the tensor(s). |
57 | struct TORCH_API MultiBatchVmapTransform { |
58 | static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); |
59 | static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors); |
60 | }; |
61 | |
62 | // VmapTransform for operators that broadcast all inputs. |
63 | // Given some logical views on Tensors, `logicalToPhysical`: |
64 | // - permutes all of the batch dims to the front of the tensors |
65 | // - aligns all the batch dims to the collective levels of all of the tensors. |
66 | // If a tensor does not have a batch dim for a vmap level, then it receives |
67 | // a size-one dimension for said level. |
68 | // - aligns the non-batch dims to have the same dimensionality, adding extra |
69 | // size-1 dimensions in between the batch dimensions and the non-batch |
70 | // dimensions so that the batch dimensions are lined up from the right. |
71 | // |
72 | // For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch |
73 | // dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap |
74 | // tensors of size (B, 1, 2) and (B, 3, 2). |
75 | // |
76 | // Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns |
77 | // VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't |
78 | // actually *need* to return a tensor of size (1, 2) for the second tensor |
79 | // because the broadcasting operation takes care of that for us, but we do |
80 | // it anyways to keep things simple. |
81 | struct TORCH_API BroadcastingVmapTransform { |
82 | static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); |
83 | }; |
84 | |
85 | // Forward declared, if you're reading this file head to toe, don't worry about |
86 | // it yet. |
87 | struct VmapPhysicalToLogicalMap; |
88 | |
89 | // NOTE: [What is a VmapPhysicalView?] |
90 | // VmapPhysicalView represents a physical view on a Tensor. |
91 | // |
92 | // One can use it to further convert logical dimension indices, logical shapes, |
93 | // and more to their physical variants, or convert a new (physical) tensor into |
94 | // a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented). |
95 | // |
96 | // VmapPhysicalView stores a physical tensor with all of its batch dimensions at |
97 | // the front and some levels that correspond to said batch dimensions. |
98 | // |
99 | // The levels bitset specifies which vmap levels correspond to the batch |
100 | // dimensions at the front of the tensor. In particular, the number of set bits |
101 | // corresponds to the number of batch dimensions on `tensor` and the rightmost |
102 | // bit of `levels` specifies the maximum number of nested vmaps we are in at |
103 | // this point in time. |
104 | // For example, given: |
105 | // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) |
106 | // |
107 | // Rightmost bit of `levels` is 3 indicating the number of nested vmaps less |
108 | // than or equal to 3. |
109 | // bitset: 010100 |
110 | // ^ |
111 | // | |
112 | // levels: 012345 |
113 | struct TORCH_API VmapPhysicalView { |
114 | VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels) |
115 | : levels_(levels), tensor_(tensor) { |
116 | TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor)); |
117 | } |
118 | |
119 | Tensor& tensor() { |
120 | return tensor_; |
121 | } |
122 | const Tensor& tensor() const { |
123 | return tensor_; |
124 | } |
125 | |
126 | // Maps logical dim indices to physical dim indices. Also does dim wrapping. |
127 | // |
128 | // For example, given: |
129 | // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3}) |
130 | // |
131 | // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}. |
132 | // This is because the size of levels tell us that the first two dimensions |
133 | // of `tensor_` are batch dimensions, so a logical dim of `n` is actually |
134 | // a physical dim of `n + 2`. |
135 | VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const; |
136 | int64_t getPhysicalDim(int64_t logical_dim) const; |
137 | |
138 | // Returns a VmapPhysicalToLogicalMap object. This can be used for |
139 | // mapping a physical tensor to a new logical tensor (BatchedTensor) |
140 | VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; |
141 | |
142 | // Maps a logical shape to a physical shape by pre-pending the batch |
143 | // sizes to the logical shape. |
144 | VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; |
145 | |
146 | int64_t numBatchDims() const; |
147 | |
148 | private: |
149 | int64_t numLogicalDims() const; |
150 | |
151 | std::bitset<kVmapNumLevels> levels_; |
152 | Tensor tensor_; |
153 | }; |
154 | |
155 | // Convenience struct used for mapping a physical tensor (a non-BatchedTensor) |
156 | // to a logical one (BatchedTensor). It holds some levels that are used to do |
157 | // the mapping and assumes that the batch dimensions in the physical tensor all |
158 | // occur at the front of the tensor. |
159 | struct TORCH_API VmapPhysicalToLogicalMap { |
160 | VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels) |
161 | : levels_(levels) {} |
162 | |
163 | // Maps a physical tensor to a new logical tensor (BatchedTensor). |
164 | // Assumes that all of the "batch dimensions" are at the front |
165 | // of the physical tensor. For example, given: |
166 | // - x = rank-4 Tensor with size 2, 3, 5, 7 |
167 | // - levels = (2, 4) |
168 | // Returns: |
169 | // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)]) |
170 | Tensor apply(const Tensor& physical_tensor) const; |
171 | |
172 | // Given a vector of physical tensors, |
173 | // 1. maps each tensor to a new logical tensor. Assumes that all of the |
174 | // "batch dimensions" are at the front of the physical tensors. |
175 | // 2. stores the new logical tensors back into the passed-in vector. This is |
176 | // to avoid additional dynamic allocations. |
177 | void applyInplace(std::vector<Tensor>& physical_tensors) const; |
178 | |
179 | std::bitset<kVmapNumLevels> levels_; |
180 | }; |
181 | |
182 | } // namespace at |
183 | |