1#include <torch/library.h>
2#include <ATen/RedispatchFunctions.h>
3#include <ATen/LegacyVmapTransforms.h>
4#include <ATen/LegacyBatchedFallback.h>
5#include <ATen/native/ResizeCommon.h>
6#include <ATen/ATen.h>
7#include <ATen/core/IListRef.h>
8#include <c10/util/irange.h>
9#include <c10/core/SymIntArrayRef.h>
10
11#include <utility>
12
13namespace at {
14
15// NOTE: [What is a batching rule?]
16//
17// A *batching rule* implements the logic of how to call an operator on inputs
18// that have zero or more additional batch dimensions. When one does a vmap, the
19// dimension(s) being vmap'ed over get recorded as batch dimensions.
20//
21// For example, vmap(torch.add)(x, y)
22// 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
23// 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
24// 3. and then runs `torch.add(batched_x, batched_y)`.
25
26// NOTE: [When should I add a batching rule?]
27// When you are adding a new operator, you'll need to add a batching rule so
28// that vmap can work efficiently with said operator. If you do not, we'll attempt
29// to generate a slow fallback for the batching rule.
30
31// NOTE: [How to write batching rules?]
32// The signature of a batching rule should look like exactly like the C++ signature
33// of its operator.
34//
35// First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
36//
37// At a high level, what a batching rule does is the following:
38// 1. Converts (logical) BatchedTensors to views on physical tensors.
39// 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
40// arguments that correspond to the physical tensors.
41// 3. Calls at:: operations on the physical tensors and arguments to produce
42// some physical results.
43// 4. Converts physical results back to BatchedTensors.
44//
45// Steps 1, 2, and 4 differ for operators with different batching behaviors. When
46// writing a new batching rule, please select a VmapTransform that matches the
47// batching behavior of your operation. The VmapTransform provides helper functions
48// to do steps (1), (2), and (4).
49// (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
50
51// Note: [Future plans]
52// The API for writing a batching rule isn't stable. In the future, we'd like
53// to think about the problem of translating these batching rules to TorchScript.
54// Ideally batching rules in eager mode vs TorchScript would look pretty similar,
55// if not use the same mechanism. In order to accomplish that we might have to
56// do some refactoring.
57
58// PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
59static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
60 return dim == 0 || dim == -1;
61}
62
63Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional<ScalarType> dtype) {
64 if (opt_dims.has_value()) {
65 auto dims = opt_dims.value();
66 // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
67 // and instead returns a new scalar tensor (this also happens for dim=-1)
68 // If the following happens:
69 // >>> x = torch.randn(B0) # the per-examples are all scalars
70 // >>> vmap(partial(torch.sum, dim=0), x)
71 // then we replicate the behavior of sum(scalar_tensor, dim=0).
72 if (/*logical*/self.dim() == 0 && (dims.empty() || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
73 return self.clone();
74 }
75 }
76 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
77 auto dims_physical = self_physical.getPhysicalDims(opt_dims);
78 auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
79 return self_physical.getPhysicalToLogicalMap().apply(result);
80}
81
82bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
83 if (logical_tensor.dim() > 0) {
84 return false;
85 }
86 auto* batched = maybeGetBatchedImpl(logical_tensor);
87 if (batched) {
88 return false;
89 }
90 return true;
91}
92
93template <typename F, F Func, typename... ExtraArgs>
94Tensor binary_pointwise_batching_rule(
95 const Tensor& self, const Tensor& other, ExtraArgs... args) {
96 if (self.dim() > 0 && other.dim() > 0) {
97 auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
98 auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
99 return physical_args[0].getPhysicalToLogicalMap().apply(result);
100 }
101 if (isPhysicalScalarTensor(self)) {
102 auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
103 auto result = Func(self, other_physical.tensor(), args...);
104 return other_physical.getPhysicalToLogicalMap().apply(result);
105 }
106 if (isPhysicalScalarTensor(other)) {
107 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
108 auto result = Func(self_physical.tensor(), other, args...);
109 return self_physical.getPhysicalToLogicalMap().apply(result);
110 }
111
112 // At this point, we know at least one of the operands is a logical Scalar tensor.
113 // Here we must emulate TensorIterator's special behavior on Scalars.
114 //
115 // As a motivating example, consider the following:
116 // x = torch.randn(3, 10)
117 // y = torch.randn(3, dtype=torch.double)
118 // vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double))
119 //
120 // At a per-example level, we are adding FloatTensor[10] and DoubleTensor[];
121 // Type Promotion dictates that the result should be FloatTensor[10].
122 // This means we cannot directly pass the physical tensors (x and y) to
123 // TensorIterator (if we did, it would promote them to DoubleTensor).
124 //
125 // FIXME(rzou): I didn't want to go down the slippery slope of emulating
126 // everything TensorIterator does (it would be better to refactor out the
127 // TensorIterator logic). The one thing that this code doesn't handle
128 // is cross-device logical scalar tensors.
129 // cpu_tensor = torch.randn(3)
130 // cuda_tensor = torch.randn(3, 10, device='cuda')
131 // vmap(torch.mul)(cpu_tensor, cuda_tensor)
132 //
133 // At a per-example level, we are adding CPUTensor[] and CUDATensor[10].
134 // TensorIterator allows for this cross-device operation because one of the
135 // tensors is a Scalar CPU tensor. However, the following code will throw an
136 // error in that case. I don't expect to see many use cases for this, so
137 // this is probably fine as-is.
138 auto logical_self = self;
139 auto logical_other = other;
140 auto result_type = at::native::result_type(logical_self, logical_other);
141 if (logical_self.scalar_type() != result_type) {
142 logical_self = logical_self.to(result_type);
143 }
144 if (logical_other.scalar_type() != result_type) {
145 logical_other = logical_other.to(result_type);
146 }
147 auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
148 {std::move(logical_self), std::move(logical_other)});
149 auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
150 return physical_args[0].getPhysicalToLogicalMap().apply(result);
151}
152
153Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
154 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
155 auto size_physical = self_physical.getPhysicalShape(size);
156 auto self_physical_dim = self_physical.tensor().dim();
157
158 TORCH_CHECK(self_physical_dim <= static_cast<int64_t>(size_physical.size()),
159 "expand: the number of sizes provided (", /*logical*/size.size(), ") ",
160 "must be greater or equal to the number of dimensions in the tensor (",
161 /*logical dim*/self.dim(), ")");
162
163 if (self_physical_dim == static_cast<int64_t>(size_physical.size())) {
164 auto result = self_physical.tensor().expand(size_physical, implicit);
165 return self_physical.getPhysicalToLogicalMap().apply(result);
166 }
167
168 TORCH_INTERNAL_ASSERT(self_physical_dim < static_cast<int64_t>(size_physical.size()));
169 // Here, we know we are expanding a (logical) tensor to a larger number
170 // of dimensions. We have to be careful because we can't call expand directly
171 // due to the presence of batch dimensions.
172 //
173 // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
174 // The result should be a tensor of size [B0, 2, 3].
175 // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
176 // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
177 // then expand.
178 auto self_physical_size = self_physical.tensor().sizes();
179 auto extra_dims = size_physical.size() - self_physical_dim;
180 VmapDimVector view_shape(size_physical.size(), 1);
181 std::copy(self_physical_size.begin(),
182 self_physical_size.begin() + self_physical.numBatchDims(),
183 view_shape.begin());
184 std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
185 self_physical_size.end(),
186 view_shape.begin() + self_physical.numBatchDims() + extra_dims);
187 auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
188 return self_physical.getPhysicalToLogicalMap().apply(result);
189}
190
191std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
192 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
193 auto dim_physical = self_physical.getPhysicalDim(dim);
194 auto result = at::chunk(self_physical.tensor(), chunks, dim_physical);
195 self_physical.getPhysicalToLogicalMap().applyInplace(result);
196 return result;
197}
198
199Tensor clamp_batching_rule(const Tensor& self, const optional<Scalar>& min, const optional<Scalar>& max) {
200 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
201 auto result = at::clamp(self_physical.tensor(), min, max);
202 return self_physical.getPhysicalToLogicalMap().apply(result);
203}
204
205Tensor clamp_min_batching_rule(const Tensor& self, const Scalar& min) {
206 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
207 auto result = at::clamp_min(self_physical.tensor(), min);
208 return self_physical.getPhysicalToLogicalMap().apply(result);
209}
210
211Tensor clamp_max_batching_rule(const Tensor& self, const Scalar& max) {
212 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
213 auto result = at::clamp_max(self_physical.tensor(), max);
214 return self_physical.getPhysicalToLogicalMap().apply(result);
215}
216
217std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
218 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
219 auto dim_physical = self_physical.getPhysicalDim(dim);
220 auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical);
221 self_physical.getPhysicalToLogicalMap().applyInplace(result);
222 return result;
223}
224
225std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) {
226 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
227 auto dim_physical = self_physical.getPhysicalDim(dim);
228 auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical);
229 self_physical.getPhysicalToLogicalMap().applyInplace(result);
230 return result;
231}
232
233Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
234 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
235 // NB: unsqueeze has some special handling of its `dim` argument so we can't call
236 // self_physical.getPhysicalDim directly. In particular, native::unsqueeze
237 // wraps the dim to (the logical dimension) + 1, so we need to do that here too.
238 // https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413
239 auto dim_physical =
240 self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1);
241 auto result = self_physical.tensor().unsqueeze(dim_physical);
242 return self_physical.getPhysicalToLogicalMap().apply(result);
243}
244
245Tensor& fill_inplace_scalar_batching_rule(Tensor& self, const Scalar& value) {
246 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
247 self_physical.tensor().fill_(value);
248 return self;
249}
250
251Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) {
252 auto value_batched = isBatchedTensor(value);
253
254 if (value_batched) {
255 auto physical_args =
256 BroadcastingVmapTransform::logicalToPhysical({self, value});
257 physical_args[0].tensor().copy_(physical_args[1].tensor());
258 } else {
259 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
260 self_physical.tensor().fill_(value);
261 }
262 return self;
263}
264
265Tensor& zero_inplace_batching_rule(Tensor &self) {
266 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
267 self_physical.tensor().zero_();
268 return self;
269}
270
271Tensor squeeze_batching_rule(const Tensor& self) {
272 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
273 auto physical_sizes = self_physical.tensor().sizes();
274
275 // Don't squeeze the batch dims!
276 VmapDimVector squeezed_sizes;
277 int64_t num_batch_dims = self_physical.numBatchDims();
278 squeezed_sizes.insert(
279 squeezed_sizes.end(),
280 physical_sizes.begin(),
281 physical_sizes.begin() + num_batch_dims);
282 for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) {
283 if (*it != 1) {
284 squeezed_sizes.push_back(*it);
285 }
286 }
287
288 auto result = self_physical.tensor().view(squeezed_sizes);
289 return self_physical.getPhysicalToLogicalMap().apply(result);
290}
291
292Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
293 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
294 auto dim_physical = self_physical.getPhysicalDim(dim);
295 auto result = self_physical.tensor().squeeze(dim_physical);
296 return self_physical.getPhysicalToLogicalMap().apply(result);
297}
298
299Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) {
300 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
301 auto dims_physical = self_physical.getPhysicalDims(dims);
302 auto result = self_physical.tensor().squeeze(dims_physical);
303 return self_physical.getPhysicalToLogicalMap().apply(result);
304}
305
306Tensor trace_batching_rule(const Tensor& self) {
307 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
308 // Batched Diagonal View
309 auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1);
310 auto result = at::sum(self_diag, -1);
311 return self_physical.getPhysicalToLogicalMap().apply(result);
312}
313
314Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) {
315 auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
316 auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
317 // Batched Diagonal View
318 auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1);
319 // Append a dimension of size one to the grad output
320 auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1);
321 grad_input_diag.copy_(grad_physical_tensor);
322 return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
323}
324
325Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) {
326 // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
327 // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
328 // >>> x = torch.randn(B0) # the per-examples are all scalars
329 // >>> vmap(lambda x: x.transpose(0, -1), x)
330 // then we replicate this behavior.
331 if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) &&
332 is_allowed_dim_on_scalar_tensor(dim1)) {
333 return self;
334 }
335 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
336 auto dim0_physical = self_physical.getPhysicalDim(dim0);
337 auto dim1_physical = self_physical.getPhysicalDim(dim1);
338 auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical);
339 return self_physical.getPhysicalToLogicalMap().apply(result);
340}
341
342Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
343 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
344 auto dims_physical = self_physical.getPhysicalDims(dims);
345
346 VmapDimVector all_dims_physical;
347 all_dims_physical.reserve(self_physical.tensor().dim());
348 for (const auto bdim : c10::irange(self_physical.numBatchDims())) {
349 all_dims_physical.push_back(bdim);
350 }
351 all_dims_physical.insert(
352 all_dims_physical.end(),
353 dims_physical.begin(),
354 dims_physical.end());
355 auto result = self_physical.tensor().permute(all_dims_physical);
356 return self_physical.getPhysicalToLogicalMap().apply(result);
357}
358
359Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
360 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
361 auto dim_physical = self_physical.getPhysicalDim(dim);
362 auto result = self_physical.tensor().select(dim_physical, index);
363 return self_physical.getPhysicalToLogicalMap().apply(result);
364}
365
366static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
367 return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims;
368}
369
370Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
371 auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
372 auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
373 auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
374 grad_input.select(physical_dim, index).copy_(grad_physical.tensor());
375 return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
376}
377
378Tensor slice_batching_rule(
379 const Tensor& self,
380 int64_t dim,
381 c10::optional<int64_t> start,
382 c10::optional<int64_t> end,
383 int64_t step) {
384 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
385 auto dim_physical = self_physical.getPhysicalDim(dim);
386 auto result = self_physical.tensor().slice(dim_physical, start, end, step);
387 return self_physical.getPhysicalToLogicalMap().apply(result);
388}
389
390Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
391 auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
392 auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
393 auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
394 grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor());
395 return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
396}
397
398Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
399 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
400 auto dim1_physical = self_physical.getPhysicalDim(dim1);
401 auto dim2_physical = self_physical.getPhysicalDim(dim2);
402 auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical);
403 return self_physical.getPhysicalToLogicalMap().apply(result);
404}
405
406Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
407 auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
408 auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
409 auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims());
410 auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims());
411 grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor());
412 return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
413}
414
415Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) {
416 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
417 auto source_physical = self_physical.getPhysicalDims(source);
418 auto destination_physical = self_physical.getPhysicalDims(destination);
419 auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical);
420 return self_physical.getPhysicalToLogicalMap().apply(result);
421}
422
423Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) {
424 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
425 auto shape_physical = self_physical.getPhysicalShape(shape);
426 auto result = self_physical.tensor().reshape(shape_physical);
427 return self_physical.getPhysicalToLogicalMap().apply(result);
428}
429
430std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
431 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
432 auto dim_physical = self_physical.getPhysicalDim(dim);
433 auto result = at::split(self_physical.tensor(), split_size, dim_physical);
434 self_physical.getPhysicalToLogicalMap().applyInplace(result);
435 return result;
436}
437
438std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
439 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
440 auto dim_physical = self_physical.getPhysicalDim(dim);
441 auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
442 self_physical.getPhysicalToLogicalMap().applyInplace(result);
443 return result;
444}
445
446std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
447 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
448 auto dim_physical = self_physical.getPhysicalDim(dim);
449 auto result = at::unbind(self_physical.tensor(), dim_physical);
450 self_physical.getPhysicalToLogicalMap().applyInplace(result);
451 return result;
452}
453
454Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) {
455 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
456 auto dim_physical = self_physical.getPhysicalDim(dim);
457 auto result = self_physical.tensor().unfold(dim_physical, size, step);
458 return self_physical.getPhysicalToLogicalMap().apply(result);
459}
460
461Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) {
462 TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
463 "NYI: Tensor.contiguous(...) inside of vmap for memory_format other ",
464 "than torch.contiguous_format");
465 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
466 auto result = physical_view.tensor().contiguous(memory_format);
467 return physical_view.getPhysicalToLogicalMap().apply(result);
468}
469
470Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
471 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
472 auto size_physical = self_physical.getPhysicalShape(size);
473 auto result = self_physical.tensor().view(size_physical);
474 return self_physical.getPhysicalToLogicalMap().apply(result);
475}
476
477Tensor view_as_complex_batching_rule(const Tensor& self) {
478 // guard against the user passing in a batch of scalar tensors with batch
479 // size equal to 2.
480 TORCH_CHECK(!self.sizes().empty(), "Input tensor must have one or more dimensions");
481 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
482 auto result = at::view_as_complex(self_physical.tensor());
483 return self_physical.getPhysicalToLogicalMap().apply(result);
484}
485
486// Checks that the smallest batch stride is greater than the largest example
487// stride. This is something we can support but we choose not to because it's
488// potentially error prone.
489static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) {
490 auto smallest_batch_stride = std::min_element(
491 physical_strides.begin(), physical_strides.begin() + num_batch_dims);
492 auto largest_example_stride = std::max_element(
493 physical_strides.begin() + num_batch_dims, physical_strides.end());
494 if (largest_example_stride == physical_strides.end()) {
495 // No example dimensions
496 return;
497 }
498 TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride,
499 "vmap: Calling Tensor.as_strided is not supported unless the batch dims being ",
500 "vmapped over are at the front of the tensor (in memory layout). When they are ",
501 "not at the front of the tensor this operation can be error prone so we "
502 "actively discourage it; please file us a bug report and/or try to ",
503 "express the as_strided operation in terms of PyTorch view operations");
504}
505
506// given (sizes, strides, storage_offset) returns the maximum location that
507// can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
508// with zero-size dims).
509static optional<int64_t> maximum_indexable_location(
510 IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
511 auto result = native::storage_size_for(sizes, strides);
512 if (result == 0) {
513 return nullopt;
514 }
515 return result + storage_offset;
516}
517
518// Let x be the "first slice" of physical_tensor.
519// This checks that the range of possible memory locations accessible by
520// x.as_strided(sizes, strides, maybe_storage_offset)
521// are within the bounds of possible memory locations accessible by x.
522static void checkBasicAsStridedValidForSlice(
523 const Tensor& physical_tensor,
524 int64_t num_batch_dims,
525 IntArrayRef sizes,
526 IntArrayRef strides,
527 optional<int64_t> maybe_storage_offset) {
528 auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims);
529 auto slice_strides = physical_tensor.strides().slice(num_batch_dims);
530 auto base_offset = physical_tensor.storage_offset();
531
532 auto storage_offset = maybe_storage_offset.value_or(base_offset);
533
534 auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
535 auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
536
537 if (!max_as_strided_loc.has_value()) {
538 return;
539 }
540 if (!max_slice_loc.has_value()) {
541 TORCH_CHECK(false,
542 "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
543 "can access memory outside of `tensor`. `tensor` has no storage but the ",
544 "passed-in (size, stride, storage_offset) imply a result with some storage. ",
545 "This is not supported inside of vmap, please try to rewrite the ",
546 "`as_strided` call as a sequence of PyTorch view operations");
547 }
548
549 TORCH_CHECK(
550 *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
551 "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
552 "can access memory outside of `tensor`. `result` can access some",
553 "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
554 "`tensor` can only access some memory in range [", base_offset, ", ",
555 *max_slice_loc, "]. This is not supported inside of vmap, please try to",
556 "rewrite the `as_strided` call as a sequence of PyTorch view operations");
557}
558
559Tensor _reshape_alias_batching_rule(const Tensor& self, IntArrayRef sizes, IntArrayRef strides) {
560 return reshape_batching_rule(self, sizes);
561}
562
563Tensor _new_zeros_with_same_feature_meta_batching_rule(
564 const Tensor& self,
565 const Tensor& other,
566 int64_t unused_num_batch_dims) {
567 TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
568 "Only the 'batched grad' use case is supported in PyTorch core.");
569
570 TORCH_INTERNAL_ASSERT(unused_num_batch_dims == 0,
571 "num_batch_dims should not be explicitly passed in because it will be overridden");
572 auto self_physical_view = at::MultiBatchVmapTransform::logicalToPhysical(self);
573 const auto& self_physical_tensor = self_physical_view.tensor();
574 int64_t num_batch_dims = self_physical_view.numBatchDims();
575 checkBatchDimsAtFrontInLayout(self_physical_tensor.strides(), num_batch_dims);
576 auto result = at::_new_zeros_with_same_feature_meta(self_physical_tensor, other, num_batch_dims);
577 return self_physical_view.getPhysicalToLogicalMap().apply(result);
578}
579
580bool _has_same_storage_numel_batching_rule(const Tensor& self, const Tensor& other) {
581 TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
582 "Only the 'batched grad' use case is supported in PyTorch core.");
583 // The _has_same_storage_numel check is skipped if the tangent is a batched
584 // tensor because using as_strided to access storage locations not indexable
585 // by the input tensor is not supported in vmap
586 return true;
587}
588
589// What are the semantics of as_strided inside of vmap?
590// y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
591// This returns a view on `x`, `y`, such that each y[i] has:
592// - sizes: `sizes`
593// - strides: `strides`
594// - storage_offset: offset + i * x.stride(batch_dim)
595//
596// In other words, it is as if we had treated each x[i] as having storage
597// offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
598// (that is equivalent to x[i].as_strided(
599// sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
600//
601// Note that this *may* be different from actually running as_strided
602// in a for-loop. This is due to how as_strided takes in `offset` to be
603// an *absolute* offset. As an example, consider:
604// >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
605// >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
606// Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
607// However, we consider the above for-loop comprehension to be a user error:
608// a user should have written the following if they wanted to use as_strided
609// in a per-sample way:
610// >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
611Tensor as_strided_batching_rule(
612 const Tensor& tensor,
613 IntArrayRef sizes,
614 IntArrayRef strides,
615 optional<int64_t> storage_offset) {
616 auto physical_view = at::MultiBatchVmapTransform::logicalToPhysical(tensor);
617 auto num_batch_dims = physical_view.numBatchDims();
618 auto physical_sizes = physical_view.getPhysicalShape(sizes);
619 const auto& physical_tensor = physical_view.tensor();
620
621 // We can't rely on the physical as_strided call to do this for us because
622 // we do some sanity checks on the size/strides before calling into as_strided.
623 TORCH_CHECK(sizes.size() == strides.size(),
624 "Tensor.as_strided(size, stride, ...): size and stride must have the ",
625 "same length! Got size ", sizes, " and stride ", strides);
626
627 // Sanity checks:
628 // 1. All batch dims are at the front in memory layout (not necessary for
629 // correctness, but we are worried the user might be doing crazy things)
630 // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
631 // is valid for a slice of the input tensor.
632 // See Note: [When will the as_strided batching rule fail?] for details.
633 checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims);
634 checkBasicAsStridedValidForSlice(
635 physical_tensor, num_batch_dims, sizes, strides, storage_offset);
636
637 // physical_strides = physical tensor's batch strides + (logical) strides
638 auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
639 at::VmapDimVector physical_strides;
640 physical_strides.reserve(num_batch_dims + strides.size());
641 physical_strides.insert(
642 physical_strides.end(), batch_strides.begin(), batch_strides.end());
643 physical_strides.insert(
644 physical_strides.end(), strides.begin(), strides.end());
645
646 // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
647 // is valid for all i, then it turns out that
648 // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
649 // and creates a tensor y such that each y[i] references the same memory
650 // locations as zi. See NOTE: [When will the as_strided batching rule fail?]
651 auto result = physical_view.tensor().as_strided(
652 physical_sizes, physical_strides, storage_offset);
653 return physical_view.getPhysicalToLogicalMap().apply(result);
654}
655
656// NOTE: [When will the as_strided batching rule fail?]
657// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
658// is valid for all i, then it turns out that
659// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
660// creates a tensor y such that each y[i] refers to the same memory as zi.
661//
662// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
663// Furthermore, let's say that as a part of being "valid" this as_strided call
664// does not return a result that can index memory not indexable by xs[i].
665//
666// WLOG, assume that there's only one batch dim and it is at the front of the
667// `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
668// - If the batch dim isn't at the front of the tensor, then we can just move it
669// to the front with movedim/permute. This is always valid because it just swaps
670// some strides around.
671// - This proof also works for tensors with multiple batch dims. We just have to
672// do a little accounting:
673// - instead of [B], we'd have [B0, B1, ..., Bk].
674// - instead of [S], we'd have [S0, S1, ..., Sk].
675// - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
676// - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
677//
678// [Equation 1]
679// xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
680// - sizes: sizes
681// - strides: strides
682// - offset: offset + S * i
683//
684// x.as_strided itself checks that:
685// - (sizes, strides, offset) are in bounds for `x`'s storage.
686// - strides are positive
687// - offset is positive
688//
689// Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
690// is valid, then
691// ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
692//
693// If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
694// won't error out. So all we need to check is that the memory locations are
695// what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
696//
697// xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
698// xs.as_strided([B] + sizes, [S] + strides, offset)
699//
700// xs.as_strided([B] + sizes, [S] + strides, offset) has:
701// - sizes: [B] + sizes
702// - strides: [S] + strides
703// - offset: offset
704//
705// xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
706// - sizes: sizes
707// - strides: strides
708// - offset: offset + S * i
709// These memory locations are exactly the same as what we got for [Equation 1],
710// so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
711//
712// [Hand-wavy proof of Claim 1]
713// Part of our definition of being valid is that xs[i].as_strided(...)
714// must return a tensor that only uses memory indexable by xs[i].
715// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
716// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
717// <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
718// (the largest-index memory location of xs[i].as_strided(...) must be \leq
719// the largest-index memory location of xs[i])
720//
721// Fiddling that inequality gives us:
722// offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
723// <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
724//
725// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
726// <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
727//
728// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
729// <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
730//
731// offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
732// <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
733// (the largest-index memory location of xs.as_strided(size, stride, offset)
734// is \leq than the largest-index memory location of xs)
735// Under the assumptions we've made, the lower bound (lowest indexed memory)
736// is trivially within the storage.
737//
738// Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
739// `xs`'s storage.
740
741template <typename F, F Func, typename... ExtraArgs>
742Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
743 auto* input_batched = unsafeGetBatchedImpl(input);
744 auto output_physical = Func(input_batched->value(), args...);
745 auto old_bdims = input_batched->bdims();
746 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
747}
748
749template <typename F, F Func, typename... ExtraArgs>
750Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
751 auto* input_batched = unsafeGetBatchedImpl(input);
752 auto output_physical = (input_batched->value().*Func)(extra_args...);
753 auto old_bdims = input_batched->bdims();
754 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
755}
756
757Tensor pow_scalar_Tensor_batching_rule(const Scalar& other, const Tensor& self) {
758 auto* self_batched = unsafeGetBatchedImpl(self);
759 auto output_physical = at::pow(other, self_batched->value());
760 auto old_bdims = self_batched->bdims();
761 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
762}
763
764Tensor clone_batching_rule(const Tensor& self, optional<MemoryFormat> memory_format) {
765 // Memory format support is a little tricky because vmap is allowed to move
766 // around batch dimensions and some memory formats are rank-dependent.
767 // Another weird case is:
768 // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we
769 // allow the user to clone a Tensor with 3 logical dimensions and 1 batch
770 // dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims
771 // and N>1 batch dims?
772 TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
773 || memory_format == MemoryFormat::Contiguous,
774 "NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
775 "memory_format torch.preserve_format or torch.contiguous_format (got ",
776 *memory_format, ")");
777
778 if (memory_format == MemoryFormat::Contiguous) {
779 // There is an ambiguity here when the batch dims are not at the front of
780 // the tensor.
781 // >>> x = torch.randn(3, B0, 5)
782 // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x)
783 // >>> y[0].is_contiguous()
784 // ???
785 // Should we make the whole tensor contiguous, or should we
786 // make the non-batch dims contiguous? We've chosen the latter because
787 // philosophically vmap hides the batch dims and operates on a per-sample level.
788 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
789 auto output_physical = at::clone(physical_view.tensor(), memory_format);
790 return physical_view.getPhysicalToLogicalMap().apply(output_physical);
791 }
792
793 TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
794 auto* self_batched = unsafeGetBatchedImpl(self);
795 auto output_physical = at::clone(self_batched->value(), memory_format);
796 auto old_bdims = self_batched->bdims();
797 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
798}
799
800// Note [Batching rules for matmul-like operators]
801// at::matmul doesn't "de-expand" arguments to get better performance (maybe
802// it should). In the batching rules for matmul-like operators (dot, mv, mm),
803// we should be careful not to expand any unnecessary dimensions. e.g., if
804// only one of the two arguments is a BatchedTensor, then we should try
805// not to expand batch dimensions onto the other arg.
806Tensor mv_batching_rule(const Tensor& self, const Tensor& other) {
807 auto self_batched = isBatchedTensor(self);
808 auto other_batched = isBatchedTensor(other);
809
810 // A shape checking API would be nice...
811 TORCH_CHECK(self.dim() == 2 && other.dim() == 1,
812 "mv(self, other): Shape mismatch: expected matrix "
813 "(got `self` of size ", self.sizes(), ") ",
814 "and vector (got `other` of size ", other.sizes(), ")");
815
816 // See Note [Batching rules for matmul-like operators] for why we have cases
817 if (self_batched && !other_batched) {
818 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
819 auto result = at::matmul(self_physical.tensor(), other);
820 return self_physical.getPhysicalToLogicalMap().apply(result);
821 }
822 if (!self_batched && other_batched) {
823 // self_physical: [L, K], other_physical: [..., K]
824 // We view the tensors as [L, K], [..., K, 1], perform matmul to get
825 // a tensor of size [..., L, 1], and unsqueeze the last dim.
826 auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
827 auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
828 return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
829 }
830 if (self_batched && other_batched) {
831 // self_physical: [..., L, K], other_physical: [..., K]
832 // We view the tensors as [..., L, K], [..., K, 1], perform matmul to get
833 // a tensor of size [..., L, 1], and unsqueeze the last dim.
834 auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
835 auto result = at::matmul(
836 physical_args[0].tensor(),
837 physical_args[1].tensor().unsqueeze(-1));
838 return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1));
839 }
840 TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
841}
842
843Tensor _make_dual_batching_rule(
844 c10::DispatchKeySet ks,
845 const Tensor& primal,
846 const Tensor& tangent,
847 int64_t level
848) {
849 DispatchKeySet after_batched_keyset =
850 DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Batched);
851 return at::redispatch::_make_dual(ks & after_batched_keyset, primal, tangent, level);
852}
853
854Tensor dot_batching_rule(const Tensor& self, const Tensor& other) {
855 auto self_batched = isBatchedTensor(self);
856 auto other_batched = isBatchedTensor(other);
857
858 TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1,
859 "dot(self, other): Shape mismatch: vector "
860 "(got `self` of size ", self.sizes(), ") ",
861 "and vector (got `other` of size ", other.sizes(), ")");
862
863 // See Note [Batching rules for matmul-like operators] for why we have cases
864 if (self_batched && !other_batched) {
865 // self_physical: [..., K], other_physical: [K]
866 // View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze.
867 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
868 auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other);
869 return self_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
870 }
871 if (!self_batched && other_batched) {
872 // self_physical: [K], other_physical: [..., K]
873 // View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze.
874 auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
875 auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
876 return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
877 }
878 if (self_batched && other_batched) {
879 // self_physical: [..., K], other_physical: [..., K]
880 // View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze.
881 auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
882 auto result = at::matmul(
883 physical_args[0].tensor().unsqueeze(-2),
884 physical_args[1].tensor().unsqueeze(-1));
885 return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
886 }
887 TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
888}
889
890Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) {
891 TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3,
892 "bmm(self, other): Shape mismatch: expected 3D `self` "
893 "(got `self` of size ", self.sizes(), ") ",
894 "and 3D `other` (got `other` of size ", other.sizes(), ")");
895
896 auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
897 auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
898 return physical_args[0].getPhysicalToLogicalMap().apply(result);
899}
900
901Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
902 auto self_batched = isBatchedTensor(self);
903 auto other_batched = isBatchedTensor(other);
904
905 TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2,
906 "mm(self, other): Shape mismatch: expected matrix "
907 "(got `self` of size ", self.sizes(), ") ",
908 "and matrix (got `other` of size ", other.sizes(), ")");
909
910 // See Note [Batching rules for matmul-like operators] for why we have cases
911 if (self_batched && !other_batched) {
912 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
913 auto result = at::matmul(self_physical.tensor(), other);
914 return self_physical.getPhysicalToLogicalMap().apply(result);
915 }
916 if (!self_batched && other_batched) {
917 auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
918 auto result = at::matmul(self, other_physical.tensor());
919 return other_physical.getPhysicalToLogicalMap().apply(result);
920 }
921 if (self_batched && other_batched) {
922 auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
923 auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
924 return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
925 }
926 TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
927}
928
929Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
930 auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
931 auto physical_tensors = fmap(
932 physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
933 TORCH_INTERNAL_ASSERT(
934 !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
935 auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim));
936 return physical_views[0].getPhysicalToLogicalMap().apply(result);
937}
938
939Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
940 auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
941 auto physical_tensors = fmap(
942 physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
943 TORCH_INTERNAL_ASSERT(
944 !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
945 // NB: stack wraps the dimensionality to (logical dim + 1), so we have to
946 // manually handle that here.
947 auto dim_physical =
948 physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
949 auto result = at::stack(physical_tensors, dim_physical);
950 return physical_views[0].getPhysicalToLogicalMap().apply(result);
951}
952
953// I am quite sad that we need to register operators with exploded TensorOptions,
954// even though the native:: implementations can use TensorOptions&.
955// This also makes it hard to metaprogram: i.e., we can't use
956// unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!)
957Tensor to_dtype_layout_batching_rule(
958 const Tensor& self,
959 optional<ScalarType> dtype,
960 optional<Layout> layout,
961 optional<Device> device,
962 optional<bool> pin_memory,
963 bool non_blocking, bool copy,
964 optional<MemoryFormat> memory_format) {
965 auto options = TensorOptions()
966 .dtype(dtype)
967 .layout(layout)
968 .device(device)
969 .pinned_memory(pin_memory);
970 auto* input_batched = unsafeGetBatchedImpl(self);
971 auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format);
972 auto old_bdims = input_batched->bdims();
973 return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
974}
975
976Tensor new_zeros_batching_rule(
977 const Tensor& self,
978 IntArrayRef size,
979 optional<ScalarType> dtype,
980 optional<Layout> layout,
981 optional<Device> device,
982 optional<bool> pin_memory) {
983 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
984 auto physical_size = physical_view.getPhysicalShape(size);
985 auto options = TensorOptions()
986 .dtype(dtype)
987 .layout(layout)
988 .device(device)
989 .pinned_memory(pin_memory);
990 auto result = physical_view.tensor().new_zeros(physical_size, options);
991 return physical_view.getPhysicalToLogicalMap().apply(result);
992}
993
994Tensor new_empty_batching_rule(
995 const Tensor& self,
996 IntArrayRef size,
997 c10::optional<ScalarType> dtype,
998 c10::optional<Layout> layout,
999 c10::optional<Device> device,
1000 c10::optional<bool> pin_memory) {
1001 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
1002 auto physical_size = physical_view.getPhysicalShape(size);
1003 auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory));
1004 return physical_view.getPhysicalToLogicalMap().apply(result);
1005}
1006
1007Tensor new_empty_strided_batching_rule(
1008 const Tensor& self,
1009 IntArrayRef size,
1010 IntArrayRef stride,
1011 optional<ScalarType> dtype,
1012 optional<Layout> layout,
1013 optional<Device> device,
1014 optional<bool> pin_memory) {
1015 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
1016 auto physical_size = physical_view.getPhysicalShape(size);
1017
1018 // Let [B0, B1, B2] be the shape of the batch dims. We're going to create
1019 // the batch dimensions at the front of the tensor (in memory layout),
1020 // irrespective of whether or not they are actually at the front (in memory layout)
1021 // in the original `self` tensor. This is because when a user calls
1022 // `new_empty_strided` in general, the `strides` they provide are for a new
1023 // tensor and have no relation to the strides of the original tensor.
1024 //
1025 // So, the physical shape of the result should be ([B0, B1, B2] + size),
1026 // but what about the physical strides?
1027 //
1028 // We're actually free to pick whatever stride we want:
1029 // e.g., for size=[5, 3], stride=[0, 1], we could decide to
1030 // use
1031 // - physical size: [B0, B1, B2, 5, 3]
1032 // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1]
1033 //
1034 // Let's select some reasonable strides such that:
1035 // - The batch dims are "contiguous" with respect to each other
1036 // - if empty_strided(size, stride) would have created a contiguous Tensor,
1037 // then this new physical Tensor (with batch dims) is also contiguous
1038 //
1039 // Let S be the size of the storage if one were to construct a tensor
1040 // with `size` and `stride` via empty_strided(size, stride).
1041 // Then the physical sizes/strides should be:
1042 // - physical size: [B0, B1, B2, 5, 3]
1043 // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1]
1044 auto batch_shape = IntArrayRef(
1045 physical_view.tensor().sizes().begin(), physical_view.numBatchDims());
1046
1047 // physical_strides = [B1 * B2 * S, B2 * S, S]
1048 auto physical_strides = at::detail::defaultStrides(batch_shape);
1049 TORCH_CHECK(size.size() == stride.size(),
1050 "new_empty_strided(sizes, strides): dimensionality of sizes (",
1051 size.size(), ") must match dimensionality of strides (",
1052 stride.size(), ")");
1053 auto storage_size = native::storage_size_for(size, stride);
1054 for (auto& physical_stride : physical_strides) {
1055 physical_stride *= storage_size;
1056 }
1057
1058 // physical_strides = [B1 * B2 * S, B2 * S, S] + strides
1059 physical_strides.insert(physical_strides.end(), stride.begin(), stride.end());
1060
1061 auto result = physical_view.tensor().new_empty_strided(
1062 physical_size, physical_strides, dtype, layout, device, pin_memory);
1063 return physical_view.getPhysicalToLogicalMap().apply(result);
1064}
1065
1066template <typename F, F Func>
1067Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) {
1068 auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
1069 auto result = Func(physical_args[0].tensor(), physical_args[1].tensor());
1070 return physical_args[0].getPhysicalToLogicalMap().apply(result);
1071}
1072
1073TORCH_LIBRARY_IMPL(_, Batched, m) {
1074 m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
1075}
1076
1077TORCH_LIBRARY_IMPL(aten, Batched, m) {
1078 // NB: Ideally we would like some operators, like size.int, to "fallthrough"
1079 // to the underlying implementation. However, because a BatchedTensor is a
1080 // Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution
1081 // here is to just directly call the underlying implementation.
1082 m.impl("size.int", static_cast<int64_t (*)(const Tensor&, int64_t)>(native::size));
1083 m.impl("_add_batch_dim", native::_add_batch_dim);
1084 m.impl("_remove_batch_dim", native::_remove_batch_dim);
1085 m.impl("_make_dual", _make_dual_batching_rule);
1086 m.impl("_has_same_storage_numel", _has_same_storage_numel_batching_rule);
1087 m.impl("is_same_size", native::is_same_size);
1088 m.impl("_new_zeros_with_same_feature_meta", _new_zeros_with_same_feature_meta_batching_rule);
1089
1090 m.impl("sum.dim_IntList", sum_batching_rule);
1091 m.impl("is_complex", native::is_complex);
1092
1093 // inplace operations
1094 m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
1095 m.impl("fill_.Tensor", fill_inplace_tensor_batching_rule);
1096 m.impl("zero_", zero_inplace_batching_rule);
1097
1098 // view operations
1099 m.impl("as_strided", as_strided_batching_rule);
1100 m.impl("chunk", chunk_batching_rule);
1101 m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
1102 m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
1103 m.impl("diagonal", diagonal_batching_rule);
1104 m.impl("expand", expand_batching_rule);
1105 m.impl("expand_as", native::expand_as); // composite wrt autograd
1106 m.impl("movedim.intlist", movedim_batching_rule);
1107 m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
1108 // There is another variant of narrow. However, we don't
1109 // want to support the other variant yet bc it isn't documented...
1110 m.impl("narrow", native::narrow_symint); // composite wrt autograd
1111 m.impl("numpy_T", native::numpy_T); // composite wrt autograd
1112 m.impl("matrix_H", native::matrix_H); // composite wrt autograd
1113 m.impl("mT", native::mT); // composite wrt autograd
1114 m.impl("mH", native::mH); // composite wrt autograd
1115 m.impl("permute", permute_batching_rule);
1116 m.impl("reshape", reshape_batching_rule);
1117 m.impl("_reshape_alias", _reshape_alias_batching_rule);
1118 m.impl("reshape_as", native::reshape_as); // composite wrt autograd
1119 m.impl("select.int", select_batching_rule);
1120 m.impl("slice.Tensor", slice_batching_rule);
1121 m.impl("split.Tensor", split_batching_rule);
1122 m.impl("split.sizes", split_with_sizes_batching_rule);
1123 m.impl("split_with_sizes", split_with_sizes_batching_rule);
1124 m.impl("squeeze", squeeze_batching_rule);
1125 m.impl("squeeze.dim", squeeze_dim_batching_rule);
1126 m.impl("squeeze.dims", squeeze_dims_batching_rule);
1127 m.impl("t", native::t); // composite wrt autograd
1128 m.impl("trace", trace_batching_rule);
1129 m.impl("transpose.int", transpose_int_batching_rule);
1130 m.impl("unbind.int", unbind_batching_rule);
1131 m.impl("unfold", unfold_batching_rule);
1132 m.impl("unsqueeze", unsqueeze_batching_rule);
1133 m.impl("view", view_batching_rule);
1134 m.impl("view_as", native::view_as); // composite wrt autograd
1135
1136 // clamp operations
1137 m.impl("clamp", clamp_batching_rule);
1138 m.impl("clamp_min", clamp_min_batching_rule);
1139 m.impl("clamp_max", clamp_max_batching_rule);
1140
1141 // unary pointwise, out-of-place, no additional arguments.
1142#define UNARY_POINTWISE(op) m.impl(#op, \
1143 unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
1144 UNARY_POINTWISE(abs);
1145 UNARY_POINTWISE(acos);
1146 UNARY_POINTWISE(asin);
1147 UNARY_POINTWISE(atan);
1148 UNARY_POINTWISE(ceil);
1149 UNARY_POINTWISE(cos);
1150 UNARY_POINTWISE(cosh);
1151 UNARY_POINTWISE(conj_physical);
1152 UNARY_POINTWISE(digamma);
1153 UNARY_POINTWISE(exp);
1154 UNARY_POINTWISE(expm1);
1155 UNARY_POINTWISE(floor);
1156 UNARY_POINTWISE(frac);
1157 UNARY_POINTWISE(lgamma);
1158 UNARY_POINTWISE(log);
1159 UNARY_POINTWISE(log10);
1160 UNARY_POINTWISE(log1p);
1161 UNARY_POINTWISE(log2);
1162 UNARY_POINTWISE(neg);
1163 UNARY_POINTWISE(reciprocal);
1164 UNARY_POINTWISE(relu);
1165 UNARY_POINTWISE(round);
1166 UNARY_POINTWISE(rsqrt);
1167 UNARY_POINTWISE(sigmoid);
1168 UNARY_POINTWISE(sign);
1169 UNARY_POINTWISE(sin);
1170 UNARY_POINTWISE(sinh);
1171 UNARY_POINTWISE(sqrt);
1172 UNARY_POINTWISE(tan);
1173 UNARY_POINTWISE(tanh);
1174 UNARY_POINTWISE(trunc);
1175#undef UNARY_POINTWISE
1176#define TO_BATCHING_RULE(name, ...) \
1177 { \
1178 using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \
1179 m.impl(name, unwrap_and_call_method< \
1180 to_type, &Tensor::to, __VA_ARGS__>);\
1181 }
1182 TO_BATCHING_RULE("to.device", Device, ScalarType, bool, bool, optional<MemoryFormat>)
1183 TO_BATCHING_RULE("to.dtype", ScalarType, bool, bool, optional<MemoryFormat>)
1184 TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, optional<MemoryFormat>)
1185 m.impl("to.dtype_layout", to_dtype_layout_batching_rule);
1186#undef TO_BATCHING_RULE
1187 m.impl("clone", clone_batching_rule);
1188
1189 using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&);
1190 using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
1191 using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&);
1192
1193#define BINARY_POINTWISE(op) \
1194 m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
1195 m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>);
1196#define BINARY_POINTWISE_VA(op, ...) \
1197 { \
1198 using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \
1199 using Unop = Tensor (*)(const Tensor&, const Scalar&, __VA_ARGS__); \
1200 m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \
1201 m.impl(#op".Scalar", unwrap_and_call<Unop, at::op, const Scalar&, __VA_ARGS__>); \
1202 }
1203
1204 BINARY_POINTWISE_VA(add, const Scalar&);
1205 BINARY_POINTWISE_VA(sub, const Scalar&);
1206 BINARY_POINTWISE_VA(rsub, const Scalar&);
1207 BINARY_POINTWISE(mul);
1208 BINARY_POINTWISE(div);
1209 {
1210 using Binop = Tensor (*)(const Tensor&, const Tensor&, c10::optional<c10::string_view>);
1211 using Unop = Tensor (*)(const Tensor&, const Scalar&, c10::optional<c10::string_view>);
1212 m.impl("div.Tensor_mode", binary_pointwise_batching_rule<Binop, at::div, c10::optional<c10::string_view>>);
1213 m.impl("div.Scalar_mode", unwrap_and_call<Unop, at::div, const Scalar&, c10::optional<c10::string_view>>);
1214 }
1215
1216 // at::pow has three out-of-place overloads
1217 m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
1218 m.impl("pow.Tensor_Scalar", unwrap_and_call<TensorScalarType, at::pow, const Scalar&>);
1219 m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);
1220
1221 m.impl("sigmoid_backward", binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>);
1222 m.impl(
1223 "threshold_backward",
1224 binary_pointwise_batching_rule<
1225 TensorTensorScalarType,
1226 at::threshold_backward,
1227 const Scalar&>);
1228
1229 // for at::result_type, call the native::result_type implementation.
1230 // We don't have to do anything special because native::result_type operates
1231 // on the logical shape of the tensors.
1232 m.impl("result_type.Tensor", static_cast<ScalarType (*)(const Tensor&, const Tensor&)>(native::result_type));
1233 m.impl("result_type.Scalar", static_cast<ScalarType (*)(const Tensor&, const Scalar&)>(native::result_type));
1234 m.impl("result_type.Scalar_Tensor", static_cast<ScalarType (*)(const Scalar&, const Tensor&)>(native::result_type));
1235 m.impl("result_type.Scalar_Scalar", static_cast<ScalarType (*)(const Scalar&, const Scalar&)>(native::result_type));
1236
1237#undef BINARY_POINTWISE_VA
1238#undef BINARY_POINTWISE
1239
1240
1241#define TRIVIAL_OP(op) m.impl(#op, \
1242 unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
1243 // complex number view operators
1244 TRIVIAL_OP(imag)
1245 TRIVIAL_OP(real);
1246 TRIVIAL_OP(view_as_real);
1247 TRIVIAL_OP(conj);
1248 TRIVIAL_OP(_conj);
1249 TRIVIAL_OP(resolve_conj);
1250 TRIVIAL_OP(resolve_neg);
1251 m.impl("view_as_complex", view_as_complex_batching_rule);
1252#undef TRIVIAL
1253
1254 // matmul-like operators
1255 m.impl("mv", mv_batching_rule);
1256 m.impl("dot", dot_batching_rule);
1257 m.impl("bmm", bmm_batching_rule);
1258 m.impl("mm", mm_batching_rule);
1259
1260 // cat/stack
1261 m.impl("cat", cat_batching_rule);
1262 m.impl("stack", stack_batching_rule);
1263
1264 // backward operators
1265 m.impl("select_backward", select_backward_batching_rule);
1266 m.impl("slice_backward", slice_backward_batching_rule);
1267 m.impl("trace_backward", trace_backward_batching_rule);
1268 m.impl("diagonal_backward", diagonal_backward_batching_rule);
1269
1270 // Tensor.new_* operators
1271 m.impl("new_empty", new_empty_batching_rule);
1272 m.impl("new_empty_strided", new_empty_strided_batching_rule);
1273 m.impl("new_zeros", new_zeros_batching_rule);
1274
1275 m.impl("contiguous", contiguous_batching_rule);
1276
1277 // Comparison ops
1278#define COMPARISON_POINTWISE(op) \
1279 m.impl(#op".Tensor", comparison_pointwise_batching_rule<TensorTensorType, at::op>); \
1280 m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>);
1281
1282 COMPARISON_POINTWISE(eq);
1283 COMPARISON_POINTWISE(gt);
1284 COMPARISON_POINTWISE(ge);
1285 COMPARISON_POINTWISE(le);
1286 COMPARISON_POINTWISE(lt);
1287 COMPARISON_POINTWISE(ne);
1288
1289#undef COMPARISON_POINTWISE
1290}
1291
1292} // namespace at
1293