1 | // Copyright (c) Facebook, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | |
7 | #pragma once |
8 | |
9 | #include <bitset> |
10 | #include <utility> |
11 | |
12 | #include <ATen/ArrayRef.h> |
13 | #include <ATen/SmallVector.h> |
14 | #include <ATen/Tensor.h> |
15 | |
16 | namespace at { |
17 | namespace functorch { |
18 | |
19 | using Tensor = at::Tensor; |
20 | |
21 | // We assume this in a few other places in the codebase, |
22 | // but there isn't a centralized definition. |
23 | constexpr int64_t kVmapMaxTensorDims = 64; |
24 | |
25 | // The valid vmap levels range from [0, 64). This effectively means that we |
26 | // support a maximum of 64 nested vmaps. |
27 | constexpr int64_t kVmapNumLevels = 64; |
28 | |
29 | // Store this number of elements of BatchDims on the stack. Most people will |
30 | // probably use <= 5 nested vmaps, but adjust this number as necessary. |
31 | constexpr int64_t kBatchDimsStackSize = 5; |
32 | |
33 | // A BatchedTensorImpl holds an underlying Tensor and a single batch dim |
34 | // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a |
35 | // BatchedTensorImpl. |
36 | // |
37 | // The batch dimensions are treated as being "private"; they are not user-visible. |
38 | // For example, in the following Tensor, |
39 | // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0) |
40 | // dimension 0 is batch dimension. |
41 | // |
42 | // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public) |
43 | // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor. |
44 | struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { |
45 | explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level); |
46 | |
47 | // Returns batch dimension of this tensor |
48 | int64_t bdim() const { return bdim_; } |
49 | |
50 | // Returns batch dimension of this tensor |
51 | int64_t level() const { return level_; } |
52 | |
53 | // BatchedTensorImpl wraps a Tensor |
54 | const Tensor& value() const { return value_; } |
55 | |
56 | // Given a public dimension index, return the dimension index in the underlying |
57 | // value() tensor. |
58 | // For example, if we have |
59 | // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0) |
60 | // bt.actualDim(0) -> 1 |
61 | // bt.actualDim(1) -> 2 |
62 | // bt.actualDim(2) -> 3 |
63 | // bt.actualDim(3) -> Error |
64 | int64_t actualDim(int64_t dim, bool wrap_dim = true) const; |
65 | |
66 | // We have to override this because we opted into CustomStrides |
67 | IntArrayRef strides_custom() const override; |
68 | SymIntArrayRef sym_strides_custom() const override; |
69 | // Override a bunch of methods inherited from TensorImpl to return error messages. |
70 | bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; |
71 | void set_size(int64_t dim, int64_t new_size) override; |
72 | void set_stride(int64_t dim, int64_t new_stride) override; |
73 | void set_storage_offset(int64_t storage_offset) override; |
74 | #ifdef DEBUG |
75 | bool has_storage() const override; |
76 | #endif |
77 | |
78 | void refreshTensorMetadata(); |
79 | |
80 | // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it |
81 | // accomplishes this is a hack where it is able to modify the levels of |
82 | // BatchedTensor to match the level of the current vmap transform. |
83 | void _unsafe_set_level(int64_t level) { |
84 | level_ = level; |
85 | } |
86 | |
87 | // Used in batching rule for in-place view operations that can change |
88 | // the index of the bdim (think squeeze_, unsqueeze_) |
89 | void unsafe_set_bdim(int64_t bdim) { |
90 | // NB: you MUST call refreshTensorMetadata after doing this. |
91 | bdim_ = bdim; |
92 | } |
93 | private: |
94 | // see NOTE: [BatchedTensorImpl levels invariant] |
95 | void checkInvariants() const; |
96 | const char* tensorimpl_type_name() const override; |
97 | |
98 | Tensor value_; |
99 | |
100 | int64_t level_; |
101 | int64_t bdim_; |
102 | }; |
103 | |
104 | // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a |
105 | // BatchedTensorImpl. |
106 | inline bool isBatchedTensor(const Tensor& tensor) { |
107 | return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched); |
108 | } |
109 | |
110 | // It is unsafe to call this on a Tensor that is not backed by a |
111 | // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible. |
112 | inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) { |
113 | return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl()); |
114 | } |
115 | |
116 | inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) { |
117 | if (!isBatchedTensor(tensor)) { |
118 | return nullptr; |
119 | } |
120 | return unsafeGetBatchedImpl(std::move(tensor)); |
121 | } |
122 | |
123 | // Returns a bitset. If bit i is set, then that means dim i is a batchdim. |
124 | inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) { |
125 | std::bitset<kVmapMaxTensorDims> is_bdim; |
126 | is_bdim.set(dim); |
127 | return is_bdim; |
128 | } |
129 | |
130 | // Creates a bitset for the given level |
131 | inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) { |
132 | std::bitset<kVmapNumLevels> result; |
133 | result.set(level); |
134 | return result; |
135 | } |
136 | |
137 | // Use this to construct a BatchedTensor from a regular Tensor |
138 | TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level); |
139 | |
140 | // Adds a batch dim to `tensor`, returning a BatchedTensor |
141 | TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level); |
142 | |
143 | // Certain dispatch keys must be propagated to the BatchedTensor (or, in general, |
144 | // any wrapper Tensor subclasses). This is because there are methods on Tensor |
145 | // that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()). |
146 | // TODO: should probably contain more (or all?) backend keys |
147 | constexpr DispatchKeySet kKeysToPropagateToWrapper({ |
148 | DispatchKey::Negative, |
149 | DispatchKey::Conjugate, |
150 | DispatchKey::XLA, |
151 | DispatchKey::CUDA, |
152 | DispatchKey::CPU, |
153 | }); |
154 | |
155 | inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { |
156 | auto key_set = tensor.unsafeGetTensorImpl()->key_set(); |
157 | return key_set & kKeysToPropagateToWrapper; |
158 | } |
159 | |
160 | } |
161 | } |
162 | |