1#include <ATen/ATen.h>
2#include <ATen/core/dispatch/Dispatcher.h>
3#include <ATen/core/op_registration/op_registration.h>
4#include <ATen/native/UnaryOps.h>
5#include <ATen/NativeFunctions.h>
6#include <c10/util/irange.h>
7#include <torch/library.h>
8#include <ATen/native/MathBitFallThroughLists.h>
9
10namespace at {
11
12 // TODO: add a note explaining the design decisions
13 // ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors
14 void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
15 const auto& arguments = op.schema().arguments();
16 const auto num_arguments = arguments.size();
17 const auto stack_start = stack->size() - num_arguments;
18
19 c10::optional<bool> is_write;
20 for (const auto i : c10::irange(num_arguments)) {
21 const auto& alias_info = arguments[i].alias_info();
22 if (alias_info != nullptr) {
23 if (is_write.has_value()) {
24 TORCH_CHECK(*is_write == alias_info->isWrite(),
25 "Unsupported operator for ", "ZeroTensorFallback: ", op.schema().name(),
26 "ZeroTensor fallback doesn't work for operators with a mix "
27 "mutable and non-mutable inputs that alias with outputs, "
28 "this must be implemented manually. "
29 "If you got this error on a core op, please report a bug to PyTorch.");
30 } else {
31 is_write = alias_info->isWrite();
32 }
33 }
34 }
35
36 if (is_write.has_value() && !*is_write) {
37 // We assume that view operators automatically handle the ZeroTensor bit
38 // correctly by propagating the dispatch key in key_set.
39 // This is not necessarily always right, so you should test these cases.
40 op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
41 return;
42 }
43
44 for (const auto i : c10::irange(num_arguments)) {
45 auto& ivalue = (*stack)[stack_start + i];
46 if (!(ivalue.isTensor() || ivalue.isTensorList())) {
47 continue;
48 }
49 const auto& argument = arguments[i];
50 bool mut_arg = false;
51
52 if (argument.alias_info()) {
53 // Was already tested by is_write loop above
54 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
55 mut_arg = true;
56 }
57
58 if (ivalue.isTensor()) {
59 auto tensor = std::move(ivalue).toTensor();
60 if (tensor._is_zerotensor()) {
61 TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
62 "obtained using .clone() if you want a mutable tensor.");
63 tensor = at::zeros({}, tensor.options()).expand(tensor.sizes());
64 }
65 (*stack)[stack_start + i] = std::move(tensor);
66 } else if (ivalue.isTensorList()) {
67 auto tensors = std::move(ivalue).toTensorList();
68 for(const auto j : c10::irange(tensors.size())) {
69 const Tensor& tensor = tensors[j];
70 if (tensor._is_zerotensor()) {
71 // TODO: assert requires_grad=False
72 //_like should not propagate zerotensor dispatch key
73 TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
74 "obtained using .clone() if you want a mutable tensor.");
75 tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes());
76 }
77 }
78 (*stack)[stack_start + i] = std::move(tensors);
79 }
80 }
81
82 op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
83 }
84
85
86 TORCH_LIBRARY_IMPL(_, ZeroTensor, m) {
87 m.fallback(torch::CppFunction::makeFromBoxedFunction<&zeroTensorFallback>());
88 }
89
90 TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) {
91 m.impl("zeros_like", torch::CppFunction::makeFallthrough());
92 m.impl("mul.Scalar", torch::CppFunction::makeFallthrough());
93 m.impl("add.Scalar", torch::CppFunction::makeFallthrough());
94 m.impl("copy_", torch::CppFunction::makeFallthrough());
95 m.impl("clone", torch::CppFunction::makeFallthrough());
96 m.impl("dot", torch::CppFunction::makeFallthrough());
97 m.impl("vdot", torch::CppFunction::makeFallthrough());
98 // The functions in the list below have a specific registeration in native_functions.yaml and
99 // do not use the fallback.
100 // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough());
101 // m.impl("add.Tensor", torch::CppFunction::makeFallthrough());
102 // m.impl("linalg_cross", torch::CppFunction::makeFallthrough());
103
104 TORCH_VIEW_FNS(m)
105 TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
106 }
107} // namespace at
108