1 | #include <torch/library.h> |
2 | #include <ATen/RedispatchFunctions.h> |
3 | #include <ATen/LegacyVmapTransforms.h> |
4 | #include <ATen/LegacyBatchedFallback.h> |
5 | #include <ATen/native/ResizeCommon.h> |
6 | #include <ATen/ATen.h> |
7 | #include <ATen/core/IListRef.h> |
8 | #include <c10/util/irange.h> |
9 | #include <c10/core/SymIntArrayRef.h> |
10 | |
11 | #include <utility> |
12 | |
13 | namespace at { |
14 | |
15 | // NOTE: [What is a batching rule?] |
16 | // |
17 | // A *batching rule* implements the logic of how to call an operator on inputs |
18 | // that have zero or more additional batch dimensions. When one does a vmap, the |
19 | // dimension(s) being vmap'ed over get recorded as batch dimensions. |
20 | // |
21 | // For example, vmap(torch.add)(x, y) |
22 | // 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)]; |
23 | // 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)]; |
24 | // 3. and then runs `torch.add(batched_x, batched_y)`. |
25 | |
26 | // NOTE: [When should I add a batching rule?] |
27 | // When you are adding a new operator, you'll need to add a batching rule so |
28 | // that vmap can work efficiently with said operator. If you do not, we'll attempt |
29 | // to generate a slow fallback for the batching rule. |
30 | |
31 | // NOTE: [How to write batching rules?] |
32 | // The signature of a batching rule should look like exactly like the C++ signature |
33 | // of its operator. |
34 | // |
35 | // First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology. |
36 | // |
37 | // At a high level, what a batching rule does is the following: |
38 | // 1. Converts (logical) BatchedTensors to views on physical tensors. |
39 | // 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical |
40 | // arguments that correspond to the physical tensors. |
41 | // 3. Calls at:: operations on the physical tensors and arguments to produce |
42 | // some physical results. |
43 | // 4. Converts physical results back to BatchedTensors. |
44 | // |
45 | // Steps 1, 2, and 4 differ for operators with different batching behaviors. When |
46 | // writing a new batching rule, please select a VmapTransform that matches the |
47 | // batching behavior of your operation. The VmapTransform provides helper functions |
48 | // to do steps (1), (2), and (4). |
49 | // (see NOTE: [What is an VmapTransform?] in VmapTransforms.h) |
50 | |
51 | // Note: [Future plans] |
52 | // The API for writing a batching rule isn't stable. In the future, we'd like |
53 | // to think about the problem of translating these batching rules to TorchScript. |
54 | // Ideally batching rules in eager mode vs TorchScript would look pretty similar, |
55 | // if not use the same mechanism. In order to accomplish that we might have to |
56 | // do some refactoring. |
57 | |
58 | // PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor. |
59 | static bool is_allowed_dim_on_scalar_tensor(int64_t dim) { |
60 | return dim == 0 || dim == -1; |
61 | } |
62 | |
63 | Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional<ScalarType> dtype) { |
64 | if (opt_dims.has_value()) { |
65 | auto dims = opt_dims.value(); |
66 | // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail |
67 | // and instead returns a new scalar tensor (this also happens for dim=-1) |
68 | // If the following happens: |
69 | // >>> x = torch.randn(B0) # the per-examples are all scalars |
70 | // >>> vmap(partial(torch.sum, dim=0), x) |
71 | // then we replicate the behavior of sum(scalar_tensor, dim=0). |
72 | if (/*logical*/self.dim() == 0 && (dims.empty() || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) { |
73 | return self.clone(); |
74 | } |
75 | } |
76 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
77 | auto dims_physical = self_physical.getPhysicalDims(opt_dims); |
78 | auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype); |
79 | return self_physical.getPhysicalToLogicalMap().apply(result); |
80 | } |
81 | |
82 | bool isPhysicalScalarTensor(const Tensor& logical_tensor) { |
83 | if (logical_tensor.dim() > 0) { |
84 | return false; |
85 | } |
86 | auto* batched = maybeGetBatchedImpl(logical_tensor); |
87 | if (batched) { |
88 | return false; |
89 | } |
90 | return true; |
91 | } |
92 | |
93 | template <typename F, F Func, typename... ExtraArgs> |
94 | Tensor binary_pointwise_batching_rule( |
95 | const Tensor& self, const Tensor& other, ExtraArgs... args) { |
96 | if (self.dim() > 0 && other.dim() > 0) { |
97 | auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other}); |
98 | auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...); |
99 | return physical_args[0].getPhysicalToLogicalMap().apply(result); |
100 | } |
101 | if (isPhysicalScalarTensor(self)) { |
102 | auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other); |
103 | auto result = Func(self, other_physical.tensor(), args...); |
104 | return other_physical.getPhysicalToLogicalMap().apply(result); |
105 | } |
106 | if (isPhysicalScalarTensor(other)) { |
107 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
108 | auto result = Func(self_physical.tensor(), other, args...); |
109 | return self_physical.getPhysicalToLogicalMap().apply(result); |
110 | } |
111 | |
112 | // At this point, we know at least one of the operands is a logical Scalar tensor. |
113 | // Here we must emulate TensorIterator's special behavior on Scalars. |
114 | // |
115 | // As a motivating example, consider the following: |
116 | // x = torch.randn(3, 10) |
117 | // y = torch.randn(3, dtype=torch.double) |
118 | // vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double)) |
119 | // |
120 | // At a per-example level, we are adding FloatTensor[10] and DoubleTensor[]; |
121 | // Type Promotion dictates that the result should be FloatTensor[10]. |
122 | // This means we cannot directly pass the physical tensors (x and y) to |
123 | // TensorIterator (if we did, it would promote them to DoubleTensor). |
124 | // |
125 | // FIXME(rzou): I didn't want to go down the slippery slope of emulating |
126 | // everything TensorIterator does (it would be better to refactor out the |
127 | // TensorIterator logic). The one thing that this code doesn't handle |
128 | // is cross-device logical scalar tensors. |
129 | // cpu_tensor = torch.randn(3) |
130 | // cuda_tensor = torch.randn(3, 10, device='cuda') |
131 | // vmap(torch.mul)(cpu_tensor, cuda_tensor) |
132 | // |
133 | // At a per-example level, we are adding CPUTensor[] and CUDATensor[10]. |
134 | // TensorIterator allows for this cross-device operation because one of the |
135 | // tensors is a Scalar CPU tensor. However, the following code will throw an |
136 | // error in that case. I don't expect to see many use cases for this, so |
137 | // this is probably fine as-is. |
138 | auto logical_self = self; |
139 | auto logical_other = other; |
140 | auto result_type = at::native::result_type(logical_self, logical_other); |
141 | if (logical_self.scalar_type() != result_type) { |
142 | logical_self = logical_self.to(result_type); |
143 | } |
144 | if (logical_other.scalar_type() != result_type) { |
145 | logical_other = logical_other.to(result_type); |
146 | } |
147 | auto physical_args = BroadcastingVmapTransform::logicalToPhysical( |
148 | {std::move(logical_self), std::move(logical_other)}); |
149 | auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...); |
150 | return physical_args[0].getPhysicalToLogicalMap().apply(result); |
151 | } |
152 | |
153 | Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) { |
154 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
155 | auto size_physical = self_physical.getPhysicalShape(size); |
156 | auto self_physical_dim = self_physical.tensor().dim(); |
157 | |
158 | TORCH_CHECK(self_physical_dim <= static_cast<int64_t>(size_physical.size()), |
159 | "expand: the number of sizes provided (" , /*logical*/size.size(), ") " , |
160 | "must be greater or equal to the number of dimensions in the tensor (" , |
161 | /*logical dim*/self.dim(), ")" ); |
162 | |
163 | if (self_physical_dim == static_cast<int64_t>(size_physical.size())) { |
164 | auto result = self_physical.tensor().expand(size_physical, implicit); |
165 | return self_physical.getPhysicalToLogicalMap().apply(result); |
166 | } |
167 | |
168 | TORCH_INTERNAL_ASSERT(self_physical_dim < static_cast<int64_t>(size_physical.size())); |
169 | // Here, we know we are expanding a (logical) tensor to a larger number |
170 | // of dimensions. We have to be careful because we can't call expand directly |
171 | // due to the presence of batch dimensions. |
172 | // |
173 | // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]). |
174 | // The result should be a tensor of size [B0, 2, 3]. |
175 | // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3] |
176 | // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and |
177 | // then expand. |
178 | auto self_physical_size = self_physical.tensor().sizes(); |
179 | auto = size_physical.size() - self_physical_dim; |
180 | VmapDimVector view_shape(size_physical.size(), 1); |
181 | std::copy(self_physical_size.begin(), |
182 | self_physical_size.begin() + self_physical.numBatchDims(), |
183 | view_shape.begin()); |
184 | std::copy(self_physical_size.begin() + self_physical.numBatchDims(), |
185 | self_physical_size.end(), |
186 | view_shape.begin() + self_physical.numBatchDims() + extra_dims); |
187 | auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit); |
188 | return self_physical.getPhysicalToLogicalMap().apply(result); |
189 | } |
190 | |
191 | std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) { |
192 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
193 | auto dim_physical = self_physical.getPhysicalDim(dim); |
194 | auto result = at::chunk(self_physical.tensor(), chunks, dim_physical); |
195 | self_physical.getPhysicalToLogicalMap().applyInplace(result); |
196 | return result; |
197 | } |
198 | |
199 | Tensor clamp_batching_rule(const Tensor& self, const optional<Scalar>& min, const optional<Scalar>& max) { |
200 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
201 | auto result = at::clamp(self_physical.tensor(), min, max); |
202 | return self_physical.getPhysicalToLogicalMap().apply(result); |
203 | } |
204 | |
205 | Tensor clamp_min_batching_rule(const Tensor& self, const Scalar& min) { |
206 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
207 | auto result = at::clamp_min(self_physical.tensor(), min); |
208 | return self_physical.getPhysicalToLogicalMap().apply(result); |
209 | } |
210 | |
211 | Tensor clamp_max_batching_rule(const Tensor& self, const Scalar& max) { |
212 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
213 | auto result = at::clamp_max(self_physical.tensor(), max); |
214 | return self_physical.getPhysicalToLogicalMap().apply(result); |
215 | } |
216 | |
217 | std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) { |
218 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
219 | auto dim_physical = self_physical.getPhysicalDim(dim); |
220 | auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical); |
221 | self_physical.getPhysicalToLogicalMap().applyInplace(result); |
222 | return result; |
223 | } |
224 | |
225 | std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) { |
226 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
227 | auto dim_physical = self_physical.getPhysicalDim(dim); |
228 | auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical); |
229 | self_physical.getPhysicalToLogicalMap().applyInplace(result); |
230 | return result; |
231 | } |
232 | |
233 | Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) { |
234 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
235 | // NB: unsqueeze has some special handling of its `dim` argument so we can't call |
236 | // self_physical.getPhysicalDim directly. In particular, native::unsqueeze |
237 | // wraps the dim to (the logical dimension) + 1, so we need to do that here too. |
238 | // https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413 |
239 | auto dim_physical = |
240 | self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1); |
241 | auto result = self_physical.tensor().unsqueeze(dim_physical); |
242 | return self_physical.getPhysicalToLogicalMap().apply(result); |
243 | } |
244 | |
245 | Tensor& fill_inplace_scalar_batching_rule(Tensor& self, const Scalar& value) { |
246 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
247 | self_physical.tensor().fill_(value); |
248 | return self; |
249 | } |
250 | |
251 | Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) { |
252 | auto value_batched = isBatchedTensor(value); |
253 | |
254 | if (value_batched) { |
255 | auto physical_args = |
256 | BroadcastingVmapTransform::logicalToPhysical({self, value}); |
257 | physical_args[0].tensor().copy_(physical_args[1].tensor()); |
258 | } else { |
259 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
260 | self_physical.tensor().fill_(value); |
261 | } |
262 | return self; |
263 | } |
264 | |
265 | Tensor& zero_inplace_batching_rule(Tensor &self) { |
266 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
267 | self_physical.tensor().zero_(); |
268 | return self; |
269 | } |
270 | |
271 | Tensor squeeze_batching_rule(const Tensor& self) { |
272 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
273 | auto physical_sizes = self_physical.tensor().sizes(); |
274 | |
275 | // Don't squeeze the batch dims! |
276 | VmapDimVector squeezed_sizes; |
277 | int64_t num_batch_dims = self_physical.numBatchDims(); |
278 | squeezed_sizes.insert( |
279 | squeezed_sizes.end(), |
280 | physical_sizes.begin(), |
281 | physical_sizes.begin() + num_batch_dims); |
282 | for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) { |
283 | if (*it != 1) { |
284 | squeezed_sizes.push_back(*it); |
285 | } |
286 | } |
287 | |
288 | auto result = self_physical.tensor().view(squeezed_sizes); |
289 | return self_physical.getPhysicalToLogicalMap().apply(result); |
290 | } |
291 | |
292 | Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) { |
293 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
294 | auto dim_physical = self_physical.getPhysicalDim(dim); |
295 | auto result = self_physical.tensor().squeeze(dim_physical); |
296 | return self_physical.getPhysicalToLogicalMap().apply(result); |
297 | } |
298 | |
299 | Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) { |
300 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
301 | auto dims_physical = self_physical.getPhysicalDims(dims); |
302 | auto result = self_physical.tensor().squeeze(dims_physical); |
303 | return self_physical.getPhysicalToLogicalMap().apply(result); |
304 | } |
305 | |
306 | Tensor trace_batching_rule(const Tensor& self) { |
307 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
308 | // Batched Diagonal View |
309 | auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1); |
310 | auto result = at::sum(self_diag, -1); |
311 | return self_physical.getPhysicalToLogicalMap().apply(result); |
312 | } |
313 | |
314 | Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) { |
315 | auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); |
316 | auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); |
317 | // Batched Diagonal View |
318 | auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1); |
319 | // Append a dimension of size one to the grad output |
320 | auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1); |
321 | grad_input_diag.copy_(grad_physical_tensor); |
322 | return grad_physical.getPhysicalToLogicalMap().apply(grad_input); |
323 | } |
324 | |
325 | Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) { |
326 | // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works |
327 | // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens: |
328 | // >>> x = torch.randn(B0) # the per-examples are all scalars |
329 | // >>> vmap(lambda x: x.transpose(0, -1), x) |
330 | // then we replicate this behavior. |
331 | if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) && |
332 | is_allowed_dim_on_scalar_tensor(dim1)) { |
333 | return self; |
334 | } |
335 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
336 | auto dim0_physical = self_physical.getPhysicalDim(dim0); |
337 | auto dim1_physical = self_physical.getPhysicalDim(dim1); |
338 | auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical); |
339 | return self_physical.getPhysicalToLogicalMap().apply(result); |
340 | } |
341 | |
342 | Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) { |
343 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
344 | auto dims_physical = self_physical.getPhysicalDims(dims); |
345 | |
346 | VmapDimVector all_dims_physical; |
347 | all_dims_physical.reserve(self_physical.tensor().dim()); |
348 | for (const auto bdim : c10::irange(self_physical.numBatchDims())) { |
349 | all_dims_physical.push_back(bdim); |
350 | } |
351 | all_dims_physical.insert( |
352 | all_dims_physical.end(), |
353 | dims_physical.begin(), |
354 | dims_physical.end()); |
355 | auto result = self_physical.tensor().permute(all_dims_physical); |
356 | return self_physical.getPhysicalToLogicalMap().apply(result); |
357 | } |
358 | |
359 | Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) { |
360 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
361 | auto dim_physical = self_physical.getPhysicalDim(dim); |
362 | auto result = self_physical.tensor().select(dim_physical, index); |
363 | return self_physical.getPhysicalToLogicalMap().apply(result); |
364 | } |
365 | |
366 | static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) { |
367 | return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims; |
368 | } |
369 | |
370 | Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { |
371 | auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); |
372 | auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); |
373 | auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); |
374 | grad_input.select(physical_dim, index).copy_(grad_physical.tensor()); |
375 | return grad_physical.getPhysicalToLogicalMap().apply(grad_input); |
376 | } |
377 | |
378 | Tensor slice_batching_rule( |
379 | const Tensor& self, |
380 | int64_t dim, |
381 | c10::optional<int64_t> start, |
382 | c10::optional<int64_t> end, |
383 | int64_t step) { |
384 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
385 | auto dim_physical = self_physical.getPhysicalDim(dim); |
386 | auto result = self_physical.tensor().slice(dim_physical, start, end, step); |
387 | return self_physical.getPhysicalToLogicalMap().apply(result); |
388 | } |
389 | |
390 | Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { |
391 | auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); |
392 | auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); |
393 | auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); |
394 | grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor()); |
395 | return grad_physical.getPhysicalToLogicalMap().apply(grad_input); |
396 | } |
397 | |
398 | Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { |
399 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
400 | auto dim1_physical = self_physical.getPhysicalDim(dim1); |
401 | auto dim2_physical = self_physical.getPhysicalDim(dim2); |
402 | auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical); |
403 | return self_physical.getPhysicalToLogicalMap().apply(result); |
404 | } |
405 | |
406 | Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { |
407 | auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); |
408 | auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); |
409 | auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims()); |
410 | auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims()); |
411 | grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor()); |
412 | return grad_physical.getPhysicalToLogicalMap().apply(grad_input); |
413 | } |
414 | |
415 | Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) { |
416 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
417 | auto source_physical = self_physical.getPhysicalDims(source); |
418 | auto destination_physical = self_physical.getPhysicalDims(destination); |
419 | auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical); |
420 | return self_physical.getPhysicalToLogicalMap().apply(result); |
421 | } |
422 | |
423 | Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) { |
424 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
425 | auto shape_physical = self_physical.getPhysicalShape(shape); |
426 | auto result = self_physical.tensor().reshape(shape_physical); |
427 | return self_physical.getPhysicalToLogicalMap().apply(result); |
428 | } |
429 | |
430 | std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) { |
431 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
432 | auto dim_physical = self_physical.getPhysicalDim(dim); |
433 | auto result = at::split(self_physical.tensor(), split_size, dim_physical); |
434 | self_physical.getPhysicalToLogicalMap().applyInplace(result); |
435 | return result; |
436 | } |
437 | |
438 | std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { |
439 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
440 | auto dim_physical = self_physical.getPhysicalDim(dim); |
441 | auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical); |
442 | self_physical.getPhysicalToLogicalMap().applyInplace(result); |
443 | return result; |
444 | } |
445 | |
446 | std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) { |
447 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
448 | auto dim_physical = self_physical.getPhysicalDim(dim); |
449 | auto result = at::unbind(self_physical.tensor(), dim_physical); |
450 | self_physical.getPhysicalToLogicalMap().applyInplace(result); |
451 | return result; |
452 | } |
453 | |
454 | Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) { |
455 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
456 | auto dim_physical = self_physical.getPhysicalDim(dim); |
457 | auto result = self_physical.tensor().unfold(dim_physical, size, step); |
458 | return self_physical.getPhysicalToLogicalMap().apply(result); |
459 | } |
460 | |
461 | Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) { |
462 | TORCH_CHECK(memory_format == MemoryFormat::Contiguous, |
463 | "NYI: Tensor.contiguous(...) inside of vmap for memory_format other " , |
464 | "than torch.contiguous_format" ); |
465 | auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); |
466 | auto result = physical_view.tensor().contiguous(memory_format); |
467 | return physical_view.getPhysicalToLogicalMap().apply(result); |
468 | } |
469 | |
470 | Tensor view_batching_rule(const Tensor& self, IntArrayRef size) { |
471 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
472 | auto size_physical = self_physical.getPhysicalShape(size); |
473 | auto result = self_physical.tensor().view(size_physical); |
474 | return self_physical.getPhysicalToLogicalMap().apply(result); |
475 | } |
476 | |
477 | Tensor view_as_complex_batching_rule(const Tensor& self) { |
478 | // guard against the user passing in a batch of scalar tensors with batch |
479 | // size equal to 2. |
480 | TORCH_CHECK(!self.sizes().empty(), "Input tensor must have one or more dimensions" ); |
481 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
482 | auto result = at::view_as_complex(self_physical.tensor()); |
483 | return self_physical.getPhysicalToLogicalMap().apply(result); |
484 | } |
485 | |
486 | // Checks that the smallest batch stride is greater than the largest example |
487 | // stride. This is something we can support but we choose not to because it's |
488 | // potentially error prone. |
489 | static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) { |
490 | auto smallest_batch_stride = std::min_element( |
491 | physical_strides.begin(), physical_strides.begin() + num_batch_dims); |
492 | auto largest_example_stride = std::max_element( |
493 | physical_strides.begin() + num_batch_dims, physical_strides.end()); |
494 | if (largest_example_stride == physical_strides.end()) { |
495 | // No example dimensions |
496 | return; |
497 | } |
498 | TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride, |
499 | "vmap: Calling Tensor.as_strided is not supported unless the batch dims being " , |
500 | "vmapped over are at the front of the tensor (in memory layout). When they are " , |
501 | "not at the front of the tensor this operation can be error prone so we " |
502 | "actively discourage it; please file us a bug report and/or try to " , |
503 | "express the as_strided operation in terms of PyTorch view operations" ); |
504 | } |
505 | |
506 | // given (sizes, strides, storage_offset) returns the maximum location that |
507 | // can be indexed (or nullopt if such a location doesn't exist, e.g., tensors |
508 | // with zero-size dims). |
509 | static optional<int64_t> maximum_indexable_location( |
510 | IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) { |
511 | auto result = native::storage_size_for(sizes, strides); |
512 | if (result == 0) { |
513 | return nullopt; |
514 | } |
515 | return result + storage_offset; |
516 | } |
517 | |
518 | // Let x be the "first slice" of physical_tensor. |
519 | // This checks that the range of possible memory locations accessible by |
520 | // x.as_strided(sizes, strides, maybe_storage_offset) |
521 | // are within the bounds of possible memory locations accessible by x. |
522 | static void checkBasicAsStridedValidForSlice( |
523 | const Tensor& physical_tensor, |
524 | int64_t num_batch_dims, |
525 | IntArrayRef sizes, |
526 | IntArrayRef strides, |
527 | optional<int64_t> maybe_storage_offset) { |
528 | auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims); |
529 | auto slice_strides = physical_tensor.strides().slice(num_batch_dims); |
530 | auto base_offset = physical_tensor.storage_offset(); |
531 | |
532 | auto storage_offset = maybe_storage_offset.value_or(base_offset); |
533 | |
534 | auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset); |
535 | auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset); |
536 | |
537 | if (!max_as_strided_loc.has_value()) { |
538 | return; |
539 | } |
540 | if (!max_slice_loc.has_value()) { |
541 | TORCH_CHECK(false, |
542 | "result = tensor.as_strided(" , sizes, "," , strides, "," , storage_offset, ")" , |
543 | "can access memory outside of `tensor`. `tensor` has no storage but the " , |
544 | "passed-in (size, stride, storage_offset) imply a result with some storage. " , |
545 | "This is not supported inside of vmap, please try to rewrite the " , |
546 | "`as_strided` call as a sequence of PyTorch view operations" ); |
547 | } |
548 | |
549 | TORCH_CHECK( |
550 | *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset, |
551 | "result = tensor.as_strided(" , sizes, "," , strides, "," , storage_offset, ")" , |
552 | "can access memory outside of `tensor`. `result` can access some" , |
553 | "memory in range [" , storage_offset, ", " , *max_as_strided_loc, "], but " , |
554 | "`tensor` can only access some memory in range [" , base_offset, ", " , |
555 | *max_slice_loc, "]. This is not supported inside of vmap, please try to" , |
556 | "rewrite the `as_strided` call as a sequence of PyTorch view operations" ); |
557 | } |
558 | |
559 | Tensor _reshape_alias_batching_rule(const Tensor& self, IntArrayRef sizes, IntArrayRef strides) { |
560 | return reshape_batching_rule(self, sizes); |
561 | } |
562 | |
563 | Tensor _new_zeros_with_same_feature_meta_batching_rule( |
564 | const Tensor& self, |
565 | const Tensor& other, |
566 | int64_t unused_num_batch_dims) { |
567 | TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other), |
568 | "Only the 'batched grad' use case is supported in PyTorch core." ); |
569 | |
570 | TORCH_INTERNAL_ASSERT(unused_num_batch_dims == 0, |
571 | "num_batch_dims should not be explicitly passed in because it will be overridden" ); |
572 | auto self_physical_view = at::MultiBatchVmapTransform::logicalToPhysical(self); |
573 | const auto& self_physical_tensor = self_physical_view.tensor(); |
574 | int64_t num_batch_dims = self_physical_view.numBatchDims(); |
575 | checkBatchDimsAtFrontInLayout(self_physical_tensor.strides(), num_batch_dims); |
576 | auto result = at::_new_zeros_with_same_feature_meta(self_physical_tensor, other, num_batch_dims); |
577 | return self_physical_view.getPhysicalToLogicalMap().apply(result); |
578 | } |
579 | |
580 | bool _has_same_storage_numel_batching_rule(const Tensor& self, const Tensor& other) { |
581 | TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other), |
582 | "Only the 'batched grad' use case is supported in PyTorch core." ); |
583 | // The _has_same_storage_numel check is skipped if the tangent is a batched |
584 | // tensor because using as_strided to access storage locations not indexable |
585 | // by the input tensor is not supported in vmap |
586 | return true; |
587 | } |
588 | |
589 | // What are the semantics of as_strided inside of vmap? |
590 | // y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs) |
591 | // This returns a view on `x`, `y`, such that each y[i] has: |
592 | // - sizes: `sizes` |
593 | // - strides: `strides` |
594 | // - storage_offset: offset + i * x.stride(batch_dim) |
595 | // |
596 | // In other words, it is as if we had treated each x[i] as having storage |
597 | // offset equal to xs.offset() and called as_strided(sizes, sizes, offset). |
598 | // (that is equivalent to x[i].as_strided( |
599 | // sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i) |
600 | // |
601 | // Note that this *may* be different from actually running as_strided |
602 | // in a for-loop. This is due to how as_strided takes in `offset` to be |
603 | // an *absolute* offset. As an example, consider: |
604 | // >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1) |
605 | // >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)] |
606 | // Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))! |
607 | // However, we consider the above for-loop comprehension to be a user error: |
608 | // a user should have written the following if they wanted to use as_strided |
609 | // in a per-sample way: |
610 | // >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)] |
611 | Tensor as_strided_batching_rule( |
612 | const Tensor& tensor, |
613 | IntArrayRef sizes, |
614 | IntArrayRef strides, |
615 | optional<int64_t> storage_offset) { |
616 | auto physical_view = at::MultiBatchVmapTransform::logicalToPhysical(tensor); |
617 | auto num_batch_dims = physical_view.numBatchDims(); |
618 | auto physical_sizes = physical_view.getPhysicalShape(sizes); |
619 | const auto& physical_tensor = physical_view.tensor(); |
620 | |
621 | // We can't rely on the physical as_strided call to do this for us because |
622 | // we do some sanity checks on the size/strides before calling into as_strided. |
623 | TORCH_CHECK(sizes.size() == strides.size(), |
624 | "Tensor.as_strided(size, stride, ...): size and stride must have the " , |
625 | "same length! Got size " , sizes, " and stride " , strides); |
626 | |
627 | // Sanity checks: |
628 | // 1. All batch dims are at the front in memory layout (not necessary for |
629 | // correctness, but we are worried the user might be doing crazy things) |
630 | // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset()) |
631 | // is valid for a slice of the input tensor. |
632 | // See Note: [When will the as_strided batching rule fail?] for details. |
633 | checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims); |
634 | checkBasicAsStridedValidForSlice( |
635 | physical_tensor, num_batch_dims, sizes, strides, storage_offset); |
636 | |
637 | // physical_strides = physical tensor's batch strides + (logical) strides |
638 | auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims); |
639 | at::VmapDimVector physical_strides; |
640 | physical_strides.reserve(num_batch_dims + strides.size()); |
641 | physical_strides.insert( |
642 | physical_strides.end(), batch_strides.begin(), batch_strides.end()); |
643 | physical_strides.insert( |
644 | physical_strides.end(), strides.begin(), strides.end()); |
645 | |
646 | // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) |
647 | // is valid for all i, then it turns out that |
648 | // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds |
649 | // and creates a tensor y such that each y[i] references the same memory |
650 | // locations as zi. See NOTE: [When will the as_strided batching rule fail?] |
651 | auto result = physical_view.tensor().as_strided( |
652 | physical_sizes, physical_strides, storage_offset); |
653 | return physical_view.getPhysicalToLogicalMap().apply(result); |
654 | } |
655 | |
656 | // NOTE: [When will the as_strided batching rule fail?] |
657 | // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) |
658 | // is valid for all i, then it turns out that |
659 | // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and |
660 | // creates a tensor y such that each y[i] refers to the same memory as zi. |
661 | // |
662 | // Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()). |
663 | // Furthermore, let's say that as a part of being "valid" this as_strided call |
664 | // does not return a result that can index memory not indexable by xs[i]. |
665 | // |
666 | // WLOG, assume that there's only one batch dim and it is at the front of the |
667 | // `xs` tensor. Let B be the batch size and S be the stride of the batch dim. |
668 | // - If the batch dim isn't at the front of the tensor, then we can just move it |
669 | // to the front with movedim/permute. This is always valid because it just swaps |
670 | // some strides around. |
671 | // - This proof also works for tensors with multiple batch dims. We just have to |
672 | // do a little accounting: |
673 | // - instead of [B], we'd have [B0, B1, ..., Bk]. |
674 | // - instead of [S], we'd have [S0, S1, ..., Sk]. |
675 | // - instead of i, we'd have a list of indices [I0, I1, ..., Ik] |
676 | // - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i |
677 | // |
678 | // [Equation 1] |
679 | // xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has: |
680 | // - sizes: sizes |
681 | // - strides: strides |
682 | // - offset: offset + S * i |
683 | // |
684 | // x.as_strided itself checks that: |
685 | // - (sizes, strides, offset) are in bounds for `x`'s storage. |
686 | // - strides are positive |
687 | // - offset is positive |
688 | // |
689 | // Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) |
690 | // is valid, then |
691 | // ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage. |
692 | // |
693 | // If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset) |
694 | // won't error out. So all we need to check is that the memory locations are |
695 | // what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important) |
696 | // |
697 | // xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to |
698 | // xs.as_strided([B] + sizes, [S] + strides, offset) |
699 | // |
700 | // xs.as_strided([B] + sizes, [S] + strides, offset) has: |
701 | // - sizes: [B] + sizes |
702 | // - strides: [S] + strides |
703 | // - offset: offset |
704 | // |
705 | // xs.as_strided([B] + sizes, [S] + strides, offset)[i] has: |
706 | // - sizes: sizes |
707 | // - strides: strides |
708 | // - offset: offset + S * i |
709 | // These memory locations are exactly the same as what we got for [Equation 1], |
710 | // so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid. |
711 | // |
712 | // [Hand-wavy proof of Claim 1] |
713 | // Part of our definition of being valid is that xs[i].as_strided(...) |
714 | // must return a tensor that only uses memory indexable by xs[i]. |
715 | // This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies: |
716 | // offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j] |
717 | // <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) |
718 | // (the largest-index memory location of xs[i].as_strided(...) must be \leq |
719 | // the largest-index memory location of xs[i]) |
720 | // |
721 | // Fiddling that inequality gives us: |
722 | // offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j] |
723 | // <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) |
724 | // |
725 | // offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j] |
726 | // <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) |
727 | // |
728 | // offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j] |
729 | // <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j) |
730 | // |
731 | // offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j] |
732 | // <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j) |
733 | // (the largest-index memory location of xs.as_strided(size, stride, offset) |
734 | // is \leq than the largest-index memory location of xs) |
735 | // Under the assumptions we've made, the lower bound (lowest indexed memory) |
736 | // is trivially within the storage. |
737 | // |
738 | // Therefore ([B] + sizes, [S] + strides, offset) are in bounds for |
739 | // `xs`'s storage. |
740 | |
741 | template <typename F, F Func, typename... ExtraArgs> |
742 | Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) { |
743 | auto* input_batched = unsafeGetBatchedImpl(input); |
744 | auto output_physical = Func(input_batched->value(), args...); |
745 | auto old_bdims = input_batched->bdims(); |
746 | return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); |
747 | } |
748 | |
749 | template <typename F, F Func, typename... ExtraArgs> |
750 | Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... ) { |
751 | auto* input_batched = unsafeGetBatchedImpl(input); |
752 | auto output_physical = (input_batched->value().*Func)(extra_args...); |
753 | auto old_bdims = input_batched->bdims(); |
754 | return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); |
755 | } |
756 | |
757 | Tensor pow_scalar_Tensor_batching_rule(const Scalar& other, const Tensor& self) { |
758 | auto* self_batched = unsafeGetBatchedImpl(self); |
759 | auto output_physical = at::pow(other, self_batched->value()); |
760 | auto old_bdims = self_batched->bdims(); |
761 | return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); |
762 | } |
763 | |
764 | Tensor clone_batching_rule(const Tensor& self, optional<MemoryFormat> memory_format) { |
765 | // Memory format support is a little tricky because vmap is allowed to move |
766 | // around batch dimensions and some memory formats are rank-dependent. |
767 | // Another weird case is: |
768 | // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we |
769 | // allow the user to clone a Tensor with 3 logical dimensions and 1 batch |
770 | // dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims |
771 | // and N>1 batch dims? |
772 | TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve |
773 | || memory_format == MemoryFormat::Contiguous, |
774 | "NYI: Tensor.clone(memory_format) inside vmap is only supported with " , |
775 | "memory_format torch.preserve_format or torch.contiguous_format (got " , |
776 | *memory_format, ")" ); |
777 | |
778 | if (memory_format == MemoryFormat::Contiguous) { |
779 | // There is an ambiguity here when the batch dims are not at the front of |
780 | // the tensor. |
781 | // >>> x = torch.randn(3, B0, 5) |
782 | // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x) |
783 | // >>> y[0].is_contiguous() |
784 | // ??? |
785 | // Should we make the whole tensor contiguous, or should we |
786 | // make the non-batch dims contiguous? We've chosen the latter because |
787 | // philosophically vmap hides the batch dims and operates on a per-sample level. |
788 | auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); |
789 | auto output_physical = at::clone(physical_view.tensor(), memory_format); |
790 | return physical_view.getPhysicalToLogicalMap().apply(output_physical); |
791 | } |
792 | |
793 | TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve); |
794 | auto* self_batched = unsafeGetBatchedImpl(self); |
795 | auto output_physical = at::clone(self_batched->value(), memory_format); |
796 | auto old_bdims = self_batched->bdims(); |
797 | return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); |
798 | } |
799 | |
800 | // Note [Batching rules for matmul-like operators] |
801 | // at::matmul doesn't "de-expand" arguments to get better performance (maybe |
802 | // it should). In the batching rules for matmul-like operators (dot, mv, mm), |
803 | // we should be careful not to expand any unnecessary dimensions. e.g., if |
804 | // only one of the two arguments is a BatchedTensor, then we should try |
805 | // not to expand batch dimensions onto the other arg. |
806 | Tensor mv_batching_rule(const Tensor& self, const Tensor& other) { |
807 | auto self_batched = isBatchedTensor(self); |
808 | auto other_batched = isBatchedTensor(other); |
809 | |
810 | // A shape checking API would be nice... |
811 | TORCH_CHECK(self.dim() == 2 && other.dim() == 1, |
812 | "mv(self, other): Shape mismatch: expected matrix " |
813 | "(got `self` of size " , self.sizes(), ") " , |
814 | "and vector (got `other` of size " , other.sizes(), ")" ); |
815 | |
816 | // See Note [Batching rules for matmul-like operators] for why we have cases |
817 | if (self_batched && !other_batched) { |
818 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
819 | auto result = at::matmul(self_physical.tensor(), other); |
820 | return self_physical.getPhysicalToLogicalMap().apply(result); |
821 | } |
822 | if (!self_batched && other_batched) { |
823 | // self_physical: [L, K], other_physical: [..., K] |
824 | // We view the tensors as [L, K], [..., K, 1], perform matmul to get |
825 | // a tensor of size [..., L, 1], and unsqueeze the last dim. |
826 | auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other); |
827 | auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1)); |
828 | return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1)); |
829 | } |
830 | if (self_batched && other_batched) { |
831 | // self_physical: [..., L, K], other_physical: [..., K] |
832 | // We view the tensors as [..., L, K], [..., K, 1], perform matmul to get |
833 | // a tensor of size [..., L, 1], and unsqueeze the last dim. |
834 | auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other}); |
835 | auto result = at::matmul( |
836 | physical_args[0].tensor(), |
837 | physical_args[1].tensor().unsqueeze(-1)); |
838 | return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1)); |
839 | } |
840 | TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor" ); |
841 | } |
842 | |
843 | Tensor _make_dual_batching_rule( |
844 | c10::DispatchKeySet ks, |
845 | const Tensor& primal, |
846 | const Tensor& tangent, |
847 | int64_t level |
848 | ) { |
849 | DispatchKeySet after_batched_keyset = |
850 | DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Batched); |
851 | return at::redispatch::_make_dual(ks & after_batched_keyset, primal, tangent, level); |
852 | } |
853 | |
854 | Tensor dot_batching_rule(const Tensor& self, const Tensor& other) { |
855 | auto self_batched = isBatchedTensor(self); |
856 | auto other_batched = isBatchedTensor(other); |
857 | |
858 | TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1, |
859 | "dot(self, other): Shape mismatch: vector " |
860 | "(got `self` of size " , self.sizes(), ") " , |
861 | "and vector (got `other` of size " , other.sizes(), ")" ); |
862 | |
863 | // See Note [Batching rules for matmul-like operators] for why we have cases |
864 | if (self_batched && !other_batched) { |
865 | // self_physical: [..., K], other_physical: [K] |
866 | // View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze. |
867 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
868 | auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other); |
869 | return self_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1)); |
870 | } |
871 | if (!self_batched && other_batched) { |
872 | // self_physical: [K], other_physical: [..., K] |
873 | // View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze. |
874 | auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other); |
875 | auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1)); |
876 | return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1)); |
877 | } |
878 | if (self_batched && other_batched) { |
879 | // self_physical: [..., K], other_physical: [..., K] |
880 | // View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze. |
881 | auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other}); |
882 | auto result = at::matmul( |
883 | physical_args[0].tensor().unsqueeze(-2), |
884 | physical_args[1].tensor().unsqueeze(-1)); |
885 | return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1)); |
886 | } |
887 | TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor" ); |
888 | } |
889 | |
890 | Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) { |
891 | TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3, |
892 | "bmm(self, other): Shape mismatch: expected 3D `self` " |
893 | "(got `self` of size " , self.sizes(), ") " , |
894 | "and 3D `other` (got `other` of size " , other.sizes(), ")" ); |
895 | |
896 | auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other}); |
897 | auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor()); |
898 | return physical_args[0].getPhysicalToLogicalMap().apply(result); |
899 | } |
900 | |
901 | Tensor mm_batching_rule(const Tensor& self, const Tensor& other) { |
902 | auto self_batched = isBatchedTensor(self); |
903 | auto other_batched = isBatchedTensor(other); |
904 | |
905 | TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2, |
906 | "mm(self, other): Shape mismatch: expected matrix " |
907 | "(got `self` of size " , self.sizes(), ") " , |
908 | "and matrix (got `other` of size " , other.sizes(), ")" ); |
909 | |
910 | // See Note [Batching rules for matmul-like operators] for why we have cases |
911 | if (self_batched && !other_batched) { |
912 | auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); |
913 | auto result = at::matmul(self_physical.tensor(), other); |
914 | return self_physical.getPhysicalToLogicalMap().apply(result); |
915 | } |
916 | if (!self_batched && other_batched) { |
917 | auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other); |
918 | auto result = at::matmul(self, other_physical.tensor()); |
919 | return other_physical.getPhysicalToLogicalMap().apply(result); |
920 | } |
921 | if (self_batched && other_batched) { |
922 | auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other}); |
923 | auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor()); |
924 | return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1)); |
925 | } |
926 | TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor" ); |
927 | } |
928 | |
929 | Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) { |
930 | auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors); |
931 | auto physical_tensors = fmap( |
932 | physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); |
933 | TORCH_INTERNAL_ASSERT( |
934 | !tensors.empty(), "The dispatcher should not have dispatched here otherwise." ); |
935 | auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim)); |
936 | return physical_views[0].getPhysicalToLogicalMap().apply(result); |
937 | } |
938 | |
939 | Tensor stack_batching_rule(TensorList tensors, int64_t dim) { |
940 | auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors); |
941 | auto physical_tensors = fmap( |
942 | physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); |
943 | TORCH_INTERNAL_ASSERT( |
944 | !tensors.empty(), "The dispatcher should not have dispatched here otherwise." ); |
945 | // NB: stack wraps the dimensionality to (logical dim + 1), so we have to |
946 | // manually handle that here. |
947 | auto dim_physical = |
948 | physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1); |
949 | auto result = at::stack(physical_tensors, dim_physical); |
950 | return physical_views[0].getPhysicalToLogicalMap().apply(result); |
951 | } |
952 | |
953 | // I am quite sad that we need to register operators with exploded TensorOptions, |
954 | // even though the native:: implementations can use TensorOptions&. |
955 | // This also makes it hard to metaprogram: i.e., we can't use |
956 | // unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!) |
957 | Tensor to_dtype_layout_batching_rule( |
958 | const Tensor& self, |
959 | optional<ScalarType> dtype, |
960 | optional<Layout> layout, |
961 | optional<Device> device, |
962 | optional<bool> pin_memory, |
963 | bool non_blocking, bool copy, |
964 | optional<MemoryFormat> memory_format) { |
965 | auto options = TensorOptions() |
966 | .dtype(dtype) |
967 | .layout(layout) |
968 | .device(device) |
969 | .pinned_memory(pin_memory); |
970 | auto* input_batched = unsafeGetBatchedImpl(self); |
971 | auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format); |
972 | auto old_bdims = input_batched->bdims(); |
973 | return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); |
974 | } |
975 | |
976 | Tensor new_zeros_batching_rule( |
977 | const Tensor& self, |
978 | IntArrayRef size, |
979 | optional<ScalarType> dtype, |
980 | optional<Layout> layout, |
981 | optional<Device> device, |
982 | optional<bool> pin_memory) { |
983 | auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); |
984 | auto physical_size = physical_view.getPhysicalShape(size); |
985 | auto options = TensorOptions() |
986 | .dtype(dtype) |
987 | .layout(layout) |
988 | .device(device) |
989 | .pinned_memory(pin_memory); |
990 | auto result = physical_view.tensor().new_zeros(physical_size, options); |
991 | return physical_view.getPhysicalToLogicalMap().apply(result); |
992 | } |
993 | |
994 | Tensor new_empty_batching_rule( |
995 | const Tensor& self, |
996 | IntArrayRef size, |
997 | c10::optional<ScalarType> dtype, |
998 | c10::optional<Layout> layout, |
999 | c10::optional<Device> device, |
1000 | c10::optional<bool> pin_memory) { |
1001 | auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); |
1002 | auto physical_size = physical_view.getPhysicalShape(size); |
1003 | auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)); |
1004 | return physical_view.getPhysicalToLogicalMap().apply(result); |
1005 | } |
1006 | |
1007 | Tensor new_empty_strided_batching_rule( |
1008 | const Tensor& self, |
1009 | IntArrayRef size, |
1010 | IntArrayRef stride, |
1011 | optional<ScalarType> dtype, |
1012 | optional<Layout> layout, |
1013 | optional<Device> device, |
1014 | optional<bool> pin_memory) { |
1015 | auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); |
1016 | auto physical_size = physical_view.getPhysicalShape(size); |
1017 | |
1018 | // Let [B0, B1, B2] be the shape of the batch dims. We're going to create |
1019 | // the batch dimensions at the front of the tensor (in memory layout), |
1020 | // irrespective of whether or not they are actually at the front (in memory layout) |
1021 | // in the original `self` tensor. This is because when a user calls |
1022 | // `new_empty_strided` in general, the `strides` they provide are for a new |
1023 | // tensor and have no relation to the strides of the original tensor. |
1024 | // |
1025 | // So, the physical shape of the result should be ([B0, B1, B2] + size), |
1026 | // but what about the physical strides? |
1027 | // |
1028 | // We're actually free to pick whatever stride we want: |
1029 | // e.g., for size=[5, 3], stride=[0, 1], we could decide to |
1030 | // use |
1031 | // - physical size: [B0, B1, B2, 5, 3] |
1032 | // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1] |
1033 | // |
1034 | // Let's select some reasonable strides such that: |
1035 | // - The batch dims are "contiguous" with respect to each other |
1036 | // - if empty_strided(size, stride) would have created a contiguous Tensor, |
1037 | // then this new physical Tensor (with batch dims) is also contiguous |
1038 | // |
1039 | // Let S be the size of the storage if one were to construct a tensor |
1040 | // with `size` and `stride` via empty_strided(size, stride). |
1041 | // Then the physical sizes/strides should be: |
1042 | // - physical size: [B0, B1, B2, 5, 3] |
1043 | // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1] |
1044 | auto batch_shape = IntArrayRef( |
1045 | physical_view.tensor().sizes().begin(), physical_view.numBatchDims()); |
1046 | |
1047 | // physical_strides = [B1 * B2 * S, B2 * S, S] |
1048 | auto physical_strides = at::detail::defaultStrides(batch_shape); |
1049 | TORCH_CHECK(size.size() == stride.size(), |
1050 | "new_empty_strided(sizes, strides): dimensionality of sizes (" , |
1051 | size.size(), ") must match dimensionality of strides (" , |
1052 | stride.size(), ")" ); |
1053 | auto storage_size = native::storage_size_for(size, stride); |
1054 | for (auto& physical_stride : physical_strides) { |
1055 | physical_stride *= storage_size; |
1056 | } |
1057 | |
1058 | // physical_strides = [B1 * B2 * S, B2 * S, S] + strides |
1059 | physical_strides.insert(physical_strides.end(), stride.begin(), stride.end()); |
1060 | |
1061 | auto result = physical_view.tensor().new_empty_strided( |
1062 | physical_size, physical_strides, dtype, layout, device, pin_memory); |
1063 | return physical_view.getPhysicalToLogicalMap().apply(result); |
1064 | } |
1065 | |
1066 | template <typename F, F Func> |
1067 | Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) { |
1068 | auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other}); |
1069 | auto result = Func(physical_args[0].tensor(), physical_args[1].tensor()); |
1070 | return physical_args[0].getPhysicalToLogicalMap().apply(result); |
1071 | } |
1072 | |
1073 | TORCH_LIBRARY_IMPL(_, Batched, m) { |
1074 | m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>()); |
1075 | } |
1076 | |
1077 | TORCH_LIBRARY_IMPL(aten, Batched, m) { |
1078 | // NB: Ideally we would like some operators, like size.int, to "fallthrough" |
1079 | // to the underlying implementation. However, because a BatchedTensor is a |
1080 | // Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution |
1081 | // here is to just directly call the underlying implementation. |
1082 | m.impl("size.int" , static_cast<int64_t (*)(const Tensor&, int64_t)>(native::size)); |
1083 | m.impl("_add_batch_dim" , native::_add_batch_dim); |
1084 | m.impl("_remove_batch_dim" , native::_remove_batch_dim); |
1085 | m.impl("_make_dual" , _make_dual_batching_rule); |
1086 | m.impl("_has_same_storage_numel" , _has_same_storage_numel_batching_rule); |
1087 | m.impl("is_same_size" , native::is_same_size); |
1088 | m.impl("_new_zeros_with_same_feature_meta" , _new_zeros_with_same_feature_meta_batching_rule); |
1089 | |
1090 | m.impl("sum.dim_IntList" , sum_batching_rule); |
1091 | m.impl("is_complex" , native::is_complex); |
1092 | |
1093 | // inplace operations |
1094 | m.impl("fill_.Scalar" , fill_inplace_scalar_batching_rule); |
1095 | m.impl("fill_.Tensor" , fill_inplace_tensor_batching_rule); |
1096 | m.impl("zero_" , zero_inplace_batching_rule); |
1097 | |
1098 | // view operations |
1099 | m.impl("as_strided" , as_strided_batching_rule); |
1100 | m.impl("chunk" , chunk_batching_rule); |
1101 | m.impl("tensor_split.sections" , tensor_split_sections_batching_rule); |
1102 | m.impl("tensor_split.indices" , tensor_split_indices_batching_rule); |
1103 | m.impl("diagonal" , diagonal_batching_rule); |
1104 | m.impl("expand" , expand_batching_rule); |
1105 | m.impl("expand_as" , native::expand_as); // composite wrt autograd |
1106 | m.impl("movedim.intlist" , movedim_batching_rule); |
1107 | m.impl("movedim.int" , static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd |
1108 | // There is another variant of narrow. However, we don't |
1109 | // want to support the other variant yet bc it isn't documented... |
1110 | m.impl("narrow" , native::narrow_symint); // composite wrt autograd |
1111 | m.impl("numpy_T" , native::numpy_T); // composite wrt autograd |
1112 | m.impl("matrix_H" , native::matrix_H); // composite wrt autograd |
1113 | m.impl("mT" , native::mT); // composite wrt autograd |
1114 | m.impl("mH" , native::mH); // composite wrt autograd |
1115 | m.impl("permute" , permute_batching_rule); |
1116 | m.impl("reshape" , reshape_batching_rule); |
1117 | m.impl("_reshape_alias" , _reshape_alias_batching_rule); |
1118 | m.impl("reshape_as" , native::reshape_as); // composite wrt autograd |
1119 | m.impl("select.int" , select_batching_rule); |
1120 | m.impl("slice.Tensor" , slice_batching_rule); |
1121 | m.impl("split.Tensor" , split_batching_rule); |
1122 | m.impl("split.sizes" , split_with_sizes_batching_rule); |
1123 | m.impl("split_with_sizes" , split_with_sizes_batching_rule); |
1124 | m.impl("squeeze" , squeeze_batching_rule); |
1125 | m.impl("squeeze.dim" , squeeze_dim_batching_rule); |
1126 | m.impl("squeeze.dims" , squeeze_dims_batching_rule); |
1127 | m.impl("t" , native::t); // composite wrt autograd |
1128 | m.impl("trace" , trace_batching_rule); |
1129 | m.impl("transpose.int" , transpose_int_batching_rule); |
1130 | m.impl("unbind.int" , unbind_batching_rule); |
1131 | m.impl("unfold" , unfold_batching_rule); |
1132 | m.impl("unsqueeze" , unsqueeze_batching_rule); |
1133 | m.impl("view" , view_batching_rule); |
1134 | m.impl("view_as" , native::view_as); // composite wrt autograd |
1135 | |
1136 | // clamp operations |
1137 | m.impl("clamp" , clamp_batching_rule); |
1138 | m.impl("clamp_min" , clamp_min_batching_rule); |
1139 | m.impl("clamp_max" , clamp_max_batching_rule); |
1140 | |
1141 | // unary pointwise, out-of-place, no additional arguments. |
1142 | #define UNARY_POINTWISE(op) m.impl(#op, \ |
1143 | unwrap_and_call<Tensor (*)(const Tensor&), at::op>); |
1144 | UNARY_POINTWISE(abs); |
1145 | UNARY_POINTWISE(acos); |
1146 | UNARY_POINTWISE(asin); |
1147 | UNARY_POINTWISE(atan); |
1148 | UNARY_POINTWISE(ceil); |
1149 | UNARY_POINTWISE(cos); |
1150 | UNARY_POINTWISE(cosh); |
1151 | UNARY_POINTWISE(conj_physical); |
1152 | UNARY_POINTWISE(digamma); |
1153 | UNARY_POINTWISE(exp); |
1154 | UNARY_POINTWISE(expm1); |
1155 | UNARY_POINTWISE(floor); |
1156 | UNARY_POINTWISE(frac); |
1157 | UNARY_POINTWISE(lgamma); |
1158 | UNARY_POINTWISE(log); |
1159 | UNARY_POINTWISE(log10); |
1160 | UNARY_POINTWISE(log1p); |
1161 | UNARY_POINTWISE(log2); |
1162 | UNARY_POINTWISE(neg); |
1163 | UNARY_POINTWISE(reciprocal); |
1164 | UNARY_POINTWISE(relu); |
1165 | UNARY_POINTWISE(round); |
1166 | UNARY_POINTWISE(rsqrt); |
1167 | UNARY_POINTWISE(sigmoid); |
1168 | UNARY_POINTWISE(sign); |
1169 | UNARY_POINTWISE(sin); |
1170 | UNARY_POINTWISE(sinh); |
1171 | UNARY_POINTWISE(sqrt); |
1172 | UNARY_POINTWISE(tan); |
1173 | UNARY_POINTWISE(tanh); |
1174 | UNARY_POINTWISE(trunc); |
1175 | #undef UNARY_POINTWISE |
1176 | #define TO_BATCHING_RULE(name, ...) \ |
1177 | { \ |
1178 | using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \ |
1179 | m.impl(name, unwrap_and_call_method< \ |
1180 | to_type, &Tensor::to, __VA_ARGS__>);\ |
1181 | } |
1182 | TO_BATCHING_RULE("to.device" , Device, ScalarType, bool, bool, optional<MemoryFormat>) |
1183 | TO_BATCHING_RULE("to.dtype" , ScalarType, bool, bool, optional<MemoryFormat>) |
1184 | TO_BATCHING_RULE("to.other" , const Tensor&, bool, bool, optional<MemoryFormat>) |
1185 | m.impl("to.dtype_layout" , to_dtype_layout_batching_rule); |
1186 | #undef TO_BATCHING_RULE |
1187 | m.impl("clone" , clone_batching_rule); |
1188 | |
1189 | using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&); |
1190 | using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&); |
1191 | using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&); |
1192 | |
1193 | #define BINARY_POINTWISE(op) \ |
1194 | m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \ |
1195 | m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>); |
1196 | #define BINARY_POINTWISE_VA(op, ...) \ |
1197 | { \ |
1198 | using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \ |
1199 | using Unop = Tensor (*)(const Tensor&, const Scalar&, __VA_ARGS__); \ |
1200 | m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \ |
1201 | m.impl(#op".Scalar", unwrap_and_call<Unop, at::op, const Scalar&, __VA_ARGS__>); \ |
1202 | } |
1203 | |
1204 | BINARY_POINTWISE_VA(add, const Scalar&); |
1205 | BINARY_POINTWISE_VA(sub, const Scalar&); |
1206 | BINARY_POINTWISE_VA(rsub, const Scalar&); |
1207 | BINARY_POINTWISE(mul); |
1208 | BINARY_POINTWISE(div); |
1209 | { |
1210 | using Binop = Tensor (*)(const Tensor&, const Tensor&, c10::optional<c10::string_view>); |
1211 | using Unop = Tensor (*)(const Tensor&, const Scalar&, c10::optional<c10::string_view>); |
1212 | m.impl("div.Tensor_mode" , binary_pointwise_batching_rule<Binop, at::div, c10::optional<c10::string_view>>); |
1213 | m.impl("div.Scalar_mode" , unwrap_and_call<Unop, at::div, const Scalar&, c10::optional<c10::string_view>>); |
1214 | } |
1215 | |
1216 | // at::pow has three out-of-place overloads |
1217 | m.impl("pow.Tensor_Tensor" , binary_pointwise_batching_rule<TensorTensorType, at::pow>); |
1218 | m.impl("pow.Tensor_Scalar" , unwrap_and_call<TensorScalarType, at::pow, const Scalar&>); |
1219 | m.impl("pow.Scalar" , pow_scalar_Tensor_batching_rule); |
1220 | |
1221 | m.impl("sigmoid_backward" , binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>); |
1222 | m.impl( |
1223 | "threshold_backward" , |
1224 | binary_pointwise_batching_rule< |
1225 | TensorTensorScalarType, |
1226 | at::threshold_backward, |
1227 | const Scalar&>); |
1228 | |
1229 | // for at::result_type, call the native::result_type implementation. |
1230 | // We don't have to do anything special because native::result_type operates |
1231 | // on the logical shape of the tensors. |
1232 | m.impl("result_type.Tensor" , static_cast<ScalarType (*)(const Tensor&, const Tensor&)>(native::result_type)); |
1233 | m.impl("result_type.Scalar" , static_cast<ScalarType (*)(const Tensor&, const Scalar&)>(native::result_type)); |
1234 | m.impl("result_type.Scalar_Tensor" , static_cast<ScalarType (*)(const Scalar&, const Tensor&)>(native::result_type)); |
1235 | m.impl("result_type.Scalar_Scalar" , static_cast<ScalarType (*)(const Scalar&, const Scalar&)>(native::result_type)); |
1236 | |
1237 | #undef BINARY_POINTWISE_VA |
1238 | #undef BINARY_POINTWISE |
1239 | |
1240 | |
1241 | #define TRIVIAL_OP(op) m.impl(#op, \ |
1242 | unwrap_and_call<Tensor (*)(const Tensor&), at::op>); |
1243 | // complex number view operators |
1244 | TRIVIAL_OP(imag) |
1245 | TRIVIAL_OP(real); |
1246 | TRIVIAL_OP(view_as_real); |
1247 | TRIVIAL_OP(conj); |
1248 | TRIVIAL_OP(_conj); |
1249 | TRIVIAL_OP(resolve_conj); |
1250 | TRIVIAL_OP(resolve_neg); |
1251 | m.impl("view_as_complex" , view_as_complex_batching_rule); |
1252 | #undef TRIVIAL |
1253 | |
1254 | // matmul-like operators |
1255 | m.impl("mv" , mv_batching_rule); |
1256 | m.impl("dot" , dot_batching_rule); |
1257 | m.impl("bmm" , bmm_batching_rule); |
1258 | m.impl("mm" , mm_batching_rule); |
1259 | |
1260 | // cat/stack |
1261 | m.impl("cat" , cat_batching_rule); |
1262 | m.impl("stack" , stack_batching_rule); |
1263 | |
1264 | // backward operators |
1265 | m.impl("select_backward" , select_backward_batching_rule); |
1266 | m.impl("slice_backward" , slice_backward_batching_rule); |
1267 | m.impl("trace_backward" , trace_backward_batching_rule); |
1268 | m.impl("diagonal_backward" , diagonal_backward_batching_rule); |
1269 | |
1270 | // Tensor.new_* operators |
1271 | m.impl("new_empty" , new_empty_batching_rule); |
1272 | m.impl("new_empty_strided" , new_empty_strided_batching_rule); |
1273 | m.impl("new_zeros" , new_zeros_batching_rule); |
1274 | |
1275 | m.impl("contiguous" , contiguous_batching_rule); |
1276 | |
1277 | // Comparison ops |
1278 | #define COMPARISON_POINTWISE(op) \ |
1279 | m.impl(#op".Tensor", comparison_pointwise_batching_rule<TensorTensorType, at::op>); \ |
1280 | m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, const Scalar&>); |
1281 | |
1282 | COMPARISON_POINTWISE(eq); |
1283 | COMPARISON_POINTWISE(gt); |
1284 | COMPARISON_POINTWISE(ge); |
1285 | COMPARISON_POINTWISE(le); |
1286 | COMPARISON_POINTWISE(lt); |
1287 | COMPARISON_POINTWISE(ne); |
1288 | |
1289 | #undef COMPARISON_POINTWISE |
1290 | } |
1291 | |
1292 | } // namespace at |
1293 | |