1#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
2
3#include <c10/util/irange.h>
4
5#include <ATen/core/TorchDispatchUtils.h>
6#include <ATen/core/dispatch/Dispatcher.h>
7#include <ATen/core/ivalue.h>
8
9#include <c10/core/impl/TorchDispatchModeTLS.h>
10#include <torch/csrc/autograd/VariableTypeUtils.h>
11#include <torch/csrc/autograd/autograd.h>
12#include <torch/csrc/autograd/function.h>
13#include <torch/csrc/autograd/functions/basic_ops.h>
14#include <torch/csrc/autograd/functions/utils.h>
15
16#include <utility>
17#include <vector>
18
19namespace torch {
20namespace autograd {
21
22namespace {
23
24template <typename F>
25void _foreach_tensor(
26 F fn,
27 torch::jit::Stack* stack,
28 size_t stack_start,
29 size_t size) {
30 // Enumerate over tensors in a stack, including ones in TensorLists
31 int idx_tensor = 0;
32 for (const auto idx_arg : c10::irange(size)) {
33 auto& ivalue = (*stack)[stack_start + idx_arg];
34 if (ivalue.isTensor()) { // true for optional tensor that has value
35 const auto& tensor = ivalue.toTensor();
36 fn(idx_tensor, idx_arg, tensor);
37 idx_tensor++;
38 } else if (ivalue.isTensorList()) {
39 for (const auto& iv : ivalue.toListRef()) {
40 const auto& tensor = iv.toTensor();
41 fn(idx_tensor, idx_arg, tensor);
42 idx_tensor++;
43 }
44 }
45 }
46}
47
48} // namespace
49
50void autogradNotImplementedFallbackImpl(
51 const c10::OperatorHandle& op,
52 c10::DispatchKeySet dispatch_keys,
53 torch::jit::Stack* stack) {
54 // Mimics a subset of the logic of a VariableType NotImplemented kernel
55 // See gen_variable_type.py
56 const auto& schema = op.schema();
57 const auto& op_name = schema.operator_name().name;
58 const auto num_arguments = schema.arguments().size();
59 const auto num_returns = schema.returns().size();
60 const auto stack_start = stack->size() - num_arguments;
61 const bool grad_mode = GradMode::is_enabled();
62 std::vector<const at::Tensor*> tensors_requiring_grad_on_stack;
63
64 // Keep track of which outputs are output of in-place modification
65 // so we can rebase_history if necessary
66 std::vector<bool> is_inplace_output(num_returns, false);
67 bool any_is_inplace_output = false;
68 std::vector<bool> is_aliased_output(num_returns, false);
69 int aliased_output_idx = -1;
70
71 for (const auto i : c10::irange(num_returns)) {
72 if (schema.is_aliasing({c10::SchemaArgType::output, i})) {
73 if (schema.is_mutable({c10::SchemaArgType::output, i})) {
74 is_inplace_output[i] = true;
75 any_is_inplace_output = true;
76 } else {
77 TORCH_CHECK(
78 aliased_output_idx == -1,
79 "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
80 "Non-composite functions where multiple outputs are aliased with inputs aren't supported."
81 "Please rewrite your function as a composite function.");
82 aliased_output_idx = i;
83 }
84 is_aliased_output[i] = true;
85 }
86 }
87
88 int aliased_input_idx = -1;
89 for (const auto i : c10::irange(num_arguments)) {
90 if (schema.is_aliasing({c10::SchemaArgType::input, i}) &&
91 !schema.is_mutable({c10::SchemaArgType::input, i})) {
92 TORCH_CHECK(
93 aliased_input_idx == -1,
94 "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
95 "Non-composite functions where multiple inputs are aliased with outputs aren't supported. "
96 "Please rewrite your function as a composite function.");
97 aliased_input_idx = i;
98 }
99 }
100
101 size_t num_tensor_inputs = 0; // Only used for DEBUG-only checks
102 _foreach_tensor(
103 [&](size_t _, size_t idx_arg, const at::Tensor& t) {
104 if (grad_mode && t.requires_grad()) {
105 tensors_requiring_grad_on_stack.push_back(&t);
106 }
107 num_tensor_inputs++;
108 TORCH_CHECK_NOT_IMPLEMENTED(
109 !isFwGradDefined(t),
110 "Trying to use forward AD with ",
111 op_name,
112 " that does not support it.");
113 },
114 stack,
115 stack_start,
116 num_arguments);
117
118 const bool any_requires_grad = !tensors_requiring_grad_on_stack.empty();
119
120 _foreach_tensor(
121 [&](size_t _, size_t i, const at::Tensor& t) {
122 if (schema.is_mutable({c10::SchemaArgType::input, i})) {
123 check_inplace(t, any_requires_grad);
124 }
125 },
126 stack,
127 stack_start,
128 num_arguments);
129
130 std::shared_ptr<NotImplemented> grad_fn;
131 if (any_requires_grad) {
132 grad_fn = std::shared_ptr<NotImplemented>(
133 new NotImplemented(op_name), deleteNode);
134 grad_fn->set_next_edges(
135 collect_next_edges(tensors_requiring_grad_on_stack));
136 }
137
138#ifndef NDEBUG
139 // See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
140 auto stack_args_copy =
141 std::vector<c10::IValue>(stack->begin() + stack_start, stack->end());
142 std::vector<c10::intrusive_ptr<c10::TensorImpl>> impl_saved;
143 impl_saved.reserve(num_tensor_inputs);
144 std::vector<c10::optional<c10::Storage>> storage_saved;
145 storage_saved.reserve(num_tensor_inputs);
146 _foreach_tensor(
147 [&](size_t idx, size_t _, const at::Tensor& t) {
148 storage_saved.push_back(
149 t.has_storage() ? c10::optional<c10::Storage>(t.storage())
150 : c10::nullopt);
151 impl_saved.push_back(t.getIntrusivePtr());
152 },
153 &stack_args_copy,
154 0,
155 num_arguments);
156#endif
157 if (aliased_input_idx != -1 || any_is_inplace_output) {
158 at::AutoDispatchBelowAutograd guard;
159 op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
160 } else {
161 // If neither in-place nor view
162 at::AutoDispatchBelowADInplaceOrView guard;
163 op.redispatchBoxed(
164 dispatch_keys & c10::after_ADInplaceOrView_keyset, stack);
165 }
166#ifndef NDEBUG
167 _foreach_tensor(
168 [&](size_t idx_tensor, size_t _, const at::Tensor& t) {
169 if (storage_saved.at(idx_tensor).has_value())
170 TORCH_INTERNAL_ASSERT(
171 storage_saved.at(idx_tensor).value().is_alias_of(t.storage()),
172 op_name);
173 if (impl_saved.at(idx_tensor))
174 TORCH_INTERNAL_ASSERT(
175 impl_saved.at(idx_tensor) == t.getIntrusivePtr(), op_name);
176 },
177 &stack_args_copy,
178 0,
179 num_arguments);
180 _foreach_tensor(
181 [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) {
182 if (at::impl::tensor_has_dispatch(t) ||
183 at::impl::dispatch_mode_enabled())
184 return;
185 if (!is_inplace_output[idx_ret])
186 TORCH_INTERNAL_ASSERT(
187 t.use_count() <= 1, op_name); // Okay to return undefined tensor
188 // note(crcrpar): `_foreach_norm` returns a list of scalar Tensors and
189 // each Tensor shares a storage of a hidden, intermediate 1D Tensor
190 // created inside the CUDA implemenetation. This is because the
191 // reference implementation of nvidia/apex repo returns this 1D Tensor
192 // where each element represents the norm of corresponding input Tensor,
193 // here I want to return the same number of Tensors as the input
194 // TensorList, see https://github.com/pytorch/pytorch/issues/93940
195 if (!is_aliased_output[idx_ret] && t.has_storage() &&
196 op_name != "aten::_foreach_norm")
197 TORCH_INTERNAL_ASSERT(t.storage().use_count() == 1);
198 },
199 stack,
200 stack->size() - num_returns,
201 num_returns);
202 // There should be only a single base-view pair, make sure their storage is
203 // aliased.
204 if (aliased_input_idx != -1 && aliased_output_idx != -1) {
205 const c10::IValue& aliased_input_iv = stack_args_copy[aliased_input_idx];
206 const c10::IValue& aliased_output_iv =
207 (*stack)[stack->size() - num_returns + aliased_output_idx];
208 TORCH_INTERNAL_ASSERT(aliased_input_iv.isTensor(), op_name);
209 TORCH_INTERNAL_ASSERT(
210 aliased_output_iv.isTensor() || aliased_output_iv.isTensorList(),
211 op_name);
212 const at::Tensor& aliased_input = aliased_input_iv.toTensor();
213 if (aliased_input.has_storage()) {
214 if (aliased_output_iv.isTensor()) {
215 const at::Tensor& aliased_output = aliased_input_iv.toTensor();
216 TORCH_INTERNAL_ASSERT(
217 aliased_input.storage().is_alias_of(aliased_output.storage()),
218 op_name);
219 } else {
220 const auto aliased_output_vec = aliased_output_iv.toTensorVector();
221 for (const auto& aliased_output : aliased_output_vec) {
222 TORCH_INTERNAL_ASSERT(
223 aliased_input.storage().is_alias_of(aliased_output.storage()),
224 op_name);
225 }
226 }
227 }
228 }
229#endif
230
231 if (any_requires_grad) {
232 _foreach_tensor(
233 [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) {
234 if (isDifferentiableType(t.scalar_type())) {
235 if (is_inplace_output[idx_ret]) {
236 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
237 rebase_history(const_cast<at::Tensor&>(t), grad_fn);
238 } else {
239 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
240 set_history(const_cast<at::Tensor&>(t), grad_fn);
241 }
242 }
243 },
244 stack,
245 stack->size() - num_returns,
246 num_returns);
247 }
248}
249
250torch::CppFunction autogradNotImplementedFallback() {
251 return torch::CppFunction::makeFromBoxedFunction<
252 &autogradNotImplementedFallbackImpl>();
253}
254
255void autogradNotImplementedInplaceOrViewFallbackImpl(
256 const c10::OperatorHandle& op,
257 c10::DispatchKeySet dispatch_keys,
258 torch::jit::Stack* stack) {
259 // Mimics a subset of the logic from ADInplaceOrViewType kernel:
260 // - see gen_inplace_or_view_type.py
261 // - this should only be used with autogradNotImplementedFallback above
262 // - For more information see
263 // https://pytorch.org/tutorials/advanced/dispatcher
264 //
265 // NOTE [ Limitations of ADInplaceOrView boxed kernel ]
266 //
267 // This op should only be used with autogradNotImplementedFallback kernel
268 // because there is some logic we need specifically to enforce that even
269 // if we do in-place on view's created in this kernel, the proper "derivative
270 // is not implemented" error is still raised.
271 //
272 // Just like the codegened kernel, we try to enforce some things:
273 // - For views: we enforce that the view relationship is between the first
274 // input
275 // and the first output (which may be either Tensor or vec of Tensors
276 // - For inplace (TODO?): enforce that the same op cannot be both a view and
277 // inplace
278 // that is not allowed in the gen_inplace_or_view logic
279 const auto& schema = op.schema();
280 const auto& op_name = schema.operator_name().name;
281 const auto num_arguments = schema.arguments().size();
282 const auto num_returns = schema.returns().size();
283 const auto stack_start = stack->size() - num_arguments;
284
285 at::Tensor aliased_input;
286
287 int64_t aliased_output_idx = -1;
288 for (const auto i : c10::irange(num_returns)) {
289 if (schema.is_aliasing({c10::SchemaArgType::output, i}) &&
290 !schema.is_mutable({c10::SchemaArgType::output, i})) {
291 TORCH_CHECK(
292 aliased_output_idx == -1,
293 "Fallback ADInplaceOrView kernel expects only a single output in the operator schema to have a "
294 "non-write alias annotation (i.e., 'Tensor(a)'). "
295 "Non-composite functions where multiple outputs are aliased with inputs aren't supported."
296 "Please rewrite your function as a composite function.");
297 aliased_output_idx = i;
298 }
299 }
300
301 int64_t aliased_input_idx = -1;
302 for (const auto i : c10::irange(num_arguments)) {
303 if (schema.is_aliasing({c10::SchemaArgType::input, i}) &&
304 !schema.is_mutable({c10::SchemaArgType::input, i})) {
305 TORCH_CHECK(
306 aliased_input_idx == -1,
307 "Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a "
308 "non-write alias annotation (i.e., 'Tensor(a)'). "
309 "Non-composite functions where multiple inputs are aliased with outputs aren't supported. "
310 "Please rewrite your function as a composite function.");
311 aliased_input_idx = i;
312 const c10::IValue& aliased_input_iv =
313 (*stack)[stack_start + i]; // get a reference to an ivalue on the
314 // stack
315 TORCH_CHECK(aliased_input_iv.isTensor());
316 aliased_input =
317 aliased_input_iv.toTensor(); // TODO: Can we avoid saving this tensor
318 // and incurring the refcount bump?
319 }
320 }
321 // See NOTE [ Limitations of ADInplaceOrView boxed kernel ] above
322 TORCH_CHECK(
323 (aliased_input_idx == -1 && aliased_output_idx == -1) ||
324 (aliased_input_idx == 0 && aliased_output_idx == 0),
325 "Fallback ADInplaceOrView kernel can only create view relationships between the first "
326 "input and the first output (the output can be a vector of tensors). Please change the "
327 "order of your operator's parameters so that this is the case.");
328 const bool is_view = aliased_input_idx != -1;
329
330 {
331 at::AutoDispatchBelowADInplaceOrView guard;
332 op.redispatchBoxed(
333 dispatch_keys & c10::after_ADInplaceOrView_keyset, stack);
334 }
335
336 for (const auto i : c10::irange(num_returns)) {
337 if (schema.is_mutable({c10::SchemaArgType::output, i})) {
338 increment_version((*stack)[stack->size() - num_returns + i].toTensor());
339 }
340 }
341
342 if (is_view) {
343 c10::IValue& aliased_output_iv =
344 (*stack)[stack->size() - num_returns + aliased_output_idx];
345 if (aliased_output_iv.isTensorList()) {
346 auto aliased_output = aliased_output_iv.toTensorVector();
347 // Only allow rebasing of the history if we return a single Tensor that is
348 // why we don't have to care about the view_func logic below.
349 // See NOTE [ View + Inplace detection ] for more details about this logic
350 auto result = as_view(
351 /* base=*/aliased_input,
352 /* tensors=*/aliased_output,
353 /* is_bw_differentiable=*/true,
354 /* is_fw_differentiable=*/true,
355 /* creation_meta=*/
356 InferenceMode::is_enabled()
357 ? CreationMeta::INFERENCE_MODE
358 : (at::GradMode::is_enabled() ? CreationMeta::MULTI_OUTPUT_NODE
359 : CreationMeta::NO_GRAD_MODE));
360 // ^ pass in creation meta unecessarily even if not isDifferentiableType,
361 // but we don't have that
362 // information here anyway.
363 stack->at(stack->size() - num_returns + aliased_output_idx) = result;
364 } else {
365 TORCH_CHECK(aliased_output_iv.isTensor());
366 auto result = as_view(
367 /* base=*/aliased_input,
368 /* tensor=*/std::move(aliased_output_iv).toTensor(),
369 /* is_bw_differentiable=*/true,
370 /* is_fw_differentiable=*/true,
371 /* view_func=*/
372 [op_name = op_name](const at::Tensor&) {
373 // We always need this view_func because otherwise if we do in-place
374 // on this view, we would implicitly use AsStridedBackward instead
375 // of the NotImplemented node. For the cross-dtype/non-strided
376 // cases, we would create something like this anyway
377 TORCH_CHECK(
378 false,
379 "Mutating the view ",
380 op_name,
381 " which does not have a derivative implemented is forbidden.");
382 return at::Tensor();
383 },
384 /* creation_meta=*/
385 InferenceMode::is_enabled()
386 ? CreationMeta::INFERENCE_MODE
387 : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT
388 : CreationMeta::NO_GRAD_MODE));
389 stack->at(stack->size() - num_returns + aliased_output_idx) =
390 std::move(result);
391 }
392 }
393}
394
395torch::CppFunction autogradNotImplementedInplaceOrViewFallback() {
396 return torch::CppFunction::makeFromBoxedFunction<
397 &autogradNotImplementedInplaceOrViewFallbackImpl>();
398}
399
400} // namespace autograd
401} // namespace torch
402