1 | #pragma once |
2 | |
3 | #include <ATen/Tensor.h> |
4 | |
5 | namespace at { |
6 | namespace functionalization { |
7 | |
8 | // See Note [Functionalization Pass In Core] |
9 | |
10 | // ViewMeta is a class used by the functionalization pass to navigate between |
11 | // a base tensor and a view tensor. |
12 | // For example, if I call `b = a.view1(...)` |
13 | // the functionalization pass will generate and store a ViewMeta on b that looks |
14 | // like: |
15 | // |
16 | // ViewMeta( |
17 | // [<captures>](const Tensor& base, int64_t mutated_view_idx) { |
18 | // return base.view1(...); |
19 | // }, |
20 | // [<captures>](const at::Tensor& base, const at::Tensor& mutated_view, |
21 | // int64_t mutated_view_idx) -> at::Tensor { |
22 | // return at::functionalization::impl::view1_inverse(base, mutated_view, |
23 | // ...); |
24 | // } |
25 | // |
26 | // The forward_fn lambda describes how to replay view1 on a tensor. |
27 | // |
28 | // The reverse_fn lambda describes how, given a tensor that is already a view, |
29 | // how to get the corresponding base tensor. See Note [Functionalization Pass: |
30 | // View Inverses] for details. |
31 | struct ViewMeta { |
32 | ViewMeta( |
33 | std::function<Tensor(const Tensor&, int64_t)> forward, |
34 | std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse, |
35 | int64_t out_idx = 0) |
36 | : forward_fn(std::move(forward)), |
37 | reverse_fn(std::move(reverse)), |
38 | out_index(out_idx) {} |
39 | |
40 | std::function<Tensor(const Tensor&, int64_t)> forward_fn; |
41 | std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn; |
42 | // See Note [out_idx in ViewMeta] |
43 | int64_t out_index; |
44 | |
45 | // Returns a copy of the current ViewMeta, if out_idx matches the current |
46 | // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse |
47 | // functions, but a new out index. |
48 | ViewMeta to_out_idx(int64_t out_idx); |
49 | }; |
50 | |
51 | // FunctionalStorageImpl is a subclass of StorageImpl used by the |
52 | // functionalization pass. It has no underlying data (similar to meta storage). |
53 | // It also knows how to reflect mutations to tensors in the absence of a valid |
54 | // data pointer. |
55 | // |
56 | // A storage represents the state shared by (potentially multiple) views of the |
57 | // same tensor. For example, in the following code: |
58 | // |
59 | // b = a.view1(...) |
60 | // c = b.view2(...) |
61 | // b.add_(1) |
62 | // --> storage.add_update(b, {view1_meta}) |
63 | // |
64 | // The call to add_(1) will result in a call to alias.add_update(b, |
65 | // {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose |
66 | // c is used in an expression (e.g. you try to print c, or pass it to an |
67 | // operator). Doing so will involve "syncing" c. First we apply any pending |
68 | // updates to the alias, and then we regenerate c by replaying its views off of |
69 | // the updated alias. E.g: |
70 | // |
71 | // print(str(c)) |
72 | // --> c.sync_() |
73 | // --> alias.apply_updates() // after this, the alias will be updated to |
74 | // reflect the mutation to b |
75 | struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { |
76 | public: |
77 | struct Update { |
78 | const at::Tensor new_val; |
79 | const std::vector<ViewMeta> view_metas; |
80 | }; |
81 | |
82 | explicit FunctionalStorageImpl(const Tensor& value); |
83 | |
84 | void add_update( |
85 | const Tensor& updated_val, |
86 | const std::vector<ViewMeta>& view_metas); |
87 | bool apply_updates(); |
88 | const Tensor& base() { |
89 | return base_; |
90 | } |
91 | size_t generation() const { |
92 | return generation_; |
93 | } |
94 | void freeze() { |
95 | frozen_ = true; |
96 | } |
97 | |
98 | ~FunctionalStorageImpl() override = default; |
99 | |
100 | private: |
101 | // NB: base_ should always point to a tensor BELOW the current |
102 | // functionalization layer. This is mainly to avoid reference cycles. e.g. |
103 | // given `b = a.view(...)` Both a.storage_ and b.storage_ are a |
104 | // FunctionStorageImpl containing an Walualias, with contains a Tensor |
105 | // `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_ |
106 | // should point not to a, but to a's unwrapped value, a.value_` See Note |
107 | // [Functionalization: Walualias Removal] for a diagram that shows this |
108 | // visually. |
109 | at::Tensor base_; |
110 | std::vector<Update> updates_; |
111 | // generation_ gets incremented every time a mutation is queued onto the |
112 | // alias. It is used to determine if a given tensor is "up to date", or if it |
113 | // needs to be regenerated from the alias. |
114 | size_t generation_ = 0; |
115 | // If frozen, no more mutations are allowed on this storage. Once frozen, a |
116 | // storage cannot be unfrozen. |
117 | bool frozen_ = false; |
118 | }; |
119 | |
120 | } // namespace functionalization |
121 | } // namespace at |
122 | |