1#pragma once
2
3#include <cstdint>
4#include <ATen/core/function_schema.h>
5#include <ATen/core/jit_type.h>
6#include <c10/util/Bitset.h>
7#include <c10/core/DispatchKeySet.h>
8#include <c10/util/irange.h>
9#include <ATen/core/Variadic.h>
10#include <ATen/core/stack.h>
11
12namespace c10 {
13
14namespace impl {
15
16// Take a DispatchKeySet for a Tensor and determine what the actual dispatch
17// DispatchKey should be, taking into account TLS, and skipping backends which
18// fall through.
19//
20// Unlike Tensor::key_set(), the value of this on a tensor can change depending
21// on TLS.
22//
23// NB: If there is no valid dispatch key, this will return Undefined
24static inline DispatchKeySet computeDispatchKeySet(
25 DispatchKeySet ks,
26 // The key mask lets us eliminate (by zero entries) keys which should not
27 // be considered for dispatch. There are two cases when we use this:
28 //
29 // - If an operator's dispatch table contains a fallthrough entry, we
30 // should bypass it entirely when finding the key
31 // - If a user invokes with redispatch, the mask lets us
32 // zero out the key the user asked us to stop.
33 //
34 // These excluded backends are NOT tracked in the TLS, but must be applied
35 // AFTER TLS (since the backend may have been introduced for consideration
36 // by the included TLS), which is why you have to pass them in to this
37 // function (as opposed to just applying it to the input 'ks').
38 DispatchKeySet key_mask
39) {
40 c10::impl::LocalDispatchKeySet local = c10::impl::tls_local_dispatch_key_set();
41 // TODO: It's a bit irritating that we have to do logical ORs here, it would
42 // be nice to only do one. Can always_included be folded into the TLS? Well,
43 // it's a bit troublesome, because fastpath TLS access requires the type of
44 // the TLS in question to be zero-initialized, so you don't actually win
45 // anyting in that case.
46 return (((ks | local.included_) - local.excluded_) & key_mask);
47}
48
49}
50
51namespace detail {
52 // A small gadget to extract the DispatchKeySet from types which are known
53 // to have it. Used to extract dispatch keys from unboxed calls.
54 struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
55 DispatchKeySet ts;
56 void operator()(const at::Tensor& x) {
57 ts = ts | x.key_set();
58 }
59 void operator()(const c10::optional<at::Tensor>& x) {
60 if (x.has_value()) {
61 ts = ts | x->key_set();
62 }
63 }
64 void operator()(at::ArrayRef<at::Tensor> xs) {
65 for (const auto& x : xs) {
66 ts = ts | x.key_set();
67 }
68 }
69 // Tensor?[] translates to this case.
70 void operator()(const c10::List<c10::optional<at::Tensor>>& xs) {
71 for (c10::optional<at::Tensor> x : xs) {
72 if (x.has_value()) {
73 ts = ts | x.value().key_set();
74 }
75 }
76 }
77 // Structured Tensor[] translates to this case
78 void operator()(at::ITensorListRef xs) {
79 for (const auto& x : xs) {
80 ts = ts | x.key_set();
81 }
82 }
83 [[noreturn]] void operator()(at::ArrayRef<c10::optional<at::Tensor>>) {
84 // Just checking that the handling of Tensor?[] didn't change.
85 TORCH_INTERNAL_ASSERT(false);
86 }
87 void operator()(const at::Generator& gen) {
88 if (gen.defined()) {
89 ts = ts | gen.key_set();
90 }
91 }
92 void operator()(const c10::optional<at::Generator>& gen) {
93 if (gen.has_value() && gen->defined()) {
94 ts = ts | gen->key_set();
95 }
96 }
97 template <typename T>
98 void operator()(const T&) {
99 // do nothing
100 }
101 };
102
103 // NB: take by const reference (Don't do universal forwarding here! You
104 // don't want to move into this function!)
105 template <typename... Args>
106 DispatchKeySet multi_dispatch_key_set(const Args&... args) {
107 return MultiDispatchKeySet().apply(args...).ts;
108 }
109}
110
111/**
112 * An instance of DispatchKeyExtractor knows how to get a dispatch key given
113 * a list of arguments for an operator call.
114 *
115 * The instance is specific for a certain operator as:
116 * - In boxed dispatch, different operators have different ways to extract
117 * the dispatch key (e.g. different numbers of arguments), and we precompute
118 * the stack locations we should look at; and
119 * - In all dispatch, some backends should be excluded from dispatch because
120 * they have been registered as fallthrough. The set of excluded backends
121 * varies from operator, as some operators may have overridden the
122 * fallthrough with custom behavior.
123 *
124 * Note - this should maintain identical impl to the py dispatcher key extraction logic
125 * at pytorch/torch/dispatcher.py
126 */
127struct TORCH_API DispatchKeyExtractor final {
128public:
129 static DispatchKeyExtractor make(const FunctionSchema& schema) {
130 return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema));
131 }
132
133 static DispatchKeyExtractor makeUninitialized() {
134 return DispatchKeyExtractor(c10::utils::bitset());
135 }
136
137 void registerSchema(const FunctionSchema& schema) {
138 TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset());
139 dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema);
140 }
141 void deregisterSchema() {
142 dispatch_arg_indices_reverse_ = c10::utils::bitset();
143 }
144
145 DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const {
146 DispatchKeySet ks;
147 dispatch_arg_indices_reverse_.for_each_set_bit([&] (size_t reverse_arg_index) {
148 const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
149 if (C10_LIKELY(ivalue.isTensor())) {
150 // NB: Take care not to introduce a refcount bump (there's
151 // no safe toTensorRef method, alas)
152 ks = ks | ivalue.unsafeToTensorImpl()->key_set();
153 } else if (C10_UNLIKELY(ivalue.isTensorList())) {
154 for (const at::Tensor& tensor : ivalue.toTensorList()) {
155 ks = ks | tensor.key_set();
156 }
157 }
158 // Tensor?[] translates to a c10::List<IValue> so we need to peek inside
159 else if (C10_UNLIKELY(ivalue.isList())) {
160 for (const auto& elt : ivalue.toListRef()) {
161 if (elt.isTensor()) {
162 ks = ks | elt.toTensor().key_set();
163 }
164 }
165 }
166 });
167 // Keys that are fallthrough should be skipped
168 if (requiresBitsetPerBackend_) {
169 auto backend_idx = ks.getBackendIndex();
170 return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
171 } else {
172 return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
173 }
174 }
175
176 template<class... Args>
177 DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
178 auto ks = detail::multi_dispatch_key_set(args...);
179 // Keys that are fallthrough should be skipped
180 if (requiresBitsetPerBackend_) {
181 auto backend_idx = ks.getBackendIndex();
182 return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
183 } else {
184 return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
185 }
186 }
187
188 void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
189
190 std::string dumpState() const;
191 void checkInvariants(const FunctionSchema& schema) const;
192
193private:
194 static c10::utils::bitset makeBitsetForDispatchArgs(const FunctionSchema& schema) {
195 TORCH_CHECK(schema.arguments().size() <= c10::utils::bitset::NUM_BITS(),
196 "The function schema has ", schema.arguments().size(),
197 " arguments but this PyTorch build only supports ", c10::utils::bitset::NUM_BITS());
198 c10::utils::bitset dispatch_arg_indices_reverse;
199 for (const auto index : c10::irange(schema.arguments().size())) {
200 if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) ||
201 schema.arguments()[index].type()->isSubtypeOf(
202 *ListType::ofTensors()) ||
203 schema.arguments()[index].type()->isSubtypeOf(
204 *ListType::ofOptionalTensors()) ||
205 schema.arguments()[index].type()->isSubtypeOf(
206 *OptionalType::ofTensor())) {
207 dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
208 }
209 }
210 return dispatch_arg_indices_reverse;
211 }
212
213 explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
214 : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
215 , nonFallthroughKeys_(DispatchKeySet::FULL)
216 , requiresBitsetPerBackend_(false) {
217 for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
218 nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
219 }
220 }
221
222 // this is a bitset that has ones for each argument index which has to be
223 // considered for dispatch. This avoids having to iterate over the stack
224 // to find all the tensors. The bits are stored in reverse order, i.e.
225 // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from
226 // the top of the stack (i.e. the i-th last argument of the function)
227 // is relevant for dispatch.
228 // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just means you must do the
229 // fallthrough
230 c10::utils::bitset dispatch_arg_indices_reverse_;
231
232 // Set of functionality keys for which the operator does NOT have fallthrough kernel.
233 DispatchKeySet nonFallthroughKeys_;
234 // Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND.
235 // This is only needed if we know that the operator has a different set of fallthroughs defined for some backends.
236 std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
237 // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path),
238 // or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_
239 bool requiresBitsetPerBackend_;
240};
241
242}
243