1// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#pragma once
8
9#include <bitset>
10#include <utility>
11
12#include <ATen/ArrayRef.h>
13#include <ATen/SmallVector.h>
14#include <ATen/Tensor.h>
15
16namespace at {
17namespace functorch {
18
19using Tensor = at::Tensor;
20
21// We assume this in a few other places in the codebase,
22// but there isn't a centralized definition.
23constexpr int64_t kVmapMaxTensorDims = 64;
24
25// The valid vmap levels range from [0, 64). This effectively means that we
26// support a maximum of 64 nested vmaps.
27constexpr int64_t kVmapNumLevels = 64;
28
29// Store this number of elements of BatchDims on the stack. Most people will
30// probably use <= 5 nested vmaps, but adjust this number as necessary.
31constexpr int64_t kBatchDimsStackSize = 5;
32
33// A BatchedTensorImpl holds an underlying Tensor and a single batch dim
34// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
35// BatchedTensorImpl.
36//
37// The batch dimensions are treated as being "private"; they are not user-visible.
38// For example, in the following Tensor,
39// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
40// dimension 0 is batch dimension.
41//
42// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
43// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
44struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
45 explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
46
47 // Returns batch dimension of this tensor
48 int64_t bdim() const { return bdim_; }
49
50 // Returns batch dimension of this tensor
51 int64_t level() const { return level_; }
52
53 // BatchedTensorImpl wraps a Tensor
54 const Tensor& value() const { return value_; }
55
56 // Given a public dimension index, return the dimension index in the underlying
57 // value() tensor.
58 // For example, if we have
59 // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
60 // bt.actualDim(0) -> 1
61 // bt.actualDim(1) -> 2
62 // bt.actualDim(2) -> 3
63 // bt.actualDim(3) -> Error
64 int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
65
66 // We have to override this because we opted into CustomStrides
67 IntArrayRef strides_custom() const override;
68 SymIntArrayRef sym_strides_custom() const override;
69 // Override a bunch of methods inherited from TensorImpl to return error messages.
70 bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
71 void set_size(int64_t dim, int64_t new_size) override;
72 void set_stride(int64_t dim, int64_t new_stride) override;
73 void set_storage_offset(int64_t storage_offset) override;
74#ifdef DEBUG
75 bool has_storage() const override;
76#endif
77
78 void refreshTensorMetadata();
79
80 // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
81 // accomplishes this is a hack where it is able to modify the levels of
82 // BatchedTensor to match the level of the current vmap transform.
83 void _unsafe_set_level(int64_t level) {
84 level_ = level;
85 }
86
87 // Used in batching rule for in-place view operations that can change
88 // the index of the bdim (think squeeze_, unsqueeze_)
89 void unsafe_set_bdim(int64_t bdim) {
90 // NB: you MUST call refreshTensorMetadata after doing this.
91 bdim_ = bdim;
92 }
93 private:
94 // see NOTE: [BatchedTensorImpl levels invariant]
95 void checkInvariants() const;
96 const char* tensorimpl_type_name() const override;
97
98 Tensor value_;
99
100 int64_t level_;
101 int64_t bdim_;
102};
103
104// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
105// BatchedTensorImpl.
106inline bool isBatchedTensor(const Tensor& tensor) {
107 return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched);
108}
109
110// It is unsafe to call this on a Tensor that is not backed by a
111// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
112inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
113 return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
114}
115
116inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
117 if (!isBatchedTensor(tensor)) {
118 return nullptr;
119 }
120 return unsafeGetBatchedImpl(std::move(tensor));
121}
122
123// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
124inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
125 std::bitset<kVmapMaxTensorDims> is_bdim;
126 is_bdim.set(dim);
127 return is_bdim;
128}
129
130// Creates a bitset for the given level
131inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
132 std::bitset<kVmapNumLevels> result;
133 result.set(level);
134 return result;
135}
136
137// Use this to construct a BatchedTensor from a regular Tensor
138TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);
139
140// Adds a batch dim to `tensor`, returning a BatchedTensor
141TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);
142
143// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
144// any wrapper Tensor subclasses). This is because there are methods on Tensor
145// that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
146// TODO: should probably contain more (or all?) backend keys
147constexpr DispatchKeySet kKeysToPropagateToWrapper({
148 DispatchKey::Negative,
149 DispatchKey::Conjugate,
150 DispatchKey::XLA,
151 DispatchKey::CUDA,
152 DispatchKey::CPU,
153});
154
155inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
156 auto key_set = tensor.unsafeGetTensorImpl()->key_set();
157 return key_set & kKeysToPropagateToWrapper;
158}
159
160}
161}
162