1#include <ATen/LegacyVmapTransforms.h>
2#include <ATen/ATen.h>
3#include <ATen/core/IListRef.h>
4#include <c10/util/irange.h>
5
6namespace at {
7
8// Checks if the batch dims in `bdims` appear at the front of the tensor.
9static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
10 for (const auto idx : c10::irange(static_cast<int64_t>(bdims.size()))) {
11 if (bdims[idx].dim() != idx) {
12 return false;
13 }
14 }
15 return true;
16}
17
18// Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
19// and then returns a physical version of the Tensor.
20static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
21 auto bdims = batched->bdims();
22 const Tensor& physical_tensor = batched->value();
23 if (areBdimsAtFrontInOrder(bdims)) {
24 return physical_tensor;
25 }
26 const auto sizes = physical_tensor.sizes();
27 VmapDimVector permutation(sizes.size(), 0);
28 permutation.reserve(sizes.size());
29 const auto is_bdim = createBatchDimBitset(bdims);
30 int64_t idx = 0;
31 for (const auto& bdim : bdims) {
32 permutation[idx++] = bdim.dim();
33 }
34 for (const auto ptr : c10::irange(sizes.size())) {
35 if (is_bdim[ptr]) {
36 continue;
37 }
38 permutation[idx++] = ptr;
39 }
40 return physical_tensor.permute(permutation);
41}
42
43VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
44 auto* batched = maybeGetBatchedImpl(logical_tensor);
45 TORCH_INTERNAL_ASSERT(
46 batched,
47 "logicalToPhysical(tensor) should only be passed a BatchedTensor");
48 return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) };
49}
50
51int64_t VmapPhysicalView::numBatchDims() const {
52 return levels_.count();
53}
54
55int64_t VmapPhysicalView::numLogicalDims() const {
56 return /*physical*/tensor_.dim() - numBatchDims();
57}
58
59VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const {
60 auto logical_ndim = numLogicalDims();
61 // NB: fmap doesn't have a SmallVector variant, so we don't use it here.
62 VmapDimVector result;
63 result.reserve(logical_ndim);
64 if (opt_logical_dims.has_value() && !opt_logical_dims.value().empty()) {
65 auto logical_dims = opt_logical_dims.value();
66 for (auto dim : logical_dims) {
67 result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
68 }
69 } else {
70 for (int64_t dim = 0; dim < logical_ndim; dim++) {
71 result.push_back(dim + numBatchDims());
72 }
73 }
74 return result;
75}
76
77int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
78 auto logical_ndim = numLogicalDims();
79 return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
80}
81
82VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
83 VmapDimVector result;
84 result.reserve(logical_shape.size() + numBatchDims());
85 auto tensor_sizes = tensor_.sizes();
86 result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
87 result.insert(result.end(), logical_shape.begin(), logical_shape.end());
88 return result;
89}
90
91static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
92 BatchDims bdims;
93 int64_t dim = 0;
94 for (const auto level : c10::irange(kVmapNumLevels)) {
95 if (!levels_bitset[level]) {
96 continue;
97 }
98 bdims.emplace_back(level, dim++);
99 }
100 return bdims;
101}
102
103// Given a Tensor or a BatchedTensor, returns the underlying physical tensor
104// with all vmapped dimensions permuted to the front, if they exist, and a
105// bitset of vmap levels that were present in the tensor.
106static std::pair<Tensor,std::bitset<kVmapNumLevels>>
107getPhysicalTensorAndLevels(const Tensor& self) {
108 auto* batched = maybeGetBatchedImpl(self);
109 if (batched) {
110 return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())};
111 }
112 return {self, 0};
113}
114
115// Given a Tensor or a BatchedTensor, creates a physical view of the tensor
116// such that it has a batch dimension for each level in `requested_levels`
117// and `requested_example_dim` number of non-batch-dimensions.
118//
119// This function is useful in preparing physical views on tensors that can
120// then be passed into broadcasting operations. For example, when adding
121// two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the
122// batch dimensions, we must align the batch dimensions and non-batch-dimensions
123// (henceforth referred to as the "example" dimensions) separately to produce
124// tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added.
125//
126// Here's a direct example of using alignBatchDimsAtFront on the above two tensors.
127//
128// 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2)
129// returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for
130// level 1 and another extra dimension to pad the example dimensions to 2.
131//
132// 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2)
133// returns a physical view of size [B0, B1, 2, 3]
134static Tensor alignBatchDimsAtFront(
135 const Tensor& self,
136 std::bitset<kVmapNumLevels> requested_levels,
137 int64_t requested_example_dim) {
138 Tensor physical_tensor;
139 std::bitset<kVmapNumLevels> tensor_levels;
140 std::tie(physical_tensor, tensor_levels) = getPhysicalTensorAndLevels(self);
141
142 TORCH_INTERNAL_ASSERT(
143 (tensor_levels | requested_levels) == requested_levels,
144 "`requested_levels` must be a superset of `self`'s levels");
145
146 auto physical_sizes = physical_tensor.sizes();
147
148 const auto tensor_example_dim = (
149 static_cast<int64_t>(physical_sizes.size())
150 - /*num_batch_dims*/static_cast<int64_t>(tensor_levels.count())
151 );
152 TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
153
154 if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
155 // Optimization: no need to do another view if the physical tensor is
156 // already the correct shape
157 return physical_tensor;
158 }
159
160 VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
161
162 // align the example dims (non-bdims dims) first
163 // aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]
164 std::copy(
165 physical_sizes.rbegin(),
166 physical_sizes.rbegin() + tensor_example_dim,
167 aligned_sizes.rbegin());
168
169 // align the bdims
170 int64_t level = 0;
171 int64_t tensor_dim = 0;
172 for (const auto bdim : c10::irange(requested_levels.count())) {
173 // Determine the level of the bdim
174 while (!requested_levels[level]) level++;
175 if (tensor_levels[level]) {
176 aligned_sizes[bdim] = physical_sizes[tensor_dim++];
177 }
178 level++;
179 }
180 return physical_tensor.view(aligned_sizes);
181}
182
183// The algorithm is as follows:
184// 1. Figure out what all of the collective levels in `logical_tensors` is.
185// 2. Move all batch dims to the front of the tensors and add extra dims
186// of size 1. At this point, every tensor will have a dimension for
187// each of the collective levels.
188// 3. Compute the batch_sizes.
189// 4. Expand each physical tensor so that they have output batch size equal
190// to `batch_sizes`
191VmapPhysicalViewVec
192MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
193 // Figure out all of the collective vmap levels in `logical_tensors`.
194 std::bitset<kVmapNumLevels> collective_levels;
195 for (const auto& logical_tensor : logical_tensors) {
196 auto* batched = maybeGetBatchedImpl(logical_tensor);
197 if (batched) {
198 collective_levels |= createVmapLevelsBitset(batched->bdims());
199 }
200 }
201
202 // Populate physical_tensors.
203 // This contains a list of regular (non-Batched) Tensors where all of the
204 // batch dims have been moved to the front of the tensor. Any previously
205 // non-existing batch dims get added to the tensors as new dimensions of size 1.
206 std::vector<Tensor> physical_tensors;
207 int64_t num_batch_dims = collective_levels.count();
208 for (const auto& logical_tensor : logical_tensors) {
209 auto requested_example_dim = /*logical_dim*/logical_tensor.dim();
210 auto physical_tensor = alignBatchDimsAtFront(
211 logical_tensor, collective_levels, requested_example_dim);
212 physical_tensors.push_back(std::move(physical_tensor));
213 }
214
215 // Compute batch_sizes
216 VmapDimVector batch_sizes(num_batch_dims, 1);
217 for (const auto& physical_tensor : physical_tensors) {
218 auto physical_sizes = physical_tensor.sizes();
219 for (const auto dim : c10::irange(num_batch_dims)) {
220 if (physical_sizes[dim] != 1) {
221 batch_sizes[dim] = physical_sizes[dim];
222 }
223 }
224 }
225
226 // Expand each physical_tensor so that it has batch sizes `batch_sizes`
227 VmapPhysicalViewVec result;
228 for (const auto& physical_tensor : physical_tensors) {
229 VmapDimVector expanded_size(batch_sizes.begin(), batch_sizes.end());
230 auto physical_sizes = physical_tensor.sizes();
231 expanded_size.insert(
232 expanded_size.end(),
233 physical_sizes.begin() + num_batch_dims,
234 physical_sizes.end());
235 result.emplace_back(physical_tensor.expand(expanded_size), collective_levels);
236 }
237 return result;
238}
239
240static std::pair<std::bitset<kVmapNumLevels>,int64_t>
241getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
242 TORCH_INTERNAL_ASSERT(!logical_tensors.empty());
243 std::bitset<kVmapNumLevels> levels;
244 int64_t largest_logical_dim = -1;
245 for (const auto& tensor : logical_tensors) {
246 auto* batched = maybeGetBatchedImpl(tensor);
247 if (batched) {
248 levels = levels | createVmapLevelsBitset(batched->bdims());
249 }
250 auto tensor_logical_dim = /*logical dim*/tensor.dim();
251 if (tensor_logical_dim > largest_logical_dim) {
252 largest_logical_dim = tensor_logical_dim;
253 }
254 }
255 return { levels, largest_logical_dim };
256}
257
258VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
259 TORCH_INTERNAL_ASSERT(
260 logical_tensors.size() == 2,
261 "This function has only been tested for two tensors. Please add more tests ",
262 "before removing this check ");
263
264 VmapPhysicalViewVec result;
265
266 std::bitset<kVmapNumLevels> levels;
267 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
268 int64_t largest_logical_dim;
269 std::tie(levels, largest_logical_dim) = getLevelsAndLargestLogicalDim(logical_tensors);
270
271 for (const auto& tensor : logical_tensors) {
272 // NB: It's possible that we didn't actually need to align `tensor`.
273 // For example, when adding two tensors of size (B, 2), and (3, 2), where
274 // the first Tensor is a BatchedTensor with batch dim B and the second is
275 // a regular Tensor, we will return views of size (B, 1, 2) and (1, 3, 2).
276 // However, the view on the second tensor is unnecessary: broadcasting
277 // semantics allow for the addition of two tensors of size (B, 1, 2) and (3, 2)!
278 //
279 // If this unnecessary view is a problem, consider optimizing it away in
280 // the future. This may involve creating a new type of VmapPhysicalView
281 auto aligned = alignBatchDimsAtFront(tensor, levels, largest_logical_dim) ;
282 result.emplace_back(std::move(aligned), levels);
283 }
284 return result;
285}
286
287VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
288 return VmapPhysicalToLogicalMap(levels_);
289}
290
291Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
292 return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
293}
294
295void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
296 for (auto & physical_tensor : physical_tensors) {
297 physical_tensor = apply(physical_tensor);
298 }
299}
300
301} // namespace at
302