1
2#pragma once
3
4#include <ATen/ArrayRef.h>
5#include <ATen/FunctionalStorageImpl.h>
6#include <ATen/core/IListRef.h>
7#include <ATen/core/List.h>
8#include <ATen/core/boxing/BoxedKernel.h>
9#include <ATen/core/boxing/impl/boxing.h>
10#include <ATen/core/dispatch/Dispatcher.h>
11
12#include <c10/core/DispatchKey.h>
13
14namespace at {
15
16// Note [Functionalization Pass In Core]
17// The Functionalization pass is used to remove aliasing from a pytorch program.
18//
19// This is useful for backends that don't support aliasing, like XLA and Vulkan.
20// It's also necessary in order to remove mutation from a program, which is
21// needed in Functorch.
22//
23// Consider this program:
24// a = torch.ones(...)
25// b = a.view(...)
26// b.add_(1)
27//
28// In this program, b is meant to alias with a due to the use of view(). At the
29// end of the program, both a and b are full of 2's. However, backends that
30// don't support aliasing aren't able to correctly implement the view()
31// operator. Instead, they can opt into the Functionalization pass, which will
32// sit between the user and the backend, and provide the necessary aliasing
33// logic.
34//
35// The functionalization pass will turn the above program into a slightly
36// different program that has the same semantics, transparently to the user,
37// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
38// a.view_copy(...) # view() replaced with view_copy(). Backends like
39// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
40// pass machinery knows that a and b are aliased - it applies b's mutation to a
41// too.
42//
43// So, how does the functionalization pass keep track of which tensors are
44// aliased? The pass works by wrapping EVERY tensor in the program inside of a
45// FunctionalTensorWrapper, which knows about its alias'd tensors.
46//
47// See Note [Functionalization: Alias Removal] for details on the aliasing
48// machinery. See Note [Functionalization: Mutation Removal] for details on
49// mutation removal.
50struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
51 explicit FunctionalTensorWrapper(const Tensor& value);
52 // Additional constructor to create a FunctionalTensorWrapper directly from an
53 // underlying tensor that was created from a view. For example, the code b =
54 // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
55 // view1_meta)
56 explicit FunctionalTensorWrapper(
57 const Tensor& view_value,
58 const FunctionalTensorWrapper* base,
59 functionalization::ViewMeta meta);
60
61 // Get the underlying, actual tensor, that doesn't know anything about
62 // functionalization.
63 const Tensor& value() const {
64 return value_;
65 };
66 // The concept of "level" is only ever important to functorch; it's exposed
67 // here as more of a hook for functorch to use.
68 int64_t level() const {
69 return level_;
70 };
71 void set_level(int64_t level) {
72 level_ = level;
73 }
74
75 // Sync's the underlying tensor with its alias, if it's out of date. This
76 // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
77 // Replay the views (if any) to regenerate the current tensor off of the
78 // updated alias.
79 void sync_();
80 // Performs step (1) of the sync. This is its own public API because it's
81 // needed by view_inplace ops like transpose_. See Note [Functionalization
82 // Pass - Inplace View Ops]
83 void regenerate_from_base();
84 // Performs step (2) of the sync. This is its own public API because it's
85 // needed by functorch. functorch wants to make sure that all input tensors to
86 // a functionalized program have been properly synced so it can properly
87 // propagate mutations to inputs. It can't just call sync_(), because the
88 // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
89 // a noop. We use the reference count on storage_ to determine if the wrapper
90 // is aliased, and by the time functorch is ready to propagate updates to
91 // inputs, any intermediate views of the input created by the program will
92 // have been deallocated. This function also returns whether or not the base
93 // actually had any updates to apply.
94 bool apply_updates();
95 // Takes the current state of value_ and snapshots it, sending it as a pending
96 // update to the alias.
97 void commit_update();
98 // When any tensor is mutated, the tensor increments its alias's "generation".
99 // Separately, each tensor maintains its own "generation" counter, which is
100 // used to determine if it's up-to-date with its alias. The act of syncing a
101 // tensor will set a tensor's generation equal to its alias's generation.
102 bool is_up_to_date() const;
103 // Freezes the storage of this tensor, preventing subsequent mutations
104 void freeze_storage() const;
105 // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
106 // describing the series of view ops that ran to generate the current tensor
107 // from the base tensor. This method is used by inplace-view ops like
108 // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
109 // tensor by replaying the views off of the alias.
110 void mutate_view_meta(at::functionalization::ViewMeta meta);
111
112 // The functionalization pass can be used to remove mutations.
113 // It does so by replacing any mutation op with it's corresponding
114 // out-of-place op, followed by a call to replace_(). e.g:
115 //
116 // a.add_(1)
117 //
118 // will turn into:
119 //
120 // tmp = a.add(1)
121 // a.replace_(tmp)
122 //
123 // replace_() swaps out the wrapped tensor, value_, with tmp.
124 void replace_(const Tensor& other);
125
126 // See Note[resize_() in functionalization pass]
127 void maybe_replace_storage(const Tensor& other);
128
129 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
130 const c10::VariableVersion& version_counter,
131 bool allow_tensor_metadata_change) const override;
132
133 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
134 c10::VariableVersion&& version_counter,
135 bool allow_tensor_metadata_change) const override;
136
137 ~FunctionalTensorWrapper() override = default;
138
139 // FunctionalTensorWrapper overrides all custom size/stride function,
140 // so that if the inner tensor has a custom implementation
141 // we make sure to call that implementation.
142 at::IntArrayRef sizes_custom() const override;
143 at::IntArrayRef strides_custom() const override;
144 int64_t dim_custom() const override;
145 int64_t numel_custom() const override;
146 bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
147 c10::SymIntArrayRef sym_sizes_custom() const override;
148 c10::SymInt sym_size_custom(int64_t d) const override;
149 c10::SymIntArrayRef sym_strides_custom() const override;
150 c10::SymInt sym_storage_offset_custom() const override;
151 c10::Device device_custom() const override;
152
153 private:
154 const char* tensorimpl_type_name() const override;
155 void set_constructor_metadata();
156 functionalization::FunctionalStorageImpl* functional_storage_impl() const;
157
158 // This is used to re-implement shallow_copy_and_detach for
159 // FunctionalTensorWrapper. The implementation is identical, but we just need
160 // to return a subclass instead of a plain TensorImpl.
161 // TODO: maybe it's possible to arrange for that to happen automatically
162 // without an override here?
163 template <typename VariableVersion>
164 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
165 VariableVersion&& version_counter,
166 bool allow_tensor_metadata_change) const;
167
168 // Note that value is not taken by reference: internally, the wrapper will
169 // change the value tensor that it points to over time.
170 Tensor value_;
171 int64_t level_;
172
173 size_t generation_ = 0;
174 std::vector<at::functionalization::ViewMeta> view_metas_;
175};
176
177// Utility functions for the functionalization pass.
178
179namespace functionalization {
180namespace impl {
181
182TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
183 const Tensor& tensor) {
184 auto functional_impl =
185 static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
186 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
187 return functional_impl;
188}
189
190TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
191TORCH_API bool isFunctionalTensor(const c10::optional<Tensor>& t);
192TORCH_API bool isFunctionalTensor(
193 const c10::List<c10::optional<Tensor>>& t_list);
194TORCH_API bool isFunctionalTensor(ITensorListRef list);
195
196TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
197TORCH_API c10::optional<Tensor> to_functional_tensor(
198 const c10::optional<Tensor>& tensor);
199TORCH_API c10::List<c10::optional<Tensor>> to_functional_tensor(
200 const c10::List<c10::optional<Tensor>>& t_list);
201TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
202
203TORCH_API void freeze_functional_tensor(const Tensor& tensor);
204
205TORCH_API Tensor
206from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
207TORCH_API c10::optional<Tensor> from_functional_tensor(
208 const c10::optional<Tensor>& t,
209 bool assert_functional = true);
210TORCH_API c10::List<c10::optional<Tensor>> from_functional_tensor(
211 const c10::List<c10::optional<Tensor>>& t_list);
212TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
213
214TORCH_API void sync(const at::Tensor& t);
215TORCH_API void sync(const c10::optional<Tensor>& t);
216TORCH_API void sync(const c10::List<c10::optional<Tensor>> t_list);
217TORCH_API void sync(ITensorListRef t_list);
218
219TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
220TORCH_API void replace_(
221 const ITensorListRef functional_tensor,
222 ITensorListRef other);
223
224TORCH_API void commit_update(const Tensor& functional_tensor);
225TORCH_API void commit_update(ITensorListRef functional_tensor);
226
227Tensor create_functional_tensor_with_view_meta(
228 const Tensor& view_to_wrap,
229 const Tensor& base,
230 functionalization::ViewMeta meta,
231 int64_t out_idx = 0);
232std::vector<Tensor> create_functional_tensor_with_view_meta(
233 ITensorListRef view_to_wrap,
234 const Tensor& base,
235 functionalization::ViewMeta meta);
236
237void mutate_view_meta(const Tensor& self, functionalization::ViewMeta meta);
238
239void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
240void set_sizes_strides_offset(
241 const std::vector<Tensor>& outs,
242 const std::vector<Tensor>& meta_outs);
243
244// ~~~~~ TLS used in functionalization ~~~~~
245
246TORCH_API bool getFunctionalizationReapplyViewsTLS();
247TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
248
249class TORCH_API FunctionalizationReapplyViewsGuard {
250 public:
251 FunctionalizationReapplyViewsGuard(bool reapply_views)
252 : prev_(getFunctionalizationReapplyViewsTLS()) {
253 setFunctionalizationReapplyViewsTLS(reapply_views);
254 }
255
256 ~FunctionalizationReapplyViewsGuard() {
257 setFunctionalizationReapplyViewsTLS(prev_);
258 }
259
260 FunctionalizationReapplyViewsGuard(
261 const FunctionalizationReapplyViewsGuard&) = delete;
262 FunctionalizationReapplyViewsGuard operator=(
263 const FunctionalizationReapplyViewsGuard&) = delete;
264 FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
265 delete;
266 FunctionalizationReapplyViewsGuard operator=(
267 FunctionalizationReapplyViewsGuard&&) = delete;
268
269 private:
270 bool prev_;
271};
272
273} // namespace impl
274
275// Helper function to call an out-of-place composite aten kernel that may use
276// mutations / views internally, and functionalize them.
277TORCH_API void functionalize_op_helper(
278 const c10::OperatorHandle& op,
279 torch::jit::Stack* stack);
280
281template <class Op, bool symint, class ReturnType, class... ParameterTypes>
282struct _functionalize_aten_op final {};
283
284template <class Op, bool symint, class ReturnType, class... ParameterTypes>
285struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
286 static ReturnType call(
287 typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
288 using FuncType = ReturnType(
289 typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
290 auto op = c10::Dispatcher::singleton()
291 .findSchemaOrThrow(
292 (const char*)Op::name, (const char*)Op::overload_name)
293 .typed<FuncType>();
294
295 return c10::impl::BoxedKernelWrapper<FuncType>::call(
296 c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
297 op,
298 // BoxedKernelWrapper knows to ignore this keyset argument,
299 // because functionalize_op_helper doesn't take in a DispatchKeySet
300 c10::DispatchKeySet(),
301 args...);
302 }
303};
304
305template <class Op>
306using functionalize_aten_op =
307 _functionalize_aten_op<Op, false, typename Op::schema>;
308
309template <class Op>
310using functionalize_aten_op_symint =
311 _functionalize_aten_op<Op, true, typename Op::schema>;
312
313} // namespace functionalization
314} // namespace at
315