1 | #include <ATen/LegacyVmapTransforms.h> |
2 | #include <ATen/ATen.h> |
3 | #include <ATen/core/IListRef.h> |
4 | #include <c10/util/irange.h> |
5 | |
6 | namespace at { |
7 | |
8 | // Checks if the batch dims in `bdims` appear at the front of the tensor. |
9 | static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) { |
10 | for (const auto idx : c10::irange(static_cast<int64_t>(bdims.size()))) { |
11 | if (bdims[idx].dim() != idx) { |
12 | return false; |
13 | } |
14 | } |
15 | return true; |
16 | } |
17 | |
18 | // Takes a BatchedTensorImpl, permutes all of the batch dims to the front, |
19 | // and then returns a physical version of the Tensor. |
20 | static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) { |
21 | auto bdims = batched->bdims(); |
22 | const Tensor& physical_tensor = batched->value(); |
23 | if (areBdimsAtFrontInOrder(bdims)) { |
24 | return physical_tensor; |
25 | } |
26 | const auto sizes = physical_tensor.sizes(); |
27 | VmapDimVector permutation(sizes.size(), 0); |
28 | permutation.reserve(sizes.size()); |
29 | const auto is_bdim = createBatchDimBitset(bdims); |
30 | int64_t idx = 0; |
31 | for (const auto& bdim : bdims) { |
32 | permutation[idx++] = bdim.dim(); |
33 | } |
34 | for (const auto ptr : c10::irange(sizes.size())) { |
35 | if (is_bdim[ptr]) { |
36 | continue; |
37 | } |
38 | permutation[idx++] = ptr; |
39 | } |
40 | return physical_tensor.permute(permutation); |
41 | } |
42 | |
43 | VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) { |
44 | auto* batched = maybeGetBatchedImpl(logical_tensor); |
45 | TORCH_INTERNAL_ASSERT( |
46 | batched, |
47 | "logicalToPhysical(tensor) should only be passed a BatchedTensor" ); |
48 | return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) }; |
49 | } |
50 | |
51 | int64_t VmapPhysicalView::numBatchDims() const { |
52 | return levels_.count(); |
53 | } |
54 | |
55 | int64_t VmapPhysicalView::numLogicalDims() const { |
56 | return /*physical*/tensor_.dim() - numBatchDims(); |
57 | } |
58 | |
59 | VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const { |
60 | auto logical_ndim = numLogicalDims(); |
61 | // NB: fmap doesn't have a SmallVector variant, so we don't use it here. |
62 | VmapDimVector result; |
63 | result.reserve(logical_ndim); |
64 | if (opt_logical_dims.has_value() && !opt_logical_dims.value().empty()) { |
65 | auto logical_dims = opt_logical_dims.value(); |
66 | for (auto dim : logical_dims) { |
67 | result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims()); |
68 | } |
69 | } else { |
70 | for (int64_t dim = 0; dim < logical_ndim; dim++) { |
71 | result.push_back(dim + numBatchDims()); |
72 | } |
73 | } |
74 | return result; |
75 | } |
76 | |
77 | int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const { |
78 | auto logical_ndim = numLogicalDims(); |
79 | return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims(); |
80 | } |
81 | |
82 | VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const { |
83 | VmapDimVector result; |
84 | result.reserve(logical_shape.size() + numBatchDims()); |
85 | auto tensor_sizes = tensor_.sizes(); |
86 | result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims()); |
87 | result.insert(result.end(), logical_shape.begin(), logical_shape.end()); |
88 | return result; |
89 | } |
90 | |
91 | static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) { |
92 | BatchDims bdims; |
93 | int64_t dim = 0; |
94 | for (const auto level : c10::irange(kVmapNumLevels)) { |
95 | if (!levels_bitset[level]) { |
96 | continue; |
97 | } |
98 | bdims.emplace_back(level, dim++); |
99 | } |
100 | return bdims; |
101 | } |
102 | |
103 | // Given a Tensor or a BatchedTensor, returns the underlying physical tensor |
104 | // with all vmapped dimensions permuted to the front, if they exist, and a |
105 | // bitset of vmap levels that were present in the tensor. |
106 | static std::pair<Tensor,std::bitset<kVmapNumLevels>> |
107 | getPhysicalTensorAndLevels(const Tensor& self) { |
108 | auto* batched = maybeGetBatchedImpl(self); |
109 | if (batched) { |
110 | return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())}; |
111 | } |
112 | return {self, 0}; |
113 | } |
114 | |
115 | // Given a Tensor or a BatchedTensor, creates a physical view of the tensor |
116 | // such that it has a batch dimension for each level in `requested_levels` |
117 | // and `requested_example_dim` number of non-batch-dimensions. |
118 | // |
119 | // This function is useful in preparing physical views on tensors that can |
120 | // then be passed into broadcasting operations. For example, when adding |
121 | // two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the |
122 | // batch dimensions, we must align the batch dimensions and non-batch-dimensions |
123 | // (henceforth referred to as the "example" dimensions) separately to produce |
124 | // tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added. |
125 | // |
126 | // Here's a direct example of using alignBatchDimsAtFront on the above two tensors. |
127 | // |
128 | // 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2) |
129 | // returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for |
130 | // level 1 and another extra dimension to pad the example dimensions to 2. |
131 | // |
132 | // 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2) |
133 | // returns a physical view of size [B0, B1, 2, 3] |
134 | static Tensor alignBatchDimsAtFront( |
135 | const Tensor& self, |
136 | std::bitset<kVmapNumLevels> requested_levels, |
137 | int64_t requested_example_dim) { |
138 | Tensor physical_tensor; |
139 | std::bitset<kVmapNumLevels> tensor_levels; |
140 | std::tie(physical_tensor, tensor_levels) = getPhysicalTensorAndLevels(self); |
141 | |
142 | TORCH_INTERNAL_ASSERT( |
143 | (tensor_levels | requested_levels) == requested_levels, |
144 | "`requested_levels` must be a superset of `self`'s levels" ); |
145 | |
146 | auto physical_sizes = physical_tensor.sizes(); |
147 | |
148 | const auto tensor_example_dim = ( |
149 | static_cast<int64_t>(physical_sizes.size()) |
150 | - /*num_batch_dims*/static_cast<int64_t>(tensor_levels.count()) |
151 | ); |
152 | TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim); |
153 | |
154 | if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) { |
155 | // Optimization: no need to do another view if the physical tensor is |
156 | // already the correct shape |
157 | return physical_tensor; |
158 | } |
159 | |
160 | VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1); |
161 | |
162 | // align the example dims (non-bdims dims) first |
163 | // aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:] |
164 | std::copy( |
165 | physical_sizes.rbegin(), |
166 | physical_sizes.rbegin() + tensor_example_dim, |
167 | aligned_sizes.rbegin()); |
168 | |
169 | // align the bdims |
170 | int64_t level = 0; |
171 | int64_t tensor_dim = 0; |
172 | for (const auto bdim : c10::irange(requested_levels.count())) { |
173 | // Determine the level of the bdim |
174 | while (!requested_levels[level]) level++; |
175 | if (tensor_levels[level]) { |
176 | aligned_sizes[bdim] = physical_sizes[tensor_dim++]; |
177 | } |
178 | level++; |
179 | } |
180 | return physical_tensor.view(aligned_sizes); |
181 | } |
182 | |
183 | // The algorithm is as follows: |
184 | // 1. Figure out what all of the collective levels in `logical_tensors` is. |
185 | // 2. Move all batch dims to the front of the tensors and add extra dims |
186 | // of size 1. At this point, every tensor will have a dimension for |
187 | // each of the collective levels. |
188 | // 3. Compute the batch_sizes. |
189 | // 4. Expand each physical tensor so that they have output batch size equal |
190 | // to `batch_sizes` |
191 | VmapPhysicalViewVec |
192 | MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) { |
193 | // Figure out all of the collective vmap levels in `logical_tensors`. |
194 | std::bitset<kVmapNumLevels> collective_levels; |
195 | for (const auto& logical_tensor : logical_tensors) { |
196 | auto* batched = maybeGetBatchedImpl(logical_tensor); |
197 | if (batched) { |
198 | collective_levels |= createVmapLevelsBitset(batched->bdims()); |
199 | } |
200 | } |
201 | |
202 | // Populate physical_tensors. |
203 | // This contains a list of regular (non-Batched) Tensors where all of the |
204 | // batch dims have been moved to the front of the tensor. Any previously |
205 | // non-existing batch dims get added to the tensors as new dimensions of size 1. |
206 | std::vector<Tensor> physical_tensors; |
207 | int64_t num_batch_dims = collective_levels.count(); |
208 | for (const auto& logical_tensor : logical_tensors) { |
209 | auto requested_example_dim = /*logical_dim*/logical_tensor.dim(); |
210 | auto physical_tensor = alignBatchDimsAtFront( |
211 | logical_tensor, collective_levels, requested_example_dim); |
212 | physical_tensors.push_back(std::move(physical_tensor)); |
213 | } |
214 | |
215 | // Compute batch_sizes |
216 | VmapDimVector batch_sizes(num_batch_dims, 1); |
217 | for (const auto& physical_tensor : physical_tensors) { |
218 | auto physical_sizes = physical_tensor.sizes(); |
219 | for (const auto dim : c10::irange(num_batch_dims)) { |
220 | if (physical_sizes[dim] != 1) { |
221 | batch_sizes[dim] = physical_sizes[dim]; |
222 | } |
223 | } |
224 | } |
225 | |
226 | // Expand each physical_tensor so that it has batch sizes `batch_sizes` |
227 | VmapPhysicalViewVec result; |
228 | for (const auto& physical_tensor : physical_tensors) { |
229 | VmapDimVector expanded_size(batch_sizes.begin(), batch_sizes.end()); |
230 | auto physical_sizes = physical_tensor.sizes(); |
231 | expanded_size.insert( |
232 | expanded_size.end(), |
233 | physical_sizes.begin() + num_batch_dims, |
234 | physical_sizes.end()); |
235 | result.emplace_back(physical_tensor.expand(expanded_size), collective_levels); |
236 | } |
237 | return result; |
238 | } |
239 | |
240 | static std::pair<std::bitset<kVmapNumLevels>,int64_t> |
241 | getLevelsAndLargestLogicalDim(TensorList logical_tensors) { |
242 | TORCH_INTERNAL_ASSERT(!logical_tensors.empty()); |
243 | std::bitset<kVmapNumLevels> levels; |
244 | int64_t largest_logical_dim = -1; |
245 | for (const auto& tensor : logical_tensors) { |
246 | auto* batched = maybeGetBatchedImpl(tensor); |
247 | if (batched) { |
248 | levels = levels | createVmapLevelsBitset(batched->bdims()); |
249 | } |
250 | auto tensor_logical_dim = /*logical dim*/tensor.dim(); |
251 | if (tensor_logical_dim > largest_logical_dim) { |
252 | largest_logical_dim = tensor_logical_dim; |
253 | } |
254 | } |
255 | return { levels, largest_logical_dim }; |
256 | } |
257 | |
258 | VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) { |
259 | TORCH_INTERNAL_ASSERT( |
260 | logical_tensors.size() == 2, |
261 | "This function has only been tested for two tensors. Please add more tests " , |
262 | "before removing this check " ); |
263 | |
264 | VmapPhysicalViewVec result; |
265 | |
266 | std::bitset<kVmapNumLevels> levels; |
267 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
268 | int64_t largest_logical_dim; |
269 | std::tie(levels, largest_logical_dim) = getLevelsAndLargestLogicalDim(logical_tensors); |
270 | |
271 | for (const auto& tensor : logical_tensors) { |
272 | // NB: It's possible that we didn't actually need to align `tensor`. |
273 | // For example, when adding two tensors of size (B, 2), and (3, 2), where |
274 | // the first Tensor is a BatchedTensor with batch dim B and the second is |
275 | // a regular Tensor, we will return views of size (B, 1, 2) and (1, 3, 2). |
276 | // However, the view on the second tensor is unnecessary: broadcasting |
277 | // semantics allow for the addition of two tensors of size (B, 1, 2) and (3, 2)! |
278 | // |
279 | // If this unnecessary view is a problem, consider optimizing it away in |
280 | // the future. This may involve creating a new type of VmapPhysicalView |
281 | auto aligned = alignBatchDimsAtFront(tensor, levels, largest_logical_dim) ; |
282 | result.emplace_back(std::move(aligned), levels); |
283 | } |
284 | return result; |
285 | } |
286 | |
287 | VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const { |
288 | return VmapPhysicalToLogicalMap(levels_); |
289 | } |
290 | |
291 | Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const { |
292 | return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_)); |
293 | } |
294 | |
295 | void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const { |
296 | for (auto & physical_tensor : physical_tensors) { |
297 | physical_tensor = apply(physical_tensor); |
298 | } |
299 | } |
300 | |
301 | } // namespace at |
302 | |