1 | #include <ATen/LegacyBatchedTensorImpl.h> |
2 | |
3 | #include <ATen/WrapDimUtils.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/irange.h> |
6 | |
7 | namespace at { |
8 | |
9 | BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims) |
10 | : TensorImpl( |
11 | c10::DispatchKeySet(DispatchKey::Batched), |
12 | value.dtype(), |
13 | value.device() |
14 | ) |
15 | , value_(std::move(value)) |
16 | , bdims_(std::move(bdims)) |
17 | { |
18 | TORCH_INTERNAL_ASSERT(value_.defined()); |
19 | set_storage_access_should_throw(); |
20 | set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); |
21 | checkInvariants(); |
22 | |
23 | const auto public_dims = value_.dim() - bdims_.size(); |
24 | const auto value_sizes = value_.sizes(); |
25 | const auto value_strides = value_.strides(); |
26 | sizes_and_strides_.resize(public_dims); |
27 | for (const auto dim : c10::irange(public_dims)) { |
28 | auto actual_dim = actualDim(dim, /*wrap_dim=*/false); |
29 | sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim); |
30 | sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim); |
31 | } |
32 | storage_offset_ = value_.storage_offset(); |
33 | refresh_numel(); |
34 | refresh_contiguous(); |
35 | } |
36 | |
37 | int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const { |
38 | if (wrap_dim) { |
39 | const auto ndim = sizes_and_strides_.size(); |
40 | dim = maybe_wrap_dim(dim, ndim); |
41 | } |
42 | auto is_bdim = createBatchDimBitset(bdims_); |
43 | |
44 | // Example: assume dim = 3, and is_bdim = 10010011000... |
45 | // The 1's are batch dims and 0's are normal dims of the underlying value_ Tensor. |
46 | // actualDim gives us the index of `dim` in the `value_` Tensor, which is equivalent |
47 | // to asking "where does the 3rd (0-indexed) zero occur in the bitset?". |
48 | // The answer to that is index 5. |
49 | // |
50 | // TODO(rzou): the PDEP instruction does exactly this |
51 | // (https://stackoverflow.com/questions/7669057/find-nth-set-bit-in-an-int) |
52 | // but it might require newer (>= ~2015) CPUs. We should clean this up |
53 | // if/when we have dropped support for older CPUs. |
54 | int64_t non_bdim_count = 0; |
55 | for (const auto actual_dim : c10::irange(kVmapMaxTensorDims)) { |
56 | if (is_bdim[actual_dim]) { |
57 | continue; |
58 | } |
59 | if (non_bdim_count == dim) { |
60 | return actual_dim; |
61 | } |
62 | non_bdim_count++; |
63 | } |
64 | // If we hit this assert, then that means |
65 | // `non_bdim_count` + #num_bdims > kVmapMaxTensorDims. We restrict the number |
66 | // of dims a BatchedTensorImpl can have to kVmapMaxTensorDims so this should |
67 | // never be hit. |
68 | TORCH_INTERNAL_ASSERT(false); |
69 | } |
70 | |
71 | void BatchedTensorImpl::checkInvariants() const { |
72 | int64_t prev_level = -1; |
73 | for (const auto& bdim : bdims_) { |
74 | TORCH_INTERNAL_ASSERT(bdim.level() > prev_level); |
75 | prev_level = bdim.level(); |
76 | } |
77 | } |
78 | |
79 | // The following are publically exposed as methods of Tensor |
80 | |
81 | IntArrayRef BatchedTensorImpl::strides_custom() const { |
82 | return strides_default(); |
83 | } |
84 | |
85 | // TODO: implement proper contiguity on batched tensor, then put |
86 | // sizes_strides_policy back to Default |
87 | bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { |
88 | TORCH_CHECK(memory_format == MemoryFormat::Contiguous, |
89 | "NYI: querying is_contiguous inside of vmap for memory_format " , |
90 | "other than torch.contiguous_format" ); |
91 | return is_contiguous_; |
92 | } |
93 | |
94 | // The following are some internal inherited methods that we do not support. |
95 | // They should never get called. |
96 | void BatchedTensorImpl::set_size(int64_t dim, int64_t new_size) { |
97 | TORCH_INTERNAL_ASSERT(false, "Can't set_size for BatchedTensorImpl" ); |
98 | } |
99 | void BatchedTensorImpl::set_stride(int64_t dim, int64_t new_stride) { |
100 | TORCH_INTERNAL_ASSERT(false, "Can't set_stride for BatchedTensorImpl" ); |
101 | } |
102 | void BatchedTensorImpl::set_storage_offset(int64_t storage_offset) { |
103 | TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for BatchedTensorImpl" ); |
104 | } |
105 | #ifdef DEBUG |
106 | bool BatchedTensorImpl::has_storage() const { |
107 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "BatchedTensorImpl assumes that storage_ is never set" ); |
108 | return false; |
109 | } |
110 | #endif |
111 | |
112 | const char* BatchedTensorImpl::tensorimpl_type_name() const { |
113 | return "BatchedTensorImpl" ; |
114 | } |
115 | |
116 | Tensor makeBatched(const Tensor& tensor, BatchDims bdims) { |
117 | TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor)); |
118 | auto tensor_dim = tensor.dim(); |
119 | TORCH_CHECK( |
120 | tensor_dim <= kVmapMaxTensorDims, |
121 | "vmap only supports tensors of dimensionality up to " , kVmapMaxTensorDims, |
122 | "; got a tensor with dim " , tensor_dim); |
123 | TORCH_INTERNAL_ASSERT( |
124 | std::all_of(bdims.begin(), bdims.end(), |
125 | [](const BatchDim& bdim) { return bdim.level() < kVmapNumLevels; }), |
126 | "We only support up to " , kVmapNumLevels, " nested vmaps" ); |
127 | return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims)); |
128 | } |
129 | |
130 | Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) { |
131 | const auto* batched = maybeGetBatchedImpl(tensor); |
132 | if (!batched) { |
133 | BatchDims bdims; |
134 | bdims.emplace_back(level, dim); |
135 | return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims)); |
136 | } |
137 | BatchDims new_bdims(batched->bdims().begin(), batched->bdims().end()); |
138 | auto actual_bdim = batched->actualDim(dim, /*wrap_dim=*/true); |
139 | new_bdims.emplace_back(level, actual_bdim); |
140 | return makeBatched(batched->value(), std::move(new_bdims)); |
141 | } |
142 | |
143 | bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other) { |
144 | const auto* other_batched = maybeGetBatchedImpl(other); |
145 | if (!other_batched) { |
146 | return true; |
147 | } |
148 | const auto* self_batched = maybeGetBatchedImpl(self); |
149 | if (!self_batched) { |
150 | // self is not batched but other is batched |
151 | return false; |
152 | } |
153 | auto self_levels = createVmapLevelsBitset(self_batched->bdims()); |
154 | auto other_levels = createVmapLevelsBitset(other_batched->bdims()); |
155 | return self_levels == (self_levels | other_levels); |
156 | } |
157 | |
158 | } // namespace at |
159 | |