1 | #pragma once |
2 | |
3 | #include <bitset> |
4 | #include <utility> |
5 | |
6 | #include <ATen/ArrayRef.h> |
7 | #include <ATen/SmallVector.h> |
8 | #include <ATen/Tensor.h> |
9 | |
10 | namespace at { |
11 | |
12 | // We assume this in a few other places in the codebase, |
13 | // but there isn't a centralized definition. |
14 | constexpr int64_t kVmapMaxTensorDims = 64; |
15 | |
16 | // The valid vmap levels range from [0, 64). This effectively means that we |
17 | // support a maximum of 64 nested vmaps. |
18 | constexpr int64_t kVmapNumLevels = 64; |
19 | |
20 | // Store this number of elements of BatchDims on the stack. Most people will |
21 | // probably use <= 5 nested vmaps, but adjust this number as necessary. |
22 | constexpr int64_t kBatchDimsStackSize = 5; |
23 | |
24 | // a BatchDim represents a "private" dimension on a Tensor created inside of |
25 | // vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension |
26 | // is being vmap'ed over and the `level` being an identifier for which vmap |
27 | // said dimension was created inside. The `dim` corresponds to a "physical |
28 | // dim" - it is a dimension index on the underlying physical tensor that is |
29 | // being vmapped over. |
30 | struct BatchDim { |
31 | BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {} |
32 | int64_t dim() const { |
33 | return dim_; |
34 | } |
35 | int64_t level() const { |
36 | return level_; |
37 | } |
38 | |
39 | private: |
40 | int64_t dim_; |
41 | int64_t level_; |
42 | }; |
43 | |
44 | using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>; |
45 | using BatchDimsRef = ArrayRef<BatchDim>; |
46 | |
47 | // A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim |
48 | // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a |
49 | // BatchedTensorImpl. |
50 | // |
51 | // The batch dimensions are treated as being "private"; they are not |
52 | // user-visible. For example, in the following Tensor, |
53 | // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)]) |
54 | // dimensions 0 and 1 are batch dimensions. |
55 | // |
56 | // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public) |
57 | // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) |
58 | // tensor. |
59 | struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { |
60 | explicit BatchedTensorImpl(Tensor value, BatchDims bdims); |
61 | |
62 | // Returns a reference to BatchDims that represent which dimensions of this |
63 | // tensor are private. |
64 | BatchDimsRef bdims() const { |
65 | return bdims_; |
66 | } |
67 | |
68 | // BatchedTensorImpl wraps a Tensor |
69 | const Tensor& value() const { |
70 | return value_; |
71 | }; |
72 | |
73 | // Given a public dimension index, return the dimension index in the |
74 | // underlying value() tensor. For example, if we have |
75 | // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, |
76 | // dim=2)]) |
77 | // bt.actualDim(0) -> 1 |
78 | // bt.actualDim(1) -> 3 |
79 | // bt.actualDim(2) -> Error |
80 | int64_t actualDim(int64_t dim, bool wrap_dim = true) const; |
81 | |
82 | // We have to override this because we opted into CustomStrides |
83 | IntArrayRef strides_custom() const override; |
84 | // Override a bunch of methods inherited from TensorImpl to return error |
85 | // messages. |
86 | bool is_contiguous_custom(at::MemoryFormat memory_format) const override; |
87 | void set_size(int64_t dim, int64_t new_size) override; |
88 | void set_stride(int64_t dim, int64_t new_stride) override; |
89 | void set_storage_offset(int64_t storage_offset) override; |
90 | #ifdef DEBUG |
91 | bool has_storage() const override; |
92 | #endif |
93 | |
94 | private: |
95 | // see NOTE: [BatchedTensorImpl levels invariant] |
96 | void checkInvariants() const; |
97 | const char* tensorimpl_type_name() const override; |
98 | |
99 | Tensor value_; |
100 | |
101 | // Note: [BatchedTensorImpl levels invariant] |
102 | // There is an invariant that the BatchDims must be stored in increasing |
103 | // `level` order. That is, for i < j, bdims_[i].level must be less than |
104 | // bdims_[j].level. |
105 | BatchDims bdims_; |
106 | }; |
107 | |
108 | // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a |
109 | // BatchedTensorImpl. |
110 | inline bool isBatchedTensor(const Tensor& tensor) { |
111 | return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched); |
112 | } |
113 | |
114 | // It is unsafe to call this on a Tensor that is not backed by a |
115 | // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible. |
116 | inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) { |
117 | return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl()); |
118 | } |
119 | |
120 | inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) { |
121 | if (!isBatchedTensor(tensor)) { |
122 | return nullptr; |
123 | } |
124 | return unsafeGetBatchedImpl(std::move(tensor)); |
125 | } |
126 | |
127 | // Returns a bitset. If bit i is set, then that means dim i is a batchdim. |
128 | inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset( |
129 | BatchDimsRef bdims) { |
130 | std::bitset<kVmapMaxTensorDims> is_bdim; |
131 | for (const auto& bdim : bdims) { |
132 | is_bdim.set(bdim.dim()); |
133 | } |
134 | return is_bdim; |
135 | } |
136 | |
137 | // Creates a bitset for all of the levels present in `bdims` |
138 | inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) { |
139 | std::bitset<kVmapNumLevels> result; |
140 | for (const auto& bdim : bdims) { |
141 | result.set(bdim.level()); |
142 | } |
143 | return result; |
144 | } |
145 | |
146 | inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) { |
147 | out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")" ; |
148 | return out; |
149 | } |
150 | |
151 | // Use this to construct a BatchedTensor from a regular Tensor |
152 | TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims); |
153 | |
154 | // Adds a batch dim to `tensor`, returning a BatchedTensor |
155 | TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim); |
156 | |
157 | // Checks if an inplace operation on self and other is "vmap compatible". |
158 | // See NOTE: [vmap-incompatible in-place operations] for the definition of this. |
159 | TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other); |
160 | |
161 | } // namespace at |
162 | |