1 | #include <ATen/record_function.h> |
2 | #include <ATen/core/dispatch/Dispatcher.h> |
3 | #include <c10/macros/Macros.h> |
4 | #include <c10/util/ThreadLocal.h> |
5 | #include <c10/util/overloaded.h> |
6 | |
7 | #include <algorithm> |
8 | #include <cstdlib> |
9 | #include <random> |
10 | |
11 | namespace at { |
12 | |
13 | namespace { |
14 | |
15 | // Used to generate unique callback handles |
16 | CallbackHandle next_unique_callback_handle() { |
17 | static std::atomic<uint64_t> unique_cb_id {1}; |
18 | return CallbackHandle(unique_cb_id++); |
19 | } |
20 | |
21 | RecordFunctionHandle next_unique_record_function_handle() { |
22 | static std::atomic<uint64_t> unique_rf_id {1}; |
23 | return RecordFunctionHandle(unique_rf_id++); |
24 | } |
25 | |
26 | std::atomic<int64_t> defaultNodeId(-1); |
27 | |
28 | // Enumerates thread ids logically; |
29 | // note: std::this_thread::get_id may return potentially |
30 | // reused thread id |
31 | std::atomic<uint64_t> next_thread_id_ {0}; |
32 | thread_local uint64_t current_thread_id_ = 0; |
33 | |
34 | static constexpr size_t NumRecordScopes = |
35 | static_cast<size_t>(RecordScope::NUM_SCOPES); |
36 | |
37 | RecordFunctionCallbacks::iterator findCallback( |
38 | RecordFunctionCallbacks& entries, |
39 | CallbackHandle handle) { |
40 | auto match_handle = [handle](const auto& el) { return el.handle_ == handle; }; |
41 | return std::find_if(entries.begin(), entries.end(), match_handle); |
42 | } |
43 | |
44 | c10::optional<RecordFunctionCallback> ( |
45 | RecordFunctionCallbacks& entries, |
46 | CallbackHandle handle) { |
47 | auto it = findCallback(entries, handle); |
48 | if (it == entries.end()) { |
49 | return c10::nullopt; |
50 | } |
51 | auto out = it->callback_; |
52 | entries.erase(it); |
53 | return out; |
54 | } |
55 | |
56 | // ============================================================================ |
57 | // == Callback manager ======================================================== |
58 | // ============================================================================ |
59 | // The high level idea of the RecordFunction callback machinery is based on the |
60 | // observation that the set of callbacks to be run changes infrequently. |
61 | // However, in order to reuse the active set we have to be able to invalidate |
62 | // when the active set changes. There are three events that can change which |
63 | // callbacks should be run: |
64 | // 1) The set of global callbacks changes |
65 | // 2) The set of local callbacks changes |
66 | // 3) A sampling callback is present, and should run on this iteration |
67 | // |
68 | // Global callbacks rely on thread local replication and an atomic version |
69 | // counter to maintain consistency. Whenever we change the set of active global |
70 | // callbacks (add / remove / enable / disable) the `GlobalCallbackManager` |
71 | // increments the version number and updates the global state while holding |
72 | // a mutex. The local callback manager snapshots the global callbacks and |
73 | // lazily rebuilds by comparing`GlobalCallbackManager::version()` (which is |
74 | // a simple atomic read) to the version of the last rebuild. In the |
75 | // overwhelmingly common case that they match it can reuse the existing |
76 | // snapshot. Otherwise it must call the much more expensive (and locked) |
77 | // `GlobalCallbackManager::getSnapshot()`. |
78 | // |
79 | // Handling changes to the thread local callbacks is trivial; functions that |
80 | // change them can simply force a cache rebuild for that thread after the |
81 | // changes are made. |
82 | // |
83 | // Sampling is by far the most challenging to handle efficiently. In general |
84 | // sampling callbacks are expected to have very low frequency. (e.g. 1 per |
85 | // million) Random number generation is rather expensive, so flipping a coin on |
86 | // every call for every sampling callback is wasteful. We can significantly |
87 | // reduce this cost by noting that the number of failures of a Bernoulli random |
88 | // variable is a geometric distribution, and thus we can sample the geometric |
89 | // distribution to determine the next time a callback should run. This reduces |
90 | // the cost from a random sample to a simple integer decrement. |
91 | // |
92 | // We can further note that Bernoulli samples are independent. (In contrast to, |
93 | // say, sampling without replacement.) This means that we can generate a |
94 | // counter for each scope that a given callback supports and then decrement the |
95 | // counter corresponding to the RecordScope being called. Conceptually, this is |
96 | // analogous to flipping different coins with the same probability. By sharding |
97 | // on RecordScope, we can consolidate the decrement to a single shared counter |
98 | // and update individual counters during rebuild. |
99 | |
100 | class GlobalCallbackManager { |
101 | public: |
102 | static GlobalCallbackManager& get(); // Singleton |
103 | |
104 | private: |
105 | GlobalCallbackManager() = default; |
106 | |
107 | public: |
108 | static constexpr size_t NoVersion = 0; |
109 | using snapshot_t = std::pair<size_t, RecordFunctionCallbacks>; |
110 | |
111 | // Locking? |
112 | size_t version() const; // No |
113 | snapshot_t getSnapshot() const; // Yes |
114 | CallbackHandle addCallback(RecordFunctionCallback cb); // Yes |
115 | void setCallbackEnabled(CallbackHandle handle, bool enabled); // Yes |
116 | void removeCallback(CallbackHandle handle); // Yes |
117 | void clearCallbacks(); // Yes |
118 | |
119 | private: |
120 | std::atomic<size_t> version_{NoVersion + 1}; |
121 | RecordFunctionCallbacks global_callbacks_; // Source of truth. |
122 | mutable std::mutex update_mutex_; |
123 | }; |
124 | |
125 | class CacheEntry { |
126 | public: |
127 | CacheEntry() = default; |
128 | CacheEntry(std::mt19937* generator, RecordScope scope); |
129 | |
130 | // The caller is expected to check `GlobalCallbackManager::get().version()' |
131 | // and call CacheEntry::update() if necessary. |
132 | StepCallbacks getActiveCallbacks(); |
133 | c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty(); |
134 | |
135 | // Full rebuild. (E.g. during registration) |
136 | void update(const std::vector<RecordFunctionCallback>& callbacks); |
137 | |
138 | private: |
139 | struct CallbackAndCounter { |
140 | RecordFunctionCallback callback_; |
141 | |
142 | // `-1` indicates that a callback is not sampled. |
143 | int tries_left_{-1}; |
144 | }; |
145 | |
146 | C10_ALWAYS_INLINE void getActiveCallbacksImpl(); |
147 | |
148 | void rebuildActiveCallbacks(); |
149 | int sampleTries(double p) const; |
150 | |
151 | // std::mt19937 is quite large, so all scopes share the same generator. |
152 | std::mt19937* generator_{nullptr}; |
153 | |
154 | // Includes sampling callbacks which are waiting to run. |
155 | c10::SmallVector<CallbackAndCounter, kSoftLimitCallbacks> callbacks_; |
156 | RecordScope scope_{RecordScope::FUNCTION}; |
157 | |
158 | StepCallbacks active_callbacks_; |
159 | |
160 | // For managing sampling callbacks |
161 | int sampling_countdown_{0}; |
162 | int steps_for_this_update_{0}; |
163 | }; |
164 | |
165 | class LocalCallbackManager { |
166 | public: |
167 | static LocalCallbackManager& get(); // Singleton |
168 | |
169 | private: |
170 | LocalCallbackManager(); |
171 | |
172 | public: |
173 | const RecordFunctionTLS& getTLS() const; |
174 | StepCallbacks getActiveCallbacks(const RecordScope scope); |
175 | c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty(const RecordScope scope); |
176 | |
177 | void setTLS(const RecordFunctionTLS& tls); |
178 | void seed(uint32_t seed); |
179 | CallbackHandle addCallback(RecordFunctionCallback callback); |
180 | bool setCallbackEnabled(CallbackHandle handle, bool enabled); |
181 | bool removeCallback(CallbackHandle handle); |
182 | void clearCallbacks(); |
183 | |
184 | private: |
185 | void rebuildActiveCallbacksIfNeeded(); |
186 | |
187 | void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot); |
188 | |
189 | void rebuild_callback_scopes( |
190 | const GlobalCallbackManager::snapshot_t& global_snapshot, |
191 | const RecordFunctionCallback& callback); |
192 | |
193 | void rebuild_scope( |
194 | const GlobalCallbackManager::snapshot_t& global_snapshot, |
195 | const RecordScope scope); |
196 | |
197 | // Source of truth. |
198 | RecordFunctionTLS registered_callbacks_; |
199 | |
200 | // Runtime cache. |
201 | size_t global_version_{GlobalCallbackManager::NoVersion}; |
202 | std::array<CacheEntry, NumRecordScopes> active_callbacks_; |
203 | std::mt19937 generator_{}; |
204 | }; |
205 | |
206 | // ============================================================================ |
207 | // == GlobalCallbackManager: Implementation =================================== |
208 | // ============================================================================ |
209 | GlobalCallbackManager& GlobalCallbackManager::get() { |
210 | static GlobalCallbackManager manager; |
211 | return manager; |
212 | } |
213 | |
214 | size_t GlobalCallbackManager::version() const { |
215 | return version_.load(std::memory_order_relaxed); |
216 | } |
217 | |
218 | std::pair<size_t, RecordFunctionCallbacks> GlobalCallbackManager::getSnapshot() const { |
219 | std::lock_guard<std::mutex> guard(update_mutex_); |
220 | return {version_.load(std::memory_order_seq_cst), global_callbacks_}; |
221 | } |
222 | |
223 | CallbackHandle GlobalCallbackManager::addCallback(RecordFunctionCallback cb) { |
224 | std::lock_guard<std::mutex> guard(update_mutex_); |
225 | ++version_; |
226 | auto handle = next_unique_callback_handle(); |
227 | global_callbacks_.emplace_back(std::move(cb), handle); |
228 | return handle; |
229 | } |
230 | |
231 | void GlobalCallbackManager::setCallbackEnabled( |
232 | CallbackHandle handle, |
233 | bool enabled) { |
234 | std::lock_guard<std::mutex> guard(update_mutex_); |
235 | auto it = findCallback(global_callbacks_, handle); |
236 | if (it != global_callbacks_.end()) { |
237 | if (it->enabled_ != enabled) { |
238 | ++version_; |
239 | it->enabled_ = enabled; |
240 | } |
241 | } else { |
242 | LOG(WARNING) << "Requested callback is not found" ; |
243 | } |
244 | } |
245 | |
246 | void GlobalCallbackManager::removeCallback(CallbackHandle handle) { |
247 | std::lock_guard<std::mutex> guard(update_mutex_); |
248 | if (extractCallback(global_callbacks_, handle).has_value()) { |
249 | ++version_; |
250 | } else { |
251 | LOG(WARNING) << "Requested callback is not found" ; |
252 | } |
253 | } |
254 | |
255 | void GlobalCallbackManager::clearCallbacks() { |
256 | std::lock_guard<std::mutex> guard(update_mutex_); |
257 | ++version_; |
258 | global_callbacks_.clear(); |
259 | } |
260 | |
261 | // ============================================================================ |
262 | // == CacheEntry: Implementation ============================================== |
263 | // ============================================================================ |
264 | CacheEntry::CacheEntry(std::mt19937* generator, RecordScope scope) |
265 | : generator_{generator}, scope_{scope} { |
266 | rebuildActiveCallbacks(); |
267 | } |
268 | |
269 | void CacheEntry::update(const std::vector<RecordFunctionCallback>& callbacks) { |
270 | callbacks_.clear(); |
271 | callbacks_.reserve(callbacks.size()); |
272 | for (const auto& callback : callbacks) { |
273 | const auto p = callback.samplingProb(); |
274 | callbacks_.push_back({callback, p < 1.0 ? sampleTries(p) : -1}); |
275 | } |
276 | |
277 | rebuildActiveCallbacks(); |
278 | } |
279 | |
280 | void CacheEntry::getActiveCallbacksImpl() { |
281 | // We rebuild the active set when `sampling_countdown_` reaches zero, so if it |
282 | // reaches zero at the start of this function something has gone wrong. |
283 | TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_); |
284 | |
285 | if (C10_UNLIKELY(!(--sampling_countdown_))) { |
286 | // Use inferred steps to update sampled callbacks. |
287 | for (auto& i : callbacks_) { |
288 | if (i.tries_left_ > 0) { |
289 | TORCH_INTERNAL_ASSERT(i.tries_left_ >= steps_for_this_update_); |
290 | i.tries_left_ -= steps_for_this_update_; |
291 | } |
292 | } |
293 | |
294 | // Determine which callbacks to run and for how long. |
295 | rebuildActiveCallbacks(); |
296 | |
297 | // Resample any sampled callbacks that ran this call. |
298 | for (auto& i : callbacks_) { |
299 | if (!i.tries_left_) { |
300 | i.tries_left_ = sampleTries(i.callback_.samplingProb()); |
301 | } |
302 | } |
303 | } |
304 | } |
305 | |
306 | StepCallbacks CacheEntry::getActiveCallbacks() { |
307 | getActiveCallbacksImpl(); |
308 | return active_callbacks_; |
309 | } |
310 | |
311 | c10::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty() { |
312 | getActiveCallbacksImpl(); |
313 | if (C10_LIKELY(active_callbacks_.empty())) { |
314 | return c10::nullopt; |
315 | } |
316 | return active_callbacks_; |
317 | } |
318 | |
319 | void CacheEntry::rebuildActiveCallbacks() { |
320 | // We could store thread ID in CacheEntry, but rebuilds are infrequent and |
321 | // this saves us from having to plumb it through. |
322 | const auto thread_id = RecordFunction::currentThreadId(); |
323 | active_callbacks_ = StepCallbacks(thread_id, scope_); |
324 | |
325 | sampling_countdown_ = std::numeric_limits<int>::max(); |
326 | for (const auto& i : callbacks_) { |
327 | if (i.tries_left_ < 0) { |
328 | // Callback is not sampled. Unconditionally push. |
329 | active_callbacks_.callbacks_.push_back( |
330 | {i.callback_.start(), i.callback_.end()}); |
331 | |
332 | } else if (i.tries_left_ == 0) { |
333 | // Callback is sampled and we have reached a sampling event. Push and |
334 | // set `sampling_countdown_` to one so we trigger a rebuild after one call. |
335 | active_callbacks_.callbacks_.push_back( |
336 | {i.callback_.start(), i.callback_.end()}); |
337 | sampling_countdown_ = 1; |
338 | |
339 | } else { |
340 | // Callback is sampled and we have not reached sampling event. Set |
341 | // `sampling_countdown_` to rebuild when it is time for this callback to |
342 | // execute. |
343 | sampling_countdown_ = std::min(sampling_countdown_, i.tries_left_); |
344 | } |
345 | active_callbacks_.needs_inputs_ |= i.callback_.needsInputs(); |
346 | active_callbacks_.needs_outputs_ |= i.callback_.needsOutputs(); |
347 | active_callbacks_.needs_ids_ |= i.callback_.needsIds(); |
348 | } |
349 | steps_for_this_update_ = sampling_countdown_; |
350 | } |
351 | |
352 | int CacheEntry::sampleTries(double p) const { |
353 | TORCH_INTERNAL_ASSERT(generator_ != nullptr); |
354 | TORCH_INTERNAL_ASSERT(p > 0.0 && p <= 1.0); |
355 | |
356 | // The geometric distribution returns the number of failures. We add one to |
357 | // also account for the call where we succeed. |
358 | return std::geometric_distribution<int>(p)(*generator_) + 1; |
359 | } |
360 | |
361 | // ============================================================================ |
362 | // == LocalCallbackManager: Implementation ==================================== |
363 | // ============================================================================ |
364 | LocalCallbackManager& LocalCallbackManager::get() { |
365 | #if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) |
366 | static c10::ThreadLocal<LocalCallbackManager> manager; |
367 | return manager.get(); |
368 | #else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) |
369 | static thread_local LocalCallbackManager manager; |
370 | return manager; |
371 | #endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) |
372 | } |
373 | |
374 | LocalCallbackManager::LocalCallbackManager() { |
375 | for (auto i : c10::irange(NumRecordScopes)) { |
376 | active_callbacks_[i] = CacheEntry(&generator_, static_cast<RecordScope>(i)); |
377 | } |
378 | rebuild_all(GlobalCallbackManager::get().getSnapshot()); |
379 | } |
380 | |
381 | const RecordFunctionTLS& LocalCallbackManager::getTLS() const { |
382 | return registered_callbacks_; |
383 | } |
384 | |
385 | void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() { |
386 | const auto global_version = GlobalCallbackManager::get().version(); |
387 | if (C10_UNLIKELY(global_version != global_version_)) { |
388 | rebuild_all(GlobalCallbackManager::get().getSnapshot()); |
389 | } |
390 | } |
391 | |
392 | StepCallbacks LocalCallbackManager::getActiveCallbacks( |
393 | const RecordScope scope) { |
394 | rebuildActiveCallbacksIfNeeded(); |
395 | return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacks(); |
396 | } |
397 | |
398 | c10::optional<StepCallbacks> LocalCallbackManager::getActiveCallbacksUnlessEmpty( |
399 | const RecordScope scope) { |
400 | rebuildActiveCallbacksIfNeeded(); |
401 | return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacksUnlessEmpty(); |
402 | } |
403 | |
404 | void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) { |
405 | registered_callbacks_ = tls; |
406 | rebuild_all(GlobalCallbackManager::get().getSnapshot()); |
407 | } |
408 | |
409 | void LocalCallbackManager::seed(uint32_t seed) { |
410 | generator_.seed(seed); |
411 | } |
412 | |
413 | CallbackHandle LocalCallbackManager::addCallback( |
414 | RecordFunctionCallback callback) { |
415 | auto handle = next_unique_callback_handle(); |
416 | auto& callbacks = registered_callbacks_.sorted_tls_callbacks_; |
417 | callbacks.emplace_back(std::move(callback), handle); |
418 | rebuild_callback_scopes( |
419 | GlobalCallbackManager::get().getSnapshot(), callbacks.back().callback_); |
420 | return handle; |
421 | } |
422 | |
423 | bool LocalCallbackManager::setCallbackEnabled( |
424 | CallbackHandle handle, |
425 | bool enabled) { |
426 | auto it = findCallback(registered_callbacks_.sorted_tls_callbacks_, handle); |
427 | auto found = (it != registered_callbacks_.sorted_tls_callbacks_.end()); |
428 | if (found && it->enabled_ != enabled) { |
429 | it->enabled_ = enabled; |
430 | rebuild_callback_scopes( |
431 | GlobalCallbackManager::get().getSnapshot(), it->callback_); |
432 | } |
433 | return found; |
434 | } |
435 | |
436 | bool LocalCallbackManager::removeCallback(CallbackHandle handle) { |
437 | auto& callbacks = registered_callbacks_.sorted_tls_callbacks_; |
438 | auto callback = extractCallback(callbacks, handle); |
439 | if (callback.has_value()) { |
440 | rebuild_callback_scopes( |
441 | GlobalCallbackManager::get().getSnapshot(), *callback); |
442 | } |
443 | return callback.has_value(); |
444 | } |
445 | |
446 | void LocalCallbackManager::clearCallbacks() { |
447 | registered_callbacks_.sorted_tls_callbacks_.clear(); |
448 | rebuild_all(GlobalCallbackManager::get().getSnapshot()); |
449 | } |
450 | |
451 | void LocalCallbackManager::rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot) { |
452 | global_version_ = global_snapshot.first; |
453 | for (auto i : c10::irange(NumRecordScopes)) { |
454 | rebuild_scope(global_snapshot, static_cast<RecordScope>(i)); |
455 | } |
456 | } |
457 | |
458 | void LocalCallbackManager::rebuild_callback_scopes( |
459 | const GlobalCallbackManager::snapshot_t& global_snapshot, |
460 | const RecordFunctionCallback& callback) { |
461 | if (global_snapshot.first == global_version_) { |
462 | // Only rebuild scopes associated with `callback` |
463 | for (auto i : c10::irange(NumRecordScopes)) { |
464 | if (callback.checkScope(static_cast<RecordScope>(i))) { |
465 | rebuild_scope(global_snapshot, static_cast<RecordScope>(i)); |
466 | } |
467 | } |
468 | } else { |
469 | rebuild_all(global_snapshot); |
470 | } |
471 | } |
472 | |
473 | void LocalCallbackManager::rebuild_scope( |
474 | const GlobalCallbackManager::snapshot_t& global_snapshot, |
475 | const RecordScope scope) { |
476 | std::vector<RecordFunctionCallback> callbacks; |
477 | if (registered_callbacks_.tls_record_function_enabled_) { |
478 | auto populate_callbacks = |
479 | [&](const RecordFunctionCallbacks& raw_callbacks) { |
480 | for (const auto& i : raw_callbacks) { |
481 | if (i.enabled_ && i.callback_.checkScope(scope) && |
482 | i.callback_.samplingProb() > 0) { |
483 | callbacks.push_back(i.callback_); |
484 | } |
485 | } |
486 | }; |
487 | populate_callbacks(global_snapshot.second); |
488 | populate_callbacks(registered_callbacks_.sorted_tls_callbacks_); |
489 | } |
490 | active_callbacks_[static_cast<size_t>(scope)].update(callbacks); |
491 | } |
492 | |
493 | // ============================================================================ |
494 | // == Callback execution ====================================================== |
495 | // ============================================================================ |
496 | void logTryRunCallbackError(const char* what, const char* name) { |
497 | LOG(WARNING) << "Exception in RecordFunction callback: " << what |
498 | << " , for the range " << name; |
499 | } |
500 | |
501 | template <bool is_start> |
502 | C10_ALWAYS_INLINE bool tryRunCallback( |
503 | const StepCallbacks::StartEndPair callback_ptrs, |
504 | const RecordFunction& rf, |
505 | std::unique_ptr<ObserverContext>& ctx) { |
506 | try { |
507 | if (is_start && callback_ptrs.start_) { |
508 | ctx = callback_ptrs.start_(rf); |
509 | } |
510 | |
511 | if (!is_start && callback_ptrs.end_) { |
512 | callback_ptrs.end_(rf, ctx.get()); |
513 | } |
514 | |
515 | return true; |
516 | } catch (const std::exception& e) { |
517 | logTryRunCallbackError(e.what(), rf.name()); |
518 | return false; |
519 | } catch (...) { |
520 | logTryRunCallbackError("unknown" , rf.name()); |
521 | return false; |
522 | } |
523 | } |
524 | |
525 | } // namespace |
526 | |
527 | RecordFunction::RecordFunction(RecordScope scope) |
528 | : RecordFunction(getStepCallbacks(scope)) {} |
529 | |
530 | RecordFunction::RecordFunction(StepCallbacks&& step_callbacks) |
531 | : step_callbacks_{std::move(step_callbacks)} { |
532 | ctx_.resize(step_callbacks_.callbacks_.size()); |
533 | if (step_callbacks_.needs_ids_) { |
534 | setHandle(next_unique_record_function_handle()); |
535 | } |
536 | } |
537 | |
538 | void RecordFunction::runStartCallbacks() { |
539 | for (const auto i : c10::irange(step_callbacks_.callbacks_.size())) { |
540 | tryRunCallback</*is_start=*/true>( |
541 | step_callbacks_.callbacks_[i], *this, ctx_[i]); |
542 | } |
543 | called_start_callbacks_ = true; |
544 | } |
545 | |
546 | void RecordFunction::end() { |
547 | if (called_start_callbacks_) { |
548 | for (const auto i : c10::irange(step_callbacks_.callbacks_.size())) { |
549 | tryRunCallback</*is_start=*/false>( |
550 | step_callbacks_.callbacks_[i], *this, ctx_[i]); |
551 | } |
552 | step_callbacks_.callbacks_.clear(); |
553 | } |
554 | } |
555 | |
556 | const char* RecordFunction::name() const { |
557 | return c10::visit( |
558 | c10::overloaded( |
559 | [](const std::string& name) { return name.c_str(); }, |
560 | [](const schema_ref_t schema) { |
561 | return schema.get().name().c_str(); |
562 | }), |
563 | fn_); |
564 | } |
565 | |
566 | size_t RecordFunction::num_inputs() const { |
567 | return c10::visit( |
568 | c10::overloaded( |
569 | [&](const std::string&) { return inputs_.size(); }, |
570 | [](const schema_ref_t schema) { |
571 | return schema.get().arguments().size(); |
572 | }), |
573 | fn_); |
574 | } |
575 | |
576 | size_t RecordFunction::num_outputs() const { |
577 | return c10::visit( |
578 | c10::overloaded( |
579 | [&](const std::string&) { return outputs_.size(); }, |
580 | [](const schema_ref_t schema) { |
581 | return schema.get().returns().size(); |
582 | }), |
583 | fn_); |
584 | } |
585 | |
586 | c10::optional<OperatorName> RecordFunction::operator_name() const { |
587 | return c10::visit( |
588 | c10::overloaded( |
589 | [&](const std::string&) -> c10::optional<OperatorName> { |
590 | return c10::nullopt; |
591 | }, |
592 | [](const schema_ref_t schema) -> c10::optional<OperatorName> { |
593 | return schema.get().operator_name(); |
594 | }), |
595 | fn_); |
596 | } |
597 | |
598 | c10::optional<c10::FunctionSchema> RecordFunction::operator_schema() const { |
599 | return c10::visit( |
600 | c10::overloaded( |
601 | [&](const std::string&) -> c10::optional<c10::FunctionSchema> { |
602 | return c10::nullopt; |
603 | }, |
604 | [](const schema_ref_t schema) -> c10::optional<c10::FunctionSchema> { |
605 | return schema.get(); |
606 | }), |
607 | fn_); |
608 | } |
609 | |
610 | StepCallbacks getStepCallbacks(RecordScope scope) { |
611 | return LocalCallbackManager::get().getActiveCallbacks(scope); |
612 | } |
613 | |
614 | c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope) { |
615 | return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope); |
616 | } |
617 | |
618 | const RecordFunctionTLS& get_record_function_tls_() { |
619 | return LocalCallbackManager::get().getTLS(); |
620 | } |
621 | |
622 | void set_record_function_tls_(const RecordFunctionTLS& tls) { |
623 | LocalCallbackManager::get().setTLS(tls); |
624 | } |
625 | |
626 | namespace { |
627 | bool anyEnabled(const RecordFunctionCallbacks& callbacks) { |
628 | return std::any_of(callbacks.begin(), callbacks.end(), [](const auto& cb) { |
629 | return cb.enabled_; |
630 | }); |
631 | } |
632 | } // namespace |
633 | |
634 | bool hasCallbacks() { |
635 | return hasThreadLocalCallbacks() || hasGlobalCallbacks(); |
636 | } |
637 | |
638 | bool hasGlobalCallbacks() { |
639 | return anyEnabled(GlobalCallbackManager::get().getSnapshot().second); |
640 | } |
641 | |
642 | bool hasThreadLocalCallbacks() { |
643 | return anyEnabled(get_record_function_tls_().sorted_tls_callbacks_); |
644 | } |
645 | |
646 | CallbackHandle addThreadLocalCallback( |
647 | RecordFunctionCallback cb) { |
648 | return LocalCallbackManager::get().addCallback(std::move(cb)); |
649 | } |
650 | |
651 | CallbackHandle addGlobalCallback( |
652 | RecordFunctionCallback cb) { |
653 | return GlobalCallbackManager::get().addCallback(std::move(cb)); |
654 | } |
655 | |
656 | void removeCallback(CallbackHandle handle) { |
657 | if (!LocalCallbackManager::get().removeCallback(handle)) { |
658 | GlobalCallbackManager::get().removeCallback(handle); |
659 | } |
660 | } |
661 | |
662 | void disableCallback(CallbackHandle handle) { |
663 | if (!LocalCallbackManager::get().setCallbackEnabled(handle, false)) { |
664 | GlobalCallbackManager::get().setCallbackEnabled(handle, false); |
665 | } |
666 | } |
667 | |
668 | void reenableCallback(CallbackHandle handle) { |
669 | if (!LocalCallbackManager::get().setCallbackEnabled(handle, true)) { |
670 | GlobalCallbackManager::get().setCallbackEnabled(handle, true); |
671 | } |
672 | } |
673 | |
674 | void clearGlobalCallbacks() { |
675 | GlobalCallbackManager::get().clearCallbacks(); |
676 | } |
677 | |
678 | void clearThreadLocalCallbacks() { |
679 | LocalCallbackManager::get().clearCallbacks(); |
680 | } |
681 | |
682 | void clearCallbacks() { |
683 | clearGlobalCallbacks(); |
684 | clearThreadLocalCallbacks(); |
685 | } |
686 | |
687 | bool isRecordFunctionEnabled() { |
688 | return LocalCallbackManager::get().getTLS().tls_record_function_enabled_; |
689 | } |
690 | |
691 | void enableRecordFunction(bool enable) { |
692 | auto tls = LocalCallbackManager::get().getTLS(); |
693 | if (tls.tls_record_function_enabled_ != enable) { |
694 | tls.tls_record_function_enabled_ = enable; |
695 | LocalCallbackManager::get().setTLS(tls); |
696 | } |
697 | } |
698 | |
699 | void set_record_function_seed_for_testing(uint32_t seed) { |
700 | LocalCallbackManager::get().seed(seed); |
701 | } |
702 | |
703 | /* static */ |
704 | uint64_t RecordFunction::currentThreadId() { |
705 | if (!current_thread_id_) { |
706 | // happens only once per thread |
707 | current_thread_id_ = ++next_thread_id_; |
708 | } |
709 | return current_thread_id_; |
710 | } |
711 | |
712 | void RecordFunction::before(const char* name, int64_t sequence_nr) { |
713 | fn_ = name; |
714 | sequence_nr_ = sequence_nr; |
715 | |
716 | #ifndef NDEBUG |
717 | inputs_valid_ = true; |
718 | #endif |
719 | runStartCallbacks(); |
720 | invalidateInputs(); |
721 | } |
722 | |
723 | void RecordFunction::before(std::string name, int64_t sequence_nr) { |
724 | fn_ = std::move(name); |
725 | sequence_nr_ = sequence_nr; |
726 | |
727 | #ifndef NDEBUG |
728 | inputs_valid_ = true; |
729 | #endif |
730 | runStartCallbacks(); |
731 | invalidateInputs(); |
732 | } |
733 | |
734 | void RecordFunction::before( |
735 | RecordFunction::schema_ref_t schema, |
736 | int64_t sequence_nr) { |
737 | sequence_nr_ = sequence_nr; |
738 | fn_ = schema; |
739 | |
740 | #ifndef NDEBUG |
741 | inputs_valid_ = true; |
742 | #endif |
743 | runStartCallbacks(); |
744 | invalidateInputs(); |
745 | } |
746 | |
747 | /* static */ void RecordFunction::setDefaultNodeId(int64_t newDefaultNodeId) { |
748 | TORCH_CHECK(newDefaultNodeId >= 0, "setDefaultNodeId expects an id >= 0." ); |
749 | defaultNodeId = newDefaultNodeId; |
750 | } |
751 | |
752 | /* static */ int64_t RecordFunction::getDefaultNodeId() { |
753 | return defaultNodeId; |
754 | } |
755 | |
756 | RecordFunction::~RecordFunction() { |
757 | end(); |
758 | } |
759 | |
760 | void RecordFunction::_setAsync() { |
761 | is_async_ = true; |
762 | } |
763 | |
764 | bool RecordFunction::isAsync() const { |
765 | return is_async_; |
766 | } |
767 | |
768 | void RecordFunction::_setStaticRuntimeOutVariant() { |
769 | if (isActive()) { |
770 | is_static_runtime_out_variant_ = true; |
771 | } |
772 | } |
773 | |
774 | bool RecordFunction::isStaticRuntimeOutVariant() const { |
775 | if (isActive()) { |
776 | return is_static_runtime_out_variant_; |
777 | } |
778 | return false; |
779 | } |
780 | } // namespace at |
781 | |