1 | #pragma once |
2 | |
3 | #include <ATen/SequenceNumber.h> |
4 | #include <ATen/core/boxing/KernelFunction.h> |
5 | #include <ATen/core/boxing/impl/boxing.h> |
6 | #include <ATen/core/dispatch/OperatorEntry.h> |
7 | #include <ATen/core/dispatch/CppSignature.h> |
8 | #include <ATen/core/dispatch/RegistrationHandleRAII.h> |
9 | #include <ATen/record_function.h> |
10 | #include <c10/util/Exception.h> |
11 | #include <c10/util/LeftRight.h> |
12 | #include <list> |
13 | #include <mutex> |
14 | #include <condition_variable> |
15 | #include <type_traits> |
16 | |
17 | #include <ATen/core/grad_mode.h> |
18 | #include <ATen/core/enum_tag.h> |
19 | |
20 | namespace c10 { |
21 | |
22 | TORCH_API bool show_dispatch_trace(); |
23 | TORCH_API void dispatch_trace_nesting_incr(); |
24 | TORCH_API void dispatch_trace_nesting_decr(); |
25 | TORCH_API int64_t dispatch_trace_nesting_value(); |
26 | |
27 | struct DispatchTraceNestingGuard { |
28 | DispatchTraceNestingGuard() { dispatch_trace_nesting_incr(); } |
29 | ~DispatchTraceNestingGuard() { dispatch_trace_nesting_decr(); } |
30 | }; |
31 | |
32 | class TORCH_API OperatorHandle; |
33 | template<class FuncType> class TypedOperatorHandle; |
34 | |
35 | /** |
36 | * Implement this interface and register your instance with the dispatcher |
37 | * to get notified when operators are registered or deregistered with |
38 | * the dispatcher. |
39 | * |
40 | * NB: registration events only occur when a 'def' occurs; we don't trigger |
41 | * on 'impl' or 'fallback' calls. |
42 | */ |
43 | class TORCH_API OpRegistrationListener { |
44 | public: |
45 | virtual ~OpRegistrationListener(); |
46 | |
47 | virtual void onOperatorRegistered(const OperatorHandle& op) = 0; |
48 | virtual void onOperatorDeregistered(const OperatorHandle& op) = 0; |
49 | }; |
50 | |
51 | namespace detail { |
52 | class RegistrationListenerList; |
53 | } |
54 | class SchemaRegistrationHandleRAII; |
55 | |
56 | /** |
57 | * Top-level dispatch interface for dispatching via the dynamic dispatcher. |
58 | * Most end users shouldn't use this directly; if you're trying to register |
59 | * ops look in op_registration |
60 | */ |
61 | class TORCH_API Dispatcher final { |
62 | private: |
63 | // For direct access to backend fallback information |
64 | friend class impl::OperatorEntry; |
65 | |
66 | struct OperatorDef final { |
67 | explicit OperatorDef(OperatorName&& op_name) |
68 | : op(std::move(op_name)) {} |
69 | |
70 | impl::OperatorEntry op; |
71 | |
72 | // These refer to the number of outstanding RegistrationHandleRAII |
73 | // for this operator. def_count reflects only def() registrations |
74 | // (in the new world, this should only ever be 1, but old style |
75 | // registrations may register the schema multiple times, which |
76 | // will increase this count). def_and_impl_count reflects the number |
77 | // of combined def() and impl() registrations. When the last def() gets |
78 | // unregistered, we must immediately call the Deregistered listeners, but we |
79 | // must not actually delete the handle as there are other outstanding RAII |
80 | // destructors which will try to destruct and they had better still have a |
81 | // working operator handle in this case |
82 | size_t def_count = 0; |
83 | size_t def_and_impl_count = 0; |
84 | }; |
85 | friend class OperatorHandle; |
86 | template<class> friend class TypedOperatorHandle; |
87 | |
88 | public: |
89 | ~Dispatcher(); |
90 | |
91 | // Implementation note: this class abstracts over the fact that we have per-operator |
92 | // dispatch tables. This could be easily adjusted to have a single global hash |
93 | // table. |
94 | static Dispatcher& realSingleton(); |
95 | |
96 | C10_ALWAYS_INLINE static Dispatcher& singleton() { |
97 | #if !defined C10_MOBILE |
98 | // Implemented inline so that steady-state code needn't incur |
99 | // function-call overhead. We can't just inline `realSingleton` |
100 | // because the function-local static would get duplicated across |
101 | // all DSOs that include & use this header, leading to multiple |
102 | // singleton instances. |
103 | static Dispatcher& s = realSingleton(); |
104 | return s; |
105 | #else |
106 | // For C10_MOBILE, we should never inline a static function that |
107 | // has a static member, since the generated code calls |
108 | // __cxa_guard_acquire and __cxa_guard_release which help |
109 | // implement exactly once semantics for the initialization of the |
110 | // static Dispatcher& s above (for the non-mobile case). That |
111 | // additional code when duplicated across all operator stubs |
112 | // for every backend results in a lot of additional code |
113 | // being generated by the compiler. |
114 | return realSingleton(); |
115 | #endif |
116 | } |
117 | |
118 | // ------------------------------------------------------------------------ |
119 | // |
120 | // Accessing operators by schema |
121 | // |
122 | // ------------------------------------------------------------------------ |
123 | |
124 | /** |
125 | * Looks for an operator schema with the given name and overload name |
126 | * and returns it if it is registered WITH A SCHEMA. |
127 | * Returns nullopt otherwise. |
128 | */ |
129 | c10::optional<OperatorHandle> findSchema(const OperatorName& operator_name); |
130 | |
131 | /** |
132 | * Variant of findSchema that results in less code generated at the call site. |
133 | * It (1) takes const char* pointer rather than OperatorName (so we skip |
134 | * generating std::string constructor calls at the call site), and (2) |
135 | * it raises an exception if the operator is not found (so we skip |
136 | * generating exception raising code at the call site) |
137 | * |
138 | * Irritatingly, we still have to generate the handful of instructions |
139 | * for dealing with an exception being thrown during static initialization |
140 | * (e.g. __cxa_guard_abort). If we could annotate this method noexcept we |
141 | * could avoid this code too, but as the name of the function suggests, |
142 | * it does throw exceptions. |
143 | */ |
144 | OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name); |
145 | |
146 | // Like findSchema, but also returns OperatorHandle even if there is no schema |
147 | c10::optional<OperatorHandle> findOp(const OperatorName& operator_name); |
148 | |
149 | // Returns a list of all operator names present in the operatorLookupTable_ |
150 | const std::vector<OperatorName> getAllOpNames(); |
151 | |
152 | // ------------------------------------------------------------------------ |
153 | // |
154 | // Invoking operators |
155 | // |
156 | // ------------------------------------------------------------------------ |
157 | |
158 | template<class Return, class... Args> |
159 | Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const; |
160 | |
161 | |
162 | template<class Return, class... Args> |
163 | static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args); |
164 | |
165 | // Like call, but intended for use in a redispatch in kernels that have explicitly performed the DispatchKey update calculatulation. |
166 | // This will take the DispatchKeySet completely as is and dispatch to the kernel of the corresponding highest priority key in the set. |
167 | // Note that this version of redispatch treats the inputted DispatchKeySet *as is*, and does NOT mask out the highest priority key. |
168 | // See Note [Plumbing Keys Through The Dispatcher] |
169 | template<class Return, class... Args> |
170 | Return redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const; |
171 | |
172 | // Invoke an operator via the boxed calling convention using an IValue stack |
173 | void callBoxed(const OperatorHandle& op, Stack* stack) const; |
174 | void callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const; |
175 | |
176 | // TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none) |
177 | // See Note [Plumbing Keys Through The Dispatcher] |
178 | void redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const; |
179 | |
180 | bool hasBackendFallbackForDispatchKey(DispatchKey dk) { |
181 | auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk); |
182 | if (dispatch_ix < 0) return false; |
183 | return backendFallbackKernels_[dispatch_ix].kernel.isValid(); |
184 | } |
185 | |
186 | // Used by torchdeploy/multipy for multiple interpreters racing. |
187 | void waitForDef(const FunctionSchema& schema); |
188 | void waitForImpl(const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key); |
189 | |
190 | // ------------------------------------------------------------------------ |
191 | // |
192 | // Performing registrations (NON user public; use op_registration) |
193 | // |
194 | // ------------------------------------------------------------------------ |
195 | |
196 | /** |
197 | * Register a new operator schema. |
198 | * |
199 | * If a schema with the same operator name and overload name already exists, |
200 | * this function will check that both schemas are exactly identical. |
201 | */ |
202 | RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags = {}); |
203 | |
204 | /** |
205 | * Register a kernel to the dispatch table for an operator. |
206 | * If dispatch_key is nullopt, then this registers a fallback kernel. |
207 | * |
208 | * @return A RAII object that manages the lifetime of the registration. |
209 | * Once that object is destructed, the kernel will be deregistered. |
210 | */ |
211 | // NB: steals the inferred function schema, as we may need to hold on to |
212 | // it for a bit until the real schema turns up |
213 | RegistrationHandleRAII registerImpl(OperatorName op_name, c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, c10::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug); |
214 | |
215 | /** |
216 | * Register a new operator by name. |
217 | */ |
218 | RegistrationHandleRAII registerName(OperatorName op_name); |
219 | |
220 | /** |
221 | * Register a fallback kernel for a backend. |
222 | * If an operator is called but there is no concrete kernel for the dispatch |
223 | * key of the given operator arguments, it will check if there is such a |
224 | * fallback kernel for the given dispatch key and, if yes, call that one. |
225 | */ |
226 | RegistrationHandleRAII registerFallback(DispatchKey dispatch_key, KernelFunction kernel, std::string debug); |
227 | |
228 | /** |
229 | * Use to register whenever we had a TORCH_LIBRARY declaration in the frontend |
230 | * API. These invocations are only permitted once per program, so we raise |
231 | * an error if this is called again for the same namespace. |
232 | */ |
233 | RegistrationHandleRAII registerLibrary(std::string ns, std::string debug); |
234 | |
235 | // ------------------------------------------------------------------------ |
236 | // |
237 | // Listeners on registrations |
238 | // |
239 | // ------------------------------------------------------------------------ |
240 | |
241 | /** |
242 | * Add a listener that gets called whenever a new op is registered or an existing |
243 | * op is deregistered. Immediately after registering, this listener gets called |
244 | * for all previously registered ops, so it can be used to keep track of ops |
245 | * registered with this dispatcher. |
246 | */ |
247 | RegistrationHandleRAII addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener); |
248 | |
249 | void checkInvariants() const; |
250 | |
251 | // |
252 | // ------------------------------------------------------------------------ |
253 | // |
254 | // Assertions |
255 | // |
256 | // ------------------------------------------------------------------------ |
257 | |
258 | /** |
259 | * For testing purposes. |
260 | * Returns a list of all operators that were created through calls to registerImpl(), |
261 | * without any corresponding calls to registerDef(). After static initialization |
262 | * is done this is almost certainly a bug, as the created OperatorHandle won't have |
263 | * any schema associated with it and users calling the op through the dispatcher |
264 | * won't be able to access it |
265 | * |
266 | * Note that we cannot enforce this invariant "as we go" during static initialization, |
267 | * due to undefined static initialization order- we have no guarantees over the order |
268 | * in which .def() and .impl() calls are registered in the dispatcher at static |
269 | * initialization time. So this function should only be called after static initialization. |
270 | */ |
271 | std::vector<OperatorHandle> findDanglingImpls() const; |
272 | |
273 | /** |
274 | * Useful for inspecting global Dispatcher registration state. |
275 | * Returns the names of all operators with a kernel registered for the specified DispatchKey. |
276 | * If no DispatchKey is specified, it returns all registered operators. |
277 | */ |
278 | std::vector<OperatorName> getRegistrationsForDispatchKey(c10::optional<DispatchKey> k) const; |
279 | |
280 | private: |
281 | Dispatcher(); |
282 | |
283 | static int64_t sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey); |
284 | static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey); |
285 | static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, c10::ArrayRef<const c10::IValue> args); |
286 | |
287 | OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema); |
288 | OperatorHandle findOrRegisterName_(const OperatorName& op_name); |
289 | |
290 | void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name); |
291 | void deregisterImpl_( |
292 | const OperatorHandle& op, |
293 | const OperatorName& op_name, |
294 | c10::optional<DispatchKey> dispatch_key, |
295 | impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle); |
296 | void deregisterName_(const OperatorHandle& op, const OperatorName& op_name); |
297 | void deregisterFallback_(DispatchKey dispatchKey); |
298 | void deregisterLibrary_(const std::string& ns); |
299 | void cleanup(const OperatorHandle& op, const OperatorName& op_name); |
300 | void checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug); |
301 | |
302 | std::list<OperatorDef> operators_; |
303 | #if !defined(C10_MOBILE) |
304 | LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_; |
305 | #else |
306 | RWSafeLeftRightWrapper<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_; |
307 | #endif |
308 | // Map from namespace to debug string (saying, e.g., where the library was defined) |
309 | ska::flat_hash_map<std::string, std::string> libraries_; |
310 | |
311 | std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_; |
312 | |
313 | std::unique_ptr<detail::RegistrationListenerList> listeners_; |
314 | |
315 | // This mutex protects concurrent access to the dispatcher |
316 | std::mutex mutex_; |
317 | |
318 | // This condition variable gets notified whenever we add a new def/impl to the |
319 | // dispatch table. This is primarily used by multipy/torchdeploy, when |
320 | // we have multiple interpreters trying to register to the dispatch table. |
321 | // In this situation, whenever the non-primary interpreter would have tried |
322 | // to register to the dispatch table, instead it will check to see if the |
323 | // expected registration has already been made, and if it hasn't, wait on |
324 | // this condition variable to see if it was just racing with the primary |
325 | // interpreter. |
326 | // |
327 | // We expect it to be rare for there to be any waiters on this condition |
328 | // variable. This is mostly just to help give better diagnostics if |
329 | // something goes horribly wrong |
330 | std::condition_variable cond_var_; |
331 | }; |
332 | |
333 | /** |
334 | * This is a handle to an operator schema registered with the dispatcher. |
335 | * This handle can be used to register kernels with the dispatcher or |
336 | * to lookup a kernel for a certain set of arguments. |
337 | */ |
338 | class TORCH_API OperatorHandle { |
339 | template <typename T> friend struct std::hash; |
340 | |
341 | public: |
342 | OperatorHandle(OperatorHandle&&) noexcept = default; |
343 | OperatorHandle& operator=(OperatorHandle&&) noexcept = default; |
344 | OperatorHandle(const OperatorHandle&) = default; |
345 | OperatorHandle& operator=(const OperatorHandle&) = default; |
346 | // NOLINTNEXTLINE(performance-trivially-destructible) |
347 | ~OperatorHandle(); |
348 | |
349 | const OperatorName& operator_name() const { |
350 | return operatorDef_->op.operator_name(); |
351 | } |
352 | |
353 | bool hasSchema() const { |
354 | return operatorDef_->op.hasSchema(); |
355 | } |
356 | |
357 | const FunctionSchema& schema() const { |
358 | return operatorDef_->op.schema(); |
359 | } |
360 | |
361 | const std::string& debug() const { |
362 | return operatorDef_->op.debug(); |
363 | } |
364 | |
365 | std::string dumpState() const { |
366 | return operatorDef_->op.dumpState(); |
367 | } |
368 | |
369 | bool hasKernelForDispatchKey(DispatchKey k) const { |
370 | return operatorDef_->op.hasKernelForDispatchKey(k); |
371 | } |
372 | |
373 | bool hasKernelForAnyDispatchKey(DispatchKeySet k) const { |
374 | return operatorDef_->op.hasKernelForAnyDispatchKey(k); |
375 | } |
376 | |
377 | bool hasComputedKernelForDispatchKey(DispatchKey k) const { |
378 | return operatorDef_->op.hasComputedKernelForDispatchKey(k); |
379 | } |
380 | |
381 | std::string dumpComputedTable() const { |
382 | return operatorDef_->op.dumpComputedTable(); |
383 | } |
384 | |
385 | void checkInvariants() const { |
386 | return operatorDef_->op.checkInvariants(); |
387 | } |
388 | |
389 | c10::ArrayRef<at::Tag> getTags() const { |
390 | return operatorDef_->op.getTags(); |
391 | } |
392 | |
393 | bool hasTag(const at::Tag& tag) const { |
394 | for(const auto& tag_: getTags()) { |
395 | if (tag == tag_) { |
396 | return true; |
397 | } |
398 | } |
399 | return false; |
400 | } |
401 | |
402 | template<class FuncType> |
403 | TypedOperatorHandle<FuncType> typed() const { |
404 | // NB: This assert is not 100% sound: you can retrieve a typed() operator |
405 | // handle prior to ANY C++ signature being registered on the operator |
406 | // and the check will say everything is OK (at which point you can then |
407 | // smuggle in a kernel that is typed incorrectly). For everything |
408 | // in core library this won't happen, because all the static registrations |
409 | // will be done by the time a typed() handle is acquired. |
410 | #if !defined C10_MOBILE |
411 | operatorDef_->op.assertSignatureIsCorrect<FuncType>(); |
412 | #endif |
413 | return TypedOperatorHandle<FuncType>(operatorIterator_); |
414 | } |
415 | |
416 | void callBoxed(Stack* stack) const { |
417 | c10::Dispatcher::singleton().callBoxed(*this, stack); |
418 | } |
419 | |
420 | void callBoxed(Stack& stack) const { |
421 | callBoxed(&stack); |
422 | } |
423 | |
424 | void callBoxedForDispatchKey(DispatchKey dk, Stack& stack) const { |
425 | c10::Dispatcher::singleton().callBoxedForDispatchKey(*this, dk, &stack); |
426 | } |
427 | |
428 | void redispatchBoxed(DispatchKeySet ks, Stack* stack) const { |
429 | c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack); |
430 | } |
431 | |
432 | template <typename F> |
433 | PyObject* getPythonOp(c10::impl::PyInterpreter* self_interpreter, F slow_accessor) const { |
434 | return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor); |
435 | } |
436 | |
437 | bool operator==(const OperatorHandle& other) const { |
438 | return operatorDef_ == other.operatorDef_; |
439 | } |
440 | |
441 | bool operator!=(const OperatorHandle& other) const { |
442 | return operatorDef_ != other.operatorDef_; |
443 | } |
444 | |
445 | private: |
446 | explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator) |
447 | : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {} |
448 | friend class Dispatcher; |
449 | template<class> friend class TypedOperatorHandle; |
450 | |
451 | // Storing a direct pointer to the OperatorDef even though we |
452 | // already have the iterator saves an instruction in the critical |
453 | // dispatch path. The iterator is effectively a |
454 | // pointer-to-std::list-node, and (at least in libstdc++'s |
455 | // implementation) the element is at an offset 16 bytes from that, |
456 | // because the prev/next pointers come first in the list node |
457 | // struct. So, an add instruction would be necessary to convert from the |
458 | // iterator to an OperatorDef*. |
459 | Dispatcher::OperatorDef* operatorDef_; |
460 | |
461 | // We need to store this iterator in order to make |
462 | // Dispatcher::cleanup() fast -- it runs a lot on program |
463 | // termination (and presuambly library unloading). |
464 | std::list<Dispatcher::OperatorDef>::iterator operatorIterator_; |
465 | }; |
466 | |
467 | /** |
468 | * This is a handle to an operator schema registered with the dispatcher. |
469 | * It holds the same information as an OperatorHandle, but it is templated |
470 | * on the operator arguments and allows calling the operator in an |
471 | * unboxed way. |
472 | */ |
473 | template<class FuncType> |
474 | class TypedOperatorHandle final { |
475 | static_assert(guts::false_t<FuncType>(), "FuncType in OperatorHandle::typed<FuncType> was not a valid function type" ); |
476 | }; |
477 | template<class Return, class... Args> |
478 | class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle { |
479 | public: |
480 | TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default; |
481 | TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default; |
482 | TypedOperatorHandle(const TypedOperatorHandle&) = default; |
483 | TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default; |
484 | |
485 | // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && |
486 | C10_ALWAYS_INLINE Return call(Args... args) const { |
487 | return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...); |
488 | } |
489 | |
490 | // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && |
491 | C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const { |
492 | return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...); |
493 | } |
494 | |
495 | private: |
496 | explicit TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator) |
497 | : OperatorHandle(operatorIterator) {} |
498 | friend class OperatorHandle; |
499 | }; |
500 | |
501 | namespace detail { |
502 | template <class... Args> inline void unused_arg_(const Args&...) {} |
503 | |
504 | // CaptureKernelCall is intended to capture return values from Dispatcher |
505 | // unboxed kernel calls. A record function may request to get outputs from the |
506 | // kernel calls. For boxed kernels, it's straightforward, the returned values |
507 | // are in the stack object. The stack can be passed to record functions. For |
508 | // unboxed kernels, we need to handle different kinds of return values, cache |
509 | // them temporarily, then release the values for the actual function call |
510 | // return. |
511 | template <typename ReturnType> |
512 | struct CaptureKernelCall { |
513 | template <typename F, typename... Args> |
514 | CaptureKernelCall( |
515 | const F& kernel, |
516 | const TypedOperatorHandle<ReturnType(Args...)>& op, |
517 | const DispatchKeySet& dispatchKeySet, |
518 | Args&&... args) |
519 | // Calls the kernel and capture the result in output_. |
520 | : output_{kernel.template call<ReturnType, Args...>( |
521 | op, |
522 | dispatchKeySet, |
523 | std::forward<Args>(args)...)} {} |
524 | // Wraps the return values in a Stack. |
525 | Stack getOutputs() { |
526 | Stack stack; |
527 | impl::push_outputs<ReturnType, false>::copy(output_, &stack); |
528 | return stack; |
529 | } |
530 | // Since we are returning the output_, we don't expect the output_ to be used |
531 | // afterward. Copy elision and RVO do not apply to class data members. Using |
532 | // move semantic to avoid copies when possible. |
533 | ReturnType release() && { |
534 | return std::move(output_); |
535 | } |
536 | |
537 | private: |
538 | ReturnType output_; |
539 | }; |
540 | |
541 | // Handle the lvalue reference differently since it should not be moved. |
542 | template <> |
543 | inline at::Tensor& CaptureKernelCall<at::Tensor&>::release() && { |
544 | return output_; |
545 | } |
546 | |
547 | // Handle case where the kernel returns void. |
548 | template <> |
549 | struct CaptureKernelCall<void> { |
550 | template <typename F, typename... Args> |
551 | CaptureKernelCall( |
552 | const F& kernel, |
553 | const TypedOperatorHandle<void(Args...)>& op, |
554 | const DispatchKeySet& dispatchKeySet, |
555 | Args&&... args) { |
556 | // Calling the kernel and no need to capture void. |
557 | kernel.template call<void, Args...>( |
558 | op, dispatchKeySet, std::forward<Args>(args)...); |
559 | } |
560 | Stack getOutputs() { |
561 | return Stack(); |
562 | } |
563 | void release() && {} |
564 | }; |
565 | |
566 | } // namespace detail |
567 | |
568 | // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && |
569 | template<class Return, class... Args> |
570 | inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) { |
571 | // If callbacks need inputs, we box the arguments and pass them to the guard. |
572 | // Note: For perf reasons we wouldn't want to prematurely box the arguments. |
573 | at::RecordFunction guard(std::move(stepCallbacks)); |
574 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved()); |
575 | auto dispatchKey = dispatchKeySet.highestPriorityTypeId(); |
576 | auto& schema = op.schema(); |
577 | auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema); |
578 | if (guard.needsInputs()) { |
579 | constexpr auto num_boxed_args = impl::boxed_size<Args...>(); |
580 | // If we used std::array<IValue, num_boxed_args> here, we would |
581 | // have to spend time default constructing the IValues in |
582 | // boxedArgs. aligned_storage has no such requirement. |
583 | // Max to avoid zero-size array.` |
584 | std::aligned_storage_t<sizeof(IValue), alignof(IValue)> boxedArgs[std::max(num_boxed_args, static_cast<size_t>(1))]; |
585 | // For debugging only; could be removed (but the compiler will do |
586 | // that for us and it's nice to have the extra assurance of |
587 | // correctness from our debug builds). |
588 | int lastArgIdx = 0; |
589 | impl::boxArgsToStack(boxedArgs, lastArgIdx, args...); |
590 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(lastArgIdx == num_boxed_args); |
591 | // I don't *think* we need std::launder here, because IValue has |
592 | // no subclasses and no const or reference fields. (We also |
593 | // couldn't use it even if we wanted to because we are currently |
594 | // stuck on C++14 rather than C++17, but we could do a backport |
595 | // similar to folly::launder if needed.) |
596 | runRecordFunction(guard, schema_ref, dispatchKey, c10::ArrayRef<const c10::IValue>(reinterpret_cast<IValue *>(boxedArgs), num_boxed_args)); |
597 | for (size_t ii = 0; ii < num_boxed_args; ++ii) { |
598 | reinterpret_cast<IValue *>(&boxedArgs[ii])->~IValue(); |
599 | } |
600 | } else { |
601 | runRecordFunction(guard, schema_ref, dispatchKey); |
602 | } |
603 | |
604 | if (C10_UNLIKELY(guard.needsOutputs())) { |
605 | // Calls the kernel and capture the output temporarily to pass to |
606 | // RecordFunction. |
607 | detail::CaptureKernelCall<Return> captureKernelCall( |
608 | kernel, op, dispatchKeySet, std::forward<Args>(args)...); |
609 | guard.setOutputs(captureKernelCall.getOutputs()); |
610 | // Releases the captured output to return to caller. |
611 | return std::move(captureKernelCall).release(); |
612 | } |
613 | |
614 | // keeping the guard alive while executing the kernel |
615 | return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...); |
616 | } |
617 | |
618 | // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && |
619 | template<class Return, class... Args> |
620 | C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const { |
621 | detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 |
622 | auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor() |
623 | .template getDispatchKeySetUnboxed<Args...>(args...); |
624 | #ifndef NDEBUG |
625 | DispatchTraceNestingGuard debug_guard; |
626 | if (show_dispatch_trace()) { |
627 | auto nesting_value = dispatch_trace_nesting_value(); |
628 | for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " " ; |
629 | std::cerr << "[call] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; |
630 | } |
631 | #endif |
632 | const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet); |
633 | #ifndef PYTORCH_DISABLE_PER_OP_PROFILING |
634 | auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); |
635 | if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) { |
636 | return callWithDispatchKeySlowPath<Return, Args...>(op, *step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...); |
637 | } |
638 | #endif // PYTORCH_DISABLE_PER_OP_PROFILING |
639 | return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...); |
640 | } |
641 | |
642 | // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && |
643 | template<class Return, class... Args> |
644 | inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const { |
645 | detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 |
646 | // do not use RecordFunction on redispatch |
647 | #ifndef NDEBUG |
648 | DispatchTraceNestingGuard debug_guard; |
649 | if (show_dispatch_trace()) { |
650 | auto nesting_value = dispatch_trace_nesting_value(); |
651 | for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " " ; |
652 | std::cerr << "[redispatch] op=[" << op.operator_name() << "], key=[" << toString(currentDispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; |
653 | } |
654 | #endif |
655 | const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet); |
656 | return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...); |
657 | } |
658 | |
659 | inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const { |
660 | // note: this doesn't need the mutex because write operations on the list keep iterators intact. |
661 | const auto& entry = op.operatorDef_->op; |
662 | auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); |
663 | #ifndef NDEBUG |
664 | DispatchTraceNestingGuard debug_guard; |
665 | if (show_dispatch_trace()) { |
666 | auto nesting_value = dispatch_trace_nesting_value(); |
667 | for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " " ; |
668 | std::cerr << "[callBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; |
669 | } |
670 | #endif |
671 | const auto& kernel = entry.lookup(dispatchKeySet); |
672 | #ifndef PYTORCH_DISABLE_PER_OP_PROFILING |
673 | auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); |
674 | if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) { |
675 | at::RecordFunction guard(std::move(*step_callbacks)); |
676 | auto dispatchKey = dispatchKeySet.highestPriorityTypeId(); |
677 | auto& schema = op.schema(); |
678 | auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema); |
679 | guard.needsInputs() ? runRecordFunction(guard, schema_ref, dispatchKey, c10::ArrayRef<const c10::IValue>(stack->data(), stack->size())) |
680 | : runRecordFunction(guard, schema_ref, dispatchKey); |
681 | |
682 | // keeping the guard alive while executing the kernel |
683 | kernel.callBoxed(op, dispatchKeySet, stack); |
684 | |
685 | if (C10_UNLIKELY(guard.needsOutputs())) { |
686 | guard.setOutputs(*stack); |
687 | } |
688 | return; |
689 | } |
690 | #endif // PYTORCH_DISABLE_PER_OP_PROFILING |
691 | kernel.callBoxed(op, dispatchKeySet, stack); |
692 | } |
693 | |
694 | // NB: this doesn't count as a "true" dispatcher jump, so no instrumentation |
695 | inline void Dispatcher::callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const { |
696 | // note: this doesn't need the mutex because write operations on the list keep iterators intact. |
697 | const auto& entry = op.operatorDef_->op; |
698 | // We still compute this as we're obligated to pass it on to the internal |
699 | // kernel, if it is a boxed fallback |
700 | auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); |
701 | const auto& kernel = ([&]() { |
702 | if (op.hasKernelForDispatchKey(dk)) { |
703 | return entry.kernelForDispatchKey(dk); |
704 | } else { |
705 | auto idx = getDispatchTableIndexForDispatchKey(dk); |
706 | TORCH_INTERNAL_ASSERT(idx >= 0); |
707 | return backendFallbackKernels_[idx].kernel; |
708 | } |
709 | })(); |
710 | kernel.callBoxed(op, dispatchKeySet, stack); |
711 | } |
712 | |
713 | inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const { |
714 | // note: this doesn't need the mutex because write operations on the list keep iterators intact. |
715 | const auto& entry = op.operatorDef_->op; |
716 | #ifndef NDEBUG |
717 | DispatchTraceNestingGuard debug_guard; |
718 | if (show_dispatch_trace()) { |
719 | auto nesting_value = dispatch_trace_nesting_value(); |
720 | for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " " ; |
721 | std::cerr << "[redispatchBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; |
722 | } |
723 | #endif |
724 | const auto& kernel = entry.lookup(dispatchKeySet); |
725 | return kernel.callBoxed(op, dispatchKeySet, stack); |
726 | } |
727 | |
728 | } // namespace c10 |
729 | |
730 | namespace std { |
731 | |
732 | template <> |
733 | struct hash<c10::OperatorHandle> { |
734 | size_t operator()(c10::OperatorHandle op) const noexcept { |
735 | return std::hash<void*>{}(static_cast<void*>(op.operatorDef_)); |
736 | } |
737 | }; |
738 | |
739 | } // namespace std |
740 | |