1#include <ATen/record_function.h>
2#include <ATen/core/dispatch/Dispatcher.h>
3#include <c10/macros/Macros.h>
4#include <c10/util/ThreadLocal.h>
5#include <c10/util/overloaded.h>
6
7#include <algorithm>
8#include <cstdlib>
9#include <random>
10
11namespace at {
12
13namespace {
14
15// Used to generate unique callback handles
16CallbackHandle next_unique_callback_handle() {
17 static std::atomic<uint64_t> unique_cb_id {1};
18 return CallbackHandle(unique_cb_id++);
19}
20
21RecordFunctionHandle next_unique_record_function_handle() {
22 static std::atomic<uint64_t> unique_rf_id {1};
23 return RecordFunctionHandle(unique_rf_id++);
24}
25
26std::atomic<int64_t> defaultNodeId(-1);
27
28// Enumerates thread ids logically;
29// note: std::this_thread::get_id may return potentially
30// reused thread id
31std::atomic<uint64_t> next_thread_id_ {0};
32thread_local uint64_t current_thread_id_ = 0;
33
34static constexpr size_t NumRecordScopes =
35 static_cast<size_t>(RecordScope::NUM_SCOPES);
36
37RecordFunctionCallbacks::iterator findCallback(
38 RecordFunctionCallbacks& entries,
39 CallbackHandle handle) {
40 auto match_handle = [handle](const auto& el) { return el.handle_ == handle; };
41 return std::find_if(entries.begin(), entries.end(), match_handle);
42}
43
44c10::optional<RecordFunctionCallback> extractCallback(
45 RecordFunctionCallbacks& entries,
46 CallbackHandle handle) {
47 auto it = findCallback(entries, handle);
48 if (it == entries.end()) {
49 return c10::nullopt;
50 }
51 auto out = it->callback_;
52 entries.erase(it);
53 return out;
54}
55
56// ============================================================================
57// == Callback manager ========================================================
58// ============================================================================
59// The high level idea of the RecordFunction callback machinery is based on the
60// observation that the set of callbacks to be run changes infrequently.
61// However, in order to reuse the active set we have to be able to invalidate
62// when the active set changes. There are three events that can change which
63// callbacks should be run:
64// 1) The set of global callbacks changes
65// 2) The set of local callbacks changes
66// 3) A sampling callback is present, and should run on this iteration
67//
68// Global callbacks rely on thread local replication and an atomic version
69// counter to maintain consistency. Whenever we change the set of active global
70// callbacks (add / remove / enable / disable) the `GlobalCallbackManager`
71// increments the version number and updates the global state while holding
72// a mutex. The local callback manager snapshots the global callbacks and
73// lazily rebuilds by comparing`GlobalCallbackManager::version()` (which is
74// a simple atomic read) to the version of the last rebuild. In the
75// overwhelmingly common case that they match it can reuse the existing
76// snapshot. Otherwise it must call the much more expensive (and locked)
77// `GlobalCallbackManager::getSnapshot()`.
78//
79// Handling changes to the thread local callbacks is trivial; functions that
80// change them can simply force a cache rebuild for that thread after the
81// changes are made.
82//
83// Sampling is by far the most challenging to handle efficiently. In general
84// sampling callbacks are expected to have very low frequency. (e.g. 1 per
85// million) Random number generation is rather expensive, so flipping a coin on
86// every call for every sampling callback is wasteful. We can significantly
87// reduce this cost by noting that the number of failures of a Bernoulli random
88// variable is a geometric distribution, and thus we can sample the geometric
89// distribution to determine the next time a callback should run. This reduces
90// the cost from a random sample to a simple integer decrement.
91//
92// We can further note that Bernoulli samples are independent. (In contrast to,
93// say, sampling without replacement.) This means that we can generate a
94// counter for each scope that a given callback supports and then decrement the
95// counter corresponding to the RecordScope being called. Conceptually, this is
96// analogous to flipping different coins with the same probability. By sharding
97// on RecordScope, we can consolidate the decrement to a single shared counter
98// and update individual counters during rebuild.
99
100class GlobalCallbackManager {
101 public:
102 static GlobalCallbackManager& get(); // Singleton
103
104 private:
105 GlobalCallbackManager() = default;
106
107 public:
108 static constexpr size_t NoVersion = 0;
109 using snapshot_t = std::pair<size_t, RecordFunctionCallbacks>;
110
111 // Locking?
112 size_t version() const; // No
113 snapshot_t getSnapshot() const; // Yes
114 CallbackHandle addCallback(RecordFunctionCallback cb); // Yes
115 void setCallbackEnabled(CallbackHandle handle, bool enabled); // Yes
116 void removeCallback(CallbackHandle handle); // Yes
117 void clearCallbacks(); // Yes
118
119 private:
120 std::atomic<size_t> version_{NoVersion + 1};
121 RecordFunctionCallbacks global_callbacks_; // Source of truth.
122 mutable std::mutex update_mutex_;
123};
124
125class CacheEntry {
126 public:
127 CacheEntry() = default;
128 CacheEntry(std::mt19937* generator, RecordScope scope);
129
130 // The caller is expected to check `GlobalCallbackManager::get().version()'
131 // and call CacheEntry::update() if necessary.
132 StepCallbacks getActiveCallbacks();
133 c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty();
134
135 // Full rebuild. (E.g. during registration)
136 void update(const std::vector<RecordFunctionCallback>& callbacks);
137
138 private:
139 struct CallbackAndCounter {
140 RecordFunctionCallback callback_;
141
142 // `-1` indicates that a callback is not sampled.
143 int tries_left_{-1};
144 };
145
146 C10_ALWAYS_INLINE void getActiveCallbacksImpl();
147
148 void rebuildActiveCallbacks();
149 int sampleTries(double p) const;
150
151 // std::mt19937 is quite large, so all scopes share the same generator.
152 std::mt19937* generator_{nullptr};
153
154 // Includes sampling callbacks which are waiting to run.
155 c10::SmallVector<CallbackAndCounter, kSoftLimitCallbacks> callbacks_;
156 RecordScope scope_{RecordScope::FUNCTION};
157
158 StepCallbacks active_callbacks_;
159
160 // For managing sampling callbacks
161 int sampling_countdown_{0};
162 int steps_for_this_update_{0};
163};
164
165class LocalCallbackManager {
166 public:
167 static LocalCallbackManager& get(); // Singleton
168
169 private:
170 LocalCallbackManager();
171
172 public:
173 const RecordFunctionTLS& getTLS() const;
174 StepCallbacks getActiveCallbacks(const RecordScope scope);
175 c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty(const RecordScope scope);
176
177 void setTLS(const RecordFunctionTLS& tls);
178 void seed(uint32_t seed);
179 CallbackHandle addCallback(RecordFunctionCallback callback);
180 bool setCallbackEnabled(CallbackHandle handle, bool enabled);
181 bool removeCallback(CallbackHandle handle);
182 void clearCallbacks();
183
184 private:
185 void rebuildActiveCallbacksIfNeeded();
186
187 void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot);
188
189 void rebuild_callback_scopes(
190 const GlobalCallbackManager::snapshot_t& global_snapshot,
191 const RecordFunctionCallback& callback);
192
193 void rebuild_scope(
194 const GlobalCallbackManager::snapshot_t& global_snapshot,
195 const RecordScope scope);
196
197 // Source of truth.
198 RecordFunctionTLS registered_callbacks_;
199
200 // Runtime cache.
201 size_t global_version_{GlobalCallbackManager::NoVersion};
202 std::array<CacheEntry, NumRecordScopes> active_callbacks_;
203 std::mt19937 generator_{};
204};
205
206// ============================================================================
207// == GlobalCallbackManager: Implementation ===================================
208// ============================================================================
209GlobalCallbackManager& GlobalCallbackManager::get() {
210 static GlobalCallbackManager manager;
211 return manager;
212}
213
214size_t GlobalCallbackManager::version() const {
215 return version_.load(std::memory_order_relaxed);
216}
217
218std::pair<size_t, RecordFunctionCallbacks> GlobalCallbackManager::getSnapshot() const {
219 std::lock_guard<std::mutex> guard(update_mutex_);
220 return {version_.load(std::memory_order_seq_cst), global_callbacks_};
221}
222
223CallbackHandle GlobalCallbackManager::addCallback(RecordFunctionCallback cb) {
224 std::lock_guard<std::mutex> guard(update_mutex_);
225 ++version_;
226 auto handle = next_unique_callback_handle();
227 global_callbacks_.emplace_back(std::move(cb), handle);
228 return handle;
229}
230
231void GlobalCallbackManager::setCallbackEnabled(
232 CallbackHandle handle,
233 bool enabled) {
234 std::lock_guard<std::mutex> guard(update_mutex_);
235 auto it = findCallback(global_callbacks_, handle);
236 if (it != global_callbacks_.end()) {
237 if (it->enabled_ != enabled) {
238 ++version_;
239 it->enabled_ = enabled;
240 }
241 } else {
242 LOG(WARNING) << "Requested callback is not found";
243 }
244}
245
246void GlobalCallbackManager::removeCallback(CallbackHandle handle) {
247 std::lock_guard<std::mutex> guard(update_mutex_);
248 if (extractCallback(global_callbacks_, handle).has_value()) {
249 ++version_;
250 } else {
251 LOG(WARNING) << "Requested callback is not found";
252 }
253}
254
255void GlobalCallbackManager::clearCallbacks() {
256 std::lock_guard<std::mutex> guard(update_mutex_);
257 ++version_;
258 global_callbacks_.clear();
259}
260
261// ============================================================================
262// == CacheEntry: Implementation ==============================================
263// ============================================================================
264CacheEntry::CacheEntry(std::mt19937* generator, RecordScope scope)
265 : generator_{generator}, scope_{scope} {
266 rebuildActiveCallbacks();
267}
268
269void CacheEntry::update(const std::vector<RecordFunctionCallback>& callbacks) {
270 callbacks_.clear();
271 callbacks_.reserve(callbacks.size());
272 for (const auto& callback : callbacks) {
273 const auto p = callback.samplingProb();
274 callbacks_.push_back({callback, p < 1.0 ? sampleTries(p) : -1});
275 }
276
277 rebuildActiveCallbacks();
278}
279
280void CacheEntry::getActiveCallbacksImpl() {
281 // We rebuild the active set when `sampling_countdown_` reaches zero, so if it
282 // reaches zero at the start of this function something has gone wrong.
283 TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_);
284
285 if (C10_UNLIKELY(!(--sampling_countdown_))) {
286 // Use inferred steps to update sampled callbacks.
287 for (auto& i : callbacks_) {
288 if (i.tries_left_ > 0) {
289 TORCH_INTERNAL_ASSERT(i.tries_left_ >= steps_for_this_update_);
290 i.tries_left_ -= steps_for_this_update_;
291 }
292 }
293
294 // Determine which callbacks to run and for how long.
295 rebuildActiveCallbacks();
296
297 // Resample any sampled callbacks that ran this call.
298 for (auto& i : callbacks_) {
299 if (!i.tries_left_) {
300 i.tries_left_ = sampleTries(i.callback_.samplingProb());
301 }
302 }
303 }
304}
305
306StepCallbacks CacheEntry::getActiveCallbacks() {
307 getActiveCallbacksImpl();
308 return active_callbacks_;
309}
310
311c10::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty() {
312 getActiveCallbacksImpl();
313 if (C10_LIKELY(active_callbacks_.empty())) {
314 return c10::nullopt;
315 }
316 return active_callbacks_;
317}
318
319void CacheEntry::rebuildActiveCallbacks() {
320 // We could store thread ID in CacheEntry, but rebuilds are infrequent and
321 // this saves us from having to plumb it through.
322 const auto thread_id = RecordFunction::currentThreadId();
323 active_callbacks_ = StepCallbacks(thread_id, scope_);
324
325 sampling_countdown_ = std::numeric_limits<int>::max();
326 for (const auto& i : callbacks_) {
327 if (i.tries_left_ < 0) {
328 // Callback is not sampled. Unconditionally push.
329 active_callbacks_.callbacks_.push_back(
330 {i.callback_.start(), i.callback_.end()});
331
332 } else if (i.tries_left_ == 0) {
333 // Callback is sampled and we have reached a sampling event. Push and
334 // set `sampling_countdown_` to one so we trigger a rebuild after one call.
335 active_callbacks_.callbacks_.push_back(
336 {i.callback_.start(), i.callback_.end()});
337 sampling_countdown_ = 1;
338
339 } else {
340 // Callback is sampled and we have not reached sampling event. Set
341 // `sampling_countdown_` to rebuild when it is time for this callback to
342 // execute.
343 sampling_countdown_ = std::min(sampling_countdown_, i.tries_left_);
344 }
345 active_callbacks_.needs_inputs_ |= i.callback_.needsInputs();
346 active_callbacks_.needs_outputs_ |= i.callback_.needsOutputs();
347 active_callbacks_.needs_ids_ |= i.callback_.needsIds();
348 }
349 steps_for_this_update_ = sampling_countdown_;
350}
351
352int CacheEntry::sampleTries(double p) const {
353 TORCH_INTERNAL_ASSERT(generator_ != nullptr);
354 TORCH_INTERNAL_ASSERT(p > 0.0 && p <= 1.0);
355
356 // The geometric distribution returns the number of failures. We add one to
357 // also account for the call where we succeed.
358 return std::geometric_distribution<int>(p)(*generator_) + 1;
359}
360
361// ============================================================================
362// == LocalCallbackManager: Implementation ====================================
363// ============================================================================
364LocalCallbackManager& LocalCallbackManager::get() {
365#if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
366 static c10::ThreadLocal<LocalCallbackManager> manager;
367 return manager.get();
368#else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
369 static thread_local LocalCallbackManager manager;
370 return manager;
371#endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
372}
373
374LocalCallbackManager::LocalCallbackManager() {
375 for (auto i : c10::irange(NumRecordScopes)) {
376 active_callbacks_[i] = CacheEntry(&generator_, static_cast<RecordScope>(i));
377 }
378 rebuild_all(GlobalCallbackManager::get().getSnapshot());
379}
380
381const RecordFunctionTLS& LocalCallbackManager::getTLS() const {
382 return registered_callbacks_;
383}
384
385void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() {
386 const auto global_version = GlobalCallbackManager::get().version();
387 if (C10_UNLIKELY(global_version != global_version_)) {
388 rebuild_all(GlobalCallbackManager::get().getSnapshot());
389 }
390}
391
392StepCallbacks LocalCallbackManager::getActiveCallbacks(
393 const RecordScope scope) {
394 rebuildActiveCallbacksIfNeeded();
395 return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacks();
396}
397
398c10::optional<StepCallbacks> LocalCallbackManager::getActiveCallbacksUnlessEmpty(
399 const RecordScope scope) {
400 rebuildActiveCallbacksIfNeeded();
401 return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacksUnlessEmpty();
402}
403
404void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) {
405 registered_callbacks_ = tls;
406 rebuild_all(GlobalCallbackManager::get().getSnapshot());
407}
408
409void LocalCallbackManager::seed(uint32_t seed) {
410 generator_.seed(seed);
411}
412
413CallbackHandle LocalCallbackManager::addCallback(
414 RecordFunctionCallback callback) {
415 auto handle = next_unique_callback_handle();
416 auto& callbacks = registered_callbacks_.sorted_tls_callbacks_;
417 callbacks.emplace_back(std::move(callback), handle);
418 rebuild_callback_scopes(
419 GlobalCallbackManager::get().getSnapshot(), callbacks.back().callback_);
420 return handle;
421}
422
423bool LocalCallbackManager::setCallbackEnabled(
424 CallbackHandle handle,
425 bool enabled) {
426 auto it = findCallback(registered_callbacks_.sorted_tls_callbacks_, handle);
427 auto found = (it != registered_callbacks_.sorted_tls_callbacks_.end());
428 if (found && it->enabled_ != enabled) {
429 it->enabled_ = enabled;
430 rebuild_callback_scopes(
431 GlobalCallbackManager::get().getSnapshot(), it->callback_);
432 }
433 return found;
434}
435
436bool LocalCallbackManager::removeCallback(CallbackHandle handle) {
437 auto& callbacks = registered_callbacks_.sorted_tls_callbacks_;
438 auto callback = extractCallback(callbacks, handle);
439 if (callback.has_value()) {
440 rebuild_callback_scopes(
441 GlobalCallbackManager::get().getSnapshot(), *callback);
442 }
443 return callback.has_value();
444}
445
446void LocalCallbackManager::clearCallbacks() {
447 registered_callbacks_.sorted_tls_callbacks_.clear();
448 rebuild_all(GlobalCallbackManager::get().getSnapshot());
449}
450
451void LocalCallbackManager::rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot) {
452 global_version_ = global_snapshot.first;
453 for (auto i : c10::irange(NumRecordScopes)) {
454 rebuild_scope(global_snapshot, static_cast<RecordScope>(i));
455 }
456}
457
458void LocalCallbackManager::rebuild_callback_scopes(
459 const GlobalCallbackManager::snapshot_t& global_snapshot,
460 const RecordFunctionCallback& callback) {
461 if (global_snapshot.first == global_version_) {
462 // Only rebuild scopes associated with `callback`
463 for (auto i : c10::irange(NumRecordScopes)) {
464 if (callback.checkScope(static_cast<RecordScope>(i))) {
465 rebuild_scope(global_snapshot, static_cast<RecordScope>(i));
466 }
467 }
468 } else {
469 rebuild_all(global_snapshot);
470 }
471}
472
473void LocalCallbackManager::rebuild_scope(
474 const GlobalCallbackManager::snapshot_t& global_snapshot,
475 const RecordScope scope) {
476 std::vector<RecordFunctionCallback> callbacks;
477 if (registered_callbacks_.tls_record_function_enabled_) {
478 auto populate_callbacks =
479 [&](const RecordFunctionCallbacks& raw_callbacks) {
480 for (const auto& i : raw_callbacks) {
481 if (i.enabled_ && i.callback_.checkScope(scope) &&
482 i.callback_.samplingProb() > 0) {
483 callbacks.push_back(i.callback_);
484 }
485 }
486 };
487 populate_callbacks(global_snapshot.second);
488 populate_callbacks(registered_callbacks_.sorted_tls_callbacks_);
489 }
490 active_callbacks_[static_cast<size_t>(scope)].update(callbacks);
491}
492
493// ============================================================================
494// == Callback execution ======================================================
495// ============================================================================
496void logTryRunCallbackError(const char* what, const char* name) {
497 LOG(WARNING) << "Exception in RecordFunction callback: " << what
498 << " , for the range " << name;
499}
500
501template <bool is_start>
502C10_ALWAYS_INLINE bool tryRunCallback(
503 const StepCallbacks::StartEndPair callback_ptrs,
504 const RecordFunction& rf,
505 std::unique_ptr<ObserverContext>& ctx) {
506 try {
507 if (is_start && callback_ptrs.start_) {
508 ctx = callback_ptrs.start_(rf);
509 }
510
511 if (!is_start && callback_ptrs.end_) {
512 callback_ptrs.end_(rf, ctx.get());
513 }
514
515 return true;
516 } catch (const std::exception& e) {
517 logTryRunCallbackError(e.what(), rf.name());
518 return false;
519 } catch (...) {
520 logTryRunCallbackError("unknown", rf.name());
521 return false;
522 }
523}
524
525} // namespace
526
527RecordFunction::RecordFunction(RecordScope scope)
528 : RecordFunction(getStepCallbacks(scope)) {}
529
530RecordFunction::RecordFunction(StepCallbacks&& step_callbacks)
531 : step_callbacks_{std::move(step_callbacks)} {
532 ctx_.resize(step_callbacks_.callbacks_.size());
533 if (step_callbacks_.needs_ids_) {
534 setHandle(next_unique_record_function_handle());
535 }
536}
537
538void RecordFunction::runStartCallbacks() {
539 for (const auto i : c10::irange(step_callbacks_.callbacks_.size())) {
540 tryRunCallback</*is_start=*/true>(
541 step_callbacks_.callbacks_[i], *this, ctx_[i]);
542 }
543 called_start_callbacks_ = true;
544}
545
546void RecordFunction::end() {
547 if (called_start_callbacks_) {
548 for (const auto i : c10::irange(step_callbacks_.callbacks_.size())) {
549 tryRunCallback</*is_start=*/false>(
550 step_callbacks_.callbacks_[i], *this, ctx_[i]);
551 }
552 step_callbacks_.callbacks_.clear();
553 }
554}
555
556const char* RecordFunction::name() const {
557 return c10::visit(
558 c10::overloaded(
559 [](const std::string& name) { return name.c_str(); },
560 [](const schema_ref_t schema) {
561 return schema.get().name().c_str();
562 }),
563 fn_);
564}
565
566size_t RecordFunction::num_inputs() const {
567 return c10::visit(
568 c10::overloaded(
569 [&](const std::string&) { return inputs_.size(); },
570 [](const schema_ref_t schema) {
571 return schema.get().arguments().size();
572 }),
573 fn_);
574}
575
576size_t RecordFunction::num_outputs() const {
577 return c10::visit(
578 c10::overloaded(
579 [&](const std::string&) { return outputs_.size(); },
580 [](const schema_ref_t schema) {
581 return schema.get().returns().size();
582 }),
583 fn_);
584}
585
586c10::optional<OperatorName> RecordFunction::operator_name() const {
587 return c10::visit(
588 c10::overloaded(
589 [&](const std::string&) -> c10::optional<OperatorName> {
590 return c10::nullopt;
591 },
592 [](const schema_ref_t schema) -> c10::optional<OperatorName> {
593 return schema.get().operator_name();
594 }),
595 fn_);
596}
597
598c10::optional<c10::FunctionSchema> RecordFunction::operator_schema() const {
599 return c10::visit(
600 c10::overloaded(
601 [&](const std::string&) -> c10::optional<c10::FunctionSchema> {
602 return c10::nullopt;
603 },
604 [](const schema_ref_t schema) -> c10::optional<c10::FunctionSchema> {
605 return schema.get();
606 }),
607 fn_);
608}
609
610StepCallbacks getStepCallbacks(RecordScope scope) {
611 return LocalCallbackManager::get().getActiveCallbacks(scope);
612}
613
614c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope) {
615 return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope);
616}
617
618const RecordFunctionTLS& get_record_function_tls_() {
619 return LocalCallbackManager::get().getTLS();
620}
621
622void set_record_function_tls_(const RecordFunctionTLS& tls) {
623 LocalCallbackManager::get().setTLS(tls);
624}
625
626namespace {
627bool anyEnabled(const RecordFunctionCallbacks& callbacks) {
628 return std::any_of(callbacks.begin(), callbacks.end(), [](const auto& cb) {
629 return cb.enabled_;
630 });
631}
632} // namespace
633
634bool hasCallbacks() {
635 return hasThreadLocalCallbacks() || hasGlobalCallbacks();
636}
637
638bool hasGlobalCallbacks() {
639 return anyEnabled(GlobalCallbackManager::get().getSnapshot().second);
640}
641
642bool hasThreadLocalCallbacks() {
643 return anyEnabled(get_record_function_tls_().sorted_tls_callbacks_);
644}
645
646CallbackHandle addThreadLocalCallback(
647 RecordFunctionCallback cb) {
648 return LocalCallbackManager::get().addCallback(std::move(cb));
649}
650
651CallbackHandle addGlobalCallback(
652 RecordFunctionCallback cb) {
653 return GlobalCallbackManager::get().addCallback(std::move(cb));
654}
655
656void removeCallback(CallbackHandle handle) {
657 if (!LocalCallbackManager::get().removeCallback(handle)) {
658 GlobalCallbackManager::get().removeCallback(handle);
659 }
660}
661
662void disableCallback(CallbackHandle handle) {
663 if (!LocalCallbackManager::get().setCallbackEnabled(handle, false)) {
664 GlobalCallbackManager::get().setCallbackEnabled(handle, false);
665 }
666}
667
668void reenableCallback(CallbackHandle handle) {
669 if (!LocalCallbackManager::get().setCallbackEnabled(handle, true)) {
670 GlobalCallbackManager::get().setCallbackEnabled(handle, true);
671 }
672}
673
674void clearGlobalCallbacks() {
675 GlobalCallbackManager::get().clearCallbacks();
676}
677
678void clearThreadLocalCallbacks() {
679 LocalCallbackManager::get().clearCallbacks();
680}
681
682void clearCallbacks() {
683 clearGlobalCallbacks();
684 clearThreadLocalCallbacks();
685}
686
687bool isRecordFunctionEnabled() {
688 return LocalCallbackManager::get().getTLS().tls_record_function_enabled_;
689}
690
691void enableRecordFunction(bool enable) {
692 auto tls = LocalCallbackManager::get().getTLS();
693 if (tls.tls_record_function_enabled_ != enable) {
694 tls.tls_record_function_enabled_ = enable;
695 LocalCallbackManager::get().setTLS(tls);
696 }
697}
698
699void set_record_function_seed_for_testing(uint32_t seed) {
700 LocalCallbackManager::get().seed(seed);
701}
702
703/* static */
704uint64_t RecordFunction::currentThreadId() {
705 if (!current_thread_id_) {
706 // happens only once per thread
707 current_thread_id_ = ++next_thread_id_;
708 }
709 return current_thread_id_;
710}
711
712void RecordFunction::before(const char* name, int64_t sequence_nr) {
713 fn_ = name;
714 sequence_nr_ = sequence_nr;
715
716#ifndef NDEBUG
717 inputs_valid_ = true;
718#endif
719 runStartCallbacks();
720 invalidateInputs();
721}
722
723void RecordFunction::before(std::string name, int64_t sequence_nr) {
724 fn_ = std::move(name);
725 sequence_nr_ = sequence_nr;
726
727#ifndef NDEBUG
728 inputs_valid_ = true;
729#endif
730 runStartCallbacks();
731 invalidateInputs();
732}
733
734void RecordFunction::before(
735 RecordFunction::schema_ref_t schema,
736 int64_t sequence_nr) {
737 sequence_nr_ = sequence_nr;
738 fn_ = schema;
739
740#ifndef NDEBUG
741 inputs_valid_ = true;
742#endif
743 runStartCallbacks();
744 invalidateInputs();
745}
746
747/* static */ void RecordFunction::setDefaultNodeId(int64_t newDefaultNodeId) {
748 TORCH_CHECK(newDefaultNodeId >= 0, "setDefaultNodeId expects an id >= 0.");
749 defaultNodeId = newDefaultNodeId;
750}
751
752/* static */ int64_t RecordFunction::getDefaultNodeId() {
753 return defaultNodeId;
754}
755
756RecordFunction::~RecordFunction() {
757 end();
758}
759
760void RecordFunction::_setAsync() {
761 is_async_ = true;
762}
763
764bool RecordFunction::isAsync() const {
765 return is_async_;
766}
767
768void RecordFunction::_setStaticRuntimeOutVariant() {
769 if (isActive()) {
770 is_static_runtime_out_variant_ = true;
771 }
772}
773
774bool RecordFunction::isStaticRuntimeOutVariant() const {
775 if (isActive()) {
776 return is_static_runtime_out_variant_;
777 }
778 return false;
779}
780} // namespace at
781