1 | #include <ATen/FunctionalStorageImpl.h> |
2 | |
3 | #include <ATen/EmptyTensor.h> |
4 | #include <ATen/FunctionalTensorWrapper.h> |
5 | #include <ATen/core/LegacyTypeDispatch.h> |
6 | #include <c10/util/Exception.h> |
7 | #include <vector> |
8 | |
9 | namespace at { |
10 | namespace functionalization { |
11 | |
12 | ViewMeta ViewMeta::to_out_idx(int64_t out_idx) { |
13 | if (out_idx == this->out_index) return *this; |
14 | return ViewMeta(forward_fn, reverse_fn, out_idx); |
15 | } |
16 | |
17 | // Note [Functionalization: Alias Removal Part 2] |
18 | // See Note [Functionalization: Alias Removal] for more details. |
19 | // This function applies a single update from one of the views to the StorageImpl. |
20 | // We start out with <original_base> and <mutated_view>, and our goal is to end up with <mutated_base>. |
21 | // Consider this program: |
22 | // |
23 | // base = ... |
24 | // a = base.view1() |
25 | // b = a.view2() |
26 | // c = b.view3() |
27 | // c.add_(3) |
28 | // |
29 | // Then the functionalization pass will queue an update as follows: |
30 | // |
31 | // update.new_val = c # the updated value of c |
32 | // update.view_metas = [view1_meta, view2_meta, view3_meta] |
33 | // |
34 | // Syncing any of a, b or c will eventually call apply_update() on the storage, and the following will run: |
35 | // |
36 | // tmp_values = [base, a, b] # NB: c is not necessary |
37 | // t = update.new_val |
38 | // t = view3_inverse(b, t, 0) # 0 is output index, these are all single output views so it's 0 |
39 | // t = view2_inverse(a, t, 0) |
40 | // t = view1_inverse(base, t, 0) # t now represents the updated storage. |
41 | // storage.base_ = t |
42 | const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) { |
43 | at::Tensor t = update.new_val; |
44 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); |
45 | if (update.view_metas.empty()) return t; |
46 | |
47 | std::vector<at::Tensor> tmp_values({base}); |
48 | tmp_values.reserve(update.view_metas.size()); |
49 | for (size_t i = 0; i < update.view_metas.size() - 1; ++i) { |
50 | at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index); |
51 | // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided |
52 | // All of these ops require additional information to recover the sizes of the original tensor. |
53 | // If need to, we could probably apply this optimization and only bother computing tmp_values |
54 | // for those necessary view ops. |
55 | tmp_values.push_back(std::move(next_view)); |
56 | } |
57 | for(int i = update.view_metas.size()-1; i >= 0; --i) { |
58 | int64_t out_idx = update.view_metas[i].out_index; |
59 | // Each view inverse is implemented in ViewInverses.cpp. |
60 | t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx); |
61 | } |
62 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); |
63 | return t; |
64 | } |
65 | |
66 | |
67 | c10::SymInt get_nbytes(const Tensor& value) { |
68 | // The functionalization story when wrapping tensors that don't have storage |
69 | // is a bit wonky, but fortunately for some models (e.g., dlrm) we never |
70 | // actually perform mutations on these tensors, so you never really get |
71 | // called out on it. For now, functionalization still creates "storages" |
72 | // for these tensors (which is wrong), but we don't give them any space. |
73 | // A more proper fix would be to have a SparseFunctionalTensorWrapper that |
74 | // models sparse correctly. |
75 | if (value.is_sparse()) { |
76 | return 0; |
77 | } |
78 | if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) { |
79 | // Today, the two implementations of SymInt are in Python (proxy tensor), |
80 | // and lazy tensor (LTC/XLA). |
81 | // LTC hasn't implemented SymInt support yet though |
82 | // Once it does, we should remove this check. |
83 | if (value.key_set().has(c10::DispatchKey::Python)) { |
84 | return value.storage().sym_nbytes(); |
85 | } |
86 | } |
87 | // XLA storage objects also do not properly track nbytes. |
88 | return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset()); |
89 | } |
90 | |
91 | FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base) |
92 | : c10::StorageImpl( |
93 | c10::StorageImpl::use_byte_size_t(), |
94 | get_nbytes(base), |
95 | DataPtr{nullptr, base.device()}, |
96 | GetAllocator(kMeta), |
97 | /*resizeable=*/true |
98 | ), |
99 | base_(base) |
100 | { |
101 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_)); |
102 | } |
103 | |
104 | void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) { |
105 | TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage" ); |
106 | updates_.push_back({updated_val, metas}); |
107 | generation_++; |
108 | } |
109 | |
110 | bool FunctionalStorageImpl::apply_updates() { |
111 | // N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point. |
112 | // The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack. |
113 | // It adds the Functionalize key into TLS before redispatching to the functionalization kernels, |
114 | // which means that we need to explicitly exclude it here before doing any other work underneath the pass. |
115 | at::AutoDispatchSkipFunctionalize guard; |
116 | bool any_updates = !updates_.empty(); |
117 | for (auto& update_data: updates_) { |
118 | base_ = apply_update(update_data, base_); |
119 | } |
120 | updates_.clear(); |
121 | return any_updates; |
122 | } |
123 | |
124 | } // namespace functionalization |
125 | } // namespace at |
126 | |