1 | |
2 | #include <ATen/FunctionalInverses.h> |
3 | |
4 | #include <ATen/ATen.h> |
5 | #include <ATen/ExpandUtils.h> |
6 | #include <ATen/WrapDimUtilsMulti.h> |
7 | |
8 | #include <utility> |
9 | namespace at { |
10 | namespace functionalization { |
11 | |
12 | // This logic is similar to autograd code for view backwards calls. |
13 | // We can't easily share it though, because (eventually) these functions |
14 | // will all call `permute/unsqueeze_copy()` instead of `permute/unsqueeze`. |
15 | |
16 | Tensor permute_copy_inverse(const Tensor& self, IntArrayRef dims, bool reapply_views) { |
17 | // invert the permutation |
18 | auto ndims = dims.size(); |
19 | std::vector<int64_t> dims_(ndims); |
20 | for(const auto i : c10::irange(ndims)) { |
21 | dims_[at::maybe_wrap_dim(dims[i], ndims)] = i; |
22 | } |
23 | if (reapply_views) { |
24 | return at::permute(self, dims_); |
25 | } else { |
26 | return at::permute_copy(self, dims_); |
27 | } |
28 | } |
29 | |
30 | Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, bool reapply_views) { |
31 | auto result = self; |
32 | |
33 | int64_t nDims = sizes.size(); |
34 | for(const auto dim : c10::irange(nDims)) { |
35 | if (sizes[dim] == 1) { |
36 | if (reapply_views) { |
37 | result = at::unsqueeze(result, dim); |
38 | } else { |
39 | result = at::unsqueeze_copy(result, dim); |
40 | } |
41 | } |
42 | } |
43 | return result; |
44 | } |
45 | |
46 | Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIntArrayRef sizes, bool reapply_views) { |
47 | const auto ndim = sizes.size(); |
48 | const auto mask = at::dim_list_to_bitset(dim, ndim); |
49 | // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided |
50 | // unsqueezing in the backward. |
51 | if (ndim == 0) { |
52 | return self; |
53 | } |
54 | |
55 | Tensor result = self; |
56 | for (const auto d : c10::irange(ndim)) { |
57 | if (mask.test(d) && sizes[d] == 1) { |
58 | if (reapply_views) { |
59 | result = at::unsqueeze(result, d); |
60 | } else { |
61 | result = at::unsqueeze_copy(result, d); |
62 | } |
63 | } |
64 | } |
65 | return result; |
66 | } |
67 | |
68 | // Note [Functionalization Pass: View Inverses]. |
69 | // This file contains the implementation of each "view inverse". |
70 | // These aren't really true inverses in the mathematically sense: each view inverse describes how to undo |
71 | // the original view (although it takes in different arguments). |
72 | // |
73 | // E.g. Below is an example of a program that has alias operations removed, and the role that view inverses play: |
74 | // |
75 | // normal program with views and mutations: |
76 | // view1 = input1.view_op(args...) |
77 | // view1.add_(1) (perform a mutation on the view, which should also modify input) |
78 | |
79 | // version of the program with no aliasing, that instead uses view_inverse functions: |
80 | // view_copy1 = input1.view_copy_op(args...) |
81 | // view_copy1.add_(1) (perform a mutation on view_copy1. At this point, input1 is NOT modified) |
82 | // x = view_op_inverse(input1, view_copy1, args...) |
83 | // |
84 | // at this point, input1 and x should be equal |
85 | // |
86 | // Note that input1 is also passed as an argument to view_op_inverse in the above example. |
87 | // This isn't actually required for most view operators: it's only required for view ops |
88 | // where you can't figure out what the size of the base tensor is given just the view tensor and arguments. |
89 | // Examples are slice/select/scatter/squeeze/as_strided. |
90 | // We happen to be passing in the base tensor in all cases, mostly to make the codegen simpler. |
91 | // But you'll see below that the "base" argument is ignored by most view_inverse implementations. |
92 | |
93 | // ---------------------------------------------------------- |
94 | // Implementations of each view_inverse() function are below. |
95 | // One of these needs to be implemented for every existing non-composite view operator. |
96 | // The codegen automatically generates the corresponding function declaration. |
97 | // ---------------------------------------------------------- |
98 | |
99 | Tensor FunctionalInverses::_fw_primal_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views, int64_t level) { |
100 | TORCH_INTERNAL_ASSERT(false, "Attempted to call _fw_primal() during the functionalization pass. For now, this is not supported." ); |
101 | return Tensor(); |
102 | } |
103 | |
104 | Tensor FunctionalInverses::_make_dual_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views, const at::Tensor& tangent, int64_t level) { |
105 | TORCH_INTERNAL_ASSERT(false, "Attempted to call _make_dual() during the functionalization pass. For now, this is not supported." ); |
106 | return Tensor(); |
107 | } |
108 | |
109 | Tensor FunctionalInverses::view_as_real_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
110 | if (reapply_views) { |
111 | return at::view_as_complex(mutated_view); |
112 | } else { |
113 | return at::view_as_complex_copy(mutated_view); |
114 | } |
115 | } |
116 | |
117 | Tensor FunctionalInverses::view_as_complex_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
118 | if (reapply_views) { |
119 | return at::view_as_real(mutated_view.resolve_conj()); |
120 | } else { |
121 | return at::view_as_real_copy(mutated_view.resolve_conj()); |
122 | } |
123 | } |
124 | |
125 | Tensor FunctionalInverses::_conj_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
126 | if (reapply_views) { |
127 | return at::_conj(mutated_view); |
128 | } else { |
129 | return at::_conj_copy(mutated_view); |
130 | } |
131 | } |
132 | |
133 | Tensor FunctionalInverses::_neg_view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
134 | if (reapply_views) { |
135 | return at::_neg_view(mutated_view); |
136 | } else { |
137 | return at::_neg_view_copy(mutated_view); |
138 | } |
139 | } |
140 | |
141 | Tensor FunctionalInverses::as_strided_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, at::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) { |
142 | // Pessimism: we can't reapply views for as_strided_scatter. |
143 | return base.as_strided_scatter_symint(mutated_view, size, stride, std::move(storage_offset)); |
144 | } |
145 | |
146 | Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t offset, int64_t dim1, int64_t dim2) { |
147 | // Pessimism: we can't reapply views for slice_scatter. |
148 | return base.diagonal_scatter(mutated_view, offset, dim1, dim2); |
149 | } |
150 | |
151 | Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, bool implicit) { |
152 | return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views); |
153 | } |
154 | |
155 | Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) { |
156 | return at::functionalization::permute_copy_inverse(mutated_view, dims, reapply_views); |
157 | } |
158 | |
159 | Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, at::SymIntArrayRef stride) { |
160 | // Note that I'm directly calling reshape(), and ignoring the strides. |
161 | // _reshape_alias() isn't available from user code, and is an implementation detail of reshape(). |
162 | // Specifically, passing in the strides directly can get us into trouble in cases like: |
163 | // b = a[0]; c = b.reshape(...); c.add_(1); print(a) |
164 | // When we eventually run the _reshape_alias_inverse() call here, if we were to pass in both sizes and strides, |
165 | // The call would fail because `mutated_view` doesn't have enough bytes of storage. |
166 | if (reapply_views) { |
167 | return at::_reshape_alias_symint(mutated_view, base.sym_sizes(), base.sym_strides()); |
168 | } else { |
169 | return at::_reshape_alias_copy_symint(mutated_view, base.sym_sizes(), base.sym_strides()); |
170 | } |
171 | } |
172 | |
173 | Tensor FunctionalInverses::select_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::SymInt index) { |
174 | // Pessimism: we can't reapply views for slice_scatter. |
175 | return base.select_scatter_symint(mutated_view, dim, std::move(index)); |
176 | } |
177 | |
178 | Tensor FunctionalInverses::detach_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
179 | // the functionalization pass doesn't care about autograd metadata - as a view, I think detach() is just an identity function |
180 | return mutated_view; |
181 | } |
182 | |
183 | Tensor FunctionalInverses::lift_fresh_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
184 | return mutated_view; |
185 | } |
186 | |
187 | Tensor FunctionalInverses::slice_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) { |
188 | // Pessimism: we can't reapply views for slice_scatter. |
189 | return base.slice_scatter_symint(mutated_view, dim, std::move(start), std::move(end), std::move(step)); |
190 | } |
191 | |
192 | Tensor FunctionalInverses::split_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) { |
193 | // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can. |
194 | // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i |
195 | // on top of the base tensor. |
196 | // For autograd, we have all of the tensors outputted by split() and we just want to stack them. |
197 | dim = at::maybe_wrap_dim(dim, base.dim()); |
198 | auto dim_size = base.sym_size(dim); |
199 | auto start = split_size * mutated_view_idx; |
200 | auto end = split_size + start; |
201 | if (end > dim_size) end = dim_size; |
202 | // Pessimism: we can't reapply views for slice_scatter. |
203 | return base.slice_scatter_symint(mutated_view, dim, start, end, 1); |
204 | } |
205 | |
206 | Tensor FunctionalInverses::split_with_sizes_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) { |
207 | dim = at::maybe_wrap_dim(dim, base.dim()); |
208 | auto dim_size = base.sym_size(dim); |
209 | c10::SymInt start = 0; |
210 | for (auto i = 0; i < mutated_view_idx; ++i) { |
211 | start += split_sizes[i]; |
212 | } |
213 | auto end = start + split_sizes[mutated_view_idx]; |
214 | if (end > dim_size) end = dim_size; |
215 | // Pessimism: we can't reapply views for slice_scatter. |
216 | return base.slice_scatter_symint(mutated_view, dim, start, end, 1); |
217 | } |
218 | |
219 | Tensor FunctionalInverses::squeeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
220 | return unsqueeze_copy_to(mutated_view, base.sym_sizes(), reapply_views); |
221 | } |
222 | |
223 | Tensor FunctionalInverses::squeeze_copy_dim_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim) { |
224 | return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views); |
225 | } |
226 | |
227 | Tensor FunctionalInverses::squeeze_copy_dims_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, IntArrayRef dim) { |
228 | return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views); |
229 | } |
230 | |
231 | Tensor FunctionalInverses::t_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
232 | if (reapply_views) { |
233 | return at::t(mutated_view); |
234 | } else { |
235 | return at::t_copy(mutated_view); |
236 | } |
237 | } |
238 | |
239 | Tensor FunctionalInverses::transpose_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim0, int64_t dim1) { |
240 | if (reapply_views) { |
241 | return transpose(mutated_view, dim0, dim1); |
242 | } else { |
243 | return transpose_copy(mutated_view, dim0, dim1); |
244 | } |
245 | } |
246 | |
247 | Tensor FunctionalInverses::_nested_view_from_buffer_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, const Tensor& nested_size_tensor, const Tensor& nested_stride_tensor, IntArrayRef offsets) { |
248 | TORCH_INTERNAL_ASSERT(false, "Attempted to call _nested_view_from_buffer() during the functionalization pass. For now, nested tensors aren't supported during functionalization" ); |
249 | return Tensor(); |
250 | } |
251 | |
252 | Tensor FunctionalInverses::unsqueeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim) { |
253 | if (reapply_views) { |
254 | return at::squeeze(mutated_view, dim); |
255 | } else { |
256 | return at::squeeze_copy(mutated_view, dim); |
257 | } |
258 | } |
259 | |
260 | Tensor FunctionalInverses::_indices_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
261 | TORCH_INTERNAL_ASSERT(false, "Attempted to call _indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization" ); |
262 | return Tensor(); |
263 | } |
264 | |
265 | Tensor FunctionalInverses::_values_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
266 | TORCH_INTERNAL_ASSERT(false, "Attempted to call _values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization" ); |
267 | return Tensor(); |
268 | } |
269 | |
270 | Tensor FunctionalInverses::indices_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
271 | TORCH_INTERNAL_ASSERT(false, "Attempted to call indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization" ); |
272 | return Tensor(); |
273 | } |
274 | |
275 | Tensor FunctionalInverses::values_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
276 | TORCH_INTERNAL_ASSERT(false, "Attempted to call values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization" ); |
277 | return Tensor(); |
278 | } |
279 | |
280 | Tensor FunctionalInverses::_sparse_broadcast_to_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size) { |
281 | TORCH_INTERNAL_ASSERT(false, "Attempted to call _sparse_broadcast_to() during the functionalization pass. For now, sparse tensors aren't supported during functionalization" ); |
282 | return Tensor(); |
283 | } |
284 | |
285 | Tensor FunctionalInverses::crow_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) { |
286 | TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization" ); |
287 | return Tensor(); |
288 | } |
289 | |
290 | Tensor FunctionalInverses::col_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) { |
291 | TORCH_INTERNAL_ASSERT(false, "Attempted to call col_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization" ); |
292 | return Tensor(); |
293 | } |
294 | |
295 | Tensor FunctionalInverses::ccol_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) { |
296 | TORCH_INTERNAL_ASSERT(false, "Attempted to call ccol_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization" ); |
297 | return Tensor(); |
298 | } |
299 | |
300 | Tensor FunctionalInverses::row_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) { |
301 | TORCH_INTERNAL_ASSERT(false, "Attempted to call row_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization" ); |
302 | return Tensor(); |
303 | } |
304 | |
305 | Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, int64_t dim) { |
306 | dim = at::maybe_wrap_dim(dim, base.sizes().size()); |
307 | // Pessimism: we can't reapply views for select_scatter. |
308 | return base.select_scatter(mutated_view, dim, mutated_view_idx); |
309 | } |
310 | |
311 | Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size) { |
312 | if (reapply_views) { |
313 | return mutated_view.view_symint(base.sym_sizes()); |
314 | } else { |
315 | return at::view_copy_symint(mutated_view, base.sym_sizes()); |
316 | } |
317 | } |
318 | |
319 | |
320 | Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) { |
321 | if (reapply_views) { |
322 | return mutated_view.view(base.scalar_type()); |
323 | } else { |
324 | return at::view_copy(mutated_view, base.scalar_type()); |
325 | } |
326 | } |
327 | |
328 | Tensor FunctionalInverses::unfold_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dimension, int64_t size, int64_t step) { |
329 | // I think autograd and the functionalization pass want the exact same thing here, but need to test to confirm. |
330 | // unfold_backward() is safe to use here because it is NOT a view op. |
331 | // (note: technically, "reapply_views" won't do anything here and we'll have an extra memory copy. |
332 | // We'd need to add an aliasing version of unfold_backward to fix that though). |
333 | return unfold_backward(mutated_view, base.sizes(), dimension, size, step); |
334 | } |
335 | |
336 | Tensor FunctionalInverses::alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { |
337 | if (reapply_views) { |
338 | return at::alias(mutated_view); |
339 | } else { |
340 | return at::alias_copy(mutated_view); |
341 | } |
342 | } |
343 | |
344 | } // functionalization |
345 | } // at |
346 | |