1
2#include <ATen/FunctionalTensorWrapper.h>
3
4#include <ATen/FunctionalInverses.h>
5#include <ATen/TensorUtils.h>
6#include <ATen/WrapDimUtils.h>
7#include <ATen/core/IListRef.h>
8#include <ATen/core/LegacyTypeDispatch.h>
9#include <c10/util/Exception.h>
10
11#include <c10/util/irange.h>
12
13#ifndef AT_PER_OPERATOR_HEADERS
14#include <ATen/Functions.h>
15#else
16#include <ATen/ops/_to_copy.h>
17#endif
18
19namespace at {
20
21void FunctionalTensorWrapper::set_constructor_metadata() {
22 TORCH_INTERNAL_ASSERT(value_.defined());
23 // Note: "level" is a concept that we don't know how to compute in core.
24 // For now I'm retroactively setting this in functorch,
25 // but once Open Multiple Dispatch lands we should be able to calculate this in core.
26 level_ = -1;
27 // mirror all of the generic tensor metadata onto the wrapper
28 copy_generic_tensor_metadata(value_.getIntrusivePtr().get(), this);
29 refresh_numel();
30 refresh_contiguous();
31 storage_access_should_throw_ = false;
32 // In general, the sizes/stride metadata on a tensor can change as it is mutated,
33 // and these changes need to be reflected in the metadata of the wrapper.
34 set_allow_tensor_metadata_change(true);
35 key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set();
36 // All of the keys corresponding to functorch transforms should not be copied over.
37 // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
38 // to participate in the functorch transforms.
39 key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks;
40 // We override a bunch of _custom(), so make sure they get called
41 // TODO: metadata copying may not actually be necessary then
42 set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
43 set_custom_device(true);
44}
45
46FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
47 : c10::TensorImpl(
48 c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(value)),
49 c10::DispatchKeySet(DispatchKey::Functionalize) | value.key_set(),
50 value.dtype()
51 ),
52 value_(value)
53{
54 set_constructor_metadata();
55}
56
57void FunctionalTensorWrapper::freeze_storage() const {
58 functional_storage_impl()->freeze();
59}
60
61// Note [Functionalization: Alias Removal]
62// When someone calls a view() op during the functionalization pass, e.g. 'b = a.view(...)',
63// we link `b` and `a` to a shared Alias object to preserve the aliasing relationship.
64//
65// How do we do that?
66//
67// Every FunctionalTensorWrapper contains a dummy FunctionalStorageImpl, which subclasses from c10::StorageImpl.
68// It doesn't contain any data (similar to MetaTensor storage), but it contains an Alias object that knows about the base tensor.
69// When a tensor is created through a view operation, both the new and old tensor point to the same FunctionalStorageImpl.
70//
71// As mutations are applied to any of the views, we also queue each mutation up on the Alias object, so we can replay them.
72// When the user requests a tensor that's had a view taken, we check if it's up to date.
73// If it's not up to date, we first replay all of the queued up mutations onto the alias, and then re-apply the current view
74// on top of the newly updated alias.
75//
76// Why do we queue up and lazily run mutations on the alias, instead of updating the alias eagerly?
77// This behavior was taken from pytorch/xla, which the alias-removal logic was inspired from.
78// One benefit of the laziness is that we save work in the cases where a user has multiple views and mutates one of them,
79// but never uses the other views later in the program (in which case we'll never update the alias).
80// It also has downsides though: repeatedly applying mutations to the same view without syncing
81// will silently use up more and more memory as more mutations are queued up.
82//
83// Corresponding diagram:
84//
85// b = a.view(...)
86//
87// a b
88// | | If the user asks for b and it’s out of date,
89// \/ \/ We regenerate b by replaying it’s views from the alias.
90// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - .
91// | FunctionalTensorWrapper | | FunctionalTensorWrapper |
92// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - .
93// | value | storage | | storage | Value |
94// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - .
95// | \ / |
96// | \ / |
97// | . - - - - - - - - - - - - . |
98// | | FunctionalStorageImpl | |
99// | . - - - - - - - - - - - - . |
100// | | Alias | |
101// | . - - - - - - - - - - - - . |
102// | / mutations to a or b |
103// | / are queued onto Alias |
104// | / |
105// \/ / \/
106// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
107// | TensorImpl | | TensorImpl |
108// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
109// | value | storage | | storage | Value |
110// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
111// | |
112// | |
113// | |
114// | In this picture the two tensor views their own storages, |
115// | have their own storages, but backends like functorch |
116// \/ are allowed to re-alias underneath the pass \/
117// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
118// | underyling_storage | | underyling_storage |
119// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
120//
121// This constructor is only used by view ops.
122// - view_value: The output tensor that we need to wrap.
123// - base: The "base" of the view that `view_value` was generated from.
124// See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
125FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, functionalization::ViewMeta meta)
126 : c10::TensorImpl(
127 c10::DispatchKeySet(DispatchKey::Functionalize),
128 view_value.dtype(),
129 view_value.device()
130 ),
131 value_(view_value)
132{
133 set_constructor_metadata();
134 // Copy the original tensor's ViewMeta vector and push the current one.
135 if (!base->view_metas_.empty()) {
136 view_metas_ = base->view_metas_; // copy
137 }
138 view_metas_.push_back(meta);
139 storage_ = base->storage_; // alias this tensor's storage with the base tensor's
140}
141
142functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
143 return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
144}
145
146void FunctionalTensorWrapper::commit_update() {
147 auto storage_impl = functional_storage_impl();
148 storage_impl->add_update(value_, view_metas_);
149 // As an optimization, we used to mark the tensor here as "up-to-date",
150 // That way, code like:
151 // x = torch.ones(1'000'000)
152 // x[0].add_(1)
153 // doesn't result in an unnecessary materialization of the base.
154 // This optimization results in the slice temporarily haven't incorrect
155 // stride/storage_offset though, and DCE should handle that optimization anyway.
156 // generation_ = storage_impl->generation();
157}
158
159bool FunctionalTensorWrapper::is_up_to_date() const {
160 auto alias_generation = functional_storage_impl()->generation();
161 return generation_ == alias_generation;
162}
163
164// See Note [Functionalization Pass - Inplace View Ops]
165void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta meta) {
166 view_metas_.push_back(meta);
167 // Note [Functionalization Pass - Inplace View Ops]
168 // So, these ops are special - they're mutation AND view ops. They get special codegen.
169 // An example is transpose_, e.g. `a.transpose_()`
170 // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
171 value_ = meta.forward_fn(value_, meta.out_index);
172}
173
174// Note [Functionalization: Mutation Removal]
175// Mutation removal is used to take a program like this:
176//
177// a.add_(b)
178//
179// and replace it with a slightly different program that has the same semantics:
180//
181// tmp = a.add(b)
182// a.replace_(tmp)
183//
184// Where the replace_() call is implemented directly in the functionalization pass, so it is transparent to the backend.
185// This is useful for backends that aren't able to handle certain types of mutations, like functorch.
186//
187// Why do we need to wrap every tensor in a FunctionalTensorWrapper? Consider this program:
188//
189// Before:
190// tensor.add_(batched_tensor)
191//
192// After:
193// tmp = tensor.add(batched_tensor)
194// tensor.replace_(tmp)
195//
196// In the above, tmp is a batched tensor (because adding a normal tensor to a batched tensor does broadcasting and creates a batched tensor).
197// But we can't just replace the underlying memory backing `tensor` with `tmp` - a batched tensor takes up more space!
198// Instead, every input, intermediate and output of the program is wrapped in a FunctionalTensorImpl, which wraps the underlying tensor.
199void FunctionalTensorWrapper::replace_(const Tensor& other) {
200 // TODO: going to need to change this if we want nested functionalize() transforms.
201 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
202 value_ = other;
203 // out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
204 // We need to propagate that metadata mutation to the wrapper (new size).
205 set_sizes_and_strides(value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset());
206 if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
207 // .to() should not re-entrantly go through functionalization.
208 at::AutoDispatchSkipFunctionalize guard;
209 // and we want _to_copy() to show up in the graph, not the composite .to() operator
210 // (this can happen if autograd has already run by the time we enter this code)
211 value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
212 }
213}
214
215void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
216 // Note [resize_() in functionalization pass]
217 // resize_() is a special operator in functionalization because it can reallocate its underlying storage.
218 // This function is only ever called in the case that resize_() needs to reallocate its storage to a larger size.
219 //
220 // However, functionalization currently bans the following code:
221 // a = torch.ones(2)
222 // b = a.view(2)
223 // b.resize_(4) # b is a view tensor, that we are trying to increase the storage size of
224 //
225 // Why is this code difficult to handle?
226 // The functionalization pass currently keeps aliases in sync by making the following assumptions:
227 // - The “base” tensor always refers to “all of the data”
228 // - Whenever you have b = view_op(a), “b” should always refer to a subset of “a”s memory.
229 //
230 // The code above breaks that assumption b.resize_(4) actually needs to update "a"
231 // to tell it that it is now actually some slice of a pre-existing larger storage.
232 // We're also no longer re-generate "b" fully from "a" anymore, since "a" refers to a slice of "b"'s data.
233 //
234 // This is probably fixable in theory, but:
235 // - the fix would likey complicated the functionalization logic quite a bit.
236 // - the primary use case for resize_() today is resizing zero-sized tensors in out= variants of operators
237 // - resize_() also can give you weird results today if you try to resize_() a weirdly strided tensor.
238 //
239 // Given all of the above, for now we're just banning the above usage.
240 TORCH_CHECK(storage().use_count() == 1, "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass");
241 TORCH_CHECK(view_metas_.empty(), "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass");
242 // If this tensor is not a view (and has no outstanding views taken out on it),
243 // Then it's safe to throw out the old storage and replace it with the new, larger one.
244 storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
245 value_ = other;
246 generation_ = 0;
247 // And update the metadata on the wrapper to reflect the new sizes and strides
248 set_sizes_and_strides(value_.sizes(), value_.strides());
249 refresh_numel();
250 // (Technically we should be guaranteed that the tensor was already contiguous,
251 // since it's guaranteed not to have been a view. Doesnt hurt to run though)
252 refresh_contiguous();
253}
254
255
256void FunctionalTensorWrapper::sync_() {
257 if (is_up_to_date()) {
258 return;
259 }
260 apply_updates();
261 regenerate_from_base();
262}
263
264void FunctionalTensorWrapper::regenerate_from_base() {
265 at::AutoDispatchSkipFunctionalize guard;
266 auto storage_impl = functional_storage_impl();
267 auto t = storage_impl->base();
268 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
269 // Reapply views to get the viewed tensor from the base in alias_
270 for (auto& view_meta: view_metas_) {
271 t = view_meta.forward_fn(t, view_meta.out_index);
272 }
273 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
274 replace_(t);
275 generation_ = storage_impl->generation();
276}
277
278bool FunctionalTensorWrapper::apply_updates() {
279 // Apply all updates on alias_
280 auto storage_impl = functional_storage_impl();
281 return storage_impl->apply_updates();
282}
283
284const char* FunctionalTensorWrapper::tensorimpl_type_name() const {
285 return "FunctionalTensorWrapper";
286}
287
288template <typename VariableVersion>
289c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_core(
290 VariableVersion&& version_counter,
291 bool allow_tensor_metadata_change) const {
292 if (key_set_.has(DispatchKey::Python) &&
293 !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
294 auto r = pyobj_slot_.load_pyobj_interpreter()->detach(this);
295 if (r) {
296 r->set_version_counter(std::forward<VariableVersion>(version_counter));
297 r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
298 return r;
299 }
300 }
301
302 auto impl = c10::make_intrusive<FunctionalTensorWrapper>(value_);
303 copy_tensor_metadata(
304 /*src_impl=*/this,
305 /*dest_impl=*/impl.get(),
306 /*version_counter=*/std::forward<VariableVersion>(version_counter),
307 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
308 impl->level_ = level_;
309 impl->generation_ = generation_;
310 impl->view_metas_ = view_metas_;
311 impl->refresh_numel();
312 impl->refresh_contiguous();
313 return impl;
314}
315
316c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
317 const c10::VariableVersion& version_counter,
318 bool allow_tensor_metadata_change) const {
319 return shallow_copy_and_detach_core(
320 version_counter, allow_tensor_metadata_change);
321}
322
323c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
324 c10::VariableVersion&& version_counter,
325 bool allow_tensor_metadata_change) const {
326 return shallow_copy_and_detach_core(
327 std::move(version_counter), allow_tensor_metadata_change);
328}
329
330c10::Device FunctionalTensorWrapper::device_custom() const {
331 return value_.unsafeGetTensorImpl()->device();
332}
333at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const {
334 return value_.unsafeGetTensorImpl()->sizes();
335}
336at::IntArrayRef FunctionalTensorWrapper::strides_custom() const {
337 return value_.unsafeGetTensorImpl()->strides();
338}
339int64_t FunctionalTensorWrapper::dim_custom() const {
340 return value_.unsafeGetTensorImpl()->dim();
341}
342int64_t FunctionalTensorWrapper::numel_custom() const {
343 return value_.unsafeGetTensorImpl()->numel();
344}
345bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
346 return value_.unsafeGetTensorImpl()->is_contiguous(memory_format);
347}
348c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
349 return value_.unsafeGetTensorImpl()->sym_sizes();
350}
351c10::SymIntArrayRef FunctionalTensorWrapper::sym_strides_custom() const {
352 return value_.unsafeGetTensorImpl()->sym_strides();
353}
354c10::SymInt FunctionalTensorWrapper::sym_size_custom(int64_t d) const {
355 return value_.unsafeGetTensorImpl()->sym_size(d);
356}
357c10::SymInt FunctionalTensorWrapper::sym_storage_offset_custom() const {
358 return value_.unsafeGetTensorImpl()->sym_storage_offset();
359}
360
361namespace functionalization {
362namespace impl {
363
364Tensor to_functional_tensor(const Tensor& tensor) {
365 // Note [Wrapped Numbers <> Functionalization]
366 if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
367 return tensor;
368 }
369 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isFunctionalTensor(tensor));
370 return at::detail::make_tensor<FunctionalTensorWrapper>(tensor);
371}
372c10::optional<Tensor> to_functional_tensor(const c10::optional<Tensor>& tensor) {
373 if (tensor.has_value()) {
374 return c10::make_optional<Tensor>(to_functional_tensor(*tensor));
375 }
376 return c10::nullopt;
377}
378c10::List<c10::optional<Tensor>> to_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) {
379 c10::List<c10::optional<Tensor>> outputs;
380 outputs.reserve(t_list.size());
381 for (const auto i : c10::irange(t_list.size())) {
382 outputs.push_back(to_functional_tensor(t_list[i]));
383 }
384 return outputs;
385}
386std::vector<Tensor> to_functional_tensor(ITensorListRef t_list) {
387 std::vector<Tensor> outputs;
388 outputs.reserve(t_list.size());
389 for (const auto& tensor : t_list) {
390 outputs.push_back(to_functional_tensor(tensor));
391 }
392 return outputs;
393}
394
395Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) {
396 // Note [Wrapped Numbers <> Functionalization]
397 if (!tensor.defined() || tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
398 return tensor;
399 }
400 if (isFunctionalTensor(tensor)) {
401 auto impl = unsafeGetFunctionalWrapper(tensor);
402 return impl->value();
403 } else {
404 // If the current tensor is not functional, then raise an error
405 // if assert_functional is true. Otherwise, return the input.
406 TORCH_INTERNAL_ASSERT(!assert_functional)
407 return tensor;
408 }
409}
410c10::optional<Tensor> from_functional_tensor(const c10::optional<Tensor>& t, bool assert_functional) {
411 if (t.has_value()) {
412 return c10::make_optional<Tensor>(from_functional_tensor(*t, assert_functional));
413 }
414 return c10::nullopt;
415}
416std::vector<Tensor> from_functional_tensor(ITensorListRef t_list) {
417 std::vector<Tensor> outputs;
418 outputs.reserve(t_list.size());
419 for (const auto& tensor : t_list) {
420 // from_functional_tensor(Tensor) has asserts to make sure you don't accidentally call
421 // it on a non-functional input,
422 // but from_functional_tensor(TensorList) can recieve a list containing both
423 // functional and non-functional tensors.
424 // Example of when that can happen: torch.cat(function_input_tensor, global_state_tensor).
425 // When that happens, we're okay with only unwrapping the functional tensors.
426 outputs.push_back(from_functional_tensor(tensor, /*assert_functional=*/false));
427 }
428 return outputs;
429}
430c10::List<c10::optional<Tensor>> from_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) {
431 c10::List<c10::optional<Tensor>> outputs;
432 outputs.reserve(t_list.size());
433 for (const auto i : c10::irange(t_list.size())) {
434 outputs.push_back(from_functional_tensor(t_list[i], /*assert_functional=*/false));
435 }
436 return outputs;
437}
438
439void sync(const Tensor& t) {
440 if (t.unsafeGetTensorImpl()->is_wrapped_number()) {
441 // Note [Wrapped Numbers <> Functionalization]
442 // Unfortunately, we can't easily guarantee that wrapped numbers (scalar-tensors)
443 // get wrapped up in a FunctionalTensorWrapper object, since they skip the dispatcher.
444 // That shouldn't matter, since I don't think we're allowed to assign to wrapped numbers anyway.
445 return;
446 }
447 // Not every tensor that hits a functionalization kernel is necessarily a functional tensor.
448 // For example, xla_tensor.copy_(cpu_tensor) needs to hit the functionalization kernel
449 // to sync xla_tensor, but not cpu_tensor.
450 if (!at::functionalization::impl::isFunctionalTensor(t)) {
451 return;
452 }
453 auto functional_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
454 functional_impl->sync_();
455}
456void sync(const c10::optional<Tensor>& t) {
457 if (t.has_value()) {
458 sync(*t);
459 }
460}
461void sync(ITensorListRef t_list) {
462 for (const auto& t : t_list) {
463 sync(t);
464 }
465}
466void sync(const c10::List<c10::optional<Tensor>> t_list) {
467 for (const auto i : c10::irange(t_list.size())) {
468 sync(t_list[i]);
469 }
470}
471
472void replace_(const Tensor& functional_tensor, const Tensor& other) {
473 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
474 unsafeGetFunctionalWrapper(functional_tensor)->replace_(other);
475}
476
477void replace_(const ITensorListRef functional_tensor, ITensorListRef other) {
478 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size());
479 auto functional_tensor_it = functional_tensor.begin();
480 auto other_it = other.begin();
481 for (const auto i : c10::irange(functional_tensor.size())) {
482 (void)i; // Suppress unused variable warning
483 replace_(*functional_tensor_it++, *other_it++);
484 }
485}
486
487void commit_update(const Tensor& functional_tensor) {
488 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
489 unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
490}
491
492void commit_update(ITensorListRef functional_tensor) {
493 for (const auto& t : functional_tensor) {
494 commit_update(t);
495 }
496}
497
498bool isFunctionalTensor(const at::Tensor& tensor) {
499 return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
500}
501
502bool isFunctionalTensor(const c10::optional<Tensor>& t) {
503 if (t.has_value()) {
504 return isFunctionalTensor(*t);
505 } else {
506 return false;
507 }
508}
509
510bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) {
511 if (t_list.empty()) return false;
512 auto functional_count = 0;
513 for (const auto i : c10::irange(t_list.size())) {
514 if (!t_list[i].has_value() || !t_list[i]->defined()) continue;
515 if (isFunctionalTensor(t_list[i])) {
516 ++functional_count;
517 }
518 }
519 return functional_count > 0;
520}
521
522template <typename T>
523bool isFunctionalTensorIListRef(c10::IListRef<T> list) {
524 if (list.size() == 0) return false;
525 auto functional_count = 0;
526 for (const auto& tensor : list) {
527 if (!tensor.defined()) continue;
528 if (isFunctionalTensor(tensor)) {
529 ++functional_count;
530 }
531 }
532 return functional_count > 0;
533}
534
535bool isFunctionalTensor(ITensorListRef list) {
536 return isFunctionalTensorIListRef(list);
537}
538
539void freeze_functional_tensor(const Tensor& tensor) {
540 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(tensor));
541 auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
542 functional_base_impl->freeze_storage();
543}
544
545Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
546 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
547 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
548 auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
549 if (out_idx != 0) {
550 // Note [out_idx in ViewMeta]
551 // When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
552 // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
553 meta = meta.to_out_idx(out_idx);
554 }
555 return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta);
556}
557
558std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta) {
559 std::vector<Tensor> outputs(view_to_wrap.size());
560 int64_t i = 0;
561 for (const auto& tensor : view_to_wrap) {
562 outputs[i] = create_functional_tensor_with_view_meta(tensor, base, meta, i);
563 i++;
564 }
565 return outputs;
566}
567
568void mutate_view_meta(const at::Tensor& self, functionalization::ViewMeta meta) {
569 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
570 auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
571 self_impl->mutate_view_meta(std::move(meta));
572}
573
574// Note [Propagating strides in the functionalization pass]
575// In order to properly compute stride information, the functionalization pass
576// calls each {view} reference implementations with meta tensors.
577// The output meta tensor's stride info serves as a reference for what the correct strides should be.
578void set_sizes_strides_offset(const Tensor& out, const Tensor& reference_out) {
579 out.unsafeGetTensorImpl()->set_sizes_and_strides(reference_out.sym_sizes(), reference_out.sym_strides(), reference_out.sym_storage_offset());
580}
581
582void set_sizes_strides_offset(const std::vector<Tensor>& outs, const std::vector<Tensor>& reference_outs) {
583 TORCH_INTERNAL_ASSERT(outs.size() == reference_outs.size());
584 for (const auto i : c10::irange(reference_outs.size())) {
585 set_sizes_strides_offset(outs[i], reference_outs[i]);
586 }
587}
588
589thread_local bool _functionalizationReapplyViews;
590
591bool getFunctionalizationReapplyViewsTLS() {
592 return _functionalizationReapplyViews;
593}
594void setFunctionalizationReapplyViewsTLS(bool reapply_views) {
595 _functionalizationReapplyViews = reapply_views;
596}
597
598} // namespace impl
599
600
601// Given an **out-of-place** op that might internally call view/inplace ops,
602// This function will "functionalize" it.
603// That is, it will call the operator, but removing any intermediate views/mutations
604// that are performed inside of it.
605// This is useful for LTC/XLA, which would like to re-use some of our composite kernels
606// from pytorch core but not have to worry about the view ops that they might call.
607// e.g. at::block_diag
608void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
609 const auto& schema = op.schema();
610 const auto num_arguments = schema.arguments().size();
611 const auto arguments_begin = stack->size() - num_arguments;
612 auto arguments = torch::jit::last(stack, num_arguments);
613
614 // Wrap all tensor-like inputs into FunctionalTensorWrappers.
615 // When we re-invoke the dispatcher, this will automatically enable the functionalization pass.
616 for (uint64_t idx = 0; idx < num_arguments; ++idx) {
617 const auto& ivalue = arguments[idx];
618 if (ivalue.isTensor()) {
619 const auto& t = ivalue.toTensor();
620 if (t.defined()) {
621 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t),
622 "The composite op functionalization fallback expects its inputs all not to be functional tensors");
623 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
624 (*stack)[arguments_begin + idx] = t_new;
625 }
626 } else if (ivalue.isTensorList()) {
627 auto tensors = ivalue.toTensorList();
628 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensors),
629 "The composite op functionalization fallback expects its inputs all not to be functional tensors");
630 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
631 (*stack)[arguments_begin + idx] = t_new;
632 } else if (ivalue.isOptionalTensorList()) {
633 auto opt_tensors = ivalue.toOptionalTensorList();
634 TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(opt_tensors),
635 "The composite op functionalization fallback expects its inputs all not to be functional tensors");
636 auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
637 (*stack)[arguments_begin + idx] = t_new;
638 }
639 }
640
641 {
642 // Today when you call at::empty(device=lazy), the lazy backend decides whether or not to wrap
643 // the output in a functional tensor based on TLS.
644 // In this code, we're re-entrantly entering functionalization in the same call-stack,
645 // so we need to manually fix up TLS as if it hadn't already been called.
646 auto curr_tls = c10::impl::tls_local_dispatch_key_set();
647 auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet();
648 tls_reenable_functionalize.set_included(curr_tls.included_);
649 tls_reenable_functionalize.set_excluded(curr_tls.excluded_.remove(c10::DispatchKey::Functionalize));
650 c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize);
651 // So, we should probably provide a way to directly call a kernel registered to
652 // the `CompositeExplicitAutograd` key.
653 // We can't do that today, so this should be a reasonably good proxy
654 // (It won't work in cases where an op has both a CompositeExplicitAutograd kernel
655 // AND a dedicated meta kernel, but that probably shouldn't ever happen).
656 op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::Meta), stack);
657 }
658
659 const auto num_returns = schema.returns().size();
660 const auto returns_begin = stack->size() - num_returns;
661 auto returns = torch::jit::last(stack, num_returns);
662
663 for (const auto idx : c10::irange(num_returns)) {
664 const auto& ivalue = returns[idx];
665 if (ivalue.isTensor()) {
666 const auto& t = ivalue.toTensor();
667 if (!t.defined()) continue;
668 at::functionalization::impl::sync(t);
669 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
670 (*stack)[returns_begin + idx] = t_new;
671 } else if (ivalue.isTensorList()) {
672 auto tensors = ivalue.toTensorList();
673 at::functionalization::impl::sync(tensors);
674 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
675 (*stack)[returns_begin + idx] = t_new;
676 } else if (ivalue.isOptionalTensorList()) {
677 auto opt_tensors = ivalue.toOptionalTensorList();
678 at::functionalization::impl::sync(opt_tensors);
679 auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
680 (*stack)[returns_begin + idx] = t_new;
681 }
682 }
683}
684
685
686
687} // namespace functionalization
688} // namespace at
689