1 | |
2 | #pragma once |
3 | |
4 | #include <ATen/ArrayRef.h> |
5 | #include <ATen/FunctionalStorageImpl.h> |
6 | #include <ATen/core/IListRef.h> |
7 | #include <ATen/core/List.h> |
8 | #include <ATen/core/boxing/BoxedKernel.h> |
9 | #include <ATen/core/boxing/impl/boxing.h> |
10 | #include <ATen/core/dispatch/Dispatcher.h> |
11 | |
12 | #include <c10/core/DispatchKey.h> |
13 | |
14 | namespace at { |
15 | |
16 | // Note [Functionalization Pass In Core] |
17 | // The Functionalization pass is used to remove aliasing from a pytorch program. |
18 | // |
19 | // This is useful for backends that don't support aliasing, like XLA and Vulkan. |
20 | // It's also necessary in order to remove mutation from a program, which is |
21 | // needed in Functorch. |
22 | // |
23 | // Consider this program: |
24 | // a = torch.ones(...) |
25 | // b = a.view(...) |
26 | // b.add_(1) |
27 | // |
28 | // In this program, b is meant to alias with a due to the use of view(). At the |
29 | // end of the program, both a and b are full of 2's. However, backends that |
30 | // don't support aliasing aren't able to correctly implement the view() |
31 | // operator. Instead, they can opt into the Functionalization pass, which will |
32 | // sit between the user and the backend, and provide the necessary aliasing |
33 | // logic. |
34 | // |
35 | // The functionalization pass will turn the above program into a slightly |
36 | // different program that has the same semantics, transparently to the user, |
37 | // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b = |
38 | // a.view_copy(...) # view() replaced with view_copy(). Backends like |
39 | // XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization |
40 | // pass machinery knows that a and b are aliased - it applies b's mutation to a |
41 | // too. |
42 | // |
43 | // So, how does the functionalization pass keep track of which tensors are |
44 | // aliased? The pass works by wrapping EVERY tensor in the program inside of a |
45 | // FunctionalTensorWrapper, which knows about its alias'd tensors. |
46 | // |
47 | // See Note [Functionalization: Alias Removal] for details on the aliasing |
48 | // machinery. See Note [Functionalization: Mutation Removal] for details on |
49 | // mutation removal. |
50 | struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { |
51 | explicit FunctionalTensorWrapper(const Tensor& value); |
52 | // Additional constructor to create a FunctionalTensorWrapper directly from an |
53 | // underlying tensor that was created from a view. For example, the code b = |
54 | // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a, |
55 | // view1_meta) |
56 | explicit FunctionalTensorWrapper( |
57 | const Tensor& view_value, |
58 | const FunctionalTensorWrapper* base, |
59 | functionalization::ViewMeta meta); |
60 | |
61 | // Get the underlying, actual tensor, that doesn't know anything about |
62 | // functionalization. |
63 | const Tensor& value() const { |
64 | return value_; |
65 | }; |
66 | // The concept of "level" is only ever important to functorch; it's exposed |
67 | // here as more of a hook for functorch to use. |
68 | int64_t level() const { |
69 | return level_; |
70 | }; |
71 | void set_level(int64_t level) { |
72 | level_ = level; |
73 | } |
74 | |
75 | // Sync's the underlying tensor with its alias, if it's out of date. This |
76 | // involves two steps: 1) Apply any pending updates/mutations to the alias 2) |
77 | // Replay the views (if any) to regenerate the current tensor off of the |
78 | // updated alias. |
79 | void sync_(); |
80 | // Performs step (1) of the sync. This is its own public API because it's |
81 | // needed by view_inplace ops like transpose_. See Note [Functionalization |
82 | // Pass - Inplace View Ops] |
83 | void regenerate_from_base(); |
84 | // Performs step (2) of the sync. This is its own public API because it's |
85 | // needed by functorch. functorch wants to make sure that all input tensors to |
86 | // a functionalized program have been properly synced so it can properly |
87 | // propagate mutations to inputs. It can't just call sync_(), because the |
88 | // FunctionalTensorWrapper will look like it has no aliases and sync_ will be |
89 | // a noop. We use the reference count on storage_ to determine if the wrapper |
90 | // is aliased, and by the time functorch is ready to propagate updates to |
91 | // inputs, any intermediate views of the input created by the program will |
92 | // have been deallocated. This function also returns whether or not the base |
93 | // actually had any updates to apply. |
94 | bool apply_updates(); |
95 | // Takes the current state of value_ and snapshots it, sending it as a pending |
96 | // update to the alias. |
97 | void commit_update(); |
98 | // When any tensor is mutated, the tensor increments its alias's "generation". |
99 | // Separately, each tensor maintains its own "generation" counter, which is |
100 | // used to determine if it's up-to-date with its alias. The act of syncing a |
101 | // tensor will set a tensor's generation equal to its alias's generation. |
102 | bool is_up_to_date() const; |
103 | // Freezes the storage of this tensor, preventing subsequent mutations |
104 | void freeze_storage() const; |
105 | // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects |
106 | // describing the series of view ops that ran to generate the current tensor |
107 | // from the base tensor. This method is used by inplace-view ops like |
108 | // transpose_. It appends a ViewMeta to the existing stack, and refreshes the |
109 | // tensor by replaying the views off of the alias. |
110 | void mutate_view_meta(at::functionalization::ViewMeta meta); |
111 | |
112 | // The functionalization pass can be used to remove mutations. |
113 | // It does so by replacing any mutation op with it's corresponding |
114 | // out-of-place op, followed by a call to replace_(). e.g: |
115 | // |
116 | // a.add_(1) |
117 | // |
118 | // will turn into: |
119 | // |
120 | // tmp = a.add(1) |
121 | // a.replace_(tmp) |
122 | // |
123 | // replace_() swaps out the wrapped tensor, value_, with tmp. |
124 | void replace_(const Tensor& other); |
125 | |
126 | // See Note[resize_() in functionalization pass] |
127 | void maybe_replace_storage(const Tensor& other); |
128 | |
129 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
130 | const c10::VariableVersion& version_counter, |
131 | bool allow_tensor_metadata_change) const override; |
132 | |
133 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
134 | c10::VariableVersion&& version_counter, |
135 | bool allow_tensor_metadata_change) const override; |
136 | |
137 | ~FunctionalTensorWrapper() override = default; |
138 | |
139 | // FunctionalTensorWrapper overrides all custom size/stride function, |
140 | // so that if the inner tensor has a custom implementation |
141 | // we make sure to call that implementation. |
142 | at::IntArrayRef sizes_custom() const override; |
143 | at::IntArrayRef strides_custom() const override; |
144 | int64_t dim_custom() const override; |
145 | int64_t numel_custom() const override; |
146 | bool is_contiguous_custom(at::MemoryFormat memory_format) const override; |
147 | c10::SymIntArrayRef sym_sizes_custom() const override; |
148 | c10::SymInt sym_size_custom(int64_t d) const override; |
149 | c10::SymIntArrayRef sym_strides_custom() const override; |
150 | c10::SymInt sym_storage_offset_custom() const override; |
151 | c10::Device device_custom() const override; |
152 | |
153 | private: |
154 | const char* tensorimpl_type_name() const override; |
155 | void set_constructor_metadata(); |
156 | functionalization::FunctionalStorageImpl* functional_storage_impl() const; |
157 | |
158 | // This is used to re-implement shallow_copy_and_detach for |
159 | // FunctionalTensorWrapper. The implementation is identical, but we just need |
160 | // to return a subclass instead of a plain TensorImpl. |
161 | // TODO: maybe it's possible to arrange for that to happen automatically |
162 | // without an override here? |
163 | template <typename VariableVersion> |
164 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core( |
165 | VariableVersion&& version_counter, |
166 | bool allow_tensor_metadata_change) const; |
167 | |
168 | // Note that value is not taken by reference: internally, the wrapper will |
169 | // change the value tensor that it points to over time. |
170 | Tensor value_; |
171 | int64_t level_; |
172 | |
173 | size_t generation_ = 0; |
174 | std::vector<at::functionalization::ViewMeta> view_metas_; |
175 | }; |
176 | |
177 | // Utility functions for the functionalization pass. |
178 | |
179 | namespace functionalization { |
180 | namespace impl { |
181 | |
182 | TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( |
183 | const Tensor& tensor) { |
184 | auto functional_impl = |
185 | static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl()); |
186 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr); |
187 | return functional_impl; |
188 | } |
189 | |
190 | TORCH_API bool isFunctionalTensor(const at::Tensor& tensor); |
191 | TORCH_API bool isFunctionalTensor(const c10::optional<Tensor>& t); |
192 | TORCH_API bool isFunctionalTensor( |
193 | const c10::List<c10::optional<Tensor>>& t_list); |
194 | TORCH_API bool isFunctionalTensor(ITensorListRef list); |
195 | |
196 | TORCH_API Tensor to_functional_tensor(const Tensor& tensor); |
197 | TORCH_API c10::optional<Tensor> to_functional_tensor( |
198 | const c10::optional<Tensor>& tensor); |
199 | TORCH_API c10::List<c10::optional<Tensor>> to_functional_tensor( |
200 | const c10::List<c10::optional<Tensor>>& t_list); |
201 | TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list); |
202 | |
203 | TORCH_API void freeze_functional_tensor(const Tensor& tensor); |
204 | |
205 | TORCH_API Tensor |
206 | from_functional_tensor(const Tensor& tensor, bool assert_functional = true); |
207 | TORCH_API c10::optional<Tensor> from_functional_tensor( |
208 | const c10::optional<Tensor>& t, |
209 | bool assert_functional = true); |
210 | TORCH_API c10::List<c10::optional<Tensor>> from_functional_tensor( |
211 | const c10::List<c10::optional<Tensor>>& t_list); |
212 | TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list); |
213 | |
214 | TORCH_API void sync(const at::Tensor& t); |
215 | TORCH_API void sync(const c10::optional<Tensor>& t); |
216 | TORCH_API void sync(const c10::List<c10::optional<Tensor>> t_list); |
217 | TORCH_API void sync(ITensorListRef t_list); |
218 | |
219 | TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other); |
220 | TORCH_API void replace_( |
221 | const ITensorListRef functional_tensor, |
222 | ITensorListRef other); |
223 | |
224 | TORCH_API void commit_update(const Tensor& functional_tensor); |
225 | TORCH_API void commit_update(ITensorListRef functional_tensor); |
226 | |
227 | Tensor create_functional_tensor_with_view_meta( |
228 | const Tensor& view_to_wrap, |
229 | const Tensor& base, |
230 | functionalization::ViewMeta meta, |
231 | int64_t out_idx = 0); |
232 | std::vector<Tensor> create_functional_tensor_with_view_meta( |
233 | ITensorListRef view_to_wrap, |
234 | const Tensor& base, |
235 | functionalization::ViewMeta meta); |
236 | |
237 | void mutate_view_meta(const Tensor& self, functionalization::ViewMeta meta); |
238 | |
239 | void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); |
240 | void set_sizes_strides_offset( |
241 | const std::vector<Tensor>& outs, |
242 | const std::vector<Tensor>& meta_outs); |
243 | |
244 | // ~~~~~ TLS used in functionalization ~~~~~ |
245 | |
246 | TORCH_API bool getFunctionalizationReapplyViewsTLS(); |
247 | TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views); |
248 | |
249 | class TORCH_API FunctionalizationReapplyViewsGuard { |
250 | public: |
251 | FunctionalizationReapplyViewsGuard(bool reapply_views) |
252 | : prev_(getFunctionalizationReapplyViewsTLS()) { |
253 | setFunctionalizationReapplyViewsTLS(reapply_views); |
254 | } |
255 | |
256 | ~FunctionalizationReapplyViewsGuard() { |
257 | setFunctionalizationReapplyViewsTLS(prev_); |
258 | } |
259 | |
260 | FunctionalizationReapplyViewsGuard( |
261 | const FunctionalizationReapplyViewsGuard&) = delete; |
262 | FunctionalizationReapplyViewsGuard operator=( |
263 | const FunctionalizationReapplyViewsGuard&) = delete; |
264 | FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) = |
265 | delete; |
266 | FunctionalizationReapplyViewsGuard operator=( |
267 | FunctionalizationReapplyViewsGuard&&) = delete; |
268 | |
269 | private: |
270 | bool prev_; |
271 | }; |
272 | |
273 | } // namespace impl |
274 | |
275 | // Helper function to call an out-of-place composite aten kernel that may use |
276 | // mutations / views internally, and functionalize them. |
277 | TORCH_API void functionalize_op_helper( |
278 | const c10::OperatorHandle& op, |
279 | torch::jit::Stack* stack); |
280 | |
281 | template <class Op, bool symint, class ReturnType, class... ParameterTypes> |
282 | struct _functionalize_aten_op final {}; |
283 | |
284 | template <class Op, bool symint, class ReturnType, class... ParameterTypes> |
285 | struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final { |
286 | static ReturnType call( |
287 | typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) { |
288 | using FuncType = ReturnType( |
289 | typename c10::maybe_keep_symint<symint, ParameterTypes>::type...); |
290 | auto op = c10::Dispatcher::singleton() |
291 | .findSchemaOrThrow( |
292 | (const char*)Op::name, (const char*)Op::overload_name) |
293 | .typed<FuncType>(); |
294 | |
295 | return c10::impl::BoxedKernelWrapper<FuncType>::call( |
296 | c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(), |
297 | op, |
298 | // BoxedKernelWrapper knows to ignore this keyset argument, |
299 | // because functionalize_op_helper doesn't take in a DispatchKeySet |
300 | c10::DispatchKeySet(), |
301 | args...); |
302 | } |
303 | }; |
304 | |
305 | template <class Op> |
306 | using functionalize_aten_op = |
307 | _functionalize_aten_op<Op, false, typename Op::schema>; |
308 | |
309 | template <class Op> |
310 | using functionalize_aten_op_symint = |
311 | _functionalize_aten_op<Op, true, typename Op::schema>; |
312 | |
313 | } // namespace functionalization |
314 | } // namespace at |
315 | |