1 | // Copyright (c) Facebook, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | #pragma once |
7 | #include <ATen/Tensor.h> |
8 | #include <ATen/functorch/BatchedTensorImpl.h> |
9 | #include <ATen/functorch/DynamicLayer.h> |
10 | |
11 | // NOTE: [vmap plumbing] |
12 | // |
13 | // Here's how "batching rules" work. |
14 | // - we register kernels to the Batched key |
15 | // - these kernels have the same signatures as the original operators. |
16 | // For example, at::sin(Tensor self) accepts a Tensor, and the batched kernel |
17 | // must also accept a Tensor |
18 | // - However, it is more natural for users to write a batching rule like the |
19 | // following: sin_batch_rule(Tensor self, optional<int> self_bdim) |
20 | // - There is some codegenerated layer (the "plumbing") that wraps the user |
21 | // defined batching rule (e.g. sin_batch_rule) in a kernel that can be |
22 | // registered to the Batched key. |
23 | // |
24 | // The plumbing is responsible for wrapping a batching rule into a form that may |
25 | // be registered as the kernel for the batched key. |
26 | |
27 | namespace at { namespace functorch { |
28 | |
29 | void vmap_check_escaped(const optional<DynamicLayer> &layer, const char* what); |
30 | |
31 | // Create a BatchedTensor given a tensor, bdim, and level |
32 | TORCH_API Tensor makeBatched(const Tensor& tensor, optional<int64_t> bdim, int64_t level); |
33 | |
34 | // Given a Tensor that may or may not be a BatchedTensor, unwrap it. |
35 | // If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level |
36 | // doesn't match, then this returns (tensor, nullopt). |
37 | // Otherwise, it returns (unwrap(tensor), bdim). |
38 | TORCH_API std::tuple<Tensor, c10::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level); |
39 | |
40 | // Creates a vector of BatchedTensor |
41 | TORCH_API std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, optional<int64_t> bdim, int64_t level); |
42 | |
43 | // Returns True if ANY tensor in tensors is batched at level |
44 | TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level); |
45 | TORCH_API bool isBatchedAtLevel(const c10::List<c10::optional<Tensor>> maybe_tensors, int64_t level); |
46 | TORCH_API bool isBatchedAtLevel(const Tensor& tensor, int64_t level); |
47 | TORCH_API bool isBatchedAtLevel(const c10::optional<Tensor>& maybe_tensor, int64_t level); |
48 | |
49 | // Convenience helper. Returns true if any tensor is batched at level |
50 | TORCH_API bool areAnyBatchedAtLevel(ArrayRef<optional<Tensor>> maybe_tensors, int64_t level); |
51 | |
52 | inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) { |
53 | if (ivalue.isTensor()) { |
54 | auto maybe_level = maybeCurrentDynamicLayer(); |
55 | TORCH_INTERNAL_ASSERT(maybe_level.has_value()); |
56 | auto current_level = maybe_level->layerId(); |
57 | return isBatchedAtLevel(ivalue.toTensor(), current_level); |
58 | } |
59 | // TODO: should really check this |
60 | return false; |
61 | } |
62 | |
63 | }} |
64 | |