1 | #include <ATen/ATen.h> |
2 | #include <ATen/core/dispatch/Dispatcher.h> |
3 | #include <ATen/core/op_registration/op_registration.h> |
4 | #include <ATen/native/UnaryOps.h> |
5 | #include <ATen/NativeFunctions.h> |
6 | #include <c10/util/irange.h> |
7 | #include <torch/library.h> |
8 | #include <ATen/native/MathBitFallThroughLists.h> |
9 | |
10 | namespace at { |
11 | |
12 | // TODO: add a note explaining the design decisions |
13 | // ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors |
14 | void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { |
15 | const auto& arguments = op.schema().arguments(); |
16 | const auto num_arguments = arguments.size(); |
17 | const auto stack_start = stack->size() - num_arguments; |
18 | |
19 | c10::optional<bool> is_write; |
20 | for (const auto i : c10::irange(num_arguments)) { |
21 | const auto& alias_info = arguments[i].alias_info(); |
22 | if (alias_info != nullptr) { |
23 | if (is_write.has_value()) { |
24 | TORCH_CHECK(*is_write == alias_info->isWrite(), |
25 | "Unsupported operator for " , "ZeroTensorFallback: " , op.schema().name(), |
26 | "ZeroTensor fallback doesn't work for operators with a mix " |
27 | "mutable and non-mutable inputs that alias with outputs, " |
28 | "this must be implemented manually. " |
29 | "If you got this error on a core op, please report a bug to PyTorch." ); |
30 | } else { |
31 | is_write = alias_info->isWrite(); |
32 | } |
33 | } |
34 | } |
35 | |
36 | if (is_write.has_value() && !*is_write) { |
37 | // We assume that view operators automatically handle the ZeroTensor bit |
38 | // correctly by propagating the dispatch key in key_set. |
39 | // This is not necessarily always right, so you should test these cases. |
40 | op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack); |
41 | return; |
42 | } |
43 | |
44 | for (const auto i : c10::irange(num_arguments)) { |
45 | auto& ivalue = (*stack)[stack_start + i]; |
46 | if (!(ivalue.isTensor() || ivalue.isTensorList())) { |
47 | continue; |
48 | } |
49 | const auto& argument = arguments[i]; |
50 | bool mut_arg = false; |
51 | |
52 | if (argument.alias_info()) { |
53 | // Was already tested by is_write loop above |
54 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite()); |
55 | mut_arg = true; |
56 | } |
57 | |
58 | if (ivalue.isTensor()) { |
59 | auto tensor = std::move(ivalue).toTensor(); |
60 | if (tensor._is_zerotensor()) { |
61 | TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor " , |
62 | "obtained using .clone() if you want a mutable tensor." ); |
63 | tensor = at::zeros({}, tensor.options()).expand(tensor.sizes()); |
64 | } |
65 | (*stack)[stack_start + i] = std::move(tensor); |
66 | } else if (ivalue.isTensorList()) { |
67 | auto tensors = std::move(ivalue).toTensorList(); |
68 | for(const auto j : c10::irange(tensors.size())) { |
69 | const Tensor& tensor = tensors[j]; |
70 | if (tensor._is_zerotensor()) { |
71 | // TODO: assert requires_grad=False |
72 | //_like should not propagate zerotensor dispatch key |
73 | TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor " , |
74 | "obtained using .clone() if you want a mutable tensor." ); |
75 | tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes()); |
76 | } |
77 | } |
78 | (*stack)[stack_start + i] = std::move(tensors); |
79 | } |
80 | } |
81 | |
82 | op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack); |
83 | } |
84 | |
85 | |
86 | TORCH_LIBRARY_IMPL(_, ZeroTensor, m) { |
87 | m.fallback(torch::CppFunction::makeFromBoxedFunction<&zeroTensorFallback>()); |
88 | } |
89 | |
90 | TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) { |
91 | m.impl("zeros_like" , torch::CppFunction::makeFallthrough()); |
92 | m.impl("mul.Scalar" , torch::CppFunction::makeFallthrough()); |
93 | m.impl("add.Scalar" , torch::CppFunction::makeFallthrough()); |
94 | m.impl("copy_" , torch::CppFunction::makeFallthrough()); |
95 | m.impl("clone" , torch::CppFunction::makeFallthrough()); |
96 | m.impl("dot" , torch::CppFunction::makeFallthrough()); |
97 | m.impl("vdot" , torch::CppFunction::makeFallthrough()); |
98 | // The functions in the list below have a specific registeration in native_functions.yaml and |
99 | // do not use the fallback. |
100 | // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough()); |
101 | // m.impl("add.Tensor", torch::CppFunction::makeFallthrough()); |
102 | // m.impl("linalg_cross", torch::CppFunction::makeFallthrough()); |
103 | |
104 | TORCH_VIEW_FNS(m) |
105 | TENSOR_UTILITIES_AND_CONSTRUCTORS(m) |
106 | } |
107 | } // namespace at |
108 | |