1
2#include <ATen/FunctionalInverses.h>
3
4#include <ATen/ATen.h>
5#include <ATen/ExpandUtils.h>
6#include <ATen/WrapDimUtilsMulti.h>
7
8#include <utility>
9namespace at {
10namespace functionalization {
11
12// This logic is similar to autograd code for view backwards calls.
13// We can't easily share it though, because (eventually) these functions
14// will all call `permute/unsqueeze_copy()` instead of `permute/unsqueeze`.
15
16Tensor permute_copy_inverse(const Tensor& self, IntArrayRef dims, bool reapply_views) {
17 // invert the permutation
18 auto ndims = dims.size();
19 std::vector<int64_t> dims_(ndims);
20 for(const auto i : c10::irange(ndims)) {
21 dims_[at::maybe_wrap_dim(dims[i], ndims)] = i;
22 }
23 if (reapply_views) {
24 return at::permute(self, dims_);
25 } else {
26 return at::permute_copy(self, dims_);
27 }
28}
29
30Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, bool reapply_views) {
31 auto result = self;
32
33 int64_t nDims = sizes.size();
34 for(const auto dim : c10::irange(nDims)) {
35 if (sizes[dim] == 1) {
36 if (reapply_views) {
37 result = at::unsqueeze(result, dim);
38 } else {
39 result = at::unsqueeze_copy(result, dim);
40 }
41 }
42 }
43 return result;
44}
45
46Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIntArrayRef sizes, bool reapply_views) {
47 const auto ndim = sizes.size();
48 const auto mask = at::dim_list_to_bitset(dim, ndim);
49 // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
50 // unsqueezing in the backward.
51 if (ndim == 0) {
52 return self;
53 }
54
55 Tensor result = self;
56 for (const auto d : c10::irange(ndim)) {
57 if (mask.test(d) && sizes[d] == 1) {
58 if (reapply_views) {
59 result = at::unsqueeze(result, d);
60 } else {
61 result = at::unsqueeze_copy(result, d);
62 }
63 }
64 }
65 return result;
66}
67
68// Note [Functionalization Pass: View Inverses].
69// This file contains the implementation of each "view inverse".
70// These aren't really true inverses in the mathematically sense: each view inverse describes how to undo
71// the original view (although it takes in different arguments).
72//
73// E.g. Below is an example of a program that has alias operations removed, and the role that view inverses play:
74//
75// normal program with views and mutations:
76// view1 = input1.view_op(args...)
77// view1.add_(1) (perform a mutation on the view, which should also modify input)
78
79// version of the program with no aliasing, that instead uses view_inverse functions:
80// view_copy1 = input1.view_copy_op(args...)
81// view_copy1.add_(1) (perform a mutation on view_copy1. At this point, input1 is NOT modified)
82// x = view_op_inverse(input1, view_copy1, args...)
83//
84// at this point, input1 and x should be equal
85//
86// Note that input1 is also passed as an argument to view_op_inverse in the above example.
87// This isn't actually required for most view operators: it's only required for view ops
88// where you can't figure out what the size of the base tensor is given just the view tensor and arguments.
89// Examples are slice/select/scatter/squeeze/as_strided.
90// We happen to be passing in the base tensor in all cases, mostly to make the codegen simpler.
91// But you'll see below that the "base" argument is ignored by most view_inverse implementations.
92
93// ----------------------------------------------------------
94// Implementations of each view_inverse() function are below.
95// One of these needs to be implemented for every existing non-composite view operator.
96// The codegen automatically generates the corresponding function declaration.
97// ----------------------------------------------------------
98
99Tensor FunctionalInverses::_fw_primal_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views, int64_t level) {
100 TORCH_INTERNAL_ASSERT(false, "Attempted to call _fw_primal() during the functionalization pass. For now, this is not supported.");
101 return Tensor();
102}
103
104Tensor FunctionalInverses::_make_dual_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views, const at::Tensor& tangent, int64_t level) {
105 TORCH_INTERNAL_ASSERT(false, "Attempted to call _make_dual() during the functionalization pass. For now, this is not supported.");
106 return Tensor();
107}
108
109Tensor FunctionalInverses::view_as_real_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
110 if (reapply_views) {
111 return at::view_as_complex(mutated_view);
112 } else {
113 return at::view_as_complex_copy(mutated_view);
114 }
115}
116
117Tensor FunctionalInverses::view_as_complex_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
118 if (reapply_views) {
119 return at::view_as_real(mutated_view.resolve_conj());
120 } else {
121 return at::view_as_real_copy(mutated_view.resolve_conj());
122 }
123}
124
125Tensor FunctionalInverses::_conj_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
126 if (reapply_views) {
127 return at::_conj(mutated_view);
128 } else {
129 return at::_conj_copy(mutated_view);
130 }
131}
132
133Tensor FunctionalInverses::_neg_view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
134 if (reapply_views) {
135 return at::_neg_view(mutated_view);
136 } else {
137 return at::_neg_view_copy(mutated_view);
138 }
139}
140
141Tensor FunctionalInverses::as_strided_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, at::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) {
142 // Pessimism: we can't reapply views for as_strided_scatter.
143 return base.as_strided_scatter_symint(mutated_view, size, stride, std::move(storage_offset));
144}
145
146Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t offset, int64_t dim1, int64_t dim2) {
147 // Pessimism: we can't reapply views for slice_scatter.
148 return base.diagonal_scatter(mutated_view, offset, dim1, dim2);
149}
150
151Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, bool implicit) {
152 return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views);
153}
154
155Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) {
156 return at::functionalization::permute_copy_inverse(mutated_view, dims, reapply_views);
157}
158
159Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, at::SymIntArrayRef stride) {
160 // Note that I'm directly calling reshape(), and ignoring the strides.
161 // _reshape_alias() isn't available from user code, and is an implementation detail of reshape().
162 // Specifically, passing in the strides directly can get us into trouble in cases like:
163 // b = a[0]; c = b.reshape(...); c.add_(1); print(a)
164 // When we eventually run the _reshape_alias_inverse() call here, if we were to pass in both sizes and strides,
165 // The call would fail because `mutated_view` doesn't have enough bytes of storage.
166 if (reapply_views) {
167 return at::_reshape_alias_symint(mutated_view, base.sym_sizes(), base.sym_strides());
168 } else {
169 return at::_reshape_alias_copy_symint(mutated_view, base.sym_sizes(), base.sym_strides());
170 }
171}
172
173Tensor FunctionalInverses::select_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::SymInt index) {
174 // Pessimism: we can't reapply views for slice_scatter.
175 return base.select_scatter_symint(mutated_view, dim, std::move(index));
176}
177
178Tensor FunctionalInverses::detach_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
179 // the functionalization pass doesn't care about autograd metadata - as a view, I think detach() is just an identity function
180 return mutated_view;
181}
182
183Tensor FunctionalInverses::lift_fresh_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
184 return mutated_view;
185}
186
187Tensor FunctionalInverses::slice_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) {
188 // Pessimism: we can't reapply views for slice_scatter.
189 return base.slice_scatter_symint(mutated_view, dim, std::move(start), std::move(end), std::move(step));
190}
191
192Tensor FunctionalInverses::split_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) {
193 // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can.
194 // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i
195 // on top of the base tensor.
196 // For autograd, we have all of the tensors outputted by split() and we just want to stack them.
197 dim = at::maybe_wrap_dim(dim, base.dim());
198 auto dim_size = base.sym_size(dim);
199 auto start = split_size * mutated_view_idx;
200 auto end = split_size + start;
201 if (end > dim_size) end = dim_size;
202 // Pessimism: we can't reapply views for slice_scatter.
203 return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
204}
205
206Tensor FunctionalInverses::split_with_sizes_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) {
207 dim = at::maybe_wrap_dim(dim, base.dim());
208 auto dim_size = base.sym_size(dim);
209 c10::SymInt start = 0;
210 for (auto i = 0; i < mutated_view_idx; ++i) {
211 start += split_sizes[i];
212 }
213 auto end = start + split_sizes[mutated_view_idx];
214 if (end > dim_size) end = dim_size;
215 // Pessimism: we can't reapply views for slice_scatter.
216 return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
217}
218
219Tensor FunctionalInverses::squeeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
220 return unsqueeze_copy_to(mutated_view, base.sym_sizes(), reapply_views);
221}
222
223Tensor FunctionalInverses::squeeze_copy_dim_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim) {
224 return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views);
225}
226
227Tensor FunctionalInverses::squeeze_copy_dims_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, IntArrayRef dim) {
228 return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views);
229}
230
231Tensor FunctionalInverses::t_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
232 if (reapply_views) {
233 return at::t(mutated_view);
234 } else {
235 return at::t_copy(mutated_view);
236 }
237}
238
239Tensor FunctionalInverses::transpose_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim0, int64_t dim1) {
240 if (reapply_views) {
241 return transpose(mutated_view, dim0, dim1);
242 } else {
243 return transpose_copy(mutated_view, dim0, dim1);
244 }
245}
246
247Tensor FunctionalInverses::_nested_view_from_buffer_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, const Tensor& nested_size_tensor, const Tensor& nested_stride_tensor, IntArrayRef offsets) {
248 TORCH_INTERNAL_ASSERT(false, "Attempted to call _nested_view_from_buffer() during the functionalization pass. For now, nested tensors aren't supported during functionalization");
249 return Tensor();
250}
251
252Tensor FunctionalInverses::unsqueeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim) {
253 if (reapply_views) {
254 return at::squeeze(mutated_view, dim);
255 } else {
256 return at::squeeze_copy(mutated_view, dim);
257 }
258}
259
260Tensor FunctionalInverses::_indices_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
261 TORCH_INTERNAL_ASSERT(false, "Attempted to call _indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
262 return Tensor();
263}
264
265Tensor FunctionalInverses::_values_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
266 TORCH_INTERNAL_ASSERT(false, "Attempted to call _values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
267 return Tensor();
268}
269
270Tensor FunctionalInverses::indices_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
271 TORCH_INTERNAL_ASSERT(false, "Attempted to call indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
272 return Tensor();
273}
274
275Tensor FunctionalInverses::values_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
276 TORCH_INTERNAL_ASSERT(false, "Attempted to call values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
277 return Tensor();
278}
279
280Tensor FunctionalInverses::_sparse_broadcast_to_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size) {
281 TORCH_INTERNAL_ASSERT(false, "Attempted to call _sparse_broadcast_to() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
282 return Tensor();
283}
284
285Tensor FunctionalInverses::crow_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) {
286 TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
287 return Tensor();
288}
289
290Tensor FunctionalInverses::col_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) {
291 TORCH_INTERNAL_ASSERT(false, "Attempted to call col_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
292 return Tensor();
293}
294
295Tensor FunctionalInverses::ccol_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) {
296 TORCH_INTERNAL_ASSERT(false, "Attempted to call ccol_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
297 return Tensor();
298}
299
300Tensor FunctionalInverses::row_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) {
301 TORCH_INTERNAL_ASSERT(false, "Attempted to call row_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
302 return Tensor();
303}
304
305Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, int64_t dim) {
306 dim = at::maybe_wrap_dim(dim, base.sizes().size());
307 // Pessimism: we can't reapply views for select_scatter.
308 return base.select_scatter(mutated_view, dim, mutated_view_idx);
309}
310
311Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size) {
312 if (reapply_views) {
313 return mutated_view.view_symint(base.sym_sizes());
314 } else {
315 return at::view_copy_symint(mutated_view, base.sym_sizes());
316 }
317}
318
319
320Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) {
321 if (reapply_views) {
322 return mutated_view.view(base.scalar_type());
323 } else {
324 return at::view_copy(mutated_view, base.scalar_type());
325 }
326}
327
328Tensor FunctionalInverses::unfold_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dimension, int64_t size, int64_t step) {
329 // I think autograd and the functionalization pass want the exact same thing here, but need to test to confirm.
330 // unfold_backward() is safe to use here because it is NOT a view op.
331 // (note: technically, "reapply_views" won't do anything here and we'll have an extra memory copy.
332 // We'd need to add an aliasing version of unfold_backward to fix that though).
333 return unfold_backward(mutated_view, base.sizes(), dimension, size, step);
334}
335
336Tensor FunctionalInverses::alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
337 if (reapply_views) {
338 return at::alias(mutated_view);
339 } else {
340 return at::alias_copy(mutated_view);
341 }
342}
343
344} // functionalization
345} // at
346