1#pragma once
2
3#include <bitset>
4#include <utility>
5
6#include <ATen/ArrayRef.h>
7#include <ATen/SmallVector.h>
8#include <ATen/Tensor.h>
9
10namespace at {
11
12// We assume this in a few other places in the codebase,
13// but there isn't a centralized definition.
14constexpr int64_t kVmapMaxTensorDims = 64;
15
16// The valid vmap levels range from [0, 64). This effectively means that we
17// support a maximum of 64 nested vmaps.
18constexpr int64_t kVmapNumLevels = 64;
19
20// Store this number of elements of BatchDims on the stack. Most people will
21// probably use <= 5 nested vmaps, but adjust this number as necessary.
22constexpr int64_t kBatchDimsStackSize = 5;
23
24// a BatchDim represents a "private" dimension on a Tensor created inside of
25// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
26// is being vmap'ed over and the `level` being an identifier for which vmap
27// said dimension was created inside. The `dim` corresponds to a "physical
28// dim" - it is a dimension index on the underlying physical tensor that is
29// being vmapped over.
30struct BatchDim {
31 BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
32 int64_t dim() const {
33 return dim_;
34 }
35 int64_t level() const {
36 return level_;
37 }
38
39 private:
40 int64_t dim_;
41 int64_t level_;
42};
43
44using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
45using BatchDimsRef = ArrayRef<BatchDim>;
46
47// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
48// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
49// BatchedTensorImpl.
50//
51// The batch dimensions are treated as being "private"; they are not
52// user-visible. For example, in the following Tensor,
53// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
54// dimensions 0 and 1 are batch dimensions.
55//
56// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
57// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
58// tensor.
59struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
60 explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
61
62 // Returns a reference to BatchDims that represent which dimensions of this
63 // tensor are private.
64 BatchDimsRef bdims() const {
65 return bdims_;
66 }
67
68 // BatchedTensorImpl wraps a Tensor
69 const Tensor& value() const {
70 return value_;
71 };
72
73 // Given a public dimension index, return the dimension index in the
74 // underlying value() tensor. For example, if we have
75 // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
76 // dim=2)])
77 // bt.actualDim(0) -> 1
78 // bt.actualDim(1) -> 3
79 // bt.actualDim(2) -> Error
80 int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
81
82 // We have to override this because we opted into CustomStrides
83 IntArrayRef strides_custom() const override;
84 // Override a bunch of methods inherited from TensorImpl to return error
85 // messages.
86 bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
87 void set_size(int64_t dim, int64_t new_size) override;
88 void set_stride(int64_t dim, int64_t new_stride) override;
89 void set_storage_offset(int64_t storage_offset) override;
90#ifdef DEBUG
91 bool has_storage() const override;
92#endif
93
94 private:
95 // see NOTE: [BatchedTensorImpl levels invariant]
96 void checkInvariants() const;
97 const char* tensorimpl_type_name() const override;
98
99 Tensor value_;
100
101 // Note: [BatchedTensorImpl levels invariant]
102 // There is an invariant that the BatchDims must be stored in increasing
103 // `level` order. That is, for i < j, bdims_[i].level must be less than
104 // bdims_[j].level.
105 BatchDims bdims_;
106};
107
108// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
109// BatchedTensorImpl.
110inline bool isBatchedTensor(const Tensor& tensor) {
111 return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
112}
113
114// It is unsafe to call this on a Tensor that is not backed by a
115// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
116inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
117 return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
118}
119
120inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
121 if (!isBatchedTensor(tensor)) {
122 return nullptr;
123 }
124 return unsafeGetBatchedImpl(std::move(tensor));
125}
126
127// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
128inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
129 BatchDimsRef bdims) {
130 std::bitset<kVmapMaxTensorDims> is_bdim;
131 for (const auto& bdim : bdims) {
132 is_bdim.set(bdim.dim());
133 }
134 return is_bdim;
135}
136
137// Creates a bitset for all of the levels present in `bdims`
138inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
139 std::bitset<kVmapNumLevels> result;
140 for (const auto& bdim : bdims) {
141 result.set(bdim.level());
142 }
143 return result;
144}
145
146inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
147 out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
148 return out;
149}
150
151// Use this to construct a BatchedTensor from a regular Tensor
152TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
153
154// Adds a batch dim to `tensor`, returning a BatchedTensor
155TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
156
157// Checks if an inplace operation on self and other is "vmap compatible".
158// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
159TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
160
161} // namespace at
162