1#include <ATen/core/dispatch/Dispatcher.h>
2#include <ATen/core/LegacyTypeDispatch.h>
3#include <ATen/EmptyTensor.h>
4#include <ATen/FunctionalTensorWrapper.h>
5#include <ATen/InferSize.h>
6#include <ATen/TensorUtils.h>
7#include <torch/library.h>
8#include <c10/util/irange.h>
9#include <c10/util/strides.h>
10
11#ifndef AT_PER_OPERATOR_HEADERS
12#include <ATen/ATen.h>
13#include <ATen/Functions.h>
14#include <ATen/NativeFunctions.h>
15#else
16#include <ATen/ops/_to_copy.h>
17#include <ATen/ops/to_native.h>
18#include <ATen/ops/lift.h>
19#include <ATen/ops/lift_fresh.h>
20#include <ATen/ops/lift_fresh_copy.h>
21#include <ATen/ops/resize.h>
22#include <ATen/ops/as_strided.h>
23#include <ATen/ops/as_strided_copy.h>
24#include <ATen/ops/empty_strided_native.h>
25#include <ATen/ops/_unsafe_view.h>
26
27#include <utility>
28#endif
29
30namespace {
31 void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet, torch::jit::Stack* stack) {
32 const auto& schema = op.schema();
33 TORCH_INTERNAL_ASSERT(!schema.hasAnyAliasInfo(), "mutating and aliasing ops should all have codegen'd kernels");
34 const auto num_arguments = schema.arguments().size();
35 const auto arguments_begin = stack->size() - num_arguments;
36 auto arguments = torch::jit::last(stack, num_arguments);
37
38 auto any_functional_inputs = false;
39 auto any_tensor_inputs = false;
40 for (uint64_t idx = 0; idx < num_arguments; ++idx) {
41 const auto& ivalue = arguments[idx];
42 if (ivalue.isTensor()) {
43 any_tensor_inputs = true;
44 const auto& t = ivalue.toTensor();
45 if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) {
46 any_functional_inputs = true;
47 at::functionalization::impl::sync(t);
48 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
49 (*stack)[arguments_begin + idx] = t_new;
50 }
51 } else if (ivalue.isTensorList()) {
52 any_tensor_inputs = true;
53 auto tensors = ivalue.toTensorList();
54 if (at::functionalization::impl::isFunctionalTensor(tensors)) {
55 any_functional_inputs = true;
56 at::functionalization::impl::sync(tensors);
57 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
58 (*stack)[arguments_begin + idx] = t_new;
59 }
60 } else if (ivalue.isOptionalTensorList()) {
61 any_tensor_inputs = true;
62 auto opt_tensors = ivalue.toOptionalTensorList();
63 if (at::functionalization::impl::isFunctionalTensor(opt_tensors)) {
64 any_functional_inputs = true;
65 at::functionalization::impl::sync(opt_tensors);
66 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
67 (*stack)[arguments_begin + idx] = t_new;
68 }
69 }
70 }
71 // we should wrap the output if any inputs were wrapped,
72 // OR if we're hitting a factory function (with no tensor inputs)
73 auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs;
74 {
75 at::AutoDispatchSkipFunctionalize guard;
76 op.callBoxed(stack);
77 }
78 const auto num_returns = schema.returns().size();
79 const auto returns_begin = stack->size() - num_returns;
80 auto returns = torch::jit::last(stack, num_returns);
81
82 for (const auto idx : c10::irange(num_returns)) {
83 const auto& ivalue = returns[idx];
84 if (ivalue.isTensor() && should_wrap_outputs) {
85 const auto& t = ivalue.toTensor();
86 if (!t.defined()) continue;
87 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
88 (*stack)[returns_begin + idx] = t_new;
89 } else if (ivalue.isTensorList() && should_wrap_outputs) {
90 auto tensors = ivalue.toTensorList();
91 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
92 (*stack)[returns_begin + idx] = t_new;
93 } else if (ivalue.isOptionalTensorList() && should_wrap_outputs) {
94 auto opt_tensors = ivalue.toOptionalTensorList();
95 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
96 (*stack)[returns_begin + idx] = t_new;
97 }
98 }
99 }
100}
101
102// resize_() is special because:
103// - when we resize to a larger size, it acts as a mutation
104// - when we resize to a smaller size, it acts as a view
105// See Note [resize_ in Functionalization] for more dtails
106const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional<at::MemoryFormat> memory_format) {
107 // First unwrap the tensor arguments
108 at::Tensor self_;
109 if (at::functionalization::impl::isFunctionalTensor(self)) {
110 at::functionalization::impl::sync(self);
111 self_ = at::functionalization::impl::from_functional_tensor(self);
112 } else {
113 self_ = self;
114 }
115 // Case 1: arguments are not functional tensors, so we no-op and redispatch.
116 if (!at::functionalization::impl::isFunctionalTensor(self)) {
117 at::AutoDispatchSkipFunctionalize guard;
118 self_.resize_(size, memory_format);
119 return self;
120 }
121
122 // Case 2: actually functionalize resize_()
123 at::Tensor tmp_output;
124 {
125 at::AutoDispatchSkipFunctionalize guard;
126 tmp_output = at::resize(self_, size, memory_format);
127 }
128
129 auto itemsize = self.dtype().itemsize();
130 auto storage_offset = self.storage_offset();
131 auto new_size_bytes = at::detail::computeStorageNbytesContiguous(size, itemsize, storage_offset);
132 auto needs_resize_storage = new_size_bytes > self.storage().nbytes();
133
134 if (needs_resize_storage) {
135 // If resize_() actually increases the size of the storage, then we need to tell FunctionalTensorWrapper about it.
136 // See Note[resize_() in functionalization pass]
137 auto func_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
138 func_impl->maybe_replace_storage(tmp_output);
139 // See the note - we're guaranteed at this point that "self" is *not* a view (and has no outstanding views)
140 // So we don't need to treat the output of resize as view tensor.
141 return self;
142 }
143
144 // Otherwise, we know that we're resizing to a smaller size.
145 // resize_() is effectively a view operator.
146 // The output of resizing is equivalent to taking a slice of a larger tensor.
147 // We have to emulate this "slicing" with an as_strided call.
148 auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
149 at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
150 [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
151 if (reapply_views) {
152 return base.as_strided(size, c10::contiguous_strides(size));
153 } else {
154 return at::as_strided_copy(base, size, c10::contiguous_strides(size));
155 }
156 },
157 [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
158 return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
159 }
160 );
161 at::functionalization::impl::mutate_view_meta(self, std::move(view_meta));
162 return self;
163}
164
165
166at::Tensor lift_functionalize(const at::Tensor & self) {
167 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
168 at::AutoDispatchSkipFunctionalize guard;
169 auto out = at::lift(self);
170 return at::functionalization::impl::to_functional_tensor(out);
171}
172
173at::Tensor lift_fresh_functionalize(const at::Tensor & self) {
174 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
175 at::AutoDispatchSkipFunctionalize guard;
176 auto out = at::lift_fresh(self);
177 return at::functionalization::impl::to_functional_tensor(out);
178}
179
180at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) {
181 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
182 at::AutoDispatchSkipFunctionalize guard;
183 auto out = at::lift_fresh_copy(self);
184 return at::functionalization::impl::to_functional_tensor(out);
185}
186
187bool device_opted_into_functionalization(c10::Device self_device, c10::optional<c10::Device> tgt_device) {
188 // If the target device is empty, then the output tensor should be on the same device as the input
189 auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
190 return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy;
191}
192
193// note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above.
194// We should probably get rid of this though.
195at::Tensor _to_copy_functionalize(
196 const at::Tensor & self,
197 c10::optional<at::ScalarType> dtype,
198 c10::optional<at::Layout> layout,
199 c10::optional<at::Device> device,
200 c10::optional<bool> pin_memory,
201 bool non_blocking,
202 c10::optional<at::MemoryFormat> memory_format) {
203 at::Tensor self_;
204 if (at::functionalization::impl::isFunctionalTensor(self)) {
205 // sync any pending updates
206 at::functionalization::impl::sync(self);
207 // pass the unwrapped tensor to the backend
208 self_ = at::functionalization::impl::from_functional_tensor(self);
209 } else {
210 self_ = self;
211 }
212
213 at::AutoDispatchSkipFunctionalize guard;
214 auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format);
215
216 // Special case: if the Functionalize key is not in TLS, we assume that we're running
217 // on a lazy backend (LTC).
218 // In that case, if we're copying to a non-functionalize-enabled device,
219 // then the functionalization pass should "end". We need to sync any updates on the input
220 // tensor, but we shouldn't wrap the output.
221 if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {
222 if (!device_opted_into_functionalization(self.device(), device)) {
223 return out;
224 }
225 }
226 return at::functionalization::impl::to_functional_tensor(out);
227}
228
229
230// Why is _unsafe_view special-cased here?
231// Basically just to satisfy autograd's debug asserts.
232// The situation:
233// - _unsafe_view's autograd kernel has debug asserts to confirm
234// that the input and output alias storage.
235// - _unsafe_view's schema in native_functions.yaml
236// does not contain alias annotations, so it advertises as non-aliasing.
237// - functionalization will then treat _unsafe_view like a non-aliasing op.
238// Specifically, autograd will redispatch to functionalization's
239// boxed fallback kernel, which creates a new FunctionalTensorWrapper output
240// that does **not** alias storage with the input, tripping the assert.
241// The kernel written here just manually re-ifies the aliasing relationship.
242//
243// Another way to handle this would be to fix unsafe_view's alias annotations
244// in native_functions.yaml, but I think this would be a pessimization.
245// The idea with _unsafe_view is that you're guaranteed that the input
246// is a temporary, and don't actually have to worry about propagating
247// mutations between the input and output.
248at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymIntArrayRef size) {
249 if (!at::functionalization::impl::isFunctionalTensor(self)) {
250 at::AutoDispatchSkipFunctionalize guard;
251 return at::_unsafe_view_symint(self, size);
252 }
253
254 auto self_ = at::functionalization::impl::from_functional_tensor(self);
255 at::Tensor tmp_output;
256 {
257 at::AutoDispatchSkipFunctionalize guard;
258 tmp_output = at::_unsafe_view_symint(self_, size);
259 }
260
261 at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
262 [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
263 return at::_unsafe_view_symint(base, size);
264 },
265 [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
266 return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
267 }
268 );
269
270 auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
271 // See Note [Propagating strides in the functionalization pass]
272 // (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
273 auto inferred_size = at::infer_size_dv(size, self.sym_numel());
274 auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
275 TORCH_INTERNAL_ASSERT(stride.has_value());
276 out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
277 return out;
278}
279
280TORCH_LIBRARY_IMPL(_, Functionalize, m) {
281 m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
282}
283
284TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
285 m.impl("resize_", TORCH_FN(resize__functionalization));
286 m.impl("lift", TORCH_FN(lift_functionalize));
287 m.impl("lift_fresh", TORCH_FN(lift_fresh_functionalize));
288 m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy));
289 m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
290 m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
291}
292