1#pragma once
2
3#include <ATen/SequenceNumber.h>
4#include <ATen/core/boxing/KernelFunction.h>
5#include <ATen/core/boxing/impl/boxing.h>
6#include <ATen/core/dispatch/OperatorEntry.h>
7#include <ATen/core/dispatch/CppSignature.h>
8#include <ATen/core/dispatch/RegistrationHandleRAII.h>
9#include <ATen/record_function.h>
10#include <c10/util/Exception.h>
11#include <c10/util/LeftRight.h>
12#include <list>
13#include <mutex>
14#include <condition_variable>
15#include <type_traits>
16
17#include <ATen/core/grad_mode.h>
18#include <ATen/core/enum_tag.h>
19
20namespace c10 {
21
22TORCH_API bool show_dispatch_trace();
23TORCH_API void dispatch_trace_nesting_incr();
24TORCH_API void dispatch_trace_nesting_decr();
25TORCH_API int64_t dispatch_trace_nesting_value();
26
27struct DispatchTraceNestingGuard {
28 DispatchTraceNestingGuard() { dispatch_trace_nesting_incr(); }
29 ~DispatchTraceNestingGuard() { dispatch_trace_nesting_decr(); }
30};
31
32class TORCH_API OperatorHandle;
33template<class FuncType> class TypedOperatorHandle;
34
35/**
36 * Implement this interface and register your instance with the dispatcher
37 * to get notified when operators are registered or deregistered with
38 * the dispatcher.
39 *
40 * NB: registration events only occur when a 'def' occurs; we don't trigger
41 * on 'impl' or 'fallback' calls.
42 */
43class TORCH_API OpRegistrationListener {
44public:
45 virtual ~OpRegistrationListener();
46
47 virtual void onOperatorRegistered(const OperatorHandle& op) = 0;
48 virtual void onOperatorDeregistered(const OperatorHandle& op) = 0;
49};
50
51namespace detail {
52class RegistrationListenerList;
53}
54class SchemaRegistrationHandleRAII;
55
56/**
57 * Top-level dispatch interface for dispatching via the dynamic dispatcher.
58 * Most end users shouldn't use this directly; if you're trying to register
59 * ops look in op_registration
60 */
61class TORCH_API Dispatcher final {
62private:
63 // For direct access to backend fallback information
64 friend class impl::OperatorEntry;
65
66 struct OperatorDef final {
67 explicit OperatorDef(OperatorName&& op_name)
68 : op(std::move(op_name)) {}
69
70 impl::OperatorEntry op;
71
72 // These refer to the number of outstanding RegistrationHandleRAII
73 // for this operator. def_count reflects only def() registrations
74 // (in the new world, this should only ever be 1, but old style
75 // registrations may register the schema multiple times, which
76 // will increase this count). def_and_impl_count reflects the number
77 // of combined def() and impl() registrations. When the last def() gets
78 // unregistered, we must immediately call the Deregistered listeners, but we
79 // must not actually delete the handle as there are other outstanding RAII
80 // destructors which will try to destruct and they had better still have a
81 // working operator handle in this case
82 size_t def_count = 0;
83 size_t def_and_impl_count = 0;
84 };
85 friend class OperatorHandle;
86 template<class> friend class TypedOperatorHandle;
87
88public:
89 ~Dispatcher();
90
91 // Implementation note: this class abstracts over the fact that we have per-operator
92 // dispatch tables. This could be easily adjusted to have a single global hash
93 // table.
94 static Dispatcher& realSingleton();
95
96 C10_ALWAYS_INLINE static Dispatcher& singleton() {
97#if !defined C10_MOBILE
98 // Implemented inline so that steady-state code needn't incur
99 // function-call overhead. We can't just inline `realSingleton`
100 // because the function-local static would get duplicated across
101 // all DSOs that include & use this header, leading to multiple
102 // singleton instances.
103 static Dispatcher& s = realSingleton();
104 return s;
105#else
106 // For C10_MOBILE, we should never inline a static function that
107 // has a static member, since the generated code calls
108 // __cxa_guard_acquire and __cxa_guard_release which help
109 // implement exactly once semantics for the initialization of the
110 // static Dispatcher& s above (for the non-mobile case). That
111 // additional code when duplicated across all operator stubs
112 // for every backend results in a lot of additional code
113 // being generated by the compiler.
114 return realSingleton();
115#endif
116 }
117
118 // ------------------------------------------------------------------------
119 //
120 // Accessing operators by schema
121 //
122 // ------------------------------------------------------------------------
123
124 /**
125 * Looks for an operator schema with the given name and overload name
126 * and returns it if it is registered WITH A SCHEMA.
127 * Returns nullopt otherwise.
128 */
129 c10::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
130
131 /**
132 * Variant of findSchema that results in less code generated at the call site.
133 * It (1) takes const char* pointer rather than OperatorName (so we skip
134 * generating std::string constructor calls at the call site), and (2)
135 * it raises an exception if the operator is not found (so we skip
136 * generating exception raising code at the call site)
137 *
138 * Irritatingly, we still have to generate the handful of instructions
139 * for dealing with an exception being thrown during static initialization
140 * (e.g. __cxa_guard_abort). If we could annotate this method noexcept we
141 * could avoid this code too, but as the name of the function suggests,
142 * it does throw exceptions.
143 */
144 OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
145
146 // Like findSchema, but also returns OperatorHandle even if there is no schema
147 c10::optional<OperatorHandle> findOp(const OperatorName& operator_name);
148
149 // Returns a list of all operator names present in the operatorLookupTable_
150 const std::vector<OperatorName> getAllOpNames();
151
152 // ------------------------------------------------------------------------
153 //
154 // Invoking operators
155 //
156 // ------------------------------------------------------------------------
157
158 template<class Return, class... Args>
159 Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const;
160
161
162 template<class Return, class... Args>
163 static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
164
165 // Like call, but intended for use in a redispatch in kernels that have explicitly performed the DispatchKey update calculatulation.
166 // This will take the DispatchKeySet completely as is and dispatch to the kernel of the corresponding highest priority key in the set.
167 // Note that this version of redispatch treats the inputted DispatchKeySet *as is*, and does NOT mask out the highest priority key.
168 // See Note [Plumbing Keys Through The Dispatcher]
169 template<class Return, class... Args>
170 Return redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const;
171
172 // Invoke an operator via the boxed calling convention using an IValue stack
173 void callBoxed(const OperatorHandle& op, Stack* stack) const;
174 void callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const;
175
176 // TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
177 // See Note [Plumbing Keys Through The Dispatcher]
178 void redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const;
179
180 bool hasBackendFallbackForDispatchKey(DispatchKey dk) {
181 auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk);
182 if (dispatch_ix < 0) return false;
183 return backendFallbackKernels_[dispatch_ix].kernel.isValid();
184 }
185
186 // Used by torchdeploy/multipy for multiple interpreters racing.
187 void waitForDef(const FunctionSchema& schema);
188 void waitForImpl(const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key);
189
190 // ------------------------------------------------------------------------
191 //
192 // Performing registrations (NON user public; use op_registration)
193 //
194 // ------------------------------------------------------------------------
195
196 /**
197 * Register a new operator schema.
198 *
199 * If a schema with the same operator name and overload name already exists,
200 * this function will check that both schemas are exactly identical.
201 */
202 RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags = {});
203
204 /**
205 * Register a kernel to the dispatch table for an operator.
206 * If dispatch_key is nullopt, then this registers a fallback kernel.
207 *
208 * @return A RAII object that manages the lifetime of the registration.
209 * Once that object is destructed, the kernel will be deregistered.
210 */
211 // NB: steals the inferred function schema, as we may need to hold on to
212 // it for a bit until the real schema turns up
213 RegistrationHandleRAII registerImpl(OperatorName op_name, c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, c10::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
214
215 /**
216 * Register a new operator by name.
217 */
218 RegistrationHandleRAII registerName(OperatorName op_name);
219
220 /**
221 * Register a fallback kernel for a backend.
222 * If an operator is called but there is no concrete kernel for the dispatch
223 * key of the given operator arguments, it will check if there is such a
224 * fallback kernel for the given dispatch key and, if yes, call that one.
225 */
226 RegistrationHandleRAII registerFallback(DispatchKey dispatch_key, KernelFunction kernel, std::string debug);
227
228 /**
229 * Use to register whenever we had a TORCH_LIBRARY declaration in the frontend
230 * API. These invocations are only permitted once per program, so we raise
231 * an error if this is called again for the same namespace.
232 */
233 RegistrationHandleRAII registerLibrary(std::string ns, std::string debug);
234
235 // ------------------------------------------------------------------------
236 //
237 // Listeners on registrations
238 //
239 // ------------------------------------------------------------------------
240
241 /**
242 * Add a listener that gets called whenever a new op is registered or an existing
243 * op is deregistered. Immediately after registering, this listener gets called
244 * for all previously registered ops, so it can be used to keep track of ops
245 * registered with this dispatcher.
246 */
247 RegistrationHandleRAII addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener);
248
249 void checkInvariants() const;
250
251 //
252 // ------------------------------------------------------------------------
253 //
254 // Assertions
255 //
256 // ------------------------------------------------------------------------
257
258 /**
259 * For testing purposes.
260 * Returns a list of all operators that were created through calls to registerImpl(),
261 * without any corresponding calls to registerDef(). After static initialization
262 * is done this is almost certainly a bug, as the created OperatorHandle won't have
263 * any schema associated with it and users calling the op through the dispatcher
264 * won't be able to access it
265 *
266 * Note that we cannot enforce this invariant "as we go" during static initialization,
267 * due to undefined static initialization order- we have no guarantees over the order
268 * in which .def() and .impl() calls are registered in the dispatcher at static
269 * initialization time. So this function should only be called after static initialization.
270 */
271 std::vector<OperatorHandle> findDanglingImpls() const;
272
273 /**
274 * Useful for inspecting global Dispatcher registration state.
275 * Returns the names of all operators with a kernel registered for the specified DispatchKey.
276 * If no DispatchKey is specified, it returns all registered operators.
277 */
278 std::vector<OperatorName> getRegistrationsForDispatchKey(c10::optional<DispatchKey> k) const;
279
280private:
281 Dispatcher();
282
283 static int64_t sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey);
284 static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey);
285 static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, c10::ArrayRef<const c10::IValue> args);
286
287 OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
288 OperatorHandle findOrRegisterName_(const OperatorName& op_name);
289
290 void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name);
291 void deregisterImpl_(
292 const OperatorHandle& op,
293 const OperatorName& op_name,
294 c10::optional<DispatchKey> dispatch_key,
295 impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle);
296 void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
297 void deregisterFallback_(DispatchKey dispatchKey);
298 void deregisterLibrary_(const std::string& ns);
299 void cleanup(const OperatorHandle& op, const OperatorName& op_name);
300 void checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug);
301
302 std::list<OperatorDef> operators_;
303#if !defined(C10_MOBILE)
304 LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
305#else
306 RWSafeLeftRightWrapper<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
307#endif
308 // Map from namespace to debug string (saying, e.g., where the library was defined)
309 ska::flat_hash_map<std::string, std::string> libraries_;
310
311 std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_;
312
313 std::unique_ptr<detail::RegistrationListenerList> listeners_;
314
315 // This mutex protects concurrent access to the dispatcher
316 std::mutex mutex_;
317
318 // This condition variable gets notified whenever we add a new def/impl to the
319 // dispatch table. This is primarily used by multipy/torchdeploy, when
320 // we have multiple interpreters trying to register to the dispatch table.
321 // In this situation, whenever the non-primary interpreter would have tried
322 // to register to the dispatch table, instead it will check to see if the
323 // expected registration has already been made, and if it hasn't, wait on
324 // this condition variable to see if it was just racing with the primary
325 // interpreter.
326 //
327 // We expect it to be rare for there to be any waiters on this condition
328 // variable. This is mostly just to help give better diagnostics if
329 // something goes horribly wrong
330 std::condition_variable cond_var_;
331};
332
333/**
334 * This is a handle to an operator schema registered with the dispatcher.
335 * This handle can be used to register kernels with the dispatcher or
336 * to lookup a kernel for a certain set of arguments.
337 */
338class TORCH_API OperatorHandle {
339 template <typename T> friend struct std::hash;
340
341public:
342 OperatorHandle(OperatorHandle&&) noexcept = default;
343 OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
344 OperatorHandle(const OperatorHandle&) = default;
345 OperatorHandle& operator=(const OperatorHandle&) = default;
346 // NOLINTNEXTLINE(performance-trivially-destructible)
347 ~OperatorHandle();
348
349 const OperatorName& operator_name() const {
350 return operatorDef_->op.operator_name();
351 }
352
353 bool hasSchema() const {
354 return operatorDef_->op.hasSchema();
355 }
356
357 const FunctionSchema& schema() const {
358 return operatorDef_->op.schema();
359 }
360
361 const std::string& debug() const {
362 return operatorDef_->op.debug();
363 }
364
365 std::string dumpState() const {
366 return operatorDef_->op.dumpState();
367 }
368
369 bool hasKernelForDispatchKey(DispatchKey k) const {
370 return operatorDef_->op.hasKernelForDispatchKey(k);
371 }
372
373 bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
374 return operatorDef_->op.hasKernelForAnyDispatchKey(k);
375 }
376
377 bool hasComputedKernelForDispatchKey(DispatchKey k) const {
378 return operatorDef_->op.hasComputedKernelForDispatchKey(k);
379 }
380
381 std::string dumpComputedTable() const {
382 return operatorDef_->op.dumpComputedTable();
383 }
384
385 void checkInvariants() const {
386 return operatorDef_->op.checkInvariants();
387 }
388
389 c10::ArrayRef<at::Tag> getTags() const {
390 return operatorDef_->op.getTags();
391 }
392
393 bool hasTag(const at::Tag& tag) const {
394 for(const auto& tag_: getTags()) {
395 if (tag == tag_) {
396 return true;
397 }
398 }
399 return false;
400 }
401
402 template<class FuncType>
403 TypedOperatorHandle<FuncType> typed() const {
404 // NB: This assert is not 100% sound: you can retrieve a typed() operator
405 // handle prior to ANY C++ signature being registered on the operator
406 // and the check will say everything is OK (at which point you can then
407 // smuggle in a kernel that is typed incorrectly). For everything
408 // in core library this won't happen, because all the static registrations
409 // will be done by the time a typed() handle is acquired.
410#if !defined C10_MOBILE
411 operatorDef_->op.assertSignatureIsCorrect<FuncType>();
412#endif
413 return TypedOperatorHandle<FuncType>(operatorIterator_);
414 }
415
416 void callBoxed(Stack* stack) const {
417 c10::Dispatcher::singleton().callBoxed(*this, stack);
418 }
419
420 void callBoxed(Stack& stack) const {
421 callBoxed(&stack);
422 }
423
424 void callBoxedForDispatchKey(DispatchKey dk, Stack& stack) const {
425 c10::Dispatcher::singleton().callBoxedForDispatchKey(*this, dk, &stack);
426 }
427
428 void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
429 c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
430 }
431
432 template <typename F>
433 PyObject* getPythonOp(c10::impl::PyInterpreter* self_interpreter, F slow_accessor) const {
434 return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
435 }
436
437 bool operator==(const OperatorHandle& other) const {
438 return operatorDef_ == other.operatorDef_;
439 }
440
441 bool operator!=(const OperatorHandle& other) const {
442 return operatorDef_ != other.operatorDef_;
443 }
444
445private:
446 explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
447 : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {}
448 friend class Dispatcher;
449 template<class> friend class TypedOperatorHandle;
450
451 // Storing a direct pointer to the OperatorDef even though we
452 // already have the iterator saves an instruction in the critical
453 // dispatch path. The iterator is effectively a
454 // pointer-to-std::list-node, and (at least in libstdc++'s
455 // implementation) the element is at an offset 16 bytes from that,
456 // because the prev/next pointers come first in the list node
457 // struct. So, an add instruction would be necessary to convert from the
458 // iterator to an OperatorDef*.
459 Dispatcher::OperatorDef* operatorDef_;
460
461 // We need to store this iterator in order to make
462 // Dispatcher::cleanup() fast -- it runs a lot on program
463 // termination (and presuambly library unloading).
464 std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
465};
466
467/**
468 * This is a handle to an operator schema registered with the dispatcher.
469 * It holds the same information as an OperatorHandle, but it is templated
470 * on the operator arguments and allows calling the operator in an
471 * unboxed way.
472 */
473template<class FuncType>
474class TypedOperatorHandle final {
475 static_assert(guts::false_t<FuncType>(), "FuncType in OperatorHandle::typed<FuncType> was not a valid function type");
476};
477template<class Return, class... Args>
478class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle {
479public:
480 TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default;
481 TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default;
482 TypedOperatorHandle(const TypedOperatorHandle&) = default;
483 TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default;
484
485 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
486 C10_ALWAYS_INLINE Return call(Args... args) const {
487 return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
488 }
489
490 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
491 C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
492 return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...);
493 }
494
495private:
496 explicit TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
497 : OperatorHandle(operatorIterator) {}
498 friend class OperatorHandle;
499};
500
501namespace detail {
502template <class... Args> inline void unused_arg_(const Args&...) {}
503
504// CaptureKernelCall is intended to capture return values from Dispatcher
505// unboxed kernel calls. A record function may request to get outputs from the
506// kernel calls. For boxed kernels, it's straightforward, the returned values
507// are in the stack object. The stack can be passed to record functions. For
508// unboxed kernels, we need to handle different kinds of return values, cache
509// them temporarily, then release the values for the actual function call
510// return.
511template <typename ReturnType>
512struct CaptureKernelCall {
513 template <typename F, typename... Args>
514 CaptureKernelCall(
515 const F& kernel,
516 const TypedOperatorHandle<ReturnType(Args...)>& op,
517 const DispatchKeySet& dispatchKeySet,
518 Args&&... args)
519 // Calls the kernel and capture the result in output_.
520 : output_{kernel.template call<ReturnType, Args...>(
521 op,
522 dispatchKeySet,
523 std::forward<Args>(args)...)} {}
524 // Wraps the return values in a Stack.
525 Stack getOutputs() {
526 Stack stack;
527 impl::push_outputs<ReturnType, false>::copy(output_, &stack);
528 return stack;
529 }
530 // Since we are returning the output_, we don't expect the output_ to be used
531 // afterward. Copy elision and RVO do not apply to class data members. Using
532 // move semantic to avoid copies when possible.
533 ReturnType release() && {
534 return std::move(output_);
535 }
536
537 private:
538 ReturnType output_;
539};
540
541// Handle the lvalue reference differently since it should not be moved.
542template <>
543inline at::Tensor& CaptureKernelCall<at::Tensor&>::release() && {
544 return output_;
545}
546
547// Handle case where the kernel returns void.
548template <>
549struct CaptureKernelCall<void> {
550 template <typename F, typename... Args>
551 CaptureKernelCall(
552 const F& kernel,
553 const TypedOperatorHandle<void(Args...)>& op,
554 const DispatchKeySet& dispatchKeySet,
555 Args&&... args) {
556 // Calling the kernel and no need to capture void.
557 kernel.template call<void, Args...>(
558 op, dispatchKeySet, std::forward<Args>(args)...);
559 }
560 Stack getOutputs() {
561 return Stack();
562 }
563 void release() && {}
564};
565
566} // namespace detail
567
568// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
569template<class Return, class... Args>
570inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
571 // If callbacks need inputs, we box the arguments and pass them to the guard.
572 // Note: For perf reasons we wouldn't want to prematurely box the arguments.
573 at::RecordFunction guard(std::move(stepCallbacks));
574 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved());
575 auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
576 auto& schema = op.schema();
577 auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
578 if (guard.needsInputs()) {
579 constexpr auto num_boxed_args = impl::boxed_size<Args...>();
580 // If we used std::array<IValue, num_boxed_args> here, we would
581 // have to spend time default constructing the IValues in
582 // boxedArgs. aligned_storage has no such requirement.
583 // Max to avoid zero-size array.`
584 std::aligned_storage_t<sizeof(IValue), alignof(IValue)> boxedArgs[std::max(num_boxed_args, static_cast<size_t>(1))];
585 // For debugging only; could be removed (but the compiler will do
586 // that for us and it's nice to have the extra assurance of
587 // correctness from our debug builds).
588 int lastArgIdx = 0;
589 impl::boxArgsToStack(boxedArgs, lastArgIdx, args...);
590 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(lastArgIdx == num_boxed_args);
591 // I don't *think* we need std::launder here, because IValue has
592 // no subclasses and no const or reference fields. (We also
593 // couldn't use it even if we wanted to because we are currently
594 // stuck on C++14 rather than C++17, but we could do a backport
595 // similar to folly::launder if needed.)
596 runRecordFunction(guard, schema_ref, dispatchKey, c10::ArrayRef<const c10::IValue>(reinterpret_cast<IValue *>(boxedArgs), num_boxed_args));
597 for (size_t ii = 0; ii < num_boxed_args; ++ii) {
598 reinterpret_cast<IValue *>(&boxedArgs[ii])->~IValue();
599 }
600 } else {
601 runRecordFunction(guard, schema_ref, dispatchKey);
602 }
603
604 if (C10_UNLIKELY(guard.needsOutputs())) {
605 // Calls the kernel and capture the output temporarily to pass to
606 // RecordFunction.
607 detail::CaptureKernelCall<Return> captureKernelCall(
608 kernel, op, dispatchKeySet, std::forward<Args>(args)...);
609 guard.setOutputs(captureKernelCall.getOutputs());
610 // Releases the captured output to return to caller.
611 return std::move(captureKernelCall).release();
612 }
613
614 // keeping the guard alive while executing the kernel
615 return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
616}
617
618// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
619template<class Return, class... Args>
620C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
621 detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
622 auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
623 .template getDispatchKeySetUnboxed<Args...>(args...);
624#ifndef NDEBUG
625 DispatchTraceNestingGuard debug_guard;
626 if (show_dispatch_trace()) {
627 auto nesting_value = dispatch_trace_nesting_value();
628 for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
629 std::cerr << "[call] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
630 }
631#endif
632 const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
633#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
634 auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
635 if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
636 return callWithDispatchKeySlowPath<Return, Args...>(op, *step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
637 }
638#endif // PYTORCH_DISABLE_PER_OP_PROFILING
639 return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
640}
641
642// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
643template<class Return, class... Args>
644inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
645 detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
646 // do not use RecordFunction on redispatch
647#ifndef NDEBUG
648 DispatchTraceNestingGuard debug_guard;
649 if (show_dispatch_trace()) {
650 auto nesting_value = dispatch_trace_nesting_value();
651 for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
652 std::cerr << "[redispatch] op=[" << op.operator_name() << "], key=[" << toString(currentDispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
653 }
654#endif
655 const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
656 return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
657}
658
659inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const {
660 // note: this doesn't need the mutex because write operations on the list keep iterators intact.
661 const auto& entry = op.operatorDef_->op;
662 auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
663#ifndef NDEBUG
664 DispatchTraceNestingGuard debug_guard;
665 if (show_dispatch_trace()) {
666 auto nesting_value = dispatch_trace_nesting_value();
667 for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
668 std::cerr << "[callBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
669 }
670#endif
671 const auto& kernel = entry.lookup(dispatchKeySet);
672#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
673 auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
674 if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
675 at::RecordFunction guard(std::move(*step_callbacks));
676 auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
677 auto& schema = op.schema();
678 auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
679 guard.needsInputs() ? runRecordFunction(guard, schema_ref, dispatchKey, c10::ArrayRef<const c10::IValue>(stack->data(), stack->size()))
680 : runRecordFunction(guard, schema_ref, dispatchKey);
681
682 // keeping the guard alive while executing the kernel
683 kernel.callBoxed(op, dispatchKeySet, stack);
684
685 if (C10_UNLIKELY(guard.needsOutputs())) {
686 guard.setOutputs(*stack);
687 }
688 return;
689 }
690#endif // PYTORCH_DISABLE_PER_OP_PROFILING
691 kernel.callBoxed(op, dispatchKeySet, stack);
692}
693
694// NB: this doesn't count as a "true" dispatcher jump, so no instrumentation
695inline void Dispatcher::callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const {
696 // note: this doesn't need the mutex because write operations on the list keep iterators intact.
697 const auto& entry = op.operatorDef_->op;
698 // We still compute this as we're obligated to pass it on to the internal
699 // kernel, if it is a boxed fallback
700 auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
701 const auto& kernel = ([&]() {
702 if (op.hasKernelForDispatchKey(dk)) {
703 return entry.kernelForDispatchKey(dk);
704 } else {
705 auto idx = getDispatchTableIndexForDispatchKey(dk);
706 TORCH_INTERNAL_ASSERT(idx >= 0);
707 return backendFallbackKernels_[idx].kernel;
708 }
709 })();
710 kernel.callBoxed(op, dispatchKeySet, stack);
711}
712
713inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
714 // note: this doesn't need the mutex because write operations on the list keep iterators intact.
715 const auto& entry = op.operatorDef_->op;
716#ifndef NDEBUG
717 DispatchTraceNestingGuard debug_guard;
718 if (show_dispatch_trace()) {
719 auto nesting_value = dispatch_trace_nesting_value();
720 for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
721 std::cerr << "[redispatchBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
722 }
723#endif
724 const auto& kernel = entry.lookup(dispatchKeySet);
725 return kernel.callBoxed(op, dispatchKeySet, stack);
726}
727
728} // namespace c10
729
730namespace std {
731
732template <>
733struct hash<c10::OperatorHandle> {
734 size_t operator()(c10::OperatorHandle op) const noexcept {
735 return std::hash<void*>{}(static_cast<void*>(op.operatorDef_));
736 }
737};
738
739} // namespace std
740