1#include <ATen/LegacyBatchedTensorImpl.h>
2
3#include <ATen/WrapDimUtils.h>
4#include <c10/util/Exception.h>
5#include <c10/util/irange.h>
6
7namespace at {
8
9BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
10 : TensorImpl(
11 c10::DispatchKeySet(DispatchKey::Batched),
12 value.dtype(),
13 value.device()
14 )
15 , value_(std::move(value))
16 , bdims_(std::move(bdims))
17{
18 TORCH_INTERNAL_ASSERT(value_.defined());
19 set_storage_access_should_throw();
20 set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
21 checkInvariants();
22
23 const auto public_dims = value_.dim() - bdims_.size();
24 const auto value_sizes = value_.sizes();
25 const auto value_strides = value_.strides();
26 sizes_and_strides_.resize(public_dims);
27 for (const auto dim : c10::irange(public_dims)) {
28 auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
29 sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim);
30 sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim);
31 }
32 storage_offset_ = value_.storage_offset();
33 refresh_numel();
34 refresh_contiguous();
35}
36
37int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const {
38 if (wrap_dim) {
39 const auto ndim = sizes_and_strides_.size();
40 dim = maybe_wrap_dim(dim, ndim);
41 }
42 auto is_bdim = createBatchDimBitset(bdims_);
43
44 // Example: assume dim = 3, and is_bdim = 10010011000...
45 // The 1's are batch dims and 0's are normal dims of the underlying value_ Tensor.
46 // actualDim gives us the index of `dim` in the `value_` Tensor, which is equivalent
47 // to asking "where does the 3rd (0-indexed) zero occur in the bitset?".
48 // The answer to that is index 5.
49 //
50 // TODO(rzou): the PDEP instruction does exactly this
51 // (https://stackoverflow.com/questions/7669057/find-nth-set-bit-in-an-int)
52 // but it might require newer (>= ~2015) CPUs. We should clean this up
53 // if/when we have dropped support for older CPUs.
54 int64_t non_bdim_count = 0;
55 for (const auto actual_dim : c10::irange(kVmapMaxTensorDims)) {
56 if (is_bdim[actual_dim]) {
57 continue;
58 }
59 if (non_bdim_count == dim) {
60 return actual_dim;
61 }
62 non_bdim_count++;
63 }
64 // If we hit this assert, then that means
65 // `non_bdim_count` + #num_bdims > kVmapMaxTensorDims. We restrict the number
66 // of dims a BatchedTensorImpl can have to kVmapMaxTensorDims so this should
67 // never be hit.
68 TORCH_INTERNAL_ASSERT(false);
69}
70
71void BatchedTensorImpl::checkInvariants() const {
72 int64_t prev_level = -1;
73 for (const auto& bdim : bdims_) {
74 TORCH_INTERNAL_ASSERT(bdim.level() > prev_level);
75 prev_level = bdim.level();
76 }
77}
78
79// The following are publically exposed as methods of Tensor
80
81IntArrayRef BatchedTensorImpl::strides_custom() const {
82 return strides_default();
83}
84
85// TODO: implement proper contiguity on batched tensor, then put
86// sizes_strides_policy back to Default
87bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
88 TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
89 "NYI: querying is_contiguous inside of vmap for memory_format ",
90 "other than torch.contiguous_format");
91 return is_contiguous_;
92}
93
94// The following are some internal inherited methods that we do not support.
95// They should never get called.
96void BatchedTensorImpl::set_size(int64_t dim, int64_t new_size) {
97 TORCH_INTERNAL_ASSERT(false, "Can't set_size for BatchedTensorImpl");
98}
99void BatchedTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
100 TORCH_INTERNAL_ASSERT(false, "Can't set_stride for BatchedTensorImpl");
101}
102void BatchedTensorImpl::set_storage_offset(int64_t storage_offset) {
103 TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for BatchedTensorImpl");
104}
105#ifdef DEBUG
106bool BatchedTensorImpl::has_storage() const {
107 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "BatchedTensorImpl assumes that storage_ is never set");
108 return false;
109}
110#endif
111
112const char* BatchedTensorImpl::tensorimpl_type_name() const {
113 return "BatchedTensorImpl";
114}
115
116Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
117 TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
118 auto tensor_dim = tensor.dim();
119 TORCH_CHECK(
120 tensor_dim <= kVmapMaxTensorDims,
121 "vmap only supports tensors of dimensionality up to ", kVmapMaxTensorDims,
122 "; got a tensor with dim ", tensor_dim);
123 TORCH_INTERNAL_ASSERT(
124 std::all_of(bdims.begin(), bdims.end(),
125 [](const BatchDim& bdim) { return bdim.level() < kVmapNumLevels; }),
126 "We only support up to ", kVmapNumLevels, " nested vmaps");
127 return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
128}
129
130Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) {
131 const auto* batched = maybeGetBatchedImpl(tensor);
132 if (!batched) {
133 BatchDims bdims;
134 bdims.emplace_back(level, dim);
135 return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
136 }
137 BatchDims new_bdims(batched->bdims().begin(), batched->bdims().end());
138 auto actual_bdim = batched->actualDim(dim, /*wrap_dim=*/true);
139 new_bdims.emplace_back(level, actual_bdim);
140 return makeBatched(batched->value(), std::move(new_bdims));
141}
142
143bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other) {
144 const auto* other_batched = maybeGetBatchedImpl(other);
145 if (!other_batched) {
146 return true;
147 }
148 const auto* self_batched = maybeGetBatchedImpl(self);
149 if (!self_batched) {
150 // self is not batched but other is batched
151 return false;
152 }
153 auto self_levels = createVmapLevelsBitset(self_batched->bdims());
154 auto other_levels = createVmapLevelsBitset(other_batched->bdims());
155 return self_levels == (self_levels | other_levels);
156}
157
158} // namespace at
159