1 | #include <torch/csrc/autograd/autograd_not_implemented_fallback.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | |
5 | #include <ATen/core/TorchDispatchUtils.h> |
6 | #include <ATen/core/dispatch/Dispatcher.h> |
7 | #include <ATen/core/ivalue.h> |
8 | |
9 | #include <c10/core/impl/TorchDispatchModeTLS.h> |
10 | #include <torch/csrc/autograd/VariableTypeUtils.h> |
11 | #include <torch/csrc/autograd/autograd.h> |
12 | #include <torch/csrc/autograd/function.h> |
13 | #include <torch/csrc/autograd/functions/basic_ops.h> |
14 | #include <torch/csrc/autograd/functions/utils.h> |
15 | |
16 | #include <utility> |
17 | #include <vector> |
18 | |
19 | namespace torch { |
20 | namespace autograd { |
21 | |
22 | namespace { |
23 | |
24 | template <typename F> |
25 | void _foreach_tensor( |
26 | F fn, |
27 | torch::jit::Stack* stack, |
28 | size_t stack_start, |
29 | size_t size) { |
30 | // Enumerate over tensors in a stack, including ones in TensorLists |
31 | int idx_tensor = 0; |
32 | for (const auto idx_arg : c10::irange(size)) { |
33 | auto& ivalue = (*stack)[stack_start + idx_arg]; |
34 | if (ivalue.isTensor()) { // true for optional tensor that has value |
35 | const auto& tensor = ivalue.toTensor(); |
36 | fn(idx_tensor, idx_arg, tensor); |
37 | idx_tensor++; |
38 | } else if (ivalue.isTensorList()) { |
39 | for (const auto& iv : ivalue.toListRef()) { |
40 | const auto& tensor = iv.toTensor(); |
41 | fn(idx_tensor, idx_arg, tensor); |
42 | idx_tensor++; |
43 | } |
44 | } |
45 | } |
46 | } |
47 | |
48 | } // namespace |
49 | |
50 | void autogradNotImplementedFallbackImpl( |
51 | const c10::OperatorHandle& op, |
52 | c10::DispatchKeySet dispatch_keys, |
53 | torch::jit::Stack* stack) { |
54 | // Mimics a subset of the logic of a VariableType NotImplemented kernel |
55 | // See gen_variable_type.py |
56 | const auto& schema = op.schema(); |
57 | const auto& op_name = schema.operator_name().name; |
58 | const auto num_arguments = schema.arguments().size(); |
59 | const auto num_returns = schema.returns().size(); |
60 | const auto stack_start = stack->size() - num_arguments; |
61 | const bool grad_mode = GradMode::is_enabled(); |
62 | std::vector<const at::Tensor*> tensors_requiring_grad_on_stack; |
63 | |
64 | // Keep track of which outputs are output of in-place modification |
65 | // so we can rebase_history if necessary |
66 | std::vector<bool> is_inplace_output(num_returns, false); |
67 | bool any_is_inplace_output = false; |
68 | std::vector<bool> is_aliased_output(num_returns, false); |
69 | int aliased_output_idx = -1; |
70 | |
71 | for (const auto i : c10::irange(num_returns)) { |
72 | if (schema.is_aliasing({c10::SchemaArgType::output, i})) { |
73 | if (schema.is_mutable({c10::SchemaArgType::output, i})) { |
74 | is_inplace_output[i] = true; |
75 | any_is_inplace_output = true; |
76 | } else { |
77 | TORCH_CHECK( |
78 | aliased_output_idx == -1, |
79 | "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " |
80 | "Non-composite functions where multiple outputs are aliased with inputs aren't supported." |
81 | "Please rewrite your function as a composite function." ); |
82 | aliased_output_idx = i; |
83 | } |
84 | is_aliased_output[i] = true; |
85 | } |
86 | } |
87 | |
88 | int aliased_input_idx = -1; |
89 | for (const auto i : c10::irange(num_arguments)) { |
90 | if (schema.is_aliasing({c10::SchemaArgType::input, i}) && |
91 | !schema.is_mutable({c10::SchemaArgType::input, i})) { |
92 | TORCH_CHECK( |
93 | aliased_input_idx == -1, |
94 | "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " |
95 | "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " |
96 | "Please rewrite your function as a composite function." ); |
97 | aliased_input_idx = i; |
98 | } |
99 | } |
100 | |
101 | size_t num_tensor_inputs = 0; // Only used for DEBUG-only checks |
102 | _foreach_tensor( |
103 | [&](size_t _, size_t idx_arg, const at::Tensor& t) { |
104 | if (grad_mode && t.requires_grad()) { |
105 | tensors_requiring_grad_on_stack.push_back(&t); |
106 | } |
107 | num_tensor_inputs++; |
108 | TORCH_CHECK_NOT_IMPLEMENTED( |
109 | !isFwGradDefined(t), |
110 | "Trying to use forward AD with " , |
111 | op_name, |
112 | " that does not support it." ); |
113 | }, |
114 | stack, |
115 | stack_start, |
116 | num_arguments); |
117 | |
118 | const bool any_requires_grad = !tensors_requiring_grad_on_stack.empty(); |
119 | |
120 | _foreach_tensor( |
121 | [&](size_t _, size_t i, const at::Tensor& t) { |
122 | if (schema.is_mutable({c10::SchemaArgType::input, i})) { |
123 | check_inplace(t, any_requires_grad); |
124 | } |
125 | }, |
126 | stack, |
127 | stack_start, |
128 | num_arguments); |
129 | |
130 | std::shared_ptr<NotImplemented> grad_fn; |
131 | if (any_requires_grad) { |
132 | grad_fn = std::shared_ptr<NotImplemented>( |
133 | new NotImplemented(op_name), deleteNode); |
134 | grad_fn->set_next_edges( |
135 | collect_next_edges(tensors_requiring_grad_on_stack)); |
136 | } |
137 | |
138 | #ifndef NDEBUG |
139 | // See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] |
140 | auto stack_args_copy = |
141 | std::vector<c10::IValue>(stack->begin() + stack_start, stack->end()); |
142 | std::vector<c10::intrusive_ptr<c10::TensorImpl>> impl_saved; |
143 | impl_saved.reserve(num_tensor_inputs); |
144 | std::vector<c10::optional<c10::Storage>> storage_saved; |
145 | storage_saved.reserve(num_tensor_inputs); |
146 | _foreach_tensor( |
147 | [&](size_t idx, size_t _, const at::Tensor& t) { |
148 | storage_saved.push_back( |
149 | t.has_storage() ? c10::optional<c10::Storage>(t.storage()) |
150 | : c10::nullopt); |
151 | impl_saved.push_back(t.getIntrusivePtr()); |
152 | }, |
153 | &stack_args_copy, |
154 | 0, |
155 | num_arguments); |
156 | #endif |
157 | if (aliased_input_idx != -1 || any_is_inplace_output) { |
158 | at::AutoDispatchBelowAutograd guard; |
159 | op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack); |
160 | } else { |
161 | // If neither in-place nor view |
162 | at::AutoDispatchBelowADInplaceOrView guard; |
163 | op.redispatchBoxed( |
164 | dispatch_keys & c10::after_ADInplaceOrView_keyset, stack); |
165 | } |
166 | #ifndef NDEBUG |
167 | _foreach_tensor( |
168 | [&](size_t idx_tensor, size_t _, const at::Tensor& t) { |
169 | if (storage_saved.at(idx_tensor).has_value()) |
170 | TORCH_INTERNAL_ASSERT( |
171 | storage_saved.at(idx_tensor).value().is_alias_of(t.storage()), |
172 | op_name); |
173 | if (impl_saved.at(idx_tensor)) |
174 | TORCH_INTERNAL_ASSERT( |
175 | impl_saved.at(idx_tensor) == t.getIntrusivePtr(), op_name); |
176 | }, |
177 | &stack_args_copy, |
178 | 0, |
179 | num_arguments); |
180 | _foreach_tensor( |
181 | [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { |
182 | if (at::impl::tensor_has_dispatch(t) || |
183 | at::impl::dispatch_mode_enabled()) |
184 | return; |
185 | if (!is_inplace_output[idx_ret]) |
186 | TORCH_INTERNAL_ASSERT( |
187 | t.use_count() <= 1, op_name); // Okay to return undefined tensor |
188 | // note(crcrpar): `_foreach_norm` returns a list of scalar Tensors and |
189 | // each Tensor shares a storage of a hidden, intermediate 1D Tensor |
190 | // created inside the CUDA implemenetation. This is because the |
191 | // reference implementation of nvidia/apex repo returns this 1D Tensor |
192 | // where each element represents the norm of corresponding input Tensor, |
193 | // here I want to return the same number of Tensors as the input |
194 | // TensorList, see https://github.com/pytorch/pytorch/issues/93940 |
195 | if (!is_aliased_output[idx_ret] && t.has_storage() && |
196 | op_name != "aten::_foreach_norm" ) |
197 | TORCH_INTERNAL_ASSERT(t.storage().use_count() == 1); |
198 | }, |
199 | stack, |
200 | stack->size() - num_returns, |
201 | num_returns); |
202 | // There should be only a single base-view pair, make sure their storage is |
203 | // aliased. |
204 | if (aliased_input_idx != -1 && aliased_output_idx != -1) { |
205 | const c10::IValue& aliased_input_iv = stack_args_copy[aliased_input_idx]; |
206 | const c10::IValue& aliased_output_iv = |
207 | (*stack)[stack->size() - num_returns + aliased_output_idx]; |
208 | TORCH_INTERNAL_ASSERT(aliased_input_iv.isTensor(), op_name); |
209 | TORCH_INTERNAL_ASSERT( |
210 | aliased_output_iv.isTensor() || aliased_output_iv.isTensorList(), |
211 | op_name); |
212 | const at::Tensor& aliased_input = aliased_input_iv.toTensor(); |
213 | if (aliased_input.has_storage()) { |
214 | if (aliased_output_iv.isTensor()) { |
215 | const at::Tensor& aliased_output = aliased_input_iv.toTensor(); |
216 | TORCH_INTERNAL_ASSERT( |
217 | aliased_input.storage().is_alias_of(aliased_output.storage()), |
218 | op_name); |
219 | } else { |
220 | const auto aliased_output_vec = aliased_output_iv.toTensorVector(); |
221 | for (const auto& aliased_output : aliased_output_vec) { |
222 | TORCH_INTERNAL_ASSERT( |
223 | aliased_input.storage().is_alias_of(aliased_output.storage()), |
224 | op_name); |
225 | } |
226 | } |
227 | } |
228 | } |
229 | #endif |
230 | |
231 | if (any_requires_grad) { |
232 | _foreach_tensor( |
233 | [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { |
234 | if (isDifferentiableType(t.scalar_type())) { |
235 | if (is_inplace_output[idx_ret]) { |
236 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
237 | rebase_history(const_cast<at::Tensor&>(t), grad_fn); |
238 | } else { |
239 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
240 | set_history(const_cast<at::Tensor&>(t), grad_fn); |
241 | } |
242 | } |
243 | }, |
244 | stack, |
245 | stack->size() - num_returns, |
246 | num_returns); |
247 | } |
248 | } |
249 | |
250 | torch::CppFunction autogradNotImplementedFallback() { |
251 | return torch::CppFunction::makeFromBoxedFunction< |
252 | &autogradNotImplementedFallbackImpl>(); |
253 | } |
254 | |
255 | void autogradNotImplementedInplaceOrViewFallbackImpl( |
256 | const c10::OperatorHandle& op, |
257 | c10::DispatchKeySet dispatch_keys, |
258 | torch::jit::Stack* stack) { |
259 | // Mimics a subset of the logic from ADInplaceOrViewType kernel: |
260 | // - see gen_inplace_or_view_type.py |
261 | // - this should only be used with autogradNotImplementedFallback above |
262 | // - For more information see |
263 | // https://pytorch.org/tutorials/advanced/dispatcher |
264 | // |
265 | // NOTE [ Limitations of ADInplaceOrView boxed kernel ] |
266 | // |
267 | // This op should only be used with autogradNotImplementedFallback kernel |
268 | // because there is some logic we need specifically to enforce that even |
269 | // if we do in-place on view's created in this kernel, the proper "derivative |
270 | // is not implemented" error is still raised. |
271 | // |
272 | // Just like the codegened kernel, we try to enforce some things: |
273 | // - For views: we enforce that the view relationship is between the first |
274 | // input |
275 | // and the first output (which may be either Tensor or vec of Tensors |
276 | // - For inplace (TODO?): enforce that the same op cannot be both a view and |
277 | // inplace |
278 | // that is not allowed in the gen_inplace_or_view logic |
279 | const auto& schema = op.schema(); |
280 | const auto& op_name = schema.operator_name().name; |
281 | const auto num_arguments = schema.arguments().size(); |
282 | const auto num_returns = schema.returns().size(); |
283 | const auto stack_start = stack->size() - num_arguments; |
284 | |
285 | at::Tensor aliased_input; |
286 | |
287 | int64_t aliased_output_idx = -1; |
288 | for (const auto i : c10::irange(num_returns)) { |
289 | if (schema.is_aliasing({c10::SchemaArgType::output, i}) && |
290 | !schema.is_mutable({c10::SchemaArgType::output, i})) { |
291 | TORCH_CHECK( |
292 | aliased_output_idx == -1, |
293 | "Fallback ADInplaceOrView kernel expects only a single output in the operator schema to have a " |
294 | "non-write alias annotation (i.e., 'Tensor(a)'). " |
295 | "Non-composite functions where multiple outputs are aliased with inputs aren't supported." |
296 | "Please rewrite your function as a composite function." ); |
297 | aliased_output_idx = i; |
298 | } |
299 | } |
300 | |
301 | int64_t aliased_input_idx = -1; |
302 | for (const auto i : c10::irange(num_arguments)) { |
303 | if (schema.is_aliasing({c10::SchemaArgType::input, i}) && |
304 | !schema.is_mutable({c10::SchemaArgType::input, i})) { |
305 | TORCH_CHECK( |
306 | aliased_input_idx == -1, |
307 | "Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a " |
308 | "non-write alias annotation (i.e., 'Tensor(a)'). " |
309 | "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " |
310 | "Please rewrite your function as a composite function." ); |
311 | aliased_input_idx = i; |
312 | const c10::IValue& aliased_input_iv = |
313 | (*stack)[stack_start + i]; // get a reference to an ivalue on the |
314 | // stack |
315 | TORCH_CHECK(aliased_input_iv.isTensor()); |
316 | aliased_input = |
317 | aliased_input_iv.toTensor(); // TODO: Can we avoid saving this tensor |
318 | // and incurring the refcount bump? |
319 | } |
320 | } |
321 | // See NOTE [ Limitations of ADInplaceOrView boxed kernel ] above |
322 | TORCH_CHECK( |
323 | (aliased_input_idx == -1 && aliased_output_idx == -1) || |
324 | (aliased_input_idx == 0 && aliased_output_idx == 0), |
325 | "Fallback ADInplaceOrView kernel can only create view relationships between the first " |
326 | "input and the first output (the output can be a vector of tensors). Please change the " |
327 | "order of your operator's parameters so that this is the case." ); |
328 | const bool is_view = aliased_input_idx != -1; |
329 | |
330 | { |
331 | at::AutoDispatchBelowADInplaceOrView guard; |
332 | op.redispatchBoxed( |
333 | dispatch_keys & c10::after_ADInplaceOrView_keyset, stack); |
334 | } |
335 | |
336 | for (const auto i : c10::irange(num_returns)) { |
337 | if (schema.is_mutable({c10::SchemaArgType::output, i})) { |
338 | increment_version((*stack)[stack->size() - num_returns + i].toTensor()); |
339 | } |
340 | } |
341 | |
342 | if (is_view) { |
343 | c10::IValue& aliased_output_iv = |
344 | (*stack)[stack->size() - num_returns + aliased_output_idx]; |
345 | if (aliased_output_iv.isTensorList()) { |
346 | auto aliased_output = aliased_output_iv.toTensorVector(); |
347 | // Only allow rebasing of the history if we return a single Tensor that is |
348 | // why we don't have to care about the view_func logic below. |
349 | // See NOTE [ View + Inplace detection ] for more details about this logic |
350 | auto result = as_view( |
351 | /* base=*/aliased_input, |
352 | /* tensors=*/aliased_output, |
353 | /* is_bw_differentiable=*/true, |
354 | /* is_fw_differentiable=*/true, |
355 | /* creation_meta=*/ |
356 | InferenceMode::is_enabled() |
357 | ? CreationMeta::INFERENCE_MODE |
358 | : (at::GradMode::is_enabled() ? CreationMeta::MULTI_OUTPUT_NODE |
359 | : CreationMeta::NO_GRAD_MODE)); |
360 | // ^ pass in creation meta unecessarily even if not isDifferentiableType, |
361 | // but we don't have that |
362 | // information here anyway. |
363 | stack->at(stack->size() - num_returns + aliased_output_idx) = result; |
364 | } else { |
365 | TORCH_CHECK(aliased_output_iv.isTensor()); |
366 | auto result = as_view( |
367 | /* base=*/aliased_input, |
368 | /* tensor=*/std::move(aliased_output_iv).toTensor(), |
369 | /* is_bw_differentiable=*/true, |
370 | /* is_fw_differentiable=*/true, |
371 | /* view_func=*/ |
372 | [op_name = op_name](const at::Tensor&) { |
373 | // We always need this view_func because otherwise if we do in-place |
374 | // on this view, we would implicitly use AsStridedBackward instead |
375 | // of the NotImplemented node. For the cross-dtype/non-strided |
376 | // cases, we would create something like this anyway |
377 | TORCH_CHECK( |
378 | false, |
379 | "Mutating the view " , |
380 | op_name, |
381 | " which does not have a derivative implemented is forbidden." ); |
382 | return at::Tensor(); |
383 | }, |
384 | /* creation_meta=*/ |
385 | InferenceMode::is_enabled() |
386 | ? CreationMeta::INFERENCE_MODE |
387 | : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT |
388 | : CreationMeta::NO_GRAD_MODE)); |
389 | stack->at(stack->size() - num_returns + aliased_output_idx) = |
390 | std::move(result); |
391 | } |
392 | } |
393 | } |
394 | |
395 | torch::CppFunction autogradNotImplementedInplaceOrViewFallback() { |
396 | return torch::CppFunction::makeFromBoxedFunction< |
397 | &autogradNotImplementedInplaceOrViewFallbackImpl>(); |
398 | } |
399 | |
400 | } // namespace autograd |
401 | } // namespace torch |
402 | |