1#include <ATen/FunctionalStorageImpl.h>
2
3#include <ATen/EmptyTensor.h>
4#include <ATen/FunctionalTensorWrapper.h>
5#include <ATen/core/LegacyTypeDispatch.h>
6#include <c10/util/Exception.h>
7#include <vector>
8
9namespace at {
10namespace functionalization {
11
12ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
13 if (out_idx == this->out_index) return *this;
14 return ViewMeta(forward_fn, reverse_fn, out_idx);
15}
16
17// Note [Functionalization: Alias Removal Part 2]
18// See Note [Functionalization: Alias Removal] for more details.
19// This function applies a single update from one of the views to the StorageImpl.
20// We start out with <original_base> and <mutated_view>, and our goal is to end up with <mutated_base>.
21// Consider this program:
22//
23// base = ...
24// a = base.view1()
25// b = a.view2()
26// c = b.view3()
27// c.add_(3)
28//
29// Then the functionalization pass will queue an update as follows:
30//
31// update.new_val = c # the updated value of c
32// update.view_metas = [view1_meta, view2_meta, view3_meta]
33//
34// Syncing any of a, b or c will eventually call apply_update() on the storage, and the following will run:
35//
36// tmp_values = [base, a, b] # NB: c is not necessary
37// t = update.new_val
38// t = view3_inverse(b, t, 0) # 0 is output index, these are all single output views so it's 0
39// t = view2_inverse(a, t, 0)
40// t = view1_inverse(base, t, 0) # t now represents the updated storage.
41// storage.base_ = t
42const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
43 at::Tensor t = update.new_val;
44 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
45 if (update.view_metas.empty()) return t;
46
47 std::vector<at::Tensor> tmp_values({base});
48 tmp_values.reserve(update.view_metas.size());
49 for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
50 at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
51 // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
52 // All of these ops require additional information to recover the sizes of the original tensor.
53 // If need to, we could probably apply this optimization and only bother computing tmp_values
54 // for those necessary view ops.
55 tmp_values.push_back(std::move(next_view));
56 }
57 for(int i = update.view_metas.size()-1; i >= 0; --i) {
58 int64_t out_idx = update.view_metas[i].out_index;
59 // Each view inverse is implemented in ViewInverses.cpp.
60 t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
61 }
62 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
63 return t;
64}
65
66
67c10::SymInt get_nbytes(const Tensor& value) {
68 // The functionalization story when wrapping tensors that don't have storage
69 // is a bit wonky, but fortunately for some models (e.g., dlrm) we never
70 // actually perform mutations on these tensors, so you never really get
71 // called out on it. For now, functionalization still creates "storages"
72 // for these tensors (which is wrong), but we don't give them any space.
73 // A more proper fix would be to have a SparseFunctionalTensorWrapper that
74 // models sparse correctly.
75 if (value.is_sparse()) {
76 return 0;
77 }
78 if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
79 // Today, the two implementations of SymInt are in Python (proxy tensor),
80 // and lazy tensor (LTC/XLA).
81 // LTC hasn't implemented SymInt support yet though
82 // Once it does, we should remove this check.
83 if (value.key_set().has(c10::DispatchKey::Python)) {
84 return value.storage().sym_nbytes();
85 }
86 }
87 // XLA storage objects also do not properly track nbytes.
88 return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset());
89}
90
91FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
92 : c10::StorageImpl(
93 c10::StorageImpl::use_byte_size_t(),
94 get_nbytes(base),
95 DataPtr{nullptr, base.device()},
96 GetAllocator(kMeta),
97 /*resizeable=*/true
98 ),
99 base_(base)
100 {
101 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
102}
103
104void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
105 TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
106 updates_.push_back({updated_val, metas});
107 generation_++;
108}
109
110bool FunctionalStorageImpl::apply_updates() {
111 // N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point.
112 // The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack.
113 // It adds the Functionalize key into TLS before redispatching to the functionalization kernels,
114 // which means that we need to explicitly exclude it here before doing any other work underneath the pass.
115 at::AutoDispatchSkipFunctionalize guard;
116 bool any_updates = !updates_.empty();
117 for (auto& update_data: updates_) {
118 base_ = apply_update(update_data, base_);
119 }
120 updates_.clear();
121 return any_updates;
122}
123
124} // namespace functionalization
125} // namespace at
126