1 | #include <ATen/core/dispatch/Dispatcher.h> |
2 | #include <ATen/core/LegacyTypeDispatch.h> |
3 | #include <ATen/EmptyTensor.h> |
4 | #include <ATen/FunctionalTensorWrapper.h> |
5 | #include <ATen/InferSize.h> |
6 | #include <ATen/TensorUtils.h> |
7 | #include <torch/library.h> |
8 | #include <c10/util/irange.h> |
9 | #include <c10/util/strides.h> |
10 | |
11 | #ifndef AT_PER_OPERATOR_HEADERS |
12 | #include <ATen/ATen.h> |
13 | #include <ATen/Functions.h> |
14 | #include <ATen/NativeFunctions.h> |
15 | #else |
16 | #include <ATen/ops/_to_copy.h> |
17 | #include <ATen/ops/to_native.h> |
18 | #include <ATen/ops/lift.h> |
19 | #include <ATen/ops/lift_fresh.h> |
20 | #include <ATen/ops/lift_fresh_copy.h> |
21 | #include <ATen/ops/resize.h> |
22 | #include <ATen/ops/as_strided.h> |
23 | #include <ATen/ops/as_strided_copy.h> |
24 | #include <ATen/ops/empty_strided_native.h> |
25 | #include <ATen/ops/_unsafe_view.h> |
26 | |
27 | #include <utility> |
28 | #endif |
29 | |
30 | namespace { |
31 | void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet, torch::jit::Stack* stack) { |
32 | const auto& schema = op.schema(); |
33 | TORCH_INTERNAL_ASSERT(!schema.hasAnyAliasInfo(), "mutating and aliasing ops should all have codegen'd kernels" ); |
34 | const auto num_arguments = schema.arguments().size(); |
35 | const auto arguments_begin = stack->size() - num_arguments; |
36 | auto arguments = torch::jit::last(stack, num_arguments); |
37 | |
38 | auto any_functional_inputs = false; |
39 | auto any_tensor_inputs = false; |
40 | for (uint64_t idx = 0; idx < num_arguments; ++idx) { |
41 | const auto& ivalue = arguments[idx]; |
42 | if (ivalue.isTensor()) { |
43 | any_tensor_inputs = true; |
44 | const auto& t = ivalue.toTensor(); |
45 | if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) { |
46 | any_functional_inputs = true; |
47 | at::functionalization::impl::sync(t); |
48 | auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t)); |
49 | (*stack)[arguments_begin + idx] = t_new; |
50 | } |
51 | } else if (ivalue.isTensorList()) { |
52 | any_tensor_inputs = true; |
53 | auto tensors = ivalue.toTensorList(); |
54 | if (at::functionalization::impl::isFunctionalTensor(tensors)) { |
55 | any_functional_inputs = true; |
56 | at::functionalization::impl::sync(tensors); |
57 | auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors)); |
58 | (*stack)[arguments_begin + idx] = t_new; |
59 | } |
60 | } else if (ivalue.isOptionalTensorList()) { |
61 | any_tensor_inputs = true; |
62 | auto opt_tensors = ivalue.toOptionalTensorList(); |
63 | if (at::functionalization::impl::isFunctionalTensor(opt_tensors)) { |
64 | any_functional_inputs = true; |
65 | at::functionalization::impl::sync(opt_tensors); |
66 | auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors)); |
67 | (*stack)[arguments_begin + idx] = t_new; |
68 | } |
69 | } |
70 | } |
71 | // we should wrap the output if any inputs were wrapped, |
72 | // OR if we're hitting a factory function (with no tensor inputs) |
73 | auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs; |
74 | { |
75 | at::AutoDispatchSkipFunctionalize guard; |
76 | op.callBoxed(stack); |
77 | } |
78 | const auto num_returns = schema.returns().size(); |
79 | const auto returns_begin = stack->size() - num_returns; |
80 | auto returns = torch::jit::last(stack, num_returns); |
81 | |
82 | for (const auto idx : c10::irange(num_returns)) { |
83 | const auto& ivalue = returns[idx]; |
84 | if (ivalue.isTensor() && should_wrap_outputs) { |
85 | const auto& t = ivalue.toTensor(); |
86 | if (!t.defined()) continue; |
87 | auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t)); |
88 | (*stack)[returns_begin + idx] = t_new; |
89 | } else if (ivalue.isTensorList() && should_wrap_outputs) { |
90 | auto tensors = ivalue.toTensorList(); |
91 | auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors)); |
92 | (*stack)[returns_begin + idx] = t_new; |
93 | } else if (ivalue.isOptionalTensorList() && should_wrap_outputs) { |
94 | auto opt_tensors = ivalue.toOptionalTensorList(); |
95 | auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors)); |
96 | (*stack)[returns_begin + idx] = t_new; |
97 | } |
98 | } |
99 | } |
100 | } |
101 | |
102 | // resize_() is special because: |
103 | // - when we resize to a larger size, it acts as a mutation |
104 | // - when we resize to a smaller size, it acts as a view |
105 | // See Note [resize_ in Functionalization] for more dtails |
106 | const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional<at::MemoryFormat> memory_format) { |
107 | // First unwrap the tensor arguments |
108 | at::Tensor self_; |
109 | if (at::functionalization::impl::isFunctionalTensor(self)) { |
110 | at::functionalization::impl::sync(self); |
111 | self_ = at::functionalization::impl::from_functional_tensor(self); |
112 | } else { |
113 | self_ = self; |
114 | } |
115 | // Case 1: arguments are not functional tensors, so we no-op and redispatch. |
116 | if (!at::functionalization::impl::isFunctionalTensor(self)) { |
117 | at::AutoDispatchSkipFunctionalize guard; |
118 | self_.resize_(size, memory_format); |
119 | return self; |
120 | } |
121 | |
122 | // Case 2: actually functionalize resize_() |
123 | at::Tensor tmp_output; |
124 | { |
125 | at::AutoDispatchSkipFunctionalize guard; |
126 | tmp_output = at::resize(self_, size, memory_format); |
127 | } |
128 | |
129 | auto itemsize = self.dtype().itemsize(); |
130 | auto storage_offset = self.storage_offset(); |
131 | auto new_size_bytes = at::detail::computeStorageNbytesContiguous(size, itemsize, storage_offset); |
132 | auto needs_resize_storage = new_size_bytes > self.storage().nbytes(); |
133 | |
134 | if (needs_resize_storage) { |
135 | // If resize_() actually increases the size of the storage, then we need to tell FunctionalTensorWrapper about it. |
136 | // See Note[resize_() in functionalization pass] |
137 | auto func_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); |
138 | func_impl->maybe_replace_storage(tmp_output); |
139 | // See the note - we're guaranteed at this point that "self" is *not* a view (and has no outstanding views) |
140 | // So we don't need to treat the output of resize as view tensor. |
141 | return self; |
142 | } |
143 | |
144 | // Otherwise, we know that we're resizing to a smaller size. |
145 | // resize_() is effectively a view operator. |
146 | // The output of resizing is equivalent to taking a slice of a larger tensor. |
147 | // We have to emulate this "slicing" with an as_strided call. |
148 | auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); |
149 | at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( |
150 | [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor { |
151 | if (reapply_views) { |
152 | return base.as_strided(size, c10::contiguous_strides(size)); |
153 | } else { |
154 | return at::as_strided_copy(base, size, c10::contiguous_strides(size)); |
155 | } |
156 | }, |
157 | [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor { |
158 | return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size)); |
159 | } |
160 | ); |
161 | at::functionalization::impl::mutate_view_meta(self, std::move(view_meta)); |
162 | return self; |
163 | } |
164 | |
165 | |
166 | at::Tensor lift_functionalize(const at::Tensor & self) { |
167 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self)); |
168 | at::AutoDispatchSkipFunctionalize guard; |
169 | auto out = at::lift(self); |
170 | return at::functionalization::impl::to_functional_tensor(out); |
171 | } |
172 | |
173 | at::Tensor lift_fresh_functionalize(const at::Tensor & self) { |
174 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self)); |
175 | at::AutoDispatchSkipFunctionalize guard; |
176 | auto out = at::lift_fresh(self); |
177 | return at::functionalization::impl::to_functional_tensor(out); |
178 | } |
179 | |
180 | at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) { |
181 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self)); |
182 | at::AutoDispatchSkipFunctionalize guard; |
183 | auto out = at::lift_fresh_copy(self); |
184 | return at::functionalization::impl::to_functional_tensor(out); |
185 | } |
186 | |
187 | bool device_opted_into_functionalization(c10::Device self_device, c10::optional<c10::Device> tgt_device) { |
188 | // If the target device is empty, then the output tensor should be on the same device as the input |
189 | auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device; |
190 | return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy; |
191 | } |
192 | |
193 | // note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above. |
194 | // We should probably get rid of this though. |
195 | at::Tensor _to_copy_functionalize( |
196 | const at::Tensor & self, |
197 | c10::optional<at::ScalarType> dtype, |
198 | c10::optional<at::Layout> layout, |
199 | c10::optional<at::Device> device, |
200 | c10::optional<bool> pin_memory, |
201 | bool non_blocking, |
202 | c10::optional<at::MemoryFormat> memory_format) { |
203 | at::Tensor self_; |
204 | if (at::functionalization::impl::isFunctionalTensor(self)) { |
205 | // sync any pending updates |
206 | at::functionalization::impl::sync(self); |
207 | // pass the unwrapped tensor to the backend |
208 | self_ = at::functionalization::impl::from_functional_tensor(self); |
209 | } else { |
210 | self_ = self; |
211 | } |
212 | |
213 | at::AutoDispatchSkipFunctionalize guard; |
214 | auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format); |
215 | |
216 | // Special case: if the Functionalize key is not in TLS, we assume that we're running |
217 | // on a lazy backend (LTC). |
218 | // In that case, if we're copying to a non-functionalize-enabled device, |
219 | // then the functionalization pass should "end". We need to sync any updates on the input |
220 | // tensor, but we shouldn't wrap the output. |
221 | if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) { |
222 | if (!device_opted_into_functionalization(self.device(), device)) { |
223 | return out; |
224 | } |
225 | } |
226 | return at::functionalization::impl::to_functional_tensor(out); |
227 | } |
228 | |
229 | |
230 | // Why is _unsafe_view special-cased here? |
231 | // Basically just to satisfy autograd's debug asserts. |
232 | // The situation: |
233 | // - _unsafe_view's autograd kernel has debug asserts to confirm |
234 | // that the input and output alias storage. |
235 | // - _unsafe_view's schema in native_functions.yaml |
236 | // does not contain alias annotations, so it advertises as non-aliasing. |
237 | // - functionalization will then treat _unsafe_view like a non-aliasing op. |
238 | // Specifically, autograd will redispatch to functionalization's |
239 | // boxed fallback kernel, which creates a new FunctionalTensorWrapper output |
240 | // that does **not** alias storage with the input, tripping the assert. |
241 | // The kernel written here just manually re-ifies the aliasing relationship. |
242 | // |
243 | // Another way to handle this would be to fix unsafe_view's alias annotations |
244 | // in native_functions.yaml, but I think this would be a pessimization. |
245 | // The idea with _unsafe_view is that you're guaranteed that the input |
246 | // is a temporary, and don't actually have to worry about propagating |
247 | // mutations between the input and output. |
248 | at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymIntArrayRef size) { |
249 | if (!at::functionalization::impl::isFunctionalTensor(self)) { |
250 | at::AutoDispatchSkipFunctionalize guard; |
251 | return at::_unsafe_view_symint(self, size); |
252 | } |
253 | |
254 | auto self_ = at::functionalization::impl::from_functional_tensor(self); |
255 | at::Tensor tmp_output; |
256 | { |
257 | at::AutoDispatchSkipFunctionalize guard; |
258 | tmp_output = at::_unsafe_view_symint(self_, size); |
259 | } |
260 | |
261 | at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( |
262 | [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor { |
263 | return at::_unsafe_view_symint(base, size); |
264 | }, |
265 | [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor { |
266 | return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); |
267 | } |
268 | ); |
269 | |
270 | auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta)); |
271 | // See Note [Propagating strides in the functionalization pass] |
272 | // (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view) |
273 | auto inferred_size = at::infer_size_dv(size, self.sym_numel()); |
274 | auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size); |
275 | TORCH_INTERNAL_ASSERT(stride.has_value()); |
276 | out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value()); |
277 | return out; |
278 | } |
279 | |
280 | TORCH_LIBRARY_IMPL(_, Functionalize, m) { |
281 | m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>()); |
282 | } |
283 | |
284 | TORCH_LIBRARY_IMPL(aten, Functionalize, m) { |
285 | m.impl("resize_" , TORCH_FN(resize__functionalization)); |
286 | m.impl("lift" , TORCH_FN(lift_functionalize)); |
287 | m.impl("lift_fresh" , TORCH_FN(lift_fresh_functionalize)); |
288 | m.impl("lift_fresh_copy" , TORCH_FN(lift_fresh_functionalize_copy)); |
289 | m.impl("_to_copy" , TORCH_FN(_to_copy_functionalize)); |
290 | m.impl("_unsafe_view" , TORCH_FN(_unsafe_view_functionalize)); |
291 | } |
292 | |