1 | |
2 | #include <ATen/FunctionalTensorWrapper.h> |
3 | |
4 | #include <ATen/FunctionalInverses.h> |
5 | #include <ATen/TensorUtils.h> |
6 | #include <ATen/WrapDimUtils.h> |
7 | #include <ATen/core/IListRef.h> |
8 | #include <ATen/core/LegacyTypeDispatch.h> |
9 | #include <c10/util/Exception.h> |
10 | |
11 | #include <c10/util/irange.h> |
12 | |
13 | #ifndef AT_PER_OPERATOR_HEADERS |
14 | #include <ATen/Functions.h> |
15 | #else |
16 | #include <ATen/ops/_to_copy.h> |
17 | #endif |
18 | |
19 | namespace at { |
20 | |
21 | void FunctionalTensorWrapper::set_constructor_metadata() { |
22 | TORCH_INTERNAL_ASSERT(value_.defined()); |
23 | // Note: "level" is a concept that we don't know how to compute in core. |
24 | // For now I'm retroactively setting this in functorch, |
25 | // but once Open Multiple Dispatch lands we should be able to calculate this in core. |
26 | level_ = -1; |
27 | // mirror all of the generic tensor metadata onto the wrapper |
28 | copy_generic_tensor_metadata(value_.getIntrusivePtr().get(), this); |
29 | refresh_numel(); |
30 | refresh_contiguous(); |
31 | storage_access_should_throw_ = false; |
32 | // In general, the sizes/stride metadata on a tensor can change as it is mutated, |
33 | // and these changes need to be reflected in the metadata of the wrapper. |
34 | set_allow_tensor_metadata_change(true); |
35 | key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set(); |
36 | // All of the keys corresponding to functorch transforms should not be copied over. |
37 | // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect |
38 | // to participate in the functorch transforms. |
39 | key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks; |
40 | // We override a bunch of _custom(), so make sure they get called |
41 | // TODO: metadata copying may not actually be necessary then |
42 | set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); |
43 | set_custom_device(true); |
44 | } |
45 | |
46 | FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value) |
47 | : c10::TensorImpl( |
48 | c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(value)), |
49 | c10::DispatchKeySet(DispatchKey::Functionalize) | value.key_set(), |
50 | value.dtype() |
51 | ), |
52 | value_(value) |
53 | { |
54 | set_constructor_metadata(); |
55 | } |
56 | |
57 | void FunctionalTensorWrapper::freeze_storage() const { |
58 | functional_storage_impl()->freeze(); |
59 | } |
60 | |
61 | // Note [Functionalization: Alias Removal] |
62 | // When someone calls a view() op during the functionalization pass, e.g. 'b = a.view(...)', |
63 | // we link `b` and `a` to a shared Alias object to preserve the aliasing relationship. |
64 | // |
65 | // How do we do that? |
66 | // |
67 | // Every FunctionalTensorWrapper contains a dummy FunctionalStorageImpl, which subclasses from c10::StorageImpl. |
68 | // It doesn't contain any data (similar to MetaTensor storage), but it contains an Alias object that knows about the base tensor. |
69 | // When a tensor is created through a view operation, both the new and old tensor point to the same FunctionalStorageImpl. |
70 | // |
71 | // As mutations are applied to any of the views, we also queue each mutation up on the Alias object, so we can replay them. |
72 | // When the user requests a tensor that's had a view taken, we check if it's up to date. |
73 | // If it's not up to date, we first replay all of the queued up mutations onto the alias, and then re-apply the current view |
74 | // on top of the newly updated alias. |
75 | // |
76 | // Why do we queue up and lazily run mutations on the alias, instead of updating the alias eagerly? |
77 | // This behavior was taken from pytorch/xla, which the alias-removal logic was inspired from. |
78 | // One benefit of the laziness is that we save work in the cases where a user has multiple views and mutates one of them, |
79 | // but never uses the other views later in the program (in which case we'll never update the alias). |
80 | // It also has downsides though: repeatedly applying mutations to the same view without syncing |
81 | // will silently use up more and more memory as more mutations are queued up. |
82 | // |
83 | // Corresponding diagram: |
84 | // |
85 | // b = a.view(...) |
86 | // |
87 | // a b |
88 | // | | If the user asks for b and it’s out of date, |
89 | // \/ \/ We regenerate b by replaying it’s views from the alias. |
90 | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - . |
91 | // | FunctionalTensorWrapper | | FunctionalTensorWrapper | |
92 | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - . |
93 | // | value | storage | | storage | Value | |
94 | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - . |
95 | // | \ / | |
96 | // | \ / | |
97 | // | . - - - - - - - - - - - - . | |
98 | // | | FunctionalStorageImpl | | |
99 | // | . - - - - - - - - - - - - . | |
100 | // | | Alias | | |
101 | // | . - - - - - - - - - - - - . | |
102 | // | / mutations to a or b | |
103 | // | / are queued onto Alias | |
104 | // | / | |
105 | // \/ / \/ |
106 | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
107 | // | TensorImpl | | TensorImpl | |
108 | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
109 | // | value | storage | | storage | Value | |
110 | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
111 | // | | |
112 | // | | |
113 | // | | |
114 | // | In this picture the two tensor views their own storages, | |
115 | // | have their own storages, but backends like functorch | |
116 | // \/ are allowed to re-alias underneath the pass \/ |
117 | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
118 | // | underyling_storage | | underyling_storage | |
119 | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . |
120 | // |
121 | // This constructor is only used by view ops. |
122 | // - view_value: The output tensor that we need to wrap. |
123 | // - base: The "base" of the view that `view_value` was generated from. |
124 | // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. |
125 | FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, functionalization::ViewMeta meta) |
126 | : c10::TensorImpl( |
127 | c10::DispatchKeySet(DispatchKey::Functionalize), |
128 | view_value.dtype(), |
129 | view_value.device() |
130 | ), |
131 | value_(view_value) |
132 | { |
133 | set_constructor_metadata(); |
134 | // Copy the original tensor's ViewMeta vector and push the current one. |
135 | if (!base->view_metas_.empty()) { |
136 | view_metas_ = base->view_metas_; // copy |
137 | } |
138 | view_metas_.push_back(meta); |
139 | storage_ = base->storage_; // alias this tensor's storage with the base tensor's |
140 | } |
141 | |
142 | functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { |
143 | return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl()); |
144 | } |
145 | |
146 | void FunctionalTensorWrapper::commit_update() { |
147 | auto storage_impl = functional_storage_impl(); |
148 | storage_impl->add_update(value_, view_metas_); |
149 | // As an optimization, we used to mark the tensor here as "up-to-date", |
150 | // That way, code like: |
151 | // x = torch.ones(1'000'000) |
152 | // x[0].add_(1) |
153 | // doesn't result in an unnecessary materialization of the base. |
154 | // This optimization results in the slice temporarily haven't incorrect |
155 | // stride/storage_offset though, and DCE should handle that optimization anyway. |
156 | // generation_ = storage_impl->generation(); |
157 | } |
158 | |
159 | bool FunctionalTensorWrapper::is_up_to_date() const { |
160 | auto alias_generation = functional_storage_impl()->generation(); |
161 | return generation_ == alias_generation; |
162 | } |
163 | |
164 | // See Note [Functionalization Pass - Inplace View Ops] |
165 | void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta meta) { |
166 | view_metas_.push_back(meta); |
167 | // Note [Functionalization Pass - Inplace View Ops] |
168 | // So, these ops are special - they're mutation AND view ops. They get special codegen. |
169 | // An example is transpose_, e.g. `a.transpose_()` |
170 | // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. |
171 | value_ = meta.forward_fn(value_, meta.out_index); |
172 | } |
173 | |
174 | // Note [Functionalization: Mutation Removal] |
175 | // Mutation removal is used to take a program like this: |
176 | // |
177 | // a.add_(b) |
178 | // |
179 | // and replace it with a slightly different program that has the same semantics: |
180 | // |
181 | // tmp = a.add(b) |
182 | // a.replace_(tmp) |
183 | // |
184 | // Where the replace_() call is implemented directly in the functionalization pass, so it is transparent to the backend. |
185 | // This is useful for backends that aren't able to handle certain types of mutations, like functorch. |
186 | // |
187 | // Why do we need to wrap every tensor in a FunctionalTensorWrapper? Consider this program: |
188 | // |
189 | // Before: |
190 | // tensor.add_(batched_tensor) |
191 | // |
192 | // After: |
193 | // tmp = tensor.add(batched_tensor) |
194 | // tensor.replace_(tmp) |
195 | // |
196 | // In the above, tmp is a batched tensor (because adding a normal tensor to a batched tensor does broadcasting and creates a batched tensor). |
197 | // But we can't just replace the underlying memory backing `tensor` with `tmp` - a batched tensor takes up more space! |
198 | // Instead, every input, intermediate and output of the program is wrapped in a FunctionalTensorImpl, which wraps the underlying tensor. |
199 | void FunctionalTensorWrapper::replace_(const Tensor& other) { |
200 | // TODO: going to need to change this if we want nested functionalize() transforms. |
201 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other)); |
202 | value_ = other; |
203 | // out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor. |
204 | // We need to propagate that metadata mutation to the wrapper (new size). |
205 | set_sizes_and_strides(value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset()); |
206 | if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) { |
207 | // .to() should not re-entrantly go through functionalization. |
208 | at::AutoDispatchSkipFunctionalize guard; |
209 | // and we want _to_copy() to show up in the graph, not the composite .to() operator |
210 | // (this can happen if autograd has already run by the time we enter this code) |
211 | value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout())); |
212 | } |
213 | } |
214 | |
215 | void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { |
216 | // Note [resize_() in functionalization pass] |
217 | // resize_() is a special operator in functionalization because it can reallocate its underlying storage. |
218 | // This function is only ever called in the case that resize_() needs to reallocate its storage to a larger size. |
219 | // |
220 | // However, functionalization currently bans the following code: |
221 | // a = torch.ones(2) |
222 | // b = a.view(2) |
223 | // b.resize_(4) # b is a view tensor, that we are trying to increase the storage size of |
224 | // |
225 | // Why is this code difficult to handle? |
226 | // The functionalization pass currently keeps aliases in sync by making the following assumptions: |
227 | // - The “base” tensor always refers to “all of the data” |
228 | // - Whenever you have b = view_op(a), “b” should always refer to a subset of “a”s memory. |
229 | // |
230 | // The code above breaks that assumption b.resize_(4) actually needs to update "a" |
231 | // to tell it that it is now actually some slice of a pre-existing larger storage. |
232 | // We're also no longer re-generate "b" fully from "a" anymore, since "a" refers to a slice of "b"'s data. |
233 | // |
234 | // This is probably fixable in theory, but: |
235 | // - the fix would likey complicated the functionalization logic quite a bit. |
236 | // - the primary use case for resize_() today is resizing zero-sized tensors in out= variants of operators |
237 | // - resize_() also can give you weird results today if you try to resize_() a weirdly strided tensor. |
238 | // |
239 | // Given all of the above, for now we're just banning the above usage. |
240 | TORCH_CHECK(storage().use_count() == 1, "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass" ); |
241 | TORCH_CHECK(view_metas_.empty(), "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass" ); |
242 | // If this tensor is not a view (and has no outstanding views taken out on it), |
243 | // Then it's safe to throw out the old storage and replace it with the new, larger one. |
244 | storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other)); |
245 | value_ = other; |
246 | generation_ = 0; |
247 | // And update the metadata on the wrapper to reflect the new sizes and strides |
248 | set_sizes_and_strides(value_.sizes(), value_.strides()); |
249 | refresh_numel(); |
250 | // (Technically we should be guaranteed that the tensor was already contiguous, |
251 | // since it's guaranteed not to have been a view. Doesnt hurt to run though) |
252 | refresh_contiguous(); |
253 | } |
254 | |
255 | |
256 | void FunctionalTensorWrapper::sync_() { |
257 | if (is_up_to_date()) { |
258 | return; |
259 | } |
260 | apply_updates(); |
261 | regenerate_from_base(); |
262 | } |
263 | |
264 | void FunctionalTensorWrapper::regenerate_from_base() { |
265 | at::AutoDispatchSkipFunctionalize guard; |
266 | auto storage_impl = functional_storage_impl(); |
267 | auto t = storage_impl->base(); |
268 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); |
269 | // Reapply views to get the viewed tensor from the base in alias_ |
270 | for (auto& view_meta: view_metas_) { |
271 | t = view_meta.forward_fn(t, view_meta.out_index); |
272 | } |
273 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); |
274 | replace_(t); |
275 | generation_ = storage_impl->generation(); |
276 | } |
277 | |
278 | bool FunctionalTensorWrapper::apply_updates() { |
279 | // Apply all updates on alias_ |
280 | auto storage_impl = functional_storage_impl(); |
281 | return storage_impl->apply_updates(); |
282 | } |
283 | |
284 | const char* FunctionalTensorWrapper::tensorimpl_type_name() const { |
285 | return "FunctionalTensorWrapper" ; |
286 | } |
287 | |
288 | template <typename VariableVersion> |
289 | c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_core( |
290 | VariableVersion&& version_counter, |
291 | bool allow_tensor_metadata_change) const { |
292 | if (key_set_.has(DispatchKey::Python) && |
293 | !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { |
294 | auto r = pyobj_slot_.load_pyobj_interpreter()->detach(this); |
295 | if (r) { |
296 | r->set_version_counter(std::forward<VariableVersion>(version_counter)); |
297 | r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); |
298 | return r; |
299 | } |
300 | } |
301 | |
302 | auto impl = c10::make_intrusive<FunctionalTensorWrapper>(value_); |
303 | copy_tensor_metadata( |
304 | /*src_impl=*/this, |
305 | /*dest_impl=*/impl.get(), |
306 | /*version_counter=*/std::forward<VariableVersion>(version_counter), |
307 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
308 | impl->level_ = level_; |
309 | impl->generation_ = generation_; |
310 | impl->view_metas_ = view_metas_; |
311 | impl->refresh_numel(); |
312 | impl->refresh_contiguous(); |
313 | return impl; |
314 | } |
315 | |
316 | c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach( |
317 | const c10::VariableVersion& version_counter, |
318 | bool allow_tensor_metadata_change) const { |
319 | return shallow_copy_and_detach_core( |
320 | version_counter, allow_tensor_metadata_change); |
321 | } |
322 | |
323 | c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach( |
324 | c10::VariableVersion&& version_counter, |
325 | bool allow_tensor_metadata_change) const { |
326 | return shallow_copy_and_detach_core( |
327 | std::move(version_counter), allow_tensor_metadata_change); |
328 | } |
329 | |
330 | c10::Device FunctionalTensorWrapper::device_custom() const { |
331 | return value_.unsafeGetTensorImpl()->device(); |
332 | } |
333 | at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const { |
334 | return value_.unsafeGetTensorImpl()->sizes(); |
335 | } |
336 | at::IntArrayRef FunctionalTensorWrapper::strides_custom() const { |
337 | return value_.unsafeGetTensorImpl()->strides(); |
338 | } |
339 | int64_t FunctionalTensorWrapper::dim_custom() const { |
340 | return value_.unsafeGetTensorImpl()->dim(); |
341 | } |
342 | int64_t FunctionalTensorWrapper::numel_custom() const { |
343 | return value_.unsafeGetTensorImpl()->numel(); |
344 | } |
345 | bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const { |
346 | return value_.unsafeGetTensorImpl()->is_contiguous(memory_format); |
347 | } |
348 | c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const { |
349 | return value_.unsafeGetTensorImpl()->sym_sizes(); |
350 | } |
351 | c10::SymIntArrayRef FunctionalTensorWrapper::sym_strides_custom() const { |
352 | return value_.unsafeGetTensorImpl()->sym_strides(); |
353 | } |
354 | c10::SymInt FunctionalTensorWrapper::sym_size_custom(int64_t d) const { |
355 | return value_.unsafeGetTensorImpl()->sym_size(d); |
356 | } |
357 | c10::SymInt FunctionalTensorWrapper::sym_storage_offset_custom() const { |
358 | return value_.unsafeGetTensorImpl()->sym_storage_offset(); |
359 | } |
360 | |
361 | namespace functionalization { |
362 | namespace impl { |
363 | |
364 | Tensor to_functional_tensor(const Tensor& tensor) { |
365 | // Note [Wrapped Numbers <> Functionalization] |
366 | if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { |
367 | return tensor; |
368 | } |
369 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isFunctionalTensor(tensor)); |
370 | return at::detail::make_tensor<FunctionalTensorWrapper>(tensor); |
371 | } |
372 | c10::optional<Tensor> to_functional_tensor(const c10::optional<Tensor>& tensor) { |
373 | if (tensor.has_value()) { |
374 | return c10::make_optional<Tensor>(to_functional_tensor(*tensor)); |
375 | } |
376 | return c10::nullopt; |
377 | } |
378 | c10::List<c10::optional<Tensor>> to_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) { |
379 | c10::List<c10::optional<Tensor>> outputs; |
380 | outputs.reserve(t_list.size()); |
381 | for (const auto i : c10::irange(t_list.size())) { |
382 | outputs.push_back(to_functional_tensor(t_list[i])); |
383 | } |
384 | return outputs; |
385 | } |
386 | std::vector<Tensor> to_functional_tensor(ITensorListRef t_list) { |
387 | std::vector<Tensor> outputs; |
388 | outputs.reserve(t_list.size()); |
389 | for (const auto& tensor : t_list) { |
390 | outputs.push_back(to_functional_tensor(tensor)); |
391 | } |
392 | return outputs; |
393 | } |
394 | |
395 | Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) { |
396 | // Note [Wrapped Numbers <> Functionalization] |
397 | if (!tensor.defined() || tensor.unsafeGetTensorImpl()->is_wrapped_number()) { |
398 | return tensor; |
399 | } |
400 | if (isFunctionalTensor(tensor)) { |
401 | auto impl = unsafeGetFunctionalWrapper(tensor); |
402 | return impl->value(); |
403 | } else { |
404 | // If the current tensor is not functional, then raise an error |
405 | // if assert_functional is true. Otherwise, return the input. |
406 | TORCH_INTERNAL_ASSERT(!assert_functional) |
407 | return tensor; |
408 | } |
409 | } |
410 | c10::optional<Tensor> from_functional_tensor(const c10::optional<Tensor>& t, bool assert_functional) { |
411 | if (t.has_value()) { |
412 | return c10::make_optional<Tensor>(from_functional_tensor(*t, assert_functional)); |
413 | } |
414 | return c10::nullopt; |
415 | } |
416 | std::vector<Tensor> from_functional_tensor(ITensorListRef t_list) { |
417 | std::vector<Tensor> outputs; |
418 | outputs.reserve(t_list.size()); |
419 | for (const auto& tensor : t_list) { |
420 | // from_functional_tensor(Tensor) has asserts to make sure you don't accidentally call |
421 | // it on a non-functional input, |
422 | // but from_functional_tensor(TensorList) can recieve a list containing both |
423 | // functional and non-functional tensors. |
424 | // Example of when that can happen: torch.cat(function_input_tensor, global_state_tensor). |
425 | // When that happens, we're okay with only unwrapping the functional tensors. |
426 | outputs.push_back(from_functional_tensor(tensor, /*assert_functional=*/false)); |
427 | } |
428 | return outputs; |
429 | } |
430 | c10::List<c10::optional<Tensor>> from_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) { |
431 | c10::List<c10::optional<Tensor>> outputs; |
432 | outputs.reserve(t_list.size()); |
433 | for (const auto i : c10::irange(t_list.size())) { |
434 | outputs.push_back(from_functional_tensor(t_list[i], /*assert_functional=*/false)); |
435 | } |
436 | return outputs; |
437 | } |
438 | |
439 | void sync(const Tensor& t) { |
440 | if (t.unsafeGetTensorImpl()->is_wrapped_number()) { |
441 | // Note [Wrapped Numbers <> Functionalization] |
442 | // Unfortunately, we can't easily guarantee that wrapped numbers (scalar-tensors) |
443 | // get wrapped up in a FunctionalTensorWrapper object, since they skip the dispatcher. |
444 | // That shouldn't matter, since I don't think we're allowed to assign to wrapped numbers anyway. |
445 | return; |
446 | } |
447 | // Not every tensor that hits a functionalization kernel is necessarily a functional tensor. |
448 | // For example, xla_tensor.copy_(cpu_tensor) needs to hit the functionalization kernel |
449 | // to sync xla_tensor, but not cpu_tensor. |
450 | if (!at::functionalization::impl::isFunctionalTensor(t)) { |
451 | return; |
452 | } |
453 | auto functional_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); |
454 | functional_impl->sync_(); |
455 | } |
456 | void sync(const c10::optional<Tensor>& t) { |
457 | if (t.has_value()) { |
458 | sync(*t); |
459 | } |
460 | } |
461 | void sync(ITensorListRef t_list) { |
462 | for (const auto& t : t_list) { |
463 | sync(t); |
464 | } |
465 | } |
466 | void sync(const c10::List<c10::optional<Tensor>> t_list) { |
467 | for (const auto i : c10::irange(t_list.size())) { |
468 | sync(t_list[i]); |
469 | } |
470 | } |
471 | |
472 | void replace_(const Tensor& functional_tensor, const Tensor& other) { |
473 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor)); |
474 | unsafeGetFunctionalWrapper(functional_tensor)->replace_(other); |
475 | } |
476 | |
477 | void replace_(const ITensorListRef functional_tensor, ITensorListRef other) { |
478 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size()); |
479 | auto functional_tensor_it = functional_tensor.begin(); |
480 | auto other_it = other.begin(); |
481 | for (const auto i : c10::irange(functional_tensor.size())) { |
482 | (void)i; // Suppress unused variable warning |
483 | replace_(*functional_tensor_it++, *other_it++); |
484 | } |
485 | } |
486 | |
487 | void commit_update(const Tensor& functional_tensor) { |
488 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor)); |
489 | unsafeGetFunctionalWrapper(functional_tensor)->commit_update(); |
490 | } |
491 | |
492 | void commit_update(ITensorListRef functional_tensor) { |
493 | for (const auto& t : functional_tensor) { |
494 | commit_update(t); |
495 | } |
496 | } |
497 | |
498 | bool isFunctionalTensor(const at::Tensor& tensor) { |
499 | return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); |
500 | } |
501 | |
502 | bool isFunctionalTensor(const c10::optional<Tensor>& t) { |
503 | if (t.has_value()) { |
504 | return isFunctionalTensor(*t); |
505 | } else { |
506 | return false; |
507 | } |
508 | } |
509 | |
510 | bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) { |
511 | if (t_list.empty()) return false; |
512 | auto functional_count = 0; |
513 | for (const auto i : c10::irange(t_list.size())) { |
514 | if (!t_list[i].has_value() || !t_list[i]->defined()) continue; |
515 | if (isFunctionalTensor(t_list[i])) { |
516 | ++functional_count; |
517 | } |
518 | } |
519 | return functional_count > 0; |
520 | } |
521 | |
522 | template <typename T> |
523 | bool isFunctionalTensorIListRef(c10::IListRef<T> list) { |
524 | if (list.size() == 0) return false; |
525 | auto functional_count = 0; |
526 | for (const auto& tensor : list) { |
527 | if (!tensor.defined()) continue; |
528 | if (isFunctionalTensor(tensor)) { |
529 | ++functional_count; |
530 | } |
531 | } |
532 | return functional_count > 0; |
533 | } |
534 | |
535 | bool isFunctionalTensor(ITensorListRef list) { |
536 | return isFunctionalTensorIListRef(list); |
537 | } |
538 | |
539 | void freeze_functional_tensor(const Tensor& tensor) { |
540 | TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(tensor)); |
541 | auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); |
542 | functional_base_impl->freeze_storage(); |
543 | } |
544 | |
545 | Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { |
546 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); |
547 | TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); |
548 | auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); |
549 | if (out_idx != 0) { |
550 | // Note [out_idx in ViewMeta] |
551 | // When a view op outputs multiple tensors, each output needs its own separate ViewMeta. |
552 | // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. |
553 | meta = meta.to_out_idx(out_idx); |
554 | } |
555 | return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta); |
556 | } |
557 | |
558 | std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) { |
559 | std::vector<Tensor> outputs(view_to_wrap.size()); |
560 | int64_t i = 0; |
561 | for (const auto& tensor : view_to_wrap) { |
562 | outputs[i] = create_functional_tensor_with_view_meta(tensor, base, meta, i); |
563 | i++; |
564 | } |
565 | return outputs; |
566 | } |
567 | |
568 | void mutate_view_meta(const at::Tensor& self, functionalization::ViewMeta meta) { |
569 | TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); |
570 | auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); |
571 | self_impl->mutate_view_meta(std::move(meta)); |
572 | } |
573 | |
574 | // Note [Propagating strides in the functionalization pass] |
575 | // In order to properly compute stride information, the functionalization pass |
576 | // calls each {view} reference implementations with meta tensors. |
577 | // The output meta tensor's stride info serves as a reference for what the correct strides should be. |
578 | void set_sizes_strides_offset(const Tensor& out, const Tensor& reference_out) { |
579 | out.unsafeGetTensorImpl()->set_sizes_and_strides(reference_out.sym_sizes(), reference_out.sym_strides(), reference_out.sym_storage_offset()); |
580 | } |
581 | |
582 | void set_sizes_strides_offset(const std::vector<Tensor>& outs, const std::vector<Tensor>& reference_outs) { |
583 | TORCH_INTERNAL_ASSERT(outs.size() == reference_outs.size()); |
584 | for (const auto i : c10::irange(reference_outs.size())) { |
585 | set_sizes_strides_offset(outs[i], reference_outs[i]); |
586 | } |
587 | } |
588 | |
589 | thread_local bool _functionalizationReapplyViews; |
590 | |
591 | bool getFunctionalizationReapplyViewsTLS() { |
592 | return _functionalizationReapplyViews; |
593 | } |
594 | void setFunctionalizationReapplyViewsTLS(bool reapply_views) { |
595 | _functionalizationReapplyViews = reapply_views; |
596 | } |
597 | |
598 | } // namespace impl |
599 | |
600 | |
601 | // Given an **out-of-place** op that might internally call view/inplace ops, |
602 | // This function will "functionalize" it. |
603 | // That is, it will call the operator, but removing any intermediate views/mutations |
604 | // that are performed inside of it. |
605 | // This is useful for LTC/XLA, which would like to re-use some of our composite kernels |
606 | // from pytorch core but not have to worry about the view ops that they might call. |
607 | // e.g. at::block_diag |
608 | void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) { |
609 | const auto& schema = op.schema(); |
610 | const auto num_arguments = schema.arguments().size(); |
611 | const auto arguments_begin = stack->size() - num_arguments; |
612 | auto arguments = torch::jit::last(stack, num_arguments); |
613 | |
614 | // Wrap all tensor-like inputs into FunctionalTensorWrappers. |
615 | // When we re-invoke the dispatcher, this will automatically enable the functionalization pass. |
616 | for (uint64_t idx = 0; idx < num_arguments; ++idx) { |
617 | const auto& ivalue = arguments[idx]; |
618 | if (ivalue.isTensor()) { |
619 | const auto& t = ivalue.toTensor(); |
620 | if (t.defined()) { |
621 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t), |
622 | "The composite op functionalization fallback expects its inputs all not to be functional tensors" ); |
623 | auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t)); |
624 | (*stack)[arguments_begin + idx] = t_new; |
625 | } |
626 | } else if (ivalue.isTensorList()) { |
627 | auto tensors = ivalue.toTensorList(); |
628 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensors), |
629 | "The composite op functionalization fallback expects its inputs all not to be functional tensors" ); |
630 | auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors)); |
631 | (*stack)[arguments_begin + idx] = t_new; |
632 | } else if (ivalue.isOptionalTensorList()) { |
633 | auto opt_tensors = ivalue.toOptionalTensorList(); |
634 | TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(opt_tensors), |
635 | "The composite op functionalization fallback expects its inputs all not to be functional tensors" ); |
636 | auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors)); |
637 | (*stack)[arguments_begin + idx] = t_new; |
638 | } |
639 | } |
640 | |
641 | { |
642 | // Today when you call at::empty(device=lazy), the lazy backend decides whether or not to wrap |
643 | // the output in a functional tensor based on TLS. |
644 | // In this code, we're re-entrantly entering functionalization in the same call-stack, |
645 | // so we need to manually fix up TLS as if it hadn't already been called. |
646 | auto curr_tls = c10::impl::tls_local_dispatch_key_set(); |
647 | auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet(); |
648 | tls_reenable_functionalize.set_included(curr_tls.included_); |
649 | tls_reenable_functionalize.set_excluded(curr_tls.excluded_.remove(c10::DispatchKey::Functionalize)); |
650 | c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize); |
651 | // So, we should probably provide a way to directly call a kernel registered to |
652 | // the `CompositeExplicitAutograd` key. |
653 | // We can't do that today, so this should be a reasonably good proxy |
654 | // (It won't work in cases where an op has both a CompositeExplicitAutograd kernel |
655 | // AND a dedicated meta kernel, but that probably shouldn't ever happen). |
656 | op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::Meta), stack); |
657 | } |
658 | |
659 | const auto num_returns = schema.returns().size(); |
660 | const auto returns_begin = stack->size() - num_returns; |
661 | auto returns = torch::jit::last(stack, num_returns); |
662 | |
663 | for (const auto idx : c10::irange(num_returns)) { |
664 | const auto& ivalue = returns[idx]; |
665 | if (ivalue.isTensor()) { |
666 | const auto& t = ivalue.toTensor(); |
667 | if (!t.defined()) continue; |
668 | at::functionalization::impl::sync(t); |
669 | auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t)); |
670 | (*stack)[returns_begin + idx] = t_new; |
671 | } else if (ivalue.isTensorList()) { |
672 | auto tensors = ivalue.toTensorList(); |
673 | at::functionalization::impl::sync(tensors); |
674 | auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors)); |
675 | (*stack)[returns_begin + idx] = t_new; |
676 | } else if (ivalue.isOptionalTensorList()) { |
677 | auto opt_tensors = ivalue.toOptionalTensorList(); |
678 | at::functionalization::impl::sync(opt_tensors); |
679 | auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors)); |
680 | (*stack)[returns_begin + idx] = t_new; |
681 | } |
682 | } |
683 | } |
684 | |
685 | |
686 | |
687 | } // namespace functionalization |
688 | } // namespace at |
689 | |