1#pragma once
2
3#include <ATen/core/ivalue.h>
4#include <ATen/core/operator_name.h>
5#include <c10/macros/Export.h>
6#include <c10/util/Optional.h>
7#include <c10/util/SmallVector.h>
8#include <c10/util/variant.h>
9
10#include <array>
11#include <atomic>
12#include <functional>
13#include <memory>
14
15namespace c10 {
16class TORCH_API OperatorHandle;
17}
18
19namespace at {
20
21// Kind of record function scope;
22enum class C10_API_ENUM RecordScope : uint8_t {
23 // c10/ATen ops, autograd nodes
24 FUNCTION = 0,
25 // Functions/nodes called from the autograd
26 BACKWARD_FUNCTION,
27 // TorchScript functions, methods
28 TORCHSCRIPT_FUNCTION,
29 // Kernel Function dtype Tag
30 KERNEL_FUNCTION_DTYPE,
31 // Torchbind custom class,
32 CUSTOM_CLASS,
33 // Generic Build Feature
34 BUILD_FEATURE,
35 // Kernel Function dtype Tag
36 LITE_INTERPRETER,
37 // User defined scope (e.g. with record_function())
38 USER_SCOPE,
39 // Scopes for static runtime, a specialized TorchScript interpreter
40 STATIC_RUNTIME_OP,
41 STATIC_RUNTIME_MODEL,
42 NUM_SCOPES, // must be the last in the list
43};
44
45} // namespace at
46
47namespace std {
48template <>
49struct hash<at::RecordScope> {
50 size_t operator()(const at::RecordScope& sc) const {
51 return static_cast<std::size_t>(sc);
52 }
53};
54} // namespace std
55
56namespace at {
57
58struct TORCH_API StringView {
59 StringView() : StringView(nullptr) {}
60 explicit StringView(const char* str_ptr)
61 : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {}
62 explicit StringView(std::string str)
63 : owned_str_ptr_(std::make_shared<std::string>(std::move(str))),
64 str_ptr_(owned_str_ptr_->c_str()) {}
65
66 const char* str() const {
67 return str_ptr_;
68 }
69
70 friend std::ostream& operator<<(std::ostream& os, const StringView& dt) {
71 os << dt.str();
72 return os;
73 }
74
75 friend bool operator==(const StringView& lhs, const StringView& rhs) {
76 return strcmp(lhs.str(), rhs.str()) == 0;
77 }
78
79 friend bool operator!=(const StringView& lhs, const StringView& rhs) {
80 return !(lhs == rhs);
81 }
82
83 private:
84 std::shared_ptr<std::string> owned_str_ptr_;
85 const char* str_ptr_;
86};
87
88// Soft limit on the number of callbacks to use;
89constexpr std::size_t kSoftLimitCallbacks = 4;
90
91// An abstract base class for various observer contexts that can be attached to
92// the RecordFunction.
93struct ObserverContext {
94 virtual ~ObserverContext() = default;
95
96 protected:
97 ObserverContext() {}
98};
99
100typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
101typedef c10::SmallVector<std::unique_ptr<ObserverContext>, kSoftLimitCallbacks>
102 ObserverContextList;
103typedef uint64_t RecordFunctionHandle;
104struct RecordFunction;
105
106//
107// PyTorch callbacks/observers API:
108//
109
110/**
111 * RecordFunctionCallback represents a pair of callbacks to be used with
112 * RecordFunction, members:
113 * start, end - the callbacks to run when entering and exiting the scope;
114 * optionally, the start callback may return an ObserverContext which will
115 * be passed to the end callback, use appropriate constructor accordingly.
116 * needs_inputs - whether the callbacks need the inputs passed from the
117 * observed function/range; NOTE: passing the inputs incurs an additional
118 * overhead; sampling_probability - if not 1.0, then the callback is
119 * probabilistically sampled to run; NOTE: start and end callbacks always run as
120 * a pair and are sampled together; scopes - types of scopes to execute the
121 * callbacks on (see RecordScope); passing empty set means the callbacks will be
122 * executed for all possible scope types should_run - optional function that
123 * returns whether this callback should run; overwrites the effect of setting
124 * sampling_probability
125 */
126class TORCH_API RecordFunctionCallback {
127 public:
128 using StartCallback =
129 std::unique_ptr<ObserverContext> (*)(const RecordFunction&);
130 using EndCallback = void (*)(const RecordFunction&, ObserverContext*);
131
132 // This interface supports observers that require passing an ObserverContext
133 // between start and end callbacks.
134 explicit RecordFunctionCallback(
135 StartCallback start,
136 EndCallback end = nullptr)
137 : start_(start), end_(end) {
138 scopes_.fill(true);
139 }
140
141 RecordFunctionCallback& needsInputs(bool needs_inputs) {
142 needs_inputs_ = needs_inputs;
143 return *this;
144 }
145
146 RecordFunctionCallback& needsOutputs(bool needs_outputs) {
147 needs_outputs_ = needs_outputs;
148 return *this;
149 }
150
151 RecordFunctionCallback& needsIds(bool needs_ids) {
152 needs_ids_ = needs_ids;
153 return *this;
154 }
155
156 RecordFunctionCallback& samplingProb(double sampling_prob) {
157 TORCH_CHECK(
158 sampling_prob >= 0.0 && sampling_prob <= 1.0,
159 "Invalid sampling probability");
160 sampling_prob_ = sampling_prob;
161 return *this;
162 }
163
164 RecordFunctionCallback& scopes(
165 const std::unordered_set<RecordScope, std::hash<RecordScope>>& scopes) {
166 if (!scopes.empty()) {
167 scopes_.fill(false);
168 for (auto sc : scopes) {
169 scopes_[static_cast<size_t>(sc)] = true;
170 }
171 } else {
172 scopes_.fill(true);
173 }
174 return *this;
175 }
176
177 bool needsInputs() const {
178 return needs_inputs_;
179 }
180
181 bool needsOutputs() const {
182 return needs_outputs_;
183 }
184
185 bool needsIds() const {
186 return needs_ids_;
187 }
188
189 double samplingProb() const {
190 return sampling_prob_;
191 }
192
193 bool checkScope(RecordScope sc) const {
194 return scopes_[(size_t)sc];
195 }
196
197 StartCallback start() const {
198 return start_;
199 }
200
201 EndCallback end() const {
202 return end_;
203 }
204
205 private:
206 StartCallback start_;
207 EndCallback end_;
208 double sampling_prob_ = 1.0;
209 std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
210 bool needs_inputs_ = false;
211 bool needs_outputs_ = false;
212 bool needs_ids_ = false;
213};
214
215// Notes:
216// - two types of callbacks are provided: thread local and global
217// - thread local callbacks are added/removed only for the given thread
218// and are stored locally for each thread and separately from the list
219// of the global callbacks
220// - global callbacks are stored in a single per process list and are
221// invoked by every RecordFunction, in addition to the thread local
222// callbacks specific to the given thread
223// - we allow the added callbacks to be sampled, by specifying a sampling
224// probability for each callback pair, if the start callback is
225// not picked to run, the corresponding end callback won't be called
226// - a typical use case for the global callbacks is passive monitoring
227// in the background (e.g. fleet-wide monitoring), without focusing on
228// the specific piece of code
229// - in contrast, thread local callbacks are enabled locally, on demand,
230// for the specific piece of code (range) and are not sampled
231// - a typical use case for thread local callbacks is profiler and code
232// execution tracer
233// - note, thread local callbacks are automatically propagated with
234// ThreadLocalState across JIT continuations and async tasks (at::launch)
235
236typedef uint64_t CallbackHandle;
237
238constexpr CallbackHandle INVALID_CALLBACK_HANDLE{0};
239
240// It is unnecessary to use atomic operations for enabling
241// thread-local function callbacks. Moreover, it prevents saving to
242// ThreadLocalState because std::atomic is non-copyable.
243struct RecordFunctionCallbacksEntry {
244 RecordFunctionCallbacksEntry(RecordFunctionCallback&& cb, CallbackHandle h)
245 : callback_(cb), handle_(h) {}
246
247 RecordFunctionCallback callback_;
248 bool enabled_{true};
249 CallbackHandle handle_;
250};
251
252// Holds pairs (callbacks, unique_id)
253using RecordFunctionCallbacks = std::vector<RecordFunctionCallbacksEntry>;
254
255// Generated by the callback managers to determine which functions to run.
256struct StepCallbacks {
257 StepCallbacks() = default;
258 StepCallbacks(uint64_t thread_id, RecordScope scope)
259 : thread_id_{thread_id}, scope_{scope} {}
260
261 bool empty() const {
262 return callbacks_.empty();
263 }
264
265 struct StartEndPair {
266 RecordFunctionCallback::StartCallback start_;
267 RecordFunctionCallback::EndCallback end_;
268 };
269
270 using StartEndPairs = c10::SmallVector<StartEndPair, kSoftLimitCallbacks>;
271
272 StartEndPairs callbacks_;
273 uint64_t thread_id_{0};
274 RecordScope scope_{RecordScope::FUNCTION};
275 bool needs_inputs_{false};
276 bool needs_outputs_{false};
277 bool needs_ids_{false};
278};
279
280struct TORCH_API RecordFunction {
281 // Default constructor is used with before function called afterwards:
282 // scope - record scope that this function tracks
283 // pre_sampled - whether this RecordFunction was already pre-sampled with
284 // kLowProb probability
285 explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION);
286 explicit RecordFunction(StepCallbacks&& step_callbacks);
287
288 template <typename F>
289 void before(
290 F fn,
291 c10::ArrayRef<const c10::IValue> args,
292 int64_t current_sequence_nr = -1) {
293 if (!isActive()) {
294 return;
295 }
296 inputs_ = args;
297 before(fn, current_sequence_nr);
298 }
299
300 template <typename F>
301 void before(
302 F fn,
303 const std::vector<IValue>* args,
304 int64_t current_sequence_nr = -1) {
305 before(
306 std::move(fn),
307 c10::ArrayRef<const c10::IValue>(args->data(), args->size()),
308 current_sequence_nr);
309 }
310
311 // Destructor calls end callbacks
312 virtual ~RecordFunction();
313
314 RecordFunction(const RecordFunction&) = delete;
315 RecordFunction& operator=(const RecordFunction&) = delete;
316
317 const char* name() const;
318
319 int64_t seqNr() const {
320 return sequence_nr_;
321 }
322
323 c10::ArrayRef<const IValue> inputs() const {
324#ifndef NDEBUG
325 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
326 inputs_valid_, "Called inputs() outside RecordFunction start callback");
327#endif
328 return inputs_;
329 }
330
331 const std::vector<c10::IValue>& outputs() const {
332 return outputs_;
333 }
334
335 void setOutputs(std::vector<c10::IValue>&& outputs) {
336 outputs_ = std::move(outputs);
337 }
338
339 void setOutputs(c10::ArrayRef<c10::IValue> outputs) {
340 outputs_ = outputs.vec();
341 }
342
343 size_t num_inputs() const;
344 size_t num_outputs() const;
345
346 // Retrieves the thread_id that this RecordFunction ran start callbacks with.
347 // Useful for writing thread safe end callbacks that may be potentially
348 // executed in a different thread (async ops)
349 uint64_t threadId() const {
350 return step_callbacks_.thread_id_;
351 }
352
353 // For backward functions - thread id of the corresponding forward function,
354 // or zero otherwise;
355 // used alongside with sequence number to correlate backward functions with
356 // the forward ones
357 uint64_t forwardThreadId() const {
358 return fwd_thread_id_;
359 }
360
361 void setForwardThreadId(uint64_t thread_id) {
362 fwd_thread_id_ = thread_id;
363 }
364
365 RecordScope scope() const {
366 return step_callbacks_.scope_;
367 }
368
369 // Returns logical thread_id for the current thread
370 static uint64_t currentThreadId();
371
372 // Internal functions, do not use directly;
373 // used in python's context manager
374
375 // before functions initialize RecordFunction members and call
376 // start callbacks
377 using schema_ref_t = std::reference_wrapper<const c10::FunctionSchema>;
378 void before(const char* name, int64_t sequence_nr = -1);
379 void before(std::string name, int64_t sequence_nr = -1);
380 void before(schema_ref_t schema, int64_t sequence_nr = -1);
381
382 // Sets node ID for distributed profiling
383 static void setDefaultNodeId(int64_t defaultNodeId);
384 // Gets node ID for distributed profiling
385 static int64_t getDefaultNodeId();
386
387 // Calls end callbacks. After end(), accessors will no longer provide useful
388 // results.
389 void end();
390
391 // Internal-only, used only force async event for distributed events
392 // profiling.
393 void _setAsync();
394
395 // Returns whether this RecordFunction corresponds to an async event orn ot.
396 bool isAsync() const;
397
398 // Internal-only, used to denote out variant used for Static Runtime execution
399 void _setStaticRuntimeOutVariant();
400 bool isStaticRuntimeOutVariant() const;
401
402 RecordFunctionHandle handle() const {
403 return handle_;
404 }
405
406 c10::optional<OperatorName> operator_name() const;
407
408 // This method returns a copy of the FunctionSchema and can be expensive.
409 c10::optional<FunctionSchema> operator_schema() const;
410
411 void setHandle(RecordFunctionHandle handle) {
412 handle_ = handle;
413 }
414
415 // Whether this RecordFunction runs any callbacks.
416 bool isActive() const {
417 return !step_callbacks_.empty();
418 }
419
420 bool needsInputs() const {
421 return step_callbacks_.needs_inputs_;
422 }
423
424 bool needsOutputs() const {
425 return step_callbacks_.needs_outputs_;
426 }
427
428 int64_t debugHandle() const {
429 return debug_handle_;
430 }
431
432 void setDebugHandle(int64_t debug_handle) {
433 debug_handle_ = debug_handle;
434 }
435
436 void invalidateInputs() {
437#ifndef NDEBUG
438 inputs_valid_ = false;
439#endif
440 }
441
442 private:
443 void runStartCallbacks();
444
445 StepCallbacks step_callbacks_;
446
447 // In cases when RecordFunction might be active but we chose not to
448 // use the observers (e.g. operator is not observed), this boolean
449 // flag is used to check whether the start callbacks were called
450 bool called_start_callbacks_ = false;
451
452#ifndef NDEBUG
453 bool inputs_valid_ = false;
454#endif
455
456 // Stores various ObserverContext objects with event metadata for callbacks.
457 ObserverContextList ctx_;
458
459 c10::variant<std::string, schema_ref_t> fn_;
460
461 int64_t sequence_nr_ = -1;
462 c10::ArrayRef<const IValue> inputs_;
463 std::vector<c10::IValue> outputs_;
464
465 // For backward functions - thread id of the the forward function
466 uint64_t fwd_thread_id_ = 0;
467
468 // Unique id for this RecordFunction, used in callbacks to track start
469 // and end of ranges
470 RecordFunctionHandle handle_{0};
471
472 // Whether this record_function corresponds to an async event or not. Async
473 // events can complete in different threads or follow a future-like pattern
474 // of use.
475 bool is_async_{false};
476
477 // Debug handles are used for lazy annotation of module hierarchy
478 // and callstack.
479 // This is specifically is useful for mobile runtime, where generated
480 // debug handles can be lazily symbolicated using debug information
481 int64_t debug_handle_{-1};
482
483 // Whether this RecordFunction is used for an out variant run with
484 // Static Runtime
485 bool is_static_runtime_out_variant_{false};
486};
487
488TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
489
490TORCH_API c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(
491 RecordScope scope);
492
493namespace detail {
494template <typename Inputs, typename F, typename... Args>
495void record_function_with_scope(
496 RecordFunction& guard,
497 F fn,
498 const Inputs& inputs,
499 Args&&... args) {
500 if (guard.needsInputs()) {
501 guard.before(
502 fn,
503 c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()),
504 std::forward<Args>(args)...);
505 } else {
506 guard.before(fn, std::forward<Args>(args)...);
507 }
508}
509
510template <typename Inputs, typename F, typename... Args>
511void record_function_with_scope_and_debug_handle(
512 RecordFunction& guard,
513 F fn,
514 int64_t debug_handle,
515 const Inputs& inputs,
516 Args&&... args) {
517 guard.setDebugHandle(debug_handle);
518 if (guard.needsInputs()) {
519 guard.before(
520 fn,
521 c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()),
522 std::forward<Args>(args)...);
523 } else {
524 guard.before(fn, std::forward<Args>(args)...);
525 }
526}
527
528template <typename F, typename... Args>
529void record_function_with_scope(
530 RecordFunction& guard,
531 F fn,
532 c10::ArrayRef<const c10::IValue> inputs,
533 Args&&... args) {
534 return record_function_with_scope<
535 c10::ArrayRef<const c10::IValue>,
536 F,
537 Args...>(guard, std::move(fn), inputs, std::forward<Args>(args)...);
538}
539
540template <typename F, typename... Args>
541void record_function_with_scope_and_debug_handle(
542 RecordFunction& guard,
543 F fn,
544 int64_t debug_handle,
545 c10::ArrayRef<const c10::IValue> inputs,
546 Args&&... args) {
547 return record_function_with_scope_and_debug_handle<
548 c10::ArrayRef<const c10::IValue>,
549 F,
550 Args...>(
551 guard, std::move(fn), debug_handle, inputs, std::forward<Args>(args)...);
552}
553
554} // namespace detail
555
556// optional argument - function's seq_no
557#define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \
558 at::RecordFunction guard(scope); \
559 if (guard.isActive()) { \
560 ::at::detail::record_function_with_scope( \
561 guard, fn, inputs, ##__VA_ARGS__); \
562 }
563
564#define RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \
565 scope, fn, inputs, outputs, ...) \
566 at::RecordFunction guard(scope); \
567 if (guard.isActive()) { \
568 if (guard.needsInputs()) { \
569 guard.before(fn, inputs, ##__VA_ARGS__); \
570 } else { \
571 guard.before(fn, ##__VA_ARGS__); \
572 } \
573 if (guard.needsOutputs()) { \
574 guard.setOutputs(outputs); \
575 } \
576 }
577
578#define RECORD_FUNCTION(fn, inputs, ...) \
579 RECORD_FUNCTION_WITH_SCOPE( \
580 at::RecordScope::FUNCTION, fn, inputs, ##__VA_ARGS__)
581
582#define RECORD_TORCHSCRIPT_FUNCTION(mn, inputs) \
583 RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs)
584
585#define RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(fn, inputs, outputs, ...) \
586 RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \
587 at::RecordScope::FUNCTION, fn, inputs, outputs, ##__VA_ARGS__)
588
589// Custom user scopes in C++; similar to Python's 'with record_function("..."):'
590#define RECORD_USER_SCOPE(fn) \
591 RECORD_FUNCTION_WITH_SCOPE( \
592 at::RecordScope::USER_SCOPE, fn, c10::ArrayRef<const c10::IValue>{})
593
594// RECORD_USER_SCOPE with inputs
595#define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \
596 RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::USER_SCOPE, fn, inputs)
597
598// Helper macro to pass in debug handle that is used to
599// post process events
600#define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \
601 scope, fn, debug_handle, inputs, ...) \
602 at::RecordFunction guard(scope); \
603 if (guard.isActive()) { \
604 ::at::detail::record_function_with_scope_and_debug_handle( \
605 guard, fn, debug_handle, inputs, ##__VA_ARGS__); \
606 }
607
608// Helper macros to record LITE INTERPETER scope events with debug handles
609#define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \
610 fn, debug_handle, inputs) \
611 RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \
612 at::RecordScope::LITE_INTERPRETER, fn, debug_handle, inputs)
613
614// Bookend to the RECORD_FUNCTION macros. Use this after the kernel
615// launch to let the profiler bind the outputs to the op that produced
616// them. Note that guard is declared by RECORD_FUNCTION so this macro
617// needs to be called from the same scope as RECORD_FUNCTION
618#define RECORD_OUTPUTS(outputs) \
619 if (guard.needsOutputs()) { \
620 guard.setOutputs( \
621 std::vector<c10::IValue>(outputs.begin(), outputs.end())); \
622 }
623
624/**
625 * addThreadLocalCallback adds a thread local callback to run with
626 * RecordFunction, returns handle to use with removeThreadLocalCallback
627 */
628TORCH_API CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb);
629
630/**
631 * hasThreadLocalCallbacks returns whether there're callbacks registered
632 * with addThreadLocalCallback
633 */
634TORCH_API bool hasThreadLocalCallbacks();
635
636/**
637 * clearThreadLocalCallbacks removes all thread local callbacks
638 */
639TORCH_API void clearThreadLocalCallbacks();
640
641/**
642 * addGlobalCallback adds a global callback to run with RecordFunction:
643 *
644 * only during the program initialization
645 */
646TORCH_API CallbackHandle addGlobalCallback(RecordFunctionCallback cb);
647
648/**
649 * removeCallback removes a callback given the handle returned by
650 * addThreadLocalCallback or addGlobalCallback;
651 *
652 * no other code can run simultaneously
653 */
654TORCH_API void removeCallback(CallbackHandle handle);
655
656/**
657 * Prevent the given callback from executing. If handle is invalid,
658 * does nothing.
659 */
660TORCH_API void disableCallback(CallbackHandle handle);
661
662/**
663 * Allow the given callback, previously disabled with disableCallback, to
664 * execute again. If handle is invalid, does nothing.
665 */
666TORCH_API void reenableCallback(CallbackHandle handle);
667
668/**
669 * hasGlobalCallbacks returns whether there're global callbacks
670 * registered with pushGlobalCallback
671 */
672TORCH_API bool hasGlobalCallbacks();
673
674/**
675 * clearGlobalCallbacks removes all global callbacks
676 */
677TORCH_API void clearGlobalCallbacks();
678
679// for both thread local and global callbacks
680TORCH_API bool hasCallbacks();
681TORCH_API void clearCallbacks();
682
683/**
684 * enableRecordFunction enables RecordFunction thread locally
685 */
686TORCH_API void enableRecordFunction(bool enable = true);
687
688/**
689 * isRecordFunctionEnabled returns whether RecordFunction
690 * is enabled thread locally
691 */
692TORCH_API bool isRecordFunctionEnabled();
693
694class TORCH_API RecordFunctionGuard {
695 public:
696 explicit RecordFunctionGuard(bool is_enabled = true)
697 : prev_value_(isRecordFunctionEnabled()) {
698 enableRecordFunction(is_enabled);
699 }
700
701 virtual ~RecordFunctionGuard() {
702 enableRecordFunction(prev_value_);
703 }
704
705 private:
706 bool prev_value_ = false;
707};
708
709class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard {
710 public:
711 DisableRecordFunctionGuard() : RecordFunctionGuard(false) {}
712 ~DisableRecordFunctionGuard() override = default;
713};
714
715struct TORCH_API RecordFunctionTLS {
716 // Thread local vector of callbacks, holds pairs (callbacks, unique_id);
717 // must be sorted in increasing handles order
718 RecordFunctionCallbacks sorted_tls_callbacks_;
719
720 bool tls_record_function_enabled_ = true;
721};
722
723TORCH_API const RecordFunctionTLS& get_record_function_tls_();
724
725TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls);
726
727TORCH_API void set_record_function_seed_for_testing(uint32_t seed);
728
729} // namespace at
730